Files
Novault-backend/internal/service/budget_service.go

498 lines
16 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package service
import (
"errors"
"fmt"
"time"
"accounting-app/internal/models"
"accounting-app/internal/repository"
"gorm.io/gorm"
)
// Service layer errors for budgets
var (
ErrBudgetNotFound = errors.New("budget not found")
ErrBudgetInUse = errors.New("budget is in use and cannot be deleted")
ErrInvalidBudgetAmount = errors.New("budget amount must be positive")
ErrInvalidDateRange = errors.New("end date must be after start date")
ErrInvalidPeriodType = errors.New("invalid period type")
)
// BudgetInput represents the input data for creating or updating a budget
type BudgetInput struct {
UserID uint `json:"user_id"`
Name string `json:"name" binding:"required"`
Amount float64 `json:"amount" binding:"required,gt=0"`
PeriodType models.PeriodType `json:"period_type" binding:"required"`
CategoryID *uint `json:"category_id,omitempty"`
AccountID *uint `json:"account_id,omitempty"`
IsRolling bool `json:"is_rolling"`
StartDate time.Time `json:"start_date" binding:"required"`
EndDate *time.Time `json:"end_date,omitempty"`
}
// BudgetProgress represents the progress of a budget
type BudgetProgress struct {
BudgetID uint `json:"budget_id"`
Name string `json:"name"`
Amount float64 `json:"amount"`
Spent float64 `json:"spent"`
Remaining float64 `json:"remaining"`
Progress float64 `json:"progress"` // Percentage (0-100)
PeriodType models.PeriodType `json:"period_type"`
CurrentPeriod string `json:"current_period"`
IsRolling bool `json:"is_rolling"`
IsOverBudget bool `json:"is_over_budget"`
IsNearLimit bool `json:"is_near_limit"` // 80% threshold
CategoryID *uint `json:"category_id,omitempty"`
AccountID *uint `json:"account_id,omitempty"`
}
// BudgetService handles business logic for budgets
type BudgetService struct {
repo *repository.BudgetRepository
db *gorm.DB
}
// NewBudgetService creates a new BudgetService instance
func NewBudgetService(repo *repository.BudgetRepository, db *gorm.DB) *BudgetService {
return &BudgetService{
repo: repo,
db: db,
}
}
// CreateBudget creates a new budget with business logic validation
func (s *BudgetService) CreateBudget(input BudgetInput) (*models.Budget, error) {
// Validate amount
if input.Amount <= 0 {
return nil, ErrInvalidBudgetAmount
}
// 分类和账户都可选,支持全局预算
// Validate date range
if input.EndDate != nil && input.EndDate.Before(input.StartDate) {
return nil, ErrInvalidDateRange
}
// Validate period type
if !isValidPeriodType(input.PeriodType) {
return nil, ErrInvalidPeriodType
}
// Create the budget model
budget := &models.Budget{
UserID: input.UserID,
Name: input.Name,
Amount: input.Amount,
PeriodType: input.PeriodType,
CategoryID: input.CategoryID,
AccountID: input.AccountID,
IsRolling: input.IsRolling,
StartDate: input.StartDate,
EndDate: input.EndDate,
}
// Save to database
if err := s.repo.Create(budget); err != nil {
return nil, fmt.Errorf("failed to create budget: %w", err)
}
return budget, nil
}
// GetBudget retrieves a budget by ID and verifies ownership
func (s *BudgetService) GetBudget(userID, id uint) (*models.Budget, error) {
budget, err := s.repo.GetByID(userID, id)
if err != nil {
if errors.Is(err, repository.ErrBudgetNotFound) {
return nil, ErrBudgetNotFound
}
return nil, fmt.Errorf("failed to get budget: %w", err)
}
// userID check handled by repo
return budget, nil
}
// GetAllBudgets retrieves all budgets for a user
func (s *BudgetService) GetAllBudgets(userID uint) ([]models.Budget, error) {
budgets, err := s.repo.GetAll(userID)
if err != nil {
return nil, fmt.Errorf("failed to get budgets: %w", err)
}
return budgets, nil
}
// BudgetWithProgress represents a budget with calculated progress fields
type BudgetWithProgress struct {
models.Budget
Spent float64 `json:"spent"`
Progress float64 `json:"progress"`
}
// GetAllBudgetsWithProgress retrieves all budgets with calculated spent and progress for a user
func (s *BudgetService) GetAllBudgetsWithProgress(userID uint) ([]BudgetWithProgress, error) {
budgets, err := s.repo.GetAll(userID)
if err != nil {
return nil, fmt.Errorf("failed to get budgets: %w", err)
}
result := make([]BudgetWithProgress, len(budgets))
now := time.Now()
for i, budget := range budgets {
// Calculate current period
startDate, endDate := s.calculateCurrentPeriod(&budget, now)
// Get spent amount for current period
spent, err := s.repo.GetSpentAmount(&budget, startDate, endDate)
if err != nil {
// Log error but continue with 0 spent
spent = 0
}
// Calculate effective amount and progress
effectiveAmount := budget.Amount
if budget.IsRolling {
// For rolling budgets, calculate effective amount
periodsElapsed := s.calculatePeriodsElapsed(&budget, startDate)
totalBudget := budget.Amount * float64(periodsElapsed+1)
historyEnd := startDate.Add(-time.Second)
historySpent := 0.0
if historyEnd.After(budget.StartDate) {
historySpent, _ = s.repo.GetSpentAmount(&budget, budget.StartDate, historyEnd)
}
effectiveAmount = totalBudget - historySpent
}
progress := 0.0
if effectiveAmount > 0 {
progress = (spent / effectiveAmount) * 100
}
result[i] = BudgetWithProgress{
Budget: budget,
Spent: spent,
Progress: progress,
}
}
return result, nil
}
// UpdateBudget updates an existing budget after verifying ownership
func (s *BudgetService) UpdateBudget(userID, id uint, input BudgetInput) (*models.Budget, error) {
// Get existing budget
budget, err := s.repo.GetByID(userID, id)
if err != nil {
if errors.Is(err, repository.ErrBudgetNotFound) {
return nil, ErrBudgetNotFound
}
return nil, fmt.Errorf("failed to get budget: %w", err)
}
// userID check handled by repo
// Validate amount
if input.Amount <= 0 {
return nil, ErrInvalidBudgetAmount
}
// 分类和账户都可选,支持全局预算
// Validate date range
if input.EndDate != nil && input.EndDate.Before(input.StartDate) {
return nil, ErrInvalidDateRange
}
// Validate period type
if !isValidPeriodType(input.PeriodType) {
return nil, ErrInvalidPeriodType
}
// Update fields
budget.Name = input.Name
budget.Amount = input.Amount
budget.PeriodType = input.PeriodType
budget.CategoryID = input.CategoryID
budget.AccountID = input.AccountID
budget.IsRolling = input.IsRolling
budget.StartDate = input.StartDate
budget.EndDate = input.EndDate
// Save to database
if err := s.repo.Update(budget); err != nil {
return nil, fmt.Errorf("failed to update budget: %w", err)
}
return budget, nil
}
// DeleteBudget deletes a budget by ID after verifying ownership
func (s *BudgetService) DeleteBudget(userID, id uint) error {
_, err := s.repo.GetByID(userID, id)
if err != nil {
if errors.Is(err, repository.ErrBudgetNotFound) {
return ErrBudgetNotFound
}
return fmt.Errorf("failed to check budget existence: %w", err)
}
// userID check handled by repo
err = s.repo.Delete(userID, id)
if err != nil {
if errors.Is(err, repository.ErrBudgetNotFound) {
return ErrBudgetNotFound
}
if errors.Is(err, repository.ErrBudgetInUse) {
return ErrBudgetInUse
}
return fmt.Errorf("failed to delete budget: %w", err)
}
return nil
}
// GetBudgetProgress calculates and returns the progress of a budget for a user
// This implements the core budget progress calculation logic for weekly, monthly, and rolling budgets
func (s *BudgetService) GetBudgetProgress(userID, id uint) (*BudgetProgress, error) {
// Get the budget
budget, err := s.repo.GetByID(userID, id)
if err != nil {
if errors.Is(err, repository.ErrBudgetNotFound) {
return nil, ErrBudgetNotFound
}
return nil, fmt.Errorf("failed to get budget: %w", err)
}
// userID check handled by repo
// Calculate the current period based on budget period type
now := time.Now()
startDate, endDate := s.calculateCurrentPeriod(budget, now)
// Get spent amount for the current period
currentSpent, err := s.repo.GetSpentAmount(budget, startDate, endDate)
if err != nil {
return nil, fmt.Errorf("failed to calculate spent amount: %w", err)
}
// Calculate effective budget amount
effectiveAmount := budget.Amount
totalSpent := currentSpent
if budget.IsRolling {
// 滚动预算:结余自动累加到下一周期
// 当期可用额度 = 总额度 - 历史支出
// 计算已过的完整周期数(不含当期)
periodsElapsed := s.calculatePeriodsElapsed(budget, startDate)
// 总额度 = (已过周期数 + 当期) × 单期额度
totalBudget := budget.Amount * float64(periodsElapsed+1)
// 获取历史支出(从预算开始到当期开始前一秒)
historyEnd := startDate.Add(-time.Second)
historySpent := 0.0
if historyEnd.After(budget.StartDate) {
historySpent, err = s.repo.GetSpentAmount(budget, budget.StartDate, historyEnd)
if err != nil {
return nil, fmt.Errorf("failed to calculate history spent: %w", err)
}
}
// 当期可用额度 = 总额度 - 历史支出
effectiveAmount = totalBudget - historySpent
totalSpent = currentSpent
}
// Calculate progress metrics
remaining := effectiveAmount - totalSpent
progress := 0.0
if effectiveAmount > 0 {
progress = (totalSpent / effectiveAmount) * 100
}
isOverBudget := totalSpent > effectiveAmount
isNearLimit := progress >= 80.0 && !isOverBudget
return &BudgetProgress{
BudgetID: budget.ID,
Name: budget.Name,
Amount: effectiveAmount,
Spent: totalSpent,
Remaining: remaining,
Progress: progress,
PeriodType: budget.PeriodType,
CurrentPeriod: formatPeriod(startDate, endDate),
IsRolling: budget.IsRolling,
IsOverBudget: isOverBudget,
IsNearLimit: isNearLimit,
CategoryID: budget.CategoryID,
AccountID: budget.AccountID,
}, nil
}
// GetAllBudgetProgress returns progress for all active budgets for a user
func (s *BudgetService) GetAllBudgetProgress(userID uint) ([]BudgetProgress, error) {
budgets, err := s.repo.GetActiveBudgets(userID, time.Now())
if err != nil {
return nil, fmt.Errorf("failed to get active budgets: %w", err)
}
var progressList []BudgetProgress
for _, budget := range budgets {
progress, err := s.GetBudgetProgress(userID, budget.ID)
if err != nil {
return nil, fmt.Errorf("failed to calculate progress for budget %d: %w", budget.ID, err)
}
progressList = append(progressList, *progress)
}
return progressList, nil
}
// calculateCurrentPeriod calculates the start and end date of the current budget period
func (s *BudgetService) calculateCurrentPeriod(budget *models.Budget, referenceDate time.Time) (time.Time, time.Time) {
switch budget.PeriodType {
case models.PeriodTypeDaily:
// Daily budget: current day
start := time.Date(referenceDate.Year(), referenceDate.Month(), referenceDate.Day(), 0, 0, 0, 0, referenceDate.Location())
end := start.AddDate(0, 0, 1).Add(-time.Second)
return start, end
case models.PeriodTypeWeekly:
// Weekly budget: current week (Monday to Sunday)
weekday := int(referenceDate.Weekday())
if weekday == 0 { // Sunday
weekday = 7
}
daysFromMonday := weekday - 1
start := time.Date(referenceDate.Year(), referenceDate.Month(), referenceDate.Day()-daysFromMonday, 0, 0, 0, 0, referenceDate.Location())
end := start.AddDate(0, 0, 7).Add(-time.Second)
return start, end
case models.PeriodTypeMonthly:
// Monthly budget: current month
start := time.Date(referenceDate.Year(), referenceDate.Month(), 1, 0, 0, 0, 0, referenceDate.Location())
end := start.AddDate(0, 1, 0).Add(-time.Second)
return start, end
case models.PeriodTypeYearly:
// Yearly budget: current year
start := time.Date(referenceDate.Year(), 1, 1, 0, 0, 0, 0, referenceDate.Location())
end := start.AddDate(1, 0, 0).Add(-time.Second)
return start, end
default:
// Default to monthly
start := time.Date(referenceDate.Year(), referenceDate.Month(), 1, 0, 0, 0, 0, referenceDate.Location())
end := start.AddDate(0, 1, 0).Add(-time.Second)
return start, end
}
}
// calculatePreviousPeriod calculates the start and end date of the previous budget period
func (s *BudgetService) calculatePreviousPeriod(budget *models.Budget, referenceDate time.Time) (time.Time, time.Time) {
switch budget.PeriodType {
case models.PeriodTypeDaily:
prevDay := referenceDate.AddDate(0, 0, -1)
return s.calculateCurrentPeriod(budget, prevDay)
case models.PeriodTypeWeekly:
prevWeek := referenceDate.AddDate(0, 0, -7)
return s.calculateCurrentPeriod(budget, prevWeek)
case models.PeriodTypeMonthly:
prevMonth := referenceDate.AddDate(0, -1, 0)
return s.calculateCurrentPeriod(budget, prevMonth)
case models.PeriodTypeYearly:
prevYear := referenceDate.AddDate(-1, 0, 0)
return s.calculateCurrentPeriod(budget, prevYear)
default:
prevMonth := referenceDate.AddDate(0, -1, 0)
return s.calculateCurrentPeriod(budget, prevMonth)
}
}
// calculatePeriodsElapsed 计算从预算开始日期到当前周期开始日期之间的完整周期数
func (s *BudgetService) calculatePeriodsElapsed(budget *models.Budget, currentPeriodStart time.Time) int {
if currentPeriodStart.Before(budget.StartDate) || currentPeriodStart.Equal(budget.StartDate) {
return 0
}
var periods int
switch budget.PeriodType {
case models.PeriodTypeDaily:
periods = int(currentPeriodStart.Sub(budget.StartDate).Hours() / 24)
case models.PeriodTypeWeekly:
periods = int(currentPeriodStart.Sub(budget.StartDate).Hours() / (24 * 7))
case models.PeriodTypeMonthly:
yearDiff := currentPeriodStart.Year() - budget.StartDate.Year()
monthDiff := int(currentPeriodStart.Month()) - int(budget.StartDate.Month())
periods = yearDiff*12 + monthDiff
case models.PeriodTypeYearly:
periods = currentPeriodStart.Year() - budget.StartDate.Year()
default:
yearDiff := currentPeriodStart.Year() - budget.StartDate.Year()
monthDiff := int(currentPeriodStart.Month()) - int(budget.StartDate.Month())
periods = yearDiff*12 + monthDiff
}
// 确保返回非负数
if periods < 0 {
return 0
}
return periods
}
// isValidPeriodType checks if a period type is valid
func isValidPeriodType(periodType models.PeriodType) bool {
switch periodType {
case models.PeriodTypeDaily, models.PeriodTypeWeekly, models.PeriodTypeMonthly, models.PeriodTypeYearly:
return true
default:
return false
}
}
// formatPeriod formats a period as a string
func formatPeriod(start, end time.Time) string {
return fmt.Sprintf("%s to %s", start.Format("2006-01-02"), end.Format("2006-01-02"))
}
// GetBudgetsByCategoryID retrieves all budgets for a specific category and user
func (s *BudgetService) GetBudgetsByCategoryID(userID, categoryID uint) ([]models.Budget, error) {
budgets, err := s.repo.GetByCategoryID(userID, categoryID)
if err != nil {
return nil, fmt.Errorf("failed to get budgets by category: %w", err)
}
return budgets, nil
}
// GetBudgetsByAccountID retrieves all budgets for a specific account and user
func (s *BudgetService) GetBudgetsByAccountID(userID, accountID uint) ([]models.Budget, error) {
budgets, err := s.repo.GetByAccountID(userID, accountID)
if err != nil {
return nil, fmt.Errorf("failed to get budgets by account: %w", err)
}
return budgets, nil
}
// GetActiveBudgets retrieves all currently active budgets for a user
func (s *BudgetService) GetActiveBudgets(userID uint) ([]models.Budget, error) {
budgets, err := s.repo.GetActiveBudgets(userID, time.Now())
if err != nil {
return nil, fmt.Errorf("failed to get active budgets: %w", err)
}
return budgets, nil
}