fix: resolve refresh token race condition causing false invalid_grant errors
When multiple goroutines/workers concurrently refresh the same OAuth token, the first succeeds but invalidates the old refresh_token (rotation). Subsequent attempts using the stale token get invalid_grant, which was incorrectly treated as non-retryable, permanently marking the account as ERROR. Three complementary fixes: 1. Race-aware recovery: after invalid_grant, re-read DB to check if another worker already refreshed (refresh_token changed) — return success instead of error 2. In-process mutex (sync.Map of per-account locks): prevents concurrent refreshes within the same process, complementing the Redis distributed lock 3. Increase default lock TTL from 30s to 60s to reduce TTL-expiry races Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -5,6 +5,7 @@ package service
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
@@ -385,6 +386,224 @@ func TestBuildClaudeAccountCredentials_Minimal(t *testing.T) {
|
||||
require.False(t, hasScope, "scope should not be set when empty")
|
||||
}
|
||||
|
||||
// refreshAPIAccountRepoWithRace supports returning a different account on subsequent GetByID calls
|
||||
// to simulate race conditions where another worker has refreshed the token.
|
||||
type refreshAPIAccountRepoWithRace struct {
|
||||
refreshAPIAccountRepo
|
||||
raceAccount *Account // returned on 2nd+ GetByID call
|
||||
getByIDCalls int
|
||||
}
|
||||
|
||||
func (r *refreshAPIAccountRepoWithRace) GetByID(_ context.Context, _ int64) (*Account, error) {
|
||||
r.getByIDCalls++
|
||||
if r.getByIDCalls > 1 && r.raceAccount != nil {
|
||||
return r.raceAccount, nil
|
||||
}
|
||||
if r.getByIDErr != nil {
|
||||
return nil, r.getByIDErr
|
||||
}
|
||||
return r.account, nil
|
||||
}
|
||||
|
||||
// ========== Race recovery tests ==========
|
||||
|
||||
func TestRefreshIfNeeded_InvalidGrantRaceRecovered(t *testing.T) {
|
||||
// Account with old refresh token
|
||||
account := &Account{
|
||||
ID: 10,
|
||||
Platform: PlatformAnthropic,
|
||||
Type: AccountTypeOAuth,
|
||||
Credentials: map[string]any{"refresh_token": "old-rt", "access_token": "old-at"},
|
||||
}
|
||||
// After race, DB has new refresh token from another worker
|
||||
racedAccount := &Account{
|
||||
ID: 10,
|
||||
Platform: PlatformAnthropic,
|
||||
Type: AccountTypeOAuth,
|
||||
Credentials: map[string]any{"refresh_token": "new-rt", "access_token": "new-at"},
|
||||
}
|
||||
repo := &refreshAPIAccountRepoWithRace{
|
||||
refreshAPIAccountRepo: refreshAPIAccountRepo{account: account},
|
||||
raceAccount: racedAccount,
|
||||
}
|
||||
cache := &refreshAPICacheStub{lockResult: true}
|
||||
executor := &refreshAPIExecutorStub{
|
||||
needsRefresh: true,
|
||||
err: errors.New("invalid_grant: refresh token not found or invalid"),
|
||||
}
|
||||
|
||||
api := NewOAuthRefreshAPI(repo, cache)
|
||||
result, err := api.RefreshIfNeeded(context.Background(), account, executor, 3*time.Minute)
|
||||
|
||||
require.NoError(t, err, "race-recovered invalid_grant should not return error")
|
||||
require.False(t, result.Refreshed)
|
||||
require.False(t, result.LockHeld)
|
||||
require.NotNil(t, result.Account)
|
||||
require.Equal(t, "new-rt", result.Account.GetCredential("refresh_token"))
|
||||
require.Equal(t, 0, repo.updateCalls) // no DB update needed, another worker did it
|
||||
}
|
||||
|
||||
func TestRefreshIfNeeded_InvalidGrantGenuine(t *testing.T) {
|
||||
// Account with revoked refresh token - DB still has the same token
|
||||
account := &Account{
|
||||
ID: 11,
|
||||
Platform: PlatformAnthropic,
|
||||
Type: AccountTypeOAuth,
|
||||
Credentials: map[string]any{"refresh_token": "revoked-rt", "access_token": "old-at"},
|
||||
}
|
||||
repo := &refreshAPIAccountRepoWithRace{
|
||||
refreshAPIAccountRepo: refreshAPIAccountRepo{account: account},
|
||||
raceAccount: account, // same refresh_token on re-read
|
||||
}
|
||||
cache := &refreshAPICacheStub{lockResult: true}
|
||||
executor := &refreshAPIExecutorStub{
|
||||
needsRefresh: true,
|
||||
err: errors.New("invalid_grant: refresh token revoked"),
|
||||
}
|
||||
|
||||
api := NewOAuthRefreshAPI(repo, cache)
|
||||
result, err := api.RefreshIfNeeded(context.Background(), account, executor, 3*time.Minute)
|
||||
|
||||
require.Error(t, err, "genuine invalid_grant should propagate error")
|
||||
require.Nil(t, result)
|
||||
require.Contains(t, err.Error(), "invalid_grant")
|
||||
}
|
||||
|
||||
func TestRefreshIfNeeded_InvalidGrantDBRereadFailsOnRecovery(t *testing.T) {
|
||||
account := &Account{
|
||||
ID: 12,
|
||||
Platform: PlatformAnthropic,
|
||||
Type: AccountTypeOAuth,
|
||||
Credentials: map[string]any{"refresh_token": "old-rt"},
|
||||
}
|
||||
repo := &refreshAPIAccountRepoWithRace{
|
||||
refreshAPIAccountRepo: refreshAPIAccountRepo{account: account},
|
||||
raceAccount: nil, // GetByID returns nil on recovery attempt
|
||||
}
|
||||
cache := &refreshAPICacheStub{lockResult: true}
|
||||
executor := &refreshAPIExecutorStub{
|
||||
needsRefresh: true,
|
||||
err: errors.New("invalid_grant"),
|
||||
}
|
||||
|
||||
api := NewOAuthRefreshAPI(repo, cache)
|
||||
result, err := api.RefreshIfNeeded(context.Background(), account, executor, 3*time.Minute)
|
||||
|
||||
require.Error(t, err, "should propagate error when recovery DB re-read fails")
|
||||
require.Nil(t, result)
|
||||
}
|
||||
|
||||
func TestRefreshIfNeeded_LocalMutexSerializesConcurrent(t *testing.T) {
|
||||
// Test that two goroutines for the same account are serialized by the local mutex.
|
||||
// The first goroutine refreshes successfully; the second sees NeedsRefresh=false.
|
||||
refreshed := &Account{
|
||||
ID: 20,
|
||||
Platform: PlatformAnthropic,
|
||||
Type: AccountTypeOAuth,
|
||||
Credentials: map[string]any{"refresh_token": "new-rt", "access_token": "new-at"},
|
||||
}
|
||||
callCount := 0
|
||||
repo := &refreshAPIAccountRepo{account: &Account{
|
||||
ID: 20,
|
||||
Platform: PlatformAnthropic,
|
||||
Type: AccountTypeOAuth,
|
||||
Credentials: map[string]any{"refresh_token": "old-rt"},
|
||||
}}
|
||||
|
||||
// After first refresh, NeedsRefresh should return false
|
||||
// We simulate this by using an executor that decrements needsRefresh after first call
|
||||
var mu sync.Mutex
|
||||
dynamicExecutor := &dynamicRefreshExecutor{
|
||||
canRefresh: true,
|
||||
cacheKey: "test:mutex:anthropic",
|
||||
refreshFunc: func(_ context.Context, _ *Account) (map[string]any, error) {
|
||||
mu.Lock()
|
||||
callCount++
|
||||
mu.Unlock()
|
||||
time.Sleep(50 * time.Millisecond) // slow refresh
|
||||
return map[string]any{"access_token": "new-at"}, nil
|
||||
},
|
||||
needsRefreshFunc: func() bool {
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
return callCount == 0 // only first call needs refresh
|
||||
},
|
||||
}
|
||||
|
||||
_ = refreshed
|
||||
|
||||
api := NewOAuthRefreshAPI(repo, nil) // no distributed lock, only local mutex
|
||||
|
||||
var wg sync.WaitGroup
|
||||
results := make([]*OAuthRefreshResult, 2)
|
||||
errs := make([]error, 2)
|
||||
|
||||
for i := 0; i < 2; i++ {
|
||||
wg.Add(1)
|
||||
go func(idx int) {
|
||||
defer wg.Done()
|
||||
results[idx], errs[idx] = api.RefreshIfNeeded(context.Background(), repo.account, dynamicExecutor, 3*time.Minute)
|
||||
}(i)
|
||||
}
|
||||
wg.Wait()
|
||||
|
||||
require.NoError(t, errs[0])
|
||||
require.NoError(t, errs[1])
|
||||
|
||||
// Only one goroutine should have actually called Refresh
|
||||
mu.Lock()
|
||||
require.Equal(t, 1, callCount, "only one refresh call should have been made")
|
||||
mu.Unlock()
|
||||
}
|
||||
|
||||
// dynamicRefreshExecutor is a test helper with function-based NeedsRefresh and Refresh.
|
||||
type dynamicRefreshExecutor struct {
|
||||
canRefresh bool
|
||||
cacheKey string
|
||||
needsRefreshFunc func() bool
|
||||
refreshFunc func(context.Context, *Account) (map[string]any, error)
|
||||
}
|
||||
|
||||
func (e *dynamicRefreshExecutor) CanRefresh(_ *Account) bool { return e.canRefresh }
|
||||
|
||||
func (e *dynamicRefreshExecutor) NeedsRefresh(_ *Account, _ time.Duration) bool {
|
||||
return e.needsRefreshFunc()
|
||||
}
|
||||
|
||||
func (e *dynamicRefreshExecutor) Refresh(ctx context.Context, account *Account) (map[string]any, error) {
|
||||
return e.refreshFunc(ctx, account)
|
||||
}
|
||||
|
||||
func (e *dynamicRefreshExecutor) CacheKey(_ *Account) string {
|
||||
return e.cacheKey
|
||||
}
|
||||
|
||||
// ========== NewOAuthRefreshAPI TTL tests ==========
|
||||
|
||||
func TestNewOAuthRefreshAPI_DefaultTTL(t *testing.T) {
|
||||
api := NewOAuthRefreshAPI(nil, nil)
|
||||
require.Equal(t, defaultRefreshLockTTL, api.lockTTL)
|
||||
}
|
||||
|
||||
func TestNewOAuthRefreshAPI_CustomTTL(t *testing.T) {
|
||||
api := NewOAuthRefreshAPI(nil, nil, 90*time.Second)
|
||||
require.Equal(t, 90*time.Second, api.lockTTL)
|
||||
}
|
||||
|
||||
func TestNewOAuthRefreshAPI_ZeroTTLUsesDefault(t *testing.T) {
|
||||
api := NewOAuthRefreshAPI(nil, nil, 0)
|
||||
require.Equal(t, defaultRefreshLockTTL, api.lockTTL)
|
||||
}
|
||||
|
||||
// ========== isInvalidGrantError tests ==========
|
||||
|
||||
func TestIsInvalidGrantError(t *testing.T) {
|
||||
require.True(t, isInvalidGrantError(errors.New("invalid_grant: token revoked")))
|
||||
require.True(t, isInvalidGrantError(errors.New("INVALID_GRANT")))
|
||||
require.False(t, isInvalidGrantError(errors.New("invalid_client")))
|
||||
require.False(t, isInvalidGrantError(nil))
|
||||
}
|
||||
|
||||
// ========== BackgroundRefreshPolicy tests ==========
|
||||
|
||||
func TestBackgroundRefreshPolicy_DefaultSkips(t *testing.T) {
|
||||
|
||||
Reference in New Issue
Block a user