2026-01-25 21:59:00 +08:00
package repository
import (
"errors"
"fmt"
"accounting-app/internal/models"
"gorm.io/gorm"
)
// Category repository errors
var (
ErrCategoryNotFound = errors . New ( "category not found" )
ErrCategoryInUse = errors . New ( "category is in use and cannot be deleted" )
ErrCategoryHasChildren = errors . New ( "category has children and cannot be deleted" )
)
// CategoryRepository handles database operations for categories
type CategoryRepository struct {
db * gorm . DB
}
// NewCategoryRepository creates a new CategoryRepository instance
func NewCategoryRepository ( db * gorm . DB ) * CategoryRepository {
return & CategoryRepository { db : db }
}
2026-01-28 09:55:29 +08:00
// GetDB returns the underlying database connection
func ( r * CategoryRepository ) GetDB ( ) * gorm . DB {
return r . db
}
2026-01-25 21:59:00 +08:00
// Create creates a new category in the database
func ( r * CategoryRepository ) Create ( category * models . Category ) error {
if err := r . db . Create ( category ) . Error ; err != nil {
return fmt . Errorf ( "failed to create category: %w" , err )
}
return nil
}
// GetByID retrieves a category by its ID
func ( r * CategoryRepository ) GetByID ( userID uint , id uint ) ( * models . Category , error ) {
var category models . Category
if err := r . db . Where ( "user_id = ?" , userID ) . First ( & category , id ) . Error ; err != nil {
if errors . Is ( err , gorm . ErrRecordNotFound ) {
return nil , ErrCategoryNotFound
}
return nil , fmt . Errorf ( "failed to get category: %w" , err )
}
return & category , nil
}
// GetAll retrieves all categories for a user
func ( r * CategoryRepository ) GetAll ( userID uint ) ( [ ] models . Category , error ) {
var categories [ ] models . Category
if err := r . db . Where ( "user_id = ?" , userID ) . Order ( "sort_order ASC, created_at ASC" ) . Find ( & categories ) . Error ; err != nil {
return nil , fmt . Errorf ( "failed to get categories: %w" , err )
}
return categories , nil
}
// GetByType retrieves all categories of a specific type (income or expense) for a user
func ( r * CategoryRepository ) GetByType ( userID uint , categoryType models . CategoryType ) ( [ ] models . Category , error ) {
var categories [ ] models . Category
if err := r . db . Where ( "user_id = ? AND type = ?" , userID , categoryType ) . Order ( "sort_order ASC, created_at ASC" ) . Find ( & categories ) . Error ; err != nil {
return nil , fmt . Errorf ( "failed to get categories by type: %w" , err )
}
return categories , nil
}
// GetRootCategories retrieves all categories without a parent (top-level categories) for a user
func ( r * CategoryRepository ) GetRootCategories ( userID uint ) ( [ ] models . Category , error ) {
var categories [ ] models . Category
if err := r . db . Where ( "user_id = ? AND parent_id IS NULL" , userID ) . Order ( "sort_order ASC, created_at ASC" ) . Find ( & categories ) . Error ; err != nil {
return nil , fmt . Errorf ( "failed to get root categories: %w" , err )
}
return categories , nil
}
// GetChildren retrieves all child categories of a given parent category
func ( r * CategoryRepository ) GetChildren ( userID uint , parentID uint ) ( [ ] models . Category , error ) {
var categories [ ] models . Category
if err := r . db . Where ( "user_id = ? AND parent_id = ?" , userID , parentID ) . Order ( "sort_order ASC, created_at ASC" ) . Find ( & categories ) . Error ; err != nil {
return nil , fmt . Errorf ( "failed to get child categories: %w" , err )
}
return categories , nil
}
// GetWithChildren retrieves a category with its children preloaded
func ( r * CategoryRepository ) GetWithChildren ( userID uint , id uint ) ( * models . Category , error ) {
var category models . Category
if err := r . db . Preload ( "Children" , func ( db * gorm . DB ) * gorm . DB {
return db . Order ( "sort_order ASC, created_at ASC" )
} ) . Where ( "user_id = ?" , userID ) . First ( & category , id ) . Error ; err != nil {
if errors . Is ( err , gorm . ErrRecordNotFound ) {
return nil , ErrCategoryNotFound
}
return nil , fmt . Errorf ( "failed to get category with children: %w" , err )
}
return & category , nil
}
// GetWithParent retrieves a category with its parent preloaded
func ( r * CategoryRepository ) GetWithParent ( userID uint , id uint ) ( * models . Category , error ) {
var category models . Category
if err := r . db . Preload ( "Parent" ) . Where ( "user_id = ?" , userID ) . First ( & category , id ) . Error ; err != nil {
if errors . Is ( err , gorm . ErrRecordNotFound ) {
return nil , ErrCategoryNotFound
}
return nil , fmt . Errorf ( "failed to get category with parent: %w" , err )
}
return & category , nil
}
// Update updates an existing category in the database
func ( r * CategoryRepository ) Update ( category * models . Category ) error {
// First check if the category exists
var existing models . Category
if err := r . db . Where ( "user_id = ?" , category . UserID ) . First ( & existing , category . ID ) . Error ; err != nil {
if errors . Is ( err , gorm . ErrRecordNotFound ) {
return ErrCategoryNotFound
}
return fmt . Errorf ( "failed to check category existence: %w" , err )
}
// Update the category
if err := r . db . Save ( category ) . Error ; err != nil {
return fmt . Errorf ( "failed to update category: %w" , err )
}
return nil
}
// Delete deletes a category by its ID
func ( r * CategoryRepository ) Delete ( userID uint , id uint ) error {
// First check if the category exists
var category models . Category
if err := r . db . Where ( "user_id = ?" , userID ) . First ( & category , id ) . Error ; err != nil {
if errors . Is ( err , gorm . ErrRecordNotFound ) {
return ErrCategoryNotFound
}
return fmt . Errorf ( "failed to check category existence: %w" , err )
}
// Check if there are any child categories
var childCount int64
if err := r . db . Model ( & models . Category { } ) . Where ( "parent_id = ?" , id ) . Count ( & childCount ) . Error ; err != nil {
return fmt . Errorf ( "failed to check child categories: %w" , err )
}
if childCount > 0 {
return ErrCategoryHasChildren
}
// Check if there are any transactions associated with this category
var transactionCount int64
if err := r . db . Model ( & models . Transaction { } ) . Where ( "category_id = ?" , id ) . Count ( & transactionCount ) . Error ; err != nil {
return fmt . Errorf ( "failed to check category transactions: %w" , err )
}
if transactionCount > 0 {
return ErrCategoryInUse
}
// Check if there are any budgets associated with this category
var budgetCount int64
if err := r . db . Model ( & models . Budget { } ) . Where ( "category_id = ?" , id ) . Count ( & budgetCount ) . Error ; err != nil {
return fmt . Errorf ( "failed to check category budgets: %w" , err )
}
if budgetCount > 0 {
return ErrCategoryInUse
}
// Check if there are any recurring transactions associated with this category
var recurringCount int64
if err := r . db . Model ( & models . RecurringTransaction { } ) . Where ( "category_id = ?" , id ) . Count ( & recurringCount ) . Error ; err != nil {
return fmt . Errorf ( "failed to check category recurring transactions: %w" , err )
}
if recurringCount > 0 {
return ErrCategoryInUse
}
// Delete the category (hard delete since Category doesn't have DeletedAt)
if err := r . db . Delete ( & category ) . Error ; err != nil {
return fmt . Errorf ( "failed to delete category: %w" , err )
}
return nil
}
// ExistsByID checks if a category with the given ID exists
func ( r * CategoryRepository ) ExistsByID ( userID uint , id uint ) ( bool , error ) {
var count int64
if err := r . db . Model ( & models . Category { } ) . Where ( "user_id = ? AND id = ?" , userID , id ) . Count ( & count ) . Error ; err != nil {
return false , fmt . Errorf ( "failed to check category existence: %w" , err )
}
return count > 0 , nil
}
// ExistsByName checks if a category with the given name exists for a user
func ( r * CategoryRepository ) ExistsByName ( userID uint , name string ) ( bool , error ) {
var count int64
if err := r . db . Model ( & models . Category { } ) . Where ( "user_id = ? AND name = ?" , userID , name ) . Count ( & count ) . Error ; err != nil {
return false , fmt . Errorf ( "failed to check category name existence: %w" , err )
}
return count > 0 , nil
}
// ExistsByNameAndType checks if a category with the given name and type exists for a user
func ( r * CategoryRepository ) ExistsByNameAndType ( userID uint , name string , categoryType models . CategoryType ) ( bool , error ) {
var count int64
if err := r . db . Model ( & models . Category { } ) . Where ( "user_id = ? AND name = ? AND type = ?" , userID , name , categoryType ) . Count ( & count ) . Error ; err != nil {
return false , fmt . Errorf ( "failed to check category name and type existence: %w" , err )
}
return count > 0 , nil
}
// ExistsByNameExcludingID checks if a category with the given name exists, excluding a specific ID, for a user
func ( r * CategoryRepository ) ExistsByNameExcludingID ( userID uint , name string , excludeID uint ) ( bool , error ) {
var count int64
if err := r . db . Model ( & models . Category { } ) . Where ( "user_id = ? AND name = ? AND id != ?" , userID , name , excludeID ) . Count ( & count ) . Error ; err != nil {
return false , fmt . Errorf ( "failed to check category name existence: %w" , err )
}
return count > 0 , nil
}
// GetRootCategoriesByType retrieves all root categories of a specific type for a user
func ( r * CategoryRepository ) GetRootCategoriesByType ( userID uint , categoryType models . CategoryType ) ( [ ] models . Category , error ) {
var categories [ ] models . Category
if err := r . db . Where ( "user_id = ? AND parent_id IS NULL AND type = ?" , userID , categoryType ) . Order ( "sort_order ASC, created_at ASC" ) . Find ( & categories ) . Error ; err != nil {
return nil , fmt . Errorf ( "failed to get root categories by type: %w" , err )
}
return categories , nil
}
// GetAllWithChildren retrieves all categories with their children preloaded for a user
func ( r * CategoryRepository ) GetAllWithChildren ( userID uint ) ( [ ] models . Category , error ) {
var categories [ ] models . Category
if err := r . db . Preload ( "Children" , func ( db * gorm . DB ) * gorm . DB {
return db . Order ( "sort_order ASC, created_at ASC" )
} ) . Where ( "user_id = ? AND parent_id IS NULL" , userID ) . Order ( "sort_order ASC, created_at ASC" ) . Find ( & categories ) . Error ; err != nil {
return nil , fmt . Errorf ( "failed to get categories with children: %w" , err )
}
return categories , nil
}
// CountByType returns the count of categories by type for a user
func ( r * CategoryRepository ) CountByType ( userID uint , categoryType models . CategoryType ) ( int64 , error ) {
var count int64
if err := r . db . Model ( & models . Category { } ) . Where ( "user_id = ? AND type = ?" , userID , categoryType ) . Count ( & count ) . Error ; err != nil {
return 0 , fmt . Errorf ( "failed to count categories by type: %w" , err )
}
return count , nil
}
// GetByName retrieves a category by its name for a user
func ( r * CategoryRepository ) GetByName ( userID uint , name string ) ( * models . Category , error ) {
var category models . Category
if err := r . db . Where ( "user_id = ? AND name = ?" , userID , name ) . First ( & category ) . Error ; err != nil {
if errors . Is ( err , gorm . ErrRecordNotFound ) {
return nil , ErrCategoryNotFound
}
return nil , fmt . Errorf ( "failed to get category by name: %w" , err )
}
return & category , nil
}