fix(仓储): 修复软删除过滤与事务测试
修复软删除拦截器使用错误,确保默认查询过滤已删记录 仓储层改用 ent.Tx 与扫描辅助,避免 sql.Tx 断言问题 同步更新集成测试以覆盖事务与统计变动
This commit is contained in:
@@ -7,6 +7,7 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/Wei-Shaw/sub2api/ent/intercept"
|
||||||
"entgo.io/ent"
|
"entgo.io/ent"
|
||||||
"entgo.io/ent/dialect"
|
"entgo.io/ent/dialect"
|
||||||
"entgo.io/ent/dialect/sql"
|
"entgo.io/ent/dialect/sql"
|
||||||
@@ -79,16 +80,13 @@ func SkipSoftDelete(parent context.Context) context.Context {
|
|||||||
// 确保软删除的记录不会出现在普通查询结果中。
|
// 确保软删除的记录不会出现在普通查询结果中。
|
||||||
func (d SoftDeleteMixin) Interceptors() []ent.Interceptor {
|
func (d SoftDeleteMixin) Interceptors() []ent.Interceptor {
|
||||||
return []ent.Interceptor{
|
return []ent.Interceptor{
|
||||||
ent.TraverseFunc(func(ctx context.Context, q ent.Query) error {
|
intercept.TraverseFunc(func(ctx context.Context, q intercept.Query) error {
|
||||||
// 检查是否需要跳过软删除过滤
|
// 检查是否需要跳过软删除过滤
|
||||||
if skip, _ := ctx.Value(softDeleteKey{}).(bool); skip {
|
if skip, _ := ctx.Value(softDeleteKey{}).(bool); skip {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
// 为查询添加 deleted_at IS NULL 条件
|
// 为查询添加 deleted_at IS NULL 条件
|
||||||
w, ok := q.(interface{ WhereP(...func(*sql.Selector)) })
|
d.applyPredicate(q)
|
||||||
if ok {
|
|
||||||
d.applyPredicate(w)
|
|
||||||
}
|
|
||||||
return nil
|
return nil
|
||||||
}),
|
}),
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -28,6 +28,7 @@ import (
|
|||||||
"github.com/lib/pq"
|
"github.com/lib/pq"
|
||||||
|
|
||||||
entsql "entgo.io/ent/dialect/sql"
|
entsql "entgo.io/ent/dialect/sql"
|
||||||
|
"entgo.io/ent/dialect/sql/sqljson"
|
||||||
)
|
)
|
||||||
|
|
||||||
// accountRepository 实现 service.AccountRepository 接口。
|
// accountRepository 实现 service.AccountRepository 接口。
|
||||||
@@ -36,11 +37,9 @@ import (
|
|||||||
// 设计说明:
|
// 设计说明:
|
||||||
// - client: Ent 客户端,用于类型安全的 ORM 操作
|
// - client: Ent 客户端,用于类型安全的 ORM 操作
|
||||||
// - sql: 原生 SQL 执行器,用于复杂查询和批量操作
|
// - sql: 原生 SQL 执行器,用于复杂查询和批量操作
|
||||||
// - begin: SQL 事务开启器,用于需要事务的操作
|
|
||||||
type accountRepository struct {
|
type accountRepository struct {
|
||||||
client *dbent.Client // Ent ORM 客户端
|
client *dbent.Client // Ent ORM 客户端
|
||||||
sql sqlExecutor // 原生 SQL 执行接口
|
sql sqlExecutor // 原生 SQL 执行接口
|
||||||
begin sqlBeginner // 事务开启接口
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewAccountRepository 创建账户仓储实例。
|
// NewAccountRepository 创建账户仓储实例。
|
||||||
@@ -52,11 +51,7 @@ func NewAccountRepository(client *dbent.Client, sqlDB *sql.DB) service.AccountRe
|
|||||||
// newAccountRepositoryWithSQL 是内部构造函数,支持依赖注入 SQL 执行器。
|
// newAccountRepositoryWithSQL 是内部构造函数,支持依赖注入 SQL 执行器。
|
||||||
// 这种设计便于单元测试时注入 mock 对象。
|
// 这种设计便于单元测试时注入 mock 对象。
|
||||||
func newAccountRepositoryWithSQL(client *dbent.Client, sqlq sqlExecutor) *accountRepository {
|
func newAccountRepositoryWithSQL(client *dbent.Client, sqlq sqlExecutor) *accountRepository {
|
||||||
var beginner sqlBeginner
|
return &accountRepository{client: client, sql: sqlq}
|
||||||
if b, ok := sqlq.(sqlBeginner); ok {
|
|
||||||
beginner = b
|
|
||||||
}
|
|
||||||
return &accountRepository{client: client, sql: sqlq, begin: beginner}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *accountRepository) Create(ctx context.Context, account *service.Account) error {
|
func (r *accountRepository) Create(ctx context.Context, account *service.Account) error {
|
||||||
@@ -146,9 +141,10 @@ func (r *accountRepository) GetByCRSAccountID(ctx context.Context, crsAccountID
|
|||||||
return nil, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// 使用 sqljson.ValueEQ 生成 JSON 路径过滤,避免手写 SQL 片段导致语法兼容问题。
|
||||||
m, err := r.client.Account.Query().
|
m, err := r.client.Account.Query().
|
||||||
Where(func(s *entsql.Selector) {
|
Where(func(s *entsql.Selector) {
|
||||||
s.Where(entsql.ExprP("extra->>'crs_account_id' = ?", crsAccountID))
|
s.Where(sqljson.ValueEQ(dbaccount.FieldExtra, crsAccountID, sqljson.Path("crs_account_id")))
|
||||||
}).
|
}).
|
||||||
Only(ctx)
|
Only(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|||||||
@@ -4,7 +4,6 @@ package repository
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"database/sql"
|
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
@@ -17,18 +16,16 @@ import (
|
|||||||
|
|
||||||
type AccountRepoSuite struct {
|
type AccountRepoSuite struct {
|
||||||
suite.Suite
|
suite.Suite
|
||||||
ctx context.Context
|
ctx context.Context
|
||||||
tx *sql.Tx
|
|
||||||
client *dbent.Client
|
client *dbent.Client
|
||||||
repo *accountRepository
|
repo *accountRepository
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *AccountRepoSuite) SetupTest() {
|
func (s *AccountRepoSuite) SetupTest() {
|
||||||
s.ctx = context.Background()
|
s.ctx = context.Background()
|
||||||
client, tx := testEntSQLTx(s.T())
|
tx := testEntTx(s.T())
|
||||||
s.client = client
|
s.client = tx.Client()
|
||||||
s.tx = tx
|
s.repo = newAccountRepositoryWithSQL(s.client, tx)
|
||||||
s.repo = newAccountRepositoryWithSQL(client, tx)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestAccountRepoSuite(t *testing.T) {
|
func TestAccountRepoSuite(t *testing.T) {
|
||||||
@@ -175,7 +172,8 @@ func (s *AccountRepoSuite) TestListWithFilters() {
|
|||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
s.Run(tt.name, func() {
|
s.Run(tt.name, func() {
|
||||||
// 每个 case 重新获取隔离资源
|
// 每个 case 重新获取隔离资源
|
||||||
client, tx := testEntSQLTx(s.T())
|
tx := testEntTx(s.T())
|
||||||
|
client := tx.Client()
|
||||||
repo := newAccountRepositoryWithSQL(client, tx)
|
repo := newAccountRepositoryWithSQL(client, tx)
|
||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
|
|
||||||
|
|||||||
@@ -20,7 +20,8 @@ func uniqueTestValue(t *testing.T, prefix string) string {
|
|||||||
|
|
||||||
func TestUserRepository_RemoveGroupFromAllowedGroups_RemovesAllOccurrences(t *testing.T) {
|
func TestUserRepository_RemoveGroupFromAllowedGroups_RemovesAllOccurrences(t *testing.T) {
|
||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
entClient, sqlTx := testEntSQLTx(t)
|
tx := testEntTx(t)
|
||||||
|
entClient := tx.Client()
|
||||||
|
|
||||||
targetGroup, err := entClient.Group.Create().
|
targetGroup, err := entClient.Group.Create().
|
||||||
SetName(uniqueTestValue(t, "target-group")).
|
SetName(uniqueTestValue(t, "target-group")).
|
||||||
@@ -33,7 +34,7 @@ func TestUserRepository_RemoveGroupFromAllowedGroups_RemovesAllOccurrences(t *te
|
|||||||
Save(ctx)
|
Save(ctx)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
repo := newUserRepositoryWithSQL(entClient, sqlTx)
|
repo := newUserRepositoryWithSQL(entClient, tx)
|
||||||
|
|
||||||
u1 := &service.User{
|
u1 := &service.User{
|
||||||
Email: uniqueTestValue(t, "u1") + "@example.com",
|
Email: uniqueTestValue(t, "u1") + "@example.com",
|
||||||
@@ -81,7 +82,8 @@ func TestUserRepository_RemoveGroupFromAllowedGroups_RemovesAllOccurrences(t *te
|
|||||||
|
|
||||||
func TestGroupRepository_DeleteCascade_RemovesAllowedGroupsAndClearsApiKeys(t *testing.T) {
|
func TestGroupRepository_DeleteCascade_RemovesAllowedGroupsAndClearsApiKeys(t *testing.T) {
|
||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
entClient, sqlTx := testEntSQLTx(t)
|
tx := testEntTx(t)
|
||||||
|
entClient := tx.Client()
|
||||||
|
|
||||||
targetGroup, err := entClient.Group.Create().
|
targetGroup, err := entClient.Group.Create().
|
||||||
SetName(uniqueTestValue(t, "delete-cascade-target")).
|
SetName(uniqueTestValue(t, "delete-cascade-target")).
|
||||||
@@ -94,8 +96,8 @@ func TestGroupRepository_DeleteCascade_RemovesAllowedGroupsAndClearsApiKeys(t *t
|
|||||||
Save(ctx)
|
Save(ctx)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
userRepo := newUserRepositoryWithSQL(entClient, sqlTx)
|
userRepo := newUserRepositoryWithSQL(entClient, tx)
|
||||||
groupRepo := newGroupRepositoryWithSQL(entClient, sqlTx)
|
groupRepo := newGroupRepositoryWithSQL(entClient, tx)
|
||||||
apiKeyRepo := NewApiKeyRepository(entClient)
|
apiKeyRepo := NewApiKeyRepository(entClient)
|
||||||
|
|
||||||
u := &service.User{
|
u := &service.User{
|
||||||
@@ -141,4 +143,3 @@ func TestGroupRepository_DeleteCascade_RemovesAllowedGroupsAndClearsApiKeys(t *t
|
|||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
require.Nil(t, keyAfter.GroupID)
|
require.Nil(t, keyAfter.GroupID)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -2,9 +2,11 @@ package repository
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"time"
|
||||||
|
|
||||||
dbent "github.com/Wei-Shaw/sub2api/ent"
|
dbent "github.com/Wei-Shaw/sub2api/ent"
|
||||||
"github.com/Wei-Shaw/sub2api/ent/apikey"
|
"github.com/Wei-Shaw/sub2api/ent/apikey"
|
||||||
|
"github.com/Wei-Shaw/sub2api/ent/schema/mixins"
|
||||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||||
|
|
||||||
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||||||
@@ -18,6 +20,11 @@ func NewApiKeyRepository(client *dbent.Client) service.ApiKeyRepository {
|
|||||||
return &apiKeyRepository{client: client}
|
return &apiKeyRepository{client: client}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (r *apiKeyRepository) activeQuery() *dbent.ApiKeyQuery {
|
||||||
|
// 默认过滤已软删除记录,避免删除后仍被查询到。
|
||||||
|
return r.client.ApiKey.Query().Where(apikey.DeletedAtIsNil())
|
||||||
|
}
|
||||||
|
|
||||||
func (r *apiKeyRepository) Create(ctx context.Context, key *service.ApiKey) error {
|
func (r *apiKeyRepository) Create(ctx context.Context, key *service.ApiKey) error {
|
||||||
created, err := r.client.ApiKey.Create().
|
created, err := r.client.ApiKey.Create().
|
||||||
SetUserID(key.UserID).
|
SetUserID(key.UserID).
|
||||||
@@ -35,7 +42,7 @@ func (r *apiKeyRepository) Create(ctx context.Context, key *service.ApiKey) erro
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (r *apiKeyRepository) GetByID(ctx context.Context, id int64) (*service.ApiKey, error) {
|
func (r *apiKeyRepository) GetByID(ctx context.Context, id int64) (*service.ApiKey, error) {
|
||||||
m, err := r.client.ApiKey.Query().
|
m, err := r.activeQuery().
|
||||||
Where(apikey.IDEQ(id)).
|
Where(apikey.IDEQ(id)).
|
||||||
WithUser().
|
WithUser().
|
||||||
WithGroup().
|
WithGroup().
|
||||||
@@ -55,7 +62,7 @@ func (r *apiKeyRepository) GetByID(ctx context.Context, id int64) (*service.ApiK
|
|||||||
// - 不加载完整的 ApiKey 实体及其关联数据(User、Group 等)
|
// - 不加载完整的 ApiKey 实体及其关联数据(User、Group 等)
|
||||||
// - 适用于权限验证等只需用户 ID 的场景(如删除前的所有权检查)
|
// - 适用于权限验证等只需用户 ID 的场景(如删除前的所有权检查)
|
||||||
func (r *apiKeyRepository) GetOwnerID(ctx context.Context, id int64) (int64, error) {
|
func (r *apiKeyRepository) GetOwnerID(ctx context.Context, id int64) (int64, error) {
|
||||||
m, err := r.client.ApiKey.Query().
|
m, err := r.activeQuery().
|
||||||
Where(apikey.IDEQ(id)).
|
Where(apikey.IDEQ(id)).
|
||||||
Select(apikey.FieldUserID).
|
Select(apikey.FieldUserID).
|
||||||
Only(ctx)
|
Only(ctx)
|
||||||
@@ -69,7 +76,7 @@ func (r *apiKeyRepository) GetOwnerID(ctx context.Context, id int64) (int64, err
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (r *apiKeyRepository) GetByKey(ctx context.Context, key string) (*service.ApiKey, error) {
|
func (r *apiKeyRepository) GetByKey(ctx context.Context, key string) (*service.ApiKey, error) {
|
||||||
m, err := r.client.ApiKey.Query().
|
m, err := r.activeQuery().
|
||||||
Where(apikey.KeyEQ(key)).
|
Where(apikey.KeyEQ(key)).
|
||||||
WithUser().
|
WithUser().
|
||||||
WithGroup().
|
WithGroup().
|
||||||
@@ -84,6 +91,14 @@ func (r *apiKeyRepository) GetByKey(ctx context.Context, key string) (*service.A
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (r *apiKeyRepository) Update(ctx context.Context, key *service.ApiKey) error {
|
func (r *apiKeyRepository) Update(ctx context.Context, key *service.ApiKey) error {
|
||||||
|
exists, err := r.activeQuery().Where(apikey.IDEQ(key.ID)).Exist(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if !exists {
|
||||||
|
return service.ErrApiKeyNotFound
|
||||||
|
}
|
||||||
|
|
||||||
builder := r.client.ApiKey.UpdateOneID(key.ID).
|
builder := r.client.ApiKey.UpdateOneID(key.ID).
|
||||||
SetName(key.Name).
|
SetName(key.Name).
|
||||||
SetStatus(key.Status)
|
SetStatus(key.Status)
|
||||||
@@ -105,12 +120,34 @@ func (r *apiKeyRepository) Update(ctx context.Context, key *service.ApiKey) erro
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (r *apiKeyRepository) Delete(ctx context.Context, id int64) error {
|
func (r *apiKeyRepository) Delete(ctx context.Context, id int64) error {
|
||||||
_, err := r.client.ApiKey.Delete().Where(apikey.IDEQ(id)).Exec(ctx)
|
// 显式软删除:避免依赖 Hook 行为,确保 deleted_at 一定被设置。
|
||||||
return err
|
affected, err := r.client.ApiKey.Update().
|
||||||
|
Where(apikey.IDEQ(id), apikey.DeletedAtIsNil()).
|
||||||
|
SetDeletedAt(time.Now()).
|
||||||
|
Save(ctx)
|
||||||
|
if err != nil {
|
||||||
|
if dbent.IsNotFound(err) {
|
||||||
|
return service.ErrApiKeyNotFound
|
||||||
|
}
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if affected == 0 {
|
||||||
|
exists, err := r.client.ApiKey.Query().
|
||||||
|
Where(apikey.IDEQ(id)).
|
||||||
|
Exist(mixins.SkipSoftDelete(ctx))
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if exists {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return service.ErrApiKeyNotFound
|
||||||
|
}
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *apiKeyRepository) ListByUserID(ctx context.Context, userID int64, params pagination.PaginationParams) ([]service.ApiKey, *pagination.PaginationResult, error) {
|
func (r *apiKeyRepository) ListByUserID(ctx context.Context, userID int64, params pagination.PaginationParams) ([]service.ApiKey, *pagination.PaginationResult, error) {
|
||||||
q := r.client.ApiKey.Query().Where(apikey.UserIDEQ(userID))
|
q := r.activeQuery().Where(apikey.UserIDEQ(userID))
|
||||||
|
|
||||||
total, err := q.Count(ctx)
|
total, err := q.Count(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -141,7 +178,7 @@ func (r *apiKeyRepository) VerifyOwnership(ctx context.Context, userID int64, ap
|
|||||||
}
|
}
|
||||||
|
|
||||||
ids, err := r.client.ApiKey.Query().
|
ids, err := r.client.ApiKey.Query().
|
||||||
Where(apikey.UserIDEQ(userID), apikey.IDIn(apiKeyIDs...)).
|
Where(apikey.UserIDEQ(userID), apikey.IDIn(apiKeyIDs...), apikey.DeletedAtIsNil()).
|
||||||
IDs(ctx)
|
IDs(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
@@ -150,17 +187,17 @@ func (r *apiKeyRepository) VerifyOwnership(ctx context.Context, userID int64, ap
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (r *apiKeyRepository) CountByUserID(ctx context.Context, userID int64) (int64, error) {
|
func (r *apiKeyRepository) CountByUserID(ctx context.Context, userID int64) (int64, error) {
|
||||||
count, err := r.client.ApiKey.Query().Where(apikey.UserIDEQ(userID)).Count(ctx)
|
count, err := r.activeQuery().Where(apikey.UserIDEQ(userID)).Count(ctx)
|
||||||
return int64(count), err
|
return int64(count), err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *apiKeyRepository) ExistsByKey(ctx context.Context, key string) (bool, error) {
|
func (r *apiKeyRepository) ExistsByKey(ctx context.Context, key string) (bool, error) {
|
||||||
count, err := r.client.ApiKey.Query().Where(apikey.KeyEQ(key)).Count(ctx)
|
count, err := r.activeQuery().Where(apikey.KeyEQ(key)).Count(ctx)
|
||||||
return count > 0, err
|
return count > 0, err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *apiKeyRepository) ListByGroupID(ctx context.Context, groupID int64, params pagination.PaginationParams) ([]service.ApiKey, *pagination.PaginationResult, error) {
|
func (r *apiKeyRepository) ListByGroupID(ctx context.Context, groupID int64, params pagination.PaginationParams) ([]service.ApiKey, *pagination.PaginationResult, error) {
|
||||||
q := r.client.ApiKey.Query().Where(apikey.GroupIDEQ(groupID))
|
q := r.activeQuery().Where(apikey.GroupIDEQ(groupID))
|
||||||
|
|
||||||
total, err := q.Count(ctx)
|
total, err := q.Count(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -187,7 +224,7 @@ func (r *apiKeyRepository) ListByGroupID(ctx context.Context, groupID int64, par
|
|||||||
|
|
||||||
// SearchApiKeys searches API keys by user ID and/or keyword (name)
|
// SearchApiKeys searches API keys by user ID and/or keyword (name)
|
||||||
func (r *apiKeyRepository) SearchApiKeys(ctx context.Context, userID int64, keyword string, limit int) ([]service.ApiKey, error) {
|
func (r *apiKeyRepository) SearchApiKeys(ctx context.Context, userID int64, keyword string, limit int) ([]service.ApiKey, error) {
|
||||||
q := r.client.ApiKey.Query()
|
q := r.activeQuery()
|
||||||
if userID > 0 {
|
if userID > 0 {
|
||||||
q = q.Where(apikey.UserIDEQ(userID))
|
q = q.Where(apikey.UserIDEQ(userID))
|
||||||
}
|
}
|
||||||
@@ -211,7 +248,7 @@ func (r *apiKeyRepository) SearchApiKeys(ctx context.Context, userID int64, keyw
|
|||||||
// ClearGroupIDByGroupID 将指定分组的所有 API Key 的 group_id 设为 nil
|
// ClearGroupIDByGroupID 将指定分组的所有 API Key 的 group_id 设为 nil
|
||||||
func (r *apiKeyRepository) ClearGroupIDByGroupID(ctx context.Context, groupID int64) (int64, error) {
|
func (r *apiKeyRepository) ClearGroupIDByGroupID(ctx context.Context, groupID int64) (int64, error) {
|
||||||
n, err := r.client.ApiKey.Update().
|
n, err := r.client.ApiKey.Update().
|
||||||
Where(apikey.GroupIDEQ(groupID)).
|
Where(apikey.GroupIDEQ(groupID), apikey.DeletedAtIsNil()).
|
||||||
ClearGroupID().
|
ClearGroupID().
|
||||||
Save(ctx)
|
Save(ctx)
|
||||||
return int64(n), err
|
return int64(n), err
|
||||||
@@ -219,7 +256,7 @@ func (r *apiKeyRepository) ClearGroupIDByGroupID(ctx context.Context, groupID in
|
|||||||
|
|
||||||
// CountByGroupID 获取分组的 API Key 数量
|
// CountByGroupID 获取分组的 API Key 数量
|
||||||
func (r *apiKeyRepository) CountByGroupID(ctx context.Context, groupID int64) (int64, error) {
|
func (r *apiKeyRepository) CountByGroupID(ctx context.Context, groupID int64) (int64, error) {
|
||||||
count, err := r.client.ApiKey.Query().Where(apikey.GroupIDEQ(groupID)).Count(ctx)
|
count, err := r.activeQuery().Where(apikey.GroupIDEQ(groupID)).Count(ctx)
|
||||||
return int64(count), err
|
return int64(count), err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -21,9 +21,9 @@ type ApiKeyRepoSuite struct {
|
|||||||
|
|
||||||
func (s *ApiKeyRepoSuite) SetupTest() {
|
func (s *ApiKeyRepoSuite) SetupTest() {
|
||||||
s.ctx = context.Background()
|
s.ctx = context.Background()
|
||||||
entClient, _ := testEntSQLTx(s.T())
|
tx := testEntTx(s.T())
|
||||||
s.client = entClient
|
s.client = tx.Client()
|
||||||
s.repo = NewApiKeyRepository(entClient).(*apiKeyRepository)
|
s.repo = NewApiKeyRepository(s.client).(*apiKeyRepository)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestApiKeyRepoSuite(t *testing.T) {
|
func TestApiKeyRepoSuite(t *testing.T) {
|
||||||
|
|||||||
@@ -3,6 +3,7 @@ package repository
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"database/sql"
|
"database/sql"
|
||||||
|
"errors"
|
||||||
|
|
||||||
dbent "github.com/Wei-Shaw/sub2api/ent"
|
dbent "github.com/Wei-Shaw/sub2api/ent"
|
||||||
"github.com/Wei-Shaw/sub2api/ent/apikey"
|
"github.com/Wei-Shaw/sub2api/ent/apikey"
|
||||||
@@ -10,26 +11,16 @@ import (
|
|||||||
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||||
"github.com/lib/pq"
|
"github.com/lib/pq"
|
||||||
|
|
||||||
"entgo.io/ent/dialect"
|
|
||||||
entsql "entgo.io/ent/dialect/sql"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
type sqlExecutor interface {
|
type sqlExecutor interface {
|
||||||
ExecContext(ctx context.Context, query string, args ...any) (sql.Result, error)
|
ExecContext(ctx context.Context, query string, args ...any) (sql.Result, error)
|
||||||
QueryContext(ctx context.Context, query string, args ...any) (*sql.Rows, error)
|
QueryContext(ctx context.Context, query string, args ...any) (*sql.Rows, error)
|
||||||
QueryRowContext(ctx context.Context, query string, args ...any) *sql.Row
|
|
||||||
}
|
|
||||||
|
|
||||||
type sqlBeginner interface {
|
|
||||||
sqlExecutor
|
|
||||||
BeginTx(ctx context.Context, opts *sql.TxOptions) (*sql.Tx, error)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
type groupRepository struct {
|
type groupRepository struct {
|
||||||
client *dbent.Client
|
client *dbent.Client
|
||||||
sql sqlExecutor
|
sql sqlExecutor
|
||||||
begin sqlBeginner
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewGroupRepository(client *dbent.Client, sqlDB *sql.DB) service.GroupRepository {
|
func NewGroupRepository(client *dbent.Client, sqlDB *sql.DB) service.GroupRepository {
|
||||||
@@ -37,11 +28,7 @@ func NewGroupRepository(client *dbent.Client, sqlDB *sql.DB) service.GroupReposi
|
|||||||
}
|
}
|
||||||
|
|
||||||
func newGroupRepositoryWithSQL(client *dbent.Client, sqlq sqlExecutor) *groupRepository {
|
func newGroupRepositoryWithSQL(client *dbent.Client, sqlq sqlExecutor) *groupRepository {
|
||||||
var beginner sqlBeginner
|
return &groupRepository{client: client, sql: sqlq}
|
||||||
if b, ok := sqlq.(sqlBeginner); ok {
|
|
||||||
beginner = b
|
|
||||||
}
|
|
||||||
return &groupRepository{client: client, sql: sqlq, begin: beginner}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *groupRepository) Create(ctx context.Context, groupIn *service.Group) error {
|
func (r *groupRepository) Create(ctx context.Context, groupIn *service.Group) error {
|
||||||
@@ -214,7 +201,7 @@ func (r *groupRepository) ExistsByName(ctx context.Context, name string) (bool,
|
|||||||
|
|
||||||
func (r *groupRepository) GetAccountCount(ctx context.Context, groupID int64) (int64, error) {
|
func (r *groupRepository) GetAccountCount(ctx context.Context, groupID int64) (int64, error) {
|
||||||
var count int64
|
var count int64
|
||||||
if err := r.sql.QueryRowContext(ctx, "SELECT COUNT(*) FROM account_groups WHERE group_id = $1", groupID).Scan(&count); err != nil {
|
if err := scanSingleRow(ctx, r.sql, "SELECT COUNT(*) FROM account_groups WHERE group_id = $1", []any{groupID}, &count); err != nil {
|
||||||
return 0, err
|
return 0, err
|
||||||
}
|
}
|
||||||
return count, nil
|
return count, nil
|
||||||
@@ -236,31 +223,44 @@ func (r *groupRepository) DeleteCascade(ctx context.Context, id int64) ([]int64,
|
|||||||
}
|
}
|
||||||
groupSvc := groupEntityToService(g)
|
groupSvc := groupEntityToService(g)
|
||||||
|
|
||||||
exec := r.sql
|
// 使用 ent 事务统一包裹:避免手工基于 *sql.Tx 构造 ent client 带来的驱动断言问题,
|
||||||
|
// 同时保证级联删除的原子性。
|
||||||
|
tx, err := r.client.Tx(ctx)
|
||||||
|
if err != nil && !errors.Is(err, dbent.ErrTxStarted) {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
exec := r.client
|
||||||
txClient := r.client
|
txClient := r.client
|
||||||
var sqlTx *sql.Tx
|
if err == nil {
|
||||||
|
defer func() { _ = tx.Rollback() }()
|
||||||
if r.begin != nil {
|
exec = tx.Client()
|
||||||
sqlTx, err = r.begin.BeginTx(ctx, nil)
|
txClient = exec
|
||||||
if err != nil {
|
} else {
|
||||||
return nil, err
|
// 已处于外部事务中(ErrTxStarted),复用当前 client 参与同一事务。
|
||||||
}
|
|
||||||
exec = sqlTx
|
|
||||||
txClient = entClientFromSQLTx(sqlTx)
|
|
||||||
// 注意:不能调用 txClient.Close(),因为基于事务的 ent client
|
|
||||||
// 在 Close() 时会尝试将 ExecQuerier 断言为 *sql.DB,但实际是 *sql.Tx
|
|
||||||
// 事务的清理通过 sqlTx.Rollback() 和 sqlTx.Commit() 完成
|
|
||||||
defer func() { _ = sqlTx.Rollback() }()
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Lock the group row to avoid concurrent writes while we cascade.
|
// Lock the group row to avoid concurrent writes while we cascade.
|
||||||
var lockedID int64
|
// 这里使用 exec.QueryContext 手动扫描,确保同一事务内加锁并能区分“未找到”与其他错误。
|
||||||
if err := exec.QueryRowContext(ctx, "SELECT id FROM groups WHERE id = $1 FOR UPDATE", id).Scan(&lockedID); err != nil {
|
rows, err := exec.QueryContext(ctx, "SELECT id FROM groups WHERE id = $1 FOR UPDATE", id)
|
||||||
if errorsIsNoRows(err) {
|
if err != nil {
|
||||||
return nil, service.ErrGroupNotFound
|
|
||||||
}
|
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
var lockedID int64
|
||||||
|
if rows.Next() {
|
||||||
|
if err := rows.Scan(&lockedID); err != nil {
|
||||||
|
_ = rows.Close()
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if err := rows.Close(); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if err := rows.Err(); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if lockedID == 0 {
|
||||||
|
return nil, service.ErrGroupNotFound
|
||||||
|
}
|
||||||
|
|
||||||
var affectedUserIDs []int64
|
var affectedUserIDs []int64
|
||||||
if groupSvc.IsSubscriptionType() {
|
if groupSvc.IsSubscriptionType() {
|
||||||
@@ -319,8 +319,8 @@ func (r *groupRepository) DeleteCascade(ctx context.Context, id int64) ([]int64,
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
if sqlTx != nil {
|
if tx != nil {
|
||||||
if err := sqlTx.Commit(); err != nil {
|
if err := tx.Commit(); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -359,11 +359,6 @@ func (r *groupRepository) loadAccountCounts(ctx context.Context, groupIDs []int6
|
|||||||
return counts, nil
|
return counts, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func entClientFromSQLTx(tx *sql.Tx) *dbent.Client {
|
|
||||||
drv := entsql.NewDriver(dialect.Postgres, entsql.Conn{ExecQuerier: tx})
|
|
||||||
return dbent.NewClient(dbent.Driver(drv))
|
|
||||||
}
|
|
||||||
|
|
||||||
func errorsIsNoRows(err error) bool {
|
func errorsIsNoRows(err error) bool {
|
||||||
return err == sql.ErrNoRows
|
return err == sql.ErrNoRows
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -4,9 +4,9 @@ package repository
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"database/sql"
|
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
|
dbent "github.com/Wei-Shaw/sub2api/ent"
|
||||||
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||||
"github.com/stretchr/testify/suite"
|
"github.com/stretchr/testify/suite"
|
||||||
@@ -15,15 +15,15 @@ import (
|
|||||||
type GroupRepoSuite struct {
|
type GroupRepoSuite struct {
|
||||||
suite.Suite
|
suite.Suite
|
||||||
ctx context.Context
|
ctx context.Context
|
||||||
tx *sql.Tx
|
tx *dbent.Tx
|
||||||
repo *groupRepository
|
repo *groupRepository
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *GroupRepoSuite) SetupTest() {
|
func (s *GroupRepoSuite) SetupTest() {
|
||||||
s.ctx = context.Background()
|
s.ctx = context.Background()
|
||||||
entClient, tx := testEntSQLTx(s.T())
|
tx := testEntTx(s.T())
|
||||||
s.tx = tx
|
s.tx = tx
|
||||||
s.repo = newGroupRepositoryWithSQL(entClient, tx)
|
s.repo = newGroupRepositoryWithSQL(tx.Client(), tx)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestGroupRepoSuite(t *testing.T) {
|
func TestGroupRepoSuite(t *testing.T) {
|
||||||
@@ -99,6 +99,9 @@ func (s *GroupRepoSuite) TestDelete() {
|
|||||||
// --- List / ListWithFilters ---
|
// --- List / ListWithFilters ---
|
||||||
|
|
||||||
func (s *GroupRepoSuite) TestList() {
|
func (s *GroupRepoSuite) TestList() {
|
||||||
|
baseGroups, basePage, err := s.repo.List(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10})
|
||||||
|
s.Require().NoError(err, "List base")
|
||||||
|
|
||||||
s.Require().NoError(s.repo.Create(s.ctx, &service.Group{
|
s.Require().NoError(s.repo.Create(s.ctx, &service.Group{
|
||||||
Name: "g1",
|
Name: "g1",
|
||||||
Platform: service.PlatformAnthropic,
|
Platform: service.PlatformAnthropic,
|
||||||
@@ -118,12 +121,20 @@ func (s *GroupRepoSuite) TestList() {
|
|||||||
|
|
||||||
groups, page, err := s.repo.List(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10})
|
groups, page, err := s.repo.List(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10})
|
||||||
s.Require().NoError(err, "List")
|
s.Require().NoError(err, "List")
|
||||||
// 3 default groups + 2 test groups = 5 total
|
s.Require().Len(groups, len(baseGroups)+2)
|
||||||
s.Require().Len(groups, 5)
|
s.Require().Equal(basePage.Total+2, page.Total)
|
||||||
s.Require().Equal(int64(5), page.Total)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *GroupRepoSuite) TestListWithFilters_Platform() {
|
func (s *GroupRepoSuite) TestListWithFilters_Platform() {
|
||||||
|
baseGroups, _, err := s.repo.ListWithFilters(
|
||||||
|
s.ctx,
|
||||||
|
pagination.PaginationParams{Page: 1, PageSize: 10},
|
||||||
|
service.PlatformOpenAI,
|
||||||
|
"",
|
||||||
|
nil,
|
||||||
|
)
|
||||||
|
s.Require().NoError(err, "ListWithFilters base")
|
||||||
|
|
||||||
s.Require().NoError(s.repo.Create(s.ctx, &service.Group{
|
s.Require().NoError(s.repo.Create(s.ctx, &service.Group{
|
||||||
Name: "g1",
|
Name: "g1",
|
||||||
Platform: service.PlatformAnthropic,
|
Platform: service.PlatformAnthropic,
|
||||||
@@ -143,8 +154,7 @@ func (s *GroupRepoSuite) TestListWithFilters_Platform() {
|
|||||||
|
|
||||||
groups, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, service.PlatformOpenAI, "", nil)
|
groups, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, service.PlatformOpenAI, "", nil)
|
||||||
s.Require().NoError(err)
|
s.Require().NoError(err)
|
||||||
// 1 default openai group + 1 test openai group = 2 total
|
s.Require().Len(groups, len(baseGroups)+1)
|
||||||
s.Require().Len(groups, 2)
|
|
||||||
// Verify all groups are OpenAI platform
|
// Verify all groups are OpenAI platform
|
||||||
for _, g := range groups {
|
for _, g := range groups {
|
||||||
s.Require().Equal(service.PlatformOpenAI, g.Platform)
|
s.Require().Equal(service.PlatformOpenAI, g.Platform)
|
||||||
@@ -221,11 +231,13 @@ func (s *GroupRepoSuite) TestListWithFilters_AccountCount() {
|
|||||||
s.Require().NoError(s.repo.Create(s.ctx, g2))
|
s.Require().NoError(s.repo.Create(s.ctx, g2))
|
||||||
|
|
||||||
var accountID int64
|
var accountID int64
|
||||||
s.Require().NoError(s.tx.QueryRowContext(
|
s.Require().NoError(scanSingleRow(
|
||||||
s.ctx,
|
s.ctx,
|
||||||
|
s.tx,
|
||||||
"INSERT INTO accounts (name, platform, type) VALUES ($1, $2, $3) RETURNING id",
|
"INSERT INTO accounts (name, platform, type) VALUES ($1, $2, $3) RETURNING id",
|
||||||
"acc1", service.PlatformAnthropic, service.AccountTypeOAuth,
|
[]any{"acc1", service.PlatformAnthropic, service.AccountTypeOAuth},
|
||||||
).Scan(&accountID))
|
&accountID,
|
||||||
|
))
|
||||||
_, err := s.tx.ExecContext(s.ctx, "INSERT INTO account_groups (account_id, group_id, priority, created_at) VALUES ($1, $2, $3, NOW())", accountID, g1.ID, 1)
|
_, err := s.tx.ExecContext(s.ctx, "INSERT INTO account_groups (account_id, group_id, priority, created_at) VALUES ($1, $2, $3, NOW())", accountID, g1.ID, 1)
|
||||||
s.Require().NoError(err)
|
s.Require().NoError(err)
|
||||||
_, err = s.tx.ExecContext(s.ctx, "INSERT INTO account_groups (account_id, group_id, priority, created_at) VALUES ($1, $2, $3, NOW())", accountID, g2.ID, 1)
|
_, err = s.tx.ExecContext(s.ctx, "INSERT INTO account_groups (account_id, group_id, priority, created_at) VALUES ($1, $2, $3, NOW())", accountID, g2.ID, 1)
|
||||||
@@ -243,6 +255,9 @@ func (s *GroupRepoSuite) TestListWithFilters_AccountCount() {
|
|||||||
// --- ListActive / ListActiveByPlatform ---
|
// --- ListActive / ListActiveByPlatform ---
|
||||||
|
|
||||||
func (s *GroupRepoSuite) TestListActive() {
|
func (s *GroupRepoSuite) TestListActive() {
|
||||||
|
baseGroups, err := s.repo.ListActive(s.ctx)
|
||||||
|
s.Require().NoError(err, "ListActive base")
|
||||||
|
|
||||||
s.Require().NoError(s.repo.Create(s.ctx, &service.Group{
|
s.Require().NoError(s.repo.Create(s.ctx, &service.Group{
|
||||||
Name: "active1",
|
Name: "active1",
|
||||||
Platform: service.PlatformAnthropic,
|
Platform: service.PlatformAnthropic,
|
||||||
@@ -262,8 +277,7 @@ func (s *GroupRepoSuite) TestListActive() {
|
|||||||
|
|
||||||
groups, err := s.repo.ListActive(s.ctx)
|
groups, err := s.repo.ListActive(s.ctx)
|
||||||
s.Require().NoError(err, "ListActive")
|
s.Require().NoError(err, "ListActive")
|
||||||
// 3 default groups (all active) + 1 test active group = 4 total
|
s.Require().Len(groups, len(baseGroups)+1)
|
||||||
s.Require().Len(groups, 4)
|
|
||||||
// Verify our test group is in the results
|
// Verify our test group is in the results
|
||||||
var found bool
|
var found bool
|
||||||
for _, g := range groups {
|
for _, g := range groups {
|
||||||
@@ -351,17 +365,21 @@ func (s *GroupRepoSuite) TestGetAccountCount() {
|
|||||||
s.Require().NoError(s.repo.Create(s.ctx, group))
|
s.Require().NoError(s.repo.Create(s.ctx, group))
|
||||||
|
|
||||||
var a1 int64
|
var a1 int64
|
||||||
s.Require().NoError(s.tx.QueryRowContext(
|
s.Require().NoError(scanSingleRow(
|
||||||
s.ctx,
|
s.ctx,
|
||||||
|
s.tx,
|
||||||
"INSERT INTO accounts (name, platform, type) VALUES ($1, $2, $3) RETURNING id",
|
"INSERT INTO accounts (name, platform, type) VALUES ($1, $2, $3) RETURNING id",
|
||||||
"a1", service.PlatformAnthropic, service.AccountTypeOAuth,
|
[]any{"a1", service.PlatformAnthropic, service.AccountTypeOAuth},
|
||||||
).Scan(&a1))
|
&a1,
|
||||||
|
))
|
||||||
var a2 int64
|
var a2 int64
|
||||||
s.Require().NoError(s.tx.QueryRowContext(
|
s.Require().NoError(scanSingleRow(
|
||||||
s.ctx,
|
s.ctx,
|
||||||
|
s.tx,
|
||||||
"INSERT INTO accounts (name, platform, type) VALUES ($1, $2, $3) RETURNING id",
|
"INSERT INTO accounts (name, platform, type) VALUES ($1, $2, $3) RETURNING id",
|
||||||
"a2", service.PlatformAnthropic, service.AccountTypeOAuth,
|
[]any{"a2", service.PlatformAnthropic, service.AccountTypeOAuth},
|
||||||
).Scan(&a2))
|
&a2,
|
||||||
|
))
|
||||||
|
|
||||||
_, err := s.tx.ExecContext(s.ctx, "INSERT INTO account_groups (account_id, group_id, priority, created_at) VALUES ($1, $2, $3, NOW())", a1, group.ID, 1)
|
_, err := s.tx.ExecContext(s.ctx, "INSERT INTO account_groups (account_id, group_id, priority, created_at) VALUES ($1, $2, $3, NOW())", a1, group.ID, 1)
|
||||||
s.Require().NoError(err)
|
s.Require().NoError(err)
|
||||||
@@ -402,11 +420,13 @@ func (s *GroupRepoSuite) TestDeleteAccountGroupsByGroupID() {
|
|||||||
}
|
}
|
||||||
s.Require().NoError(s.repo.Create(s.ctx, g))
|
s.Require().NoError(s.repo.Create(s.ctx, g))
|
||||||
var accountID int64
|
var accountID int64
|
||||||
s.Require().NoError(s.tx.QueryRowContext(
|
s.Require().NoError(scanSingleRow(
|
||||||
s.ctx,
|
s.ctx,
|
||||||
|
s.tx,
|
||||||
"INSERT INTO accounts (name, platform, type) VALUES ($1, $2, $3) RETURNING id",
|
"INSERT INTO accounts (name, platform, type) VALUES ($1, $2, $3) RETURNING id",
|
||||||
"acc-del", service.PlatformAnthropic, service.AccountTypeOAuth,
|
[]any{"acc-del", service.PlatformAnthropic, service.AccountTypeOAuth},
|
||||||
).Scan(&accountID))
|
&accountID,
|
||||||
|
))
|
||||||
_, err := s.tx.ExecContext(s.ctx, "INSERT INTO account_groups (account_id, group_id, priority, created_at) VALUES ($1, $2, $3, NOW())", accountID, g.ID, 1)
|
_, err := s.tx.ExecContext(s.ctx, "INSERT INTO account_groups (account_id, group_id, priority, created_at) VALUES ($1, $2, $3, NOW())", accountID, g.ID, 1)
|
||||||
s.Require().NoError(err)
|
s.Require().NoError(err)
|
||||||
|
|
||||||
@@ -432,11 +452,13 @@ func (s *GroupRepoSuite) TestDeleteAccountGroupsByGroupID_MultipleAccounts() {
|
|||||||
|
|
||||||
insertAccount := func(name string) int64 {
|
insertAccount := func(name string) int64 {
|
||||||
var id int64
|
var id int64
|
||||||
s.Require().NoError(s.tx.QueryRowContext(
|
s.Require().NoError(scanSingleRow(
|
||||||
s.ctx,
|
s.ctx,
|
||||||
|
s.tx,
|
||||||
"INSERT INTO accounts (name, platform, type) VALUES ($1, $2, $3) RETURNING id",
|
"INSERT INTO accounts (name, platform, type) VALUES ($1, $2, $3) RETURNING id",
|
||||||
name, service.PlatformAnthropic, service.AccountTypeOAuth,
|
[]any{name, service.PlatformAnthropic, service.AccountTypeOAuth},
|
||||||
).Scan(&id))
|
&id,
|
||||||
|
))
|
||||||
return id
|
return id
|
||||||
}
|
}
|
||||||
a1 := insertAccount("a1")
|
a1 := insertAccount("a1")
|
||||||
|
|||||||
@@ -36,8 +36,9 @@ const (
|
|||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
integrationDB *sql.DB
|
integrationDB *sql.DB
|
||||||
integrationRedis *redisclient.Client
|
integrationEntClient *dbent.Client
|
||||||
|
integrationRedis *redisclient.Client
|
||||||
|
|
||||||
redisNamespaceSeq uint64
|
redisNamespaceSeq uint64
|
||||||
)
|
)
|
||||||
@@ -101,6 +102,10 @@ func TestMain(m *testing.M) {
|
|||||||
os.Exit(1)
|
os.Exit(1)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// 创建 ent client 用于集成测试
|
||||||
|
drv := entsql.OpenDB(dialect.Postgres, integrationDB)
|
||||||
|
integrationEntClient = dbent.NewClient(dbent.Driver(drv))
|
||||||
|
|
||||||
redisHost, err := redisContainer.Host(ctx)
|
redisHost, err := redisContainer.Host(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Printf("failed to get redis host: %v", err)
|
log.Printf("failed to get redis host: %v", err)
|
||||||
@@ -123,6 +128,7 @@ func TestMain(m *testing.M) {
|
|||||||
|
|
||||||
code := m.Run()
|
code := m.Run()
|
||||||
|
|
||||||
|
_ = integrationEntClient.Close()
|
||||||
_ = integrationRedis.Close()
|
_ = integrationRedis.Close()
|
||||||
_ = integrationDB.Close()
|
_ = integrationDB.Close()
|
||||||
|
|
||||||
@@ -193,18 +199,38 @@ func testTx(t *testing.T) *sql.Tx {
|
|||||||
return tx
|
return tx
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// testEntClient 返回全局的 ent client,用于测试需要内部管理事务的代码(如 Create/Update 方法)。
|
||||||
|
// 注意:此 client 的操作会真正写入数据库,测试结束后不会自动回滚。
|
||||||
|
func testEntClient(t *testing.T) *dbent.Client {
|
||||||
|
t.Helper()
|
||||||
|
return integrationEntClient
|
||||||
|
}
|
||||||
|
|
||||||
|
// testEntTx 返回一个 ent 事务,用于需要事务隔离的测试。
|
||||||
|
// 测试结束后会自动回滚,不会影响数据库状态。
|
||||||
|
func testEntTx(t *testing.T) *dbent.Tx {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
tx, err := integrationEntClient.Tx(context.Background())
|
||||||
|
require.NoError(t, err, "begin ent tx")
|
||||||
|
t.Cleanup(func() {
|
||||||
|
_ = tx.Rollback()
|
||||||
|
})
|
||||||
|
return tx
|
||||||
|
}
|
||||||
|
|
||||||
|
// testEntSQLTx 已弃用:不要在新测试中使用此函数。
|
||||||
|
// 基于 *sql.Tx 创建的 ent client 在调用 client.Tx() 时会 panic。
|
||||||
|
// 对于需要测试内部使用事务的代码,请使用 testEntClient。
|
||||||
|
// 对于需要事务隔离的测试,请使用 testEntTx。
|
||||||
|
//
|
||||||
|
// Deprecated: Use testEntClient or testEntTx instead.
|
||||||
func testEntSQLTx(t *testing.T) (*dbent.Client, *sql.Tx) {
|
func testEntSQLTx(t *testing.T) (*dbent.Client, *sql.Tx) {
|
||||||
t.Helper()
|
t.Helper()
|
||||||
|
|
||||||
tx := testTx(t)
|
// 直接失败,避免旧测试误用导致的事务嵌套 panic。
|
||||||
drv := entsql.NewDriver(dialect.Postgres, entsql.Conn{ExecQuerier: tx})
|
t.Fatalf("testEntSQLTx 已弃用:请使用 testEntClient 或 testEntTx")
|
||||||
client := dbent.NewClient(dbent.Driver(drv))
|
return nil, nil
|
||||||
|
|
||||||
t.Cleanup(func() {
|
|
||||||
_ = client.Close()
|
|
||||||
})
|
|
||||||
|
|
||||||
return client, tx
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func testRedis(t *testing.T) *redisclient.Client {
|
func testRedis(t *testing.T) *redisclient.Client {
|
||||||
@@ -363,13 +389,16 @@ type IntegrationDBSuite struct {
|
|||||||
suite.Suite
|
suite.Suite
|
||||||
ctx context.Context
|
ctx context.Context
|
||||||
client *dbent.Client
|
client *dbent.Client
|
||||||
tx *sql.Tx
|
tx *dbent.Tx
|
||||||
}
|
}
|
||||||
|
|
||||||
// SetupTest initializes ctx and client for each test method.
|
// SetupTest initializes ctx and client for each test method.
|
||||||
func (s *IntegrationDBSuite) SetupTest() {
|
func (s *IntegrationDBSuite) SetupTest() {
|
||||||
s.ctx = context.Background()
|
s.ctx = context.Background()
|
||||||
s.client, s.tx = testEntSQLTx(s.T())
|
// 统一使用 ent.Tx,确保每个测试都有独立事务并自动回滚。
|
||||||
|
tx := testEntTx(s.T())
|
||||||
|
s.tx = tx
|
||||||
|
s.client = tx.Client()
|
||||||
}
|
}
|
||||||
|
|
||||||
// RequireNoError is a convenience method wrapping require.NoError with s.T().
|
// RequireNoError is a convenience method wrapping require.NoError with s.T().
|
||||||
|
|||||||
@@ -13,7 +13,6 @@ import (
|
|||||||
|
|
||||||
type sqlQuerier interface {
|
type sqlQuerier interface {
|
||||||
QueryContext(ctx context.Context, query string, args ...any) (*sql.Rows, error)
|
QueryContext(ctx context.Context, query string, args ...any) (*sql.Rows, error)
|
||||||
QueryRowContext(ctx context.Context, query string, args ...any) *sql.Row
|
|
||||||
}
|
}
|
||||||
|
|
||||||
type proxyRepository struct {
|
type proxyRepository struct {
|
||||||
@@ -170,9 +169,8 @@ func (r *proxyRepository) ExistsByHostPortAuth(ctx context.Context, host string,
|
|||||||
|
|
||||||
// CountAccountsByProxyID returns the number of accounts using a specific proxy
|
// CountAccountsByProxyID returns the number of accounts using a specific proxy
|
||||||
func (r *proxyRepository) CountAccountsByProxyID(ctx context.Context, proxyID int64) (int64, error) {
|
func (r *proxyRepository) CountAccountsByProxyID(ctx context.Context, proxyID int64) (int64, error) {
|
||||||
row := r.sql.QueryRowContext(ctx, "SELECT COUNT(*) FROM accounts WHERE proxy_id = $1", proxyID)
|
|
||||||
var count int64
|
var count int64
|
||||||
if err := row.Scan(&count); err != nil {
|
if err := scanSingleRow(ctx, r.sql, "SELECT COUNT(*) FROM accounts WHERE proxy_id = $1", []any{proxyID}, &count); err != nil {
|
||||||
return 0, err
|
return 0, err
|
||||||
}
|
}
|
||||||
return count, nil
|
return count, nil
|
||||||
|
|||||||
@@ -4,10 +4,10 @@ package repository
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"database/sql"
|
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
dbent "github.com/Wei-Shaw/sub2api/ent"
|
||||||
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||||
"github.com/stretchr/testify/suite"
|
"github.com/stretchr/testify/suite"
|
||||||
@@ -15,16 +15,16 @@ import (
|
|||||||
|
|
||||||
type ProxyRepoSuite struct {
|
type ProxyRepoSuite struct {
|
||||||
suite.Suite
|
suite.Suite
|
||||||
ctx context.Context
|
ctx context.Context
|
||||||
sqlTx *sql.Tx
|
tx *dbent.Tx
|
||||||
repo *proxyRepository
|
repo *proxyRepository
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *ProxyRepoSuite) SetupTest() {
|
func (s *ProxyRepoSuite) SetupTest() {
|
||||||
s.ctx = context.Background()
|
s.ctx = context.Background()
|
||||||
entClient, sqlTx := testEntSQLTx(s.T())
|
tx := testEntTx(s.T())
|
||||||
s.sqlTx = sqlTx
|
s.tx = tx
|
||||||
s.repo = newProxyRepositoryWithSQL(entClient, sqlTx)
|
s.repo = newProxyRepositoryWithSQL(tx.Client(), tx)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestProxyRepoSuite(t *testing.T) {
|
func TestProxyRepoSuite(t *testing.T) {
|
||||||
@@ -306,7 +306,7 @@ func (s *ProxyRepoSuite) mustCreateProxyWithTimes(name, status string, createdAt
|
|||||||
Port: 8080,
|
Port: 8080,
|
||||||
Status: status,
|
Status: status,
|
||||||
})
|
})
|
||||||
_, err := s.sqlTx.ExecContext(s.ctx, "UPDATE proxies SET created_at = $1, updated_at = $1 WHERE id = $2", createdAt, p.ID)
|
_, err := s.tx.ExecContext(s.ctx, "UPDATE proxies SET created_at = $1, updated_at = $1 WHERE id = $2", createdAt, p.ID)
|
||||||
s.Require().NoError(err, "update proxy timestamps")
|
s.Require().NoError(err, "update proxy timestamps")
|
||||||
return p
|
return p
|
||||||
}
|
}
|
||||||
@@ -317,7 +317,7 @@ func (s *ProxyRepoSuite) mustInsertAccount(name string, proxyID *int64) {
|
|||||||
if proxyID != nil {
|
if proxyID != nil {
|
||||||
pid = *proxyID
|
pid = *proxyID
|
||||||
}
|
}
|
||||||
_, err := s.sqlTx.ExecContext(
|
_, err := s.tx.ExecContext(
|
||||||
s.ctx,
|
s.ctx,
|
||||||
"INSERT INTO accounts (name, platform, type, proxy_id) VALUES ($1, $2, $3, $4)",
|
"INSERT INTO accounts (name, platform, type, proxy_id) VALUES ($1, $2, $3, $4)",
|
||||||
name,
|
name,
|
||||||
|
|||||||
@@ -22,9 +22,9 @@ type RedeemCodeRepoSuite struct {
|
|||||||
|
|
||||||
func (s *RedeemCodeRepoSuite) SetupTest() {
|
func (s *RedeemCodeRepoSuite) SetupTest() {
|
||||||
s.ctx = context.Background()
|
s.ctx = context.Background()
|
||||||
entClient, _ := testEntSQLTx(s.T())
|
tx := testEntTx(s.T())
|
||||||
s.client = entClient
|
s.client = tx.Client()
|
||||||
s.repo = NewRedeemCodeRepository(entClient).(*redeemCodeRepository)
|
s.repo = NewRedeemCodeRepository(s.client).(*redeemCodeRepository)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestRedeemCodeRepoSuite(t *testing.T) {
|
func TestRedeemCodeRepoSuite(t *testing.T) {
|
||||||
|
|||||||
@@ -18,8 +18,8 @@ type SettingRepoSuite struct {
|
|||||||
|
|
||||||
func (s *SettingRepoSuite) SetupTest() {
|
func (s *SettingRepoSuite) SetupTest() {
|
||||||
s.ctx = context.Background()
|
s.ctx = context.Background()
|
||||||
entClient, _ := testEntSQLTx(s.T())
|
tx := testEntTx(s.T())
|
||||||
s.repo = NewSettingRepository(entClient).(*settingRepository)
|
s.repo = NewSettingRepository(tx.Client()).(*settingRepository)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestSettingRepoSuite(t *testing.T) {
|
func TestSettingRepoSuite(t *testing.T) {
|
||||||
|
|||||||
@@ -34,7 +34,8 @@ func createEntUser(t *testing.T, ctx context.Context, client *dbent.Client, emai
|
|||||||
|
|
||||||
func TestEntSoftDelete_ApiKey_DefaultFilterAndSkip(t *testing.T) {
|
func TestEntSoftDelete_ApiKey_DefaultFilterAndSkip(t *testing.T) {
|
||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
client, _ := testEntSQLTx(t)
|
// 使用全局 ent client,确保软删除验证在实际持久化数据上进行。
|
||||||
|
client := testEntClient(t)
|
||||||
|
|
||||||
u := createEntUser(t, ctx, client, uniqueSoftDeleteValue(t, "sd-user")+"@example.com")
|
u := createEntUser(t, ctx, client, uniqueSoftDeleteValue(t, "sd-user")+"@example.com")
|
||||||
|
|
||||||
@@ -65,7 +66,8 @@ func TestEntSoftDelete_ApiKey_DefaultFilterAndSkip(t *testing.T) {
|
|||||||
|
|
||||||
func TestEntSoftDelete_ApiKey_DeleteIdempotent(t *testing.T) {
|
func TestEntSoftDelete_ApiKey_DeleteIdempotent(t *testing.T) {
|
||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
client, _ := testEntSQLTx(t)
|
// 使用全局 ent client,避免事务回滚影响幂等性验证。
|
||||||
|
client := testEntClient(t)
|
||||||
|
|
||||||
u := createEntUser(t, ctx, client, uniqueSoftDeleteValue(t, "sd-user2")+"@example.com")
|
u := createEntUser(t, ctx, client, uniqueSoftDeleteValue(t, "sd-user2")+"@example.com")
|
||||||
|
|
||||||
@@ -84,7 +86,8 @@ func TestEntSoftDelete_ApiKey_DeleteIdempotent(t *testing.T) {
|
|||||||
|
|
||||||
func TestEntSoftDelete_ApiKey_HardDeleteViaSkipSoftDelete(t *testing.T) {
|
func TestEntSoftDelete_ApiKey_HardDeleteViaSkipSoftDelete(t *testing.T) {
|
||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
client, _ := testEntSQLTx(t)
|
// 使用全局 ent client,确保 SkipSoftDelete 的硬删除语义可验证。
|
||||||
|
client := testEntClient(t)
|
||||||
|
|
||||||
u := createEntUser(t, ctx, client, uniqueSoftDeleteValue(t, "sd-user3")+"@example.com")
|
u := createEntUser(t, ctx, client, uniqueSoftDeleteValue(t, "sd-user3")+"@example.com")
|
||||||
|
|
||||||
|
|||||||
33
backend/internal/repository/sql_scan.go
Normal file
33
backend/internal/repository/sql_scan.go
Normal file
@@ -0,0 +1,33 @@
|
|||||||
|
package repository
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"database/sql"
|
||||||
|
)
|
||||||
|
|
||||||
|
type sqlQueryer interface {
|
||||||
|
QueryContext(ctx context.Context, query string, args ...any) (*sql.Rows, error)
|
||||||
|
}
|
||||||
|
|
||||||
|
// scanSingleRow executes a query and scans the first row into dest.
|
||||||
|
// If no rows are returned, sql.ErrNoRows is returned.
|
||||||
|
// 设计目的:仅依赖 QueryContext,避免 QueryRowContext 对 *sql.Tx 的强绑定,
|
||||||
|
// 让 ent.Tx 也能作为 sqlExecutor/Queryer 使用。
|
||||||
|
func scanSingleRow(ctx context.Context, q sqlQueryer, query string, args []any, dest ...any) error {
|
||||||
|
rows, err := q.QueryContext(ctx, query, args...)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
defer rows.Close()
|
||||||
|
|
||||||
|
if !rows.Next() {
|
||||||
|
if err := rows.Err(); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return sql.ErrNoRows
|
||||||
|
}
|
||||||
|
if err := rows.Scan(dest...); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return rows.Err()
|
||||||
|
}
|
||||||
@@ -3,7 +3,6 @@ package repository
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"database/sql"
|
"database/sql"
|
||||||
"errors"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
@@ -33,6 +32,7 @@ func NewUsageLogRepository(client *dbent.Client, sqlDB *sql.DB) service.UsageLog
|
|||||||
}
|
}
|
||||||
|
|
||||||
func newUsageLogRepositoryWithSQL(client *dbent.Client, sqlq sqlExecutor) *usageLogRepository {
|
func newUsageLogRepositoryWithSQL(client *dbent.Client, sqlq sqlExecutor) *usageLogRepository {
|
||||||
|
// 使用 scanSingleRow 替代 QueryRowContext,保证 ent.Tx 作为 sqlExecutor 可用。
|
||||||
return &usageLogRepository{client: client, sql: sqlq}
|
return &usageLogRepository{client: client, sql: sqlq}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -53,7 +53,7 @@ func (r *usageLogRepository) getPerformanceStats(ctx context.Context, userID int
|
|||||||
|
|
||||||
var requestCount int64
|
var requestCount int64
|
||||||
var tokenCount int64
|
var tokenCount int64
|
||||||
if err := r.sql.QueryRowContext(ctx, query, args...).Scan(&requestCount, &tokenCount); err != nil {
|
if err := scanSingleRow(ctx, r.sql, query, args, &requestCount, &tokenCount); err != nil {
|
||||||
return 0, 0, err
|
return 0, 0, err
|
||||||
}
|
}
|
||||||
return requestCount / 5, tokenCount / 5, nil
|
return requestCount / 5, tokenCount / 5, nil
|
||||||
@@ -114,9 +114,7 @@ func (r *usageLogRepository) Create(ctx context.Context, log *service.UsageLog)
|
|||||||
duration := nullInt(log.DurationMs)
|
duration := nullInt(log.DurationMs)
|
||||||
firstToken := nullInt(log.FirstTokenMs)
|
firstToken := nullInt(log.FirstTokenMs)
|
||||||
|
|
||||||
row := r.sql.QueryRowContext(
|
args := []any{
|
||||||
ctx,
|
|
||||||
query,
|
|
||||||
log.UserID,
|
log.UserID,
|
||||||
log.ApiKeyID,
|
log.ApiKeyID,
|
||||||
log.AccountID,
|
log.AccountID,
|
||||||
@@ -142,9 +140,8 @@ func (r *usageLogRepository) Create(ctx context.Context, log *service.UsageLog)
|
|||||||
duration,
|
duration,
|
||||||
firstToken,
|
firstToken,
|
||||||
createdAt,
|
createdAt,
|
||||||
)
|
}
|
||||||
|
if err := scanSingleRow(ctx, r.sql, query, args, &log.ID, &log.CreatedAt); err != nil {
|
||||||
if err := row.Scan(&log.ID, &log.CreatedAt); err != nil {
|
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
log.RateMultiplier = rateMultiplier
|
log.RateMultiplier = rateMultiplier
|
||||||
@@ -153,11 +150,22 @@ func (r *usageLogRepository) Create(ctx context.Context, log *service.UsageLog)
|
|||||||
|
|
||||||
func (r *usageLogRepository) GetByID(ctx context.Context, id int64) (*service.UsageLog, error) {
|
func (r *usageLogRepository) GetByID(ctx context.Context, id int64) (*service.UsageLog, error) {
|
||||||
query := "SELECT " + usageLogSelectColumns + " FROM usage_logs WHERE id = $1"
|
query := "SELECT " + usageLogSelectColumns + " FROM usage_logs WHERE id = $1"
|
||||||
log, err := scanUsageLog(r.sql.QueryRowContext(ctx, query, id))
|
rows, err := r.sql.QueryContext(ctx, query, id)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if errors.Is(err, sql.ErrNoRows) {
|
return nil, err
|
||||||
return nil, service.ErrUsageLogNotFound
|
}
|
||||||
|
defer rows.Close()
|
||||||
|
if !rows.Next() {
|
||||||
|
if err := rows.Err(); err != nil {
|
||||||
|
return nil, err
|
||||||
}
|
}
|
||||||
|
return nil, service.ErrUsageLogNotFound
|
||||||
|
}
|
||||||
|
log, err := scanUsageLog(rows)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if err := rows.Err(); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
return log, nil
|
return log, nil
|
||||||
@@ -195,8 +203,18 @@ func (r *usageLogRepository) GetUserStats(ctx context.Context, userID int64, sta
|
|||||||
`
|
`
|
||||||
|
|
||||||
stats := &UserStats{}
|
stats := &UserStats{}
|
||||||
if err := r.sql.QueryRowContext(ctx, query, userID, startTime, endTime).
|
if err := scanSingleRow(
|
||||||
Scan(&stats.TotalRequests, &stats.TotalTokens, &stats.TotalCost, &stats.InputTokens, &stats.OutputTokens, &stats.CacheReadTokens); err != nil {
|
ctx,
|
||||||
|
r.sql,
|
||||||
|
query,
|
||||||
|
[]any{userID, startTime, endTime},
|
||||||
|
&stats.TotalRequests,
|
||||||
|
&stats.TotalTokens,
|
||||||
|
&stats.TotalCost,
|
||||||
|
&stats.InputTokens,
|
||||||
|
&stats.OutputTokens,
|
||||||
|
&stats.CacheReadTokens,
|
||||||
|
); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
return stats, nil
|
return stats, nil
|
||||||
@@ -219,8 +237,15 @@ func (r *usageLogRepository) GetDashboardStats(ctx context.Context) (*DashboardS
|
|||||||
FROM users
|
FROM users
|
||||||
WHERE deleted_at IS NULL
|
WHERE deleted_at IS NULL
|
||||||
`
|
`
|
||||||
if err := r.sql.QueryRowContext(ctx, userStatsQuery, today, today).
|
if err := scanSingleRow(
|
||||||
Scan(&stats.TotalUsers, &stats.TodayNewUsers, &stats.ActiveUsers); err != nil {
|
ctx,
|
||||||
|
r.sql,
|
||||||
|
userStatsQuery,
|
||||||
|
[]any{today, today},
|
||||||
|
&stats.TotalUsers,
|
||||||
|
&stats.TodayNewUsers,
|
||||||
|
&stats.ActiveUsers,
|
||||||
|
); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -232,8 +257,14 @@ func (r *usageLogRepository) GetDashboardStats(ctx context.Context) (*DashboardS
|
|||||||
FROM api_keys
|
FROM api_keys
|
||||||
WHERE deleted_at IS NULL
|
WHERE deleted_at IS NULL
|
||||||
`
|
`
|
||||||
if err := r.sql.QueryRowContext(ctx, apiKeyStatsQuery, service.StatusActive).
|
if err := scanSingleRow(
|
||||||
Scan(&stats.TotalApiKeys, &stats.ActiveApiKeys); err != nil {
|
ctx,
|
||||||
|
r.sql,
|
||||||
|
apiKeyStatsQuery,
|
||||||
|
[]any{service.StatusActive},
|
||||||
|
&stats.TotalApiKeys,
|
||||||
|
&stats.ActiveApiKeys,
|
||||||
|
); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -248,8 +279,17 @@ func (r *usageLogRepository) GetDashboardStats(ctx context.Context) (*DashboardS
|
|||||||
FROM accounts
|
FROM accounts
|
||||||
WHERE deleted_at IS NULL
|
WHERE deleted_at IS NULL
|
||||||
`
|
`
|
||||||
if err := r.sql.QueryRowContext(ctx, accountStatsQuery, service.StatusActive, service.StatusError, now, now).
|
if err := scanSingleRow(
|
||||||
Scan(&stats.TotalAccounts, &stats.NormalAccounts, &stats.ErrorAccounts, &stats.RateLimitAccounts, &stats.OverloadAccounts); err != nil {
|
ctx,
|
||||||
|
r.sql,
|
||||||
|
accountStatsQuery,
|
||||||
|
[]any{service.StatusActive, service.StatusError, now, now},
|
||||||
|
&stats.TotalAccounts,
|
||||||
|
&stats.NormalAccounts,
|
||||||
|
&stats.ErrorAccounts,
|
||||||
|
&stats.RateLimitAccounts,
|
||||||
|
&stats.OverloadAccounts,
|
||||||
|
); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -266,17 +306,20 @@ func (r *usageLogRepository) GetDashboardStats(ctx context.Context) (*DashboardS
|
|||||||
COALESCE(AVG(duration_ms), 0) as avg_duration_ms
|
COALESCE(AVG(duration_ms), 0) as avg_duration_ms
|
||||||
FROM usage_logs
|
FROM usage_logs
|
||||||
`
|
`
|
||||||
if err := r.sql.QueryRowContext(ctx, totalStatsQuery).
|
if err := scanSingleRow(
|
||||||
Scan(
|
ctx,
|
||||||
&stats.TotalRequests,
|
r.sql,
|
||||||
&stats.TotalInputTokens,
|
totalStatsQuery,
|
||||||
&stats.TotalOutputTokens,
|
nil,
|
||||||
&stats.TotalCacheCreationTokens,
|
&stats.TotalRequests,
|
||||||
&stats.TotalCacheReadTokens,
|
&stats.TotalInputTokens,
|
||||||
&stats.TotalCost,
|
&stats.TotalOutputTokens,
|
||||||
&stats.TotalActualCost,
|
&stats.TotalCacheCreationTokens,
|
||||||
&stats.AverageDurationMs,
|
&stats.TotalCacheReadTokens,
|
||||||
); err != nil {
|
&stats.TotalCost,
|
||||||
|
&stats.TotalActualCost,
|
||||||
|
&stats.AverageDurationMs,
|
||||||
|
); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
stats.TotalTokens = stats.TotalInputTokens + stats.TotalOutputTokens + stats.TotalCacheCreationTokens + stats.TotalCacheReadTokens
|
stats.TotalTokens = stats.TotalInputTokens + stats.TotalOutputTokens + stats.TotalCacheCreationTokens + stats.TotalCacheReadTokens
|
||||||
@@ -294,16 +337,19 @@ func (r *usageLogRepository) GetDashboardStats(ctx context.Context) (*DashboardS
|
|||||||
FROM usage_logs
|
FROM usage_logs
|
||||||
WHERE created_at >= $1
|
WHERE created_at >= $1
|
||||||
`
|
`
|
||||||
if err := r.sql.QueryRowContext(ctx, todayStatsQuery, today).
|
if err := scanSingleRow(
|
||||||
Scan(
|
ctx,
|
||||||
&stats.TodayRequests,
|
r.sql,
|
||||||
&stats.TodayInputTokens,
|
todayStatsQuery,
|
||||||
&stats.TodayOutputTokens,
|
[]any{today},
|
||||||
&stats.TodayCacheCreationTokens,
|
&stats.TodayRequests,
|
||||||
&stats.TodayCacheReadTokens,
|
&stats.TodayInputTokens,
|
||||||
&stats.TodayCost,
|
&stats.TodayOutputTokens,
|
||||||
&stats.TodayActualCost,
|
&stats.TodayCacheCreationTokens,
|
||||||
); err != nil {
|
&stats.TodayCacheReadTokens,
|
||||||
|
&stats.TodayCost,
|
||||||
|
&stats.TodayActualCost,
|
||||||
|
); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
stats.TodayTokens = stats.TodayInputTokens + stats.TodayOutputTokens + stats.TodayCacheCreationTokens + stats.TodayCacheReadTokens
|
stats.TodayTokens = stats.TodayInputTokens + stats.TodayOutputTokens + stats.TodayCacheCreationTokens + stats.TodayCacheReadTokens
|
||||||
@@ -345,16 +391,19 @@ func (r *usageLogRepository) GetUserStatsAggregated(ctx context.Context, userID
|
|||||||
`
|
`
|
||||||
|
|
||||||
var stats usagestats.UsageStats
|
var stats usagestats.UsageStats
|
||||||
if err := r.sql.QueryRowContext(ctx, query, userID, startTime, endTime).
|
if err := scanSingleRow(
|
||||||
Scan(
|
ctx,
|
||||||
&stats.TotalRequests,
|
r.sql,
|
||||||
&stats.TotalInputTokens,
|
query,
|
||||||
&stats.TotalOutputTokens,
|
[]any{userID, startTime, endTime},
|
||||||
&stats.TotalCacheTokens,
|
&stats.TotalRequests,
|
||||||
&stats.TotalCost,
|
&stats.TotalInputTokens,
|
||||||
&stats.TotalActualCost,
|
&stats.TotalOutputTokens,
|
||||||
&stats.AverageDurationMs,
|
&stats.TotalCacheTokens,
|
||||||
); err != nil {
|
&stats.TotalCost,
|
||||||
|
&stats.TotalActualCost,
|
||||||
|
&stats.AverageDurationMs,
|
||||||
|
); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
stats.TotalTokens = stats.TotalInputTokens + stats.TotalOutputTokens + stats.TotalCacheTokens
|
stats.TotalTokens = stats.TotalInputTokens + stats.TotalOutputTokens + stats.TotalCacheTokens
|
||||||
@@ -377,16 +426,19 @@ func (r *usageLogRepository) GetApiKeyStatsAggregated(ctx context.Context, apiKe
|
|||||||
`
|
`
|
||||||
|
|
||||||
var stats usagestats.UsageStats
|
var stats usagestats.UsageStats
|
||||||
if err := r.sql.QueryRowContext(ctx, query, apiKeyID, startTime, endTime).
|
if err := scanSingleRow(
|
||||||
Scan(
|
ctx,
|
||||||
&stats.TotalRequests,
|
r.sql,
|
||||||
&stats.TotalInputTokens,
|
query,
|
||||||
&stats.TotalOutputTokens,
|
[]any{apiKeyID, startTime, endTime},
|
||||||
&stats.TotalCacheTokens,
|
&stats.TotalRequests,
|
||||||
&stats.TotalCost,
|
&stats.TotalInputTokens,
|
||||||
&stats.TotalActualCost,
|
&stats.TotalOutputTokens,
|
||||||
&stats.AverageDurationMs,
|
&stats.TotalCacheTokens,
|
||||||
); err != nil {
|
&stats.TotalCost,
|
||||||
|
&stats.TotalActualCost,
|
||||||
|
&stats.AverageDurationMs,
|
||||||
|
); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
stats.TotalTokens = stats.TotalInputTokens + stats.TotalOutputTokens + stats.TotalCacheTokens
|
stats.TotalTokens = stats.TotalInputTokens + stats.TotalOutputTokens + stats.TotalCacheTokens
|
||||||
@@ -430,8 +482,15 @@ func (r *usageLogRepository) GetAccountTodayStats(ctx context.Context, accountID
|
|||||||
`
|
`
|
||||||
|
|
||||||
stats := &usagestats.AccountStats{}
|
stats := &usagestats.AccountStats{}
|
||||||
if err := r.sql.QueryRowContext(ctx, query, accountID, today).
|
if err := scanSingleRow(
|
||||||
Scan(&stats.Requests, &stats.Tokens, &stats.Cost); err != nil {
|
ctx,
|
||||||
|
r.sql,
|
||||||
|
query,
|
||||||
|
[]any{accountID, today},
|
||||||
|
&stats.Requests,
|
||||||
|
&stats.Tokens,
|
||||||
|
&stats.Cost,
|
||||||
|
); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
return stats, nil
|
return stats, nil
|
||||||
@@ -449,8 +508,15 @@ func (r *usageLogRepository) GetAccountWindowStats(ctx context.Context, accountI
|
|||||||
`
|
`
|
||||||
|
|
||||||
stats := &usagestats.AccountStats{}
|
stats := &usagestats.AccountStats{}
|
||||||
if err := r.sql.QueryRowContext(ctx, query, accountID, startTime).
|
if err := scanSingleRow(
|
||||||
Scan(&stats.Requests, &stats.Tokens, &stats.Cost); err != nil {
|
ctx,
|
||||||
|
r.sql,
|
||||||
|
query,
|
||||||
|
[]any{accountID, startTime},
|
||||||
|
&stats.Requests,
|
||||||
|
&stats.Tokens,
|
||||||
|
&stats.Cost,
|
||||||
|
); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
return stats, nil
|
return stats, nil
|
||||||
@@ -581,12 +647,22 @@ func (r *usageLogRepository) GetUserDashboardStats(ctx context.Context, userID i
|
|||||||
today := timezone.Today()
|
today := timezone.Today()
|
||||||
|
|
||||||
// API Key 统计
|
// API Key 统计
|
||||||
if err := r.sql.QueryRowContext(ctx, "SELECT COUNT(*) FROM api_keys WHERE user_id = $1 AND deleted_at IS NULL", userID).
|
if err := scanSingleRow(
|
||||||
Scan(&stats.TotalApiKeys); err != nil {
|
ctx,
|
||||||
|
r.sql,
|
||||||
|
"SELECT COUNT(*) FROM api_keys WHERE user_id = $1 AND deleted_at IS NULL",
|
||||||
|
[]any{userID},
|
||||||
|
&stats.TotalApiKeys,
|
||||||
|
); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
if err := r.sql.QueryRowContext(ctx, "SELECT COUNT(*) FROM api_keys WHERE user_id = $1 AND status = $2 AND deleted_at IS NULL", userID, service.StatusActive).
|
if err := scanSingleRow(
|
||||||
Scan(&stats.ActiveApiKeys); err != nil {
|
ctx,
|
||||||
|
r.sql,
|
||||||
|
"SELECT COUNT(*) FROM api_keys WHERE user_id = $1 AND status = $2 AND deleted_at IS NULL",
|
||||||
|
[]any{userID, service.StatusActive},
|
||||||
|
&stats.ActiveApiKeys,
|
||||||
|
); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -604,17 +680,20 @@ func (r *usageLogRepository) GetUserDashboardStats(ctx context.Context, userID i
|
|||||||
FROM usage_logs
|
FROM usage_logs
|
||||||
WHERE user_id = $1
|
WHERE user_id = $1
|
||||||
`
|
`
|
||||||
if err := r.sql.QueryRowContext(ctx, totalStatsQuery, userID).
|
if err := scanSingleRow(
|
||||||
Scan(
|
ctx,
|
||||||
&stats.TotalRequests,
|
r.sql,
|
||||||
&stats.TotalInputTokens,
|
totalStatsQuery,
|
||||||
&stats.TotalOutputTokens,
|
[]any{userID},
|
||||||
&stats.TotalCacheCreationTokens,
|
&stats.TotalRequests,
|
||||||
&stats.TotalCacheReadTokens,
|
&stats.TotalInputTokens,
|
||||||
&stats.TotalCost,
|
&stats.TotalOutputTokens,
|
||||||
&stats.TotalActualCost,
|
&stats.TotalCacheCreationTokens,
|
||||||
&stats.AverageDurationMs,
|
&stats.TotalCacheReadTokens,
|
||||||
); err != nil {
|
&stats.TotalCost,
|
||||||
|
&stats.TotalActualCost,
|
||||||
|
&stats.AverageDurationMs,
|
||||||
|
); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
stats.TotalTokens = stats.TotalInputTokens + stats.TotalOutputTokens + stats.TotalCacheCreationTokens + stats.TotalCacheReadTokens
|
stats.TotalTokens = stats.TotalInputTokens + stats.TotalOutputTokens + stats.TotalCacheCreationTokens + stats.TotalCacheReadTokens
|
||||||
@@ -632,16 +711,19 @@ func (r *usageLogRepository) GetUserDashboardStats(ctx context.Context, userID i
|
|||||||
FROM usage_logs
|
FROM usage_logs
|
||||||
WHERE user_id = $1 AND created_at >= $2
|
WHERE user_id = $1 AND created_at >= $2
|
||||||
`
|
`
|
||||||
if err := r.sql.QueryRowContext(ctx, todayStatsQuery, userID, today).
|
if err := scanSingleRow(
|
||||||
Scan(
|
ctx,
|
||||||
&stats.TodayRequests,
|
r.sql,
|
||||||
&stats.TodayInputTokens,
|
todayStatsQuery,
|
||||||
&stats.TodayOutputTokens,
|
[]any{userID, today},
|
||||||
&stats.TodayCacheCreationTokens,
|
&stats.TodayRequests,
|
||||||
&stats.TodayCacheReadTokens,
|
&stats.TodayInputTokens,
|
||||||
&stats.TodayCost,
|
&stats.TodayOutputTokens,
|
||||||
&stats.TodayActualCost,
|
&stats.TodayCacheCreationTokens,
|
||||||
); err != nil {
|
&stats.TodayCacheReadTokens,
|
||||||
|
&stats.TodayCost,
|
||||||
|
&stats.TodayActualCost,
|
||||||
|
); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
stats.TodayTokens = stats.TodayInputTokens + stats.TodayOutputTokens + stats.TodayCacheCreationTokens + stats.TodayCacheReadTokens
|
stats.TodayTokens = stats.TodayInputTokens + stats.TodayOutputTokens + stats.TodayCacheCreationTokens + stats.TodayCacheReadTokens
|
||||||
@@ -1007,16 +1089,19 @@ func (r *usageLogRepository) GetGlobalStats(ctx context.Context, startTime, endT
|
|||||||
`
|
`
|
||||||
|
|
||||||
stats := &UsageStats{}
|
stats := &UsageStats{}
|
||||||
if err := r.sql.QueryRowContext(ctx, query, startTime, endTime).
|
if err := scanSingleRow(
|
||||||
Scan(
|
ctx,
|
||||||
&stats.TotalRequests,
|
r.sql,
|
||||||
&stats.TotalInputTokens,
|
query,
|
||||||
&stats.TotalOutputTokens,
|
[]any{startTime, endTime},
|
||||||
&stats.TotalCacheTokens,
|
&stats.TotalRequests,
|
||||||
&stats.TotalCost,
|
&stats.TotalInputTokens,
|
||||||
&stats.TotalActualCost,
|
&stats.TotalOutputTokens,
|
||||||
&stats.AverageDurationMs,
|
&stats.TotalCacheTokens,
|
||||||
); err != nil {
|
&stats.TotalCost,
|
||||||
|
&stats.TotalActualCost,
|
||||||
|
&stats.AverageDurationMs,
|
||||||
|
); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
stats.TotalTokens = stats.TotalInputTokens + stats.TotalOutputTokens + stats.TotalCacheTokens
|
stats.TotalTokens = stats.TotalInputTokens + stats.TotalOutputTokens + stats.TotalCacheTokens
|
||||||
@@ -1108,7 +1193,7 @@ func (r *usageLogRepository) GetAccountUsageStats(ctx context.Context, accountID
|
|||||||
|
|
||||||
avgQuery := "SELECT COALESCE(AVG(duration_ms), 0) as avg_duration_ms FROM usage_logs WHERE account_id = $1 AND created_at >= $2 AND created_at < $3"
|
avgQuery := "SELECT COALESCE(AVG(duration_ms), 0) as avg_duration_ms FROM usage_logs WHERE account_id = $1 AND created_at >= $2 AND created_at < $3"
|
||||||
var avgDuration float64
|
var avgDuration float64
|
||||||
if err := r.sql.QueryRowContext(ctx, avgQuery, accountID, startTime, endTime).Scan(&avgDuration); err != nil {
|
if err := scanSingleRow(ctx, r.sql, avgQuery, []any{accountID, startTime, endTime}, &avgDuration); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1186,7 +1271,7 @@ func (r *usageLogRepository) GetAccountUsageStats(ctx context.Context, accountID
|
|||||||
func (r *usageLogRepository) listUsageLogsWithPagination(ctx context.Context, whereClause string, args []any, params pagination.PaginationParams) ([]service.UsageLog, *pagination.PaginationResult, error) {
|
func (r *usageLogRepository) listUsageLogsWithPagination(ctx context.Context, whereClause string, args []any, params pagination.PaginationParams) ([]service.UsageLog, *pagination.PaginationResult, error) {
|
||||||
countQuery := "SELECT COUNT(*) FROM usage_logs " + whereClause
|
countQuery := "SELECT COUNT(*) FROM usage_logs " + whereClause
|
||||||
var total int64
|
var total int64
|
||||||
if err := r.sql.QueryRowContext(ctx, countQuery, args...).Scan(&total); err != nil {
|
if err := scanSingleRow(ctx, r.sql, countQuery, args, &total); err != nil {
|
||||||
return nil, nil, err
|
return nil, nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -4,7 +4,6 @@ package repository
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"database/sql"
|
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
@@ -19,17 +18,17 @@ import (
|
|||||||
type UsageLogRepoSuite struct {
|
type UsageLogRepoSuite struct {
|
||||||
suite.Suite
|
suite.Suite
|
||||||
ctx context.Context
|
ctx context.Context
|
||||||
tx *sql.Tx
|
tx *dbent.Tx
|
||||||
client *dbent.Client
|
client *dbent.Client
|
||||||
repo *usageLogRepository
|
repo *usageLogRepository
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *UsageLogRepoSuite) SetupTest() {
|
func (s *UsageLogRepoSuite) SetupTest() {
|
||||||
s.ctx = context.Background()
|
s.ctx = context.Background()
|
||||||
client, tx := testEntSQLTx(s.T())
|
tx := testEntTx(s.T())
|
||||||
s.client = client
|
|
||||||
s.tx = tx
|
s.tx = tx
|
||||||
s.repo = newUsageLogRepositoryWithSQL(client, tx)
|
s.client = tx.Client()
|
||||||
|
s.repo = newUsageLogRepositoryWithSQL(s.client, tx)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestUsageLogRepoSuite(t *testing.T) {
|
func TestUsageLogRepoSuite(t *testing.T) {
|
||||||
@@ -197,6 +196,8 @@ func (s *UsageLogRepoSuite) TestListWithFilters() {
|
|||||||
func (s *UsageLogRepoSuite) TestDashboardStats_TodayTotalsAndPerformance() {
|
func (s *UsageLogRepoSuite) TestDashboardStats_TodayTotalsAndPerformance() {
|
||||||
now := time.Now()
|
now := time.Now()
|
||||||
todayStart := timezone.Today()
|
todayStart := timezone.Today()
|
||||||
|
baseStats, err := s.repo.GetDashboardStats(s.ctx)
|
||||||
|
s.Require().NoError(err, "GetDashboardStats base")
|
||||||
|
|
||||||
userToday := mustCreateUser(s.T(), s.client, &service.User{
|
userToday := mustCreateUser(s.T(), s.client, &service.User{
|
||||||
Email: "today@example.com",
|
Email: "today@example.com",
|
||||||
@@ -268,24 +269,24 @@ func (s *UsageLogRepoSuite) TestDashboardStats_TodayTotalsAndPerformance() {
|
|||||||
stats, err := s.repo.GetDashboardStats(s.ctx)
|
stats, err := s.repo.GetDashboardStats(s.ctx)
|
||||||
s.Require().NoError(err, "GetDashboardStats")
|
s.Require().NoError(err, "GetDashboardStats")
|
||||||
|
|
||||||
s.Require().Equal(int64(2), stats.TotalUsers, "TotalUsers mismatch")
|
s.Require().Equal(baseStats.TotalUsers+2, stats.TotalUsers, "TotalUsers mismatch")
|
||||||
s.Require().Equal(int64(1), stats.TodayNewUsers, "TodayNewUsers mismatch")
|
s.Require().Equal(baseStats.TodayNewUsers+1, stats.TodayNewUsers, "TodayNewUsers mismatch")
|
||||||
s.Require().Equal(int64(1), stats.ActiveUsers, "ActiveUsers mismatch")
|
s.Require().Equal(baseStats.ActiveUsers+1, stats.ActiveUsers, "ActiveUsers mismatch")
|
||||||
s.Require().Equal(int64(2), stats.TotalApiKeys, "TotalApiKeys mismatch")
|
s.Require().Equal(baseStats.TotalApiKeys+2, stats.TotalApiKeys, "TotalApiKeys mismatch")
|
||||||
s.Require().Equal(int64(1), stats.ActiveApiKeys, "ActiveApiKeys mismatch")
|
s.Require().Equal(baseStats.ActiveApiKeys+1, stats.ActiveApiKeys, "ActiveApiKeys mismatch")
|
||||||
s.Require().Equal(int64(4), stats.TotalAccounts, "TotalAccounts mismatch")
|
s.Require().Equal(baseStats.TotalAccounts+4, stats.TotalAccounts, "TotalAccounts mismatch")
|
||||||
s.Require().Equal(int64(1), stats.ErrorAccounts, "ErrorAccounts mismatch")
|
s.Require().Equal(baseStats.ErrorAccounts+1, stats.ErrorAccounts, "ErrorAccounts mismatch")
|
||||||
s.Require().Equal(int64(1), stats.RateLimitAccounts, "RateLimitAccounts mismatch")
|
s.Require().Equal(baseStats.RateLimitAccounts+1, stats.RateLimitAccounts, "RateLimitAccounts mismatch")
|
||||||
s.Require().Equal(int64(1), stats.OverloadAccounts, "OverloadAccounts mismatch")
|
s.Require().Equal(baseStats.OverloadAccounts+1, stats.OverloadAccounts, "OverloadAccounts mismatch")
|
||||||
|
|
||||||
s.Require().Equal(int64(3), stats.TotalRequests, "TotalRequests mismatch")
|
s.Require().Equal(baseStats.TotalRequests+3, stats.TotalRequests, "TotalRequests mismatch")
|
||||||
s.Require().Equal(int64(16), stats.TotalInputTokens, "TotalInputTokens mismatch")
|
s.Require().Equal(baseStats.TotalInputTokens+int64(16), stats.TotalInputTokens, "TotalInputTokens mismatch")
|
||||||
s.Require().Equal(int64(28), stats.TotalOutputTokens, "TotalOutputTokens mismatch")
|
s.Require().Equal(baseStats.TotalOutputTokens+int64(28), stats.TotalOutputTokens, "TotalOutputTokens mismatch")
|
||||||
s.Require().Equal(int64(3), stats.TotalCacheCreationTokens, "TotalCacheCreationTokens mismatch")
|
s.Require().Equal(baseStats.TotalCacheCreationTokens+int64(3), stats.TotalCacheCreationTokens, "TotalCacheCreationTokens mismatch")
|
||||||
s.Require().Equal(int64(4), stats.TotalCacheReadTokens, "TotalCacheReadTokens mismatch")
|
s.Require().Equal(baseStats.TotalCacheReadTokens+int64(4), stats.TotalCacheReadTokens, "TotalCacheReadTokens mismatch")
|
||||||
s.Require().Equal(int64(51), stats.TotalTokens, "TotalTokens mismatch")
|
s.Require().Equal(baseStats.TotalTokens+int64(51), stats.TotalTokens, "TotalTokens mismatch")
|
||||||
s.Require().Equal(2.3, stats.TotalCost, "TotalCost mismatch")
|
s.Require().Equal(baseStats.TotalCost+2.3, stats.TotalCost, "TotalCost mismatch")
|
||||||
s.Require().Equal(2.0, stats.TotalActualCost, "TotalActualCost mismatch")
|
s.Require().Equal(baseStats.TotalActualCost+2.0, stats.TotalActualCost, "TotalActualCost mismatch")
|
||||||
s.Require().GreaterOrEqual(stats.TodayRequests, int64(1), "expected TodayRequests >= 1")
|
s.Require().GreaterOrEqual(stats.TodayRequests, int64(1), "expected TodayRequests >= 1")
|
||||||
s.Require().GreaterOrEqual(stats.TodayCost, 0.0, "expected TodayCost >= 0")
|
s.Require().GreaterOrEqual(stats.TodayCost, 0.0, "expected TodayCost >= 0")
|
||||||
|
|
||||||
|
|||||||
@@ -3,6 +3,7 @@ package repository
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"database/sql"
|
"database/sql"
|
||||||
|
"errors"
|
||||||
"sort"
|
"sort"
|
||||||
|
|
||||||
dbent "github.com/Wei-Shaw/sub2api/ent"
|
dbent "github.com/Wei-Shaw/sub2api/ent"
|
||||||
@@ -17,7 +18,6 @@ import (
|
|||||||
type userRepository struct {
|
type userRepository struct {
|
||||||
client *dbent.Client
|
client *dbent.Client
|
||||||
sql sqlExecutor
|
sql sqlExecutor
|
||||||
begin sqlBeginner
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewUserRepository(client *dbent.Client, sqlDB *sql.DB) service.UserRepository {
|
func NewUserRepository(client *dbent.Client, sqlDB *sql.DB) service.UserRepository {
|
||||||
@@ -25,11 +25,7 @@ func NewUserRepository(client *dbent.Client, sqlDB *sql.DB) service.UserReposito
|
|||||||
}
|
}
|
||||||
|
|
||||||
func newUserRepositoryWithSQL(client *dbent.Client, sqlq sqlExecutor) *userRepository {
|
func newUserRepositoryWithSQL(client *dbent.Client, sqlq sqlExecutor) *userRepository {
|
||||||
var beginner sqlBeginner
|
return &userRepository{client: client, sql: sqlq}
|
||||||
if b, ok := sqlq.(sqlBeginner); ok {
|
|
||||||
beginner = b
|
|
||||||
}
|
|
||||||
return &userRepository{client: client, sql: sqlq, begin: beginner}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *userRepository) Create(ctx context.Context, userIn *service.User) error {
|
func (r *userRepository) Create(ctx context.Context, userIn *service.User) error {
|
||||||
@@ -37,22 +33,20 @@ func (r *userRepository) Create(ctx context.Context, userIn *service.User) error
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
exec := r.sql
|
// 统一使用 ent 的事务:保证用户与允许分组的更新原子化,
|
||||||
txClient := r.client
|
// 并避免基于 *sql.Tx 手动构造 ent client 导致的 ExecQuerier 断言错误。
|
||||||
var sqlTx *sql.Tx
|
tx, err := r.client.Tx(ctx)
|
||||||
|
if err != nil && !errors.Is(err, dbent.ErrTxStarted) {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
if r.begin != nil {
|
var txClient *dbent.Client
|
||||||
var err error
|
if err == nil {
|
||||||
sqlTx, err = r.begin.BeginTx(ctx, nil)
|
defer func() { _ = tx.Rollback() }()
|
||||||
if err != nil {
|
txClient = tx.Client()
|
||||||
return err
|
} else {
|
||||||
}
|
// 已处于外部事务中(ErrTxStarted),复用当前 client 并由调用方负责提交/回滚。
|
||||||
exec = sqlTx
|
txClient = r.client
|
||||||
txClient = entClientFromSQLTx(sqlTx)
|
|
||||||
// 注意:不能调用 txClient.Close(),因为基于事务的 ent client
|
|
||||||
// 在 Close() 时会尝试将 ExecQuerier 断言为 *sql.DB,但实际是 *sql.Tx
|
|
||||||
// 事务的清理通过 sqlTx.Rollback() 和 sqlTx.Commit() 完成
|
|
||||||
defer func() { _ = sqlTx.Rollback() }()
|
|
||||||
}
|
}
|
||||||
|
|
||||||
created, err := txClient.User.Create().
|
created, err := txClient.User.Create().
|
||||||
@@ -70,12 +64,12 @@ func (r *userRepository) Create(ctx context.Context, userIn *service.User) error
|
|||||||
return translatePersistenceError(err, nil, service.ErrEmailExists)
|
return translatePersistenceError(err, nil, service.ErrEmailExists)
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := r.syncUserAllowedGroups(ctx, txClient, exec, created.ID, userIn.AllowedGroups); err != nil {
|
if err := r.syncUserAllowedGroupsWithClient(ctx, txClient, created.ID, userIn.AllowedGroups); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
if sqlTx != nil {
|
if tx != nil {
|
||||||
if err := sqlTx.Commit(); err != nil {
|
if err := tx.Commit(); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -121,22 +115,19 @@ func (r *userRepository) Update(ctx context.Context, userIn *service.User) error
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
exec := r.sql
|
// 使用 ent 事务包裹用户更新与 allowed_groups 同步,避免跨层事务不一致。
|
||||||
txClient := r.client
|
tx, err := r.client.Tx(ctx)
|
||||||
var sqlTx *sql.Tx
|
if err != nil && !errors.Is(err, dbent.ErrTxStarted) {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
if r.begin != nil {
|
var txClient *dbent.Client
|
||||||
var err error
|
if err == nil {
|
||||||
sqlTx, err = r.begin.BeginTx(ctx, nil)
|
defer func() { _ = tx.Rollback() }()
|
||||||
if err != nil {
|
txClient = tx.Client()
|
||||||
return err
|
} else {
|
||||||
}
|
// 已处于外部事务中(ErrTxStarted),复用当前 client 并由调用方负责提交/回滚。
|
||||||
exec = sqlTx
|
txClient = r.client
|
||||||
txClient = entClientFromSQLTx(sqlTx)
|
|
||||||
// 注意:不能调用 txClient.Close(),因为基于事务的 ent client
|
|
||||||
// 在 Close() 时会尝试将 ExecQuerier 断言为 *sql.DB,但实际是 *sql.Tx
|
|
||||||
// 事务的清理通过 sqlTx.Rollback() 和 sqlTx.Commit() 完成
|
|
||||||
defer func() { _ = sqlTx.Rollback() }()
|
|
||||||
}
|
}
|
||||||
|
|
||||||
updated, err := txClient.User.UpdateOneID(userIn.ID).
|
updated, err := txClient.User.UpdateOneID(userIn.ID).
|
||||||
@@ -154,12 +145,12 @@ func (r *userRepository) Update(ctx context.Context, userIn *service.User) error
|
|||||||
return translatePersistenceError(err, service.ErrUserNotFound, service.ErrEmailExists)
|
return translatePersistenceError(err, service.ErrUserNotFound, service.ErrEmailExists)
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := r.syncUserAllowedGroups(ctx, txClient, exec, updated.ID, userIn.AllowedGroups); err != nil {
|
if err := r.syncUserAllowedGroupsWithClient(ctx, txClient, updated.ID, userIn.AllowedGroups); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
if sqlTx != nil {
|
if tx != nil {
|
||||||
if err := sqlTx.Commit(); err != nil {
|
if err := tx.Commit(); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -289,8 +280,10 @@ func (r *userRepository) ExistsByEmail(ctx context.Context, email string) (bool,
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (r *userRepository) RemoveGroupFromAllowedGroups(ctx context.Context, groupID int64) (int64, error) {
|
func (r *userRepository) RemoveGroupFromAllowedGroups(ctx context.Context, groupID int64) (int64, error) {
|
||||||
if r.sql == nil {
|
exec := r.sql
|
||||||
return 0, nil
|
if exec == nil {
|
||||||
|
// 未注入 sqlExecutor 时,退回到 ent client 的 ExecContext(支持事务)。
|
||||||
|
exec = r.client
|
||||||
}
|
}
|
||||||
|
|
||||||
joinAffected, err := r.client.UserAllowedGroup.Delete().
|
joinAffected, err := r.client.UserAllowedGroup.Delete().
|
||||||
@@ -300,7 +293,7 @@ func (r *userRepository) RemoveGroupFromAllowedGroups(ctx context.Context, group
|
|||||||
return 0, err
|
return 0, err
|
||||||
}
|
}
|
||||||
|
|
||||||
arrayRes, err := r.sql.ExecContext(
|
arrayRes, err := exec.ExecContext(
|
||||||
ctx,
|
ctx,
|
||||||
"UPDATE users SET allowed_groups = array_remove(allowed_groups, $1), updated_at = NOW() WHERE $1 = ANY(allowed_groups)",
|
"UPDATE users SET allowed_groups = array_remove(allowed_groups, $1), updated_at = NOW() WHERE $1 = ANY(allowed_groups)",
|
||||||
groupID,
|
groupID,
|
||||||
@@ -362,6 +355,56 @@ func (r *userRepository) loadAllowedGroups(ctx context.Context, userIDs []int64)
|
|||||||
return out, nil
|
return out, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// syncUserAllowedGroupsWithClient 在 ent client/事务内同步用户允许分组:
|
||||||
|
// 1) 以 user_allowed_groups 为读写源,确保新旧逻辑一致;
|
||||||
|
// 2) 额外更新 users.allowed_groups(历史字段)以保持兼容。
|
||||||
|
func (r *userRepository) syncUserAllowedGroupsWithClient(ctx context.Context, client *dbent.Client, userID int64, groupIDs []int64) error {
|
||||||
|
if client == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Keep join table as the source of truth for reads.
|
||||||
|
if _, err := client.UserAllowedGroup.Delete().Where(userallowedgroup.UserIDEQ(userID)).Exec(ctx); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
unique := make(map[int64]struct{}, len(groupIDs))
|
||||||
|
for _, id := range groupIDs {
|
||||||
|
if id <= 0 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
unique[id] = struct{}{}
|
||||||
|
}
|
||||||
|
|
||||||
|
legacyGroups := make([]int64, 0, len(unique))
|
||||||
|
if len(unique) > 0 {
|
||||||
|
creates := make([]*dbent.UserAllowedGroupCreate, 0, len(unique))
|
||||||
|
for groupID := range unique {
|
||||||
|
creates = append(creates, client.UserAllowedGroup.Create().SetUserID(userID).SetGroupID(groupID))
|
||||||
|
legacyGroups = append(legacyGroups, groupID)
|
||||||
|
}
|
||||||
|
if err := client.UserAllowedGroup.
|
||||||
|
CreateBulk(creates...).
|
||||||
|
OnConflictColumns(userallowedgroup.FieldUserID, userallowedgroup.FieldGroupID).
|
||||||
|
DoNothing().
|
||||||
|
Exec(ctx); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Phase 1 兼容:保持 users.allowed_groups(数组字段)同步,避免旧查询路径读取到过期数据。
|
||||||
|
var legacy any
|
||||||
|
if len(legacyGroups) > 0 {
|
||||||
|
sort.Slice(legacyGroups, func(i, j int) bool { return legacyGroups[i] < legacyGroups[j] })
|
||||||
|
legacy = pq.Array(legacyGroups)
|
||||||
|
}
|
||||||
|
if _, err := client.ExecContext(ctx, "UPDATE users SET allowed_groups = $1::bigint[] WHERE id = $2", legacy, userID); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
func (r *userRepository) syncUserAllowedGroups(ctx context.Context, client *dbent.Client, exec sqlExecutor, userID int64, groupIDs []int64) error {
|
func (r *userRepository) syncUserAllowedGroups(ctx context.Context, client *dbent.Client, exec sqlExecutor, userID int64, groupIDs []int64) error {
|
||||||
if client == nil || exec == nil {
|
if client == nil || exec == nil {
|
||||||
return nil
|
return nil
|
||||||
|
|||||||
@@ -4,7 +4,6 @@ package repository
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"database/sql"
|
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
@@ -17,17 +16,19 @@ import (
|
|||||||
type UserRepoSuite struct {
|
type UserRepoSuite struct {
|
||||||
suite.Suite
|
suite.Suite
|
||||||
ctx context.Context
|
ctx context.Context
|
||||||
tx *sql.Tx
|
|
||||||
client *dbent.Client
|
client *dbent.Client
|
||||||
repo *userRepository
|
repo *userRepository
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *UserRepoSuite) SetupTest() {
|
func (s *UserRepoSuite) SetupTest() {
|
||||||
s.ctx = context.Background()
|
s.ctx = context.Background()
|
||||||
entClient, tx := testEntSQLTx(s.T())
|
s.client = testEntClient(s.T())
|
||||||
s.tx = tx
|
s.repo = newUserRepositoryWithSQL(s.client, integrationDB)
|
||||||
s.client = entClient
|
|
||||||
s.repo = newUserRepositoryWithSQL(entClient, tx)
|
// 清理测试数据,确保每个测试从干净状态开始
|
||||||
|
_, _ = integrationDB.ExecContext(s.ctx, "DELETE FROM user_subscriptions")
|
||||||
|
_, _ = integrationDB.ExecContext(s.ctx, "DELETE FROM user_allowed_groups")
|
||||||
|
_, _ = integrationDB.ExecContext(s.ctx, "DELETE FROM users")
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestUserRepoSuite(t *testing.T) {
|
func TestUserRepoSuite(t *testing.T) {
|
||||||
|
|||||||
@@ -22,8 +22,8 @@ type UserSubscriptionRepoSuite struct {
|
|||||||
|
|
||||||
func (s *UserSubscriptionRepoSuite) SetupTest() {
|
func (s *UserSubscriptionRepoSuite) SetupTest() {
|
||||||
s.ctx = context.Background()
|
s.ctx = context.Background()
|
||||||
client, _ := testEntSQLTx(s.T())
|
tx := testEntTx(s.T())
|
||||||
s.client = client
|
s.client = tx.Client()
|
||||||
s.repo = NewUserSubscriptionRepository(s.client).(*userSubscriptionRepository)
|
s.repo = NewUserSubscriptionRepository(s.client).(*userSubscriptionRepository)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -66,8 +66,8 @@ func (s *UserSubscriptionRepoSuite) mustCreateSubscription(userID, groupID int64
|
|||||||
create := s.client.UserSubscription.Create().
|
create := s.client.UserSubscription.Create().
|
||||||
SetUserID(userID).
|
SetUserID(userID).
|
||||||
SetGroupID(groupID).
|
SetGroupID(groupID).
|
||||||
SetStartsAt(now.Add(-1*time.Hour)).
|
SetStartsAt(now.Add(-1 * time.Hour)).
|
||||||
SetExpiresAt(now.Add(24*time.Hour)).
|
SetExpiresAt(now.Add(24 * time.Hour)).
|
||||||
SetStatus(service.SubscriptionStatusActive).
|
SetStatus(service.SubscriptionStatusActive).
|
||||||
SetAssignedAt(now).
|
SetAssignedAt(now).
|
||||||
SetNotes("")
|
SetNotes("")
|
||||||
@@ -631,4 +631,3 @@ func (s *UserSubscriptionRepoSuite) TestActiveExpiredBoundaries_UsageAndReset_Ba
|
|||||||
s.Require().NoError(err, "GetByID expired")
|
s.Require().NoError(err, "GetByID expired")
|
||||||
s.Require().Equal(service.SubscriptionStatusExpired, updated.Status, "expected status expired")
|
s.Require().Equal(service.SubscriptionStatusExpired, updated.Status, "expected status expired")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user