Files
Novault-backend/internal/mq/recurring_handler.go

262 lines
8.2 KiB
Go
Raw Normal View History

package mq
import (
"context"
"fmt"
"log"
"time"
"accounting-app/internal/models"
"accounting-app/internal/repository"
"gorm.io/gorm"
)
// RecurringTransactionHandler 周期性交易任务处理器
type RecurringTransactionHandler struct {
db *gorm.DB
recurringRepo *repository.RecurringTransactionRepository
transactionRepo *repository.TransactionRepository
accountRepo *repository.AccountRepository
taskQueue *TaskQueue
}
// NewRecurringTransactionHandler 创建处理器实例
func NewRecurringTransactionHandler(
db *gorm.DB,
recurringRepo *repository.RecurringTransactionRepository,
transactionRepo *repository.TransactionRepository,
accountRepo *repository.AccountRepository,
taskQueue *TaskQueue,
) *RecurringTransactionHandler {
return &RecurringTransactionHandler{
db: db,
recurringRepo: recurringRepo,
transactionRepo: transactionRepo,
accountRepo: accountRepo,
taskQueue: taskQueue,
}
}
// Handle 处理周期性交易任务
// 遵循 v7.0 规范:并发安全与幂等性
func (h *RecurringTransactionHandler) Handle(ctx context.Context, task *DelayedTask) error {
payload, err := ParseRecurringPayload(task)
if err != nil {
return fmt.Errorf("failed to parse payload: %w", err)
}
// 幂等性检查:防止同一笔交易重复执行
// 使用 Redis SETNX 锁,锁 24 小时(足以覆盖重试窗口)
// TaskID 本身包含 recurringID 和执行时间戳,是天然的幂等 Key
lockKey := fmt.Sprintf("novault:lock:recurring:%s", task.ID)
// 使用 TaskQueue 提供的公开方法获取锁
isNew, err := h.taskQueue.AcquireLock(ctx, lockKey, 24*time.Hour)
if err != nil {
return fmt.Errorf("failed to check idempotency: %w", err)
}
if !isNew {
log.Printf("[RecurringHandler] Task %s already processed (idempotency check), skipping", task.ID)
return nil
}
log.Printf("[RecurringHandler] Processing recurring transaction %d for user %d",
payload.RecurringTransactionID, task.UserID)
// 获取周期性交易记录
recurring, err := h.recurringRepo.GetByID(task.UserID, payload.RecurringTransactionID)
if err != nil {
if err == gorm.ErrRecordNotFound {
log.Printf("[RecurringHandler] Recurring transaction %d not found, skipping", payload.RecurringTransactionID)
return nil // 已被删除,不需要重试
}
return fmt.Errorf("failed to get recurring transaction: %w", err)
}
// 检查是否已禁用或已过期
if !recurring.IsActive {
log.Printf("[RecurringHandler] Recurring transaction %d is inactive, skipping", recurring.ID)
return nil
}
if recurring.EndDate != nil && time.Now().After(*recurring.EndDate) {
log.Printf("[RecurringHandler] Recurring transaction %d has ended, deactivating", recurring.ID)
recurring.IsActive = false
return h.recurringRepo.Update(recurring)
}
// 开始数据库事务
tx := h.db.Begin()
if tx.Error != nil {
return fmt.Errorf("failed to begin transaction: %w", tx.Error)
}
// 创建实际交易记录
transaction := &models.Transaction{
UserID: recurring.UserID,
Amount: recurring.Amount,
Type: recurring.Type,
CategoryID: recurring.CategoryID,
AccountID: recurring.AccountID,
Currency: recurring.Currency,
TransactionDate: recurring.NextOccurrence,
Note: recurring.Note,
RecurringID: &recurring.ID,
}
if err := tx.Create(transaction).Error; err != nil {
tx.Rollback()
return fmt.Errorf("failed to create transaction: %w", err)
}
// 更新账户余额 - 使用原子操作修复并发安全问题 (Race Condition Fix)
// UPDATE accounts SET balance = balance +/- ? WHERE id = ?
var updateExpr interface{}
switch recurring.Type {
case models.TransactionTypeIncome:
updateExpr = gorm.Expr("balance + ?", recurring.Amount)
case models.TransactionTypeExpense:
updateExpr = gorm.Expr("balance - ?", recurring.Amount)
default:
tx.Rollback()
return fmt.Errorf("unknown transaction type: %s", recurring.Type)
}
if err := tx.Model(&models.Account{}).
Where("id = ?", recurring.AccountID).
Update("balance", updateExpr).Error; err != nil {
tx.Rollback()
return fmt.Errorf("failed to update account balance atomic: %w", err)
}
// 计算下一次执行时间
nextOccurrence := h.calculateNextOccurrence(recurring.NextOccurrence, recurring.Frequency)
recurring.NextOccurrence = nextOccurrence
// 检查下次执行是否超过结束日期
if recurring.EndDate != nil && nextOccurrence.After(*recurring.EndDate) {
recurring.IsActive = false
log.Printf("[RecurringHandler] Recurring transaction %d will end after this execution", recurring.ID)
}
if err := tx.Save(recurring).Error; err != nil {
tx.Rollback()
return fmt.Errorf("failed to update recurring transaction: %w", err)
}
// 提交事务
if err := tx.Commit().Error; err != nil {
return fmt.Errorf("failed to commit transaction: %w", err)
}
log.Printf("[RecurringHandler] Created transaction %d from recurring %d, amount: %.2f %s",
transaction.ID, recurring.ID, recurring.Amount, recurring.Currency)
// 如果仍然活跃,调度下一次执行
if recurring.IsActive {
if err := h.scheduleNext(ctx, recurring); err != nil {
log.Printf("[RecurringHandler] Warning: failed to schedule next execution: %v", err)
// 不返回错误,当前任务已成功完成,仅仅是下次调度失败(会有补偿机制兜底)
}
}
return nil
}
// calculateNextOccurrence 计算下一次执行时间
func (h *RecurringTransactionHandler) calculateNextOccurrence(current time.Time, frequency models.FrequencyType) time.Time {
switch frequency {
case models.FrequencyDaily:
return current.AddDate(0, 0, 1)
case models.FrequencyWeekly:
return current.AddDate(0, 0, 7)
case models.FrequencyMonthly:
return current.AddDate(0, 1, 0)
case models.FrequencyYearly:
return current.AddDate(1, 0, 0)
default:
return current.AddDate(0, 1, 0) // 默认月度
}
}
// scheduleNext 调度下一次执行
func (h *RecurringTransactionHandler) scheduleNext(ctx context.Context, recurring *models.RecurringTransaction) error {
return h.taskQueue.ScheduleRecurringTransaction(
ctx,
recurring.UserID,
recurring.ID,
recurring.NextOccurrence,
)
}
// ScheduleAllActive 调度所有活跃的周期性交易(启动时调用)
func (h *RecurringTransactionHandler) ScheduleAllActive(ctx context.Context) error {
log.Println("[RecurringHandler] Scheduling all active recurring transactions...")
// 获取所有活跃的周期性交易
var recurringList []models.RecurringTransaction
err := h.db.Where("is_active = ?", true).Find(&recurringList).Error
if err != nil {
return fmt.Errorf("failed to get active recurring transactions: %w", err)
}
scheduledCount := 0
for _, recurring := range recurringList {
// 跳过已过期的
if recurring.EndDate != nil && time.Now().After(*recurring.EndDate) {
recurring.IsActive = false
h.db.Save(&recurring)
continue
}
// 调度任务
err := h.taskQueue.ScheduleRecurringTransaction(
ctx,
recurring.UserID,
recurring.ID,
recurring.NextOccurrence,
)
if err != nil {
log.Printf("[RecurringHandler] Failed to schedule recurring %d: %v", recurring.ID, err)
continue
}
scheduledCount++
}
log.Printf("[RecurringHandler] Scheduled %d recurring transactions", scheduledCount)
return nil
}
// ProcessOverdue 处理所有逾期的周期性交易(补偿机制)
func (h *RecurringTransactionHandler) ProcessOverdue(ctx context.Context) error {
log.Println("[RecurringHandler] Processing overdue recurring transactions...")
now := time.Now()
var overdueList []models.RecurringTransaction
err := h.db.Where("is_active = ? AND next_occurrence <= ?", true, now).Find(&overdueList).Error
if err != nil {
return fmt.Errorf("failed to get overdue recurring transactions: %w", err)
}
if len(overdueList) == 0 {
log.Println("[RecurringHandler] No overdue recurring transactions found")
return nil
}
log.Printf("[RecurringHandler] Found %d overdue recurring transactions", len(overdueList))
for _, recurring := range overdueList {
err := h.taskQueue.ScheduleRecurringTransaction(
ctx,
recurring.UserID,
recurring.ID,
recurring.NextOccurrence,
)
if err != nil {
log.Printf("[RecurringHandler] Failed to schedule overdue recurring %d: %v", recurring.ID, err)
}
}
return nil
}