diff --git a/internal/repository/transaction_repository.go b/internal/repository/transaction_repository.go index 8578e70..5533842 100644 --- a/internal/repository/transaction_repository.go +++ b/internal/repository/transaction_repository.go @@ -264,7 +264,11 @@ func (r *TransactionRepository) applyFilters(query *gorm.DB, filter TransactionF // Entity filters if filter.CategoryID != nil { - query = query.Where("category_id = ?", *filter.CategoryID) + // Include sub-categories + var subCategoryIDs []uint + r.db.Model(&models.Category{}).Where("parent_id = ?", *filter.CategoryID).Pluck("id", &subCategoryIDs) + ids := append(subCategoryIDs, *filter.CategoryID) + query = query.Where("category_id IN ?", ids) } if filter.AccountID != nil { query = query.Where("account_id = ? OR to_account_id = ?", *filter.AccountID, *filter.AccountID) diff --git a/internal/service/ai_bookkeeping_service.go b/internal/service/ai_bookkeeping_service.go index 66ec7b1..9188ea5 100644 --- a/internal/service/ai_bookkeeping_service.go +++ b/internal/service/ai_bookkeeping_service.go @@ -445,31 +445,45 @@ func extractCustomPrompt(text string) string { // ParseIntent extracts transaction parameters from text // Requirements: 7.1, 7.5, 7.6 -func (s *LLMService) ParseIntent(ctx context.Context, text string, history []ChatMessage) (*AITransactionParams, string, error) { - +func (s *LLMService) ParseIntent(ctx context.Context, text string, history []ChatMessage, userID uint) (*AITransactionParams, string, error) { if s.config.OpenAIAPIKey == "" || s.config.OpenAIBaseURL == "" { // No API key, return simple parsing result simpleParams, simpleMsg, _ := s.parseIntentSimple(text) return simpleParams, simpleMsg, nil } + // Fetch user categories for prompt context + var categoryNamesStr string + if categories, err := s.categoryRepo.GetAll(userID); err == nil && len(categories) > 0 { + var names []string + for _, c := range categories { + names = append(names, c.Name) + } + categoryNamesStr = strings.Join(names, "/") + } + // Build messages with history var systemPrompt string // 检查是否有自定义 System/Persona prompt(用于财务建议等场景) - // 如果有,直接使用自定义 prompt 覆盖默认记账 prompt if customPrompt := extractCustomPrompt(text); customPrompt != "" { systemPrompt = customPrompt } else { // 使用默认的记账 prompt todayDate := time.Now().Format("2006-01-02") + + catPrompt := "2. 分类:根据内容推断,如\"奶茶/咖啡/吃饭\"=餐饮" + if categoryNamesStr != "" { + catPrompt = fmt.Sprintf("2. 分类:必须从以下已有分类中选择最匹配的一项:[%s]。例如\"晚餐\"应映射为列表中存在的\"餐饮\"。如果列表无匹配项,则根据常识推断。", categoryNamesStr) + } + systemPrompt = fmt.Sprintf(`你是一个智能记账助手。从用户描述中提取记账信息。 今天的日期是%s 规则: 1. 金额:提取数字,如"6元"=6,"十五"=15 -2. 分类:根据内容推断,如"奶茶/咖啡/吃饭"=餐饮,"打车/地铁"=交通,"买衣服"=购物 +%s 3. 类型:默认expense(支出),除非明确说"收入/工资/奖金/红包" 4. 金额:提取明确的数字。如果用户未提及具体金额(如只说"想吃炸鸡"),amount字段必须返回 0 5. 日期:默认使用今天的日期(%s),除非用户明确指定其他日期 @@ -480,7 +494,8 @@ func (s *LLMService) ParseIntent(ctx context.Context, text string, history []Cha 示例(假设今天是%s): 用户:"买了6块的奶茶" -返回:{"amount":6,"category":"餐饮","type":"expense","note":"奶茶","date":"%s","message":"记录:餐饮支出6元,奶茶"}`, todayDate, todayDate, todayDate, todayDate) +返回:{"amount":6,"category":"餐饮","type":"expense","note":"奶茶","date":"%s","message":"记录:餐饮支出6元,奶茶"}`, + todayDate, catPrompt, todayDate, todayDate, todayDate) } messages := []ChatMessage{ @@ -915,7 +930,7 @@ func (s *AIBookkeepingService) ProcessChat(ctx context.Context, userID uint, ses isSpendingAdvice := s.isSpendingAdviceIntent(message) // Parse intent for transaction - params, responseMsg, err := s.llmService.ParseIntent(ctx, message, session.Messages[:len(session.Messages)-1]) + params, responseMsg, err := s.llmService.ParseIntent(ctx, message, session.Messages[:len(session.Messages)-1], userID) if err != nil { return nil, fmt.Errorf("failed to parse intent: %w", err) }