perf(service): 优化 model 替换函数,用 gjson/sjson 替代全量 JSON 序列化
SSE 热路径中 replaceModelInSSELine 和 replaceModelInResponseBody 原来 使用 json.Unmarshal/Marshal 对每个事件做全量反序列化再序列化,现改为 gjson.Get/sjson.Set 精确字段操作,消除 O(n) 中间 map 分配,保持 JSON 字段顺序不变。涉及 OpenAIGatewayService 和 GatewayService 两个服务。 新增 23 个单元测试覆盖:顶层/嵌套 model 替换、不匹配跳过、空行/[DONE]/ 非法 JSON 等边界情况。 Fixes: P1-08 Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -24,6 +24,8 @@ import (
|
||||
"github.com/Wei-Shaw/sub2api/internal/util/responseheaders"
|
||||
"github.com/Wei-Shaw/sub2api/internal/util/urlvalidator"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/tidwall/gjson"
|
||||
"github.com/tidwall/sjson"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -1430,31 +1432,22 @@ func (s *OpenAIGatewayService) replaceModelInSSELine(line, fromModel, toModel st
|
||||
return line
|
||||
}
|
||||
|
||||
var event map[string]any
|
||||
if err := json.Unmarshal([]byte(data), &event); err != nil {
|
||||
return line
|
||||
}
|
||||
|
||||
// Replace model in response
|
||||
if m, ok := event["model"].(string); ok && m == fromModel {
|
||||
event["model"] = toModel
|
||||
newData, err := json.Marshal(event)
|
||||
// 使用 gjson 精确检查 model 字段,避免全量 JSON 反序列化
|
||||
if m := gjson.Get(data, "model"); m.Exists() && m.Str == fromModel {
|
||||
newData, err := sjson.Set(data, "model", toModel)
|
||||
if err != nil {
|
||||
return line
|
||||
}
|
||||
return "data: " + string(newData)
|
||||
return "data: " + newData
|
||||
}
|
||||
|
||||
// Check nested response
|
||||
if response, ok := event["response"].(map[string]any); ok {
|
||||
if m, ok := response["model"].(string); ok && m == fromModel {
|
||||
response["model"] = toModel
|
||||
newData, err := json.Marshal(event)
|
||||
if err != nil {
|
||||
return line
|
||||
}
|
||||
return "data: " + string(newData)
|
||||
// 检查嵌套的 response.model 字段
|
||||
if m := gjson.Get(data, "response.model"); m.Exists() && m.Str == fromModel {
|
||||
newData, err := sjson.Set(data, "response.model", toModel)
|
||||
if err != nil {
|
||||
return line
|
||||
}
|
||||
return "data: " + newData
|
||||
}
|
||||
|
||||
return line
|
||||
@@ -1674,23 +1667,15 @@ func (s *OpenAIGatewayService) validateUpstreamBaseURL(raw string) (string, erro
|
||||
}
|
||||
|
||||
func (s *OpenAIGatewayService) replaceModelInResponseBody(body []byte, fromModel, toModel string) []byte {
|
||||
var resp map[string]any
|
||||
if err := json.Unmarshal(body, &resp); err != nil {
|
||||
return body
|
||||
// 使用 gjson/sjson 精确替换 model 字段,避免全量 JSON 反序列化
|
||||
if m := gjson.GetBytes(body, "model"); m.Exists() && m.Str == fromModel {
|
||||
newBody, err := sjson.SetBytes(body, "model", toModel)
|
||||
if err != nil {
|
||||
return body
|
||||
}
|
||||
return newBody
|
||||
}
|
||||
|
||||
model, ok := resp["model"].(string)
|
||||
if !ok || model != fromModel {
|
||||
return body
|
||||
}
|
||||
|
||||
resp["model"] = toModel
|
||||
newBody, err := json.Marshal(resp)
|
||||
if err != nil {
|
||||
return body
|
||||
}
|
||||
|
||||
return newBody
|
||||
return body
|
||||
}
|
||||
|
||||
// OpenAIRecordUsageInput input for recording usage
|
||||
|
||||
Reference in New Issue
Block a user