From 77b66653ed96ede34fcb99f5d3bfbf8a04864292 Mon Sep 17 00:00:00 2001 From: erio Date: Sun, 8 Feb 2026 01:21:02 +0800 Subject: [PATCH 1/8] fix(gateway): restore upstream account forwarding with dedicated methods v0.1.74 merged upstream accounts into the OAuth path, causing requests to hit the wrong protocol and endpoint. Add three upstream-specific methods (testUpstreamConnection, ForwardUpstream, ForwardUpstreamGemini) that use base_url + apiKey auth and passthrough the original body, while reusing the existing response handling and error/retry logic. --- .../service/antigravity_gateway_service.go | 601 ++++++++++++++++++ 1 file changed, 601 insertions(+) diff --git a/backend/internal/service/antigravity_gateway_service.go b/backend/internal/service/antigravity_gateway_service.go index 3d3c9cca..fd53ba71 100644 --- a/backend/internal/service/antigravity_gateway_service.go +++ b/backend/internal/service/antigravity_gateway_service.go @@ -650,6 +650,10 @@ type TestConnectionResult struct { // TestConnection 测试 Antigravity 账号连接(非流式,无重试、无计费) // 支持 Claude 和 Gemini 两种协议,根据 modelID 前缀自动选择 func (s *AntigravityGatewayService) TestConnection(ctx context.Context, account *Account, modelID string) (*TestConnectionResult, error) { + if account.Type == AccountTypeUpstream { + return s.testUpstreamConnection(ctx, account, modelID) + } + // 获取 token if s.tokenProvider == nil { return nil, errors.New("antigravity token provider not configured") @@ -966,6 +970,11 @@ func isModelNotFoundError(statusCode int, body []byte) bool { // Forward 转发 Claude 协议请求(Claude → Gemini 转换) func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context, account *Account, body []byte, isStickySession bool) (*ForwardResult, error) { startTime := time.Now() + + if account.Type == AccountTypeUpstream { + return s.ForwardUpstream(ctx, c, account, body, isStickySession) + } + sessionID := getSessionID(c) prefix := logPrefix(sessionID, account.Name) @@ -1585,6 +1594,11 @@ func stripSignatureSensitiveBlocksFromClaudeRequest(req *antigravity.ClaudeReque // ForwardGemini 转发 Gemini 协议请求 func (s *AntigravityGatewayService) ForwardGemini(ctx context.Context, c *gin.Context, account *Account, originalModel string, action string, stream bool, body []byte, isStickySession bool) (*ForwardResult, error) { startTime := time.Now() + + if account.Type == AccountTypeUpstream { + return s.ForwardUpstreamGemini(ctx, c, account, originalModel, action, stream, body, isStickySession) + } + sessionID := getSessionID(c) prefix := logPrefix(sessionID, account.Name) @@ -3332,3 +3346,590 @@ func filterEmptyPartsFromGeminiRequest(body []byte) ([]byte, error) { payload["contents"] = filtered return json.Marshal(payload) } + +// --------------------------------------------------------------------------- +// Upstream 专用转发方法 +// upstream 账号直接连接上游 Anthropic/Gemini 兼容端点,不走 Antigravity OAuth 协议转换。 +// --------------------------------------------------------------------------- + +// testUpstreamConnection 测试 upstream 账号连接 +func (s *AntigravityGatewayService) testUpstreamConnection(ctx context.Context, account *Account, modelID string) (*TestConnectionResult, error) { + baseURL := strings.TrimRight(strings.TrimSpace(account.GetCredential("base_url")), "/") + if baseURL == "" { + return nil, errors.New("upstream account missing base_url in credentials") + } + apiKey := strings.TrimSpace(account.GetCredential("api_key")) + if apiKey == "" { + return nil, errors.New("upstream account missing api_key in credentials") + } + + mappedModel := s.getMappedModel(account, modelID) + if mappedModel == "" { + return nil, fmt.Errorf("model %s not in whitelist", modelID) + } + + // 构建最小 Claude 格式请求 + requestBody, _ := json.Marshal(map[string]any{ + "model": mappedModel, + "max_tokens": 1, + "messages": []map[string]any{ + {"role": "user", "content": "."}, + }, + "stream": false, + }) + + apiURL := baseURL + "/antigravity/v1/messages" + req, err := http.NewRequestWithContext(ctx, http.MethodPost, apiURL, bytes.NewReader(requestBody)) + if err != nil { + return nil, fmt.Errorf("构建请求失败: %w", err) + } + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Authorization", "Bearer "+apiKey) + req.Header.Set("x-api-key", apiKey) + req.Header.Set("anthropic-version", "2023-06-01") + + proxyURL := "" + if account.ProxyID != nil && account.Proxy != nil { + proxyURL = account.Proxy.URL() + } + + log.Printf("[antigravity-Test-Upstream] account=%s url=%s", account.Name, apiURL) + + resp, err := s.httpUpstream.Do(req, proxyURL, account.ID, account.Concurrency) + if err != nil { + return nil, fmt.Errorf("请求失败: %w", err) + } + defer func() { _ = resp.Body.Close() }() + + respBody, err := io.ReadAll(io.LimitReader(resp.Body, 2<<20)) + if err != nil { + return nil, fmt.Errorf("读取响应失败: %w", err) + } + + if resp.StatusCode >= 400 { + return nil, fmt.Errorf("API 返回 %d: %s", resp.StatusCode, string(respBody)) + } + + // 从 Claude 格式非流式响应中提取文本 + var claudeResp struct { + Content []struct { + Text string `json:"text"` + } `json:"content"` + } + text := "" + if json.Unmarshal(respBody, &claudeResp) == nil && len(claudeResp.Content) > 0 { + text = claudeResp.Content[0].Text + } + + return &TestConnectionResult{ + Text: text, + MappedModel: mappedModel, + }, nil +} + +// ForwardUpstream 转发 Claude 协议请求到 upstream(不做协议转换) +func (s *AntigravityGatewayService) ForwardUpstream(ctx context.Context, c *gin.Context, account *Account, body []byte, isStickySession bool) (*ForwardResult, error) { + startTime := time.Now() + sessionID := getSessionID(c) + prefix := logPrefix(sessionID, account.Name) + + baseURL := strings.TrimRight(strings.TrimSpace(account.GetCredential("base_url")), "/") + if baseURL == "" { + return nil, s.writeClaudeError(c, http.StatusBadGateway, "api_error", "Upstream account missing base_url") + } + apiKey := strings.TrimSpace(account.GetCredential("api_key")) + if apiKey == "" { + return nil, s.writeClaudeError(c, http.StatusBadGateway, "authentication_error", "Upstream account missing api_key") + } + + // 解析请求以获取模型和流式标志 + var claudeReq antigravity.ClaudeRequest + if err := json.Unmarshal(body, &claudeReq); err != nil { + return nil, s.writeClaudeError(c, http.StatusBadRequest, "invalid_request_error", "Invalid request body") + } + if strings.TrimSpace(claudeReq.Model) == "" { + return nil, s.writeClaudeError(c, http.StatusBadRequest, "invalid_request_error", "Missing model") + } + + originalModel := claudeReq.Model + mappedModel := s.getMappedModel(account, claudeReq.Model) + if mappedModel == "" { + return nil, s.writeClaudeError(c, http.StatusForbidden, "permission_error", fmt.Sprintf("model %s not in whitelist", claudeReq.Model)) + } + loadModel := mappedModel + thinkingEnabled := claudeReq.Thinking != nil && claudeReq.Thinking.Type == "enabled" + mappedModel = applyThinkingModelSuffix(mappedModel, thinkingEnabled) + quotaScope, _ := resolveAntigravityQuotaScope(originalModel) + + // 代理 URL + proxyURL := "" + if account.ProxyID != nil && account.Proxy != nil { + proxyURL = account.Proxy.URL() + } + + // 统计模型调用次数 + if s.cache != nil { + _, _ = s.cache.IncrModelCallCount(ctx, account.ID, loadModel) + } + + apiURL := baseURL + "/antigravity/v1/messages" + log.Printf("%s upstream_forward url=%s model=%s", prefix, apiURL, mappedModel) + + // 预检查:模型级限流 + if remaining := account.GetRateLimitRemainingTimeWithContext(ctx, originalModel); remaining > 0 { + if remaining < antigravityRateLimitThreshold { + select { + case <-ctx.Done(): + return nil, ctx.Err() + case <-time.After(remaining): + } + } else { + return nil, &UpstreamFailoverError{ + StatusCode: http.StatusServiceUnavailable, + ForceCacheBilling: isStickySession, + } + } + } + + // 重试循环 + var resp *http.Response + var lastErr error + for attempt := 1; attempt <= antigravityMaxRetries; attempt++ { + select { + case <-ctx.Done(): + return nil, ctx.Err() + default: + } + + req, err := http.NewRequestWithContext(ctx, http.MethodPost, apiURL, bytes.NewReader(body)) + if err != nil { + return nil, s.writeClaudeError(c, http.StatusInternalServerError, "api_error", "Failed to build request") + } + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Authorization", "Bearer "+apiKey) + req.Header.Set("x-api-key", apiKey) + + // 透传 anthropic headers + if v := c.GetHeader("anthropic-version"); v != "" { + req.Header.Set("anthropic-version", v) + } else { + req.Header.Set("anthropic-version", "2023-06-01") + } + if v := c.GetHeader("anthropic-beta"); v != "" { + req.Header.Set("anthropic-beta", v) + } + + if c != nil && len(body) > 0 { + c.Set(OpsUpstreamRequestBodyKey, string(body)) + } + + resp, err = s.httpUpstream.Do(req, proxyURL, account.ID, account.Concurrency) + if err != nil { + lastErr = err + if attempt < antigravityMaxRetries { + log.Printf("%s status=request_failed retry=%d/%d error=%v", prefix, attempt, antigravityMaxRetries, err) + if !sleepAntigravityBackoffWithContext(ctx, attempt) { + return nil, ctx.Err() + } + continue + } + return nil, s.writeClaudeError(c, http.StatusBadGateway, "upstream_error", "Upstream request failed after retries") + } + + // 429/503 重试 + if resp.StatusCode == http.StatusTooManyRequests || resp.StatusCode == http.StatusServiceUnavailable { + respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20)) + _ = resp.Body.Close() + + s.handleUpstreamError(ctx, prefix, account, resp.StatusCode, resp.Header, respBody, quotaScope, 0, "", isStickySession) + + if attempt < antigravityMaxRetries { + log.Printf("%s status=%d retry=%d/%d body=%s", prefix, resp.StatusCode, attempt, antigravityMaxRetries, truncateForLog(respBody, 200)) + if !sleepAntigravityBackoffWithContext(ctx, attempt) { + return nil, ctx.Err() + } + continue + } + + return nil, &UpstreamFailoverError{ + StatusCode: resp.StatusCode, + ForceCacheBilling: isStickySession, + } + } + + break // 成功或非限流错误,跳出重试 + } + if resp == nil { + return nil, s.writeClaudeError(c, http.StatusBadGateway, "upstream_error", fmt.Sprintf("upstream request failed: %v", lastErr)) + } + defer func() { _ = resp.Body.Close() }() + + // 错误响应处理 + if resp.StatusCode >= 400 { + respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20)) + + // signature 重试 + if resp.StatusCode == http.StatusBadRequest && isSignatureRelatedError(respBody) { + log.Printf("%s upstream signature error, retrying with thinking stripped", prefix) + retryClaudeReq := claudeReq + retryClaudeReq.Messages = append([]antigravity.ClaudeMessage(nil), claudeReq.Messages...) + if stripped, stripErr := stripThinkingFromClaudeRequest(&retryClaudeReq); stripErr == nil && stripped { + retryBody, _ := json.Marshal(&retryClaudeReq) + retryReq, err := http.NewRequestWithContext(ctx, http.MethodPost, apiURL, bytes.NewReader(retryBody)) + if err == nil { + retryReq.Header.Set("Content-Type", "application/json") + retryReq.Header.Set("Authorization", "Bearer "+apiKey) + retryReq.Header.Set("x-api-key", apiKey) + retryReq.Header.Set("anthropic-version", "2023-06-01") + if v := c.GetHeader("anthropic-beta"); v != "" { + retryReq.Header.Set("anthropic-beta", v) + } + retryResp, retryErr := s.httpUpstream.Do(retryReq, proxyURL, account.ID, account.Concurrency) + if retryErr == nil && retryResp != nil && retryResp.StatusCode < 400 { + resp = retryResp + goto upstreamClaudeSuccess + } + if retryResp != nil { + _ = retryResp.Body.Close() + } + } + } + } + + // prompt too long + if resp.StatusCode == http.StatusBadRequest && isPromptTooLongError(respBody) { + return nil, &PromptTooLongError{ + StatusCode: resp.StatusCode, + RequestID: resp.Header.Get("x-request-id"), + Body: respBody, + } + } + + s.handleUpstreamError(ctx, prefix, account, resp.StatusCode, resp.Header, respBody, quotaScope, 0, "", isStickySession) + + if s.shouldFailoverUpstreamError(resp.StatusCode) { + return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode, ResponseBody: respBody} + } + + return nil, s.writeMappedClaudeError(c, account, resp.StatusCode, resp.Header.Get("x-request-id"), respBody) + } + +upstreamClaudeSuccess: + requestID := resp.Header.Get("x-request-id") + if requestID != "" { + c.Header("x-request-id", requestID) + } + + var usage *ClaudeUsage + var firstTokenMs *int + if claudeReq.Stream { + streamRes, err := s.handleClaudeStreamingResponse(c, resp, startTime, originalModel) + if err != nil { + log.Printf("%s status=stream_error error=%v", prefix, err) + return nil, err + } + usage = streamRes.usage + firstTokenMs = streamRes.firstTokenMs + } else { + streamRes, err := s.handleClaudeStreamToNonStreaming(c, resp, startTime, originalModel) + if err != nil { + log.Printf("%s status=stream_collect_error error=%v", prefix, err) + return nil, err + } + usage = streamRes.usage + firstTokenMs = streamRes.firstTokenMs + } + + return &ForwardResult{ + RequestID: requestID, + Usage: *usage, + Model: originalModel, + Stream: claudeReq.Stream, + Duration: time.Since(startTime), + FirstTokenMs: firstTokenMs, + }, nil +} + +// ForwardUpstreamGemini 转发 Gemini 协议请求到 upstream(不做协议转换) +func (s *AntigravityGatewayService) ForwardUpstreamGemini(ctx context.Context, c *gin.Context, account *Account, originalModel string, action string, stream bool, body []byte, isStickySession bool) (*ForwardResult, error) { + startTime := time.Now() + sessionID := getSessionID(c) + prefix := logPrefix(sessionID, account.Name) + + baseURL := strings.TrimRight(strings.TrimSpace(account.GetCredential("base_url")), "/") + if baseURL == "" { + return nil, s.writeGoogleError(c, http.StatusBadGateway, "Upstream account missing base_url") + } + apiKey := strings.TrimSpace(account.GetCredential("api_key")) + if apiKey == "" { + return nil, s.writeGoogleError(c, http.StatusBadGateway, "Upstream account missing api_key") + } + + if strings.TrimSpace(originalModel) == "" { + return nil, s.writeGoogleError(c, http.StatusBadRequest, "Missing model in URL") + } + if strings.TrimSpace(action) == "" { + return nil, s.writeGoogleError(c, http.StatusBadRequest, "Missing action in URL") + } + if len(body) == 0 { + return nil, s.writeGoogleError(c, http.StatusBadRequest, "Request body is empty") + } + quotaScope, _ := resolveAntigravityQuotaScope(originalModel) + + imageSize := s.extractImageSize(body) + + switch action { + case "generateContent", "streamGenerateContent": + // ok + case "countTokens": + c.JSON(http.StatusOK, map[string]any{"totalTokens": 0}) + return &ForwardResult{ + RequestID: "", + Usage: ClaudeUsage{}, + Model: originalModel, + Stream: false, + Duration: time.Since(time.Now()), + FirstTokenMs: nil, + }, nil + default: + return nil, s.writeGoogleError(c, http.StatusNotFound, "Unsupported action: "+action) + } + + mappedModel := s.getMappedModel(account, originalModel) + if mappedModel == "" { + return nil, s.writeGoogleError(c, http.StatusForbidden, fmt.Sprintf("model %s not in whitelist", originalModel)) + } + + // 代理 URL + proxyURL := "" + if account.ProxyID != nil && account.Proxy != nil { + proxyURL = account.Proxy.URL() + } + + // 统计模型调用次数 + if s.cache != nil { + _, _ = s.cache.IncrModelCallCount(ctx, account.ID, mappedModel) + } + + // 构建 upstream URL: base_url + /antigravity/v1beta/models/MODEL:ACTION + upstreamAction := action + if action == "generateContent" && !stream { + // 非流式也用 streamGenerateContent,与 OAuth 路径行为一致 + upstreamAction = action + } + apiURL := fmt.Sprintf("%s/antigravity/v1beta/models/%s:%s", baseURL, mappedModel, upstreamAction) + if stream || upstreamAction == "streamGenerateContent" { + apiURL += "?alt=sse" + } + + log.Printf("%s upstream_forward_gemini url=%s model=%s action=%s", prefix, apiURL, mappedModel, upstreamAction) + + // 预检查:模型级限流 + if remaining := account.GetRateLimitRemainingTimeWithContext(ctx, originalModel); remaining > 0 { + if remaining < antigravityRateLimitThreshold { + select { + case <-ctx.Done(): + return nil, ctx.Err() + case <-time.After(remaining): + } + } else { + return nil, &UpstreamFailoverError{ + StatusCode: http.StatusServiceUnavailable, + ForceCacheBilling: isStickySession, + } + } + } + + // 重试循环 + var resp *http.Response + var lastErr error + for attempt := 1; attempt <= antigravityMaxRetries; attempt++ { + select { + case <-ctx.Done(): + return nil, ctx.Err() + default: + } + + req, err := http.NewRequestWithContext(ctx, http.MethodPost, apiURL, bytes.NewReader(body)) + if err != nil { + return nil, s.writeGoogleError(c, http.StatusInternalServerError, "Failed to build request") + } + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Authorization", "Bearer "+apiKey) + + if c != nil && len(body) > 0 { + c.Set(OpsUpstreamRequestBodyKey, string(body)) + } + + resp, err = s.httpUpstream.Do(req, proxyURL, account.ID, account.Concurrency) + if err != nil { + lastErr = err + if attempt < antigravityMaxRetries { + log.Printf("%s status=request_failed retry=%d/%d error=%v", prefix, attempt, antigravityMaxRetries, err) + if !sleepAntigravityBackoffWithContext(ctx, attempt) { + return nil, ctx.Err() + } + continue + } + return nil, s.writeGoogleError(c, http.StatusBadGateway, "Upstream request failed after retries") + } + + // 429/503 重试 + if resp.StatusCode == http.StatusTooManyRequests || resp.StatusCode == http.StatusServiceUnavailable { + respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20)) + _ = resp.Body.Close() + + s.handleUpstreamError(ctx, prefix, account, resp.StatusCode, resp.Header, respBody, quotaScope, 0, "", isStickySession) + + if attempt < antigravityMaxRetries { + log.Printf("%s status=%d retry=%d/%d body=%s", prefix, resp.StatusCode, attempt, antigravityMaxRetries, truncateForLog(respBody, 200)) + if !sleepAntigravityBackoffWithContext(ctx, attempt) { + return nil, ctx.Err() + } + continue + } + + return nil, &UpstreamFailoverError{ + StatusCode: resp.StatusCode, + ForceCacheBilling: isStickySession, + } + } + + break + } + if resp == nil { + return nil, s.writeGoogleError(c, http.StatusBadGateway, fmt.Sprintf("upstream request failed: %v", lastErr)) + } + defer func() { + if resp != nil && resp.Body != nil { + _ = resp.Body.Close() + } + }() + + // 错误响应处理 + if resp.StatusCode >= 400 { + respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20)) + contentType := resp.Header.Get("Content-Type") + _ = resp.Body.Close() + resp.Body = io.NopCloser(bytes.NewReader(respBody)) + + // 模型兜底 + if s.settingService != nil && s.settingService.IsModelFallbackEnabled(ctx) && + isModelNotFoundError(resp.StatusCode, respBody) { + fallbackModel := s.settingService.GetFallbackModel(ctx, PlatformAntigravity) + if fallbackModel != "" && fallbackModel != mappedModel { + log.Printf("[Antigravity-Upstream] Model not found (%s), retrying with fallback model %s (account: %s)", mappedModel, fallbackModel, account.Name) + fallbackURL := fmt.Sprintf("%s/antigravity/v1beta/models/%s:%s", baseURL, fallbackModel, upstreamAction) + if stream || upstreamAction == "streamGenerateContent" { + fallbackURL += "?alt=sse" + } + fallbackReq, err := http.NewRequestWithContext(ctx, http.MethodPost, fallbackURL, bytes.NewReader(body)) + if err == nil { + fallbackReq.Header.Set("Content-Type", "application/json") + fallbackReq.Header.Set("Authorization", "Bearer "+apiKey) + fallbackResp, err := s.httpUpstream.Do(fallbackReq, proxyURL, account.ID, account.Concurrency) + if err == nil && fallbackResp.StatusCode < 400 { + _ = resp.Body.Close() + resp = fallbackResp + } else if fallbackResp != nil { + _ = fallbackResp.Body.Close() + } + } + } + } + + // fallback 成功 + if resp.StatusCode < 400 { + goto upstreamGeminiSuccess + } + + requestID := resp.Header.Get("x-request-id") + if requestID != "" { + c.Header("x-request-id", requestID) + } + + s.handleUpstreamError(ctx, prefix, account, resp.StatusCode, resp.Header, respBody, quotaScope, 0, "", isStickySession) + upstreamMsg := strings.TrimSpace(extractAntigravityErrorMessage(respBody)) + upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg) + upstreamDetail := s.getUpstreamErrorDetail(respBody) + + setOpsUpstreamError(c, resp.StatusCode, upstreamMsg, upstreamDetail) + + if s.shouldFailoverUpstreamError(resp.StatusCode) { + appendOpsUpstreamError(c, OpsUpstreamErrorEvent{ + Platform: account.Platform, + AccountID: account.ID, + AccountName: account.Name, + UpstreamStatusCode: resp.StatusCode, + UpstreamRequestID: requestID, + Kind: "failover", + Message: upstreamMsg, + Detail: upstreamDetail, + }) + return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode, ResponseBody: respBody} + } + if contentType == "" { + contentType = "application/json" + } + appendOpsUpstreamError(c, OpsUpstreamErrorEvent{ + Platform: account.Platform, + AccountID: account.ID, + AccountName: account.Name, + UpstreamStatusCode: resp.StatusCode, + UpstreamRequestID: requestID, + Kind: "http_error", + Message: upstreamMsg, + Detail: upstreamDetail, + }) + log.Printf("[antigravity-Forward-Upstream] upstream error status=%d body=%s", resp.StatusCode, truncateForLog(respBody, 500)) + c.Data(resp.StatusCode, contentType, respBody) + return nil, fmt.Errorf("antigravity upstream error: %d", resp.StatusCode) + } + +upstreamGeminiSuccess: + requestID := resp.Header.Get("x-request-id") + if requestID != "" { + c.Header("x-request-id", requestID) + } + + var usage *ClaudeUsage + var firstTokenMs *int + + if stream { + streamRes, err := s.handleGeminiStreamingResponse(c, resp, startTime) + if err != nil { + log.Printf("%s status=stream_error error=%v", prefix, err) + return nil, err + } + usage = streamRes.usage + firstTokenMs = streamRes.firstTokenMs + } else { + streamRes, err := s.handleGeminiStreamToNonStreaming(c, resp, startTime) + if err != nil { + log.Printf("%s status=stream_collect_error error=%v", prefix, err) + return nil, err + } + usage = streamRes.usage + firstTokenMs = streamRes.firstTokenMs + } + + if usage == nil { + usage = &ClaudeUsage{} + } + + imageCount := 0 + if isImageGenerationModel(mappedModel) { + imageCount = 1 + } + + return &ForwardResult{ + RequestID: requestID, + Usage: *usage, + Model: originalModel, + Stream: stream, + Duration: time.Since(startTime), + FirstTokenMs: firstTokenMs, + ImageCount: imageCount, + ImageSize: imageSize, + }, nil +} From df3346387fcd0c758362008de867837bd28811b8 Mon Sep 17 00:00:00 2001 From: erio Date: Sun, 8 Feb 2026 01:46:50 +0800 Subject: [PATCH 2/8] fix(frontend): upstream account edit fields and mixed_scheduling on create - EditAccountModal: add Base URL / API Key fields for upstream type - EditAccountModal: initialize editBaseUrl from credentials on upstream account open - EditAccountModal: save upstream credentials (base_url, api_key) on submit - CreateAccountModal: pass mixed_scheduling extra when creating upstream account --- .../components/account/CreateAccountModal.vue | 3 +- .../components/account/EditAccountModal.vue | 43 +++++++++++++++++++ 2 files changed, 45 insertions(+), 1 deletion(-) diff --git a/frontend/src/components/account/CreateAccountModal.vue b/frontend/src/components/account/CreateAccountModal.vue index ba1daea9..7d759be1 100644 --- a/frontend/src/components/account/CreateAccountModal.vue +++ b/frontend/src/components/account/CreateAccountModal.vue @@ -2714,7 +2714,8 @@ const handleSubmit = async () => { submitting.value = true try { - await createAccountAndFinish(form.platform, 'upstream', credentials) + const extra = mixedScheduling.value ? { mixed_scheduling: true } : undefined + await createAccountAndFinish(form.platform, 'upstream', credentials, extra) } catch (error: any) { appStore.showError(error.response?.data?.detail || t('admin.accounts.failedToCreate')) } finally { diff --git a/frontend/src/components/account/EditAccountModal.vue b/frontend/src/components/account/EditAccountModal.vue index 2e428460..986bd297 100644 --- a/frontend/src/components/account/EditAccountModal.vue +++ b/frontend/src/components/account/EditAccountModal.vue @@ -364,6 +364,30 @@ + +
+
+ + +

{{ t('admin.accounts.upstream.baseUrlHint') }}

+
+
+ + +

{{ t('admin.accounts.leaveEmptyToKeep') }}

+
+
+
@@ -1244,6 +1268,9 @@ watch( } else { selectedErrorCodes.value = [] } + } else if (newAccount.type === 'upstream' && newAccount.credentials) { + const credentials = newAccount.credentials as Record + editBaseUrl.value = (credentials.base_url as string) || '' } else { const platformDefaultUrl = newAccount.platform === 'openai' @@ -1584,6 +1611,22 @@ const handleSubmit = async () => { return } + updatePayload.credentials = newCredentials + } else if (props.account.type === 'upstream') { + const currentCredentials = (props.account.credentials as Record) || {} + const newCredentials: Record = { ...currentCredentials } + + newCredentials.base_url = editBaseUrl.value.trim() + + if (editApiKey.value.trim()) { + newCredentials.api_key = editApiKey.value.trim() + } + + if (!applyTempUnschedConfig(newCredentials)) { + submitting.value = false + return + } + updatePayload.credentials = newCredentials } else { // For oauth/setup-token types, only update intercept_warmup_requests if changed From 1563bd3dda85e7f18058357fc8fcfdc4308c94ef Mon Sep 17 00:00:00 2001 From: erio Date: Sun, 8 Feb 2026 08:33:09 +0800 Subject: [PATCH 3/8] feat(upstream): passthrough all client headers instead of manual header setting Replace manual header setting (Content-Type, anthropic-version, anthropic-beta) with full client header passthrough in ForwardUpstream/ForwardUpstreamGemini. Only authentication headers (Authorization, x-api-key) are overridden with upstream account credentials. Hop-by-hop headers are excluded. Add unit tests covering header passthrough, auth override, and hop-by-hop filtering. --- .../service/antigravity_gateway_service.go | 312 ++++-------------- .../upstream_header_passthrough_test.go | 285 ++++++++++++++++ 2 files changed, 352 insertions(+), 245 deletions(-) create mode 100644 backend/internal/service/upstream_header_passthrough_test.go diff --git a/backend/internal/service/antigravity_gateway_service.go b/backend/internal/service/antigravity_gateway_service.go index fd53ba71..fc29eeb3 100644 --- a/backend/internal/service/antigravity_gateway_service.go +++ b/backend/internal/service/antigravity_gateway_service.go @@ -47,6 +47,21 @@ const ( googleRPCReasonRateLimitExceeded = "RATE_LIMIT_EXCEEDED" ) +// upstreamHopByHopHeaders 透传请求头时需要排除的 hop-by-hop 头 +var upstreamHopByHopHeaders = map[string]bool{ + "connection": true, + "keep-alive": true, + "proxy-authenticate": true, + "proxy-authorization": true, + "proxy-connection": true, + "te": true, + "trailer": true, + "transfer-encoding": true, + "upgrade": true, + "host": true, + "content-length": true, +} + // antigravityPassthroughErrorMessages 透传给客户端的错误消息白名单(小写) // 匹配时使用 strings.Contains,无需完全匹配 var antigravityPassthroughErrorMessages = []string{ @@ -3456,10 +3471,6 @@ func (s *AntigravityGatewayService) ForwardUpstream(ctx context.Context, c *gin. if mappedModel == "" { return nil, s.writeClaudeError(c, http.StatusForbidden, "permission_error", fmt.Sprintf("model %s not in whitelist", claudeReq.Model)) } - loadModel := mappedModel - thinkingEnabled := claudeReq.Thinking != nil && claudeReq.Thinking.Type == "enabled" - mappedModel = applyThinkingModelSuffix(mappedModel, thinkingEnabled) - quotaScope, _ := resolveAntigravityQuotaScope(originalModel) // 代理 URL proxyURL := "" @@ -3469,98 +3480,38 @@ func (s *AntigravityGatewayService) ForwardUpstream(ctx context.Context, c *gin. // 统计模型调用次数 if s.cache != nil { - _, _ = s.cache.IncrModelCallCount(ctx, account.ID, loadModel) + _, _ = s.cache.IncrModelCallCount(ctx, account.ID, mappedModel) } apiURL := baseURL + "/antigravity/v1/messages" log.Printf("%s upstream_forward url=%s model=%s", prefix, apiURL, mappedModel) - // 预检查:模型级限流 - if remaining := account.GetRateLimitRemainingTimeWithContext(ctx, originalModel); remaining > 0 { - if remaining < antigravityRateLimitThreshold { - select { - case <-ctx.Done(): - return nil, ctx.Err() - case <-time.After(remaining): - } - } else { - return nil, &UpstreamFailoverError{ - StatusCode: http.StatusServiceUnavailable, - ForceCacheBilling: isStickySession, - } + // 构建请求:body 原样透传 + req, err := http.NewRequestWithContext(ctx, http.MethodPost, apiURL, bytes.NewReader(body)) + if err != nil { + return nil, s.writeClaudeError(c, http.StatusInternalServerError, "api_error", "Failed to build request") + } + // 透传客户端所有请求头(排除 hop-by-hop 和认证头) + for key, values := range c.Request.Header { + if upstreamHopByHopHeaders[strings.ToLower(key)] { + continue + } + for _, v := range values { + req.Header.Add(key, v) } } + // 覆盖认证头 + req.Header.Set("Authorization", "Bearer "+apiKey) + req.Header.Set("x-api-key", apiKey) - // 重试循环 - var resp *http.Response - var lastErr error - for attempt := 1; attempt <= antigravityMaxRetries; attempt++ { - select { - case <-ctx.Done(): - return nil, ctx.Err() - default: - } - - req, err := http.NewRequestWithContext(ctx, http.MethodPost, apiURL, bytes.NewReader(body)) - if err != nil { - return nil, s.writeClaudeError(c, http.StatusInternalServerError, "api_error", "Failed to build request") - } - req.Header.Set("Content-Type", "application/json") - req.Header.Set("Authorization", "Bearer "+apiKey) - req.Header.Set("x-api-key", apiKey) - - // 透传 anthropic headers - if v := c.GetHeader("anthropic-version"); v != "" { - req.Header.Set("anthropic-version", v) - } else { - req.Header.Set("anthropic-version", "2023-06-01") - } - if v := c.GetHeader("anthropic-beta"); v != "" { - req.Header.Set("anthropic-beta", v) - } - - if c != nil && len(body) > 0 { - c.Set(OpsUpstreamRequestBodyKey, string(body)) - } - - resp, err = s.httpUpstream.Do(req, proxyURL, account.ID, account.Concurrency) - if err != nil { - lastErr = err - if attempt < antigravityMaxRetries { - log.Printf("%s status=request_failed retry=%d/%d error=%v", prefix, attempt, antigravityMaxRetries, err) - if !sleepAntigravityBackoffWithContext(ctx, attempt) { - return nil, ctx.Err() - } - continue - } - return nil, s.writeClaudeError(c, http.StatusBadGateway, "upstream_error", "Upstream request failed after retries") - } - - // 429/503 重试 - if resp.StatusCode == http.StatusTooManyRequests || resp.StatusCode == http.StatusServiceUnavailable { - respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20)) - _ = resp.Body.Close() - - s.handleUpstreamError(ctx, prefix, account, resp.StatusCode, resp.Header, respBody, quotaScope, 0, "", isStickySession) - - if attempt < antigravityMaxRetries { - log.Printf("%s status=%d retry=%d/%d body=%s", prefix, resp.StatusCode, attempt, antigravityMaxRetries, truncateForLog(respBody, 200)) - if !sleepAntigravityBackoffWithContext(ctx, attempt) { - return nil, ctx.Err() - } - continue - } - - return nil, &UpstreamFailoverError{ - StatusCode: resp.StatusCode, - ForceCacheBilling: isStickySession, - } - } - - break // 成功或非限流错误,跳出重试 + if c != nil && len(body) > 0 { + c.Set(OpsUpstreamRequestBodyKey, string(body)) } - if resp == nil { - return nil, s.writeClaudeError(c, http.StatusBadGateway, "upstream_error", fmt.Sprintf("upstream request failed: %v", lastErr)) + + // 单次发送,不重试 + resp, err := s.httpUpstream.Do(req, proxyURL, account.ID, account.Concurrency) + if err != nil { + return nil, s.writeClaudeError(c, http.StatusBadGateway, "upstream_error", fmt.Sprintf("Upstream request failed: %v", err)) } defer func() { _ = resp.Body.Close() }() @@ -3568,44 +3519,7 @@ func (s *AntigravityGatewayService) ForwardUpstream(ctx context.Context, c *gin. if resp.StatusCode >= 400 { respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20)) - // signature 重试 - if resp.StatusCode == http.StatusBadRequest && isSignatureRelatedError(respBody) { - log.Printf("%s upstream signature error, retrying with thinking stripped", prefix) - retryClaudeReq := claudeReq - retryClaudeReq.Messages = append([]antigravity.ClaudeMessage(nil), claudeReq.Messages...) - if stripped, stripErr := stripThinkingFromClaudeRequest(&retryClaudeReq); stripErr == nil && stripped { - retryBody, _ := json.Marshal(&retryClaudeReq) - retryReq, err := http.NewRequestWithContext(ctx, http.MethodPost, apiURL, bytes.NewReader(retryBody)) - if err == nil { - retryReq.Header.Set("Content-Type", "application/json") - retryReq.Header.Set("Authorization", "Bearer "+apiKey) - retryReq.Header.Set("x-api-key", apiKey) - retryReq.Header.Set("anthropic-version", "2023-06-01") - if v := c.GetHeader("anthropic-beta"); v != "" { - retryReq.Header.Set("anthropic-beta", v) - } - retryResp, retryErr := s.httpUpstream.Do(retryReq, proxyURL, account.ID, account.Concurrency) - if retryErr == nil && retryResp != nil && retryResp.StatusCode < 400 { - resp = retryResp - goto upstreamClaudeSuccess - } - if retryResp != nil { - _ = retryResp.Body.Close() - } - } - } - } - - // prompt too long - if resp.StatusCode == http.StatusBadRequest && isPromptTooLongError(respBody) { - return nil, &PromptTooLongError{ - StatusCode: resp.StatusCode, - RequestID: resp.Header.Get("x-request-id"), - Body: respBody, - } - } - - s.handleUpstreamError(ctx, prefix, account, resp.StatusCode, resp.Header, respBody, quotaScope, 0, "", isStickySession) + s.handleUpstreamError(ctx, prefix, account, resp.StatusCode, resp.Header, respBody, "", 0, "", isStickySession) if s.shouldFailoverUpstreamError(resp.StatusCode) { return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode, ResponseBody: respBody} @@ -3614,7 +3528,7 @@ func (s *AntigravityGatewayService) ForwardUpstream(ctx context.Context, c *gin. return nil, s.writeMappedClaudeError(c, account, resp.StatusCode, resp.Header.Get("x-request-id"), respBody) } -upstreamClaudeSuccess: + // 成功响应 requestID := resp.Header.Get("x-request-id") if requestID != "" { c.Header("x-request-id", requestID) @@ -3674,7 +3588,6 @@ func (s *AntigravityGatewayService) ForwardUpstreamGemini(ctx context.Context, c if len(body) == 0 { return nil, s.writeGoogleError(c, http.StatusBadRequest, "Request body is empty") } - quotaScope, _ := resolveAntigravityQuotaScope(originalModel) imageSize := s.extractImageSize(body) @@ -3712,143 +3625,52 @@ func (s *AntigravityGatewayService) ForwardUpstreamGemini(ctx context.Context, c } // 构建 upstream URL: base_url + /antigravity/v1beta/models/MODEL:ACTION - upstreamAction := action - if action == "generateContent" && !stream { - // 非流式也用 streamGenerateContent,与 OAuth 路径行为一致 - upstreamAction = action - } - apiURL := fmt.Sprintf("%s/antigravity/v1beta/models/%s:%s", baseURL, mappedModel, upstreamAction) - if stream || upstreamAction == "streamGenerateContent" { + apiURL := fmt.Sprintf("%s/antigravity/v1beta/models/%s:%s", baseURL, mappedModel, action) + if stream || action == "streamGenerateContent" { apiURL += "?alt=sse" } - log.Printf("%s upstream_forward_gemini url=%s model=%s action=%s", prefix, apiURL, mappedModel, upstreamAction) + log.Printf("%s upstream_forward_gemini url=%s model=%s action=%s", prefix, apiURL, mappedModel, action) - // 预检查:模型级限流 - if remaining := account.GetRateLimitRemainingTimeWithContext(ctx, originalModel); remaining > 0 { - if remaining < antigravityRateLimitThreshold { - select { - case <-ctx.Done(): - return nil, ctx.Err() - case <-time.After(remaining): - } - } else { - return nil, &UpstreamFailoverError{ - StatusCode: http.StatusServiceUnavailable, - ForceCacheBilling: isStickySession, - } + // 构建请求:body 原样透传 + req, err := http.NewRequestWithContext(ctx, http.MethodPost, apiURL, bytes.NewReader(body)) + if err != nil { + return nil, s.writeGoogleError(c, http.StatusInternalServerError, "Failed to build request") + } + // 透传客户端所有请求头(排除 hop-by-hop 和认证头) + for key, values := range c.Request.Header { + if upstreamHopByHopHeaders[strings.ToLower(key)] { + continue + } + for _, v := range values { + req.Header.Add(key, v) } } + // 覆盖认证头 + req.Header.Set("Authorization", "Bearer "+apiKey) - // 重试循环 - var resp *http.Response - var lastErr error - for attempt := 1; attempt <= antigravityMaxRetries; attempt++ { - select { - case <-ctx.Done(): - return nil, ctx.Err() - default: - } - - req, err := http.NewRequestWithContext(ctx, http.MethodPost, apiURL, bytes.NewReader(body)) - if err != nil { - return nil, s.writeGoogleError(c, http.StatusInternalServerError, "Failed to build request") - } - req.Header.Set("Content-Type", "application/json") - req.Header.Set("Authorization", "Bearer "+apiKey) - - if c != nil && len(body) > 0 { - c.Set(OpsUpstreamRequestBodyKey, string(body)) - } - - resp, err = s.httpUpstream.Do(req, proxyURL, account.ID, account.Concurrency) - if err != nil { - lastErr = err - if attempt < antigravityMaxRetries { - log.Printf("%s status=request_failed retry=%d/%d error=%v", prefix, attempt, antigravityMaxRetries, err) - if !sleepAntigravityBackoffWithContext(ctx, attempt) { - return nil, ctx.Err() - } - continue - } - return nil, s.writeGoogleError(c, http.StatusBadGateway, "Upstream request failed after retries") - } - - // 429/503 重试 - if resp.StatusCode == http.StatusTooManyRequests || resp.StatusCode == http.StatusServiceUnavailable { - respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20)) - _ = resp.Body.Close() - - s.handleUpstreamError(ctx, prefix, account, resp.StatusCode, resp.Header, respBody, quotaScope, 0, "", isStickySession) - - if attempt < antigravityMaxRetries { - log.Printf("%s status=%d retry=%d/%d body=%s", prefix, resp.StatusCode, attempt, antigravityMaxRetries, truncateForLog(respBody, 200)) - if !sleepAntigravityBackoffWithContext(ctx, attempt) { - return nil, ctx.Err() - } - continue - } - - return nil, &UpstreamFailoverError{ - StatusCode: resp.StatusCode, - ForceCacheBilling: isStickySession, - } - } - - break + if c != nil && len(body) > 0 { + c.Set(OpsUpstreamRequestBodyKey, string(body)) } - if resp == nil { - return nil, s.writeGoogleError(c, http.StatusBadGateway, fmt.Sprintf("upstream request failed: %v", lastErr)) + + // 单次发送,不重试 + resp, err := s.httpUpstream.Do(req, proxyURL, account.ID, account.Concurrency) + if err != nil { + return nil, s.writeGoogleError(c, http.StatusBadGateway, fmt.Sprintf("Upstream request failed: %v", err)) } - defer func() { - if resp != nil && resp.Body != nil { - _ = resp.Body.Close() - } - }() + defer func() { _ = resp.Body.Close() }() // 错误响应处理 if resp.StatusCode >= 400 { respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20)) contentType := resp.Header.Get("Content-Type") - _ = resp.Body.Close() - resp.Body = io.NopCloser(bytes.NewReader(respBody)) - - // 模型兜底 - if s.settingService != nil && s.settingService.IsModelFallbackEnabled(ctx) && - isModelNotFoundError(resp.StatusCode, respBody) { - fallbackModel := s.settingService.GetFallbackModel(ctx, PlatformAntigravity) - if fallbackModel != "" && fallbackModel != mappedModel { - log.Printf("[Antigravity-Upstream] Model not found (%s), retrying with fallback model %s (account: %s)", mappedModel, fallbackModel, account.Name) - fallbackURL := fmt.Sprintf("%s/antigravity/v1beta/models/%s:%s", baseURL, fallbackModel, upstreamAction) - if stream || upstreamAction == "streamGenerateContent" { - fallbackURL += "?alt=sse" - } - fallbackReq, err := http.NewRequestWithContext(ctx, http.MethodPost, fallbackURL, bytes.NewReader(body)) - if err == nil { - fallbackReq.Header.Set("Content-Type", "application/json") - fallbackReq.Header.Set("Authorization", "Bearer "+apiKey) - fallbackResp, err := s.httpUpstream.Do(fallbackReq, proxyURL, account.ID, account.Concurrency) - if err == nil && fallbackResp.StatusCode < 400 { - _ = resp.Body.Close() - resp = fallbackResp - } else if fallbackResp != nil { - _ = fallbackResp.Body.Close() - } - } - } - } - - // fallback 成功 - if resp.StatusCode < 400 { - goto upstreamGeminiSuccess - } requestID := resp.Header.Get("x-request-id") if requestID != "" { c.Header("x-request-id", requestID) } - s.handleUpstreamError(ctx, prefix, account, resp.StatusCode, resp.Header, respBody, quotaScope, 0, "", isStickySession) + s.handleUpstreamError(ctx, prefix, account, resp.StatusCode, resp.Header, respBody, "", 0, "", isStickySession) upstreamMsg := strings.TrimSpace(extractAntigravityErrorMessage(respBody)) upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg) upstreamDetail := s.getUpstreamErrorDetail(respBody) @@ -3886,7 +3708,7 @@ func (s *AntigravityGatewayService) ForwardUpstreamGemini(ctx context.Context, c return nil, fmt.Errorf("antigravity upstream error: %d", resp.StatusCode) } -upstreamGeminiSuccess: + // 成功响应 requestID := resp.Header.Get("x-request-id") if requestID != "" { c.Header("x-request-id", requestID) diff --git a/backend/internal/service/upstream_header_passthrough_test.go b/backend/internal/service/upstream_header_passthrough_test.go new file mode 100644 index 00000000..51d8588b --- /dev/null +++ b/backend/internal/service/upstream_header_passthrough_test.go @@ -0,0 +1,285 @@ +//go:build unit + +package service + +import ( + "bytes" + "context" + "encoding/json" + "io" + "net/http" + "net/http/httptest" + "testing" + + "github.com/gin-gonic/gin" + "github.com/stretchr/testify/require" +) + +// httpUpstreamCapture captures the outgoing *http.Request for assertion. +type httpUpstreamCapture struct { + capturedReq *http.Request + resp *http.Response + err error +} + +func (s *httpUpstreamCapture) Do(req *http.Request, _ string, _ int64, _ int) (*http.Response, error) { + s.capturedReq = req + return s.resp, s.err +} + +func (s *httpUpstreamCapture) DoWithTLS(req *http.Request, _ string, _ int64, _ int, _ bool) (*http.Response, error) { + s.capturedReq = req + return s.resp, s.err +} + +func newUpstreamAccount() *Account { + return &Account{ + ID: 100, + Name: "upstream-test", + Platform: PlatformAntigravity, + Type: AccountTypeUpstream, + Status: StatusActive, + Concurrency: 1, + Credentials: map[string]any{ + "base_url": "https://upstream.example.com", + "api_key": "sk-upstream-secret", + }, + } +} + +// makeSSEOKResponse builds a minimal SSE response that +// handleClaudeStreamingResponse / handleGeminiStreamingResponse +// can consume without error. +// We return 502 to bypass streaming and hit the error branch instead, +// which is sufficient for testing header passthrough. +func makeUpstreamErrorResponse() *http.Response { + body := []byte(`{"error":{"message":"test error"}}`) + return &http.Response{ + StatusCode: http.StatusBadGateway, + Header: http.Header{"Content-Type": []string{"application/json"}}, + Body: io.NopCloser(bytes.NewReader(body)), + } +} + +// --- ForwardUpstream tests --- + +func TestForwardUpstream_PassthroughHeaders(t *testing.T) { + gin.SetMode(gin.TestMode) + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + + body, _ := json.Marshal(map[string]any{ + "model": "claude-sonnet-4-5", + "messages": []map[string]any{{"role": "user", "content": "hi"}}, + "max_tokens": 1, + "stream": false, + }) + + req := httptest.NewRequest(http.MethodPost, "/v1/messages", bytes.NewReader(body)) + req.Header.Set("Content-Type", "application/json") + req.Header.Set("anthropic-version", "2024-10-22") + req.Header.Set("anthropic-beta", "output-128k-2025-02-19") + req.Header.Set("X-Custom-Header", "custom-value") + c.Request = req + + stub := &httpUpstreamCapture{resp: makeUpstreamErrorResponse()} + svc := &AntigravityGatewayService{ + tokenProvider: &AntigravityTokenProvider{}, + httpUpstream: stub, + } + + _, _ = svc.ForwardUpstream(context.Background(), c, newUpstreamAccount(), body, false) + + captured := stub.capturedReq + require.NotNil(t, captured, "upstream request should have been made") + + // 客户端 header 应被透传 + require.Equal(t, "application/json", captured.Header.Get("Content-Type")) + require.Equal(t, "2024-10-22", captured.Header.Get("anthropic-version")) + require.Equal(t, "output-128k-2025-02-19", captured.Header.Get("anthropic-beta")) + require.Equal(t, "custom-value", captured.Header.Get("X-Custom-Header")) +} + +func TestForwardUpstream_OverridesAuthHeaders(t *testing.T) { + gin.SetMode(gin.TestMode) + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + + body, _ := json.Marshal(map[string]any{ + "model": "claude-sonnet-4-5", + "messages": []map[string]any{{"role": "user", "content": "hi"}}, + "max_tokens": 1, + "stream": false, + }) + + req := httptest.NewRequest(http.MethodPost, "/v1/messages", bytes.NewReader(body)) + req.Header.Set("Content-Type", "application/json") + // 客户端发来的认证头应被覆盖 + req.Header.Set("Authorization", "Bearer client-token") + req.Header.Set("x-api-key", "client-api-key") + c.Request = req + + stub := &httpUpstreamCapture{resp: makeUpstreamErrorResponse()} + svc := &AntigravityGatewayService{ + tokenProvider: &AntigravityTokenProvider{}, + httpUpstream: stub, + } + + _, _ = svc.ForwardUpstream(context.Background(), c, newUpstreamAccount(), body, false) + + captured := stub.capturedReq + require.NotNil(t, captured) + + // 认证头应使用上游账号的 api_key,而非客户端的 + require.Equal(t, "Bearer sk-upstream-secret", captured.Header.Get("Authorization")) + require.Equal(t, "sk-upstream-secret", captured.Header.Get("x-api-key")) +} + +func TestForwardUpstream_ExcludesHopByHopHeaders(t *testing.T) { + gin.SetMode(gin.TestMode) + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + + body, _ := json.Marshal(map[string]any{ + "model": "claude-sonnet-4-5", + "messages": []map[string]any{{"role": "user", "content": "hi"}}, + "max_tokens": 1, + "stream": false, + }) + + req := httptest.NewRequest(http.MethodPost, "/v1/messages", bytes.NewReader(body)) + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Connection", "keep-alive") + req.Header.Set("Keep-Alive", "timeout=5") + req.Header.Set("Transfer-Encoding", "chunked") + req.Header.Set("Upgrade", "websocket") + req.Header.Set("Te", "trailers") + c.Request = req + + stub := &httpUpstreamCapture{resp: makeUpstreamErrorResponse()} + svc := &AntigravityGatewayService{ + tokenProvider: &AntigravityTokenProvider{}, + httpUpstream: stub, + } + + _, _ = svc.ForwardUpstream(context.Background(), c, newUpstreamAccount(), body, false) + + captured := stub.capturedReq + require.NotNil(t, captured) + + // hop-by-hop header 不应出现 + require.Empty(t, captured.Header.Get("Connection")) + require.Empty(t, captured.Header.Get("Keep-Alive")) + require.Empty(t, captured.Header.Get("Transfer-Encoding")) + require.Empty(t, captured.Header.Get("Upgrade")) + require.Empty(t, captured.Header.Get("Te")) + + // 但普通 header 应保留 + require.Equal(t, "application/json", captured.Header.Get("Content-Type")) +} + +// --- ForwardUpstreamGemini tests --- + +func TestForwardUpstreamGemini_PassthroughHeaders(t *testing.T) { + gin.SetMode(gin.TestMode) + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + + body, _ := json.Marshal(map[string]any{ + "contents": []map[string]any{ + {"role": "user", "parts": []map[string]any{{"text": "hi"}}}, + }, + }) + + req := httptest.NewRequest(http.MethodPost, "/v1beta/models/gemini-2.5-flash:generateContent", bytes.NewReader(body)) + req.Header.Set("Content-Type", "application/json") + req.Header.Set("X-Custom-Gemini", "gemini-value") + req.Header.Set("X-Request-Id", "req-abc-123") + c.Request = req + + stub := &httpUpstreamCapture{resp: makeUpstreamErrorResponse()} + svc := &AntigravityGatewayService{ + tokenProvider: &AntigravityTokenProvider{}, + httpUpstream: stub, + } + + _, _ = svc.ForwardUpstreamGemini(context.Background(), c, newUpstreamAccount(), "gemini-2.5-flash", "generateContent", false, body, false) + + captured := stub.capturedReq + require.NotNil(t, captured, "upstream request should have been made") + + // 客户端 header 应被透传 + require.Equal(t, "application/json", captured.Header.Get("Content-Type")) + require.Equal(t, "gemini-value", captured.Header.Get("X-Custom-Gemini")) + require.Equal(t, "req-abc-123", captured.Header.Get("X-Request-Id")) +} + +func TestForwardUpstreamGemini_OverridesAuthHeaders(t *testing.T) { + gin.SetMode(gin.TestMode) + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + + body, _ := json.Marshal(map[string]any{ + "contents": []map[string]any{ + {"role": "user", "parts": []map[string]any{{"text": "hi"}}}, + }, + }) + + req := httptest.NewRequest(http.MethodPost, "/v1beta/models/gemini-2.5-flash:generateContent", bytes.NewReader(body)) + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Authorization", "Bearer client-gemini-token") + c.Request = req + + stub := &httpUpstreamCapture{resp: makeUpstreamErrorResponse()} + svc := &AntigravityGatewayService{ + tokenProvider: &AntigravityTokenProvider{}, + httpUpstream: stub, + } + + _, _ = svc.ForwardUpstreamGemini(context.Background(), c, newUpstreamAccount(), "gemini-2.5-flash", "generateContent", false, body, false) + + captured := stub.capturedReq + require.NotNil(t, captured) + + // 认证头应使用上游账号的 api_key + require.Equal(t, "Bearer sk-upstream-secret", captured.Header.Get("Authorization")) +} + +func TestForwardUpstreamGemini_ExcludesHopByHopHeaders(t *testing.T) { + gin.SetMode(gin.TestMode) + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + + body, _ := json.Marshal(map[string]any{ + "contents": []map[string]any{ + {"role": "user", "parts": []map[string]any{{"text": "hi"}}}, + }, + }) + + req := httptest.NewRequest(http.MethodPost, "/v1beta/models/gemini-2.5-flash:generateContent", bytes.NewReader(body)) + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Connection", "keep-alive") + req.Header.Set("Proxy-Authorization", "Basic dXNlcjpwYXNz") + req.Header.Set("Host", "evil.example.com") + c.Request = req + + stub := &httpUpstreamCapture{resp: makeUpstreamErrorResponse()} + svc := &AntigravityGatewayService{ + tokenProvider: &AntigravityTokenProvider{}, + httpUpstream: stub, + } + + _, _ = svc.ForwardUpstreamGemini(context.Background(), c, newUpstreamAccount(), "gemini-2.5-flash", "generateContent", false, body, false) + + captured := stub.capturedReq + require.NotNil(t, captured) + + // hop-by-hop header 不应出现 + require.Empty(t, captured.Header.Get("Connection")) + require.Empty(t, captured.Header.Get("Proxy-Authorization")) + // Host header 在 Go http.Request 中特殊处理,但我们的黑名单应阻止透传 + require.Empty(t, captured.Header.Values("Host")) + + // 普通 header 应保留 + require.Equal(t, "application/json", captured.Header.Get("Content-Type")) +} From 4f57d7f76188f2c767060c37d516ceb3fb05cdfe Mon Sep 17 00:00:00 2001 From: erio Date: Sun, 8 Feb 2026 08:36:35 +0800 Subject: [PATCH 4/8] fix: add nil guard for gin.Context in header passthrough to satisfy staticcheck SA5011 --- .../service/antigravity_gateway_service.go | 28 +++++++++++-------- 1 file changed, 16 insertions(+), 12 deletions(-) diff --git a/backend/internal/service/antigravity_gateway_service.go b/backend/internal/service/antigravity_gateway_service.go index fc29eeb3..c2983c47 100644 --- a/backend/internal/service/antigravity_gateway_service.go +++ b/backend/internal/service/antigravity_gateway_service.go @@ -3492,12 +3492,14 @@ func (s *AntigravityGatewayService) ForwardUpstream(ctx context.Context, c *gin. return nil, s.writeClaudeError(c, http.StatusInternalServerError, "api_error", "Failed to build request") } // 透传客户端所有请求头(排除 hop-by-hop 和认证头) - for key, values := range c.Request.Header { - if upstreamHopByHopHeaders[strings.ToLower(key)] { - continue - } - for _, v := range values { - req.Header.Add(key, v) + if c != nil && c.Request != nil { + for key, values := range c.Request.Header { + if upstreamHopByHopHeaders[strings.ToLower(key)] { + continue + } + for _, v := range values { + req.Header.Add(key, v) + } } } // 覆盖认证头 @@ -3638,12 +3640,14 @@ func (s *AntigravityGatewayService) ForwardUpstreamGemini(ctx context.Context, c return nil, s.writeGoogleError(c, http.StatusInternalServerError, "Failed to build request") } // 透传客户端所有请求头(排除 hop-by-hop 和认证头) - for key, values := range c.Request.Header { - if upstreamHopByHopHeaders[strings.ToLower(key)] { - continue - } - for _, v := range values { - req.Header.Add(key, v) + if c != nil && c.Request != nil { + for key, values := range c.Request.Header { + if upstreamHopByHopHeaders[strings.ToLower(key)] { + continue + } + for _, v := range values { + req.Header.Add(key, v) + } } } // 覆盖认证头 From 6ab77f5eb5afceb99eb32bba011261866bf6cf14 Mon Sep 17 00:00:00 2001 From: erio Date: Sun, 8 Feb 2026 08:49:43 +0800 Subject: [PATCH 5/8] fix(upstream): passthrough response body directly instead of parsing SSE ForwardUpstream/ForwardUpstreamGemini should pipe the upstream response directly to the client (headers + body), not parse it as SSE stream. --- .../service/antigravity_gateway_service.go | 99 +++++++------------ 1 file changed, 38 insertions(+), 61 deletions(-) diff --git a/backend/internal/service/antigravity_gateway_service.go b/backend/internal/service/antigravity_gateway_service.go index c2983c47..2d96b1ab 100644 --- a/backend/internal/service/antigravity_gateway_service.go +++ b/backend/internal/service/antigravity_gateway_service.go @@ -3530,39 +3530,30 @@ func (s *AntigravityGatewayService) ForwardUpstream(ctx context.Context, c *gin. return nil, s.writeMappedClaudeError(c, account, resp.StatusCode, resp.Header.Get("x-request-id"), respBody) } - // 成功响应 + // 成功响应:透传 response header + body requestID := resp.Header.Get("x-request-id") - if requestID != "" { - c.Header("x-request-id", requestID) + + // 透传上游响应头(排除 hop-by-hop) + for key, values := range resp.Header { + if upstreamHopByHopHeaders[strings.ToLower(key)] { + continue + } + for _, v := range values { + c.Header(key, v) + } } - var usage *ClaudeUsage - var firstTokenMs *int - if claudeReq.Stream { - streamRes, err := s.handleClaudeStreamingResponse(c, resp, startTime, originalModel) - if err != nil { - log.Printf("%s status=stream_error error=%v", prefix, err) - return nil, err - } - usage = streamRes.usage - firstTokenMs = streamRes.firstTokenMs - } else { - streamRes, err := s.handleClaudeStreamToNonStreaming(c, resp, startTime, originalModel) - if err != nil { - log.Printf("%s status=stream_collect_error error=%v", prefix, err) - return nil, err - } - usage = streamRes.usage - firstTokenMs = streamRes.firstTokenMs + c.Status(resp.StatusCode) + _, copyErr := io.Copy(c.Writer, resp.Body) + if copyErr != nil { + log.Printf("%s status=copy_error error=%v", prefix, copyErr) } return &ForwardResult{ - RequestID: requestID, - Usage: *usage, - Model: originalModel, - Stream: claudeReq.Stream, - Duration: time.Since(startTime), - FirstTokenMs: firstTokenMs, + RequestID: requestID, + Model: originalModel, + Stream: claudeReq.Stream, + Duration: time.Since(startTime), }, nil } @@ -3712,35 +3703,23 @@ func (s *AntigravityGatewayService) ForwardUpstreamGemini(ctx context.Context, c return nil, fmt.Errorf("antigravity upstream error: %d", resp.StatusCode) } - // 成功响应 + // 成功响应:透传 response header + body requestID := resp.Header.Get("x-request-id") - if requestID != "" { - c.Header("x-request-id", requestID) + + // 透传上游响应头(排除 hop-by-hop) + for key, values := range resp.Header { + if upstreamHopByHopHeaders[strings.ToLower(key)] { + continue + } + for _, v := range values { + c.Header(key, v) + } } - var usage *ClaudeUsage - var firstTokenMs *int - - if stream { - streamRes, err := s.handleGeminiStreamingResponse(c, resp, startTime) - if err != nil { - log.Printf("%s status=stream_error error=%v", prefix, err) - return nil, err - } - usage = streamRes.usage - firstTokenMs = streamRes.firstTokenMs - } else { - streamRes, err := s.handleGeminiStreamToNonStreaming(c, resp, startTime) - if err != nil { - log.Printf("%s status=stream_collect_error error=%v", prefix, err) - return nil, err - } - usage = streamRes.usage - firstTokenMs = streamRes.firstTokenMs - } - - if usage == nil { - usage = &ClaudeUsage{} + c.Status(resp.StatusCode) + _, copyErr := io.Copy(c.Writer, resp.Body) + if copyErr != nil { + log.Printf("%s status=copy_error error=%v", prefix, copyErr) } imageCount := 0 @@ -3749,13 +3728,11 @@ func (s *AntigravityGatewayService) ForwardUpstreamGemini(ctx context.Context, c } return &ForwardResult{ - RequestID: requestID, - Usage: *usage, - Model: originalModel, - Stream: stream, - Duration: time.Since(startTime), - FirstTokenMs: firstTokenMs, - ImageCount: imageCount, - ImageSize: imageSize, + RequestID: requestID, + Model: originalModel, + Stream: stream, + Duration: time.Since(startTime), + ImageCount: imageCount, + ImageSize: imageSize, }, nil } From fb58560d15fa34d2fc14b89f301e946e039861e7 Mon Sep 17 00:00:00 2001 From: erio Date: Sun, 8 Feb 2026 13:06:25 +0800 Subject: [PATCH 6/8] refactor(upstream): replace upstream account type with apikey, auto-append /antigravity Upstream accounts now use the standard APIKey type instead of a dedicated upstream type. GetBaseURL() and new GetGeminiBaseURL() automatically append /antigravity for Antigravity platform APIKey accounts, eliminating the need for separate upstream forwarding methods. - Remove ForwardUpstream, ForwardUpstreamGemini, testUpstreamConnection - Remove upstream branch guards in Forward/ForwardGemini/TestConnection - Add migration 052 to convert existing upstream accounts to apikey - Update frontend CreateAccountModal to create apikey type - Add unit tests for GetBaseURL and GetGeminiBaseURL --- backend/internal/handler/gateway_handler.go | 2 +- .../internal/handler/gemini_v1beta_handler.go | 2 +- backend/internal/service/account.go | 16 + .../internal/service/account_base_url_test.go | 160 ++++++++ .../service/antigravity_gateway_service.go | 386 ------------------ .../service/gemini_messages_compat_service.go | 25 +- .../upstream_header_passthrough_test.go | 285 ------------- .../052_migrate_upstream_to_apikey.sql | 11 + .../components/account/CreateAccountModal.vue | 6 +- 9 files changed, 197 insertions(+), 696 deletions(-) create mode 100644 backend/internal/service/account_base_url_test.go delete mode 100644 backend/internal/service/upstream_header_passthrough_test.go create mode 100644 backend/migrations/052_migrate_upstream_to_apikey.sql diff --git a/backend/internal/handler/gateway_handler.go b/backend/internal/handler/gateway_handler.go index ca4442e4..255d3fab 100644 --- a/backend/internal/handler/gateway_handler.go +++ b/backend/internal/handler/gateway_handler.go @@ -482,7 +482,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) { if switchCount > 0 { requestCtx = context.WithValue(requestCtx, ctxkey.AccountSwitchCount, switchCount) } - if account.Platform == service.PlatformAntigravity { + if account.Platform == service.PlatformAntigravity && account.Type != service.AccountTypeAPIKey { result, err = h.antigravityGatewayService.Forward(requestCtx, c, account, body, hasBoundSession) } else { result, err = h.gatewayService.Forward(requestCtx, c, account, parsedReq) diff --git a/backend/internal/handler/gemini_v1beta_handler.go b/backend/internal/handler/gemini_v1beta_handler.go index b1477ac6..2b69be2e 100644 --- a/backend/internal/handler/gemini_v1beta_handler.go +++ b/backend/internal/handler/gemini_v1beta_handler.go @@ -410,7 +410,7 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) { if switchCount > 0 { requestCtx = context.WithValue(requestCtx, ctxkey.AccountSwitchCount, switchCount) } - if account.Platform == service.PlatformAntigravity { + if account.Platform == service.PlatformAntigravity && account.Type != service.AccountTypeAPIKey { result, err = h.antigravityGatewayService.ForwardGemini(requestCtx, c, account, modelName, action, stream, body, hasBoundSession) } else { result, err = h.geminiCompatService.ForwardNative(requestCtx, c, account, modelName, action, stream, body) diff --git a/backend/internal/service/account.go b/backend/internal/service/account.go index a6ae8a68..138d5bcb 100644 --- a/backend/internal/service/account.go +++ b/backend/internal/service/account.go @@ -425,6 +425,22 @@ func (a *Account) GetBaseURL() string { if baseURL == "" { return "https://api.anthropic.com" } + if a.Platform == PlatformAntigravity { + return strings.TrimRight(baseURL, "/") + "/antigravity" + } + return baseURL +} + +// GetGeminiBaseURL 返回 Gemini 兼容端点的 base URL。 +// Antigravity 平台的 APIKey 账号自动拼接 /antigravity。 +func (a *Account) GetGeminiBaseURL(defaultBaseURL string) string { + baseURL := strings.TrimSpace(a.GetCredential("base_url")) + if baseURL == "" { + return defaultBaseURL + } + if a.Platform == PlatformAntigravity && a.Type == AccountTypeAPIKey { + return strings.TrimRight(baseURL, "/") + "/antigravity" + } return baseURL } diff --git a/backend/internal/service/account_base_url_test.go b/backend/internal/service/account_base_url_test.go new file mode 100644 index 00000000..a1322193 --- /dev/null +++ b/backend/internal/service/account_base_url_test.go @@ -0,0 +1,160 @@ +//go:build unit + +package service + +import ( + "testing" +) + +func TestGetBaseURL(t *testing.T) { + tests := []struct { + name string + account Account + expected string + }{ + { + name: "non-apikey type returns empty", + account: Account{ + Type: AccountTypeOAuth, + Platform: PlatformAnthropic, + }, + expected: "", + }, + { + name: "apikey without base_url returns default anthropic", + account: Account{ + Type: AccountTypeAPIKey, + Platform: PlatformAnthropic, + Credentials: map[string]any{}, + }, + expected: "https://api.anthropic.com", + }, + { + name: "apikey with custom base_url", + account: Account{ + Type: AccountTypeAPIKey, + Platform: PlatformAnthropic, + Credentials: map[string]any{"base_url": "https://custom.example.com"}, + }, + expected: "https://custom.example.com", + }, + { + name: "antigravity apikey auto-appends /antigravity", + account: Account{ + Type: AccountTypeAPIKey, + Platform: PlatformAntigravity, + Credentials: map[string]any{"base_url": "https://upstream.example.com"}, + }, + expected: "https://upstream.example.com/antigravity", + }, + { + name: "antigravity apikey trims trailing slash before appending", + account: Account{ + Type: AccountTypeAPIKey, + Platform: PlatformAntigravity, + Credentials: map[string]any{"base_url": "https://upstream.example.com/"}, + }, + expected: "https://upstream.example.com/antigravity", + }, + { + name: "antigravity non-apikey returns empty", + account: Account{ + Type: AccountTypeOAuth, + Platform: PlatformAntigravity, + Credentials: map[string]any{"base_url": "https://upstream.example.com"}, + }, + expected: "", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := tt.account.GetBaseURL() + if result != tt.expected { + t.Errorf("GetBaseURL() = %q, want %q", result, tt.expected) + } + }) + } +} + +func TestGetGeminiBaseURL(t *testing.T) { + const defaultGeminiURL = "https://generativelanguage.googleapis.com" + + tests := []struct { + name string + account Account + expected string + }{ + { + name: "apikey without base_url returns default", + account: Account{ + Type: AccountTypeAPIKey, + Platform: PlatformGemini, + Credentials: map[string]any{}, + }, + expected: defaultGeminiURL, + }, + { + name: "apikey with custom base_url", + account: Account{ + Type: AccountTypeAPIKey, + Platform: PlatformGemini, + Credentials: map[string]any{"base_url": "https://custom-gemini.example.com"}, + }, + expected: "https://custom-gemini.example.com", + }, + { + name: "antigravity apikey auto-appends /antigravity", + account: Account{ + Type: AccountTypeAPIKey, + Platform: PlatformAntigravity, + Credentials: map[string]any{"base_url": "https://upstream.example.com"}, + }, + expected: "https://upstream.example.com/antigravity", + }, + { + name: "antigravity apikey trims trailing slash", + account: Account{ + Type: AccountTypeAPIKey, + Platform: PlatformAntigravity, + Credentials: map[string]any{"base_url": "https://upstream.example.com/"}, + }, + expected: "https://upstream.example.com/antigravity", + }, + { + name: "antigravity oauth does NOT append /antigravity", + account: Account{ + Type: AccountTypeOAuth, + Platform: PlatformAntigravity, + Credentials: map[string]any{"base_url": "https://upstream.example.com"}, + }, + expected: "https://upstream.example.com", + }, + { + name: "oauth without base_url returns default", + account: Account{ + Type: AccountTypeOAuth, + Platform: PlatformAntigravity, + Credentials: map[string]any{}, + }, + expected: defaultGeminiURL, + }, + { + name: "nil credentials returns default", + account: Account{ + Type: AccountTypeAPIKey, + Platform: PlatformGemini, + }, + expected: defaultGeminiURL, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := tt.account.GetGeminiBaseURL(defaultGeminiURL) + if result != tt.expected { + t.Errorf("GetGeminiBaseURL() = %q, want %q", result, tt.expected) + } + }) + } +} diff --git a/backend/internal/service/antigravity_gateway_service.go b/backend/internal/service/antigravity_gateway_service.go index 2d96b1ab..4ea73e64 100644 --- a/backend/internal/service/antigravity_gateway_service.go +++ b/backend/internal/service/antigravity_gateway_service.go @@ -665,9 +665,6 @@ type TestConnectionResult struct { // TestConnection 测试 Antigravity 账号连接(非流式,无重试、无计费) // 支持 Claude 和 Gemini 两种协议,根据 modelID 前缀自动选择 func (s *AntigravityGatewayService) TestConnection(ctx context.Context, account *Account, modelID string) (*TestConnectionResult, error) { - if account.Type == AccountTypeUpstream { - return s.testUpstreamConnection(ctx, account, modelID) - } // 获取 token if s.tokenProvider == nil { @@ -986,10 +983,6 @@ func isModelNotFoundError(statusCode int, body []byte) bool { func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context, account *Account, body []byte, isStickySession bool) (*ForwardResult, error) { startTime := time.Now() - if account.Type == AccountTypeUpstream { - return s.ForwardUpstream(ctx, c, account, body, isStickySession) - } - sessionID := getSessionID(c) prefix := logPrefix(sessionID, account.Name) @@ -1610,10 +1603,6 @@ func stripSignatureSensitiveBlocksFromClaudeRequest(req *antigravity.ClaudeReque func (s *AntigravityGatewayService) ForwardGemini(ctx context.Context, c *gin.Context, account *Account, originalModel string, action string, stream bool, body []byte, isStickySession bool) (*ForwardResult, error) { startTime := time.Now() - if account.Type == AccountTypeUpstream { - return s.ForwardUpstreamGemini(ctx, c, account, originalModel, action, stream, body, isStickySession) - } - sessionID := getSessionID(c) prefix := logPrefix(sessionID, account.Name) @@ -3361,378 +3350,3 @@ func filterEmptyPartsFromGeminiRequest(body []byte) ([]byte, error) { payload["contents"] = filtered return json.Marshal(payload) } - -// --------------------------------------------------------------------------- -// Upstream 专用转发方法 -// upstream 账号直接连接上游 Anthropic/Gemini 兼容端点,不走 Antigravity OAuth 协议转换。 -// --------------------------------------------------------------------------- - -// testUpstreamConnection 测试 upstream 账号连接 -func (s *AntigravityGatewayService) testUpstreamConnection(ctx context.Context, account *Account, modelID string) (*TestConnectionResult, error) { - baseURL := strings.TrimRight(strings.TrimSpace(account.GetCredential("base_url")), "/") - if baseURL == "" { - return nil, errors.New("upstream account missing base_url in credentials") - } - apiKey := strings.TrimSpace(account.GetCredential("api_key")) - if apiKey == "" { - return nil, errors.New("upstream account missing api_key in credentials") - } - - mappedModel := s.getMappedModel(account, modelID) - if mappedModel == "" { - return nil, fmt.Errorf("model %s not in whitelist", modelID) - } - - // 构建最小 Claude 格式请求 - requestBody, _ := json.Marshal(map[string]any{ - "model": mappedModel, - "max_tokens": 1, - "messages": []map[string]any{ - {"role": "user", "content": "."}, - }, - "stream": false, - }) - - apiURL := baseURL + "/antigravity/v1/messages" - req, err := http.NewRequestWithContext(ctx, http.MethodPost, apiURL, bytes.NewReader(requestBody)) - if err != nil { - return nil, fmt.Errorf("构建请求失败: %w", err) - } - req.Header.Set("Content-Type", "application/json") - req.Header.Set("Authorization", "Bearer "+apiKey) - req.Header.Set("x-api-key", apiKey) - req.Header.Set("anthropic-version", "2023-06-01") - - proxyURL := "" - if account.ProxyID != nil && account.Proxy != nil { - proxyURL = account.Proxy.URL() - } - - log.Printf("[antigravity-Test-Upstream] account=%s url=%s", account.Name, apiURL) - - resp, err := s.httpUpstream.Do(req, proxyURL, account.ID, account.Concurrency) - if err != nil { - return nil, fmt.Errorf("请求失败: %w", err) - } - defer func() { _ = resp.Body.Close() }() - - respBody, err := io.ReadAll(io.LimitReader(resp.Body, 2<<20)) - if err != nil { - return nil, fmt.Errorf("读取响应失败: %w", err) - } - - if resp.StatusCode >= 400 { - return nil, fmt.Errorf("API 返回 %d: %s", resp.StatusCode, string(respBody)) - } - - // 从 Claude 格式非流式响应中提取文本 - var claudeResp struct { - Content []struct { - Text string `json:"text"` - } `json:"content"` - } - text := "" - if json.Unmarshal(respBody, &claudeResp) == nil && len(claudeResp.Content) > 0 { - text = claudeResp.Content[0].Text - } - - return &TestConnectionResult{ - Text: text, - MappedModel: mappedModel, - }, nil -} - -// ForwardUpstream 转发 Claude 协议请求到 upstream(不做协议转换) -func (s *AntigravityGatewayService) ForwardUpstream(ctx context.Context, c *gin.Context, account *Account, body []byte, isStickySession bool) (*ForwardResult, error) { - startTime := time.Now() - sessionID := getSessionID(c) - prefix := logPrefix(sessionID, account.Name) - - baseURL := strings.TrimRight(strings.TrimSpace(account.GetCredential("base_url")), "/") - if baseURL == "" { - return nil, s.writeClaudeError(c, http.StatusBadGateway, "api_error", "Upstream account missing base_url") - } - apiKey := strings.TrimSpace(account.GetCredential("api_key")) - if apiKey == "" { - return nil, s.writeClaudeError(c, http.StatusBadGateway, "authentication_error", "Upstream account missing api_key") - } - - // 解析请求以获取模型和流式标志 - var claudeReq antigravity.ClaudeRequest - if err := json.Unmarshal(body, &claudeReq); err != nil { - return nil, s.writeClaudeError(c, http.StatusBadRequest, "invalid_request_error", "Invalid request body") - } - if strings.TrimSpace(claudeReq.Model) == "" { - return nil, s.writeClaudeError(c, http.StatusBadRequest, "invalid_request_error", "Missing model") - } - - originalModel := claudeReq.Model - mappedModel := s.getMappedModel(account, claudeReq.Model) - if mappedModel == "" { - return nil, s.writeClaudeError(c, http.StatusForbidden, "permission_error", fmt.Sprintf("model %s not in whitelist", claudeReq.Model)) - } - - // 代理 URL - proxyURL := "" - if account.ProxyID != nil && account.Proxy != nil { - proxyURL = account.Proxy.URL() - } - - // 统计模型调用次数 - if s.cache != nil { - _, _ = s.cache.IncrModelCallCount(ctx, account.ID, mappedModel) - } - - apiURL := baseURL + "/antigravity/v1/messages" - log.Printf("%s upstream_forward url=%s model=%s", prefix, apiURL, mappedModel) - - // 构建请求:body 原样透传 - req, err := http.NewRequestWithContext(ctx, http.MethodPost, apiURL, bytes.NewReader(body)) - if err != nil { - return nil, s.writeClaudeError(c, http.StatusInternalServerError, "api_error", "Failed to build request") - } - // 透传客户端所有请求头(排除 hop-by-hop 和认证头) - if c != nil && c.Request != nil { - for key, values := range c.Request.Header { - if upstreamHopByHopHeaders[strings.ToLower(key)] { - continue - } - for _, v := range values { - req.Header.Add(key, v) - } - } - } - // 覆盖认证头 - req.Header.Set("Authorization", "Bearer "+apiKey) - req.Header.Set("x-api-key", apiKey) - - if c != nil && len(body) > 0 { - c.Set(OpsUpstreamRequestBodyKey, string(body)) - } - - // 单次发送,不重试 - resp, err := s.httpUpstream.Do(req, proxyURL, account.ID, account.Concurrency) - if err != nil { - return nil, s.writeClaudeError(c, http.StatusBadGateway, "upstream_error", fmt.Sprintf("Upstream request failed: %v", err)) - } - defer func() { _ = resp.Body.Close() }() - - // 错误响应处理 - if resp.StatusCode >= 400 { - respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20)) - - s.handleUpstreamError(ctx, prefix, account, resp.StatusCode, resp.Header, respBody, "", 0, "", isStickySession) - - if s.shouldFailoverUpstreamError(resp.StatusCode) { - return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode, ResponseBody: respBody} - } - - return nil, s.writeMappedClaudeError(c, account, resp.StatusCode, resp.Header.Get("x-request-id"), respBody) - } - - // 成功响应:透传 response header + body - requestID := resp.Header.Get("x-request-id") - - // 透传上游响应头(排除 hop-by-hop) - for key, values := range resp.Header { - if upstreamHopByHopHeaders[strings.ToLower(key)] { - continue - } - for _, v := range values { - c.Header(key, v) - } - } - - c.Status(resp.StatusCode) - _, copyErr := io.Copy(c.Writer, resp.Body) - if copyErr != nil { - log.Printf("%s status=copy_error error=%v", prefix, copyErr) - } - - return &ForwardResult{ - RequestID: requestID, - Model: originalModel, - Stream: claudeReq.Stream, - Duration: time.Since(startTime), - }, nil -} - -// ForwardUpstreamGemini 转发 Gemini 协议请求到 upstream(不做协议转换) -func (s *AntigravityGatewayService) ForwardUpstreamGemini(ctx context.Context, c *gin.Context, account *Account, originalModel string, action string, stream bool, body []byte, isStickySession bool) (*ForwardResult, error) { - startTime := time.Now() - sessionID := getSessionID(c) - prefix := logPrefix(sessionID, account.Name) - - baseURL := strings.TrimRight(strings.TrimSpace(account.GetCredential("base_url")), "/") - if baseURL == "" { - return nil, s.writeGoogleError(c, http.StatusBadGateway, "Upstream account missing base_url") - } - apiKey := strings.TrimSpace(account.GetCredential("api_key")) - if apiKey == "" { - return nil, s.writeGoogleError(c, http.StatusBadGateway, "Upstream account missing api_key") - } - - if strings.TrimSpace(originalModel) == "" { - return nil, s.writeGoogleError(c, http.StatusBadRequest, "Missing model in URL") - } - if strings.TrimSpace(action) == "" { - return nil, s.writeGoogleError(c, http.StatusBadRequest, "Missing action in URL") - } - if len(body) == 0 { - return nil, s.writeGoogleError(c, http.StatusBadRequest, "Request body is empty") - } - - imageSize := s.extractImageSize(body) - - switch action { - case "generateContent", "streamGenerateContent": - // ok - case "countTokens": - c.JSON(http.StatusOK, map[string]any{"totalTokens": 0}) - return &ForwardResult{ - RequestID: "", - Usage: ClaudeUsage{}, - Model: originalModel, - Stream: false, - Duration: time.Since(time.Now()), - FirstTokenMs: nil, - }, nil - default: - return nil, s.writeGoogleError(c, http.StatusNotFound, "Unsupported action: "+action) - } - - mappedModel := s.getMappedModel(account, originalModel) - if mappedModel == "" { - return nil, s.writeGoogleError(c, http.StatusForbidden, fmt.Sprintf("model %s not in whitelist", originalModel)) - } - - // 代理 URL - proxyURL := "" - if account.ProxyID != nil && account.Proxy != nil { - proxyURL = account.Proxy.URL() - } - - // 统计模型调用次数 - if s.cache != nil { - _, _ = s.cache.IncrModelCallCount(ctx, account.ID, mappedModel) - } - - // 构建 upstream URL: base_url + /antigravity/v1beta/models/MODEL:ACTION - apiURL := fmt.Sprintf("%s/antigravity/v1beta/models/%s:%s", baseURL, mappedModel, action) - if stream || action == "streamGenerateContent" { - apiURL += "?alt=sse" - } - - log.Printf("%s upstream_forward_gemini url=%s model=%s action=%s", prefix, apiURL, mappedModel, action) - - // 构建请求:body 原样透传 - req, err := http.NewRequestWithContext(ctx, http.MethodPost, apiURL, bytes.NewReader(body)) - if err != nil { - return nil, s.writeGoogleError(c, http.StatusInternalServerError, "Failed to build request") - } - // 透传客户端所有请求头(排除 hop-by-hop 和认证头) - if c != nil && c.Request != nil { - for key, values := range c.Request.Header { - if upstreamHopByHopHeaders[strings.ToLower(key)] { - continue - } - for _, v := range values { - req.Header.Add(key, v) - } - } - } - // 覆盖认证头 - req.Header.Set("Authorization", "Bearer "+apiKey) - - if c != nil && len(body) > 0 { - c.Set(OpsUpstreamRequestBodyKey, string(body)) - } - - // 单次发送,不重试 - resp, err := s.httpUpstream.Do(req, proxyURL, account.ID, account.Concurrency) - if err != nil { - return nil, s.writeGoogleError(c, http.StatusBadGateway, fmt.Sprintf("Upstream request failed: %v", err)) - } - defer func() { _ = resp.Body.Close() }() - - // 错误响应处理 - if resp.StatusCode >= 400 { - respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20)) - contentType := resp.Header.Get("Content-Type") - - requestID := resp.Header.Get("x-request-id") - if requestID != "" { - c.Header("x-request-id", requestID) - } - - s.handleUpstreamError(ctx, prefix, account, resp.StatusCode, resp.Header, respBody, "", 0, "", isStickySession) - upstreamMsg := strings.TrimSpace(extractAntigravityErrorMessage(respBody)) - upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg) - upstreamDetail := s.getUpstreamErrorDetail(respBody) - - setOpsUpstreamError(c, resp.StatusCode, upstreamMsg, upstreamDetail) - - if s.shouldFailoverUpstreamError(resp.StatusCode) { - appendOpsUpstreamError(c, OpsUpstreamErrorEvent{ - Platform: account.Platform, - AccountID: account.ID, - AccountName: account.Name, - UpstreamStatusCode: resp.StatusCode, - UpstreamRequestID: requestID, - Kind: "failover", - Message: upstreamMsg, - Detail: upstreamDetail, - }) - return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode, ResponseBody: respBody} - } - if contentType == "" { - contentType = "application/json" - } - appendOpsUpstreamError(c, OpsUpstreamErrorEvent{ - Platform: account.Platform, - AccountID: account.ID, - AccountName: account.Name, - UpstreamStatusCode: resp.StatusCode, - UpstreamRequestID: requestID, - Kind: "http_error", - Message: upstreamMsg, - Detail: upstreamDetail, - }) - log.Printf("[antigravity-Forward-Upstream] upstream error status=%d body=%s", resp.StatusCode, truncateForLog(respBody, 500)) - c.Data(resp.StatusCode, contentType, respBody) - return nil, fmt.Errorf("antigravity upstream error: %d", resp.StatusCode) - } - - // 成功响应:透传 response header + body - requestID := resp.Header.Get("x-request-id") - - // 透传上游响应头(排除 hop-by-hop) - for key, values := range resp.Header { - if upstreamHopByHopHeaders[strings.ToLower(key)] { - continue - } - for _, v := range values { - c.Header(key, v) - } - } - - c.Status(resp.StatusCode) - _, copyErr := io.Copy(c.Writer, resp.Body) - if copyErr != nil { - log.Printf("%s status=copy_error error=%v", prefix, copyErr) - } - - imageCount := 0 - if isImageGenerationModel(mappedModel) { - imageCount = 1 - } - - return &ForwardResult{ - RequestID: requestID, - Model: originalModel, - Stream: stream, - Duration: time.Since(startTime), - ImageCount: imageCount, - ImageSize: imageSize, - }, nil -} diff --git a/backend/internal/service/gemini_messages_compat_service.go b/backend/internal/service/gemini_messages_compat_service.go index 0f156c2e..4e0442fd 100644 --- a/backend/internal/service/gemini_messages_compat_service.go +++ b/backend/internal/service/gemini_messages_compat_service.go @@ -560,10 +560,7 @@ func (s *GeminiMessagesCompatService) Forward(ctx context.Context, c *gin.Contex return nil, "", errors.New("gemini api_key not configured") } - baseURL := strings.TrimSpace(account.GetCredential("base_url")) - if baseURL == "" { - baseURL = geminicli.AIStudioBaseURL - } + baseURL := account.GetGeminiBaseURL(geminicli.AIStudioBaseURL) normalizedBaseURL, err := s.validateUpstreamBaseURL(baseURL) if err != nil { return nil, "", err @@ -640,10 +637,7 @@ func (s *GeminiMessagesCompatService) Forward(ctx context.Context, c *gin.Contex return upstreamReq, "x-request-id", nil } else { // Mode 2: AI Studio API with OAuth (like API key mode, but using Bearer token) - baseURL := strings.TrimSpace(account.GetCredential("base_url")) - if baseURL == "" { - baseURL = geminicli.AIStudioBaseURL - } + baseURL := account.GetGeminiBaseURL(geminicli.AIStudioBaseURL) normalizedBaseURL, err := s.validateUpstreamBaseURL(baseURL) if err != nil { return nil, "", err @@ -1026,10 +1020,7 @@ func (s *GeminiMessagesCompatService) ForwardNative(ctx context.Context, c *gin. return nil, "", errors.New("gemini api_key not configured") } - baseURL := strings.TrimSpace(account.GetCredential("base_url")) - if baseURL == "" { - baseURL = geminicli.AIStudioBaseURL - } + baseURL := account.GetGeminiBaseURL(geminicli.AIStudioBaseURL) normalizedBaseURL, err := s.validateUpstreamBaseURL(baseURL) if err != nil { return nil, "", err @@ -1097,10 +1088,7 @@ func (s *GeminiMessagesCompatService) ForwardNative(ctx context.Context, c *gin. return upstreamReq, "x-request-id", nil } else { // Mode 2: AI Studio API with OAuth (like API key mode, but using Bearer token) - baseURL := strings.TrimSpace(account.GetCredential("base_url")) - if baseURL == "" { - baseURL = geminicli.AIStudioBaseURL - } + baseURL := account.GetGeminiBaseURL(geminicli.AIStudioBaseURL) normalizedBaseURL, err := s.validateUpstreamBaseURL(baseURL) if err != nil { return nil, "", err @@ -2420,10 +2408,7 @@ func (s *GeminiMessagesCompatService) ForwardAIStudioGET(ctx context.Context, ac return nil, errors.New("invalid path") } - baseURL := strings.TrimSpace(account.GetCredential("base_url")) - if baseURL == "" { - baseURL = geminicli.AIStudioBaseURL - } + baseURL := account.GetGeminiBaseURL(geminicli.AIStudioBaseURL) normalizedBaseURL, err := s.validateUpstreamBaseURL(baseURL) if err != nil { return nil, err diff --git a/backend/internal/service/upstream_header_passthrough_test.go b/backend/internal/service/upstream_header_passthrough_test.go deleted file mode 100644 index 51d8588b..00000000 --- a/backend/internal/service/upstream_header_passthrough_test.go +++ /dev/null @@ -1,285 +0,0 @@ -//go:build unit - -package service - -import ( - "bytes" - "context" - "encoding/json" - "io" - "net/http" - "net/http/httptest" - "testing" - - "github.com/gin-gonic/gin" - "github.com/stretchr/testify/require" -) - -// httpUpstreamCapture captures the outgoing *http.Request for assertion. -type httpUpstreamCapture struct { - capturedReq *http.Request - resp *http.Response - err error -} - -func (s *httpUpstreamCapture) Do(req *http.Request, _ string, _ int64, _ int) (*http.Response, error) { - s.capturedReq = req - return s.resp, s.err -} - -func (s *httpUpstreamCapture) DoWithTLS(req *http.Request, _ string, _ int64, _ int, _ bool) (*http.Response, error) { - s.capturedReq = req - return s.resp, s.err -} - -func newUpstreamAccount() *Account { - return &Account{ - ID: 100, - Name: "upstream-test", - Platform: PlatformAntigravity, - Type: AccountTypeUpstream, - Status: StatusActive, - Concurrency: 1, - Credentials: map[string]any{ - "base_url": "https://upstream.example.com", - "api_key": "sk-upstream-secret", - }, - } -} - -// makeSSEOKResponse builds a minimal SSE response that -// handleClaudeStreamingResponse / handleGeminiStreamingResponse -// can consume without error. -// We return 502 to bypass streaming and hit the error branch instead, -// which is sufficient for testing header passthrough. -func makeUpstreamErrorResponse() *http.Response { - body := []byte(`{"error":{"message":"test error"}}`) - return &http.Response{ - StatusCode: http.StatusBadGateway, - Header: http.Header{"Content-Type": []string{"application/json"}}, - Body: io.NopCloser(bytes.NewReader(body)), - } -} - -// --- ForwardUpstream tests --- - -func TestForwardUpstream_PassthroughHeaders(t *testing.T) { - gin.SetMode(gin.TestMode) - w := httptest.NewRecorder() - c, _ := gin.CreateTestContext(w) - - body, _ := json.Marshal(map[string]any{ - "model": "claude-sonnet-4-5", - "messages": []map[string]any{{"role": "user", "content": "hi"}}, - "max_tokens": 1, - "stream": false, - }) - - req := httptest.NewRequest(http.MethodPost, "/v1/messages", bytes.NewReader(body)) - req.Header.Set("Content-Type", "application/json") - req.Header.Set("anthropic-version", "2024-10-22") - req.Header.Set("anthropic-beta", "output-128k-2025-02-19") - req.Header.Set("X-Custom-Header", "custom-value") - c.Request = req - - stub := &httpUpstreamCapture{resp: makeUpstreamErrorResponse()} - svc := &AntigravityGatewayService{ - tokenProvider: &AntigravityTokenProvider{}, - httpUpstream: stub, - } - - _, _ = svc.ForwardUpstream(context.Background(), c, newUpstreamAccount(), body, false) - - captured := stub.capturedReq - require.NotNil(t, captured, "upstream request should have been made") - - // 客户端 header 应被透传 - require.Equal(t, "application/json", captured.Header.Get("Content-Type")) - require.Equal(t, "2024-10-22", captured.Header.Get("anthropic-version")) - require.Equal(t, "output-128k-2025-02-19", captured.Header.Get("anthropic-beta")) - require.Equal(t, "custom-value", captured.Header.Get("X-Custom-Header")) -} - -func TestForwardUpstream_OverridesAuthHeaders(t *testing.T) { - gin.SetMode(gin.TestMode) - w := httptest.NewRecorder() - c, _ := gin.CreateTestContext(w) - - body, _ := json.Marshal(map[string]any{ - "model": "claude-sonnet-4-5", - "messages": []map[string]any{{"role": "user", "content": "hi"}}, - "max_tokens": 1, - "stream": false, - }) - - req := httptest.NewRequest(http.MethodPost, "/v1/messages", bytes.NewReader(body)) - req.Header.Set("Content-Type", "application/json") - // 客户端发来的认证头应被覆盖 - req.Header.Set("Authorization", "Bearer client-token") - req.Header.Set("x-api-key", "client-api-key") - c.Request = req - - stub := &httpUpstreamCapture{resp: makeUpstreamErrorResponse()} - svc := &AntigravityGatewayService{ - tokenProvider: &AntigravityTokenProvider{}, - httpUpstream: stub, - } - - _, _ = svc.ForwardUpstream(context.Background(), c, newUpstreamAccount(), body, false) - - captured := stub.capturedReq - require.NotNil(t, captured) - - // 认证头应使用上游账号的 api_key,而非客户端的 - require.Equal(t, "Bearer sk-upstream-secret", captured.Header.Get("Authorization")) - require.Equal(t, "sk-upstream-secret", captured.Header.Get("x-api-key")) -} - -func TestForwardUpstream_ExcludesHopByHopHeaders(t *testing.T) { - gin.SetMode(gin.TestMode) - w := httptest.NewRecorder() - c, _ := gin.CreateTestContext(w) - - body, _ := json.Marshal(map[string]any{ - "model": "claude-sonnet-4-5", - "messages": []map[string]any{{"role": "user", "content": "hi"}}, - "max_tokens": 1, - "stream": false, - }) - - req := httptest.NewRequest(http.MethodPost, "/v1/messages", bytes.NewReader(body)) - req.Header.Set("Content-Type", "application/json") - req.Header.Set("Connection", "keep-alive") - req.Header.Set("Keep-Alive", "timeout=5") - req.Header.Set("Transfer-Encoding", "chunked") - req.Header.Set("Upgrade", "websocket") - req.Header.Set("Te", "trailers") - c.Request = req - - stub := &httpUpstreamCapture{resp: makeUpstreamErrorResponse()} - svc := &AntigravityGatewayService{ - tokenProvider: &AntigravityTokenProvider{}, - httpUpstream: stub, - } - - _, _ = svc.ForwardUpstream(context.Background(), c, newUpstreamAccount(), body, false) - - captured := stub.capturedReq - require.NotNil(t, captured) - - // hop-by-hop header 不应出现 - require.Empty(t, captured.Header.Get("Connection")) - require.Empty(t, captured.Header.Get("Keep-Alive")) - require.Empty(t, captured.Header.Get("Transfer-Encoding")) - require.Empty(t, captured.Header.Get("Upgrade")) - require.Empty(t, captured.Header.Get("Te")) - - // 但普通 header 应保留 - require.Equal(t, "application/json", captured.Header.Get("Content-Type")) -} - -// --- ForwardUpstreamGemini tests --- - -func TestForwardUpstreamGemini_PassthroughHeaders(t *testing.T) { - gin.SetMode(gin.TestMode) - w := httptest.NewRecorder() - c, _ := gin.CreateTestContext(w) - - body, _ := json.Marshal(map[string]any{ - "contents": []map[string]any{ - {"role": "user", "parts": []map[string]any{{"text": "hi"}}}, - }, - }) - - req := httptest.NewRequest(http.MethodPost, "/v1beta/models/gemini-2.5-flash:generateContent", bytes.NewReader(body)) - req.Header.Set("Content-Type", "application/json") - req.Header.Set("X-Custom-Gemini", "gemini-value") - req.Header.Set("X-Request-Id", "req-abc-123") - c.Request = req - - stub := &httpUpstreamCapture{resp: makeUpstreamErrorResponse()} - svc := &AntigravityGatewayService{ - tokenProvider: &AntigravityTokenProvider{}, - httpUpstream: stub, - } - - _, _ = svc.ForwardUpstreamGemini(context.Background(), c, newUpstreamAccount(), "gemini-2.5-flash", "generateContent", false, body, false) - - captured := stub.capturedReq - require.NotNil(t, captured, "upstream request should have been made") - - // 客户端 header 应被透传 - require.Equal(t, "application/json", captured.Header.Get("Content-Type")) - require.Equal(t, "gemini-value", captured.Header.Get("X-Custom-Gemini")) - require.Equal(t, "req-abc-123", captured.Header.Get("X-Request-Id")) -} - -func TestForwardUpstreamGemini_OverridesAuthHeaders(t *testing.T) { - gin.SetMode(gin.TestMode) - w := httptest.NewRecorder() - c, _ := gin.CreateTestContext(w) - - body, _ := json.Marshal(map[string]any{ - "contents": []map[string]any{ - {"role": "user", "parts": []map[string]any{{"text": "hi"}}}, - }, - }) - - req := httptest.NewRequest(http.MethodPost, "/v1beta/models/gemini-2.5-flash:generateContent", bytes.NewReader(body)) - req.Header.Set("Content-Type", "application/json") - req.Header.Set("Authorization", "Bearer client-gemini-token") - c.Request = req - - stub := &httpUpstreamCapture{resp: makeUpstreamErrorResponse()} - svc := &AntigravityGatewayService{ - tokenProvider: &AntigravityTokenProvider{}, - httpUpstream: stub, - } - - _, _ = svc.ForwardUpstreamGemini(context.Background(), c, newUpstreamAccount(), "gemini-2.5-flash", "generateContent", false, body, false) - - captured := stub.capturedReq - require.NotNil(t, captured) - - // 认证头应使用上游账号的 api_key - require.Equal(t, "Bearer sk-upstream-secret", captured.Header.Get("Authorization")) -} - -func TestForwardUpstreamGemini_ExcludesHopByHopHeaders(t *testing.T) { - gin.SetMode(gin.TestMode) - w := httptest.NewRecorder() - c, _ := gin.CreateTestContext(w) - - body, _ := json.Marshal(map[string]any{ - "contents": []map[string]any{ - {"role": "user", "parts": []map[string]any{{"text": "hi"}}}, - }, - }) - - req := httptest.NewRequest(http.MethodPost, "/v1beta/models/gemini-2.5-flash:generateContent", bytes.NewReader(body)) - req.Header.Set("Content-Type", "application/json") - req.Header.Set("Connection", "keep-alive") - req.Header.Set("Proxy-Authorization", "Basic dXNlcjpwYXNz") - req.Header.Set("Host", "evil.example.com") - c.Request = req - - stub := &httpUpstreamCapture{resp: makeUpstreamErrorResponse()} - svc := &AntigravityGatewayService{ - tokenProvider: &AntigravityTokenProvider{}, - httpUpstream: stub, - } - - _, _ = svc.ForwardUpstreamGemini(context.Background(), c, newUpstreamAccount(), "gemini-2.5-flash", "generateContent", false, body, false) - - captured := stub.capturedReq - require.NotNil(t, captured) - - // hop-by-hop header 不应出现 - require.Empty(t, captured.Header.Get("Connection")) - require.Empty(t, captured.Header.Get("Proxy-Authorization")) - // Host header 在 Go http.Request 中特殊处理,但我们的黑名单应阻止透传 - require.Empty(t, captured.Header.Values("Host")) - - // 普通 header 应保留 - require.Equal(t, "application/json", captured.Header.Get("Content-Type")) -} diff --git a/backend/migrations/052_migrate_upstream_to_apikey.sql b/backend/migrations/052_migrate_upstream_to_apikey.sql new file mode 100644 index 00000000..974f3f3c --- /dev/null +++ b/backend/migrations/052_migrate_upstream_to_apikey.sql @@ -0,0 +1,11 @@ +-- Migrate upstream accounts to apikey type +-- Background: upstream type is no longer needed. Antigravity platform APIKey accounts +-- with base_url pointing to an upstream sub2api instance can reuse the standard +-- APIKey forwarding path. GetBaseURL()/GetGeminiBaseURL() automatically appends +-- /antigravity for Antigravity platform APIKey accounts. + +UPDATE accounts +SET type = 'apikey' +WHERE type = 'upstream' + AND platform = 'antigravity' + AND deleted_at IS NULL; diff --git a/frontend/src/components/account/CreateAccountModal.vue b/frontend/src/components/account/CreateAccountModal.vue index 7d759be1..603941c1 100644 --- a/frontend/src/components/account/CreateAccountModal.vue +++ b/frontend/src/components/account/CreateAccountModal.vue @@ -2289,9 +2289,9 @@ watch( watch( [accountCategory, addMethod, antigravityAccountType], ([category, method, agType]) => { - // Antigravity upstream 类型 + // Antigravity upstream 类型(实际创建为 apikey) if (form.platform === 'antigravity' && agType === 'upstream') { - form.type = 'upstream' + form.type = 'apikey' return } if (category === 'oauth-based') { @@ -2715,7 +2715,7 @@ const handleSubmit = async () => { submitting.value = true try { const extra = mixedScheduling.value ? { mixed_scheduling: true } : undefined - await createAccountAndFinish(form.platform, 'upstream', credentials, extra) + await createAccountAndFinish(form.platform, 'apikey', credentials, extra) } catch (error: any) { appStore.showError(error.response?.data?.detail || t('admin.accounts.failedToCreate')) } finally { From 3c936441469d9483bd02c2681fcfbea9fa271f9a Mon Sep 17 00:00:00 2001 From: erio Date: Sun, 8 Feb 2026 13:14:58 +0800 Subject: [PATCH 7/8] chore: bump version to 0.1.74.7 --- backend/cmd/server/VERSION | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/backend/cmd/server/VERSION b/backend/cmd/server/VERSION index f0768f09..bc88be6e 100644 --- a/backend/cmd/server/VERSION +++ b/backend/cmd/server/VERSION @@ -1 +1 @@ -0.1.70 +0.1.74.7 From 69816f8691e9374adfafde596c5b5a34ec96ddaf Mon Sep 17 00:00:00 2001 From: erio Date: Sun, 8 Feb 2026 13:30:39 +0800 Subject: [PATCH 8/8] fix: remove unused upstreamHopByHopHeaders variable to pass golangci-lint --- .../service/antigravity_gateway_service.go | 15 --------------- 1 file changed, 15 deletions(-) diff --git a/backend/internal/service/antigravity_gateway_service.go b/backend/internal/service/antigravity_gateway_service.go index 4ea73e64..26b1c530 100644 --- a/backend/internal/service/antigravity_gateway_service.go +++ b/backend/internal/service/antigravity_gateway_service.go @@ -47,21 +47,6 @@ const ( googleRPCReasonRateLimitExceeded = "RATE_LIMIT_EXCEEDED" ) -// upstreamHopByHopHeaders 透传请求头时需要排除的 hop-by-hop 头 -var upstreamHopByHopHeaders = map[string]bool{ - "connection": true, - "keep-alive": true, - "proxy-authenticate": true, - "proxy-authorization": true, - "proxy-connection": true, - "te": true, - "trailer": true, - "transfer-encoding": true, - "upgrade": true, - "host": true, - "content-length": true, -} - // antigravityPassthroughErrorMessages 透传给客户端的错误消息白名单(小写) // 匹配时使用 strings.Contains,无需完全匹配 var antigravityPassthroughErrorMessages = []string{