package service import ( "errors" "fmt" "accounting-app/internal/models" "accounting-app/internal/repository" ) // Category service errors var ( ErrCategoryNotFound = errors.New("category not found") ErrCategoryInUse = errors.New("category is in use and cannot be deleted") ErrCategoryHasChildren = errors.New("category has children and cannot be deleted") ErrInvalidParentCategory = errors.New("invalid parent category") ErrParentTypeMismatch = errors.New("parent category type must match child category type") ErrCircularReference = errors.New("circular reference detected in category hierarchy") ErrParentIsChild = errors.New("cannot set a child category as parent") ) // CategoryInput represents the input data for creating or updating a category type CategoryInput struct { UserID uint `json:"user_id"` Name string `json:"name" binding:"required"` Icon string `json:"icon"` Color string `json:"color"` // HEX color code (e.g., #FF6B35) Type models.CategoryType `json:"type" binding:"required"` ParentID *uint `json:"parent_id,omitempty"` SortOrder int `json:"sort_order"` } // CategoryService handles business logic for categories type CategoryService struct { repo *repository.CategoryRepository } // NewCategoryService creates a new CategoryService instance func NewCategoryService(repo *repository.CategoryRepository) *CategoryService { return &CategoryService{ repo: repo, } } // CreateCategory creates a new category with business logic validation func (s *CategoryService) CreateCategory(input CategoryInput) (*models.Category, error) { // Validate parent category if provided if input.ParentID != nil { parent, err := s.repo.GetByID(input.UserID, *input.ParentID) if err != nil { if errors.Is(err, repository.ErrCategoryNotFound) { return nil, ErrInvalidParentCategory } return nil, fmt.Errorf("failed to validate parent category: %w", err) } // userID check handled by repo // Ensure parent category type matches the new category type if parent.Type != input.Type { return nil, ErrParentTypeMismatch } // Ensure parent is not already a child (only allow 2 levels) if parent.ParentID != nil { return nil, ErrParentIsChild } } // Create the category model category := &models.Category{ UserID: input.UserID, Name: input.Name, Icon: input.Icon, Color: input.Color, Type: input.Type, ParentID: input.ParentID, SortOrder: input.SortOrder, } // Save to database if err := s.repo.Create(category); err != nil { return nil, fmt.Errorf("failed to create category: %w", err) } return category, nil } // GetCategory retrieves a category by ID and verifies ownership func (s *CategoryService) GetCategory(userID, id uint) (*models.Category, error) { category, err := s.repo.GetByID(userID, id) if err != nil { if errors.Is(err, repository.ErrCategoryNotFound) { return nil, ErrCategoryNotFound } return nil, fmt.Errorf("failed to get category: %w", err) } // userID check handled by repo return category, nil } // GetCategoryWithChildren retrieves a category with its children and verifies ownership func (s *CategoryService) GetCategoryWithChildren(userID, id uint) (*models.Category, error) { category, err := s.repo.GetWithChildren(userID, id) if err != nil { if errors.Is(err, repository.ErrCategoryNotFound) { return nil, ErrCategoryNotFound } return nil, fmt.Errorf("failed to get category with children: %w", err) } // userID check handled by repo return category, nil } // GetAllCategories retrieves all categories for a user func (s *CategoryService) GetAllCategories(userID uint) ([]models.Category, error) { categories, err := s.repo.GetAll(userID) if err != nil { return nil, fmt.Errorf("failed to get categories: %w", err) } // Auto-seed if empty if len(categories) == 0 { if err := s.initDefaultCategories(userID); err != nil { // Log error but return empty list to avoid crashing fmt.Printf("Failed to seed default categories for user %d: %v\n", userID, err) return []models.Category{}, nil } // Fetch again after seeding return s.repo.GetAll(userID) } return categories, nil } // GetCategoriesByType retrieves all categories of a specific type for a user func (s *CategoryService) GetCategoriesByType(userID uint, categoryType models.CategoryType) ([]models.Category, error) { categories, err := s.repo.GetByType(userID, categoryType) if err != nil { return nil, fmt.Errorf("failed to get categories by type: %w", err) } // We don't auto-seed here because the user might just have no categories of this specific type // but might have categories of the other type. // But if they have absolutely no categories, GetAllCategories would have handled it. // We can trust GetAllCategories or just leave this as is. return categories, nil } // GetCategoryTree retrieves all categories in a hierarchical tree structure for a user // Returns only root categories with their children preloaded func (s *CategoryService) GetCategoryTree(userID uint) ([]models.Category, error) { categories, err := s.repo.GetAllWithChildren(userID) if err != nil { return nil, fmt.Errorf("failed to get category tree: %w", err) } // Auto-seed if empty if len(categories) == 0 { if err := s.initDefaultCategories(userID); err != nil { fmt.Printf("Failed to seed default categories for user %d: %v\n", userID, err) return []models.Category{}, nil } return s.repo.GetAllWithChildren(userID) } return categories, nil } // GetCategoryTreeByType retrieves categories of a specific type in a hierarchical tree structure for a user func (s *CategoryService) GetCategoryTreeByType(userID uint, categoryType models.CategoryType) ([]models.Category, error) { // Get root categories of the specified type rootCategories, err := s.repo.GetRootCategoriesByType(userID, categoryType) if err != nil { return nil, fmt.Errorf("failed to get root categories by type: %w", err) } // Load children for each root category for i := range rootCategories { children, err := s.repo.GetChildren(userID, rootCategories[i].ID) if err != nil { return nil, fmt.Errorf("failed to get children for category %d: %w", rootCategories[i].ID, err) } rootCategories[i].Children = children } return rootCategories, nil } // GetRootCategories retrieves all root categories (categories without parent) for a user func (s *CategoryService) GetRootCategories(userID uint) ([]models.Category, error) { categories, err := s.repo.GetRootCategories(userID) if err != nil { return nil, fmt.Errorf("failed to get root categories: %w", err) } return categories, nil } // GetChildCategories retrieves all child categories of a given parent func (s *CategoryService) GetChildCategories(userID, parentID uint) ([]models.Category, error) { // Verify parent exists _, err := s.repo.GetByID(userID, parentID) if err != nil { if errors.Is(err, repository.ErrCategoryNotFound) { return nil, ErrCategoryNotFound } return nil, fmt.Errorf("failed to verify parent category: %w", err) } // userID check handled by repo children, err := s.repo.GetChildren(userID, parentID) if err != nil { return nil, fmt.Errorf("failed to get child categories: %w", err) } return children, nil } // UpdateCategory updates an existing category after verifying ownership func (s *CategoryService) UpdateCategory(userID, id uint, input CategoryInput) (*models.Category, error) { // Get existing category category, err := s.repo.GetByID(userID, id) if err != nil { if errors.Is(err, repository.ErrCategoryNotFound) { return nil, ErrCategoryNotFound } return nil, fmt.Errorf("failed to get category: %w", err) } // userID check handled by repo // Validate parent category if provided if input.ParentID != nil { // Cannot set self as parent if *input.ParentID == id { return nil, ErrCircularReference } parent, err := s.repo.GetByID(userID, *input.ParentID) if err != nil { if errors.Is(err, repository.ErrCategoryNotFound) { return nil, ErrInvalidParentCategory } return nil, fmt.Errorf("failed to validate parent category: %w", err) } // userID check handled by repo // Ensure parent category type matches if parent.Type != input.Type { return nil, ErrParentTypeMismatch } // Ensure parent is not already a child (only allow 2 levels) if parent.ParentID != nil { return nil, ErrParentIsChild } // Check if the new parent is a child of the current category (circular reference) if parent.ParentID != nil && *parent.ParentID == id { return nil, ErrCircularReference } } // If this category has children and we're trying to set a parent, reject // (would create more than 2 levels) if input.ParentID != nil { children, err := s.repo.GetChildren(userID, id) if err != nil { return nil, fmt.Errorf("failed to check children: %w", err) } if len(children) > 0 { return nil, ErrParentIsChild } } // Update fields category.Name = input.Name category.Icon = input.Icon category.Type = input.Type category.ParentID = input.ParentID category.SortOrder = input.SortOrder // Save to database if err := s.repo.Update(category); err != nil { return nil, fmt.Errorf("failed to update category: %w", err) } return category, nil } // DeleteCategory deletes a category by ID after verifying ownership func (s *CategoryService) DeleteCategory(userID, id uint) error { _, err := s.repo.GetByID(userID, id) if err != nil { if errors.Is(err, repository.ErrCategoryNotFound) { return ErrCategoryNotFound } return fmt.Errorf("failed to check category existence: %w", err) } // userID check handled by repo err = s.repo.Delete(userID, id) if err != nil { if errors.Is(err, repository.ErrCategoryNotFound) { return ErrCategoryNotFound } if errors.Is(err, repository.ErrCategoryInUse) { return ErrCategoryInUse } if errors.Is(err, repository.ErrCategoryHasChildren) { return ErrCategoryHasChildren } return fmt.Errorf("failed to delete category: %w", err) } return nil } // CategoryExists checks if a category exists by ID func (s *CategoryService) CategoryExists(userID uint, id uint) (bool, error) { exists, err := s.repo.ExistsByID(userID, id) if err != nil { return false, fmt.Errorf("failed to check category existence: %w", err) } return exists, nil } // GetCategoryPath returns the full path of a category (parent -> child) func (s *CategoryService) GetCategoryPath(userID, id uint) ([]models.Category, error) { category, err := s.repo.GetWithParent(userID, id) if err != nil { if errors.Is(err, repository.ErrCategoryNotFound) { return nil, ErrCategoryNotFound } return nil, fmt.Errorf("failed to get category: %w", err) } // userID check handled by repo path := []models.Category{} if category.Parent != nil { path = append(path, *category.Parent) } path = append(path, *category) return path, nil } // initDefaultCategories seeds default categories for a user from the default_categories template table // If no templates are found in the database, falls back to hardcoded defaults func (s *CategoryService) initDefaultCategories(userID uint) error { db := s.repo.GetDB() // Try to copy from default_categories template table first err := models.CopyDefaultCategoriesToUser(db, userID) if err == nil { // Check if any categories were actually created count, countErr := s.repo.CountByType(userID, models.CategoryTypeExpense) if countErr == nil && count > 0 { return nil // Successfully copied from template } } // Fallback to hardcoded defaults if template table is empty or failed return s.initHardcodedDefaults(userID) } // initHardcodedDefaults seeds hardcoded default categories for a user (fallback) func (s *CategoryService) initHardcodedDefaults(userID uint) error { defaults := []struct { Name string Type models.CategoryType Icon string Color string SortOrder int Children []struct { Name string Type models.CategoryType Icon string Color string SortOrder int } }{ // Expenses { Name: "餐饮", Type: models.CategoryTypeExpense, Icon: "mdi:silverware-fork-knife", Color: "#FF6B35", SortOrder: 1, Children: []struct { Name string Type models.CategoryType Icon string Color string SortOrder int }{ {Name: "早餐", Type: models.CategoryTypeExpense, Icon: "mdi:food-croissant", Color: "#FBBF24", SortOrder: 1}, {Name: "午餐", Type: models.CategoryTypeExpense, Icon: "mdi:food", Color: "#FB923C", SortOrder: 2}, {Name: "晚餐", Type: models.CategoryTypeExpense, Icon: "mdi:food-turkey", Color: "#F97316", SortOrder: 3}, {Name: "零食", Type: models.CategoryTypeExpense, Icon: "mdi:cookie", Color: "#FDE047", SortOrder: 4}, {Name: "饮料", Type: models.CategoryTypeExpense, Icon: "mdi:coffee", Color: "#A16207", SortOrder: 5}, }, }, { Name: "交通", Type: models.CategoryTypeExpense, Icon: "mdi:bus", Color: "#3B82F6", SortOrder: 2, Children: []struct { Name string Type models.CategoryType Icon string Color string SortOrder int }{ {Name: "地铁", Type: models.CategoryTypeExpense, Icon: "mdi:subway-variant", Color: "#3B82F6", SortOrder: 1}, {Name: "公交", Type: models.CategoryTypeExpense, Icon: "mdi:bus", Color: "#60A5FA", SortOrder: 2}, {Name: "打车", Type: models.CategoryTypeExpense, Icon: "mdi:taxi", Color: "#FBBF24", SortOrder: 3}, {Name: "加油", Type: models.CategoryTypeExpense, Icon: "mdi:gas-station", Color: "#EF4444", SortOrder: 4}, }, }, { Name: "购物", Type: models.CategoryTypeExpense, Icon: "mdi:shopping", Color: "#EC4899", SortOrder: 3, Children: []struct { Name string Type models.CategoryType Icon string Color string SortOrder int }{ {Name: "服饰", Type: models.CategoryTypeExpense, Icon: "mdi:tshirt-crew", Color: "#EC4899", SortOrder: 1}, {Name: "日用", Type: models.CategoryTypeExpense, Icon: "mdi:basket", Color: "#F472B6", SortOrder: 2}, {Name: "电子数码", Type: models.CategoryTypeExpense, Icon: "mdi:laptop", Color: "#3B82F6", SortOrder: 3}, }, }, { Name: "居住", Type: models.CategoryTypeExpense, Icon: "mdi:home", Color: "#92400E", SortOrder: 4, Children: []struct { Name string Type models.CategoryType Icon string Color string SortOrder int }{ {Name: "房租", Type: models.CategoryTypeExpense, Icon: "mdi:home-city", Color: "#92400E", SortOrder: 1}, {Name: "水电煤", Type: models.CategoryTypeExpense, Icon: "mdi:lightbulb-on", Color: "#FBBF24", SortOrder: 2}, {Name: "物业", Type: models.CategoryTypeExpense, Icon: "mdi:office-building", Color: "#64748B", SortOrder: 3}, }, }, { Name: "娱乐", Type: models.CategoryTypeExpense, Icon: "mdi:gamepad-variant", Color: "#8B5CF6", SortOrder: 5, Children: []struct { Name string Type models.CategoryType Icon string Color string SortOrder int }{ {Name: "游戏", Type: models.CategoryTypeExpense, Icon: "mdi:controller-classic", Color: "#8B5CF6", SortOrder: 1}, {Name: "电影", Type: models.CategoryTypeExpense, Icon: "mdi:movie-open", Color: "#EF4444", SortOrder: 2}, {Name: "运动", Type: models.CategoryTypeExpense, Icon: "mdi:dumbbell", Color: "#22C55E", SortOrder: 3}, }, }, // Income { Name: "工作收入", Type: models.CategoryTypeIncome, Icon: "mdi:briefcase", Color: "#10B981", SortOrder: 10, Children: []struct { Name string Type models.CategoryType Icon string Color string SortOrder int }{ {Name: "工资", Type: models.CategoryTypeIncome, Icon: "mdi:briefcase", Color: "#10B981", SortOrder: 1}, {Name: "奖金", Type: models.CategoryTypeIncome, Icon: "mdi:trophy", Color: "#FBBF24", SortOrder: 2}, {Name: "兼职", Type: models.CategoryTypeIncome, Icon: "mdi:briefcase-clock", Color: "#8B5CF6", SortOrder: 3}, {Name: "理财", Type: models.CategoryTypeIncome, Icon: "mdi:chart-line", Color: "#F59E0B", SortOrder: 4}, }, }, } for _, cat := range defaults { parent := &models.Category{ UserID: userID, Name: cat.Name, Type: cat.Type, Icon: cat.Icon, Color: cat.Color, SortOrder: cat.SortOrder, } if err := s.repo.Create(parent); err != nil { return err } for _, child := range cat.Children { c := &models.Category{ UserID: userID, Name: child.Name, Type: child.Type, Icon: child.Icon, Color: child.Color, SortOrder: child.SortOrder, ParentID: &parent.ID, } if err := s.repo.Create(c); err != nil { return err } } } return nil }