Files
sub2api/backend/internal/service/openai_gateway_chat_completions.go
erio 748a84d871 sync: bring over remaining release/custom-0.1.115 changes
- Extract PublicSettingsInjectionPayload named struct with drift test
- Add channel_monitor_default_interval_seconds to SSR injection
- Add image_output_price to SupportedModelChip
- Simplify AppSidebar buildSelfNavItems (admins see available channels)
- Add gateway WARN logs for 503 no-available-accounts branches
- Wire ChannelMonitorRunner into provideCleanup for graceful shutdown
- Add migrations 130/131 (CC template userid fix + mimicry field cleanup)
- Clean up fork-only features (sora, claude max simulation, client affinity)
- Remove ~320 obsolete i18n keys
- Add codexUsage utility, WechatServiceButton, BulkEditAccountModal
- Tidy go.sum
2026-04-23 20:55:18 +08:00

660 lines
21 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package service
import (
"bufio"
"bytes"
"context"
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"strings"
"time"
"github.com/Wei-Shaw/sub2api/internal/pkg/apicompat"
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
"github.com/Wei-Shaw/sub2api/internal/util/responseheaders"
"github.com/gin-gonic/gin"
"github.com/tidwall/gjson"
"github.com/tidwall/sjson"
"go.uber.org/zap"
)
// cursorResponsesUnsupportedFields are top-level Responses API parameters that
// Codex upstreams reject with "Unsupported parameter: ...". They must be
// stripped when forwarding a raw client body through the Responses-shape
// short-circuit in ForwardAsChatCompletions (see isResponsesShape branch).
// The normal Chat Completions → Responses conversion path is unaffected
// because ChatCompletionsRequest has no fields for these parameters — unknown
// fields are dropped naturally by json.Unmarshal. Kept semantically in sync
// with the list in openai_gateway_service.go:2034 used by the /v1/responses
// passthrough path.
var cursorResponsesUnsupportedFields = []string{
"prompt_cache_retention",
"safety_identifier",
"metadata",
"stream_options",
}
// ForwardAsChatCompletions accepts a Chat Completions request body, converts it
// to OpenAI Responses API format, forwards to the OpenAI upstream, and converts
// the response back to Chat Completions format. All account types (OAuth and API
// Key) go through the Responses API conversion path since the upstream only
// exposes the /v1/responses endpoint.
func (s *OpenAIGatewayService) ForwardAsChatCompletions(
ctx context.Context,
c *gin.Context,
account *Account,
body []byte,
promptCacheKey string,
defaultMappedModel string,
) (*OpenAIForwardResult, error) {
startTime := time.Now()
// 1. Parse Chat Completions request
var chatReq apicompat.ChatCompletionsRequest
if err := json.Unmarshal(body, &chatReq); err != nil {
return nil, fmt.Errorf("parse chat completions request: %w", err)
}
originalModel := chatReq.Model
clientStream := chatReq.Stream
includeUsage := chatReq.StreamOptions != nil && chatReq.StreamOptions.IncludeUsage
// 2. Resolve model mapping early so compat prompt_cache_key injection can
// derive a stable seed from the final upstream model family.
billingModel := resolveOpenAIForwardModel(account, originalModel, defaultMappedModel)
upstreamModel := normalizeOpenAIModelForUpstream(account, billingModel)
promptCacheKey = strings.TrimSpace(promptCacheKey)
compatPromptCacheInjected := false
if promptCacheKey == "" && account.Type == AccountTypeOAuth && shouldAutoInjectPromptCacheKeyForCompat(upstreamModel) {
promptCacheKey = deriveCompatPromptCacheKey(&chatReq, upstreamModel)
compatPromptCacheInjected = promptCacheKey != ""
}
// 3. Build the upstream (Responses API) body.
//
// Cursor compatibility: some clients (notably Cursor cloud) send Responses
// API shaped bodies — `input: [...]` with no `messages` field — to the
// /v1/chat/completions URL. Running those through ChatCompletionsToResponses
// would silently drop Cursor's `input` array (the struct has no Input field)
// and produce `input: null`, which Codex upstreams reject with
// "Invalid type for 'input': expected a string, but got an object".
//
// Detect that shape and forward the raw body as-is, only rewriting `model`
// to the resolved upstream model. The downstream codex OAuth transform will
// still normalize store/stream/instructions/etc.
isResponsesShape := !gjson.GetBytes(body, "messages").Exists() && gjson.GetBytes(body, "input").Exists()
var (
responsesReq *apicompat.ResponsesRequest
responsesBody []byte
err error
)
if isResponsesShape {
responsesBody, err = sjson.SetBytes(body, "model", upstreamModel)
if err != nil {
return nil, fmt.Errorf("rewrite model in responses-shape body: %w", err)
}
// Strip Responses API parameters that no Codex upstream accepts.
// Because this branch forwards the raw body (the normal path rebuilds
// it from ChatCompletionsRequest and drops unknown fields naturally),
// we must filter these fields explicitly here — otherwise the upstream
// rejects the request with "Unsupported parameter: ...".
for _, field := range cursorResponsesUnsupportedFields {
if stripped, derr := sjson.DeleteBytes(responsesBody, field); derr == nil {
responsesBody = stripped
}
}
responsesBody, normalizedServiceTier, err := normalizeResponsesBodyServiceTier(responsesBody)
if err != nil {
return nil, fmt.Errorf("normalize service_tier in responses-shape body: %w", err)
}
// Minimal stub populated from the raw body so downstream billing
// propagation (ServiceTier, ReasoningEffort) keeps working.
responsesReq = &apicompat.ResponsesRequest{
Model: upstreamModel,
ServiceTier: normalizedServiceTier,
}
if effort := gjson.GetBytes(responsesBody, "reasoning.effort").String(); effort != "" {
responsesReq.Reasoning = &apicompat.ResponsesReasoning{Effort: effort}
}
} else {
// Normal path: convert Chat Completions → Responses.
// ChatCompletionsToResponses always sets Stream=true (upstream always streams).
responsesReq, err = apicompat.ChatCompletionsToResponses(&chatReq)
if err != nil {
return nil, fmt.Errorf("convert chat completions to responses: %w", err)
}
responsesReq.Model = upstreamModel
normalizeResponsesRequestServiceTier(responsesReq)
responsesBody, err = json.Marshal(responsesReq)
if err != nil {
return nil, fmt.Errorf("marshal responses request: %w", err)
}
}
logFields := []zap.Field{
zap.Int64("account_id", account.ID),
zap.String("original_model", originalModel),
zap.String("billing_model", billingModel),
zap.String("upstream_model", upstreamModel),
zap.Bool("stream", clientStream),
zap.Bool("responses_shape", isResponsesShape),
}
if compatPromptCacheInjected {
logFields = append(logFields,
zap.Bool("compat_prompt_cache_key_injected", true),
zap.String("compat_prompt_cache_key_sha256", hashSensitiveValueForLog(promptCacheKey)),
)
}
logger.L().Debug("openai chat_completions: model mapping applied", logFields...)
{
var reqBody map[string]any
if err := json.Unmarshal(responsesBody, &reqBody); err != nil {
return nil, fmt.Errorf("unmarshal for codex transform: %w", err)
}
modified := false
if account.Type == AccountTypeOAuth {
codexResult := applyCodexOAuthTransform(reqBody, false, false)
modified = codexResult.Modified
if codexResult.NormalizedModel != "" {
upstreamModel = codexResult.NormalizedModel
}
if codexResult.PromptCacheKey != "" {
promptCacheKey = codexResult.PromptCacheKey
} else if promptCacheKey != "" {
reqBody["prompt_cache_key"] = promptCacheKey
}
} else {
// 非 OAuth 账号也需要提取 system 消息并注入 instructions
// 否则上游 GPT-5/Codex 等模型会报 "Instructions are required"。
if extractSystemMessagesFromInput(reqBody) {
modified = true
}
if applyInstructions(reqBody, false) {
modified = true
}
}
if modified {
responsesBody, err = json.Marshal(reqBody)
if err != nil {
return nil, fmt.Errorf("remarshal after codex transform: %w", err)
}
}
}
// 5. Get access token
token, _, err := s.GetAccessToken(ctx, account)
if err != nil {
return nil, fmt.Errorf("get access token: %w", err)
}
// 6. Build upstream request
upstreamReq, err := s.buildUpstreamRequest(ctx, c, account, responsesBody, token, true, promptCacheKey, false)
if err != nil {
return nil, fmt.Errorf("build upstream request: %w", err)
}
if promptCacheKey != "" {
upstreamReq.Header.Set("session_id", generateSessionUUID(promptCacheKey))
}
// 7. Send request
proxyURL := ""
if account.Proxy != nil {
proxyURL = account.Proxy.URL()
}
resp, err := s.httpUpstream.Do(upstreamReq, proxyURL, account.ID, account.Concurrency)
if err != nil {
safeErr := sanitizeUpstreamErrorMessage(err.Error())
setOpsUpstreamError(c, 0, safeErr, "")
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
Platform: account.Platform,
AccountID: account.ID,
AccountName: account.Name,
UpstreamStatusCode: 0,
Kind: "request_error",
Message: safeErr,
})
writeChatCompletionsError(c, http.StatusBadGateway, "upstream_error", "Upstream request failed")
return nil, fmt.Errorf("upstream request failed: %s", safeErr)
}
defer func() { _ = resp.Body.Close() }()
// 8. Handle error response with failover
if resp.StatusCode >= 400 {
respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
_ = resp.Body.Close()
resp.Body = io.NopCloser(bytes.NewReader(respBody))
upstreamMsg := strings.TrimSpace(extractUpstreamErrorMessage(respBody))
upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg)
if s.shouldFailoverOpenAIUpstreamResponse(resp.StatusCode, upstreamMsg, respBody) {
upstreamDetail := ""
if s.cfg != nil && s.cfg.Gateway.LogUpstreamErrorBody {
maxBytes := s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes
if maxBytes <= 0 {
maxBytes = 2048
}
upstreamDetail = truncateString(string(respBody), maxBytes)
}
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
Platform: account.Platform,
AccountID: account.ID,
AccountName: account.Name,
UpstreamStatusCode: resp.StatusCode,
UpstreamRequestID: resp.Header.Get("x-request-id"),
Kind: "failover",
Message: upstreamMsg,
Detail: upstreamDetail,
})
if s.rateLimitService != nil {
s.rateLimitService.HandleUpstreamError(ctx, account, resp.StatusCode, resp.Header, respBody)
}
return nil, &UpstreamFailoverError{
StatusCode: resp.StatusCode,
ResponseBody: respBody,
RetryableOnSameAccount: account.IsPoolMode() && (isPoolModeRetryableStatus(resp.StatusCode) || isOpenAITransientProcessingError(resp.StatusCode, upstreamMsg, respBody)),
}
}
return s.handleChatCompletionsErrorResponse(resp, c, account)
}
// 9. Handle normal response
var result *OpenAIForwardResult
var handleErr error
if clientStream {
result, handleErr = s.handleChatStreamingResponse(resp, c, originalModel, billingModel, upstreamModel, includeUsage, startTime)
} else {
result, handleErr = s.handleChatBufferedStreamingResponse(resp, c, originalModel, billingModel, upstreamModel, startTime)
}
// Propagate ServiceTier and ReasoningEffort to result for billing
if handleErr == nil && result != nil {
if responsesReq.ServiceTier != "" {
st := responsesReq.ServiceTier
result.ServiceTier = &st
}
if responsesReq.Reasoning != nil && responsesReq.Reasoning.Effort != "" {
re := responsesReq.Reasoning.Effort
result.ReasoningEffort = &re
}
}
// Extract and save Codex usage snapshot from response headers (for OAuth accounts)
if handleErr == nil && account.Type == AccountTypeOAuth {
if snapshot := ParseCodexRateLimitHeaders(resp.Header); snapshot != nil {
s.updateCodexUsageSnapshot(ctx, account.ID, snapshot)
}
}
return result, handleErr
}
func normalizeResponsesRequestServiceTier(req *apicompat.ResponsesRequest) {
if req == nil {
return
}
req.ServiceTier = normalizedOpenAIServiceTierValue(req.ServiceTier)
}
func normalizeResponsesBodyServiceTier(body []byte) ([]byte, string, error) {
if len(body) == 0 {
return body, "", nil
}
rawServiceTier := gjson.GetBytes(body, "service_tier").String()
if rawServiceTier == "" {
return body, "", nil
}
normalizedServiceTier := normalizedOpenAIServiceTierValue(rawServiceTier)
if normalizedServiceTier == "" {
trimmed, err := sjson.DeleteBytes(body, "service_tier")
return trimmed, "", err
}
if normalizedServiceTier == rawServiceTier {
return body, normalizedServiceTier, nil
}
trimmed, err := sjson.SetBytes(body, "service_tier", normalizedServiceTier)
return trimmed, normalizedServiceTier, err
}
func normalizedOpenAIServiceTierValue(raw string) string {
normalized := normalizeOpenAIServiceTier(raw)
if normalized == nil {
return ""
}
return *normalized
}
// handleChatCompletionsErrorResponse reads an upstream error and returns it in
// OpenAI Chat Completions error format.
func (s *OpenAIGatewayService) handleChatCompletionsErrorResponse(
resp *http.Response,
c *gin.Context,
account *Account,
) (*OpenAIForwardResult, error) {
return s.handleCompatErrorResponse(resp, c, account, writeChatCompletionsError)
}
// handleChatBufferedStreamingResponse reads all Responses SSE events from the
// upstream, finds the terminal event, converts to a Chat Completions JSON
// response, and writes it to the client.
func (s *OpenAIGatewayService) handleChatBufferedStreamingResponse(
resp *http.Response,
c *gin.Context,
originalModel string,
billingModel string,
upstreamModel string,
startTime time.Time,
) (*OpenAIForwardResult, error) {
requestID := resp.Header.Get("x-request-id")
scanner := bufio.NewScanner(resp.Body)
maxLineSize := defaultMaxLineSize
if s.cfg != nil && s.cfg.Gateway.MaxLineSize > 0 {
maxLineSize = s.cfg.Gateway.MaxLineSize
}
scanner.Buffer(make([]byte, 0, 64*1024), maxLineSize)
var finalResponse *apicompat.ResponsesResponse
var usage OpenAIUsage
acc := apicompat.NewBufferedResponseAccumulator()
for scanner.Scan() {
line := scanner.Text()
if !strings.HasPrefix(line, "data: ") || line == "data: [DONE]" {
continue
}
payload := line[6:]
var event apicompat.ResponsesStreamEvent
if err := json.Unmarshal([]byte(payload), &event); err != nil {
logger.L().Warn("openai chat_completions buffered: failed to parse event",
zap.Error(err),
zap.String("request_id", requestID),
)
continue
}
// Accumulate delta content for fallback when terminal output is empty.
acc.ProcessEvent(&event)
if (event.Type == "response.completed" || event.Type == "response.done" ||
event.Type == "response.incomplete" || event.Type == "response.failed") &&
event.Response != nil {
finalResponse = event.Response
if event.Response.Usage != nil {
usage = OpenAIUsage{
InputTokens: event.Response.Usage.InputTokens,
OutputTokens: event.Response.Usage.OutputTokens,
}
if event.Response.Usage.InputTokensDetails != nil {
usage.CacheReadInputTokens = event.Response.Usage.InputTokensDetails.CachedTokens
}
}
}
}
if err := scanner.Err(); err != nil {
if !errors.Is(err, context.Canceled) && !errors.Is(err, context.DeadlineExceeded) {
logger.L().Warn("openai chat_completions buffered: read error",
zap.Error(err),
zap.String("request_id", requestID),
)
}
}
if finalResponse == nil {
writeChatCompletionsError(c, http.StatusBadGateway, "api_error", "Upstream stream ended without a terminal response event")
return nil, fmt.Errorf("upstream stream ended without terminal event")
}
// When the terminal event has an empty output array, reconstruct from
// accumulated delta events so the client receives the full content.
acc.SupplementResponseOutput(finalResponse)
chatResp := apicompat.ResponsesToChatCompletions(finalResponse, originalModel)
if s.responseHeaderFilter != nil {
responseheaders.WriteFilteredHeaders(c.Writer.Header(), resp.Header, s.responseHeaderFilter)
}
c.JSON(http.StatusOK, chatResp)
return &OpenAIForwardResult{
RequestID: requestID,
Usage: usage,
Model: originalModel,
BillingModel: billingModel,
UpstreamModel: upstreamModel,
Stream: false,
Duration: time.Since(startTime),
}, nil
}
// handleChatStreamingResponse reads Responses SSE events from upstream,
// converts each to Chat Completions SSE chunks, and writes them to the client.
func (s *OpenAIGatewayService) handleChatStreamingResponse(
resp *http.Response,
c *gin.Context,
originalModel string,
billingModel string,
upstreamModel string,
includeUsage bool,
startTime time.Time,
) (*OpenAIForwardResult, error) {
requestID := resp.Header.Get("x-request-id")
if s.responseHeaderFilter != nil {
responseheaders.WriteFilteredHeaders(c.Writer.Header(), resp.Header, s.responseHeaderFilter)
}
c.Writer.Header().Set("Content-Type", "text/event-stream")
c.Writer.Header().Set("Cache-Control", "no-cache")
c.Writer.Header().Set("Connection", "keep-alive")
c.Writer.Header().Set("X-Accel-Buffering", "no")
c.Writer.WriteHeader(http.StatusOK)
state := apicompat.NewResponsesEventToChatState()
state.Model = originalModel
state.IncludeUsage = includeUsage
var usage OpenAIUsage
var firstTokenMs *int
firstChunk := true
scanner := bufio.NewScanner(resp.Body)
maxLineSize := defaultMaxLineSize
if s.cfg != nil && s.cfg.Gateway.MaxLineSize > 0 {
maxLineSize = s.cfg.Gateway.MaxLineSize
}
scanner.Buffer(make([]byte, 0, 64*1024), maxLineSize)
resultWithUsage := func() *OpenAIForwardResult {
return &OpenAIForwardResult{
RequestID: requestID,
Usage: usage,
Model: originalModel,
BillingModel: billingModel,
UpstreamModel: upstreamModel,
Stream: true,
Duration: time.Since(startTime),
FirstTokenMs: firstTokenMs,
}
}
processDataLine := func(payload string) bool {
if firstChunk {
firstChunk = false
ms := int(time.Since(startTime).Milliseconds())
firstTokenMs = &ms
}
var event apicompat.ResponsesStreamEvent
if err := json.Unmarshal([]byte(payload), &event); err != nil {
logger.L().Warn("openai chat_completions stream: failed to parse event",
zap.Error(err),
zap.String("request_id", requestID),
)
return false
}
// Extract usage from completion events
if (event.Type == "response.completed" || event.Type == "response.incomplete" || event.Type == "response.failed") &&
event.Response != nil && event.Response.Usage != nil {
usage = OpenAIUsage{
InputTokens: event.Response.Usage.InputTokens,
OutputTokens: event.Response.Usage.OutputTokens,
}
if event.Response.Usage.InputTokensDetails != nil {
usage.CacheReadInputTokens = event.Response.Usage.InputTokensDetails.CachedTokens
}
}
chunks := apicompat.ResponsesEventToChatChunks(&event, state)
for _, chunk := range chunks {
sse, err := apicompat.ChatChunkToSSE(chunk)
if err != nil {
logger.L().Warn("openai chat_completions stream: failed to marshal chunk",
zap.Error(err),
zap.String("request_id", requestID),
)
continue
}
if _, err := fmt.Fprint(c.Writer, sse); err != nil {
logger.L().Info("openai chat_completions stream: client disconnected",
zap.String("request_id", requestID),
)
return true
}
}
if len(chunks) > 0 {
c.Writer.Flush()
}
return false
}
finalizeStream := func() (*OpenAIForwardResult, error) {
if finalChunks := apicompat.FinalizeResponsesChatStream(state); len(finalChunks) > 0 {
for _, chunk := range finalChunks {
sse, err := apicompat.ChatChunkToSSE(chunk)
if err != nil {
continue
}
fmt.Fprint(c.Writer, sse) //nolint:errcheck
}
}
// Send [DONE] sentinel
fmt.Fprint(c.Writer, "data: [DONE]\n\n") //nolint:errcheck
c.Writer.Flush()
return resultWithUsage(), nil
}
handleScanErr := func(err error) {
if err != nil && !errors.Is(err, context.Canceled) && !errors.Is(err, context.DeadlineExceeded) {
logger.L().Warn("openai chat_completions stream: read error",
zap.Error(err),
zap.String("request_id", requestID),
)
}
}
// Determine keepalive interval
keepaliveInterval := time.Duration(0)
if s.cfg != nil && s.cfg.Gateway.StreamKeepaliveInterval > 0 {
keepaliveInterval = time.Duration(s.cfg.Gateway.StreamKeepaliveInterval) * time.Second
}
// No keepalive: fast synchronous path
if keepaliveInterval <= 0 {
for scanner.Scan() {
line := scanner.Text()
if !strings.HasPrefix(line, "data: ") || line == "data: [DONE]" {
continue
}
if processDataLine(line[6:]) {
return resultWithUsage(), nil
}
}
handleScanErr(scanner.Err())
return finalizeStream()
}
// With keepalive: goroutine + channel + select
type scanEvent struct {
line string
err error
}
events := make(chan scanEvent, 16)
done := make(chan struct{})
sendEvent := func(ev scanEvent) bool {
select {
case events <- ev:
return true
case <-done:
return false
}
}
go func() {
defer close(events)
for scanner.Scan() {
if !sendEvent(scanEvent{line: scanner.Text()}) {
return
}
}
if err := scanner.Err(); err != nil {
_ = sendEvent(scanEvent{err: err})
}
}()
defer close(done)
keepaliveTicker := time.NewTicker(keepaliveInterval)
defer keepaliveTicker.Stop()
lastDataAt := time.Now()
for {
select {
case ev, ok := <-events:
if !ok {
return finalizeStream()
}
if ev.err != nil {
handleScanErr(ev.err)
return finalizeStream()
}
lastDataAt = time.Now()
line := ev.line
if !strings.HasPrefix(line, "data: ") || line == "data: [DONE]" {
continue
}
if processDataLine(line[6:]) {
return resultWithUsage(), nil
}
case <-keepaliveTicker.C:
if time.Since(lastDataAt) < keepaliveInterval {
continue
}
// Send SSE comment as keepalive
if _, err := fmt.Fprint(c.Writer, ":\n\n"); err != nil {
logger.L().Info("openai chat_completions stream: client disconnected during keepalive",
zap.String("request_id", requestID),
)
return resultWithUsage(), nil
}
c.Writer.Flush()
}
}
}
// writeChatCompletionsError writes an error response in OpenAI Chat Completions format.
func writeChatCompletionsError(c *gin.Context, statusCode int, errType, message string) {
c.JSON(statusCode, gin.H{
"error": gin.H{
"type": errType,
"message": message,
},
})
}