perf(service): 优化 model 替换函数,用 gjson/sjson 替代全量 JSON 序列化
SSE 热路径中 replaceModelInSSELine 和 replaceModelInResponseBody 原来 使用 json.Unmarshal/Marshal 对每个事件做全量反序列化再序列化,现改为 gjson.Get/sjson.Set 精确字段操作,消除 O(n) 中间 map 分配,保持 JSON 字段顺序不变。涉及 OpenAIGatewayService 和 GatewayService 两个服务。 新增 23 个单元测试覆盖:顶层/嵌套 model 替换、不匹配跳过、空行/[DONE]/ 非法 JSON 等边界情况。 Fixes: P1-08 Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -4577,24 +4577,16 @@ func (s *GatewayService) handleNonStreamingResponse(ctx context.Context, resp *h
|
|||||||
}
|
}
|
||||||
|
|
||||||
// replaceModelInResponseBody 替换响应体中的model字段
|
// replaceModelInResponseBody 替换响应体中的model字段
|
||||||
|
// 使用 gjson/sjson 精确替换,避免全量 JSON 反序列化
|
||||||
func (s *GatewayService) replaceModelInResponseBody(body []byte, fromModel, toModel string) []byte {
|
func (s *GatewayService) replaceModelInResponseBody(body []byte, fromModel, toModel string) []byte {
|
||||||
var resp map[string]any
|
if m := gjson.GetBytes(body, "model"); m.Exists() && m.Str == fromModel {
|
||||||
if err := json.Unmarshal(body, &resp); err != nil {
|
newBody, err := sjson.SetBytes(body, "model", toModel)
|
||||||
return body
|
if err != nil {
|
||||||
|
return body
|
||||||
|
}
|
||||||
|
return newBody
|
||||||
}
|
}
|
||||||
|
return body
|
||||||
model, ok := resp["model"].(string)
|
|
||||||
if !ok || model != fromModel {
|
|
||||||
return body
|
|
||||||
}
|
|
||||||
|
|
||||||
resp["model"] = toModel
|
|
||||||
newBody, err := json.Marshal(resp)
|
|
||||||
if err != nil {
|
|
||||||
return body
|
|
||||||
}
|
|
||||||
|
|
||||||
return newBody
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *GatewayService) replaceToolNamesInResponseBody(body []byte, toolNameMap map[string]string) []byte {
|
func (s *GatewayService) replaceToolNamesInResponseBody(body []byte, toolNameMap map[string]string) []byte {
|
||||||
|
|||||||
@@ -24,6 +24,8 @@ import (
|
|||||||
"github.com/Wei-Shaw/sub2api/internal/util/responseheaders"
|
"github.com/Wei-Shaw/sub2api/internal/util/responseheaders"
|
||||||
"github.com/Wei-Shaw/sub2api/internal/util/urlvalidator"
|
"github.com/Wei-Shaw/sub2api/internal/util/urlvalidator"
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
|
"github.com/tidwall/gjson"
|
||||||
|
"github.com/tidwall/sjson"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
@@ -1430,31 +1432,22 @@ func (s *OpenAIGatewayService) replaceModelInSSELine(line, fromModel, toModel st
|
|||||||
return line
|
return line
|
||||||
}
|
}
|
||||||
|
|
||||||
var event map[string]any
|
// 使用 gjson 精确检查 model 字段,避免全量 JSON 反序列化
|
||||||
if err := json.Unmarshal([]byte(data), &event); err != nil {
|
if m := gjson.Get(data, "model"); m.Exists() && m.Str == fromModel {
|
||||||
return line
|
newData, err := sjson.Set(data, "model", toModel)
|
||||||
}
|
|
||||||
|
|
||||||
// Replace model in response
|
|
||||||
if m, ok := event["model"].(string); ok && m == fromModel {
|
|
||||||
event["model"] = toModel
|
|
||||||
newData, err := json.Marshal(event)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return line
|
return line
|
||||||
}
|
}
|
||||||
return "data: " + string(newData)
|
return "data: " + newData
|
||||||
}
|
}
|
||||||
|
|
||||||
// Check nested response
|
// 检查嵌套的 response.model 字段
|
||||||
if response, ok := event["response"].(map[string]any); ok {
|
if m := gjson.Get(data, "response.model"); m.Exists() && m.Str == fromModel {
|
||||||
if m, ok := response["model"].(string); ok && m == fromModel {
|
newData, err := sjson.Set(data, "response.model", toModel)
|
||||||
response["model"] = toModel
|
if err != nil {
|
||||||
newData, err := json.Marshal(event)
|
return line
|
||||||
if err != nil {
|
|
||||||
return line
|
|
||||||
}
|
|
||||||
return "data: " + string(newData)
|
|
||||||
}
|
}
|
||||||
|
return "data: " + newData
|
||||||
}
|
}
|
||||||
|
|
||||||
return line
|
return line
|
||||||
@@ -1674,23 +1667,15 @@ func (s *OpenAIGatewayService) validateUpstreamBaseURL(raw string) (string, erro
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (s *OpenAIGatewayService) replaceModelInResponseBody(body []byte, fromModel, toModel string) []byte {
|
func (s *OpenAIGatewayService) replaceModelInResponseBody(body []byte, fromModel, toModel string) []byte {
|
||||||
var resp map[string]any
|
// 使用 gjson/sjson 精确替换 model 字段,避免全量 JSON 反序列化
|
||||||
if err := json.Unmarshal(body, &resp); err != nil {
|
if m := gjson.GetBytes(body, "model"); m.Exists() && m.Str == fromModel {
|
||||||
return body
|
newBody, err := sjson.SetBytes(body, "model", toModel)
|
||||||
|
if err != nil {
|
||||||
|
return body
|
||||||
|
}
|
||||||
|
return newBody
|
||||||
}
|
}
|
||||||
|
return body
|
||||||
model, ok := resp["model"].(string)
|
|
||||||
if !ok || model != fromModel {
|
|
||||||
return body
|
|
||||||
}
|
|
||||||
|
|
||||||
resp["model"] = toModel
|
|
||||||
newBody, err := json.Marshal(resp)
|
|
||||||
if err != nil {
|
|
||||||
return body
|
|
||||||
}
|
|
||||||
|
|
||||||
return newBody
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// OpenAIRecordUsageInput input for recording usage
|
// OpenAIRecordUsageInput input for recording usage
|
||||||
|
|||||||
@@ -1187,3 +1187,226 @@ func TestOpenAIValidateUpstreamBaseURLEnabledEnforcesAllowlist(t *testing.T) {
|
|||||||
t.Fatalf("expected non-allowlisted host to fail")
|
t.Fatalf("expected non-allowlisted host to fail")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ==================== P1-08 修复:model 替换性能优化测试 ====================
|
||||||
|
|
||||||
|
func TestReplaceModelInSSELine(t *testing.T) {
|
||||||
|
svc := &OpenAIGatewayService{}
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
line string
|
||||||
|
from string
|
||||||
|
to string
|
||||||
|
expected string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "顶层 model 字段替换",
|
||||||
|
line: `data: {"id":"chatcmpl-123","model":"gpt-4o","choices":[]}`,
|
||||||
|
from: "gpt-4o",
|
||||||
|
to: "my-custom-model",
|
||||||
|
expected: `data: {"id":"chatcmpl-123","model":"my-custom-model","choices":[]}`,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "嵌套 response.model 替换",
|
||||||
|
line: `data: {"type":"response","response":{"id":"resp-1","model":"gpt-4o","output":[]}}`,
|
||||||
|
from: "gpt-4o",
|
||||||
|
to: "my-model",
|
||||||
|
expected: `data: {"type":"response","response":{"id":"resp-1","model":"my-model","output":[]}}`,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "model 不匹配时不替换",
|
||||||
|
line: `data: {"id":"chatcmpl-123","model":"gpt-3.5-turbo","choices":[]}`,
|
||||||
|
from: "gpt-4o",
|
||||||
|
to: "my-model",
|
||||||
|
expected: `data: {"id":"chatcmpl-123","model":"gpt-3.5-turbo","choices":[]}`,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "无 model 字段时不替换",
|
||||||
|
line: `data: {"id":"chatcmpl-123","choices":[]}`,
|
||||||
|
from: "gpt-4o",
|
||||||
|
to: "my-model",
|
||||||
|
expected: `data: {"id":"chatcmpl-123","choices":[]}`,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "空 data 行",
|
||||||
|
line: `data: `,
|
||||||
|
from: "gpt-4o",
|
||||||
|
to: "my-model",
|
||||||
|
expected: `data: `,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "[DONE] 行",
|
||||||
|
line: `data: [DONE]`,
|
||||||
|
from: "gpt-4o",
|
||||||
|
to: "my-model",
|
||||||
|
expected: `data: [DONE]`,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "非 data: 前缀行",
|
||||||
|
line: `event: message`,
|
||||||
|
from: "gpt-4o",
|
||||||
|
to: "my-model",
|
||||||
|
expected: `event: message`,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "非法 JSON 不替换",
|
||||||
|
line: `data: {invalid json}`,
|
||||||
|
from: "gpt-4o",
|
||||||
|
to: "my-model",
|
||||||
|
expected: `data: {invalid json}`,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "无空格 data: 格式",
|
||||||
|
line: `data:{"id":"x","model":"gpt-4o"}`,
|
||||||
|
from: "gpt-4o",
|
||||||
|
to: "my-model",
|
||||||
|
expected: `data: {"id":"x","model":"my-model"}`,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "model 名含特殊字符",
|
||||||
|
line: `data: {"model":"org/model-v2.1-beta"}`,
|
||||||
|
from: "org/model-v2.1-beta",
|
||||||
|
to: "custom/alias",
|
||||||
|
expected: `data: {"model":"custom/alias"}`,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "空行",
|
||||||
|
line: "",
|
||||||
|
from: "gpt-4o",
|
||||||
|
to: "my-model",
|
||||||
|
expected: "",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "保持其他字段不变",
|
||||||
|
line: `data: {"id":"abc","object":"chat.completion.chunk","model":"gpt-4o","created":1234567890,"choices":[{"index":0,"delta":{"content":"hi"}}]}`,
|
||||||
|
from: "gpt-4o",
|
||||||
|
to: "alias",
|
||||||
|
expected: `data: {"id":"abc","object":"chat.completion.chunk","model":"alias","created":1234567890,"choices":[{"index":0,"delta":{"content":"hi"}}]}`,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "顶层优先于嵌套:同时存在两个 model",
|
||||||
|
line: `data: {"model":"gpt-4o","response":{"model":"gpt-4o"}}`,
|
||||||
|
from: "gpt-4o",
|
||||||
|
to: "replaced",
|
||||||
|
expected: `data: {"model":"replaced","response":{"model":"gpt-4o"}}`,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
got := svc.replaceModelInSSELine(tt.line, tt.from, tt.to)
|
||||||
|
require.Equal(t, tt.expected, got)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestReplaceModelInSSEBody(t *testing.T) {
|
||||||
|
svc := &OpenAIGatewayService{}
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
body string
|
||||||
|
from string
|
||||||
|
to string
|
||||||
|
expected string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "多行 SSE body 替换",
|
||||||
|
body: "data: {\"model\":\"gpt-4o\",\"choices\":[]}\n\ndata: {\"model\":\"gpt-4o\",\"choices\":[{\"delta\":{\"content\":\"hi\"}}]}\n\ndata: [DONE]\n",
|
||||||
|
from: "gpt-4o",
|
||||||
|
to: "alias",
|
||||||
|
expected: "data: {\"model\":\"alias\",\"choices\":[]}\n\ndata: {\"model\":\"alias\",\"choices\":[{\"delta\":{\"content\":\"hi\"}}]}\n\ndata: [DONE]\n",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "无需替换的 body",
|
||||||
|
body: "data: {\"model\":\"gpt-3.5-turbo\"}\n\ndata: [DONE]\n",
|
||||||
|
from: "gpt-4o",
|
||||||
|
to: "alias",
|
||||||
|
expected: "data: {\"model\":\"gpt-3.5-turbo\"}\n\ndata: [DONE]\n",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "混合 event 和 data 行",
|
||||||
|
body: "event: message\ndata: {\"model\":\"gpt-4o\"}\n\n",
|
||||||
|
from: "gpt-4o",
|
||||||
|
to: "alias",
|
||||||
|
expected: "event: message\ndata: {\"model\":\"alias\"}\n\n",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "空 body",
|
||||||
|
body: "",
|
||||||
|
from: "gpt-4o",
|
||||||
|
to: "alias",
|
||||||
|
expected: "",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
got := svc.replaceModelInSSEBody(tt.body, tt.from, tt.to)
|
||||||
|
require.Equal(t, tt.expected, got)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestReplaceModelInResponseBody(t *testing.T) {
|
||||||
|
svc := &OpenAIGatewayService{}
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
body string
|
||||||
|
from string
|
||||||
|
to string
|
||||||
|
expected string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "替换顶层 model",
|
||||||
|
body: `{"id":"chatcmpl-123","model":"gpt-4o","choices":[]}`,
|
||||||
|
from: "gpt-4o",
|
||||||
|
to: "alias",
|
||||||
|
expected: `{"id":"chatcmpl-123","model":"alias","choices":[]}`,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "model 不匹配不替换",
|
||||||
|
body: `{"id":"chatcmpl-123","model":"gpt-3.5-turbo","choices":[]}`,
|
||||||
|
from: "gpt-4o",
|
||||||
|
to: "alias",
|
||||||
|
expected: `{"id":"chatcmpl-123","model":"gpt-3.5-turbo","choices":[]}`,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "无 model 字段不替换",
|
||||||
|
body: `{"id":"chatcmpl-123","choices":[]}`,
|
||||||
|
from: "gpt-4o",
|
||||||
|
to: "alias",
|
||||||
|
expected: `{"id":"chatcmpl-123","choices":[]}`,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "非法 JSON 返回原值",
|
||||||
|
body: `not json`,
|
||||||
|
from: "gpt-4o",
|
||||||
|
to: "alias",
|
||||||
|
expected: `not json`,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "空 body 返回原值",
|
||||||
|
body: ``,
|
||||||
|
from: "gpt-4o",
|
||||||
|
to: "alias",
|
||||||
|
expected: ``,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "保持嵌套结构不变",
|
||||||
|
body: `{"model":"gpt-4o","usage":{"prompt_tokens":10,"completion_tokens":20},"choices":[{"message":{"role":"assistant","content":"hello"}}]}`,
|
||||||
|
from: "gpt-4o",
|
||||||
|
to: "alias",
|
||||||
|
expected: `{"model":"alias","usage":{"prompt_tokens":10,"completion_tokens":20},"choices":[{"message":{"role":"assistant","content":"hello"}}]}`,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
got := svc.replaceModelInResponseBody([]byte(tt.body), tt.from, tt.to)
|
||||||
|
require.Equal(t, tt.expected, string(got))
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user