From 2bfb16291f892b58e2c7c30143036b8cabbc6f05 Mon Sep 17 00:00:00 2001 From: yangjianbo Date: Mon, 9 Feb 2026 21:35:41 +0800 Subject: [PATCH] =?UTF-8?q?fix(unit):=20=E4=BF=AE=E5=A4=8D=20unit=20tag=20?= =?UTF-8?q?=E6=B5=8B=E8=AF=95=E7=BC=96=E8=AF=91=E4=B8=8E=E8=B4=A6=E5=8F=B7?= =?UTF-8?q?=E9=80=89=E6=8B=A9=E7=94=A8=E4=BE=8B?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../handler/sora_gateway_handler_test.go | 12 +- .../service/gateway_account_selection_test.go | 164 +++++++++--------- backend/internal/service/gateway_service.go | 35 +++- .../service/scheduler_shuffle_test.go | 12 +- backend/internal/testutil/stubs.go | 12 -- 5 files changed, 130 insertions(+), 105 deletions(-) diff --git a/backend/internal/handler/sora_gateway_handler_test.go b/backend/internal/handler/sora_gateway_handler_test.go index ba266d5c..bc042478 100644 --- a/backend/internal/handler/sora_gateway_handler_test.go +++ b/backend/internal/handler/sora_gateway_handler_test.go @@ -78,6 +78,9 @@ func (r *stubAccountRepo) GetByCRSAccountID(ctx context.Context, crsAccountID st func (r *stubAccountRepo) FindByExtraField(ctx context.Context, key string, value any) ([]service.Account, error) { return nil, nil } +func (r *stubAccountRepo) ListCRSAccountIDs(ctx context.Context) (map[string]int64, error) { + return map[string]int64{}, nil +} func (r *stubAccountRepo) Update(ctx context.Context, account *service.Account) error { return nil } func (r *stubAccountRepo) Delete(ctx context.Context, id int64) error { return nil } func (r *stubAccountRepo) List(ctx context.Context, params pagination.PaginationParams) ([]service.Account, *pagination.PaginationResult, error) { @@ -138,9 +141,6 @@ func (r *stubAccountRepo) ListSchedulableByGroupIDAndPlatforms(ctx context.Conte func (r *stubAccountRepo) SetRateLimited(ctx context.Context, id int64, resetAt time.Time) error { return nil } -func (r *stubAccountRepo) SetAntigravityQuotaScopeLimit(ctx context.Context, id int64, scope service.AntigravityQuotaScope, resetAt time.Time) error { - return nil -} func (r *stubAccountRepo) SetModelRateLimit(ctx context.Context, id int64, scope string, resetAt time.Time) error { return nil } @@ -227,6 +227,9 @@ func (r *stubGroupRepo) GetAccountIDsByGroupIDs(ctx context.Context, groupIDs [] func (r *stubGroupRepo) BindAccountsToGroup(ctx context.Context, groupID int64, accountIDs []int64) error { return nil } +func (r *stubGroupRepo) UpdateSortOrders(ctx context.Context, updates []service.GroupSortOrderUpdate) error { + return nil +} type stubUsageLogRepo struct{} @@ -367,7 +370,7 @@ func TestSoraGatewayHandler_ChatCompletions(t *testing.T) { nil, nil, nil, - nil, + testutil.StubGatewayCache{}, cfg, nil, concurrencyService, @@ -378,6 +381,7 @@ func TestSoraGatewayHandler_ChatCompletions(t *testing.T) { nil, deferredService, nil, + testutil.StubSessionLimitCache{}, nil, ) diff --git a/backend/internal/service/gateway_account_selection_test.go b/backend/internal/service/gateway_account_selection_test.go index 70c5d6c5..0a82fade 100644 --- a/backend/internal/service/gateway_account_selection_test.go +++ b/backend/internal/service/gateway_account_selection_test.go @@ -74,11 +74,24 @@ func TestSortAccountsByPriorityAndLastUsed_StableSort(t *testing.T) { {ID: 2, Priority: 1, LastUsedAt: nil, Type: AccountTypeAPIKey}, {ID: 3, Priority: 1, LastUsedAt: nil, Type: AccountTypeAPIKey}, } - sortAccountsByPriorityAndLastUsed(accounts, false) - // 稳定排序:相同键值的元素保持原始顺序 - require.Equal(t, int64(1), accounts[0].ID) - require.Equal(t, int64(2), accounts[1].ID) - require.Equal(t, int64(3), accounts[2].ID) + + // sortAccountsByPriorityAndLastUsed 内部会在同组(Priority+LastUsedAt)内做随机打散, + // 因此这里不再断言“稳定排序”。我们只验证: + // 1) 元素集合不变;2) 多次运行能产生不同的顺序。 + seenFirst := map[int64]bool{} + for i := 0; i < 100; i++ { + cpy := make([]*Account, len(accounts)) + copy(cpy, accounts) + sortAccountsByPriorityAndLastUsed(cpy, false) + seenFirst[cpy[0].ID] = true + + ids := map[int64]bool{} + for _, a := range cpy { + ids[a.ID] = true + } + require.True(t, ids[1] && ids[2] && ids[3]) + } + require.GreaterOrEqual(t, len(seenFirst), 2, "同组账号应能被随机打散") } func TestSortAccountsByPriorityAndLastUsed_MixedPriorityAndTime(t *testing.T) { @@ -98,101 +111,96 @@ func TestSortAccountsByPriorityAndLastUsed_MixedPriorityAndTime(t *testing.T) { require.Equal(t, int64(4), accounts[3].ID, "优先级2 + 有时间") } -// --- selectByCallCount --- +// --- filterByMinPriority --- -func TestSelectByCallCount_Empty(t *testing.T) { - result := selectByCallCount(nil, nil, false) +func TestFilterByMinPriority_Empty(t *testing.T) { + result := filterByMinPriority(nil) require.Nil(t, result) } -func TestSelectByCallCount_Single(t *testing.T) { +func TestFilterByMinPriority_SelectsMinPriority(t *testing.T) { accounts := []accountWithLoad{ - makeAccWithLoad(1, 1, 50, nil, AccountTypeAPIKey), + makeAccWithLoad(1, 5, 10, nil, AccountTypeAPIKey), + makeAccWithLoad(2, 1, 10, nil, AccountTypeAPIKey), + makeAccWithLoad(3, 1, 20, nil, AccountTypeAPIKey), + makeAccWithLoad(4, 2, 10, nil, AccountTypeAPIKey), } - result := selectByCallCount(accounts, map[int64]*ModelLoadInfo{1: {CallCount: 10}}, false) + result := filterByMinPriority(accounts) + require.Len(t, result, 2) + require.Equal(t, int64(2), result[0].account.ID) + require.Equal(t, int64(3), result[1].account.ID) +} + +// --- filterByMinLoadRate --- + +func TestFilterByMinLoadRate_Empty(t *testing.T) { + result := filterByMinLoadRate(nil) + require.Nil(t, result) +} + +func TestFilterByMinLoadRate_SelectsMinLoadRate(t *testing.T) { + accounts := []accountWithLoad{ + makeAccWithLoad(1, 1, 30, nil, AccountTypeAPIKey), + makeAccWithLoad(2, 1, 10, nil, AccountTypeAPIKey), + makeAccWithLoad(3, 1, 10, nil, AccountTypeAPIKey), + makeAccWithLoad(4, 1, 20, nil, AccountTypeAPIKey), + } + result := filterByMinLoadRate(accounts) + require.Len(t, result, 2) + require.Equal(t, int64(2), result[0].account.ID) + require.Equal(t, int64(3), result[1].account.ID) +} + +// --- selectByLRU --- + +func TestSelectByLRU_Empty(t *testing.T) { + result := selectByLRU(nil, false) + require.Nil(t, result) +} + +func TestSelectByLRU_Single(t *testing.T) { + accounts := []accountWithLoad{makeAccWithLoad(1, 1, 10, nil, AccountTypeAPIKey)} + result := selectByLRU(accounts, false) require.NotNil(t, result) require.Equal(t, int64(1), result.account.ID) } -func TestSelectByCallCount_NilModelLoadFallsBackToLRU(t *testing.T) { +func TestSelectByLRU_NilLastUsedAtWins(t *testing.T) { now := time.Now() accounts := []accountWithLoad{ - makeAccWithLoad(1, 1, 50, testTimePtr(now), AccountTypeAPIKey), - makeAccWithLoad(2, 1, 50, testTimePtr(now.Add(-1*time.Hour)), AccountTypeAPIKey), + makeAccWithLoad(1, 1, 10, testTimePtr(now), AccountTypeAPIKey), + makeAccWithLoad(2, 1, 10, nil, AccountTypeAPIKey), + makeAccWithLoad(3, 1, 10, testTimePtr(now.Add(-1*time.Hour)), AccountTypeAPIKey), } - result := selectByCallCount(accounts, nil, false) + result := selectByLRU(accounts, false) require.NotNil(t, result) - require.Equal(t, int64(2), result.account.ID, "nil modelLoadMap 应回退到 LRU 选择") + require.Equal(t, int64(2), result.account.ID) } -func TestSelectByCallCount_SelectsMinCallCount(t *testing.T) { +func TestSelectByLRU_EarliestTimeWins(t *testing.T) { + now := time.Now() accounts := []accountWithLoad{ - makeAccWithLoad(1, 1, 50, nil, AccountTypeAPIKey), - makeAccWithLoad(2, 1, 50, nil, AccountTypeAPIKey), - makeAccWithLoad(3, 1, 50, nil, AccountTypeAPIKey), - } - modelLoad := map[int64]*ModelLoadInfo{ - 1: {CallCount: 100}, - 2: {CallCount: 5}, - 3: {CallCount: 50}, - } - // 运行多次确认总是选调用次数最少的 - for i := 0; i < 10; i++ { - result := selectByCallCount(accounts, modelLoad, false) - require.NotNil(t, result) - require.Equal(t, int64(2), result.account.ID, "应选择调用次数最少的账号") + makeAccWithLoad(1, 1, 10, testTimePtr(now), AccountTypeAPIKey), + makeAccWithLoad(2, 1, 10, testTimePtr(now.Add(-1*time.Hour)), AccountTypeAPIKey), + makeAccWithLoad(3, 1, 10, testTimePtr(now.Add(-2*time.Hour)), AccountTypeAPIKey), } + result := selectByLRU(accounts, false) + require.NotNil(t, result) + require.Equal(t, int64(3), result.account.ID) } -func TestSelectByCallCount_NewAccountUsesAverage(t *testing.T) { +func TestSelectByLRU_TiePreferOAuth(t *testing.T) { + now := time.Now() + // 账号 1/2 LastUsedAt 相同,且同为最小值。 accounts := []accountWithLoad{ - makeAccWithLoad(1, 1, 50, nil, AccountTypeAPIKey), - makeAccWithLoad(2, 1, 50, nil, AccountTypeAPIKey), - makeAccWithLoad(3, 1, 50, nil, AccountTypeAPIKey), + makeAccWithLoad(1, 1, 10, testTimePtr(now), AccountTypeAPIKey), + makeAccWithLoad(2, 1, 10, testTimePtr(now), AccountTypeOAuth), + makeAccWithLoad(3, 1, 10, testTimePtr(now.Add(1*time.Hour)), AccountTypeAPIKey), } - // 账号1和2有调用记录,账号3是新账号(CallCount=0) - // 平均调用次数 = (100 + 200) / 2 = 150 - // 新账号用平均值 150,比账号1(100)多,所以应选账号1 - modelLoad := map[int64]*ModelLoadInfo{ - 1: {CallCount: 100}, - 2: {CallCount: 200}, - // 3 没有记录 - } - for i := 0; i < 10; i++ { - result := selectByCallCount(accounts, modelLoad, false) + for i := 0; i < 50; i++ { + result := selectByLRU(accounts, true) require.NotNil(t, result) - require.Equal(t, int64(1), result.account.ID, "新账号虚拟调用次数(150)高于账号1(100),应选账号1") - } -} - -func TestSelectByCallCount_AllNewAccountsFallToAvgZero(t *testing.T) { - accounts := []accountWithLoad{ - makeAccWithLoad(1, 1, 50, nil, AccountTypeAPIKey), - makeAccWithLoad(2, 1, 50, nil, AccountTypeAPIKey), - } - // 所有账号都是新的,avgCallCount = 0,所有人 effectiveCallCount 都是 0 - modelLoad := map[int64]*ModelLoadInfo{} - validIDs := map[int64]bool{1: true, 2: true} - for i := 0; i < 10; i++ { - result := selectByCallCount(accounts, modelLoad, false) - require.NotNil(t, result) - require.True(t, validIDs[result.account.ID], "所有新账号应随机选择") - } -} - -func TestSelectByCallCount_PreferOAuth(t *testing.T) { - accounts := []accountWithLoad{ - makeAccWithLoad(1, 1, 50, nil, AccountTypeAPIKey), - makeAccWithLoad(2, 1, 50, nil, AccountTypeOAuth), - } - // 两个账号调用次数相同 - modelLoad := map[int64]*ModelLoadInfo{ - 1: {CallCount: 10}, - 2: {CallCount: 10}, - } - for i := 0; i < 10; i++ { - result := selectByCallCount(accounts, modelLoad, true) - require.NotNil(t, result) - require.Equal(t, int64(2), result.account.ID, "调用次数相同时应优先选择 OAuth 账号") + require.Equal(t, AccountTypeOAuth, result.account.Type) + require.Equal(t, int64(2), result.account.ID) } } diff --git a/backend/internal/service/gateway_service.go b/backend/internal/service/gateway_service.go index 040745a8..2e1b0ba4 100644 --- a/backend/internal/service/gateway_service.go +++ b/backend/internal/service/gateway_service.go @@ -1937,7 +1937,7 @@ func sortAccountsByPriorityAndLastUsed(accounts []*Account, preferOAuth bool) { return a.LastUsedAt.Before(*b.LastUsedAt) } }) - shuffleWithinPriorityAndLastUsed(accounts) + shuffleWithinPriorityAndLastUsed(accounts, preferOAuth) } // shuffleWithinSortGroups 对排序后的 accountWithLoad 切片,按 (Priority, LoadRate, LastUsedAt) 分组后组内随机打乱。 @@ -1973,7 +1973,12 @@ func sameAccountWithLoadGroup(a, b accountWithLoad) bool { } // shuffleWithinPriorityAndLastUsed 对排序后的 []*Account 切片,按 (Priority, LastUsedAt) 分组后组内随机打乱。 -func shuffleWithinPriorityAndLastUsed(accounts []*Account) { +// +// 注意:当 preferOAuth=true 时,需要保证 OAuth 账号在同组内仍然优先,否则会把排序时的偏好打散掉。 +// 因此这里采用“组内分区 + 分区内 shuffle”的方式: +// - 先把同组账号按 (OAuth / 非 OAuth) 拆成两段,保持 OAuth 段在前; +// - 再分别在各段内随机打散,避免热点。 +func shuffleWithinPriorityAndLastUsed(accounts []*Account, preferOAuth bool) { if len(accounts) <= 1 { return } @@ -1984,9 +1989,29 @@ func shuffleWithinPriorityAndLastUsed(accounts []*Account) { j++ } if j-i > 1 { - mathrand.Shuffle(j-i, func(a, b int) { - accounts[i+a], accounts[i+b] = accounts[i+b], accounts[i+a] - }) + if preferOAuth { + oauth := make([]*Account, 0, j-i) + others := make([]*Account, 0, j-i) + for _, acc := range accounts[i:j] { + if acc.Type == AccountTypeOAuth { + oauth = append(oauth, acc) + } else { + others = append(others, acc) + } + } + if len(oauth) > 1 { + mathrand.Shuffle(len(oauth), func(a, b int) { oauth[a], oauth[b] = oauth[b], oauth[a] }) + } + if len(others) > 1 { + mathrand.Shuffle(len(others), func(a, b int) { others[a], others[b] = others[b], others[a] }) + } + copy(accounts[i:], oauth) + copy(accounts[i+len(oauth):], others) + } else { + mathrand.Shuffle(j-i, func(a, b int) { + accounts[i+a], accounts[i+b] = accounts[i+b], accounts[i+a] + }) + } } i = j } diff --git a/backend/internal/service/scheduler_shuffle_test.go b/backend/internal/service/scheduler_shuffle_test.go index 78ac5f57..0d82b2f3 100644 --- a/backend/internal/service/scheduler_shuffle_test.go +++ b/backend/internal/service/scheduler_shuffle_test.go @@ -125,13 +125,13 @@ func TestShuffleWithinSortGroups_MixedGroups(t *testing.T) { // ============ shuffleWithinPriorityAndLastUsed 测试 ============ func TestShuffleWithinPriorityAndLastUsed_Empty(t *testing.T) { - shuffleWithinPriorityAndLastUsed(nil) - shuffleWithinPriorityAndLastUsed([]*Account{}) + shuffleWithinPriorityAndLastUsed(nil, false) + shuffleWithinPriorityAndLastUsed([]*Account{}, false) } func TestShuffleWithinPriorityAndLastUsed_SingleElement(t *testing.T) { accounts := []*Account{{ID: 1, Priority: 1}} - shuffleWithinPriorityAndLastUsed(accounts) + shuffleWithinPriorityAndLastUsed(accounts, false) require.Equal(t, int64(1), accounts[0].ID) } @@ -146,7 +146,7 @@ func TestShuffleWithinPriorityAndLastUsed_SameGroup_Shuffled(t *testing.T) { for i := 0; i < 100; i++ { cpy := make([]*Account, len(accounts)) copy(cpy, accounts) - shuffleWithinPriorityAndLastUsed(cpy) + shuffleWithinPriorityAndLastUsed(cpy, false) seen[cpy[0].ID] = true } require.GreaterOrEqual(t, len(seen), 2, "same group should be shuffled") @@ -162,7 +162,7 @@ func TestShuffleWithinPriorityAndLastUsed_DifferentPriority_OrderPreserved(t *te for i := 0; i < 20; i++ { cpy := make([]*Account, len(accounts)) copy(cpy, accounts) - shuffleWithinPriorityAndLastUsed(cpy) + shuffleWithinPriorityAndLastUsed(cpy, false) require.Equal(t, int64(1), cpy[0].ID) require.Equal(t, int64(2), cpy[1].ID) require.Equal(t, int64(3), cpy[2].ID) @@ -182,7 +182,7 @@ func TestShuffleWithinPriorityAndLastUsed_DifferentLastUsedAt_OrderPreserved(t * for i := 0; i < 20; i++ { cpy := make([]*Account, len(accounts)) copy(cpy, accounts) - shuffleWithinPriorityAndLastUsed(cpy) + shuffleWithinPriorityAndLastUsed(cpy, false) require.Equal(t, int64(1), cpy[0].ID) require.Equal(t, int64(2), cpy[1].ID) require.Equal(t, int64(3), cpy[2].ID) diff --git a/backend/internal/testutil/stubs.go b/backend/internal/testutil/stubs.go index 81c40c42..3569db17 100644 --- a/backend/internal/testutil/stubs.go +++ b/backend/internal/testutil/stubs.go @@ -90,18 +90,6 @@ func (c StubGatewayCache) RefreshSessionTTL(_ context.Context, _ int64, _ string func (c StubGatewayCache) DeleteSessionAccountID(_ context.Context, _ int64, _ string) error { return nil } -func (c StubGatewayCache) IncrModelCallCount(_ context.Context, _ int64, _ string) (int64, error) { - return 0, nil -} -func (c StubGatewayCache) GetModelLoadBatch(_ context.Context, _ []int64, _ string) (map[int64]*service.ModelLoadInfo, error) { - return nil, nil -} -func (c StubGatewayCache) FindGeminiSession(_ context.Context, _ int64, _, _ string) (string, int64, bool) { - return "", 0, false -} -func (c StubGatewayCache) SaveGeminiSession(_ context.Context, _ int64, _, _, _ string, _ int64) error { - return nil -} // ============================================================ // StubSessionLimitCache — service.SessionLimitCache 的空实现