From cf34f8b3d01c9b82f239ec73b6ff2fcd2aef983b Mon Sep 17 00:00:00 2001 From: 12975 <1297598740@qq.com> Date: Thu, 29 Jan 2026 21:43:35 +0800 Subject: [PATCH] =?UTF-8?q?feat:=20=E6=B7=BB=E5=8A=A0=E9=A2=84=E7=AE=97?= =?UTF-8?q?=E7=AE=A1=E7=90=86=E5=8A=9F=E8=83=BD=EF=BC=8C=E5=8C=85=E6=8B=AC?= =?UTF-8?q?=E9=A2=84=E7=AE=97=E7=9A=84=E5=88=9B=E5=BB=BA=E3=80=81=E6=9F=A5?= =?UTF-8?q?=E8=AF=A2=E3=80=81=E6=9B=B4=E6=96=B0=E3=80=81=E5=88=A0=E9=99=A4?= =?UTF-8?q?=E5=8F=8A=E8=BF=9B=E5=BA=A6=E8=AE=A1=E7=AE=97=E3=80=82?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- internal/handler/budget_handler.go | 6 +- internal/models/models.go | 13 ++++ internal/service/account_service.go | 28 +++++++- internal/service/budget_service.go | 100 ++++++++++++++++++++-------- 4 files changed, 112 insertions(+), 35 deletions(-) diff --git a/internal/handler/budget_handler.go b/internal/handler/budget_handler.go index 7b59404..82669f1 100644 --- a/internal/handler/budget_handler.go +++ b/internal/handler/budget_handler.go @@ -3,8 +3,8 @@ package handler import ( "strconv" - "accounting-app/pkg/api" "accounting-app/internal/service" + "accounting-app/pkg/api" "github.com/gin-gonic/gin" ) @@ -60,8 +60,6 @@ func (h *BudgetHandler) CreateBudget(c *gin.Context) { api.BadRequest(c, err.Error()) case service.ErrInvalidPeriodType: api.BadRequest(c, err.Error()) - case service.ErrCategoryOrAccountRequired: - api.BadRequest(c, err.Error()) default: api.InternalError(c, "Failed to create budget") } @@ -150,8 +148,6 @@ func (h *BudgetHandler) UpdateBudget(c *gin.Context) { api.BadRequest(c, err.Error()) case service.ErrInvalidPeriodType: api.BadRequest(c, err.Error()) - case service.ErrCategoryOrAccountRequired: - api.BadRequest(c, err.Error()) default: api.InternalError(c, "Failed to update budget") } diff --git a/internal/models/models.go b/internal/models/models.go index 35f415d..626a00e 100644 --- a/internal/models/models.go +++ b/internal/models/models.go @@ -278,6 +278,7 @@ type Account struct { PiggyBanks []PiggyBank `gorm:"foreignKey:LinkedAccountID" json:"-"` ParentAccount *Account `gorm:"foreignKey:ParentAccountID" json:"parent_account,omitempty"` SubAccounts []Account `gorm:"foreignKey:ParentAccountID" json:"sub_accounts,omitempty"` + Tags []Tag `gorm:"many2many:account_tags;" json:"tags,omitempty"` } // TableName specifies the table name for Account @@ -417,6 +418,17 @@ func (TransactionTag) TableName() string { return "transaction_tags" } +// AccountTag represents the many-to-many relationship between accounts and tags +type AccountTag struct { + AccountID uint `gorm:"primaryKey" json:"account_id"` + TagID uint `gorm:"primaryKey" json:"tag_id"` +} + +// TableName specifies the table name for AccountTag +func (AccountTag) TableName() string { + return "account_tags" +} + // Budget represents a spending budget for a category or account type Budget struct { BaseModel @@ -843,6 +855,7 @@ func AllModels() []interface{} { &Tag{}, &Transaction{}, &TransactionTag{}, // Explicit join table for many-to-many relationship + &AccountTag{}, // Explicit join table for account-tag many-to-many relationship &Budget{}, &PiggyBank{}, &RecurringTransaction{}, diff --git a/internal/service/account_service.go b/internal/service/account_service.go index 4aa263f..4efdb80 100644 --- a/internal/service/account_service.go +++ b/internal/service/account_service.go @@ -32,6 +32,7 @@ type AccountInput struct { PaymentDate *int `json:"payment_date,omitempty"` WarningThreshold *float64 `json:"warning_threshold,omitempty"` AccountCode string `json:"account_code,omitempty"` + TagIDs []uint `json:"tag_ids,omitempty"` } // TransferInput represents the input data for a transfer operation @@ -99,6 +100,18 @@ func (s *AccountService) CreateAccount(userID uint, input AccountInput) (*models return nil, fmt.Errorf("failed to create account: %w", err) } + // Handle tags association + if len(input.TagIDs) > 0 { + var tags []models.Tag + if err := s.db.Where("id IN ? AND user_id = ?", input.TagIDs, userID).Find(&tags).Error; err != nil { + return nil, fmt.Errorf("failed to find tags: %w", err) + } + if err := s.db.Model(account).Association("Tags").Replace(tags); err != nil { + return nil, fmt.Errorf("failed to associate tags: %w", err) + } + account.Tags = tags + } + return account, nil } @@ -117,7 +130,8 @@ func (s *AccountService) GetAccount(userID, id uint) (*models.Account, error) { // GetAllAccounts retrieves all accounts for a specific user func (s *AccountService) GetAllAccounts(userID uint) ([]models.Account, error) { - accounts, err := s.repo.GetAll(userID) + var accounts []models.Account + err := s.db.Where("user_id = ?", userID).Preload("Tags").Order("sort_order ASC, id ASC").Find(&accounts).Error if err != nil { return nil, fmt.Errorf("failed to get accounts: %w", err) } @@ -164,6 +178,18 @@ func (s *AccountService) UpdateAccount(userID, id uint, input AccountInput) (*mo return nil, fmt.Errorf("failed to update account: %w", err) } + // Handle tags association + var tags []models.Tag + if len(input.TagIDs) > 0 { + if err := s.db.Where("id IN ? AND user_id = ?", input.TagIDs, userID).Find(&tags).Error; err != nil { + return nil, fmt.Errorf("failed to find tags: %w", err) + } + } + if err := s.db.Model(account).Association("Tags").Replace(tags); err != nil { + return nil, fmt.Errorf("failed to update tags: %w", err) + } + account.Tags = tags + return account, nil } diff --git a/internal/service/budget_service.go b/internal/service/budget_service.go index a355b30..350b9fb 100644 --- a/internal/service/budget_service.go +++ b/internal/service/budget_service.go @@ -13,12 +13,11 @@ import ( // Service layer errors for budgets var ( - ErrBudgetNotFound = errors.New("budget not found") - ErrBudgetInUse = errors.New("budget is in use and cannot be deleted") - ErrInvalidBudgetAmount = errors.New("budget amount must be positive") - ErrInvalidDateRange = errors.New("end date must be after start date") - ErrInvalidPeriodType = errors.New("invalid period type") - ErrCategoryOrAccountRequired = errors.New("either category or account must be specified") + ErrBudgetNotFound = errors.New("budget not found") + ErrBudgetInUse = errors.New("budget is in use and cannot be deleted") + ErrInvalidBudgetAmount = errors.New("budget amount must be positive") + ErrInvalidDateRange = errors.New("end date must be after start date") + ErrInvalidPeriodType = errors.New("invalid period type") ) // BudgetInput represents the input data for creating or updating a budget @@ -72,10 +71,7 @@ func (s *BudgetService) CreateBudget(input BudgetInput) (*models.Budget, error) return nil, ErrInvalidBudgetAmount } - // Validate that at least category or account is specified - if input.CategoryID == nil && input.AccountID == nil { - return nil, ErrCategoryOrAccountRequired - } + // 分类和账户都可选,支持全局预算 // Validate date range if input.EndDate != nil && input.EndDate.Before(input.StartDate) { @@ -147,10 +143,7 @@ func (s *BudgetService) UpdateBudget(userID, id uint, input BudgetInput) (*model return nil, ErrInvalidBudgetAmount } - // Validate that at least category or account is specified - if input.CategoryID == nil && input.AccountID == nil { - return nil, ErrCategoryOrAccountRequired - } + // 分类和账户都可选,支持全局预算 // Validate date range if input.EndDate != nil && input.EndDate.Before(input.StartDate) { @@ -222,41 +215,55 @@ func (s *BudgetService) GetBudgetProgress(userID, id uint) (*BudgetProgress, err startDate, endDate := s.calculateCurrentPeriod(budget, now) // Get spent amount for the current period - spent, err := s.repo.GetSpentAmount(budget, startDate, endDate) + currentSpent, err := s.repo.GetSpentAmount(budget, startDate, endDate) if err != nil { return nil, fmt.Errorf("failed to calculate spent amount: %w", err) } - // Calculate effective budget amount (considering rolling budget) + // Calculate effective budget amount effectiveAmount := budget.Amount + totalSpent := currentSpent + if budget.IsRolling { - // For rolling budgets, add the previous period's remaining balance - prevStartDate, prevEndDate := s.calculatePreviousPeriod(budget, now) - prevSpent, err := s.repo.GetSpentAmount(budget, prevStartDate, prevEndDate) - if err != nil { - return nil, fmt.Errorf("failed to calculate previous period spent: %w", err) - } - prevRemaining := budget.Amount - prevSpent - if prevRemaining > 0 { - effectiveAmount += prevRemaining + // 滚动预算:结余自动累加到下一周期 + // 当期可用额度 = 总额度 - 历史支出 + + // 计算已过的完整周期数(不含当期) + periodsElapsed := s.calculatePeriodsElapsed(budget, startDate) + + // 总额度 = (已过周期数 + 当期) × 单期额度 + totalBudget := budget.Amount * float64(periodsElapsed+1) + + // 获取历史支出(从预算开始到当期开始前一秒) + historyEnd := startDate.Add(-time.Second) + historySpent := 0.0 + if historyEnd.After(budget.StartDate) { + historySpent, err = s.repo.GetSpentAmount(budget, budget.StartDate, historyEnd) + if err != nil { + return nil, fmt.Errorf("failed to calculate history spent: %w", err) + } } + + // 当期可用额度 = 总额度 - 历史支出 + effectiveAmount = totalBudget - historySpent + totalSpent = currentSpent } // Calculate progress metrics - remaining := effectiveAmount - spent + remaining := effectiveAmount - totalSpent progress := 0.0 if effectiveAmount > 0 { - progress = (spent / effectiveAmount) * 100 + progress = (totalSpent / effectiveAmount) * 100 } - isOverBudget := spent > effectiveAmount + isOverBudget := totalSpent > effectiveAmount isNearLimit := progress >= 80.0 && !isOverBudget return &BudgetProgress{ BudgetID: budget.ID, Name: budget.Name, Amount: effectiveAmount, - Spent: spent, + Spent: totalSpent, Remaining: remaining, Progress: progress, PeriodType: budget.PeriodType, @@ -353,6 +360,41 @@ func (s *BudgetService) calculatePreviousPeriod(budget *models.Budget, reference } } +// calculatePeriodsElapsed 计算从预算开始日期到当前周期开始日期之间的完整周期数 +func (s *BudgetService) calculatePeriodsElapsed(budget *models.Budget, currentPeriodStart time.Time) int { + if currentPeriodStart.Before(budget.StartDate) || currentPeriodStart.Equal(budget.StartDate) { + return 0 + } + + var periods int + switch budget.PeriodType { + case models.PeriodTypeDaily: + periods = int(currentPeriodStart.Sub(budget.StartDate).Hours() / 24) + + case models.PeriodTypeWeekly: + periods = int(currentPeriodStart.Sub(budget.StartDate).Hours() / (24 * 7)) + + case models.PeriodTypeMonthly: + yearDiff := currentPeriodStart.Year() - budget.StartDate.Year() + monthDiff := int(currentPeriodStart.Month()) - int(budget.StartDate.Month()) + periods = yearDiff*12 + monthDiff + + case models.PeriodTypeYearly: + periods = currentPeriodStart.Year() - budget.StartDate.Year() + + default: + yearDiff := currentPeriodStart.Year() - budget.StartDate.Year() + monthDiff := int(currentPeriodStart.Month()) - int(budget.StartDate.Month()) + periods = yearDiff*12 + monthDiff + } + + // 确保返回非负数 + if periods < 0 { + return 0 + } + return periods +} + // isValidPeriodType checks if a period type is valid func isValidPeriodType(periodType models.PeriodType) bool { switch periodType {