1307 lines
42 KiB
Go
1307 lines
42 KiB
Go
package service
|
|
|
|
import (
|
|
"context"
|
|
"encoding/json"
|
|
"errors"
|
|
"io"
|
|
"net/http"
|
|
"net/http/httptest"
|
|
"strconv"
|
|
"strings"
|
|
"sync"
|
|
"sync/atomic"
|
|
"testing"
|
|
"time"
|
|
|
|
"github.com/Wei-Shaw/sub2api/internal/config"
|
|
"github.com/gin-gonic/gin"
|
|
"github.com/gorilla/websocket"
|
|
"github.com/stretchr/testify/require"
|
|
"github.com/tidwall/gjson"
|
|
)
|
|
|
|
func TestOpenAIGatewayService_Forward_WSv2_SuccessAndBindSticky(t *testing.T) {
|
|
gin.SetMode(gin.TestMode)
|
|
|
|
type receivedPayload struct {
|
|
Type string
|
|
PreviousResponseID string
|
|
StreamExists bool
|
|
Stream bool
|
|
}
|
|
receivedCh := make(chan receivedPayload, 1)
|
|
|
|
upgrader := websocket.Upgrader{CheckOrigin: func(r *http.Request) bool { return true }}
|
|
wsServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
conn, err := upgrader.Upgrade(w, r, nil)
|
|
if err != nil {
|
|
t.Errorf("upgrade websocket failed: %v", err)
|
|
return
|
|
}
|
|
defer func() {
|
|
_ = conn.Close()
|
|
}()
|
|
|
|
var request map[string]any
|
|
if err := conn.ReadJSON(&request); err != nil {
|
|
t.Errorf("read ws request failed: %v", err)
|
|
return
|
|
}
|
|
requestJSON := requestToJSONString(request)
|
|
receivedCh <- receivedPayload{
|
|
Type: strings.TrimSpace(gjson.Get(requestJSON, "type").String()),
|
|
PreviousResponseID: strings.TrimSpace(gjson.Get(requestJSON, "previous_response_id").String()),
|
|
StreamExists: gjson.Get(requestJSON, "stream").Exists(),
|
|
Stream: gjson.Get(requestJSON, "stream").Bool(),
|
|
}
|
|
|
|
if err := conn.WriteJSON(map[string]any{
|
|
"type": "response.created",
|
|
"response": map[string]any{
|
|
"id": "resp_new_1",
|
|
"model": "gpt-5.1",
|
|
},
|
|
}); err != nil {
|
|
t.Errorf("write response.created failed: %v", err)
|
|
return
|
|
}
|
|
if err := conn.WriteJSON(map[string]any{
|
|
"type": "response.completed",
|
|
"response": map[string]any{
|
|
"id": "resp_new_1",
|
|
"model": "gpt-5.1",
|
|
"usage": map[string]any{
|
|
"input_tokens": 12,
|
|
"output_tokens": 7,
|
|
"input_tokens_details": map[string]any{
|
|
"cached_tokens": 3,
|
|
},
|
|
},
|
|
},
|
|
}); err != nil {
|
|
t.Errorf("write response.completed failed: %v", err)
|
|
return
|
|
}
|
|
}))
|
|
defer wsServer.Close()
|
|
|
|
rec := httptest.NewRecorder()
|
|
c, _ := gin.CreateTestContext(rec)
|
|
c.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses", nil)
|
|
c.Request.Header.Set("User-Agent", "unit-test-agent/1.0")
|
|
groupID := int64(1001)
|
|
c.Set("api_key", &APIKey{GroupID: &groupID})
|
|
|
|
cfg := &config.Config{}
|
|
cfg.Security.URLAllowlist.Enabled = false
|
|
cfg.Security.URLAllowlist.AllowInsecureHTTP = true
|
|
cfg.Gateway.OpenAIWS.Enabled = true
|
|
cfg.Gateway.OpenAIWS.OAuthEnabled = true
|
|
cfg.Gateway.OpenAIWS.APIKeyEnabled = true
|
|
cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true
|
|
cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 2
|
|
cfg.Gateway.OpenAIWS.QueueLimitPerConn = 8
|
|
cfg.Gateway.OpenAIWS.DialTimeoutSeconds = 3
|
|
cfg.Gateway.OpenAIWS.ReadTimeoutSeconds = 30
|
|
cfg.Gateway.OpenAIWS.WriteTimeoutSeconds = 10
|
|
cfg.Gateway.OpenAIWS.StickyResponseIDTTLSeconds = 3600
|
|
|
|
upstream := &httpUpstreamRecorder{
|
|
resp: &http.Response{
|
|
StatusCode: http.StatusOK,
|
|
Header: http.Header{"Content-Type": []string{"application/json"}},
|
|
Body: io.NopCloser(strings.NewReader(`{"usage":{"input_tokens":1,"output_tokens":1}}`)),
|
|
},
|
|
}
|
|
|
|
cache := &stubGatewayCache{}
|
|
svc := &OpenAIGatewayService{
|
|
cfg: cfg,
|
|
httpUpstream: upstream,
|
|
cache: cache,
|
|
openaiWSResolver: NewOpenAIWSProtocolResolver(cfg),
|
|
toolCorrector: NewCodexToolCorrector(),
|
|
}
|
|
|
|
account := &Account{
|
|
ID: 9,
|
|
Name: "openai-ws",
|
|
Platform: PlatformOpenAI,
|
|
Type: AccountTypeAPIKey,
|
|
Status: StatusActive,
|
|
Schedulable: true,
|
|
Concurrency: 2,
|
|
Credentials: map[string]any{
|
|
"api_key": "sk-test",
|
|
"base_url": wsServer.URL,
|
|
},
|
|
Extra: map[string]any{
|
|
"responses_websockets_v2_enabled": true,
|
|
},
|
|
}
|
|
|
|
body := []byte(`{"model":"gpt-5.1","stream":false,"previous_response_id":"resp_prev_1","input":[{"type":"input_text","text":"hello"}]}`)
|
|
result, err := svc.Forward(context.Background(), c, account, body)
|
|
require.NoError(t, err)
|
|
require.NotNil(t, result)
|
|
require.Equal(t, 12, result.Usage.InputTokens)
|
|
require.Equal(t, 7, result.Usage.OutputTokens)
|
|
require.Equal(t, 3, result.Usage.CacheReadInputTokens)
|
|
require.Equal(t, "resp_new_1", result.RequestID)
|
|
require.True(t, result.OpenAIWSMode)
|
|
require.False(t, gjson.GetBytes(upstream.lastBody, "model").Exists(), "WSv2 成功时不应回落 HTTP 上游")
|
|
|
|
received := <-receivedCh
|
|
require.Equal(t, "response.create", received.Type)
|
|
require.Equal(t, "resp_prev_1", received.PreviousResponseID)
|
|
require.True(t, received.StreamExists, "WS 请求应携带 stream 字段")
|
|
require.False(t, received.Stream, "应保持客户端 stream=false 的原始语义")
|
|
|
|
store := svc.getOpenAIWSStateStore()
|
|
mappedAccountID, getErr := store.GetResponseAccount(context.Background(), groupID, "resp_new_1")
|
|
require.NoError(t, getErr)
|
|
require.Equal(t, account.ID, mappedAccountID)
|
|
connID, ok := store.GetResponseConn("resp_new_1")
|
|
require.True(t, ok)
|
|
require.NotEmpty(t, connID)
|
|
|
|
responseBody := rec.Body.Bytes()
|
|
require.Equal(t, "resp_new_1", gjson.GetBytes(responseBody, "id").String())
|
|
}
|
|
|
|
func requestToJSONString(payload map[string]any) string {
|
|
if len(payload) == 0 {
|
|
return "{}"
|
|
}
|
|
b, err := json.Marshal(payload)
|
|
if err != nil {
|
|
return "{}"
|
|
}
|
|
return string(b)
|
|
}
|
|
|
|
func TestLogOpenAIWSBindResponseAccountWarn(t *testing.T) {
|
|
require.NotPanics(t, func() {
|
|
logOpenAIWSBindResponseAccountWarn(1, 2, "resp_ok", nil)
|
|
})
|
|
require.NotPanics(t, func() {
|
|
logOpenAIWSBindResponseAccountWarn(1, 2, "resp_err", errors.New("bind failed"))
|
|
})
|
|
}
|
|
|
|
func TestOpenAIGatewayService_Forward_WSv2_RewriteModelAndToolCallsOnCompletedEvent(t *testing.T) {
|
|
gin.SetMode(gin.TestMode)
|
|
|
|
rec := httptest.NewRecorder()
|
|
c, _ := gin.CreateTestContext(rec)
|
|
c.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses", nil)
|
|
c.Request.Header.Set("User-Agent", "codex_cli_rs/0.98.0")
|
|
groupID := int64(3001)
|
|
c.Set("api_key", &APIKey{GroupID: &groupID})
|
|
|
|
cfg := &config.Config{}
|
|
cfg.Security.URLAllowlist.Enabled = false
|
|
cfg.Security.URLAllowlist.AllowInsecureHTTP = true
|
|
cfg.Gateway.OpenAIWS.Enabled = true
|
|
cfg.Gateway.OpenAIWS.OAuthEnabled = true
|
|
cfg.Gateway.OpenAIWS.APIKeyEnabled = true
|
|
cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true
|
|
cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 1
|
|
cfg.Gateway.OpenAIWS.MinIdlePerAccount = 0
|
|
cfg.Gateway.OpenAIWS.MaxIdlePerAccount = 1
|
|
cfg.Gateway.OpenAIWS.QueueLimitPerConn = 8
|
|
cfg.Gateway.OpenAIWS.DialTimeoutSeconds = 3
|
|
cfg.Gateway.OpenAIWS.ReadTimeoutSeconds = 5
|
|
cfg.Gateway.OpenAIWS.WriteTimeoutSeconds = 3
|
|
|
|
captureConn := &openAIWSCaptureConn{
|
|
events: [][]byte{
|
|
[]byte(`{"type":"response.completed","response":{"id":"resp_model_tool_1","model":"gpt-5.1","tool_calls":[{"function":{"name":"apply_patch","arguments":"{\"file_path\":\"/tmp/a.txt\",\"old_string\":\"a\",\"new_string\":\"b\"}"}}],"usage":{"input_tokens":2,"output_tokens":1}},"tool_calls":[{"function":{"name":"apply_patch","arguments":"{\"file_path\":\"/tmp/a.txt\",\"old_string\":\"a\",\"new_string\":\"b\"}"}}]}`),
|
|
},
|
|
}
|
|
captureDialer := &openAIWSCaptureDialer{conn: captureConn}
|
|
pool := newOpenAIWSConnPool(cfg)
|
|
pool.setClientDialerForTest(captureDialer)
|
|
|
|
svc := &OpenAIGatewayService{
|
|
cfg: cfg,
|
|
httpUpstream: &httpUpstreamRecorder{},
|
|
cache: &stubGatewayCache{},
|
|
openaiWSResolver: NewOpenAIWSProtocolResolver(cfg),
|
|
toolCorrector: NewCodexToolCorrector(),
|
|
openaiWSPool: pool,
|
|
}
|
|
|
|
account := &Account{
|
|
ID: 1301,
|
|
Name: "openai-rewrite",
|
|
Platform: PlatformOpenAI,
|
|
Type: AccountTypeAPIKey,
|
|
Status: StatusActive,
|
|
Schedulable: true,
|
|
Concurrency: 1,
|
|
Credentials: map[string]any{
|
|
"api_key": "sk-test",
|
|
"model_mapping": map[string]any{
|
|
"custom-original-model": "gpt-5.1",
|
|
},
|
|
},
|
|
Extra: map[string]any{
|
|
"responses_websockets_v2_enabled": true,
|
|
},
|
|
}
|
|
|
|
body := []byte(`{"model":"custom-original-model","stream":false,"input":[{"type":"input_text","text":"hello"}]}`)
|
|
result, err := svc.Forward(context.Background(), c, account, body)
|
|
require.NoError(t, err)
|
|
require.NotNil(t, result)
|
|
require.Equal(t, "resp_model_tool_1", result.RequestID)
|
|
require.Equal(t, "custom-original-model", gjson.GetBytes(rec.Body.Bytes(), "model").String(), "响应模型应回写为原始请求模型")
|
|
require.Equal(t, "edit", gjson.GetBytes(rec.Body.Bytes(), "tool_calls.0.function.name").String(), "工具名称应被修正为 OpenCode 规范")
|
|
}
|
|
|
|
func TestOpenAIWSPayloadString_OnlyAcceptsStringValues(t *testing.T) {
|
|
payload := map[string]any{
|
|
"type": nil,
|
|
"model": 123,
|
|
"prompt_cache_key": " cache-key ",
|
|
"previous_response_id": []byte(" resp_1 "),
|
|
}
|
|
|
|
require.Equal(t, "", openAIWSPayloadString(payload, "type"))
|
|
require.Equal(t, "", openAIWSPayloadString(payload, "model"))
|
|
require.Equal(t, "cache-key", openAIWSPayloadString(payload, "prompt_cache_key"))
|
|
require.Equal(t, "resp_1", openAIWSPayloadString(payload, "previous_response_id"))
|
|
}
|
|
|
|
func TestOpenAIGatewayService_Forward_WSv2_PoolReuseNotOneToOne(t *testing.T) {
|
|
gin.SetMode(gin.TestMode)
|
|
|
|
var upgradeCount atomic.Int64
|
|
var sequence atomic.Int64
|
|
upgrader := websocket.Upgrader{CheckOrigin: func(r *http.Request) bool { return true }}
|
|
wsServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
upgradeCount.Add(1)
|
|
conn, err := upgrader.Upgrade(w, r, nil)
|
|
if err != nil {
|
|
t.Errorf("upgrade websocket failed: %v", err)
|
|
return
|
|
}
|
|
defer func() {
|
|
_ = conn.Close()
|
|
}()
|
|
|
|
for {
|
|
var request map[string]any
|
|
if err := conn.ReadJSON(&request); err != nil {
|
|
return
|
|
}
|
|
idx := sequence.Add(1)
|
|
responseID := "resp_reuse_" + strconv.FormatInt(idx, 10)
|
|
if err := conn.WriteJSON(map[string]any{
|
|
"type": "response.created",
|
|
"response": map[string]any{
|
|
"id": responseID,
|
|
"model": "gpt-5.1",
|
|
},
|
|
}); err != nil {
|
|
return
|
|
}
|
|
if err := conn.WriteJSON(map[string]any{
|
|
"type": "response.completed",
|
|
"response": map[string]any{
|
|
"id": responseID,
|
|
"model": "gpt-5.1",
|
|
"usage": map[string]any{
|
|
"input_tokens": 2,
|
|
"output_tokens": 1,
|
|
},
|
|
},
|
|
}); err != nil {
|
|
return
|
|
}
|
|
}
|
|
}))
|
|
defer wsServer.Close()
|
|
|
|
cfg := &config.Config{}
|
|
cfg.Security.URLAllowlist.Enabled = false
|
|
cfg.Security.URLAllowlist.AllowInsecureHTTP = true
|
|
cfg.Gateway.OpenAIWS.Enabled = true
|
|
cfg.Gateway.OpenAIWS.OAuthEnabled = true
|
|
cfg.Gateway.OpenAIWS.APIKeyEnabled = true
|
|
cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true
|
|
cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 2
|
|
cfg.Gateway.OpenAIWS.MinIdlePerAccount = 0
|
|
cfg.Gateway.OpenAIWS.MaxIdlePerAccount = 1
|
|
cfg.Gateway.OpenAIWS.QueueLimitPerConn = 8
|
|
cfg.Gateway.OpenAIWS.DialTimeoutSeconds = 3
|
|
cfg.Gateway.OpenAIWS.ReadTimeoutSeconds = 30
|
|
cfg.Gateway.OpenAIWS.WriteTimeoutSeconds = 10
|
|
|
|
svc := &OpenAIGatewayService{
|
|
cfg: cfg,
|
|
httpUpstream: &httpUpstreamRecorder{},
|
|
cache: &stubGatewayCache{},
|
|
openaiWSResolver: NewOpenAIWSProtocolResolver(cfg),
|
|
toolCorrector: NewCodexToolCorrector(),
|
|
}
|
|
account := &Account{
|
|
ID: 19,
|
|
Name: "openai-ws",
|
|
Platform: PlatformOpenAI,
|
|
Type: AccountTypeAPIKey,
|
|
Status: StatusActive,
|
|
Schedulable: true,
|
|
Concurrency: 2,
|
|
Credentials: map[string]any{
|
|
"api_key": "sk-test",
|
|
"base_url": wsServer.URL,
|
|
},
|
|
Extra: map[string]any{
|
|
"responses_websockets_v2_enabled": true,
|
|
},
|
|
}
|
|
|
|
for i := 0; i < 2; i++ {
|
|
rec := httptest.NewRecorder()
|
|
c, _ := gin.CreateTestContext(rec)
|
|
c.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses", nil)
|
|
c.Request.Header.Set("User-Agent", "codex_cli_rs/0.98.0")
|
|
groupID := int64(2001)
|
|
c.Set("api_key", &APIKey{GroupID: &groupID})
|
|
|
|
body := []byte(`{"model":"gpt-5.1","stream":false,"previous_response_id":"resp_prev_reuse","input":[{"type":"input_text","text":"hello"}]}`)
|
|
result, err := svc.Forward(context.Background(), c, account, body)
|
|
require.NoError(t, err)
|
|
require.NotNil(t, result)
|
|
require.True(t, strings.HasPrefix(result.RequestID, "resp_reuse_"))
|
|
}
|
|
|
|
require.Equal(t, int64(1), upgradeCount.Load(), "多个客户端请求应复用账号连接池而不是 1:1 对等建链")
|
|
metrics := svc.SnapshotOpenAIWSPoolMetrics()
|
|
require.GreaterOrEqual(t, metrics.AcquireReuseTotal, int64(1))
|
|
require.GreaterOrEqual(t, metrics.ConnPickTotal, int64(1))
|
|
}
|
|
|
|
func TestOpenAIGatewayService_Forward_WSv2_OAuthStoreFalseByDefault(t *testing.T) {
|
|
gin.SetMode(gin.TestMode)
|
|
|
|
rec := httptest.NewRecorder()
|
|
c, _ := gin.CreateTestContext(rec)
|
|
c.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses", nil)
|
|
c.Request.Header.Set("User-Agent", "codex_cli_rs/0.98.0")
|
|
c.Request.Header.Set("session_id", "sess-oauth-1")
|
|
c.Request.Header.Set("conversation_id", "conv-oauth-1")
|
|
|
|
cfg := &config.Config{}
|
|
cfg.Security.URLAllowlist.Enabled = false
|
|
cfg.Security.URLAllowlist.AllowInsecureHTTP = true
|
|
cfg.Gateway.OpenAIWS.Enabled = true
|
|
cfg.Gateway.OpenAIWS.OAuthEnabled = true
|
|
cfg.Gateway.OpenAIWS.APIKeyEnabled = true
|
|
cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true
|
|
cfg.Gateway.OpenAIWS.AllowStoreRecovery = false
|
|
cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 1
|
|
cfg.Gateway.OpenAIWS.MinIdlePerAccount = 0
|
|
cfg.Gateway.OpenAIWS.MaxIdlePerAccount = 1
|
|
|
|
captureConn := &openAIWSCaptureConn{
|
|
events: [][]byte{
|
|
[]byte(`{"type":"response.completed","response":{"id":"resp_oauth_1","model":"gpt-5.1","usage":{"input_tokens":3,"output_tokens":2}}}`),
|
|
},
|
|
}
|
|
captureDialer := &openAIWSCaptureDialer{conn: captureConn}
|
|
pool := newOpenAIWSConnPool(cfg)
|
|
pool.setClientDialerForTest(captureDialer)
|
|
|
|
svc := &OpenAIGatewayService{
|
|
cfg: cfg,
|
|
httpUpstream: &httpUpstreamRecorder{},
|
|
cache: &stubGatewayCache{},
|
|
openaiWSResolver: NewOpenAIWSProtocolResolver(cfg),
|
|
toolCorrector: NewCodexToolCorrector(),
|
|
openaiWSPool: pool,
|
|
}
|
|
account := &Account{
|
|
ID: 29,
|
|
Name: "openai-oauth",
|
|
Platform: PlatformOpenAI,
|
|
Type: AccountTypeOAuth,
|
|
Status: StatusActive,
|
|
Schedulable: true,
|
|
Concurrency: 1,
|
|
Credentials: map[string]any{
|
|
"access_token": "oauth-token-1",
|
|
},
|
|
Extra: map[string]any{
|
|
"responses_websockets_v2_enabled": true,
|
|
},
|
|
}
|
|
|
|
body := []byte(`{"model":"gpt-5.1","stream":false,"store":true,"input":[{"type":"input_text","text":"hello"}]}`)
|
|
result, err := svc.Forward(context.Background(), c, account, body)
|
|
require.NoError(t, err)
|
|
require.NotNil(t, result)
|
|
require.Equal(t, "resp_oauth_1", result.RequestID)
|
|
|
|
require.NotNil(t, captureConn.lastWrite)
|
|
requestJSON := requestToJSONString(captureConn.lastWrite)
|
|
require.True(t, gjson.Get(requestJSON, "store").Exists(), "OAuth WSv2 应显式写入 store 字段")
|
|
require.False(t, gjson.Get(requestJSON, "store").Bool(), "默认策略应将 OAuth store 置为 false")
|
|
require.True(t, gjson.Get(requestJSON, "stream").Exists(), "WSv2 payload 应保留 stream 字段")
|
|
require.True(t, gjson.Get(requestJSON, "stream").Bool(), "OAuth Codex 规范化后应强制 stream=true")
|
|
require.Equal(t, openAIWSBetaV2Value, captureDialer.lastHeaders.Get("OpenAI-Beta"))
|
|
require.Equal(t, "sess-oauth-1", captureDialer.lastHeaders.Get("session_id"))
|
|
require.Equal(t, "conv-oauth-1", captureDialer.lastHeaders.Get("conversation_id"))
|
|
}
|
|
|
|
func TestOpenAIGatewayService_Forward_WSv2_HeaderSessionFallbackFromPromptCacheKey(t *testing.T) {
|
|
gin.SetMode(gin.TestMode)
|
|
|
|
rec := httptest.NewRecorder()
|
|
c, _ := gin.CreateTestContext(rec)
|
|
c.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses", nil)
|
|
c.Request.Header.Set("User-Agent", "codex_cli_rs/0.98.0")
|
|
|
|
cfg := &config.Config{}
|
|
cfg.Security.URLAllowlist.Enabled = false
|
|
cfg.Security.URLAllowlist.AllowInsecureHTTP = true
|
|
cfg.Gateway.OpenAIWS.Enabled = true
|
|
cfg.Gateway.OpenAIWS.OAuthEnabled = true
|
|
cfg.Gateway.OpenAIWS.APIKeyEnabled = true
|
|
cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true
|
|
cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 1
|
|
cfg.Gateway.OpenAIWS.MinIdlePerAccount = 0
|
|
cfg.Gateway.OpenAIWS.MaxIdlePerAccount = 1
|
|
|
|
captureConn := &openAIWSCaptureConn{
|
|
events: [][]byte{
|
|
[]byte(`{"type":"response.completed","response":{"id":"resp_prompt_cache_key","model":"gpt-5.1","usage":{"input_tokens":2,"output_tokens":1}}}`),
|
|
},
|
|
}
|
|
captureDialer := &openAIWSCaptureDialer{conn: captureConn}
|
|
pool := newOpenAIWSConnPool(cfg)
|
|
pool.setClientDialerForTest(captureDialer)
|
|
|
|
svc := &OpenAIGatewayService{
|
|
cfg: cfg,
|
|
httpUpstream: &httpUpstreamRecorder{},
|
|
cache: &stubGatewayCache{},
|
|
openaiWSResolver: NewOpenAIWSProtocolResolver(cfg),
|
|
toolCorrector: NewCodexToolCorrector(),
|
|
openaiWSPool: pool,
|
|
}
|
|
account := &Account{
|
|
ID: 31,
|
|
Name: "openai-oauth",
|
|
Platform: PlatformOpenAI,
|
|
Type: AccountTypeOAuth,
|
|
Status: StatusActive,
|
|
Schedulable: true,
|
|
Concurrency: 1,
|
|
Credentials: map[string]any{
|
|
"access_token": "oauth-token-1",
|
|
},
|
|
Extra: map[string]any{
|
|
"responses_websockets_v2_enabled": true,
|
|
},
|
|
}
|
|
|
|
body := []byte(`{"model":"gpt-5.1","stream":true,"prompt_cache_key":"pcache_123","input":[{"type":"input_text","text":"hi"}]}`)
|
|
result, err := svc.Forward(context.Background(), c, account, body)
|
|
require.NoError(t, err)
|
|
require.NotNil(t, result)
|
|
require.Equal(t, "resp_prompt_cache_key", result.RequestID)
|
|
|
|
require.Equal(t, "pcache_123", captureDialer.lastHeaders.Get("session_id"))
|
|
require.Empty(t, captureDialer.lastHeaders.Get("conversation_id"))
|
|
require.NotNil(t, captureConn.lastWrite)
|
|
require.True(t, gjson.Get(requestToJSONString(captureConn.lastWrite), "stream").Exists())
|
|
}
|
|
|
|
func TestOpenAIGatewayService_Forward_WSv1_Unsupported(t *testing.T) {
|
|
gin.SetMode(gin.TestMode)
|
|
|
|
rec := httptest.NewRecorder()
|
|
c, _ := gin.CreateTestContext(rec)
|
|
c.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses", nil)
|
|
c.Request.Header.Set("User-Agent", "codex_cli_rs/0.98.0")
|
|
|
|
cfg := &config.Config{}
|
|
cfg.Security.URLAllowlist.Enabled = false
|
|
cfg.Security.URLAllowlist.AllowInsecureHTTP = true
|
|
cfg.Gateway.OpenAIWS.Enabled = true
|
|
cfg.Gateway.OpenAIWS.OAuthEnabled = true
|
|
cfg.Gateway.OpenAIWS.APIKeyEnabled = true
|
|
cfg.Gateway.OpenAIWS.ResponsesWebsockets = true
|
|
cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = false
|
|
|
|
upstream := &httpUpstreamRecorder{
|
|
resp: &http.Response{
|
|
StatusCode: http.StatusOK,
|
|
Header: http.Header{"Content-Type": []string{"application/json"}},
|
|
Body: io.NopCloser(strings.NewReader(`{"usage":{"input_tokens":1,"output_tokens":1}}`)),
|
|
},
|
|
}
|
|
|
|
svc := &OpenAIGatewayService{
|
|
cfg: cfg,
|
|
httpUpstream: upstream,
|
|
cache: &stubGatewayCache{},
|
|
openaiWSResolver: NewOpenAIWSProtocolResolver(cfg),
|
|
toolCorrector: NewCodexToolCorrector(),
|
|
}
|
|
|
|
account := &Account{
|
|
ID: 39,
|
|
Name: "openai-ws-v1",
|
|
Platform: PlatformOpenAI,
|
|
Type: AccountTypeAPIKey,
|
|
Status: StatusActive,
|
|
Schedulable: true,
|
|
Concurrency: 1,
|
|
Credentials: map[string]any{
|
|
"api_key": "sk-test",
|
|
"base_url": "https://api.openai.com/v1/responses",
|
|
},
|
|
Extra: map[string]any{
|
|
"responses_websockets_v2_enabled": true,
|
|
},
|
|
}
|
|
|
|
body := []byte(`{"model":"gpt-5.1","stream":false,"previous_response_id":"resp_prev_v1","input":[{"type":"input_text","text":"hello"}]}`)
|
|
result, err := svc.Forward(context.Background(), c, account, body)
|
|
require.Error(t, err)
|
|
require.Nil(t, result)
|
|
require.Contains(t, err.Error(), "ws v1")
|
|
require.Equal(t, http.StatusBadRequest, rec.Code)
|
|
require.Contains(t, rec.Body.String(), "WSv1")
|
|
require.Nil(t, upstream.lastReq, "WSv1 不支持时不应触发 HTTP 上游请求")
|
|
}
|
|
|
|
func TestOpenAIGatewayService_Forward_WSv2_TurnStateAndMetadataReplayOnReconnect(t *testing.T) {
|
|
gin.SetMode(gin.TestMode)
|
|
|
|
var connIndex atomic.Int64
|
|
headersCh := make(chan http.Header, 4)
|
|
upgrader := websocket.Upgrader{CheckOrigin: func(r *http.Request) bool { return true }}
|
|
wsServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
idx := connIndex.Add(1)
|
|
headersCh <- cloneHeader(r.Header)
|
|
|
|
respHeader := http.Header{}
|
|
if idx == 1 {
|
|
respHeader.Set("x-codex-turn-state", "turn_state_first")
|
|
}
|
|
conn, err := upgrader.Upgrade(w, r, respHeader)
|
|
if err != nil {
|
|
t.Errorf("upgrade websocket failed: %v", err)
|
|
return
|
|
}
|
|
defer func() {
|
|
_ = conn.Close()
|
|
}()
|
|
|
|
var request map[string]any
|
|
if err := conn.ReadJSON(&request); err != nil {
|
|
t.Errorf("read ws request failed: %v", err)
|
|
return
|
|
}
|
|
responseID := "resp_turn_" + strconv.FormatInt(idx, 10)
|
|
if err := conn.WriteJSON(map[string]any{
|
|
"type": "response.completed",
|
|
"response": map[string]any{
|
|
"id": responseID,
|
|
"model": "gpt-5.1",
|
|
"usage": map[string]any{
|
|
"input_tokens": 2,
|
|
"output_tokens": 1,
|
|
},
|
|
},
|
|
}); err != nil {
|
|
t.Errorf("write response.completed failed: %v", err)
|
|
return
|
|
}
|
|
}))
|
|
defer wsServer.Close()
|
|
|
|
cfg := &config.Config{}
|
|
cfg.Security.URLAllowlist.Enabled = false
|
|
cfg.Security.URLAllowlist.AllowInsecureHTTP = true
|
|
cfg.Gateway.OpenAIWS.Enabled = true
|
|
cfg.Gateway.OpenAIWS.OAuthEnabled = true
|
|
cfg.Gateway.OpenAIWS.APIKeyEnabled = true
|
|
cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true
|
|
cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 1
|
|
cfg.Gateway.OpenAIWS.MinIdlePerAccount = 0
|
|
cfg.Gateway.OpenAIWS.MaxIdlePerAccount = 0
|
|
|
|
svc := &OpenAIGatewayService{
|
|
cfg: cfg,
|
|
httpUpstream: &httpUpstreamRecorder{},
|
|
cache: &stubGatewayCache{},
|
|
openaiWSResolver: NewOpenAIWSProtocolResolver(cfg),
|
|
toolCorrector: NewCodexToolCorrector(),
|
|
}
|
|
|
|
account := &Account{
|
|
ID: 49,
|
|
Name: "openai-turn-state",
|
|
Platform: PlatformOpenAI,
|
|
Type: AccountTypeAPIKey,
|
|
Status: StatusActive,
|
|
Schedulable: true,
|
|
Concurrency: 1,
|
|
Credentials: map[string]any{
|
|
"api_key": "sk-test",
|
|
"base_url": wsServer.URL,
|
|
},
|
|
Extra: map[string]any{
|
|
"responses_websockets_v2_enabled": true,
|
|
},
|
|
}
|
|
|
|
reqBody := []byte(`{"model":"gpt-5.1","stream":false,"input":[{"type":"input_text","text":"hello"}]}`)
|
|
rec1 := httptest.NewRecorder()
|
|
c1, _ := gin.CreateTestContext(rec1)
|
|
c1.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses", nil)
|
|
c1.Request.Header.Set("session_id", "session_turn_state")
|
|
c1.Request.Header.Set("x-codex-turn-metadata", "turn_meta_1")
|
|
result1, err := svc.Forward(context.Background(), c1, account, reqBody)
|
|
require.NoError(t, err)
|
|
require.NotNil(t, result1)
|
|
|
|
sessionHash := svc.GenerateSessionHash(c1, reqBody)
|
|
store := svc.getOpenAIWSStateStore()
|
|
turnState, ok := store.GetSessionTurnState(0, sessionHash)
|
|
require.True(t, ok)
|
|
require.Equal(t, "turn_state_first", turnState)
|
|
|
|
// 主动淘汰连接,模拟下一次请求发生重连。
|
|
connID, hasConn := store.GetResponseConn(result1.RequestID)
|
|
require.True(t, hasConn)
|
|
svc.getOpenAIWSConnPool().evictConn(account.ID, connID)
|
|
|
|
rec2 := httptest.NewRecorder()
|
|
c2, _ := gin.CreateTestContext(rec2)
|
|
c2.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses", nil)
|
|
c2.Request.Header.Set("session_id", "session_turn_state")
|
|
c2.Request.Header.Set("x-codex-turn-metadata", "turn_meta_2")
|
|
result2, err := svc.Forward(context.Background(), c2, account, reqBody)
|
|
require.NoError(t, err)
|
|
require.NotNil(t, result2)
|
|
|
|
firstHandshakeHeaders := <-headersCh
|
|
secondHandshakeHeaders := <-headersCh
|
|
require.Equal(t, "turn_meta_1", firstHandshakeHeaders.Get("X-Codex-Turn-Metadata"))
|
|
require.Equal(t, "turn_meta_2", secondHandshakeHeaders.Get("X-Codex-Turn-Metadata"))
|
|
require.Equal(t, "turn_state_first", secondHandshakeHeaders.Get("X-Codex-Turn-State"))
|
|
}
|
|
|
|
func TestOpenAIGatewayService_Forward_WSv2_GeneratePrewarm(t *testing.T) {
|
|
gin.SetMode(gin.TestMode)
|
|
|
|
rec := httptest.NewRecorder()
|
|
c, _ := gin.CreateTestContext(rec)
|
|
c.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses", nil)
|
|
c.Request.Header.Set("session_id", "session-prewarm")
|
|
|
|
cfg := &config.Config{}
|
|
cfg.Security.URLAllowlist.Enabled = false
|
|
cfg.Security.URLAllowlist.AllowInsecureHTTP = true
|
|
cfg.Gateway.OpenAIWS.Enabled = true
|
|
cfg.Gateway.OpenAIWS.OAuthEnabled = true
|
|
cfg.Gateway.OpenAIWS.APIKeyEnabled = true
|
|
cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true
|
|
cfg.Gateway.OpenAIWS.PrewarmGenerateEnabled = true
|
|
cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 1
|
|
cfg.Gateway.OpenAIWS.MinIdlePerAccount = 0
|
|
cfg.Gateway.OpenAIWS.MaxIdlePerAccount = 1
|
|
|
|
captureConn := &openAIWSCaptureConn{
|
|
events: [][]byte{
|
|
[]byte(`{"type":"response.completed","response":{"id":"resp_prewarm_1","model":"gpt-5.1","usage":{"input_tokens":0,"output_tokens":0}}}`),
|
|
[]byte(`{"type":"response.completed","response":{"id":"resp_main_1","model":"gpt-5.1","usage":{"input_tokens":4,"output_tokens":2}}}`),
|
|
},
|
|
}
|
|
captureDialer := &openAIWSCaptureDialer{conn: captureConn}
|
|
pool := newOpenAIWSConnPool(cfg)
|
|
pool.setClientDialerForTest(captureDialer)
|
|
|
|
svc := &OpenAIGatewayService{
|
|
cfg: cfg,
|
|
httpUpstream: &httpUpstreamRecorder{},
|
|
cache: &stubGatewayCache{},
|
|
openaiWSResolver: NewOpenAIWSProtocolResolver(cfg),
|
|
toolCorrector: NewCodexToolCorrector(),
|
|
openaiWSPool: pool,
|
|
}
|
|
|
|
account := &Account{
|
|
ID: 59,
|
|
Name: "openai-prewarm",
|
|
Platform: PlatformOpenAI,
|
|
Type: AccountTypeAPIKey,
|
|
Status: StatusActive,
|
|
Schedulable: true,
|
|
Concurrency: 1,
|
|
Credentials: map[string]any{
|
|
"api_key": "sk-test",
|
|
},
|
|
Extra: map[string]any{
|
|
"responses_websockets_v2_enabled": true,
|
|
},
|
|
}
|
|
|
|
body := []byte(`{"model":"gpt-5.1","stream":false,"input":[{"type":"input_text","text":"hello"}]}`)
|
|
result, err := svc.Forward(context.Background(), c, account, body)
|
|
require.NoError(t, err)
|
|
require.NotNil(t, result)
|
|
require.Equal(t, "resp_main_1", result.RequestID)
|
|
|
|
require.Len(t, captureConn.writes, 2, "开启 generate=false 预热后应发送两次 WS 请求")
|
|
firstWrite := requestToJSONString(captureConn.writes[0])
|
|
secondWrite := requestToJSONString(captureConn.writes[1])
|
|
require.True(t, gjson.Get(firstWrite, "generate").Exists())
|
|
require.False(t, gjson.Get(firstWrite, "generate").Bool())
|
|
require.False(t, gjson.Get(secondWrite, "generate").Exists())
|
|
}
|
|
|
|
func TestOpenAIGatewayService_PrewarmReadHonorsParentContext(t *testing.T) {
|
|
cfg := &config.Config{}
|
|
cfg.Gateway.OpenAIWS.PrewarmGenerateEnabled = true
|
|
cfg.Gateway.OpenAIWS.ReadTimeoutSeconds = 5
|
|
cfg.Gateway.OpenAIWS.WriteTimeoutSeconds = 3
|
|
|
|
svc := &OpenAIGatewayService{
|
|
cfg: cfg,
|
|
toolCorrector: NewCodexToolCorrector(),
|
|
}
|
|
account := &Account{
|
|
ID: 601,
|
|
Name: "openai-prewarm-timeout",
|
|
Platform: PlatformOpenAI,
|
|
Type: AccountTypeAPIKey,
|
|
Status: StatusActive,
|
|
Schedulable: true,
|
|
}
|
|
conn := newOpenAIWSConn("prewarm_ctx_conn", account.ID, &openAIWSBlockingConn{
|
|
readDelay: 200 * time.Millisecond,
|
|
}, nil)
|
|
lease := &openAIWSConnLease{
|
|
accountID: account.ID,
|
|
conn: conn,
|
|
}
|
|
payload := map[string]any{
|
|
"type": "response.create",
|
|
"model": "gpt-5.1",
|
|
}
|
|
|
|
ctx, cancel := context.WithTimeout(context.Background(), 40*time.Millisecond)
|
|
defer cancel()
|
|
start := time.Now()
|
|
err := svc.performOpenAIWSGeneratePrewarm(
|
|
ctx,
|
|
lease,
|
|
OpenAIWSProtocolDecision{Transport: OpenAIUpstreamTransportResponsesWebsocketV2},
|
|
payload,
|
|
"",
|
|
map[string]any{"model": "gpt-5.1"},
|
|
account,
|
|
nil,
|
|
0,
|
|
)
|
|
elapsed := time.Since(start)
|
|
require.Error(t, err)
|
|
require.Contains(t, err.Error(), "prewarm_read_event")
|
|
require.Less(t, elapsed, 180*time.Millisecond, "预热读取应受父 context 取消控制,不应阻塞到 read_timeout")
|
|
}
|
|
|
|
func TestOpenAIGatewayService_Forward_WSv2_TurnMetadataInPayloadOnConnReuse(t *testing.T) {
|
|
gin.SetMode(gin.TestMode)
|
|
|
|
cfg := &config.Config{}
|
|
cfg.Security.URLAllowlist.Enabled = false
|
|
cfg.Security.URLAllowlist.AllowInsecureHTTP = true
|
|
cfg.Gateway.OpenAIWS.Enabled = true
|
|
cfg.Gateway.OpenAIWS.OAuthEnabled = true
|
|
cfg.Gateway.OpenAIWS.APIKeyEnabled = true
|
|
cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true
|
|
cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 1
|
|
cfg.Gateway.OpenAIWS.MinIdlePerAccount = 0
|
|
cfg.Gateway.OpenAIWS.MaxIdlePerAccount = 1
|
|
|
|
captureConn := &openAIWSCaptureConn{
|
|
events: [][]byte{
|
|
[]byte(`{"type":"response.completed","response":{"id":"resp_meta_1","model":"gpt-5.1","usage":{"input_tokens":1,"output_tokens":1}}}`),
|
|
[]byte(`{"type":"response.completed","response":{"id":"resp_meta_2","model":"gpt-5.1","usage":{"input_tokens":1,"output_tokens":1}}}`),
|
|
},
|
|
}
|
|
captureDialer := &openAIWSCaptureDialer{conn: captureConn}
|
|
pool := newOpenAIWSConnPool(cfg)
|
|
pool.setClientDialerForTest(captureDialer)
|
|
|
|
svc := &OpenAIGatewayService{
|
|
cfg: cfg,
|
|
httpUpstream: &httpUpstreamRecorder{},
|
|
cache: &stubGatewayCache{},
|
|
openaiWSResolver: NewOpenAIWSProtocolResolver(cfg),
|
|
toolCorrector: NewCodexToolCorrector(),
|
|
openaiWSPool: pool,
|
|
}
|
|
|
|
account := &Account{
|
|
ID: 69,
|
|
Name: "openai-turn-metadata",
|
|
Platform: PlatformOpenAI,
|
|
Type: AccountTypeAPIKey,
|
|
Status: StatusActive,
|
|
Schedulable: true,
|
|
Concurrency: 1,
|
|
Credentials: map[string]any{
|
|
"api_key": "sk-test",
|
|
},
|
|
Extra: map[string]any{
|
|
"responses_websockets_v2_enabled": true,
|
|
},
|
|
}
|
|
|
|
body := []byte(`{"model":"gpt-5.1","stream":false,"input":[{"type":"input_text","text":"hello"}]}`)
|
|
|
|
rec1 := httptest.NewRecorder()
|
|
c1, _ := gin.CreateTestContext(rec1)
|
|
c1.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses", nil)
|
|
c1.Request.Header.Set("session_id", "session-metadata-reuse")
|
|
c1.Request.Header.Set("x-codex-turn-metadata", "turn_meta_payload_1")
|
|
result1, err := svc.Forward(context.Background(), c1, account, body)
|
|
require.NoError(t, err)
|
|
require.NotNil(t, result1)
|
|
require.Equal(t, "resp_meta_1", result1.RequestID)
|
|
|
|
rec2 := httptest.NewRecorder()
|
|
c2, _ := gin.CreateTestContext(rec2)
|
|
c2.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses", nil)
|
|
c2.Request.Header.Set("session_id", "session-metadata-reuse")
|
|
c2.Request.Header.Set("x-codex-turn-metadata", "turn_meta_payload_2")
|
|
result2, err := svc.Forward(context.Background(), c2, account, body)
|
|
require.NoError(t, err)
|
|
require.NotNil(t, result2)
|
|
require.Equal(t, "resp_meta_2", result2.RequestID)
|
|
|
|
require.Equal(t, 1, captureDialer.DialCount(), "同一账号两轮请求应复用同一 WS 连接")
|
|
require.Len(t, captureConn.writes, 2)
|
|
|
|
firstWrite := requestToJSONString(captureConn.writes[0])
|
|
secondWrite := requestToJSONString(captureConn.writes[1])
|
|
require.Equal(t, "turn_meta_payload_1", gjson.Get(firstWrite, "client_metadata.x-codex-turn-metadata").String())
|
|
require.Equal(t, "turn_meta_payload_2", gjson.Get(secondWrite, "client_metadata.x-codex-turn-metadata").String())
|
|
}
|
|
|
|
func TestOpenAIGatewayService_Forward_WSv2StoreFalseSessionConnIsolation(t *testing.T) {
|
|
gin.SetMode(gin.TestMode)
|
|
|
|
var upgradeCount atomic.Int64
|
|
var sequence atomic.Int64
|
|
upgrader := websocket.Upgrader{CheckOrigin: func(r *http.Request) bool { return true }}
|
|
wsServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
upgradeCount.Add(1)
|
|
conn, err := upgrader.Upgrade(w, r, nil)
|
|
if err != nil {
|
|
t.Errorf("upgrade websocket failed: %v", err)
|
|
return
|
|
}
|
|
defer func() {
|
|
_ = conn.Close()
|
|
}()
|
|
|
|
for {
|
|
var request map[string]any
|
|
if err := conn.ReadJSON(&request); err != nil {
|
|
return
|
|
}
|
|
responseID := "resp_store_false_" + strconv.FormatInt(sequence.Add(1), 10)
|
|
if err := conn.WriteJSON(map[string]any{
|
|
"type": "response.completed",
|
|
"response": map[string]any{
|
|
"id": responseID,
|
|
"model": "gpt-5.1",
|
|
"usage": map[string]any{
|
|
"input_tokens": 1,
|
|
"output_tokens": 1,
|
|
},
|
|
},
|
|
}); err != nil {
|
|
return
|
|
}
|
|
}
|
|
}))
|
|
defer wsServer.Close()
|
|
|
|
cfg := &config.Config{}
|
|
cfg.Security.URLAllowlist.Enabled = false
|
|
cfg.Security.URLAllowlist.AllowInsecureHTTP = true
|
|
cfg.Gateway.OpenAIWS.Enabled = true
|
|
cfg.Gateway.OpenAIWS.OAuthEnabled = true
|
|
cfg.Gateway.OpenAIWS.APIKeyEnabled = true
|
|
cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true
|
|
cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 4
|
|
cfg.Gateway.OpenAIWS.MinIdlePerAccount = 0
|
|
cfg.Gateway.OpenAIWS.MaxIdlePerAccount = 4
|
|
cfg.Gateway.OpenAIWS.StoreDisabledForceNewConn = true
|
|
|
|
svc := &OpenAIGatewayService{
|
|
cfg: cfg,
|
|
httpUpstream: &httpUpstreamRecorder{},
|
|
cache: &stubGatewayCache{},
|
|
openaiWSResolver: NewOpenAIWSProtocolResolver(cfg),
|
|
toolCorrector: NewCodexToolCorrector(),
|
|
}
|
|
|
|
account := &Account{
|
|
ID: 79,
|
|
Name: "openai-store-false",
|
|
Platform: PlatformOpenAI,
|
|
Type: AccountTypeAPIKey,
|
|
Status: StatusActive,
|
|
Schedulable: true,
|
|
Concurrency: 2,
|
|
Credentials: map[string]any{
|
|
"api_key": "sk-test",
|
|
"base_url": wsServer.URL,
|
|
},
|
|
Extra: map[string]any{
|
|
"responses_websockets_v2_enabled": true,
|
|
},
|
|
}
|
|
|
|
body := []byte(`{"model":"gpt-5.1","stream":false,"store":false,"input":[{"type":"input_text","text":"hello"}]}`)
|
|
|
|
rec1 := httptest.NewRecorder()
|
|
c1, _ := gin.CreateTestContext(rec1)
|
|
c1.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses", nil)
|
|
c1.Request.Header.Set("session_id", "session_store_false_a")
|
|
result1, err := svc.Forward(context.Background(), c1, account, body)
|
|
require.NoError(t, err)
|
|
require.NotNil(t, result1)
|
|
require.Equal(t, int64(1), upgradeCount.Load())
|
|
|
|
rec2 := httptest.NewRecorder()
|
|
c2, _ := gin.CreateTestContext(rec2)
|
|
c2.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses", nil)
|
|
c2.Request.Header.Set("session_id", "session_store_false_a")
|
|
result2, err := svc.Forward(context.Background(), c2, account, body)
|
|
require.NoError(t, err)
|
|
require.NotNil(t, result2)
|
|
require.Equal(t, int64(1), upgradeCount.Load(), "同一 session(store=false) 应复用同一 WS 连接")
|
|
|
|
rec3 := httptest.NewRecorder()
|
|
c3, _ := gin.CreateTestContext(rec3)
|
|
c3.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses", nil)
|
|
c3.Request.Header.Set("session_id", "session_store_false_b")
|
|
result3, err := svc.Forward(context.Background(), c3, account, body)
|
|
require.NoError(t, err)
|
|
require.NotNil(t, result3)
|
|
require.Equal(t, int64(2), upgradeCount.Load(), "不同 session(store=false) 应隔离连接,避免续链状态互相覆盖")
|
|
}
|
|
|
|
func TestOpenAIGatewayService_Forward_WSv2StoreFalseDisableForceNewConnAllowsReuse(t *testing.T) {
|
|
gin.SetMode(gin.TestMode)
|
|
|
|
var upgradeCount atomic.Int64
|
|
var sequence atomic.Int64
|
|
upgrader := websocket.Upgrader{CheckOrigin: func(r *http.Request) bool { return true }}
|
|
wsServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
upgradeCount.Add(1)
|
|
conn, err := upgrader.Upgrade(w, r, nil)
|
|
if err != nil {
|
|
t.Errorf("upgrade websocket failed: %v", err)
|
|
return
|
|
}
|
|
defer func() {
|
|
_ = conn.Close()
|
|
}()
|
|
|
|
for {
|
|
var request map[string]any
|
|
if err := conn.ReadJSON(&request); err != nil {
|
|
return
|
|
}
|
|
responseID := "resp_store_false_reuse_" + strconv.FormatInt(sequence.Add(1), 10)
|
|
if err := conn.WriteJSON(map[string]any{
|
|
"type": "response.completed",
|
|
"response": map[string]any{
|
|
"id": responseID,
|
|
"model": "gpt-5.1",
|
|
"usage": map[string]any{
|
|
"input_tokens": 1,
|
|
"output_tokens": 1,
|
|
},
|
|
},
|
|
}); err != nil {
|
|
return
|
|
}
|
|
}
|
|
}))
|
|
defer wsServer.Close()
|
|
|
|
cfg := &config.Config{}
|
|
cfg.Security.URLAllowlist.Enabled = false
|
|
cfg.Security.URLAllowlist.AllowInsecureHTTP = true
|
|
cfg.Gateway.OpenAIWS.Enabled = true
|
|
cfg.Gateway.OpenAIWS.OAuthEnabled = true
|
|
cfg.Gateway.OpenAIWS.APIKeyEnabled = true
|
|
cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true
|
|
cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 1
|
|
cfg.Gateway.OpenAIWS.MinIdlePerAccount = 0
|
|
cfg.Gateway.OpenAIWS.MaxIdlePerAccount = 1
|
|
cfg.Gateway.OpenAIWS.StoreDisabledForceNewConn = false
|
|
|
|
svc := &OpenAIGatewayService{
|
|
cfg: cfg,
|
|
httpUpstream: &httpUpstreamRecorder{},
|
|
cache: &stubGatewayCache{},
|
|
openaiWSResolver: NewOpenAIWSProtocolResolver(cfg),
|
|
toolCorrector: NewCodexToolCorrector(),
|
|
}
|
|
|
|
account := &Account{
|
|
ID: 80,
|
|
Name: "openai-store-false-reuse",
|
|
Platform: PlatformOpenAI,
|
|
Type: AccountTypeAPIKey,
|
|
Status: StatusActive,
|
|
Schedulable: true,
|
|
Concurrency: 2,
|
|
Credentials: map[string]any{
|
|
"api_key": "sk-test",
|
|
"base_url": wsServer.URL,
|
|
},
|
|
Extra: map[string]any{
|
|
"responses_websockets_v2_enabled": true,
|
|
},
|
|
}
|
|
|
|
body := []byte(`{"model":"gpt-5.1","stream":false,"store":false,"input":[{"type":"input_text","text":"hello"}]}`)
|
|
|
|
rec1 := httptest.NewRecorder()
|
|
c1, _ := gin.CreateTestContext(rec1)
|
|
c1.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses", nil)
|
|
c1.Request.Header.Set("session_id", "session_store_false_reuse_a")
|
|
result1, err := svc.Forward(context.Background(), c1, account, body)
|
|
require.NoError(t, err)
|
|
require.NotNil(t, result1)
|
|
require.Equal(t, int64(1), upgradeCount.Load())
|
|
|
|
rec2 := httptest.NewRecorder()
|
|
c2, _ := gin.CreateTestContext(rec2)
|
|
c2.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses", nil)
|
|
c2.Request.Header.Set("session_id", "session_store_false_reuse_b")
|
|
result2, err := svc.Forward(context.Background(), c2, account, body)
|
|
require.NoError(t, err)
|
|
require.NotNil(t, result2)
|
|
require.Equal(t, int64(1), upgradeCount.Load(), "关闭强制新连后,不同 session(store=false) 可复用连接")
|
|
}
|
|
|
|
func TestOpenAIGatewayService_Forward_WSv2ReadTimeoutAppliesPerRead(t *testing.T) {
|
|
gin.SetMode(gin.TestMode)
|
|
|
|
rec := httptest.NewRecorder()
|
|
c, _ := gin.CreateTestContext(rec)
|
|
c.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses", nil)
|
|
c.Request.Header.Set("User-Agent", "codex_cli_rs/0.98.0")
|
|
|
|
cfg := &config.Config{}
|
|
cfg.Security.URLAllowlist.Enabled = false
|
|
cfg.Security.URLAllowlist.AllowInsecureHTTP = true
|
|
cfg.Gateway.OpenAIWS.Enabled = true
|
|
cfg.Gateway.OpenAIWS.OAuthEnabled = true
|
|
cfg.Gateway.OpenAIWS.APIKeyEnabled = true
|
|
cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true
|
|
cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 1
|
|
cfg.Gateway.OpenAIWS.MinIdlePerAccount = 0
|
|
cfg.Gateway.OpenAIWS.MaxIdlePerAccount = 1
|
|
cfg.Gateway.OpenAIWS.QueueLimitPerConn = 8
|
|
cfg.Gateway.OpenAIWS.DialTimeoutSeconds = 3
|
|
cfg.Gateway.OpenAIWS.ReadTimeoutSeconds = 1
|
|
cfg.Gateway.OpenAIWS.WriteTimeoutSeconds = 3
|
|
|
|
captureConn := &openAIWSCaptureConn{
|
|
readDelays: []time.Duration{
|
|
700 * time.Millisecond,
|
|
700 * time.Millisecond,
|
|
},
|
|
events: [][]byte{
|
|
[]byte(`{"type":"response.created","response":{"id":"resp_timeout_ok","model":"gpt-5.1"}}`),
|
|
[]byte(`{"type":"response.completed","response":{"id":"resp_timeout_ok","model":"gpt-5.1","usage":{"input_tokens":2,"output_tokens":1}}}`),
|
|
},
|
|
}
|
|
captureDialer := &openAIWSCaptureDialer{conn: captureConn}
|
|
pool := newOpenAIWSConnPool(cfg)
|
|
pool.setClientDialerForTest(captureDialer)
|
|
|
|
upstream := &httpUpstreamRecorder{
|
|
resp: &http.Response{
|
|
StatusCode: http.StatusOK,
|
|
Header: http.Header{"Content-Type": []string{"application/json"}},
|
|
Body: io.NopCloser(strings.NewReader(`{"id":"resp_http_fallback","usage":{"input_tokens":1,"output_tokens":1}}`)),
|
|
},
|
|
}
|
|
|
|
svc := &OpenAIGatewayService{
|
|
cfg: cfg,
|
|
httpUpstream: upstream,
|
|
cache: &stubGatewayCache{},
|
|
openaiWSResolver: NewOpenAIWSProtocolResolver(cfg),
|
|
toolCorrector: NewCodexToolCorrector(),
|
|
openaiWSPool: pool,
|
|
}
|
|
|
|
account := &Account{
|
|
ID: 81,
|
|
Name: "openai-read-timeout",
|
|
Platform: PlatformOpenAI,
|
|
Type: AccountTypeAPIKey,
|
|
Status: StatusActive,
|
|
Schedulable: true,
|
|
Concurrency: 1,
|
|
Credentials: map[string]any{
|
|
"api_key": "sk-test",
|
|
},
|
|
Extra: map[string]any{
|
|
"responses_websockets_v2_enabled": true,
|
|
},
|
|
}
|
|
|
|
body := []byte(`{"model":"gpt-5.1","stream":false,"input":[{"type":"input_text","text":"hello"}]}`)
|
|
result, err := svc.Forward(context.Background(), c, account, body)
|
|
require.NoError(t, err)
|
|
require.NotNil(t, result)
|
|
require.Equal(t, "resp_timeout_ok", result.RequestID)
|
|
require.Nil(t, upstream.lastReq, "每次 Read 都应独立应用超时;总时长超过 read_timeout 不应误回退 HTTP")
|
|
}
|
|
|
|
type openAIWSCaptureDialer struct {
|
|
mu sync.Mutex
|
|
conn *openAIWSCaptureConn
|
|
lastHeaders http.Header
|
|
handshake http.Header
|
|
dialCount int
|
|
}
|
|
|
|
func (d *openAIWSCaptureDialer) Dial(
|
|
ctx context.Context,
|
|
wsURL string,
|
|
headers http.Header,
|
|
proxyURL string,
|
|
) (openAIWSClientConn, int, http.Header, error) {
|
|
_ = ctx
|
|
_ = wsURL
|
|
_ = proxyURL
|
|
d.mu.Lock()
|
|
d.lastHeaders = cloneHeader(headers)
|
|
d.dialCount++
|
|
respHeaders := cloneHeader(d.handshake)
|
|
d.mu.Unlock()
|
|
return d.conn, 0, respHeaders, nil
|
|
}
|
|
|
|
func (d *openAIWSCaptureDialer) DialCount() int {
|
|
d.mu.Lock()
|
|
defer d.mu.Unlock()
|
|
return d.dialCount
|
|
}
|
|
|
|
type openAIWSCaptureConn struct {
|
|
mu sync.Mutex
|
|
readDelays []time.Duration
|
|
events [][]byte
|
|
lastWrite map[string]any
|
|
writes []map[string]any
|
|
closed bool
|
|
}
|
|
|
|
func (c *openAIWSCaptureConn) WriteJSON(ctx context.Context, value any) error {
|
|
_ = ctx
|
|
c.mu.Lock()
|
|
defer c.mu.Unlock()
|
|
if c.closed {
|
|
return errOpenAIWSConnClosed
|
|
}
|
|
switch payload := value.(type) {
|
|
case map[string]any:
|
|
c.lastWrite = cloneMapStringAny(payload)
|
|
c.writes = append(c.writes, cloneMapStringAny(payload))
|
|
case json.RawMessage:
|
|
var parsed map[string]any
|
|
if err := json.Unmarshal(payload, &parsed); err == nil {
|
|
c.lastWrite = cloneMapStringAny(parsed)
|
|
c.writes = append(c.writes, cloneMapStringAny(parsed))
|
|
}
|
|
case []byte:
|
|
var parsed map[string]any
|
|
if err := json.Unmarshal(payload, &parsed); err == nil {
|
|
c.lastWrite = cloneMapStringAny(parsed)
|
|
c.writes = append(c.writes, cloneMapStringAny(parsed))
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (c *openAIWSCaptureConn) ReadMessage(ctx context.Context) ([]byte, error) {
|
|
if ctx == nil {
|
|
ctx = context.Background()
|
|
}
|
|
c.mu.Lock()
|
|
if c.closed {
|
|
c.mu.Unlock()
|
|
return nil, errOpenAIWSConnClosed
|
|
}
|
|
if len(c.events) == 0 {
|
|
c.mu.Unlock()
|
|
return nil, io.EOF
|
|
}
|
|
delay := time.Duration(0)
|
|
if len(c.readDelays) > 0 {
|
|
delay = c.readDelays[0]
|
|
c.readDelays = c.readDelays[1:]
|
|
}
|
|
event := c.events[0]
|
|
c.events = c.events[1:]
|
|
c.mu.Unlock()
|
|
if delay > 0 {
|
|
timer := time.NewTimer(delay)
|
|
defer timer.Stop()
|
|
select {
|
|
case <-ctx.Done():
|
|
return nil, ctx.Err()
|
|
case <-timer.C:
|
|
}
|
|
}
|
|
return event, nil
|
|
}
|
|
|
|
func (c *openAIWSCaptureConn) Ping(ctx context.Context) error {
|
|
_ = ctx
|
|
return nil
|
|
}
|
|
|
|
func (c *openAIWSCaptureConn) Close() error {
|
|
c.mu.Lock()
|
|
defer c.mu.Unlock()
|
|
c.closed = true
|
|
return nil
|
|
}
|
|
|
|
func cloneMapStringAny(src map[string]any) map[string]any {
|
|
if src == nil {
|
|
return nil
|
|
}
|
|
dst := make(map[string]any, len(src))
|
|
for k, v := range src {
|
|
dst[k] = v
|
|
}
|
|
return dst
|
|
}
|