perf(backend): 使用 gjson/sjson 优化热路径 JSON 处理

将 API 网关热路径中的 json.Unmarshal+json.Marshal 替换为 gjson 零拷贝查询和 sjson 精准写入:
- unwrapV1InternalResponse 性能提升 22x(4009ns→182ns),内存分配减少 28.5x
- unwrapGeminiResponse、extractGeminiUsage、estimateGeminiCountTokens、ParseGeminiRateLimitResetTime 改为接收 []byte 使用 gjson 提取
- ParseGatewayRequest 的 model/stream/metadata/thinking/max_tokens 改用 gjson 类型安全提取
- Handler 层(sora/openai)改用 gjson 提取字段、sjson 注入/修改字段,移除 map[string]any 中间变量
- Sora Client 响应解析改用 gjson ForEach 遍历,减少内存分配
- 新增约 100 个单元测试用例,所有改动函数覆盖率 >85%

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
yangjianbo
2026-02-10 08:59:30 +08:00
parent 29ca1290b3
commit 58912d4ac5
14 changed files with 1686 additions and 324 deletions

View File

@@ -18,6 +18,8 @@ import (
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
"github.com/tidwall/gjson"
"github.com/tidwall/sjson"
)
// OpenAIGatewayHandler handles OpenAI API gateway requests
@@ -93,16 +95,9 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
setOpsRequestContext(c, "", false, body)
// Parse request body to map for potential modification
var reqBody map[string]any
if err := json.Unmarshal(body, &reqBody); err != nil {
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Failed to parse request body")
return
}
// Extract model and stream
reqModel, _ := reqBody["model"].(string)
reqStream, _ := reqBody["stream"].(bool)
// 使用 gjson 只读提取字段做校验,避免完整 Unmarshal
reqModel := gjson.GetBytes(body, "model").String()
reqStream := gjson.GetBytes(body, "stream").Bool()
// 验证 model 必填
if reqModel == "" {
@@ -113,16 +108,10 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
userAgent := c.GetHeader("User-Agent")
isCodexCLI := openai.IsCodexCLIRequest(userAgent) || (h.cfg != nil && h.cfg.Gateway.ForceCodexCLI)
if !isCodexCLI {
existingInstructions, _ := reqBody["instructions"].(string)
existingInstructions := gjson.GetBytes(body, "instructions").String()
if strings.TrimSpace(existingInstructions) == "" {
if instructions := strings.TrimSpace(service.GetOpenCodeInstructions()); instructions != "" {
reqBody["instructions"] = instructions
// Re-serialize body
body, err = json.Marshal(reqBody)
if err != nil {
h.errorResponse(c, http.StatusInternalServerError, "api_error", "Failed to process request")
return
}
body, _ = sjson.SetBytes(body, "instructions", instructions)
}
}
}
@@ -132,19 +121,25 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
// 提前校验 function_call_output 是否具备可关联上下文,避免上游 400。
// 要求 previous_response_id或 input 内存在带 call_id 的 tool_call/function_call
// 或带 id 且与 call_id 匹配的 item_reference。
if service.HasFunctionCallOutput(reqBody) {
previousResponseID, _ := reqBody["previous_response_id"].(string)
if strings.TrimSpace(previousResponseID) == "" && !service.HasToolCallContext(reqBody) {
if service.HasFunctionCallOutputMissingCallID(reqBody) {
log.Printf("[OpenAI Handler] function_call_output 缺少 call_id: model=%s", reqModel)
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "function_call_output requires call_id or previous_response_id; if relying on history, ensure store=true and reuse previous_response_id")
return
}
callIDs := service.FunctionCallOutputCallIDs(reqBody)
if !service.HasItemReferenceForCallIDs(reqBody, callIDs) {
log.Printf("[OpenAI Handler] function_call_output 缺少匹配的 item_reference: model=%s", reqModel)
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "function_call_output requires item_reference ids matching each call_id, or previous_response_id/tool_call context; if relying on history, ensure store=true and reuse previous_response_id")
return
// 此路径需要遍历 input 数组做 call_id 关联检查,保留 Unmarshal
if gjson.GetBytes(body, `input.#(type=="function_call_output")`).Exists() {
var reqBody map[string]any
if err := json.Unmarshal(body, &reqBody); err == nil {
if service.HasFunctionCallOutput(reqBody) {
previousResponseID, _ := reqBody["previous_response_id"].(string)
if strings.TrimSpace(previousResponseID) == "" && !service.HasToolCallContext(reqBody) {
if service.HasFunctionCallOutputMissingCallID(reqBody) {
log.Printf("[OpenAI Handler] function_call_output 缺少 call_id: model=%s", reqModel)
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "function_call_output requires call_id or previous_response_id; if relying on history, ensure store=true and reuse previous_response_id")
return
}
callIDs := service.FunctionCallOutputCallIDs(reqBody)
if !service.HasItemReferenceForCallIDs(reqBody, callIDs) {
log.Printf("[OpenAI Handler] function_call_output 缺少匹配的 item_reference: model=%s", reqModel)
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "function_call_output requires item_reference ids matching each call_id, or previous_response_id/tool_call context; if relying on history, ensure store=true and reuse previous_response_id")
return
}
}
}
}
}
@@ -207,7 +202,7 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
}
// Generate session hash (header first; fallback to prompt_cache_key)
sessionHash := h.gatewayService.GenerateSessionHash(c, reqBody)
sessionHash := h.gatewayService.GenerateSessionHash(c, body)
maxAccountSwitches := h.maxAccountSwitches
switchCount := 0

View File

@@ -10,6 +10,8 @@ import (
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/tidwall/gjson"
"github.com/tidwall/sjson"
)
func TestOpenAIHandleStreamingAwareError_JSONEscaping(t *testing.T) {
@@ -102,3 +104,48 @@ func TestOpenAIHandleStreamingAwareError_NonStreaming(t *testing.T) {
assert.Equal(t, "upstream_error", errorObj["type"])
assert.Equal(t, "test error", errorObj["message"])
}
// TestOpenAIHandler_GjsonExtraction 验证 gjson 从请求体中提取 model/stream 的正确性
func TestOpenAIHandler_GjsonExtraction(t *testing.T) {
tests := []struct {
name string
body string
wantModel string
wantStream bool
}{
{"正常提取", `{"model":"gpt-4","stream":true,"input":"hello"}`, "gpt-4", true},
{"stream false", `{"model":"gpt-4","stream":false}`, "gpt-4", false},
{"无 stream 字段", `{"model":"gpt-4"}`, "gpt-4", false},
{"model 缺失", `{"stream":true}`, "", true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
body := []byte(tt.body)
model := gjson.GetBytes(body, "model").String()
stream := gjson.GetBytes(body, "stream").Bool()
require.Equal(t, tt.wantModel, model)
require.Equal(t, tt.wantStream, stream)
})
}
}
// TestOpenAIHandler_InstructionsInjection 验证 instructions 的 gjson/sjson 注入逻辑
func TestOpenAIHandler_InstructionsInjection(t *testing.T) {
// 测试 1无 instructions → 注入
body := []byte(`{"model":"gpt-4"}`)
existing := gjson.GetBytes(body, "instructions").String()
require.Empty(t, existing)
newBody, err := sjson.SetBytes(body, "instructions", "test instruction")
require.NoError(t, err)
require.Equal(t, "test instruction", gjson.GetBytes(newBody, "instructions").String())
// 测试 2已有 instructions → 不覆盖
body2 := []byte(`{"model":"gpt-4","instructions":"existing"}`)
existing2 := gjson.GetBytes(body2, "instructions").String()
require.Equal(t, "existing", existing2)
// 测试 3空白 instructions → 注入
body3 := []byte(`{"model":"gpt-4","instructions":" "}`)
existing3 := strings.TrimSpace(gjson.GetBytes(body3, "instructions").String())
require.Empty(t, existing3)
}

View File

@@ -4,7 +4,6 @@ import (
"context"
"crypto/sha256"
"encoding/hex"
"encoding/json"
"errors"
"fmt"
"io"
@@ -23,6 +22,8 @@ import (
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
"github.com/tidwall/gjson"
"github.com/tidwall/sjson"
)
// SoraGatewayHandler handles Sora chat completions requests
@@ -105,36 +106,29 @@ func (h *SoraGatewayHandler) ChatCompletions(c *gin.Context) {
setOpsRequestContext(c, "", false, body)
var reqBody map[string]any
if err := json.Unmarshal(body, &reqBody); err != nil {
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Failed to parse request body")
return
}
reqModel, _ := reqBody["model"].(string)
// 使用 gjson 只读提取字段做校验,避免完整 Unmarshal
reqModel := gjson.GetBytes(body, "model").String()
if reqModel == "" {
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "model is required")
return
}
reqMessages, _ := reqBody["messages"].([]any)
if len(reqMessages) == 0 {
if !gjson.GetBytes(body, "messages").Exists() || gjson.GetBytes(body, "messages").Type != gjson.JSON {
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "messages is required")
return
}
clientStream, _ := reqBody["stream"].(bool)
clientStream := gjson.GetBytes(body, "stream").Bool()
if !clientStream {
if h.streamMode == "error" {
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Sora requires stream=true")
return
}
reqBody["stream"] = true
updated, err := json.Marshal(reqBody)
var err error
body, err = sjson.SetBytes(body, "stream", true)
if err != nil {
h.errorResponse(c, http.StatusInternalServerError, "api_error", "Failed to process request")
return
}
body = updated
}
setOpsRequestContext(c, reqModel, clientStream, body)
@@ -193,7 +187,7 @@ func (h *SoraGatewayHandler) ChatCompletions(c *gin.Context) {
return
}
sessionHash := generateOpenAISessionHash(c, reqBody)
sessionHash := generateOpenAISessionHash(c, body)
maxAccountSwitches := h.maxAccountSwitches
switchCount := 0
@@ -302,7 +296,7 @@ func (h *SoraGatewayHandler) ChatCompletions(c *gin.Context) {
}
}
func generateOpenAISessionHash(c *gin.Context, reqBody map[string]any) string {
func generateOpenAISessionHash(c *gin.Context, body []byte) string {
if c == nil {
return ""
}
@@ -310,10 +304,8 @@ func generateOpenAISessionHash(c *gin.Context, reqBody map[string]any) string {
if sessionID == "" {
sessionID = strings.TrimSpace(c.GetHeader("conversation_id"))
}
if sessionID == "" && reqBody != nil {
if v, ok := reqBody["prompt_cache_key"].(string); ok {
sessionID = strings.TrimSpace(v)
}
if sessionID == "" && len(body) > 0 {
sessionID = strings.TrimSpace(gjson.GetBytes(body, "prompt_cache_key").String())
}
if sessionID == "" {
return ""

View File

@@ -19,6 +19,8 @@ import (
"github.com/Wei-Shaw/sub2api/internal/testutil"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/require"
"github.com/tidwall/gjson"
"github.com/tidwall/sjson"
)
// 编译期接口断言
@@ -414,3 +416,65 @@ func TestSoraGatewayHandler_ChatCompletions(t *testing.T) {
require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp))
require.NotEmpty(t, resp["media_url"])
}
// TestSoraHandler_StreamForcing 验证 sora handler 的 stream 强制逻辑
func TestSoraHandler_StreamForcing(t *testing.T) {
// 测试 1stream=false 时 sjson 强制修改为 true
body := []byte(`{"model":"sora","messages":[{"role":"user","content":"test"}],"stream":false}`)
clientStream := gjson.GetBytes(body, "stream").Bool()
require.False(t, clientStream)
newBody, err := sjson.SetBytes(body, "stream", true)
require.NoError(t, err)
require.True(t, gjson.GetBytes(newBody, "stream").Bool())
// 测试 2stream=true 时不修改
body2 := []byte(`{"model":"sora","messages":[{"role":"user","content":"test"}],"stream":true}`)
require.True(t, gjson.GetBytes(body2, "stream").Bool())
// 测试 3无 stream 字段时 gjson 返回 false零值
body3 := []byte(`{"model":"sora","messages":[{"role":"user","content":"test"}]}`)
require.False(t, gjson.GetBytes(body3, "stream").Bool())
}
// TestSoraHandler_ValidationExtraction 验证 sora handler 中 gjson 字段校验逻辑
func TestSoraHandler_ValidationExtraction(t *testing.T) {
// model 缺失
body := []byte(`{"messages":[{"role":"user","content":"test"}]}`)
model := gjson.GetBytes(body, "model").String()
require.Empty(t, model)
// messages 缺失
body2 := []byte(`{"model":"sora"}`)
require.False(t, gjson.GetBytes(body2, "messages").Exists())
// messages 不是 JSON 数组
body3 := []byte(`{"model":"sora","messages":"not array"}`)
msgResult := gjson.GetBytes(body3, "messages")
require.True(t, msgResult.Exists())
require.NotEqual(t, gjson.JSON, msgResult.Type) // string 类型,不是 JSON 数组
}
// TestGenerateOpenAISessionHash_WithBody 验证 generateOpenAISessionHash 的 body/header 解析逻辑
func TestGenerateOpenAISessionHash_WithBody(t *testing.T) {
gin.SetMode(gin.TestMode)
// 从 body 提取 prompt_cache_key
body := []byte(`{"model":"sora","prompt_cache_key":"session-abc"}`)
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Request = httptest.NewRequest("POST", "/", nil)
hash := generateOpenAISessionHash(c, body)
require.NotEmpty(t, hash)
// 无 prompt_cache_key 且无 header → 空 hash
body2 := []byte(`{"model":"sora"}`)
hash2 := generateOpenAISessionHash(c, body2)
require.Empty(t, hash2)
// header 优先于 body
c.Request.Header.Set("session_id", "from-header")
hash3 := generateOpenAISessionHash(c, body)
require.NotEmpty(t, hash3)
require.NotEqual(t, hash, hash3) // 不同来源应产生不同 hash
}