From aacbc98aeca78b988764c190d420899c73276aa3 Mon Sep 17 00:00:00 2001 From: yangjianbo Date: Tue, 30 Dec 2025 16:41:45 +0800 Subject: [PATCH] =?UTF-8?q?fix(=E4=BB=93=E5=82=A8):=20=E4=BF=AE=E5=A4=8D?= =?UTF-8?q?=E6=9F=A5=E8=AF=A2=E5=85=B3=E9=97=AD=E9=94=99=E8=AF=AF=E5=B9=B6?= =?UTF-8?q?=E8=BF=81=E7=A7=BB=E9=9B=86=E6=88=90=E6=B5=8B=E8=AF=95?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 修复 rows.Close 失败时的错误返回逻辑 迁移网关路由集成测试到 ent 事务基建 补齐仓储接口变更对应的测试桩方法 新增 backend/Makefile 统一测试命令 测试: GOTOOLCHAIN=go1.24.11 go test ./... 测试: golangci-lint run ./... --timeout=5m 测试: make test-integration --- backend/Makefile | 13 +++++ backend/ent/schema/account.go | 20 ++++---- backend/ent/schema/mixins/soft_delete.go | 2 +- backend/ent/schema/user_subscription.go | 1 - .../gateway_routing_integration_test.go | 42 ++++++++--------- backend/internal/repository/group_repo.go | 22 ++++----- backend/internal/repository/proxy_repo.go | 15 ++++-- backend/internal/repository/sql_scan.go | 17 +++++-- backend/internal/repository/usage_log_repo.go | 18 +++---- backend/internal/repository/user_repo.go | 47 ------------------- backend/internal/server/api_contract_test.go | 8 ++++ .../server/middleware/api_key_auth_test.go | 4 ++ .../service/account_service_delete_test.go | 8 ++++ .../service/gateway_multiplatform_test.go | 8 ++++ .../service/gemini_multiplatform_test.go | 8 ++++ 15 files changed, 123 insertions(+), 110 deletions(-) create mode 100644 backend/Makefile diff --git a/backend/Makefile b/backend/Makefile new file mode 100644 index 00000000..ae4b84f7 --- /dev/null +++ b/backend/Makefile @@ -0,0 +1,13 @@ +.PHONY: build test-unit test-integration test-e2e + +build: + go build -o bin/server ./cmd/server + +test-unit: + go test -tags=unit ./... + +test-integration: + go test -tags=integration ./... + +test-e2e: + go test -tags=e2e ./... diff --git a/backend/ent/schema/account.go b/backend/ent/schema/account.go index bd929693..c1dd64af 100644 --- a/backend/ent/schema/account.go +++ b/backend/ent/schema/account.go @@ -175,16 +175,16 @@ func (Account) Edges() []ent.Edge { // 每个索引对应一个常用的查询条件。 func (Account) Indexes() []ent.Index { return []ent.Index{ - index.Fields("platform"), // 按平台筛选 - index.Fields("type"), // 按认证类型筛选 - index.Fields("status"), // 按状态筛选 - index.Fields("proxy_id"), // 按代理筛选 - index.Fields("priority"), // 按优先级排序 - index.Fields("last_used_at"), // 按最后使用时间排序 - index.Fields("schedulable"), // 筛选可调度账户 - index.Fields("rate_limited_at"), // 筛选速率限制账户 + index.Fields("platform"), // 按平台筛选 + index.Fields("type"), // 按认证类型筛选 + index.Fields("status"), // 按状态筛选 + index.Fields("proxy_id"), // 按代理筛选 + index.Fields("priority"), // 按优先级排序 + index.Fields("last_used_at"), // 按最后使用时间排序 + index.Fields("schedulable"), // 筛选可调度账户 + index.Fields("rate_limited_at"), // 筛选速率限制账户 index.Fields("rate_limit_reset_at"), // 筛选速率限制解除时间 - index.Fields("overload_until"), // 筛选过载账户 - index.Fields("deleted_at"), // 软删除查询优化 + index.Fields("overload_until"), // 筛选过载账户 + index.Fields("deleted_at"), // 软删除查询优化 } } diff --git a/backend/ent/schema/mixins/soft_delete.go b/backend/ent/schema/mixins/soft_delete.go index d62cf4a9..00ef77a6 100644 --- a/backend/ent/schema/mixins/soft_delete.go +++ b/backend/ent/schema/mixins/soft_delete.go @@ -7,12 +7,12 @@ import ( "fmt" "time" - "github.com/Wei-Shaw/sub2api/ent/intercept" "entgo.io/ent" "entgo.io/ent/dialect" "entgo.io/ent/dialect/sql" "entgo.io/ent/schema/field" "entgo.io/ent/schema/mixin" + "github.com/Wei-Shaw/sub2api/ent/intercept" ) // SoftDeleteMixin 实现基于 deleted_at 时间戳的软删除功能。 diff --git a/backend/ent/schema/user_subscription.go b/backend/ent/schema/user_subscription.go index a87e4c39..bcb0da71 100644 --- a/backend/ent/schema/user_subscription.go +++ b/backend/ent/schema/user_subscription.go @@ -110,4 +110,3 @@ func (UserSubscription) Indexes() []ent.Index { index.Fields("user_id", "group_id").Unique(), } } - diff --git a/backend/internal/repository/gateway_routing_integration_test.go b/backend/internal/repository/gateway_routing_integration_test.go index 46a22f9c..5566d2e9 100644 --- a/backend/internal/repository/gateway_routing_integration_test.go +++ b/backend/internal/repository/gateway_routing_integration_test.go @@ -6,10 +6,9 @@ import ( "context" "testing" + dbent "github.com/Wei-Shaw/sub2api/ent" "github.com/Wei-Shaw/sub2api/internal/service" "github.com/stretchr/testify/suite" - "gorm.io/datatypes" - "gorm.io/gorm" ) // GatewayRoutingSuite 测试网关路由相关的数据库查询 @@ -17,14 +16,15 @@ import ( type GatewayRoutingSuite struct { suite.Suite ctx context.Context - db *gorm.DB + client *dbent.Client accountRepo *accountRepository } func (s *GatewayRoutingSuite) SetupTest() { s.ctx = context.Background() - s.db = testTx(s.T()) - s.accountRepo = NewAccountRepository(s.db).(*accountRepository) + tx := testEntTx(s.T()) + s.client = tx.Client() + s.accountRepo = newAccountRepositoryWithSQL(s.client, tx) } func TestGatewayRoutingSuite(t *testing.T) { @@ -34,7 +34,7 @@ func TestGatewayRoutingSuite(t *testing.T) { // TestListSchedulableByPlatforms_GeminiAndAntigravity 验证多平台账户查询 func (s *GatewayRoutingSuite) TestListSchedulableByPlatforms_GeminiAndAntigravity() { // 创建各平台账户 - geminiAcc := mustCreateAccount(s.T(), s.db, &accountModel{ + geminiAcc := mustCreateAccount(s.T(), s.client, &service.Account{ Name: "gemini-oauth", Platform: service.PlatformGemini, Type: service.AccountTypeOAuth, @@ -43,14 +43,14 @@ func (s *GatewayRoutingSuite) TestListSchedulableByPlatforms_GeminiAndAntigravit Priority: 1, }) - antigravityAcc := mustCreateAccount(s.T(), s.db, &accountModel{ + antigravityAcc := mustCreateAccount(s.T(), s.client, &service.Account{ Name: "antigravity-oauth", Platform: service.PlatformAntigravity, Type: service.AccountTypeOAuth, Status: service.StatusActive, Schedulable: true, Priority: 2, - Credentials: datatypes.JSONMap{ + Credentials: map[string]any{ "access_token": "test-token", "refresh_token": "test-refresh", "project_id": "test-project", @@ -58,7 +58,7 @@ func (s *GatewayRoutingSuite) TestListSchedulableByPlatforms_GeminiAndAntigravit }) // 创建不应被选中的 anthropic 账户 - mustCreateAccount(s.T(), s.db, &accountModel{ + mustCreateAccount(s.T(), s.client, &service.Account{ Name: "anthropic-oauth", Platform: service.PlatformAnthropic, Type: service.AccountTypeOAuth, @@ -97,20 +97,20 @@ func (s *GatewayRoutingSuite) TestListSchedulableByPlatforms_GeminiAndAntigravit // TestListSchedulableByGroupIDAndPlatforms_WithGroupBinding 验证按分组过滤 func (s *GatewayRoutingSuite) TestListSchedulableByGroupIDAndPlatforms_WithGroupBinding() { // 创建 gemini 分组 - group := mustCreateGroup(s.T(), s.db, &groupModel{ + group := mustCreateGroup(s.T(), s.client, &service.Group{ Name: "gemini-group", Platform: service.PlatformGemini, Status: service.StatusActive, }) // 创建账户 - boundAcc := mustCreateAccount(s.T(), s.db, &accountModel{ + boundAcc := mustCreateAccount(s.T(), s.client, &service.Account{ Name: "bound-antigravity", Platform: service.PlatformAntigravity, Status: service.StatusActive, Schedulable: true, }) - unboundAcc := mustCreateAccount(s.T(), s.db, &accountModel{ + unboundAcc := mustCreateAccount(s.T(), s.client, &service.Account{ Name: "unbound-antigravity", Platform: service.PlatformAntigravity, Status: service.StatusActive, @@ -118,7 +118,7 @@ func (s *GatewayRoutingSuite) TestListSchedulableByGroupIDAndPlatforms_WithGroup }) // 只绑定一个账户到分组 - mustBindAccountToGroup(s.T(), s.db, boundAcc.ID, group.ID, 1) + mustBindAccountToGroup(s.T(), s.client, boundAcc.ID, group.ID, 1) // 查询分组内的账户 accounts, err := s.accountRepo.ListSchedulableByGroupIDAndPlatforms(s.ctx, group.ID, []string{ @@ -139,14 +139,14 @@ func (s *GatewayRoutingSuite) TestListSchedulableByGroupIDAndPlatforms_WithGroup // TestListSchedulableByPlatform_Antigravity 验证单平台查询 func (s *GatewayRoutingSuite) TestListSchedulableByPlatform_Antigravity() { // 创建多种平台账户 - mustCreateAccount(s.T(), s.db, &accountModel{ + mustCreateAccount(s.T(), s.client, &service.Account{ Name: "gemini-1", Platform: service.PlatformGemini, Status: service.StatusActive, Schedulable: true, }) - antigravity := mustCreateAccount(s.T(), s.db, &accountModel{ + antigravity := mustCreateAccount(s.T(), s.client, &service.Account{ Name: "antigravity-1", Platform: service.PlatformAntigravity, Status: service.StatusActive, @@ -165,7 +165,7 @@ func (s *GatewayRoutingSuite) TestListSchedulableByPlatform_Antigravity() { // TestSchedulableFilter_ExcludesInactive 验证不可调度账户被过滤 func (s *GatewayRoutingSuite) TestSchedulableFilter_ExcludesInactive() { // 创建可调度账户 - activeAcc := mustCreateAccount(s.T(), s.db, &accountModel{ + activeAcc := mustCreateAccount(s.T(), s.client, &service.Account{ Name: "active-antigravity", Platform: service.PlatformAntigravity, Status: service.StatusActive, @@ -173,15 +173,15 @@ func (s *GatewayRoutingSuite) TestSchedulableFilter_ExcludesInactive() { }) // 创建不可调度账户(需要先创建再更新,因为 fixture 默认设置 Schedulable=true) - inactiveAcc := mustCreateAccount(s.T(), s.db, &accountModel{ + inactiveAcc := mustCreateAccount(s.T(), s.client, &service.Account{ Name: "inactive-antigravity", Platform: service.PlatformAntigravity, Status: service.StatusActive, }) - s.Require().NoError(s.db.Model(&accountModel{}).Where("id = ?", inactiveAcc.ID).Update("schedulable", false).Error) + s.Require().NoError(s.client.Account.UpdateOneID(inactiveAcc.ID).SetSchedulable(false).Exec(s.ctx)) // 创建错误状态账户 - mustCreateAccount(s.T(), s.db, &accountModel{ + mustCreateAccount(s.T(), s.client, &service.Account{ Name: "error-antigravity", Platform: service.PlatformAntigravity, Status: service.StatusError, @@ -199,14 +199,14 @@ func (s *GatewayRoutingSuite) TestSchedulableFilter_ExcludesInactive() { // 这个测试模拟 Handler 层在选择账户后的路由决策逻辑 func (s *GatewayRoutingSuite) TestPlatformRoutingDecision() { // 创建两种平台的账户 - geminiAcc := mustCreateAccount(s.T(), s.db, &accountModel{ + geminiAcc := mustCreateAccount(s.T(), s.client, &service.Account{ Name: "gemini-route-test", Platform: service.PlatformGemini, Status: service.StatusActive, Schedulable: true, }) - antigravityAcc := mustCreateAccount(s.T(), s.db, &accountModel{ + antigravityAcc := mustCreateAccount(s.T(), s.client, &service.Account{ Name: "antigravity-route-test", Platform: service.PlatformAntigravity, Status: service.StatusActive, diff --git a/backend/internal/repository/group_repo.go b/backend/internal/repository/group_repo.go index a6c7e3cc..5670a69b 100644 --- a/backend/internal/repository/group_repo.go +++ b/backend/internal/repository/group_repo.go @@ -235,9 +235,8 @@ func (r *groupRepository) DeleteCascade(ctx context.Context, id int64) ([]int64, defer func() { _ = tx.Rollback() }() exec = tx.Client() txClient = exec - } else { - // 已处于外部事务中(ErrTxStarted),复用当前 client 参与同一事务。 } + // err 为 dbent.ErrTxStarted 时,复用当前 client 参与同一事务。 // Lock the group row to avoid concurrent writes while we cascade. // 这里使用 exec.QueryContext 手动扫描,确保同一事务内加锁并能区分“未找到”与其他错误。 @@ -330,8 +329,8 @@ func (r *groupRepository) DeleteCascade(ctx context.Context, id int64) ([]int64, return affectedUserIDs, nil } -func (r *groupRepository) loadAccountCounts(ctx context.Context, groupIDs []int64) (map[int64]int64, error) { - counts := make(map[int64]int64, len(groupIDs)) +func (r *groupRepository) loadAccountCounts(ctx context.Context, groupIDs []int64) (counts map[int64]int64, err error) { + counts = make(map[int64]int64, len(groupIDs)) if len(groupIDs) == 0 { return counts, nil } @@ -344,23 +343,24 @@ func (r *groupRepository) loadAccountCounts(ctx context.Context, groupIDs []int6 if err != nil { return nil, err } - defer rows.Close() + defer func() { + if closeErr := rows.Close(); closeErr != nil && err == nil { + err = closeErr + counts = nil + } + }() for rows.Next() { var groupID int64 var count int64 - if err := rows.Scan(&groupID, &count); err != nil { + if err = rows.Scan(&groupID, &count); err != nil { return nil, err } counts[groupID] = count } - if err := rows.Err(); err != nil { + if err = rows.Err(); err != nil { return nil, err } return counts, nil } - -func errorsIsNoRows(err error) bool { - return err == sql.ErrNoRows -} diff --git a/backend/internal/repository/proxy_repo.go b/backend/internal/repository/proxy_repo.go index adbc2dfb..f9315525 100644 --- a/backend/internal/repository/proxy_repo.go +++ b/backend/internal/repository/proxy_repo.go @@ -177,22 +177,27 @@ func (r *proxyRepository) CountAccountsByProxyID(ctx context.Context, proxyID in } // GetAccountCountsForProxies returns a map of proxy ID to account count for all proxies -func (r *proxyRepository) GetAccountCountsForProxies(ctx context.Context) (map[int64]int64, error) { +func (r *proxyRepository) GetAccountCountsForProxies(ctx context.Context) (counts map[int64]int64, err error) { rows, err := r.sql.QueryContext(ctx, "SELECT proxy_id, COUNT(*) AS count FROM accounts WHERE proxy_id IS NOT NULL GROUP BY proxy_id") if err != nil { return nil, err } - defer rows.Close() + defer func() { + if closeErr := rows.Close(); closeErr != nil && err == nil { + err = closeErr + counts = nil + } + }() - counts := make(map[int64]int64) + counts = make(map[int64]int64) for rows.Next() { var proxyID, count int64 - if err := rows.Scan(&proxyID, &count); err != nil { + if err = rows.Scan(&proxyID, &count); err != nil { return nil, err } counts[proxyID] = count } - if err := rows.Err(); err != nil { + if err = rows.Err(); err != nil { return nil, err } return counts, nil diff --git a/backend/internal/repository/sql_scan.go b/backend/internal/repository/sql_scan.go index f683f50d..e734ea82 100644 --- a/backend/internal/repository/sql_scan.go +++ b/backend/internal/repository/sql_scan.go @@ -13,21 +13,28 @@ type sqlQueryer interface { // If no rows are returned, sql.ErrNoRows is returned. // 设计目的:仅依赖 QueryContext,避免 QueryRowContext 对 *sql.Tx 的强绑定, // 让 ent.Tx 也能作为 sqlExecutor/Queryer 使用。 -func scanSingleRow(ctx context.Context, q sqlQueryer, query string, args []any, dest ...any) error { +func scanSingleRow(ctx context.Context, q sqlQueryer, query string, args []any, dest ...any) (err error) { rows, err := q.QueryContext(ctx, query, args...) if err != nil { return err } - defer rows.Close() + defer func() { + if closeErr := rows.Close(); closeErr != nil && err == nil { + err = closeErr + } + }() if !rows.Next() { - if err := rows.Err(); err != nil { + if err = rows.Err(); err != nil { return err } return sql.ErrNoRows } - if err := rows.Scan(dest...); err != nil { + if err = rows.Scan(dest...); err != nil { return err } - return rows.Err() + if err = rows.Err(); err != nil { + return err + } + return nil } diff --git a/backend/internal/repository/usage_log_repo.go b/backend/internal/repository/usage_log_repo.go index eeaaa12c..4b9694c1 100644 --- a/backend/internal/repository/usage_log_repo.go +++ b/backend/internal/repository/usage_log_repo.go @@ -154,7 +154,7 @@ func (r *usageLogRepository) GetByID(ctx context.Context, id int64) (*service.Us if err != nil { return nil, err } - defer rows.Close() + defer func() { _ = rows.Close() }() if !rows.Next() { if err := rows.Err(); err != nil { return nil, err @@ -568,7 +568,7 @@ func (r *usageLogRepository) GetApiKeyUsageTrend(ctx context.Context, startTime, if err != nil { return nil, err } - defer rows.Close() + defer func() { _ = rows.Close() }() results := make([]ApiKeyUsageTrendPoint, 0) for rows.Next() { @@ -621,7 +621,7 @@ func (r *usageLogRepository) GetUserUsageTrend(ctx context.Context, startTime, e if err != nil { return nil, err } - defer rows.Close() + defer func() { _ = rows.Close() }() results := make([]UserUsageTrendPoint, 0) for rows.Next() { @@ -766,7 +766,7 @@ func (r *usageLogRepository) GetUserUsageTrendByUserID(ctx context.Context, user if err != nil { return nil, err } - defer rows.Close() + defer func() { _ = rows.Close() }() return scanTrendRows(rows) } @@ -792,7 +792,7 @@ func (r *usageLogRepository) GetUserModelStats(ctx context.Context, userID int64 if err != nil { return nil, err } - defer rows.Close() + defer func() { _ = rows.Close() }() return scanModelStatsRows(rows) } @@ -1029,7 +1029,7 @@ func (r *usageLogRepository) GetUsageTrendWithFilters(ctx context.Context, start if err != nil { return nil, err } - defer rows.Close() + defer func() { _ = rows.Close() }() return scanTrendRows(rows) } @@ -1068,7 +1068,7 @@ func (r *usageLogRepository) GetModelStatsWithFilters(ctx context.Context, start if err != nil { return nil, err } - defer rows.Close() + defer func() { _ = rows.Close() }() return scanModelStatsRows(rows) } @@ -1141,7 +1141,7 @@ func (r *usageLogRepository) GetAccountUsageStats(ctx context.Context, accountID if err != nil { return nil, err } - defer rows.Close() + defer func() { _ = rows.Close() }() history := make([]AccountUsageHistory, 0) for rows.Next() { @@ -1291,7 +1291,7 @@ func (r *usageLogRepository) queryUsageLogs(ctx context.Context, query string, a if err != nil { return nil, err } - defer rows.Close() + defer func() { _ = rows.Close() }() logs := make([]service.UsageLog, 0) for rows.Next() { diff --git a/backend/internal/repository/user_repo.go b/backend/internal/repository/user_repo.go index b88c47d3..7766fe98 100644 --- a/backend/internal/repository/user_repo.go +++ b/backend/internal/repository/user_repo.go @@ -405,53 +405,6 @@ func (r *userRepository) syncUserAllowedGroupsWithClient(ctx context.Context, cl return nil } -func (r *userRepository) syncUserAllowedGroups(ctx context.Context, client *dbent.Client, exec sqlExecutor, userID int64, groupIDs []int64) error { - if client == nil || exec == nil { - return nil - } - - // Keep join table as the source of truth for reads. - if _, err := client.UserAllowedGroup.Delete().Where(userallowedgroup.UserIDEQ(userID)).Exec(ctx); err != nil { - return err - } - - unique := make(map[int64]struct{}, len(groupIDs)) - for _, id := range groupIDs { - if id <= 0 { - continue - } - unique[id] = struct{}{} - } - - legacyGroups := make([]int64, 0, len(unique)) - if len(unique) > 0 { - creates := make([]*dbent.UserAllowedGroupCreate, 0, len(unique)) - for groupID := range unique { - creates = append(creates, client.UserAllowedGroup.Create().SetUserID(userID).SetGroupID(groupID)) - legacyGroups = append(legacyGroups, groupID) - } - if err := client.UserAllowedGroup. - CreateBulk(creates...). - OnConflictColumns(userallowedgroup.FieldUserID, userallowedgroup.FieldGroupID). - DoNothing(). - Exec(ctx); err != nil { - return err - } - } - - // Phase 1 compatibility: keep legacy users.allowed_groups array updated for existing raw SQL paths. - var legacy any - if len(legacyGroups) > 0 { - sort.Slice(legacyGroups, func(i, j int) bool { return legacyGroups[i] < legacyGroups[j] }) - legacy = pq.Array(legacyGroups) - } - if _, err := exec.ExecContext(ctx, "UPDATE users SET allowed_groups = $1::bigint[] WHERE id = $2", legacy, userID); err != nil { - return err - } - - return nil -} - func applyUserEntityToService(dst *service.User, src *dbent.User) { if dst == nil || src == nil { return diff --git a/backend/internal/server/api_contract_test.go b/backend/internal/server/api_contract_test.go index 31e62861..8d5ace96 100644 --- a/backend/internal/server/api_contract_test.go +++ b/backend/internal/server/api_contract_test.go @@ -717,6 +717,14 @@ func (r *stubApiKeyRepo) GetByID(ctx context.Context, id int64) (*service.ApiKey return &clone, nil } +func (r *stubApiKeyRepo) GetOwnerID(ctx context.Context, id int64) (int64, error) { + key, ok := r.byID[id] + if !ok { + return 0, service.ErrApiKeyNotFound + } + return key.UserID, nil +} + func (r *stubApiKeyRepo) GetByKey(ctx context.Context, key string) (*service.ApiKey, error) { found, ok := r.byKey[key] if !ok { diff --git a/backend/internal/server/middleware/api_key_auth_test.go b/backend/internal/server/middleware/api_key_auth_test.go index a9d22ede..841edd07 100644 --- a/backend/internal/server/middleware/api_key_auth_test.go +++ b/backend/internal/server/middleware/api_key_auth_test.go @@ -131,6 +131,10 @@ func (r *stubApiKeyRepo) GetByID(ctx context.Context, id int64) (*service.ApiKey return nil, errors.New("not implemented") } +func (r *stubApiKeyRepo) GetOwnerID(ctx context.Context, id int64) (int64, error) { + return 0, errors.New("not implemented") +} + func (r *stubApiKeyRepo) GetByKey(ctx context.Context, key string) (*service.ApiKey, error) { if r.getByKey != nil { return r.getByKey(ctx, key) diff --git a/backend/internal/service/account_service_delete_test.go b/backend/internal/service/account_service_delete_test.go index 93888bf5..2648b828 100644 --- a/backend/internal/service/account_service_delete_test.go +++ b/backend/internal/service/account_service_delete_test.go @@ -119,6 +119,14 @@ func (s *accountRepoStub) ListSchedulableByGroupIDAndPlatform(ctx context.Contex panic("unexpected ListSchedulableByGroupIDAndPlatform call") } +func (s *accountRepoStub) ListSchedulableByPlatforms(ctx context.Context, platforms []string) ([]Account, error) { + panic("unexpected ListSchedulableByPlatforms call") +} + +func (s *accountRepoStub) ListSchedulableByGroupIDAndPlatforms(ctx context.Context, groupID int64, platforms []string) ([]Account, error) { + panic("unexpected ListSchedulableByGroupIDAndPlatforms call") +} + func (s *accountRepoStub) SetRateLimited(ctx context.Context, id int64, resetAt time.Time) error { panic("unexpected SetRateLimited call") } diff --git a/backend/internal/service/gateway_multiplatform_test.go b/backend/internal/service/gateway_multiplatform_test.go index d66aa6f1..d779bcfa 100644 --- a/backend/internal/service/gateway_multiplatform_test.go +++ b/backend/internal/service/gateway_multiplatform_test.go @@ -32,6 +32,14 @@ func (m *mockAccountRepoForPlatform) GetByID(ctx context.Context, id int64) (*Ac return nil, errors.New("account not found") } +func (m *mockAccountRepoForPlatform) ExistsByID(ctx context.Context, id int64) (bool, error) { + if m.accountsByID == nil { + return false, nil + } + _, ok := m.accountsByID[id] + return ok, nil +} + func (m *mockAccountRepoForPlatform) ListSchedulableByPlatform(ctx context.Context, platform string) ([]Account, error) { if m.listPlatformFunc != nil { return m.listPlatformFunc(ctx, platform) diff --git a/backend/internal/service/gemini_multiplatform_test.go b/backend/internal/service/gemini_multiplatform_test.go index 43e4ccfe..dcc945eb 100644 --- a/backend/internal/service/gemini_multiplatform_test.go +++ b/backend/internal/service/gemini_multiplatform_test.go @@ -25,6 +25,14 @@ func (m *mockAccountRepoForGemini) GetByID(ctx context.Context, id int64) (*Acco return nil, errors.New("account not found") } +func (m *mockAccountRepoForGemini) ExistsByID(ctx context.Context, id int64) (bool, error) { + if m.accountsByID == nil { + return false, nil + } + _, ok := m.accountsByID[id] + return ok, nil +} + func (m *mockAccountRepoForGemini) ListSchedulableByPlatform(ctx context.Context, platform string) ([]Account, error) { var result []Account for _, acc := range m.accounts {