Files
sub2api/backend/internal/service/upstream_header_passthrough_test.go
erio 1563bd3dda feat(upstream): passthrough all client headers instead of manual header setting
Replace manual header setting (Content-Type, anthropic-version, anthropic-beta)
with full client header passthrough in ForwardUpstream/ForwardUpstreamGemini.
Only authentication headers (Authorization, x-api-key) are overridden with
upstream account credentials. Hop-by-hop headers are excluded.

Add unit tests covering header passthrough, auth override, and hop-by-hop filtering.
2026-02-08 08:33:09 +08:00

286 lines
9.2 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

//go:build unit
package service
import (
"bytes"
"context"
"encoding/json"
"io"
"net/http"
"net/http/httptest"
"testing"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/require"
)
// httpUpstreamCapture captures the outgoing *http.Request for assertion.
type httpUpstreamCapture struct {
capturedReq *http.Request
resp *http.Response
err error
}
func (s *httpUpstreamCapture) Do(req *http.Request, _ string, _ int64, _ int) (*http.Response, error) {
s.capturedReq = req
return s.resp, s.err
}
func (s *httpUpstreamCapture) DoWithTLS(req *http.Request, _ string, _ int64, _ int, _ bool) (*http.Response, error) {
s.capturedReq = req
return s.resp, s.err
}
func newUpstreamAccount() *Account {
return &Account{
ID: 100,
Name: "upstream-test",
Platform: PlatformAntigravity,
Type: AccountTypeUpstream,
Status: StatusActive,
Concurrency: 1,
Credentials: map[string]any{
"base_url": "https://upstream.example.com",
"api_key": "sk-upstream-secret",
},
}
}
// makeSSEOKResponse builds a minimal SSE response that
// handleClaudeStreamingResponse / handleGeminiStreamingResponse
// can consume without error.
// We return 502 to bypass streaming and hit the error branch instead,
// which is sufficient for testing header passthrough.
func makeUpstreamErrorResponse() *http.Response {
body := []byte(`{"error":{"message":"test error"}}`)
return &http.Response{
StatusCode: http.StatusBadGateway,
Header: http.Header{"Content-Type": []string{"application/json"}},
Body: io.NopCloser(bytes.NewReader(body)),
}
}
// --- ForwardUpstream tests ---
func TestForwardUpstream_PassthroughHeaders(t *testing.T) {
gin.SetMode(gin.TestMode)
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
body, _ := json.Marshal(map[string]any{
"model": "claude-sonnet-4-5",
"messages": []map[string]any{{"role": "user", "content": "hi"}},
"max_tokens": 1,
"stream": false,
})
req := httptest.NewRequest(http.MethodPost, "/v1/messages", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
req.Header.Set("anthropic-version", "2024-10-22")
req.Header.Set("anthropic-beta", "output-128k-2025-02-19")
req.Header.Set("X-Custom-Header", "custom-value")
c.Request = req
stub := &httpUpstreamCapture{resp: makeUpstreamErrorResponse()}
svc := &AntigravityGatewayService{
tokenProvider: &AntigravityTokenProvider{},
httpUpstream: stub,
}
_, _ = svc.ForwardUpstream(context.Background(), c, newUpstreamAccount(), body, false)
captured := stub.capturedReq
require.NotNil(t, captured, "upstream request should have been made")
// 客户端 header 应被透传
require.Equal(t, "application/json", captured.Header.Get("Content-Type"))
require.Equal(t, "2024-10-22", captured.Header.Get("anthropic-version"))
require.Equal(t, "output-128k-2025-02-19", captured.Header.Get("anthropic-beta"))
require.Equal(t, "custom-value", captured.Header.Get("X-Custom-Header"))
}
func TestForwardUpstream_OverridesAuthHeaders(t *testing.T) {
gin.SetMode(gin.TestMode)
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
body, _ := json.Marshal(map[string]any{
"model": "claude-sonnet-4-5",
"messages": []map[string]any{{"role": "user", "content": "hi"}},
"max_tokens": 1,
"stream": false,
})
req := httptest.NewRequest(http.MethodPost, "/v1/messages", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
// 客户端发来的认证头应被覆盖
req.Header.Set("Authorization", "Bearer client-token")
req.Header.Set("x-api-key", "client-api-key")
c.Request = req
stub := &httpUpstreamCapture{resp: makeUpstreamErrorResponse()}
svc := &AntigravityGatewayService{
tokenProvider: &AntigravityTokenProvider{},
httpUpstream: stub,
}
_, _ = svc.ForwardUpstream(context.Background(), c, newUpstreamAccount(), body, false)
captured := stub.capturedReq
require.NotNil(t, captured)
// 认证头应使用上游账号的 api_key而非客户端的
require.Equal(t, "Bearer sk-upstream-secret", captured.Header.Get("Authorization"))
require.Equal(t, "sk-upstream-secret", captured.Header.Get("x-api-key"))
}
func TestForwardUpstream_ExcludesHopByHopHeaders(t *testing.T) {
gin.SetMode(gin.TestMode)
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
body, _ := json.Marshal(map[string]any{
"model": "claude-sonnet-4-5",
"messages": []map[string]any{{"role": "user", "content": "hi"}},
"max_tokens": 1,
"stream": false,
})
req := httptest.NewRequest(http.MethodPost, "/v1/messages", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Connection", "keep-alive")
req.Header.Set("Keep-Alive", "timeout=5")
req.Header.Set("Transfer-Encoding", "chunked")
req.Header.Set("Upgrade", "websocket")
req.Header.Set("Te", "trailers")
c.Request = req
stub := &httpUpstreamCapture{resp: makeUpstreamErrorResponse()}
svc := &AntigravityGatewayService{
tokenProvider: &AntigravityTokenProvider{},
httpUpstream: stub,
}
_, _ = svc.ForwardUpstream(context.Background(), c, newUpstreamAccount(), body, false)
captured := stub.capturedReq
require.NotNil(t, captured)
// hop-by-hop header 不应出现
require.Empty(t, captured.Header.Get("Connection"))
require.Empty(t, captured.Header.Get("Keep-Alive"))
require.Empty(t, captured.Header.Get("Transfer-Encoding"))
require.Empty(t, captured.Header.Get("Upgrade"))
require.Empty(t, captured.Header.Get("Te"))
// 但普通 header 应保留
require.Equal(t, "application/json", captured.Header.Get("Content-Type"))
}
// --- ForwardUpstreamGemini tests ---
func TestForwardUpstreamGemini_PassthroughHeaders(t *testing.T) {
gin.SetMode(gin.TestMode)
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
body, _ := json.Marshal(map[string]any{
"contents": []map[string]any{
{"role": "user", "parts": []map[string]any{{"text": "hi"}}},
},
})
req := httptest.NewRequest(http.MethodPost, "/v1beta/models/gemini-2.5-flash:generateContent", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
req.Header.Set("X-Custom-Gemini", "gemini-value")
req.Header.Set("X-Request-Id", "req-abc-123")
c.Request = req
stub := &httpUpstreamCapture{resp: makeUpstreamErrorResponse()}
svc := &AntigravityGatewayService{
tokenProvider: &AntigravityTokenProvider{},
httpUpstream: stub,
}
_, _ = svc.ForwardUpstreamGemini(context.Background(), c, newUpstreamAccount(), "gemini-2.5-flash", "generateContent", false, body, false)
captured := stub.capturedReq
require.NotNil(t, captured, "upstream request should have been made")
// 客户端 header 应被透传
require.Equal(t, "application/json", captured.Header.Get("Content-Type"))
require.Equal(t, "gemini-value", captured.Header.Get("X-Custom-Gemini"))
require.Equal(t, "req-abc-123", captured.Header.Get("X-Request-Id"))
}
func TestForwardUpstreamGemini_OverridesAuthHeaders(t *testing.T) {
gin.SetMode(gin.TestMode)
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
body, _ := json.Marshal(map[string]any{
"contents": []map[string]any{
{"role": "user", "parts": []map[string]any{{"text": "hi"}}},
},
})
req := httptest.NewRequest(http.MethodPost, "/v1beta/models/gemini-2.5-flash:generateContent", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", "Bearer client-gemini-token")
c.Request = req
stub := &httpUpstreamCapture{resp: makeUpstreamErrorResponse()}
svc := &AntigravityGatewayService{
tokenProvider: &AntigravityTokenProvider{},
httpUpstream: stub,
}
_, _ = svc.ForwardUpstreamGemini(context.Background(), c, newUpstreamAccount(), "gemini-2.5-flash", "generateContent", false, body, false)
captured := stub.capturedReq
require.NotNil(t, captured)
// 认证头应使用上游账号的 api_key
require.Equal(t, "Bearer sk-upstream-secret", captured.Header.Get("Authorization"))
}
func TestForwardUpstreamGemini_ExcludesHopByHopHeaders(t *testing.T) {
gin.SetMode(gin.TestMode)
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
body, _ := json.Marshal(map[string]any{
"contents": []map[string]any{
{"role": "user", "parts": []map[string]any{{"text": "hi"}}},
},
})
req := httptest.NewRequest(http.MethodPost, "/v1beta/models/gemini-2.5-flash:generateContent", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Connection", "keep-alive")
req.Header.Set("Proxy-Authorization", "Basic dXNlcjpwYXNz")
req.Header.Set("Host", "evil.example.com")
c.Request = req
stub := &httpUpstreamCapture{resp: makeUpstreamErrorResponse()}
svc := &AntigravityGatewayService{
tokenProvider: &AntigravityTokenProvider{},
httpUpstream: stub,
}
_, _ = svc.ForwardUpstreamGemini(context.Background(), c, newUpstreamAccount(), "gemini-2.5-flash", "generateContent", false, body, false)
captured := stub.capturedReq
require.NotNil(t, captured)
// hop-by-hop header 不应出现
require.Empty(t, captured.Header.Get("Connection"))
require.Empty(t, captured.Header.Get("Proxy-Authorization"))
// Host header 在 Go http.Request 中特殊处理,但我们的黑名单应阻止透传
require.Empty(t, captured.Header.Values("Host"))
// 普通 header 应保留
require.Equal(t, "application/json", captured.Header.Get("Content-Type"))
}