## 数据完整性修复 (fix-critical-data-integrity) - 添加 error_translate.go 统一错误转换层 - 修复 nil 输入和 NotFound 错误处理 - 增强仓储层错误一致性 ## 仓储一致性修复 (fix-high-repository-consistency) - Group schema 添加 default_validity_days 字段 - Account schema 添加 proxy edge 关联 - 新增 UsageLog ent schema 定义 - 修复 UpdateBalance/UpdateConcurrency 受影响行数校验 ## 数据卫生修复 (fix-medium-data-hygiene) - UserSubscription 添加软删除支持 (SoftDeleteMixin) - RedeemCode/Setting 添加硬删除策略文档 - account_groups/user_allowed_groups 的 created_at 声明 timestamptz - 停止写入 legacy users.allowed_groups 列 - 新增迁移: 011-014 (索引优化、软删除、孤立数据审计、列清理) ## 测试补充 - 添加 UserSubscription 软删除测试 - 添加迁移回归测试 - 添加 NotFound 错误测试 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
494 lines
16 KiB
Go
494 lines
16 KiB
Go
package repository
|
||
|
||
import (
|
||
"context"
|
||
"time"
|
||
|
||
dbent "github.com/Wei-Shaw/sub2api/ent"
|
||
"github.com/Wei-Shaw/sub2api/ent/usersubscription"
|
||
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||
)
|
||
|
||
type userSubscriptionRepository struct {
|
||
client *dbent.Client
|
||
}
|
||
|
||
func NewUserSubscriptionRepository(client *dbent.Client) service.UserSubscriptionRepository {
|
||
return &userSubscriptionRepository{client: client}
|
||
}
|
||
|
||
func (r *userSubscriptionRepository) Create(ctx context.Context, sub *service.UserSubscription) error {
|
||
if sub == nil {
|
||
return service.ErrSubscriptionNilInput
|
||
}
|
||
|
||
client := clientFromContext(ctx, r.client)
|
||
builder := client.UserSubscription.Create().
|
||
SetUserID(sub.UserID).
|
||
SetGroupID(sub.GroupID).
|
||
SetExpiresAt(sub.ExpiresAt).
|
||
SetNillableDailyWindowStart(sub.DailyWindowStart).
|
||
SetNillableWeeklyWindowStart(sub.WeeklyWindowStart).
|
||
SetNillableMonthlyWindowStart(sub.MonthlyWindowStart).
|
||
SetDailyUsageUsd(sub.DailyUsageUSD).
|
||
SetWeeklyUsageUsd(sub.WeeklyUsageUSD).
|
||
SetMonthlyUsageUsd(sub.MonthlyUsageUSD).
|
||
SetNillableAssignedBy(sub.AssignedBy)
|
||
|
||
if sub.StartsAt.IsZero() {
|
||
builder.SetStartsAt(time.Now())
|
||
} else {
|
||
builder.SetStartsAt(sub.StartsAt)
|
||
}
|
||
if sub.Status != "" {
|
||
builder.SetStatus(sub.Status)
|
||
}
|
||
if !sub.AssignedAt.IsZero() {
|
||
builder.SetAssignedAt(sub.AssignedAt)
|
||
}
|
||
// Keep compatibility with historical behavior: always store notes as a string value.
|
||
builder.SetNotes(sub.Notes)
|
||
|
||
created, err := builder.Save(ctx)
|
||
if err == nil {
|
||
applyUserSubscriptionEntityToService(sub, created)
|
||
}
|
||
return translatePersistenceError(err, nil, service.ErrSubscriptionAlreadyExists)
|
||
}
|
||
|
||
func (r *userSubscriptionRepository) GetByID(ctx context.Context, id int64) (*service.UserSubscription, error) {
|
||
client := clientFromContext(ctx, r.client)
|
||
m, err := client.UserSubscription.Query().
|
||
Where(usersubscription.IDEQ(id)).
|
||
WithUser().
|
||
WithGroup().
|
||
WithAssignedByUser().
|
||
Only(ctx)
|
||
if err != nil {
|
||
return nil, translatePersistenceError(err, service.ErrSubscriptionNotFound, nil)
|
||
}
|
||
return userSubscriptionEntityToService(m), nil
|
||
}
|
||
|
||
func (r *userSubscriptionRepository) GetByUserIDAndGroupID(ctx context.Context, userID, groupID int64) (*service.UserSubscription, error) {
|
||
client := clientFromContext(ctx, r.client)
|
||
m, err := client.UserSubscription.Query().
|
||
Where(usersubscription.UserIDEQ(userID), usersubscription.GroupIDEQ(groupID)).
|
||
WithGroup().
|
||
Only(ctx)
|
||
if err != nil {
|
||
return nil, translatePersistenceError(err, service.ErrSubscriptionNotFound, nil)
|
||
}
|
||
return userSubscriptionEntityToService(m), nil
|
||
}
|
||
|
||
func (r *userSubscriptionRepository) GetActiveByUserIDAndGroupID(ctx context.Context, userID, groupID int64) (*service.UserSubscription, error) {
|
||
client := clientFromContext(ctx, r.client)
|
||
m, err := client.UserSubscription.Query().
|
||
Where(
|
||
usersubscription.UserIDEQ(userID),
|
||
usersubscription.GroupIDEQ(groupID),
|
||
usersubscription.StatusEQ(service.SubscriptionStatusActive),
|
||
usersubscription.ExpiresAtGT(time.Now()),
|
||
).
|
||
WithGroup().
|
||
Only(ctx)
|
||
if err != nil {
|
||
return nil, translatePersistenceError(err, service.ErrSubscriptionNotFound, nil)
|
||
}
|
||
return userSubscriptionEntityToService(m), nil
|
||
}
|
||
|
||
func (r *userSubscriptionRepository) Update(ctx context.Context, sub *service.UserSubscription) error {
|
||
if sub == nil {
|
||
return service.ErrSubscriptionNilInput
|
||
}
|
||
|
||
client := clientFromContext(ctx, r.client)
|
||
builder := client.UserSubscription.UpdateOneID(sub.ID).
|
||
SetUserID(sub.UserID).
|
||
SetGroupID(sub.GroupID).
|
||
SetStartsAt(sub.StartsAt).
|
||
SetExpiresAt(sub.ExpiresAt).
|
||
SetStatus(sub.Status).
|
||
SetNillableDailyWindowStart(sub.DailyWindowStart).
|
||
SetNillableWeeklyWindowStart(sub.WeeklyWindowStart).
|
||
SetNillableMonthlyWindowStart(sub.MonthlyWindowStart).
|
||
SetDailyUsageUsd(sub.DailyUsageUSD).
|
||
SetWeeklyUsageUsd(sub.WeeklyUsageUSD).
|
||
SetMonthlyUsageUsd(sub.MonthlyUsageUSD).
|
||
SetNillableAssignedBy(sub.AssignedBy).
|
||
SetAssignedAt(sub.AssignedAt).
|
||
SetNotes(sub.Notes)
|
||
|
||
updated, err := builder.Save(ctx)
|
||
if err == nil {
|
||
applyUserSubscriptionEntityToService(sub, updated)
|
||
return nil
|
||
}
|
||
return translatePersistenceError(err, service.ErrSubscriptionNotFound, service.ErrSubscriptionAlreadyExists)
|
||
}
|
||
|
||
func (r *userSubscriptionRepository) Delete(ctx context.Context, id int64) error {
|
||
// Match GORM semantics: deleting a missing row is not an error.
|
||
client := clientFromContext(ctx, r.client)
|
||
_, err := client.UserSubscription.Delete().Where(usersubscription.IDEQ(id)).Exec(ctx)
|
||
return err
|
||
}
|
||
|
||
func (r *userSubscriptionRepository) ListByUserID(ctx context.Context, userID int64) ([]service.UserSubscription, error) {
|
||
client := clientFromContext(ctx, r.client)
|
||
subs, err := client.UserSubscription.Query().
|
||
Where(usersubscription.UserIDEQ(userID)).
|
||
WithGroup().
|
||
Order(dbent.Desc(usersubscription.FieldCreatedAt)).
|
||
All(ctx)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
return userSubscriptionEntitiesToService(subs), nil
|
||
}
|
||
|
||
func (r *userSubscriptionRepository) ListActiveByUserID(ctx context.Context, userID int64) ([]service.UserSubscription, error) {
|
||
client := clientFromContext(ctx, r.client)
|
||
subs, err := client.UserSubscription.Query().
|
||
Where(
|
||
usersubscription.UserIDEQ(userID),
|
||
usersubscription.StatusEQ(service.SubscriptionStatusActive),
|
||
usersubscription.ExpiresAtGT(time.Now()),
|
||
).
|
||
WithGroup().
|
||
Order(dbent.Desc(usersubscription.FieldCreatedAt)).
|
||
All(ctx)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
return userSubscriptionEntitiesToService(subs), nil
|
||
}
|
||
|
||
func (r *userSubscriptionRepository) ListByGroupID(ctx context.Context, groupID int64, params pagination.PaginationParams) ([]service.UserSubscription, *pagination.PaginationResult, error) {
|
||
client := clientFromContext(ctx, r.client)
|
||
q := client.UserSubscription.Query().Where(usersubscription.GroupIDEQ(groupID))
|
||
|
||
total, err := q.Clone().Count(ctx)
|
||
if err != nil {
|
||
return nil, nil, err
|
||
}
|
||
|
||
subs, err := q.
|
||
WithUser().
|
||
WithGroup().
|
||
Order(dbent.Desc(usersubscription.FieldCreatedAt)).
|
||
Offset(params.Offset()).
|
||
Limit(params.Limit()).
|
||
All(ctx)
|
||
if err != nil {
|
||
return nil, nil, err
|
||
}
|
||
|
||
return userSubscriptionEntitiesToService(subs), paginationResultFromTotal(int64(total), params), nil
|
||
}
|
||
|
||
func (r *userSubscriptionRepository) List(ctx context.Context, params pagination.PaginationParams, userID, groupID *int64, status string) ([]service.UserSubscription, *pagination.PaginationResult, error) {
|
||
client := clientFromContext(ctx, r.client)
|
||
q := client.UserSubscription.Query()
|
||
if userID != nil {
|
||
q = q.Where(usersubscription.UserIDEQ(*userID))
|
||
}
|
||
if groupID != nil {
|
||
q = q.Where(usersubscription.GroupIDEQ(*groupID))
|
||
}
|
||
if status != "" {
|
||
q = q.Where(usersubscription.StatusEQ(status))
|
||
}
|
||
|
||
total, err := q.Clone().Count(ctx)
|
||
if err != nil {
|
||
return nil, nil, err
|
||
}
|
||
|
||
subs, err := q.
|
||
WithUser().
|
||
WithGroup().
|
||
WithAssignedByUser().
|
||
Order(dbent.Desc(usersubscription.FieldCreatedAt)).
|
||
Offset(params.Offset()).
|
||
Limit(params.Limit()).
|
||
All(ctx)
|
||
if err != nil {
|
||
return nil, nil, err
|
||
}
|
||
|
||
return userSubscriptionEntitiesToService(subs), paginationResultFromTotal(int64(total), params), nil
|
||
}
|
||
|
||
func (r *userSubscriptionRepository) ExistsByUserIDAndGroupID(ctx context.Context, userID, groupID int64) (bool, error) {
|
||
client := clientFromContext(ctx, r.client)
|
||
return client.UserSubscription.Query().
|
||
Where(usersubscription.UserIDEQ(userID), usersubscription.GroupIDEQ(groupID)).
|
||
Exist(ctx)
|
||
}
|
||
|
||
func (r *userSubscriptionRepository) ExtendExpiry(ctx context.Context, subscriptionID int64, newExpiresAt time.Time) error {
|
||
client := clientFromContext(ctx, r.client)
|
||
_, err := client.UserSubscription.UpdateOneID(subscriptionID).
|
||
SetExpiresAt(newExpiresAt).
|
||
Save(ctx)
|
||
return translatePersistenceError(err, service.ErrSubscriptionNotFound, nil)
|
||
}
|
||
|
||
func (r *userSubscriptionRepository) UpdateStatus(ctx context.Context, subscriptionID int64, status string) error {
|
||
client := clientFromContext(ctx, r.client)
|
||
_, err := client.UserSubscription.UpdateOneID(subscriptionID).
|
||
SetStatus(status).
|
||
Save(ctx)
|
||
return translatePersistenceError(err, service.ErrSubscriptionNotFound, nil)
|
||
}
|
||
|
||
func (r *userSubscriptionRepository) UpdateNotes(ctx context.Context, subscriptionID int64, notes string) error {
|
||
client := clientFromContext(ctx, r.client)
|
||
_, err := client.UserSubscription.UpdateOneID(subscriptionID).
|
||
SetNotes(notes).
|
||
Save(ctx)
|
||
return translatePersistenceError(err, service.ErrSubscriptionNotFound, nil)
|
||
}
|
||
|
||
func (r *userSubscriptionRepository) ActivateWindows(ctx context.Context, id int64, start time.Time) error {
|
||
client := clientFromContext(ctx, r.client)
|
||
_, err := client.UserSubscription.UpdateOneID(id).
|
||
SetDailyWindowStart(start).
|
||
SetWeeklyWindowStart(start).
|
||
SetMonthlyWindowStart(start).
|
||
Save(ctx)
|
||
return translatePersistenceError(err, service.ErrSubscriptionNotFound, nil)
|
||
}
|
||
|
||
func (r *userSubscriptionRepository) ResetDailyUsage(ctx context.Context, id int64, newWindowStart time.Time) error {
|
||
client := clientFromContext(ctx, r.client)
|
||
_, err := client.UserSubscription.UpdateOneID(id).
|
||
SetDailyUsageUsd(0).
|
||
SetDailyWindowStart(newWindowStart).
|
||
Save(ctx)
|
||
return translatePersistenceError(err, service.ErrSubscriptionNotFound, nil)
|
||
}
|
||
|
||
func (r *userSubscriptionRepository) ResetWeeklyUsage(ctx context.Context, id int64, newWindowStart time.Time) error {
|
||
client := clientFromContext(ctx, r.client)
|
||
_, err := client.UserSubscription.UpdateOneID(id).
|
||
SetWeeklyUsageUsd(0).
|
||
SetWeeklyWindowStart(newWindowStart).
|
||
Save(ctx)
|
||
return translatePersistenceError(err, service.ErrSubscriptionNotFound, nil)
|
||
}
|
||
|
||
func (r *userSubscriptionRepository) ResetMonthlyUsage(ctx context.Context, id int64, newWindowStart time.Time) error {
|
||
client := clientFromContext(ctx, r.client)
|
||
_, err := client.UserSubscription.UpdateOneID(id).
|
||
SetMonthlyUsageUsd(0).
|
||
SetMonthlyWindowStart(newWindowStart).
|
||
Save(ctx)
|
||
return translatePersistenceError(err, service.ErrSubscriptionNotFound, nil)
|
||
}
|
||
|
||
// IncrementUsage 原子性地累加用量并校验限额。
|
||
// 使用单条 SQL 语句同时检查 Group 的限额,如果任一限额即将超出则拒绝更新。
|
||
// 当更新失败时,会执行额外查询确定具体超出的限额类型。
|
||
func (r *userSubscriptionRepository) IncrementUsage(ctx context.Context, id int64, costUSD float64) error {
|
||
// 使用 JOIN 的原子更新:只有当所有限额条件满足时才执行累加
|
||
// NULL 限额表示无限制
|
||
const atomicUpdateSQL = `
|
||
UPDATE user_subscriptions us
|
||
SET
|
||
daily_usage_usd = us.daily_usage_usd + $1,
|
||
weekly_usage_usd = us.weekly_usage_usd + $1,
|
||
monthly_usage_usd = us.monthly_usage_usd + $1,
|
||
updated_at = NOW()
|
||
FROM groups g
|
||
WHERE us.id = $2
|
||
AND us.deleted_at IS NULL
|
||
AND us.group_id = g.id
|
||
AND g.deleted_at IS NULL
|
||
AND (g.daily_limit_usd IS NULL OR us.daily_usage_usd + $1 <= g.daily_limit_usd)
|
||
AND (g.weekly_limit_usd IS NULL OR us.weekly_usage_usd + $1 <= g.weekly_limit_usd)
|
||
AND (g.monthly_limit_usd IS NULL OR us.monthly_usage_usd + $1 <= g.monthly_limit_usd)
|
||
`
|
||
|
||
client := clientFromContext(ctx, r.client)
|
||
result, err := client.ExecContext(ctx, atomicUpdateSQL, costUSD, id)
|
||
if err != nil {
|
||
return err
|
||
}
|
||
|
||
affected, err := result.RowsAffected()
|
||
if err != nil {
|
||
return err
|
||
}
|
||
|
||
if affected > 0 {
|
||
return nil // 更新成功
|
||
}
|
||
|
||
// affected == 0:可能是订阅不存在、分组已删除、或限额超出
|
||
// 执行额外查询确定具体原因
|
||
return r.checkIncrementFailureReason(ctx, id, costUSD)
|
||
}
|
||
|
||
// checkIncrementFailureReason 查询更新失败的具体原因
|
||
func (r *userSubscriptionRepository) checkIncrementFailureReason(ctx context.Context, id int64, costUSD float64) error {
|
||
const checkSQL = `
|
||
SELECT
|
||
CASE WHEN us.deleted_at IS NOT NULL THEN 'subscription_deleted'
|
||
WHEN g.id IS NULL THEN 'subscription_not_found'
|
||
WHEN g.deleted_at IS NOT NULL THEN 'group_deleted'
|
||
WHEN g.daily_limit_usd IS NOT NULL AND us.daily_usage_usd + $1 > g.daily_limit_usd THEN 'daily_exceeded'
|
||
WHEN g.weekly_limit_usd IS NOT NULL AND us.weekly_usage_usd + $1 > g.weekly_limit_usd THEN 'weekly_exceeded'
|
||
WHEN g.monthly_limit_usd IS NOT NULL AND us.monthly_usage_usd + $1 > g.monthly_limit_usd THEN 'monthly_exceeded'
|
||
ELSE 'unknown'
|
||
END AS reason
|
||
FROM user_subscriptions us
|
||
LEFT JOIN groups g ON us.group_id = g.id
|
||
WHERE us.id = $2
|
||
`
|
||
|
||
client := clientFromContext(ctx, r.client)
|
||
rows, err := client.QueryContext(ctx, checkSQL, costUSD, id)
|
||
if err != nil {
|
||
return err
|
||
}
|
||
defer func() { _ = rows.Close() }()
|
||
|
||
if !rows.Next() {
|
||
return service.ErrSubscriptionNotFound
|
||
}
|
||
|
||
var reason string
|
||
if err := rows.Scan(&reason); err != nil {
|
||
return err
|
||
}
|
||
|
||
if err := rows.Err(); err != nil {
|
||
return err
|
||
}
|
||
|
||
switch reason {
|
||
case "subscription_not_found", "subscription_deleted", "group_deleted":
|
||
return service.ErrSubscriptionNotFound
|
||
case "daily_exceeded":
|
||
return service.ErrDailyLimitExceeded
|
||
case "weekly_exceeded":
|
||
return service.ErrWeeklyLimitExceeded
|
||
case "monthly_exceeded":
|
||
return service.ErrMonthlyLimitExceeded
|
||
default:
|
||
// unknown 情况理论上不应发生,但作为兜底返回
|
||
return service.ErrSubscriptionNotFound
|
||
}
|
||
}
|
||
|
||
func (r *userSubscriptionRepository) BatchUpdateExpiredStatus(ctx context.Context) (int64, error) {
|
||
client := clientFromContext(ctx, r.client)
|
||
n, err := client.UserSubscription.Update().
|
||
Where(
|
||
usersubscription.StatusEQ(service.SubscriptionStatusActive),
|
||
usersubscription.ExpiresAtLTE(time.Now()),
|
||
).
|
||
SetStatus(service.SubscriptionStatusExpired).
|
||
Save(ctx)
|
||
return int64(n), err
|
||
}
|
||
|
||
// Extra repository helpers (currently used only by integration tests).
|
||
|
||
func (r *userSubscriptionRepository) ListExpired(ctx context.Context) ([]service.UserSubscription, error) {
|
||
client := clientFromContext(ctx, r.client)
|
||
subs, err := client.UserSubscription.Query().
|
||
Where(
|
||
usersubscription.StatusEQ(service.SubscriptionStatusActive),
|
||
usersubscription.ExpiresAtLTE(time.Now()),
|
||
).
|
||
All(ctx)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
return userSubscriptionEntitiesToService(subs), nil
|
||
}
|
||
|
||
func (r *userSubscriptionRepository) CountByGroupID(ctx context.Context, groupID int64) (int64, error) {
|
||
client := clientFromContext(ctx, r.client)
|
||
count, err := client.UserSubscription.Query().Where(usersubscription.GroupIDEQ(groupID)).Count(ctx)
|
||
return int64(count), err
|
||
}
|
||
|
||
func (r *userSubscriptionRepository) CountActiveByGroupID(ctx context.Context, groupID int64) (int64, error) {
|
||
client := clientFromContext(ctx, r.client)
|
||
count, err := client.UserSubscription.Query().
|
||
Where(
|
||
usersubscription.GroupIDEQ(groupID),
|
||
usersubscription.StatusEQ(service.SubscriptionStatusActive),
|
||
usersubscription.ExpiresAtGT(time.Now()),
|
||
).
|
||
Count(ctx)
|
||
return int64(count), err
|
||
}
|
||
|
||
func (r *userSubscriptionRepository) DeleteByGroupID(ctx context.Context, groupID int64) (int64, error) {
|
||
client := clientFromContext(ctx, r.client)
|
||
n, err := client.UserSubscription.Delete().Where(usersubscription.GroupIDEQ(groupID)).Exec(ctx)
|
||
return int64(n), err
|
||
}
|
||
|
||
func userSubscriptionEntityToService(m *dbent.UserSubscription) *service.UserSubscription {
|
||
if m == nil {
|
||
return nil
|
||
}
|
||
out := &service.UserSubscription{
|
||
ID: m.ID,
|
||
UserID: m.UserID,
|
||
GroupID: m.GroupID,
|
||
StartsAt: m.StartsAt,
|
||
ExpiresAt: m.ExpiresAt,
|
||
Status: m.Status,
|
||
DailyWindowStart: m.DailyWindowStart,
|
||
WeeklyWindowStart: m.WeeklyWindowStart,
|
||
MonthlyWindowStart: m.MonthlyWindowStart,
|
||
DailyUsageUSD: m.DailyUsageUsd,
|
||
WeeklyUsageUSD: m.WeeklyUsageUsd,
|
||
MonthlyUsageUSD: m.MonthlyUsageUsd,
|
||
AssignedBy: m.AssignedBy,
|
||
AssignedAt: m.AssignedAt,
|
||
Notes: derefString(m.Notes),
|
||
CreatedAt: m.CreatedAt,
|
||
UpdatedAt: m.UpdatedAt,
|
||
}
|
||
if m.Edges.User != nil {
|
||
out.User = userEntityToService(m.Edges.User)
|
||
}
|
||
if m.Edges.Group != nil {
|
||
out.Group = groupEntityToService(m.Edges.Group)
|
||
}
|
||
if m.Edges.AssignedByUser != nil {
|
||
out.AssignedByUser = userEntityToService(m.Edges.AssignedByUser)
|
||
}
|
||
return out
|
||
}
|
||
|
||
func userSubscriptionEntitiesToService(models []*dbent.UserSubscription) []service.UserSubscription {
|
||
out := make([]service.UserSubscription, 0, len(models))
|
||
for i := range models {
|
||
if s := userSubscriptionEntityToService(models[i]); s != nil {
|
||
out = append(out, *s)
|
||
}
|
||
}
|
||
return out
|
||
}
|
||
|
||
func applyUserSubscriptionEntityToService(dst *service.UserSubscription, src *dbent.UserSubscription) {
|
||
if dst == nil || src == nil {
|
||
return
|
||
}
|
||
dst.ID = src.ID
|
||
dst.CreatedAt = src.CreatedAt
|
||
dst.UpdatedAt = src.UpdatedAt
|
||
}
|