feat: 实现基于 Redis 的延迟任务队列和工作器,支持周期性交易处理等后台任务调度。
This commit is contained in:
78
internal/handler/task_queue_handler.go
Normal file
78
internal/handler/task_queue_handler.go
Normal file
@@ -0,0 +1,78 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
"accounting-app/internal/mq"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// TaskQueueHandler 任务队列状态查询 Handler
|
||||
type TaskQueueHandler struct {
|
||||
taskSystem *mq.RecurringTaskSystem
|
||||
}
|
||||
|
||||
// NewTaskQueueHandler 创建任务队列 Handler
|
||||
func NewTaskQueueHandler(taskSystem *mq.RecurringTaskSystem) *TaskQueueHandler {
|
||||
return &TaskQueueHandler{
|
||||
taskSystem: taskSystem,
|
||||
}
|
||||
}
|
||||
|
||||
// GetStats 获取任务队列统计信息
|
||||
// GET /api/v1/admin/tasks/stats
|
||||
func (h *TaskQueueHandler) GetStats(c *gin.Context) {
|
||||
if h.taskSystem == nil {
|
||||
c.JSON(http.StatusServiceUnavailable, gin.H{
|
||||
"success": false,
|
||||
"error": "Task system not available (Redis required)",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
stats := h.taskSystem.GetStats(c.Request.Context())
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"data": stats,
|
||||
})
|
||||
}
|
||||
|
||||
// GetPendingTasks 获取即将执行的任务列表
|
||||
// GET /api/v1/admin/tasks/pending
|
||||
func (h *TaskQueueHandler) GetPendingTasks(c *gin.Context) {
|
||||
if h.taskSystem == nil {
|
||||
c.JSON(http.StatusServiceUnavailable, gin.H{
|
||||
"success": false,
|
||||
"error": "Task system not available (Redis required)",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
tasks, err := h.taskSystem.Queue.GetPendingTasks(c.Request.Context(), 50)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{
|
||||
"success": false,
|
||||
"error": err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"data": gin.H{
|
||||
"count": len(tasks),
|
||||
"tasks": tasks,
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
// RegisterRoutes 注册路由
|
||||
func (h *TaskQueueHandler) RegisterRoutes(rg *gin.RouterGroup) {
|
||||
tasks := rg.Group("/admin/tasks")
|
||||
{
|
||||
tasks.GET("/stats", h.GetStats)
|
||||
tasks.GET("/pending", h.GetPendingTasks)
|
||||
}
|
||||
}
|
||||
95
internal/mq/init.go
Normal file
95
internal/mq/init.go
Normal file
@@ -0,0 +1,95 @@
|
||||
package mq
|
||||
|
||||
import (
|
||||
"context"
|
||||
"log"
|
||||
"time"
|
||||
|
||||
"accounting-app/internal/repository"
|
||||
|
||||
"github.com/redis/go-redis/v9"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
// RecurringTaskSystem 周期性交易任务系统
|
||||
type RecurringTaskSystem struct {
|
||||
Queue *TaskQueue
|
||||
Worker *TaskWorker
|
||||
Handler *RecurringTransactionHandler
|
||||
}
|
||||
|
||||
// InitRecurringTaskSystem 初始化周期性交易任务系统
|
||||
// 需要 Redis 连接来存储任务队列
|
||||
func InitRecurringTaskSystem(
|
||||
ctx context.Context,
|
||||
redisClient *redis.Client,
|
||||
db *gorm.DB,
|
||||
pollInterval time.Duration,
|
||||
workerCount int,
|
||||
) (*RecurringTaskSystem, error) {
|
||||
log.Println("[MQ] Initializing recurring task system...")
|
||||
|
||||
// 创建任务队列
|
||||
taskQueue := NewTaskQueue(redisClient, "novault:tasks")
|
||||
|
||||
// 创建任务 Worker
|
||||
worker := NewTaskWorker(taskQueue, pollInterval, workerCount)
|
||||
|
||||
// 创建周期性交易处理器
|
||||
recurringRepo := repository.NewRecurringTransactionRepository(db)
|
||||
transactionRepo := repository.NewTransactionRepository(db)
|
||||
accountRepo := repository.NewAccountRepository(db)
|
||||
|
||||
handler := NewRecurringTransactionHandler(
|
||||
db,
|
||||
recurringRepo,
|
||||
transactionRepo,
|
||||
accountRepo,
|
||||
taskQueue,
|
||||
)
|
||||
|
||||
// 注册处理器
|
||||
worker.RegisterHandler(TaskTypeRecurringTransaction, handler.Handle)
|
||||
|
||||
system := &RecurringTaskSystem{
|
||||
Queue: taskQueue,
|
||||
Worker: worker,
|
||||
Handler: handler,
|
||||
}
|
||||
|
||||
// 启动时调度所有活跃的周期性交易
|
||||
if err := handler.ScheduleAllActive(ctx); err != nil {
|
||||
log.Printf("[MQ] Warning: failed to schedule active recurring transactions: %v", err)
|
||||
}
|
||||
|
||||
// 处理逾期任务(补偿机制)
|
||||
if err := handler.ProcessOverdue(ctx); err != nil {
|
||||
log.Printf("[MQ] Warning: failed to process overdue transactions: %v", err)
|
||||
}
|
||||
|
||||
log.Println("[MQ] Recurring task system initialized successfully")
|
||||
return system, nil
|
||||
}
|
||||
|
||||
// Start 启动任务系统
|
||||
func (s *RecurringTaskSystem) Start(ctx context.Context) {
|
||||
log.Println("[MQ] Starting recurring task worker...")
|
||||
s.Worker.Start(ctx)
|
||||
}
|
||||
|
||||
// Stop 停止任务系统
|
||||
func (s *RecurringTaskSystem) Stop() {
|
||||
log.Println("[MQ] Stopping recurring task worker...")
|
||||
s.Worker.Stop()
|
||||
}
|
||||
|
||||
// GetStats 获取系统统计信息
|
||||
func (s *RecurringTaskSystem) GetStats(ctx context.Context) map[string]interface{} {
|
||||
queueStats, _ := s.Queue.GetQueueStats(ctx)
|
||||
workerStats := s.Worker.GetStats()
|
||||
|
||||
return map[string]interface{}{
|
||||
"queue": queueStats,
|
||||
"worker": workerStats,
|
||||
}
|
||||
}
|
||||
245
internal/mq/recurring_handler.go
Normal file
245
internal/mq/recurring_handler.go
Normal file
@@ -0,0 +1,245 @@
|
||||
package mq
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log"
|
||||
"time"
|
||||
|
||||
"accounting-app/internal/models"
|
||||
"accounting-app/internal/repository"
|
||||
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
// RecurringTransactionHandler 周期性交易任务处理器
|
||||
type RecurringTransactionHandler struct {
|
||||
db *gorm.DB
|
||||
recurringRepo *repository.RecurringTransactionRepository
|
||||
transactionRepo *repository.TransactionRepository
|
||||
accountRepo *repository.AccountRepository
|
||||
taskQueue *TaskQueue
|
||||
}
|
||||
|
||||
// NewRecurringTransactionHandler 创建处理器实例
|
||||
func NewRecurringTransactionHandler(
|
||||
db *gorm.DB,
|
||||
recurringRepo *repository.RecurringTransactionRepository,
|
||||
transactionRepo *repository.TransactionRepository,
|
||||
accountRepo *repository.AccountRepository,
|
||||
taskQueue *TaskQueue,
|
||||
) *RecurringTransactionHandler {
|
||||
return &RecurringTransactionHandler{
|
||||
db: db,
|
||||
recurringRepo: recurringRepo,
|
||||
transactionRepo: transactionRepo,
|
||||
accountRepo: accountRepo,
|
||||
taskQueue: taskQueue,
|
||||
}
|
||||
}
|
||||
|
||||
// Handle 处理周期性交易任务
|
||||
func (h *RecurringTransactionHandler) Handle(ctx context.Context, task *DelayedTask) error {
|
||||
payload, err := ParseRecurringPayload(task)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to parse payload: %w", err)
|
||||
}
|
||||
|
||||
log.Printf("[RecurringHandler] Processing recurring transaction %d for user %d",
|
||||
payload.RecurringTransactionID, task.UserID)
|
||||
|
||||
// 获取周期性交易记录
|
||||
recurring, err := h.recurringRepo.GetByID(task.UserID, payload.RecurringTransactionID)
|
||||
if err != nil {
|
||||
if err == gorm.ErrRecordNotFound {
|
||||
log.Printf("[RecurringHandler] Recurring transaction %d not found, skipping", payload.RecurringTransactionID)
|
||||
return nil // 已被删除,不需要重试
|
||||
}
|
||||
return fmt.Errorf("failed to get recurring transaction: %w", err)
|
||||
}
|
||||
|
||||
// 检查是否已禁用或已过期
|
||||
if !recurring.IsActive {
|
||||
log.Printf("[RecurringHandler] Recurring transaction %d is inactive, skipping", recurring.ID)
|
||||
return nil
|
||||
}
|
||||
|
||||
if recurring.EndDate != nil && time.Now().After(*recurring.EndDate) {
|
||||
log.Printf("[RecurringHandler] Recurring transaction %d has ended, deactivating", recurring.ID)
|
||||
recurring.IsActive = false
|
||||
return h.recurringRepo.Update(recurring)
|
||||
}
|
||||
|
||||
// 开始数据库事务
|
||||
tx := h.db.Begin()
|
||||
if tx.Error != nil {
|
||||
return fmt.Errorf("failed to begin transaction: %w", tx.Error)
|
||||
}
|
||||
|
||||
// 创建实际交易记录
|
||||
transaction := &models.Transaction{
|
||||
UserID: recurring.UserID,
|
||||
Amount: recurring.Amount,
|
||||
Type: recurring.Type,
|
||||
CategoryID: recurring.CategoryID,
|
||||
AccountID: recurring.AccountID,
|
||||
Currency: recurring.Currency,
|
||||
TransactionDate: recurring.NextOccurrence,
|
||||
Note: recurring.Note,
|
||||
RecurringID: &recurring.ID,
|
||||
}
|
||||
|
||||
if err := tx.Create(transaction).Error; err != nil {
|
||||
tx.Rollback()
|
||||
return fmt.Errorf("failed to create transaction: %w", err)
|
||||
}
|
||||
|
||||
// 更新账户余额
|
||||
var account models.Account
|
||||
if err := tx.First(&account, recurring.AccountID).Error; err != nil {
|
||||
tx.Rollback()
|
||||
return fmt.Errorf("failed to get account: %w", err)
|
||||
}
|
||||
|
||||
switch recurring.Type {
|
||||
case models.TransactionTypeIncome:
|
||||
account.Balance += recurring.Amount
|
||||
case models.TransactionTypeExpense:
|
||||
account.Balance -= recurring.Amount
|
||||
}
|
||||
|
||||
if err := tx.Save(&account).Error; err != nil {
|
||||
tx.Rollback()
|
||||
return fmt.Errorf("failed to update account balance: %w", err)
|
||||
}
|
||||
|
||||
// 计算下一次执行时间
|
||||
nextOccurrence := h.calculateNextOccurrence(recurring.NextOccurrence, recurring.Frequency)
|
||||
recurring.NextOccurrence = nextOccurrence
|
||||
|
||||
// 检查下次执行是否超过结束日期
|
||||
if recurring.EndDate != nil && nextOccurrence.After(*recurring.EndDate) {
|
||||
recurring.IsActive = false
|
||||
log.Printf("[RecurringHandler] Recurring transaction %d will end after this execution", recurring.ID)
|
||||
}
|
||||
|
||||
if err := tx.Save(recurring).Error; err != nil {
|
||||
tx.Rollback()
|
||||
return fmt.Errorf("failed to update recurring transaction: %w", err)
|
||||
}
|
||||
|
||||
// 提交事务
|
||||
if err := tx.Commit().Error; err != nil {
|
||||
return fmt.Errorf("failed to commit transaction: %w", err)
|
||||
}
|
||||
|
||||
log.Printf("[RecurringHandler] Created transaction %d from recurring %d, amount: %.2f %s",
|
||||
transaction.ID, recurring.ID, recurring.Amount, recurring.Currency)
|
||||
|
||||
// 如果仍然活跃,调度下一次执行
|
||||
if recurring.IsActive {
|
||||
if err := h.scheduleNext(ctx, recurring); err != nil {
|
||||
log.Printf("[RecurringHandler] Warning: failed to schedule next execution: %v", err)
|
||||
// 不返回错误,当前任务已成功完成
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// calculateNextOccurrence 计算下一次执行时间
|
||||
func (h *RecurringTransactionHandler) calculateNextOccurrence(current time.Time, frequency models.FrequencyType) time.Time {
|
||||
switch frequency {
|
||||
case models.FrequencyDaily:
|
||||
return current.AddDate(0, 0, 1)
|
||||
case models.FrequencyWeekly:
|
||||
return current.AddDate(0, 0, 7)
|
||||
case models.FrequencyMonthly:
|
||||
return current.AddDate(0, 1, 0)
|
||||
case models.FrequencyYearly:
|
||||
return current.AddDate(1, 0, 0)
|
||||
default:
|
||||
return current.AddDate(0, 1, 0) // 默认月度
|
||||
}
|
||||
}
|
||||
|
||||
// scheduleNext 调度下一次执行
|
||||
func (h *RecurringTransactionHandler) scheduleNext(ctx context.Context, recurring *models.RecurringTransaction) error {
|
||||
return h.taskQueue.ScheduleRecurringTransaction(
|
||||
ctx,
|
||||
recurring.UserID,
|
||||
recurring.ID,
|
||||
recurring.NextOccurrence,
|
||||
)
|
||||
}
|
||||
|
||||
// ScheduleAllActive 调度所有活跃的周期性交易(启动时调用)
|
||||
func (h *RecurringTransactionHandler) ScheduleAllActive(ctx context.Context) error {
|
||||
log.Println("[RecurringHandler] Scheduling all active recurring transactions...")
|
||||
|
||||
// 获取所有活跃的周期性交易
|
||||
var recurringList []models.RecurringTransaction
|
||||
err := h.db.Where("is_active = ?", true).Find(&recurringList).Error
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get active recurring transactions: %w", err)
|
||||
}
|
||||
|
||||
scheduledCount := 0
|
||||
for _, recurring := range recurringList {
|
||||
// 跳过已过期的
|
||||
if recurring.EndDate != nil && time.Now().After(*recurring.EndDate) {
|
||||
recurring.IsActive = false
|
||||
h.db.Save(&recurring)
|
||||
continue
|
||||
}
|
||||
|
||||
// 调度任务
|
||||
err := h.taskQueue.ScheduleRecurringTransaction(
|
||||
ctx,
|
||||
recurring.UserID,
|
||||
recurring.ID,
|
||||
recurring.NextOccurrence,
|
||||
)
|
||||
if err != nil {
|
||||
log.Printf("[RecurringHandler] Failed to schedule recurring %d: %v", recurring.ID, err)
|
||||
continue
|
||||
}
|
||||
scheduledCount++
|
||||
}
|
||||
|
||||
log.Printf("[RecurringHandler] Scheduled %d recurring transactions", scheduledCount)
|
||||
return nil
|
||||
}
|
||||
|
||||
// ProcessOverdue 处理所有逾期的周期性交易(补偿机制)
|
||||
func (h *RecurringTransactionHandler) ProcessOverdue(ctx context.Context) error {
|
||||
log.Println("[RecurringHandler] Processing overdue recurring transactions...")
|
||||
|
||||
now := time.Now()
|
||||
var overdueList []models.RecurringTransaction
|
||||
err := h.db.Where("is_active = ? AND next_occurrence <= ?", true, now).Find(&overdueList).Error
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get overdue recurring transactions: %w", err)
|
||||
}
|
||||
|
||||
if len(overdueList) == 0 {
|
||||
log.Println("[RecurringHandler] No overdue recurring transactions found")
|
||||
return nil
|
||||
}
|
||||
|
||||
log.Printf("[RecurringHandler] Found %d overdue recurring transactions", len(overdueList))
|
||||
|
||||
for _, recurring := range overdueList {
|
||||
err := h.taskQueue.ScheduleRecurringTransaction(
|
||||
ctx,
|
||||
recurring.UserID,
|
||||
recurring.ID,
|
||||
recurring.NextOccurrence,
|
||||
)
|
||||
if err != nil {
|
||||
log.Printf("[RecurringHandler] Failed to schedule overdue recurring %d: %v", recurring.ID, err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
220
internal/mq/task_queue.go
Normal file
220
internal/mq/task_queue.go
Normal file
@@ -0,0 +1,220 @@
|
||||
package mq
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"log"
|
||||
"time"
|
||||
|
||||
"github.com/redis/go-redis/v9"
|
||||
)
|
||||
|
||||
// TaskType 定义任务类型
|
||||
type TaskType string
|
||||
|
||||
const (
|
||||
// TaskTypeRecurringTransaction 周期性交易处理任务
|
||||
TaskTypeRecurringTransaction TaskType = "recurring_transaction"
|
||||
// TaskTypeAllocationRule 分配规则执行任务
|
||||
TaskTypeAllocationRule TaskType = "allocation_rule"
|
||||
)
|
||||
|
||||
// DelayedTask 延迟任务结构
|
||||
type DelayedTask struct {
|
||||
ID string `json:"id"`
|
||||
Type TaskType `json:"type"`
|
||||
UserID uint `json:"user_id"`
|
||||
Payload json.RawMessage `json:"payload"`
|
||||
ScheduledAt time.Time `json:"scheduled_at"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
RetryCount int `json:"retry_count"`
|
||||
MaxRetries int `json:"max_retries"`
|
||||
}
|
||||
|
||||
// RecurringTransactionPayload 周期性交易任务载荷
|
||||
type RecurringTransactionPayload struct {
|
||||
RecurringTransactionID uint `json:"recurring_transaction_id"`
|
||||
NextOccurrence time.Time `json:"next_occurrence"`
|
||||
}
|
||||
|
||||
// TaskQueue 基于 Redis 的延迟任务队列
|
||||
type TaskQueue struct {
|
||||
client *redis.Client
|
||||
keyPrefix string
|
||||
delayedKey string // 延迟任务有序集合
|
||||
readyKey string // 就绪任务列表
|
||||
processingKey string // 处理中任务
|
||||
}
|
||||
|
||||
// NewTaskQueue 创建任务队列实例
|
||||
func NewTaskQueue(client *redis.Client, keyPrefix string) *TaskQueue {
|
||||
if keyPrefix == "" {
|
||||
keyPrefix = "novault:tasks"
|
||||
}
|
||||
return &TaskQueue{
|
||||
client: client,
|
||||
keyPrefix: keyPrefix,
|
||||
delayedKey: keyPrefix + ":delayed",
|
||||
readyKey: keyPrefix + ":ready",
|
||||
processingKey: keyPrefix + ":processing",
|
||||
}
|
||||
}
|
||||
|
||||
// Schedule 调度延迟任务
|
||||
func (q *TaskQueue) Schedule(ctx context.Context, task *DelayedTask) error {
|
||||
// 序列化任务
|
||||
taskJSON, err := json.Marshal(task)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to marshal task: %w", err)
|
||||
}
|
||||
|
||||
// 使用 ZADD 添加到延迟队列,score 为执行时间戳
|
||||
score := float64(task.ScheduledAt.Unix())
|
||||
err = q.client.ZAdd(ctx, q.delayedKey, redis.Z{
|
||||
Score: score,
|
||||
Member: string(taskJSON),
|
||||
}).Err()
|
||||
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to schedule task: %w", err)
|
||||
}
|
||||
|
||||
log.Printf("[TaskQueue] Scheduled task %s for %s", task.ID, task.ScheduledAt.Format("2006-01-02 15:04:05"))
|
||||
return nil
|
||||
}
|
||||
|
||||
// ScheduleRecurringTransaction 调度周期性交易任务
|
||||
func (q *TaskQueue) ScheduleRecurringTransaction(ctx context.Context, userID uint, recurringID uint, scheduledAt time.Time) error {
|
||||
payload := RecurringTransactionPayload{
|
||||
RecurringTransactionID: recurringID,
|
||||
NextOccurrence: scheduledAt,
|
||||
}
|
||||
payloadJSON, _ := json.Marshal(payload)
|
||||
|
||||
task := &DelayedTask{
|
||||
ID: fmt.Sprintf("recurring_%d_%d", recurringID, scheduledAt.Unix()),
|
||||
Type: TaskTypeRecurringTransaction,
|
||||
UserID: userID,
|
||||
Payload: payloadJSON,
|
||||
ScheduledAt: scheduledAt,
|
||||
CreatedAt: time.Now(),
|
||||
MaxRetries: 3,
|
||||
}
|
||||
|
||||
return q.Schedule(ctx, task)
|
||||
}
|
||||
|
||||
// MoveReadyTasks 将到期的延迟任务移动到就绪队列
|
||||
func (q *TaskQueue) MoveReadyTasks(ctx context.Context) (int64, error) {
|
||||
now := time.Now().Unix()
|
||||
|
||||
// 使用 Lua 脚本原子性地移动任务
|
||||
script := redis.NewScript(`
|
||||
local delayed_key = KEYS[1]
|
||||
local ready_key = KEYS[2]
|
||||
local now = tonumber(ARGV[1])
|
||||
|
||||
-- 获取所有到期的任务
|
||||
local tasks = redis.call('ZRANGEBYSCORE', delayed_key, '-inf', now)
|
||||
|
||||
if #tasks == 0 then
|
||||
return 0
|
||||
end
|
||||
|
||||
-- 移动到就绪队列
|
||||
for i, task in ipairs(tasks) do
|
||||
redis.call('LPUSH', ready_key, task)
|
||||
end
|
||||
|
||||
-- 从延迟队列中删除
|
||||
redis.call('ZREMRANGEBYSCORE', delayed_key, '-inf', now)
|
||||
|
||||
return #tasks
|
||||
`)
|
||||
|
||||
result, err := script.Run(ctx, q.client, []string{q.delayedKey, q.readyKey}, now).Int64()
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("failed to move ready tasks: %w", err)
|
||||
}
|
||||
|
||||
if result > 0 {
|
||||
log.Printf("[TaskQueue] Moved %d tasks to ready queue", result)
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// PopTask 从就绪队列中取出一个任务
|
||||
func (q *TaskQueue) PopTask(ctx context.Context, timeout time.Duration) (*DelayedTask, error) {
|
||||
// 使用 BRPOPLPUSH 原子性地获取任务并移动到处理中队列
|
||||
result, err := q.client.BRPopLPush(ctx, q.readyKey, q.processingKey, timeout).Result()
|
||||
if err != nil {
|
||||
if err == redis.Nil {
|
||||
return nil, nil // 没有任务
|
||||
}
|
||||
return nil, fmt.Errorf("failed to pop task: %w", err)
|
||||
}
|
||||
|
||||
var task DelayedTask
|
||||
if err := json.Unmarshal([]byte(result), &task); err != nil {
|
||||
return nil, fmt.Errorf("failed to unmarshal task: %w", err)
|
||||
}
|
||||
|
||||
return &task, nil
|
||||
}
|
||||
|
||||
// CompleteTask 标记任务完成
|
||||
func (q *TaskQueue) CompleteTask(ctx context.Context, task *DelayedTask) error {
|
||||
taskJSON, _ := json.Marshal(task)
|
||||
return q.client.LRem(ctx, q.processingKey, 1, string(taskJSON)).Err()
|
||||
}
|
||||
|
||||
// RetryTask 重试任务
|
||||
func (q *TaskQueue) RetryTask(ctx context.Context, task *DelayedTask, delay time.Duration) error {
|
||||
task.RetryCount++
|
||||
if task.RetryCount > task.MaxRetries {
|
||||
log.Printf("[TaskQueue] Task %s exceeded max retries, discarding", task.ID)
|
||||
return q.CompleteTask(ctx, task)
|
||||
}
|
||||
|
||||
task.ScheduledAt = time.Now().Add(delay)
|
||||
|
||||
// 先从处理中队列移除
|
||||
taskJSON, _ := json.Marshal(task)
|
||||
q.client.LRem(ctx, q.processingKey, 1, string(taskJSON))
|
||||
|
||||
// 重新调度
|
||||
return q.Schedule(ctx, task)
|
||||
}
|
||||
|
||||
// GetQueueStats 获取队列统计信息
|
||||
func (q *TaskQueue) GetQueueStats(ctx context.Context) (map[string]int64, error) {
|
||||
delayed, _ := q.client.ZCard(ctx, q.delayedKey).Result()
|
||||
ready, _ := q.client.LLen(ctx, q.readyKey).Result()
|
||||
processing, _ := q.client.LLen(ctx, q.processingKey).Result()
|
||||
|
||||
return map[string]int64{
|
||||
"delayed": delayed,
|
||||
"ready": ready,
|
||||
"processing": processing,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// GetPendingTasks 获取即将执行的任务(用于调试)
|
||||
func (q *TaskQueue) GetPendingTasks(ctx context.Context, limit int64) ([]DelayedTask, error) {
|
||||
results, err := q.client.ZRange(ctx, q.delayedKey, 0, limit-1).Result()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
tasks := make([]DelayedTask, 0, len(results))
|
||||
for _, result := range results {
|
||||
var task DelayedTask
|
||||
if err := json.Unmarshal([]byte(result), &task); err == nil {
|
||||
tasks = append(tasks, task)
|
||||
}
|
||||
}
|
||||
|
||||
return tasks, nil
|
||||
}
|
||||
227
internal/mq/task_worker.go
Normal file
227
internal/mq/task_worker.go
Normal file
@@ -0,0 +1,227 @@
|
||||
package mq
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"log"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
// TaskHandler 任务处理函数类型
|
||||
type TaskHandler func(ctx context.Context, task *DelayedTask) error
|
||||
|
||||
// TaskWorker 任务工作器,负责消费和执行任务
|
||||
type TaskWorker struct {
|
||||
queue *TaskQueue
|
||||
handlers map[TaskType]TaskHandler
|
||||
pollInterval time.Duration // 检查延迟任务间隔
|
||||
workerCount int // 并发工作器数量
|
||||
|
||||
mu sync.RWMutex
|
||||
running bool
|
||||
stopChan chan struct{}
|
||||
wg sync.WaitGroup
|
||||
|
||||
// 统计信息
|
||||
processedCount int64
|
||||
errorCount int64
|
||||
lastProcessed time.Time
|
||||
}
|
||||
|
||||
// NewTaskWorker 创建任务工作器
|
||||
func NewTaskWorker(queue *TaskQueue, pollInterval time.Duration, workerCount int) *TaskWorker {
|
||||
if pollInterval < time.Second {
|
||||
pollInterval = time.Second
|
||||
}
|
||||
if workerCount < 1 {
|
||||
workerCount = 1
|
||||
}
|
||||
|
||||
return &TaskWorker{
|
||||
queue: queue,
|
||||
handlers: make(map[TaskType]TaskHandler),
|
||||
pollInterval: pollInterval,
|
||||
workerCount: workerCount,
|
||||
stopChan: make(chan struct{}),
|
||||
}
|
||||
}
|
||||
|
||||
// RegisterHandler 注册任务处理器
|
||||
func (w *TaskWorker) RegisterHandler(taskType TaskType, handler TaskHandler) {
|
||||
w.mu.Lock()
|
||||
defer w.mu.Unlock()
|
||||
w.handlers[taskType] = handler
|
||||
log.Printf("[TaskWorker] Registered handler for task type: %s", taskType)
|
||||
}
|
||||
|
||||
// Start 启动工作器
|
||||
func (w *TaskWorker) Start(ctx context.Context) {
|
||||
w.mu.Lock()
|
||||
if w.running {
|
||||
w.mu.Unlock()
|
||||
log.Println("[TaskWorker] Already running")
|
||||
return
|
||||
}
|
||||
w.running = true
|
||||
w.stopChan = make(chan struct{})
|
||||
w.mu.Unlock()
|
||||
|
||||
log.Printf("[TaskWorker] Starting with %d workers, poll interval: %v", w.workerCount, w.pollInterval)
|
||||
|
||||
// 启动延迟任务轮询器
|
||||
w.wg.Add(1)
|
||||
go w.pollDelayedTasks(ctx)
|
||||
|
||||
// 启动工作器
|
||||
for i := 0; i < w.workerCount; i++ {
|
||||
w.wg.Add(1)
|
||||
go w.worker(ctx, i)
|
||||
}
|
||||
}
|
||||
|
||||
// Stop 停止工作器
|
||||
func (w *TaskWorker) Stop() {
|
||||
w.mu.Lock()
|
||||
if !w.running {
|
||||
w.mu.Unlock()
|
||||
return
|
||||
}
|
||||
w.running = false
|
||||
close(w.stopChan)
|
||||
w.mu.Unlock()
|
||||
|
||||
log.Println("[TaskWorker] Stopping...")
|
||||
w.wg.Wait()
|
||||
log.Println("[TaskWorker] Stopped")
|
||||
}
|
||||
|
||||
// pollDelayedTasks 定期检查并移动到期的延迟任务
|
||||
func (w *TaskWorker) pollDelayedTasks(ctx context.Context) {
|
||||
defer w.wg.Done()
|
||||
|
||||
ticker := time.NewTicker(w.pollInterval)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-w.stopChan:
|
||||
return
|
||||
case <-ticker.C:
|
||||
moved, err := w.queue.MoveReadyTasks(ctx)
|
||||
if err != nil {
|
||||
log.Printf("[TaskWorker] Error moving ready tasks: %v", err)
|
||||
} else if moved > 0 {
|
||||
log.Printf("[TaskWorker] Moved %d tasks to ready queue", moved)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// worker 工作器协程
|
||||
func (w *TaskWorker) worker(ctx context.Context, id int) {
|
||||
defer w.wg.Done()
|
||||
|
||||
log.Printf("[TaskWorker] Worker %d started", id)
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
log.Printf("[TaskWorker] Worker %d stopping (context done)", id)
|
||||
return
|
||||
case <-w.stopChan:
|
||||
log.Printf("[TaskWorker] Worker %d stopping (stop signal)", id)
|
||||
return
|
||||
default:
|
||||
// 尝试获取任务
|
||||
task, err := w.queue.PopTask(ctx, time.Second)
|
||||
if err != nil {
|
||||
log.Printf("[TaskWorker] Worker %d error popping task: %v", id, err)
|
||||
time.Sleep(time.Second) // 错误后等待
|
||||
continue
|
||||
}
|
||||
|
||||
if task == nil {
|
||||
continue // 没有任务,继续轮询
|
||||
}
|
||||
|
||||
// 处理任务
|
||||
w.processTask(ctx, id, task)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// processTask 处理单个任务
|
||||
func (w *TaskWorker) processTask(ctx context.Context, workerID int, task *DelayedTask) {
|
||||
startTime := time.Now()
|
||||
log.Printf("[TaskWorker] Worker %d processing task %s (type: %s, user: %d)",
|
||||
workerID, task.ID, task.Type, task.UserID)
|
||||
|
||||
w.mu.RLock()
|
||||
handler, exists := w.handlers[task.Type]
|
||||
w.mu.RUnlock()
|
||||
|
||||
if !exists {
|
||||
log.Printf("[TaskWorker] No handler for task type: %s", task.Type)
|
||||
w.queue.CompleteTask(ctx, task)
|
||||
return
|
||||
}
|
||||
|
||||
// 执行任务
|
||||
err := handler(ctx, task)
|
||||
duration := time.Since(startTime)
|
||||
|
||||
if err != nil {
|
||||
log.Printf("[TaskWorker] Task %s failed (attempt %d/%d): %v",
|
||||
task.ID, task.RetryCount+1, task.MaxRetries, err)
|
||||
|
||||
w.mu.Lock()
|
||||
w.errorCount++
|
||||
w.mu.Unlock()
|
||||
|
||||
// 重试(指数退避)
|
||||
retryDelay := time.Duration(1<<task.RetryCount) * time.Minute
|
||||
if err := w.queue.RetryTask(ctx, task, retryDelay); err != nil {
|
||||
log.Printf("[TaskWorker] Failed to retry task %s: %v", task.ID, err)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// 完成任务
|
||||
if err := w.queue.CompleteTask(ctx, task); err != nil {
|
||||
log.Printf("[TaskWorker] Failed to complete task %s: %v", task.ID, err)
|
||||
}
|
||||
|
||||
w.mu.Lock()
|
||||
w.processedCount++
|
||||
w.lastProcessed = time.Now()
|
||||
w.mu.Unlock()
|
||||
|
||||
log.Printf("[TaskWorker] Task %s completed in %v", task.ID, duration)
|
||||
}
|
||||
|
||||
// GetStats 获取工作器统计信息
|
||||
func (w *TaskWorker) GetStats() map[string]interface{} {
|
||||
w.mu.RLock()
|
||||
defer w.mu.RUnlock()
|
||||
|
||||
return map[string]interface{}{
|
||||
"running": w.running,
|
||||
"worker_count": w.workerCount,
|
||||
"poll_interval": w.pollInterval.String(),
|
||||
"processed_count": w.processedCount,
|
||||
"error_count": w.errorCount,
|
||||
"last_processed": w.lastProcessed,
|
||||
}
|
||||
}
|
||||
|
||||
// ParseRecurringPayload 解析周期性交易任务载荷
|
||||
func ParseRecurringPayload(task *DelayedTask) (*RecurringTransactionPayload, error) {
|
||||
var payload RecurringTransactionPayload
|
||||
if err := json.Unmarshal(task.Payload, &payload); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &payload, nil
|
||||
}
|
||||
Reference in New Issue
Block a user