diff --git a/backend/internal/service/gateway_request.go b/backend/internal/service/gateway_request.go index 417e8aae..46a4ef94 100644 --- a/backend/internal/service/gateway_request.go +++ b/backend/internal/service/gateway_request.go @@ -10,6 +10,23 @@ import ( "github.com/Wei-Shaw/sub2api/internal/domain" "github.com/Wei-Shaw/sub2api/internal/pkg/antigravity" "github.com/tidwall/gjson" + "github.com/tidwall/sjson" +) + +var ( + // 这些字节模式用于 fast-path 判断,避免每次 []byte("...") 产生临时分配。 + patternTypeThinking = []byte(`"type":"thinking"`) + patternTypeThinkingSpaced = []byte(`"type": "thinking"`) + patternTypeRedactedThinking = []byte(`"type":"redacted_thinking"`) + patternTypeRedactedSpaced = []byte(`"type": "redacted_thinking"`) + + patternThinkingField = []byte(`"thinking":`) + patternThinkingFieldSpaced = []byte(`"thinking" :`) + + patternEmptyContent = []byte(`"content":[]`) + patternEmptyContentSpaced = []byte(`"content": []`) + patternEmptyContentSp1 = []byte(`"content" : []`) + patternEmptyContentSp2 = []byte(`"content" :[]`) ) // SessionContext 粘性会话上下文,用于区分不同来源的请求。 @@ -238,49 +255,63 @@ func FilterThinkingBlocks(body []byte) []byte { // - Remove `redacted_thinking` blocks (cannot be converted to text). // - Ensure no message ends up with empty content. func FilterThinkingBlocksForRetry(body []byte) []byte { - hasThinkingContent := bytes.Contains(body, []byte(`"type":"thinking"`)) || - bytes.Contains(body, []byte(`"type": "thinking"`)) || - bytes.Contains(body, []byte(`"type":"redacted_thinking"`)) || - bytes.Contains(body, []byte(`"type": "redacted_thinking"`)) || - bytes.Contains(body, []byte(`"thinking":`)) || - bytes.Contains(body, []byte(`"thinking" :`)) + hasThinkingContent := bytes.Contains(body, patternTypeThinking) || + bytes.Contains(body, patternTypeThinkingSpaced) || + bytes.Contains(body, patternTypeRedactedThinking) || + bytes.Contains(body, patternTypeRedactedSpaced) || + bytes.Contains(body, patternThinkingField) || + bytes.Contains(body, patternThinkingFieldSpaced) // Also check for empty content arrays that need fixing. // Note: This is a heuristic check; the actual empty content handling is done below. - hasEmptyContent := bytes.Contains(body, []byte(`"content":[]`)) || - bytes.Contains(body, []byte(`"content": []`)) || - bytes.Contains(body, []byte(`"content" : []`)) || - bytes.Contains(body, []byte(`"content" :[]`)) + hasEmptyContent := bytes.Contains(body, patternEmptyContent) || + bytes.Contains(body, patternEmptyContentSpaced) || + bytes.Contains(body, patternEmptyContentSp1) || + bytes.Contains(body, patternEmptyContentSp2) // Fast path: nothing to process if !hasThinkingContent && !hasEmptyContent { return body } - var req map[string]any - if err := json.Unmarshal(body, &req); err != nil { + // 尽量避免把整个 body Unmarshal 成 map(会产生大量 map/接口分配)。 + // 这里先用 gjson 把 messages 子树摘出来,后续只对 messages 做 Unmarshal/Marshal。 + jsonStr := *(*string)(unsafe.Pointer(&body)) + msgsRes := gjson.Get(jsonStr, "messages") + if !msgsRes.Exists() || !msgsRes.IsArray() { + return body + } + + // Fast path:只需要删除顶层 thinking,不需要改 messages。 + // 注意:patternThinkingField 可能来自嵌套字段(如 tool_use.input.thinking),因此必须用 gjson 判断顶层字段是否存在。 + containsThinkingBlocks := bytes.Contains(body, patternTypeThinking) || + bytes.Contains(body, patternTypeThinkingSpaced) || + bytes.Contains(body, patternTypeRedactedThinking) || + bytes.Contains(body, patternTypeRedactedSpaced) || + bytes.Contains(body, patternThinkingFieldSpaced) + if !hasEmptyContent && !containsThinkingBlocks { + if topThinking := gjson.Get(jsonStr, "thinking"); topThinking.Exists() { + if out, err := sjson.DeleteBytes(body, "thinking"); err == nil { + return out + } + return body + } + return body + } + + var messages []any + if err := json.Unmarshal(sliceRawFromBody(body, msgsRes), &messages); err != nil { return body } modified := false - messages, ok := req["messages"].([]any) - if !ok { - return body - } - // Disable top-level thinking mode for retry to avoid structural/signature constraints upstream. - if _, exists := req["thinking"]; exists { - delete(req, "thinking") - modified = true - } + deleteTopLevelThinking := gjson.Get(jsonStr, "thinking").Exists() - newMessages := make([]any, 0, len(messages)) - - for _, msg := range messages { - msgMap, ok := msg.(map[string]any) + for i := 0; i < len(messages); i++ { + msgMap, ok := messages[i].(map[string]any) if !ok { - newMessages = append(newMessages, msg) continue } @@ -288,17 +319,30 @@ func FilterThinkingBlocksForRetry(body []byte) []byte { content, ok := msgMap["content"].([]any) if !ok { // String content or other format - keep as is - newMessages = append(newMessages, msg) continue } - newContent := make([]any, 0, len(content)) + // 延迟分配:只有检测到需要修改的块,才构建新 slice。 + var newContent []any modifiedThisMsg := false - for _, block := range content { + ensureNewContent := func(prefixLen int) { + if newContent != nil { + return + } + newContent = make([]any, 0, len(content)) + if prefixLen > 0 { + newContent = append(newContent, content[:prefixLen]...) + } + } + + for bi := 0; bi < len(content); bi++ { + block := content[bi] blockMap, ok := block.(map[string]any) if !ok { - newContent = append(newContent, block) + if newContent != nil { + newContent = append(newContent, block) + } continue } @@ -308,17 +352,15 @@ func FilterThinkingBlocksForRetry(body []byte) []byte { switch blockType { case "thinking": modifiedThisMsg = true + ensureNewContent(bi) thinkingText, _ := blockMap["thinking"].(string) - if thinkingText == "" { - continue + if thinkingText != "" { + newContent = append(newContent, map[string]any{"type": "text", "text": thinkingText}) } - newContent = append(newContent, map[string]any{ - "type": "text", - "text": thinkingText, - }) continue case "redacted_thinking": modifiedThisMsg = true + ensureNewContent(bi) continue } @@ -326,6 +368,7 @@ func FilterThinkingBlocksForRetry(body []byte) []byte { if blockType == "" { if rawThinking, hasThinking := blockMap["thinking"]; hasThinking { modifiedThisMsg = true + ensureNewContent(bi) switch v := rawThinking.(type) { case string: if v != "" { @@ -340,40 +383,64 @@ func FilterThinkingBlocksForRetry(body []byte) []byte { } } - newContent = append(newContent, block) + if newContent != nil { + newContent = append(newContent, block) + } } // Handle empty content: either from filtering or originally empty + if newContent == nil { + if len(content) == 0 { + modified = true + placeholder := "(content removed)" + if role == "assistant" { + placeholder = "(assistant content removed)" + } + msgMap["content"] = []any{map[string]any{"type": "text", "text": placeholder}} + } + continue + } + if len(newContent) == 0 { modified = true placeholder := "(content removed)" if role == "assistant" { placeholder = "(assistant content removed)" } - newContent = append(newContent, map[string]any{ - "type": "text", - "text": placeholder, - }) - msgMap["content"] = newContent - } else if modifiedThisMsg { + msgMap["content"] = []any{map[string]any{"type": "text", "text": placeholder}} + continue + } + + if modifiedThisMsg { modified = true msgMap["content"] = newContent } - newMessages = append(newMessages, msgMap) } - if modified { - req["messages"] = newMessages - } else { + if !modified && !deleteTopLevelThinking { // Avoid rewriting JSON when no changes are needed. return body } - newBody, err := json.Marshal(req) - if err != nil { - return body + out := body + if deleteTopLevelThinking { + if b, err := sjson.DeleteBytes(out, "thinking"); err == nil { + out = b + } else { + return body + } } - return newBody + if modified { + msgsBytes, err := json.Marshal(messages) + if err != nil { + return body + } + out, err = sjson.SetRawBytes(out, "messages", msgsBytes) + if err != nil { + return body + } + } + return out } // FilterSignatureSensitiveBlocksForRetry is a stronger retry filter for cases where upstream errors indicate diff --git a/backend/internal/service/gateway_request_test.go b/backend/internal/service/gateway_request_test.go index 28f916e8..42367ebe 100644 --- a/backend/internal/service/gateway_request_test.go +++ b/backend/internal/service/gateway_request_test.go @@ -450,9 +450,9 @@ func TestParseGatewayRequest_TypeValidation(t *testing.T) { errSubstr: "invalid model field type", }, { - name: "model 为 null — gjson Null 类型触发类型校验错误", - body: `{"model":null}`, - wantErr: true, // gjson: Exists()=true, Type=Null != String → 返回错误 + name: "model 为 null — gjson Null 类型触发类型校验错误", + body: `{"model":null}`, + wantErr: true, // gjson: Exists()=true, Type=Null != String → 返回错误 errSubstr: "invalid model field type", }, { @@ -468,9 +468,9 @@ func TestParseGatewayRequest_TypeValidation(t *testing.T) { errSubstr: "invalid stream field type", }, { - name: "stream 为 null — gjson Null 类型触发类型校验错误", - body: `{"stream":null}`, - wantErr: true, // gjson: Exists()=true, Type=Null != True && != False → 返回错误 + name: "stream 为 null — gjson Null 类型触发类型校验错误", + body: `{"stream":null}`, + wantErr: true, // gjson: Exists()=true, Type=Null != True && != False → 返回错误 errSubstr: "invalid stream field type", }, { @@ -499,16 +499,16 @@ func TestParseGatewayRequest_TypeValidation(t *testing.T) { // Task 7.2 — 可选字段缺失测试 func TestParseGatewayRequest_OptionalFieldsMissing(t *testing.T) { tests := []struct { - name string - body string - wantModel string - wantStream bool - wantMetadataUID string - wantHasSystem bool - wantThinking bool - wantMaxTokens int - wantMessagesNil bool - wantMessagesLen int + name string + body string + wantModel string + wantStream bool + wantMetadataUID string + wantHasSystem bool + wantThinking bool + wantMaxTokens int + wantMessagesNil bool + wantMessagesLen int }{ { name: "完全空 JSON — 所有字段零值",