diff --git a/backend/internal/service/oauth_refresh_api.go b/backend/internal/service/oauth_refresh_api.go index 5dbba638..571e9ecd 100644 --- a/backend/internal/service/oauth_refresh_api.go +++ b/backend/internal/service/oauth_refresh_api.go @@ -5,6 +5,8 @@ import ( "fmt" "log/slog" "strconv" + "strings" + "sync" "time" ) @@ -17,7 +19,7 @@ type OAuthRefreshExecutor interface { CacheKey(account *Account) string } -const refreshLockTTL = 30 * time.Second +const defaultRefreshLockTTL = 60 * time.Second // OAuthRefreshResult 统一刷新结果 type OAuthRefreshResult struct { @@ -28,20 +30,39 @@ type OAuthRefreshResult struct { } // OAuthRefreshAPI 统一的 OAuth Token 刷新入口 -// 封装分布式锁、DB 重读、已刷新检查等通用逻辑 +// 封装分布式锁、进程内互斥锁、DB 重读、已刷新检查、竞争恢复等通用逻辑 type OAuthRefreshAPI struct { accountRepo AccountRepository - tokenCache GeminiTokenCache // 可选,nil = 无锁 + tokenCache GeminiTokenCache // 可选,nil = 无分布式锁 + lockTTL time.Duration + localLocks sync.Map // key: cacheKey string -> value: *sync.Mutex } // NewOAuthRefreshAPI 创建统一刷新 API -func NewOAuthRefreshAPI(accountRepo AccountRepository, tokenCache GeminiTokenCache) *OAuthRefreshAPI { +// 可选传入 lockTTL 覆盖默认的 60s 分布式锁 TTL +func NewOAuthRefreshAPI(accountRepo AccountRepository, tokenCache GeminiTokenCache, lockTTL ...time.Duration) *OAuthRefreshAPI { + ttl := defaultRefreshLockTTL + if len(lockTTL) > 0 && lockTTL[0] > 0 { + ttl = lockTTL[0] + } return &OAuthRefreshAPI{ accountRepo: accountRepo, tokenCache: tokenCache, + lockTTL: ttl, } } +// getLocalLock 返回指定 cacheKey 的进程内互斥锁 +func (api *OAuthRefreshAPI) getLocalLock(cacheKey string) *sync.Mutex { + actual, _ := api.localLocks.LoadOrStore(cacheKey, &sync.Mutex{}) + mu, ok := actual.(*sync.Mutex) + if !ok { + mu = &sync.Mutex{} + api.localLocks.Store(cacheKey, mu) + } + return mu +} + // RefreshIfNeeded 在分布式锁保护下按需刷新 OAuth token // // 流程: @@ -59,12 +80,17 @@ func (api *OAuthRefreshAPI) RefreshIfNeeded( ) (*OAuthRefreshResult, error) { cacheKey := executor.CacheKey(account) + // 0. 获取进程内互斥锁(防止同一进程内的并发刷新竞争) + localMu := api.getLocalLock(cacheKey) + localMu.Lock() + defer localMu.Unlock() + // 1. 获取分布式锁 lockAcquired := false if api.tokenCache != nil { - acquired, lockErr := api.tokenCache.AcquireRefreshLock(ctx, cacheKey, refreshLockTTL) + acquired, lockErr := api.tokenCache.AcquireRefreshLock(ctx, cacheKey, api.lockTTL) if lockErr != nil { - // Redis 错误,降级为无锁刷新 + // Redis 错误,降级为无锁刷新(进程内互斥锁仍生效) slog.Warn("oauth_refresh_lock_failed_degraded", "account_id", account.ID, "cache_key", cacheKey, @@ -102,6 +128,19 @@ func (api *OAuthRefreshAPI) RefreshIfNeeded( // 4. 执行平台特定刷新逻辑 newCredentials, refreshErr := executor.Refresh(ctx, freshAccount) if refreshErr != nil { + // 竞争恢复:invalid_grant 可能是另一个 worker 已消费了旧 refresh_token + // 重新读取 DB,如果 refresh_token 已更新则说明是竞争,返回成功 + if isInvalidGrantError(refreshErr) { + if recoveredAccount, recovered := api.tryRecoverFromRefreshRace(ctx, freshAccount); recovered { + slog.Info("oauth_refresh_race_recovered", + "account_id", freshAccount.ID, + "platform", freshAccount.Platform, + ) + return &OAuthRefreshResult{ + Account: recoveredAccount, + }, nil + } + } return nil, refreshErr } @@ -126,6 +165,33 @@ func (api *OAuthRefreshAPI) RefreshIfNeeded( }, nil } +// isInvalidGrantError 检查错误是否为 invalid_grant +func isInvalidGrantError(err error) bool { + return err != nil && strings.Contains(strings.ToLower(err.Error()), "invalid_grant") +} + +// tryRecoverFromRefreshRace 在 invalid_grant 错误后尝试竞争恢复 +// 重新读取 DB,如果 refresh_token 已改变(说明另一个 worker 成功刷新),则返回更新后的 account +func (api *OAuthRefreshAPI) tryRecoverFromRefreshRace(ctx context.Context, usedAccount *Account) (*Account, bool) { + if api.accountRepo == nil { + return nil, false + } + reReadAccount, err := api.accountRepo.GetByID(ctx, usedAccount.ID) + if err != nil || reReadAccount == nil { + return nil, false + } + usedRT := usedAccount.GetCredential("refresh_token") + currentRT := reReadAccount.GetCredential("refresh_token") + if usedRT == "" || currentRT == "" { + return nil, false + } + // refresh_token 不同 → 另一个 worker 已成功刷新 + if usedRT != currentRT { + return reReadAccount, true + } + return nil, false +} + // MergeCredentials 将旧 credentials 中不存在于新 map 的字段保留到新 map 中 func MergeCredentials(oldCreds, newCreds map[string]any) map[string]any { if newCreds == nil { diff --git a/backend/internal/service/oauth_refresh_api_test.go b/backend/internal/service/oauth_refresh_api_test.go index c3b38ddf..4a60723b 100644 --- a/backend/internal/service/oauth_refresh_api_test.go +++ b/backend/internal/service/oauth_refresh_api_test.go @@ -5,6 +5,7 @@ package service import ( "context" "errors" + "sync" "testing" "time" @@ -385,6 +386,224 @@ func TestBuildClaudeAccountCredentials_Minimal(t *testing.T) { require.False(t, hasScope, "scope should not be set when empty") } +// refreshAPIAccountRepoWithRace supports returning a different account on subsequent GetByID calls +// to simulate race conditions where another worker has refreshed the token. +type refreshAPIAccountRepoWithRace struct { + refreshAPIAccountRepo + raceAccount *Account // returned on 2nd+ GetByID call + getByIDCalls int +} + +func (r *refreshAPIAccountRepoWithRace) GetByID(_ context.Context, _ int64) (*Account, error) { + r.getByIDCalls++ + if r.getByIDCalls > 1 && r.raceAccount != nil { + return r.raceAccount, nil + } + if r.getByIDErr != nil { + return nil, r.getByIDErr + } + return r.account, nil +} + +// ========== Race recovery tests ========== + +func TestRefreshIfNeeded_InvalidGrantRaceRecovered(t *testing.T) { + // Account with old refresh token + account := &Account{ + ID: 10, + Platform: PlatformAnthropic, + Type: AccountTypeOAuth, + Credentials: map[string]any{"refresh_token": "old-rt", "access_token": "old-at"}, + } + // After race, DB has new refresh token from another worker + racedAccount := &Account{ + ID: 10, + Platform: PlatformAnthropic, + Type: AccountTypeOAuth, + Credentials: map[string]any{"refresh_token": "new-rt", "access_token": "new-at"}, + } + repo := &refreshAPIAccountRepoWithRace{ + refreshAPIAccountRepo: refreshAPIAccountRepo{account: account}, + raceAccount: racedAccount, + } + cache := &refreshAPICacheStub{lockResult: true} + executor := &refreshAPIExecutorStub{ + needsRefresh: true, + err: errors.New("invalid_grant: refresh token not found or invalid"), + } + + api := NewOAuthRefreshAPI(repo, cache) + result, err := api.RefreshIfNeeded(context.Background(), account, executor, 3*time.Minute) + + require.NoError(t, err, "race-recovered invalid_grant should not return error") + require.False(t, result.Refreshed) + require.False(t, result.LockHeld) + require.NotNil(t, result.Account) + require.Equal(t, "new-rt", result.Account.GetCredential("refresh_token")) + require.Equal(t, 0, repo.updateCalls) // no DB update needed, another worker did it +} + +func TestRefreshIfNeeded_InvalidGrantGenuine(t *testing.T) { + // Account with revoked refresh token - DB still has the same token + account := &Account{ + ID: 11, + Platform: PlatformAnthropic, + Type: AccountTypeOAuth, + Credentials: map[string]any{"refresh_token": "revoked-rt", "access_token": "old-at"}, + } + repo := &refreshAPIAccountRepoWithRace{ + refreshAPIAccountRepo: refreshAPIAccountRepo{account: account}, + raceAccount: account, // same refresh_token on re-read + } + cache := &refreshAPICacheStub{lockResult: true} + executor := &refreshAPIExecutorStub{ + needsRefresh: true, + err: errors.New("invalid_grant: refresh token revoked"), + } + + api := NewOAuthRefreshAPI(repo, cache) + result, err := api.RefreshIfNeeded(context.Background(), account, executor, 3*time.Minute) + + require.Error(t, err, "genuine invalid_grant should propagate error") + require.Nil(t, result) + require.Contains(t, err.Error(), "invalid_grant") +} + +func TestRefreshIfNeeded_InvalidGrantDBRereadFailsOnRecovery(t *testing.T) { + account := &Account{ + ID: 12, + Platform: PlatformAnthropic, + Type: AccountTypeOAuth, + Credentials: map[string]any{"refresh_token": "old-rt"}, + } + repo := &refreshAPIAccountRepoWithRace{ + refreshAPIAccountRepo: refreshAPIAccountRepo{account: account}, + raceAccount: nil, // GetByID returns nil on recovery attempt + } + cache := &refreshAPICacheStub{lockResult: true} + executor := &refreshAPIExecutorStub{ + needsRefresh: true, + err: errors.New("invalid_grant"), + } + + api := NewOAuthRefreshAPI(repo, cache) + result, err := api.RefreshIfNeeded(context.Background(), account, executor, 3*time.Minute) + + require.Error(t, err, "should propagate error when recovery DB re-read fails") + require.Nil(t, result) +} + +func TestRefreshIfNeeded_LocalMutexSerializesConcurrent(t *testing.T) { + // Test that two goroutines for the same account are serialized by the local mutex. + // The first goroutine refreshes successfully; the second sees NeedsRefresh=false. + refreshed := &Account{ + ID: 20, + Platform: PlatformAnthropic, + Type: AccountTypeOAuth, + Credentials: map[string]any{"refresh_token": "new-rt", "access_token": "new-at"}, + } + callCount := 0 + repo := &refreshAPIAccountRepo{account: &Account{ + ID: 20, + Platform: PlatformAnthropic, + Type: AccountTypeOAuth, + Credentials: map[string]any{"refresh_token": "old-rt"}, + }} + + // After first refresh, NeedsRefresh should return false + // We simulate this by using an executor that decrements needsRefresh after first call + var mu sync.Mutex + dynamicExecutor := &dynamicRefreshExecutor{ + canRefresh: true, + cacheKey: "test:mutex:anthropic", + refreshFunc: func(_ context.Context, _ *Account) (map[string]any, error) { + mu.Lock() + callCount++ + mu.Unlock() + time.Sleep(50 * time.Millisecond) // slow refresh + return map[string]any{"access_token": "new-at"}, nil + }, + needsRefreshFunc: func() bool { + mu.Lock() + defer mu.Unlock() + return callCount == 0 // only first call needs refresh + }, + } + + _ = refreshed + + api := NewOAuthRefreshAPI(repo, nil) // no distributed lock, only local mutex + + var wg sync.WaitGroup + results := make([]*OAuthRefreshResult, 2) + errs := make([]error, 2) + + for i := 0; i < 2; i++ { + wg.Add(1) + go func(idx int) { + defer wg.Done() + results[idx], errs[idx] = api.RefreshIfNeeded(context.Background(), repo.account, dynamicExecutor, 3*time.Minute) + }(i) + } + wg.Wait() + + require.NoError(t, errs[0]) + require.NoError(t, errs[1]) + + // Only one goroutine should have actually called Refresh + mu.Lock() + require.Equal(t, 1, callCount, "only one refresh call should have been made") + mu.Unlock() +} + +// dynamicRefreshExecutor is a test helper with function-based NeedsRefresh and Refresh. +type dynamicRefreshExecutor struct { + canRefresh bool + cacheKey string + needsRefreshFunc func() bool + refreshFunc func(context.Context, *Account) (map[string]any, error) +} + +func (e *dynamicRefreshExecutor) CanRefresh(_ *Account) bool { return e.canRefresh } + +func (e *dynamicRefreshExecutor) NeedsRefresh(_ *Account, _ time.Duration) bool { + return e.needsRefreshFunc() +} + +func (e *dynamicRefreshExecutor) Refresh(ctx context.Context, account *Account) (map[string]any, error) { + return e.refreshFunc(ctx, account) +} + +func (e *dynamicRefreshExecutor) CacheKey(_ *Account) string { + return e.cacheKey +} + +// ========== NewOAuthRefreshAPI TTL tests ========== + +func TestNewOAuthRefreshAPI_DefaultTTL(t *testing.T) { + api := NewOAuthRefreshAPI(nil, nil) + require.Equal(t, defaultRefreshLockTTL, api.lockTTL) +} + +func TestNewOAuthRefreshAPI_CustomTTL(t *testing.T) { + api := NewOAuthRefreshAPI(nil, nil, 90*time.Second) + require.Equal(t, 90*time.Second, api.lockTTL) +} + +func TestNewOAuthRefreshAPI_ZeroTTLUsesDefault(t *testing.T) { + api := NewOAuthRefreshAPI(nil, nil, 0) + require.Equal(t, defaultRefreshLockTTL, api.lockTTL) +} + +// ========== isInvalidGrantError tests ========== + +func TestIsInvalidGrantError(t *testing.T) { + require.True(t, isInvalidGrantError(errors.New("invalid_grant: token revoked"))) + require.True(t, isInvalidGrantError(errors.New("INVALID_GRANT"))) + require.False(t, isInvalidGrantError(errors.New("invalid_client"))) + require.False(t, isInvalidGrantError(nil)) +} + // ========== BackgroundRefreshPolicy tests ========== func TestBackgroundRefreshPolicy_DefaultSkips(t *testing.T) {