- Extract PublicSettingsInjectionPayload named struct with drift test - Add channel_monitor_default_interval_seconds to SSR injection - Add image_output_price to SupportedModelChip - Simplify AppSidebar buildSelfNavItems (admins see available channels) - Add gateway WARN logs for 503 no-available-accounts branches - Wire ChannelMonitorRunner into provideCleanup for graceful shutdown - Add migrations 130/131 (CC template userid fix + mimicry field cleanup) - Clean up fork-only features (sora, claude max simulation, client affinity) - Remove ~320 obsolete i18n keys - Add codexUsage utility, WechatServiceButton, BulkEditAccountModal - Tidy go.sum
660 lines
21 KiB
Go
660 lines
21 KiB
Go
package service
|
||
|
||
import (
|
||
"bufio"
|
||
"bytes"
|
||
"context"
|
||
"encoding/json"
|
||
"errors"
|
||
"fmt"
|
||
"io"
|
||
"net/http"
|
||
"strings"
|
||
"time"
|
||
|
||
"github.com/Wei-Shaw/sub2api/internal/pkg/apicompat"
|
||
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
|
||
"github.com/Wei-Shaw/sub2api/internal/util/responseheaders"
|
||
"github.com/gin-gonic/gin"
|
||
"github.com/tidwall/gjson"
|
||
"github.com/tidwall/sjson"
|
||
"go.uber.org/zap"
|
||
)
|
||
|
||
// cursorResponsesUnsupportedFields are top-level Responses API parameters that
|
||
// Codex upstreams reject with "Unsupported parameter: ...". They must be
|
||
// stripped when forwarding a raw client body through the Responses-shape
|
||
// short-circuit in ForwardAsChatCompletions (see isResponsesShape branch).
|
||
// The normal Chat Completions → Responses conversion path is unaffected
|
||
// because ChatCompletionsRequest has no fields for these parameters — unknown
|
||
// fields are dropped naturally by json.Unmarshal. Kept semantically in sync
|
||
// with the list in openai_gateway_service.go:2034 used by the /v1/responses
|
||
// passthrough path.
|
||
var cursorResponsesUnsupportedFields = []string{
|
||
"prompt_cache_retention",
|
||
"safety_identifier",
|
||
"metadata",
|
||
"stream_options",
|
||
}
|
||
|
||
// ForwardAsChatCompletions accepts a Chat Completions request body, converts it
|
||
// to OpenAI Responses API format, forwards to the OpenAI upstream, and converts
|
||
// the response back to Chat Completions format. All account types (OAuth and API
|
||
// Key) go through the Responses API conversion path since the upstream only
|
||
// exposes the /v1/responses endpoint.
|
||
func (s *OpenAIGatewayService) ForwardAsChatCompletions(
|
||
ctx context.Context,
|
||
c *gin.Context,
|
||
account *Account,
|
||
body []byte,
|
||
promptCacheKey string,
|
||
defaultMappedModel string,
|
||
) (*OpenAIForwardResult, error) {
|
||
startTime := time.Now()
|
||
|
||
// 1. Parse Chat Completions request
|
||
var chatReq apicompat.ChatCompletionsRequest
|
||
if err := json.Unmarshal(body, &chatReq); err != nil {
|
||
return nil, fmt.Errorf("parse chat completions request: %w", err)
|
||
}
|
||
originalModel := chatReq.Model
|
||
clientStream := chatReq.Stream
|
||
includeUsage := chatReq.StreamOptions != nil && chatReq.StreamOptions.IncludeUsage
|
||
|
||
// 2. Resolve model mapping early so compat prompt_cache_key injection can
|
||
// derive a stable seed from the final upstream model family.
|
||
billingModel := resolveOpenAIForwardModel(account, originalModel, defaultMappedModel)
|
||
upstreamModel := normalizeOpenAIModelForUpstream(account, billingModel)
|
||
|
||
promptCacheKey = strings.TrimSpace(promptCacheKey)
|
||
compatPromptCacheInjected := false
|
||
if promptCacheKey == "" && account.Type == AccountTypeOAuth && shouldAutoInjectPromptCacheKeyForCompat(upstreamModel) {
|
||
promptCacheKey = deriveCompatPromptCacheKey(&chatReq, upstreamModel)
|
||
compatPromptCacheInjected = promptCacheKey != ""
|
||
}
|
||
|
||
// 3. Build the upstream (Responses API) body.
|
||
//
|
||
// Cursor compatibility: some clients (notably Cursor cloud) send Responses
|
||
// API shaped bodies — `input: [...]` with no `messages` field — to the
|
||
// /v1/chat/completions URL. Running those through ChatCompletionsToResponses
|
||
// would silently drop Cursor's `input` array (the struct has no Input field)
|
||
// and produce `input: null`, which Codex upstreams reject with
|
||
// "Invalid type for 'input': expected a string, but got an object".
|
||
//
|
||
// Detect that shape and forward the raw body as-is, only rewriting `model`
|
||
// to the resolved upstream model. The downstream codex OAuth transform will
|
||
// still normalize store/stream/instructions/etc.
|
||
isResponsesShape := !gjson.GetBytes(body, "messages").Exists() && gjson.GetBytes(body, "input").Exists()
|
||
|
||
var (
|
||
responsesReq *apicompat.ResponsesRequest
|
||
responsesBody []byte
|
||
err error
|
||
)
|
||
if isResponsesShape {
|
||
responsesBody, err = sjson.SetBytes(body, "model", upstreamModel)
|
||
if err != nil {
|
||
return nil, fmt.Errorf("rewrite model in responses-shape body: %w", err)
|
||
}
|
||
// Strip Responses API parameters that no Codex upstream accepts.
|
||
// Because this branch forwards the raw body (the normal path rebuilds
|
||
// it from ChatCompletionsRequest and drops unknown fields naturally),
|
||
// we must filter these fields explicitly here — otherwise the upstream
|
||
// rejects the request with "Unsupported parameter: ...".
|
||
for _, field := range cursorResponsesUnsupportedFields {
|
||
if stripped, derr := sjson.DeleteBytes(responsesBody, field); derr == nil {
|
||
responsesBody = stripped
|
||
}
|
||
}
|
||
responsesBody, normalizedServiceTier, err := normalizeResponsesBodyServiceTier(responsesBody)
|
||
if err != nil {
|
||
return nil, fmt.Errorf("normalize service_tier in responses-shape body: %w", err)
|
||
}
|
||
// Minimal stub populated from the raw body so downstream billing
|
||
// propagation (ServiceTier, ReasoningEffort) keeps working.
|
||
responsesReq = &apicompat.ResponsesRequest{
|
||
Model: upstreamModel,
|
||
ServiceTier: normalizedServiceTier,
|
||
}
|
||
if effort := gjson.GetBytes(responsesBody, "reasoning.effort").String(); effort != "" {
|
||
responsesReq.Reasoning = &apicompat.ResponsesReasoning{Effort: effort}
|
||
}
|
||
} else {
|
||
// Normal path: convert Chat Completions → Responses.
|
||
// ChatCompletionsToResponses always sets Stream=true (upstream always streams).
|
||
responsesReq, err = apicompat.ChatCompletionsToResponses(&chatReq)
|
||
if err != nil {
|
||
return nil, fmt.Errorf("convert chat completions to responses: %w", err)
|
||
}
|
||
responsesReq.Model = upstreamModel
|
||
normalizeResponsesRequestServiceTier(responsesReq)
|
||
responsesBody, err = json.Marshal(responsesReq)
|
||
if err != nil {
|
||
return nil, fmt.Errorf("marshal responses request: %w", err)
|
||
}
|
||
}
|
||
|
||
logFields := []zap.Field{
|
||
zap.Int64("account_id", account.ID),
|
||
zap.String("original_model", originalModel),
|
||
zap.String("billing_model", billingModel),
|
||
zap.String("upstream_model", upstreamModel),
|
||
zap.Bool("stream", clientStream),
|
||
zap.Bool("responses_shape", isResponsesShape),
|
||
}
|
||
if compatPromptCacheInjected {
|
||
logFields = append(logFields,
|
||
zap.Bool("compat_prompt_cache_key_injected", true),
|
||
zap.String("compat_prompt_cache_key_sha256", hashSensitiveValueForLog(promptCacheKey)),
|
||
)
|
||
}
|
||
logger.L().Debug("openai chat_completions: model mapping applied", logFields...)
|
||
|
||
{
|
||
var reqBody map[string]any
|
||
if err := json.Unmarshal(responsesBody, &reqBody); err != nil {
|
||
return nil, fmt.Errorf("unmarshal for codex transform: %w", err)
|
||
}
|
||
modified := false
|
||
if account.Type == AccountTypeOAuth {
|
||
codexResult := applyCodexOAuthTransform(reqBody, false, false)
|
||
modified = codexResult.Modified
|
||
if codexResult.NormalizedModel != "" {
|
||
upstreamModel = codexResult.NormalizedModel
|
||
}
|
||
if codexResult.PromptCacheKey != "" {
|
||
promptCacheKey = codexResult.PromptCacheKey
|
||
} else if promptCacheKey != "" {
|
||
reqBody["prompt_cache_key"] = promptCacheKey
|
||
}
|
||
} else {
|
||
// 非 OAuth 账号也需要提取 system 消息并注入 instructions,
|
||
// 否则上游 GPT-5/Codex 等模型会报 "Instructions are required"。
|
||
if extractSystemMessagesFromInput(reqBody) {
|
||
modified = true
|
||
}
|
||
if applyInstructions(reqBody, false) {
|
||
modified = true
|
||
}
|
||
}
|
||
if modified {
|
||
responsesBody, err = json.Marshal(reqBody)
|
||
if err != nil {
|
||
return nil, fmt.Errorf("remarshal after codex transform: %w", err)
|
||
}
|
||
}
|
||
}
|
||
|
||
// 5. Get access token
|
||
token, _, err := s.GetAccessToken(ctx, account)
|
||
if err != nil {
|
||
return nil, fmt.Errorf("get access token: %w", err)
|
||
}
|
||
|
||
// 6. Build upstream request
|
||
upstreamReq, err := s.buildUpstreamRequest(ctx, c, account, responsesBody, token, true, promptCacheKey, false)
|
||
if err != nil {
|
||
return nil, fmt.Errorf("build upstream request: %w", err)
|
||
}
|
||
|
||
if promptCacheKey != "" {
|
||
upstreamReq.Header.Set("session_id", generateSessionUUID(promptCacheKey))
|
||
}
|
||
|
||
// 7. Send request
|
||
proxyURL := ""
|
||
if account.Proxy != nil {
|
||
proxyURL = account.Proxy.URL()
|
||
}
|
||
resp, err := s.httpUpstream.Do(upstreamReq, proxyURL, account.ID, account.Concurrency)
|
||
if err != nil {
|
||
safeErr := sanitizeUpstreamErrorMessage(err.Error())
|
||
setOpsUpstreamError(c, 0, safeErr, "")
|
||
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
|
||
Platform: account.Platform,
|
||
AccountID: account.ID,
|
||
AccountName: account.Name,
|
||
UpstreamStatusCode: 0,
|
||
Kind: "request_error",
|
||
Message: safeErr,
|
||
})
|
||
writeChatCompletionsError(c, http.StatusBadGateway, "upstream_error", "Upstream request failed")
|
||
return nil, fmt.Errorf("upstream request failed: %s", safeErr)
|
||
}
|
||
defer func() { _ = resp.Body.Close() }()
|
||
|
||
// 8. Handle error response with failover
|
||
if resp.StatusCode >= 400 {
|
||
respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
|
||
_ = resp.Body.Close()
|
||
resp.Body = io.NopCloser(bytes.NewReader(respBody))
|
||
|
||
upstreamMsg := strings.TrimSpace(extractUpstreamErrorMessage(respBody))
|
||
upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg)
|
||
if s.shouldFailoverOpenAIUpstreamResponse(resp.StatusCode, upstreamMsg, respBody) {
|
||
upstreamDetail := ""
|
||
if s.cfg != nil && s.cfg.Gateway.LogUpstreamErrorBody {
|
||
maxBytes := s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes
|
||
if maxBytes <= 0 {
|
||
maxBytes = 2048
|
||
}
|
||
upstreamDetail = truncateString(string(respBody), maxBytes)
|
||
}
|
||
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
|
||
Platform: account.Platform,
|
||
AccountID: account.ID,
|
||
AccountName: account.Name,
|
||
UpstreamStatusCode: resp.StatusCode,
|
||
UpstreamRequestID: resp.Header.Get("x-request-id"),
|
||
Kind: "failover",
|
||
Message: upstreamMsg,
|
||
Detail: upstreamDetail,
|
||
})
|
||
if s.rateLimitService != nil {
|
||
s.rateLimitService.HandleUpstreamError(ctx, account, resp.StatusCode, resp.Header, respBody)
|
||
}
|
||
return nil, &UpstreamFailoverError{
|
||
StatusCode: resp.StatusCode,
|
||
ResponseBody: respBody,
|
||
RetryableOnSameAccount: account.IsPoolMode() && (isPoolModeRetryableStatus(resp.StatusCode) || isOpenAITransientProcessingError(resp.StatusCode, upstreamMsg, respBody)),
|
||
}
|
||
}
|
||
return s.handleChatCompletionsErrorResponse(resp, c, account)
|
||
}
|
||
|
||
// 9. Handle normal response
|
||
var result *OpenAIForwardResult
|
||
var handleErr error
|
||
if clientStream {
|
||
result, handleErr = s.handleChatStreamingResponse(resp, c, originalModel, billingModel, upstreamModel, includeUsage, startTime)
|
||
} else {
|
||
result, handleErr = s.handleChatBufferedStreamingResponse(resp, c, originalModel, billingModel, upstreamModel, startTime)
|
||
}
|
||
|
||
// Propagate ServiceTier and ReasoningEffort to result for billing
|
||
if handleErr == nil && result != nil {
|
||
if responsesReq.ServiceTier != "" {
|
||
st := responsesReq.ServiceTier
|
||
result.ServiceTier = &st
|
||
}
|
||
if responsesReq.Reasoning != nil && responsesReq.Reasoning.Effort != "" {
|
||
re := responsesReq.Reasoning.Effort
|
||
result.ReasoningEffort = &re
|
||
}
|
||
}
|
||
|
||
// Extract and save Codex usage snapshot from response headers (for OAuth accounts)
|
||
if handleErr == nil && account.Type == AccountTypeOAuth {
|
||
if snapshot := ParseCodexRateLimitHeaders(resp.Header); snapshot != nil {
|
||
s.updateCodexUsageSnapshot(ctx, account.ID, snapshot)
|
||
}
|
||
}
|
||
|
||
return result, handleErr
|
||
}
|
||
|
||
func normalizeResponsesRequestServiceTier(req *apicompat.ResponsesRequest) {
|
||
if req == nil {
|
||
return
|
||
}
|
||
req.ServiceTier = normalizedOpenAIServiceTierValue(req.ServiceTier)
|
||
}
|
||
|
||
func normalizeResponsesBodyServiceTier(body []byte) ([]byte, string, error) {
|
||
if len(body) == 0 {
|
||
return body, "", nil
|
||
}
|
||
rawServiceTier := gjson.GetBytes(body, "service_tier").String()
|
||
if rawServiceTier == "" {
|
||
return body, "", nil
|
||
}
|
||
normalizedServiceTier := normalizedOpenAIServiceTierValue(rawServiceTier)
|
||
if normalizedServiceTier == "" {
|
||
trimmed, err := sjson.DeleteBytes(body, "service_tier")
|
||
return trimmed, "", err
|
||
}
|
||
if normalizedServiceTier == rawServiceTier {
|
||
return body, normalizedServiceTier, nil
|
||
}
|
||
trimmed, err := sjson.SetBytes(body, "service_tier", normalizedServiceTier)
|
||
return trimmed, normalizedServiceTier, err
|
||
}
|
||
|
||
func normalizedOpenAIServiceTierValue(raw string) string {
|
||
normalized := normalizeOpenAIServiceTier(raw)
|
||
if normalized == nil {
|
||
return ""
|
||
}
|
||
return *normalized
|
||
}
|
||
|
||
// handleChatCompletionsErrorResponse reads an upstream error and returns it in
|
||
// OpenAI Chat Completions error format.
|
||
func (s *OpenAIGatewayService) handleChatCompletionsErrorResponse(
|
||
resp *http.Response,
|
||
c *gin.Context,
|
||
account *Account,
|
||
) (*OpenAIForwardResult, error) {
|
||
return s.handleCompatErrorResponse(resp, c, account, writeChatCompletionsError)
|
||
}
|
||
|
||
// handleChatBufferedStreamingResponse reads all Responses SSE events from the
|
||
// upstream, finds the terminal event, converts to a Chat Completions JSON
|
||
// response, and writes it to the client.
|
||
func (s *OpenAIGatewayService) handleChatBufferedStreamingResponse(
|
||
resp *http.Response,
|
||
c *gin.Context,
|
||
originalModel string,
|
||
billingModel string,
|
||
upstreamModel string,
|
||
startTime time.Time,
|
||
) (*OpenAIForwardResult, error) {
|
||
requestID := resp.Header.Get("x-request-id")
|
||
|
||
scanner := bufio.NewScanner(resp.Body)
|
||
maxLineSize := defaultMaxLineSize
|
||
if s.cfg != nil && s.cfg.Gateway.MaxLineSize > 0 {
|
||
maxLineSize = s.cfg.Gateway.MaxLineSize
|
||
}
|
||
scanner.Buffer(make([]byte, 0, 64*1024), maxLineSize)
|
||
|
||
var finalResponse *apicompat.ResponsesResponse
|
||
var usage OpenAIUsage
|
||
acc := apicompat.NewBufferedResponseAccumulator()
|
||
|
||
for scanner.Scan() {
|
||
line := scanner.Text()
|
||
if !strings.HasPrefix(line, "data: ") || line == "data: [DONE]" {
|
||
continue
|
||
}
|
||
payload := line[6:]
|
||
|
||
var event apicompat.ResponsesStreamEvent
|
||
if err := json.Unmarshal([]byte(payload), &event); err != nil {
|
||
logger.L().Warn("openai chat_completions buffered: failed to parse event",
|
||
zap.Error(err),
|
||
zap.String("request_id", requestID),
|
||
)
|
||
continue
|
||
}
|
||
|
||
// Accumulate delta content for fallback when terminal output is empty.
|
||
acc.ProcessEvent(&event)
|
||
|
||
if (event.Type == "response.completed" || event.Type == "response.done" ||
|
||
event.Type == "response.incomplete" || event.Type == "response.failed") &&
|
||
event.Response != nil {
|
||
finalResponse = event.Response
|
||
if event.Response.Usage != nil {
|
||
usage = OpenAIUsage{
|
||
InputTokens: event.Response.Usage.InputTokens,
|
||
OutputTokens: event.Response.Usage.OutputTokens,
|
||
}
|
||
if event.Response.Usage.InputTokensDetails != nil {
|
||
usage.CacheReadInputTokens = event.Response.Usage.InputTokensDetails.CachedTokens
|
||
}
|
||
}
|
||
}
|
||
}
|
||
|
||
if err := scanner.Err(); err != nil {
|
||
if !errors.Is(err, context.Canceled) && !errors.Is(err, context.DeadlineExceeded) {
|
||
logger.L().Warn("openai chat_completions buffered: read error",
|
||
zap.Error(err),
|
||
zap.String("request_id", requestID),
|
||
)
|
||
}
|
||
}
|
||
|
||
if finalResponse == nil {
|
||
writeChatCompletionsError(c, http.StatusBadGateway, "api_error", "Upstream stream ended without a terminal response event")
|
||
return nil, fmt.Errorf("upstream stream ended without terminal event")
|
||
}
|
||
|
||
// When the terminal event has an empty output array, reconstruct from
|
||
// accumulated delta events so the client receives the full content.
|
||
acc.SupplementResponseOutput(finalResponse)
|
||
|
||
chatResp := apicompat.ResponsesToChatCompletions(finalResponse, originalModel)
|
||
|
||
if s.responseHeaderFilter != nil {
|
||
responseheaders.WriteFilteredHeaders(c.Writer.Header(), resp.Header, s.responseHeaderFilter)
|
||
}
|
||
c.JSON(http.StatusOK, chatResp)
|
||
|
||
return &OpenAIForwardResult{
|
||
RequestID: requestID,
|
||
Usage: usage,
|
||
Model: originalModel,
|
||
BillingModel: billingModel,
|
||
UpstreamModel: upstreamModel,
|
||
Stream: false,
|
||
Duration: time.Since(startTime),
|
||
}, nil
|
||
}
|
||
|
||
// handleChatStreamingResponse reads Responses SSE events from upstream,
|
||
// converts each to Chat Completions SSE chunks, and writes them to the client.
|
||
func (s *OpenAIGatewayService) handleChatStreamingResponse(
|
||
resp *http.Response,
|
||
c *gin.Context,
|
||
originalModel string,
|
||
billingModel string,
|
||
upstreamModel string,
|
||
includeUsage bool,
|
||
startTime time.Time,
|
||
) (*OpenAIForwardResult, error) {
|
||
requestID := resp.Header.Get("x-request-id")
|
||
|
||
if s.responseHeaderFilter != nil {
|
||
responseheaders.WriteFilteredHeaders(c.Writer.Header(), resp.Header, s.responseHeaderFilter)
|
||
}
|
||
c.Writer.Header().Set("Content-Type", "text/event-stream")
|
||
c.Writer.Header().Set("Cache-Control", "no-cache")
|
||
c.Writer.Header().Set("Connection", "keep-alive")
|
||
c.Writer.Header().Set("X-Accel-Buffering", "no")
|
||
c.Writer.WriteHeader(http.StatusOK)
|
||
|
||
state := apicompat.NewResponsesEventToChatState()
|
||
state.Model = originalModel
|
||
state.IncludeUsage = includeUsage
|
||
|
||
var usage OpenAIUsage
|
||
var firstTokenMs *int
|
||
firstChunk := true
|
||
|
||
scanner := bufio.NewScanner(resp.Body)
|
||
maxLineSize := defaultMaxLineSize
|
||
if s.cfg != nil && s.cfg.Gateway.MaxLineSize > 0 {
|
||
maxLineSize = s.cfg.Gateway.MaxLineSize
|
||
}
|
||
scanner.Buffer(make([]byte, 0, 64*1024), maxLineSize)
|
||
|
||
resultWithUsage := func() *OpenAIForwardResult {
|
||
return &OpenAIForwardResult{
|
||
RequestID: requestID,
|
||
Usage: usage,
|
||
Model: originalModel,
|
||
BillingModel: billingModel,
|
||
UpstreamModel: upstreamModel,
|
||
Stream: true,
|
||
Duration: time.Since(startTime),
|
||
FirstTokenMs: firstTokenMs,
|
||
}
|
||
}
|
||
|
||
processDataLine := func(payload string) bool {
|
||
if firstChunk {
|
||
firstChunk = false
|
||
ms := int(time.Since(startTime).Milliseconds())
|
||
firstTokenMs = &ms
|
||
}
|
||
|
||
var event apicompat.ResponsesStreamEvent
|
||
if err := json.Unmarshal([]byte(payload), &event); err != nil {
|
||
logger.L().Warn("openai chat_completions stream: failed to parse event",
|
||
zap.Error(err),
|
||
zap.String("request_id", requestID),
|
||
)
|
||
return false
|
||
}
|
||
|
||
// Extract usage from completion events
|
||
if (event.Type == "response.completed" || event.Type == "response.incomplete" || event.Type == "response.failed") &&
|
||
event.Response != nil && event.Response.Usage != nil {
|
||
usage = OpenAIUsage{
|
||
InputTokens: event.Response.Usage.InputTokens,
|
||
OutputTokens: event.Response.Usage.OutputTokens,
|
||
}
|
||
if event.Response.Usage.InputTokensDetails != nil {
|
||
usage.CacheReadInputTokens = event.Response.Usage.InputTokensDetails.CachedTokens
|
||
}
|
||
}
|
||
|
||
chunks := apicompat.ResponsesEventToChatChunks(&event, state)
|
||
for _, chunk := range chunks {
|
||
sse, err := apicompat.ChatChunkToSSE(chunk)
|
||
if err != nil {
|
||
logger.L().Warn("openai chat_completions stream: failed to marshal chunk",
|
||
zap.Error(err),
|
||
zap.String("request_id", requestID),
|
||
)
|
||
continue
|
||
}
|
||
if _, err := fmt.Fprint(c.Writer, sse); err != nil {
|
||
logger.L().Info("openai chat_completions stream: client disconnected",
|
||
zap.String("request_id", requestID),
|
||
)
|
||
return true
|
||
}
|
||
}
|
||
if len(chunks) > 0 {
|
||
c.Writer.Flush()
|
||
}
|
||
return false
|
||
}
|
||
|
||
finalizeStream := func() (*OpenAIForwardResult, error) {
|
||
if finalChunks := apicompat.FinalizeResponsesChatStream(state); len(finalChunks) > 0 {
|
||
for _, chunk := range finalChunks {
|
||
sse, err := apicompat.ChatChunkToSSE(chunk)
|
||
if err != nil {
|
||
continue
|
||
}
|
||
fmt.Fprint(c.Writer, sse) //nolint:errcheck
|
||
}
|
||
}
|
||
// Send [DONE] sentinel
|
||
fmt.Fprint(c.Writer, "data: [DONE]\n\n") //nolint:errcheck
|
||
c.Writer.Flush()
|
||
return resultWithUsage(), nil
|
||
}
|
||
|
||
handleScanErr := func(err error) {
|
||
if err != nil && !errors.Is(err, context.Canceled) && !errors.Is(err, context.DeadlineExceeded) {
|
||
logger.L().Warn("openai chat_completions stream: read error",
|
||
zap.Error(err),
|
||
zap.String("request_id", requestID),
|
||
)
|
||
}
|
||
}
|
||
|
||
// Determine keepalive interval
|
||
keepaliveInterval := time.Duration(0)
|
||
if s.cfg != nil && s.cfg.Gateway.StreamKeepaliveInterval > 0 {
|
||
keepaliveInterval = time.Duration(s.cfg.Gateway.StreamKeepaliveInterval) * time.Second
|
||
}
|
||
|
||
// No keepalive: fast synchronous path
|
||
if keepaliveInterval <= 0 {
|
||
for scanner.Scan() {
|
||
line := scanner.Text()
|
||
if !strings.HasPrefix(line, "data: ") || line == "data: [DONE]" {
|
||
continue
|
||
}
|
||
if processDataLine(line[6:]) {
|
||
return resultWithUsage(), nil
|
||
}
|
||
}
|
||
handleScanErr(scanner.Err())
|
||
return finalizeStream()
|
||
}
|
||
|
||
// With keepalive: goroutine + channel + select
|
||
type scanEvent struct {
|
||
line string
|
||
err error
|
||
}
|
||
events := make(chan scanEvent, 16)
|
||
done := make(chan struct{})
|
||
sendEvent := func(ev scanEvent) bool {
|
||
select {
|
||
case events <- ev:
|
||
return true
|
||
case <-done:
|
||
return false
|
||
}
|
||
}
|
||
go func() {
|
||
defer close(events)
|
||
for scanner.Scan() {
|
||
if !sendEvent(scanEvent{line: scanner.Text()}) {
|
||
return
|
||
}
|
||
}
|
||
if err := scanner.Err(); err != nil {
|
||
_ = sendEvent(scanEvent{err: err})
|
||
}
|
||
}()
|
||
defer close(done)
|
||
|
||
keepaliveTicker := time.NewTicker(keepaliveInterval)
|
||
defer keepaliveTicker.Stop()
|
||
lastDataAt := time.Now()
|
||
|
||
for {
|
||
select {
|
||
case ev, ok := <-events:
|
||
if !ok {
|
||
return finalizeStream()
|
||
}
|
||
if ev.err != nil {
|
||
handleScanErr(ev.err)
|
||
return finalizeStream()
|
||
}
|
||
lastDataAt = time.Now()
|
||
line := ev.line
|
||
if !strings.HasPrefix(line, "data: ") || line == "data: [DONE]" {
|
||
continue
|
||
}
|
||
if processDataLine(line[6:]) {
|
||
return resultWithUsage(), nil
|
||
}
|
||
|
||
case <-keepaliveTicker.C:
|
||
if time.Since(lastDataAt) < keepaliveInterval {
|
||
continue
|
||
}
|
||
// Send SSE comment as keepalive
|
||
if _, err := fmt.Fprint(c.Writer, ":\n\n"); err != nil {
|
||
logger.L().Info("openai chat_completions stream: client disconnected during keepalive",
|
||
zap.String("request_id", requestID),
|
||
)
|
||
return resultWithUsage(), nil
|
||
}
|
||
c.Writer.Flush()
|
||
}
|
||
}
|
||
}
|
||
|
||
// writeChatCompletionsError writes an error response in OpenAI Chat Completions format.
|
||
func writeChatCompletionsError(c *gin.Context, statusCode int, errType, message string) {
|
||
c.JSON(statusCode, gin.H{
|
||
"error": gin.H{
|
||
"type": errType,
|
||
"message": message,
|
||
},
|
||
})
|
||
}
|