fix(handler): 修复 gjson 迁移后的请求校验语义回退
- OpenAI handler: 添加 gjson.ValidBytes 校验 JSON 合法性;model 校验改为 检查 gjson.String 类型而非仅判断非空(拒绝 model:123 等非法类型);stream 字段添加 True/False 类型检查;sjson.SetBytes 返回值显式处理错误 - Sora handler: 添加 gjson.ValidBytes 校验;model 校验同上改为类型检查; messages 校验从 Exists+Type==JSON 改为 IsArray+len>0(拒绝空数组和对象) - 补充 TestOpenAIHandler_GjsonValidation 和更新 TestSoraHandler_ValidationExtraction 覆盖新增的边界校验场景 Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -95,15 +95,26 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
|
|||||||
|
|
||||||
setOpsRequestContext(c, "", false, body)
|
setOpsRequestContext(c, "", false, body)
|
||||||
|
|
||||||
// 使用 gjson 只读提取字段做校验,避免完整 Unmarshal
|
// 校验请求体 JSON 合法性
|
||||||
reqModel := gjson.GetBytes(body, "model").String()
|
if !gjson.ValidBytes(body) {
|
||||||
reqStream := gjson.GetBytes(body, "stream").Bool()
|
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Failed to parse request body")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
// 验证 model 必填
|
// 使用 gjson 只读提取字段做校验,避免完整 Unmarshal
|
||||||
if reqModel == "" {
|
modelResult := gjson.GetBytes(body, "model")
|
||||||
|
if !modelResult.Exists() || modelResult.Type != gjson.String || modelResult.String() == "" {
|
||||||
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "model is required")
|
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "model is required")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
reqModel := modelResult.String()
|
||||||
|
|
||||||
|
streamResult := gjson.GetBytes(body, "stream")
|
||||||
|
if streamResult.Exists() && streamResult.Type != gjson.True && streamResult.Type != gjson.False {
|
||||||
|
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "invalid stream field type")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
reqStream := streamResult.Bool()
|
||||||
|
|
||||||
userAgent := c.GetHeader("User-Agent")
|
userAgent := c.GetHeader("User-Agent")
|
||||||
isCodexCLI := openai.IsCodexCLIRequest(userAgent) || (h.cfg != nil && h.cfg.Gateway.ForceCodexCLI)
|
isCodexCLI := openai.IsCodexCLIRequest(userAgent) || (h.cfg != nil && h.cfg.Gateway.ForceCodexCLI)
|
||||||
@@ -111,7 +122,12 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
|
|||||||
existingInstructions := gjson.GetBytes(body, "instructions").String()
|
existingInstructions := gjson.GetBytes(body, "instructions").String()
|
||||||
if strings.TrimSpace(existingInstructions) == "" {
|
if strings.TrimSpace(existingInstructions) == "" {
|
||||||
if instructions := strings.TrimSpace(service.GetOpenCodeInstructions()); instructions != "" {
|
if instructions := strings.TrimSpace(service.GetOpenCodeInstructions()); instructions != "" {
|
||||||
body, _ = sjson.SetBytes(body, "instructions", instructions)
|
newBody, err := sjson.SetBytes(body, "instructions", instructions)
|
||||||
|
if err != nil {
|
||||||
|
h.errorResponse(c, http.StatusInternalServerError, "api_error", "Failed to process request")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
body = newBody
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -121,7 +121,11 @@ func TestOpenAIHandler_GjsonExtraction(t *testing.T) {
|
|||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
body := []byte(tt.body)
|
body := []byte(tt.body)
|
||||||
model := gjson.GetBytes(body, "model").String()
|
modelResult := gjson.GetBytes(body, "model")
|
||||||
|
model := ""
|
||||||
|
if modelResult.Type == gjson.String {
|
||||||
|
model = modelResult.String()
|
||||||
|
}
|
||||||
stream := gjson.GetBytes(body, "stream").Bool()
|
stream := gjson.GetBytes(body, "stream").Bool()
|
||||||
require.Equal(t, tt.wantModel, model)
|
require.Equal(t, tt.wantModel, model)
|
||||||
require.Equal(t, tt.wantStream, stream)
|
require.Equal(t, tt.wantStream, stream)
|
||||||
@@ -129,6 +133,38 @@ func TestOpenAIHandler_GjsonExtraction(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TestOpenAIHandler_GjsonValidation 验证修复后的 JSON 合法性和类型校验
|
||||||
|
func TestOpenAIHandler_GjsonValidation(t *testing.T) {
|
||||||
|
// 非法 JSON 被 gjson.ValidBytes 拦截
|
||||||
|
require.False(t, gjson.ValidBytes([]byte(`{invalid json`)))
|
||||||
|
|
||||||
|
// model 为数字 → 类型不是 gjson.String,应被拒绝
|
||||||
|
body := []byte(`{"model":123}`)
|
||||||
|
modelResult := gjson.GetBytes(body, "model")
|
||||||
|
require.True(t, modelResult.Exists())
|
||||||
|
require.NotEqual(t, gjson.String, modelResult.Type)
|
||||||
|
|
||||||
|
// model 为 null → 类型不是 gjson.String,应被拒绝
|
||||||
|
body2 := []byte(`{"model":null}`)
|
||||||
|
modelResult2 := gjson.GetBytes(body2, "model")
|
||||||
|
require.True(t, modelResult2.Exists())
|
||||||
|
require.NotEqual(t, gjson.String, modelResult2.Type)
|
||||||
|
|
||||||
|
// stream 为 string → 类型既不是 True 也不是 False,应被拒绝
|
||||||
|
body3 := []byte(`{"model":"gpt-4","stream":"true"}`)
|
||||||
|
streamResult := gjson.GetBytes(body3, "stream")
|
||||||
|
require.True(t, streamResult.Exists())
|
||||||
|
require.NotEqual(t, gjson.True, streamResult.Type)
|
||||||
|
require.NotEqual(t, gjson.False, streamResult.Type)
|
||||||
|
|
||||||
|
// stream 为 int → 同上
|
||||||
|
body4 := []byte(`{"model":"gpt-4","stream":1}`)
|
||||||
|
streamResult2 := gjson.GetBytes(body4, "stream")
|
||||||
|
require.True(t, streamResult2.Exists())
|
||||||
|
require.NotEqual(t, gjson.True, streamResult2.Type)
|
||||||
|
require.NotEqual(t, gjson.False, streamResult2.Type)
|
||||||
|
}
|
||||||
|
|
||||||
// TestOpenAIHandler_InstructionsInjection 验证 instructions 的 gjson/sjson 注入逻辑
|
// TestOpenAIHandler_InstructionsInjection 验证 instructions 的 gjson/sjson 注入逻辑
|
||||||
func TestOpenAIHandler_InstructionsInjection(t *testing.T) {
|
func TestOpenAIHandler_InstructionsInjection(t *testing.T) {
|
||||||
// 测试 1:无 instructions → 注入
|
// 测试 1:无 instructions → 注入
|
||||||
@@ -148,4 +184,11 @@ func TestOpenAIHandler_InstructionsInjection(t *testing.T) {
|
|||||||
body3 := []byte(`{"model":"gpt-4","instructions":" "}`)
|
body3 := []byte(`{"model":"gpt-4","instructions":" "}`)
|
||||||
existing3 := strings.TrimSpace(gjson.GetBytes(body3, "instructions").String())
|
existing3 := strings.TrimSpace(gjson.GetBytes(body3, "instructions").String())
|
||||||
require.Empty(t, existing3)
|
require.Empty(t, existing3)
|
||||||
|
|
||||||
|
// 测试 4:sjson.SetBytes 返回错误时不应 panic
|
||||||
|
// 正常 JSON 不会产生 sjson 错误,验证返回值被正确处理
|
||||||
|
validBody := []byte(`{"model":"gpt-4"}`)
|
||||||
|
result, setErr := sjson.SetBytes(validBody, "instructions", "hello")
|
||||||
|
require.NoError(t, setErr)
|
||||||
|
require.True(t, gjson.ValidBytes(result))
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -106,13 +106,22 @@ func (h *SoraGatewayHandler) ChatCompletions(c *gin.Context) {
|
|||||||
|
|
||||||
setOpsRequestContext(c, "", false, body)
|
setOpsRequestContext(c, "", false, body)
|
||||||
|
|
||||||
|
// 校验请求体 JSON 合法性
|
||||||
|
if !gjson.ValidBytes(body) {
|
||||||
|
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Failed to parse request body")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
// 使用 gjson 只读提取字段做校验,避免完整 Unmarshal
|
// 使用 gjson 只读提取字段做校验,避免完整 Unmarshal
|
||||||
reqModel := gjson.GetBytes(body, "model").String()
|
modelResult := gjson.GetBytes(body, "model")
|
||||||
if reqModel == "" {
|
if !modelResult.Exists() || modelResult.Type != gjson.String || modelResult.String() == "" {
|
||||||
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "model is required")
|
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "model is required")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if !gjson.GetBytes(body, "messages").Exists() || gjson.GetBytes(body, "messages").Type != gjson.JSON {
|
reqModel := modelResult.String()
|
||||||
|
|
||||||
|
msgsResult := gjson.GetBytes(body, "messages")
|
||||||
|
if !msgsResult.IsArray() || len(msgsResult.Array()) == 0 {
|
||||||
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "messages is required")
|
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "messages is required")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -440,18 +440,35 @@ func TestSoraHandler_StreamForcing(t *testing.T) {
|
|||||||
func TestSoraHandler_ValidationExtraction(t *testing.T) {
|
func TestSoraHandler_ValidationExtraction(t *testing.T) {
|
||||||
// model 缺失
|
// model 缺失
|
||||||
body := []byte(`{"messages":[{"role":"user","content":"test"}]}`)
|
body := []byte(`{"messages":[{"role":"user","content":"test"}]}`)
|
||||||
model := gjson.GetBytes(body, "model").String()
|
modelResult := gjson.GetBytes(body, "model")
|
||||||
require.Empty(t, model)
|
require.True(t, !modelResult.Exists() || modelResult.Type != gjson.String || modelResult.String() == "")
|
||||||
|
|
||||||
|
// model 为数字 → 类型不是 gjson.String,应被拒绝
|
||||||
|
body1b := []byte(`{"model":123,"messages":[{"role":"user","content":"test"}]}`)
|
||||||
|
modelResult1b := gjson.GetBytes(body1b, "model")
|
||||||
|
require.True(t, modelResult1b.Exists())
|
||||||
|
require.NotEqual(t, gjson.String, modelResult1b.Type)
|
||||||
|
|
||||||
// messages 缺失
|
// messages 缺失
|
||||||
body2 := []byte(`{"model":"sora"}`)
|
body2 := []byte(`{"model":"sora"}`)
|
||||||
require.False(t, gjson.GetBytes(body2, "messages").Exists())
|
require.False(t, gjson.GetBytes(body2, "messages").IsArray())
|
||||||
|
|
||||||
// messages 不是 JSON 数组
|
// messages 不是 JSON 数组(字符串)
|
||||||
body3 := []byte(`{"model":"sora","messages":"not array"}`)
|
body3 := []byte(`{"model":"sora","messages":"not array"}`)
|
||||||
msgResult := gjson.GetBytes(body3, "messages")
|
require.False(t, gjson.GetBytes(body3, "messages").IsArray())
|
||||||
require.True(t, msgResult.Exists())
|
|
||||||
require.NotEqual(t, gjson.JSON, msgResult.Type) // string 类型,不是 JSON 数组
|
// messages 是对象而非数组 → IsArray 返回 false
|
||||||
|
body4 := []byte(`{"model":"sora","messages":{}}`)
|
||||||
|
require.False(t, gjson.GetBytes(body4, "messages").IsArray())
|
||||||
|
|
||||||
|
// messages 是空数组 → IsArray 为 true 但 len==0,应被拒绝
|
||||||
|
body5 := []byte(`{"model":"sora","messages":[]}`)
|
||||||
|
msgsResult := gjson.GetBytes(body5, "messages")
|
||||||
|
require.True(t, msgsResult.IsArray())
|
||||||
|
require.Equal(t, 0, len(msgsResult.Array()))
|
||||||
|
|
||||||
|
// 非法 JSON 被 gjson.ValidBytes 拦截
|
||||||
|
require.False(t, gjson.ValidBytes([]byte(`{invalid`)))
|
||||||
}
|
}
|
||||||
|
|
||||||
// TestGenerateOpenAISessionHash_WithBody 验证 generateOpenAISessionHash 的 body/header 解析逻辑
|
// TestGenerateOpenAISessionHash_WithBody 验证 generateOpenAISessionHash 的 body/header 解析逻辑
|
||||||
|
|||||||
Reference in New Issue
Block a user