854 lines
30 KiB
Go
854 lines
30 KiB
Go
package service
|
||
|
||
import (
|
||
"bytes"
|
||
"context"
|
||
"encoding/json"
|
||
"errors"
|
||
"fmt"
|
||
"io"
|
||
"net/http"
|
||
"net/http/httptest"
|
||
"testing"
|
||
"time"
|
||
|
||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||
"github.com/Wei-Shaw/sub2api/internal/pkg/antigravity"
|
||
"github.com/gin-gonic/gin"
|
||
"github.com/stretchr/testify/require"
|
||
)
|
||
|
||
// antigravityFailingWriter 模拟客户端断开连接的 gin.ResponseWriter
|
||
type antigravityFailingWriter struct {
|
||
gin.ResponseWriter
|
||
failAfter int // 允许成功写入的次数,之后所有写入返回错误
|
||
writes int
|
||
}
|
||
|
||
func (w *antigravityFailingWriter) Write(p []byte) (int, error) {
|
||
if w.writes >= w.failAfter {
|
||
return 0, errors.New("write failed: client disconnected")
|
||
}
|
||
w.writes++
|
||
return w.ResponseWriter.Write(p)
|
||
}
|
||
|
||
// newAntigravityTestService 创建用于流式测试的 AntigravityGatewayService
|
||
func newAntigravityTestService(cfg *config.Config) *AntigravityGatewayService {
|
||
return &AntigravityGatewayService{
|
||
settingService: &SettingService{cfg: cfg},
|
||
}
|
||
}
|
||
|
||
func TestStripSignatureSensitiveBlocksFromClaudeRequest(t *testing.T) {
|
||
req := &antigravity.ClaudeRequest{
|
||
Model: "claude-sonnet-4-5",
|
||
Thinking: &antigravity.ThinkingConfig{
|
||
Type: "enabled",
|
||
BudgetTokens: 1024,
|
||
},
|
||
Messages: []antigravity.ClaudeMessage{
|
||
{
|
||
Role: "assistant",
|
||
Content: json.RawMessage(`[
|
||
{"type":"thinking","thinking":"secret plan","signature":""},
|
||
{"type":"tool_use","id":"t1","name":"Bash","input":{"command":"ls"}}
|
||
]`),
|
||
},
|
||
{
|
||
Role: "user",
|
||
Content: json.RawMessage(`[
|
||
{"type":"tool_result","tool_use_id":"t1","content":"ok","is_error":false},
|
||
{"type":"redacted_thinking","data":"..."}
|
||
]`),
|
||
},
|
||
},
|
||
}
|
||
|
||
changed, err := stripSignatureSensitiveBlocksFromClaudeRequest(req)
|
||
require.NoError(t, err)
|
||
require.True(t, changed)
|
||
require.Nil(t, req.Thinking)
|
||
|
||
require.Len(t, req.Messages, 2)
|
||
|
||
var blocks0 []map[string]any
|
||
require.NoError(t, json.Unmarshal(req.Messages[0].Content, &blocks0))
|
||
require.Len(t, blocks0, 2)
|
||
require.Equal(t, "text", blocks0[0]["type"])
|
||
require.Equal(t, "secret plan", blocks0[0]["text"])
|
||
require.Equal(t, "text", blocks0[1]["type"])
|
||
|
||
var blocks1 []map[string]any
|
||
require.NoError(t, json.Unmarshal(req.Messages[1].Content, &blocks1))
|
||
require.Len(t, blocks1, 1)
|
||
require.Equal(t, "text", blocks1[0]["type"])
|
||
require.NotEmpty(t, blocks1[0]["text"])
|
||
}
|
||
|
||
func TestStripThinkingFromClaudeRequest_DoesNotDowngradeTools(t *testing.T) {
|
||
req := &antigravity.ClaudeRequest{
|
||
Model: "claude-sonnet-4-5",
|
||
Thinking: &antigravity.ThinkingConfig{
|
||
Type: "enabled",
|
||
BudgetTokens: 1024,
|
||
},
|
||
Messages: []antigravity.ClaudeMessage{
|
||
{
|
||
Role: "assistant",
|
||
Content: json.RawMessage(`[{"type":"thinking","thinking":"secret plan"},{"type":"tool_use","id":"t1","name":"Bash","input":{"command":"ls"}}]`),
|
||
},
|
||
},
|
||
}
|
||
|
||
changed, err := stripThinkingFromClaudeRequest(req)
|
||
require.NoError(t, err)
|
||
require.True(t, changed)
|
||
require.Nil(t, req.Thinking)
|
||
|
||
var blocks []map[string]any
|
||
require.NoError(t, json.Unmarshal(req.Messages[0].Content, &blocks))
|
||
require.Len(t, blocks, 2)
|
||
require.Equal(t, "text", blocks[0]["type"])
|
||
require.Equal(t, "secret plan", blocks[0]["text"])
|
||
require.Equal(t, "tool_use", blocks[1]["type"])
|
||
}
|
||
|
||
func TestIsPromptTooLongError(t *testing.T) {
|
||
require.True(t, isPromptTooLongError([]byte(`{"error":{"message":"Prompt is too long"}}`)))
|
||
require.True(t, isPromptTooLongError([]byte(`{"message":"Prompt is too long"}`)))
|
||
require.False(t, isPromptTooLongError([]byte(`{"error":{"message":"other"}}`)))
|
||
}
|
||
|
||
type httpUpstreamStub struct {
|
||
resp *http.Response
|
||
err error
|
||
}
|
||
|
||
func (s *httpUpstreamStub) Do(_ *http.Request, _ string, _ int64, _ int) (*http.Response, error) {
|
||
return s.resp, s.err
|
||
}
|
||
|
||
func (s *httpUpstreamStub) DoWithTLS(_ *http.Request, _ string, _ int64, _ int, _ bool) (*http.Response, error) {
|
||
return s.resp, s.err
|
||
}
|
||
|
||
func TestAntigravityGatewayService_Forward_PromptTooLong(t *testing.T) {
|
||
gin.SetMode(gin.TestMode)
|
||
writer := httptest.NewRecorder()
|
||
c, _ := gin.CreateTestContext(writer)
|
||
|
||
body, err := json.Marshal(map[string]any{
|
||
"model": "claude-opus-4-6",
|
||
"messages": []map[string]any{
|
||
{"role": "user", "content": "hi"},
|
||
},
|
||
"max_tokens": 1,
|
||
"stream": false,
|
||
})
|
||
require.NoError(t, err)
|
||
|
||
req := httptest.NewRequest(http.MethodPost, "/v1/messages", bytes.NewReader(body))
|
||
c.Request = req
|
||
|
||
respBody := []byte(`{"error":{"message":"Prompt is too long"}}`)
|
||
resp := &http.Response{
|
||
StatusCode: http.StatusBadRequest,
|
||
Header: http.Header{"X-Request-Id": []string{"req-1"}},
|
||
Body: io.NopCloser(bytes.NewReader(respBody)),
|
||
}
|
||
|
||
svc := &AntigravityGatewayService{
|
||
tokenProvider: &AntigravityTokenProvider{},
|
||
httpUpstream: &httpUpstreamStub{resp: resp},
|
||
}
|
||
|
||
account := &Account{
|
||
ID: 1,
|
||
Name: "acc-1",
|
||
Platform: PlatformAntigravity,
|
||
Type: AccountTypeOAuth,
|
||
Status: StatusActive,
|
||
Concurrency: 1,
|
||
Credentials: map[string]any{
|
||
"access_token": "token",
|
||
},
|
||
}
|
||
|
||
result, err := svc.Forward(context.Background(), c, account, body, false)
|
||
require.Nil(t, result)
|
||
|
||
var promptErr *PromptTooLongError
|
||
require.ErrorAs(t, err, &promptErr)
|
||
require.Equal(t, http.StatusBadRequest, promptErr.StatusCode)
|
||
require.Equal(t, "req-1", promptErr.RequestID)
|
||
require.NotEmpty(t, promptErr.Body)
|
||
|
||
raw, ok := c.Get(OpsUpstreamErrorsKey)
|
||
require.True(t, ok)
|
||
events, ok := raw.([]*OpsUpstreamErrorEvent)
|
||
require.True(t, ok)
|
||
require.Len(t, events, 1)
|
||
require.Equal(t, "prompt_too_long", events[0].Kind)
|
||
}
|
||
|
||
// TestAntigravityGatewayService_Forward_ModelRateLimitTriggersFailover
|
||
// 验证:当账号存在模型限流且剩余时间 >= antigravityRateLimitThreshold 时,
|
||
// Forward 方法应返回 UpstreamFailoverError,触发 Handler 切换账号
|
||
func TestAntigravityGatewayService_Forward_ModelRateLimitTriggersFailover(t *testing.T) {
|
||
gin.SetMode(gin.TestMode)
|
||
writer := httptest.NewRecorder()
|
||
c, _ := gin.CreateTestContext(writer)
|
||
|
||
body, err := json.Marshal(map[string]any{
|
||
"model": "claude-opus-4-6",
|
||
"messages": []map[string]any{
|
||
{"role": "user", "content": "hi"},
|
||
},
|
||
"max_tokens": 1,
|
||
"stream": false,
|
||
})
|
||
require.NoError(t, err)
|
||
|
||
req := httptest.NewRequest(http.MethodPost, "/v1/messages", bytes.NewReader(body))
|
||
c.Request = req
|
||
|
||
// 不需要真正调用上游,因为预检查会直接返回切换信号
|
||
svc := &AntigravityGatewayService{
|
||
tokenProvider: &AntigravityTokenProvider{},
|
||
httpUpstream: &httpUpstreamStub{resp: nil, err: nil},
|
||
}
|
||
|
||
// 设置模型限流:剩余时间 30 秒(> antigravityRateLimitThreshold 7s)
|
||
futureResetAt := time.Now().Add(30 * time.Second).Format(time.RFC3339)
|
||
account := &Account{
|
||
ID: 1,
|
||
Name: "acc-rate-limited",
|
||
Platform: PlatformAntigravity,
|
||
Type: AccountTypeOAuth,
|
||
Status: StatusActive,
|
||
Concurrency: 1,
|
||
Credentials: map[string]any{
|
||
"access_token": "token",
|
||
},
|
||
Extra: map[string]any{
|
||
modelRateLimitsKey: map[string]any{
|
||
"claude-opus-4-6-thinking": map[string]any{
|
||
"rate_limit_reset_at": futureResetAt,
|
||
},
|
||
},
|
||
},
|
||
}
|
||
|
||
result, err := svc.Forward(context.Background(), c, account, body, false)
|
||
require.Nil(t, result, "Forward should not return result when model rate limited")
|
||
require.NotNil(t, err, "Forward should return error")
|
||
|
||
// 核心验证:错误应该是 UpstreamFailoverError,而不是普通 502 错误
|
||
var failoverErr *UpstreamFailoverError
|
||
require.ErrorAs(t, err, &failoverErr, "error should be UpstreamFailoverError to trigger account switch")
|
||
require.Equal(t, http.StatusServiceUnavailable, failoverErr.StatusCode)
|
||
// 非粘性会话请求,ForceCacheBilling 应为 false
|
||
require.False(t, failoverErr.ForceCacheBilling, "ForceCacheBilling should be false for non-sticky session")
|
||
}
|
||
|
||
// TestAntigravityGatewayService_ForwardGemini_ModelRateLimitTriggersFailover
|
||
// 验证:ForwardGemini 方法同样能正确将 AntigravityAccountSwitchError 转换为 UpstreamFailoverError
|
||
func TestAntigravityGatewayService_ForwardGemini_ModelRateLimitTriggersFailover(t *testing.T) {
|
||
gin.SetMode(gin.TestMode)
|
||
writer := httptest.NewRecorder()
|
||
c, _ := gin.CreateTestContext(writer)
|
||
|
||
body, err := json.Marshal(map[string]any{
|
||
"contents": []map[string]any{
|
||
{"role": "user", "parts": []map[string]any{{"text": "hi"}}},
|
||
},
|
||
})
|
||
require.NoError(t, err)
|
||
|
||
req := httptest.NewRequest(http.MethodPost, "/v1beta/models/gemini-2.5-flash:generateContent", bytes.NewReader(body))
|
||
c.Request = req
|
||
|
||
// 不需要真正调用上游,因为预检查会直接返回切换信号
|
||
svc := &AntigravityGatewayService{
|
||
tokenProvider: &AntigravityTokenProvider{},
|
||
httpUpstream: &httpUpstreamStub{resp: nil, err: nil},
|
||
}
|
||
|
||
// 设置模型限流:剩余时间 30 秒(> antigravityRateLimitThreshold 7s)
|
||
futureResetAt := time.Now().Add(30 * time.Second).Format(time.RFC3339)
|
||
account := &Account{
|
||
ID: 2,
|
||
Name: "acc-gemini-rate-limited",
|
||
Platform: PlatformAntigravity,
|
||
Type: AccountTypeOAuth,
|
||
Status: StatusActive,
|
||
Concurrency: 1,
|
||
Credentials: map[string]any{
|
||
"access_token": "token",
|
||
},
|
||
Extra: map[string]any{
|
||
modelRateLimitsKey: map[string]any{
|
||
"gemini-2.5-flash": map[string]any{
|
||
"rate_limit_reset_at": futureResetAt,
|
||
},
|
||
},
|
||
},
|
||
}
|
||
|
||
result, err := svc.ForwardGemini(context.Background(), c, account, "gemini-2.5-flash", "generateContent", false, body, false)
|
||
require.Nil(t, result, "ForwardGemini should not return result when model rate limited")
|
||
require.NotNil(t, err, "ForwardGemini should return error")
|
||
|
||
// 核心验证:错误应该是 UpstreamFailoverError,而不是普通 502 错误
|
||
var failoverErr *UpstreamFailoverError
|
||
require.ErrorAs(t, err, &failoverErr, "error should be UpstreamFailoverError to trigger account switch")
|
||
require.Equal(t, http.StatusServiceUnavailable, failoverErr.StatusCode)
|
||
// 非粘性会话请求,ForceCacheBilling 应为 false
|
||
require.False(t, failoverErr.ForceCacheBilling, "ForceCacheBilling should be false for non-sticky session")
|
||
}
|
||
|
||
// TestAntigravityGatewayService_Forward_StickySessionForceCacheBilling
|
||
// 验证:粘性会话切换时,UpstreamFailoverError.ForceCacheBilling 应为 true
|
||
func TestAntigravityGatewayService_Forward_StickySessionForceCacheBilling(t *testing.T) {
|
||
gin.SetMode(gin.TestMode)
|
||
writer := httptest.NewRecorder()
|
||
c, _ := gin.CreateTestContext(writer)
|
||
|
||
body, err := json.Marshal(map[string]any{
|
||
"model": "claude-opus-4-6",
|
||
"messages": []map[string]string{{"role": "user", "content": "hello"}},
|
||
})
|
||
require.NoError(t, err)
|
||
|
||
req := httptest.NewRequest(http.MethodPost, "/v1/messages", bytes.NewReader(body))
|
||
c.Request = req
|
||
|
||
svc := &AntigravityGatewayService{
|
||
tokenProvider: &AntigravityTokenProvider{},
|
||
httpUpstream: &httpUpstreamStub{resp: nil, err: nil},
|
||
}
|
||
|
||
// 设置模型限流:剩余时间 30 秒(> antigravityRateLimitThreshold 7s)
|
||
futureResetAt := time.Now().Add(30 * time.Second).Format(time.RFC3339)
|
||
account := &Account{
|
||
ID: 3,
|
||
Name: "acc-sticky-rate-limited",
|
||
Platform: PlatformAntigravity,
|
||
Type: AccountTypeOAuth,
|
||
Status: StatusActive,
|
||
Concurrency: 1,
|
||
Credentials: map[string]any{
|
||
"access_token": "token",
|
||
},
|
||
Extra: map[string]any{
|
||
modelRateLimitsKey: map[string]any{
|
||
"claude-opus-4-6-thinking": map[string]any{
|
||
"rate_limit_reset_at": futureResetAt,
|
||
},
|
||
},
|
||
},
|
||
}
|
||
|
||
// 传入 isStickySession = true
|
||
result, err := svc.Forward(context.Background(), c, account, body, true)
|
||
require.Nil(t, result, "Forward should not return result when model rate limited")
|
||
require.NotNil(t, err, "Forward should return error")
|
||
|
||
// 核心验证:粘性会话切换时,ForceCacheBilling 应为 true
|
||
var failoverErr *UpstreamFailoverError
|
||
require.ErrorAs(t, err, &failoverErr, "error should be UpstreamFailoverError to trigger account switch")
|
||
require.Equal(t, http.StatusServiceUnavailable, failoverErr.StatusCode)
|
||
require.True(t, failoverErr.ForceCacheBilling, "ForceCacheBilling should be true for sticky session switch")
|
||
}
|
||
|
||
// TestAntigravityGatewayService_ForwardGemini_StickySessionForceCacheBilling verifies
|
||
// that ForwardGemini sets ForceCacheBilling=true for sticky session switch.
|
||
func TestAntigravityGatewayService_ForwardGemini_StickySessionForceCacheBilling(t *testing.T) {
|
||
gin.SetMode(gin.TestMode)
|
||
writer := httptest.NewRecorder()
|
||
c, _ := gin.CreateTestContext(writer)
|
||
|
||
body, err := json.Marshal(map[string]any{
|
||
"contents": []map[string]any{
|
||
{"role": "user", "parts": []map[string]any{{"text": "hi"}}},
|
||
},
|
||
})
|
||
require.NoError(t, err)
|
||
|
||
req := httptest.NewRequest(http.MethodPost, "/v1beta/models/gemini-2.5-flash:generateContent", bytes.NewReader(body))
|
||
c.Request = req
|
||
|
||
svc := &AntigravityGatewayService{
|
||
tokenProvider: &AntigravityTokenProvider{},
|
||
httpUpstream: &httpUpstreamStub{resp: nil, err: nil},
|
||
}
|
||
|
||
// 设置模型限流:剩余时间 30 秒(> antigravityRateLimitThreshold 7s)
|
||
futureResetAt := time.Now().Add(30 * time.Second).Format(time.RFC3339)
|
||
account := &Account{
|
||
ID: 4,
|
||
Name: "acc-gemini-sticky-rate-limited",
|
||
Platform: PlatformAntigravity,
|
||
Type: AccountTypeOAuth,
|
||
Status: StatusActive,
|
||
Concurrency: 1,
|
||
Credentials: map[string]any{
|
||
"access_token": "token",
|
||
},
|
||
Extra: map[string]any{
|
||
modelRateLimitsKey: map[string]any{
|
||
"gemini-2.5-flash": map[string]any{
|
||
"rate_limit_reset_at": futureResetAt,
|
||
},
|
||
},
|
||
},
|
||
}
|
||
|
||
// 传入 isStickySession = true
|
||
result, err := svc.ForwardGemini(context.Background(), c, account, "gemini-2.5-flash", "generateContent", false, body, true)
|
||
require.Nil(t, result, "ForwardGemini should not return result when model rate limited")
|
||
require.NotNil(t, err, "ForwardGemini should return error")
|
||
|
||
// 核心验证:粘性会话切换时,ForceCacheBilling 应为 true
|
||
var failoverErr *UpstreamFailoverError
|
||
require.ErrorAs(t, err, &failoverErr, "error should be UpstreamFailoverError to trigger account switch")
|
||
require.Equal(t, http.StatusServiceUnavailable, failoverErr.StatusCode)
|
||
require.True(t, failoverErr.ForceCacheBilling, "ForceCacheBilling should be true for sticky session switch")
|
||
}
|
||
|
||
// --- 流式 happy path 测试 ---
|
||
|
||
// TestStreamUpstreamResponse_NormalComplete
|
||
// 验证:正常流式转发完成时,数据正确透传、usage 正确收集、clientDisconnect=false
|
||
func TestStreamUpstreamResponse_NormalComplete(t *testing.T) {
|
||
gin.SetMode(gin.TestMode)
|
||
svc := newAntigravityTestService(&config.Config{
|
||
Gateway: config.GatewayConfig{MaxLineSize: defaultMaxLineSize},
|
||
})
|
||
|
||
rec := httptest.NewRecorder()
|
||
c, _ := gin.CreateTestContext(rec)
|
||
c.Request = httptest.NewRequest(http.MethodPost, "/", nil)
|
||
|
||
pr, pw := io.Pipe()
|
||
resp := &http.Response{StatusCode: http.StatusOK, Body: pr, Header: http.Header{}}
|
||
|
||
go func() {
|
||
defer func() { _ = pw.Close() }()
|
||
fmt.Fprintln(pw, `event: message_start`)
|
||
fmt.Fprintln(pw, `data: {"type":"message_start","message":{"usage":{"input_tokens":10}}}`)
|
||
fmt.Fprintln(pw, "")
|
||
fmt.Fprintln(pw, `event: content_block_delta`)
|
||
fmt.Fprintln(pw, `data: {"type":"content_block_delta","delta":{"text":"hello"}}`)
|
||
fmt.Fprintln(pw, "")
|
||
fmt.Fprintln(pw, `event: message_delta`)
|
||
fmt.Fprintln(pw, `data: {"type":"message_delta","usage":{"output_tokens":5}}`)
|
||
fmt.Fprintln(pw, "")
|
||
}()
|
||
|
||
result := svc.streamUpstreamResponse(c, resp, time.Now())
|
||
_ = pr.Close()
|
||
|
||
require.NotNil(t, result)
|
||
require.False(t, result.clientDisconnect, "normal completion should not set clientDisconnect")
|
||
require.NotNil(t, result.usage)
|
||
require.Equal(t, 5, result.usage.OutputTokens, "should collect output_tokens from message_delta")
|
||
require.NotNil(t, result.firstTokenMs, "should record first token time")
|
||
|
||
// 验证数据被透传到客户端
|
||
body := rec.Body.String()
|
||
require.Contains(t, body, "event: message_start")
|
||
require.Contains(t, body, "content_block_delta")
|
||
require.Contains(t, body, "message_delta")
|
||
}
|
||
|
||
// TestHandleGeminiStreamingResponse_NormalComplete
|
||
// 验证:正常 Gemini 流式转发,数据正确透传、usage 正确收集
|
||
func TestHandleGeminiStreamingResponse_NormalComplete(t *testing.T) {
|
||
gin.SetMode(gin.TestMode)
|
||
svc := newAntigravityTestService(&config.Config{
|
||
Gateway: config.GatewayConfig{MaxLineSize: defaultMaxLineSize},
|
||
})
|
||
|
||
rec := httptest.NewRecorder()
|
||
c, _ := gin.CreateTestContext(rec)
|
||
c.Request = httptest.NewRequest(http.MethodPost, "/", nil)
|
||
|
||
pr, pw := io.Pipe()
|
||
resp := &http.Response{StatusCode: http.StatusOK, Body: pr, Header: http.Header{}}
|
||
|
||
go func() {
|
||
defer func() { _ = pw.Close() }()
|
||
// 第一个 chunk(部分内容)
|
||
fmt.Fprintln(pw, `data: {"candidates":[{"content":{"parts":[{"text":"Hello"}]}}],"usageMetadata":{"promptTokenCount":10,"candidatesTokenCount":3}}`)
|
||
fmt.Fprintln(pw, "")
|
||
// 第二个 chunk(最终内容+完整 usage)
|
||
fmt.Fprintln(pw, `data: {"candidates":[{"content":{"parts":[{"text":" world"}]},"finishReason":"STOP"}],"usageMetadata":{"promptTokenCount":10,"candidatesTokenCount":8,"cachedContentTokenCount":2}}`)
|
||
fmt.Fprintln(pw, "")
|
||
}()
|
||
|
||
result, err := svc.handleGeminiStreamingResponse(c, resp, time.Now())
|
||
_ = pr.Close()
|
||
|
||
require.NoError(t, err)
|
||
require.NotNil(t, result)
|
||
require.False(t, result.clientDisconnect, "normal completion should not set clientDisconnect")
|
||
require.NotNil(t, result.usage)
|
||
// Gemini usage: promptTokenCount=10, candidatesTokenCount=8, cachedContentTokenCount=2
|
||
// → InputTokens=10-2=8, OutputTokens=8, CacheReadInputTokens=2
|
||
require.Equal(t, 8, result.usage.InputTokens)
|
||
require.Equal(t, 8, result.usage.OutputTokens)
|
||
require.Equal(t, 2, result.usage.CacheReadInputTokens)
|
||
require.NotNil(t, result.firstTokenMs, "should record first token time")
|
||
|
||
// 验证数据被透传到客户端
|
||
body := rec.Body.String()
|
||
require.Contains(t, body, "Hello")
|
||
require.Contains(t, body, "world")
|
||
// 不应包含错误事件
|
||
require.NotContains(t, body, "event: error")
|
||
}
|
||
|
||
// TestHandleClaudeStreamingResponse_NormalComplete
|
||
// 验证:正常 Claude 流式转发(Gemini→Claude 转换),数据正确转换并输出
|
||
func TestHandleClaudeStreamingResponse_NormalComplete(t *testing.T) {
|
||
gin.SetMode(gin.TestMode)
|
||
svc := newAntigravityTestService(&config.Config{
|
||
Gateway: config.GatewayConfig{MaxLineSize: defaultMaxLineSize},
|
||
})
|
||
|
||
rec := httptest.NewRecorder()
|
||
c, _ := gin.CreateTestContext(rec)
|
||
c.Request = httptest.NewRequest(http.MethodPost, "/", nil)
|
||
|
||
pr, pw := io.Pipe()
|
||
resp := &http.Response{StatusCode: http.StatusOK, Body: pr, Header: http.Header{}}
|
||
|
||
go func() {
|
||
defer func() { _ = pw.Close() }()
|
||
// v1internal 包装格式:Gemini 数据嵌套在 "response" 字段下
|
||
// ProcessLine 先尝试反序列化为 V1InternalResponse,裸格式会导致 Response.UsageMetadata 为空
|
||
fmt.Fprintln(pw, `data: {"response":{"candidates":[{"content":{"parts":[{"text":"Hi there"}]},"finishReason":"STOP"}],"usageMetadata":{"promptTokenCount":5,"candidatesTokenCount":3}}}`)
|
||
fmt.Fprintln(pw, "")
|
||
}()
|
||
|
||
result, err := svc.handleClaudeStreamingResponse(c, resp, time.Now(), "claude-sonnet-4-5")
|
||
_ = pr.Close()
|
||
|
||
require.NoError(t, err)
|
||
require.NotNil(t, result)
|
||
require.False(t, result.clientDisconnect, "normal completion should not set clientDisconnect")
|
||
require.NotNil(t, result.usage)
|
||
// Gemini→Claude 转换的 usage:promptTokenCount=5→InputTokens=5, candidatesTokenCount=3→OutputTokens=3
|
||
require.Equal(t, 5, result.usage.InputTokens)
|
||
require.Equal(t, 3, result.usage.OutputTokens)
|
||
require.NotNil(t, result.firstTokenMs, "should record first token time")
|
||
|
||
// 验证输出是 Claude SSE 格式(processor 会转换)
|
||
body := rec.Body.String()
|
||
require.Contains(t, body, "event: message_start", "should contain Claude message_start event")
|
||
require.Contains(t, body, "event: message_stop", "should contain Claude message_stop event")
|
||
// 不应包含错误事件
|
||
require.NotContains(t, body, "event: error")
|
||
}
|
||
|
||
// --- 流式客户端断开检测测试 ---
|
||
|
||
// TestStreamUpstreamResponse_ClientDisconnectDrainsUsage
|
||
// 验证:客户端写入失败后,streamUpstreamResponse 继续读取上游以收集 usage
|
||
func TestStreamUpstreamResponse_ClientDisconnectDrainsUsage(t *testing.T) {
|
||
gin.SetMode(gin.TestMode)
|
||
svc := newAntigravityTestService(&config.Config{
|
||
Gateway: config.GatewayConfig{MaxLineSize: defaultMaxLineSize},
|
||
})
|
||
|
||
rec := httptest.NewRecorder()
|
||
c, _ := gin.CreateTestContext(rec)
|
||
c.Request = httptest.NewRequest(http.MethodPost, "/", nil)
|
||
c.Writer = &antigravityFailingWriter{ResponseWriter: c.Writer, failAfter: 0}
|
||
|
||
pr, pw := io.Pipe()
|
||
resp := &http.Response{StatusCode: http.StatusOK, Body: pr, Header: http.Header{}}
|
||
|
||
go func() {
|
||
defer func() { _ = pw.Close() }()
|
||
fmt.Fprintln(pw, `event: message_start`)
|
||
fmt.Fprintln(pw, `data: {"type":"message_start","message":{"usage":{"input_tokens":10}}}`)
|
||
fmt.Fprintln(pw, "")
|
||
fmt.Fprintln(pw, `event: message_delta`)
|
||
fmt.Fprintln(pw, `data: {"type":"message_delta","usage":{"output_tokens":20}}`)
|
||
fmt.Fprintln(pw, "")
|
||
}()
|
||
|
||
result := svc.streamUpstreamResponse(c, resp, time.Now())
|
||
_ = pr.Close()
|
||
|
||
require.NotNil(t, result)
|
||
require.True(t, result.clientDisconnect)
|
||
require.NotNil(t, result.usage)
|
||
require.Equal(t, 20, result.usage.OutputTokens)
|
||
}
|
||
|
||
// TestStreamUpstreamResponse_ContextCanceled
|
||
// 验证:context 取消时返回 usage 且标记 clientDisconnect
|
||
func TestStreamUpstreamResponse_ContextCanceled(t *testing.T) {
|
||
gin.SetMode(gin.TestMode)
|
||
svc := newAntigravityTestService(&config.Config{
|
||
Gateway: config.GatewayConfig{MaxLineSize: defaultMaxLineSize},
|
||
})
|
||
|
||
rec := httptest.NewRecorder()
|
||
c, _ := gin.CreateTestContext(rec)
|
||
ctx, cancel := context.WithCancel(context.Background())
|
||
cancel()
|
||
c.Request = httptest.NewRequest(http.MethodPost, "/", nil).WithContext(ctx)
|
||
|
||
resp := &http.Response{StatusCode: http.StatusOK, Body: cancelReadCloser{}, Header: http.Header{}}
|
||
|
||
result := svc.streamUpstreamResponse(c, resp, time.Now())
|
||
|
||
require.NotNil(t, result)
|
||
require.True(t, result.clientDisconnect)
|
||
require.NotContains(t, rec.Body.String(), "event: error")
|
||
}
|
||
|
||
// TestStreamUpstreamResponse_Timeout
|
||
// 验证:上游超时时返回已收集的 usage
|
||
func TestStreamUpstreamResponse_Timeout(t *testing.T) {
|
||
gin.SetMode(gin.TestMode)
|
||
svc := newAntigravityTestService(&config.Config{
|
||
Gateway: config.GatewayConfig{StreamDataIntervalTimeout: 1, MaxLineSize: defaultMaxLineSize},
|
||
})
|
||
|
||
rec := httptest.NewRecorder()
|
||
c, _ := gin.CreateTestContext(rec)
|
||
c.Request = httptest.NewRequest(http.MethodPost, "/", nil)
|
||
|
||
pr, pw := io.Pipe()
|
||
resp := &http.Response{StatusCode: http.StatusOK, Body: pr, Header: http.Header{}}
|
||
|
||
result := svc.streamUpstreamResponse(c, resp, time.Now())
|
||
_ = pw.Close()
|
||
_ = pr.Close()
|
||
|
||
require.NotNil(t, result)
|
||
require.False(t, result.clientDisconnect)
|
||
}
|
||
|
||
// TestStreamUpstreamResponse_TimeoutAfterClientDisconnect
|
||
// 验证:客户端断开后上游超时,返回 usage 并标记 clientDisconnect
|
||
func TestStreamUpstreamResponse_TimeoutAfterClientDisconnect(t *testing.T) {
|
||
gin.SetMode(gin.TestMode)
|
||
svc := newAntigravityTestService(&config.Config{
|
||
Gateway: config.GatewayConfig{StreamDataIntervalTimeout: 1, MaxLineSize: defaultMaxLineSize},
|
||
})
|
||
|
||
rec := httptest.NewRecorder()
|
||
c, _ := gin.CreateTestContext(rec)
|
||
c.Request = httptest.NewRequest(http.MethodPost, "/", nil)
|
||
c.Writer = &antigravityFailingWriter{ResponseWriter: c.Writer, failAfter: 0}
|
||
|
||
pr, pw := io.Pipe()
|
||
resp := &http.Response{StatusCode: http.StatusOK, Body: pr, Header: http.Header{}}
|
||
|
||
go func() {
|
||
fmt.Fprintln(pw, `data: {"type":"message_start","message":{"usage":{"input_tokens":5}}}`)
|
||
fmt.Fprintln(pw, "")
|
||
// 不关闭 pw → 等待超时
|
||
}()
|
||
|
||
result := svc.streamUpstreamResponse(c, resp, time.Now())
|
||
_ = pw.Close()
|
||
_ = pr.Close()
|
||
|
||
require.NotNil(t, result)
|
||
require.True(t, result.clientDisconnect)
|
||
}
|
||
|
||
// TestHandleGeminiStreamingResponse_ClientDisconnect
|
||
// 验证:Gemini 流式转发中客户端断开后继续 drain 上游
|
||
func TestHandleGeminiStreamingResponse_ClientDisconnect(t *testing.T) {
|
||
gin.SetMode(gin.TestMode)
|
||
svc := newAntigravityTestService(&config.Config{
|
||
Gateway: config.GatewayConfig{MaxLineSize: defaultMaxLineSize},
|
||
})
|
||
|
||
rec := httptest.NewRecorder()
|
||
c, _ := gin.CreateTestContext(rec)
|
||
c.Request = httptest.NewRequest(http.MethodPost, "/", nil)
|
||
c.Writer = &antigravityFailingWriter{ResponseWriter: c.Writer, failAfter: 0}
|
||
|
||
pr, pw := io.Pipe()
|
||
resp := &http.Response{StatusCode: http.StatusOK, Body: pr, Header: http.Header{}}
|
||
|
||
go func() {
|
||
defer func() { _ = pw.Close() }()
|
||
fmt.Fprintln(pw, `data: {"candidates":[{"content":{"parts":[{"text":"hi"}]}}],"usageMetadata":{"promptTokenCount":5,"candidatesTokenCount":10}}`)
|
||
fmt.Fprintln(pw, "")
|
||
}()
|
||
|
||
result, err := svc.handleGeminiStreamingResponse(c, resp, time.Now())
|
||
_ = pr.Close()
|
||
|
||
require.NoError(t, err)
|
||
require.NotNil(t, result)
|
||
require.True(t, result.clientDisconnect)
|
||
require.NotContains(t, rec.Body.String(), "write_failed")
|
||
}
|
||
|
||
// TestHandleGeminiStreamingResponse_ContextCanceled
|
||
// 验证:context 取消时不注入错误事件
|
||
func TestHandleGeminiStreamingResponse_ContextCanceled(t *testing.T) {
|
||
gin.SetMode(gin.TestMode)
|
||
svc := newAntigravityTestService(&config.Config{
|
||
Gateway: config.GatewayConfig{MaxLineSize: defaultMaxLineSize},
|
||
})
|
||
|
||
rec := httptest.NewRecorder()
|
||
c, _ := gin.CreateTestContext(rec)
|
||
ctx, cancel := context.WithCancel(context.Background())
|
||
cancel()
|
||
c.Request = httptest.NewRequest(http.MethodPost, "/", nil).WithContext(ctx)
|
||
|
||
resp := &http.Response{StatusCode: http.StatusOK, Body: cancelReadCloser{}, Header: http.Header{}}
|
||
|
||
result, err := svc.handleGeminiStreamingResponse(c, resp, time.Now())
|
||
|
||
require.NoError(t, err)
|
||
require.NotNil(t, result)
|
||
require.True(t, result.clientDisconnect)
|
||
require.NotContains(t, rec.Body.String(), "event: error")
|
||
}
|
||
|
||
// TestHandleClaudeStreamingResponse_ClientDisconnect
|
||
// 验证:Claude 流式转发中客户端断开后继续 drain 上游
|
||
func TestHandleClaudeStreamingResponse_ClientDisconnect(t *testing.T) {
|
||
gin.SetMode(gin.TestMode)
|
||
svc := newAntigravityTestService(&config.Config{
|
||
Gateway: config.GatewayConfig{MaxLineSize: defaultMaxLineSize},
|
||
})
|
||
|
||
rec := httptest.NewRecorder()
|
||
c, _ := gin.CreateTestContext(rec)
|
||
c.Request = httptest.NewRequest(http.MethodPost, "/", nil)
|
||
c.Writer = &antigravityFailingWriter{ResponseWriter: c.Writer, failAfter: 0}
|
||
|
||
pr, pw := io.Pipe()
|
||
resp := &http.Response{StatusCode: http.StatusOK, Body: pr, Header: http.Header{}}
|
||
|
||
go func() {
|
||
defer func() { _ = pw.Close() }()
|
||
// v1internal 包装格式
|
||
fmt.Fprintln(pw, `data: {"response":{"candidates":[{"content":{"parts":[{"text":"hello"}]},"finishReason":"STOP"}],"usageMetadata":{"promptTokenCount":8,"candidatesTokenCount":15}}}`)
|
||
fmt.Fprintln(pw, "")
|
||
}()
|
||
|
||
result, err := svc.handleClaudeStreamingResponse(c, resp, time.Now(), "claude-sonnet-4-5")
|
||
_ = pr.Close()
|
||
|
||
require.NoError(t, err)
|
||
require.NotNil(t, result)
|
||
require.True(t, result.clientDisconnect)
|
||
}
|
||
|
||
// TestHandleClaudeStreamingResponse_ContextCanceled
|
||
// 验证:context 取消时不注入错误事件
|
||
func TestHandleClaudeStreamingResponse_ContextCanceled(t *testing.T) {
|
||
gin.SetMode(gin.TestMode)
|
||
svc := newAntigravityTestService(&config.Config{
|
||
Gateway: config.GatewayConfig{MaxLineSize: defaultMaxLineSize},
|
||
})
|
||
|
||
rec := httptest.NewRecorder()
|
||
c, _ := gin.CreateTestContext(rec)
|
||
ctx, cancel := context.WithCancel(context.Background())
|
||
cancel()
|
||
c.Request = httptest.NewRequest(http.MethodPost, "/", nil).WithContext(ctx)
|
||
|
||
resp := &http.Response{StatusCode: http.StatusOK, Body: cancelReadCloser{}, Header: http.Header{}}
|
||
|
||
result, err := svc.handleClaudeStreamingResponse(c, resp, time.Now(), "claude-sonnet-4-5")
|
||
|
||
require.NoError(t, err)
|
||
require.NotNil(t, result)
|
||
require.True(t, result.clientDisconnect)
|
||
require.NotContains(t, rec.Body.String(), "event: error")
|
||
}
|
||
|
||
// TestExtractSSEUsage 验证 extractSSEUsage 从 SSE data 行正确提取 usage
|
||
func TestExtractSSEUsage(t *testing.T) {
|
||
svc := &AntigravityGatewayService{}
|
||
tests := []struct {
|
||
name string
|
||
line string
|
||
expected ClaudeUsage
|
||
}{
|
||
{
|
||
name: "message_delta with output_tokens",
|
||
line: `data: {"type":"message_delta","usage":{"output_tokens":42}}`,
|
||
expected: ClaudeUsage{OutputTokens: 42},
|
||
},
|
||
{
|
||
name: "non-data line ignored",
|
||
line: `event: message_start`,
|
||
expected: ClaudeUsage{},
|
||
},
|
||
{
|
||
name: "top-level usage with all fields",
|
||
line: `data: {"usage":{"input_tokens":10,"output_tokens":20,"cache_read_input_tokens":5,"cache_creation_input_tokens":3}}`,
|
||
expected: ClaudeUsage{InputTokens: 10, OutputTokens: 20, CacheReadInputTokens: 5, CacheCreationInputTokens: 3},
|
||
},
|
||
}
|
||
for _, tt := range tests {
|
||
t.Run(tt.name, func(t *testing.T) {
|
||
usage := &ClaudeUsage{}
|
||
svc.extractSSEUsage(tt.line, usage)
|
||
require.Equal(t, tt.expected, *usage)
|
||
})
|
||
}
|
||
}
|
||
|
||
// TestAntigravityClientWriter 验证 antigravityClientWriter 的断开检测
|
||
func TestAntigravityClientWriter(t *testing.T) {
|
||
t.Run("normal write succeeds", func(t *testing.T) {
|
||
gin.SetMode(gin.TestMode)
|
||
rec := httptest.NewRecorder()
|
||
c, _ := gin.CreateTestContext(rec)
|
||
flusher, _ := c.Writer.(http.Flusher)
|
||
cw := newAntigravityClientWriter(c.Writer, flusher, "test")
|
||
|
||
ok := cw.Write([]byte("hello"))
|
||
require.True(t, ok)
|
||
require.False(t, cw.Disconnected())
|
||
require.Contains(t, rec.Body.String(), "hello")
|
||
})
|
||
|
||
t.Run("write failure marks disconnected", func(t *testing.T) {
|
||
gin.SetMode(gin.TestMode)
|
||
rec := httptest.NewRecorder()
|
||
c, _ := gin.CreateTestContext(rec)
|
||
fw := &antigravityFailingWriter{ResponseWriter: c.Writer, failAfter: 0}
|
||
flusher, _ := c.Writer.(http.Flusher)
|
||
cw := newAntigravityClientWriter(fw, flusher, "test")
|
||
|
||
ok := cw.Write([]byte("hello"))
|
||
require.False(t, ok)
|
||
require.True(t, cw.Disconnected())
|
||
})
|
||
|
||
t.Run("subsequent writes are no-op", func(t *testing.T) {
|
||
gin.SetMode(gin.TestMode)
|
||
rec := httptest.NewRecorder()
|
||
c, _ := gin.CreateTestContext(rec)
|
||
fw := &antigravityFailingWriter{ResponseWriter: c.Writer, failAfter: 0}
|
||
flusher, _ := c.Writer.(http.Flusher)
|
||
cw := newAntigravityClientWriter(fw, flusher, "test")
|
||
|
||
cw.Write([]byte("first"))
|
||
ok := cw.Fprintf("second %d", 2)
|
||
require.False(t, ok)
|
||
require.True(t, cw.Disconnected())
|
||
})
|
||
}
|