Files
Novault-backend/internal/service/budget_service.go

439 lines
14 KiB
Go
Raw Normal View History

2026-01-25 21:59:00 +08:00
package service
import (
"errors"
"fmt"
"time"
"accounting-app/internal/models"
"accounting-app/internal/repository"
"gorm.io/gorm"
)
// Service layer errors for budgets
var (
ErrBudgetNotFound = errors.New("budget not found")
ErrBudgetInUse = errors.New("budget is in use and cannot be deleted")
ErrInvalidBudgetAmount = errors.New("budget amount must be positive")
ErrInvalidDateRange = errors.New("end date must be after start date")
ErrInvalidPeriodType = errors.New("invalid period type")
2026-01-25 21:59:00 +08:00
)
// BudgetInput represents the input data for creating or updating a budget
type BudgetInput struct {
UserID uint `json:"user_id"`
Name string `json:"name" binding:"required"`
Amount float64 `json:"amount" binding:"required,gt=0"`
PeriodType models.PeriodType `json:"period_type" binding:"required"`
CategoryID *uint `json:"category_id,omitempty"`
AccountID *uint `json:"account_id,omitempty"`
IsRolling bool `json:"is_rolling"`
StartDate time.Time `json:"start_date" binding:"required"`
EndDate *time.Time `json:"end_date,omitempty"`
}
// BudgetProgress represents the progress of a budget
type BudgetProgress struct {
BudgetID uint `json:"budget_id"`
Name string `json:"name"`
Amount float64 `json:"amount"`
Spent float64 `json:"spent"`
Remaining float64 `json:"remaining"`
Progress float64 `json:"progress"` // Percentage (0-100)
PeriodType models.PeriodType `json:"period_type"`
CurrentPeriod string `json:"current_period"`
IsRolling bool `json:"is_rolling"`
IsOverBudget bool `json:"is_over_budget"`
IsNearLimit bool `json:"is_near_limit"` // 80% threshold
CategoryID *uint `json:"category_id,omitempty"`
AccountID *uint `json:"account_id,omitempty"`
}
// BudgetService handles business logic for budgets
type BudgetService struct {
repo *repository.BudgetRepository
db *gorm.DB
}
// NewBudgetService creates a new BudgetService instance
func NewBudgetService(repo *repository.BudgetRepository, db *gorm.DB) *BudgetService {
return &BudgetService{
repo: repo,
db: db,
}
}
// CreateBudget creates a new budget with business logic validation
func (s *BudgetService) CreateBudget(input BudgetInput) (*models.Budget, error) {
// Validate amount
if input.Amount <= 0 {
return nil, ErrInvalidBudgetAmount
}
// 分类和账户都可选,支持全局预算
2026-01-25 21:59:00 +08:00
// Validate date range
if input.EndDate != nil && input.EndDate.Before(input.StartDate) {
return nil, ErrInvalidDateRange
}
// Validate period type
if !isValidPeriodType(input.PeriodType) {
return nil, ErrInvalidPeriodType
}
// Create the budget model
budget := &models.Budget{
UserID: input.UserID,
Name: input.Name,
Amount: input.Amount,
PeriodType: input.PeriodType,
CategoryID: input.CategoryID,
AccountID: input.AccountID,
IsRolling: input.IsRolling,
StartDate: input.StartDate,
EndDate: input.EndDate,
}
// Save to database
if err := s.repo.Create(budget); err != nil {
return nil, fmt.Errorf("failed to create budget: %w", err)
}
return budget, nil
}
// GetBudget retrieves a budget by ID and verifies ownership
func (s *BudgetService) GetBudget(userID, id uint) (*models.Budget, error) {
budget, err := s.repo.GetByID(userID, id)
if err != nil {
if errors.Is(err, repository.ErrBudgetNotFound) {
return nil, ErrBudgetNotFound
}
return nil, fmt.Errorf("failed to get budget: %w", err)
}
// userID check handled by repo
return budget, nil
}
// GetAllBudgets retrieves all budgets for a user
func (s *BudgetService) GetAllBudgets(userID uint) ([]models.Budget, error) {
budgets, err := s.repo.GetAll(userID)
if err != nil {
return nil, fmt.Errorf("failed to get budgets: %w", err)
}
return budgets, nil
}
// UpdateBudget updates an existing budget after verifying ownership
func (s *BudgetService) UpdateBudget(userID, id uint, input BudgetInput) (*models.Budget, error) {
// Get existing budget
budget, err := s.repo.GetByID(userID, id)
if err != nil {
if errors.Is(err, repository.ErrBudgetNotFound) {
return nil, ErrBudgetNotFound
}
return nil, fmt.Errorf("failed to get budget: %w", err)
}
// userID check handled by repo
// Validate amount
if input.Amount <= 0 {
return nil, ErrInvalidBudgetAmount
}
// 分类和账户都可选,支持全局预算
2026-01-25 21:59:00 +08:00
// Validate date range
if input.EndDate != nil && input.EndDate.Before(input.StartDate) {
return nil, ErrInvalidDateRange
}
// Validate period type
if !isValidPeriodType(input.PeriodType) {
return nil, ErrInvalidPeriodType
}
// Update fields
budget.Name = input.Name
budget.Amount = input.Amount
budget.PeriodType = input.PeriodType
budget.CategoryID = input.CategoryID
budget.AccountID = input.AccountID
budget.IsRolling = input.IsRolling
budget.StartDate = input.StartDate
budget.EndDate = input.EndDate
// Save to database
if err := s.repo.Update(budget); err != nil {
return nil, fmt.Errorf("failed to update budget: %w", err)
}
return budget, nil
}
// DeleteBudget deletes a budget by ID after verifying ownership
func (s *BudgetService) DeleteBudget(userID, id uint) error {
_, err := s.repo.GetByID(userID, id)
if err != nil {
if errors.Is(err, repository.ErrBudgetNotFound) {
return ErrBudgetNotFound
}
return fmt.Errorf("failed to check budget existence: %w", err)
}
// userID check handled by repo
err = s.repo.Delete(userID, id)
if err != nil {
if errors.Is(err, repository.ErrBudgetNotFound) {
return ErrBudgetNotFound
}
if errors.Is(err, repository.ErrBudgetInUse) {
return ErrBudgetInUse
}
return fmt.Errorf("failed to delete budget: %w", err)
}
return nil
}
// GetBudgetProgress calculates and returns the progress of a budget for a user
// This implements the core budget progress calculation logic for weekly, monthly, and rolling budgets
func (s *BudgetService) GetBudgetProgress(userID, id uint) (*BudgetProgress, error) {
// Get the budget
budget, err := s.repo.GetByID(userID, id)
if err != nil {
if errors.Is(err, repository.ErrBudgetNotFound) {
return nil, ErrBudgetNotFound
}
return nil, fmt.Errorf("failed to get budget: %w", err)
}
// userID check handled by repo
// Calculate the current period based on budget period type
now := time.Now()
startDate, endDate := s.calculateCurrentPeriod(budget, now)
// Get spent amount for the current period
currentSpent, err := s.repo.GetSpentAmount(budget, startDate, endDate)
2026-01-25 21:59:00 +08:00
if err != nil {
return nil, fmt.Errorf("failed to calculate spent amount: %w", err)
}
// Calculate effective budget amount
2026-01-25 21:59:00 +08:00
effectiveAmount := budget.Amount
totalSpent := currentSpent
2026-01-25 21:59:00 +08:00
if budget.IsRolling {
// 滚动预算:结余自动累加到下一周期
// 当期可用额度 = 总额度 - 历史支出
// 计算已过的完整周期数(不含当期)
periodsElapsed := s.calculatePeriodsElapsed(budget, startDate)
// 总额度 = (已过周期数 + 当期) × 单期额度
totalBudget := budget.Amount * float64(periodsElapsed+1)
// 获取历史支出(从预算开始到当期开始前一秒)
historyEnd := startDate.Add(-time.Second)
historySpent := 0.0
if historyEnd.After(budget.StartDate) {
historySpent, err = s.repo.GetSpentAmount(budget, budget.StartDate, historyEnd)
if err != nil {
return nil, fmt.Errorf("failed to calculate history spent: %w", err)
}
2026-01-25 21:59:00 +08:00
}
// 当期可用额度 = 总额度 - 历史支出
effectiveAmount = totalBudget - historySpent
totalSpent = currentSpent
2026-01-25 21:59:00 +08:00
}
// Calculate progress metrics
remaining := effectiveAmount - totalSpent
2026-01-25 21:59:00 +08:00
progress := 0.0
if effectiveAmount > 0 {
progress = (totalSpent / effectiveAmount) * 100
2026-01-25 21:59:00 +08:00
}
isOverBudget := totalSpent > effectiveAmount
2026-01-25 21:59:00 +08:00
isNearLimit := progress >= 80.0 && !isOverBudget
return &BudgetProgress{
BudgetID: budget.ID,
Name: budget.Name,
Amount: effectiveAmount,
Spent: totalSpent,
2026-01-25 21:59:00 +08:00
Remaining: remaining,
Progress: progress,
PeriodType: budget.PeriodType,
CurrentPeriod: formatPeriod(startDate, endDate),
IsRolling: budget.IsRolling,
IsOverBudget: isOverBudget,
IsNearLimit: isNearLimit,
CategoryID: budget.CategoryID,
AccountID: budget.AccountID,
}, nil
}
// GetAllBudgetProgress returns progress for all active budgets for a user
func (s *BudgetService) GetAllBudgetProgress(userID uint) ([]BudgetProgress, error) {
budgets, err := s.repo.GetActiveBudgets(userID, time.Now())
if err != nil {
return nil, fmt.Errorf("failed to get active budgets: %w", err)
}
var progressList []BudgetProgress
for _, budget := range budgets {
progress, err := s.GetBudgetProgress(userID, budget.ID)
if err != nil {
return nil, fmt.Errorf("failed to calculate progress for budget %d: %w", budget.ID, err)
}
progressList = append(progressList, *progress)
}
return progressList, nil
}
// calculateCurrentPeriod calculates the start and end date of the current budget period
func (s *BudgetService) calculateCurrentPeriod(budget *models.Budget, referenceDate time.Time) (time.Time, time.Time) {
switch budget.PeriodType {
case models.PeriodTypeDaily:
// Daily budget: current day
start := time.Date(referenceDate.Year(), referenceDate.Month(), referenceDate.Day(), 0, 0, 0, 0, referenceDate.Location())
end := start.AddDate(0, 0, 1).Add(-time.Second)
return start, end
case models.PeriodTypeWeekly:
// Weekly budget: current week (Monday to Sunday)
weekday := int(referenceDate.Weekday())
if weekday == 0 { // Sunday
weekday = 7
}
daysFromMonday := weekday - 1
start := time.Date(referenceDate.Year(), referenceDate.Month(), referenceDate.Day()-daysFromMonday, 0, 0, 0, 0, referenceDate.Location())
end := start.AddDate(0, 0, 7).Add(-time.Second)
return start, end
case models.PeriodTypeMonthly:
// Monthly budget: current month
start := time.Date(referenceDate.Year(), referenceDate.Month(), 1, 0, 0, 0, 0, referenceDate.Location())
end := start.AddDate(0, 1, 0).Add(-time.Second)
return start, end
case models.PeriodTypeYearly:
// Yearly budget: current year
start := time.Date(referenceDate.Year(), 1, 1, 0, 0, 0, 0, referenceDate.Location())
end := start.AddDate(1, 0, 0).Add(-time.Second)
return start, end
default:
// Default to monthly
start := time.Date(referenceDate.Year(), referenceDate.Month(), 1, 0, 0, 0, 0, referenceDate.Location())
end := start.AddDate(0, 1, 0).Add(-time.Second)
return start, end
}
}
// calculatePreviousPeriod calculates the start and end date of the previous budget period
func (s *BudgetService) calculatePreviousPeriod(budget *models.Budget, referenceDate time.Time) (time.Time, time.Time) {
switch budget.PeriodType {
case models.PeriodTypeDaily:
prevDay := referenceDate.AddDate(0, 0, -1)
return s.calculateCurrentPeriod(budget, prevDay)
case models.PeriodTypeWeekly:
prevWeek := referenceDate.AddDate(0, 0, -7)
return s.calculateCurrentPeriod(budget, prevWeek)
case models.PeriodTypeMonthly:
prevMonth := referenceDate.AddDate(0, -1, 0)
return s.calculateCurrentPeriod(budget, prevMonth)
case models.PeriodTypeYearly:
prevYear := referenceDate.AddDate(-1, 0, 0)
return s.calculateCurrentPeriod(budget, prevYear)
default:
prevMonth := referenceDate.AddDate(0, -1, 0)
return s.calculateCurrentPeriod(budget, prevMonth)
}
}
// calculatePeriodsElapsed 计算从预算开始日期到当前周期开始日期之间的完整周期数
func (s *BudgetService) calculatePeriodsElapsed(budget *models.Budget, currentPeriodStart time.Time) int {
if currentPeriodStart.Before(budget.StartDate) || currentPeriodStart.Equal(budget.StartDate) {
return 0
}
var periods int
switch budget.PeriodType {
case models.PeriodTypeDaily:
periods = int(currentPeriodStart.Sub(budget.StartDate).Hours() / 24)
case models.PeriodTypeWeekly:
periods = int(currentPeriodStart.Sub(budget.StartDate).Hours() / (24 * 7))
case models.PeriodTypeMonthly:
yearDiff := currentPeriodStart.Year() - budget.StartDate.Year()
monthDiff := int(currentPeriodStart.Month()) - int(budget.StartDate.Month())
periods = yearDiff*12 + monthDiff
case models.PeriodTypeYearly:
periods = currentPeriodStart.Year() - budget.StartDate.Year()
default:
yearDiff := currentPeriodStart.Year() - budget.StartDate.Year()
monthDiff := int(currentPeriodStart.Month()) - int(budget.StartDate.Month())
periods = yearDiff*12 + monthDiff
}
// 确保返回非负数
if periods < 0 {
return 0
}
return periods
}
2026-01-25 21:59:00 +08:00
// isValidPeriodType checks if a period type is valid
func isValidPeriodType(periodType models.PeriodType) bool {
switch periodType {
case models.PeriodTypeDaily, models.PeriodTypeWeekly, models.PeriodTypeMonthly, models.PeriodTypeYearly:
return true
default:
return false
}
}
// formatPeriod formats a period as a string
func formatPeriod(start, end time.Time) string {
return fmt.Sprintf("%s to %s", start.Format("2006-01-02"), end.Format("2006-01-02"))
}
// GetBudgetsByCategoryID retrieves all budgets for a specific category and user
func (s *BudgetService) GetBudgetsByCategoryID(userID, categoryID uint) ([]models.Budget, error) {
budgets, err := s.repo.GetByCategoryID(userID, categoryID)
if err != nil {
return nil, fmt.Errorf("failed to get budgets by category: %w", err)
}
return budgets, nil
}
// GetBudgetsByAccountID retrieves all budgets for a specific account and user
func (s *BudgetService) GetBudgetsByAccountID(userID, accountID uint) ([]models.Budget, error) {
budgets, err := s.repo.GetByAccountID(userID, accountID)
if err != nil {
return nil, fmt.Errorf("failed to get budgets by account: %w", err)
}
return budgets, nil
}
// GetActiveBudgets retrieves all currently active budgets for a user
func (s *BudgetService) GetActiveBudgets(userID uint) ([]models.Budget, error) {
budgets, err := s.repo.GetActiveBudgets(userID, time.Now())
if err != nil {
return nil, fmt.Errorf("failed to get active budgets: %w", err)
}
return budgets, nil
}