Replace manual header setting (Content-Type, anthropic-version, anthropic-beta) with full client header passthrough in ForwardUpstream/ForwardUpstreamGemini. Only authentication headers (Authorization, x-api-key) are overridden with upstream account credentials. Hop-by-hop headers are excluded. Add unit tests covering header passthrough, auth override, and hop-by-hop filtering.
286 lines
9.2 KiB
Go
286 lines
9.2 KiB
Go
//go:build unit
|
||
|
||
package service
|
||
|
||
import (
|
||
"bytes"
|
||
"context"
|
||
"encoding/json"
|
||
"io"
|
||
"net/http"
|
||
"net/http/httptest"
|
||
"testing"
|
||
|
||
"github.com/gin-gonic/gin"
|
||
"github.com/stretchr/testify/require"
|
||
)
|
||
|
||
// httpUpstreamCapture captures the outgoing *http.Request for assertion.
|
||
type httpUpstreamCapture struct {
|
||
capturedReq *http.Request
|
||
resp *http.Response
|
||
err error
|
||
}
|
||
|
||
func (s *httpUpstreamCapture) Do(req *http.Request, _ string, _ int64, _ int) (*http.Response, error) {
|
||
s.capturedReq = req
|
||
return s.resp, s.err
|
||
}
|
||
|
||
func (s *httpUpstreamCapture) DoWithTLS(req *http.Request, _ string, _ int64, _ int, _ bool) (*http.Response, error) {
|
||
s.capturedReq = req
|
||
return s.resp, s.err
|
||
}
|
||
|
||
func newUpstreamAccount() *Account {
|
||
return &Account{
|
||
ID: 100,
|
||
Name: "upstream-test",
|
||
Platform: PlatformAntigravity,
|
||
Type: AccountTypeUpstream,
|
||
Status: StatusActive,
|
||
Concurrency: 1,
|
||
Credentials: map[string]any{
|
||
"base_url": "https://upstream.example.com",
|
||
"api_key": "sk-upstream-secret",
|
||
},
|
||
}
|
||
}
|
||
|
||
// makeSSEOKResponse builds a minimal SSE response that
|
||
// handleClaudeStreamingResponse / handleGeminiStreamingResponse
|
||
// can consume without error.
|
||
// We return 502 to bypass streaming and hit the error branch instead,
|
||
// which is sufficient for testing header passthrough.
|
||
func makeUpstreamErrorResponse() *http.Response {
|
||
body := []byte(`{"error":{"message":"test error"}}`)
|
||
return &http.Response{
|
||
StatusCode: http.StatusBadGateway,
|
||
Header: http.Header{"Content-Type": []string{"application/json"}},
|
||
Body: io.NopCloser(bytes.NewReader(body)),
|
||
}
|
||
}
|
||
|
||
// --- ForwardUpstream tests ---
|
||
|
||
func TestForwardUpstream_PassthroughHeaders(t *testing.T) {
|
||
gin.SetMode(gin.TestMode)
|
||
w := httptest.NewRecorder()
|
||
c, _ := gin.CreateTestContext(w)
|
||
|
||
body, _ := json.Marshal(map[string]any{
|
||
"model": "claude-sonnet-4-5",
|
||
"messages": []map[string]any{{"role": "user", "content": "hi"}},
|
||
"max_tokens": 1,
|
||
"stream": false,
|
||
})
|
||
|
||
req := httptest.NewRequest(http.MethodPost, "/v1/messages", bytes.NewReader(body))
|
||
req.Header.Set("Content-Type", "application/json")
|
||
req.Header.Set("anthropic-version", "2024-10-22")
|
||
req.Header.Set("anthropic-beta", "output-128k-2025-02-19")
|
||
req.Header.Set("X-Custom-Header", "custom-value")
|
||
c.Request = req
|
||
|
||
stub := &httpUpstreamCapture{resp: makeUpstreamErrorResponse()}
|
||
svc := &AntigravityGatewayService{
|
||
tokenProvider: &AntigravityTokenProvider{},
|
||
httpUpstream: stub,
|
||
}
|
||
|
||
_, _ = svc.ForwardUpstream(context.Background(), c, newUpstreamAccount(), body, false)
|
||
|
||
captured := stub.capturedReq
|
||
require.NotNil(t, captured, "upstream request should have been made")
|
||
|
||
// 客户端 header 应被透传
|
||
require.Equal(t, "application/json", captured.Header.Get("Content-Type"))
|
||
require.Equal(t, "2024-10-22", captured.Header.Get("anthropic-version"))
|
||
require.Equal(t, "output-128k-2025-02-19", captured.Header.Get("anthropic-beta"))
|
||
require.Equal(t, "custom-value", captured.Header.Get("X-Custom-Header"))
|
||
}
|
||
|
||
func TestForwardUpstream_OverridesAuthHeaders(t *testing.T) {
|
||
gin.SetMode(gin.TestMode)
|
||
w := httptest.NewRecorder()
|
||
c, _ := gin.CreateTestContext(w)
|
||
|
||
body, _ := json.Marshal(map[string]any{
|
||
"model": "claude-sonnet-4-5",
|
||
"messages": []map[string]any{{"role": "user", "content": "hi"}},
|
||
"max_tokens": 1,
|
||
"stream": false,
|
||
})
|
||
|
||
req := httptest.NewRequest(http.MethodPost, "/v1/messages", bytes.NewReader(body))
|
||
req.Header.Set("Content-Type", "application/json")
|
||
// 客户端发来的认证头应被覆盖
|
||
req.Header.Set("Authorization", "Bearer client-token")
|
||
req.Header.Set("x-api-key", "client-api-key")
|
||
c.Request = req
|
||
|
||
stub := &httpUpstreamCapture{resp: makeUpstreamErrorResponse()}
|
||
svc := &AntigravityGatewayService{
|
||
tokenProvider: &AntigravityTokenProvider{},
|
||
httpUpstream: stub,
|
||
}
|
||
|
||
_, _ = svc.ForwardUpstream(context.Background(), c, newUpstreamAccount(), body, false)
|
||
|
||
captured := stub.capturedReq
|
||
require.NotNil(t, captured)
|
||
|
||
// 认证头应使用上游账号的 api_key,而非客户端的
|
||
require.Equal(t, "Bearer sk-upstream-secret", captured.Header.Get("Authorization"))
|
||
require.Equal(t, "sk-upstream-secret", captured.Header.Get("x-api-key"))
|
||
}
|
||
|
||
func TestForwardUpstream_ExcludesHopByHopHeaders(t *testing.T) {
|
||
gin.SetMode(gin.TestMode)
|
||
w := httptest.NewRecorder()
|
||
c, _ := gin.CreateTestContext(w)
|
||
|
||
body, _ := json.Marshal(map[string]any{
|
||
"model": "claude-sonnet-4-5",
|
||
"messages": []map[string]any{{"role": "user", "content": "hi"}},
|
||
"max_tokens": 1,
|
||
"stream": false,
|
||
})
|
||
|
||
req := httptest.NewRequest(http.MethodPost, "/v1/messages", bytes.NewReader(body))
|
||
req.Header.Set("Content-Type", "application/json")
|
||
req.Header.Set("Connection", "keep-alive")
|
||
req.Header.Set("Keep-Alive", "timeout=5")
|
||
req.Header.Set("Transfer-Encoding", "chunked")
|
||
req.Header.Set("Upgrade", "websocket")
|
||
req.Header.Set("Te", "trailers")
|
||
c.Request = req
|
||
|
||
stub := &httpUpstreamCapture{resp: makeUpstreamErrorResponse()}
|
||
svc := &AntigravityGatewayService{
|
||
tokenProvider: &AntigravityTokenProvider{},
|
||
httpUpstream: stub,
|
||
}
|
||
|
||
_, _ = svc.ForwardUpstream(context.Background(), c, newUpstreamAccount(), body, false)
|
||
|
||
captured := stub.capturedReq
|
||
require.NotNil(t, captured)
|
||
|
||
// hop-by-hop header 不应出现
|
||
require.Empty(t, captured.Header.Get("Connection"))
|
||
require.Empty(t, captured.Header.Get("Keep-Alive"))
|
||
require.Empty(t, captured.Header.Get("Transfer-Encoding"))
|
||
require.Empty(t, captured.Header.Get("Upgrade"))
|
||
require.Empty(t, captured.Header.Get("Te"))
|
||
|
||
// 但普通 header 应保留
|
||
require.Equal(t, "application/json", captured.Header.Get("Content-Type"))
|
||
}
|
||
|
||
// --- ForwardUpstreamGemini tests ---
|
||
|
||
func TestForwardUpstreamGemini_PassthroughHeaders(t *testing.T) {
|
||
gin.SetMode(gin.TestMode)
|
||
w := httptest.NewRecorder()
|
||
c, _ := gin.CreateTestContext(w)
|
||
|
||
body, _ := json.Marshal(map[string]any{
|
||
"contents": []map[string]any{
|
||
{"role": "user", "parts": []map[string]any{{"text": "hi"}}},
|
||
},
|
||
})
|
||
|
||
req := httptest.NewRequest(http.MethodPost, "/v1beta/models/gemini-2.5-flash:generateContent", bytes.NewReader(body))
|
||
req.Header.Set("Content-Type", "application/json")
|
||
req.Header.Set("X-Custom-Gemini", "gemini-value")
|
||
req.Header.Set("X-Request-Id", "req-abc-123")
|
||
c.Request = req
|
||
|
||
stub := &httpUpstreamCapture{resp: makeUpstreamErrorResponse()}
|
||
svc := &AntigravityGatewayService{
|
||
tokenProvider: &AntigravityTokenProvider{},
|
||
httpUpstream: stub,
|
||
}
|
||
|
||
_, _ = svc.ForwardUpstreamGemini(context.Background(), c, newUpstreamAccount(), "gemini-2.5-flash", "generateContent", false, body, false)
|
||
|
||
captured := stub.capturedReq
|
||
require.NotNil(t, captured, "upstream request should have been made")
|
||
|
||
// 客户端 header 应被透传
|
||
require.Equal(t, "application/json", captured.Header.Get("Content-Type"))
|
||
require.Equal(t, "gemini-value", captured.Header.Get("X-Custom-Gemini"))
|
||
require.Equal(t, "req-abc-123", captured.Header.Get("X-Request-Id"))
|
||
}
|
||
|
||
func TestForwardUpstreamGemini_OverridesAuthHeaders(t *testing.T) {
|
||
gin.SetMode(gin.TestMode)
|
||
w := httptest.NewRecorder()
|
||
c, _ := gin.CreateTestContext(w)
|
||
|
||
body, _ := json.Marshal(map[string]any{
|
||
"contents": []map[string]any{
|
||
{"role": "user", "parts": []map[string]any{{"text": "hi"}}},
|
||
},
|
||
})
|
||
|
||
req := httptest.NewRequest(http.MethodPost, "/v1beta/models/gemini-2.5-flash:generateContent", bytes.NewReader(body))
|
||
req.Header.Set("Content-Type", "application/json")
|
||
req.Header.Set("Authorization", "Bearer client-gemini-token")
|
||
c.Request = req
|
||
|
||
stub := &httpUpstreamCapture{resp: makeUpstreamErrorResponse()}
|
||
svc := &AntigravityGatewayService{
|
||
tokenProvider: &AntigravityTokenProvider{},
|
||
httpUpstream: stub,
|
||
}
|
||
|
||
_, _ = svc.ForwardUpstreamGemini(context.Background(), c, newUpstreamAccount(), "gemini-2.5-flash", "generateContent", false, body, false)
|
||
|
||
captured := stub.capturedReq
|
||
require.NotNil(t, captured)
|
||
|
||
// 认证头应使用上游账号的 api_key
|
||
require.Equal(t, "Bearer sk-upstream-secret", captured.Header.Get("Authorization"))
|
||
}
|
||
|
||
func TestForwardUpstreamGemini_ExcludesHopByHopHeaders(t *testing.T) {
|
||
gin.SetMode(gin.TestMode)
|
||
w := httptest.NewRecorder()
|
||
c, _ := gin.CreateTestContext(w)
|
||
|
||
body, _ := json.Marshal(map[string]any{
|
||
"contents": []map[string]any{
|
||
{"role": "user", "parts": []map[string]any{{"text": "hi"}}},
|
||
},
|
||
})
|
||
|
||
req := httptest.NewRequest(http.MethodPost, "/v1beta/models/gemini-2.5-flash:generateContent", bytes.NewReader(body))
|
||
req.Header.Set("Content-Type", "application/json")
|
||
req.Header.Set("Connection", "keep-alive")
|
||
req.Header.Set("Proxy-Authorization", "Basic dXNlcjpwYXNz")
|
||
req.Header.Set("Host", "evil.example.com")
|
||
c.Request = req
|
||
|
||
stub := &httpUpstreamCapture{resp: makeUpstreamErrorResponse()}
|
||
svc := &AntigravityGatewayService{
|
||
tokenProvider: &AntigravityTokenProvider{},
|
||
httpUpstream: stub,
|
||
}
|
||
|
||
_, _ = svc.ForwardUpstreamGemini(context.Background(), c, newUpstreamAccount(), "gemini-2.5-flash", "generateContent", false, body, false)
|
||
|
||
captured := stub.capturedReq
|
||
require.NotNil(t, captured)
|
||
|
||
// hop-by-hop header 不应出现
|
||
require.Empty(t, captured.Header.Get("Connection"))
|
||
require.Empty(t, captured.Header.Get("Proxy-Authorization"))
|
||
// Host header 在 Go http.Request 中特殊处理,但我们的黑名单应阻止透传
|
||
require.Empty(t, captured.Header.Values("Host"))
|
||
|
||
// 普通 header 应保留
|
||
require.Equal(t, "application/json", captured.Header.Get("Content-Type"))
|
||
}
|