Files
sub2api/backend/internal/service/antigravity_token_provider.go
erio 528ff5d28c fix(antigravity): fast-fail on proxy unavailable, temp-unschedule account
## Problem

When a proxy is unreachable, token refresh retries up to 4 times with
30s timeout each, causing requests to hang for ~2 minutes before
failing with a generic 502 error. The failed account is not marked,
so subsequent requests keep hitting it.

## Changes

### Proxy connection fast-fail
- Set TCP dial timeout to 5s and TLS handshake timeout to 5s on
  antigravity client, so proxy connectivity issues fail within 5s
  instead of 30s
- Reduce overall HTTP client timeout from 30s to 10s
- Export `IsConnectionError` for service-layer use
- Detect proxy connection errors in `RefreshToken` and return
  immediately with "proxy unavailable" error (no retries)

### Token refresh temp-unschedulable
- Add 8s context timeout for token refresh on request path
- Mark account as temp-unschedulable for 10min when refresh fails
  (both background `TokenRefreshService` and request-path
  `GetAccessToken`)
- Sync temp-unschedulable state to Redis cache for immediate
  scheduler effect
- Inject `TempUnschedCache` into `AntigravityTokenProvider`

### Account failover
- Return `UpstreamFailoverError` on `GetAccessToken` failure in
  `Forward`/`ForwardGemini` to trigger handler-level account switch
  instead of returning 502 directly

### Proxy probe alignment
- Apply same 5s dial/TLS timeout to shared `httpclient` pool
- Reduce proxy probe timeout from 30s to 10s
2026-03-19 23:48:37 +08:00

240 lines
8.2 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package service
import (
"context"
"errors"
"log/slog"
"strconv"
"strings"
"sync"
"time"
)
const (
antigravityTokenRefreshSkew = 3 * time.Minute
antigravityTokenCacheSkew = 5 * time.Minute
antigravityBackfillCooldown = 5 * time.Minute
// antigravityRequestRefreshTimeout 请求路径上 token 刷新的最大等待时间。
// 超过此时间直接放弃刷新、标记账号临时不可调度并触发 failover
// 让后台 TokenRefreshService 在下个周期继续重试。
antigravityRequestRefreshTimeout = 8 * time.Second
)
// AntigravityTokenCache token cache interface.
type AntigravityTokenCache = GeminiTokenCache
// AntigravityTokenProvider manages access_token for antigravity accounts.
type AntigravityTokenProvider struct {
accountRepo AccountRepository
tokenCache AntigravityTokenCache
antigravityOAuthService *AntigravityOAuthService
backfillCooldown sync.Map // key: accountID -> last attempt time
refreshAPI *OAuthRefreshAPI
executor OAuthRefreshExecutor
refreshPolicy ProviderRefreshPolicy
tempUnschedCache TempUnschedCache // 用于同步更新 Redis 临时不可调度缓存
}
func NewAntigravityTokenProvider(
accountRepo AccountRepository,
tokenCache AntigravityTokenCache,
antigravityOAuthService *AntigravityOAuthService,
) *AntigravityTokenProvider {
return &AntigravityTokenProvider{
accountRepo: accountRepo,
tokenCache: tokenCache,
antigravityOAuthService: antigravityOAuthService,
refreshPolicy: AntigravityProviderRefreshPolicy(),
}
}
// SetRefreshAPI injects unified OAuth refresh API and executor.
func (p *AntigravityTokenProvider) SetRefreshAPI(api *OAuthRefreshAPI, executor OAuthRefreshExecutor) {
p.refreshAPI = api
p.executor = executor
}
// SetRefreshPolicy injects caller-side refresh policy.
func (p *AntigravityTokenProvider) SetRefreshPolicy(policy ProviderRefreshPolicy) {
p.refreshPolicy = policy
}
// SetTempUnschedCache injects temp unschedulable cache for immediate scheduler sync.
func (p *AntigravityTokenProvider) SetTempUnschedCache(cache TempUnschedCache) {
p.tempUnschedCache = cache
}
// GetAccessToken returns a valid access_token.
func (p *AntigravityTokenProvider) GetAccessToken(ctx context.Context, account *Account) (string, error) {
if account == nil {
return "", errors.New("account is nil")
}
if account.Platform != PlatformAntigravity {
return "", errors.New("not an antigravity account")
}
// upstream accounts use static api_key and never refresh oauth token.
if account.Type == AccountTypeUpstream {
apiKey := account.GetCredential("api_key")
if apiKey == "" {
return "", errors.New("upstream account missing api_key in credentials")
}
return apiKey, nil
}
if account.Type != AccountTypeOAuth {
return "", errors.New("not an antigravity oauth account")
}
cacheKey := AntigravityTokenCacheKey(account)
// 1) Try cache first.
if p.tokenCache != nil {
if token, err := p.tokenCache.GetAccessToken(ctx, cacheKey); err == nil && strings.TrimSpace(token) != "" {
return token, nil
}
}
// 2) Refresh if needed (pre-expiry skew).
expiresAt := account.GetCredentialAsTime("expires_at")
needsRefresh := expiresAt == nil || time.Until(*expiresAt) <= antigravityTokenRefreshSkew
if needsRefresh && p.refreshAPI != nil && p.executor != nil {
// 请求路径使用短超时,避免代理不通时阻塞过久(后台刷新服务会继续重试)
refreshCtx, cancel := context.WithTimeout(ctx, antigravityRequestRefreshTimeout)
defer cancel()
result, err := p.refreshAPI.RefreshIfNeeded(refreshCtx, account, p.executor, antigravityTokenRefreshSkew)
if err != nil {
// 标记账号临时不可调度,避免后续请求继续命中
p.markTempUnschedulable(account, err)
if p.refreshPolicy.OnRefreshError == ProviderRefreshErrorReturn {
return "", err
}
} else if result.LockHeld {
if p.refreshPolicy.OnLockHeld == ProviderLockHeldWaitForCache && p.tokenCache != nil {
if token, cacheErr := p.tokenCache.GetAccessToken(ctx, cacheKey); cacheErr == nil && strings.TrimSpace(token) != "" {
return token, nil
}
}
// default policy: continue with existing token.
} else {
account = result.Account
expiresAt = account.GetCredentialAsTime("expires_at")
}
} else if needsRefresh && p.tokenCache != nil {
// Backward-compatible test path when refreshAPI is not injected.
locked, err := p.tokenCache.AcquireRefreshLock(ctx, cacheKey, 30*time.Second)
if err == nil && locked {
defer func() { _ = p.tokenCache.ReleaseRefreshLock(ctx, cacheKey) }()
}
}
accessToken := account.GetCredential("access_token")
if strings.TrimSpace(accessToken) == "" {
return "", errors.New("access_token not found in credentials")
}
// Backfill project_id online when missing, with cooldown to avoid hammering.
if strings.TrimSpace(account.GetCredential("project_id")) == "" && p.antigravityOAuthService != nil {
if p.shouldAttemptBackfill(account.ID) {
p.markBackfillAttempted(account.ID)
if projectID, err := p.antigravityOAuthService.FillProjectID(ctx, account, accessToken); err == nil && projectID != "" {
account.Credentials["project_id"] = projectID
if updateErr := p.accountRepo.Update(ctx, account); updateErr != nil {
slog.Warn("antigravity_project_id_backfill_persist_failed",
"account_id", account.ID,
"error", updateErr,
)
}
}
}
}
// 3) Populate cache with TTL.
if p.tokenCache != nil {
latestAccount, isStale := CheckTokenVersion(ctx, account, p.accountRepo)
if isStale && latestAccount != nil {
slog.Debug("antigravity_token_version_stale_use_latest", "account_id", account.ID)
accessToken = latestAccount.GetCredential("access_token")
if strings.TrimSpace(accessToken) == "" {
return "", errors.New("access_token not found after version check")
}
} else {
ttl := 30 * time.Minute
if expiresAt != nil {
until := time.Until(*expiresAt)
switch {
case until > antigravityTokenCacheSkew:
ttl = until - antigravityTokenCacheSkew
case until > 0:
ttl = until
default:
ttl = time.Minute
}
}
_ = p.tokenCache.SetAccessToken(ctx, cacheKey, accessToken, ttl)
}
}
return accessToken, nil
}
// shouldAttemptBackfill checks backfill cooldown.
func (p *AntigravityTokenProvider) shouldAttemptBackfill(accountID int64) bool {
if v, ok := p.backfillCooldown.Load(accountID); ok {
if lastAttempt, ok := v.(time.Time); ok {
return time.Since(lastAttempt) > antigravityBackfillCooldown
}
}
return true
}
// markTempUnschedulable 在请求路径上 token 刷新失败时标记账号临时不可调度。
// 同时写 DB 和 Redis 缓存,确保调度器立即跳过该账号。
// 使用 background context 因为请求 context 可能已超时。
func (p *AntigravityTokenProvider) markTempUnschedulable(account *Account, refreshErr error) {
if p.accountRepo == nil || account == nil {
return
}
now := time.Now()
until := now.Add(tokenRefreshTempUnschedDuration)
reason := "token refresh failed on request path: " + refreshErr.Error()
bgCtx := context.Background()
if err := p.accountRepo.SetTempUnschedulable(bgCtx, account.ID, until, reason); err != nil {
slog.Warn("antigravity_token_provider.set_temp_unschedulable_failed",
"account_id", account.ID,
"error", err,
)
return
}
slog.Warn("antigravity_token_provider.temp_unschedulable_set",
"account_id", account.ID,
"until", until.Format(time.RFC3339),
"reason", reason,
)
// 同步写 Redis 缓存,调度器立即生效
if p.tempUnschedCache != nil {
state := &TempUnschedState{
UntilUnix: until.Unix(),
TriggeredAtUnix: now.Unix(),
ErrorMessage: reason,
}
if err := p.tempUnschedCache.SetTempUnsched(bgCtx, account.ID, state); err != nil {
slog.Warn("antigravity_token_provider.temp_unsched_cache_set_failed",
"account_id", account.ID,
"error", err,
)
}
}
}
func (p *AntigravityTokenProvider) markBackfillAttempted(accountID int64) {
p.backfillCooldown.Store(accountID, time.Now())
}
func AntigravityTokenCacheKey(account *Account) string {
projectID := strings.TrimSpace(account.GetCredential("project_id"))
if projectID != "" {
return "ag:" + projectID
}
return "ag:account:" + strconv.FormatInt(account.ID, 10)
}