diff --git a/backend/internal/repository/user_subscription_repo.go b/backend/internal/repository/user_subscription_repo.go index 2b308674..cd3b9db6 100644 --- a/backend/internal/repository/user_subscription_repo.go +++ b/backend/internal/repository/user_subscription_repo.go @@ -291,13 +291,11 @@ func (r *userSubscriptionRepository) ResetMonthlyUsage(ctx context.Context, id i return translatePersistenceError(err, service.ErrSubscriptionNotFound, nil) } -// IncrementUsage 原子性地累加用量并校验限额。 -// 使用单条 SQL 语句同时检查 Group 的限额,如果任一限额即将超出则拒绝更新。 -// 当更新失败时,会执行额外查询确定具体超出的限额类型。 +// IncrementUsage 原子性地累加订阅用量。 +// 限额检查已在请求前由 BillingCacheService.CheckBillingEligibility 完成, +// 此处仅负责记录实际消费,确保消费数据的完整性。 func (r *userSubscriptionRepository) IncrementUsage(ctx context.Context, id int64, costUSD float64) error { - // 使用 JOIN 的原子更新:只有当所有限额条件满足时才执行累加 - // NULL 限额表示无限制 - const atomicUpdateSQL = ` + const updateSQL = ` UPDATE user_subscriptions us SET daily_usage_usd = us.daily_usage_usd + $1, @@ -309,13 +307,10 @@ func (r *userSubscriptionRepository) IncrementUsage(ctx context.Context, id int6 AND us.deleted_at IS NULL AND us.group_id = g.id AND g.deleted_at IS NULL - AND (g.daily_limit_usd IS NULL OR us.daily_usage_usd + $1 <= g.daily_limit_usd) - AND (g.weekly_limit_usd IS NULL OR us.weekly_usage_usd + $1 <= g.weekly_limit_usd) - AND (g.monthly_limit_usd IS NULL OR us.monthly_usage_usd + $1 <= g.monthly_limit_usd) ` client := clientFromContext(ctx, r.client) - result, err := client.ExecContext(ctx, atomicUpdateSQL, costUSD, id) + result, err := client.ExecContext(ctx, updateSQL, costUSD, id) if err != nil { return err } @@ -326,64 +321,11 @@ func (r *userSubscriptionRepository) IncrementUsage(ctx context.Context, id int6 } if affected > 0 { - return nil // 更新成功 + return nil } - // affected == 0:可能是订阅不存在、分组已删除、或限额超出 - // 执行额外查询确定具体原因 - return r.checkIncrementFailureReason(ctx, id, costUSD) -} - -// checkIncrementFailureReason 查询更新失败的具体原因 -func (r *userSubscriptionRepository) checkIncrementFailureReason(ctx context.Context, id int64, costUSD float64) error { - const checkSQL = ` - SELECT - CASE WHEN us.deleted_at IS NOT NULL THEN 'subscription_deleted' - WHEN g.id IS NULL THEN 'subscription_not_found' - WHEN g.deleted_at IS NOT NULL THEN 'group_deleted' - WHEN g.daily_limit_usd IS NOT NULL AND us.daily_usage_usd + $1 > g.daily_limit_usd THEN 'daily_exceeded' - WHEN g.weekly_limit_usd IS NOT NULL AND us.weekly_usage_usd + $1 > g.weekly_limit_usd THEN 'weekly_exceeded' - WHEN g.monthly_limit_usd IS NOT NULL AND us.monthly_usage_usd + $1 > g.monthly_limit_usd THEN 'monthly_exceeded' - ELSE 'unknown' - END AS reason - FROM user_subscriptions us - LEFT JOIN groups g ON us.group_id = g.id - WHERE us.id = $2 - ` - - client := clientFromContext(ctx, r.client) - rows, err := client.QueryContext(ctx, checkSQL, costUSD, id) - if err != nil { - return err - } - defer func() { _ = rows.Close() }() - - if !rows.Next() { - return service.ErrSubscriptionNotFound - } - - var reason string - if err := rows.Scan(&reason); err != nil { - return err - } - - if err := rows.Err(); err != nil { - return err - } - - switch reason { - case "subscription_not_found", "subscription_deleted", "group_deleted": - return service.ErrSubscriptionNotFound - case "daily_exceeded": - return service.ErrDailyLimitExceeded - case "weekly_exceeded": - return service.ErrWeeklyLimitExceeded - case "monthly_exceeded": - return service.ErrMonthlyLimitExceeded - default: - // unknown 情况理论上不应发生,但作为兜底返回 - return service.ErrSubscriptionNotFound - } + // affected == 0:订阅不存在或已删除 + return service.ErrSubscriptionNotFound } func (r *userSubscriptionRepository) BatchUpdateExpiredStatus(ctx context.Context) (int64, error) { diff --git a/backend/internal/repository/user_subscription_repo_integration_test.go b/backend/internal/repository/user_subscription_repo_integration_test.go index 3a6c6434..2099e5d8 100644 --- a/backend/internal/repository/user_subscription_repo_integration_test.go +++ b/backend/internal/repository/user_subscription_repo_integration_test.go @@ -633,112 +633,7 @@ func (s *UserSubscriptionRepoSuite) TestActiveExpiredBoundaries_UsageAndReset_Ba s.Require().Equal(service.SubscriptionStatusExpired, updated.Status, "expected status expired") } -// --- 限额检查与软删除过滤测试 --- - -func (s *UserSubscriptionRepoSuite) mustCreateGroupWithLimits(name string, daily, weekly, monthly *float64) *service.Group { - s.T().Helper() - - create := s.client.Group.Create(). - SetName(name). - SetStatus(service.StatusActive). - SetSubscriptionType(service.SubscriptionTypeSubscription) - - if daily != nil { - create.SetDailyLimitUsd(*daily) - } - if weekly != nil { - create.SetWeeklyLimitUsd(*weekly) - } - if monthly != nil { - create.SetMonthlyLimitUsd(*monthly) - } - - g, err := create.Save(s.ctx) - s.Require().NoError(err, "create group with limits") - return groupEntityToService(g) -} - -func (s *UserSubscriptionRepoSuite) TestIncrementUsage_DailyLimitExceeded() { - user := s.mustCreateUser("dailylimit@test.com", service.RoleUser) - dailyLimit := 10.0 - group := s.mustCreateGroupWithLimits("g-dailylimit", &dailyLimit, nil, nil) - sub := s.mustCreateSubscription(user.ID, group.ID, nil) - - // 先增加 9.0,应该成功 - err := s.repo.IncrementUsage(s.ctx, sub.ID, 9.0) - s.Require().NoError(err, "first increment should succeed") - - // 再增加 2.0,会超过 10.0 限额,应该失败 - err = s.repo.IncrementUsage(s.ctx, sub.ID, 2.0) - s.Require().Error(err, "should fail when daily limit exceeded") - s.Require().ErrorIs(err, service.ErrDailyLimitExceeded) - - // 验证用量没有变化 - got, err := s.repo.GetByID(s.ctx, sub.ID) - s.Require().NoError(err) - s.Require().InDelta(9.0, got.DailyUsageUSD, 1e-6, "usage should not change after failed increment") -} - -func (s *UserSubscriptionRepoSuite) TestIncrementUsage_WeeklyLimitExceeded() { - user := s.mustCreateUser("weeklylimit@test.com", service.RoleUser) - weeklyLimit := 50.0 - group := s.mustCreateGroupWithLimits("g-weeklylimit", nil, &weeklyLimit, nil) - sub := s.mustCreateSubscription(user.ID, group.ID, nil) - - // 增加 45.0,应该成功 - err := s.repo.IncrementUsage(s.ctx, sub.ID, 45.0) - s.Require().NoError(err, "first increment should succeed") - - // 再增加 10.0,会超过 50.0 限额,应该失败 - err = s.repo.IncrementUsage(s.ctx, sub.ID, 10.0) - s.Require().Error(err, "should fail when weekly limit exceeded") - s.Require().ErrorIs(err, service.ErrWeeklyLimitExceeded) -} - -func (s *UserSubscriptionRepoSuite) TestIncrementUsage_MonthlyLimitExceeded() { - user := s.mustCreateUser("monthlylimit@test.com", service.RoleUser) - monthlyLimit := 100.0 - group := s.mustCreateGroupWithLimits("g-monthlylimit", nil, nil, &monthlyLimit) - sub := s.mustCreateSubscription(user.ID, group.ID, nil) - - // 增加 90.0,应该成功 - err := s.repo.IncrementUsage(s.ctx, sub.ID, 90.0) - s.Require().NoError(err, "first increment should succeed") - - // 再增加 20.0,会超过 100.0 限额,应该失败 - err = s.repo.IncrementUsage(s.ctx, sub.ID, 20.0) - s.Require().Error(err, "should fail when monthly limit exceeded") - s.Require().ErrorIs(err, service.ErrMonthlyLimitExceeded) -} - -func (s *UserSubscriptionRepoSuite) TestIncrementUsage_NoLimits() { - user := s.mustCreateUser("nolimits@test.com", service.RoleUser) - group := s.mustCreateGroupWithLimits("g-nolimits", nil, nil, nil) // 无限额 - sub := s.mustCreateSubscription(user.ID, group.ID, nil) - - // 应该可以增加任意金额 - err := s.repo.IncrementUsage(s.ctx, sub.ID, 1000000.0) - s.Require().NoError(err, "should succeed without limits") - - got, err := s.repo.GetByID(s.ctx, sub.ID) - s.Require().NoError(err) - s.Require().InDelta(1000000.0, got.DailyUsageUSD, 1e-6) -} - -func (s *UserSubscriptionRepoSuite) TestIncrementUsage_AtExactLimit() { - user := s.mustCreateUser("exactlimit@test.com", service.RoleUser) - dailyLimit := 10.0 - group := s.mustCreateGroupWithLimits("g-exactlimit", &dailyLimit, nil, nil) - sub := s.mustCreateSubscription(user.ID, group.ID, nil) - - // 正好达到限额应该成功 - err := s.repo.IncrementUsage(s.ctx, sub.ID, 10.0) - s.Require().NoError(err, "should succeed at exact limit") - - got, err := s.repo.GetByID(s.ctx, sub.ID) - s.Require().NoError(err) - s.Require().InDelta(10.0, got.DailyUsageUSD, 1e-6) -} +// --- 软删除过滤测试 --- func (s *UserSubscriptionRepoSuite) TestIncrementUsage_SoftDeletedGroup() { user := s.mustCreateUser("softdeleted@test.com", service.RoleUser) @@ -779,7 +674,7 @@ func (s *UserSubscriptionRepoSuite) TestUpdate_NilInput() { func (s *UserSubscriptionRepoSuite) TestIncrementUsage_Concurrent() { user := s.mustCreateUser("concurrent@test.com", service.RoleUser) - group := s.mustCreateGroupWithLimits("g-concurrent", nil, nil, nil) // 无限额 + group := s.mustCreateGroup("g-concurrent") sub := s.mustCreateSubscription(user.ID, group.ID, nil) const numGoroutines = 10 @@ -808,34 +703,6 @@ func (s *UserSubscriptionRepoSuite) TestIncrementUsage_Concurrent() { s.Require().InDelta(expectedUsage, got.MonthlyUsageUSD, 1e-6, "monthly usage should be correctly accumulated") } -func (s *UserSubscriptionRepoSuite) TestIncrementUsage_ConcurrentWithLimit() { - user := s.mustCreateUser("concurrentlimit@test.com", service.RoleUser) - dailyLimit := 5.0 - group := s.mustCreateGroupWithLimits("g-concurrentlimit", &dailyLimit, nil, nil) - sub := s.mustCreateSubscription(user.ID, group.ID, nil) - - // 注意:事务内的操作是串行的,所以这里改为顺序执行以验证限额逻辑 - // 尝试增加 10 次,每次 1.0,但限额只有 5.0 - const numAttempts = 10 - const incrementPerAttempt = 1.0 - - successCount := 0 - for i := 0; i < numAttempts; i++ { - err := s.repo.IncrementUsage(s.ctx, sub.ID, incrementPerAttempt) - if err == nil { - successCount++ - } - } - - // 验证:应该有 5 次成功(不超过限额),5 次失败(超出限额) - s.Require().Equal(5, successCount, "exactly 5 increments should succeed (limit=5, increment=1)") - - // 验证最终用量等于限额 - got, err := s.repo.GetByID(s.ctx, sub.ID) - s.Require().NoError(err) - s.Require().InDelta(dailyLimit, got.DailyUsageUSD, 1e-6, "daily usage should equal limit") -} - func (s *UserSubscriptionRepoSuite) TestTxContext_RollbackIsolation() { baseClient := testEntClient(s.T()) tx, err := baseClient.Tx(context.Background()) diff --git a/backend/internal/service/admin_service.go b/backend/internal/service/admin_service.go index 4be09810..feeb19a0 100644 --- a/backend/internal/service/admin_service.go +++ b/backend/internal/service/admin_service.go @@ -488,6 +488,11 @@ func (s *adminServiceImpl) CreateGroup(ctx context.Context, input *CreateGroupIn subscriptionType = SubscriptionTypeStandard } + // 限额字段:0 和 nil 都表示"无限制" + dailyLimit := normalizeLimit(input.DailyLimitUSD) + weeklyLimit := normalizeLimit(input.WeeklyLimitUSD) + monthlyLimit := normalizeLimit(input.MonthlyLimitUSD) + group := &Group{ Name: input.Name, Description: input.Description, @@ -496,9 +501,9 @@ func (s *adminServiceImpl) CreateGroup(ctx context.Context, input *CreateGroupIn IsExclusive: input.IsExclusive, Status: StatusActive, SubscriptionType: subscriptionType, - DailyLimitUSD: input.DailyLimitUSD, - WeeklyLimitUSD: input.WeeklyLimitUSD, - MonthlyLimitUSD: input.MonthlyLimitUSD, + DailyLimitUSD: dailyLimit, + WeeklyLimitUSD: weeklyLimit, + MonthlyLimitUSD: monthlyLimit, } if err := s.groupRepo.Create(ctx, group); err != nil { return nil, err @@ -506,6 +511,14 @@ func (s *adminServiceImpl) CreateGroup(ctx context.Context, input *CreateGroupIn return group, nil } +// normalizeLimit 将 0 或负数转换为 nil(表示无限制) +func normalizeLimit(limit *float64) *float64 { + if limit == nil || *limit <= 0 { + return nil + } + return limit +} + func (s *adminServiceImpl) UpdateGroup(ctx context.Context, id int64, input *UpdateGroupInput) (*Group, error) { group, err := s.groupRepo.GetByID(ctx, id) if err != nil { @@ -535,15 +548,15 @@ func (s *adminServiceImpl) UpdateGroup(ctx context.Context, id int64, input *Upd if input.SubscriptionType != "" { group.SubscriptionType = input.SubscriptionType } - // 限额字段支持设置为nil(清除限额)或具体值 + // 限额字段:0 和 nil 都表示"无限制",正数表示具体限额 if input.DailyLimitUSD != nil { - group.DailyLimitUSD = input.DailyLimitUSD + group.DailyLimitUSD = normalizeLimit(input.DailyLimitUSD) } if input.WeeklyLimitUSD != nil { - group.WeeklyLimitUSD = input.WeeklyLimitUSD + group.WeeklyLimitUSD = normalizeLimit(input.WeeklyLimitUSD) } if input.MonthlyLimitUSD != nil { - group.MonthlyLimitUSD = input.MonthlyLimitUSD + group.MonthlyLimitUSD = normalizeLimit(input.MonthlyLimitUSD) } if err := s.groupRepo.Update(ctx, group); err != nil { diff --git a/backend/internal/service/subscription_service.go b/backend/internal/service/subscription_service.go index 09554c0f..f6aefb83 100644 --- a/backend/internal/service/subscription_service.go +++ b/backend/internal/service/subscription_service.go @@ -490,6 +490,7 @@ func (s *SubscriptionService) CheckAndResetWindows(ctx context.Context, sub *Use } // CheckUsageLimits 检查使用限额(返回错误如果超限) +// 用于中间件的快速预检查,additionalCost 通常为 0 func (s *SubscriptionService) CheckUsageLimits(ctx context.Context, sub *UserSubscription, group *Group, additionalCost float64) error { if !sub.CheckDailyLimit(group, additionalCost) { return ErrDailyLimitExceeded