158 lines
4.4 KiB
Go
158 lines
4.4 KiB
Go
package repository
|
|
|
|
import (
|
|
"errors"
|
|
"time"
|
|
|
|
"accounting-app/internal/models"
|
|
|
|
"gorm.io/gorm"
|
|
)
|
|
|
|
// StreakRepository handles database operations for user streaks
|
|
type StreakRepository struct {
|
|
db *gorm.DB
|
|
}
|
|
|
|
// NewStreakRepository creates a new StreakRepository instance
|
|
func NewStreakRepository(db *gorm.DB) *StreakRepository {
|
|
return &StreakRepository{db: db}
|
|
}
|
|
|
|
// GetByUserID retrieves a user's streak record
|
|
func (r *StreakRepository) GetByUserID(userID uint) (*models.UserStreak, error) {
|
|
var streak models.UserStreak
|
|
if err := r.db.Where("user_id = ?", userID).First(&streak).Error; err != nil {
|
|
if errors.Is(err, gorm.ErrRecordNotFound) {
|
|
return nil, nil // Return nil without error if not found
|
|
}
|
|
return nil, err
|
|
}
|
|
return &streak, nil
|
|
}
|
|
|
|
// Create creates a new streak record
|
|
func (r *StreakRepository) Create(streak *models.UserStreak) error {
|
|
return r.db.Create(streak).Error
|
|
}
|
|
|
|
// Update updates an existing streak record
|
|
func (r *StreakRepository) Update(streak *models.UserStreak) error {
|
|
return r.db.Save(streak).Error
|
|
}
|
|
|
|
// GetOrCreate retrieves existing streak record or creates a new one
|
|
func (r *StreakRepository) GetOrCreate(userID uint) (*models.UserStreak, error) {
|
|
streak, err := r.GetByUserID(userID)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
if streak == nil {
|
|
// Create new streak record
|
|
streak = &models.UserStreak{
|
|
UserID: userID,
|
|
CurrentStreak: 0,
|
|
LongestStreak: 0,
|
|
TotalRecordDays: 0,
|
|
}
|
|
if err := r.Create(streak); err != nil {
|
|
return nil, err
|
|
}
|
|
}
|
|
|
|
return streak, nil
|
|
}
|
|
|
|
// HasTransactionOnDate checks if user has any transaction on the given date
|
|
func (r *StreakRepository) HasTransactionOnDate(userID uint, date time.Time) (bool, error) {
|
|
var count int64
|
|
startOfDay := time.Date(date.Year(), date.Month(), date.Day(), 0, 0, 0, 0, date.Location())
|
|
endOfDay := startOfDay.Add(24 * time.Hour)
|
|
|
|
err := r.db.Model(&models.Transaction{}).
|
|
Where("user_id = ? AND transaction_date >= ? AND transaction_date < ?", userID, startOfDay, endOfDay).
|
|
Count(&count).Error
|
|
|
|
if err != nil {
|
|
return false, err
|
|
}
|
|
|
|
return count > 0, nil
|
|
}
|
|
|
|
// GetTransactionDatesInRange returns all dates with transactions in a date range
|
|
func (r *StreakRepository) GetTransactionDatesInRange(userID uint, startDate, endDate time.Time) ([]time.Time, error) {
|
|
var dates []time.Time
|
|
|
|
rows, err := r.db.Model(&models.Transaction{}).
|
|
Select("DATE(transaction_date) as date").
|
|
Where("user_id = ? AND transaction_date >= ? AND transaction_date <= ?", userID, startDate, endDate).
|
|
Group("DATE(transaction_date)").
|
|
Order("date ASC").
|
|
Rows()
|
|
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
defer rows.Close()
|
|
|
|
for rows.Next() {
|
|
var date time.Time
|
|
if err := rows.Scan(&date); err != nil {
|
|
return nil, err
|
|
}
|
|
dates = append(dates, date)
|
|
}
|
|
|
|
return dates, nil
|
|
}
|
|
|
|
// GetDailyContribution returns daily transaction counts in a date range
|
|
func (r *StreakRepository) GetDailyContribution(userID uint, startDate, endDate time.Time) ([]models.DailyContribution, error) {
|
|
var results []models.DailyContribution
|
|
|
|
// SQLite uses strftime, MySQL/Postgres uses DATE()
|
|
// Using a more generic approach compatible with SQLite (which is likely used locally)
|
|
// For production readiness with multiple DBs, raw SQL might be safer or check dialect
|
|
|
|
rows, err := r.db.Model(&models.Transaction{}).
|
|
Select("DATE(transaction_date) as date, COUNT(*) as count").
|
|
Where("user_id = ? AND transaction_date >= ? AND transaction_date <= ?", userID, startDate, endDate).
|
|
Group("DATE(transaction_date)").
|
|
Order("date ASC").
|
|
Rows()
|
|
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
defer rows.Close()
|
|
|
|
for rows.Next() {
|
|
var dateStr string // Using string to handle potential format differences
|
|
var count int
|
|
// Scan into generic types then convert
|
|
// Some drivers return date as time.Time, some as string/bytes
|
|
// Let's try scan into string first (common for DATE() function result)
|
|
// Or scan into interface{} to be safe
|
|
if err := rows.Scan(&dateStr, &count); err != nil {
|
|
// If string scan fails, try time.Time
|
|
// Unfortunately we can't rewind rows. scan is one-way.
|
|
// But usually drivers handle string conversion for DATE()
|
|
// If this fails we might need to adjust based on specific DB driver
|
|
return nil, err
|
|
}
|
|
// Normalize date string to YYYY-MM-DD (take first 10 chars if it includes time)
|
|
if len(dateStr) > 10 {
|
|
dateStr = dateStr[:10]
|
|
}
|
|
|
|
results = append(results, models.DailyContribution{
|
|
Date: dateStr,
|
|
Count: count,
|
|
})
|
|
}
|
|
|
|
return results, nil
|
|
}
|