Files
Novault-backend/internal/service/ai_bookkeeping_service.go

1001 lines
29 KiB
Go
Raw Normal View History

2026-01-25 21:59:00 +08:00
package service
import (
"bytes"
"context"
"encoding/json"
"errors"
"fmt"
"io"
"mime/multipart"
"net/http"
"regexp"
"strconv"
"strings"
"sync"
"time"
"accounting-app/internal/config"
"accounting-app/internal/models"
"accounting-app/internal/repository"
"gorm.io/gorm"
)
// TranscriptionResult represents the result of audio transcription
type TranscriptionResult struct {
Text string `json:"text"`
Language string `json:"language,omitempty"`
Duration float64 `json:"duration,omitempty"`
}
// AITransactionParams represents parsed transaction parameters
type AITransactionParams struct {
Amount *float64 `json:"amount,omitempty"`
Category string `json:"category,omitempty"`
CategoryID *uint `json:"category_id,omitempty"`
Account string `json:"account,omitempty"`
AccountID *uint `json:"account_id,omitempty"`
Type string `json:"type,omitempty"` // "expense" or "income"
Date string `json:"date,omitempty"`
Note string `json:"note,omitempty"`
}
// ConfirmationCard represents a transaction confirmation card
type ConfirmationCard struct {
SessionID string `json:"session_id"`
Amount float64 `json:"amount"`
Category string `json:"category"`
CategoryID uint `json:"category_id"`
Account string `json:"account"`
AccountID uint `json:"account_id"`
Type string `json:"type"`
Date string `json:"date"`
Note string `json:"note,omitempty"`
IsComplete bool `json:"is_complete"`
}
// AIChatResponse represents the response from AI chat
type AIChatResponse struct {
SessionID string `json:"session_id"`
Message string `json:"message"`
Intent string `json:"intent,omitempty"` // "create_transaction", "query", "unknown"
Params *AITransactionParams `json:"params,omitempty"`
ConfirmationCard *ConfirmationCard `json:"confirmation_card,omitempty"`
NeedsFollowUp bool `json:"needs_follow_up"`
FollowUpQuestion string `json:"follow_up_question,omitempty"`
}
// AISession represents an AI conversation session
type AISession struct {
ID string
UserID uint
Params *AITransactionParams
Messages []ChatMessage
CreatedAt time.Time
ExpiresAt time.Time
}
// ChatMessage represents a message in the conversation
type ChatMessage struct {
Role string `json:"role"` // "user", "assistant", "system"
Content string `json:"content"`
}
// WhisperService handles audio transcription
type WhisperService struct {
config *config.Config
httpClient *http.Client
}
// NewWhisperService creates a new WhisperService
func NewWhisperService(cfg *config.Config) *WhisperService {
return &WhisperService{
config: cfg,
httpClient: &http.Client{
Timeout: 120 * time.Second, // Increased timeout for audio transcription
},
}
}
// TranscribeAudio transcribes audio file to text using Whisper API
// Supports formats: mp3, wav, m4a, webm
// Requirements: 6.1-6.7
func (s *WhisperService) TranscribeAudio(ctx context.Context, audioData io.Reader, filename string) (*TranscriptionResult, error) {
if s.config.OpenAIAPIKey == "" {
return nil, errors.New("OpenAI API key not configured (OPENAI_API_KEY)")
}
if s.config.OpenAIBaseURL == "" {
return nil, errors.New("OpenAI base URL not configured (OPENAI_BASE_URL)")
}
// Validate file format
ext := strings.ToLower(filename[strings.LastIndex(filename, ".")+1:])
validFormats := map[string]bool{"mp3": true, "wav": true, "m4a": true, "webm": true, "ogg": true, "flac": true}
if !validFormats[ext] {
return nil, fmt.Errorf("unsupported audio format: %s", ext)
}
// Create multipart form
var buf bytes.Buffer
writer := multipart.NewWriter(&buf)
// Add audio file
part, err := writer.CreateFormFile("file", filename)
if err != nil {
return nil, fmt.Errorf("failed to create form file: %w", err)
}
if _, err := io.Copy(part, audioData); err != nil {
return nil, fmt.Errorf("failed to copy audio data: %w", err)
}
// Add model field
if err := writer.WriteField("model", s.config.WhisperModel); err != nil {
return nil, fmt.Errorf("failed to write model field: %w", err)
}
// Add language hint for Chinese
if err := writer.WriteField("language", "zh"); err != nil {
return nil, fmt.Errorf("failed to write language field: %w", err)
}
if err := writer.Close(); err != nil {
return nil, fmt.Errorf("failed to close writer: %w", err)
}
// Create request
req, err := http.NewRequestWithContext(ctx, "POST", s.config.OpenAIBaseURL+"/audio/transcriptions", &buf)
if err != nil {
return nil, fmt.Errorf("failed to create request: %w", err)
}
req.Header.Set("Authorization", "Bearer "+s.config.OpenAIAPIKey)
req.Header.Set("Content-Type", writer.FormDataContentType())
// Send request
resp, err := s.httpClient.Do(req)
if err != nil {
return nil, fmt.Errorf("transcription request failed: %w", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
body, _ := io.ReadAll(resp.Body)
return nil, fmt.Errorf("transcription failed with status %d: %s", resp.StatusCode, string(body))
}
// Parse response
var result TranscriptionResult
if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
return nil, fmt.Errorf("failed to decode response: %w", err)
}
return &result, nil
}
// LLMService handles natural language understanding
type LLMService struct {
config *config.Config
httpClient *http.Client
accountRepo *repository.AccountRepository
categoryRepo *repository.CategoryRepository
}
// NewLLMService creates a new LLMService
func NewLLMService(cfg *config.Config, accountRepo *repository.AccountRepository, categoryRepo *repository.CategoryRepository) *LLMService {
return &LLMService{
config: cfg,
httpClient: &http.Client{
Timeout: 60 * time.Second, // Increased timeout for slow API responses
},
accountRepo: accountRepo,
categoryRepo: categoryRepo,
}
}
// ChatCompletionRequest represents OpenAI chat completion request
type ChatCompletionRequest struct {
Model string `json:"model"`
Messages []ChatMessage `json:"messages"`
Functions []Function `json:"functions,omitempty"`
Temperature float64 `json:"temperature"`
}
// Function represents an OpenAI function definition
type Function struct {
Name string `json:"name"`
Description string `json:"description"`
Parameters map[string]interface{} `json:"parameters"`
}
// ChatCompletionResponse represents OpenAI chat completion response
type ChatCompletionResponse struct {
Choices []struct {
Message struct {
Role string `json:"role"`
Content string `json:"content"`
FunctionCall *struct {
Name string `json:"name"`
Arguments string `json:"arguments"`
} `json:"function_call,omitempty"`
} `json:"message"`
} `json:"choices"`
}
// ParseIntent extracts transaction parameters from text
// Requirements: 7.1, 7.5, 7.6
func (s *LLMService) ParseIntent(ctx context.Context, text string, history []ChatMessage) (*AITransactionParams, string, error) {
// Fast path: try simple parsing first for common patterns
// This avoids LLM call for simple inputs like "6块钱奶茶"
// TODO: 暂时禁用本地解析快速路径,始终使用 LLM
// simpleParams, simpleMsg, _ := s.parseIntentSimple(text)
// if simpleParams != nil && simpleParams.Amount != nil && simpleParams.Category != "" && simpleParams.Category != "其他" {
// // Simple parsing succeeded with amount and category, use it directly
// return simpleParams, simpleMsg, nil
// }
if s.config.OpenAIAPIKey == "" || s.config.OpenAIBaseURL == "" {
// No API key, return simple parsing result
simpleParams, simpleMsg, _ := s.parseIntentSimple(text)
return simpleParams, simpleMsg, nil
}
// Build messages with history
todayDate := time.Now().Format("2006-01-02")
systemPrompt := fmt.Sprintf(`你是一个智能记账助手从用户描述中提取记账信息<EFBFBD>?
今天的日期是<EFBFBD>?s
规则<EFBFBD>?
1. 金额提取数字<EFBFBD>?6<EFBFBD>?=6<EFBFBD>?十五<EFBFBD>?=15
2. 分类根据内容推断<EFBFBD>?奶茶/咖啡/吃饭"=餐饮<E9A490>?打车/地铁"=交通"买衣<EFBFBD>?=购物
3. 类型默认expense(支出)除非明确说"收入/工资/奖金/红包"
4. 日期默认使用今天的日期<EFBFBD>?s除非用户明确指定其他日期
5. 备注提取关键描<EFBFBD>?
直接返回JSON不要解释
{"amount":数字,"category":"分类","type":"expense或income","note":"备注","date":"YYYY-MM-DD","message":"简短确<EFBFBD>?}
示例假设今天是%s
用户<EFBFBD>?买了<EFBFBD>?块的奶茶"
返回{"amount":6,"category":"餐饮","type":"expense","note":"奶茶","date":"%s","message":"记录餐饮支<E9A5AE>?元,奶茶"}`, todayDate, todayDate, todayDate, todayDate)
messages := []ChatMessage{
{
Role: "system",
Content: systemPrompt,
},
}
// Only add last 2 messages from history to reduce context
historyLen := len(history)
if historyLen > 4 {
history = history[historyLen-4:]
}
messages = append(messages, history...)
// Add current user message
messages = append(messages, ChatMessage{
Role: "user",
Content: text,
})
// Create request
reqBody := ChatCompletionRequest{
Model: s.config.ChatModel,
Messages: messages,
Temperature: 0.1, // Lower temperature for more consistent output
}
jsonBody, err := json.Marshal(reqBody)
if err != nil {
return nil, "", fmt.Errorf("failed to marshal request: %w", err)
}
req, err := http.NewRequestWithContext(ctx, "POST", s.config.OpenAIBaseURL+"/chat/completions", bytes.NewReader(jsonBody))
if err != nil {
return nil, "", fmt.Errorf("failed to create request: %w", err)
}
req.Header.Set("Authorization", "Bearer "+s.config.OpenAIAPIKey)
req.Header.Set("Content-Type", "application/json")
resp, err := s.httpClient.Do(req)
if err != nil {
return nil, "", fmt.Errorf("chat request failed: %w", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
body, _ := io.ReadAll(resp.Body)
return nil, "", fmt.Errorf("chat failed with status %d: %s", resp.StatusCode, string(body))
}
var chatResp ChatCompletionResponse
if err := json.NewDecoder(resp.Body).Decode(&chatResp); err != nil {
return nil, "", fmt.Errorf("failed to decode response: %w", err)
}
if len(chatResp.Choices) == 0 {
return nil, "", errors.New("no response from AI")
}
content := chatResp.Choices[0].Message.Content
// Remove markdown code block if present (```json ... ```)
content = strings.TrimSpace(content)
if strings.HasPrefix(content, "```") {
// Find the end of the first line (```json or ```)
if idx := strings.Index(content, "\n"); idx != -1 {
content = content[idx+1:]
}
// Remove trailing ```
if idx := strings.LastIndex(content, "```"); idx != -1 {
content = content[:idx]
}
content = strings.TrimSpace(content)
}
// Parse JSON response
var parsed struct {
Amount *float64 `json:"amount"`
Category string `json:"category"`
Type string `json:"type"`
Note string `json:"note"`
Date string `json:"date"`
Message string `json:"message"`
}
if err := json.Unmarshal([]byte(content), &parsed); err != nil {
// If not JSON, return as message
return nil, content, nil
}
params := &AITransactionParams{
Amount: parsed.Amount,
Category: parsed.Category,
Type: parsed.Type,
Note: parsed.Note,
Date: parsed.Date,
}
return params, parsed.Message, nil
}
// parseIntentSimple provides simple regex-based parsing as fallback
// This is also used as a fast path for simple inputs
func (s *LLMService) parseIntentSimple(text string) (*AITransactionParams, string, error) {
params := &AITransactionParams{
Type: "expense", // Default to expense
Date: time.Now().Format("2006-01-02"),
}
// Extract amount using regex - support various formats
amountPatterns := []string{
`(\d+(?:\.\d+)?)\s*(?:元|块|¥|¥|块钱|元钱)`,
`(?:花了?|付了?|买了?|消费)\s*(\d+(?:\.\d+)?)`,
`(\d+(?:\.\d+)?)\s*(?:的|块的)`,
}
for _, pattern := range amountPatterns {
amountRegex := regexp.MustCompile(pattern)
if matches := amountRegex.FindStringSubmatch(text); len(matches) > 1 {
if amount, err := strconv.ParseFloat(matches[1], 64); err == nil {
params.Amount = &amount
break
}
}
}
// If still no amount, try simple number extraction
if params.Amount == nil {
simpleAmountRegex := regexp.MustCompile(`(\d+(?:\.\d+)?)`)
if matches := simpleAmountRegex.FindStringSubmatch(text); len(matches) > 1 {
if amount, err := strconv.ParseFloat(matches[1], 64); err == nil {
params.Amount = &amount
}
}
}
// Enhanced category detection with priority
categoryPatterns := []struct {
keywords []string
category string
}{
{[]string{"奶茶", "咖啡", "茶", "饮料", "柠檬", "果汁"}, "餐饮"},
{[]string{"吃", "喝", "餐", "外卖", "饭", "面", "粉", "粥", "包子", "早餐", "午餐", "晚餐", "宵夜"}, "餐饮"},
{[]string{"打车", "滴滴", "出租", "的士", "uber", "曹操"}, "交通"},
{[]string{"地铁", "公交", "公车", "巴士", "轻轨", "高铁", "火车", "飞机", "机票"}, "交通"},
{[]string{"加油", "油费", "停车", "过路费"}, "交通"},
{[]string{"超市", "便利店", "商场", "购物", "淘宝", "京东", "拼多多"}, "购物"},
{[]string{"买", "购"}, "购物"},
{[]string{"水电", "电费", "水费", "燃气", "煤气", "物业"}, "生活缴费"},
{[]string{"房租", "租金", "房贷"}, "住房"},
{[]string{"电影", "游戏", "KTV", "唱歌", "娱乐", "玩"}, "娱乐"},
{[]string{"医院", "药", "看病", "挂号", "医疗"}, "医疗"},
{[]string{"话费", "流量", "充值", "手机费"}, "通讯"},
{[]string{"工资", "薪水", "薪资", "月薪"}, "工资"},
{[]string{"奖金", "年终奖", "绩效"}, "奖金"},
{[]string{"红包", "转账", "收款"}, "其他收入"},
}
for _, cp := range categoryPatterns {
for _, keyword := range cp.keywords {
if strings.Contains(text, keyword) {
params.Category = cp.category
break
}
}
if params.Category != "" {
break
}
}
// Default category if not detected
if params.Category == "" {
params.Category = "其他"
}
// Detect income keywords
incomeKeywords := []string{"工资", "薪", "奖金", "红包", "收入", "进账", "到账", "收到", "收款"}
for _, keyword := range incomeKeywords {
if strings.Contains(text, keyword) {
params.Type = "income"
break
}
}
// Extract note - remove amount and common words
note := text
if params.Amount != nil {
note = regexp.MustCompile(`\d+(?:\.\d+)?\s*(?:元|块|¥|¥|块钱|元钱)?`).ReplaceAllString(note, "")
}
note = strings.TrimSpace(note)
// Remove common filler words
fillerWords := []string{"买了", "花了", "付了", "消费了", "一个", "一条", "一份", "的"}
for _, word := range fillerWords {
note = strings.ReplaceAll(note, word, "")
}
note = strings.TrimSpace(note)
if note != "" {
params.Note = note
}
// Generate response message
var message string
if params.Amount == nil {
message = "请问金额是多少?"
} else {
typeLabel := "支出"
if params.Type == "income" {
typeLabel = "收入"
}
message = fmt.Sprintf("记录:%s %.2f元,分类:%s", typeLabel, *params.Amount, params.Category)
if params.Note != "" {
message += ",备注:" + params.Note
}
}
return params, message, nil
}
// MapAccountName maps natural language account name to account ID
func (s *LLMService) MapAccountName(ctx context.Context, name string, userID uint) (*uint, string, error) {
if name == "" {
return nil, "", nil
}
accounts, err := s.accountRepo.GetAll(userID)
if err != nil {
return nil, "", err
}
// Try exact match first
for _, acc := range accounts {
if strings.EqualFold(acc.Name, name) {
return &acc.ID, acc.Name, nil
}
}
// Try partial match
for _, acc := range accounts {
if strings.Contains(strings.ToLower(acc.Name), strings.ToLower(name)) ||
strings.Contains(strings.ToLower(name), strings.ToLower(acc.Name)) {
return &acc.ID, acc.Name, nil
}
}
return nil, "", nil
}
// MapCategoryName maps natural language category name to category ID
func (s *LLMService) MapCategoryName(ctx context.Context, name string, txType string, userID uint) (*uint, string, error) {
if name == "" {
return nil, "", nil
}
categories, err := s.categoryRepo.GetAll(userID)
if err != nil {
return nil, "", err
}
// Filter by transaction type
var filtered []models.Category
for _, cat := range categories {
if (txType == "expense" && cat.Type == "expense") ||
(txType == "income" && cat.Type == "income") ||
txType == "" {
filtered = append(filtered, cat)
}
}
// Try exact match first
for _, cat := range filtered {
if strings.EqualFold(cat.Name, name) {
return &cat.ID, cat.Name, nil
}
}
// Try partial match
for _, cat := range filtered {
if strings.Contains(strings.ToLower(cat.Name), strings.ToLower(name)) ||
strings.Contains(strings.ToLower(name), strings.ToLower(cat.Name)) {
return &cat.ID, cat.Name, nil
}
}
return nil, "", nil
}
// AIBookkeepingService orchestrates AI bookkeeping functionality
type AIBookkeepingService struct {
whisperService *WhisperService
llmService *LLMService
transactionRepo *repository.TransactionRepository
accountRepo *repository.AccountRepository
categoryRepo *repository.CategoryRepository
userSettingsRepo *repository.UserSettingsRepository
db *gorm.DB
sessions map[string]*AISession
sessionMutex sync.RWMutex
config *config.Config
}
// NewAIBookkeepingService creates a new AIBookkeepingService
func NewAIBookkeepingService(
cfg *config.Config,
transactionRepo *repository.TransactionRepository,
accountRepo *repository.AccountRepository,
categoryRepo *repository.CategoryRepository,
userSettingsRepo *repository.UserSettingsRepository,
db *gorm.DB,
) *AIBookkeepingService {
whisperService := NewWhisperService(cfg)
llmService := NewLLMService(cfg, accountRepo, categoryRepo)
svc := &AIBookkeepingService{
whisperService: whisperService,
llmService: llmService,
transactionRepo: transactionRepo,
accountRepo: accountRepo,
categoryRepo: categoryRepo,
userSettingsRepo: userSettingsRepo,
db: db,
sessions: make(map[string]*AISession),
config: cfg,
}
// Start session cleanup goroutine
go svc.cleanupExpiredSessions()
return svc
}
// generateSessionID generates a unique session ID
func generateSessionID() string {
return fmt.Sprintf("ai_%d_%d", time.Now().UnixNano(), time.Now().Unix()%1000)
}
// getOrCreateSession gets existing session or creates new one
func (s *AIBookkeepingService) getOrCreateSession(sessionID string, userID uint) *AISession {
s.sessionMutex.Lock()
defer s.sessionMutex.Unlock()
if sessionID != "" {
if session, ok := s.sessions[sessionID]; ok {
if time.Now().Before(session.ExpiresAt) {
return session
}
delete(s.sessions, sessionID)
}
}
// Create new session
newID := generateSessionID()
session := &AISession{
ID: newID,
UserID: userID,
Params: &AITransactionParams{},
Messages: []ChatMessage{},
CreatedAt: time.Now(),
ExpiresAt: time.Now().Add(s.config.AISessionTimeout),
}
s.sessions[newID] = session
return session
}
// cleanupExpiredSessions periodically removes expired sessions
func (s *AIBookkeepingService) cleanupExpiredSessions() {
ticker := time.NewTicker(5 * time.Minute)
for range ticker.C {
s.sessionMutex.Lock()
now := time.Now()
for id, session := range s.sessions {
if now.After(session.ExpiresAt) {
delete(s.sessions, id)
}
}
s.sessionMutex.Unlock()
}
}
// ProcessChat processes a chat message and returns AI response
// Requirements: 7.2-7.4, 7.7-7.10, 12.5, 12.8
func (s *AIBookkeepingService) ProcessChat(ctx context.Context, userID uint, sessionID string, message string) (*AIChatResponse, error) {
session := s.getOrCreateSession(sessionID, userID)
// Add user message to history
session.Messages = append(session.Messages, ChatMessage{
Role: "user",
Content: message,
})
// Parse intent
params, responseMsg, err := s.llmService.ParseIntent(ctx, message, session.Messages[:len(session.Messages)-1])
if err != nil {
return nil, fmt.Errorf("failed to parse intent: %w", err)
}
// Merge with existing session params
if params != nil {
s.mergeParams(session.Params, params)
}
// Map account and category names to IDs
if session.Params.Account != "" && session.Params.AccountID == nil {
accountID, accountName, _ := s.llmService.MapAccountName(ctx, session.Params.Account, userID)
if accountID != nil {
session.Params.AccountID = accountID
session.Params.Account = accountName
}
}
if session.Params.Category != "" && session.Params.CategoryID == nil {
categoryID, categoryName, _ := s.llmService.MapCategoryName(ctx, session.Params.Category, session.Params.Type, userID)
if categoryID != nil {
session.Params.CategoryID = categoryID
session.Params.Category = categoryName
}
}
// If category still not mapped, try to get a default category
if session.Params.CategoryID == nil && session.Params.Category != "" {
defaultCategoryID, defaultCategoryName := s.getDefaultCategory(userID, session.Params.Type)
if defaultCategoryID != nil {
session.Params.CategoryID = defaultCategoryID
// Keep the original category name from AI, just set the ID
if session.Params.Category == "" {
session.Params.Category = defaultCategoryName
}
}
}
// If no account specified, use default account
if session.Params.AccountID == nil {
defaultAccountID, defaultAccountName := s.getDefaultAccount(userID, session.Params.Type)
if defaultAccountID != nil {
session.Params.AccountID = defaultAccountID
session.Params.Account = defaultAccountName
}
}
// Check if we have all required params
response := &AIChatResponse{
SessionID: session.ID,
Message: responseMsg,
Intent: "create_transaction",
Params: session.Params,
}
// Check what's missing
missingFields := s.getMissingFields(session.Params)
if len(missingFields) > 0 {
response.NeedsFollowUp = true
response.FollowUpQuestion = s.generateFollowUpQuestion(missingFields)
if responseMsg == "" {
response.Message = response.FollowUpQuestion
}
} else {
// All params complete, generate confirmation card
card := s.GenerateConfirmationCard(session)
response.ConfirmationCard = card
response.Message = fmt.Sprintf("请确认:%s %.2f元分类<EFBFBD>?s账户%s",
s.getTypeLabel(session.Params.Type),
*session.Params.Amount,
session.Params.Category,
session.Params.Account)
}
// Add assistant response to history
session.Messages = append(session.Messages, ChatMessage{
Role: "assistant",
Content: response.Message,
})
return response, nil
}
// mergeParams merges new params into existing params
func (s *AIBookkeepingService) mergeParams(existing, new *AITransactionParams) {
if new.Amount != nil {
existing.Amount = new.Amount
}
if new.Category != "" {
existing.Category = new.Category
}
if new.CategoryID != nil {
existing.CategoryID = new.CategoryID
}
if new.Account != "" {
existing.Account = new.Account
}
if new.AccountID != nil {
existing.AccountID = new.AccountID
}
if new.Type != "" {
existing.Type = new.Type
}
if new.Date != "" {
existing.Date = new.Date
}
if new.Note != "" {
existing.Note = new.Note
}
}
// getDefaultAccount gets the default account based on transaction type
// If no default is set, returns the first available account
func (s *AIBookkeepingService) getDefaultAccount(userID uint, txType string) (*uint, string) {
// First try to get user's configured default account
settings, err := s.userSettingsRepo.GetOrCreate(userID)
if err == nil && settings != nil {
var accountID *uint
if txType == "expense" && settings.DefaultExpenseAccountID != nil {
accountID = settings.DefaultExpenseAccountID
} else if txType == "income" && settings.DefaultIncomeAccountID != nil {
accountID = settings.DefaultIncomeAccountID
}
if accountID != nil {
account, err := s.accountRepo.GetByID(userID, *accountID)
if err == nil && account != nil {
return accountID, account.Name
}
}
}
// Fallback: get the first available account
accounts, err := s.accountRepo.GetAll(userID)
if err != nil || len(accounts) == 0 {
return nil, ""
}
// Return the first account (usually sorted by sort_order)
return &accounts[0].ID, accounts[0].Name
}
// getDefaultCategory gets the first category of the given type
func (s *AIBookkeepingService) getDefaultCategory(userID uint, txType string) (*uint, string) {
categories, err := s.categoryRepo.GetAll(userID)
if err != nil || len(categories) == 0 {
return nil, ""
}
// Find the first category matching the transaction type
categoryType := "expense"
if txType == "income" {
categoryType = "income"
}
for _, cat := range categories {
if string(cat.Type) == categoryType {
return &cat.ID, cat.Name
}
}
// If no matching type found, return the first category
return &categories[0].ID, categories[0].Name
}
// getMissingFields returns list of missing required fields
func (s *AIBookkeepingService) getMissingFields(params *AITransactionParams) []string {
var missing []string
if params.Amount == nil {
missing = append(missing, "amount")
}
if params.CategoryID == nil && params.Category == "" {
missing = append(missing, "category")
}
if params.AccountID == nil && params.Account == "" {
missing = append(missing, "account")
}
return missing
}
// generateFollowUpQuestion generates a follow-up question for missing fields
func (s *AIBookkeepingService) generateFollowUpQuestion(missing []string) string {
if len(missing) == 0 {
return ""
}
fieldNames := map[string]string{
"amount": "金额",
"category": "分类",
"account": "账户",
}
var names []string
for _, field := range missing {
if name, ok := fieldNames[field]; ok {
names = append(names, name)
}
}
if len(names) == 1 {
return fmt.Sprintf("请问%s是多少", names[0])
}
return fmt.Sprintf("请补充以下信息:%s", strings.Join(names, "、"))
}
// getTypeLabel returns Chinese label for transaction type
func (s *AIBookkeepingService) getTypeLabel(txType string) string {
if txType == "income" {
return "收入"
}
return "支出"
}
// GenerateConfirmationCard creates a confirmation card from session params
func (s *AIBookkeepingService) GenerateConfirmationCard(session *AISession) *ConfirmationCard {
params := session.Params
card := &ConfirmationCard{
SessionID: session.ID,
Type: params.Type,
Note: params.Note,
IsComplete: true,
}
if params.Amount != nil {
card.Amount = *params.Amount
}
if params.CategoryID != nil {
card.CategoryID = *params.CategoryID
}
card.Category = params.Category
if params.AccountID != nil {
card.AccountID = *params.AccountID
}
card.Account = params.Account
// Set date
if params.Date != "" {
card.Date = params.Date
} else {
card.Date = time.Now().Format("2006-01-02")
}
return card
}
// TranscribeAudio transcribes audio and returns text
func (s *AIBookkeepingService) TranscribeAudio(ctx context.Context, audioData io.Reader, filename string) (*TranscriptionResult, error) {
return s.whisperService.TranscribeAudio(ctx, audioData, filename)
}
// ConfirmTransaction creates a transaction from confirmed card
// Requirements: 7.10
func (s *AIBookkeepingService) ConfirmTransaction(ctx context.Context, sessionID string, userID uint) (*models.Transaction, error) {
s.sessionMutex.RLock()
session, ok := s.sessions[sessionID]
s.sessionMutex.RUnlock()
if !ok {
return nil, errors.New("session not found or expired")
}
params := session.Params
// Validate required fields
if params.Amount == nil || *params.Amount <= 0 {
return nil, errors.New("invalid amount")
}
if params.CategoryID == nil {
return nil, errors.New("category not specified")
}
if params.AccountID == nil {
return nil, errors.New("account not specified")
}
// Parse date
var txDate time.Time
if params.Date != "" {
var err error
txDate, err = time.Parse("2006-01-02", params.Date)
if err != nil {
txDate = time.Now()
}
} else {
txDate = time.Now()
}
// Determine transaction type
txType := models.TransactionTypeExpense
if params.Type == "income" {
txType = models.TransactionTypeIncome
}
// Create transaction
tx := &models.Transaction{
UserID: userID,
Amount: *params.Amount,
Type: txType,
CategoryID: *params.CategoryID,
AccountID: *params.AccountID,
TransactionDate: txDate,
Note: params.Note,
Currency: "CNY",
}
// Save transaction
if err := s.transactionRepo.Create(tx); err != nil {
return nil, fmt.Errorf("failed to create transaction: %w", err)
}
// Update account balance
account, err := s.accountRepo.GetByID(userID, *params.AccountID)
if err != nil {
return nil, fmt.Errorf("failed to find account: %w", err)
}
if txType == models.TransactionTypeExpense {
account.Balance -= *params.Amount
} else {
account.Balance += *params.Amount
}
if err := s.accountRepo.Update(account); err != nil {
return nil, fmt.Errorf("failed to update account balance: %w", err)
}
// Clean up session
s.sessionMutex.Lock()
delete(s.sessions, sessionID)
s.sessionMutex.Unlock()
return tx, nil
}
// GetSession returns session by ID
func (s *AIBookkeepingService) GetSession(sessionID string) (*AISession, bool) {
s.sessionMutex.RLock()
defer s.sessionMutex.RUnlock()
session, ok := s.sessions[sessionID]
if !ok || time.Now().After(session.ExpiresAt) {
return nil, false
}
return session, true
}