diff --git a/backend/internal/handler/gateway_handler.go b/backend/internal/handler/gateway_handler.go index ca4442e4..255d3fab 100644 --- a/backend/internal/handler/gateway_handler.go +++ b/backend/internal/handler/gateway_handler.go @@ -482,7 +482,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) { if switchCount > 0 { requestCtx = context.WithValue(requestCtx, ctxkey.AccountSwitchCount, switchCount) } - if account.Platform == service.PlatformAntigravity { + if account.Platform == service.PlatformAntigravity && account.Type != service.AccountTypeAPIKey { result, err = h.antigravityGatewayService.Forward(requestCtx, c, account, body, hasBoundSession) } else { result, err = h.gatewayService.Forward(requestCtx, c, account, parsedReq) diff --git a/backend/internal/handler/gemini_v1beta_handler.go b/backend/internal/handler/gemini_v1beta_handler.go index b1477ac6..2b69be2e 100644 --- a/backend/internal/handler/gemini_v1beta_handler.go +++ b/backend/internal/handler/gemini_v1beta_handler.go @@ -410,7 +410,7 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) { if switchCount > 0 { requestCtx = context.WithValue(requestCtx, ctxkey.AccountSwitchCount, switchCount) } - if account.Platform == service.PlatformAntigravity { + if account.Platform == service.PlatformAntigravity && account.Type != service.AccountTypeAPIKey { result, err = h.antigravityGatewayService.ForwardGemini(requestCtx, c, account, modelName, action, stream, body, hasBoundSession) } else { result, err = h.geminiCompatService.ForwardNative(requestCtx, c, account, modelName, action, stream, body) diff --git a/backend/internal/service/account.go b/backend/internal/service/account.go index a6ae8a68..138d5bcb 100644 --- a/backend/internal/service/account.go +++ b/backend/internal/service/account.go @@ -425,6 +425,22 @@ func (a *Account) GetBaseURL() string { if baseURL == "" { return "https://api.anthropic.com" } + if a.Platform == PlatformAntigravity { + return strings.TrimRight(baseURL, "/") + "/antigravity" + } + return baseURL +} + +// GetGeminiBaseURL 返回 Gemini 兼容端点的 base URL。 +// Antigravity 平台的 APIKey 账号自动拼接 /antigravity。 +func (a *Account) GetGeminiBaseURL(defaultBaseURL string) string { + baseURL := strings.TrimSpace(a.GetCredential("base_url")) + if baseURL == "" { + return defaultBaseURL + } + if a.Platform == PlatformAntigravity && a.Type == AccountTypeAPIKey { + return strings.TrimRight(baseURL, "/") + "/antigravity" + } return baseURL } diff --git a/backend/internal/service/account_base_url_test.go b/backend/internal/service/account_base_url_test.go new file mode 100644 index 00000000..a1322193 --- /dev/null +++ b/backend/internal/service/account_base_url_test.go @@ -0,0 +1,160 @@ +//go:build unit + +package service + +import ( + "testing" +) + +func TestGetBaseURL(t *testing.T) { + tests := []struct { + name string + account Account + expected string + }{ + { + name: "non-apikey type returns empty", + account: Account{ + Type: AccountTypeOAuth, + Platform: PlatformAnthropic, + }, + expected: "", + }, + { + name: "apikey without base_url returns default anthropic", + account: Account{ + Type: AccountTypeAPIKey, + Platform: PlatformAnthropic, + Credentials: map[string]any{}, + }, + expected: "https://api.anthropic.com", + }, + { + name: "apikey with custom base_url", + account: Account{ + Type: AccountTypeAPIKey, + Platform: PlatformAnthropic, + Credentials: map[string]any{"base_url": "https://custom.example.com"}, + }, + expected: "https://custom.example.com", + }, + { + name: "antigravity apikey auto-appends /antigravity", + account: Account{ + Type: AccountTypeAPIKey, + Platform: PlatformAntigravity, + Credentials: map[string]any{"base_url": "https://upstream.example.com"}, + }, + expected: "https://upstream.example.com/antigravity", + }, + { + name: "antigravity apikey trims trailing slash before appending", + account: Account{ + Type: AccountTypeAPIKey, + Platform: PlatformAntigravity, + Credentials: map[string]any{"base_url": "https://upstream.example.com/"}, + }, + expected: "https://upstream.example.com/antigravity", + }, + { + name: "antigravity non-apikey returns empty", + account: Account{ + Type: AccountTypeOAuth, + Platform: PlatformAntigravity, + Credentials: map[string]any{"base_url": "https://upstream.example.com"}, + }, + expected: "", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := tt.account.GetBaseURL() + if result != tt.expected { + t.Errorf("GetBaseURL() = %q, want %q", result, tt.expected) + } + }) + } +} + +func TestGetGeminiBaseURL(t *testing.T) { + const defaultGeminiURL = "https://generativelanguage.googleapis.com" + + tests := []struct { + name string + account Account + expected string + }{ + { + name: "apikey without base_url returns default", + account: Account{ + Type: AccountTypeAPIKey, + Platform: PlatformGemini, + Credentials: map[string]any{}, + }, + expected: defaultGeminiURL, + }, + { + name: "apikey with custom base_url", + account: Account{ + Type: AccountTypeAPIKey, + Platform: PlatformGemini, + Credentials: map[string]any{"base_url": "https://custom-gemini.example.com"}, + }, + expected: "https://custom-gemini.example.com", + }, + { + name: "antigravity apikey auto-appends /antigravity", + account: Account{ + Type: AccountTypeAPIKey, + Platform: PlatformAntigravity, + Credentials: map[string]any{"base_url": "https://upstream.example.com"}, + }, + expected: "https://upstream.example.com/antigravity", + }, + { + name: "antigravity apikey trims trailing slash", + account: Account{ + Type: AccountTypeAPIKey, + Platform: PlatformAntigravity, + Credentials: map[string]any{"base_url": "https://upstream.example.com/"}, + }, + expected: "https://upstream.example.com/antigravity", + }, + { + name: "antigravity oauth does NOT append /antigravity", + account: Account{ + Type: AccountTypeOAuth, + Platform: PlatformAntigravity, + Credentials: map[string]any{"base_url": "https://upstream.example.com"}, + }, + expected: "https://upstream.example.com", + }, + { + name: "oauth without base_url returns default", + account: Account{ + Type: AccountTypeOAuth, + Platform: PlatformAntigravity, + Credentials: map[string]any{}, + }, + expected: defaultGeminiURL, + }, + { + name: "nil credentials returns default", + account: Account{ + Type: AccountTypeAPIKey, + Platform: PlatformGemini, + }, + expected: defaultGeminiURL, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := tt.account.GetGeminiBaseURL(defaultGeminiURL) + if result != tt.expected { + t.Errorf("GetGeminiBaseURL() = %q, want %q", result, tt.expected) + } + }) + } +} diff --git a/backend/internal/service/antigravity_gateway_service.go b/backend/internal/service/antigravity_gateway_service.go index 2d96b1ab..4ea73e64 100644 --- a/backend/internal/service/antigravity_gateway_service.go +++ b/backend/internal/service/antigravity_gateway_service.go @@ -665,9 +665,6 @@ type TestConnectionResult struct { // TestConnection 测试 Antigravity 账号连接(非流式,无重试、无计费) // 支持 Claude 和 Gemini 两种协议,根据 modelID 前缀自动选择 func (s *AntigravityGatewayService) TestConnection(ctx context.Context, account *Account, modelID string) (*TestConnectionResult, error) { - if account.Type == AccountTypeUpstream { - return s.testUpstreamConnection(ctx, account, modelID) - } // 获取 token if s.tokenProvider == nil { @@ -986,10 +983,6 @@ func isModelNotFoundError(statusCode int, body []byte) bool { func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context, account *Account, body []byte, isStickySession bool) (*ForwardResult, error) { startTime := time.Now() - if account.Type == AccountTypeUpstream { - return s.ForwardUpstream(ctx, c, account, body, isStickySession) - } - sessionID := getSessionID(c) prefix := logPrefix(sessionID, account.Name) @@ -1610,10 +1603,6 @@ func stripSignatureSensitiveBlocksFromClaudeRequest(req *antigravity.ClaudeReque func (s *AntigravityGatewayService) ForwardGemini(ctx context.Context, c *gin.Context, account *Account, originalModel string, action string, stream bool, body []byte, isStickySession bool) (*ForwardResult, error) { startTime := time.Now() - if account.Type == AccountTypeUpstream { - return s.ForwardUpstreamGemini(ctx, c, account, originalModel, action, stream, body, isStickySession) - } - sessionID := getSessionID(c) prefix := logPrefix(sessionID, account.Name) @@ -3361,378 +3350,3 @@ func filterEmptyPartsFromGeminiRequest(body []byte) ([]byte, error) { payload["contents"] = filtered return json.Marshal(payload) } - -// --------------------------------------------------------------------------- -// Upstream 专用转发方法 -// upstream 账号直接连接上游 Anthropic/Gemini 兼容端点,不走 Antigravity OAuth 协议转换。 -// --------------------------------------------------------------------------- - -// testUpstreamConnection 测试 upstream 账号连接 -func (s *AntigravityGatewayService) testUpstreamConnection(ctx context.Context, account *Account, modelID string) (*TestConnectionResult, error) { - baseURL := strings.TrimRight(strings.TrimSpace(account.GetCredential("base_url")), "/") - if baseURL == "" { - return nil, errors.New("upstream account missing base_url in credentials") - } - apiKey := strings.TrimSpace(account.GetCredential("api_key")) - if apiKey == "" { - return nil, errors.New("upstream account missing api_key in credentials") - } - - mappedModel := s.getMappedModel(account, modelID) - if mappedModel == "" { - return nil, fmt.Errorf("model %s not in whitelist", modelID) - } - - // 构建最小 Claude 格式请求 - requestBody, _ := json.Marshal(map[string]any{ - "model": mappedModel, - "max_tokens": 1, - "messages": []map[string]any{ - {"role": "user", "content": "."}, - }, - "stream": false, - }) - - apiURL := baseURL + "/antigravity/v1/messages" - req, err := http.NewRequestWithContext(ctx, http.MethodPost, apiURL, bytes.NewReader(requestBody)) - if err != nil { - return nil, fmt.Errorf("构建请求失败: %w", err) - } - req.Header.Set("Content-Type", "application/json") - req.Header.Set("Authorization", "Bearer "+apiKey) - req.Header.Set("x-api-key", apiKey) - req.Header.Set("anthropic-version", "2023-06-01") - - proxyURL := "" - if account.ProxyID != nil && account.Proxy != nil { - proxyURL = account.Proxy.URL() - } - - log.Printf("[antigravity-Test-Upstream] account=%s url=%s", account.Name, apiURL) - - resp, err := s.httpUpstream.Do(req, proxyURL, account.ID, account.Concurrency) - if err != nil { - return nil, fmt.Errorf("请求失败: %w", err) - } - defer func() { _ = resp.Body.Close() }() - - respBody, err := io.ReadAll(io.LimitReader(resp.Body, 2<<20)) - if err != nil { - return nil, fmt.Errorf("读取响应失败: %w", err) - } - - if resp.StatusCode >= 400 { - return nil, fmt.Errorf("API 返回 %d: %s", resp.StatusCode, string(respBody)) - } - - // 从 Claude 格式非流式响应中提取文本 - var claudeResp struct { - Content []struct { - Text string `json:"text"` - } `json:"content"` - } - text := "" - if json.Unmarshal(respBody, &claudeResp) == nil && len(claudeResp.Content) > 0 { - text = claudeResp.Content[0].Text - } - - return &TestConnectionResult{ - Text: text, - MappedModel: mappedModel, - }, nil -} - -// ForwardUpstream 转发 Claude 协议请求到 upstream(不做协议转换) -func (s *AntigravityGatewayService) ForwardUpstream(ctx context.Context, c *gin.Context, account *Account, body []byte, isStickySession bool) (*ForwardResult, error) { - startTime := time.Now() - sessionID := getSessionID(c) - prefix := logPrefix(sessionID, account.Name) - - baseURL := strings.TrimRight(strings.TrimSpace(account.GetCredential("base_url")), "/") - if baseURL == "" { - return nil, s.writeClaudeError(c, http.StatusBadGateway, "api_error", "Upstream account missing base_url") - } - apiKey := strings.TrimSpace(account.GetCredential("api_key")) - if apiKey == "" { - return nil, s.writeClaudeError(c, http.StatusBadGateway, "authentication_error", "Upstream account missing api_key") - } - - // 解析请求以获取模型和流式标志 - var claudeReq antigravity.ClaudeRequest - if err := json.Unmarshal(body, &claudeReq); err != nil { - return nil, s.writeClaudeError(c, http.StatusBadRequest, "invalid_request_error", "Invalid request body") - } - if strings.TrimSpace(claudeReq.Model) == "" { - return nil, s.writeClaudeError(c, http.StatusBadRequest, "invalid_request_error", "Missing model") - } - - originalModel := claudeReq.Model - mappedModel := s.getMappedModel(account, claudeReq.Model) - if mappedModel == "" { - return nil, s.writeClaudeError(c, http.StatusForbidden, "permission_error", fmt.Sprintf("model %s not in whitelist", claudeReq.Model)) - } - - // 代理 URL - proxyURL := "" - if account.ProxyID != nil && account.Proxy != nil { - proxyURL = account.Proxy.URL() - } - - // 统计模型调用次数 - if s.cache != nil { - _, _ = s.cache.IncrModelCallCount(ctx, account.ID, mappedModel) - } - - apiURL := baseURL + "/antigravity/v1/messages" - log.Printf("%s upstream_forward url=%s model=%s", prefix, apiURL, mappedModel) - - // 构建请求:body 原样透传 - req, err := http.NewRequestWithContext(ctx, http.MethodPost, apiURL, bytes.NewReader(body)) - if err != nil { - return nil, s.writeClaudeError(c, http.StatusInternalServerError, "api_error", "Failed to build request") - } - // 透传客户端所有请求头(排除 hop-by-hop 和认证头) - if c != nil && c.Request != nil { - for key, values := range c.Request.Header { - if upstreamHopByHopHeaders[strings.ToLower(key)] { - continue - } - for _, v := range values { - req.Header.Add(key, v) - } - } - } - // 覆盖认证头 - req.Header.Set("Authorization", "Bearer "+apiKey) - req.Header.Set("x-api-key", apiKey) - - if c != nil && len(body) > 0 { - c.Set(OpsUpstreamRequestBodyKey, string(body)) - } - - // 单次发送,不重试 - resp, err := s.httpUpstream.Do(req, proxyURL, account.ID, account.Concurrency) - if err != nil { - return nil, s.writeClaudeError(c, http.StatusBadGateway, "upstream_error", fmt.Sprintf("Upstream request failed: %v", err)) - } - defer func() { _ = resp.Body.Close() }() - - // 错误响应处理 - if resp.StatusCode >= 400 { - respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20)) - - s.handleUpstreamError(ctx, prefix, account, resp.StatusCode, resp.Header, respBody, "", 0, "", isStickySession) - - if s.shouldFailoverUpstreamError(resp.StatusCode) { - return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode, ResponseBody: respBody} - } - - return nil, s.writeMappedClaudeError(c, account, resp.StatusCode, resp.Header.Get("x-request-id"), respBody) - } - - // 成功响应:透传 response header + body - requestID := resp.Header.Get("x-request-id") - - // 透传上游响应头(排除 hop-by-hop) - for key, values := range resp.Header { - if upstreamHopByHopHeaders[strings.ToLower(key)] { - continue - } - for _, v := range values { - c.Header(key, v) - } - } - - c.Status(resp.StatusCode) - _, copyErr := io.Copy(c.Writer, resp.Body) - if copyErr != nil { - log.Printf("%s status=copy_error error=%v", prefix, copyErr) - } - - return &ForwardResult{ - RequestID: requestID, - Model: originalModel, - Stream: claudeReq.Stream, - Duration: time.Since(startTime), - }, nil -} - -// ForwardUpstreamGemini 转发 Gemini 协议请求到 upstream(不做协议转换) -func (s *AntigravityGatewayService) ForwardUpstreamGemini(ctx context.Context, c *gin.Context, account *Account, originalModel string, action string, stream bool, body []byte, isStickySession bool) (*ForwardResult, error) { - startTime := time.Now() - sessionID := getSessionID(c) - prefix := logPrefix(sessionID, account.Name) - - baseURL := strings.TrimRight(strings.TrimSpace(account.GetCredential("base_url")), "/") - if baseURL == "" { - return nil, s.writeGoogleError(c, http.StatusBadGateway, "Upstream account missing base_url") - } - apiKey := strings.TrimSpace(account.GetCredential("api_key")) - if apiKey == "" { - return nil, s.writeGoogleError(c, http.StatusBadGateway, "Upstream account missing api_key") - } - - if strings.TrimSpace(originalModel) == "" { - return nil, s.writeGoogleError(c, http.StatusBadRequest, "Missing model in URL") - } - if strings.TrimSpace(action) == "" { - return nil, s.writeGoogleError(c, http.StatusBadRequest, "Missing action in URL") - } - if len(body) == 0 { - return nil, s.writeGoogleError(c, http.StatusBadRequest, "Request body is empty") - } - - imageSize := s.extractImageSize(body) - - switch action { - case "generateContent", "streamGenerateContent": - // ok - case "countTokens": - c.JSON(http.StatusOK, map[string]any{"totalTokens": 0}) - return &ForwardResult{ - RequestID: "", - Usage: ClaudeUsage{}, - Model: originalModel, - Stream: false, - Duration: time.Since(time.Now()), - FirstTokenMs: nil, - }, nil - default: - return nil, s.writeGoogleError(c, http.StatusNotFound, "Unsupported action: "+action) - } - - mappedModel := s.getMappedModel(account, originalModel) - if mappedModel == "" { - return nil, s.writeGoogleError(c, http.StatusForbidden, fmt.Sprintf("model %s not in whitelist", originalModel)) - } - - // 代理 URL - proxyURL := "" - if account.ProxyID != nil && account.Proxy != nil { - proxyURL = account.Proxy.URL() - } - - // 统计模型调用次数 - if s.cache != nil { - _, _ = s.cache.IncrModelCallCount(ctx, account.ID, mappedModel) - } - - // 构建 upstream URL: base_url + /antigravity/v1beta/models/MODEL:ACTION - apiURL := fmt.Sprintf("%s/antigravity/v1beta/models/%s:%s", baseURL, mappedModel, action) - if stream || action == "streamGenerateContent" { - apiURL += "?alt=sse" - } - - log.Printf("%s upstream_forward_gemini url=%s model=%s action=%s", prefix, apiURL, mappedModel, action) - - // 构建请求:body 原样透传 - req, err := http.NewRequestWithContext(ctx, http.MethodPost, apiURL, bytes.NewReader(body)) - if err != nil { - return nil, s.writeGoogleError(c, http.StatusInternalServerError, "Failed to build request") - } - // 透传客户端所有请求头(排除 hop-by-hop 和认证头) - if c != nil && c.Request != nil { - for key, values := range c.Request.Header { - if upstreamHopByHopHeaders[strings.ToLower(key)] { - continue - } - for _, v := range values { - req.Header.Add(key, v) - } - } - } - // 覆盖认证头 - req.Header.Set("Authorization", "Bearer "+apiKey) - - if c != nil && len(body) > 0 { - c.Set(OpsUpstreamRequestBodyKey, string(body)) - } - - // 单次发送,不重试 - resp, err := s.httpUpstream.Do(req, proxyURL, account.ID, account.Concurrency) - if err != nil { - return nil, s.writeGoogleError(c, http.StatusBadGateway, fmt.Sprintf("Upstream request failed: %v", err)) - } - defer func() { _ = resp.Body.Close() }() - - // 错误响应处理 - if resp.StatusCode >= 400 { - respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20)) - contentType := resp.Header.Get("Content-Type") - - requestID := resp.Header.Get("x-request-id") - if requestID != "" { - c.Header("x-request-id", requestID) - } - - s.handleUpstreamError(ctx, prefix, account, resp.StatusCode, resp.Header, respBody, "", 0, "", isStickySession) - upstreamMsg := strings.TrimSpace(extractAntigravityErrorMessage(respBody)) - upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg) - upstreamDetail := s.getUpstreamErrorDetail(respBody) - - setOpsUpstreamError(c, resp.StatusCode, upstreamMsg, upstreamDetail) - - if s.shouldFailoverUpstreamError(resp.StatusCode) { - appendOpsUpstreamError(c, OpsUpstreamErrorEvent{ - Platform: account.Platform, - AccountID: account.ID, - AccountName: account.Name, - UpstreamStatusCode: resp.StatusCode, - UpstreamRequestID: requestID, - Kind: "failover", - Message: upstreamMsg, - Detail: upstreamDetail, - }) - return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode, ResponseBody: respBody} - } - if contentType == "" { - contentType = "application/json" - } - appendOpsUpstreamError(c, OpsUpstreamErrorEvent{ - Platform: account.Platform, - AccountID: account.ID, - AccountName: account.Name, - UpstreamStatusCode: resp.StatusCode, - UpstreamRequestID: requestID, - Kind: "http_error", - Message: upstreamMsg, - Detail: upstreamDetail, - }) - log.Printf("[antigravity-Forward-Upstream] upstream error status=%d body=%s", resp.StatusCode, truncateForLog(respBody, 500)) - c.Data(resp.StatusCode, contentType, respBody) - return nil, fmt.Errorf("antigravity upstream error: %d", resp.StatusCode) - } - - // 成功响应:透传 response header + body - requestID := resp.Header.Get("x-request-id") - - // 透传上游响应头(排除 hop-by-hop) - for key, values := range resp.Header { - if upstreamHopByHopHeaders[strings.ToLower(key)] { - continue - } - for _, v := range values { - c.Header(key, v) - } - } - - c.Status(resp.StatusCode) - _, copyErr := io.Copy(c.Writer, resp.Body) - if copyErr != nil { - log.Printf("%s status=copy_error error=%v", prefix, copyErr) - } - - imageCount := 0 - if isImageGenerationModel(mappedModel) { - imageCount = 1 - } - - return &ForwardResult{ - RequestID: requestID, - Model: originalModel, - Stream: stream, - Duration: time.Since(startTime), - ImageCount: imageCount, - ImageSize: imageSize, - }, nil -} diff --git a/backend/internal/service/gemini_messages_compat_service.go b/backend/internal/service/gemini_messages_compat_service.go index 0f156c2e..4e0442fd 100644 --- a/backend/internal/service/gemini_messages_compat_service.go +++ b/backend/internal/service/gemini_messages_compat_service.go @@ -560,10 +560,7 @@ func (s *GeminiMessagesCompatService) Forward(ctx context.Context, c *gin.Contex return nil, "", errors.New("gemini api_key not configured") } - baseURL := strings.TrimSpace(account.GetCredential("base_url")) - if baseURL == "" { - baseURL = geminicli.AIStudioBaseURL - } + baseURL := account.GetGeminiBaseURL(geminicli.AIStudioBaseURL) normalizedBaseURL, err := s.validateUpstreamBaseURL(baseURL) if err != nil { return nil, "", err @@ -640,10 +637,7 @@ func (s *GeminiMessagesCompatService) Forward(ctx context.Context, c *gin.Contex return upstreamReq, "x-request-id", nil } else { // Mode 2: AI Studio API with OAuth (like API key mode, but using Bearer token) - baseURL := strings.TrimSpace(account.GetCredential("base_url")) - if baseURL == "" { - baseURL = geminicli.AIStudioBaseURL - } + baseURL := account.GetGeminiBaseURL(geminicli.AIStudioBaseURL) normalizedBaseURL, err := s.validateUpstreamBaseURL(baseURL) if err != nil { return nil, "", err @@ -1026,10 +1020,7 @@ func (s *GeminiMessagesCompatService) ForwardNative(ctx context.Context, c *gin. return nil, "", errors.New("gemini api_key not configured") } - baseURL := strings.TrimSpace(account.GetCredential("base_url")) - if baseURL == "" { - baseURL = geminicli.AIStudioBaseURL - } + baseURL := account.GetGeminiBaseURL(geminicli.AIStudioBaseURL) normalizedBaseURL, err := s.validateUpstreamBaseURL(baseURL) if err != nil { return nil, "", err @@ -1097,10 +1088,7 @@ func (s *GeminiMessagesCompatService) ForwardNative(ctx context.Context, c *gin. return upstreamReq, "x-request-id", nil } else { // Mode 2: AI Studio API with OAuth (like API key mode, but using Bearer token) - baseURL := strings.TrimSpace(account.GetCredential("base_url")) - if baseURL == "" { - baseURL = geminicli.AIStudioBaseURL - } + baseURL := account.GetGeminiBaseURL(geminicli.AIStudioBaseURL) normalizedBaseURL, err := s.validateUpstreamBaseURL(baseURL) if err != nil { return nil, "", err @@ -2420,10 +2408,7 @@ func (s *GeminiMessagesCompatService) ForwardAIStudioGET(ctx context.Context, ac return nil, errors.New("invalid path") } - baseURL := strings.TrimSpace(account.GetCredential("base_url")) - if baseURL == "" { - baseURL = geminicli.AIStudioBaseURL - } + baseURL := account.GetGeminiBaseURL(geminicli.AIStudioBaseURL) normalizedBaseURL, err := s.validateUpstreamBaseURL(baseURL) if err != nil { return nil, err diff --git a/backend/internal/service/upstream_header_passthrough_test.go b/backend/internal/service/upstream_header_passthrough_test.go deleted file mode 100644 index 51d8588b..00000000 --- a/backend/internal/service/upstream_header_passthrough_test.go +++ /dev/null @@ -1,285 +0,0 @@ -//go:build unit - -package service - -import ( - "bytes" - "context" - "encoding/json" - "io" - "net/http" - "net/http/httptest" - "testing" - - "github.com/gin-gonic/gin" - "github.com/stretchr/testify/require" -) - -// httpUpstreamCapture captures the outgoing *http.Request for assertion. -type httpUpstreamCapture struct { - capturedReq *http.Request - resp *http.Response - err error -} - -func (s *httpUpstreamCapture) Do(req *http.Request, _ string, _ int64, _ int) (*http.Response, error) { - s.capturedReq = req - return s.resp, s.err -} - -func (s *httpUpstreamCapture) DoWithTLS(req *http.Request, _ string, _ int64, _ int, _ bool) (*http.Response, error) { - s.capturedReq = req - return s.resp, s.err -} - -func newUpstreamAccount() *Account { - return &Account{ - ID: 100, - Name: "upstream-test", - Platform: PlatformAntigravity, - Type: AccountTypeUpstream, - Status: StatusActive, - Concurrency: 1, - Credentials: map[string]any{ - "base_url": "https://upstream.example.com", - "api_key": "sk-upstream-secret", - }, - } -} - -// makeSSEOKResponse builds a minimal SSE response that -// handleClaudeStreamingResponse / handleGeminiStreamingResponse -// can consume without error. -// We return 502 to bypass streaming and hit the error branch instead, -// which is sufficient for testing header passthrough. -func makeUpstreamErrorResponse() *http.Response { - body := []byte(`{"error":{"message":"test error"}}`) - return &http.Response{ - StatusCode: http.StatusBadGateway, - Header: http.Header{"Content-Type": []string{"application/json"}}, - Body: io.NopCloser(bytes.NewReader(body)), - } -} - -// --- ForwardUpstream tests --- - -func TestForwardUpstream_PassthroughHeaders(t *testing.T) { - gin.SetMode(gin.TestMode) - w := httptest.NewRecorder() - c, _ := gin.CreateTestContext(w) - - body, _ := json.Marshal(map[string]any{ - "model": "claude-sonnet-4-5", - "messages": []map[string]any{{"role": "user", "content": "hi"}}, - "max_tokens": 1, - "stream": false, - }) - - req := httptest.NewRequest(http.MethodPost, "/v1/messages", bytes.NewReader(body)) - req.Header.Set("Content-Type", "application/json") - req.Header.Set("anthropic-version", "2024-10-22") - req.Header.Set("anthropic-beta", "output-128k-2025-02-19") - req.Header.Set("X-Custom-Header", "custom-value") - c.Request = req - - stub := &httpUpstreamCapture{resp: makeUpstreamErrorResponse()} - svc := &AntigravityGatewayService{ - tokenProvider: &AntigravityTokenProvider{}, - httpUpstream: stub, - } - - _, _ = svc.ForwardUpstream(context.Background(), c, newUpstreamAccount(), body, false) - - captured := stub.capturedReq - require.NotNil(t, captured, "upstream request should have been made") - - // 客户端 header 应被透传 - require.Equal(t, "application/json", captured.Header.Get("Content-Type")) - require.Equal(t, "2024-10-22", captured.Header.Get("anthropic-version")) - require.Equal(t, "output-128k-2025-02-19", captured.Header.Get("anthropic-beta")) - require.Equal(t, "custom-value", captured.Header.Get("X-Custom-Header")) -} - -func TestForwardUpstream_OverridesAuthHeaders(t *testing.T) { - gin.SetMode(gin.TestMode) - w := httptest.NewRecorder() - c, _ := gin.CreateTestContext(w) - - body, _ := json.Marshal(map[string]any{ - "model": "claude-sonnet-4-5", - "messages": []map[string]any{{"role": "user", "content": "hi"}}, - "max_tokens": 1, - "stream": false, - }) - - req := httptest.NewRequest(http.MethodPost, "/v1/messages", bytes.NewReader(body)) - req.Header.Set("Content-Type", "application/json") - // 客户端发来的认证头应被覆盖 - req.Header.Set("Authorization", "Bearer client-token") - req.Header.Set("x-api-key", "client-api-key") - c.Request = req - - stub := &httpUpstreamCapture{resp: makeUpstreamErrorResponse()} - svc := &AntigravityGatewayService{ - tokenProvider: &AntigravityTokenProvider{}, - httpUpstream: stub, - } - - _, _ = svc.ForwardUpstream(context.Background(), c, newUpstreamAccount(), body, false) - - captured := stub.capturedReq - require.NotNil(t, captured) - - // 认证头应使用上游账号的 api_key,而非客户端的 - require.Equal(t, "Bearer sk-upstream-secret", captured.Header.Get("Authorization")) - require.Equal(t, "sk-upstream-secret", captured.Header.Get("x-api-key")) -} - -func TestForwardUpstream_ExcludesHopByHopHeaders(t *testing.T) { - gin.SetMode(gin.TestMode) - w := httptest.NewRecorder() - c, _ := gin.CreateTestContext(w) - - body, _ := json.Marshal(map[string]any{ - "model": "claude-sonnet-4-5", - "messages": []map[string]any{{"role": "user", "content": "hi"}}, - "max_tokens": 1, - "stream": false, - }) - - req := httptest.NewRequest(http.MethodPost, "/v1/messages", bytes.NewReader(body)) - req.Header.Set("Content-Type", "application/json") - req.Header.Set("Connection", "keep-alive") - req.Header.Set("Keep-Alive", "timeout=5") - req.Header.Set("Transfer-Encoding", "chunked") - req.Header.Set("Upgrade", "websocket") - req.Header.Set("Te", "trailers") - c.Request = req - - stub := &httpUpstreamCapture{resp: makeUpstreamErrorResponse()} - svc := &AntigravityGatewayService{ - tokenProvider: &AntigravityTokenProvider{}, - httpUpstream: stub, - } - - _, _ = svc.ForwardUpstream(context.Background(), c, newUpstreamAccount(), body, false) - - captured := stub.capturedReq - require.NotNil(t, captured) - - // hop-by-hop header 不应出现 - require.Empty(t, captured.Header.Get("Connection")) - require.Empty(t, captured.Header.Get("Keep-Alive")) - require.Empty(t, captured.Header.Get("Transfer-Encoding")) - require.Empty(t, captured.Header.Get("Upgrade")) - require.Empty(t, captured.Header.Get("Te")) - - // 但普通 header 应保留 - require.Equal(t, "application/json", captured.Header.Get("Content-Type")) -} - -// --- ForwardUpstreamGemini tests --- - -func TestForwardUpstreamGemini_PassthroughHeaders(t *testing.T) { - gin.SetMode(gin.TestMode) - w := httptest.NewRecorder() - c, _ := gin.CreateTestContext(w) - - body, _ := json.Marshal(map[string]any{ - "contents": []map[string]any{ - {"role": "user", "parts": []map[string]any{{"text": "hi"}}}, - }, - }) - - req := httptest.NewRequest(http.MethodPost, "/v1beta/models/gemini-2.5-flash:generateContent", bytes.NewReader(body)) - req.Header.Set("Content-Type", "application/json") - req.Header.Set("X-Custom-Gemini", "gemini-value") - req.Header.Set("X-Request-Id", "req-abc-123") - c.Request = req - - stub := &httpUpstreamCapture{resp: makeUpstreamErrorResponse()} - svc := &AntigravityGatewayService{ - tokenProvider: &AntigravityTokenProvider{}, - httpUpstream: stub, - } - - _, _ = svc.ForwardUpstreamGemini(context.Background(), c, newUpstreamAccount(), "gemini-2.5-flash", "generateContent", false, body, false) - - captured := stub.capturedReq - require.NotNil(t, captured, "upstream request should have been made") - - // 客户端 header 应被透传 - require.Equal(t, "application/json", captured.Header.Get("Content-Type")) - require.Equal(t, "gemini-value", captured.Header.Get("X-Custom-Gemini")) - require.Equal(t, "req-abc-123", captured.Header.Get("X-Request-Id")) -} - -func TestForwardUpstreamGemini_OverridesAuthHeaders(t *testing.T) { - gin.SetMode(gin.TestMode) - w := httptest.NewRecorder() - c, _ := gin.CreateTestContext(w) - - body, _ := json.Marshal(map[string]any{ - "contents": []map[string]any{ - {"role": "user", "parts": []map[string]any{{"text": "hi"}}}, - }, - }) - - req := httptest.NewRequest(http.MethodPost, "/v1beta/models/gemini-2.5-flash:generateContent", bytes.NewReader(body)) - req.Header.Set("Content-Type", "application/json") - req.Header.Set("Authorization", "Bearer client-gemini-token") - c.Request = req - - stub := &httpUpstreamCapture{resp: makeUpstreamErrorResponse()} - svc := &AntigravityGatewayService{ - tokenProvider: &AntigravityTokenProvider{}, - httpUpstream: stub, - } - - _, _ = svc.ForwardUpstreamGemini(context.Background(), c, newUpstreamAccount(), "gemini-2.5-flash", "generateContent", false, body, false) - - captured := stub.capturedReq - require.NotNil(t, captured) - - // 认证头应使用上游账号的 api_key - require.Equal(t, "Bearer sk-upstream-secret", captured.Header.Get("Authorization")) -} - -func TestForwardUpstreamGemini_ExcludesHopByHopHeaders(t *testing.T) { - gin.SetMode(gin.TestMode) - w := httptest.NewRecorder() - c, _ := gin.CreateTestContext(w) - - body, _ := json.Marshal(map[string]any{ - "contents": []map[string]any{ - {"role": "user", "parts": []map[string]any{{"text": "hi"}}}, - }, - }) - - req := httptest.NewRequest(http.MethodPost, "/v1beta/models/gemini-2.5-flash:generateContent", bytes.NewReader(body)) - req.Header.Set("Content-Type", "application/json") - req.Header.Set("Connection", "keep-alive") - req.Header.Set("Proxy-Authorization", "Basic dXNlcjpwYXNz") - req.Header.Set("Host", "evil.example.com") - c.Request = req - - stub := &httpUpstreamCapture{resp: makeUpstreamErrorResponse()} - svc := &AntigravityGatewayService{ - tokenProvider: &AntigravityTokenProvider{}, - httpUpstream: stub, - } - - _, _ = svc.ForwardUpstreamGemini(context.Background(), c, newUpstreamAccount(), "gemini-2.5-flash", "generateContent", false, body, false) - - captured := stub.capturedReq - require.NotNil(t, captured) - - // hop-by-hop header 不应出现 - require.Empty(t, captured.Header.Get("Connection")) - require.Empty(t, captured.Header.Get("Proxy-Authorization")) - // Host header 在 Go http.Request 中特殊处理,但我们的黑名单应阻止透传 - require.Empty(t, captured.Header.Values("Host")) - - // 普通 header 应保留 - require.Equal(t, "application/json", captured.Header.Get("Content-Type")) -} diff --git a/backend/migrations/052_migrate_upstream_to_apikey.sql b/backend/migrations/052_migrate_upstream_to_apikey.sql new file mode 100644 index 00000000..974f3f3c --- /dev/null +++ b/backend/migrations/052_migrate_upstream_to_apikey.sql @@ -0,0 +1,11 @@ +-- Migrate upstream accounts to apikey type +-- Background: upstream type is no longer needed. Antigravity platform APIKey accounts +-- with base_url pointing to an upstream sub2api instance can reuse the standard +-- APIKey forwarding path. GetBaseURL()/GetGeminiBaseURL() automatically appends +-- /antigravity for Antigravity platform APIKey accounts. + +UPDATE accounts +SET type = 'apikey' +WHERE type = 'upstream' + AND platform = 'antigravity' + AND deleted_at IS NULL; diff --git a/frontend/src/components/account/CreateAccountModal.vue b/frontend/src/components/account/CreateAccountModal.vue index 7d759be1..603941c1 100644 --- a/frontend/src/components/account/CreateAccountModal.vue +++ b/frontend/src/components/account/CreateAccountModal.vue @@ -2289,9 +2289,9 @@ watch( watch( [accountCategory, addMethod, antigravityAccountType], ([category, method, agType]) => { - // Antigravity upstream 类型 + // Antigravity upstream 类型(实际创建为 apikey) if (form.platform === 'antigravity' && agType === 'upstream') { - form.type = 'upstream' + form.type = 'apikey' return } if (category === 'oauth-based') { @@ -2715,7 +2715,7 @@ const handleSubmit = async () => { submitting.value = true try { const extra = mixedScheduling.value ? { mixed_scheduling: true } : undefined - await createAccountAndFinish(form.platform, 'upstream', credentials, extra) + await createAccountAndFinish(form.platform, 'apikey', credentials, extra) } catch (error: any) { appStore.showError(error.response?.data?.detail || t('admin.accounts.failedToCreate')) } finally {