diff --git a/backend/internal/service/antigravity_gateway_service.go b/backend/internal/service/antigravity_gateway_service.go index cc4d2fb9..6b2a19ce 100644 --- a/backend/internal/service/antigravity_gateway_service.go +++ b/backend/internal/service/antigravity_gateway_service.go @@ -1305,6 +1305,7 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context, var usage *ClaudeUsage var firstTokenMs *int + var clientDisconnect bool if claudeReq.Stream { // 客户端要求流式,直接透传转换 streamRes, err := s.handleClaudeStreamingResponse(c, resp, startTime, originalModel) @@ -1314,6 +1315,7 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context, } usage = streamRes.usage firstTokenMs = streamRes.firstTokenMs + clientDisconnect = streamRes.clientDisconnect } else { // 客户端要求非流式,收集流式响应后转换返回 streamRes, err := s.handleClaudeStreamToNonStreaming(c, resp, startTime, originalModel) @@ -1326,12 +1328,13 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context, } return &ForwardResult{ - RequestID: requestID, - Usage: *usage, - Model: originalModel, // 使用原始模型用于计费和日志 - Stream: claudeReq.Stream, - Duration: time.Since(startTime), - FirstTokenMs: firstTokenMs, + RequestID: requestID, + Usage: *usage, + Model: originalModel, // 使用原始模型用于计费和日志 + Stream: claudeReq.Stream, + Duration: time.Since(startTime), + FirstTokenMs: firstTokenMs, + ClientDisconnect: clientDisconnect, }, nil } @@ -1860,6 +1863,7 @@ handleSuccess: var usage *ClaudeUsage var firstTokenMs *int + var clientDisconnect bool if stream { // 客户端要求流式,直接透传 @@ -1870,6 +1874,7 @@ handleSuccess: } usage = streamRes.usage firstTokenMs = streamRes.firstTokenMs + clientDisconnect = streamRes.clientDisconnect } else { // 客户端要求非流式,收集流式响应后返回 streamRes, err := s.handleGeminiStreamToNonStreaming(c, resp, startTime) @@ -1893,14 +1898,15 @@ handleSuccess: } return &ForwardResult{ - RequestID: requestID, - Usage: *usage, - Model: originalModel, - Stream: stream, - Duration: time.Since(startTime), - FirstTokenMs: firstTokenMs, - ImageCount: imageCount, - ImageSize: imageSize, + RequestID: requestID, + Usage: *usage, + Model: originalModel, + Stream: stream, + Duration: time.Since(startTime), + FirstTokenMs: firstTokenMs, + ClientDisconnect: clientDisconnect, + ImageCount: imageCount, + ImageSize: imageSize, }, nil } @@ -2319,8 +2325,69 @@ func (s *AntigravityGatewayService) handleUpstreamError( } type antigravityStreamResult struct { - usage *ClaudeUsage - firstTokenMs *int + usage *ClaudeUsage + firstTokenMs *int + clientDisconnect bool // 客户端是否在流式传输过程中断开 +} + +// antigravityClientWriter 封装流式响应的客户端写入,自动检测断开并标记。 +// 断开后所有写入操作变为 no-op,调用方通过 Disconnected() 判断是否继续 drain 上游。 +type antigravityClientWriter struct { + w gin.ResponseWriter + flusher http.Flusher + disconnected bool + prefix string // 日志前缀,标识来源方法 +} + +func newAntigravityClientWriter(w gin.ResponseWriter, flusher http.Flusher, prefix string) *antigravityClientWriter { + return &antigravityClientWriter{w: w, flusher: flusher, prefix: prefix} +} + +// Write 写入数据到客户端,写入失败时标记断开并返回 false +func (cw *antigravityClientWriter) Write(p []byte) bool { + if cw.disconnected { + return false + } + if _, err := cw.w.Write(p); err != nil { + cw.markDisconnected() + return false + } + cw.flusher.Flush() + return true +} + +// Fprintf 格式化写入数据到客户端,写入失败时标记断开并返回 false +func (cw *antigravityClientWriter) Fprintf(format string, args ...any) bool { + if cw.disconnected { + return false + } + if _, err := fmt.Fprintf(cw.w, format, args...); err != nil { + cw.markDisconnected() + return false + } + cw.flusher.Flush() + return true +} + +func (cw *antigravityClientWriter) Disconnected() bool { return cw.disconnected } + +func (cw *antigravityClientWriter) markDisconnected() { + cw.disconnected = true + log.Printf("Client disconnected during streaming (%s), continuing to drain upstream for billing", cw.prefix) +} + +// handleStreamReadError 处理上游读取错误的通用逻辑。 +// 返回 (clientDisconnect, handled):handled=true 表示错误已处理,调用方应返回已收集的 usage。 +func handleStreamReadError(err error, clientDisconnected bool, prefix string) (disconnect bool, handled bool) { + if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) { + log.Printf("Context canceled during streaming (%s), returning collected usage", prefix) + return true, true + } + if clientDisconnected { + log.Printf("Upstream read error after client disconnect (%s): %v, returning collected usage", prefix, err) + return true, true + } + return false, false } func (s *AntigravityGatewayService) handleGeminiStreamingResponse(c *gin.Context, resp *http.Response, startTime time.Time) (*antigravityStreamResult, error) { @@ -2396,10 +2463,12 @@ func (s *AntigravityGatewayService) handleGeminiStreamingResponse(c *gin.Context intervalCh = intervalTicker.C } + cw := newAntigravityClientWriter(c.Writer, flusher, "antigravity gemini") + // 仅发送一次错误事件,避免多次写入导致协议混乱 errorEventSent := false sendErrorEvent := func(reason string) { - if errorEventSent { + if errorEventSent || cw.Disconnected() { return } errorEventSent = true @@ -2411,9 +2480,12 @@ func (s *AntigravityGatewayService) handleGeminiStreamingResponse(c *gin.Context select { case ev, ok := <-events: if !ok { - return &antigravityStreamResult{usage: usage, firstTokenMs: firstTokenMs}, nil + return &antigravityStreamResult{usage: usage, firstTokenMs: firstTokenMs, clientDisconnect: cw.Disconnected()}, nil } if ev.err != nil { + if disconnect, handled := handleStreamReadError(ev.err, cw.Disconnected(), "antigravity gemini"); handled { + return &antigravityStreamResult{usage: usage, firstTokenMs: firstTokenMs, clientDisconnect: disconnect}, nil + } if errors.Is(ev.err, bufio.ErrTooLong) { log.Printf("SSE line too long (antigravity): max_size=%d error=%v", maxLineSize, ev.err) sendErrorEvent("response_too_large") @@ -2428,11 +2500,7 @@ func (s *AntigravityGatewayService) handleGeminiStreamingResponse(c *gin.Context if strings.HasPrefix(trimmed, "data:") { payload := strings.TrimSpace(strings.TrimPrefix(trimmed, "data:")) if payload == "" || payload == "[DONE]" { - if _, err := fmt.Fprintf(c.Writer, "%s\n", line); err != nil { - sendErrorEvent("write_failed") - return &antigravityStreamResult{usage: usage, firstTokenMs: firstTokenMs}, err - } - flusher.Flush() + cw.Fprintf("%s\n", line) continue } @@ -2468,27 +2536,22 @@ func (s *AntigravityGatewayService) handleGeminiStreamingResponse(c *gin.Context firstTokenMs = &ms } - if _, err := fmt.Fprintf(c.Writer, "data: %s\n\n", payload); err != nil { - sendErrorEvent("write_failed") - return &antigravityStreamResult{usage: usage, firstTokenMs: firstTokenMs}, err - } - flusher.Flush() + cw.Fprintf("data: %s\n\n", payload) continue } - if _, err := fmt.Fprintf(c.Writer, "%s\n", line); err != nil { - sendErrorEvent("write_failed") - return &antigravityStreamResult{usage: usage, firstTokenMs: firstTokenMs}, err - } - flusher.Flush() + cw.Fprintf("%s\n", line) case <-intervalCh: lastRead := time.Unix(0, atomic.LoadInt64(&lastReadAt)) if time.Since(lastRead) < streamInterval { continue } + if cw.Disconnected() { + log.Printf("Upstream timeout after client disconnect (antigravity gemini), returning collected usage") + return &antigravityStreamResult{usage: usage, firstTokenMs: firstTokenMs, clientDisconnect: true}, nil + } log.Printf("Stream data interval timeout (antigravity)") - // 注意:此函数没有 account 上下文,无法调用 HandleStreamTimeout sendErrorEvent("stream_timeout") return &antigravityStreamResult{usage: usage, firstTokenMs: firstTokenMs}, fmt.Errorf("stream data interval timeout") } @@ -3186,10 +3249,12 @@ func (s *AntigravityGatewayService) handleClaudeStreamingResponse(c *gin.Context intervalCh = intervalTicker.C } + cw := newAntigravityClientWriter(c.Writer, flusher, "antigravity claude") + // 仅发送一次错误事件,避免多次写入导致协议混乱 errorEventSent := false sendErrorEvent := func(reason string) { - if errorEventSent { + if errorEventSent || cw.Disconnected() { return } errorEventSent = true @@ -3197,19 +3262,27 @@ func (s *AntigravityGatewayService) handleClaudeStreamingResponse(c *gin.Context flusher.Flush() } + // finishUsage 是获取 processor 最终 usage 的辅助函数 + finishUsage := func() *ClaudeUsage { + _, agUsage := processor.Finish() + return convertUsage(agUsage) + } + for { select { case ev, ok := <-events: if !ok { - // 发送结束事件 + // 上游完成,发送结束事件 finalEvents, agUsage := processor.Finish() if len(finalEvents) > 0 { - _, _ = c.Writer.Write(finalEvents) - flusher.Flush() + cw.Write(finalEvents) } - return &antigravityStreamResult{usage: convertUsage(agUsage), firstTokenMs: firstTokenMs}, nil + return &antigravityStreamResult{usage: convertUsage(agUsage), firstTokenMs: firstTokenMs, clientDisconnect: cw.Disconnected()}, nil } if ev.err != nil { + if disconnect, handled := handleStreamReadError(ev.err, cw.Disconnected(), "antigravity claude"); handled { + return &antigravityStreamResult{usage: finishUsage(), firstTokenMs: firstTokenMs, clientDisconnect: disconnect}, nil + } if errors.Is(ev.err, bufio.ErrTooLong) { log.Printf("SSE line too long (antigravity): max_size=%d error=%v", maxLineSize, ev.err) sendErrorEvent("response_too_large") @@ -3219,25 +3292,14 @@ func (s *AntigravityGatewayService) handleClaudeStreamingResponse(c *gin.Context return nil, fmt.Errorf("stream read error: %w", ev.err) } - line := ev.line // 处理 SSE 行,转换为 Claude 格式 - claudeEvents := processor.ProcessLine(strings.TrimRight(line, "\r\n")) - + claudeEvents := processor.ProcessLine(strings.TrimRight(ev.line, "\r\n")) if len(claudeEvents) > 0 { if firstTokenMs == nil { ms := int(time.Since(startTime).Milliseconds()) firstTokenMs = &ms } - - if _, writeErr := c.Writer.Write(claudeEvents); writeErr != nil { - finalEvents, agUsage := processor.Finish() - if len(finalEvents) > 0 { - _, _ = c.Writer.Write(finalEvents) - } - sendErrorEvent("write_failed") - return &antigravityStreamResult{usage: convertUsage(agUsage), firstTokenMs: firstTokenMs}, writeErr - } - flusher.Flush() + cw.Write(claudeEvents) } case <-intervalCh: @@ -3245,13 +3307,15 @@ func (s *AntigravityGatewayService) handleClaudeStreamingResponse(c *gin.Context if time.Since(lastRead) < streamInterval { continue } + if cw.Disconnected() { + log.Printf("Upstream timeout after client disconnect (antigravity claude), returning collected usage") + return &antigravityStreamResult{usage: finishUsage(), firstTokenMs: firstTokenMs, clientDisconnect: true}, nil + } log.Printf("Stream data interval timeout (antigravity)") - // 注意:此函数没有 account 上下文,无法调用 HandleStreamTimeout sendErrorEvent("stream_timeout") return &antigravityStreamResult{usage: convertUsage(nil), firstTokenMs: firstTokenMs}, fmt.Errorf("stream data interval timeout") } } - } // extractImageSize 从 Gemini 请求中提取 image_size 参数 @@ -3390,3 +3454,289 @@ func filterEmptyPartsFromGeminiRequest(body []byte) ([]byte, error) { payload["contents"] = filtered return json.Marshal(payload) } + +// ForwardUpstream 使用 base_url + /v1/messages + 双 header 认证透传上游 Claude 请求 +func (s *AntigravityGatewayService) ForwardUpstream(ctx context.Context, c *gin.Context, account *Account, body []byte) (*ForwardResult, error) { + startTime := time.Now() + sessionID := getSessionID(c) + prefix := logPrefix(sessionID, account.Name) + + // 获取上游配置 + baseURL := strings.TrimSpace(account.GetCredential("base_url")) + apiKey := strings.TrimSpace(account.GetCredential("api_key")) + if baseURL == "" || apiKey == "" { + return nil, fmt.Errorf("upstream account missing base_url or api_key") + } + baseURL = strings.TrimSuffix(baseURL, "/") + + // 解析请求获取模型信息 + var claudeReq antigravity.ClaudeRequest + if err := json.Unmarshal(body, &claudeReq); err != nil { + return nil, fmt.Errorf("parse claude request: %w", err) + } + if strings.TrimSpace(claudeReq.Model) == "" { + return nil, fmt.Errorf("missing model") + } + originalModel := claudeReq.Model + billingModel := originalModel + + // 构建上游请求 URL + upstreamURL := baseURL + "/v1/messages" + + // 创建请求 + req, err := http.NewRequestWithContext(ctx, http.MethodPost, upstreamURL, bytes.NewReader(body)) + if err != nil { + return nil, fmt.Errorf("create upstream request: %w", err) + } + + // 设置请求头 + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Authorization", "Bearer "+apiKey) + req.Header.Set("x-api-key", apiKey) // Claude API 兼容 + + // 透传 Claude 相关 headers + if v := c.GetHeader("anthropic-version"); v != "" { + req.Header.Set("anthropic-version", v) + } + if v := c.GetHeader("anthropic-beta"); v != "" { + req.Header.Set("anthropic-beta", v) + } + + // 代理 URL + proxyURL := "" + if account.ProxyID != nil && account.Proxy != nil { + proxyURL = account.Proxy.URL() + } + + // 发送请求 + resp, err := s.httpUpstream.Do(req, proxyURL, account.ID, account.Concurrency) + if err != nil { + log.Printf("%s upstream request failed: %v", prefix, err) + return nil, fmt.Errorf("upstream request failed: %w", err) + } + defer func() { _ = resp.Body.Close() }() + + // 处理错误响应 + if resp.StatusCode >= 400 { + respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20)) + + // 429 错误时标记账号限流 + if resp.StatusCode == http.StatusTooManyRequests { + quotaScope, _ := resolveAntigravityQuotaScope(originalModel) + s.handleUpstreamError(ctx, prefix, account, resp.StatusCode, resp.Header, respBody, quotaScope, 0, "", false) + } + + // 透传上游错误 + c.Header("Content-Type", resp.Header.Get("Content-Type")) + c.Status(resp.StatusCode) + _, _ = c.Writer.Write(respBody) + + return &ForwardResult{ + Model: billingModel, + }, nil + } + + // 处理成功响应(流式/非流式) + var usage *ClaudeUsage + var firstTokenMs *int + var clientDisconnect bool + + if claudeReq.Stream { + // 流式响应:透传 + c.Header("Content-Type", "text/event-stream") + c.Header("Cache-Control", "no-cache") + c.Header("Connection", "keep-alive") + c.Header("X-Accel-Buffering", "no") + c.Status(http.StatusOK) + + streamRes := s.streamUpstreamResponse(c, resp, startTime) + usage = streamRes.usage + firstTokenMs = streamRes.firstTokenMs + clientDisconnect = streamRes.clientDisconnect + } else { + // 非流式响应:直接透传 + respBody, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("read upstream response: %w", err) + } + + // 提取 usage + usage = s.extractClaudeUsage(respBody) + + c.Header("Content-Type", resp.Header.Get("Content-Type")) + c.Status(http.StatusOK) + _, _ = c.Writer.Write(respBody) + } + + // 构建计费结果 + duration := time.Since(startTime) + log.Printf("%s status=success duration_ms=%d", prefix, duration.Milliseconds()) + + return &ForwardResult{ + Model: billingModel, + Stream: claudeReq.Stream, + Duration: duration, + FirstTokenMs: firstTokenMs, + ClientDisconnect: clientDisconnect, + Usage: ClaudeUsage{ + InputTokens: usage.InputTokens, + OutputTokens: usage.OutputTokens, + CacheReadInputTokens: usage.CacheReadInputTokens, + CacheCreationInputTokens: usage.CacheCreationInputTokens, + }, + }, nil +} + +// streamUpstreamResponse 透传上游 SSE 流并提取 Claude usage +func (s *AntigravityGatewayService) streamUpstreamResponse(c *gin.Context, resp *http.Response, startTime time.Time) *antigravityStreamResult { + usage := &ClaudeUsage{} + var firstTokenMs *int + + scanner := bufio.NewScanner(resp.Body) + maxLineSize := defaultMaxLineSize + if s.settingService.cfg != nil && s.settingService.cfg.Gateway.MaxLineSize > 0 { + maxLineSize = s.settingService.cfg.Gateway.MaxLineSize + } + scanner.Buffer(make([]byte, 64*1024), maxLineSize) + + type scanEvent struct { + line string + err error + } + events := make(chan scanEvent, 16) + done := make(chan struct{}) + sendEvent := func(ev scanEvent) bool { + select { + case events <- ev: + return true + case <-done: + return false + } + } + var lastReadAt int64 + atomic.StoreInt64(&lastReadAt, time.Now().UnixNano()) + go func() { + defer close(events) + for scanner.Scan() { + atomic.StoreInt64(&lastReadAt, time.Now().UnixNano()) + if !sendEvent(scanEvent{line: scanner.Text()}) { + return + } + } + if err := scanner.Err(); err != nil { + _ = sendEvent(scanEvent{err: err}) + } + }() + defer close(done) + + streamInterval := time.Duration(0) + if s.settingService.cfg != nil && s.settingService.cfg.Gateway.StreamDataIntervalTimeout > 0 { + streamInterval = time.Duration(s.settingService.cfg.Gateway.StreamDataIntervalTimeout) * time.Second + } + var intervalTicker *time.Ticker + if streamInterval > 0 { + intervalTicker = time.NewTicker(streamInterval) + defer intervalTicker.Stop() + } + var intervalCh <-chan time.Time + if intervalTicker != nil { + intervalCh = intervalTicker.C + } + + flusher, _ := c.Writer.(http.Flusher) + cw := newAntigravityClientWriter(c.Writer, flusher, "antigravity upstream") + + for { + select { + case ev, ok := <-events: + if !ok { + return &antigravityStreamResult{usage: usage, firstTokenMs: firstTokenMs, clientDisconnect: cw.Disconnected()} + } + if ev.err != nil { + if disconnect, handled := handleStreamReadError(ev.err, cw.Disconnected(), "antigravity upstream"); handled { + return &antigravityStreamResult{usage: usage, firstTokenMs: firstTokenMs, clientDisconnect: disconnect} + } + log.Printf("Stream read error (antigravity upstream): %v", ev.err) + return &antigravityStreamResult{usage: usage, firstTokenMs: firstTokenMs} + } + + line := ev.line + + // 记录首 token 时间 + if firstTokenMs == nil && len(line) > 0 { + ms := int(time.Since(startTime).Milliseconds()) + firstTokenMs = &ms + } + + // 尝试从 message_delta 或 message_stop 事件提取 usage + s.extractSSEUsage(line, usage) + + // 透传行 + cw.Fprintf("%s\n", line) + + case <-intervalCh: + lastRead := time.Unix(0, atomic.LoadInt64(&lastReadAt)) + if time.Since(lastRead) < streamInterval { + continue + } + if cw.Disconnected() { + log.Printf("Upstream timeout after client disconnect (antigravity upstream), returning collected usage") + return &antigravityStreamResult{usage: usage, firstTokenMs: firstTokenMs, clientDisconnect: true} + } + log.Printf("Stream data interval timeout (antigravity upstream)") + return &antigravityStreamResult{usage: usage, firstTokenMs: firstTokenMs} + } + } +} + +// extractSSEUsage 从 SSE data 行中提取 Claude usage(用于流式透传场景) +func (s *AntigravityGatewayService) extractSSEUsage(line string, usage *ClaudeUsage) { + if !strings.HasPrefix(line, "data: ") { + return + } + dataStr := strings.TrimPrefix(line, "data: ") + var event map[string]any + if json.Unmarshal([]byte(dataStr), &event) != nil { + return + } + u, ok := event["usage"].(map[string]any) + if !ok { + return + } + if v, ok := u["input_tokens"].(float64); ok && int(v) > 0 { + usage.InputTokens = int(v) + } + if v, ok := u["output_tokens"].(float64); ok && int(v) > 0 { + usage.OutputTokens = int(v) + } + if v, ok := u["cache_read_input_tokens"].(float64); ok && int(v) > 0 { + usage.CacheReadInputTokens = int(v) + } + if v, ok := u["cache_creation_input_tokens"].(float64); ok && int(v) > 0 { + usage.CacheCreationInputTokens = int(v) + } +} + +// extractClaudeUsage 从非流式 Claude 响应提取 usage +func (s *AntigravityGatewayService) extractClaudeUsage(body []byte) *ClaudeUsage { + usage := &ClaudeUsage{} + var resp map[string]any + if json.Unmarshal(body, &resp) != nil { + return usage + } + if u, ok := resp["usage"].(map[string]any); ok { + if v, ok := u["input_tokens"].(float64); ok { + usage.InputTokens = int(v) + } + if v, ok := u["output_tokens"].(float64); ok { + usage.OutputTokens = int(v) + } + if v, ok := u["cache_read_input_tokens"].(float64); ok { + usage.CacheReadInputTokens = int(v) + } + if v, ok := u["cache_creation_input_tokens"].(float64); ok { + usage.CacheCreationInputTokens = int(v) + } + } + return usage +} diff --git a/backend/internal/service/antigravity_gateway_service_test.go b/backend/internal/service/antigravity_gateway_service_test.go index 91cefc28..a6a349c1 100644 --- a/backend/internal/service/antigravity_gateway_service_test.go +++ b/backend/internal/service/antigravity_gateway_service_test.go @@ -4,17 +4,42 @@ import ( "bytes" "context" "encoding/json" + "errors" + "fmt" "io" "net/http" "net/http/httptest" "testing" "time" + "github.com/Wei-Shaw/sub2api/internal/config" "github.com/Wei-Shaw/sub2api/internal/pkg/antigravity" "github.com/gin-gonic/gin" "github.com/stretchr/testify/require" ) +// antigravityFailingWriter 模拟客户端断开连接的 gin.ResponseWriter +type antigravityFailingWriter struct { + gin.ResponseWriter + failAfter int // 允许成功写入的次数,之后所有写入返回错误 + writes int +} + +func (w *antigravityFailingWriter) Write(p []byte) (int, error) { + if w.writes >= w.failAfter { + return 0, errors.New("write failed: client disconnected") + } + w.writes++ + return w.ResponseWriter.Write(p) +} + +// newAntigravityTestService 创建用于流式测试的 AntigravityGatewayService +func newAntigravityTestService(cfg *config.Config) *AntigravityGatewayService { + return &AntigravityGatewayService{ + settingService: &SettingService{cfg: cfg}, + } +} + func TestStripSignatureSensitiveBlocksFromClaudeRequest(t *testing.T) { req := &antigravity.ClaudeRequest{ Model: "claude-sonnet-4-5", @@ -337,8 +362,8 @@ func TestAntigravityGatewayService_Forward_StickySessionForceCacheBilling(t *tes require.True(t, failoverErr.ForceCacheBilling, "ForceCacheBilling should be true for sticky session switch") } -// TestAntigravityGatewayService_ForwardGemini_StickySessionForceCacheBilling -// 验证:ForwardGemini 粘性会话切换时,UpstreamFailoverError.ForceCacheBilling 应为 true +// TestAntigravityGatewayService_ForwardGemini_StickySessionForceCacheBilling verifies +// that ForwardGemini sets ForceCacheBilling=true for sticky session switch. func TestAntigravityGatewayService_ForwardGemini_StickySessionForceCacheBilling(t *testing.T) { gin.SetMode(gin.TestMode) writer := httptest.NewRecorder() @@ -391,3 +416,438 @@ func TestAntigravityGatewayService_ForwardGemini_StickySessionForceCacheBilling( require.Equal(t, http.StatusServiceUnavailable, failoverErr.StatusCode) require.True(t, failoverErr.ForceCacheBilling, "ForceCacheBilling should be true for sticky session switch") } + +// --- 流式 happy path 测试 --- + +// TestStreamUpstreamResponse_NormalComplete +// 验证:正常流式转发完成时,数据正确透传、usage 正确收集、clientDisconnect=false +func TestStreamUpstreamResponse_NormalComplete(t *testing.T) { + gin.SetMode(gin.TestMode) + svc := newAntigravityTestService(&config.Config{ + Gateway: config.GatewayConfig{MaxLineSize: defaultMaxLineSize}, + }) + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/", nil) + + pr, pw := io.Pipe() + resp := &http.Response{StatusCode: http.StatusOK, Body: pr, Header: http.Header{}} + + go func() { + defer func() { _ = pw.Close() }() + fmt.Fprintln(pw, `event: message_start`) + fmt.Fprintln(pw, `data: {"type":"message_start","message":{"usage":{"input_tokens":10}}}`) + fmt.Fprintln(pw, "") + fmt.Fprintln(pw, `event: content_block_delta`) + fmt.Fprintln(pw, `data: {"type":"content_block_delta","delta":{"text":"hello"}}`) + fmt.Fprintln(pw, "") + fmt.Fprintln(pw, `event: message_delta`) + fmt.Fprintln(pw, `data: {"type":"message_delta","usage":{"output_tokens":5}}`) + fmt.Fprintln(pw, "") + }() + + result := svc.streamUpstreamResponse(c, resp, time.Now()) + _ = pr.Close() + + require.NotNil(t, result) + require.False(t, result.clientDisconnect, "normal completion should not set clientDisconnect") + require.NotNil(t, result.usage) + require.Equal(t, 5, result.usage.OutputTokens, "should collect output_tokens from message_delta") + require.NotNil(t, result.firstTokenMs, "should record first token time") + + // 验证数据被透传到客户端 + body := rec.Body.String() + require.Contains(t, body, "event: message_start") + require.Contains(t, body, "content_block_delta") + require.Contains(t, body, "message_delta") +} + +// TestHandleGeminiStreamingResponse_NormalComplete +// 验证:正常 Gemini 流式转发,数据正确透传、usage 正确收集 +func TestHandleGeminiStreamingResponse_NormalComplete(t *testing.T) { + gin.SetMode(gin.TestMode) + svc := newAntigravityTestService(&config.Config{ + Gateway: config.GatewayConfig{MaxLineSize: defaultMaxLineSize}, + }) + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/", nil) + + pr, pw := io.Pipe() + resp := &http.Response{StatusCode: http.StatusOK, Body: pr, Header: http.Header{}} + + go func() { + defer func() { _ = pw.Close() }() + // 第一个 chunk(部分内容) + fmt.Fprintln(pw, `data: {"candidates":[{"content":{"parts":[{"text":"Hello"}]}}],"usageMetadata":{"promptTokenCount":10,"candidatesTokenCount":3}}`) + fmt.Fprintln(pw, "") + // 第二个 chunk(最终内容+完整 usage) + fmt.Fprintln(pw, `data: {"candidates":[{"content":{"parts":[{"text":" world"}]},"finishReason":"STOP"}],"usageMetadata":{"promptTokenCount":10,"candidatesTokenCount":8,"cachedContentTokenCount":2}}`) + fmt.Fprintln(pw, "") + }() + + result, err := svc.handleGeminiStreamingResponse(c, resp, time.Now()) + _ = pr.Close() + + require.NoError(t, err) + require.NotNil(t, result) + require.False(t, result.clientDisconnect, "normal completion should not set clientDisconnect") + require.NotNil(t, result.usage) + // Gemini usage: promptTokenCount=10, candidatesTokenCount=8, cachedContentTokenCount=2 + // → InputTokens=10-2=8, OutputTokens=8, CacheReadInputTokens=2 + require.Equal(t, 8, result.usage.InputTokens) + require.Equal(t, 8, result.usage.OutputTokens) + require.Equal(t, 2, result.usage.CacheReadInputTokens) + require.NotNil(t, result.firstTokenMs, "should record first token time") + + // 验证数据被透传到客户端 + body := rec.Body.String() + require.Contains(t, body, "Hello") + require.Contains(t, body, "world") + // 不应包含错误事件 + require.NotContains(t, body, "event: error") +} + +// TestHandleClaudeStreamingResponse_NormalComplete +// 验证:正常 Claude 流式转发(Gemini→Claude 转换),数据正确转换并输出 +func TestHandleClaudeStreamingResponse_NormalComplete(t *testing.T) { + gin.SetMode(gin.TestMode) + svc := newAntigravityTestService(&config.Config{ + Gateway: config.GatewayConfig{MaxLineSize: defaultMaxLineSize}, + }) + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/", nil) + + pr, pw := io.Pipe() + resp := &http.Response{StatusCode: http.StatusOK, Body: pr, Header: http.Header{}} + + go func() { + defer func() { _ = pw.Close() }() + // v1internal 包装格式:Gemini 数据嵌套在 "response" 字段下 + // ProcessLine 先尝试反序列化为 V1InternalResponse,裸格式会导致 Response.UsageMetadata 为空 + fmt.Fprintln(pw, `data: {"response":{"candidates":[{"content":{"parts":[{"text":"Hi there"}]},"finishReason":"STOP"}],"usageMetadata":{"promptTokenCount":5,"candidatesTokenCount":3}}}`) + fmt.Fprintln(pw, "") + }() + + result, err := svc.handleClaudeStreamingResponse(c, resp, time.Now(), "claude-sonnet-4-5") + _ = pr.Close() + + require.NoError(t, err) + require.NotNil(t, result) + require.False(t, result.clientDisconnect, "normal completion should not set clientDisconnect") + require.NotNil(t, result.usage) + // Gemini→Claude 转换的 usage:promptTokenCount=5→InputTokens=5, candidatesTokenCount=3→OutputTokens=3 + require.Equal(t, 5, result.usage.InputTokens) + require.Equal(t, 3, result.usage.OutputTokens) + require.NotNil(t, result.firstTokenMs, "should record first token time") + + // 验证输出是 Claude SSE 格式(processor 会转换) + body := rec.Body.String() + require.Contains(t, body, "event: message_start", "should contain Claude message_start event") + require.Contains(t, body, "event: message_stop", "should contain Claude message_stop event") + // 不应包含错误事件 + require.NotContains(t, body, "event: error") +} + +// --- 流式客户端断开检测测试 --- + +// TestStreamUpstreamResponse_ClientDisconnectDrainsUsage +// 验证:客户端写入失败后,streamUpstreamResponse 继续读取上游以收集 usage +func TestStreamUpstreamResponse_ClientDisconnectDrainsUsage(t *testing.T) { + gin.SetMode(gin.TestMode) + svc := newAntigravityTestService(&config.Config{ + Gateway: config.GatewayConfig{MaxLineSize: defaultMaxLineSize}, + }) + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/", nil) + c.Writer = &antigravityFailingWriter{ResponseWriter: c.Writer, failAfter: 0} + + pr, pw := io.Pipe() + resp := &http.Response{StatusCode: http.StatusOK, Body: pr, Header: http.Header{}} + + go func() { + defer func() { _ = pw.Close() }() + fmt.Fprintln(pw, `event: message_start`) + fmt.Fprintln(pw, `data: {"type":"message_start","message":{"usage":{"input_tokens":10}}}`) + fmt.Fprintln(pw, "") + fmt.Fprintln(pw, `event: message_delta`) + fmt.Fprintln(pw, `data: {"type":"message_delta","usage":{"output_tokens":20}}`) + fmt.Fprintln(pw, "") + }() + + result := svc.streamUpstreamResponse(c, resp, time.Now()) + _ = pr.Close() + + require.NotNil(t, result) + require.True(t, result.clientDisconnect) + require.NotNil(t, result.usage) + require.Equal(t, 20, result.usage.OutputTokens) +} + +// TestStreamUpstreamResponse_ContextCanceled +// 验证:context 取消时返回 usage 且标记 clientDisconnect +func TestStreamUpstreamResponse_ContextCanceled(t *testing.T) { + gin.SetMode(gin.TestMode) + svc := newAntigravityTestService(&config.Config{ + Gateway: config.GatewayConfig{MaxLineSize: defaultMaxLineSize}, + }) + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + ctx, cancel := context.WithCancel(context.Background()) + cancel() + c.Request = httptest.NewRequest(http.MethodPost, "/", nil).WithContext(ctx) + + resp := &http.Response{StatusCode: http.StatusOK, Body: cancelReadCloser{}, Header: http.Header{}} + + result := svc.streamUpstreamResponse(c, resp, time.Now()) + + require.NotNil(t, result) + require.True(t, result.clientDisconnect) + require.NotContains(t, rec.Body.String(), "event: error") +} + +// TestStreamUpstreamResponse_Timeout +// 验证:上游超时时返回已收集的 usage +func TestStreamUpstreamResponse_Timeout(t *testing.T) { + gin.SetMode(gin.TestMode) + svc := newAntigravityTestService(&config.Config{ + Gateway: config.GatewayConfig{StreamDataIntervalTimeout: 1, MaxLineSize: defaultMaxLineSize}, + }) + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/", nil) + + pr, pw := io.Pipe() + resp := &http.Response{StatusCode: http.StatusOK, Body: pr, Header: http.Header{}} + + result := svc.streamUpstreamResponse(c, resp, time.Now()) + _ = pw.Close() + _ = pr.Close() + + require.NotNil(t, result) + require.False(t, result.clientDisconnect) +} + +// TestStreamUpstreamResponse_TimeoutAfterClientDisconnect +// 验证:客户端断开后上游超时,返回 usage 并标记 clientDisconnect +func TestStreamUpstreamResponse_TimeoutAfterClientDisconnect(t *testing.T) { + gin.SetMode(gin.TestMode) + svc := newAntigravityTestService(&config.Config{ + Gateway: config.GatewayConfig{StreamDataIntervalTimeout: 1, MaxLineSize: defaultMaxLineSize}, + }) + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/", nil) + c.Writer = &antigravityFailingWriter{ResponseWriter: c.Writer, failAfter: 0} + + pr, pw := io.Pipe() + resp := &http.Response{StatusCode: http.StatusOK, Body: pr, Header: http.Header{}} + + go func() { + fmt.Fprintln(pw, `data: {"type":"message_start","message":{"usage":{"input_tokens":5}}}`) + fmt.Fprintln(pw, "") + // 不关闭 pw → 等待超时 + }() + + result := svc.streamUpstreamResponse(c, resp, time.Now()) + _ = pw.Close() + _ = pr.Close() + + require.NotNil(t, result) + require.True(t, result.clientDisconnect) +} + +// TestHandleGeminiStreamingResponse_ClientDisconnect +// 验证:Gemini 流式转发中客户端断开后继续 drain 上游 +func TestHandleGeminiStreamingResponse_ClientDisconnect(t *testing.T) { + gin.SetMode(gin.TestMode) + svc := newAntigravityTestService(&config.Config{ + Gateway: config.GatewayConfig{MaxLineSize: defaultMaxLineSize}, + }) + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/", nil) + c.Writer = &antigravityFailingWriter{ResponseWriter: c.Writer, failAfter: 0} + + pr, pw := io.Pipe() + resp := &http.Response{StatusCode: http.StatusOK, Body: pr, Header: http.Header{}} + + go func() { + defer func() { _ = pw.Close() }() + fmt.Fprintln(pw, `data: {"candidates":[{"content":{"parts":[{"text":"hi"}]}}],"usageMetadata":{"promptTokenCount":5,"candidatesTokenCount":10}}`) + fmt.Fprintln(pw, "") + }() + + result, err := svc.handleGeminiStreamingResponse(c, resp, time.Now()) + _ = pr.Close() + + require.NoError(t, err) + require.NotNil(t, result) + require.True(t, result.clientDisconnect) + require.NotContains(t, rec.Body.String(), "write_failed") +} + +// TestHandleGeminiStreamingResponse_ContextCanceled +// 验证:context 取消时不注入错误事件 +func TestHandleGeminiStreamingResponse_ContextCanceled(t *testing.T) { + gin.SetMode(gin.TestMode) + svc := newAntigravityTestService(&config.Config{ + Gateway: config.GatewayConfig{MaxLineSize: defaultMaxLineSize}, + }) + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + ctx, cancel := context.WithCancel(context.Background()) + cancel() + c.Request = httptest.NewRequest(http.MethodPost, "/", nil).WithContext(ctx) + + resp := &http.Response{StatusCode: http.StatusOK, Body: cancelReadCloser{}, Header: http.Header{}} + + result, err := svc.handleGeminiStreamingResponse(c, resp, time.Now()) + + require.NoError(t, err) + require.NotNil(t, result) + require.True(t, result.clientDisconnect) + require.NotContains(t, rec.Body.String(), "event: error") +} + +// TestHandleClaudeStreamingResponse_ClientDisconnect +// 验证:Claude 流式转发中客户端断开后继续 drain 上游 +func TestHandleClaudeStreamingResponse_ClientDisconnect(t *testing.T) { + gin.SetMode(gin.TestMode) + svc := newAntigravityTestService(&config.Config{ + Gateway: config.GatewayConfig{MaxLineSize: defaultMaxLineSize}, + }) + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/", nil) + c.Writer = &antigravityFailingWriter{ResponseWriter: c.Writer, failAfter: 0} + + pr, pw := io.Pipe() + resp := &http.Response{StatusCode: http.StatusOK, Body: pr, Header: http.Header{}} + + go func() { + defer func() { _ = pw.Close() }() + // v1internal 包装格式 + fmt.Fprintln(pw, `data: {"response":{"candidates":[{"content":{"parts":[{"text":"hello"}]},"finishReason":"STOP"}],"usageMetadata":{"promptTokenCount":8,"candidatesTokenCount":15}}}`) + fmt.Fprintln(pw, "") + }() + + result, err := svc.handleClaudeStreamingResponse(c, resp, time.Now(), "claude-sonnet-4-5") + _ = pr.Close() + + require.NoError(t, err) + require.NotNil(t, result) + require.True(t, result.clientDisconnect) +} + +// TestHandleClaudeStreamingResponse_ContextCanceled +// 验证:context 取消时不注入错误事件 +func TestHandleClaudeStreamingResponse_ContextCanceled(t *testing.T) { + gin.SetMode(gin.TestMode) + svc := newAntigravityTestService(&config.Config{ + Gateway: config.GatewayConfig{MaxLineSize: defaultMaxLineSize}, + }) + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + ctx, cancel := context.WithCancel(context.Background()) + cancel() + c.Request = httptest.NewRequest(http.MethodPost, "/", nil).WithContext(ctx) + + resp := &http.Response{StatusCode: http.StatusOK, Body: cancelReadCloser{}, Header: http.Header{}} + + result, err := svc.handleClaudeStreamingResponse(c, resp, time.Now(), "claude-sonnet-4-5") + + require.NoError(t, err) + require.NotNil(t, result) + require.True(t, result.clientDisconnect) + require.NotContains(t, rec.Body.String(), "event: error") +} + +// TestExtractSSEUsage 验证 extractSSEUsage 从 SSE data 行正确提取 usage +func TestExtractSSEUsage(t *testing.T) { + svc := &AntigravityGatewayService{} + tests := []struct { + name string + line string + expected ClaudeUsage + }{ + { + name: "message_delta with output_tokens", + line: `data: {"type":"message_delta","usage":{"output_tokens":42}}`, + expected: ClaudeUsage{OutputTokens: 42}, + }, + { + name: "non-data line ignored", + line: `event: message_start`, + expected: ClaudeUsage{}, + }, + { + name: "top-level usage with all fields", + line: `data: {"usage":{"input_tokens":10,"output_tokens":20,"cache_read_input_tokens":5,"cache_creation_input_tokens":3}}`, + expected: ClaudeUsage{InputTokens: 10, OutputTokens: 20, CacheReadInputTokens: 5, CacheCreationInputTokens: 3}, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + usage := &ClaudeUsage{} + svc.extractSSEUsage(tt.line, usage) + require.Equal(t, tt.expected, *usage) + }) + } +} + +// TestAntigravityClientWriter 验证 antigravityClientWriter 的断开检测 +func TestAntigravityClientWriter(t *testing.T) { + t.Run("normal write succeeds", func(t *testing.T) { + gin.SetMode(gin.TestMode) + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + flusher, _ := c.Writer.(http.Flusher) + cw := newAntigravityClientWriter(c.Writer, flusher, "test") + + ok := cw.Write([]byte("hello")) + require.True(t, ok) + require.False(t, cw.Disconnected()) + require.Contains(t, rec.Body.String(), "hello") + }) + + t.Run("write failure marks disconnected", func(t *testing.T) { + gin.SetMode(gin.TestMode) + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + fw := &antigravityFailingWriter{ResponseWriter: c.Writer, failAfter: 0} + flusher, _ := c.Writer.(http.Flusher) + cw := newAntigravityClientWriter(fw, flusher, "test") + + ok := cw.Write([]byte("hello")) + require.False(t, ok) + require.True(t, cw.Disconnected()) + }) + + t.Run("subsequent writes are no-op", func(t *testing.T) { + gin.SetMode(gin.TestMode) + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + fw := &antigravityFailingWriter{ResponseWriter: c.Writer, failAfter: 0} + flusher, _ := c.Writer.(http.Flusher) + cw := newAntigravityClientWriter(fw, flusher, "test") + + cw.Write([]byte("first")) + ok := cw.Fprintf("second %d", 2) + require.False(t, ok) + require.True(t, cw.Disconnected()) + }) +}