From c4d7825328e1b45e6314cea1bde8e1a1cbac138e Mon Sep 17 00:00:00 2001 From: admin <1297598740@qq.com> Date: Fri, 30 Jan 2026 12:48:41 +0800 Subject: [PATCH] =?UTF-8?q?feat:=20=E6=96=B0=E5=A2=9E=20AI=20=E8=AE=B0?= =?UTF-8?q?=E8=B4=A6=E5=8A=9F=E8=83=BD=EF=BC=8C=E5=8C=85=E6=8B=AC=E6=B5=81?= =?UTF-8?q?=E5=BC=8F=E8=81=8A=E5=A4=A9=E3=80=81=E6=B4=9E=E5=AF=9F=E7=94=9F?= =?UTF-8?q?=E6=88=90=E3=80=81=E8=AF=AD=E9=9F=B3=E8=BD=AC=E5=BD=95=E5=92=8C?= =?UTF-8?q?=E4=BA=A4=E6=98=93=E7=A1=AE=E8=AE=A4=E6=8E=A5=E5=8F=A3=E3=80=82?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- internal/handler/ai_handler.go | 57 ++++ internal/service/ai_bookkeeping_service.go | 293 +++++++++++++++++++++ 2 files changed, 350 insertions(+) diff --git a/internal/handler/ai_handler.go b/internal/handler/ai_handler.go index 749a5e7..2c66903 100644 --- a/internal/handler/ai_handler.go +++ b/internal/handler/ai_handler.go @@ -2,6 +2,7 @@ package handler import ( "encoding/json" + "io" "net/http" "accounting-app/internal/service" @@ -14,6 +15,61 @@ type AIHandler struct { aiService *service.AIBookkeepingService } +// StreamChat handles streaming chat messages +// POST /api/v1/ai/chat/stream +func (h *AIHandler) StreamChat(c *gin.Context) { + var req ChatRequest + if err := c.ShouldBindJSON(&req); err != nil { + c.JSON(http.StatusBadRequest, gin.H{ + "success": false, + "error": "Invalid request: " + err.Error(), + }) + return + } + + // Get user ID from context + userID := uint(1) + if id, exists := c.Get("user_id"); exists { + userID = id.(uint) + } + + // Set headers for SSE + c.Writer.Header().Set("Content-Type", "text/event-stream") + c.Writer.Header().Set("Cache-Control", "no-cache") + c.Writer.Header().Set("Connection", "keep-alive") + c.Writer.Header().Set("Transfer-Encoding", "chunked") + + c.Stream(func(w io.Writer) bool { + // Define the callback for chunks + onChunk := func(chunk string) { + // Sanitize chunk for SSE format (replace newlines to avoid breaking the stream protocol, + // or just send raw if client handles it. Usually data: \n\n) + // For robustness, we JSON encode the data payload + dataMap := map[string]string{"text": chunk} + jsonData, _ := json.Marshal(dataMap) + // message event + c.SSEvent("message", string(jsonData)) + } + + // Call service + response, err := h.aiService.StreamProcessChat(c.Request.Context(), userID, req.SessionID, req.Message, onChunk) + + if err != nil { + // Send error event + errMap := map[string]string{"error": err.Error()} + jsonErr, _ := json.Marshal(errMap) + c.SSEvent("error", string(jsonErr)) + return false + } + + // Send final result event with metadata + jsonResult, _ := json.Marshal(response) + c.SSEvent("result", string(jsonResult)) + + return false // Stop stream + }) +} + // NewAIHandler creates a new AIHandler func NewAIHandler(aiService *service.AIBookkeepingService) *AIHandler { return &AIHandler{ @@ -47,6 +103,7 @@ func (h *AIHandler) RegisterRoutes(rg *gin.RouterGroup) { ai := rg.Group("/ai") { ai.POST("/chat", h.Chat) + ai.POST("/chat/stream", h.StreamChat) // New streaming endpoint ai.POST("/transcribe", h.Transcribe) ai.POST("/confirm", h.Confirm) ai.POST("/insight", h.Insight) diff --git a/internal/service/ai_bookkeeping_service.go b/internal/service/ai_bookkeeping_service.go index f9aa6c0..d3af039 100644 --- a/internal/service/ai_bookkeeping_service.go +++ b/internal/service/ai_bookkeeping_service.go @@ -1,6 +1,7 @@ package service import ( + "bufio" "bytes" "context" "crypto/md5" @@ -430,6 +431,17 @@ type ChatCompletionResponse struct { } `json:"choices"` } +// ChatCompletionStreamResponse represents OpenAI chat completion stream response +type ChatCompletionStreamResponse struct { + Choices []struct { + Delta struct { + Role string `json:"role,omitempty"` + Content string `json:"content,omitempty"` + } `json:"delta"` + FinishReason string `json:"finish_reason,omitempty"` + } `json:"choices"` +} + // extractCustomPrompt 从用户消息中提取自定义 System/Persona prompt // 如果消息包含 "System:" 或 "Persona:",则提取其后的内容作为自定义 prompt func extractCustomPrompt(text string) string { @@ -772,6 +784,99 @@ func (s *LLMService) GenerateReport(ctx context.Context, prompt string) (string, return chatResp.Choices[0].Message.Content, nil } +// StreamChat calls OpenAI Chat Completion API with streaming enabled +// It returns a channel that yields text chunks, and an error if the request fails to start +func (s *LLMService) StreamChat(ctx context.Context, messages []ChatMessage) (<-chan string, error) { + if s.config.OpenAIAPIKey == "" || s.config.OpenAIBaseURL == "" { + return nil, errors.New("OpenAI API not configured") + } + + reqBody := ChatCompletionRequest{ + Model: s.config.ChatModel, + Messages: messages, + Temperature: 0.1, // Lower temperature for consistent output in chat + } + + // Create a map to inject stream: true + var bodyMap map[string]interface{} + jsonBody, _ := json.Marshal(reqBody) + json.Unmarshal(jsonBody, &bodyMap) + bodyMap["stream"] = true + + streamBody, err := json.Marshal(bodyMap) + if err != nil { + return nil, fmt.Errorf("failed to marshal request: %w", err) + } + + req, err := http.NewRequestWithContext(ctx, "POST", s.config.OpenAIBaseURL+"/chat/completions", bytes.NewReader(streamBody)) + if err != nil { + return nil, fmt.Errorf("failed to create request: %w", err) + } + + req.Header.Set("Authorization", "Bearer "+s.config.OpenAIAPIKey) + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Accept", "text/event-stream") + + // Use a separate client for streaming to avoid timeout issues? + // Or just reuse. The standard timeout might be too short for long streams. + // We'll trust the context cancellation for now. + resp, err := s.httpClient.Do(req) + if err != nil { + return nil, fmt.Errorf("stream request failed: %w", err) + } + + if resp.StatusCode != http.StatusOK { + resp.Body.Close() + return nil, fmt.Errorf("stream failed with status %d", resp.StatusCode) + } + + outChan := make(chan string) + + go func() { + defer resp.Body.Close() + defer close(outChan) + + reader := bufio.NewReader(resp.Body) + for { + line, err := reader.ReadBytes('\n') + if err != nil { + if err != io.EOF { + // Log error? + } + return + } + + line = bytes.TrimSpace(line) + if !bytes.HasPrefix(line, []byte("data: ")) { + continue + } + + data := bytes.TrimPrefix(line, []byte("data: ")) + if string(data) == "[DONE]" { + return + } + + var streamResp ChatCompletionStreamResponse + if err := json.Unmarshal(data, &streamResp); err != nil { + continue + } + + if len(streamResp.Choices) > 0 { + content := streamResp.Choices[0].Delta.Content + if content != "" { + select { + case outChan <- content: + case <-ctx.Done(): + return + } + } + } + } + }() + + return outChan, nil +} + // MapAccountName maps natural language account name to account ID func (s *LLMService) MapAccountName(ctx context.Context, name string, userID uint) (*uint, string, error) { if name == "" { @@ -1033,6 +1138,194 @@ func (s *AIBookkeepingService) ProcessChat(ctx context.Context, userID uint, ses return response, nil } +// StreamProcessChat processes a chat message and streams the AI response +// It calls onChunk callback for every text chunk received +// Returns the final AIChatResponse after stream completion for metadata handling +func (s *AIBookkeepingService) StreamProcessChat(ctx context.Context, userID uint, sessionID string, message string, onChunk func(string)) (*AIChatResponse, error) { + session := s.getOrCreateSession(sessionID, userID) + + // Add user message to history + session.Messages = append(session.Messages, ChatMessage{ + Role: "user", + Content: message, + }) + + // 1. 获取财务上下文 + fc, _ := s.GetUserFinancialContext(ctx, userID) + + // 2. 检测纯查询意图 + queryIntent := s.detectQueryIntent(message) + if queryIntent != "" && fc != nil { + // 查询意图通常较短,直接生成并模拟流式输出 + responseMsg := s.handleQueryIntent(ctx, queryIntent, message, fc) + + // 模拟流式输出 + chunkSize := 4 + runes := []rune(responseMsg) + for i := 0; i < len(runes); i += chunkSize { + end := i + chunkSize + if end > len(runes) { + end = len(runes) + } + onChunk(string(runes[i:end])) + time.Sleep(20 * time.Millisecond) // 稍微延迟模拟打字 + } + + response := &AIChatResponse{ + SessionID: session.ID, + Message: responseMsg, + Intent: queryIntent, + Params: session.Params, + } + + session.Messages = append(session.Messages, ChatMessage{ + Role: "assistant", + Content: response.Message, + }) + return response, nil + } + + // 3. 检测消费建议意图 + isSpendingAdvice := s.isSpendingAdviceIntent(message) + + // Determine prompt + var systemPrompt string + // Fetch user categories for prompt context + var categoryNamesStr string + if categories, err := s.categoryRepo.GetAll(userID); err == nil && len(categories) > 0 { + var names []string + for _, c := range categories { + names = append(names, c.Name) + } + categoryNamesStr = strings.Join(names, "/") + } + + todayDate := time.Now().Format("2006-01-02") + catPrompt := "2. 分类:根据内容推断" + if categoryNamesStr != "" { + catPrompt = fmt.Sprintf("2. 分类:必须从以下已有分类中选择最匹配的一项:[%s]。", categoryNamesStr) + } + + systemPrompt = fmt.Sprintf(`你是一个智能记账助手。请以自然的对话方式回复用户,同时在回复中确认识别到的记账信息。 +今天的日期是%s + +规则: +1. 你的性格:贱萌、嘴硬心软、偶尔凡尔赛。 +2. 如果用户在记账:请提取金额、分类(%s)、描述。 +3. 如果用户在寻求建议:请给出毒舌但有用的建议。 +4. **不要返回 JSON**,直接以自然语言回复。`, todayDate, catPrompt) + + // Construct minimal messages for stream + streamMessages := []ChatMessage{ + {Role: "system", Content: systemPrompt}, + } + // Add recent history + historyLen := len(session.Messages) + if historyLen > 4 { + streamMessages = append(streamMessages, session.Messages[historyLen-4:]...) + } else { + streamMessages = append(streamMessages, session.Messages...) + } + + // Execute Stream + streamChan, err := s.llmService.StreamChat(ctx, streamMessages) + if err != nil { + return nil, err + } + + var fullResponseBuilder strings.Builder + for chunk := range streamChan { + fullResponseBuilder.WriteString(chunk) + onChunk(chunk) + } + fullResponse := fullResponseBuilder.String() + + // 4. Update Session History with Assistant Message + session.Messages = append(session.Messages, ChatMessage{ + Role: "assistant", + Content: fullResponse, + }) + + // 5. Post-Stream Logic: Parse Intent and Generate Card + // We re-parse the *user message* (not the AI response) to get structured data + // The original ParseIntent logic uses a specific JSON prompt which we skipped for the stream. + // So we need to call ParseIntent again or have a way to extract it. + + // OPTION A: Call ParseIntent strictly for data extraction (hidden from user) + // This ensures we get the robust JSON parsing logic. + params, _, err := s.llmService.ParseIntent(ctx, message, session.Messages[:len(session.Messages)-2], userID) + if err != nil { + // Log error but don't fail, user already saw the text + fmt.Printf("Stream ParseIntent failed: %v", err) + } + + if params != nil { + s.mergeParams(session.Params, params) + } + + // ... (Same mapping logic as ProcessChat) + if session.Params.Account != "" && session.Params.AccountID == nil { + accountID, accountName, _ := s.llmService.MapAccountName(ctx, session.Params.Account, userID) + if accountID != nil { + session.Params.AccountID = accountID + session.Params.Account = accountName + } + } + + if session.Params.Category != "" && session.Params.CategoryID == nil { + categoryID, categoryName, _ := s.llmService.MapCategoryName(ctx, session.Params.Category, session.Params.Type, userID) + if categoryID != nil { + session.Params.CategoryID = categoryID + session.Params.Category = categoryName + } + } + + // Defaults... + if session.Params.CategoryID == nil && session.Params.Category != "" { + defaultCategoryID, defaultCategoryName := s.getDefaultCategory(userID, session.Params.Type) + if defaultCategoryID != nil { + session.Params.CategoryID = defaultCategoryID + if session.Params.Category == "" { + session.Params.Category = defaultCategoryName + } + } + } + + if session.Params.AccountID == nil { + defaultAccountID, defaultAccountName := s.getDefaultAccount(userID, session.Params.Type) + if defaultAccountID != nil { + session.Params.AccountID = defaultAccountID + session.Params.Account = defaultAccountName + } + } + + // Construct final response object + response := &AIChatResponse{ + SessionID: session.ID, + Message: fullResponse, // The text already streamed + Intent: "create_transaction", + Params: session.Params, + } + + if isSpendingAdvice { + response.Intent = "spending_advice" + } + + missingFields := s.getMissingFields(session.Params) + if len(missingFields) > 0 { + response.NeedsFollowUp = true + // Note: We already streamed a response. We probably shouldn't override it with a follow-up question + // unless the streamed response was generic. + // For now, we trust the streamed response covers it, or the UI handles "NeedsFollowUp" silently if needed. + // Actually, let's leave it to the UI. If NeedsFollowUp is true, maybe don't show confirmation card. + } else { + card := s.GenerateConfirmationCard(session) + response.ConfirmationCard = card + } + + return response, nil +} + // detectQueryIntent 检测用户查询意图 func (s *AIBookkeepingService) detectQueryIntent(message string) string { budgetKeywords := []string{"预算", "剩多少", "还能花", "余额"} // 注意:余额可能指账户余额,这里简化处理