fix(gateway): 分组隔离 — 禁止未分组账号被跨组调度
当 API Key 无分组时,调度仅从未分组账号池中选取。 修复 isAccountInGroup 在 groupID==nil 时的逻辑, 同时补全 scheduler_snapshot_service 和 gemini_compat_service 中的 SimpleMode 保护,确保分组隔离在所有调度路径生效。 新增 ListSchedulableUngroupedByPlatform/s 方法, 使用 Ent 的 Not(HasAccountGroups()) 谓词实现未分组账号隔离。 新增 17 个单元和端到端隔离测试,覆盖所有分支和边界条件。
This commit is contained in:
@@ -2089,6 +2089,12 @@ func (r *stubAccountRepoForHandler) ListSchedulableByPlatforms(context.Context,
|
|||||||
func (r *stubAccountRepoForHandler) ListSchedulableByGroupIDAndPlatforms(context.Context, int64, []string) ([]service.Account, error) {
|
func (r *stubAccountRepoForHandler) ListSchedulableByGroupIDAndPlatforms(context.Context, int64, []string) ([]service.Account, error) {
|
||||||
return r.accounts, nil
|
return r.accounts, nil
|
||||||
}
|
}
|
||||||
|
func (r *stubAccountRepoForHandler) ListSchedulableUngroupedByPlatform(_ context.Context, _ string) ([]service.Account, error) {
|
||||||
|
return r.accounts, nil
|
||||||
|
}
|
||||||
|
func (r *stubAccountRepoForHandler) ListSchedulableUngroupedByPlatforms(_ context.Context, _ []string) ([]service.Account, error) {
|
||||||
|
return r.accounts, nil
|
||||||
|
}
|
||||||
func (r *stubAccountRepoForHandler) SetRateLimited(context.Context, int64, time.Time) error {
|
func (r *stubAccountRepoForHandler) SetRateLimited(context.Context, int64, time.Time) error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -182,6 +182,12 @@ func (r *stubAccountRepo) ListSchedulableByPlatforms(ctx context.Context, platfo
|
|||||||
func (r *stubAccountRepo) ListSchedulableByGroupIDAndPlatforms(ctx context.Context, groupID int64, platforms []string) ([]service.Account, error) {
|
func (r *stubAccountRepo) ListSchedulableByGroupIDAndPlatforms(ctx context.Context, groupID int64, platforms []string) ([]service.Account, error) {
|
||||||
return r.ListSchedulableByPlatforms(ctx, platforms)
|
return r.ListSchedulableByPlatforms(ctx, platforms)
|
||||||
}
|
}
|
||||||
|
func (r *stubAccountRepo) ListSchedulableUngroupedByPlatform(ctx context.Context, platform string) ([]service.Account, error) {
|
||||||
|
return r.ListSchedulableByPlatform(ctx, platform)
|
||||||
|
}
|
||||||
|
func (r *stubAccountRepo) ListSchedulableUngroupedByPlatforms(ctx context.Context, platforms []string) ([]service.Account, error) {
|
||||||
|
return r.ListSchedulableByPlatforms(ctx, platforms)
|
||||||
|
}
|
||||||
func (r *stubAccountRepo) SetRateLimited(ctx context.Context, id int64, resetAt time.Time) error {
|
func (r *stubAccountRepo) SetRateLimited(ctx context.Context, id int64, resetAt time.Time) error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -829,6 +829,51 @@ func (r *accountRepository) ListSchedulableByPlatforms(ctx context.Context, plat
|
|||||||
return r.accountsToService(ctx, accounts)
|
return r.accountsToService(ctx, accounts)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (r *accountRepository) ListSchedulableUngroupedByPlatform(ctx context.Context, platform string) ([]service.Account, error) {
|
||||||
|
now := time.Now()
|
||||||
|
accounts, err := r.client.Account.Query().
|
||||||
|
Where(
|
||||||
|
dbaccount.PlatformEQ(platform),
|
||||||
|
dbaccount.StatusEQ(service.StatusActive),
|
||||||
|
dbaccount.SchedulableEQ(true),
|
||||||
|
dbaccount.Not(dbaccount.HasAccountGroups()),
|
||||||
|
tempUnschedulablePredicate(),
|
||||||
|
notExpiredPredicate(now),
|
||||||
|
dbaccount.Or(dbaccount.OverloadUntilIsNil(), dbaccount.OverloadUntilLTE(now)),
|
||||||
|
dbaccount.Or(dbaccount.RateLimitResetAtIsNil(), dbaccount.RateLimitResetAtLTE(now)),
|
||||||
|
).
|
||||||
|
Order(dbent.Asc(dbaccount.FieldPriority)).
|
||||||
|
All(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return r.accountsToService(ctx, accounts)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *accountRepository) ListSchedulableUngroupedByPlatforms(ctx context.Context, platforms []string) ([]service.Account, error) {
|
||||||
|
if len(platforms) == 0 {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
now := time.Now()
|
||||||
|
accounts, err := r.client.Account.Query().
|
||||||
|
Where(
|
||||||
|
dbaccount.PlatformIn(platforms...),
|
||||||
|
dbaccount.StatusEQ(service.StatusActive),
|
||||||
|
dbaccount.SchedulableEQ(true),
|
||||||
|
dbaccount.Not(dbaccount.HasAccountGroups()),
|
||||||
|
tempUnschedulablePredicate(),
|
||||||
|
notExpiredPredicate(now),
|
||||||
|
dbaccount.Or(dbaccount.OverloadUntilIsNil(), dbaccount.OverloadUntilLTE(now)),
|
||||||
|
dbaccount.Or(dbaccount.RateLimitResetAtIsNil(), dbaccount.RateLimitResetAtLTE(now)),
|
||||||
|
).
|
||||||
|
Order(dbent.Asc(dbaccount.FieldPriority)).
|
||||||
|
All(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return r.accountsToService(ctx, accounts)
|
||||||
|
}
|
||||||
|
|
||||||
func (r *accountRepository) ListSchedulableByGroupIDAndPlatforms(ctx context.Context, groupID int64, platforms []string) ([]service.Account, error) {
|
func (r *accountRepository) ListSchedulableByGroupIDAndPlatforms(ctx context.Context, groupID int64, platforms []string) ([]service.Account, error) {
|
||||||
if len(platforms) == 0 {
|
if len(platforms) == 0 {
|
||||||
return nil, nil
|
return nil, nil
|
||||||
|
|||||||
@@ -1027,6 +1027,14 @@ func (s *stubAccountRepo) ListSchedulableByGroupIDAndPlatforms(ctx context.Conte
|
|||||||
return nil, errors.New("not implemented")
|
return nil, errors.New("not implemented")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *stubAccountRepo) ListSchedulableUngroupedByPlatform(ctx context.Context, platform string) ([]service.Account, error) {
|
||||||
|
return nil, errors.New("not implemented")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *stubAccountRepo) ListSchedulableUngroupedByPlatforms(ctx context.Context, platforms []string) ([]service.Account, error) {
|
||||||
|
return nil, errors.New("not implemented")
|
||||||
|
}
|
||||||
|
|
||||||
func (s *stubAccountRepo) SetRateLimited(ctx context.Context, id int64, resetAt time.Time) error {
|
func (s *stubAccountRepo) SetRateLimited(ctx context.Context, id int64, resetAt time.Time) error {
|
||||||
return errors.New("not implemented")
|
return errors.New("not implemented")
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -54,6 +54,8 @@ type AccountRepository interface {
|
|||||||
ListSchedulableByGroupIDAndPlatform(ctx context.Context, groupID int64, platform string) ([]Account, error)
|
ListSchedulableByGroupIDAndPlatform(ctx context.Context, groupID int64, platform string) ([]Account, error)
|
||||||
ListSchedulableByPlatforms(ctx context.Context, platforms []string) ([]Account, error)
|
ListSchedulableByPlatforms(ctx context.Context, platforms []string) ([]Account, error)
|
||||||
ListSchedulableByGroupIDAndPlatforms(ctx context.Context, groupID int64, platforms []string) ([]Account, error)
|
ListSchedulableByGroupIDAndPlatforms(ctx context.Context, groupID int64, platforms []string) ([]Account, error)
|
||||||
|
ListSchedulableUngroupedByPlatform(ctx context.Context, platform string) ([]Account, error)
|
||||||
|
ListSchedulableUngroupedByPlatforms(ctx context.Context, platforms []string) ([]Account, error)
|
||||||
|
|
||||||
SetRateLimited(ctx context.Context, id int64, resetAt time.Time) error
|
SetRateLimited(ctx context.Context, id int64, resetAt time.Time) error
|
||||||
SetModelRateLimit(ctx context.Context, id int64, scope string, resetAt time.Time) error
|
SetModelRateLimit(ctx context.Context, id int64, scope string, resetAt time.Time) error
|
||||||
|
|||||||
@@ -147,6 +147,14 @@ func (s *accountRepoStub) ListSchedulableByGroupIDAndPlatforms(ctx context.Conte
|
|||||||
panic("unexpected ListSchedulableByGroupIDAndPlatforms call")
|
panic("unexpected ListSchedulableByGroupIDAndPlatforms call")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *accountRepoStub) ListSchedulableUngroupedByPlatform(ctx context.Context, platform string) ([]Account, error) {
|
||||||
|
panic("unexpected ListSchedulableUngroupedByPlatform call")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *accountRepoStub) ListSchedulableUngroupedByPlatforms(ctx context.Context, platforms []string) ([]Account, error) {
|
||||||
|
panic("unexpected ListSchedulableUngroupedByPlatforms call")
|
||||||
|
}
|
||||||
|
|
||||||
func (s *accountRepoStub) SetRateLimited(ctx context.Context, id int64, resetAt time.Time) error {
|
func (s *accountRepoStub) SetRateLimited(ctx context.Context, id int64, resetAt time.Time) error {
|
||||||
panic("unexpected SetRateLimited call")
|
panic("unexpected SetRateLimited call")
|
||||||
}
|
}
|
||||||
|
|||||||
363
backend/internal/service/gateway_group_isolation_test.go
Normal file
363
backend/internal/service/gateway_group_isolation_test.go
Normal file
@@ -0,0 +1,363 @@
|
|||||||
|
//go:build unit
|
||||||
|
|
||||||
|
package service
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ============================================================================
|
||||||
|
// Part 1: isAccountInGroup 单元测试
|
||||||
|
// ============================================================================
|
||||||
|
|
||||||
|
func TestIsAccountInGroup(t *testing.T) {
|
||||||
|
svc := &GatewayService{}
|
||||||
|
groupID100 := int64(100)
|
||||||
|
groupID200 := int64(200)
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
account *Account
|
||||||
|
groupID *int64
|
||||||
|
expected bool
|
||||||
|
}{
|
||||||
|
// groupID == nil(无分组 API Key)
|
||||||
|
{
|
||||||
|
"nil_groupID_ungrouped_account_nil_groups",
|
||||||
|
&Account{ID: 1, AccountGroups: nil},
|
||||||
|
nil, true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"nil_groupID_ungrouped_account_empty_slice",
|
||||||
|
&Account{ID: 2, AccountGroups: []AccountGroup{}},
|
||||||
|
nil, true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"nil_groupID_grouped_account_single",
|
||||||
|
&Account{ID: 3, AccountGroups: []AccountGroup{{GroupID: 100}}},
|
||||||
|
nil, false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"nil_groupID_grouped_account_multiple",
|
||||||
|
&Account{ID: 4, AccountGroups: []AccountGroup{{GroupID: 100}, {GroupID: 200}}},
|
||||||
|
nil, false,
|
||||||
|
},
|
||||||
|
// groupID != nil(有分组 API Key)
|
||||||
|
{
|
||||||
|
"with_groupID_account_in_group",
|
||||||
|
&Account{ID: 5, AccountGroups: []AccountGroup{{GroupID: 100}}},
|
||||||
|
&groupID100, true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"with_groupID_account_not_in_group",
|
||||||
|
&Account{ID: 6, AccountGroups: []AccountGroup{{GroupID: 200}}},
|
||||||
|
&groupID100, false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"with_groupID_ungrouped_account",
|
||||||
|
&Account{ID: 7, AccountGroups: nil},
|
||||||
|
&groupID100, false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"with_groupID_multi_group_account_match_one",
|
||||||
|
&Account{ID: 8, AccountGroups: []AccountGroup{{GroupID: 100}, {GroupID: 200}}},
|
||||||
|
&groupID200, true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"with_groupID_multi_group_account_no_match",
|
||||||
|
&Account{ID: 9, AccountGroups: []AccountGroup{{GroupID: 300}, {GroupID: 400}}},
|
||||||
|
&groupID100, false,
|
||||||
|
},
|
||||||
|
// 防御性边界
|
||||||
|
{
|
||||||
|
"nil_account_nil_groupID",
|
||||||
|
nil,
|
||||||
|
nil, false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"nil_account_with_groupID",
|
||||||
|
nil,
|
||||||
|
&groupID100, false,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
got := svc.isAccountInGroup(tt.account, tt.groupID)
|
||||||
|
require.Equal(t, tt.expected, got, "isAccountInGroup 结果不符预期")
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ============================================================================
|
||||||
|
// Part 2: 分组隔离端到端调度测试
|
||||||
|
// ============================================================================
|
||||||
|
|
||||||
|
// groupAwareMockAccountRepo 嵌入 mockAccountRepoForPlatform,覆写分组隔离相关方法。
|
||||||
|
// allAccounts 存储所有账号,分组查询方法按 AccountGroups 字段进行真实过滤。
|
||||||
|
type groupAwareMockAccountRepo struct {
|
||||||
|
*mockAccountRepoForPlatform
|
||||||
|
allAccounts []Account
|
||||||
|
}
|
||||||
|
|
||||||
|
// ListSchedulableUngroupedByPlatform 仅返回未分组账号(AccountGroups 为空)
|
||||||
|
func (m *groupAwareMockAccountRepo) ListSchedulableUngroupedByPlatform(ctx context.Context, platform string) ([]Account, error) {
|
||||||
|
var result []Account
|
||||||
|
for _, acc := range m.allAccounts {
|
||||||
|
if acc.Platform == platform && acc.IsSchedulable() && len(acc.AccountGroups) == 0 {
|
||||||
|
result = append(result, acc)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return result, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ListSchedulableUngroupedByPlatforms 仅返回未分组账号(多平台版本)
|
||||||
|
func (m *groupAwareMockAccountRepo) ListSchedulableUngroupedByPlatforms(ctx context.Context, platforms []string) ([]Account, error) {
|
||||||
|
platformSet := make(map[string]bool, len(platforms))
|
||||||
|
for _, p := range platforms {
|
||||||
|
platformSet[p] = true
|
||||||
|
}
|
||||||
|
var result []Account
|
||||||
|
for _, acc := range m.allAccounts {
|
||||||
|
if platformSet[acc.Platform] && acc.IsSchedulable() && len(acc.AccountGroups) == 0 {
|
||||||
|
result = append(result, acc)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return result, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ListSchedulableByGroupIDAndPlatform 返回属于指定分组的账号
|
||||||
|
func (m *groupAwareMockAccountRepo) ListSchedulableByGroupIDAndPlatform(ctx context.Context, groupID int64, platform string) ([]Account, error) {
|
||||||
|
var result []Account
|
||||||
|
for _, acc := range m.allAccounts {
|
||||||
|
if acc.Platform == platform && acc.IsSchedulable() && accountBelongsToGroup(acc, groupID) {
|
||||||
|
result = append(result, acc)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return result, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ListSchedulableByGroupIDAndPlatforms 返回属于指定分组的账号(多平台版本)
|
||||||
|
func (m *groupAwareMockAccountRepo) ListSchedulableByGroupIDAndPlatforms(ctx context.Context, groupID int64, platforms []string) ([]Account, error) {
|
||||||
|
platformSet := make(map[string]bool, len(platforms))
|
||||||
|
for _, p := range platforms {
|
||||||
|
platformSet[p] = true
|
||||||
|
}
|
||||||
|
var result []Account
|
||||||
|
for _, acc := range m.allAccounts {
|
||||||
|
if platformSet[acc.Platform] && acc.IsSchedulable() && accountBelongsToGroup(acc, groupID) {
|
||||||
|
result = append(result, acc)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return result, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// accountBelongsToGroup 检查账号是否属于指定分组
|
||||||
|
func accountBelongsToGroup(acc Account, groupID int64) bool {
|
||||||
|
for _, ag := range acc.AccountGroups {
|
||||||
|
if ag.GroupID == groupID {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify interface implementation
|
||||||
|
var _ AccountRepository = (*groupAwareMockAccountRepo)(nil)
|
||||||
|
|
||||||
|
// newGroupAwareMockRepo 创建分组感知的 mock repo
|
||||||
|
func newGroupAwareMockRepo(accounts []Account) *groupAwareMockAccountRepo {
|
||||||
|
byID := make(map[int64]*Account, len(accounts))
|
||||||
|
for i := range accounts {
|
||||||
|
byID[accounts[i].ID] = &accounts[i]
|
||||||
|
}
|
||||||
|
return &groupAwareMockAccountRepo{
|
||||||
|
mockAccountRepoForPlatform: &mockAccountRepoForPlatform{
|
||||||
|
accounts: accounts,
|
||||||
|
accountsByID: byID,
|
||||||
|
},
|
||||||
|
allAccounts: accounts,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGroupIsolation_UngroupedKey_ShouldNotScheduleGroupedAccounts(t *testing.T) {
|
||||||
|
// 场景:无分组 API Key(groupID=nil),池中只有已分组账号 → 应返回错误
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
accounts := []Account{
|
||||||
|
{ID: 1, Platform: PlatformOpenAI, Priority: 1, Status: StatusActive, Schedulable: true,
|
||||||
|
AccountGroups: []AccountGroup{{GroupID: 100}}},
|
||||||
|
{ID: 2, Platform: PlatformOpenAI, Priority: 2, Status: StatusActive, Schedulable: true,
|
||||||
|
AccountGroups: []AccountGroup{{GroupID: 200}}},
|
||||||
|
}
|
||||||
|
repo := newGroupAwareMockRepo(accounts)
|
||||||
|
cache := &mockGatewayCacheForPlatform{}
|
||||||
|
|
||||||
|
svc := &GatewayService{
|
||||||
|
accountRepo: repo,
|
||||||
|
cache: cache,
|
||||||
|
cfg: testConfig(),
|
||||||
|
}
|
||||||
|
|
||||||
|
acc, err := svc.selectAccountForModelWithPlatform(ctx, nil, "", "", nil, PlatformOpenAI)
|
||||||
|
require.Error(t, err, "无分组 Key 不应调度到已分组账号")
|
||||||
|
require.Nil(t, acc)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGroupIsolation_GroupedKey_ShouldNotScheduleUngroupedAccounts(t *testing.T) {
|
||||||
|
// 场景:有分组 API Key(groupID=100),池中只有未分组账号 → 应返回错误
|
||||||
|
ctx := context.Background()
|
||||||
|
groupID := int64(100)
|
||||||
|
|
||||||
|
accounts := []Account{
|
||||||
|
{ID: 1, Platform: PlatformOpenAI, Priority: 1, Status: StatusActive, Schedulable: true,
|
||||||
|
AccountGroups: nil},
|
||||||
|
{ID: 2, Platform: PlatformOpenAI, Priority: 2, Status: StatusActive, Schedulable: true,
|
||||||
|
AccountGroups: []AccountGroup{}},
|
||||||
|
}
|
||||||
|
repo := newGroupAwareMockRepo(accounts)
|
||||||
|
cache := &mockGatewayCacheForPlatform{}
|
||||||
|
|
||||||
|
svc := &GatewayService{
|
||||||
|
accountRepo: repo,
|
||||||
|
cache: cache,
|
||||||
|
cfg: testConfig(),
|
||||||
|
}
|
||||||
|
|
||||||
|
acc, err := svc.selectAccountForModelWithPlatform(ctx, &groupID, "", "", nil, PlatformOpenAI)
|
||||||
|
require.Error(t, err, "有分组 Key 不应调度到未分组账号")
|
||||||
|
require.Nil(t, acc)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGroupIsolation_UngroupedKey_ShouldOnlyScheduleUngroupedAccounts(t *testing.T) {
|
||||||
|
// 场景:无分组 API Key(groupID=nil),池中有未分组和已分组账号 → 应只选中未分组的
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
accounts := []Account{
|
||||||
|
{ID: 1, Platform: PlatformOpenAI, Priority: 1, Status: StatusActive, Schedulable: true,
|
||||||
|
AccountGroups: []AccountGroup{{GroupID: 100}}}, // 已分组,不应被选中
|
||||||
|
{ID: 2, Platform: PlatformOpenAI, Priority: 2, Status: StatusActive, Schedulable: true,
|
||||||
|
AccountGroups: nil}, // 未分组,应被选中
|
||||||
|
{ID: 3, Platform: PlatformOpenAI, Priority: 3, Status: StatusActive, Schedulable: true,
|
||||||
|
AccountGroups: []AccountGroup{{GroupID: 200}}}, // 已分组,不应被选中
|
||||||
|
}
|
||||||
|
repo := newGroupAwareMockRepo(accounts)
|
||||||
|
cache := &mockGatewayCacheForPlatform{}
|
||||||
|
|
||||||
|
svc := &GatewayService{
|
||||||
|
accountRepo: repo,
|
||||||
|
cache: cache,
|
||||||
|
cfg: testConfig(),
|
||||||
|
}
|
||||||
|
|
||||||
|
acc, err := svc.selectAccountForModelWithPlatform(ctx, nil, "", "", nil, PlatformOpenAI)
|
||||||
|
require.NoError(t, err, "应成功调度未分组账号")
|
||||||
|
require.NotNil(t, acc)
|
||||||
|
require.Equal(t, int64(2), acc.ID, "应选中未分组的账号 ID=2")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGroupIsolation_GroupedKey_ShouldOnlyScheduleMatchingGroupAccounts(t *testing.T) {
|
||||||
|
// 场景:有分组 API Key(groupID=100),池中有未分组和多个分组账号 → 应只选中分组 100 内的
|
||||||
|
ctx := context.Background()
|
||||||
|
groupID := int64(100)
|
||||||
|
|
||||||
|
accounts := []Account{
|
||||||
|
{ID: 1, Platform: PlatformOpenAI, Priority: 1, Status: StatusActive, Schedulable: true,
|
||||||
|
AccountGroups: nil}, // 未分组,不应被选中
|
||||||
|
{ID: 2, Platform: PlatformOpenAI, Priority: 2, Status: StatusActive, Schedulable: true,
|
||||||
|
AccountGroups: []AccountGroup{{GroupID: 200}}}, // 属于分组 200,不应被选中
|
||||||
|
{ID: 3, Platform: PlatformOpenAI, Priority: 3, Status: StatusActive, Schedulable: true,
|
||||||
|
AccountGroups: []AccountGroup{{GroupID: 100}}}, // 属于分组 100,应被选中
|
||||||
|
}
|
||||||
|
repo := newGroupAwareMockRepo(accounts)
|
||||||
|
cache := &mockGatewayCacheForPlatform{}
|
||||||
|
|
||||||
|
svc := &GatewayService{
|
||||||
|
accountRepo: repo,
|
||||||
|
cache: cache,
|
||||||
|
cfg: testConfig(),
|
||||||
|
}
|
||||||
|
|
||||||
|
acc, err := svc.selectAccountForModelWithPlatform(ctx, &groupID, "", "", nil, PlatformOpenAI)
|
||||||
|
require.NoError(t, err, "应成功调度分组内账号")
|
||||||
|
require.NotNil(t, acc)
|
||||||
|
require.Equal(t, int64(3), acc.ID, "应选中分组 100 内的账号 ID=3")
|
||||||
|
}
|
||||||
|
|
||||||
|
// ============================================================================
|
||||||
|
// Part 3: SimpleMode 旁路测试
|
||||||
|
// ============================================================================
|
||||||
|
|
||||||
|
func TestGroupIsolation_SimpleMode_SkipsGroupIsolation(t *testing.T) {
|
||||||
|
// SimpleMode 应跳过分组隔离,使用 ListSchedulableByPlatform 返回所有账号。
|
||||||
|
// 测试非 useMixed 路径(platform=openai,不会触发 mixed 调度逻辑)。
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
// 混合未分组和已分组账号,SimpleMode 下应全部可调度
|
||||||
|
accounts := []Account{
|
||||||
|
{ID: 1, Platform: PlatformOpenAI, Priority: 2, Status: StatusActive, Schedulable: true,
|
||||||
|
AccountGroups: []AccountGroup{{GroupID: 100}}}, // 已分组
|
||||||
|
{ID: 2, Platform: PlatformOpenAI, Priority: 1, Status: StatusActive, Schedulable: true,
|
||||||
|
AccountGroups: nil}, // 未分组
|
||||||
|
}
|
||||||
|
|
||||||
|
// 使用基础 mock(ListSchedulableByPlatform 返回所有匹配平台的账号,不做分组过滤)
|
||||||
|
byID := make(map[int64]*Account, len(accounts))
|
||||||
|
for i := range accounts {
|
||||||
|
byID[accounts[i].ID] = &accounts[i]
|
||||||
|
}
|
||||||
|
repo := &mockAccountRepoForPlatform{
|
||||||
|
accounts: accounts,
|
||||||
|
accountsByID: byID,
|
||||||
|
}
|
||||||
|
cache := &mockGatewayCacheForPlatform{}
|
||||||
|
|
||||||
|
svc := &GatewayService{
|
||||||
|
accountRepo: repo,
|
||||||
|
cache: cache,
|
||||||
|
cfg: &config.Config{RunMode: config.RunModeSimple},
|
||||||
|
}
|
||||||
|
|
||||||
|
// groupID=nil 时,SimpleMode 应使用 ListSchedulableByPlatform(不过滤分组)
|
||||||
|
acc, err := svc.selectAccountForModelWithPlatform(ctx, nil, "", "", nil, PlatformOpenAI)
|
||||||
|
require.NoError(t, err, "SimpleMode 应跳过分组隔离直接返回账号")
|
||||||
|
require.NotNil(t, acc)
|
||||||
|
// 应选择优先级最高的账号(Priority=1, ID=2),即使它未分组
|
||||||
|
require.Equal(t, int64(2), acc.ID, "SimpleMode 应按优先级选择,不考虑分组")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGroupIsolation_SimpleMode_GroupedAccountAlsoSchedulable(t *testing.T) {
|
||||||
|
// SimpleMode + groupID=nil 时,已分组账号也应该可被调度
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
// 只有已分组账号,在 standard 模式下 groupID=nil 会报错,但 simple 模式应正常
|
||||||
|
accounts := []Account{
|
||||||
|
{ID: 1, Platform: PlatformOpenAI, Priority: 1, Status: StatusActive, Schedulable: true,
|
||||||
|
AccountGroups: []AccountGroup{{GroupID: 100}}},
|
||||||
|
}
|
||||||
|
|
||||||
|
byID := make(map[int64]*Account, len(accounts))
|
||||||
|
for i := range accounts {
|
||||||
|
byID[accounts[i].ID] = &accounts[i]
|
||||||
|
}
|
||||||
|
repo := &mockAccountRepoForPlatform{
|
||||||
|
accounts: accounts,
|
||||||
|
accountsByID: byID,
|
||||||
|
}
|
||||||
|
cache := &mockGatewayCacheForPlatform{}
|
||||||
|
|
||||||
|
svc := &GatewayService{
|
||||||
|
accountRepo: repo,
|
||||||
|
cache: cache,
|
||||||
|
cfg: &config.Config{RunMode: config.RunModeSimple},
|
||||||
|
}
|
||||||
|
|
||||||
|
acc, err := svc.selectAccountForModelWithPlatform(ctx, nil, "", "", nil, PlatformOpenAI)
|
||||||
|
require.NoError(t, err, "SimpleMode 下已分组账号也应可调度")
|
||||||
|
require.NotNil(t, acc)
|
||||||
|
require.Equal(t, int64(1), acc.ID, "SimpleMode 应能调度已分组账号")
|
||||||
|
}
|
||||||
@@ -147,6 +147,12 @@ func (m *mockAccountRepoForPlatform) ListSchedulableByPlatforms(ctx context.Cont
|
|||||||
func (m *mockAccountRepoForPlatform) ListSchedulableByGroupIDAndPlatforms(ctx context.Context, groupID int64, platforms []string) ([]Account, error) {
|
func (m *mockAccountRepoForPlatform) ListSchedulableByGroupIDAndPlatforms(ctx context.Context, groupID int64, platforms []string) ([]Account, error) {
|
||||||
return m.ListSchedulableByPlatforms(ctx, platforms)
|
return m.ListSchedulableByPlatforms(ctx, platforms)
|
||||||
}
|
}
|
||||||
|
func (m *mockAccountRepoForPlatform) ListSchedulableUngroupedByPlatform(ctx context.Context, platform string) ([]Account, error) {
|
||||||
|
return m.ListSchedulableByPlatform(ctx, platform)
|
||||||
|
}
|
||||||
|
func (m *mockAccountRepoForPlatform) ListSchedulableUngroupedByPlatforms(ctx context.Context, platforms []string) ([]Account, error) {
|
||||||
|
return m.ListSchedulableByPlatforms(ctx, platforms)
|
||||||
|
}
|
||||||
func (m *mockAccountRepoForPlatform) SetRateLimited(ctx context.Context, id int64, resetAt time.Time) error {
|
func (m *mockAccountRepoForPlatform) SetRateLimited(ctx context.Context, id int64, resetAt time.Time) error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1782,8 +1782,10 @@ func (s *GatewayService) listSchedulableAccounts(ctx context.Context, groupID *i
|
|||||||
var err error
|
var err error
|
||||||
if groupID != nil {
|
if groupID != nil {
|
||||||
accounts, err = s.accountRepo.ListSchedulableByGroupIDAndPlatforms(ctx, *groupID, platforms)
|
accounts, err = s.accountRepo.ListSchedulableByGroupIDAndPlatforms(ctx, *groupID, platforms)
|
||||||
} else {
|
} else if s.cfg != nil && s.cfg.RunMode == config.RunModeSimple {
|
||||||
accounts, err = s.accountRepo.ListSchedulableByPlatforms(ctx, platforms)
|
accounts, err = s.accountRepo.ListSchedulableByPlatforms(ctx, platforms)
|
||||||
|
} else {
|
||||||
|
accounts, err = s.accountRepo.ListSchedulableUngroupedByPlatforms(ctx, platforms)
|
||||||
}
|
}
|
||||||
if err != nil {
|
if err != nil {
|
||||||
slog.Debug("account_scheduling_list_failed",
|
slog.Debug("account_scheduling_list_failed",
|
||||||
@@ -1824,7 +1826,7 @@ func (s *GatewayService) listSchedulableAccounts(ctx context.Context, groupID *i
|
|||||||
accounts, err = s.accountRepo.ListSchedulableByGroupIDAndPlatform(ctx, *groupID, platform)
|
accounts, err = s.accountRepo.ListSchedulableByGroupIDAndPlatform(ctx, *groupID, platform)
|
||||||
// 分组内无账号则返回空列表,由上层处理错误,不再回退到全平台查询
|
// 分组内无账号则返回空列表,由上层处理错误,不再回退到全平台查询
|
||||||
} else {
|
} else {
|
||||||
accounts, err = s.accountRepo.ListSchedulableByPlatform(ctx, platform)
|
accounts, err = s.accountRepo.ListSchedulableUngroupedByPlatform(ctx, platform)
|
||||||
}
|
}
|
||||||
if err != nil {
|
if err != nil {
|
||||||
slog.Debug("account_scheduling_list_failed",
|
slog.Debug("account_scheduling_list_failed",
|
||||||
@@ -1964,14 +1966,15 @@ func (s *GatewayService) isAccountSchedulableForModelSelection(ctx context.Conte
|
|||||||
}
|
}
|
||||||
|
|
||||||
// isAccountInGroup checks if the account belongs to the specified group.
|
// isAccountInGroup checks if the account belongs to the specified group.
|
||||||
// Returns true if groupID is nil (no group restriction) or account belongs to the group.
|
// When groupID is nil, returns true only for ungrouped accounts (no group assignments).
|
||||||
func (s *GatewayService) isAccountInGroup(account *Account, groupID *int64) bool {
|
func (s *GatewayService) isAccountInGroup(account *Account, groupID *int64) bool {
|
||||||
if groupID == nil {
|
|
||||||
return true // 无分组限制
|
|
||||||
}
|
|
||||||
if account == nil {
|
if account == nil {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
if groupID == nil {
|
||||||
|
// 无分组的 API Key 只能使用未分组的账号
|
||||||
|
return len(account.AccountGroups) == 0
|
||||||
|
}
|
||||||
for _, ag := range account.AccountGroups {
|
for _, ag := range account.AccountGroups {
|
||||||
if ag.GroupID == *groupID {
|
if ag.GroupID == *groupID {
|
||||||
return true
|
return true
|
||||||
|
|||||||
@@ -431,7 +431,10 @@ func (s *GeminiMessagesCompatService) listSchedulableAccountsOnce(ctx context.Co
|
|||||||
if groupID != nil {
|
if groupID != nil {
|
||||||
return s.accountRepo.ListSchedulableByGroupIDAndPlatforms(ctx, *groupID, queryPlatforms)
|
return s.accountRepo.ListSchedulableByGroupIDAndPlatforms(ctx, *groupID, queryPlatforms)
|
||||||
}
|
}
|
||||||
return s.accountRepo.ListSchedulableByPlatforms(ctx, queryPlatforms)
|
if s.cfg != nil && s.cfg.RunMode == config.RunModeSimple {
|
||||||
|
return s.accountRepo.ListSchedulableByPlatforms(ctx, queryPlatforms)
|
||||||
|
}
|
||||||
|
return s.accountRepo.ListSchedulableUngroupedByPlatforms(ctx, queryPlatforms)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *GeminiMessagesCompatService) validateUpstreamBaseURL(raw string) (string, error) {
|
func (s *GeminiMessagesCompatService) validateUpstreamBaseURL(raw string) (string, error) {
|
||||||
|
|||||||
@@ -138,6 +138,12 @@ func (m *mockAccountRepoForGemini) ListSchedulableByGroupIDAndPlatforms(ctx cont
|
|||||||
}
|
}
|
||||||
return m.ListSchedulableByPlatforms(ctx, platforms)
|
return m.ListSchedulableByPlatforms(ctx, platforms)
|
||||||
}
|
}
|
||||||
|
func (m *mockAccountRepoForGemini) ListSchedulableUngroupedByPlatform(ctx context.Context, platform string) ([]Account, error) {
|
||||||
|
return m.ListSchedulableByPlatform(ctx, platform)
|
||||||
|
}
|
||||||
|
func (m *mockAccountRepoForGemini) ListSchedulableUngroupedByPlatforms(ctx context.Context, platforms []string) ([]Account, error) {
|
||||||
|
return m.ListSchedulableByPlatforms(ctx, platforms)
|
||||||
|
}
|
||||||
func (m *mockAccountRepoForGemini) SetRateLimited(ctx context.Context, id int64, resetAt time.Time) error {
|
func (m *mockAccountRepoForGemini) SetRateLimited(ctx context.Context, id int64, resetAt time.Time) error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1343,7 +1343,7 @@ func (s *OpenAIGatewayService) listSchedulableAccounts(ctx context.Context, grou
|
|||||||
} else if groupID != nil {
|
} else if groupID != nil {
|
||||||
accounts, err = s.accountRepo.ListSchedulableByGroupIDAndPlatform(ctx, *groupID, PlatformOpenAI)
|
accounts, err = s.accountRepo.ListSchedulableByGroupIDAndPlatform(ctx, *groupID, PlatformOpenAI)
|
||||||
} else {
|
} else {
|
||||||
accounts, err = s.accountRepo.ListSchedulableByPlatform(ctx, PlatformOpenAI)
|
accounts, err = s.accountRepo.ListSchedulableUngroupedByPlatform(ctx, PlatformOpenAI)
|
||||||
}
|
}
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("query accounts failed: %w", err)
|
return nil, fmt.Errorf("query accounts failed: %w", err)
|
||||||
|
|||||||
@@ -57,6 +57,10 @@ func (r stubOpenAIAccountRepo) ListSchedulableByPlatform(ctx context.Context, pl
|
|||||||
return result, nil
|
return result, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (r stubOpenAIAccountRepo) ListSchedulableUngroupedByPlatform(ctx context.Context, platform string) ([]Account, error) {
|
||||||
|
return r.ListSchedulableByPlatform(ctx, platform)
|
||||||
|
}
|
||||||
|
|
||||||
type stubConcurrencyCache struct {
|
type stubConcurrencyCache struct {
|
||||||
ConcurrencyCache
|
ConcurrencyCache
|
||||||
loadBatchErr error
|
loadBatchErr error
|
||||||
|
|||||||
@@ -605,8 +605,10 @@ func (s *SchedulerSnapshotService) loadAccountsFromDB(ctx context.Context, bucke
|
|||||||
var err error
|
var err error
|
||||||
if groupID > 0 {
|
if groupID > 0 {
|
||||||
accounts, err = s.accountRepo.ListSchedulableByGroupIDAndPlatforms(ctx, groupID, platforms)
|
accounts, err = s.accountRepo.ListSchedulableByGroupIDAndPlatforms(ctx, groupID, platforms)
|
||||||
} else {
|
} else if s.isRunModeSimple() {
|
||||||
accounts, err = s.accountRepo.ListSchedulableByPlatforms(ctx, platforms)
|
accounts, err = s.accountRepo.ListSchedulableByPlatforms(ctx, platforms)
|
||||||
|
} else {
|
||||||
|
accounts, err = s.accountRepo.ListSchedulableUngroupedByPlatforms(ctx, platforms)
|
||||||
}
|
}
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
@@ -624,7 +626,10 @@ func (s *SchedulerSnapshotService) loadAccountsFromDB(ctx context.Context, bucke
|
|||||||
if groupID > 0 {
|
if groupID > 0 {
|
||||||
return s.accountRepo.ListSchedulableByGroupIDAndPlatform(ctx, groupID, bucket.Platform)
|
return s.accountRepo.ListSchedulableByGroupIDAndPlatform(ctx, groupID, bucket.Platform)
|
||||||
}
|
}
|
||||||
return s.accountRepo.ListSchedulableByPlatform(ctx, bucket.Platform)
|
if s.isRunModeSimple() {
|
||||||
|
return s.accountRepo.ListSchedulableByPlatform(ctx, bucket.Platform)
|
||||||
|
}
|
||||||
|
return s.accountRepo.ListSchedulableUngroupedByPlatform(ctx, bucket.Platform)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *SchedulerSnapshotService) bucketFor(groupID *int64, platform string, mode string) SchedulerBucket {
|
func (s *SchedulerSnapshotService) bucketFor(groupID *int64, platform string, mode string) SchedulerBucket {
|
||||||
|
|||||||
Reference in New Issue
Block a user