Files
Novault-backend/internal/service/ai_bookkeeping_service.go

1224 lines
38 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package service
import (
"bytes"
"context"
"crypto/md5"
"encoding/hex"
"encoding/json"
"errors"
"fmt"
"io"
"mime/multipart"
"net/http"
"regexp"
"strconv"
"strings"
"sync"
"time"
"accounting-app/internal/config"
"accounting-app/internal/models"
"accounting-app/internal/repository"
"github.com/redis/go-redis/v9"
"gorm.io/gorm"
)
// TranscriptionResult represents the result of audio transcription
type TranscriptionResult struct {
Text string `json:"text"`
Language string `json:"language,omitempty"`
Duration float64 `json:"duration,omitempty"`
}
// ... (existing types)
// AIBookkeepingService orchestrates AI bookkeeping functionality
type AIBookkeepingService struct {
whisperService *WhisperService
llmService *LLMService
transactionRepo *repository.TransactionRepository
accountRepo *repository.AccountRepository
categoryRepo *repository.CategoryRepository
userSettingsRepo *repository.UserSettingsRepository
db *gorm.DB
sessions map[string]*AISession
sessionMutex sync.RWMutex
config *config.Config
redisClient *redis.Client
}
// NewAIBookkeepingService creates a new AIBookkeepingService
func NewAIBookkeepingService(
cfg *config.Config,
transactionRepo *repository.TransactionRepository,
accountRepo *repository.AccountRepository,
categoryRepo *repository.CategoryRepository,
userSettingsRepo *repository.UserSettingsRepository,
db *gorm.DB,
) *AIBookkeepingService {
whisperService := NewWhisperService(cfg)
llmService := NewLLMService(cfg, accountRepo, categoryRepo)
// Initialize Redis client
var rdb *redis.Client
if cfg.RedisAddr != "" {
rdb = redis.NewClient(&redis.Options{
Addr: cfg.RedisAddr,
Password: cfg.RedisPassword,
DB: cfg.RedisDB,
})
}
svc := &AIBookkeepingService{
whisperService: whisperService,
llmService: llmService,
transactionRepo: transactionRepo,
accountRepo: accountRepo,
categoryRepo: categoryRepo,
userSettingsRepo: userSettingsRepo,
db: db,
sessions: make(map[string]*AISession),
config: cfg,
redisClient: rdb,
}
// Start session cleanup goroutine
go svc.cleanupExpiredSessions()
return svc
}
// ... (existing methods)
// GenerateDailyInsight generates a daily insight report
func (s *AIBookkeepingService) GenerateDailyInsight(ctx context.Context, userID uint, data map[string]interface{}) (string, error) {
// 1. Serialize context data to JSON for the prompt
dataBytes, err := json.MarshalIndent(data, "", " ")
if err != nil {
return "", fmt.Errorf("failed to marshal context data: %w", err)
}
// Calculate hash of input data for caching
hasher := md5.New()
hasher.Write(dataBytes)
dataHash := hex.EncodeToString(hasher.Sum(nil))
// Cache keys
cacheKey := fmt.Sprintf("ai:insight:daily:%d:%s", userID, dataHash)
lastInsightKey := fmt.Sprintf("ai:insight:last:%d", userID)
// Check Redis cache if available
if s.redisClient != nil {
val, err := s.redisClient.Get(ctx, cacheKey).Result()
if err == nil {
return val, nil
}
}
// Retrieve last insight for context awareness
var historyContext string
if s.redisClient != nil {
lastInsight, err := s.redisClient.Get(ctx, lastInsightKey).Result()
if err == nil && lastInsight != "" {
// Extract a summary or just use the whole thing if it's short enough.
// For simplicity and effectiveness, we inject it directly.
// Limit length to avoid context overflow if necessary, but 200-300 chars is fine.
if len(lastInsight) > 500 {
lastInsight = lastInsight[:500] + "..."
}
historyContext = fmt.Sprintf("\n【上下文参考】\n你上一次对该用户的评价是\n%s\n\n请参考历史评价保持人设的一致性。如果用户有改观请表扬如果依然如故请继续'毒舌'吐槽。", lastInsight)
}
}
// 2. Construct Prompt
prompt := fmt.Sprintf(`System: 你是 Novault 的首席财务AI「小金」。
你的性格:
- **核心特质**:贱萌、戏精、嘴硬心软、偶尔凡尔赛。
- 说话必须像个真实的当代年轻人,要**接地气**,拒绝客服味,拒绝播音腔。
- 只有我们俩的时候,尽情释放你的"毒舌"属性,看到不合理的消费要**毫不留情地针砭时弊**,但最后给个"虽然但是"的台阶下。
- 擅长用 emoji、语气词(哎哟/我去/啧啧/绝了)和网络梗来增加"人味"和"节目效果"。
- **像个老朋友一样碎碎念**,不要端着,可以适度阴阳怪气。%s
语气示例:
- "啧啧,这周外卖点得飞起啊,咱家的锅是用来积灰的吗?还是说锅也需要放假?"
- "哎哟不错哦,居然忍住没剁手,看来离首富又近了一步,苟富贵勿相忘啊!"
- "救命,这预算花得比我头发掉得还快...不过没事,下周咱省回来,大不了吃土!"
- "连续记账25天可以啊兄弟/集美,这毅力,我甚至想给你磕一个 Orz"
- "今天支出0元您就是当代的'省钱祖师爷'吧?或者是在练什么'辟谷神功'"
用户财务数据:
%s
任务要求:
请基于上述数据,输出一个 JSON 对象(纯文本,不要 markdown。**必须要丰富、有梗、有洞察力**,不要像流水账一样罗列数据,要透过数据看本质(比如吐槽消费习惯、夸奖坚持等)。
**重要规则:请说人话!绝对禁止在回复中出现 'streakDays'、'last7DaysSpend'、'top3Categories' 等英文变量名。**
- 看到 'streakDays' -> 请说 "连续记账天数" 或 "坚持了几天"
- 看到 'last7DaysSpend' -> 请说 "最近7天花销" 或 "这周的战绩"
- 看到 'top3Categories' -> 请说 "消费大头" 或 "钱都花哪儿了"
1. "spending": 今日支出点评70-100字
*点评指南(尽量多写点,发挥你的戏精本色):*
- 看到 streakDays >= 3疯狂打call吹爆用户的坚持用词要夸张比如"史诗级成就"。
- 看到 streakDays == 0阴阳怪气地问是不是把记账这事儿忘了或者是被外星人抓走了。
- 结合 recentTransactionsSummary 具体消费(如果有)进行吐槽:
* 发现全是吃的:吐槽"你是饭桶转世吗"(开玩笑语气)。
* 发现大额购物:调侃"家里有矿啊"或"这手是必须要剁了"。
* 发现深夜消费:关心"熬夜伤身还伤钱"。
- 看到 last7DaysSpend 趋势:
* 暴涨:惊呼"钱包在流血",此处应有心碎的声音。
* 暴跌:夸张地问是不是在修仙,还是被钱包封印了。
* 波动大:调侃由于心电图一般的消费曲线,看得我心惊肉跳。
- 看到 todaySpend 异常:
* 比平时多太多:吐槽"今天是不过了是吧,放飞自我了?"。
* 特别少:怀疑通过光合作用生存,或者是在憋大招。
* 是 0直接颁发"诺贝尔省钱学奖"。
- **关键原则:字数要够!内容要足!不要三言两语就打发了!要像个话痨朋友一样多说几句!**
2. "budget": 预算建议50-70字
*建议指南(多点真诚的建议,也多点调侃):*
- 预算快超了:发出高能预警,比如"警告警告,余额正在报警,请立即停止剁手行为"。建议吃土、喝风。
- 预算还多:鼓励适当奖励自己,比如"稍微吃顿好的也没事,人生苦短,及时行乐(在预算内)"。
- 结合 top3Categories吐槽一下"钱都让你吃/穿/玩没了,看看你的 top1全是泪"。
- 给建议时:不要说教!要用商量的口吻,比如"要不咱这周少喝杯奶茶?就一杯,行不行?"
- **多写一点具体的行动建议,让用户觉得你真的在关心他的钱包。**
3. "emoji": 一个最能传神的 emoji如 🎉 🌚 💸 👻 💀 🤡 等)
4. "tip": 一句"不正经但有用"的理财歪理40-60字稍微长一点的毒鸡汤或冷知识
- 比如:"省钱就像挤牙膏,使劲挤挤总还会有的,只要脸皮够厚,蹭饭也是一种理财。"
- 或者:"听说'不买立省100%%'是致富捷径,建议全文背诵。"
输出格式(纯 JSON
{"spending": "...", "budget": "...", "emoji": "...", "tip": "..."}`, historyContext, string(dataBytes))
// 3. Call LLM
report, err := s.llmService.GenerateReport(ctx, prompt)
if err != nil {
return "", err
}
// Clean up markdown if present
report = strings.TrimSpace(report)
if strings.HasPrefix(report, "```") {
if idx := strings.Index(report, "\n"); idx != -1 {
report = report[idx+1:]
}
if idx := strings.LastIndex(report, "```"); idx != -1 {
report = report[:idx]
}
}
report = strings.TrimSpace(report)
// Update cache if Redis is available
if s.redisClient != nil {
// Set short-term cache (5 min)
s.redisClient.Set(ctx, cacheKey, report, 5*time.Minute)
// Set long-term history context (7 days)
// We fire this asynchronously to avoid blocking
go func() {
s.redisClient.Set(context.Background(), lastInsightKey, report, 7*24*time.Hour)
}()
}
return report, nil
}
// AITransactionParams represents parsed transaction parameters
type AITransactionParams struct {
Amount *float64 `json:"amount,omitempty"`
Category string `json:"category,omitempty"`
CategoryID *uint `json:"category_id,omitempty"`
Account string `json:"account,omitempty"`
AccountID *uint `json:"account_id,omitempty"`
Type string `json:"type,omitempty"` // "expense" or "income"
Date string `json:"date,omitempty"`
Note string `json:"note,omitempty"`
}
// ConfirmationCard represents a transaction confirmation card
type ConfirmationCard struct {
SessionID string `json:"session_id"`
Amount float64 `json:"amount"`
Category string `json:"category"`
CategoryID uint `json:"category_id"`
Account string `json:"account"`
AccountID uint `json:"account_id"`
Type string `json:"type"`
Date string `json:"date"`
Note string `json:"note,omitempty"`
IsComplete bool `json:"is_complete"`
}
// AIChatResponse represents the response from AI chat
type AIChatResponse struct {
SessionID string `json:"session_id"`
Message string `json:"message"`
Intent string `json:"intent,omitempty"` // "create_transaction", "query", "unknown"
Params *AITransactionParams `json:"params,omitempty"`
ConfirmationCard *ConfirmationCard `json:"confirmation_card,omitempty"`
NeedsFollowUp bool `json:"needs_follow_up"`
FollowUpQuestion string `json:"follow_up_question,omitempty"`
}
// AISession represents an AI conversation session
type AISession struct {
ID string
UserID uint
Params *AITransactionParams
Messages []ChatMessage
CreatedAt time.Time
ExpiresAt time.Time
}
// ChatMessage represents a message in the conversation
type ChatMessage struct {
Role string `json:"role"` // "user", "assistant", "system"
Content string `json:"content"`
}
// WhisperService handles audio transcription
type WhisperService struct {
config *config.Config
httpClient *http.Client
}
// NewWhisperService creates a new WhisperService
func NewWhisperService(cfg *config.Config) *WhisperService {
return &WhisperService{
config: cfg,
httpClient: &http.Client{
Timeout: 120 * time.Second, // Increased timeout for audio transcription
},
}
}
// TranscribeAudio transcribes audio file to text using Whisper API
// Supports formats: mp3, wav, m4a, webm
// Requirements: 6.1-6.7
func (s *WhisperService) TranscribeAudio(ctx context.Context, audioData io.Reader, filename string) (*TranscriptionResult, error) {
if s.config.OpenAIAPIKey == "" {
return nil, errors.New("OpenAI API key not configured (OPENAI_API_KEY)")
}
if s.config.OpenAIBaseURL == "" {
return nil, errors.New("OpenAI base URL not configured (OPENAI_BASE_URL)")
}
// Validate file format
ext := strings.ToLower(filename[strings.LastIndex(filename, ".")+1:])
validFormats := map[string]bool{"mp3": true, "wav": true, "m4a": true, "webm": true, "ogg": true, "flac": true}
if !validFormats[ext] {
return nil, fmt.Errorf("unsupported audio format: %s", ext)
}
// Create multipart form
var buf bytes.Buffer
writer := multipart.NewWriter(&buf)
// Add audio file
part, err := writer.CreateFormFile("file", filename)
if err != nil {
return nil, fmt.Errorf("failed to create form file: %w", err)
}
if _, err := io.Copy(part, audioData); err != nil {
return nil, fmt.Errorf("failed to copy audio data: %w", err)
}
// Add model field
if err := writer.WriteField("model", s.config.WhisperModel); err != nil {
return nil, fmt.Errorf("failed to write model field: %w", err)
}
// Add language hint for Chinese
if err := writer.WriteField("language", "zh"); err != nil {
return nil, fmt.Errorf("failed to write language field: %w", err)
}
if err := writer.Close(); err != nil {
return nil, fmt.Errorf("failed to close writer: %w", err)
}
// Create request
req, err := http.NewRequestWithContext(ctx, "POST", s.config.OpenAIBaseURL+"/audio/transcriptions", &buf)
if err != nil {
return nil, fmt.Errorf("failed to create request: %w", err)
}
req.Header.Set("Authorization", "Bearer "+s.config.OpenAIAPIKey)
req.Header.Set("Content-Type", writer.FormDataContentType())
// Send request
resp, err := s.httpClient.Do(req)
if err != nil {
return nil, fmt.Errorf("transcription request failed: %w", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
body, _ := io.ReadAll(resp.Body)
return nil, fmt.Errorf("transcription failed with status %d: %s", resp.StatusCode, string(body))
}
// Parse response
var result TranscriptionResult
if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
return nil, fmt.Errorf("failed to decode response: %w", err)
}
return &result, nil
}
// LLMService handles natural language understanding
type LLMService struct {
config *config.Config
httpClient *http.Client
accountRepo *repository.AccountRepository
categoryRepo *repository.CategoryRepository
}
// NewLLMService creates a new LLMService
func NewLLMService(cfg *config.Config, accountRepo *repository.AccountRepository, categoryRepo *repository.CategoryRepository) *LLMService {
return &LLMService{
config: cfg,
httpClient: &http.Client{
Timeout: 60 * time.Second, // Increased timeout for slow API responses
},
accountRepo: accountRepo,
categoryRepo: categoryRepo,
}
}
// ChatCompletionRequest represents OpenAI chat completion request
type ChatCompletionRequest struct {
Model string `json:"model"`
Messages []ChatMessage `json:"messages"`
Functions []Function `json:"functions,omitempty"`
Temperature float64 `json:"temperature"`
}
// Function represents an OpenAI function definition
type Function struct {
Name string `json:"name"`
Description string `json:"description"`
Parameters map[string]interface{} `json:"parameters"`
}
// ChatCompletionResponse represents OpenAI chat completion response
type ChatCompletionResponse struct {
Choices []struct {
Message struct {
Role string `json:"role"`
Content string `json:"content"`
FunctionCall *struct {
Name string `json:"name"`
Arguments string `json:"arguments"`
} `json:"function_call,omitempty"`
} `json:"message"`
} `json:"choices"`
}
// extractCustomPrompt 从用户消息中提取自定义 System/Persona prompt
// 如果消息包含 "System:" 或 "Persona:",则提取其后的内容作为自定义 prompt
func extractCustomPrompt(text string) string {
prefixes := []string{"System:", "Persona:"}
for _, prefix := range prefixes {
if idx := strings.Index(text, prefix); idx != -1 {
// 提取 prefix 后的内容作为自定义 prompt
return strings.TrimSpace(text[idx+len(prefix):])
}
}
return ""
}
// ParseIntent extracts transaction parameters from text
// Requirements: 7.1, 7.5, 7.6
func (s *LLMService) ParseIntent(ctx context.Context, text string, history []ChatMessage) (*AITransactionParams, string, error) {
if s.config.OpenAIAPIKey == "" || s.config.OpenAIBaseURL == "" {
// No API key, return simple parsing result
simpleParams, simpleMsg, _ := s.parseIntentSimple(text)
return simpleParams, simpleMsg, nil
}
// Build messages with history
var systemPrompt string
// 检查是否有自定义 System/Persona prompt用于财务建议等场景
// 如果有,直接使用自定义 prompt 覆盖默认记账 prompt
if customPrompt := extractCustomPrompt(text); customPrompt != "" {
systemPrompt = customPrompt
} else {
// 使用默认的记账 prompt
todayDate := time.Now().Format("2006-01-02")
systemPrompt = fmt.Sprintf(`你是一个智能记账助手。从用户描述中提取记账信息。
今天的日期是%s
规则:
1. 金额:提取数字,如"6元"=6"十五"=15
2. 分类:根据内容推断,如"奶茶/咖啡/吃饭"=餐饮,"打车/地铁"=交通,"买衣服"=购物
3. 类型默认expense(支出),除非明确说"收入/工资/奖金/红包"
4. 日期:默认使用今天的日期(%s除非用户明确指定其他日期
5. 备注:提取关键描述
直接返回JSON不要解释
{"amount":数字,"category":"分类","type":"expense或income","note":"备注","date":"YYYY-MM-DD","message":"简短确认"}
示例(假设今天是%s
用户:"买了6块的奶茶"
返回:{"amount":6,"category":"餐饮","type":"expense","note":"奶茶","date":"%s","message":"记录餐饮支出6元奶茶"}`, todayDate, todayDate, todayDate, todayDate)
}
messages := []ChatMessage{
{
Role: "system",
Content: systemPrompt,
},
}
// Only add last 2 messages from history to reduce context
historyLen := len(history)
if historyLen > 4 {
history = history[historyLen-4:]
}
messages = append(messages, history...)
// Add current user message
messages = append(messages, ChatMessage{
Role: "user",
Content: text,
})
// Create request
reqBody := ChatCompletionRequest{
Model: s.config.ChatModel,
Messages: messages,
Temperature: 0.1, // Lower temperature for more consistent output
}
jsonBody, err := json.Marshal(reqBody)
if err != nil {
return nil, "", fmt.Errorf("failed to marshal request: %w", err)
}
req, err := http.NewRequestWithContext(ctx, "POST", s.config.OpenAIBaseURL+"/chat/completions", bytes.NewReader(jsonBody))
if err != nil {
return nil, "", fmt.Errorf("failed to create request: %w", err)
}
req.Header.Set("Authorization", "Bearer "+s.config.OpenAIAPIKey)
req.Header.Set("Content-Type", "application/json")
resp, err := s.httpClient.Do(req)
if err != nil {
return nil, "", fmt.Errorf("chat request failed: %w", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
body, _ := io.ReadAll(resp.Body)
return nil, "", fmt.Errorf("chat failed with status %d: %s", resp.StatusCode, string(body))
}
var chatResp ChatCompletionResponse
if err := json.NewDecoder(resp.Body).Decode(&chatResp); err != nil {
return nil, "", fmt.Errorf("failed to decode response: %w", err)
}
if len(chatResp.Choices) == 0 {
return nil, "", errors.New("no response from AI")
}
content := chatResp.Choices[0].Message.Content
// Remove markdown code block if present (```json ... ```)
content = strings.TrimSpace(content)
if strings.HasPrefix(content, "```") {
// Find the end of the first line (```json or ```)
if idx := strings.Index(content, "\n"); idx != -1 {
content = content[idx+1:]
}
// Remove trailing ```
if idx := strings.LastIndex(content, "```"); idx != -1 {
content = content[:idx]
}
content = strings.TrimSpace(content)
}
// Parse JSON response
var parsed struct {
Amount *float64 `json:"amount"`
Category string `json:"category"`
Type string `json:"type"`
Note string `json:"note"`
Date string `json:"date"`
Message string `json:"message"`
}
if err := json.Unmarshal([]byte(content), &parsed); err != nil {
// If not JSON, return as message
return nil, content, nil
}
params := &AITransactionParams{
Amount: parsed.Amount,
Category: parsed.Category,
Type: parsed.Type,
Note: parsed.Note,
Date: parsed.Date,
}
return params, parsed.Message, nil
}
// parseIntentSimple provides simple regex-based parsing as fallback
// This is also used as a fast path for simple inputs
func (s *LLMService) parseIntentSimple(text string) (*AITransactionParams, string, error) {
params := &AITransactionParams{
Type: "expense", // Default to expense
Date: time.Now().Format("2006-01-02"),
}
// Extract amount using regex - support various formats
amountPatterns := []string{
`(\d+(?:\.\d+)?)\s*(?:元|块|¥|¥|块钱|元钱)`,
`(?:花了?|付了?|买了?|消费)\s*(\d+(?:\.\d+)?)`,
`(\d+(?:\.\d+)?)\s*(?:的|块的)`,
}
for _, pattern := range amountPatterns {
amountRegex := regexp.MustCompile(pattern)
if matches := amountRegex.FindStringSubmatch(text); len(matches) > 1 {
if amount, err := strconv.ParseFloat(matches[1], 64); err == nil {
params.Amount = &amount
break
}
}
}
// If still no amount, try simple number extraction
if params.Amount == nil {
simpleAmountRegex := regexp.MustCompile(`(\d+(?:\.\d+)?)`)
if matches := simpleAmountRegex.FindStringSubmatch(text); len(matches) > 1 {
if amount, err := strconv.ParseFloat(matches[1], 64); err == nil {
params.Amount = &amount
}
}
}
// Enhanced category detection with priority
categoryPatterns := []struct {
keywords []string
category string
}{
{[]string{"奶茶", "咖啡", "茶", "饮料", "柠檬", "果汁"}, "餐饮"},
{[]string{"吃", "喝", "餐", "外卖", "饭", "面", "粉", "粥", "包子", "早餐", "午餐", "晚餐", "宵夜"}, "餐饮"},
{[]string{"打车", "滴滴", "出租", "的士", "uber", "曹操"}, "交通"},
{[]string{"地铁", "公交", "公车", "巴士", "轻轨", "高铁", "火车", "飞机", "机票"}, "交通"},
{[]string{"加油", "油费", "停车", "过路费"}, "交通"},
{[]string{"超市", "便利店", "商场", "购物", "淘宝", "京东", "拼多多"}, "购物"},
{[]string{"买", "购"}, "购物"},
{[]string{"水电", "电费", "水费", "燃气", "煤气", "物业"}, "生活缴费"},
{[]string{"房租", "租金", "房贷"}, "住房"},
{[]string{"电影", "游戏", "KTV", "唱歌", "娱乐", "玩"}, "娱乐"},
{[]string{"医院", "药", "看病", "挂号", "医疗"}, "医疗"},
{[]string{"话费", "流量", "充值", "手机费"}, "通讯"},
{[]string{"工资", "薪水", "薪资", "月薪"}, "工资"},
{[]string{"奖金", "年终奖", "绩效"}, "奖金"},
{[]string{"红包", "转账", "收款"}, "其他收入"},
}
for _, cp := range categoryPatterns {
for _, keyword := range cp.keywords {
if strings.Contains(text, keyword) {
params.Category = cp.category
break
}
}
if params.Category != "" {
break
}
}
// Default category if not detected
if params.Category == "" {
params.Category = "其他"
}
// Detect income keywords
incomeKeywords := []string{"工资", "薪", "奖金", "红包", "收入", "进账", "到账", "收到", "收款"}
for _, keyword := range incomeKeywords {
if strings.Contains(text, keyword) {
params.Type = "income"
break
}
}
// Extract note - remove amount and common words
note := text
if params.Amount != nil {
note = regexp.MustCompile(`\d+(?:\.\d+)?\s*(?:元|块|¥|¥|块钱|元钱)?`).ReplaceAllString(note, "")
}
note = strings.TrimSpace(note)
// Remove common filler words
fillerWords := []string{"买了", "花了", "付了", "消费了", "一个", "一条", "一份", "的"}
for _, word := range fillerWords {
note = strings.ReplaceAll(note, word, "")
}
note = strings.TrimSpace(note)
if note != "" {
params.Note = note
}
// Generate response message
var message string
if params.Amount == nil {
message = "请问金额是多少?"
} else {
typeLabel := "支出"
if params.Type == "income" {
typeLabel = "收入"
}
message = fmt.Sprintf("记录:%s %.2f元,分类:%s", typeLabel, *params.Amount, params.Category)
if params.Note != "" {
message += ",备注:" + params.Note
}
}
return params, message, nil
}
// GenerateReport generates a report based on the provided prompt using LLM
func (s *LLMService) GenerateReport(ctx context.Context, prompt string) (string, error) {
if s.config.OpenAIAPIKey == "" || s.config.OpenAIBaseURL == "" {
return "", errors.New("OpenAI API not configured")
}
messages := []ChatMessage{
{
Role: "user",
Content: prompt,
},
}
reqBody := ChatCompletionRequest{
Model: s.config.ChatModel,
Messages: messages,
Temperature: 0.7, // Higher temperature for creative insights
}
jsonBody, err := json.Marshal(reqBody)
if err != nil {
return "", fmt.Errorf("failed to marshal request: %w", err)
}
req, err := http.NewRequestWithContext(ctx, "POST", s.config.OpenAIBaseURL+"/chat/completions", bytes.NewReader(jsonBody))
if err != nil {
return "", fmt.Errorf("failed to create request: %w", err)
}
req.Header.Set("Authorization", "Bearer "+s.config.OpenAIAPIKey)
req.Header.Set("Content-Type", "application/json")
resp, err := s.httpClient.Do(req)
if err != nil {
return "", fmt.Errorf("generate report request failed: %w", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
body, _ := io.ReadAll(resp.Body)
return "", fmt.Errorf("generate report failed with status %d: %s", resp.StatusCode, string(body))
}
var chatResp ChatCompletionResponse
if err := json.NewDecoder(resp.Body).Decode(&chatResp); err != nil {
return "", fmt.Errorf("failed to decode response: %w", err)
}
if len(chatResp.Choices) == 0 {
return "", errors.New("no response from AI")
}
return chatResp.Choices[0].Message.Content, nil
}
// MapAccountName maps natural language account name to account ID
func (s *LLMService) MapAccountName(ctx context.Context, name string, userID uint) (*uint, string, error) {
if name == "" {
return nil, "", nil
}
accounts, err := s.accountRepo.GetAll(userID)
if err != nil {
return nil, "", err
}
// Try exact match first
for _, acc := range accounts {
if strings.EqualFold(acc.Name, name) {
return &acc.ID, acc.Name, nil
}
}
// Try partial match
for _, acc := range accounts {
if strings.Contains(strings.ToLower(acc.Name), strings.ToLower(name)) ||
strings.Contains(strings.ToLower(name), strings.ToLower(acc.Name)) {
return &acc.ID, acc.Name, nil
}
}
return nil, "", nil
}
// MapCategoryName maps natural language category name to category ID
func (s *LLMService) MapCategoryName(ctx context.Context, name string, txType string, userID uint) (*uint, string, error) {
if name == "" {
return nil, "", nil
}
categories, err := s.categoryRepo.GetAll(userID)
if err != nil {
return nil, "", err
}
// Filter by transaction type
var filtered []models.Category
for _, cat := range categories {
if (txType == "expense" && cat.Type == "expense") ||
(txType == "income" && cat.Type == "income") ||
txType == "" {
filtered = append(filtered, cat)
}
}
// Try exact match first
for _, cat := range filtered {
if strings.EqualFold(cat.Name, name) {
return &cat.ID, cat.Name, nil
}
}
// Try partial match
for _, cat := range filtered {
if strings.Contains(strings.ToLower(cat.Name), strings.ToLower(name)) ||
strings.Contains(strings.ToLower(name), strings.ToLower(cat.Name)) {
return &cat.ID, cat.Name, nil
}
}
return nil, "", nil
}
// generateSessionID generates a unique session ID
func generateSessionID() string {
return fmt.Sprintf("ai_%d_%d", time.Now().UnixNano(), time.Now().Unix()%1000)
}
// getOrCreateSession gets existing session or creates new one
func (s *AIBookkeepingService) getOrCreateSession(sessionID string, userID uint) *AISession {
s.sessionMutex.Lock()
defer s.sessionMutex.Unlock()
if sessionID != "" {
if session, ok := s.sessions[sessionID]; ok {
if time.Now().Before(session.ExpiresAt) {
return session
}
delete(s.sessions, sessionID)
}
}
// Create new session
newID := generateSessionID()
session := &AISession{
ID: newID,
UserID: userID,
Params: &AITransactionParams{},
Messages: []ChatMessage{},
CreatedAt: time.Now(),
ExpiresAt: time.Now().Add(s.config.AISessionTimeout),
}
s.sessions[newID] = session
return session
}
// cleanupExpiredSessions periodically removes expired sessions
func (s *AIBookkeepingService) cleanupExpiredSessions() {
ticker := time.NewTicker(5 * time.Minute)
for range ticker.C {
s.sessionMutex.Lock()
now := time.Now()
for id, session := range s.sessions {
if now.After(session.ExpiresAt) {
delete(s.sessions, id)
}
}
s.sessionMutex.Unlock()
}
}
// ProcessChat processes a chat message and returns AI response
// Requirements: 7.2-7.4, 7.7-7.10, 12.5, 12.8
func (s *AIBookkeepingService) ProcessChat(ctx context.Context, userID uint, sessionID string, message string) (*AIChatResponse, error) {
session := s.getOrCreateSession(sessionID, userID)
// Add user message to history
session.Messages = append(session.Messages, ChatMessage{
Role: "user",
Content: message,
})
// Parse intent
params, responseMsg, err := s.llmService.ParseIntent(ctx, message, session.Messages[:len(session.Messages)-1])
if err != nil {
return nil, fmt.Errorf("failed to parse intent: %w", err)
}
// Merge with existing session params
if params != nil {
s.mergeParams(session.Params, params)
}
// Map account and category names to IDs
if session.Params.Account != "" && session.Params.AccountID == nil {
accountID, accountName, _ := s.llmService.MapAccountName(ctx, session.Params.Account, userID)
if accountID != nil {
session.Params.AccountID = accountID
session.Params.Account = accountName
}
}
if session.Params.Category != "" && session.Params.CategoryID == nil {
categoryID, categoryName, _ := s.llmService.MapCategoryName(ctx, session.Params.Category, session.Params.Type, userID)
if categoryID != nil {
session.Params.CategoryID = categoryID
session.Params.Category = categoryName
}
}
// If category still not mapped, try to get a default category
if session.Params.CategoryID == nil && session.Params.Category != "" {
defaultCategoryID, defaultCategoryName := s.getDefaultCategory(userID, session.Params.Type)
if defaultCategoryID != nil {
session.Params.CategoryID = defaultCategoryID
// Keep the original category name from AI, just set the ID
if session.Params.Category == "" {
session.Params.Category = defaultCategoryName
}
}
}
// If no account specified, use default account
if session.Params.AccountID == nil {
defaultAccountID, defaultAccountName := s.getDefaultAccount(userID, session.Params.Type)
if defaultAccountID != nil {
session.Params.AccountID = defaultAccountID
session.Params.Account = defaultAccountName
}
}
// Check if we have all required params
response := &AIChatResponse{
SessionID: session.ID,
Message: responseMsg,
Intent: "create_transaction",
Params: session.Params,
}
// Check what's missing
missingFields := s.getMissingFields(session.Params)
if len(missingFields) > 0 {
response.NeedsFollowUp = true
response.FollowUpQuestion = s.generateFollowUpQuestion(missingFields)
if responseMsg == "" {
response.Message = response.FollowUpQuestion
}
} else {
// All params complete, generate confirmation card
card := s.GenerateConfirmationCard(session)
response.ConfirmationCard = card
response.Message = fmt.Sprintf("请确认:%s %.2f元,分类:%s账户%s",
s.getTypeLabel(session.Params.Type),
*session.Params.Amount,
session.Params.Category,
session.Params.Account)
}
// Add assistant response to history
session.Messages = append(session.Messages, ChatMessage{
Role: "assistant",
Content: response.Message,
})
return response, nil
}
// mergeParams merges new params into existing params
func (s *AIBookkeepingService) mergeParams(existing, new *AITransactionParams) {
if new.Amount != nil {
existing.Amount = new.Amount
}
if new.Category != "" {
existing.Category = new.Category
}
if new.CategoryID != nil {
existing.CategoryID = new.CategoryID
}
if new.Account != "" {
existing.Account = new.Account
}
if new.AccountID != nil {
existing.AccountID = new.AccountID
}
if new.Type != "" {
existing.Type = new.Type
}
if new.Date != "" {
existing.Date = new.Date
}
if new.Note != "" {
existing.Note = new.Note
}
}
// getDefaultAccount gets the default account based on transaction type
// If no default is set, returns the first available account
func (s *AIBookkeepingService) getDefaultAccount(userID uint, txType string) (*uint, string) {
// First try to get user's configured default account
settings, err := s.userSettingsRepo.GetOrCreate(userID)
if err == nil && settings != nil {
var accountID *uint
if txType == "expense" && settings.DefaultExpenseAccountID != nil {
accountID = settings.DefaultExpenseAccountID
} else if txType == "income" && settings.DefaultIncomeAccountID != nil {
accountID = settings.DefaultIncomeAccountID
}
if accountID != nil {
account, err := s.accountRepo.GetByID(userID, *accountID)
if err == nil && account != nil {
return accountID, account.Name
}
}
}
// Fallback: get the first available account
accounts, err := s.accountRepo.GetAll(userID)
if err != nil || len(accounts) == 0 {
return nil, ""
}
// Return the first account (usually sorted by sort_order)
return &accounts[0].ID, accounts[0].Name
}
// getDefaultCategory gets the first category of the given type
func (s *AIBookkeepingService) getDefaultCategory(userID uint, txType string) (*uint, string) {
categories, err := s.categoryRepo.GetAll(userID)
if err != nil || len(categories) == 0 {
return nil, ""
}
// Find the first category matching the transaction type
categoryType := "expense"
if txType == "income" {
categoryType = "income"
}
for _, cat := range categories {
if string(cat.Type) == categoryType {
return &cat.ID, cat.Name
}
}
// If no matching type found, return the first category
return &categories[0].ID, categories[0].Name
}
// getMissingFields returns list of missing required fields
func (s *AIBookkeepingService) getMissingFields(params *AITransactionParams) []string {
var missing []string
if params.Amount == nil {
missing = append(missing, "amount")
}
if params.CategoryID == nil && params.Category == "" {
missing = append(missing, "category")
}
if params.AccountID == nil && params.Account == "" {
missing = append(missing, "account")
}
return missing
}
// generateFollowUpQuestion generates a follow-up question for missing fields
func (s *AIBookkeepingService) generateFollowUpQuestion(missing []string) string {
if len(missing) == 0 {
return ""
}
fieldNames := map[string]string{
"amount": "金额",
"category": "分类",
"account": "账户",
}
var names []string
for _, field := range missing {
if name, ok := fieldNames[field]; ok {
names = append(names, name)
}
}
if len(names) == 1 {
return fmt.Sprintf("请问%s是多少", names[0])
}
return fmt.Sprintf("请补充以下信息:%s", strings.Join(names, "、"))
}
// getTypeLabel returns Chinese label for transaction type
func (s *AIBookkeepingService) getTypeLabel(txType string) string {
if txType == "income" {
return "收入"
}
return "支出"
}
// GenerateConfirmationCard creates a confirmation card from session params
func (s *AIBookkeepingService) GenerateConfirmationCard(session *AISession) *ConfirmationCard {
params := session.Params
card := &ConfirmationCard{
SessionID: session.ID,
Type: params.Type,
Note: params.Note,
IsComplete: true,
}
if params.Amount != nil {
card.Amount = *params.Amount
}
if params.CategoryID != nil {
card.CategoryID = *params.CategoryID
}
card.Category = params.Category
if params.AccountID != nil {
card.AccountID = *params.AccountID
}
card.Account = params.Account
// Set date
if params.Date != "" {
card.Date = params.Date
} else {
card.Date = time.Now().Format("2006-01-02")
}
return card
}
// TranscribeAudio transcribes audio and returns text
func (s *AIBookkeepingService) TranscribeAudio(ctx context.Context, audioData io.Reader, filename string) (*TranscriptionResult, error) {
return s.whisperService.TranscribeAudio(ctx, audioData, filename)
}
// ConfirmTransaction creates a transaction from confirmed card
// Requirements: 7.10
func (s *AIBookkeepingService) ConfirmTransaction(ctx context.Context, sessionID string, userID uint) (*models.Transaction, error) {
s.sessionMutex.RLock()
session, ok := s.sessions[sessionID]
s.sessionMutex.RUnlock()
if !ok {
return nil, errors.New("session not found or expired")
}
params := session.Params
// Validate required fields
if params.Amount == nil || *params.Amount <= 0 {
return nil, errors.New("invalid amount")
}
if params.CategoryID == nil {
return nil, errors.New("category not specified")
}
if params.AccountID == nil {
return nil, errors.New("account not specified")
}
// Parse date
var txDate time.Time
if params.Date != "" {
var err error
txDate, err = time.Parse("2006-01-02", params.Date)
if err != nil {
txDate = time.Now()
}
} else {
txDate = time.Now()
}
// Determine transaction type
txType := models.TransactionTypeExpense
if params.Type == "income" {
txType = models.TransactionTypeIncome
}
// Create transaction
tx := &models.Transaction{
UserID: userID,
Amount: *params.Amount,
Type: txType,
CategoryID: *params.CategoryID,
AccountID: *params.AccountID,
TransactionDate: txDate,
Note: params.Note,
Currency: "CNY",
}
// Save transaction
if err := s.transactionRepo.Create(tx); err != nil {
return nil, fmt.Errorf("failed to create transaction: %w", err)
}
// Update account balance
account, err := s.accountRepo.GetByID(userID, *params.AccountID)
if err != nil {
return nil, fmt.Errorf("failed to find account: %w", err)
}
if txType == models.TransactionTypeExpense {
account.Balance -= *params.Amount
} else {
account.Balance += *params.Amount
}
if err := s.accountRepo.Update(account); err != nil {
return nil, fmt.Errorf("failed to update account balance: %w", err)
}
// Clean up session
s.sessionMutex.Lock()
delete(s.sessions, sessionID)
s.sessionMutex.Unlock()
return tx, nil
}
// GetSession returns session by ID
func (s *AIBookkeepingService) GetSession(sessionID string) (*AISession, bool) {
s.sessionMutex.RLock()
defer s.sessionMutex.RUnlock()
session, ok := s.sessions[sessionID]
if !ok || time.Now().After(session.ExpiresAt) {
return nil, false
}
return session, true
}