diff --git a/backend/internal/pkg/antigravity/client.go b/backend/internal/pkg/antigravity/client.go index d425b881..8b0ba6ec 100644 --- a/backend/internal/pkg/antigravity/client.go +++ b/backend/internal/pkg/antigravity/client.go @@ -1,6 +1,7 @@ package antigravity import ( + "bytes" "context" "encoding/json" "fmt" @@ -11,6 +12,19 @@ import ( "time" ) +// NewAPIRequest 创建 Antigravity API 请求(v1internal 端点) +func NewAPIRequest(ctx context.Context, action, accessToken string, body []byte) (*http.Request, error) { + apiURL := fmt.Sprintf("%s/v1internal:%s", BaseURL, action) + req, err := http.NewRequestWithContext(ctx, http.MethodPost, apiURL, bytes.NewReader(body)) + if err != nil { + return nil, err + } + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Authorization", "Bearer "+accessToken) + req.Header.Set("User-Agent", UserAgent) + return req, nil +} + // TokenResponse Google OAuth token 响应 type TokenResponse struct { AccessToken string `json:"access_token"` diff --git a/backend/internal/service/antigravity_gateway_service.go b/backend/internal/service/antigravity_gateway_service.go index 670f53ee..94b37371 100644 --- a/backend/internal/service/antigravity_gateway_service.go +++ b/backend/internal/service/antigravity_gateway_service.go @@ -148,11 +148,8 @@ func (s *AntigravityGatewayService) TestConnection(ctx context.Context, account return nil, fmt.Errorf("获取 access_token 失败: %w", err) } - // 获取 project_id + // 获取 project_id(部分账户类型可能没有) projectID := strings.TrimSpace(account.GetCredential("project_id")) - if projectID == "" { - return nil, errors.New("project_id not found in credentials") - } // 模型映射 mappedModel := s.getMappedModel(account, modelID) @@ -171,14 +168,10 @@ func (s *AntigravityGatewayService) TestConnection(ctx context.Context, account } // 构建 HTTP 请求(非流式) - fullURL := fmt.Sprintf("%s/v1internal:generateContent", antigravity.BaseURL) - req, err := http.NewRequestWithContext(ctx, http.MethodPost, fullURL, bytes.NewReader(requestBody)) + req, err := antigravity.NewAPIRequest(ctx, "generateContent", accessToken, requestBody) if err != nil { return nil, err } - req.Header.Set("Content-Type", "application/json") - req.Header.Set("Authorization", "Bearer "+accessToken) - req.Header.Set("User-Agent", antigravity.UserAgent) // 代理 URL proxyURL := "" @@ -350,11 +343,8 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context, return nil, fmt.Errorf("获取 access_token 失败: %w", err) } - // 获取 project_id + // 获取 project_id(部分账户类型可能没有) projectID := strings.TrimSpace(account.GetCredential("project_id")) - if projectID == "" { - return nil, errors.New("project_id not found in credentials") - } // 代理 URL proxyURL := "" @@ -368,26 +358,19 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context, return nil, fmt.Errorf("transform request: %w", err) } - // 构建上游 URL + // 构建上游 action action := "generateContent" if claudeReq.Stream { - action = "streamGenerateContent" - } - fullURL := fmt.Sprintf("%s/v1internal:%s", antigravity.BaseURL, action) - if claudeReq.Stream { - fullURL += "?alt=sse" + action = "streamGenerateContent?alt=sse" } // 重试循环 var resp *http.Response for attempt := 1; attempt <= antigravityMaxRetries; attempt++ { - upstreamReq, err := http.NewRequestWithContext(ctx, http.MethodPost, fullURL, bytes.NewReader(geminiBody)) + upstreamReq, err := antigravity.NewAPIRequest(ctx, action, accessToken, geminiBody) if err != nil { return nil, err } - upstreamReq.Header.Set("Content-Type", "application/json") - upstreamReq.Header.Set("Authorization", "Bearer "+accessToken) - upstreamReq.Header.Set("User-Agent", antigravity.UserAgent) resp, err = s.httpUpstream.Do(upstreamReq, proxyURL) if err != nil { @@ -500,11 +483,8 @@ func (s *AntigravityGatewayService) ForwardGemini(ctx context.Context, c *gin.Co return nil, fmt.Errorf("获取 access_token 失败: %w", err) } - // 获取 project_id + // 获取 project_id(部分账户类型可能没有) projectID := strings.TrimSpace(account.GetCredential("project_id")) - if projectID == "" { - return nil, errors.New("project_id not found in credentials") - } // 代理 URL proxyURL := "" @@ -518,26 +498,22 @@ func (s *AntigravityGatewayService) ForwardGemini(ctx context.Context, c *gin.Co return nil, err } - // 构建上游 URL + // 构建上游 action upstreamAction := action if action == "generateContent" && stream { upstreamAction = "streamGenerateContent" } - fullURL := fmt.Sprintf("%s/v1internal:%s", antigravity.BaseURL, upstreamAction) if stream || upstreamAction == "streamGenerateContent" { - fullURL += "?alt=sse" + upstreamAction += "?alt=sse" } // 重试循环 var resp *http.Response for attempt := 1; attempt <= antigravityMaxRetries; attempt++ { - upstreamReq, err := http.NewRequestWithContext(ctx, http.MethodPost, fullURL, bytes.NewReader(wrappedBody)) + upstreamReq, err := antigravity.NewAPIRequest(ctx, upstreamAction, accessToken, wrappedBody) if err != nil { return nil, err } - upstreamReq.Header.Set("Content-Type", "application/json") - upstreamReq.Header.Set("Authorization", "Bearer "+accessToken) - upstreamReq.Header.Set("User-Agent", antigravity.UserAgent) resp, err = s.httpUpstream.Do(upstreamReq, proxyURL) if err != nil { diff --git a/backend/internal/service/antigravity_oauth_service.go b/backend/internal/service/antigravity_oauth_service.go index fc6cc74d..0d104043 100644 --- a/backend/internal/service/antigravity_oauth_service.go +++ b/backend/internal/service/antigravity_oauth_service.go @@ -141,7 +141,7 @@ func (s *AntigravityOAuthService) ExchangeCode(ctx context.Context, input *Antig result.Email = userInfo.Email } - // 获取 project_id + // 获取 project_id(部分账户类型可能没有) loadResp, err := client.LoadCodeAssist(ctx, tokenResp.AccessToken) if err != nil { fmt.Printf("[AntigravityOAuth] 警告: 获取 project_id 失败: %v\n", err) diff --git a/backend/internal/service/antigravity_quota_refresher.go b/backend/internal/service/antigravity_quota_refresher.go index 5ed59d2f..de067b07 100644 --- a/backend/internal/service/antigravity_quota_refresher.go +++ b/backend/internal/service/antigravity_quota_refresher.go @@ -125,8 +125,8 @@ func (r *AntigravityQuotaRefresher) refreshAccountQuota(ctx context.Context, acc accessToken := account.GetCredential("access_token") projectID := account.GetCredential("project_id") - if accessToken == "" || projectID == "" { - return nil // 没有有效凭证,跳过 + if accessToken == "" { + return nil // 没有 access_token,跳过 } // token 过期则跳过,由 TokenRefreshService 负责刷新 @@ -151,7 +151,10 @@ func (r *AntigravityQuotaRefresher) refreshAccountQuota(ctx context.Context, acc r.updateAccountTier(account, loadResp) } - // 调用 API 获取配额 + // 调用 API 获取配额(需要 projectID) + if projectID == "" { + return r.accountRepo.Update(ctx, account) // 没有 projectID,只更新 tier + } modelsResp, err := client.FetchAvailableModels(ctx, accessToken, projectID) if err != nil { return err