1001 lines
29 KiB
Go
1001 lines
29 KiB
Go
|
|
package service
|
|||
|
|
|
|||
|
|
import (
|
|||
|
|
"bytes"
|
|||
|
|
"context"
|
|||
|
|
"encoding/json"
|
|||
|
|
"errors"
|
|||
|
|
"fmt"
|
|||
|
|
"io"
|
|||
|
|
"mime/multipart"
|
|||
|
|
"net/http"
|
|||
|
|
"regexp"
|
|||
|
|
"strconv"
|
|||
|
|
"strings"
|
|||
|
|
"sync"
|
|||
|
|
"time"
|
|||
|
|
|
|||
|
|
"accounting-app/internal/config"
|
|||
|
|
"accounting-app/internal/models"
|
|||
|
|
"accounting-app/internal/repository"
|
|||
|
|
|
|||
|
|
"gorm.io/gorm"
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
// TranscriptionResult represents the result of audio transcription
|
|||
|
|
type TranscriptionResult struct {
|
|||
|
|
Text string `json:"text"`
|
|||
|
|
Language string `json:"language,omitempty"`
|
|||
|
|
Duration float64 `json:"duration,omitempty"`
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// AITransactionParams represents parsed transaction parameters
|
|||
|
|
type AITransactionParams struct {
|
|||
|
|
Amount *float64 `json:"amount,omitempty"`
|
|||
|
|
Category string `json:"category,omitempty"`
|
|||
|
|
CategoryID *uint `json:"category_id,omitempty"`
|
|||
|
|
Account string `json:"account,omitempty"`
|
|||
|
|
AccountID *uint `json:"account_id,omitempty"`
|
|||
|
|
Type string `json:"type,omitempty"` // "expense" or "income"
|
|||
|
|
Date string `json:"date,omitempty"`
|
|||
|
|
Note string `json:"note,omitempty"`
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// ConfirmationCard represents a transaction confirmation card
|
|||
|
|
type ConfirmationCard struct {
|
|||
|
|
SessionID string `json:"session_id"`
|
|||
|
|
Amount float64 `json:"amount"`
|
|||
|
|
Category string `json:"category"`
|
|||
|
|
CategoryID uint `json:"category_id"`
|
|||
|
|
Account string `json:"account"`
|
|||
|
|
AccountID uint `json:"account_id"`
|
|||
|
|
Type string `json:"type"`
|
|||
|
|
Date string `json:"date"`
|
|||
|
|
Note string `json:"note,omitempty"`
|
|||
|
|
IsComplete bool `json:"is_complete"`
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// AIChatResponse represents the response from AI chat
|
|||
|
|
type AIChatResponse struct {
|
|||
|
|
SessionID string `json:"session_id"`
|
|||
|
|
Message string `json:"message"`
|
|||
|
|
Intent string `json:"intent,omitempty"` // "create_transaction", "query", "unknown"
|
|||
|
|
Params *AITransactionParams `json:"params,omitempty"`
|
|||
|
|
ConfirmationCard *ConfirmationCard `json:"confirmation_card,omitempty"`
|
|||
|
|
NeedsFollowUp bool `json:"needs_follow_up"`
|
|||
|
|
FollowUpQuestion string `json:"follow_up_question,omitempty"`
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// AISession represents an AI conversation session
|
|||
|
|
type AISession struct {
|
|||
|
|
ID string
|
|||
|
|
UserID uint
|
|||
|
|
Params *AITransactionParams
|
|||
|
|
Messages []ChatMessage
|
|||
|
|
CreatedAt time.Time
|
|||
|
|
ExpiresAt time.Time
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// ChatMessage represents a message in the conversation
|
|||
|
|
type ChatMessage struct {
|
|||
|
|
Role string `json:"role"` // "user", "assistant", "system"
|
|||
|
|
Content string `json:"content"`
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// WhisperService handles audio transcription
|
|||
|
|
type WhisperService struct {
|
|||
|
|
config *config.Config
|
|||
|
|
httpClient *http.Client
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// NewWhisperService creates a new WhisperService
|
|||
|
|
func NewWhisperService(cfg *config.Config) *WhisperService {
|
|||
|
|
return &WhisperService{
|
|||
|
|
config: cfg,
|
|||
|
|
httpClient: &http.Client{
|
|||
|
|
Timeout: 120 * time.Second, // Increased timeout for audio transcription
|
|||
|
|
},
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// TranscribeAudio transcribes audio file to text using Whisper API
|
|||
|
|
// Supports formats: mp3, wav, m4a, webm
|
|||
|
|
// Requirements: 6.1-6.7
|
|||
|
|
func (s *WhisperService) TranscribeAudio(ctx context.Context, audioData io.Reader, filename string) (*TranscriptionResult, error) {
|
|||
|
|
if s.config.OpenAIAPIKey == "" {
|
|||
|
|
return nil, errors.New("OpenAI API key not configured (OPENAI_API_KEY)")
|
|||
|
|
}
|
|||
|
|
if s.config.OpenAIBaseURL == "" {
|
|||
|
|
return nil, errors.New("OpenAI base URL not configured (OPENAI_BASE_URL)")
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// Validate file format
|
|||
|
|
ext := strings.ToLower(filename[strings.LastIndex(filename, ".")+1:])
|
|||
|
|
validFormats := map[string]bool{"mp3": true, "wav": true, "m4a": true, "webm": true, "ogg": true, "flac": true}
|
|||
|
|
if !validFormats[ext] {
|
|||
|
|
return nil, fmt.Errorf("unsupported audio format: %s", ext)
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// Create multipart form
|
|||
|
|
var buf bytes.Buffer
|
|||
|
|
writer := multipart.NewWriter(&buf)
|
|||
|
|
|
|||
|
|
// Add audio file
|
|||
|
|
part, err := writer.CreateFormFile("file", filename)
|
|||
|
|
if err != nil {
|
|||
|
|
return nil, fmt.Errorf("failed to create form file: %w", err)
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
if _, err := io.Copy(part, audioData); err != nil {
|
|||
|
|
return nil, fmt.Errorf("failed to copy audio data: %w", err)
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// Add model field
|
|||
|
|
if err := writer.WriteField("model", s.config.WhisperModel); err != nil {
|
|||
|
|
return nil, fmt.Errorf("failed to write model field: %w", err)
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// Add language hint for Chinese
|
|||
|
|
if err := writer.WriteField("language", "zh"); err != nil {
|
|||
|
|
return nil, fmt.Errorf("failed to write language field: %w", err)
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
if err := writer.Close(); err != nil {
|
|||
|
|
return nil, fmt.Errorf("failed to close writer: %w", err)
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// Create request
|
|||
|
|
req, err := http.NewRequestWithContext(ctx, "POST", s.config.OpenAIBaseURL+"/audio/transcriptions", &buf)
|
|||
|
|
if err != nil {
|
|||
|
|
return nil, fmt.Errorf("failed to create request: %w", err)
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
req.Header.Set("Authorization", "Bearer "+s.config.OpenAIAPIKey)
|
|||
|
|
req.Header.Set("Content-Type", writer.FormDataContentType())
|
|||
|
|
|
|||
|
|
// Send request
|
|||
|
|
resp, err := s.httpClient.Do(req)
|
|||
|
|
if err != nil {
|
|||
|
|
return nil, fmt.Errorf("transcription request failed: %w", err)
|
|||
|
|
}
|
|||
|
|
defer resp.Body.Close()
|
|||
|
|
|
|||
|
|
if resp.StatusCode != http.StatusOK {
|
|||
|
|
body, _ := io.ReadAll(resp.Body)
|
|||
|
|
return nil, fmt.Errorf("transcription failed with status %d: %s", resp.StatusCode, string(body))
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// Parse response
|
|||
|
|
var result TranscriptionResult
|
|||
|
|
if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
|
|||
|
|
return nil, fmt.Errorf("failed to decode response: %w", err)
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
return &result, nil
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// LLMService handles natural language understanding
|
|||
|
|
type LLMService struct {
|
|||
|
|
config *config.Config
|
|||
|
|
httpClient *http.Client
|
|||
|
|
accountRepo *repository.AccountRepository
|
|||
|
|
categoryRepo *repository.CategoryRepository
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// NewLLMService creates a new LLMService
|
|||
|
|
func NewLLMService(cfg *config.Config, accountRepo *repository.AccountRepository, categoryRepo *repository.CategoryRepository) *LLMService {
|
|||
|
|
return &LLMService{
|
|||
|
|
config: cfg,
|
|||
|
|
httpClient: &http.Client{
|
|||
|
|
Timeout: 60 * time.Second, // Increased timeout for slow API responses
|
|||
|
|
},
|
|||
|
|
accountRepo: accountRepo,
|
|||
|
|
categoryRepo: categoryRepo,
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// ChatCompletionRequest represents OpenAI chat completion request
|
|||
|
|
type ChatCompletionRequest struct {
|
|||
|
|
Model string `json:"model"`
|
|||
|
|
Messages []ChatMessage `json:"messages"`
|
|||
|
|
Functions []Function `json:"functions,omitempty"`
|
|||
|
|
Temperature float64 `json:"temperature"`
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// Function represents an OpenAI function definition
|
|||
|
|
type Function struct {
|
|||
|
|
Name string `json:"name"`
|
|||
|
|
Description string `json:"description"`
|
|||
|
|
Parameters map[string]interface{} `json:"parameters"`
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// ChatCompletionResponse represents OpenAI chat completion response
|
|||
|
|
type ChatCompletionResponse struct {
|
|||
|
|
Choices []struct {
|
|||
|
|
Message struct {
|
|||
|
|
Role string `json:"role"`
|
|||
|
|
Content string `json:"content"`
|
|||
|
|
FunctionCall *struct {
|
|||
|
|
Name string `json:"name"`
|
|||
|
|
Arguments string `json:"arguments"`
|
|||
|
|
} `json:"function_call,omitempty"`
|
|||
|
|
} `json:"message"`
|
|||
|
|
} `json:"choices"`
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// ParseIntent extracts transaction parameters from text
|
|||
|
|
// Requirements: 7.1, 7.5, 7.6
|
|||
|
|
func (s *LLMService) ParseIntent(ctx context.Context, text string, history []ChatMessage) (*AITransactionParams, string, error) {
|
|||
|
|
// Fast path: try simple parsing first for common patterns
|
|||
|
|
// This avoids LLM call for simple inputs like "6块钱奶茶"
|
|||
|
|
// TODO: 暂时禁用本地解析快速路径,始终使用 LLM
|
|||
|
|
// simpleParams, simpleMsg, _ := s.parseIntentSimple(text)
|
|||
|
|
// if simpleParams != nil && simpleParams.Amount != nil && simpleParams.Category != "" && simpleParams.Category != "其他" {
|
|||
|
|
// // Simple parsing succeeded with amount and category, use it directly
|
|||
|
|
// return simpleParams, simpleMsg, nil
|
|||
|
|
// }
|
|||
|
|
|
|||
|
|
if s.config.OpenAIAPIKey == "" || s.config.OpenAIBaseURL == "" {
|
|||
|
|
// No API key, return simple parsing result
|
|||
|
|
simpleParams, simpleMsg, _ := s.parseIntentSimple(text)
|
|||
|
|
return simpleParams, simpleMsg, nil
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// Build messages with history
|
|||
|
|
todayDate := time.Now().Format("2006-01-02")
|
|||
|
|
systemPrompt := fmt.Sprintf(`你是一个智能记账助手。从用户描述中提取记账信息<EFBFBD>?
|
|||
|
|
|
|||
|
|
今天的日期是<EFBFBD>?s
|
|||
|
|
|
|||
|
|
规则<EFBFBD>?
|
|||
|
|
1. 金额:提取数字,<EFBFBD>?6<EFBFBD>?=6<EFBFBD>?十五<EFBFBD>?=15
|
|||
|
|
2. 分类:根据内容推断,<EFBFBD>?奶茶/咖啡/吃饭"=餐饮<E9A490>?打车/地铁"=交通,"买衣<EFBFBD>?=购物
|
|||
|
|
3. 类型:默认expense(支出),除非明确说"收入/工资/奖金/红包"
|
|||
|
|
4. 日期:默认使用今天的日期<EFBFBD>?s),除非用户明确指定其他日期
|
|||
|
|
5. 备注:提取关键描<EFBFBD>?
|
|||
|
|
|
|||
|
|
直接返回JSON,不要解释:
|
|||
|
|
{"amount":数字,"category":"分类","type":"expense或income","note":"备注","date":"YYYY-MM-DD","message":"简短确<EFBFBD>?}
|
|||
|
|
|
|||
|
|
示例(假设今天是%s):
|
|||
|
|
用户<EFBFBD>?买了<EFBFBD>?块的奶茶"
|
|||
|
|
返回:{"amount":6,"category":"餐饮","type":"expense","note":"奶茶","date":"%s","message":"记录:餐饮支<E9A5AE>?元,奶茶"}`, todayDate, todayDate, todayDate, todayDate)
|
|||
|
|
|
|||
|
|
messages := []ChatMessage{
|
|||
|
|
{
|
|||
|
|
Role: "system",
|
|||
|
|
Content: systemPrompt,
|
|||
|
|
},
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// Only add last 2 messages from history to reduce context
|
|||
|
|
historyLen := len(history)
|
|||
|
|
if historyLen > 4 {
|
|||
|
|
history = history[historyLen-4:]
|
|||
|
|
}
|
|||
|
|
messages = append(messages, history...)
|
|||
|
|
|
|||
|
|
// Add current user message
|
|||
|
|
messages = append(messages, ChatMessage{
|
|||
|
|
Role: "user",
|
|||
|
|
Content: text,
|
|||
|
|
})
|
|||
|
|
|
|||
|
|
// Create request
|
|||
|
|
reqBody := ChatCompletionRequest{
|
|||
|
|
Model: s.config.ChatModel,
|
|||
|
|
Messages: messages,
|
|||
|
|
Temperature: 0.1, // Lower temperature for more consistent output
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
jsonBody, err := json.Marshal(reqBody)
|
|||
|
|
if err != nil {
|
|||
|
|
return nil, "", fmt.Errorf("failed to marshal request: %w", err)
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
req, err := http.NewRequestWithContext(ctx, "POST", s.config.OpenAIBaseURL+"/chat/completions", bytes.NewReader(jsonBody))
|
|||
|
|
if err != nil {
|
|||
|
|
return nil, "", fmt.Errorf("failed to create request: %w", err)
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
req.Header.Set("Authorization", "Bearer "+s.config.OpenAIAPIKey)
|
|||
|
|
req.Header.Set("Content-Type", "application/json")
|
|||
|
|
|
|||
|
|
resp, err := s.httpClient.Do(req)
|
|||
|
|
if err != nil {
|
|||
|
|
return nil, "", fmt.Errorf("chat request failed: %w", err)
|
|||
|
|
}
|
|||
|
|
defer resp.Body.Close()
|
|||
|
|
|
|||
|
|
if resp.StatusCode != http.StatusOK {
|
|||
|
|
body, _ := io.ReadAll(resp.Body)
|
|||
|
|
return nil, "", fmt.Errorf("chat failed with status %d: %s", resp.StatusCode, string(body))
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
var chatResp ChatCompletionResponse
|
|||
|
|
if err := json.NewDecoder(resp.Body).Decode(&chatResp); err != nil {
|
|||
|
|
return nil, "", fmt.Errorf("failed to decode response: %w", err)
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
if len(chatResp.Choices) == 0 {
|
|||
|
|
return nil, "", errors.New("no response from AI")
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
content := chatResp.Choices[0].Message.Content
|
|||
|
|
|
|||
|
|
// Remove markdown code block if present (```json ... ```)
|
|||
|
|
content = strings.TrimSpace(content)
|
|||
|
|
if strings.HasPrefix(content, "```") {
|
|||
|
|
// Find the end of the first line (```json or ```)
|
|||
|
|
if idx := strings.Index(content, "\n"); idx != -1 {
|
|||
|
|
content = content[idx+1:]
|
|||
|
|
}
|
|||
|
|
// Remove trailing ```
|
|||
|
|
if idx := strings.LastIndex(content, "```"); idx != -1 {
|
|||
|
|
content = content[:idx]
|
|||
|
|
}
|
|||
|
|
content = strings.TrimSpace(content)
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// Parse JSON response
|
|||
|
|
var parsed struct {
|
|||
|
|
Amount *float64 `json:"amount"`
|
|||
|
|
Category string `json:"category"`
|
|||
|
|
Type string `json:"type"`
|
|||
|
|
Note string `json:"note"`
|
|||
|
|
Date string `json:"date"`
|
|||
|
|
Message string `json:"message"`
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
if err := json.Unmarshal([]byte(content), &parsed); err != nil {
|
|||
|
|
// If not JSON, return as message
|
|||
|
|
return nil, content, nil
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
params := &AITransactionParams{
|
|||
|
|
Amount: parsed.Amount,
|
|||
|
|
Category: parsed.Category,
|
|||
|
|
Type: parsed.Type,
|
|||
|
|
Note: parsed.Note,
|
|||
|
|
Date: parsed.Date,
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
return params, parsed.Message, nil
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// parseIntentSimple provides simple regex-based parsing as fallback
|
|||
|
|
// This is also used as a fast path for simple inputs
|
|||
|
|
func (s *LLMService) parseIntentSimple(text string) (*AITransactionParams, string, error) {
|
|||
|
|
params := &AITransactionParams{
|
|||
|
|
Type: "expense", // Default to expense
|
|||
|
|
Date: time.Now().Format("2006-01-02"),
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// Extract amount using regex - support various formats
|
|||
|
|
amountPatterns := []string{
|
|||
|
|
`(\d+(?:\.\d+)?)\s*(?:元|块|¥|¥|块钱|元钱)`,
|
|||
|
|
`(?:花了?|付了?|买了?|消费)\s*(\d+(?:\.\d+)?)`,
|
|||
|
|
`(\d+(?:\.\d+)?)\s*(?:的|块的)`,
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
for _, pattern := range amountPatterns {
|
|||
|
|
amountRegex := regexp.MustCompile(pattern)
|
|||
|
|
if matches := amountRegex.FindStringSubmatch(text); len(matches) > 1 {
|
|||
|
|
if amount, err := strconv.ParseFloat(matches[1], 64); err == nil {
|
|||
|
|
params.Amount = &amount
|
|||
|
|
break
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// If still no amount, try simple number extraction
|
|||
|
|
if params.Amount == nil {
|
|||
|
|
simpleAmountRegex := regexp.MustCompile(`(\d+(?:\.\d+)?)`)
|
|||
|
|
if matches := simpleAmountRegex.FindStringSubmatch(text); len(matches) > 1 {
|
|||
|
|
if amount, err := strconv.ParseFloat(matches[1], 64); err == nil {
|
|||
|
|
params.Amount = &amount
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// Enhanced category detection with priority
|
|||
|
|
categoryPatterns := []struct {
|
|||
|
|
keywords []string
|
|||
|
|
category string
|
|||
|
|
}{
|
|||
|
|
{[]string{"奶茶", "咖啡", "茶", "饮料", "柠檬", "果汁"}, "餐饮"},
|
|||
|
|
{[]string{"吃", "喝", "餐", "外卖", "饭", "面", "粉", "粥", "包子", "早餐", "午餐", "晚餐", "宵夜"}, "餐饮"},
|
|||
|
|
{[]string{"打车", "滴滴", "出租", "的士", "uber", "曹操"}, "交通"},
|
|||
|
|
{[]string{"地铁", "公交", "公车", "巴士", "轻轨", "高铁", "火车", "飞机", "机票"}, "交通"},
|
|||
|
|
{[]string{"加油", "油费", "停车", "过路费"}, "交通"},
|
|||
|
|
{[]string{"超市", "便利店", "商场", "购物", "淘宝", "京东", "拼多多"}, "购物"},
|
|||
|
|
{[]string{"买", "购"}, "购物"},
|
|||
|
|
{[]string{"水电", "电费", "水费", "燃气", "煤气", "物业"}, "生活缴费"},
|
|||
|
|
{[]string{"房租", "租金", "房贷"}, "住房"},
|
|||
|
|
{[]string{"电影", "游戏", "KTV", "唱歌", "娱乐", "玩"}, "娱乐"},
|
|||
|
|
{[]string{"医院", "药", "看病", "挂号", "医疗"}, "医疗"},
|
|||
|
|
{[]string{"话费", "流量", "充值", "手机费"}, "通讯"},
|
|||
|
|
{[]string{"工资", "薪水", "薪资", "月薪"}, "工资"},
|
|||
|
|
{[]string{"奖金", "年终奖", "绩效"}, "奖金"},
|
|||
|
|
{[]string{"红包", "转账", "收款"}, "其他收入"},
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
for _, cp := range categoryPatterns {
|
|||
|
|
for _, keyword := range cp.keywords {
|
|||
|
|
if strings.Contains(text, keyword) {
|
|||
|
|
params.Category = cp.category
|
|||
|
|
break
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
if params.Category != "" {
|
|||
|
|
break
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// Default category if not detected
|
|||
|
|
if params.Category == "" {
|
|||
|
|
params.Category = "其他"
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// Detect income keywords
|
|||
|
|
incomeKeywords := []string{"工资", "薪", "奖金", "红包", "收入", "进账", "到账", "收到", "收款"}
|
|||
|
|
for _, keyword := range incomeKeywords {
|
|||
|
|
if strings.Contains(text, keyword) {
|
|||
|
|
params.Type = "income"
|
|||
|
|
break
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// Extract note - remove amount and common words
|
|||
|
|
note := text
|
|||
|
|
if params.Amount != nil {
|
|||
|
|
note = regexp.MustCompile(`\d+(?:\.\d+)?\s*(?:元|块|¥|¥|块钱|元钱)?`).ReplaceAllString(note, "")
|
|||
|
|
}
|
|||
|
|
note = strings.TrimSpace(note)
|
|||
|
|
// Remove common filler words
|
|||
|
|
fillerWords := []string{"买了", "花了", "付了", "消费了", "一个", "一条", "一份", "的"}
|
|||
|
|
for _, word := range fillerWords {
|
|||
|
|
note = strings.ReplaceAll(note, word, "")
|
|||
|
|
}
|
|||
|
|
note = strings.TrimSpace(note)
|
|||
|
|
if note != "" {
|
|||
|
|
params.Note = note
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// Generate response message
|
|||
|
|
var message string
|
|||
|
|
if params.Amount == nil {
|
|||
|
|
message = "请问金额是多少?"
|
|||
|
|
} else {
|
|||
|
|
typeLabel := "支出"
|
|||
|
|
if params.Type == "income" {
|
|||
|
|
typeLabel = "收入"
|
|||
|
|
}
|
|||
|
|
message = fmt.Sprintf("记录:%s %.2f元,分类:%s", typeLabel, *params.Amount, params.Category)
|
|||
|
|
if params.Note != "" {
|
|||
|
|
message += ",备注:" + params.Note
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
return params, message, nil
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// MapAccountName maps natural language account name to account ID
|
|||
|
|
func (s *LLMService) MapAccountName(ctx context.Context, name string, userID uint) (*uint, string, error) {
|
|||
|
|
if name == "" {
|
|||
|
|
return nil, "", nil
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
accounts, err := s.accountRepo.GetAll(userID)
|
|||
|
|
if err != nil {
|
|||
|
|
return nil, "", err
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// Try exact match first
|
|||
|
|
for _, acc := range accounts {
|
|||
|
|
if strings.EqualFold(acc.Name, name) {
|
|||
|
|
return &acc.ID, acc.Name, nil
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// Try partial match
|
|||
|
|
for _, acc := range accounts {
|
|||
|
|
if strings.Contains(strings.ToLower(acc.Name), strings.ToLower(name)) ||
|
|||
|
|
strings.Contains(strings.ToLower(name), strings.ToLower(acc.Name)) {
|
|||
|
|
return &acc.ID, acc.Name, nil
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
return nil, "", nil
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// MapCategoryName maps natural language category name to category ID
|
|||
|
|
func (s *LLMService) MapCategoryName(ctx context.Context, name string, txType string, userID uint) (*uint, string, error) {
|
|||
|
|
if name == "" {
|
|||
|
|
return nil, "", nil
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
categories, err := s.categoryRepo.GetAll(userID)
|
|||
|
|
if err != nil {
|
|||
|
|
return nil, "", err
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// Filter by transaction type
|
|||
|
|
var filtered []models.Category
|
|||
|
|
for _, cat := range categories {
|
|||
|
|
if (txType == "expense" && cat.Type == "expense") ||
|
|||
|
|
(txType == "income" && cat.Type == "income") ||
|
|||
|
|
txType == "" {
|
|||
|
|
filtered = append(filtered, cat)
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// Try exact match first
|
|||
|
|
for _, cat := range filtered {
|
|||
|
|
if strings.EqualFold(cat.Name, name) {
|
|||
|
|
return &cat.ID, cat.Name, nil
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// Try partial match
|
|||
|
|
for _, cat := range filtered {
|
|||
|
|
if strings.Contains(strings.ToLower(cat.Name), strings.ToLower(name)) ||
|
|||
|
|
strings.Contains(strings.ToLower(name), strings.ToLower(cat.Name)) {
|
|||
|
|
return &cat.ID, cat.Name, nil
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
return nil, "", nil
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// AIBookkeepingService orchestrates AI bookkeeping functionality
|
|||
|
|
type AIBookkeepingService struct {
|
|||
|
|
whisperService *WhisperService
|
|||
|
|
llmService *LLMService
|
|||
|
|
transactionRepo *repository.TransactionRepository
|
|||
|
|
accountRepo *repository.AccountRepository
|
|||
|
|
categoryRepo *repository.CategoryRepository
|
|||
|
|
userSettingsRepo *repository.UserSettingsRepository
|
|||
|
|
db *gorm.DB
|
|||
|
|
sessions map[string]*AISession
|
|||
|
|
sessionMutex sync.RWMutex
|
|||
|
|
config *config.Config
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// NewAIBookkeepingService creates a new AIBookkeepingService
|
|||
|
|
func NewAIBookkeepingService(
|
|||
|
|
cfg *config.Config,
|
|||
|
|
transactionRepo *repository.TransactionRepository,
|
|||
|
|
accountRepo *repository.AccountRepository,
|
|||
|
|
categoryRepo *repository.CategoryRepository,
|
|||
|
|
userSettingsRepo *repository.UserSettingsRepository,
|
|||
|
|
db *gorm.DB,
|
|||
|
|
) *AIBookkeepingService {
|
|||
|
|
whisperService := NewWhisperService(cfg)
|
|||
|
|
llmService := NewLLMService(cfg, accountRepo, categoryRepo)
|
|||
|
|
|
|||
|
|
svc := &AIBookkeepingService{
|
|||
|
|
whisperService: whisperService,
|
|||
|
|
llmService: llmService,
|
|||
|
|
transactionRepo: transactionRepo,
|
|||
|
|
accountRepo: accountRepo,
|
|||
|
|
categoryRepo: categoryRepo,
|
|||
|
|
userSettingsRepo: userSettingsRepo,
|
|||
|
|
db: db,
|
|||
|
|
sessions: make(map[string]*AISession),
|
|||
|
|
config: cfg,
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// Start session cleanup goroutine
|
|||
|
|
go svc.cleanupExpiredSessions()
|
|||
|
|
|
|||
|
|
return svc
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// generateSessionID generates a unique session ID
|
|||
|
|
func generateSessionID() string {
|
|||
|
|
return fmt.Sprintf("ai_%d_%d", time.Now().UnixNano(), time.Now().Unix()%1000)
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// getOrCreateSession gets existing session or creates new one
|
|||
|
|
func (s *AIBookkeepingService) getOrCreateSession(sessionID string, userID uint) *AISession {
|
|||
|
|
s.sessionMutex.Lock()
|
|||
|
|
defer s.sessionMutex.Unlock()
|
|||
|
|
|
|||
|
|
if sessionID != "" {
|
|||
|
|
if session, ok := s.sessions[sessionID]; ok {
|
|||
|
|
if time.Now().Before(session.ExpiresAt) {
|
|||
|
|
return session
|
|||
|
|
}
|
|||
|
|
delete(s.sessions, sessionID)
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// Create new session
|
|||
|
|
newID := generateSessionID()
|
|||
|
|
session := &AISession{
|
|||
|
|
ID: newID,
|
|||
|
|
UserID: userID,
|
|||
|
|
Params: &AITransactionParams{},
|
|||
|
|
Messages: []ChatMessage{},
|
|||
|
|
CreatedAt: time.Now(),
|
|||
|
|
ExpiresAt: time.Now().Add(s.config.AISessionTimeout),
|
|||
|
|
}
|
|||
|
|
s.sessions[newID] = session
|
|||
|
|
return session
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// cleanupExpiredSessions periodically removes expired sessions
|
|||
|
|
func (s *AIBookkeepingService) cleanupExpiredSessions() {
|
|||
|
|
ticker := time.NewTicker(5 * time.Minute)
|
|||
|
|
for range ticker.C {
|
|||
|
|
s.sessionMutex.Lock()
|
|||
|
|
now := time.Now()
|
|||
|
|
for id, session := range s.sessions {
|
|||
|
|
if now.After(session.ExpiresAt) {
|
|||
|
|
delete(s.sessions, id)
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
s.sessionMutex.Unlock()
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// ProcessChat processes a chat message and returns AI response
|
|||
|
|
// Requirements: 7.2-7.4, 7.7-7.10, 12.5, 12.8
|
|||
|
|
func (s *AIBookkeepingService) ProcessChat(ctx context.Context, userID uint, sessionID string, message string) (*AIChatResponse, error) {
|
|||
|
|
session := s.getOrCreateSession(sessionID, userID)
|
|||
|
|
|
|||
|
|
// Add user message to history
|
|||
|
|
session.Messages = append(session.Messages, ChatMessage{
|
|||
|
|
Role: "user",
|
|||
|
|
Content: message,
|
|||
|
|
})
|
|||
|
|
|
|||
|
|
// Parse intent
|
|||
|
|
params, responseMsg, err := s.llmService.ParseIntent(ctx, message, session.Messages[:len(session.Messages)-1])
|
|||
|
|
if err != nil {
|
|||
|
|
return nil, fmt.Errorf("failed to parse intent: %w", err)
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// Merge with existing session params
|
|||
|
|
if params != nil {
|
|||
|
|
s.mergeParams(session.Params, params)
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// Map account and category names to IDs
|
|||
|
|
if session.Params.Account != "" && session.Params.AccountID == nil {
|
|||
|
|
accountID, accountName, _ := s.llmService.MapAccountName(ctx, session.Params.Account, userID)
|
|||
|
|
if accountID != nil {
|
|||
|
|
session.Params.AccountID = accountID
|
|||
|
|
session.Params.Account = accountName
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
if session.Params.Category != "" && session.Params.CategoryID == nil {
|
|||
|
|
categoryID, categoryName, _ := s.llmService.MapCategoryName(ctx, session.Params.Category, session.Params.Type, userID)
|
|||
|
|
if categoryID != nil {
|
|||
|
|
session.Params.CategoryID = categoryID
|
|||
|
|
session.Params.Category = categoryName
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// If category still not mapped, try to get a default category
|
|||
|
|
if session.Params.CategoryID == nil && session.Params.Category != "" {
|
|||
|
|
defaultCategoryID, defaultCategoryName := s.getDefaultCategory(userID, session.Params.Type)
|
|||
|
|
if defaultCategoryID != nil {
|
|||
|
|
session.Params.CategoryID = defaultCategoryID
|
|||
|
|
// Keep the original category name from AI, just set the ID
|
|||
|
|
if session.Params.Category == "" {
|
|||
|
|
session.Params.Category = defaultCategoryName
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// If no account specified, use default account
|
|||
|
|
if session.Params.AccountID == nil {
|
|||
|
|
defaultAccountID, defaultAccountName := s.getDefaultAccount(userID, session.Params.Type)
|
|||
|
|
if defaultAccountID != nil {
|
|||
|
|
session.Params.AccountID = defaultAccountID
|
|||
|
|
session.Params.Account = defaultAccountName
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// Check if we have all required params
|
|||
|
|
response := &AIChatResponse{
|
|||
|
|
SessionID: session.ID,
|
|||
|
|
Message: responseMsg,
|
|||
|
|
Intent: "create_transaction",
|
|||
|
|
Params: session.Params,
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// Check what's missing
|
|||
|
|
missingFields := s.getMissingFields(session.Params)
|
|||
|
|
if len(missingFields) > 0 {
|
|||
|
|
response.NeedsFollowUp = true
|
|||
|
|
response.FollowUpQuestion = s.generateFollowUpQuestion(missingFields)
|
|||
|
|
if responseMsg == "" {
|
|||
|
|
response.Message = response.FollowUpQuestion
|
|||
|
|
}
|
|||
|
|
} else {
|
|||
|
|
// All params complete, generate confirmation card
|
|||
|
|
card := s.GenerateConfirmationCard(session)
|
|||
|
|
response.ConfirmationCard = card
|
|||
|
|
response.Message = fmt.Sprintf("请确认:%s %.2f元,分类<EFBFBD>?s,账户:%s",
|
|||
|
|
s.getTypeLabel(session.Params.Type),
|
|||
|
|
*session.Params.Amount,
|
|||
|
|
session.Params.Category,
|
|||
|
|
session.Params.Account)
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// Add assistant response to history
|
|||
|
|
session.Messages = append(session.Messages, ChatMessage{
|
|||
|
|
Role: "assistant",
|
|||
|
|
Content: response.Message,
|
|||
|
|
})
|
|||
|
|
|
|||
|
|
return response, nil
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// mergeParams merges new params into existing params
|
|||
|
|
func (s *AIBookkeepingService) mergeParams(existing, new *AITransactionParams) {
|
|||
|
|
if new.Amount != nil {
|
|||
|
|
existing.Amount = new.Amount
|
|||
|
|
}
|
|||
|
|
if new.Category != "" {
|
|||
|
|
existing.Category = new.Category
|
|||
|
|
}
|
|||
|
|
if new.CategoryID != nil {
|
|||
|
|
existing.CategoryID = new.CategoryID
|
|||
|
|
}
|
|||
|
|
if new.Account != "" {
|
|||
|
|
existing.Account = new.Account
|
|||
|
|
}
|
|||
|
|
if new.AccountID != nil {
|
|||
|
|
existing.AccountID = new.AccountID
|
|||
|
|
}
|
|||
|
|
if new.Type != "" {
|
|||
|
|
existing.Type = new.Type
|
|||
|
|
}
|
|||
|
|
if new.Date != "" {
|
|||
|
|
existing.Date = new.Date
|
|||
|
|
}
|
|||
|
|
if new.Note != "" {
|
|||
|
|
existing.Note = new.Note
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// getDefaultAccount gets the default account based on transaction type
|
|||
|
|
// If no default is set, returns the first available account
|
|||
|
|
func (s *AIBookkeepingService) getDefaultAccount(userID uint, txType string) (*uint, string) {
|
|||
|
|
// First try to get user's configured default account
|
|||
|
|
settings, err := s.userSettingsRepo.GetOrCreate(userID)
|
|||
|
|
if err == nil && settings != nil {
|
|||
|
|
var accountID *uint
|
|||
|
|
if txType == "expense" && settings.DefaultExpenseAccountID != nil {
|
|||
|
|
accountID = settings.DefaultExpenseAccountID
|
|||
|
|
} else if txType == "income" && settings.DefaultIncomeAccountID != nil {
|
|||
|
|
accountID = settings.DefaultIncomeAccountID
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
if accountID != nil {
|
|||
|
|
account, err := s.accountRepo.GetByID(userID, *accountID)
|
|||
|
|
if err == nil && account != nil {
|
|||
|
|
return accountID, account.Name
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// Fallback: get the first available account
|
|||
|
|
accounts, err := s.accountRepo.GetAll(userID)
|
|||
|
|
if err != nil || len(accounts) == 0 {
|
|||
|
|
return nil, ""
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// Return the first account (usually sorted by sort_order)
|
|||
|
|
return &accounts[0].ID, accounts[0].Name
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// getDefaultCategory gets the first category of the given type
|
|||
|
|
func (s *AIBookkeepingService) getDefaultCategory(userID uint, txType string) (*uint, string) {
|
|||
|
|
categories, err := s.categoryRepo.GetAll(userID)
|
|||
|
|
if err != nil || len(categories) == 0 {
|
|||
|
|
return nil, ""
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// Find the first category matching the transaction type
|
|||
|
|
categoryType := "expense"
|
|||
|
|
if txType == "income" {
|
|||
|
|
categoryType = "income"
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
for _, cat := range categories {
|
|||
|
|
if string(cat.Type) == categoryType {
|
|||
|
|
return &cat.ID, cat.Name
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// If no matching type found, return the first category
|
|||
|
|
return &categories[0].ID, categories[0].Name
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// getMissingFields returns list of missing required fields
|
|||
|
|
func (s *AIBookkeepingService) getMissingFields(params *AITransactionParams) []string {
|
|||
|
|
var missing []string
|
|||
|
|
if params.Amount == nil {
|
|||
|
|
missing = append(missing, "amount")
|
|||
|
|
}
|
|||
|
|
if params.CategoryID == nil && params.Category == "" {
|
|||
|
|
missing = append(missing, "category")
|
|||
|
|
}
|
|||
|
|
if params.AccountID == nil && params.Account == "" {
|
|||
|
|
missing = append(missing, "account")
|
|||
|
|
}
|
|||
|
|
return missing
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// generateFollowUpQuestion generates a follow-up question for missing fields
|
|||
|
|
func (s *AIBookkeepingService) generateFollowUpQuestion(missing []string) string {
|
|||
|
|
if len(missing) == 0 {
|
|||
|
|
return ""
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
fieldNames := map[string]string{
|
|||
|
|
"amount": "金额",
|
|||
|
|
"category": "分类",
|
|||
|
|
"account": "账户",
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
var names []string
|
|||
|
|
for _, field := range missing {
|
|||
|
|
if name, ok := fieldNames[field]; ok {
|
|||
|
|
names = append(names, name)
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
if len(names) == 1 {
|
|||
|
|
return fmt.Sprintf("请问%s是多少?", names[0])
|
|||
|
|
}
|
|||
|
|
return fmt.Sprintf("请补充以下信息:%s", strings.Join(names, "、"))
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// getTypeLabel returns Chinese label for transaction type
|
|||
|
|
func (s *AIBookkeepingService) getTypeLabel(txType string) string {
|
|||
|
|
if txType == "income" {
|
|||
|
|
return "收入"
|
|||
|
|
}
|
|||
|
|
return "支出"
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// GenerateConfirmationCard creates a confirmation card from session params
|
|||
|
|
func (s *AIBookkeepingService) GenerateConfirmationCard(session *AISession) *ConfirmationCard {
|
|||
|
|
params := session.Params
|
|||
|
|
|
|||
|
|
card := &ConfirmationCard{
|
|||
|
|
SessionID: session.ID,
|
|||
|
|
Type: params.Type,
|
|||
|
|
Note: params.Note,
|
|||
|
|
IsComplete: true,
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
if params.Amount != nil {
|
|||
|
|
card.Amount = *params.Amount
|
|||
|
|
}
|
|||
|
|
if params.CategoryID != nil {
|
|||
|
|
card.CategoryID = *params.CategoryID
|
|||
|
|
}
|
|||
|
|
card.Category = params.Category
|
|||
|
|
if params.AccountID != nil {
|
|||
|
|
card.AccountID = *params.AccountID
|
|||
|
|
}
|
|||
|
|
card.Account = params.Account
|
|||
|
|
|
|||
|
|
// Set date
|
|||
|
|
if params.Date != "" {
|
|||
|
|
card.Date = params.Date
|
|||
|
|
} else {
|
|||
|
|
card.Date = time.Now().Format("2006-01-02")
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
return card
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// TranscribeAudio transcribes audio and returns text
|
|||
|
|
func (s *AIBookkeepingService) TranscribeAudio(ctx context.Context, audioData io.Reader, filename string) (*TranscriptionResult, error) {
|
|||
|
|
return s.whisperService.TranscribeAudio(ctx, audioData, filename)
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// ConfirmTransaction creates a transaction from confirmed card
|
|||
|
|
// Requirements: 7.10
|
|||
|
|
func (s *AIBookkeepingService) ConfirmTransaction(ctx context.Context, sessionID string, userID uint) (*models.Transaction, error) {
|
|||
|
|
s.sessionMutex.RLock()
|
|||
|
|
session, ok := s.sessions[sessionID]
|
|||
|
|
s.sessionMutex.RUnlock()
|
|||
|
|
|
|||
|
|
if !ok {
|
|||
|
|
return nil, errors.New("session not found or expired")
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
params := session.Params
|
|||
|
|
|
|||
|
|
// Validate required fields
|
|||
|
|
if params.Amount == nil || *params.Amount <= 0 {
|
|||
|
|
return nil, errors.New("invalid amount")
|
|||
|
|
}
|
|||
|
|
if params.CategoryID == nil {
|
|||
|
|
return nil, errors.New("category not specified")
|
|||
|
|
}
|
|||
|
|
if params.AccountID == nil {
|
|||
|
|
return nil, errors.New("account not specified")
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// Parse date
|
|||
|
|
var txDate time.Time
|
|||
|
|
if params.Date != "" {
|
|||
|
|
var err error
|
|||
|
|
txDate, err = time.Parse("2006-01-02", params.Date)
|
|||
|
|
if err != nil {
|
|||
|
|
txDate = time.Now()
|
|||
|
|
}
|
|||
|
|
} else {
|
|||
|
|
txDate = time.Now()
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// Determine transaction type
|
|||
|
|
txType := models.TransactionTypeExpense
|
|||
|
|
if params.Type == "income" {
|
|||
|
|
txType = models.TransactionTypeIncome
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// Create transaction
|
|||
|
|
tx := &models.Transaction{
|
|||
|
|
UserID: userID,
|
|||
|
|
Amount: *params.Amount,
|
|||
|
|
Type: txType,
|
|||
|
|
CategoryID: *params.CategoryID,
|
|||
|
|
AccountID: *params.AccountID,
|
|||
|
|
TransactionDate: txDate,
|
|||
|
|
Note: params.Note,
|
|||
|
|
Currency: "CNY",
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// Save transaction
|
|||
|
|
if err := s.transactionRepo.Create(tx); err != nil {
|
|||
|
|
return nil, fmt.Errorf("failed to create transaction: %w", err)
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// Update account balance
|
|||
|
|
account, err := s.accountRepo.GetByID(userID, *params.AccountID)
|
|||
|
|
if err != nil {
|
|||
|
|
return nil, fmt.Errorf("failed to find account: %w", err)
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
if txType == models.TransactionTypeExpense {
|
|||
|
|
account.Balance -= *params.Amount
|
|||
|
|
} else {
|
|||
|
|
account.Balance += *params.Amount
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
if err := s.accountRepo.Update(account); err != nil {
|
|||
|
|
return nil, fmt.Errorf("failed to update account balance: %w", err)
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// Clean up session
|
|||
|
|
s.sessionMutex.Lock()
|
|||
|
|
delete(s.sessions, sessionID)
|
|||
|
|
s.sessionMutex.Unlock()
|
|||
|
|
|
|||
|
|
return tx, nil
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// GetSession returns session by ID
|
|||
|
|
func (s *AIBookkeepingService) GetSession(sessionID string) (*AISession, bool) {
|
|||
|
|
s.sessionMutex.RLock()
|
|||
|
|
defer s.sessionMutex.RUnlock()
|
|||
|
|
|
|||
|
|
session, ok := s.sessions[sessionID]
|
|||
|
|
if !ok || time.Now().After(session.ExpiresAt) {
|
|||
|
|
return nil, false
|
|||
|
|
}
|
|||
|
|
return session, true
|
|||
|
|
}
|