From 4b309fa8b53d6725d643d29faa83fb6698bf926d Mon Sep 17 00:00:00 2001 From: yangjianbo Date: Tue, 10 Feb 2026 22:12:24 +0800 Subject: [PATCH] =?UTF-8?q?fix(gateway):=20=E4=BC=98=E5=8C=96=20ParseGatew?= =?UTF-8?q?ayRequest=20=E5=87=BD=E6=95=B0=EF=BC=8C=E4=BD=BF=E7=94=A8=20uns?= =?UTF-8?q?afe=20=E6=8F=90=E9=AB=98=E6=80=A7=E8=83=BD=E5=B9=B6=E5=A2=9E?= =?UTF-8?q?=E5=8A=A0=20JSON=20=E6=A0=A1=E9=AA=8C?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- backend/internal/service/gateway_request.go | 89 +++++++++++++++------ 1 file changed, 66 insertions(+), 23 deletions(-) diff --git a/backend/internal/service/gateway_request.go b/backend/internal/service/gateway_request.go index 4708a663..417e8aae 100644 --- a/backend/internal/service/gateway_request.go +++ b/backend/internal/service/gateway_request.go @@ -5,6 +5,7 @@ import ( "encoding/json" "fmt" "math" + "unsafe" "github.com/Wei-Shaw/sub2api/internal/domain" "github.com/Wei-Shaw/sub2api/internal/pkg/antigravity" @@ -49,6 +50,17 @@ type ParsedRequest struct { // protocol 指定请求协议格式(domain.PlatformAnthropic / domain.PlatformGemini), // 不同协议使用不同的 system/messages 字段名。 func ParseGatewayRequest(body []byte, protocol string) (*ParsedRequest, error) { + // 保持与旧实现一致:请求体必须是合法 JSON。 + // 注意:gjson.GetBytes 对非法 JSON 不会报错,因此需要显式校验。 + if !gjson.ValidBytes(body) { + return nil, fmt.Errorf("invalid json") + } + + // 性能: + // - gjson.GetBytes 会把匹配的 Raw/Str 安全复制成 string(对于巨大 messages 会产生额外拷贝)。 + // - 这里将 body 通过 unsafe 零拷贝视为 string,仅在本函数内使用,且 body 不会被修改。 + jsonStr := *(*string)(unsafe.Pointer(&body)) + parsed := &ParsedRequest{ Body: body, } @@ -56,7 +68,7 @@ func ParseGatewayRequest(body []byte, protocol string) (*ParsedRequest, error) { // --- gjson 提取简单字段(避免完整 Unmarshal) --- // model: 需要严格类型校验,非 string 返回错误 - modelResult := gjson.GetBytes(body, "model") + modelResult := gjson.Get(jsonStr, "model") if modelResult.Exists() { if modelResult.Type != gjson.String { return nil, fmt.Errorf("invalid model field type") @@ -65,7 +77,7 @@ func ParseGatewayRequest(body []byte, protocol string) (*ParsedRequest, error) { } // stream: 需要严格类型校验,非 bool 返回错误 - streamResult := gjson.GetBytes(body, "stream") + streamResult := gjson.Get(jsonStr, "stream") if streamResult.Exists() { if streamResult.Type != gjson.True && streamResult.Type != gjson.False { return nil, fmt.Errorf("invalid stream field type") @@ -74,15 +86,15 @@ func ParseGatewayRequest(body []byte, protocol string) (*ParsedRequest, error) { } // metadata.user_id: 直接路径提取,不需要严格类型校验 - parsed.MetadataUserID = gjson.GetBytes(body, "metadata.user_id").String() + parsed.MetadataUserID = gjson.Get(jsonStr, "metadata.user_id").String() // thinking.type: 直接路径提取 - if gjson.GetBytes(body, "thinking.type").String() == "enabled" { + if gjson.Get(jsonStr, "thinking.type").String() == "enabled" { parsed.ThinkingEnabled = true } // max_tokens: 仅接受整数值 - maxTokensResult := gjson.GetBytes(body, "max_tokens") + maxTokensResult := gjson.Get(jsonStr, "max_tokens") if maxTokensResult.Exists() && maxTokensResult.Type == gjson.Number { f := maxTokensResult.Float() if !math.IsNaN(f) && !math.IsInf(f, 0) && f == math.Trunc(f) && @@ -91,37 +103,54 @@ func ParseGatewayRequest(body []byte, protocol string) (*ParsedRequest, error) { } } - // --- 保留 Unmarshal 用于 system/messages 提取 --- - // 这些字段需要作为 any/[]any 传递给下游消费者,无法用 gjson 替代 + // --- system/messages 提取 --- + // 避免把整个 body Unmarshal 到 map(会产生大量 map/接口分配)。 + // 使用 gjson 抽取目标字段的 Raw,再对该子树进行 Unmarshal。 switch protocol { case domain.PlatformGemini: // Gemini 原生格式: systemInstruction.parts / contents - var req map[string]any - if err := json.Unmarshal(body, &req); err != nil { - return nil, err - } - if sysInst, ok := req["systemInstruction"].(map[string]any); ok { - if parts, ok := sysInst["parts"].([]any); ok { - parsed.System = parts + if sysParts := gjson.Get(jsonStr, "systemInstruction.parts"); sysParts.Exists() && sysParts.IsArray() { + var parts []any + if err := json.Unmarshal(sliceRawFromBody(body, sysParts), &parts); err != nil { + return nil, err } + parsed.System = parts } - if contents, ok := req["contents"].([]any); ok { - parsed.Messages = contents + + if contents := gjson.Get(jsonStr, "contents"); contents.Exists() && contents.IsArray() { + var msgs []any + if err := json.Unmarshal(sliceRawFromBody(body, contents), &msgs); err != nil { + return nil, err + } + parsed.Messages = msgs } default: // Anthropic / OpenAI 格式: system / messages // system 字段只要存在就视为显式提供(即使为 null), // 以避免客户端传 null 时被默认 system 误注入。 - var req map[string]any - if err := json.Unmarshal(body, &req); err != nil { - return nil, err - } - if system, ok := req["system"]; ok { + if sys := gjson.Get(jsonStr, "system"); sys.Exists() { parsed.HasSystem = true - parsed.System = system + switch sys.Type { + case gjson.Null: + parsed.System = nil + case gjson.String: + // 与 encoding/json 的 Unmarshal 行为一致:返回解码后的字符串。 + parsed.System = sys.String() + default: + var system any + if err := json.Unmarshal(sliceRawFromBody(body, sys), &system); err != nil { + return nil, err + } + parsed.System = system + } } - if messages, ok := req["messages"].([]any); ok { + + if msgs := gjson.Get(jsonStr, "messages"); msgs.Exists() && msgs.IsArray() { + var messages []any + if err := json.Unmarshal(sliceRawFromBody(body, msgs), &messages); err != nil { + return nil, err + } parsed.Messages = messages } } @@ -129,6 +158,20 @@ func ParseGatewayRequest(body []byte, protocol string) (*ParsedRequest, error) { return parsed, nil } +// sliceRawFromBody 返回 Result.Raw 对应的原始字节切片。 +// 优先使用 Result.Index 直接从 body 切片,避免对大字段(如 messages)产生额外拷贝。 +// 当 Index 不可用时,退化为复制(理论上极少发生)。 +func sliceRawFromBody(body []byte, r gjson.Result) []byte { + if r.Index > 0 { + end := r.Index + len(r.Raw) + if end <= len(body) { + return body[r.Index:end] + } + } + // fallback: 不影响正确性,但会产生一次拷贝 + return []byte(r.Raw) +} + // parseIntegralNumber 将 JSON 解码后的数字安全转换为 int。 // 仅接受“整数值”的输入,小数/NaN/Inf/越界值都会返回 false。 func parseIntegralNumber(raw any) (int, bool) {