fix(Antigravity): 支持无 project_id 的账户类型
- 移除 project_id 强制检查,部分账户类型 API 不返回此字段 - 重构:提取 antigravity.NewAPIRequest() 统一创建 API 请求 - quota_refresher: 无 project_id 时仍可更新 tier 信息
This commit is contained in:
@@ -1,6 +1,7 @@
|
|||||||
package antigravity
|
package antigravity
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"bytes"
|
||||||
"context"
|
"context"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
@@ -11,6 +12,19 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// NewAPIRequest 创建 Antigravity API 请求(v1internal 端点)
|
||||||
|
func NewAPIRequest(ctx context.Context, action, accessToken string, body []byte) (*http.Request, error) {
|
||||||
|
apiURL := fmt.Sprintf("%s/v1internal:%s", BaseURL, action)
|
||||||
|
req, err := http.NewRequestWithContext(ctx, http.MethodPost, apiURL, bytes.NewReader(body))
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
req.Header.Set("Authorization", "Bearer "+accessToken)
|
||||||
|
req.Header.Set("User-Agent", UserAgent)
|
||||||
|
return req, nil
|
||||||
|
}
|
||||||
|
|
||||||
// TokenResponse Google OAuth token 响应
|
// TokenResponse Google OAuth token 响应
|
||||||
type TokenResponse struct {
|
type TokenResponse struct {
|
||||||
AccessToken string `json:"access_token"`
|
AccessToken string `json:"access_token"`
|
||||||
|
|||||||
@@ -148,11 +148,8 @@ func (s *AntigravityGatewayService) TestConnection(ctx context.Context, account
|
|||||||
return nil, fmt.Errorf("获取 access_token 失败: %w", err)
|
return nil, fmt.Errorf("获取 access_token 失败: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// 获取 project_id
|
// 获取 project_id(部分账户类型可能没有)
|
||||||
projectID := strings.TrimSpace(account.GetCredential("project_id"))
|
projectID := strings.TrimSpace(account.GetCredential("project_id"))
|
||||||
if projectID == "" {
|
|
||||||
return nil, errors.New("project_id not found in credentials")
|
|
||||||
}
|
|
||||||
|
|
||||||
// 模型映射
|
// 模型映射
|
||||||
mappedModel := s.getMappedModel(account, modelID)
|
mappedModel := s.getMappedModel(account, modelID)
|
||||||
@@ -171,14 +168,10 @@ func (s *AntigravityGatewayService) TestConnection(ctx context.Context, account
|
|||||||
}
|
}
|
||||||
|
|
||||||
// 构建 HTTP 请求(非流式)
|
// 构建 HTTP 请求(非流式)
|
||||||
fullURL := fmt.Sprintf("%s/v1internal:generateContent", antigravity.BaseURL)
|
req, err := antigravity.NewAPIRequest(ctx, "generateContent", accessToken, requestBody)
|
||||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, fullURL, bytes.NewReader(requestBody))
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
req.Header.Set("Content-Type", "application/json")
|
|
||||||
req.Header.Set("Authorization", "Bearer "+accessToken)
|
|
||||||
req.Header.Set("User-Agent", antigravity.UserAgent)
|
|
||||||
|
|
||||||
// 代理 URL
|
// 代理 URL
|
||||||
proxyURL := ""
|
proxyURL := ""
|
||||||
@@ -350,11 +343,8 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context,
|
|||||||
return nil, fmt.Errorf("获取 access_token 失败: %w", err)
|
return nil, fmt.Errorf("获取 access_token 失败: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// 获取 project_id
|
// 获取 project_id(部分账户类型可能没有)
|
||||||
projectID := strings.TrimSpace(account.GetCredential("project_id"))
|
projectID := strings.TrimSpace(account.GetCredential("project_id"))
|
||||||
if projectID == "" {
|
|
||||||
return nil, errors.New("project_id not found in credentials")
|
|
||||||
}
|
|
||||||
|
|
||||||
// 代理 URL
|
// 代理 URL
|
||||||
proxyURL := ""
|
proxyURL := ""
|
||||||
@@ -368,26 +358,19 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context,
|
|||||||
return nil, fmt.Errorf("transform request: %w", err)
|
return nil, fmt.Errorf("transform request: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// 构建上游 URL
|
// 构建上游 action
|
||||||
action := "generateContent"
|
action := "generateContent"
|
||||||
if claudeReq.Stream {
|
if claudeReq.Stream {
|
||||||
action = "streamGenerateContent"
|
action = "streamGenerateContent?alt=sse"
|
||||||
}
|
|
||||||
fullURL := fmt.Sprintf("%s/v1internal:%s", antigravity.BaseURL, action)
|
|
||||||
if claudeReq.Stream {
|
|
||||||
fullURL += "?alt=sse"
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// 重试循环
|
// 重试循环
|
||||||
var resp *http.Response
|
var resp *http.Response
|
||||||
for attempt := 1; attempt <= antigravityMaxRetries; attempt++ {
|
for attempt := 1; attempt <= antigravityMaxRetries; attempt++ {
|
||||||
upstreamReq, err := http.NewRequestWithContext(ctx, http.MethodPost, fullURL, bytes.NewReader(geminiBody))
|
upstreamReq, err := antigravity.NewAPIRequest(ctx, action, accessToken, geminiBody)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
upstreamReq.Header.Set("Content-Type", "application/json")
|
|
||||||
upstreamReq.Header.Set("Authorization", "Bearer "+accessToken)
|
|
||||||
upstreamReq.Header.Set("User-Agent", antigravity.UserAgent)
|
|
||||||
|
|
||||||
resp, err = s.httpUpstream.Do(upstreamReq, proxyURL)
|
resp, err = s.httpUpstream.Do(upstreamReq, proxyURL)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -500,11 +483,8 @@ func (s *AntigravityGatewayService) ForwardGemini(ctx context.Context, c *gin.Co
|
|||||||
return nil, fmt.Errorf("获取 access_token 失败: %w", err)
|
return nil, fmt.Errorf("获取 access_token 失败: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// 获取 project_id
|
// 获取 project_id(部分账户类型可能没有)
|
||||||
projectID := strings.TrimSpace(account.GetCredential("project_id"))
|
projectID := strings.TrimSpace(account.GetCredential("project_id"))
|
||||||
if projectID == "" {
|
|
||||||
return nil, errors.New("project_id not found in credentials")
|
|
||||||
}
|
|
||||||
|
|
||||||
// 代理 URL
|
// 代理 URL
|
||||||
proxyURL := ""
|
proxyURL := ""
|
||||||
@@ -518,26 +498,22 @@ func (s *AntigravityGatewayService) ForwardGemini(ctx context.Context, c *gin.Co
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
// 构建上游 URL
|
// 构建上游 action
|
||||||
upstreamAction := action
|
upstreamAction := action
|
||||||
if action == "generateContent" && stream {
|
if action == "generateContent" && stream {
|
||||||
upstreamAction = "streamGenerateContent"
|
upstreamAction = "streamGenerateContent"
|
||||||
}
|
}
|
||||||
fullURL := fmt.Sprintf("%s/v1internal:%s", antigravity.BaseURL, upstreamAction)
|
|
||||||
if stream || upstreamAction == "streamGenerateContent" {
|
if stream || upstreamAction == "streamGenerateContent" {
|
||||||
fullURL += "?alt=sse"
|
upstreamAction += "?alt=sse"
|
||||||
}
|
}
|
||||||
|
|
||||||
// 重试循环
|
// 重试循环
|
||||||
var resp *http.Response
|
var resp *http.Response
|
||||||
for attempt := 1; attempt <= antigravityMaxRetries; attempt++ {
|
for attempt := 1; attempt <= antigravityMaxRetries; attempt++ {
|
||||||
upstreamReq, err := http.NewRequestWithContext(ctx, http.MethodPost, fullURL, bytes.NewReader(wrappedBody))
|
upstreamReq, err := antigravity.NewAPIRequest(ctx, upstreamAction, accessToken, wrappedBody)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
upstreamReq.Header.Set("Content-Type", "application/json")
|
|
||||||
upstreamReq.Header.Set("Authorization", "Bearer "+accessToken)
|
|
||||||
upstreamReq.Header.Set("User-Agent", antigravity.UserAgent)
|
|
||||||
|
|
||||||
resp, err = s.httpUpstream.Do(upstreamReq, proxyURL)
|
resp, err = s.httpUpstream.Do(upstreamReq, proxyURL)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|||||||
@@ -141,7 +141,7 @@ func (s *AntigravityOAuthService) ExchangeCode(ctx context.Context, input *Antig
|
|||||||
result.Email = userInfo.Email
|
result.Email = userInfo.Email
|
||||||
}
|
}
|
||||||
|
|
||||||
// 获取 project_id
|
// 获取 project_id(部分账户类型可能没有)
|
||||||
loadResp, err := client.LoadCodeAssist(ctx, tokenResp.AccessToken)
|
loadResp, err := client.LoadCodeAssist(ctx, tokenResp.AccessToken)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
fmt.Printf("[AntigravityOAuth] 警告: 获取 project_id 失败: %v\n", err)
|
fmt.Printf("[AntigravityOAuth] 警告: 获取 project_id 失败: %v\n", err)
|
||||||
|
|||||||
@@ -125,8 +125,8 @@ func (r *AntigravityQuotaRefresher) refreshAccountQuota(ctx context.Context, acc
|
|||||||
accessToken := account.GetCredential("access_token")
|
accessToken := account.GetCredential("access_token")
|
||||||
projectID := account.GetCredential("project_id")
|
projectID := account.GetCredential("project_id")
|
||||||
|
|
||||||
if accessToken == "" || projectID == "" {
|
if accessToken == "" {
|
||||||
return nil // 没有有效凭证,跳过
|
return nil // 没有 access_token,跳过
|
||||||
}
|
}
|
||||||
|
|
||||||
// token 过期则跳过,由 TokenRefreshService 负责刷新
|
// token 过期则跳过,由 TokenRefreshService 负责刷新
|
||||||
@@ -151,7 +151,10 @@ func (r *AntigravityQuotaRefresher) refreshAccountQuota(ctx context.Context, acc
|
|||||||
r.updateAccountTier(account, loadResp)
|
r.updateAccountTier(account, loadResp)
|
||||||
}
|
}
|
||||||
|
|
||||||
// 调用 API 获取配额
|
// 调用 API 获取配额(需要 projectID)
|
||||||
|
if projectID == "" {
|
||||||
|
return r.accountRepo.Update(ctx, account) // 没有 projectID,只更新 tier
|
||||||
|
}
|
||||||
modelsResp, err := client.FetchAvailableModels(ctx, accessToken, projectID)
|
modelsResp, err := client.FetchAvailableModels(ctx, accessToken, projectID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
|
|||||||
Reference in New Issue
Block a user