228 lines
5.3 KiB
Go
228 lines
5.3 KiB
Go
|
|
package mq
|
||
|
|
|
||
|
|
import (
|
||
|
|
"context"
|
||
|
|
"encoding/json"
|
||
|
|
"log"
|
||
|
|
"sync"
|
||
|
|
"time"
|
||
|
|
)
|
||
|
|
|
||
|
|
// TaskHandler 任务处理函数类型
|
||
|
|
type TaskHandler func(ctx context.Context, task *DelayedTask) error
|
||
|
|
|
||
|
|
// TaskWorker 任务工作器,负责消费和执行任务
|
||
|
|
type TaskWorker struct {
|
||
|
|
queue *TaskQueue
|
||
|
|
handlers map[TaskType]TaskHandler
|
||
|
|
pollInterval time.Duration // 检查延迟任务间隔
|
||
|
|
workerCount int // 并发工作器数量
|
||
|
|
|
||
|
|
mu sync.RWMutex
|
||
|
|
running bool
|
||
|
|
stopChan chan struct{}
|
||
|
|
wg sync.WaitGroup
|
||
|
|
|
||
|
|
// 统计信息
|
||
|
|
processedCount int64
|
||
|
|
errorCount int64
|
||
|
|
lastProcessed time.Time
|
||
|
|
}
|
||
|
|
|
||
|
|
// NewTaskWorker 创建任务工作器
|
||
|
|
func NewTaskWorker(queue *TaskQueue, pollInterval time.Duration, workerCount int) *TaskWorker {
|
||
|
|
if pollInterval < time.Second {
|
||
|
|
pollInterval = time.Second
|
||
|
|
}
|
||
|
|
if workerCount < 1 {
|
||
|
|
workerCount = 1
|
||
|
|
}
|
||
|
|
|
||
|
|
return &TaskWorker{
|
||
|
|
queue: queue,
|
||
|
|
handlers: make(map[TaskType]TaskHandler),
|
||
|
|
pollInterval: pollInterval,
|
||
|
|
workerCount: workerCount,
|
||
|
|
stopChan: make(chan struct{}),
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
// RegisterHandler 注册任务处理器
|
||
|
|
func (w *TaskWorker) RegisterHandler(taskType TaskType, handler TaskHandler) {
|
||
|
|
w.mu.Lock()
|
||
|
|
defer w.mu.Unlock()
|
||
|
|
w.handlers[taskType] = handler
|
||
|
|
log.Printf("[TaskWorker] Registered handler for task type: %s", taskType)
|
||
|
|
}
|
||
|
|
|
||
|
|
// Start 启动工作器
|
||
|
|
func (w *TaskWorker) Start(ctx context.Context) {
|
||
|
|
w.mu.Lock()
|
||
|
|
if w.running {
|
||
|
|
w.mu.Unlock()
|
||
|
|
log.Println("[TaskWorker] Already running")
|
||
|
|
return
|
||
|
|
}
|
||
|
|
w.running = true
|
||
|
|
w.stopChan = make(chan struct{})
|
||
|
|
w.mu.Unlock()
|
||
|
|
|
||
|
|
log.Printf("[TaskWorker] Starting with %d workers, poll interval: %v", w.workerCount, w.pollInterval)
|
||
|
|
|
||
|
|
// 启动延迟任务轮询器
|
||
|
|
w.wg.Add(1)
|
||
|
|
go w.pollDelayedTasks(ctx)
|
||
|
|
|
||
|
|
// 启动工作器
|
||
|
|
for i := 0; i < w.workerCount; i++ {
|
||
|
|
w.wg.Add(1)
|
||
|
|
go w.worker(ctx, i)
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
// Stop 停止工作器
|
||
|
|
func (w *TaskWorker) Stop() {
|
||
|
|
w.mu.Lock()
|
||
|
|
if !w.running {
|
||
|
|
w.mu.Unlock()
|
||
|
|
return
|
||
|
|
}
|
||
|
|
w.running = false
|
||
|
|
close(w.stopChan)
|
||
|
|
w.mu.Unlock()
|
||
|
|
|
||
|
|
log.Println("[TaskWorker] Stopping...")
|
||
|
|
w.wg.Wait()
|
||
|
|
log.Println("[TaskWorker] Stopped")
|
||
|
|
}
|
||
|
|
|
||
|
|
// pollDelayedTasks 定期检查并移动到期的延迟任务
|
||
|
|
func (w *TaskWorker) pollDelayedTasks(ctx context.Context) {
|
||
|
|
defer w.wg.Done()
|
||
|
|
|
||
|
|
ticker := time.NewTicker(w.pollInterval)
|
||
|
|
defer ticker.Stop()
|
||
|
|
|
||
|
|
for {
|
||
|
|
select {
|
||
|
|
case <-ctx.Done():
|
||
|
|
return
|
||
|
|
case <-w.stopChan:
|
||
|
|
return
|
||
|
|
case <-ticker.C:
|
||
|
|
moved, err := w.queue.MoveReadyTasks(ctx)
|
||
|
|
if err != nil {
|
||
|
|
log.Printf("[TaskWorker] Error moving ready tasks: %v", err)
|
||
|
|
} else if moved > 0 {
|
||
|
|
log.Printf("[TaskWorker] Moved %d tasks to ready queue", moved)
|
||
|
|
}
|
||
|
|
}
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
// worker 工作器协程
|
||
|
|
func (w *TaskWorker) worker(ctx context.Context, id int) {
|
||
|
|
defer w.wg.Done()
|
||
|
|
|
||
|
|
log.Printf("[TaskWorker] Worker %d started", id)
|
||
|
|
|
||
|
|
for {
|
||
|
|
select {
|
||
|
|
case <-ctx.Done():
|
||
|
|
log.Printf("[TaskWorker] Worker %d stopping (context done)", id)
|
||
|
|
return
|
||
|
|
case <-w.stopChan:
|
||
|
|
log.Printf("[TaskWorker] Worker %d stopping (stop signal)", id)
|
||
|
|
return
|
||
|
|
default:
|
||
|
|
// 尝试获取任务
|
||
|
|
task, err := w.queue.PopTask(ctx, time.Second)
|
||
|
|
if err != nil {
|
||
|
|
log.Printf("[TaskWorker] Worker %d error popping task: %v", id, err)
|
||
|
|
time.Sleep(time.Second) // 错误后等待
|
||
|
|
continue
|
||
|
|
}
|
||
|
|
|
||
|
|
if task == nil {
|
||
|
|
continue // 没有任务,继续轮询
|
||
|
|
}
|
||
|
|
|
||
|
|
// 处理任务
|
||
|
|
w.processTask(ctx, id, task)
|
||
|
|
}
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
// processTask 处理单个任务
|
||
|
|
func (w *TaskWorker) processTask(ctx context.Context, workerID int, task *DelayedTask) {
|
||
|
|
startTime := time.Now()
|
||
|
|
log.Printf("[TaskWorker] Worker %d processing task %s (type: %s, user: %d)",
|
||
|
|
workerID, task.ID, task.Type, task.UserID)
|
||
|
|
|
||
|
|
w.mu.RLock()
|
||
|
|
handler, exists := w.handlers[task.Type]
|
||
|
|
w.mu.RUnlock()
|
||
|
|
|
||
|
|
if !exists {
|
||
|
|
log.Printf("[TaskWorker] No handler for task type: %s", task.Type)
|
||
|
|
w.queue.CompleteTask(ctx, task)
|
||
|
|
return
|
||
|
|
}
|
||
|
|
|
||
|
|
// 执行任务
|
||
|
|
err := handler(ctx, task)
|
||
|
|
duration := time.Since(startTime)
|
||
|
|
|
||
|
|
if err != nil {
|
||
|
|
log.Printf("[TaskWorker] Task %s failed (attempt %d/%d): %v",
|
||
|
|
task.ID, task.RetryCount+1, task.MaxRetries, err)
|
||
|
|
|
||
|
|
w.mu.Lock()
|
||
|
|
w.errorCount++
|
||
|
|
w.mu.Unlock()
|
||
|
|
|
||
|
|
// 重试(指数退避)
|
||
|
|
retryDelay := time.Duration(1<<task.RetryCount) * time.Minute
|
||
|
|
if err := w.queue.RetryTask(ctx, task, retryDelay); err != nil {
|
||
|
|
log.Printf("[TaskWorker] Failed to retry task %s: %v", task.ID, err)
|
||
|
|
}
|
||
|
|
return
|
||
|
|
}
|
||
|
|
|
||
|
|
// 完成任务
|
||
|
|
if err := w.queue.CompleteTask(ctx, task); err != nil {
|
||
|
|
log.Printf("[TaskWorker] Failed to complete task %s: %v", task.ID, err)
|
||
|
|
}
|
||
|
|
|
||
|
|
w.mu.Lock()
|
||
|
|
w.processedCount++
|
||
|
|
w.lastProcessed = time.Now()
|
||
|
|
w.mu.Unlock()
|
||
|
|
|
||
|
|
log.Printf("[TaskWorker] Task %s completed in %v", task.ID, duration)
|
||
|
|
}
|
||
|
|
|
||
|
|
// GetStats 获取工作器统计信息
|
||
|
|
func (w *TaskWorker) GetStats() map[string]interface{} {
|
||
|
|
w.mu.RLock()
|
||
|
|
defer w.mu.RUnlock()
|
||
|
|
|
||
|
|
return map[string]interface{}{
|
||
|
|
"running": w.running,
|
||
|
|
"worker_count": w.workerCount,
|
||
|
|
"poll_interval": w.pollInterval.String(),
|
||
|
|
"processed_count": w.processedCount,
|
||
|
|
"error_count": w.errorCount,
|
||
|
|
"last_processed": w.lastProcessed,
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
// ParseRecurringPayload 解析周期性交易任务载荷
|
||
|
|
func ParseRecurringPayload(task *DelayedTask) (*RecurringTransactionPayload, error) {
|
||
|
|
var payload RecurringTransactionPayload
|
||
|
|
if err := json.Unmarshal(task.Payload, &payload); err != nil {
|
||
|
|
return nil, err
|
||
|
|
}
|
||
|
|
return &payload, nil
|
||
|
|
}
|