fix(网关): SSE 缓冲 input_json_delta 反向转换
This commit is contained in:
@@ -3460,9 +3460,159 @@ func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http
|
|||||||
}
|
}
|
||||||
|
|
||||||
needModelReplace := originalModel != mappedModel
|
needModelReplace := originalModel != mappedModel
|
||||||
rewriteTools := mimicClaudeCode
|
|
||||||
clientDisconnected := false // 客户端断开标志,断开后继续读取上游以获取完整usage
|
clientDisconnected := false // 客户端断开标志,断开后继续读取上游以获取完整usage
|
||||||
|
|
||||||
|
pendingEventLines := make([]string, 0, 4)
|
||||||
|
toolInputBuffers := make(map[int]string)
|
||||||
|
|
||||||
|
transformToolInputJSON := func(raw string) string {
|
||||||
|
raw = strings.TrimSpace(raw)
|
||||||
|
if raw == "" {
|
||||||
|
return raw
|
||||||
|
}
|
||||||
|
|
||||||
|
var parsed any
|
||||||
|
if err := json.Unmarshal([]byte(raw), &parsed); err != nil {
|
||||||
|
return replaceToolNamesInText(raw, toolNameMap)
|
||||||
|
}
|
||||||
|
|
||||||
|
rewritten, changed := rewriteParamKeysInValue(parsed, toolNameMap)
|
||||||
|
if changed {
|
||||||
|
if bytes, err := json.Marshal(rewritten); err == nil {
|
||||||
|
return string(bytes)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return raw
|
||||||
|
}
|
||||||
|
|
||||||
|
processSSEEvent := func(lines []string) ([]string, string, error) {
|
||||||
|
if len(lines) == 0 {
|
||||||
|
return nil, "", nil
|
||||||
|
}
|
||||||
|
|
||||||
|
eventName := ""
|
||||||
|
dataLine := ""
|
||||||
|
for _, line := range lines {
|
||||||
|
trimmed := strings.TrimSpace(line)
|
||||||
|
if strings.HasPrefix(trimmed, "event:") {
|
||||||
|
eventName = strings.TrimSpace(strings.TrimPrefix(trimmed, "event:"))
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if dataLine == "" && sseDataRe.MatchString(trimmed) {
|
||||||
|
dataLine = sseDataRe.ReplaceAllString(trimmed, "")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if eventName == "error" {
|
||||||
|
return nil, dataLine, errors.New("have error in stream")
|
||||||
|
}
|
||||||
|
|
||||||
|
if dataLine == "" {
|
||||||
|
return []string{strings.Join(lines, "\n") + "\n\n"}, "", nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if dataLine == "[DONE]" {
|
||||||
|
block := ""
|
||||||
|
if eventName != "" {
|
||||||
|
block = "event: " + eventName + "\n"
|
||||||
|
}
|
||||||
|
block += "data: " + dataLine + "\n\n"
|
||||||
|
return []string{block}, dataLine, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
var event map[string]any
|
||||||
|
if err := json.Unmarshal([]byte(dataLine), &event); err != nil {
|
||||||
|
replaced := replaceToolNamesInText(dataLine, toolNameMap)
|
||||||
|
block := ""
|
||||||
|
if eventName != "" {
|
||||||
|
block = "event: " + eventName + "\n"
|
||||||
|
}
|
||||||
|
block += "data: " + replaced + "\n\n"
|
||||||
|
return []string{block}, replaced, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
eventType, _ := event["type"].(string)
|
||||||
|
if eventName == "" {
|
||||||
|
eventName = eventType
|
||||||
|
}
|
||||||
|
|
||||||
|
if needModelReplace && eventType == "message_start" {
|
||||||
|
if msg, ok := event["message"].(map[string]any); ok {
|
||||||
|
if model, ok := msg["model"].(string); ok && model == mappedModel {
|
||||||
|
msg["model"] = originalModel
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if eventType == "content_block_delta" {
|
||||||
|
if delta, ok := event["delta"].(map[string]any); ok {
|
||||||
|
if deltaType, _ := delta["type"].(string); deltaType == "input_json_delta" {
|
||||||
|
if indexVal, ok := event["index"].(float64); ok {
|
||||||
|
index := int(indexVal)
|
||||||
|
if partial, ok := delta["partial_json"].(string); ok {
|
||||||
|
toolInputBuffers[index] += partial
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil, dataLine, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if eventType == "content_block_stop" {
|
||||||
|
if indexVal, ok := event["index"].(float64); ok {
|
||||||
|
index := int(indexVal)
|
||||||
|
if buffered := toolInputBuffers[index]; buffered != "" {
|
||||||
|
delete(toolInputBuffers, index)
|
||||||
|
|
||||||
|
transformed := transformToolInputJSON(buffered)
|
||||||
|
synthetic := map[string]any{
|
||||||
|
"type": "content_block_delta",
|
||||||
|
"index": index,
|
||||||
|
"delta": map[string]any{
|
||||||
|
"type": "input_json_delta",
|
||||||
|
"partial_json": transformed,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
synthBytes, synthErr := json.Marshal(synthetic)
|
||||||
|
if synthErr == nil {
|
||||||
|
synthBlock := "event: content_block_delta\n" + "data: " + string(synthBytes) + "\n\n"
|
||||||
|
|
||||||
|
rewriteToolNamesInValue(event, toolNameMap)
|
||||||
|
stopBytes, stopErr := json.Marshal(event)
|
||||||
|
if stopErr == nil {
|
||||||
|
stopBlock := ""
|
||||||
|
if eventName != "" {
|
||||||
|
stopBlock = "event: " + eventName + "\n"
|
||||||
|
}
|
||||||
|
stopBlock += "data: " + string(stopBytes) + "\n\n"
|
||||||
|
return []string{synthBlock, stopBlock}, string(stopBytes), nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
rewriteToolNamesInValue(event, toolNameMap)
|
||||||
|
newData, err := json.Marshal(event)
|
||||||
|
if err != nil {
|
||||||
|
replaced := replaceToolNamesInText(dataLine, toolNameMap)
|
||||||
|
block := ""
|
||||||
|
if eventName != "" {
|
||||||
|
block = "event: " + eventName + "\n"
|
||||||
|
}
|
||||||
|
block += "data: " + replaced + "\n\n"
|
||||||
|
return []string{block}, replaced, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
block := ""
|
||||||
|
if eventName != "" {
|
||||||
|
block = "event: " + eventName + "\n"
|
||||||
|
}
|
||||||
|
block += "data: " + string(newData) + "\n\n"
|
||||||
|
return []string{block}, string(newData), nil
|
||||||
|
}
|
||||||
|
|
||||||
for {
|
for {
|
||||||
select {
|
select {
|
||||||
case ev, ok := <-events:
|
case ev, ok := <-events:
|
||||||
@@ -3491,45 +3641,43 @@ func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http
|
|||||||
return &streamingResult{usage: usage, firstTokenMs: firstTokenMs}, fmt.Errorf("stream read error: %w", ev.err)
|
return &streamingResult{usage: usage, firstTokenMs: firstTokenMs}, fmt.Errorf("stream read error: %w", ev.err)
|
||||||
}
|
}
|
||||||
line := ev.line
|
line := ev.line
|
||||||
if line == "event: error" {
|
trimmed := strings.TrimSpace(line)
|
||||||
// 上游返回错误事件,如果客户端已断开仍返回已收集的 usage
|
|
||||||
if clientDisconnected {
|
if trimmed == "" {
|
||||||
return &streamingResult{usage: usage, firstTokenMs: firstTokenMs, clientDisconnect: true}, nil
|
if len(pendingEventLines) == 0 {
|
||||||
|
continue
|
||||||
}
|
}
|
||||||
return nil, errors.New("have error in stream")
|
|
||||||
|
outputBlocks, data, err := processSSEEvent(pendingEventLines)
|
||||||
|
pendingEventLines = pendingEventLines[:0]
|
||||||
|
if err != nil {
|
||||||
|
if clientDisconnected {
|
||||||
|
return &streamingResult{usage: usage, firstTokenMs: firstTokenMs, clientDisconnect: true}, nil
|
||||||
|
}
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, block := range outputBlocks {
|
||||||
|
if !clientDisconnected {
|
||||||
|
if _, werr := fmt.Fprint(w, block); werr != nil {
|
||||||
|
clientDisconnected = true
|
||||||
|
log.Printf("Client disconnected during streaming, continuing to drain upstream for billing")
|
||||||
|
break
|
||||||
|
}
|
||||||
|
flusher.Flush()
|
||||||
|
}
|
||||||
|
if data != "" {
|
||||||
|
if firstTokenMs == nil && data != "[DONE]" {
|
||||||
|
ms := int(time.Since(startTime).Milliseconds())
|
||||||
|
firstTokenMs = &ms
|
||||||
|
}
|
||||||
|
s.parseSSEUsage(data, usage)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
// Extract data from SSE line (supports both "data: " and "data:" formats)
|
pendingEventLines = append(pendingEventLines, line)
|
||||||
var data string
|
|
||||||
if sseDataRe.MatchString(line) {
|
|
||||||
// 如果有模型映射,替换响应中的model字段
|
|
||||||
if needModelReplace {
|
|
||||||
line = s.replaceModelInSSELine(line, mappedModel, originalModel)
|
|
||||||
}
|
|
||||||
if rewriteTools {
|
|
||||||
line = s.replaceToolNamesInSSELine(line, toolNameMap)
|
|
||||||
}
|
|
||||||
data = sseDataRe.ReplaceAllString(line, "")
|
|
||||||
}
|
|
||||||
|
|
||||||
// 写入客户端(统一处理 data 行和非 data 行)
|
|
||||||
if !clientDisconnected {
|
|
||||||
if _, err := fmt.Fprintf(w, "%s\n", line); err != nil {
|
|
||||||
clientDisconnected = true
|
|
||||||
log.Printf("Client disconnected during streaming, continuing to drain upstream for billing")
|
|
||||||
} else {
|
|
||||||
flusher.Flush()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// 无论客户端是否断开,都解析 usage(仅对 data 行)
|
|
||||||
if data != "" {
|
|
||||||
if firstTokenMs == nil && data != "[DONE]" {
|
|
||||||
ms := int(time.Since(startTime).Milliseconds())
|
|
||||||
firstTokenMs = &ms
|
|
||||||
}
|
|
||||||
s.parseSSEUsage(data, usage)
|
|
||||||
}
|
|
||||||
|
|
||||||
case <-intervalCh:
|
case <-intervalCh:
|
||||||
lastRead := time.Unix(0, atomic.LoadInt64(&lastReadAt))
|
lastRead := time.Unix(0, atomic.LoadInt64(&lastReadAt))
|
||||||
|
|||||||
Reference in New Issue
Block a user