diff --git a/internal/mq/recurring_handler.go b/internal/mq/recurring_handler.go index 26b39d7..c6258e7 100644 --- a/internal/mq/recurring_handler.go +++ b/internal/mq/recurring_handler.go @@ -39,12 +39,27 @@ func NewRecurringTransactionHandler( } // Handle 处理周期性交易任务 +// 遵循 v7.0 规范:并发安全与幂等性 func (h *RecurringTransactionHandler) Handle(ctx context.Context, task *DelayedTask) error { payload, err := ParseRecurringPayload(task) if err != nil { return fmt.Errorf("failed to parse payload: %w", err) } + // 幂等性检查:防止同一笔交易重复执行 + // 使用 Redis SETNX 锁,锁 24 小时(足以覆盖重试窗口) + // TaskID 本身包含 recurringID 和执行时间戳,是天然的幂等 Key + lockKey := fmt.Sprintf("novault:lock:recurring:%s", task.ID) + // 使用 TaskQueue 提供的公开方法获取锁 + isNew, err := h.taskQueue.AcquireLock(ctx, lockKey, 24*time.Hour) + if err != nil { + return fmt.Errorf("failed to check idempotency: %w", err) + } + if !isNew { + log.Printf("[RecurringHandler] Task %s already processed (idempotency check), skipping", task.ID) + return nil + } + log.Printf("[RecurringHandler] Processing recurring transaction %d for user %d", payload.RecurringTransactionID, task.UserID) @@ -94,23 +109,24 @@ func (h *RecurringTransactionHandler) Handle(ctx context.Context, task *DelayedT return fmt.Errorf("failed to create transaction: %w", err) } - // 更新账户余额 - var account models.Account - if err := tx.First(&account, recurring.AccountID).Error; err != nil { - tx.Rollback() - return fmt.Errorf("failed to get account: %w", err) - } - + // 更新账户余额 - 使用原子操作修复并发安全问题 (Race Condition Fix) + // UPDATE accounts SET balance = balance +/- ? WHERE id = ? + var updateExpr interface{} switch recurring.Type { case models.TransactionTypeIncome: - account.Balance += recurring.Amount + updateExpr = gorm.Expr("balance + ?", recurring.Amount) case models.TransactionTypeExpense: - account.Balance -= recurring.Amount + updateExpr = gorm.Expr("balance - ?", recurring.Amount) + default: + tx.Rollback() + return fmt.Errorf("unknown transaction type: %s", recurring.Type) } - if err := tx.Save(&account).Error; err != nil { + if err := tx.Model(&models.Account{}). + Where("id = ?", recurring.AccountID). + Update("balance", updateExpr).Error; err != nil { tx.Rollback() - return fmt.Errorf("failed to update account balance: %w", err) + return fmt.Errorf("failed to update account balance atomic: %w", err) } // 计算下一次执行时间 @@ -140,7 +156,7 @@ func (h *RecurringTransactionHandler) Handle(ctx context.Context, task *DelayedT if recurring.IsActive { if err := h.scheduleNext(ctx, recurring); err != nil { log.Printf("[RecurringHandler] Warning: failed to schedule next execution: %v", err) - // 不返回错误,当前任务已成功完成 + // 不返回错误,当前任务已成功完成,仅仅是下次调度失败(会有补偿机制兜底) } } diff --git a/internal/mq/task_queue.go b/internal/mq/task_queue.go index c2579e1..e5ceb20 100644 --- a/internal/mq/task_queue.go +++ b/internal/mq/task_queue.go @@ -218,3 +218,8 @@ func (q *TaskQueue) GetPendingTasks(ctx context.Context, limit int64) ([]Delayed return tasks, nil } + +// AcquireLock 尝试获取分布式锁(用于幂等性检查) +func (q *TaskQueue) AcquireLock(ctx context.Context, key string, ttl time.Duration) (bool, error) { + return q.client.SetNX(ctx, key, "1", ttl).Result() +}