refactor(upstream): replace upstream account type with apikey, auto-append /antigravity
Upstream accounts now use the standard APIKey type instead of a dedicated upstream type. GetBaseURL() and new GetGeminiBaseURL() automatically append /antigravity for Antigravity platform APIKey accounts, eliminating the need for separate upstream forwarding methods. - Remove ForwardUpstream, ForwardUpstreamGemini, testUpstreamConnection - Remove upstream branch guards in Forward/ForwardGemini/TestConnection - Add migration 052 to convert existing upstream accounts to apikey - Update frontend CreateAccountModal to create apikey type - Add unit tests for GetBaseURL and GetGeminiBaseURL
This commit is contained in:
@@ -482,7 +482,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
|||||||
if switchCount > 0 {
|
if switchCount > 0 {
|
||||||
requestCtx = context.WithValue(requestCtx, ctxkey.AccountSwitchCount, switchCount)
|
requestCtx = context.WithValue(requestCtx, ctxkey.AccountSwitchCount, switchCount)
|
||||||
}
|
}
|
||||||
if account.Platform == service.PlatformAntigravity {
|
if account.Platform == service.PlatformAntigravity && account.Type != service.AccountTypeAPIKey {
|
||||||
result, err = h.antigravityGatewayService.Forward(requestCtx, c, account, body, hasBoundSession)
|
result, err = h.antigravityGatewayService.Forward(requestCtx, c, account, body, hasBoundSession)
|
||||||
} else {
|
} else {
|
||||||
result, err = h.gatewayService.Forward(requestCtx, c, account, parsedReq)
|
result, err = h.gatewayService.Forward(requestCtx, c, account, parsedReq)
|
||||||
|
|||||||
@@ -410,7 +410,7 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
|
|||||||
if switchCount > 0 {
|
if switchCount > 0 {
|
||||||
requestCtx = context.WithValue(requestCtx, ctxkey.AccountSwitchCount, switchCount)
|
requestCtx = context.WithValue(requestCtx, ctxkey.AccountSwitchCount, switchCount)
|
||||||
}
|
}
|
||||||
if account.Platform == service.PlatformAntigravity {
|
if account.Platform == service.PlatformAntigravity && account.Type != service.AccountTypeAPIKey {
|
||||||
result, err = h.antigravityGatewayService.ForwardGemini(requestCtx, c, account, modelName, action, stream, body, hasBoundSession)
|
result, err = h.antigravityGatewayService.ForwardGemini(requestCtx, c, account, modelName, action, stream, body, hasBoundSession)
|
||||||
} else {
|
} else {
|
||||||
result, err = h.geminiCompatService.ForwardNative(requestCtx, c, account, modelName, action, stream, body)
|
result, err = h.geminiCompatService.ForwardNative(requestCtx, c, account, modelName, action, stream, body)
|
||||||
|
|||||||
@@ -425,6 +425,22 @@ func (a *Account) GetBaseURL() string {
|
|||||||
if baseURL == "" {
|
if baseURL == "" {
|
||||||
return "https://api.anthropic.com"
|
return "https://api.anthropic.com"
|
||||||
}
|
}
|
||||||
|
if a.Platform == PlatformAntigravity {
|
||||||
|
return strings.TrimRight(baseURL, "/") + "/antigravity"
|
||||||
|
}
|
||||||
|
return baseURL
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetGeminiBaseURL 返回 Gemini 兼容端点的 base URL。
|
||||||
|
// Antigravity 平台的 APIKey 账号自动拼接 /antigravity。
|
||||||
|
func (a *Account) GetGeminiBaseURL(defaultBaseURL string) string {
|
||||||
|
baseURL := strings.TrimSpace(a.GetCredential("base_url"))
|
||||||
|
if baseURL == "" {
|
||||||
|
return defaultBaseURL
|
||||||
|
}
|
||||||
|
if a.Platform == PlatformAntigravity && a.Type == AccountTypeAPIKey {
|
||||||
|
return strings.TrimRight(baseURL, "/") + "/antigravity"
|
||||||
|
}
|
||||||
return baseURL
|
return baseURL
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
160
backend/internal/service/account_base_url_test.go
Normal file
160
backend/internal/service/account_base_url_test.go
Normal file
@@ -0,0 +1,160 @@
|
|||||||
|
//go:build unit
|
||||||
|
|
||||||
|
package service
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestGetBaseURL(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
account Account
|
||||||
|
expected string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "non-apikey type returns empty",
|
||||||
|
account: Account{
|
||||||
|
Type: AccountTypeOAuth,
|
||||||
|
Platform: PlatformAnthropic,
|
||||||
|
},
|
||||||
|
expected: "",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "apikey without base_url returns default anthropic",
|
||||||
|
account: Account{
|
||||||
|
Type: AccountTypeAPIKey,
|
||||||
|
Platform: PlatformAnthropic,
|
||||||
|
Credentials: map[string]any{},
|
||||||
|
},
|
||||||
|
expected: "https://api.anthropic.com",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "apikey with custom base_url",
|
||||||
|
account: Account{
|
||||||
|
Type: AccountTypeAPIKey,
|
||||||
|
Platform: PlatformAnthropic,
|
||||||
|
Credentials: map[string]any{"base_url": "https://custom.example.com"},
|
||||||
|
},
|
||||||
|
expected: "https://custom.example.com",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "antigravity apikey auto-appends /antigravity",
|
||||||
|
account: Account{
|
||||||
|
Type: AccountTypeAPIKey,
|
||||||
|
Platform: PlatformAntigravity,
|
||||||
|
Credentials: map[string]any{"base_url": "https://upstream.example.com"},
|
||||||
|
},
|
||||||
|
expected: "https://upstream.example.com/antigravity",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "antigravity apikey trims trailing slash before appending",
|
||||||
|
account: Account{
|
||||||
|
Type: AccountTypeAPIKey,
|
||||||
|
Platform: PlatformAntigravity,
|
||||||
|
Credentials: map[string]any{"base_url": "https://upstream.example.com/"},
|
||||||
|
},
|
||||||
|
expected: "https://upstream.example.com/antigravity",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "antigravity non-apikey returns empty",
|
||||||
|
account: Account{
|
||||||
|
Type: AccountTypeOAuth,
|
||||||
|
Platform: PlatformAntigravity,
|
||||||
|
Credentials: map[string]any{"base_url": "https://upstream.example.com"},
|
||||||
|
},
|
||||||
|
expected: "",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
result := tt.account.GetBaseURL()
|
||||||
|
if result != tt.expected {
|
||||||
|
t.Errorf("GetBaseURL() = %q, want %q", result, tt.expected)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetGeminiBaseURL(t *testing.T) {
|
||||||
|
const defaultGeminiURL = "https://generativelanguage.googleapis.com"
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
account Account
|
||||||
|
expected string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "apikey without base_url returns default",
|
||||||
|
account: Account{
|
||||||
|
Type: AccountTypeAPIKey,
|
||||||
|
Platform: PlatformGemini,
|
||||||
|
Credentials: map[string]any{},
|
||||||
|
},
|
||||||
|
expected: defaultGeminiURL,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "apikey with custom base_url",
|
||||||
|
account: Account{
|
||||||
|
Type: AccountTypeAPIKey,
|
||||||
|
Platform: PlatformGemini,
|
||||||
|
Credentials: map[string]any{"base_url": "https://custom-gemini.example.com"},
|
||||||
|
},
|
||||||
|
expected: "https://custom-gemini.example.com",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "antigravity apikey auto-appends /antigravity",
|
||||||
|
account: Account{
|
||||||
|
Type: AccountTypeAPIKey,
|
||||||
|
Platform: PlatformAntigravity,
|
||||||
|
Credentials: map[string]any{"base_url": "https://upstream.example.com"},
|
||||||
|
},
|
||||||
|
expected: "https://upstream.example.com/antigravity",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "antigravity apikey trims trailing slash",
|
||||||
|
account: Account{
|
||||||
|
Type: AccountTypeAPIKey,
|
||||||
|
Platform: PlatformAntigravity,
|
||||||
|
Credentials: map[string]any{"base_url": "https://upstream.example.com/"},
|
||||||
|
},
|
||||||
|
expected: "https://upstream.example.com/antigravity",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "antigravity oauth does NOT append /antigravity",
|
||||||
|
account: Account{
|
||||||
|
Type: AccountTypeOAuth,
|
||||||
|
Platform: PlatformAntigravity,
|
||||||
|
Credentials: map[string]any{"base_url": "https://upstream.example.com"},
|
||||||
|
},
|
||||||
|
expected: "https://upstream.example.com",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "oauth without base_url returns default",
|
||||||
|
account: Account{
|
||||||
|
Type: AccountTypeOAuth,
|
||||||
|
Platform: PlatformAntigravity,
|
||||||
|
Credentials: map[string]any{},
|
||||||
|
},
|
||||||
|
expected: defaultGeminiURL,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "nil credentials returns default",
|
||||||
|
account: Account{
|
||||||
|
Type: AccountTypeAPIKey,
|
||||||
|
Platform: PlatformGemini,
|
||||||
|
},
|
||||||
|
expected: defaultGeminiURL,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
result := tt.account.GetGeminiBaseURL(defaultGeminiURL)
|
||||||
|
if result != tt.expected {
|
||||||
|
t.Errorf("GetGeminiBaseURL() = %q, want %q", result, tt.expected)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -665,9 +665,6 @@ type TestConnectionResult struct {
|
|||||||
// TestConnection 测试 Antigravity 账号连接(非流式,无重试、无计费)
|
// TestConnection 测试 Antigravity 账号连接(非流式,无重试、无计费)
|
||||||
// 支持 Claude 和 Gemini 两种协议,根据 modelID 前缀自动选择
|
// 支持 Claude 和 Gemini 两种协议,根据 modelID 前缀自动选择
|
||||||
func (s *AntigravityGatewayService) TestConnection(ctx context.Context, account *Account, modelID string) (*TestConnectionResult, error) {
|
func (s *AntigravityGatewayService) TestConnection(ctx context.Context, account *Account, modelID string) (*TestConnectionResult, error) {
|
||||||
if account.Type == AccountTypeUpstream {
|
|
||||||
return s.testUpstreamConnection(ctx, account, modelID)
|
|
||||||
}
|
|
||||||
|
|
||||||
// 获取 token
|
// 获取 token
|
||||||
if s.tokenProvider == nil {
|
if s.tokenProvider == nil {
|
||||||
@@ -986,10 +983,6 @@ func isModelNotFoundError(statusCode int, body []byte) bool {
|
|||||||
func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context, account *Account, body []byte, isStickySession bool) (*ForwardResult, error) {
|
func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context, account *Account, body []byte, isStickySession bool) (*ForwardResult, error) {
|
||||||
startTime := time.Now()
|
startTime := time.Now()
|
||||||
|
|
||||||
if account.Type == AccountTypeUpstream {
|
|
||||||
return s.ForwardUpstream(ctx, c, account, body, isStickySession)
|
|
||||||
}
|
|
||||||
|
|
||||||
sessionID := getSessionID(c)
|
sessionID := getSessionID(c)
|
||||||
prefix := logPrefix(sessionID, account.Name)
|
prefix := logPrefix(sessionID, account.Name)
|
||||||
|
|
||||||
@@ -1610,10 +1603,6 @@ func stripSignatureSensitiveBlocksFromClaudeRequest(req *antigravity.ClaudeReque
|
|||||||
func (s *AntigravityGatewayService) ForwardGemini(ctx context.Context, c *gin.Context, account *Account, originalModel string, action string, stream bool, body []byte, isStickySession bool) (*ForwardResult, error) {
|
func (s *AntigravityGatewayService) ForwardGemini(ctx context.Context, c *gin.Context, account *Account, originalModel string, action string, stream bool, body []byte, isStickySession bool) (*ForwardResult, error) {
|
||||||
startTime := time.Now()
|
startTime := time.Now()
|
||||||
|
|
||||||
if account.Type == AccountTypeUpstream {
|
|
||||||
return s.ForwardUpstreamGemini(ctx, c, account, originalModel, action, stream, body, isStickySession)
|
|
||||||
}
|
|
||||||
|
|
||||||
sessionID := getSessionID(c)
|
sessionID := getSessionID(c)
|
||||||
prefix := logPrefix(sessionID, account.Name)
|
prefix := logPrefix(sessionID, account.Name)
|
||||||
|
|
||||||
@@ -3361,378 +3350,3 @@ func filterEmptyPartsFromGeminiRequest(body []byte) ([]byte, error) {
|
|||||||
payload["contents"] = filtered
|
payload["contents"] = filtered
|
||||||
return json.Marshal(payload)
|
return json.Marshal(payload)
|
||||||
}
|
}
|
||||||
|
|
||||||
// ---------------------------------------------------------------------------
|
|
||||||
// Upstream 专用转发方法
|
|
||||||
// upstream 账号直接连接上游 Anthropic/Gemini 兼容端点,不走 Antigravity OAuth 协议转换。
|
|
||||||
// ---------------------------------------------------------------------------
|
|
||||||
|
|
||||||
// testUpstreamConnection 测试 upstream 账号连接
|
|
||||||
func (s *AntigravityGatewayService) testUpstreamConnection(ctx context.Context, account *Account, modelID string) (*TestConnectionResult, error) {
|
|
||||||
baseURL := strings.TrimRight(strings.TrimSpace(account.GetCredential("base_url")), "/")
|
|
||||||
if baseURL == "" {
|
|
||||||
return nil, errors.New("upstream account missing base_url in credentials")
|
|
||||||
}
|
|
||||||
apiKey := strings.TrimSpace(account.GetCredential("api_key"))
|
|
||||||
if apiKey == "" {
|
|
||||||
return nil, errors.New("upstream account missing api_key in credentials")
|
|
||||||
}
|
|
||||||
|
|
||||||
mappedModel := s.getMappedModel(account, modelID)
|
|
||||||
if mappedModel == "" {
|
|
||||||
return nil, fmt.Errorf("model %s not in whitelist", modelID)
|
|
||||||
}
|
|
||||||
|
|
||||||
// 构建最小 Claude 格式请求
|
|
||||||
requestBody, _ := json.Marshal(map[string]any{
|
|
||||||
"model": mappedModel,
|
|
||||||
"max_tokens": 1,
|
|
||||||
"messages": []map[string]any{
|
|
||||||
{"role": "user", "content": "."},
|
|
||||||
},
|
|
||||||
"stream": false,
|
|
||||||
})
|
|
||||||
|
|
||||||
apiURL := baseURL + "/antigravity/v1/messages"
|
|
||||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, apiURL, bytes.NewReader(requestBody))
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("构建请求失败: %w", err)
|
|
||||||
}
|
|
||||||
req.Header.Set("Content-Type", "application/json")
|
|
||||||
req.Header.Set("Authorization", "Bearer "+apiKey)
|
|
||||||
req.Header.Set("x-api-key", apiKey)
|
|
||||||
req.Header.Set("anthropic-version", "2023-06-01")
|
|
||||||
|
|
||||||
proxyURL := ""
|
|
||||||
if account.ProxyID != nil && account.Proxy != nil {
|
|
||||||
proxyURL = account.Proxy.URL()
|
|
||||||
}
|
|
||||||
|
|
||||||
log.Printf("[antigravity-Test-Upstream] account=%s url=%s", account.Name, apiURL)
|
|
||||||
|
|
||||||
resp, err := s.httpUpstream.Do(req, proxyURL, account.ID, account.Concurrency)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("请求失败: %w", err)
|
|
||||||
}
|
|
||||||
defer func() { _ = resp.Body.Close() }()
|
|
||||||
|
|
||||||
respBody, err := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("读取响应失败: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if resp.StatusCode >= 400 {
|
|
||||||
return nil, fmt.Errorf("API 返回 %d: %s", resp.StatusCode, string(respBody))
|
|
||||||
}
|
|
||||||
|
|
||||||
// 从 Claude 格式非流式响应中提取文本
|
|
||||||
var claudeResp struct {
|
|
||||||
Content []struct {
|
|
||||||
Text string `json:"text"`
|
|
||||||
} `json:"content"`
|
|
||||||
}
|
|
||||||
text := ""
|
|
||||||
if json.Unmarshal(respBody, &claudeResp) == nil && len(claudeResp.Content) > 0 {
|
|
||||||
text = claudeResp.Content[0].Text
|
|
||||||
}
|
|
||||||
|
|
||||||
return &TestConnectionResult{
|
|
||||||
Text: text,
|
|
||||||
MappedModel: mappedModel,
|
|
||||||
}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// ForwardUpstream 转发 Claude 协议请求到 upstream(不做协议转换)
|
|
||||||
func (s *AntigravityGatewayService) ForwardUpstream(ctx context.Context, c *gin.Context, account *Account, body []byte, isStickySession bool) (*ForwardResult, error) {
|
|
||||||
startTime := time.Now()
|
|
||||||
sessionID := getSessionID(c)
|
|
||||||
prefix := logPrefix(sessionID, account.Name)
|
|
||||||
|
|
||||||
baseURL := strings.TrimRight(strings.TrimSpace(account.GetCredential("base_url")), "/")
|
|
||||||
if baseURL == "" {
|
|
||||||
return nil, s.writeClaudeError(c, http.StatusBadGateway, "api_error", "Upstream account missing base_url")
|
|
||||||
}
|
|
||||||
apiKey := strings.TrimSpace(account.GetCredential("api_key"))
|
|
||||||
if apiKey == "" {
|
|
||||||
return nil, s.writeClaudeError(c, http.StatusBadGateway, "authentication_error", "Upstream account missing api_key")
|
|
||||||
}
|
|
||||||
|
|
||||||
// 解析请求以获取模型和流式标志
|
|
||||||
var claudeReq antigravity.ClaudeRequest
|
|
||||||
if err := json.Unmarshal(body, &claudeReq); err != nil {
|
|
||||||
return nil, s.writeClaudeError(c, http.StatusBadRequest, "invalid_request_error", "Invalid request body")
|
|
||||||
}
|
|
||||||
if strings.TrimSpace(claudeReq.Model) == "" {
|
|
||||||
return nil, s.writeClaudeError(c, http.StatusBadRequest, "invalid_request_error", "Missing model")
|
|
||||||
}
|
|
||||||
|
|
||||||
originalModel := claudeReq.Model
|
|
||||||
mappedModel := s.getMappedModel(account, claudeReq.Model)
|
|
||||||
if mappedModel == "" {
|
|
||||||
return nil, s.writeClaudeError(c, http.StatusForbidden, "permission_error", fmt.Sprintf("model %s not in whitelist", claudeReq.Model))
|
|
||||||
}
|
|
||||||
|
|
||||||
// 代理 URL
|
|
||||||
proxyURL := ""
|
|
||||||
if account.ProxyID != nil && account.Proxy != nil {
|
|
||||||
proxyURL = account.Proxy.URL()
|
|
||||||
}
|
|
||||||
|
|
||||||
// 统计模型调用次数
|
|
||||||
if s.cache != nil {
|
|
||||||
_, _ = s.cache.IncrModelCallCount(ctx, account.ID, mappedModel)
|
|
||||||
}
|
|
||||||
|
|
||||||
apiURL := baseURL + "/antigravity/v1/messages"
|
|
||||||
log.Printf("%s upstream_forward url=%s model=%s", prefix, apiURL, mappedModel)
|
|
||||||
|
|
||||||
// 构建请求:body 原样透传
|
|
||||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, apiURL, bytes.NewReader(body))
|
|
||||||
if err != nil {
|
|
||||||
return nil, s.writeClaudeError(c, http.StatusInternalServerError, "api_error", "Failed to build request")
|
|
||||||
}
|
|
||||||
// 透传客户端所有请求头(排除 hop-by-hop 和认证头)
|
|
||||||
if c != nil && c.Request != nil {
|
|
||||||
for key, values := range c.Request.Header {
|
|
||||||
if upstreamHopByHopHeaders[strings.ToLower(key)] {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
for _, v := range values {
|
|
||||||
req.Header.Add(key, v)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
// 覆盖认证头
|
|
||||||
req.Header.Set("Authorization", "Bearer "+apiKey)
|
|
||||||
req.Header.Set("x-api-key", apiKey)
|
|
||||||
|
|
||||||
if c != nil && len(body) > 0 {
|
|
||||||
c.Set(OpsUpstreamRequestBodyKey, string(body))
|
|
||||||
}
|
|
||||||
|
|
||||||
// 单次发送,不重试
|
|
||||||
resp, err := s.httpUpstream.Do(req, proxyURL, account.ID, account.Concurrency)
|
|
||||||
if err != nil {
|
|
||||||
return nil, s.writeClaudeError(c, http.StatusBadGateway, "upstream_error", fmt.Sprintf("Upstream request failed: %v", err))
|
|
||||||
}
|
|
||||||
defer func() { _ = resp.Body.Close() }()
|
|
||||||
|
|
||||||
// 错误响应处理
|
|
||||||
if resp.StatusCode >= 400 {
|
|
||||||
respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
|
|
||||||
|
|
||||||
s.handleUpstreamError(ctx, prefix, account, resp.StatusCode, resp.Header, respBody, "", 0, "", isStickySession)
|
|
||||||
|
|
||||||
if s.shouldFailoverUpstreamError(resp.StatusCode) {
|
|
||||||
return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode, ResponseBody: respBody}
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil, s.writeMappedClaudeError(c, account, resp.StatusCode, resp.Header.Get("x-request-id"), respBody)
|
|
||||||
}
|
|
||||||
|
|
||||||
// 成功响应:透传 response header + body
|
|
||||||
requestID := resp.Header.Get("x-request-id")
|
|
||||||
|
|
||||||
// 透传上游响应头(排除 hop-by-hop)
|
|
||||||
for key, values := range resp.Header {
|
|
||||||
if upstreamHopByHopHeaders[strings.ToLower(key)] {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
for _, v := range values {
|
|
||||||
c.Header(key, v)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
c.Status(resp.StatusCode)
|
|
||||||
_, copyErr := io.Copy(c.Writer, resp.Body)
|
|
||||||
if copyErr != nil {
|
|
||||||
log.Printf("%s status=copy_error error=%v", prefix, copyErr)
|
|
||||||
}
|
|
||||||
|
|
||||||
return &ForwardResult{
|
|
||||||
RequestID: requestID,
|
|
||||||
Model: originalModel,
|
|
||||||
Stream: claudeReq.Stream,
|
|
||||||
Duration: time.Since(startTime),
|
|
||||||
}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// ForwardUpstreamGemini 转发 Gemini 协议请求到 upstream(不做协议转换)
|
|
||||||
func (s *AntigravityGatewayService) ForwardUpstreamGemini(ctx context.Context, c *gin.Context, account *Account, originalModel string, action string, stream bool, body []byte, isStickySession bool) (*ForwardResult, error) {
|
|
||||||
startTime := time.Now()
|
|
||||||
sessionID := getSessionID(c)
|
|
||||||
prefix := logPrefix(sessionID, account.Name)
|
|
||||||
|
|
||||||
baseURL := strings.TrimRight(strings.TrimSpace(account.GetCredential("base_url")), "/")
|
|
||||||
if baseURL == "" {
|
|
||||||
return nil, s.writeGoogleError(c, http.StatusBadGateway, "Upstream account missing base_url")
|
|
||||||
}
|
|
||||||
apiKey := strings.TrimSpace(account.GetCredential("api_key"))
|
|
||||||
if apiKey == "" {
|
|
||||||
return nil, s.writeGoogleError(c, http.StatusBadGateway, "Upstream account missing api_key")
|
|
||||||
}
|
|
||||||
|
|
||||||
if strings.TrimSpace(originalModel) == "" {
|
|
||||||
return nil, s.writeGoogleError(c, http.StatusBadRequest, "Missing model in URL")
|
|
||||||
}
|
|
||||||
if strings.TrimSpace(action) == "" {
|
|
||||||
return nil, s.writeGoogleError(c, http.StatusBadRequest, "Missing action in URL")
|
|
||||||
}
|
|
||||||
if len(body) == 0 {
|
|
||||||
return nil, s.writeGoogleError(c, http.StatusBadRequest, "Request body is empty")
|
|
||||||
}
|
|
||||||
|
|
||||||
imageSize := s.extractImageSize(body)
|
|
||||||
|
|
||||||
switch action {
|
|
||||||
case "generateContent", "streamGenerateContent":
|
|
||||||
// ok
|
|
||||||
case "countTokens":
|
|
||||||
c.JSON(http.StatusOK, map[string]any{"totalTokens": 0})
|
|
||||||
return &ForwardResult{
|
|
||||||
RequestID: "",
|
|
||||||
Usage: ClaudeUsage{},
|
|
||||||
Model: originalModel,
|
|
||||||
Stream: false,
|
|
||||||
Duration: time.Since(time.Now()),
|
|
||||||
FirstTokenMs: nil,
|
|
||||||
}, nil
|
|
||||||
default:
|
|
||||||
return nil, s.writeGoogleError(c, http.StatusNotFound, "Unsupported action: "+action)
|
|
||||||
}
|
|
||||||
|
|
||||||
mappedModel := s.getMappedModel(account, originalModel)
|
|
||||||
if mappedModel == "" {
|
|
||||||
return nil, s.writeGoogleError(c, http.StatusForbidden, fmt.Sprintf("model %s not in whitelist", originalModel))
|
|
||||||
}
|
|
||||||
|
|
||||||
// 代理 URL
|
|
||||||
proxyURL := ""
|
|
||||||
if account.ProxyID != nil && account.Proxy != nil {
|
|
||||||
proxyURL = account.Proxy.URL()
|
|
||||||
}
|
|
||||||
|
|
||||||
// 统计模型调用次数
|
|
||||||
if s.cache != nil {
|
|
||||||
_, _ = s.cache.IncrModelCallCount(ctx, account.ID, mappedModel)
|
|
||||||
}
|
|
||||||
|
|
||||||
// 构建 upstream URL: base_url + /antigravity/v1beta/models/MODEL:ACTION
|
|
||||||
apiURL := fmt.Sprintf("%s/antigravity/v1beta/models/%s:%s", baseURL, mappedModel, action)
|
|
||||||
if stream || action == "streamGenerateContent" {
|
|
||||||
apiURL += "?alt=sse"
|
|
||||||
}
|
|
||||||
|
|
||||||
log.Printf("%s upstream_forward_gemini url=%s model=%s action=%s", prefix, apiURL, mappedModel, action)
|
|
||||||
|
|
||||||
// 构建请求:body 原样透传
|
|
||||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, apiURL, bytes.NewReader(body))
|
|
||||||
if err != nil {
|
|
||||||
return nil, s.writeGoogleError(c, http.StatusInternalServerError, "Failed to build request")
|
|
||||||
}
|
|
||||||
// 透传客户端所有请求头(排除 hop-by-hop 和认证头)
|
|
||||||
if c != nil && c.Request != nil {
|
|
||||||
for key, values := range c.Request.Header {
|
|
||||||
if upstreamHopByHopHeaders[strings.ToLower(key)] {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
for _, v := range values {
|
|
||||||
req.Header.Add(key, v)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
// 覆盖认证头
|
|
||||||
req.Header.Set("Authorization", "Bearer "+apiKey)
|
|
||||||
|
|
||||||
if c != nil && len(body) > 0 {
|
|
||||||
c.Set(OpsUpstreamRequestBodyKey, string(body))
|
|
||||||
}
|
|
||||||
|
|
||||||
// 单次发送,不重试
|
|
||||||
resp, err := s.httpUpstream.Do(req, proxyURL, account.ID, account.Concurrency)
|
|
||||||
if err != nil {
|
|
||||||
return nil, s.writeGoogleError(c, http.StatusBadGateway, fmt.Sprintf("Upstream request failed: %v", err))
|
|
||||||
}
|
|
||||||
defer func() { _ = resp.Body.Close() }()
|
|
||||||
|
|
||||||
// 错误响应处理
|
|
||||||
if resp.StatusCode >= 400 {
|
|
||||||
respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
|
|
||||||
contentType := resp.Header.Get("Content-Type")
|
|
||||||
|
|
||||||
requestID := resp.Header.Get("x-request-id")
|
|
||||||
if requestID != "" {
|
|
||||||
c.Header("x-request-id", requestID)
|
|
||||||
}
|
|
||||||
|
|
||||||
s.handleUpstreamError(ctx, prefix, account, resp.StatusCode, resp.Header, respBody, "", 0, "", isStickySession)
|
|
||||||
upstreamMsg := strings.TrimSpace(extractAntigravityErrorMessage(respBody))
|
|
||||||
upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg)
|
|
||||||
upstreamDetail := s.getUpstreamErrorDetail(respBody)
|
|
||||||
|
|
||||||
setOpsUpstreamError(c, resp.StatusCode, upstreamMsg, upstreamDetail)
|
|
||||||
|
|
||||||
if s.shouldFailoverUpstreamError(resp.StatusCode) {
|
|
||||||
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
|
|
||||||
Platform: account.Platform,
|
|
||||||
AccountID: account.ID,
|
|
||||||
AccountName: account.Name,
|
|
||||||
UpstreamStatusCode: resp.StatusCode,
|
|
||||||
UpstreamRequestID: requestID,
|
|
||||||
Kind: "failover",
|
|
||||||
Message: upstreamMsg,
|
|
||||||
Detail: upstreamDetail,
|
|
||||||
})
|
|
||||||
return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode, ResponseBody: respBody}
|
|
||||||
}
|
|
||||||
if contentType == "" {
|
|
||||||
contentType = "application/json"
|
|
||||||
}
|
|
||||||
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
|
|
||||||
Platform: account.Platform,
|
|
||||||
AccountID: account.ID,
|
|
||||||
AccountName: account.Name,
|
|
||||||
UpstreamStatusCode: resp.StatusCode,
|
|
||||||
UpstreamRequestID: requestID,
|
|
||||||
Kind: "http_error",
|
|
||||||
Message: upstreamMsg,
|
|
||||||
Detail: upstreamDetail,
|
|
||||||
})
|
|
||||||
log.Printf("[antigravity-Forward-Upstream] upstream error status=%d body=%s", resp.StatusCode, truncateForLog(respBody, 500))
|
|
||||||
c.Data(resp.StatusCode, contentType, respBody)
|
|
||||||
return nil, fmt.Errorf("antigravity upstream error: %d", resp.StatusCode)
|
|
||||||
}
|
|
||||||
|
|
||||||
// 成功响应:透传 response header + body
|
|
||||||
requestID := resp.Header.Get("x-request-id")
|
|
||||||
|
|
||||||
// 透传上游响应头(排除 hop-by-hop)
|
|
||||||
for key, values := range resp.Header {
|
|
||||||
if upstreamHopByHopHeaders[strings.ToLower(key)] {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
for _, v := range values {
|
|
||||||
c.Header(key, v)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
c.Status(resp.StatusCode)
|
|
||||||
_, copyErr := io.Copy(c.Writer, resp.Body)
|
|
||||||
if copyErr != nil {
|
|
||||||
log.Printf("%s status=copy_error error=%v", prefix, copyErr)
|
|
||||||
}
|
|
||||||
|
|
||||||
imageCount := 0
|
|
||||||
if isImageGenerationModel(mappedModel) {
|
|
||||||
imageCount = 1
|
|
||||||
}
|
|
||||||
|
|
||||||
return &ForwardResult{
|
|
||||||
RequestID: requestID,
|
|
||||||
Model: originalModel,
|
|
||||||
Stream: stream,
|
|
||||||
Duration: time.Since(startTime),
|
|
||||||
ImageCount: imageCount,
|
|
||||||
ImageSize: imageSize,
|
|
||||||
}, nil
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -560,10 +560,7 @@ func (s *GeminiMessagesCompatService) Forward(ctx context.Context, c *gin.Contex
|
|||||||
return nil, "", errors.New("gemini api_key not configured")
|
return nil, "", errors.New("gemini api_key not configured")
|
||||||
}
|
}
|
||||||
|
|
||||||
baseURL := strings.TrimSpace(account.GetCredential("base_url"))
|
baseURL := account.GetGeminiBaseURL(geminicli.AIStudioBaseURL)
|
||||||
if baseURL == "" {
|
|
||||||
baseURL = geminicli.AIStudioBaseURL
|
|
||||||
}
|
|
||||||
normalizedBaseURL, err := s.validateUpstreamBaseURL(baseURL)
|
normalizedBaseURL, err := s.validateUpstreamBaseURL(baseURL)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, "", err
|
return nil, "", err
|
||||||
@@ -640,10 +637,7 @@ func (s *GeminiMessagesCompatService) Forward(ctx context.Context, c *gin.Contex
|
|||||||
return upstreamReq, "x-request-id", nil
|
return upstreamReq, "x-request-id", nil
|
||||||
} else {
|
} else {
|
||||||
// Mode 2: AI Studio API with OAuth (like API key mode, but using Bearer token)
|
// Mode 2: AI Studio API with OAuth (like API key mode, but using Bearer token)
|
||||||
baseURL := strings.TrimSpace(account.GetCredential("base_url"))
|
baseURL := account.GetGeminiBaseURL(geminicli.AIStudioBaseURL)
|
||||||
if baseURL == "" {
|
|
||||||
baseURL = geminicli.AIStudioBaseURL
|
|
||||||
}
|
|
||||||
normalizedBaseURL, err := s.validateUpstreamBaseURL(baseURL)
|
normalizedBaseURL, err := s.validateUpstreamBaseURL(baseURL)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, "", err
|
return nil, "", err
|
||||||
@@ -1026,10 +1020,7 @@ func (s *GeminiMessagesCompatService) ForwardNative(ctx context.Context, c *gin.
|
|||||||
return nil, "", errors.New("gemini api_key not configured")
|
return nil, "", errors.New("gemini api_key not configured")
|
||||||
}
|
}
|
||||||
|
|
||||||
baseURL := strings.TrimSpace(account.GetCredential("base_url"))
|
baseURL := account.GetGeminiBaseURL(geminicli.AIStudioBaseURL)
|
||||||
if baseURL == "" {
|
|
||||||
baseURL = geminicli.AIStudioBaseURL
|
|
||||||
}
|
|
||||||
normalizedBaseURL, err := s.validateUpstreamBaseURL(baseURL)
|
normalizedBaseURL, err := s.validateUpstreamBaseURL(baseURL)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, "", err
|
return nil, "", err
|
||||||
@@ -1097,10 +1088,7 @@ func (s *GeminiMessagesCompatService) ForwardNative(ctx context.Context, c *gin.
|
|||||||
return upstreamReq, "x-request-id", nil
|
return upstreamReq, "x-request-id", nil
|
||||||
} else {
|
} else {
|
||||||
// Mode 2: AI Studio API with OAuth (like API key mode, but using Bearer token)
|
// Mode 2: AI Studio API with OAuth (like API key mode, but using Bearer token)
|
||||||
baseURL := strings.TrimSpace(account.GetCredential("base_url"))
|
baseURL := account.GetGeminiBaseURL(geminicli.AIStudioBaseURL)
|
||||||
if baseURL == "" {
|
|
||||||
baseURL = geminicli.AIStudioBaseURL
|
|
||||||
}
|
|
||||||
normalizedBaseURL, err := s.validateUpstreamBaseURL(baseURL)
|
normalizedBaseURL, err := s.validateUpstreamBaseURL(baseURL)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, "", err
|
return nil, "", err
|
||||||
@@ -2420,10 +2408,7 @@ func (s *GeminiMessagesCompatService) ForwardAIStudioGET(ctx context.Context, ac
|
|||||||
return nil, errors.New("invalid path")
|
return nil, errors.New("invalid path")
|
||||||
}
|
}
|
||||||
|
|
||||||
baseURL := strings.TrimSpace(account.GetCredential("base_url"))
|
baseURL := account.GetGeminiBaseURL(geminicli.AIStudioBaseURL)
|
||||||
if baseURL == "" {
|
|
||||||
baseURL = geminicli.AIStudioBaseURL
|
|
||||||
}
|
|
||||||
normalizedBaseURL, err := s.validateUpstreamBaseURL(baseURL)
|
normalizedBaseURL, err := s.validateUpstreamBaseURL(baseURL)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
|
|||||||
@@ -1,285 +0,0 @@
|
|||||||
//go:build unit
|
|
||||||
|
|
||||||
package service
|
|
||||||
|
|
||||||
import (
|
|
||||||
"bytes"
|
|
||||||
"context"
|
|
||||||
"encoding/json"
|
|
||||||
"io"
|
|
||||||
"net/http"
|
|
||||||
"net/http/httptest"
|
|
||||||
"testing"
|
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
|
||||||
"github.com/stretchr/testify/require"
|
|
||||||
)
|
|
||||||
|
|
||||||
// httpUpstreamCapture captures the outgoing *http.Request for assertion.
|
|
||||||
type httpUpstreamCapture struct {
|
|
||||||
capturedReq *http.Request
|
|
||||||
resp *http.Response
|
|
||||||
err error
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *httpUpstreamCapture) Do(req *http.Request, _ string, _ int64, _ int) (*http.Response, error) {
|
|
||||||
s.capturedReq = req
|
|
||||||
return s.resp, s.err
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *httpUpstreamCapture) DoWithTLS(req *http.Request, _ string, _ int64, _ int, _ bool) (*http.Response, error) {
|
|
||||||
s.capturedReq = req
|
|
||||||
return s.resp, s.err
|
|
||||||
}
|
|
||||||
|
|
||||||
func newUpstreamAccount() *Account {
|
|
||||||
return &Account{
|
|
||||||
ID: 100,
|
|
||||||
Name: "upstream-test",
|
|
||||||
Platform: PlatformAntigravity,
|
|
||||||
Type: AccountTypeUpstream,
|
|
||||||
Status: StatusActive,
|
|
||||||
Concurrency: 1,
|
|
||||||
Credentials: map[string]any{
|
|
||||||
"base_url": "https://upstream.example.com",
|
|
||||||
"api_key": "sk-upstream-secret",
|
|
||||||
},
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// makeSSEOKResponse builds a minimal SSE response that
|
|
||||||
// handleClaudeStreamingResponse / handleGeminiStreamingResponse
|
|
||||||
// can consume without error.
|
|
||||||
// We return 502 to bypass streaming and hit the error branch instead,
|
|
||||||
// which is sufficient for testing header passthrough.
|
|
||||||
func makeUpstreamErrorResponse() *http.Response {
|
|
||||||
body := []byte(`{"error":{"message":"test error"}}`)
|
|
||||||
return &http.Response{
|
|
||||||
StatusCode: http.StatusBadGateway,
|
|
||||||
Header: http.Header{"Content-Type": []string{"application/json"}},
|
|
||||||
Body: io.NopCloser(bytes.NewReader(body)),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// --- ForwardUpstream tests ---
|
|
||||||
|
|
||||||
func TestForwardUpstream_PassthroughHeaders(t *testing.T) {
|
|
||||||
gin.SetMode(gin.TestMode)
|
|
||||||
w := httptest.NewRecorder()
|
|
||||||
c, _ := gin.CreateTestContext(w)
|
|
||||||
|
|
||||||
body, _ := json.Marshal(map[string]any{
|
|
||||||
"model": "claude-sonnet-4-5",
|
|
||||||
"messages": []map[string]any{{"role": "user", "content": "hi"}},
|
|
||||||
"max_tokens": 1,
|
|
||||||
"stream": false,
|
|
||||||
})
|
|
||||||
|
|
||||||
req := httptest.NewRequest(http.MethodPost, "/v1/messages", bytes.NewReader(body))
|
|
||||||
req.Header.Set("Content-Type", "application/json")
|
|
||||||
req.Header.Set("anthropic-version", "2024-10-22")
|
|
||||||
req.Header.Set("anthropic-beta", "output-128k-2025-02-19")
|
|
||||||
req.Header.Set("X-Custom-Header", "custom-value")
|
|
||||||
c.Request = req
|
|
||||||
|
|
||||||
stub := &httpUpstreamCapture{resp: makeUpstreamErrorResponse()}
|
|
||||||
svc := &AntigravityGatewayService{
|
|
||||||
tokenProvider: &AntigravityTokenProvider{},
|
|
||||||
httpUpstream: stub,
|
|
||||||
}
|
|
||||||
|
|
||||||
_, _ = svc.ForwardUpstream(context.Background(), c, newUpstreamAccount(), body, false)
|
|
||||||
|
|
||||||
captured := stub.capturedReq
|
|
||||||
require.NotNil(t, captured, "upstream request should have been made")
|
|
||||||
|
|
||||||
// 客户端 header 应被透传
|
|
||||||
require.Equal(t, "application/json", captured.Header.Get("Content-Type"))
|
|
||||||
require.Equal(t, "2024-10-22", captured.Header.Get("anthropic-version"))
|
|
||||||
require.Equal(t, "output-128k-2025-02-19", captured.Header.Get("anthropic-beta"))
|
|
||||||
require.Equal(t, "custom-value", captured.Header.Get("X-Custom-Header"))
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestForwardUpstream_OverridesAuthHeaders(t *testing.T) {
|
|
||||||
gin.SetMode(gin.TestMode)
|
|
||||||
w := httptest.NewRecorder()
|
|
||||||
c, _ := gin.CreateTestContext(w)
|
|
||||||
|
|
||||||
body, _ := json.Marshal(map[string]any{
|
|
||||||
"model": "claude-sonnet-4-5",
|
|
||||||
"messages": []map[string]any{{"role": "user", "content": "hi"}},
|
|
||||||
"max_tokens": 1,
|
|
||||||
"stream": false,
|
|
||||||
})
|
|
||||||
|
|
||||||
req := httptest.NewRequest(http.MethodPost, "/v1/messages", bytes.NewReader(body))
|
|
||||||
req.Header.Set("Content-Type", "application/json")
|
|
||||||
// 客户端发来的认证头应被覆盖
|
|
||||||
req.Header.Set("Authorization", "Bearer client-token")
|
|
||||||
req.Header.Set("x-api-key", "client-api-key")
|
|
||||||
c.Request = req
|
|
||||||
|
|
||||||
stub := &httpUpstreamCapture{resp: makeUpstreamErrorResponse()}
|
|
||||||
svc := &AntigravityGatewayService{
|
|
||||||
tokenProvider: &AntigravityTokenProvider{},
|
|
||||||
httpUpstream: stub,
|
|
||||||
}
|
|
||||||
|
|
||||||
_, _ = svc.ForwardUpstream(context.Background(), c, newUpstreamAccount(), body, false)
|
|
||||||
|
|
||||||
captured := stub.capturedReq
|
|
||||||
require.NotNil(t, captured)
|
|
||||||
|
|
||||||
// 认证头应使用上游账号的 api_key,而非客户端的
|
|
||||||
require.Equal(t, "Bearer sk-upstream-secret", captured.Header.Get("Authorization"))
|
|
||||||
require.Equal(t, "sk-upstream-secret", captured.Header.Get("x-api-key"))
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestForwardUpstream_ExcludesHopByHopHeaders(t *testing.T) {
|
|
||||||
gin.SetMode(gin.TestMode)
|
|
||||||
w := httptest.NewRecorder()
|
|
||||||
c, _ := gin.CreateTestContext(w)
|
|
||||||
|
|
||||||
body, _ := json.Marshal(map[string]any{
|
|
||||||
"model": "claude-sonnet-4-5",
|
|
||||||
"messages": []map[string]any{{"role": "user", "content": "hi"}},
|
|
||||||
"max_tokens": 1,
|
|
||||||
"stream": false,
|
|
||||||
})
|
|
||||||
|
|
||||||
req := httptest.NewRequest(http.MethodPost, "/v1/messages", bytes.NewReader(body))
|
|
||||||
req.Header.Set("Content-Type", "application/json")
|
|
||||||
req.Header.Set("Connection", "keep-alive")
|
|
||||||
req.Header.Set("Keep-Alive", "timeout=5")
|
|
||||||
req.Header.Set("Transfer-Encoding", "chunked")
|
|
||||||
req.Header.Set("Upgrade", "websocket")
|
|
||||||
req.Header.Set("Te", "trailers")
|
|
||||||
c.Request = req
|
|
||||||
|
|
||||||
stub := &httpUpstreamCapture{resp: makeUpstreamErrorResponse()}
|
|
||||||
svc := &AntigravityGatewayService{
|
|
||||||
tokenProvider: &AntigravityTokenProvider{},
|
|
||||||
httpUpstream: stub,
|
|
||||||
}
|
|
||||||
|
|
||||||
_, _ = svc.ForwardUpstream(context.Background(), c, newUpstreamAccount(), body, false)
|
|
||||||
|
|
||||||
captured := stub.capturedReq
|
|
||||||
require.NotNil(t, captured)
|
|
||||||
|
|
||||||
// hop-by-hop header 不应出现
|
|
||||||
require.Empty(t, captured.Header.Get("Connection"))
|
|
||||||
require.Empty(t, captured.Header.Get("Keep-Alive"))
|
|
||||||
require.Empty(t, captured.Header.Get("Transfer-Encoding"))
|
|
||||||
require.Empty(t, captured.Header.Get("Upgrade"))
|
|
||||||
require.Empty(t, captured.Header.Get("Te"))
|
|
||||||
|
|
||||||
// 但普通 header 应保留
|
|
||||||
require.Equal(t, "application/json", captured.Header.Get("Content-Type"))
|
|
||||||
}
|
|
||||||
|
|
||||||
// --- ForwardUpstreamGemini tests ---
|
|
||||||
|
|
||||||
func TestForwardUpstreamGemini_PassthroughHeaders(t *testing.T) {
|
|
||||||
gin.SetMode(gin.TestMode)
|
|
||||||
w := httptest.NewRecorder()
|
|
||||||
c, _ := gin.CreateTestContext(w)
|
|
||||||
|
|
||||||
body, _ := json.Marshal(map[string]any{
|
|
||||||
"contents": []map[string]any{
|
|
||||||
{"role": "user", "parts": []map[string]any{{"text": "hi"}}},
|
|
||||||
},
|
|
||||||
})
|
|
||||||
|
|
||||||
req := httptest.NewRequest(http.MethodPost, "/v1beta/models/gemini-2.5-flash:generateContent", bytes.NewReader(body))
|
|
||||||
req.Header.Set("Content-Type", "application/json")
|
|
||||||
req.Header.Set("X-Custom-Gemini", "gemini-value")
|
|
||||||
req.Header.Set("X-Request-Id", "req-abc-123")
|
|
||||||
c.Request = req
|
|
||||||
|
|
||||||
stub := &httpUpstreamCapture{resp: makeUpstreamErrorResponse()}
|
|
||||||
svc := &AntigravityGatewayService{
|
|
||||||
tokenProvider: &AntigravityTokenProvider{},
|
|
||||||
httpUpstream: stub,
|
|
||||||
}
|
|
||||||
|
|
||||||
_, _ = svc.ForwardUpstreamGemini(context.Background(), c, newUpstreamAccount(), "gemini-2.5-flash", "generateContent", false, body, false)
|
|
||||||
|
|
||||||
captured := stub.capturedReq
|
|
||||||
require.NotNil(t, captured, "upstream request should have been made")
|
|
||||||
|
|
||||||
// 客户端 header 应被透传
|
|
||||||
require.Equal(t, "application/json", captured.Header.Get("Content-Type"))
|
|
||||||
require.Equal(t, "gemini-value", captured.Header.Get("X-Custom-Gemini"))
|
|
||||||
require.Equal(t, "req-abc-123", captured.Header.Get("X-Request-Id"))
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestForwardUpstreamGemini_OverridesAuthHeaders(t *testing.T) {
|
|
||||||
gin.SetMode(gin.TestMode)
|
|
||||||
w := httptest.NewRecorder()
|
|
||||||
c, _ := gin.CreateTestContext(w)
|
|
||||||
|
|
||||||
body, _ := json.Marshal(map[string]any{
|
|
||||||
"contents": []map[string]any{
|
|
||||||
{"role": "user", "parts": []map[string]any{{"text": "hi"}}},
|
|
||||||
},
|
|
||||||
})
|
|
||||||
|
|
||||||
req := httptest.NewRequest(http.MethodPost, "/v1beta/models/gemini-2.5-flash:generateContent", bytes.NewReader(body))
|
|
||||||
req.Header.Set("Content-Type", "application/json")
|
|
||||||
req.Header.Set("Authorization", "Bearer client-gemini-token")
|
|
||||||
c.Request = req
|
|
||||||
|
|
||||||
stub := &httpUpstreamCapture{resp: makeUpstreamErrorResponse()}
|
|
||||||
svc := &AntigravityGatewayService{
|
|
||||||
tokenProvider: &AntigravityTokenProvider{},
|
|
||||||
httpUpstream: stub,
|
|
||||||
}
|
|
||||||
|
|
||||||
_, _ = svc.ForwardUpstreamGemini(context.Background(), c, newUpstreamAccount(), "gemini-2.5-flash", "generateContent", false, body, false)
|
|
||||||
|
|
||||||
captured := stub.capturedReq
|
|
||||||
require.NotNil(t, captured)
|
|
||||||
|
|
||||||
// 认证头应使用上游账号的 api_key
|
|
||||||
require.Equal(t, "Bearer sk-upstream-secret", captured.Header.Get("Authorization"))
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestForwardUpstreamGemini_ExcludesHopByHopHeaders(t *testing.T) {
|
|
||||||
gin.SetMode(gin.TestMode)
|
|
||||||
w := httptest.NewRecorder()
|
|
||||||
c, _ := gin.CreateTestContext(w)
|
|
||||||
|
|
||||||
body, _ := json.Marshal(map[string]any{
|
|
||||||
"contents": []map[string]any{
|
|
||||||
{"role": "user", "parts": []map[string]any{{"text": "hi"}}},
|
|
||||||
},
|
|
||||||
})
|
|
||||||
|
|
||||||
req := httptest.NewRequest(http.MethodPost, "/v1beta/models/gemini-2.5-flash:generateContent", bytes.NewReader(body))
|
|
||||||
req.Header.Set("Content-Type", "application/json")
|
|
||||||
req.Header.Set("Connection", "keep-alive")
|
|
||||||
req.Header.Set("Proxy-Authorization", "Basic dXNlcjpwYXNz")
|
|
||||||
req.Header.Set("Host", "evil.example.com")
|
|
||||||
c.Request = req
|
|
||||||
|
|
||||||
stub := &httpUpstreamCapture{resp: makeUpstreamErrorResponse()}
|
|
||||||
svc := &AntigravityGatewayService{
|
|
||||||
tokenProvider: &AntigravityTokenProvider{},
|
|
||||||
httpUpstream: stub,
|
|
||||||
}
|
|
||||||
|
|
||||||
_, _ = svc.ForwardUpstreamGemini(context.Background(), c, newUpstreamAccount(), "gemini-2.5-flash", "generateContent", false, body, false)
|
|
||||||
|
|
||||||
captured := stub.capturedReq
|
|
||||||
require.NotNil(t, captured)
|
|
||||||
|
|
||||||
// hop-by-hop header 不应出现
|
|
||||||
require.Empty(t, captured.Header.Get("Connection"))
|
|
||||||
require.Empty(t, captured.Header.Get("Proxy-Authorization"))
|
|
||||||
// Host header 在 Go http.Request 中特殊处理,但我们的黑名单应阻止透传
|
|
||||||
require.Empty(t, captured.Header.Values("Host"))
|
|
||||||
|
|
||||||
// 普通 header 应保留
|
|
||||||
require.Equal(t, "application/json", captured.Header.Get("Content-Type"))
|
|
||||||
}
|
|
||||||
11
backend/migrations/052_migrate_upstream_to_apikey.sql
Normal file
11
backend/migrations/052_migrate_upstream_to_apikey.sql
Normal file
@@ -0,0 +1,11 @@
|
|||||||
|
-- Migrate upstream accounts to apikey type
|
||||||
|
-- Background: upstream type is no longer needed. Antigravity platform APIKey accounts
|
||||||
|
-- with base_url pointing to an upstream sub2api instance can reuse the standard
|
||||||
|
-- APIKey forwarding path. GetBaseURL()/GetGeminiBaseURL() automatically appends
|
||||||
|
-- /antigravity for Antigravity platform APIKey accounts.
|
||||||
|
|
||||||
|
UPDATE accounts
|
||||||
|
SET type = 'apikey'
|
||||||
|
WHERE type = 'upstream'
|
||||||
|
AND platform = 'antigravity'
|
||||||
|
AND deleted_at IS NULL;
|
||||||
@@ -2289,9 +2289,9 @@ watch(
|
|||||||
watch(
|
watch(
|
||||||
[accountCategory, addMethod, antigravityAccountType],
|
[accountCategory, addMethod, antigravityAccountType],
|
||||||
([category, method, agType]) => {
|
([category, method, agType]) => {
|
||||||
// Antigravity upstream 类型
|
// Antigravity upstream 类型(实际创建为 apikey)
|
||||||
if (form.platform === 'antigravity' && agType === 'upstream') {
|
if (form.platform === 'antigravity' && agType === 'upstream') {
|
||||||
form.type = 'upstream'
|
form.type = 'apikey'
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if (category === 'oauth-based') {
|
if (category === 'oauth-based') {
|
||||||
@@ -2715,7 +2715,7 @@ const handleSubmit = async () => {
|
|||||||
submitting.value = true
|
submitting.value = true
|
||||||
try {
|
try {
|
||||||
const extra = mixedScheduling.value ? { mixed_scheduling: true } : undefined
|
const extra = mixedScheduling.value ? { mixed_scheduling: true } : undefined
|
||||||
await createAccountAndFinish(form.platform, 'upstream', credentials, extra)
|
await createAccountAndFinish(form.platform, 'apikey', credentials, extra)
|
||||||
} catch (error: any) {
|
} catch (error: any) {
|
||||||
appStore.showError(error.response?.data?.detail || t('admin.accounts.failedToCreate'))
|
appStore.showError(error.response?.data?.detail || t('admin.accounts.failedToCreate'))
|
||||||
} finally {
|
} finally {
|
||||||
|
|||||||
Reference in New Issue
Block a user