From 5ff680ee43be3defe88bb667ba69659bb88f2d46 Mon Sep 17 00:00:00 2001 From: admin <1297598740@qq.com> Date: Wed, 28 Jan 2026 16:36:56 +0800 Subject: [PATCH] =?UTF-8?q?feat:=20=E5=AE=9E=E7=8E=B0=E5=9F=BA=E4=BA=8E=20?= =?UTF-8?q?Redis=20=E7=9A=84=E5=BB=B6=E8=BF=9F=E4=BB=BB=E5=8A=A1=E9=98=9F?= =?UTF-8?q?=E5=88=97=E5=92=8C=E5=B7=A5=E4=BD=9C=E5=99=A8=EF=BC=8C=E6=94=AF?= =?UTF-8?q?=E6=8C=81=E5=91=A8=E6=9C=9F=E6=80=A7=E4=BA=A4=E6=98=93=E5=A4=84?= =?UTF-8?q?=E7=90=86=E7=AD=89=E5=90=8E=E5=8F=B0=E4=BB=BB=E5=8A=A1=E8=B0=83?= =?UTF-8?q?=E5=BA=A6=E3=80=82?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- cmd/server/main.go | 19 ++ internal/handler/task_queue_handler.go | 78 ++++++++ internal/mq/init.go | 95 ++++++++++ internal/mq/recurring_handler.go | 245 +++++++++++++++++++++++++ internal/mq/task_queue.go | 220 ++++++++++++++++++++++ internal/mq/task_worker.go | 227 +++++++++++++++++++++++ 6 files changed, 884 insertions(+) create mode 100644 internal/handler/task_queue_handler.go create mode 100644 internal/mq/init.go create mode 100644 internal/mq/recurring_handler.go create mode 100644 internal/mq/task_queue.go create mode 100644 internal/mq/task_worker.go diff --git a/cmd/server/main.go b/cmd/server/main.go index 5ef3616..5dee426 100644 --- a/cmd/server/main.go +++ b/cmd/server/main.go @@ -5,10 +5,12 @@ import ( "log" "os" "path/filepath" + "time" "accounting-app/internal/cache" "accounting-app/internal/config" "accounting-app/internal/database" + "accounting-app/internal/mq" "accounting-app/internal/repository" "accounting-app/internal/router" "accounting-app/internal/service" @@ -97,6 +99,7 @@ func main() { if err != nil { // Redis not available - fall back to old system log.Printf("Warning: Redis connection failed (%v), falling back to non-cached exchange rate system", err) + log.Printf("Warning: Recurring transaction MQ system disabled (requires Redis)") // Use old scheduler scheduler := service.NewExchangeRateScheduler(yunAPIClient, cfg.SyncInterval) @@ -116,6 +119,22 @@ func main() { // Start the new SyncScheduler in background // This will perform initial sync immediately (Requirement 3.1) go syncScheduler.Start(ctx) + + // Initialize and start MQ task system for recurring transactions + recurringTaskSystem, err := mq.InitRecurringTaskSystem( + ctx, + redisClient.Client(), + db, + 5*time.Second, // Poll interval: 检查延迟任务的间隔 + 2, // Worker count: 并发处理任务的数量 + ) + if err != nil { + log.Printf("Warning: Failed to initialize recurring task system: %v", err) + } else { + // Start MQ worker in background + go recurringTaskSystem.Start(ctx) + log.Println("Recurring transaction MQ system started") + } } // Get port from config or environment diff --git a/internal/handler/task_queue_handler.go b/internal/handler/task_queue_handler.go new file mode 100644 index 0000000..c1f5f13 --- /dev/null +++ b/internal/handler/task_queue_handler.go @@ -0,0 +1,78 @@ +package handler + +import ( + "net/http" + + "accounting-app/internal/mq" + + "github.com/gin-gonic/gin" +) + +// TaskQueueHandler 任务队列状态查询 Handler +type TaskQueueHandler struct { + taskSystem *mq.RecurringTaskSystem +} + +// NewTaskQueueHandler 创建任务队列 Handler +func NewTaskQueueHandler(taskSystem *mq.RecurringTaskSystem) *TaskQueueHandler { + return &TaskQueueHandler{ + taskSystem: taskSystem, + } +} + +// GetStats 获取任务队列统计信息 +// GET /api/v1/admin/tasks/stats +func (h *TaskQueueHandler) GetStats(c *gin.Context) { + if h.taskSystem == nil { + c.JSON(http.StatusServiceUnavailable, gin.H{ + "success": false, + "error": "Task system not available (Redis required)", + }) + return + } + + stats := h.taskSystem.GetStats(c.Request.Context()) + + c.JSON(http.StatusOK, gin.H{ + "success": true, + "data": stats, + }) +} + +// GetPendingTasks 获取即将执行的任务列表 +// GET /api/v1/admin/tasks/pending +func (h *TaskQueueHandler) GetPendingTasks(c *gin.Context) { + if h.taskSystem == nil { + c.JSON(http.StatusServiceUnavailable, gin.H{ + "success": false, + "error": "Task system not available (Redis required)", + }) + return + } + + tasks, err := h.taskSystem.Queue.GetPendingTasks(c.Request.Context(), 50) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{ + "success": false, + "error": err.Error(), + }) + return + } + + c.JSON(http.StatusOK, gin.H{ + "success": true, + "data": gin.H{ + "count": len(tasks), + "tasks": tasks, + }, + }) +} + +// RegisterRoutes 注册路由 +func (h *TaskQueueHandler) RegisterRoutes(rg *gin.RouterGroup) { + tasks := rg.Group("/admin/tasks") + { + tasks.GET("/stats", h.GetStats) + tasks.GET("/pending", h.GetPendingTasks) + } +} diff --git a/internal/mq/init.go b/internal/mq/init.go new file mode 100644 index 0000000..c41c143 --- /dev/null +++ b/internal/mq/init.go @@ -0,0 +1,95 @@ +package mq + +import ( + "context" + "log" + "time" + + "accounting-app/internal/repository" + + "github.com/redis/go-redis/v9" + "gorm.io/gorm" +) + +// RecurringTaskSystem 周期性交易任务系统 +type RecurringTaskSystem struct { + Queue *TaskQueue + Worker *TaskWorker + Handler *RecurringTransactionHandler +} + +// InitRecurringTaskSystem 初始化周期性交易任务系统 +// 需要 Redis 连接来存储任务队列 +func InitRecurringTaskSystem( + ctx context.Context, + redisClient *redis.Client, + db *gorm.DB, + pollInterval time.Duration, + workerCount int, +) (*RecurringTaskSystem, error) { + log.Println("[MQ] Initializing recurring task system...") + + // 创建任务队列 + taskQueue := NewTaskQueue(redisClient, "novault:tasks") + + // 创建任务 Worker + worker := NewTaskWorker(taskQueue, pollInterval, workerCount) + + // 创建周期性交易处理器 + recurringRepo := repository.NewRecurringTransactionRepository(db) + transactionRepo := repository.NewTransactionRepository(db) + accountRepo := repository.NewAccountRepository(db) + + handler := NewRecurringTransactionHandler( + db, + recurringRepo, + transactionRepo, + accountRepo, + taskQueue, + ) + + // 注册处理器 + worker.RegisterHandler(TaskTypeRecurringTransaction, handler.Handle) + + system := &RecurringTaskSystem{ + Queue: taskQueue, + Worker: worker, + Handler: handler, + } + + // 启动时调度所有活跃的周期性交易 + if err := handler.ScheduleAllActive(ctx); err != nil { + log.Printf("[MQ] Warning: failed to schedule active recurring transactions: %v", err) + } + + // 处理逾期任务(补偿机制) + if err := handler.ProcessOverdue(ctx); err != nil { + log.Printf("[MQ] Warning: failed to process overdue transactions: %v", err) + } + + log.Println("[MQ] Recurring task system initialized successfully") + return system, nil +} + +// Start 启动任务系统 +func (s *RecurringTaskSystem) Start(ctx context.Context) { + log.Println("[MQ] Starting recurring task worker...") + s.Worker.Start(ctx) +} + +// Stop 停止任务系统 +func (s *RecurringTaskSystem) Stop() { + log.Println("[MQ] Stopping recurring task worker...") + s.Worker.Stop() +} + +// GetStats 获取系统统计信息 +func (s *RecurringTaskSystem) GetStats(ctx context.Context) map[string]interface{} { + queueStats, _ := s.Queue.GetQueueStats(ctx) + workerStats := s.Worker.GetStats() + + return map[string]interface{}{ + "queue": queueStats, + "worker": workerStats, + } +} diff --git a/internal/mq/recurring_handler.go b/internal/mq/recurring_handler.go new file mode 100644 index 0000000..26b39d7 --- /dev/null +++ b/internal/mq/recurring_handler.go @@ -0,0 +1,245 @@ +package mq + +import ( + "context" + "fmt" + "log" + "time" + + "accounting-app/internal/models" + "accounting-app/internal/repository" + + "gorm.io/gorm" +) + +// RecurringTransactionHandler 周期性交易任务处理器 +type RecurringTransactionHandler struct { + db *gorm.DB + recurringRepo *repository.RecurringTransactionRepository + transactionRepo *repository.TransactionRepository + accountRepo *repository.AccountRepository + taskQueue *TaskQueue +} + +// NewRecurringTransactionHandler 创建处理器实例 +func NewRecurringTransactionHandler( + db *gorm.DB, + recurringRepo *repository.RecurringTransactionRepository, + transactionRepo *repository.TransactionRepository, + accountRepo *repository.AccountRepository, + taskQueue *TaskQueue, +) *RecurringTransactionHandler { + return &RecurringTransactionHandler{ + db: db, + recurringRepo: recurringRepo, + transactionRepo: transactionRepo, + accountRepo: accountRepo, + taskQueue: taskQueue, + } +} + +// Handle 处理周期性交易任务 +func (h *RecurringTransactionHandler) Handle(ctx context.Context, task *DelayedTask) error { + payload, err := ParseRecurringPayload(task) + if err != nil { + return fmt.Errorf("failed to parse payload: %w", err) + } + + log.Printf("[RecurringHandler] Processing recurring transaction %d for user %d", + payload.RecurringTransactionID, task.UserID) + + // 获取周期性交易记录 + recurring, err := h.recurringRepo.GetByID(task.UserID, payload.RecurringTransactionID) + if err != nil { + if err == gorm.ErrRecordNotFound { + log.Printf("[RecurringHandler] Recurring transaction %d not found, skipping", payload.RecurringTransactionID) + return nil // 已被删除,不需要重试 + } + return fmt.Errorf("failed to get recurring transaction: %w", err) + } + + // 检查是否已禁用或已过期 + if !recurring.IsActive { + log.Printf("[RecurringHandler] Recurring transaction %d is inactive, skipping", recurring.ID) + return nil + } + + if recurring.EndDate != nil && time.Now().After(*recurring.EndDate) { + log.Printf("[RecurringHandler] Recurring transaction %d has ended, deactivating", recurring.ID) + recurring.IsActive = false + return h.recurringRepo.Update(recurring) + } + + // 开始数据库事务 + tx := h.db.Begin() + if tx.Error != nil { + return fmt.Errorf("failed to begin transaction: %w", tx.Error) + } + + // 创建实际交易记录 + transaction := &models.Transaction{ + UserID: recurring.UserID, + Amount: recurring.Amount, + Type: recurring.Type, + CategoryID: recurring.CategoryID, + AccountID: recurring.AccountID, + Currency: recurring.Currency, + TransactionDate: recurring.NextOccurrence, + Note: recurring.Note, + RecurringID: &recurring.ID, + } + + if err := tx.Create(transaction).Error; err != nil { + tx.Rollback() + return fmt.Errorf("failed to create transaction: %w", err) + } + + // 更新账户余额 + var account models.Account + if err := tx.First(&account, recurring.AccountID).Error; err != nil { + tx.Rollback() + return fmt.Errorf("failed to get account: %w", err) + } + + switch recurring.Type { + case models.TransactionTypeIncome: + account.Balance += recurring.Amount + case models.TransactionTypeExpense: + account.Balance -= recurring.Amount + } + + if err := tx.Save(&account).Error; err != nil { + tx.Rollback() + return fmt.Errorf("failed to update account balance: %w", err) + } + + // 计算下一次执行时间 + nextOccurrence := h.calculateNextOccurrence(recurring.NextOccurrence, recurring.Frequency) + recurring.NextOccurrence = nextOccurrence + + // 检查下次执行是否超过结束日期 + if recurring.EndDate != nil && nextOccurrence.After(*recurring.EndDate) { + recurring.IsActive = false + log.Printf("[RecurringHandler] Recurring transaction %d will end after this execution", recurring.ID) + } + + if err := tx.Save(recurring).Error; err != nil { + tx.Rollback() + return fmt.Errorf("failed to update recurring transaction: %w", err) + } + + // 提交事务 + if err := tx.Commit().Error; err != nil { + return fmt.Errorf("failed to commit transaction: %w", err) + } + + log.Printf("[RecurringHandler] Created transaction %d from recurring %d, amount: %.2f %s", + transaction.ID, recurring.ID, recurring.Amount, recurring.Currency) + + // 如果仍然活跃,调度下一次执行 + if recurring.IsActive { + if err := h.scheduleNext(ctx, recurring); err != nil { + log.Printf("[RecurringHandler] Warning: failed to schedule next execution: %v", err) + // 不返回错误,当前任务已成功完成 + } + } + + return nil +} + +// calculateNextOccurrence 计算下一次执行时间 +func (h *RecurringTransactionHandler) calculateNextOccurrence(current time.Time, frequency models.FrequencyType) time.Time { + switch frequency { + case models.FrequencyDaily: + return current.AddDate(0, 0, 1) + case models.FrequencyWeekly: + return current.AddDate(0, 0, 7) + case models.FrequencyMonthly: + return current.AddDate(0, 1, 0) + case models.FrequencyYearly: + return current.AddDate(1, 0, 0) + default: + return current.AddDate(0, 1, 0) // 默认月度 + } +} + +// scheduleNext 调度下一次执行 +func (h *RecurringTransactionHandler) scheduleNext(ctx context.Context, recurring *models.RecurringTransaction) error { + return h.taskQueue.ScheduleRecurringTransaction( + ctx, + recurring.UserID, + recurring.ID, + recurring.NextOccurrence, + ) +} + +// ScheduleAllActive 调度所有活跃的周期性交易(启动时调用) +func (h *RecurringTransactionHandler) ScheduleAllActive(ctx context.Context) error { + log.Println("[RecurringHandler] Scheduling all active recurring transactions...") + + // 获取所有活跃的周期性交易 + var recurringList []models.RecurringTransaction + err := h.db.Where("is_active = ?", true).Find(&recurringList).Error + if err != nil { + return fmt.Errorf("failed to get active recurring transactions: %w", err) + } + + scheduledCount := 0 + for _, recurring := range recurringList { + // 跳过已过期的 + if recurring.EndDate != nil && time.Now().After(*recurring.EndDate) { + recurring.IsActive = false + h.db.Save(&recurring) + continue + } + + // 调度任务 + err := h.taskQueue.ScheduleRecurringTransaction( + ctx, + recurring.UserID, + recurring.ID, + recurring.NextOccurrence, + ) + if err != nil { + log.Printf("[RecurringHandler] Failed to schedule recurring %d: %v", recurring.ID, err) + continue + } + scheduledCount++ + } + + log.Printf("[RecurringHandler] Scheduled %d recurring transactions", scheduledCount) + return nil +} + +// ProcessOverdue 处理所有逾期的周期性交易(补偿机制) +func (h *RecurringTransactionHandler) ProcessOverdue(ctx context.Context) error { + log.Println("[RecurringHandler] Processing overdue recurring transactions...") + + now := time.Now() + var overdueList []models.RecurringTransaction + err := h.db.Where("is_active = ? AND next_occurrence <= ?", true, now).Find(&overdueList).Error + if err != nil { + return fmt.Errorf("failed to get overdue recurring transactions: %w", err) + } + + if len(overdueList) == 0 { + log.Println("[RecurringHandler] No overdue recurring transactions found") + return nil + } + + log.Printf("[RecurringHandler] Found %d overdue recurring transactions", len(overdueList)) + + for _, recurring := range overdueList { + err := h.taskQueue.ScheduleRecurringTransaction( + ctx, + recurring.UserID, + recurring.ID, + recurring.NextOccurrence, + ) + if err != nil { + log.Printf("[RecurringHandler] Failed to schedule overdue recurring %d: %v", recurring.ID, err) + } + } + + return nil +} diff --git a/internal/mq/task_queue.go b/internal/mq/task_queue.go new file mode 100644 index 0000000..c2579e1 --- /dev/null +++ b/internal/mq/task_queue.go @@ -0,0 +1,220 @@ +package mq + +import ( + "context" + "encoding/json" + "fmt" + "log" + "time" + + "github.com/redis/go-redis/v9" +) + +// TaskType 定义任务类型 +type TaskType string + +const ( + // TaskTypeRecurringTransaction 周期性交易处理任务 + TaskTypeRecurringTransaction TaskType = "recurring_transaction" + // TaskTypeAllocationRule 分配规则执行任务 + TaskTypeAllocationRule TaskType = "allocation_rule" +) + +// DelayedTask 延迟任务结构 +type DelayedTask struct { + ID string `json:"id"` + Type TaskType `json:"type"` + UserID uint `json:"user_id"` + Payload json.RawMessage `json:"payload"` + ScheduledAt time.Time `json:"scheduled_at"` + CreatedAt time.Time `json:"created_at"` + RetryCount int `json:"retry_count"` + MaxRetries int `json:"max_retries"` +} + +// RecurringTransactionPayload 周期性交易任务载荷 +type RecurringTransactionPayload struct { + RecurringTransactionID uint `json:"recurring_transaction_id"` + NextOccurrence time.Time `json:"next_occurrence"` +} + +// TaskQueue 基于 Redis 的延迟任务队列 +type TaskQueue struct { + client *redis.Client + keyPrefix string + delayedKey string // 延迟任务有序集合 + readyKey string // 就绪任务列表 + processingKey string // 处理中任务 +} + +// NewTaskQueue 创建任务队列实例 +func NewTaskQueue(client *redis.Client, keyPrefix string) *TaskQueue { + if keyPrefix == "" { + keyPrefix = "novault:tasks" + } + return &TaskQueue{ + client: client, + keyPrefix: keyPrefix, + delayedKey: keyPrefix + ":delayed", + readyKey: keyPrefix + ":ready", + processingKey: keyPrefix + ":processing", + } +} + +// Schedule 调度延迟任务 +func (q *TaskQueue) Schedule(ctx context.Context, task *DelayedTask) error { + // 序列化任务 + taskJSON, err := json.Marshal(task) + if err != nil { + return fmt.Errorf("failed to marshal task: %w", err) + } + + // 使用 ZADD 添加到延迟队列,score 为执行时间戳 + score := float64(task.ScheduledAt.Unix()) + err = q.client.ZAdd(ctx, q.delayedKey, redis.Z{ + Score: score, + Member: string(taskJSON), + }).Err() + + if err != nil { + return fmt.Errorf("failed to schedule task: %w", err) + } + + log.Printf("[TaskQueue] Scheduled task %s for %s", task.ID, task.ScheduledAt.Format("2006-01-02 15:04:05")) + return nil +} + +// ScheduleRecurringTransaction 调度周期性交易任务 +func (q *TaskQueue) ScheduleRecurringTransaction(ctx context.Context, userID uint, recurringID uint, scheduledAt time.Time) error { + payload := RecurringTransactionPayload{ + RecurringTransactionID: recurringID, + NextOccurrence: scheduledAt, + } + payloadJSON, _ := json.Marshal(payload) + + task := &DelayedTask{ + ID: fmt.Sprintf("recurring_%d_%d", recurringID, scheduledAt.Unix()), + Type: TaskTypeRecurringTransaction, + UserID: userID, + Payload: payloadJSON, + ScheduledAt: scheduledAt, + CreatedAt: time.Now(), + MaxRetries: 3, + } + + return q.Schedule(ctx, task) +} + +// MoveReadyTasks 将到期的延迟任务移动到就绪队列 +func (q *TaskQueue) MoveReadyTasks(ctx context.Context) (int64, error) { + now := time.Now().Unix() + + // 使用 Lua 脚本原子性地移动任务 + script := redis.NewScript(` + local delayed_key = KEYS[1] + local ready_key = KEYS[2] + local now = tonumber(ARGV[1]) + + -- 获取所有到期的任务 + local tasks = redis.call('ZRANGEBYSCORE', delayed_key, '-inf', now) + + if #tasks == 0 then + return 0 + end + + -- 移动到就绪队列 + for i, task in ipairs(tasks) do + redis.call('LPUSH', ready_key, task) + end + + -- 从延迟队列中删除 + redis.call('ZREMRANGEBYSCORE', delayed_key, '-inf', now) + + return #tasks + `) + + result, err := script.Run(ctx, q.client, []string{q.delayedKey, q.readyKey}, now).Int64() + if err != nil { + return 0, fmt.Errorf("failed to move ready tasks: %w", err) + } + + if result > 0 { + log.Printf("[TaskQueue] Moved %d tasks to ready queue", result) + } + + return result, nil +} + +// PopTask 从就绪队列中取出一个任务 +func (q *TaskQueue) PopTask(ctx context.Context, timeout time.Duration) (*DelayedTask, error) { + // 使用 BRPOPLPUSH 原子性地获取任务并移动到处理中队列 + result, err := q.client.BRPopLPush(ctx, q.readyKey, q.processingKey, timeout).Result() + if err != nil { + if err == redis.Nil { + return nil, nil // 没有任务 + } + return nil, fmt.Errorf("failed to pop task: %w", err) + } + + var task DelayedTask + if err := json.Unmarshal([]byte(result), &task); err != nil { + return nil, fmt.Errorf("failed to unmarshal task: %w", err) + } + + return &task, nil +} + +// CompleteTask 标记任务完成 +func (q *TaskQueue) CompleteTask(ctx context.Context, task *DelayedTask) error { + taskJSON, _ := json.Marshal(task) + return q.client.LRem(ctx, q.processingKey, 1, string(taskJSON)).Err() +} + +// RetryTask 重试任务 +func (q *TaskQueue) RetryTask(ctx context.Context, task *DelayedTask, delay time.Duration) error { + task.RetryCount++ + if task.RetryCount > task.MaxRetries { + log.Printf("[TaskQueue] Task %s exceeded max retries, discarding", task.ID) + return q.CompleteTask(ctx, task) + } + + task.ScheduledAt = time.Now().Add(delay) + + // 先从处理中队列移除 + taskJSON, _ := json.Marshal(task) + q.client.LRem(ctx, q.processingKey, 1, string(taskJSON)) + + // 重新调度 + return q.Schedule(ctx, task) +} + +// GetQueueStats 获取队列统计信息 +func (q *TaskQueue) GetQueueStats(ctx context.Context) (map[string]int64, error) { + delayed, _ := q.client.ZCard(ctx, q.delayedKey).Result() + ready, _ := q.client.LLen(ctx, q.readyKey).Result() + processing, _ := q.client.LLen(ctx, q.processingKey).Result() + + return map[string]int64{ + "delayed": delayed, + "ready": ready, + "processing": processing, + }, nil +} + +// GetPendingTasks 获取即将执行的任务(用于调试) +func (q *TaskQueue) GetPendingTasks(ctx context.Context, limit int64) ([]DelayedTask, error) { + results, err := q.client.ZRange(ctx, q.delayedKey, 0, limit-1).Result() + if err != nil { + return nil, err + } + + tasks := make([]DelayedTask, 0, len(results)) + for _, result := range results { + var task DelayedTask + if err := json.Unmarshal([]byte(result), &task); err == nil { + tasks = append(tasks, task) + } + } + + return tasks, nil +} diff --git a/internal/mq/task_worker.go b/internal/mq/task_worker.go new file mode 100644 index 0000000..512e6e9 --- /dev/null +++ b/internal/mq/task_worker.go @@ -0,0 +1,227 @@ +package mq + +import ( + "context" + "encoding/json" + "log" + "sync" + "time" +) + +// TaskHandler 任务处理函数类型 +type TaskHandler func(ctx context.Context, task *DelayedTask) error + +// TaskWorker 任务工作器,负责消费和执行任务 +type TaskWorker struct { + queue *TaskQueue + handlers map[TaskType]TaskHandler + pollInterval time.Duration // 检查延迟任务间隔 + workerCount int // 并发工作器数量 + + mu sync.RWMutex + running bool + stopChan chan struct{} + wg sync.WaitGroup + + // 统计信息 + processedCount int64 + errorCount int64 + lastProcessed time.Time +} + +// NewTaskWorker 创建任务工作器 +func NewTaskWorker(queue *TaskQueue, pollInterval time.Duration, workerCount int) *TaskWorker { + if pollInterval < time.Second { + pollInterval = time.Second + } + if workerCount < 1 { + workerCount = 1 + } + + return &TaskWorker{ + queue: queue, + handlers: make(map[TaskType]TaskHandler), + pollInterval: pollInterval, + workerCount: workerCount, + stopChan: make(chan struct{}), + } +} + +// RegisterHandler 注册任务处理器 +func (w *TaskWorker) RegisterHandler(taskType TaskType, handler TaskHandler) { + w.mu.Lock() + defer w.mu.Unlock() + w.handlers[taskType] = handler + log.Printf("[TaskWorker] Registered handler for task type: %s", taskType) +} + +// Start 启动工作器 +func (w *TaskWorker) Start(ctx context.Context) { + w.mu.Lock() + if w.running { + w.mu.Unlock() + log.Println("[TaskWorker] Already running") + return + } + w.running = true + w.stopChan = make(chan struct{}) + w.mu.Unlock() + + log.Printf("[TaskWorker] Starting with %d workers, poll interval: %v", w.workerCount, w.pollInterval) + + // 启动延迟任务轮询器 + w.wg.Add(1) + go w.pollDelayedTasks(ctx) + + // 启动工作器 + for i := 0; i < w.workerCount; i++ { + w.wg.Add(1) + go w.worker(ctx, i) + } +} + +// Stop 停止工作器 +func (w *TaskWorker) Stop() { + w.mu.Lock() + if !w.running { + w.mu.Unlock() + return + } + w.running = false + close(w.stopChan) + w.mu.Unlock() + + log.Println("[TaskWorker] Stopping...") + w.wg.Wait() + log.Println("[TaskWorker] Stopped") +} + +// pollDelayedTasks 定期检查并移动到期的延迟任务 +func (w *TaskWorker) pollDelayedTasks(ctx context.Context) { + defer w.wg.Done() + + ticker := time.NewTicker(w.pollInterval) + defer ticker.Stop() + + for { + select { + case <-ctx.Done(): + return + case <-w.stopChan: + return + case <-ticker.C: + moved, err := w.queue.MoveReadyTasks(ctx) + if err != nil { + log.Printf("[TaskWorker] Error moving ready tasks: %v", err) + } else if moved > 0 { + log.Printf("[TaskWorker] Moved %d tasks to ready queue", moved) + } + } + } +} + +// worker 工作器协程 +func (w *TaskWorker) worker(ctx context.Context, id int) { + defer w.wg.Done() + + log.Printf("[TaskWorker] Worker %d started", id) + + for { + select { + case <-ctx.Done(): + log.Printf("[TaskWorker] Worker %d stopping (context done)", id) + return + case <-w.stopChan: + log.Printf("[TaskWorker] Worker %d stopping (stop signal)", id) + return + default: + // 尝试获取任务 + task, err := w.queue.PopTask(ctx, time.Second) + if err != nil { + log.Printf("[TaskWorker] Worker %d error popping task: %v", id, err) + time.Sleep(time.Second) // 错误后等待 + continue + } + + if task == nil { + continue // 没有任务,继续轮询 + } + + // 处理任务 + w.processTask(ctx, id, task) + } + } +} + +// processTask 处理单个任务 +func (w *TaskWorker) processTask(ctx context.Context, workerID int, task *DelayedTask) { + startTime := time.Now() + log.Printf("[TaskWorker] Worker %d processing task %s (type: %s, user: %d)", + workerID, task.ID, task.Type, task.UserID) + + w.mu.RLock() + handler, exists := w.handlers[task.Type] + w.mu.RUnlock() + + if !exists { + log.Printf("[TaskWorker] No handler for task type: %s", task.Type) + w.queue.CompleteTask(ctx, task) + return + } + + // 执行任务 + err := handler(ctx, task) + duration := time.Since(startTime) + + if err != nil { + log.Printf("[TaskWorker] Task %s failed (attempt %d/%d): %v", + task.ID, task.RetryCount+1, task.MaxRetries, err) + + w.mu.Lock() + w.errorCount++ + w.mu.Unlock() + + // 重试(指数退避) + retryDelay := time.Duration(1<