221 lines
6.1 KiB
Go
221 lines
6.1 KiB
Go
package mq
|
||
|
||
import (
|
||
"context"
|
||
"encoding/json"
|
||
"fmt"
|
||
"log"
|
||
"time"
|
||
|
||
"github.com/redis/go-redis/v9"
|
||
)
|
||
|
||
// TaskType 定义任务类型
|
||
type TaskType string
|
||
|
||
const (
|
||
// TaskTypeRecurringTransaction 周期性交易处理任务
|
||
TaskTypeRecurringTransaction TaskType = "recurring_transaction"
|
||
// TaskTypeAllocationRule 分配规则执行任务
|
||
TaskTypeAllocationRule TaskType = "allocation_rule"
|
||
)
|
||
|
||
// DelayedTask 延迟任务结构
|
||
type DelayedTask struct {
|
||
ID string `json:"id"`
|
||
Type TaskType `json:"type"`
|
||
UserID uint `json:"user_id"`
|
||
Payload json.RawMessage `json:"payload"`
|
||
ScheduledAt time.Time `json:"scheduled_at"`
|
||
CreatedAt time.Time `json:"created_at"`
|
||
RetryCount int `json:"retry_count"`
|
||
MaxRetries int `json:"max_retries"`
|
||
}
|
||
|
||
// RecurringTransactionPayload 周期性交易任务载荷
|
||
type RecurringTransactionPayload struct {
|
||
RecurringTransactionID uint `json:"recurring_transaction_id"`
|
||
NextOccurrence time.Time `json:"next_occurrence"`
|
||
}
|
||
|
||
// TaskQueue 基于 Redis 的延迟任务队列
|
||
type TaskQueue struct {
|
||
client *redis.Client
|
||
keyPrefix string
|
||
delayedKey string // 延迟任务有序集合
|
||
readyKey string // 就绪任务列表
|
||
processingKey string // 处理中任务
|
||
}
|
||
|
||
// NewTaskQueue 创建任务队列实例
|
||
func NewTaskQueue(client *redis.Client, keyPrefix string) *TaskQueue {
|
||
if keyPrefix == "" {
|
||
keyPrefix = "novault:tasks"
|
||
}
|
||
return &TaskQueue{
|
||
client: client,
|
||
keyPrefix: keyPrefix,
|
||
delayedKey: keyPrefix + ":delayed",
|
||
readyKey: keyPrefix + ":ready",
|
||
processingKey: keyPrefix + ":processing",
|
||
}
|
||
}
|
||
|
||
// Schedule 调度延迟任务
|
||
func (q *TaskQueue) Schedule(ctx context.Context, task *DelayedTask) error {
|
||
// 序列化任务
|
||
taskJSON, err := json.Marshal(task)
|
||
if err != nil {
|
||
return fmt.Errorf("failed to marshal task: %w", err)
|
||
}
|
||
|
||
// 使用 ZADD 添加到延迟队列,score 为执行时间戳
|
||
score := float64(task.ScheduledAt.Unix())
|
||
err = q.client.ZAdd(ctx, q.delayedKey, redis.Z{
|
||
Score: score,
|
||
Member: string(taskJSON),
|
||
}).Err()
|
||
|
||
if err != nil {
|
||
return fmt.Errorf("failed to schedule task: %w", err)
|
||
}
|
||
|
||
log.Printf("[TaskQueue] Scheduled task %s for %s", task.ID, task.ScheduledAt.Format("2006-01-02 15:04:05"))
|
||
return nil
|
||
}
|
||
|
||
// ScheduleRecurringTransaction 调度周期性交易任务
|
||
func (q *TaskQueue) ScheduleRecurringTransaction(ctx context.Context, userID uint, recurringID uint, scheduledAt time.Time) error {
|
||
payload := RecurringTransactionPayload{
|
||
RecurringTransactionID: recurringID,
|
||
NextOccurrence: scheduledAt,
|
||
}
|
||
payloadJSON, _ := json.Marshal(payload)
|
||
|
||
task := &DelayedTask{
|
||
ID: fmt.Sprintf("recurring_%d_%d", recurringID, scheduledAt.Unix()),
|
||
Type: TaskTypeRecurringTransaction,
|
||
UserID: userID,
|
||
Payload: payloadJSON,
|
||
ScheduledAt: scheduledAt,
|
||
CreatedAt: time.Now(),
|
||
MaxRetries: 3,
|
||
}
|
||
|
||
return q.Schedule(ctx, task)
|
||
}
|
||
|
||
// MoveReadyTasks 将到期的延迟任务移动到就绪队列
|
||
func (q *TaskQueue) MoveReadyTasks(ctx context.Context) (int64, error) {
|
||
now := time.Now().Unix()
|
||
|
||
// 使用 Lua 脚本原子性地移动任务
|
||
script := redis.NewScript(`
|
||
local delayed_key = KEYS[1]
|
||
local ready_key = KEYS[2]
|
||
local now = tonumber(ARGV[1])
|
||
|
||
-- 获取所有到期的任务
|
||
local tasks = redis.call('ZRANGEBYSCORE', delayed_key, '-inf', now)
|
||
|
||
if #tasks == 0 then
|
||
return 0
|
||
end
|
||
|
||
-- 移动到就绪队列
|
||
for i, task in ipairs(tasks) do
|
||
redis.call('LPUSH', ready_key, task)
|
||
end
|
||
|
||
-- 从延迟队列中删除
|
||
redis.call('ZREMRANGEBYSCORE', delayed_key, '-inf', now)
|
||
|
||
return #tasks
|
||
`)
|
||
|
||
result, err := script.Run(ctx, q.client, []string{q.delayedKey, q.readyKey}, now).Int64()
|
||
if err != nil {
|
||
return 0, fmt.Errorf("failed to move ready tasks: %w", err)
|
||
}
|
||
|
||
if result > 0 {
|
||
log.Printf("[TaskQueue] Moved %d tasks to ready queue", result)
|
||
}
|
||
|
||
return result, nil
|
||
}
|
||
|
||
// PopTask 从就绪队列中取出一个任务
|
||
func (q *TaskQueue) PopTask(ctx context.Context, timeout time.Duration) (*DelayedTask, error) {
|
||
// 使用 BRPOPLPUSH 原子性地获取任务并移动到处理中队列
|
||
result, err := q.client.BRPopLPush(ctx, q.readyKey, q.processingKey, timeout).Result()
|
||
if err != nil {
|
||
if err == redis.Nil {
|
||
return nil, nil // 没有任务
|
||
}
|
||
return nil, fmt.Errorf("failed to pop task: %w", err)
|
||
}
|
||
|
||
var task DelayedTask
|
||
if err := json.Unmarshal([]byte(result), &task); err != nil {
|
||
return nil, fmt.Errorf("failed to unmarshal task: %w", err)
|
||
}
|
||
|
||
return &task, nil
|
||
}
|
||
|
||
// CompleteTask 标记任务完成
|
||
func (q *TaskQueue) CompleteTask(ctx context.Context, task *DelayedTask) error {
|
||
taskJSON, _ := json.Marshal(task)
|
||
return q.client.LRem(ctx, q.processingKey, 1, string(taskJSON)).Err()
|
||
}
|
||
|
||
// RetryTask 重试任务
|
||
func (q *TaskQueue) RetryTask(ctx context.Context, task *DelayedTask, delay time.Duration) error {
|
||
task.RetryCount++
|
||
if task.RetryCount > task.MaxRetries {
|
||
log.Printf("[TaskQueue] Task %s exceeded max retries, discarding", task.ID)
|
||
return q.CompleteTask(ctx, task)
|
||
}
|
||
|
||
task.ScheduledAt = time.Now().Add(delay)
|
||
|
||
// 先从处理中队列移除
|
||
taskJSON, _ := json.Marshal(task)
|
||
q.client.LRem(ctx, q.processingKey, 1, string(taskJSON))
|
||
|
||
// 重新调度
|
||
return q.Schedule(ctx, task)
|
||
}
|
||
|
||
// GetQueueStats 获取队列统计信息
|
||
func (q *TaskQueue) GetQueueStats(ctx context.Context) (map[string]int64, error) {
|
||
delayed, _ := q.client.ZCard(ctx, q.delayedKey).Result()
|
||
ready, _ := q.client.LLen(ctx, q.readyKey).Result()
|
||
processing, _ := q.client.LLen(ctx, q.processingKey).Result()
|
||
|
||
return map[string]int64{
|
||
"delayed": delayed,
|
||
"ready": ready,
|
||
"processing": processing,
|
||
}, nil
|
||
}
|
||
|
||
// GetPendingTasks 获取即将执行的任务(用于调试)
|
||
func (q *TaskQueue) GetPendingTasks(ctx context.Context, limit int64) ([]DelayedTask, error) {
|
||
results, err := q.client.ZRange(ctx, q.delayedKey, 0, limit-1).Result()
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
|
||
tasks := make([]DelayedTask, 0, len(results))
|
||
for _, result := range results {
|
||
var task DelayedTask
|
||
if err := json.Unmarshal([]byte(result), &task); err == nil {
|
||
tasks = append(tasks, task)
|
||
}
|
||
}
|
||
|
||
return tasks, nil
|
||
}
|