From db6f53e2c951a8900b7a2dff5a360044a96a6f83 Mon Sep 17 00:00:00 2001 From: shaw Date: Thu, 8 Jan 2026 11:25:17 +0800 Subject: [PATCH] =?UTF-8?q?fix(billing):=20=E4=BF=AE=E5=A4=8D=E5=AE=A2?= =?UTF-8?q?=E6=88=B7=E7=AB=AF=E5=8F=96=E6=B6=88=E8=AF=B7=E6=B1=82=E6=97=B6?= =?UTF-8?q?=E8=AE=A1=E8=B4=B9=E4=B8=A2=E5=A4=B1=E9=97=AE=E9=A2=98?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 检测 context.Canceled 作为客户端断开信号,返回已收集的 usage 而非错误 --- backend/internal/service/gateway_service.go | 85 +++++++++++++-------- 1 file changed, 55 insertions(+), 30 deletions(-) diff --git a/backend/internal/service/gateway_service.go b/backend/internal/service/gateway_service.go index 3a68b44c..98c061d4 100644 --- a/backend/internal/service/gateway_service.go +++ b/backend/internal/service/gateway_service.go @@ -109,12 +109,13 @@ type ClaudeUsage struct { // ForwardResult 转发结果 type ForwardResult struct { - RequestID string - Usage ClaudeUsage - Model string - Stream bool - Duration time.Duration - FirstTokenMs *int // 首字时间(流式请求) + RequestID string + Usage ClaudeUsage + Model string + Stream bool + Duration time.Duration + FirstTokenMs *int // 首字时间(流式请求) + ClientDisconnect bool // 客户端是否在流式传输过程中断开 // 图片生成计费字段(仅 gemini-3-pro-image 使用) ImageCount int // 生成的图片数量 @@ -1465,6 +1466,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A // 处理正常响应 var usage *ClaudeUsage var firstTokenMs *int + var clientDisconnect bool if reqStream { streamResult, err := s.handleStreamingResponse(ctx, resp, c, account, startTime, originalModel, reqModel) if err != nil { @@ -1477,6 +1479,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A } usage = streamResult.usage firstTokenMs = streamResult.firstTokenMs + clientDisconnect = streamResult.clientDisconnect } else { usage, err = s.handleNonStreamingResponse(ctx, resp, c, account, originalModel, reqModel) if err != nil { @@ -1485,12 +1488,13 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A } return &ForwardResult{ - RequestID: resp.Header.Get("x-request-id"), - Usage: *usage, - Model: originalModel, // 使用原始模型用于计费和日志 - Stream: reqStream, - Duration: time.Since(startTime), - FirstTokenMs: firstTokenMs, + RequestID: resp.Header.Get("x-request-id"), + Usage: *usage, + Model: originalModel, // 使用原始模型用于计费和日志 + Stream: reqStream, + Duration: time.Since(startTime), + FirstTokenMs: firstTokenMs, + ClientDisconnect: clientDisconnect, }, nil } @@ -1845,8 +1849,9 @@ func (s *GatewayService) handleRetryExhaustedError(ctx context.Context, resp *ht // streamingResult 流式响应结果 type streamingResult struct { - usage *ClaudeUsage - firstTokenMs *int + usage *ClaudeUsage + firstTokenMs *int + clientDisconnect bool // 客户端是否在流式传输过程中断开 } func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http.Response, c *gin.Context, account *Account, startTime time.Time, originalModel, mappedModel string) (*streamingResult, error) { @@ -1942,14 +1947,27 @@ func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http } needModelReplace := originalModel != mappedModel + clientDisconnected := false // 客户端断开标志,断开后继续读取上游以获取完整usage for { select { case ev, ok := <-events: if !ok { - return &streamingResult{usage: usage, firstTokenMs: firstTokenMs}, nil + // 上游完成,返回结果 + return &streamingResult{usage: usage, firstTokenMs: firstTokenMs, clientDisconnect: clientDisconnected}, nil } if ev.err != nil { + // 检测 context 取消(客户端断开会导致 context 取消,进而影响上游读取) + if errors.Is(ev.err, context.Canceled) || errors.Is(ev.err, context.DeadlineExceeded) { + log.Printf("Context canceled during streaming, returning collected usage") + return &streamingResult{usage: usage, firstTokenMs: firstTokenMs, clientDisconnect: true}, nil + } + // 客户端已通过写入失败检测到断开,上游也出错了,返回已收集的 usage + if clientDisconnected { + log.Printf("Upstream read error after client disconnect: %v, returning collected usage", ev.err) + return &streamingResult{usage: usage, firstTokenMs: firstTokenMs, clientDisconnect: true}, nil + } + // 客户端未断开,正常的错误处理 if errors.Is(ev.err, bufio.ErrTooLong) { log.Printf("SSE line too long: account=%d max_size=%d error=%v", account.ID, maxLineSize, ev.err) sendErrorEvent("response_too_large") @@ -1960,38 +1978,40 @@ func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http } line := ev.line if line == "event: error" { + // 上游返回错误事件,如果客户端已断开仍返回已收集的 usage + if clientDisconnected { + return &streamingResult{usage: usage, firstTokenMs: firstTokenMs, clientDisconnect: true}, nil + } return nil, errors.New("have error in stream") } // Extract data from SSE line (supports both "data: " and "data:" formats) + var data string if sseDataRe.MatchString(line) { - data := sseDataRe.ReplaceAllString(line, "") - + data = sseDataRe.ReplaceAllString(line, "") // 如果有模型映射,替换响应中的model字段 if needModelReplace { line = s.replaceModelInSSELine(line, mappedModel, originalModel) } + } - // 转发行 + // 写入客户端(统一处理 data 行和非 data 行) + if !clientDisconnected { if _, err := fmt.Fprintf(w, "%s\n", line); err != nil { - sendErrorEvent("write_failed") - return &streamingResult{usage: usage, firstTokenMs: firstTokenMs}, err + clientDisconnected = true + log.Printf("Client disconnected during streaming, continuing to drain upstream for billing") + } else { + flusher.Flush() } - flusher.Flush() + } - // 记录首字时间:第一个有效的 content_block_delta 或 message_start - if firstTokenMs == nil && data != "" && data != "[DONE]" { + // 无论客户端是否断开,都解析 usage(仅对 data 行) + if data != "" { + if firstTokenMs == nil && data != "[DONE]" { ms := int(time.Since(startTime).Milliseconds()) firstTokenMs = &ms } s.parseSSEUsage(data, usage) - } else { - // 非 data 行直接转发 - if _, err := fmt.Fprintf(w, "%s\n", line); err != nil { - sendErrorEvent("write_failed") - return &streamingResult{usage: usage, firstTokenMs: firstTokenMs}, err - } - flusher.Flush() } case <-intervalCh: @@ -1999,6 +2019,11 @@ func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http if time.Since(lastRead) < streamInterval { continue } + if clientDisconnected { + // 客户端已断开,上游也超时了,返回已收集的 usage + log.Printf("Upstream timeout after client disconnect, returning collected usage") + return &streamingResult{usage: usage, firstTokenMs: firstTokenMs, clientDisconnect: true}, nil + } log.Printf("Stream data interval timeout: account=%d model=%s interval=%s", account.ID, originalModel, streamInterval) sendErrorEvent("stream_timeout") return &streamingResult{usage: usage, firstTokenMs: firstTokenMs}, fmt.Errorf("stream data interval timeout")