842 lines
24 KiB
Go
842 lines
24 KiB
Go
package service
|
||
|
||
import (
|
||
"context"
|
||
"fmt"
|
||
"math"
|
||
"sync"
|
||
"testing"
|
||
"time"
|
||
|
||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||
"github.com/stretchr/testify/require"
|
||
)
|
||
|
||
func TestOpenAIGatewayService_SelectAccountWithScheduler_PreviousResponseSticky(t *testing.T) {
|
||
ctx := context.Background()
|
||
groupID := int64(9)
|
||
account := Account{
|
||
ID: 1001,
|
||
Platform: PlatformOpenAI,
|
||
Type: AccountTypeAPIKey,
|
||
Status: StatusActive,
|
||
Schedulable: true,
|
||
Concurrency: 2,
|
||
Extra: map[string]any{
|
||
"openai_apikey_responses_websockets_v2_enabled": true,
|
||
},
|
||
}
|
||
cache := &stubGatewayCache{}
|
||
cfg := &config.Config{}
|
||
cfg.Gateway.OpenAIWS.Enabled = true
|
||
cfg.Gateway.OpenAIWS.OAuthEnabled = true
|
||
cfg.Gateway.OpenAIWS.APIKeyEnabled = true
|
||
cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true
|
||
cfg.Gateway.OpenAIWS.StickySessionTTLSeconds = 1800
|
||
cfg.Gateway.OpenAIWS.StickyResponseIDTTLSeconds = 3600
|
||
|
||
svc := &OpenAIGatewayService{
|
||
accountRepo: stubOpenAIAccountRepo{accounts: []Account{account}},
|
||
cache: cache,
|
||
cfg: cfg,
|
||
concurrencyService: NewConcurrencyService(stubConcurrencyCache{}),
|
||
}
|
||
|
||
store := svc.getOpenAIWSStateStore()
|
||
require.NoError(t, store.BindResponseAccount(ctx, groupID, "resp_prev_001", account.ID, time.Hour))
|
||
|
||
selection, decision, err := svc.SelectAccountWithScheduler(
|
||
ctx,
|
||
&groupID,
|
||
"resp_prev_001",
|
||
"session_hash_001",
|
||
"gpt-5.1",
|
||
nil,
|
||
OpenAIUpstreamTransportAny,
|
||
)
|
||
require.NoError(t, err)
|
||
require.NotNil(t, selection)
|
||
require.NotNil(t, selection.Account)
|
||
require.Equal(t, account.ID, selection.Account.ID)
|
||
require.Equal(t, openAIAccountScheduleLayerPreviousResponse, decision.Layer)
|
||
require.True(t, decision.StickyPreviousHit)
|
||
require.Equal(t, account.ID, cache.sessionBindings["openai:session_hash_001"])
|
||
if selection.ReleaseFunc != nil {
|
||
selection.ReleaseFunc()
|
||
}
|
||
}
|
||
|
||
func TestOpenAIGatewayService_SelectAccountWithScheduler_SessionSticky(t *testing.T) {
|
||
ctx := context.Background()
|
||
groupID := int64(10)
|
||
account := Account{
|
||
ID: 2001,
|
||
Platform: PlatformOpenAI,
|
||
Type: AccountTypeOAuth,
|
||
Status: StatusActive,
|
||
Schedulable: true,
|
||
Concurrency: 1,
|
||
}
|
||
cache := &stubGatewayCache{
|
||
sessionBindings: map[string]int64{
|
||
"openai:session_hash_abc": account.ID,
|
||
},
|
||
}
|
||
|
||
svc := &OpenAIGatewayService{
|
||
accountRepo: stubOpenAIAccountRepo{accounts: []Account{account}},
|
||
cache: cache,
|
||
cfg: &config.Config{},
|
||
concurrencyService: NewConcurrencyService(stubConcurrencyCache{}),
|
||
}
|
||
|
||
selection, decision, err := svc.SelectAccountWithScheduler(
|
||
ctx,
|
||
&groupID,
|
||
"",
|
||
"session_hash_abc",
|
||
"gpt-5.1",
|
||
nil,
|
||
OpenAIUpstreamTransportAny,
|
||
)
|
||
require.NoError(t, err)
|
||
require.NotNil(t, selection)
|
||
require.NotNil(t, selection.Account)
|
||
require.Equal(t, account.ID, selection.Account.ID)
|
||
require.Equal(t, openAIAccountScheduleLayerSessionSticky, decision.Layer)
|
||
require.True(t, decision.StickySessionHit)
|
||
if selection.ReleaseFunc != nil {
|
||
selection.ReleaseFunc()
|
||
}
|
||
}
|
||
|
||
func TestOpenAIGatewayService_SelectAccountWithScheduler_SessionStickyBusyKeepsSticky(t *testing.T) {
|
||
ctx := context.Background()
|
||
groupID := int64(10100)
|
||
accounts := []Account{
|
||
{
|
||
ID: 21001,
|
||
Platform: PlatformOpenAI,
|
||
Type: AccountTypeAPIKey,
|
||
Status: StatusActive,
|
||
Schedulable: true,
|
||
Concurrency: 1,
|
||
Priority: 0,
|
||
},
|
||
{
|
||
ID: 21002,
|
||
Platform: PlatformOpenAI,
|
||
Type: AccountTypeAPIKey,
|
||
Status: StatusActive,
|
||
Schedulable: true,
|
||
Concurrency: 1,
|
||
Priority: 9,
|
||
},
|
||
}
|
||
cache := &stubGatewayCache{
|
||
sessionBindings: map[string]int64{
|
||
"openai:session_hash_sticky_busy": 21001,
|
||
},
|
||
}
|
||
cfg := &config.Config{}
|
||
cfg.Gateway.Scheduling.StickySessionMaxWaiting = 2
|
||
cfg.Gateway.Scheduling.StickySessionWaitTimeout = 45 * time.Second
|
||
cfg.Gateway.OpenAIWS.Enabled = true
|
||
cfg.Gateway.OpenAIWS.APIKeyEnabled = true
|
||
cfg.Gateway.OpenAIWS.OAuthEnabled = true
|
||
cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true
|
||
|
||
concurrencyCache := stubConcurrencyCache{
|
||
acquireResults: map[int64]bool{
|
||
21001: false, // sticky 账号已满
|
||
21002: true, // 若回退负载均衡会命中该账号(本测试要求不能切换)
|
||
},
|
||
waitCounts: map[int64]int{
|
||
21001: 999,
|
||
},
|
||
loadMap: map[int64]*AccountLoadInfo{
|
||
21001: {AccountID: 21001, LoadRate: 90, WaitingCount: 9},
|
||
21002: {AccountID: 21002, LoadRate: 1, WaitingCount: 0},
|
||
},
|
||
}
|
||
|
||
svc := &OpenAIGatewayService{
|
||
accountRepo: stubOpenAIAccountRepo{accounts: accounts},
|
||
cache: cache,
|
||
cfg: cfg,
|
||
concurrencyService: NewConcurrencyService(concurrencyCache),
|
||
}
|
||
|
||
selection, decision, err := svc.SelectAccountWithScheduler(
|
||
ctx,
|
||
&groupID,
|
||
"",
|
||
"session_hash_sticky_busy",
|
||
"gpt-5.1",
|
||
nil,
|
||
OpenAIUpstreamTransportAny,
|
||
)
|
||
require.NoError(t, err)
|
||
require.NotNil(t, selection)
|
||
require.NotNil(t, selection.Account)
|
||
require.Equal(t, int64(21001), selection.Account.ID, "busy sticky account should remain selected")
|
||
require.False(t, selection.Acquired)
|
||
require.NotNil(t, selection.WaitPlan)
|
||
require.Equal(t, int64(21001), selection.WaitPlan.AccountID)
|
||
require.Equal(t, openAIAccountScheduleLayerSessionSticky, decision.Layer)
|
||
require.True(t, decision.StickySessionHit)
|
||
}
|
||
|
||
func TestOpenAIGatewayService_SelectAccountWithScheduler_SessionSticky_ForceHTTP(t *testing.T) {
|
||
ctx := context.Background()
|
||
groupID := int64(1010)
|
||
account := Account{
|
||
ID: 2101,
|
||
Platform: PlatformOpenAI,
|
||
Type: AccountTypeOAuth,
|
||
Status: StatusActive,
|
||
Schedulable: true,
|
||
Concurrency: 1,
|
||
Extra: map[string]any{
|
||
"openai_ws_force_http": true,
|
||
},
|
||
}
|
||
cache := &stubGatewayCache{
|
||
sessionBindings: map[string]int64{
|
||
"openai:session_hash_force_http": account.ID,
|
||
},
|
||
}
|
||
|
||
svc := &OpenAIGatewayService{
|
||
accountRepo: stubOpenAIAccountRepo{accounts: []Account{account}},
|
||
cache: cache,
|
||
cfg: &config.Config{},
|
||
concurrencyService: NewConcurrencyService(stubConcurrencyCache{}),
|
||
}
|
||
|
||
selection, decision, err := svc.SelectAccountWithScheduler(
|
||
ctx,
|
||
&groupID,
|
||
"",
|
||
"session_hash_force_http",
|
||
"gpt-5.1",
|
||
nil,
|
||
OpenAIUpstreamTransportAny,
|
||
)
|
||
require.NoError(t, err)
|
||
require.NotNil(t, selection)
|
||
require.NotNil(t, selection.Account)
|
||
require.Equal(t, account.ID, selection.Account.ID)
|
||
require.Equal(t, openAIAccountScheduleLayerSessionSticky, decision.Layer)
|
||
require.True(t, decision.StickySessionHit)
|
||
if selection.ReleaseFunc != nil {
|
||
selection.ReleaseFunc()
|
||
}
|
||
}
|
||
|
||
func TestOpenAIGatewayService_SelectAccountWithScheduler_RequiredWSV2_SkipsStickyHTTPAccount(t *testing.T) {
|
||
ctx := context.Background()
|
||
groupID := int64(1011)
|
||
accounts := []Account{
|
||
{
|
||
ID: 2201,
|
||
Platform: PlatformOpenAI,
|
||
Type: AccountTypeAPIKey,
|
||
Status: StatusActive,
|
||
Schedulable: true,
|
||
Concurrency: 1,
|
||
Priority: 0,
|
||
},
|
||
{
|
||
ID: 2202,
|
||
Platform: PlatformOpenAI,
|
||
Type: AccountTypeAPIKey,
|
||
Status: StatusActive,
|
||
Schedulable: true,
|
||
Concurrency: 1,
|
||
Priority: 5,
|
||
Extra: map[string]any{
|
||
"openai_apikey_responses_websockets_v2_enabled": true,
|
||
},
|
||
},
|
||
}
|
||
cache := &stubGatewayCache{
|
||
sessionBindings: map[string]int64{
|
||
"openai:session_hash_ws_only": 2201,
|
||
},
|
||
}
|
||
cfg := newOpenAIWSV2TestConfig()
|
||
|
||
// 构造“HTTP-only 账号负载更低”的场景,验证 required transport 会强制过滤。
|
||
concurrencyCache := stubConcurrencyCache{
|
||
loadMap: map[int64]*AccountLoadInfo{
|
||
2201: {AccountID: 2201, LoadRate: 0, WaitingCount: 0},
|
||
2202: {AccountID: 2202, LoadRate: 90, WaitingCount: 5},
|
||
},
|
||
}
|
||
|
||
svc := &OpenAIGatewayService{
|
||
accountRepo: stubOpenAIAccountRepo{accounts: accounts},
|
||
cache: cache,
|
||
cfg: cfg,
|
||
concurrencyService: NewConcurrencyService(concurrencyCache),
|
||
}
|
||
|
||
selection, decision, err := svc.SelectAccountWithScheduler(
|
||
ctx,
|
||
&groupID,
|
||
"",
|
||
"session_hash_ws_only",
|
||
"gpt-5.1",
|
||
nil,
|
||
OpenAIUpstreamTransportResponsesWebsocketV2,
|
||
)
|
||
require.NoError(t, err)
|
||
require.NotNil(t, selection)
|
||
require.NotNil(t, selection.Account)
|
||
require.Equal(t, int64(2202), selection.Account.ID)
|
||
require.Equal(t, openAIAccountScheduleLayerLoadBalance, decision.Layer)
|
||
require.False(t, decision.StickySessionHit)
|
||
require.Equal(t, 1, decision.CandidateCount)
|
||
if selection.ReleaseFunc != nil {
|
||
selection.ReleaseFunc()
|
||
}
|
||
}
|
||
|
||
func TestOpenAIGatewayService_SelectAccountWithScheduler_RequiredWSV2_NoAvailableAccount(t *testing.T) {
|
||
ctx := context.Background()
|
||
groupID := int64(1012)
|
||
accounts := []Account{
|
||
{
|
||
ID: 2301,
|
||
Platform: PlatformOpenAI,
|
||
Type: AccountTypeOAuth,
|
||
Status: StatusActive,
|
||
Schedulable: true,
|
||
Concurrency: 1,
|
||
},
|
||
}
|
||
|
||
svc := &OpenAIGatewayService{
|
||
accountRepo: stubOpenAIAccountRepo{accounts: accounts},
|
||
cache: &stubGatewayCache{},
|
||
cfg: newOpenAIWSV2TestConfig(),
|
||
concurrencyService: NewConcurrencyService(stubConcurrencyCache{}),
|
||
}
|
||
|
||
selection, decision, err := svc.SelectAccountWithScheduler(
|
||
ctx,
|
||
&groupID,
|
||
"",
|
||
"",
|
||
"gpt-5.1",
|
||
nil,
|
||
OpenAIUpstreamTransportResponsesWebsocketV2,
|
||
)
|
||
require.Error(t, err)
|
||
require.Nil(t, selection)
|
||
require.Equal(t, openAIAccountScheduleLayerLoadBalance, decision.Layer)
|
||
require.Equal(t, 0, decision.CandidateCount)
|
||
}
|
||
|
||
func TestOpenAIGatewayService_SelectAccountWithScheduler_LoadBalanceTopKFallback(t *testing.T) {
|
||
ctx := context.Background()
|
||
groupID := int64(11)
|
||
accounts := []Account{
|
||
{
|
||
ID: 3001,
|
||
Platform: PlatformOpenAI,
|
||
Type: AccountTypeAPIKey,
|
||
Status: StatusActive,
|
||
Schedulable: true,
|
||
Concurrency: 1,
|
||
Priority: 0,
|
||
},
|
||
{
|
||
ID: 3002,
|
||
Platform: PlatformOpenAI,
|
||
Type: AccountTypeAPIKey,
|
||
Status: StatusActive,
|
||
Schedulable: true,
|
||
Concurrency: 1,
|
||
Priority: 0,
|
||
},
|
||
{
|
||
ID: 3003,
|
||
Platform: PlatformOpenAI,
|
||
Type: AccountTypeAPIKey,
|
||
Status: StatusActive,
|
||
Schedulable: true,
|
||
Concurrency: 1,
|
||
Priority: 0,
|
||
},
|
||
}
|
||
|
||
cfg := &config.Config{}
|
||
cfg.Gateway.OpenAIWS.LBTopK = 2
|
||
cfg.Gateway.OpenAIWS.SchedulerScoreWeights.Priority = 0.4
|
||
cfg.Gateway.OpenAIWS.SchedulerScoreWeights.Load = 1.0
|
||
cfg.Gateway.OpenAIWS.SchedulerScoreWeights.Queue = 1.0
|
||
cfg.Gateway.OpenAIWS.SchedulerScoreWeights.ErrorRate = 0.2
|
||
cfg.Gateway.OpenAIWS.SchedulerScoreWeights.TTFT = 0.1
|
||
|
||
concurrencyCache := stubConcurrencyCache{
|
||
loadMap: map[int64]*AccountLoadInfo{
|
||
3001: {AccountID: 3001, LoadRate: 95, WaitingCount: 8},
|
||
3002: {AccountID: 3002, LoadRate: 20, WaitingCount: 1},
|
||
3003: {AccountID: 3003, LoadRate: 10, WaitingCount: 0},
|
||
},
|
||
acquireResults: map[int64]bool{
|
||
3003: false, // top1 失败,必须回退到 top-K 的下一候选
|
||
3002: true,
|
||
},
|
||
}
|
||
|
||
svc := &OpenAIGatewayService{
|
||
accountRepo: stubOpenAIAccountRepo{accounts: accounts},
|
||
cache: &stubGatewayCache{},
|
||
cfg: cfg,
|
||
concurrencyService: NewConcurrencyService(concurrencyCache),
|
||
}
|
||
|
||
selection, decision, err := svc.SelectAccountWithScheduler(
|
||
ctx,
|
||
&groupID,
|
||
"",
|
||
"",
|
||
"gpt-5.1",
|
||
nil,
|
||
OpenAIUpstreamTransportAny,
|
||
)
|
||
require.NoError(t, err)
|
||
require.NotNil(t, selection)
|
||
require.NotNil(t, selection.Account)
|
||
require.Equal(t, int64(3002), selection.Account.ID)
|
||
require.Equal(t, openAIAccountScheduleLayerLoadBalance, decision.Layer)
|
||
require.Equal(t, 3, decision.CandidateCount)
|
||
require.Equal(t, 2, decision.TopK)
|
||
require.Greater(t, decision.LoadSkew, 0.0)
|
||
if selection.ReleaseFunc != nil {
|
||
selection.ReleaseFunc()
|
||
}
|
||
}
|
||
|
||
func TestOpenAIGatewayService_OpenAIAccountSchedulerMetrics(t *testing.T) {
|
||
ctx := context.Background()
|
||
groupID := int64(12)
|
||
account := Account{
|
||
ID: 4001,
|
||
Platform: PlatformOpenAI,
|
||
Type: AccountTypeAPIKey,
|
||
Status: StatusActive,
|
||
Schedulable: true,
|
||
Concurrency: 1,
|
||
}
|
||
cache := &stubGatewayCache{
|
||
sessionBindings: map[string]int64{
|
||
"openai:session_hash_metrics": account.ID,
|
||
},
|
||
}
|
||
svc := &OpenAIGatewayService{
|
||
accountRepo: stubOpenAIAccountRepo{accounts: []Account{account}},
|
||
cache: cache,
|
||
cfg: &config.Config{},
|
||
concurrencyService: NewConcurrencyService(stubConcurrencyCache{}),
|
||
}
|
||
|
||
selection, _, err := svc.SelectAccountWithScheduler(ctx, &groupID, "", "session_hash_metrics", "gpt-5.1", nil, OpenAIUpstreamTransportAny)
|
||
require.NoError(t, err)
|
||
require.NotNil(t, selection)
|
||
svc.ReportOpenAIAccountScheduleResult(account.ID, true, intPtrForTest(120))
|
||
svc.RecordOpenAIAccountSwitch()
|
||
|
||
snapshot := svc.SnapshotOpenAIAccountSchedulerMetrics()
|
||
require.GreaterOrEqual(t, snapshot.SelectTotal, int64(1))
|
||
require.GreaterOrEqual(t, snapshot.StickySessionHitTotal, int64(1))
|
||
require.GreaterOrEqual(t, snapshot.AccountSwitchTotal, int64(1))
|
||
require.GreaterOrEqual(t, snapshot.SchedulerLatencyMsAvg, float64(0))
|
||
require.GreaterOrEqual(t, snapshot.StickyHitRatio, 0.0)
|
||
require.GreaterOrEqual(t, snapshot.RuntimeStatsAccountCount, 1)
|
||
}
|
||
|
||
func intPtrForTest(v int) *int {
|
||
return &v
|
||
}
|
||
|
||
func TestOpenAIAccountRuntimeStats_ReportAndSnapshot(t *testing.T) {
|
||
stats := newOpenAIAccountRuntimeStats()
|
||
stats.report(1001, true, nil)
|
||
firstTTFT := 100
|
||
stats.report(1001, false, &firstTTFT)
|
||
secondTTFT := 200
|
||
stats.report(1001, false, &secondTTFT)
|
||
|
||
errorRate, ttft, hasTTFT := stats.snapshot(1001)
|
||
require.True(t, hasTTFT)
|
||
require.InDelta(t, 0.36, errorRate, 1e-9)
|
||
require.InDelta(t, 120.0, ttft, 1e-9)
|
||
require.Equal(t, 1, stats.size())
|
||
}
|
||
|
||
func TestOpenAIAccountRuntimeStats_ReportConcurrent(t *testing.T) {
|
||
stats := newOpenAIAccountRuntimeStats()
|
||
|
||
const (
|
||
accountCount = 4
|
||
workers = 16
|
||
iterations = 800
|
||
)
|
||
var wg sync.WaitGroup
|
||
wg.Add(workers)
|
||
for worker := 0; worker < workers; worker++ {
|
||
worker := worker
|
||
go func() {
|
||
defer wg.Done()
|
||
for i := 0; i < iterations; i++ {
|
||
accountID := int64(i%accountCount + 1)
|
||
success := (i+worker)%3 != 0
|
||
ttft := 80 + (i+worker)%40
|
||
stats.report(accountID, success, &ttft)
|
||
}
|
||
}()
|
||
}
|
||
wg.Wait()
|
||
|
||
require.Equal(t, accountCount, stats.size())
|
||
for accountID := int64(1); accountID <= accountCount; accountID++ {
|
||
errorRate, ttft, hasTTFT := stats.snapshot(accountID)
|
||
require.GreaterOrEqual(t, errorRate, 0.0)
|
||
require.LessOrEqual(t, errorRate, 1.0)
|
||
require.True(t, hasTTFT)
|
||
require.Greater(t, ttft, 0.0)
|
||
}
|
||
}
|
||
|
||
func TestSelectTopKOpenAICandidates(t *testing.T) {
|
||
candidates := []openAIAccountCandidateScore{
|
||
{
|
||
account: &Account{ID: 11, Priority: 2},
|
||
loadInfo: &AccountLoadInfo{LoadRate: 10, WaitingCount: 1},
|
||
score: 10.0,
|
||
},
|
||
{
|
||
account: &Account{ID: 12, Priority: 1},
|
||
loadInfo: &AccountLoadInfo{LoadRate: 20, WaitingCount: 1},
|
||
score: 9.5,
|
||
},
|
||
{
|
||
account: &Account{ID: 13, Priority: 1},
|
||
loadInfo: &AccountLoadInfo{LoadRate: 30, WaitingCount: 0},
|
||
score: 10.0,
|
||
},
|
||
{
|
||
account: &Account{ID: 14, Priority: 0},
|
||
loadInfo: &AccountLoadInfo{LoadRate: 40, WaitingCount: 0},
|
||
score: 8.0,
|
||
},
|
||
}
|
||
|
||
top2 := selectTopKOpenAICandidates(candidates, 2)
|
||
require.Len(t, top2, 2)
|
||
require.Equal(t, int64(13), top2[0].account.ID)
|
||
require.Equal(t, int64(11), top2[1].account.ID)
|
||
|
||
topAll := selectTopKOpenAICandidates(candidates, 8)
|
||
require.Len(t, topAll, len(candidates))
|
||
require.Equal(t, int64(13), topAll[0].account.ID)
|
||
require.Equal(t, int64(11), topAll[1].account.ID)
|
||
require.Equal(t, int64(12), topAll[2].account.ID)
|
||
require.Equal(t, int64(14), topAll[3].account.ID)
|
||
}
|
||
|
||
func TestBuildOpenAIWeightedSelectionOrder_DeterministicBySessionSeed(t *testing.T) {
|
||
candidates := []openAIAccountCandidateScore{
|
||
{
|
||
account: &Account{ID: 101},
|
||
loadInfo: &AccountLoadInfo{LoadRate: 10, WaitingCount: 0},
|
||
score: 4.2,
|
||
},
|
||
{
|
||
account: &Account{ID: 102},
|
||
loadInfo: &AccountLoadInfo{LoadRate: 30, WaitingCount: 1},
|
||
score: 3.5,
|
||
},
|
||
{
|
||
account: &Account{ID: 103},
|
||
loadInfo: &AccountLoadInfo{LoadRate: 50, WaitingCount: 2},
|
||
score: 2.1,
|
||
},
|
||
}
|
||
req := OpenAIAccountScheduleRequest{
|
||
GroupID: int64PtrForTest(99),
|
||
SessionHash: "session_seed_fixed",
|
||
RequestedModel: "gpt-5.1",
|
||
}
|
||
|
||
first := buildOpenAIWeightedSelectionOrder(candidates, req)
|
||
second := buildOpenAIWeightedSelectionOrder(candidates, req)
|
||
require.Len(t, first, len(candidates))
|
||
require.Len(t, second, len(candidates))
|
||
for i := range first {
|
||
require.Equal(t, first[i].account.ID, second[i].account.ID)
|
||
}
|
||
}
|
||
|
||
func TestOpenAIGatewayService_SelectAccountWithScheduler_LoadBalanceDistributesAcrossSessions(t *testing.T) {
|
||
ctx := context.Background()
|
||
groupID := int64(15)
|
||
accounts := []Account{
|
||
{
|
||
ID: 5101,
|
||
Platform: PlatformOpenAI,
|
||
Type: AccountTypeAPIKey,
|
||
Status: StatusActive,
|
||
Schedulable: true,
|
||
Concurrency: 3,
|
||
Priority: 0,
|
||
},
|
||
{
|
||
ID: 5102,
|
||
Platform: PlatformOpenAI,
|
||
Type: AccountTypeAPIKey,
|
||
Status: StatusActive,
|
||
Schedulable: true,
|
||
Concurrency: 3,
|
||
Priority: 0,
|
||
},
|
||
{
|
||
ID: 5103,
|
||
Platform: PlatformOpenAI,
|
||
Type: AccountTypeAPIKey,
|
||
Status: StatusActive,
|
||
Schedulable: true,
|
||
Concurrency: 3,
|
||
Priority: 0,
|
||
},
|
||
}
|
||
cfg := &config.Config{}
|
||
cfg.Gateway.OpenAIWS.LBTopK = 3
|
||
cfg.Gateway.OpenAIWS.SchedulerScoreWeights.Priority = 1
|
||
cfg.Gateway.OpenAIWS.SchedulerScoreWeights.Load = 1
|
||
cfg.Gateway.OpenAIWS.SchedulerScoreWeights.Queue = 1
|
||
cfg.Gateway.OpenAIWS.SchedulerScoreWeights.ErrorRate = 1
|
||
cfg.Gateway.OpenAIWS.SchedulerScoreWeights.TTFT = 1
|
||
|
||
concurrencyCache := stubConcurrencyCache{
|
||
loadMap: map[int64]*AccountLoadInfo{
|
||
5101: {AccountID: 5101, LoadRate: 20, WaitingCount: 1},
|
||
5102: {AccountID: 5102, LoadRate: 20, WaitingCount: 1},
|
||
5103: {AccountID: 5103, LoadRate: 20, WaitingCount: 1},
|
||
},
|
||
}
|
||
svc := &OpenAIGatewayService{
|
||
accountRepo: stubOpenAIAccountRepo{accounts: accounts},
|
||
cache: &stubGatewayCache{sessionBindings: map[string]int64{}},
|
||
cfg: cfg,
|
||
concurrencyService: NewConcurrencyService(concurrencyCache),
|
||
}
|
||
|
||
selected := make(map[int64]int, len(accounts))
|
||
for i := 0; i < 60; i++ {
|
||
sessionHash := fmt.Sprintf("session_hash_lb_%d", i)
|
||
selection, decision, err := svc.SelectAccountWithScheduler(
|
||
ctx,
|
||
&groupID,
|
||
"",
|
||
sessionHash,
|
||
"gpt-5.1",
|
||
nil,
|
||
OpenAIUpstreamTransportAny,
|
||
)
|
||
require.NoError(t, err)
|
||
require.NotNil(t, selection)
|
||
require.NotNil(t, selection.Account)
|
||
require.Equal(t, openAIAccountScheduleLayerLoadBalance, decision.Layer)
|
||
selected[selection.Account.ID]++
|
||
if selection.ReleaseFunc != nil {
|
||
selection.ReleaseFunc()
|
||
}
|
||
}
|
||
|
||
// 多 session 应该能打散到多个账号,避免“恒定单账号命中”。
|
||
require.GreaterOrEqual(t, len(selected), 2)
|
||
}
|
||
|
||
func TestDeriveOpenAISelectionSeed_NoAffinityAddsEntropy(t *testing.T) {
|
||
req := OpenAIAccountScheduleRequest{
|
||
RequestedModel: "gpt-5.1",
|
||
}
|
||
seed1 := deriveOpenAISelectionSeed(req)
|
||
time.Sleep(1 * time.Millisecond)
|
||
seed2 := deriveOpenAISelectionSeed(req)
|
||
require.NotZero(t, seed1)
|
||
require.NotZero(t, seed2)
|
||
require.NotEqual(t, seed1, seed2)
|
||
}
|
||
|
||
func TestBuildOpenAIWeightedSelectionOrder_HandlesInvalidScores(t *testing.T) {
|
||
candidates := []openAIAccountCandidateScore{
|
||
{
|
||
account: &Account{ID: 901},
|
||
loadInfo: &AccountLoadInfo{LoadRate: 5, WaitingCount: 0},
|
||
score: math.NaN(),
|
||
},
|
||
{
|
||
account: &Account{ID: 902},
|
||
loadInfo: &AccountLoadInfo{LoadRate: 5, WaitingCount: 0},
|
||
score: math.Inf(1),
|
||
},
|
||
{
|
||
account: &Account{ID: 903},
|
||
loadInfo: &AccountLoadInfo{LoadRate: 5, WaitingCount: 0},
|
||
score: -1,
|
||
},
|
||
}
|
||
req := OpenAIAccountScheduleRequest{
|
||
SessionHash: "seed_invalid_scores",
|
||
}
|
||
|
||
order := buildOpenAIWeightedSelectionOrder(candidates, req)
|
||
require.Len(t, order, len(candidates))
|
||
seen := map[int64]struct{}{}
|
||
for _, item := range order {
|
||
seen[item.account.ID] = struct{}{}
|
||
}
|
||
require.Len(t, seen, len(candidates))
|
||
}
|
||
|
||
func TestOpenAISelectionRNG_SeedZeroStillWorks(t *testing.T) {
|
||
rng := newOpenAISelectionRNG(0)
|
||
v1 := rng.nextUint64()
|
||
v2 := rng.nextUint64()
|
||
require.NotEqual(t, v1, v2)
|
||
require.GreaterOrEqual(t, rng.nextFloat64(), 0.0)
|
||
require.Less(t, rng.nextFloat64(), 1.0)
|
||
}
|
||
|
||
func TestOpenAIAccountCandidateHeap_PushPopAndInvalidType(t *testing.T) {
|
||
h := openAIAccountCandidateHeap{}
|
||
h.Push(openAIAccountCandidateScore{
|
||
account: &Account{ID: 7001},
|
||
loadInfo: &AccountLoadInfo{LoadRate: 0, WaitingCount: 0},
|
||
score: 1.0,
|
||
})
|
||
require.Equal(t, 1, h.Len())
|
||
popped, ok := h.Pop().(openAIAccountCandidateScore)
|
||
require.True(t, ok)
|
||
require.Equal(t, int64(7001), popped.account.ID)
|
||
require.Equal(t, 0, h.Len())
|
||
|
||
require.Panics(t, func() {
|
||
h.Push("bad_element_type")
|
||
})
|
||
}
|
||
|
||
func TestClamp01_AllBranches(t *testing.T) {
|
||
require.Equal(t, 0.0, clamp01(-0.2))
|
||
require.Equal(t, 1.0, clamp01(1.3))
|
||
require.Equal(t, 0.5, clamp01(0.5))
|
||
}
|
||
|
||
func TestCalcLoadSkewByMoments_Branches(t *testing.T) {
|
||
require.Equal(t, 0.0, calcLoadSkewByMoments(1, 1, 1))
|
||
// variance < 0 分支:sumSquares/count - mean^2 为负值时应钳制为 0。
|
||
require.Equal(t, 0.0, calcLoadSkewByMoments(1, 0, 2))
|
||
require.GreaterOrEqual(t, calcLoadSkewByMoments(6, 20, 3), 0.0)
|
||
}
|
||
|
||
func TestDefaultOpenAIAccountScheduler_ReportSwitchAndSnapshot(t *testing.T) {
|
||
schedulerAny := newDefaultOpenAIAccountScheduler(&OpenAIGatewayService{}, nil)
|
||
scheduler, ok := schedulerAny.(*defaultOpenAIAccountScheduler)
|
||
require.True(t, ok)
|
||
|
||
ttft := 100
|
||
scheduler.ReportResult(1001, true, &ttft)
|
||
scheduler.ReportSwitch()
|
||
scheduler.metrics.recordSelect(OpenAIAccountScheduleDecision{
|
||
Layer: openAIAccountScheduleLayerLoadBalance,
|
||
LatencyMs: 8,
|
||
LoadSkew: 0.5,
|
||
StickyPreviousHit: true,
|
||
})
|
||
scheduler.metrics.recordSelect(OpenAIAccountScheduleDecision{
|
||
Layer: openAIAccountScheduleLayerSessionSticky,
|
||
LatencyMs: 6,
|
||
LoadSkew: 0.2,
|
||
StickySessionHit: true,
|
||
})
|
||
|
||
snapshot := scheduler.SnapshotMetrics()
|
||
require.Equal(t, int64(2), snapshot.SelectTotal)
|
||
require.Equal(t, int64(1), snapshot.StickyPreviousHitTotal)
|
||
require.Equal(t, int64(1), snapshot.StickySessionHitTotal)
|
||
require.Equal(t, int64(1), snapshot.LoadBalanceSelectTotal)
|
||
require.Equal(t, int64(1), snapshot.AccountSwitchTotal)
|
||
require.Greater(t, snapshot.SchedulerLatencyMsAvg, 0.0)
|
||
require.Greater(t, snapshot.StickyHitRatio, 0.0)
|
||
require.Greater(t, snapshot.LoadSkewAvg, 0.0)
|
||
}
|
||
|
||
func TestOpenAIGatewayService_SchedulerWrappersAndDefaults(t *testing.T) {
|
||
svc := &OpenAIGatewayService{}
|
||
ttft := 120
|
||
svc.ReportOpenAIAccountScheduleResult(10, true, &ttft)
|
||
svc.RecordOpenAIAccountSwitch()
|
||
snapshot := svc.SnapshotOpenAIAccountSchedulerMetrics()
|
||
require.GreaterOrEqual(t, snapshot.AccountSwitchTotal, int64(1))
|
||
require.Equal(t, 7, svc.openAIWSLBTopK())
|
||
require.Equal(t, openaiStickySessionTTL, svc.openAIWSSessionStickyTTL())
|
||
|
||
defaultWeights := svc.openAIWSSchedulerWeights()
|
||
require.Equal(t, 1.0, defaultWeights.Priority)
|
||
require.Equal(t, 1.0, defaultWeights.Load)
|
||
require.Equal(t, 0.7, defaultWeights.Queue)
|
||
require.Equal(t, 0.8, defaultWeights.ErrorRate)
|
||
require.Equal(t, 0.5, defaultWeights.TTFT)
|
||
|
||
cfg := &config.Config{}
|
||
cfg.Gateway.OpenAIWS.LBTopK = 9
|
||
cfg.Gateway.OpenAIWS.StickySessionTTLSeconds = 180
|
||
cfg.Gateway.OpenAIWS.SchedulerScoreWeights.Priority = 0.2
|
||
cfg.Gateway.OpenAIWS.SchedulerScoreWeights.Load = 0.3
|
||
cfg.Gateway.OpenAIWS.SchedulerScoreWeights.Queue = 0.4
|
||
cfg.Gateway.OpenAIWS.SchedulerScoreWeights.ErrorRate = 0.5
|
||
cfg.Gateway.OpenAIWS.SchedulerScoreWeights.TTFT = 0.6
|
||
svcWithCfg := &OpenAIGatewayService{cfg: cfg}
|
||
|
||
require.Equal(t, 9, svcWithCfg.openAIWSLBTopK())
|
||
require.Equal(t, 180*time.Second, svcWithCfg.openAIWSSessionStickyTTL())
|
||
customWeights := svcWithCfg.openAIWSSchedulerWeights()
|
||
require.Equal(t, 0.2, customWeights.Priority)
|
||
require.Equal(t, 0.3, customWeights.Load)
|
||
require.Equal(t, 0.4, customWeights.Queue)
|
||
require.Equal(t, 0.5, customWeights.ErrorRate)
|
||
require.Equal(t, 0.6, customWeights.TTFT)
|
||
}
|
||
|
||
func TestDefaultOpenAIAccountScheduler_IsAccountTransportCompatible_Branches(t *testing.T) {
|
||
scheduler := &defaultOpenAIAccountScheduler{}
|
||
require.True(t, scheduler.isAccountTransportCompatible(nil, OpenAIUpstreamTransportAny))
|
||
require.True(t, scheduler.isAccountTransportCompatible(nil, OpenAIUpstreamTransportHTTPSSE))
|
||
require.False(t, scheduler.isAccountTransportCompatible(nil, OpenAIUpstreamTransportResponsesWebsocketV2))
|
||
|
||
cfg := newOpenAIWSV2TestConfig()
|
||
scheduler.service = &OpenAIGatewayService{cfg: cfg}
|
||
account := &Account{
|
||
ID: 8801,
|
||
Platform: PlatformOpenAI,
|
||
Type: AccountTypeAPIKey,
|
||
Status: StatusActive,
|
||
Schedulable: true,
|
||
Concurrency: 1,
|
||
Extra: map[string]any{
|
||
"openai_apikey_responses_websockets_v2_enabled": true,
|
||
},
|
||
}
|
||
require.True(t, scheduler.isAccountTransportCompatible(account, OpenAIUpstreamTransportResponsesWebsocketV2))
|
||
}
|
||
|
||
func int64PtrForTest(v int64) *int64 {
|
||
return &v
|
||
}
|