fix(仓储): 修复查询关闭错误并迁移集成测试

修复 rows.Close 失败时的错误返回逻辑
迁移网关路由集成测试到 ent 事务基建
补齐仓储接口变更对应的测试桩方法
新增 backend/Makefile 统一测试命令
测试: GOTOOLCHAIN=go1.24.11 go test ./...
测试: golangci-lint run ./... --timeout=5m
测试: make test-integration
This commit is contained in:
yangjianbo
2025-12-30 16:41:45 +08:00
parent b6fec590a7
commit aacbc98aec
15 changed files with 123 additions and 110 deletions

13
backend/Makefile Normal file
View File

@@ -0,0 +1,13 @@
.PHONY: build test-unit test-integration test-e2e
build:
go build -o bin/server ./cmd/server
test-unit:
go test -tags=unit ./...
test-integration:
go test -tags=integration ./...
test-e2e:
go test -tags=e2e ./...

View File

@@ -175,16 +175,16 @@ func (Account) Edges() []ent.Edge {
// 每个索引对应一个常用的查询条件。
func (Account) Indexes() []ent.Index {
return []ent.Index{
index.Fields("platform"), // 按平台筛选
index.Fields("type"), // 按认证类型筛选
index.Fields("status"), // 按状态筛选
index.Fields("proxy_id"), // 按代理筛选
index.Fields("priority"), // 按优先级排序
index.Fields("last_used_at"), // 按最后使用时间排序
index.Fields("schedulable"), // 筛选可调度账户
index.Fields("rate_limited_at"), // 筛选速率限制账户
index.Fields("platform"), // 按平台筛选
index.Fields("type"), // 按认证类型筛选
index.Fields("status"), // 按状态筛选
index.Fields("proxy_id"), // 按代理筛选
index.Fields("priority"), // 按优先级排序
index.Fields("last_used_at"), // 按最后使用时间排序
index.Fields("schedulable"), // 筛选可调度账户
index.Fields("rate_limited_at"), // 筛选速率限制账户
index.Fields("rate_limit_reset_at"), // 筛选速率限制解除时间
index.Fields("overload_until"), // 筛选过载账户
index.Fields("deleted_at"), // 软删除查询优化
index.Fields("overload_until"), // 筛选过载账户
index.Fields("deleted_at"), // 软删除查询优化
}
}

View File

@@ -7,12 +7,12 @@ import (
"fmt"
"time"
"github.com/Wei-Shaw/sub2api/ent/intercept"
"entgo.io/ent"
"entgo.io/ent/dialect"
"entgo.io/ent/dialect/sql"
"entgo.io/ent/schema/field"
"entgo.io/ent/schema/mixin"
"github.com/Wei-Shaw/sub2api/ent/intercept"
)
// SoftDeleteMixin 实现基于 deleted_at 时间戳的软删除功能。

View File

@@ -110,4 +110,3 @@ func (UserSubscription) Indexes() []ent.Index {
index.Fields("user_id", "group_id").Unique(),
}
}

View File

@@ -6,10 +6,9 @@ import (
"context"
"testing"
dbent "github.com/Wei-Shaw/sub2api/ent"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/stretchr/testify/suite"
"gorm.io/datatypes"
"gorm.io/gorm"
)
// GatewayRoutingSuite 测试网关路由相关的数据库查询
@@ -17,14 +16,15 @@ import (
type GatewayRoutingSuite struct {
suite.Suite
ctx context.Context
db *gorm.DB
client *dbent.Client
accountRepo *accountRepository
}
func (s *GatewayRoutingSuite) SetupTest() {
s.ctx = context.Background()
s.db = testTx(s.T())
s.accountRepo = NewAccountRepository(s.db).(*accountRepository)
tx := testEntTx(s.T())
s.client = tx.Client()
s.accountRepo = newAccountRepositoryWithSQL(s.client, tx)
}
func TestGatewayRoutingSuite(t *testing.T) {
@@ -34,7 +34,7 @@ func TestGatewayRoutingSuite(t *testing.T) {
// TestListSchedulableByPlatforms_GeminiAndAntigravity 验证多平台账户查询
func (s *GatewayRoutingSuite) TestListSchedulableByPlatforms_GeminiAndAntigravity() {
// 创建各平台账户
geminiAcc := mustCreateAccount(s.T(), s.db, &accountModel{
geminiAcc := mustCreateAccount(s.T(), s.client, &service.Account{
Name: "gemini-oauth",
Platform: service.PlatformGemini,
Type: service.AccountTypeOAuth,
@@ -43,14 +43,14 @@ func (s *GatewayRoutingSuite) TestListSchedulableByPlatforms_GeminiAndAntigravit
Priority: 1,
})
antigravityAcc := mustCreateAccount(s.T(), s.db, &accountModel{
antigravityAcc := mustCreateAccount(s.T(), s.client, &service.Account{
Name: "antigravity-oauth",
Platform: service.PlatformAntigravity,
Type: service.AccountTypeOAuth,
Status: service.StatusActive,
Schedulable: true,
Priority: 2,
Credentials: datatypes.JSONMap{
Credentials: map[string]any{
"access_token": "test-token",
"refresh_token": "test-refresh",
"project_id": "test-project",
@@ -58,7 +58,7 @@ func (s *GatewayRoutingSuite) TestListSchedulableByPlatforms_GeminiAndAntigravit
})
// 创建不应被选中的 anthropic 账户
mustCreateAccount(s.T(), s.db, &accountModel{
mustCreateAccount(s.T(), s.client, &service.Account{
Name: "anthropic-oauth",
Platform: service.PlatformAnthropic,
Type: service.AccountTypeOAuth,
@@ -97,20 +97,20 @@ func (s *GatewayRoutingSuite) TestListSchedulableByPlatforms_GeminiAndAntigravit
// TestListSchedulableByGroupIDAndPlatforms_WithGroupBinding 验证按分组过滤
func (s *GatewayRoutingSuite) TestListSchedulableByGroupIDAndPlatforms_WithGroupBinding() {
// 创建 gemini 分组
group := mustCreateGroup(s.T(), s.db, &groupModel{
group := mustCreateGroup(s.T(), s.client, &service.Group{
Name: "gemini-group",
Platform: service.PlatformGemini,
Status: service.StatusActive,
})
// 创建账户
boundAcc := mustCreateAccount(s.T(), s.db, &accountModel{
boundAcc := mustCreateAccount(s.T(), s.client, &service.Account{
Name: "bound-antigravity",
Platform: service.PlatformAntigravity,
Status: service.StatusActive,
Schedulable: true,
})
unboundAcc := mustCreateAccount(s.T(), s.db, &accountModel{
unboundAcc := mustCreateAccount(s.T(), s.client, &service.Account{
Name: "unbound-antigravity",
Platform: service.PlatformAntigravity,
Status: service.StatusActive,
@@ -118,7 +118,7 @@ func (s *GatewayRoutingSuite) TestListSchedulableByGroupIDAndPlatforms_WithGroup
})
// 只绑定一个账户到分组
mustBindAccountToGroup(s.T(), s.db, boundAcc.ID, group.ID, 1)
mustBindAccountToGroup(s.T(), s.client, boundAcc.ID, group.ID, 1)
// 查询分组内的账户
accounts, err := s.accountRepo.ListSchedulableByGroupIDAndPlatforms(s.ctx, group.ID, []string{
@@ -139,14 +139,14 @@ func (s *GatewayRoutingSuite) TestListSchedulableByGroupIDAndPlatforms_WithGroup
// TestListSchedulableByPlatform_Antigravity 验证单平台查询
func (s *GatewayRoutingSuite) TestListSchedulableByPlatform_Antigravity() {
// 创建多种平台账户
mustCreateAccount(s.T(), s.db, &accountModel{
mustCreateAccount(s.T(), s.client, &service.Account{
Name: "gemini-1",
Platform: service.PlatformGemini,
Status: service.StatusActive,
Schedulable: true,
})
antigravity := mustCreateAccount(s.T(), s.db, &accountModel{
antigravity := mustCreateAccount(s.T(), s.client, &service.Account{
Name: "antigravity-1",
Platform: service.PlatformAntigravity,
Status: service.StatusActive,
@@ -165,7 +165,7 @@ func (s *GatewayRoutingSuite) TestListSchedulableByPlatform_Antigravity() {
// TestSchedulableFilter_ExcludesInactive 验证不可调度账户被过滤
func (s *GatewayRoutingSuite) TestSchedulableFilter_ExcludesInactive() {
// 创建可调度账户
activeAcc := mustCreateAccount(s.T(), s.db, &accountModel{
activeAcc := mustCreateAccount(s.T(), s.client, &service.Account{
Name: "active-antigravity",
Platform: service.PlatformAntigravity,
Status: service.StatusActive,
@@ -173,15 +173,15 @@ func (s *GatewayRoutingSuite) TestSchedulableFilter_ExcludesInactive() {
})
// 创建不可调度账户(需要先创建再更新,因为 fixture 默认设置 Schedulable=true
inactiveAcc := mustCreateAccount(s.T(), s.db, &accountModel{
inactiveAcc := mustCreateAccount(s.T(), s.client, &service.Account{
Name: "inactive-antigravity",
Platform: service.PlatformAntigravity,
Status: service.StatusActive,
})
s.Require().NoError(s.db.Model(&accountModel{}).Where("id = ?", inactiveAcc.ID).Update("schedulable", false).Error)
s.Require().NoError(s.client.Account.UpdateOneID(inactiveAcc.ID).SetSchedulable(false).Exec(s.ctx))
// 创建错误状态账户
mustCreateAccount(s.T(), s.db, &accountModel{
mustCreateAccount(s.T(), s.client, &service.Account{
Name: "error-antigravity",
Platform: service.PlatformAntigravity,
Status: service.StatusError,
@@ -199,14 +199,14 @@ func (s *GatewayRoutingSuite) TestSchedulableFilter_ExcludesInactive() {
// 这个测试模拟 Handler 层在选择账户后的路由决策逻辑
func (s *GatewayRoutingSuite) TestPlatformRoutingDecision() {
// 创建两种平台的账户
geminiAcc := mustCreateAccount(s.T(), s.db, &accountModel{
geminiAcc := mustCreateAccount(s.T(), s.client, &service.Account{
Name: "gemini-route-test",
Platform: service.PlatformGemini,
Status: service.StatusActive,
Schedulable: true,
})
antigravityAcc := mustCreateAccount(s.T(), s.db, &accountModel{
antigravityAcc := mustCreateAccount(s.T(), s.client, &service.Account{
Name: "antigravity-route-test",
Platform: service.PlatformAntigravity,
Status: service.StatusActive,

View File

@@ -235,9 +235,8 @@ func (r *groupRepository) DeleteCascade(ctx context.Context, id int64) ([]int64,
defer func() { _ = tx.Rollback() }()
exec = tx.Client()
txClient = exec
} else {
// 已处于外部事务中ErrTxStarted复用当前 client 参与同一事务。
}
// err 为 dbent.ErrTxStarted 时,复用当前 client 参与同一事务。
// Lock the group row to avoid concurrent writes while we cascade.
// 这里使用 exec.QueryContext 手动扫描,确保同一事务内加锁并能区分“未找到”与其他错误。
@@ -330,8 +329,8 @@ func (r *groupRepository) DeleteCascade(ctx context.Context, id int64) ([]int64,
return affectedUserIDs, nil
}
func (r *groupRepository) loadAccountCounts(ctx context.Context, groupIDs []int64) (map[int64]int64, error) {
counts := make(map[int64]int64, len(groupIDs))
func (r *groupRepository) loadAccountCounts(ctx context.Context, groupIDs []int64) (counts map[int64]int64, err error) {
counts = make(map[int64]int64, len(groupIDs))
if len(groupIDs) == 0 {
return counts, nil
}
@@ -344,23 +343,24 @@ func (r *groupRepository) loadAccountCounts(ctx context.Context, groupIDs []int6
if err != nil {
return nil, err
}
defer rows.Close()
defer func() {
if closeErr := rows.Close(); closeErr != nil && err == nil {
err = closeErr
counts = nil
}
}()
for rows.Next() {
var groupID int64
var count int64
if err := rows.Scan(&groupID, &count); err != nil {
if err = rows.Scan(&groupID, &count); err != nil {
return nil, err
}
counts[groupID] = count
}
if err := rows.Err(); err != nil {
if err = rows.Err(); err != nil {
return nil, err
}
return counts, nil
}
func errorsIsNoRows(err error) bool {
return err == sql.ErrNoRows
}

View File

@@ -177,22 +177,27 @@ func (r *proxyRepository) CountAccountsByProxyID(ctx context.Context, proxyID in
}
// GetAccountCountsForProxies returns a map of proxy ID to account count for all proxies
func (r *proxyRepository) GetAccountCountsForProxies(ctx context.Context) (map[int64]int64, error) {
func (r *proxyRepository) GetAccountCountsForProxies(ctx context.Context) (counts map[int64]int64, err error) {
rows, err := r.sql.QueryContext(ctx, "SELECT proxy_id, COUNT(*) AS count FROM accounts WHERE proxy_id IS NOT NULL GROUP BY proxy_id")
if err != nil {
return nil, err
}
defer rows.Close()
defer func() {
if closeErr := rows.Close(); closeErr != nil && err == nil {
err = closeErr
counts = nil
}
}()
counts := make(map[int64]int64)
counts = make(map[int64]int64)
for rows.Next() {
var proxyID, count int64
if err := rows.Scan(&proxyID, &count); err != nil {
if err = rows.Scan(&proxyID, &count); err != nil {
return nil, err
}
counts[proxyID] = count
}
if err := rows.Err(); err != nil {
if err = rows.Err(); err != nil {
return nil, err
}
return counts, nil

View File

@@ -13,21 +13,28 @@ type sqlQueryer interface {
// If no rows are returned, sql.ErrNoRows is returned.
// 设计目的:仅依赖 QueryContext避免 QueryRowContext 对 *sql.Tx 的强绑定,
// 让 ent.Tx 也能作为 sqlExecutor/Queryer 使用。
func scanSingleRow(ctx context.Context, q sqlQueryer, query string, args []any, dest ...any) error {
func scanSingleRow(ctx context.Context, q sqlQueryer, query string, args []any, dest ...any) (err error) {
rows, err := q.QueryContext(ctx, query, args...)
if err != nil {
return err
}
defer rows.Close()
defer func() {
if closeErr := rows.Close(); closeErr != nil && err == nil {
err = closeErr
}
}()
if !rows.Next() {
if err := rows.Err(); err != nil {
if err = rows.Err(); err != nil {
return err
}
return sql.ErrNoRows
}
if err := rows.Scan(dest...); err != nil {
if err = rows.Scan(dest...); err != nil {
return err
}
return rows.Err()
if err = rows.Err(); err != nil {
return err
}
return nil
}

View File

@@ -154,7 +154,7 @@ func (r *usageLogRepository) GetByID(ctx context.Context, id int64) (*service.Us
if err != nil {
return nil, err
}
defer rows.Close()
defer func() { _ = rows.Close() }()
if !rows.Next() {
if err := rows.Err(); err != nil {
return nil, err
@@ -568,7 +568,7 @@ func (r *usageLogRepository) GetApiKeyUsageTrend(ctx context.Context, startTime,
if err != nil {
return nil, err
}
defer rows.Close()
defer func() { _ = rows.Close() }()
results := make([]ApiKeyUsageTrendPoint, 0)
for rows.Next() {
@@ -621,7 +621,7 @@ func (r *usageLogRepository) GetUserUsageTrend(ctx context.Context, startTime, e
if err != nil {
return nil, err
}
defer rows.Close()
defer func() { _ = rows.Close() }()
results := make([]UserUsageTrendPoint, 0)
for rows.Next() {
@@ -766,7 +766,7 @@ func (r *usageLogRepository) GetUserUsageTrendByUserID(ctx context.Context, user
if err != nil {
return nil, err
}
defer rows.Close()
defer func() { _ = rows.Close() }()
return scanTrendRows(rows)
}
@@ -792,7 +792,7 @@ func (r *usageLogRepository) GetUserModelStats(ctx context.Context, userID int64
if err != nil {
return nil, err
}
defer rows.Close()
defer func() { _ = rows.Close() }()
return scanModelStatsRows(rows)
}
@@ -1029,7 +1029,7 @@ func (r *usageLogRepository) GetUsageTrendWithFilters(ctx context.Context, start
if err != nil {
return nil, err
}
defer rows.Close()
defer func() { _ = rows.Close() }()
return scanTrendRows(rows)
}
@@ -1068,7 +1068,7 @@ func (r *usageLogRepository) GetModelStatsWithFilters(ctx context.Context, start
if err != nil {
return nil, err
}
defer rows.Close()
defer func() { _ = rows.Close() }()
return scanModelStatsRows(rows)
}
@@ -1141,7 +1141,7 @@ func (r *usageLogRepository) GetAccountUsageStats(ctx context.Context, accountID
if err != nil {
return nil, err
}
defer rows.Close()
defer func() { _ = rows.Close() }()
history := make([]AccountUsageHistory, 0)
for rows.Next() {
@@ -1291,7 +1291,7 @@ func (r *usageLogRepository) queryUsageLogs(ctx context.Context, query string, a
if err != nil {
return nil, err
}
defer rows.Close()
defer func() { _ = rows.Close() }()
logs := make([]service.UsageLog, 0)
for rows.Next() {

View File

@@ -405,53 +405,6 @@ func (r *userRepository) syncUserAllowedGroupsWithClient(ctx context.Context, cl
return nil
}
func (r *userRepository) syncUserAllowedGroups(ctx context.Context, client *dbent.Client, exec sqlExecutor, userID int64, groupIDs []int64) error {
if client == nil || exec == nil {
return nil
}
// Keep join table as the source of truth for reads.
if _, err := client.UserAllowedGroup.Delete().Where(userallowedgroup.UserIDEQ(userID)).Exec(ctx); err != nil {
return err
}
unique := make(map[int64]struct{}, len(groupIDs))
for _, id := range groupIDs {
if id <= 0 {
continue
}
unique[id] = struct{}{}
}
legacyGroups := make([]int64, 0, len(unique))
if len(unique) > 0 {
creates := make([]*dbent.UserAllowedGroupCreate, 0, len(unique))
for groupID := range unique {
creates = append(creates, client.UserAllowedGroup.Create().SetUserID(userID).SetGroupID(groupID))
legacyGroups = append(legacyGroups, groupID)
}
if err := client.UserAllowedGroup.
CreateBulk(creates...).
OnConflictColumns(userallowedgroup.FieldUserID, userallowedgroup.FieldGroupID).
DoNothing().
Exec(ctx); err != nil {
return err
}
}
// Phase 1 compatibility: keep legacy users.allowed_groups array updated for existing raw SQL paths.
var legacy any
if len(legacyGroups) > 0 {
sort.Slice(legacyGroups, func(i, j int) bool { return legacyGroups[i] < legacyGroups[j] })
legacy = pq.Array(legacyGroups)
}
if _, err := exec.ExecContext(ctx, "UPDATE users SET allowed_groups = $1::bigint[] WHERE id = $2", legacy, userID); err != nil {
return err
}
return nil
}
func applyUserEntityToService(dst *service.User, src *dbent.User) {
if dst == nil || src == nil {
return

View File

@@ -717,6 +717,14 @@ func (r *stubApiKeyRepo) GetByID(ctx context.Context, id int64) (*service.ApiKey
return &clone, nil
}
func (r *stubApiKeyRepo) GetOwnerID(ctx context.Context, id int64) (int64, error) {
key, ok := r.byID[id]
if !ok {
return 0, service.ErrApiKeyNotFound
}
return key.UserID, nil
}
func (r *stubApiKeyRepo) GetByKey(ctx context.Context, key string) (*service.ApiKey, error) {
found, ok := r.byKey[key]
if !ok {

View File

@@ -131,6 +131,10 @@ func (r *stubApiKeyRepo) GetByID(ctx context.Context, id int64) (*service.ApiKey
return nil, errors.New("not implemented")
}
func (r *stubApiKeyRepo) GetOwnerID(ctx context.Context, id int64) (int64, error) {
return 0, errors.New("not implemented")
}
func (r *stubApiKeyRepo) GetByKey(ctx context.Context, key string) (*service.ApiKey, error) {
if r.getByKey != nil {
return r.getByKey(ctx, key)

View File

@@ -119,6 +119,14 @@ func (s *accountRepoStub) ListSchedulableByGroupIDAndPlatform(ctx context.Contex
panic("unexpected ListSchedulableByGroupIDAndPlatform call")
}
func (s *accountRepoStub) ListSchedulableByPlatforms(ctx context.Context, platforms []string) ([]Account, error) {
panic("unexpected ListSchedulableByPlatforms call")
}
func (s *accountRepoStub) ListSchedulableByGroupIDAndPlatforms(ctx context.Context, groupID int64, platforms []string) ([]Account, error) {
panic("unexpected ListSchedulableByGroupIDAndPlatforms call")
}
func (s *accountRepoStub) SetRateLimited(ctx context.Context, id int64, resetAt time.Time) error {
panic("unexpected SetRateLimited call")
}

View File

@@ -32,6 +32,14 @@ func (m *mockAccountRepoForPlatform) GetByID(ctx context.Context, id int64) (*Ac
return nil, errors.New("account not found")
}
func (m *mockAccountRepoForPlatform) ExistsByID(ctx context.Context, id int64) (bool, error) {
if m.accountsByID == nil {
return false, nil
}
_, ok := m.accountsByID[id]
return ok, nil
}
func (m *mockAccountRepoForPlatform) ListSchedulableByPlatform(ctx context.Context, platform string) ([]Account, error) {
if m.listPlatformFunc != nil {
return m.listPlatformFunc(ctx, platform)

View File

@@ -25,6 +25,14 @@ func (m *mockAccountRepoForGemini) GetByID(ctx context.Context, id int64) (*Acco
return nil, errors.New("account not found")
}
func (m *mockAccountRepoForGemini) ExistsByID(ctx context.Context, id int64) (bool, error) {
if m.accountsByID == nil {
return false, nil
}
_, ok := m.accountsByID[id]
return ok, nil
}
func (m *mockAccountRepoForGemini) ListSchedulableByPlatform(ctx context.Context, platform string) ([]Account, error) {
var result []Account
for _, acc := range m.accounts {