fix(billing): 修复客户端取消请求时计费丢失问题

检测 context.Canceled 作为客户端断开信号,返回已收集的 usage 而非错误
This commit is contained in:
shaw
2026-01-08 11:25:17 +08:00
parent acabdc2f99
commit db6f53e2c9

View File

@@ -109,12 +109,13 @@ type ClaudeUsage struct {
// ForwardResult 转发结果 // ForwardResult 转发结果
type ForwardResult struct { type ForwardResult struct {
RequestID string RequestID string
Usage ClaudeUsage Usage ClaudeUsage
Model string Model string
Stream bool Stream bool
Duration time.Duration Duration time.Duration
FirstTokenMs *int // 首字时间(流式请求) FirstTokenMs *int // 首字时间(流式请求)
ClientDisconnect bool // 客户端是否在流式传输过程中断开
// 图片生成计费字段(仅 gemini-3-pro-image 使用) // 图片生成计费字段(仅 gemini-3-pro-image 使用)
ImageCount int // 生成的图片数量 ImageCount int // 生成的图片数量
@@ -1465,6 +1466,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
// 处理正常响应 // 处理正常响应
var usage *ClaudeUsage var usage *ClaudeUsage
var firstTokenMs *int var firstTokenMs *int
var clientDisconnect bool
if reqStream { if reqStream {
streamResult, err := s.handleStreamingResponse(ctx, resp, c, account, startTime, originalModel, reqModel) streamResult, err := s.handleStreamingResponse(ctx, resp, c, account, startTime, originalModel, reqModel)
if err != nil { if err != nil {
@@ -1477,6 +1479,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
} }
usage = streamResult.usage usage = streamResult.usage
firstTokenMs = streamResult.firstTokenMs firstTokenMs = streamResult.firstTokenMs
clientDisconnect = streamResult.clientDisconnect
} else { } else {
usage, err = s.handleNonStreamingResponse(ctx, resp, c, account, originalModel, reqModel) usage, err = s.handleNonStreamingResponse(ctx, resp, c, account, originalModel, reqModel)
if err != nil { if err != nil {
@@ -1485,12 +1488,13 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
} }
return &ForwardResult{ return &ForwardResult{
RequestID: resp.Header.Get("x-request-id"), RequestID: resp.Header.Get("x-request-id"),
Usage: *usage, Usage: *usage,
Model: originalModel, // 使用原始模型用于计费和日志 Model: originalModel, // 使用原始模型用于计费和日志
Stream: reqStream, Stream: reqStream,
Duration: time.Since(startTime), Duration: time.Since(startTime),
FirstTokenMs: firstTokenMs, FirstTokenMs: firstTokenMs,
ClientDisconnect: clientDisconnect,
}, nil }, nil
} }
@@ -1845,8 +1849,9 @@ func (s *GatewayService) handleRetryExhaustedError(ctx context.Context, resp *ht
// streamingResult 流式响应结果 // streamingResult 流式响应结果
type streamingResult struct { type streamingResult struct {
usage *ClaudeUsage usage *ClaudeUsage
firstTokenMs *int firstTokenMs *int
clientDisconnect bool // 客户端是否在流式传输过程中断开
} }
func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http.Response, c *gin.Context, account *Account, startTime time.Time, originalModel, mappedModel string) (*streamingResult, error) { func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http.Response, c *gin.Context, account *Account, startTime time.Time, originalModel, mappedModel string) (*streamingResult, error) {
@@ -1942,14 +1947,27 @@ func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http
} }
needModelReplace := originalModel != mappedModel needModelReplace := originalModel != mappedModel
clientDisconnected := false // 客户端断开标志断开后继续读取上游以获取完整usage
for { for {
select { select {
case ev, ok := <-events: case ev, ok := <-events:
if !ok { if !ok {
return &streamingResult{usage: usage, firstTokenMs: firstTokenMs}, nil // 上游完成,返回结果
return &streamingResult{usage: usage, firstTokenMs: firstTokenMs, clientDisconnect: clientDisconnected}, nil
} }
if ev.err != nil { if ev.err != nil {
// 检测 context 取消(客户端断开会导致 context 取消,进而影响上游读取)
if errors.Is(ev.err, context.Canceled) || errors.Is(ev.err, context.DeadlineExceeded) {
log.Printf("Context canceled during streaming, returning collected usage")
return &streamingResult{usage: usage, firstTokenMs: firstTokenMs, clientDisconnect: true}, nil
}
// 客户端已通过写入失败检测到断开,上游也出错了,返回已收集的 usage
if clientDisconnected {
log.Printf("Upstream read error after client disconnect: %v, returning collected usage", ev.err)
return &streamingResult{usage: usage, firstTokenMs: firstTokenMs, clientDisconnect: true}, nil
}
// 客户端未断开,正常的错误处理
if errors.Is(ev.err, bufio.ErrTooLong) { if errors.Is(ev.err, bufio.ErrTooLong) {
log.Printf("SSE line too long: account=%d max_size=%d error=%v", account.ID, maxLineSize, ev.err) log.Printf("SSE line too long: account=%d max_size=%d error=%v", account.ID, maxLineSize, ev.err)
sendErrorEvent("response_too_large") sendErrorEvent("response_too_large")
@@ -1960,38 +1978,40 @@ func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http
} }
line := ev.line line := ev.line
if line == "event: error" { if line == "event: error" {
// 上游返回错误事件,如果客户端已断开仍返回已收集的 usage
if clientDisconnected {
return &streamingResult{usage: usage, firstTokenMs: firstTokenMs, clientDisconnect: true}, nil
}
return nil, errors.New("have error in stream") return nil, errors.New("have error in stream")
} }
// Extract data from SSE line (supports both "data: " and "data:" formats) // Extract data from SSE line (supports both "data: " and "data:" formats)
var data string
if sseDataRe.MatchString(line) { if sseDataRe.MatchString(line) {
data := sseDataRe.ReplaceAllString(line, "") data = sseDataRe.ReplaceAllString(line, "")
// 如果有模型映射替换响应中的model字段 // 如果有模型映射替换响应中的model字段
if needModelReplace { if needModelReplace {
line = s.replaceModelInSSELine(line, mappedModel, originalModel) line = s.replaceModelInSSELine(line, mappedModel, originalModel)
} }
}
// 转发行 // 写入客户端(统一处理 data 行和非 data 行)
if !clientDisconnected {
if _, err := fmt.Fprintf(w, "%s\n", line); err != nil { if _, err := fmt.Fprintf(w, "%s\n", line); err != nil {
sendErrorEvent("write_failed") clientDisconnected = true
return &streamingResult{usage: usage, firstTokenMs: firstTokenMs}, err log.Printf("Client disconnected during streaming, continuing to drain upstream for billing")
} else {
flusher.Flush()
} }
flusher.Flush() }
// 记录首字时间:第一个有效的 content_block_delta 或 message_start // 无论客户端是否断开,都解析 usage仅对 data 行)
if firstTokenMs == nil && data != "" && data != "[DONE]" { if data != "" {
if firstTokenMs == nil && data != "[DONE]" {
ms := int(time.Since(startTime).Milliseconds()) ms := int(time.Since(startTime).Milliseconds())
firstTokenMs = &ms firstTokenMs = &ms
} }
s.parseSSEUsage(data, usage) s.parseSSEUsage(data, usage)
} else {
// 非 data 行直接转发
if _, err := fmt.Fprintf(w, "%s\n", line); err != nil {
sendErrorEvent("write_failed")
return &streamingResult{usage: usage, firstTokenMs: firstTokenMs}, err
}
flusher.Flush()
} }
case <-intervalCh: case <-intervalCh:
@@ -1999,6 +2019,11 @@ func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http
if time.Since(lastRead) < streamInterval { if time.Since(lastRead) < streamInterval {
continue continue
} }
if clientDisconnected {
// 客户端已断开,上游也超时了,返回已收集的 usage
log.Printf("Upstream timeout after client disconnect, returning collected usage")
return &streamingResult{usage: usage, firstTokenMs: firstTokenMs, clientDisconnect: true}, nil
}
log.Printf("Stream data interval timeout: account=%d model=%s interval=%s", account.ID, originalModel, streamInterval) log.Printf("Stream data interval timeout: account=%d model=%s interval=%s", account.ID, originalModel, streamInterval)
sendErrorEvent("stream_timeout") sendErrorEvent("stream_timeout")
return &streamingResult{usage: usage, firstTokenMs: firstTokenMs}, fmt.Errorf("stream data interval timeout") return &streamingResult{usage: usage, firstTokenMs: firstTokenMs}, fmt.Errorf("stream data interval timeout")