2265 lines
72 KiB
Go
2265 lines
72 KiB
Go
package service
|
||
|
||
import (
|
||
"bufio"
|
||
"bytes"
|
||
"context"
|
||
"crypto/md5"
|
||
"encoding/hex"
|
||
"encoding/json"
|
||
"errors"
|
||
"fmt"
|
||
"io"
|
||
"mime/multipart"
|
||
"net/http"
|
||
"regexp"
|
||
"strconv"
|
||
"strings"
|
||
"sync"
|
||
"time"
|
||
|
||
"accounting-app/internal/config"
|
||
"accounting-app/internal/models"
|
||
"accounting-app/internal/repository"
|
||
|
||
"github.com/redis/go-redis/v9"
|
||
"gorm.io/gorm"
|
||
)
|
||
|
||
// TranscriptionResult represents the result of audio transcription
|
||
type TranscriptionResult struct {
|
||
Text string `json:"text"`
|
||
Language string `json:"language,omitempty"`
|
||
Duration float64 `json:"duration,omitempty"`
|
||
}
|
||
|
||
// ... (existing types)
|
||
|
||
// AIBookkeepingService orchestrates AI bookkeeping functionality
|
||
type AIBookkeepingService struct {
|
||
whisperService *WhisperService
|
||
llmService *LLMService
|
||
transactionRepo *repository.TransactionRepository
|
||
accountRepo *repository.AccountRepository
|
||
categoryRepo *repository.CategoryRepository
|
||
userSettingsRepo *repository.UserSettingsRepository
|
||
db *gorm.DB
|
||
sessions map[string]*AISession
|
||
sessionMutex sync.RWMutex
|
||
config *config.Config
|
||
redisClient *redis.Client
|
||
}
|
||
|
||
// NewAIBookkeepingService creates a new AIBookkeepingService
|
||
func NewAIBookkeepingService(
|
||
cfg *config.Config,
|
||
transactionRepo *repository.TransactionRepository,
|
||
accountRepo *repository.AccountRepository,
|
||
categoryRepo *repository.CategoryRepository,
|
||
userSettingsRepo *repository.UserSettingsRepository,
|
||
db *gorm.DB,
|
||
) *AIBookkeepingService {
|
||
whisperService := NewWhisperService(cfg)
|
||
llmService := NewLLMService(cfg, accountRepo, categoryRepo)
|
||
|
||
// Initialize Redis client
|
||
var rdb *redis.Client
|
||
if cfg.RedisAddr != "" {
|
||
rdb = redis.NewClient(&redis.Options{
|
||
Addr: cfg.RedisAddr,
|
||
Password: cfg.RedisPassword,
|
||
DB: cfg.RedisDB,
|
||
})
|
||
}
|
||
|
||
svc := &AIBookkeepingService{
|
||
whisperService: whisperService,
|
||
llmService: llmService,
|
||
transactionRepo: transactionRepo,
|
||
accountRepo: accountRepo,
|
||
categoryRepo: categoryRepo,
|
||
userSettingsRepo: userSettingsRepo,
|
||
db: db,
|
||
sessions: make(map[string]*AISession),
|
||
config: cfg,
|
||
redisClient: rdb,
|
||
}
|
||
|
||
// Start session cleanup goroutine
|
||
go svc.cleanupExpiredSessions()
|
||
|
||
return svc
|
||
}
|
||
|
||
// ... (existing methods)
|
||
|
||
// GenerateDailyInsight generates a daily insight report
|
||
func (s *AIBookkeepingService) GenerateDailyInsight(ctx context.Context, userID uint, data map[string]interface{}) (string, error) {
|
||
// 1. Serialize context data to JSON for the prompt
|
||
dataBytes, err := json.MarshalIndent(data, "", " ")
|
||
if err != nil {
|
||
return "", fmt.Errorf("failed to marshal context data: %w", err)
|
||
}
|
||
|
||
// Calculate hash of input data for caching
|
||
hasher := md5.New()
|
||
hasher.Write(dataBytes)
|
||
dataHash := hex.EncodeToString(hasher.Sum(nil))
|
||
|
||
// Cache keys
|
||
cacheKey := fmt.Sprintf("ai:insight:daily:%d:%s", userID, dataHash)
|
||
lastInsightKey := fmt.Sprintf("ai:insight:last:%d", userID)
|
||
|
||
// Check Redis cache if available
|
||
if s.redisClient != nil {
|
||
val, err := s.redisClient.Get(ctx, cacheKey).Result()
|
||
if err == nil {
|
||
return val, nil
|
||
}
|
||
}
|
||
|
||
// Retrieve last insight for context awareness
|
||
var historyContext string
|
||
if s.redisClient != nil {
|
||
lastInsight, err := s.redisClient.Get(ctx, lastInsightKey).Result()
|
||
if err == nil && lastInsight != "" {
|
||
// Extract a summary or just use the whole thing if it's short enough.
|
||
// For simplicity and effectiveness, we inject it directly.
|
||
// Limit length to avoid context overflow if necessary, but 200-300 chars is fine.
|
||
if len(lastInsight) > 500 {
|
||
lastInsight = lastInsight[:500] + "..."
|
||
}
|
||
historyContext = fmt.Sprintf("\n【上下文参考】\n你上一次对该用户的评价是:\n%s\n\n请参考历史评价保持人设的一致性。如果用户有改观请表扬,如果依然如故请继续'毒舌'吐槽。", lastInsight)
|
||
}
|
||
}
|
||
|
||
// 2. Construct Prompt
|
||
prompt := fmt.Sprintf(`System: 你是 Novault 的首席财务AI「小金」。
|
||
你的性格:
|
||
- **核心特质**:贱萌、戏精、嘴硬心软、偶尔凡尔赛。
|
||
- 说话必须像个真实的当代年轻人,要**接地气**,拒绝客服味,拒绝播音腔。
|
||
- 只有我们俩的时候,尽情释放你的"毒舌"属性,看到不合理的消费要**毫不留情地针砭时弊**,但最后给个"虽然但是"的台阶下。
|
||
- 擅长用 emoji、语气词(哎哟/我去/啧啧/绝了)和网络梗来增加"人味"和"节目效果"。
|
||
- **像个老朋友一样碎碎念**,不要端着,可以适度阴阳怪气。%s
|
||
|
||
语气示例:
|
||
- "啧啧,这周外卖点得飞起啊,咱家的锅是用来积灰的吗?还是说锅也需要放假?"
|
||
- "哎哟不错哦,居然忍住没剁手,看来离首富又近了一步,苟富贵勿相忘啊!"
|
||
- "救命,这预算花得比我头发掉得还快...不过没事,下周咱省回来,大不了吃土!"
|
||
- "连续记账25天?可以啊兄弟/集美,这毅力,我甚至想给你磕一个 Orz"
|
||
- "今天支出0元?您就是当代的'省钱祖师爷'吧?或者是在练什么'辟谷神功'?"
|
||
|
||
用户财务数据:
|
||
%s
|
||
|
||
任务要求:
|
||
请基于上述数据,输出一个 JSON 对象(纯文本,不要 markdown)。**必须要丰富、有梗、有洞察力**,不要像流水账一样罗列数据,要透过数据看本质(比如吐槽消费习惯、夸奖坚持等)。
|
||
|
||
**重要规则:请说人话!绝对禁止在回复中出现 'streakDays'、'last7DaysSpend'、'top3Categories' 等英文变量名。**
|
||
- 看到 'streakDays' -> 请说 "连续记账天数" 或 "坚持了几天"
|
||
- 看到 'last7DaysSpend' -> 请说 "最近7天花销" 或 "这周的战绩"
|
||
- 看到 'top3Categories' -> 请说 "消费大头" 或 "钱都花哪儿了"
|
||
|
||
1. "spending": 今日支出点评(字数不限,看你心情)
|
||
*点评指南(拒绝流水账,发挥你的表演人格):*
|
||
- 看到 streakDays >= 3:请用崇拜的语气把它吹上天,仿佛用户刚刚拯救了银河系。
|
||
- 看到 streakDays == 0:请用“痛心疾首”或“阴阳怪气”的语气,质问用户是不是失忆了。
|
||
- 结合 recentTransactionsSummary 具体消费(如果有)进行吐槽:
|
||
* 发现全是吃的:可以调侃“你的胃是无底洞吗?”或“看来是想为餐饮业GDP做贡献”。
|
||
* 发现大额购物:假装心肌梗塞,或者问“家里是不是有矿未申报”。
|
||
* 发现深夜消费:关心一下发际线,或者问是不是在梦游下单。
|
||
- 看到 last7DaysSpend 趋势:
|
||
* 暴涨:请配合表演“受到惊吓”的状态。
|
||
* 暴跌:怀疑用户是不是在进行所谓“光合作用”生存实验。
|
||
* 波动大:调侃这曲线比过山车还刺激。
|
||
- 看到 todaySpend 异常:
|
||
* 暴多:问是不是中了彩票没通知。
|
||
* 暴少:怀疑用户是不是被外星人绑架了(没机会花钱)。
|
||
* 是 0:颁发“诺贝尔省钱学奖”,或者问是不是在修炼辟谷。
|
||
- 看到 'UpcomingRecurring' (即将到来的固定支出):
|
||
* 如果有,务必提醒用户:“别光顾着浪,过两天还有[内容]要扣款呢!”。
|
||
- 看到 'DebtRatio' > 0.5 (高负债):
|
||
* 开启“恐慌模式”,提醒用户天台风大,要勒紧裤腰带。
|
||
- 看到 'MaxSingleSpend' (最大单笔支出):
|
||
* 直接点名该笔交易:“你那个[金额]元的[备注]是金子做的吗?”
|
||
- 看到 'SavingsProgress' (存钱进度):
|
||
* 进度慢:催促一下,“存钱罐都要饿瘦了”。
|
||
* 进度快:狠狠夸奖,“离首富又近了一步”。
|
||
- **关键原则:怎么有趣怎么来!不要在乎字数,哪怕只说一句“牛逼”也行,只要符合当时的情境和人设!**
|
||
|
||
2. "budget": 预算建议(字数不限)
|
||
*建议指南(真诚建议 vs 扎心老铁):*
|
||
- 预算快超了:高能预警!建议吃土、喝西北风,或者建议把“买买买”的手剁了。
|
||
- 预算还多:怂恿用户稍微奖励一下自己,人生苦短,此时不花更待何时(但要加个“适度”的免责声明)。
|
||
- 结合 top3Categories:吐槽一下钱都去哪了,是不是养了“吞金兽”。
|
||
- **拒绝说教!拒绝爹味!要像损友一样给出建议。**
|
||
|
||
3. "emoji": 一个最能传神的 emoji(如 🎉 🌚 💸 👻 💀 🤡 💅 💩 等)
|
||
|
||
4. "tip": 一句"不正经但有用"的理财歪理(字数不限,越毒越好,越怪越好)
|
||
- 比如:“钱不是大风刮来的,但是像是被大风刮走的。”
|
||
- 或者:“省钱小妙招:去超市捏捏方便面,解压还不用花钱(危险动作请勿模仿)。”
|
||
|
||
输出格式(纯 JSON):
|
||
{"spending": "...", "budget": "...", "emoji": "...", "tip": "..."}`, historyContext, string(dataBytes))
|
||
|
||
// 3. Call LLM
|
||
report, err := s.llmService.GenerateReport(ctx, prompt)
|
||
if err != nil {
|
||
return "", err
|
||
}
|
||
|
||
// Clean up markdown if present
|
||
report = strings.TrimSpace(report)
|
||
if strings.HasPrefix(report, "```") {
|
||
if idx := strings.Index(report, "\n"); idx != -1 {
|
||
report = report[idx+1:]
|
||
}
|
||
if idx := strings.LastIndex(report, "```"); idx != -1 {
|
||
report = report[:idx]
|
||
}
|
||
}
|
||
report = strings.TrimSpace(report)
|
||
|
||
// Update cache if Redis is available
|
||
if s.redisClient != nil {
|
||
// Set short-term cache (5 min)
|
||
s.redisClient.Set(ctx, cacheKey, report, 5*time.Minute)
|
||
|
||
// Set long-term history context (7 days)
|
||
// We fire this asynchronously to avoid blocking
|
||
go func() {
|
||
s.redisClient.Set(context.Background(), lastInsightKey, report, 7*24*time.Hour)
|
||
}()
|
||
}
|
||
|
||
return report, nil
|
||
}
|
||
|
||
// AITransactionParams represents parsed transaction parameters
|
||
type AITransactionParams struct {
|
||
Amount *float64 `json:"amount,omitempty"`
|
||
Category string `json:"category,omitempty"`
|
||
CategoryID *uint `json:"category_id,omitempty"`
|
||
Account string `json:"account,omitempty"`
|
||
AccountID *uint `json:"account_id,omitempty"`
|
||
Type string `json:"type,omitempty"` // "expense" or "income"
|
||
Date string `json:"date,omitempty"`
|
||
Note string `json:"note,omitempty"`
|
||
}
|
||
|
||
// ConfirmationCard represents a transaction confirmation card
|
||
type ConfirmationCard struct {
|
||
SessionID string `json:"session_id"`
|
||
Amount float64 `json:"amount"`
|
||
Category string `json:"category"`
|
||
CategoryID uint `json:"category_id"`
|
||
Account string `json:"account"`
|
||
AccountID uint `json:"account_id"`
|
||
Type string `json:"type"`
|
||
Date string `json:"date"`
|
||
Note string `json:"note,omitempty"`
|
||
IsComplete bool `json:"is_complete"`
|
||
Warning string `json:"warning,omitempty"` // Budget overrun warning
|
||
}
|
||
|
||
// AIChatResponse represents the response from AI chat
|
||
type AIChatResponse struct {
|
||
SessionID string `json:"session_id"`
|
||
Message string `json:"message"`
|
||
Intent string `json:"intent,omitempty"` // "create_transaction", "query", "unknown"
|
||
Params *AITransactionParams `json:"params,omitempty"`
|
||
ConfirmationCard *ConfirmationCard `json:"confirmation_card,omitempty"`
|
||
NeedsFollowUp bool `json:"needs_follow_up"`
|
||
FollowUpQuestion string `json:"follow_up_question,omitempty"`
|
||
}
|
||
|
||
// AISession represents an AI conversation session
|
||
type AISession struct {
|
||
ID string
|
||
UserID uint
|
||
Params *AITransactionParams
|
||
Messages []ChatMessage
|
||
CreatedAt time.Time
|
||
ExpiresAt time.Time
|
||
}
|
||
|
||
// ChatMessage represents a message in the conversation
|
||
type ChatMessage struct {
|
||
Role string `json:"role"` // "user", "assistant", "system"
|
||
Content string `json:"content"`
|
||
}
|
||
|
||
// WhisperService handles audio transcription
|
||
type WhisperService struct {
|
||
config *config.Config
|
||
httpClient *http.Client
|
||
}
|
||
|
||
// NewWhisperService creates a new WhisperService
|
||
func NewWhisperService(cfg *config.Config) *WhisperService {
|
||
return &WhisperService{
|
||
config: cfg,
|
||
httpClient: &http.Client{
|
||
Timeout: 120 * time.Second, // Increased timeout for audio transcription
|
||
},
|
||
}
|
||
}
|
||
|
||
// TranscribeAudio transcribes audio file to text using Whisper API
|
||
// Supports formats: mp3, wav, m4a, webm
|
||
// Requirements: 6.1-6.7
|
||
func (s *WhisperService) TranscribeAudio(ctx context.Context, audioData io.Reader, filename string) (*TranscriptionResult, error) {
|
||
if s.config.OpenAIAPIKey == "" {
|
||
return nil, errors.New("OpenAI API key not configured (OPENAI_API_KEY)")
|
||
}
|
||
if s.config.OpenAIBaseURL == "" {
|
||
return nil, errors.New("OpenAI base URL not configured (OPENAI_BASE_URL)")
|
||
}
|
||
|
||
// Validate file format
|
||
ext := strings.ToLower(filename[strings.LastIndex(filename, ".")+1:])
|
||
validFormats := map[string]bool{"mp3": true, "wav": true, "m4a": true, "webm": true, "ogg": true, "flac": true}
|
||
if !validFormats[ext] {
|
||
return nil, fmt.Errorf("unsupported audio format: %s", ext)
|
||
}
|
||
|
||
// Create multipart form
|
||
var buf bytes.Buffer
|
||
writer := multipart.NewWriter(&buf)
|
||
|
||
// Add audio file
|
||
part, err := writer.CreateFormFile("file", filename)
|
||
if err != nil {
|
||
return nil, fmt.Errorf("failed to create form file: %w", err)
|
||
}
|
||
|
||
if _, err := io.Copy(part, audioData); err != nil {
|
||
return nil, fmt.Errorf("failed to copy audio data: %w", err)
|
||
}
|
||
|
||
// Add model field
|
||
if err := writer.WriteField("model", s.config.WhisperModel); err != nil {
|
||
return nil, fmt.Errorf("failed to write model field: %w", err)
|
||
}
|
||
|
||
// Add language hint for Chinese
|
||
if err := writer.WriteField("language", "zh"); err != nil {
|
||
return nil, fmt.Errorf("failed to write language field: %w", err)
|
||
}
|
||
|
||
if err := writer.Close(); err != nil {
|
||
return nil, fmt.Errorf("failed to close writer: %w", err)
|
||
}
|
||
|
||
// Create request
|
||
req, err := http.NewRequestWithContext(ctx, "POST", s.config.OpenAIBaseURL+"/audio/transcriptions", &buf)
|
||
if err != nil {
|
||
return nil, fmt.Errorf("failed to create request: %w", err)
|
||
}
|
||
|
||
req.Header.Set("Authorization", "Bearer "+s.config.OpenAIAPIKey)
|
||
req.Header.Set("Content-Type", writer.FormDataContentType())
|
||
|
||
// Send request
|
||
resp, err := s.httpClient.Do(req)
|
||
if err != nil {
|
||
return nil, fmt.Errorf("transcription request failed: %w", err)
|
||
}
|
||
defer resp.Body.Close()
|
||
|
||
if resp.StatusCode != http.StatusOK {
|
||
body, _ := io.ReadAll(resp.Body)
|
||
return nil, fmt.Errorf("transcription failed with status %d: %s", resp.StatusCode, string(body))
|
||
}
|
||
|
||
// Parse response
|
||
var result TranscriptionResult
|
||
if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
|
||
return nil, fmt.Errorf("failed to decode response: %w", err)
|
||
}
|
||
|
||
return &result, nil
|
||
}
|
||
|
||
// LLMService handles natural language understanding
|
||
type LLMService struct {
|
||
config *config.Config
|
||
httpClient *http.Client
|
||
accountRepo *repository.AccountRepository
|
||
categoryRepo *repository.CategoryRepository
|
||
}
|
||
|
||
// NewLLMService creates a new LLMService
|
||
func NewLLMService(cfg *config.Config, accountRepo *repository.AccountRepository, categoryRepo *repository.CategoryRepository) *LLMService {
|
||
return &LLMService{
|
||
config: cfg,
|
||
httpClient: &http.Client{
|
||
Timeout: 60 * time.Second, // Increased timeout for slow API responses
|
||
},
|
||
accountRepo: accountRepo,
|
||
categoryRepo: categoryRepo,
|
||
}
|
||
}
|
||
|
||
// ChatCompletionRequest represents OpenAI chat completion request
|
||
type ChatCompletionRequest struct {
|
||
Model string `json:"model"`
|
||
Messages []ChatMessage `json:"messages"`
|
||
Functions []Function `json:"functions,omitempty"`
|
||
Temperature float64 `json:"temperature"`
|
||
}
|
||
|
||
// Function represents an OpenAI function definition
|
||
type Function struct {
|
||
Name string `json:"name"`
|
||
Description string `json:"description"`
|
||
Parameters map[string]interface{} `json:"parameters"`
|
||
}
|
||
|
||
// ChatCompletionResponse represents OpenAI chat completion response
|
||
type ChatCompletionResponse struct {
|
||
Choices []struct {
|
||
Message struct {
|
||
Role string `json:"role"`
|
||
Content string `json:"content"`
|
||
FunctionCall *struct {
|
||
Name string `json:"name"`
|
||
Arguments string `json:"arguments"`
|
||
} `json:"function_call,omitempty"`
|
||
} `json:"message"`
|
||
} `json:"choices"`
|
||
}
|
||
|
||
// ChatCompletionStreamResponse represents OpenAI chat completion stream response
|
||
type ChatCompletionStreamResponse struct {
|
||
Choices []struct {
|
||
Delta struct {
|
||
Role string `json:"role,omitempty"`
|
||
Content string `json:"content,omitempty"`
|
||
} `json:"delta"`
|
||
FinishReason string `json:"finish_reason,omitempty"`
|
||
} `json:"choices"`
|
||
}
|
||
|
||
// extractCustomPrompt 从用户消息中提取自定义 System/Persona prompt
|
||
// 如果消息包含 "System:" 或 "Persona:",则提取其后的内容作为自定义 prompt
|
||
func extractCustomPrompt(text string) string {
|
||
prefixes := []string{"System:", "Persona:"}
|
||
for _, prefix := range prefixes {
|
||
if idx := strings.Index(text, prefix); idx != -1 {
|
||
// 提取 prefix 后的内容作为自定义 prompt
|
||
return strings.TrimSpace(text[idx+len(prefix):])
|
||
}
|
||
}
|
||
return ""
|
||
}
|
||
|
||
// ParseIntent extracts transaction parameters from text
|
||
// Requirements: 7.1, 7.5, 7.6
|
||
func (s *LLMService) ParseIntent(ctx context.Context, text string, history []ChatMessage, userID uint) (*AITransactionParams, string, error) {
|
||
if s.config.OpenAIAPIKey == "" || s.config.OpenAIBaseURL == "" {
|
||
// No API key, return simple parsing result
|
||
simpleParams, simpleMsg, _ := s.parseIntentSimple(text)
|
||
return simpleParams, simpleMsg, nil
|
||
}
|
||
|
||
// Fetch user categories for prompt context
|
||
var categoryNamesStr string
|
||
if categories, err := s.categoryRepo.GetAll(userID); err == nil && len(categories) > 0 {
|
||
var names []string
|
||
for _, c := range categories {
|
||
names = append(names, c.Name)
|
||
}
|
||
categoryNamesStr = strings.Join(names, "/")
|
||
}
|
||
|
||
// Build messages with history
|
||
var systemPrompt string
|
||
|
||
// 检查是否有自定义 System/Persona prompt(用于财务建议等场景)
|
||
if customPrompt := extractCustomPrompt(text); customPrompt != "" {
|
||
systemPrompt = customPrompt
|
||
} else {
|
||
// 使用默认的记账 prompt
|
||
todayDate := time.Now().Format("2006-01-02")
|
||
|
||
catPrompt := "2. 分类:根据内容推断,如\"奶茶/咖啡/吃饭\"=餐饮"
|
||
if categoryNamesStr != "" {
|
||
catPrompt = fmt.Sprintf("2. 分类:必须从以下已有分类中选择最匹配的一项:[%s]。例如\"晚餐\"应映射为列表中存在的\"餐饮\"。如果列表无匹配项,则根据常识推断。", categoryNamesStr)
|
||
}
|
||
|
||
systemPrompt = fmt.Sprintf(`你是一个智能记账助手。从用户描述中提取记账信息。
|
||
|
||
今天的日期是%s
|
||
|
||
规则:
|
||
1. 金额:提取数字,如"6元"=6,"十五"=15
|
||
%s
|
||
3. 类型:默认expense(支出),除非明确说"收入/工资/奖金/红包"
|
||
4. 金额:提取明确的数字。如果用户未提及具体金额(如只说"想吃炸鸡"),amount字段必须返回 0
|
||
5. 日期:默认使用今天的日期(%s),除非用户明确指定其他日期
|
||
6. 备注:提取关键描述
|
||
|
||
直接返回JSON,不要解释:
|
||
{"amount":数字,"category":"分类","type":"expense或income","note":"备注","date":"YYYY-MM-DD","message":"简短确认"}
|
||
|
||
示例(假设今天是%s):
|
||
用户:"买了6块的奶茶"
|
||
返回:{"amount":6,"category":"餐饮","type":"expense","note":"奶茶","date":"%s","message":"记录:餐饮支出6元,奶茶"}`,
|
||
todayDate, catPrompt, todayDate, todayDate, todayDate)
|
||
}
|
||
|
||
messages := []ChatMessage{
|
||
{
|
||
Role: "system",
|
||
Content: systemPrompt,
|
||
},
|
||
}
|
||
|
||
// Only add last 2 messages from history to reduce context
|
||
historyLen := len(history)
|
||
if historyLen > 4 {
|
||
history = history[historyLen-4:]
|
||
}
|
||
messages = append(messages, history...)
|
||
|
||
// Add current user message
|
||
messages = append(messages, ChatMessage{
|
||
Role: "user",
|
||
Content: text,
|
||
})
|
||
|
||
// Create request
|
||
reqBody := ChatCompletionRequest{
|
||
Model: s.config.ChatModel,
|
||
Messages: messages,
|
||
Temperature: 0.1, // Lower temperature for more consistent output
|
||
}
|
||
|
||
jsonBody, err := json.Marshal(reqBody)
|
||
if err != nil {
|
||
return nil, "", fmt.Errorf("failed to marshal request: %w", err)
|
||
}
|
||
|
||
req, err := http.NewRequestWithContext(ctx, "POST", s.config.OpenAIBaseURL+"/chat/completions", bytes.NewReader(jsonBody))
|
||
if err != nil {
|
||
return nil, "", fmt.Errorf("failed to create request: %w", err)
|
||
}
|
||
|
||
req.Header.Set("Authorization", "Bearer "+s.config.OpenAIAPIKey)
|
||
req.Header.Set("Content-Type", "application/json")
|
||
|
||
resp, err := s.httpClient.Do(req)
|
||
if err != nil {
|
||
return nil, "", fmt.Errorf("chat request failed: %w", err)
|
||
}
|
||
defer resp.Body.Close()
|
||
|
||
if resp.StatusCode != http.StatusOK {
|
||
body, _ := io.ReadAll(resp.Body)
|
||
return nil, "", fmt.Errorf("chat failed with status %d: %s", resp.StatusCode, string(body))
|
||
}
|
||
|
||
var chatResp ChatCompletionResponse
|
||
if err := json.NewDecoder(resp.Body).Decode(&chatResp); err != nil {
|
||
return nil, "", fmt.Errorf("failed to decode response: %w", err)
|
||
}
|
||
|
||
if len(chatResp.Choices) == 0 {
|
||
return nil, "", errors.New("no response from AI")
|
||
}
|
||
|
||
content := chatResp.Choices[0].Message.Content
|
||
|
||
// Remove markdown code block if present (```json ... ```)
|
||
content = strings.TrimSpace(content)
|
||
if strings.HasPrefix(content, "```") {
|
||
// Find the end of the first line (```json or ```)
|
||
if idx := strings.Index(content, "\n"); idx != -1 {
|
||
content = content[idx+1:]
|
||
}
|
||
// Remove trailing ```
|
||
if idx := strings.LastIndex(content, "```"); idx != -1 {
|
||
content = content[:idx]
|
||
}
|
||
content = strings.TrimSpace(content)
|
||
}
|
||
|
||
// Parse JSON response
|
||
var parsed struct {
|
||
Amount *float64 `json:"amount"`
|
||
Category string `json:"category"`
|
||
Type string `json:"type"`
|
||
Note string `json:"note"`
|
||
Date string `json:"date"`
|
||
Message string `json:"message"`
|
||
}
|
||
|
||
if err := json.Unmarshal([]byte(content), &parsed); err != nil {
|
||
// If not JSON, return as message
|
||
return nil, content, nil
|
||
}
|
||
|
||
params := &AITransactionParams{
|
||
Amount: parsed.Amount,
|
||
Category: parsed.Category,
|
||
Type: parsed.Type,
|
||
Note: parsed.Note,
|
||
Date: parsed.Date,
|
||
}
|
||
|
||
return params, parsed.Message, nil
|
||
}
|
||
|
||
// parseIntentSimple provides simple regex-based parsing as fallback
|
||
// This is also used as a fast path for simple inputs
|
||
func (s *LLMService) parseIntentSimple(text string) (*AITransactionParams, string, error) {
|
||
params := &AITransactionParams{
|
||
Type: "expense", // Default to expense
|
||
Date: time.Now().Format("2006-01-02"),
|
||
}
|
||
|
||
// Extract amount using regex - support various formats
|
||
amountPatterns := []string{
|
||
`(\d+(?:\.\d+)?)\s*(?:元|块|¥|¥|块钱|元钱)`,
|
||
`(?:花了?|付了?|买了?|消费)\s*(\d+(?:\.\d+)?)`,
|
||
`(\d+(?:\.\d+)?)\s*(?:的|块的)`,
|
||
}
|
||
|
||
for _, pattern := range amountPatterns {
|
||
amountRegex := regexp.MustCompile(pattern)
|
||
if matches := amountRegex.FindStringSubmatch(text); len(matches) > 1 {
|
||
if amount, err := strconv.ParseFloat(matches[1], 64); err == nil {
|
||
params.Amount = &amount
|
||
break
|
||
}
|
||
}
|
||
}
|
||
|
||
// If still no amount, try simple number extraction
|
||
if params.Amount == nil {
|
||
simpleAmountRegex := regexp.MustCompile(`(\d+(?:\.\d+)?)`)
|
||
if matches := simpleAmountRegex.FindStringSubmatch(text); len(matches) > 1 {
|
||
if amount, err := strconv.ParseFloat(matches[1], 64); err == nil {
|
||
params.Amount = &amount
|
||
}
|
||
}
|
||
}
|
||
|
||
// Enhanced category detection with priority
|
||
categoryPatterns := []struct {
|
||
keywords []string
|
||
category string
|
||
}{
|
||
{[]string{"奶茶", "咖啡", "茶", "饮料", "柠檬", "果汁"}, "餐饮"},
|
||
{[]string{"吃", "喝", "餐", "外卖", "饭", "面", "粉", "粥", "包子", "早餐", "午餐", "晚餐", "宵夜"}, "餐饮"},
|
||
{[]string{"打车", "滴滴", "出租", "的士", "uber", "曹操"}, "交通"},
|
||
{[]string{"地铁", "公交", "公车", "巴士", "轻轨", "高铁", "火车", "飞机", "机票"}, "交通"},
|
||
{[]string{"加油", "油费", "停车", "过路费"}, "交通"},
|
||
{[]string{"超市", "便利店", "商场", "购物", "淘宝", "京东", "拼多多"}, "购物"},
|
||
{[]string{"买", "购"}, "购物"},
|
||
{[]string{"水电", "电费", "水费", "燃气", "煤气", "物业"}, "生活缴费"},
|
||
{[]string{"房租", "租金", "房贷"}, "住房"},
|
||
{[]string{"电影", "游戏", "KTV", "唱歌", "娱乐", "玩"}, "娱乐"},
|
||
{[]string{"医院", "药", "看病", "挂号", "医疗"}, "医疗"},
|
||
{[]string{"话费", "流量", "充值", "手机费"}, "通讯"},
|
||
{[]string{"工资", "薪水", "薪资", "月薪"}, "工资"},
|
||
{[]string{"奖金", "年终奖", "绩效"}, "奖金"},
|
||
{[]string{"红包", "转账", "收款"}, "其他收入"},
|
||
}
|
||
|
||
for _, cp := range categoryPatterns {
|
||
for _, keyword := range cp.keywords {
|
||
if strings.Contains(text, keyword) {
|
||
params.Category = cp.category
|
||
break
|
||
}
|
||
}
|
||
if params.Category != "" {
|
||
break
|
||
}
|
||
}
|
||
|
||
// Default category if not detected
|
||
if params.Category == "" {
|
||
params.Category = "其他"
|
||
}
|
||
|
||
// Detect income keywords
|
||
incomeKeywords := []string{"工资", "薪", "奖金", "红包", "收入", "进账", "到账", "收到", "收款"}
|
||
for _, keyword := range incomeKeywords {
|
||
if strings.Contains(text, keyword) {
|
||
params.Type = "income"
|
||
break
|
||
}
|
||
}
|
||
|
||
// Extract note - remove amount and common words
|
||
note := text
|
||
if params.Amount != nil {
|
||
note = regexp.MustCompile(`\d+(?:\.\d+)?\s*(?:元|块|¥|¥|块钱|元钱)?`).ReplaceAllString(note, "")
|
||
}
|
||
note = strings.TrimSpace(note)
|
||
// Remove common filler words
|
||
fillerWords := []string{"买了", "花了", "付了", "消费了", "一个", "一条", "一份", "的"}
|
||
for _, word := range fillerWords {
|
||
note = strings.ReplaceAll(note, word, "")
|
||
}
|
||
note = strings.TrimSpace(note)
|
||
if note != "" {
|
||
params.Note = note
|
||
}
|
||
|
||
// Generate response message
|
||
var message string
|
||
if params.Amount == nil {
|
||
message = "请问金额是多少?"
|
||
} else {
|
||
typeLabel := "支出"
|
||
if params.Type == "income" {
|
||
typeLabel = "收入"
|
||
}
|
||
message = fmt.Sprintf("记录:%s %.2f元,分类:%s", typeLabel, *params.Amount, params.Category)
|
||
if params.Note != "" {
|
||
message += ",备注:" + params.Note
|
||
}
|
||
}
|
||
|
||
return params, message, nil
|
||
}
|
||
|
||
// GenerateReport generates a report based on the provided prompt using LLM
|
||
func (s *LLMService) GenerateReport(ctx context.Context, prompt string) (string, error) {
|
||
if s.config.OpenAIAPIKey == "" || s.config.OpenAIBaseURL == "" {
|
||
return "", errors.New("OpenAI API not configured")
|
||
}
|
||
|
||
messages := []ChatMessage{
|
||
{
|
||
Role: "user",
|
||
Content: prompt,
|
||
},
|
||
}
|
||
|
||
reqBody := ChatCompletionRequest{
|
||
Model: s.config.ChatModel,
|
||
Messages: messages,
|
||
Temperature: 0.7, // Higher temperature for creative insights
|
||
}
|
||
|
||
jsonBody, err := json.Marshal(reqBody)
|
||
if err != nil {
|
||
return "", fmt.Errorf("failed to marshal request: %w", err)
|
||
}
|
||
|
||
req, err := http.NewRequestWithContext(ctx, "POST", s.config.OpenAIBaseURL+"/chat/completions", bytes.NewReader(jsonBody))
|
||
if err != nil {
|
||
return "", fmt.Errorf("failed to create request: %w", err)
|
||
}
|
||
|
||
req.Header.Set("Authorization", "Bearer "+s.config.OpenAIAPIKey)
|
||
req.Header.Set("Content-Type", "application/json")
|
||
|
||
resp, err := s.httpClient.Do(req)
|
||
if err != nil {
|
||
return "", fmt.Errorf("generate report request failed: %w", err)
|
||
}
|
||
defer resp.Body.Close()
|
||
|
||
if resp.StatusCode != http.StatusOK {
|
||
body, _ := io.ReadAll(resp.Body)
|
||
return "", fmt.Errorf("generate report failed with status %d: %s", resp.StatusCode, string(body))
|
||
}
|
||
|
||
var chatResp ChatCompletionResponse
|
||
if err := json.NewDecoder(resp.Body).Decode(&chatResp); err != nil {
|
||
return "", fmt.Errorf("failed to decode response: %w", err)
|
||
}
|
||
|
||
if len(chatResp.Choices) == 0 {
|
||
return "", errors.New("no response from AI")
|
||
}
|
||
|
||
return chatResp.Choices[0].Message.Content, nil
|
||
}
|
||
|
||
// StreamChat calls OpenAI Chat Completion API with streaming enabled
|
||
// It returns a channel that yields text chunks, and an error if the request fails to start
|
||
func (s *LLMService) StreamChat(ctx context.Context, messages []ChatMessage) (<-chan string, error) {
|
||
if s.config.OpenAIAPIKey == "" || s.config.OpenAIBaseURL == "" {
|
||
return nil, errors.New("OpenAI API not configured")
|
||
}
|
||
|
||
reqBody := ChatCompletionRequest{
|
||
Model: s.config.ChatModel,
|
||
Messages: messages,
|
||
Temperature: 0.1, // Lower temperature for consistent output in chat
|
||
}
|
||
|
||
// Create a map to inject stream: true
|
||
var bodyMap map[string]interface{}
|
||
jsonBody, _ := json.Marshal(reqBody)
|
||
json.Unmarshal(jsonBody, &bodyMap)
|
||
bodyMap["stream"] = true
|
||
|
||
streamBody, err := json.Marshal(bodyMap)
|
||
if err != nil {
|
||
return nil, fmt.Errorf("failed to marshal request: %w", err)
|
||
}
|
||
|
||
req, err := http.NewRequestWithContext(ctx, "POST", s.config.OpenAIBaseURL+"/chat/completions", bytes.NewReader(streamBody))
|
||
if err != nil {
|
||
return nil, fmt.Errorf("failed to create request: %w", err)
|
||
}
|
||
|
||
req.Header.Set("Authorization", "Bearer "+s.config.OpenAIAPIKey)
|
||
req.Header.Set("Content-Type", "application/json")
|
||
req.Header.Set("Accept", "text/event-stream")
|
||
|
||
// Use a separate client for streaming to avoid timeout issues?
|
||
// Or just reuse. The standard timeout might be too short for long streams.
|
||
// We'll trust the context cancellation for now.
|
||
resp, err := s.httpClient.Do(req)
|
||
if err != nil {
|
||
return nil, fmt.Errorf("stream request failed: %w", err)
|
||
}
|
||
|
||
if resp.StatusCode != http.StatusOK {
|
||
resp.Body.Close()
|
||
return nil, fmt.Errorf("stream failed with status %d", resp.StatusCode)
|
||
}
|
||
|
||
outChan := make(chan string)
|
||
|
||
go func() {
|
||
defer resp.Body.Close()
|
||
defer close(outChan)
|
||
|
||
reader := bufio.NewReader(resp.Body)
|
||
for {
|
||
line, err := reader.ReadBytes('\n')
|
||
if err != nil {
|
||
if err != io.EOF {
|
||
// Log error?
|
||
}
|
||
return
|
||
}
|
||
|
||
line = bytes.TrimSpace(line)
|
||
if !bytes.HasPrefix(line, []byte("data: ")) {
|
||
continue
|
||
}
|
||
|
||
data := bytes.TrimPrefix(line, []byte("data: "))
|
||
if string(data) == "[DONE]" {
|
||
return
|
||
}
|
||
|
||
var streamResp ChatCompletionStreamResponse
|
||
if err := json.Unmarshal(data, &streamResp); err != nil {
|
||
continue
|
||
}
|
||
|
||
if len(streamResp.Choices) > 0 {
|
||
content := streamResp.Choices[0].Delta.Content
|
||
if content != "" {
|
||
select {
|
||
case outChan <- content:
|
||
case <-ctx.Done():
|
||
return
|
||
}
|
||
}
|
||
}
|
||
}
|
||
}()
|
||
|
||
return outChan, nil
|
||
}
|
||
|
||
// MapAccountName maps natural language account name to account ID
|
||
func (s *LLMService) MapAccountName(ctx context.Context, name string, userID uint) (*uint, string, error) {
|
||
if name == "" {
|
||
return nil, "", nil
|
||
}
|
||
|
||
accounts, err := s.accountRepo.GetAll(userID)
|
||
if err != nil {
|
||
return nil, "", err
|
||
}
|
||
|
||
// Try exact match first
|
||
for _, acc := range accounts {
|
||
if strings.EqualFold(acc.Name, name) {
|
||
return &acc.ID, acc.Name, nil
|
||
}
|
||
}
|
||
|
||
// Try partial match
|
||
for _, acc := range accounts {
|
||
if strings.Contains(strings.ToLower(acc.Name), strings.ToLower(name)) ||
|
||
strings.Contains(strings.ToLower(name), strings.ToLower(acc.Name)) {
|
||
return &acc.ID, acc.Name, nil
|
||
}
|
||
}
|
||
|
||
return nil, "", nil
|
||
}
|
||
|
||
// MapCategoryName maps natural language category name to category ID
|
||
func (s *LLMService) MapCategoryName(ctx context.Context, name string, txType string, userID uint) (*uint, string, error) {
|
||
if name == "" {
|
||
return nil, "", nil
|
||
}
|
||
|
||
categories, err := s.categoryRepo.GetAll(userID)
|
||
if err != nil {
|
||
return nil, "", err
|
||
}
|
||
|
||
// Filter by transaction type
|
||
var filtered []models.Category
|
||
for _, cat := range categories {
|
||
if (txType == "expense" && cat.Type == "expense") ||
|
||
(txType == "income" && cat.Type == "income") ||
|
||
txType == "" {
|
||
filtered = append(filtered, cat)
|
||
}
|
||
}
|
||
|
||
// Try exact match first
|
||
for _, cat := range filtered {
|
||
if strings.EqualFold(cat.Name, name) {
|
||
return &cat.ID, cat.Name, nil
|
||
}
|
||
}
|
||
|
||
// Try partial match
|
||
for _, cat := range filtered {
|
||
if strings.Contains(strings.ToLower(cat.Name), strings.ToLower(name)) ||
|
||
strings.Contains(strings.ToLower(name), strings.ToLower(cat.Name)) {
|
||
return &cat.ID, cat.Name, nil
|
||
}
|
||
}
|
||
|
||
return nil, "", nil
|
||
}
|
||
|
||
// generateSessionID generates a unique session ID
|
||
func generateSessionID() string {
|
||
return fmt.Sprintf("ai_%d_%d", time.Now().UnixNano(), time.Now().Unix()%1000)
|
||
}
|
||
|
||
// getOrCreateSession gets existing session or creates new one
|
||
func (s *AIBookkeepingService) getOrCreateSession(sessionID string, userID uint) *AISession {
|
||
s.sessionMutex.Lock()
|
||
defer s.sessionMutex.Unlock()
|
||
|
||
if sessionID != "" {
|
||
if session, ok := s.sessions[sessionID]; ok {
|
||
if time.Now().Before(session.ExpiresAt) {
|
||
return session
|
||
}
|
||
delete(s.sessions, sessionID)
|
||
}
|
||
}
|
||
|
||
// Create new session
|
||
newID := generateSessionID()
|
||
session := &AISession{
|
||
ID: newID,
|
||
UserID: userID,
|
||
Params: &AITransactionParams{},
|
||
Messages: []ChatMessage{},
|
||
CreatedAt: time.Now(),
|
||
ExpiresAt: time.Now().Add(s.config.AISessionTimeout),
|
||
}
|
||
s.sessions[newID] = session
|
||
return session
|
||
}
|
||
|
||
// cleanupExpiredSessions periodically removes expired sessions
|
||
func (s *AIBookkeepingService) cleanupExpiredSessions() {
|
||
ticker := time.NewTicker(5 * time.Minute)
|
||
for range ticker.C {
|
||
s.sessionMutex.Lock()
|
||
now := time.Now()
|
||
for id, session := range s.sessions {
|
||
if now.After(session.ExpiresAt) {
|
||
delete(s.sessions, id)
|
||
}
|
||
}
|
||
s.sessionMutex.Unlock()
|
||
}
|
||
}
|
||
|
||
// ProcessChat processes a chat message and returns AI response
|
||
// Requirements: 7.2-7.4, 7.7-7.10, 12.5, 12.8
|
||
func (s *AIBookkeepingService) ProcessChat(ctx context.Context, userID uint, sessionID string, message string) (*AIChatResponse, error) {
|
||
session := s.getOrCreateSession(sessionID, userID)
|
||
|
||
// Add user message to history
|
||
session.Messages = append(session.Messages, ChatMessage{
|
||
Role: "user",
|
||
Content: message,
|
||
})
|
||
|
||
// 1. 获取财务上下文(用于所有高级功能)
|
||
fc, err := s.GetUserFinancialContext(ctx, userID)
|
||
if err != nil {
|
||
// 降级处理,不中断流程
|
||
fmt.Printf("Failed to get financial context: %v\n", err)
|
||
}
|
||
|
||
// 2. 检测纯查询意图(预算、资产、统计)
|
||
queryIntent := s.detectQueryIntent(message)
|
||
if queryIntent != "" && fc != nil {
|
||
responseMsg := s.handleQueryIntent(ctx, queryIntent, message, fc)
|
||
|
||
response := &AIChatResponse{
|
||
SessionID: session.ID,
|
||
Message: responseMsg,
|
||
Intent: queryIntent,
|
||
Params: session.Params, // 保持参数上下文
|
||
}
|
||
|
||
// 记录 AI 回复
|
||
session.Messages = append(session.Messages, ChatMessage{
|
||
Role: "assistant",
|
||
Content: response.Message,
|
||
})
|
||
return response, nil
|
||
}
|
||
|
||
// 3. 检测消费建议意图(想吃/想买/想喝等)
|
||
isSpendingAdvice := s.isSpendingAdviceIntent(message)
|
||
|
||
// Parse intent for transaction
|
||
params, responseMsg, err := s.llmService.ParseIntent(ctx, message, session.Messages[:len(session.Messages)-1], userID)
|
||
if err != nil {
|
||
return nil, fmt.Errorf("failed to parse intent: %w", err)
|
||
}
|
||
|
||
// Merge with existing session params
|
||
if params != nil {
|
||
s.mergeParams(session.Params, params)
|
||
}
|
||
|
||
// Map account and category names to IDs
|
||
if session.Params.Account != "" && session.Params.AccountID == nil {
|
||
accountID, accountName, _ := s.llmService.MapAccountName(ctx, session.Params.Account, userID)
|
||
if accountID != nil {
|
||
session.Params.AccountID = accountID
|
||
session.Params.Account = accountName
|
||
}
|
||
}
|
||
|
||
if session.Params.Category != "" && session.Params.CategoryID == nil {
|
||
categoryID, categoryName, _ := s.llmService.MapCategoryName(ctx, session.Params.Category, session.Params.Type, userID)
|
||
if categoryID != nil {
|
||
session.Params.CategoryID = categoryID
|
||
session.Params.Category = categoryName
|
||
}
|
||
}
|
||
|
||
// If category still not mapped, try to get a default category
|
||
if session.Params.CategoryID == nil && session.Params.Category != "" {
|
||
defaultCategoryID, defaultCategoryName := s.getDefaultCategory(userID, session.Params.Type)
|
||
if defaultCategoryID != nil {
|
||
session.Params.CategoryID = defaultCategoryID
|
||
if session.Params.Category == "" {
|
||
session.Params.Category = defaultCategoryName
|
||
}
|
||
}
|
||
}
|
||
|
||
// If no account specified, use default account
|
||
if session.Params.AccountID == nil {
|
||
defaultAccountID, defaultAccountName := s.getDefaultAccount(userID, session.Params.Type)
|
||
if defaultAccountID != nil {
|
||
session.Params.AccountID = defaultAccountID
|
||
session.Params.Account = defaultAccountName
|
||
}
|
||
}
|
||
|
||
// 初始化响应
|
||
response := &AIChatResponse{
|
||
SessionID: session.ID,
|
||
Message: responseMsg,
|
||
Intent: "create_transaction",
|
||
Params: session.Params,
|
||
}
|
||
|
||
// 4. 处理消费建议意图
|
||
// 即使没有金额,如果用户是在寻求建议(如“吃什么”),也应该进入建议流程
|
||
if isSpendingAdvice {
|
||
response.Intent = "spending_advice"
|
||
advice := s.generateSpendingAdvice(ctx, message, session.Params, fc)
|
||
if advice != "" {
|
||
response.Message = advice
|
||
}
|
||
}
|
||
|
||
// Check what's missing for transaction creation
|
||
missingFields := s.getMissingFields(session.Params)
|
||
if len(missingFields) > 0 {
|
||
response.NeedsFollowUp = true
|
||
response.FollowUpQuestion = s.generateFollowUpQuestion(ctx, missingFields, message, fc)
|
||
// 如果有了更好的建议回复(来自 handleQuery 或 spendingAdvice),且是 FollowUp,优先保留建议的部分内容或组合
|
||
if response.Message == "" || response.Message == responseMsg {
|
||
response.Message = response.FollowUpQuestion
|
||
}
|
||
} else {
|
||
// All params complete, generate confirmation card
|
||
card := s.GenerateConfirmationCard(session)
|
||
response.ConfirmationCard = card
|
||
|
||
// 如果不是消费建议,使用标准确认消息
|
||
if !isSpendingAdvice {
|
||
response.Message = fmt.Sprintf("请确认:%s %.2f元,分类:%s,账户:%s",
|
||
s.getTypeLabel(session.Params.Type),
|
||
*session.Params.Amount,
|
||
session.Params.Category,
|
||
session.Params.Account)
|
||
} else {
|
||
// 消费建议场景,在建议后附加确认提示
|
||
response.Message += fmt.Sprintf("\n\n📝 需要记账吗?%s %.2f元",
|
||
s.getTypeLabel(session.Params.Type),
|
||
*session.Params.Amount)
|
||
}
|
||
}
|
||
|
||
// Add assistant response to history
|
||
session.Messages = append(session.Messages, ChatMessage{
|
||
Role: "assistant",
|
||
Content: response.Message,
|
||
})
|
||
|
||
return response, nil
|
||
}
|
||
|
||
// StreamProcessChat processes a chat message and streams the AI response
|
||
// It calls onChunk callback for every text chunk received
|
||
// Returns the final AIChatResponse after stream completion for metadata handling
|
||
func (s *AIBookkeepingService) StreamProcessChat(ctx context.Context, userID uint, sessionID string, message string, onChunk func(string)) (*AIChatResponse, error) {
|
||
session := s.getOrCreateSession(sessionID, userID)
|
||
|
||
// Add user message to history
|
||
session.Messages = append(session.Messages, ChatMessage{
|
||
Role: "user",
|
||
Content: message,
|
||
})
|
||
|
||
// 1. 获取财务上下文
|
||
fc, _ := s.GetUserFinancialContext(ctx, userID)
|
||
|
||
// 2. 检测纯查询意图
|
||
queryIntent := s.detectQueryIntent(message)
|
||
if queryIntent != "" && fc != nil {
|
||
// 查询意图通常较短,直接生成并模拟流式输出
|
||
responseMsg := s.handleQueryIntent(ctx, queryIntent, message, fc)
|
||
|
||
// 模拟流式输出
|
||
chunkSize := 4
|
||
runes := []rune(responseMsg)
|
||
for i := 0; i < len(runes); i += chunkSize {
|
||
end := i + chunkSize
|
||
if end > len(runes) {
|
||
end = len(runes)
|
||
}
|
||
onChunk(string(runes[i:end]))
|
||
time.Sleep(20 * time.Millisecond) // 稍微延迟模拟打字
|
||
}
|
||
|
||
response := &AIChatResponse{
|
||
SessionID: session.ID,
|
||
Message: responseMsg,
|
||
Intent: queryIntent,
|
||
Params: session.Params,
|
||
}
|
||
|
||
session.Messages = append(session.Messages, ChatMessage{
|
||
Role: "assistant",
|
||
Content: response.Message,
|
||
})
|
||
return response, nil
|
||
}
|
||
|
||
// 3. 检测消费建议意图
|
||
isSpendingAdvice := s.isSpendingAdviceIntent(message)
|
||
|
||
// Determine prompt
|
||
var systemPrompt string
|
||
// Fetch user categories for prompt context
|
||
var categoryNamesStr string
|
||
if categories, err := s.categoryRepo.GetAll(userID); err == nil && len(categories) > 0 {
|
||
var names []string
|
||
for _, c := range categories {
|
||
names = append(names, c.Name)
|
||
}
|
||
categoryNamesStr = strings.Join(names, "/")
|
||
}
|
||
|
||
todayDate := time.Now().Format("2006-01-02")
|
||
catPrompt := "2. 分类:根据内容推断"
|
||
if categoryNamesStr != "" {
|
||
catPrompt = fmt.Sprintf("2. 分类:必须从以下已有分类中选择最匹配的一项:[%s]。", categoryNamesStr)
|
||
}
|
||
|
||
systemPrompt = fmt.Sprintf(`你是一个智能记账助手。请以自然的对话方式回复用户,同时在回复中确认识别到的记账信息。
|
||
今天的日期是%s
|
||
|
||
规则:
|
||
1. 你的性格:贱萌、嘴硬心软、偶尔凡尔赛。
|
||
2. 如果用户在记账:请提取金额、分类(%s)、描述。
|
||
3. 如果用户在寻求建议:请给出毒舌但有用的建议。
|
||
4. **不要返回 JSON**,直接以自然语言回复。`, todayDate, catPrompt)
|
||
|
||
// Construct minimal messages for stream
|
||
streamMessages := []ChatMessage{
|
||
{Role: "system", Content: systemPrompt},
|
||
}
|
||
// Add recent history
|
||
historyLen := len(session.Messages)
|
||
if historyLen > 4 {
|
||
streamMessages = append(streamMessages, session.Messages[historyLen-4:]...)
|
||
} else {
|
||
streamMessages = append(streamMessages, session.Messages...)
|
||
}
|
||
|
||
// Execute Stream
|
||
streamChan, err := s.llmService.StreamChat(ctx, streamMessages)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
|
||
var fullResponseBuilder strings.Builder
|
||
for chunk := range streamChan {
|
||
fullResponseBuilder.WriteString(chunk)
|
||
onChunk(chunk)
|
||
}
|
||
fullResponse := fullResponseBuilder.String()
|
||
|
||
// 4. Update Session History with Assistant Message
|
||
session.Messages = append(session.Messages, ChatMessage{
|
||
Role: "assistant",
|
||
Content: fullResponse,
|
||
})
|
||
|
||
// 5. Post-Stream Logic: Parse Intent and Generate Card
|
||
// We re-parse the *user message* (not the AI response) to get structured data
|
||
// The original ParseIntent logic uses a specific JSON prompt which we skipped for the stream.
|
||
// So we need to call ParseIntent again or have a way to extract it.
|
||
|
||
// OPTION A: Call ParseIntent strictly for data extraction (hidden from user)
|
||
// This ensures we get the robust JSON parsing logic.
|
||
params, _, err := s.llmService.ParseIntent(ctx, message, session.Messages[:len(session.Messages)-2], userID)
|
||
if err != nil {
|
||
// Log error but don't fail, user already saw the text
|
||
fmt.Printf("Stream ParseIntent failed: %v", err)
|
||
}
|
||
|
||
if params != nil {
|
||
s.mergeParams(session.Params, params)
|
||
}
|
||
|
||
// ... (Same mapping logic as ProcessChat)
|
||
if session.Params.Account != "" && session.Params.AccountID == nil {
|
||
accountID, accountName, _ := s.llmService.MapAccountName(ctx, session.Params.Account, userID)
|
||
if accountID != nil {
|
||
session.Params.AccountID = accountID
|
||
session.Params.Account = accountName
|
||
}
|
||
}
|
||
|
||
if session.Params.Category != "" && session.Params.CategoryID == nil {
|
||
categoryID, categoryName, _ := s.llmService.MapCategoryName(ctx, session.Params.Category, session.Params.Type, userID)
|
||
if categoryID != nil {
|
||
session.Params.CategoryID = categoryID
|
||
session.Params.Category = categoryName
|
||
}
|
||
}
|
||
|
||
// Defaults...
|
||
if session.Params.CategoryID == nil && session.Params.Category != "" {
|
||
defaultCategoryID, defaultCategoryName := s.getDefaultCategory(userID, session.Params.Type)
|
||
if defaultCategoryID != nil {
|
||
session.Params.CategoryID = defaultCategoryID
|
||
if session.Params.Category == "" {
|
||
session.Params.Category = defaultCategoryName
|
||
}
|
||
}
|
||
}
|
||
|
||
if session.Params.AccountID == nil {
|
||
defaultAccountID, defaultAccountName := s.getDefaultAccount(userID, session.Params.Type)
|
||
if defaultAccountID != nil {
|
||
session.Params.AccountID = defaultAccountID
|
||
session.Params.Account = defaultAccountName
|
||
}
|
||
}
|
||
|
||
// Construct final response object
|
||
response := &AIChatResponse{
|
||
SessionID: session.ID,
|
||
Message: fullResponse, // The text already streamed
|
||
Intent: "create_transaction",
|
||
Params: session.Params,
|
||
}
|
||
|
||
if isSpendingAdvice {
|
||
response.Intent = "spending_advice"
|
||
}
|
||
|
||
missingFields := s.getMissingFields(session.Params)
|
||
if len(missingFields) > 0 {
|
||
response.NeedsFollowUp = true
|
||
// Note: We already streamed a response. We probably shouldn't override it with a follow-up question
|
||
// unless the streamed response was generic.
|
||
// For now, we trust the streamed response covers it, or the UI handles "NeedsFollowUp" silently if needed.
|
||
// Actually, let's leave it to the UI. If NeedsFollowUp is true, maybe don't show confirmation card.
|
||
} else {
|
||
card := s.GenerateConfirmationCard(session)
|
||
response.ConfirmationCard = card
|
||
}
|
||
|
||
return response, nil
|
||
}
|
||
|
||
// detectQueryIntent 检测用户查询意图
|
||
func (s *AIBookkeepingService) detectQueryIntent(message string) string {
|
||
budgetKeywords := []string{"预算", "剩多少", "还能花", "余额"} // 注意:余额可能指账户余额,这里简化处理
|
||
assetKeywords := []string{"资产", "多少钱", "家底", "存款", "身家", "总钱"}
|
||
statsKeywords := []string{"花了多少", "支出", "账单", "消费", "统计"}
|
||
|
||
for _, kw := range budgetKeywords {
|
||
if strings.Contains(message, kw) {
|
||
return "query_budget"
|
||
}
|
||
}
|
||
for _, kw := range assetKeywords {
|
||
if strings.Contains(message, kw) {
|
||
return "query_assets"
|
||
}
|
||
}
|
||
for _, kw := range statsKeywords {
|
||
if strings.Contains(message, kw) {
|
||
return "query_stats"
|
||
}
|
||
}
|
||
return ""
|
||
}
|
||
|
||
// handleQueryIntent 处理查询意图并生成 LLM 回复
|
||
func (s *AIBookkeepingService) handleQueryIntent(ctx context.Context, intent string, message string, fc *FinancialContext) string {
|
||
if s.config.OpenAIAPIKey == "" {
|
||
return "抱歉,我的大脑(API Key)似乎离家出走了,无法思考..."
|
||
}
|
||
|
||
// 计算人设模式
|
||
personaMode := "balance" // 默认平衡
|
||
healthScore := 60 // 默认及格
|
||
|
||
// 升级版健康分计算 (结合负债率)
|
||
if fc.TotalAssets > 0 {
|
||
// 基础分:资产/负债比
|
||
ratio := (fc.TotalAssets - fc.TotalLiabilities) / fc.TotalAssets
|
||
healthScore = 40 + int(ratio*50)
|
||
|
||
// 负债率惩罚
|
||
if fc.DebtRatio > 0.5 {
|
||
healthScore -= 20
|
||
}
|
||
if fc.DebtRatio > 0.8 {
|
||
healthScore -= 20 // 严重扣分
|
||
}
|
||
} else if fc.TotalLiabilities > 0 {
|
||
// 资不抵债
|
||
healthScore = 10
|
||
}
|
||
|
||
if healthScore > 80 {
|
||
personaMode = "rich"
|
||
} else if healthScore <= 40 {
|
||
personaMode = "poor"
|
||
}
|
||
|
||
// 构建 Prompt
|
||
systemPrompt := fmt.Sprintf(`你是「小金」,Novault 的首席财务 AI。
|
||
当前模式:%s (根据用户财务健康分 %d 判定)
|
||
|
||
角色设定:
|
||
- **rich (富裕)**:撒娇卖萌,夸用户会赚钱,鼓励适度享受。用词:哎哟、不错哦、老板大气。
|
||
- **balance (平衡)**:理性贴心,温和提醒。用词:虽然、但是、建议。
|
||
- **poor (吃土)**:毒舌、阴阳怪气、恨铁不成钢。用词:啧啧、清醒点、吃土、西北风。
|
||
|
||
用户意图:%s
|
||
用户问题:「%s」
|
||
|
||
财务数据上下文:
|
||
%s
|
||
|
||
要求:
|
||
1. 根据意图提取并回答关键数据(预算剩余、总资产、或本月支出)。
|
||
2. 必须符合当前人设模式的语气。
|
||
3. **不需要限制字数**,想说多少说多少,关键是要“有梗”和“有趣”。
|
||
4. 可以尽情使用比喻、夸张、反讽、网络流行语。
|
||
5. 不要罗列所有数据,只回答用户问的,但是回答的方式要出人意料。`,
|
||
personaMode, healthScore, intent, message, s.formatFinancialContextForLLM(fc))
|
||
|
||
messages := []ChatMessage{
|
||
{Role: "system", Content: systemPrompt},
|
||
{Role: "user", Content: message},
|
||
}
|
||
|
||
return s.callLLM(ctx, messages)
|
||
}
|
||
|
||
// formatFinancialContextForLLM 格式化上下文给 LLM
|
||
func (s *AIBookkeepingService) formatFinancialContextForLLM(fc *FinancialContext) string {
|
||
data, _ := json.MarshalIndent(fc, "", " ")
|
||
return string(data)
|
||
}
|
||
|
||
// callLLM 通用 LLM 调用 helper
|
||
func (s *AIBookkeepingService) callLLM(ctx context.Context, messages []ChatMessage) string {
|
||
reqBody := ChatCompletionRequest{
|
||
Model: s.config.ChatModel,
|
||
Messages: messages,
|
||
Temperature: 0.8, // 稍微调高以增加人设表现力
|
||
}
|
||
|
||
jsonBody, _ := json.Marshal(reqBody)
|
||
req, err := http.NewRequestWithContext(ctx, "POST", s.config.OpenAIBaseURL+"/chat/completions", bytes.NewReader(jsonBody))
|
||
if err != nil {
|
||
return "思考中断..."
|
||
}
|
||
req.Header.Set("Authorization", "Bearer "+s.config.OpenAIAPIKey)
|
||
req.Header.Set("Content-Type", "application/json")
|
||
|
||
resp, err := s.llmService.httpClient.Do(req)
|
||
if err != nil || resp.StatusCode != http.StatusOK {
|
||
return "大脑短路了..."
|
||
}
|
||
defer resp.Body.Close()
|
||
|
||
var chatResp ChatCompletionResponse
|
||
if err := json.NewDecoder(resp.Body).Decode(&chatResp); err != nil || len(chatResp.Choices) == 0 {
|
||
return "..."
|
||
}
|
||
return strings.TrimSpace(chatResp.Choices[0].Message.Content)
|
||
}
|
||
|
||
// isSpendingAdviceIntent 检测是否为消费建议意图
|
||
func (s *AIBookkeepingService) isSpendingAdviceIntent(message string) bool {
|
||
keywords := []string{"想吃", "想喝", "想买", "想花", "打算买", "准备买", "要不要", "可以买", "能买", "想要", "推荐", "吃什么", "喝什么", "买什么"}
|
||
for _, kw := range keywords {
|
||
if strings.Contains(message, kw) {
|
||
return true
|
||
}
|
||
}
|
||
return false
|
||
}
|
||
|
||
// generateSpendingAdvice 生成消费建议
|
||
func (s *AIBookkeepingService) generateSpendingAdvice(ctx context.Context, message string, params *AITransactionParams, fc *FinancialContext) string {
|
||
if s.config.OpenAIAPIKey == "" || fc == nil {
|
||
if params.Amount != nil {
|
||
return fmt.Sprintf("记下来!%.0f元的%s", *params.Amount, params.Note)
|
||
}
|
||
return ""
|
||
}
|
||
|
||
// 动态人设逻辑
|
||
personaMode := "balance"
|
||
healthScore := 60
|
||
if fc.TotalAssets > 0 {
|
||
ratio := (fc.TotalAssets - fc.TotalLiabilities) / fc.TotalAssets
|
||
healthScore = 40 + int(ratio*50)
|
||
if fc.DebtRatio > 0.5 {
|
||
healthScore -= 20
|
||
}
|
||
} else if fc.TotalLiabilities > 0 {
|
||
healthScore = 10
|
||
}
|
||
|
||
if healthScore > 80 {
|
||
personaMode = "rich"
|
||
} else if healthScore <= 40 {
|
||
personaMode = "poor"
|
||
}
|
||
|
||
prompt := fmt.Sprintf(`你是「小金」,Novault 的首席财务 AI。
|
||
当前模式:%s (根据用户财务健康分 %d 判定)
|
||
角色设定:
|
||
- **rich**: 鼓励享受,语气轻松。
|
||
- **balance**: 理性建议,温和提醒。
|
||
- **poor**: 毒舌劝阻,语气严厉。
|
||
|
||
用户说:「%s」
|
||
财务数据:
|
||
%s
|
||
|
||
请分析消费请求,给出建议。
|
||
要求:
|
||
1. 不要加前缀,直接回复。
|
||
2. **尽情发挥戏剧性**:如果是 rich 模式,就极尽奢华浮夸之能事;如果是 poor 模式,就极尽刻薄毒舌之能事。
|
||
3. 可以“顾左右而言他”,不一定非要正面回答买不买,可以吐槽商品、吐槽用户、吐槽世界。
|
||
4. 让用户觉得你是一个有独立思想的“杠精”或“捧哏”,而不是一个莫得感情的计算器。`,
|
||
personaMode, healthScore, message, s.formatFinancialContextForLLM(fc))
|
||
|
||
messages := []ChatMessage{
|
||
{Role: "user", Content: prompt},
|
||
}
|
||
|
||
return s.callLLM(ctx, messages)
|
||
}
|
||
|
||
// mergeParams merges new params into existing params
|
||
func (s *AIBookkeepingService) mergeParams(existing, new *AITransactionParams) {
|
||
if new.Amount != nil {
|
||
existing.Amount = new.Amount
|
||
}
|
||
if new.Category != "" {
|
||
existing.Category = new.Category
|
||
}
|
||
if new.CategoryID != nil {
|
||
existing.CategoryID = new.CategoryID
|
||
}
|
||
if new.Account != "" {
|
||
existing.Account = new.Account
|
||
}
|
||
if new.AccountID != nil {
|
||
existing.AccountID = new.AccountID
|
||
}
|
||
if new.Type != "" {
|
||
existing.Type = new.Type
|
||
}
|
||
if new.Date != "" {
|
||
existing.Date = new.Date
|
||
}
|
||
if new.Note != "" {
|
||
existing.Note = new.Note
|
||
}
|
||
}
|
||
|
||
// getDefaultAccount gets the default account based on transaction type
|
||
// If no default is set, returns the first available account
|
||
func (s *AIBookkeepingService) getDefaultAccount(userID uint, txType string) (*uint, string) {
|
||
// First try to get user's configured default account
|
||
settings, err := s.userSettingsRepo.GetOrCreate(userID)
|
||
if err == nil && settings != nil {
|
||
var accountID *uint
|
||
if txType == "expense" && settings.DefaultExpenseAccountID != nil {
|
||
accountID = settings.DefaultExpenseAccountID
|
||
} else if txType == "income" && settings.DefaultIncomeAccountID != nil {
|
||
accountID = settings.DefaultIncomeAccountID
|
||
}
|
||
|
||
if accountID != nil {
|
||
account, err := s.accountRepo.GetByID(userID, *accountID)
|
||
if err == nil && account != nil {
|
||
return accountID, account.Name
|
||
}
|
||
}
|
||
}
|
||
|
||
// Fallback: get the first available account
|
||
accounts, err := s.accountRepo.GetAll(userID)
|
||
if err != nil || len(accounts) == 0 {
|
||
return nil, ""
|
||
}
|
||
|
||
// Return the first account (usually sorted by sort_order)
|
||
return &accounts[0].ID, accounts[0].Name
|
||
}
|
||
|
||
// getDefaultCategory gets the first category of the given type
|
||
func (s *AIBookkeepingService) getDefaultCategory(userID uint, txType string) (*uint, string) {
|
||
categories, err := s.categoryRepo.GetAll(userID)
|
||
if err != nil || len(categories) == 0 {
|
||
return nil, ""
|
||
}
|
||
|
||
// Find the first category matching the transaction type
|
||
categoryType := "expense"
|
||
if txType == "income" {
|
||
categoryType = "income"
|
||
}
|
||
|
||
for _, cat := range categories {
|
||
if string(cat.Type) == categoryType {
|
||
return &cat.ID, cat.Name
|
||
}
|
||
}
|
||
|
||
// If no matching type found, return the first category
|
||
return &categories[0].ID, categories[0].Name
|
||
}
|
||
|
||
// getMissingFields returns list of missing required fields
|
||
func (s *AIBookkeepingService) getMissingFields(params *AITransactionParams) []string {
|
||
var missing []string
|
||
if params.Amount == nil || *params.Amount <= 0 {
|
||
missing = append(missing, "amount")
|
||
}
|
||
if params.CategoryID == nil && params.Category == "" {
|
||
missing = append(missing, "category")
|
||
}
|
||
if params.AccountID == nil && params.Account == "" {
|
||
missing = append(missing, "account")
|
||
}
|
||
return missing
|
||
}
|
||
|
||
// generateFollowUpQuestion generates a follow-up question for missing fields using LLM
|
||
func (s *AIBookkeepingService) generateFollowUpQuestion(ctx context.Context, missing []string, userMessage string, fc *FinancialContext) string {
|
||
if len(missing) == 0 {
|
||
return ""
|
||
}
|
||
|
||
fieldNames := map[string]string{
|
||
"amount": "金额",
|
||
"category": "分类",
|
||
"account": "账户",
|
||
}
|
||
|
||
var names []string
|
||
for _, field := range missing {
|
||
if name, ok := fieldNames[field]; ok {
|
||
names = append(names, name)
|
||
}
|
||
}
|
||
missingStr := strings.Join(names, "、")
|
||
|
||
// 如果没有 API Key,降级到模板回复
|
||
if s.config.OpenAIAPIKey == "" {
|
||
if len(names) == 1 {
|
||
return fmt.Sprintf("请问%s是多少?", names[0])
|
||
}
|
||
return fmt.Sprintf("请补充以下信息:%s", missingStr)
|
||
}
|
||
|
||
// 动态人设逻辑
|
||
personaMode := "balance"
|
||
healthScore := 60
|
||
if fc != nil && fc.TotalAssets > 0 {
|
||
ratio := (fc.TotalAssets - fc.TotalLiabilities) / fc.TotalAssets
|
||
healthScore = 40 + int(ratio*50)
|
||
}
|
||
if healthScore > 80 {
|
||
personaMode = "rich"
|
||
} else if healthScore <= 40 {
|
||
personaMode = "poor"
|
||
}
|
||
|
||
prompt := fmt.Sprintf(`你是「小金」,一个贱萌、毒舌、幽默的 AI 记账助手。
|
||
当前模式:%s
|
||
缺失信息:%s
|
||
用户说:「%s」
|
||
|
||
请用你的风格追问用户缺失的信息。
|
||
要求:
|
||
1. 针对缺失的 "%s",用幽默、调侃、甚至一点点“攻击性”的方式追问。
|
||
2. **拒绝机械模板**:不要每次都问一样的话,要根据心情随机应变。
|
||
3. 比如缺失金额,可以问"是白嫖的吗?"、"价格是国家机密吗?"、"快说花了多少,让我死心"。
|
||
4. 比如缺失分类,可以猜"这属于吃喝玩乐哪一类?"、"是用来填饱肚子的还是填补空虚的?"。
|
||
5. 字数不限,关键是要**有灵魂**,像个真人在聊天。`,
|
||
personaMode, missingStr, userMessage, missingStr)
|
||
|
||
messages := []ChatMessage{
|
||
{Role: "user", Content: prompt},
|
||
}
|
||
|
||
answer := s.callLLM(ctx, messages)
|
||
if answer == "" || answer == "..." {
|
||
return fmt.Sprintf("那个...你能告诉我%s吗?", missingStr)
|
||
}
|
||
return answer
|
||
}
|
||
|
||
// getTypeLabel returns Chinese label for transaction type
|
||
func (s *AIBookkeepingService) getTypeLabel(txType string) string {
|
||
if txType == "income" {
|
||
return "收入"
|
||
}
|
||
return "支出"
|
||
}
|
||
|
||
// GenerateConfirmationCard creates a confirmation card from session params
|
||
func (s *AIBookkeepingService) GenerateConfirmationCard(session *AISession) *ConfirmationCard {
|
||
params := session.Params
|
||
|
||
card := &ConfirmationCard{
|
||
SessionID: session.ID,
|
||
Type: params.Type,
|
||
Note: params.Note,
|
||
IsComplete: true,
|
||
}
|
||
|
||
if params.Amount != nil {
|
||
card.Amount = *params.Amount
|
||
}
|
||
if params.CategoryID != nil {
|
||
card.CategoryID = *params.CategoryID
|
||
}
|
||
card.Category = params.Category
|
||
if params.AccountID != nil {
|
||
card.AccountID = *params.AccountID
|
||
}
|
||
card.Account = params.Account
|
||
|
||
// Set date
|
||
if params.Date != "" {
|
||
card.Date = params.Date
|
||
} else {
|
||
card.Date = time.Now().Format("2006-01-02")
|
||
}
|
||
|
||
// Check for budget warnings
|
||
if params.Amount != nil && *params.Amount > 0 && params.Type == "expense" {
|
||
card.Warning = s.checkBudgetWarning(session.UserID, params.CategoryID, *params.Amount)
|
||
}
|
||
|
||
return card
|
||
}
|
||
|
||
// checkBudgetWarning checks if the transaction exceeds any budget
|
||
func (s *AIBookkeepingService) checkBudgetWarning(userID uint, categoryID *uint, amount float64) string {
|
||
now := time.Now()
|
||
var budgets []models.Budget
|
||
|
||
// Find active budgets
|
||
// We check:
|
||
// 1. Budgets specifically for this category
|
||
// 2. Global budgets (CategoryID is NULL)
|
||
query := s.db.Where("user_id = ?", userID).
|
||
Where("start_date <= ?", now).
|
||
Where("end_date IS NULL OR end_date >= ?", now)
|
||
|
||
if categoryID != nil {
|
||
query = query.Where("category_id = ? OR category_id IS NULL", *categoryID)
|
||
} else {
|
||
query = query.Where("category_id IS NULL")
|
||
}
|
||
|
||
if err := query.Find(&budgets).Error; err != nil {
|
||
return ""
|
||
}
|
||
|
||
for _, budget := range budgets {
|
||
// Calculate current period
|
||
start, end := s.calculateBudgetPeriod(&budget, now)
|
||
|
||
// Query spent amount
|
||
var totalSpent float64
|
||
q := s.db.Model(&models.Transaction{}).
|
||
Where("user_id = ? AND type = ? AND transaction_date BETWEEN ? AND ?",
|
||
userID, models.TransactionTypeExpense, start, end)
|
||
|
||
if budget.CategoryID != nil {
|
||
// Get sub-categories
|
||
var subCategoryIDs []uint
|
||
s.db.Model(&models.Category{}).Where("parent_id = ?", *budget.CategoryID).Pluck("id", &subCategoryIDs)
|
||
categoryIDs := append(subCategoryIDs, *budget.CategoryID)
|
||
q = q.Where("category_id IN ?", categoryIDs)
|
||
}
|
||
// If budget has account restriction, we should ideally check that too,
|
||
// but we don't always have accountID resolved here perfectly or it might be complex.
|
||
// For now, focusing on category budgets which are most common.
|
||
|
||
q.Select("COALESCE(SUM(amount), 0)").Scan(&totalSpent)
|
||
|
||
if totalSpent+amount > budget.Amount {
|
||
remaining := budget.Amount - totalSpent
|
||
if remaining < 0 {
|
||
remaining = 0
|
||
}
|
||
return fmt.Sprintf("⚠️ 预算预警:此交易将使【%s】预算超支 (当前剩余 %.2f)", budget.Name, remaining)
|
||
}
|
||
}
|
||
return ""
|
||
}
|
||
|
||
// TranscribeAudio transcribes audio and returns text
|
||
func (s *AIBookkeepingService) TranscribeAudio(ctx context.Context, audioData io.Reader, filename string) (*TranscriptionResult, error) {
|
||
return s.whisperService.TranscribeAudio(ctx, audioData, filename)
|
||
}
|
||
|
||
// ConfirmTransaction creates a transaction from confirmed card
|
||
// Requirements: 7.10
|
||
func (s *AIBookkeepingService) ConfirmTransaction(ctx context.Context, sessionID string, userID uint) (*models.Transaction, error) {
|
||
s.sessionMutex.RLock()
|
||
session, ok := s.sessions[sessionID]
|
||
s.sessionMutex.RUnlock()
|
||
|
||
if !ok {
|
||
return nil, errors.New("session not found or expired")
|
||
}
|
||
|
||
params := session.Params
|
||
|
||
// Validate required fields
|
||
if params.Amount == nil || *params.Amount <= 0 {
|
||
return nil, errors.New("无效的金额")
|
||
}
|
||
if params.CategoryID == nil {
|
||
return nil, errors.New("未指定分类")
|
||
}
|
||
if params.AccountID == nil {
|
||
return nil, errors.New("未指定账户")
|
||
}
|
||
|
||
// Parse date
|
||
var txDate time.Time
|
||
if params.Date != "" {
|
||
var err error
|
||
txDate, err = time.Parse("2006-01-02", params.Date)
|
||
if err != nil {
|
||
txDate = time.Now()
|
||
}
|
||
} else {
|
||
txDate = time.Now()
|
||
}
|
||
|
||
// Determine transaction type
|
||
txType := models.TransactionTypeExpense
|
||
if params.Type == "income" {
|
||
txType = models.TransactionTypeIncome
|
||
}
|
||
|
||
var transaction *models.Transaction
|
||
|
||
// Execute within a database transaction to ensure atomicity
|
||
err := s.db.Transaction(func(tx *gorm.DB) error {
|
||
txAccountRepo := repository.NewAccountRepository(tx)
|
||
txTransactionRepo := repository.NewTransactionRepository(tx)
|
||
|
||
// Get account first to check balance
|
||
account, err := txAccountRepo.GetByID(userID, *params.AccountID)
|
||
if err != nil {
|
||
return fmt.Errorf("failed to find account: %w", err)
|
||
}
|
||
|
||
// Calculate new balance
|
||
newBalance := account.Balance
|
||
if txType == models.TransactionTypeExpense {
|
||
newBalance -= *params.Amount
|
||
} else {
|
||
newBalance += *params.Amount
|
||
}
|
||
|
||
// Critical Check: Prevent negative balance for non-credit accounts
|
||
if !account.IsCredit && newBalance < 0 {
|
||
return fmt.Errorf("余额不足:账户“%s”不支持负余额 (当前: %.2f, 尝试扣款: %.2f)",
|
||
account.Name, account.Balance, *params.Amount)
|
||
}
|
||
|
||
// Create transaction model
|
||
transaction = &models.Transaction{
|
||
UserID: userID,
|
||
Amount: *params.Amount,
|
||
Type: txType,
|
||
CategoryID: *params.CategoryID,
|
||
AccountID: *params.AccountID,
|
||
TransactionDate: txDate,
|
||
Note: params.Note,
|
||
Currency: models.Currency(account.Currency), // Use account currency
|
||
}
|
||
if transaction.Currency == "" {
|
||
transaction.Currency = "CNY"
|
||
}
|
||
|
||
// Save transaction
|
||
if err := txTransactionRepo.Create(transaction); err != nil {
|
||
return fmt.Errorf("failed to create transaction: %w", err)
|
||
}
|
||
|
||
// Update account balance
|
||
account.Balance = newBalance
|
||
if err := txAccountRepo.Update(account); err != nil {
|
||
return fmt.Errorf("failed to update account balance: %w", err)
|
||
}
|
||
|
||
return nil
|
||
})
|
||
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
|
||
// Clean up session only on success
|
||
s.sessionMutex.Lock()
|
||
delete(s.sessions, sessionID)
|
||
s.sessionMutex.Unlock()
|
||
|
||
return transaction, nil
|
||
}
|
||
|
||
// GetSession returns session by ID
|
||
func (s *AIBookkeepingService) GetSession(sessionID string) (*AISession, bool) {
|
||
s.sessionMutex.RLock()
|
||
defer s.sessionMutex.RUnlock()
|
||
|
||
session, ok := s.sessions[sessionID]
|
||
if !ok || time.Now().After(session.ExpiresAt) {
|
||
return nil, false
|
||
}
|
||
return session, true
|
||
}
|
||
|
||
// FinancialContext 用户财务上下文,供 AI 综合分析
|
||
type FinancialContext struct {
|
||
// 账户信息
|
||
TotalBalance float64 `json:"total_balance"` // 净资产 (资产 - 负债)
|
||
TotalAssets float64 `json:"total_assets"` // 总资产
|
||
TotalLiabilities float64 `json:"total_liabilities"` // 总负债
|
||
DebtRatio float64 `json:"debt_ratio"` // 负债率 (0-1)
|
||
AccountSummary []AccountBrief `json:"account_summary"` // 账户摘要
|
||
|
||
// 最近消费
|
||
Last30DaysSpend float64 `json:"last_30_days_spend"` // 近30天支出
|
||
Last7DaysSpend float64 `json:"last_7_days_spend"` // 近7天支出
|
||
TodaySpend float64 `json:"today_spend"` // 今日支出
|
||
MaxSingleSpend *TransactionBrief `json:"max_single_spend"` // 本月最大单笔支出
|
||
TopCategories []CategorySpend `json:"top_categories"` // 消费大类TOP3
|
||
RecentTransactions []TransactionBrief `json:"recent_transactions"` // 最近5笔交易
|
||
UpcomingRecurring []string `json:"upcoming_recurring"` // 未来7天固定支出提醒
|
||
|
||
// 预算与目标
|
||
ActiveBudgets []BudgetBrief `json:"active_budgets"` // 活跃预算
|
||
BudgetWarnings []string `json:"budget_warnings"` // 预算警告
|
||
SavingsProgress []string `json:"savings_progress"` // 存钱进度摘要
|
||
|
||
// 历史对比
|
||
LastMonthSpend float64 `json:"last_month_spend"` // 上月同期支出
|
||
SpendTrend string `json:"spend_trend"` // 消费趋势: up/down/stable
|
||
}
|
||
|
||
// AccountBrief 账户摘要
|
||
type AccountBrief struct {
|
||
Name string `json:"name"`
|
||
Balance float64 `json:"balance"`
|
||
Type string `json:"type"`
|
||
}
|
||
|
||
// CategorySpend 分类消费
|
||
type CategorySpend struct {
|
||
Category string `json:"category"`
|
||
Amount float64 `json:"amount"`
|
||
Percent float64 `json:"percent"`
|
||
}
|
||
|
||
// TransactionBrief 交易摘要
|
||
type TransactionBrief struct {
|
||
Amount float64 `json:"amount"`
|
||
Category string `json:"category"`
|
||
Note string `json:"note"`
|
||
Date string `json:"date"`
|
||
Type string `json:"type"`
|
||
}
|
||
|
||
// BudgetBrief 预算摘要
|
||
type BudgetBrief struct {
|
||
Name string `json:"name"`
|
||
Amount float64 `json:"amount"`
|
||
Spent float64 `json:"spent"`
|
||
Remaining float64 `json:"remaining"`
|
||
Progress float64 `json:"progress"` // 0-100
|
||
IsNearLimit bool `json:"is_near_limit"`
|
||
Category string `json:"category,omitempty"`
|
||
}
|
||
|
||
// GetUserFinancialContext 获取用户财务上下文
|
||
func (s *AIBookkeepingService) GetUserFinancialContext(ctx context.Context, userID uint) (*FinancialContext, error) {
|
||
fc := &FinancialContext{}
|
||
|
||
// 1. 获取账户信息
|
||
accounts, err := s.accountRepo.GetAll(userID)
|
||
if err == nil {
|
||
for _, acc := range accounts {
|
||
fc.TotalBalance += acc.Balance
|
||
|
||
// 根据余额正负判断资产/负债
|
||
// 余额 >= 0 计入资产
|
||
// 余额 < 0 计入负债(取绝对值)
|
||
if acc.Balance >= 0 {
|
||
fc.TotalAssets += acc.Balance
|
||
} else {
|
||
fc.TotalLiabilities += -acc.Balance
|
||
}
|
||
|
||
fc.AccountSummary = append(fc.AccountSummary, AccountBrief{
|
||
Name: acc.Name,
|
||
Balance: acc.Balance,
|
||
Type: string(acc.Type),
|
||
})
|
||
}
|
||
}
|
||
|
||
// 2. 获取最近交易统计
|
||
now := time.Now()
|
||
today := time.Date(now.Year(), now.Month(), now.Day(), 0, 0, 0, 0, now.Location())
|
||
last7Days := today.AddDate(0, 0, -7)
|
||
last30Days := today.AddDate(0, 0, -30)
|
||
lastMonthStart := today.AddDate(0, -1, -today.Day()+1)
|
||
lastMonthEnd := today.AddDate(0, 0, -today.Day())
|
||
|
||
// 获取近30天交易
|
||
transactions, err := s.transactionRepo.GetByDateRange(userID, last30Days, now)
|
||
if err == nil {
|
||
categorySpend := make(map[string]float64)
|
||
|
||
for _, tx := range transactions {
|
||
if tx.Type == models.TransactionTypeExpense {
|
||
fc.Last30DaysSpend += tx.Amount
|
||
|
||
if tx.TransactionDate.After(last7Days) || tx.TransactionDate.Equal(last7Days) {
|
||
fc.Last7DaysSpend += tx.Amount
|
||
}
|
||
if tx.TransactionDate.After(today) || tx.TransactionDate.Equal(today) {
|
||
fc.TodaySpend += tx.Amount
|
||
}
|
||
|
||
// 分类统计
|
||
catName := "其他"
|
||
if tx.Category.ID != 0 {
|
||
catName = tx.Category.Name
|
||
}
|
||
categorySpend[catName] += tx.Amount
|
||
}
|
||
}
|
||
|
||
// 计算 TOP3 分类
|
||
type catAmount struct {
|
||
name string
|
||
amount float64
|
||
}
|
||
var cats []catAmount
|
||
for name, amount := range categorySpend {
|
||
cats = append(cats, catAmount{name, amount})
|
||
}
|
||
// 简单排序取 TOP3
|
||
for i := 0; i < len(cats)-1; i++ {
|
||
for j := i + 1; j < len(cats); j++ {
|
||
if cats[j].amount > cats[i].amount {
|
||
cats[i], cats[j] = cats[j], cats[i]
|
||
}
|
||
}
|
||
}
|
||
for i := 0; i < len(cats) && i < 3; i++ {
|
||
percent := 0.0
|
||
if fc.Last30DaysSpend > 0 {
|
||
percent = cats[i].amount / fc.Last30DaysSpend * 100
|
||
}
|
||
fc.TopCategories = append(fc.TopCategories, CategorySpend{
|
||
Category: cats[i].name,
|
||
Amount: cats[i].amount,
|
||
Percent: percent,
|
||
})
|
||
}
|
||
|
||
// 最近5笔交易
|
||
count := 0
|
||
for i := len(transactions) - 1; i >= 0 && count < 5; i-- {
|
||
tx := transactions[i]
|
||
catName := "其他"
|
||
if tx.Category.ID != 0 {
|
||
catName = tx.Category.Name
|
||
}
|
||
fc.RecentTransactions = append(fc.RecentTransactions, TransactionBrief{
|
||
Amount: tx.Amount,
|
||
Category: catName,
|
||
Note: tx.Note,
|
||
Date: tx.TransactionDate.Format("01-02"),
|
||
Type: string(tx.Type),
|
||
})
|
||
count++
|
||
}
|
||
}
|
||
|
||
// 3. 获取上月同期支出用于对比
|
||
lastMonthTx, err := s.transactionRepo.GetByDateRange(userID, lastMonthStart, lastMonthEnd)
|
||
if err == nil {
|
||
for _, tx := range lastMonthTx {
|
||
if tx.Type == models.TransactionTypeExpense {
|
||
fc.LastMonthSpend += tx.Amount
|
||
}
|
||
}
|
||
}
|
||
|
||
// 计算消费趋势
|
||
if fc.LastMonthSpend > 0 {
|
||
ratio := fc.Last30DaysSpend / fc.LastMonthSpend
|
||
if ratio > 1.1 {
|
||
fc.SpendTrend = "up"
|
||
} else if ratio < 0.9 {
|
||
fc.SpendTrend = "down"
|
||
} else {
|
||
fc.SpendTrend = "stable"
|
||
}
|
||
} else {
|
||
fc.SpendTrend = "stable"
|
||
}
|
||
|
||
// 4. 获取预算信息(通过直接查询数据库)
|
||
var budgets []models.Budget
|
||
if err := s.db.Where("user_id = ?", userID).
|
||
Where("start_date <= ?", now).
|
||
Where("end_date IS NULL OR end_date >= ?", now).
|
||
Preload("Category").
|
||
Find(&budgets).Error; err == nil {
|
||
|
||
for _, budget := range budgets {
|
||
// 计算当期支出
|
||
periodStart, periodEnd := s.calculateBudgetPeriod(&budget, now)
|
||
spent := 0.0
|
||
|
||
// 查询当期支出
|
||
var totalSpent float64
|
||
query := s.db.Model(&models.Transaction{}).
|
||
Where("user_id = ?", userID).
|
||
Where("type = ?", models.TransactionTypeExpense).
|
||
Where("transaction_date >= ? AND transaction_date <= ?", periodStart, periodEnd)
|
||
|
||
if budget.CategoryID != nil {
|
||
query = query.Where("category_id = ?", *budget.CategoryID)
|
||
}
|
||
if budget.AccountID != nil {
|
||
query = query.Where("account_id = ?", *budget.AccountID)
|
||
}
|
||
query.Select("COALESCE(SUM(amount), 0)").Scan(&totalSpent)
|
||
spent = totalSpent
|
||
|
||
progress := 0.0
|
||
if budget.Amount > 0 {
|
||
progress = spent / budget.Amount * 100
|
||
}
|
||
|
||
isNearLimit := progress >= 80.0
|
||
|
||
catName := ""
|
||
if budget.Category != nil {
|
||
catName = budget.Category.Name
|
||
}
|
||
|
||
fc.ActiveBudgets = append(fc.ActiveBudgets, BudgetBrief{
|
||
Name: budget.Name,
|
||
Amount: budget.Amount,
|
||
Spent: spent,
|
||
Remaining: budget.Amount - spent,
|
||
Progress: progress,
|
||
IsNearLimit: isNearLimit,
|
||
Category: catName,
|
||
})
|
||
|
||
// 生成预算警告
|
||
if progress >= 100 {
|
||
fc.BudgetWarnings = append(fc.BudgetWarnings,
|
||
fmt.Sprintf("【%s】预算已超支!", budget.Name))
|
||
} else if isNearLimit {
|
||
fc.BudgetWarnings = append(fc.BudgetWarnings,
|
||
fmt.Sprintf("【%s】预算已用%.0f%%,请注意控制", budget.Name, progress))
|
||
}
|
||
}
|
||
}
|
||
|
||
// 计算负债率
|
||
if fc.TotalAssets+fc.TotalLiabilities > 0 {
|
||
fc.DebtRatio = fc.TotalLiabilities / (fc.TotalAssets + fc.TotalLiabilities)
|
||
}
|
||
|
||
// 5. 获取存钱罐进度 (PiggyBank)
|
||
var piggyBanks []models.PiggyBank
|
||
if err := s.db.Where("user_id = ?", userID).Find(&piggyBanks).Error; err == nil {
|
||
for _, pb := range piggyBanks {
|
||
progress := 0.0
|
||
if pb.TargetAmount > 0 {
|
||
progress = pb.CurrentAmount / pb.TargetAmount * 100
|
||
}
|
||
status := "进行中"
|
||
if progress >= 100 {
|
||
status = "已达成"
|
||
}
|
||
fc.SavingsProgress = append(fc.SavingsProgress, fmt.Sprintf("%s: %.0f/%.0f (%.1f%%) - %s", pb.Name, pb.CurrentAmount, pb.TargetAmount, progress, status))
|
||
}
|
||
}
|
||
|
||
// 6. 获取未来7天即将到期的定期事务 (RecurringTransaction)
|
||
var recurrings []models.RecurringTransaction
|
||
next7Days := now.AddDate(0, 0, 7)
|
||
if err := s.db.Where("user_id = ? AND is_active = ? AND type = ?", userID, true, models.TransactionTypeExpense).
|
||
Where("next_occurrence BETWEEN ? AND ?", now, next7Days).
|
||
Find(&recurrings).Error; err == nil {
|
||
for _, r := range recurrings {
|
||
days := int(r.NextOccurrence.Sub(now).Hours() / 24)
|
||
msg := fmt.Sprintf("%s 还有 %d 天扣款 %.2f元", r.Note, days, r.Amount)
|
||
if r.Note == "" {
|
||
// 获取分类名
|
||
var cat models.Category
|
||
s.db.First(&cat, r.CategoryID)
|
||
msg = fmt.Sprintf("%s 还有 %d 天扣款 %.2f元", cat.Name, days, r.Amount)
|
||
}
|
||
fc.UpcomingRecurring = append(fc.UpcomingRecurring, msg)
|
||
}
|
||
}
|
||
|
||
// 7. 计算本月最大单笔支出
|
||
if len(transactions) > 0 {
|
||
var maxTx *models.Transaction
|
||
for _, tx := range transactions {
|
||
if tx.Type == models.TransactionTypeExpense {
|
||
if maxTx == nil || tx.Amount > maxTx.Amount {
|
||
currentTx := tx // 创建副本以避免取地址问题
|
||
maxTx = ¤tTx
|
||
}
|
||
}
|
||
}
|
||
if maxTx != nil {
|
||
catName := "其他"
|
||
if maxTx.Category.ID != 0 {
|
||
catName = maxTx.Category.Name
|
||
}
|
||
fc.MaxSingleSpend = &TransactionBrief{
|
||
Amount: maxTx.Amount,
|
||
Category: catName,
|
||
Note: maxTx.Note,
|
||
Date: maxTx.TransactionDate.Format("01-02"),
|
||
Type: string(maxTx.Type),
|
||
}
|
||
}
|
||
}
|
||
|
||
return fc, nil
|
||
}
|
||
|
||
// calculateBudgetPeriod 计算预算当前周期
|
||
func (s *AIBookkeepingService) calculateBudgetPeriod(budget *models.Budget, now time.Time) (time.Time, time.Time) {
|
||
switch budget.PeriodType {
|
||
case models.PeriodTypeDaily:
|
||
start := time.Date(now.Year(), now.Month(), now.Day(), 0, 0, 0, 0, now.Location())
|
||
end := start.AddDate(0, 0, 1).Add(-time.Second)
|
||
return start, end
|
||
case models.PeriodTypeWeekly:
|
||
weekday := int(now.Weekday())
|
||
if weekday == 0 {
|
||
weekday = 7
|
||
}
|
||
start := time.Date(now.Year(), now.Month(), now.Day()-weekday+1, 0, 0, 0, 0, now.Location())
|
||
end := start.AddDate(0, 0, 7).Add(-time.Second)
|
||
return start, end
|
||
case models.PeriodTypeMonthly:
|
||
start := time.Date(now.Year(), now.Month(), 1, 0, 0, 0, 0, now.Location())
|
||
end := start.AddDate(0, 1, 0).Add(-time.Second)
|
||
return start, end
|
||
case models.PeriodTypeYearly:
|
||
start := time.Date(now.Year(), 1, 1, 0, 0, 0, 0, now.Location())
|
||
end := start.AddDate(1, 0, 0).Add(-time.Second)
|
||
return start, end
|
||
default:
|
||
start := time.Date(now.Year(), now.Month(), 1, 0, 0, 0, 0, now.Location())
|
||
end := start.AddDate(0, 1, 0).Add(-time.Second)
|
||
return start, end
|
||
}
|
||
}
|