package repository import ( "errors" "time" "accounting-app/internal/models" "gorm.io/gorm" ) // StreakRepository handles database operations for user streaks type StreakRepository struct { db *gorm.DB } // NewStreakRepository creates a new StreakRepository instance func NewStreakRepository(db *gorm.DB) *StreakRepository { return &StreakRepository{db: db} } // GetByUserID retrieves a user's streak record func (r *StreakRepository) GetByUserID(userID uint) (*models.UserStreak, error) { var streak models.UserStreak if err := r.db.Where("user_id = ?", userID).First(&streak).Error; err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { return nil, nil // Return nil without error if not found } return nil, err } return &streak, nil } // Create creates a new streak record func (r *StreakRepository) Create(streak *models.UserStreak) error { return r.db.Create(streak).Error } // Update updates an existing streak record func (r *StreakRepository) Update(streak *models.UserStreak) error { return r.db.Save(streak).Error } // GetOrCreate retrieves existing streak record or creates a new one func (r *StreakRepository) GetOrCreate(userID uint) (*models.UserStreak, error) { streak, err := r.GetByUserID(userID) if err != nil { return nil, err } if streak == nil { // Create new streak record streak = &models.UserStreak{ UserID: userID, CurrentStreak: 0, LongestStreak: 0, TotalRecordDays: 0, } if err := r.Create(streak); err != nil { return nil, err } } return streak, nil } // HasTransactionOnDate checks if user has any transaction on the given date func (r *StreakRepository) HasTransactionOnDate(userID uint, date time.Time) (bool, error) { var count int64 startOfDay := time.Date(date.Year(), date.Month(), date.Day(), 0, 0, 0, 0, date.Location()) endOfDay := startOfDay.Add(24 * time.Hour) err := r.db.Model(&models.Transaction{}). Where("user_id = ? AND transaction_date >= ? AND transaction_date < ?", userID, startOfDay, endOfDay). Count(&count).Error if err != nil { return false, err } return count > 0, nil } // GetTransactionDatesInRange returns all dates with transactions in a date range func (r *StreakRepository) GetTransactionDatesInRange(userID uint, startDate, endDate time.Time) ([]time.Time, error) { var dates []time.Time rows, err := r.db.Model(&models.Transaction{}). Select("DATE(transaction_date) as date"). Where("user_id = ? AND transaction_date >= ? AND transaction_date <= ?", userID, startDate, endDate). Group("DATE(transaction_date)"). Order("date ASC"). Rows() if err != nil { return nil, err } defer rows.Close() for rows.Next() { var date time.Time if err := rows.Scan(&date); err != nil { return nil, err } dates = append(dates, date) } return dates, nil } // GetDailyContribution returns daily transaction counts in a date range func (r *StreakRepository) GetDailyContribution(userID uint, startDate, endDate time.Time) ([]models.DailyContribution, error) { var results []models.DailyContribution // SQLite uses strftime, MySQL/Postgres uses DATE() // Using a more generic approach compatible with SQLite (which is likely used locally) // For production readiness with multiple DBs, raw SQL might be safer or check dialect rows, err := r.db.Model(&models.Transaction{}). Select("DATE(transaction_date) as date, COUNT(*) as count"). Where("user_id = ? AND transaction_date >= ? AND transaction_date <= ?", userID, startDate, endDate). Group("DATE(transaction_date)"). Order("date ASC"). Rows() if err != nil { return nil, err } defer rows.Close() for rows.Next() { var dateStr string // Using string to handle potential format differences var count int // Scan into generic types then convert // Some drivers return date as time.Time, some as string/bytes // Let's try scan into string first (common for DATE() function result) // Or scan into interface{} to be safe if err := rows.Scan(&dateStr, &count); err != nil { // If string scan fails, try time.Time // Unfortunately we can't rewind rows. scan is one-way. // But usually drivers handle string conversion for DATE() // If this fails we might need to adjust based on specific DB driver return nil, err } // Normalize date string to YYYY-MM-DD (take first 10 chars if it includes time) if len(dateStr) > 10 { dateStr = dateStr[:10] } results = append(results, models.DailyContribution{ Date: dateStr, Count: count, }) } return results, nil }