From 530a16291cf50942cc8254721d72dbb90ca4d422 Mon Sep 17 00:00:00 2001 From: QTom Date: Tue, 3 Mar 2026 13:10:26 +0800 Subject: [PATCH] =?UTF-8?q?fix(gateway):=20=E5=88=86=E7=BB=84=E9=9A=94?= =?UTF-8?q?=E7=A6=BB=20=E2=80=94=20=E7=A6=81=E6=AD=A2=E6=9C=AA=E5=88=86?= =?UTF-8?q?=E7=BB=84=E8=B4=A6=E5=8F=B7=E8=A2=AB=E8=B7=A8=E7=BB=84=E8=B0=83?= =?UTF-8?q?=E5=BA=A6?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 当 API Key 无分组时,调度仅从未分组账号池中选取。 修复 isAccountInGroup 在 groupID==nil 时的逻辑, 同时补全 scheduler_snapshot_service 和 gemini_compat_service 中的 SimpleMode 保护,确保分组隔离在所有调度路径生效。 新增 ListSchedulableUngroupedByPlatform/s 方法, 使用 Ent 的 Not(HasAccountGroups()) 谓词实现未分组账号隔离。 新增 17 个单元和端到端隔离测试,覆盖所有分支和边界条件。 --- .../handler/sora_client_handler_test.go | 6 + .../handler/sora_gateway_handler_test.go | 6 + backend/internal/repository/account_repo.go | 45 +++ backend/internal/server/api_contract_test.go | 8 + backend/internal/service/account_service.go | 2 + .../service/account_service_delete_test.go | 8 + .../service/gateway_group_isolation_test.go | 363 ++++++++++++++++++ .../service/gateway_multiplatform_test.go | 6 + backend/internal/service/gateway_service.go | 15 +- .../service/gemini_messages_compat_service.go | 5 +- .../service/gemini_multiplatform_test.go | 6 + .../service/openai_gateway_service.go | 2 +- .../service/openai_gateway_service_test.go | 4 + .../service/scheduler_snapshot_service.go | 9 +- 14 files changed, 475 insertions(+), 10 deletions(-) create mode 100644 backend/internal/service/gateway_group_isolation_test.go diff --git a/backend/internal/handler/sora_client_handler_test.go b/backend/internal/handler/sora_client_handler_test.go index 5df7fa0a..c2284ce2 100644 --- a/backend/internal/handler/sora_client_handler_test.go +++ b/backend/internal/handler/sora_client_handler_test.go @@ -2089,6 +2089,12 @@ func (r *stubAccountRepoForHandler) ListSchedulableByPlatforms(context.Context, func (r *stubAccountRepoForHandler) ListSchedulableByGroupIDAndPlatforms(context.Context, int64, []string) ([]service.Account, error) { return r.accounts, nil } +func (r *stubAccountRepoForHandler) ListSchedulableUngroupedByPlatform(_ context.Context, _ string) ([]service.Account, error) { + return r.accounts, nil +} +func (r *stubAccountRepoForHandler) ListSchedulableUngroupedByPlatforms(_ context.Context, _ []string) ([]service.Account, error) { + return r.accounts, nil +} func (r *stubAccountRepoForHandler) SetRateLimited(context.Context, int64, time.Time) error { return nil } diff --git a/backend/internal/handler/sora_gateway_handler_test.go b/backend/internal/handler/sora_gateway_handler_test.go index 355cdb7a..59ac34b1 100644 --- a/backend/internal/handler/sora_gateway_handler_test.go +++ b/backend/internal/handler/sora_gateway_handler_test.go @@ -182,6 +182,12 @@ func (r *stubAccountRepo) ListSchedulableByPlatforms(ctx context.Context, platfo func (r *stubAccountRepo) ListSchedulableByGroupIDAndPlatforms(ctx context.Context, groupID int64, platforms []string) ([]service.Account, error) { return r.ListSchedulableByPlatforms(ctx, platforms) } +func (r *stubAccountRepo) ListSchedulableUngroupedByPlatform(ctx context.Context, platform string) ([]service.Account, error) { + return r.ListSchedulableByPlatform(ctx, platform) +} +func (r *stubAccountRepo) ListSchedulableUngroupedByPlatforms(ctx context.Context, platforms []string) ([]service.Account, error) { + return r.ListSchedulableByPlatforms(ctx, platforms) +} func (r *stubAccountRepo) SetRateLimited(ctx context.Context, id int64, resetAt time.Time) error { return nil } diff --git a/backend/internal/repository/account_repo.go b/backend/internal/repository/account_repo.go index 4aa74928..0669cbbd 100644 --- a/backend/internal/repository/account_repo.go +++ b/backend/internal/repository/account_repo.go @@ -829,6 +829,51 @@ func (r *accountRepository) ListSchedulableByPlatforms(ctx context.Context, plat return r.accountsToService(ctx, accounts) } +func (r *accountRepository) ListSchedulableUngroupedByPlatform(ctx context.Context, platform string) ([]service.Account, error) { + now := time.Now() + accounts, err := r.client.Account.Query(). + Where( + dbaccount.PlatformEQ(platform), + dbaccount.StatusEQ(service.StatusActive), + dbaccount.SchedulableEQ(true), + dbaccount.Not(dbaccount.HasAccountGroups()), + tempUnschedulablePredicate(), + notExpiredPredicate(now), + dbaccount.Or(dbaccount.OverloadUntilIsNil(), dbaccount.OverloadUntilLTE(now)), + dbaccount.Or(dbaccount.RateLimitResetAtIsNil(), dbaccount.RateLimitResetAtLTE(now)), + ). + Order(dbent.Asc(dbaccount.FieldPriority)). + All(ctx) + if err != nil { + return nil, err + } + return r.accountsToService(ctx, accounts) +} + +func (r *accountRepository) ListSchedulableUngroupedByPlatforms(ctx context.Context, platforms []string) ([]service.Account, error) { + if len(platforms) == 0 { + return nil, nil + } + now := time.Now() + accounts, err := r.client.Account.Query(). + Where( + dbaccount.PlatformIn(platforms...), + dbaccount.StatusEQ(service.StatusActive), + dbaccount.SchedulableEQ(true), + dbaccount.Not(dbaccount.HasAccountGroups()), + tempUnschedulablePredicate(), + notExpiredPredicate(now), + dbaccount.Or(dbaccount.OverloadUntilIsNil(), dbaccount.OverloadUntilLTE(now)), + dbaccount.Or(dbaccount.RateLimitResetAtIsNil(), dbaccount.RateLimitResetAtLTE(now)), + ). + Order(dbent.Asc(dbaccount.FieldPriority)). + All(ctx) + if err != nil { + return nil, err + } + return r.accountsToService(ctx, accounts) +} + func (r *accountRepository) ListSchedulableByGroupIDAndPlatforms(ctx context.Context, groupID int64, platforms []string) ([]service.Account, error) { if len(platforms) == 0 { return nil, nil diff --git a/backend/internal/server/api_contract_test.go b/backend/internal/server/api_contract_test.go index f15a2074..446ee20d 100644 --- a/backend/internal/server/api_contract_test.go +++ b/backend/internal/server/api_contract_test.go @@ -1027,6 +1027,14 @@ func (s *stubAccountRepo) ListSchedulableByGroupIDAndPlatforms(ctx context.Conte return nil, errors.New("not implemented") } +func (s *stubAccountRepo) ListSchedulableUngroupedByPlatform(ctx context.Context, platform string) ([]service.Account, error) { + return nil, errors.New("not implemented") +} + +func (s *stubAccountRepo) ListSchedulableUngroupedByPlatforms(ctx context.Context, platforms []string) ([]service.Account, error) { + return nil, errors.New("not implemented") +} + func (s *stubAccountRepo) SetRateLimited(ctx context.Context, id int64, resetAt time.Time) error { return errors.New("not implemented") } diff --git a/backend/internal/service/account_service.go b/backend/internal/service/account_service.go index a3707184..18a70c5c 100644 --- a/backend/internal/service/account_service.go +++ b/backend/internal/service/account_service.go @@ -54,6 +54,8 @@ type AccountRepository interface { ListSchedulableByGroupIDAndPlatform(ctx context.Context, groupID int64, platform string) ([]Account, error) ListSchedulableByPlatforms(ctx context.Context, platforms []string) ([]Account, error) ListSchedulableByGroupIDAndPlatforms(ctx context.Context, groupID int64, platforms []string) ([]Account, error) + ListSchedulableUngroupedByPlatform(ctx context.Context, platform string) ([]Account, error) + ListSchedulableUngroupedByPlatforms(ctx context.Context, platforms []string) ([]Account, error) SetRateLimited(ctx context.Context, id int64, resetAt time.Time) error SetModelRateLimit(ctx context.Context, id int64, scope string, resetAt time.Time) error diff --git a/backend/internal/service/account_service_delete_test.go b/backend/internal/service/account_service_delete_test.go index a466b68a..768cf7b7 100644 --- a/backend/internal/service/account_service_delete_test.go +++ b/backend/internal/service/account_service_delete_test.go @@ -147,6 +147,14 @@ func (s *accountRepoStub) ListSchedulableByGroupIDAndPlatforms(ctx context.Conte panic("unexpected ListSchedulableByGroupIDAndPlatforms call") } +func (s *accountRepoStub) ListSchedulableUngroupedByPlatform(ctx context.Context, platform string) ([]Account, error) { + panic("unexpected ListSchedulableUngroupedByPlatform call") +} + +func (s *accountRepoStub) ListSchedulableUngroupedByPlatforms(ctx context.Context, platforms []string) ([]Account, error) { + panic("unexpected ListSchedulableUngroupedByPlatforms call") +} + func (s *accountRepoStub) SetRateLimited(ctx context.Context, id int64, resetAt time.Time) error { panic("unexpected SetRateLimited call") } diff --git a/backend/internal/service/gateway_group_isolation_test.go b/backend/internal/service/gateway_group_isolation_test.go new file mode 100644 index 00000000..00508f0e --- /dev/null +++ b/backend/internal/service/gateway_group_isolation_test.go @@ -0,0 +1,363 @@ +//go:build unit + +package service + +import ( + "context" + "testing" + + "github.com/Wei-Shaw/sub2api/internal/config" + "github.com/stretchr/testify/require" +) + +// ============================================================================ +// Part 1: isAccountInGroup 单元测试 +// ============================================================================ + +func TestIsAccountInGroup(t *testing.T) { + svc := &GatewayService{} + groupID100 := int64(100) + groupID200 := int64(200) + + tests := []struct { + name string + account *Account + groupID *int64 + expected bool + }{ + // groupID == nil(无分组 API Key) + { + "nil_groupID_ungrouped_account_nil_groups", + &Account{ID: 1, AccountGroups: nil}, + nil, true, + }, + { + "nil_groupID_ungrouped_account_empty_slice", + &Account{ID: 2, AccountGroups: []AccountGroup{}}, + nil, true, + }, + { + "nil_groupID_grouped_account_single", + &Account{ID: 3, AccountGroups: []AccountGroup{{GroupID: 100}}}, + nil, false, + }, + { + "nil_groupID_grouped_account_multiple", + &Account{ID: 4, AccountGroups: []AccountGroup{{GroupID: 100}, {GroupID: 200}}}, + nil, false, + }, + // groupID != nil(有分组 API Key) + { + "with_groupID_account_in_group", + &Account{ID: 5, AccountGroups: []AccountGroup{{GroupID: 100}}}, + &groupID100, true, + }, + { + "with_groupID_account_not_in_group", + &Account{ID: 6, AccountGroups: []AccountGroup{{GroupID: 200}}}, + &groupID100, false, + }, + { + "with_groupID_ungrouped_account", + &Account{ID: 7, AccountGroups: nil}, + &groupID100, false, + }, + { + "with_groupID_multi_group_account_match_one", + &Account{ID: 8, AccountGroups: []AccountGroup{{GroupID: 100}, {GroupID: 200}}}, + &groupID200, true, + }, + { + "with_groupID_multi_group_account_no_match", + &Account{ID: 9, AccountGroups: []AccountGroup{{GroupID: 300}, {GroupID: 400}}}, + &groupID100, false, + }, + // 防御性边界 + { + "nil_account_nil_groupID", + nil, + nil, false, + }, + { + "nil_account_with_groupID", + nil, + &groupID100, false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := svc.isAccountInGroup(tt.account, tt.groupID) + require.Equal(t, tt.expected, got, "isAccountInGroup 结果不符预期") + }) + } +} + +// ============================================================================ +// Part 2: 分组隔离端到端调度测试 +// ============================================================================ + +// groupAwareMockAccountRepo 嵌入 mockAccountRepoForPlatform,覆写分组隔离相关方法。 +// allAccounts 存储所有账号,分组查询方法按 AccountGroups 字段进行真实过滤。 +type groupAwareMockAccountRepo struct { + *mockAccountRepoForPlatform + allAccounts []Account +} + +// ListSchedulableUngroupedByPlatform 仅返回未分组账号(AccountGroups 为空) +func (m *groupAwareMockAccountRepo) ListSchedulableUngroupedByPlatform(ctx context.Context, platform string) ([]Account, error) { + var result []Account + for _, acc := range m.allAccounts { + if acc.Platform == platform && acc.IsSchedulable() && len(acc.AccountGroups) == 0 { + result = append(result, acc) + } + } + return result, nil +} + +// ListSchedulableUngroupedByPlatforms 仅返回未分组账号(多平台版本) +func (m *groupAwareMockAccountRepo) ListSchedulableUngroupedByPlatforms(ctx context.Context, platforms []string) ([]Account, error) { + platformSet := make(map[string]bool, len(platforms)) + for _, p := range platforms { + platformSet[p] = true + } + var result []Account + for _, acc := range m.allAccounts { + if platformSet[acc.Platform] && acc.IsSchedulable() && len(acc.AccountGroups) == 0 { + result = append(result, acc) + } + } + return result, nil +} + +// ListSchedulableByGroupIDAndPlatform 返回属于指定分组的账号 +func (m *groupAwareMockAccountRepo) ListSchedulableByGroupIDAndPlatform(ctx context.Context, groupID int64, platform string) ([]Account, error) { + var result []Account + for _, acc := range m.allAccounts { + if acc.Platform == platform && acc.IsSchedulable() && accountBelongsToGroup(acc, groupID) { + result = append(result, acc) + } + } + return result, nil +} + +// ListSchedulableByGroupIDAndPlatforms 返回属于指定分组的账号(多平台版本) +func (m *groupAwareMockAccountRepo) ListSchedulableByGroupIDAndPlatforms(ctx context.Context, groupID int64, platforms []string) ([]Account, error) { + platformSet := make(map[string]bool, len(platforms)) + for _, p := range platforms { + platformSet[p] = true + } + var result []Account + for _, acc := range m.allAccounts { + if platformSet[acc.Platform] && acc.IsSchedulable() && accountBelongsToGroup(acc, groupID) { + result = append(result, acc) + } + } + return result, nil +} + +// accountBelongsToGroup 检查账号是否属于指定分组 +func accountBelongsToGroup(acc Account, groupID int64) bool { + for _, ag := range acc.AccountGroups { + if ag.GroupID == groupID { + return true + } + } + return false +} + +// Verify interface implementation +var _ AccountRepository = (*groupAwareMockAccountRepo)(nil) + +// newGroupAwareMockRepo 创建分组感知的 mock repo +func newGroupAwareMockRepo(accounts []Account) *groupAwareMockAccountRepo { + byID := make(map[int64]*Account, len(accounts)) + for i := range accounts { + byID[accounts[i].ID] = &accounts[i] + } + return &groupAwareMockAccountRepo{ + mockAccountRepoForPlatform: &mockAccountRepoForPlatform{ + accounts: accounts, + accountsByID: byID, + }, + allAccounts: accounts, + } +} + +func TestGroupIsolation_UngroupedKey_ShouldNotScheduleGroupedAccounts(t *testing.T) { + // 场景:无分组 API Key(groupID=nil),池中只有已分组账号 → 应返回错误 + ctx := context.Background() + + accounts := []Account{ + {ID: 1, Platform: PlatformOpenAI, Priority: 1, Status: StatusActive, Schedulable: true, + AccountGroups: []AccountGroup{{GroupID: 100}}}, + {ID: 2, Platform: PlatformOpenAI, Priority: 2, Status: StatusActive, Schedulable: true, + AccountGroups: []AccountGroup{{GroupID: 200}}}, + } + repo := newGroupAwareMockRepo(accounts) + cache := &mockGatewayCacheForPlatform{} + + svc := &GatewayService{ + accountRepo: repo, + cache: cache, + cfg: testConfig(), + } + + acc, err := svc.selectAccountForModelWithPlatform(ctx, nil, "", "", nil, PlatformOpenAI) + require.Error(t, err, "无分组 Key 不应调度到已分组账号") + require.Nil(t, acc) +} + +func TestGroupIsolation_GroupedKey_ShouldNotScheduleUngroupedAccounts(t *testing.T) { + // 场景:有分组 API Key(groupID=100),池中只有未分组账号 → 应返回错误 + ctx := context.Background() + groupID := int64(100) + + accounts := []Account{ + {ID: 1, Platform: PlatformOpenAI, Priority: 1, Status: StatusActive, Schedulable: true, + AccountGroups: nil}, + {ID: 2, Platform: PlatformOpenAI, Priority: 2, Status: StatusActive, Schedulable: true, + AccountGroups: []AccountGroup{}}, + } + repo := newGroupAwareMockRepo(accounts) + cache := &mockGatewayCacheForPlatform{} + + svc := &GatewayService{ + accountRepo: repo, + cache: cache, + cfg: testConfig(), + } + + acc, err := svc.selectAccountForModelWithPlatform(ctx, &groupID, "", "", nil, PlatformOpenAI) + require.Error(t, err, "有分组 Key 不应调度到未分组账号") + require.Nil(t, acc) +} + +func TestGroupIsolation_UngroupedKey_ShouldOnlyScheduleUngroupedAccounts(t *testing.T) { + // 场景:无分组 API Key(groupID=nil),池中有未分组和已分组账号 → 应只选中未分组的 + ctx := context.Background() + + accounts := []Account{ + {ID: 1, Platform: PlatformOpenAI, Priority: 1, Status: StatusActive, Schedulable: true, + AccountGroups: []AccountGroup{{GroupID: 100}}}, // 已分组,不应被选中 + {ID: 2, Platform: PlatformOpenAI, Priority: 2, Status: StatusActive, Schedulable: true, + AccountGroups: nil}, // 未分组,应被选中 + {ID: 3, Platform: PlatformOpenAI, Priority: 3, Status: StatusActive, Schedulable: true, + AccountGroups: []AccountGroup{{GroupID: 200}}}, // 已分组,不应被选中 + } + repo := newGroupAwareMockRepo(accounts) + cache := &mockGatewayCacheForPlatform{} + + svc := &GatewayService{ + accountRepo: repo, + cache: cache, + cfg: testConfig(), + } + + acc, err := svc.selectAccountForModelWithPlatform(ctx, nil, "", "", nil, PlatformOpenAI) + require.NoError(t, err, "应成功调度未分组账号") + require.NotNil(t, acc) + require.Equal(t, int64(2), acc.ID, "应选中未分组的账号 ID=2") +} + +func TestGroupIsolation_GroupedKey_ShouldOnlyScheduleMatchingGroupAccounts(t *testing.T) { + // 场景:有分组 API Key(groupID=100),池中有未分组和多个分组账号 → 应只选中分组 100 内的 + ctx := context.Background() + groupID := int64(100) + + accounts := []Account{ + {ID: 1, Platform: PlatformOpenAI, Priority: 1, Status: StatusActive, Schedulable: true, + AccountGroups: nil}, // 未分组,不应被选中 + {ID: 2, Platform: PlatformOpenAI, Priority: 2, Status: StatusActive, Schedulable: true, + AccountGroups: []AccountGroup{{GroupID: 200}}}, // 属于分组 200,不应被选中 + {ID: 3, Platform: PlatformOpenAI, Priority: 3, Status: StatusActive, Schedulable: true, + AccountGroups: []AccountGroup{{GroupID: 100}}}, // 属于分组 100,应被选中 + } + repo := newGroupAwareMockRepo(accounts) + cache := &mockGatewayCacheForPlatform{} + + svc := &GatewayService{ + accountRepo: repo, + cache: cache, + cfg: testConfig(), + } + + acc, err := svc.selectAccountForModelWithPlatform(ctx, &groupID, "", "", nil, PlatformOpenAI) + require.NoError(t, err, "应成功调度分组内账号") + require.NotNil(t, acc) + require.Equal(t, int64(3), acc.ID, "应选中分组 100 内的账号 ID=3") +} + +// ============================================================================ +// Part 3: SimpleMode 旁路测试 +// ============================================================================ + +func TestGroupIsolation_SimpleMode_SkipsGroupIsolation(t *testing.T) { + // SimpleMode 应跳过分组隔离,使用 ListSchedulableByPlatform 返回所有账号。 + // 测试非 useMixed 路径(platform=openai,不会触发 mixed 调度逻辑)。 + ctx := context.Background() + + // 混合未分组和已分组账号,SimpleMode 下应全部可调度 + accounts := []Account{ + {ID: 1, Platform: PlatformOpenAI, Priority: 2, Status: StatusActive, Schedulable: true, + AccountGroups: []AccountGroup{{GroupID: 100}}}, // 已分组 + {ID: 2, Platform: PlatformOpenAI, Priority: 1, Status: StatusActive, Schedulable: true, + AccountGroups: nil}, // 未分组 + } + + // 使用基础 mock(ListSchedulableByPlatform 返回所有匹配平台的账号,不做分组过滤) + byID := make(map[int64]*Account, len(accounts)) + for i := range accounts { + byID[accounts[i].ID] = &accounts[i] + } + repo := &mockAccountRepoForPlatform{ + accounts: accounts, + accountsByID: byID, + } + cache := &mockGatewayCacheForPlatform{} + + svc := &GatewayService{ + accountRepo: repo, + cache: cache, + cfg: &config.Config{RunMode: config.RunModeSimple}, + } + + // groupID=nil 时,SimpleMode 应使用 ListSchedulableByPlatform(不过滤分组) + acc, err := svc.selectAccountForModelWithPlatform(ctx, nil, "", "", nil, PlatformOpenAI) + require.NoError(t, err, "SimpleMode 应跳过分组隔离直接返回账号") + require.NotNil(t, acc) + // 应选择优先级最高的账号(Priority=1, ID=2),即使它未分组 + require.Equal(t, int64(2), acc.ID, "SimpleMode 应按优先级选择,不考虑分组") +} + +func TestGroupIsolation_SimpleMode_GroupedAccountAlsoSchedulable(t *testing.T) { + // SimpleMode + groupID=nil 时,已分组账号也应该可被调度 + ctx := context.Background() + + // 只有已分组账号,在 standard 模式下 groupID=nil 会报错,但 simple 模式应正常 + accounts := []Account{ + {ID: 1, Platform: PlatformOpenAI, Priority: 1, Status: StatusActive, Schedulable: true, + AccountGroups: []AccountGroup{{GroupID: 100}}}, + } + + byID := make(map[int64]*Account, len(accounts)) + for i := range accounts { + byID[accounts[i].ID] = &accounts[i] + } + repo := &mockAccountRepoForPlatform{ + accounts: accounts, + accountsByID: byID, + } + cache := &mockGatewayCacheForPlatform{} + + svc := &GatewayService{ + accountRepo: repo, + cache: cache, + cfg: &config.Config{RunMode: config.RunModeSimple}, + } + + acc, err := svc.selectAccountForModelWithPlatform(ctx, nil, "", "", nil, PlatformOpenAI) + require.NoError(t, err, "SimpleMode 下已分组账号也应可调度") + require.NotNil(t, acc) + require.Equal(t, int64(1), acc.ID, "SimpleMode 应能调度已分组账号") +} diff --git a/backend/internal/service/gateway_multiplatform_test.go b/backend/internal/service/gateway_multiplatform_test.go index 067a0e08..1cb3c61e 100644 --- a/backend/internal/service/gateway_multiplatform_test.go +++ b/backend/internal/service/gateway_multiplatform_test.go @@ -147,6 +147,12 @@ func (m *mockAccountRepoForPlatform) ListSchedulableByPlatforms(ctx context.Cont func (m *mockAccountRepoForPlatform) ListSchedulableByGroupIDAndPlatforms(ctx context.Context, groupID int64, platforms []string) ([]Account, error) { return m.ListSchedulableByPlatforms(ctx, platforms) } +func (m *mockAccountRepoForPlatform) ListSchedulableUngroupedByPlatform(ctx context.Context, platform string) ([]Account, error) { + return m.ListSchedulableByPlatform(ctx, platform) +} +func (m *mockAccountRepoForPlatform) ListSchedulableUngroupedByPlatforms(ctx context.Context, platforms []string) ([]Account, error) { + return m.ListSchedulableByPlatforms(ctx, platforms) +} func (m *mockAccountRepoForPlatform) SetRateLimited(ctx context.Context, id int64, resetAt time.Time) error { return nil } diff --git a/backend/internal/service/gateway_service.go b/backend/internal/service/gateway_service.go index 48c69881..fa9a3cb1 100644 --- a/backend/internal/service/gateway_service.go +++ b/backend/internal/service/gateway_service.go @@ -1782,8 +1782,10 @@ func (s *GatewayService) listSchedulableAccounts(ctx context.Context, groupID *i var err error if groupID != nil { accounts, err = s.accountRepo.ListSchedulableByGroupIDAndPlatforms(ctx, *groupID, platforms) - } else { + } else if s.cfg != nil && s.cfg.RunMode == config.RunModeSimple { accounts, err = s.accountRepo.ListSchedulableByPlatforms(ctx, platforms) + } else { + accounts, err = s.accountRepo.ListSchedulableUngroupedByPlatforms(ctx, platforms) } if err != nil { slog.Debug("account_scheduling_list_failed", @@ -1824,7 +1826,7 @@ func (s *GatewayService) listSchedulableAccounts(ctx context.Context, groupID *i accounts, err = s.accountRepo.ListSchedulableByGroupIDAndPlatform(ctx, *groupID, platform) // 分组内无账号则返回空列表,由上层处理错误,不再回退到全平台查询 } else { - accounts, err = s.accountRepo.ListSchedulableByPlatform(ctx, platform) + accounts, err = s.accountRepo.ListSchedulableUngroupedByPlatform(ctx, platform) } if err != nil { slog.Debug("account_scheduling_list_failed", @@ -1964,14 +1966,15 @@ func (s *GatewayService) isAccountSchedulableForModelSelection(ctx context.Conte } // isAccountInGroup checks if the account belongs to the specified group. -// Returns true if groupID is nil (no group restriction) or account belongs to the group. +// When groupID is nil, returns true only for ungrouped accounts (no group assignments). func (s *GatewayService) isAccountInGroup(account *Account, groupID *int64) bool { - if groupID == nil { - return true // 无分组限制 - } if account == nil { return false } + if groupID == nil { + // 无分组的 API Key 只能使用未分组的账号 + return len(account.AccountGroups) == 0 + } for _, ag := range account.AccountGroups { if ag.GroupID == *groupID { return true diff --git a/backend/internal/service/gemini_messages_compat_service.go b/backend/internal/service/gemini_messages_compat_service.go index 1c38b6c2..a003f636 100644 --- a/backend/internal/service/gemini_messages_compat_service.go +++ b/backend/internal/service/gemini_messages_compat_service.go @@ -431,7 +431,10 @@ func (s *GeminiMessagesCompatService) listSchedulableAccountsOnce(ctx context.Co if groupID != nil { return s.accountRepo.ListSchedulableByGroupIDAndPlatforms(ctx, *groupID, queryPlatforms) } - return s.accountRepo.ListSchedulableByPlatforms(ctx, queryPlatforms) + if s.cfg != nil && s.cfg.RunMode == config.RunModeSimple { + return s.accountRepo.ListSchedulableByPlatforms(ctx, queryPlatforms) + } + return s.accountRepo.ListSchedulableUngroupedByPlatforms(ctx, queryPlatforms) } func (s *GeminiMessagesCompatService) validateUpstreamBaseURL(raw string) (string, error) { diff --git a/backend/internal/service/gemini_multiplatform_test.go b/backend/internal/service/gemini_multiplatform_test.go index 86bc9476..9476e984 100644 --- a/backend/internal/service/gemini_multiplatform_test.go +++ b/backend/internal/service/gemini_multiplatform_test.go @@ -138,6 +138,12 @@ func (m *mockAccountRepoForGemini) ListSchedulableByGroupIDAndPlatforms(ctx cont } return m.ListSchedulableByPlatforms(ctx, platforms) } +func (m *mockAccountRepoForGemini) ListSchedulableUngroupedByPlatform(ctx context.Context, platform string) ([]Account, error) { + return m.ListSchedulableByPlatform(ctx, platform) +} +func (m *mockAccountRepoForGemini) ListSchedulableUngroupedByPlatforms(ctx context.Context, platforms []string) ([]Account, error) { + return m.ListSchedulableByPlatforms(ctx, platforms) +} func (m *mockAccountRepoForGemini) SetRateLimited(ctx context.Context, id int64, resetAt time.Time) error { return nil } diff --git a/backend/internal/service/openai_gateway_service.go b/backend/internal/service/openai_gateway_service.go index f624d92a..02db384f 100644 --- a/backend/internal/service/openai_gateway_service.go +++ b/backend/internal/service/openai_gateway_service.go @@ -1343,7 +1343,7 @@ func (s *OpenAIGatewayService) listSchedulableAccounts(ctx context.Context, grou } else if groupID != nil { accounts, err = s.accountRepo.ListSchedulableByGroupIDAndPlatform(ctx, *groupID, PlatformOpenAI) } else { - accounts, err = s.accountRepo.ListSchedulableByPlatform(ctx, PlatformOpenAI) + accounts, err = s.accountRepo.ListSchedulableUngroupedByPlatform(ctx, PlatformOpenAI) } if err != nil { return nil, fmt.Errorf("query accounts failed: %w", err) diff --git a/backend/internal/service/openai_gateway_service_test.go b/backend/internal/service/openai_gateway_service_test.go index 89443b69..4f5f7f3c 100644 --- a/backend/internal/service/openai_gateway_service_test.go +++ b/backend/internal/service/openai_gateway_service_test.go @@ -57,6 +57,10 @@ func (r stubOpenAIAccountRepo) ListSchedulableByPlatform(ctx context.Context, pl return result, nil } +func (r stubOpenAIAccountRepo) ListSchedulableUngroupedByPlatform(ctx context.Context, platform string) ([]Account, error) { + return r.ListSchedulableByPlatform(ctx, platform) +} + type stubConcurrencyCache struct { ConcurrencyCache loadBatchErr error diff --git a/backend/internal/service/scheduler_snapshot_service.go b/backend/internal/service/scheduler_snapshot_service.go index 9f8fa14a..4c9540f1 100644 --- a/backend/internal/service/scheduler_snapshot_service.go +++ b/backend/internal/service/scheduler_snapshot_service.go @@ -605,8 +605,10 @@ func (s *SchedulerSnapshotService) loadAccountsFromDB(ctx context.Context, bucke var err error if groupID > 0 { accounts, err = s.accountRepo.ListSchedulableByGroupIDAndPlatforms(ctx, groupID, platforms) - } else { + } else if s.isRunModeSimple() { accounts, err = s.accountRepo.ListSchedulableByPlatforms(ctx, platforms) + } else { + accounts, err = s.accountRepo.ListSchedulableUngroupedByPlatforms(ctx, platforms) } if err != nil { return nil, err @@ -624,7 +626,10 @@ func (s *SchedulerSnapshotService) loadAccountsFromDB(ctx context.Context, bucke if groupID > 0 { return s.accountRepo.ListSchedulableByGroupIDAndPlatform(ctx, groupID, bucket.Platform) } - return s.accountRepo.ListSchedulableByPlatform(ctx, bucket.Platform) + if s.isRunModeSimple() { + return s.accountRepo.ListSchedulableByPlatform(ctx, bucket.Platform) + } + return s.accountRepo.ListSchedulableUngroupedByPlatform(ctx, bucket.Platform) } func (s *SchedulerSnapshotService) bucketFor(groupID *int64, platform string, mode string) SchedulerBucket {