feat: 实现基于 Redis 的延迟任务队列和工作器,支持周期性交易处理等后台任务调度。
This commit is contained in:
227
internal/mq/task_worker.go
Normal file
227
internal/mq/task_worker.go
Normal file
@@ -0,0 +1,227 @@
|
||||
package mq
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"log"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
// TaskHandler 任务处理函数类型
|
||||
type TaskHandler func(ctx context.Context, task *DelayedTask) error
|
||||
|
||||
// TaskWorker 任务工作器,负责消费和执行任务
|
||||
type TaskWorker struct {
|
||||
queue *TaskQueue
|
||||
handlers map[TaskType]TaskHandler
|
||||
pollInterval time.Duration // 检查延迟任务间隔
|
||||
workerCount int // 并发工作器数量
|
||||
|
||||
mu sync.RWMutex
|
||||
running bool
|
||||
stopChan chan struct{}
|
||||
wg sync.WaitGroup
|
||||
|
||||
// 统计信息
|
||||
processedCount int64
|
||||
errorCount int64
|
||||
lastProcessed time.Time
|
||||
}
|
||||
|
||||
// NewTaskWorker 创建任务工作器
|
||||
func NewTaskWorker(queue *TaskQueue, pollInterval time.Duration, workerCount int) *TaskWorker {
|
||||
if pollInterval < time.Second {
|
||||
pollInterval = time.Second
|
||||
}
|
||||
if workerCount < 1 {
|
||||
workerCount = 1
|
||||
}
|
||||
|
||||
return &TaskWorker{
|
||||
queue: queue,
|
||||
handlers: make(map[TaskType]TaskHandler),
|
||||
pollInterval: pollInterval,
|
||||
workerCount: workerCount,
|
||||
stopChan: make(chan struct{}),
|
||||
}
|
||||
}
|
||||
|
||||
// RegisterHandler 注册任务处理器
|
||||
func (w *TaskWorker) RegisterHandler(taskType TaskType, handler TaskHandler) {
|
||||
w.mu.Lock()
|
||||
defer w.mu.Unlock()
|
||||
w.handlers[taskType] = handler
|
||||
log.Printf("[TaskWorker] Registered handler for task type: %s", taskType)
|
||||
}
|
||||
|
||||
// Start 启动工作器
|
||||
func (w *TaskWorker) Start(ctx context.Context) {
|
||||
w.mu.Lock()
|
||||
if w.running {
|
||||
w.mu.Unlock()
|
||||
log.Println("[TaskWorker] Already running")
|
||||
return
|
||||
}
|
||||
w.running = true
|
||||
w.stopChan = make(chan struct{})
|
||||
w.mu.Unlock()
|
||||
|
||||
log.Printf("[TaskWorker] Starting with %d workers, poll interval: %v", w.workerCount, w.pollInterval)
|
||||
|
||||
// 启动延迟任务轮询器
|
||||
w.wg.Add(1)
|
||||
go w.pollDelayedTasks(ctx)
|
||||
|
||||
// 启动工作器
|
||||
for i := 0; i < w.workerCount; i++ {
|
||||
w.wg.Add(1)
|
||||
go w.worker(ctx, i)
|
||||
}
|
||||
}
|
||||
|
||||
// Stop 停止工作器
|
||||
func (w *TaskWorker) Stop() {
|
||||
w.mu.Lock()
|
||||
if !w.running {
|
||||
w.mu.Unlock()
|
||||
return
|
||||
}
|
||||
w.running = false
|
||||
close(w.stopChan)
|
||||
w.mu.Unlock()
|
||||
|
||||
log.Println("[TaskWorker] Stopping...")
|
||||
w.wg.Wait()
|
||||
log.Println("[TaskWorker] Stopped")
|
||||
}
|
||||
|
||||
// pollDelayedTasks 定期检查并移动到期的延迟任务
|
||||
func (w *TaskWorker) pollDelayedTasks(ctx context.Context) {
|
||||
defer w.wg.Done()
|
||||
|
||||
ticker := time.NewTicker(w.pollInterval)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-w.stopChan:
|
||||
return
|
||||
case <-ticker.C:
|
||||
moved, err := w.queue.MoveReadyTasks(ctx)
|
||||
if err != nil {
|
||||
log.Printf("[TaskWorker] Error moving ready tasks: %v", err)
|
||||
} else if moved > 0 {
|
||||
log.Printf("[TaskWorker] Moved %d tasks to ready queue", moved)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// worker 工作器协程
|
||||
func (w *TaskWorker) worker(ctx context.Context, id int) {
|
||||
defer w.wg.Done()
|
||||
|
||||
log.Printf("[TaskWorker] Worker %d started", id)
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
log.Printf("[TaskWorker] Worker %d stopping (context done)", id)
|
||||
return
|
||||
case <-w.stopChan:
|
||||
log.Printf("[TaskWorker] Worker %d stopping (stop signal)", id)
|
||||
return
|
||||
default:
|
||||
// 尝试获取任务
|
||||
task, err := w.queue.PopTask(ctx, time.Second)
|
||||
if err != nil {
|
||||
log.Printf("[TaskWorker] Worker %d error popping task: %v", id, err)
|
||||
time.Sleep(time.Second) // 错误后等待
|
||||
continue
|
||||
}
|
||||
|
||||
if task == nil {
|
||||
continue // 没有任务,继续轮询
|
||||
}
|
||||
|
||||
// 处理任务
|
||||
w.processTask(ctx, id, task)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// processTask 处理单个任务
|
||||
func (w *TaskWorker) processTask(ctx context.Context, workerID int, task *DelayedTask) {
|
||||
startTime := time.Now()
|
||||
log.Printf("[TaskWorker] Worker %d processing task %s (type: %s, user: %d)",
|
||||
workerID, task.ID, task.Type, task.UserID)
|
||||
|
||||
w.mu.RLock()
|
||||
handler, exists := w.handlers[task.Type]
|
||||
w.mu.RUnlock()
|
||||
|
||||
if !exists {
|
||||
log.Printf("[TaskWorker] No handler for task type: %s", task.Type)
|
||||
w.queue.CompleteTask(ctx, task)
|
||||
return
|
||||
}
|
||||
|
||||
// 执行任务
|
||||
err := handler(ctx, task)
|
||||
duration := time.Since(startTime)
|
||||
|
||||
if err != nil {
|
||||
log.Printf("[TaskWorker] Task %s failed (attempt %d/%d): %v",
|
||||
task.ID, task.RetryCount+1, task.MaxRetries, err)
|
||||
|
||||
w.mu.Lock()
|
||||
w.errorCount++
|
||||
w.mu.Unlock()
|
||||
|
||||
// 重试(指数退避)
|
||||
retryDelay := time.Duration(1<<task.RetryCount) * time.Minute
|
||||
if err := w.queue.RetryTask(ctx, task, retryDelay); err != nil {
|
||||
log.Printf("[TaskWorker] Failed to retry task %s: %v", task.ID, err)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// 完成任务
|
||||
if err := w.queue.CompleteTask(ctx, task); err != nil {
|
||||
log.Printf("[TaskWorker] Failed to complete task %s: %v", task.ID, err)
|
||||
}
|
||||
|
||||
w.mu.Lock()
|
||||
w.processedCount++
|
||||
w.lastProcessed = time.Now()
|
||||
w.mu.Unlock()
|
||||
|
||||
log.Printf("[TaskWorker] Task %s completed in %v", task.ID, duration)
|
||||
}
|
||||
|
||||
// GetStats 获取工作器统计信息
|
||||
func (w *TaskWorker) GetStats() map[string]interface{} {
|
||||
w.mu.RLock()
|
||||
defer w.mu.RUnlock()
|
||||
|
||||
return map[string]interface{}{
|
||||
"running": w.running,
|
||||
"worker_count": w.workerCount,
|
||||
"poll_interval": w.pollInterval.String(),
|
||||
"processed_count": w.processedCount,
|
||||
"error_count": w.errorCount,
|
||||
"last_processed": w.lastProcessed,
|
||||
}
|
||||
}
|
||||
|
||||
// ParseRecurringPayload 解析周期性交易任务载荷
|
||||
func ParseRecurringPayload(task *DelayedTask) (*RecurringTransactionPayload, error) {
|
||||
var payload RecurringTransactionPayload
|
||||
if err := json.Unmarshal(task.Payload, &payload); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &payload, nil
|
||||
}
|
||||
Reference in New Issue
Block a user