fix(仓储): 修复软删除过滤与事务测试

修复软删除拦截器使用错误,确保默认查询过滤已删记录
仓储层改用 ent.Tx 与扫描辅助,避免 sql.Tx 断言问题
同步更新集成测试以覆盖事务与统计变动
This commit is contained in:
yangjianbo
2025-12-29 19:23:49 +08:00
parent b436da7249
commit ae191f72a4
20 changed files with 565 additions and 326 deletions

View File

@@ -7,6 +7,7 @@ import (
"fmt" "fmt"
"time" "time"
"github.com/Wei-Shaw/sub2api/ent/intercept"
"entgo.io/ent" "entgo.io/ent"
"entgo.io/ent/dialect" "entgo.io/ent/dialect"
"entgo.io/ent/dialect/sql" "entgo.io/ent/dialect/sql"
@@ -79,16 +80,13 @@ func SkipSoftDelete(parent context.Context) context.Context {
// 确保软删除的记录不会出现在普通查询结果中。 // 确保软删除的记录不会出现在普通查询结果中。
func (d SoftDeleteMixin) Interceptors() []ent.Interceptor { func (d SoftDeleteMixin) Interceptors() []ent.Interceptor {
return []ent.Interceptor{ return []ent.Interceptor{
ent.TraverseFunc(func(ctx context.Context, q ent.Query) error { intercept.TraverseFunc(func(ctx context.Context, q intercept.Query) error {
// 检查是否需要跳过软删除过滤 // 检查是否需要跳过软删除过滤
if skip, _ := ctx.Value(softDeleteKey{}).(bool); skip { if skip, _ := ctx.Value(softDeleteKey{}).(bool); skip {
return nil return nil
} }
// 为查询添加 deleted_at IS NULL 条件 // 为查询添加 deleted_at IS NULL 条件
w, ok := q.(interface{ WhereP(...func(*sql.Selector)) }) d.applyPredicate(q)
if ok {
d.applyPredicate(w)
}
return nil return nil
}), }),
} }

View File

@@ -28,6 +28,7 @@ import (
"github.com/lib/pq" "github.com/lib/pq"
entsql "entgo.io/ent/dialect/sql" entsql "entgo.io/ent/dialect/sql"
"entgo.io/ent/dialect/sql/sqljson"
) )
// accountRepository 实现 service.AccountRepository 接口。 // accountRepository 实现 service.AccountRepository 接口。
@@ -36,11 +37,9 @@ import (
// 设计说明: // 设计说明:
// - client: Ent 客户端,用于类型安全的 ORM 操作 // - client: Ent 客户端,用于类型安全的 ORM 操作
// - sql: 原生 SQL 执行器,用于复杂查询和批量操作 // - sql: 原生 SQL 执行器,用于复杂查询和批量操作
// - begin: SQL 事务开启器,用于需要事务的操作
type accountRepository struct { type accountRepository struct {
client *dbent.Client // Ent ORM 客户端 client *dbent.Client // Ent ORM 客户端
sql sqlExecutor // 原生 SQL 执行接口 sql sqlExecutor // 原生 SQL 执行接口
begin sqlBeginner // 事务开启接口
} }
// NewAccountRepository 创建账户仓储实例。 // NewAccountRepository 创建账户仓储实例。
@@ -52,11 +51,7 @@ func NewAccountRepository(client *dbent.Client, sqlDB *sql.DB) service.AccountRe
// newAccountRepositoryWithSQL 是内部构造函数,支持依赖注入 SQL 执行器。 // newAccountRepositoryWithSQL 是内部构造函数,支持依赖注入 SQL 执行器。
// 这种设计便于单元测试时注入 mock 对象。 // 这种设计便于单元测试时注入 mock 对象。
func newAccountRepositoryWithSQL(client *dbent.Client, sqlq sqlExecutor) *accountRepository { func newAccountRepositoryWithSQL(client *dbent.Client, sqlq sqlExecutor) *accountRepository {
var beginner sqlBeginner return &accountRepository{client: client, sql: sqlq}
if b, ok := sqlq.(sqlBeginner); ok {
beginner = b
}
return &accountRepository{client: client, sql: sqlq, begin: beginner}
} }
func (r *accountRepository) Create(ctx context.Context, account *service.Account) error { func (r *accountRepository) Create(ctx context.Context, account *service.Account) error {
@@ -146,9 +141,10 @@ func (r *accountRepository) GetByCRSAccountID(ctx context.Context, crsAccountID
return nil, nil return nil, nil
} }
// 使用 sqljson.ValueEQ 生成 JSON 路径过滤,避免手写 SQL 片段导致语法兼容问题。
m, err := r.client.Account.Query(). m, err := r.client.Account.Query().
Where(func(s *entsql.Selector) { Where(func(s *entsql.Selector) {
s.Where(entsql.ExprP("extra->>'crs_account_id' = ?", crsAccountID)) s.Where(sqljson.ValueEQ(dbaccount.FieldExtra, crsAccountID, sqljson.Path("crs_account_id")))
}). }).
Only(ctx) Only(ctx)
if err != nil { if err != nil {

View File

@@ -4,7 +4,6 @@ package repository
import ( import (
"context" "context"
"database/sql"
"testing" "testing"
"time" "time"
@@ -17,18 +16,16 @@ import (
type AccountRepoSuite struct { type AccountRepoSuite struct {
suite.Suite suite.Suite
ctx context.Context ctx context.Context
tx *sql.Tx
client *dbent.Client client *dbent.Client
repo *accountRepository repo *accountRepository
} }
func (s *AccountRepoSuite) SetupTest() { func (s *AccountRepoSuite) SetupTest() {
s.ctx = context.Background() s.ctx = context.Background()
client, tx := testEntSQLTx(s.T()) tx := testEntTx(s.T())
s.client = client s.client = tx.Client()
s.tx = tx s.repo = newAccountRepositoryWithSQL(s.client, tx)
s.repo = newAccountRepositoryWithSQL(client, tx)
} }
func TestAccountRepoSuite(t *testing.T) { func TestAccountRepoSuite(t *testing.T) {
@@ -175,7 +172,8 @@ func (s *AccountRepoSuite) TestListWithFilters() {
for _, tt := range tests { for _, tt := range tests {
s.Run(tt.name, func() { s.Run(tt.name, func() {
// 每个 case 重新获取隔离资源 // 每个 case 重新获取隔离资源
client, tx := testEntSQLTx(s.T()) tx := testEntTx(s.T())
client := tx.Client()
repo := newAccountRepositoryWithSQL(client, tx) repo := newAccountRepositoryWithSQL(client, tx)
ctx := context.Background() ctx := context.Background()

View File

@@ -20,7 +20,8 @@ func uniqueTestValue(t *testing.T, prefix string) string {
func TestUserRepository_RemoveGroupFromAllowedGroups_RemovesAllOccurrences(t *testing.T) { func TestUserRepository_RemoveGroupFromAllowedGroups_RemovesAllOccurrences(t *testing.T) {
ctx := context.Background() ctx := context.Background()
entClient, sqlTx := testEntSQLTx(t) tx := testEntTx(t)
entClient := tx.Client()
targetGroup, err := entClient.Group.Create(). targetGroup, err := entClient.Group.Create().
SetName(uniqueTestValue(t, "target-group")). SetName(uniqueTestValue(t, "target-group")).
@@ -33,7 +34,7 @@ func TestUserRepository_RemoveGroupFromAllowedGroups_RemovesAllOccurrences(t *te
Save(ctx) Save(ctx)
require.NoError(t, err) require.NoError(t, err)
repo := newUserRepositoryWithSQL(entClient, sqlTx) repo := newUserRepositoryWithSQL(entClient, tx)
u1 := &service.User{ u1 := &service.User{
Email: uniqueTestValue(t, "u1") + "@example.com", Email: uniqueTestValue(t, "u1") + "@example.com",
@@ -81,7 +82,8 @@ func TestUserRepository_RemoveGroupFromAllowedGroups_RemovesAllOccurrences(t *te
func TestGroupRepository_DeleteCascade_RemovesAllowedGroupsAndClearsApiKeys(t *testing.T) { func TestGroupRepository_DeleteCascade_RemovesAllowedGroupsAndClearsApiKeys(t *testing.T) {
ctx := context.Background() ctx := context.Background()
entClient, sqlTx := testEntSQLTx(t) tx := testEntTx(t)
entClient := tx.Client()
targetGroup, err := entClient.Group.Create(). targetGroup, err := entClient.Group.Create().
SetName(uniqueTestValue(t, "delete-cascade-target")). SetName(uniqueTestValue(t, "delete-cascade-target")).
@@ -94,8 +96,8 @@ func TestGroupRepository_DeleteCascade_RemovesAllowedGroupsAndClearsApiKeys(t *t
Save(ctx) Save(ctx)
require.NoError(t, err) require.NoError(t, err)
userRepo := newUserRepositoryWithSQL(entClient, sqlTx) userRepo := newUserRepositoryWithSQL(entClient, tx)
groupRepo := newGroupRepositoryWithSQL(entClient, sqlTx) groupRepo := newGroupRepositoryWithSQL(entClient, tx)
apiKeyRepo := NewApiKeyRepository(entClient) apiKeyRepo := NewApiKeyRepository(entClient)
u := &service.User{ u := &service.User{
@@ -141,4 +143,3 @@ func TestGroupRepository_DeleteCascade_RemovesAllowedGroupsAndClearsApiKeys(t *t
require.NoError(t, err) require.NoError(t, err)
require.Nil(t, keyAfter.GroupID) require.Nil(t, keyAfter.GroupID)
} }

View File

@@ -2,9 +2,11 @@ package repository
import ( import (
"context" "context"
"time"
dbent "github.com/Wei-Shaw/sub2api/ent" dbent "github.com/Wei-Shaw/sub2api/ent"
"github.com/Wei-Shaw/sub2api/ent/apikey" "github.com/Wei-Shaw/sub2api/ent/apikey"
"github.com/Wei-Shaw/sub2api/ent/schema/mixins"
"github.com/Wei-Shaw/sub2api/internal/service" "github.com/Wei-Shaw/sub2api/internal/service"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination" "github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
@@ -18,6 +20,11 @@ func NewApiKeyRepository(client *dbent.Client) service.ApiKeyRepository {
return &apiKeyRepository{client: client} return &apiKeyRepository{client: client}
} }
func (r *apiKeyRepository) activeQuery() *dbent.ApiKeyQuery {
// 默认过滤已软删除记录,避免删除后仍被查询到。
return r.client.ApiKey.Query().Where(apikey.DeletedAtIsNil())
}
func (r *apiKeyRepository) Create(ctx context.Context, key *service.ApiKey) error { func (r *apiKeyRepository) Create(ctx context.Context, key *service.ApiKey) error {
created, err := r.client.ApiKey.Create(). created, err := r.client.ApiKey.Create().
SetUserID(key.UserID). SetUserID(key.UserID).
@@ -35,7 +42,7 @@ func (r *apiKeyRepository) Create(ctx context.Context, key *service.ApiKey) erro
} }
func (r *apiKeyRepository) GetByID(ctx context.Context, id int64) (*service.ApiKey, error) { func (r *apiKeyRepository) GetByID(ctx context.Context, id int64) (*service.ApiKey, error) {
m, err := r.client.ApiKey.Query(). m, err := r.activeQuery().
Where(apikey.IDEQ(id)). Where(apikey.IDEQ(id)).
WithUser(). WithUser().
WithGroup(). WithGroup().
@@ -55,7 +62,7 @@ func (r *apiKeyRepository) GetByID(ctx context.Context, id int64) (*service.ApiK
// - 不加载完整的 ApiKey 实体及其关联数据User、Group 等) // - 不加载完整的 ApiKey 实体及其关联数据User、Group 等)
// - 适用于权限验证等只需用户 ID 的场景(如删除前的所有权检查) // - 适用于权限验证等只需用户 ID 的场景(如删除前的所有权检查)
func (r *apiKeyRepository) GetOwnerID(ctx context.Context, id int64) (int64, error) { func (r *apiKeyRepository) GetOwnerID(ctx context.Context, id int64) (int64, error) {
m, err := r.client.ApiKey.Query(). m, err := r.activeQuery().
Where(apikey.IDEQ(id)). Where(apikey.IDEQ(id)).
Select(apikey.FieldUserID). Select(apikey.FieldUserID).
Only(ctx) Only(ctx)
@@ -69,7 +76,7 @@ func (r *apiKeyRepository) GetOwnerID(ctx context.Context, id int64) (int64, err
} }
func (r *apiKeyRepository) GetByKey(ctx context.Context, key string) (*service.ApiKey, error) { func (r *apiKeyRepository) GetByKey(ctx context.Context, key string) (*service.ApiKey, error) {
m, err := r.client.ApiKey.Query(). m, err := r.activeQuery().
Where(apikey.KeyEQ(key)). Where(apikey.KeyEQ(key)).
WithUser(). WithUser().
WithGroup(). WithGroup().
@@ -84,6 +91,14 @@ func (r *apiKeyRepository) GetByKey(ctx context.Context, key string) (*service.A
} }
func (r *apiKeyRepository) Update(ctx context.Context, key *service.ApiKey) error { func (r *apiKeyRepository) Update(ctx context.Context, key *service.ApiKey) error {
exists, err := r.activeQuery().Where(apikey.IDEQ(key.ID)).Exist(ctx)
if err != nil {
return err
}
if !exists {
return service.ErrApiKeyNotFound
}
builder := r.client.ApiKey.UpdateOneID(key.ID). builder := r.client.ApiKey.UpdateOneID(key.ID).
SetName(key.Name). SetName(key.Name).
SetStatus(key.Status) SetStatus(key.Status)
@@ -105,12 +120,34 @@ func (r *apiKeyRepository) Update(ctx context.Context, key *service.ApiKey) erro
} }
func (r *apiKeyRepository) Delete(ctx context.Context, id int64) error { func (r *apiKeyRepository) Delete(ctx context.Context, id int64) error {
_, err := r.client.ApiKey.Delete().Where(apikey.IDEQ(id)).Exec(ctx) // 显式软删除:避免依赖 Hook 行为,确保 deleted_at 一定被设置。
return err affected, err := r.client.ApiKey.Update().
Where(apikey.IDEQ(id), apikey.DeletedAtIsNil()).
SetDeletedAt(time.Now()).
Save(ctx)
if err != nil {
if dbent.IsNotFound(err) {
return service.ErrApiKeyNotFound
}
return err
}
if affected == 0 {
exists, err := r.client.ApiKey.Query().
Where(apikey.IDEQ(id)).
Exist(mixins.SkipSoftDelete(ctx))
if err != nil {
return err
}
if exists {
return nil
}
return service.ErrApiKeyNotFound
}
return nil
} }
func (r *apiKeyRepository) ListByUserID(ctx context.Context, userID int64, params pagination.PaginationParams) ([]service.ApiKey, *pagination.PaginationResult, error) { func (r *apiKeyRepository) ListByUserID(ctx context.Context, userID int64, params pagination.PaginationParams) ([]service.ApiKey, *pagination.PaginationResult, error) {
q := r.client.ApiKey.Query().Where(apikey.UserIDEQ(userID)) q := r.activeQuery().Where(apikey.UserIDEQ(userID))
total, err := q.Count(ctx) total, err := q.Count(ctx)
if err != nil { if err != nil {
@@ -141,7 +178,7 @@ func (r *apiKeyRepository) VerifyOwnership(ctx context.Context, userID int64, ap
} }
ids, err := r.client.ApiKey.Query(). ids, err := r.client.ApiKey.Query().
Where(apikey.UserIDEQ(userID), apikey.IDIn(apiKeyIDs...)). Where(apikey.UserIDEQ(userID), apikey.IDIn(apiKeyIDs...), apikey.DeletedAtIsNil()).
IDs(ctx) IDs(ctx)
if err != nil { if err != nil {
return nil, err return nil, err
@@ -150,17 +187,17 @@ func (r *apiKeyRepository) VerifyOwnership(ctx context.Context, userID int64, ap
} }
func (r *apiKeyRepository) CountByUserID(ctx context.Context, userID int64) (int64, error) { func (r *apiKeyRepository) CountByUserID(ctx context.Context, userID int64) (int64, error) {
count, err := r.client.ApiKey.Query().Where(apikey.UserIDEQ(userID)).Count(ctx) count, err := r.activeQuery().Where(apikey.UserIDEQ(userID)).Count(ctx)
return int64(count), err return int64(count), err
} }
func (r *apiKeyRepository) ExistsByKey(ctx context.Context, key string) (bool, error) { func (r *apiKeyRepository) ExistsByKey(ctx context.Context, key string) (bool, error) {
count, err := r.client.ApiKey.Query().Where(apikey.KeyEQ(key)).Count(ctx) count, err := r.activeQuery().Where(apikey.KeyEQ(key)).Count(ctx)
return count > 0, err return count > 0, err
} }
func (r *apiKeyRepository) ListByGroupID(ctx context.Context, groupID int64, params pagination.PaginationParams) ([]service.ApiKey, *pagination.PaginationResult, error) { func (r *apiKeyRepository) ListByGroupID(ctx context.Context, groupID int64, params pagination.PaginationParams) ([]service.ApiKey, *pagination.PaginationResult, error) {
q := r.client.ApiKey.Query().Where(apikey.GroupIDEQ(groupID)) q := r.activeQuery().Where(apikey.GroupIDEQ(groupID))
total, err := q.Count(ctx) total, err := q.Count(ctx)
if err != nil { if err != nil {
@@ -187,7 +224,7 @@ func (r *apiKeyRepository) ListByGroupID(ctx context.Context, groupID int64, par
// SearchApiKeys searches API keys by user ID and/or keyword (name) // SearchApiKeys searches API keys by user ID and/or keyword (name)
func (r *apiKeyRepository) SearchApiKeys(ctx context.Context, userID int64, keyword string, limit int) ([]service.ApiKey, error) { func (r *apiKeyRepository) SearchApiKeys(ctx context.Context, userID int64, keyword string, limit int) ([]service.ApiKey, error) {
q := r.client.ApiKey.Query() q := r.activeQuery()
if userID > 0 { if userID > 0 {
q = q.Where(apikey.UserIDEQ(userID)) q = q.Where(apikey.UserIDEQ(userID))
} }
@@ -211,7 +248,7 @@ func (r *apiKeyRepository) SearchApiKeys(ctx context.Context, userID int64, keyw
// ClearGroupIDByGroupID 将指定分组的所有 API Key 的 group_id 设为 nil // ClearGroupIDByGroupID 将指定分组的所有 API Key 的 group_id 设为 nil
func (r *apiKeyRepository) ClearGroupIDByGroupID(ctx context.Context, groupID int64) (int64, error) { func (r *apiKeyRepository) ClearGroupIDByGroupID(ctx context.Context, groupID int64) (int64, error) {
n, err := r.client.ApiKey.Update(). n, err := r.client.ApiKey.Update().
Where(apikey.GroupIDEQ(groupID)). Where(apikey.GroupIDEQ(groupID), apikey.DeletedAtIsNil()).
ClearGroupID(). ClearGroupID().
Save(ctx) Save(ctx)
return int64(n), err return int64(n), err
@@ -219,7 +256,7 @@ func (r *apiKeyRepository) ClearGroupIDByGroupID(ctx context.Context, groupID in
// CountByGroupID 获取分组的 API Key 数量 // CountByGroupID 获取分组的 API Key 数量
func (r *apiKeyRepository) CountByGroupID(ctx context.Context, groupID int64) (int64, error) { func (r *apiKeyRepository) CountByGroupID(ctx context.Context, groupID int64) (int64, error) {
count, err := r.client.ApiKey.Query().Where(apikey.GroupIDEQ(groupID)).Count(ctx) count, err := r.activeQuery().Where(apikey.GroupIDEQ(groupID)).Count(ctx)
return int64(count), err return int64(count), err
} }

View File

@@ -21,9 +21,9 @@ type ApiKeyRepoSuite struct {
func (s *ApiKeyRepoSuite) SetupTest() { func (s *ApiKeyRepoSuite) SetupTest() {
s.ctx = context.Background() s.ctx = context.Background()
entClient, _ := testEntSQLTx(s.T()) tx := testEntTx(s.T())
s.client = entClient s.client = tx.Client()
s.repo = NewApiKeyRepository(entClient).(*apiKeyRepository) s.repo = NewApiKeyRepository(s.client).(*apiKeyRepository)
} }
func TestApiKeyRepoSuite(t *testing.T) { func TestApiKeyRepoSuite(t *testing.T) {

View File

@@ -3,6 +3,7 @@ package repository
import ( import (
"context" "context"
"database/sql" "database/sql"
"errors"
dbent "github.com/Wei-Shaw/sub2api/ent" dbent "github.com/Wei-Shaw/sub2api/ent"
"github.com/Wei-Shaw/sub2api/ent/apikey" "github.com/Wei-Shaw/sub2api/ent/apikey"
@@ -10,26 +11,16 @@ import (
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination" "github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/service" "github.com/Wei-Shaw/sub2api/internal/service"
"github.com/lib/pq" "github.com/lib/pq"
"entgo.io/ent/dialect"
entsql "entgo.io/ent/dialect/sql"
) )
type sqlExecutor interface { type sqlExecutor interface {
ExecContext(ctx context.Context, query string, args ...any) (sql.Result, error) ExecContext(ctx context.Context, query string, args ...any) (sql.Result, error)
QueryContext(ctx context.Context, query string, args ...any) (*sql.Rows, error) QueryContext(ctx context.Context, query string, args ...any) (*sql.Rows, error)
QueryRowContext(ctx context.Context, query string, args ...any) *sql.Row
}
type sqlBeginner interface {
sqlExecutor
BeginTx(ctx context.Context, opts *sql.TxOptions) (*sql.Tx, error)
} }
type groupRepository struct { type groupRepository struct {
client *dbent.Client client *dbent.Client
sql sqlExecutor sql sqlExecutor
begin sqlBeginner
} }
func NewGroupRepository(client *dbent.Client, sqlDB *sql.DB) service.GroupRepository { func NewGroupRepository(client *dbent.Client, sqlDB *sql.DB) service.GroupRepository {
@@ -37,11 +28,7 @@ func NewGroupRepository(client *dbent.Client, sqlDB *sql.DB) service.GroupReposi
} }
func newGroupRepositoryWithSQL(client *dbent.Client, sqlq sqlExecutor) *groupRepository { func newGroupRepositoryWithSQL(client *dbent.Client, sqlq sqlExecutor) *groupRepository {
var beginner sqlBeginner return &groupRepository{client: client, sql: sqlq}
if b, ok := sqlq.(sqlBeginner); ok {
beginner = b
}
return &groupRepository{client: client, sql: sqlq, begin: beginner}
} }
func (r *groupRepository) Create(ctx context.Context, groupIn *service.Group) error { func (r *groupRepository) Create(ctx context.Context, groupIn *service.Group) error {
@@ -214,7 +201,7 @@ func (r *groupRepository) ExistsByName(ctx context.Context, name string) (bool,
func (r *groupRepository) GetAccountCount(ctx context.Context, groupID int64) (int64, error) { func (r *groupRepository) GetAccountCount(ctx context.Context, groupID int64) (int64, error) {
var count int64 var count int64
if err := r.sql.QueryRowContext(ctx, "SELECT COUNT(*) FROM account_groups WHERE group_id = $1", groupID).Scan(&count); err != nil { if err := scanSingleRow(ctx, r.sql, "SELECT COUNT(*) FROM account_groups WHERE group_id = $1", []any{groupID}, &count); err != nil {
return 0, err return 0, err
} }
return count, nil return count, nil
@@ -236,31 +223,44 @@ func (r *groupRepository) DeleteCascade(ctx context.Context, id int64) ([]int64,
} }
groupSvc := groupEntityToService(g) groupSvc := groupEntityToService(g)
exec := r.sql // 使用 ent 事务统一包裹:避免手工基于 *sql.Tx 构造 ent client 带来的驱动断言问题,
// 同时保证级联删除的原子性。
tx, err := r.client.Tx(ctx)
if err != nil && !errors.Is(err, dbent.ErrTxStarted) {
return nil, err
}
exec := r.client
txClient := r.client txClient := r.client
var sqlTx *sql.Tx if err == nil {
defer func() { _ = tx.Rollback() }()
if r.begin != nil { exec = tx.Client()
sqlTx, err = r.begin.BeginTx(ctx, nil) txClient = exec
if err != nil { } else {
return nil, err // 已处于外部事务中ErrTxStarted复用当前 client 参与同一事务。
}
exec = sqlTx
txClient = entClientFromSQLTx(sqlTx)
// 注意:不能调用 txClient.Close(),因为基于事务的 ent client
// 在 Close() 时会尝试将 ExecQuerier 断言为 *sql.DB但实际是 *sql.Tx
// 事务的清理通过 sqlTx.Rollback() 和 sqlTx.Commit() 完成
defer func() { _ = sqlTx.Rollback() }()
} }
// Lock the group row to avoid concurrent writes while we cascade. // Lock the group row to avoid concurrent writes while we cascade.
var lockedID int64 // 这里使用 exec.QueryContext 手动扫描,确保同一事务内加锁并能区分“未找到”与其他错误。
if err := exec.QueryRowContext(ctx, "SELECT id FROM groups WHERE id = $1 FOR UPDATE", id).Scan(&lockedID); err != nil { rows, err := exec.QueryContext(ctx, "SELECT id FROM groups WHERE id = $1 FOR UPDATE", id)
if errorsIsNoRows(err) { if err != nil {
return nil, service.ErrGroupNotFound
}
return nil, err return nil, err
} }
var lockedID int64
if rows.Next() {
if err := rows.Scan(&lockedID); err != nil {
_ = rows.Close()
return nil, err
}
}
if err := rows.Close(); err != nil {
return nil, err
}
if err := rows.Err(); err != nil {
return nil, err
}
if lockedID == 0 {
return nil, service.ErrGroupNotFound
}
var affectedUserIDs []int64 var affectedUserIDs []int64
if groupSvc.IsSubscriptionType() { if groupSvc.IsSubscriptionType() {
@@ -319,8 +319,8 @@ func (r *groupRepository) DeleteCascade(ctx context.Context, id int64) ([]int64,
return nil, err return nil, err
} }
if sqlTx != nil { if tx != nil {
if err := sqlTx.Commit(); err != nil { if err := tx.Commit(); err != nil {
return nil, err return nil, err
} }
} }
@@ -359,11 +359,6 @@ func (r *groupRepository) loadAccountCounts(ctx context.Context, groupIDs []int6
return counts, nil return counts, nil
} }
func entClientFromSQLTx(tx *sql.Tx) *dbent.Client {
drv := entsql.NewDriver(dialect.Postgres, entsql.Conn{ExecQuerier: tx})
return dbent.NewClient(dbent.Driver(drv))
}
func errorsIsNoRows(err error) bool { func errorsIsNoRows(err error) bool {
return err == sql.ErrNoRows return err == sql.ErrNoRows
} }

View File

@@ -4,9 +4,9 @@ package repository
import ( import (
"context" "context"
"database/sql"
"testing" "testing"
dbent "github.com/Wei-Shaw/sub2api/ent"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination" "github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/service" "github.com/Wei-Shaw/sub2api/internal/service"
"github.com/stretchr/testify/suite" "github.com/stretchr/testify/suite"
@@ -15,15 +15,15 @@ import (
type GroupRepoSuite struct { type GroupRepoSuite struct {
suite.Suite suite.Suite
ctx context.Context ctx context.Context
tx *sql.Tx tx *dbent.Tx
repo *groupRepository repo *groupRepository
} }
func (s *GroupRepoSuite) SetupTest() { func (s *GroupRepoSuite) SetupTest() {
s.ctx = context.Background() s.ctx = context.Background()
entClient, tx := testEntSQLTx(s.T()) tx := testEntTx(s.T())
s.tx = tx s.tx = tx
s.repo = newGroupRepositoryWithSQL(entClient, tx) s.repo = newGroupRepositoryWithSQL(tx.Client(), tx)
} }
func TestGroupRepoSuite(t *testing.T) { func TestGroupRepoSuite(t *testing.T) {
@@ -99,6 +99,9 @@ func (s *GroupRepoSuite) TestDelete() {
// --- List / ListWithFilters --- // --- List / ListWithFilters ---
func (s *GroupRepoSuite) TestList() { func (s *GroupRepoSuite) TestList() {
baseGroups, basePage, err := s.repo.List(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10})
s.Require().NoError(err, "List base")
s.Require().NoError(s.repo.Create(s.ctx, &service.Group{ s.Require().NoError(s.repo.Create(s.ctx, &service.Group{
Name: "g1", Name: "g1",
Platform: service.PlatformAnthropic, Platform: service.PlatformAnthropic,
@@ -118,12 +121,20 @@ func (s *GroupRepoSuite) TestList() {
groups, page, err := s.repo.List(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}) groups, page, err := s.repo.List(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10})
s.Require().NoError(err, "List") s.Require().NoError(err, "List")
// 3 default groups + 2 test groups = 5 total s.Require().Len(groups, len(baseGroups)+2)
s.Require().Len(groups, 5) s.Require().Equal(basePage.Total+2, page.Total)
s.Require().Equal(int64(5), page.Total)
} }
func (s *GroupRepoSuite) TestListWithFilters_Platform() { func (s *GroupRepoSuite) TestListWithFilters_Platform() {
baseGroups, _, err := s.repo.ListWithFilters(
s.ctx,
pagination.PaginationParams{Page: 1, PageSize: 10},
service.PlatformOpenAI,
"",
nil,
)
s.Require().NoError(err, "ListWithFilters base")
s.Require().NoError(s.repo.Create(s.ctx, &service.Group{ s.Require().NoError(s.repo.Create(s.ctx, &service.Group{
Name: "g1", Name: "g1",
Platform: service.PlatformAnthropic, Platform: service.PlatformAnthropic,
@@ -143,8 +154,7 @@ func (s *GroupRepoSuite) TestListWithFilters_Platform() {
groups, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, service.PlatformOpenAI, "", nil) groups, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, service.PlatformOpenAI, "", nil)
s.Require().NoError(err) s.Require().NoError(err)
// 1 default openai group + 1 test openai group = 2 total s.Require().Len(groups, len(baseGroups)+1)
s.Require().Len(groups, 2)
// Verify all groups are OpenAI platform // Verify all groups are OpenAI platform
for _, g := range groups { for _, g := range groups {
s.Require().Equal(service.PlatformOpenAI, g.Platform) s.Require().Equal(service.PlatformOpenAI, g.Platform)
@@ -221,11 +231,13 @@ func (s *GroupRepoSuite) TestListWithFilters_AccountCount() {
s.Require().NoError(s.repo.Create(s.ctx, g2)) s.Require().NoError(s.repo.Create(s.ctx, g2))
var accountID int64 var accountID int64
s.Require().NoError(s.tx.QueryRowContext( s.Require().NoError(scanSingleRow(
s.ctx, s.ctx,
s.tx,
"INSERT INTO accounts (name, platform, type) VALUES ($1, $2, $3) RETURNING id", "INSERT INTO accounts (name, platform, type) VALUES ($1, $2, $3) RETURNING id",
"acc1", service.PlatformAnthropic, service.AccountTypeOAuth, []any{"acc1", service.PlatformAnthropic, service.AccountTypeOAuth},
).Scan(&accountID)) &accountID,
))
_, err := s.tx.ExecContext(s.ctx, "INSERT INTO account_groups (account_id, group_id, priority, created_at) VALUES ($1, $2, $3, NOW())", accountID, g1.ID, 1) _, err := s.tx.ExecContext(s.ctx, "INSERT INTO account_groups (account_id, group_id, priority, created_at) VALUES ($1, $2, $3, NOW())", accountID, g1.ID, 1)
s.Require().NoError(err) s.Require().NoError(err)
_, err = s.tx.ExecContext(s.ctx, "INSERT INTO account_groups (account_id, group_id, priority, created_at) VALUES ($1, $2, $3, NOW())", accountID, g2.ID, 1) _, err = s.tx.ExecContext(s.ctx, "INSERT INTO account_groups (account_id, group_id, priority, created_at) VALUES ($1, $2, $3, NOW())", accountID, g2.ID, 1)
@@ -243,6 +255,9 @@ func (s *GroupRepoSuite) TestListWithFilters_AccountCount() {
// --- ListActive / ListActiveByPlatform --- // --- ListActive / ListActiveByPlatform ---
func (s *GroupRepoSuite) TestListActive() { func (s *GroupRepoSuite) TestListActive() {
baseGroups, err := s.repo.ListActive(s.ctx)
s.Require().NoError(err, "ListActive base")
s.Require().NoError(s.repo.Create(s.ctx, &service.Group{ s.Require().NoError(s.repo.Create(s.ctx, &service.Group{
Name: "active1", Name: "active1",
Platform: service.PlatformAnthropic, Platform: service.PlatformAnthropic,
@@ -262,8 +277,7 @@ func (s *GroupRepoSuite) TestListActive() {
groups, err := s.repo.ListActive(s.ctx) groups, err := s.repo.ListActive(s.ctx)
s.Require().NoError(err, "ListActive") s.Require().NoError(err, "ListActive")
// 3 default groups (all active) + 1 test active group = 4 total s.Require().Len(groups, len(baseGroups)+1)
s.Require().Len(groups, 4)
// Verify our test group is in the results // Verify our test group is in the results
var found bool var found bool
for _, g := range groups { for _, g := range groups {
@@ -351,17 +365,21 @@ func (s *GroupRepoSuite) TestGetAccountCount() {
s.Require().NoError(s.repo.Create(s.ctx, group)) s.Require().NoError(s.repo.Create(s.ctx, group))
var a1 int64 var a1 int64
s.Require().NoError(s.tx.QueryRowContext( s.Require().NoError(scanSingleRow(
s.ctx, s.ctx,
s.tx,
"INSERT INTO accounts (name, platform, type) VALUES ($1, $2, $3) RETURNING id", "INSERT INTO accounts (name, platform, type) VALUES ($1, $2, $3) RETURNING id",
"a1", service.PlatformAnthropic, service.AccountTypeOAuth, []any{"a1", service.PlatformAnthropic, service.AccountTypeOAuth},
).Scan(&a1)) &a1,
))
var a2 int64 var a2 int64
s.Require().NoError(s.tx.QueryRowContext( s.Require().NoError(scanSingleRow(
s.ctx, s.ctx,
s.tx,
"INSERT INTO accounts (name, platform, type) VALUES ($1, $2, $3) RETURNING id", "INSERT INTO accounts (name, platform, type) VALUES ($1, $2, $3) RETURNING id",
"a2", service.PlatformAnthropic, service.AccountTypeOAuth, []any{"a2", service.PlatformAnthropic, service.AccountTypeOAuth},
).Scan(&a2)) &a2,
))
_, err := s.tx.ExecContext(s.ctx, "INSERT INTO account_groups (account_id, group_id, priority, created_at) VALUES ($1, $2, $3, NOW())", a1, group.ID, 1) _, err := s.tx.ExecContext(s.ctx, "INSERT INTO account_groups (account_id, group_id, priority, created_at) VALUES ($1, $2, $3, NOW())", a1, group.ID, 1)
s.Require().NoError(err) s.Require().NoError(err)
@@ -402,11 +420,13 @@ func (s *GroupRepoSuite) TestDeleteAccountGroupsByGroupID() {
} }
s.Require().NoError(s.repo.Create(s.ctx, g)) s.Require().NoError(s.repo.Create(s.ctx, g))
var accountID int64 var accountID int64
s.Require().NoError(s.tx.QueryRowContext( s.Require().NoError(scanSingleRow(
s.ctx, s.ctx,
s.tx,
"INSERT INTO accounts (name, platform, type) VALUES ($1, $2, $3) RETURNING id", "INSERT INTO accounts (name, platform, type) VALUES ($1, $2, $3) RETURNING id",
"acc-del", service.PlatformAnthropic, service.AccountTypeOAuth, []any{"acc-del", service.PlatformAnthropic, service.AccountTypeOAuth},
).Scan(&accountID)) &accountID,
))
_, err := s.tx.ExecContext(s.ctx, "INSERT INTO account_groups (account_id, group_id, priority, created_at) VALUES ($1, $2, $3, NOW())", accountID, g.ID, 1) _, err := s.tx.ExecContext(s.ctx, "INSERT INTO account_groups (account_id, group_id, priority, created_at) VALUES ($1, $2, $3, NOW())", accountID, g.ID, 1)
s.Require().NoError(err) s.Require().NoError(err)
@@ -432,11 +452,13 @@ func (s *GroupRepoSuite) TestDeleteAccountGroupsByGroupID_MultipleAccounts() {
insertAccount := func(name string) int64 { insertAccount := func(name string) int64 {
var id int64 var id int64
s.Require().NoError(s.tx.QueryRowContext( s.Require().NoError(scanSingleRow(
s.ctx, s.ctx,
s.tx,
"INSERT INTO accounts (name, platform, type) VALUES ($1, $2, $3) RETURNING id", "INSERT INTO accounts (name, platform, type) VALUES ($1, $2, $3) RETURNING id",
name, service.PlatformAnthropic, service.AccountTypeOAuth, []any{name, service.PlatformAnthropic, service.AccountTypeOAuth},
).Scan(&id)) &id,
))
return id return id
} }
a1 := insertAccount("a1") a1 := insertAccount("a1")

View File

@@ -36,8 +36,9 @@ const (
) )
var ( var (
integrationDB *sql.DB integrationDB *sql.DB
integrationRedis *redisclient.Client integrationEntClient *dbent.Client
integrationRedis *redisclient.Client
redisNamespaceSeq uint64 redisNamespaceSeq uint64
) )
@@ -101,6 +102,10 @@ func TestMain(m *testing.M) {
os.Exit(1) os.Exit(1)
} }
// 创建 ent client 用于集成测试
drv := entsql.OpenDB(dialect.Postgres, integrationDB)
integrationEntClient = dbent.NewClient(dbent.Driver(drv))
redisHost, err := redisContainer.Host(ctx) redisHost, err := redisContainer.Host(ctx)
if err != nil { if err != nil {
log.Printf("failed to get redis host: %v", err) log.Printf("failed to get redis host: %v", err)
@@ -123,6 +128,7 @@ func TestMain(m *testing.M) {
code := m.Run() code := m.Run()
_ = integrationEntClient.Close()
_ = integrationRedis.Close() _ = integrationRedis.Close()
_ = integrationDB.Close() _ = integrationDB.Close()
@@ -193,18 +199,38 @@ func testTx(t *testing.T) *sql.Tx {
return tx return tx
} }
// testEntClient 返回全局的 ent client用于测试需要内部管理事务的代码如 Create/Update 方法)。
// 注意:此 client 的操作会真正写入数据库,测试结束后不会自动回滚。
func testEntClient(t *testing.T) *dbent.Client {
t.Helper()
return integrationEntClient
}
// testEntTx 返回一个 ent 事务,用于需要事务隔离的测试。
// 测试结束后会自动回滚,不会影响数据库状态。
func testEntTx(t *testing.T) *dbent.Tx {
t.Helper()
tx, err := integrationEntClient.Tx(context.Background())
require.NoError(t, err, "begin ent tx")
t.Cleanup(func() {
_ = tx.Rollback()
})
return tx
}
// testEntSQLTx 已弃用:不要在新测试中使用此函数。
// 基于 *sql.Tx 创建的 ent client 在调用 client.Tx() 时会 panic。
// 对于需要测试内部使用事务的代码,请使用 testEntClient。
// 对于需要事务隔离的测试,请使用 testEntTx。
//
// Deprecated: Use testEntClient or testEntTx instead.
func testEntSQLTx(t *testing.T) (*dbent.Client, *sql.Tx) { func testEntSQLTx(t *testing.T) (*dbent.Client, *sql.Tx) {
t.Helper() t.Helper()
tx := testTx(t) // 直接失败,避免旧测试误用导致的事务嵌套 panic。
drv := entsql.NewDriver(dialect.Postgres, entsql.Conn{ExecQuerier: tx}) t.Fatalf("testEntSQLTx 已弃用:请使用 testEntClient 或 testEntTx")
client := dbent.NewClient(dbent.Driver(drv)) return nil, nil
t.Cleanup(func() {
_ = client.Close()
})
return client, tx
} }
func testRedis(t *testing.T) *redisclient.Client { func testRedis(t *testing.T) *redisclient.Client {
@@ -363,13 +389,16 @@ type IntegrationDBSuite struct {
suite.Suite suite.Suite
ctx context.Context ctx context.Context
client *dbent.Client client *dbent.Client
tx *sql.Tx tx *dbent.Tx
} }
// SetupTest initializes ctx and client for each test method. // SetupTest initializes ctx and client for each test method.
func (s *IntegrationDBSuite) SetupTest() { func (s *IntegrationDBSuite) SetupTest() {
s.ctx = context.Background() s.ctx = context.Background()
s.client, s.tx = testEntSQLTx(s.T()) // 统一使用 ent.Tx确保每个测试都有独立事务并自动回滚。
tx := testEntTx(s.T())
s.tx = tx
s.client = tx.Client()
} }
// RequireNoError is a convenience method wrapping require.NoError with s.T(). // RequireNoError is a convenience method wrapping require.NoError with s.T().

View File

@@ -13,7 +13,6 @@ import (
type sqlQuerier interface { type sqlQuerier interface {
QueryContext(ctx context.Context, query string, args ...any) (*sql.Rows, error) QueryContext(ctx context.Context, query string, args ...any) (*sql.Rows, error)
QueryRowContext(ctx context.Context, query string, args ...any) *sql.Row
} }
type proxyRepository struct { type proxyRepository struct {
@@ -170,9 +169,8 @@ func (r *proxyRepository) ExistsByHostPortAuth(ctx context.Context, host string,
// CountAccountsByProxyID returns the number of accounts using a specific proxy // CountAccountsByProxyID returns the number of accounts using a specific proxy
func (r *proxyRepository) CountAccountsByProxyID(ctx context.Context, proxyID int64) (int64, error) { func (r *proxyRepository) CountAccountsByProxyID(ctx context.Context, proxyID int64) (int64, error) {
row := r.sql.QueryRowContext(ctx, "SELECT COUNT(*) FROM accounts WHERE proxy_id = $1", proxyID)
var count int64 var count int64
if err := row.Scan(&count); err != nil { if err := scanSingleRow(ctx, r.sql, "SELECT COUNT(*) FROM accounts WHERE proxy_id = $1", []any{proxyID}, &count); err != nil {
return 0, err return 0, err
} }
return count, nil return count, nil

View File

@@ -4,10 +4,10 @@ package repository
import ( import (
"context" "context"
"database/sql"
"testing" "testing"
"time" "time"
dbent "github.com/Wei-Shaw/sub2api/ent"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination" "github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/service" "github.com/Wei-Shaw/sub2api/internal/service"
"github.com/stretchr/testify/suite" "github.com/stretchr/testify/suite"
@@ -15,16 +15,16 @@ import (
type ProxyRepoSuite struct { type ProxyRepoSuite struct {
suite.Suite suite.Suite
ctx context.Context ctx context.Context
sqlTx *sql.Tx tx *dbent.Tx
repo *proxyRepository repo *proxyRepository
} }
func (s *ProxyRepoSuite) SetupTest() { func (s *ProxyRepoSuite) SetupTest() {
s.ctx = context.Background() s.ctx = context.Background()
entClient, sqlTx := testEntSQLTx(s.T()) tx := testEntTx(s.T())
s.sqlTx = sqlTx s.tx = tx
s.repo = newProxyRepositoryWithSQL(entClient, sqlTx) s.repo = newProxyRepositoryWithSQL(tx.Client(), tx)
} }
func TestProxyRepoSuite(t *testing.T) { func TestProxyRepoSuite(t *testing.T) {
@@ -306,7 +306,7 @@ func (s *ProxyRepoSuite) mustCreateProxyWithTimes(name, status string, createdAt
Port: 8080, Port: 8080,
Status: status, Status: status,
}) })
_, err := s.sqlTx.ExecContext(s.ctx, "UPDATE proxies SET created_at = $1, updated_at = $1 WHERE id = $2", createdAt, p.ID) _, err := s.tx.ExecContext(s.ctx, "UPDATE proxies SET created_at = $1, updated_at = $1 WHERE id = $2", createdAt, p.ID)
s.Require().NoError(err, "update proxy timestamps") s.Require().NoError(err, "update proxy timestamps")
return p return p
} }
@@ -317,7 +317,7 @@ func (s *ProxyRepoSuite) mustInsertAccount(name string, proxyID *int64) {
if proxyID != nil { if proxyID != nil {
pid = *proxyID pid = *proxyID
} }
_, err := s.sqlTx.ExecContext( _, err := s.tx.ExecContext(
s.ctx, s.ctx,
"INSERT INTO accounts (name, platform, type, proxy_id) VALUES ($1, $2, $3, $4)", "INSERT INTO accounts (name, platform, type, proxy_id) VALUES ($1, $2, $3, $4)",
name, name,

View File

@@ -22,9 +22,9 @@ type RedeemCodeRepoSuite struct {
func (s *RedeemCodeRepoSuite) SetupTest() { func (s *RedeemCodeRepoSuite) SetupTest() {
s.ctx = context.Background() s.ctx = context.Background()
entClient, _ := testEntSQLTx(s.T()) tx := testEntTx(s.T())
s.client = entClient s.client = tx.Client()
s.repo = NewRedeemCodeRepository(entClient).(*redeemCodeRepository) s.repo = NewRedeemCodeRepository(s.client).(*redeemCodeRepository)
} }
func TestRedeemCodeRepoSuite(t *testing.T) { func TestRedeemCodeRepoSuite(t *testing.T) {

View File

@@ -18,8 +18,8 @@ type SettingRepoSuite struct {
func (s *SettingRepoSuite) SetupTest() { func (s *SettingRepoSuite) SetupTest() {
s.ctx = context.Background() s.ctx = context.Background()
entClient, _ := testEntSQLTx(s.T()) tx := testEntTx(s.T())
s.repo = NewSettingRepository(entClient).(*settingRepository) s.repo = NewSettingRepository(tx.Client()).(*settingRepository)
} }
func TestSettingRepoSuite(t *testing.T) { func TestSettingRepoSuite(t *testing.T) {

View File

@@ -34,7 +34,8 @@ func createEntUser(t *testing.T, ctx context.Context, client *dbent.Client, emai
func TestEntSoftDelete_ApiKey_DefaultFilterAndSkip(t *testing.T) { func TestEntSoftDelete_ApiKey_DefaultFilterAndSkip(t *testing.T) {
ctx := context.Background() ctx := context.Background()
client, _ := testEntSQLTx(t) // 使用全局 ent client确保软删除验证在实际持久化数据上进行。
client := testEntClient(t)
u := createEntUser(t, ctx, client, uniqueSoftDeleteValue(t, "sd-user")+"@example.com") u := createEntUser(t, ctx, client, uniqueSoftDeleteValue(t, "sd-user")+"@example.com")
@@ -65,7 +66,8 @@ func TestEntSoftDelete_ApiKey_DefaultFilterAndSkip(t *testing.T) {
func TestEntSoftDelete_ApiKey_DeleteIdempotent(t *testing.T) { func TestEntSoftDelete_ApiKey_DeleteIdempotent(t *testing.T) {
ctx := context.Background() ctx := context.Background()
client, _ := testEntSQLTx(t) // 使用全局 ent client避免事务回滚影响幂等性验证。
client := testEntClient(t)
u := createEntUser(t, ctx, client, uniqueSoftDeleteValue(t, "sd-user2")+"@example.com") u := createEntUser(t, ctx, client, uniqueSoftDeleteValue(t, "sd-user2")+"@example.com")
@@ -84,7 +86,8 @@ func TestEntSoftDelete_ApiKey_DeleteIdempotent(t *testing.T) {
func TestEntSoftDelete_ApiKey_HardDeleteViaSkipSoftDelete(t *testing.T) { func TestEntSoftDelete_ApiKey_HardDeleteViaSkipSoftDelete(t *testing.T) {
ctx := context.Background() ctx := context.Background()
client, _ := testEntSQLTx(t) // 使用全局 ent client确保 SkipSoftDelete 的硬删除语义可验证。
client := testEntClient(t)
u := createEntUser(t, ctx, client, uniqueSoftDeleteValue(t, "sd-user3")+"@example.com") u := createEntUser(t, ctx, client, uniqueSoftDeleteValue(t, "sd-user3")+"@example.com")

View File

@@ -0,0 +1,33 @@
package repository
import (
"context"
"database/sql"
)
type sqlQueryer interface {
QueryContext(ctx context.Context, query string, args ...any) (*sql.Rows, error)
}
// scanSingleRow executes a query and scans the first row into dest.
// If no rows are returned, sql.ErrNoRows is returned.
// 设计目的:仅依赖 QueryContext避免 QueryRowContext 对 *sql.Tx 的强绑定,
// 让 ent.Tx 也能作为 sqlExecutor/Queryer 使用。
func scanSingleRow(ctx context.Context, q sqlQueryer, query string, args []any, dest ...any) error {
rows, err := q.QueryContext(ctx, query, args...)
if err != nil {
return err
}
defer rows.Close()
if !rows.Next() {
if err := rows.Err(); err != nil {
return err
}
return sql.ErrNoRows
}
if err := rows.Scan(dest...); err != nil {
return err
}
return rows.Err()
}

View File

@@ -3,7 +3,6 @@ package repository
import ( import (
"context" "context"
"database/sql" "database/sql"
"errors"
"fmt" "fmt"
"strings" "strings"
"time" "time"
@@ -33,6 +32,7 @@ func NewUsageLogRepository(client *dbent.Client, sqlDB *sql.DB) service.UsageLog
} }
func newUsageLogRepositoryWithSQL(client *dbent.Client, sqlq sqlExecutor) *usageLogRepository { func newUsageLogRepositoryWithSQL(client *dbent.Client, sqlq sqlExecutor) *usageLogRepository {
// 使用 scanSingleRow 替代 QueryRowContext保证 ent.Tx 作为 sqlExecutor 可用。
return &usageLogRepository{client: client, sql: sqlq} return &usageLogRepository{client: client, sql: sqlq}
} }
@@ -53,7 +53,7 @@ func (r *usageLogRepository) getPerformanceStats(ctx context.Context, userID int
var requestCount int64 var requestCount int64
var tokenCount int64 var tokenCount int64
if err := r.sql.QueryRowContext(ctx, query, args...).Scan(&requestCount, &tokenCount); err != nil { if err := scanSingleRow(ctx, r.sql, query, args, &requestCount, &tokenCount); err != nil {
return 0, 0, err return 0, 0, err
} }
return requestCount / 5, tokenCount / 5, nil return requestCount / 5, tokenCount / 5, nil
@@ -114,9 +114,7 @@ func (r *usageLogRepository) Create(ctx context.Context, log *service.UsageLog)
duration := nullInt(log.DurationMs) duration := nullInt(log.DurationMs)
firstToken := nullInt(log.FirstTokenMs) firstToken := nullInt(log.FirstTokenMs)
row := r.sql.QueryRowContext( args := []any{
ctx,
query,
log.UserID, log.UserID,
log.ApiKeyID, log.ApiKeyID,
log.AccountID, log.AccountID,
@@ -142,9 +140,8 @@ func (r *usageLogRepository) Create(ctx context.Context, log *service.UsageLog)
duration, duration,
firstToken, firstToken,
createdAt, createdAt,
) }
if err := scanSingleRow(ctx, r.sql, query, args, &log.ID, &log.CreatedAt); err != nil {
if err := row.Scan(&log.ID, &log.CreatedAt); err != nil {
return err return err
} }
log.RateMultiplier = rateMultiplier log.RateMultiplier = rateMultiplier
@@ -153,11 +150,22 @@ func (r *usageLogRepository) Create(ctx context.Context, log *service.UsageLog)
func (r *usageLogRepository) GetByID(ctx context.Context, id int64) (*service.UsageLog, error) { func (r *usageLogRepository) GetByID(ctx context.Context, id int64) (*service.UsageLog, error) {
query := "SELECT " + usageLogSelectColumns + " FROM usage_logs WHERE id = $1" query := "SELECT " + usageLogSelectColumns + " FROM usage_logs WHERE id = $1"
log, err := scanUsageLog(r.sql.QueryRowContext(ctx, query, id)) rows, err := r.sql.QueryContext(ctx, query, id)
if err != nil { if err != nil {
if errors.Is(err, sql.ErrNoRows) { return nil, err
return nil, service.ErrUsageLogNotFound }
defer rows.Close()
if !rows.Next() {
if err := rows.Err(); err != nil {
return nil, err
} }
return nil, service.ErrUsageLogNotFound
}
log, err := scanUsageLog(rows)
if err != nil {
return nil, err
}
if err := rows.Err(); err != nil {
return nil, err return nil, err
} }
return log, nil return log, nil
@@ -195,8 +203,18 @@ func (r *usageLogRepository) GetUserStats(ctx context.Context, userID int64, sta
` `
stats := &UserStats{} stats := &UserStats{}
if err := r.sql.QueryRowContext(ctx, query, userID, startTime, endTime). if err := scanSingleRow(
Scan(&stats.TotalRequests, &stats.TotalTokens, &stats.TotalCost, &stats.InputTokens, &stats.OutputTokens, &stats.CacheReadTokens); err != nil { ctx,
r.sql,
query,
[]any{userID, startTime, endTime},
&stats.TotalRequests,
&stats.TotalTokens,
&stats.TotalCost,
&stats.InputTokens,
&stats.OutputTokens,
&stats.CacheReadTokens,
); err != nil {
return nil, err return nil, err
} }
return stats, nil return stats, nil
@@ -219,8 +237,15 @@ func (r *usageLogRepository) GetDashboardStats(ctx context.Context) (*DashboardS
FROM users FROM users
WHERE deleted_at IS NULL WHERE deleted_at IS NULL
` `
if err := r.sql.QueryRowContext(ctx, userStatsQuery, today, today). if err := scanSingleRow(
Scan(&stats.TotalUsers, &stats.TodayNewUsers, &stats.ActiveUsers); err != nil { ctx,
r.sql,
userStatsQuery,
[]any{today, today},
&stats.TotalUsers,
&stats.TodayNewUsers,
&stats.ActiveUsers,
); err != nil {
return nil, err return nil, err
} }
@@ -232,8 +257,14 @@ func (r *usageLogRepository) GetDashboardStats(ctx context.Context) (*DashboardS
FROM api_keys FROM api_keys
WHERE deleted_at IS NULL WHERE deleted_at IS NULL
` `
if err := r.sql.QueryRowContext(ctx, apiKeyStatsQuery, service.StatusActive). if err := scanSingleRow(
Scan(&stats.TotalApiKeys, &stats.ActiveApiKeys); err != nil { ctx,
r.sql,
apiKeyStatsQuery,
[]any{service.StatusActive},
&stats.TotalApiKeys,
&stats.ActiveApiKeys,
); err != nil {
return nil, err return nil, err
} }
@@ -248,8 +279,17 @@ func (r *usageLogRepository) GetDashboardStats(ctx context.Context) (*DashboardS
FROM accounts FROM accounts
WHERE deleted_at IS NULL WHERE deleted_at IS NULL
` `
if err := r.sql.QueryRowContext(ctx, accountStatsQuery, service.StatusActive, service.StatusError, now, now). if err := scanSingleRow(
Scan(&stats.TotalAccounts, &stats.NormalAccounts, &stats.ErrorAccounts, &stats.RateLimitAccounts, &stats.OverloadAccounts); err != nil { ctx,
r.sql,
accountStatsQuery,
[]any{service.StatusActive, service.StatusError, now, now},
&stats.TotalAccounts,
&stats.NormalAccounts,
&stats.ErrorAccounts,
&stats.RateLimitAccounts,
&stats.OverloadAccounts,
); err != nil {
return nil, err return nil, err
} }
@@ -266,17 +306,20 @@ func (r *usageLogRepository) GetDashboardStats(ctx context.Context) (*DashboardS
COALESCE(AVG(duration_ms), 0) as avg_duration_ms COALESCE(AVG(duration_ms), 0) as avg_duration_ms
FROM usage_logs FROM usage_logs
` `
if err := r.sql.QueryRowContext(ctx, totalStatsQuery). if err := scanSingleRow(
Scan( ctx,
&stats.TotalRequests, r.sql,
&stats.TotalInputTokens, totalStatsQuery,
&stats.TotalOutputTokens, nil,
&stats.TotalCacheCreationTokens, &stats.TotalRequests,
&stats.TotalCacheReadTokens, &stats.TotalInputTokens,
&stats.TotalCost, &stats.TotalOutputTokens,
&stats.TotalActualCost, &stats.TotalCacheCreationTokens,
&stats.AverageDurationMs, &stats.TotalCacheReadTokens,
); err != nil { &stats.TotalCost,
&stats.TotalActualCost,
&stats.AverageDurationMs,
); err != nil {
return nil, err return nil, err
} }
stats.TotalTokens = stats.TotalInputTokens + stats.TotalOutputTokens + stats.TotalCacheCreationTokens + stats.TotalCacheReadTokens stats.TotalTokens = stats.TotalInputTokens + stats.TotalOutputTokens + stats.TotalCacheCreationTokens + stats.TotalCacheReadTokens
@@ -294,16 +337,19 @@ func (r *usageLogRepository) GetDashboardStats(ctx context.Context) (*DashboardS
FROM usage_logs FROM usage_logs
WHERE created_at >= $1 WHERE created_at >= $1
` `
if err := r.sql.QueryRowContext(ctx, todayStatsQuery, today). if err := scanSingleRow(
Scan( ctx,
&stats.TodayRequests, r.sql,
&stats.TodayInputTokens, todayStatsQuery,
&stats.TodayOutputTokens, []any{today},
&stats.TodayCacheCreationTokens, &stats.TodayRequests,
&stats.TodayCacheReadTokens, &stats.TodayInputTokens,
&stats.TodayCost, &stats.TodayOutputTokens,
&stats.TodayActualCost, &stats.TodayCacheCreationTokens,
); err != nil { &stats.TodayCacheReadTokens,
&stats.TodayCost,
&stats.TodayActualCost,
); err != nil {
return nil, err return nil, err
} }
stats.TodayTokens = stats.TodayInputTokens + stats.TodayOutputTokens + stats.TodayCacheCreationTokens + stats.TodayCacheReadTokens stats.TodayTokens = stats.TodayInputTokens + stats.TodayOutputTokens + stats.TodayCacheCreationTokens + stats.TodayCacheReadTokens
@@ -345,16 +391,19 @@ func (r *usageLogRepository) GetUserStatsAggregated(ctx context.Context, userID
` `
var stats usagestats.UsageStats var stats usagestats.UsageStats
if err := r.sql.QueryRowContext(ctx, query, userID, startTime, endTime). if err := scanSingleRow(
Scan( ctx,
&stats.TotalRequests, r.sql,
&stats.TotalInputTokens, query,
&stats.TotalOutputTokens, []any{userID, startTime, endTime},
&stats.TotalCacheTokens, &stats.TotalRequests,
&stats.TotalCost, &stats.TotalInputTokens,
&stats.TotalActualCost, &stats.TotalOutputTokens,
&stats.AverageDurationMs, &stats.TotalCacheTokens,
); err != nil { &stats.TotalCost,
&stats.TotalActualCost,
&stats.AverageDurationMs,
); err != nil {
return nil, err return nil, err
} }
stats.TotalTokens = stats.TotalInputTokens + stats.TotalOutputTokens + stats.TotalCacheTokens stats.TotalTokens = stats.TotalInputTokens + stats.TotalOutputTokens + stats.TotalCacheTokens
@@ -377,16 +426,19 @@ func (r *usageLogRepository) GetApiKeyStatsAggregated(ctx context.Context, apiKe
` `
var stats usagestats.UsageStats var stats usagestats.UsageStats
if err := r.sql.QueryRowContext(ctx, query, apiKeyID, startTime, endTime). if err := scanSingleRow(
Scan( ctx,
&stats.TotalRequests, r.sql,
&stats.TotalInputTokens, query,
&stats.TotalOutputTokens, []any{apiKeyID, startTime, endTime},
&stats.TotalCacheTokens, &stats.TotalRequests,
&stats.TotalCost, &stats.TotalInputTokens,
&stats.TotalActualCost, &stats.TotalOutputTokens,
&stats.AverageDurationMs, &stats.TotalCacheTokens,
); err != nil { &stats.TotalCost,
&stats.TotalActualCost,
&stats.AverageDurationMs,
); err != nil {
return nil, err return nil, err
} }
stats.TotalTokens = stats.TotalInputTokens + stats.TotalOutputTokens + stats.TotalCacheTokens stats.TotalTokens = stats.TotalInputTokens + stats.TotalOutputTokens + stats.TotalCacheTokens
@@ -430,8 +482,15 @@ func (r *usageLogRepository) GetAccountTodayStats(ctx context.Context, accountID
` `
stats := &usagestats.AccountStats{} stats := &usagestats.AccountStats{}
if err := r.sql.QueryRowContext(ctx, query, accountID, today). if err := scanSingleRow(
Scan(&stats.Requests, &stats.Tokens, &stats.Cost); err != nil { ctx,
r.sql,
query,
[]any{accountID, today},
&stats.Requests,
&stats.Tokens,
&stats.Cost,
); err != nil {
return nil, err return nil, err
} }
return stats, nil return stats, nil
@@ -449,8 +508,15 @@ func (r *usageLogRepository) GetAccountWindowStats(ctx context.Context, accountI
` `
stats := &usagestats.AccountStats{} stats := &usagestats.AccountStats{}
if err := r.sql.QueryRowContext(ctx, query, accountID, startTime). if err := scanSingleRow(
Scan(&stats.Requests, &stats.Tokens, &stats.Cost); err != nil { ctx,
r.sql,
query,
[]any{accountID, startTime},
&stats.Requests,
&stats.Tokens,
&stats.Cost,
); err != nil {
return nil, err return nil, err
} }
return stats, nil return stats, nil
@@ -581,12 +647,22 @@ func (r *usageLogRepository) GetUserDashboardStats(ctx context.Context, userID i
today := timezone.Today() today := timezone.Today()
// API Key 统计 // API Key 统计
if err := r.sql.QueryRowContext(ctx, "SELECT COUNT(*) FROM api_keys WHERE user_id = $1 AND deleted_at IS NULL", userID). if err := scanSingleRow(
Scan(&stats.TotalApiKeys); err != nil { ctx,
r.sql,
"SELECT COUNT(*) FROM api_keys WHERE user_id = $1 AND deleted_at IS NULL",
[]any{userID},
&stats.TotalApiKeys,
); err != nil {
return nil, err return nil, err
} }
if err := r.sql.QueryRowContext(ctx, "SELECT COUNT(*) FROM api_keys WHERE user_id = $1 AND status = $2 AND deleted_at IS NULL", userID, service.StatusActive). if err := scanSingleRow(
Scan(&stats.ActiveApiKeys); err != nil { ctx,
r.sql,
"SELECT COUNT(*) FROM api_keys WHERE user_id = $1 AND status = $2 AND deleted_at IS NULL",
[]any{userID, service.StatusActive},
&stats.ActiveApiKeys,
); err != nil {
return nil, err return nil, err
} }
@@ -604,17 +680,20 @@ func (r *usageLogRepository) GetUserDashboardStats(ctx context.Context, userID i
FROM usage_logs FROM usage_logs
WHERE user_id = $1 WHERE user_id = $1
` `
if err := r.sql.QueryRowContext(ctx, totalStatsQuery, userID). if err := scanSingleRow(
Scan( ctx,
&stats.TotalRequests, r.sql,
&stats.TotalInputTokens, totalStatsQuery,
&stats.TotalOutputTokens, []any{userID},
&stats.TotalCacheCreationTokens, &stats.TotalRequests,
&stats.TotalCacheReadTokens, &stats.TotalInputTokens,
&stats.TotalCost, &stats.TotalOutputTokens,
&stats.TotalActualCost, &stats.TotalCacheCreationTokens,
&stats.AverageDurationMs, &stats.TotalCacheReadTokens,
); err != nil { &stats.TotalCost,
&stats.TotalActualCost,
&stats.AverageDurationMs,
); err != nil {
return nil, err return nil, err
} }
stats.TotalTokens = stats.TotalInputTokens + stats.TotalOutputTokens + stats.TotalCacheCreationTokens + stats.TotalCacheReadTokens stats.TotalTokens = stats.TotalInputTokens + stats.TotalOutputTokens + stats.TotalCacheCreationTokens + stats.TotalCacheReadTokens
@@ -632,16 +711,19 @@ func (r *usageLogRepository) GetUserDashboardStats(ctx context.Context, userID i
FROM usage_logs FROM usage_logs
WHERE user_id = $1 AND created_at >= $2 WHERE user_id = $1 AND created_at >= $2
` `
if err := r.sql.QueryRowContext(ctx, todayStatsQuery, userID, today). if err := scanSingleRow(
Scan( ctx,
&stats.TodayRequests, r.sql,
&stats.TodayInputTokens, todayStatsQuery,
&stats.TodayOutputTokens, []any{userID, today},
&stats.TodayCacheCreationTokens, &stats.TodayRequests,
&stats.TodayCacheReadTokens, &stats.TodayInputTokens,
&stats.TodayCost, &stats.TodayOutputTokens,
&stats.TodayActualCost, &stats.TodayCacheCreationTokens,
); err != nil { &stats.TodayCacheReadTokens,
&stats.TodayCost,
&stats.TodayActualCost,
); err != nil {
return nil, err return nil, err
} }
stats.TodayTokens = stats.TodayInputTokens + stats.TodayOutputTokens + stats.TodayCacheCreationTokens + stats.TodayCacheReadTokens stats.TodayTokens = stats.TodayInputTokens + stats.TodayOutputTokens + stats.TodayCacheCreationTokens + stats.TodayCacheReadTokens
@@ -1007,16 +1089,19 @@ func (r *usageLogRepository) GetGlobalStats(ctx context.Context, startTime, endT
` `
stats := &UsageStats{} stats := &UsageStats{}
if err := r.sql.QueryRowContext(ctx, query, startTime, endTime). if err := scanSingleRow(
Scan( ctx,
&stats.TotalRequests, r.sql,
&stats.TotalInputTokens, query,
&stats.TotalOutputTokens, []any{startTime, endTime},
&stats.TotalCacheTokens, &stats.TotalRequests,
&stats.TotalCost, &stats.TotalInputTokens,
&stats.TotalActualCost, &stats.TotalOutputTokens,
&stats.AverageDurationMs, &stats.TotalCacheTokens,
); err != nil { &stats.TotalCost,
&stats.TotalActualCost,
&stats.AverageDurationMs,
); err != nil {
return nil, err return nil, err
} }
stats.TotalTokens = stats.TotalInputTokens + stats.TotalOutputTokens + stats.TotalCacheTokens stats.TotalTokens = stats.TotalInputTokens + stats.TotalOutputTokens + stats.TotalCacheTokens
@@ -1108,7 +1193,7 @@ func (r *usageLogRepository) GetAccountUsageStats(ctx context.Context, accountID
avgQuery := "SELECT COALESCE(AVG(duration_ms), 0) as avg_duration_ms FROM usage_logs WHERE account_id = $1 AND created_at >= $2 AND created_at < $3" avgQuery := "SELECT COALESCE(AVG(duration_ms), 0) as avg_duration_ms FROM usage_logs WHERE account_id = $1 AND created_at >= $2 AND created_at < $3"
var avgDuration float64 var avgDuration float64
if err := r.sql.QueryRowContext(ctx, avgQuery, accountID, startTime, endTime).Scan(&avgDuration); err != nil { if err := scanSingleRow(ctx, r.sql, avgQuery, []any{accountID, startTime, endTime}, &avgDuration); err != nil {
return nil, err return nil, err
} }
@@ -1186,7 +1271,7 @@ func (r *usageLogRepository) GetAccountUsageStats(ctx context.Context, accountID
func (r *usageLogRepository) listUsageLogsWithPagination(ctx context.Context, whereClause string, args []any, params pagination.PaginationParams) ([]service.UsageLog, *pagination.PaginationResult, error) { func (r *usageLogRepository) listUsageLogsWithPagination(ctx context.Context, whereClause string, args []any, params pagination.PaginationParams) ([]service.UsageLog, *pagination.PaginationResult, error) {
countQuery := "SELECT COUNT(*) FROM usage_logs " + whereClause countQuery := "SELECT COUNT(*) FROM usage_logs " + whereClause
var total int64 var total int64
if err := r.sql.QueryRowContext(ctx, countQuery, args...).Scan(&total); err != nil { if err := scanSingleRow(ctx, r.sql, countQuery, args, &total); err != nil {
return nil, nil, err return nil, nil, err
} }

View File

@@ -4,7 +4,6 @@ package repository
import ( import (
"context" "context"
"database/sql"
"testing" "testing"
"time" "time"
@@ -19,17 +18,17 @@ import (
type UsageLogRepoSuite struct { type UsageLogRepoSuite struct {
suite.Suite suite.Suite
ctx context.Context ctx context.Context
tx *sql.Tx tx *dbent.Tx
client *dbent.Client client *dbent.Client
repo *usageLogRepository repo *usageLogRepository
} }
func (s *UsageLogRepoSuite) SetupTest() { func (s *UsageLogRepoSuite) SetupTest() {
s.ctx = context.Background() s.ctx = context.Background()
client, tx := testEntSQLTx(s.T()) tx := testEntTx(s.T())
s.client = client
s.tx = tx s.tx = tx
s.repo = newUsageLogRepositoryWithSQL(client, tx) s.client = tx.Client()
s.repo = newUsageLogRepositoryWithSQL(s.client, tx)
} }
func TestUsageLogRepoSuite(t *testing.T) { func TestUsageLogRepoSuite(t *testing.T) {
@@ -197,6 +196,8 @@ func (s *UsageLogRepoSuite) TestListWithFilters() {
func (s *UsageLogRepoSuite) TestDashboardStats_TodayTotalsAndPerformance() { func (s *UsageLogRepoSuite) TestDashboardStats_TodayTotalsAndPerformance() {
now := time.Now() now := time.Now()
todayStart := timezone.Today() todayStart := timezone.Today()
baseStats, err := s.repo.GetDashboardStats(s.ctx)
s.Require().NoError(err, "GetDashboardStats base")
userToday := mustCreateUser(s.T(), s.client, &service.User{ userToday := mustCreateUser(s.T(), s.client, &service.User{
Email: "today@example.com", Email: "today@example.com",
@@ -268,24 +269,24 @@ func (s *UsageLogRepoSuite) TestDashboardStats_TodayTotalsAndPerformance() {
stats, err := s.repo.GetDashboardStats(s.ctx) stats, err := s.repo.GetDashboardStats(s.ctx)
s.Require().NoError(err, "GetDashboardStats") s.Require().NoError(err, "GetDashboardStats")
s.Require().Equal(int64(2), stats.TotalUsers, "TotalUsers mismatch") s.Require().Equal(baseStats.TotalUsers+2, stats.TotalUsers, "TotalUsers mismatch")
s.Require().Equal(int64(1), stats.TodayNewUsers, "TodayNewUsers mismatch") s.Require().Equal(baseStats.TodayNewUsers+1, stats.TodayNewUsers, "TodayNewUsers mismatch")
s.Require().Equal(int64(1), stats.ActiveUsers, "ActiveUsers mismatch") s.Require().Equal(baseStats.ActiveUsers+1, stats.ActiveUsers, "ActiveUsers mismatch")
s.Require().Equal(int64(2), stats.TotalApiKeys, "TotalApiKeys mismatch") s.Require().Equal(baseStats.TotalApiKeys+2, stats.TotalApiKeys, "TotalApiKeys mismatch")
s.Require().Equal(int64(1), stats.ActiveApiKeys, "ActiveApiKeys mismatch") s.Require().Equal(baseStats.ActiveApiKeys+1, stats.ActiveApiKeys, "ActiveApiKeys mismatch")
s.Require().Equal(int64(4), stats.TotalAccounts, "TotalAccounts mismatch") s.Require().Equal(baseStats.TotalAccounts+4, stats.TotalAccounts, "TotalAccounts mismatch")
s.Require().Equal(int64(1), stats.ErrorAccounts, "ErrorAccounts mismatch") s.Require().Equal(baseStats.ErrorAccounts+1, stats.ErrorAccounts, "ErrorAccounts mismatch")
s.Require().Equal(int64(1), stats.RateLimitAccounts, "RateLimitAccounts mismatch") s.Require().Equal(baseStats.RateLimitAccounts+1, stats.RateLimitAccounts, "RateLimitAccounts mismatch")
s.Require().Equal(int64(1), stats.OverloadAccounts, "OverloadAccounts mismatch") s.Require().Equal(baseStats.OverloadAccounts+1, stats.OverloadAccounts, "OverloadAccounts mismatch")
s.Require().Equal(int64(3), stats.TotalRequests, "TotalRequests mismatch") s.Require().Equal(baseStats.TotalRequests+3, stats.TotalRequests, "TotalRequests mismatch")
s.Require().Equal(int64(16), stats.TotalInputTokens, "TotalInputTokens mismatch") s.Require().Equal(baseStats.TotalInputTokens+int64(16), stats.TotalInputTokens, "TotalInputTokens mismatch")
s.Require().Equal(int64(28), stats.TotalOutputTokens, "TotalOutputTokens mismatch") s.Require().Equal(baseStats.TotalOutputTokens+int64(28), stats.TotalOutputTokens, "TotalOutputTokens mismatch")
s.Require().Equal(int64(3), stats.TotalCacheCreationTokens, "TotalCacheCreationTokens mismatch") s.Require().Equal(baseStats.TotalCacheCreationTokens+int64(3), stats.TotalCacheCreationTokens, "TotalCacheCreationTokens mismatch")
s.Require().Equal(int64(4), stats.TotalCacheReadTokens, "TotalCacheReadTokens mismatch") s.Require().Equal(baseStats.TotalCacheReadTokens+int64(4), stats.TotalCacheReadTokens, "TotalCacheReadTokens mismatch")
s.Require().Equal(int64(51), stats.TotalTokens, "TotalTokens mismatch") s.Require().Equal(baseStats.TotalTokens+int64(51), stats.TotalTokens, "TotalTokens mismatch")
s.Require().Equal(2.3, stats.TotalCost, "TotalCost mismatch") s.Require().Equal(baseStats.TotalCost+2.3, stats.TotalCost, "TotalCost mismatch")
s.Require().Equal(2.0, stats.TotalActualCost, "TotalActualCost mismatch") s.Require().Equal(baseStats.TotalActualCost+2.0, stats.TotalActualCost, "TotalActualCost mismatch")
s.Require().GreaterOrEqual(stats.TodayRequests, int64(1), "expected TodayRequests >= 1") s.Require().GreaterOrEqual(stats.TodayRequests, int64(1), "expected TodayRequests >= 1")
s.Require().GreaterOrEqual(stats.TodayCost, 0.0, "expected TodayCost >= 0") s.Require().GreaterOrEqual(stats.TodayCost, 0.0, "expected TodayCost >= 0")

View File

@@ -3,6 +3,7 @@ package repository
import ( import (
"context" "context"
"database/sql" "database/sql"
"errors"
"sort" "sort"
dbent "github.com/Wei-Shaw/sub2api/ent" dbent "github.com/Wei-Shaw/sub2api/ent"
@@ -17,7 +18,6 @@ import (
type userRepository struct { type userRepository struct {
client *dbent.Client client *dbent.Client
sql sqlExecutor sql sqlExecutor
begin sqlBeginner
} }
func NewUserRepository(client *dbent.Client, sqlDB *sql.DB) service.UserRepository { func NewUserRepository(client *dbent.Client, sqlDB *sql.DB) service.UserRepository {
@@ -25,11 +25,7 @@ func NewUserRepository(client *dbent.Client, sqlDB *sql.DB) service.UserReposito
} }
func newUserRepositoryWithSQL(client *dbent.Client, sqlq sqlExecutor) *userRepository { func newUserRepositoryWithSQL(client *dbent.Client, sqlq sqlExecutor) *userRepository {
var beginner sqlBeginner return &userRepository{client: client, sql: sqlq}
if b, ok := sqlq.(sqlBeginner); ok {
beginner = b
}
return &userRepository{client: client, sql: sqlq, begin: beginner}
} }
func (r *userRepository) Create(ctx context.Context, userIn *service.User) error { func (r *userRepository) Create(ctx context.Context, userIn *service.User) error {
@@ -37,22 +33,20 @@ func (r *userRepository) Create(ctx context.Context, userIn *service.User) error
return nil return nil
} }
exec := r.sql // 统一使用 ent 的事务:保证用户与允许分组的更新原子化,
txClient := r.client // 并避免基于 *sql.Tx 手动构造 ent client 导致的 ExecQuerier 断言错误。
var sqlTx *sql.Tx tx, err := r.client.Tx(ctx)
if err != nil && !errors.Is(err, dbent.ErrTxStarted) {
return err
}
if r.begin != nil { var txClient *dbent.Client
var err error if err == nil {
sqlTx, err = r.begin.BeginTx(ctx, nil) defer func() { _ = tx.Rollback() }()
if err != nil { txClient = tx.Client()
return err } else {
} // 已处于外部事务中ErrTxStarted复用当前 client 并由调用方负责提交/回滚。
exec = sqlTx txClient = r.client
txClient = entClientFromSQLTx(sqlTx)
// 注意:不能调用 txClient.Close(),因为基于事务的 ent client
// 在 Close() 时会尝试将 ExecQuerier 断言为 *sql.DB但实际是 *sql.Tx
// 事务的清理通过 sqlTx.Rollback() 和 sqlTx.Commit() 完成
defer func() { _ = sqlTx.Rollback() }()
} }
created, err := txClient.User.Create(). created, err := txClient.User.Create().
@@ -70,12 +64,12 @@ func (r *userRepository) Create(ctx context.Context, userIn *service.User) error
return translatePersistenceError(err, nil, service.ErrEmailExists) return translatePersistenceError(err, nil, service.ErrEmailExists)
} }
if err := r.syncUserAllowedGroups(ctx, txClient, exec, created.ID, userIn.AllowedGroups); err != nil { if err := r.syncUserAllowedGroupsWithClient(ctx, txClient, created.ID, userIn.AllowedGroups); err != nil {
return err return err
} }
if sqlTx != nil { if tx != nil {
if err := sqlTx.Commit(); err != nil { if err := tx.Commit(); err != nil {
return err return err
} }
} }
@@ -121,22 +115,19 @@ func (r *userRepository) Update(ctx context.Context, userIn *service.User) error
return nil return nil
} }
exec := r.sql // 使用 ent 事务包裹用户更新与 allowed_groups 同步,避免跨层事务不一致。
txClient := r.client tx, err := r.client.Tx(ctx)
var sqlTx *sql.Tx if err != nil && !errors.Is(err, dbent.ErrTxStarted) {
return err
}
if r.begin != nil { var txClient *dbent.Client
var err error if err == nil {
sqlTx, err = r.begin.BeginTx(ctx, nil) defer func() { _ = tx.Rollback() }()
if err != nil { txClient = tx.Client()
return err } else {
} // 已处于外部事务中ErrTxStarted复用当前 client 并由调用方负责提交/回滚。
exec = sqlTx txClient = r.client
txClient = entClientFromSQLTx(sqlTx)
// 注意:不能调用 txClient.Close(),因为基于事务的 ent client
// 在 Close() 时会尝试将 ExecQuerier 断言为 *sql.DB但实际是 *sql.Tx
// 事务的清理通过 sqlTx.Rollback() 和 sqlTx.Commit() 完成
defer func() { _ = sqlTx.Rollback() }()
} }
updated, err := txClient.User.UpdateOneID(userIn.ID). updated, err := txClient.User.UpdateOneID(userIn.ID).
@@ -154,12 +145,12 @@ func (r *userRepository) Update(ctx context.Context, userIn *service.User) error
return translatePersistenceError(err, service.ErrUserNotFound, service.ErrEmailExists) return translatePersistenceError(err, service.ErrUserNotFound, service.ErrEmailExists)
} }
if err := r.syncUserAllowedGroups(ctx, txClient, exec, updated.ID, userIn.AllowedGroups); err != nil { if err := r.syncUserAllowedGroupsWithClient(ctx, txClient, updated.ID, userIn.AllowedGroups); err != nil {
return err return err
} }
if sqlTx != nil { if tx != nil {
if err := sqlTx.Commit(); err != nil { if err := tx.Commit(); err != nil {
return err return err
} }
} }
@@ -289,8 +280,10 @@ func (r *userRepository) ExistsByEmail(ctx context.Context, email string) (bool,
} }
func (r *userRepository) RemoveGroupFromAllowedGroups(ctx context.Context, groupID int64) (int64, error) { func (r *userRepository) RemoveGroupFromAllowedGroups(ctx context.Context, groupID int64) (int64, error) {
if r.sql == nil { exec := r.sql
return 0, nil if exec == nil {
// 未注入 sqlExecutor 时,退回到 ent client 的 ExecContext支持事务
exec = r.client
} }
joinAffected, err := r.client.UserAllowedGroup.Delete(). joinAffected, err := r.client.UserAllowedGroup.Delete().
@@ -300,7 +293,7 @@ func (r *userRepository) RemoveGroupFromAllowedGroups(ctx context.Context, group
return 0, err return 0, err
} }
arrayRes, err := r.sql.ExecContext( arrayRes, err := exec.ExecContext(
ctx, ctx,
"UPDATE users SET allowed_groups = array_remove(allowed_groups, $1), updated_at = NOW() WHERE $1 = ANY(allowed_groups)", "UPDATE users SET allowed_groups = array_remove(allowed_groups, $1), updated_at = NOW() WHERE $1 = ANY(allowed_groups)",
groupID, groupID,
@@ -362,6 +355,56 @@ func (r *userRepository) loadAllowedGroups(ctx context.Context, userIDs []int64)
return out, nil return out, nil
} }
// syncUserAllowedGroupsWithClient 在 ent client/事务内同步用户允许分组:
// 1) 以 user_allowed_groups 为读写源,确保新旧逻辑一致;
// 2) 额外更新 users.allowed_groups历史字段以保持兼容。
func (r *userRepository) syncUserAllowedGroupsWithClient(ctx context.Context, client *dbent.Client, userID int64, groupIDs []int64) error {
if client == nil {
return nil
}
// Keep join table as the source of truth for reads.
if _, err := client.UserAllowedGroup.Delete().Where(userallowedgroup.UserIDEQ(userID)).Exec(ctx); err != nil {
return err
}
unique := make(map[int64]struct{}, len(groupIDs))
for _, id := range groupIDs {
if id <= 0 {
continue
}
unique[id] = struct{}{}
}
legacyGroups := make([]int64, 0, len(unique))
if len(unique) > 0 {
creates := make([]*dbent.UserAllowedGroupCreate, 0, len(unique))
for groupID := range unique {
creates = append(creates, client.UserAllowedGroup.Create().SetUserID(userID).SetGroupID(groupID))
legacyGroups = append(legacyGroups, groupID)
}
if err := client.UserAllowedGroup.
CreateBulk(creates...).
OnConflictColumns(userallowedgroup.FieldUserID, userallowedgroup.FieldGroupID).
DoNothing().
Exec(ctx); err != nil {
return err
}
}
// Phase 1 兼容:保持 users.allowed_groups数组字段同步避免旧查询路径读取到过期数据。
var legacy any
if len(legacyGroups) > 0 {
sort.Slice(legacyGroups, func(i, j int) bool { return legacyGroups[i] < legacyGroups[j] })
legacy = pq.Array(legacyGroups)
}
if _, err := client.ExecContext(ctx, "UPDATE users SET allowed_groups = $1::bigint[] WHERE id = $2", legacy, userID); err != nil {
return err
}
return nil
}
func (r *userRepository) syncUserAllowedGroups(ctx context.Context, client *dbent.Client, exec sqlExecutor, userID int64, groupIDs []int64) error { func (r *userRepository) syncUserAllowedGroups(ctx context.Context, client *dbent.Client, exec sqlExecutor, userID int64, groupIDs []int64) error {
if client == nil || exec == nil { if client == nil || exec == nil {
return nil return nil

View File

@@ -4,7 +4,6 @@ package repository
import ( import (
"context" "context"
"database/sql"
"testing" "testing"
"time" "time"
@@ -17,17 +16,19 @@ import (
type UserRepoSuite struct { type UserRepoSuite struct {
suite.Suite suite.Suite
ctx context.Context ctx context.Context
tx *sql.Tx
client *dbent.Client client *dbent.Client
repo *userRepository repo *userRepository
} }
func (s *UserRepoSuite) SetupTest() { func (s *UserRepoSuite) SetupTest() {
s.ctx = context.Background() s.ctx = context.Background()
entClient, tx := testEntSQLTx(s.T()) s.client = testEntClient(s.T())
s.tx = tx s.repo = newUserRepositoryWithSQL(s.client, integrationDB)
s.client = entClient
s.repo = newUserRepositoryWithSQL(entClient, tx) // 清理测试数据,确保每个测试从干净状态开始
_, _ = integrationDB.ExecContext(s.ctx, "DELETE FROM user_subscriptions")
_, _ = integrationDB.ExecContext(s.ctx, "DELETE FROM user_allowed_groups")
_, _ = integrationDB.ExecContext(s.ctx, "DELETE FROM users")
} }
func TestUserRepoSuite(t *testing.T) { func TestUserRepoSuite(t *testing.T) {

View File

@@ -22,8 +22,8 @@ type UserSubscriptionRepoSuite struct {
func (s *UserSubscriptionRepoSuite) SetupTest() { func (s *UserSubscriptionRepoSuite) SetupTest() {
s.ctx = context.Background() s.ctx = context.Background()
client, _ := testEntSQLTx(s.T()) tx := testEntTx(s.T())
s.client = client s.client = tx.Client()
s.repo = NewUserSubscriptionRepository(s.client).(*userSubscriptionRepository) s.repo = NewUserSubscriptionRepository(s.client).(*userSubscriptionRepository)
} }
@@ -66,8 +66,8 @@ func (s *UserSubscriptionRepoSuite) mustCreateSubscription(userID, groupID int64
create := s.client.UserSubscription.Create(). create := s.client.UserSubscription.Create().
SetUserID(userID). SetUserID(userID).
SetGroupID(groupID). SetGroupID(groupID).
SetStartsAt(now.Add(-1*time.Hour)). SetStartsAt(now.Add(-1 * time.Hour)).
SetExpiresAt(now.Add(24*time.Hour)). SetExpiresAt(now.Add(24 * time.Hour)).
SetStatus(service.SubscriptionStatusActive). SetStatus(service.SubscriptionStatusActive).
SetAssignedAt(now). SetAssignedAt(now).
SetNotes("") SetNotes("")
@@ -631,4 +631,3 @@ func (s *UserSubscriptionRepoSuite) TestActiveExpiredBoundaries_UsageAndReset_Ba
s.Require().NoError(err, "GetByID expired") s.Require().NoError(err, "GetByID expired")
s.Require().Equal(service.SubscriptionStatusExpired, updated.Status, "expected status expired") s.Require().Equal(service.SubscriptionStatusExpired, updated.Status, "expected status expired")
} }