feat: 添加预算管理功能,包括预算的创建、查询、更新、删除及进度计算。

This commit is contained in:
2026-01-29 21:43:35 +08:00
parent f1fa7b6c54
commit cf34f8b3d0
4 changed files with 112 additions and 35 deletions

View File

@@ -3,8 +3,8 @@ package handler
import ( import (
"strconv" "strconv"
"accounting-app/pkg/api"
"accounting-app/internal/service" "accounting-app/internal/service"
"accounting-app/pkg/api"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
) )
@@ -60,8 +60,6 @@ func (h *BudgetHandler) CreateBudget(c *gin.Context) {
api.BadRequest(c, err.Error()) api.BadRequest(c, err.Error())
case service.ErrInvalidPeriodType: case service.ErrInvalidPeriodType:
api.BadRequest(c, err.Error()) api.BadRequest(c, err.Error())
case service.ErrCategoryOrAccountRequired:
api.BadRequest(c, err.Error())
default: default:
api.InternalError(c, "Failed to create budget") api.InternalError(c, "Failed to create budget")
} }
@@ -150,8 +148,6 @@ func (h *BudgetHandler) UpdateBudget(c *gin.Context) {
api.BadRequest(c, err.Error()) api.BadRequest(c, err.Error())
case service.ErrInvalidPeriodType: case service.ErrInvalidPeriodType:
api.BadRequest(c, err.Error()) api.BadRequest(c, err.Error())
case service.ErrCategoryOrAccountRequired:
api.BadRequest(c, err.Error())
default: default:
api.InternalError(c, "Failed to update budget") api.InternalError(c, "Failed to update budget")
} }

View File

@@ -278,6 +278,7 @@ type Account struct {
PiggyBanks []PiggyBank `gorm:"foreignKey:LinkedAccountID" json:"-"` PiggyBanks []PiggyBank `gorm:"foreignKey:LinkedAccountID" json:"-"`
ParentAccount *Account `gorm:"foreignKey:ParentAccountID" json:"parent_account,omitempty"` ParentAccount *Account `gorm:"foreignKey:ParentAccountID" json:"parent_account,omitempty"`
SubAccounts []Account `gorm:"foreignKey:ParentAccountID" json:"sub_accounts,omitempty"` SubAccounts []Account `gorm:"foreignKey:ParentAccountID" json:"sub_accounts,omitempty"`
Tags []Tag `gorm:"many2many:account_tags;" json:"tags,omitempty"`
} }
// TableName specifies the table name for Account // TableName specifies the table name for Account
@@ -417,6 +418,17 @@ func (TransactionTag) TableName() string {
return "transaction_tags" return "transaction_tags"
} }
// AccountTag represents the many-to-many relationship between accounts and tags
type AccountTag struct {
AccountID uint `gorm:"primaryKey" json:"account_id"`
TagID uint `gorm:"primaryKey" json:"tag_id"`
}
// TableName specifies the table name for AccountTag
func (AccountTag) TableName() string {
return "account_tags"
}
// Budget represents a spending budget for a category or account // Budget represents a spending budget for a category or account
type Budget struct { type Budget struct {
BaseModel BaseModel
@@ -843,6 +855,7 @@ func AllModels() []interface{} {
&Tag{}, &Tag{},
&Transaction{}, &Transaction{},
&TransactionTag{}, // Explicit join table for many-to-many relationship &TransactionTag{}, // Explicit join table for many-to-many relationship
&AccountTag{}, // Explicit join table for account-tag many-to-many relationship
&Budget{}, &Budget{},
&PiggyBank{}, &PiggyBank{},
&RecurringTransaction{}, &RecurringTransaction{},

View File

@@ -32,6 +32,7 @@ type AccountInput struct {
PaymentDate *int `json:"payment_date,omitempty"` PaymentDate *int `json:"payment_date,omitempty"`
WarningThreshold *float64 `json:"warning_threshold,omitempty"` WarningThreshold *float64 `json:"warning_threshold,omitempty"`
AccountCode string `json:"account_code,omitempty"` AccountCode string `json:"account_code,omitempty"`
TagIDs []uint `json:"tag_ids,omitempty"`
} }
// TransferInput represents the input data for a transfer operation // TransferInput represents the input data for a transfer operation
@@ -99,6 +100,18 @@ func (s *AccountService) CreateAccount(userID uint, input AccountInput) (*models
return nil, fmt.Errorf("failed to create account: %w", err) return nil, fmt.Errorf("failed to create account: %w", err)
} }
// Handle tags association
if len(input.TagIDs) > 0 {
var tags []models.Tag
if err := s.db.Where("id IN ? AND user_id = ?", input.TagIDs, userID).Find(&tags).Error; err != nil {
return nil, fmt.Errorf("failed to find tags: %w", err)
}
if err := s.db.Model(account).Association("Tags").Replace(tags); err != nil {
return nil, fmt.Errorf("failed to associate tags: %w", err)
}
account.Tags = tags
}
return account, nil return account, nil
} }
@@ -117,7 +130,8 @@ func (s *AccountService) GetAccount(userID, id uint) (*models.Account, error) {
// GetAllAccounts retrieves all accounts for a specific user // GetAllAccounts retrieves all accounts for a specific user
func (s *AccountService) GetAllAccounts(userID uint) ([]models.Account, error) { func (s *AccountService) GetAllAccounts(userID uint) ([]models.Account, error) {
accounts, err := s.repo.GetAll(userID) var accounts []models.Account
err := s.db.Where("user_id = ?", userID).Preload("Tags").Order("sort_order ASC, id ASC").Find(&accounts).Error
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to get accounts: %w", err) return nil, fmt.Errorf("failed to get accounts: %w", err)
} }
@@ -164,6 +178,18 @@ func (s *AccountService) UpdateAccount(userID, id uint, input AccountInput) (*mo
return nil, fmt.Errorf("failed to update account: %w", err) return nil, fmt.Errorf("failed to update account: %w", err)
} }
// Handle tags association
var tags []models.Tag
if len(input.TagIDs) > 0 {
if err := s.db.Where("id IN ? AND user_id = ?", input.TagIDs, userID).Find(&tags).Error; err != nil {
return nil, fmt.Errorf("failed to find tags: %w", err)
}
}
if err := s.db.Model(account).Association("Tags").Replace(tags); err != nil {
return nil, fmt.Errorf("failed to update tags: %w", err)
}
account.Tags = tags
return account, nil return account, nil
} }

View File

@@ -18,7 +18,6 @@ var (
ErrInvalidBudgetAmount = errors.New("budget amount must be positive") ErrInvalidBudgetAmount = errors.New("budget amount must be positive")
ErrInvalidDateRange = errors.New("end date must be after start date") ErrInvalidDateRange = errors.New("end date must be after start date")
ErrInvalidPeriodType = errors.New("invalid period type") ErrInvalidPeriodType = errors.New("invalid period type")
ErrCategoryOrAccountRequired = errors.New("either category or account must be specified")
) )
// BudgetInput represents the input data for creating or updating a budget // BudgetInput represents the input data for creating or updating a budget
@@ -72,10 +71,7 @@ func (s *BudgetService) CreateBudget(input BudgetInput) (*models.Budget, error)
return nil, ErrInvalidBudgetAmount return nil, ErrInvalidBudgetAmount
} }
// Validate that at least category or account is specified // 分类和账户都可选,支持全局预算
if input.CategoryID == nil && input.AccountID == nil {
return nil, ErrCategoryOrAccountRequired
}
// Validate date range // Validate date range
if input.EndDate != nil && input.EndDate.Before(input.StartDate) { if input.EndDate != nil && input.EndDate.Before(input.StartDate) {
@@ -147,10 +143,7 @@ func (s *BudgetService) UpdateBudget(userID, id uint, input BudgetInput) (*model
return nil, ErrInvalidBudgetAmount return nil, ErrInvalidBudgetAmount
} }
// Validate that at least category or account is specified // 分类和账户都可选,支持全局预算
if input.CategoryID == nil && input.AccountID == nil {
return nil, ErrCategoryOrAccountRequired
}
// Validate date range // Validate date range
if input.EndDate != nil && input.EndDate.Before(input.StartDate) { if input.EndDate != nil && input.EndDate.Before(input.StartDate) {
@@ -222,41 +215,55 @@ func (s *BudgetService) GetBudgetProgress(userID, id uint) (*BudgetProgress, err
startDate, endDate := s.calculateCurrentPeriod(budget, now) startDate, endDate := s.calculateCurrentPeriod(budget, now)
// Get spent amount for the current period // Get spent amount for the current period
spent, err := s.repo.GetSpentAmount(budget, startDate, endDate) currentSpent, err := s.repo.GetSpentAmount(budget, startDate, endDate)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to calculate spent amount: %w", err) return nil, fmt.Errorf("failed to calculate spent amount: %w", err)
} }
// Calculate effective budget amount (considering rolling budget) // Calculate effective budget amount
effectiveAmount := budget.Amount effectiveAmount := budget.Amount
totalSpent := currentSpent
if budget.IsRolling { if budget.IsRolling {
// For rolling budgets, add the previous period's remaining balance // 滚动预算:结余自动累加到下一周期
prevStartDate, prevEndDate := s.calculatePreviousPeriod(budget, now) // 当期可用额度 = 总额度 - 历史支出
prevSpent, err := s.repo.GetSpentAmount(budget, prevStartDate, prevEndDate)
// 计算已过的完整周期数(不含当期)
periodsElapsed := s.calculatePeriodsElapsed(budget, startDate)
// 总额度 = (已过周期数 + 当期) × 单期额度
totalBudget := budget.Amount * float64(periodsElapsed+1)
// 获取历史支出(从预算开始到当期开始前一秒)
historyEnd := startDate.Add(-time.Second)
historySpent := 0.0
if historyEnd.After(budget.StartDate) {
historySpent, err = s.repo.GetSpentAmount(budget, budget.StartDate, historyEnd)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to calculate previous period spent: %w", err) return nil, fmt.Errorf("failed to calculate history spent: %w", err)
} }
prevRemaining := budget.Amount - prevSpent
if prevRemaining > 0 {
effectiveAmount += prevRemaining
} }
// 当期可用额度 = 总额度 - 历史支出
effectiveAmount = totalBudget - historySpent
totalSpent = currentSpent
} }
// Calculate progress metrics // Calculate progress metrics
remaining := effectiveAmount - spent remaining := effectiveAmount - totalSpent
progress := 0.0 progress := 0.0
if effectiveAmount > 0 { if effectiveAmount > 0 {
progress = (spent / effectiveAmount) * 100 progress = (totalSpent / effectiveAmount) * 100
} }
isOverBudget := spent > effectiveAmount isOverBudget := totalSpent > effectiveAmount
isNearLimit := progress >= 80.0 && !isOverBudget isNearLimit := progress >= 80.0 && !isOverBudget
return &BudgetProgress{ return &BudgetProgress{
BudgetID: budget.ID, BudgetID: budget.ID,
Name: budget.Name, Name: budget.Name,
Amount: effectiveAmount, Amount: effectiveAmount,
Spent: spent, Spent: totalSpent,
Remaining: remaining, Remaining: remaining,
Progress: progress, Progress: progress,
PeriodType: budget.PeriodType, PeriodType: budget.PeriodType,
@@ -353,6 +360,41 @@ func (s *BudgetService) calculatePreviousPeriod(budget *models.Budget, reference
} }
} }
// calculatePeriodsElapsed 计算从预算开始日期到当前周期开始日期之间的完整周期数
func (s *BudgetService) calculatePeriodsElapsed(budget *models.Budget, currentPeriodStart time.Time) int {
if currentPeriodStart.Before(budget.StartDate) || currentPeriodStart.Equal(budget.StartDate) {
return 0
}
var periods int
switch budget.PeriodType {
case models.PeriodTypeDaily:
periods = int(currentPeriodStart.Sub(budget.StartDate).Hours() / 24)
case models.PeriodTypeWeekly:
periods = int(currentPeriodStart.Sub(budget.StartDate).Hours() / (24 * 7))
case models.PeriodTypeMonthly:
yearDiff := currentPeriodStart.Year() - budget.StartDate.Year()
monthDiff := int(currentPeriodStart.Month()) - int(budget.StartDate.Month())
periods = yearDiff*12 + monthDiff
case models.PeriodTypeYearly:
periods = currentPeriodStart.Year() - budget.StartDate.Year()
default:
yearDiff := currentPeriodStart.Year() - budget.StartDate.Year()
monthDiff := int(currentPeriodStart.Month()) - int(budget.StartDate.Month())
periods = yearDiff*12 + monthDiff
}
// 确保返回非负数
if periods < 0 {
return 0
}
return periods
}
// isValidPeriodType checks if a period type is valid // isValidPeriodType checks if a period type is valid
func isValidPeriodType(periodType models.PeriodType) bool { func isValidPeriodType(periodType models.PeriodType) bool {
switch periodType { switch periodType {