diff --git a/backend/internal/service/antigravity_gateway_service.go b/backend/internal/service/antigravity_gateway_service.go index fd53ba71..fc29eeb3 100644 --- a/backend/internal/service/antigravity_gateway_service.go +++ b/backend/internal/service/antigravity_gateway_service.go @@ -47,6 +47,21 @@ const ( googleRPCReasonRateLimitExceeded = "RATE_LIMIT_EXCEEDED" ) +// upstreamHopByHopHeaders 透传请求头时需要排除的 hop-by-hop 头 +var upstreamHopByHopHeaders = map[string]bool{ + "connection": true, + "keep-alive": true, + "proxy-authenticate": true, + "proxy-authorization": true, + "proxy-connection": true, + "te": true, + "trailer": true, + "transfer-encoding": true, + "upgrade": true, + "host": true, + "content-length": true, +} + // antigravityPassthroughErrorMessages 透传给客户端的错误消息白名单(小写) // 匹配时使用 strings.Contains,无需完全匹配 var antigravityPassthroughErrorMessages = []string{ @@ -3456,10 +3471,6 @@ func (s *AntigravityGatewayService) ForwardUpstream(ctx context.Context, c *gin. if mappedModel == "" { return nil, s.writeClaudeError(c, http.StatusForbidden, "permission_error", fmt.Sprintf("model %s not in whitelist", claudeReq.Model)) } - loadModel := mappedModel - thinkingEnabled := claudeReq.Thinking != nil && claudeReq.Thinking.Type == "enabled" - mappedModel = applyThinkingModelSuffix(mappedModel, thinkingEnabled) - quotaScope, _ := resolveAntigravityQuotaScope(originalModel) // 代理 URL proxyURL := "" @@ -3469,98 +3480,38 @@ func (s *AntigravityGatewayService) ForwardUpstream(ctx context.Context, c *gin. // 统计模型调用次数 if s.cache != nil { - _, _ = s.cache.IncrModelCallCount(ctx, account.ID, loadModel) + _, _ = s.cache.IncrModelCallCount(ctx, account.ID, mappedModel) } apiURL := baseURL + "/antigravity/v1/messages" log.Printf("%s upstream_forward url=%s model=%s", prefix, apiURL, mappedModel) - // 预检查:模型级限流 - if remaining := account.GetRateLimitRemainingTimeWithContext(ctx, originalModel); remaining > 0 { - if remaining < antigravityRateLimitThreshold { - select { - case <-ctx.Done(): - return nil, ctx.Err() - case <-time.After(remaining): - } - } else { - return nil, &UpstreamFailoverError{ - StatusCode: http.StatusServiceUnavailable, - ForceCacheBilling: isStickySession, - } + // 构建请求:body 原样透传 + req, err := http.NewRequestWithContext(ctx, http.MethodPost, apiURL, bytes.NewReader(body)) + if err != nil { + return nil, s.writeClaudeError(c, http.StatusInternalServerError, "api_error", "Failed to build request") + } + // 透传客户端所有请求头(排除 hop-by-hop 和认证头) + for key, values := range c.Request.Header { + if upstreamHopByHopHeaders[strings.ToLower(key)] { + continue + } + for _, v := range values { + req.Header.Add(key, v) } } + // 覆盖认证头 + req.Header.Set("Authorization", "Bearer "+apiKey) + req.Header.Set("x-api-key", apiKey) - // 重试循环 - var resp *http.Response - var lastErr error - for attempt := 1; attempt <= antigravityMaxRetries; attempt++ { - select { - case <-ctx.Done(): - return nil, ctx.Err() - default: - } - - req, err := http.NewRequestWithContext(ctx, http.MethodPost, apiURL, bytes.NewReader(body)) - if err != nil { - return nil, s.writeClaudeError(c, http.StatusInternalServerError, "api_error", "Failed to build request") - } - req.Header.Set("Content-Type", "application/json") - req.Header.Set("Authorization", "Bearer "+apiKey) - req.Header.Set("x-api-key", apiKey) - - // 透传 anthropic headers - if v := c.GetHeader("anthropic-version"); v != "" { - req.Header.Set("anthropic-version", v) - } else { - req.Header.Set("anthropic-version", "2023-06-01") - } - if v := c.GetHeader("anthropic-beta"); v != "" { - req.Header.Set("anthropic-beta", v) - } - - if c != nil && len(body) > 0 { - c.Set(OpsUpstreamRequestBodyKey, string(body)) - } - - resp, err = s.httpUpstream.Do(req, proxyURL, account.ID, account.Concurrency) - if err != nil { - lastErr = err - if attempt < antigravityMaxRetries { - log.Printf("%s status=request_failed retry=%d/%d error=%v", prefix, attempt, antigravityMaxRetries, err) - if !sleepAntigravityBackoffWithContext(ctx, attempt) { - return nil, ctx.Err() - } - continue - } - return nil, s.writeClaudeError(c, http.StatusBadGateway, "upstream_error", "Upstream request failed after retries") - } - - // 429/503 重试 - if resp.StatusCode == http.StatusTooManyRequests || resp.StatusCode == http.StatusServiceUnavailable { - respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20)) - _ = resp.Body.Close() - - s.handleUpstreamError(ctx, prefix, account, resp.StatusCode, resp.Header, respBody, quotaScope, 0, "", isStickySession) - - if attempt < antigravityMaxRetries { - log.Printf("%s status=%d retry=%d/%d body=%s", prefix, resp.StatusCode, attempt, antigravityMaxRetries, truncateForLog(respBody, 200)) - if !sleepAntigravityBackoffWithContext(ctx, attempt) { - return nil, ctx.Err() - } - continue - } - - return nil, &UpstreamFailoverError{ - StatusCode: resp.StatusCode, - ForceCacheBilling: isStickySession, - } - } - - break // 成功或非限流错误,跳出重试 + if c != nil && len(body) > 0 { + c.Set(OpsUpstreamRequestBodyKey, string(body)) } - if resp == nil { - return nil, s.writeClaudeError(c, http.StatusBadGateway, "upstream_error", fmt.Sprintf("upstream request failed: %v", lastErr)) + + // 单次发送,不重试 + resp, err := s.httpUpstream.Do(req, proxyURL, account.ID, account.Concurrency) + if err != nil { + return nil, s.writeClaudeError(c, http.StatusBadGateway, "upstream_error", fmt.Sprintf("Upstream request failed: %v", err)) } defer func() { _ = resp.Body.Close() }() @@ -3568,44 +3519,7 @@ func (s *AntigravityGatewayService) ForwardUpstream(ctx context.Context, c *gin. if resp.StatusCode >= 400 { respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20)) - // signature 重试 - if resp.StatusCode == http.StatusBadRequest && isSignatureRelatedError(respBody) { - log.Printf("%s upstream signature error, retrying with thinking stripped", prefix) - retryClaudeReq := claudeReq - retryClaudeReq.Messages = append([]antigravity.ClaudeMessage(nil), claudeReq.Messages...) - if stripped, stripErr := stripThinkingFromClaudeRequest(&retryClaudeReq); stripErr == nil && stripped { - retryBody, _ := json.Marshal(&retryClaudeReq) - retryReq, err := http.NewRequestWithContext(ctx, http.MethodPost, apiURL, bytes.NewReader(retryBody)) - if err == nil { - retryReq.Header.Set("Content-Type", "application/json") - retryReq.Header.Set("Authorization", "Bearer "+apiKey) - retryReq.Header.Set("x-api-key", apiKey) - retryReq.Header.Set("anthropic-version", "2023-06-01") - if v := c.GetHeader("anthropic-beta"); v != "" { - retryReq.Header.Set("anthropic-beta", v) - } - retryResp, retryErr := s.httpUpstream.Do(retryReq, proxyURL, account.ID, account.Concurrency) - if retryErr == nil && retryResp != nil && retryResp.StatusCode < 400 { - resp = retryResp - goto upstreamClaudeSuccess - } - if retryResp != nil { - _ = retryResp.Body.Close() - } - } - } - } - - // prompt too long - if resp.StatusCode == http.StatusBadRequest && isPromptTooLongError(respBody) { - return nil, &PromptTooLongError{ - StatusCode: resp.StatusCode, - RequestID: resp.Header.Get("x-request-id"), - Body: respBody, - } - } - - s.handleUpstreamError(ctx, prefix, account, resp.StatusCode, resp.Header, respBody, quotaScope, 0, "", isStickySession) + s.handleUpstreamError(ctx, prefix, account, resp.StatusCode, resp.Header, respBody, "", 0, "", isStickySession) if s.shouldFailoverUpstreamError(resp.StatusCode) { return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode, ResponseBody: respBody} @@ -3614,7 +3528,7 @@ func (s *AntigravityGatewayService) ForwardUpstream(ctx context.Context, c *gin. return nil, s.writeMappedClaudeError(c, account, resp.StatusCode, resp.Header.Get("x-request-id"), respBody) } -upstreamClaudeSuccess: + // 成功响应 requestID := resp.Header.Get("x-request-id") if requestID != "" { c.Header("x-request-id", requestID) @@ -3674,7 +3588,6 @@ func (s *AntigravityGatewayService) ForwardUpstreamGemini(ctx context.Context, c if len(body) == 0 { return nil, s.writeGoogleError(c, http.StatusBadRequest, "Request body is empty") } - quotaScope, _ := resolveAntigravityQuotaScope(originalModel) imageSize := s.extractImageSize(body) @@ -3712,143 +3625,52 @@ func (s *AntigravityGatewayService) ForwardUpstreamGemini(ctx context.Context, c } // 构建 upstream URL: base_url + /antigravity/v1beta/models/MODEL:ACTION - upstreamAction := action - if action == "generateContent" && !stream { - // 非流式也用 streamGenerateContent,与 OAuth 路径行为一致 - upstreamAction = action - } - apiURL := fmt.Sprintf("%s/antigravity/v1beta/models/%s:%s", baseURL, mappedModel, upstreamAction) - if stream || upstreamAction == "streamGenerateContent" { + apiURL := fmt.Sprintf("%s/antigravity/v1beta/models/%s:%s", baseURL, mappedModel, action) + if stream || action == "streamGenerateContent" { apiURL += "?alt=sse" } - log.Printf("%s upstream_forward_gemini url=%s model=%s action=%s", prefix, apiURL, mappedModel, upstreamAction) + log.Printf("%s upstream_forward_gemini url=%s model=%s action=%s", prefix, apiURL, mappedModel, action) - // 预检查:模型级限流 - if remaining := account.GetRateLimitRemainingTimeWithContext(ctx, originalModel); remaining > 0 { - if remaining < antigravityRateLimitThreshold { - select { - case <-ctx.Done(): - return nil, ctx.Err() - case <-time.After(remaining): - } - } else { - return nil, &UpstreamFailoverError{ - StatusCode: http.StatusServiceUnavailable, - ForceCacheBilling: isStickySession, - } + // 构建请求:body 原样透传 + req, err := http.NewRequestWithContext(ctx, http.MethodPost, apiURL, bytes.NewReader(body)) + if err != nil { + return nil, s.writeGoogleError(c, http.StatusInternalServerError, "Failed to build request") + } + // 透传客户端所有请求头(排除 hop-by-hop 和认证头) + for key, values := range c.Request.Header { + if upstreamHopByHopHeaders[strings.ToLower(key)] { + continue + } + for _, v := range values { + req.Header.Add(key, v) } } + // 覆盖认证头 + req.Header.Set("Authorization", "Bearer "+apiKey) - // 重试循环 - var resp *http.Response - var lastErr error - for attempt := 1; attempt <= antigravityMaxRetries; attempt++ { - select { - case <-ctx.Done(): - return nil, ctx.Err() - default: - } - - req, err := http.NewRequestWithContext(ctx, http.MethodPost, apiURL, bytes.NewReader(body)) - if err != nil { - return nil, s.writeGoogleError(c, http.StatusInternalServerError, "Failed to build request") - } - req.Header.Set("Content-Type", "application/json") - req.Header.Set("Authorization", "Bearer "+apiKey) - - if c != nil && len(body) > 0 { - c.Set(OpsUpstreamRequestBodyKey, string(body)) - } - - resp, err = s.httpUpstream.Do(req, proxyURL, account.ID, account.Concurrency) - if err != nil { - lastErr = err - if attempt < antigravityMaxRetries { - log.Printf("%s status=request_failed retry=%d/%d error=%v", prefix, attempt, antigravityMaxRetries, err) - if !sleepAntigravityBackoffWithContext(ctx, attempt) { - return nil, ctx.Err() - } - continue - } - return nil, s.writeGoogleError(c, http.StatusBadGateway, "Upstream request failed after retries") - } - - // 429/503 重试 - if resp.StatusCode == http.StatusTooManyRequests || resp.StatusCode == http.StatusServiceUnavailable { - respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20)) - _ = resp.Body.Close() - - s.handleUpstreamError(ctx, prefix, account, resp.StatusCode, resp.Header, respBody, quotaScope, 0, "", isStickySession) - - if attempt < antigravityMaxRetries { - log.Printf("%s status=%d retry=%d/%d body=%s", prefix, resp.StatusCode, attempt, antigravityMaxRetries, truncateForLog(respBody, 200)) - if !sleepAntigravityBackoffWithContext(ctx, attempt) { - return nil, ctx.Err() - } - continue - } - - return nil, &UpstreamFailoverError{ - StatusCode: resp.StatusCode, - ForceCacheBilling: isStickySession, - } - } - - break + if c != nil && len(body) > 0 { + c.Set(OpsUpstreamRequestBodyKey, string(body)) } - if resp == nil { - return nil, s.writeGoogleError(c, http.StatusBadGateway, fmt.Sprintf("upstream request failed: %v", lastErr)) + + // 单次发送,不重试 + resp, err := s.httpUpstream.Do(req, proxyURL, account.ID, account.Concurrency) + if err != nil { + return nil, s.writeGoogleError(c, http.StatusBadGateway, fmt.Sprintf("Upstream request failed: %v", err)) } - defer func() { - if resp != nil && resp.Body != nil { - _ = resp.Body.Close() - } - }() + defer func() { _ = resp.Body.Close() }() // 错误响应处理 if resp.StatusCode >= 400 { respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20)) contentType := resp.Header.Get("Content-Type") - _ = resp.Body.Close() - resp.Body = io.NopCloser(bytes.NewReader(respBody)) - - // 模型兜底 - if s.settingService != nil && s.settingService.IsModelFallbackEnabled(ctx) && - isModelNotFoundError(resp.StatusCode, respBody) { - fallbackModel := s.settingService.GetFallbackModel(ctx, PlatformAntigravity) - if fallbackModel != "" && fallbackModel != mappedModel { - log.Printf("[Antigravity-Upstream] Model not found (%s), retrying with fallback model %s (account: %s)", mappedModel, fallbackModel, account.Name) - fallbackURL := fmt.Sprintf("%s/antigravity/v1beta/models/%s:%s", baseURL, fallbackModel, upstreamAction) - if stream || upstreamAction == "streamGenerateContent" { - fallbackURL += "?alt=sse" - } - fallbackReq, err := http.NewRequestWithContext(ctx, http.MethodPost, fallbackURL, bytes.NewReader(body)) - if err == nil { - fallbackReq.Header.Set("Content-Type", "application/json") - fallbackReq.Header.Set("Authorization", "Bearer "+apiKey) - fallbackResp, err := s.httpUpstream.Do(fallbackReq, proxyURL, account.ID, account.Concurrency) - if err == nil && fallbackResp.StatusCode < 400 { - _ = resp.Body.Close() - resp = fallbackResp - } else if fallbackResp != nil { - _ = fallbackResp.Body.Close() - } - } - } - } - - // fallback 成功 - if resp.StatusCode < 400 { - goto upstreamGeminiSuccess - } requestID := resp.Header.Get("x-request-id") if requestID != "" { c.Header("x-request-id", requestID) } - s.handleUpstreamError(ctx, prefix, account, resp.StatusCode, resp.Header, respBody, quotaScope, 0, "", isStickySession) + s.handleUpstreamError(ctx, prefix, account, resp.StatusCode, resp.Header, respBody, "", 0, "", isStickySession) upstreamMsg := strings.TrimSpace(extractAntigravityErrorMessage(respBody)) upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg) upstreamDetail := s.getUpstreamErrorDetail(respBody) @@ -3886,7 +3708,7 @@ func (s *AntigravityGatewayService) ForwardUpstreamGemini(ctx context.Context, c return nil, fmt.Errorf("antigravity upstream error: %d", resp.StatusCode) } -upstreamGeminiSuccess: + // 成功响应 requestID := resp.Header.Get("x-request-id") if requestID != "" { c.Header("x-request-id", requestID) diff --git a/backend/internal/service/upstream_header_passthrough_test.go b/backend/internal/service/upstream_header_passthrough_test.go new file mode 100644 index 00000000..51d8588b --- /dev/null +++ b/backend/internal/service/upstream_header_passthrough_test.go @@ -0,0 +1,285 @@ +//go:build unit + +package service + +import ( + "bytes" + "context" + "encoding/json" + "io" + "net/http" + "net/http/httptest" + "testing" + + "github.com/gin-gonic/gin" + "github.com/stretchr/testify/require" +) + +// httpUpstreamCapture captures the outgoing *http.Request for assertion. +type httpUpstreamCapture struct { + capturedReq *http.Request + resp *http.Response + err error +} + +func (s *httpUpstreamCapture) Do(req *http.Request, _ string, _ int64, _ int) (*http.Response, error) { + s.capturedReq = req + return s.resp, s.err +} + +func (s *httpUpstreamCapture) DoWithTLS(req *http.Request, _ string, _ int64, _ int, _ bool) (*http.Response, error) { + s.capturedReq = req + return s.resp, s.err +} + +func newUpstreamAccount() *Account { + return &Account{ + ID: 100, + Name: "upstream-test", + Platform: PlatformAntigravity, + Type: AccountTypeUpstream, + Status: StatusActive, + Concurrency: 1, + Credentials: map[string]any{ + "base_url": "https://upstream.example.com", + "api_key": "sk-upstream-secret", + }, + } +} + +// makeSSEOKResponse builds a minimal SSE response that +// handleClaudeStreamingResponse / handleGeminiStreamingResponse +// can consume without error. +// We return 502 to bypass streaming and hit the error branch instead, +// which is sufficient for testing header passthrough. +func makeUpstreamErrorResponse() *http.Response { + body := []byte(`{"error":{"message":"test error"}}`) + return &http.Response{ + StatusCode: http.StatusBadGateway, + Header: http.Header{"Content-Type": []string{"application/json"}}, + Body: io.NopCloser(bytes.NewReader(body)), + } +} + +// --- ForwardUpstream tests --- + +func TestForwardUpstream_PassthroughHeaders(t *testing.T) { + gin.SetMode(gin.TestMode) + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + + body, _ := json.Marshal(map[string]any{ + "model": "claude-sonnet-4-5", + "messages": []map[string]any{{"role": "user", "content": "hi"}}, + "max_tokens": 1, + "stream": false, + }) + + req := httptest.NewRequest(http.MethodPost, "/v1/messages", bytes.NewReader(body)) + req.Header.Set("Content-Type", "application/json") + req.Header.Set("anthropic-version", "2024-10-22") + req.Header.Set("anthropic-beta", "output-128k-2025-02-19") + req.Header.Set("X-Custom-Header", "custom-value") + c.Request = req + + stub := &httpUpstreamCapture{resp: makeUpstreamErrorResponse()} + svc := &AntigravityGatewayService{ + tokenProvider: &AntigravityTokenProvider{}, + httpUpstream: stub, + } + + _, _ = svc.ForwardUpstream(context.Background(), c, newUpstreamAccount(), body, false) + + captured := stub.capturedReq + require.NotNil(t, captured, "upstream request should have been made") + + // 客户端 header 应被透传 + require.Equal(t, "application/json", captured.Header.Get("Content-Type")) + require.Equal(t, "2024-10-22", captured.Header.Get("anthropic-version")) + require.Equal(t, "output-128k-2025-02-19", captured.Header.Get("anthropic-beta")) + require.Equal(t, "custom-value", captured.Header.Get("X-Custom-Header")) +} + +func TestForwardUpstream_OverridesAuthHeaders(t *testing.T) { + gin.SetMode(gin.TestMode) + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + + body, _ := json.Marshal(map[string]any{ + "model": "claude-sonnet-4-5", + "messages": []map[string]any{{"role": "user", "content": "hi"}}, + "max_tokens": 1, + "stream": false, + }) + + req := httptest.NewRequest(http.MethodPost, "/v1/messages", bytes.NewReader(body)) + req.Header.Set("Content-Type", "application/json") + // 客户端发来的认证头应被覆盖 + req.Header.Set("Authorization", "Bearer client-token") + req.Header.Set("x-api-key", "client-api-key") + c.Request = req + + stub := &httpUpstreamCapture{resp: makeUpstreamErrorResponse()} + svc := &AntigravityGatewayService{ + tokenProvider: &AntigravityTokenProvider{}, + httpUpstream: stub, + } + + _, _ = svc.ForwardUpstream(context.Background(), c, newUpstreamAccount(), body, false) + + captured := stub.capturedReq + require.NotNil(t, captured) + + // 认证头应使用上游账号的 api_key,而非客户端的 + require.Equal(t, "Bearer sk-upstream-secret", captured.Header.Get("Authorization")) + require.Equal(t, "sk-upstream-secret", captured.Header.Get("x-api-key")) +} + +func TestForwardUpstream_ExcludesHopByHopHeaders(t *testing.T) { + gin.SetMode(gin.TestMode) + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + + body, _ := json.Marshal(map[string]any{ + "model": "claude-sonnet-4-5", + "messages": []map[string]any{{"role": "user", "content": "hi"}}, + "max_tokens": 1, + "stream": false, + }) + + req := httptest.NewRequest(http.MethodPost, "/v1/messages", bytes.NewReader(body)) + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Connection", "keep-alive") + req.Header.Set("Keep-Alive", "timeout=5") + req.Header.Set("Transfer-Encoding", "chunked") + req.Header.Set("Upgrade", "websocket") + req.Header.Set("Te", "trailers") + c.Request = req + + stub := &httpUpstreamCapture{resp: makeUpstreamErrorResponse()} + svc := &AntigravityGatewayService{ + tokenProvider: &AntigravityTokenProvider{}, + httpUpstream: stub, + } + + _, _ = svc.ForwardUpstream(context.Background(), c, newUpstreamAccount(), body, false) + + captured := stub.capturedReq + require.NotNil(t, captured) + + // hop-by-hop header 不应出现 + require.Empty(t, captured.Header.Get("Connection")) + require.Empty(t, captured.Header.Get("Keep-Alive")) + require.Empty(t, captured.Header.Get("Transfer-Encoding")) + require.Empty(t, captured.Header.Get("Upgrade")) + require.Empty(t, captured.Header.Get("Te")) + + // 但普通 header 应保留 + require.Equal(t, "application/json", captured.Header.Get("Content-Type")) +} + +// --- ForwardUpstreamGemini tests --- + +func TestForwardUpstreamGemini_PassthroughHeaders(t *testing.T) { + gin.SetMode(gin.TestMode) + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + + body, _ := json.Marshal(map[string]any{ + "contents": []map[string]any{ + {"role": "user", "parts": []map[string]any{{"text": "hi"}}}, + }, + }) + + req := httptest.NewRequest(http.MethodPost, "/v1beta/models/gemini-2.5-flash:generateContent", bytes.NewReader(body)) + req.Header.Set("Content-Type", "application/json") + req.Header.Set("X-Custom-Gemini", "gemini-value") + req.Header.Set("X-Request-Id", "req-abc-123") + c.Request = req + + stub := &httpUpstreamCapture{resp: makeUpstreamErrorResponse()} + svc := &AntigravityGatewayService{ + tokenProvider: &AntigravityTokenProvider{}, + httpUpstream: stub, + } + + _, _ = svc.ForwardUpstreamGemini(context.Background(), c, newUpstreamAccount(), "gemini-2.5-flash", "generateContent", false, body, false) + + captured := stub.capturedReq + require.NotNil(t, captured, "upstream request should have been made") + + // 客户端 header 应被透传 + require.Equal(t, "application/json", captured.Header.Get("Content-Type")) + require.Equal(t, "gemini-value", captured.Header.Get("X-Custom-Gemini")) + require.Equal(t, "req-abc-123", captured.Header.Get("X-Request-Id")) +} + +func TestForwardUpstreamGemini_OverridesAuthHeaders(t *testing.T) { + gin.SetMode(gin.TestMode) + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + + body, _ := json.Marshal(map[string]any{ + "contents": []map[string]any{ + {"role": "user", "parts": []map[string]any{{"text": "hi"}}}, + }, + }) + + req := httptest.NewRequest(http.MethodPost, "/v1beta/models/gemini-2.5-flash:generateContent", bytes.NewReader(body)) + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Authorization", "Bearer client-gemini-token") + c.Request = req + + stub := &httpUpstreamCapture{resp: makeUpstreamErrorResponse()} + svc := &AntigravityGatewayService{ + tokenProvider: &AntigravityTokenProvider{}, + httpUpstream: stub, + } + + _, _ = svc.ForwardUpstreamGemini(context.Background(), c, newUpstreamAccount(), "gemini-2.5-flash", "generateContent", false, body, false) + + captured := stub.capturedReq + require.NotNil(t, captured) + + // 认证头应使用上游账号的 api_key + require.Equal(t, "Bearer sk-upstream-secret", captured.Header.Get("Authorization")) +} + +func TestForwardUpstreamGemini_ExcludesHopByHopHeaders(t *testing.T) { + gin.SetMode(gin.TestMode) + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + + body, _ := json.Marshal(map[string]any{ + "contents": []map[string]any{ + {"role": "user", "parts": []map[string]any{{"text": "hi"}}}, + }, + }) + + req := httptest.NewRequest(http.MethodPost, "/v1beta/models/gemini-2.5-flash:generateContent", bytes.NewReader(body)) + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Connection", "keep-alive") + req.Header.Set("Proxy-Authorization", "Basic dXNlcjpwYXNz") + req.Header.Set("Host", "evil.example.com") + c.Request = req + + stub := &httpUpstreamCapture{resp: makeUpstreamErrorResponse()} + svc := &AntigravityGatewayService{ + tokenProvider: &AntigravityTokenProvider{}, + httpUpstream: stub, + } + + _, _ = svc.ForwardUpstreamGemini(context.Background(), c, newUpstreamAccount(), "gemini-2.5-flash", "generateContent", false, body, false) + + captured := stub.capturedReq + require.NotNil(t, captured) + + // hop-by-hop header 不应出现 + require.Empty(t, captured.Header.Get("Connection")) + require.Empty(t, captured.Header.Get("Proxy-Authorization")) + // Host header 在 Go http.Request 中特殊处理,但我们的黑名单应阻止透传 + require.Empty(t, captured.Header.Values("Host")) + + // 普通 header 应保留 + require.Equal(t, "application/json", captured.Header.Get("Content-Type")) +}