feat: detect client disconnect during streaming and continue draining upstream for billing
This commit is contained in:
@@ -1305,6 +1305,7 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context,
|
|||||||
|
|
||||||
var usage *ClaudeUsage
|
var usage *ClaudeUsage
|
||||||
var firstTokenMs *int
|
var firstTokenMs *int
|
||||||
|
var clientDisconnect bool
|
||||||
if claudeReq.Stream {
|
if claudeReq.Stream {
|
||||||
// 客户端要求流式,直接透传转换
|
// 客户端要求流式,直接透传转换
|
||||||
streamRes, err := s.handleClaudeStreamingResponse(c, resp, startTime, originalModel)
|
streamRes, err := s.handleClaudeStreamingResponse(c, resp, startTime, originalModel)
|
||||||
@@ -1314,6 +1315,7 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context,
|
|||||||
}
|
}
|
||||||
usage = streamRes.usage
|
usage = streamRes.usage
|
||||||
firstTokenMs = streamRes.firstTokenMs
|
firstTokenMs = streamRes.firstTokenMs
|
||||||
|
clientDisconnect = streamRes.clientDisconnect
|
||||||
} else {
|
} else {
|
||||||
// 客户端要求非流式,收集流式响应后转换返回
|
// 客户端要求非流式,收集流式响应后转换返回
|
||||||
streamRes, err := s.handleClaudeStreamToNonStreaming(c, resp, startTime, originalModel)
|
streamRes, err := s.handleClaudeStreamToNonStreaming(c, resp, startTime, originalModel)
|
||||||
@@ -1326,12 +1328,13 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context,
|
|||||||
}
|
}
|
||||||
|
|
||||||
return &ForwardResult{
|
return &ForwardResult{
|
||||||
RequestID: requestID,
|
RequestID: requestID,
|
||||||
Usage: *usage,
|
Usage: *usage,
|
||||||
Model: originalModel, // 使用原始模型用于计费和日志
|
Model: originalModel, // 使用原始模型用于计费和日志
|
||||||
Stream: claudeReq.Stream,
|
Stream: claudeReq.Stream,
|
||||||
Duration: time.Since(startTime),
|
Duration: time.Since(startTime),
|
||||||
FirstTokenMs: firstTokenMs,
|
FirstTokenMs: firstTokenMs,
|
||||||
|
ClientDisconnect: clientDisconnect,
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1860,6 +1863,7 @@ handleSuccess:
|
|||||||
|
|
||||||
var usage *ClaudeUsage
|
var usage *ClaudeUsage
|
||||||
var firstTokenMs *int
|
var firstTokenMs *int
|
||||||
|
var clientDisconnect bool
|
||||||
|
|
||||||
if stream {
|
if stream {
|
||||||
// 客户端要求流式,直接透传
|
// 客户端要求流式,直接透传
|
||||||
@@ -1870,6 +1874,7 @@ handleSuccess:
|
|||||||
}
|
}
|
||||||
usage = streamRes.usage
|
usage = streamRes.usage
|
||||||
firstTokenMs = streamRes.firstTokenMs
|
firstTokenMs = streamRes.firstTokenMs
|
||||||
|
clientDisconnect = streamRes.clientDisconnect
|
||||||
} else {
|
} else {
|
||||||
// 客户端要求非流式,收集流式响应后返回
|
// 客户端要求非流式,收集流式响应后返回
|
||||||
streamRes, err := s.handleGeminiStreamToNonStreaming(c, resp, startTime)
|
streamRes, err := s.handleGeminiStreamToNonStreaming(c, resp, startTime)
|
||||||
@@ -1893,14 +1898,15 @@ handleSuccess:
|
|||||||
}
|
}
|
||||||
|
|
||||||
return &ForwardResult{
|
return &ForwardResult{
|
||||||
RequestID: requestID,
|
RequestID: requestID,
|
||||||
Usage: *usage,
|
Usage: *usage,
|
||||||
Model: originalModel,
|
Model: originalModel,
|
||||||
Stream: stream,
|
Stream: stream,
|
||||||
Duration: time.Since(startTime),
|
Duration: time.Since(startTime),
|
||||||
FirstTokenMs: firstTokenMs,
|
FirstTokenMs: firstTokenMs,
|
||||||
ImageCount: imageCount,
|
ClientDisconnect: clientDisconnect,
|
||||||
ImageSize: imageSize,
|
ImageCount: imageCount,
|
||||||
|
ImageSize: imageSize,
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -2319,8 +2325,69 @@ func (s *AntigravityGatewayService) handleUpstreamError(
|
|||||||
}
|
}
|
||||||
|
|
||||||
type antigravityStreamResult struct {
|
type antigravityStreamResult struct {
|
||||||
usage *ClaudeUsage
|
usage *ClaudeUsage
|
||||||
firstTokenMs *int
|
firstTokenMs *int
|
||||||
|
clientDisconnect bool // 客户端是否在流式传输过程中断开
|
||||||
|
}
|
||||||
|
|
||||||
|
// antigravityClientWriter 封装流式响应的客户端写入,自动检测断开并标记。
|
||||||
|
// 断开后所有写入操作变为 no-op,调用方通过 Disconnected() 判断是否继续 drain 上游。
|
||||||
|
type antigravityClientWriter struct {
|
||||||
|
w gin.ResponseWriter
|
||||||
|
flusher http.Flusher
|
||||||
|
disconnected bool
|
||||||
|
prefix string // 日志前缀,标识来源方法
|
||||||
|
}
|
||||||
|
|
||||||
|
func newAntigravityClientWriter(w gin.ResponseWriter, flusher http.Flusher, prefix string) *antigravityClientWriter {
|
||||||
|
return &antigravityClientWriter{w: w, flusher: flusher, prefix: prefix}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Write 写入数据到客户端,写入失败时标记断开并返回 false
|
||||||
|
func (cw *antigravityClientWriter) Write(p []byte) bool {
|
||||||
|
if cw.disconnected {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if _, err := cw.w.Write(p); err != nil {
|
||||||
|
cw.markDisconnected()
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
cw.flusher.Flush()
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
// Fprintf 格式化写入数据到客户端,写入失败时标记断开并返回 false
|
||||||
|
func (cw *antigravityClientWriter) Fprintf(format string, args ...any) bool {
|
||||||
|
if cw.disconnected {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if _, err := fmt.Fprintf(cw.w, format, args...); err != nil {
|
||||||
|
cw.markDisconnected()
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
cw.flusher.Flush()
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
func (cw *antigravityClientWriter) Disconnected() bool { return cw.disconnected }
|
||||||
|
|
||||||
|
func (cw *antigravityClientWriter) markDisconnected() {
|
||||||
|
cw.disconnected = true
|
||||||
|
log.Printf("Client disconnected during streaming (%s), continuing to drain upstream for billing", cw.prefix)
|
||||||
|
}
|
||||||
|
|
||||||
|
// handleStreamReadError 处理上游读取错误的通用逻辑。
|
||||||
|
// 返回 (clientDisconnect, handled):handled=true 表示错误已处理,调用方应返回已收集的 usage。
|
||||||
|
func handleStreamReadError(err error, clientDisconnected bool, prefix string) (disconnect bool, handled bool) {
|
||||||
|
if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) {
|
||||||
|
log.Printf("Context canceled during streaming (%s), returning collected usage", prefix)
|
||||||
|
return true, true
|
||||||
|
}
|
||||||
|
if clientDisconnected {
|
||||||
|
log.Printf("Upstream read error after client disconnect (%s): %v, returning collected usage", prefix, err)
|
||||||
|
return true, true
|
||||||
|
}
|
||||||
|
return false, false
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *AntigravityGatewayService) handleGeminiStreamingResponse(c *gin.Context, resp *http.Response, startTime time.Time) (*antigravityStreamResult, error) {
|
func (s *AntigravityGatewayService) handleGeminiStreamingResponse(c *gin.Context, resp *http.Response, startTime time.Time) (*antigravityStreamResult, error) {
|
||||||
@@ -2396,10 +2463,12 @@ func (s *AntigravityGatewayService) handleGeminiStreamingResponse(c *gin.Context
|
|||||||
intervalCh = intervalTicker.C
|
intervalCh = intervalTicker.C
|
||||||
}
|
}
|
||||||
|
|
||||||
|
cw := newAntigravityClientWriter(c.Writer, flusher, "antigravity gemini")
|
||||||
|
|
||||||
// 仅发送一次错误事件,避免多次写入导致协议混乱
|
// 仅发送一次错误事件,避免多次写入导致协议混乱
|
||||||
errorEventSent := false
|
errorEventSent := false
|
||||||
sendErrorEvent := func(reason string) {
|
sendErrorEvent := func(reason string) {
|
||||||
if errorEventSent {
|
if errorEventSent || cw.Disconnected() {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
errorEventSent = true
|
errorEventSent = true
|
||||||
@@ -2411,9 +2480,12 @@ func (s *AntigravityGatewayService) handleGeminiStreamingResponse(c *gin.Context
|
|||||||
select {
|
select {
|
||||||
case ev, ok := <-events:
|
case ev, ok := <-events:
|
||||||
if !ok {
|
if !ok {
|
||||||
return &antigravityStreamResult{usage: usage, firstTokenMs: firstTokenMs}, nil
|
return &antigravityStreamResult{usage: usage, firstTokenMs: firstTokenMs, clientDisconnect: cw.Disconnected()}, nil
|
||||||
}
|
}
|
||||||
if ev.err != nil {
|
if ev.err != nil {
|
||||||
|
if disconnect, handled := handleStreamReadError(ev.err, cw.Disconnected(), "antigravity gemini"); handled {
|
||||||
|
return &antigravityStreamResult{usage: usage, firstTokenMs: firstTokenMs, clientDisconnect: disconnect}, nil
|
||||||
|
}
|
||||||
if errors.Is(ev.err, bufio.ErrTooLong) {
|
if errors.Is(ev.err, bufio.ErrTooLong) {
|
||||||
log.Printf("SSE line too long (antigravity): max_size=%d error=%v", maxLineSize, ev.err)
|
log.Printf("SSE line too long (antigravity): max_size=%d error=%v", maxLineSize, ev.err)
|
||||||
sendErrorEvent("response_too_large")
|
sendErrorEvent("response_too_large")
|
||||||
@@ -2428,11 +2500,7 @@ func (s *AntigravityGatewayService) handleGeminiStreamingResponse(c *gin.Context
|
|||||||
if strings.HasPrefix(trimmed, "data:") {
|
if strings.HasPrefix(trimmed, "data:") {
|
||||||
payload := strings.TrimSpace(strings.TrimPrefix(trimmed, "data:"))
|
payload := strings.TrimSpace(strings.TrimPrefix(trimmed, "data:"))
|
||||||
if payload == "" || payload == "[DONE]" {
|
if payload == "" || payload == "[DONE]" {
|
||||||
if _, err := fmt.Fprintf(c.Writer, "%s\n", line); err != nil {
|
cw.Fprintf("%s\n", line)
|
||||||
sendErrorEvent("write_failed")
|
|
||||||
return &antigravityStreamResult{usage: usage, firstTokenMs: firstTokenMs}, err
|
|
||||||
}
|
|
||||||
flusher.Flush()
|
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -2468,27 +2536,22 @@ func (s *AntigravityGatewayService) handleGeminiStreamingResponse(c *gin.Context
|
|||||||
firstTokenMs = &ms
|
firstTokenMs = &ms
|
||||||
}
|
}
|
||||||
|
|
||||||
if _, err := fmt.Fprintf(c.Writer, "data: %s\n\n", payload); err != nil {
|
cw.Fprintf("data: %s\n\n", payload)
|
||||||
sendErrorEvent("write_failed")
|
|
||||||
return &antigravityStreamResult{usage: usage, firstTokenMs: firstTokenMs}, err
|
|
||||||
}
|
|
||||||
flusher.Flush()
|
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
if _, err := fmt.Fprintf(c.Writer, "%s\n", line); err != nil {
|
cw.Fprintf("%s\n", line)
|
||||||
sendErrorEvent("write_failed")
|
|
||||||
return &antigravityStreamResult{usage: usage, firstTokenMs: firstTokenMs}, err
|
|
||||||
}
|
|
||||||
flusher.Flush()
|
|
||||||
|
|
||||||
case <-intervalCh:
|
case <-intervalCh:
|
||||||
lastRead := time.Unix(0, atomic.LoadInt64(&lastReadAt))
|
lastRead := time.Unix(0, atomic.LoadInt64(&lastReadAt))
|
||||||
if time.Since(lastRead) < streamInterval {
|
if time.Since(lastRead) < streamInterval {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
if cw.Disconnected() {
|
||||||
|
log.Printf("Upstream timeout after client disconnect (antigravity gemini), returning collected usage")
|
||||||
|
return &antigravityStreamResult{usage: usage, firstTokenMs: firstTokenMs, clientDisconnect: true}, nil
|
||||||
|
}
|
||||||
log.Printf("Stream data interval timeout (antigravity)")
|
log.Printf("Stream data interval timeout (antigravity)")
|
||||||
// 注意:此函数没有 account 上下文,无法调用 HandleStreamTimeout
|
|
||||||
sendErrorEvent("stream_timeout")
|
sendErrorEvent("stream_timeout")
|
||||||
return &antigravityStreamResult{usage: usage, firstTokenMs: firstTokenMs}, fmt.Errorf("stream data interval timeout")
|
return &antigravityStreamResult{usage: usage, firstTokenMs: firstTokenMs}, fmt.Errorf("stream data interval timeout")
|
||||||
}
|
}
|
||||||
@@ -3186,10 +3249,12 @@ func (s *AntigravityGatewayService) handleClaudeStreamingResponse(c *gin.Context
|
|||||||
intervalCh = intervalTicker.C
|
intervalCh = intervalTicker.C
|
||||||
}
|
}
|
||||||
|
|
||||||
|
cw := newAntigravityClientWriter(c.Writer, flusher, "antigravity claude")
|
||||||
|
|
||||||
// 仅发送一次错误事件,避免多次写入导致协议混乱
|
// 仅发送一次错误事件,避免多次写入导致协议混乱
|
||||||
errorEventSent := false
|
errorEventSent := false
|
||||||
sendErrorEvent := func(reason string) {
|
sendErrorEvent := func(reason string) {
|
||||||
if errorEventSent {
|
if errorEventSent || cw.Disconnected() {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
errorEventSent = true
|
errorEventSent = true
|
||||||
@@ -3197,19 +3262,27 @@ func (s *AntigravityGatewayService) handleClaudeStreamingResponse(c *gin.Context
|
|||||||
flusher.Flush()
|
flusher.Flush()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// finishUsage 是获取 processor 最终 usage 的辅助函数
|
||||||
|
finishUsage := func() *ClaudeUsage {
|
||||||
|
_, agUsage := processor.Finish()
|
||||||
|
return convertUsage(agUsage)
|
||||||
|
}
|
||||||
|
|
||||||
for {
|
for {
|
||||||
select {
|
select {
|
||||||
case ev, ok := <-events:
|
case ev, ok := <-events:
|
||||||
if !ok {
|
if !ok {
|
||||||
// 发送结束事件
|
// 上游完成,发送结束事件
|
||||||
finalEvents, agUsage := processor.Finish()
|
finalEvents, agUsage := processor.Finish()
|
||||||
if len(finalEvents) > 0 {
|
if len(finalEvents) > 0 {
|
||||||
_, _ = c.Writer.Write(finalEvents)
|
cw.Write(finalEvents)
|
||||||
flusher.Flush()
|
|
||||||
}
|
}
|
||||||
return &antigravityStreamResult{usage: convertUsage(agUsage), firstTokenMs: firstTokenMs}, nil
|
return &antigravityStreamResult{usage: convertUsage(agUsage), firstTokenMs: firstTokenMs, clientDisconnect: cw.Disconnected()}, nil
|
||||||
}
|
}
|
||||||
if ev.err != nil {
|
if ev.err != nil {
|
||||||
|
if disconnect, handled := handleStreamReadError(ev.err, cw.Disconnected(), "antigravity claude"); handled {
|
||||||
|
return &antigravityStreamResult{usage: finishUsage(), firstTokenMs: firstTokenMs, clientDisconnect: disconnect}, nil
|
||||||
|
}
|
||||||
if errors.Is(ev.err, bufio.ErrTooLong) {
|
if errors.Is(ev.err, bufio.ErrTooLong) {
|
||||||
log.Printf("SSE line too long (antigravity): max_size=%d error=%v", maxLineSize, ev.err)
|
log.Printf("SSE line too long (antigravity): max_size=%d error=%v", maxLineSize, ev.err)
|
||||||
sendErrorEvent("response_too_large")
|
sendErrorEvent("response_too_large")
|
||||||
@@ -3219,25 +3292,14 @@ func (s *AntigravityGatewayService) handleClaudeStreamingResponse(c *gin.Context
|
|||||||
return nil, fmt.Errorf("stream read error: %w", ev.err)
|
return nil, fmt.Errorf("stream read error: %w", ev.err)
|
||||||
}
|
}
|
||||||
|
|
||||||
line := ev.line
|
|
||||||
// 处理 SSE 行,转换为 Claude 格式
|
// 处理 SSE 行,转换为 Claude 格式
|
||||||
claudeEvents := processor.ProcessLine(strings.TrimRight(line, "\r\n"))
|
claudeEvents := processor.ProcessLine(strings.TrimRight(ev.line, "\r\n"))
|
||||||
|
|
||||||
if len(claudeEvents) > 0 {
|
if len(claudeEvents) > 0 {
|
||||||
if firstTokenMs == nil {
|
if firstTokenMs == nil {
|
||||||
ms := int(time.Since(startTime).Milliseconds())
|
ms := int(time.Since(startTime).Milliseconds())
|
||||||
firstTokenMs = &ms
|
firstTokenMs = &ms
|
||||||
}
|
}
|
||||||
|
cw.Write(claudeEvents)
|
||||||
if _, writeErr := c.Writer.Write(claudeEvents); writeErr != nil {
|
|
||||||
finalEvents, agUsage := processor.Finish()
|
|
||||||
if len(finalEvents) > 0 {
|
|
||||||
_, _ = c.Writer.Write(finalEvents)
|
|
||||||
}
|
|
||||||
sendErrorEvent("write_failed")
|
|
||||||
return &antigravityStreamResult{usage: convertUsage(agUsage), firstTokenMs: firstTokenMs}, writeErr
|
|
||||||
}
|
|
||||||
flusher.Flush()
|
|
||||||
}
|
}
|
||||||
|
|
||||||
case <-intervalCh:
|
case <-intervalCh:
|
||||||
@@ -3245,13 +3307,15 @@ func (s *AntigravityGatewayService) handleClaudeStreamingResponse(c *gin.Context
|
|||||||
if time.Since(lastRead) < streamInterval {
|
if time.Since(lastRead) < streamInterval {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
if cw.Disconnected() {
|
||||||
|
log.Printf("Upstream timeout after client disconnect (antigravity claude), returning collected usage")
|
||||||
|
return &antigravityStreamResult{usage: finishUsage(), firstTokenMs: firstTokenMs, clientDisconnect: true}, nil
|
||||||
|
}
|
||||||
log.Printf("Stream data interval timeout (antigravity)")
|
log.Printf("Stream data interval timeout (antigravity)")
|
||||||
// 注意:此函数没有 account 上下文,无法调用 HandleStreamTimeout
|
|
||||||
sendErrorEvent("stream_timeout")
|
sendErrorEvent("stream_timeout")
|
||||||
return &antigravityStreamResult{usage: convertUsage(nil), firstTokenMs: firstTokenMs}, fmt.Errorf("stream data interval timeout")
|
return &antigravityStreamResult{usage: convertUsage(nil), firstTokenMs: firstTokenMs}, fmt.Errorf("stream data interval timeout")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// extractImageSize 从 Gemini 请求中提取 image_size 参数
|
// extractImageSize 从 Gemini 请求中提取 image_size 参数
|
||||||
@@ -3390,3 +3454,289 @@ func filterEmptyPartsFromGeminiRequest(body []byte) ([]byte, error) {
|
|||||||
payload["contents"] = filtered
|
payload["contents"] = filtered
|
||||||
return json.Marshal(payload)
|
return json.Marshal(payload)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ForwardUpstream 使用 base_url + /v1/messages + 双 header 认证透传上游 Claude 请求
|
||||||
|
func (s *AntigravityGatewayService) ForwardUpstream(ctx context.Context, c *gin.Context, account *Account, body []byte) (*ForwardResult, error) {
|
||||||
|
startTime := time.Now()
|
||||||
|
sessionID := getSessionID(c)
|
||||||
|
prefix := logPrefix(sessionID, account.Name)
|
||||||
|
|
||||||
|
// 获取上游配置
|
||||||
|
baseURL := strings.TrimSpace(account.GetCredential("base_url"))
|
||||||
|
apiKey := strings.TrimSpace(account.GetCredential("api_key"))
|
||||||
|
if baseURL == "" || apiKey == "" {
|
||||||
|
return nil, fmt.Errorf("upstream account missing base_url or api_key")
|
||||||
|
}
|
||||||
|
baseURL = strings.TrimSuffix(baseURL, "/")
|
||||||
|
|
||||||
|
// 解析请求获取模型信息
|
||||||
|
var claudeReq antigravity.ClaudeRequest
|
||||||
|
if err := json.Unmarshal(body, &claudeReq); err != nil {
|
||||||
|
return nil, fmt.Errorf("parse claude request: %w", err)
|
||||||
|
}
|
||||||
|
if strings.TrimSpace(claudeReq.Model) == "" {
|
||||||
|
return nil, fmt.Errorf("missing model")
|
||||||
|
}
|
||||||
|
originalModel := claudeReq.Model
|
||||||
|
billingModel := originalModel
|
||||||
|
|
||||||
|
// 构建上游请求 URL
|
||||||
|
upstreamURL := baseURL + "/v1/messages"
|
||||||
|
|
||||||
|
// 创建请求
|
||||||
|
req, err := http.NewRequestWithContext(ctx, http.MethodPost, upstreamURL, bytes.NewReader(body))
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("create upstream request: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 设置请求头
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
req.Header.Set("Authorization", "Bearer "+apiKey)
|
||||||
|
req.Header.Set("x-api-key", apiKey) // Claude API 兼容
|
||||||
|
|
||||||
|
// 透传 Claude 相关 headers
|
||||||
|
if v := c.GetHeader("anthropic-version"); v != "" {
|
||||||
|
req.Header.Set("anthropic-version", v)
|
||||||
|
}
|
||||||
|
if v := c.GetHeader("anthropic-beta"); v != "" {
|
||||||
|
req.Header.Set("anthropic-beta", v)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 代理 URL
|
||||||
|
proxyURL := ""
|
||||||
|
if account.ProxyID != nil && account.Proxy != nil {
|
||||||
|
proxyURL = account.Proxy.URL()
|
||||||
|
}
|
||||||
|
|
||||||
|
// 发送请求
|
||||||
|
resp, err := s.httpUpstream.Do(req, proxyURL, account.ID, account.Concurrency)
|
||||||
|
if err != nil {
|
||||||
|
log.Printf("%s upstream request failed: %v", prefix, err)
|
||||||
|
return nil, fmt.Errorf("upstream request failed: %w", err)
|
||||||
|
}
|
||||||
|
defer func() { _ = resp.Body.Close() }()
|
||||||
|
|
||||||
|
// 处理错误响应
|
||||||
|
if resp.StatusCode >= 400 {
|
||||||
|
respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
|
||||||
|
|
||||||
|
// 429 错误时标记账号限流
|
||||||
|
if resp.StatusCode == http.StatusTooManyRequests {
|
||||||
|
quotaScope, _ := resolveAntigravityQuotaScope(originalModel)
|
||||||
|
s.handleUpstreamError(ctx, prefix, account, resp.StatusCode, resp.Header, respBody, quotaScope, 0, "", false)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 透传上游错误
|
||||||
|
c.Header("Content-Type", resp.Header.Get("Content-Type"))
|
||||||
|
c.Status(resp.StatusCode)
|
||||||
|
_, _ = c.Writer.Write(respBody)
|
||||||
|
|
||||||
|
return &ForwardResult{
|
||||||
|
Model: billingModel,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// 处理成功响应(流式/非流式)
|
||||||
|
var usage *ClaudeUsage
|
||||||
|
var firstTokenMs *int
|
||||||
|
var clientDisconnect bool
|
||||||
|
|
||||||
|
if claudeReq.Stream {
|
||||||
|
// 流式响应:透传
|
||||||
|
c.Header("Content-Type", "text/event-stream")
|
||||||
|
c.Header("Cache-Control", "no-cache")
|
||||||
|
c.Header("Connection", "keep-alive")
|
||||||
|
c.Header("X-Accel-Buffering", "no")
|
||||||
|
c.Status(http.StatusOK)
|
||||||
|
|
||||||
|
streamRes := s.streamUpstreamResponse(c, resp, startTime)
|
||||||
|
usage = streamRes.usage
|
||||||
|
firstTokenMs = streamRes.firstTokenMs
|
||||||
|
clientDisconnect = streamRes.clientDisconnect
|
||||||
|
} else {
|
||||||
|
// 非流式响应:直接透传
|
||||||
|
respBody, err := io.ReadAll(resp.Body)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("read upstream response: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 提取 usage
|
||||||
|
usage = s.extractClaudeUsage(respBody)
|
||||||
|
|
||||||
|
c.Header("Content-Type", resp.Header.Get("Content-Type"))
|
||||||
|
c.Status(http.StatusOK)
|
||||||
|
_, _ = c.Writer.Write(respBody)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 构建计费结果
|
||||||
|
duration := time.Since(startTime)
|
||||||
|
log.Printf("%s status=success duration_ms=%d", prefix, duration.Milliseconds())
|
||||||
|
|
||||||
|
return &ForwardResult{
|
||||||
|
Model: billingModel,
|
||||||
|
Stream: claudeReq.Stream,
|
||||||
|
Duration: duration,
|
||||||
|
FirstTokenMs: firstTokenMs,
|
||||||
|
ClientDisconnect: clientDisconnect,
|
||||||
|
Usage: ClaudeUsage{
|
||||||
|
InputTokens: usage.InputTokens,
|
||||||
|
OutputTokens: usage.OutputTokens,
|
||||||
|
CacheReadInputTokens: usage.CacheReadInputTokens,
|
||||||
|
CacheCreationInputTokens: usage.CacheCreationInputTokens,
|
||||||
|
},
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// streamUpstreamResponse 透传上游 SSE 流并提取 Claude usage
|
||||||
|
func (s *AntigravityGatewayService) streamUpstreamResponse(c *gin.Context, resp *http.Response, startTime time.Time) *antigravityStreamResult {
|
||||||
|
usage := &ClaudeUsage{}
|
||||||
|
var firstTokenMs *int
|
||||||
|
|
||||||
|
scanner := bufio.NewScanner(resp.Body)
|
||||||
|
maxLineSize := defaultMaxLineSize
|
||||||
|
if s.settingService.cfg != nil && s.settingService.cfg.Gateway.MaxLineSize > 0 {
|
||||||
|
maxLineSize = s.settingService.cfg.Gateway.MaxLineSize
|
||||||
|
}
|
||||||
|
scanner.Buffer(make([]byte, 64*1024), maxLineSize)
|
||||||
|
|
||||||
|
type scanEvent struct {
|
||||||
|
line string
|
||||||
|
err error
|
||||||
|
}
|
||||||
|
events := make(chan scanEvent, 16)
|
||||||
|
done := make(chan struct{})
|
||||||
|
sendEvent := func(ev scanEvent) bool {
|
||||||
|
select {
|
||||||
|
case events <- ev:
|
||||||
|
return true
|
||||||
|
case <-done:
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
var lastReadAt int64
|
||||||
|
atomic.StoreInt64(&lastReadAt, time.Now().UnixNano())
|
||||||
|
go func() {
|
||||||
|
defer close(events)
|
||||||
|
for scanner.Scan() {
|
||||||
|
atomic.StoreInt64(&lastReadAt, time.Now().UnixNano())
|
||||||
|
if !sendEvent(scanEvent{line: scanner.Text()}) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if err := scanner.Err(); err != nil {
|
||||||
|
_ = sendEvent(scanEvent{err: err})
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
defer close(done)
|
||||||
|
|
||||||
|
streamInterval := time.Duration(0)
|
||||||
|
if s.settingService.cfg != nil && s.settingService.cfg.Gateway.StreamDataIntervalTimeout > 0 {
|
||||||
|
streamInterval = time.Duration(s.settingService.cfg.Gateway.StreamDataIntervalTimeout) * time.Second
|
||||||
|
}
|
||||||
|
var intervalTicker *time.Ticker
|
||||||
|
if streamInterval > 0 {
|
||||||
|
intervalTicker = time.NewTicker(streamInterval)
|
||||||
|
defer intervalTicker.Stop()
|
||||||
|
}
|
||||||
|
var intervalCh <-chan time.Time
|
||||||
|
if intervalTicker != nil {
|
||||||
|
intervalCh = intervalTicker.C
|
||||||
|
}
|
||||||
|
|
||||||
|
flusher, _ := c.Writer.(http.Flusher)
|
||||||
|
cw := newAntigravityClientWriter(c.Writer, flusher, "antigravity upstream")
|
||||||
|
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case ev, ok := <-events:
|
||||||
|
if !ok {
|
||||||
|
return &antigravityStreamResult{usage: usage, firstTokenMs: firstTokenMs, clientDisconnect: cw.Disconnected()}
|
||||||
|
}
|
||||||
|
if ev.err != nil {
|
||||||
|
if disconnect, handled := handleStreamReadError(ev.err, cw.Disconnected(), "antigravity upstream"); handled {
|
||||||
|
return &antigravityStreamResult{usage: usage, firstTokenMs: firstTokenMs, clientDisconnect: disconnect}
|
||||||
|
}
|
||||||
|
log.Printf("Stream read error (antigravity upstream): %v", ev.err)
|
||||||
|
return &antigravityStreamResult{usage: usage, firstTokenMs: firstTokenMs}
|
||||||
|
}
|
||||||
|
|
||||||
|
line := ev.line
|
||||||
|
|
||||||
|
// 记录首 token 时间
|
||||||
|
if firstTokenMs == nil && len(line) > 0 {
|
||||||
|
ms := int(time.Since(startTime).Milliseconds())
|
||||||
|
firstTokenMs = &ms
|
||||||
|
}
|
||||||
|
|
||||||
|
// 尝试从 message_delta 或 message_stop 事件提取 usage
|
||||||
|
s.extractSSEUsage(line, usage)
|
||||||
|
|
||||||
|
// 透传行
|
||||||
|
cw.Fprintf("%s\n", line)
|
||||||
|
|
||||||
|
case <-intervalCh:
|
||||||
|
lastRead := time.Unix(0, atomic.LoadInt64(&lastReadAt))
|
||||||
|
if time.Since(lastRead) < streamInterval {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if cw.Disconnected() {
|
||||||
|
log.Printf("Upstream timeout after client disconnect (antigravity upstream), returning collected usage")
|
||||||
|
return &antigravityStreamResult{usage: usage, firstTokenMs: firstTokenMs, clientDisconnect: true}
|
||||||
|
}
|
||||||
|
log.Printf("Stream data interval timeout (antigravity upstream)")
|
||||||
|
return &antigravityStreamResult{usage: usage, firstTokenMs: firstTokenMs}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// extractSSEUsage 从 SSE data 行中提取 Claude usage(用于流式透传场景)
|
||||||
|
func (s *AntigravityGatewayService) extractSSEUsage(line string, usage *ClaudeUsage) {
|
||||||
|
if !strings.HasPrefix(line, "data: ") {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
dataStr := strings.TrimPrefix(line, "data: ")
|
||||||
|
var event map[string]any
|
||||||
|
if json.Unmarshal([]byte(dataStr), &event) != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
u, ok := event["usage"].(map[string]any)
|
||||||
|
if !ok {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if v, ok := u["input_tokens"].(float64); ok && int(v) > 0 {
|
||||||
|
usage.InputTokens = int(v)
|
||||||
|
}
|
||||||
|
if v, ok := u["output_tokens"].(float64); ok && int(v) > 0 {
|
||||||
|
usage.OutputTokens = int(v)
|
||||||
|
}
|
||||||
|
if v, ok := u["cache_read_input_tokens"].(float64); ok && int(v) > 0 {
|
||||||
|
usage.CacheReadInputTokens = int(v)
|
||||||
|
}
|
||||||
|
if v, ok := u["cache_creation_input_tokens"].(float64); ok && int(v) > 0 {
|
||||||
|
usage.CacheCreationInputTokens = int(v)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// extractClaudeUsage 从非流式 Claude 响应提取 usage
|
||||||
|
func (s *AntigravityGatewayService) extractClaudeUsage(body []byte) *ClaudeUsage {
|
||||||
|
usage := &ClaudeUsage{}
|
||||||
|
var resp map[string]any
|
||||||
|
if json.Unmarshal(body, &resp) != nil {
|
||||||
|
return usage
|
||||||
|
}
|
||||||
|
if u, ok := resp["usage"].(map[string]any); ok {
|
||||||
|
if v, ok := u["input_tokens"].(float64); ok {
|
||||||
|
usage.InputTokens = int(v)
|
||||||
|
}
|
||||||
|
if v, ok := u["output_tokens"].(float64); ok {
|
||||||
|
usage.OutputTokens = int(v)
|
||||||
|
}
|
||||||
|
if v, ok := u["cache_read_input_tokens"].(float64); ok {
|
||||||
|
usage.CacheReadInputTokens = int(v)
|
||||||
|
}
|
||||||
|
if v, ok := u["cache_creation_input_tokens"].(float64); ok {
|
||||||
|
usage.CacheCreationInputTokens = int(v)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return usage
|
||||||
|
}
|
||||||
|
|||||||
@@ -4,17 +4,42 @@ import (
|
|||||||
"bytes"
|
"bytes"
|
||||||
"context"
|
"context"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/http/httptest"
|
"net/http/httptest"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||||
"github.com/Wei-Shaw/sub2api/internal/pkg/antigravity"
|
"github.com/Wei-Shaw/sub2api/internal/pkg/antigravity"
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// antigravityFailingWriter 模拟客户端断开连接的 gin.ResponseWriter
|
||||||
|
type antigravityFailingWriter struct {
|
||||||
|
gin.ResponseWriter
|
||||||
|
failAfter int // 允许成功写入的次数,之后所有写入返回错误
|
||||||
|
writes int
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *antigravityFailingWriter) Write(p []byte) (int, error) {
|
||||||
|
if w.writes >= w.failAfter {
|
||||||
|
return 0, errors.New("write failed: client disconnected")
|
||||||
|
}
|
||||||
|
w.writes++
|
||||||
|
return w.ResponseWriter.Write(p)
|
||||||
|
}
|
||||||
|
|
||||||
|
// newAntigravityTestService 创建用于流式测试的 AntigravityGatewayService
|
||||||
|
func newAntigravityTestService(cfg *config.Config) *AntigravityGatewayService {
|
||||||
|
return &AntigravityGatewayService{
|
||||||
|
settingService: &SettingService{cfg: cfg},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestStripSignatureSensitiveBlocksFromClaudeRequest(t *testing.T) {
|
func TestStripSignatureSensitiveBlocksFromClaudeRequest(t *testing.T) {
|
||||||
req := &antigravity.ClaudeRequest{
|
req := &antigravity.ClaudeRequest{
|
||||||
Model: "claude-sonnet-4-5",
|
Model: "claude-sonnet-4-5",
|
||||||
@@ -337,8 +362,8 @@ func TestAntigravityGatewayService_Forward_StickySessionForceCacheBilling(t *tes
|
|||||||
require.True(t, failoverErr.ForceCacheBilling, "ForceCacheBilling should be true for sticky session switch")
|
require.True(t, failoverErr.ForceCacheBilling, "ForceCacheBilling should be true for sticky session switch")
|
||||||
}
|
}
|
||||||
|
|
||||||
// TestAntigravityGatewayService_ForwardGemini_StickySessionForceCacheBilling
|
// TestAntigravityGatewayService_ForwardGemini_StickySessionForceCacheBilling verifies
|
||||||
// 验证:ForwardGemini 粘性会话切换时,UpstreamFailoverError.ForceCacheBilling 应为 true
|
// that ForwardGemini sets ForceCacheBilling=true for sticky session switch.
|
||||||
func TestAntigravityGatewayService_ForwardGemini_StickySessionForceCacheBilling(t *testing.T) {
|
func TestAntigravityGatewayService_ForwardGemini_StickySessionForceCacheBilling(t *testing.T) {
|
||||||
gin.SetMode(gin.TestMode)
|
gin.SetMode(gin.TestMode)
|
||||||
writer := httptest.NewRecorder()
|
writer := httptest.NewRecorder()
|
||||||
@@ -391,3 +416,438 @@ func TestAntigravityGatewayService_ForwardGemini_StickySessionForceCacheBilling(
|
|||||||
require.Equal(t, http.StatusServiceUnavailable, failoverErr.StatusCode)
|
require.Equal(t, http.StatusServiceUnavailable, failoverErr.StatusCode)
|
||||||
require.True(t, failoverErr.ForceCacheBilling, "ForceCacheBilling should be true for sticky session switch")
|
require.True(t, failoverErr.ForceCacheBilling, "ForceCacheBilling should be true for sticky session switch")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// --- 流式 happy path 测试 ---
|
||||||
|
|
||||||
|
// TestStreamUpstreamResponse_NormalComplete
|
||||||
|
// 验证:正常流式转发完成时,数据正确透传、usage 正确收集、clientDisconnect=false
|
||||||
|
func TestStreamUpstreamResponse_NormalComplete(t *testing.T) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
svc := newAntigravityTestService(&config.Config{
|
||||||
|
Gateway: config.GatewayConfig{MaxLineSize: defaultMaxLineSize},
|
||||||
|
})
|
||||||
|
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
c, _ := gin.CreateTestContext(rec)
|
||||||
|
c.Request = httptest.NewRequest(http.MethodPost, "/", nil)
|
||||||
|
|
||||||
|
pr, pw := io.Pipe()
|
||||||
|
resp := &http.Response{StatusCode: http.StatusOK, Body: pr, Header: http.Header{}}
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
defer func() { _ = pw.Close() }()
|
||||||
|
fmt.Fprintln(pw, `event: message_start`)
|
||||||
|
fmt.Fprintln(pw, `data: {"type":"message_start","message":{"usage":{"input_tokens":10}}}`)
|
||||||
|
fmt.Fprintln(pw, "")
|
||||||
|
fmt.Fprintln(pw, `event: content_block_delta`)
|
||||||
|
fmt.Fprintln(pw, `data: {"type":"content_block_delta","delta":{"text":"hello"}}`)
|
||||||
|
fmt.Fprintln(pw, "")
|
||||||
|
fmt.Fprintln(pw, `event: message_delta`)
|
||||||
|
fmt.Fprintln(pw, `data: {"type":"message_delta","usage":{"output_tokens":5}}`)
|
||||||
|
fmt.Fprintln(pw, "")
|
||||||
|
}()
|
||||||
|
|
||||||
|
result := svc.streamUpstreamResponse(c, resp, time.Now())
|
||||||
|
_ = pr.Close()
|
||||||
|
|
||||||
|
require.NotNil(t, result)
|
||||||
|
require.False(t, result.clientDisconnect, "normal completion should not set clientDisconnect")
|
||||||
|
require.NotNil(t, result.usage)
|
||||||
|
require.Equal(t, 5, result.usage.OutputTokens, "should collect output_tokens from message_delta")
|
||||||
|
require.NotNil(t, result.firstTokenMs, "should record first token time")
|
||||||
|
|
||||||
|
// 验证数据被透传到客户端
|
||||||
|
body := rec.Body.String()
|
||||||
|
require.Contains(t, body, "event: message_start")
|
||||||
|
require.Contains(t, body, "content_block_delta")
|
||||||
|
require.Contains(t, body, "message_delta")
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestHandleGeminiStreamingResponse_NormalComplete
|
||||||
|
// 验证:正常 Gemini 流式转发,数据正确透传、usage 正确收集
|
||||||
|
func TestHandleGeminiStreamingResponse_NormalComplete(t *testing.T) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
svc := newAntigravityTestService(&config.Config{
|
||||||
|
Gateway: config.GatewayConfig{MaxLineSize: defaultMaxLineSize},
|
||||||
|
})
|
||||||
|
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
c, _ := gin.CreateTestContext(rec)
|
||||||
|
c.Request = httptest.NewRequest(http.MethodPost, "/", nil)
|
||||||
|
|
||||||
|
pr, pw := io.Pipe()
|
||||||
|
resp := &http.Response{StatusCode: http.StatusOK, Body: pr, Header: http.Header{}}
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
defer func() { _ = pw.Close() }()
|
||||||
|
// 第一个 chunk(部分内容)
|
||||||
|
fmt.Fprintln(pw, `data: {"candidates":[{"content":{"parts":[{"text":"Hello"}]}}],"usageMetadata":{"promptTokenCount":10,"candidatesTokenCount":3}}`)
|
||||||
|
fmt.Fprintln(pw, "")
|
||||||
|
// 第二个 chunk(最终内容+完整 usage)
|
||||||
|
fmt.Fprintln(pw, `data: {"candidates":[{"content":{"parts":[{"text":" world"}]},"finishReason":"STOP"}],"usageMetadata":{"promptTokenCount":10,"candidatesTokenCount":8,"cachedContentTokenCount":2}}`)
|
||||||
|
fmt.Fprintln(pw, "")
|
||||||
|
}()
|
||||||
|
|
||||||
|
result, err := svc.handleGeminiStreamingResponse(c, resp, time.Now())
|
||||||
|
_ = pr.Close()
|
||||||
|
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, result)
|
||||||
|
require.False(t, result.clientDisconnect, "normal completion should not set clientDisconnect")
|
||||||
|
require.NotNil(t, result.usage)
|
||||||
|
// Gemini usage: promptTokenCount=10, candidatesTokenCount=8, cachedContentTokenCount=2
|
||||||
|
// → InputTokens=10-2=8, OutputTokens=8, CacheReadInputTokens=2
|
||||||
|
require.Equal(t, 8, result.usage.InputTokens)
|
||||||
|
require.Equal(t, 8, result.usage.OutputTokens)
|
||||||
|
require.Equal(t, 2, result.usage.CacheReadInputTokens)
|
||||||
|
require.NotNil(t, result.firstTokenMs, "should record first token time")
|
||||||
|
|
||||||
|
// 验证数据被透传到客户端
|
||||||
|
body := rec.Body.String()
|
||||||
|
require.Contains(t, body, "Hello")
|
||||||
|
require.Contains(t, body, "world")
|
||||||
|
// 不应包含错误事件
|
||||||
|
require.NotContains(t, body, "event: error")
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestHandleClaudeStreamingResponse_NormalComplete
|
||||||
|
// 验证:正常 Claude 流式转发(Gemini→Claude 转换),数据正确转换并输出
|
||||||
|
func TestHandleClaudeStreamingResponse_NormalComplete(t *testing.T) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
svc := newAntigravityTestService(&config.Config{
|
||||||
|
Gateway: config.GatewayConfig{MaxLineSize: defaultMaxLineSize},
|
||||||
|
})
|
||||||
|
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
c, _ := gin.CreateTestContext(rec)
|
||||||
|
c.Request = httptest.NewRequest(http.MethodPost, "/", nil)
|
||||||
|
|
||||||
|
pr, pw := io.Pipe()
|
||||||
|
resp := &http.Response{StatusCode: http.StatusOK, Body: pr, Header: http.Header{}}
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
defer func() { _ = pw.Close() }()
|
||||||
|
// v1internal 包装格式:Gemini 数据嵌套在 "response" 字段下
|
||||||
|
// ProcessLine 先尝试反序列化为 V1InternalResponse,裸格式会导致 Response.UsageMetadata 为空
|
||||||
|
fmt.Fprintln(pw, `data: {"response":{"candidates":[{"content":{"parts":[{"text":"Hi there"}]},"finishReason":"STOP"}],"usageMetadata":{"promptTokenCount":5,"candidatesTokenCount":3}}}`)
|
||||||
|
fmt.Fprintln(pw, "")
|
||||||
|
}()
|
||||||
|
|
||||||
|
result, err := svc.handleClaudeStreamingResponse(c, resp, time.Now(), "claude-sonnet-4-5")
|
||||||
|
_ = pr.Close()
|
||||||
|
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, result)
|
||||||
|
require.False(t, result.clientDisconnect, "normal completion should not set clientDisconnect")
|
||||||
|
require.NotNil(t, result.usage)
|
||||||
|
// Gemini→Claude 转换的 usage:promptTokenCount=5→InputTokens=5, candidatesTokenCount=3→OutputTokens=3
|
||||||
|
require.Equal(t, 5, result.usage.InputTokens)
|
||||||
|
require.Equal(t, 3, result.usage.OutputTokens)
|
||||||
|
require.NotNil(t, result.firstTokenMs, "should record first token time")
|
||||||
|
|
||||||
|
// 验证输出是 Claude SSE 格式(processor 会转换)
|
||||||
|
body := rec.Body.String()
|
||||||
|
require.Contains(t, body, "event: message_start", "should contain Claude message_start event")
|
||||||
|
require.Contains(t, body, "event: message_stop", "should contain Claude message_stop event")
|
||||||
|
// 不应包含错误事件
|
||||||
|
require.NotContains(t, body, "event: error")
|
||||||
|
}
|
||||||
|
|
||||||
|
// --- 流式客户端断开检测测试 ---
|
||||||
|
|
||||||
|
// TestStreamUpstreamResponse_ClientDisconnectDrainsUsage
|
||||||
|
// 验证:客户端写入失败后,streamUpstreamResponse 继续读取上游以收集 usage
|
||||||
|
func TestStreamUpstreamResponse_ClientDisconnectDrainsUsage(t *testing.T) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
svc := newAntigravityTestService(&config.Config{
|
||||||
|
Gateway: config.GatewayConfig{MaxLineSize: defaultMaxLineSize},
|
||||||
|
})
|
||||||
|
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
c, _ := gin.CreateTestContext(rec)
|
||||||
|
c.Request = httptest.NewRequest(http.MethodPost, "/", nil)
|
||||||
|
c.Writer = &antigravityFailingWriter{ResponseWriter: c.Writer, failAfter: 0}
|
||||||
|
|
||||||
|
pr, pw := io.Pipe()
|
||||||
|
resp := &http.Response{StatusCode: http.StatusOK, Body: pr, Header: http.Header{}}
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
defer func() { _ = pw.Close() }()
|
||||||
|
fmt.Fprintln(pw, `event: message_start`)
|
||||||
|
fmt.Fprintln(pw, `data: {"type":"message_start","message":{"usage":{"input_tokens":10}}}`)
|
||||||
|
fmt.Fprintln(pw, "")
|
||||||
|
fmt.Fprintln(pw, `event: message_delta`)
|
||||||
|
fmt.Fprintln(pw, `data: {"type":"message_delta","usage":{"output_tokens":20}}`)
|
||||||
|
fmt.Fprintln(pw, "")
|
||||||
|
}()
|
||||||
|
|
||||||
|
result := svc.streamUpstreamResponse(c, resp, time.Now())
|
||||||
|
_ = pr.Close()
|
||||||
|
|
||||||
|
require.NotNil(t, result)
|
||||||
|
require.True(t, result.clientDisconnect)
|
||||||
|
require.NotNil(t, result.usage)
|
||||||
|
require.Equal(t, 20, result.usage.OutputTokens)
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestStreamUpstreamResponse_ContextCanceled
|
||||||
|
// 验证:context 取消时返回 usage 且标记 clientDisconnect
|
||||||
|
func TestStreamUpstreamResponse_ContextCanceled(t *testing.T) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
svc := newAntigravityTestService(&config.Config{
|
||||||
|
Gateway: config.GatewayConfig{MaxLineSize: defaultMaxLineSize},
|
||||||
|
})
|
||||||
|
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
c, _ := gin.CreateTestContext(rec)
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
cancel()
|
||||||
|
c.Request = httptest.NewRequest(http.MethodPost, "/", nil).WithContext(ctx)
|
||||||
|
|
||||||
|
resp := &http.Response{StatusCode: http.StatusOK, Body: cancelReadCloser{}, Header: http.Header{}}
|
||||||
|
|
||||||
|
result := svc.streamUpstreamResponse(c, resp, time.Now())
|
||||||
|
|
||||||
|
require.NotNil(t, result)
|
||||||
|
require.True(t, result.clientDisconnect)
|
||||||
|
require.NotContains(t, rec.Body.String(), "event: error")
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestStreamUpstreamResponse_Timeout
|
||||||
|
// 验证:上游超时时返回已收集的 usage
|
||||||
|
func TestStreamUpstreamResponse_Timeout(t *testing.T) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
svc := newAntigravityTestService(&config.Config{
|
||||||
|
Gateway: config.GatewayConfig{StreamDataIntervalTimeout: 1, MaxLineSize: defaultMaxLineSize},
|
||||||
|
})
|
||||||
|
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
c, _ := gin.CreateTestContext(rec)
|
||||||
|
c.Request = httptest.NewRequest(http.MethodPost, "/", nil)
|
||||||
|
|
||||||
|
pr, pw := io.Pipe()
|
||||||
|
resp := &http.Response{StatusCode: http.StatusOK, Body: pr, Header: http.Header{}}
|
||||||
|
|
||||||
|
result := svc.streamUpstreamResponse(c, resp, time.Now())
|
||||||
|
_ = pw.Close()
|
||||||
|
_ = pr.Close()
|
||||||
|
|
||||||
|
require.NotNil(t, result)
|
||||||
|
require.False(t, result.clientDisconnect)
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestStreamUpstreamResponse_TimeoutAfterClientDisconnect
|
||||||
|
// 验证:客户端断开后上游超时,返回 usage 并标记 clientDisconnect
|
||||||
|
func TestStreamUpstreamResponse_TimeoutAfterClientDisconnect(t *testing.T) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
svc := newAntigravityTestService(&config.Config{
|
||||||
|
Gateway: config.GatewayConfig{StreamDataIntervalTimeout: 1, MaxLineSize: defaultMaxLineSize},
|
||||||
|
})
|
||||||
|
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
c, _ := gin.CreateTestContext(rec)
|
||||||
|
c.Request = httptest.NewRequest(http.MethodPost, "/", nil)
|
||||||
|
c.Writer = &antigravityFailingWriter{ResponseWriter: c.Writer, failAfter: 0}
|
||||||
|
|
||||||
|
pr, pw := io.Pipe()
|
||||||
|
resp := &http.Response{StatusCode: http.StatusOK, Body: pr, Header: http.Header{}}
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
fmt.Fprintln(pw, `data: {"type":"message_start","message":{"usage":{"input_tokens":5}}}`)
|
||||||
|
fmt.Fprintln(pw, "")
|
||||||
|
// 不关闭 pw → 等待超时
|
||||||
|
}()
|
||||||
|
|
||||||
|
result := svc.streamUpstreamResponse(c, resp, time.Now())
|
||||||
|
_ = pw.Close()
|
||||||
|
_ = pr.Close()
|
||||||
|
|
||||||
|
require.NotNil(t, result)
|
||||||
|
require.True(t, result.clientDisconnect)
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestHandleGeminiStreamingResponse_ClientDisconnect
|
||||||
|
// 验证:Gemini 流式转发中客户端断开后继续 drain 上游
|
||||||
|
func TestHandleGeminiStreamingResponse_ClientDisconnect(t *testing.T) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
svc := newAntigravityTestService(&config.Config{
|
||||||
|
Gateway: config.GatewayConfig{MaxLineSize: defaultMaxLineSize},
|
||||||
|
})
|
||||||
|
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
c, _ := gin.CreateTestContext(rec)
|
||||||
|
c.Request = httptest.NewRequest(http.MethodPost, "/", nil)
|
||||||
|
c.Writer = &antigravityFailingWriter{ResponseWriter: c.Writer, failAfter: 0}
|
||||||
|
|
||||||
|
pr, pw := io.Pipe()
|
||||||
|
resp := &http.Response{StatusCode: http.StatusOK, Body: pr, Header: http.Header{}}
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
defer func() { _ = pw.Close() }()
|
||||||
|
fmt.Fprintln(pw, `data: {"candidates":[{"content":{"parts":[{"text":"hi"}]}}],"usageMetadata":{"promptTokenCount":5,"candidatesTokenCount":10}}`)
|
||||||
|
fmt.Fprintln(pw, "")
|
||||||
|
}()
|
||||||
|
|
||||||
|
result, err := svc.handleGeminiStreamingResponse(c, resp, time.Now())
|
||||||
|
_ = pr.Close()
|
||||||
|
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, result)
|
||||||
|
require.True(t, result.clientDisconnect)
|
||||||
|
require.NotContains(t, rec.Body.String(), "write_failed")
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestHandleGeminiStreamingResponse_ContextCanceled
|
||||||
|
// 验证:context 取消时不注入错误事件
|
||||||
|
func TestHandleGeminiStreamingResponse_ContextCanceled(t *testing.T) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
svc := newAntigravityTestService(&config.Config{
|
||||||
|
Gateway: config.GatewayConfig{MaxLineSize: defaultMaxLineSize},
|
||||||
|
})
|
||||||
|
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
c, _ := gin.CreateTestContext(rec)
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
cancel()
|
||||||
|
c.Request = httptest.NewRequest(http.MethodPost, "/", nil).WithContext(ctx)
|
||||||
|
|
||||||
|
resp := &http.Response{StatusCode: http.StatusOK, Body: cancelReadCloser{}, Header: http.Header{}}
|
||||||
|
|
||||||
|
result, err := svc.handleGeminiStreamingResponse(c, resp, time.Now())
|
||||||
|
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, result)
|
||||||
|
require.True(t, result.clientDisconnect)
|
||||||
|
require.NotContains(t, rec.Body.String(), "event: error")
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestHandleClaudeStreamingResponse_ClientDisconnect
|
||||||
|
// 验证:Claude 流式转发中客户端断开后继续 drain 上游
|
||||||
|
func TestHandleClaudeStreamingResponse_ClientDisconnect(t *testing.T) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
svc := newAntigravityTestService(&config.Config{
|
||||||
|
Gateway: config.GatewayConfig{MaxLineSize: defaultMaxLineSize},
|
||||||
|
})
|
||||||
|
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
c, _ := gin.CreateTestContext(rec)
|
||||||
|
c.Request = httptest.NewRequest(http.MethodPost, "/", nil)
|
||||||
|
c.Writer = &antigravityFailingWriter{ResponseWriter: c.Writer, failAfter: 0}
|
||||||
|
|
||||||
|
pr, pw := io.Pipe()
|
||||||
|
resp := &http.Response{StatusCode: http.StatusOK, Body: pr, Header: http.Header{}}
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
defer func() { _ = pw.Close() }()
|
||||||
|
// v1internal 包装格式
|
||||||
|
fmt.Fprintln(pw, `data: {"response":{"candidates":[{"content":{"parts":[{"text":"hello"}]},"finishReason":"STOP"}],"usageMetadata":{"promptTokenCount":8,"candidatesTokenCount":15}}}`)
|
||||||
|
fmt.Fprintln(pw, "")
|
||||||
|
}()
|
||||||
|
|
||||||
|
result, err := svc.handleClaudeStreamingResponse(c, resp, time.Now(), "claude-sonnet-4-5")
|
||||||
|
_ = pr.Close()
|
||||||
|
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, result)
|
||||||
|
require.True(t, result.clientDisconnect)
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestHandleClaudeStreamingResponse_ContextCanceled
|
||||||
|
// 验证:context 取消时不注入错误事件
|
||||||
|
func TestHandleClaudeStreamingResponse_ContextCanceled(t *testing.T) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
svc := newAntigravityTestService(&config.Config{
|
||||||
|
Gateway: config.GatewayConfig{MaxLineSize: defaultMaxLineSize},
|
||||||
|
})
|
||||||
|
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
c, _ := gin.CreateTestContext(rec)
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
cancel()
|
||||||
|
c.Request = httptest.NewRequest(http.MethodPost, "/", nil).WithContext(ctx)
|
||||||
|
|
||||||
|
resp := &http.Response{StatusCode: http.StatusOK, Body: cancelReadCloser{}, Header: http.Header{}}
|
||||||
|
|
||||||
|
result, err := svc.handleClaudeStreamingResponse(c, resp, time.Now(), "claude-sonnet-4-5")
|
||||||
|
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, result)
|
||||||
|
require.True(t, result.clientDisconnect)
|
||||||
|
require.NotContains(t, rec.Body.String(), "event: error")
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestExtractSSEUsage 验证 extractSSEUsage 从 SSE data 行正确提取 usage
|
||||||
|
func TestExtractSSEUsage(t *testing.T) {
|
||||||
|
svc := &AntigravityGatewayService{}
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
line string
|
||||||
|
expected ClaudeUsage
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "message_delta with output_tokens",
|
||||||
|
line: `data: {"type":"message_delta","usage":{"output_tokens":42}}`,
|
||||||
|
expected: ClaudeUsage{OutputTokens: 42},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "non-data line ignored",
|
||||||
|
line: `event: message_start`,
|
||||||
|
expected: ClaudeUsage{},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "top-level usage with all fields",
|
||||||
|
line: `data: {"usage":{"input_tokens":10,"output_tokens":20,"cache_read_input_tokens":5,"cache_creation_input_tokens":3}}`,
|
||||||
|
expected: ClaudeUsage{InputTokens: 10, OutputTokens: 20, CacheReadInputTokens: 5, CacheCreationInputTokens: 3},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
usage := &ClaudeUsage{}
|
||||||
|
svc.extractSSEUsage(tt.line, usage)
|
||||||
|
require.Equal(t, tt.expected, *usage)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestAntigravityClientWriter 验证 antigravityClientWriter 的断开检测
|
||||||
|
func TestAntigravityClientWriter(t *testing.T) {
|
||||||
|
t.Run("normal write succeeds", func(t *testing.T) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
c, _ := gin.CreateTestContext(rec)
|
||||||
|
flusher, _ := c.Writer.(http.Flusher)
|
||||||
|
cw := newAntigravityClientWriter(c.Writer, flusher, "test")
|
||||||
|
|
||||||
|
ok := cw.Write([]byte("hello"))
|
||||||
|
require.True(t, ok)
|
||||||
|
require.False(t, cw.Disconnected())
|
||||||
|
require.Contains(t, rec.Body.String(), "hello")
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("write failure marks disconnected", func(t *testing.T) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
c, _ := gin.CreateTestContext(rec)
|
||||||
|
fw := &antigravityFailingWriter{ResponseWriter: c.Writer, failAfter: 0}
|
||||||
|
flusher, _ := c.Writer.(http.Flusher)
|
||||||
|
cw := newAntigravityClientWriter(fw, flusher, "test")
|
||||||
|
|
||||||
|
ok := cw.Write([]byte("hello"))
|
||||||
|
require.False(t, ok)
|
||||||
|
require.True(t, cw.Disconnected())
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("subsequent writes are no-op", func(t *testing.T) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
c, _ := gin.CreateTestContext(rec)
|
||||||
|
fw := &antigravityFailingWriter{ResponseWriter: c.Writer, failAfter: 0}
|
||||||
|
flusher, _ := c.Writer.(http.Flusher)
|
||||||
|
cw := newAntigravityClientWriter(fw, flusher, "test")
|
||||||
|
|
||||||
|
cw.Write([]byte("first"))
|
||||||
|
ok := cw.Fprintf("second %d", 2)
|
||||||
|
require.False(t, ok)
|
||||||
|
require.True(t, cw.Disconnected())
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user