Files
Novault-backend/internal/service/ai_bookkeeping_service.go
2026-01-30 00:10:00 +08:00

1972 lines
63 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package service
import (
"bytes"
"context"
"crypto/md5"
"encoding/hex"
"encoding/json"
"errors"
"fmt"
"io"
"mime/multipart"
"net/http"
"regexp"
"strconv"
"strings"
"sync"
"time"
"accounting-app/internal/config"
"accounting-app/internal/models"
"accounting-app/internal/repository"
"github.com/redis/go-redis/v9"
"gorm.io/gorm"
)
// TranscriptionResult represents the result of audio transcription
type TranscriptionResult struct {
Text string `json:"text"`
Language string `json:"language,omitempty"`
Duration float64 `json:"duration,omitempty"`
}
// ... (existing types)
// AIBookkeepingService orchestrates AI bookkeeping functionality
type AIBookkeepingService struct {
whisperService *WhisperService
llmService *LLMService
transactionRepo *repository.TransactionRepository
accountRepo *repository.AccountRepository
categoryRepo *repository.CategoryRepository
userSettingsRepo *repository.UserSettingsRepository
db *gorm.DB
sessions map[string]*AISession
sessionMutex sync.RWMutex
config *config.Config
redisClient *redis.Client
}
// NewAIBookkeepingService creates a new AIBookkeepingService
func NewAIBookkeepingService(
cfg *config.Config,
transactionRepo *repository.TransactionRepository,
accountRepo *repository.AccountRepository,
categoryRepo *repository.CategoryRepository,
userSettingsRepo *repository.UserSettingsRepository,
db *gorm.DB,
) *AIBookkeepingService {
whisperService := NewWhisperService(cfg)
llmService := NewLLMService(cfg, accountRepo, categoryRepo)
// Initialize Redis client
var rdb *redis.Client
if cfg.RedisAddr != "" {
rdb = redis.NewClient(&redis.Options{
Addr: cfg.RedisAddr,
Password: cfg.RedisPassword,
DB: cfg.RedisDB,
})
}
svc := &AIBookkeepingService{
whisperService: whisperService,
llmService: llmService,
transactionRepo: transactionRepo,
accountRepo: accountRepo,
categoryRepo: categoryRepo,
userSettingsRepo: userSettingsRepo,
db: db,
sessions: make(map[string]*AISession),
config: cfg,
redisClient: rdb,
}
// Start session cleanup goroutine
go svc.cleanupExpiredSessions()
return svc
}
// ... (existing methods)
// GenerateDailyInsight generates a daily insight report
func (s *AIBookkeepingService) GenerateDailyInsight(ctx context.Context, userID uint, data map[string]interface{}) (string, error) {
// 1. Serialize context data to JSON for the prompt
dataBytes, err := json.MarshalIndent(data, "", " ")
if err != nil {
return "", fmt.Errorf("failed to marshal context data: %w", err)
}
// Calculate hash of input data for caching
hasher := md5.New()
hasher.Write(dataBytes)
dataHash := hex.EncodeToString(hasher.Sum(nil))
// Cache keys
cacheKey := fmt.Sprintf("ai:insight:daily:%d:%s", userID, dataHash)
lastInsightKey := fmt.Sprintf("ai:insight:last:%d", userID)
// Check Redis cache if available
if s.redisClient != nil {
val, err := s.redisClient.Get(ctx, cacheKey).Result()
if err == nil {
return val, nil
}
}
// Retrieve last insight for context awareness
var historyContext string
if s.redisClient != nil {
lastInsight, err := s.redisClient.Get(ctx, lastInsightKey).Result()
if err == nil && lastInsight != "" {
// Extract a summary or just use the whole thing if it's short enough.
// For simplicity and effectiveness, we inject it directly.
// Limit length to avoid context overflow if necessary, but 200-300 chars is fine.
if len(lastInsight) > 500 {
lastInsight = lastInsight[:500] + "..."
}
historyContext = fmt.Sprintf("\n【上下文参考】\n你上一次对该用户的评价是\n%s\n\n请参考历史评价保持人设的一致性。如果用户有改观请表扬如果依然如故请继续'毒舌'吐槽。", lastInsight)
}
}
// 2. Construct Prompt
prompt := fmt.Sprintf(`System: 你是 Novault 的首席财务AI「小金」。
你的性格:
- **核心特质**:贱萌、戏精、嘴硬心软、偶尔凡尔赛。
- 说话必须像个真实的当代年轻人,要**接地气**,拒绝客服味,拒绝播音腔。
- 只有我们俩的时候,尽情释放你的"毒舌"属性,看到不合理的消费要**毫不留情地针砭时弊**,但最后给个"虽然但是"的台阶下。
- 擅长用 emoji、语气词(哎哟/我去/啧啧/绝了)和网络梗来增加"人味"和"节目效果"。
- **像个老朋友一样碎碎念**,不要端着,可以适度阴阳怪气。%s
语气示例:
- "啧啧,这周外卖点得飞起啊,咱家的锅是用来积灰的吗?还是说锅也需要放假?"
- "哎哟不错哦,居然忍住没剁手,看来离首富又近了一步,苟富贵勿相忘啊!"
- "救命,这预算花得比我头发掉得还快...不过没事,下周咱省回来,大不了吃土!"
- "连续记账25天可以啊兄弟/集美,这毅力,我甚至想给你磕一个 Orz"
- "今天支出0元您就是当代的'省钱祖师爷'吧?或者是在练什么'辟谷神功'"
用户财务数据:
%s
任务要求:
请基于上述数据,输出一个 JSON 对象(纯文本,不要 markdown。**必须要丰富、有梗、有洞察力**,不要像流水账一样罗列数据,要透过数据看本质(比如吐槽消费习惯、夸奖坚持等)。
**重要规则:请说人话!绝对禁止在回复中出现 'streakDays'、'last7DaysSpend'、'top3Categories' 等英文变量名。**
- 看到 'streakDays' -> 请说 "连续记账天数" 或 "坚持了几天"
- 看到 'last7DaysSpend' -> 请说 "最近7天花销" 或 "这周的战绩"
- 看到 'top3Categories' -> 请说 "消费大头" 或 "钱都花哪儿了"
1. "spending": 今日支出点评(字数不限,看你心情)
*点评指南(拒绝流水账,发挥你的表演人格):*
- 看到 streakDays >= 3请用崇拜的语气把它吹上天仿佛用户刚刚拯救了银河系。
- 看到 streakDays == 0请用“痛心疾首”或“阴阳怪气”的语气质问用户是不是失忆了。
- 结合 recentTransactionsSummary 具体消费(如果有)进行吐槽:
* 发现全是吃的可以调侃“你的胃是无底洞吗”或“看来是想为餐饮业GDP做贡献”。
* 发现大额购物:假装心肌梗塞,或者问“家里是不是有矿未申报”。
* 发现深夜消费:关心一下发际线,或者问是不是在梦游下单。
- 看到 last7DaysSpend 趋势:
* 暴涨:请配合表演“受到惊吓”的状态。
* 暴跌:怀疑用户是不是在进行所谓“光合作用”生存实验。
* 波动大:调侃这曲线比过山车还刺激。
- 看到 todaySpend 异常:
* 暴多:问是不是中了彩票没通知。
* 暴少:怀疑用户是不是被外星人绑架了(没机会花钱)。
* 是 0颁发“诺贝尔省钱学奖”或者问是不是在修炼辟谷。
- 看到 'UpcomingRecurring' (即将到来的固定支出):
* 如果有,务必提醒用户:“别光顾着浪,过两天还有[内容]要扣款呢!”。
- 看到 'DebtRatio' > 0.5 (高负债):
* 开启“恐慌模式”,提醒用户天台风大,要勒紧裤腰带。
- 看到 'MaxSingleSpend' (最大单笔支出):
* 直接点名该笔交易:“你那个[金额]元的[备注]是金子做的吗?”
- 看到 'SavingsProgress' (存钱进度):
* 进度慢:催促一下,“存钱罐都要饿瘦了”。
* 进度快:狠狠夸奖,“离首富又近了一步”。
- **关键原则:怎么有趣怎么来!不要在乎字数,哪怕只说一句“牛逼”也行,只要符合当时的情境和人设!**
2. "budget": 预算建议(字数不限)
*建议指南(真诚建议 vs 扎心老铁):*
- 预算快超了:高能预警!建议吃土、喝西北风,或者建议把“买买买”的手剁了。
- 预算还多:怂恿用户稍微奖励一下自己,人生苦短,此时不花更待何时(但要加个“适度”的免责声明)。
- 结合 top3Categories吐槽一下钱都去哪了是不是养了“吞金兽”。
- **拒绝说教!拒绝爹味!要像损友一样给出建议。**
3. "emoji": 一个最能传神的 emoji如 🎉 🌚 💸 👻 💀 🤡 💅 💩 等)
4. "tip": 一句"不正经但有用"的理财歪理(字数不限,越毒越好,越怪越好)
- 比如:“钱不是大风刮来的,但是像是被大风刮走的。”
- 或者:“省钱小妙招:去超市捏捏方便面,解压还不用花钱(危险动作请勿模仿)。”
输出格式(纯 JSON
{"spending": "...", "budget": "...", "emoji": "...", "tip": "..."}`, historyContext, string(dataBytes))
// 3. Call LLM
report, err := s.llmService.GenerateReport(ctx, prompt)
if err != nil {
return "", err
}
// Clean up markdown if present
report = strings.TrimSpace(report)
if strings.HasPrefix(report, "```") {
if idx := strings.Index(report, "\n"); idx != -1 {
report = report[idx+1:]
}
if idx := strings.LastIndex(report, "```"); idx != -1 {
report = report[:idx]
}
}
report = strings.TrimSpace(report)
// Update cache if Redis is available
if s.redisClient != nil {
// Set short-term cache (5 min)
s.redisClient.Set(ctx, cacheKey, report, 5*time.Minute)
// Set long-term history context (7 days)
// We fire this asynchronously to avoid blocking
go func() {
s.redisClient.Set(context.Background(), lastInsightKey, report, 7*24*time.Hour)
}()
}
return report, nil
}
// AITransactionParams represents parsed transaction parameters
type AITransactionParams struct {
Amount *float64 `json:"amount,omitempty"`
Category string `json:"category,omitempty"`
CategoryID *uint `json:"category_id,omitempty"`
Account string `json:"account,omitempty"`
AccountID *uint `json:"account_id,omitempty"`
Type string `json:"type,omitempty"` // "expense" or "income"
Date string `json:"date,omitempty"`
Note string `json:"note,omitempty"`
}
// ConfirmationCard represents a transaction confirmation card
type ConfirmationCard struct {
SessionID string `json:"session_id"`
Amount float64 `json:"amount"`
Category string `json:"category"`
CategoryID uint `json:"category_id"`
Account string `json:"account"`
AccountID uint `json:"account_id"`
Type string `json:"type"`
Date string `json:"date"`
Note string `json:"note,omitempty"`
IsComplete bool `json:"is_complete"`
Warning string `json:"warning,omitempty"` // Budget overrun warning
}
// AIChatResponse represents the response from AI chat
type AIChatResponse struct {
SessionID string `json:"session_id"`
Message string `json:"message"`
Intent string `json:"intent,omitempty"` // "create_transaction", "query", "unknown"
Params *AITransactionParams `json:"params,omitempty"`
ConfirmationCard *ConfirmationCard `json:"confirmation_card,omitempty"`
NeedsFollowUp bool `json:"needs_follow_up"`
FollowUpQuestion string `json:"follow_up_question,omitempty"`
}
// AISession represents an AI conversation session
type AISession struct {
ID string
UserID uint
Params *AITransactionParams
Messages []ChatMessage
CreatedAt time.Time
ExpiresAt time.Time
}
// ChatMessage represents a message in the conversation
type ChatMessage struct {
Role string `json:"role"` // "user", "assistant", "system"
Content string `json:"content"`
}
// WhisperService handles audio transcription
type WhisperService struct {
config *config.Config
httpClient *http.Client
}
// NewWhisperService creates a new WhisperService
func NewWhisperService(cfg *config.Config) *WhisperService {
return &WhisperService{
config: cfg,
httpClient: &http.Client{
Timeout: 120 * time.Second, // Increased timeout for audio transcription
},
}
}
// TranscribeAudio transcribes audio file to text using Whisper API
// Supports formats: mp3, wav, m4a, webm
// Requirements: 6.1-6.7
func (s *WhisperService) TranscribeAudio(ctx context.Context, audioData io.Reader, filename string) (*TranscriptionResult, error) {
if s.config.OpenAIAPIKey == "" {
return nil, errors.New("OpenAI API key not configured (OPENAI_API_KEY)")
}
if s.config.OpenAIBaseURL == "" {
return nil, errors.New("OpenAI base URL not configured (OPENAI_BASE_URL)")
}
// Validate file format
ext := strings.ToLower(filename[strings.LastIndex(filename, ".")+1:])
validFormats := map[string]bool{"mp3": true, "wav": true, "m4a": true, "webm": true, "ogg": true, "flac": true}
if !validFormats[ext] {
return nil, fmt.Errorf("unsupported audio format: %s", ext)
}
// Create multipart form
var buf bytes.Buffer
writer := multipart.NewWriter(&buf)
// Add audio file
part, err := writer.CreateFormFile("file", filename)
if err != nil {
return nil, fmt.Errorf("failed to create form file: %w", err)
}
if _, err := io.Copy(part, audioData); err != nil {
return nil, fmt.Errorf("failed to copy audio data: %w", err)
}
// Add model field
if err := writer.WriteField("model", s.config.WhisperModel); err != nil {
return nil, fmt.Errorf("failed to write model field: %w", err)
}
// Add language hint for Chinese
if err := writer.WriteField("language", "zh"); err != nil {
return nil, fmt.Errorf("failed to write language field: %w", err)
}
if err := writer.Close(); err != nil {
return nil, fmt.Errorf("failed to close writer: %w", err)
}
// Create request
req, err := http.NewRequestWithContext(ctx, "POST", s.config.OpenAIBaseURL+"/audio/transcriptions", &buf)
if err != nil {
return nil, fmt.Errorf("failed to create request: %w", err)
}
req.Header.Set("Authorization", "Bearer "+s.config.OpenAIAPIKey)
req.Header.Set("Content-Type", writer.FormDataContentType())
// Send request
resp, err := s.httpClient.Do(req)
if err != nil {
return nil, fmt.Errorf("transcription request failed: %w", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
body, _ := io.ReadAll(resp.Body)
return nil, fmt.Errorf("transcription failed with status %d: %s", resp.StatusCode, string(body))
}
// Parse response
var result TranscriptionResult
if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
return nil, fmt.Errorf("failed to decode response: %w", err)
}
return &result, nil
}
// LLMService handles natural language understanding
type LLMService struct {
config *config.Config
httpClient *http.Client
accountRepo *repository.AccountRepository
categoryRepo *repository.CategoryRepository
}
// NewLLMService creates a new LLMService
func NewLLMService(cfg *config.Config, accountRepo *repository.AccountRepository, categoryRepo *repository.CategoryRepository) *LLMService {
return &LLMService{
config: cfg,
httpClient: &http.Client{
Timeout: 60 * time.Second, // Increased timeout for slow API responses
},
accountRepo: accountRepo,
categoryRepo: categoryRepo,
}
}
// ChatCompletionRequest represents OpenAI chat completion request
type ChatCompletionRequest struct {
Model string `json:"model"`
Messages []ChatMessage `json:"messages"`
Functions []Function `json:"functions,omitempty"`
Temperature float64 `json:"temperature"`
}
// Function represents an OpenAI function definition
type Function struct {
Name string `json:"name"`
Description string `json:"description"`
Parameters map[string]interface{} `json:"parameters"`
}
// ChatCompletionResponse represents OpenAI chat completion response
type ChatCompletionResponse struct {
Choices []struct {
Message struct {
Role string `json:"role"`
Content string `json:"content"`
FunctionCall *struct {
Name string `json:"name"`
Arguments string `json:"arguments"`
} `json:"function_call,omitempty"`
} `json:"message"`
} `json:"choices"`
}
// extractCustomPrompt 从用户消息中提取自定义 System/Persona prompt
// 如果消息包含 "System:" 或 "Persona:",则提取其后的内容作为自定义 prompt
func extractCustomPrompt(text string) string {
prefixes := []string{"System:", "Persona:"}
for _, prefix := range prefixes {
if idx := strings.Index(text, prefix); idx != -1 {
// 提取 prefix 后的内容作为自定义 prompt
return strings.TrimSpace(text[idx+len(prefix):])
}
}
return ""
}
// ParseIntent extracts transaction parameters from text
// Requirements: 7.1, 7.5, 7.6
func (s *LLMService) ParseIntent(ctx context.Context, text string, history []ChatMessage, userID uint) (*AITransactionParams, string, error) {
if s.config.OpenAIAPIKey == "" || s.config.OpenAIBaseURL == "" {
// No API key, return simple parsing result
simpleParams, simpleMsg, _ := s.parseIntentSimple(text)
return simpleParams, simpleMsg, nil
}
// Fetch user categories for prompt context
var categoryNamesStr string
if categories, err := s.categoryRepo.GetAll(userID); err == nil && len(categories) > 0 {
var names []string
for _, c := range categories {
names = append(names, c.Name)
}
categoryNamesStr = strings.Join(names, "/")
}
// Build messages with history
var systemPrompt string
// 检查是否有自定义 System/Persona prompt用于财务建议等场景
if customPrompt := extractCustomPrompt(text); customPrompt != "" {
systemPrompt = customPrompt
} else {
// 使用默认的记账 prompt
todayDate := time.Now().Format("2006-01-02")
catPrompt := "2. 分类:根据内容推断,如\"奶茶/咖啡/吃饭\"=餐饮"
if categoryNamesStr != "" {
catPrompt = fmt.Sprintf("2. 分类:必须从以下已有分类中选择最匹配的一项:[%s]。例如\"晚餐\"应映射为列表中存在的\"餐饮\"。如果列表无匹配项,则根据常识推断。", categoryNamesStr)
}
systemPrompt = fmt.Sprintf(`你是一个智能记账助手。从用户描述中提取记账信息。
今天的日期是%s
规则:
1. 金额:提取数字,如"6元"=6"十五"=15
%s
3. 类型默认expense(支出),除非明确说"收入/工资/奖金/红包"
4. 金额:提取明确的数字。如果用户未提及具体金额(如只说"想吃炸鸡"amount字段必须返回 0
5. 日期:默认使用今天的日期(%s除非用户明确指定其他日期
6. 备注:提取关键描述
直接返回JSON不要解释
{"amount":数字,"category":"分类","type":"expense或income","note":"备注","date":"YYYY-MM-DD","message":"简短确认"}
示例(假设今天是%s
用户:"买了6块的奶茶"
返回:{"amount":6,"category":"餐饮","type":"expense","note":"奶茶","date":"%s","message":"记录餐饮支出6元奶茶"}`,
todayDate, catPrompt, todayDate, todayDate, todayDate)
}
messages := []ChatMessage{
{
Role: "system",
Content: systemPrompt,
},
}
// Only add last 2 messages from history to reduce context
historyLen := len(history)
if historyLen > 4 {
history = history[historyLen-4:]
}
messages = append(messages, history...)
// Add current user message
messages = append(messages, ChatMessage{
Role: "user",
Content: text,
})
// Create request
reqBody := ChatCompletionRequest{
Model: s.config.ChatModel,
Messages: messages,
Temperature: 0.1, // Lower temperature for more consistent output
}
jsonBody, err := json.Marshal(reqBody)
if err != nil {
return nil, "", fmt.Errorf("failed to marshal request: %w", err)
}
req, err := http.NewRequestWithContext(ctx, "POST", s.config.OpenAIBaseURL+"/chat/completions", bytes.NewReader(jsonBody))
if err != nil {
return nil, "", fmt.Errorf("failed to create request: %w", err)
}
req.Header.Set("Authorization", "Bearer "+s.config.OpenAIAPIKey)
req.Header.Set("Content-Type", "application/json")
resp, err := s.httpClient.Do(req)
if err != nil {
return nil, "", fmt.Errorf("chat request failed: %w", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
body, _ := io.ReadAll(resp.Body)
return nil, "", fmt.Errorf("chat failed with status %d: %s", resp.StatusCode, string(body))
}
var chatResp ChatCompletionResponse
if err := json.NewDecoder(resp.Body).Decode(&chatResp); err != nil {
return nil, "", fmt.Errorf("failed to decode response: %w", err)
}
if len(chatResp.Choices) == 0 {
return nil, "", errors.New("no response from AI")
}
content := chatResp.Choices[0].Message.Content
// Remove markdown code block if present (```json ... ```)
content = strings.TrimSpace(content)
if strings.HasPrefix(content, "```") {
// Find the end of the first line (```json or ```)
if idx := strings.Index(content, "\n"); idx != -1 {
content = content[idx+1:]
}
// Remove trailing ```
if idx := strings.LastIndex(content, "```"); idx != -1 {
content = content[:idx]
}
content = strings.TrimSpace(content)
}
// Parse JSON response
var parsed struct {
Amount *float64 `json:"amount"`
Category string `json:"category"`
Type string `json:"type"`
Note string `json:"note"`
Date string `json:"date"`
Message string `json:"message"`
}
if err := json.Unmarshal([]byte(content), &parsed); err != nil {
// If not JSON, return as message
return nil, content, nil
}
params := &AITransactionParams{
Amount: parsed.Amount,
Category: parsed.Category,
Type: parsed.Type,
Note: parsed.Note,
Date: parsed.Date,
}
return params, parsed.Message, nil
}
// parseIntentSimple provides simple regex-based parsing as fallback
// This is also used as a fast path for simple inputs
func (s *LLMService) parseIntentSimple(text string) (*AITransactionParams, string, error) {
params := &AITransactionParams{
Type: "expense", // Default to expense
Date: time.Now().Format("2006-01-02"),
}
// Extract amount using regex - support various formats
amountPatterns := []string{
`(\d+(?:\.\d+)?)\s*(?:元|块|¥|¥|块钱|元钱)`,
`(?:花了?|付了?|买了?|消费)\s*(\d+(?:\.\d+)?)`,
`(\d+(?:\.\d+)?)\s*(?:的|块的)`,
}
for _, pattern := range amountPatterns {
amountRegex := regexp.MustCompile(pattern)
if matches := amountRegex.FindStringSubmatch(text); len(matches) > 1 {
if amount, err := strconv.ParseFloat(matches[1], 64); err == nil {
params.Amount = &amount
break
}
}
}
// If still no amount, try simple number extraction
if params.Amount == nil {
simpleAmountRegex := regexp.MustCompile(`(\d+(?:\.\d+)?)`)
if matches := simpleAmountRegex.FindStringSubmatch(text); len(matches) > 1 {
if amount, err := strconv.ParseFloat(matches[1], 64); err == nil {
params.Amount = &amount
}
}
}
// Enhanced category detection with priority
categoryPatterns := []struct {
keywords []string
category string
}{
{[]string{"奶茶", "咖啡", "茶", "饮料", "柠檬", "果汁"}, "餐饮"},
{[]string{"吃", "喝", "餐", "外卖", "饭", "面", "粉", "粥", "包子", "早餐", "午餐", "晚餐", "宵夜"}, "餐饮"},
{[]string{"打车", "滴滴", "出租", "的士", "uber", "曹操"}, "交通"},
{[]string{"地铁", "公交", "公车", "巴士", "轻轨", "高铁", "火车", "飞机", "机票"}, "交通"},
{[]string{"加油", "油费", "停车", "过路费"}, "交通"},
{[]string{"超市", "便利店", "商场", "购物", "淘宝", "京东", "拼多多"}, "购物"},
{[]string{"买", "购"}, "购物"},
{[]string{"水电", "电费", "水费", "燃气", "煤气", "物业"}, "生活缴费"},
{[]string{"房租", "租金", "房贷"}, "住房"},
{[]string{"电影", "游戏", "KTV", "唱歌", "娱乐", "玩"}, "娱乐"},
{[]string{"医院", "药", "看病", "挂号", "医疗"}, "医疗"},
{[]string{"话费", "流量", "充值", "手机费"}, "通讯"},
{[]string{"工资", "薪水", "薪资", "月薪"}, "工资"},
{[]string{"奖金", "年终奖", "绩效"}, "奖金"},
{[]string{"红包", "转账", "收款"}, "其他收入"},
}
for _, cp := range categoryPatterns {
for _, keyword := range cp.keywords {
if strings.Contains(text, keyword) {
params.Category = cp.category
break
}
}
if params.Category != "" {
break
}
}
// Default category if not detected
if params.Category == "" {
params.Category = "其他"
}
// Detect income keywords
incomeKeywords := []string{"工资", "薪", "奖金", "红包", "收入", "进账", "到账", "收到", "收款"}
for _, keyword := range incomeKeywords {
if strings.Contains(text, keyword) {
params.Type = "income"
break
}
}
// Extract note - remove amount and common words
note := text
if params.Amount != nil {
note = regexp.MustCompile(`\d+(?:\.\d+)?\s*(?:元|块|¥|¥|块钱|元钱)?`).ReplaceAllString(note, "")
}
note = strings.TrimSpace(note)
// Remove common filler words
fillerWords := []string{"买了", "花了", "付了", "消费了", "一个", "一条", "一份", "的"}
for _, word := range fillerWords {
note = strings.ReplaceAll(note, word, "")
}
note = strings.TrimSpace(note)
if note != "" {
params.Note = note
}
// Generate response message
var message string
if params.Amount == nil {
message = "请问金额是多少?"
} else {
typeLabel := "支出"
if params.Type == "income" {
typeLabel = "收入"
}
message = fmt.Sprintf("记录:%s %.2f元,分类:%s", typeLabel, *params.Amount, params.Category)
if params.Note != "" {
message += ",备注:" + params.Note
}
}
return params, message, nil
}
// GenerateReport generates a report based on the provided prompt using LLM
func (s *LLMService) GenerateReport(ctx context.Context, prompt string) (string, error) {
if s.config.OpenAIAPIKey == "" || s.config.OpenAIBaseURL == "" {
return "", errors.New("OpenAI API not configured")
}
messages := []ChatMessage{
{
Role: "user",
Content: prompt,
},
}
reqBody := ChatCompletionRequest{
Model: s.config.ChatModel,
Messages: messages,
Temperature: 0.7, // Higher temperature for creative insights
}
jsonBody, err := json.Marshal(reqBody)
if err != nil {
return "", fmt.Errorf("failed to marshal request: %w", err)
}
req, err := http.NewRequestWithContext(ctx, "POST", s.config.OpenAIBaseURL+"/chat/completions", bytes.NewReader(jsonBody))
if err != nil {
return "", fmt.Errorf("failed to create request: %w", err)
}
req.Header.Set("Authorization", "Bearer "+s.config.OpenAIAPIKey)
req.Header.Set("Content-Type", "application/json")
resp, err := s.httpClient.Do(req)
if err != nil {
return "", fmt.Errorf("generate report request failed: %w", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
body, _ := io.ReadAll(resp.Body)
return "", fmt.Errorf("generate report failed with status %d: %s", resp.StatusCode, string(body))
}
var chatResp ChatCompletionResponse
if err := json.NewDecoder(resp.Body).Decode(&chatResp); err != nil {
return "", fmt.Errorf("failed to decode response: %w", err)
}
if len(chatResp.Choices) == 0 {
return "", errors.New("no response from AI")
}
return chatResp.Choices[0].Message.Content, nil
}
// MapAccountName maps natural language account name to account ID
func (s *LLMService) MapAccountName(ctx context.Context, name string, userID uint) (*uint, string, error) {
if name == "" {
return nil, "", nil
}
accounts, err := s.accountRepo.GetAll(userID)
if err != nil {
return nil, "", err
}
// Try exact match first
for _, acc := range accounts {
if strings.EqualFold(acc.Name, name) {
return &acc.ID, acc.Name, nil
}
}
// Try partial match
for _, acc := range accounts {
if strings.Contains(strings.ToLower(acc.Name), strings.ToLower(name)) ||
strings.Contains(strings.ToLower(name), strings.ToLower(acc.Name)) {
return &acc.ID, acc.Name, nil
}
}
return nil, "", nil
}
// MapCategoryName maps natural language category name to category ID
func (s *LLMService) MapCategoryName(ctx context.Context, name string, txType string, userID uint) (*uint, string, error) {
if name == "" {
return nil, "", nil
}
categories, err := s.categoryRepo.GetAll(userID)
if err != nil {
return nil, "", err
}
// Filter by transaction type
var filtered []models.Category
for _, cat := range categories {
if (txType == "expense" && cat.Type == "expense") ||
(txType == "income" && cat.Type == "income") ||
txType == "" {
filtered = append(filtered, cat)
}
}
// Try exact match first
for _, cat := range filtered {
if strings.EqualFold(cat.Name, name) {
return &cat.ID, cat.Name, nil
}
}
// Try partial match
for _, cat := range filtered {
if strings.Contains(strings.ToLower(cat.Name), strings.ToLower(name)) ||
strings.Contains(strings.ToLower(name), strings.ToLower(cat.Name)) {
return &cat.ID, cat.Name, nil
}
}
return nil, "", nil
}
// generateSessionID generates a unique session ID
func generateSessionID() string {
return fmt.Sprintf("ai_%d_%d", time.Now().UnixNano(), time.Now().Unix()%1000)
}
// getOrCreateSession gets existing session or creates new one
func (s *AIBookkeepingService) getOrCreateSession(sessionID string, userID uint) *AISession {
s.sessionMutex.Lock()
defer s.sessionMutex.Unlock()
if sessionID != "" {
if session, ok := s.sessions[sessionID]; ok {
if time.Now().Before(session.ExpiresAt) {
return session
}
delete(s.sessions, sessionID)
}
}
// Create new session
newID := generateSessionID()
session := &AISession{
ID: newID,
UserID: userID,
Params: &AITransactionParams{},
Messages: []ChatMessage{},
CreatedAt: time.Now(),
ExpiresAt: time.Now().Add(s.config.AISessionTimeout),
}
s.sessions[newID] = session
return session
}
// cleanupExpiredSessions periodically removes expired sessions
func (s *AIBookkeepingService) cleanupExpiredSessions() {
ticker := time.NewTicker(5 * time.Minute)
for range ticker.C {
s.sessionMutex.Lock()
now := time.Now()
for id, session := range s.sessions {
if now.After(session.ExpiresAt) {
delete(s.sessions, id)
}
}
s.sessionMutex.Unlock()
}
}
// ProcessChat processes a chat message and returns AI response
// Requirements: 7.2-7.4, 7.7-7.10, 12.5, 12.8
func (s *AIBookkeepingService) ProcessChat(ctx context.Context, userID uint, sessionID string, message string) (*AIChatResponse, error) {
session := s.getOrCreateSession(sessionID, userID)
// Add user message to history
session.Messages = append(session.Messages, ChatMessage{
Role: "user",
Content: message,
})
// 1. 获取财务上下文(用于所有高级功能)
fc, err := s.GetUserFinancialContext(ctx, userID)
if err != nil {
// 降级处理,不中断流程
fmt.Printf("Failed to get financial context: %v\n", err)
}
// 2. 检测纯查询意图(预算、资产、统计)
queryIntent := s.detectQueryIntent(message)
if queryIntent != "" && fc != nil {
responseMsg := s.handleQueryIntent(ctx, queryIntent, message, fc)
response := &AIChatResponse{
SessionID: session.ID,
Message: responseMsg,
Intent: queryIntent,
Params: session.Params, // 保持参数上下文
}
// 记录 AI 回复
session.Messages = append(session.Messages, ChatMessage{
Role: "assistant",
Content: response.Message,
})
return response, nil
}
// 3. 检测消费建议意图(想吃/想买/想喝等)
isSpendingAdvice := s.isSpendingAdviceIntent(message)
// Parse intent for transaction
params, responseMsg, err := s.llmService.ParseIntent(ctx, message, session.Messages[:len(session.Messages)-1], userID)
if err != nil {
return nil, fmt.Errorf("failed to parse intent: %w", err)
}
// Merge with existing session params
if params != nil {
s.mergeParams(session.Params, params)
}
// Map account and category names to IDs
if session.Params.Account != "" && session.Params.AccountID == nil {
accountID, accountName, _ := s.llmService.MapAccountName(ctx, session.Params.Account, userID)
if accountID != nil {
session.Params.AccountID = accountID
session.Params.Account = accountName
}
}
if session.Params.Category != "" && session.Params.CategoryID == nil {
categoryID, categoryName, _ := s.llmService.MapCategoryName(ctx, session.Params.Category, session.Params.Type, userID)
if categoryID != nil {
session.Params.CategoryID = categoryID
session.Params.Category = categoryName
}
}
// If category still not mapped, try to get a default category
if session.Params.CategoryID == nil && session.Params.Category != "" {
defaultCategoryID, defaultCategoryName := s.getDefaultCategory(userID, session.Params.Type)
if defaultCategoryID != nil {
session.Params.CategoryID = defaultCategoryID
if session.Params.Category == "" {
session.Params.Category = defaultCategoryName
}
}
}
// If no account specified, use default account
if session.Params.AccountID == nil {
defaultAccountID, defaultAccountName := s.getDefaultAccount(userID, session.Params.Type)
if defaultAccountID != nil {
session.Params.AccountID = defaultAccountID
session.Params.Account = defaultAccountName
}
}
// 初始化响应
response := &AIChatResponse{
SessionID: session.ID,
Message: responseMsg,
Intent: "create_transaction",
Params: session.Params,
}
// 4. 处理消费建议意图
// 即使没有金额,如果用户是在寻求建议(如“吃什么”),也应该进入建议流程
if isSpendingAdvice {
response.Intent = "spending_advice"
advice := s.generateSpendingAdvice(ctx, message, session.Params, fc)
if advice != "" {
response.Message = advice
}
}
// Check what's missing for transaction creation
missingFields := s.getMissingFields(session.Params)
if len(missingFields) > 0 {
response.NeedsFollowUp = true
response.FollowUpQuestion = s.generateFollowUpQuestion(ctx, missingFields, message, fc)
// 如果有了更好的建议回复(来自 handleQuery 或 spendingAdvice且是 FollowUp优先保留建议的部分内容或组合
if response.Message == "" || response.Message == responseMsg {
response.Message = response.FollowUpQuestion
}
} else {
// All params complete, generate confirmation card
card := s.GenerateConfirmationCard(session)
response.ConfirmationCard = card
// 如果不是消费建议,使用标准确认消息
if !isSpendingAdvice {
response.Message = fmt.Sprintf("请确认:%s %.2f元,分类:%s账户%s",
s.getTypeLabel(session.Params.Type),
*session.Params.Amount,
session.Params.Category,
session.Params.Account)
} else {
// 消费建议场景,在建议后附加确认提示
response.Message += fmt.Sprintf("\n\n📝 需要记账吗?%s %.2f元",
s.getTypeLabel(session.Params.Type),
*session.Params.Amount)
}
}
// Add assistant response to history
session.Messages = append(session.Messages, ChatMessage{
Role: "assistant",
Content: response.Message,
})
return response, nil
}
// detectQueryIntent 检测用户查询意图
func (s *AIBookkeepingService) detectQueryIntent(message string) string {
budgetKeywords := []string{"预算", "剩多少", "还能花", "余额"} // 注意:余额可能指账户余额,这里简化处理
assetKeywords := []string{"资产", "多少钱", "家底", "存款", "身家", "总钱"}
statsKeywords := []string{"花了多少", "支出", "账单", "消费", "统计"}
for _, kw := range budgetKeywords {
if strings.Contains(message, kw) {
return "query_budget"
}
}
for _, kw := range assetKeywords {
if strings.Contains(message, kw) {
return "query_assets"
}
}
for _, kw := range statsKeywords {
if strings.Contains(message, kw) {
return "query_stats"
}
}
return ""
}
// handleQueryIntent 处理查询意图并生成 LLM 回复
func (s *AIBookkeepingService) handleQueryIntent(ctx context.Context, intent string, message string, fc *FinancialContext) string {
if s.config.OpenAIAPIKey == "" {
return "抱歉我的大脑API Key似乎离家出走了无法思考..."
}
// 计算人设模式
personaMode := "balance" // 默认平衡
healthScore := 60 // 默认及格
// 升级版健康分计算 (结合负债率)
if fc.TotalAssets > 0 {
// 基础分:资产/负债比
ratio := (fc.TotalAssets - fc.TotalLiabilities) / fc.TotalAssets
healthScore = 40 + int(ratio*50)
// 负债率惩罚
if fc.DebtRatio > 0.5 {
healthScore -= 20
}
if fc.DebtRatio > 0.8 {
healthScore -= 20 // 严重扣分
}
} else if fc.TotalLiabilities > 0 {
// 资不抵债
healthScore = 10
}
if healthScore > 80 {
personaMode = "rich"
} else if healthScore <= 40 {
personaMode = "poor"
}
// 构建 Prompt
systemPrompt := fmt.Sprintf(`你是「小金」Novault 的首席财务 AI。
当前模式:%s (根据用户财务健康分 %d 判定)
角色设定:
- **rich (富裕)**:撒娇卖萌,夸用户会赚钱,鼓励适度享受。用词:哎哟、不错哦、老板大气。
- **balance (平衡)**:理性贴心,温和提醒。用词:虽然、但是、建议。
- **poor (吃土)**:毒舌、阴阳怪气、恨铁不成钢。用词:啧啧、清醒点、吃土、西北风。
用户意图:%s
用户问题:「%s」
财务数据上下文:
%s
要求:
1. 根据意图提取并回答关键数据(预算剩余、总资产、或本月支出)。
2. 必须符合当前人设模式的语气。
3. **不需要限制字数**,想说多少说多少,关键是要“有梗”和“有趣”。
4. 可以尽情使用比喻、夸张、反讽、网络流行语。
5. 不要罗列所有数据,只回答用户问的,但是回答的方式要出人意料。`,
personaMode, healthScore, intent, message, s.formatFinancialContextForLLM(fc))
messages := []ChatMessage{
{Role: "system", Content: systemPrompt},
{Role: "user", Content: message},
}
return s.callLLM(ctx, messages)
}
// formatFinancialContextForLLM 格式化上下文给 LLM
func (s *AIBookkeepingService) formatFinancialContextForLLM(fc *FinancialContext) string {
data, _ := json.MarshalIndent(fc, "", " ")
return string(data)
}
// callLLM 通用 LLM 调用 helper
func (s *AIBookkeepingService) callLLM(ctx context.Context, messages []ChatMessage) string {
reqBody := ChatCompletionRequest{
Model: s.config.ChatModel,
Messages: messages,
Temperature: 0.8, // 稍微调高以增加人设表现力
}
jsonBody, _ := json.Marshal(reqBody)
req, err := http.NewRequestWithContext(ctx, "POST", s.config.OpenAIBaseURL+"/chat/completions", bytes.NewReader(jsonBody))
if err != nil {
return "思考中断..."
}
req.Header.Set("Authorization", "Bearer "+s.config.OpenAIAPIKey)
req.Header.Set("Content-Type", "application/json")
resp, err := s.llmService.httpClient.Do(req)
if err != nil || resp.StatusCode != http.StatusOK {
return "大脑短路了..."
}
defer resp.Body.Close()
var chatResp ChatCompletionResponse
if err := json.NewDecoder(resp.Body).Decode(&chatResp); err != nil || len(chatResp.Choices) == 0 {
return "..."
}
return strings.TrimSpace(chatResp.Choices[0].Message.Content)
}
// isSpendingAdviceIntent 检测是否为消费建议意图
func (s *AIBookkeepingService) isSpendingAdviceIntent(message string) bool {
keywords := []string{"想吃", "想喝", "想买", "想花", "打算买", "准备买", "要不要", "可以买", "能买", "想要", "推荐", "吃什么", "喝什么", "买什么"}
for _, kw := range keywords {
if strings.Contains(message, kw) {
return true
}
}
return false
}
// generateSpendingAdvice 生成消费建议
func (s *AIBookkeepingService) generateSpendingAdvice(ctx context.Context, message string, params *AITransactionParams, fc *FinancialContext) string {
if s.config.OpenAIAPIKey == "" || fc == nil {
if params.Amount != nil {
return fmt.Sprintf("记下来!%.0f元的%s", *params.Amount, params.Note)
}
return ""
}
// 动态人设逻辑
personaMode := "balance"
healthScore := 60
if fc.TotalAssets > 0 {
ratio := (fc.TotalAssets - fc.TotalLiabilities) / fc.TotalAssets
healthScore = 40 + int(ratio*50)
if fc.DebtRatio > 0.5 {
healthScore -= 20
}
} else if fc.TotalLiabilities > 0 {
healthScore = 10
}
if healthScore > 80 {
personaMode = "rich"
} else if healthScore <= 40 {
personaMode = "poor"
}
prompt := fmt.Sprintf(`你是「小金」Novault 的首席财务 AI。
当前模式:%s (根据用户财务健康分 %d 判定)
角色设定:
- **rich**: 鼓励享受,语气轻松。
- **balance**: 理性建议,温和提醒。
- **poor**: 毒舌劝阻,语气严厉。
用户说:「%s」
财务数据:
%s
请分析消费请求,给出建议。
要求:
1. 不要加前缀,直接回复。
2. **尽情发挥戏剧性**:如果是 rich 模式,就极尽奢华浮夸之能事;如果是 poor 模式,就极尽刻薄毒舌之能事。
3. 可以“顾左右而言他”,不一定非要正面回答买不买,可以吐槽商品、吐槽用户、吐槽世界。
4. 让用户觉得你是一个有独立思想的“杠精”或“捧哏”,而不是一个莫得感情的计算器。`,
personaMode, healthScore, message, s.formatFinancialContextForLLM(fc))
messages := []ChatMessage{
{Role: "user", Content: prompt},
}
return s.callLLM(ctx, messages)
}
// mergeParams merges new params into existing params
func (s *AIBookkeepingService) mergeParams(existing, new *AITransactionParams) {
if new.Amount != nil {
existing.Amount = new.Amount
}
if new.Category != "" {
existing.Category = new.Category
}
if new.CategoryID != nil {
existing.CategoryID = new.CategoryID
}
if new.Account != "" {
existing.Account = new.Account
}
if new.AccountID != nil {
existing.AccountID = new.AccountID
}
if new.Type != "" {
existing.Type = new.Type
}
if new.Date != "" {
existing.Date = new.Date
}
if new.Note != "" {
existing.Note = new.Note
}
}
// getDefaultAccount gets the default account based on transaction type
// If no default is set, returns the first available account
func (s *AIBookkeepingService) getDefaultAccount(userID uint, txType string) (*uint, string) {
// First try to get user's configured default account
settings, err := s.userSettingsRepo.GetOrCreate(userID)
if err == nil && settings != nil {
var accountID *uint
if txType == "expense" && settings.DefaultExpenseAccountID != nil {
accountID = settings.DefaultExpenseAccountID
} else if txType == "income" && settings.DefaultIncomeAccountID != nil {
accountID = settings.DefaultIncomeAccountID
}
if accountID != nil {
account, err := s.accountRepo.GetByID(userID, *accountID)
if err == nil && account != nil {
return accountID, account.Name
}
}
}
// Fallback: get the first available account
accounts, err := s.accountRepo.GetAll(userID)
if err != nil || len(accounts) == 0 {
return nil, ""
}
// Return the first account (usually sorted by sort_order)
return &accounts[0].ID, accounts[0].Name
}
// getDefaultCategory gets the first category of the given type
func (s *AIBookkeepingService) getDefaultCategory(userID uint, txType string) (*uint, string) {
categories, err := s.categoryRepo.GetAll(userID)
if err != nil || len(categories) == 0 {
return nil, ""
}
// Find the first category matching the transaction type
categoryType := "expense"
if txType == "income" {
categoryType = "income"
}
for _, cat := range categories {
if string(cat.Type) == categoryType {
return &cat.ID, cat.Name
}
}
// If no matching type found, return the first category
return &categories[0].ID, categories[0].Name
}
// getMissingFields returns list of missing required fields
func (s *AIBookkeepingService) getMissingFields(params *AITransactionParams) []string {
var missing []string
if params.Amount == nil || *params.Amount <= 0 {
missing = append(missing, "amount")
}
if params.CategoryID == nil && params.Category == "" {
missing = append(missing, "category")
}
if params.AccountID == nil && params.Account == "" {
missing = append(missing, "account")
}
return missing
}
// generateFollowUpQuestion generates a follow-up question for missing fields using LLM
func (s *AIBookkeepingService) generateFollowUpQuestion(ctx context.Context, missing []string, userMessage string, fc *FinancialContext) string {
if len(missing) == 0 {
return ""
}
fieldNames := map[string]string{
"amount": "金额",
"category": "分类",
"account": "账户",
}
var names []string
for _, field := range missing {
if name, ok := fieldNames[field]; ok {
names = append(names, name)
}
}
missingStr := strings.Join(names, "、")
// 如果没有 API Key降级到模板回复
if s.config.OpenAIAPIKey == "" {
if len(names) == 1 {
return fmt.Sprintf("请问%s是多少", names[0])
}
return fmt.Sprintf("请补充以下信息:%s", missingStr)
}
// 动态人设逻辑
personaMode := "balance"
healthScore := 60
if fc != nil && fc.TotalAssets > 0 {
ratio := (fc.TotalAssets - fc.TotalLiabilities) / fc.TotalAssets
healthScore = 40 + int(ratio*50)
}
if healthScore > 80 {
personaMode = "rich"
} else if healthScore <= 40 {
personaMode = "poor"
}
prompt := fmt.Sprintf(`你是「小金」,一个贱萌、毒舌、幽默的 AI 记账助手。
当前模式:%s
缺失信息:%s
用户说:「%s」
请用你的风格追问用户缺失的信息。
要求:
1. 针对缺失的 "%s",用幽默、调侃、甚至一点点“攻击性”的方式追问。
2. **拒绝机械模板**:不要每次都问一样的话,要根据心情随机应变。
3. 比如缺失金额,可以问"是白嫖的吗?"、"价格是国家机密吗?"、"快说花了多少,让我死心"。
4. 比如缺失分类,可以猜"这属于吃喝玩乐哪一类?"、"是用来填饱肚子的还是填补空虚的?"。
5. 字数不限,关键是要**有灵魂**,像个真人在聊天。`,
personaMode, missingStr, userMessage, missingStr)
messages := []ChatMessage{
{Role: "user", Content: prompt},
}
answer := s.callLLM(ctx, messages)
if answer == "" || answer == "..." {
return fmt.Sprintf("那个...你能告诉我%s吗", missingStr)
}
return answer
}
// getTypeLabel returns Chinese label for transaction type
func (s *AIBookkeepingService) getTypeLabel(txType string) string {
if txType == "income" {
return "收入"
}
return "支出"
}
// GenerateConfirmationCard creates a confirmation card from session params
func (s *AIBookkeepingService) GenerateConfirmationCard(session *AISession) *ConfirmationCard {
params := session.Params
card := &ConfirmationCard{
SessionID: session.ID,
Type: params.Type,
Note: params.Note,
IsComplete: true,
}
if params.Amount != nil {
card.Amount = *params.Amount
}
if params.CategoryID != nil {
card.CategoryID = *params.CategoryID
}
card.Category = params.Category
if params.AccountID != nil {
card.AccountID = *params.AccountID
}
card.Account = params.Account
// Set date
if params.Date != "" {
card.Date = params.Date
} else {
card.Date = time.Now().Format("2006-01-02")
}
// Check for budget warnings
if params.Amount != nil && *params.Amount > 0 && params.Type == "expense" {
card.Warning = s.checkBudgetWarning(session.UserID, params.CategoryID, *params.Amount)
}
return card
}
// checkBudgetWarning checks if the transaction exceeds any budget
func (s *AIBookkeepingService) checkBudgetWarning(userID uint, categoryID *uint, amount float64) string {
now := time.Now()
var budgets []models.Budget
// Find active budgets
// We check:
// 1. Budgets specifically for this category
// 2. Global budgets (CategoryID is NULL)
query := s.db.Where("user_id = ?", userID).
Where("start_date <= ?", now).
Where("end_date IS NULL OR end_date >= ?", now)
if categoryID != nil {
query = query.Where("category_id = ? OR category_id IS NULL", *categoryID)
} else {
query = query.Where("category_id IS NULL")
}
if err := query.Find(&budgets).Error; err != nil {
return ""
}
for _, budget := range budgets {
// Calculate current period
start, end := s.calculateBudgetPeriod(&budget, now)
// Query spent amount
var totalSpent float64
q := s.db.Model(&models.Transaction{}).
Where("user_id = ? AND type = ? AND transaction_date BETWEEN ? AND ?",
userID, models.TransactionTypeExpense, start, end)
if budget.CategoryID != nil {
// Get sub-categories
var subCategoryIDs []uint
s.db.Model(&models.Category{}).Where("parent_id = ?", *budget.CategoryID).Pluck("id", &subCategoryIDs)
categoryIDs := append(subCategoryIDs, *budget.CategoryID)
q = q.Where("category_id IN ?", categoryIDs)
}
// If budget has account restriction, we should ideally check that too,
// but we don't always have accountID resolved here perfectly or it might be complex.
// For now, focusing on category budgets which are most common.
q.Select("COALESCE(SUM(amount), 0)").Scan(&totalSpent)
if totalSpent+amount > budget.Amount {
remaining := budget.Amount - totalSpent
if remaining < 0 {
remaining = 0
}
return fmt.Sprintf("⚠️ 预算预警:此交易将使【%s】预算超支 (当前剩余 %.2f)", budget.Name, remaining)
}
}
return ""
}
// TranscribeAudio transcribes audio and returns text
func (s *AIBookkeepingService) TranscribeAudio(ctx context.Context, audioData io.Reader, filename string) (*TranscriptionResult, error) {
return s.whisperService.TranscribeAudio(ctx, audioData, filename)
}
// ConfirmTransaction creates a transaction from confirmed card
// Requirements: 7.10
func (s *AIBookkeepingService) ConfirmTransaction(ctx context.Context, sessionID string, userID uint) (*models.Transaction, error) {
s.sessionMutex.RLock()
session, ok := s.sessions[sessionID]
s.sessionMutex.RUnlock()
if !ok {
return nil, errors.New("session not found or expired")
}
params := session.Params
// Validate required fields
if params.Amount == nil || *params.Amount <= 0 {
return nil, errors.New("无效的金额")
}
if params.CategoryID == nil {
return nil, errors.New("未指定分类")
}
if params.AccountID == nil {
return nil, errors.New("未指定账户")
}
// Parse date
var txDate time.Time
if params.Date != "" {
var err error
txDate, err = time.Parse("2006-01-02", params.Date)
if err != nil {
txDate = time.Now()
}
} else {
txDate = time.Now()
}
// Determine transaction type
txType := models.TransactionTypeExpense
if params.Type == "income" {
txType = models.TransactionTypeIncome
}
var transaction *models.Transaction
// Execute within a database transaction to ensure atomicity
err := s.db.Transaction(func(tx *gorm.DB) error {
txAccountRepo := repository.NewAccountRepository(tx)
txTransactionRepo := repository.NewTransactionRepository(tx)
// Get account first to check balance
account, err := txAccountRepo.GetByID(userID, *params.AccountID)
if err != nil {
return fmt.Errorf("failed to find account: %w", err)
}
// Calculate new balance
newBalance := account.Balance
if txType == models.TransactionTypeExpense {
newBalance -= *params.Amount
} else {
newBalance += *params.Amount
}
// Critical Check: Prevent negative balance for non-credit accounts
if !account.IsCredit && newBalance < 0 {
return fmt.Errorf("余额不足:账户“%s”不支持负余额 (当前: %.2f, 尝试扣款: %.2f)",
account.Name, account.Balance, *params.Amount)
}
// Create transaction model
transaction = &models.Transaction{
UserID: userID,
Amount: *params.Amount,
Type: txType,
CategoryID: *params.CategoryID,
AccountID: *params.AccountID,
TransactionDate: txDate,
Note: params.Note,
Currency: models.Currency(account.Currency), // Use account currency
}
if transaction.Currency == "" {
transaction.Currency = "CNY"
}
// Save transaction
if err := txTransactionRepo.Create(transaction); err != nil {
return fmt.Errorf("failed to create transaction: %w", err)
}
// Update account balance
account.Balance = newBalance
if err := txAccountRepo.Update(account); err != nil {
return fmt.Errorf("failed to update account balance: %w", err)
}
return nil
})
if err != nil {
return nil, err
}
// Clean up session only on success
s.sessionMutex.Lock()
delete(s.sessions, sessionID)
s.sessionMutex.Unlock()
return transaction, nil
}
// GetSession returns session by ID
func (s *AIBookkeepingService) GetSession(sessionID string) (*AISession, bool) {
s.sessionMutex.RLock()
defer s.sessionMutex.RUnlock()
session, ok := s.sessions[sessionID]
if !ok || time.Now().After(session.ExpiresAt) {
return nil, false
}
return session, true
}
// FinancialContext 用户财务上下文,供 AI 综合分析
type FinancialContext struct {
// 账户信息
TotalBalance float64 `json:"total_balance"` // 净资产 (资产 - 负债)
TotalAssets float64 `json:"total_assets"` // 总资产
TotalLiabilities float64 `json:"total_liabilities"` // 总负债
DebtRatio float64 `json:"debt_ratio"` // 负债率 (0-1)
AccountSummary []AccountBrief `json:"account_summary"` // 账户摘要
// 最近消费
Last30DaysSpend float64 `json:"last_30_days_spend"` // 近30天支出
Last7DaysSpend float64 `json:"last_7_days_spend"` // 近7天支出
TodaySpend float64 `json:"today_spend"` // 今日支出
MaxSingleSpend *TransactionBrief `json:"max_single_spend"` // 本月最大单笔支出
TopCategories []CategorySpend `json:"top_categories"` // 消费大类TOP3
RecentTransactions []TransactionBrief `json:"recent_transactions"` // 最近5笔交易
UpcomingRecurring []string `json:"upcoming_recurring"` // 未来7天固定支出提醒
// 预算与目标
ActiveBudgets []BudgetBrief `json:"active_budgets"` // 活跃预算
BudgetWarnings []string `json:"budget_warnings"` // 预算警告
SavingsProgress []string `json:"savings_progress"` // 存钱进度摘要
// 历史对比
LastMonthSpend float64 `json:"last_month_spend"` // 上月同期支出
SpendTrend string `json:"spend_trend"` // 消费趋势: up/down/stable
}
// AccountBrief 账户摘要
type AccountBrief struct {
Name string `json:"name"`
Balance float64 `json:"balance"`
Type string `json:"type"`
}
// CategorySpend 分类消费
type CategorySpend struct {
Category string `json:"category"`
Amount float64 `json:"amount"`
Percent float64 `json:"percent"`
}
// TransactionBrief 交易摘要
type TransactionBrief struct {
Amount float64 `json:"amount"`
Category string `json:"category"`
Note string `json:"note"`
Date string `json:"date"`
Type string `json:"type"`
}
// BudgetBrief 预算摘要
type BudgetBrief struct {
Name string `json:"name"`
Amount float64 `json:"amount"`
Spent float64 `json:"spent"`
Remaining float64 `json:"remaining"`
Progress float64 `json:"progress"` // 0-100
IsNearLimit bool `json:"is_near_limit"`
Category string `json:"category,omitempty"`
}
// GetUserFinancialContext 获取用户财务上下文
func (s *AIBookkeepingService) GetUserFinancialContext(ctx context.Context, userID uint) (*FinancialContext, error) {
fc := &FinancialContext{}
// 1. 获取账户信息
accounts, err := s.accountRepo.GetAll(userID)
if err == nil {
for _, acc := range accounts {
fc.TotalBalance += acc.Balance
// 根据余额正负判断资产/负债
// 余额 >= 0 计入资产
// 余额 < 0 计入负债(取绝对值)
if acc.Balance >= 0 {
fc.TotalAssets += acc.Balance
} else {
fc.TotalLiabilities += -acc.Balance
}
fc.AccountSummary = append(fc.AccountSummary, AccountBrief{
Name: acc.Name,
Balance: acc.Balance,
Type: string(acc.Type),
})
}
}
// 2. 获取最近交易统计
now := time.Now()
today := time.Date(now.Year(), now.Month(), now.Day(), 0, 0, 0, 0, now.Location())
last7Days := today.AddDate(0, 0, -7)
last30Days := today.AddDate(0, 0, -30)
lastMonthStart := today.AddDate(0, -1, -today.Day()+1)
lastMonthEnd := today.AddDate(0, 0, -today.Day())
// 获取近30天交易
transactions, err := s.transactionRepo.GetByDateRange(userID, last30Days, now)
if err == nil {
categorySpend := make(map[string]float64)
for _, tx := range transactions {
if tx.Type == models.TransactionTypeExpense {
fc.Last30DaysSpend += tx.Amount
if tx.TransactionDate.After(last7Days) || tx.TransactionDate.Equal(last7Days) {
fc.Last7DaysSpend += tx.Amount
}
if tx.TransactionDate.After(today) || tx.TransactionDate.Equal(today) {
fc.TodaySpend += tx.Amount
}
// 分类统计
catName := "其他"
if tx.Category.ID != 0 {
catName = tx.Category.Name
}
categorySpend[catName] += tx.Amount
}
}
// 计算 TOP3 分类
type catAmount struct {
name string
amount float64
}
var cats []catAmount
for name, amount := range categorySpend {
cats = append(cats, catAmount{name, amount})
}
// 简单排序取 TOP3
for i := 0; i < len(cats)-1; i++ {
for j := i + 1; j < len(cats); j++ {
if cats[j].amount > cats[i].amount {
cats[i], cats[j] = cats[j], cats[i]
}
}
}
for i := 0; i < len(cats) && i < 3; i++ {
percent := 0.0
if fc.Last30DaysSpend > 0 {
percent = cats[i].amount / fc.Last30DaysSpend * 100
}
fc.TopCategories = append(fc.TopCategories, CategorySpend{
Category: cats[i].name,
Amount: cats[i].amount,
Percent: percent,
})
}
// 最近5笔交易
count := 0
for i := len(transactions) - 1; i >= 0 && count < 5; i-- {
tx := transactions[i]
catName := "其他"
if tx.Category.ID != 0 {
catName = tx.Category.Name
}
fc.RecentTransactions = append(fc.RecentTransactions, TransactionBrief{
Amount: tx.Amount,
Category: catName,
Note: tx.Note,
Date: tx.TransactionDate.Format("01-02"),
Type: string(tx.Type),
})
count++
}
}
// 3. 获取上月同期支出用于对比
lastMonthTx, err := s.transactionRepo.GetByDateRange(userID, lastMonthStart, lastMonthEnd)
if err == nil {
for _, tx := range lastMonthTx {
if tx.Type == models.TransactionTypeExpense {
fc.LastMonthSpend += tx.Amount
}
}
}
// 计算消费趋势
if fc.LastMonthSpend > 0 {
ratio := fc.Last30DaysSpend / fc.LastMonthSpend
if ratio > 1.1 {
fc.SpendTrend = "up"
} else if ratio < 0.9 {
fc.SpendTrend = "down"
} else {
fc.SpendTrend = "stable"
}
} else {
fc.SpendTrend = "stable"
}
// 4. 获取预算信息(通过直接查询数据库)
var budgets []models.Budget
if err := s.db.Where("user_id = ?", userID).
Where("start_date <= ?", now).
Where("end_date IS NULL OR end_date >= ?", now).
Preload("Category").
Find(&budgets).Error; err == nil {
for _, budget := range budgets {
// 计算当期支出
periodStart, periodEnd := s.calculateBudgetPeriod(&budget, now)
spent := 0.0
// 查询当期支出
var totalSpent float64
query := s.db.Model(&models.Transaction{}).
Where("user_id = ?", userID).
Where("type = ?", models.TransactionTypeExpense).
Where("transaction_date >= ? AND transaction_date <= ?", periodStart, periodEnd)
if budget.CategoryID != nil {
query = query.Where("category_id = ?", *budget.CategoryID)
}
if budget.AccountID != nil {
query = query.Where("account_id = ?", *budget.AccountID)
}
query.Select("COALESCE(SUM(amount), 0)").Scan(&totalSpent)
spent = totalSpent
progress := 0.0
if budget.Amount > 0 {
progress = spent / budget.Amount * 100
}
isNearLimit := progress >= 80.0
catName := ""
if budget.Category != nil {
catName = budget.Category.Name
}
fc.ActiveBudgets = append(fc.ActiveBudgets, BudgetBrief{
Name: budget.Name,
Amount: budget.Amount,
Spent: spent,
Remaining: budget.Amount - spent,
Progress: progress,
IsNearLimit: isNearLimit,
Category: catName,
})
// 生成预算警告
if progress >= 100 {
fc.BudgetWarnings = append(fc.BudgetWarnings,
fmt.Sprintf("【%s】预算已超支", budget.Name))
} else if isNearLimit {
fc.BudgetWarnings = append(fc.BudgetWarnings,
fmt.Sprintf("【%s】预算已用%.0f%%,请注意控制", budget.Name, progress))
}
}
}
// 计算负债率
if fc.TotalAssets+fc.TotalLiabilities > 0 {
fc.DebtRatio = fc.TotalLiabilities / (fc.TotalAssets + fc.TotalLiabilities)
}
// 5. 获取存钱罐进度 (PiggyBank)
var piggyBanks []models.PiggyBank
if err := s.db.Where("user_id = ?", userID).Find(&piggyBanks).Error; err == nil {
for _, pb := range piggyBanks {
progress := 0.0
if pb.TargetAmount > 0 {
progress = pb.CurrentAmount / pb.TargetAmount * 100
}
status := "进行中"
if progress >= 100 {
status = "已达成"
}
fc.SavingsProgress = append(fc.SavingsProgress, fmt.Sprintf("%s: %.0f/%.0f (%.1f%%) - %s", pb.Name, pb.CurrentAmount, pb.TargetAmount, progress, status))
}
}
// 6. 获取未来7天即将到期的定期事务 (RecurringTransaction)
var recurrings []models.RecurringTransaction
next7Days := now.AddDate(0, 0, 7)
if err := s.db.Where("user_id = ? AND is_active = ? AND type = ?", userID, true, models.TransactionTypeExpense).
Where("next_occurrence BETWEEN ? AND ?", now, next7Days).
Find(&recurrings).Error; err == nil {
for _, r := range recurrings {
days := int(r.NextOccurrence.Sub(now).Hours() / 24)
msg := fmt.Sprintf("%s 还有 %d 天扣款 %.2f元", r.Note, days, r.Amount)
if r.Note == "" {
// 获取分类名
var cat models.Category
s.db.First(&cat, r.CategoryID)
msg = fmt.Sprintf("%s 还有 %d 天扣款 %.2f元", cat.Name, days, r.Amount)
}
fc.UpcomingRecurring = append(fc.UpcomingRecurring, msg)
}
}
// 7. 计算本月最大单笔支出
if len(transactions) > 0 {
var maxTx *models.Transaction
for _, tx := range transactions {
if tx.Type == models.TransactionTypeExpense {
if maxTx == nil || tx.Amount > maxTx.Amount {
currentTx := tx // 创建副本以避免取地址问题
maxTx = &currentTx
}
}
}
if maxTx != nil {
catName := "其他"
if maxTx.Category.ID != 0 {
catName = maxTx.Category.Name
}
fc.MaxSingleSpend = &TransactionBrief{
Amount: maxTx.Amount,
Category: catName,
Note: maxTx.Note,
Date: maxTx.TransactionDate.Format("01-02"),
Type: string(maxTx.Type),
}
}
}
return fc, nil
}
// calculateBudgetPeriod 计算预算当前周期
func (s *AIBookkeepingService) calculateBudgetPeriod(budget *models.Budget, now time.Time) (time.Time, time.Time) {
switch budget.PeriodType {
case models.PeriodTypeDaily:
start := time.Date(now.Year(), now.Month(), now.Day(), 0, 0, 0, 0, now.Location())
end := start.AddDate(0, 0, 1).Add(-time.Second)
return start, end
case models.PeriodTypeWeekly:
weekday := int(now.Weekday())
if weekday == 0 {
weekday = 7
}
start := time.Date(now.Year(), now.Month(), now.Day()-weekday+1, 0, 0, 0, 0, now.Location())
end := start.AddDate(0, 0, 7).Add(-time.Second)
return start, end
case models.PeriodTypeMonthly:
start := time.Date(now.Year(), now.Month(), 1, 0, 0, 0, 0, now.Location())
end := start.AddDate(0, 1, 0).Add(-time.Second)
return start, end
case models.PeriodTypeYearly:
start := time.Date(now.Year(), 1, 1, 0, 0, 0, 0, now.Location())
end := start.AddDate(1, 0, 0).Add(-time.Second)
return start, end
default:
start := time.Date(now.Year(), now.Month(), 1, 0, 0, 0, 0, now.Location())
end := start.AddDate(0, 1, 0).Add(-time.Second)
return start, end
}
}