From 1d0872e7cace0c0386131443ac5e3e4046e96f3b Mon Sep 17 00:00:00 2001 From: yangjianbo Date: Thu, 5 Mar 2026 11:50:58 +0800 Subject: [PATCH] =?UTF-8?q?feat(openai-ws):=20=E5=90=88=E5=B9=B6=20WS=20v2?= =?UTF-8?q?=20=E9=80=8F=E4=BC=A0=E6=A8=A1=E5=BC=8F=E4=B8=8E=E5=89=8D?= =?UTF-8?q?=E7=AB=AF=20ws=20mode?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 新增 OpenAI WebSocket v2 passthrough relay 数据面与服务适配层, 支持按账号 ws mode 在 ctx_pool 与 passthrough 间路由。 同步调整前端 OpenAI ws mode 选项为 off/ctx_pool/passthrough, 并补充 i18n 文案与对应单测。 新增 Caddyfile.dmit 与 docker-compose-aicodex.yml 部署配置, 用于宿主机场景下的反向代理与服务编排。 Co-Authored-By: Claude Opus 4.6 --- Caddyfile.dmit | 222 +++++ backend/internal/config/config.go | 10 +- backend/internal/config/config_test.go | 6 +- backend/internal/service/account.go | 27 +- .../account_openai_passthrough_test.go | 37 +- .../service/openai_gateway_service.go | 16 +- backend/internal/service/openai_ws_client.go | 27 + .../internal/service/openai_ws_forwarder.go | 52 +- ...penai_ws_forwarder_ingress_session_test.go | 136 ++- .../openai_ws_forwarder_success_test.go | 13 + .../service/openai_ws_protocol_resolver.go | 5 +- .../openai_ws_protocol_resolver_test.go | 28 +- .../service/openai_ws_v2/caddy_adapter.go | 24 + .../internal/service/openai_ws_v2/entry.go | 23 + .../internal/service/openai_ws_v2/metrics.go | 29 + .../service/openai_ws_v2/passthrough_relay.go | 807 ++++++++++++++++++ .../passthrough_relay_internal_test.go | 432 ++++++++++ .../openai_ws_v2/passthrough_relay_test.go | 752 ++++++++++++++++ .../openai_ws_v2_passthrough_adapter.go | 367 ++++++++ deploy/config.example.yaml | 5 +- docker-compose-aicodex.yml | 263 ++++++ .../components/account/CreateAccountModal.vue | 28 +- .../components/account/EditAccountModal.vue | 27 +- frontend/src/i18n/locales/en.ts | 3 + frontend/src/i18n/locales/zh.ts | 3 + .../src/utils/__tests__/openaiWsMode.spec.ts | 35 +- frontend/src/utils/openaiWsMode.ts | 26 +- 27 files changed, 3322 insertions(+), 81 deletions(-) create mode 100644 Caddyfile.dmit create mode 100644 backend/internal/service/openai_ws_v2/caddy_adapter.go create mode 100644 backend/internal/service/openai_ws_v2/entry.go create mode 100644 backend/internal/service/openai_ws_v2/metrics.go create mode 100644 backend/internal/service/openai_ws_v2/passthrough_relay.go create mode 100644 backend/internal/service/openai_ws_v2/passthrough_relay_internal_test.go create mode 100644 backend/internal/service/openai_ws_v2/passthrough_relay_test.go create mode 100644 backend/internal/service/openai_ws_v2_passthrough_adapter.go create mode 100644 docker-compose-aicodex.yml diff --git a/Caddyfile.dmit b/Caddyfile.dmit new file mode 100644 index 00000000..232606bb --- /dev/null +++ b/Caddyfile.dmit @@ -0,0 +1,222 @@ +# ============================================================================= +# Sub2API Caddy Reverse Proxy Configuration (宿主机部署) +# ============================================================================= +# 使用方法: +# 1. 安装 Caddy: https://caddyserver.com/docs/install +# 2. 修改下方 example.com 为你的域名 +# 3. 确保域名 DNS 已指向服务器 +# 4. 复制配置: sudo cp Caddyfile /etc/caddy/Caddyfile +# 5. 重载配置: sudo systemctl reload caddy +# +# Caddy 会自动申请和续期 Let's Encrypt SSL 证书 +# ============================================================================= + +# 全局配置 +{ + # Let's Encrypt 邮箱通知 + email mt21625457@gmail.com + + # 服务器配置 + servers { + # 启用 HTTP/2 和 HTTP/3 + protocols h1 h2 h3 + + # 超时配置 + timeouts { + read_body 30s + read_header 10s + # WebSocket/流式场景下,延长写入与空闲超时,避免长会话被过早回收 + write 3600s + idle 3600s + } + } +} + +# 修改为你的域名 +dmit.leagsoft.ai { + # ========================================================================= + # 静态资源长期缓存(高优先级,放在最前面) + # 带 hash 的文件可以永久缓存,浏览器和 CDN 都会缓存 + # ========================================================================= + @static { + path /assets/* + path /logo.png + path /favicon.ico + } + header @static { + Cache-Control "public, max-age=31536000, immutable" + # 移除可能干扰缓存的头 + -Pragma + -Expires + } + # ========================================================================= + # TLS 安全配置 + # ========================================================================= + tls { + # 仅使用 TLS 1.2 和 1.3 + protocols tls1.2 tls1.3 + + # 优先使用的加密套件 + ciphers TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384 TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384 TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256 TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256 TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256 TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256 + } + + # ========================================================================= + # 反向代理配置 + # ========================================================================= + # OpenAI Responses(含 WebSocket/SSE)专用代理: + # 1) 禁用流式缓冲,降低中间层等待导致的断流概率 + # 2) 上游强制 HTTP/1.1,保证 Upgrade 行为稳定 + # 3) 放宽流生命周期,避免长会话被代理提前切断 + @openai_responses { + path /openai/v1/responses* + } + reverse_proxy @openai_responses localhost:8080 { + flush_interval -1 + stream_timeout 24h + stream_close_delay 5m + + # 传递真实客户端信息 + header_up X-Real-IP {remote_host} + header_up X-Forwarded-For {remote_host} + header_up X-Forwarded-Proto {scheme} + header_up X-Forwarded-Host {host} + header_up CF-Connecting-IP {http.request.header.CF-Connecting-IP} + + transport http { + versions 1.1 + keepalive 120s + keepalive_idle_conns 256 + read_buffer 32KB + write_buffer 32KB + compression off + } + } + + reverse_proxy localhost:8080 { + # 健康检查 + health_uri /health + health_interval 30s + health_timeout 10s + health_status 200 + + # 负载均衡策略(单节点可忽略,多节点时有用) + lb_policy round_robin + lb_try_duration 5s + lb_try_interval 250ms + + # 传递真实客户端信息 + # 兼容 Cloudflare 和直连:后端应优先读取 CF-Connecting-IP,其次 X-Real-IP + header_up X-Real-IP {remote_host} + header_up X-Forwarded-For {remote_host} + header_up X-Forwarded-Proto {scheme} + header_up X-Forwarded-Host {host} + # 保留 Cloudflare 原始头(如果存在) + # 后端获取 IP 的优先级建议: CF-Connecting-IP → X-Real-IP → X-Forwarded-For + header_up CF-Connecting-IP {http.request.header.CF-Connecting-IP} + + # 连接池优化 + transport http { + keepalive 120s + keepalive_idle_conns 256 + read_buffer 16KB + write_buffer 16KB + compression off + } + + # 故障转移 + fail_duration 30s + max_fails 3 + unhealthy_status 500 502 503 504 + } + + # ========================================================================= + # 压缩配置 + # ========================================================================= + encode { + zstd + gzip 6 + minimum_length 256 + match { + header Content-Type text/* + header Content-Type application/json* + header Content-Type application/javascript* + header Content-Type application/xml* + header Content-Type application/rss+xml* + header Content-Type image/svg+xml* + } + } + + # ========================================================================= + # 速率限制 (需要 caddy-ratelimit 插件) + # 如未安装插件,请注释掉此段 + # ========================================================================= + # rate_limit { + # zone api { + # key {remote_host} + # events 100 + # window 1m + # } + # } + + # ========================================================================= + # 安全响应头 + # ========================================================================= + header { + # 防止点击劫持 + X-Frame-Options "SAMEORIGIN" + + # XSS 保护 + X-XSS-Protection "1; mode=block" + + # 防止 MIME 类型嗅探 + X-Content-Type-Options "nosniff" + + # 引用策略 + Referrer-Policy "strict-origin-when-cross-origin" + + # HSTS - 强制 HTTPS (max-age=1年) + Strict-Transport-Security "max-age=31536000; includeSubDomains; preload" + + # 内容安全策略 (根据需要调整) + # Content-Security-Policy "default-src 'self'; script-src 'self' 'unsafe-inline' 'unsafe-eval'; style-src 'self' 'unsafe-inline'; img-src 'self' data: https:; font-src 'self' data:; connect-src 'self' https:;" + + # 权限策略 + Permissions-Policy "accelerometer=(), camera=(), geolocation=(), gyroscope=(), magnetometer=(), microphone=(), payment=(), usb=()" + + # 跨域资源策略 + Cross-Origin-Opener-Policy "same-origin" + Cross-Origin-Embedder-Policy "require-corp" + Cross-Origin-Resource-Policy "same-origin" + + # 移除敏感头 + -Server + -X-Powered-By + } + + # ========================================================================= + # 请求大小限制 (防止大文件攻击) + # ========================================================================= + request_body { + max_size 100MB + } + + # ========================================================================= + # 日志配置 + # ========================================================================= + log { + output file /var/log/caddy/sub2api.log { + roll_size 50mb + roll_keep 10 + roll_keep_for 720h + } + format json + level INFO + } + + # ========================================================================= + # 错误处理 + # ========================================================================= + handle_errors { + respond "{err.status_code} {err.status_text}" + } +} \ No newline at end of file diff --git a/backend/internal/config/config.go b/backend/internal/config/config.go index 54be38a1..42f1e629 100644 --- a/backend/internal/config/config.go +++ b/backend/internal/config/config.go @@ -516,7 +516,7 @@ func (c *UserMessageQueueConfig) GetEffectiveMode() string { type GatewayOpenAIWSConfig struct { // ModeRouterV2Enabled: 新版 WS mode 路由开关(默认 false;关闭时保持 legacy 行为) ModeRouterV2Enabled bool `mapstructure:"mode_router_v2_enabled"` - // IngressModeDefault: ingress 默认模式(off/shared/dedicated) + // IngressModeDefault: ingress 默认模式(off/ctx_pool/passthrough) IngressModeDefault string `mapstructure:"ingress_mode_default"` // Enabled: 全局总开关(默认 true) Enabled bool `mapstructure:"enabled"` @@ -1335,7 +1335,7 @@ func setDefaults() { // OpenAI Responses WebSocket(默认开启;可通过 force_http 紧急回滚) viper.SetDefault("gateway.openai_ws.enabled", true) viper.SetDefault("gateway.openai_ws.mode_router_v2_enabled", false) - viper.SetDefault("gateway.openai_ws.ingress_mode_default", "shared") + viper.SetDefault("gateway.openai_ws.ingress_mode_default", "ctx_pool") viper.SetDefault("gateway.openai_ws.oauth_enabled", true) viper.SetDefault("gateway.openai_ws.apikey_enabled", true) viper.SetDefault("gateway.openai_ws.force_http", false) @@ -2043,9 +2043,11 @@ func (c *Config) Validate() error { } if mode := strings.ToLower(strings.TrimSpace(c.Gateway.OpenAIWS.IngressModeDefault)); mode != "" { switch mode { - case "off", "shared", "dedicated": + case "off", "ctx_pool", "passthrough": + case "shared", "dedicated": + slog.Warn("gateway.openai_ws.ingress_mode_default is deprecated, treating as ctx_pool; please update to off|ctx_pool|passthrough", "value", mode) default: - return fmt.Errorf("gateway.openai_ws.ingress_mode_default must be one of off|shared|dedicated") + return fmt.Errorf("gateway.openai_ws.ingress_mode_default must be one of off|ctx_pool|passthrough") } } if mode := strings.ToLower(strings.TrimSpace(c.Gateway.OpenAIWS.StoreDisabledConnMode)); mode != "" { diff --git a/backend/internal/config/config_test.go b/backend/internal/config/config_test.go index e3b592e2..79fcc6d0 100644 --- a/backend/internal/config/config_test.go +++ b/backend/internal/config/config_test.go @@ -153,8 +153,8 @@ func TestLoadDefaultOpenAIWSConfig(t *testing.T) { if cfg.Gateway.OpenAIWS.ModeRouterV2Enabled { t.Fatalf("Gateway.OpenAIWS.ModeRouterV2Enabled = true, want false") } - if cfg.Gateway.OpenAIWS.IngressModeDefault != "shared" { - t.Fatalf("Gateway.OpenAIWS.IngressModeDefault = %q, want %q", cfg.Gateway.OpenAIWS.IngressModeDefault, "shared") + if cfg.Gateway.OpenAIWS.IngressModeDefault != "ctx_pool" { + t.Fatalf("Gateway.OpenAIWS.IngressModeDefault = %q, want %q", cfg.Gateway.OpenAIWS.IngressModeDefault, "ctx_pool") } } @@ -1373,7 +1373,7 @@ func TestValidateConfig_OpenAIWSRules(t *testing.T) { wantErr: "gateway.openai_ws.store_disabled_conn_mode", }, { - name: "ingress_mode_default 必须为 off|shared|dedicated", + name: "ingress_mode_default 必须为 off|ctx_pool|passthrough", mutate: func(c *Config) { c.Gateway.OpenAIWS.IngressModeDefault = "invalid" }, wantErr: "gateway.openai_ws.ingress_mode_default", }, diff --git a/backend/internal/service/account.go b/backend/internal/service/account.go index 81e91aeb..7d56b754 100644 --- a/backend/internal/service/account.go +++ b/backend/internal/service/account.go @@ -853,15 +853,21 @@ func (a *Account) IsOpenAIResponsesWebSocketV2Enabled() bool { } const ( - OpenAIWSIngressModeOff = "off" - OpenAIWSIngressModeShared = "shared" - OpenAIWSIngressModeDedicated = "dedicated" + OpenAIWSIngressModeOff = "off" + OpenAIWSIngressModeShared = "shared" + OpenAIWSIngressModeDedicated = "dedicated" + OpenAIWSIngressModeCtxPool = "ctx_pool" + OpenAIWSIngressModePassthrough = "passthrough" ) func normalizeOpenAIWSIngressMode(mode string) string { switch strings.ToLower(strings.TrimSpace(mode)) { case OpenAIWSIngressModeOff: return OpenAIWSIngressModeOff + case OpenAIWSIngressModeCtxPool: + return OpenAIWSIngressModeCtxPool + case OpenAIWSIngressModePassthrough: + return OpenAIWSIngressModePassthrough case OpenAIWSIngressModeShared: return OpenAIWSIngressModeShared case OpenAIWSIngressModeDedicated: @@ -873,18 +879,21 @@ func normalizeOpenAIWSIngressMode(mode string) string { func normalizeOpenAIWSIngressDefaultMode(mode string) string { if normalized := normalizeOpenAIWSIngressMode(mode); normalized != "" { + if normalized == OpenAIWSIngressModeShared || normalized == OpenAIWSIngressModeDedicated { + return OpenAIWSIngressModeCtxPool + } return normalized } - return OpenAIWSIngressModeShared + return OpenAIWSIngressModeCtxPool } -// ResolveOpenAIResponsesWebSocketV2Mode 返回账号在 WSv2 ingress 下的有效模式(off/shared/dedicated)。 +// ResolveOpenAIResponsesWebSocketV2Mode 返回账号在 WSv2 ingress 下的有效模式(off/ctx_pool/passthrough)。 // // 优先级: // 1. 分类型 mode 新字段(string) // 2. 分类型 enabled 旧字段(bool) // 3. 兼容 enabled 旧字段(bool) -// 4. defaultMode(非法时回退 shared) +// 4. defaultMode(非法时回退 ctx_pool) func (a *Account) ResolveOpenAIResponsesWebSocketV2Mode(defaultMode string) string { resolvedDefault := normalizeOpenAIWSIngressDefaultMode(defaultMode) if a == nil || !a.IsOpenAI() { @@ -919,7 +928,7 @@ func (a *Account) ResolveOpenAIResponsesWebSocketV2Mode(defaultMode string) stri return "", false } if enabled { - return OpenAIWSIngressModeShared, true + return OpenAIWSIngressModeCtxPool, true } return OpenAIWSIngressModeOff, true } @@ -946,6 +955,10 @@ func (a *Account) ResolveOpenAIResponsesWebSocketV2Mode(defaultMode string) stri if mode, ok := resolveBoolMode("openai_ws_enabled"); ok { return mode } + // 兼容旧值:shared/dedicated 语义都归并到 ctx_pool。 + if resolvedDefault == OpenAIWSIngressModeShared || resolvedDefault == OpenAIWSIngressModeDedicated { + return OpenAIWSIngressModeCtxPool + } return resolvedDefault } diff --git a/backend/internal/service/account_openai_passthrough_test.go b/backend/internal/service/account_openai_passthrough_test.go index a85c68ec..50c2b7cb 100644 --- a/backend/internal/service/account_openai_passthrough_test.go +++ b/backend/internal/service/account_openai_passthrough_test.go @@ -206,14 +206,14 @@ func TestAccount_IsOpenAIResponsesWebSocketV2Enabled(t *testing.T) { } func TestAccount_ResolveOpenAIResponsesWebSocketV2Mode(t *testing.T) { - t.Run("default fallback to shared", func(t *testing.T) { + t.Run("default fallback to ctx_pool", func(t *testing.T) { account := &Account{ Platform: PlatformOpenAI, Type: AccountTypeOAuth, Extra: map[string]any{}, } - require.Equal(t, OpenAIWSIngressModeShared, account.ResolveOpenAIResponsesWebSocketV2Mode("")) - require.Equal(t, OpenAIWSIngressModeShared, account.ResolveOpenAIResponsesWebSocketV2Mode("invalid")) + require.Equal(t, OpenAIWSIngressModeCtxPool, account.ResolveOpenAIResponsesWebSocketV2Mode("")) + require.Equal(t, OpenAIWSIngressModeCtxPool, account.ResolveOpenAIResponsesWebSocketV2Mode("invalid")) }) t.Run("oauth mode field has highest priority", func(t *testing.T) { @@ -221,15 +221,15 @@ func TestAccount_ResolveOpenAIResponsesWebSocketV2Mode(t *testing.T) { Platform: PlatformOpenAI, Type: AccountTypeOAuth, Extra: map[string]any{ - "openai_oauth_responses_websockets_v2_mode": OpenAIWSIngressModeDedicated, + "openai_oauth_responses_websockets_v2_mode": OpenAIWSIngressModePassthrough, "openai_oauth_responses_websockets_v2_enabled": false, "responses_websockets_v2_enabled": false, }, } - require.Equal(t, OpenAIWSIngressModeDedicated, account.ResolveOpenAIResponsesWebSocketV2Mode(OpenAIWSIngressModeShared)) + require.Equal(t, OpenAIWSIngressModePassthrough, account.ResolveOpenAIResponsesWebSocketV2Mode(OpenAIWSIngressModeCtxPool)) }) - t.Run("legacy enabled maps to shared", func(t *testing.T) { + t.Run("legacy enabled maps to ctx_pool", func(t *testing.T) { account := &Account{ Platform: PlatformOpenAI, Type: AccountTypeAPIKey, @@ -237,7 +237,28 @@ func TestAccount_ResolveOpenAIResponsesWebSocketV2Mode(t *testing.T) { "responses_websockets_v2_enabled": true, }, } - require.Equal(t, OpenAIWSIngressModeShared, account.ResolveOpenAIResponsesWebSocketV2Mode(OpenAIWSIngressModeOff)) + require.Equal(t, OpenAIWSIngressModeCtxPool, account.ResolveOpenAIResponsesWebSocketV2Mode(OpenAIWSIngressModeOff)) + }) + + t.Run("shared/dedicated mode strings are compatible with ctx_pool", func(t *testing.T) { + shared := &Account{ + Platform: PlatformOpenAI, + Type: AccountTypeOAuth, + Extra: map[string]any{ + "openai_oauth_responses_websockets_v2_mode": OpenAIWSIngressModeShared, + }, + } + dedicated := &Account{ + Platform: PlatformOpenAI, + Type: AccountTypeOAuth, + Extra: map[string]any{ + "openai_oauth_responses_websockets_v2_mode": OpenAIWSIngressModeDedicated, + }, + } + require.Equal(t, OpenAIWSIngressModeShared, shared.ResolveOpenAIResponsesWebSocketV2Mode(OpenAIWSIngressModeOff)) + require.Equal(t, OpenAIWSIngressModeDedicated, dedicated.ResolveOpenAIResponsesWebSocketV2Mode(OpenAIWSIngressModeOff)) + require.Equal(t, OpenAIWSIngressModeCtxPool, normalizeOpenAIWSIngressDefaultMode(OpenAIWSIngressModeShared)) + require.Equal(t, OpenAIWSIngressModeCtxPool, normalizeOpenAIWSIngressDefaultMode(OpenAIWSIngressModeDedicated)) }) t.Run("legacy disabled maps to off", func(t *testing.T) { @@ -249,7 +270,7 @@ func TestAccount_ResolveOpenAIResponsesWebSocketV2Mode(t *testing.T) { "responses_websockets_v2_enabled": true, }, } - require.Equal(t, OpenAIWSIngressModeOff, account.ResolveOpenAIResponsesWebSocketV2Mode(OpenAIWSIngressModeShared)) + require.Equal(t, OpenAIWSIngressModeOff, account.ResolveOpenAIResponsesWebSocketV2Mode(OpenAIWSIngressModeCtxPool)) }) t.Run("non openai always off", func(t *testing.T) { diff --git a/backend/internal/service/openai_gateway_service.go b/backend/internal/service/openai_gateway_service.go index 8606708f..d92b2ecf 100644 --- a/backend/internal/service/openai_gateway_service.go +++ b/backend/internal/service/openai_gateway_service.go @@ -263,13 +263,15 @@ type OpenAIGatewayService struct { toolCorrector *CodexToolCorrector openaiWSResolver OpenAIWSProtocolResolver - openaiWSPoolOnce sync.Once - openaiWSStateStoreOnce sync.Once - openaiSchedulerOnce sync.Once - openaiWSPool *openAIWSConnPool - openaiWSStateStore OpenAIWSStateStore - openaiScheduler OpenAIAccountScheduler - openaiAccountStats *openAIAccountRuntimeStats + openaiWSPoolOnce sync.Once + openaiWSStateStoreOnce sync.Once + openaiSchedulerOnce sync.Once + openaiWSPassthroughDialerOnce sync.Once + openaiWSPool *openAIWSConnPool + openaiWSStateStore OpenAIWSStateStore + openaiScheduler OpenAIAccountScheduler + openaiWSPassthroughDialer openAIWSClientDialer + openaiAccountStats *openAIAccountRuntimeStats openaiWSFallbackUntil sync.Map // key: int64(accountID), value: time.Time openaiWSRetryMetrics openAIWSRetryMetrics diff --git a/backend/internal/service/openai_ws_client.go b/backend/internal/service/openai_ws_client.go index 9f3c47b7..80b75530 100644 --- a/backend/internal/service/openai_ws_client.go +++ b/backend/internal/service/openai_ws_client.go @@ -11,6 +11,7 @@ import ( "sync/atomic" "time" + openaiwsv2 "github.com/Wei-Shaw/sub2api/internal/service/openai_ws_v2" coderws "github.com/coder/websocket" "github.com/coder/websocket/wsjson" ) @@ -234,6 +235,8 @@ type coderOpenAIWSClientConn struct { conn *coderws.Conn } +var _ openaiwsv2.FrameConn = (*coderOpenAIWSClientConn)(nil) + func (c *coderOpenAIWSClientConn) WriteJSON(ctx context.Context, value any) error { if c == nil || c.conn == nil { return errOpenAIWSConnClosed @@ -264,6 +267,30 @@ func (c *coderOpenAIWSClientConn) ReadMessage(ctx context.Context) ([]byte, erro } } +func (c *coderOpenAIWSClientConn) ReadFrame(ctx context.Context) (coderws.MessageType, []byte, error) { + if c == nil || c.conn == nil { + return coderws.MessageText, nil, errOpenAIWSConnClosed + } + if ctx == nil { + ctx = context.Background() + } + msgType, payload, err := c.conn.Read(ctx) + if err != nil { + return coderws.MessageText, nil, err + } + return msgType, payload, nil +} + +func (c *coderOpenAIWSClientConn) WriteFrame(ctx context.Context, msgType coderws.MessageType, payload []byte) error { + if c == nil || c.conn == nil { + return errOpenAIWSConnClosed + } + if ctx == nil { + ctx = context.Background() + } + return c.conn.Write(ctx, msgType, payload) +} + func (c *coderOpenAIWSClientConn) Ping(ctx context.Context) error { if c == nil || c.conn == nil { return errOpenAIWSConnClosed diff --git a/backend/internal/service/openai_ws_forwarder.go b/backend/internal/service/openai_ws_forwarder.go index 74ba472f..a5c2fd7a 100644 --- a/backend/internal/service/openai_ws_forwarder.go +++ b/backend/internal/service/openai_ws_forwarder.go @@ -46,9 +46,10 @@ const ( openAIWSPayloadSizeEstimateMaxBytes = 64 * 1024 openAIWSPayloadSizeEstimateMaxItems = 16 - openAIWSEventFlushBatchSizeDefault = 4 - openAIWSEventFlushIntervalDefault = 25 * time.Millisecond - openAIWSPayloadLogSampleDefault = 0.2 + openAIWSEventFlushBatchSizeDefault = 4 + openAIWSEventFlushIntervalDefault = 25 * time.Millisecond + openAIWSPayloadLogSampleDefault = 0.2 + openAIWSPassthroughIdleTimeoutDefault = time.Hour openAIWSStoreDisabledConnModeStrict = "strict" openAIWSStoreDisabledConnModeAdaptive = "adaptive" @@ -904,6 +905,18 @@ func (s *OpenAIGatewayService) getOpenAIWSConnPool() *openAIWSConnPool { return s.openaiWSPool } +func (s *OpenAIGatewayService) getOpenAIWSPassthroughDialer() openAIWSClientDialer { + if s == nil { + return nil + } + s.openaiWSPassthroughDialerOnce.Do(func() { + if s.openaiWSPassthroughDialer == nil { + s.openaiWSPassthroughDialer = newDefaultOpenAIWSClientDialer() + } + }) + return s.openaiWSPassthroughDialer +} + func (s *OpenAIGatewayService) SnapshotOpenAIWSPoolMetrics() OpenAIWSPoolMetricsSnapshot { pool := s.getOpenAIWSConnPool() if pool == nil { @@ -967,6 +980,13 @@ func (s *OpenAIGatewayService) openAIWSReadTimeout() time.Duration { return 15 * time.Minute } +func (s *OpenAIGatewayService) openAIWSPassthroughIdleTimeout() time.Duration { + if timeout := s.openAIWSReadTimeout(); timeout > 0 { + return timeout + } + return openAIWSPassthroughIdleTimeoutDefault +} + func (s *OpenAIGatewayService) openAIWSWriteTimeout() time.Duration { if s != nil && s.cfg != nil && s.cfg.Gateway.OpenAIWS.WriteTimeoutSeconds > 0 { return time.Duration(s.cfg.Gateway.OpenAIWS.WriteTimeoutSeconds) * time.Second @@ -2322,7 +2342,7 @@ func (s *OpenAIGatewayService) ProxyResponsesWebSocketFromClient( wsDecision := s.getOpenAIWSProtocolResolver().Resolve(account) modeRouterV2Enabled := s != nil && s.cfg != nil && s.cfg.Gateway.OpenAIWS.ModeRouterV2Enabled - ingressMode := OpenAIWSIngressModeShared + ingressMode := OpenAIWSIngressModeCtxPool if modeRouterV2Enabled { ingressMode = account.ResolveOpenAIResponsesWebSocketV2Mode(s.cfg.Gateway.OpenAIWS.IngressModeDefault) if ingressMode == OpenAIWSIngressModeOff { @@ -2332,6 +2352,30 @@ func (s *OpenAIGatewayService) ProxyResponsesWebSocketFromClient( nil, ) } + switch ingressMode { + case OpenAIWSIngressModePassthrough: + if wsDecision.Transport != OpenAIUpstreamTransportResponsesWebsocketV2 { + return fmt.Errorf("websocket ingress requires ws_v2 transport, got=%s", wsDecision.Transport) + } + return s.proxyResponsesWebSocketV2Passthrough( + ctx, + c, + clientConn, + account, + token, + firstClientMessage, + hooks, + wsDecision, + ) + case OpenAIWSIngressModeCtxPool, OpenAIWSIngressModeShared, OpenAIWSIngressModeDedicated: + // continue + default: + return NewOpenAIWSClientCloseError( + coderws.StatusPolicyViolation, + "websocket mode only supports ctx_pool/passthrough", + nil, + ) + } } if wsDecision.Transport != OpenAIUpstreamTransportResponsesWebsocketV2 { return fmt.Errorf("websocket ingress requires ws_v2 transport, got=%s", wsDecision.Transport) diff --git a/backend/internal/service/openai_ws_forwarder_ingress_session_test.go b/backend/internal/service/openai_ws_forwarder_ingress_session_test.go index 5a3c12c3..59e6ecad 100644 --- a/backend/internal/service/openai_ws_forwarder_ingress_session_test.go +++ b/backend/internal/service/openai_ws_forwarder_ingress_session_test.go @@ -149,7 +149,7 @@ func TestOpenAIGatewayService_ProxyResponsesWebSocketFromClient_KeepLeaseAcrossT require.True(t, <-turnWSModeCh, "首轮 turn 应标记为 WS 模式") require.True(t, <-turnWSModeCh, "第二轮 turn 应标记为 WS 模式") - require.NoError(t, clientConn.Close(coderws.StatusNormalClosure, "done")) + _ = clientConn.Close(coderws.StatusNormalClosure, "done") select { case serverErr := <-serverErrCh: @@ -298,6 +298,140 @@ func TestOpenAIGatewayService_ProxyResponsesWebSocketFromClient_DedicatedModeDoe require.Equal(t, 2, dialer.DialCount(), "dedicated 模式下跨客户端会话不应复用上游连接") } +func TestOpenAIGatewayService_ProxyResponsesWebSocketFromClient_PassthroughModeRelaysByCaddyAdapter(t *testing.T) { + gin.SetMode(gin.TestMode) + + cfg := &config.Config{} + cfg.Security.URLAllowlist.Enabled = false + cfg.Security.URLAllowlist.AllowInsecureHTTP = true + cfg.Gateway.OpenAIWS.Enabled = true + cfg.Gateway.OpenAIWS.OAuthEnabled = true + cfg.Gateway.OpenAIWS.APIKeyEnabled = true + cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true + cfg.Gateway.OpenAIWS.ModeRouterV2Enabled = true + cfg.Gateway.OpenAIWS.IngressModeDefault = OpenAIWSIngressModeCtxPool + cfg.Gateway.OpenAIWS.DialTimeoutSeconds = 3 + cfg.Gateway.OpenAIWS.ReadTimeoutSeconds = 3 + cfg.Gateway.OpenAIWS.WriteTimeoutSeconds = 3 + + upstreamConn := &openAIWSCaptureConn{ + events: [][]byte{ + []byte(`{"type":"response.completed","response":{"id":"resp_passthrough_turn_1","model":"gpt-5.1","usage":{"input_tokens":2,"output_tokens":3}}}`), + }, + } + captureDialer := &openAIWSCaptureDialer{conn: upstreamConn} + svc := &OpenAIGatewayService{ + cfg: cfg, + httpUpstream: &httpUpstreamRecorder{}, + cache: &stubGatewayCache{}, + openaiWSResolver: NewOpenAIWSProtocolResolver(cfg), + toolCorrector: NewCodexToolCorrector(), + openaiWSPassthroughDialer: captureDialer, + } + + account := &Account{ + ID: 452, + Name: "openai-ingress-passthrough", + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + Status: StatusActive, + Schedulable: true, + Concurrency: 1, + Credentials: map[string]any{ + "api_key": "sk-test", + }, + Extra: map[string]any{ + "openai_apikey_responses_websockets_v2_mode": OpenAIWSIngressModePassthrough, + }, + } + + serverErrCh := make(chan error, 1) + resultCh := make(chan *OpenAIForwardResult, 1) + hooks := &OpenAIWSIngressHooks{ + AfterTurn: func(_ int, result *OpenAIForwardResult, turnErr error) { + if turnErr == nil && result != nil { + resultCh <- result + } + }, + } + + wsServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + conn, err := coderws.Accept(w, r, &coderws.AcceptOptions{ + CompressionMode: coderws.CompressionContextTakeover, + }) + if err != nil { + serverErrCh <- err + return + } + defer func() { + _ = conn.CloseNow() + }() + + rec := httptest.NewRecorder() + ginCtx, _ := gin.CreateTestContext(rec) + req := r.Clone(r.Context()) + req.Header = req.Header.Clone() + req.Header.Set("User-Agent", "unit-test-agent/1.0") + ginCtx.Request = req + + readCtx, cancel := context.WithTimeout(r.Context(), 3*time.Second) + msgType, firstMessage, readErr := conn.Read(readCtx) + cancel() + if readErr != nil { + serverErrCh <- readErr + return + } + if msgType != coderws.MessageText && msgType != coderws.MessageBinary { + serverErrCh <- errors.New("unsupported websocket client message type") + return + } + + serverErrCh <- svc.ProxyResponsesWebSocketFromClient(r.Context(), ginCtx, conn, account, "sk-test", firstMessage, hooks) + })) + defer wsServer.Close() + + dialCtx, cancelDial := context.WithTimeout(context.Background(), 3*time.Second) + clientConn, _, err := coderws.Dial(dialCtx, "ws"+strings.TrimPrefix(wsServer.URL, "http"), nil) + cancelDial() + require.NoError(t, err) + defer func() { + _ = clientConn.CloseNow() + }() + + writeCtx, cancelWrite := context.WithTimeout(context.Background(), 3*time.Second) + err = clientConn.Write(writeCtx, coderws.MessageText, []byte(`{"type":"response.create","model":"gpt-5.1","stream":false}`)) + cancelWrite() + require.NoError(t, err) + + readCtx, cancelRead := context.WithTimeout(context.Background(), 3*time.Second) + _, event, readErr := clientConn.Read(readCtx) + cancelRead() + require.NoError(t, readErr) + require.Equal(t, "response.completed", gjson.GetBytes(event, "type").String()) + require.Equal(t, "resp_passthrough_turn_1", gjson.GetBytes(event, "response.id").String()) + _ = clientConn.Close(coderws.StatusNormalClosure, "done") + + select { + case serverErr := <-serverErrCh: + require.NoError(t, serverErr) + case <-time.After(5 * time.Second): + t.Fatal("等待 passthrough websocket 结束超时") + } + + select { + case result := <-resultCh: + require.Equal(t, "resp_passthrough_turn_1", result.RequestID) + require.True(t, result.OpenAIWSMode) + require.Equal(t, 2, result.Usage.InputTokens) + require.Equal(t, 3, result.Usage.OutputTokens) + case <-time.After(2 * time.Second): + t.Fatal("未收到 passthrough turn 结果回调") + } + + require.Equal(t, 1, captureDialer.DialCount(), "passthrough 模式应直接建立上游 websocket") + require.Len(t, upstreamConn.writes, 1, "passthrough 模式应透传首条 response.create") +} + func TestOpenAIGatewayService_ProxyResponsesWebSocketFromClient_ModeOffReturnsPolicyViolation(t *testing.T) { gin.SetMode(gin.TestMode) diff --git a/backend/internal/service/openai_ws_forwarder_success_test.go b/backend/internal/service/openai_ws_forwarder_success_test.go index 592801f6..1beb9ae9 100644 --- a/backend/internal/service/openai_ws_forwarder_success_test.go +++ b/backend/internal/service/openai_ws_forwarder_success_test.go @@ -15,6 +15,7 @@ import ( "time" "github.com/Wei-Shaw/sub2api/internal/config" + coderws "github.com/coder/websocket" "github.com/gin-gonic/gin" "github.com/gorilla/websocket" "github.com/stretchr/testify/require" @@ -1282,6 +1283,18 @@ func (c *openAIWSCaptureConn) ReadMessage(ctx context.Context) ([]byte, error) { return event, nil } +func (c *openAIWSCaptureConn) ReadFrame(ctx context.Context) (coderws.MessageType, []byte, error) { + payload, err := c.ReadMessage(ctx) + if err != nil { + return coderws.MessageText, nil, err + } + return coderws.MessageText, payload, nil +} + +func (c *openAIWSCaptureConn) WriteFrame(ctx context.Context, _ coderws.MessageType, payload []byte) error { + return c.WriteJSON(ctx, json.RawMessage(payload)) +} + func (c *openAIWSCaptureConn) Ping(ctx context.Context) error { _ = ctx return nil diff --git a/backend/internal/service/openai_ws_protocol_resolver.go b/backend/internal/service/openai_ws_protocol_resolver.go index 368643be..7266759c 100644 --- a/backend/internal/service/openai_ws_protocol_resolver.go +++ b/backend/internal/service/openai_ws_protocol_resolver.go @@ -69,8 +69,11 @@ func (r *defaultOpenAIWSProtocolResolver) Resolve(account *Account) OpenAIWSProt switch mode { case OpenAIWSIngressModeOff: return openAIWSHTTPDecision("account_mode_off") - case OpenAIWSIngressModeShared, OpenAIWSIngressModeDedicated: + case OpenAIWSIngressModeCtxPool, OpenAIWSIngressModePassthrough: // continue + case OpenAIWSIngressModeShared, OpenAIWSIngressModeDedicated: + // 历史值兼容:按 ctx_pool 处理。 + mode = OpenAIWSIngressModeCtxPool default: return openAIWSHTTPDecision("account_mode_off") } diff --git a/backend/internal/service/openai_ws_protocol_resolver_test.go b/backend/internal/service/openai_ws_protocol_resolver_test.go index 5be76e28..4d5dc5f1 100644 --- a/backend/internal/service/openai_ws_protocol_resolver_test.go +++ b/backend/internal/service/openai_ws_protocol_resolver_test.go @@ -143,21 +143,21 @@ func TestOpenAIWSProtocolResolver_Resolve_ModeRouterV2(t *testing.T) { cfg.Gateway.OpenAIWS.APIKeyEnabled = true cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true cfg.Gateway.OpenAIWS.ModeRouterV2Enabled = true - cfg.Gateway.OpenAIWS.IngressModeDefault = OpenAIWSIngressModeShared + cfg.Gateway.OpenAIWS.IngressModeDefault = OpenAIWSIngressModeCtxPool account := &Account{ Platform: PlatformOpenAI, Type: AccountTypeOAuth, Concurrency: 1, Extra: map[string]any{ - "openai_oauth_responses_websockets_v2_mode": OpenAIWSIngressModeDedicated, + "openai_oauth_responses_websockets_v2_mode": OpenAIWSIngressModeCtxPool, }, } - t.Run("dedicated mode routes to ws v2", func(t *testing.T) { + t.Run("ctx_pool mode routes to ws v2", func(t *testing.T) { decision := NewOpenAIWSProtocolResolver(cfg).Resolve(account) require.Equal(t, OpenAIUpstreamTransportResponsesWebsocketV2, decision.Transport) - require.Equal(t, "ws_v2_mode_dedicated", decision.Reason) + require.Equal(t, "ws_v2_mode_ctx_pool", decision.Reason) }) t.Run("off mode routes to http", func(t *testing.T) { @@ -174,7 +174,7 @@ func TestOpenAIWSProtocolResolver_Resolve_ModeRouterV2(t *testing.T) { require.Equal(t, "account_mode_off", decision.Reason) }) - t.Run("legacy boolean maps to shared in v2 router", func(t *testing.T) { + t.Run("legacy boolean maps to ctx_pool in v2 router", func(t *testing.T) { legacyAccount := &Account{ Platform: PlatformOpenAI, Type: AccountTypeAPIKey, @@ -185,7 +185,21 @@ func TestOpenAIWSProtocolResolver_Resolve_ModeRouterV2(t *testing.T) { } decision := NewOpenAIWSProtocolResolver(cfg).Resolve(legacyAccount) require.Equal(t, OpenAIUpstreamTransportResponsesWebsocketV2, decision.Transport) - require.Equal(t, "ws_v2_mode_shared", decision.Reason) + require.Equal(t, "ws_v2_mode_ctx_pool", decision.Reason) + }) + + t.Run("passthrough mode routes to ws v2", func(t *testing.T) { + passthroughAccount := &Account{ + Platform: PlatformOpenAI, + Type: AccountTypeOAuth, + Concurrency: 1, + Extra: map[string]any{ + "openai_oauth_responses_websockets_v2_mode": OpenAIWSIngressModePassthrough, + }, + } + decision := NewOpenAIWSProtocolResolver(cfg).Resolve(passthroughAccount) + require.Equal(t, OpenAIUpstreamTransportResponsesWebsocketV2, decision.Transport) + require.Equal(t, "ws_v2_mode_passthrough", decision.Reason) }) t.Run("non-positive concurrency is rejected in v2 router", func(t *testing.T) { @@ -193,7 +207,7 @@ func TestOpenAIWSProtocolResolver_Resolve_ModeRouterV2(t *testing.T) { Platform: PlatformOpenAI, Type: AccountTypeOAuth, Extra: map[string]any{ - "openai_oauth_responses_websockets_v2_mode": OpenAIWSIngressModeShared, + "openai_oauth_responses_websockets_v2_mode": OpenAIWSIngressModeCtxPool, }, } decision := NewOpenAIWSProtocolResolver(cfg).Resolve(invalidConcurrency) diff --git a/backend/internal/service/openai_ws_v2/caddy_adapter.go b/backend/internal/service/openai_ws_v2/caddy_adapter.go new file mode 100644 index 00000000..1fecc231 --- /dev/null +++ b/backend/internal/service/openai_ws_v2/caddy_adapter.go @@ -0,0 +1,24 @@ +package openai_ws_v2 + +import ( + "context" +) + +// runCaddyStyleRelay 采用 Caddy reverseproxy 的双向隧道思想: +// 连接建立后并发复制两个方向,任一方向退出触发收敛关闭。 +// +// Reference: +// - Project: caddyserver/caddy (Apache-2.0) +// - Commit: f283062d37c50627d53ca682ebae2ce219b35515 +// - Files: +// - modules/caddyhttp/reverseproxy/streaming.go +// - modules/caddyhttp/reverseproxy/reverseproxy.go +func runCaddyStyleRelay( + ctx context.Context, + clientConn FrameConn, + upstreamConn FrameConn, + firstClientMessage []byte, + options RelayOptions, +) (RelayResult, *RelayExit) { + return Relay(ctx, clientConn, upstreamConn, firstClientMessage, options) +} diff --git a/backend/internal/service/openai_ws_v2/entry.go b/backend/internal/service/openai_ws_v2/entry.go new file mode 100644 index 00000000..176298fe --- /dev/null +++ b/backend/internal/service/openai_ws_v2/entry.go @@ -0,0 +1,23 @@ +package openai_ws_v2 + +import "context" + +// EntryInput 是 passthrough v2 数据面的入口参数。 +type EntryInput struct { + Ctx context.Context + ClientConn FrameConn + UpstreamConn FrameConn + FirstClientMessage []byte + Options RelayOptions +} + +// RunEntry 是 openai_ws_v2 包对外的统一入口。 +func RunEntry(input EntryInput) (RelayResult, *RelayExit) { + return runCaddyStyleRelay( + input.Ctx, + input.ClientConn, + input.UpstreamConn, + input.FirstClientMessage, + input.Options, + ) +} diff --git a/backend/internal/service/openai_ws_v2/metrics.go b/backend/internal/service/openai_ws_v2/metrics.go new file mode 100644 index 00000000..3708befd --- /dev/null +++ b/backend/internal/service/openai_ws_v2/metrics.go @@ -0,0 +1,29 @@ +package openai_ws_v2 + +import ( + "sync/atomic" +) + +// MetricsSnapshot 是 OpenAI WS v2 passthrough 路径的轻量运行时指标快照。 +type MetricsSnapshot struct { + SemanticMutationTotal int64 `json:"semantic_mutation_total"` + UsageParseFailureTotal int64 `json:"usage_parse_failure_total"` +} + +var ( + // passthrough 路径默认不会做语义改写,该计数通常应保持为 0(保留用于未来防御性校验)。 + passthroughSemanticMutationTotal atomic.Int64 + passthroughUsageParseFailureTotal atomic.Int64 +) + +func recordUsageParseFailure() { + passthroughUsageParseFailureTotal.Add(1) +} + +// SnapshotMetrics 返回当前 passthrough 指标快照。 +func SnapshotMetrics() MetricsSnapshot { + return MetricsSnapshot{ + SemanticMutationTotal: passthroughSemanticMutationTotal.Load(), + UsageParseFailureTotal: passthroughUsageParseFailureTotal.Load(), + } +} diff --git a/backend/internal/service/openai_ws_v2/passthrough_relay.go b/backend/internal/service/openai_ws_v2/passthrough_relay.go new file mode 100644 index 00000000..af8ee195 --- /dev/null +++ b/backend/internal/service/openai_ws_v2/passthrough_relay.go @@ -0,0 +1,807 @@ +package openai_ws_v2 + +import ( + "context" + "errors" + "io" + "net" + "strconv" + "strings" + "sync/atomic" + "time" + + coderws "github.com/coder/websocket" + "github.com/tidwall/gjson" +) + +type FrameConn interface { + ReadFrame(ctx context.Context) (coderws.MessageType, []byte, error) + WriteFrame(ctx context.Context, msgType coderws.MessageType, payload []byte) error + Close() error +} + +type Usage struct { + InputTokens int + OutputTokens int + CacheCreationInputTokens int + CacheReadInputTokens int +} + +type RelayResult struct { + RequestModel string + Usage Usage + RequestID string + TerminalEventType string + FirstTokenMs *int + Duration time.Duration + ClientToUpstreamFrames int64 + UpstreamToClientFrames int64 + DroppedDownstreamFrames int64 +} + +type RelayTurnResult struct { + RequestModel string + Usage Usage + RequestID string + TerminalEventType string + Duration time.Duration + FirstTokenMs *int +} + +type RelayExit struct { + Stage string + Err error + WroteDownstream bool +} + +type RelayOptions struct { + WriteTimeout time.Duration + IdleTimeout time.Duration + UpstreamDrainTimeout time.Duration + FirstMessageType coderws.MessageType + OnUsageParseFailure func(eventType string, usageRaw string) + OnTurnComplete func(turn RelayTurnResult) + OnTrace func(event RelayTraceEvent) + Now func() time.Time +} + +type RelayTraceEvent struct { + Stage string + Direction string + MessageType string + PayloadBytes int + Graceful bool + WroteDownstream bool + Error string +} + +type relayState struct { + usage Usage + requestModel string + lastResponseID string + terminalEventType string + firstTokenMs *int + turnTimingByID map[string]*relayTurnTiming +} + +type relayExitSignal struct { + stage string + err error + graceful bool + wroteDownstream bool +} + +type observedUpstreamEvent struct { + terminal bool + eventType string + responseID string + usage Usage + duration time.Duration + firstToken *int +} + +type relayTurnTiming struct { + startAt time.Time + firstTokenMs *int +} + +func Relay( + ctx context.Context, + clientConn FrameConn, + upstreamConn FrameConn, + firstClientMessage []byte, + options RelayOptions, +) (RelayResult, *RelayExit) { + result := RelayResult{RequestModel: strings.TrimSpace(gjson.GetBytes(firstClientMessage, "model").String())} + if clientConn == nil || upstreamConn == nil { + return result, &RelayExit{Stage: "relay_init", Err: errors.New("relay connection is nil")} + } + if ctx == nil { + ctx = context.Background() + } + + nowFn := options.Now + if nowFn == nil { + nowFn = time.Now + } + writeTimeout := options.WriteTimeout + if writeTimeout <= 0 { + writeTimeout = 2 * time.Minute + } + drainTimeout := options.UpstreamDrainTimeout + if drainTimeout <= 0 { + drainTimeout = 1200 * time.Millisecond + } + firstMessageType := options.FirstMessageType + if firstMessageType != coderws.MessageBinary { + firstMessageType = coderws.MessageText + } + startAt := nowFn() + state := &relayState{requestModel: result.RequestModel} + onTrace := options.OnTrace + + relayCtx, relayCancel := context.WithCancel(ctx) + defer relayCancel() + + lastActivity := atomic.Int64{} + lastActivity.Store(nowFn().UnixNano()) + markActivity := func() { + lastActivity.Store(nowFn().UnixNano()) + } + + writeUpstream := func(msgType coderws.MessageType, payload []byte) error { + writeCtx, cancel := context.WithTimeout(relayCtx, writeTimeout) + defer cancel() + return upstreamConn.WriteFrame(writeCtx, msgType, payload) + } + writeClient := func(msgType coderws.MessageType, payload []byte) error { + writeCtx, cancel := context.WithTimeout(relayCtx, writeTimeout) + defer cancel() + return clientConn.WriteFrame(writeCtx, msgType, payload) + } + + clientToUpstreamFrames := &atomic.Int64{} + upstreamToClientFrames := &atomic.Int64{} + droppedDownstreamFrames := &atomic.Int64{} + emitRelayTrace(onTrace, RelayTraceEvent{ + Stage: "relay_start", + PayloadBytes: len(firstClientMessage), + MessageType: relayMessageTypeString(firstMessageType), + }) + + if err := writeUpstream(firstMessageType, firstClientMessage); err != nil { + result.Duration = nowFn().Sub(startAt) + emitRelayTrace(onTrace, RelayTraceEvent{ + Stage: "write_first_message_failed", + Direction: "client_to_upstream", + MessageType: relayMessageTypeString(firstMessageType), + PayloadBytes: len(firstClientMessage), + Error: err.Error(), + }) + return result, &RelayExit{Stage: "write_upstream", Err: err} + } + clientToUpstreamFrames.Add(1) + emitRelayTrace(onTrace, RelayTraceEvent{ + Stage: "write_first_message_ok", + Direction: "client_to_upstream", + MessageType: relayMessageTypeString(firstMessageType), + PayloadBytes: len(firstClientMessage), + }) + markActivity() + + exitCh := make(chan relayExitSignal, 3) + dropDownstreamWrites := atomic.Bool{} + go runClientToUpstream(relayCtx, clientConn, writeUpstream, markActivity, clientToUpstreamFrames, onTrace, exitCh) + go runUpstreamToClient( + relayCtx, + upstreamConn, + writeClient, + startAt, + nowFn, + state, + options.OnUsageParseFailure, + options.OnTurnComplete, + &dropDownstreamWrites, + upstreamToClientFrames, + droppedDownstreamFrames, + markActivity, + onTrace, + exitCh, + ) + go runIdleWatchdog(relayCtx, nowFn, options.IdleTimeout, &lastActivity, onTrace, exitCh) + + firstExit := <-exitCh + emitRelayTrace(onTrace, RelayTraceEvent{ + Stage: "first_exit", + Direction: relayDirectionFromStage(firstExit.stage), + Graceful: firstExit.graceful, + WroteDownstream: firstExit.wroteDownstream, + Error: relayErrorString(firstExit.err), + }) + combinedWroteDownstream := firstExit.wroteDownstream + secondExit := relayExitSignal{graceful: true} + hasSecondExit := false + + // 客户端断开后尽力继续读取上游短窗口,捕获延迟 usage/terminal 事件用于计费。 + if firstExit.stage == "read_client" && firstExit.graceful { + dropDownstreamWrites.Store(true) + secondExit, hasSecondExit = waitRelayExit(exitCh, drainTimeout) + } else { + relayCancel() + _ = upstreamConn.Close() + secondExit, hasSecondExit = waitRelayExit(exitCh, 200*time.Millisecond) + } + if hasSecondExit { + combinedWroteDownstream = combinedWroteDownstream || secondExit.wroteDownstream + emitRelayTrace(onTrace, RelayTraceEvent{ + Stage: "second_exit", + Direction: relayDirectionFromStage(secondExit.stage), + Graceful: secondExit.graceful, + WroteDownstream: secondExit.wroteDownstream, + Error: relayErrorString(secondExit.err), + }) + } + + relayCancel() + _ = upstreamConn.Close() + + enrichResult(&result, state, nowFn().Sub(startAt)) + result.ClientToUpstreamFrames = clientToUpstreamFrames.Load() + result.UpstreamToClientFrames = upstreamToClientFrames.Load() + result.DroppedDownstreamFrames = droppedDownstreamFrames.Load() + if firstExit.stage == "read_client" && firstExit.graceful { + stage := "client_disconnected" + exitErr := firstExit.err + if hasSecondExit && !secondExit.graceful { + stage = secondExit.stage + exitErr = secondExit.err + } + if exitErr == nil { + exitErr = io.EOF + } + emitRelayTrace(onTrace, RelayTraceEvent{ + Stage: "relay_exit", + Direction: relayDirectionFromStage(stage), + Graceful: false, + WroteDownstream: combinedWroteDownstream, + Error: relayErrorString(exitErr), + }) + return result, &RelayExit{ + Stage: stage, + Err: exitErr, + WroteDownstream: combinedWroteDownstream, + } + } + if firstExit.graceful && (!hasSecondExit || secondExit.graceful) { + emitRelayTrace(onTrace, RelayTraceEvent{ + Stage: "relay_complete", + Graceful: true, + WroteDownstream: combinedWroteDownstream, + }) + _ = clientConn.Close() + return result, nil + } + if !firstExit.graceful { + emitRelayTrace(onTrace, RelayTraceEvent{ + Stage: "relay_exit", + Direction: relayDirectionFromStage(firstExit.stage), + Graceful: false, + WroteDownstream: combinedWroteDownstream, + Error: relayErrorString(firstExit.err), + }) + return result, &RelayExit{ + Stage: firstExit.stage, + Err: firstExit.err, + WroteDownstream: combinedWroteDownstream, + } + } + if hasSecondExit && !secondExit.graceful { + emitRelayTrace(onTrace, RelayTraceEvent{ + Stage: "relay_exit", + Direction: relayDirectionFromStage(secondExit.stage), + Graceful: false, + WroteDownstream: combinedWroteDownstream, + Error: relayErrorString(secondExit.err), + }) + return result, &RelayExit{ + Stage: secondExit.stage, + Err: secondExit.err, + WroteDownstream: combinedWroteDownstream, + } + } + emitRelayTrace(onTrace, RelayTraceEvent{ + Stage: "relay_complete", + Graceful: true, + WroteDownstream: combinedWroteDownstream, + }) + _ = clientConn.Close() + return result, nil +} + +func runClientToUpstream( + ctx context.Context, + clientConn FrameConn, + writeUpstream func(msgType coderws.MessageType, payload []byte) error, + markActivity func(), + forwardedFrames *atomic.Int64, + onTrace func(event RelayTraceEvent), + exitCh chan<- relayExitSignal, +) { + for { + msgType, payload, err := clientConn.ReadFrame(ctx) + if err != nil { + emitRelayTrace(onTrace, RelayTraceEvent{ + Stage: "read_client_failed", + Direction: "client_to_upstream", + Error: err.Error(), + Graceful: isDisconnectError(err), + }) + exitCh <- relayExitSignal{stage: "read_client", err: err, graceful: isDisconnectError(err)} + return + } + markActivity() + if err := writeUpstream(msgType, payload); err != nil { + emitRelayTrace(onTrace, RelayTraceEvent{ + Stage: "write_upstream_failed", + Direction: "client_to_upstream", + MessageType: relayMessageTypeString(msgType), + PayloadBytes: len(payload), + Error: err.Error(), + }) + exitCh <- relayExitSignal{stage: "write_upstream", err: err} + return + } + if forwardedFrames != nil { + forwardedFrames.Add(1) + } + markActivity() + } +} + +func runUpstreamToClient( + ctx context.Context, + upstreamConn FrameConn, + writeClient func(msgType coderws.MessageType, payload []byte) error, + startAt time.Time, + nowFn func() time.Time, + state *relayState, + onUsageParseFailure func(eventType string, usageRaw string), + onTurnComplete func(turn RelayTurnResult), + dropDownstreamWrites *atomic.Bool, + forwardedFrames *atomic.Int64, + droppedFrames *atomic.Int64, + markActivity func(), + onTrace func(event RelayTraceEvent), + exitCh chan<- relayExitSignal, +) { + wroteDownstream := false + for { + msgType, payload, err := upstreamConn.ReadFrame(ctx) + if err != nil { + emitRelayTrace(onTrace, RelayTraceEvent{ + Stage: "read_upstream_failed", + Direction: "upstream_to_client", + Error: err.Error(), + Graceful: isDisconnectError(err), + WroteDownstream: wroteDownstream, + }) + exitCh <- relayExitSignal{ + stage: "read_upstream", + err: err, + graceful: isDisconnectError(err), + wroteDownstream: wroteDownstream, + } + return + } + markActivity() + observedEvent := observedUpstreamEvent{} + switch msgType { + case coderws.MessageText: + observedEvent = observeUpstreamMessage(state, payload, startAt, nowFn, onUsageParseFailure) + case coderws.MessageBinary: + // binary frame 直接透传,不进入 JSON 观测路径(避免无效解析开销)。 + } + emitTurnComplete(onTurnComplete, state, observedEvent) + if dropDownstreamWrites != nil && dropDownstreamWrites.Load() { + if droppedFrames != nil { + droppedFrames.Add(1) + } + emitRelayTrace(onTrace, RelayTraceEvent{ + Stage: "drop_downstream_frame", + Direction: "upstream_to_client", + MessageType: relayMessageTypeString(msgType), + PayloadBytes: len(payload), + WroteDownstream: wroteDownstream, + }) + if observedEvent.terminal { + exitCh <- relayExitSignal{ + stage: "drain_terminal", + graceful: true, + wroteDownstream: wroteDownstream, + } + return + } + markActivity() + continue + } + if err := writeClient(msgType, payload); err != nil { + emitRelayTrace(onTrace, RelayTraceEvent{ + Stage: "write_client_failed", + Direction: "upstream_to_client", + MessageType: relayMessageTypeString(msgType), + PayloadBytes: len(payload), + WroteDownstream: wroteDownstream, + Error: err.Error(), + }) + exitCh <- relayExitSignal{stage: "write_client", err: err, wroteDownstream: wroteDownstream} + return + } + wroteDownstream = true + if forwardedFrames != nil { + forwardedFrames.Add(1) + } + markActivity() + } +} + +func runIdleWatchdog( + ctx context.Context, + nowFn func() time.Time, + idleTimeout time.Duration, + lastActivity *atomic.Int64, + onTrace func(event RelayTraceEvent), + exitCh chan<- relayExitSignal, +) { + if idleTimeout <= 0 { + return + } + checkInterval := minDuration(idleTimeout/4, 5*time.Second) + if checkInterval < time.Second { + checkInterval = time.Second + } + ticker := time.NewTicker(checkInterval) + defer ticker.Stop() + + for { + select { + case <-ctx.Done(): + return + case <-ticker.C: + last := time.Unix(0, lastActivity.Load()) + if nowFn().Sub(last) < idleTimeout { + continue + } + emitRelayTrace(onTrace, RelayTraceEvent{ + Stage: "idle_timeout_triggered", + Direction: "watchdog", + Error: context.DeadlineExceeded.Error(), + }) + exitCh <- relayExitSignal{stage: "idle_timeout", err: context.DeadlineExceeded} + return + } + } +} + +func emitRelayTrace(onTrace func(event RelayTraceEvent), event RelayTraceEvent) { + if onTrace == nil { + return + } + onTrace(event) +} + +func relayMessageTypeString(msgType coderws.MessageType) string { + switch msgType { + case coderws.MessageText: + return "text" + case coderws.MessageBinary: + return "binary" + default: + return "unknown(" + strconv.Itoa(int(msgType)) + ")" + } +} + +func relayDirectionFromStage(stage string) string { + switch stage { + case "read_client", "write_upstream": + return "client_to_upstream" + case "read_upstream", "write_client", "drain_terminal": + return "upstream_to_client" + case "idle_timeout": + return "watchdog" + default: + return "" + } +} + +func relayErrorString(err error) string { + if err == nil { + return "" + } + return err.Error() +} + +func observeUpstreamMessage( + state *relayState, + message []byte, + startAt time.Time, + nowFn func() time.Time, + onUsageParseFailure func(eventType string, usageRaw string), +) observedUpstreamEvent { + if state == nil || len(message) == 0 { + return observedUpstreamEvent{} + } + values := gjson.GetManyBytes(message, "type", "response.id", "response_id", "id") + eventType := strings.TrimSpace(values[0].String()) + if eventType == "" { + return observedUpstreamEvent{} + } + responseID := strings.TrimSpace(values[1].String()) + if responseID == "" { + responseID = strings.TrimSpace(values[2].String()) + } + // 仅 terminal 事件兜底读取顶层 id,避免把 event_id 当成 response_id 关联到 turn。 + if responseID == "" && isTerminalEvent(eventType) { + responseID = strings.TrimSpace(values[3].String()) + } + now := nowFn() + + if state.firstTokenMs == nil && isTokenEvent(eventType) { + ms := int(now.Sub(startAt).Milliseconds()) + if ms >= 0 { + state.firstTokenMs = &ms + } + } + parsedUsage := parseUsageAndAccumulate(state, message, eventType, onUsageParseFailure) + observed := observedUpstreamEvent{ + eventType: eventType, + responseID: responseID, + usage: parsedUsage, + } + if responseID != "" { + turnTiming := openAIWSRelayGetOrInitTurnTiming(state, responseID, now) + if turnTiming != nil && turnTiming.firstTokenMs == nil && isTokenEvent(eventType) { + ms := int(now.Sub(turnTiming.startAt).Milliseconds()) + if ms >= 0 { + turnTiming.firstTokenMs = &ms + } + } + } + if !isTerminalEvent(eventType) { + return observed + } + observed.terminal = true + state.terminalEventType = eventType + if responseID != "" { + state.lastResponseID = responseID + if turnTiming, ok := openAIWSRelayDeleteTurnTiming(state, responseID); ok { + duration := now.Sub(turnTiming.startAt) + if duration < 0 { + duration = 0 + } + observed.duration = duration + observed.firstToken = openAIWSRelayCloneIntPtr(turnTiming.firstTokenMs) + } + } + return observed +} + +func emitTurnComplete( + onTurnComplete func(turn RelayTurnResult), + state *relayState, + observed observedUpstreamEvent, +) { + if onTurnComplete == nil || !observed.terminal { + return + } + responseID := strings.TrimSpace(observed.responseID) + if responseID == "" { + return + } + requestModel := "" + if state != nil { + requestModel = state.requestModel + } + onTurnComplete(RelayTurnResult{ + RequestModel: requestModel, + Usage: observed.usage, + RequestID: responseID, + TerminalEventType: observed.eventType, + Duration: observed.duration, + FirstTokenMs: openAIWSRelayCloneIntPtr(observed.firstToken), + }) +} + +func openAIWSRelayGetOrInitTurnTiming(state *relayState, responseID string, now time.Time) *relayTurnTiming { + if state == nil { + return nil + } + if state.turnTimingByID == nil { + state.turnTimingByID = make(map[string]*relayTurnTiming, 8) + } + timing, ok := state.turnTimingByID[responseID] + if !ok || timing == nil || timing.startAt.IsZero() { + timing = &relayTurnTiming{startAt: now} + state.turnTimingByID[responseID] = timing + return timing + } + return timing +} + +func openAIWSRelayDeleteTurnTiming(state *relayState, responseID string) (relayTurnTiming, bool) { + if state == nil || state.turnTimingByID == nil { + return relayTurnTiming{}, false + } + timing, ok := state.turnTimingByID[responseID] + if !ok || timing == nil { + return relayTurnTiming{}, false + } + delete(state.turnTimingByID, responseID) + return *timing, true +} + +func openAIWSRelayCloneIntPtr(v *int) *int { + if v == nil { + return nil + } + cloned := *v + return &cloned +} + +func parseUsageAndAccumulate( + state *relayState, + message []byte, + eventType string, + onParseFailure func(eventType string, usageRaw string), +) Usage { + if state == nil || len(message) == 0 || !shouldParseUsage(eventType) { + return Usage{} + } + usageResult := gjson.GetBytes(message, "response.usage") + if !usageResult.Exists() { + return Usage{} + } + usageRaw := strings.TrimSpace(usageResult.Raw) + if usageRaw == "" || !strings.HasPrefix(usageRaw, "{") { + recordUsageParseFailure() + if onParseFailure != nil { + onParseFailure(eventType, usageRaw) + } + return Usage{} + } + + inputResult := gjson.GetBytes(message, "response.usage.input_tokens") + outputResult := gjson.GetBytes(message, "response.usage.output_tokens") + cachedResult := gjson.GetBytes(message, "response.usage.input_tokens_details.cached_tokens") + + inputTokens, inputOK := parseUsageIntField(inputResult, true) + outputTokens, outputOK := parseUsageIntField(outputResult, true) + cachedTokens, cachedOK := parseUsageIntField(cachedResult, false) + if !inputOK || !outputOK || !cachedOK { + recordUsageParseFailure() + if onParseFailure != nil { + onParseFailure(eventType, usageRaw) + } + // 解析失败时不做部分字段累加,避免计费 usage 出现“半有效”状态。 + return Usage{} + } + parsedUsage := Usage{ + InputTokens: inputTokens, + OutputTokens: outputTokens, + CacheReadInputTokens: cachedTokens, + } + + state.usage.InputTokens += parsedUsage.InputTokens + state.usage.OutputTokens += parsedUsage.OutputTokens + state.usage.CacheReadInputTokens += parsedUsage.CacheReadInputTokens + return parsedUsage +} + +func parseUsageIntField(value gjson.Result, required bool) (int, bool) { + if !value.Exists() { + return 0, !required + } + if value.Type != gjson.Number { + return 0, false + } + return int(value.Int()), true +} + +func enrichResult(result *RelayResult, state *relayState, duration time.Duration) { + if result == nil { + return + } + result.Duration = duration + if state == nil { + return + } + result.RequestModel = state.requestModel + result.Usage = state.usage + result.RequestID = state.lastResponseID + result.TerminalEventType = state.terminalEventType + result.FirstTokenMs = state.firstTokenMs +} + +func isDisconnectError(err error) bool { + if err == nil { + return false + } + if errors.Is(err, io.EOF) || errors.Is(err, net.ErrClosed) || errors.Is(err, context.Canceled) { + return true + } + switch coderws.CloseStatus(err) { + case coderws.StatusNormalClosure, coderws.StatusGoingAway, coderws.StatusNoStatusRcvd, coderws.StatusAbnormalClosure: + return true + } + message := strings.ToLower(strings.TrimSpace(err.Error())) + if message == "" { + return false + } + return strings.Contains(message, "failed to read frame header: eof") || + strings.Contains(message, "unexpected eof") || + strings.Contains(message, "use of closed network connection") || + strings.Contains(message, "connection reset by peer") || + strings.Contains(message, "broken pipe") +} + +func isTerminalEvent(eventType string) bool { + switch eventType { + case "response.completed", "response.done", "response.failed", "response.incomplete", "response.cancelled", "response.canceled": + return true + default: + return false + } +} + +func shouldParseUsage(eventType string) bool { + switch eventType { + case "response.completed", "response.done", "response.failed": + return true + default: + return false + } +} + +func isTokenEvent(eventType string) bool { + if eventType == "" { + return false + } + switch eventType { + case "response.created", "response.in_progress", "response.output_item.added", "response.output_item.done": + return false + } + if strings.Contains(eventType, ".delta") { + return true + } + if strings.HasPrefix(eventType, "response.output_text") { + return true + } + if strings.HasPrefix(eventType, "response.output") { + return true + } + return eventType == "response.completed" || eventType == "response.done" +} + +func minDuration(a, b time.Duration) time.Duration { + if a <= 0 { + return b + } + if b <= 0 { + return a + } + if a < b { + return a + } + return b +} + +func waitRelayExit(exitCh <-chan relayExitSignal, timeout time.Duration) (relayExitSignal, bool) { + if timeout <= 0 { + timeout = 200 * time.Millisecond + } + select { + case sig := <-exitCh: + return sig, true + case <-time.After(timeout): + return relayExitSignal{}, false + } +} diff --git a/backend/internal/service/openai_ws_v2/passthrough_relay_internal_test.go b/backend/internal/service/openai_ws_v2/passthrough_relay_internal_test.go new file mode 100644 index 00000000..123e10ce --- /dev/null +++ b/backend/internal/service/openai_ws_v2/passthrough_relay_internal_test.go @@ -0,0 +1,432 @@ +package openai_ws_v2 + +import ( + "context" + "errors" + "io" + "net" + "sync/atomic" + "testing" + "time" + + coderws "github.com/coder/websocket" + "github.com/stretchr/testify/require" + "github.com/tidwall/gjson" +) + +func TestRunEntry_DelegatesRelay(t *testing.T) { + t.Parallel() + + clientConn := newPassthroughTestFrameConn(nil, false) + upstreamConn := newPassthroughTestFrameConn([]passthroughTestFrame{ + { + msgType: coderws.MessageText, + payload: []byte(`{"type":"response.completed","response":{"id":"resp_entry","usage":{"input_tokens":1,"output_tokens":1}}}`), + }, + }, true) + + result, relayExit := RunEntry(EntryInput{ + Ctx: context.Background(), + ClientConn: clientConn, + UpstreamConn: upstreamConn, + FirstClientMessage: []byte(`{"type":"response.create","model":"gpt-4o","input":[]}`), + }) + require.Nil(t, relayExit) + require.Equal(t, "resp_entry", result.RequestID) +} + +func TestRunClientToUpstream_ErrorPaths(t *testing.T) { + t.Parallel() + + t.Run("read client eof", func(t *testing.T) { + t.Parallel() + + exitCh := make(chan relayExitSignal, 1) + runClientToUpstream( + context.Background(), + newPassthroughTestFrameConn(nil, true), + func(_ coderws.MessageType, _ []byte) error { return nil }, + func() {}, + nil, + nil, + exitCh, + ) + sig := <-exitCh + require.Equal(t, "read_client", sig.stage) + require.True(t, sig.graceful) + }) + + t.Run("write upstream failed", func(t *testing.T) { + t.Parallel() + + exitCh := make(chan relayExitSignal, 1) + runClientToUpstream( + context.Background(), + newPassthroughTestFrameConn([]passthroughTestFrame{ + {msgType: coderws.MessageText, payload: []byte(`{"x":1}`)}, + }, true), + func(_ coderws.MessageType, _ []byte) error { return errors.New("boom") }, + func() {}, + nil, + nil, + exitCh, + ) + sig := <-exitCh + require.Equal(t, "write_upstream", sig.stage) + require.False(t, sig.graceful) + }) + + t.Run("forwarded counter and trace callback", func(t *testing.T) { + t.Parallel() + + exitCh := make(chan relayExitSignal, 1) + forwarded := &atomic.Int64{} + traces := make([]RelayTraceEvent, 0, 2) + runClientToUpstream( + context.Background(), + newPassthroughTestFrameConn([]passthroughTestFrame{ + {msgType: coderws.MessageText, payload: []byte(`{"x":1}`)}, + }, true), + func(_ coderws.MessageType, _ []byte) error { return nil }, + func() {}, + forwarded, + func(event RelayTraceEvent) { + traces = append(traces, event) + }, + exitCh, + ) + sig := <-exitCh + require.Equal(t, "read_client", sig.stage) + require.Equal(t, int64(1), forwarded.Load()) + require.NotEmpty(t, traces) + }) +} + +func TestRunUpstreamToClient_ErrorAndDropPaths(t *testing.T) { + t.Parallel() + + t.Run("read upstream eof", func(t *testing.T) { + t.Parallel() + + exitCh := make(chan relayExitSignal, 1) + drop := &atomic.Bool{} + drop.Store(false) + runUpstreamToClient( + context.Background(), + newPassthroughTestFrameConn(nil, true), + func(_ coderws.MessageType, _ []byte) error { return nil }, + time.Now(), + time.Now, + &relayState{}, + nil, + nil, + drop, + nil, + nil, + func() {}, + nil, + exitCh, + ) + sig := <-exitCh + require.Equal(t, "read_upstream", sig.stage) + require.True(t, sig.graceful) + }) + + t.Run("write client failed", func(t *testing.T) { + t.Parallel() + + exitCh := make(chan relayExitSignal, 1) + drop := &atomic.Bool{} + drop.Store(false) + runUpstreamToClient( + context.Background(), + newPassthroughTestFrameConn([]passthroughTestFrame{ + {msgType: coderws.MessageText, payload: []byte(`{"type":"response.output_text.delta","delta":"x"}`)}, + }, true), + func(_ coderws.MessageType, _ []byte) error { return errors.New("write failed") }, + time.Now(), + time.Now, + &relayState{}, + nil, + nil, + drop, + nil, + nil, + func() {}, + nil, + exitCh, + ) + sig := <-exitCh + require.Equal(t, "write_client", sig.stage) + }) + + t.Run("drop downstream and stop on terminal", func(t *testing.T) { + t.Parallel() + + exitCh := make(chan relayExitSignal, 1) + drop := &atomic.Bool{} + drop.Store(true) + dropped := &atomic.Int64{} + runUpstreamToClient( + context.Background(), + newPassthroughTestFrameConn([]passthroughTestFrame{ + { + msgType: coderws.MessageText, + payload: []byte(`{"type":"response.completed","response":{"id":"resp_drop","usage":{"input_tokens":1,"output_tokens":1}}}`), + }, + }, true), + func(_ coderws.MessageType, _ []byte) error { return nil }, + time.Now(), + time.Now, + &relayState{}, + nil, + nil, + drop, + nil, + dropped, + func() {}, + nil, + exitCh, + ) + sig := <-exitCh + require.Equal(t, "drain_terminal", sig.stage) + require.True(t, sig.graceful) + require.Equal(t, int64(1), dropped.Load()) + }) +} + +func TestRunIdleWatchdog_NoTimeoutWhenDisabled(t *testing.T) { + t.Parallel() + + exitCh := make(chan relayExitSignal, 1) + lastActivity := &atomic.Int64{} + lastActivity.Store(time.Now().UnixNano()) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + go runIdleWatchdog(ctx, time.Now, 0, lastActivity, nil, exitCh) + select { + case <-exitCh: + t.Fatal("unexpected idle timeout signal") + case <-time.After(200 * time.Millisecond): + } +} + +func TestHelperFunctionsCoverage(t *testing.T) { + t.Parallel() + + require.Equal(t, "text", relayMessageTypeString(coderws.MessageText)) + require.Equal(t, "binary", relayMessageTypeString(coderws.MessageBinary)) + require.Contains(t, relayMessageTypeString(coderws.MessageType(99)), "unknown(") + + require.Equal(t, "", relayErrorString(nil)) + require.Equal(t, "x", relayErrorString(errors.New("x"))) + + require.True(t, isDisconnectError(io.EOF)) + require.True(t, isDisconnectError(net.ErrClosed)) + require.True(t, isDisconnectError(context.Canceled)) + require.True(t, isDisconnectError(coderws.CloseError{Code: coderws.StatusGoingAway})) + require.True(t, isDisconnectError(errors.New("broken pipe"))) + require.False(t, isDisconnectError(errors.New("unrelated"))) + + require.True(t, isTokenEvent("response.output_text.delta")) + require.True(t, isTokenEvent("response.output_audio.delta")) + require.True(t, isTokenEvent("response.completed")) + require.False(t, isTokenEvent("")) + require.False(t, isTokenEvent("response.created")) + + require.Equal(t, 2*time.Second, minDuration(2*time.Second, 5*time.Second)) + require.Equal(t, 2*time.Second, minDuration(5*time.Second, 2*time.Second)) + require.Equal(t, 5*time.Second, minDuration(0, 5*time.Second)) + require.Equal(t, 2*time.Second, minDuration(2*time.Second, 0)) + + ch := make(chan relayExitSignal, 1) + ch <- relayExitSignal{stage: "ok"} + sig, ok := waitRelayExit(ch, 10*time.Millisecond) + require.True(t, ok) + require.Equal(t, "ok", sig.stage) + ch <- relayExitSignal{stage: "ok2"} + sig, ok = waitRelayExit(ch, 0) + require.True(t, ok) + require.Equal(t, "ok2", sig.stage) + _, ok = waitRelayExit(ch, 10*time.Millisecond) + require.False(t, ok) + + n, ok := parseUsageIntField(gjson.Get(`{"n":3}`, "n"), true) + require.True(t, ok) + require.Equal(t, 3, n) + _, ok = parseUsageIntField(gjson.Get(`{"n":"x"}`, "n"), true) + require.False(t, ok) + n, ok = parseUsageIntField(gjson.Result{}, false) + require.True(t, ok) + require.Equal(t, 0, n) + _, ok = parseUsageIntField(gjson.Result{}, true) + require.False(t, ok) +} + +func TestParseUsageAndEnrichCoverage(t *testing.T) { + t.Parallel() + + state := &relayState{} + parseUsageAndAccumulate(state, []byte(`{"type":"response.completed","response":{"usage":{"input_tokens":"bad"}}}`), "response.completed", nil) + require.Equal(t, 0, state.usage.InputTokens) + + parseUsageAndAccumulate( + state, + []byte(`{"type":"response.completed","response":{"usage":{"input_tokens":9,"output_tokens":"bad","input_tokens_details":{"cached_tokens":2}}}}`), + "response.completed", + nil, + ) + require.Equal(t, 0, state.usage.InputTokens, "部分字段解析失败时不应累加 usage") + require.Equal(t, 0, state.usage.OutputTokens) + require.Equal(t, 0, state.usage.CacheReadInputTokens) + + parseUsageAndAccumulate( + state, + []byte(`{"type":"response.completed","response":{"usage":{"input_tokens_details":{"cached_tokens":2}}}}`), + "response.completed", + nil, + ) + require.Equal(t, 0, state.usage.InputTokens, "必填 usage 字段缺失时不应累加 usage") + require.Equal(t, 0, state.usage.OutputTokens) + require.Equal(t, 0, state.usage.CacheReadInputTokens) + + parseUsageAndAccumulate(state, []byte(`{"type":"response.completed","response":{"usage":{"input_tokens":2,"output_tokens":1,"input_tokens_details":{"cached_tokens":1}}}}`), "response.completed", nil) + require.Equal(t, 2, state.usage.InputTokens) + require.Equal(t, 1, state.usage.OutputTokens) + require.Equal(t, 1, state.usage.CacheReadInputTokens) + + result := &RelayResult{} + enrichResult(result, state, 5*time.Millisecond) + require.Equal(t, state.usage.InputTokens, result.Usage.InputTokens) + require.Equal(t, 5*time.Millisecond, result.Duration) + parseUsageAndAccumulate(state, []byte(`{"type":"response.in_progress","response":{"usage":{"input_tokens":9}}}`), "response.in_progress", nil) + require.Equal(t, 2, state.usage.InputTokens) + enrichResult(nil, state, 0) +} + +func TestEmitTurnCompleteCoverage(t *testing.T) { + t.Parallel() + + // 非 terminal 事件不应触发。 + called := 0 + emitTurnComplete(func(turn RelayTurnResult) { + called++ + }, &relayState{requestModel: "gpt-5"}, observedUpstreamEvent{ + terminal: false, + eventType: "response.output_text.delta", + responseID: "resp_ignored", + usage: Usage{InputTokens: 1}, + }) + require.Equal(t, 0, called) + + // 缺少 response_id 时不应触发。 + emitTurnComplete(func(turn RelayTurnResult) { + called++ + }, &relayState{requestModel: "gpt-5"}, observedUpstreamEvent{ + terminal: true, + eventType: "response.completed", + }) + require.Equal(t, 0, called) + + // terminal 且 response_id 存在,应该触发;state=nil 时 model 为空串。 + var got RelayTurnResult + emitTurnComplete(func(turn RelayTurnResult) { + called++ + got = turn + }, nil, observedUpstreamEvent{ + terminal: true, + eventType: "response.completed", + responseID: "resp_emit", + usage: Usage{InputTokens: 2, OutputTokens: 3}, + }) + require.Equal(t, 1, called) + require.Equal(t, "resp_emit", got.RequestID) + require.Equal(t, "response.completed", got.TerminalEventType) + require.Equal(t, 2, got.Usage.InputTokens) + require.Equal(t, 3, got.Usage.OutputTokens) + require.Equal(t, "", got.RequestModel) +} + +func TestIsDisconnectErrorCoverage_CloseStatusesAndMessageBranches(t *testing.T) { + t.Parallel() + + require.True(t, isDisconnectError(coderws.CloseError{Code: coderws.StatusNormalClosure})) + require.True(t, isDisconnectError(coderws.CloseError{Code: coderws.StatusNoStatusRcvd})) + require.True(t, isDisconnectError(coderws.CloseError{Code: coderws.StatusAbnormalClosure})) + require.True(t, isDisconnectError(errors.New("connection reset by peer"))) + require.False(t, isDisconnectError(errors.New(" "))) +} + +func TestIsTokenEventCoverageBranches(t *testing.T) { + t.Parallel() + + require.False(t, isTokenEvent("response.in_progress")) + require.False(t, isTokenEvent("response.output_item.added")) + require.True(t, isTokenEvent("response.output_audio.delta")) + require.True(t, isTokenEvent("response.output")) + require.True(t, isTokenEvent("response.done")) +} + +func TestRelayTurnTimingHelpersCoverage(t *testing.T) { + t.Parallel() + + now := time.Unix(100, 0) + // nil state + require.Nil(t, openAIWSRelayGetOrInitTurnTiming(nil, "resp_nil", now)) + _, ok := openAIWSRelayDeleteTurnTiming(nil, "resp_nil") + require.False(t, ok) + + state := &relayState{} + timing := openAIWSRelayGetOrInitTurnTiming(state, "resp_a", now) + require.NotNil(t, timing) + require.Equal(t, now, timing.startAt) + + // 再次获取返回同一条 timing + timing2 := openAIWSRelayGetOrInitTurnTiming(state, "resp_a", now.Add(5*time.Second)) + require.NotNil(t, timing2) + require.Equal(t, now, timing2.startAt) + + // 删除存在键 + deleted, ok := openAIWSRelayDeleteTurnTiming(state, "resp_a") + require.True(t, ok) + require.Equal(t, now, deleted.startAt) + + // 删除不存在键 + _, ok = openAIWSRelayDeleteTurnTiming(state, "resp_a") + require.False(t, ok) +} + +func TestObserveUpstreamMessage_ResponseIDFallbackPolicy(t *testing.T) { + t.Parallel() + + state := &relayState{requestModel: "gpt-5"} + startAt := time.Unix(0, 0) + now := startAt + nowFn := func() time.Time { + now = now.Add(5 * time.Millisecond) + return now + } + + // 非 terminal:仅有顶层 id,不应把 event id 当成 response_id。 + observed := observeUpstreamMessage( + state, + []byte(`{"type":"response.output_text.delta","id":"evt_123","delta":"hi"}`), + startAt, + nowFn, + nil, + ) + require.False(t, observed.terminal) + require.Equal(t, "", observed.responseID) + + // terminal:允许兜底用顶层 id(用于兼容少数字段变体)。 + observed = observeUpstreamMessage( + state, + []byte(`{"type":"response.completed","id":"resp_fallback","response":{"usage":{"input_tokens":1,"output_tokens":1}}}`), + startAt, + nowFn, + nil, + ) + require.True(t, observed.terminal) + require.Equal(t, "resp_fallback", observed.responseID) +} diff --git a/backend/internal/service/openai_ws_v2/passthrough_relay_test.go b/backend/internal/service/openai_ws_v2/passthrough_relay_test.go new file mode 100644 index 00000000..ff9b7311 --- /dev/null +++ b/backend/internal/service/openai_ws_v2/passthrough_relay_test.go @@ -0,0 +1,752 @@ +package openai_ws_v2 + +import ( + "context" + "errors" + "io" + "sync" + "sync/atomic" + "testing" + "time" + + coderws "github.com/coder/websocket" + "github.com/stretchr/testify/require" +) + +type passthroughTestFrame struct { + msgType coderws.MessageType + payload []byte +} + +type passthroughTestFrameConn struct { + mu sync.Mutex + writes []passthroughTestFrame + readCh chan passthroughTestFrame + once sync.Once +} + +type delayedReadFrameConn struct { + base FrameConn + firstDelay time.Duration + once sync.Once +} + +type closeSpyFrameConn struct { + closeCalls atomic.Int32 +} + +func newPassthroughTestFrameConn(frames []passthroughTestFrame, autoClose bool) *passthroughTestFrameConn { + c := &passthroughTestFrameConn{ + readCh: make(chan passthroughTestFrame, len(frames)+1), + } + for _, frame := range frames { + copied := passthroughTestFrame{msgType: frame.msgType, payload: append([]byte(nil), frame.payload...)} + c.readCh <- copied + } + if autoClose { + close(c.readCh) + } + return c +} + +func (c *passthroughTestFrameConn) ReadFrame(ctx context.Context) (coderws.MessageType, []byte, error) { + if ctx == nil { + ctx = context.Background() + } + select { + case <-ctx.Done(): + return coderws.MessageText, nil, ctx.Err() + case frame, ok := <-c.readCh: + if !ok { + return coderws.MessageText, nil, io.EOF + } + return frame.msgType, append([]byte(nil), frame.payload...), nil + } +} + +func (c *passthroughTestFrameConn) WriteFrame(ctx context.Context, msgType coderws.MessageType, payload []byte) error { + if ctx == nil { + ctx = context.Background() + } + select { + case <-ctx.Done(): + return ctx.Err() + default: + } + c.mu.Lock() + defer c.mu.Unlock() + c.writes = append(c.writes, passthroughTestFrame{msgType: msgType, payload: append([]byte(nil), payload...)}) + return nil +} + +func (c *passthroughTestFrameConn) Close() error { + c.once.Do(func() { + defer func() { _ = recover() }() + close(c.readCh) + }) + return nil +} + +func (c *passthroughTestFrameConn) Writes() []passthroughTestFrame { + c.mu.Lock() + defer c.mu.Unlock() + out := make([]passthroughTestFrame, len(c.writes)) + copy(out, c.writes) + return out +} + +func (c *delayedReadFrameConn) ReadFrame(ctx context.Context) (coderws.MessageType, []byte, error) { + if c == nil || c.base == nil { + return coderws.MessageText, nil, io.EOF + } + c.once.Do(func() { + if c.firstDelay > 0 { + timer := time.NewTimer(c.firstDelay) + defer timer.Stop() + select { + case <-ctx.Done(): + case <-timer.C: + } + } + }) + return c.base.ReadFrame(ctx) +} + +func (c *delayedReadFrameConn) WriteFrame(ctx context.Context, msgType coderws.MessageType, payload []byte) error { + if c == nil || c.base == nil { + return io.EOF + } + return c.base.WriteFrame(ctx, msgType, payload) +} + +func (c *delayedReadFrameConn) Close() error { + if c == nil || c.base == nil { + return nil + } + return c.base.Close() +} + +func (c *closeSpyFrameConn) ReadFrame(ctx context.Context) (coderws.MessageType, []byte, error) { + if ctx == nil { + ctx = context.Background() + } + <-ctx.Done() + return coderws.MessageText, nil, ctx.Err() +} + +func (c *closeSpyFrameConn) WriteFrame(ctx context.Context, _ coderws.MessageType, _ []byte) error { + if ctx == nil { + ctx = context.Background() + } + select { + case <-ctx.Done(): + return ctx.Err() + default: + return nil + } +} + +func (c *closeSpyFrameConn) Close() error { + if c != nil { + c.closeCalls.Add(1) + } + return nil +} + +func (c *closeSpyFrameConn) CloseCalls() int32 { + if c == nil { + return 0 + } + return c.closeCalls.Load() +} + +func TestRelay_BasicRelayAndUsage(t *testing.T) { + t.Parallel() + + clientConn := newPassthroughTestFrameConn(nil, false) + upstreamConn := newPassthroughTestFrameConn([]passthroughTestFrame{ + { + msgType: coderws.MessageText, + payload: []byte(`{"type":"response.completed","response":{"id":"resp_123","usage":{"input_tokens":7,"output_tokens":3,"input_tokens_details":{"cached_tokens":2}}}}`), + }, + }, true) + + firstPayload := []byte(`{"type":"response.create","model":"gpt-5.3-codex","input":[{"type":"input_text","text":"hello"}]}`) + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() + + result, relayExit := Relay(ctx, clientConn, upstreamConn, firstPayload, RelayOptions{}) + require.Nil(t, relayExit) + require.Equal(t, "gpt-5.3-codex", result.RequestModel) + require.Equal(t, "resp_123", result.RequestID) + require.Equal(t, "response.completed", result.TerminalEventType) + require.Equal(t, 7, result.Usage.InputTokens) + require.Equal(t, 3, result.Usage.OutputTokens) + require.Equal(t, 2, result.Usage.CacheReadInputTokens) + require.NotNil(t, result.FirstTokenMs) + require.Equal(t, int64(1), result.ClientToUpstreamFrames) + require.Equal(t, int64(1), result.UpstreamToClientFrames) + require.Equal(t, int64(0), result.DroppedDownstreamFrames) + + upstreamWrites := upstreamConn.Writes() + require.Len(t, upstreamWrites, 1) + require.Equal(t, coderws.MessageText, upstreamWrites[0].msgType) + require.JSONEq(t, string(firstPayload), string(upstreamWrites[0].payload)) + + clientWrites := clientConn.Writes() + require.Len(t, clientWrites, 1) + require.Equal(t, coderws.MessageText, clientWrites[0].msgType) + require.JSONEq(t, `{"type":"response.completed","response":{"id":"resp_123","usage":{"input_tokens":7,"output_tokens":3,"input_tokens_details":{"cached_tokens":2}}}}`, string(clientWrites[0].payload)) +} + +func TestRelay_FunctionCallOutputBytesPreserved(t *testing.T) { + t.Parallel() + + clientConn := newPassthroughTestFrameConn(nil, false) + upstreamConn := newPassthroughTestFrameConn([]passthroughTestFrame{ + { + msgType: coderws.MessageText, + payload: []byte(`{"type":"response.completed","response":{"id":"resp_func","usage":{"input_tokens":1,"output_tokens":1}}}`), + }, + }, true) + + firstPayload := []byte(`{"type":"response.create","model":"gpt-5.3-codex","input":[{"type":"function_call_output","call_id":"call_abc123","output":"{\"ok\":true}"}]}`) + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() + + _, relayExit := Relay(ctx, clientConn, upstreamConn, firstPayload, RelayOptions{}) + require.Nil(t, relayExit) + + upstreamWrites := upstreamConn.Writes() + require.Len(t, upstreamWrites, 1) + require.Equal(t, coderws.MessageText, upstreamWrites[0].msgType) + require.Equal(t, firstPayload, upstreamWrites[0].payload) +} + +func TestRelay_UpstreamDisconnect(t *testing.T) { + t.Parallel() + + // 上游立即关闭(EOF),客户端不发送额外帧 + clientConn := newPassthroughTestFrameConn(nil, false) + upstreamConn := newPassthroughTestFrameConn(nil, true) // 立即 close -> EOF + + firstPayload := []byte(`{"type":"response.create","model":"gpt-4o","input":[]}`) + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() + + result, relayExit := Relay(ctx, clientConn, upstreamConn, firstPayload, RelayOptions{}) + // 上游 EOF 属于 disconnect,标记为 graceful + require.Nil(t, relayExit, "上游 EOF 应被视为 graceful disconnect") + require.Equal(t, "gpt-4o", result.RequestModel) +} + +func TestRelay_ClientDisconnect(t *testing.T) { + t.Parallel() + + // 客户端立即关闭(EOF),上游阻塞读取直到 context 取消 + clientConn := newPassthroughTestFrameConn(nil, true) // 立即 close -> EOF + upstreamConn := newPassthroughTestFrameConn(nil, false) + + firstPayload := []byte(`{"type":"response.create","model":"gpt-4o","input":[]}`) + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() + + result, relayExit := Relay(ctx, clientConn, upstreamConn, firstPayload, RelayOptions{}) + require.NotNil(t, relayExit, "客户端 EOF 应返回可观测的中断状态") + require.Equal(t, "client_disconnected", relayExit.Stage) + require.Equal(t, "gpt-4o", result.RequestModel) +} + +func TestRelay_ClientDisconnect_DrainCapturesLateUsage(t *testing.T) { + t.Parallel() + + clientConn := newPassthroughTestFrameConn(nil, true) + upstreamBase := newPassthroughTestFrameConn([]passthroughTestFrame{ + { + msgType: coderws.MessageText, + payload: []byte(`{"type":"response.completed","response":{"id":"resp_drain","usage":{"input_tokens":6,"output_tokens":4,"input_tokens_details":{"cached_tokens":1}}}}`), + }, + }, true) + upstreamConn := &delayedReadFrameConn{ + base: upstreamBase, + firstDelay: 80 * time.Millisecond, + } + + firstPayload := []byte(`{"type":"response.create","model":"gpt-4o","input":[]}`) + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() + + result, relayExit := Relay(ctx, clientConn, upstreamConn, firstPayload, RelayOptions{ + UpstreamDrainTimeout: 400 * time.Millisecond, + }) + require.NotNil(t, relayExit) + require.Equal(t, "client_disconnected", relayExit.Stage) + require.Equal(t, "resp_drain", result.RequestID) + require.Equal(t, "response.completed", result.TerminalEventType) + require.Equal(t, 6, result.Usage.InputTokens) + require.Equal(t, 4, result.Usage.OutputTokens) + require.Equal(t, 1, result.Usage.CacheReadInputTokens) + require.Equal(t, int64(1), result.ClientToUpstreamFrames) + require.Equal(t, int64(0), result.UpstreamToClientFrames) + require.Equal(t, int64(1), result.DroppedDownstreamFrames) +} + +func TestRelay_IdleTimeout(t *testing.T) { + t.Parallel() + + // 客户端和上游都不发送帧,idle timeout 应触发 + clientConn := newPassthroughTestFrameConn(nil, false) + upstreamConn := newPassthroughTestFrameConn(nil, false) + + firstPayload := []byte(`{"type":"response.create","model":"gpt-4o","input":[]}`) + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + // 使用快进时间来加速 idle timeout + now := time.Now() + callCount := 0 + nowFn := func() time.Time { + callCount++ + // 前几次调用返回正常时间(初始化阶段),之后快进 + if callCount <= 5 { + return now + } + return now.Add(time.Hour) // 快进到超时 + } + + result, relayExit := Relay(ctx, clientConn, upstreamConn, firstPayload, RelayOptions{ + IdleTimeout: 2 * time.Second, + Now: nowFn, + }) + require.NotNil(t, relayExit, "应因 idle timeout 退出") + require.Equal(t, "idle_timeout", relayExit.Stage) + require.Equal(t, "gpt-4o", result.RequestModel) +} + +func TestRelay_IdleTimeoutDoesNotCloseClientOnError(t *testing.T) { + t.Parallel() + + clientConn := &closeSpyFrameConn{} + upstreamConn := &closeSpyFrameConn{} + + firstPayload := []byte(`{"type":"response.create","model":"gpt-4o","input":[]}`) + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + now := time.Now() + callCount := 0 + nowFn := func() time.Time { + callCount++ + if callCount <= 5 { + return now + } + return now.Add(time.Hour) + } + + _, relayExit := Relay(ctx, clientConn, upstreamConn, firstPayload, RelayOptions{ + IdleTimeout: 2 * time.Second, + Now: nowFn, + }) + require.NotNil(t, relayExit, "应因 idle timeout 退出") + require.Equal(t, "idle_timeout", relayExit.Stage) + require.Zero(t, clientConn.CloseCalls(), "错误路径不应提前关闭客户端连接,交给上层决定 close code") + require.GreaterOrEqual(t, upstreamConn.CloseCalls(), int32(1)) +} + +func TestRelay_NilConnections(t *testing.T) { + t.Parallel() + + firstPayload := []byte(`{"type":"response.create","model":"gpt-4o","input":[]}`) + ctx := context.Background() + + t.Run("nil client conn", func(t *testing.T) { + upstreamConn := newPassthroughTestFrameConn(nil, true) + _, relayExit := Relay(ctx, nil, upstreamConn, firstPayload, RelayOptions{}) + require.NotNil(t, relayExit) + require.Equal(t, "relay_init", relayExit.Stage) + require.Contains(t, relayExit.Err.Error(), "nil") + }) + + t.Run("nil upstream conn", func(t *testing.T) { + clientConn := newPassthroughTestFrameConn(nil, true) + _, relayExit := Relay(ctx, clientConn, nil, firstPayload, RelayOptions{}) + require.NotNil(t, relayExit) + require.Equal(t, "relay_init", relayExit.Stage) + require.Contains(t, relayExit.Err.Error(), "nil") + }) +} + +func TestRelay_MultipleUpstreamMessages(t *testing.T) { + t.Parallel() + + // 上游发送多个事件(delta + completed),验证多帧中继和 usage 聚合 + clientConn := newPassthroughTestFrameConn(nil, false) + upstreamConn := newPassthroughTestFrameConn([]passthroughTestFrame{ + { + msgType: coderws.MessageText, + payload: []byte(`{"type":"response.output_text.delta","delta":"Hello"}`), + }, + { + msgType: coderws.MessageText, + payload: []byte(`{"type":"response.output_text.delta","delta":" world"}`), + }, + { + msgType: coderws.MessageText, + payload: []byte(`{"type":"response.completed","response":{"id":"resp_multi","usage":{"input_tokens":10,"output_tokens":5,"input_tokens_details":{"cached_tokens":3}}}}`), + }, + }, true) + + firstPayload := []byte(`{"type":"response.create","model":"gpt-4o","input":[{"type":"input_text","text":"hi"}]}`) + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() + + result, relayExit := Relay(ctx, clientConn, upstreamConn, firstPayload, RelayOptions{}) + require.Nil(t, relayExit) + require.Equal(t, "resp_multi", result.RequestID) + require.Equal(t, "response.completed", result.TerminalEventType) + require.Equal(t, 10, result.Usage.InputTokens) + require.Equal(t, 5, result.Usage.OutputTokens) + require.Equal(t, 3, result.Usage.CacheReadInputTokens) + require.NotNil(t, result.FirstTokenMs) + + // 验证所有 3 个上游帧都转发给了客户端 + clientWrites := clientConn.Writes() + require.Len(t, clientWrites, 3) +} + +func TestRelay_OnTurnComplete_PerTerminalEvent(t *testing.T) { + t.Parallel() + + clientConn := newPassthroughTestFrameConn(nil, false) + upstreamConn := newPassthroughTestFrameConn([]passthroughTestFrame{ + { + msgType: coderws.MessageText, + payload: []byte(`{"type":"response.completed","response":{"id":"resp_turn_1","usage":{"input_tokens":2,"output_tokens":1}}}`), + }, + { + msgType: coderws.MessageText, + payload: []byte(`{"type":"response.failed","response":{"id":"resp_turn_2","usage":{"input_tokens":3,"output_tokens":4}}}`), + }, + }, true) + + firstPayload := []byte(`{"type":"response.create","model":"gpt-5.3-codex","input":[]}`) + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() + + turns := make([]RelayTurnResult, 0, 2) + result, relayExit := Relay(ctx, clientConn, upstreamConn, firstPayload, RelayOptions{ + OnTurnComplete: func(turn RelayTurnResult) { + turns = append(turns, turn) + }, + }) + require.Nil(t, relayExit) + require.Len(t, turns, 2) + require.Equal(t, "resp_turn_1", turns[0].RequestID) + require.Equal(t, "response.completed", turns[0].TerminalEventType) + require.Equal(t, 2, turns[0].Usage.InputTokens) + require.Equal(t, 1, turns[0].Usage.OutputTokens) + require.Equal(t, "resp_turn_2", turns[1].RequestID) + require.Equal(t, "response.failed", turns[1].TerminalEventType) + require.Equal(t, 3, turns[1].Usage.InputTokens) + require.Equal(t, 4, turns[1].Usage.OutputTokens) + require.Equal(t, 5, result.Usage.InputTokens) + require.Equal(t, 5, result.Usage.OutputTokens) +} + +func TestRelay_OnTurnComplete_ProvidesTurnMetrics(t *testing.T) { + t.Parallel() + + clientConn := newPassthroughTestFrameConn(nil, false) + upstreamConn := newPassthroughTestFrameConn([]passthroughTestFrame{ + { + msgType: coderws.MessageText, + payload: []byte(`{"type":"response.output_text.delta","response_id":"resp_metric","delta":"hi"}`), + }, + { + msgType: coderws.MessageText, + payload: []byte(`{"type":"response.completed","response":{"id":"resp_metric","usage":{"input_tokens":2,"output_tokens":1}}}`), + }, + }, true) + + firstPayload := []byte(`{"type":"response.create","model":"gpt-5.3-codex","input":[]}`) + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() + + base := time.Unix(0, 0) + var nowTick atomic.Int64 + nowFn := func() time.Time { + step := nowTick.Add(1) + return base.Add(time.Duration(step) * 5 * time.Millisecond) + } + + var turn RelayTurnResult + result, relayExit := Relay(ctx, clientConn, upstreamConn, firstPayload, RelayOptions{ + Now: nowFn, + OnTurnComplete: func(current RelayTurnResult) { + turn = current + }, + }) + require.Nil(t, relayExit) + require.Equal(t, "resp_metric", turn.RequestID) + require.Equal(t, "response.completed", turn.TerminalEventType) + require.NotNil(t, turn.FirstTokenMs) + require.GreaterOrEqual(t, *turn.FirstTokenMs, 0) + require.Greater(t, turn.Duration.Milliseconds(), int64(0)) + require.NotNil(t, result.FirstTokenMs) + require.Greater(t, result.Duration.Milliseconds(), int64(0)) +} + +func TestRelay_BinaryFramePassthrough(t *testing.T) { + t.Parallel() + + // 验证 binary frame 被透传但不进行 usage 解析 + binaryPayload := []byte{0x00, 0x01, 0x02, 0x03} + clientConn := newPassthroughTestFrameConn(nil, false) + upstreamConn := newPassthroughTestFrameConn([]passthroughTestFrame{ + { + msgType: coderws.MessageBinary, + payload: binaryPayload, + }, + }, true) + + firstPayload := []byte(`{"type":"response.create","model":"gpt-4o","input":[]}`) + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() + + result, relayExit := Relay(ctx, clientConn, upstreamConn, firstPayload, RelayOptions{}) + require.Nil(t, relayExit) + // binary frame 不解析 usage + require.Equal(t, 0, result.Usage.InputTokens) + + clientWrites := clientConn.Writes() + require.Len(t, clientWrites, 1) + require.Equal(t, coderws.MessageBinary, clientWrites[0].msgType) + require.Equal(t, binaryPayload, clientWrites[0].payload) +} + +func TestRelay_BinaryJSONFrameSkipsObservation(t *testing.T) { + t.Parallel() + + clientConn := newPassthroughTestFrameConn(nil, false) + upstreamConn := newPassthroughTestFrameConn([]passthroughTestFrame{ + { + msgType: coderws.MessageBinary, + payload: []byte(`{"type":"response.completed","response":{"id":"resp_binary","usage":{"input_tokens":7,"output_tokens":3}}}`), + }, + }, true) + + firstPayload := []byte(`{"type":"response.create","model":"gpt-4o","input":[]}`) + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() + + result, relayExit := Relay(ctx, clientConn, upstreamConn, firstPayload, RelayOptions{}) + require.Nil(t, relayExit) + require.Equal(t, 0, result.Usage.InputTokens) + require.Equal(t, "", result.RequestID) + require.Equal(t, "", result.TerminalEventType) + + clientWrites := clientConn.Writes() + require.Len(t, clientWrites, 1) + require.Equal(t, coderws.MessageBinary, clientWrites[0].msgType) +} + +func TestRelay_UpstreamErrorEventPassthroughRaw(t *testing.T) { + t.Parallel() + + clientConn := newPassthroughTestFrameConn(nil, false) + errorEvent := []byte(`{"type":"error","error":{"type":"invalid_request_error","message":"No tool call found"}}`) + upstreamConn := newPassthroughTestFrameConn([]passthroughTestFrame{ + { + msgType: coderws.MessageText, + payload: errorEvent, + }, + }, true) + + firstPayload := []byte(`{"type":"response.create","model":"gpt-4o","input":[]}`) + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() + + _, relayExit := Relay(ctx, clientConn, upstreamConn, firstPayload, RelayOptions{}) + require.Nil(t, relayExit) + + clientWrites := clientConn.Writes() + require.Len(t, clientWrites, 1) + require.Equal(t, coderws.MessageText, clientWrites[0].msgType) + require.Equal(t, errorEvent, clientWrites[0].payload) +} + +func TestRelay_PreservesFirstMessageType(t *testing.T) { + t.Parallel() + + clientConn := newPassthroughTestFrameConn(nil, false) + upstreamConn := newPassthroughTestFrameConn(nil, true) + + firstPayload := []byte(`{"type":"response.create","model":"gpt-4o","input":[]}`) + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() + + _, relayExit := Relay(ctx, clientConn, upstreamConn, firstPayload, RelayOptions{ + FirstMessageType: coderws.MessageBinary, + }) + require.Nil(t, relayExit) + + upstreamWrites := upstreamConn.Writes() + require.Len(t, upstreamWrites, 1) + require.Equal(t, coderws.MessageBinary, upstreamWrites[0].msgType) + require.Equal(t, firstPayload, upstreamWrites[0].payload) +} + +func TestRelay_UsageParseFailureDoesNotBlockRelay(t *testing.T) { + baseline := SnapshotMetrics().UsageParseFailureTotal + + // 上游发送无效 JSON(非 usage 格式),不应影响透传 + clientConn := newPassthroughTestFrameConn(nil, false) + upstreamConn := newPassthroughTestFrameConn([]passthroughTestFrame{ + { + msgType: coderws.MessageText, + payload: []byte(`{"type":"response.completed","response":{"id":"resp_bad","usage":"not_an_object"}}`), + }, + }, true) + + firstPayload := []byte(`{"type":"response.create","model":"gpt-4o","input":[]}`) + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() + + result, relayExit := Relay(ctx, clientConn, upstreamConn, firstPayload, RelayOptions{}) + require.Nil(t, relayExit) + // usage 解析失败,值为 0 但不影响透传 + require.Equal(t, 0, result.Usage.InputTokens) + require.Equal(t, "response.completed", result.TerminalEventType) + + // 帧仍然被转发 + clientWrites := clientConn.Writes() + require.Len(t, clientWrites, 1) + require.GreaterOrEqual(t, SnapshotMetrics().UsageParseFailureTotal, baseline+1) +} + +func TestRelay_WriteUpstreamFirstMessageFails(t *testing.T) { + t.Parallel() + + // 上游连接立即关闭,首包写入失败 + upstreamConn := newPassthroughTestFrameConn(nil, true) + _ = upstreamConn.Close() + + // 覆盖 WriteFrame 使其返回错误 + errConn := &errorOnWriteFrameConn{} + clientConn := newPassthroughTestFrameConn(nil, false) + + firstPayload := []byte(`{"type":"response.create","model":"gpt-4o","input":[]}`) + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() + + _, relayExit := Relay(ctx, clientConn, errConn, firstPayload, RelayOptions{}) + require.NotNil(t, relayExit) + require.Equal(t, "write_upstream", relayExit.Stage) +} + +func TestRelay_ContextCanceled(t *testing.T) { + t.Parallel() + + clientConn := newPassthroughTestFrameConn(nil, false) + upstreamConn := newPassthroughTestFrameConn(nil, false) + + firstPayload := []byte(`{"type":"response.create","model":"gpt-4o","input":[]}`) + + // 立即取消 context + ctx, cancel := context.WithCancel(context.Background()) + cancel() + + _, relayExit := Relay(ctx, clientConn, upstreamConn, firstPayload, RelayOptions{}) + // context 取消导致写首包失败 + require.NotNil(t, relayExit) +} + +func TestRelay_TraceEvents_ContainsLifecycleStages(t *testing.T) { + t.Parallel() + + clientConn := newPassthroughTestFrameConn(nil, false) + upstreamConn := newPassthroughTestFrameConn([]passthroughTestFrame{ + { + msgType: coderws.MessageText, + payload: []byte(`{"type":"response.completed","response":{"id":"resp_trace","usage":{"input_tokens":1,"output_tokens":1}}}`), + }, + }, true) + + firstPayload := []byte(`{"type":"response.create","model":"gpt-4o","input":[]}`) + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() + + stages := make([]string, 0, 8) + var stagesMu sync.Mutex + _, relayExit := Relay(ctx, clientConn, upstreamConn, firstPayload, RelayOptions{ + OnTrace: func(event RelayTraceEvent) { + stagesMu.Lock() + stages = append(stages, event.Stage) + stagesMu.Unlock() + }, + }) + require.Nil(t, relayExit) + stagesMu.Lock() + capturedStages := append([]string(nil), stages...) + stagesMu.Unlock() + require.Contains(t, capturedStages, "relay_start") + require.Contains(t, capturedStages, "write_first_message_ok") + require.Contains(t, capturedStages, "first_exit") + require.Contains(t, capturedStages, "relay_complete") +} + +func TestRelay_TraceEvents_IdleTimeout(t *testing.T) { + t.Parallel() + + clientConn := newPassthroughTestFrameConn(nil, false) + upstreamConn := newPassthroughTestFrameConn(nil, false) + + firstPayload := []byte(`{"type":"response.create","model":"gpt-4o","input":[]}`) + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + now := time.Now() + callCount := 0 + nowFn := func() time.Time { + callCount++ + if callCount <= 5 { + return now + } + return now.Add(time.Hour) + } + + stages := make([]string, 0, 8) + var stagesMu sync.Mutex + _, relayExit := Relay(ctx, clientConn, upstreamConn, firstPayload, RelayOptions{ + IdleTimeout: 2 * time.Second, + Now: nowFn, + OnTrace: func(event RelayTraceEvent) { + stagesMu.Lock() + stages = append(stages, event.Stage) + stagesMu.Unlock() + }, + }) + require.NotNil(t, relayExit) + require.Equal(t, "idle_timeout", relayExit.Stage) + stagesMu.Lock() + capturedStages := append([]string(nil), stages...) + stagesMu.Unlock() + require.Contains(t, capturedStages, "idle_timeout_triggered") + require.Contains(t, capturedStages, "relay_exit") +} + +// errorOnWriteFrameConn 是一个写入总是失败的 FrameConn 实现,用于测试首包写入失败。 +type errorOnWriteFrameConn struct{} + +func (c *errorOnWriteFrameConn) ReadFrame(ctx context.Context) (coderws.MessageType, []byte, error) { + <-ctx.Done() + return coderws.MessageText, nil, ctx.Err() +} + +func (c *errorOnWriteFrameConn) WriteFrame(_ context.Context, _ coderws.MessageType, _ []byte) error { + return errors.New("write failed: connection refused") +} + +func (c *errorOnWriteFrameConn) Close() error { + return nil +} diff --git a/backend/internal/service/openai_ws_v2_passthrough_adapter.go b/backend/internal/service/openai_ws_v2_passthrough_adapter.go new file mode 100644 index 00000000..3b429f4d --- /dev/null +++ b/backend/internal/service/openai_ws_v2_passthrough_adapter.go @@ -0,0 +1,367 @@ +package service + +import ( + "context" + "errors" + "fmt" + "net/http" + "net/url" + "strings" + "sync/atomic" + + "github.com/Wei-Shaw/sub2api/internal/pkg/logger" + "github.com/Wei-Shaw/sub2api/internal/pkg/openai" + openaiwsv2 "github.com/Wei-Shaw/sub2api/internal/service/openai_ws_v2" + coderws "github.com/coder/websocket" + "github.com/gin-gonic/gin" + "github.com/tidwall/gjson" +) + +type openAIWSClientFrameConn struct { + conn *coderws.Conn +} + +const openaiWSV2PassthroughModeFields = "ws_mode=passthrough ws_router=v2" + +var _ openaiwsv2.FrameConn = (*openAIWSClientFrameConn)(nil) + +func (c *openAIWSClientFrameConn) ReadFrame(ctx context.Context) (coderws.MessageType, []byte, error) { + if c == nil || c.conn == nil { + return coderws.MessageText, nil, errOpenAIWSConnClosed + } + if ctx == nil { + ctx = context.Background() + } + return c.conn.Read(ctx) +} + +func (c *openAIWSClientFrameConn) WriteFrame(ctx context.Context, msgType coderws.MessageType, payload []byte) error { + if c == nil || c.conn == nil { + return errOpenAIWSConnClosed + } + if ctx == nil { + ctx = context.Background() + } + return c.conn.Write(ctx, msgType, payload) +} + +func (c *openAIWSClientFrameConn) Close() error { + if c == nil || c.conn == nil { + return nil + } + _ = c.conn.Close(coderws.StatusNormalClosure, "") + _ = c.conn.CloseNow() + return nil +} + +func (s *OpenAIGatewayService) proxyResponsesWebSocketV2Passthrough( + ctx context.Context, + c *gin.Context, + clientConn *coderws.Conn, + account *Account, + token string, + firstClientMessage []byte, + hooks *OpenAIWSIngressHooks, + wsDecision OpenAIWSProtocolDecision, +) error { + if s == nil { + return errors.New("service is nil") + } + if clientConn == nil { + return errors.New("client websocket is nil") + } + if account == nil { + return errors.New("account is nil") + } + if strings.TrimSpace(token) == "" { + return errors.New("token is empty") + } + requestModel := strings.TrimSpace(gjson.GetBytes(firstClientMessage, "model").String()) + requestPreviousResponseID := strings.TrimSpace(gjson.GetBytes(firstClientMessage, "previous_response_id").String()) + logOpenAIWSV2Passthrough( + "relay_start account_id=%d model=%s previous_response_id=%s first_message_type=%s first_message_bytes=%d", + account.ID, + truncateOpenAIWSLogValue(requestModel, openAIWSLogValueMaxLen), + truncateOpenAIWSLogValue(requestPreviousResponseID, openAIWSIDValueMaxLen), + openaiwsv2RelayMessageTypeName(coderws.MessageText), + len(firstClientMessage), + ) + + wsURL, err := s.buildOpenAIResponsesWSURL(account) + if err != nil { + return fmt.Errorf("build ws url: %w", err) + } + wsHost := "-" + wsPath := "-" + if parsedURL, parseErr := url.Parse(wsURL); parseErr == nil && parsedURL != nil { + wsHost = normalizeOpenAIWSLogValue(parsedURL.Host) + wsPath = normalizeOpenAIWSLogValue(parsedURL.Path) + } + logOpenAIWSV2Passthrough( + "relay_dial_start account_id=%d ws_host=%s ws_path=%s proxy_enabled=%v", + account.ID, + wsHost, + wsPath, + account.ProxyID != nil && account.Proxy != nil, + ) + + isCodexCLI := false + if c != nil { + isCodexCLI = openai.IsCodexCLIRequest(c.GetHeader("User-Agent")) + } + if s.cfg != nil && s.cfg.Gateway.ForceCodexCLI { + isCodexCLI = true + } + headers, _ := s.buildOpenAIWSHeaders(c, account, token, wsDecision, isCodexCLI, "", "", "") + proxyURL := "" + if account.ProxyID != nil && account.Proxy != nil { + proxyURL = account.Proxy.URL() + } + + dialer := s.getOpenAIWSPassthroughDialer() + if dialer == nil { + return errors.New("openai ws passthrough dialer is nil") + } + + dialCtx, cancelDial := context.WithTimeout(ctx, s.openAIWSDialTimeout()) + defer cancelDial() + upstreamConn, statusCode, handshakeHeaders, err := dialer.Dial(dialCtx, wsURL, headers, proxyURL) + if err != nil { + logOpenAIWSV2Passthrough( + "relay_dial_failed account_id=%d status_code=%d err=%s", + account.ID, + statusCode, + truncateOpenAIWSLogValue(err.Error(), openAIWSLogValueMaxLen), + ) + return s.mapOpenAIWSPassthroughDialError(err, statusCode, handshakeHeaders) + } + defer func() { + _ = upstreamConn.Close() + }() + logOpenAIWSV2Passthrough( + "relay_dial_ok account_id=%d status_code=%d upstream_request_id=%s", + account.ID, + statusCode, + openAIWSHeaderValueForLog(handshakeHeaders, "x-request-id"), + ) + + upstreamFrameConn, ok := upstreamConn.(openaiwsv2.FrameConn) + if !ok { + return errors.New("openai ws passthrough upstream connection does not support frame relay") + } + + completedTurns := atomic.Int32{} + relayResult, relayExit := openaiwsv2.RunEntry(openaiwsv2.EntryInput{ + Ctx: ctx, + ClientConn: &openAIWSClientFrameConn{conn: clientConn}, + UpstreamConn: upstreamFrameConn, + FirstClientMessage: firstClientMessage, + Options: openaiwsv2.RelayOptions{ + WriteTimeout: s.openAIWSWriteTimeout(), + IdleTimeout: s.openAIWSPassthroughIdleTimeout(), + FirstMessageType: coderws.MessageText, + OnUsageParseFailure: func(eventType string, usageRaw string) { + logOpenAIWSV2Passthrough( + "usage_parse_failed event_type=%s usage_raw=%s", + truncateOpenAIWSLogValue(eventType, openAIWSLogValueMaxLen), + truncateOpenAIWSLogValue(usageRaw, openAIWSLogValueMaxLen), + ) + }, + OnTurnComplete: func(turn openaiwsv2.RelayTurnResult) { + turnNo := int(completedTurns.Add(1)) + turnResult := &OpenAIForwardResult{ + RequestID: turn.RequestID, + Usage: OpenAIUsage{ + InputTokens: turn.Usage.InputTokens, + OutputTokens: turn.Usage.OutputTokens, + CacheCreationInputTokens: turn.Usage.CacheCreationInputTokens, + CacheReadInputTokens: turn.Usage.CacheReadInputTokens, + }, + Model: turn.RequestModel, + Stream: true, + OpenAIWSMode: true, + Duration: turn.Duration, + FirstTokenMs: turn.FirstTokenMs, + } + logOpenAIWSV2Passthrough( + "relay_turn_completed account_id=%d turn=%d request_id=%s terminal_event=%s duration_ms=%d first_token_ms=%d input_tokens=%d output_tokens=%d cache_read_tokens=%d", + account.ID, + turnNo, + truncateOpenAIWSLogValue(turnResult.RequestID, openAIWSIDValueMaxLen), + truncateOpenAIWSLogValue(turn.TerminalEventType, openAIWSLogValueMaxLen), + turnResult.Duration.Milliseconds(), + openAIWSFirstTokenMsForLog(turnResult.FirstTokenMs), + turnResult.Usage.InputTokens, + turnResult.Usage.OutputTokens, + turnResult.Usage.CacheReadInputTokens, + ) + if hooks != nil && hooks.AfterTurn != nil { + hooks.AfterTurn(turnNo, turnResult, nil) + } + }, + OnTrace: func(event openaiwsv2.RelayTraceEvent) { + logOpenAIWSV2Passthrough( + "relay_trace account_id=%d stage=%s direction=%s msg_type=%s bytes=%d graceful=%v wrote_downstream=%v err=%s", + account.ID, + truncateOpenAIWSLogValue(event.Stage, openAIWSLogValueMaxLen), + truncateOpenAIWSLogValue(event.Direction, openAIWSLogValueMaxLen), + truncateOpenAIWSLogValue(event.MessageType, openAIWSLogValueMaxLen), + event.PayloadBytes, + event.Graceful, + event.WroteDownstream, + truncateOpenAIWSLogValue(event.Error, openAIWSLogValueMaxLen), + ) + }, + }, + }) + + result := &OpenAIForwardResult{ + RequestID: relayResult.RequestID, + Usage: OpenAIUsage{ + InputTokens: relayResult.Usage.InputTokens, + OutputTokens: relayResult.Usage.OutputTokens, + CacheCreationInputTokens: relayResult.Usage.CacheCreationInputTokens, + CacheReadInputTokens: relayResult.Usage.CacheReadInputTokens, + }, + Model: relayResult.RequestModel, + Stream: true, + OpenAIWSMode: true, + Duration: relayResult.Duration, + FirstTokenMs: relayResult.FirstTokenMs, + } + + turnCount := int(completedTurns.Load()) + if relayExit == nil { + logOpenAIWSV2Passthrough( + "relay_completed account_id=%d request_id=%s terminal_event=%s duration_ms=%d c2u_frames=%d u2c_frames=%d dropped_frames=%d turns=%d", + account.ID, + truncateOpenAIWSLogValue(result.RequestID, openAIWSIDValueMaxLen), + truncateOpenAIWSLogValue(relayResult.TerminalEventType, openAIWSLogValueMaxLen), + result.Duration.Milliseconds(), + relayResult.ClientToUpstreamFrames, + relayResult.UpstreamToClientFrames, + relayResult.DroppedDownstreamFrames, + turnCount, + ) + // 正常路径按 terminal 事件逐 turn 已回调;仅在零 turn 场景兜底回调一次。 + if turnCount == 0 && hooks != nil && hooks.AfterTurn != nil { + hooks.AfterTurn(1, result, nil) + } + return nil + } + logOpenAIWSV2Passthrough( + "relay_failed account_id=%d stage=%s wrote_downstream=%v err=%s duration_ms=%d c2u_frames=%d u2c_frames=%d dropped_frames=%d turns=%d", + account.ID, + truncateOpenAIWSLogValue(relayExit.Stage, openAIWSLogValueMaxLen), + relayExit.WroteDownstream, + truncateOpenAIWSLogValue(relayErrorText(relayExit.Err), openAIWSLogValueMaxLen), + result.Duration.Milliseconds(), + relayResult.ClientToUpstreamFrames, + relayResult.UpstreamToClientFrames, + relayResult.DroppedDownstreamFrames, + turnCount, + ) + + relayErr := relayExit.Err + if relayExit.Stage == "idle_timeout" { + relayErr = NewOpenAIWSClientCloseError( + coderws.StatusPolicyViolation, + "client websocket idle timeout", + relayErr, + ) + } + turnErr := wrapOpenAIWSIngressTurnError( + relayExit.Stage, + relayErr, + relayExit.WroteDownstream, + ) + if hooks != nil && hooks.AfterTurn != nil { + hooks.AfterTurn(turnCount+1, nil, turnErr) + } + return turnErr +} + +func (s *OpenAIGatewayService) mapOpenAIWSPassthroughDialError( + err error, + statusCode int, + handshakeHeaders http.Header, +) error { + if err == nil { + return nil + } + wrappedErr := err + var dialErr *openAIWSDialError + if !errors.As(err, &dialErr) { + wrappedErr = &openAIWSDialError{ + StatusCode: statusCode, + ResponseHeaders: cloneHeader(handshakeHeaders), + Err: err, + } + } + + if errors.Is(err, context.Canceled) { + return err + } + if errors.Is(err, context.DeadlineExceeded) { + return NewOpenAIWSClientCloseError( + coderws.StatusTryAgainLater, + "upstream websocket connect timeout", + wrappedErr, + ) + } + if statusCode == http.StatusTooManyRequests { + return NewOpenAIWSClientCloseError( + coderws.StatusTryAgainLater, + "upstream websocket is busy, please retry later", + wrappedErr, + ) + } + if statusCode == http.StatusUnauthorized || statusCode == http.StatusForbidden { + return NewOpenAIWSClientCloseError( + coderws.StatusPolicyViolation, + "upstream websocket authentication failed", + wrappedErr, + ) + } + if statusCode >= http.StatusBadRequest && statusCode < http.StatusInternalServerError { + return NewOpenAIWSClientCloseError( + coderws.StatusPolicyViolation, + "upstream websocket handshake rejected", + wrappedErr, + ) + } + return fmt.Errorf("openai ws passthrough dial: %w", wrappedErr) +} + +func openaiwsv2RelayMessageTypeName(msgType coderws.MessageType) string { + switch msgType { + case coderws.MessageText: + return "text" + case coderws.MessageBinary: + return "binary" + default: + return fmt.Sprintf("unknown(%d)", msgType) + } +} + +func relayErrorText(err error) string { + if err == nil { + return "" + } + return err.Error() +} + +func openAIWSFirstTokenMsForLog(firstTokenMs *int) int { + if firstTokenMs == nil { + return -1 + } + return *firstTokenMs +} + +func logOpenAIWSV2Passthrough(format string, args ...any) { + logger.LegacyPrintf( + "service.openai_ws_v2", + "[OpenAI WS v2 passthrough] %s "+format, + append([]any{openaiWSV2PassthroughModeFields}, args...)..., + ) +} diff --git a/deploy/config.example.yaml b/deploy/config.example.yaml index e2eb3130..2058ced1 100644 --- a/deploy/config.example.yaml +++ b/deploy/config.example.yaml @@ -209,8 +209,9 @@ gateway: openai_ws: # 新版 WS mode 路由(默认关闭)。关闭时保持当前 legacy 实现行为。 mode_router_v2_enabled: false - # ingress 默认模式:off|shared|dedicated(仅 mode_router_v2_enabled=true 生效) - ingress_mode_default: shared + # ingress 默认模式:off|ctx_pool|passthrough(仅 mode_router_v2_enabled=true 生效) + # 兼容旧值:shared/dedicated 会按 ctx_pool 处理。 + ingress_mode_default: ctx_pool # 全局总开关,默认 true;关闭时所有请求保持原有 HTTP/SSE 路由 enabled: true # 按账号类型细分开关 diff --git a/docker-compose-aicodex.yml b/docker-compose-aicodex.yml new file mode 100644 index 00000000..74ab0379 --- /dev/null +++ b/docker-compose-aicodex.yml @@ -0,0 +1,263 @@ +# ============================================================================= +# aicodex2api Docker Compose Host Configuration (Local Build) +# ============================================================================= +# Quick Start: +# 1. Copy .env.example to .env and configure +# 2. docker-compose -f docker-compose-host.yml up -d --build +# 3. Check logs: docker-compose -f docker-compose-host.yml logs -f aicodex2api +# 4. Access: http://localhost:8080 +# +# This configuration builds the image from source (Dockerfile in project root). +# All configuration is done via environment variables. +# No Setup Wizard needed - the system auto-initializes on first run. +# ============================================================================= + +services: + # =========================================================================== + # aicodex2api Application + # =========================================================================== + aicodex2api: + image: yangjianbo/aicodex2api:latest + build: + context: .. + dockerfile: Dockerfile + container_name: aicodex2api + restart: unless-stopped + network_mode: host + ulimits: + nofile: + soft: 800000 + hard: 800000 + volumes: + # Data persistence (config.yaml will be auto-generated here) + - aicodex2api_data:/app/data + # Mount custom config.yaml (optional, overrides auto-generated config) + #- ./config.yaml:/app/data/config.yaml:ro + environment: + # ======================================================================= + # Auto Setup (REQUIRED for Docker deployment) + # ======================================================================= + - AUTO_SETUP=true + + # ======================================================================= + # Server Configuration + # ======================================================================= + - SERVER_HOST=0.0.0.0 + - SERVER_PORT=8080 + - SERVER_MODE=${SERVER_MODE:-release} + - RUN_MODE=${RUN_MODE:-standard} + # 新用户默认并发(仅影响新注册用户;已有用户请在后台或数据库单独调整) + - DEFAULT_USER_CONCURRENCY=${DEFAULT_USER_CONCURRENCY:-12} + + # ======================================================================= + # Database Configuration (PostgreSQL) + # ======================================================================= + # Using host network: point to host/external DB by DATABASE_HOST/DATABASE_PORT + - DATABASE_HOST=${DATABASE_HOST:-127.0.0.1} + - DATABASE_PORT=${DATABASE_PORT:-5432} + - DATABASE_USER=${POSTGRES_USER:-aicodex2api} + - DATABASE_PASSWORD=${POSTGRES_PASSWORD:?POSTGRES_PASSWORD is required} + - DATABASE_DBNAME=${POSTGRES_DB:-aicodex2api} + - DATABASE_SSLMODE=disable + - DATABASE_MAX_OPEN_CONNS=${DATABASE_MAX_OPEN_CONNS:-50} + - DATABASE_MAX_IDLE_CONNS=${DATABASE_MAX_IDLE_CONNS:-10} + - DATABASE_CONN_MAX_LIFETIME_MINUTES=${DATABASE_CONN_MAX_LIFETIME_MINUTES:-30} + - DATABASE_CONN_MAX_IDLE_TIME_MINUTES=${DATABASE_CONN_MAX_IDLE_TIME_MINUTES:-5} + + # ======================================================================= + # Gateway Configuration + # ======================================================================= + - GATEWAY_FORCE_CODEX_CLI=${GATEWAY_FORCE_CODEX_CLI:-false} + - GATEWAY_OPENAI_WS_ENABLED=${GATEWAY_OPENAI_WS_ENABLED:-true} + - GATEWAY_OPENAI_WS_OAUTH_ENABLED=${GATEWAY_OPENAI_WS_OAUTH_ENABLED:-true} + - GATEWAY_OPENAI_WS_APIKEY_ENABLED=${GATEWAY_OPENAI_WS_APIKEY_ENABLED:-true} + - GATEWAY_OPENAI_WS_FORCE_HTTP=${GATEWAY_OPENAI_WS_FORCE_HTTP:-false} + - GATEWAY_OPENAI_WS_RESPONSES_WEBSOCKETS_V2=${GATEWAY_OPENAI_WS_RESPONSES_WEBSOCKETS_V2:-true} + # 多窗口场景建议 adaptive:兼顾会话隔离与连接复用 + - GATEWAY_OPENAI_WS_STORE_DISABLED_CONN_MODE=${GATEWAY_OPENAI_WS_STORE_DISABLED_CONN_MODE:-adaptive} + - GATEWAY_OPENAI_WS_MAX_CONNS_PER_ACCOUNT=${GATEWAY_OPENAI_WS_MAX_CONNS_PER_ACCOUNT:-128} + - GATEWAY_OPENAI_WS_MIN_IDLE_PER_ACCOUNT=${GATEWAY_OPENAI_WS_MIN_IDLE_PER_ACCOUNT:-4} + - GATEWAY_OPENAI_WS_MAX_IDLE_PER_ACCOUNT=${GATEWAY_OPENAI_WS_MAX_IDLE_PER_ACCOUNT:-16} + - GATEWAY_OPENAI_WS_DYNAMIC_MAX_CONNS_BY_ACCOUNT_CONCURRENCY_ENABLED=${GATEWAY_OPENAI_WS_DYNAMIC_MAX_CONNS_BY_ACCOUNT_CONCURRENCY_ENABLED:-true} + - GATEWAY_OPENAI_WS_OAUTH_MAX_CONNS_FACTOR=${GATEWAY_OPENAI_WS_OAUTH_MAX_CONNS_FACTOR:-1.5} + - GATEWAY_OPENAI_WS_APIKEY_MAX_CONNS_FACTOR=${GATEWAY_OPENAI_WS_APIKEY_MAX_CONNS_FACTOR:-1.5} + - GATEWAY_OPENAI_WS_DIAL_TIMEOUT_SECONDS=${GATEWAY_OPENAI_WS_DIAL_TIMEOUT_SECONDS:-15} + - GATEWAY_OPENAI_WS_READ_TIMEOUT_SECONDS=${GATEWAY_OPENAI_WS_READ_TIMEOUT_SECONDS:-900} + - GATEWAY_OPENAI_WS_WRITE_TIMEOUT_SECONDS=${GATEWAY_OPENAI_WS_WRITE_TIMEOUT_SECONDS:-120} + - GATEWAY_OPENAI_WS_QUEUE_LIMIT_PER_CONN=${GATEWAY_OPENAI_WS_QUEUE_LIMIT_PER_CONN:-128} + - GATEWAY_OPENAI_WS_RETRY_BACKOFF_INITIAL_MS=${GATEWAY_OPENAI_WS_RETRY_BACKOFF_INITIAL_MS:-150} + - GATEWAY_OPENAI_WS_RETRY_BACKOFF_MAX_MS=${GATEWAY_OPENAI_WS_RETRY_BACKOFF_MAX_MS:-3000} + - GATEWAY_OPENAI_WS_RETRY_TOTAL_BUDGET_MS=${GATEWAY_OPENAI_WS_RETRY_TOTAL_BUDGET_MS:-15000} + - GATEWAY_MAX_IDLE_CONNS=${GATEWAY_MAX_IDLE_CONNS:-2560} + - GATEWAY_MAX_IDLE_CONNS_PER_HOST=${GATEWAY_MAX_IDLE_CONNS_PER_HOST:-120} + - GATEWAY_MAX_CONNS_PER_HOST=${GATEWAY_MAX_CONNS_PER_HOST:-8192} + + # ======================================================================= + # Redis Configuration + # ======================================================================= + # Using host network: point to host/external Redis by REDIS_HOST/REDIS_PORT + - REDIS_HOST=${REDIS_HOST:-127.0.0.1} + - REDIS_PORT=${REDIS_PORT:-6379} + - REDIS_PASSWORD=${REDIS_PASSWORD:-} + - REDIS_DB=${REDIS_DB:-0} + - REDIS_POOL_SIZE=${REDIS_POOL_SIZE:-1024} + - REDIS_MIN_IDLE_CONNS=${REDIS_MIN_IDLE_CONNS:-10} + - REDIS_ENABLE_TLS=${REDIS_ENABLE_TLS:-false} + + # ======================================================================= + # Admin Account (auto-created on first run) + # ======================================================================= + - ADMIN_EMAIL=${ADMIN_EMAIL:-admin@aicodex2api.local} + - ADMIN_PASSWORD=${ADMIN_PASSWORD:-} + + # ======================================================================= + # JWT Configuration + # ======================================================================= + # Leave empty to auto-generate (recommended) + - JWT_SECRET=${JWT_SECRET:-} + - JWT_EXPIRE_HOUR=${JWT_EXPIRE_HOUR:-24} + + # ======================================================================= + # TOTP (2FA) Configuration + # ======================================================================= + # IMPORTANT: Set a fixed encryption key for TOTP secrets. If left empty, + # a random key will be generated on each startup, causing all existing + # TOTP configurations to become invalid (users won't be able to login + # with 2FA). + # Generate a secure key: openssl rand -hex 32 + - TOTP_ENCRYPTION_KEY=${TOTP_ENCRYPTION_KEY:-} + + # ======================================================================= + # Timezone Configuration + # This affects ALL time operations in the application: + # - Database timestamps + # - Usage statistics "today" boundary + # - Subscription expiry times + # - Log timestamps + # Common values: Asia/Shanghai, America/New_York, Europe/London, UTC + # ======================================================================= + - TZ=${TZ:-Asia/Shanghai} + + # ======================================================================= + # Gemini OAuth Configuration (for Gemini accounts) + # ======================================================================= + - GEMINI_OAUTH_CLIENT_ID=${GEMINI_OAUTH_CLIENT_ID:-} + - GEMINI_OAUTH_CLIENT_SECRET=${GEMINI_OAUTH_CLIENT_SECRET:-} + - GEMINI_OAUTH_SCOPES=${GEMINI_OAUTH_SCOPES:-} + - GEMINI_QUOTA_POLICY=${GEMINI_QUOTA_POLICY:-} + + # Built-in OAuth client secrets (optional) + # SECURITY: This repo does not embed third-party client_secret. + - GEMINI_CLI_OAUTH_CLIENT_SECRET=${GEMINI_CLI_OAUTH_CLIENT_SECRET:-} + - ANTIGRAVITY_OAUTH_CLIENT_SECRET=${ANTIGRAVITY_OAUTH_CLIENT_SECRET:-} + + # ======================================================================= + # Security Configuration (URL Allowlist) + # ======================================================================= + # Allow private IP addresses for CRS sync (for internal deployments) + - SECURITY_URL_ALLOWLIST_ALLOW_PRIVATE_HOSTS=${SECURITY_URL_ALLOWLIST_ALLOW_PRIVATE_HOSTS:-true} + depends_on: + postgres: + condition: service_healthy + redis: + condition: service_healthy + healthcheck: + test: ["CMD", "curl", "-f", "http://localhost:8080/health"] + interval: 30s + timeout: 10s + retries: 3 + start_period: 30s + + # =========================================================================== + # PostgreSQL Database + # =========================================================================== + postgres: + image: postgres:18-alpine + container_name: aicodex2api-postgres + restart: unless-stopped + network_mode: host + ulimits: + nofile: + soft: 800000 + hard: 800000 + volumes: + - postgres_data:/var/lib/postgresql/data + environment: + # postgres:18-alpine 默认 PGDATA=/var/lib/postgresql/18/docker(位于镜像声明的匿名卷 /var/lib/postgresql 内)。 + # 若不显式设置 PGDATA,则即使挂载了 postgres_data 到 /var/lib/postgresql/data,数据也不会落盘到该命名卷, + # docker compose down/up 后会触发 initdb 重新初始化,导致用户/密码等数据丢失。 + - PGDATA=/var/lib/postgresql/data + - POSTGRES_USER=${POSTGRES_USER:-aicodex2api} + - POSTGRES_PASSWORD=${POSTGRES_PASSWORD:?POSTGRES_PASSWORD is required} + - POSTGRES_DB=${POSTGRES_DB:-aicodex2api} + - TZ=${TZ:-Asia/Shanghai} + command: + - "postgres" + - "-c" + - "listen_addresses=127.0.0.1" + # 监听端口:与应用侧 DATABASE_PORT 保持一致。 + - "-c" + - "port=${DATABASE_PORT:-5432}" + # 连接数上限:需要结合应用侧 DATABASE_MAX_OPEN_CONNS 调整。 + # 注意:max_connections 过大可能导致内存占用与上下文切换开销显著上升。 + - "-c" + - "max_connections=${POSTGRES_MAX_CONNECTIONS:-1024}" + # 典型内存参数(建议结合机器内存调优;不确定就保持默认或小步调大)。 + - "-c" + - "shared_buffers=${POSTGRES_SHARED_BUFFERS:-1GB}" + - "-c" + - "effective_cache_size=${POSTGRES_EFFECTIVE_CACHE_SIZE:-6GB}" + - "-c" + - "maintenance_work_mem=${POSTGRES_MAINTENANCE_WORK_MEM:-128MB}" + healthcheck: + test: ["CMD-SHELL", "pg_isready -U ${POSTGRES_USER:-aicodex2api} -d ${POSTGRES_DB:-aicodex2api} -p ${DATABASE_PORT:-5432}"] + interval: 10s + timeout: 5s + retries: 5 + start_period: 10s + # Note: bound to localhost only; not exposed to external network by default. + + # =========================================================================== + # Redis Cache + # =========================================================================== + redis: + image: redis:8-alpine + container_name: aicodex2api-redis + restart: unless-stopped + network_mode: host + ulimits: + nofile: + soft: 100000 + hard: 100000 + volumes: + - redis_data:/data + command: > + redis-server + --bind 127.0.0.1 + --port ${REDIS_PORT:-6379} + --maxclients ${REDIS_MAXCLIENTS:-50000} + --save 60 1 + --appendonly yes + --appendfsync everysec + ${REDIS_PASSWORD:+--requirepass ${REDIS_PASSWORD}} + environment: + - TZ=${TZ:-Asia/Shanghai} + # REDISCLI_AUTH is used by redis-cli for authentication (safer than -a flag) + - REDISCLI_AUTH=${REDIS_PASSWORD:-} + healthcheck: + test: ["CMD-SHELL", "redis-cli -p ${REDIS_PORT:-6379} -a \"$REDISCLI_AUTH\" ping | grep -q PONG || redis-cli -p ${REDIS_PORT:-6379} ping | grep -q PONG"] + interval: 10s + timeout: 5s + retries: 5 + start_period: 5s + +# ============================================================================= +# Volumes +# ============================================================================= +volumes: + aicodex2api_data: + driver: local + postgres_data: + driver: local + redis_data: + driver: local diff --git a/frontend/src/components/account/CreateAccountModal.vue b/frontend/src/components/account/CreateAccountModal.vue index 75f04081..f7e6f5ff 100644 --- a/frontend/src/components/account/CreateAccountModal.vue +++ b/frontend/src/components/account/CreateAccountModal.vue @@ -1807,7 +1807,7 @@ - +

- {{ t('admin.accounts.openai.wsModeConcurrencyHint') }} + {{ t(openAIWSModeConcurrencyHintKey) }}

@@ -2341,10 +2341,11 @@ import { applyInterceptWarmup } from '@/components/account/credentialsBuilder' import { formatDateTimeLocalInput, parseDateTimeLocalInput } from '@/utils/format' import { createStableObjectKeyResolver } from '@/utils/stableObjectKey' import { - OPENAI_WS_MODE_DEDICATED, + OPENAI_WS_MODE_CTX_POOL, OPENAI_WS_MODE_OFF, - OPENAI_WS_MODE_SHARED, + OPENAI_WS_MODE_PASSTHROUGH, isOpenAIWSModeEnabled, + resolveOpenAIWSModeConcurrencyHintKey, type OpenAIWSMode } from '@/utils/openaiWsMode' import OAuthAuthorizationFlow from './OAuthAuthorizationFlow.vue' @@ -2541,8 +2542,8 @@ const geminiSelectedTier = computed(() => { const openAIWSModeOptions = computed(() => [ { value: OPENAI_WS_MODE_OFF, label: t('admin.accounts.openai.wsModeOff') }, - { value: OPENAI_WS_MODE_SHARED, label: t('admin.accounts.openai.wsModeShared') }, - { value: OPENAI_WS_MODE_DEDICATED, label: t('admin.accounts.openai.wsModeDedicated') } + { value: OPENAI_WS_MODE_CTX_POOL, label: t('admin.accounts.openai.wsModeCtxPool') }, + { value: OPENAI_WS_MODE_PASSTHROUGH, label: t('admin.accounts.openai.wsModePassthrough') } ]) const openaiResponsesWebSocketV2Mode = computed({ @@ -2561,6 +2562,10 @@ const openaiResponsesWebSocketV2Mode = computed({ } }) +const openAIWSModeConcurrencyHintKey = computed(() => + resolveOpenAIWSModeConcurrencyHintKey(openaiResponsesWebSocketV2Mode.value) +) + const isOpenAIModelRestrictionDisabled = computed(() => form.platform === 'openai' && openaiPassthroughEnabled.value ) @@ -3180,10 +3185,13 @@ const buildOpenAIExtra = (base?: Record): Record = { ...(base || {}) } - extra.openai_oauth_responses_websockets_v2_mode = openaiOAuthResponsesWebSocketV2Mode.value - extra.openai_apikey_responses_websockets_v2_mode = openaiAPIKeyResponsesWebSocketV2Mode.value - extra.openai_oauth_responses_websockets_v2_enabled = isOpenAIWSModeEnabled(openaiOAuthResponsesWebSocketV2Mode.value) - extra.openai_apikey_responses_websockets_v2_enabled = isOpenAIWSModeEnabled(openaiAPIKeyResponsesWebSocketV2Mode.value) + if (accountCategory.value === 'oauth-based') { + extra.openai_oauth_responses_websockets_v2_mode = openaiOAuthResponsesWebSocketV2Mode.value + extra.openai_oauth_responses_websockets_v2_enabled = isOpenAIWSModeEnabled(openaiOAuthResponsesWebSocketV2Mode.value) + } else if (accountCategory.value === 'apikey') { + extra.openai_apikey_responses_websockets_v2_mode = openaiAPIKeyResponsesWebSocketV2Mode.value + extra.openai_apikey_responses_websockets_v2_enabled = isOpenAIWSModeEnabled(openaiAPIKeyResponsesWebSocketV2Mode.value) + } // 清理兼容旧键,统一改用分类型开关。 delete extra.responses_websockets_v2_enabled delete extra.openai_ws_enabled diff --git a/frontend/src/components/account/EditAccountModal.vue b/frontend/src/components/account/EditAccountModal.vue index 24166a5c..20d785e2 100644 --- a/frontend/src/components/account/EditAccountModal.vue +++ b/frontend/src/components/account/EditAccountModal.vue @@ -708,7 +708,7 @@
- +

- {{ t('admin.accounts.openai.wsModeConcurrencyHint') }} + {{ t(openAIWSModeConcurrencyHintKey) }}

@@ -1273,10 +1273,11 @@ import { applyInterceptWarmup } from '@/components/account/credentialsBuilder' import { formatDateTimeLocalInput, parseDateTimeLocalInput } from '@/utils/format' import { createStableObjectKeyResolver } from '@/utils/stableObjectKey' import { - OPENAI_WS_MODE_DEDICATED, + OPENAI_WS_MODE_CTX_POOL, OPENAI_WS_MODE_OFF, - OPENAI_WS_MODE_SHARED, + OPENAI_WS_MODE_PASSTHROUGH, isOpenAIWSModeEnabled, + resolveOpenAIWSModeConcurrencyHintKey, type OpenAIWSMode, resolveOpenAIWSModeFromExtra } from '@/utils/openaiWsMode' @@ -1387,8 +1388,8 @@ const codexCLIOnlyEnabled = ref(false) const anthropicPassthroughEnabled = ref(false) const openAIWSModeOptions = computed(() => [ { value: OPENAI_WS_MODE_OFF, label: t('admin.accounts.openai.wsModeOff') }, - { value: OPENAI_WS_MODE_SHARED, label: t('admin.accounts.openai.wsModeShared') }, - { value: OPENAI_WS_MODE_DEDICATED, label: t('admin.accounts.openai.wsModeDedicated') } + { value: OPENAI_WS_MODE_CTX_POOL, label: t('admin.accounts.openai.wsModeCtxPool') }, + { value: OPENAI_WS_MODE_PASSTHROUGH, label: t('admin.accounts.openai.wsModePassthrough') } ]) const openaiResponsesWebSocketV2Mode = computed({ get: () => { @@ -1405,6 +1406,9 @@ const openaiResponsesWebSocketV2Mode = computed({ openaiOAuthResponsesWebSocketV2Mode.value = mode } }) +const openAIWSModeConcurrencyHintKey = computed(() => + resolveOpenAIWSModeConcurrencyHintKey(openaiResponsesWebSocketV2Mode.value) +) const isOpenAIModelRestrictionDisabled = computed(() => props.account?.platform === 'openai' && openaiPassthroughEnabled.value ) @@ -2248,10 +2252,13 @@ const handleSubmit = async () => { const currentExtra = (props.account.extra as Record) || {} const newExtra: Record = { ...currentExtra } const hadCodexCLIOnlyEnabled = currentExtra.codex_cli_only === true - newExtra.openai_oauth_responses_websockets_v2_mode = openaiOAuthResponsesWebSocketV2Mode.value - newExtra.openai_apikey_responses_websockets_v2_mode = openaiAPIKeyResponsesWebSocketV2Mode.value - newExtra.openai_oauth_responses_websockets_v2_enabled = isOpenAIWSModeEnabled(openaiOAuthResponsesWebSocketV2Mode.value) - newExtra.openai_apikey_responses_websockets_v2_enabled = isOpenAIWSModeEnabled(openaiAPIKeyResponsesWebSocketV2Mode.value) + if (props.account.type === 'oauth') { + newExtra.openai_oauth_responses_websockets_v2_mode = openaiOAuthResponsesWebSocketV2Mode.value + newExtra.openai_oauth_responses_websockets_v2_enabled = isOpenAIWSModeEnabled(openaiOAuthResponsesWebSocketV2Mode.value) + } else if (props.account.type === 'apikey') { + newExtra.openai_apikey_responses_websockets_v2_mode = openaiAPIKeyResponsesWebSocketV2Mode.value + newExtra.openai_apikey_responses_websockets_v2_enabled = isOpenAIWSModeEnabled(openaiAPIKeyResponsesWebSocketV2Mode.value) + } delete newExtra.responses_websockets_v2_enabled delete newExtra.openai_ws_enabled if (openaiPassthroughEnabled.value) { diff --git a/frontend/src/i18n/locales/en.ts b/frontend/src/i18n/locales/en.ts index 3963ad01..bd76db3a 100644 --- a/frontend/src/i18n/locales/en.ts +++ b/frontend/src/i18n/locales/en.ts @@ -1787,10 +1787,13 @@ export default { wsMode: 'WS mode', wsModeDesc: 'Only applies to the current OpenAI account type.', wsModeOff: 'Off (off)', + wsModeCtxPool: 'Context Pool (ctx_pool)', + wsModePassthrough: 'Passthrough (passthrough)', wsModeShared: 'Shared (shared)', wsModeDedicated: 'Dedicated (dedicated)', wsModeConcurrencyHint: 'When WS mode is enabled, account concurrency becomes the WS connection pool limit for this account.', + wsModePassthroughHint: 'Passthrough mode does not use the WS connection pool.', oauthResponsesWebsocketsV2: 'OAuth WebSocket Mode', oauthResponsesWebsocketsV2Desc: 'Only applies to OpenAI OAuth. This account can use OpenAI WebSocket Mode only when enabled.', diff --git a/frontend/src/i18n/locales/zh.ts b/frontend/src/i18n/locales/zh.ts index a692b7f6..671b4ec5 100644 --- a/frontend/src/i18n/locales/zh.ts +++ b/frontend/src/i18n/locales/zh.ts @@ -1935,9 +1935,12 @@ export default { wsMode: 'WS mode', wsModeDesc: '仅对当前 OpenAI 账号类型生效。', wsModeOff: '关闭(off)', + wsModeCtxPool: '上下文池(ctx_pool)', + wsModePassthrough: '透传(passthrough)', wsModeShared: '共享(shared)', wsModeDedicated: '独享(dedicated)', wsModeConcurrencyHint: '启用 WS mode 后,该账号并发数将作为该账号 WS 连接池上限。', + wsModePassthroughHint: 'passthrough 模式不使用 WS 连接池。', oauthResponsesWebsocketsV2: 'OAuth WebSocket Mode', oauthResponsesWebsocketsV2Desc: '仅对 OpenAI OAuth 生效。开启后该账号才允许使用 OpenAI WebSocket Mode 协议。', diff --git a/frontend/src/utils/__tests__/openaiWsMode.spec.ts b/frontend/src/utils/__tests__/openaiWsMode.spec.ts index 39f21aef..8e4f33b2 100644 --- a/frontend/src/utils/__tests__/openaiWsMode.spec.ts +++ b/frontend/src/utils/__tests__/openaiWsMode.spec.ts @@ -1,31 +1,34 @@ import { describe, expect, it } from 'vitest' import { - OPENAI_WS_MODE_DEDICATED, + OPENAI_WS_MODE_CTX_POOL, OPENAI_WS_MODE_OFF, - OPENAI_WS_MODE_SHARED, + OPENAI_WS_MODE_PASSTHROUGH, isOpenAIWSModeEnabled, normalizeOpenAIWSMode, openAIWSModeFromEnabled, + resolveOpenAIWSModeConcurrencyHintKey, resolveOpenAIWSModeFromExtra } from '@/utils/openaiWsMode' describe('openaiWsMode utils', () => { it('normalizes mode values', () => { expect(normalizeOpenAIWSMode('off')).toBe(OPENAI_WS_MODE_OFF) - expect(normalizeOpenAIWSMode(' Shared ')).toBe(OPENAI_WS_MODE_SHARED) - expect(normalizeOpenAIWSMode('DEDICATED')).toBe(OPENAI_WS_MODE_DEDICATED) + expect(normalizeOpenAIWSMode('ctx_pool')).toBe(OPENAI_WS_MODE_CTX_POOL) + expect(normalizeOpenAIWSMode('passthrough')).toBe(OPENAI_WS_MODE_PASSTHROUGH) + expect(normalizeOpenAIWSMode(' Shared ')).toBe(OPENAI_WS_MODE_CTX_POOL) + expect(normalizeOpenAIWSMode('DEDICATED')).toBe(OPENAI_WS_MODE_CTX_POOL) expect(normalizeOpenAIWSMode('invalid')).toBeNull() }) it('maps legacy enabled flag to mode', () => { - expect(openAIWSModeFromEnabled(true)).toBe(OPENAI_WS_MODE_SHARED) + expect(openAIWSModeFromEnabled(true)).toBe(OPENAI_WS_MODE_CTX_POOL) expect(openAIWSModeFromEnabled(false)).toBe(OPENAI_WS_MODE_OFF) expect(openAIWSModeFromEnabled('true')).toBeNull() }) it('resolves by mode key first, then enabled, then fallback enabled keys', () => { const extra = { - openai_oauth_responses_websockets_v2_mode: 'dedicated', + openai_oauth_responses_websockets_v2_mode: 'passthrough', openai_oauth_responses_websockets_v2_enabled: false, responses_websockets_v2_enabled: false } @@ -34,7 +37,7 @@ describe('openaiWsMode utils', () => { enabledKey: 'openai_oauth_responses_websockets_v2_enabled', fallbackEnabledKeys: ['responses_websockets_v2_enabled', 'openai_ws_enabled'] }) - expect(mode).toBe(OPENAI_WS_MODE_DEDICATED) + expect(mode).toBe(OPENAI_WS_MODE_PASSTHROUGH) }) it('falls back to default when nothing is present', () => { @@ -47,9 +50,21 @@ describe('openaiWsMode utils', () => { expect(mode).toBe(OPENAI_WS_MODE_OFF) }) - it('treats off as disabled and shared/dedicated as enabled', () => { + it('treats off as disabled and non-off modes as enabled', () => { expect(isOpenAIWSModeEnabled(OPENAI_WS_MODE_OFF)).toBe(false) - expect(isOpenAIWSModeEnabled(OPENAI_WS_MODE_SHARED)).toBe(true) - expect(isOpenAIWSModeEnabled(OPENAI_WS_MODE_DEDICATED)).toBe(true) + expect(isOpenAIWSModeEnabled(OPENAI_WS_MODE_CTX_POOL)).toBe(true) + expect(isOpenAIWSModeEnabled(OPENAI_WS_MODE_PASSTHROUGH)).toBe(true) + }) + + it('resolves concurrency hint key by mode', () => { + expect(resolveOpenAIWSModeConcurrencyHintKey(OPENAI_WS_MODE_OFF)).toBe( + 'admin.accounts.openai.wsModeConcurrencyHint' + ) + expect(resolveOpenAIWSModeConcurrencyHintKey(OPENAI_WS_MODE_CTX_POOL)).toBe( + 'admin.accounts.openai.wsModeConcurrencyHint' + ) + expect(resolveOpenAIWSModeConcurrencyHintKey(OPENAI_WS_MODE_PASSTHROUGH)).toBe( + 'admin.accounts.openai.wsModePassthroughHint' + ) }) }) diff --git a/frontend/src/utils/openaiWsMode.ts b/frontend/src/utils/openaiWsMode.ts index b3e9cc00..52eba8b0 100644 --- a/frontend/src/utils/openaiWsMode.ts +++ b/frontend/src/utils/openaiWsMode.ts @@ -1,16 +1,16 @@ export const OPENAI_WS_MODE_OFF = 'off' -export const OPENAI_WS_MODE_SHARED = 'shared' -export const OPENAI_WS_MODE_DEDICATED = 'dedicated' +export const OPENAI_WS_MODE_CTX_POOL = 'ctx_pool' +export const OPENAI_WS_MODE_PASSTHROUGH = 'passthrough' export type OpenAIWSMode = | typeof OPENAI_WS_MODE_OFF - | typeof OPENAI_WS_MODE_SHARED - | typeof OPENAI_WS_MODE_DEDICATED + | typeof OPENAI_WS_MODE_CTX_POOL + | typeof OPENAI_WS_MODE_PASSTHROUGH const OPENAI_WS_MODES = new Set([ OPENAI_WS_MODE_OFF, - OPENAI_WS_MODE_SHARED, - OPENAI_WS_MODE_DEDICATED + OPENAI_WS_MODE_CTX_POOL, + OPENAI_WS_MODE_PASSTHROUGH ]) export interface ResolveOpenAIWSModeOptions { @@ -23,6 +23,9 @@ export interface ResolveOpenAIWSModeOptions { export const normalizeOpenAIWSMode = (mode: unknown): OpenAIWSMode | null => { if (typeof mode !== 'string') return null const normalized = mode.trim().toLowerCase() + if (normalized === 'shared' || normalized === 'dedicated') { + return OPENAI_WS_MODE_CTX_POOL + } if (OPENAI_WS_MODES.has(normalized as OpenAIWSMode)) { return normalized as OpenAIWSMode } @@ -31,13 +34,22 @@ export const normalizeOpenAIWSMode = (mode: unknown): OpenAIWSMode | null => { export const openAIWSModeFromEnabled = (enabled: unknown): OpenAIWSMode | null => { if (typeof enabled !== 'boolean') return null - return enabled ? OPENAI_WS_MODE_SHARED : OPENAI_WS_MODE_OFF + return enabled ? OPENAI_WS_MODE_CTX_POOL : OPENAI_WS_MODE_OFF } export const isOpenAIWSModeEnabled = (mode: OpenAIWSMode): boolean => { return mode !== OPENAI_WS_MODE_OFF } +export const resolveOpenAIWSModeConcurrencyHintKey = ( + mode: OpenAIWSMode +): 'admin.accounts.openai.wsModeConcurrencyHint' | 'admin.accounts.openai.wsModePassthroughHint' => { + if (mode === OPENAI_WS_MODE_PASSTHROUGH) { + return 'admin.accounts.openai.wsModePassthroughHint' + } + return 'admin.accounts.openai.wsModeConcurrencyHint' +} + export const resolveOpenAIWSModeFromExtra = ( extra: Record | null | undefined, options: ResolveOpenAIWSModeOptions