Files
Novault-backend/internal/mq/task_queue.go

226 lines
6.4 KiB
Go
Raw Normal View History

package mq
import (
"context"
"encoding/json"
"fmt"
"log"
"time"
"github.com/redis/go-redis/v9"
)
// TaskType 定义任务类型
type TaskType string
const (
// TaskTypeRecurringTransaction 周期性交易处理任务
TaskTypeRecurringTransaction TaskType = "recurring_transaction"
// TaskTypeAllocationRule 分配规则执行任务
TaskTypeAllocationRule TaskType = "allocation_rule"
)
// DelayedTask 延迟任务结构
type DelayedTask struct {
ID string `json:"id"`
Type TaskType `json:"type"`
UserID uint `json:"user_id"`
Payload json.RawMessage `json:"payload"`
ScheduledAt time.Time `json:"scheduled_at"`
CreatedAt time.Time `json:"created_at"`
RetryCount int `json:"retry_count"`
MaxRetries int `json:"max_retries"`
}
// RecurringTransactionPayload 周期性交易任务载荷
type RecurringTransactionPayload struct {
RecurringTransactionID uint `json:"recurring_transaction_id"`
NextOccurrence time.Time `json:"next_occurrence"`
}
// TaskQueue 基于 Redis 的延迟任务队列
type TaskQueue struct {
client *redis.Client
keyPrefix string
delayedKey string // 延迟任务有序集合
readyKey string // 就绪任务列表
processingKey string // 处理中任务
}
// NewTaskQueue 创建任务队列实例
func NewTaskQueue(client *redis.Client, keyPrefix string) *TaskQueue {
if keyPrefix == "" {
keyPrefix = "novault:tasks"
}
return &TaskQueue{
client: client,
keyPrefix: keyPrefix,
delayedKey: keyPrefix + ":delayed",
readyKey: keyPrefix + ":ready",
processingKey: keyPrefix + ":processing",
}
}
// Schedule 调度延迟任务
func (q *TaskQueue) Schedule(ctx context.Context, task *DelayedTask) error {
// 序列化任务
taskJSON, err := json.Marshal(task)
if err != nil {
return fmt.Errorf("failed to marshal task: %w", err)
}
// 使用 ZADD 添加到延迟队列score 为执行时间戳
score := float64(task.ScheduledAt.Unix())
err = q.client.ZAdd(ctx, q.delayedKey, redis.Z{
Score: score,
Member: string(taskJSON),
}).Err()
if err != nil {
return fmt.Errorf("failed to schedule task: %w", err)
}
log.Printf("[TaskQueue] Scheduled task %s for %s", task.ID, task.ScheduledAt.Format("2006-01-02 15:04:05"))
return nil
}
// ScheduleRecurringTransaction 调度周期性交易任务
func (q *TaskQueue) ScheduleRecurringTransaction(ctx context.Context, userID uint, recurringID uint, scheduledAt time.Time) error {
payload := RecurringTransactionPayload{
RecurringTransactionID: recurringID,
NextOccurrence: scheduledAt,
}
payloadJSON, _ := json.Marshal(payload)
task := &DelayedTask{
ID: fmt.Sprintf("recurring_%d_%d", recurringID, scheduledAt.Unix()),
Type: TaskTypeRecurringTransaction,
UserID: userID,
Payload: payloadJSON,
ScheduledAt: scheduledAt,
CreatedAt: time.Now(),
MaxRetries: 3,
}
return q.Schedule(ctx, task)
}
// MoveReadyTasks 将到期的延迟任务移动到就绪队列
func (q *TaskQueue) MoveReadyTasks(ctx context.Context) (int64, error) {
now := time.Now().Unix()
// 使用 Lua 脚本原子性地移动任务
script := redis.NewScript(`
local delayed_key = KEYS[1]
local ready_key = KEYS[2]
local now = tonumber(ARGV[1])
-- 获取所有到期的任务
local tasks = redis.call('ZRANGEBYSCORE', delayed_key, '-inf', now)
if #tasks == 0 then
return 0
end
-- 移动到就绪队列
for i, task in ipairs(tasks) do
redis.call('LPUSH', ready_key, task)
end
-- 从延迟队列中删除
redis.call('ZREMRANGEBYSCORE', delayed_key, '-inf', now)
return #tasks
`)
result, err := script.Run(ctx, q.client, []string{q.delayedKey, q.readyKey}, now).Int64()
if err != nil {
return 0, fmt.Errorf("failed to move ready tasks: %w", err)
}
if result > 0 {
log.Printf("[TaskQueue] Moved %d tasks to ready queue", result)
}
return result, nil
}
// PopTask 从就绪队列中取出一个任务
func (q *TaskQueue) PopTask(ctx context.Context, timeout time.Duration) (*DelayedTask, error) {
// 使用 BRPOPLPUSH 原子性地获取任务并移动到处理中队列
result, err := q.client.BRPopLPush(ctx, q.readyKey, q.processingKey, timeout).Result()
if err != nil {
if err == redis.Nil {
return nil, nil // 没有任务
}
return nil, fmt.Errorf("failed to pop task: %w", err)
}
var task DelayedTask
if err := json.Unmarshal([]byte(result), &task); err != nil {
return nil, fmt.Errorf("failed to unmarshal task: %w", err)
}
return &task, nil
}
// CompleteTask 标记任务完成
func (q *TaskQueue) CompleteTask(ctx context.Context, task *DelayedTask) error {
taskJSON, _ := json.Marshal(task)
return q.client.LRem(ctx, q.processingKey, 1, string(taskJSON)).Err()
}
// RetryTask 重试任务
func (q *TaskQueue) RetryTask(ctx context.Context, task *DelayedTask, delay time.Duration) error {
task.RetryCount++
if task.RetryCount > task.MaxRetries {
log.Printf("[TaskQueue] Task %s exceeded max retries, discarding", task.ID)
return q.CompleteTask(ctx, task)
}
task.ScheduledAt = time.Now().Add(delay)
// 先从处理中队列移除
taskJSON, _ := json.Marshal(task)
q.client.LRem(ctx, q.processingKey, 1, string(taskJSON))
// 重新调度
return q.Schedule(ctx, task)
}
// GetQueueStats 获取队列统计信息
func (q *TaskQueue) GetQueueStats(ctx context.Context) (map[string]int64, error) {
delayed, _ := q.client.ZCard(ctx, q.delayedKey).Result()
ready, _ := q.client.LLen(ctx, q.readyKey).Result()
processing, _ := q.client.LLen(ctx, q.processingKey).Result()
return map[string]int64{
"delayed": delayed,
"ready": ready,
"processing": processing,
}, nil
}
// GetPendingTasks 获取即将执行的任务(用于调试)
func (q *TaskQueue) GetPendingTasks(ctx context.Context, limit int64) ([]DelayedTask, error) {
results, err := q.client.ZRange(ctx, q.delayedKey, 0, limit-1).Result()
if err != nil {
return nil, err
}
tasks := make([]DelayedTask, 0, len(results))
for _, result := range results {
var task DelayedTask
if err := json.Unmarshal([]byte(result), &task); err == nil {
tasks = append(tasks, task)
}
}
return tasks, nil
}
// AcquireLock 尝试获取分布式锁(用于幂等性检查)
func (q *TaskQueue) AcquireLock(ctx context.Context, key string, ttl time.Duration) (bool, error) {
return q.client.SetNX(ctx, key, "1", ttl).Result()
}