package service import ( "context" "encoding/json" "io" "net/http" "net/http/httptest" "strings" "sync" "sync/atomic" "testing" "time" "github.com/Wei-Shaw/sub2api/internal/config" "github.com/gin-gonic/gin" "github.com/gorilla/websocket" "github.com/stretchr/testify/require" "github.com/tidwall/gjson" ) func TestOpenAIGatewayService_Forward_PreservePreviousResponseIDWhenWSEnabled(t *testing.T) { gin.SetMode(gin.TestMode) wsFallbackServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { http.NotFound(w, r) })) defer wsFallbackServer.Close() rec := httptest.NewRecorder() c, _ := gin.CreateTestContext(rec) c.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses", nil) c.Request.Header.Set("User-Agent", "custom-client/1.0") upstream := &httpUpstreamRecorder{ resp: &http.Response{ StatusCode: http.StatusOK, Header: http.Header{"Content-Type": []string{"application/json"}}, Body: io.NopCloser(strings.NewReader( `{"usage":{"input_tokens":1,"output_tokens":2,"input_tokens_details":{"cached_tokens":0}}}`, )), }, } cfg := &config.Config{} cfg.Security.URLAllowlist.Enabled = false cfg.Security.URLAllowlist.AllowInsecureHTTP = true cfg.Gateway.OpenAIWS.Enabled = true cfg.Gateway.OpenAIWS.OAuthEnabled = true cfg.Gateway.OpenAIWS.APIKeyEnabled = true cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true svc := &OpenAIGatewayService{ cfg: cfg, httpUpstream: upstream, openaiWSResolver: NewOpenAIWSProtocolResolver(cfg), } account := &Account{ ID: 1, Name: "openai-apikey", Platform: PlatformOpenAI, Type: AccountTypeAPIKey, Concurrency: 1, Credentials: map[string]any{ "api_key": "sk-test", "base_url": wsFallbackServer.URL, }, Extra: map[string]any{ "responses_websockets_v2_enabled": true, }, } body := []byte(`{"model":"gpt-5.1","stream":false,"previous_response_id":"resp_123","input":[{"type":"input_text","text":"hello"}]}`) result, err := svc.Forward(context.Background(), c, account, body) require.Error(t, err) require.Nil(t, result) require.Nil(t, upstream.lastReq, "WS 模式下失败时不应回退 HTTP") } func TestOpenAIGatewayService_Forward_HTTPIngressStaysHTTPWhenWSEnabled(t *testing.T) { gin.SetMode(gin.TestMode) wsFallbackServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { http.NotFound(w, r) })) defer wsFallbackServer.Close() rec := httptest.NewRecorder() c, _ := gin.CreateTestContext(rec) c.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses", nil) c.Request.Header.Set("User-Agent", "custom-client/1.0") SetOpenAIClientTransport(c, OpenAIClientTransportHTTP) upstream := &httpUpstreamRecorder{ resp: &http.Response{ StatusCode: http.StatusOK, Header: http.Header{"Content-Type": []string{"application/json"}}, Body: io.NopCloser(strings.NewReader( `{"usage":{"input_tokens":1,"output_tokens":2,"input_tokens_details":{"cached_tokens":0}}}`, )), }, } cfg := &config.Config{} cfg.Security.URLAllowlist.Enabled = false cfg.Security.URLAllowlist.AllowInsecureHTTP = true cfg.Gateway.OpenAIWS.Enabled = true cfg.Gateway.OpenAIWS.OAuthEnabled = true cfg.Gateway.OpenAIWS.APIKeyEnabled = true cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true svc := &OpenAIGatewayService{ cfg: cfg, httpUpstream: upstream, openaiWSResolver: NewOpenAIWSProtocolResolver(cfg), } account := &Account{ ID: 101, Name: "openai-apikey", Platform: PlatformOpenAI, Type: AccountTypeAPIKey, Concurrency: 1, Credentials: map[string]any{ "api_key": "sk-test", "base_url": wsFallbackServer.URL, }, Extra: map[string]any{ "responses_websockets_v2_enabled": true, }, } body := []byte(`{"model":"gpt-5.1","stream":false,"previous_response_id":"resp_http_keep","input":[{"type":"input_text","text":"hello"}]}`) result, err := svc.Forward(context.Background(), c, account, body) require.NoError(t, err) require.NotNil(t, result) require.False(t, result.OpenAIWSMode, "HTTP 入站应保持 HTTP 转发") require.NotNil(t, upstream.lastReq, "HTTP 入站应命中 HTTP 上游") require.False(t, gjson.GetBytes(upstream.lastBody, "previous_response_id").Exists(), "HTTP 路径应沿用原逻辑移除 previous_response_id") decision, _ := c.Get("openai_ws_transport_decision") reason, _ := c.Get("openai_ws_transport_reason") require.Equal(t, string(OpenAIUpstreamTransportHTTPSSE), decision) require.Equal(t, "client_protocol_http", reason) } func TestOpenAIGatewayService_Forward_RemovePreviousResponseIDWhenWSDisabled(t *testing.T) { gin.SetMode(gin.TestMode) wsFallbackServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { http.NotFound(w, r) })) defer wsFallbackServer.Close() rec := httptest.NewRecorder() c, _ := gin.CreateTestContext(rec) c.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses", nil) c.Request.Header.Set("User-Agent", "custom-client/1.0") upstream := &httpUpstreamRecorder{ resp: &http.Response{ StatusCode: http.StatusOK, Header: http.Header{"Content-Type": []string{"application/json"}}, Body: io.NopCloser(strings.NewReader( `{"usage":{"input_tokens":1,"output_tokens":2,"input_tokens_details":{"cached_tokens":0}}}`, )), }, } cfg := &config.Config{} cfg.Security.URLAllowlist.Enabled = false cfg.Security.URLAllowlist.AllowInsecureHTTP = true cfg.Gateway.OpenAIWS.Enabled = false cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true svc := &OpenAIGatewayService{ cfg: cfg, httpUpstream: upstream, openaiWSResolver: NewOpenAIWSProtocolResolver(cfg), } account := &Account{ ID: 1, Name: "openai-apikey", Platform: PlatformOpenAI, Type: AccountTypeAPIKey, Concurrency: 1, Credentials: map[string]any{ "api_key": "sk-test", "base_url": wsFallbackServer.URL, }, Extra: map[string]any{ "responses_websockets_v2_enabled": true, }, } body := []byte(`{"model":"gpt-5.1","stream":false,"previous_response_id":"resp_123","input":[{"type":"input_text","text":"hello"}]}`) result, err := svc.Forward(context.Background(), c, account, body) require.NoError(t, err) require.NotNil(t, result) require.False(t, gjson.GetBytes(upstream.lastBody, "previous_response_id").Exists()) } func TestOpenAIGatewayService_Forward_WSv2Dial426FallbackHTTP(t *testing.T) { gin.SetMode(gin.TestMode) ws426Server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusUpgradeRequired) _, _ = w.Write([]byte(`upgrade required`)) })) defer ws426Server.Close() rec := httptest.NewRecorder() c, _ := gin.CreateTestContext(rec) c.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses", nil) c.Request.Header.Set("User-Agent", "custom-client/1.0") upstream := &httpUpstreamRecorder{ resp: &http.Response{ StatusCode: http.StatusOK, Header: http.Header{"Content-Type": []string{"application/json"}}, Body: io.NopCloser(strings.NewReader( `{"usage":{"input_tokens":8,"output_tokens":9,"input_tokens_details":{"cached_tokens":1}}}`, )), }, } cfg := &config.Config{} cfg.Security.URLAllowlist.Enabled = false cfg.Security.URLAllowlist.AllowInsecureHTTP = true cfg.Gateway.OpenAIWS.Enabled = true cfg.Gateway.OpenAIWS.OAuthEnabled = true cfg.Gateway.OpenAIWS.APIKeyEnabled = true cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true cfg.Gateway.OpenAIWS.FallbackCooldownSeconds = 1 svc := &OpenAIGatewayService{ cfg: cfg, httpUpstream: upstream, openaiWSResolver: NewOpenAIWSProtocolResolver(cfg), } account := &Account{ ID: 12, Name: "openai-apikey", Platform: PlatformOpenAI, Type: AccountTypeAPIKey, Concurrency: 1, Credentials: map[string]any{ "api_key": "sk-test", "base_url": ws426Server.URL, }, Extra: map[string]any{ "responses_websockets_v2_enabled": true, }, } body := []byte(`{"model":"gpt-5.1","stream":false,"previous_response_id":"resp_426","input":[{"type":"input_text","text":"hello"}]}`) result, err := svc.Forward(context.Background(), c, account, body) require.Error(t, err) require.Nil(t, result) require.Contains(t, err.Error(), "upgrade_required") require.Nil(t, upstream.lastReq, "WS 模式下不应再回退 HTTP") require.Equal(t, http.StatusUpgradeRequired, rec.Code) require.Contains(t, rec.Body.String(), "426") } func TestOpenAIGatewayService_Forward_WSv2FallbackCoolingSkipWS(t *testing.T) { gin.SetMode(gin.TestMode) wsServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { http.NotFound(w, r) })) defer wsServer.Close() rec := httptest.NewRecorder() c, _ := gin.CreateTestContext(rec) c.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses", nil) c.Request.Header.Set("User-Agent", "custom-client/1.0") upstream := &httpUpstreamRecorder{ resp: &http.Response{ StatusCode: http.StatusOK, Header: http.Header{"Content-Type": []string{"application/json"}}, Body: io.NopCloser(strings.NewReader( `{"usage":{"input_tokens":2,"output_tokens":3,"input_tokens_details":{"cached_tokens":0}}}`, )), }, } cfg := &config.Config{} cfg.Security.URLAllowlist.Enabled = false cfg.Security.URLAllowlist.AllowInsecureHTTP = true cfg.Gateway.OpenAIWS.Enabled = true cfg.Gateway.OpenAIWS.OAuthEnabled = true cfg.Gateway.OpenAIWS.APIKeyEnabled = true cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true cfg.Gateway.OpenAIWS.FallbackCooldownSeconds = 30 svc := &OpenAIGatewayService{ cfg: cfg, httpUpstream: upstream, openaiWSResolver: NewOpenAIWSProtocolResolver(cfg), } account := &Account{ ID: 21, Name: "openai-apikey", Platform: PlatformOpenAI, Type: AccountTypeAPIKey, Concurrency: 1, Credentials: map[string]any{ "api_key": "sk-test", "base_url": wsServer.URL, }, Extra: map[string]any{ "responses_websockets_v2_enabled": true, }, } svc.markOpenAIWSFallbackCooling(account.ID, "upgrade_required") body := []byte(`{"model":"gpt-5.1","stream":false,"previous_response_id":"resp_cooling","input":[{"type":"input_text","text":"hello"}]}`) result, err := svc.Forward(context.Background(), c, account, body) require.Error(t, err) require.Nil(t, result) require.Nil(t, upstream.lastReq, "WS 模式下不应再回退 HTTP") _, ok := c.Get("openai_ws_fallback_cooling") require.False(t, ok, "已移除 fallback cooling 快捷回退路径") } func TestOpenAIGatewayService_Forward_ReturnErrorWhenOnlyWSv1Enabled(t *testing.T) { gin.SetMode(gin.TestMode) rec := httptest.NewRecorder() c, _ := gin.CreateTestContext(rec) c.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses", nil) c.Request.Header.Set("User-Agent", "custom-client/1.0") upstream := &httpUpstreamRecorder{ resp: &http.Response{ StatusCode: http.StatusOK, Header: http.Header{"Content-Type": []string{"application/json"}}, Body: io.NopCloser(strings.NewReader( `{"usage":{"input_tokens":1,"output_tokens":2,"input_tokens_details":{"cached_tokens":0}}}`, )), }, } cfg := &config.Config{} cfg.Security.URLAllowlist.Enabled = false cfg.Security.URLAllowlist.AllowInsecureHTTP = true cfg.Gateway.OpenAIWS.Enabled = true cfg.Gateway.OpenAIWS.OAuthEnabled = true cfg.Gateway.OpenAIWS.APIKeyEnabled = true cfg.Gateway.OpenAIWS.ResponsesWebsockets = true cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = false svc := &OpenAIGatewayService{ cfg: cfg, httpUpstream: upstream, openaiWSResolver: NewOpenAIWSProtocolResolver(cfg), } account := &Account{ ID: 31, Name: "openai-apikey", Platform: PlatformOpenAI, Type: AccountTypeAPIKey, Concurrency: 1, Credentials: map[string]any{ "api_key": "sk-test", "base_url": "https://api.openai.com/v1/responses", }, Extra: map[string]any{ "responses_websockets_v2_enabled": true, }, } body := []byte(`{"model":"gpt-5.1","stream":false,"previous_response_id":"resp_v1","input":[{"type":"input_text","text":"hello"}]}`) result, err := svc.Forward(context.Background(), c, account, body) require.Error(t, err) require.Nil(t, result) require.Contains(t, err.Error(), "ws v1") require.Equal(t, http.StatusBadRequest, rec.Code) require.Contains(t, rec.Body.String(), "WSv1") require.Nil(t, upstream.lastReq, "WSv1 不支持时不应触发 HTTP 上游请求") } func TestNewOpenAIGatewayService_InitializesOpenAIWSResolver(t *testing.T) { cfg := &config.Config{} svc := NewOpenAIGatewayService( nil, nil, nil, nil, nil, cfg, nil, nil, nil, nil, nil, nil, nil, nil, ) decision := svc.getOpenAIWSProtocolResolver().Resolve(nil) require.Equal(t, OpenAIUpstreamTransportHTTPSSE, decision.Transport) require.Equal(t, "account_missing", decision.Reason) } func TestOpenAIGatewayService_Forward_WSv2FallbackWhenResponseAlreadyWrittenReturnsWSError(t *testing.T) { gin.SetMode(gin.TestMode) ws426Server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusUpgradeRequired) _, _ = w.Write([]byte(`upgrade required`)) })) defer ws426Server.Close() rec := httptest.NewRecorder() c, _ := gin.CreateTestContext(rec) c.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses", nil) c.Request.Header.Set("User-Agent", "custom-client/1.0") c.String(http.StatusAccepted, "already-written") upstream := &httpUpstreamRecorder{ resp: &http.Response{ StatusCode: http.StatusOK, Header: http.Header{"Content-Type": []string{"application/json"}}, Body: io.NopCloser(strings.NewReader(`{"usage":{"input_tokens":1,"output_tokens":1}}`)), }, } cfg := &config.Config{} cfg.Security.URLAllowlist.Enabled = false cfg.Security.URLAllowlist.AllowInsecureHTTP = true cfg.Gateway.OpenAIWS.Enabled = true cfg.Gateway.OpenAIWS.OAuthEnabled = true cfg.Gateway.OpenAIWS.APIKeyEnabled = true cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true cfg.Gateway.OpenAIWS.FallbackCooldownSeconds = 1 svc := &OpenAIGatewayService{ cfg: cfg, httpUpstream: upstream, openaiWSResolver: NewOpenAIWSProtocolResolver(cfg), } account := &Account{ ID: 41, Name: "openai-apikey", Platform: PlatformOpenAI, Type: AccountTypeAPIKey, Concurrency: 1, Credentials: map[string]any{ "api_key": "sk-test", "base_url": ws426Server.URL, }, Extra: map[string]any{ "responses_websockets_v2_enabled": true, }, } body := []byte(`{"model":"gpt-5.1","stream":false,"input":[{"type":"input_text","text":"hello"}]}`) result, err := svc.Forward(context.Background(), c, account, body) require.Error(t, err) require.Nil(t, result) require.Contains(t, err.Error(), "ws fallback") require.Nil(t, upstream.lastReq, "已写下游响应时,不应再回退 HTTP") } func TestOpenAIGatewayService_Forward_WSv2StreamEarlyCloseFallbackHTTP(t *testing.T) { gin.SetMode(gin.TestMode) upgrader := websocket.Upgrader{CheckOrigin: func(r *http.Request) bool { return true }} wsServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { conn, err := upgrader.Upgrade(w, r, nil) if err != nil { t.Errorf("upgrade websocket failed: %v", err) return } defer func() { _ = conn.Close() }() var req map[string]any if err := conn.ReadJSON(&req); err != nil { t.Errorf("read ws request failed: %v", err) return } // 仅发送 response.created(非 token 事件)后立即关闭, // 模拟线上“上游早期内部错误断连”的场景。 if err := conn.WriteJSON(map[string]any{ "type": "response.created", "response": map[string]any{ "id": "resp_ws_created_only", "model": "gpt-5.3-codex", }, }); err != nil { t.Errorf("write response.created failed: %v", err) return } closePayload := websocket.FormatCloseMessage(websocket.CloseInternalServerErr, "") _ = conn.WriteControl(websocket.CloseMessage, closePayload, time.Now().Add(time.Second)) })) defer wsServer.Close() rec := httptest.NewRecorder() c, _ := gin.CreateTestContext(rec) c.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses", nil) c.Request.Header.Set("User-Agent", "custom-client/1.0") upstream := &httpUpstreamRecorder{ resp: &http.Response{ StatusCode: http.StatusOK, Header: http.Header{"Content-Type": []string{"text/event-stream"}}, Body: io.NopCloser(strings.NewReader( "data: {\"type\":\"response.output_text.delta\",\"delta\":\"ok\"}\n\n" + "data: {\"type\":\"response.completed\",\"response\":{\"id\":\"resp_http_fallback\",\"usage\":{\"input_tokens\":2,\"output_tokens\":1}}}\n\n" + "data: [DONE]\n\n", )), }, } cfg := &config.Config{} cfg.Security.URLAllowlist.Enabled = false cfg.Security.URLAllowlist.AllowInsecureHTTP = true cfg.Gateway.OpenAIWS.Enabled = true cfg.Gateway.OpenAIWS.OAuthEnabled = true cfg.Gateway.OpenAIWS.APIKeyEnabled = true cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true cfg.Gateway.OpenAIWS.FallbackCooldownSeconds = 1 svc := &OpenAIGatewayService{ cfg: cfg, httpUpstream: upstream, openaiWSResolver: NewOpenAIWSProtocolResolver(cfg), toolCorrector: NewCodexToolCorrector(), } account := &Account{ ID: 88, Name: "openai-apikey", Platform: PlatformOpenAI, Type: AccountTypeAPIKey, Concurrency: 1, Credentials: map[string]any{ "api_key": "sk-test", "base_url": wsServer.URL, }, Extra: map[string]any{ "responses_websockets_v2_enabled": true, }, } body := []byte(`{"model":"gpt-5.3-codex","stream":true,"input":[{"type":"input_text","text":"hello"}]}`) result, err := svc.Forward(context.Background(), c, account, body) require.Error(t, err) require.Nil(t, result) require.Nil(t, upstream.lastReq, "WS 早期断连后不应再回退 HTTP") require.Empty(t, rec.Body.String(), "未产出 token 前上游断连时不应写入下游半截流") } func TestOpenAIGatewayService_Forward_WSv2RetryFiveTimesThenFallbackHTTP(t *testing.T) { gin.SetMode(gin.TestMode) var wsAttempts atomic.Int32 upgrader := websocket.Upgrader{CheckOrigin: func(r *http.Request) bool { return true }} wsServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { wsAttempts.Add(1) conn, err := upgrader.Upgrade(w, r, nil) if err != nil { t.Errorf("upgrade websocket failed: %v", err) return } defer func() { _ = conn.Close() }() var req map[string]any if err := conn.ReadJSON(&req); err != nil { t.Errorf("read ws request failed: %v", err) return } closePayload := websocket.FormatCloseMessage(websocket.CloseInternalServerErr, "") _ = conn.WriteControl(websocket.CloseMessage, closePayload, time.Now().Add(time.Second)) })) defer wsServer.Close() rec := httptest.NewRecorder() c, _ := gin.CreateTestContext(rec) c.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses", nil) c.Request.Header.Set("User-Agent", "custom-client/1.0") upstream := &httpUpstreamRecorder{ resp: &http.Response{ StatusCode: http.StatusOK, Header: http.Header{"Content-Type": []string{"text/event-stream"}}, Body: io.NopCloser(strings.NewReader( "data: {\"type\":\"response.output_text.delta\",\"delta\":\"ok\"}\n\n" + "data: {\"type\":\"response.completed\",\"response\":{\"id\":\"resp_retry_http_fallback\",\"usage\":{\"input_tokens\":2,\"output_tokens\":1}}}\n\n" + "data: [DONE]\n\n", )), }, } cfg := &config.Config{} cfg.Security.URLAllowlist.Enabled = false cfg.Security.URLAllowlist.AllowInsecureHTTP = true cfg.Gateway.OpenAIWS.Enabled = true cfg.Gateway.OpenAIWS.OAuthEnabled = true cfg.Gateway.OpenAIWS.APIKeyEnabled = true cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true cfg.Gateway.OpenAIWS.FallbackCooldownSeconds = 1 svc := &OpenAIGatewayService{ cfg: cfg, httpUpstream: upstream, openaiWSResolver: NewOpenAIWSProtocolResolver(cfg), toolCorrector: NewCodexToolCorrector(), } account := &Account{ ID: 89, Name: "openai-apikey", Platform: PlatformOpenAI, Type: AccountTypeAPIKey, Concurrency: 1, Credentials: map[string]any{ "api_key": "sk-test", "base_url": wsServer.URL, }, Extra: map[string]any{ "responses_websockets_v2_enabled": true, }, } body := []byte(`{"model":"gpt-5.3-codex","stream":true,"input":[{"type":"input_text","text":"hello"}]}`) result, err := svc.Forward(context.Background(), c, account, body) require.Error(t, err) require.Nil(t, result) require.Nil(t, upstream.lastReq, "WS 重连耗尽后不应再回退 HTTP") require.Equal(t, int32(openAIWSReconnectRetryLimit+1), wsAttempts.Load()) } func TestOpenAIGatewayService_Forward_WSv2PolicyViolationFastFallbackHTTP(t *testing.T) { gin.SetMode(gin.TestMode) var wsAttempts atomic.Int32 upgrader := websocket.Upgrader{CheckOrigin: func(r *http.Request) bool { return true }} wsServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { wsAttempts.Add(1) conn, err := upgrader.Upgrade(w, r, nil) if err != nil { t.Errorf("upgrade websocket failed: %v", err) return } defer func() { _ = conn.Close() }() var req map[string]any if err := conn.ReadJSON(&req); err != nil { t.Errorf("read ws request failed: %v", err) return } closePayload := websocket.FormatCloseMessage(websocket.ClosePolicyViolation, "") _ = conn.WriteControl(websocket.CloseMessage, closePayload, time.Now().Add(time.Second)) })) defer wsServer.Close() rec := httptest.NewRecorder() c, _ := gin.CreateTestContext(rec) c.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses", nil) c.Request.Header.Set("User-Agent", "custom-client/1.0") upstream := &httpUpstreamRecorder{ resp: &http.Response{ StatusCode: http.StatusOK, Header: http.Header{"Content-Type": []string{"application/json"}}, Body: io.NopCloser(strings.NewReader(`{"id":"resp_policy_fallback","usage":{"input_tokens":1,"output_tokens":1}}`)), }, } cfg := &config.Config{} cfg.Security.URLAllowlist.Enabled = false cfg.Security.URLAllowlist.AllowInsecureHTTP = true cfg.Gateway.OpenAIWS.Enabled = true cfg.Gateway.OpenAIWS.OAuthEnabled = true cfg.Gateway.OpenAIWS.APIKeyEnabled = true cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true cfg.Gateway.OpenAIWS.FallbackCooldownSeconds = 1 cfg.Gateway.OpenAIWS.RetryBackoffInitialMS = 1 cfg.Gateway.OpenAIWS.RetryBackoffMaxMS = 2 cfg.Gateway.OpenAIWS.RetryJitterRatio = 0 svc := &OpenAIGatewayService{ cfg: cfg, httpUpstream: upstream, openaiWSResolver: NewOpenAIWSProtocolResolver(cfg), toolCorrector: NewCodexToolCorrector(), } account := &Account{ ID: 8901, Name: "openai-apikey", Platform: PlatformOpenAI, Type: AccountTypeAPIKey, Concurrency: 1, Credentials: map[string]any{ "api_key": "sk-test", "base_url": wsServer.URL, }, Extra: map[string]any{ "responses_websockets_v2_enabled": true, }, } body := []byte(`{"model":"gpt-5.3-codex","stream":false,"input":[{"type":"input_text","text":"hello"}]}`) result, err := svc.Forward(context.Background(), c, account, body) require.Error(t, err) require.Nil(t, result) require.Nil(t, upstream.lastReq, "策略违规关闭后不应回退 HTTP") require.Equal(t, int32(1), wsAttempts.Load(), "策略违规不应进行 WS 重试") } func TestOpenAIGatewayService_Forward_WSv2ConnectionLimitReachedRetryThenFallbackHTTP(t *testing.T) { gin.SetMode(gin.TestMode) var wsAttempts atomic.Int32 upgrader := websocket.Upgrader{CheckOrigin: func(r *http.Request) bool { return true }} wsServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { wsAttempts.Add(1) conn, err := upgrader.Upgrade(w, r, nil) if err != nil { t.Errorf("upgrade websocket failed: %v", err) return } defer func() { _ = conn.Close() }() var req map[string]any if err := conn.ReadJSON(&req); err != nil { t.Errorf("read ws request failed: %v", err) return } _ = conn.WriteJSON(map[string]any{ "type": "error", "error": map[string]any{ "code": "websocket_connection_limit_reached", "type": "server_error", "message": "websocket connection limit reached", }, }) })) defer wsServer.Close() rec := httptest.NewRecorder() c, _ := gin.CreateTestContext(rec) c.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses", nil) c.Request.Header.Set("User-Agent", "custom-client/1.0") upstream := &httpUpstreamRecorder{ resp: &http.Response{ StatusCode: http.StatusOK, Header: http.Header{"Content-Type": []string{"application/json"}}, Body: io.NopCloser(strings.NewReader(`{"id":"resp_http_retry_limit","usage":{"input_tokens":1,"output_tokens":1}}`)), }, } cfg := &config.Config{} cfg.Security.URLAllowlist.Enabled = false cfg.Security.URLAllowlist.AllowInsecureHTTP = true cfg.Gateway.OpenAIWS.Enabled = true cfg.Gateway.OpenAIWS.OAuthEnabled = true cfg.Gateway.OpenAIWS.APIKeyEnabled = true cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true cfg.Gateway.OpenAIWS.FallbackCooldownSeconds = 1 svc := &OpenAIGatewayService{ cfg: cfg, httpUpstream: upstream, openaiWSResolver: NewOpenAIWSProtocolResolver(cfg), toolCorrector: NewCodexToolCorrector(), } account := &Account{ ID: 90, Name: "openai-apikey", Platform: PlatformOpenAI, Type: AccountTypeAPIKey, Concurrency: 1, Credentials: map[string]any{ "api_key": "sk-test", "base_url": wsServer.URL, }, Extra: map[string]any{ "responses_websockets_v2_enabled": true, }, } body := []byte(`{"model":"gpt-5.3-codex","stream":false,"input":[{"type":"input_text","text":"hello"}]}`) result, err := svc.Forward(context.Background(), c, account, body) require.Error(t, err) require.Nil(t, result) require.Nil(t, upstream.lastReq, "触发 websocket_connection_limit_reached 后不应回退 HTTP") require.Equal(t, int32(openAIWSReconnectRetryLimit+1), wsAttempts.Load()) } func TestOpenAIGatewayService_Forward_WSv2PreviousResponseNotFoundRecoversByDroppingPreviousResponseID(t *testing.T) { gin.SetMode(gin.TestMode) var wsAttempts atomic.Int32 var wsRequestPayloads [][]byte var wsRequestMu sync.Mutex upgrader := websocket.Upgrader{CheckOrigin: func(r *http.Request) bool { return true }} wsServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { attempt := wsAttempts.Add(1) conn, err := upgrader.Upgrade(w, r, nil) if err != nil { t.Errorf("upgrade websocket failed: %v", err) return } defer func() { _ = conn.Close() }() var req map[string]any if err := conn.ReadJSON(&req); err != nil { t.Errorf("read ws request failed: %v", err) return } reqRaw, _ := json.Marshal(req) wsRequestMu.Lock() wsRequestPayloads = append(wsRequestPayloads, reqRaw) wsRequestMu.Unlock() if attempt == 1 { _ = conn.WriteJSON(map[string]any{ "type": "error", "error": map[string]any{ "code": "previous_response_not_found", "type": "invalid_request_error", "message": "previous response not found", }, }) return } _ = conn.WriteJSON(map[string]any{ "type": "response.completed", "response": map[string]any{ "id": "resp_ws_prev_recover_ok", "model": "gpt-5.3-codex", "usage": map[string]any{ "input_tokens": 1, "output_tokens": 1, "input_tokens_details": map[string]any{ "cached_tokens": 0, }, }, }, }) })) defer wsServer.Close() rec := httptest.NewRecorder() c, _ := gin.CreateTestContext(rec) c.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses", nil) c.Request.Header.Set("User-Agent", "custom-client/1.0") upstream := &httpUpstreamRecorder{ resp: &http.Response{ StatusCode: http.StatusOK, Header: http.Header{"Content-Type": []string{"application/json"}}, Body: io.NopCloser(strings.NewReader(`{"id":"resp_http_drop_prev","usage":{"input_tokens":1,"output_tokens":1}}`)), }, } cfg := &config.Config{} cfg.Security.URLAllowlist.Enabled = false cfg.Security.URLAllowlist.AllowInsecureHTTP = true cfg.Gateway.OpenAIWS.Enabled = true cfg.Gateway.OpenAIWS.OAuthEnabled = true cfg.Gateway.OpenAIWS.APIKeyEnabled = true cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true cfg.Gateway.OpenAIWS.FallbackCooldownSeconds = 1 svc := &OpenAIGatewayService{ cfg: cfg, httpUpstream: upstream, openaiWSResolver: NewOpenAIWSProtocolResolver(cfg), toolCorrector: NewCodexToolCorrector(), } account := &Account{ ID: 91, Name: "openai-apikey", Platform: PlatformOpenAI, Type: AccountTypeAPIKey, Concurrency: 1, Credentials: map[string]any{ "api_key": "sk-test", "base_url": wsServer.URL, }, Extra: map[string]any{ "responses_websockets_v2_enabled": true, }, } body := []byte(`{"model":"gpt-5.3-codex","stream":false,"previous_response_id":"resp_prev_missing","input":[{"type":"input_text","text":"hello"}]}`) result, err := svc.Forward(context.Background(), c, account, body) require.NoError(t, err) require.NotNil(t, result) require.Equal(t, "resp_ws_prev_recover_ok", result.RequestID) require.Nil(t, upstream.lastReq, "previous_response_not_found 不应回退 HTTP") require.Equal(t, int32(2), wsAttempts.Load(), "previous_response_not_found 应触发一次去掉 previous_response_id 的恢复重试") require.Equal(t, http.StatusOK, rec.Code) require.Equal(t, "resp_ws_prev_recover_ok", gjson.Get(rec.Body.String(), "id").String()) wsRequestMu.Lock() requests := append([][]byte(nil), wsRequestPayloads...) wsRequestMu.Unlock() require.Len(t, requests, 2) require.True(t, gjson.GetBytes(requests[0], "previous_response_id").Exists(), "首轮请求应保留 previous_response_id") require.False(t, gjson.GetBytes(requests[1], "previous_response_id").Exists(), "恢复重试应移除 previous_response_id") } func TestOpenAIGatewayService_Forward_WSv2PreviousResponseNotFoundSkipsRecoveryForFunctionCallOutput(t *testing.T) { gin.SetMode(gin.TestMode) var wsAttempts atomic.Int32 var wsRequestPayloads [][]byte var wsRequestMu sync.Mutex upgrader := websocket.Upgrader{CheckOrigin: func(r *http.Request) bool { return true }} wsServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { wsAttempts.Add(1) conn, err := upgrader.Upgrade(w, r, nil) if err != nil { t.Errorf("upgrade websocket failed: %v", err) return } defer func() { _ = conn.Close() }() var req map[string]any if err := conn.ReadJSON(&req); err != nil { t.Errorf("read ws request failed: %v", err) return } reqRaw, _ := json.Marshal(req) wsRequestMu.Lock() wsRequestPayloads = append(wsRequestPayloads, reqRaw) wsRequestMu.Unlock() _ = conn.WriteJSON(map[string]any{ "type": "error", "error": map[string]any{ "code": "previous_response_not_found", "type": "invalid_request_error", "message": "previous response not found", }, }) })) defer wsServer.Close() rec := httptest.NewRecorder() c, _ := gin.CreateTestContext(rec) c.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses", nil) c.Request.Header.Set("User-Agent", "custom-client/1.0") upstream := &httpUpstreamRecorder{ resp: &http.Response{ StatusCode: http.StatusOK, Header: http.Header{"Content-Type": []string{"application/json"}}, Body: io.NopCloser(strings.NewReader(`{"id":"resp_http_drop_prev","usage":{"input_tokens":1,"output_tokens":1}}`)), }, } cfg := &config.Config{} cfg.Security.URLAllowlist.Enabled = false cfg.Security.URLAllowlist.AllowInsecureHTTP = true cfg.Gateway.OpenAIWS.Enabled = true cfg.Gateway.OpenAIWS.OAuthEnabled = true cfg.Gateway.OpenAIWS.APIKeyEnabled = true cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true cfg.Gateway.OpenAIWS.FallbackCooldownSeconds = 1 svc := &OpenAIGatewayService{ cfg: cfg, httpUpstream: upstream, openaiWSResolver: NewOpenAIWSProtocolResolver(cfg), toolCorrector: NewCodexToolCorrector(), } account := &Account{ ID: 92, Name: "openai-apikey", Platform: PlatformOpenAI, Type: AccountTypeAPIKey, Concurrency: 1, Credentials: map[string]any{ "api_key": "sk-test", "base_url": wsServer.URL, }, Extra: map[string]any{ "responses_websockets_v2_enabled": true, }, } body := []byte(`{"model":"gpt-5.3-codex","stream":false,"previous_response_id":"resp_prev_missing","input":[{"type":"function_call_output","call_id":"call_1","output":"ok"}]}`) result, err := svc.Forward(context.Background(), c, account, body) require.Error(t, err) require.Nil(t, result) require.Nil(t, upstream.lastReq, "previous_response_not_found 不应回退 HTTP") require.Equal(t, int32(1), wsAttempts.Load(), "function_call_output 场景应跳过 previous_response_not_found 自动恢复") require.Equal(t, http.StatusBadRequest, rec.Code) require.Contains(t, strings.ToLower(rec.Body.String()), "previous response not found") wsRequestMu.Lock() requests := append([][]byte(nil), wsRequestPayloads...) wsRequestMu.Unlock() require.Len(t, requests, 1) require.True(t, gjson.GetBytes(requests[0], "previous_response_id").Exists()) } func TestOpenAIGatewayService_Forward_WSv2PreviousResponseNotFoundSkipsRecoveryWithoutPreviousResponseID(t *testing.T) { gin.SetMode(gin.TestMode) var wsAttempts atomic.Int32 var wsRequestPayloads [][]byte var wsRequestMu sync.Mutex upgrader := websocket.Upgrader{CheckOrigin: func(r *http.Request) bool { return true }} wsServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { wsAttempts.Add(1) conn, err := upgrader.Upgrade(w, r, nil) if err != nil { t.Errorf("upgrade websocket failed: %v", err) return } defer func() { _ = conn.Close() }() var req map[string]any if err := conn.ReadJSON(&req); err != nil { t.Errorf("read ws request failed: %v", err) return } reqRaw, _ := json.Marshal(req) wsRequestMu.Lock() wsRequestPayloads = append(wsRequestPayloads, reqRaw) wsRequestMu.Unlock() _ = conn.WriteJSON(map[string]any{ "type": "error", "error": map[string]any{ "code": "previous_response_not_found", "type": "invalid_request_error", "message": "previous response not found", }, }) })) defer wsServer.Close() rec := httptest.NewRecorder() c, _ := gin.CreateTestContext(rec) c.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses", nil) c.Request.Header.Set("User-Agent", "custom-client/1.0") upstream := &httpUpstreamRecorder{ resp: &http.Response{ StatusCode: http.StatusOK, Header: http.Header{"Content-Type": []string{"application/json"}}, Body: io.NopCloser(strings.NewReader(`{"id":"resp_http_drop_prev","usage":{"input_tokens":1,"output_tokens":1}}`)), }, } cfg := &config.Config{} cfg.Security.URLAllowlist.Enabled = false cfg.Security.URLAllowlist.AllowInsecureHTTP = true cfg.Gateway.OpenAIWS.Enabled = true cfg.Gateway.OpenAIWS.OAuthEnabled = true cfg.Gateway.OpenAIWS.APIKeyEnabled = true cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true cfg.Gateway.OpenAIWS.FallbackCooldownSeconds = 1 svc := &OpenAIGatewayService{ cfg: cfg, httpUpstream: upstream, openaiWSResolver: NewOpenAIWSProtocolResolver(cfg), toolCorrector: NewCodexToolCorrector(), } account := &Account{ ID: 93, Name: "openai-apikey", Platform: PlatformOpenAI, Type: AccountTypeAPIKey, Concurrency: 1, Credentials: map[string]any{ "api_key": "sk-test", "base_url": wsServer.URL, }, Extra: map[string]any{ "responses_websockets_v2_enabled": true, }, } body := []byte(`{"model":"gpt-5.3-codex","stream":false,"input":[{"type":"input_text","text":"hello"}]}`) result, err := svc.Forward(context.Background(), c, account, body) require.Error(t, err) require.Nil(t, result) require.Nil(t, upstream.lastReq, "WS 模式下 previous_response_not_found 不应回退 HTTP") require.Equal(t, int32(1), wsAttempts.Load(), "缺少 previous_response_id 时应跳过自动恢复重试") require.Equal(t, http.StatusBadRequest, rec.Code) wsRequestMu.Lock() requests := append([][]byte(nil), wsRequestPayloads...) wsRequestMu.Unlock() require.Len(t, requests, 1) require.False(t, gjson.GetBytes(requests[0], "previous_response_id").Exists()) } func TestOpenAIGatewayService_Forward_WSv2PreviousResponseNotFoundOnlyRecoversOnce(t *testing.T) { gin.SetMode(gin.TestMode) var wsAttempts atomic.Int32 var wsRequestPayloads [][]byte var wsRequestMu sync.Mutex upgrader := websocket.Upgrader{CheckOrigin: func(r *http.Request) bool { return true }} wsServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { wsAttempts.Add(1) conn, err := upgrader.Upgrade(w, r, nil) if err != nil { t.Errorf("upgrade websocket failed: %v", err) return } defer func() { _ = conn.Close() }() var req map[string]any if err := conn.ReadJSON(&req); err != nil { t.Errorf("read ws request failed: %v", err) return } reqRaw, _ := json.Marshal(req) wsRequestMu.Lock() wsRequestPayloads = append(wsRequestPayloads, reqRaw) wsRequestMu.Unlock() _ = conn.WriteJSON(map[string]any{ "type": "error", "error": map[string]any{ "code": "previous_response_not_found", "type": "invalid_request_error", "message": "previous response not found", }, }) })) defer wsServer.Close() rec := httptest.NewRecorder() c, _ := gin.CreateTestContext(rec) c.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses", nil) c.Request.Header.Set("User-Agent", "custom-client/1.0") upstream := &httpUpstreamRecorder{ resp: &http.Response{ StatusCode: http.StatusOK, Header: http.Header{"Content-Type": []string{"application/json"}}, Body: io.NopCloser(strings.NewReader(`{"id":"resp_http_drop_prev","usage":{"input_tokens":1,"output_tokens":1}}`)), }, } cfg := &config.Config{} cfg.Security.URLAllowlist.Enabled = false cfg.Security.URLAllowlist.AllowInsecureHTTP = true cfg.Gateway.OpenAIWS.Enabled = true cfg.Gateway.OpenAIWS.OAuthEnabled = true cfg.Gateway.OpenAIWS.APIKeyEnabled = true cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true cfg.Gateway.OpenAIWS.FallbackCooldownSeconds = 1 svc := &OpenAIGatewayService{ cfg: cfg, httpUpstream: upstream, openaiWSResolver: NewOpenAIWSProtocolResolver(cfg), toolCorrector: NewCodexToolCorrector(), } account := &Account{ ID: 94, Name: "openai-apikey", Platform: PlatformOpenAI, Type: AccountTypeAPIKey, Concurrency: 1, Credentials: map[string]any{ "api_key": "sk-test", "base_url": wsServer.URL, }, Extra: map[string]any{ "responses_websockets_v2_enabled": true, }, } body := []byte(`{"model":"gpt-5.3-codex","stream":false,"previous_response_id":"resp_prev_missing","input":[{"type":"input_text","text":"hello"}]}`) result, err := svc.Forward(context.Background(), c, account, body) require.Error(t, err) require.Nil(t, result) require.Nil(t, upstream.lastReq, "WS 模式下 previous_response_not_found 不应回退 HTTP") require.Equal(t, int32(2), wsAttempts.Load(), "应只允许一次自动恢复重试") require.Equal(t, http.StatusBadRequest, rec.Code) wsRequestMu.Lock() requests := append([][]byte(nil), wsRequestPayloads...) wsRequestMu.Unlock() require.Len(t, requests, 2) require.True(t, gjson.GetBytes(requests[0], "previous_response_id").Exists(), "首轮请求应包含 previous_response_id") require.False(t, gjson.GetBytes(requests[1], "previous_response_id").Exists(), "恢复重试应移除 previous_response_id") }