package service import ( "context" "errors" "fmt" "net/http" "net/url" "strings" "sync/atomic" "github.com/Wei-Shaw/sub2api/internal/pkg/logger" "github.com/Wei-Shaw/sub2api/internal/pkg/openai" openaiwsv2 "github.com/Wei-Shaw/sub2api/internal/service/openai_ws_v2" coderws "github.com/coder/websocket" "github.com/gin-gonic/gin" "github.com/tidwall/gjson" ) type openAIWSClientFrameConn struct { conn *coderws.Conn } const openaiWSV2PassthroughModeFields = "ws_mode=passthrough ws_router=v2" var _ openaiwsv2.FrameConn = (*openAIWSClientFrameConn)(nil) func (c *openAIWSClientFrameConn) ReadFrame(ctx context.Context) (coderws.MessageType, []byte, error) { if c == nil || c.conn == nil { return coderws.MessageText, nil, errOpenAIWSConnClosed } if ctx == nil { ctx = context.Background() } return c.conn.Read(ctx) } func (c *openAIWSClientFrameConn) WriteFrame(ctx context.Context, msgType coderws.MessageType, payload []byte) error { if c == nil || c.conn == nil { return errOpenAIWSConnClosed } if ctx == nil { ctx = context.Background() } return c.conn.Write(ctx, msgType, payload) } func (c *openAIWSClientFrameConn) Close() error { if c == nil || c.conn == nil { return nil } _ = c.conn.Close(coderws.StatusNormalClosure, "") _ = c.conn.CloseNow() return nil } func (s *OpenAIGatewayService) proxyResponsesWebSocketV2Passthrough( ctx context.Context, c *gin.Context, clientConn *coderws.Conn, account *Account, token string, firstClientMessage []byte, hooks *OpenAIWSIngressHooks, wsDecision OpenAIWSProtocolDecision, ) error { if s == nil { return errors.New("service is nil") } if clientConn == nil { return errors.New("client websocket is nil") } if account == nil { return errors.New("account is nil") } if strings.TrimSpace(token) == "" { return errors.New("token is empty") } requestModel := strings.TrimSpace(gjson.GetBytes(firstClientMessage, "model").String()) requestPreviousResponseID := strings.TrimSpace(gjson.GetBytes(firstClientMessage, "previous_response_id").String()) logOpenAIWSV2Passthrough( "relay_start account_id=%d model=%s previous_response_id=%s first_message_type=%s first_message_bytes=%d", account.ID, truncateOpenAIWSLogValue(requestModel, openAIWSLogValueMaxLen), truncateOpenAIWSLogValue(requestPreviousResponseID, openAIWSIDValueMaxLen), openaiwsv2RelayMessageTypeName(coderws.MessageText), len(firstClientMessage), ) wsURL, err := s.buildOpenAIResponsesWSURL(account) if err != nil { return fmt.Errorf("build ws url: %w", err) } wsHost := "-" wsPath := "-" if parsedURL, parseErr := url.Parse(wsURL); parseErr == nil && parsedURL != nil { wsHost = normalizeOpenAIWSLogValue(parsedURL.Host) wsPath = normalizeOpenAIWSLogValue(parsedURL.Path) } logOpenAIWSV2Passthrough( "relay_dial_start account_id=%d ws_host=%s ws_path=%s proxy_enabled=%v", account.ID, wsHost, wsPath, account.ProxyID != nil && account.Proxy != nil, ) isCodexCLI := false if c != nil { isCodexCLI = openai.IsCodexOfficialClientByHeaders(c.GetHeader("User-Agent"), c.GetHeader("originator")) } if s.cfg != nil && s.cfg.Gateway.ForceCodexCLI { isCodexCLI = true } headers, _ := s.buildOpenAIWSHeaders(c, account, token, wsDecision, isCodexCLI, "", "", "") proxyURL := "" if account.ProxyID != nil && account.Proxy != nil { proxyURL = account.Proxy.URL() } dialer := s.getOpenAIWSPassthroughDialer() if dialer == nil { return errors.New("openai ws passthrough dialer is nil") } dialCtx, cancelDial := context.WithTimeout(ctx, s.openAIWSDialTimeout()) defer cancelDial() upstreamConn, statusCode, handshakeHeaders, err := dialer.Dial(dialCtx, wsURL, headers, proxyURL) if err != nil { logOpenAIWSV2Passthrough( "relay_dial_failed account_id=%d status_code=%d err=%s", account.ID, statusCode, truncateOpenAIWSLogValue(err.Error(), openAIWSLogValueMaxLen), ) return s.mapOpenAIWSPassthroughDialError(err, statusCode, handshakeHeaders) } defer func() { _ = upstreamConn.Close() }() logOpenAIWSV2Passthrough( "relay_dial_ok account_id=%d status_code=%d upstream_request_id=%s", account.ID, statusCode, openAIWSHeaderValueForLog(handshakeHeaders, "x-request-id"), ) upstreamFrameConn, ok := upstreamConn.(openaiwsv2.FrameConn) if !ok { return errors.New("openai ws passthrough upstream connection does not support frame relay") } completedTurns := atomic.Int32{} relayResult, relayExit := openaiwsv2.RunEntry(openaiwsv2.EntryInput{ Ctx: ctx, ClientConn: &openAIWSClientFrameConn{conn: clientConn}, UpstreamConn: upstreamFrameConn, FirstClientMessage: firstClientMessage, Options: openaiwsv2.RelayOptions{ WriteTimeout: s.openAIWSWriteTimeout(), IdleTimeout: s.openAIWSPassthroughIdleTimeout(), FirstMessageType: coderws.MessageText, OnUsageParseFailure: func(eventType string, usageRaw string) { logOpenAIWSV2Passthrough( "usage_parse_failed event_type=%s usage_raw=%s", truncateOpenAIWSLogValue(eventType, openAIWSLogValueMaxLen), truncateOpenAIWSLogValue(usageRaw, openAIWSLogValueMaxLen), ) }, OnTurnComplete: func(turn openaiwsv2.RelayTurnResult) { turnNo := int(completedTurns.Add(1)) turnResult := &OpenAIForwardResult{ RequestID: turn.RequestID, Usage: OpenAIUsage{ InputTokens: turn.Usage.InputTokens, OutputTokens: turn.Usage.OutputTokens, CacheCreationInputTokens: turn.Usage.CacheCreationInputTokens, CacheReadInputTokens: turn.Usage.CacheReadInputTokens, }, Model: turn.RequestModel, Stream: true, OpenAIWSMode: true, ResponseHeaders: cloneHeader(handshakeHeaders), Duration: turn.Duration, FirstTokenMs: turn.FirstTokenMs, } logOpenAIWSV2Passthrough( "relay_turn_completed account_id=%d turn=%d request_id=%s terminal_event=%s duration_ms=%d first_token_ms=%d input_tokens=%d output_tokens=%d cache_read_tokens=%d", account.ID, turnNo, truncateOpenAIWSLogValue(turnResult.RequestID, openAIWSIDValueMaxLen), truncateOpenAIWSLogValue(turn.TerminalEventType, openAIWSLogValueMaxLen), turnResult.Duration.Milliseconds(), openAIWSFirstTokenMsForLog(turnResult.FirstTokenMs), turnResult.Usage.InputTokens, turnResult.Usage.OutputTokens, turnResult.Usage.CacheReadInputTokens, ) if hooks != nil && hooks.AfterTurn != nil { hooks.AfterTurn(turnNo, turnResult, nil) } }, OnTrace: func(event openaiwsv2.RelayTraceEvent) { logOpenAIWSV2Passthrough( "relay_trace account_id=%d stage=%s direction=%s msg_type=%s bytes=%d graceful=%v wrote_downstream=%v err=%s", account.ID, truncateOpenAIWSLogValue(event.Stage, openAIWSLogValueMaxLen), truncateOpenAIWSLogValue(event.Direction, openAIWSLogValueMaxLen), truncateOpenAIWSLogValue(event.MessageType, openAIWSLogValueMaxLen), event.PayloadBytes, event.Graceful, event.WroteDownstream, truncateOpenAIWSLogValue(event.Error, openAIWSLogValueMaxLen), ) }, }, }) result := &OpenAIForwardResult{ RequestID: relayResult.RequestID, Usage: OpenAIUsage{ InputTokens: relayResult.Usage.InputTokens, OutputTokens: relayResult.Usage.OutputTokens, CacheCreationInputTokens: relayResult.Usage.CacheCreationInputTokens, CacheReadInputTokens: relayResult.Usage.CacheReadInputTokens, }, Model: relayResult.RequestModel, Stream: true, OpenAIWSMode: true, ResponseHeaders: cloneHeader(handshakeHeaders), Duration: relayResult.Duration, FirstTokenMs: relayResult.FirstTokenMs, } turnCount := int(completedTurns.Load()) if relayExit == nil { logOpenAIWSV2Passthrough( "relay_completed account_id=%d request_id=%s terminal_event=%s duration_ms=%d c2u_frames=%d u2c_frames=%d dropped_frames=%d turns=%d", account.ID, truncateOpenAIWSLogValue(result.RequestID, openAIWSIDValueMaxLen), truncateOpenAIWSLogValue(relayResult.TerminalEventType, openAIWSLogValueMaxLen), result.Duration.Milliseconds(), relayResult.ClientToUpstreamFrames, relayResult.UpstreamToClientFrames, relayResult.DroppedDownstreamFrames, turnCount, ) // 正常路径按 terminal 事件逐 turn 已回调;仅在零 turn 场景兜底回调一次。 if turnCount == 0 && hooks != nil && hooks.AfterTurn != nil { hooks.AfterTurn(1, result, nil) } return nil } logOpenAIWSV2Passthrough( "relay_failed account_id=%d stage=%s wrote_downstream=%v err=%s duration_ms=%d c2u_frames=%d u2c_frames=%d dropped_frames=%d turns=%d", account.ID, truncateOpenAIWSLogValue(relayExit.Stage, openAIWSLogValueMaxLen), relayExit.WroteDownstream, truncateOpenAIWSLogValue(relayErrorText(relayExit.Err), openAIWSLogValueMaxLen), result.Duration.Milliseconds(), relayResult.ClientToUpstreamFrames, relayResult.UpstreamToClientFrames, relayResult.DroppedDownstreamFrames, turnCount, ) relayErr := relayExit.Err if relayExit.Stage == "idle_timeout" { relayErr = NewOpenAIWSClientCloseError( coderws.StatusPolicyViolation, "client websocket idle timeout", relayErr, ) } turnErr := wrapOpenAIWSIngressTurnError( relayExit.Stage, relayErr, relayExit.WroteDownstream, ) if hooks != nil && hooks.AfterTurn != nil { hooks.AfterTurn(turnCount+1, nil, turnErr) } return turnErr } func (s *OpenAIGatewayService) mapOpenAIWSPassthroughDialError( err error, statusCode int, handshakeHeaders http.Header, ) error { if err == nil { return nil } wrappedErr := err var dialErr *openAIWSDialError if !errors.As(err, &dialErr) { wrappedErr = &openAIWSDialError{ StatusCode: statusCode, ResponseHeaders: cloneHeader(handshakeHeaders), Err: err, } } if errors.Is(err, context.Canceled) { return err } if errors.Is(err, context.DeadlineExceeded) { return NewOpenAIWSClientCloseError( coderws.StatusTryAgainLater, "upstream websocket connect timeout", wrappedErr, ) } if statusCode == http.StatusTooManyRequests { return NewOpenAIWSClientCloseError( coderws.StatusTryAgainLater, "upstream websocket is busy, please retry later", wrappedErr, ) } if statusCode == http.StatusUnauthorized || statusCode == http.StatusForbidden { return NewOpenAIWSClientCloseError( coderws.StatusPolicyViolation, "upstream websocket authentication failed", wrappedErr, ) } if statusCode >= http.StatusBadRequest && statusCode < http.StatusInternalServerError { return NewOpenAIWSClientCloseError( coderws.StatusPolicyViolation, "upstream websocket handshake rejected", wrappedErr, ) } return fmt.Errorf("openai ws passthrough dial: %w", wrappedErr) } func openaiwsv2RelayMessageTypeName(msgType coderws.MessageType) string { switch msgType { case coderws.MessageText: return "text" case coderws.MessageBinary: return "binary" default: return fmt.Sprintf("unknown(%d)", msgType) } } func relayErrorText(err error) string { if err == nil { return "" } return err.Error() } func openAIWSFirstTokenMsForLog(firstTokenMs *int) int { if firstTokenMs == nil { return -1 } return *firstTokenMs } func logOpenAIWSV2Passthrough(format string, args ...any) { logger.LegacyPrintf( "service.openai_ws_v2", "[OpenAI WS v2 passthrough] %s "+format, append([]any{openaiWSV2PassthroughModeFields}, args...)..., ) }