Merge pull request #1382 from StarryKira/fix/refresh-token-race-condition
fix: resolve refresh token race condition causing false invalid_grant errors fix issue#1381
This commit is contained in:
@@ -5,6 +5,8 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"log/slog"
|
"log/slog"
|
||||||
"strconv"
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -17,7 +19,7 @@ type OAuthRefreshExecutor interface {
|
|||||||
CacheKey(account *Account) string
|
CacheKey(account *Account) string
|
||||||
}
|
}
|
||||||
|
|
||||||
const refreshLockTTL = 30 * time.Second
|
const defaultRefreshLockTTL = 60 * time.Second
|
||||||
|
|
||||||
// OAuthRefreshResult 统一刷新结果
|
// OAuthRefreshResult 统一刷新结果
|
||||||
type OAuthRefreshResult struct {
|
type OAuthRefreshResult struct {
|
||||||
@@ -28,20 +30,39 @@ type OAuthRefreshResult struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// OAuthRefreshAPI 统一的 OAuth Token 刷新入口
|
// OAuthRefreshAPI 统一的 OAuth Token 刷新入口
|
||||||
// 封装分布式锁、DB 重读、已刷新检查等通用逻辑
|
// 封装分布式锁、进程内互斥锁、DB 重读、已刷新检查、竞争恢复等通用逻辑
|
||||||
type OAuthRefreshAPI struct {
|
type OAuthRefreshAPI struct {
|
||||||
accountRepo AccountRepository
|
accountRepo AccountRepository
|
||||||
tokenCache GeminiTokenCache // 可选,nil = 无锁
|
tokenCache GeminiTokenCache // 可选,nil = 无分布式锁
|
||||||
|
lockTTL time.Duration
|
||||||
|
localLocks sync.Map // key: cacheKey string -> value: *sync.Mutex
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewOAuthRefreshAPI 创建统一刷新 API
|
// NewOAuthRefreshAPI 创建统一刷新 API
|
||||||
func NewOAuthRefreshAPI(accountRepo AccountRepository, tokenCache GeminiTokenCache) *OAuthRefreshAPI {
|
// 可选传入 lockTTL 覆盖默认的 60s 分布式锁 TTL
|
||||||
|
func NewOAuthRefreshAPI(accountRepo AccountRepository, tokenCache GeminiTokenCache, lockTTL ...time.Duration) *OAuthRefreshAPI {
|
||||||
|
ttl := defaultRefreshLockTTL
|
||||||
|
if len(lockTTL) > 0 && lockTTL[0] > 0 {
|
||||||
|
ttl = lockTTL[0]
|
||||||
|
}
|
||||||
return &OAuthRefreshAPI{
|
return &OAuthRefreshAPI{
|
||||||
accountRepo: accountRepo,
|
accountRepo: accountRepo,
|
||||||
tokenCache: tokenCache,
|
tokenCache: tokenCache,
|
||||||
|
lockTTL: ttl,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// getLocalLock 返回指定 cacheKey 的进程内互斥锁
|
||||||
|
func (api *OAuthRefreshAPI) getLocalLock(cacheKey string) *sync.Mutex {
|
||||||
|
actual, _ := api.localLocks.LoadOrStore(cacheKey, &sync.Mutex{})
|
||||||
|
mu, ok := actual.(*sync.Mutex)
|
||||||
|
if !ok {
|
||||||
|
mu = &sync.Mutex{}
|
||||||
|
api.localLocks.Store(cacheKey, mu)
|
||||||
|
}
|
||||||
|
return mu
|
||||||
|
}
|
||||||
|
|
||||||
// RefreshIfNeeded 在分布式锁保护下按需刷新 OAuth token
|
// RefreshIfNeeded 在分布式锁保护下按需刷新 OAuth token
|
||||||
//
|
//
|
||||||
// 流程:
|
// 流程:
|
||||||
@@ -59,12 +80,17 @@ func (api *OAuthRefreshAPI) RefreshIfNeeded(
|
|||||||
) (*OAuthRefreshResult, error) {
|
) (*OAuthRefreshResult, error) {
|
||||||
cacheKey := executor.CacheKey(account)
|
cacheKey := executor.CacheKey(account)
|
||||||
|
|
||||||
|
// 0. 获取进程内互斥锁(防止同一进程内的并发刷新竞争)
|
||||||
|
localMu := api.getLocalLock(cacheKey)
|
||||||
|
localMu.Lock()
|
||||||
|
defer localMu.Unlock()
|
||||||
|
|
||||||
// 1. 获取分布式锁
|
// 1. 获取分布式锁
|
||||||
lockAcquired := false
|
lockAcquired := false
|
||||||
if api.tokenCache != nil {
|
if api.tokenCache != nil {
|
||||||
acquired, lockErr := api.tokenCache.AcquireRefreshLock(ctx, cacheKey, refreshLockTTL)
|
acquired, lockErr := api.tokenCache.AcquireRefreshLock(ctx, cacheKey, api.lockTTL)
|
||||||
if lockErr != nil {
|
if lockErr != nil {
|
||||||
// Redis 错误,降级为无锁刷新
|
// Redis 错误,降级为无锁刷新(进程内互斥锁仍生效)
|
||||||
slog.Warn("oauth_refresh_lock_failed_degraded",
|
slog.Warn("oauth_refresh_lock_failed_degraded",
|
||||||
"account_id", account.ID,
|
"account_id", account.ID,
|
||||||
"cache_key", cacheKey,
|
"cache_key", cacheKey,
|
||||||
@@ -102,6 +128,19 @@ func (api *OAuthRefreshAPI) RefreshIfNeeded(
|
|||||||
// 4. 执行平台特定刷新逻辑
|
// 4. 执行平台特定刷新逻辑
|
||||||
newCredentials, refreshErr := executor.Refresh(ctx, freshAccount)
|
newCredentials, refreshErr := executor.Refresh(ctx, freshAccount)
|
||||||
if refreshErr != nil {
|
if refreshErr != nil {
|
||||||
|
// 竞争恢复:invalid_grant 可能是另一个 worker 已消费了旧 refresh_token
|
||||||
|
// 重新读取 DB,如果 refresh_token 已更新则说明是竞争,返回成功
|
||||||
|
if isInvalidGrantError(refreshErr) {
|
||||||
|
if recoveredAccount, recovered := api.tryRecoverFromRefreshRace(ctx, freshAccount); recovered {
|
||||||
|
slog.Info("oauth_refresh_race_recovered",
|
||||||
|
"account_id", freshAccount.ID,
|
||||||
|
"platform", freshAccount.Platform,
|
||||||
|
)
|
||||||
|
return &OAuthRefreshResult{
|
||||||
|
Account: recoveredAccount,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
return nil, refreshErr
|
return nil, refreshErr
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -126,6 +165,33 @@ func (api *OAuthRefreshAPI) RefreshIfNeeded(
|
|||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// isInvalidGrantError 检查错误是否为 invalid_grant
|
||||||
|
func isInvalidGrantError(err error) bool {
|
||||||
|
return err != nil && strings.Contains(strings.ToLower(err.Error()), "invalid_grant")
|
||||||
|
}
|
||||||
|
|
||||||
|
// tryRecoverFromRefreshRace 在 invalid_grant 错误后尝试竞争恢复
|
||||||
|
// 重新读取 DB,如果 refresh_token 已改变(说明另一个 worker 成功刷新),则返回更新后的 account
|
||||||
|
func (api *OAuthRefreshAPI) tryRecoverFromRefreshRace(ctx context.Context, usedAccount *Account) (*Account, bool) {
|
||||||
|
if api.accountRepo == nil {
|
||||||
|
return nil, false
|
||||||
|
}
|
||||||
|
reReadAccount, err := api.accountRepo.GetByID(ctx, usedAccount.ID)
|
||||||
|
if err != nil || reReadAccount == nil {
|
||||||
|
return nil, false
|
||||||
|
}
|
||||||
|
usedRT := usedAccount.GetCredential("refresh_token")
|
||||||
|
currentRT := reReadAccount.GetCredential("refresh_token")
|
||||||
|
if usedRT == "" || currentRT == "" {
|
||||||
|
return nil, false
|
||||||
|
}
|
||||||
|
// refresh_token 不同 → 另一个 worker 已成功刷新
|
||||||
|
if usedRT != currentRT {
|
||||||
|
return reReadAccount, true
|
||||||
|
}
|
||||||
|
return nil, false
|
||||||
|
}
|
||||||
|
|
||||||
// MergeCredentials 将旧 credentials 中不存在于新 map 的字段保留到新 map 中
|
// MergeCredentials 将旧 credentials 中不存在于新 map 的字段保留到新 map 中
|
||||||
func MergeCredentials(oldCreds, newCreds map[string]any) map[string]any {
|
func MergeCredentials(oldCreds, newCreds map[string]any) map[string]any {
|
||||||
if newCreds == nil {
|
if newCreds == nil {
|
||||||
|
|||||||
@@ -5,6 +5,7 @@ package service
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"errors"
|
"errors"
|
||||||
|
"sync"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
@@ -385,6 +386,224 @@ func TestBuildClaudeAccountCredentials_Minimal(t *testing.T) {
|
|||||||
require.False(t, hasScope, "scope should not be set when empty")
|
require.False(t, hasScope, "scope should not be set when empty")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// refreshAPIAccountRepoWithRace supports returning a different account on subsequent GetByID calls
|
||||||
|
// to simulate race conditions where another worker has refreshed the token.
|
||||||
|
type refreshAPIAccountRepoWithRace struct {
|
||||||
|
refreshAPIAccountRepo
|
||||||
|
raceAccount *Account // returned on 2nd+ GetByID call
|
||||||
|
getByIDCalls int
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *refreshAPIAccountRepoWithRace) GetByID(_ context.Context, _ int64) (*Account, error) {
|
||||||
|
r.getByIDCalls++
|
||||||
|
if r.getByIDCalls > 1 && r.raceAccount != nil {
|
||||||
|
return r.raceAccount, nil
|
||||||
|
}
|
||||||
|
if r.getByIDErr != nil {
|
||||||
|
return nil, r.getByIDErr
|
||||||
|
}
|
||||||
|
return r.account, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ========== Race recovery tests ==========
|
||||||
|
|
||||||
|
func TestRefreshIfNeeded_InvalidGrantRaceRecovered(t *testing.T) {
|
||||||
|
// Account with old refresh token
|
||||||
|
account := &Account{
|
||||||
|
ID: 10,
|
||||||
|
Platform: PlatformAnthropic,
|
||||||
|
Type: AccountTypeOAuth,
|
||||||
|
Credentials: map[string]any{"refresh_token": "old-rt", "access_token": "old-at"},
|
||||||
|
}
|
||||||
|
// After race, DB has new refresh token from another worker
|
||||||
|
racedAccount := &Account{
|
||||||
|
ID: 10,
|
||||||
|
Platform: PlatformAnthropic,
|
||||||
|
Type: AccountTypeOAuth,
|
||||||
|
Credentials: map[string]any{"refresh_token": "new-rt", "access_token": "new-at"},
|
||||||
|
}
|
||||||
|
repo := &refreshAPIAccountRepoWithRace{
|
||||||
|
refreshAPIAccountRepo: refreshAPIAccountRepo{account: account},
|
||||||
|
raceAccount: racedAccount,
|
||||||
|
}
|
||||||
|
cache := &refreshAPICacheStub{lockResult: true}
|
||||||
|
executor := &refreshAPIExecutorStub{
|
||||||
|
needsRefresh: true,
|
||||||
|
err: errors.New("invalid_grant: refresh token not found or invalid"),
|
||||||
|
}
|
||||||
|
|
||||||
|
api := NewOAuthRefreshAPI(repo, cache)
|
||||||
|
result, err := api.RefreshIfNeeded(context.Background(), account, executor, 3*time.Minute)
|
||||||
|
|
||||||
|
require.NoError(t, err, "race-recovered invalid_grant should not return error")
|
||||||
|
require.False(t, result.Refreshed)
|
||||||
|
require.False(t, result.LockHeld)
|
||||||
|
require.NotNil(t, result.Account)
|
||||||
|
require.Equal(t, "new-rt", result.Account.GetCredential("refresh_token"))
|
||||||
|
require.Equal(t, 0, repo.updateCalls) // no DB update needed, another worker did it
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRefreshIfNeeded_InvalidGrantGenuine(t *testing.T) {
|
||||||
|
// Account with revoked refresh token - DB still has the same token
|
||||||
|
account := &Account{
|
||||||
|
ID: 11,
|
||||||
|
Platform: PlatformAnthropic,
|
||||||
|
Type: AccountTypeOAuth,
|
||||||
|
Credentials: map[string]any{"refresh_token": "revoked-rt", "access_token": "old-at"},
|
||||||
|
}
|
||||||
|
repo := &refreshAPIAccountRepoWithRace{
|
||||||
|
refreshAPIAccountRepo: refreshAPIAccountRepo{account: account},
|
||||||
|
raceAccount: account, // same refresh_token on re-read
|
||||||
|
}
|
||||||
|
cache := &refreshAPICacheStub{lockResult: true}
|
||||||
|
executor := &refreshAPIExecutorStub{
|
||||||
|
needsRefresh: true,
|
||||||
|
err: errors.New("invalid_grant: refresh token revoked"),
|
||||||
|
}
|
||||||
|
|
||||||
|
api := NewOAuthRefreshAPI(repo, cache)
|
||||||
|
result, err := api.RefreshIfNeeded(context.Background(), account, executor, 3*time.Minute)
|
||||||
|
|
||||||
|
require.Error(t, err, "genuine invalid_grant should propagate error")
|
||||||
|
require.Nil(t, result)
|
||||||
|
require.Contains(t, err.Error(), "invalid_grant")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRefreshIfNeeded_InvalidGrantDBRereadFailsOnRecovery(t *testing.T) {
|
||||||
|
account := &Account{
|
||||||
|
ID: 12,
|
||||||
|
Platform: PlatformAnthropic,
|
||||||
|
Type: AccountTypeOAuth,
|
||||||
|
Credentials: map[string]any{"refresh_token": "old-rt"},
|
||||||
|
}
|
||||||
|
repo := &refreshAPIAccountRepoWithRace{
|
||||||
|
refreshAPIAccountRepo: refreshAPIAccountRepo{account: account},
|
||||||
|
raceAccount: nil, // GetByID returns nil on recovery attempt
|
||||||
|
}
|
||||||
|
cache := &refreshAPICacheStub{lockResult: true}
|
||||||
|
executor := &refreshAPIExecutorStub{
|
||||||
|
needsRefresh: true,
|
||||||
|
err: errors.New("invalid_grant"),
|
||||||
|
}
|
||||||
|
|
||||||
|
api := NewOAuthRefreshAPI(repo, cache)
|
||||||
|
result, err := api.RefreshIfNeeded(context.Background(), account, executor, 3*time.Minute)
|
||||||
|
|
||||||
|
require.Error(t, err, "should propagate error when recovery DB re-read fails")
|
||||||
|
require.Nil(t, result)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRefreshIfNeeded_LocalMutexSerializesConcurrent(t *testing.T) {
|
||||||
|
// Test that two goroutines for the same account are serialized by the local mutex.
|
||||||
|
// The first goroutine refreshes successfully; the second sees NeedsRefresh=false.
|
||||||
|
refreshed := &Account{
|
||||||
|
ID: 20,
|
||||||
|
Platform: PlatformAnthropic,
|
||||||
|
Type: AccountTypeOAuth,
|
||||||
|
Credentials: map[string]any{"refresh_token": "new-rt", "access_token": "new-at"},
|
||||||
|
}
|
||||||
|
callCount := 0
|
||||||
|
repo := &refreshAPIAccountRepo{account: &Account{
|
||||||
|
ID: 20,
|
||||||
|
Platform: PlatformAnthropic,
|
||||||
|
Type: AccountTypeOAuth,
|
||||||
|
Credentials: map[string]any{"refresh_token": "old-rt"},
|
||||||
|
}}
|
||||||
|
|
||||||
|
// After first refresh, NeedsRefresh should return false
|
||||||
|
// We simulate this by using an executor that decrements needsRefresh after first call
|
||||||
|
var mu sync.Mutex
|
||||||
|
dynamicExecutor := &dynamicRefreshExecutor{
|
||||||
|
canRefresh: true,
|
||||||
|
cacheKey: "test:mutex:anthropic",
|
||||||
|
refreshFunc: func(_ context.Context, _ *Account) (map[string]any, error) {
|
||||||
|
mu.Lock()
|
||||||
|
callCount++
|
||||||
|
mu.Unlock()
|
||||||
|
time.Sleep(50 * time.Millisecond) // slow refresh
|
||||||
|
return map[string]any{"access_token": "new-at"}, nil
|
||||||
|
},
|
||||||
|
needsRefreshFunc: func() bool {
|
||||||
|
mu.Lock()
|
||||||
|
defer mu.Unlock()
|
||||||
|
return callCount == 0 // only first call needs refresh
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
_ = refreshed
|
||||||
|
|
||||||
|
api := NewOAuthRefreshAPI(repo, nil) // no distributed lock, only local mutex
|
||||||
|
|
||||||
|
var wg sync.WaitGroup
|
||||||
|
results := make([]*OAuthRefreshResult, 2)
|
||||||
|
errs := make([]error, 2)
|
||||||
|
|
||||||
|
for i := 0; i < 2; i++ {
|
||||||
|
wg.Add(1)
|
||||||
|
go func(idx int) {
|
||||||
|
defer wg.Done()
|
||||||
|
results[idx], errs[idx] = api.RefreshIfNeeded(context.Background(), repo.account, dynamicExecutor, 3*time.Minute)
|
||||||
|
}(i)
|
||||||
|
}
|
||||||
|
wg.Wait()
|
||||||
|
|
||||||
|
require.NoError(t, errs[0])
|
||||||
|
require.NoError(t, errs[1])
|
||||||
|
|
||||||
|
// Only one goroutine should have actually called Refresh
|
||||||
|
mu.Lock()
|
||||||
|
require.Equal(t, 1, callCount, "only one refresh call should have been made")
|
||||||
|
mu.Unlock()
|
||||||
|
}
|
||||||
|
|
||||||
|
// dynamicRefreshExecutor is a test helper with function-based NeedsRefresh and Refresh.
|
||||||
|
type dynamicRefreshExecutor struct {
|
||||||
|
canRefresh bool
|
||||||
|
cacheKey string
|
||||||
|
needsRefreshFunc func() bool
|
||||||
|
refreshFunc func(context.Context, *Account) (map[string]any, error)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *dynamicRefreshExecutor) CanRefresh(_ *Account) bool { return e.canRefresh }
|
||||||
|
|
||||||
|
func (e *dynamicRefreshExecutor) NeedsRefresh(_ *Account, _ time.Duration) bool {
|
||||||
|
return e.needsRefreshFunc()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *dynamicRefreshExecutor) Refresh(ctx context.Context, account *Account) (map[string]any, error) {
|
||||||
|
return e.refreshFunc(ctx, account)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *dynamicRefreshExecutor) CacheKey(_ *Account) string {
|
||||||
|
return e.cacheKey
|
||||||
|
}
|
||||||
|
|
||||||
|
// ========== NewOAuthRefreshAPI TTL tests ==========
|
||||||
|
|
||||||
|
func TestNewOAuthRefreshAPI_DefaultTTL(t *testing.T) {
|
||||||
|
api := NewOAuthRefreshAPI(nil, nil)
|
||||||
|
require.Equal(t, defaultRefreshLockTTL, api.lockTTL)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNewOAuthRefreshAPI_CustomTTL(t *testing.T) {
|
||||||
|
api := NewOAuthRefreshAPI(nil, nil, 90*time.Second)
|
||||||
|
require.Equal(t, 90*time.Second, api.lockTTL)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNewOAuthRefreshAPI_ZeroTTLUsesDefault(t *testing.T) {
|
||||||
|
api := NewOAuthRefreshAPI(nil, nil, 0)
|
||||||
|
require.Equal(t, defaultRefreshLockTTL, api.lockTTL)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ========== isInvalidGrantError tests ==========
|
||||||
|
|
||||||
|
func TestIsInvalidGrantError(t *testing.T) {
|
||||||
|
require.True(t, isInvalidGrantError(errors.New("invalid_grant: token revoked")))
|
||||||
|
require.True(t, isInvalidGrantError(errors.New("INVALID_GRANT")))
|
||||||
|
require.False(t, isInvalidGrantError(errors.New("invalid_client")))
|
||||||
|
require.False(t, isInvalidGrantError(nil))
|
||||||
|
}
|
||||||
|
|
||||||
// ========== BackgroundRefreshPolicy tests ==========
|
// ========== BackgroundRefreshPolicy tests ==========
|
||||||
|
|
||||||
func TestBackgroundRefreshPolicy_DefaultSkips(t *testing.T) {
|
func TestBackgroundRefreshPolicy_DefaultSkips(t *testing.T) {
|
||||||
|
|||||||
Reference in New Issue
Block a user