Merge pull request #772 from mt21625457/aicodex2api-main
feat(openai-ws): 合并 WS v2 透传模式与前端 ws mode
This commit is contained in:
@@ -516,7 +516,7 @@ func (c *UserMessageQueueConfig) GetEffectiveMode() string {
|
|||||||
type GatewayOpenAIWSConfig struct {
|
type GatewayOpenAIWSConfig struct {
|
||||||
// ModeRouterV2Enabled: 新版 WS mode 路由开关(默认 false;关闭时保持 legacy 行为)
|
// ModeRouterV2Enabled: 新版 WS mode 路由开关(默认 false;关闭时保持 legacy 行为)
|
||||||
ModeRouterV2Enabled bool `mapstructure:"mode_router_v2_enabled"`
|
ModeRouterV2Enabled bool `mapstructure:"mode_router_v2_enabled"`
|
||||||
// IngressModeDefault: ingress 默认模式(off/shared/dedicated)
|
// IngressModeDefault: ingress 默认模式(off/ctx_pool/passthrough)
|
||||||
IngressModeDefault string `mapstructure:"ingress_mode_default"`
|
IngressModeDefault string `mapstructure:"ingress_mode_default"`
|
||||||
// Enabled: 全局总开关(默认 true)
|
// Enabled: 全局总开关(默认 true)
|
||||||
Enabled bool `mapstructure:"enabled"`
|
Enabled bool `mapstructure:"enabled"`
|
||||||
@@ -1335,7 +1335,7 @@ func setDefaults() {
|
|||||||
// OpenAI Responses WebSocket(默认开启;可通过 force_http 紧急回滚)
|
// OpenAI Responses WebSocket(默认开启;可通过 force_http 紧急回滚)
|
||||||
viper.SetDefault("gateway.openai_ws.enabled", true)
|
viper.SetDefault("gateway.openai_ws.enabled", true)
|
||||||
viper.SetDefault("gateway.openai_ws.mode_router_v2_enabled", false)
|
viper.SetDefault("gateway.openai_ws.mode_router_v2_enabled", false)
|
||||||
viper.SetDefault("gateway.openai_ws.ingress_mode_default", "shared")
|
viper.SetDefault("gateway.openai_ws.ingress_mode_default", "ctx_pool")
|
||||||
viper.SetDefault("gateway.openai_ws.oauth_enabled", true)
|
viper.SetDefault("gateway.openai_ws.oauth_enabled", true)
|
||||||
viper.SetDefault("gateway.openai_ws.apikey_enabled", true)
|
viper.SetDefault("gateway.openai_ws.apikey_enabled", true)
|
||||||
viper.SetDefault("gateway.openai_ws.force_http", false)
|
viper.SetDefault("gateway.openai_ws.force_http", false)
|
||||||
@@ -2043,9 +2043,11 @@ func (c *Config) Validate() error {
|
|||||||
}
|
}
|
||||||
if mode := strings.ToLower(strings.TrimSpace(c.Gateway.OpenAIWS.IngressModeDefault)); mode != "" {
|
if mode := strings.ToLower(strings.TrimSpace(c.Gateway.OpenAIWS.IngressModeDefault)); mode != "" {
|
||||||
switch mode {
|
switch mode {
|
||||||
case "off", "shared", "dedicated":
|
case "off", "ctx_pool", "passthrough":
|
||||||
|
case "shared", "dedicated":
|
||||||
|
slog.Warn("gateway.openai_ws.ingress_mode_default is deprecated, treating as ctx_pool; please update to off|ctx_pool|passthrough", "value", mode)
|
||||||
default:
|
default:
|
||||||
return fmt.Errorf("gateway.openai_ws.ingress_mode_default must be one of off|shared|dedicated")
|
return fmt.Errorf("gateway.openai_ws.ingress_mode_default must be one of off|ctx_pool|passthrough")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if mode := strings.ToLower(strings.TrimSpace(c.Gateway.OpenAIWS.StoreDisabledConnMode)); mode != "" {
|
if mode := strings.ToLower(strings.TrimSpace(c.Gateway.OpenAIWS.StoreDisabledConnMode)); mode != "" {
|
||||||
|
|||||||
@@ -153,8 +153,8 @@ func TestLoadDefaultOpenAIWSConfig(t *testing.T) {
|
|||||||
if cfg.Gateway.OpenAIWS.ModeRouterV2Enabled {
|
if cfg.Gateway.OpenAIWS.ModeRouterV2Enabled {
|
||||||
t.Fatalf("Gateway.OpenAIWS.ModeRouterV2Enabled = true, want false")
|
t.Fatalf("Gateway.OpenAIWS.ModeRouterV2Enabled = true, want false")
|
||||||
}
|
}
|
||||||
if cfg.Gateway.OpenAIWS.IngressModeDefault != "shared" {
|
if cfg.Gateway.OpenAIWS.IngressModeDefault != "ctx_pool" {
|
||||||
t.Fatalf("Gateway.OpenAIWS.IngressModeDefault = %q, want %q", cfg.Gateway.OpenAIWS.IngressModeDefault, "shared")
|
t.Fatalf("Gateway.OpenAIWS.IngressModeDefault = %q, want %q", cfg.Gateway.OpenAIWS.IngressModeDefault, "ctx_pool")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1373,7 +1373,7 @@ func TestValidateConfig_OpenAIWSRules(t *testing.T) {
|
|||||||
wantErr: "gateway.openai_ws.store_disabled_conn_mode",
|
wantErr: "gateway.openai_ws.store_disabled_conn_mode",
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "ingress_mode_default 必须为 off|shared|dedicated",
|
name: "ingress_mode_default 必须为 off|ctx_pool|passthrough",
|
||||||
mutate: func(c *Config) { c.Gateway.OpenAIWS.IngressModeDefault = "invalid" },
|
mutate: func(c *Config) { c.Gateway.OpenAIWS.IngressModeDefault = "invalid" },
|
||||||
wantErr: "gateway.openai_ws.ingress_mode_default",
|
wantErr: "gateway.openai_ws.ingress_mode_default",
|
||||||
},
|
},
|
||||||
|
|||||||
@@ -853,15 +853,21 @@ func (a *Account) IsOpenAIResponsesWebSocketV2Enabled() bool {
|
|||||||
}
|
}
|
||||||
|
|
||||||
const (
|
const (
|
||||||
OpenAIWSIngressModeOff = "off"
|
OpenAIWSIngressModeOff = "off"
|
||||||
OpenAIWSIngressModeShared = "shared"
|
OpenAIWSIngressModeShared = "shared"
|
||||||
OpenAIWSIngressModeDedicated = "dedicated"
|
OpenAIWSIngressModeDedicated = "dedicated"
|
||||||
|
OpenAIWSIngressModeCtxPool = "ctx_pool"
|
||||||
|
OpenAIWSIngressModePassthrough = "passthrough"
|
||||||
)
|
)
|
||||||
|
|
||||||
func normalizeOpenAIWSIngressMode(mode string) string {
|
func normalizeOpenAIWSIngressMode(mode string) string {
|
||||||
switch strings.ToLower(strings.TrimSpace(mode)) {
|
switch strings.ToLower(strings.TrimSpace(mode)) {
|
||||||
case OpenAIWSIngressModeOff:
|
case OpenAIWSIngressModeOff:
|
||||||
return OpenAIWSIngressModeOff
|
return OpenAIWSIngressModeOff
|
||||||
|
case OpenAIWSIngressModeCtxPool:
|
||||||
|
return OpenAIWSIngressModeCtxPool
|
||||||
|
case OpenAIWSIngressModePassthrough:
|
||||||
|
return OpenAIWSIngressModePassthrough
|
||||||
case OpenAIWSIngressModeShared:
|
case OpenAIWSIngressModeShared:
|
||||||
return OpenAIWSIngressModeShared
|
return OpenAIWSIngressModeShared
|
||||||
case OpenAIWSIngressModeDedicated:
|
case OpenAIWSIngressModeDedicated:
|
||||||
@@ -873,18 +879,21 @@ func normalizeOpenAIWSIngressMode(mode string) string {
|
|||||||
|
|
||||||
func normalizeOpenAIWSIngressDefaultMode(mode string) string {
|
func normalizeOpenAIWSIngressDefaultMode(mode string) string {
|
||||||
if normalized := normalizeOpenAIWSIngressMode(mode); normalized != "" {
|
if normalized := normalizeOpenAIWSIngressMode(mode); normalized != "" {
|
||||||
|
if normalized == OpenAIWSIngressModeShared || normalized == OpenAIWSIngressModeDedicated {
|
||||||
|
return OpenAIWSIngressModeCtxPool
|
||||||
|
}
|
||||||
return normalized
|
return normalized
|
||||||
}
|
}
|
||||||
return OpenAIWSIngressModeShared
|
return OpenAIWSIngressModeCtxPool
|
||||||
}
|
}
|
||||||
|
|
||||||
// ResolveOpenAIResponsesWebSocketV2Mode 返回账号在 WSv2 ingress 下的有效模式(off/shared/dedicated)。
|
// ResolveOpenAIResponsesWebSocketV2Mode 返回账号在 WSv2 ingress 下的有效模式(off/ctx_pool/passthrough)。
|
||||||
//
|
//
|
||||||
// 优先级:
|
// 优先级:
|
||||||
// 1. 分类型 mode 新字段(string)
|
// 1. 分类型 mode 新字段(string)
|
||||||
// 2. 分类型 enabled 旧字段(bool)
|
// 2. 分类型 enabled 旧字段(bool)
|
||||||
// 3. 兼容 enabled 旧字段(bool)
|
// 3. 兼容 enabled 旧字段(bool)
|
||||||
// 4. defaultMode(非法时回退 shared)
|
// 4. defaultMode(非法时回退 ctx_pool)
|
||||||
func (a *Account) ResolveOpenAIResponsesWebSocketV2Mode(defaultMode string) string {
|
func (a *Account) ResolveOpenAIResponsesWebSocketV2Mode(defaultMode string) string {
|
||||||
resolvedDefault := normalizeOpenAIWSIngressDefaultMode(defaultMode)
|
resolvedDefault := normalizeOpenAIWSIngressDefaultMode(defaultMode)
|
||||||
if a == nil || !a.IsOpenAI() {
|
if a == nil || !a.IsOpenAI() {
|
||||||
@@ -919,7 +928,7 @@ func (a *Account) ResolveOpenAIResponsesWebSocketV2Mode(defaultMode string) stri
|
|||||||
return "", false
|
return "", false
|
||||||
}
|
}
|
||||||
if enabled {
|
if enabled {
|
||||||
return OpenAIWSIngressModeShared, true
|
return OpenAIWSIngressModeCtxPool, true
|
||||||
}
|
}
|
||||||
return OpenAIWSIngressModeOff, true
|
return OpenAIWSIngressModeOff, true
|
||||||
}
|
}
|
||||||
@@ -946,6 +955,10 @@ func (a *Account) ResolveOpenAIResponsesWebSocketV2Mode(defaultMode string) stri
|
|||||||
if mode, ok := resolveBoolMode("openai_ws_enabled"); ok {
|
if mode, ok := resolveBoolMode("openai_ws_enabled"); ok {
|
||||||
return mode
|
return mode
|
||||||
}
|
}
|
||||||
|
// 兼容旧值:shared/dedicated 语义都归并到 ctx_pool。
|
||||||
|
if resolvedDefault == OpenAIWSIngressModeShared || resolvedDefault == OpenAIWSIngressModeDedicated {
|
||||||
|
return OpenAIWSIngressModeCtxPool
|
||||||
|
}
|
||||||
return resolvedDefault
|
return resolvedDefault
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -206,14 +206,14 @@ func TestAccount_IsOpenAIResponsesWebSocketV2Enabled(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestAccount_ResolveOpenAIResponsesWebSocketV2Mode(t *testing.T) {
|
func TestAccount_ResolveOpenAIResponsesWebSocketV2Mode(t *testing.T) {
|
||||||
t.Run("default fallback to shared", func(t *testing.T) {
|
t.Run("default fallback to ctx_pool", func(t *testing.T) {
|
||||||
account := &Account{
|
account := &Account{
|
||||||
Platform: PlatformOpenAI,
|
Platform: PlatformOpenAI,
|
||||||
Type: AccountTypeOAuth,
|
Type: AccountTypeOAuth,
|
||||||
Extra: map[string]any{},
|
Extra: map[string]any{},
|
||||||
}
|
}
|
||||||
require.Equal(t, OpenAIWSIngressModeShared, account.ResolveOpenAIResponsesWebSocketV2Mode(""))
|
require.Equal(t, OpenAIWSIngressModeCtxPool, account.ResolveOpenAIResponsesWebSocketV2Mode(""))
|
||||||
require.Equal(t, OpenAIWSIngressModeShared, account.ResolveOpenAIResponsesWebSocketV2Mode("invalid"))
|
require.Equal(t, OpenAIWSIngressModeCtxPool, account.ResolveOpenAIResponsesWebSocketV2Mode("invalid"))
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("oauth mode field has highest priority", func(t *testing.T) {
|
t.Run("oauth mode field has highest priority", func(t *testing.T) {
|
||||||
@@ -221,15 +221,15 @@ func TestAccount_ResolveOpenAIResponsesWebSocketV2Mode(t *testing.T) {
|
|||||||
Platform: PlatformOpenAI,
|
Platform: PlatformOpenAI,
|
||||||
Type: AccountTypeOAuth,
|
Type: AccountTypeOAuth,
|
||||||
Extra: map[string]any{
|
Extra: map[string]any{
|
||||||
"openai_oauth_responses_websockets_v2_mode": OpenAIWSIngressModeDedicated,
|
"openai_oauth_responses_websockets_v2_mode": OpenAIWSIngressModePassthrough,
|
||||||
"openai_oauth_responses_websockets_v2_enabled": false,
|
"openai_oauth_responses_websockets_v2_enabled": false,
|
||||||
"responses_websockets_v2_enabled": false,
|
"responses_websockets_v2_enabled": false,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
require.Equal(t, OpenAIWSIngressModeDedicated, account.ResolveOpenAIResponsesWebSocketV2Mode(OpenAIWSIngressModeShared))
|
require.Equal(t, OpenAIWSIngressModePassthrough, account.ResolveOpenAIResponsesWebSocketV2Mode(OpenAIWSIngressModeCtxPool))
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("legacy enabled maps to shared", func(t *testing.T) {
|
t.Run("legacy enabled maps to ctx_pool", func(t *testing.T) {
|
||||||
account := &Account{
|
account := &Account{
|
||||||
Platform: PlatformOpenAI,
|
Platform: PlatformOpenAI,
|
||||||
Type: AccountTypeAPIKey,
|
Type: AccountTypeAPIKey,
|
||||||
@@ -237,7 +237,28 @@ func TestAccount_ResolveOpenAIResponsesWebSocketV2Mode(t *testing.T) {
|
|||||||
"responses_websockets_v2_enabled": true,
|
"responses_websockets_v2_enabled": true,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
require.Equal(t, OpenAIWSIngressModeShared, account.ResolveOpenAIResponsesWebSocketV2Mode(OpenAIWSIngressModeOff))
|
require.Equal(t, OpenAIWSIngressModeCtxPool, account.ResolveOpenAIResponsesWebSocketV2Mode(OpenAIWSIngressModeOff))
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("shared/dedicated mode strings are compatible with ctx_pool", func(t *testing.T) {
|
||||||
|
shared := &Account{
|
||||||
|
Platform: PlatformOpenAI,
|
||||||
|
Type: AccountTypeOAuth,
|
||||||
|
Extra: map[string]any{
|
||||||
|
"openai_oauth_responses_websockets_v2_mode": OpenAIWSIngressModeShared,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
dedicated := &Account{
|
||||||
|
Platform: PlatformOpenAI,
|
||||||
|
Type: AccountTypeOAuth,
|
||||||
|
Extra: map[string]any{
|
||||||
|
"openai_oauth_responses_websockets_v2_mode": OpenAIWSIngressModeDedicated,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
require.Equal(t, OpenAIWSIngressModeShared, shared.ResolveOpenAIResponsesWebSocketV2Mode(OpenAIWSIngressModeOff))
|
||||||
|
require.Equal(t, OpenAIWSIngressModeDedicated, dedicated.ResolveOpenAIResponsesWebSocketV2Mode(OpenAIWSIngressModeOff))
|
||||||
|
require.Equal(t, OpenAIWSIngressModeCtxPool, normalizeOpenAIWSIngressDefaultMode(OpenAIWSIngressModeShared))
|
||||||
|
require.Equal(t, OpenAIWSIngressModeCtxPool, normalizeOpenAIWSIngressDefaultMode(OpenAIWSIngressModeDedicated))
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("legacy disabled maps to off", func(t *testing.T) {
|
t.Run("legacy disabled maps to off", func(t *testing.T) {
|
||||||
@@ -249,7 +270,7 @@ func TestAccount_ResolveOpenAIResponsesWebSocketV2Mode(t *testing.T) {
|
|||||||
"responses_websockets_v2_enabled": true,
|
"responses_websockets_v2_enabled": true,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
require.Equal(t, OpenAIWSIngressModeOff, account.ResolveOpenAIResponsesWebSocketV2Mode(OpenAIWSIngressModeShared))
|
require.Equal(t, OpenAIWSIngressModeOff, account.ResolveOpenAIResponsesWebSocketV2Mode(OpenAIWSIngressModeCtxPool))
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("non openai always off", func(t *testing.T) {
|
t.Run("non openai always off", func(t *testing.T) {
|
||||||
|
|||||||
@@ -263,13 +263,15 @@ type OpenAIGatewayService struct {
|
|||||||
toolCorrector *CodexToolCorrector
|
toolCorrector *CodexToolCorrector
|
||||||
openaiWSResolver OpenAIWSProtocolResolver
|
openaiWSResolver OpenAIWSProtocolResolver
|
||||||
|
|
||||||
openaiWSPoolOnce sync.Once
|
openaiWSPoolOnce sync.Once
|
||||||
openaiWSStateStoreOnce sync.Once
|
openaiWSStateStoreOnce sync.Once
|
||||||
openaiSchedulerOnce sync.Once
|
openaiSchedulerOnce sync.Once
|
||||||
openaiWSPool *openAIWSConnPool
|
openaiWSPassthroughDialerOnce sync.Once
|
||||||
openaiWSStateStore OpenAIWSStateStore
|
openaiWSPool *openAIWSConnPool
|
||||||
openaiScheduler OpenAIAccountScheduler
|
openaiWSStateStore OpenAIWSStateStore
|
||||||
openaiAccountStats *openAIAccountRuntimeStats
|
openaiScheduler OpenAIAccountScheduler
|
||||||
|
openaiWSPassthroughDialer openAIWSClientDialer
|
||||||
|
openaiAccountStats *openAIAccountRuntimeStats
|
||||||
|
|
||||||
openaiWSFallbackUntil sync.Map // key: int64(accountID), value: time.Time
|
openaiWSFallbackUntil sync.Map // key: int64(accountID), value: time.Time
|
||||||
openaiWSRetryMetrics openAIWSRetryMetrics
|
openaiWSRetryMetrics openAIWSRetryMetrics
|
||||||
|
|||||||
@@ -11,6 +11,7 @@ import (
|
|||||||
"sync/atomic"
|
"sync/atomic"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
openaiwsv2 "github.com/Wei-Shaw/sub2api/internal/service/openai_ws_v2"
|
||||||
coderws "github.com/coder/websocket"
|
coderws "github.com/coder/websocket"
|
||||||
"github.com/coder/websocket/wsjson"
|
"github.com/coder/websocket/wsjson"
|
||||||
)
|
)
|
||||||
@@ -234,6 +235,8 @@ type coderOpenAIWSClientConn struct {
|
|||||||
conn *coderws.Conn
|
conn *coderws.Conn
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var _ openaiwsv2.FrameConn = (*coderOpenAIWSClientConn)(nil)
|
||||||
|
|
||||||
func (c *coderOpenAIWSClientConn) WriteJSON(ctx context.Context, value any) error {
|
func (c *coderOpenAIWSClientConn) WriteJSON(ctx context.Context, value any) error {
|
||||||
if c == nil || c.conn == nil {
|
if c == nil || c.conn == nil {
|
||||||
return errOpenAIWSConnClosed
|
return errOpenAIWSConnClosed
|
||||||
@@ -264,6 +267,30 @@ func (c *coderOpenAIWSClientConn) ReadMessage(ctx context.Context) ([]byte, erro
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (c *coderOpenAIWSClientConn) ReadFrame(ctx context.Context) (coderws.MessageType, []byte, error) {
|
||||||
|
if c == nil || c.conn == nil {
|
||||||
|
return coderws.MessageText, nil, errOpenAIWSConnClosed
|
||||||
|
}
|
||||||
|
if ctx == nil {
|
||||||
|
ctx = context.Background()
|
||||||
|
}
|
||||||
|
msgType, payload, err := c.conn.Read(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return coderws.MessageText, nil, err
|
||||||
|
}
|
||||||
|
return msgType, payload, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *coderOpenAIWSClientConn) WriteFrame(ctx context.Context, msgType coderws.MessageType, payload []byte) error {
|
||||||
|
if c == nil || c.conn == nil {
|
||||||
|
return errOpenAIWSConnClosed
|
||||||
|
}
|
||||||
|
if ctx == nil {
|
||||||
|
ctx = context.Background()
|
||||||
|
}
|
||||||
|
return c.conn.Write(ctx, msgType, payload)
|
||||||
|
}
|
||||||
|
|
||||||
func (c *coderOpenAIWSClientConn) Ping(ctx context.Context) error {
|
func (c *coderOpenAIWSClientConn) Ping(ctx context.Context) error {
|
||||||
if c == nil || c.conn == nil {
|
if c == nil || c.conn == nil {
|
||||||
return errOpenAIWSConnClosed
|
return errOpenAIWSConnClosed
|
||||||
|
|||||||
@@ -46,9 +46,10 @@ const (
|
|||||||
openAIWSPayloadSizeEstimateMaxBytes = 64 * 1024
|
openAIWSPayloadSizeEstimateMaxBytes = 64 * 1024
|
||||||
openAIWSPayloadSizeEstimateMaxItems = 16
|
openAIWSPayloadSizeEstimateMaxItems = 16
|
||||||
|
|
||||||
openAIWSEventFlushBatchSizeDefault = 4
|
openAIWSEventFlushBatchSizeDefault = 4
|
||||||
openAIWSEventFlushIntervalDefault = 25 * time.Millisecond
|
openAIWSEventFlushIntervalDefault = 25 * time.Millisecond
|
||||||
openAIWSPayloadLogSampleDefault = 0.2
|
openAIWSPayloadLogSampleDefault = 0.2
|
||||||
|
openAIWSPassthroughIdleTimeoutDefault = time.Hour
|
||||||
|
|
||||||
openAIWSStoreDisabledConnModeStrict = "strict"
|
openAIWSStoreDisabledConnModeStrict = "strict"
|
||||||
openAIWSStoreDisabledConnModeAdaptive = "adaptive"
|
openAIWSStoreDisabledConnModeAdaptive = "adaptive"
|
||||||
@@ -904,6 +905,18 @@ func (s *OpenAIGatewayService) getOpenAIWSConnPool() *openAIWSConnPool {
|
|||||||
return s.openaiWSPool
|
return s.openaiWSPool
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *OpenAIGatewayService) getOpenAIWSPassthroughDialer() openAIWSClientDialer {
|
||||||
|
if s == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
s.openaiWSPassthroughDialerOnce.Do(func() {
|
||||||
|
if s.openaiWSPassthroughDialer == nil {
|
||||||
|
s.openaiWSPassthroughDialer = newDefaultOpenAIWSClientDialer()
|
||||||
|
}
|
||||||
|
})
|
||||||
|
return s.openaiWSPassthroughDialer
|
||||||
|
}
|
||||||
|
|
||||||
func (s *OpenAIGatewayService) SnapshotOpenAIWSPoolMetrics() OpenAIWSPoolMetricsSnapshot {
|
func (s *OpenAIGatewayService) SnapshotOpenAIWSPoolMetrics() OpenAIWSPoolMetricsSnapshot {
|
||||||
pool := s.getOpenAIWSConnPool()
|
pool := s.getOpenAIWSConnPool()
|
||||||
if pool == nil {
|
if pool == nil {
|
||||||
@@ -967,6 +980,13 @@ func (s *OpenAIGatewayService) openAIWSReadTimeout() time.Duration {
|
|||||||
return 15 * time.Minute
|
return 15 * time.Minute
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *OpenAIGatewayService) openAIWSPassthroughIdleTimeout() time.Duration {
|
||||||
|
if timeout := s.openAIWSReadTimeout(); timeout > 0 {
|
||||||
|
return timeout
|
||||||
|
}
|
||||||
|
return openAIWSPassthroughIdleTimeoutDefault
|
||||||
|
}
|
||||||
|
|
||||||
func (s *OpenAIGatewayService) openAIWSWriteTimeout() time.Duration {
|
func (s *OpenAIGatewayService) openAIWSWriteTimeout() time.Duration {
|
||||||
if s != nil && s.cfg != nil && s.cfg.Gateway.OpenAIWS.WriteTimeoutSeconds > 0 {
|
if s != nil && s.cfg != nil && s.cfg.Gateway.OpenAIWS.WriteTimeoutSeconds > 0 {
|
||||||
return time.Duration(s.cfg.Gateway.OpenAIWS.WriteTimeoutSeconds) * time.Second
|
return time.Duration(s.cfg.Gateway.OpenAIWS.WriteTimeoutSeconds) * time.Second
|
||||||
@@ -2322,7 +2342,7 @@ func (s *OpenAIGatewayService) ProxyResponsesWebSocketFromClient(
|
|||||||
|
|
||||||
wsDecision := s.getOpenAIWSProtocolResolver().Resolve(account)
|
wsDecision := s.getOpenAIWSProtocolResolver().Resolve(account)
|
||||||
modeRouterV2Enabled := s != nil && s.cfg != nil && s.cfg.Gateway.OpenAIWS.ModeRouterV2Enabled
|
modeRouterV2Enabled := s != nil && s.cfg != nil && s.cfg.Gateway.OpenAIWS.ModeRouterV2Enabled
|
||||||
ingressMode := OpenAIWSIngressModeShared
|
ingressMode := OpenAIWSIngressModeCtxPool
|
||||||
if modeRouterV2Enabled {
|
if modeRouterV2Enabled {
|
||||||
ingressMode = account.ResolveOpenAIResponsesWebSocketV2Mode(s.cfg.Gateway.OpenAIWS.IngressModeDefault)
|
ingressMode = account.ResolveOpenAIResponsesWebSocketV2Mode(s.cfg.Gateway.OpenAIWS.IngressModeDefault)
|
||||||
if ingressMode == OpenAIWSIngressModeOff {
|
if ingressMode == OpenAIWSIngressModeOff {
|
||||||
@@ -2332,6 +2352,30 @@ func (s *OpenAIGatewayService) ProxyResponsesWebSocketFromClient(
|
|||||||
nil,
|
nil,
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
switch ingressMode {
|
||||||
|
case OpenAIWSIngressModePassthrough:
|
||||||
|
if wsDecision.Transport != OpenAIUpstreamTransportResponsesWebsocketV2 {
|
||||||
|
return fmt.Errorf("websocket ingress requires ws_v2 transport, got=%s", wsDecision.Transport)
|
||||||
|
}
|
||||||
|
return s.proxyResponsesWebSocketV2Passthrough(
|
||||||
|
ctx,
|
||||||
|
c,
|
||||||
|
clientConn,
|
||||||
|
account,
|
||||||
|
token,
|
||||||
|
firstClientMessage,
|
||||||
|
hooks,
|
||||||
|
wsDecision,
|
||||||
|
)
|
||||||
|
case OpenAIWSIngressModeCtxPool, OpenAIWSIngressModeShared, OpenAIWSIngressModeDedicated:
|
||||||
|
// continue
|
||||||
|
default:
|
||||||
|
return NewOpenAIWSClientCloseError(
|
||||||
|
coderws.StatusPolicyViolation,
|
||||||
|
"websocket mode only supports ctx_pool/passthrough",
|
||||||
|
nil,
|
||||||
|
)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
if wsDecision.Transport != OpenAIUpstreamTransportResponsesWebsocketV2 {
|
if wsDecision.Transport != OpenAIUpstreamTransportResponsesWebsocketV2 {
|
||||||
return fmt.Errorf("websocket ingress requires ws_v2 transport, got=%s", wsDecision.Transport)
|
return fmt.Errorf("websocket ingress requires ws_v2 transport, got=%s", wsDecision.Transport)
|
||||||
|
|||||||
@@ -149,7 +149,7 @@ func TestOpenAIGatewayService_ProxyResponsesWebSocketFromClient_KeepLeaseAcrossT
|
|||||||
require.True(t, <-turnWSModeCh, "首轮 turn 应标记为 WS 模式")
|
require.True(t, <-turnWSModeCh, "首轮 turn 应标记为 WS 模式")
|
||||||
require.True(t, <-turnWSModeCh, "第二轮 turn 应标记为 WS 模式")
|
require.True(t, <-turnWSModeCh, "第二轮 turn 应标记为 WS 模式")
|
||||||
|
|
||||||
require.NoError(t, clientConn.Close(coderws.StatusNormalClosure, "done"))
|
_ = clientConn.Close(coderws.StatusNormalClosure, "done")
|
||||||
|
|
||||||
select {
|
select {
|
||||||
case serverErr := <-serverErrCh:
|
case serverErr := <-serverErrCh:
|
||||||
@@ -298,6 +298,140 @@ func TestOpenAIGatewayService_ProxyResponsesWebSocketFromClient_DedicatedModeDoe
|
|||||||
require.Equal(t, 2, dialer.DialCount(), "dedicated 模式下跨客户端会话不应复用上游连接")
|
require.Equal(t, 2, dialer.DialCount(), "dedicated 模式下跨客户端会话不应复用上游连接")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestOpenAIGatewayService_ProxyResponsesWebSocketFromClient_PassthroughModeRelaysByCaddyAdapter(t *testing.T) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
|
||||||
|
cfg := &config.Config{}
|
||||||
|
cfg.Security.URLAllowlist.Enabled = false
|
||||||
|
cfg.Security.URLAllowlist.AllowInsecureHTTP = true
|
||||||
|
cfg.Gateway.OpenAIWS.Enabled = true
|
||||||
|
cfg.Gateway.OpenAIWS.OAuthEnabled = true
|
||||||
|
cfg.Gateway.OpenAIWS.APIKeyEnabled = true
|
||||||
|
cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true
|
||||||
|
cfg.Gateway.OpenAIWS.ModeRouterV2Enabled = true
|
||||||
|
cfg.Gateway.OpenAIWS.IngressModeDefault = OpenAIWSIngressModeCtxPool
|
||||||
|
cfg.Gateway.OpenAIWS.DialTimeoutSeconds = 3
|
||||||
|
cfg.Gateway.OpenAIWS.ReadTimeoutSeconds = 3
|
||||||
|
cfg.Gateway.OpenAIWS.WriteTimeoutSeconds = 3
|
||||||
|
|
||||||
|
upstreamConn := &openAIWSCaptureConn{
|
||||||
|
events: [][]byte{
|
||||||
|
[]byte(`{"type":"response.completed","response":{"id":"resp_passthrough_turn_1","model":"gpt-5.1","usage":{"input_tokens":2,"output_tokens":3}}}`),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
captureDialer := &openAIWSCaptureDialer{conn: upstreamConn}
|
||||||
|
svc := &OpenAIGatewayService{
|
||||||
|
cfg: cfg,
|
||||||
|
httpUpstream: &httpUpstreamRecorder{},
|
||||||
|
cache: &stubGatewayCache{},
|
||||||
|
openaiWSResolver: NewOpenAIWSProtocolResolver(cfg),
|
||||||
|
toolCorrector: NewCodexToolCorrector(),
|
||||||
|
openaiWSPassthroughDialer: captureDialer,
|
||||||
|
}
|
||||||
|
|
||||||
|
account := &Account{
|
||||||
|
ID: 452,
|
||||||
|
Name: "openai-ingress-passthrough",
|
||||||
|
Platform: PlatformOpenAI,
|
||||||
|
Type: AccountTypeAPIKey,
|
||||||
|
Status: StatusActive,
|
||||||
|
Schedulable: true,
|
||||||
|
Concurrency: 1,
|
||||||
|
Credentials: map[string]any{
|
||||||
|
"api_key": "sk-test",
|
||||||
|
},
|
||||||
|
Extra: map[string]any{
|
||||||
|
"openai_apikey_responses_websockets_v2_mode": OpenAIWSIngressModePassthrough,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
serverErrCh := make(chan error, 1)
|
||||||
|
resultCh := make(chan *OpenAIForwardResult, 1)
|
||||||
|
hooks := &OpenAIWSIngressHooks{
|
||||||
|
AfterTurn: func(_ int, result *OpenAIForwardResult, turnErr error) {
|
||||||
|
if turnErr == nil && result != nil {
|
||||||
|
resultCh <- result
|
||||||
|
}
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
wsServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
conn, err := coderws.Accept(w, r, &coderws.AcceptOptions{
|
||||||
|
CompressionMode: coderws.CompressionContextTakeover,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
serverErrCh <- err
|
||||||
|
return
|
||||||
|
}
|
||||||
|
defer func() {
|
||||||
|
_ = conn.CloseNow()
|
||||||
|
}()
|
||||||
|
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
ginCtx, _ := gin.CreateTestContext(rec)
|
||||||
|
req := r.Clone(r.Context())
|
||||||
|
req.Header = req.Header.Clone()
|
||||||
|
req.Header.Set("User-Agent", "unit-test-agent/1.0")
|
||||||
|
ginCtx.Request = req
|
||||||
|
|
||||||
|
readCtx, cancel := context.WithTimeout(r.Context(), 3*time.Second)
|
||||||
|
msgType, firstMessage, readErr := conn.Read(readCtx)
|
||||||
|
cancel()
|
||||||
|
if readErr != nil {
|
||||||
|
serverErrCh <- readErr
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if msgType != coderws.MessageText && msgType != coderws.MessageBinary {
|
||||||
|
serverErrCh <- errors.New("unsupported websocket client message type")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
serverErrCh <- svc.ProxyResponsesWebSocketFromClient(r.Context(), ginCtx, conn, account, "sk-test", firstMessage, hooks)
|
||||||
|
}))
|
||||||
|
defer wsServer.Close()
|
||||||
|
|
||||||
|
dialCtx, cancelDial := context.WithTimeout(context.Background(), 3*time.Second)
|
||||||
|
clientConn, _, err := coderws.Dial(dialCtx, "ws"+strings.TrimPrefix(wsServer.URL, "http"), nil)
|
||||||
|
cancelDial()
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer func() {
|
||||||
|
_ = clientConn.CloseNow()
|
||||||
|
}()
|
||||||
|
|
||||||
|
writeCtx, cancelWrite := context.WithTimeout(context.Background(), 3*time.Second)
|
||||||
|
err = clientConn.Write(writeCtx, coderws.MessageText, []byte(`{"type":"response.create","model":"gpt-5.1","stream":false}`))
|
||||||
|
cancelWrite()
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
readCtx, cancelRead := context.WithTimeout(context.Background(), 3*time.Second)
|
||||||
|
_, event, readErr := clientConn.Read(readCtx)
|
||||||
|
cancelRead()
|
||||||
|
require.NoError(t, readErr)
|
||||||
|
require.Equal(t, "response.completed", gjson.GetBytes(event, "type").String())
|
||||||
|
require.Equal(t, "resp_passthrough_turn_1", gjson.GetBytes(event, "response.id").String())
|
||||||
|
_ = clientConn.Close(coderws.StatusNormalClosure, "done")
|
||||||
|
|
||||||
|
select {
|
||||||
|
case serverErr := <-serverErrCh:
|
||||||
|
require.NoError(t, serverErr)
|
||||||
|
case <-time.After(5 * time.Second):
|
||||||
|
t.Fatal("等待 passthrough websocket 结束超时")
|
||||||
|
}
|
||||||
|
|
||||||
|
select {
|
||||||
|
case result := <-resultCh:
|
||||||
|
require.Equal(t, "resp_passthrough_turn_1", result.RequestID)
|
||||||
|
require.True(t, result.OpenAIWSMode)
|
||||||
|
require.Equal(t, 2, result.Usage.InputTokens)
|
||||||
|
require.Equal(t, 3, result.Usage.OutputTokens)
|
||||||
|
case <-time.After(2 * time.Second):
|
||||||
|
t.Fatal("未收到 passthrough turn 结果回调")
|
||||||
|
}
|
||||||
|
|
||||||
|
require.Equal(t, 1, captureDialer.DialCount(), "passthrough 模式应直接建立上游 websocket")
|
||||||
|
require.Len(t, upstreamConn.writes, 1, "passthrough 模式应透传首条 response.create")
|
||||||
|
}
|
||||||
|
|
||||||
func TestOpenAIGatewayService_ProxyResponsesWebSocketFromClient_ModeOffReturnsPolicyViolation(t *testing.T) {
|
func TestOpenAIGatewayService_ProxyResponsesWebSocketFromClient_ModeOffReturnsPolicyViolation(t *testing.T) {
|
||||||
gin.SetMode(gin.TestMode)
|
gin.SetMode(gin.TestMode)
|
||||||
|
|
||||||
|
|||||||
@@ -15,6 +15,7 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||||
|
coderws "github.com/coder/websocket"
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
"github.com/gorilla/websocket"
|
"github.com/gorilla/websocket"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
@@ -1282,6 +1283,18 @@ func (c *openAIWSCaptureConn) ReadMessage(ctx context.Context) ([]byte, error) {
|
|||||||
return event, nil
|
return event, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (c *openAIWSCaptureConn) ReadFrame(ctx context.Context) (coderws.MessageType, []byte, error) {
|
||||||
|
payload, err := c.ReadMessage(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return coderws.MessageText, nil, err
|
||||||
|
}
|
||||||
|
return coderws.MessageText, payload, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *openAIWSCaptureConn) WriteFrame(ctx context.Context, _ coderws.MessageType, payload []byte) error {
|
||||||
|
return c.WriteJSON(ctx, json.RawMessage(payload))
|
||||||
|
}
|
||||||
|
|
||||||
func (c *openAIWSCaptureConn) Ping(ctx context.Context) error {
|
func (c *openAIWSCaptureConn) Ping(ctx context.Context) error {
|
||||||
_ = ctx
|
_ = ctx
|
||||||
return nil
|
return nil
|
||||||
|
|||||||
@@ -69,8 +69,11 @@ func (r *defaultOpenAIWSProtocolResolver) Resolve(account *Account) OpenAIWSProt
|
|||||||
switch mode {
|
switch mode {
|
||||||
case OpenAIWSIngressModeOff:
|
case OpenAIWSIngressModeOff:
|
||||||
return openAIWSHTTPDecision("account_mode_off")
|
return openAIWSHTTPDecision("account_mode_off")
|
||||||
case OpenAIWSIngressModeShared, OpenAIWSIngressModeDedicated:
|
case OpenAIWSIngressModeCtxPool, OpenAIWSIngressModePassthrough:
|
||||||
// continue
|
// continue
|
||||||
|
case OpenAIWSIngressModeShared, OpenAIWSIngressModeDedicated:
|
||||||
|
// 历史值兼容:按 ctx_pool 处理。
|
||||||
|
mode = OpenAIWSIngressModeCtxPool
|
||||||
default:
|
default:
|
||||||
return openAIWSHTTPDecision("account_mode_off")
|
return openAIWSHTTPDecision("account_mode_off")
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -143,21 +143,21 @@ func TestOpenAIWSProtocolResolver_Resolve_ModeRouterV2(t *testing.T) {
|
|||||||
cfg.Gateway.OpenAIWS.APIKeyEnabled = true
|
cfg.Gateway.OpenAIWS.APIKeyEnabled = true
|
||||||
cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true
|
cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true
|
||||||
cfg.Gateway.OpenAIWS.ModeRouterV2Enabled = true
|
cfg.Gateway.OpenAIWS.ModeRouterV2Enabled = true
|
||||||
cfg.Gateway.OpenAIWS.IngressModeDefault = OpenAIWSIngressModeShared
|
cfg.Gateway.OpenAIWS.IngressModeDefault = OpenAIWSIngressModeCtxPool
|
||||||
|
|
||||||
account := &Account{
|
account := &Account{
|
||||||
Platform: PlatformOpenAI,
|
Platform: PlatformOpenAI,
|
||||||
Type: AccountTypeOAuth,
|
Type: AccountTypeOAuth,
|
||||||
Concurrency: 1,
|
Concurrency: 1,
|
||||||
Extra: map[string]any{
|
Extra: map[string]any{
|
||||||
"openai_oauth_responses_websockets_v2_mode": OpenAIWSIngressModeDedicated,
|
"openai_oauth_responses_websockets_v2_mode": OpenAIWSIngressModeCtxPool,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
t.Run("dedicated mode routes to ws v2", func(t *testing.T) {
|
t.Run("ctx_pool mode routes to ws v2", func(t *testing.T) {
|
||||||
decision := NewOpenAIWSProtocolResolver(cfg).Resolve(account)
|
decision := NewOpenAIWSProtocolResolver(cfg).Resolve(account)
|
||||||
require.Equal(t, OpenAIUpstreamTransportResponsesWebsocketV2, decision.Transport)
|
require.Equal(t, OpenAIUpstreamTransportResponsesWebsocketV2, decision.Transport)
|
||||||
require.Equal(t, "ws_v2_mode_dedicated", decision.Reason)
|
require.Equal(t, "ws_v2_mode_ctx_pool", decision.Reason)
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("off mode routes to http", func(t *testing.T) {
|
t.Run("off mode routes to http", func(t *testing.T) {
|
||||||
@@ -174,7 +174,7 @@ func TestOpenAIWSProtocolResolver_Resolve_ModeRouterV2(t *testing.T) {
|
|||||||
require.Equal(t, "account_mode_off", decision.Reason)
|
require.Equal(t, "account_mode_off", decision.Reason)
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("legacy boolean maps to shared in v2 router", func(t *testing.T) {
|
t.Run("legacy boolean maps to ctx_pool in v2 router", func(t *testing.T) {
|
||||||
legacyAccount := &Account{
|
legacyAccount := &Account{
|
||||||
Platform: PlatformOpenAI,
|
Platform: PlatformOpenAI,
|
||||||
Type: AccountTypeAPIKey,
|
Type: AccountTypeAPIKey,
|
||||||
@@ -185,7 +185,21 @@ func TestOpenAIWSProtocolResolver_Resolve_ModeRouterV2(t *testing.T) {
|
|||||||
}
|
}
|
||||||
decision := NewOpenAIWSProtocolResolver(cfg).Resolve(legacyAccount)
|
decision := NewOpenAIWSProtocolResolver(cfg).Resolve(legacyAccount)
|
||||||
require.Equal(t, OpenAIUpstreamTransportResponsesWebsocketV2, decision.Transport)
|
require.Equal(t, OpenAIUpstreamTransportResponsesWebsocketV2, decision.Transport)
|
||||||
require.Equal(t, "ws_v2_mode_shared", decision.Reason)
|
require.Equal(t, "ws_v2_mode_ctx_pool", decision.Reason)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("passthrough mode routes to ws v2", func(t *testing.T) {
|
||||||
|
passthroughAccount := &Account{
|
||||||
|
Platform: PlatformOpenAI,
|
||||||
|
Type: AccountTypeOAuth,
|
||||||
|
Concurrency: 1,
|
||||||
|
Extra: map[string]any{
|
||||||
|
"openai_oauth_responses_websockets_v2_mode": OpenAIWSIngressModePassthrough,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
decision := NewOpenAIWSProtocolResolver(cfg).Resolve(passthroughAccount)
|
||||||
|
require.Equal(t, OpenAIUpstreamTransportResponsesWebsocketV2, decision.Transport)
|
||||||
|
require.Equal(t, "ws_v2_mode_passthrough", decision.Reason)
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("non-positive concurrency is rejected in v2 router", func(t *testing.T) {
|
t.Run("non-positive concurrency is rejected in v2 router", func(t *testing.T) {
|
||||||
@@ -193,7 +207,7 @@ func TestOpenAIWSProtocolResolver_Resolve_ModeRouterV2(t *testing.T) {
|
|||||||
Platform: PlatformOpenAI,
|
Platform: PlatformOpenAI,
|
||||||
Type: AccountTypeOAuth,
|
Type: AccountTypeOAuth,
|
||||||
Extra: map[string]any{
|
Extra: map[string]any{
|
||||||
"openai_oauth_responses_websockets_v2_mode": OpenAIWSIngressModeShared,
|
"openai_oauth_responses_websockets_v2_mode": OpenAIWSIngressModeCtxPool,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
decision := NewOpenAIWSProtocolResolver(cfg).Resolve(invalidConcurrency)
|
decision := NewOpenAIWSProtocolResolver(cfg).Resolve(invalidConcurrency)
|
||||||
|
|||||||
24
backend/internal/service/openai_ws_v2/caddy_adapter.go
Normal file
24
backend/internal/service/openai_ws_v2/caddy_adapter.go
Normal file
@@ -0,0 +1,24 @@
|
|||||||
|
package openai_ws_v2
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
)
|
||||||
|
|
||||||
|
// runCaddyStyleRelay 采用 Caddy reverseproxy 的双向隧道思想:
|
||||||
|
// 连接建立后并发复制两个方向,任一方向退出触发收敛关闭。
|
||||||
|
//
|
||||||
|
// Reference:
|
||||||
|
// - Project: caddyserver/caddy (Apache-2.0)
|
||||||
|
// - Commit: f283062d37c50627d53ca682ebae2ce219b35515
|
||||||
|
// - Files:
|
||||||
|
// - modules/caddyhttp/reverseproxy/streaming.go
|
||||||
|
// - modules/caddyhttp/reverseproxy/reverseproxy.go
|
||||||
|
func runCaddyStyleRelay(
|
||||||
|
ctx context.Context,
|
||||||
|
clientConn FrameConn,
|
||||||
|
upstreamConn FrameConn,
|
||||||
|
firstClientMessage []byte,
|
||||||
|
options RelayOptions,
|
||||||
|
) (RelayResult, *RelayExit) {
|
||||||
|
return Relay(ctx, clientConn, upstreamConn, firstClientMessage, options)
|
||||||
|
}
|
||||||
23
backend/internal/service/openai_ws_v2/entry.go
Normal file
23
backend/internal/service/openai_ws_v2/entry.go
Normal file
@@ -0,0 +1,23 @@
|
|||||||
|
package openai_ws_v2
|
||||||
|
|
||||||
|
import "context"
|
||||||
|
|
||||||
|
// EntryInput 是 passthrough v2 数据面的入口参数。
|
||||||
|
type EntryInput struct {
|
||||||
|
Ctx context.Context
|
||||||
|
ClientConn FrameConn
|
||||||
|
UpstreamConn FrameConn
|
||||||
|
FirstClientMessage []byte
|
||||||
|
Options RelayOptions
|
||||||
|
}
|
||||||
|
|
||||||
|
// RunEntry 是 openai_ws_v2 包对外的统一入口。
|
||||||
|
func RunEntry(input EntryInput) (RelayResult, *RelayExit) {
|
||||||
|
return runCaddyStyleRelay(
|
||||||
|
input.Ctx,
|
||||||
|
input.ClientConn,
|
||||||
|
input.UpstreamConn,
|
||||||
|
input.FirstClientMessage,
|
||||||
|
input.Options,
|
||||||
|
)
|
||||||
|
}
|
||||||
29
backend/internal/service/openai_ws_v2/metrics.go
Normal file
29
backend/internal/service/openai_ws_v2/metrics.go
Normal file
@@ -0,0 +1,29 @@
|
|||||||
|
package openai_ws_v2
|
||||||
|
|
||||||
|
import (
|
||||||
|
"sync/atomic"
|
||||||
|
)
|
||||||
|
|
||||||
|
// MetricsSnapshot 是 OpenAI WS v2 passthrough 路径的轻量运行时指标快照。
|
||||||
|
type MetricsSnapshot struct {
|
||||||
|
SemanticMutationTotal int64 `json:"semantic_mutation_total"`
|
||||||
|
UsageParseFailureTotal int64 `json:"usage_parse_failure_total"`
|
||||||
|
}
|
||||||
|
|
||||||
|
var (
|
||||||
|
// passthrough 路径默认不会做语义改写,该计数通常应保持为 0(保留用于未来防御性校验)。
|
||||||
|
passthroughSemanticMutationTotal atomic.Int64
|
||||||
|
passthroughUsageParseFailureTotal atomic.Int64
|
||||||
|
)
|
||||||
|
|
||||||
|
func recordUsageParseFailure() {
|
||||||
|
passthroughUsageParseFailureTotal.Add(1)
|
||||||
|
}
|
||||||
|
|
||||||
|
// SnapshotMetrics 返回当前 passthrough 指标快照。
|
||||||
|
func SnapshotMetrics() MetricsSnapshot {
|
||||||
|
return MetricsSnapshot{
|
||||||
|
SemanticMutationTotal: passthroughSemanticMutationTotal.Load(),
|
||||||
|
UsageParseFailureTotal: passthroughUsageParseFailureTotal.Load(),
|
||||||
|
}
|
||||||
|
}
|
||||||
807
backend/internal/service/openai_ws_v2/passthrough_relay.go
Normal file
807
backend/internal/service/openai_ws_v2/passthrough_relay.go
Normal file
@@ -0,0 +1,807 @@
|
|||||||
|
package openai_ws_v2
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"io"
|
||||||
|
"net"
|
||||||
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
"sync/atomic"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
coderws "github.com/coder/websocket"
|
||||||
|
"github.com/tidwall/gjson"
|
||||||
|
)
|
||||||
|
|
||||||
|
type FrameConn interface {
|
||||||
|
ReadFrame(ctx context.Context) (coderws.MessageType, []byte, error)
|
||||||
|
WriteFrame(ctx context.Context, msgType coderws.MessageType, payload []byte) error
|
||||||
|
Close() error
|
||||||
|
}
|
||||||
|
|
||||||
|
type Usage struct {
|
||||||
|
InputTokens int
|
||||||
|
OutputTokens int
|
||||||
|
CacheCreationInputTokens int
|
||||||
|
CacheReadInputTokens int
|
||||||
|
}
|
||||||
|
|
||||||
|
type RelayResult struct {
|
||||||
|
RequestModel string
|
||||||
|
Usage Usage
|
||||||
|
RequestID string
|
||||||
|
TerminalEventType string
|
||||||
|
FirstTokenMs *int
|
||||||
|
Duration time.Duration
|
||||||
|
ClientToUpstreamFrames int64
|
||||||
|
UpstreamToClientFrames int64
|
||||||
|
DroppedDownstreamFrames int64
|
||||||
|
}
|
||||||
|
|
||||||
|
type RelayTurnResult struct {
|
||||||
|
RequestModel string
|
||||||
|
Usage Usage
|
||||||
|
RequestID string
|
||||||
|
TerminalEventType string
|
||||||
|
Duration time.Duration
|
||||||
|
FirstTokenMs *int
|
||||||
|
}
|
||||||
|
|
||||||
|
type RelayExit struct {
|
||||||
|
Stage string
|
||||||
|
Err error
|
||||||
|
WroteDownstream bool
|
||||||
|
}
|
||||||
|
|
||||||
|
type RelayOptions struct {
|
||||||
|
WriteTimeout time.Duration
|
||||||
|
IdleTimeout time.Duration
|
||||||
|
UpstreamDrainTimeout time.Duration
|
||||||
|
FirstMessageType coderws.MessageType
|
||||||
|
OnUsageParseFailure func(eventType string, usageRaw string)
|
||||||
|
OnTurnComplete func(turn RelayTurnResult)
|
||||||
|
OnTrace func(event RelayTraceEvent)
|
||||||
|
Now func() time.Time
|
||||||
|
}
|
||||||
|
|
||||||
|
type RelayTraceEvent struct {
|
||||||
|
Stage string
|
||||||
|
Direction string
|
||||||
|
MessageType string
|
||||||
|
PayloadBytes int
|
||||||
|
Graceful bool
|
||||||
|
WroteDownstream bool
|
||||||
|
Error string
|
||||||
|
}
|
||||||
|
|
||||||
|
type relayState struct {
|
||||||
|
usage Usage
|
||||||
|
requestModel string
|
||||||
|
lastResponseID string
|
||||||
|
terminalEventType string
|
||||||
|
firstTokenMs *int
|
||||||
|
turnTimingByID map[string]*relayTurnTiming
|
||||||
|
}
|
||||||
|
|
||||||
|
type relayExitSignal struct {
|
||||||
|
stage string
|
||||||
|
err error
|
||||||
|
graceful bool
|
||||||
|
wroteDownstream bool
|
||||||
|
}
|
||||||
|
|
||||||
|
type observedUpstreamEvent struct {
|
||||||
|
terminal bool
|
||||||
|
eventType string
|
||||||
|
responseID string
|
||||||
|
usage Usage
|
||||||
|
duration time.Duration
|
||||||
|
firstToken *int
|
||||||
|
}
|
||||||
|
|
||||||
|
type relayTurnTiming struct {
|
||||||
|
startAt time.Time
|
||||||
|
firstTokenMs *int
|
||||||
|
}
|
||||||
|
|
||||||
|
func Relay(
|
||||||
|
ctx context.Context,
|
||||||
|
clientConn FrameConn,
|
||||||
|
upstreamConn FrameConn,
|
||||||
|
firstClientMessage []byte,
|
||||||
|
options RelayOptions,
|
||||||
|
) (RelayResult, *RelayExit) {
|
||||||
|
result := RelayResult{RequestModel: strings.TrimSpace(gjson.GetBytes(firstClientMessage, "model").String())}
|
||||||
|
if clientConn == nil || upstreamConn == nil {
|
||||||
|
return result, &RelayExit{Stage: "relay_init", Err: errors.New("relay connection is nil")}
|
||||||
|
}
|
||||||
|
if ctx == nil {
|
||||||
|
ctx = context.Background()
|
||||||
|
}
|
||||||
|
|
||||||
|
nowFn := options.Now
|
||||||
|
if nowFn == nil {
|
||||||
|
nowFn = time.Now
|
||||||
|
}
|
||||||
|
writeTimeout := options.WriteTimeout
|
||||||
|
if writeTimeout <= 0 {
|
||||||
|
writeTimeout = 2 * time.Minute
|
||||||
|
}
|
||||||
|
drainTimeout := options.UpstreamDrainTimeout
|
||||||
|
if drainTimeout <= 0 {
|
||||||
|
drainTimeout = 1200 * time.Millisecond
|
||||||
|
}
|
||||||
|
firstMessageType := options.FirstMessageType
|
||||||
|
if firstMessageType != coderws.MessageBinary {
|
||||||
|
firstMessageType = coderws.MessageText
|
||||||
|
}
|
||||||
|
startAt := nowFn()
|
||||||
|
state := &relayState{requestModel: result.RequestModel}
|
||||||
|
onTrace := options.OnTrace
|
||||||
|
|
||||||
|
relayCtx, relayCancel := context.WithCancel(ctx)
|
||||||
|
defer relayCancel()
|
||||||
|
|
||||||
|
lastActivity := atomic.Int64{}
|
||||||
|
lastActivity.Store(nowFn().UnixNano())
|
||||||
|
markActivity := func() {
|
||||||
|
lastActivity.Store(nowFn().UnixNano())
|
||||||
|
}
|
||||||
|
|
||||||
|
writeUpstream := func(msgType coderws.MessageType, payload []byte) error {
|
||||||
|
writeCtx, cancel := context.WithTimeout(relayCtx, writeTimeout)
|
||||||
|
defer cancel()
|
||||||
|
return upstreamConn.WriteFrame(writeCtx, msgType, payload)
|
||||||
|
}
|
||||||
|
writeClient := func(msgType coderws.MessageType, payload []byte) error {
|
||||||
|
writeCtx, cancel := context.WithTimeout(relayCtx, writeTimeout)
|
||||||
|
defer cancel()
|
||||||
|
return clientConn.WriteFrame(writeCtx, msgType, payload)
|
||||||
|
}
|
||||||
|
|
||||||
|
clientToUpstreamFrames := &atomic.Int64{}
|
||||||
|
upstreamToClientFrames := &atomic.Int64{}
|
||||||
|
droppedDownstreamFrames := &atomic.Int64{}
|
||||||
|
emitRelayTrace(onTrace, RelayTraceEvent{
|
||||||
|
Stage: "relay_start",
|
||||||
|
PayloadBytes: len(firstClientMessage),
|
||||||
|
MessageType: relayMessageTypeString(firstMessageType),
|
||||||
|
})
|
||||||
|
|
||||||
|
if err := writeUpstream(firstMessageType, firstClientMessage); err != nil {
|
||||||
|
result.Duration = nowFn().Sub(startAt)
|
||||||
|
emitRelayTrace(onTrace, RelayTraceEvent{
|
||||||
|
Stage: "write_first_message_failed",
|
||||||
|
Direction: "client_to_upstream",
|
||||||
|
MessageType: relayMessageTypeString(firstMessageType),
|
||||||
|
PayloadBytes: len(firstClientMessage),
|
||||||
|
Error: err.Error(),
|
||||||
|
})
|
||||||
|
return result, &RelayExit{Stage: "write_upstream", Err: err}
|
||||||
|
}
|
||||||
|
clientToUpstreamFrames.Add(1)
|
||||||
|
emitRelayTrace(onTrace, RelayTraceEvent{
|
||||||
|
Stage: "write_first_message_ok",
|
||||||
|
Direction: "client_to_upstream",
|
||||||
|
MessageType: relayMessageTypeString(firstMessageType),
|
||||||
|
PayloadBytes: len(firstClientMessage),
|
||||||
|
})
|
||||||
|
markActivity()
|
||||||
|
|
||||||
|
exitCh := make(chan relayExitSignal, 3)
|
||||||
|
dropDownstreamWrites := atomic.Bool{}
|
||||||
|
go runClientToUpstream(relayCtx, clientConn, writeUpstream, markActivity, clientToUpstreamFrames, onTrace, exitCh)
|
||||||
|
go runUpstreamToClient(
|
||||||
|
relayCtx,
|
||||||
|
upstreamConn,
|
||||||
|
writeClient,
|
||||||
|
startAt,
|
||||||
|
nowFn,
|
||||||
|
state,
|
||||||
|
options.OnUsageParseFailure,
|
||||||
|
options.OnTurnComplete,
|
||||||
|
&dropDownstreamWrites,
|
||||||
|
upstreamToClientFrames,
|
||||||
|
droppedDownstreamFrames,
|
||||||
|
markActivity,
|
||||||
|
onTrace,
|
||||||
|
exitCh,
|
||||||
|
)
|
||||||
|
go runIdleWatchdog(relayCtx, nowFn, options.IdleTimeout, &lastActivity, onTrace, exitCh)
|
||||||
|
|
||||||
|
firstExit := <-exitCh
|
||||||
|
emitRelayTrace(onTrace, RelayTraceEvent{
|
||||||
|
Stage: "first_exit",
|
||||||
|
Direction: relayDirectionFromStage(firstExit.stage),
|
||||||
|
Graceful: firstExit.graceful,
|
||||||
|
WroteDownstream: firstExit.wroteDownstream,
|
||||||
|
Error: relayErrorString(firstExit.err),
|
||||||
|
})
|
||||||
|
combinedWroteDownstream := firstExit.wroteDownstream
|
||||||
|
secondExit := relayExitSignal{graceful: true}
|
||||||
|
hasSecondExit := false
|
||||||
|
|
||||||
|
// 客户端断开后尽力继续读取上游短窗口,捕获延迟 usage/terminal 事件用于计费。
|
||||||
|
if firstExit.stage == "read_client" && firstExit.graceful {
|
||||||
|
dropDownstreamWrites.Store(true)
|
||||||
|
secondExit, hasSecondExit = waitRelayExit(exitCh, drainTimeout)
|
||||||
|
} else {
|
||||||
|
relayCancel()
|
||||||
|
_ = upstreamConn.Close()
|
||||||
|
secondExit, hasSecondExit = waitRelayExit(exitCh, 200*time.Millisecond)
|
||||||
|
}
|
||||||
|
if hasSecondExit {
|
||||||
|
combinedWroteDownstream = combinedWroteDownstream || secondExit.wroteDownstream
|
||||||
|
emitRelayTrace(onTrace, RelayTraceEvent{
|
||||||
|
Stage: "second_exit",
|
||||||
|
Direction: relayDirectionFromStage(secondExit.stage),
|
||||||
|
Graceful: secondExit.graceful,
|
||||||
|
WroteDownstream: secondExit.wroteDownstream,
|
||||||
|
Error: relayErrorString(secondExit.err),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
relayCancel()
|
||||||
|
_ = upstreamConn.Close()
|
||||||
|
|
||||||
|
enrichResult(&result, state, nowFn().Sub(startAt))
|
||||||
|
result.ClientToUpstreamFrames = clientToUpstreamFrames.Load()
|
||||||
|
result.UpstreamToClientFrames = upstreamToClientFrames.Load()
|
||||||
|
result.DroppedDownstreamFrames = droppedDownstreamFrames.Load()
|
||||||
|
if firstExit.stage == "read_client" && firstExit.graceful {
|
||||||
|
stage := "client_disconnected"
|
||||||
|
exitErr := firstExit.err
|
||||||
|
if hasSecondExit && !secondExit.graceful {
|
||||||
|
stage = secondExit.stage
|
||||||
|
exitErr = secondExit.err
|
||||||
|
}
|
||||||
|
if exitErr == nil {
|
||||||
|
exitErr = io.EOF
|
||||||
|
}
|
||||||
|
emitRelayTrace(onTrace, RelayTraceEvent{
|
||||||
|
Stage: "relay_exit",
|
||||||
|
Direction: relayDirectionFromStage(stage),
|
||||||
|
Graceful: false,
|
||||||
|
WroteDownstream: combinedWroteDownstream,
|
||||||
|
Error: relayErrorString(exitErr),
|
||||||
|
})
|
||||||
|
return result, &RelayExit{
|
||||||
|
Stage: stage,
|
||||||
|
Err: exitErr,
|
||||||
|
WroteDownstream: combinedWroteDownstream,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if firstExit.graceful && (!hasSecondExit || secondExit.graceful) {
|
||||||
|
emitRelayTrace(onTrace, RelayTraceEvent{
|
||||||
|
Stage: "relay_complete",
|
||||||
|
Graceful: true,
|
||||||
|
WroteDownstream: combinedWroteDownstream,
|
||||||
|
})
|
||||||
|
_ = clientConn.Close()
|
||||||
|
return result, nil
|
||||||
|
}
|
||||||
|
if !firstExit.graceful {
|
||||||
|
emitRelayTrace(onTrace, RelayTraceEvent{
|
||||||
|
Stage: "relay_exit",
|
||||||
|
Direction: relayDirectionFromStage(firstExit.stage),
|
||||||
|
Graceful: false,
|
||||||
|
WroteDownstream: combinedWroteDownstream,
|
||||||
|
Error: relayErrorString(firstExit.err),
|
||||||
|
})
|
||||||
|
return result, &RelayExit{
|
||||||
|
Stage: firstExit.stage,
|
||||||
|
Err: firstExit.err,
|
||||||
|
WroteDownstream: combinedWroteDownstream,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if hasSecondExit && !secondExit.graceful {
|
||||||
|
emitRelayTrace(onTrace, RelayTraceEvent{
|
||||||
|
Stage: "relay_exit",
|
||||||
|
Direction: relayDirectionFromStage(secondExit.stage),
|
||||||
|
Graceful: false,
|
||||||
|
WroteDownstream: combinedWroteDownstream,
|
||||||
|
Error: relayErrorString(secondExit.err),
|
||||||
|
})
|
||||||
|
return result, &RelayExit{
|
||||||
|
Stage: secondExit.stage,
|
||||||
|
Err: secondExit.err,
|
||||||
|
WroteDownstream: combinedWroteDownstream,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
emitRelayTrace(onTrace, RelayTraceEvent{
|
||||||
|
Stage: "relay_complete",
|
||||||
|
Graceful: true,
|
||||||
|
WroteDownstream: combinedWroteDownstream,
|
||||||
|
})
|
||||||
|
_ = clientConn.Close()
|
||||||
|
return result, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func runClientToUpstream(
|
||||||
|
ctx context.Context,
|
||||||
|
clientConn FrameConn,
|
||||||
|
writeUpstream func(msgType coderws.MessageType, payload []byte) error,
|
||||||
|
markActivity func(),
|
||||||
|
forwardedFrames *atomic.Int64,
|
||||||
|
onTrace func(event RelayTraceEvent),
|
||||||
|
exitCh chan<- relayExitSignal,
|
||||||
|
) {
|
||||||
|
for {
|
||||||
|
msgType, payload, err := clientConn.ReadFrame(ctx)
|
||||||
|
if err != nil {
|
||||||
|
emitRelayTrace(onTrace, RelayTraceEvent{
|
||||||
|
Stage: "read_client_failed",
|
||||||
|
Direction: "client_to_upstream",
|
||||||
|
Error: err.Error(),
|
||||||
|
Graceful: isDisconnectError(err),
|
||||||
|
})
|
||||||
|
exitCh <- relayExitSignal{stage: "read_client", err: err, graceful: isDisconnectError(err)}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
markActivity()
|
||||||
|
if err := writeUpstream(msgType, payload); err != nil {
|
||||||
|
emitRelayTrace(onTrace, RelayTraceEvent{
|
||||||
|
Stage: "write_upstream_failed",
|
||||||
|
Direction: "client_to_upstream",
|
||||||
|
MessageType: relayMessageTypeString(msgType),
|
||||||
|
PayloadBytes: len(payload),
|
||||||
|
Error: err.Error(),
|
||||||
|
})
|
||||||
|
exitCh <- relayExitSignal{stage: "write_upstream", err: err}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if forwardedFrames != nil {
|
||||||
|
forwardedFrames.Add(1)
|
||||||
|
}
|
||||||
|
markActivity()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func runUpstreamToClient(
|
||||||
|
ctx context.Context,
|
||||||
|
upstreamConn FrameConn,
|
||||||
|
writeClient func(msgType coderws.MessageType, payload []byte) error,
|
||||||
|
startAt time.Time,
|
||||||
|
nowFn func() time.Time,
|
||||||
|
state *relayState,
|
||||||
|
onUsageParseFailure func(eventType string, usageRaw string),
|
||||||
|
onTurnComplete func(turn RelayTurnResult),
|
||||||
|
dropDownstreamWrites *atomic.Bool,
|
||||||
|
forwardedFrames *atomic.Int64,
|
||||||
|
droppedFrames *atomic.Int64,
|
||||||
|
markActivity func(),
|
||||||
|
onTrace func(event RelayTraceEvent),
|
||||||
|
exitCh chan<- relayExitSignal,
|
||||||
|
) {
|
||||||
|
wroteDownstream := false
|
||||||
|
for {
|
||||||
|
msgType, payload, err := upstreamConn.ReadFrame(ctx)
|
||||||
|
if err != nil {
|
||||||
|
emitRelayTrace(onTrace, RelayTraceEvent{
|
||||||
|
Stage: "read_upstream_failed",
|
||||||
|
Direction: "upstream_to_client",
|
||||||
|
Error: err.Error(),
|
||||||
|
Graceful: isDisconnectError(err),
|
||||||
|
WroteDownstream: wroteDownstream,
|
||||||
|
})
|
||||||
|
exitCh <- relayExitSignal{
|
||||||
|
stage: "read_upstream",
|
||||||
|
err: err,
|
||||||
|
graceful: isDisconnectError(err),
|
||||||
|
wroteDownstream: wroteDownstream,
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
markActivity()
|
||||||
|
observedEvent := observedUpstreamEvent{}
|
||||||
|
switch msgType {
|
||||||
|
case coderws.MessageText:
|
||||||
|
observedEvent = observeUpstreamMessage(state, payload, startAt, nowFn, onUsageParseFailure)
|
||||||
|
case coderws.MessageBinary:
|
||||||
|
// binary frame 直接透传,不进入 JSON 观测路径(避免无效解析开销)。
|
||||||
|
}
|
||||||
|
emitTurnComplete(onTurnComplete, state, observedEvent)
|
||||||
|
if dropDownstreamWrites != nil && dropDownstreamWrites.Load() {
|
||||||
|
if droppedFrames != nil {
|
||||||
|
droppedFrames.Add(1)
|
||||||
|
}
|
||||||
|
emitRelayTrace(onTrace, RelayTraceEvent{
|
||||||
|
Stage: "drop_downstream_frame",
|
||||||
|
Direction: "upstream_to_client",
|
||||||
|
MessageType: relayMessageTypeString(msgType),
|
||||||
|
PayloadBytes: len(payload),
|
||||||
|
WroteDownstream: wroteDownstream,
|
||||||
|
})
|
||||||
|
if observedEvent.terminal {
|
||||||
|
exitCh <- relayExitSignal{
|
||||||
|
stage: "drain_terminal",
|
||||||
|
graceful: true,
|
||||||
|
wroteDownstream: wroteDownstream,
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
markActivity()
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if err := writeClient(msgType, payload); err != nil {
|
||||||
|
emitRelayTrace(onTrace, RelayTraceEvent{
|
||||||
|
Stage: "write_client_failed",
|
||||||
|
Direction: "upstream_to_client",
|
||||||
|
MessageType: relayMessageTypeString(msgType),
|
||||||
|
PayloadBytes: len(payload),
|
||||||
|
WroteDownstream: wroteDownstream,
|
||||||
|
Error: err.Error(),
|
||||||
|
})
|
||||||
|
exitCh <- relayExitSignal{stage: "write_client", err: err, wroteDownstream: wroteDownstream}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
wroteDownstream = true
|
||||||
|
if forwardedFrames != nil {
|
||||||
|
forwardedFrames.Add(1)
|
||||||
|
}
|
||||||
|
markActivity()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func runIdleWatchdog(
|
||||||
|
ctx context.Context,
|
||||||
|
nowFn func() time.Time,
|
||||||
|
idleTimeout time.Duration,
|
||||||
|
lastActivity *atomic.Int64,
|
||||||
|
onTrace func(event RelayTraceEvent),
|
||||||
|
exitCh chan<- relayExitSignal,
|
||||||
|
) {
|
||||||
|
if idleTimeout <= 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
checkInterval := minDuration(idleTimeout/4, 5*time.Second)
|
||||||
|
if checkInterval < time.Second {
|
||||||
|
checkInterval = time.Second
|
||||||
|
}
|
||||||
|
ticker := time.NewTicker(checkInterval)
|
||||||
|
defer ticker.Stop()
|
||||||
|
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
return
|
||||||
|
case <-ticker.C:
|
||||||
|
last := time.Unix(0, lastActivity.Load())
|
||||||
|
if nowFn().Sub(last) < idleTimeout {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
emitRelayTrace(onTrace, RelayTraceEvent{
|
||||||
|
Stage: "idle_timeout_triggered",
|
||||||
|
Direction: "watchdog",
|
||||||
|
Error: context.DeadlineExceeded.Error(),
|
||||||
|
})
|
||||||
|
exitCh <- relayExitSignal{stage: "idle_timeout", err: context.DeadlineExceeded}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func emitRelayTrace(onTrace func(event RelayTraceEvent), event RelayTraceEvent) {
|
||||||
|
if onTrace == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
onTrace(event)
|
||||||
|
}
|
||||||
|
|
||||||
|
func relayMessageTypeString(msgType coderws.MessageType) string {
|
||||||
|
switch msgType {
|
||||||
|
case coderws.MessageText:
|
||||||
|
return "text"
|
||||||
|
case coderws.MessageBinary:
|
||||||
|
return "binary"
|
||||||
|
default:
|
||||||
|
return "unknown(" + strconv.Itoa(int(msgType)) + ")"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func relayDirectionFromStage(stage string) string {
|
||||||
|
switch stage {
|
||||||
|
case "read_client", "write_upstream":
|
||||||
|
return "client_to_upstream"
|
||||||
|
case "read_upstream", "write_client", "drain_terminal":
|
||||||
|
return "upstream_to_client"
|
||||||
|
case "idle_timeout":
|
||||||
|
return "watchdog"
|
||||||
|
default:
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func relayErrorString(err error) string {
|
||||||
|
if err == nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
return err.Error()
|
||||||
|
}
|
||||||
|
|
||||||
|
func observeUpstreamMessage(
|
||||||
|
state *relayState,
|
||||||
|
message []byte,
|
||||||
|
startAt time.Time,
|
||||||
|
nowFn func() time.Time,
|
||||||
|
onUsageParseFailure func(eventType string, usageRaw string),
|
||||||
|
) observedUpstreamEvent {
|
||||||
|
if state == nil || len(message) == 0 {
|
||||||
|
return observedUpstreamEvent{}
|
||||||
|
}
|
||||||
|
values := gjson.GetManyBytes(message, "type", "response.id", "response_id", "id")
|
||||||
|
eventType := strings.TrimSpace(values[0].String())
|
||||||
|
if eventType == "" {
|
||||||
|
return observedUpstreamEvent{}
|
||||||
|
}
|
||||||
|
responseID := strings.TrimSpace(values[1].String())
|
||||||
|
if responseID == "" {
|
||||||
|
responseID = strings.TrimSpace(values[2].String())
|
||||||
|
}
|
||||||
|
// 仅 terminal 事件兜底读取顶层 id,避免把 event_id 当成 response_id 关联到 turn。
|
||||||
|
if responseID == "" && isTerminalEvent(eventType) {
|
||||||
|
responseID = strings.TrimSpace(values[3].String())
|
||||||
|
}
|
||||||
|
now := nowFn()
|
||||||
|
|
||||||
|
if state.firstTokenMs == nil && isTokenEvent(eventType) {
|
||||||
|
ms := int(now.Sub(startAt).Milliseconds())
|
||||||
|
if ms >= 0 {
|
||||||
|
state.firstTokenMs = &ms
|
||||||
|
}
|
||||||
|
}
|
||||||
|
parsedUsage := parseUsageAndAccumulate(state, message, eventType, onUsageParseFailure)
|
||||||
|
observed := observedUpstreamEvent{
|
||||||
|
eventType: eventType,
|
||||||
|
responseID: responseID,
|
||||||
|
usage: parsedUsage,
|
||||||
|
}
|
||||||
|
if responseID != "" {
|
||||||
|
turnTiming := openAIWSRelayGetOrInitTurnTiming(state, responseID, now)
|
||||||
|
if turnTiming != nil && turnTiming.firstTokenMs == nil && isTokenEvent(eventType) {
|
||||||
|
ms := int(now.Sub(turnTiming.startAt).Milliseconds())
|
||||||
|
if ms >= 0 {
|
||||||
|
turnTiming.firstTokenMs = &ms
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if !isTerminalEvent(eventType) {
|
||||||
|
return observed
|
||||||
|
}
|
||||||
|
observed.terminal = true
|
||||||
|
state.terminalEventType = eventType
|
||||||
|
if responseID != "" {
|
||||||
|
state.lastResponseID = responseID
|
||||||
|
if turnTiming, ok := openAIWSRelayDeleteTurnTiming(state, responseID); ok {
|
||||||
|
duration := now.Sub(turnTiming.startAt)
|
||||||
|
if duration < 0 {
|
||||||
|
duration = 0
|
||||||
|
}
|
||||||
|
observed.duration = duration
|
||||||
|
observed.firstToken = openAIWSRelayCloneIntPtr(turnTiming.firstTokenMs)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return observed
|
||||||
|
}
|
||||||
|
|
||||||
|
func emitTurnComplete(
|
||||||
|
onTurnComplete func(turn RelayTurnResult),
|
||||||
|
state *relayState,
|
||||||
|
observed observedUpstreamEvent,
|
||||||
|
) {
|
||||||
|
if onTurnComplete == nil || !observed.terminal {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
responseID := strings.TrimSpace(observed.responseID)
|
||||||
|
if responseID == "" {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
requestModel := ""
|
||||||
|
if state != nil {
|
||||||
|
requestModel = state.requestModel
|
||||||
|
}
|
||||||
|
onTurnComplete(RelayTurnResult{
|
||||||
|
RequestModel: requestModel,
|
||||||
|
Usage: observed.usage,
|
||||||
|
RequestID: responseID,
|
||||||
|
TerminalEventType: observed.eventType,
|
||||||
|
Duration: observed.duration,
|
||||||
|
FirstTokenMs: openAIWSRelayCloneIntPtr(observed.firstToken),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func openAIWSRelayGetOrInitTurnTiming(state *relayState, responseID string, now time.Time) *relayTurnTiming {
|
||||||
|
if state == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
if state.turnTimingByID == nil {
|
||||||
|
state.turnTimingByID = make(map[string]*relayTurnTiming, 8)
|
||||||
|
}
|
||||||
|
timing, ok := state.turnTimingByID[responseID]
|
||||||
|
if !ok || timing == nil || timing.startAt.IsZero() {
|
||||||
|
timing = &relayTurnTiming{startAt: now}
|
||||||
|
state.turnTimingByID[responseID] = timing
|
||||||
|
return timing
|
||||||
|
}
|
||||||
|
return timing
|
||||||
|
}
|
||||||
|
|
||||||
|
func openAIWSRelayDeleteTurnTiming(state *relayState, responseID string) (relayTurnTiming, bool) {
|
||||||
|
if state == nil || state.turnTimingByID == nil {
|
||||||
|
return relayTurnTiming{}, false
|
||||||
|
}
|
||||||
|
timing, ok := state.turnTimingByID[responseID]
|
||||||
|
if !ok || timing == nil {
|
||||||
|
return relayTurnTiming{}, false
|
||||||
|
}
|
||||||
|
delete(state.turnTimingByID, responseID)
|
||||||
|
return *timing, true
|
||||||
|
}
|
||||||
|
|
||||||
|
func openAIWSRelayCloneIntPtr(v *int) *int {
|
||||||
|
if v == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
cloned := *v
|
||||||
|
return &cloned
|
||||||
|
}
|
||||||
|
|
||||||
|
func parseUsageAndAccumulate(
|
||||||
|
state *relayState,
|
||||||
|
message []byte,
|
||||||
|
eventType string,
|
||||||
|
onParseFailure func(eventType string, usageRaw string),
|
||||||
|
) Usage {
|
||||||
|
if state == nil || len(message) == 0 || !shouldParseUsage(eventType) {
|
||||||
|
return Usage{}
|
||||||
|
}
|
||||||
|
usageResult := gjson.GetBytes(message, "response.usage")
|
||||||
|
if !usageResult.Exists() {
|
||||||
|
return Usage{}
|
||||||
|
}
|
||||||
|
usageRaw := strings.TrimSpace(usageResult.Raw)
|
||||||
|
if usageRaw == "" || !strings.HasPrefix(usageRaw, "{") {
|
||||||
|
recordUsageParseFailure()
|
||||||
|
if onParseFailure != nil {
|
||||||
|
onParseFailure(eventType, usageRaw)
|
||||||
|
}
|
||||||
|
return Usage{}
|
||||||
|
}
|
||||||
|
|
||||||
|
inputResult := gjson.GetBytes(message, "response.usage.input_tokens")
|
||||||
|
outputResult := gjson.GetBytes(message, "response.usage.output_tokens")
|
||||||
|
cachedResult := gjson.GetBytes(message, "response.usage.input_tokens_details.cached_tokens")
|
||||||
|
|
||||||
|
inputTokens, inputOK := parseUsageIntField(inputResult, true)
|
||||||
|
outputTokens, outputOK := parseUsageIntField(outputResult, true)
|
||||||
|
cachedTokens, cachedOK := parseUsageIntField(cachedResult, false)
|
||||||
|
if !inputOK || !outputOK || !cachedOK {
|
||||||
|
recordUsageParseFailure()
|
||||||
|
if onParseFailure != nil {
|
||||||
|
onParseFailure(eventType, usageRaw)
|
||||||
|
}
|
||||||
|
// 解析失败时不做部分字段累加,避免计费 usage 出现“半有效”状态。
|
||||||
|
return Usage{}
|
||||||
|
}
|
||||||
|
parsedUsage := Usage{
|
||||||
|
InputTokens: inputTokens,
|
||||||
|
OutputTokens: outputTokens,
|
||||||
|
CacheReadInputTokens: cachedTokens,
|
||||||
|
}
|
||||||
|
|
||||||
|
state.usage.InputTokens += parsedUsage.InputTokens
|
||||||
|
state.usage.OutputTokens += parsedUsage.OutputTokens
|
||||||
|
state.usage.CacheReadInputTokens += parsedUsage.CacheReadInputTokens
|
||||||
|
return parsedUsage
|
||||||
|
}
|
||||||
|
|
||||||
|
func parseUsageIntField(value gjson.Result, required bool) (int, bool) {
|
||||||
|
if !value.Exists() {
|
||||||
|
return 0, !required
|
||||||
|
}
|
||||||
|
if value.Type != gjson.Number {
|
||||||
|
return 0, false
|
||||||
|
}
|
||||||
|
return int(value.Int()), true
|
||||||
|
}
|
||||||
|
|
||||||
|
func enrichResult(result *RelayResult, state *relayState, duration time.Duration) {
|
||||||
|
if result == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
result.Duration = duration
|
||||||
|
if state == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
result.RequestModel = state.requestModel
|
||||||
|
result.Usage = state.usage
|
||||||
|
result.RequestID = state.lastResponseID
|
||||||
|
result.TerminalEventType = state.terminalEventType
|
||||||
|
result.FirstTokenMs = state.firstTokenMs
|
||||||
|
}
|
||||||
|
|
||||||
|
func isDisconnectError(err error) bool {
|
||||||
|
if err == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if errors.Is(err, io.EOF) || errors.Is(err, net.ErrClosed) || errors.Is(err, context.Canceled) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
switch coderws.CloseStatus(err) {
|
||||||
|
case coderws.StatusNormalClosure, coderws.StatusGoingAway, coderws.StatusNoStatusRcvd, coderws.StatusAbnormalClosure:
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
message := strings.ToLower(strings.TrimSpace(err.Error()))
|
||||||
|
if message == "" {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return strings.Contains(message, "failed to read frame header: eof") ||
|
||||||
|
strings.Contains(message, "unexpected eof") ||
|
||||||
|
strings.Contains(message, "use of closed network connection") ||
|
||||||
|
strings.Contains(message, "connection reset by peer") ||
|
||||||
|
strings.Contains(message, "broken pipe")
|
||||||
|
}
|
||||||
|
|
||||||
|
func isTerminalEvent(eventType string) bool {
|
||||||
|
switch eventType {
|
||||||
|
case "response.completed", "response.done", "response.failed", "response.incomplete", "response.cancelled", "response.canceled":
|
||||||
|
return true
|
||||||
|
default:
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func shouldParseUsage(eventType string) bool {
|
||||||
|
switch eventType {
|
||||||
|
case "response.completed", "response.done", "response.failed":
|
||||||
|
return true
|
||||||
|
default:
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func isTokenEvent(eventType string) bool {
|
||||||
|
if eventType == "" {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
switch eventType {
|
||||||
|
case "response.created", "response.in_progress", "response.output_item.added", "response.output_item.done":
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if strings.Contains(eventType, ".delta") {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
if strings.HasPrefix(eventType, "response.output_text") {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
if strings.HasPrefix(eventType, "response.output") {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
return eventType == "response.completed" || eventType == "response.done"
|
||||||
|
}
|
||||||
|
|
||||||
|
func minDuration(a, b time.Duration) time.Duration {
|
||||||
|
if a <= 0 {
|
||||||
|
return b
|
||||||
|
}
|
||||||
|
if b <= 0 {
|
||||||
|
return a
|
||||||
|
}
|
||||||
|
if a < b {
|
||||||
|
return a
|
||||||
|
}
|
||||||
|
return b
|
||||||
|
}
|
||||||
|
|
||||||
|
func waitRelayExit(exitCh <-chan relayExitSignal, timeout time.Duration) (relayExitSignal, bool) {
|
||||||
|
if timeout <= 0 {
|
||||||
|
timeout = 200 * time.Millisecond
|
||||||
|
}
|
||||||
|
select {
|
||||||
|
case sig := <-exitCh:
|
||||||
|
return sig, true
|
||||||
|
case <-time.After(timeout):
|
||||||
|
return relayExitSignal{}, false
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,432 @@
|
|||||||
|
package openai_ws_v2
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"io"
|
||||||
|
"net"
|
||||||
|
"sync/atomic"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
coderws "github.com/coder/websocket"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
"github.com/tidwall/gjson"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestRunEntry_DelegatesRelay(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
clientConn := newPassthroughTestFrameConn(nil, false)
|
||||||
|
upstreamConn := newPassthroughTestFrameConn([]passthroughTestFrame{
|
||||||
|
{
|
||||||
|
msgType: coderws.MessageText,
|
||||||
|
payload: []byte(`{"type":"response.completed","response":{"id":"resp_entry","usage":{"input_tokens":1,"output_tokens":1}}}`),
|
||||||
|
},
|
||||||
|
}, true)
|
||||||
|
|
||||||
|
result, relayExit := RunEntry(EntryInput{
|
||||||
|
Ctx: context.Background(),
|
||||||
|
ClientConn: clientConn,
|
||||||
|
UpstreamConn: upstreamConn,
|
||||||
|
FirstClientMessage: []byte(`{"type":"response.create","model":"gpt-4o","input":[]}`),
|
||||||
|
})
|
||||||
|
require.Nil(t, relayExit)
|
||||||
|
require.Equal(t, "resp_entry", result.RequestID)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRunClientToUpstream_ErrorPaths(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
t.Run("read client eof", func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
exitCh := make(chan relayExitSignal, 1)
|
||||||
|
runClientToUpstream(
|
||||||
|
context.Background(),
|
||||||
|
newPassthroughTestFrameConn(nil, true),
|
||||||
|
func(_ coderws.MessageType, _ []byte) error { return nil },
|
||||||
|
func() {},
|
||||||
|
nil,
|
||||||
|
nil,
|
||||||
|
exitCh,
|
||||||
|
)
|
||||||
|
sig := <-exitCh
|
||||||
|
require.Equal(t, "read_client", sig.stage)
|
||||||
|
require.True(t, sig.graceful)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("write upstream failed", func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
exitCh := make(chan relayExitSignal, 1)
|
||||||
|
runClientToUpstream(
|
||||||
|
context.Background(),
|
||||||
|
newPassthroughTestFrameConn([]passthroughTestFrame{
|
||||||
|
{msgType: coderws.MessageText, payload: []byte(`{"x":1}`)},
|
||||||
|
}, true),
|
||||||
|
func(_ coderws.MessageType, _ []byte) error { return errors.New("boom") },
|
||||||
|
func() {},
|
||||||
|
nil,
|
||||||
|
nil,
|
||||||
|
exitCh,
|
||||||
|
)
|
||||||
|
sig := <-exitCh
|
||||||
|
require.Equal(t, "write_upstream", sig.stage)
|
||||||
|
require.False(t, sig.graceful)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("forwarded counter and trace callback", func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
exitCh := make(chan relayExitSignal, 1)
|
||||||
|
forwarded := &atomic.Int64{}
|
||||||
|
traces := make([]RelayTraceEvent, 0, 2)
|
||||||
|
runClientToUpstream(
|
||||||
|
context.Background(),
|
||||||
|
newPassthroughTestFrameConn([]passthroughTestFrame{
|
||||||
|
{msgType: coderws.MessageText, payload: []byte(`{"x":1}`)},
|
||||||
|
}, true),
|
||||||
|
func(_ coderws.MessageType, _ []byte) error { return nil },
|
||||||
|
func() {},
|
||||||
|
forwarded,
|
||||||
|
func(event RelayTraceEvent) {
|
||||||
|
traces = append(traces, event)
|
||||||
|
},
|
||||||
|
exitCh,
|
||||||
|
)
|
||||||
|
sig := <-exitCh
|
||||||
|
require.Equal(t, "read_client", sig.stage)
|
||||||
|
require.Equal(t, int64(1), forwarded.Load())
|
||||||
|
require.NotEmpty(t, traces)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRunUpstreamToClient_ErrorAndDropPaths(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
t.Run("read upstream eof", func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
exitCh := make(chan relayExitSignal, 1)
|
||||||
|
drop := &atomic.Bool{}
|
||||||
|
drop.Store(false)
|
||||||
|
runUpstreamToClient(
|
||||||
|
context.Background(),
|
||||||
|
newPassthroughTestFrameConn(nil, true),
|
||||||
|
func(_ coderws.MessageType, _ []byte) error { return nil },
|
||||||
|
time.Now(),
|
||||||
|
time.Now,
|
||||||
|
&relayState{},
|
||||||
|
nil,
|
||||||
|
nil,
|
||||||
|
drop,
|
||||||
|
nil,
|
||||||
|
nil,
|
||||||
|
func() {},
|
||||||
|
nil,
|
||||||
|
exitCh,
|
||||||
|
)
|
||||||
|
sig := <-exitCh
|
||||||
|
require.Equal(t, "read_upstream", sig.stage)
|
||||||
|
require.True(t, sig.graceful)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("write client failed", func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
exitCh := make(chan relayExitSignal, 1)
|
||||||
|
drop := &atomic.Bool{}
|
||||||
|
drop.Store(false)
|
||||||
|
runUpstreamToClient(
|
||||||
|
context.Background(),
|
||||||
|
newPassthroughTestFrameConn([]passthroughTestFrame{
|
||||||
|
{msgType: coderws.MessageText, payload: []byte(`{"type":"response.output_text.delta","delta":"x"}`)},
|
||||||
|
}, true),
|
||||||
|
func(_ coderws.MessageType, _ []byte) error { return errors.New("write failed") },
|
||||||
|
time.Now(),
|
||||||
|
time.Now,
|
||||||
|
&relayState{},
|
||||||
|
nil,
|
||||||
|
nil,
|
||||||
|
drop,
|
||||||
|
nil,
|
||||||
|
nil,
|
||||||
|
func() {},
|
||||||
|
nil,
|
||||||
|
exitCh,
|
||||||
|
)
|
||||||
|
sig := <-exitCh
|
||||||
|
require.Equal(t, "write_client", sig.stage)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("drop downstream and stop on terminal", func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
exitCh := make(chan relayExitSignal, 1)
|
||||||
|
drop := &atomic.Bool{}
|
||||||
|
drop.Store(true)
|
||||||
|
dropped := &atomic.Int64{}
|
||||||
|
runUpstreamToClient(
|
||||||
|
context.Background(),
|
||||||
|
newPassthroughTestFrameConn([]passthroughTestFrame{
|
||||||
|
{
|
||||||
|
msgType: coderws.MessageText,
|
||||||
|
payload: []byte(`{"type":"response.completed","response":{"id":"resp_drop","usage":{"input_tokens":1,"output_tokens":1}}}`),
|
||||||
|
},
|
||||||
|
}, true),
|
||||||
|
func(_ coderws.MessageType, _ []byte) error { return nil },
|
||||||
|
time.Now(),
|
||||||
|
time.Now,
|
||||||
|
&relayState{},
|
||||||
|
nil,
|
||||||
|
nil,
|
||||||
|
drop,
|
||||||
|
nil,
|
||||||
|
dropped,
|
||||||
|
func() {},
|
||||||
|
nil,
|
||||||
|
exitCh,
|
||||||
|
)
|
||||||
|
sig := <-exitCh
|
||||||
|
require.Equal(t, "drain_terminal", sig.stage)
|
||||||
|
require.True(t, sig.graceful)
|
||||||
|
require.Equal(t, int64(1), dropped.Load())
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRunIdleWatchdog_NoTimeoutWhenDisabled(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
exitCh := make(chan relayExitSignal, 1)
|
||||||
|
lastActivity := &atomic.Int64{}
|
||||||
|
lastActivity.Store(time.Now().UnixNano())
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
go runIdleWatchdog(ctx, time.Now, 0, lastActivity, nil, exitCh)
|
||||||
|
select {
|
||||||
|
case <-exitCh:
|
||||||
|
t.Fatal("unexpected idle timeout signal")
|
||||||
|
case <-time.After(200 * time.Millisecond):
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestHelperFunctionsCoverage(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
require.Equal(t, "text", relayMessageTypeString(coderws.MessageText))
|
||||||
|
require.Equal(t, "binary", relayMessageTypeString(coderws.MessageBinary))
|
||||||
|
require.Contains(t, relayMessageTypeString(coderws.MessageType(99)), "unknown(")
|
||||||
|
|
||||||
|
require.Equal(t, "", relayErrorString(nil))
|
||||||
|
require.Equal(t, "x", relayErrorString(errors.New("x")))
|
||||||
|
|
||||||
|
require.True(t, isDisconnectError(io.EOF))
|
||||||
|
require.True(t, isDisconnectError(net.ErrClosed))
|
||||||
|
require.True(t, isDisconnectError(context.Canceled))
|
||||||
|
require.True(t, isDisconnectError(coderws.CloseError{Code: coderws.StatusGoingAway}))
|
||||||
|
require.True(t, isDisconnectError(errors.New("broken pipe")))
|
||||||
|
require.False(t, isDisconnectError(errors.New("unrelated")))
|
||||||
|
|
||||||
|
require.True(t, isTokenEvent("response.output_text.delta"))
|
||||||
|
require.True(t, isTokenEvent("response.output_audio.delta"))
|
||||||
|
require.True(t, isTokenEvent("response.completed"))
|
||||||
|
require.False(t, isTokenEvent(""))
|
||||||
|
require.False(t, isTokenEvent("response.created"))
|
||||||
|
|
||||||
|
require.Equal(t, 2*time.Second, minDuration(2*time.Second, 5*time.Second))
|
||||||
|
require.Equal(t, 2*time.Second, minDuration(5*time.Second, 2*time.Second))
|
||||||
|
require.Equal(t, 5*time.Second, minDuration(0, 5*time.Second))
|
||||||
|
require.Equal(t, 2*time.Second, minDuration(2*time.Second, 0))
|
||||||
|
|
||||||
|
ch := make(chan relayExitSignal, 1)
|
||||||
|
ch <- relayExitSignal{stage: "ok"}
|
||||||
|
sig, ok := waitRelayExit(ch, 10*time.Millisecond)
|
||||||
|
require.True(t, ok)
|
||||||
|
require.Equal(t, "ok", sig.stage)
|
||||||
|
ch <- relayExitSignal{stage: "ok2"}
|
||||||
|
sig, ok = waitRelayExit(ch, 0)
|
||||||
|
require.True(t, ok)
|
||||||
|
require.Equal(t, "ok2", sig.stage)
|
||||||
|
_, ok = waitRelayExit(ch, 10*time.Millisecond)
|
||||||
|
require.False(t, ok)
|
||||||
|
|
||||||
|
n, ok := parseUsageIntField(gjson.Get(`{"n":3}`, "n"), true)
|
||||||
|
require.True(t, ok)
|
||||||
|
require.Equal(t, 3, n)
|
||||||
|
_, ok = parseUsageIntField(gjson.Get(`{"n":"x"}`, "n"), true)
|
||||||
|
require.False(t, ok)
|
||||||
|
n, ok = parseUsageIntField(gjson.Result{}, false)
|
||||||
|
require.True(t, ok)
|
||||||
|
require.Equal(t, 0, n)
|
||||||
|
_, ok = parseUsageIntField(gjson.Result{}, true)
|
||||||
|
require.False(t, ok)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestParseUsageAndEnrichCoverage(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
state := &relayState{}
|
||||||
|
parseUsageAndAccumulate(state, []byte(`{"type":"response.completed","response":{"usage":{"input_tokens":"bad"}}}`), "response.completed", nil)
|
||||||
|
require.Equal(t, 0, state.usage.InputTokens)
|
||||||
|
|
||||||
|
parseUsageAndAccumulate(
|
||||||
|
state,
|
||||||
|
[]byte(`{"type":"response.completed","response":{"usage":{"input_tokens":9,"output_tokens":"bad","input_tokens_details":{"cached_tokens":2}}}}`),
|
||||||
|
"response.completed",
|
||||||
|
nil,
|
||||||
|
)
|
||||||
|
require.Equal(t, 0, state.usage.InputTokens, "部分字段解析失败时不应累加 usage")
|
||||||
|
require.Equal(t, 0, state.usage.OutputTokens)
|
||||||
|
require.Equal(t, 0, state.usage.CacheReadInputTokens)
|
||||||
|
|
||||||
|
parseUsageAndAccumulate(
|
||||||
|
state,
|
||||||
|
[]byte(`{"type":"response.completed","response":{"usage":{"input_tokens_details":{"cached_tokens":2}}}}`),
|
||||||
|
"response.completed",
|
||||||
|
nil,
|
||||||
|
)
|
||||||
|
require.Equal(t, 0, state.usage.InputTokens, "必填 usage 字段缺失时不应累加 usage")
|
||||||
|
require.Equal(t, 0, state.usage.OutputTokens)
|
||||||
|
require.Equal(t, 0, state.usage.CacheReadInputTokens)
|
||||||
|
|
||||||
|
parseUsageAndAccumulate(state, []byte(`{"type":"response.completed","response":{"usage":{"input_tokens":2,"output_tokens":1,"input_tokens_details":{"cached_tokens":1}}}}`), "response.completed", nil)
|
||||||
|
require.Equal(t, 2, state.usage.InputTokens)
|
||||||
|
require.Equal(t, 1, state.usage.OutputTokens)
|
||||||
|
require.Equal(t, 1, state.usage.CacheReadInputTokens)
|
||||||
|
|
||||||
|
result := &RelayResult{}
|
||||||
|
enrichResult(result, state, 5*time.Millisecond)
|
||||||
|
require.Equal(t, state.usage.InputTokens, result.Usage.InputTokens)
|
||||||
|
require.Equal(t, 5*time.Millisecond, result.Duration)
|
||||||
|
parseUsageAndAccumulate(state, []byte(`{"type":"response.in_progress","response":{"usage":{"input_tokens":9}}}`), "response.in_progress", nil)
|
||||||
|
require.Equal(t, 2, state.usage.InputTokens)
|
||||||
|
enrichResult(nil, state, 0)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestEmitTurnCompleteCoverage(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
// 非 terminal 事件不应触发。
|
||||||
|
called := 0
|
||||||
|
emitTurnComplete(func(turn RelayTurnResult) {
|
||||||
|
called++
|
||||||
|
}, &relayState{requestModel: "gpt-5"}, observedUpstreamEvent{
|
||||||
|
terminal: false,
|
||||||
|
eventType: "response.output_text.delta",
|
||||||
|
responseID: "resp_ignored",
|
||||||
|
usage: Usage{InputTokens: 1},
|
||||||
|
})
|
||||||
|
require.Equal(t, 0, called)
|
||||||
|
|
||||||
|
// 缺少 response_id 时不应触发。
|
||||||
|
emitTurnComplete(func(turn RelayTurnResult) {
|
||||||
|
called++
|
||||||
|
}, &relayState{requestModel: "gpt-5"}, observedUpstreamEvent{
|
||||||
|
terminal: true,
|
||||||
|
eventType: "response.completed",
|
||||||
|
})
|
||||||
|
require.Equal(t, 0, called)
|
||||||
|
|
||||||
|
// terminal 且 response_id 存在,应该触发;state=nil 时 model 为空串。
|
||||||
|
var got RelayTurnResult
|
||||||
|
emitTurnComplete(func(turn RelayTurnResult) {
|
||||||
|
called++
|
||||||
|
got = turn
|
||||||
|
}, nil, observedUpstreamEvent{
|
||||||
|
terminal: true,
|
||||||
|
eventType: "response.completed",
|
||||||
|
responseID: "resp_emit",
|
||||||
|
usage: Usage{InputTokens: 2, OutputTokens: 3},
|
||||||
|
})
|
||||||
|
require.Equal(t, 1, called)
|
||||||
|
require.Equal(t, "resp_emit", got.RequestID)
|
||||||
|
require.Equal(t, "response.completed", got.TerminalEventType)
|
||||||
|
require.Equal(t, 2, got.Usage.InputTokens)
|
||||||
|
require.Equal(t, 3, got.Usage.OutputTokens)
|
||||||
|
require.Equal(t, "", got.RequestModel)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestIsDisconnectErrorCoverage_CloseStatusesAndMessageBranches(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
require.True(t, isDisconnectError(coderws.CloseError{Code: coderws.StatusNormalClosure}))
|
||||||
|
require.True(t, isDisconnectError(coderws.CloseError{Code: coderws.StatusNoStatusRcvd}))
|
||||||
|
require.True(t, isDisconnectError(coderws.CloseError{Code: coderws.StatusAbnormalClosure}))
|
||||||
|
require.True(t, isDisconnectError(errors.New("connection reset by peer")))
|
||||||
|
require.False(t, isDisconnectError(errors.New(" ")))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestIsTokenEventCoverageBranches(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
require.False(t, isTokenEvent("response.in_progress"))
|
||||||
|
require.False(t, isTokenEvent("response.output_item.added"))
|
||||||
|
require.True(t, isTokenEvent("response.output_audio.delta"))
|
||||||
|
require.True(t, isTokenEvent("response.output"))
|
||||||
|
require.True(t, isTokenEvent("response.done"))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRelayTurnTimingHelpersCoverage(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
now := time.Unix(100, 0)
|
||||||
|
// nil state
|
||||||
|
require.Nil(t, openAIWSRelayGetOrInitTurnTiming(nil, "resp_nil", now))
|
||||||
|
_, ok := openAIWSRelayDeleteTurnTiming(nil, "resp_nil")
|
||||||
|
require.False(t, ok)
|
||||||
|
|
||||||
|
state := &relayState{}
|
||||||
|
timing := openAIWSRelayGetOrInitTurnTiming(state, "resp_a", now)
|
||||||
|
require.NotNil(t, timing)
|
||||||
|
require.Equal(t, now, timing.startAt)
|
||||||
|
|
||||||
|
// 再次获取返回同一条 timing
|
||||||
|
timing2 := openAIWSRelayGetOrInitTurnTiming(state, "resp_a", now.Add(5*time.Second))
|
||||||
|
require.NotNil(t, timing2)
|
||||||
|
require.Equal(t, now, timing2.startAt)
|
||||||
|
|
||||||
|
// 删除存在键
|
||||||
|
deleted, ok := openAIWSRelayDeleteTurnTiming(state, "resp_a")
|
||||||
|
require.True(t, ok)
|
||||||
|
require.Equal(t, now, deleted.startAt)
|
||||||
|
|
||||||
|
// 删除不存在键
|
||||||
|
_, ok = openAIWSRelayDeleteTurnTiming(state, "resp_a")
|
||||||
|
require.False(t, ok)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestObserveUpstreamMessage_ResponseIDFallbackPolicy(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
state := &relayState{requestModel: "gpt-5"}
|
||||||
|
startAt := time.Unix(0, 0)
|
||||||
|
now := startAt
|
||||||
|
nowFn := func() time.Time {
|
||||||
|
now = now.Add(5 * time.Millisecond)
|
||||||
|
return now
|
||||||
|
}
|
||||||
|
|
||||||
|
// 非 terminal:仅有顶层 id,不应把 event id 当成 response_id。
|
||||||
|
observed := observeUpstreamMessage(
|
||||||
|
state,
|
||||||
|
[]byte(`{"type":"response.output_text.delta","id":"evt_123","delta":"hi"}`),
|
||||||
|
startAt,
|
||||||
|
nowFn,
|
||||||
|
nil,
|
||||||
|
)
|
||||||
|
require.False(t, observed.terminal)
|
||||||
|
require.Equal(t, "", observed.responseID)
|
||||||
|
|
||||||
|
// terminal:允许兜底用顶层 id(用于兼容少数字段变体)。
|
||||||
|
observed = observeUpstreamMessage(
|
||||||
|
state,
|
||||||
|
[]byte(`{"type":"response.completed","id":"resp_fallback","response":{"usage":{"input_tokens":1,"output_tokens":1}}}`),
|
||||||
|
startAt,
|
||||||
|
nowFn,
|
||||||
|
nil,
|
||||||
|
)
|
||||||
|
require.True(t, observed.terminal)
|
||||||
|
require.Equal(t, "resp_fallback", observed.responseID)
|
||||||
|
}
|
||||||
752
backend/internal/service/openai_ws_v2/passthrough_relay_test.go
Normal file
752
backend/internal/service/openai_ws_v2/passthrough_relay_test.go
Normal file
@@ -0,0 +1,752 @@
|
|||||||
|
package openai_ws_v2
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"io"
|
||||||
|
"sync"
|
||||||
|
"sync/atomic"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
coderws "github.com/coder/websocket"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
type passthroughTestFrame struct {
|
||||||
|
msgType coderws.MessageType
|
||||||
|
payload []byte
|
||||||
|
}
|
||||||
|
|
||||||
|
type passthroughTestFrameConn struct {
|
||||||
|
mu sync.Mutex
|
||||||
|
writes []passthroughTestFrame
|
||||||
|
readCh chan passthroughTestFrame
|
||||||
|
once sync.Once
|
||||||
|
}
|
||||||
|
|
||||||
|
type delayedReadFrameConn struct {
|
||||||
|
base FrameConn
|
||||||
|
firstDelay time.Duration
|
||||||
|
once sync.Once
|
||||||
|
}
|
||||||
|
|
||||||
|
type closeSpyFrameConn struct {
|
||||||
|
closeCalls atomic.Int32
|
||||||
|
}
|
||||||
|
|
||||||
|
func newPassthroughTestFrameConn(frames []passthroughTestFrame, autoClose bool) *passthroughTestFrameConn {
|
||||||
|
c := &passthroughTestFrameConn{
|
||||||
|
readCh: make(chan passthroughTestFrame, len(frames)+1),
|
||||||
|
}
|
||||||
|
for _, frame := range frames {
|
||||||
|
copied := passthroughTestFrame{msgType: frame.msgType, payload: append([]byte(nil), frame.payload...)}
|
||||||
|
c.readCh <- copied
|
||||||
|
}
|
||||||
|
if autoClose {
|
||||||
|
close(c.readCh)
|
||||||
|
}
|
||||||
|
return c
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *passthroughTestFrameConn) ReadFrame(ctx context.Context) (coderws.MessageType, []byte, error) {
|
||||||
|
if ctx == nil {
|
||||||
|
ctx = context.Background()
|
||||||
|
}
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
return coderws.MessageText, nil, ctx.Err()
|
||||||
|
case frame, ok := <-c.readCh:
|
||||||
|
if !ok {
|
||||||
|
return coderws.MessageText, nil, io.EOF
|
||||||
|
}
|
||||||
|
return frame.msgType, append([]byte(nil), frame.payload...), nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *passthroughTestFrameConn) WriteFrame(ctx context.Context, msgType coderws.MessageType, payload []byte) error {
|
||||||
|
if ctx == nil {
|
||||||
|
ctx = context.Background()
|
||||||
|
}
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
return ctx.Err()
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
c.mu.Lock()
|
||||||
|
defer c.mu.Unlock()
|
||||||
|
c.writes = append(c.writes, passthroughTestFrame{msgType: msgType, payload: append([]byte(nil), payload...)})
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *passthroughTestFrameConn) Close() error {
|
||||||
|
c.once.Do(func() {
|
||||||
|
defer func() { _ = recover() }()
|
||||||
|
close(c.readCh)
|
||||||
|
})
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *passthroughTestFrameConn) Writes() []passthroughTestFrame {
|
||||||
|
c.mu.Lock()
|
||||||
|
defer c.mu.Unlock()
|
||||||
|
out := make([]passthroughTestFrame, len(c.writes))
|
||||||
|
copy(out, c.writes)
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *delayedReadFrameConn) ReadFrame(ctx context.Context) (coderws.MessageType, []byte, error) {
|
||||||
|
if c == nil || c.base == nil {
|
||||||
|
return coderws.MessageText, nil, io.EOF
|
||||||
|
}
|
||||||
|
c.once.Do(func() {
|
||||||
|
if c.firstDelay > 0 {
|
||||||
|
timer := time.NewTimer(c.firstDelay)
|
||||||
|
defer timer.Stop()
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
case <-timer.C:
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
return c.base.ReadFrame(ctx)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *delayedReadFrameConn) WriteFrame(ctx context.Context, msgType coderws.MessageType, payload []byte) error {
|
||||||
|
if c == nil || c.base == nil {
|
||||||
|
return io.EOF
|
||||||
|
}
|
||||||
|
return c.base.WriteFrame(ctx, msgType, payload)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *delayedReadFrameConn) Close() error {
|
||||||
|
if c == nil || c.base == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return c.base.Close()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *closeSpyFrameConn) ReadFrame(ctx context.Context) (coderws.MessageType, []byte, error) {
|
||||||
|
if ctx == nil {
|
||||||
|
ctx = context.Background()
|
||||||
|
}
|
||||||
|
<-ctx.Done()
|
||||||
|
return coderws.MessageText, nil, ctx.Err()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *closeSpyFrameConn) WriteFrame(ctx context.Context, _ coderws.MessageType, _ []byte) error {
|
||||||
|
if ctx == nil {
|
||||||
|
ctx = context.Background()
|
||||||
|
}
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
return ctx.Err()
|
||||||
|
default:
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *closeSpyFrameConn) Close() error {
|
||||||
|
if c != nil {
|
||||||
|
c.closeCalls.Add(1)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *closeSpyFrameConn) CloseCalls() int32 {
|
||||||
|
if c == nil {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
return c.closeCalls.Load()
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRelay_BasicRelayAndUsage(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
clientConn := newPassthroughTestFrameConn(nil, false)
|
||||||
|
upstreamConn := newPassthroughTestFrameConn([]passthroughTestFrame{
|
||||||
|
{
|
||||||
|
msgType: coderws.MessageText,
|
||||||
|
payload: []byte(`{"type":"response.completed","response":{"id":"resp_123","usage":{"input_tokens":7,"output_tokens":3,"input_tokens_details":{"cached_tokens":2}}}}`),
|
||||||
|
},
|
||||||
|
}, true)
|
||||||
|
|
||||||
|
firstPayload := []byte(`{"type":"response.create","model":"gpt-5.3-codex","input":[{"type":"input_text","text":"hello"}]}`)
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
result, relayExit := Relay(ctx, clientConn, upstreamConn, firstPayload, RelayOptions{})
|
||||||
|
require.Nil(t, relayExit)
|
||||||
|
require.Equal(t, "gpt-5.3-codex", result.RequestModel)
|
||||||
|
require.Equal(t, "resp_123", result.RequestID)
|
||||||
|
require.Equal(t, "response.completed", result.TerminalEventType)
|
||||||
|
require.Equal(t, 7, result.Usage.InputTokens)
|
||||||
|
require.Equal(t, 3, result.Usage.OutputTokens)
|
||||||
|
require.Equal(t, 2, result.Usage.CacheReadInputTokens)
|
||||||
|
require.NotNil(t, result.FirstTokenMs)
|
||||||
|
require.Equal(t, int64(1), result.ClientToUpstreamFrames)
|
||||||
|
require.Equal(t, int64(1), result.UpstreamToClientFrames)
|
||||||
|
require.Equal(t, int64(0), result.DroppedDownstreamFrames)
|
||||||
|
|
||||||
|
upstreamWrites := upstreamConn.Writes()
|
||||||
|
require.Len(t, upstreamWrites, 1)
|
||||||
|
require.Equal(t, coderws.MessageText, upstreamWrites[0].msgType)
|
||||||
|
require.JSONEq(t, string(firstPayload), string(upstreamWrites[0].payload))
|
||||||
|
|
||||||
|
clientWrites := clientConn.Writes()
|
||||||
|
require.Len(t, clientWrites, 1)
|
||||||
|
require.Equal(t, coderws.MessageText, clientWrites[0].msgType)
|
||||||
|
require.JSONEq(t, `{"type":"response.completed","response":{"id":"resp_123","usage":{"input_tokens":7,"output_tokens":3,"input_tokens_details":{"cached_tokens":2}}}}`, string(clientWrites[0].payload))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRelay_FunctionCallOutputBytesPreserved(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
clientConn := newPassthroughTestFrameConn(nil, false)
|
||||||
|
upstreamConn := newPassthroughTestFrameConn([]passthroughTestFrame{
|
||||||
|
{
|
||||||
|
msgType: coderws.MessageText,
|
||||||
|
payload: []byte(`{"type":"response.completed","response":{"id":"resp_func","usage":{"input_tokens":1,"output_tokens":1}}}`),
|
||||||
|
},
|
||||||
|
}, true)
|
||||||
|
|
||||||
|
firstPayload := []byte(`{"type":"response.create","model":"gpt-5.3-codex","input":[{"type":"function_call_output","call_id":"call_abc123","output":"{\"ok\":true}"}]}`)
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
_, relayExit := Relay(ctx, clientConn, upstreamConn, firstPayload, RelayOptions{})
|
||||||
|
require.Nil(t, relayExit)
|
||||||
|
|
||||||
|
upstreamWrites := upstreamConn.Writes()
|
||||||
|
require.Len(t, upstreamWrites, 1)
|
||||||
|
require.Equal(t, coderws.MessageText, upstreamWrites[0].msgType)
|
||||||
|
require.Equal(t, firstPayload, upstreamWrites[0].payload)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRelay_UpstreamDisconnect(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
// 上游立即关闭(EOF),客户端不发送额外帧
|
||||||
|
clientConn := newPassthroughTestFrameConn(nil, false)
|
||||||
|
upstreamConn := newPassthroughTestFrameConn(nil, true) // 立即 close -> EOF
|
||||||
|
|
||||||
|
firstPayload := []byte(`{"type":"response.create","model":"gpt-4o","input":[]}`)
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
result, relayExit := Relay(ctx, clientConn, upstreamConn, firstPayload, RelayOptions{})
|
||||||
|
// 上游 EOF 属于 disconnect,标记为 graceful
|
||||||
|
require.Nil(t, relayExit, "上游 EOF 应被视为 graceful disconnect")
|
||||||
|
require.Equal(t, "gpt-4o", result.RequestModel)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRelay_ClientDisconnect(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
// 客户端立即关闭(EOF),上游阻塞读取直到 context 取消
|
||||||
|
clientConn := newPassthroughTestFrameConn(nil, true) // 立即 close -> EOF
|
||||||
|
upstreamConn := newPassthroughTestFrameConn(nil, false)
|
||||||
|
|
||||||
|
firstPayload := []byte(`{"type":"response.create","model":"gpt-4o","input":[]}`)
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
result, relayExit := Relay(ctx, clientConn, upstreamConn, firstPayload, RelayOptions{})
|
||||||
|
require.NotNil(t, relayExit, "客户端 EOF 应返回可观测的中断状态")
|
||||||
|
require.Equal(t, "client_disconnected", relayExit.Stage)
|
||||||
|
require.Equal(t, "gpt-4o", result.RequestModel)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRelay_ClientDisconnect_DrainCapturesLateUsage(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
clientConn := newPassthroughTestFrameConn(nil, true)
|
||||||
|
upstreamBase := newPassthroughTestFrameConn([]passthroughTestFrame{
|
||||||
|
{
|
||||||
|
msgType: coderws.MessageText,
|
||||||
|
payload: []byte(`{"type":"response.completed","response":{"id":"resp_drain","usage":{"input_tokens":6,"output_tokens":4,"input_tokens_details":{"cached_tokens":1}}}}`),
|
||||||
|
},
|
||||||
|
}, true)
|
||||||
|
upstreamConn := &delayedReadFrameConn{
|
||||||
|
base: upstreamBase,
|
||||||
|
firstDelay: 80 * time.Millisecond,
|
||||||
|
}
|
||||||
|
|
||||||
|
firstPayload := []byte(`{"type":"response.create","model":"gpt-4o","input":[]}`)
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
result, relayExit := Relay(ctx, clientConn, upstreamConn, firstPayload, RelayOptions{
|
||||||
|
UpstreamDrainTimeout: 400 * time.Millisecond,
|
||||||
|
})
|
||||||
|
require.NotNil(t, relayExit)
|
||||||
|
require.Equal(t, "client_disconnected", relayExit.Stage)
|
||||||
|
require.Equal(t, "resp_drain", result.RequestID)
|
||||||
|
require.Equal(t, "response.completed", result.TerminalEventType)
|
||||||
|
require.Equal(t, 6, result.Usage.InputTokens)
|
||||||
|
require.Equal(t, 4, result.Usage.OutputTokens)
|
||||||
|
require.Equal(t, 1, result.Usage.CacheReadInputTokens)
|
||||||
|
require.Equal(t, int64(1), result.ClientToUpstreamFrames)
|
||||||
|
require.Equal(t, int64(0), result.UpstreamToClientFrames)
|
||||||
|
require.Equal(t, int64(1), result.DroppedDownstreamFrames)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRelay_IdleTimeout(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
// 客户端和上游都不发送帧,idle timeout 应触发
|
||||||
|
clientConn := newPassthroughTestFrameConn(nil, false)
|
||||||
|
upstreamConn := newPassthroughTestFrameConn(nil, false)
|
||||||
|
|
||||||
|
firstPayload := []byte(`{"type":"response.create","model":"gpt-4o","input":[]}`)
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
// 使用快进时间来加速 idle timeout
|
||||||
|
now := time.Now()
|
||||||
|
callCount := 0
|
||||||
|
nowFn := func() time.Time {
|
||||||
|
callCount++
|
||||||
|
// 前几次调用返回正常时间(初始化阶段),之后快进
|
||||||
|
if callCount <= 5 {
|
||||||
|
return now
|
||||||
|
}
|
||||||
|
return now.Add(time.Hour) // 快进到超时
|
||||||
|
}
|
||||||
|
|
||||||
|
result, relayExit := Relay(ctx, clientConn, upstreamConn, firstPayload, RelayOptions{
|
||||||
|
IdleTimeout: 2 * time.Second,
|
||||||
|
Now: nowFn,
|
||||||
|
})
|
||||||
|
require.NotNil(t, relayExit, "应因 idle timeout 退出")
|
||||||
|
require.Equal(t, "idle_timeout", relayExit.Stage)
|
||||||
|
require.Equal(t, "gpt-4o", result.RequestModel)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRelay_IdleTimeoutDoesNotCloseClientOnError(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
clientConn := &closeSpyFrameConn{}
|
||||||
|
upstreamConn := &closeSpyFrameConn{}
|
||||||
|
|
||||||
|
firstPayload := []byte(`{"type":"response.create","model":"gpt-4o","input":[]}`)
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
now := time.Now()
|
||||||
|
callCount := 0
|
||||||
|
nowFn := func() time.Time {
|
||||||
|
callCount++
|
||||||
|
if callCount <= 5 {
|
||||||
|
return now
|
||||||
|
}
|
||||||
|
return now.Add(time.Hour)
|
||||||
|
}
|
||||||
|
|
||||||
|
_, relayExit := Relay(ctx, clientConn, upstreamConn, firstPayload, RelayOptions{
|
||||||
|
IdleTimeout: 2 * time.Second,
|
||||||
|
Now: nowFn,
|
||||||
|
})
|
||||||
|
require.NotNil(t, relayExit, "应因 idle timeout 退出")
|
||||||
|
require.Equal(t, "idle_timeout", relayExit.Stage)
|
||||||
|
require.Zero(t, clientConn.CloseCalls(), "错误路径不应提前关闭客户端连接,交给上层决定 close code")
|
||||||
|
require.GreaterOrEqual(t, upstreamConn.CloseCalls(), int32(1))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRelay_NilConnections(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
firstPayload := []byte(`{"type":"response.create","model":"gpt-4o","input":[]}`)
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
t.Run("nil client conn", func(t *testing.T) {
|
||||||
|
upstreamConn := newPassthroughTestFrameConn(nil, true)
|
||||||
|
_, relayExit := Relay(ctx, nil, upstreamConn, firstPayload, RelayOptions{})
|
||||||
|
require.NotNil(t, relayExit)
|
||||||
|
require.Equal(t, "relay_init", relayExit.Stage)
|
||||||
|
require.Contains(t, relayExit.Err.Error(), "nil")
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("nil upstream conn", func(t *testing.T) {
|
||||||
|
clientConn := newPassthroughTestFrameConn(nil, true)
|
||||||
|
_, relayExit := Relay(ctx, clientConn, nil, firstPayload, RelayOptions{})
|
||||||
|
require.NotNil(t, relayExit)
|
||||||
|
require.Equal(t, "relay_init", relayExit.Stage)
|
||||||
|
require.Contains(t, relayExit.Err.Error(), "nil")
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRelay_MultipleUpstreamMessages(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
// 上游发送多个事件(delta + completed),验证多帧中继和 usage 聚合
|
||||||
|
clientConn := newPassthroughTestFrameConn(nil, false)
|
||||||
|
upstreamConn := newPassthroughTestFrameConn([]passthroughTestFrame{
|
||||||
|
{
|
||||||
|
msgType: coderws.MessageText,
|
||||||
|
payload: []byte(`{"type":"response.output_text.delta","delta":"Hello"}`),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
msgType: coderws.MessageText,
|
||||||
|
payload: []byte(`{"type":"response.output_text.delta","delta":" world"}`),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
msgType: coderws.MessageText,
|
||||||
|
payload: []byte(`{"type":"response.completed","response":{"id":"resp_multi","usage":{"input_tokens":10,"output_tokens":5,"input_tokens_details":{"cached_tokens":3}}}}`),
|
||||||
|
},
|
||||||
|
}, true)
|
||||||
|
|
||||||
|
firstPayload := []byte(`{"type":"response.create","model":"gpt-4o","input":[{"type":"input_text","text":"hi"}]}`)
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
result, relayExit := Relay(ctx, clientConn, upstreamConn, firstPayload, RelayOptions{})
|
||||||
|
require.Nil(t, relayExit)
|
||||||
|
require.Equal(t, "resp_multi", result.RequestID)
|
||||||
|
require.Equal(t, "response.completed", result.TerminalEventType)
|
||||||
|
require.Equal(t, 10, result.Usage.InputTokens)
|
||||||
|
require.Equal(t, 5, result.Usage.OutputTokens)
|
||||||
|
require.Equal(t, 3, result.Usage.CacheReadInputTokens)
|
||||||
|
require.NotNil(t, result.FirstTokenMs)
|
||||||
|
|
||||||
|
// 验证所有 3 个上游帧都转发给了客户端
|
||||||
|
clientWrites := clientConn.Writes()
|
||||||
|
require.Len(t, clientWrites, 3)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRelay_OnTurnComplete_PerTerminalEvent(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
clientConn := newPassthroughTestFrameConn(nil, false)
|
||||||
|
upstreamConn := newPassthroughTestFrameConn([]passthroughTestFrame{
|
||||||
|
{
|
||||||
|
msgType: coderws.MessageText,
|
||||||
|
payload: []byte(`{"type":"response.completed","response":{"id":"resp_turn_1","usage":{"input_tokens":2,"output_tokens":1}}}`),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
msgType: coderws.MessageText,
|
||||||
|
payload: []byte(`{"type":"response.failed","response":{"id":"resp_turn_2","usage":{"input_tokens":3,"output_tokens":4}}}`),
|
||||||
|
},
|
||||||
|
}, true)
|
||||||
|
|
||||||
|
firstPayload := []byte(`{"type":"response.create","model":"gpt-5.3-codex","input":[]}`)
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
turns := make([]RelayTurnResult, 0, 2)
|
||||||
|
result, relayExit := Relay(ctx, clientConn, upstreamConn, firstPayload, RelayOptions{
|
||||||
|
OnTurnComplete: func(turn RelayTurnResult) {
|
||||||
|
turns = append(turns, turn)
|
||||||
|
},
|
||||||
|
})
|
||||||
|
require.Nil(t, relayExit)
|
||||||
|
require.Len(t, turns, 2)
|
||||||
|
require.Equal(t, "resp_turn_1", turns[0].RequestID)
|
||||||
|
require.Equal(t, "response.completed", turns[0].TerminalEventType)
|
||||||
|
require.Equal(t, 2, turns[0].Usage.InputTokens)
|
||||||
|
require.Equal(t, 1, turns[0].Usage.OutputTokens)
|
||||||
|
require.Equal(t, "resp_turn_2", turns[1].RequestID)
|
||||||
|
require.Equal(t, "response.failed", turns[1].TerminalEventType)
|
||||||
|
require.Equal(t, 3, turns[1].Usage.InputTokens)
|
||||||
|
require.Equal(t, 4, turns[1].Usage.OutputTokens)
|
||||||
|
require.Equal(t, 5, result.Usage.InputTokens)
|
||||||
|
require.Equal(t, 5, result.Usage.OutputTokens)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRelay_OnTurnComplete_ProvidesTurnMetrics(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
clientConn := newPassthroughTestFrameConn(nil, false)
|
||||||
|
upstreamConn := newPassthroughTestFrameConn([]passthroughTestFrame{
|
||||||
|
{
|
||||||
|
msgType: coderws.MessageText,
|
||||||
|
payload: []byte(`{"type":"response.output_text.delta","response_id":"resp_metric","delta":"hi"}`),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
msgType: coderws.MessageText,
|
||||||
|
payload: []byte(`{"type":"response.completed","response":{"id":"resp_metric","usage":{"input_tokens":2,"output_tokens":1}}}`),
|
||||||
|
},
|
||||||
|
}, true)
|
||||||
|
|
||||||
|
firstPayload := []byte(`{"type":"response.create","model":"gpt-5.3-codex","input":[]}`)
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
base := time.Unix(0, 0)
|
||||||
|
var nowTick atomic.Int64
|
||||||
|
nowFn := func() time.Time {
|
||||||
|
step := nowTick.Add(1)
|
||||||
|
return base.Add(time.Duration(step) * 5 * time.Millisecond)
|
||||||
|
}
|
||||||
|
|
||||||
|
var turn RelayTurnResult
|
||||||
|
result, relayExit := Relay(ctx, clientConn, upstreamConn, firstPayload, RelayOptions{
|
||||||
|
Now: nowFn,
|
||||||
|
OnTurnComplete: func(current RelayTurnResult) {
|
||||||
|
turn = current
|
||||||
|
},
|
||||||
|
})
|
||||||
|
require.Nil(t, relayExit)
|
||||||
|
require.Equal(t, "resp_metric", turn.RequestID)
|
||||||
|
require.Equal(t, "response.completed", turn.TerminalEventType)
|
||||||
|
require.NotNil(t, turn.FirstTokenMs)
|
||||||
|
require.GreaterOrEqual(t, *turn.FirstTokenMs, 0)
|
||||||
|
require.Greater(t, turn.Duration.Milliseconds(), int64(0))
|
||||||
|
require.NotNil(t, result.FirstTokenMs)
|
||||||
|
require.Greater(t, result.Duration.Milliseconds(), int64(0))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRelay_BinaryFramePassthrough(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
// 验证 binary frame 被透传但不进行 usage 解析
|
||||||
|
binaryPayload := []byte{0x00, 0x01, 0x02, 0x03}
|
||||||
|
clientConn := newPassthroughTestFrameConn(nil, false)
|
||||||
|
upstreamConn := newPassthroughTestFrameConn([]passthroughTestFrame{
|
||||||
|
{
|
||||||
|
msgType: coderws.MessageBinary,
|
||||||
|
payload: binaryPayload,
|
||||||
|
},
|
||||||
|
}, true)
|
||||||
|
|
||||||
|
firstPayload := []byte(`{"type":"response.create","model":"gpt-4o","input":[]}`)
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
result, relayExit := Relay(ctx, clientConn, upstreamConn, firstPayload, RelayOptions{})
|
||||||
|
require.Nil(t, relayExit)
|
||||||
|
// binary frame 不解析 usage
|
||||||
|
require.Equal(t, 0, result.Usage.InputTokens)
|
||||||
|
|
||||||
|
clientWrites := clientConn.Writes()
|
||||||
|
require.Len(t, clientWrites, 1)
|
||||||
|
require.Equal(t, coderws.MessageBinary, clientWrites[0].msgType)
|
||||||
|
require.Equal(t, binaryPayload, clientWrites[0].payload)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRelay_BinaryJSONFrameSkipsObservation(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
clientConn := newPassthroughTestFrameConn(nil, false)
|
||||||
|
upstreamConn := newPassthroughTestFrameConn([]passthroughTestFrame{
|
||||||
|
{
|
||||||
|
msgType: coderws.MessageBinary,
|
||||||
|
payload: []byte(`{"type":"response.completed","response":{"id":"resp_binary","usage":{"input_tokens":7,"output_tokens":3}}}`),
|
||||||
|
},
|
||||||
|
}, true)
|
||||||
|
|
||||||
|
firstPayload := []byte(`{"type":"response.create","model":"gpt-4o","input":[]}`)
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
result, relayExit := Relay(ctx, clientConn, upstreamConn, firstPayload, RelayOptions{})
|
||||||
|
require.Nil(t, relayExit)
|
||||||
|
require.Equal(t, 0, result.Usage.InputTokens)
|
||||||
|
require.Equal(t, "", result.RequestID)
|
||||||
|
require.Equal(t, "", result.TerminalEventType)
|
||||||
|
|
||||||
|
clientWrites := clientConn.Writes()
|
||||||
|
require.Len(t, clientWrites, 1)
|
||||||
|
require.Equal(t, coderws.MessageBinary, clientWrites[0].msgType)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRelay_UpstreamErrorEventPassthroughRaw(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
clientConn := newPassthroughTestFrameConn(nil, false)
|
||||||
|
errorEvent := []byte(`{"type":"error","error":{"type":"invalid_request_error","message":"No tool call found"}}`)
|
||||||
|
upstreamConn := newPassthroughTestFrameConn([]passthroughTestFrame{
|
||||||
|
{
|
||||||
|
msgType: coderws.MessageText,
|
||||||
|
payload: errorEvent,
|
||||||
|
},
|
||||||
|
}, true)
|
||||||
|
|
||||||
|
firstPayload := []byte(`{"type":"response.create","model":"gpt-4o","input":[]}`)
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
_, relayExit := Relay(ctx, clientConn, upstreamConn, firstPayload, RelayOptions{})
|
||||||
|
require.Nil(t, relayExit)
|
||||||
|
|
||||||
|
clientWrites := clientConn.Writes()
|
||||||
|
require.Len(t, clientWrites, 1)
|
||||||
|
require.Equal(t, coderws.MessageText, clientWrites[0].msgType)
|
||||||
|
require.Equal(t, errorEvent, clientWrites[0].payload)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRelay_PreservesFirstMessageType(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
clientConn := newPassthroughTestFrameConn(nil, false)
|
||||||
|
upstreamConn := newPassthroughTestFrameConn(nil, true)
|
||||||
|
|
||||||
|
firstPayload := []byte(`{"type":"response.create","model":"gpt-4o","input":[]}`)
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
_, relayExit := Relay(ctx, clientConn, upstreamConn, firstPayload, RelayOptions{
|
||||||
|
FirstMessageType: coderws.MessageBinary,
|
||||||
|
})
|
||||||
|
require.Nil(t, relayExit)
|
||||||
|
|
||||||
|
upstreamWrites := upstreamConn.Writes()
|
||||||
|
require.Len(t, upstreamWrites, 1)
|
||||||
|
require.Equal(t, coderws.MessageBinary, upstreamWrites[0].msgType)
|
||||||
|
require.Equal(t, firstPayload, upstreamWrites[0].payload)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRelay_UsageParseFailureDoesNotBlockRelay(t *testing.T) {
|
||||||
|
baseline := SnapshotMetrics().UsageParseFailureTotal
|
||||||
|
|
||||||
|
// 上游发送无效 JSON(非 usage 格式),不应影响透传
|
||||||
|
clientConn := newPassthroughTestFrameConn(nil, false)
|
||||||
|
upstreamConn := newPassthroughTestFrameConn([]passthroughTestFrame{
|
||||||
|
{
|
||||||
|
msgType: coderws.MessageText,
|
||||||
|
payload: []byte(`{"type":"response.completed","response":{"id":"resp_bad","usage":"not_an_object"}}`),
|
||||||
|
},
|
||||||
|
}, true)
|
||||||
|
|
||||||
|
firstPayload := []byte(`{"type":"response.create","model":"gpt-4o","input":[]}`)
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
result, relayExit := Relay(ctx, clientConn, upstreamConn, firstPayload, RelayOptions{})
|
||||||
|
require.Nil(t, relayExit)
|
||||||
|
// usage 解析失败,值为 0 但不影响透传
|
||||||
|
require.Equal(t, 0, result.Usage.InputTokens)
|
||||||
|
require.Equal(t, "response.completed", result.TerminalEventType)
|
||||||
|
|
||||||
|
// 帧仍然被转发
|
||||||
|
clientWrites := clientConn.Writes()
|
||||||
|
require.Len(t, clientWrites, 1)
|
||||||
|
require.GreaterOrEqual(t, SnapshotMetrics().UsageParseFailureTotal, baseline+1)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRelay_WriteUpstreamFirstMessageFails(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
// 上游连接立即关闭,首包写入失败
|
||||||
|
upstreamConn := newPassthroughTestFrameConn(nil, true)
|
||||||
|
_ = upstreamConn.Close()
|
||||||
|
|
||||||
|
// 覆盖 WriteFrame 使其返回错误
|
||||||
|
errConn := &errorOnWriteFrameConn{}
|
||||||
|
clientConn := newPassthroughTestFrameConn(nil, false)
|
||||||
|
|
||||||
|
firstPayload := []byte(`{"type":"response.create","model":"gpt-4o","input":[]}`)
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
_, relayExit := Relay(ctx, clientConn, errConn, firstPayload, RelayOptions{})
|
||||||
|
require.NotNil(t, relayExit)
|
||||||
|
require.Equal(t, "write_upstream", relayExit.Stage)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRelay_ContextCanceled(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
clientConn := newPassthroughTestFrameConn(nil, false)
|
||||||
|
upstreamConn := newPassthroughTestFrameConn(nil, false)
|
||||||
|
|
||||||
|
firstPayload := []byte(`{"type":"response.create","model":"gpt-4o","input":[]}`)
|
||||||
|
|
||||||
|
// 立即取消 context
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
cancel()
|
||||||
|
|
||||||
|
_, relayExit := Relay(ctx, clientConn, upstreamConn, firstPayload, RelayOptions{})
|
||||||
|
// context 取消导致写首包失败
|
||||||
|
require.NotNil(t, relayExit)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRelay_TraceEvents_ContainsLifecycleStages(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
clientConn := newPassthroughTestFrameConn(nil, false)
|
||||||
|
upstreamConn := newPassthroughTestFrameConn([]passthroughTestFrame{
|
||||||
|
{
|
||||||
|
msgType: coderws.MessageText,
|
||||||
|
payload: []byte(`{"type":"response.completed","response":{"id":"resp_trace","usage":{"input_tokens":1,"output_tokens":1}}}`),
|
||||||
|
},
|
||||||
|
}, true)
|
||||||
|
|
||||||
|
firstPayload := []byte(`{"type":"response.create","model":"gpt-4o","input":[]}`)
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
stages := make([]string, 0, 8)
|
||||||
|
var stagesMu sync.Mutex
|
||||||
|
_, relayExit := Relay(ctx, clientConn, upstreamConn, firstPayload, RelayOptions{
|
||||||
|
OnTrace: func(event RelayTraceEvent) {
|
||||||
|
stagesMu.Lock()
|
||||||
|
stages = append(stages, event.Stage)
|
||||||
|
stagesMu.Unlock()
|
||||||
|
},
|
||||||
|
})
|
||||||
|
require.Nil(t, relayExit)
|
||||||
|
stagesMu.Lock()
|
||||||
|
capturedStages := append([]string(nil), stages...)
|
||||||
|
stagesMu.Unlock()
|
||||||
|
require.Contains(t, capturedStages, "relay_start")
|
||||||
|
require.Contains(t, capturedStages, "write_first_message_ok")
|
||||||
|
require.Contains(t, capturedStages, "first_exit")
|
||||||
|
require.Contains(t, capturedStages, "relay_complete")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRelay_TraceEvents_IdleTimeout(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
clientConn := newPassthroughTestFrameConn(nil, false)
|
||||||
|
upstreamConn := newPassthroughTestFrameConn(nil, false)
|
||||||
|
|
||||||
|
firstPayload := []byte(`{"type":"response.create","model":"gpt-4o","input":[]}`)
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
now := time.Now()
|
||||||
|
callCount := 0
|
||||||
|
nowFn := func() time.Time {
|
||||||
|
callCount++
|
||||||
|
if callCount <= 5 {
|
||||||
|
return now
|
||||||
|
}
|
||||||
|
return now.Add(time.Hour)
|
||||||
|
}
|
||||||
|
|
||||||
|
stages := make([]string, 0, 8)
|
||||||
|
var stagesMu sync.Mutex
|
||||||
|
_, relayExit := Relay(ctx, clientConn, upstreamConn, firstPayload, RelayOptions{
|
||||||
|
IdleTimeout: 2 * time.Second,
|
||||||
|
Now: nowFn,
|
||||||
|
OnTrace: func(event RelayTraceEvent) {
|
||||||
|
stagesMu.Lock()
|
||||||
|
stages = append(stages, event.Stage)
|
||||||
|
stagesMu.Unlock()
|
||||||
|
},
|
||||||
|
})
|
||||||
|
require.NotNil(t, relayExit)
|
||||||
|
require.Equal(t, "idle_timeout", relayExit.Stage)
|
||||||
|
stagesMu.Lock()
|
||||||
|
capturedStages := append([]string(nil), stages...)
|
||||||
|
stagesMu.Unlock()
|
||||||
|
require.Contains(t, capturedStages, "idle_timeout_triggered")
|
||||||
|
require.Contains(t, capturedStages, "relay_exit")
|
||||||
|
}
|
||||||
|
|
||||||
|
// errorOnWriteFrameConn 是一个写入总是失败的 FrameConn 实现,用于测试首包写入失败。
|
||||||
|
type errorOnWriteFrameConn struct{}
|
||||||
|
|
||||||
|
func (c *errorOnWriteFrameConn) ReadFrame(ctx context.Context) (coderws.MessageType, []byte, error) {
|
||||||
|
<-ctx.Done()
|
||||||
|
return coderws.MessageText, nil, ctx.Err()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *errorOnWriteFrameConn) WriteFrame(_ context.Context, _ coderws.MessageType, _ []byte) error {
|
||||||
|
return errors.New("write failed: connection refused")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *errorOnWriteFrameConn) Close() error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
367
backend/internal/service/openai_ws_v2_passthrough_adapter.go
Normal file
367
backend/internal/service/openai_ws_v2_passthrough_adapter.go
Normal file
@@ -0,0 +1,367 @@
|
|||||||
|
package service
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"net/http"
|
||||||
|
"net/url"
|
||||||
|
"strings"
|
||||||
|
"sync/atomic"
|
||||||
|
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/pkg/openai"
|
||||||
|
openaiwsv2 "github.com/Wei-Shaw/sub2api/internal/service/openai_ws_v2"
|
||||||
|
coderws "github.com/coder/websocket"
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
"github.com/tidwall/gjson"
|
||||||
|
)
|
||||||
|
|
||||||
|
type openAIWSClientFrameConn struct {
|
||||||
|
conn *coderws.Conn
|
||||||
|
}
|
||||||
|
|
||||||
|
const openaiWSV2PassthroughModeFields = "ws_mode=passthrough ws_router=v2"
|
||||||
|
|
||||||
|
var _ openaiwsv2.FrameConn = (*openAIWSClientFrameConn)(nil)
|
||||||
|
|
||||||
|
func (c *openAIWSClientFrameConn) ReadFrame(ctx context.Context) (coderws.MessageType, []byte, error) {
|
||||||
|
if c == nil || c.conn == nil {
|
||||||
|
return coderws.MessageText, nil, errOpenAIWSConnClosed
|
||||||
|
}
|
||||||
|
if ctx == nil {
|
||||||
|
ctx = context.Background()
|
||||||
|
}
|
||||||
|
return c.conn.Read(ctx)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *openAIWSClientFrameConn) WriteFrame(ctx context.Context, msgType coderws.MessageType, payload []byte) error {
|
||||||
|
if c == nil || c.conn == nil {
|
||||||
|
return errOpenAIWSConnClosed
|
||||||
|
}
|
||||||
|
if ctx == nil {
|
||||||
|
ctx = context.Background()
|
||||||
|
}
|
||||||
|
return c.conn.Write(ctx, msgType, payload)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *openAIWSClientFrameConn) Close() error {
|
||||||
|
if c == nil || c.conn == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
_ = c.conn.Close(coderws.StatusNormalClosure, "")
|
||||||
|
_ = c.conn.CloseNow()
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *OpenAIGatewayService) proxyResponsesWebSocketV2Passthrough(
|
||||||
|
ctx context.Context,
|
||||||
|
c *gin.Context,
|
||||||
|
clientConn *coderws.Conn,
|
||||||
|
account *Account,
|
||||||
|
token string,
|
||||||
|
firstClientMessage []byte,
|
||||||
|
hooks *OpenAIWSIngressHooks,
|
||||||
|
wsDecision OpenAIWSProtocolDecision,
|
||||||
|
) error {
|
||||||
|
if s == nil {
|
||||||
|
return errors.New("service is nil")
|
||||||
|
}
|
||||||
|
if clientConn == nil {
|
||||||
|
return errors.New("client websocket is nil")
|
||||||
|
}
|
||||||
|
if account == nil {
|
||||||
|
return errors.New("account is nil")
|
||||||
|
}
|
||||||
|
if strings.TrimSpace(token) == "" {
|
||||||
|
return errors.New("token is empty")
|
||||||
|
}
|
||||||
|
requestModel := strings.TrimSpace(gjson.GetBytes(firstClientMessage, "model").String())
|
||||||
|
requestPreviousResponseID := strings.TrimSpace(gjson.GetBytes(firstClientMessage, "previous_response_id").String())
|
||||||
|
logOpenAIWSV2Passthrough(
|
||||||
|
"relay_start account_id=%d model=%s previous_response_id=%s first_message_type=%s first_message_bytes=%d",
|
||||||
|
account.ID,
|
||||||
|
truncateOpenAIWSLogValue(requestModel, openAIWSLogValueMaxLen),
|
||||||
|
truncateOpenAIWSLogValue(requestPreviousResponseID, openAIWSIDValueMaxLen),
|
||||||
|
openaiwsv2RelayMessageTypeName(coderws.MessageText),
|
||||||
|
len(firstClientMessage),
|
||||||
|
)
|
||||||
|
|
||||||
|
wsURL, err := s.buildOpenAIResponsesWSURL(account)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("build ws url: %w", err)
|
||||||
|
}
|
||||||
|
wsHost := "-"
|
||||||
|
wsPath := "-"
|
||||||
|
if parsedURL, parseErr := url.Parse(wsURL); parseErr == nil && parsedURL != nil {
|
||||||
|
wsHost = normalizeOpenAIWSLogValue(parsedURL.Host)
|
||||||
|
wsPath = normalizeOpenAIWSLogValue(parsedURL.Path)
|
||||||
|
}
|
||||||
|
logOpenAIWSV2Passthrough(
|
||||||
|
"relay_dial_start account_id=%d ws_host=%s ws_path=%s proxy_enabled=%v",
|
||||||
|
account.ID,
|
||||||
|
wsHost,
|
||||||
|
wsPath,
|
||||||
|
account.ProxyID != nil && account.Proxy != nil,
|
||||||
|
)
|
||||||
|
|
||||||
|
isCodexCLI := false
|
||||||
|
if c != nil {
|
||||||
|
isCodexCLI = openai.IsCodexCLIRequest(c.GetHeader("User-Agent"))
|
||||||
|
}
|
||||||
|
if s.cfg != nil && s.cfg.Gateway.ForceCodexCLI {
|
||||||
|
isCodexCLI = true
|
||||||
|
}
|
||||||
|
headers, _ := s.buildOpenAIWSHeaders(c, account, token, wsDecision, isCodexCLI, "", "", "")
|
||||||
|
proxyURL := ""
|
||||||
|
if account.ProxyID != nil && account.Proxy != nil {
|
||||||
|
proxyURL = account.Proxy.URL()
|
||||||
|
}
|
||||||
|
|
||||||
|
dialer := s.getOpenAIWSPassthroughDialer()
|
||||||
|
if dialer == nil {
|
||||||
|
return errors.New("openai ws passthrough dialer is nil")
|
||||||
|
}
|
||||||
|
|
||||||
|
dialCtx, cancelDial := context.WithTimeout(ctx, s.openAIWSDialTimeout())
|
||||||
|
defer cancelDial()
|
||||||
|
upstreamConn, statusCode, handshakeHeaders, err := dialer.Dial(dialCtx, wsURL, headers, proxyURL)
|
||||||
|
if err != nil {
|
||||||
|
logOpenAIWSV2Passthrough(
|
||||||
|
"relay_dial_failed account_id=%d status_code=%d err=%s",
|
||||||
|
account.ID,
|
||||||
|
statusCode,
|
||||||
|
truncateOpenAIWSLogValue(err.Error(), openAIWSLogValueMaxLen),
|
||||||
|
)
|
||||||
|
return s.mapOpenAIWSPassthroughDialError(err, statusCode, handshakeHeaders)
|
||||||
|
}
|
||||||
|
defer func() {
|
||||||
|
_ = upstreamConn.Close()
|
||||||
|
}()
|
||||||
|
logOpenAIWSV2Passthrough(
|
||||||
|
"relay_dial_ok account_id=%d status_code=%d upstream_request_id=%s",
|
||||||
|
account.ID,
|
||||||
|
statusCode,
|
||||||
|
openAIWSHeaderValueForLog(handshakeHeaders, "x-request-id"),
|
||||||
|
)
|
||||||
|
|
||||||
|
upstreamFrameConn, ok := upstreamConn.(openaiwsv2.FrameConn)
|
||||||
|
if !ok {
|
||||||
|
return errors.New("openai ws passthrough upstream connection does not support frame relay")
|
||||||
|
}
|
||||||
|
|
||||||
|
completedTurns := atomic.Int32{}
|
||||||
|
relayResult, relayExit := openaiwsv2.RunEntry(openaiwsv2.EntryInput{
|
||||||
|
Ctx: ctx,
|
||||||
|
ClientConn: &openAIWSClientFrameConn{conn: clientConn},
|
||||||
|
UpstreamConn: upstreamFrameConn,
|
||||||
|
FirstClientMessage: firstClientMessage,
|
||||||
|
Options: openaiwsv2.RelayOptions{
|
||||||
|
WriteTimeout: s.openAIWSWriteTimeout(),
|
||||||
|
IdleTimeout: s.openAIWSPassthroughIdleTimeout(),
|
||||||
|
FirstMessageType: coderws.MessageText,
|
||||||
|
OnUsageParseFailure: func(eventType string, usageRaw string) {
|
||||||
|
logOpenAIWSV2Passthrough(
|
||||||
|
"usage_parse_failed event_type=%s usage_raw=%s",
|
||||||
|
truncateOpenAIWSLogValue(eventType, openAIWSLogValueMaxLen),
|
||||||
|
truncateOpenAIWSLogValue(usageRaw, openAIWSLogValueMaxLen),
|
||||||
|
)
|
||||||
|
},
|
||||||
|
OnTurnComplete: func(turn openaiwsv2.RelayTurnResult) {
|
||||||
|
turnNo := int(completedTurns.Add(1))
|
||||||
|
turnResult := &OpenAIForwardResult{
|
||||||
|
RequestID: turn.RequestID,
|
||||||
|
Usage: OpenAIUsage{
|
||||||
|
InputTokens: turn.Usage.InputTokens,
|
||||||
|
OutputTokens: turn.Usage.OutputTokens,
|
||||||
|
CacheCreationInputTokens: turn.Usage.CacheCreationInputTokens,
|
||||||
|
CacheReadInputTokens: turn.Usage.CacheReadInputTokens,
|
||||||
|
},
|
||||||
|
Model: turn.RequestModel,
|
||||||
|
Stream: true,
|
||||||
|
OpenAIWSMode: true,
|
||||||
|
Duration: turn.Duration,
|
||||||
|
FirstTokenMs: turn.FirstTokenMs,
|
||||||
|
}
|
||||||
|
logOpenAIWSV2Passthrough(
|
||||||
|
"relay_turn_completed account_id=%d turn=%d request_id=%s terminal_event=%s duration_ms=%d first_token_ms=%d input_tokens=%d output_tokens=%d cache_read_tokens=%d",
|
||||||
|
account.ID,
|
||||||
|
turnNo,
|
||||||
|
truncateOpenAIWSLogValue(turnResult.RequestID, openAIWSIDValueMaxLen),
|
||||||
|
truncateOpenAIWSLogValue(turn.TerminalEventType, openAIWSLogValueMaxLen),
|
||||||
|
turnResult.Duration.Milliseconds(),
|
||||||
|
openAIWSFirstTokenMsForLog(turnResult.FirstTokenMs),
|
||||||
|
turnResult.Usage.InputTokens,
|
||||||
|
turnResult.Usage.OutputTokens,
|
||||||
|
turnResult.Usage.CacheReadInputTokens,
|
||||||
|
)
|
||||||
|
if hooks != nil && hooks.AfterTurn != nil {
|
||||||
|
hooks.AfterTurn(turnNo, turnResult, nil)
|
||||||
|
}
|
||||||
|
},
|
||||||
|
OnTrace: func(event openaiwsv2.RelayTraceEvent) {
|
||||||
|
logOpenAIWSV2Passthrough(
|
||||||
|
"relay_trace account_id=%d stage=%s direction=%s msg_type=%s bytes=%d graceful=%v wrote_downstream=%v err=%s",
|
||||||
|
account.ID,
|
||||||
|
truncateOpenAIWSLogValue(event.Stage, openAIWSLogValueMaxLen),
|
||||||
|
truncateOpenAIWSLogValue(event.Direction, openAIWSLogValueMaxLen),
|
||||||
|
truncateOpenAIWSLogValue(event.MessageType, openAIWSLogValueMaxLen),
|
||||||
|
event.PayloadBytes,
|
||||||
|
event.Graceful,
|
||||||
|
event.WroteDownstream,
|
||||||
|
truncateOpenAIWSLogValue(event.Error, openAIWSLogValueMaxLen),
|
||||||
|
)
|
||||||
|
},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
|
||||||
|
result := &OpenAIForwardResult{
|
||||||
|
RequestID: relayResult.RequestID,
|
||||||
|
Usage: OpenAIUsage{
|
||||||
|
InputTokens: relayResult.Usage.InputTokens,
|
||||||
|
OutputTokens: relayResult.Usage.OutputTokens,
|
||||||
|
CacheCreationInputTokens: relayResult.Usage.CacheCreationInputTokens,
|
||||||
|
CacheReadInputTokens: relayResult.Usage.CacheReadInputTokens,
|
||||||
|
},
|
||||||
|
Model: relayResult.RequestModel,
|
||||||
|
Stream: true,
|
||||||
|
OpenAIWSMode: true,
|
||||||
|
Duration: relayResult.Duration,
|
||||||
|
FirstTokenMs: relayResult.FirstTokenMs,
|
||||||
|
}
|
||||||
|
|
||||||
|
turnCount := int(completedTurns.Load())
|
||||||
|
if relayExit == nil {
|
||||||
|
logOpenAIWSV2Passthrough(
|
||||||
|
"relay_completed account_id=%d request_id=%s terminal_event=%s duration_ms=%d c2u_frames=%d u2c_frames=%d dropped_frames=%d turns=%d",
|
||||||
|
account.ID,
|
||||||
|
truncateOpenAIWSLogValue(result.RequestID, openAIWSIDValueMaxLen),
|
||||||
|
truncateOpenAIWSLogValue(relayResult.TerminalEventType, openAIWSLogValueMaxLen),
|
||||||
|
result.Duration.Milliseconds(),
|
||||||
|
relayResult.ClientToUpstreamFrames,
|
||||||
|
relayResult.UpstreamToClientFrames,
|
||||||
|
relayResult.DroppedDownstreamFrames,
|
||||||
|
turnCount,
|
||||||
|
)
|
||||||
|
// 正常路径按 terminal 事件逐 turn 已回调;仅在零 turn 场景兜底回调一次。
|
||||||
|
if turnCount == 0 && hooks != nil && hooks.AfterTurn != nil {
|
||||||
|
hooks.AfterTurn(1, result, nil)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
logOpenAIWSV2Passthrough(
|
||||||
|
"relay_failed account_id=%d stage=%s wrote_downstream=%v err=%s duration_ms=%d c2u_frames=%d u2c_frames=%d dropped_frames=%d turns=%d",
|
||||||
|
account.ID,
|
||||||
|
truncateOpenAIWSLogValue(relayExit.Stage, openAIWSLogValueMaxLen),
|
||||||
|
relayExit.WroteDownstream,
|
||||||
|
truncateOpenAIWSLogValue(relayErrorText(relayExit.Err), openAIWSLogValueMaxLen),
|
||||||
|
result.Duration.Milliseconds(),
|
||||||
|
relayResult.ClientToUpstreamFrames,
|
||||||
|
relayResult.UpstreamToClientFrames,
|
||||||
|
relayResult.DroppedDownstreamFrames,
|
||||||
|
turnCount,
|
||||||
|
)
|
||||||
|
|
||||||
|
relayErr := relayExit.Err
|
||||||
|
if relayExit.Stage == "idle_timeout" {
|
||||||
|
relayErr = NewOpenAIWSClientCloseError(
|
||||||
|
coderws.StatusPolicyViolation,
|
||||||
|
"client websocket idle timeout",
|
||||||
|
relayErr,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
turnErr := wrapOpenAIWSIngressTurnError(
|
||||||
|
relayExit.Stage,
|
||||||
|
relayErr,
|
||||||
|
relayExit.WroteDownstream,
|
||||||
|
)
|
||||||
|
if hooks != nil && hooks.AfterTurn != nil {
|
||||||
|
hooks.AfterTurn(turnCount+1, nil, turnErr)
|
||||||
|
}
|
||||||
|
return turnErr
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *OpenAIGatewayService) mapOpenAIWSPassthroughDialError(
|
||||||
|
err error,
|
||||||
|
statusCode int,
|
||||||
|
handshakeHeaders http.Header,
|
||||||
|
) error {
|
||||||
|
if err == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
wrappedErr := err
|
||||||
|
var dialErr *openAIWSDialError
|
||||||
|
if !errors.As(err, &dialErr) {
|
||||||
|
wrappedErr = &openAIWSDialError{
|
||||||
|
StatusCode: statusCode,
|
||||||
|
ResponseHeaders: cloneHeader(handshakeHeaders),
|
||||||
|
Err: err,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if errors.Is(err, context.Canceled) {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if errors.Is(err, context.DeadlineExceeded) {
|
||||||
|
return NewOpenAIWSClientCloseError(
|
||||||
|
coderws.StatusTryAgainLater,
|
||||||
|
"upstream websocket connect timeout",
|
||||||
|
wrappedErr,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
if statusCode == http.StatusTooManyRequests {
|
||||||
|
return NewOpenAIWSClientCloseError(
|
||||||
|
coderws.StatusTryAgainLater,
|
||||||
|
"upstream websocket is busy, please retry later",
|
||||||
|
wrappedErr,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
if statusCode == http.StatusUnauthorized || statusCode == http.StatusForbidden {
|
||||||
|
return NewOpenAIWSClientCloseError(
|
||||||
|
coderws.StatusPolicyViolation,
|
||||||
|
"upstream websocket authentication failed",
|
||||||
|
wrappedErr,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
if statusCode >= http.StatusBadRequest && statusCode < http.StatusInternalServerError {
|
||||||
|
return NewOpenAIWSClientCloseError(
|
||||||
|
coderws.StatusPolicyViolation,
|
||||||
|
"upstream websocket handshake rejected",
|
||||||
|
wrappedErr,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
return fmt.Errorf("openai ws passthrough dial: %w", wrappedErr)
|
||||||
|
}
|
||||||
|
|
||||||
|
func openaiwsv2RelayMessageTypeName(msgType coderws.MessageType) string {
|
||||||
|
switch msgType {
|
||||||
|
case coderws.MessageText:
|
||||||
|
return "text"
|
||||||
|
case coderws.MessageBinary:
|
||||||
|
return "binary"
|
||||||
|
default:
|
||||||
|
return fmt.Sprintf("unknown(%d)", msgType)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func relayErrorText(err error) string {
|
||||||
|
if err == nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
return err.Error()
|
||||||
|
}
|
||||||
|
|
||||||
|
func openAIWSFirstTokenMsForLog(firstTokenMs *int) int {
|
||||||
|
if firstTokenMs == nil {
|
||||||
|
return -1
|
||||||
|
}
|
||||||
|
return *firstTokenMs
|
||||||
|
}
|
||||||
|
|
||||||
|
func logOpenAIWSV2Passthrough(format string, args ...any) {
|
||||||
|
logger.LegacyPrintf(
|
||||||
|
"service.openai_ws_v2",
|
||||||
|
"[OpenAI WS v2 passthrough] %s "+format,
|
||||||
|
append([]any{openaiWSV2PassthroughModeFields}, args...)...,
|
||||||
|
)
|
||||||
|
}
|
||||||
@@ -209,8 +209,9 @@ gateway:
|
|||||||
openai_ws:
|
openai_ws:
|
||||||
# 新版 WS mode 路由(默认关闭)。关闭时保持当前 legacy 实现行为。
|
# 新版 WS mode 路由(默认关闭)。关闭时保持当前 legacy 实现行为。
|
||||||
mode_router_v2_enabled: false
|
mode_router_v2_enabled: false
|
||||||
# ingress 默认模式:off|shared|dedicated(仅 mode_router_v2_enabled=true 生效)
|
# ingress 默认模式:off|ctx_pool|passthrough(仅 mode_router_v2_enabled=true 生效)
|
||||||
ingress_mode_default: shared
|
# 兼容旧值:shared/dedicated 会按 ctx_pool 处理。
|
||||||
|
ingress_mode_default: ctx_pool
|
||||||
# 全局总开关,默认 true;关闭时所有请求保持原有 HTTP/SSE 路由
|
# 全局总开关,默认 true;关闭时所有请求保持原有 HTTP/SSE 路由
|
||||||
enabled: true
|
enabled: true
|
||||||
# 按账号类型细分开关
|
# 按账号类型细分开关
|
||||||
|
|||||||
@@ -1807,7 +1807,7 @@
|
|||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
<!-- OpenAI WS Mode 三态(off/shared/dedicated) -->
|
<!-- OpenAI WS Mode 三态(off/ctx_pool/passthrough) -->
|
||||||
<div
|
<div
|
||||||
v-if="form.platform === 'openai' && (accountCategory === 'oauth-based' || accountCategory === 'apikey')"
|
v-if="form.platform === 'openai' && (accountCategory === 'oauth-based' || accountCategory === 'apikey')"
|
||||||
class="border-t border-gray-200 pt-4 dark:border-dark-600"
|
class="border-t border-gray-200 pt-4 dark:border-dark-600"
|
||||||
@@ -1819,7 +1819,7 @@
|
|||||||
{{ t('admin.accounts.openai.wsModeDesc') }}
|
{{ t('admin.accounts.openai.wsModeDesc') }}
|
||||||
</p>
|
</p>
|
||||||
<p class="mt-1 text-xs text-gray-500 dark:text-gray-400">
|
<p class="mt-1 text-xs text-gray-500 dark:text-gray-400">
|
||||||
{{ t('admin.accounts.openai.wsModeConcurrencyHint') }}
|
{{ t(openAIWSModeConcurrencyHintKey) }}
|
||||||
</p>
|
</p>
|
||||||
</div>
|
</div>
|
||||||
<div class="w-52">
|
<div class="w-52">
|
||||||
@@ -2341,10 +2341,11 @@ import { applyInterceptWarmup } from '@/components/account/credentialsBuilder'
|
|||||||
import { formatDateTimeLocalInput, parseDateTimeLocalInput } from '@/utils/format'
|
import { formatDateTimeLocalInput, parseDateTimeLocalInput } from '@/utils/format'
|
||||||
import { createStableObjectKeyResolver } from '@/utils/stableObjectKey'
|
import { createStableObjectKeyResolver } from '@/utils/stableObjectKey'
|
||||||
import {
|
import {
|
||||||
OPENAI_WS_MODE_DEDICATED,
|
OPENAI_WS_MODE_CTX_POOL,
|
||||||
OPENAI_WS_MODE_OFF,
|
OPENAI_WS_MODE_OFF,
|
||||||
OPENAI_WS_MODE_SHARED,
|
OPENAI_WS_MODE_PASSTHROUGH,
|
||||||
isOpenAIWSModeEnabled,
|
isOpenAIWSModeEnabled,
|
||||||
|
resolveOpenAIWSModeConcurrencyHintKey,
|
||||||
type OpenAIWSMode
|
type OpenAIWSMode
|
||||||
} from '@/utils/openaiWsMode'
|
} from '@/utils/openaiWsMode'
|
||||||
import OAuthAuthorizationFlow from './OAuthAuthorizationFlow.vue'
|
import OAuthAuthorizationFlow from './OAuthAuthorizationFlow.vue'
|
||||||
@@ -2541,8 +2542,8 @@ const geminiSelectedTier = computed(() => {
|
|||||||
|
|
||||||
const openAIWSModeOptions = computed(() => [
|
const openAIWSModeOptions = computed(() => [
|
||||||
{ value: OPENAI_WS_MODE_OFF, label: t('admin.accounts.openai.wsModeOff') },
|
{ value: OPENAI_WS_MODE_OFF, label: t('admin.accounts.openai.wsModeOff') },
|
||||||
{ value: OPENAI_WS_MODE_SHARED, label: t('admin.accounts.openai.wsModeShared') },
|
{ value: OPENAI_WS_MODE_CTX_POOL, label: t('admin.accounts.openai.wsModeCtxPool') },
|
||||||
{ value: OPENAI_WS_MODE_DEDICATED, label: t('admin.accounts.openai.wsModeDedicated') }
|
{ value: OPENAI_WS_MODE_PASSTHROUGH, label: t('admin.accounts.openai.wsModePassthrough') }
|
||||||
])
|
])
|
||||||
|
|
||||||
const openaiResponsesWebSocketV2Mode = computed({
|
const openaiResponsesWebSocketV2Mode = computed({
|
||||||
@@ -2561,6 +2562,10 @@ const openaiResponsesWebSocketV2Mode = computed({
|
|||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
|
const openAIWSModeConcurrencyHintKey = computed(() =>
|
||||||
|
resolveOpenAIWSModeConcurrencyHintKey(openaiResponsesWebSocketV2Mode.value)
|
||||||
|
)
|
||||||
|
|
||||||
const isOpenAIModelRestrictionDisabled = computed(() =>
|
const isOpenAIModelRestrictionDisabled = computed(() =>
|
||||||
form.platform === 'openai' && openaiPassthroughEnabled.value
|
form.platform === 'openai' && openaiPassthroughEnabled.value
|
||||||
)
|
)
|
||||||
@@ -3180,10 +3185,13 @@ const buildOpenAIExtra = (base?: Record<string, unknown>): Record<string, unknow
|
|||||||
}
|
}
|
||||||
|
|
||||||
const extra: Record<string, unknown> = { ...(base || {}) }
|
const extra: Record<string, unknown> = { ...(base || {}) }
|
||||||
extra.openai_oauth_responses_websockets_v2_mode = openaiOAuthResponsesWebSocketV2Mode.value
|
if (accountCategory.value === 'oauth-based') {
|
||||||
extra.openai_apikey_responses_websockets_v2_mode = openaiAPIKeyResponsesWebSocketV2Mode.value
|
extra.openai_oauth_responses_websockets_v2_mode = openaiOAuthResponsesWebSocketV2Mode.value
|
||||||
extra.openai_oauth_responses_websockets_v2_enabled = isOpenAIWSModeEnabled(openaiOAuthResponsesWebSocketV2Mode.value)
|
extra.openai_oauth_responses_websockets_v2_enabled = isOpenAIWSModeEnabled(openaiOAuthResponsesWebSocketV2Mode.value)
|
||||||
extra.openai_apikey_responses_websockets_v2_enabled = isOpenAIWSModeEnabled(openaiAPIKeyResponsesWebSocketV2Mode.value)
|
} else if (accountCategory.value === 'apikey') {
|
||||||
|
extra.openai_apikey_responses_websockets_v2_mode = openaiAPIKeyResponsesWebSocketV2Mode.value
|
||||||
|
extra.openai_apikey_responses_websockets_v2_enabled = isOpenAIWSModeEnabled(openaiAPIKeyResponsesWebSocketV2Mode.value)
|
||||||
|
}
|
||||||
// 清理兼容旧键,统一改用分类型开关。
|
// 清理兼容旧键,统一改用分类型开关。
|
||||||
delete extra.responses_websockets_v2_enabled
|
delete extra.responses_websockets_v2_enabled
|
||||||
delete extra.openai_ws_enabled
|
delete extra.openai_ws_enabled
|
||||||
|
|||||||
@@ -708,7 +708,7 @@
|
|||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
<!-- OpenAI WS Mode 三态(off/shared/dedicated) -->
|
<!-- OpenAI WS Mode 三态(off/ctx_pool/passthrough) -->
|
||||||
<div
|
<div
|
||||||
v-if="account?.platform === 'openai' && (account?.type === 'oauth' || account?.type === 'apikey')"
|
v-if="account?.platform === 'openai' && (account?.type === 'oauth' || account?.type === 'apikey')"
|
||||||
class="border-t border-gray-200 pt-4 dark:border-dark-600"
|
class="border-t border-gray-200 pt-4 dark:border-dark-600"
|
||||||
@@ -720,7 +720,7 @@
|
|||||||
{{ t('admin.accounts.openai.wsModeDesc') }}
|
{{ t('admin.accounts.openai.wsModeDesc') }}
|
||||||
</p>
|
</p>
|
||||||
<p class="mt-1 text-xs text-gray-500 dark:text-gray-400">
|
<p class="mt-1 text-xs text-gray-500 dark:text-gray-400">
|
||||||
{{ t('admin.accounts.openai.wsModeConcurrencyHint') }}
|
{{ t(openAIWSModeConcurrencyHintKey) }}
|
||||||
</p>
|
</p>
|
||||||
</div>
|
</div>
|
||||||
<div class="w-52">
|
<div class="w-52">
|
||||||
@@ -1273,10 +1273,11 @@ import { applyInterceptWarmup } from '@/components/account/credentialsBuilder'
|
|||||||
import { formatDateTimeLocalInput, parseDateTimeLocalInput } from '@/utils/format'
|
import { formatDateTimeLocalInput, parseDateTimeLocalInput } from '@/utils/format'
|
||||||
import { createStableObjectKeyResolver } from '@/utils/stableObjectKey'
|
import { createStableObjectKeyResolver } from '@/utils/stableObjectKey'
|
||||||
import {
|
import {
|
||||||
OPENAI_WS_MODE_DEDICATED,
|
OPENAI_WS_MODE_CTX_POOL,
|
||||||
OPENAI_WS_MODE_OFF,
|
OPENAI_WS_MODE_OFF,
|
||||||
OPENAI_WS_MODE_SHARED,
|
OPENAI_WS_MODE_PASSTHROUGH,
|
||||||
isOpenAIWSModeEnabled,
|
isOpenAIWSModeEnabled,
|
||||||
|
resolveOpenAIWSModeConcurrencyHintKey,
|
||||||
type OpenAIWSMode,
|
type OpenAIWSMode,
|
||||||
resolveOpenAIWSModeFromExtra
|
resolveOpenAIWSModeFromExtra
|
||||||
} from '@/utils/openaiWsMode'
|
} from '@/utils/openaiWsMode'
|
||||||
@@ -1387,8 +1388,8 @@ const codexCLIOnlyEnabled = ref(false)
|
|||||||
const anthropicPassthroughEnabled = ref(false)
|
const anthropicPassthroughEnabled = ref(false)
|
||||||
const openAIWSModeOptions = computed(() => [
|
const openAIWSModeOptions = computed(() => [
|
||||||
{ value: OPENAI_WS_MODE_OFF, label: t('admin.accounts.openai.wsModeOff') },
|
{ value: OPENAI_WS_MODE_OFF, label: t('admin.accounts.openai.wsModeOff') },
|
||||||
{ value: OPENAI_WS_MODE_SHARED, label: t('admin.accounts.openai.wsModeShared') },
|
{ value: OPENAI_WS_MODE_CTX_POOL, label: t('admin.accounts.openai.wsModeCtxPool') },
|
||||||
{ value: OPENAI_WS_MODE_DEDICATED, label: t('admin.accounts.openai.wsModeDedicated') }
|
{ value: OPENAI_WS_MODE_PASSTHROUGH, label: t('admin.accounts.openai.wsModePassthrough') }
|
||||||
])
|
])
|
||||||
const openaiResponsesWebSocketV2Mode = computed({
|
const openaiResponsesWebSocketV2Mode = computed({
|
||||||
get: () => {
|
get: () => {
|
||||||
@@ -1405,6 +1406,9 @@ const openaiResponsesWebSocketV2Mode = computed({
|
|||||||
openaiOAuthResponsesWebSocketV2Mode.value = mode
|
openaiOAuthResponsesWebSocketV2Mode.value = mode
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
const openAIWSModeConcurrencyHintKey = computed(() =>
|
||||||
|
resolveOpenAIWSModeConcurrencyHintKey(openaiResponsesWebSocketV2Mode.value)
|
||||||
|
)
|
||||||
const isOpenAIModelRestrictionDisabled = computed(() =>
|
const isOpenAIModelRestrictionDisabled = computed(() =>
|
||||||
props.account?.platform === 'openai' && openaiPassthroughEnabled.value
|
props.account?.platform === 'openai' && openaiPassthroughEnabled.value
|
||||||
)
|
)
|
||||||
@@ -2248,10 +2252,13 @@ const handleSubmit = async () => {
|
|||||||
const currentExtra = (props.account.extra as Record<string, unknown>) || {}
|
const currentExtra = (props.account.extra as Record<string, unknown>) || {}
|
||||||
const newExtra: Record<string, unknown> = { ...currentExtra }
|
const newExtra: Record<string, unknown> = { ...currentExtra }
|
||||||
const hadCodexCLIOnlyEnabled = currentExtra.codex_cli_only === true
|
const hadCodexCLIOnlyEnabled = currentExtra.codex_cli_only === true
|
||||||
newExtra.openai_oauth_responses_websockets_v2_mode = openaiOAuthResponsesWebSocketV2Mode.value
|
if (props.account.type === 'oauth') {
|
||||||
newExtra.openai_apikey_responses_websockets_v2_mode = openaiAPIKeyResponsesWebSocketV2Mode.value
|
newExtra.openai_oauth_responses_websockets_v2_mode = openaiOAuthResponsesWebSocketV2Mode.value
|
||||||
newExtra.openai_oauth_responses_websockets_v2_enabled = isOpenAIWSModeEnabled(openaiOAuthResponsesWebSocketV2Mode.value)
|
newExtra.openai_oauth_responses_websockets_v2_enabled = isOpenAIWSModeEnabled(openaiOAuthResponsesWebSocketV2Mode.value)
|
||||||
newExtra.openai_apikey_responses_websockets_v2_enabled = isOpenAIWSModeEnabled(openaiAPIKeyResponsesWebSocketV2Mode.value)
|
} else if (props.account.type === 'apikey') {
|
||||||
|
newExtra.openai_apikey_responses_websockets_v2_mode = openaiAPIKeyResponsesWebSocketV2Mode.value
|
||||||
|
newExtra.openai_apikey_responses_websockets_v2_enabled = isOpenAIWSModeEnabled(openaiAPIKeyResponsesWebSocketV2Mode.value)
|
||||||
|
}
|
||||||
delete newExtra.responses_websockets_v2_enabled
|
delete newExtra.responses_websockets_v2_enabled
|
||||||
delete newExtra.openai_ws_enabled
|
delete newExtra.openai_ws_enabled
|
||||||
if (openaiPassthroughEnabled.value) {
|
if (openaiPassthroughEnabled.value) {
|
||||||
|
|||||||
@@ -1846,10 +1846,13 @@ export default {
|
|||||||
wsMode: 'WS mode',
|
wsMode: 'WS mode',
|
||||||
wsModeDesc: 'Only applies to the current OpenAI account type.',
|
wsModeDesc: 'Only applies to the current OpenAI account type.',
|
||||||
wsModeOff: 'Off (off)',
|
wsModeOff: 'Off (off)',
|
||||||
|
wsModeCtxPool: 'Context Pool (ctx_pool)',
|
||||||
|
wsModePassthrough: 'Passthrough (passthrough)',
|
||||||
wsModeShared: 'Shared (shared)',
|
wsModeShared: 'Shared (shared)',
|
||||||
wsModeDedicated: 'Dedicated (dedicated)',
|
wsModeDedicated: 'Dedicated (dedicated)',
|
||||||
wsModeConcurrencyHint:
|
wsModeConcurrencyHint:
|
||||||
'When WS mode is enabled, account concurrency becomes the WS connection pool limit for this account.',
|
'When WS mode is enabled, account concurrency becomes the WS connection pool limit for this account.',
|
||||||
|
wsModePassthroughHint: 'Passthrough mode does not use the WS connection pool.',
|
||||||
oauthResponsesWebsocketsV2: 'OAuth WebSocket Mode',
|
oauthResponsesWebsocketsV2: 'OAuth WebSocket Mode',
|
||||||
oauthResponsesWebsocketsV2Desc:
|
oauthResponsesWebsocketsV2Desc:
|
||||||
'Only applies to OpenAI OAuth. This account can use OpenAI WebSocket Mode only when enabled.',
|
'Only applies to OpenAI OAuth. This account can use OpenAI WebSocket Mode only when enabled.',
|
||||||
|
|||||||
@@ -1994,9 +1994,12 @@ export default {
|
|||||||
wsMode: 'WS mode',
|
wsMode: 'WS mode',
|
||||||
wsModeDesc: '仅对当前 OpenAI 账号类型生效。',
|
wsModeDesc: '仅对当前 OpenAI 账号类型生效。',
|
||||||
wsModeOff: '关闭(off)',
|
wsModeOff: '关闭(off)',
|
||||||
|
wsModeCtxPool: '上下文池(ctx_pool)',
|
||||||
|
wsModePassthrough: '透传(passthrough)',
|
||||||
wsModeShared: '共享(shared)',
|
wsModeShared: '共享(shared)',
|
||||||
wsModeDedicated: '独享(dedicated)',
|
wsModeDedicated: '独享(dedicated)',
|
||||||
wsModeConcurrencyHint: '启用 WS mode 后,该账号并发数将作为该账号 WS 连接池上限。',
|
wsModeConcurrencyHint: '启用 WS mode 后,该账号并发数将作为该账号 WS 连接池上限。',
|
||||||
|
wsModePassthroughHint: 'passthrough 模式不使用 WS 连接池。',
|
||||||
oauthResponsesWebsocketsV2: 'OAuth WebSocket Mode',
|
oauthResponsesWebsocketsV2: 'OAuth WebSocket Mode',
|
||||||
oauthResponsesWebsocketsV2Desc:
|
oauthResponsesWebsocketsV2Desc:
|
||||||
'仅对 OpenAI OAuth 生效。开启后该账号才允许使用 OpenAI WebSocket Mode 协议。',
|
'仅对 OpenAI OAuth 生效。开启后该账号才允许使用 OpenAI WebSocket Mode 协议。',
|
||||||
|
|||||||
@@ -1,31 +1,34 @@
|
|||||||
import { describe, expect, it } from 'vitest'
|
import { describe, expect, it } from 'vitest'
|
||||||
import {
|
import {
|
||||||
OPENAI_WS_MODE_DEDICATED,
|
OPENAI_WS_MODE_CTX_POOL,
|
||||||
OPENAI_WS_MODE_OFF,
|
OPENAI_WS_MODE_OFF,
|
||||||
OPENAI_WS_MODE_SHARED,
|
OPENAI_WS_MODE_PASSTHROUGH,
|
||||||
isOpenAIWSModeEnabled,
|
isOpenAIWSModeEnabled,
|
||||||
normalizeOpenAIWSMode,
|
normalizeOpenAIWSMode,
|
||||||
openAIWSModeFromEnabled,
|
openAIWSModeFromEnabled,
|
||||||
|
resolveOpenAIWSModeConcurrencyHintKey,
|
||||||
resolveOpenAIWSModeFromExtra
|
resolveOpenAIWSModeFromExtra
|
||||||
} from '@/utils/openaiWsMode'
|
} from '@/utils/openaiWsMode'
|
||||||
|
|
||||||
describe('openaiWsMode utils', () => {
|
describe('openaiWsMode utils', () => {
|
||||||
it('normalizes mode values', () => {
|
it('normalizes mode values', () => {
|
||||||
expect(normalizeOpenAIWSMode('off')).toBe(OPENAI_WS_MODE_OFF)
|
expect(normalizeOpenAIWSMode('off')).toBe(OPENAI_WS_MODE_OFF)
|
||||||
expect(normalizeOpenAIWSMode(' Shared ')).toBe(OPENAI_WS_MODE_SHARED)
|
expect(normalizeOpenAIWSMode('ctx_pool')).toBe(OPENAI_WS_MODE_CTX_POOL)
|
||||||
expect(normalizeOpenAIWSMode('DEDICATED')).toBe(OPENAI_WS_MODE_DEDICATED)
|
expect(normalizeOpenAIWSMode('passthrough')).toBe(OPENAI_WS_MODE_PASSTHROUGH)
|
||||||
|
expect(normalizeOpenAIWSMode(' Shared ')).toBe(OPENAI_WS_MODE_CTX_POOL)
|
||||||
|
expect(normalizeOpenAIWSMode('DEDICATED')).toBe(OPENAI_WS_MODE_CTX_POOL)
|
||||||
expect(normalizeOpenAIWSMode('invalid')).toBeNull()
|
expect(normalizeOpenAIWSMode('invalid')).toBeNull()
|
||||||
})
|
})
|
||||||
|
|
||||||
it('maps legacy enabled flag to mode', () => {
|
it('maps legacy enabled flag to mode', () => {
|
||||||
expect(openAIWSModeFromEnabled(true)).toBe(OPENAI_WS_MODE_SHARED)
|
expect(openAIWSModeFromEnabled(true)).toBe(OPENAI_WS_MODE_CTX_POOL)
|
||||||
expect(openAIWSModeFromEnabled(false)).toBe(OPENAI_WS_MODE_OFF)
|
expect(openAIWSModeFromEnabled(false)).toBe(OPENAI_WS_MODE_OFF)
|
||||||
expect(openAIWSModeFromEnabled('true')).toBeNull()
|
expect(openAIWSModeFromEnabled('true')).toBeNull()
|
||||||
})
|
})
|
||||||
|
|
||||||
it('resolves by mode key first, then enabled, then fallback enabled keys', () => {
|
it('resolves by mode key first, then enabled, then fallback enabled keys', () => {
|
||||||
const extra = {
|
const extra = {
|
||||||
openai_oauth_responses_websockets_v2_mode: 'dedicated',
|
openai_oauth_responses_websockets_v2_mode: 'passthrough',
|
||||||
openai_oauth_responses_websockets_v2_enabled: false,
|
openai_oauth_responses_websockets_v2_enabled: false,
|
||||||
responses_websockets_v2_enabled: false
|
responses_websockets_v2_enabled: false
|
||||||
}
|
}
|
||||||
@@ -34,7 +37,7 @@ describe('openaiWsMode utils', () => {
|
|||||||
enabledKey: 'openai_oauth_responses_websockets_v2_enabled',
|
enabledKey: 'openai_oauth_responses_websockets_v2_enabled',
|
||||||
fallbackEnabledKeys: ['responses_websockets_v2_enabled', 'openai_ws_enabled']
|
fallbackEnabledKeys: ['responses_websockets_v2_enabled', 'openai_ws_enabled']
|
||||||
})
|
})
|
||||||
expect(mode).toBe(OPENAI_WS_MODE_DEDICATED)
|
expect(mode).toBe(OPENAI_WS_MODE_PASSTHROUGH)
|
||||||
})
|
})
|
||||||
|
|
||||||
it('falls back to default when nothing is present', () => {
|
it('falls back to default when nothing is present', () => {
|
||||||
@@ -47,9 +50,21 @@ describe('openaiWsMode utils', () => {
|
|||||||
expect(mode).toBe(OPENAI_WS_MODE_OFF)
|
expect(mode).toBe(OPENAI_WS_MODE_OFF)
|
||||||
})
|
})
|
||||||
|
|
||||||
it('treats off as disabled and shared/dedicated as enabled', () => {
|
it('treats off as disabled and non-off modes as enabled', () => {
|
||||||
expect(isOpenAIWSModeEnabled(OPENAI_WS_MODE_OFF)).toBe(false)
|
expect(isOpenAIWSModeEnabled(OPENAI_WS_MODE_OFF)).toBe(false)
|
||||||
expect(isOpenAIWSModeEnabled(OPENAI_WS_MODE_SHARED)).toBe(true)
|
expect(isOpenAIWSModeEnabled(OPENAI_WS_MODE_CTX_POOL)).toBe(true)
|
||||||
expect(isOpenAIWSModeEnabled(OPENAI_WS_MODE_DEDICATED)).toBe(true)
|
expect(isOpenAIWSModeEnabled(OPENAI_WS_MODE_PASSTHROUGH)).toBe(true)
|
||||||
|
})
|
||||||
|
|
||||||
|
it('resolves concurrency hint key by mode', () => {
|
||||||
|
expect(resolveOpenAIWSModeConcurrencyHintKey(OPENAI_WS_MODE_OFF)).toBe(
|
||||||
|
'admin.accounts.openai.wsModeConcurrencyHint'
|
||||||
|
)
|
||||||
|
expect(resolveOpenAIWSModeConcurrencyHintKey(OPENAI_WS_MODE_CTX_POOL)).toBe(
|
||||||
|
'admin.accounts.openai.wsModeConcurrencyHint'
|
||||||
|
)
|
||||||
|
expect(resolveOpenAIWSModeConcurrencyHintKey(OPENAI_WS_MODE_PASSTHROUGH)).toBe(
|
||||||
|
'admin.accounts.openai.wsModePassthroughHint'
|
||||||
|
)
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
|
|||||||
@@ -1,16 +1,16 @@
|
|||||||
export const OPENAI_WS_MODE_OFF = 'off'
|
export const OPENAI_WS_MODE_OFF = 'off'
|
||||||
export const OPENAI_WS_MODE_SHARED = 'shared'
|
export const OPENAI_WS_MODE_CTX_POOL = 'ctx_pool'
|
||||||
export const OPENAI_WS_MODE_DEDICATED = 'dedicated'
|
export const OPENAI_WS_MODE_PASSTHROUGH = 'passthrough'
|
||||||
|
|
||||||
export type OpenAIWSMode =
|
export type OpenAIWSMode =
|
||||||
| typeof OPENAI_WS_MODE_OFF
|
| typeof OPENAI_WS_MODE_OFF
|
||||||
| typeof OPENAI_WS_MODE_SHARED
|
| typeof OPENAI_WS_MODE_CTX_POOL
|
||||||
| typeof OPENAI_WS_MODE_DEDICATED
|
| typeof OPENAI_WS_MODE_PASSTHROUGH
|
||||||
|
|
||||||
const OPENAI_WS_MODES = new Set<OpenAIWSMode>([
|
const OPENAI_WS_MODES = new Set<OpenAIWSMode>([
|
||||||
OPENAI_WS_MODE_OFF,
|
OPENAI_WS_MODE_OFF,
|
||||||
OPENAI_WS_MODE_SHARED,
|
OPENAI_WS_MODE_CTX_POOL,
|
||||||
OPENAI_WS_MODE_DEDICATED
|
OPENAI_WS_MODE_PASSTHROUGH
|
||||||
])
|
])
|
||||||
|
|
||||||
export interface ResolveOpenAIWSModeOptions {
|
export interface ResolveOpenAIWSModeOptions {
|
||||||
@@ -23,6 +23,9 @@ export interface ResolveOpenAIWSModeOptions {
|
|||||||
export const normalizeOpenAIWSMode = (mode: unknown): OpenAIWSMode | null => {
|
export const normalizeOpenAIWSMode = (mode: unknown): OpenAIWSMode | null => {
|
||||||
if (typeof mode !== 'string') return null
|
if (typeof mode !== 'string') return null
|
||||||
const normalized = mode.trim().toLowerCase()
|
const normalized = mode.trim().toLowerCase()
|
||||||
|
if (normalized === 'shared' || normalized === 'dedicated') {
|
||||||
|
return OPENAI_WS_MODE_CTX_POOL
|
||||||
|
}
|
||||||
if (OPENAI_WS_MODES.has(normalized as OpenAIWSMode)) {
|
if (OPENAI_WS_MODES.has(normalized as OpenAIWSMode)) {
|
||||||
return normalized as OpenAIWSMode
|
return normalized as OpenAIWSMode
|
||||||
}
|
}
|
||||||
@@ -31,13 +34,22 @@ export const normalizeOpenAIWSMode = (mode: unknown): OpenAIWSMode | null => {
|
|||||||
|
|
||||||
export const openAIWSModeFromEnabled = (enabled: unknown): OpenAIWSMode | null => {
|
export const openAIWSModeFromEnabled = (enabled: unknown): OpenAIWSMode | null => {
|
||||||
if (typeof enabled !== 'boolean') return null
|
if (typeof enabled !== 'boolean') return null
|
||||||
return enabled ? OPENAI_WS_MODE_SHARED : OPENAI_WS_MODE_OFF
|
return enabled ? OPENAI_WS_MODE_CTX_POOL : OPENAI_WS_MODE_OFF
|
||||||
}
|
}
|
||||||
|
|
||||||
export const isOpenAIWSModeEnabled = (mode: OpenAIWSMode): boolean => {
|
export const isOpenAIWSModeEnabled = (mode: OpenAIWSMode): boolean => {
|
||||||
return mode !== OPENAI_WS_MODE_OFF
|
return mode !== OPENAI_WS_MODE_OFF
|
||||||
}
|
}
|
||||||
|
|
||||||
|
export const resolveOpenAIWSModeConcurrencyHintKey = (
|
||||||
|
mode: OpenAIWSMode
|
||||||
|
): 'admin.accounts.openai.wsModeConcurrencyHint' | 'admin.accounts.openai.wsModePassthroughHint' => {
|
||||||
|
if (mode === OPENAI_WS_MODE_PASSTHROUGH) {
|
||||||
|
return 'admin.accounts.openai.wsModePassthroughHint'
|
||||||
|
}
|
||||||
|
return 'admin.accounts.openai.wsModeConcurrencyHint'
|
||||||
|
}
|
||||||
|
|
||||||
export const resolveOpenAIWSModeFromExtra = (
|
export const resolveOpenAIWSModeFromExtra = (
|
||||||
extra: Record<string, unknown> | null | undefined,
|
extra: Record<string, unknown> | null | undefined,
|
||||||
options: ResolveOpenAIWSModeOptions
|
options: ResolveOpenAIWSModeOptions
|
||||||
|
|||||||
Reference in New Issue
Block a user