diff --git a/backend/ent/schema/mixins/soft_delete.go b/backend/ent/schema/mixins/soft_delete.go index 9f98a422..d62cf4a9 100644 --- a/backend/ent/schema/mixins/soft_delete.go +++ b/backend/ent/schema/mixins/soft_delete.go @@ -7,6 +7,7 @@ import ( "fmt" "time" + "github.com/Wei-Shaw/sub2api/ent/intercept" "entgo.io/ent" "entgo.io/ent/dialect" "entgo.io/ent/dialect/sql" @@ -79,16 +80,13 @@ func SkipSoftDelete(parent context.Context) context.Context { // 确保软删除的记录不会出现在普通查询结果中。 func (d SoftDeleteMixin) Interceptors() []ent.Interceptor { return []ent.Interceptor{ - ent.TraverseFunc(func(ctx context.Context, q ent.Query) error { + intercept.TraverseFunc(func(ctx context.Context, q intercept.Query) error { // 检查是否需要跳过软删除过滤 if skip, _ := ctx.Value(softDeleteKey{}).(bool); skip { return nil } // 为查询添加 deleted_at IS NULL 条件 - w, ok := q.(interface{ WhereP(...func(*sql.Selector)) }) - if ok { - d.applyPredicate(w) - } + d.applyPredicate(q) return nil }), } diff --git a/backend/internal/repository/account_repo.go b/backend/internal/repository/account_repo.go index e2818fa4..4659a77b 100644 --- a/backend/internal/repository/account_repo.go +++ b/backend/internal/repository/account_repo.go @@ -28,6 +28,7 @@ import ( "github.com/lib/pq" entsql "entgo.io/ent/dialect/sql" + "entgo.io/ent/dialect/sql/sqljson" ) // accountRepository 实现 service.AccountRepository 接口。 @@ -36,11 +37,9 @@ import ( // 设计说明: // - client: Ent 客户端,用于类型安全的 ORM 操作 // - sql: 原生 SQL 执行器,用于复杂查询和批量操作 -// - begin: SQL 事务开启器,用于需要事务的操作 type accountRepository struct { - client *dbent.Client // Ent ORM 客户端 - sql sqlExecutor // 原生 SQL 执行接口 - begin sqlBeginner // 事务开启接口 + client *dbent.Client // Ent ORM 客户端 + sql sqlExecutor // 原生 SQL 执行接口 } // NewAccountRepository 创建账户仓储实例。 @@ -52,11 +51,7 @@ func NewAccountRepository(client *dbent.Client, sqlDB *sql.DB) service.AccountRe // newAccountRepositoryWithSQL 是内部构造函数,支持依赖注入 SQL 执行器。 // 这种设计便于单元测试时注入 mock 对象。 func newAccountRepositoryWithSQL(client *dbent.Client, sqlq sqlExecutor) *accountRepository { - var beginner sqlBeginner - if b, ok := sqlq.(sqlBeginner); ok { - beginner = b - } - return &accountRepository{client: client, sql: sqlq, begin: beginner} + return &accountRepository{client: client, sql: sqlq} } func (r *accountRepository) Create(ctx context.Context, account *service.Account) error { @@ -146,9 +141,10 @@ func (r *accountRepository) GetByCRSAccountID(ctx context.Context, crsAccountID return nil, nil } + // 使用 sqljson.ValueEQ 生成 JSON 路径过滤,避免手写 SQL 片段导致语法兼容问题。 m, err := r.client.Account.Query(). Where(func(s *entsql.Selector) { - s.Where(entsql.ExprP("extra->>'crs_account_id' = ?", crsAccountID)) + s.Where(sqljson.ValueEQ(dbaccount.FieldExtra, crsAccountID, sqljson.Path("crs_account_id"))) }). Only(ctx) if err != nil { diff --git a/backend/internal/repository/account_repo_integration_test.go b/backend/internal/repository/account_repo_integration_test.go index 41874549..84a88f23 100644 --- a/backend/internal/repository/account_repo_integration_test.go +++ b/backend/internal/repository/account_repo_integration_test.go @@ -4,7 +4,6 @@ package repository import ( "context" - "database/sql" "testing" "time" @@ -17,18 +16,16 @@ import ( type AccountRepoSuite struct { suite.Suite - ctx context.Context - tx *sql.Tx + ctx context.Context client *dbent.Client - repo *accountRepository + repo *accountRepository } func (s *AccountRepoSuite) SetupTest() { s.ctx = context.Background() - client, tx := testEntSQLTx(s.T()) - s.client = client - s.tx = tx - s.repo = newAccountRepositoryWithSQL(client, tx) + tx := testEntTx(s.T()) + s.client = tx.Client() + s.repo = newAccountRepositoryWithSQL(s.client, tx) } func TestAccountRepoSuite(t *testing.T) { @@ -175,7 +172,8 @@ func (s *AccountRepoSuite) TestListWithFilters() { for _, tt := range tests { s.Run(tt.name, func() { // 每个 case 重新获取隔离资源 - client, tx := testEntSQLTx(s.T()) + tx := testEntTx(s.T()) + client := tx.Client() repo := newAccountRepositoryWithSQL(client, tx) ctx := context.Background() diff --git a/backend/internal/repository/allowed_groups_contract_integration_test.go b/backend/internal/repository/allowed_groups_contract_integration_test.go index c2aa945c..02cde527 100644 --- a/backend/internal/repository/allowed_groups_contract_integration_test.go +++ b/backend/internal/repository/allowed_groups_contract_integration_test.go @@ -20,7 +20,8 @@ func uniqueTestValue(t *testing.T, prefix string) string { func TestUserRepository_RemoveGroupFromAllowedGroups_RemovesAllOccurrences(t *testing.T) { ctx := context.Background() - entClient, sqlTx := testEntSQLTx(t) + tx := testEntTx(t) + entClient := tx.Client() targetGroup, err := entClient.Group.Create(). SetName(uniqueTestValue(t, "target-group")). @@ -33,7 +34,7 @@ func TestUserRepository_RemoveGroupFromAllowedGroups_RemovesAllOccurrences(t *te Save(ctx) require.NoError(t, err) - repo := newUserRepositoryWithSQL(entClient, sqlTx) + repo := newUserRepositoryWithSQL(entClient, tx) u1 := &service.User{ Email: uniqueTestValue(t, "u1") + "@example.com", @@ -81,7 +82,8 @@ func TestUserRepository_RemoveGroupFromAllowedGroups_RemovesAllOccurrences(t *te func TestGroupRepository_DeleteCascade_RemovesAllowedGroupsAndClearsApiKeys(t *testing.T) { ctx := context.Background() - entClient, sqlTx := testEntSQLTx(t) + tx := testEntTx(t) + entClient := tx.Client() targetGroup, err := entClient.Group.Create(). SetName(uniqueTestValue(t, "delete-cascade-target")). @@ -94,8 +96,8 @@ func TestGroupRepository_DeleteCascade_RemovesAllowedGroupsAndClearsApiKeys(t *t Save(ctx) require.NoError(t, err) - userRepo := newUserRepositoryWithSQL(entClient, sqlTx) - groupRepo := newGroupRepositoryWithSQL(entClient, sqlTx) + userRepo := newUserRepositoryWithSQL(entClient, tx) + groupRepo := newGroupRepositoryWithSQL(entClient, tx) apiKeyRepo := NewApiKeyRepository(entClient) u := &service.User{ @@ -141,4 +143,3 @@ func TestGroupRepository_DeleteCascade_RemovesAllowedGroupsAndClearsApiKeys(t *t require.NoError(t, err) require.Nil(t, keyAfter.GroupID) } - diff --git a/backend/internal/repository/api_key_repo.go b/backend/internal/repository/api_key_repo.go index 5b998b1b..7bd846e1 100644 --- a/backend/internal/repository/api_key_repo.go +++ b/backend/internal/repository/api_key_repo.go @@ -2,9 +2,11 @@ package repository import ( "context" + "time" dbent "github.com/Wei-Shaw/sub2api/ent" "github.com/Wei-Shaw/sub2api/ent/apikey" + "github.com/Wei-Shaw/sub2api/ent/schema/mixins" "github.com/Wei-Shaw/sub2api/internal/service" "github.com/Wei-Shaw/sub2api/internal/pkg/pagination" @@ -18,6 +20,11 @@ func NewApiKeyRepository(client *dbent.Client) service.ApiKeyRepository { return &apiKeyRepository{client: client} } +func (r *apiKeyRepository) activeQuery() *dbent.ApiKeyQuery { + // 默认过滤已软删除记录,避免删除后仍被查询到。 + return r.client.ApiKey.Query().Where(apikey.DeletedAtIsNil()) +} + func (r *apiKeyRepository) Create(ctx context.Context, key *service.ApiKey) error { created, err := r.client.ApiKey.Create(). SetUserID(key.UserID). @@ -35,7 +42,7 @@ func (r *apiKeyRepository) Create(ctx context.Context, key *service.ApiKey) erro } func (r *apiKeyRepository) GetByID(ctx context.Context, id int64) (*service.ApiKey, error) { - m, err := r.client.ApiKey.Query(). + m, err := r.activeQuery(). Where(apikey.IDEQ(id)). WithUser(). WithGroup(). @@ -55,7 +62,7 @@ func (r *apiKeyRepository) GetByID(ctx context.Context, id int64) (*service.ApiK // - 不加载完整的 ApiKey 实体及其关联数据(User、Group 等) // - 适用于权限验证等只需用户 ID 的场景(如删除前的所有权检查) func (r *apiKeyRepository) GetOwnerID(ctx context.Context, id int64) (int64, error) { - m, err := r.client.ApiKey.Query(). + m, err := r.activeQuery(). Where(apikey.IDEQ(id)). Select(apikey.FieldUserID). Only(ctx) @@ -69,7 +76,7 @@ func (r *apiKeyRepository) GetOwnerID(ctx context.Context, id int64) (int64, err } func (r *apiKeyRepository) GetByKey(ctx context.Context, key string) (*service.ApiKey, error) { - m, err := r.client.ApiKey.Query(). + m, err := r.activeQuery(). Where(apikey.KeyEQ(key)). WithUser(). WithGroup(). @@ -84,6 +91,14 @@ func (r *apiKeyRepository) GetByKey(ctx context.Context, key string) (*service.A } func (r *apiKeyRepository) Update(ctx context.Context, key *service.ApiKey) error { + exists, err := r.activeQuery().Where(apikey.IDEQ(key.ID)).Exist(ctx) + if err != nil { + return err + } + if !exists { + return service.ErrApiKeyNotFound + } + builder := r.client.ApiKey.UpdateOneID(key.ID). SetName(key.Name). SetStatus(key.Status) @@ -105,12 +120,34 @@ func (r *apiKeyRepository) Update(ctx context.Context, key *service.ApiKey) erro } func (r *apiKeyRepository) Delete(ctx context.Context, id int64) error { - _, err := r.client.ApiKey.Delete().Where(apikey.IDEQ(id)).Exec(ctx) - return err + // 显式软删除:避免依赖 Hook 行为,确保 deleted_at 一定被设置。 + affected, err := r.client.ApiKey.Update(). + Where(apikey.IDEQ(id), apikey.DeletedAtIsNil()). + SetDeletedAt(time.Now()). + Save(ctx) + if err != nil { + if dbent.IsNotFound(err) { + return service.ErrApiKeyNotFound + } + return err + } + if affected == 0 { + exists, err := r.client.ApiKey.Query(). + Where(apikey.IDEQ(id)). + Exist(mixins.SkipSoftDelete(ctx)) + if err != nil { + return err + } + if exists { + return nil + } + return service.ErrApiKeyNotFound + } + return nil } func (r *apiKeyRepository) ListByUserID(ctx context.Context, userID int64, params pagination.PaginationParams) ([]service.ApiKey, *pagination.PaginationResult, error) { - q := r.client.ApiKey.Query().Where(apikey.UserIDEQ(userID)) + q := r.activeQuery().Where(apikey.UserIDEQ(userID)) total, err := q.Count(ctx) if err != nil { @@ -141,7 +178,7 @@ func (r *apiKeyRepository) VerifyOwnership(ctx context.Context, userID int64, ap } ids, err := r.client.ApiKey.Query(). - Where(apikey.UserIDEQ(userID), apikey.IDIn(apiKeyIDs...)). + Where(apikey.UserIDEQ(userID), apikey.IDIn(apiKeyIDs...), apikey.DeletedAtIsNil()). IDs(ctx) if err != nil { return nil, err @@ -150,17 +187,17 @@ func (r *apiKeyRepository) VerifyOwnership(ctx context.Context, userID int64, ap } func (r *apiKeyRepository) CountByUserID(ctx context.Context, userID int64) (int64, error) { - count, err := r.client.ApiKey.Query().Where(apikey.UserIDEQ(userID)).Count(ctx) + count, err := r.activeQuery().Where(apikey.UserIDEQ(userID)).Count(ctx) return int64(count), err } func (r *apiKeyRepository) ExistsByKey(ctx context.Context, key string) (bool, error) { - count, err := r.client.ApiKey.Query().Where(apikey.KeyEQ(key)).Count(ctx) + count, err := r.activeQuery().Where(apikey.KeyEQ(key)).Count(ctx) return count > 0, err } func (r *apiKeyRepository) ListByGroupID(ctx context.Context, groupID int64, params pagination.PaginationParams) ([]service.ApiKey, *pagination.PaginationResult, error) { - q := r.client.ApiKey.Query().Where(apikey.GroupIDEQ(groupID)) + q := r.activeQuery().Where(apikey.GroupIDEQ(groupID)) total, err := q.Count(ctx) if err != nil { @@ -187,7 +224,7 @@ func (r *apiKeyRepository) ListByGroupID(ctx context.Context, groupID int64, par // SearchApiKeys searches API keys by user ID and/or keyword (name) func (r *apiKeyRepository) SearchApiKeys(ctx context.Context, userID int64, keyword string, limit int) ([]service.ApiKey, error) { - q := r.client.ApiKey.Query() + q := r.activeQuery() if userID > 0 { q = q.Where(apikey.UserIDEQ(userID)) } @@ -211,7 +248,7 @@ func (r *apiKeyRepository) SearchApiKeys(ctx context.Context, userID int64, keyw // ClearGroupIDByGroupID 将指定分组的所有 API Key 的 group_id 设为 nil func (r *apiKeyRepository) ClearGroupIDByGroupID(ctx context.Context, groupID int64) (int64, error) { n, err := r.client.ApiKey.Update(). - Where(apikey.GroupIDEQ(groupID)). + Where(apikey.GroupIDEQ(groupID), apikey.DeletedAtIsNil()). ClearGroupID(). Save(ctx) return int64(n), err @@ -219,7 +256,7 @@ func (r *apiKeyRepository) ClearGroupIDByGroupID(ctx context.Context, groupID in // CountByGroupID 获取分组的 API Key 数量 func (r *apiKeyRepository) CountByGroupID(ctx context.Context, groupID int64) (int64, error) { - count, err := r.client.ApiKey.Query().Where(apikey.GroupIDEQ(groupID)).Count(ctx) + count, err := r.activeQuery().Where(apikey.GroupIDEQ(groupID)).Count(ctx) return int64(count), err } diff --git a/backend/internal/repository/api_key_repo_integration_test.go b/backend/internal/repository/api_key_repo_integration_test.go index 0916fcc5..79564ff0 100644 --- a/backend/internal/repository/api_key_repo_integration_test.go +++ b/backend/internal/repository/api_key_repo_integration_test.go @@ -21,9 +21,9 @@ type ApiKeyRepoSuite struct { func (s *ApiKeyRepoSuite) SetupTest() { s.ctx = context.Background() - entClient, _ := testEntSQLTx(s.T()) - s.client = entClient - s.repo = NewApiKeyRepository(entClient).(*apiKeyRepository) + tx := testEntTx(s.T()) + s.client = tx.Client() + s.repo = NewApiKeyRepository(s.client).(*apiKeyRepository) } func TestApiKeyRepoSuite(t *testing.T) { diff --git a/backend/internal/repository/group_repo.go b/backend/internal/repository/group_repo.go index f54588ce..41ee84fe 100644 --- a/backend/internal/repository/group_repo.go +++ b/backend/internal/repository/group_repo.go @@ -3,6 +3,7 @@ package repository import ( "context" "database/sql" + "errors" dbent "github.com/Wei-Shaw/sub2api/ent" "github.com/Wei-Shaw/sub2api/ent/apikey" @@ -10,26 +11,16 @@ import ( "github.com/Wei-Shaw/sub2api/internal/pkg/pagination" "github.com/Wei-Shaw/sub2api/internal/service" "github.com/lib/pq" - - "entgo.io/ent/dialect" - entsql "entgo.io/ent/dialect/sql" ) type sqlExecutor interface { ExecContext(ctx context.Context, query string, args ...any) (sql.Result, error) QueryContext(ctx context.Context, query string, args ...any) (*sql.Rows, error) - QueryRowContext(ctx context.Context, query string, args ...any) *sql.Row -} - -type sqlBeginner interface { - sqlExecutor - BeginTx(ctx context.Context, opts *sql.TxOptions) (*sql.Tx, error) } type groupRepository struct { client *dbent.Client sql sqlExecutor - begin sqlBeginner } func NewGroupRepository(client *dbent.Client, sqlDB *sql.DB) service.GroupRepository { @@ -37,11 +28,7 @@ func NewGroupRepository(client *dbent.Client, sqlDB *sql.DB) service.GroupReposi } func newGroupRepositoryWithSQL(client *dbent.Client, sqlq sqlExecutor) *groupRepository { - var beginner sqlBeginner - if b, ok := sqlq.(sqlBeginner); ok { - beginner = b - } - return &groupRepository{client: client, sql: sqlq, begin: beginner} + return &groupRepository{client: client, sql: sqlq} } func (r *groupRepository) Create(ctx context.Context, groupIn *service.Group) error { @@ -214,7 +201,7 @@ func (r *groupRepository) ExistsByName(ctx context.Context, name string) (bool, func (r *groupRepository) GetAccountCount(ctx context.Context, groupID int64) (int64, error) { var count int64 - if err := r.sql.QueryRowContext(ctx, "SELECT COUNT(*) FROM account_groups WHERE group_id = $1", groupID).Scan(&count); err != nil { + if err := scanSingleRow(ctx, r.sql, "SELECT COUNT(*) FROM account_groups WHERE group_id = $1", []any{groupID}, &count); err != nil { return 0, err } return count, nil @@ -236,31 +223,44 @@ func (r *groupRepository) DeleteCascade(ctx context.Context, id int64) ([]int64, } groupSvc := groupEntityToService(g) - exec := r.sql + // 使用 ent 事务统一包裹:避免手工基于 *sql.Tx 构造 ent client 带来的驱动断言问题, + // 同时保证级联删除的原子性。 + tx, err := r.client.Tx(ctx) + if err != nil && !errors.Is(err, dbent.ErrTxStarted) { + return nil, err + } + exec := r.client txClient := r.client - var sqlTx *sql.Tx - - if r.begin != nil { - sqlTx, err = r.begin.BeginTx(ctx, nil) - if err != nil { - return nil, err - } - exec = sqlTx - txClient = entClientFromSQLTx(sqlTx) - // 注意:不能调用 txClient.Close(),因为基于事务的 ent client - // 在 Close() 时会尝试将 ExecQuerier 断言为 *sql.DB,但实际是 *sql.Tx - // 事务的清理通过 sqlTx.Rollback() 和 sqlTx.Commit() 完成 - defer func() { _ = sqlTx.Rollback() }() + if err == nil { + defer func() { _ = tx.Rollback() }() + exec = tx.Client() + txClient = exec + } else { + // 已处于外部事务中(ErrTxStarted),复用当前 client 参与同一事务。 } // Lock the group row to avoid concurrent writes while we cascade. - var lockedID int64 - if err := exec.QueryRowContext(ctx, "SELECT id FROM groups WHERE id = $1 FOR UPDATE", id).Scan(&lockedID); err != nil { - if errorsIsNoRows(err) { - return nil, service.ErrGroupNotFound - } + // 这里使用 exec.QueryContext 手动扫描,确保同一事务内加锁并能区分“未找到”与其他错误。 + rows, err := exec.QueryContext(ctx, "SELECT id FROM groups WHERE id = $1 FOR UPDATE", id) + if err != nil { return nil, err } + var lockedID int64 + if rows.Next() { + if err := rows.Scan(&lockedID); err != nil { + _ = rows.Close() + return nil, err + } + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + if lockedID == 0 { + return nil, service.ErrGroupNotFound + } var affectedUserIDs []int64 if groupSvc.IsSubscriptionType() { @@ -319,8 +319,8 @@ func (r *groupRepository) DeleteCascade(ctx context.Context, id int64) ([]int64, return nil, err } - if sqlTx != nil { - if err := sqlTx.Commit(); err != nil { + if tx != nil { + if err := tx.Commit(); err != nil { return nil, err } } @@ -359,11 +359,6 @@ func (r *groupRepository) loadAccountCounts(ctx context.Context, groupIDs []int6 return counts, nil } -func entClientFromSQLTx(tx *sql.Tx) *dbent.Client { - drv := entsql.NewDriver(dialect.Postgres, entsql.Conn{ExecQuerier: tx}) - return dbent.NewClient(dbent.Driver(drv)) -} - func errorsIsNoRows(err error) bool { return err == sql.ErrNoRows } diff --git a/backend/internal/repository/group_repo_integration_test.go b/backend/internal/repository/group_repo_integration_test.go index 90a898d1..a02c5f8f 100644 --- a/backend/internal/repository/group_repo_integration_test.go +++ b/backend/internal/repository/group_repo_integration_test.go @@ -4,9 +4,9 @@ package repository import ( "context" - "database/sql" "testing" + dbent "github.com/Wei-Shaw/sub2api/ent" "github.com/Wei-Shaw/sub2api/internal/pkg/pagination" "github.com/Wei-Shaw/sub2api/internal/service" "github.com/stretchr/testify/suite" @@ -15,15 +15,15 @@ import ( type GroupRepoSuite struct { suite.Suite ctx context.Context - tx *sql.Tx + tx *dbent.Tx repo *groupRepository } func (s *GroupRepoSuite) SetupTest() { s.ctx = context.Background() - entClient, tx := testEntSQLTx(s.T()) + tx := testEntTx(s.T()) s.tx = tx - s.repo = newGroupRepositoryWithSQL(entClient, tx) + s.repo = newGroupRepositoryWithSQL(tx.Client(), tx) } func TestGroupRepoSuite(t *testing.T) { @@ -99,6 +99,9 @@ func (s *GroupRepoSuite) TestDelete() { // --- List / ListWithFilters --- func (s *GroupRepoSuite) TestList() { + baseGroups, basePage, err := s.repo.List(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}) + s.Require().NoError(err, "List base") + s.Require().NoError(s.repo.Create(s.ctx, &service.Group{ Name: "g1", Platform: service.PlatformAnthropic, @@ -118,12 +121,20 @@ func (s *GroupRepoSuite) TestList() { groups, page, err := s.repo.List(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}) s.Require().NoError(err, "List") - // 3 default groups + 2 test groups = 5 total - s.Require().Len(groups, 5) - s.Require().Equal(int64(5), page.Total) + s.Require().Len(groups, len(baseGroups)+2) + s.Require().Equal(basePage.Total+2, page.Total) } func (s *GroupRepoSuite) TestListWithFilters_Platform() { + baseGroups, _, err := s.repo.ListWithFilters( + s.ctx, + pagination.PaginationParams{Page: 1, PageSize: 10}, + service.PlatformOpenAI, + "", + nil, + ) + s.Require().NoError(err, "ListWithFilters base") + s.Require().NoError(s.repo.Create(s.ctx, &service.Group{ Name: "g1", Platform: service.PlatformAnthropic, @@ -143,8 +154,7 @@ func (s *GroupRepoSuite) TestListWithFilters_Platform() { groups, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, service.PlatformOpenAI, "", nil) s.Require().NoError(err) - // 1 default openai group + 1 test openai group = 2 total - s.Require().Len(groups, 2) + s.Require().Len(groups, len(baseGroups)+1) // Verify all groups are OpenAI platform for _, g := range groups { s.Require().Equal(service.PlatformOpenAI, g.Platform) @@ -221,11 +231,13 @@ func (s *GroupRepoSuite) TestListWithFilters_AccountCount() { s.Require().NoError(s.repo.Create(s.ctx, g2)) var accountID int64 - s.Require().NoError(s.tx.QueryRowContext( + s.Require().NoError(scanSingleRow( s.ctx, + s.tx, "INSERT INTO accounts (name, platform, type) VALUES ($1, $2, $3) RETURNING id", - "acc1", service.PlatformAnthropic, service.AccountTypeOAuth, - ).Scan(&accountID)) + []any{"acc1", service.PlatformAnthropic, service.AccountTypeOAuth}, + &accountID, + )) _, err := s.tx.ExecContext(s.ctx, "INSERT INTO account_groups (account_id, group_id, priority, created_at) VALUES ($1, $2, $3, NOW())", accountID, g1.ID, 1) s.Require().NoError(err) _, err = s.tx.ExecContext(s.ctx, "INSERT INTO account_groups (account_id, group_id, priority, created_at) VALUES ($1, $2, $3, NOW())", accountID, g2.ID, 1) @@ -243,6 +255,9 @@ func (s *GroupRepoSuite) TestListWithFilters_AccountCount() { // --- ListActive / ListActiveByPlatform --- func (s *GroupRepoSuite) TestListActive() { + baseGroups, err := s.repo.ListActive(s.ctx) + s.Require().NoError(err, "ListActive base") + s.Require().NoError(s.repo.Create(s.ctx, &service.Group{ Name: "active1", Platform: service.PlatformAnthropic, @@ -262,8 +277,7 @@ func (s *GroupRepoSuite) TestListActive() { groups, err := s.repo.ListActive(s.ctx) s.Require().NoError(err, "ListActive") - // 3 default groups (all active) + 1 test active group = 4 total - s.Require().Len(groups, 4) + s.Require().Len(groups, len(baseGroups)+1) // Verify our test group is in the results var found bool for _, g := range groups { @@ -351,17 +365,21 @@ func (s *GroupRepoSuite) TestGetAccountCount() { s.Require().NoError(s.repo.Create(s.ctx, group)) var a1 int64 - s.Require().NoError(s.tx.QueryRowContext( + s.Require().NoError(scanSingleRow( s.ctx, + s.tx, "INSERT INTO accounts (name, platform, type) VALUES ($1, $2, $3) RETURNING id", - "a1", service.PlatformAnthropic, service.AccountTypeOAuth, - ).Scan(&a1)) + []any{"a1", service.PlatformAnthropic, service.AccountTypeOAuth}, + &a1, + )) var a2 int64 - s.Require().NoError(s.tx.QueryRowContext( + s.Require().NoError(scanSingleRow( s.ctx, + s.tx, "INSERT INTO accounts (name, platform, type) VALUES ($1, $2, $3) RETURNING id", - "a2", service.PlatformAnthropic, service.AccountTypeOAuth, - ).Scan(&a2)) + []any{"a2", service.PlatformAnthropic, service.AccountTypeOAuth}, + &a2, + )) _, err := s.tx.ExecContext(s.ctx, "INSERT INTO account_groups (account_id, group_id, priority, created_at) VALUES ($1, $2, $3, NOW())", a1, group.ID, 1) s.Require().NoError(err) @@ -402,11 +420,13 @@ func (s *GroupRepoSuite) TestDeleteAccountGroupsByGroupID() { } s.Require().NoError(s.repo.Create(s.ctx, g)) var accountID int64 - s.Require().NoError(s.tx.QueryRowContext( + s.Require().NoError(scanSingleRow( s.ctx, + s.tx, "INSERT INTO accounts (name, platform, type) VALUES ($1, $2, $3) RETURNING id", - "acc-del", service.PlatformAnthropic, service.AccountTypeOAuth, - ).Scan(&accountID)) + []any{"acc-del", service.PlatformAnthropic, service.AccountTypeOAuth}, + &accountID, + )) _, err := s.tx.ExecContext(s.ctx, "INSERT INTO account_groups (account_id, group_id, priority, created_at) VALUES ($1, $2, $3, NOW())", accountID, g.ID, 1) s.Require().NoError(err) @@ -432,11 +452,13 @@ func (s *GroupRepoSuite) TestDeleteAccountGroupsByGroupID_MultipleAccounts() { insertAccount := func(name string) int64 { var id int64 - s.Require().NoError(s.tx.QueryRowContext( + s.Require().NoError(scanSingleRow( s.ctx, + s.tx, "INSERT INTO accounts (name, platform, type) VALUES ($1, $2, $3) RETURNING id", - name, service.PlatformAnthropic, service.AccountTypeOAuth, - ).Scan(&id)) + []any{name, service.PlatformAnthropic, service.AccountTypeOAuth}, + &id, + )) return id } a1 := insertAccount("a1") diff --git a/backend/internal/repository/integration_harness_test.go b/backend/internal/repository/integration_harness_test.go index 315cb86a..553a581a 100644 --- a/backend/internal/repository/integration_harness_test.go +++ b/backend/internal/repository/integration_harness_test.go @@ -36,8 +36,9 @@ const ( ) var ( - integrationDB *sql.DB - integrationRedis *redisclient.Client + integrationDB *sql.DB + integrationEntClient *dbent.Client + integrationRedis *redisclient.Client redisNamespaceSeq uint64 ) @@ -101,6 +102,10 @@ func TestMain(m *testing.M) { os.Exit(1) } + // 创建 ent client 用于集成测试 + drv := entsql.OpenDB(dialect.Postgres, integrationDB) + integrationEntClient = dbent.NewClient(dbent.Driver(drv)) + redisHost, err := redisContainer.Host(ctx) if err != nil { log.Printf("failed to get redis host: %v", err) @@ -123,6 +128,7 @@ func TestMain(m *testing.M) { code := m.Run() + _ = integrationEntClient.Close() _ = integrationRedis.Close() _ = integrationDB.Close() @@ -193,18 +199,38 @@ func testTx(t *testing.T) *sql.Tx { return tx } +// testEntClient 返回全局的 ent client,用于测试需要内部管理事务的代码(如 Create/Update 方法)。 +// 注意:此 client 的操作会真正写入数据库,测试结束后不会自动回滚。 +func testEntClient(t *testing.T) *dbent.Client { + t.Helper() + return integrationEntClient +} + +// testEntTx 返回一个 ent 事务,用于需要事务隔离的测试。 +// 测试结束后会自动回滚,不会影响数据库状态。 +func testEntTx(t *testing.T) *dbent.Tx { + t.Helper() + + tx, err := integrationEntClient.Tx(context.Background()) + require.NoError(t, err, "begin ent tx") + t.Cleanup(func() { + _ = tx.Rollback() + }) + return tx +} + +// testEntSQLTx 已弃用:不要在新测试中使用此函数。 +// 基于 *sql.Tx 创建的 ent client 在调用 client.Tx() 时会 panic。 +// 对于需要测试内部使用事务的代码,请使用 testEntClient。 +// 对于需要事务隔离的测试,请使用 testEntTx。 +// +// Deprecated: Use testEntClient or testEntTx instead. func testEntSQLTx(t *testing.T) (*dbent.Client, *sql.Tx) { t.Helper() - tx := testTx(t) - drv := entsql.NewDriver(dialect.Postgres, entsql.Conn{ExecQuerier: tx}) - client := dbent.NewClient(dbent.Driver(drv)) - - t.Cleanup(func() { - _ = client.Close() - }) - - return client, tx + // 直接失败,避免旧测试误用导致的事务嵌套 panic。 + t.Fatalf("testEntSQLTx 已弃用:请使用 testEntClient 或 testEntTx") + return nil, nil } func testRedis(t *testing.T) *redisclient.Client { @@ -363,13 +389,16 @@ type IntegrationDBSuite struct { suite.Suite ctx context.Context client *dbent.Client - tx *sql.Tx + tx *dbent.Tx } // SetupTest initializes ctx and client for each test method. func (s *IntegrationDBSuite) SetupTest() { s.ctx = context.Background() - s.client, s.tx = testEntSQLTx(s.T()) + // 统一使用 ent.Tx,确保每个测试都有独立事务并自动回滚。 + tx := testEntTx(s.T()) + s.tx = tx + s.client = tx.Client() } // RequireNoError is a convenience method wrapping require.NoError with s.T(). diff --git a/backend/internal/repository/proxy_repo.go b/backend/internal/repository/proxy_repo.go index 26290a79..adbc2dfb 100644 --- a/backend/internal/repository/proxy_repo.go +++ b/backend/internal/repository/proxy_repo.go @@ -13,7 +13,6 @@ import ( type sqlQuerier interface { QueryContext(ctx context.Context, query string, args ...any) (*sql.Rows, error) - QueryRowContext(ctx context.Context, query string, args ...any) *sql.Row } type proxyRepository struct { @@ -170,9 +169,8 @@ func (r *proxyRepository) ExistsByHostPortAuth(ctx context.Context, host string, // CountAccountsByProxyID returns the number of accounts using a specific proxy func (r *proxyRepository) CountAccountsByProxyID(ctx context.Context, proxyID int64) (int64, error) { - row := r.sql.QueryRowContext(ctx, "SELECT COUNT(*) FROM accounts WHERE proxy_id = $1", proxyID) var count int64 - if err := row.Scan(&count); err != nil { + if err := scanSingleRow(ctx, r.sql, "SELECT COUNT(*) FROM accounts WHERE proxy_id = $1", []any{proxyID}, &count); err != nil { return 0, err } return count, nil diff --git a/backend/internal/repository/proxy_repo_integration_test.go b/backend/internal/repository/proxy_repo_integration_test.go index 6f88528f..8f5ef01e 100644 --- a/backend/internal/repository/proxy_repo_integration_test.go +++ b/backend/internal/repository/proxy_repo_integration_test.go @@ -4,10 +4,10 @@ package repository import ( "context" - "database/sql" "testing" "time" + dbent "github.com/Wei-Shaw/sub2api/ent" "github.com/Wei-Shaw/sub2api/internal/pkg/pagination" "github.com/Wei-Shaw/sub2api/internal/service" "github.com/stretchr/testify/suite" @@ -15,16 +15,16 @@ import ( type ProxyRepoSuite struct { suite.Suite - ctx context.Context - sqlTx *sql.Tx - repo *proxyRepository + ctx context.Context + tx *dbent.Tx + repo *proxyRepository } func (s *ProxyRepoSuite) SetupTest() { s.ctx = context.Background() - entClient, sqlTx := testEntSQLTx(s.T()) - s.sqlTx = sqlTx - s.repo = newProxyRepositoryWithSQL(entClient, sqlTx) + tx := testEntTx(s.T()) + s.tx = tx + s.repo = newProxyRepositoryWithSQL(tx.Client(), tx) } func TestProxyRepoSuite(t *testing.T) { @@ -306,7 +306,7 @@ func (s *ProxyRepoSuite) mustCreateProxyWithTimes(name, status string, createdAt Port: 8080, Status: status, }) - _, err := s.sqlTx.ExecContext(s.ctx, "UPDATE proxies SET created_at = $1, updated_at = $1 WHERE id = $2", createdAt, p.ID) + _, err := s.tx.ExecContext(s.ctx, "UPDATE proxies SET created_at = $1, updated_at = $1 WHERE id = $2", createdAt, p.ID) s.Require().NoError(err, "update proxy timestamps") return p } @@ -317,7 +317,7 @@ func (s *ProxyRepoSuite) mustInsertAccount(name string, proxyID *int64) { if proxyID != nil { pid = *proxyID } - _, err := s.sqlTx.ExecContext( + _, err := s.tx.ExecContext( s.ctx, "INSERT INTO accounts (name, platform, type, proxy_id) VALUES ($1, $2, $3, $4)", name, diff --git a/backend/internal/repository/redeem_code_repo_integration_test.go b/backend/internal/repository/redeem_code_repo_integration_test.go index ee9f79ed..39674b52 100644 --- a/backend/internal/repository/redeem_code_repo_integration_test.go +++ b/backend/internal/repository/redeem_code_repo_integration_test.go @@ -22,9 +22,9 @@ type RedeemCodeRepoSuite struct { func (s *RedeemCodeRepoSuite) SetupTest() { s.ctx = context.Background() - entClient, _ := testEntSQLTx(s.T()) - s.client = entClient - s.repo = NewRedeemCodeRepository(entClient).(*redeemCodeRepository) + tx := testEntTx(s.T()) + s.client = tx.Client() + s.repo = NewRedeemCodeRepository(s.client).(*redeemCodeRepository) } func TestRedeemCodeRepoSuite(t *testing.T) { diff --git a/backend/internal/repository/setting_repo_integration_test.go b/backend/internal/repository/setting_repo_integration_test.go index 71fac0b2..784124f4 100644 --- a/backend/internal/repository/setting_repo_integration_test.go +++ b/backend/internal/repository/setting_repo_integration_test.go @@ -18,8 +18,8 @@ type SettingRepoSuite struct { func (s *SettingRepoSuite) SetupTest() { s.ctx = context.Background() - entClient, _ := testEntSQLTx(s.T()) - s.repo = NewSettingRepository(entClient).(*settingRepository) + tx := testEntTx(s.T()) + s.repo = NewSettingRepository(tx.Client()).(*settingRepository) } func TestSettingRepoSuite(t *testing.T) { diff --git a/backend/internal/repository/soft_delete_ent_integration_test.go b/backend/internal/repository/soft_delete_ent_integration_test.go index e1e7a35a..02176f90 100644 --- a/backend/internal/repository/soft_delete_ent_integration_test.go +++ b/backend/internal/repository/soft_delete_ent_integration_test.go @@ -34,7 +34,8 @@ func createEntUser(t *testing.T, ctx context.Context, client *dbent.Client, emai func TestEntSoftDelete_ApiKey_DefaultFilterAndSkip(t *testing.T) { ctx := context.Background() - client, _ := testEntSQLTx(t) + // 使用全局 ent client,确保软删除验证在实际持久化数据上进行。 + client := testEntClient(t) u := createEntUser(t, ctx, client, uniqueSoftDeleteValue(t, "sd-user")+"@example.com") @@ -65,7 +66,8 @@ func TestEntSoftDelete_ApiKey_DefaultFilterAndSkip(t *testing.T) { func TestEntSoftDelete_ApiKey_DeleteIdempotent(t *testing.T) { ctx := context.Background() - client, _ := testEntSQLTx(t) + // 使用全局 ent client,避免事务回滚影响幂等性验证。 + client := testEntClient(t) u := createEntUser(t, ctx, client, uniqueSoftDeleteValue(t, "sd-user2")+"@example.com") @@ -84,7 +86,8 @@ func TestEntSoftDelete_ApiKey_DeleteIdempotent(t *testing.T) { func TestEntSoftDelete_ApiKey_HardDeleteViaSkipSoftDelete(t *testing.T) { ctx := context.Background() - client, _ := testEntSQLTx(t) + // 使用全局 ent client,确保 SkipSoftDelete 的硬删除语义可验证。 + client := testEntClient(t) u := createEntUser(t, ctx, client, uniqueSoftDeleteValue(t, "sd-user3")+"@example.com") diff --git a/backend/internal/repository/sql_scan.go b/backend/internal/repository/sql_scan.go new file mode 100644 index 00000000..f683f50d --- /dev/null +++ b/backend/internal/repository/sql_scan.go @@ -0,0 +1,33 @@ +package repository + +import ( + "context" + "database/sql" +) + +type sqlQueryer interface { + QueryContext(ctx context.Context, query string, args ...any) (*sql.Rows, error) +} + +// scanSingleRow executes a query and scans the first row into dest. +// If no rows are returned, sql.ErrNoRows is returned. +// 设计目的:仅依赖 QueryContext,避免 QueryRowContext 对 *sql.Tx 的强绑定, +// 让 ent.Tx 也能作为 sqlExecutor/Queryer 使用。 +func scanSingleRow(ctx context.Context, q sqlQueryer, query string, args []any, dest ...any) error { + rows, err := q.QueryContext(ctx, query, args...) + if err != nil { + return err + } + defer rows.Close() + + if !rows.Next() { + if err := rows.Err(); err != nil { + return err + } + return sql.ErrNoRows + } + if err := rows.Scan(dest...); err != nil { + return err + } + return rows.Err() +} diff --git a/backend/internal/repository/usage_log_repo.go b/backend/internal/repository/usage_log_repo.go index 5939c827..eeaaa12c 100644 --- a/backend/internal/repository/usage_log_repo.go +++ b/backend/internal/repository/usage_log_repo.go @@ -3,7 +3,6 @@ package repository import ( "context" "database/sql" - "errors" "fmt" "strings" "time" @@ -33,6 +32,7 @@ func NewUsageLogRepository(client *dbent.Client, sqlDB *sql.DB) service.UsageLog } func newUsageLogRepositoryWithSQL(client *dbent.Client, sqlq sqlExecutor) *usageLogRepository { + // 使用 scanSingleRow 替代 QueryRowContext,保证 ent.Tx 作为 sqlExecutor 可用。 return &usageLogRepository{client: client, sql: sqlq} } @@ -53,7 +53,7 @@ func (r *usageLogRepository) getPerformanceStats(ctx context.Context, userID int var requestCount int64 var tokenCount int64 - if err := r.sql.QueryRowContext(ctx, query, args...).Scan(&requestCount, &tokenCount); err != nil { + if err := scanSingleRow(ctx, r.sql, query, args, &requestCount, &tokenCount); err != nil { return 0, 0, err } return requestCount / 5, tokenCount / 5, nil @@ -114,9 +114,7 @@ func (r *usageLogRepository) Create(ctx context.Context, log *service.UsageLog) duration := nullInt(log.DurationMs) firstToken := nullInt(log.FirstTokenMs) - row := r.sql.QueryRowContext( - ctx, - query, + args := []any{ log.UserID, log.ApiKeyID, log.AccountID, @@ -142,9 +140,8 @@ func (r *usageLogRepository) Create(ctx context.Context, log *service.UsageLog) duration, firstToken, createdAt, - ) - - if err := row.Scan(&log.ID, &log.CreatedAt); err != nil { + } + if err := scanSingleRow(ctx, r.sql, query, args, &log.ID, &log.CreatedAt); err != nil { return err } log.RateMultiplier = rateMultiplier @@ -153,11 +150,22 @@ func (r *usageLogRepository) Create(ctx context.Context, log *service.UsageLog) func (r *usageLogRepository) GetByID(ctx context.Context, id int64) (*service.UsageLog, error) { query := "SELECT " + usageLogSelectColumns + " FROM usage_logs WHERE id = $1" - log, err := scanUsageLog(r.sql.QueryRowContext(ctx, query, id)) + rows, err := r.sql.QueryContext(ctx, query, id) if err != nil { - if errors.Is(err, sql.ErrNoRows) { - return nil, service.ErrUsageLogNotFound + return nil, err + } + defer rows.Close() + if !rows.Next() { + if err := rows.Err(); err != nil { + return nil, err } + return nil, service.ErrUsageLogNotFound + } + log, err := scanUsageLog(rows) + if err != nil { + return nil, err + } + if err := rows.Err(); err != nil { return nil, err } return log, nil @@ -195,8 +203,18 @@ func (r *usageLogRepository) GetUserStats(ctx context.Context, userID int64, sta ` stats := &UserStats{} - if err := r.sql.QueryRowContext(ctx, query, userID, startTime, endTime). - Scan(&stats.TotalRequests, &stats.TotalTokens, &stats.TotalCost, &stats.InputTokens, &stats.OutputTokens, &stats.CacheReadTokens); err != nil { + if err := scanSingleRow( + ctx, + r.sql, + query, + []any{userID, startTime, endTime}, + &stats.TotalRequests, + &stats.TotalTokens, + &stats.TotalCost, + &stats.InputTokens, + &stats.OutputTokens, + &stats.CacheReadTokens, + ); err != nil { return nil, err } return stats, nil @@ -219,8 +237,15 @@ func (r *usageLogRepository) GetDashboardStats(ctx context.Context) (*DashboardS FROM users WHERE deleted_at IS NULL ` - if err := r.sql.QueryRowContext(ctx, userStatsQuery, today, today). - Scan(&stats.TotalUsers, &stats.TodayNewUsers, &stats.ActiveUsers); err != nil { + if err := scanSingleRow( + ctx, + r.sql, + userStatsQuery, + []any{today, today}, + &stats.TotalUsers, + &stats.TodayNewUsers, + &stats.ActiveUsers, + ); err != nil { return nil, err } @@ -232,8 +257,14 @@ func (r *usageLogRepository) GetDashboardStats(ctx context.Context) (*DashboardS FROM api_keys WHERE deleted_at IS NULL ` - if err := r.sql.QueryRowContext(ctx, apiKeyStatsQuery, service.StatusActive). - Scan(&stats.TotalApiKeys, &stats.ActiveApiKeys); err != nil { + if err := scanSingleRow( + ctx, + r.sql, + apiKeyStatsQuery, + []any{service.StatusActive}, + &stats.TotalApiKeys, + &stats.ActiveApiKeys, + ); err != nil { return nil, err } @@ -248,8 +279,17 @@ func (r *usageLogRepository) GetDashboardStats(ctx context.Context) (*DashboardS FROM accounts WHERE deleted_at IS NULL ` - if err := r.sql.QueryRowContext(ctx, accountStatsQuery, service.StatusActive, service.StatusError, now, now). - Scan(&stats.TotalAccounts, &stats.NormalAccounts, &stats.ErrorAccounts, &stats.RateLimitAccounts, &stats.OverloadAccounts); err != nil { + if err := scanSingleRow( + ctx, + r.sql, + accountStatsQuery, + []any{service.StatusActive, service.StatusError, now, now}, + &stats.TotalAccounts, + &stats.NormalAccounts, + &stats.ErrorAccounts, + &stats.RateLimitAccounts, + &stats.OverloadAccounts, + ); err != nil { return nil, err } @@ -266,17 +306,20 @@ func (r *usageLogRepository) GetDashboardStats(ctx context.Context) (*DashboardS COALESCE(AVG(duration_ms), 0) as avg_duration_ms FROM usage_logs ` - if err := r.sql.QueryRowContext(ctx, totalStatsQuery). - Scan( - &stats.TotalRequests, - &stats.TotalInputTokens, - &stats.TotalOutputTokens, - &stats.TotalCacheCreationTokens, - &stats.TotalCacheReadTokens, - &stats.TotalCost, - &stats.TotalActualCost, - &stats.AverageDurationMs, - ); err != nil { + if err := scanSingleRow( + ctx, + r.sql, + totalStatsQuery, + nil, + &stats.TotalRequests, + &stats.TotalInputTokens, + &stats.TotalOutputTokens, + &stats.TotalCacheCreationTokens, + &stats.TotalCacheReadTokens, + &stats.TotalCost, + &stats.TotalActualCost, + &stats.AverageDurationMs, + ); err != nil { return nil, err } stats.TotalTokens = stats.TotalInputTokens + stats.TotalOutputTokens + stats.TotalCacheCreationTokens + stats.TotalCacheReadTokens @@ -294,16 +337,19 @@ func (r *usageLogRepository) GetDashboardStats(ctx context.Context) (*DashboardS FROM usage_logs WHERE created_at >= $1 ` - if err := r.sql.QueryRowContext(ctx, todayStatsQuery, today). - Scan( - &stats.TodayRequests, - &stats.TodayInputTokens, - &stats.TodayOutputTokens, - &stats.TodayCacheCreationTokens, - &stats.TodayCacheReadTokens, - &stats.TodayCost, - &stats.TodayActualCost, - ); err != nil { + if err := scanSingleRow( + ctx, + r.sql, + todayStatsQuery, + []any{today}, + &stats.TodayRequests, + &stats.TodayInputTokens, + &stats.TodayOutputTokens, + &stats.TodayCacheCreationTokens, + &stats.TodayCacheReadTokens, + &stats.TodayCost, + &stats.TodayActualCost, + ); err != nil { return nil, err } stats.TodayTokens = stats.TodayInputTokens + stats.TodayOutputTokens + stats.TodayCacheCreationTokens + stats.TodayCacheReadTokens @@ -345,16 +391,19 @@ func (r *usageLogRepository) GetUserStatsAggregated(ctx context.Context, userID ` var stats usagestats.UsageStats - if err := r.sql.QueryRowContext(ctx, query, userID, startTime, endTime). - Scan( - &stats.TotalRequests, - &stats.TotalInputTokens, - &stats.TotalOutputTokens, - &stats.TotalCacheTokens, - &stats.TotalCost, - &stats.TotalActualCost, - &stats.AverageDurationMs, - ); err != nil { + if err := scanSingleRow( + ctx, + r.sql, + query, + []any{userID, startTime, endTime}, + &stats.TotalRequests, + &stats.TotalInputTokens, + &stats.TotalOutputTokens, + &stats.TotalCacheTokens, + &stats.TotalCost, + &stats.TotalActualCost, + &stats.AverageDurationMs, + ); err != nil { return nil, err } stats.TotalTokens = stats.TotalInputTokens + stats.TotalOutputTokens + stats.TotalCacheTokens @@ -377,16 +426,19 @@ func (r *usageLogRepository) GetApiKeyStatsAggregated(ctx context.Context, apiKe ` var stats usagestats.UsageStats - if err := r.sql.QueryRowContext(ctx, query, apiKeyID, startTime, endTime). - Scan( - &stats.TotalRequests, - &stats.TotalInputTokens, - &stats.TotalOutputTokens, - &stats.TotalCacheTokens, - &stats.TotalCost, - &stats.TotalActualCost, - &stats.AverageDurationMs, - ); err != nil { + if err := scanSingleRow( + ctx, + r.sql, + query, + []any{apiKeyID, startTime, endTime}, + &stats.TotalRequests, + &stats.TotalInputTokens, + &stats.TotalOutputTokens, + &stats.TotalCacheTokens, + &stats.TotalCost, + &stats.TotalActualCost, + &stats.AverageDurationMs, + ); err != nil { return nil, err } stats.TotalTokens = stats.TotalInputTokens + stats.TotalOutputTokens + stats.TotalCacheTokens @@ -430,8 +482,15 @@ func (r *usageLogRepository) GetAccountTodayStats(ctx context.Context, accountID ` stats := &usagestats.AccountStats{} - if err := r.sql.QueryRowContext(ctx, query, accountID, today). - Scan(&stats.Requests, &stats.Tokens, &stats.Cost); err != nil { + if err := scanSingleRow( + ctx, + r.sql, + query, + []any{accountID, today}, + &stats.Requests, + &stats.Tokens, + &stats.Cost, + ); err != nil { return nil, err } return stats, nil @@ -449,8 +508,15 @@ func (r *usageLogRepository) GetAccountWindowStats(ctx context.Context, accountI ` stats := &usagestats.AccountStats{} - if err := r.sql.QueryRowContext(ctx, query, accountID, startTime). - Scan(&stats.Requests, &stats.Tokens, &stats.Cost); err != nil { + if err := scanSingleRow( + ctx, + r.sql, + query, + []any{accountID, startTime}, + &stats.Requests, + &stats.Tokens, + &stats.Cost, + ); err != nil { return nil, err } return stats, nil @@ -581,12 +647,22 @@ func (r *usageLogRepository) GetUserDashboardStats(ctx context.Context, userID i today := timezone.Today() // API Key 统计 - if err := r.sql.QueryRowContext(ctx, "SELECT COUNT(*) FROM api_keys WHERE user_id = $1 AND deleted_at IS NULL", userID). - Scan(&stats.TotalApiKeys); err != nil { + if err := scanSingleRow( + ctx, + r.sql, + "SELECT COUNT(*) FROM api_keys WHERE user_id = $1 AND deleted_at IS NULL", + []any{userID}, + &stats.TotalApiKeys, + ); err != nil { return nil, err } - if err := r.sql.QueryRowContext(ctx, "SELECT COUNT(*) FROM api_keys WHERE user_id = $1 AND status = $2 AND deleted_at IS NULL", userID, service.StatusActive). - Scan(&stats.ActiveApiKeys); err != nil { + if err := scanSingleRow( + ctx, + r.sql, + "SELECT COUNT(*) FROM api_keys WHERE user_id = $1 AND status = $2 AND deleted_at IS NULL", + []any{userID, service.StatusActive}, + &stats.ActiveApiKeys, + ); err != nil { return nil, err } @@ -604,17 +680,20 @@ func (r *usageLogRepository) GetUserDashboardStats(ctx context.Context, userID i FROM usage_logs WHERE user_id = $1 ` - if err := r.sql.QueryRowContext(ctx, totalStatsQuery, userID). - Scan( - &stats.TotalRequests, - &stats.TotalInputTokens, - &stats.TotalOutputTokens, - &stats.TotalCacheCreationTokens, - &stats.TotalCacheReadTokens, - &stats.TotalCost, - &stats.TotalActualCost, - &stats.AverageDurationMs, - ); err != nil { + if err := scanSingleRow( + ctx, + r.sql, + totalStatsQuery, + []any{userID}, + &stats.TotalRequests, + &stats.TotalInputTokens, + &stats.TotalOutputTokens, + &stats.TotalCacheCreationTokens, + &stats.TotalCacheReadTokens, + &stats.TotalCost, + &stats.TotalActualCost, + &stats.AverageDurationMs, + ); err != nil { return nil, err } stats.TotalTokens = stats.TotalInputTokens + stats.TotalOutputTokens + stats.TotalCacheCreationTokens + stats.TotalCacheReadTokens @@ -632,16 +711,19 @@ func (r *usageLogRepository) GetUserDashboardStats(ctx context.Context, userID i FROM usage_logs WHERE user_id = $1 AND created_at >= $2 ` - if err := r.sql.QueryRowContext(ctx, todayStatsQuery, userID, today). - Scan( - &stats.TodayRequests, - &stats.TodayInputTokens, - &stats.TodayOutputTokens, - &stats.TodayCacheCreationTokens, - &stats.TodayCacheReadTokens, - &stats.TodayCost, - &stats.TodayActualCost, - ); err != nil { + if err := scanSingleRow( + ctx, + r.sql, + todayStatsQuery, + []any{userID, today}, + &stats.TodayRequests, + &stats.TodayInputTokens, + &stats.TodayOutputTokens, + &stats.TodayCacheCreationTokens, + &stats.TodayCacheReadTokens, + &stats.TodayCost, + &stats.TodayActualCost, + ); err != nil { return nil, err } stats.TodayTokens = stats.TodayInputTokens + stats.TodayOutputTokens + stats.TodayCacheCreationTokens + stats.TodayCacheReadTokens @@ -1007,16 +1089,19 @@ func (r *usageLogRepository) GetGlobalStats(ctx context.Context, startTime, endT ` stats := &UsageStats{} - if err := r.sql.QueryRowContext(ctx, query, startTime, endTime). - Scan( - &stats.TotalRequests, - &stats.TotalInputTokens, - &stats.TotalOutputTokens, - &stats.TotalCacheTokens, - &stats.TotalCost, - &stats.TotalActualCost, - &stats.AverageDurationMs, - ); err != nil { + if err := scanSingleRow( + ctx, + r.sql, + query, + []any{startTime, endTime}, + &stats.TotalRequests, + &stats.TotalInputTokens, + &stats.TotalOutputTokens, + &stats.TotalCacheTokens, + &stats.TotalCost, + &stats.TotalActualCost, + &stats.AverageDurationMs, + ); err != nil { return nil, err } stats.TotalTokens = stats.TotalInputTokens + stats.TotalOutputTokens + stats.TotalCacheTokens @@ -1108,7 +1193,7 @@ func (r *usageLogRepository) GetAccountUsageStats(ctx context.Context, accountID avgQuery := "SELECT COALESCE(AVG(duration_ms), 0) as avg_duration_ms FROM usage_logs WHERE account_id = $1 AND created_at >= $2 AND created_at < $3" var avgDuration float64 - if err := r.sql.QueryRowContext(ctx, avgQuery, accountID, startTime, endTime).Scan(&avgDuration); err != nil { + if err := scanSingleRow(ctx, r.sql, avgQuery, []any{accountID, startTime, endTime}, &avgDuration); err != nil { return nil, err } @@ -1186,7 +1271,7 @@ func (r *usageLogRepository) GetAccountUsageStats(ctx context.Context, accountID func (r *usageLogRepository) listUsageLogsWithPagination(ctx context.Context, whereClause string, args []any, params pagination.PaginationParams) ([]service.UsageLog, *pagination.PaginationResult, error) { countQuery := "SELECT COUNT(*) FROM usage_logs " + whereClause var total int64 - if err := r.sql.QueryRowContext(ctx, countQuery, args...).Scan(&total); err != nil { + if err := scanSingleRow(ctx, r.sql, countQuery, args, &total); err != nil { return nil, nil, err } diff --git a/backend/internal/repository/usage_log_repo_integration_test.go b/backend/internal/repository/usage_log_repo_integration_test.go index 4ef5fa56..ef03ada7 100644 --- a/backend/internal/repository/usage_log_repo_integration_test.go +++ b/backend/internal/repository/usage_log_repo_integration_test.go @@ -4,7 +4,6 @@ package repository import ( "context" - "database/sql" "testing" "time" @@ -19,17 +18,17 @@ import ( type UsageLogRepoSuite struct { suite.Suite ctx context.Context - tx *sql.Tx + tx *dbent.Tx client *dbent.Client repo *usageLogRepository } func (s *UsageLogRepoSuite) SetupTest() { s.ctx = context.Background() - client, tx := testEntSQLTx(s.T()) - s.client = client + tx := testEntTx(s.T()) s.tx = tx - s.repo = newUsageLogRepositoryWithSQL(client, tx) + s.client = tx.Client() + s.repo = newUsageLogRepositoryWithSQL(s.client, tx) } func TestUsageLogRepoSuite(t *testing.T) { @@ -197,6 +196,8 @@ func (s *UsageLogRepoSuite) TestListWithFilters() { func (s *UsageLogRepoSuite) TestDashboardStats_TodayTotalsAndPerformance() { now := time.Now() todayStart := timezone.Today() + baseStats, err := s.repo.GetDashboardStats(s.ctx) + s.Require().NoError(err, "GetDashboardStats base") userToday := mustCreateUser(s.T(), s.client, &service.User{ Email: "today@example.com", @@ -268,24 +269,24 @@ func (s *UsageLogRepoSuite) TestDashboardStats_TodayTotalsAndPerformance() { stats, err := s.repo.GetDashboardStats(s.ctx) s.Require().NoError(err, "GetDashboardStats") - s.Require().Equal(int64(2), stats.TotalUsers, "TotalUsers mismatch") - s.Require().Equal(int64(1), stats.TodayNewUsers, "TodayNewUsers mismatch") - s.Require().Equal(int64(1), stats.ActiveUsers, "ActiveUsers mismatch") - s.Require().Equal(int64(2), stats.TotalApiKeys, "TotalApiKeys mismatch") - s.Require().Equal(int64(1), stats.ActiveApiKeys, "ActiveApiKeys mismatch") - s.Require().Equal(int64(4), stats.TotalAccounts, "TotalAccounts mismatch") - s.Require().Equal(int64(1), stats.ErrorAccounts, "ErrorAccounts mismatch") - s.Require().Equal(int64(1), stats.RateLimitAccounts, "RateLimitAccounts mismatch") - s.Require().Equal(int64(1), stats.OverloadAccounts, "OverloadAccounts mismatch") + s.Require().Equal(baseStats.TotalUsers+2, stats.TotalUsers, "TotalUsers mismatch") + s.Require().Equal(baseStats.TodayNewUsers+1, stats.TodayNewUsers, "TodayNewUsers mismatch") + s.Require().Equal(baseStats.ActiveUsers+1, stats.ActiveUsers, "ActiveUsers mismatch") + s.Require().Equal(baseStats.TotalApiKeys+2, stats.TotalApiKeys, "TotalApiKeys mismatch") + s.Require().Equal(baseStats.ActiveApiKeys+1, stats.ActiveApiKeys, "ActiveApiKeys mismatch") + s.Require().Equal(baseStats.TotalAccounts+4, stats.TotalAccounts, "TotalAccounts mismatch") + s.Require().Equal(baseStats.ErrorAccounts+1, stats.ErrorAccounts, "ErrorAccounts mismatch") + s.Require().Equal(baseStats.RateLimitAccounts+1, stats.RateLimitAccounts, "RateLimitAccounts mismatch") + s.Require().Equal(baseStats.OverloadAccounts+1, stats.OverloadAccounts, "OverloadAccounts mismatch") - s.Require().Equal(int64(3), stats.TotalRequests, "TotalRequests mismatch") - s.Require().Equal(int64(16), stats.TotalInputTokens, "TotalInputTokens mismatch") - s.Require().Equal(int64(28), stats.TotalOutputTokens, "TotalOutputTokens mismatch") - s.Require().Equal(int64(3), stats.TotalCacheCreationTokens, "TotalCacheCreationTokens mismatch") - s.Require().Equal(int64(4), stats.TotalCacheReadTokens, "TotalCacheReadTokens mismatch") - s.Require().Equal(int64(51), stats.TotalTokens, "TotalTokens mismatch") - s.Require().Equal(2.3, stats.TotalCost, "TotalCost mismatch") - s.Require().Equal(2.0, stats.TotalActualCost, "TotalActualCost mismatch") + s.Require().Equal(baseStats.TotalRequests+3, stats.TotalRequests, "TotalRequests mismatch") + s.Require().Equal(baseStats.TotalInputTokens+int64(16), stats.TotalInputTokens, "TotalInputTokens mismatch") + s.Require().Equal(baseStats.TotalOutputTokens+int64(28), stats.TotalOutputTokens, "TotalOutputTokens mismatch") + s.Require().Equal(baseStats.TotalCacheCreationTokens+int64(3), stats.TotalCacheCreationTokens, "TotalCacheCreationTokens mismatch") + s.Require().Equal(baseStats.TotalCacheReadTokens+int64(4), stats.TotalCacheReadTokens, "TotalCacheReadTokens mismatch") + s.Require().Equal(baseStats.TotalTokens+int64(51), stats.TotalTokens, "TotalTokens mismatch") + s.Require().Equal(baseStats.TotalCost+2.3, stats.TotalCost, "TotalCost mismatch") + s.Require().Equal(baseStats.TotalActualCost+2.0, stats.TotalActualCost, "TotalActualCost mismatch") s.Require().GreaterOrEqual(stats.TodayRequests, int64(1), "expected TodayRequests >= 1") s.Require().GreaterOrEqual(stats.TodayCost, 0.0, "expected TodayCost >= 0") diff --git a/backend/internal/repository/user_repo.go b/backend/internal/repository/user_repo.go index 3ffb00f3..b88c47d3 100644 --- a/backend/internal/repository/user_repo.go +++ b/backend/internal/repository/user_repo.go @@ -3,6 +3,7 @@ package repository import ( "context" "database/sql" + "errors" "sort" dbent "github.com/Wei-Shaw/sub2api/ent" @@ -17,7 +18,6 @@ import ( type userRepository struct { client *dbent.Client sql sqlExecutor - begin sqlBeginner } func NewUserRepository(client *dbent.Client, sqlDB *sql.DB) service.UserRepository { @@ -25,11 +25,7 @@ func NewUserRepository(client *dbent.Client, sqlDB *sql.DB) service.UserReposito } func newUserRepositoryWithSQL(client *dbent.Client, sqlq sqlExecutor) *userRepository { - var beginner sqlBeginner - if b, ok := sqlq.(sqlBeginner); ok { - beginner = b - } - return &userRepository{client: client, sql: sqlq, begin: beginner} + return &userRepository{client: client, sql: sqlq} } func (r *userRepository) Create(ctx context.Context, userIn *service.User) error { @@ -37,22 +33,20 @@ func (r *userRepository) Create(ctx context.Context, userIn *service.User) error return nil } - exec := r.sql - txClient := r.client - var sqlTx *sql.Tx + // 统一使用 ent 的事务:保证用户与允许分组的更新原子化, + // 并避免基于 *sql.Tx 手动构造 ent client 导致的 ExecQuerier 断言错误。 + tx, err := r.client.Tx(ctx) + if err != nil && !errors.Is(err, dbent.ErrTxStarted) { + return err + } - if r.begin != nil { - var err error - sqlTx, err = r.begin.BeginTx(ctx, nil) - if err != nil { - return err - } - exec = sqlTx - txClient = entClientFromSQLTx(sqlTx) - // 注意:不能调用 txClient.Close(),因为基于事务的 ent client - // 在 Close() 时会尝试将 ExecQuerier 断言为 *sql.DB,但实际是 *sql.Tx - // 事务的清理通过 sqlTx.Rollback() 和 sqlTx.Commit() 完成 - defer func() { _ = sqlTx.Rollback() }() + var txClient *dbent.Client + if err == nil { + defer func() { _ = tx.Rollback() }() + txClient = tx.Client() + } else { + // 已处于外部事务中(ErrTxStarted),复用当前 client 并由调用方负责提交/回滚。 + txClient = r.client } created, err := txClient.User.Create(). @@ -70,12 +64,12 @@ func (r *userRepository) Create(ctx context.Context, userIn *service.User) error return translatePersistenceError(err, nil, service.ErrEmailExists) } - if err := r.syncUserAllowedGroups(ctx, txClient, exec, created.ID, userIn.AllowedGroups); err != nil { + if err := r.syncUserAllowedGroupsWithClient(ctx, txClient, created.ID, userIn.AllowedGroups); err != nil { return err } - if sqlTx != nil { - if err := sqlTx.Commit(); err != nil { + if tx != nil { + if err := tx.Commit(); err != nil { return err } } @@ -121,22 +115,19 @@ func (r *userRepository) Update(ctx context.Context, userIn *service.User) error return nil } - exec := r.sql - txClient := r.client - var sqlTx *sql.Tx + // 使用 ent 事务包裹用户更新与 allowed_groups 同步,避免跨层事务不一致。 + tx, err := r.client.Tx(ctx) + if err != nil && !errors.Is(err, dbent.ErrTxStarted) { + return err + } - if r.begin != nil { - var err error - sqlTx, err = r.begin.BeginTx(ctx, nil) - if err != nil { - return err - } - exec = sqlTx - txClient = entClientFromSQLTx(sqlTx) - // 注意:不能调用 txClient.Close(),因为基于事务的 ent client - // 在 Close() 时会尝试将 ExecQuerier 断言为 *sql.DB,但实际是 *sql.Tx - // 事务的清理通过 sqlTx.Rollback() 和 sqlTx.Commit() 完成 - defer func() { _ = sqlTx.Rollback() }() + var txClient *dbent.Client + if err == nil { + defer func() { _ = tx.Rollback() }() + txClient = tx.Client() + } else { + // 已处于外部事务中(ErrTxStarted),复用当前 client 并由调用方负责提交/回滚。 + txClient = r.client } updated, err := txClient.User.UpdateOneID(userIn.ID). @@ -154,12 +145,12 @@ func (r *userRepository) Update(ctx context.Context, userIn *service.User) error return translatePersistenceError(err, service.ErrUserNotFound, service.ErrEmailExists) } - if err := r.syncUserAllowedGroups(ctx, txClient, exec, updated.ID, userIn.AllowedGroups); err != nil { + if err := r.syncUserAllowedGroupsWithClient(ctx, txClient, updated.ID, userIn.AllowedGroups); err != nil { return err } - if sqlTx != nil { - if err := sqlTx.Commit(); err != nil { + if tx != nil { + if err := tx.Commit(); err != nil { return err } } @@ -289,8 +280,10 @@ func (r *userRepository) ExistsByEmail(ctx context.Context, email string) (bool, } func (r *userRepository) RemoveGroupFromAllowedGroups(ctx context.Context, groupID int64) (int64, error) { - if r.sql == nil { - return 0, nil + exec := r.sql + if exec == nil { + // 未注入 sqlExecutor 时,退回到 ent client 的 ExecContext(支持事务)。 + exec = r.client } joinAffected, err := r.client.UserAllowedGroup.Delete(). @@ -300,7 +293,7 @@ func (r *userRepository) RemoveGroupFromAllowedGroups(ctx context.Context, group return 0, err } - arrayRes, err := r.sql.ExecContext( + arrayRes, err := exec.ExecContext( ctx, "UPDATE users SET allowed_groups = array_remove(allowed_groups, $1), updated_at = NOW() WHERE $1 = ANY(allowed_groups)", groupID, @@ -362,6 +355,56 @@ func (r *userRepository) loadAllowedGroups(ctx context.Context, userIDs []int64) return out, nil } +// syncUserAllowedGroupsWithClient 在 ent client/事务内同步用户允许分组: +// 1) 以 user_allowed_groups 为读写源,确保新旧逻辑一致; +// 2) 额外更新 users.allowed_groups(历史字段)以保持兼容。 +func (r *userRepository) syncUserAllowedGroupsWithClient(ctx context.Context, client *dbent.Client, userID int64, groupIDs []int64) error { + if client == nil { + return nil + } + + // Keep join table as the source of truth for reads. + if _, err := client.UserAllowedGroup.Delete().Where(userallowedgroup.UserIDEQ(userID)).Exec(ctx); err != nil { + return err + } + + unique := make(map[int64]struct{}, len(groupIDs)) + for _, id := range groupIDs { + if id <= 0 { + continue + } + unique[id] = struct{}{} + } + + legacyGroups := make([]int64, 0, len(unique)) + if len(unique) > 0 { + creates := make([]*dbent.UserAllowedGroupCreate, 0, len(unique)) + for groupID := range unique { + creates = append(creates, client.UserAllowedGroup.Create().SetUserID(userID).SetGroupID(groupID)) + legacyGroups = append(legacyGroups, groupID) + } + if err := client.UserAllowedGroup. + CreateBulk(creates...). + OnConflictColumns(userallowedgroup.FieldUserID, userallowedgroup.FieldGroupID). + DoNothing(). + Exec(ctx); err != nil { + return err + } + } + + // Phase 1 兼容:保持 users.allowed_groups(数组字段)同步,避免旧查询路径读取到过期数据。 + var legacy any + if len(legacyGroups) > 0 { + sort.Slice(legacyGroups, func(i, j int) bool { return legacyGroups[i] < legacyGroups[j] }) + legacy = pq.Array(legacyGroups) + } + if _, err := client.ExecContext(ctx, "UPDATE users SET allowed_groups = $1::bigint[] WHERE id = $2", legacy, userID); err != nil { + return err + } + + return nil +} + func (r *userRepository) syncUserAllowedGroups(ctx context.Context, client *dbent.Client, exec sqlExecutor, userID int64, groupIDs []int64) error { if client == nil || exec == nil { return nil diff --git a/backend/internal/repository/user_repo_integration_test.go b/backend/internal/repository/user_repo_integration_test.go index afb1fb6a..a59d2312 100644 --- a/backend/internal/repository/user_repo_integration_test.go +++ b/backend/internal/repository/user_repo_integration_test.go @@ -4,7 +4,6 @@ package repository import ( "context" - "database/sql" "testing" "time" @@ -17,17 +16,19 @@ import ( type UserRepoSuite struct { suite.Suite ctx context.Context - tx *sql.Tx client *dbent.Client repo *userRepository } func (s *UserRepoSuite) SetupTest() { s.ctx = context.Background() - entClient, tx := testEntSQLTx(s.T()) - s.tx = tx - s.client = entClient - s.repo = newUserRepositoryWithSQL(entClient, tx) + s.client = testEntClient(s.T()) + s.repo = newUserRepositoryWithSQL(s.client, integrationDB) + + // 清理测试数据,确保每个测试从干净状态开始 + _, _ = integrationDB.ExecContext(s.ctx, "DELETE FROM user_subscriptions") + _, _ = integrationDB.ExecContext(s.ctx, "DELETE FROM user_allowed_groups") + _, _ = integrationDB.ExecContext(s.ctx, "DELETE FROM users") } func TestUserRepoSuite(t *testing.T) { diff --git a/backend/internal/repository/user_subscription_repo_integration_test.go b/backend/internal/repository/user_subscription_repo_integration_test.go index e9859012..282b9673 100644 --- a/backend/internal/repository/user_subscription_repo_integration_test.go +++ b/backend/internal/repository/user_subscription_repo_integration_test.go @@ -22,8 +22,8 @@ type UserSubscriptionRepoSuite struct { func (s *UserSubscriptionRepoSuite) SetupTest() { s.ctx = context.Background() - client, _ := testEntSQLTx(s.T()) - s.client = client + tx := testEntTx(s.T()) + s.client = tx.Client() s.repo = NewUserSubscriptionRepository(s.client).(*userSubscriptionRepository) } @@ -66,8 +66,8 @@ func (s *UserSubscriptionRepoSuite) mustCreateSubscription(userID, groupID int64 create := s.client.UserSubscription.Create(). SetUserID(userID). SetGroupID(groupID). - SetStartsAt(now.Add(-1*time.Hour)). - SetExpiresAt(now.Add(24*time.Hour)). + SetStartsAt(now.Add(-1 * time.Hour)). + SetExpiresAt(now.Add(24 * time.Hour)). SetStatus(service.SubscriptionStatusActive). SetAssignedAt(now). SetNotes("") @@ -631,4 +631,3 @@ func (s *UserSubscriptionRepoSuite) TestActiveExpiredBoundaries_UsageAndReset_Ba s.Require().NoError(err, "GetByID expired") s.Require().Equal(service.SubscriptionStatusExpired, updated.Status, "expected status expired") } -