From ace082066a14824462cb45386f43626ed7db9547 Mon Sep 17 00:00:00 2001 From: IanShaw027 Date: Tue, 21 Apr 2026 13:50:55 +0800 Subject: [PATCH] fix: honor ws transport when scheduler is disabled --- .../service/openai_account_scheduler.go | 59 ++++++++++-- .../service/openai_account_scheduler_test.go | 92 +++++++++++++++++++ 2 files changed, 143 insertions(+), 8 deletions(-) diff --git a/backend/internal/service/openai_account_scheduler.go b/backend/internal/service/openai_account_scheduler.go index 09e60220..38b92b47 100644 --- a/backend/internal/service/openai_account_scheduler.go +++ b/backend/internal/service/openai_account_scheduler.go @@ -767,14 +767,10 @@ func (s *defaultOpenAIAccountScheduler) selectByLoadBalance( } func (s *defaultOpenAIAccountScheduler) isAccountTransportCompatible(account *Account, requiredTransport OpenAIUpstreamTransport) bool { - // HTTP 入站可回退到 HTTP 线路,不需要在账号选择阶段做传输协议强过滤。 - if requiredTransport == OpenAIUpstreamTransportAny || requiredTransport == OpenAIUpstreamTransportHTTPSSE { - return true - } - if s == nil || s.service == nil || account == nil { + if s == nil || s.service == nil { return false } - return s.service.getOpenAIWSProtocolResolver().Resolve(account).Transport == requiredTransport + return s.service.isOpenAIAccountTransportCompatible(account, requiredTransport) } func (s *defaultOpenAIAccountScheduler) ReportResult(accountID int64, success bool, firstTokenMs *int) { @@ -899,9 +895,35 @@ func (s *OpenAIGatewayService) SelectAccountWithScheduler( decision := OpenAIAccountScheduleDecision{} scheduler := s.getOpenAIAccountScheduler(ctx) if scheduler == nil { - selection, err := s.SelectAccountWithLoadAwareness(ctx, groupID, sessionHash, requestedModel, excludedIDs) decision.Layer = openAIAccountScheduleLayerLoadBalance - return selection, decision, err + if requiredTransport == OpenAIUpstreamTransportAny || requiredTransport == OpenAIUpstreamTransportHTTPSSE { + selection, err := s.SelectAccountWithLoadAwareness(ctx, groupID, sessionHash, requestedModel, excludedIDs) + return selection, decision, err + } + + effectiveExcludedIDs := cloneExcludedAccountIDs(excludedIDs) + for { + selection, err := s.SelectAccountWithLoadAwareness(ctx, groupID, sessionHash, requestedModel, effectiveExcludedIDs) + if err != nil { + return nil, decision, err + } + if selection == nil || selection.Account == nil { + return selection, decision, nil + } + if s.isOpenAIAccountTransportCompatible(selection.Account, requiredTransport) { + return selection, decision, nil + } + if selection.ReleaseFunc != nil { + selection.ReleaseFunc() + } + if effectiveExcludedIDs == nil { + effectiveExcludedIDs = make(map[int64]struct{}) + } + if _, exists := effectiveExcludedIDs[selection.Account.ID]; exists { + return nil, decision, ErrNoAvailableAccounts + } + effectiveExcludedIDs[selection.Account.ID] = struct{}{} + } } var stickyAccountID int64 @@ -922,6 +944,27 @@ func (s *OpenAIGatewayService) SelectAccountWithScheduler( }) } +func cloneExcludedAccountIDs(excludedIDs map[int64]struct{}) map[int64]struct{} { + if len(excludedIDs) == 0 { + return nil + } + cloned := make(map[int64]struct{}, len(excludedIDs)) + for id := range excludedIDs { + cloned[id] = struct{}{} + } + return cloned +} + +func (s *OpenAIGatewayService) isOpenAIAccountTransportCompatible(account *Account, requiredTransport OpenAIUpstreamTransport) bool { + if requiredTransport == OpenAIUpstreamTransportAny || requiredTransport == OpenAIUpstreamTransportHTTPSSE { + return true + } + if s == nil || account == nil { + return false + } + return s.getOpenAIWSProtocolResolver().Resolve(account).Transport == requiredTransport +} + func (s *OpenAIGatewayService) ReportOpenAIAccountScheduleResult(accountID int64, success bool, firstTokenMs *int) { scheduler := s.getOpenAIAccountScheduler(context.Background()) if scheduler == nil { diff --git a/backend/internal/service/openai_account_scheduler_test.go b/backend/internal/service/openai_account_scheduler_test.go index a54f2614..b02370cb 100644 --- a/backend/internal/service/openai_account_scheduler_test.go +++ b/backend/internal/service/openai_account_scheduler_test.go @@ -298,6 +298,98 @@ func TestOpenAIGatewayService_SelectAccountWithScheduler_DefaultDisabledUsesLega require.False(t, decision.StickyPreviousHit) } +func TestOpenAIGatewayService_SelectAccountWithScheduler_DefaultDisabled_RequiredWSV2_SkipsHTTPOnlyAccount(t *testing.T) { + resetOpenAIAdvancedSchedulerSettingCacheForTest() + + ctx := context.Background() + groupID := int64(10108) + accounts := []Account{ + { + ID: 36011, + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + Status: StatusActive, + Schedulable: true, + Concurrency: 1, + Priority: 0, + }, + { + ID: 36012, + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + Status: StatusActive, + Schedulable: true, + Concurrency: 1, + Priority: 5, + Extra: map[string]any{ + "openai_apikey_responses_websockets_v2_enabled": true, + }, + }, + } + cfg := newSchedulerTestOpenAIWSV2Config() + cfg.Gateway.Scheduling.LoadBatchEnabled = false + svc := &OpenAIGatewayService{ + accountRepo: schedulerTestOpenAIAccountRepo{accounts: accounts}, + cache: &schedulerTestGatewayCache{}, + cfg: cfg, + concurrencyService: NewConcurrencyService(schedulerTestConcurrencyCache{}), + } + + selection, decision, err := svc.SelectAccountWithScheduler( + ctx, + &groupID, + "", + "", + "gpt-5.1", + nil, + OpenAIUpstreamTransportResponsesWebsocketV2, + ) + require.NoError(t, err) + require.NotNil(t, selection) + require.NotNil(t, selection.Account) + require.Equal(t, int64(36012), selection.Account.ID) + require.Equal(t, openAIAccountScheduleLayerLoadBalance, decision.Layer) +} + +func TestOpenAIGatewayService_SelectAccountWithScheduler_DefaultDisabled_RequiredWSV2_NoAvailableAccount(t *testing.T) { + resetOpenAIAdvancedSchedulerSettingCacheForTest() + + ctx := context.Background() + groupID := int64(10109) + accounts := []Account{ + { + ID: 36021, + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + Status: StatusActive, + Schedulable: true, + Concurrency: 1, + Priority: 0, + }, + } + cfg := newSchedulerTestOpenAIWSV2Config() + cfg.Gateway.Scheduling.LoadBatchEnabled = false + svc := &OpenAIGatewayService{ + accountRepo: schedulerTestOpenAIAccountRepo{accounts: accounts}, + cache: &schedulerTestGatewayCache{}, + cfg: cfg, + concurrencyService: NewConcurrencyService(schedulerTestConcurrencyCache{}), + } + + selection, decision, err := svc.SelectAccountWithScheduler( + ctx, + &groupID, + "", + "", + "gpt-5.1", + nil, + OpenAIUpstreamTransportResponsesWebsocketV2, + ) + require.ErrorContains(t, err, "no available OpenAI accounts") + require.Nil(t, selection) + require.Equal(t, openAIAccountScheduleLayerLoadBalance, decision.Layer) +} + func TestOpenAIGatewayService_SelectAccountWithScheduler_EnabledUsesAdvancedPreviousResponseRouting(t *testing.T) { resetOpenAIAdvancedSchedulerSettingCacheForTest()