fix(billing): 修复客户端取消请求时计费丢失问题
检测 context.Canceled 作为客户端断开信号,返回已收集的 usage 而非错误
This commit is contained in:
@@ -109,12 +109,13 @@ type ClaudeUsage struct {
|
|||||||
|
|
||||||
// ForwardResult 转发结果
|
// ForwardResult 转发结果
|
||||||
type ForwardResult struct {
|
type ForwardResult struct {
|
||||||
RequestID string
|
RequestID string
|
||||||
Usage ClaudeUsage
|
Usage ClaudeUsage
|
||||||
Model string
|
Model string
|
||||||
Stream bool
|
Stream bool
|
||||||
Duration time.Duration
|
Duration time.Duration
|
||||||
FirstTokenMs *int // 首字时间(流式请求)
|
FirstTokenMs *int // 首字时间(流式请求)
|
||||||
|
ClientDisconnect bool // 客户端是否在流式传输过程中断开
|
||||||
|
|
||||||
// 图片生成计费字段(仅 gemini-3-pro-image 使用)
|
// 图片生成计费字段(仅 gemini-3-pro-image 使用)
|
||||||
ImageCount int // 生成的图片数量
|
ImageCount int // 生成的图片数量
|
||||||
@@ -1465,6 +1466,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
|
|||||||
// 处理正常响应
|
// 处理正常响应
|
||||||
var usage *ClaudeUsage
|
var usage *ClaudeUsage
|
||||||
var firstTokenMs *int
|
var firstTokenMs *int
|
||||||
|
var clientDisconnect bool
|
||||||
if reqStream {
|
if reqStream {
|
||||||
streamResult, err := s.handleStreamingResponse(ctx, resp, c, account, startTime, originalModel, reqModel)
|
streamResult, err := s.handleStreamingResponse(ctx, resp, c, account, startTime, originalModel, reqModel)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -1477,6 +1479,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
|
|||||||
}
|
}
|
||||||
usage = streamResult.usage
|
usage = streamResult.usage
|
||||||
firstTokenMs = streamResult.firstTokenMs
|
firstTokenMs = streamResult.firstTokenMs
|
||||||
|
clientDisconnect = streamResult.clientDisconnect
|
||||||
} else {
|
} else {
|
||||||
usage, err = s.handleNonStreamingResponse(ctx, resp, c, account, originalModel, reqModel)
|
usage, err = s.handleNonStreamingResponse(ctx, resp, c, account, originalModel, reqModel)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -1485,12 +1488,13 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
|
|||||||
}
|
}
|
||||||
|
|
||||||
return &ForwardResult{
|
return &ForwardResult{
|
||||||
RequestID: resp.Header.Get("x-request-id"),
|
RequestID: resp.Header.Get("x-request-id"),
|
||||||
Usage: *usage,
|
Usage: *usage,
|
||||||
Model: originalModel, // 使用原始模型用于计费和日志
|
Model: originalModel, // 使用原始模型用于计费和日志
|
||||||
Stream: reqStream,
|
Stream: reqStream,
|
||||||
Duration: time.Since(startTime),
|
Duration: time.Since(startTime),
|
||||||
FirstTokenMs: firstTokenMs,
|
FirstTokenMs: firstTokenMs,
|
||||||
|
ClientDisconnect: clientDisconnect,
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1845,8 +1849,9 @@ func (s *GatewayService) handleRetryExhaustedError(ctx context.Context, resp *ht
|
|||||||
|
|
||||||
// streamingResult 流式响应结果
|
// streamingResult 流式响应结果
|
||||||
type streamingResult struct {
|
type streamingResult struct {
|
||||||
usage *ClaudeUsage
|
usage *ClaudeUsage
|
||||||
firstTokenMs *int
|
firstTokenMs *int
|
||||||
|
clientDisconnect bool // 客户端是否在流式传输过程中断开
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http.Response, c *gin.Context, account *Account, startTime time.Time, originalModel, mappedModel string) (*streamingResult, error) {
|
func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http.Response, c *gin.Context, account *Account, startTime time.Time, originalModel, mappedModel string) (*streamingResult, error) {
|
||||||
@@ -1942,14 +1947,27 @@ func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http
|
|||||||
}
|
}
|
||||||
|
|
||||||
needModelReplace := originalModel != mappedModel
|
needModelReplace := originalModel != mappedModel
|
||||||
|
clientDisconnected := false // 客户端断开标志,断开后继续读取上游以获取完整usage
|
||||||
|
|
||||||
for {
|
for {
|
||||||
select {
|
select {
|
||||||
case ev, ok := <-events:
|
case ev, ok := <-events:
|
||||||
if !ok {
|
if !ok {
|
||||||
return &streamingResult{usage: usage, firstTokenMs: firstTokenMs}, nil
|
// 上游完成,返回结果
|
||||||
|
return &streamingResult{usage: usage, firstTokenMs: firstTokenMs, clientDisconnect: clientDisconnected}, nil
|
||||||
}
|
}
|
||||||
if ev.err != nil {
|
if ev.err != nil {
|
||||||
|
// 检测 context 取消(客户端断开会导致 context 取消,进而影响上游读取)
|
||||||
|
if errors.Is(ev.err, context.Canceled) || errors.Is(ev.err, context.DeadlineExceeded) {
|
||||||
|
log.Printf("Context canceled during streaming, returning collected usage")
|
||||||
|
return &streamingResult{usage: usage, firstTokenMs: firstTokenMs, clientDisconnect: true}, nil
|
||||||
|
}
|
||||||
|
// 客户端已通过写入失败检测到断开,上游也出错了,返回已收集的 usage
|
||||||
|
if clientDisconnected {
|
||||||
|
log.Printf("Upstream read error after client disconnect: %v, returning collected usage", ev.err)
|
||||||
|
return &streamingResult{usage: usage, firstTokenMs: firstTokenMs, clientDisconnect: true}, nil
|
||||||
|
}
|
||||||
|
// 客户端未断开,正常的错误处理
|
||||||
if errors.Is(ev.err, bufio.ErrTooLong) {
|
if errors.Is(ev.err, bufio.ErrTooLong) {
|
||||||
log.Printf("SSE line too long: account=%d max_size=%d error=%v", account.ID, maxLineSize, ev.err)
|
log.Printf("SSE line too long: account=%d max_size=%d error=%v", account.ID, maxLineSize, ev.err)
|
||||||
sendErrorEvent("response_too_large")
|
sendErrorEvent("response_too_large")
|
||||||
@@ -1960,38 +1978,40 @@ func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http
|
|||||||
}
|
}
|
||||||
line := ev.line
|
line := ev.line
|
||||||
if line == "event: error" {
|
if line == "event: error" {
|
||||||
|
// 上游返回错误事件,如果客户端已断开仍返回已收集的 usage
|
||||||
|
if clientDisconnected {
|
||||||
|
return &streamingResult{usage: usage, firstTokenMs: firstTokenMs, clientDisconnect: true}, nil
|
||||||
|
}
|
||||||
return nil, errors.New("have error in stream")
|
return nil, errors.New("have error in stream")
|
||||||
}
|
}
|
||||||
|
|
||||||
// Extract data from SSE line (supports both "data: " and "data:" formats)
|
// Extract data from SSE line (supports both "data: " and "data:" formats)
|
||||||
|
var data string
|
||||||
if sseDataRe.MatchString(line) {
|
if sseDataRe.MatchString(line) {
|
||||||
data := sseDataRe.ReplaceAllString(line, "")
|
data = sseDataRe.ReplaceAllString(line, "")
|
||||||
|
|
||||||
// 如果有模型映射,替换响应中的model字段
|
// 如果有模型映射,替换响应中的model字段
|
||||||
if needModelReplace {
|
if needModelReplace {
|
||||||
line = s.replaceModelInSSELine(line, mappedModel, originalModel)
|
line = s.replaceModelInSSELine(line, mappedModel, originalModel)
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// 转发行
|
// 写入客户端(统一处理 data 行和非 data 行)
|
||||||
|
if !clientDisconnected {
|
||||||
if _, err := fmt.Fprintf(w, "%s\n", line); err != nil {
|
if _, err := fmt.Fprintf(w, "%s\n", line); err != nil {
|
||||||
sendErrorEvent("write_failed")
|
clientDisconnected = true
|
||||||
return &streamingResult{usage: usage, firstTokenMs: firstTokenMs}, err
|
log.Printf("Client disconnected during streaming, continuing to drain upstream for billing")
|
||||||
|
} else {
|
||||||
|
flusher.Flush()
|
||||||
}
|
}
|
||||||
flusher.Flush()
|
}
|
||||||
|
|
||||||
// 记录首字时间:第一个有效的 content_block_delta 或 message_start
|
// 无论客户端是否断开,都解析 usage(仅对 data 行)
|
||||||
if firstTokenMs == nil && data != "" && data != "[DONE]" {
|
if data != "" {
|
||||||
|
if firstTokenMs == nil && data != "[DONE]" {
|
||||||
ms := int(time.Since(startTime).Milliseconds())
|
ms := int(time.Since(startTime).Milliseconds())
|
||||||
firstTokenMs = &ms
|
firstTokenMs = &ms
|
||||||
}
|
}
|
||||||
s.parseSSEUsage(data, usage)
|
s.parseSSEUsage(data, usage)
|
||||||
} else {
|
|
||||||
// 非 data 行直接转发
|
|
||||||
if _, err := fmt.Fprintf(w, "%s\n", line); err != nil {
|
|
||||||
sendErrorEvent("write_failed")
|
|
||||||
return &streamingResult{usage: usage, firstTokenMs: firstTokenMs}, err
|
|
||||||
}
|
|
||||||
flusher.Flush()
|
|
||||||
}
|
}
|
||||||
|
|
||||||
case <-intervalCh:
|
case <-intervalCh:
|
||||||
@@ -1999,6 +2019,11 @@ func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http
|
|||||||
if time.Since(lastRead) < streamInterval {
|
if time.Since(lastRead) < streamInterval {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
if clientDisconnected {
|
||||||
|
// 客户端已断开,上游也超时了,返回已收集的 usage
|
||||||
|
log.Printf("Upstream timeout after client disconnect, returning collected usage")
|
||||||
|
return &streamingResult{usage: usage, firstTokenMs: firstTokenMs, clientDisconnect: true}, nil
|
||||||
|
}
|
||||||
log.Printf("Stream data interval timeout: account=%d model=%s interval=%s", account.ID, originalModel, streamInterval)
|
log.Printf("Stream data interval timeout: account=%d model=%s interval=%s", account.ID, originalModel, streamInterval)
|
||||||
sendErrorEvent("stream_timeout")
|
sendErrorEvent("stream_timeout")
|
||||||
return &streamingResult{usage: usage, firstTokenMs: firstTokenMs}, fmt.Errorf("stream data interval timeout")
|
return &streamingResult{usage: usage, firstTokenMs: firstTokenMs}, fmt.Errorf("stream data interval timeout")
|
||||||
|
|||||||
Reference in New Issue
Block a user