fix(网关): SSE 缓冲 input_json_delta 反向转换

This commit is contained in:
cyhhao
2026-01-19 03:46:09 +08:00
parent 8917a3ea8f
commit a05b8b56e3

View File

@@ -3315,9 +3315,159 @@ func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http
}
needModelReplace := originalModel != mappedModel
rewriteTools := mimicClaudeCode
clientDisconnected := false // 客户端断开标志断开后继续读取上游以获取完整usage
pendingEventLines := make([]string, 0, 4)
toolInputBuffers := make(map[int]string)
transformToolInputJSON := func(raw string) string {
raw = strings.TrimSpace(raw)
if raw == "" {
return raw
}
var parsed any
if err := json.Unmarshal([]byte(raw), &parsed); err != nil {
return replaceToolNamesInText(raw, toolNameMap)
}
rewritten, changed := rewriteParamKeysInValue(parsed, toolNameMap)
if changed {
if bytes, err := json.Marshal(rewritten); err == nil {
return string(bytes)
}
}
return raw
}
processSSEEvent := func(lines []string) ([]string, string, error) {
if len(lines) == 0 {
return nil, "", nil
}
eventName := ""
dataLine := ""
for _, line := range lines {
trimmed := strings.TrimSpace(line)
if strings.HasPrefix(trimmed, "event:") {
eventName = strings.TrimSpace(strings.TrimPrefix(trimmed, "event:"))
continue
}
if dataLine == "" && sseDataRe.MatchString(trimmed) {
dataLine = sseDataRe.ReplaceAllString(trimmed, "")
}
}
if eventName == "error" {
return nil, dataLine, errors.New("have error in stream")
}
if dataLine == "" {
return []string{strings.Join(lines, "\n") + "\n\n"}, "", nil
}
if dataLine == "[DONE]" {
block := ""
if eventName != "" {
block = "event: " + eventName + "\n"
}
block += "data: " + dataLine + "\n\n"
return []string{block}, dataLine, nil
}
var event map[string]any
if err := json.Unmarshal([]byte(dataLine), &event); err != nil {
replaced := replaceToolNamesInText(dataLine, toolNameMap)
block := ""
if eventName != "" {
block = "event: " + eventName + "\n"
}
block += "data: " + replaced + "\n\n"
return []string{block}, replaced, nil
}
eventType, _ := event["type"].(string)
if eventName == "" {
eventName = eventType
}
if needModelReplace && eventType == "message_start" {
if msg, ok := event["message"].(map[string]any); ok {
if model, ok := msg["model"].(string); ok && model == mappedModel {
msg["model"] = originalModel
}
}
}
if eventType == "content_block_delta" {
if delta, ok := event["delta"].(map[string]any); ok {
if deltaType, _ := delta["type"].(string); deltaType == "input_json_delta" {
if indexVal, ok := event["index"].(float64); ok {
index := int(indexVal)
if partial, ok := delta["partial_json"].(string); ok {
toolInputBuffers[index] += partial
}
}
return nil, dataLine, nil
}
}
}
if eventType == "content_block_stop" {
if indexVal, ok := event["index"].(float64); ok {
index := int(indexVal)
if buffered := toolInputBuffers[index]; buffered != "" {
delete(toolInputBuffers, index)
transformed := transformToolInputJSON(buffered)
synthetic := map[string]any{
"type": "content_block_delta",
"index": index,
"delta": map[string]any{
"type": "input_json_delta",
"partial_json": transformed,
},
}
synthBytes, synthErr := json.Marshal(synthetic)
if synthErr == nil {
synthBlock := "event: content_block_delta\n" + "data: " + string(synthBytes) + "\n\n"
rewriteToolNamesInValue(event, toolNameMap)
stopBytes, stopErr := json.Marshal(event)
if stopErr == nil {
stopBlock := ""
if eventName != "" {
stopBlock = "event: " + eventName + "\n"
}
stopBlock += "data: " + string(stopBytes) + "\n\n"
return []string{synthBlock, stopBlock}, string(stopBytes), nil
}
}
}
}
}
rewriteToolNamesInValue(event, toolNameMap)
newData, err := json.Marshal(event)
if err != nil {
replaced := replaceToolNamesInText(dataLine, toolNameMap)
block := ""
if eventName != "" {
block = "event: " + eventName + "\n"
}
block += "data: " + replaced + "\n\n"
return []string{block}, replaced, nil
}
block := ""
if eventName != "" {
block = "event: " + eventName + "\n"
}
block += "data: " + string(newData) + "\n\n"
return []string{block}, string(newData), nil
}
for {
select {
case ev, ok := <-events:
@@ -3346,45 +3496,43 @@ func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http
return &streamingResult{usage: usage, firstTokenMs: firstTokenMs}, fmt.Errorf("stream read error: %w", ev.err)
}
line := ev.line
if line == "event: error" {
// 上游返回错误事件,如果客户端已断开仍返回已收集的 usage
if clientDisconnected {
return &streamingResult{usage: usage, firstTokenMs: firstTokenMs, clientDisconnect: true}, nil
trimmed := strings.TrimSpace(line)
if trimmed == "" {
if len(pendingEventLines) == 0 {
continue
}
return nil, errors.New("have error in stream")
outputBlocks, data, err := processSSEEvent(pendingEventLines)
pendingEventLines = pendingEventLines[:0]
if err != nil {
if clientDisconnected {
return &streamingResult{usage: usage, firstTokenMs: firstTokenMs, clientDisconnect: true}, nil
}
return nil, err
}
for _, block := range outputBlocks {
if !clientDisconnected {
if _, werr := fmt.Fprint(w, block); werr != nil {
clientDisconnected = true
log.Printf("Client disconnected during streaming, continuing to drain upstream for billing")
break
}
flusher.Flush()
}
if data != "" {
if firstTokenMs == nil && data != "[DONE]" {
ms := int(time.Since(startTime).Milliseconds())
firstTokenMs = &ms
}
s.parseSSEUsage(data, usage)
}
}
continue
}
// Extract data from SSE line (supports both "data: " and "data:" formats)
var data string
if sseDataRe.MatchString(line) {
// 如果有模型映射替换响应中的model字段
if needModelReplace {
line = s.replaceModelInSSELine(line, mappedModel, originalModel)
}
if rewriteTools {
line = s.replaceToolNamesInSSELine(line, toolNameMap)
}
data = sseDataRe.ReplaceAllString(line, "")
}
// 写入客户端(统一处理 data 行和非 data 行)
if !clientDisconnected {
if _, err := fmt.Fprintf(w, "%s\n", line); err != nil {
clientDisconnected = true
log.Printf("Client disconnected during streaming, continuing to drain upstream for billing")
} else {
flusher.Flush()
}
}
// 无论客户端是否断开,都解析 usage仅对 data 行)
if data != "" {
if firstTokenMs == nil && data != "[DONE]" {
ms := int(time.Since(startTime).Milliseconds())
firstTokenMs = &ms
}
s.parseSSEUsage(data, usage)
}
pendingEventLines = append(pendingEventLines, line)
case <-intervalCh:
lastRead := time.Unix(0, atomic.LoadInt64(&lastReadAt))