fix: honor ws transport when scheduler is disabled

This commit is contained in:
IanShaw027
2026-04-21 13:50:55 +08:00
parent 65efef1eee
commit ace082066a
2 changed files with 143 additions and 8 deletions

View File

@@ -767,14 +767,10 @@ func (s *defaultOpenAIAccountScheduler) selectByLoadBalance(
}
func (s *defaultOpenAIAccountScheduler) isAccountTransportCompatible(account *Account, requiredTransport OpenAIUpstreamTransport) bool {
// HTTP 入站可回退到 HTTP 线路,不需要在账号选择阶段做传输协议强过滤。
if requiredTransport == OpenAIUpstreamTransportAny || requiredTransport == OpenAIUpstreamTransportHTTPSSE {
return true
}
if s == nil || s.service == nil || account == nil {
if s == nil || s.service == nil {
return false
}
return s.service.getOpenAIWSProtocolResolver().Resolve(account).Transport == requiredTransport
return s.service.isOpenAIAccountTransportCompatible(account, requiredTransport)
}
func (s *defaultOpenAIAccountScheduler) ReportResult(accountID int64, success bool, firstTokenMs *int) {
@@ -899,9 +895,35 @@ func (s *OpenAIGatewayService) SelectAccountWithScheduler(
decision := OpenAIAccountScheduleDecision{}
scheduler := s.getOpenAIAccountScheduler(ctx)
if scheduler == nil {
selection, err := s.SelectAccountWithLoadAwareness(ctx, groupID, sessionHash, requestedModel, excludedIDs)
decision.Layer = openAIAccountScheduleLayerLoadBalance
return selection, decision, err
if requiredTransport == OpenAIUpstreamTransportAny || requiredTransport == OpenAIUpstreamTransportHTTPSSE {
selection, err := s.SelectAccountWithLoadAwareness(ctx, groupID, sessionHash, requestedModel, excludedIDs)
return selection, decision, err
}
effectiveExcludedIDs := cloneExcludedAccountIDs(excludedIDs)
for {
selection, err := s.SelectAccountWithLoadAwareness(ctx, groupID, sessionHash, requestedModel, effectiveExcludedIDs)
if err != nil {
return nil, decision, err
}
if selection == nil || selection.Account == nil {
return selection, decision, nil
}
if s.isOpenAIAccountTransportCompatible(selection.Account, requiredTransport) {
return selection, decision, nil
}
if selection.ReleaseFunc != nil {
selection.ReleaseFunc()
}
if effectiveExcludedIDs == nil {
effectiveExcludedIDs = make(map[int64]struct{})
}
if _, exists := effectiveExcludedIDs[selection.Account.ID]; exists {
return nil, decision, ErrNoAvailableAccounts
}
effectiveExcludedIDs[selection.Account.ID] = struct{}{}
}
}
var stickyAccountID int64
@@ -922,6 +944,27 @@ func (s *OpenAIGatewayService) SelectAccountWithScheduler(
})
}
func cloneExcludedAccountIDs(excludedIDs map[int64]struct{}) map[int64]struct{} {
if len(excludedIDs) == 0 {
return nil
}
cloned := make(map[int64]struct{}, len(excludedIDs))
for id := range excludedIDs {
cloned[id] = struct{}{}
}
return cloned
}
func (s *OpenAIGatewayService) isOpenAIAccountTransportCompatible(account *Account, requiredTransport OpenAIUpstreamTransport) bool {
if requiredTransport == OpenAIUpstreamTransportAny || requiredTransport == OpenAIUpstreamTransportHTTPSSE {
return true
}
if s == nil || account == nil {
return false
}
return s.getOpenAIWSProtocolResolver().Resolve(account).Transport == requiredTransport
}
func (s *OpenAIGatewayService) ReportOpenAIAccountScheduleResult(accountID int64, success bool, firstTokenMs *int) {
scheduler := s.getOpenAIAccountScheduler(context.Background())
if scheduler == nil {

View File

@@ -298,6 +298,98 @@ func TestOpenAIGatewayService_SelectAccountWithScheduler_DefaultDisabledUsesLega
require.False(t, decision.StickyPreviousHit)
}
func TestOpenAIGatewayService_SelectAccountWithScheduler_DefaultDisabled_RequiredWSV2_SkipsHTTPOnlyAccount(t *testing.T) {
resetOpenAIAdvancedSchedulerSettingCacheForTest()
ctx := context.Background()
groupID := int64(10108)
accounts := []Account{
{
ID: 36011,
Platform: PlatformOpenAI,
Type: AccountTypeAPIKey,
Status: StatusActive,
Schedulable: true,
Concurrency: 1,
Priority: 0,
},
{
ID: 36012,
Platform: PlatformOpenAI,
Type: AccountTypeAPIKey,
Status: StatusActive,
Schedulable: true,
Concurrency: 1,
Priority: 5,
Extra: map[string]any{
"openai_apikey_responses_websockets_v2_enabled": true,
},
},
}
cfg := newSchedulerTestOpenAIWSV2Config()
cfg.Gateway.Scheduling.LoadBatchEnabled = false
svc := &OpenAIGatewayService{
accountRepo: schedulerTestOpenAIAccountRepo{accounts: accounts},
cache: &schedulerTestGatewayCache{},
cfg: cfg,
concurrencyService: NewConcurrencyService(schedulerTestConcurrencyCache{}),
}
selection, decision, err := svc.SelectAccountWithScheduler(
ctx,
&groupID,
"",
"",
"gpt-5.1",
nil,
OpenAIUpstreamTransportResponsesWebsocketV2,
)
require.NoError(t, err)
require.NotNil(t, selection)
require.NotNil(t, selection.Account)
require.Equal(t, int64(36012), selection.Account.ID)
require.Equal(t, openAIAccountScheduleLayerLoadBalance, decision.Layer)
}
func TestOpenAIGatewayService_SelectAccountWithScheduler_DefaultDisabled_RequiredWSV2_NoAvailableAccount(t *testing.T) {
resetOpenAIAdvancedSchedulerSettingCacheForTest()
ctx := context.Background()
groupID := int64(10109)
accounts := []Account{
{
ID: 36021,
Platform: PlatformOpenAI,
Type: AccountTypeAPIKey,
Status: StatusActive,
Schedulable: true,
Concurrency: 1,
Priority: 0,
},
}
cfg := newSchedulerTestOpenAIWSV2Config()
cfg.Gateway.Scheduling.LoadBatchEnabled = false
svc := &OpenAIGatewayService{
accountRepo: schedulerTestOpenAIAccountRepo{accounts: accounts},
cache: &schedulerTestGatewayCache{},
cfg: cfg,
concurrencyService: NewConcurrencyService(schedulerTestConcurrencyCache{}),
}
selection, decision, err := svc.SelectAccountWithScheduler(
ctx,
&groupID,
"",
"",
"gpt-5.1",
nil,
OpenAIUpstreamTransportResponsesWebsocketV2,
)
require.ErrorContains(t, err, "no available OpenAI accounts")
require.Nil(t, selection)
require.Equal(t, openAIAccountScheduleLayerLoadBalance, decision.Layer)
}
func TestOpenAIGatewayService_SelectAccountWithScheduler_EnabledUsesAdvancedPreviousResponseRouting(t *testing.T) {
resetOpenAIAdvancedSchedulerSettingCacheForTest()