Files
Novault-backend/internal/mq/task_queue.go

238 lines
6.7 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package mq
import (
"context"
"encoding/json"
"fmt"
"log"
"time"
"github.com/redis/go-redis/v9"
)
// TaskType 定义任务类型
type TaskType string
const (
// TaskTypeRecurringTransaction 周期性交易处理任务
TaskTypeRecurringTransaction TaskType = "recurring_transaction"
// TaskTypeAllocationRule 分配规则执行任务
TaskTypeAllocationRule TaskType = "allocation_rule"
// Redis Key Prefixes
KeyPrefixDefault = "novault:tasks"
KeySuffixDelayed = ":delayed"
KeySuffixReady = ":ready"
KeySuffixProcessing = ":processing"
KeyPrefixLock = "novault:lock:"
)
// DelayedTask 延迟任务结构
type DelayedTask struct {
ID string `json:"id"`
Type TaskType `json:"type"`
UserID uint `json:"user_id"`
Payload json.RawMessage `json:"payload"`
ScheduledAt time.Time `json:"scheduled_at"`
CreatedAt time.Time `json:"created_at"`
RetryCount int `json:"retry_count"`
MaxRetries int `json:"max_retries"`
}
// RecurringTransactionPayload 周期性交易任务载荷
type RecurringTransactionPayload struct {
RecurringTransactionID uint `json:"recurring_transaction_id"`
NextOccurrence time.Time `json:"next_occurrence"`
}
// TaskQueue 基于 Redis 的延迟任务队列
type TaskQueue struct {
client *redis.Client
keyPrefix string
delayedKey string // 延迟任务有序集合
readyKey string // 就绪任务列表
processingKey string // 处理中任务
}
// NewTaskQueue 创建任务队列实例
func NewTaskQueue(client *redis.Client, keyPrefix string) *TaskQueue {
if keyPrefix == "" {
keyPrefix = KeyPrefixDefault
}
return &TaskQueue{
client: client,
keyPrefix: keyPrefix,
delayedKey: keyPrefix + KeySuffixDelayed,
readyKey: keyPrefix + KeySuffixReady,
processingKey: keyPrefix + KeySuffixProcessing,
}
}
// GenerateLockKey 生成分布式锁 Key
func GenerateLockKey(taskID string) string {
return KeyPrefixLock + "recurring:" + taskID
}
// Schedule 调度延迟任务
func (q *TaskQueue) Schedule(ctx context.Context, task *DelayedTask) error {
// 序列化任务
taskJSON, err := json.Marshal(task)
if err != nil {
return fmt.Errorf("failed to marshal task: %w", err)
}
// 使用 ZADD 添加到延迟队列score 为执行时间戳
score := float64(task.ScheduledAt.Unix())
err = q.client.ZAdd(ctx, q.delayedKey, redis.Z{
Score: score,
Member: string(taskJSON),
}).Err()
if err != nil {
return fmt.Errorf("failed to schedule task: %w", err)
}
log.Printf("[TaskQueue] Scheduled task %s for %s", task.ID, task.ScheduledAt.Format("2006-01-02 15:04:05"))
return nil
}
// ScheduleRecurringTransaction 调度周期性交易任务
func (q *TaskQueue) ScheduleRecurringTransaction(ctx context.Context, userID uint, recurringID uint, scheduledAt time.Time) error {
payload := RecurringTransactionPayload{
RecurringTransactionID: recurringID,
NextOccurrence: scheduledAt,
}
payloadJSON, _ := json.Marshal(payload)
task := &DelayedTask{
ID: fmt.Sprintf("recurring_%d_%d", recurringID, scheduledAt.Unix()),
Type: TaskTypeRecurringTransaction,
UserID: userID,
Payload: payloadJSON,
ScheduledAt: scheduledAt,
CreatedAt: time.Now(),
MaxRetries: 3,
}
return q.Schedule(ctx, task)
}
// MoveReadyTasks 将到期的延迟任务移动到就绪队列
func (q *TaskQueue) MoveReadyTasks(ctx context.Context) (int64, error) {
now := time.Now().Unix()
// 使用 Lua 脚本原子性地移动任务
script := redis.NewScript(`
local delayed_key = KEYS[1]
local ready_key = KEYS[2]
local now = tonumber(ARGV[1])
-- 获取所有到期的任务
local tasks = redis.call('ZRANGEBYSCORE', delayed_key, '-inf', now)
if #tasks == 0 then
return 0
end
-- 移动到就绪队列
for i, task in ipairs(tasks) do
redis.call('LPUSH', ready_key, task)
end
-- 从延迟队列中删除
redis.call('ZREMRANGEBYSCORE', delayed_key, '-inf', now)
return #tasks
`)
result, err := script.Run(ctx, q.client, []string{q.delayedKey, q.readyKey}, now).Int64()
if err != nil {
return 0, fmt.Errorf("failed to move ready tasks: %w", err)
}
if result > 0 {
log.Printf("[TaskQueue] Moved %d tasks to ready queue", result)
}
return result, nil
}
// PopTask 从就绪队列中取出一个任务
func (q *TaskQueue) PopTask(ctx context.Context, timeout time.Duration) (*DelayedTask, error) {
// 使用 BRPOPLPUSH 原子性地获取任务并移动到处理中队列
result, err := q.client.BRPopLPush(ctx, q.readyKey, q.processingKey, timeout).Result()
if err != nil {
if err == redis.Nil {
return nil, nil // 没有任务
}
return nil, fmt.Errorf("failed to pop task: %w", err)
}
var task DelayedTask
if err := json.Unmarshal([]byte(result), &task); err != nil {
return nil, fmt.Errorf("failed to unmarshal task: %w", err)
}
return &task, nil
}
// CompleteTask 标记任务完成
func (q *TaskQueue) CompleteTask(ctx context.Context, task *DelayedTask) error {
taskJSON, _ := json.Marshal(task)
return q.client.LRem(ctx, q.processingKey, 1, string(taskJSON)).Err()
}
// RetryTask 重试任务
func (q *TaskQueue) RetryTask(ctx context.Context, task *DelayedTask, delay time.Duration) error {
task.RetryCount++
if task.RetryCount > task.MaxRetries {
log.Printf("[TaskQueue] Task %s exceeded max retries, discarding", task.ID)
return q.CompleteTask(ctx, task)
}
task.ScheduledAt = time.Now().Add(delay)
// 先从处理中队列移除
taskJSON, _ := json.Marshal(task)
q.client.LRem(ctx, q.processingKey, 1, string(taskJSON))
// 重新调度
return q.Schedule(ctx, task)
}
// GetQueueStats 获取队列统计信息
func (q *TaskQueue) GetQueueStats(ctx context.Context) (map[string]int64, error) {
delayed, _ := q.client.ZCard(ctx, q.delayedKey).Result()
ready, _ := q.client.LLen(ctx, q.readyKey).Result()
processing, _ := q.client.LLen(ctx, q.processingKey).Result()
return map[string]int64{
"delayed": delayed,
"ready": ready,
"processing": processing,
}, nil
}
// GetPendingTasks 获取即将执行的任务(用于调试)
func (q *TaskQueue) GetPendingTasks(ctx context.Context, limit int64) ([]DelayedTask, error) {
results, err := q.client.ZRange(ctx, q.delayedKey, 0, limit-1).Result()
if err != nil {
return nil, err
}
tasks := make([]DelayedTask, 0, len(results))
for _, result := range results {
var task DelayedTask
if err := json.Unmarshal([]byte(result), &task); err == nil {
tasks = append(tasks, task)
}
}
return tasks, nil
}
// AcquireLock 尝试获取分布式锁(用于幂等性检查)
func (q *TaskQueue) AcquireLock(ctx context.Context, key string, ttl time.Duration) (bool, error) {
return q.client.SetNX(ctx, key, "1", ttl).Result()
}