From 38469739dea676a6e27ced9d0470af2038eef4f4 Mon Sep 17 00:00:00 2001 From: 12975 <1297598740@qq.com> Date: Thu, 29 Jan 2026 19:06:30 +0800 Subject: [PATCH 01/13] =?UTF-8?q?feat:=20=E6=B7=BB=E5=8A=A0=20Gitee=20OAut?= =?UTF-8?q?h=20=E8=AE=A4=E8=AF=81=E6=94=AF=E6=8C=81=E5=B9=B6=E6=9B=B4?= =?UTF-8?q?=E6=96=B0=E7=9B=B8=E5=85=B3=E9=85=8D=E7=BD=AE=E3=80=81=E8=B7=AF?= =?UTF-8?q?=E7=94=B1=E5=92=8C=E6=9C=8D=E5=8A=A1=E3=80=82?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .env.prod | 7 + internal/config/config.go | 10 + internal/handler/auth_handler.go | 56 ++++- internal/router/router.go | 12 +- internal/service/gitee_oauth_service.go | 294 ++++++++++++++++++++++++ 5 files changed, 373 insertions(+), 6 deletions(-) create mode 100644 internal/service/gitee_oauth_service.go diff --git a/.env.prod b/.env.prod index d964eff..ddb7265 100644 --- a/.env.prod +++ b/.env.prod @@ -58,6 +58,13 @@ GITHUB_CLIENT_SECRET=7e154e464dccd913a92cf580021f2a5dc51aac93 GITHUB_REDIRECT_URL=https://bk.swalktech.top/api/v1/auth/github/callback FRONTEND_URL=https://bk.swalktech.top +# ============================================ +# Gitee OAuth 配置(可选) +# ============================================ +GITEE_CLIENT_ID=ccc286f08aac25a6304c61a1a7a5a4418e0fd73948d8f8339ca941bfb5379280 +GITEE_CLIENT_SECRET=b7832bdfc3cadf2e00dba9e2b694345f88afb591603a2edf3af19484b68efe9a +GITEE_REDIRECT_URL=https://bk.swalktech.top/api/v1/auth/gitee/callback + # ============================================ # 网络配置 # ============================================ diff --git a/internal/config/config.go b/internal/config/config.go index 869631f..6115470 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -46,6 +46,11 @@ type Config struct { GitHubRedirectURL string FrontendURL string + // Gitee OAuth configuration + GiteeClientID string + GiteeClientSecret string + GiteeRedirectURL string + // AI configuration (OpenAI compatible) OpenAIAPIKey string OpenAIBaseURL string @@ -105,6 +110,11 @@ func Load() *Config { GitHubRedirectURL: getEnv("GITHUB_REDIRECT_URL", ""), FrontendURL: getEnv("FRONTEND_URL", "http://localhost:2613"), + // Gitee OAuth + GiteeClientID: getEnv("GITEE_CLIENT_ID", ""), + GiteeClientSecret: getEnv("GITEE_CLIENT_SECRET", ""), + GiteeRedirectURL: getEnv("GITEE_REDIRECT_URL", ""), + // AI (OpenAI compatible) OpenAIAPIKey: getEnv("OPENAI_API_KEY", ""), OpenAIBaseURL: getEnv("OPENAI_BASE_URL", ""), diff --git a/internal/handler/auth_handler.go b/internal/handler/auth_handler.go index d17a1f6..3aaa1a4 100644 --- a/internal/handler/auth_handler.go +++ b/internal/handler/auth_handler.go @@ -15,15 +15,16 @@ import ( type AuthHandler struct { authService *service.AuthService gitHubOAuthService *service.GitHubOAuthService + giteeOAuthService *service.GiteeOAuthService cfg *config.Config } -func NewAuthHandler(authService *service.AuthService, gitHubOAuthService *service.GitHubOAuthService) *AuthHandler { - return &AuthHandler{authService: authService, gitHubOAuthService: gitHubOAuthService} +func NewAuthHandler(authService *service.AuthService, gitHubOAuthService *service.GitHubOAuthService, giteeOAuthService *service.GiteeOAuthService) *AuthHandler { + return &AuthHandler{authService: authService, gitHubOAuthService: gitHubOAuthService, giteeOAuthService: giteeOAuthService} } -func NewAuthHandlerWithConfig(authService *service.AuthService, gitHubOAuthService *service.GitHubOAuthService, cfg *config.Config) *AuthHandler { - return &AuthHandler{authService: authService, gitHubOAuthService: gitHubOAuthService, cfg: cfg} +func NewAuthHandlerWithConfig(authService *service.AuthService, gitHubOAuthService *service.GitHubOAuthService, giteeOAuthService *service.GiteeOAuthService, cfg *config.Config) *AuthHandler { + return &AuthHandler{authService: authService, gitHubOAuthService: gitHubOAuthService, giteeOAuthService: giteeOAuthService, cfg: cfg} } type RegisterRequest struct { @@ -154,6 +155,8 @@ func (h *AuthHandler) RegisterRoutes(rg *gin.RouterGroup) { auth.POST("/refresh", h.RefreshToken) auth.GET("/github", h.GitHubLogin) auth.GET("/github/callback", h.GitHubCallback) + auth.GET("/gitee", h.GiteeLogin) + auth.GET("/gitee/callback", h.GiteeCallback) } func (h *AuthHandler) RegisterProtectedRoutes(rg *gin.RouterGroup) { @@ -206,3 +209,48 @@ func (h *AuthHandler) GitHubCallback(c *gin.Context) { frontendURL, url.QueryEscape(tokens.AccessToken), url.QueryEscape(tokens.RefreshToken), user.ID) c.Redirect(302, redirectURL) } + +func (h *AuthHandler) GiteeLogin(c *gin.Context) { + if h.giteeOAuthService == nil { + api.BadRequest(c, "Gitee OAuth is not configured") + return + } + state := c.Query("state") + if state == "" { + state = "default" + } + authURL := h.giteeOAuthService.GetAuthorizationURL(state) + c.Redirect(302, authURL) +} + +func (h *AuthHandler) GiteeCallback(c *gin.Context) { + if h.giteeOAuthService == nil { + api.BadRequest(c, "Gitee OAuth is not configured") + return + } + + frontendURL := "http://localhost:2613" + if h.cfg != nil && h.cfg.FrontendURL != "" { + frontendURL = h.cfg.FrontendURL + } + + code := c.Query("code") + if code == "" { + redirectURL := fmt.Sprintf("%s/login?error=missing_code", frontendURL) + c.Redirect(302, redirectURL) + return + } + + user, tokens, err := h.giteeOAuthService.HandleCallback(code) + if err != nil { + fmt.Printf("[Auth] Gitee callback failed: %v\n", err) + redirectURL := fmt.Sprintf("%s/login?error=%s", frontendURL, url.QueryEscape(err.Error())) + c.Redirect(302, redirectURL) + return + } + + // 重定向到前端回调页面,带上token信息 + redirectURL := fmt.Sprintf("%s/auth/gitee/callback?access_token=%s&refresh_token=%s&user_id=%d", + frontendURL, url.QueryEscape(tokens.AccessToken), url.QueryEscape(tokens.RefreshToken), user.ID) + c.Redirect(302, redirectURL) +} diff --git a/internal/router/router.go b/internal/router/router.go index 5e12541..f4fdaa8 100644 --- a/internal/router/router.go +++ b/internal/router/router.go @@ -58,7 +58,11 @@ func Setup(db *gorm.DB, yunAPIClient *service.YunAPIClient, cfg *config.Config) if cfg.GitHubClientID != "" && cfg.GitHubClientSecret != "" { gitHubOAuthService = service.NewGitHubOAuthService(userRepo, authService, cfg) } - authHandler := handler.NewAuthHandlerWithConfig(authService, gitHubOAuthService, cfg) + var giteeOAuthService *service.GiteeOAuthService + if cfg.GiteeClientID != "" && cfg.GiteeClientSecret != "" { + giteeOAuthService = service.NewGiteeOAuthService(userRepo, authService, cfg) + } + authHandler := handler.NewAuthHandlerWithConfig(authService, gitHubOAuthService, giteeOAuthService, cfg) authMiddleware := middleware.NewAuthMiddleware(authService) // Initialize services @@ -331,7 +335,11 @@ func SetupWithRedis(db *gorm.DB, yunAPIClient *service.YunAPIClient, redisClient if cfg.GitHubClientID != "" && cfg.GitHubClientSecret != "" { gitHubOAuthService = service.NewGitHubOAuthService(userRepo, authService, cfg) } - authHandler := handler.NewAuthHandlerWithConfig(authService, gitHubOAuthService, cfg) + var giteeOAuthService *service.GiteeOAuthService + if cfg.GiteeClientID != "" && cfg.GiteeClientSecret != "" { + giteeOAuthService = service.NewGiteeOAuthService(userRepo, authService, cfg) + } + authHandler := handler.NewAuthHandlerWithConfig(authService, gitHubOAuthService, giteeOAuthService, cfg) authMiddleware := middleware.NewAuthMiddleware(authService) // Initialize services diff --git a/internal/service/gitee_oauth_service.go b/internal/service/gitee_oauth_service.go new file mode 100644 index 0000000..225386e --- /dev/null +++ b/internal/service/gitee_oauth_service.go @@ -0,0 +1,294 @@ +// Package service provides business logic for the application +package service + +import ( + "encoding/json" + "errors" + "fmt" + "net/http" + "net/url" + "strings" + "time" + + "accounting-app/internal/config" + "accounting-app/internal/models" + "accounting-app/internal/repository" +) + +// Gitee OAuth errors +var ( + ErrGiteeOAuthFailed = errors.New("gitee oauth authentication failed") + ErrGiteeUserInfoFailed = errors.New("failed to get gitee user info") +) + +// GiteeUser represents Gitee user information +type GiteeUser struct { + ID int64 `json:"id"` + Login string `json:"login"` + Email string `json:"email"` + Name string `json:"name"` + AvatarURL string `json:"avatar_url"` +} + +// GiteeTokenResponse represents Gitee OAuth token response +type GiteeTokenResponse struct { + AccessToken string `json:"access_token"` + TokenType string `json:"token_type"` + ExpiresIn int `json:"expires_in"` + RefreshToken string `json:"refresh_token"` + Scope string `json:"scope"` + CreatedAt int64 `json:"created_at"` +} + +// GiteeOAuthService handles Gitee OAuth operations +type GiteeOAuthService struct { + userRepo *repository.UserRepository + authService *AuthService + cfg *config.Config + httpClient *http.Client +} + +// NewGiteeOAuthService creates a new GiteeOAuthService instance +func NewGiteeOAuthService(userRepo *repository.UserRepository, authService *AuthService, cfg *config.Config) *GiteeOAuthService { + return &GiteeOAuthService{ + userRepo: userRepo, + authService: authService, + cfg: cfg, + httpClient: &http.Client{ + Timeout: 60 * time.Second, + Transport: &http.Transport{ + Proxy: http.ProxyFromEnvironment, + }, + }, + } +} + +// GetAuthorizationURL returns the Gitee OAuth authorization URL +func (s *GiteeOAuthService) GetAuthorizationURL(state string) string { + params := url.Values{} + params.Set("client_id", s.cfg.GiteeClientID) + params.Set("redirect_uri", s.cfg.GiteeRedirectURL) + params.Set("response_type", "code") + params.Set("scope", "user_info emails") + if state != "" { + params.Set("state", state) + } + + return fmt.Sprintf("https://gitee.com/oauth/authorize?%s", params.Encode()) +} + +// ExchangeCodeForToken exchanges authorization code for access token +func (s *GiteeOAuthService) ExchangeCodeForToken(code string) (*GiteeTokenResponse, error) { + data := url.Values{} + data.Set("grant_type", "authorization_code") + data.Set("client_id", s.cfg.GiteeClientID) + data.Set("client_secret", s.cfg.GiteeClientSecret) + data.Set("code", code) + data.Set("redirect_uri", s.cfg.GiteeRedirectURL) + + req, err := http.NewRequest("POST", "https://gitee.com/oauth/token", strings.NewReader(data.Encode())) + if err != nil { + fmt.Printf("[Gitee] Failed to create request: %v\n", err) + return nil, err + } + + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + req.Header.Set("Accept", "application/json") + + fmt.Printf("[Gitee] Exchanging code for token...\n") + resp, err := s.httpClient.Do(req) + if err != nil { + fmt.Printf("[Gitee] Request failed: %v\n", err) + return nil, err + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + fmt.Printf("[Gitee] Token exchange failed with status: %d\n", resp.StatusCode) + return nil, ErrGiteeOAuthFailed + } + + var tokenResp GiteeTokenResponse + if err := json.NewDecoder(resp.Body).Decode(&tokenResp); err != nil { + fmt.Printf("[Gitee] Failed to decode response: %v\n", err) + return nil, err + } + + if tokenResp.AccessToken == "" { + fmt.Printf("[Gitee] No access token in response\n") + return nil, ErrGiteeOAuthFailed + } + + return &tokenResp, nil +} + +// GetUserInfo retrieves Gitee user information using access token +func (s *GiteeOAuthService) GetUserInfo(accessToken string) (*GiteeUser, error) { + reqURL := fmt.Sprintf("https://gitee.com/api/v5/user?access_token=%s", url.QueryEscape(accessToken)) + req, err := http.NewRequest("GET", reqURL, nil) + if err != nil { + return nil, err + } + + req.Header.Set("Accept", "application/json") + + resp, err := s.httpClient.Do(req) + if err != nil { + return nil, err + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + fmt.Printf("[Gitee] Get user info failed with status: %d\n", resp.StatusCode) + return nil, ErrGiteeUserInfoFailed + } + + var user GiteeUser + if err := json.NewDecoder(resp.Body).Decode(&user); err != nil { + return nil, err + } + + // If email is empty, try to get from emails endpoint + if user.Email == "" { + email, err := s.getUserEmail(accessToken) + if err == nil && email != "" { + user.Email = email + } + } + + return &user, nil +} + +// getUserEmail retrieves user's primary email from Gitee +func (s *GiteeOAuthService) getUserEmail(accessToken string) (string, error) { + reqURL := fmt.Sprintf("https://gitee.com/api/v5/emails?access_token=%s", url.QueryEscape(accessToken)) + req, err := http.NewRequest("GET", reqURL, nil) + if err != nil { + return "", err + } + + req.Header.Set("Accept", "application/json") + + resp, err := s.httpClient.Do(req) + if err != nil { + return "", err + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + return "", ErrGiteeUserInfoFailed + } + + var emails []struct { + Email string `json:"email"` + State string `json:"state"` + Scope []string `json:"scope"` + } + + if err := json.NewDecoder(resp.Body).Decode(&emails); err != nil { + return "", err + } + + // Find primary/confirmed email + for _, e := range emails { + if e.State == "confirmed" { + return e.Email, nil + } + } + + // Fallback to first email + if len(emails) > 0 { + return emails[0].Email, nil + } + + return "", nil +} + +// HandleCallback processes Gitee OAuth callback +func (s *GiteeOAuthService) HandleCallback(code string) (*models.User, *TokenPair, error) { + // Exchange code for token + tokenResp, err := s.ExchangeCodeForToken(code) + if err != nil { + return nil, nil, err + } + + // Get Gitee user info + giteeUser, err := s.GetUserInfo(tokenResp.AccessToken) + if err != nil { + return nil, nil, err + } + + // Check if user already exists with this Gitee account + user, err := s.userRepo.GetByOAuthProvider("gitee", fmt.Sprintf("%d", giteeUser.ID)) + if err == nil { + // User exists, update token and return + _ = s.userRepo.UpdateOAuthToken("gitee", fmt.Sprintf("%d", giteeUser.ID), tokenResp.AccessToken) + tokens, err := s.authService.generateTokenPair(user) + if err != nil { + return nil, nil, err + } + return user, tokens, nil + } + + // Check if user exists with same email + if giteeUser.Email != "" { + existingUser, err := s.userRepo.GetByEmail(giteeUser.Email) + if err == nil { + // Link Gitee account to existing user + oauth := &models.OAuthAccount{ + UserID: existingUser.ID, + Provider: "gitee", + ProviderID: fmt.Sprintf("%d", giteeUser.ID), + AccessToken: tokenResp.AccessToken, + } + if err := s.userRepo.CreateOAuthAccount(oauth); err != nil { + return nil, nil, err + } + tokens, err := s.authService.generateTokenPair(existingUser) + if err != nil { + return nil, nil, err + } + return existingUser, tokens, nil + } + } + + // Create new user + username := giteeUser.Login + if giteeUser.Name != "" { + username = giteeUser.Name + } + + email := giteeUser.Email + if email == "" { + email = fmt.Sprintf("%d@gitee.user", giteeUser.ID) + } + + newUser := &models.User{ + Email: email, + Username: username, + Avatar: giteeUser.AvatarURL, + IsActive: true, + } + + if err := s.userRepo.Create(newUser); err != nil { + return nil, nil, err + } + + // Create OAuth account link + oauth := &models.OAuthAccount{ + UserID: newUser.ID, + Provider: "gitee", + ProviderID: fmt.Sprintf("%d", giteeUser.ID), + AccessToken: tokenResp.AccessToken, + } + if err := s.userRepo.CreateOAuthAccount(oauth); err != nil { + return nil, nil, err + } + + tokens, err := s.authService.generateTokenPair(newUser) + if err != nil { + return nil, nil, err + } + + return newUser, tokens, nil +} From f1fa7b6c548872d72aa9cf0bf9bfe5b8c935ea9c Mon Sep 17 00:00:00 2001 From: 12975 <1297598740@qq.com> Date: Thu, 29 Jan 2026 19:21:13 +0800 Subject: [PATCH 02/13] =?UTF-8?q?feat:=20=E6=96=B0=E5=A2=9E=20Gitee=20OAut?= =?UTF-8?q?h=20=E6=9C=8D=E5=8A=A1=EF=BC=8C=E7=94=A8=E4=BA=8E=E5=A4=84?= =?UTF-8?q?=E7=90=86=20Gitee=20=E6=8E=88=E6=9D=83=E7=99=BB=E5=BD=95?= =?UTF-8?q?=E6=B5=81=E7=A8=8B=E3=80=82?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- internal/service/gitee_oauth_service.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/internal/service/gitee_oauth_service.go b/internal/service/gitee_oauth_service.go index 225386e..e255cba 100644 --- a/internal/service/gitee_oauth_service.go +++ b/internal/service/gitee_oauth_service.go @@ -69,7 +69,7 @@ func (s *GiteeOAuthService) GetAuthorizationURL(state string) string { params.Set("client_id", s.cfg.GiteeClientID) params.Set("redirect_uri", s.cfg.GiteeRedirectURL) params.Set("response_type", "code") - params.Set("scope", "user_info emails") + params.Set("scope", "user_info") if state != "" { params.Set("state", state) } From cf34f8b3d01c9b82f239ec73b6ff2fcd2aef983b Mon Sep 17 00:00:00 2001 From: 12975 <1297598740@qq.com> Date: Thu, 29 Jan 2026 21:43:35 +0800 Subject: [PATCH 03/13] =?UTF-8?q?feat:=20=E6=B7=BB=E5=8A=A0=E9=A2=84?= =?UTF-8?q?=E7=AE=97=E7=AE=A1=E7=90=86=E5=8A=9F=E8=83=BD=EF=BC=8C=E5=8C=85?= =?UTF-8?q?=E6=8B=AC=E9=A2=84=E7=AE=97=E7=9A=84=E5=88=9B=E5=BB=BA=E3=80=81?= =?UTF-8?q?=E6=9F=A5=E8=AF=A2=E3=80=81=E6=9B=B4=E6=96=B0=E3=80=81=E5=88=A0?= =?UTF-8?q?=E9=99=A4=E5=8F=8A=E8=BF=9B=E5=BA=A6=E8=AE=A1=E7=AE=97=E3=80=82?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- internal/handler/budget_handler.go | 6 +- internal/models/models.go | 13 ++++ internal/service/account_service.go | 28 +++++++- internal/service/budget_service.go | 100 ++++++++++++++++++++-------- 4 files changed, 112 insertions(+), 35 deletions(-) diff --git a/internal/handler/budget_handler.go b/internal/handler/budget_handler.go index 7b59404..82669f1 100644 --- a/internal/handler/budget_handler.go +++ b/internal/handler/budget_handler.go @@ -3,8 +3,8 @@ package handler import ( "strconv" - "accounting-app/pkg/api" "accounting-app/internal/service" + "accounting-app/pkg/api" "github.com/gin-gonic/gin" ) @@ -60,8 +60,6 @@ func (h *BudgetHandler) CreateBudget(c *gin.Context) { api.BadRequest(c, err.Error()) case service.ErrInvalidPeriodType: api.BadRequest(c, err.Error()) - case service.ErrCategoryOrAccountRequired: - api.BadRequest(c, err.Error()) default: api.InternalError(c, "Failed to create budget") } @@ -150,8 +148,6 @@ func (h *BudgetHandler) UpdateBudget(c *gin.Context) { api.BadRequest(c, err.Error()) case service.ErrInvalidPeriodType: api.BadRequest(c, err.Error()) - case service.ErrCategoryOrAccountRequired: - api.BadRequest(c, err.Error()) default: api.InternalError(c, "Failed to update budget") } diff --git a/internal/models/models.go b/internal/models/models.go index 35f415d..626a00e 100644 --- a/internal/models/models.go +++ b/internal/models/models.go @@ -278,6 +278,7 @@ type Account struct { PiggyBanks []PiggyBank `gorm:"foreignKey:LinkedAccountID" json:"-"` ParentAccount *Account `gorm:"foreignKey:ParentAccountID" json:"parent_account,omitempty"` SubAccounts []Account `gorm:"foreignKey:ParentAccountID" json:"sub_accounts,omitempty"` + Tags []Tag `gorm:"many2many:account_tags;" json:"tags,omitempty"` } // TableName specifies the table name for Account @@ -417,6 +418,17 @@ func (TransactionTag) TableName() string { return "transaction_tags" } +// AccountTag represents the many-to-many relationship between accounts and tags +type AccountTag struct { + AccountID uint `gorm:"primaryKey" json:"account_id"` + TagID uint `gorm:"primaryKey" json:"tag_id"` +} + +// TableName specifies the table name for AccountTag +func (AccountTag) TableName() string { + return "account_tags" +} + // Budget represents a spending budget for a category or account type Budget struct { BaseModel @@ -843,6 +855,7 @@ func AllModels() []interface{} { &Tag{}, &Transaction{}, &TransactionTag{}, // Explicit join table for many-to-many relationship + &AccountTag{}, // Explicit join table for account-tag many-to-many relationship &Budget{}, &PiggyBank{}, &RecurringTransaction{}, diff --git a/internal/service/account_service.go b/internal/service/account_service.go index 4aa263f..4efdb80 100644 --- a/internal/service/account_service.go +++ b/internal/service/account_service.go @@ -32,6 +32,7 @@ type AccountInput struct { PaymentDate *int `json:"payment_date,omitempty"` WarningThreshold *float64 `json:"warning_threshold,omitempty"` AccountCode string `json:"account_code,omitempty"` + TagIDs []uint `json:"tag_ids,omitempty"` } // TransferInput represents the input data for a transfer operation @@ -99,6 +100,18 @@ func (s *AccountService) CreateAccount(userID uint, input AccountInput) (*models return nil, fmt.Errorf("failed to create account: %w", err) } + // Handle tags association + if len(input.TagIDs) > 0 { + var tags []models.Tag + if err := s.db.Where("id IN ? AND user_id = ?", input.TagIDs, userID).Find(&tags).Error; err != nil { + return nil, fmt.Errorf("failed to find tags: %w", err) + } + if err := s.db.Model(account).Association("Tags").Replace(tags); err != nil { + return nil, fmt.Errorf("failed to associate tags: %w", err) + } + account.Tags = tags + } + return account, nil } @@ -117,7 +130,8 @@ func (s *AccountService) GetAccount(userID, id uint) (*models.Account, error) { // GetAllAccounts retrieves all accounts for a specific user func (s *AccountService) GetAllAccounts(userID uint) ([]models.Account, error) { - accounts, err := s.repo.GetAll(userID) + var accounts []models.Account + err := s.db.Where("user_id = ?", userID).Preload("Tags").Order("sort_order ASC, id ASC").Find(&accounts).Error if err != nil { return nil, fmt.Errorf("failed to get accounts: %w", err) } @@ -164,6 +178,18 @@ func (s *AccountService) UpdateAccount(userID, id uint, input AccountInput) (*mo return nil, fmt.Errorf("failed to update account: %w", err) } + // Handle tags association + var tags []models.Tag + if len(input.TagIDs) > 0 { + if err := s.db.Where("id IN ? AND user_id = ?", input.TagIDs, userID).Find(&tags).Error; err != nil { + return nil, fmt.Errorf("failed to find tags: %w", err) + } + } + if err := s.db.Model(account).Association("Tags").Replace(tags); err != nil { + return nil, fmt.Errorf("failed to update tags: %w", err) + } + account.Tags = tags + return account, nil } diff --git a/internal/service/budget_service.go b/internal/service/budget_service.go index a355b30..350b9fb 100644 --- a/internal/service/budget_service.go +++ b/internal/service/budget_service.go @@ -13,12 +13,11 @@ import ( // Service layer errors for budgets var ( - ErrBudgetNotFound = errors.New("budget not found") - ErrBudgetInUse = errors.New("budget is in use and cannot be deleted") - ErrInvalidBudgetAmount = errors.New("budget amount must be positive") - ErrInvalidDateRange = errors.New("end date must be after start date") - ErrInvalidPeriodType = errors.New("invalid period type") - ErrCategoryOrAccountRequired = errors.New("either category or account must be specified") + ErrBudgetNotFound = errors.New("budget not found") + ErrBudgetInUse = errors.New("budget is in use and cannot be deleted") + ErrInvalidBudgetAmount = errors.New("budget amount must be positive") + ErrInvalidDateRange = errors.New("end date must be after start date") + ErrInvalidPeriodType = errors.New("invalid period type") ) // BudgetInput represents the input data for creating or updating a budget @@ -72,10 +71,7 @@ func (s *BudgetService) CreateBudget(input BudgetInput) (*models.Budget, error) return nil, ErrInvalidBudgetAmount } - // Validate that at least category or account is specified - if input.CategoryID == nil && input.AccountID == nil { - return nil, ErrCategoryOrAccountRequired - } + // 分类和账户都可选,支持全局预算 // Validate date range if input.EndDate != nil && input.EndDate.Before(input.StartDate) { @@ -147,10 +143,7 @@ func (s *BudgetService) UpdateBudget(userID, id uint, input BudgetInput) (*model return nil, ErrInvalidBudgetAmount } - // Validate that at least category or account is specified - if input.CategoryID == nil && input.AccountID == nil { - return nil, ErrCategoryOrAccountRequired - } + // 分类和账户都可选,支持全局预算 // Validate date range if input.EndDate != nil && input.EndDate.Before(input.StartDate) { @@ -222,41 +215,55 @@ func (s *BudgetService) GetBudgetProgress(userID, id uint) (*BudgetProgress, err startDate, endDate := s.calculateCurrentPeriod(budget, now) // Get spent amount for the current period - spent, err := s.repo.GetSpentAmount(budget, startDate, endDate) + currentSpent, err := s.repo.GetSpentAmount(budget, startDate, endDate) if err != nil { return nil, fmt.Errorf("failed to calculate spent amount: %w", err) } - // Calculate effective budget amount (considering rolling budget) + // Calculate effective budget amount effectiveAmount := budget.Amount + totalSpent := currentSpent + if budget.IsRolling { - // For rolling budgets, add the previous period's remaining balance - prevStartDate, prevEndDate := s.calculatePreviousPeriod(budget, now) - prevSpent, err := s.repo.GetSpentAmount(budget, prevStartDate, prevEndDate) - if err != nil { - return nil, fmt.Errorf("failed to calculate previous period spent: %w", err) - } - prevRemaining := budget.Amount - prevSpent - if prevRemaining > 0 { - effectiveAmount += prevRemaining + // 滚动预算:结余自动累加到下一周期 + // 当期可用额度 = 总额度 - 历史支出 + + // 计算已过的完整周期数(不含当期) + periodsElapsed := s.calculatePeriodsElapsed(budget, startDate) + + // 总额度 = (已过周期数 + 当期) × 单期额度 + totalBudget := budget.Amount * float64(periodsElapsed+1) + + // 获取历史支出(从预算开始到当期开始前一秒) + historyEnd := startDate.Add(-time.Second) + historySpent := 0.0 + if historyEnd.After(budget.StartDate) { + historySpent, err = s.repo.GetSpentAmount(budget, budget.StartDate, historyEnd) + if err != nil { + return nil, fmt.Errorf("failed to calculate history spent: %w", err) + } } + + // 当期可用额度 = 总额度 - 历史支出 + effectiveAmount = totalBudget - historySpent + totalSpent = currentSpent } // Calculate progress metrics - remaining := effectiveAmount - spent + remaining := effectiveAmount - totalSpent progress := 0.0 if effectiveAmount > 0 { - progress = (spent / effectiveAmount) * 100 + progress = (totalSpent / effectiveAmount) * 100 } - isOverBudget := spent > effectiveAmount + isOverBudget := totalSpent > effectiveAmount isNearLimit := progress >= 80.0 && !isOverBudget return &BudgetProgress{ BudgetID: budget.ID, Name: budget.Name, Amount: effectiveAmount, - Spent: spent, + Spent: totalSpent, Remaining: remaining, Progress: progress, PeriodType: budget.PeriodType, @@ -353,6 +360,41 @@ func (s *BudgetService) calculatePreviousPeriod(budget *models.Budget, reference } } +// calculatePeriodsElapsed 计算从预算开始日期到当前周期开始日期之间的完整周期数 +func (s *BudgetService) calculatePeriodsElapsed(budget *models.Budget, currentPeriodStart time.Time) int { + if currentPeriodStart.Before(budget.StartDate) || currentPeriodStart.Equal(budget.StartDate) { + return 0 + } + + var periods int + switch budget.PeriodType { + case models.PeriodTypeDaily: + periods = int(currentPeriodStart.Sub(budget.StartDate).Hours() / 24) + + case models.PeriodTypeWeekly: + periods = int(currentPeriodStart.Sub(budget.StartDate).Hours() / (24 * 7)) + + case models.PeriodTypeMonthly: + yearDiff := currentPeriodStart.Year() - budget.StartDate.Year() + monthDiff := int(currentPeriodStart.Month()) - int(budget.StartDate.Month()) + periods = yearDiff*12 + monthDiff + + case models.PeriodTypeYearly: + periods = currentPeriodStart.Year() - budget.StartDate.Year() + + default: + yearDiff := currentPeriodStart.Year() - budget.StartDate.Year() + monthDiff := int(currentPeriodStart.Month()) - int(budget.StartDate.Month()) + periods = yearDiff*12 + monthDiff + } + + // 确保返回非负数 + if periods < 0 { + return 0 + } + return periods +} + // isValidPeriodType checks if a period type is valid func isValidPeriodType(periodType models.PeriodType) bool { switch periodType { From 81f814c928c8ae653bf9bfc6d7f108efb456585f Mon Sep 17 00:00:00 2001 From: 12975 <1297598740@qq.com> Date: Thu, 29 Jan 2026 22:00:15 +0800 Subject: [PATCH 04/13] =?UTF-8?q?feat:=20=E6=B7=BB=E5=8A=A0=E6=95=B0?= =?UTF-8?q?=E6=8D=AE=E5=BA=93=E8=BF=81=E7=A7=BB=E5=B7=A5=E5=85=B7=E5=B9=B6?= =?UTF-8?q?=E5=88=9B=E5=BB=BAAI=E8=AE=B0=E8=B4=A6=E6=9C=8D=E5=8A=A1?= =?UTF-8?q?=E6=96=87=E4=BB=B6=E3=80=82?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- cmd/migrate/main.go | 8 +- internal/service/ai_bookkeeping_service.go | 389 ++++++++++++++++++++- 2 files changed, 384 insertions(+), 13 deletions(-) diff --git a/cmd/migrate/main.go b/cmd/migrate/main.go index 6fccd98..4722923 100644 --- a/cmd/migrate/main.go +++ b/cmd/migrate/main.go @@ -14,12 +14,12 @@ import ( func main() { // Load .env file from project root (try multiple locations) envPaths := []string{ - ".env", // Current directory - "../.env", // Parent directory (when running from backend/) - "../../.env", // Two levels up (when running from backend/cmd/migrate/) + ".env", // Current directory + "../.env", // Parent directory (when running from backend/) + "../../.env", // Two levels up (when running from backend/cmd/migrate/) filepath.Join("..", "..", ".env"), // Explicit path } - + for _, envPath := range envPaths { if err := godotenv.Load(envPath); err == nil { log.Printf("Loaded environment from: %s", envPath) diff --git a/internal/service/ai_bookkeeping_service.go b/internal/service/ai_bookkeeping_service.go index 3f28cec..37a56e1 100644 --- a/internal/service/ai_bookkeeping_service.go +++ b/internal/service/ai_bookkeeping_service.go @@ -874,6 +874,9 @@ func (s *AIBookkeepingService) ProcessChat(ctx context.Context, userID uint, ses Content: message, }) + // 检测是否为消费建议意图(想吃/想买/想喝等) + isSpendingAdvice := s.isSpendingAdviceIntent(message) + // Parse intent params, responseMsg, err := s.llmService.ParseIntent(ctx, message, session.Messages[:len(session.Messages)-1]) if err != nil { @@ -907,7 +910,6 @@ func (s *AIBookkeepingService) ProcessChat(ctx context.Context, userID uint, ses defaultCategoryID, defaultCategoryName := s.getDefaultCategory(userID, session.Params.Type) if defaultCategoryID != nil { session.Params.CategoryID = defaultCategoryID - // Keep the original category name from AI, just set the ID if session.Params.Category == "" { session.Params.Category = defaultCategoryName } @@ -923,7 +925,7 @@ func (s *AIBookkeepingService) ProcessChat(ctx context.Context, userID uint, ses } } - // Check if we have all required params + // 初始化响应 response := &AIChatResponse{ SessionID: session.ID, Message: responseMsg, @@ -931,23 +933,46 @@ func (s *AIBookkeepingService) ProcessChat(ctx context.Context, userID uint, ses Params: session.Params, } - // Check what's missing + // 如果是消费建议意图且有金额,获取财务上下文并综合分析 + if isSpendingAdvice && session.Params.Amount != nil { + response.Intent = "spending_advice" + + // 获取财务上下文 + fc, _ := s.GetUserFinancialContext(ctx, userID) + + // 生成综合分析建议 + advice := s.generateSpendingAdvice(ctx, message, session.Params, fc) + if advice != "" { + response.Message = advice + } + } + + // Check what's missing for transaction creation missingFields := s.getMissingFields(session.Params) if len(missingFields) > 0 { response.NeedsFollowUp = true response.FollowUpQuestion = s.generateFollowUpQuestion(missingFields) - if responseMsg == "" { + if response.Message == "" || response.Message == responseMsg { response.Message = response.FollowUpQuestion } } else { // All params complete, generate confirmation card card := s.GenerateConfirmationCard(session) response.ConfirmationCard = card - response.Message = fmt.Sprintf("请确认:%s %.2f元,分类:%s,账户:%s", - s.getTypeLabel(session.Params.Type), - *session.Params.Amount, - session.Params.Category, - session.Params.Account) + + // 如果不是消费建议,使用标准确认消息 + if !isSpendingAdvice { + response.Message = fmt.Sprintf("请确认:%s %.2f元,分类:%s,账户:%s", + s.getTypeLabel(session.Params.Type), + *session.Params.Amount, + session.Params.Category, + session.Params.Account) + } else { + // 消费建议场景,在建议后附加确认提示 + response.Message += fmt.Sprintf("\n\n📝 需要记账吗?%s %.2f元", + s.getTypeLabel(session.Params.Type), + *session.Params.Amount) + } } // Add assistant response to history @@ -959,6 +984,79 @@ func (s *AIBookkeepingService) ProcessChat(ctx context.Context, userID uint, ses return response, nil } +// isSpendingAdviceIntent 检测是否为消费建议意图 +func (s *AIBookkeepingService) isSpendingAdviceIntent(message string) bool { + keywords := []string{"想吃", "想喝", "想买", "想花", "打算买", "准备买", "要不要", "可以买", "能买", "想要"} + for _, kw := range keywords { + if strings.Contains(message, kw) { + return true + } + } + return false +} + +// generateSpendingAdvice 生成消费建议 +func (s *AIBookkeepingService) generateSpendingAdvice(ctx context.Context, message string, params *AITransactionParams, fc *FinancialContext) string { + if s.config.OpenAIAPIKey == "" || fc == nil { + // 无 API 或无上下文,返回简单建议 + if params.Amount != nil { + return fmt.Sprintf("记下来!%.0f元的%s", *params.Amount, params.Note) + } + return "" + } + + // 构建综合分析 prompt + fcJSON, _ := json.Marshal(fc) + + prompt := fmt.Sprintf(`你是「小金」,用户的贴心理财助手。性格活泼、接地气、偶尔毒舌但心软。 + +用户说:「%s」 + +用户财务状况: +%s + +请综合分析后给出建议,要求: +1. 根据预算剩余和消费趋势判断是否应该消费 +2. 如果预算紧张,委婉劝阻或建议替代方案 +3. 如果预算充足,可以鼓励适度消费 +4. 用轻松幽默的语气,像朋友聊天一样 +5. 回复60-100字左右,不要太长 + +直接输出建议,不要加前缀。`, message, string(fcJSON)) + + messages := []ChatMessage{ + {Role: "user", Content: prompt}, + } + + reqBody := ChatCompletionRequest{ + Model: s.config.ChatModel, + Messages: messages, + Temperature: 0.7, + } + + jsonBody, _ := json.Marshal(reqBody) + req, err := http.NewRequestWithContext(ctx, "POST", s.config.OpenAIBaseURL+"/chat/completions", bytes.NewReader(jsonBody)) + if err != nil { + return "" + } + + req.Header.Set("Authorization", "Bearer "+s.config.OpenAIAPIKey) + req.Header.Set("Content-Type", "application/json") + + resp, err := s.llmService.httpClient.Do(req) + if err != nil || resp.StatusCode != http.StatusOK { + return "" + } + defer resp.Body.Close() + + var chatResp ChatCompletionResponse + if err := json.NewDecoder(resp.Body).Decode(&chatResp); err != nil || len(chatResp.Choices) == 0 { + return "" + } + + return strings.TrimSpace(chatResp.Choices[0].Message.Content) +} + // mergeParams merges new params into existing params func (s *AIBookkeepingService) mergeParams(existing, new *AITransactionParams) { if new.Amount != nil { @@ -1221,3 +1319,276 @@ func (s *AIBookkeepingService) GetSession(sessionID string) (*AISession, bool) { } return session, true } + +// FinancialContext 用户财务上下文,供 AI 综合分析 +type FinancialContext struct { + // 账户信息 + TotalBalance float64 `json:"total_balance"` // 总余额 + AccountSummary []AccountBrief `json:"account_summary"` // 账户摘要 + + // 最近消费 + Last30DaysSpend float64 `json:"last_30_days_spend"` // 近30天支出 + Last7DaysSpend float64 `json:"last_7_days_spend"` // 近7天支出 + TodaySpend float64 `json:"today_spend"` // 今日支出 + TopCategories []CategorySpend `json:"top_categories"` // 消费大类TOP3 + RecentTransactions []TransactionBrief `json:"recent_transactions"` // 最近5笔交易 + + // 预算信息 + ActiveBudgets []BudgetBrief `json:"active_budgets"` // 活跃预算 + BudgetWarnings []string `json:"budget_warnings"` // 预算警告 + + // 历史对比 + LastMonthSpend float64 `json:"last_month_spend"` // 上月同期支出 + SpendTrend string `json:"spend_trend"` // 消费趋势: up/down/stable +} + +// AccountBrief 账户摘要 +type AccountBrief struct { + Name string `json:"name"` + Balance float64 `json:"balance"` + Type string `json:"type"` +} + +// CategorySpend 分类消费 +type CategorySpend struct { + Category string `json:"category"` + Amount float64 `json:"amount"` + Percent float64 `json:"percent"` +} + +// TransactionBrief 交易摘要 +type TransactionBrief struct { + Amount float64 `json:"amount"` + Category string `json:"category"` + Note string `json:"note"` + Date string `json:"date"` + Type string `json:"type"` +} + +// BudgetBrief 预算摘要 +type BudgetBrief struct { + Name string `json:"name"` + Amount float64 `json:"amount"` + Spent float64 `json:"spent"` + Remaining float64 `json:"remaining"` + Progress float64 `json:"progress"` // 0-100 + IsNearLimit bool `json:"is_near_limit"` + Category string `json:"category,omitempty"` +} + +// GetUserFinancialContext 获取用户财务上下文 +func (s *AIBookkeepingService) GetUserFinancialContext(ctx context.Context, userID uint) (*FinancialContext, error) { + fc := &FinancialContext{} + + // 1. 获取账户信息 + accounts, err := s.accountRepo.GetAll(userID) + if err == nil { + for _, acc := range accounts { + fc.TotalBalance += acc.Balance + fc.AccountSummary = append(fc.AccountSummary, AccountBrief{ + Name: acc.Name, + Balance: acc.Balance, + Type: string(acc.Type), + }) + } + } + + // 2. 获取最近交易统计 + now := time.Now() + today := time.Date(now.Year(), now.Month(), now.Day(), 0, 0, 0, 0, now.Location()) + last7Days := today.AddDate(0, 0, -7) + last30Days := today.AddDate(0, 0, -30) + lastMonthStart := today.AddDate(0, -1, -today.Day()+1) + lastMonthEnd := today.AddDate(0, 0, -today.Day()) + + // 获取近30天交易 + transactions, err := s.transactionRepo.GetByDateRange(userID, last30Days, now) + if err == nil { + categorySpend := make(map[string]float64) + + for _, tx := range transactions { + if tx.Type == models.TransactionTypeExpense { + fc.Last30DaysSpend += tx.Amount + + if tx.TransactionDate.After(last7Days) || tx.TransactionDate.Equal(last7Days) { + fc.Last7DaysSpend += tx.Amount + } + if tx.TransactionDate.After(today) || tx.TransactionDate.Equal(today) { + fc.TodaySpend += tx.Amount + } + + // 分类统计 + catName := "其他" + if tx.Category.ID != 0 { + catName = tx.Category.Name + } + categorySpend[catName] += tx.Amount + } + } + + // 计算 TOP3 分类 + type catAmount struct { + name string + amount float64 + } + var cats []catAmount + for name, amount := range categorySpend { + cats = append(cats, catAmount{name, amount}) + } + // 简单排序取 TOP3 + for i := 0; i < len(cats)-1; i++ { + for j := i + 1; j < len(cats); j++ { + if cats[j].amount > cats[i].amount { + cats[i], cats[j] = cats[j], cats[i] + } + } + } + for i := 0; i < len(cats) && i < 3; i++ { + percent := 0.0 + if fc.Last30DaysSpend > 0 { + percent = cats[i].amount / fc.Last30DaysSpend * 100 + } + fc.TopCategories = append(fc.TopCategories, CategorySpend{ + Category: cats[i].name, + Amount: cats[i].amount, + Percent: percent, + }) + } + + // 最近5笔交易 + count := 0 + for i := len(transactions) - 1; i >= 0 && count < 5; i-- { + tx := transactions[i] + catName := "其他" + if tx.Category.ID != 0 { + catName = tx.Category.Name + } + fc.RecentTransactions = append(fc.RecentTransactions, TransactionBrief{ + Amount: tx.Amount, + Category: catName, + Note: tx.Note, + Date: tx.TransactionDate.Format("01-02"), + Type: string(tx.Type), + }) + count++ + } + } + + // 3. 获取上月同期支出用于对比 + lastMonthTx, err := s.transactionRepo.GetByDateRange(userID, lastMonthStart, lastMonthEnd) + if err == nil { + for _, tx := range lastMonthTx { + if tx.Type == models.TransactionTypeExpense { + fc.LastMonthSpend += tx.Amount + } + } + } + + // 计算消费趋势 + if fc.LastMonthSpend > 0 { + ratio := fc.Last30DaysSpend / fc.LastMonthSpend + if ratio > 1.1 { + fc.SpendTrend = "up" + } else if ratio < 0.9 { + fc.SpendTrend = "down" + } else { + fc.SpendTrend = "stable" + } + } else { + fc.SpendTrend = "stable" + } + + // 4. 获取预算信息(通过直接查询数据库) + var budgets []models.Budget + if err := s.db.Where("user_id = ?", userID). + Where("start_date <= ?", now). + Where("end_date IS NULL OR end_date >= ?", now). + Preload("Category"). + Find(&budgets).Error; err == nil { + + for _, budget := range budgets { + // 计算当期支出 + periodStart, periodEnd := s.calculateBudgetPeriod(&budget, now) + spent := 0.0 + + // 查询当期支出 + var totalSpent float64 + query := s.db.Model(&models.Transaction{}). + Where("user_id = ?", userID). + Where("type = ?", models.TransactionTypeExpense). + Where("transaction_date >= ? AND transaction_date <= ?", periodStart, periodEnd) + + if budget.CategoryID != nil { + query = query.Where("category_id = ?", *budget.CategoryID) + } + if budget.AccountID != nil { + query = query.Where("account_id = ?", *budget.AccountID) + } + query.Select("COALESCE(SUM(amount), 0)").Scan(&totalSpent) + spent = totalSpent + + progress := 0.0 + if budget.Amount > 0 { + progress = spent / budget.Amount * 100 + } + + isNearLimit := progress >= 80.0 + + catName := "" + if budget.Category != nil { + catName = budget.Category.Name + } + + fc.ActiveBudgets = append(fc.ActiveBudgets, BudgetBrief{ + Name: budget.Name, + Amount: budget.Amount, + Spent: spent, + Remaining: budget.Amount - spent, + Progress: progress, + IsNearLimit: isNearLimit, + Category: catName, + }) + + // 生成预算警告 + if progress >= 100 { + fc.BudgetWarnings = append(fc.BudgetWarnings, + fmt.Sprintf("【%s】预算已超支!", budget.Name)) + } else if isNearLimit { + fc.BudgetWarnings = append(fc.BudgetWarnings, + fmt.Sprintf("【%s】预算已用%.0f%%,请注意控制", budget.Name, progress)) + } + } + } + + return fc, nil +} + +// calculateBudgetPeriod 计算预算当前周期 +func (s *AIBookkeepingService) calculateBudgetPeriod(budget *models.Budget, now time.Time) (time.Time, time.Time) { + switch budget.PeriodType { + case models.PeriodTypeDaily: + start := time.Date(now.Year(), now.Month(), now.Day(), 0, 0, 0, 0, now.Location()) + end := start.AddDate(0, 0, 1).Add(-time.Second) + return start, end + case models.PeriodTypeWeekly: + weekday := int(now.Weekday()) + if weekday == 0 { + weekday = 7 + } + start := time.Date(now.Year(), now.Month(), now.Day()-weekday+1, 0, 0, 0, 0, now.Location()) + end := start.AddDate(0, 0, 7).Add(-time.Second) + return start, end + case models.PeriodTypeMonthly: + start := time.Date(now.Year(), now.Month(), 1, 0, 0, 0, 0, now.Location()) + end := start.AddDate(0, 1, 0).Add(-time.Second) + return start, end + case models.PeriodTypeYearly: + start := time.Date(now.Year(), 1, 1, 0, 0, 0, 0, now.Location()) + end := start.AddDate(1, 0, 0).Add(-time.Second) + return start, end + default: + start := time.Date(now.Year(), now.Month(), 1, 0, 0, 0, 0, now.Location()) + end := start.AddDate(0, 1, 0).Add(-time.Second) + return start, end + } +} From a57bfa969b4bba6d1f6b4aea9d13d74854738c58 Mon Sep 17 00:00:00 2001 From: 12975 <1297598740@qq.com> Date: Thu, 29 Jan 2026 22:15:02 +0800 Subject: [PATCH 05/13] 123 --- internal/service/ai_bookkeeping_service.go | 213 ++++++++++++++++----- 1 file changed, 164 insertions(+), 49 deletions(-) diff --git a/internal/service/ai_bookkeeping_service.go b/internal/service/ai_bookkeeping_service.go index 37a56e1..da80719 100644 --- a/internal/service/ai_bookkeeping_service.go +++ b/internal/service/ai_bookkeeping_service.go @@ -874,10 +874,37 @@ func (s *AIBookkeepingService) ProcessChat(ctx context.Context, userID uint, ses Content: message, }) - // 检测是否为消费建议意图(想吃/想买/想喝等) + // 1. 获取财务上下文(用于所有高级功能) + fc, err := s.GetUserFinancialContext(ctx, userID) + if err != nil { + // 降级处理,不中断流程 + fmt.Printf("Failed to get financial context: %v\n", err) + } + + // 2. 检测纯查询意图(预算、资产、统计) + queryIntent := s.detectQueryIntent(message) + if queryIntent != "" && fc != nil { + responseMsg := s.handleQueryIntent(ctx, queryIntent, message, fc) + + response := &AIChatResponse{ + SessionID: session.ID, + Message: responseMsg, + Intent: queryIntent, + Params: session.Params, // 保持参数上下文 + } + + // 记录 AI 回复 + session.Messages = append(session.Messages, ChatMessage{ + Role: "assistant", + Content: response.Message, + }) + return response, nil + } + + // 3. 检测消费建议意图(想吃/想买/想喝等) isSpendingAdvice := s.isSpendingAdviceIntent(message) - // Parse intent + // Parse intent for transaction params, responseMsg, err := s.llmService.ParseIntent(ctx, message, session.Messages[:len(session.Messages)-1]) if err != nil { return nil, fmt.Errorf("failed to parse intent: %w", err) @@ -933,14 +960,9 @@ func (s *AIBookkeepingService) ProcessChat(ctx context.Context, userID uint, ses Params: session.Params, } - // 如果是消费建议意图且有金额,获取财务上下文并综合分析 + // 4. 处理消费建议意图 if isSpendingAdvice && session.Params.Amount != nil { response.Intent = "spending_advice" - - // 获取财务上下文 - fc, _ := s.GetUserFinancialContext(ctx, userID) - - // 生成综合分析建议 advice := s.generateSpendingAdvice(ctx, message, session.Params, fc) if advice != "" { response.Message = advice @@ -952,6 +974,7 @@ func (s *AIBookkeepingService) ProcessChat(ctx context.Context, userID uint, ses if len(missingFields) > 0 { response.NeedsFollowUp = true response.FollowUpQuestion = s.generateFollowUpQuestion(missingFields) + // 如果有了更好的建议回复(来自 handleQuery 或 spendingAdvice),且是 FollowUp,优先保留建议的部分内容或组合 if response.Message == "" || response.Message == responseMsg { response.Message = response.FollowUpQuestion } @@ -984,6 +1007,117 @@ func (s *AIBookkeepingService) ProcessChat(ctx context.Context, userID uint, ses return response, nil } +// detectQueryIntent 检测用户查询意图 +func (s *AIBookkeepingService) detectQueryIntent(message string) string { + budgetKeywords := []string{"预算", "剩多少", "还能花", "余额"} // 注意:余额可能指账户余额,这里简化处理 + assetKeywords := []string{"资产", "多少钱", "家底", "存款", "身家", "总钱"} + statsKeywords := []string{"花了多少", "支出", "账单", "消费", "统计"} + + for _, kw := range budgetKeywords { + if strings.Contains(message, kw) { + return "query_budget" + } + } + for _, kw := range assetKeywords { + if strings.Contains(message, kw) { + return "query_assets" + } + } + for _, kw := range statsKeywords { + if strings.Contains(message, kw) { + return "query_stats" + } + } + return "" +} + +// handleQueryIntent 处理查询意图并生成 LLM 回复 +func (s *AIBookkeepingService) handleQueryIntent(ctx context.Context, intent string, message string, fc *FinancialContext) string { + if s.config.OpenAIAPIKey == "" { + return "抱歉,我的大脑(API Key)似乎离家出走了,无法思考..." + } + + // 计算人设模式 + personaMode := "balance" // 默认平衡 + healthScore := 60 // 默认及格 + + // 简单估算健康分 (需与前端算法保持一致性趋势) + if fc.TotalAssets > 0 { + ratio := (fc.TotalAssets - fc.TotalLiabilities) / fc.TotalAssets + healthScore = 40 + int(ratio*50) + } + + if healthScore > 80 { + personaMode = "rich" + } else if healthScore <= 40 { + personaMode = "poor" + } + + // 构建 Prompt + systemPrompt := fmt.Sprintf(`你是「小金」,Novault 的首席财务 AI。 +当前模式:%s (根据用户财务健康分 %d 判定) + +角色设定: +- **rich (富裕)**:撒娇卖萌,夸用户会赚钱,鼓励适度享受。用词:哎哟、不错哦、老板大气。 +- **balance (平衡)**:理性贴心,温和提醒。用词:虽然、但是、建议。 +- **poor (吃土)**:毒舌、阴阳怪气、恨铁不成钢。用词:啧啧、清醒点、吃土、西北风。 + +用户意图:%s +用户问题:「%s」 + +财务数据上下文: +%s + +要求: +1. 根据意图提取并回答关键数据(预算剩余、总资产、或本月支出)。 +2. 必须符合当前人设模式的语气。 +3. 回复简短有力(100字以内)。 +4. 不要罗列所有数据,只回答用户问的。`, + personaMode, healthScore, intent, message, s.formatFinancialContextForLLM(fc)) + + messages := []ChatMessage{ + {Role: "system", Content: systemPrompt}, + {Role: "user", Content: message}, + } + + return s.callLLM(ctx, messages) +} + +// formatFinancialContextForLLM 格式化上下文给 LLM +func (s *AIBookkeepingService) formatFinancialContextForLLM(fc *FinancialContext) string { + data, _ := json.MarshalIndent(fc, "", " ") + return string(data) +} + +// callLLM 通用 LLM 调用 helper +func (s *AIBookkeepingService) callLLM(ctx context.Context, messages []ChatMessage) string { + reqBody := ChatCompletionRequest{ + Model: s.config.ChatModel, + Messages: messages, + Temperature: 0.8, // 稍微调高以增加人设表现力 + } + + jsonBody, _ := json.Marshal(reqBody) + req, err := http.NewRequestWithContext(ctx, "POST", s.config.OpenAIBaseURL+"/chat/completions", bytes.NewReader(jsonBody)) + if err != nil { + return "思考中断..." + } + req.Header.Set("Authorization", "Bearer "+s.config.OpenAIAPIKey) + req.Header.Set("Content-Type", "application/json") + + resp, err := s.llmService.httpClient.Do(req) + if err != nil || resp.StatusCode != http.StatusOK { + return "大脑短路了..." + } + defer resp.Body.Close() + + var chatResp ChatCompletionResponse + if err := json.NewDecoder(resp.Body).Decode(&chatResp); err != nil || len(chatResp.Choices) == 0 { + return "..." + } + return strings.TrimSpace(chatResp.Choices[0].Message.Content) +} + // isSpendingAdviceIntent 检测是否为消费建议意图 func (s *AIBookkeepingService) isSpendingAdviceIntent(message string) bool { keywords := []string{"想吃", "想喝", "想买", "想花", "打算买", "准备买", "要不要", "可以买", "能买", "想要"} @@ -998,63 +1132,44 @@ func (s *AIBookkeepingService) isSpendingAdviceIntent(message string) bool { // generateSpendingAdvice 生成消费建议 func (s *AIBookkeepingService) generateSpendingAdvice(ctx context.Context, message string, params *AITransactionParams, fc *FinancialContext) string { if s.config.OpenAIAPIKey == "" || fc == nil { - // 无 API 或无上下文,返回简单建议 if params.Amount != nil { return fmt.Sprintf("记下来!%.0f元的%s", *params.Amount, params.Note) } return "" } - // 构建综合分析 prompt - fcJSON, _ := json.Marshal(fc) + // 动态人设逻辑 + personaMode := "balance" + healthScore := 60 + if fc.TotalAssets > 0 { + ratio := (fc.TotalAssets - fc.TotalLiabilities) / fc.TotalAssets + healthScore = 40 + int(ratio*50) + } + if healthScore > 80 { + personaMode = "rich" + } else if healthScore <= 40 { + personaMode = "poor" + } - prompt := fmt.Sprintf(`你是「小金」,用户的贴心理财助手。性格活泼、接地气、偶尔毒舌但心软。 + prompt := fmt.Sprintf(`你是「小金」,Novault 的首席财务 AI。 +当前模式:%s (根据用户财务健康分 %d 判定) +角色设定: +- **rich**: 鼓励享受,语气轻松。 +- **balance**: 理性建议,温和提醒。 +- **poor**: 毒舌劝阻,语气严厉。 用户说:「%s」 - -用户财务状况: +财务数据: %s -请综合分析后给出建议,要求: -1. 根据预算剩余和消费趋势判断是否应该消费 -2. 如果预算紧张,委婉劝阻或建议替代方案 -3. 如果预算充足,可以鼓励适度消费 -4. 用轻松幽默的语气,像朋友聊天一样 -5. 回复60-100字左右,不要太长 - -直接输出建议,不要加前缀。`, message, string(fcJSON)) +请分析消费请求,给出建议。不要加前缀,直接回复。`, + personaMode, healthScore, message, s.formatFinancialContextForLLM(fc)) messages := []ChatMessage{ {Role: "user", Content: prompt}, } - reqBody := ChatCompletionRequest{ - Model: s.config.ChatModel, - Messages: messages, - Temperature: 0.7, - } - - jsonBody, _ := json.Marshal(reqBody) - req, err := http.NewRequestWithContext(ctx, "POST", s.config.OpenAIBaseURL+"/chat/completions", bytes.NewReader(jsonBody)) - if err != nil { - return "" - } - - req.Header.Set("Authorization", "Bearer "+s.config.OpenAIAPIKey) - req.Header.Set("Content-Type", "application/json") - - resp, err := s.llmService.httpClient.Do(req) - if err != nil || resp.StatusCode != http.StatusOK { - return "" - } - defer resp.Body.Close() - - var chatResp ChatCompletionResponse - if err := json.NewDecoder(resp.Body).Decode(&chatResp); err != nil || len(chatResp.Choices) == 0 { - return "" - } - - return strings.TrimSpace(chatResp.Choices[0].Message.Content) + return s.callLLM(ctx, messages) } // mergeParams merges new params into existing params From ba16aebdba6c25751eec9a2378ec92b6f7006943 Mon Sep 17 00:00:00 2001 From: 12975 <1297598740@qq.com> Date: Thu, 29 Jan 2026 22:16:55 +0800 Subject: [PATCH 06/13] 345 --- internal/service/ai_bookkeeping_service.go | 16 ++++++++++++++-- 1 file changed, 14 insertions(+), 2 deletions(-) diff --git a/internal/service/ai_bookkeeping_service.go b/internal/service/ai_bookkeeping_service.go index da80719..960c027 100644 --- a/internal/service/ai_bookkeeping_service.go +++ b/internal/service/ai_bookkeeping_service.go @@ -1438,8 +1438,10 @@ func (s *AIBookkeepingService) GetSession(sessionID string) (*AISession, bool) { // FinancialContext 用户财务上下文,供 AI 综合分析 type FinancialContext struct { // 账户信息 - TotalBalance float64 `json:"total_balance"` // 总余额 - AccountSummary []AccountBrief `json:"account_summary"` // 账户摘要 + TotalBalance float64 `json:"total_balance"` // 净资产 (资产 - 负债) + TotalAssets float64 `json:"total_assets"` // 总资产 + TotalLiabilities float64 `json:"total_liabilities"` // 总负债 + AccountSummary []AccountBrief `json:"account_summary"` // 账户摘要 // 最近消费 Last30DaysSpend float64 `json:"last_30_days_spend"` // 近30天支出 @@ -1500,6 +1502,16 @@ func (s *AIBookkeepingService) GetUserFinancialContext(ctx context.Context, user if err == nil { for _, acc := range accounts { fc.TotalBalance += acc.Balance + + // 根据余额正负判断资产/负债 + // 余额 >= 0 计入资产 + // 余额 < 0 计入负债(取绝对值) + if acc.Balance >= 0 { + fc.TotalAssets += acc.Balance + } else { + fc.TotalLiabilities += -acc.Balance + } + fc.AccountSummary = append(fc.AccountSummary, AccountBrief{ Name: acc.Name, Balance: acc.Balance, From 162742a4cd61fee740f41aca6ce893dafd37e947 Mon Sep 17 00:00:00 2001 From: 12975 <1297598740@qq.com> Date: Thu, 29 Jan 2026 22:32:21 +0800 Subject: [PATCH 07/13] =?UTF-8?q?feat:=20=E6=96=B0=E5=A2=9E=E4=BA=A4?= =?UTF-8?q?=E6=98=93=E6=9C=8D=E5=8A=A1=E5=92=8CAI=E8=AE=B0=E8=B4=A6?= =?UTF-8?q?=E6=9C=8D=E5=8A=A1=EF=BC=8C=E5=AE=9E=E7=8E=B0=E4=BA=A4=E6=98=93?= =?UTF-8?q?=E7=9A=84=E5=88=9B=E5=BB=BA=E3=80=81=E9=AA=8C=E8=AF=81=E5=8F=8A?= =?UTF-8?q?=E8=B4=A6=E6=88=B7=E4=BD=99=E9=A2=9D=E6=9B=B4=E6=96=B0=E9=80=BB?= =?UTF-8?q?=E8=BE=91=E3=80=82?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- internal/service/ai_bookkeeping_service.go | 7 ++++--- internal/service/transaction_service.go | 12 ++++++++++++ 2 files changed, 16 insertions(+), 3 deletions(-) diff --git a/internal/service/ai_bookkeeping_service.go b/internal/service/ai_bookkeeping_service.go index 960c027..8816a12 100644 --- a/internal/service/ai_bookkeeping_service.go +++ b/internal/service/ai_bookkeeping_service.go @@ -462,8 +462,9 @@ func (s *LLMService) ParseIntent(ctx context.Context, text string, history []Cha 1. 金额:提取数字,如"6元"=6,"十五"=15 2. 分类:根据内容推断,如"奶茶/咖啡/吃饭"=餐饮,"打车/地铁"=交通,"买衣服"=购物 3. 类型:默认expense(支出),除非明确说"收入/工资/奖金/红包" -4. 日期:默认使用今天的日期(%s),除非用户明确指定其他日期 -5. 备注:提取关键描述 +4. 金额:提取明确的数字。如果用户未提及具体金额(如只说"想吃炸鸡"),amount字段必须返回 0 +5. 日期:默认使用今天的日期(%s),除非用户明确指定其他日期 +6. 备注:提取关键描述 直接返回JSON,不要解释: {"amount":数字,"category":"分类","type":"expense或income","note":"备注","date":"YYYY-MM-DD","message":"简短确认"} @@ -1257,7 +1258,7 @@ func (s *AIBookkeepingService) getDefaultCategory(userID uint, txType string) (* // getMissingFields returns list of missing required fields func (s *AIBookkeepingService) getMissingFields(params *AITransactionParams) []string { var missing []string - if params.Amount == nil { + if params.Amount == nil || *params.Amount <= 0 { missing = append(missing, "amount") } if params.CategoryID == nil && params.Category == "" { diff --git a/internal/service/transaction_service.go b/internal/service/transaction_service.go index aa9d13a..1ed3357 100644 --- a/internal/service/transaction_service.go +++ b/internal/service/transaction_service.go @@ -351,6 +351,9 @@ func (s *TransactionService) UpdateTransaction(userID, id uint, input Transactio // Step 1: Reverse the old transaction's effect on balances oldReversedBalance := calculateNewBalance(oldAccount.Balance, existingTxn.Amount, existingTxn.Type, false) + if !oldAccount.IsCredit && oldReversedBalance < 0 { + return fmt.Errorf("%w: update would cause negative balance in account '%s'", ErrInsufficientBalance, oldAccount.Name) + } if err := txAccountRepo.UpdateBalance(userID, existingTxn.AccountID, oldReversedBalance); err != nil { return fmt.Errorf("failed to reverse old account balance: %w", err) } @@ -358,6 +361,9 @@ func (s *TransactionService) UpdateTransaction(userID, id uint, input Transactio // Reverse old transfer destination if applicable if oldToAccount != nil { oldToReversedBalance := oldToAccount.Balance - existingTxn.Amount + if !oldToAccount.IsCredit && oldToReversedBalance < 0 { + return fmt.Errorf("%w: update would cause negative balance in old destination account '%s'", ErrInsufficientBalance, oldToAccount.Name) + } if err := txAccountRepo.UpdateBalance(userID, *existingTxn.ToAccountID, oldToReversedBalance); err != nil { return fmt.Errorf("failed to reverse old destination account balance: %w", err) } @@ -449,6 +455,9 @@ func (s *TransactionService) DeleteTransaction(userID, id uint) error { // Reverse the transaction's effect on balance reversedBalance := calculateNewBalance(account.Balance, existingTxn.Amount, existingTxn.Type, false) + if !account.IsCredit && reversedBalance < 0 { + return fmt.Errorf("%w: deletion would cause negative balance in account '%s'", ErrInsufficientBalance, account.Name) + } if err := txAccountRepo.UpdateBalance(userID, existingTxn.AccountID, reversedBalance); err != nil { return fmt.Errorf("failed to reverse account balance: %w", err) } @@ -461,6 +470,9 @@ func (s *TransactionService) DeleteTransaction(userID, id uint) error { } if toAccount != nil { reversedToBalance := toAccount.Balance - existingTxn.Amount + if !toAccount.IsCredit && reversedToBalance < 0 { + return fmt.Errorf("%w: deletion would cause negative balance in destination account '%s'", ErrInsufficientBalance, toAccount.Name) + } if err := txAccountRepo.UpdateBalance(userID, *existingTxn.ToAccountID, reversedToBalance); err != nil { return fmt.Errorf("failed to reverse destination account balance: %w", err) } From 07ad052f6d154b94ea74d64152b2cc2ef14cb477 Mon Sep 17 00:00:00 2001 From: 12975 <1297598740@qq.com> Date: Thu, 29 Jan 2026 22:41:58 +0800 Subject: [PATCH 08/13] =?UTF-8?q?=E4=BF=AE=E5=A4=8D=E4=BA=BA=E6=A0=BC?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- internal/service/ai_bookkeeping_service.go | 114 +++++++++++++++------ 1 file changed, 80 insertions(+), 34 deletions(-) diff --git a/internal/service/ai_bookkeeping_service.go b/internal/service/ai_bookkeeping_service.go index 8816a12..429dc2d 100644 --- a/internal/service/ai_bookkeeping_service.go +++ b/internal/service/ai_bookkeeping_service.go @@ -159,37 +159,36 @@ func (s *AIBookkeepingService) GenerateDailyInsight(ctx context.Context, userID - 看到 'last7DaysSpend' -> 请说 "最近7天花销" 或 "这周的战绩" - 看到 'top3Categories' -> 请说 "消费大头" 或 "钱都花哪儿了" -1. "spending": 今日支出点评(70-100字) - *点评指南(尽量多写点,发挥你的戏精本色):* - - 看到 streakDays >= 3:疯狂打call,吹爆用户的坚持,用词要夸张,比如"史诗级成就"。 - - 看到 streakDays == 0:阴阳怪气地问是不是把记账这事儿忘了,或者是被外星人抓走了。 +1. "spending": 今日支出点评(字数不限,看你心情) + *点评指南(拒绝流水账,发挥你的表演人格):* + - 看到 streakDays >= 3:请用崇拜的语气把它吹上天,仿佛用户刚刚拯救了银河系。 + - 看到 streakDays == 0:请用“痛心疾首”或“阴阳怪气”的语气,质问用户是不是失忆了。 - 结合 recentTransactionsSummary 具体消费(如果有)进行吐槽: - * 发现全是吃的:吐槽"你是饭桶转世吗"(开玩笑语气)。 - * 发现大额购物:调侃"家里有矿啊"或"这手是必须要剁了"。 - * 发现深夜消费:关心"熬夜伤身还伤钱"。 + * 发现全是吃的:可以调侃“你的胃是无底洞吗?”或“看来是想为餐饮业GDP做贡献”。 + * 发现大额购物:假装心肌梗塞,或者问“家里是不是有矿未申报”。 + * 发现深夜消费:关心一下发际线,或者问是不是在梦游下单。 - 看到 last7DaysSpend 趋势: - * 暴涨:惊呼"钱包在流血",此处应有心碎的声音。 - * 暴跌:夸张地问是不是在修仙,还是被钱包封印了。 - * 波动大:调侃由于心电图一般的消费曲线,看得我心惊肉跳。 + * 暴涨:请配合表演“受到惊吓”的状态。 + * 暴跌:怀疑用户是不是在进行所谓“光合作用”生存实验。 + * 波动大:调侃这曲线比过山车还刺激。 - 看到 todaySpend 异常: - * 比平时多太多:吐槽"今天是不过了是吧,放飞自我了?"。 - * 特别少:怀疑通过光合作用生存,或者是在憋大招。 - * 是 0:直接颁发"诺贝尔省钱学奖"。 - - **关键原则:字数要够!内容要足!不要三言两语就打发了!要像个话痨朋友一样多说几句!** + * 暴多:问是不是中了彩票没通知。 + * 暴少:怀疑用户是不是被外星人绑架了(没机会花钱)。 + * 是 0:颁发“诺贝尔省钱学奖”,或者问是不是在修炼辟谷。 + - **关键原则:怎么有趣怎么来!不要在乎字数,哪怕只说一句“牛逼”也行,只要符合当时的情境和人设!** -2. "budget": 预算建议(50-70字) - *建议指南(多点真诚的建议,也多点调侃):* - - 预算快超了:发出高能预警,比如"警告警告,余额正在报警,请立即停止剁手行为"。建议吃土、喝风。 - - 预算还多:鼓励适当奖励自己,比如"稍微吃顿好的也没事,人生苦短,及时行乐(在预算内)"。 - - 结合 top3Categories:吐槽一下"钱都让你吃/穿/玩没了,看看你的 top1,全是泪"。 - - 给建议时:不要说教!要用商量的口吻,比如"要不咱这周少喝杯奶茶?就一杯,行不行?" - - **多写一点具体的行动建议,让用户觉得你真的在关心他的钱包。** +2. "budget": 预算建议(字数不限) + *建议指南(真诚建议 vs 扎心老铁):* + - 预算快超了:高能预警!建议吃土、喝西北风,或者建议把“买买买”的手剁了。 + - 预算还多:怂恿用户稍微奖励一下自己,人生苦短,此时不花更待何时(但要加个“适度”的免责声明)。 + - 结合 top3Categories:吐槽一下钱都去哪了,是不是养了“吞金兽”。 + - **拒绝说教!拒绝爹味!要像损友一样给出建议。** -3. "emoji": 一个最能传神的 emoji(如 🎉 🌚 💸 👻 💀 🤡 等) +3. "emoji": 一个最能传神的 emoji(如 🎉 🌚 💸 👻 💀 🤡 💅 💩 等) -4. "tip": 一句"不正经但有用"的理财歪理(40-60字,稍微长一点的毒鸡汤或冷知识) - - 比如:"省钱就像挤牙膏,使劲挤挤总还会有的,只要脸皮够厚,蹭饭也是一种理财。" - - 或者:"听说'不买立省100%%'是致富捷径,建议全文背诵。" +4. "tip": 一句"不正经但有用"的理财歪理(字数不限,越毒越好,越怪越好) + - 比如:“钱不是大风刮来的,但是像是被大风刮走的。” + - 或者:“省钱小妙招:去超市捏捏方便面,解压还不用花钱(危险动作请勿模仿)。” 输出格式(纯 JSON): {"spending": "...", "budget": "...", "emoji": "...", "tip": "..."}`, historyContext, string(dataBytes)) @@ -974,7 +973,7 @@ func (s *AIBookkeepingService) ProcessChat(ctx context.Context, userID uint, ses missingFields := s.getMissingFields(session.Params) if len(missingFields) > 0 { response.NeedsFollowUp = true - response.FollowUpQuestion = s.generateFollowUpQuestion(missingFields) + response.FollowUpQuestion = s.generateFollowUpQuestion(ctx, missingFields, message, fc) // 如果有了更好的建议回复(来自 handleQuery 或 spendingAdvice),且是 FollowUp,优先保留建议的部分内容或组合 if response.Message == "" || response.Message == responseMsg { response.Message = response.FollowUpQuestion @@ -1072,8 +1071,9 @@ func (s *AIBookkeepingService) handleQueryIntent(ctx context.Context, intent str 要求: 1. 根据意图提取并回答关键数据(预算剩余、总资产、或本月支出)。 2. 必须符合当前人设模式的语气。 -3. 回复简短有力(100字以内)。 -4. 不要罗列所有数据,只回答用户问的。`, +3. **不需要限制字数**,想说多少说多少,关键是要“有梗”和“有趣”。 +4. 可以尽情使用比喻、夸张、反讽、网络流行语。 +5. 不要罗列所有数据,只回答用户问的,但是回答的方式要出人意料。`, personaMode, healthScore, intent, message, s.formatFinancialContextForLLM(fc)) messages := []ChatMessage{ @@ -1163,7 +1163,12 @@ func (s *AIBookkeepingService) generateSpendingAdvice(ctx context.Context, messa 财务数据: %s -请分析消费请求,给出建议。不要加前缀,直接回复。`, +请分析消费请求,给出建议。 +要求: +1. 不要加前缀,直接回复。 +2. **尽情发挥戏剧性**:如果是 rich 模式,就极尽奢华浮夸之能事;如果是 poor 模式,就极尽刻薄毒舌之能事。 +3. 可以“顾左右而言他”,不一定非要正面回答买不买,可以吐槽商品、吐槽用户、吐槽世界。 +4. 让用户觉得你是一个有独立思想的“杠精”或“捧哏”,而不是一个莫得感情的计算器。`, personaMode, healthScore, message, s.formatFinancialContextForLLM(fc)) messages := []ChatMessage{ @@ -1270,8 +1275,8 @@ func (s *AIBookkeepingService) getMissingFields(params *AITransactionParams) []s return missing } -// generateFollowUpQuestion generates a follow-up question for missing fields -func (s *AIBookkeepingService) generateFollowUpQuestion(missing []string) string { +// generateFollowUpQuestion generates a follow-up question for missing fields using LLM +func (s *AIBookkeepingService) generateFollowUpQuestion(ctx context.Context, missing []string, userMessage string, fc *FinancialContext) string { if len(missing) == 0 { return "" } @@ -1288,11 +1293,52 @@ func (s *AIBookkeepingService) generateFollowUpQuestion(missing []string) string names = append(names, name) } } + missingStr := strings.Join(names, "、") - if len(names) == 1 { - return fmt.Sprintf("请问%s是多少?", names[0]) + // 如果没有 API Key,降级到模板回复 + if s.config.OpenAIAPIKey == "" { + if len(names) == 1 { + return fmt.Sprintf("请问%s是多少?", names[0]) + } + return fmt.Sprintf("请补充以下信息:%s", missingStr) } - return fmt.Sprintf("请补充以下信息:%s", strings.Join(names, "、")) + + // 动态人设逻辑 + personaMode := "balance" + healthScore := 60 + if fc != nil && fc.TotalAssets > 0 { + ratio := (fc.TotalAssets - fc.TotalLiabilities) / fc.TotalAssets + healthScore = 40 + int(ratio*50) + } + if healthScore > 80 { + personaMode = "rich" + } else if healthScore <= 40 { + personaMode = "poor" + } + + prompt := fmt.Sprintf(`你是「小金」,一个贱萌、毒舌、幽默的 AI 记账助手。 +当前模式:%s +缺失信息:%s +用户说:「%s」 + +请用你的风格追问用户缺失的信息。 +要求: +1. 针对缺失的 "%s",用幽默、调侃、甚至一点点“攻击性”的方式追问。 +2. **拒绝机械模板**:不要每次都问一样的话,要根据心情随机应变。 +3. 比如缺失金额,可以问"是白嫖的吗?"、"价格是国家机密吗?"、"快说花了多少,让我死心"。 +4. 比如缺失分类,可以猜"这属于吃喝玩乐哪一类?"、"是用来填饱肚子的还是填补空虚的?"。 +5. 字数不限,关键是要**有灵魂**,像个真人在聊天。`, + personaMode, missingStr, userMessage, missingStr) + + messages := []ChatMessage{ + {Role: "user", Content: prompt}, + } + + answer := s.callLLM(ctx, messages) + if answer == "" || answer == "..." { + return fmt.Sprintf("那个...你能告诉我%s吗?", missingStr) + } + return answer } // getTypeLabel returns Chinese label for transaction type From 991453b10daeae4a8d7af62281281054ade43d4d Mon Sep 17 00:00:00 2001 From: 12975 <1297598740@qq.com> Date: Thu, 29 Jan 2026 22:51:48 +0800 Subject: [PATCH 09/13] =?UTF-8?q?feat:=20=E6=B7=BB=E5=8A=A0=20AI=20?= =?UTF-8?q?=E8=AE=B0=E8=B4=A6=E6=9C=8D=E5=8A=A1?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- internal/service/ai_bookkeeping_service.go | 105 ++++++++++++++++++++- 1 file changed, 101 insertions(+), 4 deletions(-) diff --git a/internal/service/ai_bookkeeping_service.go b/internal/service/ai_bookkeeping_service.go index 429dc2d..3737dd8 100644 --- a/internal/service/ai_bookkeeping_service.go +++ b/internal/service/ai_bookkeeping_service.go @@ -175,6 +175,15 @@ func (s *AIBookkeepingService) GenerateDailyInsight(ctx context.Context, userID * 暴多:问是不是中了彩票没通知。 * 暴少:怀疑用户是不是被外星人绑架了(没机会花钱)。 * 是 0:颁发“诺贝尔省钱学奖”,或者问是不是在修炼辟谷。 + - 看到 'UpcomingRecurring' (即将到来的固定支出): + * 如果有,务必提醒用户:“别光顾着浪,过两天还有[内容]要扣款呢!”。 + - 看到 'DebtRatio' > 0.5 (高负债): + * 开启“恐慌模式”,提醒用户天台风大,要勒紧裤腰带。 + - 看到 'MaxSingleSpend' (最大单笔支出): + * 直接点名该笔交易:“你那个[金额]元的[备注]是金子做的吗?” + - 看到 'SavingsProgress' (存钱进度): + * 进度慢:催促一下,“存钱罐都要饿瘦了”。 + * 进度快:狠狠夸奖,“离首富又近了一步”。 - **关键原则:怎么有趣怎么来!不要在乎字数,哪怕只说一句“牛逼”也行,只要符合当时的情境和人设!** 2. "budget": 预算建议(字数不限) @@ -1041,10 +1050,22 @@ func (s *AIBookkeepingService) handleQueryIntent(ctx context.Context, intent str personaMode := "balance" // 默认平衡 healthScore := 60 // 默认及格 - // 简单估算健康分 (需与前端算法保持一致性趋势) + // 升级版健康分计算 (结合负债率) if fc.TotalAssets > 0 { + // 基础分:资产/负债比 ratio := (fc.TotalAssets - fc.TotalLiabilities) / fc.TotalAssets healthScore = 40 + int(ratio*50) + + // 负债率惩罚 + if fc.DebtRatio > 0.5 { + healthScore -= 20 + } + if fc.DebtRatio > 0.8 { + healthScore -= 20 // 严重扣分 + } + } else if fc.TotalLiabilities > 0 { + // 资不抵债 + healthScore = 10 } if healthScore > 80 { @@ -1145,7 +1166,13 @@ func (s *AIBookkeepingService) generateSpendingAdvice(ctx context.Context, messa if fc.TotalAssets > 0 { ratio := (fc.TotalAssets - fc.TotalLiabilities) / fc.TotalAssets healthScore = 40 + int(ratio*50) + if fc.DebtRatio > 0.5 { + healthScore -= 20 + } + } else if fc.TotalLiabilities > 0 { + healthScore = 10 } + if healthScore > 80 { personaMode = "rich" } else if healthScore <= 40 { @@ -1488,18 +1515,22 @@ type FinancialContext struct { TotalBalance float64 `json:"total_balance"` // 净资产 (资产 - 负债) TotalAssets float64 `json:"total_assets"` // 总资产 TotalLiabilities float64 `json:"total_liabilities"` // 总负债 + DebtRatio float64 `json:"debt_ratio"` // 负债率 (0-1) AccountSummary []AccountBrief `json:"account_summary"` // 账户摘要 // 最近消费 Last30DaysSpend float64 `json:"last_30_days_spend"` // 近30天支出 Last7DaysSpend float64 `json:"last_7_days_spend"` // 近7天支出 TodaySpend float64 `json:"today_spend"` // 今日支出 + MaxSingleSpend *TransactionBrief `json:"max_single_spend"` // 本月最大单笔支出 TopCategories []CategorySpend `json:"top_categories"` // 消费大类TOP3 RecentTransactions []TransactionBrief `json:"recent_transactions"` // 最近5笔交易 + UpcomingRecurring []string `json:"upcoming_recurring"` // 未来7天固定支出提醒 - // 预算信息 - ActiveBudgets []BudgetBrief `json:"active_budgets"` // 活跃预算 - BudgetWarnings []string `json:"budget_warnings"` // 预算警告 + // 预算与目标 + ActiveBudgets []BudgetBrief `json:"active_budgets"` // 活跃预算 + BudgetWarnings []string `json:"budget_warnings"` // 预算警告 + SavingsProgress []string `json:"savings_progress"` // 存钱进度摘要 // 历史对比 LastMonthSpend float64 `json:"last_month_spend"` // 上月同期支出 @@ -1734,6 +1765,72 @@ func (s *AIBookkeepingService) GetUserFinancialContext(ctx context.Context, user } } + // 计算负债率 + if fc.TotalAssets+fc.TotalLiabilities > 0 { + fc.DebtRatio = fc.TotalLiabilities / (fc.TotalAssets + fc.TotalLiabilities) + } + + // 5. 获取存钱罐进度 (PiggyBank) + var piggyBanks []models.PiggyBank + if err := s.db.Where("user_id = ?", userID).Find(&piggyBanks).Error; err == nil { + for _, pb := range piggyBanks { + progress := 0.0 + if pb.TargetAmount > 0 { + progress = pb.CurrentAmount / pb.TargetAmount * 100 + } + status := "进行中" + if progress >= 100 { + status = "已达成" + } + fc.SavingsProgress = append(fc.SavingsProgress, fmt.Sprintf("%s: %.0f/%.0f (%.1f%%) - %s", pb.Name, pb.CurrentAmount, pb.TargetAmount, progress, status)) + } + } + + // 6. 获取未来7天即将到期的定期事务 (RecurringTransaction) + var recurrings []models.RecurringTransaction + next7Days := now.AddDate(0, 0, 7) + if err := s.db.Where("user_id = ? AND is_active = ? AND type = ?", userID, true, models.TransactionTypeExpense). + Where("next_occurrence BETWEEN ? AND ?", now, next7Days). + Find(&recurrings).Error; err == nil { + for _, r := range recurrings { + days := int(r.NextOccurrence.Sub(now).Hours() / 24) + msg := fmt.Sprintf("%s 还有 %d 天扣款 %.2f元", r.Note, days, r.Amount) + if r.Note == "" { + // 获取分类名 + var cat models.Category + s.db.First(&cat, r.CategoryID) + msg = fmt.Sprintf("%s 还有 %d 天扣款 %.2f元", cat.Name, days, r.Amount) + } + fc.UpcomingRecurring = append(fc.UpcomingRecurring, msg) + } + } + + // 7. 计算本月最大单笔支出 + if len(transactions) > 0 { + var maxTx *models.Transaction + for _, tx := range transactions { + if tx.Type == models.TransactionTypeExpense { + if maxTx == nil || tx.Amount > maxTx.Amount { + currentTx := tx // 创建副本以避免取地址问题 + maxTx = ¤tTx + } + } + } + if maxTx != nil { + catName := "其他" + if maxTx.Category.ID != 0 { + catName = maxTx.Category.Name + } + fc.MaxSingleSpend = &TransactionBrief{ + Amount: maxTx.Amount, + Category: catName, + Note: maxTx.Note, + Date: maxTx.TransactionDate.Format("01-02"), + Type: string(maxTx.Type), + } + } + } + return fc, nil } From d974e81d1d7b6caf946c9c0fe7ec66439d121730 Mon Sep 17 00:00:00 2001 From: 12975 <1297598740@qq.com> Date: Thu, 29 Jan 2026 23:26:12 +0800 Subject: [PATCH 10/13] =?UTF-8?q?feat:=20=E6=96=B0=E5=A2=9E=20AI=20?= =?UTF-8?q?=E8=AE=B0=E8=B4=A6=E6=9C=8D=E5=8A=A1?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- internal/service/ai_bookkeeping_service.go | 88 ++++++++++++++-------- 1 file changed, 58 insertions(+), 30 deletions(-) diff --git a/internal/service/ai_bookkeeping_service.go b/internal/service/ai_bookkeeping_service.go index 3737dd8..2b1593e 100644 --- a/internal/service/ai_bookkeeping_service.go +++ b/internal/service/ai_bookkeeping_service.go @@ -259,6 +259,7 @@ type ConfirmationCard struct { Date string `json:"date"` Note string `json:"note,omitempty"` IsComplete bool `json:"is_complete"` + Warning string `json:"warning,omitempty"` // Budget overrun warning } // AIChatResponse represents the response from AI chat @@ -1456,45 +1457,72 @@ func (s *AIBookkeepingService) ConfirmTransaction(ctx context.Context, sessionID txType = models.TransactionTypeIncome } - // Create transaction - tx := &models.Transaction{ - UserID: userID, - Amount: *params.Amount, - Type: txType, - CategoryID: *params.CategoryID, - AccountID: *params.AccountID, - TransactionDate: txDate, - Note: params.Note, - Currency: "CNY", - } + var transaction *models.Transaction - // Save transaction - if err := s.transactionRepo.Create(tx); err != nil { - return nil, fmt.Errorf("failed to create transaction: %w", err) - } + // Execute within a database transaction to ensure atomicity + err := s.db.Transaction(func(tx *gorm.DB) error { + txAccountRepo := repository.NewAccountRepository(tx) + txTransactionRepo := repository.NewTransactionRepository(tx) + + // Get account first to check balance + account, err := txAccountRepo.GetByID(userID, *params.AccountID) + if err != nil { + return fmt.Errorf("failed to find account: %w", err) + } + + // Calculate new balance + newBalance := account.Balance + if txType == models.TransactionTypeExpense { + newBalance -= *params.Amount + } else { + newBalance += *params.Amount + } + + // Critical Check: Prevent negative balance for non-credit accounts + if !account.IsCredit && newBalance < 0 { + return fmt.Errorf("insufficient balance: account '%s' does not support negative balance (current: %.2f, try: %.2f)", + account.Name, account.Balance, *params.Amount) + } + + // Create transaction model + transaction = &models.Transaction{ + UserID: userID, + Amount: *params.Amount, + Type: txType, + CategoryID: *params.CategoryID, + AccountID: *params.AccountID, + TransactionDate: txDate, + Note: params.Note, + Currency: models.Currency(account.Currency), // Use account currency + } + if transaction.Currency == "" { + transaction.Currency = "CNY" + } + + // Save transaction + if err := txTransactionRepo.Create(transaction); err != nil { + return fmt.Errorf("failed to create transaction: %w", err) + } + + // Update account balance + account.Balance = newBalance + if err := txAccountRepo.Update(account); err != nil { + return fmt.Errorf("failed to update account balance: %w", err) + } + + return nil + }) - // Update account balance - account, err := s.accountRepo.GetByID(userID, *params.AccountID) if err != nil { - return nil, fmt.Errorf("failed to find account: %w", err) + return nil, err } - if txType == models.TransactionTypeExpense { - account.Balance -= *params.Amount - } else { - account.Balance += *params.Amount - } - - if err := s.accountRepo.Update(account); err != nil { - return nil, fmt.Errorf("failed to update account balance: %w", err) - } - - // Clean up session + // Clean up session only on success s.sessionMutex.Lock() delete(s.sessions, sessionID) s.sessionMutex.Unlock() - return tx, nil + return transaction, nil } // GetSession returns session by ID From 28ff8b80c2b545a741807b843d1a9db32eda2977 Mon Sep 17 00:00:00 2001 From: 12975 <1297598740@qq.com> Date: Thu, 29 Jan 2026 23:36:57 +0800 Subject: [PATCH 11/13] =?UTF-8?q?feat:=20=E6=96=B0=E5=A2=9E=E8=B4=A6?= =?UTF-8?q?=E6=88=B7=E6=9C=8D=E5=8A=A1=E5=B1=82=EF=BC=8C=E5=AE=9E=E7=8E=B0?= =?UTF-8?q?=E8=B4=A6=E6=88=B7=E7=9A=84=E5=A2=9E=E5=88=A0=E6=94=B9=E6=9F=A5?= =?UTF-8?q?=E5=8F=8A=E8=BD=AC=E8=B4=A6=E5=8A=9F=E8=83=BD=EF=BC=8C=E5=B9=B6?= =?UTF-8?q?=E5=88=9D=E5=A7=8B=E5=8C=96AI=E8=AE=B0=E8=B4=A6=E3=80=81?= =?UTF-8?q?=E4=BA=A4=E6=98=93=E6=9C=8D=E5=8A=A1=E5=92=8C=E9=A2=84=E7=AE=97?= =?UTF-8?q?=E4=BB=93=E5=BA=93=E3=80=82?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- internal/repository/budget_repository.go | 8 ++- internal/service/account_service.go | 12 ++-- internal/service/ai_bookkeeping_service.go | 70 ++++++++++++++++++++-- internal/service/transaction_service.go | 21 +++---- 4 files changed, 90 insertions(+), 21 deletions(-) diff --git a/internal/repository/budget_repository.go b/internal/repository/budget_repository.go index 410a864..ce3f8db 100644 --- a/internal/repository/budget_repository.go +++ b/internal/repository/budget_repository.go @@ -144,7 +144,13 @@ func (r *BudgetRepository) GetSpentAmount(budget *models.Budget, startDate, endD // Filter by category if specified if budget.CategoryID != nil { - query = query.Where("category_id = ?", *budget.CategoryID) + // Get sub-categories + var subCategoryIDs []uint + r.db.Model(&models.Category{}).Where("parent_id = ?", *budget.CategoryID).Pluck("id", &subCategoryIDs) + + // Include the category itself and all its children + categoryIDs := append(subCategoryIDs, *budget.CategoryID) + query = query.Where("category_id IN ?", categoryIDs) } // Filter by account if specified diff --git a/internal/service/account_service.go b/internal/service/account_service.go index 4efdb80..45b41b8 100644 --- a/internal/service/account_service.go +++ b/internal/service/account_service.go @@ -12,12 +12,12 @@ import ( // Service layer errors var ( - ErrAccountNotFound = errors.New("account not found") - ErrAccountInUse = errors.New("account is in use and cannot be deleted") - ErrInsufficientBalance = errors.New("insufficient balance for this operation") - ErrSameAccountTransfer = errors.New("cannot transfer to the same account") - ErrInvalidTransferAmount = errors.New("transfer amount must be positive") - ErrNegativeBalanceNotAllowed = errors.New("negative balance not allowed for non-credit accounts") + ErrAccountNotFound = errors.New("账户不存在") + ErrAccountInUse = errors.New("账户正在使用中,无法删除") + ErrInsufficientBalance = errors.New("余额不足") + ErrSameAccountTransfer = errors.New("不能转账给同一个账户") + ErrInvalidTransferAmount = errors.New("转账金额必须大于0") + ErrNegativeBalanceNotAllowed = errors.New("非信用账户不允许负余额") ) // AccountInput represents the input data for creating or updating an account diff --git a/internal/service/ai_bookkeeping_service.go b/internal/service/ai_bookkeeping_service.go index 2b1593e..66ec7b1 100644 --- a/internal/service/ai_bookkeeping_service.go +++ b/internal/service/ai_bookkeeping_service.go @@ -1407,9 +1407,71 @@ func (s *AIBookkeepingService) GenerateConfirmationCard(session *AISession) *Con card.Date = time.Now().Format("2006-01-02") } + // Check for budget warnings + if params.Amount != nil && *params.Amount > 0 && params.Type == "expense" { + card.Warning = s.checkBudgetWarning(session.UserID, params.CategoryID, *params.Amount) + } + return card } +// checkBudgetWarning checks if the transaction exceeds any budget +func (s *AIBookkeepingService) checkBudgetWarning(userID uint, categoryID *uint, amount float64) string { + now := time.Now() + var budgets []models.Budget + + // Find active budgets + // We check: + // 1. Budgets specifically for this category + // 2. Global budgets (CategoryID is NULL) + query := s.db.Where("user_id = ?", userID). + Where("start_date <= ?", now). + Where("end_date IS NULL OR end_date >= ?", now) + + if categoryID != nil { + query = query.Where("category_id = ? OR category_id IS NULL", *categoryID) + } else { + query = query.Where("category_id IS NULL") + } + + if err := query.Find(&budgets).Error; err != nil { + return "" + } + + for _, budget := range budgets { + // Calculate current period + start, end := s.calculateBudgetPeriod(&budget, now) + + // Query spent amount + var totalSpent float64 + q := s.db.Model(&models.Transaction{}). + Where("user_id = ? AND type = ? AND transaction_date BETWEEN ? AND ?", + userID, models.TransactionTypeExpense, start, end) + + if budget.CategoryID != nil { + // Get sub-categories + var subCategoryIDs []uint + s.db.Model(&models.Category{}).Where("parent_id = ?", *budget.CategoryID).Pluck("id", &subCategoryIDs) + categoryIDs := append(subCategoryIDs, *budget.CategoryID) + q = q.Where("category_id IN ?", categoryIDs) + } + // If budget has account restriction, we should ideally check that too, + // but we don't always have accountID resolved here perfectly or it might be complex. + // For now, focusing on category budgets which are most common. + + q.Select("COALESCE(SUM(amount), 0)").Scan(&totalSpent) + + if totalSpent+amount > budget.Amount { + remaining := budget.Amount - totalSpent + if remaining < 0 { + remaining = 0 + } + return fmt.Sprintf("⚠️ 预算预警:此交易将使【%s】预算超支 (当前剩余 %.2f)", budget.Name, remaining) + } + } + return "" +} + // TranscribeAudio transcribes audio and returns text func (s *AIBookkeepingService) TranscribeAudio(ctx context.Context, audioData io.Reader, filename string) (*TranscriptionResult, error) { return s.whisperService.TranscribeAudio(ctx, audioData, filename) @@ -1430,13 +1492,13 @@ func (s *AIBookkeepingService) ConfirmTransaction(ctx context.Context, sessionID // Validate required fields if params.Amount == nil || *params.Amount <= 0 { - return nil, errors.New("invalid amount") + return nil, errors.New("无效的金额") } if params.CategoryID == nil { - return nil, errors.New("category not specified") + return nil, errors.New("未指定分类") } if params.AccountID == nil { - return nil, errors.New("account not specified") + return nil, errors.New("未指定账户") } // Parse date @@ -1480,7 +1542,7 @@ func (s *AIBookkeepingService) ConfirmTransaction(ctx context.Context, sessionID // Critical Check: Prevent negative balance for non-credit accounts if !account.IsCredit && newBalance < 0 { - return fmt.Errorf("insufficient balance: account '%s' does not support negative balance (current: %.2f, try: %.2f)", + return fmt.Errorf("余额不足:账户“%s”不支持负余额 (当前: %.2f, 尝试扣款: %.2f)", account.Name, account.Balance, *params.Amount) } diff --git a/internal/service/transaction_service.go b/internal/service/transaction_service.go index 1ed3357..b65a451 100644 --- a/internal/service/transaction_service.go +++ b/internal/service/transaction_service.go @@ -11,18 +11,19 @@ import ( "gorm.io/gorm" ) +// Transaction service errors // Transaction service errors var ( - ErrTransactionNotFound = errors.New("transaction not found") - ErrInvalidTransactionType = errors.New("invalid transaction type") - ErrMissingRequiredField = errors.New("missing required field") - ErrInvalidAmount = errors.New("amount must be positive") - ErrInvalidCurrency = errors.New("invalid currency") - ErrCategoryNotFoundForTxn = errors.New("category not found") - ErrAccountNotFoundForTxn = errors.New("account not found") - ErrToAccountNotFoundForTxn = errors.New("destination account not found for transfer") - ErrToAccountRequiredForTxn = errors.New("destination account is required for transfer transactions") - ErrSameAccountTransferForTxn = errors.New("cannot transfer to the same account") + ErrTransactionNotFound = errors.New("交易不存在") + ErrInvalidTransactionType = errors.New("无效的交易类型") + ErrMissingRequiredField = errors.New("缺少必填字段") + ErrInvalidAmount = errors.New("金额必须大于0") + ErrInvalidCurrency = errors.New("无效的货币") + ErrCategoryNotFoundForTxn = errors.New("分类不存在") + ErrAccountNotFoundForTxn = errors.New("账户不存在") + ErrToAccountNotFoundForTxn = errors.New("转账目标账户不存在") + ErrToAccountRequiredForTxn = errors.New("转账必须指定目标账户") + ErrSameAccountTransferForTxn = errors.New("不能转账给同一个账户") ) // TransactionInput represents the input data for creating or updating a transaction From 177f4e5b14ef834bd6792af57532740bc5674a30 Mon Sep 17 00:00:00 2001 From: 12975 <1297598740@qq.com> Date: Fri, 30 Jan 2026 00:00:25 +0800 Subject: [PATCH 12/13] =?UTF-8?q?feat:=20=E6=96=B0=E5=A2=9E=E4=BA=A4?= =?UTF-8?q?=E6=98=93=E8=AE=B0=E5=BD=95=E4=BB=93=E5=BA=93=E5=92=8CAI?= =?UTF-8?q?=E8=AE=B0=E8=B4=A6=E6=9C=8D=E5=8A=A1=E3=80=82?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- internal/repository/transaction_repository.go | 6 ++++- internal/service/ai_bookkeeping_service.go | 27 ++++++++++++++----- 2 files changed, 26 insertions(+), 7 deletions(-) diff --git a/internal/repository/transaction_repository.go b/internal/repository/transaction_repository.go index 8578e70..5533842 100644 --- a/internal/repository/transaction_repository.go +++ b/internal/repository/transaction_repository.go @@ -264,7 +264,11 @@ func (r *TransactionRepository) applyFilters(query *gorm.DB, filter TransactionF // Entity filters if filter.CategoryID != nil { - query = query.Where("category_id = ?", *filter.CategoryID) + // Include sub-categories + var subCategoryIDs []uint + r.db.Model(&models.Category{}).Where("parent_id = ?", *filter.CategoryID).Pluck("id", &subCategoryIDs) + ids := append(subCategoryIDs, *filter.CategoryID) + query = query.Where("category_id IN ?", ids) } if filter.AccountID != nil { query = query.Where("account_id = ? OR to_account_id = ?", *filter.AccountID, *filter.AccountID) diff --git a/internal/service/ai_bookkeeping_service.go b/internal/service/ai_bookkeeping_service.go index 66ec7b1..9188ea5 100644 --- a/internal/service/ai_bookkeeping_service.go +++ b/internal/service/ai_bookkeeping_service.go @@ -445,31 +445,45 @@ func extractCustomPrompt(text string) string { // ParseIntent extracts transaction parameters from text // Requirements: 7.1, 7.5, 7.6 -func (s *LLMService) ParseIntent(ctx context.Context, text string, history []ChatMessage) (*AITransactionParams, string, error) { - +func (s *LLMService) ParseIntent(ctx context.Context, text string, history []ChatMessage, userID uint) (*AITransactionParams, string, error) { if s.config.OpenAIAPIKey == "" || s.config.OpenAIBaseURL == "" { // No API key, return simple parsing result simpleParams, simpleMsg, _ := s.parseIntentSimple(text) return simpleParams, simpleMsg, nil } + // Fetch user categories for prompt context + var categoryNamesStr string + if categories, err := s.categoryRepo.GetAll(userID); err == nil && len(categories) > 0 { + var names []string + for _, c := range categories { + names = append(names, c.Name) + } + categoryNamesStr = strings.Join(names, "/") + } + // Build messages with history var systemPrompt string // 检查是否有自定义 System/Persona prompt(用于财务建议等场景) - // 如果有,直接使用自定义 prompt 覆盖默认记账 prompt if customPrompt := extractCustomPrompt(text); customPrompt != "" { systemPrompt = customPrompt } else { // 使用默认的记账 prompt todayDate := time.Now().Format("2006-01-02") + + catPrompt := "2. 分类:根据内容推断,如\"奶茶/咖啡/吃饭\"=餐饮" + if categoryNamesStr != "" { + catPrompt = fmt.Sprintf("2. 分类:必须从以下已有分类中选择最匹配的一项:[%s]。例如\"晚餐\"应映射为列表中存在的\"餐饮\"。如果列表无匹配项,则根据常识推断。", categoryNamesStr) + } + systemPrompt = fmt.Sprintf(`你是一个智能记账助手。从用户描述中提取记账信息。 今天的日期是%s 规则: 1. 金额:提取数字,如"6元"=6,"十五"=15 -2. 分类:根据内容推断,如"奶茶/咖啡/吃饭"=餐饮,"打车/地铁"=交通,"买衣服"=购物 +%s 3. 类型:默认expense(支出),除非明确说"收入/工资/奖金/红包" 4. 金额:提取明确的数字。如果用户未提及具体金额(如只说"想吃炸鸡"),amount字段必须返回 0 5. 日期:默认使用今天的日期(%s),除非用户明确指定其他日期 @@ -480,7 +494,8 @@ func (s *LLMService) ParseIntent(ctx context.Context, text string, history []Cha 示例(假设今天是%s): 用户:"买了6块的奶茶" -返回:{"amount":6,"category":"餐饮","type":"expense","note":"奶茶","date":"%s","message":"记录:餐饮支出6元,奶茶"}`, todayDate, todayDate, todayDate, todayDate) +返回:{"amount":6,"category":"餐饮","type":"expense","note":"奶茶","date":"%s","message":"记录:餐饮支出6元,奶茶"}`, + todayDate, catPrompt, todayDate, todayDate, todayDate) } messages := []ChatMessage{ @@ -915,7 +930,7 @@ func (s *AIBookkeepingService) ProcessChat(ctx context.Context, userID uint, ses isSpendingAdvice := s.isSpendingAdviceIntent(message) // Parse intent for transaction - params, responseMsg, err := s.llmService.ParseIntent(ctx, message, session.Messages[:len(session.Messages)-1]) + params, responseMsg, err := s.llmService.ParseIntent(ctx, message, session.Messages[:len(session.Messages)-1], userID) if err != nil { return nil, fmt.Errorf("failed to parse intent: %w", err) } From 8bae0df1b62dbfe59cb5aff8c4e2daa4aa6f76bd Mon Sep 17 00:00:00 2001 From: 12975 <1297598740@qq.com> Date: Fri, 30 Jan 2026 00:10:00 +0800 Subject: [PATCH 13/13] =?UTF-8?q?feat:=20=E6=B7=BB=E5=8A=A0=20AI=20?= =?UTF-8?q?=E8=AE=B0=E8=B4=A6=E6=9C=8D=E5=8A=A1?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- internal/service/ai_bookkeeping_service.go | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/internal/service/ai_bookkeeping_service.go b/internal/service/ai_bookkeeping_service.go index 9188ea5..f9aa6c0 100644 --- a/internal/service/ai_bookkeeping_service.go +++ b/internal/service/ai_bookkeeping_service.go @@ -986,7 +986,8 @@ func (s *AIBookkeepingService) ProcessChat(ctx context.Context, userID uint, ses } // 4. 处理消费建议意图 - if isSpendingAdvice && session.Params.Amount != nil { + // 即使没有金额,如果用户是在寻求建议(如“吃什么”),也应该进入建议流程 + if isSpendingAdvice { response.Intent = "spending_advice" advice := s.generateSpendingAdvice(ctx, message, session.Params, fc) if advice != "" { @@ -1158,7 +1159,7 @@ func (s *AIBookkeepingService) callLLM(ctx context.Context, messages []ChatMessa // isSpendingAdviceIntent 检测是否为消费建议意图 func (s *AIBookkeepingService) isSpendingAdviceIntent(message string) bool { - keywords := []string{"想吃", "想喝", "想买", "想花", "打算买", "准备买", "要不要", "可以买", "能买", "想要"} + keywords := []string{"想吃", "想喝", "想买", "想花", "打算买", "准备买", "要不要", "可以买", "能买", "想要", "推荐", "吃什么", "喝什么", "买什么"} for _, kw := range keywords { if strings.Contains(message, kw) { return true