feat: detect client disconnect during streaming and continue draining upstream for billing
This commit is contained in:
@@ -1305,6 +1305,7 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context,
|
||||
|
||||
var usage *ClaudeUsage
|
||||
var firstTokenMs *int
|
||||
var clientDisconnect bool
|
||||
if claudeReq.Stream {
|
||||
// 客户端要求流式,直接透传转换
|
||||
streamRes, err := s.handleClaudeStreamingResponse(c, resp, startTime, originalModel)
|
||||
@@ -1314,6 +1315,7 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context,
|
||||
}
|
||||
usage = streamRes.usage
|
||||
firstTokenMs = streamRes.firstTokenMs
|
||||
clientDisconnect = streamRes.clientDisconnect
|
||||
} else {
|
||||
// 客户端要求非流式,收集流式响应后转换返回
|
||||
streamRes, err := s.handleClaudeStreamToNonStreaming(c, resp, startTime, originalModel)
|
||||
@@ -1326,12 +1328,13 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context,
|
||||
}
|
||||
|
||||
return &ForwardResult{
|
||||
RequestID: requestID,
|
||||
Usage: *usage,
|
||||
Model: originalModel, // 使用原始模型用于计费和日志
|
||||
Stream: claudeReq.Stream,
|
||||
Duration: time.Since(startTime),
|
||||
FirstTokenMs: firstTokenMs,
|
||||
RequestID: requestID,
|
||||
Usage: *usage,
|
||||
Model: originalModel, // 使用原始模型用于计费和日志
|
||||
Stream: claudeReq.Stream,
|
||||
Duration: time.Since(startTime),
|
||||
FirstTokenMs: firstTokenMs,
|
||||
ClientDisconnect: clientDisconnect,
|
||||
}, nil
|
||||
}
|
||||
|
||||
@@ -1860,6 +1863,7 @@ handleSuccess:
|
||||
|
||||
var usage *ClaudeUsage
|
||||
var firstTokenMs *int
|
||||
var clientDisconnect bool
|
||||
|
||||
if stream {
|
||||
// 客户端要求流式,直接透传
|
||||
@@ -1870,6 +1874,7 @@ handleSuccess:
|
||||
}
|
||||
usage = streamRes.usage
|
||||
firstTokenMs = streamRes.firstTokenMs
|
||||
clientDisconnect = streamRes.clientDisconnect
|
||||
} else {
|
||||
// 客户端要求非流式,收集流式响应后返回
|
||||
streamRes, err := s.handleGeminiStreamToNonStreaming(c, resp, startTime)
|
||||
@@ -1893,14 +1898,15 @@ handleSuccess:
|
||||
}
|
||||
|
||||
return &ForwardResult{
|
||||
RequestID: requestID,
|
||||
Usage: *usage,
|
||||
Model: originalModel,
|
||||
Stream: stream,
|
||||
Duration: time.Since(startTime),
|
||||
FirstTokenMs: firstTokenMs,
|
||||
ImageCount: imageCount,
|
||||
ImageSize: imageSize,
|
||||
RequestID: requestID,
|
||||
Usage: *usage,
|
||||
Model: originalModel,
|
||||
Stream: stream,
|
||||
Duration: time.Since(startTime),
|
||||
FirstTokenMs: firstTokenMs,
|
||||
ClientDisconnect: clientDisconnect,
|
||||
ImageCount: imageCount,
|
||||
ImageSize: imageSize,
|
||||
}, nil
|
||||
}
|
||||
|
||||
@@ -2319,8 +2325,69 @@ func (s *AntigravityGatewayService) handleUpstreamError(
|
||||
}
|
||||
|
||||
type antigravityStreamResult struct {
|
||||
usage *ClaudeUsage
|
||||
firstTokenMs *int
|
||||
usage *ClaudeUsage
|
||||
firstTokenMs *int
|
||||
clientDisconnect bool // 客户端是否在流式传输过程中断开
|
||||
}
|
||||
|
||||
// antigravityClientWriter 封装流式响应的客户端写入,自动检测断开并标记。
|
||||
// 断开后所有写入操作变为 no-op,调用方通过 Disconnected() 判断是否继续 drain 上游。
|
||||
type antigravityClientWriter struct {
|
||||
w gin.ResponseWriter
|
||||
flusher http.Flusher
|
||||
disconnected bool
|
||||
prefix string // 日志前缀,标识来源方法
|
||||
}
|
||||
|
||||
func newAntigravityClientWriter(w gin.ResponseWriter, flusher http.Flusher, prefix string) *antigravityClientWriter {
|
||||
return &antigravityClientWriter{w: w, flusher: flusher, prefix: prefix}
|
||||
}
|
||||
|
||||
// Write 写入数据到客户端,写入失败时标记断开并返回 false
|
||||
func (cw *antigravityClientWriter) Write(p []byte) bool {
|
||||
if cw.disconnected {
|
||||
return false
|
||||
}
|
||||
if _, err := cw.w.Write(p); err != nil {
|
||||
cw.markDisconnected()
|
||||
return false
|
||||
}
|
||||
cw.flusher.Flush()
|
||||
return true
|
||||
}
|
||||
|
||||
// Fprintf 格式化写入数据到客户端,写入失败时标记断开并返回 false
|
||||
func (cw *antigravityClientWriter) Fprintf(format string, args ...any) bool {
|
||||
if cw.disconnected {
|
||||
return false
|
||||
}
|
||||
if _, err := fmt.Fprintf(cw.w, format, args...); err != nil {
|
||||
cw.markDisconnected()
|
||||
return false
|
||||
}
|
||||
cw.flusher.Flush()
|
||||
return true
|
||||
}
|
||||
|
||||
func (cw *antigravityClientWriter) Disconnected() bool { return cw.disconnected }
|
||||
|
||||
func (cw *antigravityClientWriter) markDisconnected() {
|
||||
cw.disconnected = true
|
||||
log.Printf("Client disconnected during streaming (%s), continuing to drain upstream for billing", cw.prefix)
|
||||
}
|
||||
|
||||
// handleStreamReadError 处理上游读取错误的通用逻辑。
|
||||
// 返回 (clientDisconnect, handled):handled=true 表示错误已处理,调用方应返回已收集的 usage。
|
||||
func handleStreamReadError(err error, clientDisconnected bool, prefix string) (disconnect bool, handled bool) {
|
||||
if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) {
|
||||
log.Printf("Context canceled during streaming (%s), returning collected usage", prefix)
|
||||
return true, true
|
||||
}
|
||||
if clientDisconnected {
|
||||
log.Printf("Upstream read error after client disconnect (%s): %v, returning collected usage", prefix, err)
|
||||
return true, true
|
||||
}
|
||||
return false, false
|
||||
}
|
||||
|
||||
func (s *AntigravityGatewayService) handleGeminiStreamingResponse(c *gin.Context, resp *http.Response, startTime time.Time) (*antigravityStreamResult, error) {
|
||||
@@ -2396,10 +2463,12 @@ func (s *AntigravityGatewayService) handleGeminiStreamingResponse(c *gin.Context
|
||||
intervalCh = intervalTicker.C
|
||||
}
|
||||
|
||||
cw := newAntigravityClientWriter(c.Writer, flusher, "antigravity gemini")
|
||||
|
||||
// 仅发送一次错误事件,避免多次写入导致协议混乱
|
||||
errorEventSent := false
|
||||
sendErrorEvent := func(reason string) {
|
||||
if errorEventSent {
|
||||
if errorEventSent || cw.Disconnected() {
|
||||
return
|
||||
}
|
||||
errorEventSent = true
|
||||
@@ -2411,9 +2480,12 @@ func (s *AntigravityGatewayService) handleGeminiStreamingResponse(c *gin.Context
|
||||
select {
|
||||
case ev, ok := <-events:
|
||||
if !ok {
|
||||
return &antigravityStreamResult{usage: usage, firstTokenMs: firstTokenMs}, nil
|
||||
return &antigravityStreamResult{usage: usage, firstTokenMs: firstTokenMs, clientDisconnect: cw.Disconnected()}, nil
|
||||
}
|
||||
if ev.err != nil {
|
||||
if disconnect, handled := handleStreamReadError(ev.err, cw.Disconnected(), "antigravity gemini"); handled {
|
||||
return &antigravityStreamResult{usage: usage, firstTokenMs: firstTokenMs, clientDisconnect: disconnect}, nil
|
||||
}
|
||||
if errors.Is(ev.err, bufio.ErrTooLong) {
|
||||
log.Printf("SSE line too long (antigravity): max_size=%d error=%v", maxLineSize, ev.err)
|
||||
sendErrorEvent("response_too_large")
|
||||
@@ -2428,11 +2500,7 @@ func (s *AntigravityGatewayService) handleGeminiStreamingResponse(c *gin.Context
|
||||
if strings.HasPrefix(trimmed, "data:") {
|
||||
payload := strings.TrimSpace(strings.TrimPrefix(trimmed, "data:"))
|
||||
if payload == "" || payload == "[DONE]" {
|
||||
if _, err := fmt.Fprintf(c.Writer, "%s\n", line); err != nil {
|
||||
sendErrorEvent("write_failed")
|
||||
return &antigravityStreamResult{usage: usage, firstTokenMs: firstTokenMs}, err
|
||||
}
|
||||
flusher.Flush()
|
||||
cw.Fprintf("%s\n", line)
|
||||
continue
|
||||
}
|
||||
|
||||
@@ -2468,27 +2536,22 @@ func (s *AntigravityGatewayService) handleGeminiStreamingResponse(c *gin.Context
|
||||
firstTokenMs = &ms
|
||||
}
|
||||
|
||||
if _, err := fmt.Fprintf(c.Writer, "data: %s\n\n", payload); err != nil {
|
||||
sendErrorEvent("write_failed")
|
||||
return &antigravityStreamResult{usage: usage, firstTokenMs: firstTokenMs}, err
|
||||
}
|
||||
flusher.Flush()
|
||||
cw.Fprintf("data: %s\n\n", payload)
|
||||
continue
|
||||
}
|
||||
|
||||
if _, err := fmt.Fprintf(c.Writer, "%s\n", line); err != nil {
|
||||
sendErrorEvent("write_failed")
|
||||
return &antigravityStreamResult{usage: usage, firstTokenMs: firstTokenMs}, err
|
||||
}
|
||||
flusher.Flush()
|
||||
cw.Fprintf("%s\n", line)
|
||||
|
||||
case <-intervalCh:
|
||||
lastRead := time.Unix(0, atomic.LoadInt64(&lastReadAt))
|
||||
if time.Since(lastRead) < streamInterval {
|
||||
continue
|
||||
}
|
||||
if cw.Disconnected() {
|
||||
log.Printf("Upstream timeout after client disconnect (antigravity gemini), returning collected usage")
|
||||
return &antigravityStreamResult{usage: usage, firstTokenMs: firstTokenMs, clientDisconnect: true}, nil
|
||||
}
|
||||
log.Printf("Stream data interval timeout (antigravity)")
|
||||
// 注意:此函数没有 account 上下文,无法调用 HandleStreamTimeout
|
||||
sendErrorEvent("stream_timeout")
|
||||
return &antigravityStreamResult{usage: usage, firstTokenMs: firstTokenMs}, fmt.Errorf("stream data interval timeout")
|
||||
}
|
||||
@@ -3186,10 +3249,12 @@ func (s *AntigravityGatewayService) handleClaudeStreamingResponse(c *gin.Context
|
||||
intervalCh = intervalTicker.C
|
||||
}
|
||||
|
||||
cw := newAntigravityClientWriter(c.Writer, flusher, "antigravity claude")
|
||||
|
||||
// 仅发送一次错误事件,避免多次写入导致协议混乱
|
||||
errorEventSent := false
|
||||
sendErrorEvent := func(reason string) {
|
||||
if errorEventSent {
|
||||
if errorEventSent || cw.Disconnected() {
|
||||
return
|
||||
}
|
||||
errorEventSent = true
|
||||
@@ -3197,19 +3262,27 @@ func (s *AntigravityGatewayService) handleClaudeStreamingResponse(c *gin.Context
|
||||
flusher.Flush()
|
||||
}
|
||||
|
||||
// finishUsage 是获取 processor 最终 usage 的辅助函数
|
||||
finishUsage := func() *ClaudeUsage {
|
||||
_, agUsage := processor.Finish()
|
||||
return convertUsage(agUsage)
|
||||
}
|
||||
|
||||
for {
|
||||
select {
|
||||
case ev, ok := <-events:
|
||||
if !ok {
|
||||
// 发送结束事件
|
||||
// 上游完成,发送结束事件
|
||||
finalEvents, agUsage := processor.Finish()
|
||||
if len(finalEvents) > 0 {
|
||||
_, _ = c.Writer.Write(finalEvents)
|
||||
flusher.Flush()
|
||||
cw.Write(finalEvents)
|
||||
}
|
||||
return &antigravityStreamResult{usage: convertUsage(agUsage), firstTokenMs: firstTokenMs}, nil
|
||||
return &antigravityStreamResult{usage: convertUsage(agUsage), firstTokenMs: firstTokenMs, clientDisconnect: cw.Disconnected()}, nil
|
||||
}
|
||||
if ev.err != nil {
|
||||
if disconnect, handled := handleStreamReadError(ev.err, cw.Disconnected(), "antigravity claude"); handled {
|
||||
return &antigravityStreamResult{usage: finishUsage(), firstTokenMs: firstTokenMs, clientDisconnect: disconnect}, nil
|
||||
}
|
||||
if errors.Is(ev.err, bufio.ErrTooLong) {
|
||||
log.Printf("SSE line too long (antigravity): max_size=%d error=%v", maxLineSize, ev.err)
|
||||
sendErrorEvent("response_too_large")
|
||||
@@ -3219,25 +3292,14 @@ func (s *AntigravityGatewayService) handleClaudeStreamingResponse(c *gin.Context
|
||||
return nil, fmt.Errorf("stream read error: %w", ev.err)
|
||||
}
|
||||
|
||||
line := ev.line
|
||||
// 处理 SSE 行,转换为 Claude 格式
|
||||
claudeEvents := processor.ProcessLine(strings.TrimRight(line, "\r\n"))
|
||||
|
||||
claudeEvents := processor.ProcessLine(strings.TrimRight(ev.line, "\r\n"))
|
||||
if len(claudeEvents) > 0 {
|
||||
if firstTokenMs == nil {
|
||||
ms := int(time.Since(startTime).Milliseconds())
|
||||
firstTokenMs = &ms
|
||||
}
|
||||
|
||||
if _, writeErr := c.Writer.Write(claudeEvents); writeErr != nil {
|
||||
finalEvents, agUsage := processor.Finish()
|
||||
if len(finalEvents) > 0 {
|
||||
_, _ = c.Writer.Write(finalEvents)
|
||||
}
|
||||
sendErrorEvent("write_failed")
|
||||
return &antigravityStreamResult{usage: convertUsage(agUsage), firstTokenMs: firstTokenMs}, writeErr
|
||||
}
|
||||
flusher.Flush()
|
||||
cw.Write(claudeEvents)
|
||||
}
|
||||
|
||||
case <-intervalCh:
|
||||
@@ -3245,13 +3307,15 @@ func (s *AntigravityGatewayService) handleClaudeStreamingResponse(c *gin.Context
|
||||
if time.Since(lastRead) < streamInterval {
|
||||
continue
|
||||
}
|
||||
if cw.Disconnected() {
|
||||
log.Printf("Upstream timeout after client disconnect (antigravity claude), returning collected usage")
|
||||
return &antigravityStreamResult{usage: finishUsage(), firstTokenMs: firstTokenMs, clientDisconnect: true}, nil
|
||||
}
|
||||
log.Printf("Stream data interval timeout (antigravity)")
|
||||
// 注意:此函数没有 account 上下文,无法调用 HandleStreamTimeout
|
||||
sendErrorEvent("stream_timeout")
|
||||
return &antigravityStreamResult{usage: convertUsage(nil), firstTokenMs: firstTokenMs}, fmt.Errorf("stream data interval timeout")
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
// extractImageSize 从 Gemini 请求中提取 image_size 参数
|
||||
@@ -3390,3 +3454,289 @@ func filterEmptyPartsFromGeminiRequest(body []byte) ([]byte, error) {
|
||||
payload["contents"] = filtered
|
||||
return json.Marshal(payload)
|
||||
}
|
||||
|
||||
// ForwardUpstream 使用 base_url + /v1/messages + 双 header 认证透传上游 Claude 请求
|
||||
func (s *AntigravityGatewayService) ForwardUpstream(ctx context.Context, c *gin.Context, account *Account, body []byte) (*ForwardResult, error) {
|
||||
startTime := time.Now()
|
||||
sessionID := getSessionID(c)
|
||||
prefix := logPrefix(sessionID, account.Name)
|
||||
|
||||
// 获取上游配置
|
||||
baseURL := strings.TrimSpace(account.GetCredential("base_url"))
|
||||
apiKey := strings.TrimSpace(account.GetCredential("api_key"))
|
||||
if baseURL == "" || apiKey == "" {
|
||||
return nil, fmt.Errorf("upstream account missing base_url or api_key")
|
||||
}
|
||||
baseURL = strings.TrimSuffix(baseURL, "/")
|
||||
|
||||
// 解析请求获取模型信息
|
||||
var claudeReq antigravity.ClaudeRequest
|
||||
if err := json.Unmarshal(body, &claudeReq); err != nil {
|
||||
return nil, fmt.Errorf("parse claude request: %w", err)
|
||||
}
|
||||
if strings.TrimSpace(claudeReq.Model) == "" {
|
||||
return nil, fmt.Errorf("missing model")
|
||||
}
|
||||
originalModel := claudeReq.Model
|
||||
billingModel := originalModel
|
||||
|
||||
// 构建上游请求 URL
|
||||
upstreamURL := baseURL + "/v1/messages"
|
||||
|
||||
// 创建请求
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, upstreamURL, bytes.NewReader(body))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("create upstream request: %w", err)
|
||||
}
|
||||
|
||||
// 设置请求头
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Authorization", "Bearer "+apiKey)
|
||||
req.Header.Set("x-api-key", apiKey) // Claude API 兼容
|
||||
|
||||
// 透传 Claude 相关 headers
|
||||
if v := c.GetHeader("anthropic-version"); v != "" {
|
||||
req.Header.Set("anthropic-version", v)
|
||||
}
|
||||
if v := c.GetHeader("anthropic-beta"); v != "" {
|
||||
req.Header.Set("anthropic-beta", v)
|
||||
}
|
||||
|
||||
// 代理 URL
|
||||
proxyURL := ""
|
||||
if account.ProxyID != nil && account.Proxy != nil {
|
||||
proxyURL = account.Proxy.URL()
|
||||
}
|
||||
|
||||
// 发送请求
|
||||
resp, err := s.httpUpstream.Do(req, proxyURL, account.ID, account.Concurrency)
|
||||
if err != nil {
|
||||
log.Printf("%s upstream request failed: %v", prefix, err)
|
||||
return nil, fmt.Errorf("upstream request failed: %w", err)
|
||||
}
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
|
||||
// 处理错误响应
|
||||
if resp.StatusCode >= 400 {
|
||||
respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
|
||||
|
||||
// 429 错误时标记账号限流
|
||||
if resp.StatusCode == http.StatusTooManyRequests {
|
||||
quotaScope, _ := resolveAntigravityQuotaScope(originalModel)
|
||||
s.handleUpstreamError(ctx, prefix, account, resp.StatusCode, resp.Header, respBody, quotaScope, 0, "", false)
|
||||
}
|
||||
|
||||
// 透传上游错误
|
||||
c.Header("Content-Type", resp.Header.Get("Content-Type"))
|
||||
c.Status(resp.StatusCode)
|
||||
_, _ = c.Writer.Write(respBody)
|
||||
|
||||
return &ForwardResult{
|
||||
Model: billingModel,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// 处理成功响应(流式/非流式)
|
||||
var usage *ClaudeUsage
|
||||
var firstTokenMs *int
|
||||
var clientDisconnect bool
|
||||
|
||||
if claudeReq.Stream {
|
||||
// 流式响应:透传
|
||||
c.Header("Content-Type", "text/event-stream")
|
||||
c.Header("Cache-Control", "no-cache")
|
||||
c.Header("Connection", "keep-alive")
|
||||
c.Header("X-Accel-Buffering", "no")
|
||||
c.Status(http.StatusOK)
|
||||
|
||||
streamRes := s.streamUpstreamResponse(c, resp, startTime)
|
||||
usage = streamRes.usage
|
||||
firstTokenMs = streamRes.firstTokenMs
|
||||
clientDisconnect = streamRes.clientDisconnect
|
||||
} else {
|
||||
// 非流式响应:直接透传
|
||||
respBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("read upstream response: %w", err)
|
||||
}
|
||||
|
||||
// 提取 usage
|
||||
usage = s.extractClaudeUsage(respBody)
|
||||
|
||||
c.Header("Content-Type", resp.Header.Get("Content-Type"))
|
||||
c.Status(http.StatusOK)
|
||||
_, _ = c.Writer.Write(respBody)
|
||||
}
|
||||
|
||||
// 构建计费结果
|
||||
duration := time.Since(startTime)
|
||||
log.Printf("%s status=success duration_ms=%d", prefix, duration.Milliseconds())
|
||||
|
||||
return &ForwardResult{
|
||||
Model: billingModel,
|
||||
Stream: claudeReq.Stream,
|
||||
Duration: duration,
|
||||
FirstTokenMs: firstTokenMs,
|
||||
ClientDisconnect: clientDisconnect,
|
||||
Usage: ClaudeUsage{
|
||||
InputTokens: usage.InputTokens,
|
||||
OutputTokens: usage.OutputTokens,
|
||||
CacheReadInputTokens: usage.CacheReadInputTokens,
|
||||
CacheCreationInputTokens: usage.CacheCreationInputTokens,
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
// streamUpstreamResponse 透传上游 SSE 流并提取 Claude usage
|
||||
func (s *AntigravityGatewayService) streamUpstreamResponse(c *gin.Context, resp *http.Response, startTime time.Time) *antigravityStreamResult {
|
||||
usage := &ClaudeUsage{}
|
||||
var firstTokenMs *int
|
||||
|
||||
scanner := bufio.NewScanner(resp.Body)
|
||||
maxLineSize := defaultMaxLineSize
|
||||
if s.settingService.cfg != nil && s.settingService.cfg.Gateway.MaxLineSize > 0 {
|
||||
maxLineSize = s.settingService.cfg.Gateway.MaxLineSize
|
||||
}
|
||||
scanner.Buffer(make([]byte, 64*1024), maxLineSize)
|
||||
|
||||
type scanEvent struct {
|
||||
line string
|
||||
err error
|
||||
}
|
||||
events := make(chan scanEvent, 16)
|
||||
done := make(chan struct{})
|
||||
sendEvent := func(ev scanEvent) bool {
|
||||
select {
|
||||
case events <- ev:
|
||||
return true
|
||||
case <-done:
|
||||
return false
|
||||
}
|
||||
}
|
||||
var lastReadAt int64
|
||||
atomic.StoreInt64(&lastReadAt, time.Now().UnixNano())
|
||||
go func() {
|
||||
defer close(events)
|
||||
for scanner.Scan() {
|
||||
atomic.StoreInt64(&lastReadAt, time.Now().UnixNano())
|
||||
if !sendEvent(scanEvent{line: scanner.Text()}) {
|
||||
return
|
||||
}
|
||||
}
|
||||
if err := scanner.Err(); err != nil {
|
||||
_ = sendEvent(scanEvent{err: err})
|
||||
}
|
||||
}()
|
||||
defer close(done)
|
||||
|
||||
streamInterval := time.Duration(0)
|
||||
if s.settingService.cfg != nil && s.settingService.cfg.Gateway.StreamDataIntervalTimeout > 0 {
|
||||
streamInterval = time.Duration(s.settingService.cfg.Gateway.StreamDataIntervalTimeout) * time.Second
|
||||
}
|
||||
var intervalTicker *time.Ticker
|
||||
if streamInterval > 0 {
|
||||
intervalTicker = time.NewTicker(streamInterval)
|
||||
defer intervalTicker.Stop()
|
||||
}
|
||||
var intervalCh <-chan time.Time
|
||||
if intervalTicker != nil {
|
||||
intervalCh = intervalTicker.C
|
||||
}
|
||||
|
||||
flusher, _ := c.Writer.(http.Flusher)
|
||||
cw := newAntigravityClientWriter(c.Writer, flusher, "antigravity upstream")
|
||||
|
||||
for {
|
||||
select {
|
||||
case ev, ok := <-events:
|
||||
if !ok {
|
||||
return &antigravityStreamResult{usage: usage, firstTokenMs: firstTokenMs, clientDisconnect: cw.Disconnected()}
|
||||
}
|
||||
if ev.err != nil {
|
||||
if disconnect, handled := handleStreamReadError(ev.err, cw.Disconnected(), "antigravity upstream"); handled {
|
||||
return &antigravityStreamResult{usage: usage, firstTokenMs: firstTokenMs, clientDisconnect: disconnect}
|
||||
}
|
||||
log.Printf("Stream read error (antigravity upstream): %v", ev.err)
|
||||
return &antigravityStreamResult{usage: usage, firstTokenMs: firstTokenMs}
|
||||
}
|
||||
|
||||
line := ev.line
|
||||
|
||||
// 记录首 token 时间
|
||||
if firstTokenMs == nil && len(line) > 0 {
|
||||
ms := int(time.Since(startTime).Milliseconds())
|
||||
firstTokenMs = &ms
|
||||
}
|
||||
|
||||
// 尝试从 message_delta 或 message_stop 事件提取 usage
|
||||
s.extractSSEUsage(line, usage)
|
||||
|
||||
// 透传行
|
||||
cw.Fprintf("%s\n", line)
|
||||
|
||||
case <-intervalCh:
|
||||
lastRead := time.Unix(0, atomic.LoadInt64(&lastReadAt))
|
||||
if time.Since(lastRead) < streamInterval {
|
||||
continue
|
||||
}
|
||||
if cw.Disconnected() {
|
||||
log.Printf("Upstream timeout after client disconnect (antigravity upstream), returning collected usage")
|
||||
return &antigravityStreamResult{usage: usage, firstTokenMs: firstTokenMs, clientDisconnect: true}
|
||||
}
|
||||
log.Printf("Stream data interval timeout (antigravity upstream)")
|
||||
return &antigravityStreamResult{usage: usage, firstTokenMs: firstTokenMs}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// extractSSEUsage 从 SSE data 行中提取 Claude usage(用于流式透传场景)
|
||||
func (s *AntigravityGatewayService) extractSSEUsage(line string, usage *ClaudeUsage) {
|
||||
if !strings.HasPrefix(line, "data: ") {
|
||||
return
|
||||
}
|
||||
dataStr := strings.TrimPrefix(line, "data: ")
|
||||
var event map[string]any
|
||||
if json.Unmarshal([]byte(dataStr), &event) != nil {
|
||||
return
|
||||
}
|
||||
u, ok := event["usage"].(map[string]any)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
if v, ok := u["input_tokens"].(float64); ok && int(v) > 0 {
|
||||
usage.InputTokens = int(v)
|
||||
}
|
||||
if v, ok := u["output_tokens"].(float64); ok && int(v) > 0 {
|
||||
usage.OutputTokens = int(v)
|
||||
}
|
||||
if v, ok := u["cache_read_input_tokens"].(float64); ok && int(v) > 0 {
|
||||
usage.CacheReadInputTokens = int(v)
|
||||
}
|
||||
if v, ok := u["cache_creation_input_tokens"].(float64); ok && int(v) > 0 {
|
||||
usage.CacheCreationInputTokens = int(v)
|
||||
}
|
||||
}
|
||||
|
||||
// extractClaudeUsage 从非流式 Claude 响应提取 usage
|
||||
func (s *AntigravityGatewayService) extractClaudeUsage(body []byte) *ClaudeUsage {
|
||||
usage := &ClaudeUsage{}
|
||||
var resp map[string]any
|
||||
if json.Unmarshal(body, &resp) != nil {
|
||||
return usage
|
||||
}
|
||||
if u, ok := resp["usage"].(map[string]any); ok {
|
||||
if v, ok := u["input_tokens"].(float64); ok {
|
||||
usage.InputTokens = int(v)
|
||||
}
|
||||
if v, ok := u["output_tokens"].(float64); ok {
|
||||
usage.OutputTokens = int(v)
|
||||
}
|
||||
if v, ok := u["cache_read_input_tokens"].(float64); ok {
|
||||
usage.CacheReadInputTokens = int(v)
|
||||
}
|
||||
if v, ok := u["cache_creation_input_tokens"].(float64); ok {
|
||||
usage.CacheCreationInputTokens = int(v)
|
||||
}
|
||||
}
|
||||
return usage
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user