fix(gateway): 优化 ParseGatewayRequest 函数,使用 unsafe 提高性能并增加 JSON 校验

This commit is contained in:
yangjianbo
2026-02-10 22:12:24 +08:00
parent 166080b29c
commit 4b309fa8b5

View File

@@ -5,6 +5,7 @@ import (
"encoding/json"
"fmt"
"math"
"unsafe"
"github.com/Wei-Shaw/sub2api/internal/domain"
"github.com/Wei-Shaw/sub2api/internal/pkg/antigravity"
@@ -49,6 +50,17 @@ type ParsedRequest struct {
// protocol 指定请求协议格式domain.PlatformAnthropic / domain.PlatformGemini
// 不同协议使用不同的 system/messages 字段名。
func ParseGatewayRequest(body []byte, protocol string) (*ParsedRequest, error) {
// 保持与旧实现一致:请求体必须是合法 JSON。
// 注意gjson.GetBytes 对非法 JSON 不会报错,因此需要显式校验。
if !gjson.ValidBytes(body) {
return nil, fmt.Errorf("invalid json")
}
// 性能:
// - gjson.GetBytes 会把匹配的 Raw/Str 安全复制成 string对于巨大 messages 会产生额外拷贝)。
// - 这里将 body 通过 unsafe 零拷贝视为 string仅在本函数内使用且 body 不会被修改。
jsonStr := *(*string)(unsafe.Pointer(&body))
parsed := &ParsedRequest{
Body: body,
}
@@ -56,7 +68,7 @@ func ParseGatewayRequest(body []byte, protocol string) (*ParsedRequest, error) {
// --- gjson 提取简单字段(避免完整 Unmarshal ---
// model: 需要严格类型校验,非 string 返回错误
modelResult := gjson.GetBytes(body, "model")
modelResult := gjson.Get(jsonStr, "model")
if modelResult.Exists() {
if modelResult.Type != gjson.String {
return nil, fmt.Errorf("invalid model field type")
@@ -65,7 +77,7 @@ func ParseGatewayRequest(body []byte, protocol string) (*ParsedRequest, error) {
}
// stream: 需要严格类型校验,非 bool 返回错误
streamResult := gjson.GetBytes(body, "stream")
streamResult := gjson.Get(jsonStr, "stream")
if streamResult.Exists() {
if streamResult.Type != gjson.True && streamResult.Type != gjson.False {
return nil, fmt.Errorf("invalid stream field type")
@@ -74,15 +86,15 @@ func ParseGatewayRequest(body []byte, protocol string) (*ParsedRequest, error) {
}
// metadata.user_id: 直接路径提取,不需要严格类型校验
parsed.MetadataUserID = gjson.GetBytes(body, "metadata.user_id").String()
parsed.MetadataUserID = gjson.Get(jsonStr, "metadata.user_id").String()
// thinking.type: 直接路径提取
if gjson.GetBytes(body, "thinking.type").String() == "enabled" {
if gjson.Get(jsonStr, "thinking.type").String() == "enabled" {
parsed.ThinkingEnabled = true
}
// max_tokens: 仅接受整数值
maxTokensResult := gjson.GetBytes(body, "max_tokens")
maxTokensResult := gjson.Get(jsonStr, "max_tokens")
if maxTokensResult.Exists() && maxTokensResult.Type == gjson.Number {
f := maxTokensResult.Float()
if !math.IsNaN(f) && !math.IsInf(f, 0) && f == math.Trunc(f) &&
@@ -91,37 +103,54 @@ func ParseGatewayRequest(body []byte, protocol string) (*ParsedRequest, error) {
}
}
// --- 保留 Unmarshal 用于 system/messages 提取 ---
// 这些字段需要作为 any/[]any 传递给下游消费者,无法用 gjson 替代
// --- system/messages 提取 ---
// 避免把整个 body Unmarshal 到 map会产生大量 map/接口分配)。
// 使用 gjson 抽取目标字段的 Raw再对该子树进行 Unmarshal。
switch protocol {
case domain.PlatformGemini:
// Gemini 原生格式: systemInstruction.parts / contents
var req map[string]any
if err := json.Unmarshal(body, &req); err != nil {
return nil, err
}
if sysInst, ok := req["systemInstruction"].(map[string]any); ok {
if parts, ok := sysInst["parts"].([]any); ok {
parsed.System = parts
if sysParts := gjson.Get(jsonStr, "systemInstruction.parts"); sysParts.Exists() && sysParts.IsArray() {
var parts []any
if err := json.Unmarshal(sliceRawFromBody(body, sysParts), &parts); err != nil {
return nil, err
}
parsed.System = parts
}
if contents, ok := req["contents"].([]any); ok {
parsed.Messages = contents
if contents := gjson.Get(jsonStr, "contents"); contents.Exists() && contents.IsArray() {
var msgs []any
if err := json.Unmarshal(sliceRawFromBody(body, contents), &msgs); err != nil {
return nil, err
}
parsed.Messages = msgs
}
default:
// Anthropic / OpenAI 格式: system / messages
// system 字段只要存在就视为显式提供(即使为 null
// 以避免客户端传 null 时被默认 system 误注入。
var req map[string]any
if err := json.Unmarshal(body, &req); err != nil {
return nil, err
}
if system, ok := req["system"]; ok {
if sys := gjson.Get(jsonStr, "system"); sys.Exists() {
parsed.HasSystem = true
parsed.System = system
switch sys.Type {
case gjson.Null:
parsed.System = nil
case gjson.String:
// 与 encoding/json 的 Unmarshal 行为一致:返回解码后的字符串。
parsed.System = sys.String()
default:
var system any
if err := json.Unmarshal(sliceRawFromBody(body, sys), &system); err != nil {
return nil, err
}
parsed.System = system
}
}
if messages, ok := req["messages"].([]any); ok {
if msgs := gjson.Get(jsonStr, "messages"); msgs.Exists() && msgs.IsArray() {
var messages []any
if err := json.Unmarshal(sliceRawFromBody(body, msgs), &messages); err != nil {
return nil, err
}
parsed.Messages = messages
}
}
@@ -129,6 +158,20 @@ func ParseGatewayRequest(body []byte, protocol string) (*ParsedRequest, error) {
return parsed, nil
}
// sliceRawFromBody 返回 Result.Raw 对应的原始字节切片。
// 优先使用 Result.Index 直接从 body 切片,避免对大字段(如 messages产生额外拷贝。
// 当 Index 不可用时,退化为复制(理论上极少发生)。
func sliceRawFromBody(body []byte, r gjson.Result) []byte {
if r.Index > 0 {
end := r.Index + len(r.Raw)
if end <= len(body) {
return body[r.Index:end]
}
}
// fallback: 不影响正确性,但会产生一次拷贝
return []byte(r.Raw)
}
// parseIntegralNumber 将 JSON 解码后的数字安全转换为 int。
// 仅接受“整数值”的输入,小数/NaN/Inf/越界值都会返回 false。
func parseIntegralNumber(raw any) (int, bool) {