fix(仓储): 修复查询关闭错误并迁移集成测试
修复 rows.Close 失败时的错误返回逻辑 迁移网关路由集成测试到 ent 事务基建 补齐仓储接口变更对应的测试桩方法 新增 backend/Makefile 统一测试命令 测试: GOTOOLCHAIN=go1.24.11 go test ./... 测试: golangci-lint run ./... --timeout=5m 测试: make test-integration
This commit is contained in:
13
backend/Makefile
Normal file
13
backend/Makefile
Normal file
@@ -0,0 +1,13 @@
|
||||
.PHONY: build test-unit test-integration test-e2e
|
||||
|
||||
build:
|
||||
go build -o bin/server ./cmd/server
|
||||
|
||||
test-unit:
|
||||
go test -tags=unit ./...
|
||||
|
||||
test-integration:
|
||||
go test -tags=integration ./...
|
||||
|
||||
test-e2e:
|
||||
go test -tags=e2e ./...
|
||||
@@ -175,16 +175,16 @@ func (Account) Edges() []ent.Edge {
|
||||
// 每个索引对应一个常用的查询条件。
|
||||
func (Account) Indexes() []ent.Index {
|
||||
return []ent.Index{
|
||||
index.Fields("platform"), // 按平台筛选
|
||||
index.Fields("type"), // 按认证类型筛选
|
||||
index.Fields("status"), // 按状态筛选
|
||||
index.Fields("proxy_id"), // 按代理筛选
|
||||
index.Fields("priority"), // 按优先级排序
|
||||
index.Fields("last_used_at"), // 按最后使用时间排序
|
||||
index.Fields("schedulable"), // 筛选可调度账户
|
||||
index.Fields("rate_limited_at"), // 筛选速率限制账户
|
||||
index.Fields("platform"), // 按平台筛选
|
||||
index.Fields("type"), // 按认证类型筛选
|
||||
index.Fields("status"), // 按状态筛选
|
||||
index.Fields("proxy_id"), // 按代理筛选
|
||||
index.Fields("priority"), // 按优先级排序
|
||||
index.Fields("last_used_at"), // 按最后使用时间排序
|
||||
index.Fields("schedulable"), // 筛选可调度账户
|
||||
index.Fields("rate_limited_at"), // 筛选速率限制账户
|
||||
index.Fields("rate_limit_reset_at"), // 筛选速率限制解除时间
|
||||
index.Fields("overload_until"), // 筛选过载账户
|
||||
index.Fields("deleted_at"), // 软删除查询优化
|
||||
index.Fields("overload_until"), // 筛选过载账户
|
||||
index.Fields("deleted_at"), // 软删除查询优化
|
||||
}
|
||||
}
|
||||
|
||||
@@ -7,12 +7,12 @@ import (
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/ent/intercept"
|
||||
"entgo.io/ent"
|
||||
"entgo.io/ent/dialect"
|
||||
"entgo.io/ent/dialect/sql"
|
||||
"entgo.io/ent/schema/field"
|
||||
"entgo.io/ent/schema/mixin"
|
||||
"github.com/Wei-Shaw/sub2api/ent/intercept"
|
||||
)
|
||||
|
||||
// SoftDeleteMixin 实现基于 deleted_at 时间戳的软删除功能。
|
||||
|
||||
@@ -110,4 +110,3 @@ func (UserSubscription) Indexes() []ent.Index {
|
||||
index.Fields("user_id", "group_id").Unique(),
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -6,10 +6,9 @@ import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
dbent "github.com/Wei-Shaw/sub2api/ent"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/stretchr/testify/suite"
|
||||
"gorm.io/datatypes"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
// GatewayRoutingSuite 测试网关路由相关的数据库查询
|
||||
@@ -17,14 +16,15 @@ import (
|
||||
type GatewayRoutingSuite struct {
|
||||
suite.Suite
|
||||
ctx context.Context
|
||||
db *gorm.DB
|
||||
client *dbent.Client
|
||||
accountRepo *accountRepository
|
||||
}
|
||||
|
||||
func (s *GatewayRoutingSuite) SetupTest() {
|
||||
s.ctx = context.Background()
|
||||
s.db = testTx(s.T())
|
||||
s.accountRepo = NewAccountRepository(s.db).(*accountRepository)
|
||||
tx := testEntTx(s.T())
|
||||
s.client = tx.Client()
|
||||
s.accountRepo = newAccountRepositoryWithSQL(s.client, tx)
|
||||
}
|
||||
|
||||
func TestGatewayRoutingSuite(t *testing.T) {
|
||||
@@ -34,7 +34,7 @@ func TestGatewayRoutingSuite(t *testing.T) {
|
||||
// TestListSchedulableByPlatforms_GeminiAndAntigravity 验证多平台账户查询
|
||||
func (s *GatewayRoutingSuite) TestListSchedulableByPlatforms_GeminiAndAntigravity() {
|
||||
// 创建各平台账户
|
||||
geminiAcc := mustCreateAccount(s.T(), s.db, &accountModel{
|
||||
geminiAcc := mustCreateAccount(s.T(), s.client, &service.Account{
|
||||
Name: "gemini-oauth",
|
||||
Platform: service.PlatformGemini,
|
||||
Type: service.AccountTypeOAuth,
|
||||
@@ -43,14 +43,14 @@ func (s *GatewayRoutingSuite) TestListSchedulableByPlatforms_GeminiAndAntigravit
|
||||
Priority: 1,
|
||||
})
|
||||
|
||||
antigravityAcc := mustCreateAccount(s.T(), s.db, &accountModel{
|
||||
antigravityAcc := mustCreateAccount(s.T(), s.client, &service.Account{
|
||||
Name: "antigravity-oauth",
|
||||
Platform: service.PlatformAntigravity,
|
||||
Type: service.AccountTypeOAuth,
|
||||
Status: service.StatusActive,
|
||||
Schedulable: true,
|
||||
Priority: 2,
|
||||
Credentials: datatypes.JSONMap{
|
||||
Credentials: map[string]any{
|
||||
"access_token": "test-token",
|
||||
"refresh_token": "test-refresh",
|
||||
"project_id": "test-project",
|
||||
@@ -58,7 +58,7 @@ func (s *GatewayRoutingSuite) TestListSchedulableByPlatforms_GeminiAndAntigravit
|
||||
})
|
||||
|
||||
// 创建不应被选中的 anthropic 账户
|
||||
mustCreateAccount(s.T(), s.db, &accountModel{
|
||||
mustCreateAccount(s.T(), s.client, &service.Account{
|
||||
Name: "anthropic-oauth",
|
||||
Platform: service.PlatformAnthropic,
|
||||
Type: service.AccountTypeOAuth,
|
||||
@@ -97,20 +97,20 @@ func (s *GatewayRoutingSuite) TestListSchedulableByPlatforms_GeminiAndAntigravit
|
||||
// TestListSchedulableByGroupIDAndPlatforms_WithGroupBinding 验证按分组过滤
|
||||
func (s *GatewayRoutingSuite) TestListSchedulableByGroupIDAndPlatforms_WithGroupBinding() {
|
||||
// 创建 gemini 分组
|
||||
group := mustCreateGroup(s.T(), s.db, &groupModel{
|
||||
group := mustCreateGroup(s.T(), s.client, &service.Group{
|
||||
Name: "gemini-group",
|
||||
Platform: service.PlatformGemini,
|
||||
Status: service.StatusActive,
|
||||
})
|
||||
|
||||
// 创建账户
|
||||
boundAcc := mustCreateAccount(s.T(), s.db, &accountModel{
|
||||
boundAcc := mustCreateAccount(s.T(), s.client, &service.Account{
|
||||
Name: "bound-antigravity",
|
||||
Platform: service.PlatformAntigravity,
|
||||
Status: service.StatusActive,
|
||||
Schedulable: true,
|
||||
})
|
||||
unboundAcc := mustCreateAccount(s.T(), s.db, &accountModel{
|
||||
unboundAcc := mustCreateAccount(s.T(), s.client, &service.Account{
|
||||
Name: "unbound-antigravity",
|
||||
Platform: service.PlatformAntigravity,
|
||||
Status: service.StatusActive,
|
||||
@@ -118,7 +118,7 @@ func (s *GatewayRoutingSuite) TestListSchedulableByGroupIDAndPlatforms_WithGroup
|
||||
})
|
||||
|
||||
// 只绑定一个账户到分组
|
||||
mustBindAccountToGroup(s.T(), s.db, boundAcc.ID, group.ID, 1)
|
||||
mustBindAccountToGroup(s.T(), s.client, boundAcc.ID, group.ID, 1)
|
||||
|
||||
// 查询分组内的账户
|
||||
accounts, err := s.accountRepo.ListSchedulableByGroupIDAndPlatforms(s.ctx, group.ID, []string{
|
||||
@@ -139,14 +139,14 @@ func (s *GatewayRoutingSuite) TestListSchedulableByGroupIDAndPlatforms_WithGroup
|
||||
// TestListSchedulableByPlatform_Antigravity 验证单平台查询
|
||||
func (s *GatewayRoutingSuite) TestListSchedulableByPlatform_Antigravity() {
|
||||
// 创建多种平台账户
|
||||
mustCreateAccount(s.T(), s.db, &accountModel{
|
||||
mustCreateAccount(s.T(), s.client, &service.Account{
|
||||
Name: "gemini-1",
|
||||
Platform: service.PlatformGemini,
|
||||
Status: service.StatusActive,
|
||||
Schedulable: true,
|
||||
})
|
||||
|
||||
antigravity := mustCreateAccount(s.T(), s.db, &accountModel{
|
||||
antigravity := mustCreateAccount(s.T(), s.client, &service.Account{
|
||||
Name: "antigravity-1",
|
||||
Platform: service.PlatformAntigravity,
|
||||
Status: service.StatusActive,
|
||||
@@ -165,7 +165,7 @@ func (s *GatewayRoutingSuite) TestListSchedulableByPlatform_Antigravity() {
|
||||
// TestSchedulableFilter_ExcludesInactive 验证不可调度账户被过滤
|
||||
func (s *GatewayRoutingSuite) TestSchedulableFilter_ExcludesInactive() {
|
||||
// 创建可调度账户
|
||||
activeAcc := mustCreateAccount(s.T(), s.db, &accountModel{
|
||||
activeAcc := mustCreateAccount(s.T(), s.client, &service.Account{
|
||||
Name: "active-antigravity",
|
||||
Platform: service.PlatformAntigravity,
|
||||
Status: service.StatusActive,
|
||||
@@ -173,15 +173,15 @@ func (s *GatewayRoutingSuite) TestSchedulableFilter_ExcludesInactive() {
|
||||
})
|
||||
|
||||
// 创建不可调度账户(需要先创建再更新,因为 fixture 默认设置 Schedulable=true)
|
||||
inactiveAcc := mustCreateAccount(s.T(), s.db, &accountModel{
|
||||
inactiveAcc := mustCreateAccount(s.T(), s.client, &service.Account{
|
||||
Name: "inactive-antigravity",
|
||||
Platform: service.PlatformAntigravity,
|
||||
Status: service.StatusActive,
|
||||
})
|
||||
s.Require().NoError(s.db.Model(&accountModel{}).Where("id = ?", inactiveAcc.ID).Update("schedulable", false).Error)
|
||||
s.Require().NoError(s.client.Account.UpdateOneID(inactiveAcc.ID).SetSchedulable(false).Exec(s.ctx))
|
||||
|
||||
// 创建错误状态账户
|
||||
mustCreateAccount(s.T(), s.db, &accountModel{
|
||||
mustCreateAccount(s.T(), s.client, &service.Account{
|
||||
Name: "error-antigravity",
|
||||
Platform: service.PlatformAntigravity,
|
||||
Status: service.StatusError,
|
||||
@@ -199,14 +199,14 @@ func (s *GatewayRoutingSuite) TestSchedulableFilter_ExcludesInactive() {
|
||||
// 这个测试模拟 Handler 层在选择账户后的路由决策逻辑
|
||||
func (s *GatewayRoutingSuite) TestPlatformRoutingDecision() {
|
||||
// 创建两种平台的账户
|
||||
geminiAcc := mustCreateAccount(s.T(), s.db, &accountModel{
|
||||
geminiAcc := mustCreateAccount(s.T(), s.client, &service.Account{
|
||||
Name: "gemini-route-test",
|
||||
Platform: service.PlatformGemini,
|
||||
Status: service.StatusActive,
|
||||
Schedulable: true,
|
||||
})
|
||||
|
||||
antigravityAcc := mustCreateAccount(s.T(), s.db, &accountModel{
|
||||
antigravityAcc := mustCreateAccount(s.T(), s.client, &service.Account{
|
||||
Name: "antigravity-route-test",
|
||||
Platform: service.PlatformAntigravity,
|
||||
Status: service.StatusActive,
|
||||
|
||||
@@ -235,9 +235,8 @@ func (r *groupRepository) DeleteCascade(ctx context.Context, id int64) ([]int64,
|
||||
defer func() { _ = tx.Rollback() }()
|
||||
exec = tx.Client()
|
||||
txClient = exec
|
||||
} else {
|
||||
// 已处于外部事务中(ErrTxStarted),复用当前 client 参与同一事务。
|
||||
}
|
||||
// err 为 dbent.ErrTxStarted 时,复用当前 client 参与同一事务。
|
||||
|
||||
// Lock the group row to avoid concurrent writes while we cascade.
|
||||
// 这里使用 exec.QueryContext 手动扫描,确保同一事务内加锁并能区分“未找到”与其他错误。
|
||||
@@ -330,8 +329,8 @@ func (r *groupRepository) DeleteCascade(ctx context.Context, id int64) ([]int64,
|
||||
return affectedUserIDs, nil
|
||||
}
|
||||
|
||||
func (r *groupRepository) loadAccountCounts(ctx context.Context, groupIDs []int64) (map[int64]int64, error) {
|
||||
counts := make(map[int64]int64, len(groupIDs))
|
||||
func (r *groupRepository) loadAccountCounts(ctx context.Context, groupIDs []int64) (counts map[int64]int64, err error) {
|
||||
counts = make(map[int64]int64, len(groupIDs))
|
||||
if len(groupIDs) == 0 {
|
||||
return counts, nil
|
||||
}
|
||||
@@ -344,23 +343,24 @@ func (r *groupRepository) loadAccountCounts(ctx context.Context, groupIDs []int6
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
defer func() {
|
||||
if closeErr := rows.Close(); closeErr != nil && err == nil {
|
||||
err = closeErr
|
||||
counts = nil
|
||||
}
|
||||
}()
|
||||
|
||||
for rows.Next() {
|
||||
var groupID int64
|
||||
var count int64
|
||||
if err := rows.Scan(&groupID, &count); err != nil {
|
||||
if err = rows.Scan(&groupID, &count); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
counts[groupID] = count
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
if err = rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return counts, nil
|
||||
}
|
||||
|
||||
func errorsIsNoRows(err error) bool {
|
||||
return err == sql.ErrNoRows
|
||||
}
|
||||
|
||||
@@ -177,22 +177,27 @@ func (r *proxyRepository) CountAccountsByProxyID(ctx context.Context, proxyID in
|
||||
}
|
||||
|
||||
// GetAccountCountsForProxies returns a map of proxy ID to account count for all proxies
|
||||
func (r *proxyRepository) GetAccountCountsForProxies(ctx context.Context) (map[int64]int64, error) {
|
||||
func (r *proxyRepository) GetAccountCountsForProxies(ctx context.Context) (counts map[int64]int64, err error) {
|
||||
rows, err := r.sql.QueryContext(ctx, "SELECT proxy_id, COUNT(*) AS count FROM accounts WHERE proxy_id IS NOT NULL GROUP BY proxy_id")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
defer func() {
|
||||
if closeErr := rows.Close(); closeErr != nil && err == nil {
|
||||
err = closeErr
|
||||
counts = nil
|
||||
}
|
||||
}()
|
||||
|
||||
counts := make(map[int64]int64)
|
||||
counts = make(map[int64]int64)
|
||||
for rows.Next() {
|
||||
var proxyID, count int64
|
||||
if err := rows.Scan(&proxyID, &count); err != nil {
|
||||
if err = rows.Scan(&proxyID, &count); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
counts[proxyID] = count
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
if err = rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return counts, nil
|
||||
|
||||
@@ -13,21 +13,28 @@ type sqlQueryer interface {
|
||||
// If no rows are returned, sql.ErrNoRows is returned.
|
||||
// 设计目的:仅依赖 QueryContext,避免 QueryRowContext 对 *sql.Tx 的强绑定,
|
||||
// 让 ent.Tx 也能作为 sqlExecutor/Queryer 使用。
|
||||
func scanSingleRow(ctx context.Context, q sqlQueryer, query string, args []any, dest ...any) error {
|
||||
func scanSingleRow(ctx context.Context, q sqlQueryer, query string, args []any, dest ...any) (err error) {
|
||||
rows, err := q.QueryContext(ctx, query, args...)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer rows.Close()
|
||||
defer func() {
|
||||
if closeErr := rows.Close(); closeErr != nil && err == nil {
|
||||
err = closeErr
|
||||
}
|
||||
}()
|
||||
|
||||
if !rows.Next() {
|
||||
if err := rows.Err(); err != nil {
|
||||
if err = rows.Err(); err != nil {
|
||||
return err
|
||||
}
|
||||
return sql.ErrNoRows
|
||||
}
|
||||
if err := rows.Scan(dest...); err != nil {
|
||||
if err = rows.Scan(dest...); err != nil {
|
||||
return err
|
||||
}
|
||||
return rows.Err()
|
||||
if err = rows.Err(); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -154,7 +154,7 @@ func (r *usageLogRepository) GetByID(ctx context.Context, id int64) (*service.Us
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
defer func() { _ = rows.Close() }()
|
||||
if !rows.Next() {
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
@@ -568,7 +568,7 @@ func (r *usageLogRepository) GetApiKeyUsageTrend(ctx context.Context, startTime,
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
defer func() { _ = rows.Close() }()
|
||||
|
||||
results := make([]ApiKeyUsageTrendPoint, 0)
|
||||
for rows.Next() {
|
||||
@@ -621,7 +621,7 @@ func (r *usageLogRepository) GetUserUsageTrend(ctx context.Context, startTime, e
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
defer func() { _ = rows.Close() }()
|
||||
|
||||
results := make([]UserUsageTrendPoint, 0)
|
||||
for rows.Next() {
|
||||
@@ -766,7 +766,7 @@ func (r *usageLogRepository) GetUserUsageTrendByUserID(ctx context.Context, user
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
defer func() { _ = rows.Close() }()
|
||||
|
||||
return scanTrendRows(rows)
|
||||
}
|
||||
@@ -792,7 +792,7 @@ func (r *usageLogRepository) GetUserModelStats(ctx context.Context, userID int64
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
defer func() { _ = rows.Close() }()
|
||||
|
||||
return scanModelStatsRows(rows)
|
||||
}
|
||||
@@ -1029,7 +1029,7 @@ func (r *usageLogRepository) GetUsageTrendWithFilters(ctx context.Context, start
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
defer func() { _ = rows.Close() }()
|
||||
|
||||
return scanTrendRows(rows)
|
||||
}
|
||||
@@ -1068,7 +1068,7 @@ func (r *usageLogRepository) GetModelStatsWithFilters(ctx context.Context, start
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
defer func() { _ = rows.Close() }()
|
||||
|
||||
return scanModelStatsRows(rows)
|
||||
}
|
||||
@@ -1141,7 +1141,7 @@ func (r *usageLogRepository) GetAccountUsageStats(ctx context.Context, accountID
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
defer func() { _ = rows.Close() }()
|
||||
|
||||
history := make([]AccountUsageHistory, 0)
|
||||
for rows.Next() {
|
||||
@@ -1291,7 +1291,7 @@ func (r *usageLogRepository) queryUsageLogs(ctx context.Context, query string, a
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
defer func() { _ = rows.Close() }()
|
||||
|
||||
logs := make([]service.UsageLog, 0)
|
||||
for rows.Next() {
|
||||
|
||||
@@ -405,53 +405,6 @@ func (r *userRepository) syncUserAllowedGroupsWithClient(ctx context.Context, cl
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *userRepository) syncUserAllowedGroups(ctx context.Context, client *dbent.Client, exec sqlExecutor, userID int64, groupIDs []int64) error {
|
||||
if client == nil || exec == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Keep join table as the source of truth for reads.
|
||||
if _, err := client.UserAllowedGroup.Delete().Where(userallowedgroup.UserIDEQ(userID)).Exec(ctx); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
unique := make(map[int64]struct{}, len(groupIDs))
|
||||
for _, id := range groupIDs {
|
||||
if id <= 0 {
|
||||
continue
|
||||
}
|
||||
unique[id] = struct{}{}
|
||||
}
|
||||
|
||||
legacyGroups := make([]int64, 0, len(unique))
|
||||
if len(unique) > 0 {
|
||||
creates := make([]*dbent.UserAllowedGroupCreate, 0, len(unique))
|
||||
for groupID := range unique {
|
||||
creates = append(creates, client.UserAllowedGroup.Create().SetUserID(userID).SetGroupID(groupID))
|
||||
legacyGroups = append(legacyGroups, groupID)
|
||||
}
|
||||
if err := client.UserAllowedGroup.
|
||||
CreateBulk(creates...).
|
||||
OnConflictColumns(userallowedgroup.FieldUserID, userallowedgroup.FieldGroupID).
|
||||
DoNothing().
|
||||
Exec(ctx); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
// Phase 1 compatibility: keep legacy users.allowed_groups array updated for existing raw SQL paths.
|
||||
var legacy any
|
||||
if len(legacyGroups) > 0 {
|
||||
sort.Slice(legacyGroups, func(i, j int) bool { return legacyGroups[i] < legacyGroups[j] })
|
||||
legacy = pq.Array(legacyGroups)
|
||||
}
|
||||
if _, err := exec.ExecContext(ctx, "UPDATE users SET allowed_groups = $1::bigint[] WHERE id = $2", legacy, userID); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func applyUserEntityToService(dst *service.User, src *dbent.User) {
|
||||
if dst == nil || src == nil {
|
||||
return
|
||||
|
||||
@@ -717,6 +717,14 @@ func (r *stubApiKeyRepo) GetByID(ctx context.Context, id int64) (*service.ApiKey
|
||||
return &clone, nil
|
||||
}
|
||||
|
||||
func (r *stubApiKeyRepo) GetOwnerID(ctx context.Context, id int64) (int64, error) {
|
||||
key, ok := r.byID[id]
|
||||
if !ok {
|
||||
return 0, service.ErrApiKeyNotFound
|
||||
}
|
||||
return key.UserID, nil
|
||||
}
|
||||
|
||||
func (r *stubApiKeyRepo) GetByKey(ctx context.Context, key string) (*service.ApiKey, error) {
|
||||
found, ok := r.byKey[key]
|
||||
if !ok {
|
||||
|
||||
@@ -131,6 +131,10 @@ func (r *stubApiKeyRepo) GetByID(ctx context.Context, id int64) (*service.ApiKey
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (r *stubApiKeyRepo) GetOwnerID(ctx context.Context, id int64) (int64, error) {
|
||||
return 0, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (r *stubApiKeyRepo) GetByKey(ctx context.Context, key string) (*service.ApiKey, error) {
|
||||
if r.getByKey != nil {
|
||||
return r.getByKey(ctx, key)
|
||||
|
||||
@@ -119,6 +119,14 @@ func (s *accountRepoStub) ListSchedulableByGroupIDAndPlatform(ctx context.Contex
|
||||
panic("unexpected ListSchedulableByGroupIDAndPlatform call")
|
||||
}
|
||||
|
||||
func (s *accountRepoStub) ListSchedulableByPlatforms(ctx context.Context, platforms []string) ([]Account, error) {
|
||||
panic("unexpected ListSchedulableByPlatforms call")
|
||||
}
|
||||
|
||||
func (s *accountRepoStub) ListSchedulableByGroupIDAndPlatforms(ctx context.Context, groupID int64, platforms []string) ([]Account, error) {
|
||||
panic("unexpected ListSchedulableByGroupIDAndPlatforms call")
|
||||
}
|
||||
|
||||
func (s *accountRepoStub) SetRateLimited(ctx context.Context, id int64, resetAt time.Time) error {
|
||||
panic("unexpected SetRateLimited call")
|
||||
}
|
||||
|
||||
@@ -32,6 +32,14 @@ func (m *mockAccountRepoForPlatform) GetByID(ctx context.Context, id int64) (*Ac
|
||||
return nil, errors.New("account not found")
|
||||
}
|
||||
|
||||
func (m *mockAccountRepoForPlatform) ExistsByID(ctx context.Context, id int64) (bool, error) {
|
||||
if m.accountsByID == nil {
|
||||
return false, nil
|
||||
}
|
||||
_, ok := m.accountsByID[id]
|
||||
return ok, nil
|
||||
}
|
||||
|
||||
func (m *mockAccountRepoForPlatform) ListSchedulableByPlatform(ctx context.Context, platform string) ([]Account, error) {
|
||||
if m.listPlatformFunc != nil {
|
||||
return m.listPlatformFunc(ctx, platform)
|
||||
|
||||
@@ -25,6 +25,14 @@ func (m *mockAccountRepoForGemini) GetByID(ctx context.Context, id int64) (*Acco
|
||||
return nil, errors.New("account not found")
|
||||
}
|
||||
|
||||
func (m *mockAccountRepoForGemini) ExistsByID(ctx context.Context, id int64) (bool, error) {
|
||||
if m.accountsByID == nil {
|
||||
return false, nil
|
||||
}
|
||||
_, ok := m.accountsByID[id]
|
||||
return ok, nil
|
||||
}
|
||||
|
||||
func (m *mockAccountRepoForGemini) ListSchedulableByPlatform(ctx context.Context, platform string) ([]Account, error) {
|
||||
var result []Account
|
||||
for _, acc := range m.accounts {
|
||||
|
||||
Reference in New Issue
Block a user