package service import ( "bytes" "context" "encoding/json" "errors" "fmt" "io" "math/rand" "net" "net/http" "net/url" "sort" "strings" "time" "github.com/Wei-Shaw/sub2api/internal/pkg/logger" "github.com/Wei-Shaw/sub2api/internal/pkg/openai" "github.com/Wei-Shaw/sub2api/internal/util/responseheaders" coderws "github.com/coder/websocket" "github.com/gin-gonic/gin" "github.com/tidwall/gjson" "github.com/tidwall/sjson" "go.uber.org/zap" ) const ( openAIWSBetaV1Value = "responses_websockets=2026-02-04" openAIWSBetaV2Value = "responses_websockets=2026-02-06" openAIWSTurnStateHeader = "x-codex-turn-state" openAIWSTurnMetadataHeader = "x-codex-turn-metadata" openAIWSLogValueMaxLen = 160 openAIWSHeaderValueMaxLen = 120 openAIWSIDValueMaxLen = 64 openAIWSEventLogHeadLimit = 20 openAIWSEventLogEveryN = 50 openAIWSBufferLogHeadLimit = 8 openAIWSBufferLogEveryN = 20 openAIWSPrewarmEventLogHead = 10 openAIWSPayloadKeySizeTopN = 6 openAIWSPayloadSizeEstimateDepth = 3 openAIWSPayloadSizeEstimateMaxBytes = 64 * 1024 openAIWSPayloadSizeEstimateMaxItems = 16 openAIWSEventFlushBatchSizeDefault = 4 openAIWSEventFlushIntervalDefault = 25 * time.Millisecond openAIWSPayloadLogSampleDefault = 0.2 openAIWSStoreDisabledConnModeStrict = "strict" openAIWSStoreDisabledConnModeAdaptive = "adaptive" openAIWSStoreDisabledConnModeOff = "off" openAIWSIngressStagePreviousResponseNotFound = "previous_response_not_found" openAIWSMaxPrevResponseIDDeletePasses = 8 ) var openAIWSLogValueReplacer = strings.NewReplacer( "error", "err", "fallback", "fb", "warning", "warnx", "failed", "fail", ) var openAIWSIngressPreflightPingIdle = 20 * time.Second // openAIWSFallbackError 表示可安全回退到 HTTP 的 WS 错误(尚未写下游)。 type openAIWSFallbackError struct { Reason string Err error } func (e *openAIWSFallbackError) Error() string { if e == nil { return "" } if e.Err == nil { return fmt.Sprintf("openai ws fallback: %s", strings.TrimSpace(e.Reason)) } return fmt.Sprintf("openai ws fallback: %s: %v", strings.TrimSpace(e.Reason), e.Err) } func (e *openAIWSFallbackError) Unwrap() error { if e == nil { return nil } return e.Err } func wrapOpenAIWSFallback(reason string, err error) error { return &openAIWSFallbackError{Reason: strings.TrimSpace(reason), Err: err} } // OpenAIWSClientCloseError 表示应以指定 WebSocket close code 主动关闭客户端连接的错误。 type OpenAIWSClientCloseError struct { statusCode coderws.StatusCode reason string err error } type openAIWSIngressTurnError struct { stage string cause error wroteDownstream bool } func (e *openAIWSIngressTurnError) Error() string { if e == nil { return "" } if e.cause == nil { return strings.TrimSpace(e.stage) } return e.cause.Error() } func (e *openAIWSIngressTurnError) Unwrap() error { if e == nil { return nil } return e.cause } func wrapOpenAIWSIngressTurnError(stage string, cause error, wroteDownstream bool) error { if cause == nil { return nil } return &openAIWSIngressTurnError{ stage: strings.TrimSpace(stage), cause: cause, wroteDownstream: wroteDownstream, } } func isOpenAIWSIngressTurnRetryable(err error) bool { var turnErr *openAIWSIngressTurnError if !errors.As(err, &turnErr) || turnErr == nil { return false } if errors.Is(turnErr.cause, context.Canceled) || errors.Is(turnErr.cause, context.DeadlineExceeded) { return false } if turnErr.wroteDownstream { return false } switch turnErr.stage { case "write_upstream", "read_upstream": return true default: return false } } func openAIWSIngressTurnRetryReason(err error) string { var turnErr *openAIWSIngressTurnError if !errors.As(err, &turnErr) || turnErr == nil { return "unknown" } if turnErr.stage == "" { return "unknown" } return turnErr.stage } func isOpenAIWSIngressPreviousResponseNotFound(err error) bool { var turnErr *openAIWSIngressTurnError if !errors.As(err, &turnErr) || turnErr == nil { return false } if strings.TrimSpace(turnErr.stage) != openAIWSIngressStagePreviousResponseNotFound { return false } return !turnErr.wroteDownstream } // NewOpenAIWSClientCloseError 创建一个客户端 WS 关闭错误。 func NewOpenAIWSClientCloseError(statusCode coderws.StatusCode, reason string, err error) error { return &OpenAIWSClientCloseError{ statusCode: statusCode, reason: strings.TrimSpace(reason), err: err, } } func (e *OpenAIWSClientCloseError) Error() string { if e == nil { return "" } if e.err == nil { return fmt.Sprintf("openai ws client close: %d %s", int(e.statusCode), strings.TrimSpace(e.reason)) } return fmt.Sprintf("openai ws client close: %d %s: %v", int(e.statusCode), strings.TrimSpace(e.reason), e.err) } func (e *OpenAIWSClientCloseError) Unwrap() error { if e == nil { return nil } return e.err } func (e *OpenAIWSClientCloseError) StatusCode() coderws.StatusCode { if e == nil { return coderws.StatusInternalError } return e.statusCode } func (e *OpenAIWSClientCloseError) Reason() string { if e == nil { return "" } return strings.TrimSpace(e.reason) } // OpenAIWSIngressHooks 定义入站 WS 每个 turn 的生命周期回调。 type OpenAIWSIngressHooks struct { BeforeTurn func(turn int) error AfterTurn func(turn int, result *OpenAIForwardResult, turnErr error) } func normalizeOpenAIWSLogValue(value string) string { trimmed := strings.TrimSpace(value) if trimmed == "" { return "-" } return openAIWSLogValueReplacer.Replace(trimmed) } func truncateOpenAIWSLogValue(value string, maxLen int) string { normalized := normalizeOpenAIWSLogValue(value) if normalized == "-" || maxLen <= 0 { return normalized } if len(normalized) <= maxLen { return normalized } return normalized[:maxLen] + "..." } func openAIWSHeaderValueForLog(headers http.Header, key string) string { if headers == nil { return "-" } return truncateOpenAIWSLogValue(headers.Get(key), openAIWSHeaderValueMaxLen) } func hasOpenAIWSHeader(headers http.Header, key string) bool { if headers == nil { return false } return strings.TrimSpace(headers.Get(key)) != "" } type openAIWSSessionHeaderResolution struct { SessionID string ConversationID string SessionSource string ConversationSource string } func resolveOpenAIWSSessionHeaders(c *gin.Context, promptCacheKey string) openAIWSSessionHeaderResolution { resolution := openAIWSSessionHeaderResolution{ SessionSource: "none", ConversationSource: "none", } if c != nil && c.Request != nil { if sessionID := strings.TrimSpace(c.Request.Header.Get("session_id")); sessionID != "" { resolution.SessionID = sessionID resolution.SessionSource = "header_session_id" } if conversationID := strings.TrimSpace(c.Request.Header.Get("conversation_id")); conversationID != "" { resolution.ConversationID = conversationID resolution.ConversationSource = "header_conversation_id" if resolution.SessionID == "" { resolution.SessionID = conversationID resolution.SessionSource = "header_conversation_id" } } } cacheKey := strings.TrimSpace(promptCacheKey) if cacheKey != "" { if resolution.SessionID == "" { resolution.SessionID = cacheKey resolution.SessionSource = "prompt_cache_key" } } return resolution } func shouldLogOpenAIWSEvent(idx int, eventType string) bool { if idx <= openAIWSEventLogHeadLimit { return true } if openAIWSEventLogEveryN > 0 && idx%openAIWSEventLogEveryN == 0 { return true } if eventType == "error" || isOpenAIWSTerminalEvent(eventType) { return true } return false } func shouldLogOpenAIWSBufferedEvent(idx int) bool { if idx <= openAIWSBufferLogHeadLimit { return true } if openAIWSBufferLogEveryN > 0 && idx%openAIWSBufferLogEveryN == 0 { return true } return false } func openAIWSEventMayContainModel(eventType string) bool { switch eventType { case "response.created", "response.in_progress", "response.completed", "response.done", "response.failed", "response.incomplete", "response.cancelled", "response.canceled": return true default: trimmed := strings.TrimSpace(eventType) if trimmed == eventType { return false } switch trimmed { case "response.created", "response.in_progress", "response.completed", "response.done", "response.failed", "response.incomplete", "response.cancelled", "response.canceled": return true default: return false } } } func openAIWSEventMayContainToolCalls(eventType string) bool { eventType = strings.TrimSpace(eventType) if eventType == "" { return false } if strings.Contains(eventType, "function_call") || strings.Contains(eventType, "tool_call") { return true } switch eventType { case "response.output_item.added", "response.output_item.done", "response.completed", "response.done": return true default: return false } } func openAIWSEventShouldParseUsage(eventType string) bool { return eventType == "response.completed" || strings.TrimSpace(eventType) == "response.completed" } func parseOpenAIWSEventEnvelope(message []byte) (eventType string, responseID string, response gjson.Result) { if len(message) == 0 { return "", "", gjson.Result{} } values := gjson.GetManyBytes(message, "type", "response.id", "id", "response") eventType = strings.TrimSpace(values[0].String()) if id := strings.TrimSpace(values[1].String()); id != "" { responseID = id } else { responseID = strings.TrimSpace(values[2].String()) } return eventType, responseID, values[3] } func openAIWSMessageLikelyContainsToolCalls(message []byte) bool { if len(message) == 0 { return false } return bytes.Contains(message, []byte(`"tool_calls"`)) || bytes.Contains(message, []byte(`"tool_call"`)) || bytes.Contains(message, []byte(`"function_call"`)) } func parseOpenAIWSResponseUsageFromCompletedEvent(message []byte, usage *OpenAIUsage) { if usage == nil || len(message) == 0 { return } values := gjson.GetManyBytes( message, "response.usage.input_tokens", "response.usage.output_tokens", "response.usage.input_tokens_details.cached_tokens", ) usage.InputTokens = int(values[0].Int()) usage.OutputTokens = int(values[1].Int()) usage.CacheReadInputTokens = int(values[2].Int()) } func parseOpenAIWSErrorEventFields(message []byte) (code string, errType string, errMessage string) { if len(message) == 0 { return "", "", "" } values := gjson.GetManyBytes(message, "error.code", "error.type", "error.message") return strings.TrimSpace(values[0].String()), strings.TrimSpace(values[1].String()), strings.TrimSpace(values[2].String()) } func summarizeOpenAIWSErrorEventFieldsFromRaw(codeRaw, errTypeRaw, errMessageRaw string) (code string, errType string, errMessage string) { code = truncateOpenAIWSLogValue(codeRaw, openAIWSLogValueMaxLen) errType = truncateOpenAIWSLogValue(errTypeRaw, openAIWSLogValueMaxLen) errMessage = truncateOpenAIWSLogValue(errMessageRaw, openAIWSLogValueMaxLen) return code, errType, errMessage } func summarizeOpenAIWSErrorEventFields(message []byte) (code string, errType string, errMessage string) { if len(message) == 0 { return "-", "-", "-" } return summarizeOpenAIWSErrorEventFieldsFromRaw(parseOpenAIWSErrorEventFields(message)) } func summarizeOpenAIWSPayloadKeySizes(payload map[string]any, topN int) string { if len(payload) == 0 { return "-" } type keySize struct { Key string Size int } sizes := make([]keySize, 0, len(payload)) for key, value := range payload { size := estimateOpenAIWSPayloadValueSize(value, openAIWSPayloadSizeEstimateDepth) sizes = append(sizes, keySize{Key: key, Size: size}) } sort.Slice(sizes, func(i, j int) bool { if sizes[i].Size == sizes[j].Size { return sizes[i].Key < sizes[j].Key } return sizes[i].Size > sizes[j].Size }) if topN <= 0 || topN > len(sizes) { topN = len(sizes) } parts := make([]string, 0, topN) for idx := 0; idx < topN; idx++ { item := sizes[idx] parts = append(parts, fmt.Sprintf("%s:%d", item.Key, item.Size)) } return strings.Join(parts, ",") } func estimateOpenAIWSPayloadValueSize(value any, depth int) int { if depth <= 0 { return -1 } switch v := value.(type) { case nil: return 0 case string: return len(v) case []byte: return len(v) case bool: return 1 case int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64: return 8 case float32, float64: return 8 case map[string]any: if len(v) == 0 { return 2 } total := 2 count := 0 for key, item := range v { count++ if count > openAIWSPayloadSizeEstimateMaxItems { return -1 } itemSize := estimateOpenAIWSPayloadValueSize(item, depth-1) if itemSize < 0 { return -1 } total += len(key) + itemSize + 3 if total > openAIWSPayloadSizeEstimateMaxBytes { return -1 } } return total case []any: if len(v) == 0 { return 2 } total := 2 limit := len(v) if limit > openAIWSPayloadSizeEstimateMaxItems { return -1 } for i := 0; i < limit; i++ { itemSize := estimateOpenAIWSPayloadValueSize(v[i], depth-1) if itemSize < 0 { return -1 } total += itemSize + 1 if total > openAIWSPayloadSizeEstimateMaxBytes { return -1 } } return total default: raw, err := json.Marshal(v) if err != nil { return -1 } if len(raw) > openAIWSPayloadSizeEstimateMaxBytes { return -1 } return len(raw) } } func openAIWSPayloadString(payload map[string]any, key string) string { if len(payload) == 0 { return "" } raw, ok := payload[key] if !ok { return "" } switch v := raw.(type) { case nil: return "" case string: return strings.TrimSpace(v) case []byte: return strings.TrimSpace(string(v)) default: return "" } } func openAIWSPayloadStringFromRaw(payload []byte, key string) string { if len(payload) == 0 || strings.TrimSpace(key) == "" { return "" } return strings.TrimSpace(gjson.GetBytes(payload, key).String()) } func openAIWSPayloadBoolFromRaw(payload []byte, key string, defaultValue bool) bool { if len(payload) == 0 || strings.TrimSpace(key) == "" { return defaultValue } value := gjson.GetBytes(payload, key) if !value.Exists() { return defaultValue } if value.Type != gjson.True && value.Type != gjson.False { return defaultValue } return value.Bool() } func openAIWSSessionHashesFromID(sessionID string) (string, string) { return deriveOpenAISessionHashes(sessionID) } func extractOpenAIWSImageURL(value any) string { switch v := value.(type) { case string: return strings.TrimSpace(v) case map[string]any: if raw, ok := v["url"].(string); ok { return strings.TrimSpace(raw) } } return "" } func summarizeOpenAIWSInput(input any) string { items, ok := input.([]any) if !ok || len(items) == 0 { return "-" } itemCount := len(items) textChars := 0 imageDataURLs := 0 imageDataURLChars := 0 imageRemoteURLs := 0 handleContentItem := func(contentItem map[string]any) { contentType, _ := contentItem["type"].(string) switch strings.TrimSpace(contentType) { case "input_text", "output_text", "text": if text, ok := contentItem["text"].(string); ok { textChars += len(text) } case "input_image": imageURL := extractOpenAIWSImageURL(contentItem["image_url"]) if imageURL == "" { return } if strings.HasPrefix(strings.ToLower(imageURL), "data:image/") { imageDataURLs++ imageDataURLChars += len(imageURL) return } imageRemoteURLs++ } } handleInputItem := func(inputItem map[string]any) { if content, ok := inputItem["content"].([]any); ok { for _, rawContent := range content { contentItem, ok := rawContent.(map[string]any) if !ok { continue } handleContentItem(contentItem) } return } itemType, _ := inputItem["type"].(string) switch strings.TrimSpace(itemType) { case "input_text", "output_text", "text": if text, ok := inputItem["text"].(string); ok { textChars += len(text) } case "input_image": imageURL := extractOpenAIWSImageURL(inputItem["image_url"]) if imageURL == "" { return } if strings.HasPrefix(strings.ToLower(imageURL), "data:image/") { imageDataURLs++ imageDataURLChars += len(imageURL) return } imageRemoteURLs++ } } for _, rawItem := range items { inputItem, ok := rawItem.(map[string]any) if !ok { continue } handleInputItem(inputItem) } return fmt.Sprintf( "items=%d,text_chars=%d,image_data_urls=%d,image_data_url_chars=%d,image_remote_urls=%d", itemCount, textChars, imageDataURLs, imageDataURLChars, imageRemoteURLs, ) } func dropOpenAIWSPayloadKey(payload map[string]any, key string, removed *[]string) { if len(payload) == 0 || strings.TrimSpace(key) == "" { return } if _, exists := payload[key]; !exists { return } delete(payload, key) *removed = append(*removed, key) } // applyOpenAIWSRetryPayloadStrategy 在 WS 连续失败时仅移除无语义字段, // 避免重试成功却改变原始请求语义。 // 注意:prompt_cache_key 不应在重试中移除;它常用于会话稳定标识(session_id 兜底)。 func applyOpenAIWSRetryPayloadStrategy(payload map[string]any, attempt int) (strategy string, removedKeys []string) { if len(payload) == 0 { return "empty", nil } if attempt <= 1 { return "full", nil } removed := make([]string, 0, 2) if attempt >= 2 { dropOpenAIWSPayloadKey(payload, "include", &removed) } if len(removed) == 0 { return "full", nil } sort.Strings(removed) return "trim_optional_fields", removed } func logOpenAIWSModeInfo(format string, args ...any) { logger.LegacyPrintf("service.openai_gateway", "[OpenAI WS Mode][openai_ws_mode=true] "+format, args...) } func isOpenAIWSModeDebugEnabled() bool { return logger.L().Core().Enabled(zap.DebugLevel) } func logOpenAIWSModeDebug(format string, args ...any) { if !isOpenAIWSModeDebugEnabled() { return } logger.LegacyPrintf("service.openai_gateway", "[debug] [OpenAI WS Mode][openai_ws_mode=true] "+format, args...) } func logOpenAIWSBindResponseAccountWarn(groupID, accountID int64, responseID string, err error) { if err == nil { return } logger.L().Warn( "openai.ws_bind_response_account_failed", zap.Int64("group_id", groupID), zap.Int64("account_id", accountID), zap.String("response_id", truncateOpenAIWSLogValue(responseID, openAIWSIDValueMaxLen)), zap.Error(err), ) } func summarizeOpenAIWSReadCloseError(err error) (status string, reason string) { if err == nil { return "-", "-" } statusCode := coderws.CloseStatus(err) if statusCode == -1 { return "-", "-" } closeStatus := fmt.Sprintf("%d(%s)", int(statusCode), statusCode.String()) closeReason := "-" var closeErr coderws.CloseError if errors.As(err, &closeErr) { reasonText := strings.TrimSpace(closeErr.Reason) if reasonText != "" { closeReason = normalizeOpenAIWSLogValue(reasonText) } } return normalizeOpenAIWSLogValue(closeStatus), closeReason } func unwrapOpenAIWSDialBaseError(err error) error { if err == nil { return nil } var dialErr *openAIWSDialError if errors.As(err, &dialErr) && dialErr != nil && dialErr.Err != nil { return dialErr.Err } return err } func openAIWSDialRespHeaderForLog(err error, key string) string { var dialErr *openAIWSDialError if !errors.As(err, &dialErr) || dialErr == nil || dialErr.ResponseHeaders == nil { return "-" } return truncateOpenAIWSLogValue(dialErr.ResponseHeaders.Get(key), openAIWSHeaderValueMaxLen) } func classifyOpenAIWSDialError(err error) string { if err == nil { return "-" } baseErr := unwrapOpenAIWSDialBaseError(err) if baseErr == nil { return "-" } if errors.Is(baseErr, context.DeadlineExceeded) { return "ctx_deadline_exceeded" } if errors.Is(baseErr, context.Canceled) { return "ctx_canceled" } var netErr net.Error if errors.As(baseErr, &netErr) && netErr.Timeout() { return "net_timeout" } if status := coderws.CloseStatus(baseErr); status != -1 { return normalizeOpenAIWSLogValue(fmt.Sprintf("ws_close_%d", int(status))) } message := strings.ToLower(strings.TrimSpace(baseErr.Error())) switch { case strings.Contains(message, "handshake not finished"): return "handshake_not_finished" case strings.Contains(message, "bad handshake"): return "bad_handshake" case strings.Contains(message, "connection refused"): return "connection_refused" case strings.Contains(message, "no such host"): return "dns_not_found" case strings.Contains(message, "tls"): return "tls_error" case strings.Contains(message, "i/o timeout"): return "io_timeout" case strings.Contains(message, "context deadline exceeded"): return "ctx_deadline_exceeded" default: return "dial_error" } } func summarizeOpenAIWSDialError(err error) ( statusCode int, dialClass string, closeStatus string, closeReason string, respServer string, respVia string, respCFRay string, respRequestID string, ) { dialClass = "-" closeStatus = "-" closeReason = "-" respServer = "-" respVia = "-" respCFRay = "-" respRequestID = "-" if err == nil { return } var dialErr *openAIWSDialError if errors.As(err, &dialErr) && dialErr != nil { statusCode = dialErr.StatusCode respServer = openAIWSDialRespHeaderForLog(err, "server") respVia = openAIWSDialRespHeaderForLog(err, "via") respCFRay = openAIWSDialRespHeaderForLog(err, "cf-ray") respRequestID = openAIWSDialRespHeaderForLog(err, "x-request-id") } dialClass = normalizeOpenAIWSLogValue(classifyOpenAIWSDialError(err)) closeStatus, closeReason = summarizeOpenAIWSReadCloseError(unwrapOpenAIWSDialBaseError(err)) return } func isOpenAIWSClientDisconnectError(err error) bool { if err == nil { return false } if errors.Is(err, io.EOF) || errors.Is(err, net.ErrClosed) || errors.Is(err, context.Canceled) { return true } switch coderws.CloseStatus(err) { case coderws.StatusNormalClosure, coderws.StatusGoingAway, coderws.StatusNoStatusRcvd, coderws.StatusAbnormalClosure: return true } message := strings.ToLower(strings.TrimSpace(err.Error())) if message == "" { return false } return strings.Contains(message, "failed to read frame header: eof") || strings.Contains(message, "unexpected eof") || strings.Contains(message, "use of closed network connection") || strings.Contains(message, "connection reset by peer") || strings.Contains(message, "broken pipe") } func classifyOpenAIWSReadFallbackReason(err error) string { if err == nil { return "read_event" } switch coderws.CloseStatus(err) { case coderws.StatusPolicyViolation: return "policy_violation" case coderws.StatusMessageTooBig: return "message_too_big" default: return "read_event" } } func sortedKeys(m map[string]any) []string { if len(m) == 0 { return nil } keys := make([]string, 0, len(m)) for k := range m { keys = append(keys, k) } sort.Strings(keys) return keys } func (s *OpenAIGatewayService) getOpenAIWSConnPool() *openAIWSConnPool { if s == nil { return nil } s.openaiWSPoolOnce.Do(func() { if s.openaiWSPool == nil { s.openaiWSPool = newOpenAIWSConnPool(s.cfg) } }) return s.openaiWSPool } func (s *OpenAIGatewayService) SnapshotOpenAIWSPoolMetrics() OpenAIWSPoolMetricsSnapshot { pool := s.getOpenAIWSConnPool() if pool == nil { return OpenAIWSPoolMetricsSnapshot{} } return pool.SnapshotMetrics() } type OpenAIWSPerformanceMetricsSnapshot struct { Pool OpenAIWSPoolMetricsSnapshot `json:"pool"` Retry OpenAIWSRetryMetricsSnapshot `json:"retry"` Transport OpenAIWSTransportMetricsSnapshot `json:"transport"` } func (s *OpenAIGatewayService) SnapshotOpenAIWSPerformanceMetrics() OpenAIWSPerformanceMetricsSnapshot { pool := s.getOpenAIWSConnPool() snapshot := OpenAIWSPerformanceMetricsSnapshot{ Retry: s.SnapshotOpenAIWSRetryMetrics(), } if pool == nil { return snapshot } snapshot.Pool = pool.SnapshotMetrics() snapshot.Transport = pool.SnapshotTransportMetrics() return snapshot } func (s *OpenAIGatewayService) getOpenAIWSStateStore() OpenAIWSStateStore { if s == nil { return nil } s.openaiWSStateStoreOnce.Do(func() { if s.openaiWSStateStore == nil { s.openaiWSStateStore = NewOpenAIWSStateStore(s.cache) } }) return s.openaiWSStateStore } func (s *OpenAIGatewayService) openAIWSResponseStickyTTL() time.Duration { if s != nil && s.cfg != nil { seconds := s.cfg.Gateway.OpenAIWS.StickyResponseIDTTLSeconds if seconds > 0 { return time.Duration(seconds) * time.Second } } return time.Hour } func (s *OpenAIGatewayService) openAIWSIngressPreviousResponseRecoveryEnabled() bool { if s != nil && s.cfg != nil { return s.cfg.Gateway.OpenAIWS.IngressPreviousResponseRecoveryEnabled } return true } func (s *OpenAIGatewayService) openAIWSReadTimeout() time.Duration { if s != nil && s.cfg != nil && s.cfg.Gateway.OpenAIWS.ReadTimeoutSeconds > 0 { return time.Duration(s.cfg.Gateway.OpenAIWS.ReadTimeoutSeconds) * time.Second } return 15 * time.Minute } func (s *OpenAIGatewayService) openAIWSWriteTimeout() time.Duration { if s != nil && s.cfg != nil && s.cfg.Gateway.OpenAIWS.WriteTimeoutSeconds > 0 { return time.Duration(s.cfg.Gateway.OpenAIWS.WriteTimeoutSeconds) * time.Second } return 2 * time.Minute } func (s *OpenAIGatewayService) openAIWSEventFlushBatchSize() int { if s != nil && s.cfg != nil && s.cfg.Gateway.OpenAIWS.EventFlushBatchSize > 0 { return s.cfg.Gateway.OpenAIWS.EventFlushBatchSize } return openAIWSEventFlushBatchSizeDefault } func (s *OpenAIGatewayService) openAIWSEventFlushInterval() time.Duration { if s != nil && s.cfg != nil && s.cfg.Gateway.OpenAIWS.EventFlushIntervalMS >= 0 { if s.cfg.Gateway.OpenAIWS.EventFlushIntervalMS == 0 { return 0 } return time.Duration(s.cfg.Gateway.OpenAIWS.EventFlushIntervalMS) * time.Millisecond } return openAIWSEventFlushIntervalDefault } func (s *OpenAIGatewayService) openAIWSPayloadLogSampleRate() float64 { if s != nil && s.cfg != nil { rate := s.cfg.Gateway.OpenAIWS.PayloadLogSampleRate if rate < 0 { return 0 } if rate > 1 { return 1 } return rate } return openAIWSPayloadLogSampleDefault } func (s *OpenAIGatewayService) shouldLogOpenAIWSPayloadSchema(attempt int) bool { // 首次尝试保留一条完整 payload_schema 便于排障。 if attempt <= 1 { return true } rate := s.openAIWSPayloadLogSampleRate() if rate <= 0 { return false } if rate >= 1 { return true } return rand.Float64() < rate } func (s *OpenAIGatewayService) shouldEmitOpenAIWSPayloadSchema(attempt int) bool { if !s.shouldLogOpenAIWSPayloadSchema(attempt) { return false } return logger.L().Core().Enabled(zap.DebugLevel) } func (s *OpenAIGatewayService) openAIWSDialTimeout() time.Duration { if s != nil && s.cfg != nil && s.cfg.Gateway.OpenAIWS.DialTimeoutSeconds > 0 { return time.Duration(s.cfg.Gateway.OpenAIWS.DialTimeoutSeconds) * time.Second } return 10 * time.Second } func (s *OpenAIGatewayService) openAIWSAcquireTimeout() time.Duration { // Acquire 覆盖“连接复用命中/排队/新建连接”三个阶段。 // 这里不再叠加 write_timeout,避免高并发排队时把 TTFT 长尾拉到分钟级。 dial := s.openAIWSDialTimeout() if dial <= 0 { dial = 10 * time.Second } return dial + 2*time.Second } func (s *OpenAIGatewayService) buildOpenAIResponsesWSURL(account *Account) (string, error) { if account == nil { return "", errors.New("account is nil") } var targetURL string switch account.Type { case AccountTypeOAuth: targetURL = chatgptCodexURL case AccountTypeAPIKey: baseURL := account.GetOpenAIBaseURL() if baseURL == "" { targetURL = openaiPlatformAPIURL } else { validatedURL, err := s.validateUpstreamBaseURL(baseURL) if err != nil { return "", err } targetURL = buildOpenAIResponsesURL(validatedURL) } default: targetURL = openaiPlatformAPIURL } parsed, err := url.Parse(strings.TrimSpace(targetURL)) if err != nil { return "", fmt.Errorf("invalid target url: %w", err) } switch strings.ToLower(parsed.Scheme) { case "https": parsed.Scheme = "wss" case "http": parsed.Scheme = "ws" case "wss", "ws": // 保持不变 default: return "", fmt.Errorf("unsupported scheme for ws: %s", parsed.Scheme) } return parsed.String(), nil } func (s *OpenAIGatewayService) buildOpenAIWSHeaders( c *gin.Context, account *Account, token string, decision OpenAIWSProtocolDecision, isCodexCLI bool, turnState string, turnMetadata string, promptCacheKey string, ) (http.Header, openAIWSSessionHeaderResolution) { headers := make(http.Header) headers.Set("authorization", "Bearer "+token) sessionResolution := resolveOpenAIWSSessionHeaders(c, promptCacheKey) if c != nil && c.Request != nil { if v := strings.TrimSpace(c.Request.Header.Get("accept-language")); v != "" { headers.Set("accept-language", v) } } if sessionResolution.SessionID != "" { headers.Set("session_id", sessionResolution.SessionID) } if sessionResolution.ConversationID != "" { headers.Set("conversation_id", sessionResolution.ConversationID) } if state := strings.TrimSpace(turnState); state != "" { headers.Set(openAIWSTurnStateHeader, state) } if metadata := strings.TrimSpace(turnMetadata); metadata != "" { headers.Set(openAIWSTurnMetadataHeader, metadata) } if account != nil && account.Type == AccountTypeOAuth { if chatgptAccountID := account.GetChatGPTAccountID(); chatgptAccountID != "" { headers.Set("chatgpt-account-id", chatgptAccountID) } if isCodexCLI { headers.Set("originator", "codex_cli_rs") } else { headers.Set("originator", "opencode") } } betaValue := openAIWSBetaV2Value if decision.Transport == OpenAIUpstreamTransportResponsesWebsocket { betaValue = openAIWSBetaV1Value } headers.Set("OpenAI-Beta", betaValue) customUA := "" if account != nil { customUA = account.GetOpenAIUserAgent() } if strings.TrimSpace(customUA) != "" { headers.Set("user-agent", customUA) } else if c != nil { if ua := strings.TrimSpace(c.GetHeader("User-Agent")); ua != "" { headers.Set("user-agent", ua) } } if s != nil && s.cfg != nil && s.cfg.Gateway.ForceCodexCLI { headers.Set("user-agent", codexCLIUserAgent) } if account != nil && account.Type == AccountTypeOAuth && !openai.IsCodexCLIRequest(headers.Get("user-agent")) { headers.Set("user-agent", codexCLIUserAgent) } return headers, sessionResolution } func (s *OpenAIGatewayService) buildOpenAIWSCreatePayload(reqBody map[string]any, account *Account) map[string]any { // OpenAI WS Mode 协议:response.create 字段与 HTTP /responses 基本一致。 // 保留 stream 字段(与 Codex CLI 一致),仅移除 background。 payload := make(map[string]any, len(reqBody)+1) for k, v := range reqBody { payload[k] = v } delete(payload, "background") if _, exists := payload["stream"]; !exists { payload["stream"] = true } payload["type"] = "response.create" // OAuth 默认保持 store=false,避免误依赖服务端历史。 if account != nil && account.Type == AccountTypeOAuth && !s.isOpenAIWSStoreRecoveryAllowed(account) { payload["store"] = false } return payload } func setOpenAIWSTurnMetadata(payload map[string]any, turnMetadata string) { if len(payload) == 0 { return } metadata := strings.TrimSpace(turnMetadata) if metadata == "" { return } switch existing := payload["client_metadata"].(type) { case map[string]any: existing[openAIWSTurnMetadataHeader] = metadata payload["client_metadata"] = existing case map[string]string: next := make(map[string]any, len(existing)+1) for k, v := range existing { next[k] = v } next[openAIWSTurnMetadataHeader] = metadata payload["client_metadata"] = next default: payload["client_metadata"] = map[string]any{ openAIWSTurnMetadataHeader: metadata, } } } func (s *OpenAIGatewayService) isOpenAIWSStoreRecoveryAllowed(account *Account) bool { if account != nil && account.IsOpenAIWSAllowStoreRecoveryEnabled() { return true } if s != nil && s.cfg != nil && s.cfg.Gateway.OpenAIWS.AllowStoreRecovery { return true } return false } func (s *OpenAIGatewayService) isOpenAIWSStoreDisabledInRequest(reqBody map[string]any, account *Account) bool { if account != nil && account.Type == AccountTypeOAuth && !s.isOpenAIWSStoreRecoveryAllowed(account) { return true } if len(reqBody) == 0 { return false } rawStore, ok := reqBody["store"] if !ok { return false } storeEnabled, ok := rawStore.(bool) if !ok { return false } return !storeEnabled } func (s *OpenAIGatewayService) isOpenAIWSStoreDisabledInRequestRaw(reqBody []byte, account *Account) bool { if account != nil && account.Type == AccountTypeOAuth && !s.isOpenAIWSStoreRecoveryAllowed(account) { return true } if len(reqBody) == 0 { return false } storeValue := gjson.GetBytes(reqBody, "store") if !storeValue.Exists() { return false } if storeValue.Type != gjson.True && storeValue.Type != gjson.False { return false } return !storeValue.Bool() } func (s *OpenAIGatewayService) openAIWSStoreDisabledConnMode() string { if s == nil || s.cfg == nil { return openAIWSStoreDisabledConnModeStrict } mode := strings.ToLower(strings.TrimSpace(s.cfg.Gateway.OpenAIWS.StoreDisabledConnMode)) switch mode { case openAIWSStoreDisabledConnModeStrict, openAIWSStoreDisabledConnModeAdaptive, openAIWSStoreDisabledConnModeOff: return mode case "": // 兼容旧配置:仅配置了布尔开关时按旧语义推导。 if s.cfg.Gateway.OpenAIWS.StoreDisabledForceNewConn { return openAIWSStoreDisabledConnModeStrict } return openAIWSStoreDisabledConnModeOff default: return openAIWSStoreDisabledConnModeStrict } } func shouldForceNewConnOnStoreDisabled(mode, lastFailureReason string) bool { switch mode { case openAIWSStoreDisabledConnModeOff: return false case openAIWSStoreDisabledConnModeAdaptive: reason := strings.TrimPrefix(strings.TrimSpace(lastFailureReason), "prewarm_") switch reason { case "policy_violation", "message_too_big", "auth_failed", "write_request", "write": return true default: return false } default: return true } } func dropPreviousResponseIDFromRawPayload(payload []byte) ([]byte, bool, error) { return dropPreviousResponseIDFromRawPayloadWithDeleteFn(payload, sjson.DeleteBytes) } func dropPreviousResponseIDFromRawPayloadWithDeleteFn( payload []byte, deleteFn func([]byte, string) ([]byte, error), ) ([]byte, bool, error) { if len(payload) == 0 { return payload, false, nil } if !gjson.GetBytes(payload, "previous_response_id").Exists() { return payload, false, nil } if deleteFn == nil { deleteFn = sjson.DeleteBytes } updated := payload for i := 0; i < openAIWSMaxPrevResponseIDDeletePasses && gjson.GetBytes(updated, "previous_response_id").Exists(); i++ { next, err := deleteFn(updated, "previous_response_id") if err != nil { return payload, false, err } updated = next } return updated, !gjson.GetBytes(updated, "previous_response_id").Exists(), nil } func setPreviousResponseIDToRawPayload(payload []byte, previousResponseID string) ([]byte, error) { normalizedPrevID := strings.TrimSpace(previousResponseID) if len(payload) == 0 || normalizedPrevID == "" { return payload, nil } updated, err := sjson.SetBytes(payload, "previous_response_id", normalizedPrevID) if err == nil { return updated, nil } var reqBody map[string]any if unmarshalErr := json.Unmarshal(payload, &reqBody); unmarshalErr != nil { return nil, err } reqBody["previous_response_id"] = normalizedPrevID rebuilt, marshalErr := json.Marshal(reqBody) if marshalErr != nil { return nil, marshalErr } return rebuilt, nil } func shouldInferIngressFunctionCallOutputPreviousResponseID( storeDisabled bool, turn int, hasFunctionCallOutput bool, currentPreviousResponseID string, expectedPreviousResponseID string, ) bool { if !storeDisabled || turn <= 1 || !hasFunctionCallOutput { return false } if strings.TrimSpace(currentPreviousResponseID) != "" { return false } return strings.TrimSpace(expectedPreviousResponseID) != "" } func alignStoreDisabledPreviousResponseID( payload []byte, expectedPreviousResponseID string, ) ([]byte, bool, error) { if len(payload) == 0 { return payload, false, nil } expected := strings.TrimSpace(expectedPreviousResponseID) if expected == "" { return payload, false, nil } current := openAIWSPayloadStringFromRaw(payload, "previous_response_id") if current == "" || current == expected { return payload, false, nil } withoutPrev, removed, dropErr := dropPreviousResponseIDFromRawPayload(payload) if dropErr != nil { return payload, false, dropErr } if !removed { return payload, false, nil } updated, setErr := setPreviousResponseIDToRawPayload(withoutPrev, expected) if setErr != nil { return payload, false, setErr } return updated, true, nil } func cloneOpenAIWSPayloadBytes(payload []byte) []byte { if len(payload) == 0 { return nil } cloned := make([]byte, len(payload)) copy(cloned, payload) return cloned } func cloneOpenAIWSRawMessages(items []json.RawMessage) []json.RawMessage { if items == nil { return nil } cloned := make([]json.RawMessage, 0, len(items)) for idx := range items { cloned = append(cloned, json.RawMessage(cloneOpenAIWSPayloadBytes(items[idx]))) } return cloned } func normalizeOpenAIWSJSONForCompare(raw []byte) ([]byte, error) { trimmed := bytes.TrimSpace(raw) if len(trimmed) == 0 { return nil, errors.New("json is empty") } var decoded any if err := json.Unmarshal(trimmed, &decoded); err != nil { return nil, err } return json.Marshal(decoded) } func normalizeOpenAIWSJSONForCompareOrRaw(raw []byte) []byte { normalized, err := normalizeOpenAIWSJSONForCompare(raw) if err != nil { return bytes.TrimSpace(raw) } return normalized } func normalizeOpenAIWSPayloadWithoutInputAndPreviousResponseID(payload []byte) ([]byte, error) { if len(payload) == 0 { return nil, errors.New("payload is empty") } var decoded map[string]any if err := json.Unmarshal(payload, &decoded); err != nil { return nil, err } delete(decoded, "input") delete(decoded, "previous_response_id") return json.Marshal(decoded) } func openAIWSExtractNormalizedInputSequence(payload []byte) ([]json.RawMessage, bool, error) { if len(payload) == 0 { return nil, false, nil } inputValue := gjson.GetBytes(payload, "input") if !inputValue.Exists() { return nil, false, nil } if inputValue.Type == gjson.JSON { raw := strings.TrimSpace(inputValue.Raw) if strings.HasPrefix(raw, "[") { var items []json.RawMessage if err := json.Unmarshal([]byte(raw), &items); err != nil { return nil, true, err } return items, true, nil } return []json.RawMessage{json.RawMessage(raw)}, true, nil } if inputValue.Type == gjson.String { encoded, _ := json.Marshal(inputValue.String()) return []json.RawMessage{encoded}, true, nil } return []json.RawMessage{json.RawMessage(inputValue.Raw)}, true, nil } func openAIWSInputIsPrefixExtended(previousPayload, currentPayload []byte) (bool, error) { previousItems, previousExists, prevErr := openAIWSExtractNormalizedInputSequence(previousPayload) if prevErr != nil { return false, prevErr } currentItems, currentExists, currentErr := openAIWSExtractNormalizedInputSequence(currentPayload) if currentErr != nil { return false, currentErr } if !previousExists && !currentExists { return true, nil } if !previousExists { return len(currentItems) == 0, nil } if !currentExists { return len(previousItems) == 0, nil } if len(currentItems) < len(previousItems) { return false, nil } for idx := range previousItems { previousNormalized := normalizeOpenAIWSJSONForCompareOrRaw(previousItems[idx]) currentNormalized := normalizeOpenAIWSJSONForCompareOrRaw(currentItems[idx]) if !bytes.Equal(previousNormalized, currentNormalized) { return false, nil } } return true, nil } func openAIWSRawItemsHasPrefix(items []json.RawMessage, prefix []json.RawMessage) bool { if len(prefix) == 0 { return true } if len(items) < len(prefix) { return false } for idx := range prefix { previousNormalized := normalizeOpenAIWSJSONForCompareOrRaw(prefix[idx]) currentNormalized := normalizeOpenAIWSJSONForCompareOrRaw(items[idx]) if !bytes.Equal(previousNormalized, currentNormalized) { return false } } return true } func buildOpenAIWSReplayInputSequence( previousFullInput []json.RawMessage, previousFullInputExists bool, currentPayload []byte, hasPreviousResponseID bool, ) ([]json.RawMessage, bool, error) { currentItems, currentExists, currentErr := openAIWSExtractNormalizedInputSequence(currentPayload) if currentErr != nil { return nil, false, currentErr } if !hasPreviousResponseID { return cloneOpenAIWSRawMessages(currentItems), currentExists, nil } if !previousFullInputExists { return cloneOpenAIWSRawMessages(currentItems), currentExists, nil } if !currentExists || len(currentItems) == 0 { return cloneOpenAIWSRawMessages(previousFullInput), true, nil } if openAIWSRawItemsHasPrefix(currentItems, previousFullInput) { return cloneOpenAIWSRawMessages(currentItems), true, nil } merged := make([]json.RawMessage, 0, len(previousFullInput)+len(currentItems)) merged = append(merged, cloneOpenAIWSRawMessages(previousFullInput)...) merged = append(merged, cloneOpenAIWSRawMessages(currentItems)...) return merged, true, nil } func setOpenAIWSPayloadInputSequence( payload []byte, fullInput []json.RawMessage, fullInputExists bool, ) ([]byte, error) { if !fullInputExists { return payload, nil } // Preserve [] vs null semantics when input exists but is empty. inputForMarshal := fullInput if inputForMarshal == nil { inputForMarshal = []json.RawMessage{} } inputRaw, marshalErr := json.Marshal(inputForMarshal) if marshalErr != nil { return nil, marshalErr } return sjson.SetRawBytes(payload, "input", inputRaw) } func shouldKeepIngressPreviousResponseID( previousPayload []byte, currentPayload []byte, lastTurnResponseID string, hasFunctionCallOutput bool, ) (bool, string, error) { if hasFunctionCallOutput { return true, "has_function_call_output", nil } currentPreviousResponseID := strings.TrimSpace(openAIWSPayloadStringFromRaw(currentPayload, "previous_response_id")) if currentPreviousResponseID == "" { return false, "missing_previous_response_id", nil } expectedPreviousResponseID := strings.TrimSpace(lastTurnResponseID) if expectedPreviousResponseID == "" { return false, "missing_last_turn_response_id", nil } if currentPreviousResponseID != expectedPreviousResponseID { return false, "previous_response_id_mismatch", nil } if len(previousPayload) == 0 { return false, "missing_previous_turn_payload", nil } previousComparable, previousComparableErr := normalizeOpenAIWSPayloadWithoutInputAndPreviousResponseID(previousPayload) if previousComparableErr != nil { return false, "non_input_compare_error", previousComparableErr } currentComparable, currentComparableErr := normalizeOpenAIWSPayloadWithoutInputAndPreviousResponseID(currentPayload) if currentComparableErr != nil { return false, "non_input_compare_error", currentComparableErr } if !bytes.Equal(previousComparable, currentComparable) { return false, "non_input_changed", nil } return true, "strict_incremental_ok", nil } type openAIWSIngressPreviousTurnStrictState struct { nonInputComparable []byte } func buildOpenAIWSIngressPreviousTurnStrictState(payload []byte) (*openAIWSIngressPreviousTurnStrictState, error) { if len(payload) == 0 { return nil, nil } nonInputComparable, nonInputErr := normalizeOpenAIWSPayloadWithoutInputAndPreviousResponseID(payload) if nonInputErr != nil { return nil, nonInputErr } return &openAIWSIngressPreviousTurnStrictState{ nonInputComparable: nonInputComparable, }, nil } func shouldKeepIngressPreviousResponseIDWithStrictState( previousState *openAIWSIngressPreviousTurnStrictState, currentPayload []byte, lastTurnResponseID string, hasFunctionCallOutput bool, ) (bool, string, error) { if hasFunctionCallOutput { return true, "has_function_call_output", nil } currentPreviousResponseID := strings.TrimSpace(openAIWSPayloadStringFromRaw(currentPayload, "previous_response_id")) if currentPreviousResponseID == "" { return false, "missing_previous_response_id", nil } expectedPreviousResponseID := strings.TrimSpace(lastTurnResponseID) if expectedPreviousResponseID == "" { return false, "missing_last_turn_response_id", nil } if currentPreviousResponseID != expectedPreviousResponseID { return false, "previous_response_id_mismatch", nil } if previousState == nil { return false, "missing_previous_turn_payload", nil } currentComparable, currentComparableErr := normalizeOpenAIWSPayloadWithoutInputAndPreviousResponseID(currentPayload) if currentComparableErr != nil { return false, "non_input_compare_error", currentComparableErr } if !bytes.Equal(previousState.nonInputComparable, currentComparable) { return false, "non_input_changed", nil } return true, "strict_incremental_ok", nil } func (s *OpenAIGatewayService) forwardOpenAIWSV2( ctx context.Context, c *gin.Context, account *Account, reqBody map[string]any, token string, decision OpenAIWSProtocolDecision, isCodexCLI bool, reqStream bool, originalModel string, mappedModel string, startTime time.Time, attempt int, lastFailureReason string, ) (*OpenAIForwardResult, error) { if s == nil || account == nil { return nil, wrapOpenAIWSFallback("invalid_state", errors.New("service or account is nil")) } wsURL, err := s.buildOpenAIResponsesWSURL(account) if err != nil { return nil, wrapOpenAIWSFallback("build_ws_url", err) } wsHost := "-" wsPath := "-" if parsed, parseErr := url.Parse(wsURL); parseErr == nil && parsed != nil { if h := strings.TrimSpace(parsed.Host); h != "" { wsHost = normalizeOpenAIWSLogValue(h) } if p := strings.TrimSpace(parsed.Path); p != "" { wsPath = normalizeOpenAIWSLogValue(p) } } logOpenAIWSModeDebug( "dial_target account_id=%d account_type=%s ws_host=%s ws_path=%s", account.ID, account.Type, wsHost, wsPath, ) payload := s.buildOpenAIWSCreatePayload(reqBody, account) payloadStrategy, removedKeys := applyOpenAIWSRetryPayloadStrategy(payload, attempt) previousResponseID := openAIWSPayloadString(payload, "previous_response_id") previousResponseIDKind := ClassifyOpenAIPreviousResponseIDKind(previousResponseID) promptCacheKey := openAIWSPayloadString(payload, "prompt_cache_key") _, hasTools := payload["tools"] debugEnabled := isOpenAIWSModeDebugEnabled() payloadBytes := -1 resolvePayloadBytes := func() int { if payloadBytes >= 0 { return payloadBytes } payloadBytes = len(payloadAsJSONBytes(payload)) return payloadBytes } streamValue := "-" if raw, ok := payload["stream"]; ok { streamValue = normalizeOpenAIWSLogValue(strings.TrimSpace(fmt.Sprintf("%v", raw))) } turnState := "" turnMetadata := "" if c != nil && c.Request != nil { turnState = strings.TrimSpace(c.GetHeader(openAIWSTurnStateHeader)) turnMetadata = strings.TrimSpace(c.GetHeader(openAIWSTurnMetadataHeader)) } setOpenAIWSTurnMetadata(payload, turnMetadata) payloadEventType := openAIWSPayloadString(payload, "type") if payloadEventType == "" { payloadEventType = "response.create" } if s.shouldEmitOpenAIWSPayloadSchema(attempt) { logOpenAIWSModeInfo( "[debug] payload_schema account_id=%d attempt=%d event=%s payload_keys=%s payload_bytes=%d payload_key_sizes=%s input_summary=%s stream=%s payload_strategy=%s removed_keys=%s has_previous_response_id=%v has_prompt_cache_key=%v has_tools=%v", account.ID, attempt, payloadEventType, normalizeOpenAIWSLogValue(strings.Join(sortedKeys(payload), ",")), resolvePayloadBytes(), normalizeOpenAIWSLogValue(summarizeOpenAIWSPayloadKeySizes(payload, openAIWSPayloadKeySizeTopN)), normalizeOpenAIWSLogValue(summarizeOpenAIWSInput(payload["input"])), streamValue, normalizeOpenAIWSLogValue(payloadStrategy), normalizeOpenAIWSLogValue(strings.Join(removedKeys, ",")), previousResponseID != "", promptCacheKey != "", hasTools, ) } stateStore := s.getOpenAIWSStateStore() groupID := getOpenAIGroupIDFromContext(c) sessionHash := s.GenerateSessionHash(c, nil) if sessionHash == "" { var legacySessionHash string sessionHash, legacySessionHash = openAIWSSessionHashesFromID(promptCacheKey) attachOpenAILegacySessionHashToGin(c, legacySessionHash) } if turnState == "" && stateStore != nil && sessionHash != "" { if savedTurnState, ok := stateStore.GetSessionTurnState(groupID, sessionHash); ok { turnState = savedTurnState } } preferredConnID := "" if stateStore != nil && previousResponseID != "" { if connID, ok := stateStore.GetResponseConn(previousResponseID); ok { preferredConnID = connID } } storeDisabled := s.isOpenAIWSStoreDisabledInRequest(reqBody, account) if stateStore != nil && storeDisabled && previousResponseID == "" && sessionHash != "" { if connID, ok := stateStore.GetSessionConn(groupID, sessionHash); ok { preferredConnID = connID } } storeDisabledConnMode := s.openAIWSStoreDisabledConnMode() forceNewConnByPolicy := shouldForceNewConnOnStoreDisabled(storeDisabledConnMode, lastFailureReason) forceNewConn := forceNewConnByPolicy && storeDisabled && previousResponseID == "" && sessionHash != "" && preferredConnID == "" wsHeaders, sessionResolution := s.buildOpenAIWSHeaders(c, account, token, decision, isCodexCLI, turnState, turnMetadata, promptCacheKey) logOpenAIWSModeDebug( "acquire_start account_id=%d account_type=%s transport=%s preferred_conn_id=%s has_previous_response_id=%v session_hash=%s has_turn_state=%v turn_state_len=%d has_turn_metadata=%v turn_metadata_len=%d store_disabled=%v store_disabled_conn_mode=%s retry_last_reason=%s force_new_conn=%v header_user_agent=%s header_openai_beta=%s header_originator=%s header_accept_language=%s header_session_id=%s header_conversation_id=%s session_id_source=%s conversation_id_source=%s has_prompt_cache_key=%v has_chatgpt_account_id=%v has_authorization=%v has_session_id=%v has_conversation_id=%v proxy_enabled=%v", account.ID, account.Type, normalizeOpenAIWSLogValue(string(decision.Transport)), truncateOpenAIWSLogValue(preferredConnID, openAIWSIDValueMaxLen), previousResponseID != "", truncateOpenAIWSLogValue(sessionHash, 12), turnState != "", len(turnState), turnMetadata != "", len(turnMetadata), storeDisabled, normalizeOpenAIWSLogValue(storeDisabledConnMode), truncateOpenAIWSLogValue(lastFailureReason, openAIWSLogValueMaxLen), forceNewConn, openAIWSHeaderValueForLog(wsHeaders, "user-agent"), openAIWSHeaderValueForLog(wsHeaders, "openai-beta"), openAIWSHeaderValueForLog(wsHeaders, "originator"), openAIWSHeaderValueForLog(wsHeaders, "accept-language"), openAIWSHeaderValueForLog(wsHeaders, "session_id"), openAIWSHeaderValueForLog(wsHeaders, "conversation_id"), normalizeOpenAIWSLogValue(sessionResolution.SessionSource), normalizeOpenAIWSLogValue(sessionResolution.ConversationSource), promptCacheKey != "", hasOpenAIWSHeader(wsHeaders, "chatgpt-account-id"), hasOpenAIWSHeader(wsHeaders, "authorization"), hasOpenAIWSHeader(wsHeaders, "session_id"), hasOpenAIWSHeader(wsHeaders, "conversation_id"), account.ProxyID != nil && account.Proxy != nil, ) acquireCtx, acquireCancel := context.WithTimeout(ctx, s.openAIWSAcquireTimeout()) defer acquireCancel() lease, err := s.getOpenAIWSConnPool().Acquire(acquireCtx, openAIWSAcquireRequest{ Account: account, WSURL: wsURL, Headers: wsHeaders, PreferredConnID: preferredConnID, ForceNewConn: forceNewConn, ProxyURL: func() string { if account.ProxyID != nil && account.Proxy != nil { return account.Proxy.URL() } return "" }(), }) if err != nil { dialStatus, dialClass, dialCloseStatus, dialCloseReason, dialRespServer, dialRespVia, dialRespCFRay, dialRespReqID := summarizeOpenAIWSDialError(err) logOpenAIWSModeInfo( "acquire_fail account_id=%d account_type=%s transport=%s reason=%s dial_status=%d dial_class=%s dial_close_status=%s dial_close_reason=%s dial_resp_server=%s dial_resp_via=%s dial_resp_cf_ray=%s dial_resp_x_request_id=%s cause=%s preferred_conn_id=%s force_new_conn=%v ws_host=%s ws_path=%s proxy_enabled=%v", account.ID, account.Type, normalizeOpenAIWSLogValue(string(decision.Transport)), normalizeOpenAIWSLogValue(classifyOpenAIWSAcquireError(err)), dialStatus, dialClass, dialCloseStatus, truncateOpenAIWSLogValue(dialCloseReason, openAIWSHeaderValueMaxLen), dialRespServer, dialRespVia, dialRespCFRay, dialRespReqID, truncateOpenAIWSLogValue(err.Error(), openAIWSLogValueMaxLen), truncateOpenAIWSLogValue(preferredConnID, openAIWSIDValueMaxLen), forceNewConn, wsHost, wsPath, account.ProxyID != nil && account.Proxy != nil, ) return nil, wrapOpenAIWSFallback(classifyOpenAIWSAcquireError(err), err) } defer lease.Release() connID := strings.TrimSpace(lease.ConnID()) logOpenAIWSModeDebug( "connected account_id=%d account_type=%s transport=%s conn_id=%s conn_reused=%v conn_pick_ms=%d queue_wait_ms=%d has_previous_response_id=%v", account.ID, account.Type, normalizeOpenAIWSLogValue(string(decision.Transport)), connID, lease.Reused(), lease.ConnPickDuration().Milliseconds(), lease.QueueWaitDuration().Milliseconds(), previousResponseID != "", ) if previousResponseID != "" { logOpenAIWSModeInfo( "continuation_probe account_id=%d account_type=%s conn_id=%s previous_response_id=%s previous_response_id_kind=%s preferred_conn_id=%s conn_reused=%v store_disabled=%v session_hash=%s header_session_id=%s header_conversation_id=%s session_id_source=%s conversation_id_source=%s has_turn_state=%v turn_state_len=%d has_prompt_cache_key=%v", account.ID, account.Type, truncateOpenAIWSLogValue(connID, openAIWSIDValueMaxLen), truncateOpenAIWSLogValue(previousResponseID, openAIWSIDValueMaxLen), normalizeOpenAIWSLogValue(previousResponseIDKind), truncateOpenAIWSLogValue(preferredConnID, openAIWSIDValueMaxLen), lease.Reused(), storeDisabled, truncateOpenAIWSLogValue(sessionHash, 12), openAIWSHeaderValueForLog(wsHeaders, "session_id"), openAIWSHeaderValueForLog(wsHeaders, "conversation_id"), normalizeOpenAIWSLogValue(sessionResolution.SessionSource), normalizeOpenAIWSLogValue(sessionResolution.ConversationSource), turnState != "", len(turnState), promptCacheKey != "", ) } if c != nil { SetOpsLatencyMs(c, OpsOpenAIWSConnPickMsKey, lease.ConnPickDuration().Milliseconds()) SetOpsLatencyMs(c, OpsOpenAIWSQueueWaitMsKey, lease.QueueWaitDuration().Milliseconds()) c.Set(OpsOpenAIWSConnReusedKey, lease.Reused()) if connID != "" { c.Set(OpsOpenAIWSConnIDKey, connID) } } handshakeTurnState := strings.TrimSpace(lease.HandshakeHeader(openAIWSTurnStateHeader)) logOpenAIWSModeDebug( "handshake account_id=%d conn_id=%s has_turn_state=%v turn_state_len=%d", account.ID, connID, handshakeTurnState != "", len(handshakeTurnState), ) if handshakeTurnState != "" { if stateStore != nil && sessionHash != "" { stateStore.BindSessionTurnState(groupID, sessionHash, handshakeTurnState, s.openAIWSSessionStickyTTL()) } if c != nil { c.Header(http.CanonicalHeaderKey(openAIWSTurnStateHeader), handshakeTurnState) } } if err := s.performOpenAIWSGeneratePrewarm( ctx, lease, decision, payload, previousResponseID, reqBody, account, stateStore, groupID, ); err != nil { return nil, err } if err := lease.WriteJSONWithContextTimeout(ctx, payload, s.openAIWSWriteTimeout()); err != nil { lease.MarkBroken() logOpenAIWSModeInfo( "write_request_fail account_id=%d conn_id=%s cause=%s payload_bytes=%d", account.ID, connID, truncateOpenAIWSLogValue(err.Error(), openAIWSLogValueMaxLen), resolvePayloadBytes(), ) return nil, wrapOpenAIWSFallback("write_request", err) } if debugEnabled { logOpenAIWSModeDebug( "write_request_sent account_id=%d conn_id=%s stream=%v payload_bytes=%d previous_response_id=%s", account.ID, connID, reqStream, resolvePayloadBytes(), truncateOpenAIWSLogValue(previousResponseID, openAIWSIDValueMaxLen), ) } usage := &OpenAIUsage{} var firstTokenMs *int responseID := "" var finalResponse []byte wroteDownstream := false needModelReplace := originalModel != mappedModel var mappedModelBytes []byte if needModelReplace && mappedModel != "" { mappedModelBytes = []byte(mappedModel) } bufferedStreamEvents := make([][]byte, 0, 4) eventCount := 0 tokenEventCount := 0 terminalEventCount := 0 bufferedEventCount := 0 flushedBufferedEventCount := 0 firstEventType := "" lastEventType := "" var flusher http.Flusher if reqStream { if s.responseHeaderFilter != nil { responseheaders.WriteFilteredHeaders(c.Writer.Header(), http.Header{}, s.responseHeaderFilter) } c.Header("Content-Type", "text/event-stream") c.Header("Cache-Control", "no-cache") c.Header("Connection", "keep-alive") c.Header("X-Accel-Buffering", "no") f, ok := c.Writer.(http.Flusher) if !ok { lease.MarkBroken() return nil, wrapOpenAIWSFallback("streaming_not_supported", errors.New("streaming not supported")) } flusher = f } clientDisconnected := false flushBatchSize := s.openAIWSEventFlushBatchSize() flushInterval := s.openAIWSEventFlushInterval() pendingFlushEvents := 0 lastFlushAt := time.Now() flushStreamWriter := func(force bool) { if clientDisconnected || flusher == nil || pendingFlushEvents <= 0 { return } if !force && flushBatchSize > 1 && pendingFlushEvents < flushBatchSize { if flushInterval <= 0 || time.Since(lastFlushAt) < flushInterval { return } } flusher.Flush() pendingFlushEvents = 0 lastFlushAt = time.Now() } emitStreamMessage := func(message []byte, forceFlush bool) { if clientDisconnected { return } frame := make([]byte, 0, len(message)+8) frame = append(frame, "data: "...) frame = append(frame, message...) frame = append(frame, '\n', '\n') _, wErr := c.Writer.Write(frame) if wErr == nil { wroteDownstream = true pendingFlushEvents++ flushStreamWriter(forceFlush) return } clientDisconnected = true logger.LegacyPrintf("service.openai_gateway", "[OpenAI WS Mode] client disconnected, continue draining upstream: account=%d", account.ID) } flushBufferedStreamEvents := func(reason string) { if len(bufferedStreamEvents) == 0 { return } flushed := len(bufferedStreamEvents) for _, buffered := range bufferedStreamEvents { emitStreamMessage(buffered, false) } bufferedStreamEvents = bufferedStreamEvents[:0] flushStreamWriter(true) flushedBufferedEventCount += flushed if debugEnabled { logOpenAIWSModeDebug( "buffer_flush account_id=%d conn_id=%s reason=%s flushed=%d total_flushed=%d client_disconnected=%v", account.ID, connID, truncateOpenAIWSLogValue(reason, openAIWSLogValueMaxLen), flushed, flushedBufferedEventCount, clientDisconnected, ) } } readTimeout := s.openAIWSReadTimeout() for { message, readErr := lease.ReadMessageWithContextTimeout(ctx, readTimeout) if readErr != nil { lease.MarkBroken() closeStatus, closeReason := summarizeOpenAIWSReadCloseError(readErr) logOpenAIWSModeInfo( "read_fail account_id=%d conn_id=%s wrote_downstream=%v close_status=%s close_reason=%s cause=%s events=%d token_events=%d terminal_events=%d buffered_pending=%d buffered_flushed=%d first_event=%s last_event=%s", account.ID, connID, wroteDownstream, closeStatus, closeReason, truncateOpenAIWSLogValue(readErr.Error(), openAIWSLogValueMaxLen), eventCount, tokenEventCount, terminalEventCount, len(bufferedStreamEvents), flushedBufferedEventCount, truncateOpenAIWSLogValue(firstEventType, openAIWSLogValueMaxLen), truncateOpenAIWSLogValue(lastEventType, openAIWSLogValueMaxLen), ) if !wroteDownstream { return nil, wrapOpenAIWSFallback(classifyOpenAIWSReadFallbackReason(readErr), readErr) } if clientDisconnected { break } setOpsUpstreamError(c, 0, sanitizeUpstreamErrorMessage(readErr.Error()), "") return nil, fmt.Errorf("openai ws read event: %w", readErr) } eventType, eventResponseID, responseField := parseOpenAIWSEventEnvelope(message) if eventType == "" { continue } eventCount++ if firstEventType == "" { firstEventType = eventType } lastEventType = eventType if responseID == "" && eventResponseID != "" { responseID = eventResponseID } isTokenEvent := isOpenAIWSTokenEvent(eventType) if isTokenEvent { tokenEventCount++ } isTerminalEvent := isOpenAIWSTerminalEvent(eventType) if isTerminalEvent { terminalEventCount++ } if firstTokenMs == nil && isTokenEvent { ms := int(time.Since(startTime).Milliseconds()) firstTokenMs = &ms } if debugEnabled && shouldLogOpenAIWSEvent(eventCount, eventType) { logOpenAIWSModeDebug( "event_received account_id=%d conn_id=%s idx=%d type=%s bytes=%d token=%v terminal=%v buffered_pending=%d", account.ID, connID, eventCount, truncateOpenAIWSLogValue(eventType, openAIWSLogValueMaxLen), len(message), isTokenEvent, isTerminalEvent, len(bufferedStreamEvents), ) } if !clientDisconnected { if needModelReplace && len(mappedModelBytes) > 0 && openAIWSEventMayContainModel(eventType) && bytes.Contains(message, mappedModelBytes) { message = replaceOpenAIWSMessageModel(message, mappedModel, originalModel) } if openAIWSEventMayContainToolCalls(eventType) && openAIWSMessageLikelyContainsToolCalls(message) { if corrected, changed := s.toolCorrector.CorrectToolCallsInSSEBytes(message); changed { message = corrected } } } if openAIWSEventShouldParseUsage(eventType) { parseOpenAIWSResponseUsageFromCompletedEvent(message, usage) } if eventType == "error" { errCodeRaw, errTypeRaw, errMsgRaw := parseOpenAIWSErrorEventFields(message) errMsg := strings.TrimSpace(errMsgRaw) if errMsg == "" { errMsg = "Upstream websocket error" } fallbackReason, canFallback := classifyOpenAIWSErrorEventFromRaw(errCodeRaw, errTypeRaw, errMsgRaw) errCode, errType, errMessage := summarizeOpenAIWSErrorEventFieldsFromRaw(errCodeRaw, errTypeRaw, errMsgRaw) logOpenAIWSModeInfo( "error_event account_id=%d conn_id=%s idx=%d fallback_reason=%s can_fallback=%v err_code=%s err_type=%s err_message=%s", account.ID, connID, eventCount, truncateOpenAIWSLogValue(fallbackReason, openAIWSLogValueMaxLen), canFallback, errCode, errType, errMessage, ) if fallbackReason == "previous_response_not_found" { logOpenAIWSModeInfo( "previous_response_not_found_diag account_id=%d account_type=%s conn_id=%s previous_response_id=%s previous_response_id_kind=%s response_id=%s event_idx=%d req_stream=%v store_disabled=%v conn_reused=%v session_hash=%s header_session_id=%s header_conversation_id=%s session_id_source=%s conversation_id_source=%s has_turn_state=%v turn_state_len=%d has_prompt_cache_key=%v err_code=%s err_type=%s err_message=%s", account.ID, account.Type, connID, truncateOpenAIWSLogValue(previousResponseID, openAIWSIDValueMaxLen), normalizeOpenAIWSLogValue(previousResponseIDKind), truncateOpenAIWSLogValue(responseID, openAIWSIDValueMaxLen), eventCount, reqStream, storeDisabled, lease.Reused(), truncateOpenAIWSLogValue(sessionHash, 12), openAIWSHeaderValueForLog(wsHeaders, "session_id"), openAIWSHeaderValueForLog(wsHeaders, "conversation_id"), normalizeOpenAIWSLogValue(sessionResolution.SessionSource), normalizeOpenAIWSLogValue(sessionResolution.ConversationSource), turnState != "", len(turnState), promptCacheKey != "", errCode, errType, errMessage, ) } // error 事件后连接不再可复用,避免回池后污染下一请求。 lease.MarkBroken() if !wroteDownstream && canFallback { return nil, wrapOpenAIWSFallback(fallbackReason, errors.New(errMsg)) } statusCode := openAIWSErrorHTTPStatusFromRaw(errCodeRaw, errTypeRaw) setOpsUpstreamError(c, statusCode, errMsg, "") if reqStream && !clientDisconnected { flushBufferedStreamEvents("error_event") emitStreamMessage(message, true) } if !reqStream { c.JSON(statusCode, gin.H{ "error": gin.H{ "type": "upstream_error", "message": errMsg, }, }) } return nil, fmt.Errorf("openai ws error event: %s", errMsg) } if reqStream { // 在首个 token 前先缓冲事件(如 response.created), // 以便上游早期断连时仍可安全回退到 HTTP,不给下游发送半截流。 shouldBuffer := firstTokenMs == nil && !isTokenEvent && !isTerminalEvent if shouldBuffer { buffered := make([]byte, len(message)) copy(buffered, message) bufferedStreamEvents = append(bufferedStreamEvents, buffered) bufferedEventCount++ if debugEnabled && shouldLogOpenAIWSBufferedEvent(bufferedEventCount) { logOpenAIWSModeDebug( "buffer_enqueue account_id=%d conn_id=%s idx=%d event_idx=%d event_type=%s buffer_size=%d", account.ID, connID, bufferedEventCount, eventCount, truncateOpenAIWSLogValue(eventType, openAIWSLogValueMaxLen), len(bufferedStreamEvents), ) } } else { flushBufferedStreamEvents(eventType) emitStreamMessage(message, isTerminalEvent) } } else { if responseField.Exists() && responseField.Type == gjson.JSON { finalResponse = []byte(responseField.Raw) } } if isTerminalEvent { break } } if !reqStream { if len(finalResponse) == 0 { logOpenAIWSModeInfo( "missing_final_response account_id=%d conn_id=%s events=%d token_events=%d terminal_events=%d wrote_downstream=%v", account.ID, connID, eventCount, tokenEventCount, terminalEventCount, wroteDownstream, ) if !wroteDownstream { return nil, wrapOpenAIWSFallback("missing_final_response", errors.New("no terminal response payload")) } return nil, errors.New("ws finished without final response") } if needModelReplace { finalResponse = s.replaceModelInResponseBody(finalResponse, mappedModel, originalModel) } finalResponse = s.correctToolCallsInResponseBody(finalResponse) populateOpenAIUsageFromResponseJSON(finalResponse, usage) if responseID == "" { responseID = strings.TrimSpace(gjson.GetBytes(finalResponse, "id").String()) } c.Data(http.StatusOK, "application/json", finalResponse) } else { flushStreamWriter(true) } if responseID != "" && stateStore != nil { ttl := s.openAIWSResponseStickyTTL() logOpenAIWSBindResponseAccountWarn(groupID, account.ID, responseID, stateStore.BindResponseAccount(ctx, groupID, responseID, account.ID, ttl)) stateStore.BindResponseConn(responseID, lease.ConnID(), ttl) } if stateStore != nil && storeDisabled && sessionHash != "" { stateStore.BindSessionConn(groupID, sessionHash, lease.ConnID(), s.openAIWSSessionStickyTTL()) } firstTokenMsValue := -1 if firstTokenMs != nil { firstTokenMsValue = *firstTokenMs } logOpenAIWSModeDebug( "completed account_id=%d conn_id=%s response_id=%s stream=%v duration_ms=%d events=%d token_events=%d terminal_events=%d buffered_events=%d buffered_flushed=%d first_event=%s last_event=%s first_token_ms=%d wrote_downstream=%v client_disconnected=%v", account.ID, connID, truncateOpenAIWSLogValue(strings.TrimSpace(responseID), openAIWSIDValueMaxLen), reqStream, time.Since(startTime).Milliseconds(), eventCount, tokenEventCount, terminalEventCount, bufferedEventCount, flushedBufferedEventCount, truncateOpenAIWSLogValue(firstEventType, openAIWSLogValueMaxLen), truncateOpenAIWSLogValue(lastEventType, openAIWSLogValueMaxLen), firstTokenMsValue, wroteDownstream, clientDisconnected, ) return &OpenAIForwardResult{ RequestID: responseID, Usage: *usage, Model: originalModel, ReasoningEffort: extractOpenAIReasoningEffort(reqBody, originalModel), Stream: reqStream, OpenAIWSMode: true, Duration: time.Since(startTime), FirstTokenMs: firstTokenMs, }, nil } // ProxyResponsesWebSocketFromClient 处理客户端入站 WebSocket(OpenAI Responses WS Mode)并转发到上游。 // 当前实现按“单请求 -> 终止事件 -> 下一请求”的顺序代理,适配 Codex CLI 的 turn 模式。 func (s *OpenAIGatewayService) ProxyResponsesWebSocketFromClient( ctx context.Context, c *gin.Context, clientConn *coderws.Conn, account *Account, token string, firstClientMessage []byte, hooks *OpenAIWSIngressHooks, ) error { if s == nil { return errors.New("service is nil") } if c == nil { return errors.New("gin context is nil") } if clientConn == nil { return errors.New("client websocket is nil") } if account == nil { return errors.New("account is nil") } if strings.TrimSpace(token) == "" { return errors.New("token is empty") } wsDecision := s.getOpenAIWSProtocolResolver().Resolve(account) modeRouterV2Enabled := s != nil && s.cfg != nil && s.cfg.Gateway.OpenAIWS.ModeRouterV2Enabled ingressMode := OpenAIWSIngressModeShared if modeRouterV2Enabled { ingressMode = account.ResolveOpenAIResponsesWebSocketV2Mode(s.cfg.Gateway.OpenAIWS.IngressModeDefault) if ingressMode == OpenAIWSIngressModeOff { return NewOpenAIWSClientCloseError( coderws.StatusPolicyViolation, "websocket mode is disabled for this account", nil, ) } } if wsDecision.Transport != OpenAIUpstreamTransportResponsesWebsocketV2 { return fmt.Errorf("websocket ingress requires ws_v2 transport, got=%s", wsDecision.Transport) } dedicatedMode := modeRouterV2Enabled && ingressMode == OpenAIWSIngressModeDedicated wsURL, err := s.buildOpenAIResponsesWSURL(account) if err != nil { return fmt.Errorf("build ws url: %w", err) } wsHost := "-" wsPath := "-" if parsedURL, parseErr := url.Parse(wsURL); parseErr == nil && parsedURL != nil { wsHost = normalizeOpenAIWSLogValue(parsedURL.Host) wsPath = normalizeOpenAIWSLogValue(parsedURL.Path) } debugEnabled := isOpenAIWSModeDebugEnabled() type openAIWSClientPayload struct { payloadRaw []byte rawForHash []byte promptCacheKey string previousResponseID string originalModel string payloadBytes int } applyPayloadMutation := func(current []byte, path string, value any) ([]byte, error) { next, err := sjson.SetBytes(current, path, value) if err == nil { return next, nil } // 仅在确实需要修改 payload 且 sjson 失败时,退回 map 路径确保兼容性。 payload := make(map[string]any) if unmarshalErr := json.Unmarshal(current, &payload); unmarshalErr != nil { return nil, err } switch path { case "type", "model": payload[path] = value case "client_metadata." + openAIWSTurnMetadataHeader: setOpenAIWSTurnMetadata(payload, fmt.Sprintf("%v", value)) default: return nil, err } rebuilt, marshalErr := json.Marshal(payload) if marshalErr != nil { return nil, marshalErr } return rebuilt, nil } parseClientPayload := func(raw []byte) (openAIWSClientPayload, error) { trimmed := bytes.TrimSpace(raw) if len(trimmed) == 0 { return openAIWSClientPayload{}, NewOpenAIWSClientCloseError(coderws.StatusPolicyViolation, "empty websocket request payload", nil) } if !gjson.ValidBytes(trimmed) { return openAIWSClientPayload{}, NewOpenAIWSClientCloseError(coderws.StatusPolicyViolation, "invalid websocket request payload", errors.New("invalid json")) } values := gjson.GetManyBytes(trimmed, "type", "model", "prompt_cache_key", "previous_response_id") eventType := strings.TrimSpace(values[0].String()) normalized := trimmed switch eventType { case "": eventType = "response.create" next, setErr := applyPayloadMutation(normalized, "type", eventType) if setErr != nil { return openAIWSClientPayload{}, NewOpenAIWSClientCloseError(coderws.StatusPolicyViolation, "invalid websocket request payload", setErr) } normalized = next case "response.create": case "response.append": return openAIWSClientPayload{}, NewOpenAIWSClientCloseError( coderws.StatusPolicyViolation, "response.append is not supported in ws v2; use response.create with previous_response_id", nil, ) default: return openAIWSClientPayload{}, NewOpenAIWSClientCloseError( coderws.StatusPolicyViolation, fmt.Sprintf("unsupported websocket request type: %s", eventType), nil, ) } originalModel := strings.TrimSpace(values[1].String()) if originalModel == "" { return openAIWSClientPayload{}, NewOpenAIWSClientCloseError( coderws.StatusPolicyViolation, "model is required in response.create payload", nil, ) } promptCacheKey := strings.TrimSpace(values[2].String()) previousResponseID := strings.TrimSpace(values[3].String()) previousResponseIDKind := ClassifyOpenAIPreviousResponseIDKind(previousResponseID) if previousResponseID != "" && previousResponseIDKind == OpenAIPreviousResponseIDKindMessageID { return openAIWSClientPayload{}, NewOpenAIWSClientCloseError( coderws.StatusPolicyViolation, "previous_response_id must be a response.id (resp_*), not a message id", nil, ) } if turnMetadata := strings.TrimSpace(c.GetHeader(openAIWSTurnMetadataHeader)); turnMetadata != "" { next, setErr := applyPayloadMutation(normalized, "client_metadata."+openAIWSTurnMetadataHeader, turnMetadata) if setErr != nil { return openAIWSClientPayload{}, NewOpenAIWSClientCloseError(coderws.StatusPolicyViolation, "invalid websocket request payload", setErr) } normalized = next } mappedModel := account.GetMappedModel(originalModel) if normalizedModel := normalizeCodexModel(mappedModel); normalizedModel != "" { mappedModel = normalizedModel } if mappedModel != originalModel { next, setErr := applyPayloadMutation(normalized, "model", mappedModel) if setErr != nil { return openAIWSClientPayload{}, NewOpenAIWSClientCloseError(coderws.StatusPolicyViolation, "invalid websocket request payload", setErr) } normalized = next } return openAIWSClientPayload{ payloadRaw: normalized, rawForHash: trimmed, promptCacheKey: promptCacheKey, previousResponseID: previousResponseID, originalModel: originalModel, payloadBytes: len(normalized), }, nil } firstPayload, err := parseClientPayload(firstClientMessage) if err != nil { return err } turnState := strings.TrimSpace(c.GetHeader(openAIWSTurnStateHeader)) stateStore := s.getOpenAIWSStateStore() groupID := getOpenAIGroupIDFromContext(c) sessionHash := s.GenerateSessionHash(c, firstPayload.rawForHash) if turnState == "" && stateStore != nil && sessionHash != "" { if savedTurnState, ok := stateStore.GetSessionTurnState(groupID, sessionHash); ok { turnState = savedTurnState } } preferredConnID := "" if stateStore != nil && firstPayload.previousResponseID != "" { if connID, ok := stateStore.GetResponseConn(firstPayload.previousResponseID); ok { preferredConnID = connID } } storeDisabled := s.isOpenAIWSStoreDisabledInRequestRaw(firstPayload.payloadRaw, account) storeDisabledConnMode := s.openAIWSStoreDisabledConnMode() if stateStore != nil && storeDisabled && firstPayload.previousResponseID == "" && sessionHash != "" { if connID, ok := stateStore.GetSessionConn(groupID, sessionHash); ok { preferredConnID = connID } } isCodexCLI := openai.IsCodexCLIRequest(c.GetHeader("User-Agent")) || (s.cfg != nil && s.cfg.Gateway.ForceCodexCLI) wsHeaders, _ := s.buildOpenAIWSHeaders(c, account, token, wsDecision, isCodexCLI, turnState, strings.TrimSpace(c.GetHeader(openAIWSTurnMetadataHeader)), firstPayload.promptCacheKey) baseAcquireReq := openAIWSAcquireRequest{ Account: account, WSURL: wsURL, Headers: wsHeaders, ProxyURL: func() string { if account.ProxyID != nil && account.Proxy != nil { return account.Proxy.URL() } return "" }(), ForceNewConn: false, } pool := s.getOpenAIWSConnPool() if pool == nil { return errors.New("openai ws conn pool is nil") } logOpenAIWSModeInfo( "ingress_ws_protocol_confirm account_id=%d account_type=%s transport=%s ws_host=%s ws_path=%s ws_mode=%s store_disabled=%v has_session_hash=%v has_previous_response_id=%v", account.ID, account.Type, normalizeOpenAIWSLogValue(string(wsDecision.Transport)), wsHost, wsPath, normalizeOpenAIWSLogValue(ingressMode), storeDisabled, sessionHash != "", firstPayload.previousResponseID != "", ) if debugEnabled { logOpenAIWSModeDebug( "ingress_ws_start account_id=%d account_type=%s transport=%s ws_host=%s preferred_conn_id=%s has_session_hash=%v has_previous_response_id=%v store_disabled=%v", account.ID, account.Type, normalizeOpenAIWSLogValue(string(wsDecision.Transport)), wsHost, truncateOpenAIWSLogValue(preferredConnID, openAIWSIDValueMaxLen), sessionHash != "", firstPayload.previousResponseID != "", storeDisabled, ) } if firstPayload.previousResponseID != "" { firstPreviousResponseIDKind := ClassifyOpenAIPreviousResponseIDKind(firstPayload.previousResponseID) logOpenAIWSModeInfo( "ingress_ws_continuation_probe account_id=%d turn=%d previous_response_id=%s previous_response_id_kind=%s preferred_conn_id=%s session_hash=%s header_session_id=%s header_conversation_id=%s has_turn_state=%v turn_state_len=%d has_prompt_cache_key=%v store_disabled=%v", account.ID, 1, truncateOpenAIWSLogValue(firstPayload.previousResponseID, openAIWSIDValueMaxLen), normalizeOpenAIWSLogValue(firstPreviousResponseIDKind), truncateOpenAIWSLogValue(preferredConnID, openAIWSIDValueMaxLen), truncateOpenAIWSLogValue(sessionHash, 12), openAIWSHeaderValueForLog(baseAcquireReq.Headers, "session_id"), openAIWSHeaderValueForLog(baseAcquireReq.Headers, "conversation_id"), turnState != "", len(turnState), firstPayload.promptCacheKey != "", storeDisabled, ) } acquireTimeout := s.openAIWSAcquireTimeout() if acquireTimeout <= 0 { acquireTimeout = 30 * time.Second } acquireTurnLease := func(turn int, preferred string, forcePreferredConn bool) (*openAIWSConnLease, error) { req := cloneOpenAIWSAcquireRequest(baseAcquireReq) req.PreferredConnID = strings.TrimSpace(preferred) req.ForcePreferredConn = forcePreferredConn // dedicated 模式下每次获取均新建连接,避免跨会话复用残留上下文。 req.ForceNewConn = dedicatedMode acquireCtx, acquireCancel := context.WithTimeout(ctx, acquireTimeout) lease, acquireErr := pool.Acquire(acquireCtx, req) acquireCancel() if acquireErr != nil { dialStatus, dialClass, dialCloseStatus, dialCloseReason, dialRespServer, dialRespVia, dialRespCFRay, dialRespReqID := summarizeOpenAIWSDialError(acquireErr) logOpenAIWSModeInfo( "ingress_ws_upstream_acquire_fail account_id=%d turn=%d reason=%s dial_status=%d dial_class=%s dial_close_status=%s dial_close_reason=%s dial_resp_server=%s dial_resp_via=%s dial_resp_cf_ray=%s dial_resp_x_request_id=%s cause=%s preferred_conn_id=%s force_preferred_conn=%v ws_host=%s ws_path=%s proxy_enabled=%v", account.ID, turn, normalizeOpenAIWSLogValue(classifyOpenAIWSAcquireError(acquireErr)), dialStatus, dialClass, dialCloseStatus, truncateOpenAIWSLogValue(dialCloseReason, openAIWSHeaderValueMaxLen), dialRespServer, dialRespVia, dialRespCFRay, dialRespReqID, truncateOpenAIWSLogValue(acquireErr.Error(), openAIWSLogValueMaxLen), truncateOpenAIWSLogValue(preferred, openAIWSIDValueMaxLen), forcePreferredConn, wsHost, wsPath, account.ProxyID != nil && account.Proxy != nil, ) if errors.Is(acquireErr, errOpenAIWSPreferredConnUnavailable) { return nil, NewOpenAIWSClientCloseError( coderws.StatusPolicyViolation, "upstream continuation connection is unavailable; please restart the conversation", acquireErr, ) } if errors.Is(acquireErr, context.DeadlineExceeded) || errors.Is(acquireErr, errOpenAIWSConnQueueFull) { return nil, NewOpenAIWSClientCloseError( coderws.StatusTryAgainLater, "upstream websocket is busy, please retry later", acquireErr, ) } return nil, acquireErr } connID := strings.TrimSpace(lease.ConnID()) if handshakeTurnState := strings.TrimSpace(lease.HandshakeHeader(openAIWSTurnStateHeader)); handshakeTurnState != "" { turnState = handshakeTurnState if stateStore != nil && sessionHash != "" { stateStore.BindSessionTurnState(groupID, sessionHash, handshakeTurnState, s.openAIWSSessionStickyTTL()) } updatedHeaders := cloneHeader(baseAcquireReq.Headers) if updatedHeaders == nil { updatedHeaders = make(http.Header) } updatedHeaders.Set(openAIWSTurnStateHeader, handshakeTurnState) baseAcquireReq.Headers = updatedHeaders } logOpenAIWSModeInfo( "ingress_ws_upstream_connected account_id=%d turn=%d conn_id=%s conn_reused=%v conn_pick_ms=%d queue_wait_ms=%d preferred_conn_id=%s", account.ID, turn, truncateOpenAIWSLogValue(connID, openAIWSIDValueMaxLen), lease.Reused(), lease.ConnPickDuration().Milliseconds(), lease.QueueWaitDuration().Milliseconds(), truncateOpenAIWSLogValue(preferred, openAIWSIDValueMaxLen), ) return lease, nil } writeClientMessage := func(message []byte) error { writeCtx, cancel := context.WithTimeout(ctx, s.openAIWSWriteTimeout()) defer cancel() return clientConn.Write(writeCtx, coderws.MessageText, message) } readClientMessage := func() ([]byte, error) { msgType, payload, readErr := clientConn.Read(ctx) if readErr != nil { return nil, readErr } if msgType != coderws.MessageText && msgType != coderws.MessageBinary { return nil, NewOpenAIWSClientCloseError( coderws.StatusPolicyViolation, fmt.Sprintf("unsupported websocket client message type: %s", msgType.String()), nil, ) } return payload, nil } sendAndRelay := func(turn int, lease *openAIWSConnLease, payload []byte, payloadBytes int, originalModel string) (*OpenAIForwardResult, error) { if lease == nil { return nil, errors.New("upstream websocket lease is nil") } turnStart := time.Now() wroteDownstream := false if err := lease.WriteJSONWithContextTimeout(ctx, json.RawMessage(payload), s.openAIWSWriteTimeout()); err != nil { return nil, wrapOpenAIWSIngressTurnError( "write_upstream", fmt.Errorf("write upstream websocket request: %w", err), false, ) } if debugEnabled { logOpenAIWSModeDebug( "ingress_ws_turn_request_sent account_id=%d turn=%d conn_id=%s payload_bytes=%d", account.ID, turn, truncateOpenAIWSLogValue(lease.ConnID(), openAIWSIDValueMaxLen), payloadBytes, ) } responseID := "" usage := OpenAIUsage{} var firstTokenMs *int reqStream := openAIWSPayloadBoolFromRaw(payload, "stream", true) turnPreviousResponseID := openAIWSPayloadStringFromRaw(payload, "previous_response_id") turnPreviousResponseIDKind := ClassifyOpenAIPreviousResponseIDKind(turnPreviousResponseID) turnPromptCacheKey := openAIWSPayloadStringFromRaw(payload, "prompt_cache_key") turnStoreDisabled := s.isOpenAIWSStoreDisabledInRequestRaw(payload, account) turnHasFunctionCallOutput := gjson.GetBytes(payload, `input.#(type=="function_call_output")`).Exists() eventCount := 0 tokenEventCount := 0 terminalEventCount := 0 firstEventType := "" lastEventType := "" needModelReplace := false clientDisconnected := false mappedModel := "" var mappedModelBytes []byte if originalModel != "" { mappedModel = account.GetMappedModel(originalModel) if normalizedModel := normalizeCodexModel(mappedModel); normalizedModel != "" { mappedModel = normalizedModel } needModelReplace = mappedModel != "" && mappedModel != originalModel if needModelReplace { mappedModelBytes = []byte(mappedModel) } } for { upstreamMessage, readErr := lease.ReadMessageWithContextTimeout(ctx, s.openAIWSReadTimeout()) if readErr != nil { lease.MarkBroken() return nil, wrapOpenAIWSIngressTurnError( "read_upstream", fmt.Errorf("read upstream websocket event: %w", readErr), wroteDownstream, ) } eventType, eventResponseID, _ := parseOpenAIWSEventEnvelope(upstreamMessage) if responseID == "" && eventResponseID != "" { responseID = eventResponseID } if eventType != "" { eventCount++ if firstEventType == "" { firstEventType = eventType } lastEventType = eventType } if eventType == "error" { errCodeRaw, errTypeRaw, errMsgRaw := parseOpenAIWSErrorEventFields(upstreamMessage) fallbackReason, _ := classifyOpenAIWSErrorEventFromRaw(errCodeRaw, errTypeRaw, errMsgRaw) errCode, errType, errMessage := summarizeOpenAIWSErrorEventFieldsFromRaw(errCodeRaw, errTypeRaw, errMsgRaw) recoverablePrevNotFound := fallbackReason == openAIWSIngressStagePreviousResponseNotFound && turnPreviousResponseID != "" && !turnHasFunctionCallOutput && s.openAIWSIngressPreviousResponseRecoveryEnabled() && !wroteDownstream if recoverablePrevNotFound { // 可恢复场景使用非 error 关键字日志,避免被 LegacyPrintf 误判为 ERROR 级别。 logOpenAIWSModeInfo( "ingress_ws_prev_response_recoverable account_id=%d turn=%d conn_id=%s idx=%d reason=%s code=%s type=%s message=%s previous_response_id=%s previous_response_id_kind=%s response_id=%s store_disabled=%v has_prompt_cache_key=%v", account.ID, turn, truncateOpenAIWSLogValue(lease.ConnID(), openAIWSIDValueMaxLen), eventCount, truncateOpenAIWSLogValue(fallbackReason, openAIWSLogValueMaxLen), errCode, errType, errMessage, truncateOpenAIWSLogValue(turnPreviousResponseID, openAIWSIDValueMaxLen), normalizeOpenAIWSLogValue(turnPreviousResponseIDKind), truncateOpenAIWSLogValue(responseID, openAIWSIDValueMaxLen), turnStoreDisabled, turnPromptCacheKey != "", ) } else { logOpenAIWSModeInfo( "ingress_ws_error_event account_id=%d turn=%d conn_id=%s idx=%d fallback_reason=%s err_code=%s err_type=%s err_message=%s previous_response_id=%s previous_response_id_kind=%s response_id=%s store_disabled=%v has_prompt_cache_key=%v", account.ID, turn, truncateOpenAIWSLogValue(lease.ConnID(), openAIWSIDValueMaxLen), eventCount, truncateOpenAIWSLogValue(fallbackReason, openAIWSLogValueMaxLen), errCode, errType, errMessage, truncateOpenAIWSLogValue(turnPreviousResponseID, openAIWSIDValueMaxLen), normalizeOpenAIWSLogValue(turnPreviousResponseIDKind), truncateOpenAIWSLogValue(responseID, openAIWSIDValueMaxLen), turnStoreDisabled, turnPromptCacheKey != "", ) } // previous_response_not_found 在 ingress 模式支持单次恢复重试: // 不把该 error 直接下发客户端,而是由上层去掉 previous_response_id 后重放当前 turn。 if recoverablePrevNotFound { lease.MarkBroken() errMsg := strings.TrimSpace(errMsgRaw) if errMsg == "" { errMsg = "previous response not found" } return nil, wrapOpenAIWSIngressTurnError( openAIWSIngressStagePreviousResponseNotFound, errors.New(errMsg), false, ) } } isTokenEvent := isOpenAIWSTokenEvent(eventType) if isTokenEvent { tokenEventCount++ } isTerminalEvent := isOpenAIWSTerminalEvent(eventType) if isTerminalEvent { terminalEventCount++ } if firstTokenMs == nil && isTokenEvent { ms := int(time.Since(turnStart).Milliseconds()) firstTokenMs = &ms } if openAIWSEventShouldParseUsage(eventType) { parseOpenAIWSResponseUsageFromCompletedEvent(upstreamMessage, &usage) } if !clientDisconnected { if needModelReplace && len(mappedModelBytes) > 0 && openAIWSEventMayContainModel(eventType) && bytes.Contains(upstreamMessage, mappedModelBytes) { upstreamMessage = replaceOpenAIWSMessageModel(upstreamMessage, mappedModel, originalModel) } if openAIWSEventMayContainToolCalls(eventType) && openAIWSMessageLikelyContainsToolCalls(upstreamMessage) { if corrected, changed := s.toolCorrector.CorrectToolCallsInSSEBytes(upstreamMessage); changed { upstreamMessage = corrected } } if err := writeClientMessage(upstreamMessage); err != nil { if isOpenAIWSClientDisconnectError(err) { clientDisconnected = true closeStatus, closeReason := summarizeOpenAIWSReadCloseError(err) logOpenAIWSModeInfo( "ingress_ws_client_disconnected_drain account_id=%d turn=%d conn_id=%s close_status=%s close_reason=%s", account.ID, turn, truncateOpenAIWSLogValue(lease.ConnID(), openAIWSIDValueMaxLen), closeStatus, truncateOpenAIWSLogValue(closeReason, openAIWSHeaderValueMaxLen), ) } else { return nil, wrapOpenAIWSIngressTurnError( "write_client", fmt.Errorf("write client websocket event: %w", err), wroteDownstream, ) } } else { wroteDownstream = true } } if isTerminalEvent { // 客户端已断连时,上游连接的 session 状态不可信,标记 broken 避免回池复用。 if clientDisconnected { lease.MarkBroken() } firstTokenMsValue := -1 if firstTokenMs != nil { firstTokenMsValue = *firstTokenMs } if debugEnabled { logOpenAIWSModeDebug( "ingress_ws_turn_completed account_id=%d turn=%d conn_id=%s response_id=%s duration_ms=%d events=%d token_events=%d terminal_events=%d first_event=%s last_event=%s first_token_ms=%d client_disconnected=%v", account.ID, turn, truncateOpenAIWSLogValue(lease.ConnID(), openAIWSIDValueMaxLen), truncateOpenAIWSLogValue(responseID, openAIWSIDValueMaxLen), time.Since(turnStart).Milliseconds(), eventCount, tokenEventCount, terminalEventCount, truncateOpenAIWSLogValue(firstEventType, openAIWSLogValueMaxLen), truncateOpenAIWSLogValue(lastEventType, openAIWSLogValueMaxLen), firstTokenMsValue, clientDisconnected, ) } return &OpenAIForwardResult{ RequestID: responseID, Usage: usage, Model: originalModel, ReasoningEffort: extractOpenAIReasoningEffortFromBody(payload, originalModel), Stream: reqStream, OpenAIWSMode: true, Duration: time.Since(turnStart), FirstTokenMs: firstTokenMs, }, nil } } } currentPayload := firstPayload.payloadRaw currentOriginalModel := firstPayload.originalModel currentPayloadBytes := firstPayload.payloadBytes isStrictAffinityTurn := func(payload []byte) bool { if !storeDisabled { return false } return strings.TrimSpace(openAIWSPayloadStringFromRaw(payload, "previous_response_id")) != "" } var sessionLease *openAIWSConnLease sessionConnID := "" pinnedSessionConnID := "" unpinSessionConn := func(connID string) { connID = strings.TrimSpace(connID) if connID == "" || pinnedSessionConnID != connID { return } pool.UnpinConn(account.ID, connID) pinnedSessionConnID = "" } pinSessionConn := func(connID string) { if !storeDisabled { return } connID = strings.TrimSpace(connID) if connID == "" || pinnedSessionConnID == connID { return } if pinnedSessionConnID != "" { pool.UnpinConn(account.ID, pinnedSessionConnID) pinnedSessionConnID = "" } if pool.PinConn(account.ID, connID) { pinnedSessionConnID = connID } } releaseSessionLease := func() { if sessionLease == nil { return } if dedicatedMode { // dedicated 会话结束后主动标记损坏,确保连接不会跨会话复用。 sessionLease.MarkBroken() } unpinSessionConn(sessionConnID) sessionLease.Release() if debugEnabled { logOpenAIWSModeDebug( "ingress_ws_upstream_released account_id=%d conn_id=%s", account.ID, truncateOpenAIWSLogValue(sessionConnID, openAIWSIDValueMaxLen), ) } } defer releaseSessionLease() turn := 1 turnRetry := 0 turnPrevRecoveryTried := false lastTurnFinishedAt := time.Time{} lastTurnResponseID := "" lastTurnPayload := []byte(nil) var lastTurnStrictState *openAIWSIngressPreviousTurnStrictState lastTurnReplayInput := []json.RawMessage(nil) lastTurnReplayInputExists := false currentTurnReplayInput := []json.RawMessage(nil) currentTurnReplayInputExists := false skipBeforeTurn := false resetSessionLease := func(markBroken bool) { if sessionLease == nil { return } if markBroken { sessionLease.MarkBroken() } releaseSessionLease() sessionLease = nil sessionConnID = "" preferredConnID = "" } recoverIngressPrevResponseNotFound := func(relayErr error, turn int, connID string) bool { if !isOpenAIWSIngressPreviousResponseNotFound(relayErr) { return false } if turnPrevRecoveryTried || !s.openAIWSIngressPreviousResponseRecoveryEnabled() { return false } if isStrictAffinityTurn(currentPayload) { // Layer 2:严格亲和链路命中 previous_response_not_found 时,降级为“去掉 previous_response_id 后重放一次”。 // 该错误说明续链锚点已失效,继续 strict fail-close 只会直接中断本轮请求。 logOpenAIWSModeInfo( "ingress_ws_prev_response_recovery_layer2 account_id=%d turn=%d conn_id=%s store_disabled_conn_mode=%s action=drop_previous_response_id_retry", account.ID, turn, truncateOpenAIWSLogValue(connID, openAIWSIDValueMaxLen), normalizeOpenAIWSLogValue(storeDisabledConnMode), ) } turnPrevRecoveryTried = true updatedPayload, removed, dropErr := dropPreviousResponseIDFromRawPayload(currentPayload) if dropErr != nil || !removed { reason := "not_removed" if dropErr != nil { reason = "drop_error" } logOpenAIWSModeInfo( "ingress_ws_prev_response_recovery_skip account_id=%d turn=%d conn_id=%s reason=%s", account.ID, turn, truncateOpenAIWSLogValue(connID, openAIWSIDValueMaxLen), normalizeOpenAIWSLogValue(reason), ) return false } updatedWithInput, setInputErr := setOpenAIWSPayloadInputSequence( updatedPayload, currentTurnReplayInput, currentTurnReplayInputExists, ) if setInputErr != nil { logOpenAIWSModeInfo( "ingress_ws_prev_response_recovery_skip account_id=%d turn=%d conn_id=%s reason=set_full_input_error cause=%s", account.ID, turn, truncateOpenAIWSLogValue(connID, openAIWSIDValueMaxLen), truncateOpenAIWSLogValue(setInputErr.Error(), openAIWSLogValueMaxLen), ) return false } logOpenAIWSModeInfo( "ingress_ws_prev_response_recovery account_id=%d turn=%d conn_id=%s action=drop_previous_response_id retry=1", account.ID, turn, truncateOpenAIWSLogValue(connID, openAIWSIDValueMaxLen), ) currentPayload = updatedWithInput currentPayloadBytes = len(updatedWithInput) resetSessionLease(true) skipBeforeTurn = true return true } retryIngressTurn := func(relayErr error, turn int, connID string) bool { if !isOpenAIWSIngressTurnRetryable(relayErr) || turnRetry >= 1 { return false } if isStrictAffinityTurn(currentPayload) { logOpenAIWSModeInfo( "ingress_ws_turn_retry_skip account_id=%d turn=%d conn_id=%s reason=strict_affinity", account.ID, turn, truncateOpenAIWSLogValue(connID, openAIWSIDValueMaxLen), ) return false } turnRetry++ logOpenAIWSModeInfo( "ingress_ws_turn_retry account_id=%d turn=%d retry=%d reason=%s conn_id=%s", account.ID, turn, turnRetry, truncateOpenAIWSLogValue(openAIWSIngressTurnRetryReason(relayErr), openAIWSLogValueMaxLen), truncateOpenAIWSLogValue(connID, openAIWSIDValueMaxLen), ) resetSessionLease(true) skipBeforeTurn = true return true } for { if !skipBeforeTurn && hooks != nil && hooks.BeforeTurn != nil { if err := hooks.BeforeTurn(turn); err != nil { return err } } skipBeforeTurn = false currentPreviousResponseID := openAIWSPayloadStringFromRaw(currentPayload, "previous_response_id") expectedPrev := strings.TrimSpace(lastTurnResponseID) hasFunctionCallOutput := gjson.GetBytes(currentPayload, `input.#(type=="function_call_output")`).Exists() // store=false + function_call_output 场景必须有续链锚点。 // 若客户端未传 previous_response_id,优先回填上一轮响应 ID,避免上游报 call_id 无法关联。 if shouldInferIngressFunctionCallOutputPreviousResponseID( storeDisabled, turn, hasFunctionCallOutput, currentPreviousResponseID, expectedPrev, ) { updatedPayload, setPrevErr := setPreviousResponseIDToRawPayload(currentPayload, expectedPrev) if setPrevErr != nil { logOpenAIWSModeInfo( "ingress_ws_function_call_output_prev_infer_skip account_id=%d turn=%d conn_id=%s reason=set_previous_response_id_error cause=%s expected_previous_response_id=%s", account.ID, turn, truncateOpenAIWSLogValue(sessionConnID, openAIWSIDValueMaxLen), truncateOpenAIWSLogValue(setPrevErr.Error(), openAIWSLogValueMaxLen), truncateOpenAIWSLogValue(expectedPrev, openAIWSIDValueMaxLen), ) } else { currentPayload = updatedPayload currentPayloadBytes = len(updatedPayload) currentPreviousResponseID = expectedPrev logOpenAIWSModeInfo( "ingress_ws_function_call_output_prev_infer account_id=%d turn=%d conn_id=%s action=set_previous_response_id previous_response_id=%s", account.ID, turn, truncateOpenAIWSLogValue(sessionConnID, openAIWSIDValueMaxLen), truncateOpenAIWSLogValue(expectedPrev, openAIWSIDValueMaxLen), ) } } nextReplayInput, nextReplayInputExists, replayInputErr := buildOpenAIWSReplayInputSequence( lastTurnReplayInput, lastTurnReplayInputExists, currentPayload, currentPreviousResponseID != "", ) if replayInputErr != nil { logOpenAIWSModeInfo( "ingress_ws_replay_input_skip account_id=%d turn=%d conn_id=%s reason=build_error cause=%s", account.ID, turn, truncateOpenAIWSLogValue(sessionConnID, openAIWSIDValueMaxLen), truncateOpenAIWSLogValue(replayInputErr.Error(), openAIWSLogValueMaxLen), ) currentTurnReplayInput = nil currentTurnReplayInputExists = false } else { currentTurnReplayInput = nextReplayInput currentTurnReplayInputExists = nextReplayInputExists } if storeDisabled && turn > 1 && currentPreviousResponseID != "" { shouldKeepPreviousResponseID := false strictReason := "" var strictErr error if lastTurnStrictState != nil { shouldKeepPreviousResponseID, strictReason, strictErr = shouldKeepIngressPreviousResponseIDWithStrictState( lastTurnStrictState, currentPayload, lastTurnResponseID, hasFunctionCallOutput, ) } else { shouldKeepPreviousResponseID, strictReason, strictErr = shouldKeepIngressPreviousResponseID( lastTurnPayload, currentPayload, lastTurnResponseID, hasFunctionCallOutput, ) } if strictErr != nil { logOpenAIWSModeInfo( "ingress_ws_prev_response_strict_eval account_id=%d turn=%d conn_id=%s action=keep_previous_response_id reason=%s cause=%s previous_response_id=%s expected_previous_response_id=%s has_function_call_output=%v", account.ID, turn, truncateOpenAIWSLogValue(sessionConnID, openAIWSIDValueMaxLen), normalizeOpenAIWSLogValue(strictReason), truncateOpenAIWSLogValue(strictErr.Error(), openAIWSLogValueMaxLen), truncateOpenAIWSLogValue(currentPreviousResponseID, openAIWSIDValueMaxLen), truncateOpenAIWSLogValue(expectedPrev, openAIWSIDValueMaxLen), hasFunctionCallOutput, ) } else if !shouldKeepPreviousResponseID { updatedPayload, removed, dropErr := dropPreviousResponseIDFromRawPayload(currentPayload) if dropErr != nil || !removed { dropReason := "not_removed" if dropErr != nil { dropReason = "drop_error" } logOpenAIWSModeInfo( "ingress_ws_prev_response_strict_eval account_id=%d turn=%d conn_id=%s action=keep_previous_response_id reason=%s drop_reason=%s previous_response_id=%s expected_previous_response_id=%s has_function_call_output=%v", account.ID, turn, truncateOpenAIWSLogValue(sessionConnID, openAIWSIDValueMaxLen), normalizeOpenAIWSLogValue(strictReason), normalizeOpenAIWSLogValue(dropReason), truncateOpenAIWSLogValue(currentPreviousResponseID, openAIWSIDValueMaxLen), truncateOpenAIWSLogValue(expectedPrev, openAIWSIDValueMaxLen), hasFunctionCallOutput, ) } else { updatedWithInput, setInputErr := setOpenAIWSPayloadInputSequence( updatedPayload, currentTurnReplayInput, currentTurnReplayInputExists, ) if setInputErr != nil { logOpenAIWSModeInfo( "ingress_ws_prev_response_strict_eval account_id=%d turn=%d conn_id=%s action=keep_previous_response_id reason=%s drop_reason=set_full_input_error previous_response_id=%s expected_previous_response_id=%s cause=%s has_function_call_output=%v", account.ID, turn, truncateOpenAIWSLogValue(sessionConnID, openAIWSIDValueMaxLen), normalizeOpenAIWSLogValue(strictReason), truncateOpenAIWSLogValue(currentPreviousResponseID, openAIWSIDValueMaxLen), truncateOpenAIWSLogValue(expectedPrev, openAIWSIDValueMaxLen), truncateOpenAIWSLogValue(setInputErr.Error(), openAIWSLogValueMaxLen), hasFunctionCallOutput, ) } else { currentPayload = updatedWithInput currentPayloadBytes = len(updatedWithInput) logOpenAIWSModeInfo( "ingress_ws_prev_response_strict_eval account_id=%d turn=%d conn_id=%s action=drop_previous_response_id_full_create reason=%s previous_response_id=%s expected_previous_response_id=%s has_function_call_output=%v", account.ID, turn, truncateOpenAIWSLogValue(sessionConnID, openAIWSIDValueMaxLen), normalizeOpenAIWSLogValue(strictReason), truncateOpenAIWSLogValue(currentPreviousResponseID, openAIWSIDValueMaxLen), truncateOpenAIWSLogValue(expectedPrev, openAIWSIDValueMaxLen), hasFunctionCallOutput, ) currentPreviousResponseID = "" } } } } forcePreferredConn := isStrictAffinityTurn(currentPayload) if sessionLease == nil { acquiredLease, acquireErr := acquireTurnLease(turn, preferredConnID, forcePreferredConn) if acquireErr != nil { return fmt.Errorf("acquire upstream websocket: %w", acquireErr) } sessionLease = acquiredLease sessionConnID = strings.TrimSpace(sessionLease.ConnID()) if storeDisabled { pinSessionConn(sessionConnID) } else { unpinSessionConn(sessionConnID) } } shouldPreflightPing := turn > 1 && sessionLease != nil && turnRetry == 0 if shouldPreflightPing && openAIWSIngressPreflightPingIdle > 0 && !lastTurnFinishedAt.IsZero() { if time.Since(lastTurnFinishedAt) < openAIWSIngressPreflightPingIdle { shouldPreflightPing = false } } if shouldPreflightPing { if pingErr := sessionLease.PingWithTimeout(openAIWSConnHealthCheckTO); pingErr != nil { logOpenAIWSModeInfo( "ingress_ws_upstream_preflight_ping_fail account_id=%d turn=%d conn_id=%s cause=%s", account.ID, turn, truncateOpenAIWSLogValue(sessionConnID, openAIWSIDValueMaxLen), truncateOpenAIWSLogValue(pingErr.Error(), openAIWSLogValueMaxLen), ) if forcePreferredConn { if !turnPrevRecoveryTried && currentPreviousResponseID != "" { updatedPayload, removed, dropErr := dropPreviousResponseIDFromRawPayload(currentPayload) if dropErr != nil || !removed { reason := "not_removed" if dropErr != nil { reason = "drop_error" } logOpenAIWSModeInfo( "ingress_ws_preflight_ping_recovery_skip account_id=%d turn=%d conn_id=%s reason=%s previous_response_id=%s", account.ID, turn, truncateOpenAIWSLogValue(sessionConnID, openAIWSIDValueMaxLen), normalizeOpenAIWSLogValue(reason), truncateOpenAIWSLogValue(currentPreviousResponseID, openAIWSIDValueMaxLen), ) } else { updatedWithInput, setInputErr := setOpenAIWSPayloadInputSequence( updatedPayload, currentTurnReplayInput, currentTurnReplayInputExists, ) if setInputErr != nil { logOpenAIWSModeInfo( "ingress_ws_preflight_ping_recovery_skip account_id=%d turn=%d conn_id=%s reason=set_full_input_error previous_response_id=%s cause=%s", account.ID, turn, truncateOpenAIWSLogValue(sessionConnID, openAIWSIDValueMaxLen), truncateOpenAIWSLogValue(currentPreviousResponseID, openAIWSIDValueMaxLen), truncateOpenAIWSLogValue(setInputErr.Error(), openAIWSLogValueMaxLen), ) } else { logOpenAIWSModeInfo( "ingress_ws_preflight_ping_recovery account_id=%d turn=%d conn_id=%s action=drop_previous_response_id_retry previous_response_id=%s", account.ID, turn, truncateOpenAIWSLogValue(sessionConnID, openAIWSIDValueMaxLen), truncateOpenAIWSLogValue(currentPreviousResponseID, openAIWSIDValueMaxLen), ) turnPrevRecoveryTried = true currentPayload = updatedWithInput currentPayloadBytes = len(updatedWithInput) resetSessionLease(true) skipBeforeTurn = true continue } } } resetSessionLease(true) return NewOpenAIWSClientCloseError( coderws.StatusPolicyViolation, "upstream continuation connection is unavailable; please restart the conversation", pingErr, ) } resetSessionLease(true) acquiredLease, acquireErr := acquireTurnLease(turn, preferredConnID, forcePreferredConn) if acquireErr != nil { return fmt.Errorf("acquire upstream websocket after preflight ping fail: %w", acquireErr) } sessionLease = acquiredLease sessionConnID = strings.TrimSpace(sessionLease.ConnID()) if storeDisabled { pinSessionConn(sessionConnID) } } } connID := sessionConnID if currentPreviousResponseID != "" { chainedFromLast := expectedPrev != "" && currentPreviousResponseID == expectedPrev currentPreviousResponseIDKind := ClassifyOpenAIPreviousResponseIDKind(currentPreviousResponseID) logOpenAIWSModeInfo( "ingress_ws_turn_chain account_id=%d turn=%d conn_id=%s previous_response_id=%s previous_response_id_kind=%s last_turn_response_id=%s chained_from_last=%v preferred_conn_id=%s header_session_id=%s header_conversation_id=%s has_turn_state=%v turn_state_len=%d has_prompt_cache_key=%v store_disabled=%v", account.ID, turn, truncateOpenAIWSLogValue(connID, openAIWSIDValueMaxLen), truncateOpenAIWSLogValue(currentPreviousResponseID, openAIWSIDValueMaxLen), normalizeOpenAIWSLogValue(currentPreviousResponseIDKind), truncateOpenAIWSLogValue(expectedPrev, openAIWSIDValueMaxLen), chainedFromLast, truncateOpenAIWSLogValue(preferredConnID, openAIWSIDValueMaxLen), openAIWSHeaderValueForLog(baseAcquireReq.Headers, "session_id"), openAIWSHeaderValueForLog(baseAcquireReq.Headers, "conversation_id"), turnState != "", len(turnState), openAIWSPayloadStringFromRaw(currentPayload, "prompt_cache_key") != "", storeDisabled, ) } result, relayErr := sendAndRelay(turn, sessionLease, currentPayload, currentPayloadBytes, currentOriginalModel) if relayErr != nil { if recoverIngressPrevResponseNotFound(relayErr, turn, connID) { continue } if retryIngressTurn(relayErr, turn, connID) { continue } finalErr := relayErr if unwrapped := errors.Unwrap(relayErr); unwrapped != nil { finalErr = unwrapped } if hooks != nil && hooks.AfterTurn != nil { hooks.AfterTurn(turn, nil, finalErr) } sessionLease.MarkBroken() return finalErr } turnRetry = 0 turnPrevRecoveryTried = false lastTurnFinishedAt = time.Now() if hooks != nil && hooks.AfterTurn != nil { hooks.AfterTurn(turn, result, nil) } if result == nil { return errors.New("websocket turn result is nil") } responseID := strings.TrimSpace(result.RequestID) lastTurnResponseID = responseID lastTurnPayload = cloneOpenAIWSPayloadBytes(currentPayload) lastTurnReplayInput = cloneOpenAIWSRawMessages(currentTurnReplayInput) lastTurnReplayInputExists = currentTurnReplayInputExists nextStrictState, strictStateErr := buildOpenAIWSIngressPreviousTurnStrictState(currentPayload) if strictStateErr != nil { lastTurnStrictState = nil logOpenAIWSModeInfo( "ingress_ws_prev_response_strict_state_skip account_id=%d turn=%d conn_id=%s reason=build_error cause=%s", account.ID, turn, truncateOpenAIWSLogValue(connID, openAIWSIDValueMaxLen), truncateOpenAIWSLogValue(strictStateErr.Error(), openAIWSLogValueMaxLen), ) } else { lastTurnStrictState = nextStrictState } if responseID != "" && stateStore != nil { ttl := s.openAIWSResponseStickyTTL() logOpenAIWSBindResponseAccountWarn(groupID, account.ID, responseID, stateStore.BindResponseAccount(ctx, groupID, responseID, account.ID, ttl)) stateStore.BindResponseConn(responseID, connID, ttl) } if stateStore != nil && storeDisabled && sessionHash != "" { stateStore.BindSessionConn(groupID, sessionHash, connID, s.openAIWSSessionStickyTTL()) } if connID != "" { preferredConnID = connID } nextClientMessage, readErr := readClientMessage() if readErr != nil { if isOpenAIWSClientDisconnectError(readErr) { closeStatus, closeReason := summarizeOpenAIWSReadCloseError(readErr) logOpenAIWSModeInfo( "ingress_ws_client_closed account_id=%d conn_id=%s close_status=%s close_reason=%s", account.ID, truncateOpenAIWSLogValue(connID, openAIWSIDValueMaxLen), closeStatus, truncateOpenAIWSLogValue(closeReason, openAIWSHeaderValueMaxLen), ) return nil } return fmt.Errorf("read client websocket request: %w", readErr) } nextPayload, parseErr := parseClientPayload(nextClientMessage) if parseErr != nil { return parseErr } if nextPayload.promptCacheKey != "" { // ingress 会话在整个客户端 WS 生命周期内复用同一上游连接; // prompt_cache_key 对握手头的更新仅在未来需要重新建连时生效。 updatedHeaders, _ := s.buildOpenAIWSHeaders(c, account, token, wsDecision, isCodexCLI, turnState, strings.TrimSpace(c.GetHeader(openAIWSTurnMetadataHeader)), nextPayload.promptCacheKey) baseAcquireReq.Headers = updatedHeaders } if nextPayload.previousResponseID != "" { expectedPrev := strings.TrimSpace(lastTurnResponseID) chainedFromLast := expectedPrev != "" && nextPayload.previousResponseID == expectedPrev nextPreviousResponseIDKind := ClassifyOpenAIPreviousResponseIDKind(nextPayload.previousResponseID) logOpenAIWSModeInfo( "ingress_ws_next_turn_chain account_id=%d turn=%d next_turn=%d conn_id=%s previous_response_id=%s previous_response_id_kind=%s last_turn_response_id=%s chained_from_last=%v has_prompt_cache_key=%v store_disabled=%v", account.ID, turn, turn+1, truncateOpenAIWSLogValue(connID, openAIWSIDValueMaxLen), truncateOpenAIWSLogValue(nextPayload.previousResponseID, openAIWSIDValueMaxLen), normalizeOpenAIWSLogValue(nextPreviousResponseIDKind), truncateOpenAIWSLogValue(expectedPrev, openAIWSIDValueMaxLen), chainedFromLast, nextPayload.promptCacheKey != "", storeDisabled, ) } if stateStore != nil && nextPayload.previousResponseID != "" { if stickyConnID, ok := stateStore.GetResponseConn(nextPayload.previousResponseID); ok { if sessionConnID != "" && stickyConnID != "" && stickyConnID != sessionConnID { logOpenAIWSModeInfo( "ingress_ws_keep_session_conn account_id=%d turn=%d conn_id=%s sticky_conn_id=%s previous_response_id=%s", account.ID, turn, truncateOpenAIWSLogValue(sessionConnID, openAIWSIDValueMaxLen), truncateOpenAIWSLogValue(stickyConnID, openAIWSIDValueMaxLen), truncateOpenAIWSLogValue(nextPayload.previousResponseID, openAIWSIDValueMaxLen), ) } else { preferredConnID = stickyConnID } } } currentPayload = nextPayload.payloadRaw currentOriginalModel = nextPayload.originalModel currentPayloadBytes = nextPayload.payloadBytes storeDisabled = s.isOpenAIWSStoreDisabledInRequestRaw(currentPayload, account) if !storeDisabled { unpinSessionConn(sessionConnID) } turn++ } } func (s *OpenAIGatewayService) isOpenAIWSGeneratePrewarmEnabled() bool { return s != nil && s.cfg != nil && s.cfg.Gateway.OpenAIWS.PrewarmGenerateEnabled } // performOpenAIWSGeneratePrewarm 在 WSv2 下执行可选的 generate=false 预热。 // 预热默认关闭,仅在配置开启后生效;失败时按可恢复错误回退到 HTTP。 func (s *OpenAIGatewayService) performOpenAIWSGeneratePrewarm( ctx context.Context, lease *openAIWSConnLease, decision OpenAIWSProtocolDecision, payload map[string]any, previousResponseID string, reqBody map[string]any, account *Account, stateStore OpenAIWSStateStore, groupID int64, ) error { if s == nil { return nil } if lease == nil || account == nil { logOpenAIWSModeInfo("prewarm_skip reason=invalid_state has_lease=%v has_account=%v", lease != nil, account != nil) return nil } connID := strings.TrimSpace(lease.ConnID()) if !s.isOpenAIWSGeneratePrewarmEnabled() { return nil } if decision.Transport != OpenAIUpstreamTransportResponsesWebsocketV2 { logOpenAIWSModeInfo( "prewarm_skip account_id=%d conn_id=%s reason=transport_not_v2 transport=%s", account.ID, connID, normalizeOpenAIWSLogValue(string(decision.Transport)), ) return nil } if strings.TrimSpace(previousResponseID) != "" { logOpenAIWSModeInfo( "prewarm_skip account_id=%d conn_id=%s reason=has_previous_response_id previous_response_id=%s", account.ID, connID, truncateOpenAIWSLogValue(previousResponseID, openAIWSIDValueMaxLen), ) return nil } if lease.IsPrewarmed() { logOpenAIWSModeInfo("prewarm_skip account_id=%d conn_id=%s reason=already_prewarmed", account.ID, connID) return nil } if NeedsToolContinuation(reqBody) { logOpenAIWSModeInfo("prewarm_skip account_id=%d conn_id=%s reason=tool_continuation", account.ID, connID) return nil } prewarmStart := time.Now() logOpenAIWSModeInfo("prewarm_start account_id=%d conn_id=%s", account.ID, connID) prewarmPayload := make(map[string]any, len(payload)+1) for k, v := range payload { prewarmPayload[k] = v } prewarmPayload["generate"] = false prewarmPayloadJSON := payloadAsJSONBytes(prewarmPayload) if err := lease.WriteJSONWithContextTimeout(ctx, prewarmPayload, s.openAIWSWriteTimeout()); err != nil { lease.MarkBroken() logOpenAIWSModeInfo( "prewarm_write_fail account_id=%d conn_id=%s cause=%s", account.ID, connID, truncateOpenAIWSLogValue(err.Error(), openAIWSLogValueMaxLen), ) return wrapOpenAIWSFallback("prewarm_write", err) } logOpenAIWSModeInfo("prewarm_write_sent account_id=%d conn_id=%s payload_bytes=%d", account.ID, connID, len(prewarmPayloadJSON)) prewarmResponseID := "" prewarmEventCount := 0 prewarmTerminalCount := 0 for { message, readErr := lease.ReadMessageWithContextTimeout(ctx, s.openAIWSReadTimeout()) if readErr != nil { lease.MarkBroken() closeStatus, closeReason := summarizeOpenAIWSReadCloseError(readErr) logOpenAIWSModeInfo( "prewarm_read_fail account_id=%d conn_id=%s close_status=%s close_reason=%s cause=%s events=%d", account.ID, connID, closeStatus, closeReason, truncateOpenAIWSLogValue(readErr.Error(), openAIWSLogValueMaxLen), prewarmEventCount, ) return wrapOpenAIWSFallback("prewarm_"+classifyOpenAIWSReadFallbackReason(readErr), readErr) } eventType, eventResponseID, _ := parseOpenAIWSEventEnvelope(message) if eventType == "" { continue } prewarmEventCount++ if prewarmResponseID == "" && eventResponseID != "" { prewarmResponseID = eventResponseID } if prewarmEventCount <= openAIWSPrewarmEventLogHead || eventType == "error" || isOpenAIWSTerminalEvent(eventType) { logOpenAIWSModeInfo( "prewarm_event account_id=%d conn_id=%s idx=%d type=%s bytes=%d", account.ID, connID, prewarmEventCount, truncateOpenAIWSLogValue(eventType, openAIWSLogValueMaxLen), len(message), ) } if eventType == "error" { errCodeRaw, errTypeRaw, errMsgRaw := parseOpenAIWSErrorEventFields(message) errMsg := strings.TrimSpace(errMsgRaw) if errMsg == "" { errMsg = "OpenAI websocket prewarm error" } fallbackReason, canFallback := classifyOpenAIWSErrorEventFromRaw(errCodeRaw, errTypeRaw, errMsgRaw) errCode, errType, errMessage := summarizeOpenAIWSErrorEventFieldsFromRaw(errCodeRaw, errTypeRaw, errMsgRaw) logOpenAIWSModeInfo( "prewarm_error_event account_id=%d conn_id=%s idx=%d fallback_reason=%s can_fallback=%v err_code=%s err_type=%s err_message=%s", account.ID, connID, prewarmEventCount, truncateOpenAIWSLogValue(fallbackReason, openAIWSLogValueMaxLen), canFallback, errCode, errType, errMessage, ) lease.MarkBroken() if canFallback { return wrapOpenAIWSFallback("prewarm_"+fallbackReason, errors.New(errMsg)) } return wrapOpenAIWSFallback("prewarm_error_event", errors.New(errMsg)) } if isOpenAIWSTerminalEvent(eventType) { prewarmTerminalCount++ break } } lease.MarkPrewarmed() if prewarmResponseID != "" && stateStore != nil { ttl := s.openAIWSResponseStickyTTL() logOpenAIWSBindResponseAccountWarn(groupID, account.ID, prewarmResponseID, stateStore.BindResponseAccount(ctx, groupID, prewarmResponseID, account.ID, ttl)) stateStore.BindResponseConn(prewarmResponseID, lease.ConnID(), ttl) } logOpenAIWSModeInfo( "prewarm_done account_id=%d conn_id=%s response_id=%s events=%d terminal_events=%d duration_ms=%d", account.ID, connID, truncateOpenAIWSLogValue(prewarmResponseID, openAIWSIDValueMaxLen), prewarmEventCount, prewarmTerminalCount, time.Since(prewarmStart).Milliseconds(), ) return nil } func payloadAsJSON(payload map[string]any) string { return string(payloadAsJSONBytes(payload)) } func payloadAsJSONBytes(payload map[string]any) []byte { if len(payload) == 0 { return []byte("{}") } body, err := json.Marshal(payload) if err != nil { return []byte("{}") } return body } func isOpenAIWSTerminalEvent(eventType string) bool { switch strings.TrimSpace(eventType) { case "response.completed", "response.done", "response.failed", "response.incomplete", "response.cancelled", "response.canceled": return true default: return false } } func isOpenAIWSTokenEvent(eventType string) bool { eventType = strings.TrimSpace(eventType) if eventType == "" { return false } switch eventType { case "response.created", "response.in_progress", "response.output_item.added", "response.output_item.done": return false } if strings.Contains(eventType, ".delta") { return true } if strings.HasPrefix(eventType, "response.output_text") { return true } if strings.HasPrefix(eventType, "response.output") { return true } return eventType == "response.completed" || eventType == "response.done" } func replaceOpenAIWSMessageModel(message []byte, fromModel, toModel string) []byte { if len(message) == 0 { return message } if strings.TrimSpace(fromModel) == "" || strings.TrimSpace(toModel) == "" || fromModel == toModel { return message } if !bytes.Contains(message, []byte(`"model"`)) || !bytes.Contains(message, []byte(fromModel)) { return message } modelValues := gjson.GetManyBytes(message, "model", "response.model") replaceModel := modelValues[0].Exists() && modelValues[0].Str == fromModel replaceResponseModel := modelValues[1].Exists() && modelValues[1].Str == fromModel if !replaceModel && !replaceResponseModel { return message } updated := message if replaceModel { if next, err := sjson.SetBytes(updated, "model", toModel); err == nil { updated = next } } if replaceResponseModel { if next, err := sjson.SetBytes(updated, "response.model", toModel); err == nil { updated = next } } return updated } func populateOpenAIUsageFromResponseJSON(body []byte, usage *OpenAIUsage) { if usage == nil || len(body) == 0 { return } values := gjson.GetManyBytes( body, "usage.input_tokens", "usage.output_tokens", "usage.input_tokens_details.cached_tokens", ) usage.InputTokens = int(values[0].Int()) usage.OutputTokens = int(values[1].Int()) usage.CacheReadInputTokens = int(values[2].Int()) } func getOpenAIGroupIDFromContext(c *gin.Context) int64 { if c == nil { return 0 } value, exists := c.Get("api_key") if !exists { return 0 } apiKey, ok := value.(*APIKey) if !ok || apiKey == nil || apiKey.GroupID == nil { return 0 } return *apiKey.GroupID } // SelectAccountByPreviousResponseID 按 previous_response_id 命中账号粘连。 // 未命中或账号不可用时返回 (nil, nil),由调用方继续走常规调度。 func (s *OpenAIGatewayService) SelectAccountByPreviousResponseID( ctx context.Context, groupID *int64, previousResponseID string, requestedModel string, excludedIDs map[int64]struct{}, ) (*AccountSelectionResult, error) { if s == nil { return nil, nil } responseID := strings.TrimSpace(previousResponseID) if responseID == "" { return nil, nil } store := s.getOpenAIWSStateStore() if store == nil { return nil, nil } accountID, err := store.GetResponseAccount(ctx, derefGroupID(groupID), responseID) if err != nil || accountID <= 0 { return nil, nil } if excludedIDs != nil { if _, excluded := excludedIDs[accountID]; excluded { return nil, nil } } account, err := s.getSchedulableAccount(ctx, accountID) if err != nil || account == nil { _ = store.DeleteResponseAccount(ctx, derefGroupID(groupID), responseID) return nil, nil } // 非 WSv2 场景(如 force_http/全局关闭)不应使用 previous_response_id 粘连, // 以保持“回滚到 HTTP”后的历史行为一致性。 if s.getOpenAIWSProtocolResolver().Resolve(account).Transport != OpenAIUpstreamTransportResponsesWebsocketV2 { return nil, nil } if shouldClearStickySession(account, requestedModel) || !account.IsOpenAI() { _ = store.DeleteResponseAccount(ctx, derefGroupID(groupID), responseID) return nil, nil } if requestedModel != "" && !account.IsModelSupported(requestedModel) { return nil, nil } result, acquireErr := s.tryAcquireAccountSlot(ctx, accountID, account.Concurrency) if acquireErr == nil && result.Acquired { logOpenAIWSBindResponseAccountWarn( derefGroupID(groupID), accountID, responseID, store.BindResponseAccount(ctx, derefGroupID(groupID), responseID, accountID, s.openAIWSResponseStickyTTL()), ) return &AccountSelectionResult{ Account: account, Acquired: true, ReleaseFunc: result.ReleaseFunc, }, nil } cfg := s.schedulingConfig() if s.concurrencyService != nil { return &AccountSelectionResult{ Account: account, WaitPlan: &AccountWaitPlan{ AccountID: accountID, MaxConcurrency: account.Concurrency, Timeout: cfg.StickySessionWaitTimeout, MaxWaiting: cfg.StickySessionMaxWaiting, }, }, nil } return nil, nil } func classifyOpenAIWSAcquireError(err error) string { if err == nil { return "acquire_conn" } var dialErr *openAIWSDialError if errors.As(err, &dialErr) { switch dialErr.StatusCode { case 426: return "upgrade_required" case 401, 403: return "auth_failed" case 429: return "upstream_rate_limited" } if dialErr.StatusCode >= 500 { return "upstream_5xx" } return "dial_failed" } if errors.Is(err, errOpenAIWSConnQueueFull) { return "conn_queue_full" } if errors.Is(err, errOpenAIWSPreferredConnUnavailable) { return "preferred_conn_unavailable" } if errors.Is(err, context.DeadlineExceeded) { return "acquire_timeout" } return "acquire_conn" } func classifyOpenAIWSErrorEventFromRaw(codeRaw, errTypeRaw, msgRaw string) (string, bool) { code := strings.ToLower(strings.TrimSpace(codeRaw)) errType := strings.ToLower(strings.TrimSpace(errTypeRaw)) msg := strings.ToLower(strings.TrimSpace(msgRaw)) switch code { case "upgrade_required": return "upgrade_required", true case "websocket_not_supported", "websocket_unsupported": return "ws_unsupported", true case "websocket_connection_limit_reached": return "ws_connection_limit_reached", true case "previous_response_not_found": return "previous_response_not_found", true } if strings.Contains(msg, "upgrade required") || strings.Contains(msg, "status 426") { return "upgrade_required", true } if strings.Contains(errType, "upgrade") { return "upgrade_required", true } if strings.Contains(msg, "websocket") && strings.Contains(msg, "unsupported") { return "ws_unsupported", true } if strings.Contains(msg, "connection limit") && strings.Contains(msg, "websocket") { return "ws_connection_limit_reached", true } if strings.Contains(msg, "previous_response_not_found") || (strings.Contains(msg, "previous response") && strings.Contains(msg, "not found")) { return "previous_response_not_found", true } if strings.Contains(errType, "server_error") || strings.Contains(code, "server_error") { return "upstream_error_event", true } return "event_error", false } func classifyOpenAIWSErrorEvent(message []byte) (string, bool) { if len(message) == 0 { return "event_error", false } return classifyOpenAIWSErrorEventFromRaw(parseOpenAIWSErrorEventFields(message)) } func openAIWSErrorHTTPStatusFromRaw(codeRaw, errTypeRaw string) int { code := strings.ToLower(strings.TrimSpace(codeRaw)) errType := strings.ToLower(strings.TrimSpace(errTypeRaw)) switch { case strings.Contains(errType, "invalid_request"), strings.Contains(code, "invalid_request"), strings.Contains(code, "bad_request"), code == "previous_response_not_found": return http.StatusBadRequest case strings.Contains(errType, "authentication"), strings.Contains(code, "invalid_api_key"), strings.Contains(code, "unauthorized"): return http.StatusUnauthorized case strings.Contains(errType, "permission"), strings.Contains(code, "forbidden"): return http.StatusForbidden case strings.Contains(errType, "rate_limit"), strings.Contains(code, "rate_limit"), strings.Contains(code, "insufficient_quota"): return http.StatusTooManyRequests default: return http.StatusBadGateway } } func openAIWSErrorHTTPStatus(message []byte) int { if len(message) == 0 { return http.StatusBadGateway } codeRaw, errTypeRaw, _ := parseOpenAIWSErrorEventFields(message) return openAIWSErrorHTTPStatusFromRaw(codeRaw, errTypeRaw) } func (s *OpenAIGatewayService) openAIWSFallbackCooldown() time.Duration { if s == nil || s.cfg == nil { return 30 * time.Second } seconds := s.cfg.Gateway.OpenAIWS.FallbackCooldownSeconds if seconds <= 0 { return 0 } return time.Duration(seconds) * time.Second } func (s *OpenAIGatewayService) isOpenAIWSFallbackCooling(accountID int64) bool { if s == nil || accountID <= 0 { return false } cooldown := s.openAIWSFallbackCooldown() if cooldown <= 0 { return false } rawUntil, ok := s.openaiWSFallbackUntil.Load(accountID) if !ok || rawUntil == nil { return false } until, ok := rawUntil.(time.Time) if !ok || until.IsZero() { s.openaiWSFallbackUntil.Delete(accountID) return false } if time.Now().Before(until) { return true } s.openaiWSFallbackUntil.Delete(accountID) return false } func (s *OpenAIGatewayService) markOpenAIWSFallbackCooling(accountID int64, _ string) { if s == nil || accountID <= 0 { return } cooldown := s.openAIWSFallbackCooldown() if cooldown <= 0 { return } s.openaiWSFallbackUntil.Store(accountID, time.Now().Add(cooldown)) } func (s *OpenAIGatewayService) clearOpenAIWSFallbackCooling(accountID int64) { if s == nil || accountID <= 0 { return } s.openaiWSFallbackUntil.Delete(accountID) }