feat: 添加预算管理功能,包括预算的创建、查询、更新、删除及进度计算。
This commit is contained in:
@@ -3,8 +3,8 @@ package handler
|
|||||||
import (
|
import (
|
||||||
"strconv"
|
"strconv"
|
||||||
|
|
||||||
"accounting-app/pkg/api"
|
|
||||||
"accounting-app/internal/service"
|
"accounting-app/internal/service"
|
||||||
|
"accounting-app/pkg/api"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
)
|
)
|
||||||
@@ -60,8 +60,6 @@ func (h *BudgetHandler) CreateBudget(c *gin.Context) {
|
|||||||
api.BadRequest(c, err.Error())
|
api.BadRequest(c, err.Error())
|
||||||
case service.ErrInvalidPeriodType:
|
case service.ErrInvalidPeriodType:
|
||||||
api.BadRequest(c, err.Error())
|
api.BadRequest(c, err.Error())
|
||||||
case service.ErrCategoryOrAccountRequired:
|
|
||||||
api.BadRequest(c, err.Error())
|
|
||||||
default:
|
default:
|
||||||
api.InternalError(c, "Failed to create budget")
|
api.InternalError(c, "Failed to create budget")
|
||||||
}
|
}
|
||||||
@@ -150,8 +148,6 @@ func (h *BudgetHandler) UpdateBudget(c *gin.Context) {
|
|||||||
api.BadRequest(c, err.Error())
|
api.BadRequest(c, err.Error())
|
||||||
case service.ErrInvalidPeriodType:
|
case service.ErrInvalidPeriodType:
|
||||||
api.BadRequest(c, err.Error())
|
api.BadRequest(c, err.Error())
|
||||||
case service.ErrCategoryOrAccountRequired:
|
|
||||||
api.BadRequest(c, err.Error())
|
|
||||||
default:
|
default:
|
||||||
api.InternalError(c, "Failed to update budget")
|
api.InternalError(c, "Failed to update budget")
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -278,6 +278,7 @@ type Account struct {
|
|||||||
PiggyBanks []PiggyBank `gorm:"foreignKey:LinkedAccountID" json:"-"`
|
PiggyBanks []PiggyBank `gorm:"foreignKey:LinkedAccountID" json:"-"`
|
||||||
ParentAccount *Account `gorm:"foreignKey:ParentAccountID" json:"parent_account,omitempty"`
|
ParentAccount *Account `gorm:"foreignKey:ParentAccountID" json:"parent_account,omitempty"`
|
||||||
SubAccounts []Account `gorm:"foreignKey:ParentAccountID" json:"sub_accounts,omitempty"`
|
SubAccounts []Account `gorm:"foreignKey:ParentAccountID" json:"sub_accounts,omitempty"`
|
||||||
|
Tags []Tag `gorm:"many2many:account_tags;" json:"tags,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// TableName specifies the table name for Account
|
// TableName specifies the table name for Account
|
||||||
@@ -417,6 +418,17 @@ func (TransactionTag) TableName() string {
|
|||||||
return "transaction_tags"
|
return "transaction_tags"
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// AccountTag represents the many-to-many relationship between accounts and tags
|
||||||
|
type AccountTag struct {
|
||||||
|
AccountID uint `gorm:"primaryKey" json:"account_id"`
|
||||||
|
TagID uint `gorm:"primaryKey" json:"tag_id"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// TableName specifies the table name for AccountTag
|
||||||
|
func (AccountTag) TableName() string {
|
||||||
|
return "account_tags"
|
||||||
|
}
|
||||||
|
|
||||||
// Budget represents a spending budget for a category or account
|
// Budget represents a spending budget for a category or account
|
||||||
type Budget struct {
|
type Budget struct {
|
||||||
BaseModel
|
BaseModel
|
||||||
@@ -843,6 +855,7 @@ func AllModels() []interface{} {
|
|||||||
&Tag{},
|
&Tag{},
|
||||||
&Transaction{},
|
&Transaction{},
|
||||||
&TransactionTag{}, // Explicit join table for many-to-many relationship
|
&TransactionTag{}, // Explicit join table for many-to-many relationship
|
||||||
|
&AccountTag{}, // Explicit join table for account-tag many-to-many relationship
|
||||||
&Budget{},
|
&Budget{},
|
||||||
&PiggyBank{},
|
&PiggyBank{},
|
||||||
&RecurringTransaction{},
|
&RecurringTransaction{},
|
||||||
|
|||||||
@@ -32,6 +32,7 @@ type AccountInput struct {
|
|||||||
PaymentDate *int `json:"payment_date,omitempty"`
|
PaymentDate *int `json:"payment_date,omitempty"`
|
||||||
WarningThreshold *float64 `json:"warning_threshold,omitempty"`
|
WarningThreshold *float64 `json:"warning_threshold,omitempty"`
|
||||||
AccountCode string `json:"account_code,omitempty"`
|
AccountCode string `json:"account_code,omitempty"`
|
||||||
|
TagIDs []uint `json:"tag_ids,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// TransferInput represents the input data for a transfer operation
|
// TransferInput represents the input data for a transfer operation
|
||||||
@@ -99,6 +100,18 @@ func (s *AccountService) CreateAccount(userID uint, input AccountInput) (*models
|
|||||||
return nil, fmt.Errorf("failed to create account: %w", err)
|
return nil, fmt.Errorf("failed to create account: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Handle tags association
|
||||||
|
if len(input.TagIDs) > 0 {
|
||||||
|
var tags []models.Tag
|
||||||
|
if err := s.db.Where("id IN ? AND user_id = ?", input.TagIDs, userID).Find(&tags).Error; err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to find tags: %w", err)
|
||||||
|
}
|
||||||
|
if err := s.db.Model(account).Association("Tags").Replace(tags); err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to associate tags: %w", err)
|
||||||
|
}
|
||||||
|
account.Tags = tags
|
||||||
|
}
|
||||||
|
|
||||||
return account, nil
|
return account, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -117,7 +130,8 @@ func (s *AccountService) GetAccount(userID, id uint) (*models.Account, error) {
|
|||||||
|
|
||||||
// GetAllAccounts retrieves all accounts for a specific user
|
// GetAllAccounts retrieves all accounts for a specific user
|
||||||
func (s *AccountService) GetAllAccounts(userID uint) ([]models.Account, error) {
|
func (s *AccountService) GetAllAccounts(userID uint) ([]models.Account, error) {
|
||||||
accounts, err := s.repo.GetAll(userID)
|
var accounts []models.Account
|
||||||
|
err := s.db.Where("user_id = ?", userID).Preload("Tags").Order("sort_order ASC, id ASC").Find(&accounts).Error
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to get accounts: %w", err)
|
return nil, fmt.Errorf("failed to get accounts: %w", err)
|
||||||
}
|
}
|
||||||
@@ -164,6 +178,18 @@ func (s *AccountService) UpdateAccount(userID, id uint, input AccountInput) (*mo
|
|||||||
return nil, fmt.Errorf("failed to update account: %w", err)
|
return nil, fmt.Errorf("failed to update account: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Handle tags association
|
||||||
|
var tags []models.Tag
|
||||||
|
if len(input.TagIDs) > 0 {
|
||||||
|
if err := s.db.Where("id IN ? AND user_id = ?", input.TagIDs, userID).Find(&tags).Error; err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to find tags: %w", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if err := s.db.Model(account).Association("Tags").Replace(tags); err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to update tags: %w", err)
|
||||||
|
}
|
||||||
|
account.Tags = tags
|
||||||
|
|
||||||
return account, nil
|
return account, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -13,12 +13,11 @@ import (
|
|||||||
|
|
||||||
// Service layer errors for budgets
|
// Service layer errors for budgets
|
||||||
var (
|
var (
|
||||||
ErrBudgetNotFound = errors.New("budget not found")
|
ErrBudgetNotFound = errors.New("budget not found")
|
||||||
ErrBudgetInUse = errors.New("budget is in use and cannot be deleted")
|
ErrBudgetInUse = errors.New("budget is in use and cannot be deleted")
|
||||||
ErrInvalidBudgetAmount = errors.New("budget amount must be positive")
|
ErrInvalidBudgetAmount = errors.New("budget amount must be positive")
|
||||||
ErrInvalidDateRange = errors.New("end date must be after start date")
|
ErrInvalidDateRange = errors.New("end date must be after start date")
|
||||||
ErrInvalidPeriodType = errors.New("invalid period type")
|
ErrInvalidPeriodType = errors.New("invalid period type")
|
||||||
ErrCategoryOrAccountRequired = errors.New("either category or account must be specified")
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// BudgetInput represents the input data for creating or updating a budget
|
// BudgetInput represents the input data for creating or updating a budget
|
||||||
@@ -72,10 +71,7 @@ func (s *BudgetService) CreateBudget(input BudgetInput) (*models.Budget, error)
|
|||||||
return nil, ErrInvalidBudgetAmount
|
return nil, ErrInvalidBudgetAmount
|
||||||
}
|
}
|
||||||
|
|
||||||
// Validate that at least category or account is specified
|
// 分类和账户都可选,支持全局预算
|
||||||
if input.CategoryID == nil && input.AccountID == nil {
|
|
||||||
return nil, ErrCategoryOrAccountRequired
|
|
||||||
}
|
|
||||||
|
|
||||||
// Validate date range
|
// Validate date range
|
||||||
if input.EndDate != nil && input.EndDate.Before(input.StartDate) {
|
if input.EndDate != nil && input.EndDate.Before(input.StartDate) {
|
||||||
@@ -147,10 +143,7 @@ func (s *BudgetService) UpdateBudget(userID, id uint, input BudgetInput) (*model
|
|||||||
return nil, ErrInvalidBudgetAmount
|
return nil, ErrInvalidBudgetAmount
|
||||||
}
|
}
|
||||||
|
|
||||||
// Validate that at least category or account is specified
|
// 分类和账户都可选,支持全局预算
|
||||||
if input.CategoryID == nil && input.AccountID == nil {
|
|
||||||
return nil, ErrCategoryOrAccountRequired
|
|
||||||
}
|
|
||||||
|
|
||||||
// Validate date range
|
// Validate date range
|
||||||
if input.EndDate != nil && input.EndDate.Before(input.StartDate) {
|
if input.EndDate != nil && input.EndDate.Before(input.StartDate) {
|
||||||
@@ -222,41 +215,55 @@ func (s *BudgetService) GetBudgetProgress(userID, id uint) (*BudgetProgress, err
|
|||||||
startDate, endDate := s.calculateCurrentPeriod(budget, now)
|
startDate, endDate := s.calculateCurrentPeriod(budget, now)
|
||||||
|
|
||||||
// Get spent amount for the current period
|
// Get spent amount for the current period
|
||||||
spent, err := s.repo.GetSpentAmount(budget, startDate, endDate)
|
currentSpent, err := s.repo.GetSpentAmount(budget, startDate, endDate)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to calculate spent amount: %w", err)
|
return nil, fmt.Errorf("failed to calculate spent amount: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Calculate effective budget amount (considering rolling budget)
|
// Calculate effective budget amount
|
||||||
effectiveAmount := budget.Amount
|
effectiveAmount := budget.Amount
|
||||||
|
totalSpent := currentSpent
|
||||||
|
|
||||||
if budget.IsRolling {
|
if budget.IsRolling {
|
||||||
// For rolling budgets, add the previous period's remaining balance
|
// 滚动预算:结余自动累加到下一周期
|
||||||
prevStartDate, prevEndDate := s.calculatePreviousPeriod(budget, now)
|
// 当期可用额度 = 总额度 - 历史支出
|
||||||
prevSpent, err := s.repo.GetSpentAmount(budget, prevStartDate, prevEndDate)
|
|
||||||
if err != nil {
|
// 计算已过的完整周期数(不含当期)
|
||||||
return nil, fmt.Errorf("failed to calculate previous period spent: %w", err)
|
periodsElapsed := s.calculatePeriodsElapsed(budget, startDate)
|
||||||
}
|
|
||||||
prevRemaining := budget.Amount - prevSpent
|
// 总额度 = (已过周期数 + 当期) × 单期额度
|
||||||
if prevRemaining > 0 {
|
totalBudget := budget.Amount * float64(periodsElapsed+1)
|
||||||
effectiveAmount += prevRemaining
|
|
||||||
|
// 获取历史支出(从预算开始到当期开始前一秒)
|
||||||
|
historyEnd := startDate.Add(-time.Second)
|
||||||
|
historySpent := 0.0
|
||||||
|
if historyEnd.After(budget.StartDate) {
|
||||||
|
historySpent, err = s.repo.GetSpentAmount(budget, budget.StartDate, historyEnd)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to calculate history spent: %w", err)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// 当期可用额度 = 总额度 - 历史支出
|
||||||
|
effectiveAmount = totalBudget - historySpent
|
||||||
|
totalSpent = currentSpent
|
||||||
}
|
}
|
||||||
|
|
||||||
// Calculate progress metrics
|
// Calculate progress metrics
|
||||||
remaining := effectiveAmount - spent
|
remaining := effectiveAmount - totalSpent
|
||||||
progress := 0.0
|
progress := 0.0
|
||||||
if effectiveAmount > 0 {
|
if effectiveAmount > 0 {
|
||||||
progress = (spent / effectiveAmount) * 100
|
progress = (totalSpent / effectiveAmount) * 100
|
||||||
}
|
}
|
||||||
|
|
||||||
isOverBudget := spent > effectiveAmount
|
isOverBudget := totalSpent > effectiveAmount
|
||||||
isNearLimit := progress >= 80.0 && !isOverBudget
|
isNearLimit := progress >= 80.0 && !isOverBudget
|
||||||
|
|
||||||
return &BudgetProgress{
|
return &BudgetProgress{
|
||||||
BudgetID: budget.ID,
|
BudgetID: budget.ID,
|
||||||
Name: budget.Name,
|
Name: budget.Name,
|
||||||
Amount: effectiveAmount,
|
Amount: effectiveAmount,
|
||||||
Spent: spent,
|
Spent: totalSpent,
|
||||||
Remaining: remaining,
|
Remaining: remaining,
|
||||||
Progress: progress,
|
Progress: progress,
|
||||||
PeriodType: budget.PeriodType,
|
PeriodType: budget.PeriodType,
|
||||||
@@ -353,6 +360,41 @@ func (s *BudgetService) calculatePreviousPeriod(budget *models.Budget, reference
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// calculatePeriodsElapsed 计算从预算开始日期到当前周期开始日期之间的完整周期数
|
||||||
|
func (s *BudgetService) calculatePeriodsElapsed(budget *models.Budget, currentPeriodStart time.Time) int {
|
||||||
|
if currentPeriodStart.Before(budget.StartDate) || currentPeriodStart.Equal(budget.StartDate) {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
|
||||||
|
var periods int
|
||||||
|
switch budget.PeriodType {
|
||||||
|
case models.PeriodTypeDaily:
|
||||||
|
periods = int(currentPeriodStart.Sub(budget.StartDate).Hours() / 24)
|
||||||
|
|
||||||
|
case models.PeriodTypeWeekly:
|
||||||
|
periods = int(currentPeriodStart.Sub(budget.StartDate).Hours() / (24 * 7))
|
||||||
|
|
||||||
|
case models.PeriodTypeMonthly:
|
||||||
|
yearDiff := currentPeriodStart.Year() - budget.StartDate.Year()
|
||||||
|
monthDiff := int(currentPeriodStart.Month()) - int(budget.StartDate.Month())
|
||||||
|
periods = yearDiff*12 + monthDiff
|
||||||
|
|
||||||
|
case models.PeriodTypeYearly:
|
||||||
|
periods = currentPeriodStart.Year() - budget.StartDate.Year()
|
||||||
|
|
||||||
|
default:
|
||||||
|
yearDiff := currentPeriodStart.Year() - budget.StartDate.Year()
|
||||||
|
monthDiff := int(currentPeriodStart.Month()) - int(budget.StartDate.Month())
|
||||||
|
periods = yearDiff*12 + monthDiff
|
||||||
|
}
|
||||||
|
|
||||||
|
// 确保返回非负数
|
||||||
|
if periods < 0 {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
return periods
|
||||||
|
}
|
||||||
|
|
||||||
// isValidPeriodType checks if a period type is valid
|
// isValidPeriodType checks if a period type is valid
|
||||||
func isValidPeriodType(periodType models.PeriodType) bool {
|
func isValidPeriodType(periodType models.PeriodType) bool {
|
||||||
switch periodType {
|
switch periodType {
|
||||||
|
|||||||
Reference in New Issue
Block a user