feat: 添加预算管理功能,包括预算的创建、查询、更新、删除及进度计算。

This commit is contained in:
2026-01-29 21:43:35 +08:00
parent f1fa7b6c54
commit cf34f8b3d0
4 changed files with 112 additions and 35 deletions

View File

@@ -3,8 +3,8 @@ package handler
import (
"strconv"
"accounting-app/pkg/api"
"accounting-app/internal/service"
"accounting-app/pkg/api"
"github.com/gin-gonic/gin"
)
@@ -60,8 +60,6 @@ func (h *BudgetHandler) CreateBudget(c *gin.Context) {
api.BadRequest(c, err.Error())
case service.ErrInvalidPeriodType:
api.BadRequest(c, err.Error())
case service.ErrCategoryOrAccountRequired:
api.BadRequest(c, err.Error())
default:
api.InternalError(c, "Failed to create budget")
}
@@ -150,8 +148,6 @@ func (h *BudgetHandler) UpdateBudget(c *gin.Context) {
api.BadRequest(c, err.Error())
case service.ErrInvalidPeriodType:
api.BadRequest(c, err.Error())
case service.ErrCategoryOrAccountRequired:
api.BadRequest(c, err.Error())
default:
api.InternalError(c, "Failed to update budget")
}

View File

@@ -278,6 +278,7 @@ type Account struct {
PiggyBanks []PiggyBank `gorm:"foreignKey:LinkedAccountID" json:"-"`
ParentAccount *Account `gorm:"foreignKey:ParentAccountID" json:"parent_account,omitempty"`
SubAccounts []Account `gorm:"foreignKey:ParentAccountID" json:"sub_accounts,omitempty"`
Tags []Tag `gorm:"many2many:account_tags;" json:"tags,omitempty"`
}
// TableName specifies the table name for Account
@@ -417,6 +418,17 @@ func (TransactionTag) TableName() string {
return "transaction_tags"
}
// AccountTag represents the many-to-many relationship between accounts and tags
type AccountTag struct {
AccountID uint `gorm:"primaryKey" json:"account_id"`
TagID uint `gorm:"primaryKey" json:"tag_id"`
}
// TableName specifies the table name for AccountTag
func (AccountTag) TableName() string {
return "account_tags"
}
// Budget represents a spending budget for a category or account
type Budget struct {
BaseModel
@@ -843,6 +855,7 @@ func AllModels() []interface{} {
&Tag{},
&Transaction{},
&TransactionTag{}, // Explicit join table for many-to-many relationship
&AccountTag{}, // Explicit join table for account-tag many-to-many relationship
&Budget{},
&PiggyBank{},
&RecurringTransaction{},

View File

@@ -32,6 +32,7 @@ type AccountInput struct {
PaymentDate *int `json:"payment_date,omitempty"`
WarningThreshold *float64 `json:"warning_threshold,omitempty"`
AccountCode string `json:"account_code,omitempty"`
TagIDs []uint `json:"tag_ids,omitempty"`
}
// TransferInput represents the input data for a transfer operation
@@ -99,6 +100,18 @@ func (s *AccountService) CreateAccount(userID uint, input AccountInput) (*models
return nil, fmt.Errorf("failed to create account: %w", err)
}
// Handle tags association
if len(input.TagIDs) > 0 {
var tags []models.Tag
if err := s.db.Where("id IN ? AND user_id = ?", input.TagIDs, userID).Find(&tags).Error; err != nil {
return nil, fmt.Errorf("failed to find tags: %w", err)
}
if err := s.db.Model(account).Association("Tags").Replace(tags); err != nil {
return nil, fmt.Errorf("failed to associate tags: %w", err)
}
account.Tags = tags
}
return account, nil
}
@@ -117,7 +130,8 @@ func (s *AccountService) GetAccount(userID, id uint) (*models.Account, error) {
// GetAllAccounts retrieves all accounts for a specific user
func (s *AccountService) GetAllAccounts(userID uint) ([]models.Account, error) {
accounts, err := s.repo.GetAll(userID)
var accounts []models.Account
err := s.db.Where("user_id = ?", userID).Preload("Tags").Order("sort_order ASC, id ASC").Find(&accounts).Error
if err != nil {
return nil, fmt.Errorf("failed to get accounts: %w", err)
}
@@ -164,6 +178,18 @@ func (s *AccountService) UpdateAccount(userID, id uint, input AccountInput) (*mo
return nil, fmt.Errorf("failed to update account: %w", err)
}
// Handle tags association
var tags []models.Tag
if len(input.TagIDs) > 0 {
if err := s.db.Where("id IN ? AND user_id = ?", input.TagIDs, userID).Find(&tags).Error; err != nil {
return nil, fmt.Errorf("failed to find tags: %w", err)
}
}
if err := s.db.Model(account).Association("Tags").Replace(tags); err != nil {
return nil, fmt.Errorf("failed to update tags: %w", err)
}
account.Tags = tags
return account, nil
}

View File

@@ -13,12 +13,11 @@ import (
// Service layer errors for budgets
var (
ErrBudgetNotFound = errors.New("budget not found")
ErrBudgetInUse = errors.New("budget is in use and cannot be deleted")
ErrInvalidBudgetAmount = errors.New("budget amount must be positive")
ErrInvalidDateRange = errors.New("end date must be after start date")
ErrInvalidPeriodType = errors.New("invalid period type")
ErrCategoryOrAccountRequired = errors.New("either category or account must be specified")
ErrBudgetNotFound = errors.New("budget not found")
ErrBudgetInUse = errors.New("budget is in use and cannot be deleted")
ErrInvalidBudgetAmount = errors.New("budget amount must be positive")
ErrInvalidDateRange = errors.New("end date must be after start date")
ErrInvalidPeriodType = errors.New("invalid period type")
)
// BudgetInput represents the input data for creating or updating a budget
@@ -72,10 +71,7 @@ func (s *BudgetService) CreateBudget(input BudgetInput) (*models.Budget, error)
return nil, ErrInvalidBudgetAmount
}
// Validate that at least category or account is specified
if input.CategoryID == nil && input.AccountID == nil {
return nil, ErrCategoryOrAccountRequired
}
// 分类和账户都可选,支持全局预算
// Validate date range
if input.EndDate != nil && input.EndDate.Before(input.StartDate) {
@@ -147,10 +143,7 @@ func (s *BudgetService) UpdateBudget(userID, id uint, input BudgetInput) (*model
return nil, ErrInvalidBudgetAmount
}
// Validate that at least category or account is specified
if input.CategoryID == nil && input.AccountID == nil {
return nil, ErrCategoryOrAccountRequired
}
// 分类和账户都可选,支持全局预算
// Validate date range
if input.EndDate != nil && input.EndDate.Before(input.StartDate) {
@@ -222,41 +215,55 @@ func (s *BudgetService) GetBudgetProgress(userID, id uint) (*BudgetProgress, err
startDate, endDate := s.calculateCurrentPeriod(budget, now)
// Get spent amount for the current period
spent, err := s.repo.GetSpentAmount(budget, startDate, endDate)
currentSpent, err := s.repo.GetSpentAmount(budget, startDate, endDate)
if err != nil {
return nil, fmt.Errorf("failed to calculate spent amount: %w", err)
}
// Calculate effective budget amount (considering rolling budget)
// Calculate effective budget amount
effectiveAmount := budget.Amount
totalSpent := currentSpent
if budget.IsRolling {
// For rolling budgets, add the previous period's remaining balance
prevStartDate, prevEndDate := s.calculatePreviousPeriod(budget, now)
prevSpent, err := s.repo.GetSpentAmount(budget, prevStartDate, prevEndDate)
if err != nil {
return nil, fmt.Errorf("failed to calculate previous period spent: %w", err)
}
prevRemaining := budget.Amount - prevSpent
if prevRemaining > 0 {
effectiveAmount += prevRemaining
// 滚动预算:结余自动累加到下一周期
// 当期可用额度 = 总额度 - 历史支出
// 计算已过的完整周期数(不含当期)
periodsElapsed := s.calculatePeriodsElapsed(budget, startDate)
// 总额度 = (已过周期数 + 当期) × 单期额度
totalBudget := budget.Amount * float64(periodsElapsed+1)
// 获取历史支出(从预算开始到当期开始前一秒)
historyEnd := startDate.Add(-time.Second)
historySpent := 0.0
if historyEnd.After(budget.StartDate) {
historySpent, err = s.repo.GetSpentAmount(budget, budget.StartDate, historyEnd)
if err != nil {
return nil, fmt.Errorf("failed to calculate history spent: %w", err)
}
}
// 当期可用额度 = 总额度 - 历史支出
effectiveAmount = totalBudget - historySpent
totalSpent = currentSpent
}
// Calculate progress metrics
remaining := effectiveAmount - spent
remaining := effectiveAmount - totalSpent
progress := 0.0
if effectiveAmount > 0 {
progress = (spent / effectiveAmount) * 100
progress = (totalSpent / effectiveAmount) * 100
}
isOverBudget := spent > effectiveAmount
isOverBudget := totalSpent > effectiveAmount
isNearLimit := progress >= 80.0 && !isOverBudget
return &BudgetProgress{
BudgetID: budget.ID,
Name: budget.Name,
Amount: effectiveAmount,
Spent: spent,
Spent: totalSpent,
Remaining: remaining,
Progress: progress,
PeriodType: budget.PeriodType,
@@ -353,6 +360,41 @@ func (s *BudgetService) calculatePreviousPeriod(budget *models.Budget, reference
}
}
// calculatePeriodsElapsed 计算从预算开始日期到当前周期开始日期之间的完整周期数
func (s *BudgetService) calculatePeriodsElapsed(budget *models.Budget, currentPeriodStart time.Time) int {
if currentPeriodStart.Before(budget.StartDate) || currentPeriodStart.Equal(budget.StartDate) {
return 0
}
var periods int
switch budget.PeriodType {
case models.PeriodTypeDaily:
periods = int(currentPeriodStart.Sub(budget.StartDate).Hours() / 24)
case models.PeriodTypeWeekly:
periods = int(currentPeriodStart.Sub(budget.StartDate).Hours() / (24 * 7))
case models.PeriodTypeMonthly:
yearDiff := currentPeriodStart.Year() - budget.StartDate.Year()
monthDiff := int(currentPeriodStart.Month()) - int(budget.StartDate.Month())
periods = yearDiff*12 + monthDiff
case models.PeriodTypeYearly:
periods = currentPeriodStart.Year() - budget.StartDate.Year()
default:
yearDiff := currentPeriodStart.Year() - budget.StartDate.Year()
monthDiff := int(currentPeriodStart.Month()) - int(budget.StartDate.Month())
periods = yearDiff*12 + monthDiff
}
// 确保返回非负数
if periods < 0 {
return 0
}
return periods
}
// isValidPeriodType checks if a period type is valid
func isValidPeriodType(periodType models.PeriodType) bool {
switch periodType {