✨ feat(antigravity): 添加 onboardUser 支持并修复 project_id 补齐逻辑
- 新增 OnboardUser API 客户端方法,支持账号 onboarding 获取 project_id - loadProjectIDWithRetry 增加 onboard 回退:LoadCodeAssist 未返回 project_id 时自动触发 onboarding - GetAccessToken 中 project_id 补齐改用轻量 FillProjectID 替代全量 RefreshAccountToken - 补齐逻辑增加 5 分钟冷却机制,防止频繁重试 - OnboardUser 轮询等待改为 context 感知,支持提前取消 - 提取 mergeCredentials 辅助方法消除重复代码 - 新增 extractProjectIDFromOnboardResponse 和 resolveDefaultTierID 单元测试
This commit is contained in:
@@ -273,12 +273,21 @@ func (s *AntigravityOAuthService) loadProjectIDWithRetry(ctx context.Context, ac
|
||||
}
|
||||
|
||||
client := antigravity.NewClient(proxyURL)
|
||||
loadResp, _, err := client.LoadCodeAssist(ctx, accessToken)
|
||||
loadResp, loadRaw, err := client.LoadCodeAssist(ctx, accessToken)
|
||||
|
||||
if err == nil && loadResp != nil && loadResp.CloudAICompanionProject != "" {
|
||||
return loadResp.CloudAICompanionProject, nil
|
||||
}
|
||||
|
||||
if err == nil {
|
||||
if projectID, onboardErr := tryOnboardProjectID(ctx, client, accessToken, loadRaw); onboardErr == nil && projectID != "" {
|
||||
return projectID, nil
|
||||
} else if onboardErr != nil {
|
||||
lastErr = onboardErr
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
// 记录错误
|
||||
if err != nil {
|
||||
lastErr = err
|
||||
@@ -292,6 +301,65 @@ func (s *AntigravityOAuthService) loadProjectIDWithRetry(ctx context.Context, ac
|
||||
return "", fmt.Errorf("获取 project_id 失败 (重试 %d 次后): %w", maxRetries, lastErr)
|
||||
}
|
||||
|
||||
func tryOnboardProjectID(ctx context.Context, client *antigravity.Client, accessToken string, loadRaw map[string]any) (string, error) {
|
||||
tierID := resolveDefaultTierID(loadRaw)
|
||||
if tierID == "" {
|
||||
return "", fmt.Errorf("loadCodeAssist 未返回可用的默认 tier")
|
||||
}
|
||||
|
||||
projectID, err := client.OnboardUser(ctx, accessToken, tierID)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("onboardUser 失败 (tier=%s): %w", tierID, err)
|
||||
}
|
||||
return projectID, nil
|
||||
}
|
||||
|
||||
func resolveDefaultTierID(loadRaw map[string]any) string {
|
||||
if len(loadRaw) == 0 {
|
||||
return ""
|
||||
}
|
||||
|
||||
rawTiers, ok := loadRaw["allowedTiers"]
|
||||
if !ok {
|
||||
return ""
|
||||
}
|
||||
|
||||
tiers, ok := rawTiers.([]any)
|
||||
if !ok {
|
||||
return ""
|
||||
}
|
||||
|
||||
for _, rawTier := range tiers {
|
||||
tier, ok := rawTier.(map[string]any)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
if isDefault, _ := tier["isDefault"].(bool); !isDefault {
|
||||
continue
|
||||
}
|
||||
if id, ok := tier["id"].(string); ok {
|
||||
id = strings.TrimSpace(id)
|
||||
if id != "" {
|
||||
return id
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return ""
|
||||
}
|
||||
|
||||
// FillProjectID 仅获取 project_id,不刷新 OAuth token
|
||||
func (s *AntigravityOAuthService) FillProjectID(ctx context.Context, account *Account, accessToken string) (string, error) {
|
||||
var proxyURL string
|
||||
if account.ProxyID != nil {
|
||||
proxy, err := s.proxyRepo.GetByID(ctx, *account.ProxyID)
|
||||
if err == nil && proxy != nil {
|
||||
proxyURL = proxy.URL()
|
||||
}
|
||||
}
|
||||
return s.loadProjectIDWithRetry(ctx, accessToken, proxyURL, 3)
|
||||
}
|
||||
|
||||
// BuildAccountCredentials 构建账户凭证
|
||||
func (s *AntigravityOAuthService) BuildAccountCredentials(tokenInfo *AntigravityTokenInfo) map[string]any {
|
||||
creds := map[string]any{
|
||||
|
||||
82
backend/internal/service/antigravity_oauth_service_test.go
Normal file
82
backend/internal/service/antigravity_oauth_service_test.go
Normal file
@@ -0,0 +1,82 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestResolveDefaultTierID(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
loadRaw map[string]any
|
||||
want string
|
||||
}{
|
||||
{
|
||||
name: "nil loadRaw",
|
||||
loadRaw: nil,
|
||||
want: "",
|
||||
},
|
||||
{
|
||||
name: "missing allowedTiers",
|
||||
loadRaw: map[string]any{
|
||||
"paidTier": map[string]any{"id": "g1-pro-tier"},
|
||||
},
|
||||
want: "",
|
||||
},
|
||||
{
|
||||
name: "empty allowedTiers",
|
||||
loadRaw: map[string]any{"allowedTiers": []any{}},
|
||||
want: "",
|
||||
},
|
||||
{
|
||||
name: "tier missing id field",
|
||||
loadRaw: map[string]any{
|
||||
"allowedTiers": []any{
|
||||
map[string]any{"isDefault": true},
|
||||
},
|
||||
},
|
||||
want: "",
|
||||
},
|
||||
{
|
||||
name: "allowedTiers but no default",
|
||||
loadRaw: map[string]any{
|
||||
"allowedTiers": []any{
|
||||
map[string]any{"id": "free-tier", "isDefault": false},
|
||||
map[string]any{"id": "standard-tier", "isDefault": false},
|
||||
},
|
||||
},
|
||||
want: "",
|
||||
},
|
||||
{
|
||||
name: "default tier found",
|
||||
loadRaw: map[string]any{
|
||||
"allowedTiers": []any{
|
||||
map[string]any{"id": "free-tier", "isDefault": true},
|
||||
map[string]any{"id": "standard-tier", "isDefault": false},
|
||||
},
|
||||
},
|
||||
want: "free-tier",
|
||||
},
|
||||
{
|
||||
name: "default tier id with spaces",
|
||||
loadRaw: map[string]any{
|
||||
"allowedTiers": []any{
|
||||
map[string]any{"id": " standard-tier ", "isDefault": true},
|
||||
},
|
||||
},
|
||||
want: "standard-tier",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
got := resolveDefaultTierID(tc.loadRaw)
|
||||
if got != tc.want {
|
||||
t.Fatalf("resolveDefaultTierID() = %q, want %q", got, tc.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -7,12 +7,14 @@ import (
|
||||
"log/slog"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
const (
|
||||
antigravityTokenRefreshSkew = 3 * time.Minute
|
||||
antigravityTokenCacheSkew = 5 * time.Minute
|
||||
antigravityBackfillCooldown = 5 * time.Minute
|
||||
)
|
||||
|
||||
// AntigravityTokenCache Token 缓存接口(复用 GeminiTokenCache 接口定义)
|
||||
@@ -23,6 +25,7 @@ type AntigravityTokenProvider struct {
|
||||
accountRepo AccountRepository
|
||||
tokenCache AntigravityTokenCache
|
||||
antigravityOAuthService *AntigravityOAuthService
|
||||
backfillCooldown sync.Map // key: int64 (account.ID) → value: time.Time
|
||||
}
|
||||
|
||||
func NewAntigravityTokenProvider(
|
||||
@@ -93,13 +96,7 @@ func (p *AntigravityTokenProvider) GetAccessToken(ctx context.Context, account *
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
newCredentials := p.antigravityOAuthService.BuildAccountCredentials(tokenInfo)
|
||||
for k, v := range account.Credentials {
|
||||
if _, exists := newCredentials[k]; !exists {
|
||||
newCredentials[k] = v
|
||||
}
|
||||
}
|
||||
account.Credentials = newCredentials
|
||||
p.mergeCredentials(account, tokenInfo)
|
||||
if updateErr := p.accountRepo.Update(ctx, account); updateErr != nil {
|
||||
log.Printf("[AntigravityTokenProvider] Failed to update account credentials: %v", updateErr)
|
||||
}
|
||||
@@ -113,6 +110,21 @@ func (p *AntigravityTokenProvider) GetAccessToken(ctx context.Context, account *
|
||||
return "", errors.New("access_token not found in credentials")
|
||||
}
|
||||
|
||||
// 如果账号还没有 project_id,尝试在线补齐,避免请求 daily/sandbox 时出现
|
||||
// "Invalid project resource name projects/"。
|
||||
// 仅调用 loadProjectIDWithRetry,不刷新 OAuth token;带冷却机制防止频繁重试。
|
||||
if strings.TrimSpace(account.GetCredential("project_id")) == "" && p.antigravityOAuthService != nil {
|
||||
if p.shouldAttemptBackfill(account.ID) {
|
||||
p.markBackfillAttempted(account.ID)
|
||||
if projectID, err := p.antigravityOAuthService.FillProjectID(ctx, account, accessToken); err == nil && projectID != "" {
|
||||
account.Credentials["project_id"] = projectID
|
||||
if updateErr := p.accountRepo.Update(ctx, account); updateErr != nil {
|
||||
log.Printf("[AntigravityTokenProvider] project_id 补齐持久化失败: %v", updateErr)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 3. 存入缓存(验证版本后再写入,避免异步刷新任务与请求线程的竞态条件)
|
||||
if p.tokenCache != nil {
|
||||
latestAccount, isStale := CheckTokenVersion(ctx, account, p.accountRepo)
|
||||
@@ -144,6 +156,31 @@ func (p *AntigravityTokenProvider) GetAccessToken(ctx context.Context, account *
|
||||
return accessToken, nil
|
||||
}
|
||||
|
||||
// mergeCredentials 将 tokenInfo 构建的凭证合并到 account 中,保留原有未覆盖的字段
|
||||
func (p *AntigravityTokenProvider) mergeCredentials(account *Account, tokenInfo *AntigravityTokenInfo) {
|
||||
newCredentials := p.antigravityOAuthService.BuildAccountCredentials(tokenInfo)
|
||||
for k, v := range account.Credentials {
|
||||
if _, exists := newCredentials[k]; !exists {
|
||||
newCredentials[k] = v
|
||||
}
|
||||
}
|
||||
account.Credentials = newCredentials
|
||||
}
|
||||
|
||||
// shouldAttemptBackfill 检查是否应该尝试补齐 project_id(冷却期内不重复尝试)
|
||||
func (p *AntigravityTokenProvider) shouldAttemptBackfill(accountID int64) bool {
|
||||
if v, ok := p.backfillCooldown.Load(accountID); ok {
|
||||
if lastAttempt, ok := v.(time.Time); ok {
|
||||
return time.Since(lastAttempt) > antigravityBackfillCooldown
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func (p *AntigravityTokenProvider) markBackfillAttempted(accountID int64) {
|
||||
p.backfillCooldown.Store(accountID, time.Now())
|
||||
}
|
||||
|
||||
func AntigravityTokenCacheKey(account *Account) string {
|
||||
projectID := strings.TrimSpace(account.GetCredential("project_id"))
|
||||
if projectID != "" {
|
||||
|
||||
Reference in New Issue
Block a user