fix(gateway): 分组隔离 — 禁止未分组账号被跨组调度

当 API Key 无分组时,调度仅从未分组账号池中选取。
修复 isAccountInGroup 在 groupID==nil 时的逻辑,
同时补全 scheduler_snapshot_service 和 gemini_compat_service
中的 SimpleMode 保护,确保分组隔离在所有调度路径生效。

新增 ListSchedulableUngroupedByPlatform/s 方法,
使用 Ent 的 Not(HasAccountGroups()) 谓词实现未分组账号隔离。
新增 17 个单元和端到端隔离测试,覆盖所有分支和边界条件。
This commit is contained in:
QTom
2026-03-03 13:10:26 +08:00
parent 9792b17597
commit 530a16291c
14 changed files with 475 additions and 10 deletions

View File

@@ -2089,6 +2089,12 @@ func (r *stubAccountRepoForHandler) ListSchedulableByPlatforms(context.Context,
func (r *stubAccountRepoForHandler) ListSchedulableByGroupIDAndPlatforms(context.Context, int64, []string) ([]service.Account, error) { func (r *stubAccountRepoForHandler) ListSchedulableByGroupIDAndPlatforms(context.Context, int64, []string) ([]service.Account, error) {
return r.accounts, nil return r.accounts, nil
} }
func (r *stubAccountRepoForHandler) ListSchedulableUngroupedByPlatform(_ context.Context, _ string) ([]service.Account, error) {
return r.accounts, nil
}
func (r *stubAccountRepoForHandler) ListSchedulableUngroupedByPlatforms(_ context.Context, _ []string) ([]service.Account, error) {
return r.accounts, nil
}
func (r *stubAccountRepoForHandler) SetRateLimited(context.Context, int64, time.Time) error { func (r *stubAccountRepoForHandler) SetRateLimited(context.Context, int64, time.Time) error {
return nil return nil
} }

View File

@@ -182,6 +182,12 @@ func (r *stubAccountRepo) ListSchedulableByPlatforms(ctx context.Context, platfo
func (r *stubAccountRepo) ListSchedulableByGroupIDAndPlatforms(ctx context.Context, groupID int64, platforms []string) ([]service.Account, error) { func (r *stubAccountRepo) ListSchedulableByGroupIDAndPlatforms(ctx context.Context, groupID int64, platforms []string) ([]service.Account, error) {
return r.ListSchedulableByPlatforms(ctx, platforms) return r.ListSchedulableByPlatforms(ctx, platforms)
} }
func (r *stubAccountRepo) ListSchedulableUngroupedByPlatform(ctx context.Context, platform string) ([]service.Account, error) {
return r.ListSchedulableByPlatform(ctx, platform)
}
func (r *stubAccountRepo) ListSchedulableUngroupedByPlatforms(ctx context.Context, platforms []string) ([]service.Account, error) {
return r.ListSchedulableByPlatforms(ctx, platforms)
}
func (r *stubAccountRepo) SetRateLimited(ctx context.Context, id int64, resetAt time.Time) error { func (r *stubAccountRepo) SetRateLimited(ctx context.Context, id int64, resetAt time.Time) error {
return nil return nil
} }

View File

@@ -829,6 +829,51 @@ func (r *accountRepository) ListSchedulableByPlatforms(ctx context.Context, plat
return r.accountsToService(ctx, accounts) return r.accountsToService(ctx, accounts)
} }
func (r *accountRepository) ListSchedulableUngroupedByPlatform(ctx context.Context, platform string) ([]service.Account, error) {
now := time.Now()
accounts, err := r.client.Account.Query().
Where(
dbaccount.PlatformEQ(platform),
dbaccount.StatusEQ(service.StatusActive),
dbaccount.SchedulableEQ(true),
dbaccount.Not(dbaccount.HasAccountGroups()),
tempUnschedulablePredicate(),
notExpiredPredicate(now),
dbaccount.Or(dbaccount.OverloadUntilIsNil(), dbaccount.OverloadUntilLTE(now)),
dbaccount.Or(dbaccount.RateLimitResetAtIsNil(), dbaccount.RateLimitResetAtLTE(now)),
).
Order(dbent.Asc(dbaccount.FieldPriority)).
All(ctx)
if err != nil {
return nil, err
}
return r.accountsToService(ctx, accounts)
}
func (r *accountRepository) ListSchedulableUngroupedByPlatforms(ctx context.Context, platforms []string) ([]service.Account, error) {
if len(platforms) == 0 {
return nil, nil
}
now := time.Now()
accounts, err := r.client.Account.Query().
Where(
dbaccount.PlatformIn(platforms...),
dbaccount.StatusEQ(service.StatusActive),
dbaccount.SchedulableEQ(true),
dbaccount.Not(dbaccount.HasAccountGroups()),
tempUnschedulablePredicate(),
notExpiredPredicate(now),
dbaccount.Or(dbaccount.OverloadUntilIsNil(), dbaccount.OverloadUntilLTE(now)),
dbaccount.Or(dbaccount.RateLimitResetAtIsNil(), dbaccount.RateLimitResetAtLTE(now)),
).
Order(dbent.Asc(dbaccount.FieldPriority)).
All(ctx)
if err != nil {
return nil, err
}
return r.accountsToService(ctx, accounts)
}
func (r *accountRepository) ListSchedulableByGroupIDAndPlatforms(ctx context.Context, groupID int64, platforms []string) ([]service.Account, error) { func (r *accountRepository) ListSchedulableByGroupIDAndPlatforms(ctx context.Context, groupID int64, platforms []string) ([]service.Account, error) {
if len(platforms) == 0 { if len(platforms) == 0 {
return nil, nil return nil, nil

View File

@@ -1027,6 +1027,14 @@ func (s *stubAccountRepo) ListSchedulableByGroupIDAndPlatforms(ctx context.Conte
return nil, errors.New("not implemented") return nil, errors.New("not implemented")
} }
func (s *stubAccountRepo) ListSchedulableUngroupedByPlatform(ctx context.Context, platform string) ([]service.Account, error) {
return nil, errors.New("not implemented")
}
func (s *stubAccountRepo) ListSchedulableUngroupedByPlatforms(ctx context.Context, platforms []string) ([]service.Account, error) {
return nil, errors.New("not implemented")
}
func (s *stubAccountRepo) SetRateLimited(ctx context.Context, id int64, resetAt time.Time) error { func (s *stubAccountRepo) SetRateLimited(ctx context.Context, id int64, resetAt time.Time) error {
return errors.New("not implemented") return errors.New("not implemented")
} }

View File

@@ -54,6 +54,8 @@ type AccountRepository interface {
ListSchedulableByGroupIDAndPlatform(ctx context.Context, groupID int64, platform string) ([]Account, error) ListSchedulableByGroupIDAndPlatform(ctx context.Context, groupID int64, platform string) ([]Account, error)
ListSchedulableByPlatforms(ctx context.Context, platforms []string) ([]Account, error) ListSchedulableByPlatforms(ctx context.Context, platforms []string) ([]Account, error)
ListSchedulableByGroupIDAndPlatforms(ctx context.Context, groupID int64, platforms []string) ([]Account, error) ListSchedulableByGroupIDAndPlatforms(ctx context.Context, groupID int64, platforms []string) ([]Account, error)
ListSchedulableUngroupedByPlatform(ctx context.Context, platform string) ([]Account, error)
ListSchedulableUngroupedByPlatforms(ctx context.Context, platforms []string) ([]Account, error)
SetRateLimited(ctx context.Context, id int64, resetAt time.Time) error SetRateLimited(ctx context.Context, id int64, resetAt time.Time) error
SetModelRateLimit(ctx context.Context, id int64, scope string, resetAt time.Time) error SetModelRateLimit(ctx context.Context, id int64, scope string, resetAt time.Time) error

View File

@@ -147,6 +147,14 @@ func (s *accountRepoStub) ListSchedulableByGroupIDAndPlatforms(ctx context.Conte
panic("unexpected ListSchedulableByGroupIDAndPlatforms call") panic("unexpected ListSchedulableByGroupIDAndPlatforms call")
} }
func (s *accountRepoStub) ListSchedulableUngroupedByPlatform(ctx context.Context, platform string) ([]Account, error) {
panic("unexpected ListSchedulableUngroupedByPlatform call")
}
func (s *accountRepoStub) ListSchedulableUngroupedByPlatforms(ctx context.Context, platforms []string) ([]Account, error) {
panic("unexpected ListSchedulableUngroupedByPlatforms call")
}
func (s *accountRepoStub) SetRateLimited(ctx context.Context, id int64, resetAt time.Time) error { func (s *accountRepoStub) SetRateLimited(ctx context.Context, id int64, resetAt time.Time) error {
panic("unexpected SetRateLimited call") panic("unexpected SetRateLimited call")
} }

View File

@@ -0,0 +1,363 @@
//go:build unit
package service
import (
"context"
"testing"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/stretchr/testify/require"
)
// ============================================================================
// Part 1: isAccountInGroup 单元测试
// ============================================================================
func TestIsAccountInGroup(t *testing.T) {
svc := &GatewayService{}
groupID100 := int64(100)
groupID200 := int64(200)
tests := []struct {
name string
account *Account
groupID *int64
expected bool
}{
// groupID == nil无分组 API Key
{
"nil_groupID_ungrouped_account_nil_groups",
&Account{ID: 1, AccountGroups: nil},
nil, true,
},
{
"nil_groupID_ungrouped_account_empty_slice",
&Account{ID: 2, AccountGroups: []AccountGroup{}},
nil, true,
},
{
"nil_groupID_grouped_account_single",
&Account{ID: 3, AccountGroups: []AccountGroup{{GroupID: 100}}},
nil, false,
},
{
"nil_groupID_grouped_account_multiple",
&Account{ID: 4, AccountGroups: []AccountGroup{{GroupID: 100}, {GroupID: 200}}},
nil, false,
},
// groupID != nil有分组 API Key
{
"with_groupID_account_in_group",
&Account{ID: 5, AccountGroups: []AccountGroup{{GroupID: 100}}},
&groupID100, true,
},
{
"with_groupID_account_not_in_group",
&Account{ID: 6, AccountGroups: []AccountGroup{{GroupID: 200}}},
&groupID100, false,
},
{
"with_groupID_ungrouped_account",
&Account{ID: 7, AccountGroups: nil},
&groupID100, false,
},
{
"with_groupID_multi_group_account_match_one",
&Account{ID: 8, AccountGroups: []AccountGroup{{GroupID: 100}, {GroupID: 200}}},
&groupID200, true,
},
{
"with_groupID_multi_group_account_no_match",
&Account{ID: 9, AccountGroups: []AccountGroup{{GroupID: 300}, {GroupID: 400}}},
&groupID100, false,
},
// 防御性边界
{
"nil_account_nil_groupID",
nil,
nil, false,
},
{
"nil_account_with_groupID",
nil,
&groupID100, false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := svc.isAccountInGroup(tt.account, tt.groupID)
require.Equal(t, tt.expected, got, "isAccountInGroup 结果不符预期")
})
}
}
// ============================================================================
// Part 2: 分组隔离端到端调度测试
// ============================================================================
// groupAwareMockAccountRepo 嵌入 mockAccountRepoForPlatform覆写分组隔离相关方法。
// allAccounts 存储所有账号,分组查询方法按 AccountGroups 字段进行真实过滤。
type groupAwareMockAccountRepo struct {
*mockAccountRepoForPlatform
allAccounts []Account
}
// ListSchedulableUngroupedByPlatform 仅返回未分组账号AccountGroups 为空)
func (m *groupAwareMockAccountRepo) ListSchedulableUngroupedByPlatform(ctx context.Context, platform string) ([]Account, error) {
var result []Account
for _, acc := range m.allAccounts {
if acc.Platform == platform && acc.IsSchedulable() && len(acc.AccountGroups) == 0 {
result = append(result, acc)
}
}
return result, nil
}
// ListSchedulableUngroupedByPlatforms 仅返回未分组账号(多平台版本)
func (m *groupAwareMockAccountRepo) ListSchedulableUngroupedByPlatforms(ctx context.Context, platforms []string) ([]Account, error) {
platformSet := make(map[string]bool, len(platforms))
for _, p := range platforms {
platformSet[p] = true
}
var result []Account
for _, acc := range m.allAccounts {
if platformSet[acc.Platform] && acc.IsSchedulable() && len(acc.AccountGroups) == 0 {
result = append(result, acc)
}
}
return result, nil
}
// ListSchedulableByGroupIDAndPlatform 返回属于指定分组的账号
func (m *groupAwareMockAccountRepo) ListSchedulableByGroupIDAndPlatform(ctx context.Context, groupID int64, platform string) ([]Account, error) {
var result []Account
for _, acc := range m.allAccounts {
if acc.Platform == platform && acc.IsSchedulable() && accountBelongsToGroup(acc, groupID) {
result = append(result, acc)
}
}
return result, nil
}
// ListSchedulableByGroupIDAndPlatforms 返回属于指定分组的账号(多平台版本)
func (m *groupAwareMockAccountRepo) ListSchedulableByGroupIDAndPlatforms(ctx context.Context, groupID int64, platforms []string) ([]Account, error) {
platformSet := make(map[string]bool, len(platforms))
for _, p := range platforms {
platformSet[p] = true
}
var result []Account
for _, acc := range m.allAccounts {
if platformSet[acc.Platform] && acc.IsSchedulable() && accountBelongsToGroup(acc, groupID) {
result = append(result, acc)
}
}
return result, nil
}
// accountBelongsToGroup 检查账号是否属于指定分组
func accountBelongsToGroup(acc Account, groupID int64) bool {
for _, ag := range acc.AccountGroups {
if ag.GroupID == groupID {
return true
}
}
return false
}
// Verify interface implementation
var _ AccountRepository = (*groupAwareMockAccountRepo)(nil)
// newGroupAwareMockRepo 创建分组感知的 mock repo
func newGroupAwareMockRepo(accounts []Account) *groupAwareMockAccountRepo {
byID := make(map[int64]*Account, len(accounts))
for i := range accounts {
byID[accounts[i].ID] = &accounts[i]
}
return &groupAwareMockAccountRepo{
mockAccountRepoForPlatform: &mockAccountRepoForPlatform{
accounts: accounts,
accountsByID: byID,
},
allAccounts: accounts,
}
}
func TestGroupIsolation_UngroupedKey_ShouldNotScheduleGroupedAccounts(t *testing.T) {
// 场景:无分组 API KeygroupID=nil池中只有已分组账号 → 应返回错误
ctx := context.Background()
accounts := []Account{
{ID: 1, Platform: PlatformOpenAI, Priority: 1, Status: StatusActive, Schedulable: true,
AccountGroups: []AccountGroup{{GroupID: 100}}},
{ID: 2, Platform: PlatformOpenAI, Priority: 2, Status: StatusActive, Schedulable: true,
AccountGroups: []AccountGroup{{GroupID: 200}}},
}
repo := newGroupAwareMockRepo(accounts)
cache := &mockGatewayCacheForPlatform{}
svc := &GatewayService{
accountRepo: repo,
cache: cache,
cfg: testConfig(),
}
acc, err := svc.selectAccountForModelWithPlatform(ctx, nil, "", "", nil, PlatformOpenAI)
require.Error(t, err, "无分组 Key 不应调度到已分组账号")
require.Nil(t, acc)
}
func TestGroupIsolation_GroupedKey_ShouldNotScheduleUngroupedAccounts(t *testing.T) {
// 场景:有分组 API KeygroupID=100池中只有未分组账号 → 应返回错误
ctx := context.Background()
groupID := int64(100)
accounts := []Account{
{ID: 1, Platform: PlatformOpenAI, Priority: 1, Status: StatusActive, Schedulable: true,
AccountGroups: nil},
{ID: 2, Platform: PlatformOpenAI, Priority: 2, Status: StatusActive, Schedulable: true,
AccountGroups: []AccountGroup{}},
}
repo := newGroupAwareMockRepo(accounts)
cache := &mockGatewayCacheForPlatform{}
svc := &GatewayService{
accountRepo: repo,
cache: cache,
cfg: testConfig(),
}
acc, err := svc.selectAccountForModelWithPlatform(ctx, &groupID, "", "", nil, PlatformOpenAI)
require.Error(t, err, "有分组 Key 不应调度到未分组账号")
require.Nil(t, acc)
}
func TestGroupIsolation_UngroupedKey_ShouldOnlyScheduleUngroupedAccounts(t *testing.T) {
// 场景:无分组 API KeygroupID=nil池中有未分组和已分组账号 → 应只选中未分组的
ctx := context.Background()
accounts := []Account{
{ID: 1, Platform: PlatformOpenAI, Priority: 1, Status: StatusActive, Schedulable: true,
AccountGroups: []AccountGroup{{GroupID: 100}}}, // 已分组,不应被选中
{ID: 2, Platform: PlatformOpenAI, Priority: 2, Status: StatusActive, Schedulable: true,
AccountGroups: nil}, // 未分组,应被选中
{ID: 3, Platform: PlatformOpenAI, Priority: 3, Status: StatusActive, Schedulable: true,
AccountGroups: []AccountGroup{{GroupID: 200}}}, // 已分组,不应被选中
}
repo := newGroupAwareMockRepo(accounts)
cache := &mockGatewayCacheForPlatform{}
svc := &GatewayService{
accountRepo: repo,
cache: cache,
cfg: testConfig(),
}
acc, err := svc.selectAccountForModelWithPlatform(ctx, nil, "", "", nil, PlatformOpenAI)
require.NoError(t, err, "应成功调度未分组账号")
require.NotNil(t, acc)
require.Equal(t, int64(2), acc.ID, "应选中未分组的账号 ID=2")
}
func TestGroupIsolation_GroupedKey_ShouldOnlyScheduleMatchingGroupAccounts(t *testing.T) {
// 场景:有分组 API KeygroupID=100池中有未分组和多个分组账号 → 应只选中分组 100 内的
ctx := context.Background()
groupID := int64(100)
accounts := []Account{
{ID: 1, Platform: PlatformOpenAI, Priority: 1, Status: StatusActive, Schedulable: true,
AccountGroups: nil}, // 未分组,不应被选中
{ID: 2, Platform: PlatformOpenAI, Priority: 2, Status: StatusActive, Schedulable: true,
AccountGroups: []AccountGroup{{GroupID: 200}}}, // 属于分组 200不应被选中
{ID: 3, Platform: PlatformOpenAI, Priority: 3, Status: StatusActive, Schedulable: true,
AccountGroups: []AccountGroup{{GroupID: 100}}}, // 属于分组 100应被选中
}
repo := newGroupAwareMockRepo(accounts)
cache := &mockGatewayCacheForPlatform{}
svc := &GatewayService{
accountRepo: repo,
cache: cache,
cfg: testConfig(),
}
acc, err := svc.selectAccountForModelWithPlatform(ctx, &groupID, "", "", nil, PlatformOpenAI)
require.NoError(t, err, "应成功调度分组内账号")
require.NotNil(t, acc)
require.Equal(t, int64(3), acc.ID, "应选中分组 100 内的账号 ID=3")
}
// ============================================================================
// Part 3: SimpleMode 旁路测试
// ============================================================================
func TestGroupIsolation_SimpleMode_SkipsGroupIsolation(t *testing.T) {
// SimpleMode 应跳过分组隔离,使用 ListSchedulableByPlatform 返回所有账号。
// 测试非 useMixed 路径platform=openai不会触发 mixed 调度逻辑)。
ctx := context.Background()
// 混合未分组和已分组账号SimpleMode 下应全部可调度
accounts := []Account{
{ID: 1, Platform: PlatformOpenAI, Priority: 2, Status: StatusActive, Schedulable: true,
AccountGroups: []AccountGroup{{GroupID: 100}}}, // 已分组
{ID: 2, Platform: PlatformOpenAI, Priority: 1, Status: StatusActive, Schedulable: true,
AccountGroups: nil}, // 未分组
}
// 使用基础 mockListSchedulableByPlatform 返回所有匹配平台的账号,不做分组过滤)
byID := make(map[int64]*Account, len(accounts))
for i := range accounts {
byID[accounts[i].ID] = &accounts[i]
}
repo := &mockAccountRepoForPlatform{
accounts: accounts,
accountsByID: byID,
}
cache := &mockGatewayCacheForPlatform{}
svc := &GatewayService{
accountRepo: repo,
cache: cache,
cfg: &config.Config{RunMode: config.RunModeSimple},
}
// groupID=nil 时SimpleMode 应使用 ListSchedulableByPlatform不过滤分组
acc, err := svc.selectAccountForModelWithPlatform(ctx, nil, "", "", nil, PlatformOpenAI)
require.NoError(t, err, "SimpleMode 应跳过分组隔离直接返回账号")
require.NotNil(t, acc)
// 应选择优先级最高的账号Priority=1, ID=2即使它未分组
require.Equal(t, int64(2), acc.ID, "SimpleMode 应按优先级选择,不考虑分组")
}
func TestGroupIsolation_SimpleMode_GroupedAccountAlsoSchedulable(t *testing.T) {
// SimpleMode + groupID=nil 时,已分组账号也应该可被调度
ctx := context.Background()
// 只有已分组账号,在 standard 模式下 groupID=nil 会报错,但 simple 模式应正常
accounts := []Account{
{ID: 1, Platform: PlatformOpenAI, Priority: 1, Status: StatusActive, Schedulable: true,
AccountGroups: []AccountGroup{{GroupID: 100}}},
}
byID := make(map[int64]*Account, len(accounts))
for i := range accounts {
byID[accounts[i].ID] = &accounts[i]
}
repo := &mockAccountRepoForPlatform{
accounts: accounts,
accountsByID: byID,
}
cache := &mockGatewayCacheForPlatform{}
svc := &GatewayService{
accountRepo: repo,
cache: cache,
cfg: &config.Config{RunMode: config.RunModeSimple},
}
acc, err := svc.selectAccountForModelWithPlatform(ctx, nil, "", "", nil, PlatformOpenAI)
require.NoError(t, err, "SimpleMode 下已分组账号也应可调度")
require.NotNil(t, acc)
require.Equal(t, int64(1), acc.ID, "SimpleMode 应能调度已分组账号")
}

View File

@@ -147,6 +147,12 @@ func (m *mockAccountRepoForPlatform) ListSchedulableByPlatforms(ctx context.Cont
func (m *mockAccountRepoForPlatform) ListSchedulableByGroupIDAndPlatforms(ctx context.Context, groupID int64, platforms []string) ([]Account, error) { func (m *mockAccountRepoForPlatform) ListSchedulableByGroupIDAndPlatforms(ctx context.Context, groupID int64, platforms []string) ([]Account, error) {
return m.ListSchedulableByPlatforms(ctx, platforms) return m.ListSchedulableByPlatforms(ctx, platforms)
} }
func (m *mockAccountRepoForPlatform) ListSchedulableUngroupedByPlatform(ctx context.Context, platform string) ([]Account, error) {
return m.ListSchedulableByPlatform(ctx, platform)
}
func (m *mockAccountRepoForPlatform) ListSchedulableUngroupedByPlatforms(ctx context.Context, platforms []string) ([]Account, error) {
return m.ListSchedulableByPlatforms(ctx, platforms)
}
func (m *mockAccountRepoForPlatform) SetRateLimited(ctx context.Context, id int64, resetAt time.Time) error { func (m *mockAccountRepoForPlatform) SetRateLimited(ctx context.Context, id int64, resetAt time.Time) error {
return nil return nil
} }

View File

@@ -1782,8 +1782,10 @@ func (s *GatewayService) listSchedulableAccounts(ctx context.Context, groupID *i
var err error var err error
if groupID != nil { if groupID != nil {
accounts, err = s.accountRepo.ListSchedulableByGroupIDAndPlatforms(ctx, *groupID, platforms) accounts, err = s.accountRepo.ListSchedulableByGroupIDAndPlatforms(ctx, *groupID, platforms)
} else { } else if s.cfg != nil && s.cfg.RunMode == config.RunModeSimple {
accounts, err = s.accountRepo.ListSchedulableByPlatforms(ctx, platforms) accounts, err = s.accountRepo.ListSchedulableByPlatforms(ctx, platforms)
} else {
accounts, err = s.accountRepo.ListSchedulableUngroupedByPlatforms(ctx, platforms)
} }
if err != nil { if err != nil {
slog.Debug("account_scheduling_list_failed", slog.Debug("account_scheduling_list_failed",
@@ -1824,7 +1826,7 @@ func (s *GatewayService) listSchedulableAccounts(ctx context.Context, groupID *i
accounts, err = s.accountRepo.ListSchedulableByGroupIDAndPlatform(ctx, *groupID, platform) accounts, err = s.accountRepo.ListSchedulableByGroupIDAndPlatform(ctx, *groupID, platform)
// 分组内无账号则返回空列表,由上层处理错误,不再回退到全平台查询 // 分组内无账号则返回空列表,由上层处理错误,不再回退到全平台查询
} else { } else {
accounts, err = s.accountRepo.ListSchedulableByPlatform(ctx, platform) accounts, err = s.accountRepo.ListSchedulableUngroupedByPlatform(ctx, platform)
} }
if err != nil { if err != nil {
slog.Debug("account_scheduling_list_failed", slog.Debug("account_scheduling_list_failed",
@@ -1964,14 +1966,15 @@ func (s *GatewayService) isAccountSchedulableForModelSelection(ctx context.Conte
} }
// isAccountInGroup checks if the account belongs to the specified group. // isAccountInGroup checks if the account belongs to the specified group.
// Returns true if groupID is nil (no group restriction) or account belongs to the group. // When groupID is nil, returns true only for ungrouped accounts (no group assignments).
func (s *GatewayService) isAccountInGroup(account *Account, groupID *int64) bool { func (s *GatewayService) isAccountInGroup(account *Account, groupID *int64) bool {
if groupID == nil {
return true // 无分组限制
}
if account == nil { if account == nil {
return false return false
} }
if groupID == nil {
// 无分组的 API Key 只能使用未分组的账号
return len(account.AccountGroups) == 0
}
for _, ag := range account.AccountGroups { for _, ag := range account.AccountGroups {
if ag.GroupID == *groupID { if ag.GroupID == *groupID {
return true return true

View File

@@ -431,7 +431,10 @@ func (s *GeminiMessagesCompatService) listSchedulableAccountsOnce(ctx context.Co
if groupID != nil { if groupID != nil {
return s.accountRepo.ListSchedulableByGroupIDAndPlatforms(ctx, *groupID, queryPlatforms) return s.accountRepo.ListSchedulableByGroupIDAndPlatforms(ctx, *groupID, queryPlatforms)
} }
return s.accountRepo.ListSchedulableByPlatforms(ctx, queryPlatforms) if s.cfg != nil && s.cfg.RunMode == config.RunModeSimple {
return s.accountRepo.ListSchedulableByPlatforms(ctx, queryPlatforms)
}
return s.accountRepo.ListSchedulableUngroupedByPlatforms(ctx, queryPlatforms)
} }
func (s *GeminiMessagesCompatService) validateUpstreamBaseURL(raw string) (string, error) { func (s *GeminiMessagesCompatService) validateUpstreamBaseURL(raw string) (string, error) {

View File

@@ -138,6 +138,12 @@ func (m *mockAccountRepoForGemini) ListSchedulableByGroupIDAndPlatforms(ctx cont
} }
return m.ListSchedulableByPlatforms(ctx, platforms) return m.ListSchedulableByPlatforms(ctx, platforms)
} }
func (m *mockAccountRepoForGemini) ListSchedulableUngroupedByPlatform(ctx context.Context, platform string) ([]Account, error) {
return m.ListSchedulableByPlatform(ctx, platform)
}
func (m *mockAccountRepoForGemini) ListSchedulableUngroupedByPlatforms(ctx context.Context, platforms []string) ([]Account, error) {
return m.ListSchedulableByPlatforms(ctx, platforms)
}
func (m *mockAccountRepoForGemini) SetRateLimited(ctx context.Context, id int64, resetAt time.Time) error { func (m *mockAccountRepoForGemini) SetRateLimited(ctx context.Context, id int64, resetAt time.Time) error {
return nil return nil
} }

View File

@@ -1343,7 +1343,7 @@ func (s *OpenAIGatewayService) listSchedulableAccounts(ctx context.Context, grou
} else if groupID != nil { } else if groupID != nil {
accounts, err = s.accountRepo.ListSchedulableByGroupIDAndPlatform(ctx, *groupID, PlatformOpenAI) accounts, err = s.accountRepo.ListSchedulableByGroupIDAndPlatform(ctx, *groupID, PlatformOpenAI)
} else { } else {
accounts, err = s.accountRepo.ListSchedulableByPlatform(ctx, PlatformOpenAI) accounts, err = s.accountRepo.ListSchedulableUngroupedByPlatform(ctx, PlatformOpenAI)
} }
if err != nil { if err != nil {
return nil, fmt.Errorf("query accounts failed: %w", err) return nil, fmt.Errorf("query accounts failed: %w", err)

View File

@@ -57,6 +57,10 @@ func (r stubOpenAIAccountRepo) ListSchedulableByPlatform(ctx context.Context, pl
return result, nil return result, nil
} }
func (r stubOpenAIAccountRepo) ListSchedulableUngroupedByPlatform(ctx context.Context, platform string) ([]Account, error) {
return r.ListSchedulableByPlatform(ctx, platform)
}
type stubConcurrencyCache struct { type stubConcurrencyCache struct {
ConcurrencyCache ConcurrencyCache
loadBatchErr error loadBatchErr error

View File

@@ -605,8 +605,10 @@ func (s *SchedulerSnapshotService) loadAccountsFromDB(ctx context.Context, bucke
var err error var err error
if groupID > 0 { if groupID > 0 {
accounts, err = s.accountRepo.ListSchedulableByGroupIDAndPlatforms(ctx, groupID, platforms) accounts, err = s.accountRepo.ListSchedulableByGroupIDAndPlatforms(ctx, groupID, platforms)
} else { } else if s.isRunModeSimple() {
accounts, err = s.accountRepo.ListSchedulableByPlatforms(ctx, platforms) accounts, err = s.accountRepo.ListSchedulableByPlatforms(ctx, platforms)
} else {
accounts, err = s.accountRepo.ListSchedulableUngroupedByPlatforms(ctx, platforms)
} }
if err != nil { if err != nil {
return nil, err return nil, err
@@ -624,7 +626,10 @@ func (s *SchedulerSnapshotService) loadAccountsFromDB(ctx context.Context, bucke
if groupID > 0 { if groupID > 0 {
return s.accountRepo.ListSchedulableByGroupIDAndPlatform(ctx, groupID, bucket.Platform) return s.accountRepo.ListSchedulableByGroupIDAndPlatform(ctx, groupID, bucket.Platform)
} }
return s.accountRepo.ListSchedulableByPlatform(ctx, bucket.Platform) if s.isRunModeSimple() {
return s.accountRepo.ListSchedulableByPlatform(ctx, bucket.Platform)
}
return s.accountRepo.ListSchedulableUngroupedByPlatform(ctx, bucket.Platform)
} }
func (s *SchedulerSnapshotService) bucketFor(groupID *int64, platform string, mode string) SchedulerBucket { func (s *SchedulerSnapshotService) bucketFor(groupID *int64, platform string, mode string) SchedulerBucket {