feat: 实现基于 Redis 的延迟任务队列和工作器,支持周期性交易处理等后台任务调度。

This commit is contained in:
2026-01-28 16:36:56 +08:00
parent a9ee8856ba
commit 5ff680ee43
6 changed files with 884 additions and 0 deletions

View File

@@ -5,10 +5,12 @@ import (
"log"
"os"
"path/filepath"
"time"
"accounting-app/internal/cache"
"accounting-app/internal/config"
"accounting-app/internal/database"
"accounting-app/internal/mq"
"accounting-app/internal/repository"
"accounting-app/internal/router"
"accounting-app/internal/service"
@@ -97,6 +99,7 @@ func main() {
if err != nil {
// Redis not available - fall back to old system
log.Printf("Warning: Redis connection failed (%v), falling back to non-cached exchange rate system", err)
log.Printf("Warning: Recurring transaction MQ system disabled (requires Redis)")
// Use old scheduler
scheduler := service.NewExchangeRateScheduler(yunAPIClient, cfg.SyncInterval)
@@ -116,6 +119,22 @@ func main() {
// Start the new SyncScheduler in background
// This will perform initial sync immediately (Requirement 3.1)
go syncScheduler.Start(ctx)
// Initialize and start MQ task system for recurring transactions
recurringTaskSystem, err := mq.InitRecurringTaskSystem(
ctx,
redisClient.Client(),
db,
5*time.Second, // Poll interval: 检查延迟任务的间隔
2, // Worker count: 并发处理任务的数量
)
if err != nil {
log.Printf("Warning: Failed to initialize recurring task system: %v", err)
} else {
// Start MQ worker in background
go recurringTaskSystem.Start(ctx)
log.Println("Recurring transaction MQ system started")
}
}
// Get port from config or environment

View File

@@ -0,0 +1,78 @@
package handler
import (
"net/http"
"accounting-app/internal/mq"
"github.com/gin-gonic/gin"
)
// TaskQueueHandler 任务队列状态查询 Handler
type TaskQueueHandler struct {
taskSystem *mq.RecurringTaskSystem
}
// NewTaskQueueHandler 创建任务队列 Handler
func NewTaskQueueHandler(taskSystem *mq.RecurringTaskSystem) *TaskQueueHandler {
return &TaskQueueHandler{
taskSystem: taskSystem,
}
}
// GetStats 获取任务队列统计信息
// GET /api/v1/admin/tasks/stats
func (h *TaskQueueHandler) GetStats(c *gin.Context) {
if h.taskSystem == nil {
c.JSON(http.StatusServiceUnavailable, gin.H{
"success": false,
"error": "Task system not available (Redis required)",
})
return
}
stats := h.taskSystem.GetStats(c.Request.Context())
c.JSON(http.StatusOK, gin.H{
"success": true,
"data": stats,
})
}
// GetPendingTasks 获取即将执行的任务列表
// GET /api/v1/admin/tasks/pending
func (h *TaskQueueHandler) GetPendingTasks(c *gin.Context) {
if h.taskSystem == nil {
c.JSON(http.StatusServiceUnavailable, gin.H{
"success": false,
"error": "Task system not available (Redis required)",
})
return
}
tasks, err := h.taskSystem.Queue.GetPendingTasks(c.Request.Context(), 50)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{
"success": false,
"error": err.Error(),
})
return
}
c.JSON(http.StatusOK, gin.H{
"success": true,
"data": gin.H{
"count": len(tasks),
"tasks": tasks,
},
})
}
// RegisterRoutes 注册路由
func (h *TaskQueueHandler) RegisterRoutes(rg *gin.RouterGroup) {
tasks := rg.Group("/admin/tasks")
{
tasks.GET("/stats", h.GetStats)
tasks.GET("/pending", h.GetPendingTasks)
}
}

95
internal/mq/init.go Normal file
View File

@@ -0,0 +1,95 @@
package mq
import (
"context"
"log"
"time"
"accounting-app/internal/repository"
"github.com/redis/go-redis/v9"
"gorm.io/gorm"
)
// RecurringTaskSystem 周期性交易任务系统
type RecurringTaskSystem struct {
Queue *TaskQueue
Worker *TaskWorker
Handler *RecurringTransactionHandler
}
// InitRecurringTaskSystem 初始化周期性交易任务系统
// 需要 Redis 连接来存储任务队列
func InitRecurringTaskSystem(
ctx context.Context,
redisClient *redis.Client,
db *gorm.DB,
pollInterval time.Duration,
workerCount int,
) (*RecurringTaskSystem, error) {
log.Println("[MQ] Initializing recurring task system...")
// 创建任务队列
taskQueue := NewTaskQueue(redisClient, "novault:tasks")
// 创建任务 Worker
worker := NewTaskWorker(taskQueue, pollInterval, workerCount)
// 创建周期性交易处理器
recurringRepo := repository.NewRecurringTransactionRepository(db)
transactionRepo := repository.NewTransactionRepository(db)
accountRepo := repository.NewAccountRepository(db)
handler := NewRecurringTransactionHandler(
db,
recurringRepo,
transactionRepo,
accountRepo,
taskQueue,
)
// 注册处理器
worker.RegisterHandler(TaskTypeRecurringTransaction, handler.Handle)
system := &RecurringTaskSystem{
Queue: taskQueue,
Worker: worker,
Handler: handler,
}
// 启动时调度所有活跃的周期性交易
if err := handler.ScheduleAllActive(ctx); err != nil {
log.Printf("[MQ] Warning: failed to schedule active recurring transactions: %v", err)
}
// 处理逾期任务(补偿机制)
if err := handler.ProcessOverdue(ctx); err != nil {
log.Printf("[MQ] Warning: failed to process overdue transactions: %v", err)
}
log.Println("[MQ] Recurring task system initialized successfully")
return system, nil
}
// Start 启动任务系统
func (s *RecurringTaskSystem) Start(ctx context.Context) {
log.Println("[MQ] Starting recurring task worker...")
s.Worker.Start(ctx)
}
// Stop 停止任务系统
func (s *RecurringTaskSystem) Stop() {
log.Println("[MQ] Stopping recurring task worker...")
s.Worker.Stop()
}
// GetStats 获取系统统计信息
func (s *RecurringTaskSystem) GetStats(ctx context.Context) map[string]interface{} {
queueStats, _ := s.Queue.GetQueueStats(ctx)
workerStats := s.Worker.GetStats()
return map[string]interface{}{
"queue": queueStats,
"worker": workerStats,
}
}

View File

@@ -0,0 +1,245 @@
package mq
import (
"context"
"fmt"
"log"
"time"
"accounting-app/internal/models"
"accounting-app/internal/repository"
"gorm.io/gorm"
)
// RecurringTransactionHandler 周期性交易任务处理器
type RecurringTransactionHandler struct {
db *gorm.DB
recurringRepo *repository.RecurringTransactionRepository
transactionRepo *repository.TransactionRepository
accountRepo *repository.AccountRepository
taskQueue *TaskQueue
}
// NewRecurringTransactionHandler 创建处理器实例
func NewRecurringTransactionHandler(
db *gorm.DB,
recurringRepo *repository.RecurringTransactionRepository,
transactionRepo *repository.TransactionRepository,
accountRepo *repository.AccountRepository,
taskQueue *TaskQueue,
) *RecurringTransactionHandler {
return &RecurringTransactionHandler{
db: db,
recurringRepo: recurringRepo,
transactionRepo: transactionRepo,
accountRepo: accountRepo,
taskQueue: taskQueue,
}
}
// Handle 处理周期性交易任务
func (h *RecurringTransactionHandler) Handle(ctx context.Context, task *DelayedTask) error {
payload, err := ParseRecurringPayload(task)
if err != nil {
return fmt.Errorf("failed to parse payload: %w", err)
}
log.Printf("[RecurringHandler] Processing recurring transaction %d for user %d",
payload.RecurringTransactionID, task.UserID)
// 获取周期性交易记录
recurring, err := h.recurringRepo.GetByID(task.UserID, payload.RecurringTransactionID)
if err != nil {
if err == gorm.ErrRecordNotFound {
log.Printf("[RecurringHandler] Recurring transaction %d not found, skipping", payload.RecurringTransactionID)
return nil // 已被删除,不需要重试
}
return fmt.Errorf("failed to get recurring transaction: %w", err)
}
// 检查是否已禁用或已过期
if !recurring.IsActive {
log.Printf("[RecurringHandler] Recurring transaction %d is inactive, skipping", recurring.ID)
return nil
}
if recurring.EndDate != nil && time.Now().After(*recurring.EndDate) {
log.Printf("[RecurringHandler] Recurring transaction %d has ended, deactivating", recurring.ID)
recurring.IsActive = false
return h.recurringRepo.Update(recurring)
}
// 开始数据库事务
tx := h.db.Begin()
if tx.Error != nil {
return fmt.Errorf("failed to begin transaction: %w", tx.Error)
}
// 创建实际交易记录
transaction := &models.Transaction{
UserID: recurring.UserID,
Amount: recurring.Amount,
Type: recurring.Type,
CategoryID: recurring.CategoryID,
AccountID: recurring.AccountID,
Currency: recurring.Currency,
TransactionDate: recurring.NextOccurrence,
Note: recurring.Note,
RecurringID: &recurring.ID,
}
if err := tx.Create(transaction).Error; err != nil {
tx.Rollback()
return fmt.Errorf("failed to create transaction: %w", err)
}
// 更新账户余额
var account models.Account
if err := tx.First(&account, recurring.AccountID).Error; err != nil {
tx.Rollback()
return fmt.Errorf("failed to get account: %w", err)
}
switch recurring.Type {
case models.TransactionTypeIncome:
account.Balance += recurring.Amount
case models.TransactionTypeExpense:
account.Balance -= recurring.Amount
}
if err := tx.Save(&account).Error; err != nil {
tx.Rollback()
return fmt.Errorf("failed to update account balance: %w", err)
}
// 计算下一次执行时间
nextOccurrence := h.calculateNextOccurrence(recurring.NextOccurrence, recurring.Frequency)
recurring.NextOccurrence = nextOccurrence
// 检查下次执行是否超过结束日期
if recurring.EndDate != nil && nextOccurrence.After(*recurring.EndDate) {
recurring.IsActive = false
log.Printf("[RecurringHandler] Recurring transaction %d will end after this execution", recurring.ID)
}
if err := tx.Save(recurring).Error; err != nil {
tx.Rollback()
return fmt.Errorf("failed to update recurring transaction: %w", err)
}
// 提交事务
if err := tx.Commit().Error; err != nil {
return fmt.Errorf("failed to commit transaction: %w", err)
}
log.Printf("[RecurringHandler] Created transaction %d from recurring %d, amount: %.2f %s",
transaction.ID, recurring.ID, recurring.Amount, recurring.Currency)
// 如果仍然活跃,调度下一次执行
if recurring.IsActive {
if err := h.scheduleNext(ctx, recurring); err != nil {
log.Printf("[RecurringHandler] Warning: failed to schedule next execution: %v", err)
// 不返回错误,当前任务已成功完成
}
}
return nil
}
// calculateNextOccurrence 计算下一次执行时间
func (h *RecurringTransactionHandler) calculateNextOccurrence(current time.Time, frequency models.FrequencyType) time.Time {
switch frequency {
case models.FrequencyDaily:
return current.AddDate(0, 0, 1)
case models.FrequencyWeekly:
return current.AddDate(0, 0, 7)
case models.FrequencyMonthly:
return current.AddDate(0, 1, 0)
case models.FrequencyYearly:
return current.AddDate(1, 0, 0)
default:
return current.AddDate(0, 1, 0) // 默认月度
}
}
// scheduleNext 调度下一次执行
func (h *RecurringTransactionHandler) scheduleNext(ctx context.Context, recurring *models.RecurringTransaction) error {
return h.taskQueue.ScheduleRecurringTransaction(
ctx,
recurring.UserID,
recurring.ID,
recurring.NextOccurrence,
)
}
// ScheduleAllActive 调度所有活跃的周期性交易(启动时调用)
func (h *RecurringTransactionHandler) ScheduleAllActive(ctx context.Context) error {
log.Println("[RecurringHandler] Scheduling all active recurring transactions...")
// 获取所有活跃的周期性交易
var recurringList []models.RecurringTransaction
err := h.db.Where("is_active = ?", true).Find(&recurringList).Error
if err != nil {
return fmt.Errorf("failed to get active recurring transactions: %w", err)
}
scheduledCount := 0
for _, recurring := range recurringList {
// 跳过已过期的
if recurring.EndDate != nil && time.Now().After(*recurring.EndDate) {
recurring.IsActive = false
h.db.Save(&recurring)
continue
}
// 调度任务
err := h.taskQueue.ScheduleRecurringTransaction(
ctx,
recurring.UserID,
recurring.ID,
recurring.NextOccurrence,
)
if err != nil {
log.Printf("[RecurringHandler] Failed to schedule recurring %d: %v", recurring.ID, err)
continue
}
scheduledCount++
}
log.Printf("[RecurringHandler] Scheduled %d recurring transactions", scheduledCount)
return nil
}
// ProcessOverdue 处理所有逾期的周期性交易(补偿机制)
func (h *RecurringTransactionHandler) ProcessOverdue(ctx context.Context) error {
log.Println("[RecurringHandler] Processing overdue recurring transactions...")
now := time.Now()
var overdueList []models.RecurringTransaction
err := h.db.Where("is_active = ? AND next_occurrence <= ?", true, now).Find(&overdueList).Error
if err != nil {
return fmt.Errorf("failed to get overdue recurring transactions: %w", err)
}
if len(overdueList) == 0 {
log.Println("[RecurringHandler] No overdue recurring transactions found")
return nil
}
log.Printf("[RecurringHandler] Found %d overdue recurring transactions", len(overdueList))
for _, recurring := range overdueList {
err := h.taskQueue.ScheduleRecurringTransaction(
ctx,
recurring.UserID,
recurring.ID,
recurring.NextOccurrence,
)
if err != nil {
log.Printf("[RecurringHandler] Failed to schedule overdue recurring %d: %v", recurring.ID, err)
}
}
return nil
}

220
internal/mq/task_queue.go Normal file
View File

@@ -0,0 +1,220 @@
package mq
import (
"context"
"encoding/json"
"fmt"
"log"
"time"
"github.com/redis/go-redis/v9"
)
// TaskType 定义任务类型
type TaskType string
const (
// TaskTypeRecurringTransaction 周期性交易处理任务
TaskTypeRecurringTransaction TaskType = "recurring_transaction"
// TaskTypeAllocationRule 分配规则执行任务
TaskTypeAllocationRule TaskType = "allocation_rule"
)
// DelayedTask 延迟任务结构
type DelayedTask struct {
ID string `json:"id"`
Type TaskType `json:"type"`
UserID uint `json:"user_id"`
Payload json.RawMessage `json:"payload"`
ScheduledAt time.Time `json:"scheduled_at"`
CreatedAt time.Time `json:"created_at"`
RetryCount int `json:"retry_count"`
MaxRetries int `json:"max_retries"`
}
// RecurringTransactionPayload 周期性交易任务载荷
type RecurringTransactionPayload struct {
RecurringTransactionID uint `json:"recurring_transaction_id"`
NextOccurrence time.Time `json:"next_occurrence"`
}
// TaskQueue 基于 Redis 的延迟任务队列
type TaskQueue struct {
client *redis.Client
keyPrefix string
delayedKey string // 延迟任务有序集合
readyKey string // 就绪任务列表
processingKey string // 处理中任务
}
// NewTaskQueue 创建任务队列实例
func NewTaskQueue(client *redis.Client, keyPrefix string) *TaskQueue {
if keyPrefix == "" {
keyPrefix = "novault:tasks"
}
return &TaskQueue{
client: client,
keyPrefix: keyPrefix,
delayedKey: keyPrefix + ":delayed",
readyKey: keyPrefix + ":ready",
processingKey: keyPrefix + ":processing",
}
}
// Schedule 调度延迟任务
func (q *TaskQueue) Schedule(ctx context.Context, task *DelayedTask) error {
// 序列化任务
taskJSON, err := json.Marshal(task)
if err != nil {
return fmt.Errorf("failed to marshal task: %w", err)
}
// 使用 ZADD 添加到延迟队列score 为执行时间戳
score := float64(task.ScheduledAt.Unix())
err = q.client.ZAdd(ctx, q.delayedKey, redis.Z{
Score: score,
Member: string(taskJSON),
}).Err()
if err != nil {
return fmt.Errorf("failed to schedule task: %w", err)
}
log.Printf("[TaskQueue] Scheduled task %s for %s", task.ID, task.ScheduledAt.Format("2006-01-02 15:04:05"))
return nil
}
// ScheduleRecurringTransaction 调度周期性交易任务
func (q *TaskQueue) ScheduleRecurringTransaction(ctx context.Context, userID uint, recurringID uint, scheduledAt time.Time) error {
payload := RecurringTransactionPayload{
RecurringTransactionID: recurringID,
NextOccurrence: scheduledAt,
}
payloadJSON, _ := json.Marshal(payload)
task := &DelayedTask{
ID: fmt.Sprintf("recurring_%d_%d", recurringID, scheduledAt.Unix()),
Type: TaskTypeRecurringTransaction,
UserID: userID,
Payload: payloadJSON,
ScheduledAt: scheduledAt,
CreatedAt: time.Now(),
MaxRetries: 3,
}
return q.Schedule(ctx, task)
}
// MoveReadyTasks 将到期的延迟任务移动到就绪队列
func (q *TaskQueue) MoveReadyTasks(ctx context.Context) (int64, error) {
now := time.Now().Unix()
// 使用 Lua 脚本原子性地移动任务
script := redis.NewScript(`
local delayed_key = KEYS[1]
local ready_key = KEYS[2]
local now = tonumber(ARGV[1])
-- 获取所有到期的任务
local tasks = redis.call('ZRANGEBYSCORE', delayed_key, '-inf', now)
if #tasks == 0 then
return 0
end
-- 移动到就绪队列
for i, task in ipairs(tasks) do
redis.call('LPUSH', ready_key, task)
end
-- 从延迟队列中删除
redis.call('ZREMRANGEBYSCORE', delayed_key, '-inf', now)
return #tasks
`)
result, err := script.Run(ctx, q.client, []string{q.delayedKey, q.readyKey}, now).Int64()
if err != nil {
return 0, fmt.Errorf("failed to move ready tasks: %w", err)
}
if result > 0 {
log.Printf("[TaskQueue] Moved %d tasks to ready queue", result)
}
return result, nil
}
// PopTask 从就绪队列中取出一个任务
func (q *TaskQueue) PopTask(ctx context.Context, timeout time.Duration) (*DelayedTask, error) {
// 使用 BRPOPLPUSH 原子性地获取任务并移动到处理中队列
result, err := q.client.BRPopLPush(ctx, q.readyKey, q.processingKey, timeout).Result()
if err != nil {
if err == redis.Nil {
return nil, nil // 没有任务
}
return nil, fmt.Errorf("failed to pop task: %w", err)
}
var task DelayedTask
if err := json.Unmarshal([]byte(result), &task); err != nil {
return nil, fmt.Errorf("failed to unmarshal task: %w", err)
}
return &task, nil
}
// CompleteTask 标记任务完成
func (q *TaskQueue) CompleteTask(ctx context.Context, task *DelayedTask) error {
taskJSON, _ := json.Marshal(task)
return q.client.LRem(ctx, q.processingKey, 1, string(taskJSON)).Err()
}
// RetryTask 重试任务
func (q *TaskQueue) RetryTask(ctx context.Context, task *DelayedTask, delay time.Duration) error {
task.RetryCount++
if task.RetryCount > task.MaxRetries {
log.Printf("[TaskQueue] Task %s exceeded max retries, discarding", task.ID)
return q.CompleteTask(ctx, task)
}
task.ScheduledAt = time.Now().Add(delay)
// 先从处理中队列移除
taskJSON, _ := json.Marshal(task)
q.client.LRem(ctx, q.processingKey, 1, string(taskJSON))
// 重新调度
return q.Schedule(ctx, task)
}
// GetQueueStats 获取队列统计信息
func (q *TaskQueue) GetQueueStats(ctx context.Context) (map[string]int64, error) {
delayed, _ := q.client.ZCard(ctx, q.delayedKey).Result()
ready, _ := q.client.LLen(ctx, q.readyKey).Result()
processing, _ := q.client.LLen(ctx, q.processingKey).Result()
return map[string]int64{
"delayed": delayed,
"ready": ready,
"processing": processing,
}, nil
}
// GetPendingTasks 获取即将执行的任务(用于调试)
func (q *TaskQueue) GetPendingTasks(ctx context.Context, limit int64) ([]DelayedTask, error) {
results, err := q.client.ZRange(ctx, q.delayedKey, 0, limit-1).Result()
if err != nil {
return nil, err
}
tasks := make([]DelayedTask, 0, len(results))
for _, result := range results {
var task DelayedTask
if err := json.Unmarshal([]byte(result), &task); err == nil {
tasks = append(tasks, task)
}
}
return tasks, nil
}

227
internal/mq/task_worker.go Normal file
View File

@@ -0,0 +1,227 @@
package mq
import (
"context"
"encoding/json"
"log"
"sync"
"time"
)
// TaskHandler 任务处理函数类型
type TaskHandler func(ctx context.Context, task *DelayedTask) error
// TaskWorker 任务工作器,负责消费和执行任务
type TaskWorker struct {
queue *TaskQueue
handlers map[TaskType]TaskHandler
pollInterval time.Duration // 检查延迟任务间隔
workerCount int // 并发工作器数量
mu sync.RWMutex
running bool
stopChan chan struct{}
wg sync.WaitGroup
// 统计信息
processedCount int64
errorCount int64
lastProcessed time.Time
}
// NewTaskWorker 创建任务工作器
func NewTaskWorker(queue *TaskQueue, pollInterval time.Duration, workerCount int) *TaskWorker {
if pollInterval < time.Second {
pollInterval = time.Second
}
if workerCount < 1 {
workerCount = 1
}
return &TaskWorker{
queue: queue,
handlers: make(map[TaskType]TaskHandler),
pollInterval: pollInterval,
workerCount: workerCount,
stopChan: make(chan struct{}),
}
}
// RegisterHandler 注册任务处理器
func (w *TaskWorker) RegisterHandler(taskType TaskType, handler TaskHandler) {
w.mu.Lock()
defer w.mu.Unlock()
w.handlers[taskType] = handler
log.Printf("[TaskWorker] Registered handler for task type: %s", taskType)
}
// Start 启动工作器
func (w *TaskWorker) Start(ctx context.Context) {
w.mu.Lock()
if w.running {
w.mu.Unlock()
log.Println("[TaskWorker] Already running")
return
}
w.running = true
w.stopChan = make(chan struct{})
w.mu.Unlock()
log.Printf("[TaskWorker] Starting with %d workers, poll interval: %v", w.workerCount, w.pollInterval)
// 启动延迟任务轮询器
w.wg.Add(1)
go w.pollDelayedTasks(ctx)
// 启动工作器
for i := 0; i < w.workerCount; i++ {
w.wg.Add(1)
go w.worker(ctx, i)
}
}
// Stop 停止工作器
func (w *TaskWorker) Stop() {
w.mu.Lock()
if !w.running {
w.mu.Unlock()
return
}
w.running = false
close(w.stopChan)
w.mu.Unlock()
log.Println("[TaskWorker] Stopping...")
w.wg.Wait()
log.Println("[TaskWorker] Stopped")
}
// pollDelayedTasks 定期检查并移动到期的延迟任务
func (w *TaskWorker) pollDelayedTasks(ctx context.Context) {
defer w.wg.Done()
ticker := time.NewTicker(w.pollInterval)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
return
case <-w.stopChan:
return
case <-ticker.C:
moved, err := w.queue.MoveReadyTasks(ctx)
if err != nil {
log.Printf("[TaskWorker] Error moving ready tasks: %v", err)
} else if moved > 0 {
log.Printf("[TaskWorker] Moved %d tasks to ready queue", moved)
}
}
}
}
// worker 工作器协程
func (w *TaskWorker) worker(ctx context.Context, id int) {
defer w.wg.Done()
log.Printf("[TaskWorker] Worker %d started", id)
for {
select {
case <-ctx.Done():
log.Printf("[TaskWorker] Worker %d stopping (context done)", id)
return
case <-w.stopChan:
log.Printf("[TaskWorker] Worker %d stopping (stop signal)", id)
return
default:
// 尝试获取任务
task, err := w.queue.PopTask(ctx, time.Second)
if err != nil {
log.Printf("[TaskWorker] Worker %d error popping task: %v", id, err)
time.Sleep(time.Second) // 错误后等待
continue
}
if task == nil {
continue // 没有任务,继续轮询
}
// 处理任务
w.processTask(ctx, id, task)
}
}
}
// processTask 处理单个任务
func (w *TaskWorker) processTask(ctx context.Context, workerID int, task *DelayedTask) {
startTime := time.Now()
log.Printf("[TaskWorker] Worker %d processing task %s (type: %s, user: %d)",
workerID, task.ID, task.Type, task.UserID)
w.mu.RLock()
handler, exists := w.handlers[task.Type]
w.mu.RUnlock()
if !exists {
log.Printf("[TaskWorker] No handler for task type: %s", task.Type)
w.queue.CompleteTask(ctx, task)
return
}
// 执行任务
err := handler(ctx, task)
duration := time.Since(startTime)
if err != nil {
log.Printf("[TaskWorker] Task %s failed (attempt %d/%d): %v",
task.ID, task.RetryCount+1, task.MaxRetries, err)
w.mu.Lock()
w.errorCount++
w.mu.Unlock()
// 重试(指数退避)
retryDelay := time.Duration(1<<task.RetryCount) * time.Minute
if err := w.queue.RetryTask(ctx, task, retryDelay); err != nil {
log.Printf("[TaskWorker] Failed to retry task %s: %v", task.ID, err)
}
return
}
// 完成任务
if err := w.queue.CompleteTask(ctx, task); err != nil {
log.Printf("[TaskWorker] Failed to complete task %s: %v", task.ID, err)
}
w.mu.Lock()
w.processedCount++
w.lastProcessed = time.Now()
w.mu.Unlock()
log.Printf("[TaskWorker] Task %s completed in %v", task.ID, duration)
}
// GetStats 获取工作器统计信息
func (w *TaskWorker) GetStats() map[string]interface{} {
w.mu.RLock()
defer w.mu.RUnlock()
return map[string]interface{}{
"running": w.running,
"worker_count": w.workerCount,
"poll_interval": w.pollInterval.String(),
"processed_count": w.processedCount,
"error_count": w.errorCount,
"last_processed": w.lastProcessed,
}
}
// ParseRecurringPayload 解析周期性交易任务载荷
func ParseRecurringPayload(task *DelayedTask) (*RecurringTransactionPayload, error) {
var payload RecurringTransactionPayload
if err := json.Unmarshal(task.Payload, &payload); err != nil {
return nil, err
}
return &payload, nil
}