From 5d1c51a37f47dc8213b64f0ca3eb71d5d41425bc Mon Sep 17 00:00:00 2001 From: yangjianbo Date: Tue, 10 Feb 2026 09:13:20 +0800 Subject: [PATCH] =?UTF-8?q?fix(handler):=20=E4=BF=AE=E5=A4=8D=20gjson=20?= =?UTF-8?q?=E8=BF=81=E7=A7=BB=E5=90=8E=E7=9A=84=E8=AF=B7=E6=B1=82=E6=A0=A1?= =?UTF-8?q?=E9=AA=8C=E8=AF=AD=E4=B9=89=E5=9B=9E=E9=80=80?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - OpenAI handler: 添加 gjson.ValidBytes 校验 JSON 合法性;model 校验改为 检查 gjson.String 类型而非仅判断非空(拒绝 model:123 等非法类型);stream 字段添加 True/False 类型检查;sjson.SetBytes 返回值显式处理错误 - Sora handler: 添加 gjson.ValidBytes 校验;model 校验同上改为类型检查; messages 校验从 Exists+Type==JSON 改为 IsArray+len>0(拒绝空数组和对象) - 补充 TestOpenAIHandler_GjsonValidation 和更新 TestSoraHandler_ValidationExtraction 覆盖新增的边界校验场景 Co-Authored-By: Claude Opus 4.6 --- .../handler/openai_gateway_handler.go | 28 +++++++++--- .../handler/openai_gateway_handler_test.go | 45 ++++++++++++++++++- .../internal/handler/sora_gateway_handler.go | 15 +++++-- .../handler/sora_gateway_handler_test.go | 31 ++++++++++--- 4 files changed, 102 insertions(+), 17 deletions(-) diff --git a/backend/internal/handler/openai_gateway_handler.go b/backend/internal/handler/openai_gateway_handler.go index 81195804..a4c25284 100644 --- a/backend/internal/handler/openai_gateway_handler.go +++ b/backend/internal/handler/openai_gateway_handler.go @@ -95,15 +95,26 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) { setOpsRequestContext(c, "", false, body) - // 使用 gjson 只读提取字段做校验,避免完整 Unmarshal - reqModel := gjson.GetBytes(body, "model").String() - reqStream := gjson.GetBytes(body, "stream").Bool() + // 校验请求体 JSON 合法性 + if !gjson.ValidBytes(body) { + h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Failed to parse request body") + return + } - // 验证 model 必填 - if reqModel == "" { + // 使用 gjson 只读提取字段做校验,避免完整 Unmarshal + modelResult := gjson.GetBytes(body, "model") + if !modelResult.Exists() || modelResult.Type != gjson.String || modelResult.String() == "" { h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "model is required") return } + reqModel := modelResult.String() + + streamResult := gjson.GetBytes(body, "stream") + if streamResult.Exists() && streamResult.Type != gjson.True && streamResult.Type != gjson.False { + h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "invalid stream field type") + return + } + reqStream := streamResult.Bool() userAgent := c.GetHeader("User-Agent") isCodexCLI := openai.IsCodexCLIRequest(userAgent) || (h.cfg != nil && h.cfg.Gateway.ForceCodexCLI) @@ -111,7 +122,12 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) { existingInstructions := gjson.GetBytes(body, "instructions").String() if strings.TrimSpace(existingInstructions) == "" { if instructions := strings.TrimSpace(service.GetOpenCodeInstructions()); instructions != "" { - body, _ = sjson.SetBytes(body, "instructions", instructions) + newBody, err := sjson.SetBytes(body, "instructions", instructions) + if err != nil { + h.errorResponse(c, http.StatusInternalServerError, "api_error", "Failed to process request") + return + } + body = newBody } } } diff --git a/backend/internal/handler/openai_gateway_handler_test.go b/backend/internal/handler/openai_gateway_handler_test.go index 782acfbf..65296da4 100644 --- a/backend/internal/handler/openai_gateway_handler_test.go +++ b/backend/internal/handler/openai_gateway_handler_test.go @@ -121,7 +121,11 @@ func TestOpenAIHandler_GjsonExtraction(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { body := []byte(tt.body) - model := gjson.GetBytes(body, "model").String() + modelResult := gjson.GetBytes(body, "model") + model := "" + if modelResult.Type == gjson.String { + model = modelResult.String() + } stream := gjson.GetBytes(body, "stream").Bool() require.Equal(t, tt.wantModel, model) require.Equal(t, tt.wantStream, stream) @@ -129,6 +133,38 @@ func TestOpenAIHandler_GjsonExtraction(t *testing.T) { } } +// TestOpenAIHandler_GjsonValidation 验证修复后的 JSON 合法性和类型校验 +func TestOpenAIHandler_GjsonValidation(t *testing.T) { + // 非法 JSON 被 gjson.ValidBytes 拦截 + require.False(t, gjson.ValidBytes([]byte(`{invalid json`))) + + // model 为数字 → 类型不是 gjson.String,应被拒绝 + body := []byte(`{"model":123}`) + modelResult := gjson.GetBytes(body, "model") + require.True(t, modelResult.Exists()) + require.NotEqual(t, gjson.String, modelResult.Type) + + // model 为 null → 类型不是 gjson.String,应被拒绝 + body2 := []byte(`{"model":null}`) + modelResult2 := gjson.GetBytes(body2, "model") + require.True(t, modelResult2.Exists()) + require.NotEqual(t, gjson.String, modelResult2.Type) + + // stream 为 string → 类型既不是 True 也不是 False,应被拒绝 + body3 := []byte(`{"model":"gpt-4","stream":"true"}`) + streamResult := gjson.GetBytes(body3, "stream") + require.True(t, streamResult.Exists()) + require.NotEqual(t, gjson.True, streamResult.Type) + require.NotEqual(t, gjson.False, streamResult.Type) + + // stream 为 int → 同上 + body4 := []byte(`{"model":"gpt-4","stream":1}`) + streamResult2 := gjson.GetBytes(body4, "stream") + require.True(t, streamResult2.Exists()) + require.NotEqual(t, gjson.True, streamResult2.Type) + require.NotEqual(t, gjson.False, streamResult2.Type) +} + // TestOpenAIHandler_InstructionsInjection 验证 instructions 的 gjson/sjson 注入逻辑 func TestOpenAIHandler_InstructionsInjection(t *testing.T) { // 测试 1:无 instructions → 注入 @@ -148,4 +184,11 @@ func TestOpenAIHandler_InstructionsInjection(t *testing.T) { body3 := []byte(`{"model":"gpt-4","instructions":" "}`) existing3 := strings.TrimSpace(gjson.GetBytes(body3, "instructions").String()) require.Empty(t, existing3) + + // 测试 4:sjson.SetBytes 返回错误时不应 panic + // 正常 JSON 不会产生 sjson 错误,验证返回值被正确处理 + validBody := []byte(`{"model":"gpt-4"}`) + result, setErr := sjson.SetBytes(validBody, "instructions", "hello") + require.NoError(t, setErr) + require.True(t, gjson.ValidBytes(result)) } diff --git a/backend/internal/handler/sora_gateway_handler.go b/backend/internal/handler/sora_gateway_handler.go index fdf28956..aed54167 100644 --- a/backend/internal/handler/sora_gateway_handler.go +++ b/backend/internal/handler/sora_gateway_handler.go @@ -106,13 +106,22 @@ func (h *SoraGatewayHandler) ChatCompletions(c *gin.Context) { setOpsRequestContext(c, "", false, body) + // 校验请求体 JSON 合法性 + if !gjson.ValidBytes(body) { + h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Failed to parse request body") + return + } + // 使用 gjson 只读提取字段做校验,避免完整 Unmarshal - reqModel := gjson.GetBytes(body, "model").String() - if reqModel == "" { + modelResult := gjson.GetBytes(body, "model") + if !modelResult.Exists() || modelResult.Type != gjson.String || modelResult.String() == "" { h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "model is required") return } - if !gjson.GetBytes(body, "messages").Exists() || gjson.GetBytes(body, "messages").Type != gjson.JSON { + reqModel := modelResult.String() + + msgsResult := gjson.GetBytes(body, "messages") + if !msgsResult.IsArray() || len(msgsResult.Array()) == 0 { h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "messages is required") return } diff --git a/backend/internal/handler/sora_gateway_handler_test.go b/backend/internal/handler/sora_gateway_handler_test.go index fa321585..3cae5cdd 100644 --- a/backend/internal/handler/sora_gateway_handler_test.go +++ b/backend/internal/handler/sora_gateway_handler_test.go @@ -440,18 +440,35 @@ func TestSoraHandler_StreamForcing(t *testing.T) { func TestSoraHandler_ValidationExtraction(t *testing.T) { // model 缺失 body := []byte(`{"messages":[{"role":"user","content":"test"}]}`) - model := gjson.GetBytes(body, "model").String() - require.Empty(t, model) + modelResult := gjson.GetBytes(body, "model") + require.True(t, !modelResult.Exists() || modelResult.Type != gjson.String || modelResult.String() == "") + + // model 为数字 → 类型不是 gjson.String,应被拒绝 + body1b := []byte(`{"model":123,"messages":[{"role":"user","content":"test"}]}`) + modelResult1b := gjson.GetBytes(body1b, "model") + require.True(t, modelResult1b.Exists()) + require.NotEqual(t, gjson.String, modelResult1b.Type) // messages 缺失 body2 := []byte(`{"model":"sora"}`) - require.False(t, gjson.GetBytes(body2, "messages").Exists()) + require.False(t, gjson.GetBytes(body2, "messages").IsArray()) - // messages 不是 JSON 数组 + // messages 不是 JSON 数组(字符串) body3 := []byte(`{"model":"sora","messages":"not array"}`) - msgResult := gjson.GetBytes(body3, "messages") - require.True(t, msgResult.Exists()) - require.NotEqual(t, gjson.JSON, msgResult.Type) // string 类型,不是 JSON 数组 + require.False(t, gjson.GetBytes(body3, "messages").IsArray()) + + // messages 是对象而非数组 → IsArray 返回 false + body4 := []byte(`{"model":"sora","messages":{}}`) + require.False(t, gjson.GetBytes(body4, "messages").IsArray()) + + // messages 是空数组 → IsArray 为 true 但 len==0,应被拒绝 + body5 := []byte(`{"model":"sora","messages":[]}`) + msgsResult := gjson.GetBytes(body5, "messages") + require.True(t, msgsResult.IsArray()) + require.Equal(t, 0, len(msgsResult.Array())) + + // 非法 JSON 被 gjson.ValidBytes 拦截 + require.False(t, gjson.ValidBytes([]byte(`{invalid`))) } // TestGenerateOpenAISessionHash_WithBody 验证 generateOpenAISessionHash 的 body/header 解析逻辑