Files
sub2api/backend/internal/service/openai_ws_forwarder_success_test.go
2026-02-28 15:01:20 +08:00

1307 lines
42 KiB
Go

package service
import (
"context"
"encoding/json"
"errors"
"io"
"net/http"
"net/http/httptest"
"strconv"
"strings"
"sync"
"sync/atomic"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/gin-gonic/gin"
"github.com/gorilla/websocket"
"github.com/stretchr/testify/require"
"github.com/tidwall/gjson"
)
func TestOpenAIGatewayService_Forward_WSv2_SuccessAndBindSticky(t *testing.T) {
gin.SetMode(gin.TestMode)
type receivedPayload struct {
Type string
PreviousResponseID string
StreamExists bool
Stream bool
}
receivedCh := make(chan receivedPayload, 1)
upgrader := websocket.Upgrader{CheckOrigin: func(r *http.Request) bool { return true }}
wsServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
conn, err := upgrader.Upgrade(w, r, nil)
if err != nil {
t.Errorf("upgrade websocket failed: %v", err)
return
}
defer func() {
_ = conn.Close()
}()
var request map[string]any
if err := conn.ReadJSON(&request); err != nil {
t.Errorf("read ws request failed: %v", err)
return
}
requestJSON := requestToJSONString(request)
receivedCh <- receivedPayload{
Type: strings.TrimSpace(gjson.Get(requestJSON, "type").String()),
PreviousResponseID: strings.TrimSpace(gjson.Get(requestJSON, "previous_response_id").String()),
StreamExists: gjson.Get(requestJSON, "stream").Exists(),
Stream: gjson.Get(requestJSON, "stream").Bool(),
}
if err := conn.WriteJSON(map[string]any{
"type": "response.created",
"response": map[string]any{
"id": "resp_new_1",
"model": "gpt-5.1",
},
}); err != nil {
t.Errorf("write response.created failed: %v", err)
return
}
if err := conn.WriteJSON(map[string]any{
"type": "response.completed",
"response": map[string]any{
"id": "resp_new_1",
"model": "gpt-5.1",
"usage": map[string]any{
"input_tokens": 12,
"output_tokens": 7,
"input_tokens_details": map[string]any{
"cached_tokens": 3,
},
},
},
}); err != nil {
t.Errorf("write response.completed failed: %v", err)
return
}
}))
defer wsServer.Close()
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses", nil)
c.Request.Header.Set("User-Agent", "unit-test-agent/1.0")
groupID := int64(1001)
c.Set("api_key", &APIKey{GroupID: &groupID})
cfg := &config.Config{}
cfg.Security.URLAllowlist.Enabled = false
cfg.Security.URLAllowlist.AllowInsecureHTTP = true
cfg.Gateway.OpenAIWS.Enabled = true
cfg.Gateway.OpenAIWS.OAuthEnabled = true
cfg.Gateway.OpenAIWS.APIKeyEnabled = true
cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true
cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 2
cfg.Gateway.OpenAIWS.QueueLimitPerConn = 8
cfg.Gateway.OpenAIWS.DialTimeoutSeconds = 3
cfg.Gateway.OpenAIWS.ReadTimeoutSeconds = 30
cfg.Gateway.OpenAIWS.WriteTimeoutSeconds = 10
cfg.Gateway.OpenAIWS.StickyResponseIDTTLSeconds = 3600
upstream := &httpUpstreamRecorder{
resp: &http.Response{
StatusCode: http.StatusOK,
Header: http.Header{"Content-Type": []string{"application/json"}},
Body: io.NopCloser(strings.NewReader(`{"usage":{"input_tokens":1,"output_tokens":1}}`)),
},
}
cache := &stubGatewayCache{}
svc := &OpenAIGatewayService{
cfg: cfg,
httpUpstream: upstream,
cache: cache,
openaiWSResolver: NewOpenAIWSProtocolResolver(cfg),
toolCorrector: NewCodexToolCorrector(),
}
account := &Account{
ID: 9,
Name: "openai-ws",
Platform: PlatformOpenAI,
Type: AccountTypeAPIKey,
Status: StatusActive,
Schedulable: true,
Concurrency: 2,
Credentials: map[string]any{
"api_key": "sk-test",
"base_url": wsServer.URL,
},
Extra: map[string]any{
"responses_websockets_v2_enabled": true,
},
}
body := []byte(`{"model":"gpt-5.1","stream":false,"previous_response_id":"resp_prev_1","input":[{"type":"input_text","text":"hello"}]}`)
result, err := svc.Forward(context.Background(), c, account, body)
require.NoError(t, err)
require.NotNil(t, result)
require.Equal(t, 12, result.Usage.InputTokens)
require.Equal(t, 7, result.Usage.OutputTokens)
require.Equal(t, 3, result.Usage.CacheReadInputTokens)
require.Equal(t, "resp_new_1", result.RequestID)
require.True(t, result.OpenAIWSMode)
require.False(t, gjson.GetBytes(upstream.lastBody, "model").Exists(), "WSv2 成功时不应回落 HTTP 上游")
received := <-receivedCh
require.Equal(t, "response.create", received.Type)
require.Equal(t, "resp_prev_1", received.PreviousResponseID)
require.True(t, received.StreamExists, "WS 请求应携带 stream 字段")
require.False(t, received.Stream, "应保持客户端 stream=false 的原始语义")
store := svc.getOpenAIWSStateStore()
mappedAccountID, getErr := store.GetResponseAccount(context.Background(), groupID, "resp_new_1")
require.NoError(t, getErr)
require.Equal(t, account.ID, mappedAccountID)
connID, ok := store.GetResponseConn("resp_new_1")
require.True(t, ok)
require.NotEmpty(t, connID)
responseBody := rec.Body.Bytes()
require.Equal(t, "resp_new_1", gjson.GetBytes(responseBody, "id").String())
}
func requestToJSONString(payload map[string]any) string {
if len(payload) == 0 {
return "{}"
}
b, err := json.Marshal(payload)
if err != nil {
return "{}"
}
return string(b)
}
func TestLogOpenAIWSBindResponseAccountWarn(t *testing.T) {
require.NotPanics(t, func() {
logOpenAIWSBindResponseAccountWarn(1, 2, "resp_ok", nil)
})
require.NotPanics(t, func() {
logOpenAIWSBindResponseAccountWarn(1, 2, "resp_err", errors.New("bind failed"))
})
}
func TestOpenAIGatewayService_Forward_WSv2_RewriteModelAndToolCallsOnCompletedEvent(t *testing.T) {
gin.SetMode(gin.TestMode)
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses", nil)
c.Request.Header.Set("User-Agent", "codex_cli_rs/0.98.0")
groupID := int64(3001)
c.Set("api_key", &APIKey{GroupID: &groupID})
cfg := &config.Config{}
cfg.Security.URLAllowlist.Enabled = false
cfg.Security.URLAllowlist.AllowInsecureHTTP = true
cfg.Gateway.OpenAIWS.Enabled = true
cfg.Gateway.OpenAIWS.OAuthEnabled = true
cfg.Gateway.OpenAIWS.APIKeyEnabled = true
cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true
cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 1
cfg.Gateway.OpenAIWS.MinIdlePerAccount = 0
cfg.Gateway.OpenAIWS.MaxIdlePerAccount = 1
cfg.Gateway.OpenAIWS.QueueLimitPerConn = 8
cfg.Gateway.OpenAIWS.DialTimeoutSeconds = 3
cfg.Gateway.OpenAIWS.ReadTimeoutSeconds = 5
cfg.Gateway.OpenAIWS.WriteTimeoutSeconds = 3
captureConn := &openAIWSCaptureConn{
events: [][]byte{
[]byte(`{"type":"response.completed","response":{"id":"resp_model_tool_1","model":"gpt-5.1","tool_calls":[{"function":{"name":"apply_patch","arguments":"{\"file_path\":\"/tmp/a.txt\",\"old_string\":\"a\",\"new_string\":\"b\"}"}}],"usage":{"input_tokens":2,"output_tokens":1}},"tool_calls":[{"function":{"name":"apply_patch","arguments":"{\"file_path\":\"/tmp/a.txt\",\"old_string\":\"a\",\"new_string\":\"b\"}"}}]}`),
},
}
captureDialer := &openAIWSCaptureDialer{conn: captureConn}
pool := newOpenAIWSConnPool(cfg)
pool.setClientDialerForTest(captureDialer)
svc := &OpenAIGatewayService{
cfg: cfg,
httpUpstream: &httpUpstreamRecorder{},
cache: &stubGatewayCache{},
openaiWSResolver: NewOpenAIWSProtocolResolver(cfg),
toolCorrector: NewCodexToolCorrector(),
openaiWSPool: pool,
}
account := &Account{
ID: 1301,
Name: "openai-rewrite",
Platform: PlatformOpenAI,
Type: AccountTypeAPIKey,
Status: StatusActive,
Schedulable: true,
Concurrency: 1,
Credentials: map[string]any{
"api_key": "sk-test",
"model_mapping": map[string]any{
"custom-original-model": "gpt-5.1",
},
},
Extra: map[string]any{
"responses_websockets_v2_enabled": true,
},
}
body := []byte(`{"model":"custom-original-model","stream":false,"input":[{"type":"input_text","text":"hello"}]}`)
result, err := svc.Forward(context.Background(), c, account, body)
require.NoError(t, err)
require.NotNil(t, result)
require.Equal(t, "resp_model_tool_1", result.RequestID)
require.Equal(t, "custom-original-model", gjson.GetBytes(rec.Body.Bytes(), "model").String(), "响应模型应回写为原始请求模型")
require.Equal(t, "edit", gjson.GetBytes(rec.Body.Bytes(), "tool_calls.0.function.name").String(), "工具名称应被修正为 OpenCode 规范")
}
func TestOpenAIWSPayloadString_OnlyAcceptsStringValues(t *testing.T) {
payload := map[string]any{
"type": nil,
"model": 123,
"prompt_cache_key": " cache-key ",
"previous_response_id": []byte(" resp_1 "),
}
require.Equal(t, "", openAIWSPayloadString(payload, "type"))
require.Equal(t, "", openAIWSPayloadString(payload, "model"))
require.Equal(t, "cache-key", openAIWSPayloadString(payload, "prompt_cache_key"))
require.Equal(t, "resp_1", openAIWSPayloadString(payload, "previous_response_id"))
}
func TestOpenAIGatewayService_Forward_WSv2_PoolReuseNotOneToOne(t *testing.T) {
gin.SetMode(gin.TestMode)
var upgradeCount atomic.Int64
var sequence atomic.Int64
upgrader := websocket.Upgrader{CheckOrigin: func(r *http.Request) bool { return true }}
wsServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
upgradeCount.Add(1)
conn, err := upgrader.Upgrade(w, r, nil)
if err != nil {
t.Errorf("upgrade websocket failed: %v", err)
return
}
defer func() {
_ = conn.Close()
}()
for {
var request map[string]any
if err := conn.ReadJSON(&request); err != nil {
return
}
idx := sequence.Add(1)
responseID := "resp_reuse_" + strconv.FormatInt(idx, 10)
if err := conn.WriteJSON(map[string]any{
"type": "response.created",
"response": map[string]any{
"id": responseID,
"model": "gpt-5.1",
},
}); err != nil {
return
}
if err := conn.WriteJSON(map[string]any{
"type": "response.completed",
"response": map[string]any{
"id": responseID,
"model": "gpt-5.1",
"usage": map[string]any{
"input_tokens": 2,
"output_tokens": 1,
},
},
}); err != nil {
return
}
}
}))
defer wsServer.Close()
cfg := &config.Config{}
cfg.Security.URLAllowlist.Enabled = false
cfg.Security.URLAllowlist.AllowInsecureHTTP = true
cfg.Gateway.OpenAIWS.Enabled = true
cfg.Gateway.OpenAIWS.OAuthEnabled = true
cfg.Gateway.OpenAIWS.APIKeyEnabled = true
cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true
cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 2
cfg.Gateway.OpenAIWS.MinIdlePerAccount = 0
cfg.Gateway.OpenAIWS.MaxIdlePerAccount = 1
cfg.Gateway.OpenAIWS.QueueLimitPerConn = 8
cfg.Gateway.OpenAIWS.DialTimeoutSeconds = 3
cfg.Gateway.OpenAIWS.ReadTimeoutSeconds = 30
cfg.Gateway.OpenAIWS.WriteTimeoutSeconds = 10
svc := &OpenAIGatewayService{
cfg: cfg,
httpUpstream: &httpUpstreamRecorder{},
cache: &stubGatewayCache{},
openaiWSResolver: NewOpenAIWSProtocolResolver(cfg),
toolCorrector: NewCodexToolCorrector(),
}
account := &Account{
ID: 19,
Name: "openai-ws",
Platform: PlatformOpenAI,
Type: AccountTypeAPIKey,
Status: StatusActive,
Schedulable: true,
Concurrency: 2,
Credentials: map[string]any{
"api_key": "sk-test",
"base_url": wsServer.URL,
},
Extra: map[string]any{
"responses_websockets_v2_enabled": true,
},
}
for i := 0; i < 2; i++ {
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses", nil)
c.Request.Header.Set("User-Agent", "codex_cli_rs/0.98.0")
groupID := int64(2001)
c.Set("api_key", &APIKey{GroupID: &groupID})
body := []byte(`{"model":"gpt-5.1","stream":false,"previous_response_id":"resp_prev_reuse","input":[{"type":"input_text","text":"hello"}]}`)
result, err := svc.Forward(context.Background(), c, account, body)
require.NoError(t, err)
require.NotNil(t, result)
require.True(t, strings.HasPrefix(result.RequestID, "resp_reuse_"))
}
require.Equal(t, int64(1), upgradeCount.Load(), "多个客户端请求应复用账号连接池而不是 1:1 对等建链")
metrics := svc.SnapshotOpenAIWSPoolMetrics()
require.GreaterOrEqual(t, metrics.AcquireReuseTotal, int64(1))
require.GreaterOrEqual(t, metrics.ConnPickTotal, int64(1))
}
func TestOpenAIGatewayService_Forward_WSv2_OAuthStoreFalseByDefault(t *testing.T) {
gin.SetMode(gin.TestMode)
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses", nil)
c.Request.Header.Set("User-Agent", "codex_cli_rs/0.98.0")
c.Request.Header.Set("session_id", "sess-oauth-1")
c.Request.Header.Set("conversation_id", "conv-oauth-1")
cfg := &config.Config{}
cfg.Security.URLAllowlist.Enabled = false
cfg.Security.URLAllowlist.AllowInsecureHTTP = true
cfg.Gateway.OpenAIWS.Enabled = true
cfg.Gateway.OpenAIWS.OAuthEnabled = true
cfg.Gateway.OpenAIWS.APIKeyEnabled = true
cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true
cfg.Gateway.OpenAIWS.AllowStoreRecovery = false
cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 1
cfg.Gateway.OpenAIWS.MinIdlePerAccount = 0
cfg.Gateway.OpenAIWS.MaxIdlePerAccount = 1
captureConn := &openAIWSCaptureConn{
events: [][]byte{
[]byte(`{"type":"response.completed","response":{"id":"resp_oauth_1","model":"gpt-5.1","usage":{"input_tokens":3,"output_tokens":2}}}`),
},
}
captureDialer := &openAIWSCaptureDialer{conn: captureConn}
pool := newOpenAIWSConnPool(cfg)
pool.setClientDialerForTest(captureDialer)
svc := &OpenAIGatewayService{
cfg: cfg,
httpUpstream: &httpUpstreamRecorder{},
cache: &stubGatewayCache{},
openaiWSResolver: NewOpenAIWSProtocolResolver(cfg),
toolCorrector: NewCodexToolCorrector(),
openaiWSPool: pool,
}
account := &Account{
ID: 29,
Name: "openai-oauth",
Platform: PlatformOpenAI,
Type: AccountTypeOAuth,
Status: StatusActive,
Schedulable: true,
Concurrency: 1,
Credentials: map[string]any{
"access_token": "oauth-token-1",
},
Extra: map[string]any{
"responses_websockets_v2_enabled": true,
},
}
body := []byte(`{"model":"gpt-5.1","stream":false,"store":true,"input":[{"type":"input_text","text":"hello"}]}`)
result, err := svc.Forward(context.Background(), c, account, body)
require.NoError(t, err)
require.NotNil(t, result)
require.Equal(t, "resp_oauth_1", result.RequestID)
require.NotNil(t, captureConn.lastWrite)
requestJSON := requestToJSONString(captureConn.lastWrite)
require.True(t, gjson.Get(requestJSON, "store").Exists(), "OAuth WSv2 应显式写入 store 字段")
require.False(t, gjson.Get(requestJSON, "store").Bool(), "默认策略应将 OAuth store 置为 false")
require.True(t, gjson.Get(requestJSON, "stream").Exists(), "WSv2 payload 应保留 stream 字段")
require.True(t, gjson.Get(requestJSON, "stream").Bool(), "OAuth Codex 规范化后应强制 stream=true")
require.Equal(t, openAIWSBetaV2Value, captureDialer.lastHeaders.Get("OpenAI-Beta"))
require.Equal(t, "sess-oauth-1", captureDialer.lastHeaders.Get("session_id"))
require.Equal(t, "conv-oauth-1", captureDialer.lastHeaders.Get("conversation_id"))
}
func TestOpenAIGatewayService_Forward_WSv2_HeaderSessionFallbackFromPromptCacheKey(t *testing.T) {
gin.SetMode(gin.TestMode)
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses", nil)
c.Request.Header.Set("User-Agent", "codex_cli_rs/0.98.0")
cfg := &config.Config{}
cfg.Security.URLAllowlist.Enabled = false
cfg.Security.URLAllowlist.AllowInsecureHTTP = true
cfg.Gateway.OpenAIWS.Enabled = true
cfg.Gateway.OpenAIWS.OAuthEnabled = true
cfg.Gateway.OpenAIWS.APIKeyEnabled = true
cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true
cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 1
cfg.Gateway.OpenAIWS.MinIdlePerAccount = 0
cfg.Gateway.OpenAIWS.MaxIdlePerAccount = 1
captureConn := &openAIWSCaptureConn{
events: [][]byte{
[]byte(`{"type":"response.completed","response":{"id":"resp_prompt_cache_key","model":"gpt-5.1","usage":{"input_tokens":2,"output_tokens":1}}}`),
},
}
captureDialer := &openAIWSCaptureDialer{conn: captureConn}
pool := newOpenAIWSConnPool(cfg)
pool.setClientDialerForTest(captureDialer)
svc := &OpenAIGatewayService{
cfg: cfg,
httpUpstream: &httpUpstreamRecorder{},
cache: &stubGatewayCache{},
openaiWSResolver: NewOpenAIWSProtocolResolver(cfg),
toolCorrector: NewCodexToolCorrector(),
openaiWSPool: pool,
}
account := &Account{
ID: 31,
Name: "openai-oauth",
Platform: PlatformOpenAI,
Type: AccountTypeOAuth,
Status: StatusActive,
Schedulable: true,
Concurrency: 1,
Credentials: map[string]any{
"access_token": "oauth-token-1",
},
Extra: map[string]any{
"responses_websockets_v2_enabled": true,
},
}
body := []byte(`{"model":"gpt-5.1","stream":true,"prompt_cache_key":"pcache_123","input":[{"type":"input_text","text":"hi"}]}`)
result, err := svc.Forward(context.Background(), c, account, body)
require.NoError(t, err)
require.NotNil(t, result)
require.Equal(t, "resp_prompt_cache_key", result.RequestID)
require.Equal(t, "pcache_123", captureDialer.lastHeaders.Get("session_id"))
require.Empty(t, captureDialer.lastHeaders.Get("conversation_id"))
require.NotNil(t, captureConn.lastWrite)
require.True(t, gjson.Get(requestToJSONString(captureConn.lastWrite), "stream").Exists())
}
func TestOpenAIGatewayService_Forward_WSv1_Unsupported(t *testing.T) {
gin.SetMode(gin.TestMode)
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses", nil)
c.Request.Header.Set("User-Agent", "codex_cli_rs/0.98.0")
cfg := &config.Config{}
cfg.Security.URLAllowlist.Enabled = false
cfg.Security.URLAllowlist.AllowInsecureHTTP = true
cfg.Gateway.OpenAIWS.Enabled = true
cfg.Gateway.OpenAIWS.OAuthEnabled = true
cfg.Gateway.OpenAIWS.APIKeyEnabled = true
cfg.Gateway.OpenAIWS.ResponsesWebsockets = true
cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = false
upstream := &httpUpstreamRecorder{
resp: &http.Response{
StatusCode: http.StatusOK,
Header: http.Header{"Content-Type": []string{"application/json"}},
Body: io.NopCloser(strings.NewReader(`{"usage":{"input_tokens":1,"output_tokens":1}}`)),
},
}
svc := &OpenAIGatewayService{
cfg: cfg,
httpUpstream: upstream,
cache: &stubGatewayCache{},
openaiWSResolver: NewOpenAIWSProtocolResolver(cfg),
toolCorrector: NewCodexToolCorrector(),
}
account := &Account{
ID: 39,
Name: "openai-ws-v1",
Platform: PlatformOpenAI,
Type: AccountTypeAPIKey,
Status: StatusActive,
Schedulable: true,
Concurrency: 1,
Credentials: map[string]any{
"api_key": "sk-test",
"base_url": "https://api.openai.com/v1/responses",
},
Extra: map[string]any{
"responses_websockets_v2_enabled": true,
},
}
body := []byte(`{"model":"gpt-5.1","stream":false,"previous_response_id":"resp_prev_v1","input":[{"type":"input_text","text":"hello"}]}`)
result, err := svc.Forward(context.Background(), c, account, body)
require.Error(t, err)
require.Nil(t, result)
require.Contains(t, err.Error(), "ws v1")
require.Equal(t, http.StatusBadRequest, rec.Code)
require.Contains(t, rec.Body.String(), "WSv1")
require.Nil(t, upstream.lastReq, "WSv1 不支持时不应触发 HTTP 上游请求")
}
func TestOpenAIGatewayService_Forward_WSv2_TurnStateAndMetadataReplayOnReconnect(t *testing.T) {
gin.SetMode(gin.TestMode)
var connIndex atomic.Int64
headersCh := make(chan http.Header, 4)
upgrader := websocket.Upgrader{CheckOrigin: func(r *http.Request) bool { return true }}
wsServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
idx := connIndex.Add(1)
headersCh <- cloneHeader(r.Header)
respHeader := http.Header{}
if idx == 1 {
respHeader.Set("x-codex-turn-state", "turn_state_first")
}
conn, err := upgrader.Upgrade(w, r, respHeader)
if err != nil {
t.Errorf("upgrade websocket failed: %v", err)
return
}
defer func() {
_ = conn.Close()
}()
var request map[string]any
if err := conn.ReadJSON(&request); err != nil {
t.Errorf("read ws request failed: %v", err)
return
}
responseID := "resp_turn_" + strconv.FormatInt(idx, 10)
if err := conn.WriteJSON(map[string]any{
"type": "response.completed",
"response": map[string]any{
"id": responseID,
"model": "gpt-5.1",
"usage": map[string]any{
"input_tokens": 2,
"output_tokens": 1,
},
},
}); err != nil {
t.Errorf("write response.completed failed: %v", err)
return
}
}))
defer wsServer.Close()
cfg := &config.Config{}
cfg.Security.URLAllowlist.Enabled = false
cfg.Security.URLAllowlist.AllowInsecureHTTP = true
cfg.Gateway.OpenAIWS.Enabled = true
cfg.Gateway.OpenAIWS.OAuthEnabled = true
cfg.Gateway.OpenAIWS.APIKeyEnabled = true
cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true
cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 1
cfg.Gateway.OpenAIWS.MinIdlePerAccount = 0
cfg.Gateway.OpenAIWS.MaxIdlePerAccount = 0
svc := &OpenAIGatewayService{
cfg: cfg,
httpUpstream: &httpUpstreamRecorder{},
cache: &stubGatewayCache{},
openaiWSResolver: NewOpenAIWSProtocolResolver(cfg),
toolCorrector: NewCodexToolCorrector(),
}
account := &Account{
ID: 49,
Name: "openai-turn-state",
Platform: PlatformOpenAI,
Type: AccountTypeAPIKey,
Status: StatusActive,
Schedulable: true,
Concurrency: 1,
Credentials: map[string]any{
"api_key": "sk-test",
"base_url": wsServer.URL,
},
Extra: map[string]any{
"responses_websockets_v2_enabled": true,
},
}
reqBody := []byte(`{"model":"gpt-5.1","stream":false,"input":[{"type":"input_text","text":"hello"}]}`)
rec1 := httptest.NewRecorder()
c1, _ := gin.CreateTestContext(rec1)
c1.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses", nil)
c1.Request.Header.Set("session_id", "session_turn_state")
c1.Request.Header.Set("x-codex-turn-metadata", "turn_meta_1")
result1, err := svc.Forward(context.Background(), c1, account, reqBody)
require.NoError(t, err)
require.NotNil(t, result1)
sessionHash := svc.GenerateSessionHash(c1, reqBody)
store := svc.getOpenAIWSStateStore()
turnState, ok := store.GetSessionTurnState(0, sessionHash)
require.True(t, ok)
require.Equal(t, "turn_state_first", turnState)
// 主动淘汰连接,模拟下一次请求发生重连。
connID, hasConn := store.GetResponseConn(result1.RequestID)
require.True(t, hasConn)
svc.getOpenAIWSConnPool().evictConn(account.ID, connID)
rec2 := httptest.NewRecorder()
c2, _ := gin.CreateTestContext(rec2)
c2.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses", nil)
c2.Request.Header.Set("session_id", "session_turn_state")
c2.Request.Header.Set("x-codex-turn-metadata", "turn_meta_2")
result2, err := svc.Forward(context.Background(), c2, account, reqBody)
require.NoError(t, err)
require.NotNil(t, result2)
firstHandshakeHeaders := <-headersCh
secondHandshakeHeaders := <-headersCh
require.Equal(t, "turn_meta_1", firstHandshakeHeaders.Get("X-Codex-Turn-Metadata"))
require.Equal(t, "turn_meta_2", secondHandshakeHeaders.Get("X-Codex-Turn-Metadata"))
require.Equal(t, "turn_state_first", secondHandshakeHeaders.Get("X-Codex-Turn-State"))
}
func TestOpenAIGatewayService_Forward_WSv2_GeneratePrewarm(t *testing.T) {
gin.SetMode(gin.TestMode)
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses", nil)
c.Request.Header.Set("session_id", "session-prewarm")
cfg := &config.Config{}
cfg.Security.URLAllowlist.Enabled = false
cfg.Security.URLAllowlist.AllowInsecureHTTP = true
cfg.Gateway.OpenAIWS.Enabled = true
cfg.Gateway.OpenAIWS.OAuthEnabled = true
cfg.Gateway.OpenAIWS.APIKeyEnabled = true
cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true
cfg.Gateway.OpenAIWS.PrewarmGenerateEnabled = true
cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 1
cfg.Gateway.OpenAIWS.MinIdlePerAccount = 0
cfg.Gateway.OpenAIWS.MaxIdlePerAccount = 1
captureConn := &openAIWSCaptureConn{
events: [][]byte{
[]byte(`{"type":"response.completed","response":{"id":"resp_prewarm_1","model":"gpt-5.1","usage":{"input_tokens":0,"output_tokens":0}}}`),
[]byte(`{"type":"response.completed","response":{"id":"resp_main_1","model":"gpt-5.1","usage":{"input_tokens":4,"output_tokens":2}}}`),
},
}
captureDialer := &openAIWSCaptureDialer{conn: captureConn}
pool := newOpenAIWSConnPool(cfg)
pool.setClientDialerForTest(captureDialer)
svc := &OpenAIGatewayService{
cfg: cfg,
httpUpstream: &httpUpstreamRecorder{},
cache: &stubGatewayCache{},
openaiWSResolver: NewOpenAIWSProtocolResolver(cfg),
toolCorrector: NewCodexToolCorrector(),
openaiWSPool: pool,
}
account := &Account{
ID: 59,
Name: "openai-prewarm",
Platform: PlatformOpenAI,
Type: AccountTypeAPIKey,
Status: StatusActive,
Schedulable: true,
Concurrency: 1,
Credentials: map[string]any{
"api_key": "sk-test",
},
Extra: map[string]any{
"responses_websockets_v2_enabled": true,
},
}
body := []byte(`{"model":"gpt-5.1","stream":false,"input":[{"type":"input_text","text":"hello"}]}`)
result, err := svc.Forward(context.Background(), c, account, body)
require.NoError(t, err)
require.NotNil(t, result)
require.Equal(t, "resp_main_1", result.RequestID)
require.Len(t, captureConn.writes, 2, "开启 generate=false 预热后应发送两次 WS 请求")
firstWrite := requestToJSONString(captureConn.writes[0])
secondWrite := requestToJSONString(captureConn.writes[1])
require.True(t, gjson.Get(firstWrite, "generate").Exists())
require.False(t, gjson.Get(firstWrite, "generate").Bool())
require.False(t, gjson.Get(secondWrite, "generate").Exists())
}
func TestOpenAIGatewayService_PrewarmReadHonorsParentContext(t *testing.T) {
cfg := &config.Config{}
cfg.Gateway.OpenAIWS.PrewarmGenerateEnabled = true
cfg.Gateway.OpenAIWS.ReadTimeoutSeconds = 5
cfg.Gateway.OpenAIWS.WriteTimeoutSeconds = 3
svc := &OpenAIGatewayService{
cfg: cfg,
toolCorrector: NewCodexToolCorrector(),
}
account := &Account{
ID: 601,
Name: "openai-prewarm-timeout",
Platform: PlatformOpenAI,
Type: AccountTypeAPIKey,
Status: StatusActive,
Schedulable: true,
}
conn := newOpenAIWSConn("prewarm_ctx_conn", account.ID, &openAIWSBlockingConn{
readDelay: 200 * time.Millisecond,
}, nil)
lease := &openAIWSConnLease{
accountID: account.ID,
conn: conn,
}
payload := map[string]any{
"type": "response.create",
"model": "gpt-5.1",
}
ctx, cancel := context.WithTimeout(context.Background(), 40*time.Millisecond)
defer cancel()
start := time.Now()
err := svc.performOpenAIWSGeneratePrewarm(
ctx,
lease,
OpenAIWSProtocolDecision{Transport: OpenAIUpstreamTransportResponsesWebsocketV2},
payload,
"",
map[string]any{"model": "gpt-5.1"},
account,
nil,
0,
)
elapsed := time.Since(start)
require.Error(t, err)
require.Contains(t, err.Error(), "prewarm_read_event")
require.Less(t, elapsed, 180*time.Millisecond, "预热读取应受父 context 取消控制,不应阻塞到 read_timeout")
}
func TestOpenAIGatewayService_Forward_WSv2_TurnMetadataInPayloadOnConnReuse(t *testing.T) {
gin.SetMode(gin.TestMode)
cfg := &config.Config{}
cfg.Security.URLAllowlist.Enabled = false
cfg.Security.URLAllowlist.AllowInsecureHTTP = true
cfg.Gateway.OpenAIWS.Enabled = true
cfg.Gateway.OpenAIWS.OAuthEnabled = true
cfg.Gateway.OpenAIWS.APIKeyEnabled = true
cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true
cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 1
cfg.Gateway.OpenAIWS.MinIdlePerAccount = 0
cfg.Gateway.OpenAIWS.MaxIdlePerAccount = 1
captureConn := &openAIWSCaptureConn{
events: [][]byte{
[]byte(`{"type":"response.completed","response":{"id":"resp_meta_1","model":"gpt-5.1","usage":{"input_tokens":1,"output_tokens":1}}}`),
[]byte(`{"type":"response.completed","response":{"id":"resp_meta_2","model":"gpt-5.1","usage":{"input_tokens":1,"output_tokens":1}}}`),
},
}
captureDialer := &openAIWSCaptureDialer{conn: captureConn}
pool := newOpenAIWSConnPool(cfg)
pool.setClientDialerForTest(captureDialer)
svc := &OpenAIGatewayService{
cfg: cfg,
httpUpstream: &httpUpstreamRecorder{},
cache: &stubGatewayCache{},
openaiWSResolver: NewOpenAIWSProtocolResolver(cfg),
toolCorrector: NewCodexToolCorrector(),
openaiWSPool: pool,
}
account := &Account{
ID: 69,
Name: "openai-turn-metadata",
Platform: PlatformOpenAI,
Type: AccountTypeAPIKey,
Status: StatusActive,
Schedulable: true,
Concurrency: 1,
Credentials: map[string]any{
"api_key": "sk-test",
},
Extra: map[string]any{
"responses_websockets_v2_enabled": true,
},
}
body := []byte(`{"model":"gpt-5.1","stream":false,"input":[{"type":"input_text","text":"hello"}]}`)
rec1 := httptest.NewRecorder()
c1, _ := gin.CreateTestContext(rec1)
c1.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses", nil)
c1.Request.Header.Set("session_id", "session-metadata-reuse")
c1.Request.Header.Set("x-codex-turn-metadata", "turn_meta_payload_1")
result1, err := svc.Forward(context.Background(), c1, account, body)
require.NoError(t, err)
require.NotNil(t, result1)
require.Equal(t, "resp_meta_1", result1.RequestID)
rec2 := httptest.NewRecorder()
c2, _ := gin.CreateTestContext(rec2)
c2.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses", nil)
c2.Request.Header.Set("session_id", "session-metadata-reuse")
c2.Request.Header.Set("x-codex-turn-metadata", "turn_meta_payload_2")
result2, err := svc.Forward(context.Background(), c2, account, body)
require.NoError(t, err)
require.NotNil(t, result2)
require.Equal(t, "resp_meta_2", result2.RequestID)
require.Equal(t, 1, captureDialer.DialCount(), "同一账号两轮请求应复用同一 WS 连接")
require.Len(t, captureConn.writes, 2)
firstWrite := requestToJSONString(captureConn.writes[0])
secondWrite := requestToJSONString(captureConn.writes[1])
require.Equal(t, "turn_meta_payload_1", gjson.Get(firstWrite, "client_metadata.x-codex-turn-metadata").String())
require.Equal(t, "turn_meta_payload_2", gjson.Get(secondWrite, "client_metadata.x-codex-turn-metadata").String())
}
func TestOpenAIGatewayService_Forward_WSv2StoreFalseSessionConnIsolation(t *testing.T) {
gin.SetMode(gin.TestMode)
var upgradeCount atomic.Int64
var sequence atomic.Int64
upgrader := websocket.Upgrader{CheckOrigin: func(r *http.Request) bool { return true }}
wsServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
upgradeCount.Add(1)
conn, err := upgrader.Upgrade(w, r, nil)
if err != nil {
t.Errorf("upgrade websocket failed: %v", err)
return
}
defer func() {
_ = conn.Close()
}()
for {
var request map[string]any
if err := conn.ReadJSON(&request); err != nil {
return
}
responseID := "resp_store_false_" + strconv.FormatInt(sequence.Add(1), 10)
if err := conn.WriteJSON(map[string]any{
"type": "response.completed",
"response": map[string]any{
"id": responseID,
"model": "gpt-5.1",
"usage": map[string]any{
"input_tokens": 1,
"output_tokens": 1,
},
},
}); err != nil {
return
}
}
}))
defer wsServer.Close()
cfg := &config.Config{}
cfg.Security.URLAllowlist.Enabled = false
cfg.Security.URLAllowlist.AllowInsecureHTTP = true
cfg.Gateway.OpenAIWS.Enabled = true
cfg.Gateway.OpenAIWS.OAuthEnabled = true
cfg.Gateway.OpenAIWS.APIKeyEnabled = true
cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true
cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 4
cfg.Gateway.OpenAIWS.MinIdlePerAccount = 0
cfg.Gateway.OpenAIWS.MaxIdlePerAccount = 4
cfg.Gateway.OpenAIWS.StoreDisabledForceNewConn = true
svc := &OpenAIGatewayService{
cfg: cfg,
httpUpstream: &httpUpstreamRecorder{},
cache: &stubGatewayCache{},
openaiWSResolver: NewOpenAIWSProtocolResolver(cfg),
toolCorrector: NewCodexToolCorrector(),
}
account := &Account{
ID: 79,
Name: "openai-store-false",
Platform: PlatformOpenAI,
Type: AccountTypeAPIKey,
Status: StatusActive,
Schedulable: true,
Concurrency: 2,
Credentials: map[string]any{
"api_key": "sk-test",
"base_url": wsServer.URL,
},
Extra: map[string]any{
"responses_websockets_v2_enabled": true,
},
}
body := []byte(`{"model":"gpt-5.1","stream":false,"store":false,"input":[{"type":"input_text","text":"hello"}]}`)
rec1 := httptest.NewRecorder()
c1, _ := gin.CreateTestContext(rec1)
c1.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses", nil)
c1.Request.Header.Set("session_id", "session_store_false_a")
result1, err := svc.Forward(context.Background(), c1, account, body)
require.NoError(t, err)
require.NotNil(t, result1)
require.Equal(t, int64(1), upgradeCount.Load())
rec2 := httptest.NewRecorder()
c2, _ := gin.CreateTestContext(rec2)
c2.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses", nil)
c2.Request.Header.Set("session_id", "session_store_false_a")
result2, err := svc.Forward(context.Background(), c2, account, body)
require.NoError(t, err)
require.NotNil(t, result2)
require.Equal(t, int64(1), upgradeCount.Load(), "同一 session(store=false) 应复用同一 WS 连接")
rec3 := httptest.NewRecorder()
c3, _ := gin.CreateTestContext(rec3)
c3.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses", nil)
c3.Request.Header.Set("session_id", "session_store_false_b")
result3, err := svc.Forward(context.Background(), c3, account, body)
require.NoError(t, err)
require.NotNil(t, result3)
require.Equal(t, int64(2), upgradeCount.Load(), "不同 session(store=false) 应隔离连接,避免续链状态互相覆盖")
}
func TestOpenAIGatewayService_Forward_WSv2StoreFalseDisableForceNewConnAllowsReuse(t *testing.T) {
gin.SetMode(gin.TestMode)
var upgradeCount atomic.Int64
var sequence atomic.Int64
upgrader := websocket.Upgrader{CheckOrigin: func(r *http.Request) bool { return true }}
wsServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
upgradeCount.Add(1)
conn, err := upgrader.Upgrade(w, r, nil)
if err != nil {
t.Errorf("upgrade websocket failed: %v", err)
return
}
defer func() {
_ = conn.Close()
}()
for {
var request map[string]any
if err := conn.ReadJSON(&request); err != nil {
return
}
responseID := "resp_store_false_reuse_" + strconv.FormatInt(sequence.Add(1), 10)
if err := conn.WriteJSON(map[string]any{
"type": "response.completed",
"response": map[string]any{
"id": responseID,
"model": "gpt-5.1",
"usage": map[string]any{
"input_tokens": 1,
"output_tokens": 1,
},
},
}); err != nil {
return
}
}
}))
defer wsServer.Close()
cfg := &config.Config{}
cfg.Security.URLAllowlist.Enabled = false
cfg.Security.URLAllowlist.AllowInsecureHTTP = true
cfg.Gateway.OpenAIWS.Enabled = true
cfg.Gateway.OpenAIWS.OAuthEnabled = true
cfg.Gateway.OpenAIWS.APIKeyEnabled = true
cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true
cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 1
cfg.Gateway.OpenAIWS.MinIdlePerAccount = 0
cfg.Gateway.OpenAIWS.MaxIdlePerAccount = 1
cfg.Gateway.OpenAIWS.StoreDisabledForceNewConn = false
svc := &OpenAIGatewayService{
cfg: cfg,
httpUpstream: &httpUpstreamRecorder{},
cache: &stubGatewayCache{},
openaiWSResolver: NewOpenAIWSProtocolResolver(cfg),
toolCorrector: NewCodexToolCorrector(),
}
account := &Account{
ID: 80,
Name: "openai-store-false-reuse",
Platform: PlatformOpenAI,
Type: AccountTypeAPIKey,
Status: StatusActive,
Schedulable: true,
Concurrency: 2,
Credentials: map[string]any{
"api_key": "sk-test",
"base_url": wsServer.URL,
},
Extra: map[string]any{
"responses_websockets_v2_enabled": true,
},
}
body := []byte(`{"model":"gpt-5.1","stream":false,"store":false,"input":[{"type":"input_text","text":"hello"}]}`)
rec1 := httptest.NewRecorder()
c1, _ := gin.CreateTestContext(rec1)
c1.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses", nil)
c1.Request.Header.Set("session_id", "session_store_false_reuse_a")
result1, err := svc.Forward(context.Background(), c1, account, body)
require.NoError(t, err)
require.NotNil(t, result1)
require.Equal(t, int64(1), upgradeCount.Load())
rec2 := httptest.NewRecorder()
c2, _ := gin.CreateTestContext(rec2)
c2.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses", nil)
c2.Request.Header.Set("session_id", "session_store_false_reuse_b")
result2, err := svc.Forward(context.Background(), c2, account, body)
require.NoError(t, err)
require.NotNil(t, result2)
require.Equal(t, int64(1), upgradeCount.Load(), "关闭强制新连后,不同 session(store=false) 可复用连接")
}
func TestOpenAIGatewayService_Forward_WSv2ReadTimeoutAppliesPerRead(t *testing.T) {
gin.SetMode(gin.TestMode)
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses", nil)
c.Request.Header.Set("User-Agent", "codex_cli_rs/0.98.0")
cfg := &config.Config{}
cfg.Security.URLAllowlist.Enabled = false
cfg.Security.URLAllowlist.AllowInsecureHTTP = true
cfg.Gateway.OpenAIWS.Enabled = true
cfg.Gateway.OpenAIWS.OAuthEnabled = true
cfg.Gateway.OpenAIWS.APIKeyEnabled = true
cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true
cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 1
cfg.Gateway.OpenAIWS.MinIdlePerAccount = 0
cfg.Gateway.OpenAIWS.MaxIdlePerAccount = 1
cfg.Gateway.OpenAIWS.QueueLimitPerConn = 8
cfg.Gateway.OpenAIWS.DialTimeoutSeconds = 3
cfg.Gateway.OpenAIWS.ReadTimeoutSeconds = 1
cfg.Gateway.OpenAIWS.WriteTimeoutSeconds = 3
captureConn := &openAIWSCaptureConn{
readDelays: []time.Duration{
700 * time.Millisecond,
700 * time.Millisecond,
},
events: [][]byte{
[]byte(`{"type":"response.created","response":{"id":"resp_timeout_ok","model":"gpt-5.1"}}`),
[]byte(`{"type":"response.completed","response":{"id":"resp_timeout_ok","model":"gpt-5.1","usage":{"input_tokens":2,"output_tokens":1}}}`),
},
}
captureDialer := &openAIWSCaptureDialer{conn: captureConn}
pool := newOpenAIWSConnPool(cfg)
pool.setClientDialerForTest(captureDialer)
upstream := &httpUpstreamRecorder{
resp: &http.Response{
StatusCode: http.StatusOK,
Header: http.Header{"Content-Type": []string{"application/json"}},
Body: io.NopCloser(strings.NewReader(`{"id":"resp_http_fallback","usage":{"input_tokens":1,"output_tokens":1}}`)),
},
}
svc := &OpenAIGatewayService{
cfg: cfg,
httpUpstream: upstream,
cache: &stubGatewayCache{},
openaiWSResolver: NewOpenAIWSProtocolResolver(cfg),
toolCorrector: NewCodexToolCorrector(),
openaiWSPool: pool,
}
account := &Account{
ID: 81,
Name: "openai-read-timeout",
Platform: PlatformOpenAI,
Type: AccountTypeAPIKey,
Status: StatusActive,
Schedulable: true,
Concurrency: 1,
Credentials: map[string]any{
"api_key": "sk-test",
},
Extra: map[string]any{
"responses_websockets_v2_enabled": true,
},
}
body := []byte(`{"model":"gpt-5.1","stream":false,"input":[{"type":"input_text","text":"hello"}]}`)
result, err := svc.Forward(context.Background(), c, account, body)
require.NoError(t, err)
require.NotNil(t, result)
require.Equal(t, "resp_timeout_ok", result.RequestID)
require.Nil(t, upstream.lastReq, "每次 Read 都应独立应用超时;总时长超过 read_timeout 不应误回退 HTTP")
}
type openAIWSCaptureDialer struct {
mu sync.Mutex
conn *openAIWSCaptureConn
lastHeaders http.Header
handshake http.Header
dialCount int
}
func (d *openAIWSCaptureDialer) Dial(
ctx context.Context,
wsURL string,
headers http.Header,
proxyURL string,
) (openAIWSClientConn, int, http.Header, error) {
_ = ctx
_ = wsURL
_ = proxyURL
d.mu.Lock()
d.lastHeaders = cloneHeader(headers)
d.dialCount++
respHeaders := cloneHeader(d.handshake)
d.mu.Unlock()
return d.conn, 0, respHeaders, nil
}
func (d *openAIWSCaptureDialer) DialCount() int {
d.mu.Lock()
defer d.mu.Unlock()
return d.dialCount
}
type openAIWSCaptureConn struct {
mu sync.Mutex
readDelays []time.Duration
events [][]byte
lastWrite map[string]any
writes []map[string]any
closed bool
}
func (c *openAIWSCaptureConn) WriteJSON(ctx context.Context, value any) error {
_ = ctx
c.mu.Lock()
defer c.mu.Unlock()
if c.closed {
return errOpenAIWSConnClosed
}
switch payload := value.(type) {
case map[string]any:
c.lastWrite = cloneMapStringAny(payload)
c.writes = append(c.writes, cloneMapStringAny(payload))
case json.RawMessage:
var parsed map[string]any
if err := json.Unmarshal(payload, &parsed); err == nil {
c.lastWrite = cloneMapStringAny(parsed)
c.writes = append(c.writes, cloneMapStringAny(parsed))
}
case []byte:
var parsed map[string]any
if err := json.Unmarshal(payload, &parsed); err == nil {
c.lastWrite = cloneMapStringAny(parsed)
c.writes = append(c.writes, cloneMapStringAny(parsed))
}
}
return nil
}
func (c *openAIWSCaptureConn) ReadMessage(ctx context.Context) ([]byte, error) {
if ctx == nil {
ctx = context.Background()
}
c.mu.Lock()
if c.closed {
c.mu.Unlock()
return nil, errOpenAIWSConnClosed
}
if len(c.events) == 0 {
c.mu.Unlock()
return nil, io.EOF
}
delay := time.Duration(0)
if len(c.readDelays) > 0 {
delay = c.readDelays[0]
c.readDelays = c.readDelays[1:]
}
event := c.events[0]
c.events = c.events[1:]
c.mu.Unlock()
if delay > 0 {
timer := time.NewTimer(delay)
defer timer.Stop()
select {
case <-ctx.Done():
return nil, ctx.Err()
case <-timer.C:
}
}
return event, nil
}
func (c *openAIWSCaptureConn) Ping(ctx context.Context) error {
_ = ctx
return nil
}
func (c *openAIWSCaptureConn) Close() error {
c.mu.Lock()
defer c.mu.Unlock()
c.closed = true
return nil
}
func cloneMapStringAny(src map[string]any) map[string]any {
if src == nil {
return nil
}
dst := make(map[string]any, len(src))
for k, v := range src {
dst[k] = v
}
return dst
}