From d08f7902c94a043958ddd901677c391ecd84283c Mon Sep 17 00:00:00 2001 From: admin <1297598740@qq.com> Date: Mon, 2 Feb 2026 13:47:30 +0800 Subject: [PATCH] =?UTF-8?q?feat:=20=E6=96=B0=E5=A2=9E=E5=9F=BA=E7=A1=80?= =?UTF-8?q?=E6=A8=A1=E5=9E=8B=E5=AE=9A=E4=B9=89=E5=92=8C=E4=BA=A4=E6=98=93?= =?UTF-8?q?=E6=95=B0=E6=8D=AE=E4=BB=93=E5=BA=93=E6=8E=A5=E5=8F=A3=E3=80=82?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- internal/models/models.go | 3 +- internal/repository/transaction_repository.go | 37 ++++++++++++++++--- 2 files changed, 34 insertions(+), 6 deletions(-) diff --git a/internal/models/models.go b/internal/models/models.go index 82e1506..c17c081 100644 --- a/internal/models/models.go +++ b/internal/models/models.go @@ -31,8 +31,9 @@ const ( AccountTypeDebitCard AccountType = "debit_card" AccountTypeCreditCard AccountType = "credit_card" AccountTypeEWallet AccountType = "e_wallet" - AccountTypeCreditLine AccountType = "credit_line" // 花呗、白�? + AccountTypeCreditLine AccountType = "credit_line" // 花呗、白条 AccountTypeInvestment AccountType = "investment" + AccountTypeReceivable AccountType = "receivable" // 应收账款(别人欠我的钱) ) // FrequencyType represents the frequency of recurring transactions diff --git a/internal/repository/transaction_repository.go b/internal/repository/transaction_repository.go index 5533842..56ff510 100644 --- a/internal/repository/transaction_repository.go +++ b/internal/repository/transaction_repository.go @@ -264,11 +264,15 @@ func (r *TransactionRepository) applyFilters(query *gorm.DB, filter TransactionF // Entity filters if filter.CategoryID != nil { - // Include sub-categories - var subCategoryIDs []uint - r.db.Model(&models.Category{}).Where("parent_id = ?", *filter.CategoryID).Pluck("id", &subCategoryIDs) - ids := append(subCategoryIDs, *filter.CategoryID) - query = query.Where("category_id IN ?", ids) + // Include all descendant sub-categories recursively + subCategoryIDs, err := r.getAllSubCategoryIDs(*filter.CategoryID) + if err == nil { + ids := append(subCategoryIDs, *filter.CategoryID) + query = query.Where("category_id IN ?", ids) + } else { + // Fallback: just use the specified category ID + query = query.Where("category_id = ?", *filter.CategoryID) + } } if filter.AccountID != nil { query = query.Where("account_id = ? OR to_account_id = ?", *filter.AccountID, *filter.AccountID) @@ -308,6 +312,29 @@ func (r *TransactionRepository) applyFilters(query *gorm.DB, filter TransactionF return query } +// getAllSubCategoryIDs recursively gets all descendant category IDs +func (r *TransactionRepository) getAllSubCategoryIDs(parentID uint) ([]uint, error) { + var allIDs []uint + var currentLevelIDs = []uint{parentID} + + // Iteratively find children level by level + for len(currentLevelIDs) > 0 { + var nextLevelIDs []uint + if err := r.db.Model(&models.Category{}).Where("parent_id IN ?", currentLevelIDs).Pluck("id", &nextLevelIDs).Error; err != nil { + return nil, err + } + + if len(nextLevelIDs) == 0 { + break + } + + allIDs = append(allIDs, nextLevelIDs...) + currentLevelIDs = nextLevelIDs + } + + return allIDs, nil +} + // applySorting applies sorting to the query func (r *TransactionRepository) applySorting(query *gorm.DB, sort TransactionSort) *gorm.DB { // Default sorting: transaction_date DESC (newest first)