Files
sub2api/backend/internal/service/antigravity_token_provider.go
Edric Li a4a46a8618 feat(antigravity): 添加 onboardUser 支持并修复 project_id 补齐逻辑
- 新增 OnboardUser API 客户端方法,支持账号 onboarding 获取 project_id
- loadProjectIDWithRetry 增加 onboard 回退:LoadCodeAssist 未返回 project_id 时自动触发 onboarding
- GetAccessToken 中 project_id 补齐改用轻量 FillProjectID 替代全量 RefreshAccountToken
- 补齐逻辑增加 5 分钟冷却机制,防止频繁重试
- OnboardUser 轮询等待改为 context 感知,支持提前取消
- 提取 mergeCredentials 辅助方法消除重复代码
- 新增 extractProjectIDFromOnboardResponse 和 resolveDefaultTierID 单元测试
2026-02-11 13:41:55 +08:00

191 lines
6.5 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package service
import (
"context"
"errors"
"log"
"log/slog"
"strconv"
"strings"
"sync"
"time"
)
const (
antigravityTokenRefreshSkew = 3 * time.Minute
antigravityTokenCacheSkew = 5 * time.Minute
antigravityBackfillCooldown = 5 * time.Minute
)
// AntigravityTokenCache Token 缓存接口(复用 GeminiTokenCache 接口定义)
type AntigravityTokenCache = GeminiTokenCache
// AntigravityTokenProvider 管理 Antigravity 账户的 access_token
type AntigravityTokenProvider struct {
accountRepo AccountRepository
tokenCache AntigravityTokenCache
antigravityOAuthService *AntigravityOAuthService
backfillCooldown sync.Map // key: int64 (account.ID) → value: time.Time
}
func NewAntigravityTokenProvider(
accountRepo AccountRepository,
tokenCache AntigravityTokenCache,
antigravityOAuthService *AntigravityOAuthService,
) *AntigravityTokenProvider {
return &AntigravityTokenProvider{
accountRepo: accountRepo,
tokenCache: tokenCache,
antigravityOAuthService: antigravityOAuthService,
}
}
// GetAccessToken 获取有效的 access_token
func (p *AntigravityTokenProvider) GetAccessToken(ctx context.Context, account *Account) (string, error) {
if account == nil {
return "", errors.New("account is nil")
}
if account.Platform != PlatformAntigravity {
return "", errors.New("not an antigravity account")
}
// upstream 类型:直接从 credentials 读取 api_key不走 OAuth 刷新流程
if account.Type == AccountTypeUpstream {
apiKey := account.GetCredential("api_key")
if apiKey == "" {
return "", errors.New("upstream account missing api_key in credentials")
}
return apiKey, nil
}
if account.Type != AccountTypeOAuth {
return "", errors.New("not an antigravity oauth account")
}
cacheKey := AntigravityTokenCacheKey(account)
// 1. 先尝试缓存
if p.tokenCache != nil {
if token, err := p.tokenCache.GetAccessToken(ctx, cacheKey); err == nil && strings.TrimSpace(token) != "" {
return token, nil
}
}
// 2. 如果即将过期则刷新
expiresAt := account.GetCredentialAsTime("expires_at")
needsRefresh := expiresAt == nil || time.Until(*expiresAt) <= antigravityTokenRefreshSkew
if needsRefresh && p.tokenCache != nil {
locked, err := p.tokenCache.AcquireRefreshLock(ctx, cacheKey, 30*time.Second)
if err == nil && locked {
defer func() { _ = p.tokenCache.ReleaseRefreshLock(ctx, cacheKey) }()
// 拿到锁后再次检查缓存(另一个 worker 可能已刷新)
if token, err := p.tokenCache.GetAccessToken(ctx, cacheKey); err == nil && strings.TrimSpace(token) != "" {
return token, nil
}
// 从数据库获取最新账户信息
fresh, err := p.accountRepo.GetByID(ctx, account.ID)
if err == nil && fresh != nil {
account = fresh
}
expiresAt = account.GetCredentialAsTime("expires_at")
if expiresAt == nil || time.Until(*expiresAt) <= antigravityTokenRefreshSkew {
if p.antigravityOAuthService == nil {
return "", errors.New("antigravity oauth service not configured")
}
tokenInfo, err := p.antigravityOAuthService.RefreshAccountToken(ctx, account)
if err != nil {
return "", err
}
p.mergeCredentials(account, tokenInfo)
if updateErr := p.accountRepo.Update(ctx, account); updateErr != nil {
log.Printf("[AntigravityTokenProvider] Failed to update account credentials: %v", updateErr)
}
expiresAt = account.GetCredentialAsTime("expires_at")
}
}
}
accessToken := account.GetCredential("access_token")
if strings.TrimSpace(accessToken) == "" {
return "", errors.New("access_token not found in credentials")
}
// 如果账号还没有 project_id尝试在线补齐避免请求 daily/sandbox 时出现
// "Invalid project resource name projects/"。
// 仅调用 loadProjectIDWithRetry不刷新 OAuth token带冷却机制防止频繁重试。
if strings.TrimSpace(account.GetCredential("project_id")) == "" && p.antigravityOAuthService != nil {
if p.shouldAttemptBackfill(account.ID) {
p.markBackfillAttempted(account.ID)
if projectID, err := p.antigravityOAuthService.FillProjectID(ctx, account, accessToken); err == nil && projectID != "" {
account.Credentials["project_id"] = projectID
if updateErr := p.accountRepo.Update(ctx, account); updateErr != nil {
log.Printf("[AntigravityTokenProvider] project_id 补齐持久化失败: %v", updateErr)
}
}
}
}
// 3. 存入缓存(验证版本后再写入,避免异步刷新任务与请求线程的竞态条件)
if p.tokenCache != nil {
latestAccount, isStale := CheckTokenVersion(ctx, account, p.accountRepo)
if isStale && latestAccount != nil {
// 版本过时,使用 DB 中的最新 token
slog.Debug("antigravity_token_version_stale_use_latest", "account_id", account.ID)
accessToken = latestAccount.GetCredential("access_token")
if strings.TrimSpace(accessToken) == "" {
return "", errors.New("access_token not found after version check")
}
// 不写入缓存,让下次请求重新处理
} else {
ttl := 30 * time.Minute
if expiresAt != nil {
until := time.Until(*expiresAt)
switch {
case until > antigravityTokenCacheSkew:
ttl = until - antigravityTokenCacheSkew
case until > 0:
ttl = until
default:
ttl = time.Minute
}
}
_ = p.tokenCache.SetAccessToken(ctx, cacheKey, accessToken, ttl)
}
}
return accessToken, nil
}
// mergeCredentials 将 tokenInfo 构建的凭证合并到 account 中,保留原有未覆盖的字段
func (p *AntigravityTokenProvider) mergeCredentials(account *Account, tokenInfo *AntigravityTokenInfo) {
newCredentials := p.antigravityOAuthService.BuildAccountCredentials(tokenInfo)
for k, v := range account.Credentials {
if _, exists := newCredentials[k]; !exists {
newCredentials[k] = v
}
}
account.Credentials = newCredentials
}
// shouldAttemptBackfill 检查是否应该尝试补齐 project_id冷却期内不重复尝试
func (p *AntigravityTokenProvider) shouldAttemptBackfill(accountID int64) bool {
if v, ok := p.backfillCooldown.Load(accountID); ok {
if lastAttempt, ok := v.(time.Time); ok {
return time.Since(lastAttempt) > antigravityBackfillCooldown
}
}
return true
}
func (p *AntigravityTokenProvider) markBackfillAttempted(accountID int64) {
p.backfillCooldown.Store(accountID, time.Now())
}
func AntigravityTokenCacheKey(account *Account) string {
projectID := strings.TrimSpace(account.GetCredential("project_id"))
if projectID != "" {
return "ag:" + projectID
}
return "ag:account:" + strconv.FormatInt(account.ID, 10)
}