fix(gateway): 优化 ParseGatewayRequest 函数,使用 unsafe 提高性能并增加 JSON 校验
This commit is contained in:
@@ -5,6 +5,7 @@ import (
|
|||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
"math"
|
"math"
|
||||||
|
"unsafe"
|
||||||
|
|
||||||
"github.com/Wei-Shaw/sub2api/internal/domain"
|
"github.com/Wei-Shaw/sub2api/internal/domain"
|
||||||
"github.com/Wei-Shaw/sub2api/internal/pkg/antigravity"
|
"github.com/Wei-Shaw/sub2api/internal/pkg/antigravity"
|
||||||
@@ -49,6 +50,17 @@ type ParsedRequest struct {
|
|||||||
// protocol 指定请求协议格式(domain.PlatformAnthropic / domain.PlatformGemini),
|
// protocol 指定请求协议格式(domain.PlatformAnthropic / domain.PlatformGemini),
|
||||||
// 不同协议使用不同的 system/messages 字段名。
|
// 不同协议使用不同的 system/messages 字段名。
|
||||||
func ParseGatewayRequest(body []byte, protocol string) (*ParsedRequest, error) {
|
func ParseGatewayRequest(body []byte, protocol string) (*ParsedRequest, error) {
|
||||||
|
// 保持与旧实现一致:请求体必须是合法 JSON。
|
||||||
|
// 注意:gjson.GetBytes 对非法 JSON 不会报错,因此需要显式校验。
|
||||||
|
if !gjson.ValidBytes(body) {
|
||||||
|
return nil, fmt.Errorf("invalid json")
|
||||||
|
}
|
||||||
|
|
||||||
|
// 性能:
|
||||||
|
// - gjson.GetBytes 会把匹配的 Raw/Str 安全复制成 string(对于巨大 messages 会产生额外拷贝)。
|
||||||
|
// - 这里将 body 通过 unsafe 零拷贝视为 string,仅在本函数内使用,且 body 不会被修改。
|
||||||
|
jsonStr := *(*string)(unsafe.Pointer(&body))
|
||||||
|
|
||||||
parsed := &ParsedRequest{
|
parsed := &ParsedRequest{
|
||||||
Body: body,
|
Body: body,
|
||||||
}
|
}
|
||||||
@@ -56,7 +68,7 @@ func ParseGatewayRequest(body []byte, protocol string) (*ParsedRequest, error) {
|
|||||||
// --- gjson 提取简单字段(避免完整 Unmarshal) ---
|
// --- gjson 提取简单字段(避免完整 Unmarshal) ---
|
||||||
|
|
||||||
// model: 需要严格类型校验,非 string 返回错误
|
// model: 需要严格类型校验,非 string 返回错误
|
||||||
modelResult := gjson.GetBytes(body, "model")
|
modelResult := gjson.Get(jsonStr, "model")
|
||||||
if modelResult.Exists() {
|
if modelResult.Exists() {
|
||||||
if modelResult.Type != gjson.String {
|
if modelResult.Type != gjson.String {
|
||||||
return nil, fmt.Errorf("invalid model field type")
|
return nil, fmt.Errorf("invalid model field type")
|
||||||
@@ -65,7 +77,7 @@ func ParseGatewayRequest(body []byte, protocol string) (*ParsedRequest, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// stream: 需要严格类型校验,非 bool 返回错误
|
// stream: 需要严格类型校验,非 bool 返回错误
|
||||||
streamResult := gjson.GetBytes(body, "stream")
|
streamResult := gjson.Get(jsonStr, "stream")
|
||||||
if streamResult.Exists() {
|
if streamResult.Exists() {
|
||||||
if streamResult.Type != gjson.True && streamResult.Type != gjson.False {
|
if streamResult.Type != gjson.True && streamResult.Type != gjson.False {
|
||||||
return nil, fmt.Errorf("invalid stream field type")
|
return nil, fmt.Errorf("invalid stream field type")
|
||||||
@@ -74,15 +86,15 @@ func ParseGatewayRequest(body []byte, protocol string) (*ParsedRequest, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// metadata.user_id: 直接路径提取,不需要严格类型校验
|
// metadata.user_id: 直接路径提取,不需要严格类型校验
|
||||||
parsed.MetadataUserID = gjson.GetBytes(body, "metadata.user_id").String()
|
parsed.MetadataUserID = gjson.Get(jsonStr, "metadata.user_id").String()
|
||||||
|
|
||||||
// thinking.type: 直接路径提取
|
// thinking.type: 直接路径提取
|
||||||
if gjson.GetBytes(body, "thinking.type").String() == "enabled" {
|
if gjson.Get(jsonStr, "thinking.type").String() == "enabled" {
|
||||||
parsed.ThinkingEnabled = true
|
parsed.ThinkingEnabled = true
|
||||||
}
|
}
|
||||||
|
|
||||||
// max_tokens: 仅接受整数值
|
// max_tokens: 仅接受整数值
|
||||||
maxTokensResult := gjson.GetBytes(body, "max_tokens")
|
maxTokensResult := gjson.Get(jsonStr, "max_tokens")
|
||||||
if maxTokensResult.Exists() && maxTokensResult.Type == gjson.Number {
|
if maxTokensResult.Exists() && maxTokensResult.Type == gjson.Number {
|
||||||
f := maxTokensResult.Float()
|
f := maxTokensResult.Float()
|
||||||
if !math.IsNaN(f) && !math.IsInf(f, 0) && f == math.Trunc(f) &&
|
if !math.IsNaN(f) && !math.IsInf(f, 0) && f == math.Trunc(f) &&
|
||||||
@@ -91,37 +103,54 @@ func ParseGatewayRequest(body []byte, protocol string) (*ParsedRequest, error) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// --- 保留 Unmarshal 用于 system/messages 提取 ---
|
// --- system/messages 提取 ---
|
||||||
// 这些字段需要作为 any/[]any 传递给下游消费者,无法用 gjson 替代
|
// 避免把整个 body Unmarshal 到 map(会产生大量 map/接口分配)。
|
||||||
|
// 使用 gjson 抽取目标字段的 Raw,再对该子树进行 Unmarshal。
|
||||||
|
|
||||||
switch protocol {
|
switch protocol {
|
||||||
case domain.PlatformGemini:
|
case domain.PlatformGemini:
|
||||||
// Gemini 原生格式: systemInstruction.parts / contents
|
// Gemini 原生格式: systemInstruction.parts / contents
|
||||||
var req map[string]any
|
if sysParts := gjson.Get(jsonStr, "systemInstruction.parts"); sysParts.Exists() && sysParts.IsArray() {
|
||||||
if err := json.Unmarshal(body, &req); err != nil {
|
var parts []any
|
||||||
return nil, err
|
if err := json.Unmarshal(sliceRawFromBody(body, sysParts), &parts); err != nil {
|
||||||
}
|
return nil, err
|
||||||
if sysInst, ok := req["systemInstruction"].(map[string]any); ok {
|
|
||||||
if parts, ok := sysInst["parts"].([]any); ok {
|
|
||||||
parsed.System = parts
|
|
||||||
}
|
}
|
||||||
|
parsed.System = parts
|
||||||
}
|
}
|
||||||
if contents, ok := req["contents"].([]any); ok {
|
|
||||||
parsed.Messages = contents
|
if contents := gjson.Get(jsonStr, "contents"); contents.Exists() && contents.IsArray() {
|
||||||
|
var msgs []any
|
||||||
|
if err := json.Unmarshal(sliceRawFromBody(body, contents), &msgs); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
parsed.Messages = msgs
|
||||||
}
|
}
|
||||||
default:
|
default:
|
||||||
// Anthropic / OpenAI 格式: system / messages
|
// Anthropic / OpenAI 格式: system / messages
|
||||||
// system 字段只要存在就视为显式提供(即使为 null),
|
// system 字段只要存在就视为显式提供(即使为 null),
|
||||||
// 以避免客户端传 null 时被默认 system 误注入。
|
// 以避免客户端传 null 时被默认 system 误注入。
|
||||||
var req map[string]any
|
if sys := gjson.Get(jsonStr, "system"); sys.Exists() {
|
||||||
if err := json.Unmarshal(body, &req); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
if system, ok := req["system"]; ok {
|
|
||||||
parsed.HasSystem = true
|
parsed.HasSystem = true
|
||||||
parsed.System = system
|
switch sys.Type {
|
||||||
|
case gjson.Null:
|
||||||
|
parsed.System = nil
|
||||||
|
case gjson.String:
|
||||||
|
// 与 encoding/json 的 Unmarshal 行为一致:返回解码后的字符串。
|
||||||
|
parsed.System = sys.String()
|
||||||
|
default:
|
||||||
|
var system any
|
||||||
|
if err := json.Unmarshal(sliceRawFromBody(body, sys), &system); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
parsed.System = system
|
||||||
|
}
|
||||||
}
|
}
|
||||||
if messages, ok := req["messages"].([]any); ok {
|
|
||||||
|
if msgs := gjson.Get(jsonStr, "messages"); msgs.Exists() && msgs.IsArray() {
|
||||||
|
var messages []any
|
||||||
|
if err := json.Unmarshal(sliceRawFromBody(body, msgs), &messages); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
parsed.Messages = messages
|
parsed.Messages = messages
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -129,6 +158,20 @@ func ParseGatewayRequest(body []byte, protocol string) (*ParsedRequest, error) {
|
|||||||
return parsed, nil
|
return parsed, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// sliceRawFromBody 返回 Result.Raw 对应的原始字节切片。
|
||||||
|
// 优先使用 Result.Index 直接从 body 切片,避免对大字段(如 messages)产生额外拷贝。
|
||||||
|
// 当 Index 不可用时,退化为复制(理论上极少发生)。
|
||||||
|
func sliceRawFromBody(body []byte, r gjson.Result) []byte {
|
||||||
|
if r.Index > 0 {
|
||||||
|
end := r.Index + len(r.Raw)
|
||||||
|
if end <= len(body) {
|
||||||
|
return body[r.Index:end]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// fallback: 不影响正确性,但会产生一次拷贝
|
||||||
|
return []byte(r.Raw)
|
||||||
|
}
|
||||||
|
|
||||||
// parseIntegralNumber 将 JSON 解码后的数字安全转换为 int。
|
// parseIntegralNumber 将 JSON 解码后的数字安全转换为 int。
|
||||||
// 仅接受“整数值”的输入,小数/NaN/Inf/越界值都会返回 false。
|
// 仅接受“整数值”的输入,小数/NaN/Inf/越界值都会返回 false。
|
||||||
func parseIntegralNumber(raw any) (int, bool) {
|
func parseIntegralNumber(raw any) (int, bool) {
|
||||||
|
|||||||
Reference in New Issue
Block a user