fix(billing): 修复客户端取消请求时计费丢失问题
检测 context.Canceled 作为客户端断开信号,返回已收集的 usage 而非错误
This commit is contained in:
@@ -109,12 +109,13 @@ type ClaudeUsage struct {
|
||||
|
||||
// ForwardResult 转发结果
|
||||
type ForwardResult struct {
|
||||
RequestID string
|
||||
Usage ClaudeUsage
|
||||
Model string
|
||||
Stream bool
|
||||
Duration time.Duration
|
||||
FirstTokenMs *int // 首字时间(流式请求)
|
||||
RequestID string
|
||||
Usage ClaudeUsage
|
||||
Model string
|
||||
Stream bool
|
||||
Duration time.Duration
|
||||
FirstTokenMs *int // 首字时间(流式请求)
|
||||
ClientDisconnect bool // 客户端是否在流式传输过程中断开
|
||||
|
||||
// 图片生成计费字段(仅 gemini-3-pro-image 使用)
|
||||
ImageCount int // 生成的图片数量
|
||||
@@ -1465,6 +1466,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
|
||||
// 处理正常响应
|
||||
var usage *ClaudeUsage
|
||||
var firstTokenMs *int
|
||||
var clientDisconnect bool
|
||||
if reqStream {
|
||||
streamResult, err := s.handleStreamingResponse(ctx, resp, c, account, startTime, originalModel, reqModel)
|
||||
if err != nil {
|
||||
@@ -1477,6 +1479,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
|
||||
}
|
||||
usage = streamResult.usage
|
||||
firstTokenMs = streamResult.firstTokenMs
|
||||
clientDisconnect = streamResult.clientDisconnect
|
||||
} else {
|
||||
usage, err = s.handleNonStreamingResponse(ctx, resp, c, account, originalModel, reqModel)
|
||||
if err != nil {
|
||||
@@ -1485,12 +1488,13 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
|
||||
}
|
||||
|
||||
return &ForwardResult{
|
||||
RequestID: resp.Header.Get("x-request-id"),
|
||||
Usage: *usage,
|
||||
Model: originalModel, // 使用原始模型用于计费和日志
|
||||
Stream: reqStream,
|
||||
Duration: time.Since(startTime),
|
||||
FirstTokenMs: firstTokenMs,
|
||||
RequestID: resp.Header.Get("x-request-id"),
|
||||
Usage: *usage,
|
||||
Model: originalModel, // 使用原始模型用于计费和日志
|
||||
Stream: reqStream,
|
||||
Duration: time.Since(startTime),
|
||||
FirstTokenMs: firstTokenMs,
|
||||
ClientDisconnect: clientDisconnect,
|
||||
}, nil
|
||||
}
|
||||
|
||||
@@ -1845,8 +1849,9 @@ func (s *GatewayService) handleRetryExhaustedError(ctx context.Context, resp *ht
|
||||
|
||||
// streamingResult 流式响应结果
|
||||
type streamingResult struct {
|
||||
usage *ClaudeUsage
|
||||
firstTokenMs *int
|
||||
usage *ClaudeUsage
|
||||
firstTokenMs *int
|
||||
clientDisconnect bool // 客户端是否在流式传输过程中断开
|
||||
}
|
||||
|
||||
func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http.Response, c *gin.Context, account *Account, startTime time.Time, originalModel, mappedModel string) (*streamingResult, error) {
|
||||
@@ -1942,14 +1947,27 @@ func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http
|
||||
}
|
||||
|
||||
needModelReplace := originalModel != mappedModel
|
||||
clientDisconnected := false // 客户端断开标志,断开后继续读取上游以获取完整usage
|
||||
|
||||
for {
|
||||
select {
|
||||
case ev, ok := <-events:
|
||||
if !ok {
|
||||
return &streamingResult{usage: usage, firstTokenMs: firstTokenMs}, nil
|
||||
// 上游完成,返回结果
|
||||
return &streamingResult{usage: usage, firstTokenMs: firstTokenMs, clientDisconnect: clientDisconnected}, nil
|
||||
}
|
||||
if ev.err != nil {
|
||||
// 检测 context 取消(客户端断开会导致 context 取消,进而影响上游读取)
|
||||
if errors.Is(ev.err, context.Canceled) || errors.Is(ev.err, context.DeadlineExceeded) {
|
||||
log.Printf("Context canceled during streaming, returning collected usage")
|
||||
return &streamingResult{usage: usage, firstTokenMs: firstTokenMs, clientDisconnect: true}, nil
|
||||
}
|
||||
// 客户端已通过写入失败检测到断开,上游也出错了,返回已收集的 usage
|
||||
if clientDisconnected {
|
||||
log.Printf("Upstream read error after client disconnect: %v, returning collected usage", ev.err)
|
||||
return &streamingResult{usage: usage, firstTokenMs: firstTokenMs, clientDisconnect: true}, nil
|
||||
}
|
||||
// 客户端未断开,正常的错误处理
|
||||
if errors.Is(ev.err, bufio.ErrTooLong) {
|
||||
log.Printf("SSE line too long: account=%d max_size=%d error=%v", account.ID, maxLineSize, ev.err)
|
||||
sendErrorEvent("response_too_large")
|
||||
@@ -1960,38 +1978,40 @@ func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http
|
||||
}
|
||||
line := ev.line
|
||||
if line == "event: error" {
|
||||
// 上游返回错误事件,如果客户端已断开仍返回已收集的 usage
|
||||
if clientDisconnected {
|
||||
return &streamingResult{usage: usage, firstTokenMs: firstTokenMs, clientDisconnect: true}, nil
|
||||
}
|
||||
return nil, errors.New("have error in stream")
|
||||
}
|
||||
|
||||
// Extract data from SSE line (supports both "data: " and "data:" formats)
|
||||
var data string
|
||||
if sseDataRe.MatchString(line) {
|
||||
data := sseDataRe.ReplaceAllString(line, "")
|
||||
|
||||
data = sseDataRe.ReplaceAllString(line, "")
|
||||
// 如果有模型映射,替换响应中的model字段
|
||||
if needModelReplace {
|
||||
line = s.replaceModelInSSELine(line, mappedModel, originalModel)
|
||||
}
|
||||
}
|
||||
|
||||
// 转发行
|
||||
// 写入客户端(统一处理 data 行和非 data 行)
|
||||
if !clientDisconnected {
|
||||
if _, err := fmt.Fprintf(w, "%s\n", line); err != nil {
|
||||
sendErrorEvent("write_failed")
|
||||
return &streamingResult{usage: usage, firstTokenMs: firstTokenMs}, err
|
||||
clientDisconnected = true
|
||||
log.Printf("Client disconnected during streaming, continuing to drain upstream for billing")
|
||||
} else {
|
||||
flusher.Flush()
|
||||
}
|
||||
flusher.Flush()
|
||||
}
|
||||
|
||||
// 记录首字时间:第一个有效的 content_block_delta 或 message_start
|
||||
if firstTokenMs == nil && data != "" && data != "[DONE]" {
|
||||
// 无论客户端是否断开,都解析 usage(仅对 data 行)
|
||||
if data != "" {
|
||||
if firstTokenMs == nil && data != "[DONE]" {
|
||||
ms := int(time.Since(startTime).Milliseconds())
|
||||
firstTokenMs = &ms
|
||||
}
|
||||
s.parseSSEUsage(data, usage)
|
||||
} else {
|
||||
// 非 data 行直接转发
|
||||
if _, err := fmt.Fprintf(w, "%s\n", line); err != nil {
|
||||
sendErrorEvent("write_failed")
|
||||
return &streamingResult{usage: usage, firstTokenMs: firstTokenMs}, err
|
||||
}
|
||||
flusher.Flush()
|
||||
}
|
||||
|
||||
case <-intervalCh:
|
||||
@@ -1999,6 +2019,11 @@ func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http
|
||||
if time.Since(lastRead) < streamInterval {
|
||||
continue
|
||||
}
|
||||
if clientDisconnected {
|
||||
// 客户端已断开,上游也超时了,返回已收集的 usage
|
||||
log.Printf("Upstream timeout after client disconnect, returning collected usage")
|
||||
return &streamingResult{usage: usage, firstTokenMs: firstTokenMs, clientDisconnect: true}, nil
|
||||
}
|
||||
log.Printf("Stream data interval timeout: account=%d model=%s interval=%s", account.ID, originalModel, streamInterval)
|
||||
sendErrorEvent("stream_timeout")
|
||||
return &streamingResult{usage: usage, firstTokenMs: firstTokenMs}, fmt.Errorf("stream data interval timeout")
|
||||
|
||||
Reference in New Issue
Block a user