Files
Novault-backend/internal/repository/notification_repository.go

158 lines
4.5 KiB
Go

package repository
import (
"errors"
"fmt"
"time"
"accounting-app/internal/models"
"gorm.io/gorm"
)
// Notification repository errors
var (
ErrNotificationNotFound = errors.New("notification not found")
)
// NotificationRepository handles database operations for notifications
type NotificationRepository struct {
db *gorm.DB
}
// NewNotificationRepository creates a new NotificationRepository instance
func NewNotificationRepository(db *gorm.DB) *NotificationRepository {
return &NotificationRepository{db: db}
}
// Create creates a new notification in the database
func (r *NotificationRepository) Create(notification *models.Notification) error {
if err := r.db.Create(notification).Error; err != nil {
return fmt.Errorf("failed to create notification: %w", err)
}
return nil
}
// GetByID retrieves a notification by its ID
func (r *NotificationRepository) GetByID(userID uint, id uint) (*models.Notification, error) {
var notification models.Notification
if err := r.db.Where("user_id = ?", userID).First(&notification, id).Error; err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, ErrNotificationNotFound
}
return nil, fmt.Errorf("failed to get notification: %w", err)
}
return &notification, nil
}
// NotificationListOptions contains options for listing notifications
type NotificationListOptions struct {
Type *models.NotificationType
IsRead *bool
Offset int
Limit int
}
// NotificationListResult contains the result of a paginated notification list query
type NotificationListResult struct {
Notifications []models.Notification
Total int64
Offset int
Limit int
}
// List retrieves notifications with pagination and filtering
func (r *NotificationRepository) List(userID uint, options NotificationListOptions) (*NotificationListResult, error) {
query := r.db.Model(&models.Notification{}).Where("user_id = ?", userID)
// Apply filters
if options.Type != nil {
query = query.Where("type = ?", *options.Type)
}
if options.IsRead != nil {
query = query.Where("is_read = ?", *options.IsRead)
}
// Count total before pagination
var total int64
if err := query.Count(&total).Error; err != nil {
return nil, fmt.Errorf("failed to count notifications: %w", err)
}
// Apply sorting (newest first)
query = query.Order("created_at DESC")
// Apply pagination
if options.Limit > 0 {
query = query.Limit(options.Limit)
}
if options.Offset > 0 {
query = query.Offset(options.Offset)
}
// Execute query
var notifications []models.Notification
if err := query.Find(&notifications).Error; err != nil {
return nil, fmt.Errorf("failed to list notifications: %w", err)
}
return &NotificationListResult{
Notifications: notifications,
Total: total,
Offset: options.Offset,
Limit: options.Limit,
}, nil
}
// MarkAsRead marks a notification as read
func (r *NotificationRepository) MarkAsRead(userID uint, id uint) error {
result := r.db.Model(&models.Notification{}).
Where("user_id = ? AND id = ?", userID, id).
Update("is_read", true)
if result.Error != nil {
return fmt.Errorf("failed to mark notification as read: %w", result.Error)
}
if result.RowsAffected == 0 {
return ErrNotificationNotFound
}
return nil
}
// MarkAllAsRead marks all notifications as read for a user
func (r *NotificationRepository) MarkAllAsRead(userID uint) error {
return r.db.Model(&models.Notification{}).
Where("user_id = ? AND is_read = ?", userID, false).
Updates(map[string]interface{}{
"is_read": true,
"read_at": time.Now(),
}).Error
}
// Delete deletes a notification by its ID
func (r *NotificationRepository) Delete(userID uint, id uint) error {
result := r.db.Where("user_id = ? AND id = ?", userID, id).Delete(&models.Notification{})
if result.Error != nil {
return fmt.Errorf("failed to delete notification: %w", result.Error)
}
if result.RowsAffected == 0 {
return ErrNotificationNotFound
}
return nil
}
// GetUnreadCount returns the count of unread notifications for a user
func (r *NotificationRepository) GetUnreadCount(userID uint) (int64, error) {
var count int64
if err := r.db.Model(&models.Notification{}).
Where("user_id = ? AND is_read = ?", userID, false).
Count(&count).Error; err != nil {
return 0, fmt.Errorf("failed to count unread notifications: %w", err)
}
return count, nil
}
// CreateBatch creates multiple notifications in a single transaction
func (r *NotificationRepository) CreateBatch(notifications []models.Notification) error {
return r.db.CreateInBatches(notifications, 100).Error
}