package mq import ( "context" "fmt" "log" "time" "accounting-app/internal/models" "accounting-app/internal/repository" "gorm.io/gorm" ) // RecurringTransactionHandler 周期性交易任务处理器 type RecurringTransactionHandler struct { db *gorm.DB recurringRepo *repository.RecurringTransactionRepository transactionRepo *repository.TransactionRepository accountRepo *repository.AccountRepository taskQueue *TaskQueue } // NewRecurringTransactionHandler 创建处理器实例 func NewRecurringTransactionHandler( db *gorm.DB, recurringRepo *repository.RecurringTransactionRepository, transactionRepo *repository.TransactionRepository, accountRepo *repository.AccountRepository, taskQueue *TaskQueue, ) *RecurringTransactionHandler { return &RecurringTransactionHandler{ db: db, recurringRepo: recurringRepo, transactionRepo: transactionRepo, accountRepo: accountRepo, taskQueue: taskQueue, } } // Handle 处理周期性交易任务 func (h *RecurringTransactionHandler) Handle(ctx context.Context, task *DelayedTask) error { payload, err := ParseRecurringPayload(task) if err != nil { return fmt.Errorf("failed to parse payload: %w", err) } log.Printf("[RecurringHandler] Processing recurring transaction %d for user %d", payload.RecurringTransactionID, task.UserID) // 获取周期性交易记录 recurring, err := h.recurringRepo.GetByID(task.UserID, payload.RecurringTransactionID) if err != nil { if err == gorm.ErrRecordNotFound { log.Printf("[RecurringHandler] Recurring transaction %d not found, skipping", payload.RecurringTransactionID) return nil // 已被删除,不需要重试 } return fmt.Errorf("failed to get recurring transaction: %w", err) } // 检查是否已禁用或已过期 if !recurring.IsActive { log.Printf("[RecurringHandler] Recurring transaction %d is inactive, skipping", recurring.ID) return nil } if recurring.EndDate != nil && time.Now().After(*recurring.EndDate) { log.Printf("[RecurringHandler] Recurring transaction %d has ended, deactivating", recurring.ID) recurring.IsActive = false return h.recurringRepo.Update(recurring) } // 开始数据库事务 tx := h.db.Begin() if tx.Error != nil { return fmt.Errorf("failed to begin transaction: %w", tx.Error) } // 创建实际交易记录 transaction := &models.Transaction{ UserID: recurring.UserID, Amount: recurring.Amount, Type: recurring.Type, CategoryID: recurring.CategoryID, AccountID: recurring.AccountID, Currency: recurring.Currency, TransactionDate: recurring.NextOccurrence, Note: recurring.Note, RecurringID: &recurring.ID, } if err := tx.Create(transaction).Error; err != nil { tx.Rollback() return fmt.Errorf("failed to create transaction: %w", err) } // 更新账户余额 var account models.Account if err := tx.First(&account, recurring.AccountID).Error; err != nil { tx.Rollback() return fmt.Errorf("failed to get account: %w", err) } switch recurring.Type { case models.TransactionTypeIncome: account.Balance += recurring.Amount case models.TransactionTypeExpense: account.Balance -= recurring.Amount } if err := tx.Save(&account).Error; err != nil { tx.Rollback() return fmt.Errorf("failed to update account balance: %w", err) } // 计算下一次执行时间 nextOccurrence := h.calculateNextOccurrence(recurring.NextOccurrence, recurring.Frequency) recurring.NextOccurrence = nextOccurrence // 检查下次执行是否超过结束日期 if recurring.EndDate != nil && nextOccurrence.After(*recurring.EndDate) { recurring.IsActive = false log.Printf("[RecurringHandler] Recurring transaction %d will end after this execution", recurring.ID) } if err := tx.Save(recurring).Error; err != nil { tx.Rollback() return fmt.Errorf("failed to update recurring transaction: %w", err) } // 提交事务 if err := tx.Commit().Error; err != nil { return fmt.Errorf("failed to commit transaction: %w", err) } log.Printf("[RecurringHandler] Created transaction %d from recurring %d, amount: %.2f %s", transaction.ID, recurring.ID, recurring.Amount, recurring.Currency) // 如果仍然活跃,调度下一次执行 if recurring.IsActive { if err := h.scheduleNext(ctx, recurring); err != nil { log.Printf("[RecurringHandler] Warning: failed to schedule next execution: %v", err) // 不返回错误,当前任务已成功完成 } } return nil } // calculateNextOccurrence 计算下一次执行时间 func (h *RecurringTransactionHandler) calculateNextOccurrence(current time.Time, frequency models.FrequencyType) time.Time { switch frequency { case models.FrequencyDaily: return current.AddDate(0, 0, 1) case models.FrequencyWeekly: return current.AddDate(0, 0, 7) case models.FrequencyMonthly: return current.AddDate(0, 1, 0) case models.FrequencyYearly: return current.AddDate(1, 0, 0) default: return current.AddDate(0, 1, 0) // 默认月度 } } // scheduleNext 调度下一次执行 func (h *RecurringTransactionHandler) scheduleNext(ctx context.Context, recurring *models.RecurringTransaction) error { return h.taskQueue.ScheduleRecurringTransaction( ctx, recurring.UserID, recurring.ID, recurring.NextOccurrence, ) } // ScheduleAllActive 调度所有活跃的周期性交易(启动时调用) func (h *RecurringTransactionHandler) ScheduleAllActive(ctx context.Context) error { log.Println("[RecurringHandler] Scheduling all active recurring transactions...") // 获取所有活跃的周期性交易 var recurringList []models.RecurringTransaction err := h.db.Where("is_active = ?", true).Find(&recurringList).Error if err != nil { return fmt.Errorf("failed to get active recurring transactions: %w", err) } scheduledCount := 0 for _, recurring := range recurringList { // 跳过已过期的 if recurring.EndDate != nil && time.Now().After(*recurring.EndDate) { recurring.IsActive = false h.db.Save(&recurring) continue } // 调度任务 err := h.taskQueue.ScheduleRecurringTransaction( ctx, recurring.UserID, recurring.ID, recurring.NextOccurrence, ) if err != nil { log.Printf("[RecurringHandler] Failed to schedule recurring %d: %v", recurring.ID, err) continue } scheduledCount++ } log.Printf("[RecurringHandler] Scheduled %d recurring transactions", scheduledCount) return nil } // ProcessOverdue 处理所有逾期的周期性交易(补偿机制) func (h *RecurringTransactionHandler) ProcessOverdue(ctx context.Context) error { log.Println("[RecurringHandler] Processing overdue recurring transactions...") now := time.Now() var overdueList []models.RecurringTransaction err := h.db.Where("is_active = ? AND next_occurrence <= ?", true, now).Find(&overdueList).Error if err != nil { return fmt.Errorf("failed to get overdue recurring transactions: %w", err) } if len(overdueList) == 0 { log.Println("[RecurringHandler] No overdue recurring transactions found") return nil } log.Printf("[RecurringHandler] Found %d overdue recurring transactions", len(overdueList)) for _, recurring := range overdueList { err := h.taskQueue.ScheduleRecurringTransaction( ctx, recurring.UserID, recurring.ID, recurring.NextOccurrence, ) if err != nil { log.Printf("[RecurringHandler] Failed to schedule overdue recurring %d: %v", recurring.ID, err) } } return nil }