fix(Antigravity): 支持无 project_id 的账户类型

- 移除 project_id 强制检查,部分账户类型 API 不返回此字段
- 重构:提取 antigravity.NewAPIRequest() 统一创建 API 请求
- quota_refresher: 无 project_id 时仍可更新 tier 信息
This commit is contained in:
song
2025-12-30 23:42:50 +08:00
parent 5844ea7e6e
commit 1c42403e6d
4 changed files with 31 additions and 38 deletions

View File

@@ -1,6 +1,7 @@
package antigravity
import (
"bytes"
"context"
"encoding/json"
"fmt"
@@ -11,6 +12,19 @@ import (
"time"
)
// NewAPIRequest 创建 Antigravity API 请求v1internal 端点)
func NewAPIRequest(ctx context.Context, action, accessToken string, body []byte) (*http.Request, error) {
apiURL := fmt.Sprintf("%s/v1internal:%s", BaseURL, action)
req, err := http.NewRequestWithContext(ctx, http.MethodPost, apiURL, bytes.NewReader(body))
if err != nil {
return nil, err
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", "Bearer "+accessToken)
req.Header.Set("User-Agent", UserAgent)
return req, nil
}
// TokenResponse Google OAuth token 响应
type TokenResponse struct {
AccessToken string `json:"access_token"`

View File

@@ -148,11 +148,8 @@ func (s *AntigravityGatewayService) TestConnection(ctx context.Context, account
return nil, fmt.Errorf("获取 access_token 失败: %w", err)
}
// 获取 project_id
// 获取 project_id(部分账户类型可能没有)
projectID := strings.TrimSpace(account.GetCredential("project_id"))
if projectID == "" {
return nil, errors.New("project_id not found in credentials")
}
// 模型映射
mappedModel := s.getMappedModel(account, modelID)
@@ -171,14 +168,10 @@ func (s *AntigravityGatewayService) TestConnection(ctx context.Context, account
}
// 构建 HTTP 请求(非流式)
fullURL := fmt.Sprintf("%s/v1internal:generateContent", antigravity.BaseURL)
req, err := http.NewRequestWithContext(ctx, http.MethodPost, fullURL, bytes.NewReader(requestBody))
req, err := antigravity.NewAPIRequest(ctx, "generateContent", accessToken, requestBody)
if err != nil {
return nil, err
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", "Bearer "+accessToken)
req.Header.Set("User-Agent", antigravity.UserAgent)
// 代理 URL
proxyURL := ""
@@ -350,11 +343,8 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context,
return nil, fmt.Errorf("获取 access_token 失败: %w", err)
}
// 获取 project_id
// 获取 project_id(部分账户类型可能没有)
projectID := strings.TrimSpace(account.GetCredential("project_id"))
if projectID == "" {
return nil, errors.New("project_id not found in credentials")
}
// 代理 URL
proxyURL := ""
@@ -368,26 +358,19 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context,
return nil, fmt.Errorf("transform request: %w", err)
}
// 构建上游 URL
// 构建上游 action
action := "generateContent"
if claudeReq.Stream {
action = "streamGenerateContent"
}
fullURL := fmt.Sprintf("%s/v1internal:%s", antigravity.BaseURL, action)
if claudeReq.Stream {
fullURL += "?alt=sse"
action = "streamGenerateContent?alt=sse"
}
// 重试循环
var resp *http.Response
for attempt := 1; attempt <= antigravityMaxRetries; attempt++ {
upstreamReq, err := http.NewRequestWithContext(ctx, http.MethodPost, fullURL, bytes.NewReader(geminiBody))
upstreamReq, err := antigravity.NewAPIRequest(ctx, action, accessToken, geminiBody)
if err != nil {
return nil, err
}
upstreamReq.Header.Set("Content-Type", "application/json")
upstreamReq.Header.Set("Authorization", "Bearer "+accessToken)
upstreamReq.Header.Set("User-Agent", antigravity.UserAgent)
resp, err = s.httpUpstream.Do(upstreamReq, proxyURL)
if err != nil {
@@ -500,11 +483,8 @@ func (s *AntigravityGatewayService) ForwardGemini(ctx context.Context, c *gin.Co
return nil, fmt.Errorf("获取 access_token 失败: %w", err)
}
// 获取 project_id
// 获取 project_id(部分账户类型可能没有)
projectID := strings.TrimSpace(account.GetCredential("project_id"))
if projectID == "" {
return nil, errors.New("project_id not found in credentials")
}
// 代理 URL
proxyURL := ""
@@ -518,26 +498,22 @@ func (s *AntigravityGatewayService) ForwardGemini(ctx context.Context, c *gin.Co
return nil, err
}
// 构建上游 URL
// 构建上游 action
upstreamAction := action
if action == "generateContent" && stream {
upstreamAction = "streamGenerateContent"
}
fullURL := fmt.Sprintf("%s/v1internal:%s", antigravity.BaseURL, upstreamAction)
if stream || upstreamAction == "streamGenerateContent" {
fullURL += "?alt=sse"
upstreamAction += "?alt=sse"
}
// 重试循环
var resp *http.Response
for attempt := 1; attempt <= antigravityMaxRetries; attempt++ {
upstreamReq, err := http.NewRequestWithContext(ctx, http.MethodPost, fullURL, bytes.NewReader(wrappedBody))
upstreamReq, err := antigravity.NewAPIRequest(ctx, upstreamAction, accessToken, wrappedBody)
if err != nil {
return nil, err
}
upstreamReq.Header.Set("Content-Type", "application/json")
upstreamReq.Header.Set("Authorization", "Bearer "+accessToken)
upstreamReq.Header.Set("User-Agent", antigravity.UserAgent)
resp, err = s.httpUpstream.Do(upstreamReq, proxyURL)
if err != nil {

View File

@@ -141,7 +141,7 @@ func (s *AntigravityOAuthService) ExchangeCode(ctx context.Context, input *Antig
result.Email = userInfo.Email
}
// 获取 project_id
// 获取 project_id(部分账户类型可能没有)
loadResp, err := client.LoadCodeAssist(ctx, tokenResp.AccessToken)
if err != nil {
fmt.Printf("[AntigravityOAuth] 警告: 获取 project_id 失败: %v\n", err)

View File

@@ -125,8 +125,8 @@ func (r *AntigravityQuotaRefresher) refreshAccountQuota(ctx context.Context, acc
accessToken := account.GetCredential("access_token")
projectID := account.GetCredential("project_id")
if accessToken == "" || projectID == "" {
return nil // 没有有效凭证,跳过
if accessToken == "" {
return nil // 没有 access_token,跳过
}
// token 过期则跳过,由 TokenRefreshService 负责刷新
@@ -151,7 +151,10 @@ func (r *AntigravityQuotaRefresher) refreshAccountQuota(ctx context.Context, acc
r.updateAccountTier(account, loadResp)
}
// 调用 API 获取配额
// 调用 API 获取配额(需要 projectID
if projectID == "" {
return r.accountRepo.Update(ctx, account) // 没有 projectID只更新 tier
}
modelsResp, err := client.FetchAvailableModels(ctx, accessToken, projectID)
if err != nil {
return err