package service import ( "errors" "fmt" "accounting-app/internal/models" "accounting-app/internal/repository" ) // Category service errors var ( ErrCategoryNotFound = errors.New("category not found") ErrCategoryInUse = errors.New("category is in use and cannot be deleted") ErrCategoryHasChildren = errors.New("category has children and cannot be deleted") ErrInvalidParentCategory = errors.New("invalid parent category") ErrParentTypeMismatch = errors.New("parent category type must match child category type") ErrCircularReference = errors.New("circular reference detected in category hierarchy") ErrParentIsChild = errors.New("cannot set a child category as parent") ) // CategoryInput represents the input data for creating or updating a category type CategoryInput struct { UserID uint `json:"user_id"` Name string `json:"name" binding:"required"` Icon string `json:"icon"` Type models.CategoryType `json:"type" binding:"required"` ParentID *uint `json:"parent_id,omitempty"` SortOrder int `json:"sort_order"` } // CategoryService handles business logic for categories type CategoryService struct { repo *repository.CategoryRepository } // NewCategoryService creates a new CategoryService instance func NewCategoryService(repo *repository.CategoryRepository) *CategoryService { return &CategoryService{ repo: repo, } } // CreateCategory creates a new category with business logic validation func (s *CategoryService) CreateCategory(input CategoryInput) (*models.Category, error) { // Validate parent category if provided if input.ParentID != nil { parent, err := s.repo.GetByID(input.UserID, *input.ParentID) if err != nil { if errors.Is(err, repository.ErrCategoryNotFound) { return nil, ErrInvalidParentCategory } return nil, fmt.Errorf("failed to validate parent category: %w", err) } // userID check handled by repo // Ensure parent category type matches the new category type if parent.Type != input.Type { return nil, ErrParentTypeMismatch } // Ensure parent is not already a child (only allow 2 levels) if parent.ParentID != nil { return nil, ErrParentIsChild } } // Create the category model category := &models.Category{ UserID: input.UserID, Name: input.Name, Icon: input.Icon, Type: input.Type, ParentID: input.ParentID, SortOrder: input.SortOrder, } // Save to database if err := s.repo.Create(category); err != nil { return nil, fmt.Errorf("failed to create category: %w", err) } return category, nil } // GetCategory retrieves a category by ID and verifies ownership func (s *CategoryService) GetCategory(userID, id uint) (*models.Category, error) { category, err := s.repo.GetByID(userID, id) if err != nil { if errors.Is(err, repository.ErrCategoryNotFound) { return nil, ErrCategoryNotFound } return nil, fmt.Errorf("failed to get category: %w", err) } // userID check handled by repo return category, nil } // GetCategoryWithChildren retrieves a category with its children and verifies ownership func (s *CategoryService) GetCategoryWithChildren(userID, id uint) (*models.Category, error) { category, err := s.repo.GetWithChildren(userID, id) if err != nil { if errors.Is(err, repository.ErrCategoryNotFound) { return nil, ErrCategoryNotFound } return nil, fmt.Errorf("failed to get category with children: %w", err) } // userID check handled by repo return category, nil } // GetAllCategories retrieves all categories for a user func (s *CategoryService) GetAllCategories(userID uint) ([]models.Category, error) { categories, err := s.repo.GetAll(userID) if err != nil { return nil, fmt.Errorf("failed to get categories: %w", err) } // Auto-seed if empty if len(categories) == 0 { if err := s.initDefaultCategories(userID); err != nil { // Log error but return empty list to avoid crashing fmt.Printf("Failed to seed default categories for user %d: %v\n", userID, err) return []models.Category{}, nil } // Fetch again after seeding return s.repo.GetAll(userID) } return categories, nil } // GetCategoriesByType retrieves all categories of a specific type for a user func (s *CategoryService) GetCategoriesByType(userID uint, categoryType models.CategoryType) ([]models.Category, error) { categories, err := s.repo.GetByType(userID, categoryType) if err != nil { return nil, fmt.Errorf("failed to get categories by type: %w", err) } // We don't auto-seed here because the user might just have no categories of this specific type // but might have categories of the other type. // But if they have absolutely no categories, GetAllCategories would have handled it. // We can trust GetAllCategories or just leave this as is. return categories, nil } // GetCategoryTree retrieves all categories in a hierarchical tree structure for a user // Returns only root categories with their children preloaded func (s *CategoryService) GetCategoryTree(userID uint) ([]models.Category, error) { categories, err := s.repo.GetAllWithChildren(userID) if err != nil { return nil, fmt.Errorf("failed to get category tree: %w", err) } // Auto-seed if empty if len(categories) == 0 { if err := s.initDefaultCategories(userID); err != nil { fmt.Printf("Failed to seed default categories for user %d: %v\n", userID, err) return []models.Category{}, nil } return s.repo.GetAllWithChildren(userID) } return categories, nil } // GetCategoryTreeByType retrieves categories of a specific type in a hierarchical tree structure for a user func (s *CategoryService) GetCategoryTreeByType(userID uint, categoryType models.CategoryType) ([]models.Category, error) { // Get root categories of the specified type rootCategories, err := s.repo.GetRootCategoriesByType(userID, categoryType) if err != nil { return nil, fmt.Errorf("failed to get root categories by type: %w", err) } // Load children for each root category for i := range rootCategories { children, err := s.repo.GetChildren(userID, rootCategories[i].ID) if err != nil { return nil, fmt.Errorf("failed to get children for category %d: %w", rootCategories[i].ID, err) } rootCategories[i].Children = children } return rootCategories, nil } // GetRootCategories retrieves all root categories (categories without parent) for a user func (s *CategoryService) GetRootCategories(userID uint) ([]models.Category, error) { categories, err := s.repo.GetRootCategories(userID) if err != nil { return nil, fmt.Errorf("failed to get root categories: %w", err) } return categories, nil } // GetChildCategories retrieves all child categories of a given parent func (s *CategoryService) GetChildCategories(userID, parentID uint) ([]models.Category, error) { // Verify parent exists _, err := s.repo.GetByID(userID, parentID) if err != nil { if errors.Is(err, repository.ErrCategoryNotFound) { return nil, ErrCategoryNotFound } return nil, fmt.Errorf("failed to verify parent category: %w", err) } // userID check handled by repo children, err := s.repo.GetChildren(userID, parentID) if err != nil { return nil, fmt.Errorf("failed to get child categories: %w", err) } return children, nil } // UpdateCategory updates an existing category after verifying ownership func (s *CategoryService) UpdateCategory(userID, id uint, input CategoryInput) (*models.Category, error) { // Get existing category category, err := s.repo.GetByID(userID, id) if err != nil { if errors.Is(err, repository.ErrCategoryNotFound) { return nil, ErrCategoryNotFound } return nil, fmt.Errorf("failed to get category: %w", err) } // userID check handled by repo // Validate parent category if provided if input.ParentID != nil { // Cannot set self as parent if *input.ParentID == id { return nil, ErrCircularReference } parent, err := s.repo.GetByID(userID, *input.ParentID) if err != nil { if errors.Is(err, repository.ErrCategoryNotFound) { return nil, ErrInvalidParentCategory } return nil, fmt.Errorf("failed to validate parent category: %w", err) } // userID check handled by repo // Ensure parent category type matches if parent.Type != input.Type { return nil, ErrParentTypeMismatch } // Ensure parent is not already a child (only allow 2 levels) if parent.ParentID != nil { return nil, ErrParentIsChild } // Check if the new parent is a child of the current category (circular reference) if parent.ParentID != nil && *parent.ParentID == id { return nil, ErrCircularReference } } // If this category has children and we're trying to set a parent, reject // (would create more than 2 levels) if input.ParentID != nil { children, err := s.repo.GetChildren(userID, id) if err != nil { return nil, fmt.Errorf("failed to check children: %w", err) } if len(children) > 0 { return nil, ErrParentIsChild } } // Update fields category.Name = input.Name category.Icon = input.Icon category.Type = input.Type category.ParentID = input.ParentID category.SortOrder = input.SortOrder // Save to database if err := s.repo.Update(category); err != nil { return nil, fmt.Errorf("failed to update category: %w", err) } return category, nil } // DeleteCategory deletes a category by ID after verifying ownership func (s *CategoryService) DeleteCategory(userID, id uint) error { _, err := s.repo.GetByID(userID, id) if err != nil { if errors.Is(err, repository.ErrCategoryNotFound) { return ErrCategoryNotFound } return fmt.Errorf("failed to check category existence: %w", err) } // userID check handled by repo err = s.repo.Delete(userID, id) if err != nil { if errors.Is(err, repository.ErrCategoryNotFound) { return ErrCategoryNotFound } if errors.Is(err, repository.ErrCategoryInUse) { return ErrCategoryInUse } if errors.Is(err, repository.ErrCategoryHasChildren) { return ErrCategoryHasChildren } return fmt.Errorf("failed to delete category: %w", err) } return nil } // CategoryExists checks if a category exists by ID func (s *CategoryService) CategoryExists(userID uint, id uint) (bool, error) { exists, err := s.repo.ExistsByID(userID, id) if err != nil { return false, fmt.Errorf("failed to check category existence: %w", err) } return exists, nil } // GetCategoryPath returns the full path of a category (parent -> child) func (s *CategoryService) GetCategoryPath(userID, id uint) ([]models.Category, error) { category, err := s.repo.GetWithParent(userID, id) if err != nil { if errors.Is(err, repository.ErrCategoryNotFound) { return nil, ErrCategoryNotFound } return nil, fmt.Errorf("failed to get category: %w", err) } // userID check handled by repo path := []models.Category{} if category.Parent != nil { path = append(path, *category.Parent) } path = append(path, *category) return path, nil } // initDefaultCategories seeds default categories for a user func (s *CategoryService) initDefaultCategories(userID uint) error { defaults := []struct { Name string Type models.CategoryType Icon string SortOrder int Children []struct { Name string Type models.CategoryType Icon string SortOrder int } }{ // Expenses { Name: "餐饮", Type: models.CategoryTypeExpense, Icon: "restaurant", SortOrder: 1, Children: []struct { Name string Type models.CategoryType Icon string SortOrder int }{ {Name: "早餐", Type: models.CategoryTypeExpense, Icon: "breakfast_dining", SortOrder: 1}, {Name: "午餐", Type: models.CategoryTypeExpense, Icon: "lunch_dining", SortOrder: 2}, {Name: "晚餐", Type: models.CategoryTypeExpense, Icon: "dinner_dining", SortOrder: 3}, {Name: "零食", Type: models.CategoryTypeExpense, Icon: "icecream", SortOrder: 4}, {Name: "饮料", Type: models.CategoryTypeExpense, Icon: "local_cafe", SortOrder: 5}, }, }, { Name: "交通", Type: models.CategoryTypeExpense, Icon: "directions_bus", SortOrder: 2, Children: []struct { Name string Type models.CategoryType Icon string SortOrder int }{ {Name: "地铁", Type: models.CategoryTypeExpense, Icon: "subway", SortOrder: 1}, {Name: "公交", Type: models.CategoryTypeExpense, Icon: "directions_bus", SortOrder: 2}, {Name: "打车", Type: models.CategoryTypeExpense, Icon: "local_taxi", SortOrder: 3}, {Name: "加油", Type: models.CategoryTypeExpense, Icon: "local_gas_station", SortOrder: 4}, }, }, { Name: "购物", Type: models.CategoryTypeExpense, Icon: "shopping_bag", SortOrder: 3, Children: []struct { Name string Type models.CategoryType Icon string SortOrder int }{ {Name: "服饰", Type: models.CategoryTypeExpense, Icon: "checkroom", SortOrder: 1}, {Name: "日用", Type: models.CategoryTypeExpense, Icon: "soap", SortOrder: 2}, {Name: "电子数码", Type: models.CategoryTypeExpense, Icon: "devices", SortOrder: 3}, }, }, { Name: "居住", Type: models.CategoryTypeExpense, Icon: "home", SortOrder: 4, Children: []struct { Name string Type models.CategoryType Icon string SortOrder int }{ {Name: "房租", Type: models.CategoryTypeExpense, Icon: "house", SortOrder: 1}, {Name: "水电煤", Type: models.CategoryTypeExpense, Icon: "lightbulb", SortOrder: 2}, {Name: "物业", Type: models.CategoryTypeExpense, Icon: "security", SortOrder: 3}, }, }, { Name: "娱乐", Type: models.CategoryTypeExpense, Icon: "sports_esports", SortOrder: 5, Children: []struct { Name string Type models.CategoryType Icon string SortOrder int }{ {Name: "游戏", Type: models.CategoryTypeExpense, Icon: "gamepad", SortOrder: 1}, {Name: "电影", Type: models.CategoryTypeExpense, Icon: "movie", SortOrder: 2}, {Name: "运动", Type: models.CategoryTypeExpense, Icon: "fitness_center", SortOrder: 3}, }, }, // Income { Name: "收入", Type: models.CategoryTypeIncome, Icon: "payments", SortOrder: 10, Children: []struct { Name string Type models.CategoryType Icon string SortOrder int }{ {Name: "工资", Type: models.CategoryTypeIncome, Icon: "work", SortOrder: 1}, {Name: "奖金", Type: models.CategoryTypeIncome, Icon: "star", SortOrder: 2}, {Name: "兼职", Type: models.CategoryTypeIncome, Icon: "access_time", SortOrder: 3}, {Name: "理财", Type: models.CategoryTypeIncome, Icon: "trending_up", SortOrder: 4}, }, }, } for _, cat := range defaults { parent := &models.Category{ UserID: userID, Name: cat.Name, Type: cat.Type, Icon: cat.Icon, SortOrder: cat.SortOrder, } if err := s.repo.Create(parent); err != nil { return err } for _, child := range cat.Children { c := &models.Category{ UserID: userID, Name: child.Name, Type: child.Type, Icon: child.Icon, SortOrder: child.SortOrder, ParentID: &parent.ID, } if err := s.repo.Create(c); err != nil { return err } } } return nil }