Files
sub2api/backend/internal/repository/user_subscription_repo.go
2026-02-02 22:13:50 +08:00

477 lines
15 KiB
Go
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package repository
import (
"context"
"time"
dbent "github.com/Wei-Shaw/sub2api/ent"
"github.com/Wei-Shaw/sub2api/ent/usersubscription"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/service"
)
type userSubscriptionRepository struct {
client *dbent.Client
}
func NewUserSubscriptionRepository(client *dbent.Client) service.UserSubscriptionRepository {
return &userSubscriptionRepository{client: client}
}
func (r *userSubscriptionRepository) Create(ctx context.Context, sub *service.UserSubscription) error {
if sub == nil {
return service.ErrSubscriptionNilInput
}
client := clientFromContext(ctx, r.client)
builder := client.UserSubscription.Create().
SetUserID(sub.UserID).
SetGroupID(sub.GroupID).
SetExpiresAt(sub.ExpiresAt).
SetNillableDailyWindowStart(sub.DailyWindowStart).
SetNillableWeeklyWindowStart(sub.WeeklyWindowStart).
SetNillableMonthlyWindowStart(sub.MonthlyWindowStart).
SetDailyUsageUsd(sub.DailyUsageUSD).
SetWeeklyUsageUsd(sub.WeeklyUsageUSD).
SetMonthlyUsageUsd(sub.MonthlyUsageUSD).
SetNillableAssignedBy(sub.AssignedBy)
if sub.StartsAt.IsZero() {
builder.SetStartsAt(time.Now())
} else {
builder.SetStartsAt(sub.StartsAt)
}
if sub.Status != "" {
builder.SetStatus(sub.Status)
}
if !sub.AssignedAt.IsZero() {
builder.SetAssignedAt(sub.AssignedAt)
}
// Keep compatibility with historical behavior: always store notes as a string value.
builder.SetNotes(sub.Notes)
created, err := builder.Save(ctx)
if err == nil {
applyUserSubscriptionEntityToService(sub, created)
}
return translatePersistenceError(err, nil, service.ErrSubscriptionAlreadyExists)
}
func (r *userSubscriptionRepository) GetByID(ctx context.Context, id int64) (*service.UserSubscription, error) {
client := clientFromContext(ctx, r.client)
m, err := client.UserSubscription.Query().
Where(usersubscription.IDEQ(id)).
WithUser().
WithGroup().
WithAssignedByUser().
Only(ctx)
if err != nil {
return nil, translatePersistenceError(err, service.ErrSubscriptionNotFound, nil)
}
return userSubscriptionEntityToService(m), nil
}
func (r *userSubscriptionRepository) GetByUserIDAndGroupID(ctx context.Context, userID, groupID int64) (*service.UserSubscription, error) {
client := clientFromContext(ctx, r.client)
m, err := client.UserSubscription.Query().
Where(usersubscription.UserIDEQ(userID), usersubscription.GroupIDEQ(groupID)).
WithGroup().
Only(ctx)
if err != nil {
return nil, translatePersistenceError(err, service.ErrSubscriptionNotFound, nil)
}
return userSubscriptionEntityToService(m), nil
}
func (r *userSubscriptionRepository) GetActiveByUserIDAndGroupID(ctx context.Context, userID, groupID int64) (*service.UserSubscription, error) {
client := clientFromContext(ctx, r.client)
m, err := client.UserSubscription.Query().
Where(
usersubscription.UserIDEQ(userID),
usersubscription.GroupIDEQ(groupID),
usersubscription.StatusEQ(service.SubscriptionStatusActive),
usersubscription.ExpiresAtGT(time.Now()),
).
WithGroup().
Only(ctx)
if err != nil {
return nil, translatePersistenceError(err, service.ErrSubscriptionNotFound, nil)
}
return userSubscriptionEntityToService(m), nil
}
func (r *userSubscriptionRepository) Update(ctx context.Context, sub *service.UserSubscription) error {
if sub == nil {
return service.ErrSubscriptionNilInput
}
client := clientFromContext(ctx, r.client)
builder := client.UserSubscription.UpdateOneID(sub.ID).
SetUserID(sub.UserID).
SetGroupID(sub.GroupID).
SetStartsAt(sub.StartsAt).
SetExpiresAt(sub.ExpiresAt).
SetStatus(sub.Status).
SetNillableDailyWindowStart(sub.DailyWindowStart).
SetNillableWeeklyWindowStart(sub.WeeklyWindowStart).
SetNillableMonthlyWindowStart(sub.MonthlyWindowStart).
SetDailyUsageUsd(sub.DailyUsageUSD).
SetWeeklyUsageUsd(sub.WeeklyUsageUSD).
SetMonthlyUsageUsd(sub.MonthlyUsageUSD).
SetNillableAssignedBy(sub.AssignedBy).
SetAssignedAt(sub.AssignedAt).
SetNotes(sub.Notes)
updated, err := builder.Save(ctx)
if err == nil {
applyUserSubscriptionEntityToService(sub, updated)
return nil
}
return translatePersistenceError(err, service.ErrSubscriptionNotFound, service.ErrSubscriptionAlreadyExists)
}
func (r *userSubscriptionRepository) Delete(ctx context.Context, id int64) error {
// Match GORM semantics: deleting a missing row is not an error.
client := clientFromContext(ctx, r.client)
_, err := client.UserSubscription.Delete().Where(usersubscription.IDEQ(id)).Exec(ctx)
return err
}
func (r *userSubscriptionRepository) ListByUserID(ctx context.Context, userID int64) ([]service.UserSubscription, error) {
client := clientFromContext(ctx, r.client)
subs, err := client.UserSubscription.Query().
Where(usersubscription.UserIDEQ(userID)).
WithGroup().
Order(dbent.Desc(usersubscription.FieldCreatedAt)).
All(ctx)
if err != nil {
return nil, err
}
return userSubscriptionEntitiesToService(subs), nil
}
func (r *userSubscriptionRepository) ListActiveByUserID(ctx context.Context, userID int64) ([]service.UserSubscription, error) {
client := clientFromContext(ctx, r.client)
subs, err := client.UserSubscription.Query().
Where(
usersubscription.UserIDEQ(userID),
usersubscription.StatusEQ(service.SubscriptionStatusActive),
usersubscription.ExpiresAtGT(time.Now()),
).
WithGroup().
Order(dbent.Desc(usersubscription.FieldCreatedAt)).
All(ctx)
if err != nil {
return nil, err
}
return userSubscriptionEntitiesToService(subs), nil
}
func (r *userSubscriptionRepository) ListByGroupID(ctx context.Context, groupID int64, params pagination.PaginationParams) ([]service.UserSubscription, *pagination.PaginationResult, error) {
client := clientFromContext(ctx, r.client)
q := client.UserSubscription.Query().Where(usersubscription.GroupIDEQ(groupID))
total, err := q.Clone().Count(ctx)
if err != nil {
return nil, nil, err
}
subs, err := q.
WithUser().
WithGroup().
Order(dbent.Desc(usersubscription.FieldCreatedAt)).
Offset(params.Offset()).
Limit(params.Limit()).
All(ctx)
if err != nil {
return nil, nil, err
}
return userSubscriptionEntitiesToService(subs), paginationResultFromTotal(int64(total), params), nil
}
func (r *userSubscriptionRepository) List(ctx context.Context, params pagination.PaginationParams, userID, groupID *int64, status, sortBy, sortOrder string) ([]service.UserSubscription, *pagination.PaginationResult, error) {
client := clientFromContext(ctx, r.client)
q := client.UserSubscription.Query()
if userID != nil {
q = q.Where(usersubscription.UserIDEQ(*userID))
}
if groupID != nil {
q = q.Where(usersubscription.GroupIDEQ(*groupID))
}
// Status filtering with real-time expiration check
now := time.Now()
switch status {
case service.SubscriptionStatusActive:
// Active: status is active AND not yet expired
q = q.Where(
usersubscription.StatusEQ(service.SubscriptionStatusActive),
usersubscription.ExpiresAtGT(now),
)
case service.SubscriptionStatusExpired:
// Expired: status is expired OR (status is active but already expired)
q = q.Where(
usersubscription.Or(
usersubscription.StatusEQ(service.SubscriptionStatusExpired),
usersubscription.And(
usersubscription.StatusEQ(service.SubscriptionStatusActive),
usersubscription.ExpiresAtLTE(now),
),
),
)
case "":
// No filter
default:
// Other status (e.g., revoked)
q = q.Where(usersubscription.StatusEQ(status))
}
total, err := q.Clone().Count(ctx)
if err != nil {
return nil, nil, err
}
// Apply sorting
q = q.WithUser().WithGroup().WithAssignedByUser()
// Determine sort field
var field string
switch sortBy {
case "expires_at":
field = usersubscription.FieldExpiresAt
case "status":
field = usersubscription.FieldStatus
default:
field = usersubscription.FieldCreatedAt
}
// Determine sort order (default: desc)
if sortOrder == "asc" && sortBy != "" {
q = q.Order(dbent.Asc(field))
} else {
q = q.Order(dbent.Desc(field))
}
subs, err := q.
Offset(params.Offset()).
Limit(params.Limit()).
All(ctx)
if err != nil {
return nil, nil, err
}
return userSubscriptionEntitiesToService(subs), paginationResultFromTotal(int64(total), params), nil
}
func (r *userSubscriptionRepository) ExistsByUserIDAndGroupID(ctx context.Context, userID, groupID int64) (bool, error) {
client := clientFromContext(ctx, r.client)
return client.UserSubscription.Query().
Where(usersubscription.UserIDEQ(userID), usersubscription.GroupIDEQ(groupID)).
Exist(ctx)
}
func (r *userSubscriptionRepository) ExtendExpiry(ctx context.Context, subscriptionID int64, newExpiresAt time.Time) error {
client := clientFromContext(ctx, r.client)
_, err := client.UserSubscription.UpdateOneID(subscriptionID).
SetExpiresAt(newExpiresAt).
Save(ctx)
return translatePersistenceError(err, service.ErrSubscriptionNotFound, nil)
}
func (r *userSubscriptionRepository) UpdateStatus(ctx context.Context, subscriptionID int64, status string) error {
client := clientFromContext(ctx, r.client)
_, err := client.UserSubscription.UpdateOneID(subscriptionID).
SetStatus(status).
Save(ctx)
return translatePersistenceError(err, service.ErrSubscriptionNotFound, nil)
}
func (r *userSubscriptionRepository) UpdateNotes(ctx context.Context, subscriptionID int64, notes string) error {
client := clientFromContext(ctx, r.client)
_, err := client.UserSubscription.UpdateOneID(subscriptionID).
SetNotes(notes).
Save(ctx)
return translatePersistenceError(err, service.ErrSubscriptionNotFound, nil)
}
func (r *userSubscriptionRepository) ActivateWindows(ctx context.Context, id int64, start time.Time) error {
client := clientFromContext(ctx, r.client)
_, err := client.UserSubscription.UpdateOneID(id).
SetDailyWindowStart(start).
SetWeeklyWindowStart(start).
SetMonthlyWindowStart(start).
Save(ctx)
return translatePersistenceError(err, service.ErrSubscriptionNotFound, nil)
}
func (r *userSubscriptionRepository) ResetDailyUsage(ctx context.Context, id int64, newWindowStart time.Time) error {
client := clientFromContext(ctx, r.client)
_, err := client.UserSubscription.UpdateOneID(id).
SetDailyUsageUsd(0).
SetDailyWindowStart(newWindowStart).
Save(ctx)
return translatePersistenceError(err, service.ErrSubscriptionNotFound, nil)
}
func (r *userSubscriptionRepository) ResetWeeklyUsage(ctx context.Context, id int64, newWindowStart time.Time) error {
client := clientFromContext(ctx, r.client)
_, err := client.UserSubscription.UpdateOneID(id).
SetWeeklyUsageUsd(0).
SetWeeklyWindowStart(newWindowStart).
Save(ctx)
return translatePersistenceError(err, service.ErrSubscriptionNotFound, nil)
}
func (r *userSubscriptionRepository) ResetMonthlyUsage(ctx context.Context, id int64, newWindowStart time.Time) error {
client := clientFromContext(ctx, r.client)
_, err := client.UserSubscription.UpdateOneID(id).
SetMonthlyUsageUsd(0).
SetMonthlyWindowStart(newWindowStart).
Save(ctx)
return translatePersistenceError(err, service.ErrSubscriptionNotFound, nil)
}
// IncrementUsage 原子性地累加订阅用量。
// 限额检查已在请求前由 BillingCacheService.CheckBillingEligibility 完成,
// 此处仅负责记录实际消费,确保消费数据的完整性。
func (r *userSubscriptionRepository) IncrementUsage(ctx context.Context, id int64, costUSD float64) error {
const updateSQL = `
UPDATE user_subscriptions us
SET
daily_usage_usd = us.daily_usage_usd + $1,
weekly_usage_usd = us.weekly_usage_usd + $1,
monthly_usage_usd = us.monthly_usage_usd + $1,
updated_at = NOW()
FROM groups g
WHERE us.id = $2
AND us.deleted_at IS NULL
AND us.group_id = g.id
AND g.deleted_at IS NULL
`
client := clientFromContext(ctx, r.client)
result, err := client.ExecContext(ctx, updateSQL, costUSD, id)
if err != nil {
return err
}
affected, err := result.RowsAffected()
if err != nil {
return err
}
if affected > 0 {
return nil
}
// affected == 0订阅不存在或已删除
return service.ErrSubscriptionNotFound
}
func (r *userSubscriptionRepository) BatchUpdateExpiredStatus(ctx context.Context) (int64, error) {
client := clientFromContext(ctx, r.client)
n, err := client.UserSubscription.Update().
Where(
usersubscription.StatusEQ(service.SubscriptionStatusActive),
usersubscription.ExpiresAtLTE(time.Now()),
).
SetStatus(service.SubscriptionStatusExpired).
Save(ctx)
return int64(n), err
}
// Extra repository helpers (currently used only by integration tests).
func (r *userSubscriptionRepository) ListExpired(ctx context.Context) ([]service.UserSubscription, error) {
client := clientFromContext(ctx, r.client)
subs, err := client.UserSubscription.Query().
Where(
usersubscription.StatusEQ(service.SubscriptionStatusActive),
usersubscription.ExpiresAtLTE(time.Now()),
).
All(ctx)
if err != nil {
return nil, err
}
return userSubscriptionEntitiesToService(subs), nil
}
func (r *userSubscriptionRepository) CountByGroupID(ctx context.Context, groupID int64) (int64, error) {
client := clientFromContext(ctx, r.client)
count, err := client.UserSubscription.Query().Where(usersubscription.GroupIDEQ(groupID)).Count(ctx)
return int64(count), err
}
func (r *userSubscriptionRepository) CountActiveByGroupID(ctx context.Context, groupID int64) (int64, error) {
client := clientFromContext(ctx, r.client)
count, err := client.UserSubscription.Query().
Where(
usersubscription.GroupIDEQ(groupID),
usersubscription.StatusEQ(service.SubscriptionStatusActive),
usersubscription.ExpiresAtGT(time.Now()),
).
Count(ctx)
return int64(count), err
}
func (r *userSubscriptionRepository) DeleteByGroupID(ctx context.Context, groupID int64) (int64, error) {
client := clientFromContext(ctx, r.client)
n, err := client.UserSubscription.Delete().Where(usersubscription.GroupIDEQ(groupID)).Exec(ctx)
return int64(n), err
}
func userSubscriptionEntityToService(m *dbent.UserSubscription) *service.UserSubscription {
if m == nil {
return nil
}
out := &service.UserSubscription{
ID: m.ID,
UserID: m.UserID,
GroupID: m.GroupID,
StartsAt: m.StartsAt,
ExpiresAt: m.ExpiresAt,
Status: m.Status,
DailyWindowStart: m.DailyWindowStart,
WeeklyWindowStart: m.WeeklyWindowStart,
MonthlyWindowStart: m.MonthlyWindowStart,
DailyUsageUSD: m.DailyUsageUsd,
WeeklyUsageUSD: m.WeeklyUsageUsd,
MonthlyUsageUSD: m.MonthlyUsageUsd,
AssignedBy: m.AssignedBy,
AssignedAt: m.AssignedAt,
Notes: derefString(m.Notes),
CreatedAt: m.CreatedAt,
UpdatedAt: m.UpdatedAt,
}
if m.Edges.User != nil {
out.User = userEntityToService(m.Edges.User)
}
if m.Edges.Group != nil {
out.Group = groupEntityToService(m.Edges.Group)
}
if m.Edges.AssignedByUser != nil {
out.AssignedByUser = userEntityToService(m.Edges.AssignedByUser)
}
return out
}
func userSubscriptionEntitiesToService(models []*dbent.UserSubscription) []service.UserSubscription {
out := make([]service.UserSubscription, 0, len(models))
for i := range models {
if s := userSubscriptionEntityToService(models[i]); s != nil {
out = append(out, *s)
}
}
return out
}
func applyUserSubscriptionEntityToService(dst *service.UserSubscription, src *dbent.UserSubscription) {
if dst == nil || src == nil {
return
}
dst.ID = src.ID
dst.CreatedAt = src.CreatedAt
dst.UpdatedAt = src.UpdatedAt
}