feat: 实现基于 Redis 的延迟任务队列和工作器,支持周期性交易处理等后台任务调度。
This commit is contained in:
@@ -5,10 +5,12 @@ import (
|
|||||||
"log"
|
"log"
|
||||||
"os"
|
"os"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
|
"time"
|
||||||
|
|
||||||
"accounting-app/internal/cache"
|
"accounting-app/internal/cache"
|
||||||
"accounting-app/internal/config"
|
"accounting-app/internal/config"
|
||||||
"accounting-app/internal/database"
|
"accounting-app/internal/database"
|
||||||
|
"accounting-app/internal/mq"
|
||||||
"accounting-app/internal/repository"
|
"accounting-app/internal/repository"
|
||||||
"accounting-app/internal/router"
|
"accounting-app/internal/router"
|
||||||
"accounting-app/internal/service"
|
"accounting-app/internal/service"
|
||||||
@@ -97,6 +99,7 @@ func main() {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
// Redis not available - fall back to old system
|
// Redis not available - fall back to old system
|
||||||
log.Printf("Warning: Redis connection failed (%v), falling back to non-cached exchange rate system", err)
|
log.Printf("Warning: Redis connection failed (%v), falling back to non-cached exchange rate system", err)
|
||||||
|
log.Printf("Warning: Recurring transaction MQ system disabled (requires Redis)")
|
||||||
|
|
||||||
// Use old scheduler
|
// Use old scheduler
|
||||||
scheduler := service.NewExchangeRateScheduler(yunAPIClient, cfg.SyncInterval)
|
scheduler := service.NewExchangeRateScheduler(yunAPIClient, cfg.SyncInterval)
|
||||||
@@ -116,6 +119,22 @@ func main() {
|
|||||||
// Start the new SyncScheduler in background
|
// Start the new SyncScheduler in background
|
||||||
// This will perform initial sync immediately (Requirement 3.1)
|
// This will perform initial sync immediately (Requirement 3.1)
|
||||||
go syncScheduler.Start(ctx)
|
go syncScheduler.Start(ctx)
|
||||||
|
|
||||||
|
// Initialize and start MQ task system for recurring transactions
|
||||||
|
recurringTaskSystem, err := mq.InitRecurringTaskSystem(
|
||||||
|
ctx,
|
||||||
|
redisClient.Client(),
|
||||||
|
db,
|
||||||
|
5*time.Second, // Poll interval: 检查延迟任务的间隔
|
||||||
|
2, // Worker count: 并发处理任务的数量
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
log.Printf("Warning: Failed to initialize recurring task system: %v", err)
|
||||||
|
} else {
|
||||||
|
// Start MQ worker in background
|
||||||
|
go recurringTaskSystem.Start(ctx)
|
||||||
|
log.Println("Recurring transaction MQ system started")
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Get port from config or environment
|
// Get port from config or environment
|
||||||
|
|||||||
78
internal/handler/task_queue_handler.go
Normal file
78
internal/handler/task_queue_handler.go
Normal file
@@ -0,0 +1,78 @@
|
|||||||
|
package handler
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http"
|
||||||
|
|
||||||
|
"accounting-app/internal/mq"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
)
|
||||||
|
|
||||||
|
// TaskQueueHandler 任务队列状态查询 Handler
|
||||||
|
type TaskQueueHandler struct {
|
||||||
|
taskSystem *mq.RecurringTaskSystem
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewTaskQueueHandler 创建任务队列 Handler
|
||||||
|
func NewTaskQueueHandler(taskSystem *mq.RecurringTaskSystem) *TaskQueueHandler {
|
||||||
|
return &TaskQueueHandler{
|
||||||
|
taskSystem: taskSystem,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetStats 获取任务队列统计信息
|
||||||
|
// GET /api/v1/admin/tasks/stats
|
||||||
|
func (h *TaskQueueHandler) GetStats(c *gin.Context) {
|
||||||
|
if h.taskSystem == nil {
|
||||||
|
c.JSON(http.StatusServiceUnavailable, gin.H{
|
||||||
|
"success": false,
|
||||||
|
"error": "Task system not available (Redis required)",
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
stats := h.taskSystem.GetStats(c.Request.Context())
|
||||||
|
|
||||||
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
"success": true,
|
||||||
|
"data": stats,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetPendingTasks 获取即将执行的任务列表
|
||||||
|
// GET /api/v1/admin/tasks/pending
|
||||||
|
func (h *TaskQueueHandler) GetPendingTasks(c *gin.Context) {
|
||||||
|
if h.taskSystem == nil {
|
||||||
|
c.JSON(http.StatusServiceUnavailable, gin.H{
|
||||||
|
"success": false,
|
||||||
|
"error": "Task system not available (Redis required)",
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
tasks, err := h.taskSystem.Queue.GetPendingTasks(c.Request.Context(), 50)
|
||||||
|
if err != nil {
|
||||||
|
c.JSON(http.StatusInternalServerError, gin.H{
|
||||||
|
"success": false,
|
||||||
|
"error": err.Error(),
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
"success": true,
|
||||||
|
"data": gin.H{
|
||||||
|
"count": len(tasks),
|
||||||
|
"tasks": tasks,
|
||||||
|
},
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// RegisterRoutes 注册路由
|
||||||
|
func (h *TaskQueueHandler) RegisterRoutes(rg *gin.RouterGroup) {
|
||||||
|
tasks := rg.Group("/admin/tasks")
|
||||||
|
{
|
||||||
|
tasks.GET("/stats", h.GetStats)
|
||||||
|
tasks.GET("/pending", h.GetPendingTasks)
|
||||||
|
}
|
||||||
|
}
|
||||||
95
internal/mq/init.go
Normal file
95
internal/mq/init.go
Normal file
@@ -0,0 +1,95 @@
|
|||||||
|
package mq
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"log"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"accounting-app/internal/repository"
|
||||||
|
|
||||||
|
"github.com/redis/go-redis/v9"
|
||||||
|
"gorm.io/gorm"
|
||||||
|
)
|
||||||
|
|
||||||
|
// RecurringTaskSystem 周期性交易任务系统
|
||||||
|
type RecurringTaskSystem struct {
|
||||||
|
Queue *TaskQueue
|
||||||
|
Worker *TaskWorker
|
||||||
|
Handler *RecurringTransactionHandler
|
||||||
|
}
|
||||||
|
|
||||||
|
// InitRecurringTaskSystem 初始化周期性交易任务系统
|
||||||
|
// 需要 Redis 连接来存储任务队列
|
||||||
|
func InitRecurringTaskSystem(
|
||||||
|
ctx context.Context,
|
||||||
|
redisClient *redis.Client,
|
||||||
|
db *gorm.DB,
|
||||||
|
pollInterval time.Duration,
|
||||||
|
workerCount int,
|
||||||
|
) (*RecurringTaskSystem, error) {
|
||||||
|
log.Println("[MQ] Initializing recurring task system...")
|
||||||
|
|
||||||
|
// 创建任务队列
|
||||||
|
taskQueue := NewTaskQueue(redisClient, "novault:tasks")
|
||||||
|
|
||||||
|
// 创建任务 Worker
|
||||||
|
worker := NewTaskWorker(taskQueue, pollInterval, workerCount)
|
||||||
|
|
||||||
|
// 创建周期性交易处理器
|
||||||
|
recurringRepo := repository.NewRecurringTransactionRepository(db)
|
||||||
|
transactionRepo := repository.NewTransactionRepository(db)
|
||||||
|
accountRepo := repository.NewAccountRepository(db)
|
||||||
|
|
||||||
|
handler := NewRecurringTransactionHandler(
|
||||||
|
db,
|
||||||
|
recurringRepo,
|
||||||
|
transactionRepo,
|
||||||
|
accountRepo,
|
||||||
|
taskQueue,
|
||||||
|
)
|
||||||
|
|
||||||
|
// 注册处理器
|
||||||
|
worker.RegisterHandler(TaskTypeRecurringTransaction, handler.Handle)
|
||||||
|
|
||||||
|
system := &RecurringTaskSystem{
|
||||||
|
Queue: taskQueue,
|
||||||
|
Worker: worker,
|
||||||
|
Handler: handler,
|
||||||
|
}
|
||||||
|
|
||||||
|
// 启动时调度所有活跃的周期性交易
|
||||||
|
if err := handler.ScheduleAllActive(ctx); err != nil {
|
||||||
|
log.Printf("[MQ] Warning: failed to schedule active recurring transactions: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 处理逾期任务(补偿机制)
|
||||||
|
if err := handler.ProcessOverdue(ctx); err != nil {
|
||||||
|
log.Printf("[MQ] Warning: failed to process overdue transactions: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Println("[MQ] Recurring task system initialized successfully")
|
||||||
|
return system, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Start 启动任务系统
|
||||||
|
func (s *RecurringTaskSystem) Start(ctx context.Context) {
|
||||||
|
log.Println("[MQ] Starting recurring task worker...")
|
||||||
|
s.Worker.Start(ctx)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Stop 停止任务系统
|
||||||
|
func (s *RecurringTaskSystem) Stop() {
|
||||||
|
log.Println("[MQ] Stopping recurring task worker...")
|
||||||
|
s.Worker.Stop()
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetStats 获取系统统计信息
|
||||||
|
func (s *RecurringTaskSystem) GetStats(ctx context.Context) map[string]interface{} {
|
||||||
|
queueStats, _ := s.Queue.GetQueueStats(ctx)
|
||||||
|
workerStats := s.Worker.GetStats()
|
||||||
|
|
||||||
|
return map[string]interface{}{
|
||||||
|
"queue": queueStats,
|
||||||
|
"worker": workerStats,
|
||||||
|
}
|
||||||
|
}
|
||||||
245
internal/mq/recurring_handler.go
Normal file
245
internal/mq/recurring_handler.go
Normal file
@@ -0,0 +1,245 @@
|
|||||||
|
package mq
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"log"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"accounting-app/internal/models"
|
||||||
|
"accounting-app/internal/repository"
|
||||||
|
|
||||||
|
"gorm.io/gorm"
|
||||||
|
)
|
||||||
|
|
||||||
|
// RecurringTransactionHandler 周期性交易任务处理器
|
||||||
|
type RecurringTransactionHandler struct {
|
||||||
|
db *gorm.DB
|
||||||
|
recurringRepo *repository.RecurringTransactionRepository
|
||||||
|
transactionRepo *repository.TransactionRepository
|
||||||
|
accountRepo *repository.AccountRepository
|
||||||
|
taskQueue *TaskQueue
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewRecurringTransactionHandler 创建处理器实例
|
||||||
|
func NewRecurringTransactionHandler(
|
||||||
|
db *gorm.DB,
|
||||||
|
recurringRepo *repository.RecurringTransactionRepository,
|
||||||
|
transactionRepo *repository.TransactionRepository,
|
||||||
|
accountRepo *repository.AccountRepository,
|
||||||
|
taskQueue *TaskQueue,
|
||||||
|
) *RecurringTransactionHandler {
|
||||||
|
return &RecurringTransactionHandler{
|
||||||
|
db: db,
|
||||||
|
recurringRepo: recurringRepo,
|
||||||
|
transactionRepo: transactionRepo,
|
||||||
|
accountRepo: accountRepo,
|
||||||
|
taskQueue: taskQueue,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Handle 处理周期性交易任务
|
||||||
|
func (h *RecurringTransactionHandler) Handle(ctx context.Context, task *DelayedTask) error {
|
||||||
|
payload, err := ParseRecurringPayload(task)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to parse payload: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Printf("[RecurringHandler] Processing recurring transaction %d for user %d",
|
||||||
|
payload.RecurringTransactionID, task.UserID)
|
||||||
|
|
||||||
|
// 获取周期性交易记录
|
||||||
|
recurring, err := h.recurringRepo.GetByID(task.UserID, payload.RecurringTransactionID)
|
||||||
|
if err != nil {
|
||||||
|
if err == gorm.ErrRecordNotFound {
|
||||||
|
log.Printf("[RecurringHandler] Recurring transaction %d not found, skipping", payload.RecurringTransactionID)
|
||||||
|
return nil // 已被删除,不需要重试
|
||||||
|
}
|
||||||
|
return fmt.Errorf("failed to get recurring transaction: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 检查是否已禁用或已过期
|
||||||
|
if !recurring.IsActive {
|
||||||
|
log.Printf("[RecurringHandler] Recurring transaction %d is inactive, skipping", recurring.ID)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if recurring.EndDate != nil && time.Now().After(*recurring.EndDate) {
|
||||||
|
log.Printf("[RecurringHandler] Recurring transaction %d has ended, deactivating", recurring.ID)
|
||||||
|
recurring.IsActive = false
|
||||||
|
return h.recurringRepo.Update(recurring)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 开始数据库事务
|
||||||
|
tx := h.db.Begin()
|
||||||
|
if tx.Error != nil {
|
||||||
|
return fmt.Errorf("failed to begin transaction: %w", tx.Error)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 创建实际交易记录
|
||||||
|
transaction := &models.Transaction{
|
||||||
|
UserID: recurring.UserID,
|
||||||
|
Amount: recurring.Amount,
|
||||||
|
Type: recurring.Type,
|
||||||
|
CategoryID: recurring.CategoryID,
|
||||||
|
AccountID: recurring.AccountID,
|
||||||
|
Currency: recurring.Currency,
|
||||||
|
TransactionDate: recurring.NextOccurrence,
|
||||||
|
Note: recurring.Note,
|
||||||
|
RecurringID: &recurring.ID,
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := tx.Create(transaction).Error; err != nil {
|
||||||
|
tx.Rollback()
|
||||||
|
return fmt.Errorf("failed to create transaction: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 更新账户余额
|
||||||
|
var account models.Account
|
||||||
|
if err := tx.First(&account, recurring.AccountID).Error; err != nil {
|
||||||
|
tx.Rollback()
|
||||||
|
return fmt.Errorf("failed to get account: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
switch recurring.Type {
|
||||||
|
case models.TransactionTypeIncome:
|
||||||
|
account.Balance += recurring.Amount
|
||||||
|
case models.TransactionTypeExpense:
|
||||||
|
account.Balance -= recurring.Amount
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := tx.Save(&account).Error; err != nil {
|
||||||
|
tx.Rollback()
|
||||||
|
return fmt.Errorf("failed to update account balance: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 计算下一次执行时间
|
||||||
|
nextOccurrence := h.calculateNextOccurrence(recurring.NextOccurrence, recurring.Frequency)
|
||||||
|
recurring.NextOccurrence = nextOccurrence
|
||||||
|
|
||||||
|
// 检查下次执行是否超过结束日期
|
||||||
|
if recurring.EndDate != nil && nextOccurrence.After(*recurring.EndDate) {
|
||||||
|
recurring.IsActive = false
|
||||||
|
log.Printf("[RecurringHandler] Recurring transaction %d will end after this execution", recurring.ID)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := tx.Save(recurring).Error; err != nil {
|
||||||
|
tx.Rollback()
|
||||||
|
return fmt.Errorf("failed to update recurring transaction: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 提交事务
|
||||||
|
if err := tx.Commit().Error; err != nil {
|
||||||
|
return fmt.Errorf("failed to commit transaction: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Printf("[RecurringHandler] Created transaction %d from recurring %d, amount: %.2f %s",
|
||||||
|
transaction.ID, recurring.ID, recurring.Amount, recurring.Currency)
|
||||||
|
|
||||||
|
// 如果仍然活跃,调度下一次执行
|
||||||
|
if recurring.IsActive {
|
||||||
|
if err := h.scheduleNext(ctx, recurring); err != nil {
|
||||||
|
log.Printf("[RecurringHandler] Warning: failed to schedule next execution: %v", err)
|
||||||
|
// 不返回错误,当前任务已成功完成
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// calculateNextOccurrence 计算下一次执行时间
|
||||||
|
func (h *RecurringTransactionHandler) calculateNextOccurrence(current time.Time, frequency models.FrequencyType) time.Time {
|
||||||
|
switch frequency {
|
||||||
|
case models.FrequencyDaily:
|
||||||
|
return current.AddDate(0, 0, 1)
|
||||||
|
case models.FrequencyWeekly:
|
||||||
|
return current.AddDate(0, 0, 7)
|
||||||
|
case models.FrequencyMonthly:
|
||||||
|
return current.AddDate(0, 1, 0)
|
||||||
|
case models.FrequencyYearly:
|
||||||
|
return current.AddDate(1, 0, 0)
|
||||||
|
default:
|
||||||
|
return current.AddDate(0, 1, 0) // 默认月度
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// scheduleNext 调度下一次执行
|
||||||
|
func (h *RecurringTransactionHandler) scheduleNext(ctx context.Context, recurring *models.RecurringTransaction) error {
|
||||||
|
return h.taskQueue.ScheduleRecurringTransaction(
|
||||||
|
ctx,
|
||||||
|
recurring.UserID,
|
||||||
|
recurring.ID,
|
||||||
|
recurring.NextOccurrence,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ScheduleAllActive 调度所有活跃的周期性交易(启动时调用)
|
||||||
|
func (h *RecurringTransactionHandler) ScheduleAllActive(ctx context.Context) error {
|
||||||
|
log.Println("[RecurringHandler] Scheduling all active recurring transactions...")
|
||||||
|
|
||||||
|
// 获取所有活跃的周期性交易
|
||||||
|
var recurringList []models.RecurringTransaction
|
||||||
|
err := h.db.Where("is_active = ?", true).Find(&recurringList).Error
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to get active recurring transactions: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
scheduledCount := 0
|
||||||
|
for _, recurring := range recurringList {
|
||||||
|
// 跳过已过期的
|
||||||
|
if recurring.EndDate != nil && time.Now().After(*recurring.EndDate) {
|
||||||
|
recurring.IsActive = false
|
||||||
|
h.db.Save(&recurring)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// 调度任务
|
||||||
|
err := h.taskQueue.ScheduleRecurringTransaction(
|
||||||
|
ctx,
|
||||||
|
recurring.UserID,
|
||||||
|
recurring.ID,
|
||||||
|
recurring.NextOccurrence,
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
log.Printf("[RecurringHandler] Failed to schedule recurring %d: %v", recurring.ID, err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
scheduledCount++
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Printf("[RecurringHandler] Scheduled %d recurring transactions", scheduledCount)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ProcessOverdue 处理所有逾期的周期性交易(补偿机制)
|
||||||
|
func (h *RecurringTransactionHandler) ProcessOverdue(ctx context.Context) error {
|
||||||
|
log.Println("[RecurringHandler] Processing overdue recurring transactions...")
|
||||||
|
|
||||||
|
now := time.Now()
|
||||||
|
var overdueList []models.RecurringTransaction
|
||||||
|
err := h.db.Where("is_active = ? AND next_occurrence <= ?", true, now).Find(&overdueList).Error
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to get overdue recurring transactions: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(overdueList) == 0 {
|
||||||
|
log.Println("[RecurringHandler] No overdue recurring transactions found")
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Printf("[RecurringHandler] Found %d overdue recurring transactions", len(overdueList))
|
||||||
|
|
||||||
|
for _, recurring := range overdueList {
|
||||||
|
err := h.taskQueue.ScheduleRecurringTransaction(
|
||||||
|
ctx,
|
||||||
|
recurring.UserID,
|
||||||
|
recurring.ID,
|
||||||
|
recurring.NextOccurrence,
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
log.Printf("[RecurringHandler] Failed to schedule overdue recurring %d: %v", recurring.ID, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
220
internal/mq/task_queue.go
Normal file
220
internal/mq/task_queue.go
Normal file
@@ -0,0 +1,220 @@
|
|||||||
|
package mq
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"log"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/redis/go-redis/v9"
|
||||||
|
)
|
||||||
|
|
||||||
|
// TaskType 定义任务类型
|
||||||
|
type TaskType string
|
||||||
|
|
||||||
|
const (
|
||||||
|
// TaskTypeRecurringTransaction 周期性交易处理任务
|
||||||
|
TaskTypeRecurringTransaction TaskType = "recurring_transaction"
|
||||||
|
// TaskTypeAllocationRule 分配规则执行任务
|
||||||
|
TaskTypeAllocationRule TaskType = "allocation_rule"
|
||||||
|
)
|
||||||
|
|
||||||
|
// DelayedTask 延迟任务结构
|
||||||
|
type DelayedTask struct {
|
||||||
|
ID string `json:"id"`
|
||||||
|
Type TaskType `json:"type"`
|
||||||
|
UserID uint `json:"user_id"`
|
||||||
|
Payload json.RawMessage `json:"payload"`
|
||||||
|
ScheduledAt time.Time `json:"scheduled_at"`
|
||||||
|
CreatedAt time.Time `json:"created_at"`
|
||||||
|
RetryCount int `json:"retry_count"`
|
||||||
|
MaxRetries int `json:"max_retries"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// RecurringTransactionPayload 周期性交易任务载荷
|
||||||
|
type RecurringTransactionPayload struct {
|
||||||
|
RecurringTransactionID uint `json:"recurring_transaction_id"`
|
||||||
|
NextOccurrence time.Time `json:"next_occurrence"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// TaskQueue 基于 Redis 的延迟任务队列
|
||||||
|
type TaskQueue struct {
|
||||||
|
client *redis.Client
|
||||||
|
keyPrefix string
|
||||||
|
delayedKey string // 延迟任务有序集合
|
||||||
|
readyKey string // 就绪任务列表
|
||||||
|
processingKey string // 处理中任务
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewTaskQueue 创建任务队列实例
|
||||||
|
func NewTaskQueue(client *redis.Client, keyPrefix string) *TaskQueue {
|
||||||
|
if keyPrefix == "" {
|
||||||
|
keyPrefix = "novault:tasks"
|
||||||
|
}
|
||||||
|
return &TaskQueue{
|
||||||
|
client: client,
|
||||||
|
keyPrefix: keyPrefix,
|
||||||
|
delayedKey: keyPrefix + ":delayed",
|
||||||
|
readyKey: keyPrefix + ":ready",
|
||||||
|
processingKey: keyPrefix + ":processing",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Schedule 调度延迟任务
|
||||||
|
func (q *TaskQueue) Schedule(ctx context.Context, task *DelayedTask) error {
|
||||||
|
// 序列化任务
|
||||||
|
taskJSON, err := json.Marshal(task)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to marshal task: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 使用 ZADD 添加到延迟队列,score 为执行时间戳
|
||||||
|
score := float64(task.ScheduledAt.Unix())
|
||||||
|
err = q.client.ZAdd(ctx, q.delayedKey, redis.Z{
|
||||||
|
Score: score,
|
||||||
|
Member: string(taskJSON),
|
||||||
|
}).Err()
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to schedule task: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Printf("[TaskQueue] Scheduled task %s for %s", task.ID, task.ScheduledAt.Format("2006-01-02 15:04:05"))
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ScheduleRecurringTransaction 调度周期性交易任务
|
||||||
|
func (q *TaskQueue) ScheduleRecurringTransaction(ctx context.Context, userID uint, recurringID uint, scheduledAt time.Time) error {
|
||||||
|
payload := RecurringTransactionPayload{
|
||||||
|
RecurringTransactionID: recurringID,
|
||||||
|
NextOccurrence: scheduledAt,
|
||||||
|
}
|
||||||
|
payloadJSON, _ := json.Marshal(payload)
|
||||||
|
|
||||||
|
task := &DelayedTask{
|
||||||
|
ID: fmt.Sprintf("recurring_%d_%d", recurringID, scheduledAt.Unix()),
|
||||||
|
Type: TaskTypeRecurringTransaction,
|
||||||
|
UserID: userID,
|
||||||
|
Payload: payloadJSON,
|
||||||
|
ScheduledAt: scheduledAt,
|
||||||
|
CreatedAt: time.Now(),
|
||||||
|
MaxRetries: 3,
|
||||||
|
}
|
||||||
|
|
||||||
|
return q.Schedule(ctx, task)
|
||||||
|
}
|
||||||
|
|
||||||
|
// MoveReadyTasks 将到期的延迟任务移动到就绪队列
|
||||||
|
func (q *TaskQueue) MoveReadyTasks(ctx context.Context) (int64, error) {
|
||||||
|
now := time.Now().Unix()
|
||||||
|
|
||||||
|
// 使用 Lua 脚本原子性地移动任务
|
||||||
|
script := redis.NewScript(`
|
||||||
|
local delayed_key = KEYS[1]
|
||||||
|
local ready_key = KEYS[2]
|
||||||
|
local now = tonumber(ARGV[1])
|
||||||
|
|
||||||
|
-- 获取所有到期的任务
|
||||||
|
local tasks = redis.call('ZRANGEBYSCORE', delayed_key, '-inf', now)
|
||||||
|
|
||||||
|
if #tasks == 0 then
|
||||||
|
return 0
|
||||||
|
end
|
||||||
|
|
||||||
|
-- 移动到就绪队列
|
||||||
|
for i, task in ipairs(tasks) do
|
||||||
|
redis.call('LPUSH', ready_key, task)
|
||||||
|
end
|
||||||
|
|
||||||
|
-- 从延迟队列中删除
|
||||||
|
redis.call('ZREMRANGEBYSCORE', delayed_key, '-inf', now)
|
||||||
|
|
||||||
|
return #tasks
|
||||||
|
`)
|
||||||
|
|
||||||
|
result, err := script.Run(ctx, q.client, []string{q.delayedKey, q.readyKey}, now).Int64()
|
||||||
|
if err != nil {
|
||||||
|
return 0, fmt.Errorf("failed to move ready tasks: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if result > 0 {
|
||||||
|
log.Printf("[TaskQueue] Moved %d tasks to ready queue", result)
|
||||||
|
}
|
||||||
|
|
||||||
|
return result, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// PopTask 从就绪队列中取出一个任务
|
||||||
|
func (q *TaskQueue) PopTask(ctx context.Context, timeout time.Duration) (*DelayedTask, error) {
|
||||||
|
// 使用 BRPOPLPUSH 原子性地获取任务并移动到处理中队列
|
||||||
|
result, err := q.client.BRPopLPush(ctx, q.readyKey, q.processingKey, timeout).Result()
|
||||||
|
if err != nil {
|
||||||
|
if err == redis.Nil {
|
||||||
|
return nil, nil // 没有任务
|
||||||
|
}
|
||||||
|
return nil, fmt.Errorf("failed to pop task: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var task DelayedTask
|
||||||
|
if err := json.Unmarshal([]byte(result), &task); err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to unmarshal task: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return &task, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// CompleteTask 标记任务完成
|
||||||
|
func (q *TaskQueue) CompleteTask(ctx context.Context, task *DelayedTask) error {
|
||||||
|
taskJSON, _ := json.Marshal(task)
|
||||||
|
return q.client.LRem(ctx, q.processingKey, 1, string(taskJSON)).Err()
|
||||||
|
}
|
||||||
|
|
||||||
|
// RetryTask 重试任务
|
||||||
|
func (q *TaskQueue) RetryTask(ctx context.Context, task *DelayedTask, delay time.Duration) error {
|
||||||
|
task.RetryCount++
|
||||||
|
if task.RetryCount > task.MaxRetries {
|
||||||
|
log.Printf("[TaskQueue] Task %s exceeded max retries, discarding", task.ID)
|
||||||
|
return q.CompleteTask(ctx, task)
|
||||||
|
}
|
||||||
|
|
||||||
|
task.ScheduledAt = time.Now().Add(delay)
|
||||||
|
|
||||||
|
// 先从处理中队列移除
|
||||||
|
taskJSON, _ := json.Marshal(task)
|
||||||
|
q.client.LRem(ctx, q.processingKey, 1, string(taskJSON))
|
||||||
|
|
||||||
|
// 重新调度
|
||||||
|
return q.Schedule(ctx, task)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetQueueStats 获取队列统计信息
|
||||||
|
func (q *TaskQueue) GetQueueStats(ctx context.Context) (map[string]int64, error) {
|
||||||
|
delayed, _ := q.client.ZCard(ctx, q.delayedKey).Result()
|
||||||
|
ready, _ := q.client.LLen(ctx, q.readyKey).Result()
|
||||||
|
processing, _ := q.client.LLen(ctx, q.processingKey).Result()
|
||||||
|
|
||||||
|
return map[string]int64{
|
||||||
|
"delayed": delayed,
|
||||||
|
"ready": ready,
|
||||||
|
"processing": processing,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetPendingTasks 获取即将执行的任务(用于调试)
|
||||||
|
func (q *TaskQueue) GetPendingTasks(ctx context.Context, limit int64) ([]DelayedTask, error) {
|
||||||
|
results, err := q.client.ZRange(ctx, q.delayedKey, 0, limit-1).Result()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
tasks := make([]DelayedTask, 0, len(results))
|
||||||
|
for _, result := range results {
|
||||||
|
var task DelayedTask
|
||||||
|
if err := json.Unmarshal([]byte(result), &task); err == nil {
|
||||||
|
tasks = append(tasks, task)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return tasks, nil
|
||||||
|
}
|
||||||
227
internal/mq/task_worker.go
Normal file
227
internal/mq/task_worker.go
Normal file
@@ -0,0 +1,227 @@
|
|||||||
|
package mq
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"log"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
// TaskHandler 任务处理函数类型
|
||||||
|
type TaskHandler func(ctx context.Context, task *DelayedTask) error
|
||||||
|
|
||||||
|
// TaskWorker 任务工作器,负责消费和执行任务
|
||||||
|
type TaskWorker struct {
|
||||||
|
queue *TaskQueue
|
||||||
|
handlers map[TaskType]TaskHandler
|
||||||
|
pollInterval time.Duration // 检查延迟任务间隔
|
||||||
|
workerCount int // 并发工作器数量
|
||||||
|
|
||||||
|
mu sync.RWMutex
|
||||||
|
running bool
|
||||||
|
stopChan chan struct{}
|
||||||
|
wg sync.WaitGroup
|
||||||
|
|
||||||
|
// 统计信息
|
||||||
|
processedCount int64
|
||||||
|
errorCount int64
|
||||||
|
lastProcessed time.Time
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewTaskWorker 创建任务工作器
|
||||||
|
func NewTaskWorker(queue *TaskQueue, pollInterval time.Duration, workerCount int) *TaskWorker {
|
||||||
|
if pollInterval < time.Second {
|
||||||
|
pollInterval = time.Second
|
||||||
|
}
|
||||||
|
if workerCount < 1 {
|
||||||
|
workerCount = 1
|
||||||
|
}
|
||||||
|
|
||||||
|
return &TaskWorker{
|
||||||
|
queue: queue,
|
||||||
|
handlers: make(map[TaskType]TaskHandler),
|
||||||
|
pollInterval: pollInterval,
|
||||||
|
workerCount: workerCount,
|
||||||
|
stopChan: make(chan struct{}),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// RegisterHandler 注册任务处理器
|
||||||
|
func (w *TaskWorker) RegisterHandler(taskType TaskType, handler TaskHandler) {
|
||||||
|
w.mu.Lock()
|
||||||
|
defer w.mu.Unlock()
|
||||||
|
w.handlers[taskType] = handler
|
||||||
|
log.Printf("[TaskWorker] Registered handler for task type: %s", taskType)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Start 启动工作器
|
||||||
|
func (w *TaskWorker) Start(ctx context.Context) {
|
||||||
|
w.mu.Lock()
|
||||||
|
if w.running {
|
||||||
|
w.mu.Unlock()
|
||||||
|
log.Println("[TaskWorker] Already running")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
w.running = true
|
||||||
|
w.stopChan = make(chan struct{})
|
||||||
|
w.mu.Unlock()
|
||||||
|
|
||||||
|
log.Printf("[TaskWorker] Starting with %d workers, poll interval: %v", w.workerCount, w.pollInterval)
|
||||||
|
|
||||||
|
// 启动延迟任务轮询器
|
||||||
|
w.wg.Add(1)
|
||||||
|
go w.pollDelayedTasks(ctx)
|
||||||
|
|
||||||
|
// 启动工作器
|
||||||
|
for i := 0; i < w.workerCount; i++ {
|
||||||
|
w.wg.Add(1)
|
||||||
|
go w.worker(ctx, i)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Stop 停止工作器
|
||||||
|
func (w *TaskWorker) Stop() {
|
||||||
|
w.mu.Lock()
|
||||||
|
if !w.running {
|
||||||
|
w.mu.Unlock()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
w.running = false
|
||||||
|
close(w.stopChan)
|
||||||
|
w.mu.Unlock()
|
||||||
|
|
||||||
|
log.Println("[TaskWorker] Stopping...")
|
||||||
|
w.wg.Wait()
|
||||||
|
log.Println("[TaskWorker] Stopped")
|
||||||
|
}
|
||||||
|
|
||||||
|
// pollDelayedTasks 定期检查并移动到期的延迟任务
|
||||||
|
func (w *TaskWorker) pollDelayedTasks(ctx context.Context) {
|
||||||
|
defer w.wg.Done()
|
||||||
|
|
||||||
|
ticker := time.NewTicker(w.pollInterval)
|
||||||
|
defer ticker.Stop()
|
||||||
|
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
return
|
||||||
|
case <-w.stopChan:
|
||||||
|
return
|
||||||
|
case <-ticker.C:
|
||||||
|
moved, err := w.queue.MoveReadyTasks(ctx)
|
||||||
|
if err != nil {
|
||||||
|
log.Printf("[TaskWorker] Error moving ready tasks: %v", err)
|
||||||
|
} else if moved > 0 {
|
||||||
|
log.Printf("[TaskWorker] Moved %d tasks to ready queue", moved)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// worker 工作器协程
|
||||||
|
func (w *TaskWorker) worker(ctx context.Context, id int) {
|
||||||
|
defer w.wg.Done()
|
||||||
|
|
||||||
|
log.Printf("[TaskWorker] Worker %d started", id)
|
||||||
|
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
log.Printf("[TaskWorker] Worker %d stopping (context done)", id)
|
||||||
|
return
|
||||||
|
case <-w.stopChan:
|
||||||
|
log.Printf("[TaskWorker] Worker %d stopping (stop signal)", id)
|
||||||
|
return
|
||||||
|
default:
|
||||||
|
// 尝试获取任务
|
||||||
|
task, err := w.queue.PopTask(ctx, time.Second)
|
||||||
|
if err != nil {
|
||||||
|
log.Printf("[TaskWorker] Worker %d error popping task: %v", id, err)
|
||||||
|
time.Sleep(time.Second) // 错误后等待
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if task == nil {
|
||||||
|
continue // 没有任务,继续轮询
|
||||||
|
}
|
||||||
|
|
||||||
|
// 处理任务
|
||||||
|
w.processTask(ctx, id, task)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// processTask 处理单个任务
|
||||||
|
func (w *TaskWorker) processTask(ctx context.Context, workerID int, task *DelayedTask) {
|
||||||
|
startTime := time.Now()
|
||||||
|
log.Printf("[TaskWorker] Worker %d processing task %s (type: %s, user: %d)",
|
||||||
|
workerID, task.ID, task.Type, task.UserID)
|
||||||
|
|
||||||
|
w.mu.RLock()
|
||||||
|
handler, exists := w.handlers[task.Type]
|
||||||
|
w.mu.RUnlock()
|
||||||
|
|
||||||
|
if !exists {
|
||||||
|
log.Printf("[TaskWorker] No handler for task type: %s", task.Type)
|
||||||
|
w.queue.CompleteTask(ctx, task)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 执行任务
|
||||||
|
err := handler(ctx, task)
|
||||||
|
duration := time.Since(startTime)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
log.Printf("[TaskWorker] Task %s failed (attempt %d/%d): %v",
|
||||||
|
task.ID, task.RetryCount+1, task.MaxRetries, err)
|
||||||
|
|
||||||
|
w.mu.Lock()
|
||||||
|
w.errorCount++
|
||||||
|
w.mu.Unlock()
|
||||||
|
|
||||||
|
// 重试(指数退避)
|
||||||
|
retryDelay := time.Duration(1<<task.RetryCount) * time.Minute
|
||||||
|
if err := w.queue.RetryTask(ctx, task, retryDelay); err != nil {
|
||||||
|
log.Printf("[TaskWorker] Failed to retry task %s: %v", task.ID, err)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 完成任务
|
||||||
|
if err := w.queue.CompleteTask(ctx, task); err != nil {
|
||||||
|
log.Printf("[TaskWorker] Failed to complete task %s: %v", task.ID, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
w.mu.Lock()
|
||||||
|
w.processedCount++
|
||||||
|
w.lastProcessed = time.Now()
|
||||||
|
w.mu.Unlock()
|
||||||
|
|
||||||
|
log.Printf("[TaskWorker] Task %s completed in %v", task.ID, duration)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetStats 获取工作器统计信息
|
||||||
|
func (w *TaskWorker) GetStats() map[string]interface{} {
|
||||||
|
w.mu.RLock()
|
||||||
|
defer w.mu.RUnlock()
|
||||||
|
|
||||||
|
return map[string]interface{}{
|
||||||
|
"running": w.running,
|
||||||
|
"worker_count": w.workerCount,
|
||||||
|
"poll_interval": w.pollInterval.String(),
|
||||||
|
"processed_count": w.processedCount,
|
||||||
|
"error_count": w.errorCount,
|
||||||
|
"last_processed": w.lastProcessed,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ParseRecurringPayload 解析周期性交易任务载荷
|
||||||
|
func ParseRecurringPayload(task *DelayedTask) (*RecurringTransactionPayload, error) {
|
||||||
|
var payload RecurringTransactionPayload
|
||||||
|
if err := json.Unmarshal(task.Payload, &payload); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return &payload, nil
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user