feat: 添加预算管理功能,包括预算的创建、查询、更新、删除及进度计算。
This commit is contained in:
@@ -3,8 +3,8 @@ package handler
|
||||
import (
|
||||
"strconv"
|
||||
|
||||
"accounting-app/pkg/api"
|
||||
"accounting-app/internal/service"
|
||||
"accounting-app/pkg/api"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
@@ -60,8 +60,6 @@ func (h *BudgetHandler) CreateBudget(c *gin.Context) {
|
||||
api.BadRequest(c, err.Error())
|
||||
case service.ErrInvalidPeriodType:
|
||||
api.BadRequest(c, err.Error())
|
||||
case service.ErrCategoryOrAccountRequired:
|
||||
api.BadRequest(c, err.Error())
|
||||
default:
|
||||
api.InternalError(c, "Failed to create budget")
|
||||
}
|
||||
@@ -150,8 +148,6 @@ func (h *BudgetHandler) UpdateBudget(c *gin.Context) {
|
||||
api.BadRequest(c, err.Error())
|
||||
case service.ErrInvalidPeriodType:
|
||||
api.BadRequest(c, err.Error())
|
||||
case service.ErrCategoryOrAccountRequired:
|
||||
api.BadRequest(c, err.Error())
|
||||
default:
|
||||
api.InternalError(c, "Failed to update budget")
|
||||
}
|
||||
|
||||
@@ -278,6 +278,7 @@ type Account struct {
|
||||
PiggyBanks []PiggyBank `gorm:"foreignKey:LinkedAccountID" json:"-"`
|
||||
ParentAccount *Account `gorm:"foreignKey:ParentAccountID" json:"parent_account,omitempty"`
|
||||
SubAccounts []Account `gorm:"foreignKey:ParentAccountID" json:"sub_accounts,omitempty"`
|
||||
Tags []Tag `gorm:"many2many:account_tags;" json:"tags,omitempty"`
|
||||
}
|
||||
|
||||
// TableName specifies the table name for Account
|
||||
@@ -417,6 +418,17 @@ func (TransactionTag) TableName() string {
|
||||
return "transaction_tags"
|
||||
}
|
||||
|
||||
// AccountTag represents the many-to-many relationship between accounts and tags
|
||||
type AccountTag struct {
|
||||
AccountID uint `gorm:"primaryKey" json:"account_id"`
|
||||
TagID uint `gorm:"primaryKey" json:"tag_id"`
|
||||
}
|
||||
|
||||
// TableName specifies the table name for AccountTag
|
||||
func (AccountTag) TableName() string {
|
||||
return "account_tags"
|
||||
}
|
||||
|
||||
// Budget represents a spending budget for a category or account
|
||||
type Budget struct {
|
||||
BaseModel
|
||||
@@ -843,6 +855,7 @@ func AllModels() []interface{} {
|
||||
&Tag{},
|
||||
&Transaction{},
|
||||
&TransactionTag{}, // Explicit join table for many-to-many relationship
|
||||
&AccountTag{}, // Explicit join table for account-tag many-to-many relationship
|
||||
&Budget{},
|
||||
&PiggyBank{},
|
||||
&RecurringTransaction{},
|
||||
|
||||
@@ -32,6 +32,7 @@ type AccountInput struct {
|
||||
PaymentDate *int `json:"payment_date,omitempty"`
|
||||
WarningThreshold *float64 `json:"warning_threshold,omitempty"`
|
||||
AccountCode string `json:"account_code,omitempty"`
|
||||
TagIDs []uint `json:"tag_ids,omitempty"`
|
||||
}
|
||||
|
||||
// TransferInput represents the input data for a transfer operation
|
||||
@@ -99,6 +100,18 @@ func (s *AccountService) CreateAccount(userID uint, input AccountInput) (*models
|
||||
return nil, fmt.Errorf("failed to create account: %w", err)
|
||||
}
|
||||
|
||||
// Handle tags association
|
||||
if len(input.TagIDs) > 0 {
|
||||
var tags []models.Tag
|
||||
if err := s.db.Where("id IN ? AND user_id = ?", input.TagIDs, userID).Find(&tags).Error; err != nil {
|
||||
return nil, fmt.Errorf("failed to find tags: %w", err)
|
||||
}
|
||||
if err := s.db.Model(account).Association("Tags").Replace(tags); err != nil {
|
||||
return nil, fmt.Errorf("failed to associate tags: %w", err)
|
||||
}
|
||||
account.Tags = tags
|
||||
}
|
||||
|
||||
return account, nil
|
||||
}
|
||||
|
||||
@@ -117,7 +130,8 @@ func (s *AccountService) GetAccount(userID, id uint) (*models.Account, error) {
|
||||
|
||||
// GetAllAccounts retrieves all accounts for a specific user
|
||||
func (s *AccountService) GetAllAccounts(userID uint) ([]models.Account, error) {
|
||||
accounts, err := s.repo.GetAll(userID)
|
||||
var accounts []models.Account
|
||||
err := s.db.Where("user_id = ?", userID).Preload("Tags").Order("sort_order ASC, id ASC").Find(&accounts).Error
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get accounts: %w", err)
|
||||
}
|
||||
@@ -164,6 +178,18 @@ func (s *AccountService) UpdateAccount(userID, id uint, input AccountInput) (*mo
|
||||
return nil, fmt.Errorf("failed to update account: %w", err)
|
||||
}
|
||||
|
||||
// Handle tags association
|
||||
var tags []models.Tag
|
||||
if len(input.TagIDs) > 0 {
|
||||
if err := s.db.Where("id IN ? AND user_id = ?", input.TagIDs, userID).Find(&tags).Error; err != nil {
|
||||
return nil, fmt.Errorf("failed to find tags: %w", err)
|
||||
}
|
||||
}
|
||||
if err := s.db.Model(account).Association("Tags").Replace(tags); err != nil {
|
||||
return nil, fmt.Errorf("failed to update tags: %w", err)
|
||||
}
|
||||
account.Tags = tags
|
||||
|
||||
return account, nil
|
||||
}
|
||||
|
||||
|
||||
@@ -13,12 +13,11 @@ import (
|
||||
|
||||
// Service layer errors for budgets
|
||||
var (
|
||||
ErrBudgetNotFound = errors.New("budget not found")
|
||||
ErrBudgetInUse = errors.New("budget is in use and cannot be deleted")
|
||||
ErrInvalidBudgetAmount = errors.New("budget amount must be positive")
|
||||
ErrInvalidDateRange = errors.New("end date must be after start date")
|
||||
ErrInvalidPeriodType = errors.New("invalid period type")
|
||||
ErrCategoryOrAccountRequired = errors.New("either category or account must be specified")
|
||||
ErrBudgetNotFound = errors.New("budget not found")
|
||||
ErrBudgetInUse = errors.New("budget is in use and cannot be deleted")
|
||||
ErrInvalidBudgetAmount = errors.New("budget amount must be positive")
|
||||
ErrInvalidDateRange = errors.New("end date must be after start date")
|
||||
ErrInvalidPeriodType = errors.New("invalid period type")
|
||||
)
|
||||
|
||||
// BudgetInput represents the input data for creating or updating a budget
|
||||
@@ -72,10 +71,7 @@ func (s *BudgetService) CreateBudget(input BudgetInput) (*models.Budget, error)
|
||||
return nil, ErrInvalidBudgetAmount
|
||||
}
|
||||
|
||||
// Validate that at least category or account is specified
|
||||
if input.CategoryID == nil && input.AccountID == nil {
|
||||
return nil, ErrCategoryOrAccountRequired
|
||||
}
|
||||
// 分类和账户都可选,支持全局预算
|
||||
|
||||
// Validate date range
|
||||
if input.EndDate != nil && input.EndDate.Before(input.StartDate) {
|
||||
@@ -147,10 +143,7 @@ func (s *BudgetService) UpdateBudget(userID, id uint, input BudgetInput) (*model
|
||||
return nil, ErrInvalidBudgetAmount
|
||||
}
|
||||
|
||||
// Validate that at least category or account is specified
|
||||
if input.CategoryID == nil && input.AccountID == nil {
|
||||
return nil, ErrCategoryOrAccountRequired
|
||||
}
|
||||
// 分类和账户都可选,支持全局预算
|
||||
|
||||
// Validate date range
|
||||
if input.EndDate != nil && input.EndDate.Before(input.StartDate) {
|
||||
@@ -222,41 +215,55 @@ func (s *BudgetService) GetBudgetProgress(userID, id uint) (*BudgetProgress, err
|
||||
startDate, endDate := s.calculateCurrentPeriod(budget, now)
|
||||
|
||||
// Get spent amount for the current period
|
||||
spent, err := s.repo.GetSpentAmount(budget, startDate, endDate)
|
||||
currentSpent, err := s.repo.GetSpentAmount(budget, startDate, endDate)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to calculate spent amount: %w", err)
|
||||
}
|
||||
|
||||
// Calculate effective budget amount (considering rolling budget)
|
||||
// Calculate effective budget amount
|
||||
effectiveAmount := budget.Amount
|
||||
totalSpent := currentSpent
|
||||
|
||||
if budget.IsRolling {
|
||||
// For rolling budgets, add the previous period's remaining balance
|
||||
prevStartDate, prevEndDate := s.calculatePreviousPeriod(budget, now)
|
||||
prevSpent, err := s.repo.GetSpentAmount(budget, prevStartDate, prevEndDate)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to calculate previous period spent: %w", err)
|
||||
}
|
||||
prevRemaining := budget.Amount - prevSpent
|
||||
if prevRemaining > 0 {
|
||||
effectiveAmount += prevRemaining
|
||||
// 滚动预算:结余自动累加到下一周期
|
||||
// 当期可用额度 = 总额度 - 历史支出
|
||||
|
||||
// 计算已过的完整周期数(不含当期)
|
||||
periodsElapsed := s.calculatePeriodsElapsed(budget, startDate)
|
||||
|
||||
// 总额度 = (已过周期数 + 当期) × 单期额度
|
||||
totalBudget := budget.Amount * float64(periodsElapsed+1)
|
||||
|
||||
// 获取历史支出(从预算开始到当期开始前一秒)
|
||||
historyEnd := startDate.Add(-time.Second)
|
||||
historySpent := 0.0
|
||||
if historyEnd.After(budget.StartDate) {
|
||||
historySpent, err = s.repo.GetSpentAmount(budget, budget.StartDate, historyEnd)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to calculate history spent: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
// 当期可用额度 = 总额度 - 历史支出
|
||||
effectiveAmount = totalBudget - historySpent
|
||||
totalSpent = currentSpent
|
||||
}
|
||||
|
||||
// Calculate progress metrics
|
||||
remaining := effectiveAmount - spent
|
||||
remaining := effectiveAmount - totalSpent
|
||||
progress := 0.0
|
||||
if effectiveAmount > 0 {
|
||||
progress = (spent / effectiveAmount) * 100
|
||||
progress = (totalSpent / effectiveAmount) * 100
|
||||
}
|
||||
|
||||
isOverBudget := spent > effectiveAmount
|
||||
isOverBudget := totalSpent > effectiveAmount
|
||||
isNearLimit := progress >= 80.0 && !isOverBudget
|
||||
|
||||
return &BudgetProgress{
|
||||
BudgetID: budget.ID,
|
||||
Name: budget.Name,
|
||||
Amount: effectiveAmount,
|
||||
Spent: spent,
|
||||
Spent: totalSpent,
|
||||
Remaining: remaining,
|
||||
Progress: progress,
|
||||
PeriodType: budget.PeriodType,
|
||||
@@ -353,6 +360,41 @@ func (s *BudgetService) calculatePreviousPeriod(budget *models.Budget, reference
|
||||
}
|
||||
}
|
||||
|
||||
// calculatePeriodsElapsed 计算从预算开始日期到当前周期开始日期之间的完整周期数
|
||||
func (s *BudgetService) calculatePeriodsElapsed(budget *models.Budget, currentPeriodStart time.Time) int {
|
||||
if currentPeriodStart.Before(budget.StartDate) || currentPeriodStart.Equal(budget.StartDate) {
|
||||
return 0
|
||||
}
|
||||
|
||||
var periods int
|
||||
switch budget.PeriodType {
|
||||
case models.PeriodTypeDaily:
|
||||
periods = int(currentPeriodStart.Sub(budget.StartDate).Hours() / 24)
|
||||
|
||||
case models.PeriodTypeWeekly:
|
||||
periods = int(currentPeriodStart.Sub(budget.StartDate).Hours() / (24 * 7))
|
||||
|
||||
case models.PeriodTypeMonthly:
|
||||
yearDiff := currentPeriodStart.Year() - budget.StartDate.Year()
|
||||
monthDiff := int(currentPeriodStart.Month()) - int(budget.StartDate.Month())
|
||||
periods = yearDiff*12 + monthDiff
|
||||
|
||||
case models.PeriodTypeYearly:
|
||||
periods = currentPeriodStart.Year() - budget.StartDate.Year()
|
||||
|
||||
default:
|
||||
yearDiff := currentPeriodStart.Year() - budget.StartDate.Year()
|
||||
monthDiff := int(currentPeriodStart.Month()) - int(budget.StartDate.Month())
|
||||
periods = yearDiff*12 + monthDiff
|
||||
}
|
||||
|
||||
// 确保返回非负数
|
||||
if periods < 0 {
|
||||
return 0
|
||||
}
|
||||
return periods
|
||||
}
|
||||
|
||||
// isValidPeriodType checks if a period type is valid
|
||||
func isValidPeriodType(periodType models.PeriodType) bool {
|
||||
switch periodType {
|
||||
|
||||
Reference in New Issue
Block a user