fix(数据层): 修复数据完整性与仓储一致性问题

## 数据完整性修复 (fix-critical-data-integrity)
- 添加 error_translate.go 统一错误转换层
- 修复 nil 输入和 NotFound 错误处理
- 增强仓储层错误一致性

## 仓储一致性修复 (fix-high-repository-consistency)
- Group schema 添加 default_validity_days 字段
- Account schema 添加 proxy edge 关联
- 新增 UsageLog ent schema 定义
- 修复 UpdateBalance/UpdateConcurrency 受影响行数校验

## 数据卫生修复 (fix-medium-data-hygiene)
- UserSubscription 添加软删除支持 (SoftDeleteMixin)
- RedeemCode/Setting 添加硬删除策略文档
- account_groups/user_allowed_groups 的 created_at 声明 timestamptz
- 停止写入 legacy users.allowed_groups 列
- 新增迁移: 011-014 (索引优化、软删除、孤立数据审计、列清理)

## 测试补充
- 添加 UserSubscription 软删除测试
- 添加迁移回归测试
- 添加 NotFound 错误测试

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
yangjianbo
2025-12-31 14:11:57 +08:00
parent 820bb16ca7
commit 5906f9ab98
87 changed files with 15258 additions and 485 deletions

View File

@@ -20,10 +20,11 @@ func NewUserSubscriptionRepository(client *dbent.Client) service.UserSubscriptio
func (r *userSubscriptionRepository) Create(ctx context.Context, sub *service.UserSubscription) error {
if sub == nil {
return nil
return service.ErrSubscriptionNilInput
}
builder := r.client.UserSubscription.Create().
client := clientFromContext(ctx, r.client)
builder := client.UserSubscription.Create().
SetUserID(sub.UserID).
SetGroupID(sub.GroupID).
SetExpiresAt(sub.ExpiresAt).
@@ -57,7 +58,8 @@ func (r *userSubscriptionRepository) Create(ctx context.Context, sub *service.Us
}
func (r *userSubscriptionRepository) GetByID(ctx context.Context, id int64) (*service.UserSubscription, error) {
m, err := r.client.UserSubscription.Query().
client := clientFromContext(ctx, r.client)
m, err := client.UserSubscription.Query().
Where(usersubscription.IDEQ(id)).
WithUser().
WithGroup().
@@ -70,7 +72,8 @@ func (r *userSubscriptionRepository) GetByID(ctx context.Context, id int64) (*se
}
func (r *userSubscriptionRepository) GetByUserIDAndGroupID(ctx context.Context, userID, groupID int64) (*service.UserSubscription, error) {
m, err := r.client.UserSubscription.Query().
client := clientFromContext(ctx, r.client)
m, err := client.UserSubscription.Query().
Where(usersubscription.UserIDEQ(userID), usersubscription.GroupIDEQ(groupID)).
WithGroup().
Only(ctx)
@@ -81,7 +84,8 @@ func (r *userSubscriptionRepository) GetByUserIDAndGroupID(ctx context.Context,
}
func (r *userSubscriptionRepository) GetActiveByUserIDAndGroupID(ctx context.Context, userID, groupID int64) (*service.UserSubscription, error) {
m, err := r.client.UserSubscription.Query().
client := clientFromContext(ctx, r.client)
m, err := client.UserSubscription.Query().
Where(
usersubscription.UserIDEQ(userID),
usersubscription.GroupIDEQ(groupID),
@@ -98,10 +102,11 @@ func (r *userSubscriptionRepository) GetActiveByUserIDAndGroupID(ctx context.Con
func (r *userSubscriptionRepository) Update(ctx context.Context, sub *service.UserSubscription) error {
if sub == nil {
return nil
return service.ErrSubscriptionNilInput
}
builder := r.client.UserSubscription.UpdateOneID(sub.ID).
client := clientFromContext(ctx, r.client)
builder := client.UserSubscription.UpdateOneID(sub.ID).
SetUserID(sub.UserID).
SetGroupID(sub.GroupID).
SetStartsAt(sub.StartsAt).
@@ -127,12 +132,14 @@ func (r *userSubscriptionRepository) Update(ctx context.Context, sub *service.Us
func (r *userSubscriptionRepository) Delete(ctx context.Context, id int64) error {
// Match GORM semantics: deleting a missing row is not an error.
_, err := r.client.UserSubscription.Delete().Where(usersubscription.IDEQ(id)).Exec(ctx)
client := clientFromContext(ctx, r.client)
_, err := client.UserSubscription.Delete().Where(usersubscription.IDEQ(id)).Exec(ctx)
return err
}
func (r *userSubscriptionRepository) ListByUserID(ctx context.Context, userID int64) ([]service.UserSubscription, error) {
subs, err := r.client.UserSubscription.Query().
client := clientFromContext(ctx, r.client)
subs, err := client.UserSubscription.Query().
Where(usersubscription.UserIDEQ(userID)).
WithGroup().
Order(dbent.Desc(usersubscription.FieldCreatedAt)).
@@ -144,7 +151,8 @@ func (r *userSubscriptionRepository) ListByUserID(ctx context.Context, userID in
}
func (r *userSubscriptionRepository) ListActiveByUserID(ctx context.Context, userID int64) ([]service.UserSubscription, error) {
subs, err := r.client.UserSubscription.Query().
client := clientFromContext(ctx, r.client)
subs, err := client.UserSubscription.Query().
Where(
usersubscription.UserIDEQ(userID),
usersubscription.StatusEQ(service.SubscriptionStatusActive),
@@ -160,7 +168,8 @@ func (r *userSubscriptionRepository) ListActiveByUserID(ctx context.Context, use
}
func (r *userSubscriptionRepository) ListByGroupID(ctx context.Context, groupID int64, params pagination.PaginationParams) ([]service.UserSubscription, *pagination.PaginationResult, error) {
q := r.client.UserSubscription.Query().Where(usersubscription.GroupIDEQ(groupID))
client := clientFromContext(ctx, r.client)
q := client.UserSubscription.Query().Where(usersubscription.GroupIDEQ(groupID))
total, err := q.Clone().Count(ctx)
if err != nil {
@@ -182,7 +191,8 @@ func (r *userSubscriptionRepository) ListByGroupID(ctx context.Context, groupID
}
func (r *userSubscriptionRepository) List(ctx context.Context, params pagination.PaginationParams, userID, groupID *int64, status string) ([]service.UserSubscription, *pagination.PaginationResult, error) {
q := r.client.UserSubscription.Query()
client := clientFromContext(ctx, r.client)
q := client.UserSubscription.Query()
if userID != nil {
q = q.Where(usersubscription.UserIDEQ(*userID))
}
@@ -214,34 +224,39 @@ func (r *userSubscriptionRepository) List(ctx context.Context, params pagination
}
func (r *userSubscriptionRepository) ExistsByUserIDAndGroupID(ctx context.Context, userID, groupID int64) (bool, error) {
return r.client.UserSubscription.Query().
client := clientFromContext(ctx, r.client)
return client.UserSubscription.Query().
Where(usersubscription.UserIDEQ(userID), usersubscription.GroupIDEQ(groupID)).
Exist(ctx)
}
func (r *userSubscriptionRepository) ExtendExpiry(ctx context.Context, subscriptionID int64, newExpiresAt time.Time) error {
_, err := r.client.UserSubscription.UpdateOneID(subscriptionID).
client := clientFromContext(ctx, r.client)
_, err := client.UserSubscription.UpdateOneID(subscriptionID).
SetExpiresAt(newExpiresAt).
Save(ctx)
return translatePersistenceError(err, service.ErrSubscriptionNotFound, nil)
}
func (r *userSubscriptionRepository) UpdateStatus(ctx context.Context, subscriptionID int64, status string) error {
_, err := r.client.UserSubscription.UpdateOneID(subscriptionID).
client := clientFromContext(ctx, r.client)
_, err := client.UserSubscription.UpdateOneID(subscriptionID).
SetStatus(status).
Save(ctx)
return translatePersistenceError(err, service.ErrSubscriptionNotFound, nil)
}
func (r *userSubscriptionRepository) UpdateNotes(ctx context.Context, subscriptionID int64, notes string) error {
_, err := r.client.UserSubscription.UpdateOneID(subscriptionID).
client := clientFromContext(ctx, r.client)
_, err := client.UserSubscription.UpdateOneID(subscriptionID).
SetNotes(notes).
Save(ctx)
return translatePersistenceError(err, service.ErrSubscriptionNotFound, nil)
}
func (r *userSubscriptionRepository) ActivateWindows(ctx context.Context, id int64, start time.Time) error {
_, err := r.client.UserSubscription.UpdateOneID(id).
client := clientFromContext(ctx, r.client)
_, err := client.UserSubscription.UpdateOneID(id).
SetDailyWindowStart(start).
SetWeeklyWindowStart(start).
SetMonthlyWindowStart(start).
@@ -250,7 +265,8 @@ func (r *userSubscriptionRepository) ActivateWindows(ctx context.Context, id int
}
func (r *userSubscriptionRepository) ResetDailyUsage(ctx context.Context, id int64, newWindowStart time.Time) error {
_, err := r.client.UserSubscription.UpdateOneID(id).
client := clientFromContext(ctx, r.client)
_, err := client.UserSubscription.UpdateOneID(id).
SetDailyUsageUsd(0).
SetDailyWindowStart(newWindowStart).
Save(ctx)
@@ -258,7 +274,8 @@ func (r *userSubscriptionRepository) ResetDailyUsage(ctx context.Context, id int
}
func (r *userSubscriptionRepository) ResetWeeklyUsage(ctx context.Context, id int64, newWindowStart time.Time) error {
_, err := r.client.UserSubscription.UpdateOneID(id).
client := clientFromContext(ctx, r.client)
_, err := client.UserSubscription.UpdateOneID(id).
SetWeeklyUsageUsd(0).
SetWeeklyWindowStart(newWindowStart).
Save(ctx)
@@ -266,24 +283,112 @@ func (r *userSubscriptionRepository) ResetWeeklyUsage(ctx context.Context, id in
}
func (r *userSubscriptionRepository) ResetMonthlyUsage(ctx context.Context, id int64, newWindowStart time.Time) error {
_, err := r.client.UserSubscription.UpdateOneID(id).
client := clientFromContext(ctx, r.client)
_, err := client.UserSubscription.UpdateOneID(id).
SetMonthlyUsageUsd(0).
SetMonthlyWindowStart(newWindowStart).
Save(ctx)
return translatePersistenceError(err, service.ErrSubscriptionNotFound, nil)
}
// IncrementUsage 原子性地累加用量并校验限额。
// 使用单条 SQL 语句同时检查 Group 的限额,如果任一限额即将超出则拒绝更新。
// 当更新失败时,会执行额外查询确定具体超出的限额类型。
func (r *userSubscriptionRepository) IncrementUsage(ctx context.Context, id int64, costUSD float64) error {
_, err := r.client.UserSubscription.UpdateOneID(id).
AddDailyUsageUsd(costUSD).
AddWeeklyUsageUsd(costUSD).
AddMonthlyUsageUsd(costUSD).
Save(ctx)
return translatePersistenceError(err, service.ErrSubscriptionNotFound, nil)
// 使用 JOIN 的原子更新:只有当所有限额条件满足时才执行累加
// NULL 限额表示无限制
const atomicUpdateSQL = `
UPDATE user_subscriptions us
SET
daily_usage_usd = us.daily_usage_usd + $1,
weekly_usage_usd = us.weekly_usage_usd + $1,
monthly_usage_usd = us.monthly_usage_usd + $1,
updated_at = NOW()
FROM groups g
WHERE us.id = $2
AND us.deleted_at IS NULL
AND us.group_id = g.id
AND g.deleted_at IS NULL
AND (g.daily_limit_usd IS NULL OR us.daily_usage_usd + $1 <= g.daily_limit_usd)
AND (g.weekly_limit_usd IS NULL OR us.weekly_usage_usd + $1 <= g.weekly_limit_usd)
AND (g.monthly_limit_usd IS NULL OR us.monthly_usage_usd + $1 <= g.monthly_limit_usd)
`
client := clientFromContext(ctx, r.client)
result, err := client.ExecContext(ctx, atomicUpdateSQL, costUSD, id)
if err != nil {
return err
}
affected, err := result.RowsAffected()
if err != nil {
return err
}
if affected > 0 {
return nil // 更新成功
}
// affected == 0可能是订阅不存在、分组已删除、或限额超出
// 执行额外查询确定具体原因
return r.checkIncrementFailureReason(ctx, id, costUSD)
}
// checkIncrementFailureReason 查询更新失败的具体原因
func (r *userSubscriptionRepository) checkIncrementFailureReason(ctx context.Context, id int64, costUSD float64) error {
const checkSQL = `
SELECT
CASE WHEN us.deleted_at IS NOT NULL THEN 'subscription_deleted'
WHEN g.id IS NULL THEN 'subscription_not_found'
WHEN g.deleted_at IS NOT NULL THEN 'group_deleted'
WHEN g.daily_limit_usd IS NOT NULL AND us.daily_usage_usd + $1 > g.daily_limit_usd THEN 'daily_exceeded'
WHEN g.weekly_limit_usd IS NOT NULL AND us.weekly_usage_usd + $1 > g.weekly_limit_usd THEN 'weekly_exceeded'
WHEN g.monthly_limit_usd IS NOT NULL AND us.monthly_usage_usd + $1 > g.monthly_limit_usd THEN 'monthly_exceeded'
ELSE 'unknown'
END AS reason
FROM user_subscriptions us
LEFT JOIN groups g ON us.group_id = g.id
WHERE us.id = $2
`
client := clientFromContext(ctx, r.client)
rows, err := client.QueryContext(ctx, checkSQL, costUSD, id)
if err != nil {
return err
}
defer func() { _ = rows.Close() }()
if !rows.Next() {
return service.ErrSubscriptionNotFound
}
var reason string
if err := rows.Scan(&reason); err != nil {
return err
}
if err := rows.Err(); err != nil {
return err
}
switch reason {
case "subscription_not_found", "subscription_deleted", "group_deleted":
return service.ErrSubscriptionNotFound
case "daily_exceeded":
return service.ErrDailyLimitExceeded
case "weekly_exceeded":
return service.ErrWeeklyLimitExceeded
case "monthly_exceeded":
return service.ErrMonthlyLimitExceeded
default:
// unknown 情况理论上不应发生,但作为兜底返回
return service.ErrSubscriptionNotFound
}
}
func (r *userSubscriptionRepository) BatchUpdateExpiredStatus(ctx context.Context) (int64, error) {
n, err := r.client.UserSubscription.Update().
client := clientFromContext(ctx, r.client)
n, err := client.UserSubscription.Update().
Where(
usersubscription.StatusEQ(service.SubscriptionStatusActive),
usersubscription.ExpiresAtLTE(time.Now()),
@@ -296,7 +401,8 @@ func (r *userSubscriptionRepository) BatchUpdateExpiredStatus(ctx context.Contex
// Extra repository helpers (currently used only by integration tests).
func (r *userSubscriptionRepository) ListExpired(ctx context.Context) ([]service.UserSubscription, error) {
subs, err := r.client.UserSubscription.Query().
client := clientFromContext(ctx, r.client)
subs, err := client.UserSubscription.Query().
Where(
usersubscription.StatusEQ(service.SubscriptionStatusActive),
usersubscription.ExpiresAtLTE(time.Now()),
@@ -309,12 +415,14 @@ func (r *userSubscriptionRepository) ListExpired(ctx context.Context) ([]service
}
func (r *userSubscriptionRepository) CountByGroupID(ctx context.Context, groupID int64) (int64, error) {
count, err := r.client.UserSubscription.Query().Where(usersubscription.GroupIDEQ(groupID)).Count(ctx)
client := clientFromContext(ctx, r.client)
count, err := client.UserSubscription.Query().Where(usersubscription.GroupIDEQ(groupID)).Count(ctx)
return int64(count), err
}
func (r *userSubscriptionRepository) CountActiveByGroupID(ctx context.Context, groupID int64) (int64, error) {
count, err := r.client.UserSubscription.Query().
client := clientFromContext(ctx, r.client)
count, err := client.UserSubscription.Query().
Where(
usersubscription.GroupIDEQ(groupID),
usersubscription.StatusEQ(service.SubscriptionStatusActive),
@@ -325,7 +433,8 @@ func (r *userSubscriptionRepository) CountActiveByGroupID(ctx context.Context, g
}
func (r *userSubscriptionRepository) DeleteByGroupID(ctx context.Context, groupID int64) (int64, error) {
n, err := r.client.UserSubscription.Delete().Where(usersubscription.GroupIDEQ(groupID)).Exec(ctx)
client := clientFromContext(ctx, r.client)
n, err := client.UserSubscription.Delete().Where(usersubscription.GroupIDEQ(groupID)).Exec(ctx)
return int64(n), err
}