feat: 新增基础模型定义和交易数据仓库接口。
This commit is contained in:
@@ -31,8 +31,9 @@ const (
|
||||
AccountTypeDebitCard AccountType = "debit_card"
|
||||
AccountTypeCreditCard AccountType = "credit_card"
|
||||
AccountTypeEWallet AccountType = "e_wallet"
|
||||
AccountTypeCreditLine AccountType = "credit_line" // 花呗、白<EFBFBD>?
|
||||
AccountTypeCreditLine AccountType = "credit_line" // 花呗、白条
|
||||
AccountTypeInvestment AccountType = "investment"
|
||||
AccountTypeReceivable AccountType = "receivable" // 应收账款(别人欠我的钱)
|
||||
)
|
||||
|
||||
// FrequencyType represents the frequency of recurring transactions
|
||||
|
||||
@@ -264,11 +264,15 @@ func (r *TransactionRepository) applyFilters(query *gorm.DB, filter TransactionF
|
||||
|
||||
// Entity filters
|
||||
if filter.CategoryID != nil {
|
||||
// Include sub-categories
|
||||
var subCategoryIDs []uint
|
||||
r.db.Model(&models.Category{}).Where("parent_id = ?", *filter.CategoryID).Pluck("id", &subCategoryIDs)
|
||||
ids := append(subCategoryIDs, *filter.CategoryID)
|
||||
query = query.Where("category_id IN ?", ids)
|
||||
// Include all descendant sub-categories recursively
|
||||
subCategoryIDs, err := r.getAllSubCategoryIDs(*filter.CategoryID)
|
||||
if err == nil {
|
||||
ids := append(subCategoryIDs, *filter.CategoryID)
|
||||
query = query.Where("category_id IN ?", ids)
|
||||
} else {
|
||||
// Fallback: just use the specified category ID
|
||||
query = query.Where("category_id = ?", *filter.CategoryID)
|
||||
}
|
||||
}
|
||||
if filter.AccountID != nil {
|
||||
query = query.Where("account_id = ? OR to_account_id = ?", *filter.AccountID, *filter.AccountID)
|
||||
@@ -308,6 +312,29 @@ func (r *TransactionRepository) applyFilters(query *gorm.DB, filter TransactionF
|
||||
return query
|
||||
}
|
||||
|
||||
// getAllSubCategoryIDs recursively gets all descendant category IDs
|
||||
func (r *TransactionRepository) getAllSubCategoryIDs(parentID uint) ([]uint, error) {
|
||||
var allIDs []uint
|
||||
var currentLevelIDs = []uint{parentID}
|
||||
|
||||
// Iteratively find children level by level
|
||||
for len(currentLevelIDs) > 0 {
|
||||
var nextLevelIDs []uint
|
||||
if err := r.db.Model(&models.Category{}).Where("parent_id IN ?", currentLevelIDs).Pluck("id", &nextLevelIDs).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if len(nextLevelIDs) == 0 {
|
||||
break
|
||||
}
|
||||
|
||||
allIDs = append(allIDs, nextLevelIDs...)
|
||||
currentLevelIDs = nextLevelIDs
|
||||
}
|
||||
|
||||
return allIDs, nil
|
||||
}
|
||||
|
||||
// applySorting applies sorting to the query
|
||||
func (r *TransactionRepository) applySorting(query *gorm.DB, sort TransactionSort) *gorm.DB {
|
||||
// Default sorting: transaction_date DESC (newest first)
|
||||
|
||||
Reference in New Issue
Block a user