Files
sub2api/backend/internal/service/openai_ws_protocol_resolver_test.go
2026-02-28 15:01:20 +08:00

204 lines
7.1 KiB
Go

package service
import (
"testing"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/stretchr/testify/require"
)
func TestOpenAIWSProtocolResolver_Resolve(t *testing.T) {
baseCfg := &config.Config{}
baseCfg.Gateway.OpenAIWS.Enabled = true
baseCfg.Gateway.OpenAIWS.OAuthEnabled = true
baseCfg.Gateway.OpenAIWS.APIKeyEnabled = true
baseCfg.Gateway.OpenAIWS.ResponsesWebsockets = false
baseCfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true
openAIOAuthEnabled := &Account{
Platform: PlatformOpenAI,
Type: AccountTypeOAuth,
Extra: map[string]any{
"openai_oauth_responses_websockets_v2_enabled": true,
},
}
t.Run("v2优先", func(t *testing.T) {
decision := NewOpenAIWSProtocolResolver(baseCfg).Resolve(openAIOAuthEnabled)
require.Equal(t, OpenAIUpstreamTransportResponsesWebsocketV2, decision.Transport)
require.Equal(t, "ws_v2_enabled", decision.Reason)
})
t.Run("v2关闭时回退v1", func(t *testing.T) {
cfg := *baseCfg
cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = false
cfg.Gateway.OpenAIWS.ResponsesWebsockets = true
decision := NewOpenAIWSProtocolResolver(&cfg).Resolve(openAIOAuthEnabled)
require.Equal(t, OpenAIUpstreamTransportResponsesWebsocket, decision.Transport)
require.Equal(t, "ws_v1_enabled", decision.Reason)
})
t.Run("透传开关不影响WS协议判定", func(t *testing.T) {
account := *openAIOAuthEnabled
account.Extra = map[string]any{
"openai_oauth_responses_websockets_v2_enabled": true,
"openai_passthrough": true,
}
decision := NewOpenAIWSProtocolResolver(baseCfg).Resolve(&account)
require.Equal(t, OpenAIUpstreamTransportResponsesWebsocketV2, decision.Transport)
require.Equal(t, "ws_v2_enabled", decision.Reason)
})
t.Run("账号级强制HTTP", func(t *testing.T) {
account := *openAIOAuthEnabled
account.Extra = map[string]any{
"openai_oauth_responses_websockets_v2_enabled": true,
"openai_ws_force_http": true,
}
decision := NewOpenAIWSProtocolResolver(baseCfg).Resolve(&account)
require.Equal(t, OpenAIUpstreamTransportHTTPSSE, decision.Transport)
require.Equal(t, "account_force_http", decision.Reason)
})
t.Run("全局关闭保持HTTP", func(t *testing.T) {
cfg := *baseCfg
cfg.Gateway.OpenAIWS.Enabled = false
decision := NewOpenAIWSProtocolResolver(&cfg).Resolve(openAIOAuthEnabled)
require.Equal(t, OpenAIUpstreamTransportHTTPSSE, decision.Transport)
require.Equal(t, "global_disabled", decision.Reason)
})
t.Run("账号开关关闭保持HTTP", func(t *testing.T) {
account := *openAIOAuthEnabled
account.Extra = map[string]any{
"openai_oauth_responses_websockets_v2_enabled": false,
}
decision := NewOpenAIWSProtocolResolver(baseCfg).Resolve(&account)
require.Equal(t, OpenAIUpstreamTransportHTTPSSE, decision.Transport)
require.Equal(t, "account_disabled", decision.Reason)
})
t.Run("OAuth账号不会读取API Key专用开关", func(t *testing.T) {
account := *openAIOAuthEnabled
account.Extra = map[string]any{
"openai_apikey_responses_websockets_v2_enabled": true,
}
decision := NewOpenAIWSProtocolResolver(baseCfg).Resolve(&account)
require.Equal(t, OpenAIUpstreamTransportHTTPSSE, decision.Transport)
require.Equal(t, "account_disabled", decision.Reason)
})
t.Run("兼容旧键openai_ws_enabled", func(t *testing.T) {
account := *openAIOAuthEnabled
account.Extra = map[string]any{
"openai_ws_enabled": true,
}
decision := NewOpenAIWSProtocolResolver(baseCfg).Resolve(&account)
require.Equal(t, OpenAIUpstreamTransportResponsesWebsocketV2, decision.Transport)
require.Equal(t, "ws_v2_enabled", decision.Reason)
})
t.Run("按账号类型开关控制", func(t *testing.T) {
cfg := *baseCfg
cfg.Gateway.OpenAIWS.OAuthEnabled = false
decision := NewOpenAIWSProtocolResolver(&cfg).Resolve(openAIOAuthEnabled)
require.Equal(t, OpenAIUpstreamTransportHTTPSSE, decision.Transport)
require.Equal(t, "oauth_disabled", decision.Reason)
})
t.Run("API Key 账号关闭开关时回退HTTP", func(t *testing.T) {
cfg := *baseCfg
cfg.Gateway.OpenAIWS.APIKeyEnabled = false
account := &Account{
Platform: PlatformOpenAI,
Type: AccountTypeAPIKey,
Extra: map[string]any{
"openai_apikey_responses_websockets_v2_enabled": true,
},
}
decision := NewOpenAIWSProtocolResolver(&cfg).Resolve(account)
require.Equal(t, OpenAIUpstreamTransportHTTPSSE, decision.Transport)
require.Equal(t, "apikey_disabled", decision.Reason)
})
t.Run("未知认证类型回退HTTP", func(t *testing.T) {
account := &Account{
Platform: PlatformOpenAI,
Type: "unknown_type",
Extra: map[string]any{
"responses_websockets_v2_enabled": true,
},
}
decision := NewOpenAIWSProtocolResolver(baseCfg).Resolve(account)
require.Equal(t, OpenAIUpstreamTransportHTTPSSE, decision.Transport)
require.Equal(t, "unknown_auth_type", decision.Reason)
})
}
func TestOpenAIWSProtocolResolver_Resolve_ModeRouterV2(t *testing.T) {
cfg := &config.Config{}
cfg.Gateway.OpenAIWS.Enabled = true
cfg.Gateway.OpenAIWS.OAuthEnabled = true
cfg.Gateway.OpenAIWS.APIKeyEnabled = true
cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true
cfg.Gateway.OpenAIWS.ModeRouterV2Enabled = true
cfg.Gateway.OpenAIWS.IngressModeDefault = OpenAIWSIngressModeShared
account := &Account{
Platform: PlatformOpenAI,
Type: AccountTypeOAuth,
Concurrency: 1,
Extra: map[string]any{
"openai_oauth_responses_websockets_v2_mode": OpenAIWSIngressModeDedicated,
},
}
t.Run("dedicated mode routes to ws v2", func(t *testing.T) {
decision := NewOpenAIWSProtocolResolver(cfg).Resolve(account)
require.Equal(t, OpenAIUpstreamTransportResponsesWebsocketV2, decision.Transport)
require.Equal(t, "ws_v2_mode_dedicated", decision.Reason)
})
t.Run("off mode routes to http", func(t *testing.T) {
offAccount := &Account{
Platform: PlatformOpenAI,
Type: AccountTypeOAuth,
Concurrency: 1,
Extra: map[string]any{
"openai_oauth_responses_websockets_v2_mode": OpenAIWSIngressModeOff,
},
}
decision := NewOpenAIWSProtocolResolver(cfg).Resolve(offAccount)
require.Equal(t, OpenAIUpstreamTransportHTTPSSE, decision.Transport)
require.Equal(t, "account_mode_off", decision.Reason)
})
t.Run("legacy boolean maps to shared in v2 router", func(t *testing.T) {
legacyAccount := &Account{
Platform: PlatformOpenAI,
Type: AccountTypeAPIKey,
Concurrency: 1,
Extra: map[string]any{
"openai_apikey_responses_websockets_v2_enabled": true,
},
}
decision := NewOpenAIWSProtocolResolver(cfg).Resolve(legacyAccount)
require.Equal(t, OpenAIUpstreamTransportResponsesWebsocketV2, decision.Transport)
require.Equal(t, "ws_v2_mode_shared", decision.Reason)
})
t.Run("non-positive concurrency is rejected in v2 router", func(t *testing.T) {
invalidConcurrency := &Account{
Platform: PlatformOpenAI,
Type: AccountTypeOAuth,
Extra: map[string]any{
"openai_oauth_responses_websockets_v2_mode": OpenAIWSIngressModeShared,
},
}
decision := NewOpenAIWSProtocolResolver(cfg).Resolve(invalidConcurrency)
require.Equal(t, OpenAIUpstreamTransportHTTPSSE, decision.Transport)
require.Equal(t, "account_concurrency_invalid", decision.Reason)
})
}