feat(openai): 极致优化 OAuth 链路并补齐性能守护
- 优化 /v1/responses 热路径,减少重复解析与不必要拷贝\n- 优化并发与 token 竞争路径并补齐运行指标\n- 补充 OpenAI/Ops 相关单元测试与回归用例\n- 新增灰度阈值守护与压测脚本,支撑发布验收
This commit is contained in:
@@ -12,7 +12,6 @@ import (
|
||||
"io"
|
||||
"log"
|
||||
"net/http"
|
||||
"regexp"
|
||||
"sort"
|
||||
"strconv"
|
||||
"strings"
|
||||
@@ -34,11 +33,10 @@ const (
|
||||
// OpenAI Platform API for API Key accounts (fallback)
|
||||
openaiPlatformAPIURL = "https://api.openai.com/v1/responses"
|
||||
openaiStickySessionTTL = time.Hour // 粘性会话TTL
|
||||
)
|
||||
|
||||
// openaiSSEDataRe matches SSE data lines with optional whitespace after colon.
|
||||
// Some upstream APIs return non-standard "data:" without space (should be "data: ").
|
||||
var openaiSSEDataRe = regexp.MustCompile(`^data:\s*`)
|
||||
// OpenAIParsedRequestBodyKey 缓存 handler 侧已解析的请求体,避免重复解析。
|
||||
OpenAIParsedRequestBodyKey = "openai_parsed_request_body"
|
||||
)
|
||||
|
||||
// OpenAI allowed headers whitelist (for non-OAuth accounts)
|
||||
var openaiAllowedHeaders = map[string]bool{
|
||||
@@ -745,32 +743,37 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco
|
||||
startTime := time.Now()
|
||||
|
||||
originalBody := body
|
||||
reqModel, reqStream, promptCacheKey := extractOpenAIRequestMetaFromBody(body)
|
||||
originalModel := reqModel
|
||||
|
||||
// Parse request body once (avoid multiple parse/serialize cycles)
|
||||
var reqBody map[string]any
|
||||
if err := json.Unmarshal(body, &reqBody); err != nil {
|
||||
return nil, fmt.Errorf("parse request: %w", err)
|
||||
isCodexCLI := openai.IsCodexCLIRequest(c.GetHeader("User-Agent")) || (s.cfg != nil && s.cfg.Gateway.ForceCodexCLI)
|
||||
passthroughEnabled := account.Type == AccountTypeOAuth && account.IsOpenAIOAuthPassthroughEnabled() && isCodexCLI
|
||||
if passthroughEnabled {
|
||||
// 透传分支只需要轻量提取字段,避免热路径全量 Unmarshal。
|
||||
reasoningEffort := extractOpenAIReasoningEffortFromBody(body, reqModel)
|
||||
return s.forwardOAuthPassthrough(ctx, c, account, originalBody, reqModel, reasoningEffort, reqStream, startTime)
|
||||
}
|
||||
|
||||
// Extract model and stream from parsed body
|
||||
reqModel, _ := reqBody["model"].(string)
|
||||
reqStream, _ := reqBody["stream"].(bool)
|
||||
promptCacheKey := ""
|
||||
if v, ok := reqBody["prompt_cache_key"].(string); ok {
|
||||
promptCacheKey = strings.TrimSpace(v)
|
||||
reqBody, err := getOpenAIRequestBodyMap(c, body)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if v, ok := reqBody["model"].(string); ok {
|
||||
reqModel = v
|
||||
originalModel = reqModel
|
||||
}
|
||||
if v, ok := reqBody["stream"].(bool); ok {
|
||||
reqStream = v
|
||||
}
|
||||
if promptCacheKey == "" {
|
||||
if v, ok := reqBody["prompt_cache_key"].(string); ok {
|
||||
promptCacheKey = strings.TrimSpace(v)
|
||||
}
|
||||
}
|
||||
|
||||
// Track if body needs re-serialization
|
||||
bodyModified := false
|
||||
originalModel := reqModel
|
||||
|
||||
isCodexCLI := openai.IsCodexCLIRequest(c.GetHeader("User-Agent")) || (s.cfg != nil && s.cfg.Gateway.ForceCodexCLI)
|
||||
|
||||
passthroughEnabled := account.Type == AccountTypeOAuth && account.IsOpenAIOAuthPassthroughEnabled() && isCodexCLI
|
||||
if passthroughEnabled {
|
||||
reasoningEffort := extractOpenAIReasoningEffort(reqBody, reqModel)
|
||||
return s.forwardOAuthPassthrough(ctx, c, account, originalBody, reqModel, reasoningEffort, reqStream, startTime)
|
||||
}
|
||||
|
||||
// 对所有请求执行模型映射(包含 Codex CLI)。
|
||||
mappedModel := account.GetMappedModel(reqModel)
|
||||
@@ -888,12 +891,12 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco
|
||||
}
|
||||
|
||||
// Capture upstream request body for ops retry of this attempt.
|
||||
if c != nil {
|
||||
c.Set(OpsUpstreamRequestBodyKey, string(body))
|
||||
}
|
||||
setOpsUpstreamRequestBody(c, body)
|
||||
|
||||
// Send request
|
||||
upstreamStart := time.Now()
|
||||
resp, err := s.httpUpstream.Do(upstreamReq, proxyURL, account.ID, account.Concurrency)
|
||||
SetOpsLatencyMs(c, OpsUpstreamLatencyMsKey, time.Since(upstreamStart).Milliseconds())
|
||||
if err != nil {
|
||||
// Ensure the client receives an error response (handlers assume Forward writes on non-failover errors).
|
||||
safeErr := sanitizeUpstreamErrorMessage(err.Error())
|
||||
@@ -1019,12 +1022,14 @@ func (s *OpenAIGatewayService) forwardOAuthPassthrough(
|
||||
proxyURL = account.Proxy.URL()
|
||||
}
|
||||
|
||||
setOpsUpstreamRequestBody(c, body)
|
||||
if c != nil {
|
||||
c.Set(OpsUpstreamRequestBodyKey, string(body))
|
||||
c.Set("openai_passthrough", true)
|
||||
}
|
||||
|
||||
upstreamStart := time.Now()
|
||||
resp, err := s.httpUpstream.Do(upstreamReq, proxyURL, account.ID, account.Concurrency)
|
||||
SetOpsLatencyMs(c, OpsUpstreamLatencyMsKey, time.Since(upstreamStart).Milliseconds())
|
||||
if err != nil {
|
||||
safeErr := sanitizeUpstreamErrorMessage(err.Error())
|
||||
setOpsUpstreamError(c, 0, safeErr, "")
|
||||
@@ -1240,8 +1245,7 @@ func (s *OpenAIGatewayService) handleStreamingResponsePassthrough(
|
||||
|
||||
for scanner.Scan() {
|
||||
line := scanner.Text()
|
||||
if openaiSSEDataRe.MatchString(line) {
|
||||
data := openaiSSEDataRe.ReplaceAllString(line, "")
|
||||
if data, ok := extractOpenAISSEDataLine(line); ok {
|
||||
if firstTokenMs == nil && strings.TrimSpace(data) != "" {
|
||||
ms := int(time.Since(startTime).Milliseconds())
|
||||
firstTokenMs = &ms
|
||||
@@ -1750,8 +1754,7 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp
|
||||
lastDataAt = time.Now()
|
||||
|
||||
// Extract data from SSE line (supports both "data: " and "data:" formats)
|
||||
if openaiSSEDataRe.MatchString(line) {
|
||||
data := openaiSSEDataRe.ReplaceAllString(line, "")
|
||||
if data, ok := extractOpenAISSEDataLine(line); ok {
|
||||
|
||||
// Replace model in response if needed
|
||||
if needModelReplace {
|
||||
@@ -1827,11 +1830,27 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp
|
||||
|
||||
}
|
||||
|
||||
// extractOpenAISSEDataLine 低开销提取 SSE `data:` 行内容。
|
||||
// 兼容 `data: xxx` 与 `data:xxx` 两种格式。
|
||||
func extractOpenAISSEDataLine(line string) (string, bool) {
|
||||
if !strings.HasPrefix(line, "data:") {
|
||||
return "", false
|
||||
}
|
||||
start := len("data:")
|
||||
for start < len(line) {
|
||||
if line[start] != ' ' && line[start] != ' ' {
|
||||
break
|
||||
}
|
||||
start++
|
||||
}
|
||||
return line[start:], true
|
||||
}
|
||||
|
||||
func (s *OpenAIGatewayService) replaceModelInSSELine(line, fromModel, toModel string) string {
|
||||
if !openaiSSEDataRe.MatchString(line) {
|
||||
data, ok := extractOpenAISSEDataLine(line)
|
||||
if !ok {
|
||||
return line
|
||||
}
|
||||
data := openaiSSEDataRe.ReplaceAllString(line, "")
|
||||
if data == "" || data == "[DONE]" {
|
||||
return line
|
||||
}
|
||||
@@ -1872,25 +1891,20 @@ func (s *OpenAIGatewayService) correctToolCallsInResponseBody(body []byte) []byt
|
||||
}
|
||||
|
||||
func (s *OpenAIGatewayService) parseSSEUsage(data string, usage *OpenAIUsage) {
|
||||
// Parse response.completed event for usage (OpenAI Responses format)
|
||||
var event struct {
|
||||
Type string `json:"type"`
|
||||
Response struct {
|
||||
Usage struct {
|
||||
InputTokens int `json:"input_tokens"`
|
||||
OutputTokens int `json:"output_tokens"`
|
||||
InputTokenDetails struct {
|
||||
CachedTokens int `json:"cached_tokens"`
|
||||
} `json:"input_tokens_details"`
|
||||
} `json:"usage"`
|
||||
} `json:"response"`
|
||||
if usage == nil || data == "" || data == "[DONE]" {
|
||||
return
|
||||
}
|
||||
// 选择性解析:仅在数据中包含 completed 事件标识时才进入字段提取。
|
||||
if !strings.Contains(data, `"response.completed"`) {
|
||||
return
|
||||
}
|
||||
if gjson.Get(data, "type").String() != "response.completed" {
|
||||
return
|
||||
}
|
||||
|
||||
if json.Unmarshal([]byte(data), &event) == nil && event.Type == "response.completed" {
|
||||
usage.InputTokens = event.Response.Usage.InputTokens
|
||||
usage.OutputTokens = event.Response.Usage.OutputTokens
|
||||
usage.CacheReadInputTokens = event.Response.Usage.InputTokenDetails.CachedTokens
|
||||
}
|
||||
usage.InputTokens = int(gjson.Get(data, "response.usage.input_tokens").Int())
|
||||
usage.OutputTokens = int(gjson.Get(data, "response.usage.output_tokens").Int())
|
||||
usage.CacheReadInputTokens = int(gjson.Get(data, "response.usage.input_tokens_details.cached_tokens").Int())
|
||||
}
|
||||
|
||||
func (s *OpenAIGatewayService) handleNonStreamingResponse(ctx context.Context, resp *http.Response, c *gin.Context, account *Account, originalModel, mappedModel string) (*OpenAIUsage, error) {
|
||||
@@ -2001,10 +2015,10 @@ func (s *OpenAIGatewayService) handleOAuthSSEToJSON(resp *http.Response, c *gin.
|
||||
func extractCodexFinalResponse(body string) ([]byte, bool) {
|
||||
lines := strings.Split(body, "\n")
|
||||
for _, line := range lines {
|
||||
if !openaiSSEDataRe.MatchString(line) {
|
||||
data, ok := extractOpenAISSEDataLine(line)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
data := openaiSSEDataRe.ReplaceAllString(line, "")
|
||||
if data == "" || data == "[DONE]" {
|
||||
continue
|
||||
}
|
||||
@@ -2028,10 +2042,10 @@ func (s *OpenAIGatewayService) parseSSEUsageFromBody(body string) *OpenAIUsage {
|
||||
usage := &OpenAIUsage{}
|
||||
lines := strings.Split(body, "\n")
|
||||
for _, line := range lines {
|
||||
if !openaiSSEDataRe.MatchString(line) {
|
||||
data, ok := extractOpenAISSEDataLine(line)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
data := openaiSSEDataRe.ReplaceAllString(line, "")
|
||||
if data == "" || data == "[DONE]" {
|
||||
continue
|
||||
}
|
||||
@@ -2043,7 +2057,7 @@ func (s *OpenAIGatewayService) parseSSEUsageFromBody(body string) *OpenAIUsage {
|
||||
func (s *OpenAIGatewayService) replaceModelInSSEBody(body, fromModel, toModel string) string {
|
||||
lines := strings.Split(body, "\n")
|
||||
for i, line := range lines {
|
||||
if !openaiSSEDataRe.MatchString(line) {
|
||||
if _, ok := extractOpenAISSEDataLine(line); !ok {
|
||||
continue
|
||||
}
|
||||
lines[i] = s.replaceModelInSSELine(line, fromModel, toModel)
|
||||
@@ -2396,6 +2410,53 @@ func deriveOpenAIReasoningEffortFromModel(model string) string {
|
||||
return normalizeOpenAIReasoningEffort(parts[len(parts)-1])
|
||||
}
|
||||
|
||||
func extractOpenAIRequestMetaFromBody(body []byte) (model string, stream bool, promptCacheKey string) {
|
||||
if len(body) == 0 {
|
||||
return "", false, ""
|
||||
}
|
||||
|
||||
model = strings.TrimSpace(gjson.GetBytes(body, "model").String())
|
||||
stream = gjson.GetBytes(body, "stream").Bool()
|
||||
promptCacheKey = strings.TrimSpace(gjson.GetBytes(body, "prompt_cache_key").String())
|
||||
return model, stream, promptCacheKey
|
||||
}
|
||||
|
||||
func extractOpenAIReasoningEffortFromBody(body []byte, requestedModel string) *string {
|
||||
reasoningEffort := strings.TrimSpace(gjson.GetBytes(body, "reasoning.effort").String())
|
||||
if reasoningEffort == "" {
|
||||
reasoningEffort = strings.TrimSpace(gjson.GetBytes(body, "reasoning_effort").String())
|
||||
}
|
||||
if reasoningEffort != "" {
|
||||
normalized := normalizeOpenAIReasoningEffort(reasoningEffort)
|
||||
if normalized == "" {
|
||||
return nil
|
||||
}
|
||||
return &normalized
|
||||
}
|
||||
|
||||
value := deriveOpenAIReasoningEffortFromModel(requestedModel)
|
||||
if value == "" {
|
||||
return nil
|
||||
}
|
||||
return &value
|
||||
}
|
||||
|
||||
func getOpenAIRequestBodyMap(c *gin.Context, body []byte) (map[string]any, error) {
|
||||
if c != nil {
|
||||
if cached, ok := c.Get(OpenAIParsedRequestBodyKey); ok {
|
||||
if reqBody, ok := cached.(map[string]any); ok && reqBody != nil {
|
||||
return reqBody, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
var reqBody map[string]any
|
||||
if err := json.Unmarshal(body, &reqBody); err != nil {
|
||||
return nil, fmt.Errorf("parse request: %w", err)
|
||||
}
|
||||
return reqBody, nil
|
||||
}
|
||||
|
||||
func extractOpenAIReasoningEffort(reqBody map[string]any, requestedModel string) *string {
|
||||
if value, present := getOpenAIReasoningEffortFromReqBody(reqBody); present {
|
||||
if value == "" {
|
||||
|
||||
Reference in New Issue
Block a user