fix(仓储): 修复软删除过滤与事务测试

修复软删除拦截器使用错误,确保默认查询过滤已删记录
仓储层改用 ent.Tx 与扫描辅助,避免 sql.Tx 断言问题
同步更新集成测试以覆盖事务与统计变动
This commit is contained in:
yangjianbo
2025-12-29 19:23:49 +08:00
parent b436da7249
commit ae191f72a4
20 changed files with 565 additions and 326 deletions

View File

@@ -7,6 +7,7 @@ import (
"fmt"
"time"
"github.com/Wei-Shaw/sub2api/ent/intercept"
"entgo.io/ent"
"entgo.io/ent/dialect"
"entgo.io/ent/dialect/sql"
@@ -79,16 +80,13 @@ func SkipSoftDelete(parent context.Context) context.Context {
// 确保软删除的记录不会出现在普通查询结果中。
func (d SoftDeleteMixin) Interceptors() []ent.Interceptor {
return []ent.Interceptor{
ent.TraverseFunc(func(ctx context.Context, q ent.Query) error {
intercept.TraverseFunc(func(ctx context.Context, q intercept.Query) error {
// 检查是否需要跳过软删除过滤
if skip, _ := ctx.Value(softDeleteKey{}).(bool); skip {
return nil
}
// 为查询添加 deleted_at IS NULL 条件
w, ok := q.(interface{ WhereP(...func(*sql.Selector)) })
if ok {
d.applyPredicate(w)
}
d.applyPredicate(q)
return nil
}),
}

View File

@@ -28,6 +28,7 @@ import (
"github.com/lib/pq"
entsql "entgo.io/ent/dialect/sql"
"entgo.io/ent/dialect/sql/sqljson"
)
// accountRepository 实现 service.AccountRepository 接口。
@@ -36,11 +37,9 @@ import (
// 设计说明:
// - client: Ent 客户端,用于类型安全的 ORM 操作
// - sql: 原生 SQL 执行器,用于复杂查询和批量操作
// - begin: SQL 事务开启器,用于需要事务的操作
type accountRepository struct {
client *dbent.Client // Ent ORM 客户端
sql sqlExecutor // 原生 SQL 执行接口
begin sqlBeginner // 事务开启接口
client *dbent.Client // Ent ORM 客户端
sql sqlExecutor // 原生 SQL 执行接口
}
// NewAccountRepository 创建账户仓储实例。
@@ -52,11 +51,7 @@ func NewAccountRepository(client *dbent.Client, sqlDB *sql.DB) service.AccountRe
// newAccountRepositoryWithSQL 是内部构造函数,支持依赖注入 SQL 执行器。
// 这种设计便于单元测试时注入 mock 对象。
func newAccountRepositoryWithSQL(client *dbent.Client, sqlq sqlExecutor) *accountRepository {
var beginner sqlBeginner
if b, ok := sqlq.(sqlBeginner); ok {
beginner = b
}
return &accountRepository{client: client, sql: sqlq, begin: beginner}
return &accountRepository{client: client, sql: sqlq}
}
func (r *accountRepository) Create(ctx context.Context, account *service.Account) error {
@@ -146,9 +141,10 @@ func (r *accountRepository) GetByCRSAccountID(ctx context.Context, crsAccountID
return nil, nil
}
// 使用 sqljson.ValueEQ 生成 JSON 路径过滤,避免手写 SQL 片段导致语法兼容问题。
m, err := r.client.Account.Query().
Where(func(s *entsql.Selector) {
s.Where(entsql.ExprP("extra->>'crs_account_id' = ?", crsAccountID))
s.Where(sqljson.ValueEQ(dbaccount.FieldExtra, crsAccountID, sqljson.Path("crs_account_id")))
}).
Only(ctx)
if err != nil {

View File

@@ -4,7 +4,6 @@ package repository
import (
"context"
"database/sql"
"testing"
"time"
@@ -17,18 +16,16 @@ import (
type AccountRepoSuite struct {
suite.Suite
ctx context.Context
tx *sql.Tx
ctx context.Context
client *dbent.Client
repo *accountRepository
repo *accountRepository
}
func (s *AccountRepoSuite) SetupTest() {
s.ctx = context.Background()
client, tx := testEntSQLTx(s.T())
s.client = client
s.tx = tx
s.repo = newAccountRepositoryWithSQL(client, tx)
tx := testEntTx(s.T())
s.client = tx.Client()
s.repo = newAccountRepositoryWithSQL(s.client, tx)
}
func TestAccountRepoSuite(t *testing.T) {
@@ -175,7 +172,8 @@ func (s *AccountRepoSuite) TestListWithFilters() {
for _, tt := range tests {
s.Run(tt.name, func() {
// 每个 case 重新获取隔离资源
client, tx := testEntSQLTx(s.T())
tx := testEntTx(s.T())
client := tx.Client()
repo := newAccountRepositoryWithSQL(client, tx)
ctx := context.Background()

View File

@@ -20,7 +20,8 @@ func uniqueTestValue(t *testing.T, prefix string) string {
func TestUserRepository_RemoveGroupFromAllowedGroups_RemovesAllOccurrences(t *testing.T) {
ctx := context.Background()
entClient, sqlTx := testEntSQLTx(t)
tx := testEntTx(t)
entClient := tx.Client()
targetGroup, err := entClient.Group.Create().
SetName(uniqueTestValue(t, "target-group")).
@@ -33,7 +34,7 @@ func TestUserRepository_RemoveGroupFromAllowedGroups_RemovesAllOccurrences(t *te
Save(ctx)
require.NoError(t, err)
repo := newUserRepositoryWithSQL(entClient, sqlTx)
repo := newUserRepositoryWithSQL(entClient, tx)
u1 := &service.User{
Email: uniqueTestValue(t, "u1") + "@example.com",
@@ -81,7 +82,8 @@ func TestUserRepository_RemoveGroupFromAllowedGroups_RemovesAllOccurrences(t *te
func TestGroupRepository_DeleteCascade_RemovesAllowedGroupsAndClearsApiKeys(t *testing.T) {
ctx := context.Background()
entClient, sqlTx := testEntSQLTx(t)
tx := testEntTx(t)
entClient := tx.Client()
targetGroup, err := entClient.Group.Create().
SetName(uniqueTestValue(t, "delete-cascade-target")).
@@ -94,8 +96,8 @@ func TestGroupRepository_DeleteCascade_RemovesAllowedGroupsAndClearsApiKeys(t *t
Save(ctx)
require.NoError(t, err)
userRepo := newUserRepositoryWithSQL(entClient, sqlTx)
groupRepo := newGroupRepositoryWithSQL(entClient, sqlTx)
userRepo := newUserRepositoryWithSQL(entClient, tx)
groupRepo := newGroupRepositoryWithSQL(entClient, tx)
apiKeyRepo := NewApiKeyRepository(entClient)
u := &service.User{
@@ -141,4 +143,3 @@ func TestGroupRepository_DeleteCascade_RemovesAllowedGroupsAndClearsApiKeys(t *t
require.NoError(t, err)
require.Nil(t, keyAfter.GroupID)
}

View File

@@ -2,9 +2,11 @@ package repository
import (
"context"
"time"
dbent "github.com/Wei-Shaw/sub2api/ent"
"github.com/Wei-Shaw/sub2api/ent/apikey"
"github.com/Wei-Shaw/sub2api/ent/schema/mixins"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
@@ -18,6 +20,11 @@ func NewApiKeyRepository(client *dbent.Client) service.ApiKeyRepository {
return &apiKeyRepository{client: client}
}
func (r *apiKeyRepository) activeQuery() *dbent.ApiKeyQuery {
// 默认过滤已软删除记录,避免删除后仍被查询到。
return r.client.ApiKey.Query().Where(apikey.DeletedAtIsNil())
}
func (r *apiKeyRepository) Create(ctx context.Context, key *service.ApiKey) error {
created, err := r.client.ApiKey.Create().
SetUserID(key.UserID).
@@ -35,7 +42,7 @@ func (r *apiKeyRepository) Create(ctx context.Context, key *service.ApiKey) erro
}
func (r *apiKeyRepository) GetByID(ctx context.Context, id int64) (*service.ApiKey, error) {
m, err := r.client.ApiKey.Query().
m, err := r.activeQuery().
Where(apikey.IDEQ(id)).
WithUser().
WithGroup().
@@ -55,7 +62,7 @@ func (r *apiKeyRepository) GetByID(ctx context.Context, id int64) (*service.ApiK
// - 不加载完整的 ApiKey 实体及其关联数据User、Group 等)
// - 适用于权限验证等只需用户 ID 的场景(如删除前的所有权检查)
func (r *apiKeyRepository) GetOwnerID(ctx context.Context, id int64) (int64, error) {
m, err := r.client.ApiKey.Query().
m, err := r.activeQuery().
Where(apikey.IDEQ(id)).
Select(apikey.FieldUserID).
Only(ctx)
@@ -69,7 +76,7 @@ func (r *apiKeyRepository) GetOwnerID(ctx context.Context, id int64) (int64, err
}
func (r *apiKeyRepository) GetByKey(ctx context.Context, key string) (*service.ApiKey, error) {
m, err := r.client.ApiKey.Query().
m, err := r.activeQuery().
Where(apikey.KeyEQ(key)).
WithUser().
WithGroup().
@@ -84,6 +91,14 @@ func (r *apiKeyRepository) GetByKey(ctx context.Context, key string) (*service.A
}
func (r *apiKeyRepository) Update(ctx context.Context, key *service.ApiKey) error {
exists, err := r.activeQuery().Where(apikey.IDEQ(key.ID)).Exist(ctx)
if err != nil {
return err
}
if !exists {
return service.ErrApiKeyNotFound
}
builder := r.client.ApiKey.UpdateOneID(key.ID).
SetName(key.Name).
SetStatus(key.Status)
@@ -105,12 +120,34 @@ func (r *apiKeyRepository) Update(ctx context.Context, key *service.ApiKey) erro
}
func (r *apiKeyRepository) Delete(ctx context.Context, id int64) error {
_, err := r.client.ApiKey.Delete().Where(apikey.IDEQ(id)).Exec(ctx)
return err
// 显式软删除:避免依赖 Hook 行为,确保 deleted_at 一定被设置。
affected, err := r.client.ApiKey.Update().
Where(apikey.IDEQ(id), apikey.DeletedAtIsNil()).
SetDeletedAt(time.Now()).
Save(ctx)
if err != nil {
if dbent.IsNotFound(err) {
return service.ErrApiKeyNotFound
}
return err
}
if affected == 0 {
exists, err := r.client.ApiKey.Query().
Where(apikey.IDEQ(id)).
Exist(mixins.SkipSoftDelete(ctx))
if err != nil {
return err
}
if exists {
return nil
}
return service.ErrApiKeyNotFound
}
return nil
}
func (r *apiKeyRepository) ListByUserID(ctx context.Context, userID int64, params pagination.PaginationParams) ([]service.ApiKey, *pagination.PaginationResult, error) {
q := r.client.ApiKey.Query().Where(apikey.UserIDEQ(userID))
q := r.activeQuery().Where(apikey.UserIDEQ(userID))
total, err := q.Count(ctx)
if err != nil {
@@ -141,7 +178,7 @@ func (r *apiKeyRepository) VerifyOwnership(ctx context.Context, userID int64, ap
}
ids, err := r.client.ApiKey.Query().
Where(apikey.UserIDEQ(userID), apikey.IDIn(apiKeyIDs...)).
Where(apikey.UserIDEQ(userID), apikey.IDIn(apiKeyIDs...), apikey.DeletedAtIsNil()).
IDs(ctx)
if err != nil {
return nil, err
@@ -150,17 +187,17 @@ func (r *apiKeyRepository) VerifyOwnership(ctx context.Context, userID int64, ap
}
func (r *apiKeyRepository) CountByUserID(ctx context.Context, userID int64) (int64, error) {
count, err := r.client.ApiKey.Query().Where(apikey.UserIDEQ(userID)).Count(ctx)
count, err := r.activeQuery().Where(apikey.UserIDEQ(userID)).Count(ctx)
return int64(count), err
}
func (r *apiKeyRepository) ExistsByKey(ctx context.Context, key string) (bool, error) {
count, err := r.client.ApiKey.Query().Where(apikey.KeyEQ(key)).Count(ctx)
count, err := r.activeQuery().Where(apikey.KeyEQ(key)).Count(ctx)
return count > 0, err
}
func (r *apiKeyRepository) ListByGroupID(ctx context.Context, groupID int64, params pagination.PaginationParams) ([]service.ApiKey, *pagination.PaginationResult, error) {
q := r.client.ApiKey.Query().Where(apikey.GroupIDEQ(groupID))
q := r.activeQuery().Where(apikey.GroupIDEQ(groupID))
total, err := q.Count(ctx)
if err != nil {
@@ -187,7 +224,7 @@ func (r *apiKeyRepository) ListByGroupID(ctx context.Context, groupID int64, par
// SearchApiKeys searches API keys by user ID and/or keyword (name)
func (r *apiKeyRepository) SearchApiKeys(ctx context.Context, userID int64, keyword string, limit int) ([]service.ApiKey, error) {
q := r.client.ApiKey.Query()
q := r.activeQuery()
if userID > 0 {
q = q.Where(apikey.UserIDEQ(userID))
}
@@ -211,7 +248,7 @@ func (r *apiKeyRepository) SearchApiKeys(ctx context.Context, userID int64, keyw
// ClearGroupIDByGroupID 将指定分组的所有 API Key 的 group_id 设为 nil
func (r *apiKeyRepository) ClearGroupIDByGroupID(ctx context.Context, groupID int64) (int64, error) {
n, err := r.client.ApiKey.Update().
Where(apikey.GroupIDEQ(groupID)).
Where(apikey.GroupIDEQ(groupID), apikey.DeletedAtIsNil()).
ClearGroupID().
Save(ctx)
return int64(n), err
@@ -219,7 +256,7 @@ func (r *apiKeyRepository) ClearGroupIDByGroupID(ctx context.Context, groupID in
// CountByGroupID 获取分组的 API Key 数量
func (r *apiKeyRepository) CountByGroupID(ctx context.Context, groupID int64) (int64, error) {
count, err := r.client.ApiKey.Query().Where(apikey.GroupIDEQ(groupID)).Count(ctx)
count, err := r.activeQuery().Where(apikey.GroupIDEQ(groupID)).Count(ctx)
return int64(count), err
}

View File

@@ -21,9 +21,9 @@ type ApiKeyRepoSuite struct {
func (s *ApiKeyRepoSuite) SetupTest() {
s.ctx = context.Background()
entClient, _ := testEntSQLTx(s.T())
s.client = entClient
s.repo = NewApiKeyRepository(entClient).(*apiKeyRepository)
tx := testEntTx(s.T())
s.client = tx.Client()
s.repo = NewApiKeyRepository(s.client).(*apiKeyRepository)
}
func TestApiKeyRepoSuite(t *testing.T) {

View File

@@ -3,6 +3,7 @@ package repository
import (
"context"
"database/sql"
"errors"
dbent "github.com/Wei-Shaw/sub2api/ent"
"github.com/Wei-Shaw/sub2api/ent/apikey"
@@ -10,26 +11,16 @@ import (
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/lib/pq"
"entgo.io/ent/dialect"
entsql "entgo.io/ent/dialect/sql"
)
type sqlExecutor interface {
ExecContext(ctx context.Context, query string, args ...any) (sql.Result, error)
QueryContext(ctx context.Context, query string, args ...any) (*sql.Rows, error)
QueryRowContext(ctx context.Context, query string, args ...any) *sql.Row
}
type sqlBeginner interface {
sqlExecutor
BeginTx(ctx context.Context, opts *sql.TxOptions) (*sql.Tx, error)
}
type groupRepository struct {
client *dbent.Client
sql sqlExecutor
begin sqlBeginner
}
func NewGroupRepository(client *dbent.Client, sqlDB *sql.DB) service.GroupRepository {
@@ -37,11 +28,7 @@ func NewGroupRepository(client *dbent.Client, sqlDB *sql.DB) service.GroupReposi
}
func newGroupRepositoryWithSQL(client *dbent.Client, sqlq sqlExecutor) *groupRepository {
var beginner sqlBeginner
if b, ok := sqlq.(sqlBeginner); ok {
beginner = b
}
return &groupRepository{client: client, sql: sqlq, begin: beginner}
return &groupRepository{client: client, sql: sqlq}
}
func (r *groupRepository) Create(ctx context.Context, groupIn *service.Group) error {
@@ -214,7 +201,7 @@ func (r *groupRepository) ExistsByName(ctx context.Context, name string) (bool,
func (r *groupRepository) GetAccountCount(ctx context.Context, groupID int64) (int64, error) {
var count int64
if err := r.sql.QueryRowContext(ctx, "SELECT COUNT(*) FROM account_groups WHERE group_id = $1", groupID).Scan(&count); err != nil {
if err := scanSingleRow(ctx, r.sql, "SELECT COUNT(*) FROM account_groups WHERE group_id = $1", []any{groupID}, &count); err != nil {
return 0, err
}
return count, nil
@@ -236,31 +223,44 @@ func (r *groupRepository) DeleteCascade(ctx context.Context, id int64) ([]int64,
}
groupSvc := groupEntityToService(g)
exec := r.sql
// 使用 ent 事务统一包裹:避免手工基于 *sql.Tx 构造 ent client 带来的驱动断言问题,
// 同时保证级联删除的原子性。
tx, err := r.client.Tx(ctx)
if err != nil && !errors.Is(err, dbent.ErrTxStarted) {
return nil, err
}
exec := r.client
txClient := r.client
var sqlTx *sql.Tx
if r.begin != nil {
sqlTx, err = r.begin.BeginTx(ctx, nil)
if err != nil {
return nil, err
}
exec = sqlTx
txClient = entClientFromSQLTx(sqlTx)
// 注意:不能调用 txClient.Close(),因为基于事务的 ent client
// 在 Close() 时会尝试将 ExecQuerier 断言为 *sql.DB但实际是 *sql.Tx
// 事务的清理通过 sqlTx.Rollback() 和 sqlTx.Commit() 完成
defer func() { _ = sqlTx.Rollback() }()
if err == nil {
defer func() { _ = tx.Rollback() }()
exec = tx.Client()
txClient = exec
} else {
// 已处于外部事务中ErrTxStarted复用当前 client 参与同一事务。
}
// Lock the group row to avoid concurrent writes while we cascade.
var lockedID int64
if err := exec.QueryRowContext(ctx, "SELECT id FROM groups WHERE id = $1 FOR UPDATE", id).Scan(&lockedID); err != nil {
if errorsIsNoRows(err) {
return nil, service.ErrGroupNotFound
}
// 这里使用 exec.QueryContext 手动扫描,确保同一事务内加锁并能区分“未找到”与其他错误。
rows, err := exec.QueryContext(ctx, "SELECT id FROM groups WHERE id = $1 FOR UPDATE", id)
if err != nil {
return nil, err
}
var lockedID int64
if rows.Next() {
if err := rows.Scan(&lockedID); err != nil {
_ = rows.Close()
return nil, err
}
}
if err := rows.Close(); err != nil {
return nil, err
}
if err := rows.Err(); err != nil {
return nil, err
}
if lockedID == 0 {
return nil, service.ErrGroupNotFound
}
var affectedUserIDs []int64
if groupSvc.IsSubscriptionType() {
@@ -319,8 +319,8 @@ func (r *groupRepository) DeleteCascade(ctx context.Context, id int64) ([]int64,
return nil, err
}
if sqlTx != nil {
if err := sqlTx.Commit(); err != nil {
if tx != nil {
if err := tx.Commit(); err != nil {
return nil, err
}
}
@@ -359,11 +359,6 @@ func (r *groupRepository) loadAccountCounts(ctx context.Context, groupIDs []int6
return counts, nil
}
func entClientFromSQLTx(tx *sql.Tx) *dbent.Client {
drv := entsql.NewDriver(dialect.Postgres, entsql.Conn{ExecQuerier: tx})
return dbent.NewClient(dbent.Driver(drv))
}
func errorsIsNoRows(err error) bool {
return err == sql.ErrNoRows
}

View File

@@ -4,9 +4,9 @@ package repository
import (
"context"
"database/sql"
"testing"
dbent "github.com/Wei-Shaw/sub2api/ent"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/stretchr/testify/suite"
@@ -15,15 +15,15 @@ import (
type GroupRepoSuite struct {
suite.Suite
ctx context.Context
tx *sql.Tx
tx *dbent.Tx
repo *groupRepository
}
func (s *GroupRepoSuite) SetupTest() {
s.ctx = context.Background()
entClient, tx := testEntSQLTx(s.T())
tx := testEntTx(s.T())
s.tx = tx
s.repo = newGroupRepositoryWithSQL(entClient, tx)
s.repo = newGroupRepositoryWithSQL(tx.Client(), tx)
}
func TestGroupRepoSuite(t *testing.T) {
@@ -99,6 +99,9 @@ func (s *GroupRepoSuite) TestDelete() {
// --- List / ListWithFilters ---
func (s *GroupRepoSuite) TestList() {
baseGroups, basePage, err := s.repo.List(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10})
s.Require().NoError(err, "List base")
s.Require().NoError(s.repo.Create(s.ctx, &service.Group{
Name: "g1",
Platform: service.PlatformAnthropic,
@@ -118,12 +121,20 @@ func (s *GroupRepoSuite) TestList() {
groups, page, err := s.repo.List(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10})
s.Require().NoError(err, "List")
// 3 default groups + 2 test groups = 5 total
s.Require().Len(groups, 5)
s.Require().Equal(int64(5), page.Total)
s.Require().Len(groups, len(baseGroups)+2)
s.Require().Equal(basePage.Total+2, page.Total)
}
func (s *GroupRepoSuite) TestListWithFilters_Platform() {
baseGroups, _, err := s.repo.ListWithFilters(
s.ctx,
pagination.PaginationParams{Page: 1, PageSize: 10},
service.PlatformOpenAI,
"",
nil,
)
s.Require().NoError(err, "ListWithFilters base")
s.Require().NoError(s.repo.Create(s.ctx, &service.Group{
Name: "g1",
Platform: service.PlatformAnthropic,
@@ -143,8 +154,7 @@ func (s *GroupRepoSuite) TestListWithFilters_Platform() {
groups, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, service.PlatformOpenAI, "", nil)
s.Require().NoError(err)
// 1 default openai group + 1 test openai group = 2 total
s.Require().Len(groups, 2)
s.Require().Len(groups, len(baseGroups)+1)
// Verify all groups are OpenAI platform
for _, g := range groups {
s.Require().Equal(service.PlatformOpenAI, g.Platform)
@@ -221,11 +231,13 @@ func (s *GroupRepoSuite) TestListWithFilters_AccountCount() {
s.Require().NoError(s.repo.Create(s.ctx, g2))
var accountID int64
s.Require().NoError(s.tx.QueryRowContext(
s.Require().NoError(scanSingleRow(
s.ctx,
s.tx,
"INSERT INTO accounts (name, platform, type) VALUES ($1, $2, $3) RETURNING id",
"acc1", service.PlatformAnthropic, service.AccountTypeOAuth,
).Scan(&accountID))
[]any{"acc1", service.PlatformAnthropic, service.AccountTypeOAuth},
&accountID,
))
_, err := s.tx.ExecContext(s.ctx, "INSERT INTO account_groups (account_id, group_id, priority, created_at) VALUES ($1, $2, $3, NOW())", accountID, g1.ID, 1)
s.Require().NoError(err)
_, err = s.tx.ExecContext(s.ctx, "INSERT INTO account_groups (account_id, group_id, priority, created_at) VALUES ($1, $2, $3, NOW())", accountID, g2.ID, 1)
@@ -243,6 +255,9 @@ func (s *GroupRepoSuite) TestListWithFilters_AccountCount() {
// --- ListActive / ListActiveByPlatform ---
func (s *GroupRepoSuite) TestListActive() {
baseGroups, err := s.repo.ListActive(s.ctx)
s.Require().NoError(err, "ListActive base")
s.Require().NoError(s.repo.Create(s.ctx, &service.Group{
Name: "active1",
Platform: service.PlatformAnthropic,
@@ -262,8 +277,7 @@ func (s *GroupRepoSuite) TestListActive() {
groups, err := s.repo.ListActive(s.ctx)
s.Require().NoError(err, "ListActive")
// 3 default groups (all active) + 1 test active group = 4 total
s.Require().Len(groups, 4)
s.Require().Len(groups, len(baseGroups)+1)
// Verify our test group is in the results
var found bool
for _, g := range groups {
@@ -351,17 +365,21 @@ func (s *GroupRepoSuite) TestGetAccountCount() {
s.Require().NoError(s.repo.Create(s.ctx, group))
var a1 int64
s.Require().NoError(s.tx.QueryRowContext(
s.Require().NoError(scanSingleRow(
s.ctx,
s.tx,
"INSERT INTO accounts (name, platform, type) VALUES ($1, $2, $3) RETURNING id",
"a1", service.PlatformAnthropic, service.AccountTypeOAuth,
).Scan(&a1))
[]any{"a1", service.PlatformAnthropic, service.AccountTypeOAuth},
&a1,
))
var a2 int64
s.Require().NoError(s.tx.QueryRowContext(
s.Require().NoError(scanSingleRow(
s.ctx,
s.tx,
"INSERT INTO accounts (name, platform, type) VALUES ($1, $2, $3) RETURNING id",
"a2", service.PlatformAnthropic, service.AccountTypeOAuth,
).Scan(&a2))
[]any{"a2", service.PlatformAnthropic, service.AccountTypeOAuth},
&a2,
))
_, err := s.tx.ExecContext(s.ctx, "INSERT INTO account_groups (account_id, group_id, priority, created_at) VALUES ($1, $2, $3, NOW())", a1, group.ID, 1)
s.Require().NoError(err)
@@ -402,11 +420,13 @@ func (s *GroupRepoSuite) TestDeleteAccountGroupsByGroupID() {
}
s.Require().NoError(s.repo.Create(s.ctx, g))
var accountID int64
s.Require().NoError(s.tx.QueryRowContext(
s.Require().NoError(scanSingleRow(
s.ctx,
s.tx,
"INSERT INTO accounts (name, platform, type) VALUES ($1, $2, $3) RETURNING id",
"acc-del", service.PlatformAnthropic, service.AccountTypeOAuth,
).Scan(&accountID))
[]any{"acc-del", service.PlatformAnthropic, service.AccountTypeOAuth},
&accountID,
))
_, err := s.tx.ExecContext(s.ctx, "INSERT INTO account_groups (account_id, group_id, priority, created_at) VALUES ($1, $2, $3, NOW())", accountID, g.ID, 1)
s.Require().NoError(err)
@@ -432,11 +452,13 @@ func (s *GroupRepoSuite) TestDeleteAccountGroupsByGroupID_MultipleAccounts() {
insertAccount := func(name string) int64 {
var id int64
s.Require().NoError(s.tx.QueryRowContext(
s.Require().NoError(scanSingleRow(
s.ctx,
s.tx,
"INSERT INTO accounts (name, platform, type) VALUES ($1, $2, $3) RETURNING id",
name, service.PlatformAnthropic, service.AccountTypeOAuth,
).Scan(&id))
[]any{name, service.PlatformAnthropic, service.AccountTypeOAuth},
&id,
))
return id
}
a1 := insertAccount("a1")

View File

@@ -36,8 +36,9 @@ const (
)
var (
integrationDB *sql.DB
integrationRedis *redisclient.Client
integrationDB *sql.DB
integrationEntClient *dbent.Client
integrationRedis *redisclient.Client
redisNamespaceSeq uint64
)
@@ -101,6 +102,10 @@ func TestMain(m *testing.M) {
os.Exit(1)
}
// 创建 ent client 用于集成测试
drv := entsql.OpenDB(dialect.Postgres, integrationDB)
integrationEntClient = dbent.NewClient(dbent.Driver(drv))
redisHost, err := redisContainer.Host(ctx)
if err != nil {
log.Printf("failed to get redis host: %v", err)
@@ -123,6 +128,7 @@ func TestMain(m *testing.M) {
code := m.Run()
_ = integrationEntClient.Close()
_ = integrationRedis.Close()
_ = integrationDB.Close()
@@ -193,18 +199,38 @@ func testTx(t *testing.T) *sql.Tx {
return tx
}
// testEntClient 返回全局的 ent client用于测试需要内部管理事务的代码如 Create/Update 方法)。
// 注意:此 client 的操作会真正写入数据库,测试结束后不会自动回滚。
func testEntClient(t *testing.T) *dbent.Client {
t.Helper()
return integrationEntClient
}
// testEntTx 返回一个 ent 事务,用于需要事务隔离的测试。
// 测试结束后会自动回滚,不会影响数据库状态。
func testEntTx(t *testing.T) *dbent.Tx {
t.Helper()
tx, err := integrationEntClient.Tx(context.Background())
require.NoError(t, err, "begin ent tx")
t.Cleanup(func() {
_ = tx.Rollback()
})
return tx
}
// testEntSQLTx 已弃用:不要在新测试中使用此函数。
// 基于 *sql.Tx 创建的 ent client 在调用 client.Tx() 时会 panic。
// 对于需要测试内部使用事务的代码,请使用 testEntClient。
// 对于需要事务隔离的测试,请使用 testEntTx。
//
// Deprecated: Use testEntClient or testEntTx instead.
func testEntSQLTx(t *testing.T) (*dbent.Client, *sql.Tx) {
t.Helper()
tx := testTx(t)
drv := entsql.NewDriver(dialect.Postgres, entsql.Conn{ExecQuerier: tx})
client := dbent.NewClient(dbent.Driver(drv))
t.Cleanup(func() {
_ = client.Close()
})
return client, tx
// 直接失败,避免旧测试误用导致的事务嵌套 panic。
t.Fatalf("testEntSQLTx 已弃用:请使用 testEntClient 或 testEntTx")
return nil, nil
}
func testRedis(t *testing.T) *redisclient.Client {
@@ -363,13 +389,16 @@ type IntegrationDBSuite struct {
suite.Suite
ctx context.Context
client *dbent.Client
tx *sql.Tx
tx *dbent.Tx
}
// SetupTest initializes ctx and client for each test method.
func (s *IntegrationDBSuite) SetupTest() {
s.ctx = context.Background()
s.client, s.tx = testEntSQLTx(s.T())
// 统一使用 ent.Tx确保每个测试都有独立事务并自动回滚。
tx := testEntTx(s.T())
s.tx = tx
s.client = tx.Client()
}
// RequireNoError is a convenience method wrapping require.NoError with s.T().

View File

@@ -13,7 +13,6 @@ import (
type sqlQuerier interface {
QueryContext(ctx context.Context, query string, args ...any) (*sql.Rows, error)
QueryRowContext(ctx context.Context, query string, args ...any) *sql.Row
}
type proxyRepository struct {
@@ -170,9 +169,8 @@ func (r *proxyRepository) ExistsByHostPortAuth(ctx context.Context, host string,
// CountAccountsByProxyID returns the number of accounts using a specific proxy
func (r *proxyRepository) CountAccountsByProxyID(ctx context.Context, proxyID int64) (int64, error) {
row := r.sql.QueryRowContext(ctx, "SELECT COUNT(*) FROM accounts WHERE proxy_id = $1", proxyID)
var count int64
if err := row.Scan(&count); err != nil {
if err := scanSingleRow(ctx, r.sql, "SELECT COUNT(*) FROM accounts WHERE proxy_id = $1", []any{proxyID}, &count); err != nil {
return 0, err
}
return count, nil

View File

@@ -4,10 +4,10 @@ package repository
import (
"context"
"database/sql"
"testing"
"time"
dbent "github.com/Wei-Shaw/sub2api/ent"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/stretchr/testify/suite"
@@ -15,16 +15,16 @@ import (
type ProxyRepoSuite struct {
suite.Suite
ctx context.Context
sqlTx *sql.Tx
repo *proxyRepository
ctx context.Context
tx *dbent.Tx
repo *proxyRepository
}
func (s *ProxyRepoSuite) SetupTest() {
s.ctx = context.Background()
entClient, sqlTx := testEntSQLTx(s.T())
s.sqlTx = sqlTx
s.repo = newProxyRepositoryWithSQL(entClient, sqlTx)
tx := testEntTx(s.T())
s.tx = tx
s.repo = newProxyRepositoryWithSQL(tx.Client(), tx)
}
func TestProxyRepoSuite(t *testing.T) {
@@ -306,7 +306,7 @@ func (s *ProxyRepoSuite) mustCreateProxyWithTimes(name, status string, createdAt
Port: 8080,
Status: status,
})
_, err := s.sqlTx.ExecContext(s.ctx, "UPDATE proxies SET created_at = $1, updated_at = $1 WHERE id = $2", createdAt, p.ID)
_, err := s.tx.ExecContext(s.ctx, "UPDATE proxies SET created_at = $1, updated_at = $1 WHERE id = $2", createdAt, p.ID)
s.Require().NoError(err, "update proxy timestamps")
return p
}
@@ -317,7 +317,7 @@ func (s *ProxyRepoSuite) mustInsertAccount(name string, proxyID *int64) {
if proxyID != nil {
pid = *proxyID
}
_, err := s.sqlTx.ExecContext(
_, err := s.tx.ExecContext(
s.ctx,
"INSERT INTO accounts (name, platform, type, proxy_id) VALUES ($1, $2, $3, $4)",
name,

View File

@@ -22,9 +22,9 @@ type RedeemCodeRepoSuite struct {
func (s *RedeemCodeRepoSuite) SetupTest() {
s.ctx = context.Background()
entClient, _ := testEntSQLTx(s.T())
s.client = entClient
s.repo = NewRedeemCodeRepository(entClient).(*redeemCodeRepository)
tx := testEntTx(s.T())
s.client = tx.Client()
s.repo = NewRedeemCodeRepository(s.client).(*redeemCodeRepository)
}
func TestRedeemCodeRepoSuite(t *testing.T) {

View File

@@ -18,8 +18,8 @@ type SettingRepoSuite struct {
func (s *SettingRepoSuite) SetupTest() {
s.ctx = context.Background()
entClient, _ := testEntSQLTx(s.T())
s.repo = NewSettingRepository(entClient).(*settingRepository)
tx := testEntTx(s.T())
s.repo = NewSettingRepository(tx.Client()).(*settingRepository)
}
func TestSettingRepoSuite(t *testing.T) {

View File

@@ -34,7 +34,8 @@ func createEntUser(t *testing.T, ctx context.Context, client *dbent.Client, emai
func TestEntSoftDelete_ApiKey_DefaultFilterAndSkip(t *testing.T) {
ctx := context.Background()
client, _ := testEntSQLTx(t)
// 使用全局 ent client确保软删除验证在实际持久化数据上进行。
client := testEntClient(t)
u := createEntUser(t, ctx, client, uniqueSoftDeleteValue(t, "sd-user")+"@example.com")
@@ -65,7 +66,8 @@ func TestEntSoftDelete_ApiKey_DefaultFilterAndSkip(t *testing.T) {
func TestEntSoftDelete_ApiKey_DeleteIdempotent(t *testing.T) {
ctx := context.Background()
client, _ := testEntSQLTx(t)
// 使用全局 ent client避免事务回滚影响幂等性验证。
client := testEntClient(t)
u := createEntUser(t, ctx, client, uniqueSoftDeleteValue(t, "sd-user2")+"@example.com")
@@ -84,7 +86,8 @@ func TestEntSoftDelete_ApiKey_DeleteIdempotent(t *testing.T) {
func TestEntSoftDelete_ApiKey_HardDeleteViaSkipSoftDelete(t *testing.T) {
ctx := context.Background()
client, _ := testEntSQLTx(t)
// 使用全局 ent client确保 SkipSoftDelete 的硬删除语义可验证。
client := testEntClient(t)
u := createEntUser(t, ctx, client, uniqueSoftDeleteValue(t, "sd-user3")+"@example.com")

View File

@@ -0,0 +1,33 @@
package repository
import (
"context"
"database/sql"
)
type sqlQueryer interface {
QueryContext(ctx context.Context, query string, args ...any) (*sql.Rows, error)
}
// scanSingleRow executes a query and scans the first row into dest.
// If no rows are returned, sql.ErrNoRows is returned.
// 设计目的:仅依赖 QueryContext避免 QueryRowContext 对 *sql.Tx 的强绑定,
// 让 ent.Tx 也能作为 sqlExecutor/Queryer 使用。
func scanSingleRow(ctx context.Context, q sqlQueryer, query string, args []any, dest ...any) error {
rows, err := q.QueryContext(ctx, query, args...)
if err != nil {
return err
}
defer rows.Close()
if !rows.Next() {
if err := rows.Err(); err != nil {
return err
}
return sql.ErrNoRows
}
if err := rows.Scan(dest...); err != nil {
return err
}
return rows.Err()
}

View File

@@ -3,7 +3,6 @@ package repository
import (
"context"
"database/sql"
"errors"
"fmt"
"strings"
"time"
@@ -33,6 +32,7 @@ func NewUsageLogRepository(client *dbent.Client, sqlDB *sql.DB) service.UsageLog
}
func newUsageLogRepositoryWithSQL(client *dbent.Client, sqlq sqlExecutor) *usageLogRepository {
// 使用 scanSingleRow 替代 QueryRowContext保证 ent.Tx 作为 sqlExecutor 可用。
return &usageLogRepository{client: client, sql: sqlq}
}
@@ -53,7 +53,7 @@ func (r *usageLogRepository) getPerformanceStats(ctx context.Context, userID int
var requestCount int64
var tokenCount int64
if err := r.sql.QueryRowContext(ctx, query, args...).Scan(&requestCount, &tokenCount); err != nil {
if err := scanSingleRow(ctx, r.sql, query, args, &requestCount, &tokenCount); err != nil {
return 0, 0, err
}
return requestCount / 5, tokenCount / 5, nil
@@ -114,9 +114,7 @@ func (r *usageLogRepository) Create(ctx context.Context, log *service.UsageLog)
duration := nullInt(log.DurationMs)
firstToken := nullInt(log.FirstTokenMs)
row := r.sql.QueryRowContext(
ctx,
query,
args := []any{
log.UserID,
log.ApiKeyID,
log.AccountID,
@@ -142,9 +140,8 @@ func (r *usageLogRepository) Create(ctx context.Context, log *service.UsageLog)
duration,
firstToken,
createdAt,
)
if err := row.Scan(&log.ID, &log.CreatedAt); err != nil {
}
if err := scanSingleRow(ctx, r.sql, query, args, &log.ID, &log.CreatedAt); err != nil {
return err
}
log.RateMultiplier = rateMultiplier
@@ -153,11 +150,22 @@ func (r *usageLogRepository) Create(ctx context.Context, log *service.UsageLog)
func (r *usageLogRepository) GetByID(ctx context.Context, id int64) (*service.UsageLog, error) {
query := "SELECT " + usageLogSelectColumns + " FROM usage_logs WHERE id = $1"
log, err := scanUsageLog(r.sql.QueryRowContext(ctx, query, id))
rows, err := r.sql.QueryContext(ctx, query, id)
if err != nil {
if errors.Is(err, sql.ErrNoRows) {
return nil, service.ErrUsageLogNotFound
return nil, err
}
defer rows.Close()
if !rows.Next() {
if err := rows.Err(); err != nil {
return nil, err
}
return nil, service.ErrUsageLogNotFound
}
log, err := scanUsageLog(rows)
if err != nil {
return nil, err
}
if err := rows.Err(); err != nil {
return nil, err
}
return log, nil
@@ -195,8 +203,18 @@ func (r *usageLogRepository) GetUserStats(ctx context.Context, userID int64, sta
`
stats := &UserStats{}
if err := r.sql.QueryRowContext(ctx, query, userID, startTime, endTime).
Scan(&stats.TotalRequests, &stats.TotalTokens, &stats.TotalCost, &stats.InputTokens, &stats.OutputTokens, &stats.CacheReadTokens); err != nil {
if err := scanSingleRow(
ctx,
r.sql,
query,
[]any{userID, startTime, endTime},
&stats.TotalRequests,
&stats.TotalTokens,
&stats.TotalCost,
&stats.InputTokens,
&stats.OutputTokens,
&stats.CacheReadTokens,
); err != nil {
return nil, err
}
return stats, nil
@@ -219,8 +237,15 @@ func (r *usageLogRepository) GetDashboardStats(ctx context.Context) (*DashboardS
FROM users
WHERE deleted_at IS NULL
`
if err := r.sql.QueryRowContext(ctx, userStatsQuery, today, today).
Scan(&stats.TotalUsers, &stats.TodayNewUsers, &stats.ActiveUsers); err != nil {
if err := scanSingleRow(
ctx,
r.sql,
userStatsQuery,
[]any{today, today},
&stats.TotalUsers,
&stats.TodayNewUsers,
&stats.ActiveUsers,
); err != nil {
return nil, err
}
@@ -232,8 +257,14 @@ func (r *usageLogRepository) GetDashboardStats(ctx context.Context) (*DashboardS
FROM api_keys
WHERE deleted_at IS NULL
`
if err := r.sql.QueryRowContext(ctx, apiKeyStatsQuery, service.StatusActive).
Scan(&stats.TotalApiKeys, &stats.ActiveApiKeys); err != nil {
if err := scanSingleRow(
ctx,
r.sql,
apiKeyStatsQuery,
[]any{service.StatusActive},
&stats.TotalApiKeys,
&stats.ActiveApiKeys,
); err != nil {
return nil, err
}
@@ -248,8 +279,17 @@ func (r *usageLogRepository) GetDashboardStats(ctx context.Context) (*DashboardS
FROM accounts
WHERE deleted_at IS NULL
`
if err := r.sql.QueryRowContext(ctx, accountStatsQuery, service.StatusActive, service.StatusError, now, now).
Scan(&stats.TotalAccounts, &stats.NormalAccounts, &stats.ErrorAccounts, &stats.RateLimitAccounts, &stats.OverloadAccounts); err != nil {
if err := scanSingleRow(
ctx,
r.sql,
accountStatsQuery,
[]any{service.StatusActive, service.StatusError, now, now},
&stats.TotalAccounts,
&stats.NormalAccounts,
&stats.ErrorAccounts,
&stats.RateLimitAccounts,
&stats.OverloadAccounts,
); err != nil {
return nil, err
}
@@ -266,17 +306,20 @@ func (r *usageLogRepository) GetDashboardStats(ctx context.Context) (*DashboardS
COALESCE(AVG(duration_ms), 0) as avg_duration_ms
FROM usage_logs
`
if err := r.sql.QueryRowContext(ctx, totalStatsQuery).
Scan(
&stats.TotalRequests,
&stats.TotalInputTokens,
&stats.TotalOutputTokens,
&stats.TotalCacheCreationTokens,
&stats.TotalCacheReadTokens,
&stats.TotalCost,
&stats.TotalActualCost,
&stats.AverageDurationMs,
); err != nil {
if err := scanSingleRow(
ctx,
r.sql,
totalStatsQuery,
nil,
&stats.TotalRequests,
&stats.TotalInputTokens,
&stats.TotalOutputTokens,
&stats.TotalCacheCreationTokens,
&stats.TotalCacheReadTokens,
&stats.TotalCost,
&stats.TotalActualCost,
&stats.AverageDurationMs,
); err != nil {
return nil, err
}
stats.TotalTokens = stats.TotalInputTokens + stats.TotalOutputTokens + stats.TotalCacheCreationTokens + stats.TotalCacheReadTokens
@@ -294,16 +337,19 @@ func (r *usageLogRepository) GetDashboardStats(ctx context.Context) (*DashboardS
FROM usage_logs
WHERE created_at >= $1
`
if err := r.sql.QueryRowContext(ctx, todayStatsQuery, today).
Scan(
&stats.TodayRequests,
&stats.TodayInputTokens,
&stats.TodayOutputTokens,
&stats.TodayCacheCreationTokens,
&stats.TodayCacheReadTokens,
&stats.TodayCost,
&stats.TodayActualCost,
); err != nil {
if err := scanSingleRow(
ctx,
r.sql,
todayStatsQuery,
[]any{today},
&stats.TodayRequests,
&stats.TodayInputTokens,
&stats.TodayOutputTokens,
&stats.TodayCacheCreationTokens,
&stats.TodayCacheReadTokens,
&stats.TodayCost,
&stats.TodayActualCost,
); err != nil {
return nil, err
}
stats.TodayTokens = stats.TodayInputTokens + stats.TodayOutputTokens + stats.TodayCacheCreationTokens + stats.TodayCacheReadTokens
@@ -345,16 +391,19 @@ func (r *usageLogRepository) GetUserStatsAggregated(ctx context.Context, userID
`
var stats usagestats.UsageStats
if err := r.sql.QueryRowContext(ctx, query, userID, startTime, endTime).
Scan(
&stats.TotalRequests,
&stats.TotalInputTokens,
&stats.TotalOutputTokens,
&stats.TotalCacheTokens,
&stats.TotalCost,
&stats.TotalActualCost,
&stats.AverageDurationMs,
); err != nil {
if err := scanSingleRow(
ctx,
r.sql,
query,
[]any{userID, startTime, endTime},
&stats.TotalRequests,
&stats.TotalInputTokens,
&stats.TotalOutputTokens,
&stats.TotalCacheTokens,
&stats.TotalCost,
&stats.TotalActualCost,
&stats.AverageDurationMs,
); err != nil {
return nil, err
}
stats.TotalTokens = stats.TotalInputTokens + stats.TotalOutputTokens + stats.TotalCacheTokens
@@ -377,16 +426,19 @@ func (r *usageLogRepository) GetApiKeyStatsAggregated(ctx context.Context, apiKe
`
var stats usagestats.UsageStats
if err := r.sql.QueryRowContext(ctx, query, apiKeyID, startTime, endTime).
Scan(
&stats.TotalRequests,
&stats.TotalInputTokens,
&stats.TotalOutputTokens,
&stats.TotalCacheTokens,
&stats.TotalCost,
&stats.TotalActualCost,
&stats.AverageDurationMs,
); err != nil {
if err := scanSingleRow(
ctx,
r.sql,
query,
[]any{apiKeyID, startTime, endTime},
&stats.TotalRequests,
&stats.TotalInputTokens,
&stats.TotalOutputTokens,
&stats.TotalCacheTokens,
&stats.TotalCost,
&stats.TotalActualCost,
&stats.AverageDurationMs,
); err != nil {
return nil, err
}
stats.TotalTokens = stats.TotalInputTokens + stats.TotalOutputTokens + stats.TotalCacheTokens
@@ -430,8 +482,15 @@ func (r *usageLogRepository) GetAccountTodayStats(ctx context.Context, accountID
`
stats := &usagestats.AccountStats{}
if err := r.sql.QueryRowContext(ctx, query, accountID, today).
Scan(&stats.Requests, &stats.Tokens, &stats.Cost); err != nil {
if err := scanSingleRow(
ctx,
r.sql,
query,
[]any{accountID, today},
&stats.Requests,
&stats.Tokens,
&stats.Cost,
); err != nil {
return nil, err
}
return stats, nil
@@ -449,8 +508,15 @@ func (r *usageLogRepository) GetAccountWindowStats(ctx context.Context, accountI
`
stats := &usagestats.AccountStats{}
if err := r.sql.QueryRowContext(ctx, query, accountID, startTime).
Scan(&stats.Requests, &stats.Tokens, &stats.Cost); err != nil {
if err := scanSingleRow(
ctx,
r.sql,
query,
[]any{accountID, startTime},
&stats.Requests,
&stats.Tokens,
&stats.Cost,
); err != nil {
return nil, err
}
return stats, nil
@@ -581,12 +647,22 @@ func (r *usageLogRepository) GetUserDashboardStats(ctx context.Context, userID i
today := timezone.Today()
// API Key 统计
if err := r.sql.QueryRowContext(ctx, "SELECT COUNT(*) FROM api_keys WHERE user_id = $1 AND deleted_at IS NULL", userID).
Scan(&stats.TotalApiKeys); err != nil {
if err := scanSingleRow(
ctx,
r.sql,
"SELECT COUNT(*) FROM api_keys WHERE user_id = $1 AND deleted_at IS NULL",
[]any{userID},
&stats.TotalApiKeys,
); err != nil {
return nil, err
}
if err := r.sql.QueryRowContext(ctx, "SELECT COUNT(*) FROM api_keys WHERE user_id = $1 AND status = $2 AND deleted_at IS NULL", userID, service.StatusActive).
Scan(&stats.ActiveApiKeys); err != nil {
if err := scanSingleRow(
ctx,
r.sql,
"SELECT COUNT(*) FROM api_keys WHERE user_id = $1 AND status = $2 AND deleted_at IS NULL",
[]any{userID, service.StatusActive},
&stats.ActiveApiKeys,
); err != nil {
return nil, err
}
@@ -604,17 +680,20 @@ func (r *usageLogRepository) GetUserDashboardStats(ctx context.Context, userID i
FROM usage_logs
WHERE user_id = $1
`
if err := r.sql.QueryRowContext(ctx, totalStatsQuery, userID).
Scan(
&stats.TotalRequests,
&stats.TotalInputTokens,
&stats.TotalOutputTokens,
&stats.TotalCacheCreationTokens,
&stats.TotalCacheReadTokens,
&stats.TotalCost,
&stats.TotalActualCost,
&stats.AverageDurationMs,
); err != nil {
if err := scanSingleRow(
ctx,
r.sql,
totalStatsQuery,
[]any{userID},
&stats.TotalRequests,
&stats.TotalInputTokens,
&stats.TotalOutputTokens,
&stats.TotalCacheCreationTokens,
&stats.TotalCacheReadTokens,
&stats.TotalCost,
&stats.TotalActualCost,
&stats.AverageDurationMs,
); err != nil {
return nil, err
}
stats.TotalTokens = stats.TotalInputTokens + stats.TotalOutputTokens + stats.TotalCacheCreationTokens + stats.TotalCacheReadTokens
@@ -632,16 +711,19 @@ func (r *usageLogRepository) GetUserDashboardStats(ctx context.Context, userID i
FROM usage_logs
WHERE user_id = $1 AND created_at >= $2
`
if err := r.sql.QueryRowContext(ctx, todayStatsQuery, userID, today).
Scan(
&stats.TodayRequests,
&stats.TodayInputTokens,
&stats.TodayOutputTokens,
&stats.TodayCacheCreationTokens,
&stats.TodayCacheReadTokens,
&stats.TodayCost,
&stats.TodayActualCost,
); err != nil {
if err := scanSingleRow(
ctx,
r.sql,
todayStatsQuery,
[]any{userID, today},
&stats.TodayRequests,
&stats.TodayInputTokens,
&stats.TodayOutputTokens,
&stats.TodayCacheCreationTokens,
&stats.TodayCacheReadTokens,
&stats.TodayCost,
&stats.TodayActualCost,
); err != nil {
return nil, err
}
stats.TodayTokens = stats.TodayInputTokens + stats.TodayOutputTokens + stats.TodayCacheCreationTokens + stats.TodayCacheReadTokens
@@ -1007,16 +1089,19 @@ func (r *usageLogRepository) GetGlobalStats(ctx context.Context, startTime, endT
`
stats := &UsageStats{}
if err := r.sql.QueryRowContext(ctx, query, startTime, endTime).
Scan(
&stats.TotalRequests,
&stats.TotalInputTokens,
&stats.TotalOutputTokens,
&stats.TotalCacheTokens,
&stats.TotalCost,
&stats.TotalActualCost,
&stats.AverageDurationMs,
); err != nil {
if err := scanSingleRow(
ctx,
r.sql,
query,
[]any{startTime, endTime},
&stats.TotalRequests,
&stats.TotalInputTokens,
&stats.TotalOutputTokens,
&stats.TotalCacheTokens,
&stats.TotalCost,
&stats.TotalActualCost,
&stats.AverageDurationMs,
); err != nil {
return nil, err
}
stats.TotalTokens = stats.TotalInputTokens + stats.TotalOutputTokens + stats.TotalCacheTokens
@@ -1108,7 +1193,7 @@ func (r *usageLogRepository) GetAccountUsageStats(ctx context.Context, accountID
avgQuery := "SELECT COALESCE(AVG(duration_ms), 0) as avg_duration_ms FROM usage_logs WHERE account_id = $1 AND created_at >= $2 AND created_at < $3"
var avgDuration float64
if err := r.sql.QueryRowContext(ctx, avgQuery, accountID, startTime, endTime).Scan(&avgDuration); err != nil {
if err := scanSingleRow(ctx, r.sql, avgQuery, []any{accountID, startTime, endTime}, &avgDuration); err != nil {
return nil, err
}
@@ -1186,7 +1271,7 @@ func (r *usageLogRepository) GetAccountUsageStats(ctx context.Context, accountID
func (r *usageLogRepository) listUsageLogsWithPagination(ctx context.Context, whereClause string, args []any, params pagination.PaginationParams) ([]service.UsageLog, *pagination.PaginationResult, error) {
countQuery := "SELECT COUNT(*) FROM usage_logs " + whereClause
var total int64
if err := r.sql.QueryRowContext(ctx, countQuery, args...).Scan(&total); err != nil {
if err := scanSingleRow(ctx, r.sql, countQuery, args, &total); err != nil {
return nil, nil, err
}

View File

@@ -4,7 +4,6 @@ package repository
import (
"context"
"database/sql"
"testing"
"time"
@@ -19,17 +18,17 @@ import (
type UsageLogRepoSuite struct {
suite.Suite
ctx context.Context
tx *sql.Tx
tx *dbent.Tx
client *dbent.Client
repo *usageLogRepository
}
func (s *UsageLogRepoSuite) SetupTest() {
s.ctx = context.Background()
client, tx := testEntSQLTx(s.T())
s.client = client
tx := testEntTx(s.T())
s.tx = tx
s.repo = newUsageLogRepositoryWithSQL(client, tx)
s.client = tx.Client()
s.repo = newUsageLogRepositoryWithSQL(s.client, tx)
}
func TestUsageLogRepoSuite(t *testing.T) {
@@ -197,6 +196,8 @@ func (s *UsageLogRepoSuite) TestListWithFilters() {
func (s *UsageLogRepoSuite) TestDashboardStats_TodayTotalsAndPerformance() {
now := time.Now()
todayStart := timezone.Today()
baseStats, err := s.repo.GetDashboardStats(s.ctx)
s.Require().NoError(err, "GetDashboardStats base")
userToday := mustCreateUser(s.T(), s.client, &service.User{
Email: "today@example.com",
@@ -268,24 +269,24 @@ func (s *UsageLogRepoSuite) TestDashboardStats_TodayTotalsAndPerformance() {
stats, err := s.repo.GetDashboardStats(s.ctx)
s.Require().NoError(err, "GetDashboardStats")
s.Require().Equal(int64(2), stats.TotalUsers, "TotalUsers mismatch")
s.Require().Equal(int64(1), stats.TodayNewUsers, "TodayNewUsers mismatch")
s.Require().Equal(int64(1), stats.ActiveUsers, "ActiveUsers mismatch")
s.Require().Equal(int64(2), stats.TotalApiKeys, "TotalApiKeys mismatch")
s.Require().Equal(int64(1), stats.ActiveApiKeys, "ActiveApiKeys mismatch")
s.Require().Equal(int64(4), stats.TotalAccounts, "TotalAccounts mismatch")
s.Require().Equal(int64(1), stats.ErrorAccounts, "ErrorAccounts mismatch")
s.Require().Equal(int64(1), stats.RateLimitAccounts, "RateLimitAccounts mismatch")
s.Require().Equal(int64(1), stats.OverloadAccounts, "OverloadAccounts mismatch")
s.Require().Equal(baseStats.TotalUsers+2, stats.TotalUsers, "TotalUsers mismatch")
s.Require().Equal(baseStats.TodayNewUsers+1, stats.TodayNewUsers, "TodayNewUsers mismatch")
s.Require().Equal(baseStats.ActiveUsers+1, stats.ActiveUsers, "ActiveUsers mismatch")
s.Require().Equal(baseStats.TotalApiKeys+2, stats.TotalApiKeys, "TotalApiKeys mismatch")
s.Require().Equal(baseStats.ActiveApiKeys+1, stats.ActiveApiKeys, "ActiveApiKeys mismatch")
s.Require().Equal(baseStats.TotalAccounts+4, stats.TotalAccounts, "TotalAccounts mismatch")
s.Require().Equal(baseStats.ErrorAccounts+1, stats.ErrorAccounts, "ErrorAccounts mismatch")
s.Require().Equal(baseStats.RateLimitAccounts+1, stats.RateLimitAccounts, "RateLimitAccounts mismatch")
s.Require().Equal(baseStats.OverloadAccounts+1, stats.OverloadAccounts, "OverloadAccounts mismatch")
s.Require().Equal(int64(3), stats.TotalRequests, "TotalRequests mismatch")
s.Require().Equal(int64(16), stats.TotalInputTokens, "TotalInputTokens mismatch")
s.Require().Equal(int64(28), stats.TotalOutputTokens, "TotalOutputTokens mismatch")
s.Require().Equal(int64(3), stats.TotalCacheCreationTokens, "TotalCacheCreationTokens mismatch")
s.Require().Equal(int64(4), stats.TotalCacheReadTokens, "TotalCacheReadTokens mismatch")
s.Require().Equal(int64(51), stats.TotalTokens, "TotalTokens mismatch")
s.Require().Equal(2.3, stats.TotalCost, "TotalCost mismatch")
s.Require().Equal(2.0, stats.TotalActualCost, "TotalActualCost mismatch")
s.Require().Equal(baseStats.TotalRequests+3, stats.TotalRequests, "TotalRequests mismatch")
s.Require().Equal(baseStats.TotalInputTokens+int64(16), stats.TotalInputTokens, "TotalInputTokens mismatch")
s.Require().Equal(baseStats.TotalOutputTokens+int64(28), stats.TotalOutputTokens, "TotalOutputTokens mismatch")
s.Require().Equal(baseStats.TotalCacheCreationTokens+int64(3), stats.TotalCacheCreationTokens, "TotalCacheCreationTokens mismatch")
s.Require().Equal(baseStats.TotalCacheReadTokens+int64(4), stats.TotalCacheReadTokens, "TotalCacheReadTokens mismatch")
s.Require().Equal(baseStats.TotalTokens+int64(51), stats.TotalTokens, "TotalTokens mismatch")
s.Require().Equal(baseStats.TotalCost+2.3, stats.TotalCost, "TotalCost mismatch")
s.Require().Equal(baseStats.TotalActualCost+2.0, stats.TotalActualCost, "TotalActualCost mismatch")
s.Require().GreaterOrEqual(stats.TodayRequests, int64(1), "expected TodayRequests >= 1")
s.Require().GreaterOrEqual(stats.TodayCost, 0.0, "expected TodayCost >= 0")

View File

@@ -3,6 +3,7 @@ package repository
import (
"context"
"database/sql"
"errors"
"sort"
dbent "github.com/Wei-Shaw/sub2api/ent"
@@ -17,7 +18,6 @@ import (
type userRepository struct {
client *dbent.Client
sql sqlExecutor
begin sqlBeginner
}
func NewUserRepository(client *dbent.Client, sqlDB *sql.DB) service.UserRepository {
@@ -25,11 +25,7 @@ func NewUserRepository(client *dbent.Client, sqlDB *sql.DB) service.UserReposito
}
func newUserRepositoryWithSQL(client *dbent.Client, sqlq sqlExecutor) *userRepository {
var beginner sqlBeginner
if b, ok := sqlq.(sqlBeginner); ok {
beginner = b
}
return &userRepository{client: client, sql: sqlq, begin: beginner}
return &userRepository{client: client, sql: sqlq}
}
func (r *userRepository) Create(ctx context.Context, userIn *service.User) error {
@@ -37,22 +33,20 @@ func (r *userRepository) Create(ctx context.Context, userIn *service.User) error
return nil
}
exec := r.sql
txClient := r.client
var sqlTx *sql.Tx
// 统一使用 ent 的事务:保证用户与允许分组的更新原子化,
// 并避免基于 *sql.Tx 手动构造 ent client 导致的 ExecQuerier 断言错误。
tx, err := r.client.Tx(ctx)
if err != nil && !errors.Is(err, dbent.ErrTxStarted) {
return err
}
if r.begin != nil {
var err error
sqlTx, err = r.begin.BeginTx(ctx, nil)
if err != nil {
return err
}
exec = sqlTx
txClient = entClientFromSQLTx(sqlTx)
// 注意:不能调用 txClient.Close(),因为基于事务的 ent client
// 在 Close() 时会尝试将 ExecQuerier 断言为 *sql.DB但实际是 *sql.Tx
// 事务的清理通过 sqlTx.Rollback() 和 sqlTx.Commit() 完成
defer func() { _ = sqlTx.Rollback() }()
var txClient *dbent.Client
if err == nil {
defer func() { _ = tx.Rollback() }()
txClient = tx.Client()
} else {
// 已处于外部事务中ErrTxStarted复用当前 client 并由调用方负责提交/回滚。
txClient = r.client
}
created, err := txClient.User.Create().
@@ -70,12 +64,12 @@ func (r *userRepository) Create(ctx context.Context, userIn *service.User) error
return translatePersistenceError(err, nil, service.ErrEmailExists)
}
if err := r.syncUserAllowedGroups(ctx, txClient, exec, created.ID, userIn.AllowedGroups); err != nil {
if err := r.syncUserAllowedGroupsWithClient(ctx, txClient, created.ID, userIn.AllowedGroups); err != nil {
return err
}
if sqlTx != nil {
if err := sqlTx.Commit(); err != nil {
if tx != nil {
if err := tx.Commit(); err != nil {
return err
}
}
@@ -121,22 +115,19 @@ func (r *userRepository) Update(ctx context.Context, userIn *service.User) error
return nil
}
exec := r.sql
txClient := r.client
var sqlTx *sql.Tx
// 使用 ent 事务包裹用户更新与 allowed_groups 同步,避免跨层事务不一致。
tx, err := r.client.Tx(ctx)
if err != nil && !errors.Is(err, dbent.ErrTxStarted) {
return err
}
if r.begin != nil {
var err error
sqlTx, err = r.begin.BeginTx(ctx, nil)
if err != nil {
return err
}
exec = sqlTx
txClient = entClientFromSQLTx(sqlTx)
// 注意:不能调用 txClient.Close(),因为基于事务的 ent client
// 在 Close() 时会尝试将 ExecQuerier 断言为 *sql.DB但实际是 *sql.Tx
// 事务的清理通过 sqlTx.Rollback() 和 sqlTx.Commit() 完成
defer func() { _ = sqlTx.Rollback() }()
var txClient *dbent.Client
if err == nil {
defer func() { _ = tx.Rollback() }()
txClient = tx.Client()
} else {
// 已处于外部事务中ErrTxStarted复用当前 client 并由调用方负责提交/回滚。
txClient = r.client
}
updated, err := txClient.User.UpdateOneID(userIn.ID).
@@ -154,12 +145,12 @@ func (r *userRepository) Update(ctx context.Context, userIn *service.User) error
return translatePersistenceError(err, service.ErrUserNotFound, service.ErrEmailExists)
}
if err := r.syncUserAllowedGroups(ctx, txClient, exec, updated.ID, userIn.AllowedGroups); err != nil {
if err := r.syncUserAllowedGroupsWithClient(ctx, txClient, updated.ID, userIn.AllowedGroups); err != nil {
return err
}
if sqlTx != nil {
if err := sqlTx.Commit(); err != nil {
if tx != nil {
if err := tx.Commit(); err != nil {
return err
}
}
@@ -289,8 +280,10 @@ func (r *userRepository) ExistsByEmail(ctx context.Context, email string) (bool,
}
func (r *userRepository) RemoveGroupFromAllowedGroups(ctx context.Context, groupID int64) (int64, error) {
if r.sql == nil {
return 0, nil
exec := r.sql
if exec == nil {
// 未注入 sqlExecutor 时,退回到 ent client 的 ExecContext支持事务
exec = r.client
}
joinAffected, err := r.client.UserAllowedGroup.Delete().
@@ -300,7 +293,7 @@ func (r *userRepository) RemoveGroupFromAllowedGroups(ctx context.Context, group
return 0, err
}
arrayRes, err := r.sql.ExecContext(
arrayRes, err := exec.ExecContext(
ctx,
"UPDATE users SET allowed_groups = array_remove(allowed_groups, $1), updated_at = NOW() WHERE $1 = ANY(allowed_groups)",
groupID,
@@ -362,6 +355,56 @@ func (r *userRepository) loadAllowedGroups(ctx context.Context, userIDs []int64)
return out, nil
}
// syncUserAllowedGroupsWithClient 在 ent client/事务内同步用户允许分组:
// 1) 以 user_allowed_groups 为读写源,确保新旧逻辑一致;
// 2) 额外更新 users.allowed_groups历史字段以保持兼容。
func (r *userRepository) syncUserAllowedGroupsWithClient(ctx context.Context, client *dbent.Client, userID int64, groupIDs []int64) error {
if client == nil {
return nil
}
// Keep join table as the source of truth for reads.
if _, err := client.UserAllowedGroup.Delete().Where(userallowedgroup.UserIDEQ(userID)).Exec(ctx); err != nil {
return err
}
unique := make(map[int64]struct{}, len(groupIDs))
for _, id := range groupIDs {
if id <= 0 {
continue
}
unique[id] = struct{}{}
}
legacyGroups := make([]int64, 0, len(unique))
if len(unique) > 0 {
creates := make([]*dbent.UserAllowedGroupCreate, 0, len(unique))
for groupID := range unique {
creates = append(creates, client.UserAllowedGroup.Create().SetUserID(userID).SetGroupID(groupID))
legacyGroups = append(legacyGroups, groupID)
}
if err := client.UserAllowedGroup.
CreateBulk(creates...).
OnConflictColumns(userallowedgroup.FieldUserID, userallowedgroup.FieldGroupID).
DoNothing().
Exec(ctx); err != nil {
return err
}
}
// Phase 1 兼容:保持 users.allowed_groups数组字段同步避免旧查询路径读取到过期数据。
var legacy any
if len(legacyGroups) > 0 {
sort.Slice(legacyGroups, func(i, j int) bool { return legacyGroups[i] < legacyGroups[j] })
legacy = pq.Array(legacyGroups)
}
if _, err := client.ExecContext(ctx, "UPDATE users SET allowed_groups = $1::bigint[] WHERE id = $2", legacy, userID); err != nil {
return err
}
return nil
}
func (r *userRepository) syncUserAllowedGroups(ctx context.Context, client *dbent.Client, exec sqlExecutor, userID int64, groupIDs []int64) error {
if client == nil || exec == nil {
return nil

View File

@@ -4,7 +4,6 @@ package repository
import (
"context"
"database/sql"
"testing"
"time"
@@ -17,17 +16,19 @@ import (
type UserRepoSuite struct {
suite.Suite
ctx context.Context
tx *sql.Tx
client *dbent.Client
repo *userRepository
}
func (s *UserRepoSuite) SetupTest() {
s.ctx = context.Background()
entClient, tx := testEntSQLTx(s.T())
s.tx = tx
s.client = entClient
s.repo = newUserRepositoryWithSQL(entClient, tx)
s.client = testEntClient(s.T())
s.repo = newUserRepositoryWithSQL(s.client, integrationDB)
// 清理测试数据,确保每个测试从干净状态开始
_, _ = integrationDB.ExecContext(s.ctx, "DELETE FROM user_subscriptions")
_, _ = integrationDB.ExecContext(s.ctx, "DELETE FROM user_allowed_groups")
_, _ = integrationDB.ExecContext(s.ctx, "DELETE FROM users")
}
func TestUserRepoSuite(t *testing.T) {

View File

@@ -22,8 +22,8 @@ type UserSubscriptionRepoSuite struct {
func (s *UserSubscriptionRepoSuite) SetupTest() {
s.ctx = context.Background()
client, _ := testEntSQLTx(s.T())
s.client = client
tx := testEntTx(s.T())
s.client = tx.Client()
s.repo = NewUserSubscriptionRepository(s.client).(*userSubscriptionRepository)
}
@@ -66,8 +66,8 @@ func (s *UserSubscriptionRepoSuite) mustCreateSubscription(userID, groupID int64
create := s.client.UserSubscription.Create().
SetUserID(userID).
SetGroupID(groupID).
SetStartsAt(now.Add(-1*time.Hour)).
SetExpiresAt(now.Add(24*time.Hour)).
SetStartsAt(now.Add(-1 * time.Hour)).
SetExpiresAt(now.Add(24 * time.Hour)).
SetStatus(service.SubscriptionStatusActive).
SetAssignedAt(now).
SetNotes("")
@@ -631,4 +631,3 @@ func (s *UserSubscriptionRepoSuite) TestActiveExpiredBoundaries_UsageAndReset_Ba
s.Require().NoError(err, "GetByID expired")
s.Require().Equal(service.SubscriptionStatusExpired, updated.Status, "expected status expired")
}