Files
Novault-backend/internal/mq/task_worker.go

228 lines
5.3 KiB
Go

package mq
import (
"context"
"encoding/json"
"log"
"sync"
"time"
)
// TaskHandler 任务处理函数类型
type TaskHandler func(ctx context.Context, task *DelayedTask) error
// TaskWorker 任务工作器,负责消费和执行任务
type TaskWorker struct {
queue *TaskQueue
handlers map[TaskType]TaskHandler
pollInterval time.Duration // 检查延迟任务间隔
workerCount int // 并发工作器数量
mu sync.RWMutex
running bool
stopChan chan struct{}
wg sync.WaitGroup
// 统计信息
processedCount int64
errorCount int64
lastProcessed time.Time
}
// NewTaskWorker 创建任务工作器
func NewTaskWorker(queue *TaskQueue, pollInterval time.Duration, workerCount int) *TaskWorker {
if pollInterval < time.Second {
pollInterval = time.Second
}
if workerCount < 1 {
workerCount = 1
}
return &TaskWorker{
queue: queue,
handlers: make(map[TaskType]TaskHandler),
pollInterval: pollInterval,
workerCount: workerCount,
stopChan: make(chan struct{}),
}
}
// RegisterHandler 注册任务处理器
func (w *TaskWorker) RegisterHandler(taskType TaskType, handler TaskHandler) {
w.mu.Lock()
defer w.mu.Unlock()
w.handlers[taskType] = handler
log.Printf("[TaskWorker] Registered handler for task type: %s", taskType)
}
// Start 启动工作器
func (w *TaskWorker) Start(ctx context.Context) {
w.mu.Lock()
if w.running {
w.mu.Unlock()
log.Println("[TaskWorker] Already running")
return
}
w.running = true
w.stopChan = make(chan struct{})
w.mu.Unlock()
log.Printf("[TaskWorker] Starting with %d workers, poll interval: %v", w.workerCount, w.pollInterval)
// 启动延迟任务轮询器
w.wg.Add(1)
go w.pollDelayedTasks(ctx)
// 启动工作器
for i := 0; i < w.workerCount; i++ {
w.wg.Add(1)
go w.worker(ctx, i)
}
}
// Stop 停止工作器
func (w *TaskWorker) Stop() {
w.mu.Lock()
if !w.running {
w.mu.Unlock()
return
}
w.running = false
close(w.stopChan)
w.mu.Unlock()
log.Println("[TaskWorker] Stopping...")
w.wg.Wait()
log.Println("[TaskWorker] Stopped")
}
// pollDelayedTasks 定期检查并移动到期的延迟任务
func (w *TaskWorker) pollDelayedTasks(ctx context.Context) {
defer w.wg.Done()
ticker := time.NewTicker(w.pollInterval)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
return
case <-w.stopChan:
return
case <-ticker.C:
moved, err := w.queue.MoveReadyTasks(ctx)
if err != nil {
log.Printf("[TaskWorker] Error moving ready tasks: %v", err)
} else if moved > 0 {
log.Printf("[TaskWorker] Moved %d tasks to ready queue", moved)
}
}
}
}
// worker 工作器协程
func (w *TaskWorker) worker(ctx context.Context, id int) {
defer w.wg.Done()
log.Printf("[TaskWorker] Worker %d started", id)
for {
select {
case <-ctx.Done():
log.Printf("[TaskWorker] Worker %d stopping (context done)", id)
return
case <-w.stopChan:
log.Printf("[TaskWorker] Worker %d stopping (stop signal)", id)
return
default:
// 尝试获取任务
task, err := w.queue.PopTask(ctx, time.Second)
if err != nil {
log.Printf("[TaskWorker] Worker %d error popping task: %v", id, err)
time.Sleep(time.Second) // 错误后等待
continue
}
if task == nil {
continue // 没有任务,继续轮询
}
// 处理任务
w.processTask(ctx, id, task)
}
}
}
// processTask 处理单个任务
func (w *TaskWorker) processTask(ctx context.Context, workerID int, task *DelayedTask) {
startTime := time.Now()
log.Printf("[TaskWorker] Worker %d processing task %s (type: %s, user: %d)",
workerID, task.ID, task.Type, task.UserID)
w.mu.RLock()
handler, exists := w.handlers[task.Type]
w.mu.RUnlock()
if !exists {
log.Printf("[TaskWorker] No handler for task type: %s", task.Type)
w.queue.CompleteTask(ctx, task)
return
}
// 执行任务
err := handler(ctx, task)
duration := time.Since(startTime)
if err != nil {
log.Printf("[TaskWorker] Task %s failed (attempt %d/%d): %v",
task.ID, task.RetryCount+1, task.MaxRetries, err)
w.mu.Lock()
w.errorCount++
w.mu.Unlock()
// 重试(指数退避)
retryDelay := time.Duration(1<<task.RetryCount) * time.Minute
if err := w.queue.RetryTask(ctx, task, retryDelay); err != nil {
log.Printf("[TaskWorker] Failed to retry task %s: %v", task.ID, err)
}
return
}
// 完成任务
if err := w.queue.CompleteTask(ctx, task); err != nil {
log.Printf("[TaskWorker] Failed to complete task %s: %v", task.ID, err)
}
w.mu.Lock()
w.processedCount++
w.lastProcessed = time.Now()
w.mu.Unlock()
log.Printf("[TaskWorker] Task %s completed in %v", task.ID, duration)
}
// GetStats 获取工作器统计信息
func (w *TaskWorker) GetStats() map[string]interface{} {
w.mu.RLock()
defer w.mu.RUnlock()
return map[string]interface{}{
"running": w.running,
"worker_count": w.workerCount,
"poll_interval": w.pollInterval.String(),
"processed_count": w.processedCount,
"error_count": w.errorCount,
"last_processed": w.lastProcessed,
}
}
// ParseRecurringPayload 解析周期性交易任务载荷
func ParseRecurringPayload(task *DelayedTask) (*RecurringTransactionPayload, error) {
var payload RecurringTransactionPayload
if err := json.Unmarshal(task.Payload, &payload); err != nil {
return nil, err
}
return &payload, nil
}