fix: resolve refresh token race condition causing false invalid_grant errors

When multiple goroutines/workers concurrently refresh the same OAuth token,
the first succeeds but invalidates the old refresh_token (rotation). Subsequent
attempts using the stale token get invalid_grant, which was incorrectly treated
as non-retryable, permanently marking the account as ERROR.

Three complementary fixes:
1. Race-aware recovery: after invalid_grant, re-read DB to check if another
   worker already refreshed (refresh_token changed) — return success instead
   of error
2. In-process mutex (sync.Map of per-account locks): prevents concurrent
   refreshes within the same process, complementing the Redis distributed lock
3. Increase default lock TTL from 30s to 60s to reduce TTL-expiry races

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
haruka
2026-03-30 16:23:38 +08:00
parent 6a2cf09ee0
commit ad2cd97618
2 changed files with 286 additions and 6 deletions

View File

@@ -5,6 +5,8 @@ import (
"fmt"
"log/slog"
"strconv"
"strings"
"sync"
"time"
)
@@ -17,7 +19,7 @@ type OAuthRefreshExecutor interface {
CacheKey(account *Account) string
}
const refreshLockTTL = 30 * time.Second
const defaultRefreshLockTTL = 60 * time.Second
// OAuthRefreshResult 统一刷新结果
type OAuthRefreshResult struct {
@@ -28,20 +30,34 @@ type OAuthRefreshResult struct {
}
// OAuthRefreshAPI 统一的 OAuth Token 刷新入口
// 封装分布式锁、DB 重读、已刷新检查等通用逻辑
// 封装分布式锁、进程内互斥锁、DB 重读、已刷新检查、竞争恢复等通用逻辑
type OAuthRefreshAPI struct {
accountRepo AccountRepository
tokenCache GeminiTokenCache // 可选nil = 无锁
tokenCache GeminiTokenCache // 可选nil = 无分布式
lockTTL time.Duration
localLocks sync.Map // key: cacheKey string -> value: *sync.Mutex
}
// NewOAuthRefreshAPI 创建统一刷新 API
func NewOAuthRefreshAPI(accountRepo AccountRepository, tokenCache GeminiTokenCache) *OAuthRefreshAPI {
// 可选传入 lockTTL 覆盖默认的 60s 分布式锁 TTL
func NewOAuthRefreshAPI(accountRepo AccountRepository, tokenCache GeminiTokenCache, lockTTL ...time.Duration) *OAuthRefreshAPI {
ttl := defaultRefreshLockTTL
if len(lockTTL) > 0 && lockTTL[0] > 0 {
ttl = lockTTL[0]
}
return &OAuthRefreshAPI{
accountRepo: accountRepo,
tokenCache: tokenCache,
lockTTL: ttl,
}
}
// getLocalLock 返回指定 cacheKey 的进程内互斥锁
func (api *OAuthRefreshAPI) getLocalLock(cacheKey string) *sync.Mutex {
val, _ := api.localLocks.LoadOrStore(cacheKey, &sync.Mutex{})
return val.(*sync.Mutex)
}
// RefreshIfNeeded 在分布式锁保护下按需刷新 OAuth token
//
// 流程:
@@ -59,12 +75,17 @@ func (api *OAuthRefreshAPI) RefreshIfNeeded(
) (*OAuthRefreshResult, error) {
cacheKey := executor.CacheKey(account)
// 0. 获取进程内互斥锁(防止同一进程内的并发刷新竞争)
localMu := api.getLocalLock(cacheKey)
localMu.Lock()
defer localMu.Unlock()
// 1. 获取分布式锁
lockAcquired := false
if api.tokenCache != nil {
acquired, lockErr := api.tokenCache.AcquireRefreshLock(ctx, cacheKey, refreshLockTTL)
acquired, lockErr := api.tokenCache.AcquireRefreshLock(ctx, cacheKey, api.lockTTL)
if lockErr != nil {
// Redis 错误,降级为无锁刷新
// Redis 错误,降级为无锁刷新(进程内互斥锁仍生效)
slog.Warn("oauth_refresh_lock_failed_degraded",
"account_id", account.ID,
"cache_key", cacheKey,
@@ -102,6 +123,19 @@ func (api *OAuthRefreshAPI) RefreshIfNeeded(
// 4. 执行平台特定刷新逻辑
newCredentials, refreshErr := executor.Refresh(ctx, freshAccount)
if refreshErr != nil {
// 竞争恢复invalid_grant 可能是另一个 worker 已消费了旧 refresh_token
// 重新读取 DB如果 refresh_token 已更新则说明是竞争,返回成功
if isInvalidGrantError(refreshErr) {
if recoveredAccount, recovered := api.tryRecoverFromRefreshRace(ctx, freshAccount); recovered {
slog.Info("oauth_refresh_race_recovered",
"account_id", freshAccount.ID,
"platform", freshAccount.Platform,
)
return &OAuthRefreshResult{
Account: recoveredAccount,
}, nil
}
}
return nil, refreshErr
}
@@ -126,6 +160,33 @@ func (api *OAuthRefreshAPI) RefreshIfNeeded(
}, nil
}
// isInvalidGrantError 检查错误是否为 invalid_grant
func isInvalidGrantError(err error) bool {
return err != nil && strings.Contains(strings.ToLower(err.Error()), "invalid_grant")
}
// tryRecoverFromRefreshRace 在 invalid_grant 错误后尝试竞争恢复
// 重新读取 DB如果 refresh_token 已改变(说明另一个 worker 成功刷新),则返回更新后的 account
func (api *OAuthRefreshAPI) tryRecoverFromRefreshRace(ctx context.Context, usedAccount *Account) (*Account, bool) {
if api.accountRepo == nil {
return nil, false
}
reReadAccount, err := api.accountRepo.GetByID(ctx, usedAccount.ID)
if err != nil || reReadAccount == nil {
return nil, false
}
usedRT := usedAccount.GetCredential("refresh_token")
currentRT := reReadAccount.GetCredential("refresh_token")
if usedRT == "" || currentRT == "" {
return nil, false
}
// refresh_token 不同 → 另一个 worker 已成功刷新
if usedRT != currentRT {
return reReadAccount, true
}
return nil, false
}
// MergeCredentials 将旧 credentials 中不存在于新 map 的字段保留到新 map 中
func MergeCredentials(oldCreds, newCreds map[string]any) map[string]any {
if newCreds == nil {