-feat 修复部分bug

This commit is contained in:
2026-01-27 18:42:35 +08:00
parent a339a9adce
commit 6422f2c45f
10 changed files with 1001 additions and 1 deletions

View File

@@ -0,0 +1,151 @@
package repository
import (
"errors"
"fmt"
"accounting-app/internal/models"
"gorm.io/gorm"
)
// Notification repository errors
var (
ErrNotificationNotFound = errors.New("notification not found")
)
// NotificationRepository handles database operations for notifications
type NotificationRepository struct {
db *gorm.DB
}
// NewNotificationRepository creates a new NotificationRepository instance
func NewNotificationRepository(db *gorm.DB) *NotificationRepository {
return &NotificationRepository{db: db}
}
// Create creates a new notification in the database
func (r *NotificationRepository) Create(notification *models.Notification) error {
if err := r.db.Create(notification).Error; err != nil {
return fmt.Errorf("failed to create notification: %w", err)
}
return nil
}
// GetByID retrieves a notification by its ID
func (r *NotificationRepository) GetByID(userID uint, id uint) (*models.Notification, error) {
var notification models.Notification
if err := r.db.Where("user_id = ?", userID).First(&notification, id).Error; err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, ErrNotificationNotFound
}
return nil, fmt.Errorf("failed to get notification: %w", err)
}
return &notification, nil
}
// NotificationListOptions contains options for listing notifications
type NotificationListOptions struct {
Type *models.NotificationType
IsRead *bool
Offset int
Limit int
}
// NotificationListResult contains the result of a paginated notification list query
type NotificationListResult struct {
Notifications []models.Notification
Total int64
Offset int
Limit int
}
// List retrieves notifications with pagination and filtering
func (r *NotificationRepository) List(userID uint, options NotificationListOptions) (*NotificationListResult, error) {
query := r.db.Model(&models.Notification{}).Where("user_id = ?", userID)
// Apply filters
if options.Type != nil {
query = query.Where("type = ?", *options.Type)
}
if options.IsRead != nil {
query = query.Where("is_read = ?", *options.IsRead)
}
// Count total before pagination
var total int64
if err := query.Count(&total).Error; err != nil {
return nil, fmt.Errorf("failed to count notifications: %w", err)
}
// Apply sorting (newest first)
query = query.Order("created_at DESC")
// Apply pagination
if options.Limit > 0 {
query = query.Limit(options.Limit)
}
if options.Offset > 0 {
query = query.Offset(options.Offset)
}
// Execute query
var notifications []models.Notification
if err := query.Find(&notifications).Error; err != nil {
return nil, fmt.Errorf("failed to list notifications: %w", err)
}
return &NotificationListResult{
Notifications: notifications,
Total: total,
Offset: options.Offset,
Limit: options.Limit,
}, nil
}
// MarkAsRead marks a notification as read
func (r *NotificationRepository) MarkAsRead(userID uint, id uint) error {
result := r.db.Model(&models.Notification{}).
Where("user_id = ? AND id = ?", userID, id).
Update("is_read", true)
if result.Error != nil {
return fmt.Errorf("failed to mark notification as read: %w", result.Error)
}
if result.RowsAffected == 0 {
return ErrNotificationNotFound
}
return nil
}
// MarkAllAsRead marks all notifications as read for a user
func (r *NotificationRepository) MarkAllAsRead(userID uint) error {
if err := r.db.Model(&models.Notification{}).
Where("user_id = ? AND is_read = ?", userID, false).
Update("is_read", true).Error; err != nil {
return fmt.Errorf("failed to mark all notifications as read: %w", err)
}
return nil
}
// Delete deletes a notification by its ID
func (r *NotificationRepository) Delete(userID uint, id uint) error {
result := r.db.Where("user_id = ? AND id = ?", userID, id).Delete(&models.Notification{})
if result.Error != nil {
return fmt.Errorf("failed to delete notification: %w", result.Error)
}
if result.RowsAffected == 0 {
return ErrNotificationNotFound
}
return nil
}
// GetUnreadCount returns the count of unread notifications for a user
func (r *NotificationRepository) GetUnreadCount(userID uint) (int64, error) {
var count int64
if err := r.db.Model(&models.Notification{}).
Where("user_id = ? AND is_read = ?", userID, false).
Count(&count).Error; err != nil {
return 0, fmt.Errorf("failed to count unread notifications: %w", err)
}
return count, nil
}

View File

@@ -24,6 +24,7 @@ type TransactionFilter struct {
// Entity filters
CategoryID *uint
AccountID *uint
LedgerID *uint // 账本过滤
TagIDs []uint
Type *models.TransactionType
Currency *models.Currency
@@ -277,6 +278,13 @@ func (r *TransactionRepository) applyFilters(query *gorm.DB, filter TransactionF
if filter.RecurringID != nil {
query = query.Where("recurring_id = ?", *filter.RecurringID)
}
// LedgerID filter - requires subquery to get accounts belonging to the ledger
if filter.LedgerID != nil {
query = query.Where("account_id IN (?)",
r.db.Model(&models.Account{}).
Select("id").
Where("ledger_id = ?", *filter.LedgerID))
}
// UserID provided in argument takes precedence, but if filter has it, we can redundant check or ignore.
// The caller `List` already applied `Where("user_id = ?", userID)`.