fix(网关): SSE 缓冲 input_json_delta 反向转换
This commit is contained in:
@@ -3315,9 +3315,159 @@ func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http
|
||||
}
|
||||
|
||||
needModelReplace := originalModel != mappedModel
|
||||
rewriteTools := mimicClaudeCode
|
||||
clientDisconnected := false // 客户端断开标志,断开后继续读取上游以获取完整usage
|
||||
|
||||
pendingEventLines := make([]string, 0, 4)
|
||||
toolInputBuffers := make(map[int]string)
|
||||
|
||||
transformToolInputJSON := func(raw string) string {
|
||||
raw = strings.TrimSpace(raw)
|
||||
if raw == "" {
|
||||
return raw
|
||||
}
|
||||
|
||||
var parsed any
|
||||
if err := json.Unmarshal([]byte(raw), &parsed); err != nil {
|
||||
return replaceToolNamesInText(raw, toolNameMap)
|
||||
}
|
||||
|
||||
rewritten, changed := rewriteParamKeysInValue(parsed, toolNameMap)
|
||||
if changed {
|
||||
if bytes, err := json.Marshal(rewritten); err == nil {
|
||||
return string(bytes)
|
||||
}
|
||||
}
|
||||
return raw
|
||||
}
|
||||
|
||||
processSSEEvent := func(lines []string) ([]string, string, error) {
|
||||
if len(lines) == 0 {
|
||||
return nil, "", nil
|
||||
}
|
||||
|
||||
eventName := ""
|
||||
dataLine := ""
|
||||
for _, line := range lines {
|
||||
trimmed := strings.TrimSpace(line)
|
||||
if strings.HasPrefix(trimmed, "event:") {
|
||||
eventName = strings.TrimSpace(strings.TrimPrefix(trimmed, "event:"))
|
||||
continue
|
||||
}
|
||||
if dataLine == "" && sseDataRe.MatchString(trimmed) {
|
||||
dataLine = sseDataRe.ReplaceAllString(trimmed, "")
|
||||
}
|
||||
}
|
||||
|
||||
if eventName == "error" {
|
||||
return nil, dataLine, errors.New("have error in stream")
|
||||
}
|
||||
|
||||
if dataLine == "" {
|
||||
return []string{strings.Join(lines, "\n") + "\n\n"}, "", nil
|
||||
}
|
||||
|
||||
if dataLine == "[DONE]" {
|
||||
block := ""
|
||||
if eventName != "" {
|
||||
block = "event: " + eventName + "\n"
|
||||
}
|
||||
block += "data: " + dataLine + "\n\n"
|
||||
return []string{block}, dataLine, nil
|
||||
}
|
||||
|
||||
var event map[string]any
|
||||
if err := json.Unmarshal([]byte(dataLine), &event); err != nil {
|
||||
replaced := replaceToolNamesInText(dataLine, toolNameMap)
|
||||
block := ""
|
||||
if eventName != "" {
|
||||
block = "event: " + eventName + "\n"
|
||||
}
|
||||
block += "data: " + replaced + "\n\n"
|
||||
return []string{block}, replaced, nil
|
||||
}
|
||||
|
||||
eventType, _ := event["type"].(string)
|
||||
if eventName == "" {
|
||||
eventName = eventType
|
||||
}
|
||||
|
||||
if needModelReplace && eventType == "message_start" {
|
||||
if msg, ok := event["message"].(map[string]any); ok {
|
||||
if model, ok := msg["model"].(string); ok && model == mappedModel {
|
||||
msg["model"] = originalModel
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if eventType == "content_block_delta" {
|
||||
if delta, ok := event["delta"].(map[string]any); ok {
|
||||
if deltaType, _ := delta["type"].(string); deltaType == "input_json_delta" {
|
||||
if indexVal, ok := event["index"].(float64); ok {
|
||||
index := int(indexVal)
|
||||
if partial, ok := delta["partial_json"].(string); ok {
|
||||
toolInputBuffers[index] += partial
|
||||
}
|
||||
}
|
||||
return nil, dataLine, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if eventType == "content_block_stop" {
|
||||
if indexVal, ok := event["index"].(float64); ok {
|
||||
index := int(indexVal)
|
||||
if buffered := toolInputBuffers[index]; buffered != "" {
|
||||
delete(toolInputBuffers, index)
|
||||
|
||||
transformed := transformToolInputJSON(buffered)
|
||||
synthetic := map[string]any{
|
||||
"type": "content_block_delta",
|
||||
"index": index,
|
||||
"delta": map[string]any{
|
||||
"type": "input_json_delta",
|
||||
"partial_json": transformed,
|
||||
},
|
||||
}
|
||||
|
||||
synthBytes, synthErr := json.Marshal(synthetic)
|
||||
if synthErr == nil {
|
||||
synthBlock := "event: content_block_delta\n" + "data: " + string(synthBytes) + "\n\n"
|
||||
|
||||
rewriteToolNamesInValue(event, toolNameMap)
|
||||
stopBytes, stopErr := json.Marshal(event)
|
||||
if stopErr == nil {
|
||||
stopBlock := ""
|
||||
if eventName != "" {
|
||||
stopBlock = "event: " + eventName + "\n"
|
||||
}
|
||||
stopBlock += "data: " + string(stopBytes) + "\n\n"
|
||||
return []string{synthBlock, stopBlock}, string(stopBytes), nil
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
rewriteToolNamesInValue(event, toolNameMap)
|
||||
newData, err := json.Marshal(event)
|
||||
if err != nil {
|
||||
replaced := replaceToolNamesInText(dataLine, toolNameMap)
|
||||
block := ""
|
||||
if eventName != "" {
|
||||
block = "event: " + eventName + "\n"
|
||||
}
|
||||
block += "data: " + replaced + "\n\n"
|
||||
return []string{block}, replaced, nil
|
||||
}
|
||||
|
||||
block := ""
|
||||
if eventName != "" {
|
||||
block = "event: " + eventName + "\n"
|
||||
}
|
||||
block += "data: " + string(newData) + "\n\n"
|
||||
return []string{block}, string(newData), nil
|
||||
}
|
||||
|
||||
for {
|
||||
select {
|
||||
case ev, ok := <-events:
|
||||
@@ -3346,45 +3496,43 @@ func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http
|
||||
return &streamingResult{usage: usage, firstTokenMs: firstTokenMs}, fmt.Errorf("stream read error: %w", ev.err)
|
||||
}
|
||||
line := ev.line
|
||||
if line == "event: error" {
|
||||
// 上游返回错误事件,如果客户端已断开仍返回已收集的 usage
|
||||
if clientDisconnected {
|
||||
return &streamingResult{usage: usage, firstTokenMs: firstTokenMs, clientDisconnect: true}, nil
|
||||
trimmed := strings.TrimSpace(line)
|
||||
|
||||
if trimmed == "" {
|
||||
if len(pendingEventLines) == 0 {
|
||||
continue
|
||||
}
|
||||
return nil, errors.New("have error in stream")
|
||||
|
||||
outputBlocks, data, err := processSSEEvent(pendingEventLines)
|
||||
pendingEventLines = pendingEventLines[:0]
|
||||
if err != nil {
|
||||
if clientDisconnected {
|
||||
return &streamingResult{usage: usage, firstTokenMs: firstTokenMs, clientDisconnect: true}, nil
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
for _, block := range outputBlocks {
|
||||
if !clientDisconnected {
|
||||
if _, werr := fmt.Fprint(w, block); werr != nil {
|
||||
clientDisconnected = true
|
||||
log.Printf("Client disconnected during streaming, continuing to drain upstream for billing")
|
||||
break
|
||||
}
|
||||
flusher.Flush()
|
||||
}
|
||||
if data != "" {
|
||||
if firstTokenMs == nil && data != "[DONE]" {
|
||||
ms := int(time.Since(startTime).Milliseconds())
|
||||
firstTokenMs = &ms
|
||||
}
|
||||
s.parseSSEUsage(data, usage)
|
||||
}
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
// Extract data from SSE line (supports both "data: " and "data:" formats)
|
||||
var data string
|
||||
if sseDataRe.MatchString(line) {
|
||||
// 如果有模型映射,替换响应中的model字段
|
||||
if needModelReplace {
|
||||
line = s.replaceModelInSSELine(line, mappedModel, originalModel)
|
||||
}
|
||||
if rewriteTools {
|
||||
line = s.replaceToolNamesInSSELine(line, toolNameMap)
|
||||
}
|
||||
data = sseDataRe.ReplaceAllString(line, "")
|
||||
}
|
||||
|
||||
// 写入客户端(统一处理 data 行和非 data 行)
|
||||
if !clientDisconnected {
|
||||
if _, err := fmt.Fprintf(w, "%s\n", line); err != nil {
|
||||
clientDisconnected = true
|
||||
log.Printf("Client disconnected during streaming, continuing to drain upstream for billing")
|
||||
} else {
|
||||
flusher.Flush()
|
||||
}
|
||||
}
|
||||
|
||||
// 无论客户端是否断开,都解析 usage(仅对 data 行)
|
||||
if data != "" {
|
||||
if firstTokenMs == nil && data != "[DONE]" {
|
||||
ms := int(time.Since(startTime).Milliseconds())
|
||||
firstTokenMs = &ms
|
||||
}
|
||||
s.parseSSEUsage(data, usage)
|
||||
}
|
||||
pendingEventLines = append(pendingEventLines, line)
|
||||
|
||||
case <-intervalCh:
|
||||
lastRead := time.Unix(0, atomic.LoadInt64(&lastReadAt))
|
||||
|
||||
Reference in New Issue
Block a user