feat(upstream): passthrough all client headers instead of manual header setting

Replace manual header setting (Content-Type, anthropic-version, anthropic-beta)
with full client header passthrough in ForwardUpstream/ForwardUpstreamGemini.
Only authentication headers (Authorization, x-api-key) are overridden with
upstream account credentials. Hop-by-hop headers are excluded.

Add unit tests covering header passthrough, auth override, and hop-by-hop filtering.
This commit is contained in:
erio
2026-02-08 08:33:09 +08:00
parent df3346387f
commit 1563bd3dda
2 changed files with 352 additions and 245 deletions

View File

@@ -0,0 +1,285 @@
//go:build unit
package service
import (
"bytes"
"context"
"encoding/json"
"io"
"net/http"
"net/http/httptest"
"testing"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/require"
)
// httpUpstreamCapture captures the outgoing *http.Request for assertion.
type httpUpstreamCapture struct {
capturedReq *http.Request
resp *http.Response
err error
}
func (s *httpUpstreamCapture) Do(req *http.Request, _ string, _ int64, _ int) (*http.Response, error) {
s.capturedReq = req
return s.resp, s.err
}
func (s *httpUpstreamCapture) DoWithTLS(req *http.Request, _ string, _ int64, _ int, _ bool) (*http.Response, error) {
s.capturedReq = req
return s.resp, s.err
}
func newUpstreamAccount() *Account {
return &Account{
ID: 100,
Name: "upstream-test",
Platform: PlatformAntigravity,
Type: AccountTypeUpstream,
Status: StatusActive,
Concurrency: 1,
Credentials: map[string]any{
"base_url": "https://upstream.example.com",
"api_key": "sk-upstream-secret",
},
}
}
// makeSSEOKResponse builds a minimal SSE response that
// handleClaudeStreamingResponse / handleGeminiStreamingResponse
// can consume without error.
// We return 502 to bypass streaming and hit the error branch instead,
// which is sufficient for testing header passthrough.
func makeUpstreamErrorResponse() *http.Response {
body := []byte(`{"error":{"message":"test error"}}`)
return &http.Response{
StatusCode: http.StatusBadGateway,
Header: http.Header{"Content-Type": []string{"application/json"}},
Body: io.NopCloser(bytes.NewReader(body)),
}
}
// --- ForwardUpstream tests ---
func TestForwardUpstream_PassthroughHeaders(t *testing.T) {
gin.SetMode(gin.TestMode)
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
body, _ := json.Marshal(map[string]any{
"model": "claude-sonnet-4-5",
"messages": []map[string]any{{"role": "user", "content": "hi"}},
"max_tokens": 1,
"stream": false,
})
req := httptest.NewRequest(http.MethodPost, "/v1/messages", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
req.Header.Set("anthropic-version", "2024-10-22")
req.Header.Set("anthropic-beta", "output-128k-2025-02-19")
req.Header.Set("X-Custom-Header", "custom-value")
c.Request = req
stub := &httpUpstreamCapture{resp: makeUpstreamErrorResponse()}
svc := &AntigravityGatewayService{
tokenProvider: &AntigravityTokenProvider{},
httpUpstream: stub,
}
_, _ = svc.ForwardUpstream(context.Background(), c, newUpstreamAccount(), body, false)
captured := stub.capturedReq
require.NotNil(t, captured, "upstream request should have been made")
// 客户端 header 应被透传
require.Equal(t, "application/json", captured.Header.Get("Content-Type"))
require.Equal(t, "2024-10-22", captured.Header.Get("anthropic-version"))
require.Equal(t, "output-128k-2025-02-19", captured.Header.Get("anthropic-beta"))
require.Equal(t, "custom-value", captured.Header.Get("X-Custom-Header"))
}
func TestForwardUpstream_OverridesAuthHeaders(t *testing.T) {
gin.SetMode(gin.TestMode)
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
body, _ := json.Marshal(map[string]any{
"model": "claude-sonnet-4-5",
"messages": []map[string]any{{"role": "user", "content": "hi"}},
"max_tokens": 1,
"stream": false,
})
req := httptest.NewRequest(http.MethodPost, "/v1/messages", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
// 客户端发来的认证头应被覆盖
req.Header.Set("Authorization", "Bearer client-token")
req.Header.Set("x-api-key", "client-api-key")
c.Request = req
stub := &httpUpstreamCapture{resp: makeUpstreamErrorResponse()}
svc := &AntigravityGatewayService{
tokenProvider: &AntigravityTokenProvider{},
httpUpstream: stub,
}
_, _ = svc.ForwardUpstream(context.Background(), c, newUpstreamAccount(), body, false)
captured := stub.capturedReq
require.NotNil(t, captured)
// 认证头应使用上游账号的 api_key而非客户端的
require.Equal(t, "Bearer sk-upstream-secret", captured.Header.Get("Authorization"))
require.Equal(t, "sk-upstream-secret", captured.Header.Get("x-api-key"))
}
func TestForwardUpstream_ExcludesHopByHopHeaders(t *testing.T) {
gin.SetMode(gin.TestMode)
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
body, _ := json.Marshal(map[string]any{
"model": "claude-sonnet-4-5",
"messages": []map[string]any{{"role": "user", "content": "hi"}},
"max_tokens": 1,
"stream": false,
})
req := httptest.NewRequest(http.MethodPost, "/v1/messages", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Connection", "keep-alive")
req.Header.Set("Keep-Alive", "timeout=5")
req.Header.Set("Transfer-Encoding", "chunked")
req.Header.Set("Upgrade", "websocket")
req.Header.Set("Te", "trailers")
c.Request = req
stub := &httpUpstreamCapture{resp: makeUpstreamErrorResponse()}
svc := &AntigravityGatewayService{
tokenProvider: &AntigravityTokenProvider{},
httpUpstream: stub,
}
_, _ = svc.ForwardUpstream(context.Background(), c, newUpstreamAccount(), body, false)
captured := stub.capturedReq
require.NotNil(t, captured)
// hop-by-hop header 不应出现
require.Empty(t, captured.Header.Get("Connection"))
require.Empty(t, captured.Header.Get("Keep-Alive"))
require.Empty(t, captured.Header.Get("Transfer-Encoding"))
require.Empty(t, captured.Header.Get("Upgrade"))
require.Empty(t, captured.Header.Get("Te"))
// 但普通 header 应保留
require.Equal(t, "application/json", captured.Header.Get("Content-Type"))
}
// --- ForwardUpstreamGemini tests ---
func TestForwardUpstreamGemini_PassthroughHeaders(t *testing.T) {
gin.SetMode(gin.TestMode)
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
body, _ := json.Marshal(map[string]any{
"contents": []map[string]any{
{"role": "user", "parts": []map[string]any{{"text": "hi"}}},
},
})
req := httptest.NewRequest(http.MethodPost, "/v1beta/models/gemini-2.5-flash:generateContent", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
req.Header.Set("X-Custom-Gemini", "gemini-value")
req.Header.Set("X-Request-Id", "req-abc-123")
c.Request = req
stub := &httpUpstreamCapture{resp: makeUpstreamErrorResponse()}
svc := &AntigravityGatewayService{
tokenProvider: &AntigravityTokenProvider{},
httpUpstream: stub,
}
_, _ = svc.ForwardUpstreamGemini(context.Background(), c, newUpstreamAccount(), "gemini-2.5-flash", "generateContent", false, body, false)
captured := stub.capturedReq
require.NotNil(t, captured, "upstream request should have been made")
// 客户端 header 应被透传
require.Equal(t, "application/json", captured.Header.Get("Content-Type"))
require.Equal(t, "gemini-value", captured.Header.Get("X-Custom-Gemini"))
require.Equal(t, "req-abc-123", captured.Header.Get("X-Request-Id"))
}
func TestForwardUpstreamGemini_OverridesAuthHeaders(t *testing.T) {
gin.SetMode(gin.TestMode)
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
body, _ := json.Marshal(map[string]any{
"contents": []map[string]any{
{"role": "user", "parts": []map[string]any{{"text": "hi"}}},
},
})
req := httptest.NewRequest(http.MethodPost, "/v1beta/models/gemini-2.5-flash:generateContent", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", "Bearer client-gemini-token")
c.Request = req
stub := &httpUpstreamCapture{resp: makeUpstreamErrorResponse()}
svc := &AntigravityGatewayService{
tokenProvider: &AntigravityTokenProvider{},
httpUpstream: stub,
}
_, _ = svc.ForwardUpstreamGemini(context.Background(), c, newUpstreamAccount(), "gemini-2.5-flash", "generateContent", false, body, false)
captured := stub.capturedReq
require.NotNil(t, captured)
// 认证头应使用上游账号的 api_key
require.Equal(t, "Bearer sk-upstream-secret", captured.Header.Get("Authorization"))
}
func TestForwardUpstreamGemini_ExcludesHopByHopHeaders(t *testing.T) {
gin.SetMode(gin.TestMode)
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
body, _ := json.Marshal(map[string]any{
"contents": []map[string]any{
{"role": "user", "parts": []map[string]any{{"text": "hi"}}},
},
})
req := httptest.NewRequest(http.MethodPost, "/v1beta/models/gemini-2.5-flash:generateContent", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Connection", "keep-alive")
req.Header.Set("Proxy-Authorization", "Basic dXNlcjpwYXNz")
req.Header.Set("Host", "evil.example.com")
c.Request = req
stub := &httpUpstreamCapture{resp: makeUpstreamErrorResponse()}
svc := &AntigravityGatewayService{
tokenProvider: &AntigravityTokenProvider{},
httpUpstream: stub,
}
_, _ = svc.ForwardUpstreamGemini(context.Background(), c, newUpstreamAccount(), "gemini-2.5-flash", "generateContent", false, body, false)
captured := stub.capturedReq
require.NotNil(t, captured)
// hop-by-hop header 不应出现
require.Empty(t, captured.Header.Get("Connection"))
require.Empty(t, captured.Header.Get("Proxy-Authorization"))
// Host header 在 Go http.Request 中特殊处理,但我们的黑名单应阻止透传
require.Empty(t, captured.Header.Values("Host"))
// 普通 header 应保留
require.Equal(t, "application/json", captured.Header.Get("Content-Type"))
}