diff --git a/internal/repository/budget_repository.go b/internal/repository/budget_repository.go index 410a864..ce3f8db 100644 --- a/internal/repository/budget_repository.go +++ b/internal/repository/budget_repository.go @@ -144,7 +144,13 @@ func (r *BudgetRepository) GetSpentAmount(budget *models.Budget, startDate, endD // Filter by category if specified if budget.CategoryID != nil { - query = query.Where("category_id = ?", *budget.CategoryID) + // Get sub-categories + var subCategoryIDs []uint + r.db.Model(&models.Category{}).Where("parent_id = ?", *budget.CategoryID).Pluck("id", &subCategoryIDs) + + // Include the category itself and all its children + categoryIDs := append(subCategoryIDs, *budget.CategoryID) + query = query.Where("category_id IN ?", categoryIDs) } // Filter by account if specified diff --git a/internal/service/account_service.go b/internal/service/account_service.go index 4efdb80..45b41b8 100644 --- a/internal/service/account_service.go +++ b/internal/service/account_service.go @@ -12,12 +12,12 @@ import ( // Service layer errors var ( - ErrAccountNotFound = errors.New("account not found") - ErrAccountInUse = errors.New("account is in use and cannot be deleted") - ErrInsufficientBalance = errors.New("insufficient balance for this operation") - ErrSameAccountTransfer = errors.New("cannot transfer to the same account") - ErrInvalidTransferAmount = errors.New("transfer amount must be positive") - ErrNegativeBalanceNotAllowed = errors.New("negative balance not allowed for non-credit accounts") + ErrAccountNotFound = errors.New("账户不存在") + ErrAccountInUse = errors.New("账户正在使用中,无法删除") + ErrInsufficientBalance = errors.New("余额不足") + ErrSameAccountTransfer = errors.New("不能转账给同一个账户") + ErrInvalidTransferAmount = errors.New("转账金额必须大于0") + ErrNegativeBalanceNotAllowed = errors.New("非信用账户不允许负余额") ) // AccountInput represents the input data for creating or updating an account diff --git a/internal/service/ai_bookkeeping_service.go b/internal/service/ai_bookkeeping_service.go index 2b1593e..66ec7b1 100644 --- a/internal/service/ai_bookkeeping_service.go +++ b/internal/service/ai_bookkeeping_service.go @@ -1407,9 +1407,71 @@ func (s *AIBookkeepingService) GenerateConfirmationCard(session *AISession) *Con card.Date = time.Now().Format("2006-01-02") } + // Check for budget warnings + if params.Amount != nil && *params.Amount > 0 && params.Type == "expense" { + card.Warning = s.checkBudgetWarning(session.UserID, params.CategoryID, *params.Amount) + } + return card } +// checkBudgetWarning checks if the transaction exceeds any budget +func (s *AIBookkeepingService) checkBudgetWarning(userID uint, categoryID *uint, amount float64) string { + now := time.Now() + var budgets []models.Budget + + // Find active budgets + // We check: + // 1. Budgets specifically for this category + // 2. Global budgets (CategoryID is NULL) + query := s.db.Where("user_id = ?", userID). + Where("start_date <= ?", now). + Where("end_date IS NULL OR end_date >= ?", now) + + if categoryID != nil { + query = query.Where("category_id = ? OR category_id IS NULL", *categoryID) + } else { + query = query.Where("category_id IS NULL") + } + + if err := query.Find(&budgets).Error; err != nil { + return "" + } + + for _, budget := range budgets { + // Calculate current period + start, end := s.calculateBudgetPeriod(&budget, now) + + // Query spent amount + var totalSpent float64 + q := s.db.Model(&models.Transaction{}). + Where("user_id = ? AND type = ? AND transaction_date BETWEEN ? AND ?", + userID, models.TransactionTypeExpense, start, end) + + if budget.CategoryID != nil { + // Get sub-categories + var subCategoryIDs []uint + s.db.Model(&models.Category{}).Where("parent_id = ?", *budget.CategoryID).Pluck("id", &subCategoryIDs) + categoryIDs := append(subCategoryIDs, *budget.CategoryID) + q = q.Where("category_id IN ?", categoryIDs) + } + // If budget has account restriction, we should ideally check that too, + // but we don't always have accountID resolved here perfectly or it might be complex. + // For now, focusing on category budgets which are most common. + + q.Select("COALESCE(SUM(amount), 0)").Scan(&totalSpent) + + if totalSpent+amount > budget.Amount { + remaining := budget.Amount - totalSpent + if remaining < 0 { + remaining = 0 + } + return fmt.Sprintf("⚠️ 预算预警:此交易将使【%s】预算超支 (当前剩余 %.2f)", budget.Name, remaining) + } + } + return "" +} + // TranscribeAudio transcribes audio and returns text func (s *AIBookkeepingService) TranscribeAudio(ctx context.Context, audioData io.Reader, filename string) (*TranscriptionResult, error) { return s.whisperService.TranscribeAudio(ctx, audioData, filename) @@ -1430,13 +1492,13 @@ func (s *AIBookkeepingService) ConfirmTransaction(ctx context.Context, sessionID // Validate required fields if params.Amount == nil || *params.Amount <= 0 { - return nil, errors.New("invalid amount") + return nil, errors.New("无效的金额") } if params.CategoryID == nil { - return nil, errors.New("category not specified") + return nil, errors.New("未指定分类") } if params.AccountID == nil { - return nil, errors.New("account not specified") + return nil, errors.New("未指定账户") } // Parse date @@ -1480,7 +1542,7 @@ func (s *AIBookkeepingService) ConfirmTransaction(ctx context.Context, sessionID // Critical Check: Prevent negative balance for non-credit accounts if !account.IsCredit && newBalance < 0 { - return fmt.Errorf("insufficient balance: account '%s' does not support negative balance (current: %.2f, try: %.2f)", + return fmt.Errorf("余额不足:账户“%s”不支持负余额 (当前: %.2f, 尝试扣款: %.2f)", account.Name, account.Balance, *params.Amount) } diff --git a/internal/service/transaction_service.go b/internal/service/transaction_service.go index 1ed3357..b65a451 100644 --- a/internal/service/transaction_service.go +++ b/internal/service/transaction_service.go @@ -11,18 +11,19 @@ import ( "gorm.io/gorm" ) +// Transaction service errors // Transaction service errors var ( - ErrTransactionNotFound = errors.New("transaction not found") - ErrInvalidTransactionType = errors.New("invalid transaction type") - ErrMissingRequiredField = errors.New("missing required field") - ErrInvalidAmount = errors.New("amount must be positive") - ErrInvalidCurrency = errors.New("invalid currency") - ErrCategoryNotFoundForTxn = errors.New("category not found") - ErrAccountNotFoundForTxn = errors.New("account not found") - ErrToAccountNotFoundForTxn = errors.New("destination account not found for transfer") - ErrToAccountRequiredForTxn = errors.New("destination account is required for transfer transactions") - ErrSameAccountTransferForTxn = errors.New("cannot transfer to the same account") + ErrTransactionNotFound = errors.New("交易不存在") + ErrInvalidTransactionType = errors.New("无效的交易类型") + ErrMissingRequiredField = errors.New("缺少必填字段") + ErrInvalidAmount = errors.New("金额必须大于0") + ErrInvalidCurrency = errors.New("无效的货币") + ErrCategoryNotFoundForTxn = errors.New("分类不存在") + ErrAccountNotFoundForTxn = errors.New("账户不存在") + ErrToAccountNotFoundForTxn = errors.New("转账目标账户不存在") + ErrToAccountRequiredForTxn = errors.New("转账必须指定目标账户") + ErrSameAccountTransferForTxn = errors.New("不能转账给同一个账户") ) // TransactionInput represents the input data for creating or updating a transaction