3956 lines
126 KiB
Go
3956 lines
126 KiB
Go
package service
|
||
|
||
import (
|
||
"bytes"
|
||
"context"
|
||
"encoding/json"
|
||
"errors"
|
||
"fmt"
|
||
"io"
|
||
"math/rand"
|
||
"net"
|
||
"net/http"
|
||
"net/url"
|
||
"sort"
|
||
"strings"
|
||
"time"
|
||
|
||
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
|
||
"github.com/Wei-Shaw/sub2api/internal/pkg/openai"
|
||
"github.com/Wei-Shaw/sub2api/internal/util/responseheaders"
|
||
coderws "github.com/coder/websocket"
|
||
"github.com/gin-gonic/gin"
|
||
"github.com/tidwall/gjson"
|
||
"github.com/tidwall/sjson"
|
||
"go.uber.org/zap"
|
||
)
|
||
|
||
const (
|
||
openAIWSBetaV1Value = "responses_websockets=2026-02-04"
|
||
openAIWSBetaV2Value = "responses_websockets=2026-02-06"
|
||
|
||
openAIWSTurnStateHeader = "x-codex-turn-state"
|
||
openAIWSTurnMetadataHeader = "x-codex-turn-metadata"
|
||
|
||
openAIWSLogValueMaxLen = 160
|
||
openAIWSHeaderValueMaxLen = 120
|
||
openAIWSIDValueMaxLen = 64
|
||
openAIWSEventLogHeadLimit = 20
|
||
openAIWSEventLogEveryN = 50
|
||
openAIWSBufferLogHeadLimit = 8
|
||
openAIWSBufferLogEveryN = 20
|
||
openAIWSPrewarmEventLogHead = 10
|
||
openAIWSPayloadKeySizeTopN = 6
|
||
|
||
openAIWSPayloadSizeEstimateDepth = 3
|
||
openAIWSPayloadSizeEstimateMaxBytes = 64 * 1024
|
||
openAIWSPayloadSizeEstimateMaxItems = 16
|
||
|
||
openAIWSEventFlushBatchSizeDefault = 4
|
||
openAIWSEventFlushIntervalDefault = 25 * time.Millisecond
|
||
openAIWSPayloadLogSampleDefault = 0.2
|
||
|
||
openAIWSStoreDisabledConnModeStrict = "strict"
|
||
openAIWSStoreDisabledConnModeAdaptive = "adaptive"
|
||
openAIWSStoreDisabledConnModeOff = "off"
|
||
|
||
openAIWSIngressStagePreviousResponseNotFound = "previous_response_not_found"
|
||
openAIWSMaxPrevResponseIDDeletePasses = 8
|
||
)
|
||
|
||
var openAIWSLogValueReplacer = strings.NewReplacer(
|
||
"error", "err",
|
||
"fallback", "fb",
|
||
"warning", "warnx",
|
||
"failed", "fail",
|
||
)
|
||
|
||
var openAIWSIngressPreflightPingIdle = 20 * time.Second
|
||
|
||
// openAIWSFallbackError 表示可安全回退到 HTTP 的 WS 错误(尚未写下游)。
|
||
type openAIWSFallbackError struct {
|
||
Reason string
|
||
Err error
|
||
}
|
||
|
||
func (e *openAIWSFallbackError) Error() string {
|
||
if e == nil {
|
||
return ""
|
||
}
|
||
if e.Err == nil {
|
||
return fmt.Sprintf("openai ws fallback: %s", strings.TrimSpace(e.Reason))
|
||
}
|
||
return fmt.Sprintf("openai ws fallback: %s: %v", strings.TrimSpace(e.Reason), e.Err)
|
||
}
|
||
|
||
func (e *openAIWSFallbackError) Unwrap() error {
|
||
if e == nil {
|
||
return nil
|
||
}
|
||
return e.Err
|
||
}
|
||
|
||
func wrapOpenAIWSFallback(reason string, err error) error {
|
||
return &openAIWSFallbackError{Reason: strings.TrimSpace(reason), Err: err}
|
||
}
|
||
|
||
// OpenAIWSClientCloseError 表示应以指定 WebSocket close code 主动关闭客户端连接的错误。
|
||
type OpenAIWSClientCloseError struct {
|
||
statusCode coderws.StatusCode
|
||
reason string
|
||
err error
|
||
}
|
||
|
||
type openAIWSIngressTurnError struct {
|
||
stage string
|
||
cause error
|
||
wroteDownstream bool
|
||
}
|
||
|
||
func (e *openAIWSIngressTurnError) Error() string {
|
||
if e == nil {
|
||
return ""
|
||
}
|
||
if e.cause == nil {
|
||
return strings.TrimSpace(e.stage)
|
||
}
|
||
return e.cause.Error()
|
||
}
|
||
|
||
func (e *openAIWSIngressTurnError) Unwrap() error {
|
||
if e == nil {
|
||
return nil
|
||
}
|
||
return e.cause
|
||
}
|
||
|
||
func wrapOpenAIWSIngressTurnError(stage string, cause error, wroteDownstream bool) error {
|
||
if cause == nil {
|
||
return nil
|
||
}
|
||
return &openAIWSIngressTurnError{
|
||
stage: strings.TrimSpace(stage),
|
||
cause: cause,
|
||
wroteDownstream: wroteDownstream,
|
||
}
|
||
}
|
||
|
||
func isOpenAIWSIngressTurnRetryable(err error) bool {
|
||
var turnErr *openAIWSIngressTurnError
|
||
if !errors.As(err, &turnErr) || turnErr == nil {
|
||
return false
|
||
}
|
||
if errors.Is(turnErr.cause, context.Canceled) || errors.Is(turnErr.cause, context.DeadlineExceeded) {
|
||
return false
|
||
}
|
||
if turnErr.wroteDownstream {
|
||
return false
|
||
}
|
||
switch turnErr.stage {
|
||
case "write_upstream", "read_upstream":
|
||
return true
|
||
default:
|
||
return false
|
||
}
|
||
}
|
||
|
||
func openAIWSIngressTurnRetryReason(err error) string {
|
||
var turnErr *openAIWSIngressTurnError
|
||
if !errors.As(err, &turnErr) || turnErr == nil {
|
||
return "unknown"
|
||
}
|
||
if turnErr.stage == "" {
|
||
return "unknown"
|
||
}
|
||
return turnErr.stage
|
||
}
|
||
|
||
func isOpenAIWSIngressPreviousResponseNotFound(err error) bool {
|
||
var turnErr *openAIWSIngressTurnError
|
||
if !errors.As(err, &turnErr) || turnErr == nil {
|
||
return false
|
||
}
|
||
if strings.TrimSpace(turnErr.stage) != openAIWSIngressStagePreviousResponseNotFound {
|
||
return false
|
||
}
|
||
return !turnErr.wroteDownstream
|
||
}
|
||
|
||
// NewOpenAIWSClientCloseError 创建一个客户端 WS 关闭错误。
|
||
func NewOpenAIWSClientCloseError(statusCode coderws.StatusCode, reason string, err error) error {
|
||
return &OpenAIWSClientCloseError{
|
||
statusCode: statusCode,
|
||
reason: strings.TrimSpace(reason),
|
||
err: err,
|
||
}
|
||
}
|
||
|
||
func (e *OpenAIWSClientCloseError) Error() string {
|
||
if e == nil {
|
||
return ""
|
||
}
|
||
if e.err == nil {
|
||
return fmt.Sprintf("openai ws client close: %d %s", int(e.statusCode), strings.TrimSpace(e.reason))
|
||
}
|
||
return fmt.Sprintf("openai ws client close: %d %s: %v", int(e.statusCode), strings.TrimSpace(e.reason), e.err)
|
||
}
|
||
|
||
func (e *OpenAIWSClientCloseError) Unwrap() error {
|
||
if e == nil {
|
||
return nil
|
||
}
|
||
return e.err
|
||
}
|
||
|
||
func (e *OpenAIWSClientCloseError) StatusCode() coderws.StatusCode {
|
||
if e == nil {
|
||
return coderws.StatusInternalError
|
||
}
|
||
return e.statusCode
|
||
}
|
||
|
||
func (e *OpenAIWSClientCloseError) Reason() string {
|
||
if e == nil {
|
||
return ""
|
||
}
|
||
return strings.TrimSpace(e.reason)
|
||
}
|
||
|
||
// OpenAIWSIngressHooks 定义入站 WS 每个 turn 的生命周期回调。
|
||
type OpenAIWSIngressHooks struct {
|
||
BeforeTurn func(turn int) error
|
||
AfterTurn func(turn int, result *OpenAIForwardResult, turnErr error)
|
||
}
|
||
|
||
func normalizeOpenAIWSLogValue(value string) string {
|
||
trimmed := strings.TrimSpace(value)
|
||
if trimmed == "" {
|
||
return "-"
|
||
}
|
||
return openAIWSLogValueReplacer.Replace(trimmed)
|
||
}
|
||
|
||
func truncateOpenAIWSLogValue(value string, maxLen int) string {
|
||
normalized := normalizeOpenAIWSLogValue(value)
|
||
if normalized == "-" || maxLen <= 0 {
|
||
return normalized
|
||
}
|
||
if len(normalized) <= maxLen {
|
||
return normalized
|
||
}
|
||
return normalized[:maxLen] + "..."
|
||
}
|
||
|
||
func openAIWSHeaderValueForLog(headers http.Header, key string) string {
|
||
if headers == nil {
|
||
return "-"
|
||
}
|
||
return truncateOpenAIWSLogValue(headers.Get(key), openAIWSHeaderValueMaxLen)
|
||
}
|
||
|
||
func hasOpenAIWSHeader(headers http.Header, key string) bool {
|
||
if headers == nil {
|
||
return false
|
||
}
|
||
return strings.TrimSpace(headers.Get(key)) != ""
|
||
}
|
||
|
||
type openAIWSSessionHeaderResolution struct {
|
||
SessionID string
|
||
ConversationID string
|
||
SessionSource string
|
||
ConversationSource string
|
||
}
|
||
|
||
func resolveOpenAIWSSessionHeaders(c *gin.Context, promptCacheKey string) openAIWSSessionHeaderResolution {
|
||
resolution := openAIWSSessionHeaderResolution{
|
||
SessionSource: "none",
|
||
ConversationSource: "none",
|
||
}
|
||
if c != nil && c.Request != nil {
|
||
if sessionID := strings.TrimSpace(c.Request.Header.Get("session_id")); sessionID != "" {
|
||
resolution.SessionID = sessionID
|
||
resolution.SessionSource = "header_session_id"
|
||
}
|
||
if conversationID := strings.TrimSpace(c.Request.Header.Get("conversation_id")); conversationID != "" {
|
||
resolution.ConversationID = conversationID
|
||
resolution.ConversationSource = "header_conversation_id"
|
||
if resolution.SessionID == "" {
|
||
resolution.SessionID = conversationID
|
||
resolution.SessionSource = "header_conversation_id"
|
||
}
|
||
}
|
||
}
|
||
|
||
cacheKey := strings.TrimSpace(promptCacheKey)
|
||
if cacheKey != "" {
|
||
if resolution.SessionID == "" {
|
||
resolution.SessionID = cacheKey
|
||
resolution.SessionSource = "prompt_cache_key"
|
||
}
|
||
}
|
||
return resolution
|
||
}
|
||
|
||
func shouldLogOpenAIWSEvent(idx int, eventType string) bool {
|
||
if idx <= openAIWSEventLogHeadLimit {
|
||
return true
|
||
}
|
||
if openAIWSEventLogEveryN > 0 && idx%openAIWSEventLogEveryN == 0 {
|
||
return true
|
||
}
|
||
if eventType == "error" || isOpenAIWSTerminalEvent(eventType) {
|
||
return true
|
||
}
|
||
return false
|
||
}
|
||
|
||
func shouldLogOpenAIWSBufferedEvent(idx int) bool {
|
||
if idx <= openAIWSBufferLogHeadLimit {
|
||
return true
|
||
}
|
||
if openAIWSBufferLogEveryN > 0 && idx%openAIWSBufferLogEveryN == 0 {
|
||
return true
|
||
}
|
||
return false
|
||
}
|
||
|
||
func openAIWSEventMayContainModel(eventType string) bool {
|
||
switch eventType {
|
||
case "response.created",
|
||
"response.in_progress",
|
||
"response.completed",
|
||
"response.done",
|
||
"response.failed",
|
||
"response.incomplete",
|
||
"response.cancelled",
|
||
"response.canceled":
|
||
return true
|
||
default:
|
||
trimmed := strings.TrimSpace(eventType)
|
||
if trimmed == eventType {
|
||
return false
|
||
}
|
||
switch trimmed {
|
||
case "response.created",
|
||
"response.in_progress",
|
||
"response.completed",
|
||
"response.done",
|
||
"response.failed",
|
||
"response.incomplete",
|
||
"response.cancelled",
|
||
"response.canceled":
|
||
return true
|
||
default:
|
||
return false
|
||
}
|
||
}
|
||
}
|
||
|
||
func openAIWSEventMayContainToolCalls(eventType string) bool {
|
||
eventType = strings.TrimSpace(eventType)
|
||
if eventType == "" {
|
||
return false
|
||
}
|
||
if strings.Contains(eventType, "function_call") || strings.Contains(eventType, "tool_call") {
|
||
return true
|
||
}
|
||
switch eventType {
|
||
case "response.output_item.added", "response.output_item.done", "response.completed", "response.done":
|
||
return true
|
||
default:
|
||
return false
|
||
}
|
||
}
|
||
|
||
func openAIWSEventShouldParseUsage(eventType string) bool {
|
||
return eventType == "response.completed" || strings.TrimSpace(eventType) == "response.completed"
|
||
}
|
||
|
||
func parseOpenAIWSEventEnvelope(message []byte) (eventType string, responseID string, response gjson.Result) {
|
||
if len(message) == 0 {
|
||
return "", "", gjson.Result{}
|
||
}
|
||
values := gjson.GetManyBytes(message, "type", "response.id", "id", "response")
|
||
eventType = strings.TrimSpace(values[0].String())
|
||
if id := strings.TrimSpace(values[1].String()); id != "" {
|
||
responseID = id
|
||
} else {
|
||
responseID = strings.TrimSpace(values[2].String())
|
||
}
|
||
return eventType, responseID, values[3]
|
||
}
|
||
|
||
func openAIWSMessageLikelyContainsToolCalls(message []byte) bool {
|
||
if len(message) == 0 {
|
||
return false
|
||
}
|
||
return bytes.Contains(message, []byte(`"tool_calls"`)) ||
|
||
bytes.Contains(message, []byte(`"tool_call"`)) ||
|
||
bytes.Contains(message, []byte(`"function_call"`))
|
||
}
|
||
|
||
func parseOpenAIWSResponseUsageFromCompletedEvent(message []byte, usage *OpenAIUsage) {
|
||
if usage == nil || len(message) == 0 {
|
||
return
|
||
}
|
||
values := gjson.GetManyBytes(
|
||
message,
|
||
"response.usage.input_tokens",
|
||
"response.usage.output_tokens",
|
||
"response.usage.input_tokens_details.cached_tokens",
|
||
)
|
||
usage.InputTokens = int(values[0].Int())
|
||
usage.OutputTokens = int(values[1].Int())
|
||
usage.CacheReadInputTokens = int(values[2].Int())
|
||
}
|
||
|
||
func parseOpenAIWSErrorEventFields(message []byte) (code string, errType string, errMessage string) {
|
||
if len(message) == 0 {
|
||
return "", "", ""
|
||
}
|
||
values := gjson.GetManyBytes(message, "error.code", "error.type", "error.message")
|
||
return strings.TrimSpace(values[0].String()), strings.TrimSpace(values[1].String()), strings.TrimSpace(values[2].String())
|
||
}
|
||
|
||
func summarizeOpenAIWSErrorEventFieldsFromRaw(codeRaw, errTypeRaw, errMessageRaw string) (code string, errType string, errMessage string) {
|
||
code = truncateOpenAIWSLogValue(codeRaw, openAIWSLogValueMaxLen)
|
||
errType = truncateOpenAIWSLogValue(errTypeRaw, openAIWSLogValueMaxLen)
|
||
errMessage = truncateOpenAIWSLogValue(errMessageRaw, openAIWSLogValueMaxLen)
|
||
return code, errType, errMessage
|
||
}
|
||
|
||
func summarizeOpenAIWSErrorEventFields(message []byte) (code string, errType string, errMessage string) {
|
||
if len(message) == 0 {
|
||
return "-", "-", "-"
|
||
}
|
||
return summarizeOpenAIWSErrorEventFieldsFromRaw(parseOpenAIWSErrorEventFields(message))
|
||
}
|
||
|
||
func summarizeOpenAIWSPayloadKeySizes(payload map[string]any, topN int) string {
|
||
if len(payload) == 0 {
|
||
return "-"
|
||
}
|
||
type keySize struct {
|
||
Key string
|
||
Size int
|
||
}
|
||
sizes := make([]keySize, 0, len(payload))
|
||
for key, value := range payload {
|
||
size := estimateOpenAIWSPayloadValueSize(value, openAIWSPayloadSizeEstimateDepth)
|
||
sizes = append(sizes, keySize{Key: key, Size: size})
|
||
}
|
||
sort.Slice(sizes, func(i, j int) bool {
|
||
if sizes[i].Size == sizes[j].Size {
|
||
return sizes[i].Key < sizes[j].Key
|
||
}
|
||
return sizes[i].Size > sizes[j].Size
|
||
})
|
||
|
||
if topN <= 0 || topN > len(sizes) {
|
||
topN = len(sizes)
|
||
}
|
||
parts := make([]string, 0, topN)
|
||
for idx := 0; idx < topN; idx++ {
|
||
item := sizes[idx]
|
||
parts = append(parts, fmt.Sprintf("%s:%d", item.Key, item.Size))
|
||
}
|
||
return strings.Join(parts, ",")
|
||
}
|
||
|
||
func estimateOpenAIWSPayloadValueSize(value any, depth int) int {
|
||
if depth <= 0 {
|
||
return -1
|
||
}
|
||
switch v := value.(type) {
|
||
case nil:
|
||
return 0
|
||
case string:
|
||
return len(v)
|
||
case []byte:
|
||
return len(v)
|
||
case bool:
|
||
return 1
|
||
case int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64:
|
||
return 8
|
||
case float32, float64:
|
||
return 8
|
||
case map[string]any:
|
||
if len(v) == 0 {
|
||
return 2
|
||
}
|
||
total := 2
|
||
count := 0
|
||
for key, item := range v {
|
||
count++
|
||
if count > openAIWSPayloadSizeEstimateMaxItems {
|
||
return -1
|
||
}
|
||
itemSize := estimateOpenAIWSPayloadValueSize(item, depth-1)
|
||
if itemSize < 0 {
|
||
return -1
|
||
}
|
||
total += len(key) + itemSize + 3
|
||
if total > openAIWSPayloadSizeEstimateMaxBytes {
|
||
return -1
|
||
}
|
||
}
|
||
return total
|
||
case []any:
|
||
if len(v) == 0 {
|
||
return 2
|
||
}
|
||
total := 2
|
||
limit := len(v)
|
||
if limit > openAIWSPayloadSizeEstimateMaxItems {
|
||
return -1
|
||
}
|
||
for i := 0; i < limit; i++ {
|
||
itemSize := estimateOpenAIWSPayloadValueSize(v[i], depth-1)
|
||
if itemSize < 0 {
|
||
return -1
|
||
}
|
||
total += itemSize + 1
|
||
if total > openAIWSPayloadSizeEstimateMaxBytes {
|
||
return -1
|
||
}
|
||
}
|
||
return total
|
||
default:
|
||
raw, err := json.Marshal(v)
|
||
if err != nil {
|
||
return -1
|
||
}
|
||
if len(raw) > openAIWSPayloadSizeEstimateMaxBytes {
|
||
return -1
|
||
}
|
||
return len(raw)
|
||
}
|
||
}
|
||
|
||
func openAIWSPayloadString(payload map[string]any, key string) string {
|
||
if len(payload) == 0 {
|
||
return ""
|
||
}
|
||
raw, ok := payload[key]
|
||
if !ok {
|
||
return ""
|
||
}
|
||
switch v := raw.(type) {
|
||
case nil:
|
||
return ""
|
||
case string:
|
||
return strings.TrimSpace(v)
|
||
case []byte:
|
||
return strings.TrimSpace(string(v))
|
||
default:
|
||
return ""
|
||
}
|
||
}
|
||
|
||
func openAIWSPayloadStringFromRaw(payload []byte, key string) string {
|
||
if len(payload) == 0 || strings.TrimSpace(key) == "" {
|
||
return ""
|
||
}
|
||
return strings.TrimSpace(gjson.GetBytes(payload, key).String())
|
||
}
|
||
|
||
func openAIWSPayloadBoolFromRaw(payload []byte, key string, defaultValue bool) bool {
|
||
if len(payload) == 0 || strings.TrimSpace(key) == "" {
|
||
return defaultValue
|
||
}
|
||
value := gjson.GetBytes(payload, key)
|
||
if !value.Exists() {
|
||
return defaultValue
|
||
}
|
||
if value.Type != gjson.True && value.Type != gjson.False {
|
||
return defaultValue
|
||
}
|
||
return value.Bool()
|
||
}
|
||
|
||
func openAIWSSessionHashesFromID(sessionID string) (string, string) {
|
||
return deriveOpenAISessionHashes(sessionID)
|
||
}
|
||
|
||
func extractOpenAIWSImageURL(value any) string {
|
||
switch v := value.(type) {
|
||
case string:
|
||
return strings.TrimSpace(v)
|
||
case map[string]any:
|
||
if raw, ok := v["url"].(string); ok {
|
||
return strings.TrimSpace(raw)
|
||
}
|
||
}
|
||
return ""
|
||
}
|
||
|
||
func summarizeOpenAIWSInput(input any) string {
|
||
items, ok := input.([]any)
|
||
if !ok || len(items) == 0 {
|
||
return "-"
|
||
}
|
||
|
||
itemCount := len(items)
|
||
textChars := 0
|
||
imageDataURLs := 0
|
||
imageDataURLChars := 0
|
||
imageRemoteURLs := 0
|
||
|
||
handleContentItem := func(contentItem map[string]any) {
|
||
contentType, _ := contentItem["type"].(string)
|
||
switch strings.TrimSpace(contentType) {
|
||
case "input_text", "output_text", "text":
|
||
if text, ok := contentItem["text"].(string); ok {
|
||
textChars += len(text)
|
||
}
|
||
case "input_image":
|
||
imageURL := extractOpenAIWSImageURL(contentItem["image_url"])
|
||
if imageURL == "" {
|
||
return
|
||
}
|
||
if strings.HasPrefix(strings.ToLower(imageURL), "data:image/") {
|
||
imageDataURLs++
|
||
imageDataURLChars += len(imageURL)
|
||
return
|
||
}
|
||
imageRemoteURLs++
|
||
}
|
||
}
|
||
|
||
handleInputItem := func(inputItem map[string]any) {
|
||
if content, ok := inputItem["content"].([]any); ok {
|
||
for _, rawContent := range content {
|
||
contentItem, ok := rawContent.(map[string]any)
|
||
if !ok {
|
||
continue
|
||
}
|
||
handleContentItem(contentItem)
|
||
}
|
||
return
|
||
}
|
||
|
||
itemType, _ := inputItem["type"].(string)
|
||
switch strings.TrimSpace(itemType) {
|
||
case "input_text", "output_text", "text":
|
||
if text, ok := inputItem["text"].(string); ok {
|
||
textChars += len(text)
|
||
}
|
||
case "input_image":
|
||
imageURL := extractOpenAIWSImageURL(inputItem["image_url"])
|
||
if imageURL == "" {
|
||
return
|
||
}
|
||
if strings.HasPrefix(strings.ToLower(imageURL), "data:image/") {
|
||
imageDataURLs++
|
||
imageDataURLChars += len(imageURL)
|
||
return
|
||
}
|
||
imageRemoteURLs++
|
||
}
|
||
}
|
||
|
||
for _, rawItem := range items {
|
||
inputItem, ok := rawItem.(map[string]any)
|
||
if !ok {
|
||
continue
|
||
}
|
||
handleInputItem(inputItem)
|
||
}
|
||
|
||
return fmt.Sprintf(
|
||
"items=%d,text_chars=%d,image_data_urls=%d,image_data_url_chars=%d,image_remote_urls=%d",
|
||
itemCount,
|
||
textChars,
|
||
imageDataURLs,
|
||
imageDataURLChars,
|
||
imageRemoteURLs,
|
||
)
|
||
}
|
||
|
||
func dropOpenAIWSPayloadKey(payload map[string]any, key string, removed *[]string) {
|
||
if len(payload) == 0 || strings.TrimSpace(key) == "" {
|
||
return
|
||
}
|
||
if _, exists := payload[key]; !exists {
|
||
return
|
||
}
|
||
delete(payload, key)
|
||
*removed = append(*removed, key)
|
||
}
|
||
|
||
// applyOpenAIWSRetryPayloadStrategy 在 WS 连续失败时仅移除无语义字段,
|
||
// 避免重试成功却改变原始请求语义。
|
||
// 注意:prompt_cache_key 不应在重试中移除;它常用于会话稳定标识(session_id 兜底)。
|
||
func applyOpenAIWSRetryPayloadStrategy(payload map[string]any, attempt int) (strategy string, removedKeys []string) {
|
||
if len(payload) == 0 {
|
||
return "empty", nil
|
||
}
|
||
if attempt <= 1 {
|
||
return "full", nil
|
||
}
|
||
|
||
removed := make([]string, 0, 2)
|
||
if attempt >= 2 {
|
||
dropOpenAIWSPayloadKey(payload, "include", &removed)
|
||
}
|
||
|
||
if len(removed) == 0 {
|
||
return "full", nil
|
||
}
|
||
sort.Strings(removed)
|
||
return "trim_optional_fields", removed
|
||
}
|
||
|
||
func logOpenAIWSModeInfo(format string, args ...any) {
|
||
logger.LegacyPrintf("service.openai_gateway", "[OpenAI WS Mode][openai_ws_mode=true] "+format, args...)
|
||
}
|
||
|
||
func isOpenAIWSModeDebugEnabled() bool {
|
||
return logger.L().Core().Enabled(zap.DebugLevel)
|
||
}
|
||
|
||
func logOpenAIWSModeDebug(format string, args ...any) {
|
||
if !isOpenAIWSModeDebugEnabled() {
|
||
return
|
||
}
|
||
logger.LegacyPrintf("service.openai_gateway", "[debug] [OpenAI WS Mode][openai_ws_mode=true] "+format, args...)
|
||
}
|
||
|
||
func logOpenAIWSBindResponseAccountWarn(groupID, accountID int64, responseID string, err error) {
|
||
if err == nil {
|
||
return
|
||
}
|
||
logger.L().Warn(
|
||
"openai.ws_bind_response_account_failed",
|
||
zap.Int64("group_id", groupID),
|
||
zap.Int64("account_id", accountID),
|
||
zap.String("response_id", truncateOpenAIWSLogValue(responseID, openAIWSIDValueMaxLen)),
|
||
zap.Error(err),
|
||
)
|
||
}
|
||
|
||
func summarizeOpenAIWSReadCloseError(err error) (status string, reason string) {
|
||
if err == nil {
|
||
return "-", "-"
|
||
}
|
||
statusCode := coderws.CloseStatus(err)
|
||
if statusCode == -1 {
|
||
return "-", "-"
|
||
}
|
||
closeStatus := fmt.Sprintf("%d(%s)", int(statusCode), statusCode.String())
|
||
closeReason := "-"
|
||
var closeErr coderws.CloseError
|
||
if errors.As(err, &closeErr) {
|
||
reasonText := strings.TrimSpace(closeErr.Reason)
|
||
if reasonText != "" {
|
||
closeReason = normalizeOpenAIWSLogValue(reasonText)
|
||
}
|
||
}
|
||
return normalizeOpenAIWSLogValue(closeStatus), closeReason
|
||
}
|
||
|
||
func unwrapOpenAIWSDialBaseError(err error) error {
|
||
if err == nil {
|
||
return nil
|
||
}
|
||
var dialErr *openAIWSDialError
|
||
if errors.As(err, &dialErr) && dialErr != nil && dialErr.Err != nil {
|
||
return dialErr.Err
|
||
}
|
||
return err
|
||
}
|
||
|
||
func openAIWSDialRespHeaderForLog(err error, key string) string {
|
||
var dialErr *openAIWSDialError
|
||
if !errors.As(err, &dialErr) || dialErr == nil || dialErr.ResponseHeaders == nil {
|
||
return "-"
|
||
}
|
||
return truncateOpenAIWSLogValue(dialErr.ResponseHeaders.Get(key), openAIWSHeaderValueMaxLen)
|
||
}
|
||
|
||
func classifyOpenAIWSDialError(err error) string {
|
||
if err == nil {
|
||
return "-"
|
||
}
|
||
baseErr := unwrapOpenAIWSDialBaseError(err)
|
||
if baseErr == nil {
|
||
return "-"
|
||
}
|
||
if errors.Is(baseErr, context.DeadlineExceeded) {
|
||
return "ctx_deadline_exceeded"
|
||
}
|
||
if errors.Is(baseErr, context.Canceled) {
|
||
return "ctx_canceled"
|
||
}
|
||
var netErr net.Error
|
||
if errors.As(baseErr, &netErr) && netErr.Timeout() {
|
||
return "net_timeout"
|
||
}
|
||
if status := coderws.CloseStatus(baseErr); status != -1 {
|
||
return normalizeOpenAIWSLogValue(fmt.Sprintf("ws_close_%d", int(status)))
|
||
}
|
||
message := strings.ToLower(strings.TrimSpace(baseErr.Error()))
|
||
switch {
|
||
case strings.Contains(message, "handshake not finished"):
|
||
return "handshake_not_finished"
|
||
case strings.Contains(message, "bad handshake"):
|
||
return "bad_handshake"
|
||
case strings.Contains(message, "connection refused"):
|
||
return "connection_refused"
|
||
case strings.Contains(message, "no such host"):
|
||
return "dns_not_found"
|
||
case strings.Contains(message, "tls"):
|
||
return "tls_error"
|
||
case strings.Contains(message, "i/o timeout"):
|
||
return "io_timeout"
|
||
case strings.Contains(message, "context deadline exceeded"):
|
||
return "ctx_deadline_exceeded"
|
||
default:
|
||
return "dial_error"
|
||
}
|
||
}
|
||
|
||
func summarizeOpenAIWSDialError(err error) (
|
||
statusCode int,
|
||
dialClass string,
|
||
closeStatus string,
|
||
closeReason string,
|
||
respServer string,
|
||
respVia string,
|
||
respCFRay string,
|
||
respRequestID string,
|
||
) {
|
||
dialClass = "-"
|
||
closeStatus = "-"
|
||
closeReason = "-"
|
||
respServer = "-"
|
||
respVia = "-"
|
||
respCFRay = "-"
|
||
respRequestID = "-"
|
||
if err == nil {
|
||
return
|
||
}
|
||
var dialErr *openAIWSDialError
|
||
if errors.As(err, &dialErr) && dialErr != nil {
|
||
statusCode = dialErr.StatusCode
|
||
respServer = openAIWSDialRespHeaderForLog(err, "server")
|
||
respVia = openAIWSDialRespHeaderForLog(err, "via")
|
||
respCFRay = openAIWSDialRespHeaderForLog(err, "cf-ray")
|
||
respRequestID = openAIWSDialRespHeaderForLog(err, "x-request-id")
|
||
}
|
||
dialClass = normalizeOpenAIWSLogValue(classifyOpenAIWSDialError(err))
|
||
closeStatus, closeReason = summarizeOpenAIWSReadCloseError(unwrapOpenAIWSDialBaseError(err))
|
||
return
|
||
}
|
||
|
||
func isOpenAIWSClientDisconnectError(err error) bool {
|
||
if err == nil {
|
||
return false
|
||
}
|
||
if errors.Is(err, io.EOF) || errors.Is(err, net.ErrClosed) || errors.Is(err, context.Canceled) {
|
||
return true
|
||
}
|
||
switch coderws.CloseStatus(err) {
|
||
case coderws.StatusNormalClosure, coderws.StatusGoingAway, coderws.StatusNoStatusRcvd, coderws.StatusAbnormalClosure:
|
||
return true
|
||
}
|
||
message := strings.ToLower(strings.TrimSpace(err.Error()))
|
||
if message == "" {
|
||
return false
|
||
}
|
||
return strings.Contains(message, "failed to read frame header: eof") ||
|
||
strings.Contains(message, "unexpected eof") ||
|
||
strings.Contains(message, "use of closed network connection") ||
|
||
strings.Contains(message, "connection reset by peer") ||
|
||
strings.Contains(message, "broken pipe")
|
||
}
|
||
|
||
func classifyOpenAIWSReadFallbackReason(err error) string {
|
||
if err == nil {
|
||
return "read_event"
|
||
}
|
||
switch coderws.CloseStatus(err) {
|
||
case coderws.StatusPolicyViolation:
|
||
return "policy_violation"
|
||
case coderws.StatusMessageTooBig:
|
||
return "message_too_big"
|
||
default:
|
||
return "read_event"
|
||
}
|
||
}
|
||
|
||
func sortedKeys(m map[string]any) []string {
|
||
if len(m) == 0 {
|
||
return nil
|
||
}
|
||
keys := make([]string, 0, len(m))
|
||
for k := range m {
|
||
keys = append(keys, k)
|
||
}
|
||
sort.Strings(keys)
|
||
return keys
|
||
}
|
||
|
||
func (s *OpenAIGatewayService) getOpenAIWSConnPool() *openAIWSConnPool {
|
||
if s == nil {
|
||
return nil
|
||
}
|
||
s.openaiWSPoolOnce.Do(func() {
|
||
if s.openaiWSPool == nil {
|
||
s.openaiWSPool = newOpenAIWSConnPool(s.cfg)
|
||
}
|
||
})
|
||
return s.openaiWSPool
|
||
}
|
||
|
||
func (s *OpenAIGatewayService) SnapshotOpenAIWSPoolMetrics() OpenAIWSPoolMetricsSnapshot {
|
||
pool := s.getOpenAIWSConnPool()
|
||
if pool == nil {
|
||
return OpenAIWSPoolMetricsSnapshot{}
|
||
}
|
||
return pool.SnapshotMetrics()
|
||
}
|
||
|
||
type OpenAIWSPerformanceMetricsSnapshot struct {
|
||
Pool OpenAIWSPoolMetricsSnapshot `json:"pool"`
|
||
Retry OpenAIWSRetryMetricsSnapshot `json:"retry"`
|
||
Transport OpenAIWSTransportMetricsSnapshot `json:"transport"`
|
||
}
|
||
|
||
func (s *OpenAIGatewayService) SnapshotOpenAIWSPerformanceMetrics() OpenAIWSPerformanceMetricsSnapshot {
|
||
pool := s.getOpenAIWSConnPool()
|
||
snapshot := OpenAIWSPerformanceMetricsSnapshot{
|
||
Retry: s.SnapshotOpenAIWSRetryMetrics(),
|
||
}
|
||
if pool == nil {
|
||
return snapshot
|
||
}
|
||
snapshot.Pool = pool.SnapshotMetrics()
|
||
snapshot.Transport = pool.SnapshotTransportMetrics()
|
||
return snapshot
|
||
}
|
||
|
||
func (s *OpenAIGatewayService) getOpenAIWSStateStore() OpenAIWSStateStore {
|
||
if s == nil {
|
||
return nil
|
||
}
|
||
s.openaiWSStateStoreOnce.Do(func() {
|
||
if s.openaiWSStateStore == nil {
|
||
s.openaiWSStateStore = NewOpenAIWSStateStore(s.cache)
|
||
}
|
||
})
|
||
return s.openaiWSStateStore
|
||
}
|
||
|
||
func (s *OpenAIGatewayService) openAIWSResponseStickyTTL() time.Duration {
|
||
if s != nil && s.cfg != nil {
|
||
seconds := s.cfg.Gateway.OpenAIWS.StickyResponseIDTTLSeconds
|
||
if seconds > 0 {
|
||
return time.Duration(seconds) * time.Second
|
||
}
|
||
}
|
||
return time.Hour
|
||
}
|
||
|
||
func (s *OpenAIGatewayService) openAIWSIngressPreviousResponseRecoveryEnabled() bool {
|
||
if s != nil && s.cfg != nil {
|
||
return s.cfg.Gateway.OpenAIWS.IngressPreviousResponseRecoveryEnabled
|
||
}
|
||
return true
|
||
}
|
||
|
||
func (s *OpenAIGatewayService) openAIWSReadTimeout() time.Duration {
|
||
if s != nil && s.cfg != nil && s.cfg.Gateway.OpenAIWS.ReadTimeoutSeconds > 0 {
|
||
return time.Duration(s.cfg.Gateway.OpenAIWS.ReadTimeoutSeconds) * time.Second
|
||
}
|
||
return 15 * time.Minute
|
||
}
|
||
|
||
func (s *OpenAIGatewayService) openAIWSWriteTimeout() time.Duration {
|
||
if s != nil && s.cfg != nil && s.cfg.Gateway.OpenAIWS.WriteTimeoutSeconds > 0 {
|
||
return time.Duration(s.cfg.Gateway.OpenAIWS.WriteTimeoutSeconds) * time.Second
|
||
}
|
||
return 2 * time.Minute
|
||
}
|
||
|
||
func (s *OpenAIGatewayService) openAIWSEventFlushBatchSize() int {
|
||
if s != nil && s.cfg != nil && s.cfg.Gateway.OpenAIWS.EventFlushBatchSize > 0 {
|
||
return s.cfg.Gateway.OpenAIWS.EventFlushBatchSize
|
||
}
|
||
return openAIWSEventFlushBatchSizeDefault
|
||
}
|
||
|
||
func (s *OpenAIGatewayService) openAIWSEventFlushInterval() time.Duration {
|
||
if s != nil && s.cfg != nil && s.cfg.Gateway.OpenAIWS.EventFlushIntervalMS >= 0 {
|
||
if s.cfg.Gateway.OpenAIWS.EventFlushIntervalMS == 0 {
|
||
return 0
|
||
}
|
||
return time.Duration(s.cfg.Gateway.OpenAIWS.EventFlushIntervalMS) * time.Millisecond
|
||
}
|
||
return openAIWSEventFlushIntervalDefault
|
||
}
|
||
|
||
func (s *OpenAIGatewayService) openAIWSPayloadLogSampleRate() float64 {
|
||
if s != nil && s.cfg != nil {
|
||
rate := s.cfg.Gateway.OpenAIWS.PayloadLogSampleRate
|
||
if rate < 0 {
|
||
return 0
|
||
}
|
||
if rate > 1 {
|
||
return 1
|
||
}
|
||
return rate
|
||
}
|
||
return openAIWSPayloadLogSampleDefault
|
||
}
|
||
|
||
func (s *OpenAIGatewayService) shouldLogOpenAIWSPayloadSchema(attempt int) bool {
|
||
// 首次尝试保留一条完整 payload_schema 便于排障。
|
||
if attempt <= 1 {
|
||
return true
|
||
}
|
||
rate := s.openAIWSPayloadLogSampleRate()
|
||
if rate <= 0 {
|
||
return false
|
||
}
|
||
if rate >= 1 {
|
||
return true
|
||
}
|
||
return rand.Float64() < rate
|
||
}
|
||
|
||
func (s *OpenAIGatewayService) shouldEmitOpenAIWSPayloadSchema(attempt int) bool {
|
||
if !s.shouldLogOpenAIWSPayloadSchema(attempt) {
|
||
return false
|
||
}
|
||
return logger.L().Core().Enabled(zap.DebugLevel)
|
||
}
|
||
|
||
func (s *OpenAIGatewayService) openAIWSDialTimeout() time.Duration {
|
||
if s != nil && s.cfg != nil && s.cfg.Gateway.OpenAIWS.DialTimeoutSeconds > 0 {
|
||
return time.Duration(s.cfg.Gateway.OpenAIWS.DialTimeoutSeconds) * time.Second
|
||
}
|
||
return 10 * time.Second
|
||
}
|
||
|
||
func (s *OpenAIGatewayService) openAIWSAcquireTimeout() time.Duration {
|
||
// Acquire 覆盖“连接复用命中/排队/新建连接”三个阶段。
|
||
// 这里不再叠加 write_timeout,避免高并发排队时把 TTFT 长尾拉到分钟级。
|
||
dial := s.openAIWSDialTimeout()
|
||
if dial <= 0 {
|
||
dial = 10 * time.Second
|
||
}
|
||
return dial + 2*time.Second
|
||
}
|
||
|
||
func (s *OpenAIGatewayService) buildOpenAIResponsesWSURL(account *Account) (string, error) {
|
||
if account == nil {
|
||
return "", errors.New("account is nil")
|
||
}
|
||
var targetURL string
|
||
switch account.Type {
|
||
case AccountTypeOAuth:
|
||
targetURL = chatgptCodexURL
|
||
case AccountTypeAPIKey:
|
||
baseURL := account.GetOpenAIBaseURL()
|
||
if baseURL == "" {
|
||
targetURL = openaiPlatformAPIURL
|
||
} else {
|
||
validatedURL, err := s.validateUpstreamBaseURL(baseURL)
|
||
if err != nil {
|
||
return "", err
|
||
}
|
||
targetURL = buildOpenAIResponsesURL(validatedURL)
|
||
}
|
||
default:
|
||
targetURL = openaiPlatformAPIURL
|
||
}
|
||
|
||
parsed, err := url.Parse(strings.TrimSpace(targetURL))
|
||
if err != nil {
|
||
return "", fmt.Errorf("invalid target url: %w", err)
|
||
}
|
||
switch strings.ToLower(parsed.Scheme) {
|
||
case "https":
|
||
parsed.Scheme = "wss"
|
||
case "http":
|
||
parsed.Scheme = "ws"
|
||
case "wss", "ws":
|
||
// 保持不变
|
||
default:
|
||
return "", fmt.Errorf("unsupported scheme for ws: %s", parsed.Scheme)
|
||
}
|
||
return parsed.String(), nil
|
||
}
|
||
|
||
func (s *OpenAIGatewayService) buildOpenAIWSHeaders(
|
||
c *gin.Context,
|
||
account *Account,
|
||
token string,
|
||
decision OpenAIWSProtocolDecision,
|
||
isCodexCLI bool,
|
||
turnState string,
|
||
turnMetadata string,
|
||
promptCacheKey string,
|
||
) (http.Header, openAIWSSessionHeaderResolution) {
|
||
headers := make(http.Header)
|
||
headers.Set("authorization", "Bearer "+token)
|
||
|
||
sessionResolution := resolveOpenAIWSSessionHeaders(c, promptCacheKey)
|
||
if c != nil && c.Request != nil {
|
||
if v := strings.TrimSpace(c.Request.Header.Get("accept-language")); v != "" {
|
||
headers.Set("accept-language", v)
|
||
}
|
||
}
|
||
if sessionResolution.SessionID != "" {
|
||
headers.Set("session_id", sessionResolution.SessionID)
|
||
}
|
||
if sessionResolution.ConversationID != "" {
|
||
headers.Set("conversation_id", sessionResolution.ConversationID)
|
||
}
|
||
if state := strings.TrimSpace(turnState); state != "" {
|
||
headers.Set(openAIWSTurnStateHeader, state)
|
||
}
|
||
if metadata := strings.TrimSpace(turnMetadata); metadata != "" {
|
||
headers.Set(openAIWSTurnMetadataHeader, metadata)
|
||
}
|
||
|
||
if account != nil && account.Type == AccountTypeOAuth {
|
||
if chatgptAccountID := account.GetChatGPTAccountID(); chatgptAccountID != "" {
|
||
headers.Set("chatgpt-account-id", chatgptAccountID)
|
||
}
|
||
if isCodexCLI {
|
||
headers.Set("originator", "codex_cli_rs")
|
||
} else {
|
||
headers.Set("originator", "opencode")
|
||
}
|
||
}
|
||
|
||
betaValue := openAIWSBetaV2Value
|
||
if decision.Transport == OpenAIUpstreamTransportResponsesWebsocket {
|
||
betaValue = openAIWSBetaV1Value
|
||
}
|
||
headers.Set("OpenAI-Beta", betaValue)
|
||
|
||
customUA := ""
|
||
if account != nil {
|
||
customUA = account.GetOpenAIUserAgent()
|
||
}
|
||
if strings.TrimSpace(customUA) != "" {
|
||
headers.Set("user-agent", customUA)
|
||
} else if c != nil {
|
||
if ua := strings.TrimSpace(c.GetHeader("User-Agent")); ua != "" {
|
||
headers.Set("user-agent", ua)
|
||
}
|
||
}
|
||
if s != nil && s.cfg != nil && s.cfg.Gateway.ForceCodexCLI {
|
||
headers.Set("user-agent", codexCLIUserAgent)
|
||
}
|
||
if account != nil && account.Type == AccountTypeOAuth && !openai.IsCodexCLIRequest(headers.Get("user-agent")) {
|
||
headers.Set("user-agent", codexCLIUserAgent)
|
||
}
|
||
|
||
return headers, sessionResolution
|
||
}
|
||
|
||
func (s *OpenAIGatewayService) buildOpenAIWSCreatePayload(reqBody map[string]any, account *Account) map[string]any {
|
||
// OpenAI WS Mode 协议:response.create 字段与 HTTP /responses 基本一致。
|
||
// 保留 stream 字段(与 Codex CLI 一致),仅移除 background。
|
||
payload := make(map[string]any, len(reqBody)+1)
|
||
for k, v := range reqBody {
|
||
payload[k] = v
|
||
}
|
||
|
||
delete(payload, "background")
|
||
if _, exists := payload["stream"]; !exists {
|
||
payload["stream"] = true
|
||
}
|
||
payload["type"] = "response.create"
|
||
|
||
// OAuth 默认保持 store=false,避免误依赖服务端历史。
|
||
if account != nil && account.Type == AccountTypeOAuth && !s.isOpenAIWSStoreRecoveryAllowed(account) {
|
||
payload["store"] = false
|
||
}
|
||
return payload
|
||
}
|
||
|
||
func setOpenAIWSTurnMetadata(payload map[string]any, turnMetadata string) {
|
||
if len(payload) == 0 {
|
||
return
|
||
}
|
||
metadata := strings.TrimSpace(turnMetadata)
|
||
if metadata == "" {
|
||
return
|
||
}
|
||
|
||
switch existing := payload["client_metadata"].(type) {
|
||
case map[string]any:
|
||
existing[openAIWSTurnMetadataHeader] = metadata
|
||
payload["client_metadata"] = existing
|
||
case map[string]string:
|
||
next := make(map[string]any, len(existing)+1)
|
||
for k, v := range existing {
|
||
next[k] = v
|
||
}
|
||
next[openAIWSTurnMetadataHeader] = metadata
|
||
payload["client_metadata"] = next
|
||
default:
|
||
payload["client_metadata"] = map[string]any{
|
||
openAIWSTurnMetadataHeader: metadata,
|
||
}
|
||
}
|
||
}
|
||
|
||
func (s *OpenAIGatewayService) isOpenAIWSStoreRecoveryAllowed(account *Account) bool {
|
||
if account != nil && account.IsOpenAIWSAllowStoreRecoveryEnabled() {
|
||
return true
|
||
}
|
||
if s != nil && s.cfg != nil && s.cfg.Gateway.OpenAIWS.AllowStoreRecovery {
|
||
return true
|
||
}
|
||
return false
|
||
}
|
||
|
||
func (s *OpenAIGatewayService) isOpenAIWSStoreDisabledInRequest(reqBody map[string]any, account *Account) bool {
|
||
if account != nil && account.Type == AccountTypeOAuth && !s.isOpenAIWSStoreRecoveryAllowed(account) {
|
||
return true
|
||
}
|
||
if len(reqBody) == 0 {
|
||
return false
|
||
}
|
||
rawStore, ok := reqBody["store"]
|
||
if !ok {
|
||
return false
|
||
}
|
||
storeEnabled, ok := rawStore.(bool)
|
||
if !ok {
|
||
return false
|
||
}
|
||
return !storeEnabled
|
||
}
|
||
|
||
func (s *OpenAIGatewayService) isOpenAIWSStoreDisabledInRequestRaw(reqBody []byte, account *Account) bool {
|
||
if account != nil && account.Type == AccountTypeOAuth && !s.isOpenAIWSStoreRecoveryAllowed(account) {
|
||
return true
|
||
}
|
||
if len(reqBody) == 0 {
|
||
return false
|
||
}
|
||
storeValue := gjson.GetBytes(reqBody, "store")
|
||
if !storeValue.Exists() {
|
||
return false
|
||
}
|
||
if storeValue.Type != gjson.True && storeValue.Type != gjson.False {
|
||
return false
|
||
}
|
||
return !storeValue.Bool()
|
||
}
|
||
|
||
func (s *OpenAIGatewayService) openAIWSStoreDisabledConnMode() string {
|
||
if s == nil || s.cfg == nil {
|
||
return openAIWSStoreDisabledConnModeStrict
|
||
}
|
||
mode := strings.ToLower(strings.TrimSpace(s.cfg.Gateway.OpenAIWS.StoreDisabledConnMode))
|
||
switch mode {
|
||
case openAIWSStoreDisabledConnModeStrict, openAIWSStoreDisabledConnModeAdaptive, openAIWSStoreDisabledConnModeOff:
|
||
return mode
|
||
case "":
|
||
// 兼容旧配置:仅配置了布尔开关时按旧语义推导。
|
||
if s.cfg.Gateway.OpenAIWS.StoreDisabledForceNewConn {
|
||
return openAIWSStoreDisabledConnModeStrict
|
||
}
|
||
return openAIWSStoreDisabledConnModeOff
|
||
default:
|
||
return openAIWSStoreDisabledConnModeStrict
|
||
}
|
||
}
|
||
|
||
func shouldForceNewConnOnStoreDisabled(mode, lastFailureReason string) bool {
|
||
switch mode {
|
||
case openAIWSStoreDisabledConnModeOff:
|
||
return false
|
||
case openAIWSStoreDisabledConnModeAdaptive:
|
||
reason := strings.TrimPrefix(strings.TrimSpace(lastFailureReason), "prewarm_")
|
||
switch reason {
|
||
case "policy_violation", "message_too_big", "auth_failed", "write_request", "write":
|
||
return true
|
||
default:
|
||
return false
|
||
}
|
||
default:
|
||
return true
|
||
}
|
||
}
|
||
|
||
func dropPreviousResponseIDFromRawPayload(payload []byte) ([]byte, bool, error) {
|
||
return dropPreviousResponseIDFromRawPayloadWithDeleteFn(payload, sjson.DeleteBytes)
|
||
}
|
||
|
||
func dropPreviousResponseIDFromRawPayloadWithDeleteFn(
|
||
payload []byte,
|
||
deleteFn func([]byte, string) ([]byte, error),
|
||
) ([]byte, bool, error) {
|
||
if len(payload) == 0 {
|
||
return payload, false, nil
|
||
}
|
||
if !gjson.GetBytes(payload, "previous_response_id").Exists() {
|
||
return payload, false, nil
|
||
}
|
||
if deleteFn == nil {
|
||
deleteFn = sjson.DeleteBytes
|
||
}
|
||
|
||
updated := payload
|
||
for i := 0; i < openAIWSMaxPrevResponseIDDeletePasses &&
|
||
gjson.GetBytes(updated, "previous_response_id").Exists(); i++ {
|
||
next, err := deleteFn(updated, "previous_response_id")
|
||
if err != nil {
|
||
return payload, false, err
|
||
}
|
||
updated = next
|
||
}
|
||
return updated, !gjson.GetBytes(updated, "previous_response_id").Exists(), nil
|
||
}
|
||
|
||
func setPreviousResponseIDToRawPayload(payload []byte, previousResponseID string) ([]byte, error) {
|
||
normalizedPrevID := strings.TrimSpace(previousResponseID)
|
||
if len(payload) == 0 || normalizedPrevID == "" {
|
||
return payload, nil
|
||
}
|
||
updated, err := sjson.SetBytes(payload, "previous_response_id", normalizedPrevID)
|
||
if err == nil {
|
||
return updated, nil
|
||
}
|
||
|
||
var reqBody map[string]any
|
||
if unmarshalErr := json.Unmarshal(payload, &reqBody); unmarshalErr != nil {
|
||
return nil, err
|
||
}
|
||
reqBody["previous_response_id"] = normalizedPrevID
|
||
rebuilt, marshalErr := json.Marshal(reqBody)
|
||
if marshalErr != nil {
|
||
return nil, marshalErr
|
||
}
|
||
return rebuilt, nil
|
||
}
|
||
|
||
func shouldInferIngressFunctionCallOutputPreviousResponseID(
|
||
storeDisabled bool,
|
||
turn int,
|
||
hasFunctionCallOutput bool,
|
||
currentPreviousResponseID string,
|
||
expectedPreviousResponseID string,
|
||
) bool {
|
||
if !storeDisabled || turn <= 1 || !hasFunctionCallOutput {
|
||
return false
|
||
}
|
||
if strings.TrimSpace(currentPreviousResponseID) != "" {
|
||
return false
|
||
}
|
||
return strings.TrimSpace(expectedPreviousResponseID) != ""
|
||
}
|
||
|
||
func alignStoreDisabledPreviousResponseID(
|
||
payload []byte,
|
||
expectedPreviousResponseID string,
|
||
) ([]byte, bool, error) {
|
||
if len(payload) == 0 {
|
||
return payload, false, nil
|
||
}
|
||
expected := strings.TrimSpace(expectedPreviousResponseID)
|
||
if expected == "" {
|
||
return payload, false, nil
|
||
}
|
||
current := openAIWSPayloadStringFromRaw(payload, "previous_response_id")
|
||
if current == "" || current == expected {
|
||
return payload, false, nil
|
||
}
|
||
|
||
withoutPrev, removed, dropErr := dropPreviousResponseIDFromRawPayload(payload)
|
||
if dropErr != nil {
|
||
return payload, false, dropErr
|
||
}
|
||
if !removed {
|
||
return payload, false, nil
|
||
}
|
||
updated, setErr := setPreviousResponseIDToRawPayload(withoutPrev, expected)
|
||
if setErr != nil {
|
||
return payload, false, setErr
|
||
}
|
||
return updated, true, nil
|
||
}
|
||
|
||
func cloneOpenAIWSPayloadBytes(payload []byte) []byte {
|
||
if len(payload) == 0 {
|
||
return nil
|
||
}
|
||
cloned := make([]byte, len(payload))
|
||
copy(cloned, payload)
|
||
return cloned
|
||
}
|
||
|
||
func cloneOpenAIWSRawMessages(items []json.RawMessage) []json.RawMessage {
|
||
if items == nil {
|
||
return nil
|
||
}
|
||
cloned := make([]json.RawMessage, 0, len(items))
|
||
for idx := range items {
|
||
cloned = append(cloned, json.RawMessage(cloneOpenAIWSPayloadBytes(items[idx])))
|
||
}
|
||
return cloned
|
||
}
|
||
|
||
func normalizeOpenAIWSJSONForCompare(raw []byte) ([]byte, error) {
|
||
trimmed := bytes.TrimSpace(raw)
|
||
if len(trimmed) == 0 {
|
||
return nil, errors.New("json is empty")
|
||
}
|
||
var decoded any
|
||
if err := json.Unmarshal(trimmed, &decoded); err != nil {
|
||
return nil, err
|
||
}
|
||
return json.Marshal(decoded)
|
||
}
|
||
|
||
func normalizeOpenAIWSJSONForCompareOrRaw(raw []byte) []byte {
|
||
normalized, err := normalizeOpenAIWSJSONForCompare(raw)
|
||
if err != nil {
|
||
return bytes.TrimSpace(raw)
|
||
}
|
||
return normalized
|
||
}
|
||
|
||
func normalizeOpenAIWSPayloadWithoutInputAndPreviousResponseID(payload []byte) ([]byte, error) {
|
||
if len(payload) == 0 {
|
||
return nil, errors.New("payload is empty")
|
||
}
|
||
var decoded map[string]any
|
||
if err := json.Unmarshal(payload, &decoded); err != nil {
|
||
return nil, err
|
||
}
|
||
delete(decoded, "input")
|
||
delete(decoded, "previous_response_id")
|
||
return json.Marshal(decoded)
|
||
}
|
||
|
||
func openAIWSExtractNormalizedInputSequence(payload []byte) ([]json.RawMessage, bool, error) {
|
||
if len(payload) == 0 {
|
||
return nil, false, nil
|
||
}
|
||
inputValue := gjson.GetBytes(payload, "input")
|
||
if !inputValue.Exists() {
|
||
return nil, false, nil
|
||
}
|
||
if inputValue.Type == gjson.JSON {
|
||
raw := strings.TrimSpace(inputValue.Raw)
|
||
if strings.HasPrefix(raw, "[") {
|
||
var items []json.RawMessage
|
||
if err := json.Unmarshal([]byte(raw), &items); err != nil {
|
||
return nil, true, err
|
||
}
|
||
return items, true, nil
|
||
}
|
||
return []json.RawMessage{json.RawMessage(raw)}, true, nil
|
||
}
|
||
if inputValue.Type == gjson.String {
|
||
encoded, _ := json.Marshal(inputValue.String())
|
||
return []json.RawMessage{encoded}, true, nil
|
||
}
|
||
return []json.RawMessage{json.RawMessage(inputValue.Raw)}, true, nil
|
||
}
|
||
|
||
func openAIWSInputIsPrefixExtended(previousPayload, currentPayload []byte) (bool, error) {
|
||
previousItems, previousExists, prevErr := openAIWSExtractNormalizedInputSequence(previousPayload)
|
||
if prevErr != nil {
|
||
return false, prevErr
|
||
}
|
||
currentItems, currentExists, currentErr := openAIWSExtractNormalizedInputSequence(currentPayload)
|
||
if currentErr != nil {
|
||
return false, currentErr
|
||
}
|
||
if !previousExists && !currentExists {
|
||
return true, nil
|
||
}
|
||
if !previousExists {
|
||
return len(currentItems) == 0, nil
|
||
}
|
||
if !currentExists {
|
||
return len(previousItems) == 0, nil
|
||
}
|
||
if len(currentItems) < len(previousItems) {
|
||
return false, nil
|
||
}
|
||
|
||
for idx := range previousItems {
|
||
previousNormalized := normalizeOpenAIWSJSONForCompareOrRaw(previousItems[idx])
|
||
currentNormalized := normalizeOpenAIWSJSONForCompareOrRaw(currentItems[idx])
|
||
if !bytes.Equal(previousNormalized, currentNormalized) {
|
||
return false, nil
|
||
}
|
||
}
|
||
return true, nil
|
||
}
|
||
|
||
func openAIWSRawItemsHasPrefix(items []json.RawMessage, prefix []json.RawMessage) bool {
|
||
if len(prefix) == 0 {
|
||
return true
|
||
}
|
||
if len(items) < len(prefix) {
|
||
return false
|
||
}
|
||
for idx := range prefix {
|
||
previousNormalized := normalizeOpenAIWSJSONForCompareOrRaw(prefix[idx])
|
||
currentNormalized := normalizeOpenAIWSJSONForCompareOrRaw(items[idx])
|
||
if !bytes.Equal(previousNormalized, currentNormalized) {
|
||
return false
|
||
}
|
||
}
|
||
return true
|
||
}
|
||
|
||
func buildOpenAIWSReplayInputSequence(
|
||
previousFullInput []json.RawMessage,
|
||
previousFullInputExists bool,
|
||
currentPayload []byte,
|
||
hasPreviousResponseID bool,
|
||
) ([]json.RawMessage, bool, error) {
|
||
currentItems, currentExists, currentErr := openAIWSExtractNormalizedInputSequence(currentPayload)
|
||
if currentErr != nil {
|
||
return nil, false, currentErr
|
||
}
|
||
if !hasPreviousResponseID {
|
||
return cloneOpenAIWSRawMessages(currentItems), currentExists, nil
|
||
}
|
||
if !previousFullInputExists {
|
||
return cloneOpenAIWSRawMessages(currentItems), currentExists, nil
|
||
}
|
||
if !currentExists || len(currentItems) == 0 {
|
||
return cloneOpenAIWSRawMessages(previousFullInput), true, nil
|
||
}
|
||
if openAIWSRawItemsHasPrefix(currentItems, previousFullInput) {
|
||
return cloneOpenAIWSRawMessages(currentItems), true, nil
|
||
}
|
||
merged := make([]json.RawMessage, 0, len(previousFullInput)+len(currentItems))
|
||
merged = append(merged, cloneOpenAIWSRawMessages(previousFullInput)...)
|
||
merged = append(merged, cloneOpenAIWSRawMessages(currentItems)...)
|
||
return merged, true, nil
|
||
}
|
||
|
||
func setOpenAIWSPayloadInputSequence(
|
||
payload []byte,
|
||
fullInput []json.RawMessage,
|
||
fullInputExists bool,
|
||
) ([]byte, error) {
|
||
if !fullInputExists {
|
||
return payload, nil
|
||
}
|
||
// Preserve [] vs null semantics when input exists but is empty.
|
||
inputForMarshal := fullInput
|
||
if inputForMarshal == nil {
|
||
inputForMarshal = []json.RawMessage{}
|
||
}
|
||
inputRaw, marshalErr := json.Marshal(inputForMarshal)
|
||
if marshalErr != nil {
|
||
return nil, marshalErr
|
||
}
|
||
return sjson.SetRawBytes(payload, "input", inputRaw)
|
||
}
|
||
|
||
func shouldKeepIngressPreviousResponseID(
|
||
previousPayload []byte,
|
||
currentPayload []byte,
|
||
lastTurnResponseID string,
|
||
hasFunctionCallOutput bool,
|
||
) (bool, string, error) {
|
||
if hasFunctionCallOutput {
|
||
return true, "has_function_call_output", nil
|
||
}
|
||
currentPreviousResponseID := strings.TrimSpace(openAIWSPayloadStringFromRaw(currentPayload, "previous_response_id"))
|
||
if currentPreviousResponseID == "" {
|
||
return false, "missing_previous_response_id", nil
|
||
}
|
||
expectedPreviousResponseID := strings.TrimSpace(lastTurnResponseID)
|
||
if expectedPreviousResponseID == "" {
|
||
return false, "missing_last_turn_response_id", nil
|
||
}
|
||
if currentPreviousResponseID != expectedPreviousResponseID {
|
||
return false, "previous_response_id_mismatch", nil
|
||
}
|
||
if len(previousPayload) == 0 {
|
||
return false, "missing_previous_turn_payload", nil
|
||
}
|
||
|
||
previousComparable, previousComparableErr := normalizeOpenAIWSPayloadWithoutInputAndPreviousResponseID(previousPayload)
|
||
if previousComparableErr != nil {
|
||
return false, "non_input_compare_error", previousComparableErr
|
||
}
|
||
currentComparable, currentComparableErr := normalizeOpenAIWSPayloadWithoutInputAndPreviousResponseID(currentPayload)
|
||
if currentComparableErr != nil {
|
||
return false, "non_input_compare_error", currentComparableErr
|
||
}
|
||
if !bytes.Equal(previousComparable, currentComparable) {
|
||
return false, "non_input_changed", nil
|
||
}
|
||
return true, "strict_incremental_ok", nil
|
||
}
|
||
|
||
type openAIWSIngressPreviousTurnStrictState struct {
|
||
nonInputComparable []byte
|
||
}
|
||
|
||
func buildOpenAIWSIngressPreviousTurnStrictState(payload []byte) (*openAIWSIngressPreviousTurnStrictState, error) {
|
||
if len(payload) == 0 {
|
||
return nil, nil
|
||
}
|
||
nonInputComparable, nonInputErr := normalizeOpenAIWSPayloadWithoutInputAndPreviousResponseID(payload)
|
||
if nonInputErr != nil {
|
||
return nil, nonInputErr
|
||
}
|
||
return &openAIWSIngressPreviousTurnStrictState{
|
||
nonInputComparable: nonInputComparable,
|
||
}, nil
|
||
}
|
||
|
||
func shouldKeepIngressPreviousResponseIDWithStrictState(
|
||
previousState *openAIWSIngressPreviousTurnStrictState,
|
||
currentPayload []byte,
|
||
lastTurnResponseID string,
|
||
hasFunctionCallOutput bool,
|
||
) (bool, string, error) {
|
||
if hasFunctionCallOutput {
|
||
return true, "has_function_call_output", nil
|
||
}
|
||
currentPreviousResponseID := strings.TrimSpace(openAIWSPayloadStringFromRaw(currentPayload, "previous_response_id"))
|
||
if currentPreviousResponseID == "" {
|
||
return false, "missing_previous_response_id", nil
|
||
}
|
||
expectedPreviousResponseID := strings.TrimSpace(lastTurnResponseID)
|
||
if expectedPreviousResponseID == "" {
|
||
return false, "missing_last_turn_response_id", nil
|
||
}
|
||
if currentPreviousResponseID != expectedPreviousResponseID {
|
||
return false, "previous_response_id_mismatch", nil
|
||
}
|
||
if previousState == nil {
|
||
return false, "missing_previous_turn_payload", nil
|
||
}
|
||
|
||
currentComparable, currentComparableErr := normalizeOpenAIWSPayloadWithoutInputAndPreviousResponseID(currentPayload)
|
||
if currentComparableErr != nil {
|
||
return false, "non_input_compare_error", currentComparableErr
|
||
}
|
||
if !bytes.Equal(previousState.nonInputComparable, currentComparable) {
|
||
return false, "non_input_changed", nil
|
||
}
|
||
return true, "strict_incremental_ok", nil
|
||
}
|
||
|
||
func (s *OpenAIGatewayService) forwardOpenAIWSV2(
|
||
ctx context.Context,
|
||
c *gin.Context,
|
||
account *Account,
|
||
reqBody map[string]any,
|
||
token string,
|
||
decision OpenAIWSProtocolDecision,
|
||
isCodexCLI bool,
|
||
reqStream bool,
|
||
originalModel string,
|
||
mappedModel string,
|
||
startTime time.Time,
|
||
attempt int,
|
||
lastFailureReason string,
|
||
) (*OpenAIForwardResult, error) {
|
||
if s == nil || account == nil {
|
||
return nil, wrapOpenAIWSFallback("invalid_state", errors.New("service or account is nil"))
|
||
}
|
||
|
||
wsURL, err := s.buildOpenAIResponsesWSURL(account)
|
||
if err != nil {
|
||
return nil, wrapOpenAIWSFallback("build_ws_url", err)
|
||
}
|
||
wsHost := "-"
|
||
wsPath := "-"
|
||
if parsed, parseErr := url.Parse(wsURL); parseErr == nil && parsed != nil {
|
||
if h := strings.TrimSpace(parsed.Host); h != "" {
|
||
wsHost = normalizeOpenAIWSLogValue(h)
|
||
}
|
||
if p := strings.TrimSpace(parsed.Path); p != "" {
|
||
wsPath = normalizeOpenAIWSLogValue(p)
|
||
}
|
||
}
|
||
logOpenAIWSModeDebug(
|
||
"dial_target account_id=%d account_type=%s ws_host=%s ws_path=%s",
|
||
account.ID,
|
||
account.Type,
|
||
wsHost,
|
||
wsPath,
|
||
)
|
||
|
||
payload := s.buildOpenAIWSCreatePayload(reqBody, account)
|
||
payloadStrategy, removedKeys := applyOpenAIWSRetryPayloadStrategy(payload, attempt)
|
||
previousResponseID := openAIWSPayloadString(payload, "previous_response_id")
|
||
previousResponseIDKind := ClassifyOpenAIPreviousResponseIDKind(previousResponseID)
|
||
promptCacheKey := openAIWSPayloadString(payload, "prompt_cache_key")
|
||
_, hasTools := payload["tools"]
|
||
debugEnabled := isOpenAIWSModeDebugEnabled()
|
||
payloadBytes := -1
|
||
resolvePayloadBytes := func() int {
|
||
if payloadBytes >= 0 {
|
||
return payloadBytes
|
||
}
|
||
payloadBytes = len(payloadAsJSONBytes(payload))
|
||
return payloadBytes
|
||
}
|
||
streamValue := "-"
|
||
if raw, ok := payload["stream"]; ok {
|
||
streamValue = normalizeOpenAIWSLogValue(strings.TrimSpace(fmt.Sprintf("%v", raw)))
|
||
}
|
||
turnState := ""
|
||
turnMetadata := ""
|
||
if c != nil && c.Request != nil {
|
||
turnState = strings.TrimSpace(c.GetHeader(openAIWSTurnStateHeader))
|
||
turnMetadata = strings.TrimSpace(c.GetHeader(openAIWSTurnMetadataHeader))
|
||
}
|
||
setOpenAIWSTurnMetadata(payload, turnMetadata)
|
||
payloadEventType := openAIWSPayloadString(payload, "type")
|
||
if payloadEventType == "" {
|
||
payloadEventType = "response.create"
|
||
}
|
||
if s.shouldEmitOpenAIWSPayloadSchema(attempt) {
|
||
logOpenAIWSModeInfo(
|
||
"[debug] payload_schema account_id=%d attempt=%d event=%s payload_keys=%s payload_bytes=%d payload_key_sizes=%s input_summary=%s stream=%s payload_strategy=%s removed_keys=%s has_previous_response_id=%v has_prompt_cache_key=%v has_tools=%v",
|
||
account.ID,
|
||
attempt,
|
||
payloadEventType,
|
||
normalizeOpenAIWSLogValue(strings.Join(sortedKeys(payload), ",")),
|
||
resolvePayloadBytes(),
|
||
normalizeOpenAIWSLogValue(summarizeOpenAIWSPayloadKeySizes(payload, openAIWSPayloadKeySizeTopN)),
|
||
normalizeOpenAIWSLogValue(summarizeOpenAIWSInput(payload["input"])),
|
||
streamValue,
|
||
normalizeOpenAIWSLogValue(payloadStrategy),
|
||
normalizeOpenAIWSLogValue(strings.Join(removedKeys, ",")),
|
||
previousResponseID != "",
|
||
promptCacheKey != "",
|
||
hasTools,
|
||
)
|
||
}
|
||
|
||
stateStore := s.getOpenAIWSStateStore()
|
||
groupID := getOpenAIGroupIDFromContext(c)
|
||
sessionHash := s.GenerateSessionHash(c, nil)
|
||
if sessionHash == "" {
|
||
var legacySessionHash string
|
||
sessionHash, legacySessionHash = openAIWSSessionHashesFromID(promptCacheKey)
|
||
attachOpenAILegacySessionHashToGin(c, legacySessionHash)
|
||
}
|
||
if turnState == "" && stateStore != nil && sessionHash != "" {
|
||
if savedTurnState, ok := stateStore.GetSessionTurnState(groupID, sessionHash); ok {
|
||
turnState = savedTurnState
|
||
}
|
||
}
|
||
preferredConnID := ""
|
||
if stateStore != nil && previousResponseID != "" {
|
||
if connID, ok := stateStore.GetResponseConn(previousResponseID); ok {
|
||
preferredConnID = connID
|
||
}
|
||
}
|
||
storeDisabled := s.isOpenAIWSStoreDisabledInRequest(reqBody, account)
|
||
if stateStore != nil && storeDisabled && previousResponseID == "" && sessionHash != "" {
|
||
if connID, ok := stateStore.GetSessionConn(groupID, sessionHash); ok {
|
||
preferredConnID = connID
|
||
}
|
||
}
|
||
storeDisabledConnMode := s.openAIWSStoreDisabledConnMode()
|
||
forceNewConnByPolicy := shouldForceNewConnOnStoreDisabled(storeDisabledConnMode, lastFailureReason)
|
||
forceNewConn := forceNewConnByPolicy && storeDisabled && previousResponseID == "" && sessionHash != "" && preferredConnID == ""
|
||
wsHeaders, sessionResolution := s.buildOpenAIWSHeaders(c, account, token, decision, isCodexCLI, turnState, turnMetadata, promptCacheKey)
|
||
logOpenAIWSModeDebug(
|
||
"acquire_start account_id=%d account_type=%s transport=%s preferred_conn_id=%s has_previous_response_id=%v session_hash=%s has_turn_state=%v turn_state_len=%d has_turn_metadata=%v turn_metadata_len=%d store_disabled=%v store_disabled_conn_mode=%s retry_last_reason=%s force_new_conn=%v header_user_agent=%s header_openai_beta=%s header_originator=%s header_accept_language=%s header_session_id=%s header_conversation_id=%s session_id_source=%s conversation_id_source=%s has_prompt_cache_key=%v has_chatgpt_account_id=%v has_authorization=%v has_session_id=%v has_conversation_id=%v proxy_enabled=%v",
|
||
account.ID,
|
||
account.Type,
|
||
normalizeOpenAIWSLogValue(string(decision.Transport)),
|
||
truncateOpenAIWSLogValue(preferredConnID, openAIWSIDValueMaxLen),
|
||
previousResponseID != "",
|
||
truncateOpenAIWSLogValue(sessionHash, 12),
|
||
turnState != "",
|
||
len(turnState),
|
||
turnMetadata != "",
|
||
len(turnMetadata),
|
||
storeDisabled,
|
||
normalizeOpenAIWSLogValue(storeDisabledConnMode),
|
||
truncateOpenAIWSLogValue(lastFailureReason, openAIWSLogValueMaxLen),
|
||
forceNewConn,
|
||
openAIWSHeaderValueForLog(wsHeaders, "user-agent"),
|
||
openAIWSHeaderValueForLog(wsHeaders, "openai-beta"),
|
||
openAIWSHeaderValueForLog(wsHeaders, "originator"),
|
||
openAIWSHeaderValueForLog(wsHeaders, "accept-language"),
|
||
openAIWSHeaderValueForLog(wsHeaders, "session_id"),
|
||
openAIWSHeaderValueForLog(wsHeaders, "conversation_id"),
|
||
normalizeOpenAIWSLogValue(sessionResolution.SessionSource),
|
||
normalizeOpenAIWSLogValue(sessionResolution.ConversationSource),
|
||
promptCacheKey != "",
|
||
hasOpenAIWSHeader(wsHeaders, "chatgpt-account-id"),
|
||
hasOpenAIWSHeader(wsHeaders, "authorization"),
|
||
hasOpenAIWSHeader(wsHeaders, "session_id"),
|
||
hasOpenAIWSHeader(wsHeaders, "conversation_id"),
|
||
account.ProxyID != nil && account.Proxy != nil,
|
||
)
|
||
|
||
acquireCtx, acquireCancel := context.WithTimeout(ctx, s.openAIWSAcquireTimeout())
|
||
defer acquireCancel()
|
||
|
||
lease, err := s.getOpenAIWSConnPool().Acquire(acquireCtx, openAIWSAcquireRequest{
|
||
Account: account,
|
||
WSURL: wsURL,
|
||
Headers: wsHeaders,
|
||
PreferredConnID: preferredConnID,
|
||
ForceNewConn: forceNewConn,
|
||
ProxyURL: func() string {
|
||
if account.ProxyID != nil && account.Proxy != nil {
|
||
return account.Proxy.URL()
|
||
}
|
||
return ""
|
||
}(),
|
||
})
|
||
if err != nil {
|
||
dialStatus, dialClass, dialCloseStatus, dialCloseReason, dialRespServer, dialRespVia, dialRespCFRay, dialRespReqID := summarizeOpenAIWSDialError(err)
|
||
logOpenAIWSModeInfo(
|
||
"acquire_fail account_id=%d account_type=%s transport=%s reason=%s dial_status=%d dial_class=%s dial_close_status=%s dial_close_reason=%s dial_resp_server=%s dial_resp_via=%s dial_resp_cf_ray=%s dial_resp_x_request_id=%s cause=%s preferred_conn_id=%s force_new_conn=%v ws_host=%s ws_path=%s proxy_enabled=%v",
|
||
account.ID,
|
||
account.Type,
|
||
normalizeOpenAIWSLogValue(string(decision.Transport)),
|
||
normalizeOpenAIWSLogValue(classifyOpenAIWSAcquireError(err)),
|
||
dialStatus,
|
||
dialClass,
|
||
dialCloseStatus,
|
||
truncateOpenAIWSLogValue(dialCloseReason, openAIWSHeaderValueMaxLen),
|
||
dialRespServer,
|
||
dialRespVia,
|
||
dialRespCFRay,
|
||
dialRespReqID,
|
||
truncateOpenAIWSLogValue(err.Error(), openAIWSLogValueMaxLen),
|
||
truncateOpenAIWSLogValue(preferredConnID, openAIWSIDValueMaxLen),
|
||
forceNewConn,
|
||
wsHost,
|
||
wsPath,
|
||
account.ProxyID != nil && account.Proxy != nil,
|
||
)
|
||
return nil, wrapOpenAIWSFallback(classifyOpenAIWSAcquireError(err), err)
|
||
}
|
||
defer lease.Release()
|
||
connID := strings.TrimSpace(lease.ConnID())
|
||
logOpenAIWSModeDebug(
|
||
"connected account_id=%d account_type=%s transport=%s conn_id=%s conn_reused=%v conn_pick_ms=%d queue_wait_ms=%d has_previous_response_id=%v",
|
||
account.ID,
|
||
account.Type,
|
||
normalizeOpenAIWSLogValue(string(decision.Transport)),
|
||
connID,
|
||
lease.Reused(),
|
||
lease.ConnPickDuration().Milliseconds(),
|
||
lease.QueueWaitDuration().Milliseconds(),
|
||
previousResponseID != "",
|
||
)
|
||
if previousResponseID != "" {
|
||
logOpenAIWSModeInfo(
|
||
"continuation_probe account_id=%d account_type=%s conn_id=%s previous_response_id=%s previous_response_id_kind=%s preferred_conn_id=%s conn_reused=%v store_disabled=%v session_hash=%s header_session_id=%s header_conversation_id=%s session_id_source=%s conversation_id_source=%s has_turn_state=%v turn_state_len=%d has_prompt_cache_key=%v",
|
||
account.ID,
|
||
account.Type,
|
||
truncateOpenAIWSLogValue(connID, openAIWSIDValueMaxLen),
|
||
truncateOpenAIWSLogValue(previousResponseID, openAIWSIDValueMaxLen),
|
||
normalizeOpenAIWSLogValue(previousResponseIDKind),
|
||
truncateOpenAIWSLogValue(preferredConnID, openAIWSIDValueMaxLen),
|
||
lease.Reused(),
|
||
storeDisabled,
|
||
truncateOpenAIWSLogValue(sessionHash, 12),
|
||
openAIWSHeaderValueForLog(wsHeaders, "session_id"),
|
||
openAIWSHeaderValueForLog(wsHeaders, "conversation_id"),
|
||
normalizeOpenAIWSLogValue(sessionResolution.SessionSource),
|
||
normalizeOpenAIWSLogValue(sessionResolution.ConversationSource),
|
||
turnState != "",
|
||
len(turnState),
|
||
promptCacheKey != "",
|
||
)
|
||
}
|
||
if c != nil {
|
||
SetOpsLatencyMs(c, OpsOpenAIWSConnPickMsKey, lease.ConnPickDuration().Milliseconds())
|
||
SetOpsLatencyMs(c, OpsOpenAIWSQueueWaitMsKey, lease.QueueWaitDuration().Milliseconds())
|
||
c.Set(OpsOpenAIWSConnReusedKey, lease.Reused())
|
||
if connID != "" {
|
||
c.Set(OpsOpenAIWSConnIDKey, connID)
|
||
}
|
||
}
|
||
|
||
handshakeTurnState := strings.TrimSpace(lease.HandshakeHeader(openAIWSTurnStateHeader))
|
||
logOpenAIWSModeDebug(
|
||
"handshake account_id=%d conn_id=%s has_turn_state=%v turn_state_len=%d",
|
||
account.ID,
|
||
connID,
|
||
handshakeTurnState != "",
|
||
len(handshakeTurnState),
|
||
)
|
||
if handshakeTurnState != "" {
|
||
if stateStore != nil && sessionHash != "" {
|
||
stateStore.BindSessionTurnState(groupID, sessionHash, handshakeTurnState, s.openAIWSSessionStickyTTL())
|
||
}
|
||
if c != nil {
|
||
c.Header(http.CanonicalHeaderKey(openAIWSTurnStateHeader), handshakeTurnState)
|
||
}
|
||
}
|
||
|
||
if err := s.performOpenAIWSGeneratePrewarm(
|
||
ctx,
|
||
lease,
|
||
decision,
|
||
payload,
|
||
previousResponseID,
|
||
reqBody,
|
||
account,
|
||
stateStore,
|
||
groupID,
|
||
); err != nil {
|
||
return nil, err
|
||
}
|
||
|
||
if err := lease.WriteJSONWithContextTimeout(ctx, payload, s.openAIWSWriteTimeout()); err != nil {
|
||
lease.MarkBroken()
|
||
logOpenAIWSModeInfo(
|
||
"write_request_fail account_id=%d conn_id=%s cause=%s payload_bytes=%d",
|
||
account.ID,
|
||
connID,
|
||
truncateOpenAIWSLogValue(err.Error(), openAIWSLogValueMaxLen),
|
||
resolvePayloadBytes(),
|
||
)
|
||
return nil, wrapOpenAIWSFallback("write_request", err)
|
||
}
|
||
if debugEnabled {
|
||
logOpenAIWSModeDebug(
|
||
"write_request_sent account_id=%d conn_id=%s stream=%v payload_bytes=%d previous_response_id=%s",
|
||
account.ID,
|
||
connID,
|
||
reqStream,
|
||
resolvePayloadBytes(),
|
||
truncateOpenAIWSLogValue(previousResponseID, openAIWSIDValueMaxLen),
|
||
)
|
||
}
|
||
|
||
usage := &OpenAIUsage{}
|
||
var firstTokenMs *int
|
||
responseID := ""
|
||
var finalResponse []byte
|
||
wroteDownstream := false
|
||
needModelReplace := originalModel != mappedModel
|
||
var mappedModelBytes []byte
|
||
if needModelReplace && mappedModel != "" {
|
||
mappedModelBytes = []byte(mappedModel)
|
||
}
|
||
bufferedStreamEvents := make([][]byte, 0, 4)
|
||
eventCount := 0
|
||
tokenEventCount := 0
|
||
terminalEventCount := 0
|
||
bufferedEventCount := 0
|
||
flushedBufferedEventCount := 0
|
||
firstEventType := ""
|
||
lastEventType := ""
|
||
|
||
var flusher http.Flusher
|
||
if reqStream {
|
||
if s.responseHeaderFilter != nil {
|
||
responseheaders.WriteFilteredHeaders(c.Writer.Header(), http.Header{}, s.responseHeaderFilter)
|
||
}
|
||
c.Header("Content-Type", "text/event-stream")
|
||
c.Header("Cache-Control", "no-cache")
|
||
c.Header("Connection", "keep-alive")
|
||
c.Header("X-Accel-Buffering", "no")
|
||
f, ok := c.Writer.(http.Flusher)
|
||
if !ok {
|
||
lease.MarkBroken()
|
||
return nil, wrapOpenAIWSFallback("streaming_not_supported", errors.New("streaming not supported"))
|
||
}
|
||
flusher = f
|
||
}
|
||
|
||
clientDisconnected := false
|
||
flushBatchSize := s.openAIWSEventFlushBatchSize()
|
||
flushInterval := s.openAIWSEventFlushInterval()
|
||
pendingFlushEvents := 0
|
||
lastFlushAt := time.Now()
|
||
flushStreamWriter := func(force bool) {
|
||
if clientDisconnected || flusher == nil || pendingFlushEvents <= 0 {
|
||
return
|
||
}
|
||
if !force && flushBatchSize > 1 && pendingFlushEvents < flushBatchSize {
|
||
if flushInterval <= 0 || time.Since(lastFlushAt) < flushInterval {
|
||
return
|
||
}
|
||
}
|
||
flusher.Flush()
|
||
pendingFlushEvents = 0
|
||
lastFlushAt = time.Now()
|
||
}
|
||
emitStreamMessage := func(message []byte, forceFlush bool) {
|
||
if clientDisconnected {
|
||
return
|
||
}
|
||
frame := make([]byte, 0, len(message)+8)
|
||
frame = append(frame, "data: "...)
|
||
frame = append(frame, message...)
|
||
frame = append(frame, '\n', '\n')
|
||
_, wErr := c.Writer.Write(frame)
|
||
if wErr == nil {
|
||
wroteDownstream = true
|
||
pendingFlushEvents++
|
||
flushStreamWriter(forceFlush)
|
||
return
|
||
}
|
||
clientDisconnected = true
|
||
logger.LegacyPrintf("service.openai_gateway", "[OpenAI WS Mode] client disconnected, continue draining upstream: account=%d", account.ID)
|
||
}
|
||
flushBufferedStreamEvents := func(reason string) {
|
||
if len(bufferedStreamEvents) == 0 {
|
||
return
|
||
}
|
||
flushed := len(bufferedStreamEvents)
|
||
for _, buffered := range bufferedStreamEvents {
|
||
emitStreamMessage(buffered, false)
|
||
}
|
||
bufferedStreamEvents = bufferedStreamEvents[:0]
|
||
flushStreamWriter(true)
|
||
flushedBufferedEventCount += flushed
|
||
if debugEnabled {
|
||
logOpenAIWSModeDebug(
|
||
"buffer_flush account_id=%d conn_id=%s reason=%s flushed=%d total_flushed=%d client_disconnected=%v",
|
||
account.ID,
|
||
connID,
|
||
truncateOpenAIWSLogValue(reason, openAIWSLogValueMaxLen),
|
||
flushed,
|
||
flushedBufferedEventCount,
|
||
clientDisconnected,
|
||
)
|
||
}
|
||
}
|
||
|
||
readTimeout := s.openAIWSReadTimeout()
|
||
|
||
for {
|
||
message, readErr := lease.ReadMessageWithContextTimeout(ctx, readTimeout)
|
||
if readErr != nil {
|
||
lease.MarkBroken()
|
||
closeStatus, closeReason := summarizeOpenAIWSReadCloseError(readErr)
|
||
logOpenAIWSModeInfo(
|
||
"read_fail account_id=%d conn_id=%s wrote_downstream=%v close_status=%s close_reason=%s cause=%s events=%d token_events=%d terminal_events=%d buffered_pending=%d buffered_flushed=%d first_event=%s last_event=%s",
|
||
account.ID,
|
||
connID,
|
||
wroteDownstream,
|
||
closeStatus,
|
||
closeReason,
|
||
truncateOpenAIWSLogValue(readErr.Error(), openAIWSLogValueMaxLen),
|
||
eventCount,
|
||
tokenEventCount,
|
||
terminalEventCount,
|
||
len(bufferedStreamEvents),
|
||
flushedBufferedEventCount,
|
||
truncateOpenAIWSLogValue(firstEventType, openAIWSLogValueMaxLen),
|
||
truncateOpenAIWSLogValue(lastEventType, openAIWSLogValueMaxLen),
|
||
)
|
||
if !wroteDownstream {
|
||
return nil, wrapOpenAIWSFallback(classifyOpenAIWSReadFallbackReason(readErr), readErr)
|
||
}
|
||
if clientDisconnected {
|
||
break
|
||
}
|
||
setOpsUpstreamError(c, 0, sanitizeUpstreamErrorMessage(readErr.Error()), "")
|
||
return nil, fmt.Errorf("openai ws read event: %w", readErr)
|
||
}
|
||
|
||
eventType, eventResponseID, responseField := parseOpenAIWSEventEnvelope(message)
|
||
if eventType == "" {
|
||
continue
|
||
}
|
||
eventCount++
|
||
if firstEventType == "" {
|
||
firstEventType = eventType
|
||
}
|
||
lastEventType = eventType
|
||
|
||
if responseID == "" && eventResponseID != "" {
|
||
responseID = eventResponseID
|
||
}
|
||
|
||
isTokenEvent := isOpenAIWSTokenEvent(eventType)
|
||
if isTokenEvent {
|
||
tokenEventCount++
|
||
}
|
||
isTerminalEvent := isOpenAIWSTerminalEvent(eventType)
|
||
if isTerminalEvent {
|
||
terminalEventCount++
|
||
}
|
||
if firstTokenMs == nil && isTokenEvent {
|
||
ms := int(time.Since(startTime).Milliseconds())
|
||
firstTokenMs = &ms
|
||
}
|
||
if debugEnabled && shouldLogOpenAIWSEvent(eventCount, eventType) {
|
||
logOpenAIWSModeDebug(
|
||
"event_received account_id=%d conn_id=%s idx=%d type=%s bytes=%d token=%v terminal=%v buffered_pending=%d",
|
||
account.ID,
|
||
connID,
|
||
eventCount,
|
||
truncateOpenAIWSLogValue(eventType, openAIWSLogValueMaxLen),
|
||
len(message),
|
||
isTokenEvent,
|
||
isTerminalEvent,
|
||
len(bufferedStreamEvents),
|
||
)
|
||
}
|
||
|
||
if !clientDisconnected {
|
||
if needModelReplace && len(mappedModelBytes) > 0 && openAIWSEventMayContainModel(eventType) && bytes.Contains(message, mappedModelBytes) {
|
||
message = replaceOpenAIWSMessageModel(message, mappedModel, originalModel)
|
||
}
|
||
if openAIWSEventMayContainToolCalls(eventType) && openAIWSMessageLikelyContainsToolCalls(message) {
|
||
if corrected, changed := s.toolCorrector.CorrectToolCallsInSSEBytes(message); changed {
|
||
message = corrected
|
||
}
|
||
}
|
||
}
|
||
if openAIWSEventShouldParseUsage(eventType) {
|
||
parseOpenAIWSResponseUsageFromCompletedEvent(message, usage)
|
||
}
|
||
|
||
if eventType == "error" {
|
||
errCodeRaw, errTypeRaw, errMsgRaw := parseOpenAIWSErrorEventFields(message)
|
||
errMsg := strings.TrimSpace(errMsgRaw)
|
||
if errMsg == "" {
|
||
errMsg = "Upstream websocket error"
|
||
}
|
||
fallbackReason, canFallback := classifyOpenAIWSErrorEventFromRaw(errCodeRaw, errTypeRaw, errMsgRaw)
|
||
errCode, errType, errMessage := summarizeOpenAIWSErrorEventFieldsFromRaw(errCodeRaw, errTypeRaw, errMsgRaw)
|
||
logOpenAIWSModeInfo(
|
||
"error_event account_id=%d conn_id=%s idx=%d fallback_reason=%s can_fallback=%v err_code=%s err_type=%s err_message=%s",
|
||
account.ID,
|
||
connID,
|
||
eventCount,
|
||
truncateOpenAIWSLogValue(fallbackReason, openAIWSLogValueMaxLen),
|
||
canFallback,
|
||
errCode,
|
||
errType,
|
||
errMessage,
|
||
)
|
||
if fallbackReason == "previous_response_not_found" {
|
||
logOpenAIWSModeInfo(
|
||
"previous_response_not_found_diag account_id=%d account_type=%s conn_id=%s previous_response_id=%s previous_response_id_kind=%s response_id=%s event_idx=%d req_stream=%v store_disabled=%v conn_reused=%v session_hash=%s header_session_id=%s header_conversation_id=%s session_id_source=%s conversation_id_source=%s has_turn_state=%v turn_state_len=%d has_prompt_cache_key=%v err_code=%s err_type=%s err_message=%s",
|
||
account.ID,
|
||
account.Type,
|
||
connID,
|
||
truncateOpenAIWSLogValue(previousResponseID, openAIWSIDValueMaxLen),
|
||
normalizeOpenAIWSLogValue(previousResponseIDKind),
|
||
truncateOpenAIWSLogValue(responseID, openAIWSIDValueMaxLen),
|
||
eventCount,
|
||
reqStream,
|
||
storeDisabled,
|
||
lease.Reused(),
|
||
truncateOpenAIWSLogValue(sessionHash, 12),
|
||
openAIWSHeaderValueForLog(wsHeaders, "session_id"),
|
||
openAIWSHeaderValueForLog(wsHeaders, "conversation_id"),
|
||
normalizeOpenAIWSLogValue(sessionResolution.SessionSource),
|
||
normalizeOpenAIWSLogValue(sessionResolution.ConversationSource),
|
||
turnState != "",
|
||
len(turnState),
|
||
promptCacheKey != "",
|
||
errCode,
|
||
errType,
|
||
errMessage,
|
||
)
|
||
}
|
||
// error 事件后连接不再可复用,避免回池后污染下一请求。
|
||
lease.MarkBroken()
|
||
if !wroteDownstream && canFallback {
|
||
return nil, wrapOpenAIWSFallback(fallbackReason, errors.New(errMsg))
|
||
}
|
||
statusCode := openAIWSErrorHTTPStatusFromRaw(errCodeRaw, errTypeRaw)
|
||
setOpsUpstreamError(c, statusCode, errMsg, "")
|
||
if reqStream && !clientDisconnected {
|
||
flushBufferedStreamEvents("error_event")
|
||
emitStreamMessage(message, true)
|
||
}
|
||
if !reqStream {
|
||
c.JSON(statusCode, gin.H{
|
||
"error": gin.H{
|
||
"type": "upstream_error",
|
||
"message": errMsg,
|
||
},
|
||
})
|
||
}
|
||
return nil, fmt.Errorf("openai ws error event: %s", errMsg)
|
||
}
|
||
|
||
if reqStream {
|
||
// 在首个 token 前先缓冲事件(如 response.created),
|
||
// 以便上游早期断连时仍可安全回退到 HTTP,不给下游发送半截流。
|
||
shouldBuffer := firstTokenMs == nil && !isTokenEvent && !isTerminalEvent
|
||
if shouldBuffer {
|
||
buffered := make([]byte, len(message))
|
||
copy(buffered, message)
|
||
bufferedStreamEvents = append(bufferedStreamEvents, buffered)
|
||
bufferedEventCount++
|
||
if debugEnabled && shouldLogOpenAIWSBufferedEvent(bufferedEventCount) {
|
||
logOpenAIWSModeDebug(
|
||
"buffer_enqueue account_id=%d conn_id=%s idx=%d event_idx=%d event_type=%s buffer_size=%d",
|
||
account.ID,
|
||
connID,
|
||
bufferedEventCount,
|
||
eventCount,
|
||
truncateOpenAIWSLogValue(eventType, openAIWSLogValueMaxLen),
|
||
len(bufferedStreamEvents),
|
||
)
|
||
}
|
||
} else {
|
||
flushBufferedStreamEvents(eventType)
|
||
emitStreamMessage(message, isTerminalEvent)
|
||
}
|
||
} else {
|
||
if responseField.Exists() && responseField.Type == gjson.JSON {
|
||
finalResponse = []byte(responseField.Raw)
|
||
}
|
||
}
|
||
|
||
if isTerminalEvent {
|
||
break
|
||
}
|
||
}
|
||
|
||
if !reqStream {
|
||
if len(finalResponse) == 0 {
|
||
logOpenAIWSModeInfo(
|
||
"missing_final_response account_id=%d conn_id=%s events=%d token_events=%d terminal_events=%d wrote_downstream=%v",
|
||
account.ID,
|
||
connID,
|
||
eventCount,
|
||
tokenEventCount,
|
||
terminalEventCount,
|
||
wroteDownstream,
|
||
)
|
||
if !wroteDownstream {
|
||
return nil, wrapOpenAIWSFallback("missing_final_response", errors.New("no terminal response payload"))
|
||
}
|
||
return nil, errors.New("ws finished without final response")
|
||
}
|
||
|
||
if needModelReplace {
|
||
finalResponse = s.replaceModelInResponseBody(finalResponse, mappedModel, originalModel)
|
||
}
|
||
finalResponse = s.correctToolCallsInResponseBody(finalResponse)
|
||
populateOpenAIUsageFromResponseJSON(finalResponse, usage)
|
||
if responseID == "" {
|
||
responseID = strings.TrimSpace(gjson.GetBytes(finalResponse, "id").String())
|
||
}
|
||
|
||
c.Data(http.StatusOK, "application/json", finalResponse)
|
||
} else {
|
||
flushStreamWriter(true)
|
||
}
|
||
|
||
if responseID != "" && stateStore != nil {
|
||
ttl := s.openAIWSResponseStickyTTL()
|
||
logOpenAIWSBindResponseAccountWarn(groupID, account.ID, responseID, stateStore.BindResponseAccount(ctx, groupID, responseID, account.ID, ttl))
|
||
stateStore.BindResponseConn(responseID, lease.ConnID(), ttl)
|
||
}
|
||
if stateStore != nil && storeDisabled && sessionHash != "" {
|
||
stateStore.BindSessionConn(groupID, sessionHash, lease.ConnID(), s.openAIWSSessionStickyTTL())
|
||
}
|
||
firstTokenMsValue := -1
|
||
if firstTokenMs != nil {
|
||
firstTokenMsValue = *firstTokenMs
|
||
}
|
||
logOpenAIWSModeDebug(
|
||
"completed account_id=%d conn_id=%s response_id=%s stream=%v duration_ms=%d events=%d token_events=%d terminal_events=%d buffered_events=%d buffered_flushed=%d first_event=%s last_event=%s first_token_ms=%d wrote_downstream=%v client_disconnected=%v",
|
||
account.ID,
|
||
connID,
|
||
truncateOpenAIWSLogValue(strings.TrimSpace(responseID), openAIWSIDValueMaxLen),
|
||
reqStream,
|
||
time.Since(startTime).Milliseconds(),
|
||
eventCount,
|
||
tokenEventCount,
|
||
terminalEventCount,
|
||
bufferedEventCount,
|
||
flushedBufferedEventCount,
|
||
truncateOpenAIWSLogValue(firstEventType, openAIWSLogValueMaxLen),
|
||
truncateOpenAIWSLogValue(lastEventType, openAIWSLogValueMaxLen),
|
||
firstTokenMsValue,
|
||
wroteDownstream,
|
||
clientDisconnected,
|
||
)
|
||
|
||
return &OpenAIForwardResult{
|
||
RequestID: responseID,
|
||
Usage: *usage,
|
||
Model: originalModel,
|
||
ReasoningEffort: extractOpenAIReasoningEffort(reqBody, originalModel),
|
||
Stream: reqStream,
|
||
OpenAIWSMode: true,
|
||
Duration: time.Since(startTime),
|
||
FirstTokenMs: firstTokenMs,
|
||
}, nil
|
||
}
|
||
|
||
// ProxyResponsesWebSocketFromClient 处理客户端入站 WebSocket(OpenAI Responses WS Mode)并转发到上游。
|
||
// 当前实现按“单请求 -> 终止事件 -> 下一请求”的顺序代理,适配 Codex CLI 的 turn 模式。
|
||
func (s *OpenAIGatewayService) ProxyResponsesWebSocketFromClient(
|
||
ctx context.Context,
|
||
c *gin.Context,
|
||
clientConn *coderws.Conn,
|
||
account *Account,
|
||
token string,
|
||
firstClientMessage []byte,
|
||
hooks *OpenAIWSIngressHooks,
|
||
) error {
|
||
if s == nil {
|
||
return errors.New("service is nil")
|
||
}
|
||
if c == nil {
|
||
return errors.New("gin context is nil")
|
||
}
|
||
if clientConn == nil {
|
||
return errors.New("client websocket is nil")
|
||
}
|
||
if account == nil {
|
||
return errors.New("account is nil")
|
||
}
|
||
if strings.TrimSpace(token) == "" {
|
||
return errors.New("token is empty")
|
||
}
|
||
|
||
wsDecision := s.getOpenAIWSProtocolResolver().Resolve(account)
|
||
modeRouterV2Enabled := s != nil && s.cfg != nil && s.cfg.Gateway.OpenAIWS.ModeRouterV2Enabled
|
||
ingressMode := OpenAIWSIngressModeShared
|
||
if modeRouterV2Enabled {
|
||
ingressMode = account.ResolveOpenAIResponsesWebSocketV2Mode(s.cfg.Gateway.OpenAIWS.IngressModeDefault)
|
||
if ingressMode == OpenAIWSIngressModeOff {
|
||
return NewOpenAIWSClientCloseError(
|
||
coderws.StatusPolicyViolation,
|
||
"websocket mode is disabled for this account",
|
||
nil,
|
||
)
|
||
}
|
||
}
|
||
if wsDecision.Transport != OpenAIUpstreamTransportResponsesWebsocketV2 {
|
||
return fmt.Errorf("websocket ingress requires ws_v2 transport, got=%s", wsDecision.Transport)
|
||
}
|
||
dedicatedMode := modeRouterV2Enabled && ingressMode == OpenAIWSIngressModeDedicated
|
||
|
||
wsURL, err := s.buildOpenAIResponsesWSURL(account)
|
||
if err != nil {
|
||
return fmt.Errorf("build ws url: %w", err)
|
||
}
|
||
wsHost := "-"
|
||
wsPath := "-"
|
||
if parsedURL, parseErr := url.Parse(wsURL); parseErr == nil && parsedURL != nil {
|
||
wsHost = normalizeOpenAIWSLogValue(parsedURL.Host)
|
||
wsPath = normalizeOpenAIWSLogValue(parsedURL.Path)
|
||
}
|
||
debugEnabled := isOpenAIWSModeDebugEnabled()
|
||
|
||
type openAIWSClientPayload struct {
|
||
payloadRaw []byte
|
||
rawForHash []byte
|
||
promptCacheKey string
|
||
previousResponseID string
|
||
originalModel string
|
||
payloadBytes int
|
||
}
|
||
|
||
applyPayloadMutation := func(current []byte, path string, value any) ([]byte, error) {
|
||
next, err := sjson.SetBytes(current, path, value)
|
||
if err == nil {
|
||
return next, nil
|
||
}
|
||
|
||
// 仅在确实需要修改 payload 且 sjson 失败时,退回 map 路径确保兼容性。
|
||
payload := make(map[string]any)
|
||
if unmarshalErr := json.Unmarshal(current, &payload); unmarshalErr != nil {
|
||
return nil, err
|
||
}
|
||
switch path {
|
||
case "type", "model":
|
||
payload[path] = value
|
||
case "client_metadata." + openAIWSTurnMetadataHeader:
|
||
setOpenAIWSTurnMetadata(payload, fmt.Sprintf("%v", value))
|
||
default:
|
||
return nil, err
|
||
}
|
||
rebuilt, marshalErr := json.Marshal(payload)
|
||
if marshalErr != nil {
|
||
return nil, marshalErr
|
||
}
|
||
return rebuilt, nil
|
||
}
|
||
|
||
parseClientPayload := func(raw []byte) (openAIWSClientPayload, error) {
|
||
trimmed := bytes.TrimSpace(raw)
|
||
if len(trimmed) == 0 {
|
||
return openAIWSClientPayload{}, NewOpenAIWSClientCloseError(coderws.StatusPolicyViolation, "empty websocket request payload", nil)
|
||
}
|
||
if !gjson.ValidBytes(trimmed) {
|
||
return openAIWSClientPayload{}, NewOpenAIWSClientCloseError(coderws.StatusPolicyViolation, "invalid websocket request payload", errors.New("invalid json"))
|
||
}
|
||
|
||
values := gjson.GetManyBytes(trimmed, "type", "model", "prompt_cache_key", "previous_response_id")
|
||
eventType := strings.TrimSpace(values[0].String())
|
||
normalized := trimmed
|
||
switch eventType {
|
||
case "":
|
||
eventType = "response.create"
|
||
next, setErr := applyPayloadMutation(normalized, "type", eventType)
|
||
if setErr != nil {
|
||
return openAIWSClientPayload{}, NewOpenAIWSClientCloseError(coderws.StatusPolicyViolation, "invalid websocket request payload", setErr)
|
||
}
|
||
normalized = next
|
||
case "response.create":
|
||
case "response.append":
|
||
return openAIWSClientPayload{}, NewOpenAIWSClientCloseError(
|
||
coderws.StatusPolicyViolation,
|
||
"response.append is not supported in ws v2; use response.create with previous_response_id",
|
||
nil,
|
||
)
|
||
default:
|
||
return openAIWSClientPayload{}, NewOpenAIWSClientCloseError(
|
||
coderws.StatusPolicyViolation,
|
||
fmt.Sprintf("unsupported websocket request type: %s", eventType),
|
||
nil,
|
||
)
|
||
}
|
||
|
||
originalModel := strings.TrimSpace(values[1].String())
|
||
if originalModel == "" {
|
||
return openAIWSClientPayload{}, NewOpenAIWSClientCloseError(
|
||
coderws.StatusPolicyViolation,
|
||
"model is required in response.create payload",
|
||
nil,
|
||
)
|
||
}
|
||
promptCacheKey := strings.TrimSpace(values[2].String())
|
||
previousResponseID := strings.TrimSpace(values[3].String())
|
||
previousResponseIDKind := ClassifyOpenAIPreviousResponseIDKind(previousResponseID)
|
||
if previousResponseID != "" && previousResponseIDKind == OpenAIPreviousResponseIDKindMessageID {
|
||
return openAIWSClientPayload{}, NewOpenAIWSClientCloseError(
|
||
coderws.StatusPolicyViolation,
|
||
"previous_response_id must be a response.id (resp_*), not a message id",
|
||
nil,
|
||
)
|
||
}
|
||
if turnMetadata := strings.TrimSpace(c.GetHeader(openAIWSTurnMetadataHeader)); turnMetadata != "" {
|
||
next, setErr := applyPayloadMutation(normalized, "client_metadata."+openAIWSTurnMetadataHeader, turnMetadata)
|
||
if setErr != nil {
|
||
return openAIWSClientPayload{}, NewOpenAIWSClientCloseError(coderws.StatusPolicyViolation, "invalid websocket request payload", setErr)
|
||
}
|
||
normalized = next
|
||
}
|
||
mappedModel := account.GetMappedModel(originalModel)
|
||
if normalizedModel := normalizeCodexModel(mappedModel); normalizedModel != "" {
|
||
mappedModel = normalizedModel
|
||
}
|
||
if mappedModel != originalModel {
|
||
next, setErr := applyPayloadMutation(normalized, "model", mappedModel)
|
||
if setErr != nil {
|
||
return openAIWSClientPayload{}, NewOpenAIWSClientCloseError(coderws.StatusPolicyViolation, "invalid websocket request payload", setErr)
|
||
}
|
||
normalized = next
|
||
}
|
||
|
||
return openAIWSClientPayload{
|
||
payloadRaw: normalized,
|
||
rawForHash: trimmed,
|
||
promptCacheKey: promptCacheKey,
|
||
previousResponseID: previousResponseID,
|
||
originalModel: originalModel,
|
||
payloadBytes: len(normalized),
|
||
}, nil
|
||
}
|
||
|
||
firstPayload, err := parseClientPayload(firstClientMessage)
|
||
if err != nil {
|
||
return err
|
||
}
|
||
|
||
turnState := strings.TrimSpace(c.GetHeader(openAIWSTurnStateHeader))
|
||
stateStore := s.getOpenAIWSStateStore()
|
||
groupID := getOpenAIGroupIDFromContext(c)
|
||
sessionHash := s.GenerateSessionHash(c, firstPayload.rawForHash)
|
||
if turnState == "" && stateStore != nil && sessionHash != "" {
|
||
if savedTurnState, ok := stateStore.GetSessionTurnState(groupID, sessionHash); ok {
|
||
turnState = savedTurnState
|
||
}
|
||
}
|
||
|
||
preferredConnID := ""
|
||
if stateStore != nil && firstPayload.previousResponseID != "" {
|
||
if connID, ok := stateStore.GetResponseConn(firstPayload.previousResponseID); ok {
|
||
preferredConnID = connID
|
||
}
|
||
}
|
||
|
||
storeDisabled := s.isOpenAIWSStoreDisabledInRequestRaw(firstPayload.payloadRaw, account)
|
||
storeDisabledConnMode := s.openAIWSStoreDisabledConnMode()
|
||
if stateStore != nil && storeDisabled && firstPayload.previousResponseID == "" && sessionHash != "" {
|
||
if connID, ok := stateStore.GetSessionConn(groupID, sessionHash); ok {
|
||
preferredConnID = connID
|
||
}
|
||
}
|
||
|
||
isCodexCLI := openai.IsCodexCLIRequest(c.GetHeader("User-Agent")) || (s.cfg != nil && s.cfg.Gateway.ForceCodexCLI)
|
||
wsHeaders, _ := s.buildOpenAIWSHeaders(c, account, token, wsDecision, isCodexCLI, turnState, strings.TrimSpace(c.GetHeader(openAIWSTurnMetadataHeader)), firstPayload.promptCacheKey)
|
||
baseAcquireReq := openAIWSAcquireRequest{
|
||
Account: account,
|
||
WSURL: wsURL,
|
||
Headers: wsHeaders,
|
||
ProxyURL: func() string {
|
||
if account.ProxyID != nil && account.Proxy != nil {
|
||
return account.Proxy.URL()
|
||
}
|
||
return ""
|
||
}(),
|
||
ForceNewConn: false,
|
||
}
|
||
pool := s.getOpenAIWSConnPool()
|
||
if pool == nil {
|
||
return errors.New("openai ws conn pool is nil")
|
||
}
|
||
|
||
logOpenAIWSModeInfo(
|
||
"ingress_ws_protocol_confirm account_id=%d account_type=%s transport=%s ws_host=%s ws_path=%s ws_mode=%s store_disabled=%v has_session_hash=%v has_previous_response_id=%v",
|
||
account.ID,
|
||
account.Type,
|
||
normalizeOpenAIWSLogValue(string(wsDecision.Transport)),
|
||
wsHost,
|
||
wsPath,
|
||
normalizeOpenAIWSLogValue(ingressMode),
|
||
storeDisabled,
|
||
sessionHash != "",
|
||
firstPayload.previousResponseID != "",
|
||
)
|
||
|
||
if debugEnabled {
|
||
logOpenAIWSModeDebug(
|
||
"ingress_ws_start account_id=%d account_type=%s transport=%s ws_host=%s preferred_conn_id=%s has_session_hash=%v has_previous_response_id=%v store_disabled=%v",
|
||
account.ID,
|
||
account.Type,
|
||
normalizeOpenAIWSLogValue(string(wsDecision.Transport)),
|
||
wsHost,
|
||
truncateOpenAIWSLogValue(preferredConnID, openAIWSIDValueMaxLen),
|
||
sessionHash != "",
|
||
firstPayload.previousResponseID != "",
|
||
storeDisabled,
|
||
)
|
||
}
|
||
if firstPayload.previousResponseID != "" {
|
||
firstPreviousResponseIDKind := ClassifyOpenAIPreviousResponseIDKind(firstPayload.previousResponseID)
|
||
logOpenAIWSModeInfo(
|
||
"ingress_ws_continuation_probe account_id=%d turn=%d previous_response_id=%s previous_response_id_kind=%s preferred_conn_id=%s session_hash=%s header_session_id=%s header_conversation_id=%s has_turn_state=%v turn_state_len=%d has_prompt_cache_key=%v store_disabled=%v",
|
||
account.ID,
|
||
1,
|
||
truncateOpenAIWSLogValue(firstPayload.previousResponseID, openAIWSIDValueMaxLen),
|
||
normalizeOpenAIWSLogValue(firstPreviousResponseIDKind),
|
||
truncateOpenAIWSLogValue(preferredConnID, openAIWSIDValueMaxLen),
|
||
truncateOpenAIWSLogValue(sessionHash, 12),
|
||
openAIWSHeaderValueForLog(baseAcquireReq.Headers, "session_id"),
|
||
openAIWSHeaderValueForLog(baseAcquireReq.Headers, "conversation_id"),
|
||
turnState != "",
|
||
len(turnState),
|
||
firstPayload.promptCacheKey != "",
|
||
storeDisabled,
|
||
)
|
||
}
|
||
|
||
acquireTimeout := s.openAIWSAcquireTimeout()
|
||
if acquireTimeout <= 0 {
|
||
acquireTimeout = 30 * time.Second
|
||
}
|
||
|
||
acquireTurnLease := func(turn int, preferred string, forcePreferredConn bool) (*openAIWSConnLease, error) {
|
||
req := cloneOpenAIWSAcquireRequest(baseAcquireReq)
|
||
req.PreferredConnID = strings.TrimSpace(preferred)
|
||
req.ForcePreferredConn = forcePreferredConn
|
||
// dedicated 模式下每次获取均新建连接,避免跨会话复用残留上下文。
|
||
req.ForceNewConn = dedicatedMode
|
||
acquireCtx, acquireCancel := context.WithTimeout(ctx, acquireTimeout)
|
||
lease, acquireErr := pool.Acquire(acquireCtx, req)
|
||
acquireCancel()
|
||
if acquireErr != nil {
|
||
dialStatus, dialClass, dialCloseStatus, dialCloseReason, dialRespServer, dialRespVia, dialRespCFRay, dialRespReqID := summarizeOpenAIWSDialError(acquireErr)
|
||
logOpenAIWSModeInfo(
|
||
"ingress_ws_upstream_acquire_fail account_id=%d turn=%d reason=%s dial_status=%d dial_class=%s dial_close_status=%s dial_close_reason=%s dial_resp_server=%s dial_resp_via=%s dial_resp_cf_ray=%s dial_resp_x_request_id=%s cause=%s preferred_conn_id=%s force_preferred_conn=%v ws_host=%s ws_path=%s proxy_enabled=%v",
|
||
account.ID,
|
||
turn,
|
||
normalizeOpenAIWSLogValue(classifyOpenAIWSAcquireError(acquireErr)),
|
||
dialStatus,
|
||
dialClass,
|
||
dialCloseStatus,
|
||
truncateOpenAIWSLogValue(dialCloseReason, openAIWSHeaderValueMaxLen),
|
||
dialRespServer,
|
||
dialRespVia,
|
||
dialRespCFRay,
|
||
dialRespReqID,
|
||
truncateOpenAIWSLogValue(acquireErr.Error(), openAIWSLogValueMaxLen),
|
||
truncateOpenAIWSLogValue(preferred, openAIWSIDValueMaxLen),
|
||
forcePreferredConn,
|
||
wsHost,
|
||
wsPath,
|
||
account.ProxyID != nil && account.Proxy != nil,
|
||
)
|
||
if errors.Is(acquireErr, errOpenAIWSPreferredConnUnavailable) {
|
||
return nil, NewOpenAIWSClientCloseError(
|
||
coderws.StatusPolicyViolation,
|
||
"upstream continuation connection is unavailable; please restart the conversation",
|
||
acquireErr,
|
||
)
|
||
}
|
||
if errors.Is(acquireErr, context.DeadlineExceeded) || errors.Is(acquireErr, errOpenAIWSConnQueueFull) {
|
||
return nil, NewOpenAIWSClientCloseError(
|
||
coderws.StatusTryAgainLater,
|
||
"upstream websocket is busy, please retry later",
|
||
acquireErr,
|
||
)
|
||
}
|
||
return nil, acquireErr
|
||
}
|
||
connID := strings.TrimSpace(lease.ConnID())
|
||
if handshakeTurnState := strings.TrimSpace(lease.HandshakeHeader(openAIWSTurnStateHeader)); handshakeTurnState != "" {
|
||
turnState = handshakeTurnState
|
||
if stateStore != nil && sessionHash != "" {
|
||
stateStore.BindSessionTurnState(groupID, sessionHash, handshakeTurnState, s.openAIWSSessionStickyTTL())
|
||
}
|
||
updatedHeaders := cloneHeader(baseAcquireReq.Headers)
|
||
if updatedHeaders == nil {
|
||
updatedHeaders = make(http.Header)
|
||
}
|
||
updatedHeaders.Set(openAIWSTurnStateHeader, handshakeTurnState)
|
||
baseAcquireReq.Headers = updatedHeaders
|
||
}
|
||
logOpenAIWSModeInfo(
|
||
"ingress_ws_upstream_connected account_id=%d turn=%d conn_id=%s conn_reused=%v conn_pick_ms=%d queue_wait_ms=%d preferred_conn_id=%s",
|
||
account.ID,
|
||
turn,
|
||
truncateOpenAIWSLogValue(connID, openAIWSIDValueMaxLen),
|
||
lease.Reused(),
|
||
lease.ConnPickDuration().Milliseconds(),
|
||
lease.QueueWaitDuration().Milliseconds(),
|
||
truncateOpenAIWSLogValue(preferred, openAIWSIDValueMaxLen),
|
||
)
|
||
return lease, nil
|
||
}
|
||
|
||
writeClientMessage := func(message []byte) error {
|
||
writeCtx, cancel := context.WithTimeout(ctx, s.openAIWSWriteTimeout())
|
||
defer cancel()
|
||
return clientConn.Write(writeCtx, coderws.MessageText, message)
|
||
}
|
||
|
||
readClientMessage := func() ([]byte, error) {
|
||
msgType, payload, readErr := clientConn.Read(ctx)
|
||
if readErr != nil {
|
||
return nil, readErr
|
||
}
|
||
if msgType != coderws.MessageText && msgType != coderws.MessageBinary {
|
||
return nil, NewOpenAIWSClientCloseError(
|
||
coderws.StatusPolicyViolation,
|
||
fmt.Sprintf("unsupported websocket client message type: %s", msgType.String()),
|
||
nil,
|
||
)
|
||
}
|
||
return payload, nil
|
||
}
|
||
|
||
sendAndRelay := func(turn int, lease *openAIWSConnLease, payload []byte, payloadBytes int, originalModel string) (*OpenAIForwardResult, error) {
|
||
if lease == nil {
|
||
return nil, errors.New("upstream websocket lease is nil")
|
||
}
|
||
turnStart := time.Now()
|
||
wroteDownstream := false
|
||
if err := lease.WriteJSONWithContextTimeout(ctx, json.RawMessage(payload), s.openAIWSWriteTimeout()); err != nil {
|
||
return nil, wrapOpenAIWSIngressTurnError(
|
||
"write_upstream",
|
||
fmt.Errorf("write upstream websocket request: %w", err),
|
||
false,
|
||
)
|
||
}
|
||
if debugEnabled {
|
||
logOpenAIWSModeDebug(
|
||
"ingress_ws_turn_request_sent account_id=%d turn=%d conn_id=%s payload_bytes=%d",
|
||
account.ID,
|
||
turn,
|
||
truncateOpenAIWSLogValue(lease.ConnID(), openAIWSIDValueMaxLen),
|
||
payloadBytes,
|
||
)
|
||
}
|
||
|
||
responseID := ""
|
||
usage := OpenAIUsage{}
|
||
var firstTokenMs *int
|
||
reqStream := openAIWSPayloadBoolFromRaw(payload, "stream", true)
|
||
turnPreviousResponseID := openAIWSPayloadStringFromRaw(payload, "previous_response_id")
|
||
turnPreviousResponseIDKind := ClassifyOpenAIPreviousResponseIDKind(turnPreviousResponseID)
|
||
turnPromptCacheKey := openAIWSPayloadStringFromRaw(payload, "prompt_cache_key")
|
||
turnStoreDisabled := s.isOpenAIWSStoreDisabledInRequestRaw(payload, account)
|
||
turnHasFunctionCallOutput := gjson.GetBytes(payload, `input.#(type=="function_call_output")`).Exists()
|
||
eventCount := 0
|
||
tokenEventCount := 0
|
||
terminalEventCount := 0
|
||
firstEventType := ""
|
||
lastEventType := ""
|
||
needModelReplace := false
|
||
clientDisconnected := false
|
||
mappedModel := ""
|
||
var mappedModelBytes []byte
|
||
if originalModel != "" {
|
||
mappedModel = account.GetMappedModel(originalModel)
|
||
if normalizedModel := normalizeCodexModel(mappedModel); normalizedModel != "" {
|
||
mappedModel = normalizedModel
|
||
}
|
||
needModelReplace = mappedModel != "" && mappedModel != originalModel
|
||
if needModelReplace {
|
||
mappedModelBytes = []byte(mappedModel)
|
||
}
|
||
}
|
||
for {
|
||
upstreamMessage, readErr := lease.ReadMessageWithContextTimeout(ctx, s.openAIWSReadTimeout())
|
||
if readErr != nil {
|
||
lease.MarkBroken()
|
||
return nil, wrapOpenAIWSIngressTurnError(
|
||
"read_upstream",
|
||
fmt.Errorf("read upstream websocket event: %w", readErr),
|
||
wroteDownstream,
|
||
)
|
||
}
|
||
|
||
eventType, eventResponseID, _ := parseOpenAIWSEventEnvelope(upstreamMessage)
|
||
if responseID == "" && eventResponseID != "" {
|
||
responseID = eventResponseID
|
||
}
|
||
if eventType != "" {
|
||
eventCount++
|
||
if firstEventType == "" {
|
||
firstEventType = eventType
|
||
}
|
||
lastEventType = eventType
|
||
}
|
||
if eventType == "error" {
|
||
errCodeRaw, errTypeRaw, errMsgRaw := parseOpenAIWSErrorEventFields(upstreamMessage)
|
||
fallbackReason, _ := classifyOpenAIWSErrorEventFromRaw(errCodeRaw, errTypeRaw, errMsgRaw)
|
||
errCode, errType, errMessage := summarizeOpenAIWSErrorEventFieldsFromRaw(errCodeRaw, errTypeRaw, errMsgRaw)
|
||
recoverablePrevNotFound := fallbackReason == openAIWSIngressStagePreviousResponseNotFound &&
|
||
turnPreviousResponseID != "" &&
|
||
!turnHasFunctionCallOutput &&
|
||
s.openAIWSIngressPreviousResponseRecoveryEnabled() &&
|
||
!wroteDownstream
|
||
if recoverablePrevNotFound {
|
||
// 可恢复场景使用非 error 关键字日志,避免被 LegacyPrintf 误判为 ERROR 级别。
|
||
logOpenAIWSModeInfo(
|
||
"ingress_ws_prev_response_recoverable account_id=%d turn=%d conn_id=%s idx=%d reason=%s code=%s type=%s message=%s previous_response_id=%s previous_response_id_kind=%s response_id=%s store_disabled=%v has_prompt_cache_key=%v",
|
||
account.ID,
|
||
turn,
|
||
truncateOpenAIWSLogValue(lease.ConnID(), openAIWSIDValueMaxLen),
|
||
eventCount,
|
||
truncateOpenAIWSLogValue(fallbackReason, openAIWSLogValueMaxLen),
|
||
errCode,
|
||
errType,
|
||
errMessage,
|
||
truncateOpenAIWSLogValue(turnPreviousResponseID, openAIWSIDValueMaxLen),
|
||
normalizeOpenAIWSLogValue(turnPreviousResponseIDKind),
|
||
truncateOpenAIWSLogValue(responseID, openAIWSIDValueMaxLen),
|
||
turnStoreDisabled,
|
||
turnPromptCacheKey != "",
|
||
)
|
||
} else {
|
||
logOpenAIWSModeInfo(
|
||
"ingress_ws_error_event account_id=%d turn=%d conn_id=%s idx=%d fallback_reason=%s err_code=%s err_type=%s err_message=%s previous_response_id=%s previous_response_id_kind=%s response_id=%s store_disabled=%v has_prompt_cache_key=%v",
|
||
account.ID,
|
||
turn,
|
||
truncateOpenAIWSLogValue(lease.ConnID(), openAIWSIDValueMaxLen),
|
||
eventCount,
|
||
truncateOpenAIWSLogValue(fallbackReason, openAIWSLogValueMaxLen),
|
||
errCode,
|
||
errType,
|
||
errMessage,
|
||
truncateOpenAIWSLogValue(turnPreviousResponseID, openAIWSIDValueMaxLen),
|
||
normalizeOpenAIWSLogValue(turnPreviousResponseIDKind),
|
||
truncateOpenAIWSLogValue(responseID, openAIWSIDValueMaxLen),
|
||
turnStoreDisabled,
|
||
turnPromptCacheKey != "",
|
||
)
|
||
}
|
||
// previous_response_not_found 在 ingress 模式支持单次恢复重试:
|
||
// 不把该 error 直接下发客户端,而是由上层去掉 previous_response_id 后重放当前 turn。
|
||
if recoverablePrevNotFound {
|
||
lease.MarkBroken()
|
||
errMsg := strings.TrimSpace(errMsgRaw)
|
||
if errMsg == "" {
|
||
errMsg = "previous response not found"
|
||
}
|
||
return nil, wrapOpenAIWSIngressTurnError(
|
||
openAIWSIngressStagePreviousResponseNotFound,
|
||
errors.New(errMsg),
|
||
false,
|
||
)
|
||
}
|
||
}
|
||
isTokenEvent := isOpenAIWSTokenEvent(eventType)
|
||
if isTokenEvent {
|
||
tokenEventCount++
|
||
}
|
||
isTerminalEvent := isOpenAIWSTerminalEvent(eventType)
|
||
if isTerminalEvent {
|
||
terminalEventCount++
|
||
}
|
||
if firstTokenMs == nil && isTokenEvent {
|
||
ms := int(time.Since(turnStart).Milliseconds())
|
||
firstTokenMs = &ms
|
||
}
|
||
if openAIWSEventShouldParseUsage(eventType) {
|
||
parseOpenAIWSResponseUsageFromCompletedEvent(upstreamMessage, &usage)
|
||
}
|
||
|
||
if !clientDisconnected {
|
||
if needModelReplace && len(mappedModelBytes) > 0 && openAIWSEventMayContainModel(eventType) && bytes.Contains(upstreamMessage, mappedModelBytes) {
|
||
upstreamMessage = replaceOpenAIWSMessageModel(upstreamMessage, mappedModel, originalModel)
|
||
}
|
||
if openAIWSEventMayContainToolCalls(eventType) && openAIWSMessageLikelyContainsToolCalls(upstreamMessage) {
|
||
if corrected, changed := s.toolCorrector.CorrectToolCallsInSSEBytes(upstreamMessage); changed {
|
||
upstreamMessage = corrected
|
||
}
|
||
}
|
||
if err := writeClientMessage(upstreamMessage); err != nil {
|
||
if isOpenAIWSClientDisconnectError(err) {
|
||
clientDisconnected = true
|
||
closeStatus, closeReason := summarizeOpenAIWSReadCloseError(err)
|
||
logOpenAIWSModeInfo(
|
||
"ingress_ws_client_disconnected_drain account_id=%d turn=%d conn_id=%s close_status=%s close_reason=%s",
|
||
account.ID,
|
||
turn,
|
||
truncateOpenAIWSLogValue(lease.ConnID(), openAIWSIDValueMaxLen),
|
||
closeStatus,
|
||
truncateOpenAIWSLogValue(closeReason, openAIWSHeaderValueMaxLen),
|
||
)
|
||
} else {
|
||
return nil, wrapOpenAIWSIngressTurnError(
|
||
"write_client",
|
||
fmt.Errorf("write client websocket event: %w", err),
|
||
wroteDownstream,
|
||
)
|
||
}
|
||
} else {
|
||
wroteDownstream = true
|
||
}
|
||
}
|
||
if isTerminalEvent {
|
||
// 客户端已断连时,上游连接的 session 状态不可信,标记 broken 避免回池复用。
|
||
if clientDisconnected {
|
||
lease.MarkBroken()
|
||
}
|
||
firstTokenMsValue := -1
|
||
if firstTokenMs != nil {
|
||
firstTokenMsValue = *firstTokenMs
|
||
}
|
||
if debugEnabled {
|
||
logOpenAIWSModeDebug(
|
||
"ingress_ws_turn_completed account_id=%d turn=%d conn_id=%s response_id=%s duration_ms=%d events=%d token_events=%d terminal_events=%d first_event=%s last_event=%s first_token_ms=%d client_disconnected=%v",
|
||
account.ID,
|
||
turn,
|
||
truncateOpenAIWSLogValue(lease.ConnID(), openAIWSIDValueMaxLen),
|
||
truncateOpenAIWSLogValue(responseID, openAIWSIDValueMaxLen),
|
||
time.Since(turnStart).Milliseconds(),
|
||
eventCount,
|
||
tokenEventCount,
|
||
terminalEventCount,
|
||
truncateOpenAIWSLogValue(firstEventType, openAIWSLogValueMaxLen),
|
||
truncateOpenAIWSLogValue(lastEventType, openAIWSLogValueMaxLen),
|
||
firstTokenMsValue,
|
||
clientDisconnected,
|
||
)
|
||
}
|
||
return &OpenAIForwardResult{
|
||
RequestID: responseID,
|
||
Usage: usage,
|
||
Model: originalModel,
|
||
ReasoningEffort: extractOpenAIReasoningEffortFromBody(payload, originalModel),
|
||
Stream: reqStream,
|
||
OpenAIWSMode: true,
|
||
Duration: time.Since(turnStart),
|
||
FirstTokenMs: firstTokenMs,
|
||
}, nil
|
||
}
|
||
}
|
||
}
|
||
|
||
currentPayload := firstPayload.payloadRaw
|
||
currentOriginalModel := firstPayload.originalModel
|
||
currentPayloadBytes := firstPayload.payloadBytes
|
||
isStrictAffinityTurn := func(payload []byte) bool {
|
||
if !storeDisabled {
|
||
return false
|
||
}
|
||
return strings.TrimSpace(openAIWSPayloadStringFromRaw(payload, "previous_response_id")) != ""
|
||
}
|
||
var sessionLease *openAIWSConnLease
|
||
sessionConnID := ""
|
||
pinnedSessionConnID := ""
|
||
unpinSessionConn := func(connID string) {
|
||
connID = strings.TrimSpace(connID)
|
||
if connID == "" || pinnedSessionConnID != connID {
|
||
return
|
||
}
|
||
pool.UnpinConn(account.ID, connID)
|
||
pinnedSessionConnID = ""
|
||
}
|
||
pinSessionConn := func(connID string) {
|
||
if !storeDisabled {
|
||
return
|
||
}
|
||
connID = strings.TrimSpace(connID)
|
||
if connID == "" || pinnedSessionConnID == connID {
|
||
return
|
||
}
|
||
if pinnedSessionConnID != "" {
|
||
pool.UnpinConn(account.ID, pinnedSessionConnID)
|
||
pinnedSessionConnID = ""
|
||
}
|
||
if pool.PinConn(account.ID, connID) {
|
||
pinnedSessionConnID = connID
|
||
}
|
||
}
|
||
releaseSessionLease := func() {
|
||
if sessionLease == nil {
|
||
return
|
||
}
|
||
if dedicatedMode {
|
||
// dedicated 会话结束后主动标记损坏,确保连接不会跨会话复用。
|
||
sessionLease.MarkBroken()
|
||
}
|
||
unpinSessionConn(sessionConnID)
|
||
sessionLease.Release()
|
||
if debugEnabled {
|
||
logOpenAIWSModeDebug(
|
||
"ingress_ws_upstream_released account_id=%d conn_id=%s",
|
||
account.ID,
|
||
truncateOpenAIWSLogValue(sessionConnID, openAIWSIDValueMaxLen),
|
||
)
|
||
}
|
||
}
|
||
defer releaseSessionLease()
|
||
|
||
turn := 1
|
||
turnRetry := 0
|
||
turnPrevRecoveryTried := false
|
||
lastTurnFinishedAt := time.Time{}
|
||
lastTurnResponseID := ""
|
||
lastTurnPayload := []byte(nil)
|
||
var lastTurnStrictState *openAIWSIngressPreviousTurnStrictState
|
||
lastTurnReplayInput := []json.RawMessage(nil)
|
||
lastTurnReplayInputExists := false
|
||
currentTurnReplayInput := []json.RawMessage(nil)
|
||
currentTurnReplayInputExists := false
|
||
skipBeforeTurn := false
|
||
resetSessionLease := func(markBroken bool) {
|
||
if sessionLease == nil {
|
||
return
|
||
}
|
||
if markBroken {
|
||
sessionLease.MarkBroken()
|
||
}
|
||
releaseSessionLease()
|
||
sessionLease = nil
|
||
sessionConnID = ""
|
||
preferredConnID = ""
|
||
}
|
||
recoverIngressPrevResponseNotFound := func(relayErr error, turn int, connID string) bool {
|
||
if !isOpenAIWSIngressPreviousResponseNotFound(relayErr) {
|
||
return false
|
||
}
|
||
if turnPrevRecoveryTried || !s.openAIWSIngressPreviousResponseRecoveryEnabled() {
|
||
return false
|
||
}
|
||
if isStrictAffinityTurn(currentPayload) {
|
||
// Layer 2:严格亲和链路命中 previous_response_not_found 时,降级为“去掉 previous_response_id 后重放一次”。
|
||
// 该错误说明续链锚点已失效,继续 strict fail-close 只会直接中断本轮请求。
|
||
logOpenAIWSModeInfo(
|
||
"ingress_ws_prev_response_recovery_layer2 account_id=%d turn=%d conn_id=%s store_disabled_conn_mode=%s action=drop_previous_response_id_retry",
|
||
account.ID,
|
||
turn,
|
||
truncateOpenAIWSLogValue(connID, openAIWSIDValueMaxLen),
|
||
normalizeOpenAIWSLogValue(storeDisabledConnMode),
|
||
)
|
||
}
|
||
turnPrevRecoveryTried = true
|
||
updatedPayload, removed, dropErr := dropPreviousResponseIDFromRawPayload(currentPayload)
|
||
if dropErr != nil || !removed {
|
||
reason := "not_removed"
|
||
if dropErr != nil {
|
||
reason = "drop_error"
|
||
}
|
||
logOpenAIWSModeInfo(
|
||
"ingress_ws_prev_response_recovery_skip account_id=%d turn=%d conn_id=%s reason=%s",
|
||
account.ID,
|
||
turn,
|
||
truncateOpenAIWSLogValue(connID, openAIWSIDValueMaxLen),
|
||
normalizeOpenAIWSLogValue(reason),
|
||
)
|
||
return false
|
||
}
|
||
updatedWithInput, setInputErr := setOpenAIWSPayloadInputSequence(
|
||
updatedPayload,
|
||
currentTurnReplayInput,
|
||
currentTurnReplayInputExists,
|
||
)
|
||
if setInputErr != nil {
|
||
logOpenAIWSModeInfo(
|
||
"ingress_ws_prev_response_recovery_skip account_id=%d turn=%d conn_id=%s reason=set_full_input_error cause=%s",
|
||
account.ID,
|
||
turn,
|
||
truncateOpenAIWSLogValue(connID, openAIWSIDValueMaxLen),
|
||
truncateOpenAIWSLogValue(setInputErr.Error(), openAIWSLogValueMaxLen),
|
||
)
|
||
return false
|
||
}
|
||
logOpenAIWSModeInfo(
|
||
"ingress_ws_prev_response_recovery account_id=%d turn=%d conn_id=%s action=drop_previous_response_id retry=1",
|
||
account.ID,
|
||
turn,
|
||
truncateOpenAIWSLogValue(connID, openAIWSIDValueMaxLen),
|
||
)
|
||
currentPayload = updatedWithInput
|
||
currentPayloadBytes = len(updatedWithInput)
|
||
resetSessionLease(true)
|
||
skipBeforeTurn = true
|
||
return true
|
||
}
|
||
retryIngressTurn := func(relayErr error, turn int, connID string) bool {
|
||
if !isOpenAIWSIngressTurnRetryable(relayErr) || turnRetry >= 1 {
|
||
return false
|
||
}
|
||
if isStrictAffinityTurn(currentPayload) {
|
||
logOpenAIWSModeInfo(
|
||
"ingress_ws_turn_retry_skip account_id=%d turn=%d conn_id=%s reason=strict_affinity",
|
||
account.ID,
|
||
turn,
|
||
truncateOpenAIWSLogValue(connID, openAIWSIDValueMaxLen),
|
||
)
|
||
return false
|
||
}
|
||
turnRetry++
|
||
logOpenAIWSModeInfo(
|
||
"ingress_ws_turn_retry account_id=%d turn=%d retry=%d reason=%s conn_id=%s",
|
||
account.ID,
|
||
turn,
|
||
turnRetry,
|
||
truncateOpenAIWSLogValue(openAIWSIngressTurnRetryReason(relayErr), openAIWSLogValueMaxLen),
|
||
truncateOpenAIWSLogValue(connID, openAIWSIDValueMaxLen),
|
||
)
|
||
resetSessionLease(true)
|
||
skipBeforeTurn = true
|
||
return true
|
||
}
|
||
for {
|
||
if !skipBeforeTurn && hooks != nil && hooks.BeforeTurn != nil {
|
||
if err := hooks.BeforeTurn(turn); err != nil {
|
||
return err
|
||
}
|
||
}
|
||
skipBeforeTurn = false
|
||
currentPreviousResponseID := openAIWSPayloadStringFromRaw(currentPayload, "previous_response_id")
|
||
expectedPrev := strings.TrimSpace(lastTurnResponseID)
|
||
hasFunctionCallOutput := gjson.GetBytes(currentPayload, `input.#(type=="function_call_output")`).Exists()
|
||
// store=false + function_call_output 场景必须有续链锚点。
|
||
// 若客户端未传 previous_response_id,优先回填上一轮响应 ID,避免上游报 call_id 无法关联。
|
||
if shouldInferIngressFunctionCallOutputPreviousResponseID(
|
||
storeDisabled,
|
||
turn,
|
||
hasFunctionCallOutput,
|
||
currentPreviousResponseID,
|
||
expectedPrev,
|
||
) {
|
||
updatedPayload, setPrevErr := setPreviousResponseIDToRawPayload(currentPayload, expectedPrev)
|
||
if setPrevErr != nil {
|
||
logOpenAIWSModeInfo(
|
||
"ingress_ws_function_call_output_prev_infer_skip account_id=%d turn=%d conn_id=%s reason=set_previous_response_id_error cause=%s expected_previous_response_id=%s",
|
||
account.ID,
|
||
turn,
|
||
truncateOpenAIWSLogValue(sessionConnID, openAIWSIDValueMaxLen),
|
||
truncateOpenAIWSLogValue(setPrevErr.Error(), openAIWSLogValueMaxLen),
|
||
truncateOpenAIWSLogValue(expectedPrev, openAIWSIDValueMaxLen),
|
||
)
|
||
} else {
|
||
currentPayload = updatedPayload
|
||
currentPayloadBytes = len(updatedPayload)
|
||
currentPreviousResponseID = expectedPrev
|
||
logOpenAIWSModeInfo(
|
||
"ingress_ws_function_call_output_prev_infer account_id=%d turn=%d conn_id=%s action=set_previous_response_id previous_response_id=%s",
|
||
account.ID,
|
||
turn,
|
||
truncateOpenAIWSLogValue(sessionConnID, openAIWSIDValueMaxLen),
|
||
truncateOpenAIWSLogValue(expectedPrev, openAIWSIDValueMaxLen),
|
||
)
|
||
}
|
||
}
|
||
nextReplayInput, nextReplayInputExists, replayInputErr := buildOpenAIWSReplayInputSequence(
|
||
lastTurnReplayInput,
|
||
lastTurnReplayInputExists,
|
||
currentPayload,
|
||
currentPreviousResponseID != "",
|
||
)
|
||
if replayInputErr != nil {
|
||
logOpenAIWSModeInfo(
|
||
"ingress_ws_replay_input_skip account_id=%d turn=%d conn_id=%s reason=build_error cause=%s",
|
||
account.ID,
|
||
turn,
|
||
truncateOpenAIWSLogValue(sessionConnID, openAIWSIDValueMaxLen),
|
||
truncateOpenAIWSLogValue(replayInputErr.Error(), openAIWSLogValueMaxLen),
|
||
)
|
||
currentTurnReplayInput = nil
|
||
currentTurnReplayInputExists = false
|
||
} else {
|
||
currentTurnReplayInput = nextReplayInput
|
||
currentTurnReplayInputExists = nextReplayInputExists
|
||
}
|
||
if storeDisabled && turn > 1 && currentPreviousResponseID != "" {
|
||
shouldKeepPreviousResponseID := false
|
||
strictReason := ""
|
||
var strictErr error
|
||
if lastTurnStrictState != nil {
|
||
shouldKeepPreviousResponseID, strictReason, strictErr = shouldKeepIngressPreviousResponseIDWithStrictState(
|
||
lastTurnStrictState,
|
||
currentPayload,
|
||
lastTurnResponseID,
|
||
hasFunctionCallOutput,
|
||
)
|
||
} else {
|
||
shouldKeepPreviousResponseID, strictReason, strictErr = shouldKeepIngressPreviousResponseID(
|
||
lastTurnPayload,
|
||
currentPayload,
|
||
lastTurnResponseID,
|
||
hasFunctionCallOutput,
|
||
)
|
||
}
|
||
if strictErr != nil {
|
||
logOpenAIWSModeInfo(
|
||
"ingress_ws_prev_response_strict_eval account_id=%d turn=%d conn_id=%s action=keep_previous_response_id reason=%s cause=%s previous_response_id=%s expected_previous_response_id=%s has_function_call_output=%v",
|
||
account.ID,
|
||
turn,
|
||
truncateOpenAIWSLogValue(sessionConnID, openAIWSIDValueMaxLen),
|
||
normalizeOpenAIWSLogValue(strictReason),
|
||
truncateOpenAIWSLogValue(strictErr.Error(), openAIWSLogValueMaxLen),
|
||
truncateOpenAIWSLogValue(currentPreviousResponseID, openAIWSIDValueMaxLen),
|
||
truncateOpenAIWSLogValue(expectedPrev, openAIWSIDValueMaxLen),
|
||
hasFunctionCallOutput,
|
||
)
|
||
} else if !shouldKeepPreviousResponseID {
|
||
updatedPayload, removed, dropErr := dropPreviousResponseIDFromRawPayload(currentPayload)
|
||
if dropErr != nil || !removed {
|
||
dropReason := "not_removed"
|
||
if dropErr != nil {
|
||
dropReason = "drop_error"
|
||
}
|
||
logOpenAIWSModeInfo(
|
||
"ingress_ws_prev_response_strict_eval account_id=%d turn=%d conn_id=%s action=keep_previous_response_id reason=%s drop_reason=%s previous_response_id=%s expected_previous_response_id=%s has_function_call_output=%v",
|
||
account.ID,
|
||
turn,
|
||
truncateOpenAIWSLogValue(sessionConnID, openAIWSIDValueMaxLen),
|
||
normalizeOpenAIWSLogValue(strictReason),
|
||
normalizeOpenAIWSLogValue(dropReason),
|
||
truncateOpenAIWSLogValue(currentPreviousResponseID, openAIWSIDValueMaxLen),
|
||
truncateOpenAIWSLogValue(expectedPrev, openAIWSIDValueMaxLen),
|
||
hasFunctionCallOutput,
|
||
)
|
||
} else {
|
||
updatedWithInput, setInputErr := setOpenAIWSPayloadInputSequence(
|
||
updatedPayload,
|
||
currentTurnReplayInput,
|
||
currentTurnReplayInputExists,
|
||
)
|
||
if setInputErr != nil {
|
||
logOpenAIWSModeInfo(
|
||
"ingress_ws_prev_response_strict_eval account_id=%d turn=%d conn_id=%s action=keep_previous_response_id reason=%s drop_reason=set_full_input_error previous_response_id=%s expected_previous_response_id=%s cause=%s has_function_call_output=%v",
|
||
account.ID,
|
||
turn,
|
||
truncateOpenAIWSLogValue(sessionConnID, openAIWSIDValueMaxLen),
|
||
normalizeOpenAIWSLogValue(strictReason),
|
||
truncateOpenAIWSLogValue(currentPreviousResponseID, openAIWSIDValueMaxLen),
|
||
truncateOpenAIWSLogValue(expectedPrev, openAIWSIDValueMaxLen),
|
||
truncateOpenAIWSLogValue(setInputErr.Error(), openAIWSLogValueMaxLen),
|
||
hasFunctionCallOutput,
|
||
)
|
||
} else {
|
||
currentPayload = updatedWithInput
|
||
currentPayloadBytes = len(updatedWithInput)
|
||
logOpenAIWSModeInfo(
|
||
"ingress_ws_prev_response_strict_eval account_id=%d turn=%d conn_id=%s action=drop_previous_response_id_full_create reason=%s previous_response_id=%s expected_previous_response_id=%s has_function_call_output=%v",
|
||
account.ID,
|
||
turn,
|
||
truncateOpenAIWSLogValue(sessionConnID, openAIWSIDValueMaxLen),
|
||
normalizeOpenAIWSLogValue(strictReason),
|
||
truncateOpenAIWSLogValue(currentPreviousResponseID, openAIWSIDValueMaxLen),
|
||
truncateOpenAIWSLogValue(expectedPrev, openAIWSIDValueMaxLen),
|
||
hasFunctionCallOutput,
|
||
)
|
||
currentPreviousResponseID = ""
|
||
}
|
||
}
|
||
}
|
||
}
|
||
forcePreferredConn := isStrictAffinityTurn(currentPayload)
|
||
if sessionLease == nil {
|
||
acquiredLease, acquireErr := acquireTurnLease(turn, preferredConnID, forcePreferredConn)
|
||
if acquireErr != nil {
|
||
return fmt.Errorf("acquire upstream websocket: %w", acquireErr)
|
||
}
|
||
sessionLease = acquiredLease
|
||
sessionConnID = strings.TrimSpace(sessionLease.ConnID())
|
||
if storeDisabled {
|
||
pinSessionConn(sessionConnID)
|
||
} else {
|
||
unpinSessionConn(sessionConnID)
|
||
}
|
||
}
|
||
shouldPreflightPing := turn > 1 && sessionLease != nil && turnRetry == 0
|
||
if shouldPreflightPing && openAIWSIngressPreflightPingIdle > 0 && !lastTurnFinishedAt.IsZero() {
|
||
if time.Since(lastTurnFinishedAt) < openAIWSIngressPreflightPingIdle {
|
||
shouldPreflightPing = false
|
||
}
|
||
}
|
||
if shouldPreflightPing {
|
||
if pingErr := sessionLease.PingWithTimeout(openAIWSConnHealthCheckTO); pingErr != nil {
|
||
logOpenAIWSModeInfo(
|
||
"ingress_ws_upstream_preflight_ping_fail account_id=%d turn=%d conn_id=%s cause=%s",
|
||
account.ID,
|
||
turn,
|
||
truncateOpenAIWSLogValue(sessionConnID, openAIWSIDValueMaxLen),
|
||
truncateOpenAIWSLogValue(pingErr.Error(), openAIWSLogValueMaxLen),
|
||
)
|
||
if forcePreferredConn {
|
||
if !turnPrevRecoveryTried && currentPreviousResponseID != "" {
|
||
updatedPayload, removed, dropErr := dropPreviousResponseIDFromRawPayload(currentPayload)
|
||
if dropErr != nil || !removed {
|
||
reason := "not_removed"
|
||
if dropErr != nil {
|
||
reason = "drop_error"
|
||
}
|
||
logOpenAIWSModeInfo(
|
||
"ingress_ws_preflight_ping_recovery_skip account_id=%d turn=%d conn_id=%s reason=%s previous_response_id=%s",
|
||
account.ID,
|
||
turn,
|
||
truncateOpenAIWSLogValue(sessionConnID, openAIWSIDValueMaxLen),
|
||
normalizeOpenAIWSLogValue(reason),
|
||
truncateOpenAIWSLogValue(currentPreviousResponseID, openAIWSIDValueMaxLen),
|
||
)
|
||
} else {
|
||
updatedWithInput, setInputErr := setOpenAIWSPayloadInputSequence(
|
||
updatedPayload,
|
||
currentTurnReplayInput,
|
||
currentTurnReplayInputExists,
|
||
)
|
||
if setInputErr != nil {
|
||
logOpenAIWSModeInfo(
|
||
"ingress_ws_preflight_ping_recovery_skip account_id=%d turn=%d conn_id=%s reason=set_full_input_error previous_response_id=%s cause=%s",
|
||
account.ID,
|
||
turn,
|
||
truncateOpenAIWSLogValue(sessionConnID, openAIWSIDValueMaxLen),
|
||
truncateOpenAIWSLogValue(currentPreviousResponseID, openAIWSIDValueMaxLen),
|
||
truncateOpenAIWSLogValue(setInputErr.Error(), openAIWSLogValueMaxLen),
|
||
)
|
||
} else {
|
||
logOpenAIWSModeInfo(
|
||
"ingress_ws_preflight_ping_recovery account_id=%d turn=%d conn_id=%s action=drop_previous_response_id_retry previous_response_id=%s",
|
||
account.ID,
|
||
turn,
|
||
truncateOpenAIWSLogValue(sessionConnID, openAIWSIDValueMaxLen),
|
||
truncateOpenAIWSLogValue(currentPreviousResponseID, openAIWSIDValueMaxLen),
|
||
)
|
||
turnPrevRecoveryTried = true
|
||
currentPayload = updatedWithInput
|
||
currentPayloadBytes = len(updatedWithInput)
|
||
resetSessionLease(true)
|
||
skipBeforeTurn = true
|
||
continue
|
||
}
|
||
}
|
||
}
|
||
resetSessionLease(true)
|
||
return NewOpenAIWSClientCloseError(
|
||
coderws.StatusPolicyViolation,
|
||
"upstream continuation connection is unavailable; please restart the conversation",
|
||
pingErr,
|
||
)
|
||
}
|
||
resetSessionLease(true)
|
||
|
||
acquiredLease, acquireErr := acquireTurnLease(turn, preferredConnID, forcePreferredConn)
|
||
if acquireErr != nil {
|
||
return fmt.Errorf("acquire upstream websocket after preflight ping fail: %w", acquireErr)
|
||
}
|
||
sessionLease = acquiredLease
|
||
sessionConnID = strings.TrimSpace(sessionLease.ConnID())
|
||
if storeDisabled {
|
||
pinSessionConn(sessionConnID)
|
||
}
|
||
}
|
||
}
|
||
connID := sessionConnID
|
||
if currentPreviousResponseID != "" {
|
||
chainedFromLast := expectedPrev != "" && currentPreviousResponseID == expectedPrev
|
||
currentPreviousResponseIDKind := ClassifyOpenAIPreviousResponseIDKind(currentPreviousResponseID)
|
||
logOpenAIWSModeInfo(
|
||
"ingress_ws_turn_chain account_id=%d turn=%d conn_id=%s previous_response_id=%s previous_response_id_kind=%s last_turn_response_id=%s chained_from_last=%v preferred_conn_id=%s header_session_id=%s header_conversation_id=%s has_turn_state=%v turn_state_len=%d has_prompt_cache_key=%v store_disabled=%v",
|
||
account.ID,
|
||
turn,
|
||
truncateOpenAIWSLogValue(connID, openAIWSIDValueMaxLen),
|
||
truncateOpenAIWSLogValue(currentPreviousResponseID, openAIWSIDValueMaxLen),
|
||
normalizeOpenAIWSLogValue(currentPreviousResponseIDKind),
|
||
truncateOpenAIWSLogValue(expectedPrev, openAIWSIDValueMaxLen),
|
||
chainedFromLast,
|
||
truncateOpenAIWSLogValue(preferredConnID, openAIWSIDValueMaxLen),
|
||
openAIWSHeaderValueForLog(baseAcquireReq.Headers, "session_id"),
|
||
openAIWSHeaderValueForLog(baseAcquireReq.Headers, "conversation_id"),
|
||
turnState != "",
|
||
len(turnState),
|
||
openAIWSPayloadStringFromRaw(currentPayload, "prompt_cache_key") != "",
|
||
storeDisabled,
|
||
)
|
||
}
|
||
|
||
result, relayErr := sendAndRelay(turn, sessionLease, currentPayload, currentPayloadBytes, currentOriginalModel)
|
||
if relayErr != nil {
|
||
if recoverIngressPrevResponseNotFound(relayErr, turn, connID) {
|
||
continue
|
||
}
|
||
if retryIngressTurn(relayErr, turn, connID) {
|
||
continue
|
||
}
|
||
finalErr := relayErr
|
||
if unwrapped := errors.Unwrap(relayErr); unwrapped != nil {
|
||
finalErr = unwrapped
|
||
}
|
||
if hooks != nil && hooks.AfterTurn != nil {
|
||
hooks.AfterTurn(turn, nil, finalErr)
|
||
}
|
||
sessionLease.MarkBroken()
|
||
return finalErr
|
||
}
|
||
turnRetry = 0
|
||
turnPrevRecoveryTried = false
|
||
lastTurnFinishedAt = time.Now()
|
||
if hooks != nil && hooks.AfterTurn != nil {
|
||
hooks.AfterTurn(turn, result, nil)
|
||
}
|
||
if result == nil {
|
||
return errors.New("websocket turn result is nil")
|
||
}
|
||
responseID := strings.TrimSpace(result.RequestID)
|
||
lastTurnResponseID = responseID
|
||
lastTurnPayload = cloneOpenAIWSPayloadBytes(currentPayload)
|
||
lastTurnReplayInput = cloneOpenAIWSRawMessages(currentTurnReplayInput)
|
||
lastTurnReplayInputExists = currentTurnReplayInputExists
|
||
nextStrictState, strictStateErr := buildOpenAIWSIngressPreviousTurnStrictState(currentPayload)
|
||
if strictStateErr != nil {
|
||
lastTurnStrictState = nil
|
||
logOpenAIWSModeInfo(
|
||
"ingress_ws_prev_response_strict_state_skip account_id=%d turn=%d conn_id=%s reason=build_error cause=%s",
|
||
account.ID,
|
||
turn,
|
||
truncateOpenAIWSLogValue(connID, openAIWSIDValueMaxLen),
|
||
truncateOpenAIWSLogValue(strictStateErr.Error(), openAIWSLogValueMaxLen),
|
||
)
|
||
} else {
|
||
lastTurnStrictState = nextStrictState
|
||
}
|
||
|
||
if responseID != "" && stateStore != nil {
|
||
ttl := s.openAIWSResponseStickyTTL()
|
||
logOpenAIWSBindResponseAccountWarn(groupID, account.ID, responseID, stateStore.BindResponseAccount(ctx, groupID, responseID, account.ID, ttl))
|
||
stateStore.BindResponseConn(responseID, connID, ttl)
|
||
}
|
||
if stateStore != nil && storeDisabled && sessionHash != "" {
|
||
stateStore.BindSessionConn(groupID, sessionHash, connID, s.openAIWSSessionStickyTTL())
|
||
}
|
||
if connID != "" {
|
||
preferredConnID = connID
|
||
}
|
||
|
||
nextClientMessage, readErr := readClientMessage()
|
||
if readErr != nil {
|
||
if isOpenAIWSClientDisconnectError(readErr) {
|
||
closeStatus, closeReason := summarizeOpenAIWSReadCloseError(readErr)
|
||
logOpenAIWSModeInfo(
|
||
"ingress_ws_client_closed account_id=%d conn_id=%s close_status=%s close_reason=%s",
|
||
account.ID,
|
||
truncateOpenAIWSLogValue(connID, openAIWSIDValueMaxLen),
|
||
closeStatus,
|
||
truncateOpenAIWSLogValue(closeReason, openAIWSHeaderValueMaxLen),
|
||
)
|
||
return nil
|
||
}
|
||
return fmt.Errorf("read client websocket request: %w", readErr)
|
||
}
|
||
|
||
nextPayload, parseErr := parseClientPayload(nextClientMessage)
|
||
if parseErr != nil {
|
||
return parseErr
|
||
}
|
||
if nextPayload.promptCacheKey != "" {
|
||
// ingress 会话在整个客户端 WS 生命周期内复用同一上游连接;
|
||
// prompt_cache_key 对握手头的更新仅在未来需要重新建连时生效。
|
||
updatedHeaders, _ := s.buildOpenAIWSHeaders(c, account, token, wsDecision, isCodexCLI, turnState, strings.TrimSpace(c.GetHeader(openAIWSTurnMetadataHeader)), nextPayload.promptCacheKey)
|
||
baseAcquireReq.Headers = updatedHeaders
|
||
}
|
||
if nextPayload.previousResponseID != "" {
|
||
expectedPrev := strings.TrimSpace(lastTurnResponseID)
|
||
chainedFromLast := expectedPrev != "" && nextPayload.previousResponseID == expectedPrev
|
||
nextPreviousResponseIDKind := ClassifyOpenAIPreviousResponseIDKind(nextPayload.previousResponseID)
|
||
logOpenAIWSModeInfo(
|
||
"ingress_ws_next_turn_chain account_id=%d turn=%d next_turn=%d conn_id=%s previous_response_id=%s previous_response_id_kind=%s last_turn_response_id=%s chained_from_last=%v has_prompt_cache_key=%v store_disabled=%v",
|
||
account.ID,
|
||
turn,
|
||
turn+1,
|
||
truncateOpenAIWSLogValue(connID, openAIWSIDValueMaxLen),
|
||
truncateOpenAIWSLogValue(nextPayload.previousResponseID, openAIWSIDValueMaxLen),
|
||
normalizeOpenAIWSLogValue(nextPreviousResponseIDKind),
|
||
truncateOpenAIWSLogValue(expectedPrev, openAIWSIDValueMaxLen),
|
||
chainedFromLast,
|
||
nextPayload.promptCacheKey != "",
|
||
storeDisabled,
|
||
)
|
||
}
|
||
if stateStore != nil && nextPayload.previousResponseID != "" {
|
||
if stickyConnID, ok := stateStore.GetResponseConn(nextPayload.previousResponseID); ok {
|
||
if sessionConnID != "" && stickyConnID != "" && stickyConnID != sessionConnID {
|
||
logOpenAIWSModeInfo(
|
||
"ingress_ws_keep_session_conn account_id=%d turn=%d conn_id=%s sticky_conn_id=%s previous_response_id=%s",
|
||
account.ID,
|
||
turn,
|
||
truncateOpenAIWSLogValue(sessionConnID, openAIWSIDValueMaxLen),
|
||
truncateOpenAIWSLogValue(stickyConnID, openAIWSIDValueMaxLen),
|
||
truncateOpenAIWSLogValue(nextPayload.previousResponseID, openAIWSIDValueMaxLen),
|
||
)
|
||
} else {
|
||
preferredConnID = stickyConnID
|
||
}
|
||
}
|
||
}
|
||
currentPayload = nextPayload.payloadRaw
|
||
currentOriginalModel = nextPayload.originalModel
|
||
currentPayloadBytes = nextPayload.payloadBytes
|
||
storeDisabled = s.isOpenAIWSStoreDisabledInRequestRaw(currentPayload, account)
|
||
if !storeDisabled {
|
||
unpinSessionConn(sessionConnID)
|
||
}
|
||
turn++
|
||
}
|
||
}
|
||
|
||
func (s *OpenAIGatewayService) isOpenAIWSGeneratePrewarmEnabled() bool {
|
||
return s != nil && s.cfg != nil && s.cfg.Gateway.OpenAIWS.PrewarmGenerateEnabled
|
||
}
|
||
|
||
// performOpenAIWSGeneratePrewarm 在 WSv2 下执行可选的 generate=false 预热。
|
||
// 预热默认关闭,仅在配置开启后生效;失败时按可恢复错误回退到 HTTP。
|
||
func (s *OpenAIGatewayService) performOpenAIWSGeneratePrewarm(
|
||
ctx context.Context,
|
||
lease *openAIWSConnLease,
|
||
decision OpenAIWSProtocolDecision,
|
||
payload map[string]any,
|
||
previousResponseID string,
|
||
reqBody map[string]any,
|
||
account *Account,
|
||
stateStore OpenAIWSStateStore,
|
||
groupID int64,
|
||
) error {
|
||
if s == nil {
|
||
return nil
|
||
}
|
||
if lease == nil || account == nil {
|
||
logOpenAIWSModeInfo("prewarm_skip reason=invalid_state has_lease=%v has_account=%v", lease != nil, account != nil)
|
||
return nil
|
||
}
|
||
connID := strings.TrimSpace(lease.ConnID())
|
||
if !s.isOpenAIWSGeneratePrewarmEnabled() {
|
||
return nil
|
||
}
|
||
if decision.Transport != OpenAIUpstreamTransportResponsesWebsocketV2 {
|
||
logOpenAIWSModeInfo(
|
||
"prewarm_skip account_id=%d conn_id=%s reason=transport_not_v2 transport=%s",
|
||
account.ID,
|
||
connID,
|
||
normalizeOpenAIWSLogValue(string(decision.Transport)),
|
||
)
|
||
return nil
|
||
}
|
||
if strings.TrimSpace(previousResponseID) != "" {
|
||
logOpenAIWSModeInfo(
|
||
"prewarm_skip account_id=%d conn_id=%s reason=has_previous_response_id previous_response_id=%s",
|
||
account.ID,
|
||
connID,
|
||
truncateOpenAIWSLogValue(previousResponseID, openAIWSIDValueMaxLen),
|
||
)
|
||
return nil
|
||
}
|
||
if lease.IsPrewarmed() {
|
||
logOpenAIWSModeInfo("prewarm_skip account_id=%d conn_id=%s reason=already_prewarmed", account.ID, connID)
|
||
return nil
|
||
}
|
||
if NeedsToolContinuation(reqBody) {
|
||
logOpenAIWSModeInfo("prewarm_skip account_id=%d conn_id=%s reason=tool_continuation", account.ID, connID)
|
||
return nil
|
||
}
|
||
prewarmStart := time.Now()
|
||
logOpenAIWSModeInfo("prewarm_start account_id=%d conn_id=%s", account.ID, connID)
|
||
|
||
prewarmPayload := make(map[string]any, len(payload)+1)
|
||
for k, v := range payload {
|
||
prewarmPayload[k] = v
|
||
}
|
||
prewarmPayload["generate"] = false
|
||
prewarmPayloadJSON := payloadAsJSONBytes(prewarmPayload)
|
||
|
||
if err := lease.WriteJSONWithContextTimeout(ctx, prewarmPayload, s.openAIWSWriteTimeout()); err != nil {
|
||
lease.MarkBroken()
|
||
logOpenAIWSModeInfo(
|
||
"prewarm_write_fail account_id=%d conn_id=%s cause=%s",
|
||
account.ID,
|
||
connID,
|
||
truncateOpenAIWSLogValue(err.Error(), openAIWSLogValueMaxLen),
|
||
)
|
||
return wrapOpenAIWSFallback("prewarm_write", err)
|
||
}
|
||
logOpenAIWSModeInfo("prewarm_write_sent account_id=%d conn_id=%s payload_bytes=%d", account.ID, connID, len(prewarmPayloadJSON))
|
||
|
||
prewarmResponseID := ""
|
||
prewarmEventCount := 0
|
||
prewarmTerminalCount := 0
|
||
for {
|
||
message, readErr := lease.ReadMessageWithContextTimeout(ctx, s.openAIWSReadTimeout())
|
||
if readErr != nil {
|
||
lease.MarkBroken()
|
||
closeStatus, closeReason := summarizeOpenAIWSReadCloseError(readErr)
|
||
logOpenAIWSModeInfo(
|
||
"prewarm_read_fail account_id=%d conn_id=%s close_status=%s close_reason=%s cause=%s events=%d",
|
||
account.ID,
|
||
connID,
|
||
closeStatus,
|
||
closeReason,
|
||
truncateOpenAIWSLogValue(readErr.Error(), openAIWSLogValueMaxLen),
|
||
prewarmEventCount,
|
||
)
|
||
return wrapOpenAIWSFallback("prewarm_"+classifyOpenAIWSReadFallbackReason(readErr), readErr)
|
||
}
|
||
|
||
eventType, eventResponseID, _ := parseOpenAIWSEventEnvelope(message)
|
||
if eventType == "" {
|
||
continue
|
||
}
|
||
prewarmEventCount++
|
||
if prewarmResponseID == "" && eventResponseID != "" {
|
||
prewarmResponseID = eventResponseID
|
||
}
|
||
if prewarmEventCount <= openAIWSPrewarmEventLogHead || eventType == "error" || isOpenAIWSTerminalEvent(eventType) {
|
||
logOpenAIWSModeInfo(
|
||
"prewarm_event account_id=%d conn_id=%s idx=%d type=%s bytes=%d",
|
||
account.ID,
|
||
connID,
|
||
prewarmEventCount,
|
||
truncateOpenAIWSLogValue(eventType, openAIWSLogValueMaxLen),
|
||
len(message),
|
||
)
|
||
}
|
||
|
||
if eventType == "error" {
|
||
errCodeRaw, errTypeRaw, errMsgRaw := parseOpenAIWSErrorEventFields(message)
|
||
errMsg := strings.TrimSpace(errMsgRaw)
|
||
if errMsg == "" {
|
||
errMsg = "OpenAI websocket prewarm error"
|
||
}
|
||
fallbackReason, canFallback := classifyOpenAIWSErrorEventFromRaw(errCodeRaw, errTypeRaw, errMsgRaw)
|
||
errCode, errType, errMessage := summarizeOpenAIWSErrorEventFieldsFromRaw(errCodeRaw, errTypeRaw, errMsgRaw)
|
||
logOpenAIWSModeInfo(
|
||
"prewarm_error_event account_id=%d conn_id=%s idx=%d fallback_reason=%s can_fallback=%v err_code=%s err_type=%s err_message=%s",
|
||
account.ID,
|
||
connID,
|
||
prewarmEventCount,
|
||
truncateOpenAIWSLogValue(fallbackReason, openAIWSLogValueMaxLen),
|
||
canFallback,
|
||
errCode,
|
||
errType,
|
||
errMessage,
|
||
)
|
||
lease.MarkBroken()
|
||
if canFallback {
|
||
return wrapOpenAIWSFallback("prewarm_"+fallbackReason, errors.New(errMsg))
|
||
}
|
||
return wrapOpenAIWSFallback("prewarm_error_event", errors.New(errMsg))
|
||
}
|
||
|
||
if isOpenAIWSTerminalEvent(eventType) {
|
||
prewarmTerminalCount++
|
||
break
|
||
}
|
||
}
|
||
|
||
lease.MarkPrewarmed()
|
||
if prewarmResponseID != "" && stateStore != nil {
|
||
ttl := s.openAIWSResponseStickyTTL()
|
||
logOpenAIWSBindResponseAccountWarn(groupID, account.ID, prewarmResponseID, stateStore.BindResponseAccount(ctx, groupID, prewarmResponseID, account.ID, ttl))
|
||
stateStore.BindResponseConn(prewarmResponseID, lease.ConnID(), ttl)
|
||
}
|
||
logOpenAIWSModeInfo(
|
||
"prewarm_done account_id=%d conn_id=%s response_id=%s events=%d terminal_events=%d duration_ms=%d",
|
||
account.ID,
|
||
connID,
|
||
truncateOpenAIWSLogValue(prewarmResponseID, openAIWSIDValueMaxLen),
|
||
prewarmEventCount,
|
||
prewarmTerminalCount,
|
||
time.Since(prewarmStart).Milliseconds(),
|
||
)
|
||
return nil
|
||
}
|
||
|
||
func payloadAsJSON(payload map[string]any) string {
|
||
return string(payloadAsJSONBytes(payload))
|
||
}
|
||
|
||
func payloadAsJSONBytes(payload map[string]any) []byte {
|
||
if len(payload) == 0 {
|
||
return []byte("{}")
|
||
}
|
||
body, err := json.Marshal(payload)
|
||
if err != nil {
|
||
return []byte("{}")
|
||
}
|
||
return body
|
||
}
|
||
|
||
func isOpenAIWSTerminalEvent(eventType string) bool {
|
||
switch strings.TrimSpace(eventType) {
|
||
case "response.completed", "response.done", "response.failed", "response.incomplete", "response.cancelled", "response.canceled":
|
||
return true
|
||
default:
|
||
return false
|
||
}
|
||
}
|
||
|
||
func isOpenAIWSTokenEvent(eventType string) bool {
|
||
eventType = strings.TrimSpace(eventType)
|
||
if eventType == "" {
|
||
return false
|
||
}
|
||
switch eventType {
|
||
case "response.created", "response.in_progress", "response.output_item.added", "response.output_item.done":
|
||
return false
|
||
}
|
||
if strings.Contains(eventType, ".delta") {
|
||
return true
|
||
}
|
||
if strings.HasPrefix(eventType, "response.output_text") {
|
||
return true
|
||
}
|
||
if strings.HasPrefix(eventType, "response.output") {
|
||
return true
|
||
}
|
||
return eventType == "response.completed" || eventType == "response.done"
|
||
}
|
||
|
||
func replaceOpenAIWSMessageModel(message []byte, fromModel, toModel string) []byte {
|
||
if len(message) == 0 {
|
||
return message
|
||
}
|
||
if strings.TrimSpace(fromModel) == "" || strings.TrimSpace(toModel) == "" || fromModel == toModel {
|
||
return message
|
||
}
|
||
if !bytes.Contains(message, []byte(`"model"`)) || !bytes.Contains(message, []byte(fromModel)) {
|
||
return message
|
||
}
|
||
modelValues := gjson.GetManyBytes(message, "model", "response.model")
|
||
replaceModel := modelValues[0].Exists() && modelValues[0].Str == fromModel
|
||
replaceResponseModel := modelValues[1].Exists() && modelValues[1].Str == fromModel
|
||
if !replaceModel && !replaceResponseModel {
|
||
return message
|
||
}
|
||
updated := message
|
||
if replaceModel {
|
||
if next, err := sjson.SetBytes(updated, "model", toModel); err == nil {
|
||
updated = next
|
||
}
|
||
}
|
||
if replaceResponseModel {
|
||
if next, err := sjson.SetBytes(updated, "response.model", toModel); err == nil {
|
||
updated = next
|
||
}
|
||
}
|
||
return updated
|
||
}
|
||
|
||
func populateOpenAIUsageFromResponseJSON(body []byte, usage *OpenAIUsage) {
|
||
if usage == nil || len(body) == 0 {
|
||
return
|
||
}
|
||
values := gjson.GetManyBytes(
|
||
body,
|
||
"usage.input_tokens",
|
||
"usage.output_tokens",
|
||
"usage.input_tokens_details.cached_tokens",
|
||
)
|
||
usage.InputTokens = int(values[0].Int())
|
||
usage.OutputTokens = int(values[1].Int())
|
||
usage.CacheReadInputTokens = int(values[2].Int())
|
||
}
|
||
|
||
func getOpenAIGroupIDFromContext(c *gin.Context) int64 {
|
||
if c == nil {
|
||
return 0
|
||
}
|
||
value, exists := c.Get("api_key")
|
||
if !exists {
|
||
return 0
|
||
}
|
||
apiKey, ok := value.(*APIKey)
|
||
if !ok || apiKey == nil || apiKey.GroupID == nil {
|
||
return 0
|
||
}
|
||
return *apiKey.GroupID
|
||
}
|
||
|
||
// SelectAccountByPreviousResponseID 按 previous_response_id 命中账号粘连。
|
||
// 未命中或账号不可用时返回 (nil, nil),由调用方继续走常规调度。
|
||
func (s *OpenAIGatewayService) SelectAccountByPreviousResponseID(
|
||
ctx context.Context,
|
||
groupID *int64,
|
||
previousResponseID string,
|
||
requestedModel string,
|
||
excludedIDs map[int64]struct{},
|
||
) (*AccountSelectionResult, error) {
|
||
if s == nil {
|
||
return nil, nil
|
||
}
|
||
responseID := strings.TrimSpace(previousResponseID)
|
||
if responseID == "" {
|
||
return nil, nil
|
||
}
|
||
store := s.getOpenAIWSStateStore()
|
||
if store == nil {
|
||
return nil, nil
|
||
}
|
||
|
||
accountID, err := store.GetResponseAccount(ctx, derefGroupID(groupID), responseID)
|
||
if err != nil || accountID <= 0 {
|
||
return nil, nil
|
||
}
|
||
if excludedIDs != nil {
|
||
if _, excluded := excludedIDs[accountID]; excluded {
|
||
return nil, nil
|
||
}
|
||
}
|
||
|
||
account, err := s.getSchedulableAccount(ctx, accountID)
|
||
if err != nil || account == nil {
|
||
_ = store.DeleteResponseAccount(ctx, derefGroupID(groupID), responseID)
|
||
return nil, nil
|
||
}
|
||
// 非 WSv2 场景(如 force_http/全局关闭)不应使用 previous_response_id 粘连,
|
||
// 以保持“回滚到 HTTP”后的历史行为一致性。
|
||
if s.getOpenAIWSProtocolResolver().Resolve(account).Transport != OpenAIUpstreamTransportResponsesWebsocketV2 {
|
||
return nil, nil
|
||
}
|
||
if shouldClearStickySession(account, requestedModel) || !account.IsOpenAI() {
|
||
_ = store.DeleteResponseAccount(ctx, derefGroupID(groupID), responseID)
|
||
return nil, nil
|
||
}
|
||
if requestedModel != "" && !account.IsModelSupported(requestedModel) {
|
||
return nil, nil
|
||
}
|
||
|
||
result, acquireErr := s.tryAcquireAccountSlot(ctx, accountID, account.Concurrency)
|
||
if acquireErr == nil && result.Acquired {
|
||
logOpenAIWSBindResponseAccountWarn(
|
||
derefGroupID(groupID),
|
||
accountID,
|
||
responseID,
|
||
store.BindResponseAccount(ctx, derefGroupID(groupID), responseID, accountID, s.openAIWSResponseStickyTTL()),
|
||
)
|
||
return &AccountSelectionResult{
|
||
Account: account,
|
||
Acquired: true,
|
||
ReleaseFunc: result.ReleaseFunc,
|
||
}, nil
|
||
}
|
||
|
||
cfg := s.schedulingConfig()
|
||
if s.concurrencyService != nil {
|
||
return &AccountSelectionResult{
|
||
Account: account,
|
||
WaitPlan: &AccountWaitPlan{
|
||
AccountID: accountID,
|
||
MaxConcurrency: account.Concurrency,
|
||
Timeout: cfg.StickySessionWaitTimeout,
|
||
MaxWaiting: cfg.StickySessionMaxWaiting,
|
||
},
|
||
}, nil
|
||
}
|
||
return nil, nil
|
||
}
|
||
|
||
func classifyOpenAIWSAcquireError(err error) string {
|
||
if err == nil {
|
||
return "acquire_conn"
|
||
}
|
||
var dialErr *openAIWSDialError
|
||
if errors.As(err, &dialErr) {
|
||
switch dialErr.StatusCode {
|
||
case 426:
|
||
return "upgrade_required"
|
||
case 401, 403:
|
||
return "auth_failed"
|
||
case 429:
|
||
return "upstream_rate_limited"
|
||
}
|
||
if dialErr.StatusCode >= 500 {
|
||
return "upstream_5xx"
|
||
}
|
||
return "dial_failed"
|
||
}
|
||
if errors.Is(err, errOpenAIWSConnQueueFull) {
|
||
return "conn_queue_full"
|
||
}
|
||
if errors.Is(err, errOpenAIWSPreferredConnUnavailable) {
|
||
return "preferred_conn_unavailable"
|
||
}
|
||
if errors.Is(err, context.DeadlineExceeded) {
|
||
return "acquire_timeout"
|
||
}
|
||
return "acquire_conn"
|
||
}
|
||
|
||
func classifyOpenAIWSErrorEventFromRaw(codeRaw, errTypeRaw, msgRaw string) (string, bool) {
|
||
code := strings.ToLower(strings.TrimSpace(codeRaw))
|
||
errType := strings.ToLower(strings.TrimSpace(errTypeRaw))
|
||
msg := strings.ToLower(strings.TrimSpace(msgRaw))
|
||
|
||
switch code {
|
||
case "upgrade_required":
|
||
return "upgrade_required", true
|
||
case "websocket_not_supported", "websocket_unsupported":
|
||
return "ws_unsupported", true
|
||
case "websocket_connection_limit_reached":
|
||
return "ws_connection_limit_reached", true
|
||
case "previous_response_not_found":
|
||
return "previous_response_not_found", true
|
||
}
|
||
if strings.Contains(msg, "upgrade required") || strings.Contains(msg, "status 426") {
|
||
return "upgrade_required", true
|
||
}
|
||
if strings.Contains(errType, "upgrade") {
|
||
return "upgrade_required", true
|
||
}
|
||
if strings.Contains(msg, "websocket") && strings.Contains(msg, "unsupported") {
|
||
return "ws_unsupported", true
|
||
}
|
||
if strings.Contains(msg, "connection limit") && strings.Contains(msg, "websocket") {
|
||
return "ws_connection_limit_reached", true
|
||
}
|
||
if strings.Contains(msg, "previous_response_not_found") ||
|
||
(strings.Contains(msg, "previous response") && strings.Contains(msg, "not found")) {
|
||
return "previous_response_not_found", true
|
||
}
|
||
if strings.Contains(errType, "server_error") || strings.Contains(code, "server_error") {
|
||
return "upstream_error_event", true
|
||
}
|
||
return "event_error", false
|
||
}
|
||
|
||
func classifyOpenAIWSErrorEvent(message []byte) (string, bool) {
|
||
if len(message) == 0 {
|
||
return "event_error", false
|
||
}
|
||
return classifyOpenAIWSErrorEventFromRaw(parseOpenAIWSErrorEventFields(message))
|
||
}
|
||
|
||
func openAIWSErrorHTTPStatusFromRaw(codeRaw, errTypeRaw string) int {
|
||
code := strings.ToLower(strings.TrimSpace(codeRaw))
|
||
errType := strings.ToLower(strings.TrimSpace(errTypeRaw))
|
||
switch {
|
||
case strings.Contains(errType, "invalid_request"),
|
||
strings.Contains(code, "invalid_request"),
|
||
strings.Contains(code, "bad_request"),
|
||
code == "previous_response_not_found":
|
||
return http.StatusBadRequest
|
||
case strings.Contains(errType, "authentication"),
|
||
strings.Contains(code, "invalid_api_key"),
|
||
strings.Contains(code, "unauthorized"):
|
||
return http.StatusUnauthorized
|
||
case strings.Contains(errType, "permission"),
|
||
strings.Contains(code, "forbidden"):
|
||
return http.StatusForbidden
|
||
case strings.Contains(errType, "rate_limit"),
|
||
strings.Contains(code, "rate_limit"),
|
||
strings.Contains(code, "insufficient_quota"):
|
||
return http.StatusTooManyRequests
|
||
default:
|
||
return http.StatusBadGateway
|
||
}
|
||
}
|
||
|
||
func openAIWSErrorHTTPStatus(message []byte) int {
|
||
if len(message) == 0 {
|
||
return http.StatusBadGateway
|
||
}
|
||
codeRaw, errTypeRaw, _ := parseOpenAIWSErrorEventFields(message)
|
||
return openAIWSErrorHTTPStatusFromRaw(codeRaw, errTypeRaw)
|
||
}
|
||
|
||
func (s *OpenAIGatewayService) openAIWSFallbackCooldown() time.Duration {
|
||
if s == nil || s.cfg == nil {
|
||
return 30 * time.Second
|
||
}
|
||
seconds := s.cfg.Gateway.OpenAIWS.FallbackCooldownSeconds
|
||
if seconds <= 0 {
|
||
return 0
|
||
}
|
||
return time.Duration(seconds) * time.Second
|
||
}
|
||
|
||
func (s *OpenAIGatewayService) isOpenAIWSFallbackCooling(accountID int64) bool {
|
||
if s == nil || accountID <= 0 {
|
||
return false
|
||
}
|
||
cooldown := s.openAIWSFallbackCooldown()
|
||
if cooldown <= 0 {
|
||
return false
|
||
}
|
||
rawUntil, ok := s.openaiWSFallbackUntil.Load(accountID)
|
||
if !ok || rawUntil == nil {
|
||
return false
|
||
}
|
||
until, ok := rawUntil.(time.Time)
|
||
if !ok || until.IsZero() {
|
||
s.openaiWSFallbackUntil.Delete(accountID)
|
||
return false
|
||
}
|
||
if time.Now().Before(until) {
|
||
return true
|
||
}
|
||
s.openaiWSFallbackUntil.Delete(accountID)
|
||
return false
|
||
}
|
||
|
||
func (s *OpenAIGatewayService) markOpenAIWSFallbackCooling(accountID int64, _ string) {
|
||
if s == nil || accountID <= 0 {
|
||
return
|
||
}
|
||
cooldown := s.openAIWSFallbackCooldown()
|
||
if cooldown <= 0 {
|
||
return
|
||
}
|
||
s.openaiWSFallbackUntil.Store(accountID, time.Now().Add(cooldown))
|
||
}
|
||
|
||
func (s *OpenAIGatewayService) clearOpenAIWSFallbackCooling(accountID int64) {
|
||
if s == nil || accountID <= 0 {
|
||
return
|
||
}
|
||
s.openaiWSFallbackUntil.Delete(accountID)
|
||
}
|