feat: unified OAuth token refresh API with distributed locking
Introduce OAuthRefreshAPI as the single entry point for all OAuth token refresh operations, eliminating the race condition where background refresh and inline refresh could simultaneously use the same refresh_token (fixes #1035). Key changes: - Add OAuthRefreshExecutor interface extending TokenRefresher with CacheKey - Add OAuthRefreshAPI.RefreshIfNeeded with lock → DB re-read → double-check flow - Add ProviderRefreshPolicy / BackgroundRefreshPolicy strategy types - Simplify all 4 TokenProviders to delegate to OAuthRefreshAPI - Rewrite TokenRefreshService.refreshWithRetry to use unified API path - Add MergeCredentials and BuildClaudeAccountCredentials helpers - Add 40 unit tests covering all new and modified code paths
This commit is contained in:
@@ -84,6 +84,10 @@ func (r *tokenRefresherStub) Refresh(ctx context.Context, account *Account) (map
|
||||
return r.credentials, nil
|
||||
}
|
||||
|
||||
func (r *tokenRefresherStub) CacheKey(account *Account) string {
|
||||
return "test:stub:" + account.Platform
|
||||
}
|
||||
|
||||
func TestTokenRefreshService_RefreshWithRetry_InvalidatesCache(t *testing.T) {
|
||||
repo := &tokenRefreshAccountRepo{}
|
||||
invalidator := &tokenCacheInvalidatorStub{}
|
||||
@@ -105,7 +109,7 @@ func TestTokenRefreshService_RefreshWithRetry_InvalidatesCache(t *testing.T) {
|
||||
},
|
||||
}
|
||||
|
||||
err := service.refreshWithRetry(context.Background(), account, refresher)
|
||||
err := service.refreshWithRetry(context.Background(), account, refresher, refresher, time.Hour)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, 1, repo.updateCalls)
|
||||
require.Equal(t, 1, invalidator.calls)
|
||||
@@ -133,7 +137,7 @@ func TestTokenRefreshService_RefreshWithRetry_InvalidatorErrorIgnored(t *testing
|
||||
},
|
||||
}
|
||||
|
||||
err := service.refreshWithRetry(context.Background(), account, refresher)
|
||||
err := service.refreshWithRetry(context.Background(), account, refresher, refresher, time.Hour)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, 1, repo.updateCalls)
|
||||
require.Equal(t, 1, invalidator.calls)
|
||||
@@ -159,7 +163,7 @@ func TestTokenRefreshService_RefreshWithRetry_NilInvalidator(t *testing.T) {
|
||||
},
|
||||
}
|
||||
|
||||
err := service.refreshWithRetry(context.Background(), account, refresher)
|
||||
err := service.refreshWithRetry(context.Background(), account, refresher, refresher, time.Hour)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, 1, repo.updateCalls)
|
||||
}
|
||||
@@ -186,7 +190,7 @@ func TestTokenRefreshService_RefreshWithRetry_Antigravity(t *testing.T) {
|
||||
},
|
||||
}
|
||||
|
||||
err := service.refreshWithRetry(context.Background(), account, refresher)
|
||||
err := service.refreshWithRetry(context.Background(), account, refresher, refresher, time.Hour)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, 1, repo.updateCalls)
|
||||
require.Equal(t, 1, invalidator.calls) // Antigravity 也应触发缓存失效
|
||||
@@ -214,7 +218,7 @@ func TestTokenRefreshService_RefreshWithRetry_NonOAuthAccount(t *testing.T) {
|
||||
},
|
||||
}
|
||||
|
||||
err := service.refreshWithRetry(context.Background(), account, refresher)
|
||||
err := service.refreshWithRetry(context.Background(), account, refresher, refresher, time.Hour)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, 1, repo.updateCalls)
|
||||
require.Equal(t, 0, invalidator.calls) // 非 OAuth 不触发缓存失效
|
||||
@@ -242,7 +246,7 @@ func TestTokenRefreshService_RefreshWithRetry_OtherPlatformOAuth(t *testing.T) {
|
||||
},
|
||||
}
|
||||
|
||||
err := service.refreshWithRetry(context.Background(), account, refresher)
|
||||
err := service.refreshWithRetry(context.Background(), account, refresher, refresher, time.Hour)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, 1, repo.updateCalls)
|
||||
require.Equal(t, 1, invalidator.calls) // 所有 OAuth 账户刷新后触发缓存失效
|
||||
@@ -270,7 +274,7 @@ func TestTokenRefreshService_RefreshWithRetry_UpdateFailed(t *testing.T) {
|
||||
},
|
||||
}
|
||||
|
||||
err := service.refreshWithRetry(context.Background(), account, refresher)
|
||||
err := service.refreshWithRetry(context.Background(), account, refresher, refresher, time.Hour)
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "failed to save credentials")
|
||||
require.Equal(t, 1, repo.updateCalls)
|
||||
@@ -297,7 +301,7 @@ func TestTokenRefreshService_RefreshWithRetry_RefreshFailed(t *testing.T) {
|
||||
err: errors.New("refresh failed"),
|
||||
}
|
||||
|
||||
err := service.refreshWithRetry(context.Background(), account, refresher)
|
||||
err := service.refreshWithRetry(context.Background(), account, refresher, refresher, time.Hour)
|
||||
require.Error(t, err)
|
||||
require.Equal(t, 0, repo.updateCalls) // 刷新失败不应更新
|
||||
require.Equal(t, 0, invalidator.calls) // 刷新失败不应触发缓存失效
|
||||
@@ -324,7 +328,7 @@ func TestTokenRefreshService_RefreshWithRetry_AntigravityRefreshFailed(t *testin
|
||||
err: errors.New("network error"), // 可重试错误
|
||||
}
|
||||
|
||||
err := service.refreshWithRetry(context.Background(), account, refresher)
|
||||
err := service.refreshWithRetry(context.Background(), account, refresher, refresher, time.Hour)
|
||||
require.Error(t, err)
|
||||
require.Equal(t, 0, repo.updateCalls)
|
||||
require.Equal(t, 0, invalidator.calls)
|
||||
@@ -351,7 +355,7 @@ func TestTokenRefreshService_RefreshWithRetry_AntigravityNonRetryableError(t *te
|
||||
err: errors.New("invalid_grant: token revoked"), // 不可重试错误
|
||||
}
|
||||
|
||||
err := service.refreshWithRetry(context.Background(), account, refresher)
|
||||
err := service.refreshWithRetry(context.Background(), account, refresher, refresher, time.Hour)
|
||||
require.Error(t, err)
|
||||
require.Equal(t, 0, repo.updateCalls)
|
||||
require.Equal(t, 0, invalidator.calls)
|
||||
@@ -383,7 +387,7 @@ func TestTokenRefreshService_RefreshWithRetry_ClearsTempUnschedulable(t *testing
|
||||
},
|
||||
}
|
||||
|
||||
err := service.refreshWithRetry(context.Background(), account, refresher)
|
||||
err := service.refreshWithRetry(context.Background(), account, refresher, refresher, time.Hour)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, 1, repo.updateCalls)
|
||||
require.Equal(t, 1, repo.clearTempCalls) // DB 清除
|
||||
@@ -422,7 +426,7 @@ func TestTokenRefreshService_RefreshWithRetry_NonRetryableErrorAllPlatforms(t *t
|
||||
err: errors.New("invalid_grant: token revoked"),
|
||||
}
|
||||
|
||||
err := service.refreshWithRetry(context.Background(), account, refresher)
|
||||
err := service.refreshWithRetry(context.Background(), account, refresher, refresher, time.Hour)
|
||||
require.Error(t, err)
|
||||
require.Equal(t, 1, repo.setErrorCalls) // 所有平台不可重试错误都应 SetError
|
||||
})
|
||||
@@ -453,3 +457,212 @@ func TestIsNonRetryableRefreshError(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// ========== Path A (refreshAPI) 测试用例 ==========
|
||||
|
||||
// mockTokenCacheForRefreshAPI 用于 Path A 测试的 GeminiTokenCache mock
|
||||
type mockTokenCacheForRefreshAPI struct {
|
||||
lockResult bool
|
||||
lockErr error
|
||||
releaseCalls int
|
||||
}
|
||||
|
||||
func (m *mockTokenCacheForRefreshAPI) GetAccessToken(_ context.Context, _ string) (string, error) {
|
||||
return "", errors.New("not cached")
|
||||
}
|
||||
|
||||
func (m *mockTokenCacheForRefreshAPI) SetAccessToken(_ context.Context, _ string, _ string, _ time.Duration) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockTokenCacheForRefreshAPI) DeleteAccessToken(_ context.Context, _ string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockTokenCacheForRefreshAPI) AcquireRefreshLock(_ context.Context, _ string, _ time.Duration) (bool, error) {
|
||||
return m.lockResult, m.lockErr
|
||||
}
|
||||
|
||||
func (m *mockTokenCacheForRefreshAPI) ReleaseRefreshLock(_ context.Context, _ string) error {
|
||||
m.releaseCalls++
|
||||
return nil
|
||||
}
|
||||
|
||||
// buildPathAService 构建注入了 refreshAPI 的 service(Path A 测试辅助)
|
||||
func buildPathAService(repo *tokenRefreshAccountRepo, cache GeminiTokenCache, invalidator TokenCacheInvalidator) (*TokenRefreshService, *tokenRefresherStub) {
|
||||
cfg := &config.Config{
|
||||
TokenRefresh: config.TokenRefreshConfig{
|
||||
MaxRetries: 1,
|
||||
RetryBackoffSeconds: 0,
|
||||
},
|
||||
}
|
||||
service := NewTokenRefreshService(repo, nil, nil, nil, nil, invalidator, nil, cfg, nil)
|
||||
refreshAPI := NewOAuthRefreshAPI(repo, cache)
|
||||
service.SetRefreshAPI(refreshAPI)
|
||||
|
||||
refresher := &tokenRefresherStub{
|
||||
credentials: map[string]any{
|
||||
"access_token": "refreshed-token",
|
||||
},
|
||||
}
|
||||
return service, refresher
|
||||
}
|
||||
|
||||
// TestPathA_Success 统一 API 路径正常成功:刷新 + DB 更新 + postRefreshActions
|
||||
func TestPathA_Success(t *testing.T) {
|
||||
account := &Account{
|
||||
ID: 100,
|
||||
Platform: PlatformGemini,
|
||||
Type: AccountTypeOAuth,
|
||||
}
|
||||
repo := &tokenRefreshAccountRepo{}
|
||||
repo.accountsByID = map[int64]*Account{account.ID: account}
|
||||
invalidator := &tokenCacheInvalidatorStub{}
|
||||
cache := &mockTokenCacheForRefreshAPI{lockResult: true}
|
||||
|
||||
service, refresher := buildPathAService(repo, cache, invalidator)
|
||||
|
||||
err := service.refreshWithRetry(context.Background(), account, refresher, refresher, time.Hour)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, 1, repo.updateCalls) // DB 更新被调用
|
||||
require.Equal(t, 1, invalidator.calls) // 缓存失效被调用
|
||||
require.Equal(t, 1, cache.releaseCalls) // 锁被释放
|
||||
}
|
||||
|
||||
// TestPathA_LockHeld 锁被其他 worker 持有 → 返回 errRefreshSkipped
|
||||
func TestPathA_LockHeld(t *testing.T) {
|
||||
account := &Account{
|
||||
ID: 101,
|
||||
Platform: PlatformGemini,
|
||||
Type: AccountTypeOAuth,
|
||||
}
|
||||
repo := &tokenRefreshAccountRepo{}
|
||||
invalidator := &tokenCacheInvalidatorStub{}
|
||||
cache := &mockTokenCacheForRefreshAPI{lockResult: false} // 锁获取失败(被占)
|
||||
|
||||
service, refresher := buildPathAService(repo, cache, invalidator)
|
||||
|
||||
err := service.refreshWithRetry(context.Background(), account, refresher, refresher, time.Hour)
|
||||
require.ErrorIs(t, err, errRefreshSkipped)
|
||||
require.Equal(t, 0, repo.updateCalls) // 不应更新 DB
|
||||
require.Equal(t, 0, invalidator.calls) // 不应触发缓存失效
|
||||
}
|
||||
|
||||
// TestPathA_AlreadyRefreshed 二次检查发现已被其他路径刷新 → 返回 errRefreshSkipped
|
||||
func TestPathA_AlreadyRefreshed(t *testing.T) {
|
||||
// NeedsRefresh 返回 false → RefreshIfNeeded 返回 {Refreshed: false}
|
||||
account := &Account{
|
||||
ID: 102,
|
||||
Platform: PlatformGemini,
|
||||
Type: AccountTypeOAuth,
|
||||
}
|
||||
repo := &tokenRefreshAccountRepo{}
|
||||
repo.accountsByID = map[int64]*Account{account.ID: account}
|
||||
invalidator := &tokenCacheInvalidatorStub{}
|
||||
cache := &mockTokenCacheForRefreshAPI{lockResult: true}
|
||||
|
||||
service, _ := buildPathAService(repo, cache, invalidator)
|
||||
|
||||
// 使用一个 NeedsRefresh 返回 false 的 stub
|
||||
noRefreshNeeded := &tokenRefresherStub{
|
||||
credentials: map[string]any{"access_token": "token"},
|
||||
}
|
||||
// 覆盖 NeedsRefresh 行为 — 我们需要一个新的 stub 类型
|
||||
alwaysFreshStub := &alwaysFreshRefresherStub{}
|
||||
|
||||
err := service.refreshWithRetry(context.Background(), account, noRefreshNeeded, alwaysFreshStub, time.Hour)
|
||||
require.ErrorIs(t, err, errRefreshSkipped)
|
||||
require.Equal(t, 0, repo.updateCalls)
|
||||
require.Equal(t, 0, invalidator.calls)
|
||||
}
|
||||
|
||||
// alwaysFreshRefresherStub 二次检查时认为不需要刷新(模拟已被其他路径刷新)
|
||||
type alwaysFreshRefresherStub struct{}
|
||||
|
||||
func (r *alwaysFreshRefresherStub) CanRefresh(_ *Account) bool { return true }
|
||||
func (r *alwaysFreshRefresherStub) NeedsRefresh(_ *Account, _ time.Duration) bool { return false }
|
||||
func (r *alwaysFreshRefresherStub) Refresh(_ context.Context, _ *Account) (map[string]any, error) {
|
||||
return nil, errors.New("should not be called")
|
||||
}
|
||||
func (r *alwaysFreshRefresherStub) CacheKey(account *Account) string {
|
||||
return "test:fresh:" + account.Platform
|
||||
}
|
||||
|
||||
// TestPathA_NonRetryableError 统一 API 路径返回不可重试错误 → SetError
|
||||
func TestPathA_NonRetryableError(t *testing.T) {
|
||||
account := &Account{
|
||||
ID: 103,
|
||||
Platform: PlatformGemini,
|
||||
Type: AccountTypeOAuth,
|
||||
}
|
||||
repo := &tokenRefreshAccountRepo{}
|
||||
repo.accountsByID = map[int64]*Account{account.ID: account}
|
||||
invalidator := &tokenCacheInvalidatorStub{}
|
||||
cache := &mockTokenCacheForRefreshAPI{lockResult: true}
|
||||
|
||||
service, _ := buildPathAService(repo, cache, invalidator)
|
||||
|
||||
refresher := &tokenRefresherStub{
|
||||
err: errors.New("invalid_grant: token revoked"),
|
||||
}
|
||||
|
||||
err := service.refreshWithRetry(context.Background(), account, refresher, refresher, time.Hour)
|
||||
require.Error(t, err)
|
||||
require.Equal(t, 1, repo.setErrorCalls) // 应标记 error 状态
|
||||
require.Equal(t, 0, repo.updateCalls) // 不应更新 credentials
|
||||
require.Equal(t, 0, invalidator.calls) // 不应触发缓存失效
|
||||
}
|
||||
|
||||
// TestPathA_RetryableErrorExhausted 统一 API 路径可重试错误耗尽 → 不标记 error
|
||||
func TestPathA_RetryableErrorExhausted(t *testing.T) {
|
||||
account := &Account{
|
||||
ID: 104,
|
||||
Platform: PlatformGemini,
|
||||
Type: AccountTypeOAuth,
|
||||
}
|
||||
repo := &tokenRefreshAccountRepo{}
|
||||
repo.accountsByID = map[int64]*Account{account.ID: account}
|
||||
invalidator := &tokenCacheInvalidatorStub{}
|
||||
cache := &mockTokenCacheForRefreshAPI{lockResult: true}
|
||||
|
||||
cfg := &config.Config{
|
||||
TokenRefresh: config.TokenRefreshConfig{
|
||||
MaxRetries: 2,
|
||||
RetryBackoffSeconds: 0,
|
||||
},
|
||||
}
|
||||
service := NewTokenRefreshService(repo, nil, nil, nil, nil, invalidator, nil, cfg, nil)
|
||||
refreshAPI := NewOAuthRefreshAPI(repo, cache)
|
||||
service.SetRefreshAPI(refreshAPI)
|
||||
|
||||
refresher := &tokenRefresherStub{
|
||||
err: errors.New("network timeout"),
|
||||
}
|
||||
|
||||
err := service.refreshWithRetry(context.Background(), account, refresher, refresher, time.Hour)
|
||||
require.Error(t, err)
|
||||
require.Equal(t, 0, repo.setErrorCalls) // 可重试错误不标记 error
|
||||
require.Equal(t, 0, repo.updateCalls) // 刷新失败不应更新
|
||||
require.Equal(t, 0, invalidator.calls) // 不应触发缓存失效
|
||||
}
|
||||
|
||||
// TestPathA_DBUpdateFailed 统一 API 路径 DB 更新失败 → 返回 error,不执行 postRefreshActions
|
||||
func TestPathA_DBUpdateFailed(t *testing.T) {
|
||||
account := &Account{
|
||||
ID: 105,
|
||||
Platform: PlatformGemini,
|
||||
Type: AccountTypeOAuth,
|
||||
}
|
||||
repo := &tokenRefreshAccountRepo{updateErr: errors.New("db connection lost")}
|
||||
repo.accountsByID = map[int64]*Account{account.ID: account}
|
||||
invalidator := &tokenCacheInvalidatorStub{}
|
||||
cache := &mockTokenCacheForRefreshAPI{lockResult: true}
|
||||
|
||||
service, refresher := buildPathAService(repo, cache, invalidator)
|
||||
|
||||
err := service.refreshWithRetry(context.Background(), account, refresher, refresher, time.Hour)
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "DB update failed")
|
||||
require.Equal(t, 1, repo.updateCalls) // DB 更新被尝试
|
||||
require.Equal(t, 0, invalidator.calls) // DB 失败时不应触发缓存失效
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user