feat(openai): 极致优化 OAuth 链路并补齐性能守护

- 优化 /v1/responses 热路径,减少重复解析与不必要拷贝\n- 优化并发与 token 竞争路径并补齐运行指标\n- 补充 OpenAI/Ops 相关单元测试与回归用例\n- 新增灰度阈值守护与压测脚本,支撑发布验收
This commit is contained in:
yangjianbo
2026-02-12 09:41:37 +08:00
parent a88bb8684f
commit 61a2bf469a
16 changed files with 1519 additions and 135 deletions

View File

@@ -12,7 +12,6 @@ import (
"io"
"log"
"net/http"
"regexp"
"sort"
"strconv"
"strings"
@@ -34,11 +33,10 @@ const (
// OpenAI Platform API for API Key accounts (fallback)
openaiPlatformAPIURL = "https://api.openai.com/v1/responses"
openaiStickySessionTTL = time.Hour // 粘性会话TTL
)
// openaiSSEDataRe matches SSE data lines with optional whitespace after colon.
// Some upstream APIs return non-standard "data:" without space (should be "data: ").
var openaiSSEDataRe = regexp.MustCompile(`^data:\s*`)
// OpenAIParsedRequestBodyKey 缓存 handler 侧已解析的请求体,避免重复解析。
OpenAIParsedRequestBodyKey = "openai_parsed_request_body"
)
// OpenAI allowed headers whitelist (for non-OAuth accounts)
var openaiAllowedHeaders = map[string]bool{
@@ -745,32 +743,37 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco
startTime := time.Now()
originalBody := body
reqModel, reqStream, promptCacheKey := extractOpenAIRequestMetaFromBody(body)
originalModel := reqModel
// Parse request body once (avoid multiple parse/serialize cycles)
var reqBody map[string]any
if err := json.Unmarshal(body, &reqBody); err != nil {
return nil, fmt.Errorf("parse request: %w", err)
isCodexCLI := openai.IsCodexCLIRequest(c.GetHeader("User-Agent")) || (s.cfg != nil && s.cfg.Gateway.ForceCodexCLI)
passthroughEnabled := account.Type == AccountTypeOAuth && account.IsOpenAIOAuthPassthroughEnabled() && isCodexCLI
if passthroughEnabled {
// 透传分支只需要轻量提取字段,避免热路径全量 Unmarshal。
reasoningEffort := extractOpenAIReasoningEffortFromBody(body, reqModel)
return s.forwardOAuthPassthrough(ctx, c, account, originalBody, reqModel, reasoningEffort, reqStream, startTime)
}
// Extract model and stream from parsed body
reqModel, _ := reqBody["model"].(string)
reqStream, _ := reqBody["stream"].(bool)
promptCacheKey := ""
if v, ok := reqBody["prompt_cache_key"].(string); ok {
promptCacheKey = strings.TrimSpace(v)
reqBody, err := getOpenAIRequestBodyMap(c, body)
if err != nil {
return nil, err
}
if v, ok := reqBody["model"].(string); ok {
reqModel = v
originalModel = reqModel
}
if v, ok := reqBody["stream"].(bool); ok {
reqStream = v
}
if promptCacheKey == "" {
if v, ok := reqBody["prompt_cache_key"].(string); ok {
promptCacheKey = strings.TrimSpace(v)
}
}
// Track if body needs re-serialization
bodyModified := false
originalModel := reqModel
isCodexCLI := openai.IsCodexCLIRequest(c.GetHeader("User-Agent")) || (s.cfg != nil && s.cfg.Gateway.ForceCodexCLI)
passthroughEnabled := account.Type == AccountTypeOAuth && account.IsOpenAIOAuthPassthroughEnabled() && isCodexCLI
if passthroughEnabled {
reasoningEffort := extractOpenAIReasoningEffort(reqBody, reqModel)
return s.forwardOAuthPassthrough(ctx, c, account, originalBody, reqModel, reasoningEffort, reqStream, startTime)
}
// 对所有请求执行模型映射(包含 Codex CLI
mappedModel := account.GetMappedModel(reqModel)
@@ -888,12 +891,12 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco
}
// Capture upstream request body for ops retry of this attempt.
if c != nil {
c.Set(OpsUpstreamRequestBodyKey, string(body))
}
setOpsUpstreamRequestBody(c, body)
// Send request
upstreamStart := time.Now()
resp, err := s.httpUpstream.Do(upstreamReq, proxyURL, account.ID, account.Concurrency)
SetOpsLatencyMs(c, OpsUpstreamLatencyMsKey, time.Since(upstreamStart).Milliseconds())
if err != nil {
// Ensure the client receives an error response (handlers assume Forward writes on non-failover errors).
safeErr := sanitizeUpstreamErrorMessage(err.Error())
@@ -1019,12 +1022,14 @@ func (s *OpenAIGatewayService) forwardOAuthPassthrough(
proxyURL = account.Proxy.URL()
}
setOpsUpstreamRequestBody(c, body)
if c != nil {
c.Set(OpsUpstreamRequestBodyKey, string(body))
c.Set("openai_passthrough", true)
}
upstreamStart := time.Now()
resp, err := s.httpUpstream.Do(upstreamReq, proxyURL, account.ID, account.Concurrency)
SetOpsLatencyMs(c, OpsUpstreamLatencyMsKey, time.Since(upstreamStart).Milliseconds())
if err != nil {
safeErr := sanitizeUpstreamErrorMessage(err.Error())
setOpsUpstreamError(c, 0, safeErr, "")
@@ -1240,8 +1245,7 @@ func (s *OpenAIGatewayService) handleStreamingResponsePassthrough(
for scanner.Scan() {
line := scanner.Text()
if openaiSSEDataRe.MatchString(line) {
data := openaiSSEDataRe.ReplaceAllString(line, "")
if data, ok := extractOpenAISSEDataLine(line); ok {
if firstTokenMs == nil && strings.TrimSpace(data) != "" {
ms := int(time.Since(startTime).Milliseconds())
firstTokenMs = &ms
@@ -1750,8 +1754,7 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp
lastDataAt = time.Now()
// Extract data from SSE line (supports both "data: " and "data:" formats)
if openaiSSEDataRe.MatchString(line) {
data := openaiSSEDataRe.ReplaceAllString(line, "")
if data, ok := extractOpenAISSEDataLine(line); ok {
// Replace model in response if needed
if needModelReplace {
@@ -1827,11 +1830,27 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp
}
// extractOpenAISSEDataLine 低开销提取 SSE `data:` 行内容。
// 兼容 `data: xxx` 与 `data:xxx` 两种格式。
func extractOpenAISSEDataLine(line string) (string, bool) {
if !strings.HasPrefix(line, "data:") {
return "", false
}
start := len("data:")
for start < len(line) {
if line[start] != ' ' && line[start] != ' ' {
break
}
start++
}
return line[start:], true
}
func (s *OpenAIGatewayService) replaceModelInSSELine(line, fromModel, toModel string) string {
if !openaiSSEDataRe.MatchString(line) {
data, ok := extractOpenAISSEDataLine(line)
if !ok {
return line
}
data := openaiSSEDataRe.ReplaceAllString(line, "")
if data == "" || data == "[DONE]" {
return line
}
@@ -1872,25 +1891,20 @@ func (s *OpenAIGatewayService) correctToolCallsInResponseBody(body []byte) []byt
}
func (s *OpenAIGatewayService) parseSSEUsage(data string, usage *OpenAIUsage) {
// Parse response.completed event for usage (OpenAI Responses format)
var event struct {
Type string `json:"type"`
Response struct {
Usage struct {
InputTokens int `json:"input_tokens"`
OutputTokens int `json:"output_tokens"`
InputTokenDetails struct {
CachedTokens int `json:"cached_tokens"`
} `json:"input_tokens_details"`
} `json:"usage"`
} `json:"response"`
if usage == nil || data == "" || data == "[DONE]" {
return
}
// 选择性解析:仅在数据中包含 completed 事件标识时才进入字段提取。
if !strings.Contains(data, `"response.completed"`) {
return
}
if gjson.Get(data, "type").String() != "response.completed" {
return
}
if json.Unmarshal([]byte(data), &event) == nil && event.Type == "response.completed" {
usage.InputTokens = event.Response.Usage.InputTokens
usage.OutputTokens = event.Response.Usage.OutputTokens
usage.CacheReadInputTokens = event.Response.Usage.InputTokenDetails.CachedTokens
}
usage.InputTokens = int(gjson.Get(data, "response.usage.input_tokens").Int())
usage.OutputTokens = int(gjson.Get(data, "response.usage.output_tokens").Int())
usage.CacheReadInputTokens = int(gjson.Get(data, "response.usage.input_tokens_details.cached_tokens").Int())
}
func (s *OpenAIGatewayService) handleNonStreamingResponse(ctx context.Context, resp *http.Response, c *gin.Context, account *Account, originalModel, mappedModel string) (*OpenAIUsage, error) {
@@ -2001,10 +2015,10 @@ func (s *OpenAIGatewayService) handleOAuthSSEToJSON(resp *http.Response, c *gin.
func extractCodexFinalResponse(body string) ([]byte, bool) {
lines := strings.Split(body, "\n")
for _, line := range lines {
if !openaiSSEDataRe.MatchString(line) {
data, ok := extractOpenAISSEDataLine(line)
if !ok {
continue
}
data := openaiSSEDataRe.ReplaceAllString(line, "")
if data == "" || data == "[DONE]" {
continue
}
@@ -2028,10 +2042,10 @@ func (s *OpenAIGatewayService) parseSSEUsageFromBody(body string) *OpenAIUsage {
usage := &OpenAIUsage{}
lines := strings.Split(body, "\n")
for _, line := range lines {
if !openaiSSEDataRe.MatchString(line) {
data, ok := extractOpenAISSEDataLine(line)
if !ok {
continue
}
data := openaiSSEDataRe.ReplaceAllString(line, "")
if data == "" || data == "[DONE]" {
continue
}
@@ -2043,7 +2057,7 @@ func (s *OpenAIGatewayService) parseSSEUsageFromBody(body string) *OpenAIUsage {
func (s *OpenAIGatewayService) replaceModelInSSEBody(body, fromModel, toModel string) string {
lines := strings.Split(body, "\n")
for i, line := range lines {
if !openaiSSEDataRe.MatchString(line) {
if _, ok := extractOpenAISSEDataLine(line); !ok {
continue
}
lines[i] = s.replaceModelInSSELine(line, fromModel, toModel)
@@ -2396,6 +2410,53 @@ func deriveOpenAIReasoningEffortFromModel(model string) string {
return normalizeOpenAIReasoningEffort(parts[len(parts)-1])
}
func extractOpenAIRequestMetaFromBody(body []byte) (model string, stream bool, promptCacheKey string) {
if len(body) == 0 {
return "", false, ""
}
model = strings.TrimSpace(gjson.GetBytes(body, "model").String())
stream = gjson.GetBytes(body, "stream").Bool()
promptCacheKey = strings.TrimSpace(gjson.GetBytes(body, "prompt_cache_key").String())
return model, stream, promptCacheKey
}
func extractOpenAIReasoningEffortFromBody(body []byte, requestedModel string) *string {
reasoningEffort := strings.TrimSpace(gjson.GetBytes(body, "reasoning.effort").String())
if reasoningEffort == "" {
reasoningEffort = strings.TrimSpace(gjson.GetBytes(body, "reasoning_effort").String())
}
if reasoningEffort != "" {
normalized := normalizeOpenAIReasoningEffort(reasoningEffort)
if normalized == "" {
return nil
}
return &normalized
}
value := deriveOpenAIReasoningEffortFromModel(requestedModel)
if value == "" {
return nil
}
return &value
}
func getOpenAIRequestBodyMap(c *gin.Context, body []byte) (map[string]any, error) {
if c != nil {
if cached, ok := c.Get(OpenAIParsedRequestBodyKey); ok {
if reqBody, ok := cached.(map[string]any); ok && reqBody != nil {
return reqBody, nil
}
}
}
var reqBody map[string]any
if err := json.Unmarshal(body, &reqBody); err != nil {
return nil, fmt.Errorf("parse request: %w", err)
}
return reqBody, nil
}
func extractOpenAIReasoningEffort(reqBody map[string]any, requestedModel string) *string {
if value, present := getOpenAIReasoningEffortFromReqBody(reqBody); present {
if value == "" {