fix(Antigravity): 支持无 project_id 的账户类型

- 移除 project_id 强制检查,部分账户类型 API 不返回此字段
- 重构:提取 antigravity.NewAPIRequest() 统一创建 API 请求
- quota_refresher: 无 project_id 时仍可更新 tier 信息
This commit is contained in:
song
2025-12-30 23:42:50 +08:00
parent 5844ea7e6e
commit 1c42403e6d
4 changed files with 31 additions and 38 deletions

View File

@@ -1,6 +1,7 @@
package antigravity package antigravity
import ( import (
"bytes"
"context" "context"
"encoding/json" "encoding/json"
"fmt" "fmt"
@@ -11,6 +12,19 @@ import (
"time" "time"
) )
// NewAPIRequest 创建 Antigravity API 请求v1internal 端点)
func NewAPIRequest(ctx context.Context, action, accessToken string, body []byte) (*http.Request, error) {
apiURL := fmt.Sprintf("%s/v1internal:%s", BaseURL, action)
req, err := http.NewRequestWithContext(ctx, http.MethodPost, apiURL, bytes.NewReader(body))
if err != nil {
return nil, err
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", "Bearer "+accessToken)
req.Header.Set("User-Agent", UserAgent)
return req, nil
}
// TokenResponse Google OAuth token 响应 // TokenResponse Google OAuth token 响应
type TokenResponse struct { type TokenResponse struct {
AccessToken string `json:"access_token"` AccessToken string `json:"access_token"`

View File

@@ -148,11 +148,8 @@ func (s *AntigravityGatewayService) TestConnection(ctx context.Context, account
return nil, fmt.Errorf("获取 access_token 失败: %w", err) return nil, fmt.Errorf("获取 access_token 失败: %w", err)
} }
// 获取 project_id // 获取 project_id(部分账户类型可能没有)
projectID := strings.TrimSpace(account.GetCredential("project_id")) projectID := strings.TrimSpace(account.GetCredential("project_id"))
if projectID == "" {
return nil, errors.New("project_id not found in credentials")
}
// 模型映射 // 模型映射
mappedModel := s.getMappedModel(account, modelID) mappedModel := s.getMappedModel(account, modelID)
@@ -171,14 +168,10 @@ func (s *AntigravityGatewayService) TestConnection(ctx context.Context, account
} }
// 构建 HTTP 请求(非流式) // 构建 HTTP 请求(非流式)
fullURL := fmt.Sprintf("%s/v1internal:generateContent", antigravity.BaseURL) req, err := antigravity.NewAPIRequest(ctx, "generateContent", accessToken, requestBody)
req, err := http.NewRequestWithContext(ctx, http.MethodPost, fullURL, bytes.NewReader(requestBody))
if err != nil { if err != nil {
return nil, err return nil, err
} }
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", "Bearer "+accessToken)
req.Header.Set("User-Agent", antigravity.UserAgent)
// 代理 URL // 代理 URL
proxyURL := "" proxyURL := ""
@@ -350,11 +343,8 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context,
return nil, fmt.Errorf("获取 access_token 失败: %w", err) return nil, fmt.Errorf("获取 access_token 失败: %w", err)
} }
// 获取 project_id // 获取 project_id(部分账户类型可能没有)
projectID := strings.TrimSpace(account.GetCredential("project_id")) projectID := strings.TrimSpace(account.GetCredential("project_id"))
if projectID == "" {
return nil, errors.New("project_id not found in credentials")
}
// 代理 URL // 代理 URL
proxyURL := "" proxyURL := ""
@@ -368,26 +358,19 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context,
return nil, fmt.Errorf("transform request: %w", err) return nil, fmt.Errorf("transform request: %w", err)
} }
// 构建上游 URL // 构建上游 action
action := "generateContent" action := "generateContent"
if claudeReq.Stream { if claudeReq.Stream {
action = "streamGenerateContent" action = "streamGenerateContent?alt=sse"
}
fullURL := fmt.Sprintf("%s/v1internal:%s", antigravity.BaseURL, action)
if claudeReq.Stream {
fullURL += "?alt=sse"
} }
// 重试循环 // 重试循环
var resp *http.Response var resp *http.Response
for attempt := 1; attempt <= antigravityMaxRetries; attempt++ { for attempt := 1; attempt <= antigravityMaxRetries; attempt++ {
upstreamReq, err := http.NewRequestWithContext(ctx, http.MethodPost, fullURL, bytes.NewReader(geminiBody)) upstreamReq, err := antigravity.NewAPIRequest(ctx, action, accessToken, geminiBody)
if err != nil { if err != nil {
return nil, err return nil, err
} }
upstreamReq.Header.Set("Content-Type", "application/json")
upstreamReq.Header.Set("Authorization", "Bearer "+accessToken)
upstreamReq.Header.Set("User-Agent", antigravity.UserAgent)
resp, err = s.httpUpstream.Do(upstreamReq, proxyURL) resp, err = s.httpUpstream.Do(upstreamReq, proxyURL)
if err != nil { if err != nil {
@@ -500,11 +483,8 @@ func (s *AntigravityGatewayService) ForwardGemini(ctx context.Context, c *gin.Co
return nil, fmt.Errorf("获取 access_token 失败: %w", err) return nil, fmt.Errorf("获取 access_token 失败: %w", err)
} }
// 获取 project_id // 获取 project_id(部分账户类型可能没有)
projectID := strings.TrimSpace(account.GetCredential("project_id")) projectID := strings.TrimSpace(account.GetCredential("project_id"))
if projectID == "" {
return nil, errors.New("project_id not found in credentials")
}
// 代理 URL // 代理 URL
proxyURL := "" proxyURL := ""
@@ -518,26 +498,22 @@ func (s *AntigravityGatewayService) ForwardGemini(ctx context.Context, c *gin.Co
return nil, err return nil, err
} }
// 构建上游 URL // 构建上游 action
upstreamAction := action upstreamAction := action
if action == "generateContent" && stream { if action == "generateContent" && stream {
upstreamAction = "streamGenerateContent" upstreamAction = "streamGenerateContent"
} }
fullURL := fmt.Sprintf("%s/v1internal:%s", antigravity.BaseURL, upstreamAction)
if stream || upstreamAction == "streamGenerateContent" { if stream || upstreamAction == "streamGenerateContent" {
fullURL += "?alt=sse" upstreamAction += "?alt=sse"
} }
// 重试循环 // 重试循环
var resp *http.Response var resp *http.Response
for attempt := 1; attempt <= antigravityMaxRetries; attempt++ { for attempt := 1; attempt <= antigravityMaxRetries; attempt++ {
upstreamReq, err := http.NewRequestWithContext(ctx, http.MethodPost, fullURL, bytes.NewReader(wrappedBody)) upstreamReq, err := antigravity.NewAPIRequest(ctx, upstreamAction, accessToken, wrappedBody)
if err != nil { if err != nil {
return nil, err return nil, err
} }
upstreamReq.Header.Set("Content-Type", "application/json")
upstreamReq.Header.Set("Authorization", "Bearer "+accessToken)
upstreamReq.Header.Set("User-Agent", antigravity.UserAgent)
resp, err = s.httpUpstream.Do(upstreamReq, proxyURL) resp, err = s.httpUpstream.Do(upstreamReq, proxyURL)
if err != nil { if err != nil {

View File

@@ -141,7 +141,7 @@ func (s *AntigravityOAuthService) ExchangeCode(ctx context.Context, input *Antig
result.Email = userInfo.Email result.Email = userInfo.Email
} }
// 获取 project_id // 获取 project_id(部分账户类型可能没有)
loadResp, err := client.LoadCodeAssist(ctx, tokenResp.AccessToken) loadResp, err := client.LoadCodeAssist(ctx, tokenResp.AccessToken)
if err != nil { if err != nil {
fmt.Printf("[AntigravityOAuth] 警告: 获取 project_id 失败: %v\n", err) fmt.Printf("[AntigravityOAuth] 警告: 获取 project_id 失败: %v\n", err)

View File

@@ -125,8 +125,8 @@ func (r *AntigravityQuotaRefresher) refreshAccountQuota(ctx context.Context, acc
accessToken := account.GetCredential("access_token") accessToken := account.GetCredential("access_token")
projectID := account.GetCredential("project_id") projectID := account.GetCredential("project_id")
if accessToken == "" || projectID == "" { if accessToken == "" {
return nil // 没有有效凭证,跳过 return nil // 没有 access_token,跳过
} }
// token 过期则跳过,由 TokenRefreshService 负责刷新 // token 过期则跳过,由 TokenRefreshService 负责刷新
@@ -151,7 +151,10 @@ func (r *AntigravityQuotaRefresher) refreshAccountQuota(ctx context.Context, acc
r.updateAccountTier(account, loadResp) r.updateAccountTier(account, loadResp)
} }
// 调用 API 获取配额 // 调用 API 获取配额(需要 projectID
if projectID == "" {
return r.accountRepo.Update(ctx, account) // 没有 projectID只更新 tier
}
modelsResp, err := client.FetchAvailableModels(ctx, accessToken, projectID) modelsResp, err := client.FetchAvailableModels(ctx, accessToken, projectID)
if err != nil { if err != nil {
return err return err