package mq import ( "context" "encoding/json" "fmt" "log" "time" "github.com/redis/go-redis/v9" ) // TaskType 定义任务类型 type TaskType string const ( // TaskTypeRecurringTransaction 周期性交易处理任务 TaskTypeRecurringTransaction TaskType = "recurring_transaction" // TaskTypeAllocationRule 分配规则执行任务 TaskTypeAllocationRule TaskType = "allocation_rule" // Redis Key Prefixes KeyPrefixDefault = "novault:tasks" KeySuffixDelayed = ":delayed" KeySuffixReady = ":ready" KeySuffixProcessing = ":processing" KeyPrefixLock = "novault:lock:" ) // DelayedTask 延迟任务结构 type DelayedTask struct { ID string `json:"id"` Type TaskType `json:"type"` UserID uint `json:"user_id"` Payload json.RawMessage `json:"payload"` ScheduledAt time.Time `json:"scheduled_at"` CreatedAt time.Time `json:"created_at"` RetryCount int `json:"retry_count"` MaxRetries int `json:"max_retries"` } // RecurringTransactionPayload 周期性交易任务载荷 type RecurringTransactionPayload struct { RecurringTransactionID uint `json:"recurring_transaction_id"` NextOccurrence time.Time `json:"next_occurrence"` } // TaskQueue 基于 Redis 的延迟任务队列 type TaskQueue struct { client *redis.Client keyPrefix string delayedKey string // 延迟任务有序集合 readyKey string // 就绪任务列表 processingKey string // 处理中任务 } // NewTaskQueue 创建任务队列实例 func NewTaskQueue(client *redis.Client, keyPrefix string) *TaskQueue { if keyPrefix == "" { keyPrefix = KeyPrefixDefault } return &TaskQueue{ client: client, keyPrefix: keyPrefix, delayedKey: keyPrefix + KeySuffixDelayed, readyKey: keyPrefix + KeySuffixReady, processingKey: keyPrefix + KeySuffixProcessing, } } // GenerateLockKey 生成分布式锁 Key func GenerateLockKey(taskID string) string { return KeyPrefixLock + "recurring:" + taskID } // Schedule 调度延迟任务 func (q *TaskQueue) Schedule(ctx context.Context, task *DelayedTask) error { // 序列化任务 taskJSON, err := json.Marshal(task) if err != nil { return fmt.Errorf("failed to marshal task: %w", err) } // 使用 ZADD 添加到延迟队列,score 为执行时间戳 score := float64(task.ScheduledAt.Unix()) err = q.client.ZAdd(ctx, q.delayedKey, redis.Z{ Score: score, Member: string(taskJSON), }).Err() if err != nil { return fmt.Errorf("failed to schedule task: %w", err) } log.Printf("[TaskQueue] Scheduled task %s for %s", task.ID, task.ScheduledAt.Format("2006-01-02 15:04:05")) return nil } // ScheduleRecurringTransaction 调度周期性交易任务 func (q *TaskQueue) ScheduleRecurringTransaction(ctx context.Context, userID uint, recurringID uint, scheduledAt time.Time) error { payload := RecurringTransactionPayload{ RecurringTransactionID: recurringID, NextOccurrence: scheduledAt, } payloadJSON, _ := json.Marshal(payload) task := &DelayedTask{ ID: fmt.Sprintf("recurring_%d_%d", recurringID, scheduledAt.Unix()), Type: TaskTypeRecurringTransaction, UserID: userID, Payload: payloadJSON, ScheduledAt: scheduledAt, CreatedAt: time.Now(), MaxRetries: 3, } return q.Schedule(ctx, task) } // MoveReadyTasks 将到期的延迟任务移动到就绪队列 func (q *TaskQueue) MoveReadyTasks(ctx context.Context) (int64, error) { now := time.Now().Unix() // 使用 Lua 脚本原子性地移动任务 script := redis.NewScript(` local delayed_key = KEYS[1] local ready_key = KEYS[2] local now = tonumber(ARGV[1]) -- 获取所有到期的任务 local tasks = redis.call('ZRANGEBYSCORE', delayed_key, '-inf', now) if #tasks == 0 then return 0 end -- 移动到就绪队列 for i, task in ipairs(tasks) do redis.call('LPUSH', ready_key, task) end -- 从延迟队列中删除 redis.call('ZREMRANGEBYSCORE', delayed_key, '-inf', now) return #tasks `) result, err := script.Run(ctx, q.client, []string{q.delayedKey, q.readyKey}, now).Int64() if err != nil { return 0, fmt.Errorf("failed to move ready tasks: %w", err) } if result > 0 { log.Printf("[TaskQueue] Moved %d tasks to ready queue", result) } return result, nil } // PopTask 从就绪队列中取出一个任务 func (q *TaskQueue) PopTask(ctx context.Context, timeout time.Duration) (*DelayedTask, error) { // 使用 BRPOPLPUSH 原子性地获取任务并移动到处理中队列 result, err := q.client.BRPopLPush(ctx, q.readyKey, q.processingKey, timeout).Result() if err != nil { if err == redis.Nil { return nil, nil // 没有任务 } return nil, fmt.Errorf("failed to pop task: %w", err) } var task DelayedTask if err := json.Unmarshal([]byte(result), &task); err != nil { return nil, fmt.Errorf("failed to unmarshal task: %w", err) } return &task, nil } // CompleteTask 标记任务完成 func (q *TaskQueue) CompleteTask(ctx context.Context, task *DelayedTask) error { taskJSON, _ := json.Marshal(task) return q.client.LRem(ctx, q.processingKey, 1, string(taskJSON)).Err() } // RetryTask 重试任务 func (q *TaskQueue) RetryTask(ctx context.Context, task *DelayedTask, delay time.Duration) error { task.RetryCount++ if task.RetryCount > task.MaxRetries { log.Printf("[TaskQueue] Task %s exceeded max retries, discarding", task.ID) return q.CompleteTask(ctx, task) } task.ScheduledAt = time.Now().Add(delay) // 先从处理中队列移除 taskJSON, _ := json.Marshal(task) q.client.LRem(ctx, q.processingKey, 1, string(taskJSON)) // 重新调度 return q.Schedule(ctx, task) } // GetQueueStats 获取队列统计信息 func (q *TaskQueue) GetQueueStats(ctx context.Context) (map[string]int64, error) { delayed, _ := q.client.ZCard(ctx, q.delayedKey).Result() ready, _ := q.client.LLen(ctx, q.readyKey).Result() processing, _ := q.client.LLen(ctx, q.processingKey).Result() return map[string]int64{ "delayed": delayed, "ready": ready, "processing": processing, }, nil } // GetPendingTasks 获取即将执行的任务(用于调试) func (q *TaskQueue) GetPendingTasks(ctx context.Context, limit int64) ([]DelayedTask, error) { results, err := q.client.ZRange(ctx, q.delayedKey, 0, limit-1).Result() if err != nil { return nil, err } tasks := make([]DelayedTask, 0, len(results)) for _, result := range results { var task DelayedTask if err := json.Unmarshal([]byte(result), &task); err == nil { tasks = append(tasks, task) } } return tasks, nil } // AcquireLock 尝试获取分布式锁(用于幂等性检查) func (q *TaskQueue) AcquireLock(ctx context.Context, key string, ttl time.Duration) (bool, error) { return q.client.SetNX(ctx, key, "1", ttl).Result() }