fix(数据层): 修复数据完整性与仓储一致性问题

## 数据完整性修复 (fix-critical-data-integrity)
- 添加 error_translate.go 统一错误转换层
- 修复 nil 输入和 NotFound 错误处理
- 增强仓储层错误一致性

## 仓储一致性修复 (fix-high-repository-consistency)
- Group schema 添加 default_validity_days 字段
- Account schema 添加 proxy edge 关联
- 新增 UsageLog ent schema 定义
- 修复 UpdateBalance/UpdateConcurrency 受影响行数校验

## 数据卫生修复 (fix-medium-data-hygiene)
- UserSubscription 添加软删除支持 (SoftDeleteMixin)
- RedeemCode/Setting 添加硬删除策略文档
- account_groups/user_allowed_groups 的 created_at 声明 timestamptz
- 停止写入 legacy users.allowed_groups 列
- 新增迁移: 011-014 (索引优化、软删除、孤立数据审计、列清理)

## 测试补充
- 添加 UserSubscription 软删除测试
- 添加迁移回归测试
- 添加 NotFound 错误测试

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
yangjianbo
2025-12-31 14:11:57 +08:00
parent 820bb16ca7
commit 5906f9ab98
87 changed files with 15258 additions and 485 deletions

View File

@@ -14,6 +14,7 @@ import (
"context"
"database/sql"
"encoding/json"
"errors"
"strconv"
"time"
@@ -56,7 +57,7 @@ func newAccountRepositoryWithSQL(client *dbent.Client, sqlq sqlExecutor) *accoun
func (r *accountRepository) Create(ctx context.Context, account *service.Account) error {
if account == nil {
return nil
return service.ErrAccountNilInput
}
builder := r.client.Account.Create().
@@ -98,7 +99,7 @@ func (r *accountRepository) Create(ctx context.Context, account *service.Account
created, err := builder.Save(ctx)
if err != nil {
return err
return translatePersistenceError(err, service.ErrAccountNotFound, nil)
}
account.ID = created.ID
@@ -231,11 +232,32 @@ func (r *accountRepository) Update(ctx context.Context, account *service.Account
}
func (r *accountRepository) Delete(ctx context.Context, id int64) error {
if _, err := r.client.AccountGroup.Delete().Where(dbaccountgroup.AccountIDEQ(id)).Exec(ctx); err != nil {
// 使用事务保证账号与关联分组的删除原子性
tx, err := r.client.Tx(ctx)
if err != nil && !errors.Is(err, dbent.ErrTxStarted) {
return err
}
_, err := r.client.Account.Delete().Where(dbaccount.IDEQ(id)).Exec(ctx)
return err
var txClient *dbent.Client
if err == nil {
defer func() { _ = tx.Rollback() }()
txClient = tx.Client()
} else {
// 已处于外部事务中ErrTxStarted复用当前 client
txClient = r.client
}
if _, err := txClient.AccountGroup.Delete().Where(dbaccountgroup.AccountIDEQ(id)).Exec(ctx); err != nil {
return err
}
if _, err := txClient.Account.Delete().Where(dbaccount.IDEQ(id)).Exec(ctx); err != nil {
return err
}
if tx != nil {
return tx.Commit()
}
return nil
}
func (r *accountRepository) List(ctx context.Context, params pagination.PaginationParams) ([]service.Account, *pagination.PaginationResult, error) {
@@ -393,25 +415,49 @@ func (r *accountRepository) GetGroups(ctx context.Context, accountID int64) ([]s
}
func (r *accountRepository) BindGroups(ctx context.Context, accountID int64, groupIDs []int64) error {
if _, err := r.client.AccountGroup.Delete().Where(dbaccountgroup.AccountIDEQ(accountID)).Exec(ctx); err != nil {
// 使用事务保证删除旧绑定与创建新绑定的原子性
tx, err := r.client.Tx(ctx)
if err != nil && !errors.Is(err, dbent.ErrTxStarted) {
return err
}
var txClient *dbent.Client
if err == nil {
defer func() { _ = tx.Rollback() }()
txClient = tx.Client()
} else {
// 已处于外部事务中ErrTxStarted复用当前 client
txClient = r.client
}
if _, err := txClient.AccountGroup.Delete().Where(dbaccountgroup.AccountIDEQ(accountID)).Exec(ctx); err != nil {
return err
}
if len(groupIDs) == 0 {
if tx != nil {
return tx.Commit()
}
return nil
}
builders := make([]*dbent.AccountGroupCreate, 0, len(groupIDs))
for i, groupID := range groupIDs {
builders = append(builders, r.client.AccountGroup.Create().
builders = append(builders, txClient.AccountGroup.Create().
SetAccountID(accountID).
SetGroupID(groupID).
SetPriority(i+1),
)
}
_, err := r.client.AccountGroup.CreateBulk(builders...).Save(ctx)
return err
if _, err := txClient.AccountGroup.CreateBulk(builders...).Save(ctx); err != nil {
return err
}
if tx != nil {
return tx.Commit()
}
return nil
}
func (r *accountRepository) ListSchedulable(ctx context.Context) ([]service.Account, error) {
@@ -555,24 +601,30 @@ func (r *accountRepository) UpdateExtra(ctx context.Context, id int64, updates m
return nil
}
accountExtra, err := r.client.Account.Query().
Where(dbaccount.IDEQ(id)).
Select(dbaccount.FieldExtra).
Only(ctx)
// 使用 JSONB 合并操作实现原子更新,避免读-改-写的并发丢失更新问题
payload, err := json.Marshal(updates)
if err != nil {
return translatePersistenceError(err, service.ErrAccountNotFound, nil)
return err
}
extra := normalizeJSONMap(accountExtra.Extra)
for k, v := range updates {
extra[k] = v
client := clientFromContext(ctx, r.client)
result, err := client.ExecContext(
ctx,
"UPDATE accounts SET extra = COALESCE(extra, '{}'::jsonb) || $1::jsonb, updated_at = NOW() WHERE id = $2 AND deleted_at IS NULL",
payload, id,
)
if err != nil {
return err
}
_, err = r.client.Account.Update().
Where(dbaccount.IDEQ(id)).
SetExtra(extra).
Save(ctx)
return err
affected, err := result.RowsAffected()
if err != nil {
return err
}
if affected == 0 {
return service.ErrAccountNotFound
}
return nil
}
func (r *accountRepository) BulkUpdate(ctx context.Context, ids []int64, updates service.AccountBulkUpdate) (int64, error) {

View File

@@ -318,12 +318,13 @@ func groupEntityToService(g *dbent.Group) *service.Group {
RateMultiplier: g.RateMultiplier,
IsExclusive: g.IsExclusive,
Status: g.Status,
SubscriptionType: g.SubscriptionType,
DailyLimitUSD: g.DailyLimitUsd,
WeeklyLimitUSD: g.WeeklyLimitUsd,
MonthlyLimitUSD: g.MonthlyLimitUsd,
CreatedAt: g.CreatedAt,
UpdatedAt: g.UpdatedAt,
SubscriptionType: g.SubscriptionType,
DailyLimitUSD: g.DailyLimitUsd,
WeeklyLimitUSD: g.WeeklyLimitUsd,
MonthlyLimitUSD: g.MonthlyLimitUsd,
DefaultValidityDays: g.DefaultValidityDays,
CreatedAt: g.CreatedAt,
UpdatedAt: g.UpdatedAt,
}
}

View File

@@ -1,6 +1,7 @@
package repository
import (
"context"
"database/sql"
"errors"
"strings"
@@ -10,6 +11,25 @@ import (
"github.com/lib/pq"
)
// clientFromContext 从 context 中获取事务 client如果不存在则返回默认 client。
//
// 这个辅助函数支持 repository 方法在事务上下文中工作:
// - 如果 context 中存在事务(通过 ent.NewTxContext 设置),返回事务的 client
// - 否则返回传入的默认 client
//
// 使用示例:
//
// func (r *someRepo) SomeMethod(ctx context.Context) error {
// client := clientFromContext(ctx, r.client)
// return client.SomeEntity.Create().Save(ctx)
// }
func clientFromContext(ctx context.Context, defaultClient *dbent.Client) *dbent.Client {
if tx := dbent.TxFromContext(ctx); tx != nil {
return tx.Client()
}
return defaultClient
}
// translatePersistenceError 将数据库层错误翻译为业务层错误。
//
// 这是 Repository 层的核心错误处理函数,确保数据库细节不会泄露到业务层。

View File

@@ -42,7 +42,8 @@ func (r *groupRepository) Create(ctx context.Context, groupIn *service.Group) er
SetSubscriptionType(groupIn.SubscriptionType).
SetNillableDailyLimitUsd(groupIn.DailyLimitUSD).
SetNillableWeeklyLimitUsd(groupIn.WeeklyLimitUSD).
SetNillableMonthlyLimitUsd(groupIn.MonthlyLimitUSD)
SetNillableMonthlyLimitUsd(groupIn.MonthlyLimitUSD).
SetDefaultValidityDays(groupIn.DefaultValidityDays)
created, err := builder.Save(ctx)
if err == nil {
@@ -79,6 +80,7 @@ func (r *groupRepository) Update(ctx context.Context, groupIn *service.Group) er
SetNillableDailyLimitUsd(groupIn.DailyLimitUSD).
SetNillableWeeklyLimitUsd(groupIn.WeeklyLimitUSD).
SetNillableMonthlyLimitUsd(groupIn.MonthlyLimitUSD).
SetDefaultValidityDays(groupIn.DefaultValidityDays).
Save(ctx)
if err != nil {
return translatePersistenceError(err, service.ErrGroupNotFound, service.ErrGroupExists)
@@ -89,7 +91,7 @@ func (r *groupRepository) Update(ctx context.Context, groupIn *service.Group) er
func (r *groupRepository) Delete(ctx context.Context, id int64) error {
_, err := r.client.Group.Delete().Where(group.IDEQ(id)).Exec(ctx)
return err
return translatePersistenceError(err, service.ErrGroupNotFound, nil)
}
func (r *groupRepository) List(ctx context.Context, params pagination.PaginationParams) ([]service.Group, *pagination.PaginationResult, error) {
@@ -239,8 +241,8 @@ func (r *groupRepository) DeleteCascade(ctx context.Context, id int64) ([]int64,
// err 为 dbent.ErrTxStarted 时,复用当前 client 参与同一事务。
// Lock the group row to avoid concurrent writes while we cascade.
// 这里使用 exec.QueryContext 手动扫描,确保同一事务内加锁并能区分未找到与其他错误。
rows, err := exec.QueryContext(ctx, "SELECT id FROM groups WHERE id = $1 FOR UPDATE", id)
// 这里使用 exec.QueryContext 手动扫描,确保同一事务内加锁并能区分"未找到"与其他错误。
rows, err := exec.QueryContext(ctx, "SELECT id FROM groups WHERE id = $1 AND deleted_at IS NULL FOR UPDATE", id)
if err != nil {
return nil, err
}
@@ -263,7 +265,8 @@ func (r *groupRepository) DeleteCascade(ctx context.Context, id int64) ([]int64,
var affectedUserIDs []int64
if groupSvc.IsSubscriptionType() {
rows, err := exec.QueryContext(ctx, "SELECT user_id FROM user_subscriptions WHERE group_id = $1", id)
// 只查询未软删除的订阅,避免通知已取消订阅的用户
rows, err := exec.QueryContext(ctx, "SELECT user_id FROM user_subscriptions WHERE group_id = $1 AND deleted_at IS NULL", id)
if err != nil {
return nil, err
}
@@ -282,7 +285,8 @@ func (r *groupRepository) DeleteCascade(ctx context.Context, id int64) ([]int64,
return nil, err
}
if _, err := exec.ExecContext(ctx, "DELETE FROM user_subscriptions WHERE group_id = $1", id); err != nil {
// 软删除订阅:设置 deleted_at 而非硬删除
if _, err := exec.ExecContext(ctx, "UPDATE user_subscriptions SET deleted_at = NOW() WHERE group_id = $1 AND deleted_at IS NULL", id); err != nil {
return nil, err
}
}
@@ -297,18 +301,11 @@ func (r *groupRepository) DeleteCascade(ctx context.Context, id int64) ([]int64,
return nil, err
}
// 3. Remove the group id from users.allowed_groups array (legacy representation).
// Phase 1 compatibility: also delete from user_allowed_groups join table when present.
// 3. Remove the group id from user_allowed_groups join table.
// Legacy users.allowed_groups 列已弃用,不再同步。
if _, err := exec.ExecContext(ctx, "DELETE FROM user_allowed_groups WHERE group_id = $1", id); err != nil {
return nil, err
}
if _, err := exec.ExecContext(
ctx,
"UPDATE users SET allowed_groups = array_remove(allowed_groups, $1) WHERE $1 = ANY(allowed_groups)",
id,
); err != nil {
return nil, err
}
// 4. Delete account_groups join rows.
if _, err := exec.ExecContext(ctx, "DELETE FROM account_groups WHERE group_id = $1", id); err != nil {

View File

@@ -478,3 +478,58 @@ func (s *GroupRepoSuite) TestDeleteAccountGroupsByGroupID_MultipleAccounts() {
count, _ := s.repo.GetAccountCount(s.ctx, g.ID)
s.Require().Zero(count)
}
// --- 软删除过滤测试 ---
func (s *GroupRepoSuite) TestDelete_SoftDelete_NotVisibleInList() {
group := &service.Group{
Name: "to-soft-delete",
Platform: service.PlatformAnthropic,
RateMultiplier: 1.0,
IsExclusive: false,
Status: service.StatusActive,
SubscriptionType: service.SubscriptionTypeStandard,
}
s.Require().NoError(s.repo.Create(s.ctx, group))
// 获取删除前的列表数量
listBefore, _, err := s.repo.List(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 100})
s.Require().NoError(err)
beforeCount := len(listBefore)
// 软删除
err = s.repo.Delete(s.ctx, group.ID)
s.Require().NoError(err, "Delete (soft delete)")
// 验证列表中不再包含软删除的 group
listAfter, _, err := s.repo.List(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 100})
s.Require().NoError(err)
s.Require().Len(listAfter, beforeCount-1, "soft deleted group should not appear in list")
// 验证 GetByID 也无法找到
_, err = s.repo.GetByID(s.ctx, group.ID)
s.Require().Error(err)
s.Require().ErrorIs(err, service.ErrGroupNotFound)
}
func (s *GroupRepoSuite) TestDelete_SoftDeletedGroup_lockForUpdate() {
group := &service.Group{
Name: "lock-soft-delete",
Platform: service.PlatformAnthropic,
RateMultiplier: 1.0,
IsExclusive: false,
Status: service.StatusActive,
SubscriptionType: service.SubscriptionTypeStandard,
}
s.Require().NoError(s.repo.Create(s.ctx, group))
// 软删除
err := s.repo.Delete(s.ctx, group.ID)
s.Require().NoError(err)
// 验证软删除的 group 在 GetByID 时返回 ErrGroupNotFound
// 这证明 lockForUpdate 的 deleted_at IS NULL 过滤正在工作
_, err = s.repo.GetByID(s.ctx, group.ID)
s.Require().Error(err, "should fail to get soft-deleted group")
s.Require().ErrorIs(err, service.ErrGroupNotFound)
}

View File

@@ -53,6 +53,20 @@ func TestMigrationsRunner_IsIdempotent_AndSchemaIsUpToDate(t *testing.T) {
var uagRegclass sql.NullString
require.NoError(t, tx.QueryRowContext(context.Background(), "SELECT to_regclass('public.user_allowed_groups')").Scan(&uagRegclass))
require.True(t, uagRegclass.Valid, "expected user_allowed_groups table to exist")
// user_subscriptions: deleted_at for soft delete support (migration 012)
requireColumn(t, tx, "user_subscriptions", "deleted_at", "timestamp with time zone", 0, true)
// orphan_allowed_groups_audit table should exist (migration 013)
var orphanAuditRegclass sql.NullString
require.NoError(t, tx.QueryRowContext(context.Background(), "SELECT to_regclass('public.orphan_allowed_groups_audit')").Scan(&orphanAuditRegclass))
require.True(t, orphanAuditRegclass.Valid, "expected orphan_allowed_groups_audit table to exist")
// account_groups: created_at should be timestamptz
requireColumn(t, tx, "account_groups", "created_at", "timestamp with time zone", 0, false)
// user_allowed_groups: created_at should be timestamptz
requireColumn(t, tx, "user_allowed_groups", "created_at", "timestamp with time zone", 0, false)
}
func requireColumn(t *testing.T, tx *sql.Tx, table, column, dataType string, maxLen int, nullable bool) {

View File

@@ -178,7 +178,7 @@ func (r *proxyRepository) CountAccountsByProxyID(ctx context.Context, proxyID in
// GetAccountCountsForProxies returns a map of proxy ID to account count for all proxies
func (r *proxyRepository) GetAccountCountsForProxies(ctx context.Context) (counts map[int64]int64, err error) {
rows, err := r.sql.QueryContext(ctx, "SELECT proxy_id, COUNT(*) AS count FROM accounts WHERE proxy_id IS NOT NULL GROUP BY proxy_id")
rows, err := r.sql.QueryContext(ctx, "SELECT proxy_id, COUNT(*) AS count FROM accounts WHERE proxy_id IS NOT NULL AND deleted_at IS NULL GROUP BY proxy_id")
if err != nil {
return nil, err
}

View File

@@ -168,7 +168,8 @@ func (r *redeemCodeRepository) Update(ctx context.Context, code *service.RedeemC
func (r *redeemCodeRepository) Use(ctx context.Context, id, userID int64) error {
now := time.Now()
affected, err := r.client.RedeemCode.Update().
client := clientFromContext(ctx, r.client)
affected, err := client.RedeemCode.Update().
Where(redeemcode.IDEQ(id), redeemcode.StatusEQ(service.StatusUnused)).
SetStatus(service.StatusUsed).
SetUsedBy(userID).

View File

@@ -7,10 +7,12 @@ import (
"fmt"
"strings"
"testing"
"time"
dbent "github.com/Wei-Shaw/sub2api/ent"
"github.com/Wei-Shaw/sub2api/ent/apikey"
"github.com/Wei-Shaw/sub2api/ent/schema/mixins"
"github.com/Wei-Shaw/sub2api/ent/usersubscription"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/stretchr/testify/require"
)
@@ -111,3 +113,104 @@ func TestEntSoftDelete_ApiKey_HardDeleteViaSkipSoftDelete(t *testing.T) {
Only(mixins.SkipSoftDelete(ctx))
require.True(t, dbent.IsNotFound(err), "expected row to be hard deleted")
}
// --- UserSubscription 软删除测试 ---
func createEntGroup(t *testing.T, ctx context.Context, client *dbent.Client, name string) *dbent.Group {
t.Helper()
g, err := client.Group.Create().
SetName(name).
SetStatus(service.StatusActive).
Save(ctx)
require.NoError(t, err, "create ent group")
return g
}
func TestEntSoftDelete_UserSubscription_DefaultFilterAndSkip(t *testing.T) {
ctx := context.Background()
client := testEntClient(t)
u := createEntUser(t, ctx, client, uniqueSoftDeleteValue(t, "sd-sub-user")+"@example.com")
g := createEntGroup(t, ctx, client, uniqueSoftDeleteValue(t, "sd-sub-group"))
repo := NewUserSubscriptionRepository(client)
sub := &service.UserSubscription{
UserID: u.ID,
GroupID: g.ID,
Status: service.SubscriptionStatusActive,
ExpiresAt: time.Now().Add(24 * time.Hour),
}
require.NoError(t, repo.Create(ctx, sub), "create user subscription")
require.NoError(t, repo.Delete(ctx, sub.ID), "soft delete user subscription")
_, err := repo.GetByID(ctx, sub.ID)
require.Error(t, err, "deleted rows should be hidden by default")
_, err = client.UserSubscription.Query().Where(usersubscription.IDEQ(sub.ID)).Only(ctx)
require.Error(t, err, "default ent query should not see soft-deleted rows")
require.True(t, dbent.IsNotFound(err), "expected ent not-found after default soft delete filter")
got, err := client.UserSubscription.Query().
Where(usersubscription.IDEQ(sub.ID)).
Only(mixins.SkipSoftDelete(ctx))
require.NoError(t, err, "SkipSoftDelete should include soft-deleted rows")
require.NotNil(t, got.DeletedAt, "deleted_at should be set after soft delete")
}
func TestEntSoftDelete_UserSubscription_DeleteIdempotent(t *testing.T) {
ctx := context.Background()
client := testEntClient(t)
u := createEntUser(t, ctx, client, uniqueSoftDeleteValue(t, "sd-sub-user2")+"@example.com")
g := createEntGroup(t, ctx, client, uniqueSoftDeleteValue(t, "sd-sub-group2"))
repo := NewUserSubscriptionRepository(client)
sub := &service.UserSubscription{
UserID: u.ID,
GroupID: g.ID,
Status: service.SubscriptionStatusActive,
ExpiresAt: time.Now().Add(24 * time.Hour),
}
require.NoError(t, repo.Create(ctx, sub), "create user subscription")
require.NoError(t, repo.Delete(ctx, sub.ID), "first delete")
require.NoError(t, repo.Delete(ctx, sub.ID), "second delete should be idempotent")
}
func TestEntSoftDelete_UserSubscription_ListExcludesDeleted(t *testing.T) {
ctx := context.Background()
client := testEntClient(t)
u := createEntUser(t, ctx, client, uniqueSoftDeleteValue(t, "sd-sub-user3")+"@example.com")
g1 := createEntGroup(t, ctx, client, uniqueSoftDeleteValue(t, "sd-sub-group3a"))
g2 := createEntGroup(t, ctx, client, uniqueSoftDeleteValue(t, "sd-sub-group3b"))
repo := NewUserSubscriptionRepository(client)
sub1 := &service.UserSubscription{
UserID: u.ID,
GroupID: g1.ID,
Status: service.SubscriptionStatusActive,
ExpiresAt: time.Now().Add(24 * time.Hour),
}
require.NoError(t, repo.Create(ctx, sub1), "create subscription 1")
sub2 := &service.UserSubscription{
UserID: u.ID,
GroupID: g2.ID,
Status: service.SubscriptionStatusActive,
ExpiresAt: time.Now().Add(24 * time.Hour),
}
require.NoError(t, repo.Create(ctx, sub2), "create subscription 2")
// 软删除 sub1
require.NoError(t, repo.Delete(ctx, sub1.ID), "soft delete subscription 1")
// ListByUserID 应只返回未删除的订阅
subs, err := repo.ListByUserID(ctx, u.ID)
require.NoError(t, err, "ListByUserID")
require.Len(t, subs, 1, "should only return non-deleted subscriptions")
require.Equal(t, sub2.ID, subs[0].ID, "expected sub2 to be returned")
}

View File

@@ -1109,6 +1109,9 @@ func (r *usageLogRepository) GetBatchUserUsageStats(ctx context.Context, userIDs
if err := rows.Close(); err != nil {
return nil, err
}
if err := rows.Err(); err != nil {
return nil, err
}
today := timezone.Today()
todayQuery := `
@@ -1135,6 +1138,9 @@ func (r *usageLogRepository) GetBatchUserUsageStats(ctx context.Context, userIDs
if err := rows.Close(); err != nil {
return nil, err
}
if err := rows.Err(); err != nil {
return nil, err
}
return result, nil
}
@@ -1177,6 +1183,9 @@ func (r *usageLogRepository) GetBatchApiKeyUsageStats(ctx context.Context, apiKe
if err := rows.Close(); err != nil {
return nil, err
}
if err := rows.Err(); err != nil {
return nil, err
}
today := timezone.Today()
todayQuery := `
@@ -1203,6 +1212,9 @@ func (r *usageLogRepository) GetBatchApiKeyUsageStats(ctx context.Context, apiKe
if err := rows.Close(); err != nil {
return nil, err
}
if err := rows.Err(); err != nil {
return nil, err
}
return result, nil
}

View File

@@ -12,7 +12,6 @@ import (
"github.com/Wei-Shaw/sub2api/ent/usersubscription"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/lib/pq"
)
type userRepository struct {
@@ -86,10 +85,11 @@ func (r *userRepository) GetByID(ctx context.Context, id int64) (*service.User,
out := userEntityToService(m)
groups, err := r.loadAllowedGroups(ctx, []int64{id})
if err == nil {
if v, ok := groups[id]; ok {
out.AllowedGroups = v
}
if err != nil {
return nil, err
}
if v, ok := groups[id]; ok {
out.AllowedGroups = v
}
return out, nil
}
@@ -102,10 +102,11 @@ func (r *userRepository) GetByEmail(ctx context.Context, email string) (*service
out := userEntityToService(m)
groups, err := r.loadAllowedGroups(ctx, []int64{m.ID})
if err == nil {
if v, ok := groups[m.ID]; ok {
out.AllowedGroups = v
}
if err != nil {
return nil, err
}
if v, ok := groups[m.ID]; ok {
out.AllowedGroups = v
}
return out, nil
}
@@ -240,11 +241,12 @@ func (r *userRepository) ListWithFilters(ctx context.Context, params pagination.
}
allowedGroupsByUser, err := r.loadAllowedGroups(ctx, userIDs)
if err == nil {
for id, u := range userMap {
if groups, ok := allowedGroupsByUser[id]; ok {
u.AllowedGroups = groups
}
if err != nil {
return nil, nil, err
}
for id, u := range userMap {
if groups, ok := allowedGroupsByUser[id]; ok {
u.AllowedGroups = groups
}
}
@@ -252,12 +254,20 @@ func (r *userRepository) ListWithFilters(ctx context.Context, params pagination.
}
func (r *userRepository) UpdateBalance(ctx context.Context, id int64, amount float64) error {
_, err := r.client.User.Update().Where(dbuser.IDEQ(id)).AddBalance(amount).Save(ctx)
return err
client := clientFromContext(ctx, r.client)
n, err := client.User.Update().Where(dbuser.IDEQ(id)).AddBalance(amount).Save(ctx)
if err != nil {
return translatePersistenceError(err, service.ErrUserNotFound, nil)
}
if n == 0 {
return service.ErrUserNotFound
}
return nil
}
func (r *userRepository) DeductBalance(ctx context.Context, id int64, amount float64) error {
n, err := r.client.User.Update().
client := clientFromContext(ctx, r.client)
n, err := client.User.Update().
Where(dbuser.IDEQ(id), dbuser.BalanceGTE(amount)).
AddBalance(-amount).
Save(ctx)
@@ -271,8 +281,15 @@ func (r *userRepository) DeductBalance(ctx context.Context, id int64, amount flo
}
func (r *userRepository) UpdateConcurrency(ctx context.Context, id int64, amount int) error {
_, err := r.client.User.Update().Where(dbuser.IDEQ(id)).AddConcurrency(amount).Save(ctx)
return err
client := clientFromContext(ctx, r.client)
n, err := client.User.Update().Where(dbuser.IDEQ(id)).AddConcurrency(amount).Save(ctx)
if err != nil {
return translatePersistenceError(err, service.ErrUserNotFound, nil)
}
if n == 0 {
return service.ErrUserNotFound
}
return nil
}
func (r *userRepository) ExistsByEmail(ctx context.Context, email string) (bool, error) {
@@ -280,33 +297,14 @@ func (r *userRepository) ExistsByEmail(ctx context.Context, email string) (bool,
}
func (r *userRepository) RemoveGroupFromAllowedGroups(ctx context.Context, groupID int64) (int64, error) {
exec := r.sql
if exec == nil {
// 未注入 sqlExecutor 时,退回到 ent client 的 ExecContext支持事务
exec = r.client
}
joinAffected, err := r.client.UserAllowedGroup.Delete().
// 仅操作 user_allowed_groups 联接表legacy users.allowed_groups 列已弃用。
affected, err := r.client.UserAllowedGroup.Delete().
Where(userallowedgroup.GroupIDEQ(groupID)).
Exec(ctx)
if err != nil {
return 0, err
}
arrayRes, err := exec.ExecContext(
ctx,
"UPDATE users SET allowed_groups = array_remove(allowed_groups, $1), updated_at = NOW() WHERE $1 = ANY(allowed_groups)",
groupID,
)
if err != nil {
return 0, err
}
arrayAffected, _ := arrayRes.RowsAffected()
if int64(joinAffected) > arrayAffected {
return int64(joinAffected), nil
}
return arrayAffected, nil
return int64(affected), nil
}
func (r *userRepository) GetFirstAdmin(ctx context.Context) (*service.User, error) {
@@ -323,10 +321,11 @@ func (r *userRepository) GetFirstAdmin(ctx context.Context) (*service.User, erro
out := userEntityToService(m)
groups, err := r.loadAllowedGroups(ctx, []int64{m.ID})
if err == nil {
if v, ok := groups[m.ID]; ok {
out.AllowedGroups = v
}
if err != nil {
return nil, err
}
if v, ok := groups[m.ID]; ok {
out.AllowedGroups = v
}
return out, nil
}
@@ -356,8 +355,7 @@ func (r *userRepository) loadAllowedGroups(ctx context.Context, userIDs []int64)
}
// syncUserAllowedGroupsWithClient 在 ent client/事务内同步用户允许分组:
// 1) 以 user_allowed_groups 为读写源,确保新旧逻辑一致;
// 2) 额外更新 users.allowed_groups历史字段以保持兼容。
// 仅操作 user_allowed_groups 联接表legacy users.allowed_groups 列已弃用。
func (r *userRepository) syncUserAllowedGroupsWithClient(ctx context.Context, client *dbent.Client, userID int64, groupIDs []int64) error {
if client == nil {
return nil
@@ -376,12 +374,10 @@ func (r *userRepository) syncUserAllowedGroupsWithClient(ctx context.Context, cl
unique[id] = struct{}{}
}
legacyGroups := make([]int64, 0, len(unique))
if len(unique) > 0 {
creates := make([]*dbent.UserAllowedGroupCreate, 0, len(unique))
for groupID := range unique {
creates = append(creates, client.UserAllowedGroup.Create().SetUserID(userID).SetGroupID(groupID))
legacyGroups = append(legacyGroups, groupID)
}
if err := client.UserAllowedGroup.
CreateBulk(creates...).
@@ -392,16 +388,6 @@ func (r *userRepository) syncUserAllowedGroupsWithClient(ctx context.Context, cl
}
}
// Phase 1 兼容:保持 users.allowed_groups数组字段同步避免旧查询路径读取到过期数据。
var legacy any
if len(legacyGroups) > 0 {
sort.Slice(legacyGroups, func(i, j int) bool { return legacyGroups[i] < legacyGroups[j] })
legacy = pq.Array(legacyGroups)
}
if _, err := client.ExecContext(ctx, "UPDATE users SET allowed_groups = $1::bigint[] WHERE id = $2", legacy, userID); err != nil {
return err
}
return nil
}

View File

@@ -508,3 +508,24 @@ func (s *UserRepoSuite) TestCRUD_And_Filters_And_AtomicUpdates() {
s.Require().Equal(user2.ID, users[0].ID, "ListWithFilters result mismatch")
}
// --- UpdateBalance/UpdateConcurrency 影响行数校验测试 ---
func (s *UserRepoSuite) TestUpdateBalance_NotFound() {
err := s.repo.UpdateBalance(s.ctx, 999999, 10.0)
s.Require().Error(err, "expected error for non-existent user")
s.Require().ErrorIs(err, service.ErrUserNotFound)
}
func (s *UserRepoSuite) TestUpdateConcurrency_NotFound() {
err := s.repo.UpdateConcurrency(s.ctx, 999999, 5)
s.Require().Error(err, "expected error for non-existent user")
s.Require().ErrorIs(err, service.ErrUserNotFound)
}
func (s *UserRepoSuite) TestDeductBalance_NotFound() {
err := s.repo.DeductBalance(s.ctx, 999999, 5)
s.Require().Error(err, "expected error for non-existent user")
// DeductBalance 在用户不存在时返回 ErrInsufficientBalance 因为 WHERE 条件不匹配
s.Require().ErrorIs(err, service.ErrInsufficientBalance)
}

View File

@@ -20,10 +20,11 @@ func NewUserSubscriptionRepository(client *dbent.Client) service.UserSubscriptio
func (r *userSubscriptionRepository) Create(ctx context.Context, sub *service.UserSubscription) error {
if sub == nil {
return nil
return service.ErrSubscriptionNilInput
}
builder := r.client.UserSubscription.Create().
client := clientFromContext(ctx, r.client)
builder := client.UserSubscription.Create().
SetUserID(sub.UserID).
SetGroupID(sub.GroupID).
SetExpiresAt(sub.ExpiresAt).
@@ -57,7 +58,8 @@ func (r *userSubscriptionRepository) Create(ctx context.Context, sub *service.Us
}
func (r *userSubscriptionRepository) GetByID(ctx context.Context, id int64) (*service.UserSubscription, error) {
m, err := r.client.UserSubscription.Query().
client := clientFromContext(ctx, r.client)
m, err := client.UserSubscription.Query().
Where(usersubscription.IDEQ(id)).
WithUser().
WithGroup().
@@ -70,7 +72,8 @@ func (r *userSubscriptionRepository) GetByID(ctx context.Context, id int64) (*se
}
func (r *userSubscriptionRepository) GetByUserIDAndGroupID(ctx context.Context, userID, groupID int64) (*service.UserSubscription, error) {
m, err := r.client.UserSubscription.Query().
client := clientFromContext(ctx, r.client)
m, err := client.UserSubscription.Query().
Where(usersubscription.UserIDEQ(userID), usersubscription.GroupIDEQ(groupID)).
WithGroup().
Only(ctx)
@@ -81,7 +84,8 @@ func (r *userSubscriptionRepository) GetByUserIDAndGroupID(ctx context.Context,
}
func (r *userSubscriptionRepository) GetActiveByUserIDAndGroupID(ctx context.Context, userID, groupID int64) (*service.UserSubscription, error) {
m, err := r.client.UserSubscription.Query().
client := clientFromContext(ctx, r.client)
m, err := client.UserSubscription.Query().
Where(
usersubscription.UserIDEQ(userID),
usersubscription.GroupIDEQ(groupID),
@@ -98,10 +102,11 @@ func (r *userSubscriptionRepository) GetActiveByUserIDAndGroupID(ctx context.Con
func (r *userSubscriptionRepository) Update(ctx context.Context, sub *service.UserSubscription) error {
if sub == nil {
return nil
return service.ErrSubscriptionNilInput
}
builder := r.client.UserSubscription.UpdateOneID(sub.ID).
client := clientFromContext(ctx, r.client)
builder := client.UserSubscription.UpdateOneID(sub.ID).
SetUserID(sub.UserID).
SetGroupID(sub.GroupID).
SetStartsAt(sub.StartsAt).
@@ -127,12 +132,14 @@ func (r *userSubscriptionRepository) Update(ctx context.Context, sub *service.Us
func (r *userSubscriptionRepository) Delete(ctx context.Context, id int64) error {
// Match GORM semantics: deleting a missing row is not an error.
_, err := r.client.UserSubscription.Delete().Where(usersubscription.IDEQ(id)).Exec(ctx)
client := clientFromContext(ctx, r.client)
_, err := client.UserSubscription.Delete().Where(usersubscription.IDEQ(id)).Exec(ctx)
return err
}
func (r *userSubscriptionRepository) ListByUserID(ctx context.Context, userID int64) ([]service.UserSubscription, error) {
subs, err := r.client.UserSubscription.Query().
client := clientFromContext(ctx, r.client)
subs, err := client.UserSubscription.Query().
Where(usersubscription.UserIDEQ(userID)).
WithGroup().
Order(dbent.Desc(usersubscription.FieldCreatedAt)).
@@ -144,7 +151,8 @@ func (r *userSubscriptionRepository) ListByUserID(ctx context.Context, userID in
}
func (r *userSubscriptionRepository) ListActiveByUserID(ctx context.Context, userID int64) ([]service.UserSubscription, error) {
subs, err := r.client.UserSubscription.Query().
client := clientFromContext(ctx, r.client)
subs, err := client.UserSubscription.Query().
Where(
usersubscription.UserIDEQ(userID),
usersubscription.StatusEQ(service.SubscriptionStatusActive),
@@ -160,7 +168,8 @@ func (r *userSubscriptionRepository) ListActiveByUserID(ctx context.Context, use
}
func (r *userSubscriptionRepository) ListByGroupID(ctx context.Context, groupID int64, params pagination.PaginationParams) ([]service.UserSubscription, *pagination.PaginationResult, error) {
q := r.client.UserSubscription.Query().Where(usersubscription.GroupIDEQ(groupID))
client := clientFromContext(ctx, r.client)
q := client.UserSubscription.Query().Where(usersubscription.GroupIDEQ(groupID))
total, err := q.Clone().Count(ctx)
if err != nil {
@@ -182,7 +191,8 @@ func (r *userSubscriptionRepository) ListByGroupID(ctx context.Context, groupID
}
func (r *userSubscriptionRepository) List(ctx context.Context, params pagination.PaginationParams, userID, groupID *int64, status string) ([]service.UserSubscription, *pagination.PaginationResult, error) {
q := r.client.UserSubscription.Query()
client := clientFromContext(ctx, r.client)
q := client.UserSubscription.Query()
if userID != nil {
q = q.Where(usersubscription.UserIDEQ(*userID))
}
@@ -214,34 +224,39 @@ func (r *userSubscriptionRepository) List(ctx context.Context, params pagination
}
func (r *userSubscriptionRepository) ExistsByUserIDAndGroupID(ctx context.Context, userID, groupID int64) (bool, error) {
return r.client.UserSubscription.Query().
client := clientFromContext(ctx, r.client)
return client.UserSubscription.Query().
Where(usersubscription.UserIDEQ(userID), usersubscription.GroupIDEQ(groupID)).
Exist(ctx)
}
func (r *userSubscriptionRepository) ExtendExpiry(ctx context.Context, subscriptionID int64, newExpiresAt time.Time) error {
_, err := r.client.UserSubscription.UpdateOneID(subscriptionID).
client := clientFromContext(ctx, r.client)
_, err := client.UserSubscription.UpdateOneID(subscriptionID).
SetExpiresAt(newExpiresAt).
Save(ctx)
return translatePersistenceError(err, service.ErrSubscriptionNotFound, nil)
}
func (r *userSubscriptionRepository) UpdateStatus(ctx context.Context, subscriptionID int64, status string) error {
_, err := r.client.UserSubscription.UpdateOneID(subscriptionID).
client := clientFromContext(ctx, r.client)
_, err := client.UserSubscription.UpdateOneID(subscriptionID).
SetStatus(status).
Save(ctx)
return translatePersistenceError(err, service.ErrSubscriptionNotFound, nil)
}
func (r *userSubscriptionRepository) UpdateNotes(ctx context.Context, subscriptionID int64, notes string) error {
_, err := r.client.UserSubscription.UpdateOneID(subscriptionID).
client := clientFromContext(ctx, r.client)
_, err := client.UserSubscription.UpdateOneID(subscriptionID).
SetNotes(notes).
Save(ctx)
return translatePersistenceError(err, service.ErrSubscriptionNotFound, nil)
}
func (r *userSubscriptionRepository) ActivateWindows(ctx context.Context, id int64, start time.Time) error {
_, err := r.client.UserSubscription.UpdateOneID(id).
client := clientFromContext(ctx, r.client)
_, err := client.UserSubscription.UpdateOneID(id).
SetDailyWindowStart(start).
SetWeeklyWindowStart(start).
SetMonthlyWindowStart(start).
@@ -250,7 +265,8 @@ func (r *userSubscriptionRepository) ActivateWindows(ctx context.Context, id int
}
func (r *userSubscriptionRepository) ResetDailyUsage(ctx context.Context, id int64, newWindowStart time.Time) error {
_, err := r.client.UserSubscription.UpdateOneID(id).
client := clientFromContext(ctx, r.client)
_, err := client.UserSubscription.UpdateOneID(id).
SetDailyUsageUsd(0).
SetDailyWindowStart(newWindowStart).
Save(ctx)
@@ -258,7 +274,8 @@ func (r *userSubscriptionRepository) ResetDailyUsage(ctx context.Context, id int
}
func (r *userSubscriptionRepository) ResetWeeklyUsage(ctx context.Context, id int64, newWindowStart time.Time) error {
_, err := r.client.UserSubscription.UpdateOneID(id).
client := clientFromContext(ctx, r.client)
_, err := client.UserSubscription.UpdateOneID(id).
SetWeeklyUsageUsd(0).
SetWeeklyWindowStart(newWindowStart).
Save(ctx)
@@ -266,24 +283,112 @@ func (r *userSubscriptionRepository) ResetWeeklyUsage(ctx context.Context, id in
}
func (r *userSubscriptionRepository) ResetMonthlyUsage(ctx context.Context, id int64, newWindowStart time.Time) error {
_, err := r.client.UserSubscription.UpdateOneID(id).
client := clientFromContext(ctx, r.client)
_, err := client.UserSubscription.UpdateOneID(id).
SetMonthlyUsageUsd(0).
SetMonthlyWindowStart(newWindowStart).
Save(ctx)
return translatePersistenceError(err, service.ErrSubscriptionNotFound, nil)
}
// IncrementUsage 原子性地累加用量并校验限额。
// 使用单条 SQL 语句同时检查 Group 的限额,如果任一限额即将超出则拒绝更新。
// 当更新失败时,会执行额外查询确定具体超出的限额类型。
func (r *userSubscriptionRepository) IncrementUsage(ctx context.Context, id int64, costUSD float64) error {
_, err := r.client.UserSubscription.UpdateOneID(id).
AddDailyUsageUsd(costUSD).
AddWeeklyUsageUsd(costUSD).
AddMonthlyUsageUsd(costUSD).
Save(ctx)
return translatePersistenceError(err, service.ErrSubscriptionNotFound, nil)
// 使用 JOIN 的原子更新:只有当所有限额条件满足时才执行累加
// NULL 限额表示无限制
const atomicUpdateSQL = `
UPDATE user_subscriptions us
SET
daily_usage_usd = us.daily_usage_usd + $1,
weekly_usage_usd = us.weekly_usage_usd + $1,
monthly_usage_usd = us.monthly_usage_usd + $1,
updated_at = NOW()
FROM groups g
WHERE us.id = $2
AND us.deleted_at IS NULL
AND us.group_id = g.id
AND g.deleted_at IS NULL
AND (g.daily_limit_usd IS NULL OR us.daily_usage_usd + $1 <= g.daily_limit_usd)
AND (g.weekly_limit_usd IS NULL OR us.weekly_usage_usd + $1 <= g.weekly_limit_usd)
AND (g.monthly_limit_usd IS NULL OR us.monthly_usage_usd + $1 <= g.monthly_limit_usd)
`
client := clientFromContext(ctx, r.client)
result, err := client.ExecContext(ctx, atomicUpdateSQL, costUSD, id)
if err != nil {
return err
}
affected, err := result.RowsAffected()
if err != nil {
return err
}
if affected > 0 {
return nil // 更新成功
}
// affected == 0可能是订阅不存在、分组已删除、或限额超出
// 执行额外查询确定具体原因
return r.checkIncrementFailureReason(ctx, id, costUSD)
}
// checkIncrementFailureReason 查询更新失败的具体原因
func (r *userSubscriptionRepository) checkIncrementFailureReason(ctx context.Context, id int64, costUSD float64) error {
const checkSQL = `
SELECT
CASE WHEN us.deleted_at IS NOT NULL THEN 'subscription_deleted'
WHEN g.id IS NULL THEN 'subscription_not_found'
WHEN g.deleted_at IS NOT NULL THEN 'group_deleted'
WHEN g.daily_limit_usd IS NOT NULL AND us.daily_usage_usd + $1 > g.daily_limit_usd THEN 'daily_exceeded'
WHEN g.weekly_limit_usd IS NOT NULL AND us.weekly_usage_usd + $1 > g.weekly_limit_usd THEN 'weekly_exceeded'
WHEN g.monthly_limit_usd IS NOT NULL AND us.monthly_usage_usd + $1 > g.monthly_limit_usd THEN 'monthly_exceeded'
ELSE 'unknown'
END AS reason
FROM user_subscriptions us
LEFT JOIN groups g ON us.group_id = g.id
WHERE us.id = $2
`
client := clientFromContext(ctx, r.client)
rows, err := client.QueryContext(ctx, checkSQL, costUSD, id)
if err != nil {
return err
}
defer func() { _ = rows.Close() }()
if !rows.Next() {
return service.ErrSubscriptionNotFound
}
var reason string
if err := rows.Scan(&reason); err != nil {
return err
}
if err := rows.Err(); err != nil {
return err
}
switch reason {
case "subscription_not_found", "subscription_deleted", "group_deleted":
return service.ErrSubscriptionNotFound
case "daily_exceeded":
return service.ErrDailyLimitExceeded
case "weekly_exceeded":
return service.ErrWeeklyLimitExceeded
case "monthly_exceeded":
return service.ErrMonthlyLimitExceeded
default:
// unknown 情况理论上不应发生,但作为兜底返回
return service.ErrSubscriptionNotFound
}
}
func (r *userSubscriptionRepository) BatchUpdateExpiredStatus(ctx context.Context) (int64, error) {
n, err := r.client.UserSubscription.Update().
client := clientFromContext(ctx, r.client)
n, err := client.UserSubscription.Update().
Where(
usersubscription.StatusEQ(service.SubscriptionStatusActive),
usersubscription.ExpiresAtLTE(time.Now()),
@@ -296,7 +401,8 @@ func (r *userSubscriptionRepository) BatchUpdateExpiredStatus(ctx context.Contex
// Extra repository helpers (currently used only by integration tests).
func (r *userSubscriptionRepository) ListExpired(ctx context.Context) ([]service.UserSubscription, error) {
subs, err := r.client.UserSubscription.Query().
client := clientFromContext(ctx, r.client)
subs, err := client.UserSubscription.Query().
Where(
usersubscription.StatusEQ(service.SubscriptionStatusActive),
usersubscription.ExpiresAtLTE(time.Now()),
@@ -309,12 +415,14 @@ func (r *userSubscriptionRepository) ListExpired(ctx context.Context) ([]service
}
func (r *userSubscriptionRepository) CountByGroupID(ctx context.Context, groupID int64) (int64, error) {
count, err := r.client.UserSubscription.Query().Where(usersubscription.GroupIDEQ(groupID)).Count(ctx)
client := clientFromContext(ctx, r.client)
count, err := client.UserSubscription.Query().Where(usersubscription.GroupIDEQ(groupID)).Count(ctx)
return int64(count), err
}
func (r *userSubscriptionRepository) CountActiveByGroupID(ctx context.Context, groupID int64) (int64, error) {
count, err := r.client.UserSubscription.Query().
client := clientFromContext(ctx, r.client)
count, err := client.UserSubscription.Query().
Where(
usersubscription.GroupIDEQ(groupID),
usersubscription.StatusEQ(service.SubscriptionStatusActive),
@@ -325,7 +433,8 @@ func (r *userSubscriptionRepository) CountActiveByGroupID(ctx context.Context, g
}
func (r *userSubscriptionRepository) DeleteByGroupID(ctx context.Context, groupID int64) (int64, error) {
n, err := r.client.UserSubscription.Delete().Where(usersubscription.GroupIDEQ(groupID)).Exec(ctx)
client := clientFromContext(ctx, r.client)
n, err := client.UserSubscription.Delete().Where(usersubscription.GroupIDEQ(groupID)).Exec(ctx)
return int64(n), err
}

View File

@@ -4,6 +4,7 @@ package repository
import (
"context"
"fmt"
"testing"
"time"
@@ -631,3 +632,249 @@ func (s *UserSubscriptionRepoSuite) TestActiveExpiredBoundaries_UsageAndReset_Ba
s.Require().NoError(err, "GetByID expired")
s.Require().Equal(service.SubscriptionStatusExpired, updated.Status, "expected status expired")
}
// --- 限额检查与软删除过滤测试 ---
func (s *UserSubscriptionRepoSuite) mustCreateGroupWithLimits(name string, daily, weekly, monthly *float64) *service.Group {
s.T().Helper()
create := s.client.Group.Create().
SetName(name).
SetStatus(service.StatusActive).
SetSubscriptionType(service.SubscriptionTypeSubscription)
if daily != nil {
create.SetDailyLimitUsd(*daily)
}
if weekly != nil {
create.SetWeeklyLimitUsd(*weekly)
}
if monthly != nil {
create.SetMonthlyLimitUsd(*monthly)
}
g, err := create.Save(s.ctx)
s.Require().NoError(err, "create group with limits")
return groupEntityToService(g)
}
func (s *UserSubscriptionRepoSuite) TestIncrementUsage_DailyLimitExceeded() {
user := s.mustCreateUser("dailylimit@test.com", service.RoleUser)
dailyLimit := 10.0
group := s.mustCreateGroupWithLimits("g-dailylimit", &dailyLimit, nil, nil)
sub := s.mustCreateSubscription(user.ID, group.ID, nil)
// 先增加 9.0,应该成功
err := s.repo.IncrementUsage(s.ctx, sub.ID, 9.0)
s.Require().NoError(err, "first increment should succeed")
// 再增加 2.0,会超过 10.0 限额,应该失败
err = s.repo.IncrementUsage(s.ctx, sub.ID, 2.0)
s.Require().Error(err, "should fail when daily limit exceeded")
s.Require().ErrorIs(err, service.ErrDailyLimitExceeded)
// 验证用量没有变化
got, err := s.repo.GetByID(s.ctx, sub.ID)
s.Require().NoError(err)
s.Require().InDelta(9.0, got.DailyUsageUSD, 1e-6, "usage should not change after failed increment")
}
func (s *UserSubscriptionRepoSuite) TestIncrementUsage_WeeklyLimitExceeded() {
user := s.mustCreateUser("weeklylimit@test.com", service.RoleUser)
weeklyLimit := 50.0
group := s.mustCreateGroupWithLimits("g-weeklylimit", nil, &weeklyLimit, nil)
sub := s.mustCreateSubscription(user.ID, group.ID, nil)
// 增加 45.0,应该成功
err := s.repo.IncrementUsage(s.ctx, sub.ID, 45.0)
s.Require().NoError(err, "first increment should succeed")
// 再增加 10.0,会超过 50.0 限额,应该失败
err = s.repo.IncrementUsage(s.ctx, sub.ID, 10.0)
s.Require().Error(err, "should fail when weekly limit exceeded")
s.Require().ErrorIs(err, service.ErrWeeklyLimitExceeded)
}
func (s *UserSubscriptionRepoSuite) TestIncrementUsage_MonthlyLimitExceeded() {
user := s.mustCreateUser("monthlylimit@test.com", service.RoleUser)
monthlyLimit := 100.0
group := s.mustCreateGroupWithLimits("g-monthlylimit", nil, nil, &monthlyLimit)
sub := s.mustCreateSubscription(user.ID, group.ID, nil)
// 增加 90.0,应该成功
err := s.repo.IncrementUsage(s.ctx, sub.ID, 90.0)
s.Require().NoError(err, "first increment should succeed")
// 再增加 20.0,会超过 100.0 限额,应该失败
err = s.repo.IncrementUsage(s.ctx, sub.ID, 20.0)
s.Require().Error(err, "should fail when monthly limit exceeded")
s.Require().ErrorIs(err, service.ErrMonthlyLimitExceeded)
}
func (s *UserSubscriptionRepoSuite) TestIncrementUsage_NoLimits() {
user := s.mustCreateUser("nolimits@test.com", service.RoleUser)
group := s.mustCreateGroupWithLimits("g-nolimits", nil, nil, nil) // 无限额
sub := s.mustCreateSubscription(user.ID, group.ID, nil)
// 应该可以增加任意金额
err := s.repo.IncrementUsage(s.ctx, sub.ID, 1000000.0)
s.Require().NoError(err, "should succeed without limits")
got, err := s.repo.GetByID(s.ctx, sub.ID)
s.Require().NoError(err)
s.Require().InDelta(1000000.0, got.DailyUsageUSD, 1e-6)
}
func (s *UserSubscriptionRepoSuite) TestIncrementUsage_AtExactLimit() {
user := s.mustCreateUser("exactlimit@test.com", service.RoleUser)
dailyLimit := 10.0
group := s.mustCreateGroupWithLimits("g-exactlimit", &dailyLimit, nil, nil)
sub := s.mustCreateSubscription(user.ID, group.ID, nil)
// 正好达到限额应该成功
err := s.repo.IncrementUsage(s.ctx, sub.ID, 10.0)
s.Require().NoError(err, "should succeed at exact limit")
got, err := s.repo.GetByID(s.ctx, sub.ID)
s.Require().NoError(err)
s.Require().InDelta(10.0, got.DailyUsageUSD, 1e-6)
}
func (s *UserSubscriptionRepoSuite) TestIncrementUsage_SoftDeletedGroup() {
user := s.mustCreateUser("softdeleted@test.com", service.RoleUser)
group := s.mustCreateGroup("g-softdeleted")
sub := s.mustCreateSubscription(user.ID, group.ID, nil)
// 软删除分组
_, err := s.client.Group.UpdateOneID(group.ID).SetDeletedAt(time.Now()).Save(s.ctx)
s.Require().NoError(err, "soft delete group")
// IncrementUsage 应该失败,因为分组已软删除
err = s.repo.IncrementUsage(s.ctx, sub.ID, 1.0)
s.Require().Error(err, "should fail for soft-deleted group")
s.Require().ErrorIs(err, service.ErrSubscriptionNotFound)
}
func (s *UserSubscriptionRepoSuite) TestIncrementUsage_NotFound() {
err := s.repo.IncrementUsage(s.ctx, 999999, 1.0)
s.Require().Error(err, "should fail for non-existent subscription")
s.Require().ErrorIs(err, service.ErrSubscriptionNotFound)
}
// --- nil 入参测试 ---
func (s *UserSubscriptionRepoSuite) TestCreate_NilInput() {
err := s.repo.Create(s.ctx, nil)
s.Require().Error(err, "Create should fail with nil input")
s.Require().ErrorIs(err, service.ErrSubscriptionNilInput)
}
func (s *UserSubscriptionRepoSuite) TestUpdate_NilInput() {
err := s.repo.Update(s.ctx, nil)
s.Require().Error(err, "Update should fail with nil input")
s.Require().ErrorIs(err, service.ErrSubscriptionNilInput)
}
// --- 并发用量更新测试 ---
func (s *UserSubscriptionRepoSuite) TestIncrementUsage_Concurrent() {
user := s.mustCreateUser("concurrent@test.com", service.RoleUser)
group := s.mustCreateGroupWithLimits("g-concurrent", nil, nil, nil) // 无限额
sub := s.mustCreateSubscription(user.ID, group.ID, nil)
const numGoroutines = 10
const incrementPerGoroutine = 1.5
// 启动多个 goroutine 并发调用 IncrementUsage
errCh := make(chan error, numGoroutines)
for i := 0; i < numGoroutines; i++ {
go func() {
errCh <- s.repo.IncrementUsage(s.ctx, sub.ID, incrementPerGoroutine)
}()
}
// 等待所有 goroutine 完成
for i := 0; i < numGoroutines; i++ {
err := <-errCh
s.Require().NoError(err, "IncrementUsage should succeed")
}
// 验证累加结果正确
got, err := s.repo.GetByID(s.ctx, sub.ID)
s.Require().NoError(err)
expectedUsage := float64(numGoroutines) * incrementPerGoroutine
s.Require().InDelta(expectedUsage, got.DailyUsageUSD, 1e-6, "daily usage should be correctly accumulated")
s.Require().InDelta(expectedUsage, got.WeeklyUsageUSD, 1e-6, "weekly usage should be correctly accumulated")
s.Require().InDelta(expectedUsage, got.MonthlyUsageUSD, 1e-6, "monthly usage should be correctly accumulated")
}
func (s *UserSubscriptionRepoSuite) TestIncrementUsage_ConcurrentWithLimit() {
user := s.mustCreateUser("concurrentlimit@test.com", service.RoleUser)
dailyLimit := 5.0
group := s.mustCreateGroupWithLimits("g-concurrentlimit", &dailyLimit, nil, nil)
sub := s.mustCreateSubscription(user.ID, group.ID, nil)
// 注意:事务内的操作是串行的,所以这里改为顺序执行以验证限额逻辑
// 尝试增加 10 次,每次 1.0,但限额只有 5.0
const numAttempts = 10
const incrementPerAttempt = 1.0
successCount := 0
for i := 0; i < numAttempts; i++ {
err := s.repo.IncrementUsage(s.ctx, sub.ID, incrementPerAttempt)
if err == nil {
successCount++
}
}
// 验证:应该有 5 次成功不超过限额5 次失败(超出限额)
s.Require().Equal(5, successCount, "exactly 5 increments should succeed (limit=5, increment=1)")
// 验证最终用量等于限额
got, err := s.repo.GetByID(s.ctx, sub.ID)
s.Require().NoError(err)
s.Require().InDelta(dailyLimit, got.DailyUsageUSD, 1e-6, "daily usage should equal limit")
}
func (s *UserSubscriptionRepoSuite) TestTxContext_RollbackIsolation() {
baseClient := testEntClient(s.T())
tx, err := baseClient.Tx(context.Background())
s.Require().NoError(err, "begin tx")
defer func() {
if tx != nil {
_ = tx.Rollback()
}
}()
txCtx := dbent.NewTxContext(context.Background(), tx)
suffix := fmt.Sprintf("%d", time.Now().UnixNano())
userEnt, err := tx.Client().User.Create().
SetEmail("tx-user-" + suffix + "@example.com").
SetPasswordHash("test").
Save(txCtx)
s.Require().NoError(err, "create user in tx")
groupEnt, err := tx.Client().Group.Create().
SetName("tx-group-" + suffix).
Save(txCtx)
s.Require().NoError(err, "create group in tx")
repo := NewUserSubscriptionRepository(baseClient)
sub := &service.UserSubscription{
UserID: userEnt.ID,
GroupID: groupEnt.ID,
ExpiresAt: time.Now().AddDate(0, 0, 30),
Status: service.SubscriptionStatusActive,
AssignedAt: time.Now(),
Notes: "tx",
}
s.Require().NoError(repo.Create(txCtx, sub), "create subscription in tx")
s.Require().NoError(repo.UpdateNotes(txCtx, sub.ID, "tx-note"), "update subscription in tx")
s.Require().NoError(tx.Rollback(), "rollback tx")
tx = nil
_, err = repo.GetByID(context.Background(), sub.ID)
s.Require().ErrorIs(err, service.ErrSubscriptionNotFound)
}

View File

@@ -11,6 +11,7 @@ import (
var (
ErrAccountNotFound = infraerrors.NotFound("ACCOUNT_NOT_FOUND", "account not found")
ErrAccountNilInput = infraerrors.BadRequest("ACCOUNT_NIL_INPUT", "account input cannot be nil")
)
type AccountRepository interface {

View File

@@ -11,10 +11,11 @@ type Group struct {
IsExclusive bool
Status string
SubscriptionType string
DailyLimitUSD *float64
WeeklyLimitUSD *float64
MonthlyLimitUSD *float64
SubscriptionType string
DailyLimitUSD *float64
WeeklyLimitUSD *float64
MonthlyLimitUSD *float64
DefaultValidityDays int
CreatedAt time.Time
UpdatedAt time.Time

View File

@@ -9,6 +9,7 @@ import (
"strings"
"time"
dbent "github.com/Wei-Shaw/sub2api/ent"
infraerrors "github.com/Wei-Shaw/sub2api/internal/infrastructure/errors"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
)
@@ -72,6 +73,7 @@ type RedeemService struct {
subscriptionService *SubscriptionService
cache RedeemCache
billingCacheService *BillingCacheService
entClient *dbent.Client
}
// NewRedeemService 创建兑换码服务实例
@@ -81,6 +83,7 @@ func NewRedeemService(
subscriptionService *SubscriptionService,
cache RedeemCache,
billingCacheService *BillingCacheService,
entClient *dbent.Client,
) *RedeemService {
return &RedeemService{
redeemRepo: redeemRepo,
@@ -88,6 +91,7 @@ func NewRedeemService(
subscriptionService: subscriptionService,
cache: cache,
billingCacheService: billingCacheService,
entClient: entClient,
}
}
@@ -248,9 +252,19 @@ func (s *RedeemService) Redeem(ctx context.Context, userID int64, code string) (
}
_ = user // 使用变量避免未使用错误
// 使用数据库事务保证兑换码标记与权益发放的原子性
tx, err := s.entClient.Tx(ctx)
if err != nil {
return nil, fmt.Errorf("begin transaction: %w", err)
}
defer func() { _ = tx.Rollback() }()
// 将事务放入 context使 repository 方法能够使用同一事务
txCtx := dbent.NewTxContext(ctx, tx)
// 【关键】先标记兑换码为已使用,确保并发安全
// 利用数据库乐观锁WHERE status = 'unused')保证原子性
if err := s.redeemRepo.Use(ctx, redeemCode.ID, userID); err != nil {
if err := s.redeemRepo.Use(txCtx, redeemCode.ID, userID); err != nil {
if errors.Is(err, ErrRedeemCodeNotFound) || errors.Is(err, ErrRedeemCodeUsed) {
return nil, ErrRedeemCodeUsed
}
@@ -261,21 +275,13 @@ func (s *RedeemService) Redeem(ctx context.Context, userID int64, code string) (
switch redeemCode.Type {
case RedeemTypeBalance:
// 增加用户余额
if err := s.userRepo.UpdateBalance(ctx, userID, redeemCode.Value); err != nil {
if err := s.userRepo.UpdateBalance(txCtx, userID, redeemCode.Value); err != nil {
return nil, fmt.Errorf("update user balance: %w", err)
}
// 失效余额缓存
if s.billingCacheService != nil {
go func() {
cacheCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
_ = s.billingCacheService.InvalidateUserBalance(cacheCtx, userID)
}()
}
case RedeemTypeConcurrency:
// 增加用户并发数
if err := s.userRepo.UpdateConcurrency(ctx, userID, int(redeemCode.Value)); err != nil {
if err := s.userRepo.UpdateConcurrency(txCtx, userID, int(redeemCode.Value)); err != nil {
return nil, fmt.Errorf("update user concurrency: %w", err)
}
@@ -284,7 +290,7 @@ func (s *RedeemService) Redeem(ctx context.Context, userID int64, code string) (
if validityDays <= 0 {
validityDays = 30
}
_, _, err := s.subscriptionService.AssignOrExtendSubscription(ctx, &AssignSubscriptionInput{
_, _, err := s.subscriptionService.AssignOrExtendSubscription(txCtx, &AssignSubscriptionInput{
UserID: userID,
GroupID: *redeemCode.GroupID,
ValidityDays: validityDays,
@@ -294,20 +300,19 @@ func (s *RedeemService) Redeem(ctx context.Context, userID int64, code string) (
if err != nil {
return nil, fmt.Errorf("assign or extend subscription: %w", err)
}
// 失效订阅缓存
if s.billingCacheService != nil {
groupID := *redeemCode.GroupID
go func() {
cacheCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
_ = s.billingCacheService.InvalidateSubscription(cacheCtx, userID, groupID)
}()
}
default:
return nil, fmt.Errorf("unsupported redeem type: %s", redeemCode.Type)
}
// 提交事务
if err := tx.Commit(); err != nil {
return nil, fmt.Errorf("commit transaction: %w", err)
}
// 事务提交成功后失效缓存
s.invalidateRedeemCaches(ctx, userID, redeemCode)
// 重新获取更新后的兑换码
redeemCode, err = s.redeemRepo.GetByID(ctx, redeemCode.ID)
if err != nil {
@@ -317,6 +322,31 @@ func (s *RedeemService) Redeem(ctx context.Context, userID int64, code string) (
return redeemCode, nil
}
// invalidateRedeemCaches 失效兑换相关的缓存
func (s *RedeemService) invalidateRedeemCaches(ctx context.Context, userID int64, redeemCode *RedeemCode) {
if s.billingCacheService == nil {
return
}
switch redeemCode.Type {
case RedeemTypeBalance:
go func() {
cacheCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
_ = s.billingCacheService.InvalidateUserBalance(cacheCtx, userID)
}()
case RedeemTypeSubscription:
if redeemCode.GroupID != nil {
groupID := *redeemCode.GroupID
go func() {
cacheCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
_ = s.billingCacheService.InvalidateSubscription(cacheCtx, userID, groupID)
}()
}
}
}
// GetByID 根据ID获取兑换码
func (s *RedeemService) GetByID(ctx context.Context, id int64) (*RedeemCode, error) {
code, err := s.redeemRepo.GetByID(ctx, id)

View File

@@ -26,6 +26,7 @@ var (
ErrDailyLimitExceeded = infraerrors.TooManyRequests("DAILY_LIMIT_EXCEEDED", "daily usage limit exceeded")
ErrWeeklyLimitExceeded = infraerrors.TooManyRequests("WEEKLY_LIMIT_EXCEEDED", "weekly usage limit exceeded")
ErrMonthlyLimitExceeded = infraerrors.TooManyRequests("MONTHLY_LIMIT_EXCEEDED", "monthly usage limit exceeded")
ErrSubscriptionNilInput = infraerrors.BadRequest("SUBSCRIPTION_NIL_INPUT", "subscription input cannot be nil")
)
// SubscriptionService 订阅服务