fix(仓储): 修复软删除过滤与事务测试
修复软删除拦截器使用错误,确保默认查询过滤已删记录 仓储层改用 ent.Tx 与扫描辅助,避免 sql.Tx 断言问题 同步更新集成测试以覆盖事务与统计变动
This commit is contained in:
@@ -7,6 +7,7 @@ import (
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/ent/intercept"
|
||||
"entgo.io/ent"
|
||||
"entgo.io/ent/dialect"
|
||||
"entgo.io/ent/dialect/sql"
|
||||
@@ -79,16 +80,13 @@ func SkipSoftDelete(parent context.Context) context.Context {
|
||||
// 确保软删除的记录不会出现在普通查询结果中。
|
||||
func (d SoftDeleteMixin) Interceptors() []ent.Interceptor {
|
||||
return []ent.Interceptor{
|
||||
ent.TraverseFunc(func(ctx context.Context, q ent.Query) error {
|
||||
intercept.TraverseFunc(func(ctx context.Context, q intercept.Query) error {
|
||||
// 检查是否需要跳过软删除过滤
|
||||
if skip, _ := ctx.Value(softDeleteKey{}).(bool); skip {
|
||||
return nil
|
||||
}
|
||||
// 为查询添加 deleted_at IS NULL 条件
|
||||
w, ok := q.(interface{ WhereP(...func(*sql.Selector)) })
|
||||
if ok {
|
||||
d.applyPredicate(w)
|
||||
}
|
||||
d.applyPredicate(q)
|
||||
return nil
|
||||
}),
|
||||
}
|
||||
|
||||
@@ -28,6 +28,7 @@ import (
|
||||
"github.com/lib/pq"
|
||||
|
||||
entsql "entgo.io/ent/dialect/sql"
|
||||
"entgo.io/ent/dialect/sql/sqljson"
|
||||
)
|
||||
|
||||
// accountRepository 实现 service.AccountRepository 接口。
|
||||
@@ -36,11 +37,9 @@ import (
|
||||
// 设计说明:
|
||||
// - client: Ent 客户端,用于类型安全的 ORM 操作
|
||||
// - sql: 原生 SQL 执行器,用于复杂查询和批量操作
|
||||
// - begin: SQL 事务开启器,用于需要事务的操作
|
||||
type accountRepository struct {
|
||||
client *dbent.Client // Ent ORM 客户端
|
||||
sql sqlExecutor // 原生 SQL 执行接口
|
||||
begin sqlBeginner // 事务开启接口
|
||||
client *dbent.Client // Ent ORM 客户端
|
||||
sql sqlExecutor // 原生 SQL 执行接口
|
||||
}
|
||||
|
||||
// NewAccountRepository 创建账户仓储实例。
|
||||
@@ -52,11 +51,7 @@ func NewAccountRepository(client *dbent.Client, sqlDB *sql.DB) service.AccountRe
|
||||
// newAccountRepositoryWithSQL 是内部构造函数,支持依赖注入 SQL 执行器。
|
||||
// 这种设计便于单元测试时注入 mock 对象。
|
||||
func newAccountRepositoryWithSQL(client *dbent.Client, sqlq sqlExecutor) *accountRepository {
|
||||
var beginner sqlBeginner
|
||||
if b, ok := sqlq.(sqlBeginner); ok {
|
||||
beginner = b
|
||||
}
|
||||
return &accountRepository{client: client, sql: sqlq, begin: beginner}
|
||||
return &accountRepository{client: client, sql: sqlq}
|
||||
}
|
||||
|
||||
func (r *accountRepository) Create(ctx context.Context, account *service.Account) error {
|
||||
@@ -146,9 +141,10 @@ func (r *accountRepository) GetByCRSAccountID(ctx context.Context, crsAccountID
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// 使用 sqljson.ValueEQ 生成 JSON 路径过滤,避免手写 SQL 片段导致语法兼容问题。
|
||||
m, err := r.client.Account.Query().
|
||||
Where(func(s *entsql.Selector) {
|
||||
s.Where(entsql.ExprP("extra->>'crs_account_id' = ?", crsAccountID))
|
||||
s.Where(sqljson.ValueEQ(dbaccount.FieldExtra, crsAccountID, sqljson.Path("crs_account_id")))
|
||||
}).
|
||||
Only(ctx)
|
||||
if err != nil {
|
||||
|
||||
@@ -4,7 +4,6 @@ package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
@@ -17,18 +16,16 @@ import (
|
||||
|
||||
type AccountRepoSuite struct {
|
||||
suite.Suite
|
||||
ctx context.Context
|
||||
tx *sql.Tx
|
||||
ctx context.Context
|
||||
client *dbent.Client
|
||||
repo *accountRepository
|
||||
repo *accountRepository
|
||||
}
|
||||
|
||||
func (s *AccountRepoSuite) SetupTest() {
|
||||
s.ctx = context.Background()
|
||||
client, tx := testEntSQLTx(s.T())
|
||||
s.client = client
|
||||
s.tx = tx
|
||||
s.repo = newAccountRepositoryWithSQL(client, tx)
|
||||
tx := testEntTx(s.T())
|
||||
s.client = tx.Client()
|
||||
s.repo = newAccountRepositoryWithSQL(s.client, tx)
|
||||
}
|
||||
|
||||
func TestAccountRepoSuite(t *testing.T) {
|
||||
@@ -175,7 +172,8 @@ func (s *AccountRepoSuite) TestListWithFilters() {
|
||||
for _, tt := range tests {
|
||||
s.Run(tt.name, func() {
|
||||
// 每个 case 重新获取隔离资源
|
||||
client, tx := testEntSQLTx(s.T())
|
||||
tx := testEntTx(s.T())
|
||||
client := tx.Client()
|
||||
repo := newAccountRepositoryWithSQL(client, tx)
|
||||
ctx := context.Background()
|
||||
|
||||
|
||||
@@ -20,7 +20,8 @@ func uniqueTestValue(t *testing.T, prefix string) string {
|
||||
|
||||
func TestUserRepository_RemoveGroupFromAllowedGroups_RemovesAllOccurrences(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
entClient, sqlTx := testEntSQLTx(t)
|
||||
tx := testEntTx(t)
|
||||
entClient := tx.Client()
|
||||
|
||||
targetGroup, err := entClient.Group.Create().
|
||||
SetName(uniqueTestValue(t, "target-group")).
|
||||
@@ -33,7 +34,7 @@ func TestUserRepository_RemoveGroupFromAllowedGroups_RemovesAllOccurrences(t *te
|
||||
Save(ctx)
|
||||
require.NoError(t, err)
|
||||
|
||||
repo := newUserRepositoryWithSQL(entClient, sqlTx)
|
||||
repo := newUserRepositoryWithSQL(entClient, tx)
|
||||
|
||||
u1 := &service.User{
|
||||
Email: uniqueTestValue(t, "u1") + "@example.com",
|
||||
@@ -81,7 +82,8 @@ func TestUserRepository_RemoveGroupFromAllowedGroups_RemovesAllOccurrences(t *te
|
||||
|
||||
func TestGroupRepository_DeleteCascade_RemovesAllowedGroupsAndClearsApiKeys(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
entClient, sqlTx := testEntSQLTx(t)
|
||||
tx := testEntTx(t)
|
||||
entClient := tx.Client()
|
||||
|
||||
targetGroup, err := entClient.Group.Create().
|
||||
SetName(uniqueTestValue(t, "delete-cascade-target")).
|
||||
@@ -94,8 +96,8 @@ func TestGroupRepository_DeleteCascade_RemovesAllowedGroupsAndClearsApiKeys(t *t
|
||||
Save(ctx)
|
||||
require.NoError(t, err)
|
||||
|
||||
userRepo := newUserRepositoryWithSQL(entClient, sqlTx)
|
||||
groupRepo := newGroupRepositoryWithSQL(entClient, sqlTx)
|
||||
userRepo := newUserRepositoryWithSQL(entClient, tx)
|
||||
groupRepo := newGroupRepositoryWithSQL(entClient, tx)
|
||||
apiKeyRepo := NewApiKeyRepository(entClient)
|
||||
|
||||
u := &service.User{
|
||||
@@ -141,4 +143,3 @@ func TestGroupRepository_DeleteCascade_RemovesAllowedGroupsAndClearsApiKeys(t *t
|
||||
require.NoError(t, err)
|
||||
require.Nil(t, keyAfter.GroupID)
|
||||
}
|
||||
|
||||
|
||||
@@ -2,9 +2,11 @@ package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
dbent "github.com/Wei-Shaw/sub2api/ent"
|
||||
"github.com/Wei-Shaw/sub2api/ent/apikey"
|
||||
"github.com/Wei-Shaw/sub2api/ent/schema/mixins"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||||
@@ -18,6 +20,11 @@ func NewApiKeyRepository(client *dbent.Client) service.ApiKeyRepository {
|
||||
return &apiKeyRepository{client: client}
|
||||
}
|
||||
|
||||
func (r *apiKeyRepository) activeQuery() *dbent.ApiKeyQuery {
|
||||
// 默认过滤已软删除记录,避免删除后仍被查询到。
|
||||
return r.client.ApiKey.Query().Where(apikey.DeletedAtIsNil())
|
||||
}
|
||||
|
||||
func (r *apiKeyRepository) Create(ctx context.Context, key *service.ApiKey) error {
|
||||
created, err := r.client.ApiKey.Create().
|
||||
SetUserID(key.UserID).
|
||||
@@ -35,7 +42,7 @@ func (r *apiKeyRepository) Create(ctx context.Context, key *service.ApiKey) erro
|
||||
}
|
||||
|
||||
func (r *apiKeyRepository) GetByID(ctx context.Context, id int64) (*service.ApiKey, error) {
|
||||
m, err := r.client.ApiKey.Query().
|
||||
m, err := r.activeQuery().
|
||||
Where(apikey.IDEQ(id)).
|
||||
WithUser().
|
||||
WithGroup().
|
||||
@@ -55,7 +62,7 @@ func (r *apiKeyRepository) GetByID(ctx context.Context, id int64) (*service.ApiK
|
||||
// - 不加载完整的 ApiKey 实体及其关联数据(User、Group 等)
|
||||
// - 适用于权限验证等只需用户 ID 的场景(如删除前的所有权检查)
|
||||
func (r *apiKeyRepository) GetOwnerID(ctx context.Context, id int64) (int64, error) {
|
||||
m, err := r.client.ApiKey.Query().
|
||||
m, err := r.activeQuery().
|
||||
Where(apikey.IDEQ(id)).
|
||||
Select(apikey.FieldUserID).
|
||||
Only(ctx)
|
||||
@@ -69,7 +76,7 @@ func (r *apiKeyRepository) GetOwnerID(ctx context.Context, id int64) (int64, err
|
||||
}
|
||||
|
||||
func (r *apiKeyRepository) GetByKey(ctx context.Context, key string) (*service.ApiKey, error) {
|
||||
m, err := r.client.ApiKey.Query().
|
||||
m, err := r.activeQuery().
|
||||
Where(apikey.KeyEQ(key)).
|
||||
WithUser().
|
||||
WithGroup().
|
||||
@@ -84,6 +91,14 @@ func (r *apiKeyRepository) GetByKey(ctx context.Context, key string) (*service.A
|
||||
}
|
||||
|
||||
func (r *apiKeyRepository) Update(ctx context.Context, key *service.ApiKey) error {
|
||||
exists, err := r.activeQuery().Where(apikey.IDEQ(key.ID)).Exist(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if !exists {
|
||||
return service.ErrApiKeyNotFound
|
||||
}
|
||||
|
||||
builder := r.client.ApiKey.UpdateOneID(key.ID).
|
||||
SetName(key.Name).
|
||||
SetStatus(key.Status)
|
||||
@@ -105,12 +120,34 @@ func (r *apiKeyRepository) Update(ctx context.Context, key *service.ApiKey) erro
|
||||
}
|
||||
|
||||
func (r *apiKeyRepository) Delete(ctx context.Context, id int64) error {
|
||||
_, err := r.client.ApiKey.Delete().Where(apikey.IDEQ(id)).Exec(ctx)
|
||||
return err
|
||||
// 显式软删除:避免依赖 Hook 行为,确保 deleted_at 一定被设置。
|
||||
affected, err := r.client.ApiKey.Update().
|
||||
Where(apikey.IDEQ(id), apikey.DeletedAtIsNil()).
|
||||
SetDeletedAt(time.Now()).
|
||||
Save(ctx)
|
||||
if err != nil {
|
||||
if dbent.IsNotFound(err) {
|
||||
return service.ErrApiKeyNotFound
|
||||
}
|
||||
return err
|
||||
}
|
||||
if affected == 0 {
|
||||
exists, err := r.client.ApiKey.Query().
|
||||
Where(apikey.IDEQ(id)).
|
||||
Exist(mixins.SkipSoftDelete(ctx))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if exists {
|
||||
return nil
|
||||
}
|
||||
return service.ErrApiKeyNotFound
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *apiKeyRepository) ListByUserID(ctx context.Context, userID int64, params pagination.PaginationParams) ([]service.ApiKey, *pagination.PaginationResult, error) {
|
||||
q := r.client.ApiKey.Query().Where(apikey.UserIDEQ(userID))
|
||||
q := r.activeQuery().Where(apikey.UserIDEQ(userID))
|
||||
|
||||
total, err := q.Count(ctx)
|
||||
if err != nil {
|
||||
@@ -141,7 +178,7 @@ func (r *apiKeyRepository) VerifyOwnership(ctx context.Context, userID int64, ap
|
||||
}
|
||||
|
||||
ids, err := r.client.ApiKey.Query().
|
||||
Where(apikey.UserIDEQ(userID), apikey.IDIn(apiKeyIDs...)).
|
||||
Where(apikey.UserIDEQ(userID), apikey.IDIn(apiKeyIDs...), apikey.DeletedAtIsNil()).
|
||||
IDs(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -150,17 +187,17 @@ func (r *apiKeyRepository) VerifyOwnership(ctx context.Context, userID int64, ap
|
||||
}
|
||||
|
||||
func (r *apiKeyRepository) CountByUserID(ctx context.Context, userID int64) (int64, error) {
|
||||
count, err := r.client.ApiKey.Query().Where(apikey.UserIDEQ(userID)).Count(ctx)
|
||||
count, err := r.activeQuery().Where(apikey.UserIDEQ(userID)).Count(ctx)
|
||||
return int64(count), err
|
||||
}
|
||||
|
||||
func (r *apiKeyRepository) ExistsByKey(ctx context.Context, key string) (bool, error) {
|
||||
count, err := r.client.ApiKey.Query().Where(apikey.KeyEQ(key)).Count(ctx)
|
||||
count, err := r.activeQuery().Where(apikey.KeyEQ(key)).Count(ctx)
|
||||
return count > 0, err
|
||||
}
|
||||
|
||||
func (r *apiKeyRepository) ListByGroupID(ctx context.Context, groupID int64, params pagination.PaginationParams) ([]service.ApiKey, *pagination.PaginationResult, error) {
|
||||
q := r.client.ApiKey.Query().Where(apikey.GroupIDEQ(groupID))
|
||||
q := r.activeQuery().Where(apikey.GroupIDEQ(groupID))
|
||||
|
||||
total, err := q.Count(ctx)
|
||||
if err != nil {
|
||||
@@ -187,7 +224,7 @@ func (r *apiKeyRepository) ListByGroupID(ctx context.Context, groupID int64, par
|
||||
|
||||
// SearchApiKeys searches API keys by user ID and/or keyword (name)
|
||||
func (r *apiKeyRepository) SearchApiKeys(ctx context.Context, userID int64, keyword string, limit int) ([]service.ApiKey, error) {
|
||||
q := r.client.ApiKey.Query()
|
||||
q := r.activeQuery()
|
||||
if userID > 0 {
|
||||
q = q.Where(apikey.UserIDEQ(userID))
|
||||
}
|
||||
@@ -211,7 +248,7 @@ func (r *apiKeyRepository) SearchApiKeys(ctx context.Context, userID int64, keyw
|
||||
// ClearGroupIDByGroupID 将指定分组的所有 API Key 的 group_id 设为 nil
|
||||
func (r *apiKeyRepository) ClearGroupIDByGroupID(ctx context.Context, groupID int64) (int64, error) {
|
||||
n, err := r.client.ApiKey.Update().
|
||||
Where(apikey.GroupIDEQ(groupID)).
|
||||
Where(apikey.GroupIDEQ(groupID), apikey.DeletedAtIsNil()).
|
||||
ClearGroupID().
|
||||
Save(ctx)
|
||||
return int64(n), err
|
||||
@@ -219,7 +256,7 @@ func (r *apiKeyRepository) ClearGroupIDByGroupID(ctx context.Context, groupID in
|
||||
|
||||
// CountByGroupID 获取分组的 API Key 数量
|
||||
func (r *apiKeyRepository) CountByGroupID(ctx context.Context, groupID int64) (int64, error) {
|
||||
count, err := r.client.ApiKey.Query().Where(apikey.GroupIDEQ(groupID)).Count(ctx)
|
||||
count, err := r.activeQuery().Where(apikey.GroupIDEQ(groupID)).Count(ctx)
|
||||
return int64(count), err
|
||||
}
|
||||
|
||||
|
||||
@@ -21,9 +21,9 @@ type ApiKeyRepoSuite struct {
|
||||
|
||||
func (s *ApiKeyRepoSuite) SetupTest() {
|
||||
s.ctx = context.Background()
|
||||
entClient, _ := testEntSQLTx(s.T())
|
||||
s.client = entClient
|
||||
s.repo = NewApiKeyRepository(entClient).(*apiKeyRepository)
|
||||
tx := testEntTx(s.T())
|
||||
s.client = tx.Client()
|
||||
s.repo = NewApiKeyRepository(s.client).(*apiKeyRepository)
|
||||
}
|
||||
|
||||
func TestApiKeyRepoSuite(t *testing.T) {
|
||||
|
||||
@@ -3,6 +3,7 @@ package repository
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"errors"
|
||||
|
||||
dbent "github.com/Wei-Shaw/sub2api/ent"
|
||||
"github.com/Wei-Shaw/sub2api/ent/apikey"
|
||||
@@ -10,26 +11,16 @@ import (
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/lib/pq"
|
||||
|
||||
"entgo.io/ent/dialect"
|
||||
entsql "entgo.io/ent/dialect/sql"
|
||||
)
|
||||
|
||||
type sqlExecutor interface {
|
||||
ExecContext(ctx context.Context, query string, args ...any) (sql.Result, error)
|
||||
QueryContext(ctx context.Context, query string, args ...any) (*sql.Rows, error)
|
||||
QueryRowContext(ctx context.Context, query string, args ...any) *sql.Row
|
||||
}
|
||||
|
||||
type sqlBeginner interface {
|
||||
sqlExecutor
|
||||
BeginTx(ctx context.Context, opts *sql.TxOptions) (*sql.Tx, error)
|
||||
}
|
||||
|
||||
type groupRepository struct {
|
||||
client *dbent.Client
|
||||
sql sqlExecutor
|
||||
begin sqlBeginner
|
||||
}
|
||||
|
||||
func NewGroupRepository(client *dbent.Client, sqlDB *sql.DB) service.GroupRepository {
|
||||
@@ -37,11 +28,7 @@ func NewGroupRepository(client *dbent.Client, sqlDB *sql.DB) service.GroupReposi
|
||||
}
|
||||
|
||||
func newGroupRepositoryWithSQL(client *dbent.Client, sqlq sqlExecutor) *groupRepository {
|
||||
var beginner sqlBeginner
|
||||
if b, ok := sqlq.(sqlBeginner); ok {
|
||||
beginner = b
|
||||
}
|
||||
return &groupRepository{client: client, sql: sqlq, begin: beginner}
|
||||
return &groupRepository{client: client, sql: sqlq}
|
||||
}
|
||||
|
||||
func (r *groupRepository) Create(ctx context.Context, groupIn *service.Group) error {
|
||||
@@ -214,7 +201,7 @@ func (r *groupRepository) ExistsByName(ctx context.Context, name string) (bool,
|
||||
|
||||
func (r *groupRepository) GetAccountCount(ctx context.Context, groupID int64) (int64, error) {
|
||||
var count int64
|
||||
if err := r.sql.QueryRowContext(ctx, "SELECT COUNT(*) FROM account_groups WHERE group_id = $1", groupID).Scan(&count); err != nil {
|
||||
if err := scanSingleRow(ctx, r.sql, "SELECT COUNT(*) FROM account_groups WHERE group_id = $1", []any{groupID}, &count); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return count, nil
|
||||
@@ -236,31 +223,44 @@ func (r *groupRepository) DeleteCascade(ctx context.Context, id int64) ([]int64,
|
||||
}
|
||||
groupSvc := groupEntityToService(g)
|
||||
|
||||
exec := r.sql
|
||||
// 使用 ent 事务统一包裹:避免手工基于 *sql.Tx 构造 ent client 带来的驱动断言问题,
|
||||
// 同时保证级联删除的原子性。
|
||||
tx, err := r.client.Tx(ctx)
|
||||
if err != nil && !errors.Is(err, dbent.ErrTxStarted) {
|
||||
return nil, err
|
||||
}
|
||||
exec := r.client
|
||||
txClient := r.client
|
||||
var sqlTx *sql.Tx
|
||||
|
||||
if r.begin != nil {
|
||||
sqlTx, err = r.begin.BeginTx(ctx, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
exec = sqlTx
|
||||
txClient = entClientFromSQLTx(sqlTx)
|
||||
// 注意:不能调用 txClient.Close(),因为基于事务的 ent client
|
||||
// 在 Close() 时会尝试将 ExecQuerier 断言为 *sql.DB,但实际是 *sql.Tx
|
||||
// 事务的清理通过 sqlTx.Rollback() 和 sqlTx.Commit() 完成
|
||||
defer func() { _ = sqlTx.Rollback() }()
|
||||
if err == nil {
|
||||
defer func() { _ = tx.Rollback() }()
|
||||
exec = tx.Client()
|
||||
txClient = exec
|
||||
} else {
|
||||
// 已处于外部事务中(ErrTxStarted),复用当前 client 参与同一事务。
|
||||
}
|
||||
|
||||
// Lock the group row to avoid concurrent writes while we cascade.
|
||||
var lockedID int64
|
||||
if err := exec.QueryRowContext(ctx, "SELECT id FROM groups WHERE id = $1 FOR UPDATE", id).Scan(&lockedID); err != nil {
|
||||
if errorsIsNoRows(err) {
|
||||
return nil, service.ErrGroupNotFound
|
||||
}
|
||||
// 这里使用 exec.QueryContext 手动扫描,确保同一事务内加锁并能区分“未找到”与其他错误。
|
||||
rows, err := exec.QueryContext(ctx, "SELECT id FROM groups WHERE id = $1 FOR UPDATE", id)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var lockedID int64
|
||||
if rows.Next() {
|
||||
if err := rows.Scan(&lockedID); err != nil {
|
||||
_ = rows.Close()
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
if err := rows.Close(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if lockedID == 0 {
|
||||
return nil, service.ErrGroupNotFound
|
||||
}
|
||||
|
||||
var affectedUserIDs []int64
|
||||
if groupSvc.IsSubscriptionType() {
|
||||
@@ -319,8 +319,8 @@ func (r *groupRepository) DeleteCascade(ctx context.Context, id int64) ([]int64,
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if sqlTx != nil {
|
||||
if err := sqlTx.Commit(); err != nil {
|
||||
if tx != nil {
|
||||
if err := tx.Commit(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
@@ -359,11 +359,6 @@ func (r *groupRepository) loadAccountCounts(ctx context.Context, groupIDs []int6
|
||||
return counts, nil
|
||||
}
|
||||
|
||||
func entClientFromSQLTx(tx *sql.Tx) *dbent.Client {
|
||||
drv := entsql.NewDriver(dialect.Postgres, entsql.Conn{ExecQuerier: tx})
|
||||
return dbent.NewClient(dbent.Driver(drv))
|
||||
}
|
||||
|
||||
func errorsIsNoRows(err error) bool {
|
||||
return err == sql.ErrNoRows
|
||||
}
|
||||
|
||||
@@ -4,9 +4,9 @@ package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"testing"
|
||||
|
||||
dbent "github.com/Wei-Shaw/sub2api/ent"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/stretchr/testify/suite"
|
||||
@@ -15,15 +15,15 @@ import (
|
||||
type GroupRepoSuite struct {
|
||||
suite.Suite
|
||||
ctx context.Context
|
||||
tx *sql.Tx
|
||||
tx *dbent.Tx
|
||||
repo *groupRepository
|
||||
}
|
||||
|
||||
func (s *GroupRepoSuite) SetupTest() {
|
||||
s.ctx = context.Background()
|
||||
entClient, tx := testEntSQLTx(s.T())
|
||||
tx := testEntTx(s.T())
|
||||
s.tx = tx
|
||||
s.repo = newGroupRepositoryWithSQL(entClient, tx)
|
||||
s.repo = newGroupRepositoryWithSQL(tx.Client(), tx)
|
||||
}
|
||||
|
||||
func TestGroupRepoSuite(t *testing.T) {
|
||||
@@ -99,6 +99,9 @@ func (s *GroupRepoSuite) TestDelete() {
|
||||
// --- List / ListWithFilters ---
|
||||
|
||||
func (s *GroupRepoSuite) TestList() {
|
||||
baseGroups, basePage, err := s.repo.List(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10})
|
||||
s.Require().NoError(err, "List base")
|
||||
|
||||
s.Require().NoError(s.repo.Create(s.ctx, &service.Group{
|
||||
Name: "g1",
|
||||
Platform: service.PlatformAnthropic,
|
||||
@@ -118,12 +121,20 @@ func (s *GroupRepoSuite) TestList() {
|
||||
|
||||
groups, page, err := s.repo.List(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10})
|
||||
s.Require().NoError(err, "List")
|
||||
// 3 default groups + 2 test groups = 5 total
|
||||
s.Require().Len(groups, 5)
|
||||
s.Require().Equal(int64(5), page.Total)
|
||||
s.Require().Len(groups, len(baseGroups)+2)
|
||||
s.Require().Equal(basePage.Total+2, page.Total)
|
||||
}
|
||||
|
||||
func (s *GroupRepoSuite) TestListWithFilters_Platform() {
|
||||
baseGroups, _, err := s.repo.ListWithFilters(
|
||||
s.ctx,
|
||||
pagination.PaginationParams{Page: 1, PageSize: 10},
|
||||
service.PlatformOpenAI,
|
||||
"",
|
||||
nil,
|
||||
)
|
||||
s.Require().NoError(err, "ListWithFilters base")
|
||||
|
||||
s.Require().NoError(s.repo.Create(s.ctx, &service.Group{
|
||||
Name: "g1",
|
||||
Platform: service.PlatformAnthropic,
|
||||
@@ -143,8 +154,7 @@ func (s *GroupRepoSuite) TestListWithFilters_Platform() {
|
||||
|
||||
groups, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, service.PlatformOpenAI, "", nil)
|
||||
s.Require().NoError(err)
|
||||
// 1 default openai group + 1 test openai group = 2 total
|
||||
s.Require().Len(groups, 2)
|
||||
s.Require().Len(groups, len(baseGroups)+1)
|
||||
// Verify all groups are OpenAI platform
|
||||
for _, g := range groups {
|
||||
s.Require().Equal(service.PlatformOpenAI, g.Platform)
|
||||
@@ -221,11 +231,13 @@ func (s *GroupRepoSuite) TestListWithFilters_AccountCount() {
|
||||
s.Require().NoError(s.repo.Create(s.ctx, g2))
|
||||
|
||||
var accountID int64
|
||||
s.Require().NoError(s.tx.QueryRowContext(
|
||||
s.Require().NoError(scanSingleRow(
|
||||
s.ctx,
|
||||
s.tx,
|
||||
"INSERT INTO accounts (name, platform, type) VALUES ($1, $2, $3) RETURNING id",
|
||||
"acc1", service.PlatformAnthropic, service.AccountTypeOAuth,
|
||||
).Scan(&accountID))
|
||||
[]any{"acc1", service.PlatformAnthropic, service.AccountTypeOAuth},
|
||||
&accountID,
|
||||
))
|
||||
_, err := s.tx.ExecContext(s.ctx, "INSERT INTO account_groups (account_id, group_id, priority, created_at) VALUES ($1, $2, $3, NOW())", accountID, g1.ID, 1)
|
||||
s.Require().NoError(err)
|
||||
_, err = s.tx.ExecContext(s.ctx, "INSERT INTO account_groups (account_id, group_id, priority, created_at) VALUES ($1, $2, $3, NOW())", accountID, g2.ID, 1)
|
||||
@@ -243,6 +255,9 @@ func (s *GroupRepoSuite) TestListWithFilters_AccountCount() {
|
||||
// --- ListActive / ListActiveByPlatform ---
|
||||
|
||||
func (s *GroupRepoSuite) TestListActive() {
|
||||
baseGroups, err := s.repo.ListActive(s.ctx)
|
||||
s.Require().NoError(err, "ListActive base")
|
||||
|
||||
s.Require().NoError(s.repo.Create(s.ctx, &service.Group{
|
||||
Name: "active1",
|
||||
Platform: service.PlatformAnthropic,
|
||||
@@ -262,8 +277,7 @@ func (s *GroupRepoSuite) TestListActive() {
|
||||
|
||||
groups, err := s.repo.ListActive(s.ctx)
|
||||
s.Require().NoError(err, "ListActive")
|
||||
// 3 default groups (all active) + 1 test active group = 4 total
|
||||
s.Require().Len(groups, 4)
|
||||
s.Require().Len(groups, len(baseGroups)+1)
|
||||
// Verify our test group is in the results
|
||||
var found bool
|
||||
for _, g := range groups {
|
||||
@@ -351,17 +365,21 @@ func (s *GroupRepoSuite) TestGetAccountCount() {
|
||||
s.Require().NoError(s.repo.Create(s.ctx, group))
|
||||
|
||||
var a1 int64
|
||||
s.Require().NoError(s.tx.QueryRowContext(
|
||||
s.Require().NoError(scanSingleRow(
|
||||
s.ctx,
|
||||
s.tx,
|
||||
"INSERT INTO accounts (name, platform, type) VALUES ($1, $2, $3) RETURNING id",
|
||||
"a1", service.PlatformAnthropic, service.AccountTypeOAuth,
|
||||
).Scan(&a1))
|
||||
[]any{"a1", service.PlatformAnthropic, service.AccountTypeOAuth},
|
||||
&a1,
|
||||
))
|
||||
var a2 int64
|
||||
s.Require().NoError(s.tx.QueryRowContext(
|
||||
s.Require().NoError(scanSingleRow(
|
||||
s.ctx,
|
||||
s.tx,
|
||||
"INSERT INTO accounts (name, platform, type) VALUES ($1, $2, $3) RETURNING id",
|
||||
"a2", service.PlatformAnthropic, service.AccountTypeOAuth,
|
||||
).Scan(&a2))
|
||||
[]any{"a2", service.PlatformAnthropic, service.AccountTypeOAuth},
|
||||
&a2,
|
||||
))
|
||||
|
||||
_, err := s.tx.ExecContext(s.ctx, "INSERT INTO account_groups (account_id, group_id, priority, created_at) VALUES ($1, $2, $3, NOW())", a1, group.ID, 1)
|
||||
s.Require().NoError(err)
|
||||
@@ -402,11 +420,13 @@ func (s *GroupRepoSuite) TestDeleteAccountGroupsByGroupID() {
|
||||
}
|
||||
s.Require().NoError(s.repo.Create(s.ctx, g))
|
||||
var accountID int64
|
||||
s.Require().NoError(s.tx.QueryRowContext(
|
||||
s.Require().NoError(scanSingleRow(
|
||||
s.ctx,
|
||||
s.tx,
|
||||
"INSERT INTO accounts (name, platform, type) VALUES ($1, $2, $3) RETURNING id",
|
||||
"acc-del", service.PlatformAnthropic, service.AccountTypeOAuth,
|
||||
).Scan(&accountID))
|
||||
[]any{"acc-del", service.PlatformAnthropic, service.AccountTypeOAuth},
|
||||
&accountID,
|
||||
))
|
||||
_, err := s.tx.ExecContext(s.ctx, "INSERT INTO account_groups (account_id, group_id, priority, created_at) VALUES ($1, $2, $3, NOW())", accountID, g.ID, 1)
|
||||
s.Require().NoError(err)
|
||||
|
||||
@@ -432,11 +452,13 @@ func (s *GroupRepoSuite) TestDeleteAccountGroupsByGroupID_MultipleAccounts() {
|
||||
|
||||
insertAccount := func(name string) int64 {
|
||||
var id int64
|
||||
s.Require().NoError(s.tx.QueryRowContext(
|
||||
s.Require().NoError(scanSingleRow(
|
||||
s.ctx,
|
||||
s.tx,
|
||||
"INSERT INTO accounts (name, platform, type) VALUES ($1, $2, $3) RETURNING id",
|
||||
name, service.PlatformAnthropic, service.AccountTypeOAuth,
|
||||
).Scan(&id))
|
||||
[]any{name, service.PlatformAnthropic, service.AccountTypeOAuth},
|
||||
&id,
|
||||
))
|
||||
return id
|
||||
}
|
||||
a1 := insertAccount("a1")
|
||||
|
||||
@@ -36,8 +36,9 @@ const (
|
||||
)
|
||||
|
||||
var (
|
||||
integrationDB *sql.DB
|
||||
integrationRedis *redisclient.Client
|
||||
integrationDB *sql.DB
|
||||
integrationEntClient *dbent.Client
|
||||
integrationRedis *redisclient.Client
|
||||
|
||||
redisNamespaceSeq uint64
|
||||
)
|
||||
@@ -101,6 +102,10 @@ func TestMain(m *testing.M) {
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
// 创建 ent client 用于集成测试
|
||||
drv := entsql.OpenDB(dialect.Postgres, integrationDB)
|
||||
integrationEntClient = dbent.NewClient(dbent.Driver(drv))
|
||||
|
||||
redisHost, err := redisContainer.Host(ctx)
|
||||
if err != nil {
|
||||
log.Printf("failed to get redis host: %v", err)
|
||||
@@ -123,6 +128,7 @@ func TestMain(m *testing.M) {
|
||||
|
||||
code := m.Run()
|
||||
|
||||
_ = integrationEntClient.Close()
|
||||
_ = integrationRedis.Close()
|
||||
_ = integrationDB.Close()
|
||||
|
||||
@@ -193,18 +199,38 @@ func testTx(t *testing.T) *sql.Tx {
|
||||
return tx
|
||||
}
|
||||
|
||||
// testEntClient 返回全局的 ent client,用于测试需要内部管理事务的代码(如 Create/Update 方法)。
|
||||
// 注意:此 client 的操作会真正写入数据库,测试结束后不会自动回滚。
|
||||
func testEntClient(t *testing.T) *dbent.Client {
|
||||
t.Helper()
|
||||
return integrationEntClient
|
||||
}
|
||||
|
||||
// testEntTx 返回一个 ent 事务,用于需要事务隔离的测试。
|
||||
// 测试结束后会自动回滚,不会影响数据库状态。
|
||||
func testEntTx(t *testing.T) *dbent.Tx {
|
||||
t.Helper()
|
||||
|
||||
tx, err := integrationEntClient.Tx(context.Background())
|
||||
require.NoError(t, err, "begin ent tx")
|
||||
t.Cleanup(func() {
|
||||
_ = tx.Rollback()
|
||||
})
|
||||
return tx
|
||||
}
|
||||
|
||||
// testEntSQLTx 已弃用:不要在新测试中使用此函数。
|
||||
// 基于 *sql.Tx 创建的 ent client 在调用 client.Tx() 时会 panic。
|
||||
// 对于需要测试内部使用事务的代码,请使用 testEntClient。
|
||||
// 对于需要事务隔离的测试,请使用 testEntTx。
|
||||
//
|
||||
// Deprecated: Use testEntClient or testEntTx instead.
|
||||
func testEntSQLTx(t *testing.T) (*dbent.Client, *sql.Tx) {
|
||||
t.Helper()
|
||||
|
||||
tx := testTx(t)
|
||||
drv := entsql.NewDriver(dialect.Postgres, entsql.Conn{ExecQuerier: tx})
|
||||
client := dbent.NewClient(dbent.Driver(drv))
|
||||
|
||||
t.Cleanup(func() {
|
||||
_ = client.Close()
|
||||
})
|
||||
|
||||
return client, tx
|
||||
// 直接失败,避免旧测试误用导致的事务嵌套 panic。
|
||||
t.Fatalf("testEntSQLTx 已弃用:请使用 testEntClient 或 testEntTx")
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func testRedis(t *testing.T) *redisclient.Client {
|
||||
@@ -363,13 +389,16 @@ type IntegrationDBSuite struct {
|
||||
suite.Suite
|
||||
ctx context.Context
|
||||
client *dbent.Client
|
||||
tx *sql.Tx
|
||||
tx *dbent.Tx
|
||||
}
|
||||
|
||||
// SetupTest initializes ctx and client for each test method.
|
||||
func (s *IntegrationDBSuite) SetupTest() {
|
||||
s.ctx = context.Background()
|
||||
s.client, s.tx = testEntSQLTx(s.T())
|
||||
// 统一使用 ent.Tx,确保每个测试都有独立事务并自动回滚。
|
||||
tx := testEntTx(s.T())
|
||||
s.tx = tx
|
||||
s.client = tx.Client()
|
||||
}
|
||||
|
||||
// RequireNoError is a convenience method wrapping require.NoError with s.T().
|
||||
|
||||
@@ -13,7 +13,6 @@ import (
|
||||
|
||||
type sqlQuerier interface {
|
||||
QueryContext(ctx context.Context, query string, args ...any) (*sql.Rows, error)
|
||||
QueryRowContext(ctx context.Context, query string, args ...any) *sql.Row
|
||||
}
|
||||
|
||||
type proxyRepository struct {
|
||||
@@ -170,9 +169,8 @@ func (r *proxyRepository) ExistsByHostPortAuth(ctx context.Context, host string,
|
||||
|
||||
// CountAccountsByProxyID returns the number of accounts using a specific proxy
|
||||
func (r *proxyRepository) CountAccountsByProxyID(ctx context.Context, proxyID int64) (int64, error) {
|
||||
row := r.sql.QueryRowContext(ctx, "SELECT COUNT(*) FROM accounts WHERE proxy_id = $1", proxyID)
|
||||
var count int64
|
||||
if err := row.Scan(&count); err != nil {
|
||||
if err := scanSingleRow(ctx, r.sql, "SELECT COUNT(*) FROM accounts WHERE proxy_id = $1", []any{proxyID}, &count); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return count, nil
|
||||
|
||||
@@ -4,10 +4,10 @@ package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
dbent "github.com/Wei-Shaw/sub2api/ent"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/stretchr/testify/suite"
|
||||
@@ -15,16 +15,16 @@ import (
|
||||
|
||||
type ProxyRepoSuite struct {
|
||||
suite.Suite
|
||||
ctx context.Context
|
||||
sqlTx *sql.Tx
|
||||
repo *proxyRepository
|
||||
ctx context.Context
|
||||
tx *dbent.Tx
|
||||
repo *proxyRepository
|
||||
}
|
||||
|
||||
func (s *ProxyRepoSuite) SetupTest() {
|
||||
s.ctx = context.Background()
|
||||
entClient, sqlTx := testEntSQLTx(s.T())
|
||||
s.sqlTx = sqlTx
|
||||
s.repo = newProxyRepositoryWithSQL(entClient, sqlTx)
|
||||
tx := testEntTx(s.T())
|
||||
s.tx = tx
|
||||
s.repo = newProxyRepositoryWithSQL(tx.Client(), tx)
|
||||
}
|
||||
|
||||
func TestProxyRepoSuite(t *testing.T) {
|
||||
@@ -306,7 +306,7 @@ func (s *ProxyRepoSuite) mustCreateProxyWithTimes(name, status string, createdAt
|
||||
Port: 8080,
|
||||
Status: status,
|
||||
})
|
||||
_, err := s.sqlTx.ExecContext(s.ctx, "UPDATE proxies SET created_at = $1, updated_at = $1 WHERE id = $2", createdAt, p.ID)
|
||||
_, err := s.tx.ExecContext(s.ctx, "UPDATE proxies SET created_at = $1, updated_at = $1 WHERE id = $2", createdAt, p.ID)
|
||||
s.Require().NoError(err, "update proxy timestamps")
|
||||
return p
|
||||
}
|
||||
@@ -317,7 +317,7 @@ func (s *ProxyRepoSuite) mustInsertAccount(name string, proxyID *int64) {
|
||||
if proxyID != nil {
|
||||
pid = *proxyID
|
||||
}
|
||||
_, err := s.sqlTx.ExecContext(
|
||||
_, err := s.tx.ExecContext(
|
||||
s.ctx,
|
||||
"INSERT INTO accounts (name, platform, type, proxy_id) VALUES ($1, $2, $3, $4)",
|
||||
name,
|
||||
|
||||
@@ -22,9 +22,9 @@ type RedeemCodeRepoSuite struct {
|
||||
|
||||
func (s *RedeemCodeRepoSuite) SetupTest() {
|
||||
s.ctx = context.Background()
|
||||
entClient, _ := testEntSQLTx(s.T())
|
||||
s.client = entClient
|
||||
s.repo = NewRedeemCodeRepository(entClient).(*redeemCodeRepository)
|
||||
tx := testEntTx(s.T())
|
||||
s.client = tx.Client()
|
||||
s.repo = NewRedeemCodeRepository(s.client).(*redeemCodeRepository)
|
||||
}
|
||||
|
||||
func TestRedeemCodeRepoSuite(t *testing.T) {
|
||||
|
||||
@@ -18,8 +18,8 @@ type SettingRepoSuite struct {
|
||||
|
||||
func (s *SettingRepoSuite) SetupTest() {
|
||||
s.ctx = context.Background()
|
||||
entClient, _ := testEntSQLTx(s.T())
|
||||
s.repo = NewSettingRepository(entClient).(*settingRepository)
|
||||
tx := testEntTx(s.T())
|
||||
s.repo = NewSettingRepository(tx.Client()).(*settingRepository)
|
||||
}
|
||||
|
||||
func TestSettingRepoSuite(t *testing.T) {
|
||||
|
||||
@@ -34,7 +34,8 @@ func createEntUser(t *testing.T, ctx context.Context, client *dbent.Client, emai
|
||||
|
||||
func TestEntSoftDelete_ApiKey_DefaultFilterAndSkip(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
client, _ := testEntSQLTx(t)
|
||||
// 使用全局 ent client,确保软删除验证在实际持久化数据上进行。
|
||||
client := testEntClient(t)
|
||||
|
||||
u := createEntUser(t, ctx, client, uniqueSoftDeleteValue(t, "sd-user")+"@example.com")
|
||||
|
||||
@@ -65,7 +66,8 @@ func TestEntSoftDelete_ApiKey_DefaultFilterAndSkip(t *testing.T) {
|
||||
|
||||
func TestEntSoftDelete_ApiKey_DeleteIdempotent(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
client, _ := testEntSQLTx(t)
|
||||
// 使用全局 ent client,避免事务回滚影响幂等性验证。
|
||||
client := testEntClient(t)
|
||||
|
||||
u := createEntUser(t, ctx, client, uniqueSoftDeleteValue(t, "sd-user2")+"@example.com")
|
||||
|
||||
@@ -84,7 +86,8 @@ func TestEntSoftDelete_ApiKey_DeleteIdempotent(t *testing.T) {
|
||||
|
||||
func TestEntSoftDelete_ApiKey_HardDeleteViaSkipSoftDelete(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
client, _ := testEntSQLTx(t)
|
||||
// 使用全局 ent client,确保 SkipSoftDelete 的硬删除语义可验证。
|
||||
client := testEntClient(t)
|
||||
|
||||
u := createEntUser(t, ctx, client, uniqueSoftDeleteValue(t, "sd-user3")+"@example.com")
|
||||
|
||||
|
||||
33
backend/internal/repository/sql_scan.go
Normal file
33
backend/internal/repository/sql_scan.go
Normal file
@@ -0,0 +1,33 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
)
|
||||
|
||||
type sqlQueryer interface {
|
||||
QueryContext(ctx context.Context, query string, args ...any) (*sql.Rows, error)
|
||||
}
|
||||
|
||||
// scanSingleRow executes a query and scans the first row into dest.
|
||||
// If no rows are returned, sql.ErrNoRows is returned.
|
||||
// 设计目的:仅依赖 QueryContext,避免 QueryRowContext 对 *sql.Tx 的强绑定,
|
||||
// 让 ent.Tx 也能作为 sqlExecutor/Queryer 使用。
|
||||
func scanSingleRow(ctx context.Context, q sqlQueryer, query string, args []any, dest ...any) error {
|
||||
rows, err := q.QueryContext(ctx, query, args...)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
if !rows.Next() {
|
||||
if err := rows.Err(); err != nil {
|
||||
return err
|
||||
}
|
||||
return sql.ErrNoRows
|
||||
}
|
||||
if err := rows.Scan(dest...); err != nil {
|
||||
return err
|
||||
}
|
||||
return rows.Err()
|
||||
}
|
||||
@@ -3,7 +3,6 @@ package repository
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
@@ -33,6 +32,7 @@ func NewUsageLogRepository(client *dbent.Client, sqlDB *sql.DB) service.UsageLog
|
||||
}
|
||||
|
||||
func newUsageLogRepositoryWithSQL(client *dbent.Client, sqlq sqlExecutor) *usageLogRepository {
|
||||
// 使用 scanSingleRow 替代 QueryRowContext,保证 ent.Tx 作为 sqlExecutor 可用。
|
||||
return &usageLogRepository{client: client, sql: sqlq}
|
||||
}
|
||||
|
||||
@@ -53,7 +53,7 @@ func (r *usageLogRepository) getPerformanceStats(ctx context.Context, userID int
|
||||
|
||||
var requestCount int64
|
||||
var tokenCount int64
|
||||
if err := r.sql.QueryRowContext(ctx, query, args...).Scan(&requestCount, &tokenCount); err != nil {
|
||||
if err := scanSingleRow(ctx, r.sql, query, args, &requestCount, &tokenCount); err != nil {
|
||||
return 0, 0, err
|
||||
}
|
||||
return requestCount / 5, tokenCount / 5, nil
|
||||
@@ -114,9 +114,7 @@ func (r *usageLogRepository) Create(ctx context.Context, log *service.UsageLog)
|
||||
duration := nullInt(log.DurationMs)
|
||||
firstToken := nullInt(log.FirstTokenMs)
|
||||
|
||||
row := r.sql.QueryRowContext(
|
||||
ctx,
|
||||
query,
|
||||
args := []any{
|
||||
log.UserID,
|
||||
log.ApiKeyID,
|
||||
log.AccountID,
|
||||
@@ -142,9 +140,8 @@ func (r *usageLogRepository) Create(ctx context.Context, log *service.UsageLog)
|
||||
duration,
|
||||
firstToken,
|
||||
createdAt,
|
||||
)
|
||||
|
||||
if err := row.Scan(&log.ID, &log.CreatedAt); err != nil {
|
||||
}
|
||||
if err := scanSingleRow(ctx, r.sql, query, args, &log.ID, &log.CreatedAt); err != nil {
|
||||
return err
|
||||
}
|
||||
log.RateMultiplier = rateMultiplier
|
||||
@@ -153,11 +150,22 @@ func (r *usageLogRepository) Create(ctx context.Context, log *service.UsageLog)
|
||||
|
||||
func (r *usageLogRepository) GetByID(ctx context.Context, id int64) (*service.UsageLog, error) {
|
||||
query := "SELECT " + usageLogSelectColumns + " FROM usage_logs WHERE id = $1"
|
||||
log, err := scanUsageLog(r.sql.QueryRowContext(ctx, query, id))
|
||||
rows, err := r.sql.QueryContext(ctx, query, id)
|
||||
if err != nil {
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
return nil, service.ErrUsageLogNotFound
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
if !rows.Next() {
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return nil, service.ErrUsageLogNotFound
|
||||
}
|
||||
log, err := scanUsageLog(rows)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return log, nil
|
||||
@@ -195,8 +203,18 @@ func (r *usageLogRepository) GetUserStats(ctx context.Context, userID int64, sta
|
||||
`
|
||||
|
||||
stats := &UserStats{}
|
||||
if err := r.sql.QueryRowContext(ctx, query, userID, startTime, endTime).
|
||||
Scan(&stats.TotalRequests, &stats.TotalTokens, &stats.TotalCost, &stats.InputTokens, &stats.OutputTokens, &stats.CacheReadTokens); err != nil {
|
||||
if err := scanSingleRow(
|
||||
ctx,
|
||||
r.sql,
|
||||
query,
|
||||
[]any{userID, startTime, endTime},
|
||||
&stats.TotalRequests,
|
||||
&stats.TotalTokens,
|
||||
&stats.TotalCost,
|
||||
&stats.InputTokens,
|
||||
&stats.OutputTokens,
|
||||
&stats.CacheReadTokens,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return stats, nil
|
||||
@@ -219,8 +237,15 @@ func (r *usageLogRepository) GetDashboardStats(ctx context.Context) (*DashboardS
|
||||
FROM users
|
||||
WHERE deleted_at IS NULL
|
||||
`
|
||||
if err := r.sql.QueryRowContext(ctx, userStatsQuery, today, today).
|
||||
Scan(&stats.TotalUsers, &stats.TodayNewUsers, &stats.ActiveUsers); err != nil {
|
||||
if err := scanSingleRow(
|
||||
ctx,
|
||||
r.sql,
|
||||
userStatsQuery,
|
||||
[]any{today, today},
|
||||
&stats.TotalUsers,
|
||||
&stats.TodayNewUsers,
|
||||
&stats.ActiveUsers,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -232,8 +257,14 @@ func (r *usageLogRepository) GetDashboardStats(ctx context.Context) (*DashboardS
|
||||
FROM api_keys
|
||||
WHERE deleted_at IS NULL
|
||||
`
|
||||
if err := r.sql.QueryRowContext(ctx, apiKeyStatsQuery, service.StatusActive).
|
||||
Scan(&stats.TotalApiKeys, &stats.ActiveApiKeys); err != nil {
|
||||
if err := scanSingleRow(
|
||||
ctx,
|
||||
r.sql,
|
||||
apiKeyStatsQuery,
|
||||
[]any{service.StatusActive},
|
||||
&stats.TotalApiKeys,
|
||||
&stats.ActiveApiKeys,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -248,8 +279,17 @@ func (r *usageLogRepository) GetDashboardStats(ctx context.Context) (*DashboardS
|
||||
FROM accounts
|
||||
WHERE deleted_at IS NULL
|
||||
`
|
||||
if err := r.sql.QueryRowContext(ctx, accountStatsQuery, service.StatusActive, service.StatusError, now, now).
|
||||
Scan(&stats.TotalAccounts, &stats.NormalAccounts, &stats.ErrorAccounts, &stats.RateLimitAccounts, &stats.OverloadAccounts); err != nil {
|
||||
if err := scanSingleRow(
|
||||
ctx,
|
||||
r.sql,
|
||||
accountStatsQuery,
|
||||
[]any{service.StatusActive, service.StatusError, now, now},
|
||||
&stats.TotalAccounts,
|
||||
&stats.NormalAccounts,
|
||||
&stats.ErrorAccounts,
|
||||
&stats.RateLimitAccounts,
|
||||
&stats.OverloadAccounts,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -266,17 +306,20 @@ func (r *usageLogRepository) GetDashboardStats(ctx context.Context) (*DashboardS
|
||||
COALESCE(AVG(duration_ms), 0) as avg_duration_ms
|
||||
FROM usage_logs
|
||||
`
|
||||
if err := r.sql.QueryRowContext(ctx, totalStatsQuery).
|
||||
Scan(
|
||||
&stats.TotalRequests,
|
||||
&stats.TotalInputTokens,
|
||||
&stats.TotalOutputTokens,
|
||||
&stats.TotalCacheCreationTokens,
|
||||
&stats.TotalCacheReadTokens,
|
||||
&stats.TotalCost,
|
||||
&stats.TotalActualCost,
|
||||
&stats.AverageDurationMs,
|
||||
); err != nil {
|
||||
if err := scanSingleRow(
|
||||
ctx,
|
||||
r.sql,
|
||||
totalStatsQuery,
|
||||
nil,
|
||||
&stats.TotalRequests,
|
||||
&stats.TotalInputTokens,
|
||||
&stats.TotalOutputTokens,
|
||||
&stats.TotalCacheCreationTokens,
|
||||
&stats.TotalCacheReadTokens,
|
||||
&stats.TotalCost,
|
||||
&stats.TotalActualCost,
|
||||
&stats.AverageDurationMs,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
stats.TotalTokens = stats.TotalInputTokens + stats.TotalOutputTokens + stats.TotalCacheCreationTokens + stats.TotalCacheReadTokens
|
||||
@@ -294,16 +337,19 @@ func (r *usageLogRepository) GetDashboardStats(ctx context.Context) (*DashboardS
|
||||
FROM usage_logs
|
||||
WHERE created_at >= $1
|
||||
`
|
||||
if err := r.sql.QueryRowContext(ctx, todayStatsQuery, today).
|
||||
Scan(
|
||||
&stats.TodayRequests,
|
||||
&stats.TodayInputTokens,
|
||||
&stats.TodayOutputTokens,
|
||||
&stats.TodayCacheCreationTokens,
|
||||
&stats.TodayCacheReadTokens,
|
||||
&stats.TodayCost,
|
||||
&stats.TodayActualCost,
|
||||
); err != nil {
|
||||
if err := scanSingleRow(
|
||||
ctx,
|
||||
r.sql,
|
||||
todayStatsQuery,
|
||||
[]any{today},
|
||||
&stats.TodayRequests,
|
||||
&stats.TodayInputTokens,
|
||||
&stats.TodayOutputTokens,
|
||||
&stats.TodayCacheCreationTokens,
|
||||
&stats.TodayCacheReadTokens,
|
||||
&stats.TodayCost,
|
||||
&stats.TodayActualCost,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
stats.TodayTokens = stats.TodayInputTokens + stats.TodayOutputTokens + stats.TodayCacheCreationTokens + stats.TodayCacheReadTokens
|
||||
@@ -345,16 +391,19 @@ func (r *usageLogRepository) GetUserStatsAggregated(ctx context.Context, userID
|
||||
`
|
||||
|
||||
var stats usagestats.UsageStats
|
||||
if err := r.sql.QueryRowContext(ctx, query, userID, startTime, endTime).
|
||||
Scan(
|
||||
&stats.TotalRequests,
|
||||
&stats.TotalInputTokens,
|
||||
&stats.TotalOutputTokens,
|
||||
&stats.TotalCacheTokens,
|
||||
&stats.TotalCost,
|
||||
&stats.TotalActualCost,
|
||||
&stats.AverageDurationMs,
|
||||
); err != nil {
|
||||
if err := scanSingleRow(
|
||||
ctx,
|
||||
r.sql,
|
||||
query,
|
||||
[]any{userID, startTime, endTime},
|
||||
&stats.TotalRequests,
|
||||
&stats.TotalInputTokens,
|
||||
&stats.TotalOutputTokens,
|
||||
&stats.TotalCacheTokens,
|
||||
&stats.TotalCost,
|
||||
&stats.TotalActualCost,
|
||||
&stats.AverageDurationMs,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
stats.TotalTokens = stats.TotalInputTokens + stats.TotalOutputTokens + stats.TotalCacheTokens
|
||||
@@ -377,16 +426,19 @@ func (r *usageLogRepository) GetApiKeyStatsAggregated(ctx context.Context, apiKe
|
||||
`
|
||||
|
||||
var stats usagestats.UsageStats
|
||||
if err := r.sql.QueryRowContext(ctx, query, apiKeyID, startTime, endTime).
|
||||
Scan(
|
||||
&stats.TotalRequests,
|
||||
&stats.TotalInputTokens,
|
||||
&stats.TotalOutputTokens,
|
||||
&stats.TotalCacheTokens,
|
||||
&stats.TotalCost,
|
||||
&stats.TotalActualCost,
|
||||
&stats.AverageDurationMs,
|
||||
); err != nil {
|
||||
if err := scanSingleRow(
|
||||
ctx,
|
||||
r.sql,
|
||||
query,
|
||||
[]any{apiKeyID, startTime, endTime},
|
||||
&stats.TotalRequests,
|
||||
&stats.TotalInputTokens,
|
||||
&stats.TotalOutputTokens,
|
||||
&stats.TotalCacheTokens,
|
||||
&stats.TotalCost,
|
||||
&stats.TotalActualCost,
|
||||
&stats.AverageDurationMs,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
stats.TotalTokens = stats.TotalInputTokens + stats.TotalOutputTokens + stats.TotalCacheTokens
|
||||
@@ -430,8 +482,15 @@ func (r *usageLogRepository) GetAccountTodayStats(ctx context.Context, accountID
|
||||
`
|
||||
|
||||
stats := &usagestats.AccountStats{}
|
||||
if err := r.sql.QueryRowContext(ctx, query, accountID, today).
|
||||
Scan(&stats.Requests, &stats.Tokens, &stats.Cost); err != nil {
|
||||
if err := scanSingleRow(
|
||||
ctx,
|
||||
r.sql,
|
||||
query,
|
||||
[]any{accountID, today},
|
||||
&stats.Requests,
|
||||
&stats.Tokens,
|
||||
&stats.Cost,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return stats, nil
|
||||
@@ -449,8 +508,15 @@ func (r *usageLogRepository) GetAccountWindowStats(ctx context.Context, accountI
|
||||
`
|
||||
|
||||
stats := &usagestats.AccountStats{}
|
||||
if err := r.sql.QueryRowContext(ctx, query, accountID, startTime).
|
||||
Scan(&stats.Requests, &stats.Tokens, &stats.Cost); err != nil {
|
||||
if err := scanSingleRow(
|
||||
ctx,
|
||||
r.sql,
|
||||
query,
|
||||
[]any{accountID, startTime},
|
||||
&stats.Requests,
|
||||
&stats.Tokens,
|
||||
&stats.Cost,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return stats, nil
|
||||
@@ -581,12 +647,22 @@ func (r *usageLogRepository) GetUserDashboardStats(ctx context.Context, userID i
|
||||
today := timezone.Today()
|
||||
|
||||
// API Key 统计
|
||||
if err := r.sql.QueryRowContext(ctx, "SELECT COUNT(*) FROM api_keys WHERE user_id = $1 AND deleted_at IS NULL", userID).
|
||||
Scan(&stats.TotalApiKeys); err != nil {
|
||||
if err := scanSingleRow(
|
||||
ctx,
|
||||
r.sql,
|
||||
"SELECT COUNT(*) FROM api_keys WHERE user_id = $1 AND deleted_at IS NULL",
|
||||
[]any{userID},
|
||||
&stats.TotalApiKeys,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := r.sql.QueryRowContext(ctx, "SELECT COUNT(*) FROM api_keys WHERE user_id = $1 AND status = $2 AND deleted_at IS NULL", userID, service.StatusActive).
|
||||
Scan(&stats.ActiveApiKeys); err != nil {
|
||||
if err := scanSingleRow(
|
||||
ctx,
|
||||
r.sql,
|
||||
"SELECT COUNT(*) FROM api_keys WHERE user_id = $1 AND status = $2 AND deleted_at IS NULL",
|
||||
[]any{userID, service.StatusActive},
|
||||
&stats.ActiveApiKeys,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -604,17 +680,20 @@ func (r *usageLogRepository) GetUserDashboardStats(ctx context.Context, userID i
|
||||
FROM usage_logs
|
||||
WHERE user_id = $1
|
||||
`
|
||||
if err := r.sql.QueryRowContext(ctx, totalStatsQuery, userID).
|
||||
Scan(
|
||||
&stats.TotalRequests,
|
||||
&stats.TotalInputTokens,
|
||||
&stats.TotalOutputTokens,
|
||||
&stats.TotalCacheCreationTokens,
|
||||
&stats.TotalCacheReadTokens,
|
||||
&stats.TotalCost,
|
||||
&stats.TotalActualCost,
|
||||
&stats.AverageDurationMs,
|
||||
); err != nil {
|
||||
if err := scanSingleRow(
|
||||
ctx,
|
||||
r.sql,
|
||||
totalStatsQuery,
|
||||
[]any{userID},
|
||||
&stats.TotalRequests,
|
||||
&stats.TotalInputTokens,
|
||||
&stats.TotalOutputTokens,
|
||||
&stats.TotalCacheCreationTokens,
|
||||
&stats.TotalCacheReadTokens,
|
||||
&stats.TotalCost,
|
||||
&stats.TotalActualCost,
|
||||
&stats.AverageDurationMs,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
stats.TotalTokens = stats.TotalInputTokens + stats.TotalOutputTokens + stats.TotalCacheCreationTokens + stats.TotalCacheReadTokens
|
||||
@@ -632,16 +711,19 @@ func (r *usageLogRepository) GetUserDashboardStats(ctx context.Context, userID i
|
||||
FROM usage_logs
|
||||
WHERE user_id = $1 AND created_at >= $2
|
||||
`
|
||||
if err := r.sql.QueryRowContext(ctx, todayStatsQuery, userID, today).
|
||||
Scan(
|
||||
&stats.TodayRequests,
|
||||
&stats.TodayInputTokens,
|
||||
&stats.TodayOutputTokens,
|
||||
&stats.TodayCacheCreationTokens,
|
||||
&stats.TodayCacheReadTokens,
|
||||
&stats.TodayCost,
|
||||
&stats.TodayActualCost,
|
||||
); err != nil {
|
||||
if err := scanSingleRow(
|
||||
ctx,
|
||||
r.sql,
|
||||
todayStatsQuery,
|
||||
[]any{userID, today},
|
||||
&stats.TodayRequests,
|
||||
&stats.TodayInputTokens,
|
||||
&stats.TodayOutputTokens,
|
||||
&stats.TodayCacheCreationTokens,
|
||||
&stats.TodayCacheReadTokens,
|
||||
&stats.TodayCost,
|
||||
&stats.TodayActualCost,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
stats.TodayTokens = stats.TodayInputTokens + stats.TodayOutputTokens + stats.TodayCacheCreationTokens + stats.TodayCacheReadTokens
|
||||
@@ -1007,16 +1089,19 @@ func (r *usageLogRepository) GetGlobalStats(ctx context.Context, startTime, endT
|
||||
`
|
||||
|
||||
stats := &UsageStats{}
|
||||
if err := r.sql.QueryRowContext(ctx, query, startTime, endTime).
|
||||
Scan(
|
||||
&stats.TotalRequests,
|
||||
&stats.TotalInputTokens,
|
||||
&stats.TotalOutputTokens,
|
||||
&stats.TotalCacheTokens,
|
||||
&stats.TotalCost,
|
||||
&stats.TotalActualCost,
|
||||
&stats.AverageDurationMs,
|
||||
); err != nil {
|
||||
if err := scanSingleRow(
|
||||
ctx,
|
||||
r.sql,
|
||||
query,
|
||||
[]any{startTime, endTime},
|
||||
&stats.TotalRequests,
|
||||
&stats.TotalInputTokens,
|
||||
&stats.TotalOutputTokens,
|
||||
&stats.TotalCacheTokens,
|
||||
&stats.TotalCost,
|
||||
&stats.TotalActualCost,
|
||||
&stats.AverageDurationMs,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
stats.TotalTokens = stats.TotalInputTokens + stats.TotalOutputTokens + stats.TotalCacheTokens
|
||||
@@ -1108,7 +1193,7 @@ func (r *usageLogRepository) GetAccountUsageStats(ctx context.Context, accountID
|
||||
|
||||
avgQuery := "SELECT COALESCE(AVG(duration_ms), 0) as avg_duration_ms FROM usage_logs WHERE account_id = $1 AND created_at >= $2 AND created_at < $3"
|
||||
var avgDuration float64
|
||||
if err := r.sql.QueryRowContext(ctx, avgQuery, accountID, startTime, endTime).Scan(&avgDuration); err != nil {
|
||||
if err := scanSingleRow(ctx, r.sql, avgQuery, []any{accountID, startTime, endTime}, &avgDuration); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -1186,7 +1271,7 @@ func (r *usageLogRepository) GetAccountUsageStats(ctx context.Context, accountID
|
||||
func (r *usageLogRepository) listUsageLogsWithPagination(ctx context.Context, whereClause string, args []any, params pagination.PaginationParams) ([]service.UsageLog, *pagination.PaginationResult, error) {
|
||||
countQuery := "SELECT COUNT(*) FROM usage_logs " + whereClause
|
||||
var total int64
|
||||
if err := r.sql.QueryRowContext(ctx, countQuery, args...).Scan(&total); err != nil {
|
||||
if err := scanSingleRow(ctx, r.sql, countQuery, args, &total); err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
|
||||
@@ -4,7 +4,6 @@ package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
@@ -19,17 +18,17 @@ import (
|
||||
type UsageLogRepoSuite struct {
|
||||
suite.Suite
|
||||
ctx context.Context
|
||||
tx *sql.Tx
|
||||
tx *dbent.Tx
|
||||
client *dbent.Client
|
||||
repo *usageLogRepository
|
||||
}
|
||||
|
||||
func (s *UsageLogRepoSuite) SetupTest() {
|
||||
s.ctx = context.Background()
|
||||
client, tx := testEntSQLTx(s.T())
|
||||
s.client = client
|
||||
tx := testEntTx(s.T())
|
||||
s.tx = tx
|
||||
s.repo = newUsageLogRepositoryWithSQL(client, tx)
|
||||
s.client = tx.Client()
|
||||
s.repo = newUsageLogRepositoryWithSQL(s.client, tx)
|
||||
}
|
||||
|
||||
func TestUsageLogRepoSuite(t *testing.T) {
|
||||
@@ -197,6 +196,8 @@ func (s *UsageLogRepoSuite) TestListWithFilters() {
|
||||
func (s *UsageLogRepoSuite) TestDashboardStats_TodayTotalsAndPerformance() {
|
||||
now := time.Now()
|
||||
todayStart := timezone.Today()
|
||||
baseStats, err := s.repo.GetDashboardStats(s.ctx)
|
||||
s.Require().NoError(err, "GetDashboardStats base")
|
||||
|
||||
userToday := mustCreateUser(s.T(), s.client, &service.User{
|
||||
Email: "today@example.com",
|
||||
@@ -268,24 +269,24 @@ func (s *UsageLogRepoSuite) TestDashboardStats_TodayTotalsAndPerformance() {
|
||||
stats, err := s.repo.GetDashboardStats(s.ctx)
|
||||
s.Require().NoError(err, "GetDashboardStats")
|
||||
|
||||
s.Require().Equal(int64(2), stats.TotalUsers, "TotalUsers mismatch")
|
||||
s.Require().Equal(int64(1), stats.TodayNewUsers, "TodayNewUsers mismatch")
|
||||
s.Require().Equal(int64(1), stats.ActiveUsers, "ActiveUsers mismatch")
|
||||
s.Require().Equal(int64(2), stats.TotalApiKeys, "TotalApiKeys mismatch")
|
||||
s.Require().Equal(int64(1), stats.ActiveApiKeys, "ActiveApiKeys mismatch")
|
||||
s.Require().Equal(int64(4), stats.TotalAccounts, "TotalAccounts mismatch")
|
||||
s.Require().Equal(int64(1), stats.ErrorAccounts, "ErrorAccounts mismatch")
|
||||
s.Require().Equal(int64(1), stats.RateLimitAccounts, "RateLimitAccounts mismatch")
|
||||
s.Require().Equal(int64(1), stats.OverloadAccounts, "OverloadAccounts mismatch")
|
||||
s.Require().Equal(baseStats.TotalUsers+2, stats.TotalUsers, "TotalUsers mismatch")
|
||||
s.Require().Equal(baseStats.TodayNewUsers+1, stats.TodayNewUsers, "TodayNewUsers mismatch")
|
||||
s.Require().Equal(baseStats.ActiveUsers+1, stats.ActiveUsers, "ActiveUsers mismatch")
|
||||
s.Require().Equal(baseStats.TotalApiKeys+2, stats.TotalApiKeys, "TotalApiKeys mismatch")
|
||||
s.Require().Equal(baseStats.ActiveApiKeys+1, stats.ActiveApiKeys, "ActiveApiKeys mismatch")
|
||||
s.Require().Equal(baseStats.TotalAccounts+4, stats.TotalAccounts, "TotalAccounts mismatch")
|
||||
s.Require().Equal(baseStats.ErrorAccounts+1, stats.ErrorAccounts, "ErrorAccounts mismatch")
|
||||
s.Require().Equal(baseStats.RateLimitAccounts+1, stats.RateLimitAccounts, "RateLimitAccounts mismatch")
|
||||
s.Require().Equal(baseStats.OverloadAccounts+1, stats.OverloadAccounts, "OverloadAccounts mismatch")
|
||||
|
||||
s.Require().Equal(int64(3), stats.TotalRequests, "TotalRequests mismatch")
|
||||
s.Require().Equal(int64(16), stats.TotalInputTokens, "TotalInputTokens mismatch")
|
||||
s.Require().Equal(int64(28), stats.TotalOutputTokens, "TotalOutputTokens mismatch")
|
||||
s.Require().Equal(int64(3), stats.TotalCacheCreationTokens, "TotalCacheCreationTokens mismatch")
|
||||
s.Require().Equal(int64(4), stats.TotalCacheReadTokens, "TotalCacheReadTokens mismatch")
|
||||
s.Require().Equal(int64(51), stats.TotalTokens, "TotalTokens mismatch")
|
||||
s.Require().Equal(2.3, stats.TotalCost, "TotalCost mismatch")
|
||||
s.Require().Equal(2.0, stats.TotalActualCost, "TotalActualCost mismatch")
|
||||
s.Require().Equal(baseStats.TotalRequests+3, stats.TotalRequests, "TotalRequests mismatch")
|
||||
s.Require().Equal(baseStats.TotalInputTokens+int64(16), stats.TotalInputTokens, "TotalInputTokens mismatch")
|
||||
s.Require().Equal(baseStats.TotalOutputTokens+int64(28), stats.TotalOutputTokens, "TotalOutputTokens mismatch")
|
||||
s.Require().Equal(baseStats.TotalCacheCreationTokens+int64(3), stats.TotalCacheCreationTokens, "TotalCacheCreationTokens mismatch")
|
||||
s.Require().Equal(baseStats.TotalCacheReadTokens+int64(4), stats.TotalCacheReadTokens, "TotalCacheReadTokens mismatch")
|
||||
s.Require().Equal(baseStats.TotalTokens+int64(51), stats.TotalTokens, "TotalTokens mismatch")
|
||||
s.Require().Equal(baseStats.TotalCost+2.3, stats.TotalCost, "TotalCost mismatch")
|
||||
s.Require().Equal(baseStats.TotalActualCost+2.0, stats.TotalActualCost, "TotalActualCost mismatch")
|
||||
s.Require().GreaterOrEqual(stats.TodayRequests, int64(1), "expected TodayRequests >= 1")
|
||||
s.Require().GreaterOrEqual(stats.TodayCost, 0.0, "expected TodayCost >= 0")
|
||||
|
||||
|
||||
@@ -3,6 +3,7 @@ package repository
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"errors"
|
||||
"sort"
|
||||
|
||||
dbent "github.com/Wei-Shaw/sub2api/ent"
|
||||
@@ -17,7 +18,6 @@ import (
|
||||
type userRepository struct {
|
||||
client *dbent.Client
|
||||
sql sqlExecutor
|
||||
begin sqlBeginner
|
||||
}
|
||||
|
||||
func NewUserRepository(client *dbent.Client, sqlDB *sql.DB) service.UserRepository {
|
||||
@@ -25,11 +25,7 @@ func NewUserRepository(client *dbent.Client, sqlDB *sql.DB) service.UserReposito
|
||||
}
|
||||
|
||||
func newUserRepositoryWithSQL(client *dbent.Client, sqlq sqlExecutor) *userRepository {
|
||||
var beginner sqlBeginner
|
||||
if b, ok := sqlq.(sqlBeginner); ok {
|
||||
beginner = b
|
||||
}
|
||||
return &userRepository{client: client, sql: sqlq, begin: beginner}
|
||||
return &userRepository{client: client, sql: sqlq}
|
||||
}
|
||||
|
||||
func (r *userRepository) Create(ctx context.Context, userIn *service.User) error {
|
||||
@@ -37,22 +33,20 @@ func (r *userRepository) Create(ctx context.Context, userIn *service.User) error
|
||||
return nil
|
||||
}
|
||||
|
||||
exec := r.sql
|
||||
txClient := r.client
|
||||
var sqlTx *sql.Tx
|
||||
// 统一使用 ent 的事务:保证用户与允许分组的更新原子化,
|
||||
// 并避免基于 *sql.Tx 手动构造 ent client 导致的 ExecQuerier 断言错误。
|
||||
tx, err := r.client.Tx(ctx)
|
||||
if err != nil && !errors.Is(err, dbent.ErrTxStarted) {
|
||||
return err
|
||||
}
|
||||
|
||||
if r.begin != nil {
|
||||
var err error
|
||||
sqlTx, err = r.begin.BeginTx(ctx, nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
exec = sqlTx
|
||||
txClient = entClientFromSQLTx(sqlTx)
|
||||
// 注意:不能调用 txClient.Close(),因为基于事务的 ent client
|
||||
// 在 Close() 时会尝试将 ExecQuerier 断言为 *sql.DB,但实际是 *sql.Tx
|
||||
// 事务的清理通过 sqlTx.Rollback() 和 sqlTx.Commit() 完成
|
||||
defer func() { _ = sqlTx.Rollback() }()
|
||||
var txClient *dbent.Client
|
||||
if err == nil {
|
||||
defer func() { _ = tx.Rollback() }()
|
||||
txClient = tx.Client()
|
||||
} else {
|
||||
// 已处于外部事务中(ErrTxStarted),复用当前 client 并由调用方负责提交/回滚。
|
||||
txClient = r.client
|
||||
}
|
||||
|
||||
created, err := txClient.User.Create().
|
||||
@@ -70,12 +64,12 @@ func (r *userRepository) Create(ctx context.Context, userIn *service.User) error
|
||||
return translatePersistenceError(err, nil, service.ErrEmailExists)
|
||||
}
|
||||
|
||||
if err := r.syncUserAllowedGroups(ctx, txClient, exec, created.ID, userIn.AllowedGroups); err != nil {
|
||||
if err := r.syncUserAllowedGroupsWithClient(ctx, txClient, created.ID, userIn.AllowedGroups); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if sqlTx != nil {
|
||||
if err := sqlTx.Commit(); err != nil {
|
||||
if tx != nil {
|
||||
if err := tx.Commit(); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
@@ -121,22 +115,19 @@ func (r *userRepository) Update(ctx context.Context, userIn *service.User) error
|
||||
return nil
|
||||
}
|
||||
|
||||
exec := r.sql
|
||||
txClient := r.client
|
||||
var sqlTx *sql.Tx
|
||||
// 使用 ent 事务包裹用户更新与 allowed_groups 同步,避免跨层事务不一致。
|
||||
tx, err := r.client.Tx(ctx)
|
||||
if err != nil && !errors.Is(err, dbent.ErrTxStarted) {
|
||||
return err
|
||||
}
|
||||
|
||||
if r.begin != nil {
|
||||
var err error
|
||||
sqlTx, err = r.begin.BeginTx(ctx, nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
exec = sqlTx
|
||||
txClient = entClientFromSQLTx(sqlTx)
|
||||
// 注意:不能调用 txClient.Close(),因为基于事务的 ent client
|
||||
// 在 Close() 时会尝试将 ExecQuerier 断言为 *sql.DB,但实际是 *sql.Tx
|
||||
// 事务的清理通过 sqlTx.Rollback() 和 sqlTx.Commit() 完成
|
||||
defer func() { _ = sqlTx.Rollback() }()
|
||||
var txClient *dbent.Client
|
||||
if err == nil {
|
||||
defer func() { _ = tx.Rollback() }()
|
||||
txClient = tx.Client()
|
||||
} else {
|
||||
// 已处于外部事务中(ErrTxStarted),复用当前 client 并由调用方负责提交/回滚。
|
||||
txClient = r.client
|
||||
}
|
||||
|
||||
updated, err := txClient.User.UpdateOneID(userIn.ID).
|
||||
@@ -154,12 +145,12 @@ func (r *userRepository) Update(ctx context.Context, userIn *service.User) error
|
||||
return translatePersistenceError(err, service.ErrUserNotFound, service.ErrEmailExists)
|
||||
}
|
||||
|
||||
if err := r.syncUserAllowedGroups(ctx, txClient, exec, updated.ID, userIn.AllowedGroups); err != nil {
|
||||
if err := r.syncUserAllowedGroupsWithClient(ctx, txClient, updated.ID, userIn.AllowedGroups); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if sqlTx != nil {
|
||||
if err := sqlTx.Commit(); err != nil {
|
||||
if tx != nil {
|
||||
if err := tx.Commit(); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
@@ -289,8 +280,10 @@ func (r *userRepository) ExistsByEmail(ctx context.Context, email string) (bool,
|
||||
}
|
||||
|
||||
func (r *userRepository) RemoveGroupFromAllowedGroups(ctx context.Context, groupID int64) (int64, error) {
|
||||
if r.sql == nil {
|
||||
return 0, nil
|
||||
exec := r.sql
|
||||
if exec == nil {
|
||||
// 未注入 sqlExecutor 时,退回到 ent client 的 ExecContext(支持事务)。
|
||||
exec = r.client
|
||||
}
|
||||
|
||||
joinAffected, err := r.client.UserAllowedGroup.Delete().
|
||||
@@ -300,7 +293,7 @@ func (r *userRepository) RemoveGroupFromAllowedGroups(ctx context.Context, group
|
||||
return 0, err
|
||||
}
|
||||
|
||||
arrayRes, err := r.sql.ExecContext(
|
||||
arrayRes, err := exec.ExecContext(
|
||||
ctx,
|
||||
"UPDATE users SET allowed_groups = array_remove(allowed_groups, $1), updated_at = NOW() WHERE $1 = ANY(allowed_groups)",
|
||||
groupID,
|
||||
@@ -362,6 +355,56 @@ func (r *userRepository) loadAllowedGroups(ctx context.Context, userIDs []int64)
|
||||
return out, nil
|
||||
}
|
||||
|
||||
// syncUserAllowedGroupsWithClient 在 ent client/事务内同步用户允许分组:
|
||||
// 1) 以 user_allowed_groups 为读写源,确保新旧逻辑一致;
|
||||
// 2) 额外更新 users.allowed_groups(历史字段)以保持兼容。
|
||||
func (r *userRepository) syncUserAllowedGroupsWithClient(ctx context.Context, client *dbent.Client, userID int64, groupIDs []int64) error {
|
||||
if client == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Keep join table as the source of truth for reads.
|
||||
if _, err := client.UserAllowedGroup.Delete().Where(userallowedgroup.UserIDEQ(userID)).Exec(ctx); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
unique := make(map[int64]struct{}, len(groupIDs))
|
||||
for _, id := range groupIDs {
|
||||
if id <= 0 {
|
||||
continue
|
||||
}
|
||||
unique[id] = struct{}{}
|
||||
}
|
||||
|
||||
legacyGroups := make([]int64, 0, len(unique))
|
||||
if len(unique) > 0 {
|
||||
creates := make([]*dbent.UserAllowedGroupCreate, 0, len(unique))
|
||||
for groupID := range unique {
|
||||
creates = append(creates, client.UserAllowedGroup.Create().SetUserID(userID).SetGroupID(groupID))
|
||||
legacyGroups = append(legacyGroups, groupID)
|
||||
}
|
||||
if err := client.UserAllowedGroup.
|
||||
CreateBulk(creates...).
|
||||
OnConflictColumns(userallowedgroup.FieldUserID, userallowedgroup.FieldGroupID).
|
||||
DoNothing().
|
||||
Exec(ctx); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
// Phase 1 兼容:保持 users.allowed_groups(数组字段)同步,避免旧查询路径读取到过期数据。
|
||||
var legacy any
|
||||
if len(legacyGroups) > 0 {
|
||||
sort.Slice(legacyGroups, func(i, j int) bool { return legacyGroups[i] < legacyGroups[j] })
|
||||
legacy = pq.Array(legacyGroups)
|
||||
}
|
||||
if _, err := client.ExecContext(ctx, "UPDATE users SET allowed_groups = $1::bigint[] WHERE id = $2", legacy, userID); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *userRepository) syncUserAllowedGroups(ctx context.Context, client *dbent.Client, exec sqlExecutor, userID int64, groupIDs []int64) error {
|
||||
if client == nil || exec == nil {
|
||||
return nil
|
||||
|
||||
@@ -4,7 +4,6 @@ package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
@@ -17,17 +16,19 @@ import (
|
||||
type UserRepoSuite struct {
|
||||
suite.Suite
|
||||
ctx context.Context
|
||||
tx *sql.Tx
|
||||
client *dbent.Client
|
||||
repo *userRepository
|
||||
}
|
||||
|
||||
func (s *UserRepoSuite) SetupTest() {
|
||||
s.ctx = context.Background()
|
||||
entClient, tx := testEntSQLTx(s.T())
|
||||
s.tx = tx
|
||||
s.client = entClient
|
||||
s.repo = newUserRepositoryWithSQL(entClient, tx)
|
||||
s.client = testEntClient(s.T())
|
||||
s.repo = newUserRepositoryWithSQL(s.client, integrationDB)
|
||||
|
||||
// 清理测试数据,确保每个测试从干净状态开始
|
||||
_, _ = integrationDB.ExecContext(s.ctx, "DELETE FROM user_subscriptions")
|
||||
_, _ = integrationDB.ExecContext(s.ctx, "DELETE FROM user_allowed_groups")
|
||||
_, _ = integrationDB.ExecContext(s.ctx, "DELETE FROM users")
|
||||
}
|
||||
|
||||
func TestUserRepoSuite(t *testing.T) {
|
||||
|
||||
@@ -22,8 +22,8 @@ type UserSubscriptionRepoSuite struct {
|
||||
|
||||
func (s *UserSubscriptionRepoSuite) SetupTest() {
|
||||
s.ctx = context.Background()
|
||||
client, _ := testEntSQLTx(s.T())
|
||||
s.client = client
|
||||
tx := testEntTx(s.T())
|
||||
s.client = tx.Client()
|
||||
s.repo = NewUserSubscriptionRepository(s.client).(*userSubscriptionRepository)
|
||||
}
|
||||
|
||||
@@ -66,8 +66,8 @@ func (s *UserSubscriptionRepoSuite) mustCreateSubscription(userID, groupID int64
|
||||
create := s.client.UserSubscription.Create().
|
||||
SetUserID(userID).
|
||||
SetGroupID(groupID).
|
||||
SetStartsAt(now.Add(-1*time.Hour)).
|
||||
SetExpiresAt(now.Add(24*time.Hour)).
|
||||
SetStartsAt(now.Add(-1 * time.Hour)).
|
||||
SetExpiresAt(now.Add(24 * time.Hour)).
|
||||
SetStatus(service.SubscriptionStatusActive).
|
||||
SetAssignedAt(now).
|
||||
SetNotes("")
|
||||
@@ -631,4 +631,3 @@ func (s *UserSubscriptionRepoSuite) TestActiveExpiredBoundaries_UsageAndReset_Ba
|
||||
s.Require().NoError(err, "GetByID expired")
|
||||
s.Require().Equal(service.SubscriptionStatusExpired, updated.Status, "expected status expired")
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user