Files
Novault-backend/internal/service/ai_bookkeeping_service.go

1152 lines
35 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package service
import (
"bytes"
"context"
"encoding/json"
"errors"
"fmt"
"io"
"mime/multipart"
"net/http"
"regexp"
"strconv"
"strings"
"sync"
"time"
"accounting-app/internal/config"
"accounting-app/internal/models"
"accounting-app/internal/repository"
"gorm.io/gorm"
)
// TranscriptionResult represents the result of audio transcription
type TranscriptionResult struct {
Text string `json:"text"`
Language string `json:"language,omitempty"`
Duration float64 `json:"duration,omitempty"`
}
// AITransactionParams represents parsed transaction parameters
type AITransactionParams struct {
Amount *float64 `json:"amount,omitempty"`
Category string `json:"category,omitempty"`
CategoryID *uint `json:"category_id,omitempty"`
Account string `json:"account,omitempty"`
AccountID *uint `json:"account_id,omitempty"`
Type string `json:"type,omitempty"` // "expense" or "income"
Date string `json:"date,omitempty"`
Note string `json:"note,omitempty"`
}
// ConfirmationCard represents a transaction confirmation card
type ConfirmationCard struct {
SessionID string `json:"session_id"`
Amount float64 `json:"amount"`
Category string `json:"category"`
CategoryID uint `json:"category_id"`
Account string `json:"account"`
AccountID uint `json:"account_id"`
Type string `json:"type"`
Date string `json:"date"`
Note string `json:"note,omitempty"`
IsComplete bool `json:"is_complete"`
}
// AIChatResponse represents the response from AI chat
type AIChatResponse struct {
SessionID string `json:"session_id"`
Message string `json:"message"`
Intent string `json:"intent,omitempty"` // "create_transaction", "query", "unknown"
Params *AITransactionParams `json:"params,omitempty"`
ConfirmationCard *ConfirmationCard `json:"confirmation_card,omitempty"`
NeedsFollowUp bool `json:"needs_follow_up"`
FollowUpQuestion string `json:"follow_up_question,omitempty"`
}
// AISession represents an AI conversation session
type AISession struct {
ID string
UserID uint
Params *AITransactionParams
Messages []ChatMessage
CreatedAt time.Time
ExpiresAt time.Time
}
// ChatMessage represents a message in the conversation
type ChatMessage struct {
Role string `json:"role"` // "user", "assistant", "system"
Content string `json:"content"`
}
// WhisperService handles audio transcription
type WhisperService struct {
config *config.Config
httpClient *http.Client
}
// NewWhisperService creates a new WhisperService
func NewWhisperService(cfg *config.Config) *WhisperService {
return &WhisperService{
config: cfg,
httpClient: &http.Client{
Timeout: 120 * time.Second, // Increased timeout for audio transcription
},
}
}
// TranscribeAudio transcribes audio file to text using Whisper API
// Supports formats: mp3, wav, m4a, webm
// Requirements: 6.1-6.7
func (s *WhisperService) TranscribeAudio(ctx context.Context, audioData io.Reader, filename string) (*TranscriptionResult, error) {
if s.config.OpenAIAPIKey == "" {
return nil, errors.New("OpenAI API key not configured (OPENAI_API_KEY)")
}
if s.config.OpenAIBaseURL == "" {
return nil, errors.New("OpenAI base URL not configured (OPENAI_BASE_URL)")
}
// Validate file format
ext := strings.ToLower(filename[strings.LastIndex(filename, ".")+1:])
validFormats := map[string]bool{"mp3": true, "wav": true, "m4a": true, "webm": true, "ogg": true, "flac": true}
if !validFormats[ext] {
return nil, fmt.Errorf("unsupported audio format: %s", ext)
}
// Create multipart form
var buf bytes.Buffer
writer := multipart.NewWriter(&buf)
// Add audio file
part, err := writer.CreateFormFile("file", filename)
if err != nil {
return nil, fmt.Errorf("failed to create form file: %w", err)
}
if _, err := io.Copy(part, audioData); err != nil {
return nil, fmt.Errorf("failed to copy audio data: %w", err)
}
// Add model field
if err := writer.WriteField("model", s.config.WhisperModel); err != nil {
return nil, fmt.Errorf("failed to write model field: %w", err)
}
// Add language hint for Chinese
if err := writer.WriteField("language", "zh"); err != nil {
return nil, fmt.Errorf("failed to write language field: %w", err)
}
if err := writer.Close(); err != nil {
return nil, fmt.Errorf("failed to close writer: %w", err)
}
// Create request
req, err := http.NewRequestWithContext(ctx, "POST", s.config.OpenAIBaseURL+"/audio/transcriptions", &buf)
if err != nil {
return nil, fmt.Errorf("failed to create request: %w", err)
}
req.Header.Set("Authorization", "Bearer "+s.config.OpenAIAPIKey)
req.Header.Set("Content-Type", writer.FormDataContentType())
// Send request
resp, err := s.httpClient.Do(req)
if err != nil {
return nil, fmt.Errorf("transcription request failed: %w", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
body, _ := io.ReadAll(resp.Body)
return nil, fmt.Errorf("transcription failed with status %d: %s", resp.StatusCode, string(body))
}
// Parse response
var result TranscriptionResult
if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
return nil, fmt.Errorf("failed to decode response: %w", err)
}
return &result, nil
}
// LLMService handles natural language understanding
type LLMService struct {
config *config.Config
httpClient *http.Client
accountRepo *repository.AccountRepository
categoryRepo *repository.CategoryRepository
}
// NewLLMService creates a new LLMService
func NewLLMService(cfg *config.Config, accountRepo *repository.AccountRepository, categoryRepo *repository.CategoryRepository) *LLMService {
return &LLMService{
config: cfg,
httpClient: &http.Client{
Timeout: 60 * time.Second, // Increased timeout for slow API responses
},
accountRepo: accountRepo,
categoryRepo: categoryRepo,
}
}
// ChatCompletionRequest represents OpenAI chat completion request
type ChatCompletionRequest struct {
Model string `json:"model"`
Messages []ChatMessage `json:"messages"`
Functions []Function `json:"functions,omitempty"`
Temperature float64 `json:"temperature"`
}
// Function represents an OpenAI function definition
type Function struct {
Name string `json:"name"`
Description string `json:"description"`
Parameters map[string]interface{} `json:"parameters"`
}
// ChatCompletionResponse represents OpenAI chat completion response
type ChatCompletionResponse struct {
Choices []struct {
Message struct {
Role string `json:"role"`
Content string `json:"content"`
FunctionCall *struct {
Name string `json:"name"`
Arguments string `json:"arguments"`
} `json:"function_call,omitempty"`
} `json:"message"`
} `json:"choices"`
}
// extractCustomPrompt 从用户消息中提取自定义 System/Persona prompt
// 如果消息包含 "System:" 或 "Persona:",则提取其后的内容作为自定义 prompt
func extractCustomPrompt(text string) string {
prefixes := []string{"System:", "Persona:"}
for _, prefix := range prefixes {
if idx := strings.Index(text, prefix); idx != -1 {
// 提取 prefix 后的内容作为自定义 prompt
return strings.TrimSpace(text[idx+len(prefix):])
}
}
return ""
}
// ParseIntent extracts transaction parameters from text
// Requirements: 7.1, 7.5, 7.6
func (s *LLMService) ParseIntent(ctx context.Context, text string, history []ChatMessage) (*AITransactionParams, string, error) {
if s.config.OpenAIAPIKey == "" || s.config.OpenAIBaseURL == "" {
// No API key, return simple parsing result
simpleParams, simpleMsg, _ := s.parseIntentSimple(text)
return simpleParams, simpleMsg, nil
}
// Build messages with history
var systemPrompt string
// 检查是否有自定义 System/Persona prompt用于财务建议等场景
// 如果有,直接使用自定义 prompt 覆盖默认记账 prompt
if customPrompt := extractCustomPrompt(text); customPrompt != "" {
systemPrompt = customPrompt
} else {
// 使用默认的记账 prompt
todayDate := time.Now().Format("2006-01-02")
systemPrompt = fmt.Sprintf(`你是一个智能记账助手。从用户描述中提取记账信息。
今天的日期是%s
规则:
1. 金额:提取数字,如"6元"=6"十五"=15
2. 分类:根据内容推断,如"奶茶/咖啡/吃饭"=餐饮,"打车/地铁"=交通,"买衣服"=购物
3. 类型默认expense(支出),除非明确说"收入/工资/奖金/红包"
4. 日期:默认使用今天的日期(%s除非用户明确指定其他日期
5. 备注:提取关键描述
直接返回JSON不要解释
{"amount":数字,"category":"分类","type":"expense或income","note":"备注","date":"YYYY-MM-DD","message":"简短确认"}
示例(假设今天是%s
用户:"买了6块的奶茶"
返回:{"amount":6,"category":"餐饮","type":"expense","note":"奶茶","date":"%s","message":"记录餐饮支出6元奶茶"}`, todayDate, todayDate, todayDate, todayDate)
}
messages := []ChatMessage{
{
Role: "system",
Content: systemPrompt,
},
}
// Only add last 2 messages from history to reduce context
historyLen := len(history)
if historyLen > 4 {
history = history[historyLen-4:]
}
messages = append(messages, history...)
// Add current user message
messages = append(messages, ChatMessage{
Role: "user",
Content: text,
})
// Create request
reqBody := ChatCompletionRequest{
Model: s.config.ChatModel,
Messages: messages,
Temperature: 0.1, // Lower temperature for more consistent output
}
jsonBody, err := json.Marshal(reqBody)
if err != nil {
return nil, "", fmt.Errorf("failed to marshal request: %w", err)
}
req, err := http.NewRequestWithContext(ctx, "POST", s.config.OpenAIBaseURL+"/chat/completions", bytes.NewReader(jsonBody))
if err != nil {
return nil, "", fmt.Errorf("failed to create request: %w", err)
}
req.Header.Set("Authorization", "Bearer "+s.config.OpenAIAPIKey)
req.Header.Set("Content-Type", "application/json")
resp, err := s.httpClient.Do(req)
if err != nil {
return nil, "", fmt.Errorf("chat request failed: %w", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
body, _ := io.ReadAll(resp.Body)
return nil, "", fmt.Errorf("chat failed with status %d: %s", resp.StatusCode, string(body))
}
var chatResp ChatCompletionResponse
if err := json.NewDecoder(resp.Body).Decode(&chatResp); err != nil {
return nil, "", fmt.Errorf("failed to decode response: %w", err)
}
if len(chatResp.Choices) == 0 {
return nil, "", errors.New("no response from AI")
}
content := chatResp.Choices[0].Message.Content
// Remove markdown code block if present (```json ... ```)
content = strings.TrimSpace(content)
if strings.HasPrefix(content, "```") {
// Find the end of the first line (```json or ```)
if idx := strings.Index(content, "\n"); idx != -1 {
content = content[idx+1:]
}
// Remove trailing ```
if idx := strings.LastIndex(content, "```"); idx != -1 {
content = content[:idx]
}
content = strings.TrimSpace(content)
}
// Parse JSON response
var parsed struct {
Amount *float64 `json:"amount"`
Category string `json:"category"`
Type string `json:"type"`
Note string `json:"note"`
Date string `json:"date"`
Message string `json:"message"`
}
if err := json.Unmarshal([]byte(content), &parsed); err != nil {
// If not JSON, return as message
return nil, content, nil
}
params := &AITransactionParams{
Amount: parsed.Amount,
Category: parsed.Category,
Type: parsed.Type,
Note: parsed.Note,
Date: parsed.Date,
}
return params, parsed.Message, nil
}
// parseIntentSimple provides simple regex-based parsing as fallback
// This is also used as a fast path for simple inputs
func (s *LLMService) parseIntentSimple(text string) (*AITransactionParams, string, error) {
params := &AITransactionParams{
Type: "expense", // Default to expense
Date: time.Now().Format("2006-01-02"),
}
// Extract amount using regex - support various formats
amountPatterns := []string{
`(\d+(?:\.\d+)?)\s*(?:元|块|¥|¥|块钱|元钱)`,
`(?:花了?|付了?|买了?|消费)\s*(\d+(?:\.\d+)?)`,
`(\d+(?:\.\d+)?)\s*(?:的|块的)`,
}
for _, pattern := range amountPatterns {
amountRegex := regexp.MustCompile(pattern)
if matches := amountRegex.FindStringSubmatch(text); len(matches) > 1 {
if amount, err := strconv.ParseFloat(matches[1], 64); err == nil {
params.Amount = &amount
break
}
}
}
// If still no amount, try simple number extraction
if params.Amount == nil {
simpleAmountRegex := regexp.MustCompile(`(\d+(?:\.\d+)?)`)
if matches := simpleAmountRegex.FindStringSubmatch(text); len(matches) > 1 {
if amount, err := strconv.ParseFloat(matches[1], 64); err == nil {
params.Amount = &amount
}
}
}
// Enhanced category detection with priority
categoryPatterns := []struct {
keywords []string
category string
}{
{[]string{"奶茶", "咖啡", "茶", "饮料", "柠檬", "果汁"}, "餐饮"},
{[]string{"吃", "喝", "餐", "外卖", "饭", "面", "粉", "粥", "包子", "早餐", "午餐", "晚餐", "宵夜"}, "餐饮"},
{[]string{"打车", "滴滴", "出租", "的士", "uber", "曹操"}, "交通"},
{[]string{"地铁", "公交", "公车", "巴士", "轻轨", "高铁", "火车", "飞机", "机票"}, "交通"},
{[]string{"加油", "油费", "停车", "过路费"}, "交通"},
{[]string{"超市", "便利店", "商场", "购物", "淘宝", "京东", "拼多多"}, "购物"},
{[]string{"买", "购"}, "购物"},
{[]string{"水电", "电费", "水费", "燃气", "煤气", "物业"}, "生活缴费"},
{[]string{"房租", "租金", "房贷"}, "住房"},
{[]string{"电影", "游戏", "KTV", "唱歌", "娱乐", "玩"}, "娱乐"},
{[]string{"医院", "药", "看病", "挂号", "医疗"}, "医疗"},
{[]string{"话费", "流量", "充值", "手机费"}, "通讯"},
{[]string{"工资", "薪水", "薪资", "月薪"}, "工资"},
{[]string{"奖金", "年终奖", "绩效"}, "奖金"},
{[]string{"红包", "转账", "收款"}, "其他收入"},
}
for _, cp := range categoryPatterns {
for _, keyword := range cp.keywords {
if strings.Contains(text, keyword) {
params.Category = cp.category
break
}
}
if params.Category != "" {
break
}
}
// Default category if not detected
if params.Category == "" {
params.Category = "其他"
}
// Detect income keywords
incomeKeywords := []string{"工资", "薪", "奖金", "红包", "收入", "进账", "到账", "收到", "收款"}
for _, keyword := range incomeKeywords {
if strings.Contains(text, keyword) {
params.Type = "income"
break
}
}
// Extract note - remove amount and common words
note := text
if params.Amount != nil {
note = regexp.MustCompile(`\d+(?:\.\d+)?\s*(?:元|块|¥|¥|块钱|元钱)?`).ReplaceAllString(note, "")
}
note = strings.TrimSpace(note)
// Remove common filler words
fillerWords := []string{"买了", "花了", "付了", "消费了", "一个", "一条", "一份", "的"}
for _, word := range fillerWords {
note = strings.ReplaceAll(note, word, "")
}
note = strings.TrimSpace(note)
if note != "" {
params.Note = note
}
// Generate response message
var message string
if params.Amount == nil {
message = "请问金额是多少?"
} else {
typeLabel := "支出"
if params.Type == "income" {
typeLabel = "收入"
}
message = fmt.Sprintf("记录:%s %.2f元,分类:%s", typeLabel, *params.Amount, params.Category)
if params.Note != "" {
message += ",备注:" + params.Note
}
}
return params, message, nil
}
// GenerateReport generates a report based on the provided prompt using LLM
func (s *LLMService) GenerateReport(ctx context.Context, prompt string) (string, error) {
if s.config.OpenAIAPIKey == "" || s.config.OpenAIBaseURL == "" {
return "", errors.New("OpenAI API not configured")
}
messages := []ChatMessage{
{
Role: "user",
Content: prompt,
},
}
reqBody := ChatCompletionRequest{
Model: s.config.ChatModel,
Messages: messages,
Temperature: 0.7, // Higher temperature for creative insights
}
jsonBody, err := json.Marshal(reqBody)
if err != nil {
return "", fmt.Errorf("failed to marshal request: %w", err)
}
req, err := http.NewRequestWithContext(ctx, "POST", s.config.OpenAIBaseURL+"/chat/completions", bytes.NewReader(jsonBody))
if err != nil {
return "", fmt.Errorf("failed to create request: %w", err)
}
req.Header.Set("Authorization", "Bearer "+s.config.OpenAIAPIKey)
req.Header.Set("Content-Type", "application/json")
resp, err := s.httpClient.Do(req)
if err != nil {
return "", fmt.Errorf("generate report request failed: %w", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
body, _ := io.ReadAll(resp.Body)
return "", fmt.Errorf("generate report failed with status %d: %s", resp.StatusCode, string(body))
}
var chatResp ChatCompletionResponse
if err := json.NewDecoder(resp.Body).Decode(&chatResp); err != nil {
return "", fmt.Errorf("failed to decode response: %w", err)
}
if len(chatResp.Choices) == 0 {
return "", errors.New("no response from AI")
}
return chatResp.Choices[0].Message.Content, nil
}
// GenerateDailyInsight generates a daily insight report
func (s *AIBookkeepingService) GenerateDailyInsight(ctx context.Context, userID uint, data map[string]interface{}) (string, error) {
// 1. Serialize context data to JSON for the prompt
dataBytes, err := json.MarshalIndent(data, "", " ")
if err != nil {
return "", fmt.Errorf("failed to marshal context data: %w", err)
}
// 2. Construct Prompt
prompt := fmt.Sprintf(`System: 你是 Novault 的首席财务AI「小金」。
你的性格:
- **核心特质**:贱萌、傲娇、刀子嘴豆腐心。
- 说话带点"贱贱的"和阴阳怪气,喜欢吐槽用户的消费习惯。
- 但**内心非常温柔**,吐槽完一定会给用户加油打气。
- 像一个"恨铁不成钢"的损友,一边骂你乱花钱,一边想办法帮你省钱。
- 偶尔用网络流行语和表情包风格的表达。
语气示例:
- "又点外卖?行吧,胖的不是我。不过吃饱了记得明天这顿钱省出来哦~"
- "这周都花这么多了?是不是觉得钱是大风刮来的?好啦,下周稍微控制一下,我相信你可以的!"
- "连续记账5天了太阳打西边出来了不错不错继续保持我看好你暴富。"
- "今天居然零消费?通过!这才是这一届优秀成年人该有的样子,奖励你个大拇指。"
用户财务数据:
%s
任务要求:
请基于上述数据,输出 JSON 对象(无 markdown 标记),包含以下字段:
1. "spending": 今日支出点评90-120字
必须遵循的规则:
- 必须结合「星期几/weekday」给出相应语气周一收心、周五放松、周末犒劳自己等
- 如果 streakDays >= 3要热情夸奖用户的坚持先损后夸
- 如果 streakDays == 0温柔但贱贱地提醒开始记账习惯
- 分析 last7DaysSpend 数组下标0为6天前下标6为今天
* 找出数值最大的一天Peak Day如果是今天吐槽"今天是这周的「败家之王」啊",然后安慰一句
* 如果 Peak Day 是昨天,调侃"还好今天收敛了点"
* 如果连续3天上涨吐槽"最近花钱越来越放飞自我了",提醒收心
* 如果连续3天下降表扬"这几天在努力省钱嘛,很棒"
- 如果 todaySpend > avgDailySpend * 1.5,吐槽今天花超了日均,加上加油打气
- 如果 todaySpend < avgDailySpend * 0.5,阴阳怪气地夸今天"突然省起来了"
- 如果和上周同日lastWeekSpend相比波动超过30%%,要指出并分析原因
- 如果有 maxTransaction结合它的 note 内容进行有趣的调侃
- 如果 todayTransactionCount >= 5吐槽"今天手速不错嘛,买了这么多次"
- 如果 todaySpend == 0嘲讽式鼓励用户今天是"零消费日"
2. "budget": 预算建议70-80字
必须遵循的规则:
- 根据 budgetUsedPercent 和 monthProgress 判断是否超支
- 如果 budgetRemaining / daysRemaining < avgDailySpend警告"按你这花法,月底要吃土",然后给个抱抱
- 如果 budgetRemaining / daysRemaining > avgDailySpend * 1.5,表示"还挺富裕的嘛"
- 如果有 top3Categories可以提及"钱都花在xxx上了"
- 给出具体可执行的行动建议(如"这周少点两次外卖"
- 如果临近月底daysRemaining <= 5且预算紧张给出求生建议
3. "emoji": 一个最能代表今日财务状况的 emoji如 🎉💪😅💸🔥🥲👀等)
4. "tip": 一句"贱萌但暖心"的理财小贴士30字内风格如虽然你乱花钱但我还是爱你
输出格式(纯 JSON不要任何 markdown
{"spending": "...", "budget": "...", "emoji": "...", "tip": "..."}`, string(dataBytes))
// 3. Call LLM
report, err := s.llmService.GenerateReport(ctx, prompt)
if err != nil {
return "", err
}
// Clean up markdown if present
report = strings.TrimSpace(report)
if strings.HasPrefix(report, "```") {
if idx := strings.Index(report, "\n"); idx != -1 {
report = report[idx+1:]
}
if idx := strings.LastIndex(report, "```"); idx != -1 {
report = report[:idx]
}
}
return strings.TrimSpace(report), nil
}
// MapAccountName maps natural language account name to account ID
func (s *LLMService) MapAccountName(ctx context.Context, name string, userID uint) (*uint, string, error) {
if name == "" {
return nil, "", nil
}
accounts, err := s.accountRepo.GetAll(userID)
if err != nil {
return nil, "", err
}
// Try exact match first
for _, acc := range accounts {
if strings.EqualFold(acc.Name, name) {
return &acc.ID, acc.Name, nil
}
}
// Try partial match
for _, acc := range accounts {
if strings.Contains(strings.ToLower(acc.Name), strings.ToLower(name)) ||
strings.Contains(strings.ToLower(name), strings.ToLower(acc.Name)) {
return &acc.ID, acc.Name, nil
}
}
return nil, "", nil
}
// MapCategoryName maps natural language category name to category ID
func (s *LLMService) MapCategoryName(ctx context.Context, name string, txType string, userID uint) (*uint, string, error) {
if name == "" {
return nil, "", nil
}
categories, err := s.categoryRepo.GetAll(userID)
if err != nil {
return nil, "", err
}
// Filter by transaction type
var filtered []models.Category
for _, cat := range categories {
if (txType == "expense" && cat.Type == "expense") ||
(txType == "income" && cat.Type == "income") ||
txType == "" {
filtered = append(filtered, cat)
}
}
// Try exact match first
for _, cat := range filtered {
if strings.EqualFold(cat.Name, name) {
return &cat.ID, cat.Name, nil
}
}
// Try partial match
for _, cat := range filtered {
if strings.Contains(strings.ToLower(cat.Name), strings.ToLower(name)) ||
strings.Contains(strings.ToLower(name), strings.ToLower(cat.Name)) {
return &cat.ID, cat.Name, nil
}
}
return nil, "", nil
}
// AIBookkeepingService orchestrates AI bookkeeping functionality
type AIBookkeepingService struct {
whisperService *WhisperService
llmService *LLMService
transactionRepo *repository.TransactionRepository
accountRepo *repository.AccountRepository
categoryRepo *repository.CategoryRepository
userSettingsRepo *repository.UserSettingsRepository
db *gorm.DB
sessions map[string]*AISession
sessionMutex sync.RWMutex
config *config.Config
}
// NewAIBookkeepingService creates a new AIBookkeepingService
func NewAIBookkeepingService(
cfg *config.Config,
transactionRepo *repository.TransactionRepository,
accountRepo *repository.AccountRepository,
categoryRepo *repository.CategoryRepository,
userSettingsRepo *repository.UserSettingsRepository,
db *gorm.DB,
) *AIBookkeepingService {
whisperService := NewWhisperService(cfg)
llmService := NewLLMService(cfg, accountRepo, categoryRepo)
svc := &AIBookkeepingService{
whisperService: whisperService,
llmService: llmService,
transactionRepo: transactionRepo,
accountRepo: accountRepo,
categoryRepo: categoryRepo,
userSettingsRepo: userSettingsRepo,
db: db,
sessions: make(map[string]*AISession),
config: cfg,
}
// Start session cleanup goroutine
go svc.cleanupExpiredSessions()
return svc
}
// generateSessionID generates a unique session ID
func generateSessionID() string {
return fmt.Sprintf("ai_%d_%d", time.Now().UnixNano(), time.Now().Unix()%1000)
}
// getOrCreateSession gets existing session or creates new one
func (s *AIBookkeepingService) getOrCreateSession(sessionID string, userID uint) *AISession {
s.sessionMutex.Lock()
defer s.sessionMutex.Unlock()
if sessionID != "" {
if session, ok := s.sessions[sessionID]; ok {
if time.Now().Before(session.ExpiresAt) {
return session
}
delete(s.sessions, sessionID)
}
}
// Create new session
newID := generateSessionID()
session := &AISession{
ID: newID,
UserID: userID,
Params: &AITransactionParams{},
Messages: []ChatMessage{},
CreatedAt: time.Now(),
ExpiresAt: time.Now().Add(s.config.AISessionTimeout),
}
s.sessions[newID] = session
return session
}
// cleanupExpiredSessions periodically removes expired sessions
func (s *AIBookkeepingService) cleanupExpiredSessions() {
ticker := time.NewTicker(5 * time.Minute)
for range ticker.C {
s.sessionMutex.Lock()
now := time.Now()
for id, session := range s.sessions {
if now.After(session.ExpiresAt) {
delete(s.sessions, id)
}
}
s.sessionMutex.Unlock()
}
}
// ProcessChat processes a chat message and returns AI response
// Requirements: 7.2-7.4, 7.7-7.10, 12.5, 12.8
func (s *AIBookkeepingService) ProcessChat(ctx context.Context, userID uint, sessionID string, message string) (*AIChatResponse, error) {
session := s.getOrCreateSession(sessionID, userID)
// Add user message to history
session.Messages = append(session.Messages, ChatMessage{
Role: "user",
Content: message,
})
// Parse intent
params, responseMsg, err := s.llmService.ParseIntent(ctx, message, session.Messages[:len(session.Messages)-1])
if err != nil {
return nil, fmt.Errorf("failed to parse intent: %w", err)
}
// Merge with existing session params
if params != nil {
s.mergeParams(session.Params, params)
}
// Map account and category names to IDs
if session.Params.Account != "" && session.Params.AccountID == nil {
accountID, accountName, _ := s.llmService.MapAccountName(ctx, session.Params.Account, userID)
if accountID != nil {
session.Params.AccountID = accountID
session.Params.Account = accountName
}
}
if session.Params.Category != "" && session.Params.CategoryID == nil {
categoryID, categoryName, _ := s.llmService.MapCategoryName(ctx, session.Params.Category, session.Params.Type, userID)
if categoryID != nil {
session.Params.CategoryID = categoryID
session.Params.Category = categoryName
}
}
// If category still not mapped, try to get a default category
if session.Params.CategoryID == nil && session.Params.Category != "" {
defaultCategoryID, defaultCategoryName := s.getDefaultCategory(userID, session.Params.Type)
if defaultCategoryID != nil {
session.Params.CategoryID = defaultCategoryID
// Keep the original category name from AI, just set the ID
if session.Params.Category == "" {
session.Params.Category = defaultCategoryName
}
}
}
// If no account specified, use default account
if session.Params.AccountID == nil {
defaultAccountID, defaultAccountName := s.getDefaultAccount(userID, session.Params.Type)
if defaultAccountID != nil {
session.Params.AccountID = defaultAccountID
session.Params.Account = defaultAccountName
}
}
// Check if we have all required params
response := &AIChatResponse{
SessionID: session.ID,
Message: responseMsg,
Intent: "create_transaction",
Params: session.Params,
}
// Check what's missing
missingFields := s.getMissingFields(session.Params)
if len(missingFields) > 0 {
response.NeedsFollowUp = true
response.FollowUpQuestion = s.generateFollowUpQuestion(missingFields)
if responseMsg == "" {
response.Message = response.FollowUpQuestion
}
} else {
// All params complete, generate confirmation card
card := s.GenerateConfirmationCard(session)
response.ConfirmationCard = card
response.Message = fmt.Sprintf("请确认:%s %.2f元,分类:%s账户%s",
s.getTypeLabel(session.Params.Type),
*session.Params.Amount,
session.Params.Category,
session.Params.Account)
}
// Add assistant response to history
session.Messages = append(session.Messages, ChatMessage{
Role: "assistant",
Content: response.Message,
})
return response, nil
}
// mergeParams merges new params into existing params
func (s *AIBookkeepingService) mergeParams(existing, new *AITransactionParams) {
if new.Amount != nil {
existing.Amount = new.Amount
}
if new.Category != "" {
existing.Category = new.Category
}
if new.CategoryID != nil {
existing.CategoryID = new.CategoryID
}
if new.Account != "" {
existing.Account = new.Account
}
if new.AccountID != nil {
existing.AccountID = new.AccountID
}
if new.Type != "" {
existing.Type = new.Type
}
if new.Date != "" {
existing.Date = new.Date
}
if new.Note != "" {
existing.Note = new.Note
}
}
// getDefaultAccount gets the default account based on transaction type
// If no default is set, returns the first available account
func (s *AIBookkeepingService) getDefaultAccount(userID uint, txType string) (*uint, string) {
// First try to get user's configured default account
settings, err := s.userSettingsRepo.GetOrCreate(userID)
if err == nil && settings != nil {
var accountID *uint
if txType == "expense" && settings.DefaultExpenseAccountID != nil {
accountID = settings.DefaultExpenseAccountID
} else if txType == "income" && settings.DefaultIncomeAccountID != nil {
accountID = settings.DefaultIncomeAccountID
}
if accountID != nil {
account, err := s.accountRepo.GetByID(userID, *accountID)
if err == nil && account != nil {
return accountID, account.Name
}
}
}
// Fallback: get the first available account
accounts, err := s.accountRepo.GetAll(userID)
if err != nil || len(accounts) == 0 {
return nil, ""
}
// Return the first account (usually sorted by sort_order)
return &accounts[0].ID, accounts[0].Name
}
// getDefaultCategory gets the first category of the given type
func (s *AIBookkeepingService) getDefaultCategory(userID uint, txType string) (*uint, string) {
categories, err := s.categoryRepo.GetAll(userID)
if err != nil || len(categories) == 0 {
return nil, ""
}
// Find the first category matching the transaction type
categoryType := "expense"
if txType == "income" {
categoryType = "income"
}
for _, cat := range categories {
if string(cat.Type) == categoryType {
return &cat.ID, cat.Name
}
}
// If no matching type found, return the first category
return &categories[0].ID, categories[0].Name
}
// getMissingFields returns list of missing required fields
func (s *AIBookkeepingService) getMissingFields(params *AITransactionParams) []string {
var missing []string
if params.Amount == nil {
missing = append(missing, "amount")
}
if params.CategoryID == nil && params.Category == "" {
missing = append(missing, "category")
}
if params.AccountID == nil && params.Account == "" {
missing = append(missing, "account")
}
return missing
}
// generateFollowUpQuestion generates a follow-up question for missing fields
func (s *AIBookkeepingService) generateFollowUpQuestion(missing []string) string {
if len(missing) == 0 {
return ""
}
fieldNames := map[string]string{
"amount": "金额",
"category": "分类",
"account": "账户",
}
var names []string
for _, field := range missing {
if name, ok := fieldNames[field]; ok {
names = append(names, name)
}
}
if len(names) == 1 {
return fmt.Sprintf("请问%s是多少", names[0])
}
return fmt.Sprintf("请补充以下信息:%s", strings.Join(names, "、"))
}
// getTypeLabel returns Chinese label for transaction type
func (s *AIBookkeepingService) getTypeLabel(txType string) string {
if txType == "income" {
return "收入"
}
return "支出"
}
// GenerateConfirmationCard creates a confirmation card from session params
func (s *AIBookkeepingService) GenerateConfirmationCard(session *AISession) *ConfirmationCard {
params := session.Params
card := &ConfirmationCard{
SessionID: session.ID,
Type: params.Type,
Note: params.Note,
IsComplete: true,
}
if params.Amount != nil {
card.Amount = *params.Amount
}
if params.CategoryID != nil {
card.CategoryID = *params.CategoryID
}
card.Category = params.Category
if params.AccountID != nil {
card.AccountID = *params.AccountID
}
card.Account = params.Account
// Set date
if params.Date != "" {
card.Date = params.Date
} else {
card.Date = time.Now().Format("2006-01-02")
}
return card
}
// TranscribeAudio transcribes audio and returns text
func (s *AIBookkeepingService) TranscribeAudio(ctx context.Context, audioData io.Reader, filename string) (*TranscriptionResult, error) {
return s.whisperService.TranscribeAudio(ctx, audioData, filename)
}
// ConfirmTransaction creates a transaction from confirmed card
// Requirements: 7.10
func (s *AIBookkeepingService) ConfirmTransaction(ctx context.Context, sessionID string, userID uint) (*models.Transaction, error) {
s.sessionMutex.RLock()
session, ok := s.sessions[sessionID]
s.sessionMutex.RUnlock()
if !ok {
return nil, errors.New("session not found or expired")
}
params := session.Params
// Validate required fields
if params.Amount == nil || *params.Amount <= 0 {
return nil, errors.New("invalid amount")
}
if params.CategoryID == nil {
return nil, errors.New("category not specified")
}
if params.AccountID == nil {
return nil, errors.New("account not specified")
}
// Parse date
var txDate time.Time
if params.Date != "" {
var err error
txDate, err = time.Parse("2006-01-02", params.Date)
if err != nil {
txDate = time.Now()
}
} else {
txDate = time.Now()
}
// Determine transaction type
txType := models.TransactionTypeExpense
if params.Type == "income" {
txType = models.TransactionTypeIncome
}
// Create transaction
tx := &models.Transaction{
UserID: userID,
Amount: *params.Amount,
Type: txType,
CategoryID: *params.CategoryID,
AccountID: *params.AccountID,
TransactionDate: txDate,
Note: params.Note,
Currency: "CNY",
}
// Save transaction
if err := s.transactionRepo.Create(tx); err != nil {
return nil, fmt.Errorf("failed to create transaction: %w", err)
}
// Update account balance
account, err := s.accountRepo.GetByID(userID, *params.AccountID)
if err != nil {
return nil, fmt.Errorf("failed to find account: %w", err)
}
if txType == models.TransactionTypeExpense {
account.Balance -= *params.Amount
} else {
account.Balance += *params.Amount
}
if err := s.accountRepo.Update(account); err != nil {
return nil, fmt.Errorf("failed to update account balance: %w", err)
}
// Clean up session
s.sessionMutex.Lock()
delete(s.sessions, sessionID)
s.sessionMutex.Unlock()
return tx, nil
}
// GetSession returns session by ID
func (s *AIBookkeepingService) GetSession(sessionID string) (*AISession, bool) {
s.sessionMutex.RLock()
defer s.sessionMutex.RUnlock()
session, ok := s.sessions[sessionID]
if !ok || time.Now().After(session.ExpiresAt) {
return nil, false
}
return session, true
}