package mq import ( "context" "encoding/json" "log" "sync" "time" ) // TaskHandler 任务处理函数类型 type TaskHandler func(ctx context.Context, task *DelayedTask) error // TaskWorker 任务工作器,负责消费和执行任务 type TaskWorker struct { queue *TaskQueue handlers map[TaskType]TaskHandler pollInterval time.Duration // 检查延迟任务间隔 workerCount int // 并发工作器数量 mu sync.RWMutex running bool stopChan chan struct{} wg sync.WaitGroup // 统计信息 processedCount int64 errorCount int64 lastProcessed time.Time } // NewTaskWorker 创建任务工作器 func NewTaskWorker(queue *TaskQueue, pollInterval time.Duration, workerCount int) *TaskWorker { if pollInterval < time.Second { pollInterval = time.Second } if workerCount < 1 { workerCount = 1 } return &TaskWorker{ queue: queue, handlers: make(map[TaskType]TaskHandler), pollInterval: pollInterval, workerCount: workerCount, stopChan: make(chan struct{}), } } // RegisterHandler 注册任务处理器 func (w *TaskWorker) RegisterHandler(taskType TaskType, handler TaskHandler) { w.mu.Lock() defer w.mu.Unlock() w.handlers[taskType] = handler log.Printf("[TaskWorker] Registered handler for task type: %s", taskType) } // Start 启动工作器 func (w *TaskWorker) Start(ctx context.Context) { w.mu.Lock() if w.running { w.mu.Unlock() log.Println("[TaskWorker] Already running") return } w.running = true w.stopChan = make(chan struct{}) w.mu.Unlock() log.Printf("[TaskWorker] Starting with %d workers, poll interval: %v", w.workerCount, w.pollInterval) // 启动延迟任务轮询器 w.wg.Add(1) go w.pollDelayedTasks(ctx) // 启动工作器 for i := 0; i < w.workerCount; i++ { w.wg.Add(1) go w.worker(ctx, i) } } // Stop 停止工作器 func (w *TaskWorker) Stop() { w.mu.Lock() if !w.running { w.mu.Unlock() return } w.running = false close(w.stopChan) w.mu.Unlock() log.Println("[TaskWorker] Stopping...") w.wg.Wait() log.Println("[TaskWorker] Stopped") } // pollDelayedTasks 定期检查并移动到期的延迟任务 func (w *TaskWorker) pollDelayedTasks(ctx context.Context) { defer w.wg.Done() ticker := time.NewTicker(w.pollInterval) defer ticker.Stop() for { select { case <-ctx.Done(): return case <-w.stopChan: return case <-ticker.C: moved, err := w.queue.MoveReadyTasks(ctx) if err != nil { log.Printf("[TaskWorker] Error moving ready tasks: %v", err) } else if moved > 0 { log.Printf("[TaskWorker] Moved %d tasks to ready queue", moved) } } } } // worker 工作器协程 func (w *TaskWorker) worker(ctx context.Context, id int) { defer w.wg.Done() log.Printf("[TaskWorker] Worker %d started", id) for { select { case <-ctx.Done(): log.Printf("[TaskWorker] Worker %d stopping (context done)", id) return case <-w.stopChan: log.Printf("[TaskWorker] Worker %d stopping (stop signal)", id) return default: // 尝试获取任务 task, err := w.queue.PopTask(ctx, time.Second) if err != nil { log.Printf("[TaskWorker] Worker %d error popping task: %v", id, err) time.Sleep(time.Second) // 错误后等待 continue } if task == nil { continue // 没有任务,继续轮询 } // 处理任务 w.processTask(ctx, id, task) } } } // processTask 处理单个任务 func (w *TaskWorker) processTask(ctx context.Context, workerID int, task *DelayedTask) { startTime := time.Now() log.Printf("[TaskWorker] Worker %d processing task %s (type: %s, user: %d)", workerID, task.ID, task.Type, task.UserID) w.mu.RLock() handler, exists := w.handlers[task.Type] w.mu.RUnlock() if !exists { log.Printf("[TaskWorker] No handler for task type: %s", task.Type) w.queue.CompleteTask(ctx, task) return } // 执行任务 err := handler(ctx, task) duration := time.Since(startTime) if err != nil { log.Printf("[TaskWorker] Task %s failed (attempt %d/%d): %v", task.ID, task.RetryCount+1, task.MaxRetries, err) w.mu.Lock() w.errorCount++ w.mu.Unlock() // 重试(指数退避) retryDelay := time.Duration(1<