From a05b8b56e3f306f5a08a15363d16a2db2c734b50 Mon Sep 17 00:00:00 2001 From: cyhhao Date: Mon, 19 Jan 2026 03:46:09 +0800 Subject: [PATCH] =?UTF-8?q?fix(=E7=BD=91=E5=85=B3):=20SSE=20=E7=BC=93?= =?UTF-8?q?=E5=86=B2=20input=5Fjson=5Fdelta=20=E5=8F=8D=E5=90=91=E8=BD=AC?= =?UTF-8?q?=E6=8D=A2?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- backend/internal/service/gateway_service.go | 222 ++++++++++++++++---- 1 file changed, 185 insertions(+), 37 deletions(-) diff --git a/backend/internal/service/gateway_service.go b/backend/internal/service/gateway_service.go index fb2d40a3..8a6770c8 100644 --- a/backend/internal/service/gateway_service.go +++ b/backend/internal/service/gateway_service.go @@ -3315,9 +3315,159 @@ func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http } needModelReplace := originalModel != mappedModel - rewriteTools := mimicClaudeCode clientDisconnected := false // 客户端断开标志,断开后继续读取上游以获取完整usage + pendingEventLines := make([]string, 0, 4) + toolInputBuffers := make(map[int]string) + + transformToolInputJSON := func(raw string) string { + raw = strings.TrimSpace(raw) + if raw == "" { + return raw + } + + var parsed any + if err := json.Unmarshal([]byte(raw), &parsed); err != nil { + return replaceToolNamesInText(raw, toolNameMap) + } + + rewritten, changed := rewriteParamKeysInValue(parsed, toolNameMap) + if changed { + if bytes, err := json.Marshal(rewritten); err == nil { + return string(bytes) + } + } + return raw + } + + processSSEEvent := func(lines []string) ([]string, string, error) { + if len(lines) == 0 { + return nil, "", nil + } + + eventName := "" + dataLine := "" + for _, line := range lines { + trimmed := strings.TrimSpace(line) + if strings.HasPrefix(trimmed, "event:") { + eventName = strings.TrimSpace(strings.TrimPrefix(trimmed, "event:")) + continue + } + if dataLine == "" && sseDataRe.MatchString(trimmed) { + dataLine = sseDataRe.ReplaceAllString(trimmed, "") + } + } + + if eventName == "error" { + return nil, dataLine, errors.New("have error in stream") + } + + if dataLine == "" { + return []string{strings.Join(lines, "\n") + "\n\n"}, "", nil + } + + if dataLine == "[DONE]" { + block := "" + if eventName != "" { + block = "event: " + eventName + "\n" + } + block += "data: " + dataLine + "\n\n" + return []string{block}, dataLine, nil + } + + var event map[string]any + if err := json.Unmarshal([]byte(dataLine), &event); err != nil { + replaced := replaceToolNamesInText(dataLine, toolNameMap) + block := "" + if eventName != "" { + block = "event: " + eventName + "\n" + } + block += "data: " + replaced + "\n\n" + return []string{block}, replaced, nil + } + + eventType, _ := event["type"].(string) + if eventName == "" { + eventName = eventType + } + + if needModelReplace && eventType == "message_start" { + if msg, ok := event["message"].(map[string]any); ok { + if model, ok := msg["model"].(string); ok && model == mappedModel { + msg["model"] = originalModel + } + } + } + + if eventType == "content_block_delta" { + if delta, ok := event["delta"].(map[string]any); ok { + if deltaType, _ := delta["type"].(string); deltaType == "input_json_delta" { + if indexVal, ok := event["index"].(float64); ok { + index := int(indexVal) + if partial, ok := delta["partial_json"].(string); ok { + toolInputBuffers[index] += partial + } + } + return nil, dataLine, nil + } + } + } + + if eventType == "content_block_stop" { + if indexVal, ok := event["index"].(float64); ok { + index := int(indexVal) + if buffered := toolInputBuffers[index]; buffered != "" { + delete(toolInputBuffers, index) + + transformed := transformToolInputJSON(buffered) + synthetic := map[string]any{ + "type": "content_block_delta", + "index": index, + "delta": map[string]any{ + "type": "input_json_delta", + "partial_json": transformed, + }, + } + + synthBytes, synthErr := json.Marshal(synthetic) + if synthErr == nil { + synthBlock := "event: content_block_delta\n" + "data: " + string(synthBytes) + "\n\n" + + rewriteToolNamesInValue(event, toolNameMap) + stopBytes, stopErr := json.Marshal(event) + if stopErr == nil { + stopBlock := "" + if eventName != "" { + stopBlock = "event: " + eventName + "\n" + } + stopBlock += "data: " + string(stopBytes) + "\n\n" + return []string{synthBlock, stopBlock}, string(stopBytes), nil + } + } + } + } + } + + rewriteToolNamesInValue(event, toolNameMap) + newData, err := json.Marshal(event) + if err != nil { + replaced := replaceToolNamesInText(dataLine, toolNameMap) + block := "" + if eventName != "" { + block = "event: " + eventName + "\n" + } + block += "data: " + replaced + "\n\n" + return []string{block}, replaced, nil + } + + block := "" + if eventName != "" { + block = "event: " + eventName + "\n" + } + block += "data: " + string(newData) + "\n\n" + return []string{block}, string(newData), nil + } + for { select { case ev, ok := <-events: @@ -3346,45 +3496,43 @@ func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http return &streamingResult{usage: usage, firstTokenMs: firstTokenMs}, fmt.Errorf("stream read error: %w", ev.err) } line := ev.line - if line == "event: error" { - // 上游返回错误事件,如果客户端已断开仍返回已收集的 usage - if clientDisconnected { - return &streamingResult{usage: usage, firstTokenMs: firstTokenMs, clientDisconnect: true}, nil + trimmed := strings.TrimSpace(line) + + if trimmed == "" { + if len(pendingEventLines) == 0 { + continue } - return nil, errors.New("have error in stream") + + outputBlocks, data, err := processSSEEvent(pendingEventLines) + pendingEventLines = pendingEventLines[:0] + if err != nil { + if clientDisconnected { + return &streamingResult{usage: usage, firstTokenMs: firstTokenMs, clientDisconnect: true}, nil + } + return nil, err + } + + for _, block := range outputBlocks { + if !clientDisconnected { + if _, werr := fmt.Fprint(w, block); werr != nil { + clientDisconnected = true + log.Printf("Client disconnected during streaming, continuing to drain upstream for billing") + break + } + flusher.Flush() + } + if data != "" { + if firstTokenMs == nil && data != "[DONE]" { + ms := int(time.Since(startTime).Milliseconds()) + firstTokenMs = &ms + } + s.parseSSEUsage(data, usage) + } + } + continue } - // Extract data from SSE line (supports both "data: " and "data:" formats) - var data string - if sseDataRe.MatchString(line) { - // 如果有模型映射,替换响应中的model字段 - if needModelReplace { - line = s.replaceModelInSSELine(line, mappedModel, originalModel) - } - if rewriteTools { - line = s.replaceToolNamesInSSELine(line, toolNameMap) - } - data = sseDataRe.ReplaceAllString(line, "") - } - - // 写入客户端(统一处理 data 行和非 data 行) - if !clientDisconnected { - if _, err := fmt.Fprintf(w, "%s\n", line); err != nil { - clientDisconnected = true - log.Printf("Client disconnected during streaming, continuing to drain upstream for billing") - } else { - flusher.Flush() - } - } - - // 无论客户端是否断开,都解析 usage(仅对 data 行) - if data != "" { - if firstTokenMs == nil && data != "[DONE]" { - ms := int(time.Since(startTime).Milliseconds()) - firstTokenMs = &ms - } - s.parseSSEUsage(data, usage) - } + pendingEventLines = append(pendingEventLines, line) case <-intervalCh: lastRead := time.Unix(0, atomic.LoadInt64(&lastReadAt))