323 lines
10 KiB
Go
323 lines
10 KiB
Go
package repository
|
|
|
|
import (
|
|
"context"
|
|
"time"
|
|
|
|
"sub2api/internal/model"
|
|
|
|
"gorm.io/gorm"
|
|
)
|
|
|
|
// UserSubscriptionRepository 用户订阅仓库
|
|
type UserSubscriptionRepository struct {
|
|
db *gorm.DB
|
|
}
|
|
|
|
// NewUserSubscriptionRepository 创建用户订阅仓库
|
|
func NewUserSubscriptionRepository(db *gorm.DB) *UserSubscriptionRepository {
|
|
return &UserSubscriptionRepository{db: db}
|
|
}
|
|
|
|
// Create 创建订阅
|
|
func (r *UserSubscriptionRepository) Create(ctx context.Context, sub *model.UserSubscription) error {
|
|
return r.db.WithContext(ctx).Create(sub).Error
|
|
}
|
|
|
|
// GetByID 根据ID获取订阅
|
|
func (r *UserSubscriptionRepository) GetByID(ctx context.Context, id int64) (*model.UserSubscription, error) {
|
|
var sub model.UserSubscription
|
|
err := r.db.WithContext(ctx).
|
|
Preload("User").
|
|
Preload("Group").
|
|
Preload("AssignedByUser").
|
|
First(&sub, id).Error
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return &sub, nil
|
|
}
|
|
|
|
// GetByUserIDAndGroupID 根据用户ID和分组ID获取订阅
|
|
func (r *UserSubscriptionRepository) GetByUserIDAndGroupID(ctx context.Context, userID, groupID int64) (*model.UserSubscription, error) {
|
|
var sub model.UserSubscription
|
|
err := r.db.WithContext(ctx).
|
|
Preload("Group").
|
|
Where("user_id = ? AND group_id = ?", userID, groupID).
|
|
First(&sub).Error
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return &sub, nil
|
|
}
|
|
|
|
// GetActiveByUserIDAndGroupID 获取用户对特定分组的有效订阅
|
|
func (r *UserSubscriptionRepository) GetActiveByUserIDAndGroupID(ctx context.Context, userID, groupID int64) (*model.UserSubscription, error) {
|
|
var sub model.UserSubscription
|
|
err := r.db.WithContext(ctx).
|
|
Preload("Group").
|
|
Where("user_id = ? AND group_id = ? AND status = ? AND expires_at > ?",
|
|
userID, groupID, model.SubscriptionStatusActive, time.Now()).
|
|
First(&sub).Error
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return &sub, nil
|
|
}
|
|
|
|
// Update 更新订阅
|
|
func (r *UserSubscriptionRepository) Update(ctx context.Context, sub *model.UserSubscription) error {
|
|
sub.UpdatedAt = time.Now()
|
|
return r.db.WithContext(ctx).Save(sub).Error
|
|
}
|
|
|
|
// Delete 删除订阅
|
|
func (r *UserSubscriptionRepository) Delete(ctx context.Context, id int64) error {
|
|
return r.db.WithContext(ctx).Delete(&model.UserSubscription{}, id).Error
|
|
}
|
|
|
|
// ListByUserID 获取用户的所有订阅
|
|
func (r *UserSubscriptionRepository) ListByUserID(ctx context.Context, userID int64) ([]model.UserSubscription, error) {
|
|
var subs []model.UserSubscription
|
|
err := r.db.WithContext(ctx).
|
|
Preload("Group").
|
|
Where("user_id = ?", userID).
|
|
Order("created_at DESC").
|
|
Find(&subs).Error
|
|
return subs, err
|
|
}
|
|
|
|
// ListActiveByUserID 获取用户的所有有效订阅
|
|
func (r *UserSubscriptionRepository) ListActiveByUserID(ctx context.Context, userID int64) ([]model.UserSubscription, error) {
|
|
var subs []model.UserSubscription
|
|
err := r.db.WithContext(ctx).
|
|
Preload("Group").
|
|
Where("user_id = ? AND status = ? AND expires_at > ?",
|
|
userID, model.SubscriptionStatusActive, time.Now()).
|
|
Order("created_at DESC").
|
|
Find(&subs).Error
|
|
return subs, err
|
|
}
|
|
|
|
// ListByGroupID 获取分组的所有订阅(分页)
|
|
func (r *UserSubscriptionRepository) ListByGroupID(ctx context.Context, groupID int64, params PaginationParams) ([]model.UserSubscription, *PaginationResult, error) {
|
|
var subs []model.UserSubscription
|
|
var total int64
|
|
|
|
query := r.db.WithContext(ctx).Model(&model.UserSubscription{}).Where("group_id = ?", groupID)
|
|
|
|
if err := query.Count(&total).Error; err != nil {
|
|
return nil, nil, err
|
|
}
|
|
|
|
err := query.
|
|
Preload("User").
|
|
Preload("Group").
|
|
Order("created_at DESC").
|
|
Offset(params.Offset()).
|
|
Limit(params.Limit()).
|
|
Find(&subs).Error
|
|
if err != nil {
|
|
return nil, nil, err
|
|
}
|
|
|
|
pages := int(total) / params.Limit()
|
|
if int(total)%params.Limit() > 0 {
|
|
pages++
|
|
}
|
|
|
|
return subs, &PaginationResult{
|
|
Total: total,
|
|
Page: params.Page,
|
|
PageSize: params.Limit(),
|
|
Pages: pages,
|
|
}, nil
|
|
}
|
|
|
|
// List 获取所有订阅(分页,支持筛选)
|
|
func (r *UserSubscriptionRepository) List(ctx context.Context, params PaginationParams, userID, groupID *int64, status string) ([]model.UserSubscription, *PaginationResult, error) {
|
|
var subs []model.UserSubscription
|
|
var total int64
|
|
|
|
query := r.db.WithContext(ctx).Model(&model.UserSubscription{})
|
|
|
|
if userID != nil {
|
|
query = query.Where("user_id = ?", *userID)
|
|
}
|
|
if groupID != nil {
|
|
query = query.Where("group_id = ?", *groupID)
|
|
}
|
|
if status != "" {
|
|
query = query.Where("status = ?", status)
|
|
}
|
|
|
|
if err := query.Count(&total).Error; err != nil {
|
|
return nil, nil, err
|
|
}
|
|
|
|
err := query.
|
|
Preload("User").
|
|
Preload("Group").
|
|
Preload("AssignedByUser").
|
|
Order("created_at DESC").
|
|
Offset(params.Offset()).
|
|
Limit(params.Limit()).
|
|
Find(&subs).Error
|
|
if err != nil {
|
|
return nil, nil, err
|
|
}
|
|
|
|
pages := int(total) / params.Limit()
|
|
if int(total)%params.Limit() > 0 {
|
|
pages++
|
|
}
|
|
|
|
return subs, &PaginationResult{
|
|
Total: total,
|
|
Page: params.Page,
|
|
PageSize: params.Limit(),
|
|
Pages: pages,
|
|
}, nil
|
|
}
|
|
|
|
// IncrementUsage 增加使用量
|
|
func (r *UserSubscriptionRepository) IncrementUsage(ctx context.Context, id int64, costUSD float64) error {
|
|
return r.db.WithContext(ctx).Model(&model.UserSubscription{}).
|
|
Where("id = ?", id).
|
|
Updates(map[string]interface{}{
|
|
"daily_usage_usd": gorm.Expr("daily_usage_usd + ?", costUSD),
|
|
"weekly_usage_usd": gorm.Expr("weekly_usage_usd + ?", costUSD),
|
|
"monthly_usage_usd": gorm.Expr("monthly_usage_usd + ?", costUSD),
|
|
"updated_at": time.Now(),
|
|
}).Error
|
|
}
|
|
|
|
// ResetDailyUsage 重置日使用量
|
|
func (r *UserSubscriptionRepository) ResetDailyUsage(ctx context.Context, id int64, newWindowStart time.Time) error {
|
|
return r.db.WithContext(ctx).Model(&model.UserSubscription{}).
|
|
Where("id = ?", id).
|
|
Updates(map[string]interface{}{
|
|
"daily_usage_usd": 0,
|
|
"daily_window_start": newWindowStart,
|
|
"updated_at": time.Now(),
|
|
}).Error
|
|
}
|
|
|
|
// ResetWeeklyUsage 重置周使用量
|
|
func (r *UserSubscriptionRepository) ResetWeeklyUsage(ctx context.Context, id int64, newWindowStart time.Time) error {
|
|
return r.db.WithContext(ctx).Model(&model.UserSubscription{}).
|
|
Where("id = ?", id).
|
|
Updates(map[string]interface{}{
|
|
"weekly_usage_usd": 0,
|
|
"weekly_window_start": newWindowStart,
|
|
"updated_at": time.Now(),
|
|
}).Error
|
|
}
|
|
|
|
// ResetMonthlyUsage 重置月使用量
|
|
func (r *UserSubscriptionRepository) ResetMonthlyUsage(ctx context.Context, id int64, newWindowStart time.Time) error {
|
|
return r.db.WithContext(ctx).Model(&model.UserSubscription{}).
|
|
Where("id = ?", id).
|
|
Updates(map[string]interface{}{
|
|
"monthly_usage_usd": 0,
|
|
"monthly_window_start": newWindowStart,
|
|
"updated_at": time.Now(),
|
|
}).Error
|
|
}
|
|
|
|
// ActivateWindows 激活所有窗口(首次使用时)
|
|
func (r *UserSubscriptionRepository) ActivateWindows(ctx context.Context, id int64, activateTime time.Time) error {
|
|
return r.db.WithContext(ctx).Model(&model.UserSubscription{}).
|
|
Where("id = ?", id).
|
|
Updates(map[string]interface{}{
|
|
"daily_window_start": activateTime,
|
|
"weekly_window_start": activateTime,
|
|
"monthly_window_start": activateTime,
|
|
"updated_at": time.Now(),
|
|
}).Error
|
|
}
|
|
|
|
// UpdateStatus 更新订阅状态
|
|
func (r *UserSubscriptionRepository) UpdateStatus(ctx context.Context, id int64, status string) error {
|
|
return r.db.WithContext(ctx).Model(&model.UserSubscription{}).
|
|
Where("id = ?", id).
|
|
Updates(map[string]interface{}{
|
|
"status": status,
|
|
"updated_at": time.Now(),
|
|
}).Error
|
|
}
|
|
|
|
// ExtendExpiry 延长订阅过期时间
|
|
func (r *UserSubscriptionRepository) ExtendExpiry(ctx context.Context, id int64, newExpiresAt time.Time) error {
|
|
return r.db.WithContext(ctx).Model(&model.UserSubscription{}).
|
|
Where("id = ?", id).
|
|
Updates(map[string]interface{}{
|
|
"expires_at": newExpiresAt,
|
|
"updated_at": time.Now(),
|
|
}).Error
|
|
}
|
|
|
|
// UpdateNotes 更新订阅备注
|
|
func (r *UserSubscriptionRepository) UpdateNotes(ctx context.Context, id int64, notes string) error {
|
|
return r.db.WithContext(ctx).Model(&model.UserSubscription{}).
|
|
Where("id = ?", id).
|
|
Updates(map[string]interface{}{
|
|
"notes": notes,
|
|
"updated_at": time.Now(),
|
|
}).Error
|
|
}
|
|
|
|
// ListExpired 获取所有已过期但状态仍为active的订阅
|
|
func (r *UserSubscriptionRepository) ListExpired(ctx context.Context) ([]model.UserSubscription, error) {
|
|
var subs []model.UserSubscription
|
|
err := r.db.WithContext(ctx).
|
|
Where("status = ? AND expires_at <= ?", model.SubscriptionStatusActive, time.Now()).
|
|
Find(&subs).Error
|
|
return subs, err
|
|
}
|
|
|
|
// BatchUpdateExpiredStatus 批量更新过期订阅状态
|
|
func (r *UserSubscriptionRepository) BatchUpdateExpiredStatus(ctx context.Context) (int64, error) {
|
|
result := r.db.WithContext(ctx).Model(&model.UserSubscription{}).
|
|
Where("status = ? AND expires_at <= ?", model.SubscriptionStatusActive, time.Now()).
|
|
Updates(map[string]interface{}{
|
|
"status": model.SubscriptionStatusExpired,
|
|
"updated_at": time.Now(),
|
|
})
|
|
return result.RowsAffected, result.Error
|
|
}
|
|
|
|
// ExistsByUserIDAndGroupID 检查用户是否已有该分组的订阅
|
|
func (r *UserSubscriptionRepository) ExistsByUserIDAndGroupID(ctx context.Context, userID, groupID int64) (bool, error) {
|
|
var count int64
|
|
err := r.db.WithContext(ctx).Model(&model.UserSubscription{}).
|
|
Where("user_id = ? AND group_id = ?", userID, groupID).
|
|
Count(&count).Error
|
|
return count > 0, err
|
|
}
|
|
|
|
// CountByGroupID 获取分组的订阅数量
|
|
func (r *UserSubscriptionRepository) CountByGroupID(ctx context.Context, groupID int64) (int64, error) {
|
|
var count int64
|
|
err := r.db.WithContext(ctx).Model(&model.UserSubscription{}).
|
|
Where("group_id = ?", groupID).
|
|
Count(&count).Error
|
|
return count, err
|
|
}
|
|
|
|
// CountActiveByGroupID 获取分组的有效订阅数量
|
|
func (r *UserSubscriptionRepository) CountActiveByGroupID(ctx context.Context, groupID int64) (int64, error) {
|
|
var count int64
|
|
err := r.db.WithContext(ctx).Model(&model.UserSubscription{}).
|
|
Where("group_id = ? AND status = ? AND expires_at > ?",
|
|
groupID, model.SubscriptionStatusActive, time.Now()).
|
|
Count(&count).Error
|
|
return count, err
|
|
}
|
|
|
|
// DeleteByGroupID 删除分组相关的所有订阅记录
|
|
func (r *UserSubscriptionRepository) DeleteByGroupID(ctx context.Context, groupID int64) (int64, error) {
|
|
result := r.db.WithContext(ctx).Where("group_id = ?", groupID).Delete(&model.UserSubscription{})
|
|
return result.RowsAffected, result.Error
|
|
}
|