498 lines
16 KiB
Go
498 lines
16 KiB
Go
package service
|
||
|
||
import (
|
||
"errors"
|
||
"fmt"
|
||
"time"
|
||
|
||
"accounting-app/internal/models"
|
||
"accounting-app/internal/repository"
|
||
|
||
"gorm.io/gorm"
|
||
)
|
||
|
||
// Service layer errors for budgets
|
||
var (
|
||
ErrBudgetNotFound = errors.New("budget not found")
|
||
ErrBudgetInUse = errors.New("budget is in use and cannot be deleted")
|
||
ErrInvalidBudgetAmount = errors.New("budget amount must be positive")
|
||
ErrInvalidDateRange = errors.New("end date must be after start date")
|
||
ErrInvalidPeriodType = errors.New("invalid period type")
|
||
)
|
||
|
||
// BudgetInput represents the input data for creating or updating a budget
|
||
type BudgetInput struct {
|
||
UserID uint `json:"user_id"`
|
||
Name string `json:"name" binding:"required"`
|
||
Amount float64 `json:"amount" binding:"required,gt=0"`
|
||
PeriodType models.PeriodType `json:"period_type" binding:"required"`
|
||
CategoryID *uint `json:"category_id,omitempty"`
|
||
AccountID *uint `json:"account_id,omitempty"`
|
||
IsRolling bool `json:"is_rolling"`
|
||
StartDate time.Time `json:"start_date" binding:"required"`
|
||
EndDate *time.Time `json:"end_date,omitempty"`
|
||
}
|
||
|
||
// BudgetProgress represents the progress of a budget
|
||
type BudgetProgress struct {
|
||
BudgetID uint `json:"budget_id"`
|
||
Name string `json:"name"`
|
||
Amount float64 `json:"amount"`
|
||
Spent float64 `json:"spent"`
|
||
Remaining float64 `json:"remaining"`
|
||
Progress float64 `json:"progress"` // Percentage (0-100)
|
||
PeriodType models.PeriodType `json:"period_type"`
|
||
CurrentPeriod string `json:"current_period"`
|
||
IsRolling bool `json:"is_rolling"`
|
||
IsOverBudget bool `json:"is_over_budget"`
|
||
IsNearLimit bool `json:"is_near_limit"` // 80% threshold
|
||
CategoryID *uint `json:"category_id,omitempty"`
|
||
AccountID *uint `json:"account_id,omitempty"`
|
||
}
|
||
|
||
// BudgetService handles business logic for budgets
|
||
type BudgetService struct {
|
||
repo *repository.BudgetRepository
|
||
db *gorm.DB
|
||
}
|
||
|
||
// NewBudgetService creates a new BudgetService instance
|
||
func NewBudgetService(repo *repository.BudgetRepository, db *gorm.DB) *BudgetService {
|
||
return &BudgetService{
|
||
repo: repo,
|
||
db: db,
|
||
}
|
||
}
|
||
|
||
// CreateBudget creates a new budget with business logic validation
|
||
func (s *BudgetService) CreateBudget(input BudgetInput) (*models.Budget, error) {
|
||
// Validate amount
|
||
if input.Amount <= 0 {
|
||
return nil, ErrInvalidBudgetAmount
|
||
}
|
||
|
||
// 分类和账户都可选,支持全局预算
|
||
|
||
// Validate date range
|
||
if input.EndDate != nil && input.EndDate.Before(input.StartDate) {
|
||
return nil, ErrInvalidDateRange
|
||
}
|
||
|
||
// Validate period type
|
||
if !isValidPeriodType(input.PeriodType) {
|
||
return nil, ErrInvalidPeriodType
|
||
}
|
||
|
||
// Create the budget model
|
||
budget := &models.Budget{
|
||
UserID: input.UserID,
|
||
Name: input.Name,
|
||
Amount: input.Amount,
|
||
PeriodType: input.PeriodType,
|
||
CategoryID: input.CategoryID,
|
||
AccountID: input.AccountID,
|
||
IsRolling: input.IsRolling,
|
||
StartDate: input.StartDate,
|
||
EndDate: input.EndDate,
|
||
}
|
||
|
||
// Save to database
|
||
if err := s.repo.Create(budget); err != nil {
|
||
return nil, fmt.Errorf("failed to create budget: %w", err)
|
||
}
|
||
|
||
return budget, nil
|
||
}
|
||
|
||
// GetBudget retrieves a budget by ID and verifies ownership
|
||
func (s *BudgetService) GetBudget(userID, id uint) (*models.Budget, error) {
|
||
budget, err := s.repo.GetByID(userID, id)
|
||
if err != nil {
|
||
if errors.Is(err, repository.ErrBudgetNotFound) {
|
||
return nil, ErrBudgetNotFound
|
||
}
|
||
return nil, fmt.Errorf("failed to get budget: %w", err)
|
||
}
|
||
// userID check handled by repo
|
||
return budget, nil
|
||
}
|
||
|
||
// GetAllBudgets retrieves all budgets for a user
|
||
func (s *BudgetService) GetAllBudgets(userID uint) ([]models.Budget, error) {
|
||
budgets, err := s.repo.GetAll(userID)
|
||
if err != nil {
|
||
return nil, fmt.Errorf("failed to get budgets: %w", err)
|
||
}
|
||
return budgets, nil
|
||
}
|
||
|
||
// BudgetWithProgress represents a budget with calculated progress fields
|
||
type BudgetWithProgress struct {
|
||
models.Budget
|
||
Spent float64 `json:"spent"`
|
||
Progress float64 `json:"progress"`
|
||
}
|
||
|
||
// GetAllBudgetsWithProgress retrieves all budgets with calculated spent and progress for a user
|
||
func (s *BudgetService) GetAllBudgetsWithProgress(userID uint) ([]BudgetWithProgress, error) {
|
||
budgets, err := s.repo.GetAll(userID)
|
||
if err != nil {
|
||
return nil, fmt.Errorf("failed to get budgets: %w", err)
|
||
}
|
||
|
||
result := make([]BudgetWithProgress, len(budgets))
|
||
now := time.Now()
|
||
|
||
for i, budget := range budgets {
|
||
// Calculate current period
|
||
startDate, endDate := s.calculateCurrentPeriod(&budget, now)
|
||
|
||
// Get spent amount for current period
|
||
spent, err := s.repo.GetSpentAmount(&budget, startDate, endDate)
|
||
if err != nil {
|
||
// Log error but continue with 0 spent
|
||
spent = 0
|
||
}
|
||
|
||
// Calculate effective amount and progress
|
||
effectiveAmount := budget.Amount
|
||
|
||
if budget.IsRolling {
|
||
// For rolling budgets, calculate effective amount
|
||
periodsElapsed := s.calculatePeriodsElapsed(&budget, startDate)
|
||
totalBudget := budget.Amount * float64(periodsElapsed+1)
|
||
|
||
historyEnd := startDate.Add(-time.Second)
|
||
historySpent := 0.0
|
||
if historyEnd.After(budget.StartDate) {
|
||
historySpent, _ = s.repo.GetSpentAmount(&budget, budget.StartDate, historyEnd)
|
||
}
|
||
effectiveAmount = totalBudget - historySpent
|
||
}
|
||
|
||
progress := 0.0
|
||
if effectiveAmount > 0 {
|
||
progress = (spent / effectiveAmount) * 100
|
||
}
|
||
|
||
result[i] = BudgetWithProgress{
|
||
Budget: budget,
|
||
Spent: spent,
|
||
Progress: progress,
|
||
}
|
||
}
|
||
|
||
return result, nil
|
||
}
|
||
|
||
// UpdateBudget updates an existing budget after verifying ownership
|
||
func (s *BudgetService) UpdateBudget(userID, id uint, input BudgetInput) (*models.Budget, error) {
|
||
// Get existing budget
|
||
budget, err := s.repo.GetByID(userID, id)
|
||
if err != nil {
|
||
if errors.Is(err, repository.ErrBudgetNotFound) {
|
||
return nil, ErrBudgetNotFound
|
||
}
|
||
return nil, fmt.Errorf("failed to get budget: %w", err)
|
||
}
|
||
// userID check handled by repo
|
||
|
||
// Validate amount
|
||
if input.Amount <= 0 {
|
||
return nil, ErrInvalidBudgetAmount
|
||
}
|
||
|
||
// 分类和账户都可选,支持全局预算
|
||
|
||
// Validate date range
|
||
if input.EndDate != nil && input.EndDate.Before(input.StartDate) {
|
||
return nil, ErrInvalidDateRange
|
||
}
|
||
|
||
// Validate period type
|
||
if !isValidPeriodType(input.PeriodType) {
|
||
return nil, ErrInvalidPeriodType
|
||
}
|
||
|
||
// Update fields
|
||
budget.Name = input.Name
|
||
budget.Amount = input.Amount
|
||
budget.PeriodType = input.PeriodType
|
||
budget.CategoryID = input.CategoryID
|
||
budget.AccountID = input.AccountID
|
||
budget.IsRolling = input.IsRolling
|
||
budget.StartDate = input.StartDate
|
||
budget.EndDate = input.EndDate
|
||
|
||
// Save to database
|
||
if err := s.repo.Update(budget); err != nil {
|
||
return nil, fmt.Errorf("failed to update budget: %w", err)
|
||
}
|
||
|
||
return budget, nil
|
||
}
|
||
|
||
// DeleteBudget deletes a budget by ID after verifying ownership
|
||
func (s *BudgetService) DeleteBudget(userID, id uint) error {
|
||
_, err := s.repo.GetByID(userID, id)
|
||
if err != nil {
|
||
if errors.Is(err, repository.ErrBudgetNotFound) {
|
||
return ErrBudgetNotFound
|
||
}
|
||
return fmt.Errorf("failed to check budget existence: %w", err)
|
||
}
|
||
// userID check handled by repo
|
||
|
||
err = s.repo.Delete(userID, id)
|
||
if err != nil {
|
||
if errors.Is(err, repository.ErrBudgetNotFound) {
|
||
return ErrBudgetNotFound
|
||
}
|
||
if errors.Is(err, repository.ErrBudgetInUse) {
|
||
return ErrBudgetInUse
|
||
}
|
||
return fmt.Errorf("failed to delete budget: %w", err)
|
||
}
|
||
return nil
|
||
}
|
||
|
||
// GetBudgetProgress calculates and returns the progress of a budget for a user
|
||
// This implements the core budget progress calculation logic for weekly, monthly, and rolling budgets
|
||
func (s *BudgetService) GetBudgetProgress(userID, id uint) (*BudgetProgress, error) {
|
||
// Get the budget
|
||
budget, err := s.repo.GetByID(userID, id)
|
||
if err != nil {
|
||
if errors.Is(err, repository.ErrBudgetNotFound) {
|
||
return nil, ErrBudgetNotFound
|
||
}
|
||
return nil, fmt.Errorf("failed to get budget: %w", err)
|
||
}
|
||
// userID check handled by repo
|
||
|
||
// Calculate the current period based on budget period type
|
||
now := time.Now()
|
||
startDate, endDate := s.calculateCurrentPeriod(budget, now)
|
||
|
||
// Get spent amount for the current period
|
||
currentSpent, err := s.repo.GetSpentAmount(budget, startDate, endDate)
|
||
if err != nil {
|
||
return nil, fmt.Errorf("failed to calculate spent amount: %w", err)
|
||
}
|
||
|
||
// Calculate effective budget amount
|
||
effectiveAmount := budget.Amount
|
||
totalSpent := currentSpent
|
||
|
||
if budget.IsRolling {
|
||
// 滚动预算:结余自动累加到下一周期
|
||
// 当期可用额度 = 总额度 - 历史支出
|
||
|
||
// 计算已过的完整周期数(不含当期)
|
||
periodsElapsed := s.calculatePeriodsElapsed(budget, startDate)
|
||
|
||
// 总额度 = (已过周期数 + 当期) × 单期额度
|
||
totalBudget := budget.Amount * float64(periodsElapsed+1)
|
||
|
||
// 获取历史支出(从预算开始到当期开始前一秒)
|
||
historyEnd := startDate.Add(-time.Second)
|
||
historySpent := 0.0
|
||
if historyEnd.After(budget.StartDate) {
|
||
historySpent, err = s.repo.GetSpentAmount(budget, budget.StartDate, historyEnd)
|
||
if err != nil {
|
||
return nil, fmt.Errorf("failed to calculate history spent: %w", err)
|
||
}
|
||
}
|
||
|
||
// 当期可用额度 = 总额度 - 历史支出
|
||
effectiveAmount = totalBudget - historySpent
|
||
totalSpent = currentSpent
|
||
}
|
||
|
||
// Calculate progress metrics
|
||
remaining := effectiveAmount - totalSpent
|
||
progress := 0.0
|
||
if effectiveAmount > 0 {
|
||
progress = (totalSpent / effectiveAmount) * 100
|
||
}
|
||
|
||
isOverBudget := totalSpent > effectiveAmount
|
||
isNearLimit := progress >= 80.0 && !isOverBudget
|
||
|
||
return &BudgetProgress{
|
||
BudgetID: budget.ID,
|
||
Name: budget.Name,
|
||
Amount: effectiveAmount,
|
||
Spent: totalSpent,
|
||
Remaining: remaining,
|
||
Progress: progress,
|
||
PeriodType: budget.PeriodType,
|
||
CurrentPeriod: formatPeriod(startDate, endDate),
|
||
IsRolling: budget.IsRolling,
|
||
IsOverBudget: isOverBudget,
|
||
IsNearLimit: isNearLimit,
|
||
CategoryID: budget.CategoryID,
|
||
AccountID: budget.AccountID,
|
||
}, nil
|
||
}
|
||
|
||
// GetAllBudgetProgress returns progress for all active budgets for a user
|
||
func (s *BudgetService) GetAllBudgetProgress(userID uint) ([]BudgetProgress, error) {
|
||
budgets, err := s.repo.GetActiveBudgets(userID, time.Now())
|
||
if err != nil {
|
||
return nil, fmt.Errorf("failed to get active budgets: %w", err)
|
||
}
|
||
|
||
var progressList []BudgetProgress
|
||
for _, budget := range budgets {
|
||
progress, err := s.GetBudgetProgress(userID, budget.ID)
|
||
if err != nil {
|
||
return nil, fmt.Errorf("failed to calculate progress for budget %d: %w", budget.ID, err)
|
||
}
|
||
progressList = append(progressList, *progress)
|
||
}
|
||
|
||
return progressList, nil
|
||
}
|
||
|
||
// calculateCurrentPeriod calculates the start and end date of the current budget period
|
||
func (s *BudgetService) calculateCurrentPeriod(budget *models.Budget, referenceDate time.Time) (time.Time, time.Time) {
|
||
switch budget.PeriodType {
|
||
case models.PeriodTypeDaily:
|
||
// Daily budget: current day
|
||
start := time.Date(referenceDate.Year(), referenceDate.Month(), referenceDate.Day(), 0, 0, 0, 0, referenceDate.Location())
|
||
end := start.AddDate(0, 0, 1).Add(-time.Second)
|
||
return start, end
|
||
|
||
case models.PeriodTypeWeekly:
|
||
// Weekly budget: current week (Monday to Sunday)
|
||
weekday := int(referenceDate.Weekday())
|
||
if weekday == 0 { // Sunday
|
||
weekday = 7
|
||
}
|
||
daysFromMonday := weekday - 1
|
||
start := time.Date(referenceDate.Year(), referenceDate.Month(), referenceDate.Day()-daysFromMonday, 0, 0, 0, 0, referenceDate.Location())
|
||
end := start.AddDate(0, 0, 7).Add(-time.Second)
|
||
return start, end
|
||
|
||
case models.PeriodTypeMonthly:
|
||
// Monthly budget: current month
|
||
start := time.Date(referenceDate.Year(), referenceDate.Month(), 1, 0, 0, 0, 0, referenceDate.Location())
|
||
end := start.AddDate(0, 1, 0).Add(-time.Second)
|
||
return start, end
|
||
|
||
case models.PeriodTypeYearly:
|
||
// Yearly budget: current year
|
||
start := time.Date(referenceDate.Year(), 1, 1, 0, 0, 0, 0, referenceDate.Location())
|
||
end := start.AddDate(1, 0, 0).Add(-time.Second)
|
||
return start, end
|
||
|
||
default:
|
||
// Default to monthly
|
||
start := time.Date(referenceDate.Year(), referenceDate.Month(), 1, 0, 0, 0, 0, referenceDate.Location())
|
||
end := start.AddDate(0, 1, 0).Add(-time.Second)
|
||
return start, end
|
||
}
|
||
}
|
||
|
||
// calculatePreviousPeriod calculates the start and end date of the previous budget period
|
||
func (s *BudgetService) calculatePreviousPeriod(budget *models.Budget, referenceDate time.Time) (time.Time, time.Time) {
|
||
switch budget.PeriodType {
|
||
case models.PeriodTypeDaily:
|
||
prevDay := referenceDate.AddDate(0, 0, -1)
|
||
return s.calculateCurrentPeriod(budget, prevDay)
|
||
|
||
case models.PeriodTypeWeekly:
|
||
prevWeek := referenceDate.AddDate(0, 0, -7)
|
||
return s.calculateCurrentPeriod(budget, prevWeek)
|
||
|
||
case models.PeriodTypeMonthly:
|
||
prevMonth := referenceDate.AddDate(0, -1, 0)
|
||
return s.calculateCurrentPeriod(budget, prevMonth)
|
||
|
||
case models.PeriodTypeYearly:
|
||
prevYear := referenceDate.AddDate(-1, 0, 0)
|
||
return s.calculateCurrentPeriod(budget, prevYear)
|
||
|
||
default:
|
||
prevMonth := referenceDate.AddDate(0, -1, 0)
|
||
return s.calculateCurrentPeriod(budget, prevMonth)
|
||
}
|
||
}
|
||
|
||
// calculatePeriodsElapsed 计算从预算开始日期到当前周期开始日期之间的完整周期数
|
||
func (s *BudgetService) calculatePeriodsElapsed(budget *models.Budget, currentPeriodStart time.Time) int {
|
||
if currentPeriodStart.Before(budget.StartDate) || currentPeriodStart.Equal(budget.StartDate) {
|
||
return 0
|
||
}
|
||
|
||
var periods int
|
||
switch budget.PeriodType {
|
||
case models.PeriodTypeDaily:
|
||
periods = int(currentPeriodStart.Sub(budget.StartDate).Hours() / 24)
|
||
|
||
case models.PeriodTypeWeekly:
|
||
periods = int(currentPeriodStart.Sub(budget.StartDate).Hours() / (24 * 7))
|
||
|
||
case models.PeriodTypeMonthly:
|
||
yearDiff := currentPeriodStart.Year() - budget.StartDate.Year()
|
||
monthDiff := int(currentPeriodStart.Month()) - int(budget.StartDate.Month())
|
||
periods = yearDiff*12 + monthDiff
|
||
|
||
case models.PeriodTypeYearly:
|
||
periods = currentPeriodStart.Year() - budget.StartDate.Year()
|
||
|
||
default:
|
||
yearDiff := currentPeriodStart.Year() - budget.StartDate.Year()
|
||
monthDiff := int(currentPeriodStart.Month()) - int(budget.StartDate.Month())
|
||
periods = yearDiff*12 + monthDiff
|
||
}
|
||
|
||
// 确保返回非负数
|
||
if periods < 0 {
|
||
return 0
|
||
}
|
||
return periods
|
||
}
|
||
|
||
// isValidPeriodType checks if a period type is valid
|
||
func isValidPeriodType(periodType models.PeriodType) bool {
|
||
switch periodType {
|
||
case models.PeriodTypeDaily, models.PeriodTypeWeekly, models.PeriodTypeMonthly, models.PeriodTypeYearly:
|
||
return true
|
||
default:
|
||
return false
|
||
}
|
||
}
|
||
|
||
// formatPeriod formats a period as a string
|
||
func formatPeriod(start, end time.Time) string {
|
||
return fmt.Sprintf("%s to %s", start.Format("2006-01-02"), end.Format("2006-01-02"))
|
||
}
|
||
|
||
// GetBudgetsByCategoryID retrieves all budgets for a specific category and user
|
||
func (s *BudgetService) GetBudgetsByCategoryID(userID, categoryID uint) ([]models.Budget, error) {
|
||
budgets, err := s.repo.GetByCategoryID(userID, categoryID)
|
||
if err != nil {
|
||
return nil, fmt.Errorf("failed to get budgets by category: %w", err)
|
||
}
|
||
return budgets, nil
|
||
}
|
||
|
||
// GetBudgetsByAccountID retrieves all budgets for a specific account and user
|
||
func (s *BudgetService) GetBudgetsByAccountID(userID, accountID uint) ([]models.Budget, error) {
|
||
budgets, err := s.repo.GetByAccountID(userID, accountID)
|
||
if err != nil {
|
||
return nil, fmt.Errorf("failed to get budgets by account: %w", err)
|
||
}
|
||
return budgets, nil
|
||
}
|
||
|
||
// GetActiveBudgets retrieves all currently active budgets for a user
|
||
func (s *BudgetService) GetActiveBudgets(userID uint) ([]models.Budget, error) {
|
||
budgets, err := s.repo.GetActiveBudgets(userID, time.Now())
|
||
if err != nil {
|
||
return nil, fmt.Errorf("failed to get active budgets: %w", err)
|
||
}
|
||
return budgets, nil
|
||
}
|