feat: unified OAuth token refresh API with distributed locking

Introduce OAuthRefreshAPI as the single entry point for all OAuth token
refresh operations, eliminating the race condition where background
refresh and inline refresh could simultaneously use the same
refresh_token (fixes #1035).

Key changes:
- Add OAuthRefreshExecutor interface extending TokenRefresher with CacheKey
- Add OAuthRefreshAPI.RefreshIfNeeded with lock → DB re-read → double-check flow
- Add ProviderRefreshPolicy / BackgroundRefreshPolicy strategy types
- Simplify all 4 TokenProviders to delegate to OAuthRefreshAPI
- Rewrite TokenRefreshService.refreshWithRetry to use unified API path
- Add MergeCredentials and BuildClaudeAccountCredentials helpers
- Add 40 unit tests covering all new and modified code paths
This commit is contained in:
erio
2026-03-16 01:31:54 +08:00
parent d3a9f5bb88
commit 1fc9dd7b68
14 changed files with 1336 additions and 452 deletions

View File

@@ -84,6 +84,10 @@ func (r *tokenRefresherStub) Refresh(ctx context.Context, account *Account) (map
return r.credentials, nil
}
func (r *tokenRefresherStub) CacheKey(account *Account) string {
return "test:stub:" + account.Platform
}
func TestTokenRefreshService_RefreshWithRetry_InvalidatesCache(t *testing.T) {
repo := &tokenRefreshAccountRepo{}
invalidator := &tokenCacheInvalidatorStub{}
@@ -105,7 +109,7 @@ func TestTokenRefreshService_RefreshWithRetry_InvalidatesCache(t *testing.T) {
},
}
err := service.refreshWithRetry(context.Background(), account, refresher)
err := service.refreshWithRetry(context.Background(), account, refresher, refresher, time.Hour)
require.NoError(t, err)
require.Equal(t, 1, repo.updateCalls)
require.Equal(t, 1, invalidator.calls)
@@ -133,7 +137,7 @@ func TestTokenRefreshService_RefreshWithRetry_InvalidatorErrorIgnored(t *testing
},
}
err := service.refreshWithRetry(context.Background(), account, refresher)
err := service.refreshWithRetry(context.Background(), account, refresher, refresher, time.Hour)
require.NoError(t, err)
require.Equal(t, 1, repo.updateCalls)
require.Equal(t, 1, invalidator.calls)
@@ -159,7 +163,7 @@ func TestTokenRefreshService_RefreshWithRetry_NilInvalidator(t *testing.T) {
},
}
err := service.refreshWithRetry(context.Background(), account, refresher)
err := service.refreshWithRetry(context.Background(), account, refresher, refresher, time.Hour)
require.NoError(t, err)
require.Equal(t, 1, repo.updateCalls)
}
@@ -186,7 +190,7 @@ func TestTokenRefreshService_RefreshWithRetry_Antigravity(t *testing.T) {
},
}
err := service.refreshWithRetry(context.Background(), account, refresher)
err := service.refreshWithRetry(context.Background(), account, refresher, refresher, time.Hour)
require.NoError(t, err)
require.Equal(t, 1, repo.updateCalls)
require.Equal(t, 1, invalidator.calls) // Antigravity 也应触发缓存失效
@@ -214,7 +218,7 @@ func TestTokenRefreshService_RefreshWithRetry_NonOAuthAccount(t *testing.T) {
},
}
err := service.refreshWithRetry(context.Background(), account, refresher)
err := service.refreshWithRetry(context.Background(), account, refresher, refresher, time.Hour)
require.NoError(t, err)
require.Equal(t, 1, repo.updateCalls)
require.Equal(t, 0, invalidator.calls) // 非 OAuth 不触发缓存失效
@@ -242,7 +246,7 @@ func TestTokenRefreshService_RefreshWithRetry_OtherPlatformOAuth(t *testing.T) {
},
}
err := service.refreshWithRetry(context.Background(), account, refresher)
err := service.refreshWithRetry(context.Background(), account, refresher, refresher, time.Hour)
require.NoError(t, err)
require.Equal(t, 1, repo.updateCalls)
require.Equal(t, 1, invalidator.calls) // 所有 OAuth 账户刷新后触发缓存失效
@@ -270,7 +274,7 @@ func TestTokenRefreshService_RefreshWithRetry_UpdateFailed(t *testing.T) {
},
}
err := service.refreshWithRetry(context.Background(), account, refresher)
err := service.refreshWithRetry(context.Background(), account, refresher, refresher, time.Hour)
require.Error(t, err)
require.Contains(t, err.Error(), "failed to save credentials")
require.Equal(t, 1, repo.updateCalls)
@@ -297,7 +301,7 @@ func TestTokenRefreshService_RefreshWithRetry_RefreshFailed(t *testing.T) {
err: errors.New("refresh failed"),
}
err := service.refreshWithRetry(context.Background(), account, refresher)
err := service.refreshWithRetry(context.Background(), account, refresher, refresher, time.Hour)
require.Error(t, err)
require.Equal(t, 0, repo.updateCalls) // 刷新失败不应更新
require.Equal(t, 0, invalidator.calls) // 刷新失败不应触发缓存失效
@@ -324,7 +328,7 @@ func TestTokenRefreshService_RefreshWithRetry_AntigravityRefreshFailed(t *testin
err: errors.New("network error"), // 可重试错误
}
err := service.refreshWithRetry(context.Background(), account, refresher)
err := service.refreshWithRetry(context.Background(), account, refresher, refresher, time.Hour)
require.Error(t, err)
require.Equal(t, 0, repo.updateCalls)
require.Equal(t, 0, invalidator.calls)
@@ -351,7 +355,7 @@ func TestTokenRefreshService_RefreshWithRetry_AntigravityNonRetryableError(t *te
err: errors.New("invalid_grant: token revoked"), // 不可重试错误
}
err := service.refreshWithRetry(context.Background(), account, refresher)
err := service.refreshWithRetry(context.Background(), account, refresher, refresher, time.Hour)
require.Error(t, err)
require.Equal(t, 0, repo.updateCalls)
require.Equal(t, 0, invalidator.calls)
@@ -383,7 +387,7 @@ func TestTokenRefreshService_RefreshWithRetry_ClearsTempUnschedulable(t *testing
},
}
err := service.refreshWithRetry(context.Background(), account, refresher)
err := service.refreshWithRetry(context.Background(), account, refresher, refresher, time.Hour)
require.NoError(t, err)
require.Equal(t, 1, repo.updateCalls)
require.Equal(t, 1, repo.clearTempCalls) // DB 清除
@@ -422,7 +426,7 @@ func TestTokenRefreshService_RefreshWithRetry_NonRetryableErrorAllPlatforms(t *t
err: errors.New("invalid_grant: token revoked"),
}
err := service.refreshWithRetry(context.Background(), account, refresher)
err := service.refreshWithRetry(context.Background(), account, refresher, refresher, time.Hour)
require.Error(t, err)
require.Equal(t, 1, repo.setErrorCalls) // 所有平台不可重试错误都应 SetError
})
@@ -453,3 +457,212 @@ func TestIsNonRetryableRefreshError(t *testing.T) {
})
}
}
// ========== Path A (refreshAPI) 测试用例 ==========
// mockTokenCacheForRefreshAPI 用于 Path A 测试的 GeminiTokenCache mock
type mockTokenCacheForRefreshAPI struct {
lockResult bool
lockErr error
releaseCalls int
}
func (m *mockTokenCacheForRefreshAPI) GetAccessToken(_ context.Context, _ string) (string, error) {
return "", errors.New("not cached")
}
func (m *mockTokenCacheForRefreshAPI) SetAccessToken(_ context.Context, _ string, _ string, _ time.Duration) error {
return nil
}
func (m *mockTokenCacheForRefreshAPI) DeleteAccessToken(_ context.Context, _ string) error {
return nil
}
func (m *mockTokenCacheForRefreshAPI) AcquireRefreshLock(_ context.Context, _ string, _ time.Duration) (bool, error) {
return m.lockResult, m.lockErr
}
func (m *mockTokenCacheForRefreshAPI) ReleaseRefreshLock(_ context.Context, _ string) error {
m.releaseCalls++
return nil
}
// buildPathAService 构建注入了 refreshAPI 的 servicePath A 测试辅助)
func buildPathAService(repo *tokenRefreshAccountRepo, cache GeminiTokenCache, invalidator TokenCacheInvalidator) (*TokenRefreshService, *tokenRefresherStub) {
cfg := &config.Config{
TokenRefresh: config.TokenRefreshConfig{
MaxRetries: 1,
RetryBackoffSeconds: 0,
},
}
service := NewTokenRefreshService(repo, nil, nil, nil, nil, invalidator, nil, cfg, nil)
refreshAPI := NewOAuthRefreshAPI(repo, cache)
service.SetRefreshAPI(refreshAPI)
refresher := &tokenRefresherStub{
credentials: map[string]any{
"access_token": "refreshed-token",
},
}
return service, refresher
}
// TestPathA_Success 统一 API 路径正常成功:刷新 + DB 更新 + postRefreshActions
func TestPathA_Success(t *testing.T) {
account := &Account{
ID: 100,
Platform: PlatformGemini,
Type: AccountTypeOAuth,
}
repo := &tokenRefreshAccountRepo{}
repo.accountsByID = map[int64]*Account{account.ID: account}
invalidator := &tokenCacheInvalidatorStub{}
cache := &mockTokenCacheForRefreshAPI{lockResult: true}
service, refresher := buildPathAService(repo, cache, invalidator)
err := service.refreshWithRetry(context.Background(), account, refresher, refresher, time.Hour)
require.NoError(t, err)
require.Equal(t, 1, repo.updateCalls) // DB 更新被调用
require.Equal(t, 1, invalidator.calls) // 缓存失效被调用
require.Equal(t, 1, cache.releaseCalls) // 锁被释放
}
// TestPathA_LockHeld 锁被其他 worker 持有 → 返回 errRefreshSkipped
func TestPathA_LockHeld(t *testing.T) {
account := &Account{
ID: 101,
Platform: PlatformGemini,
Type: AccountTypeOAuth,
}
repo := &tokenRefreshAccountRepo{}
invalidator := &tokenCacheInvalidatorStub{}
cache := &mockTokenCacheForRefreshAPI{lockResult: false} // 锁获取失败(被占)
service, refresher := buildPathAService(repo, cache, invalidator)
err := service.refreshWithRetry(context.Background(), account, refresher, refresher, time.Hour)
require.ErrorIs(t, err, errRefreshSkipped)
require.Equal(t, 0, repo.updateCalls) // 不应更新 DB
require.Equal(t, 0, invalidator.calls) // 不应触发缓存失效
}
// TestPathA_AlreadyRefreshed 二次检查发现已被其他路径刷新 → 返回 errRefreshSkipped
func TestPathA_AlreadyRefreshed(t *testing.T) {
// NeedsRefresh 返回 false → RefreshIfNeeded 返回 {Refreshed: false}
account := &Account{
ID: 102,
Platform: PlatformGemini,
Type: AccountTypeOAuth,
}
repo := &tokenRefreshAccountRepo{}
repo.accountsByID = map[int64]*Account{account.ID: account}
invalidator := &tokenCacheInvalidatorStub{}
cache := &mockTokenCacheForRefreshAPI{lockResult: true}
service, _ := buildPathAService(repo, cache, invalidator)
// 使用一个 NeedsRefresh 返回 false 的 stub
noRefreshNeeded := &tokenRefresherStub{
credentials: map[string]any{"access_token": "token"},
}
// 覆盖 NeedsRefresh 行为 — 我们需要一个新的 stub 类型
alwaysFreshStub := &alwaysFreshRefresherStub{}
err := service.refreshWithRetry(context.Background(), account, noRefreshNeeded, alwaysFreshStub, time.Hour)
require.ErrorIs(t, err, errRefreshSkipped)
require.Equal(t, 0, repo.updateCalls)
require.Equal(t, 0, invalidator.calls)
}
// alwaysFreshRefresherStub 二次检查时认为不需要刷新(模拟已被其他路径刷新)
type alwaysFreshRefresherStub struct{}
func (r *alwaysFreshRefresherStub) CanRefresh(_ *Account) bool { return true }
func (r *alwaysFreshRefresherStub) NeedsRefresh(_ *Account, _ time.Duration) bool { return false }
func (r *alwaysFreshRefresherStub) Refresh(_ context.Context, _ *Account) (map[string]any, error) {
return nil, errors.New("should not be called")
}
func (r *alwaysFreshRefresherStub) CacheKey(account *Account) string {
return "test:fresh:" + account.Platform
}
// TestPathA_NonRetryableError 统一 API 路径返回不可重试错误 → SetError
func TestPathA_NonRetryableError(t *testing.T) {
account := &Account{
ID: 103,
Platform: PlatformGemini,
Type: AccountTypeOAuth,
}
repo := &tokenRefreshAccountRepo{}
repo.accountsByID = map[int64]*Account{account.ID: account}
invalidator := &tokenCacheInvalidatorStub{}
cache := &mockTokenCacheForRefreshAPI{lockResult: true}
service, _ := buildPathAService(repo, cache, invalidator)
refresher := &tokenRefresherStub{
err: errors.New("invalid_grant: token revoked"),
}
err := service.refreshWithRetry(context.Background(), account, refresher, refresher, time.Hour)
require.Error(t, err)
require.Equal(t, 1, repo.setErrorCalls) // 应标记 error 状态
require.Equal(t, 0, repo.updateCalls) // 不应更新 credentials
require.Equal(t, 0, invalidator.calls) // 不应触发缓存失效
}
// TestPathA_RetryableErrorExhausted 统一 API 路径可重试错误耗尽 → 不标记 error
func TestPathA_RetryableErrorExhausted(t *testing.T) {
account := &Account{
ID: 104,
Platform: PlatformGemini,
Type: AccountTypeOAuth,
}
repo := &tokenRefreshAccountRepo{}
repo.accountsByID = map[int64]*Account{account.ID: account}
invalidator := &tokenCacheInvalidatorStub{}
cache := &mockTokenCacheForRefreshAPI{lockResult: true}
cfg := &config.Config{
TokenRefresh: config.TokenRefreshConfig{
MaxRetries: 2,
RetryBackoffSeconds: 0,
},
}
service := NewTokenRefreshService(repo, nil, nil, nil, nil, invalidator, nil, cfg, nil)
refreshAPI := NewOAuthRefreshAPI(repo, cache)
service.SetRefreshAPI(refreshAPI)
refresher := &tokenRefresherStub{
err: errors.New("network timeout"),
}
err := service.refreshWithRetry(context.Background(), account, refresher, refresher, time.Hour)
require.Error(t, err)
require.Equal(t, 0, repo.setErrorCalls) // 可重试错误不标记 error
require.Equal(t, 0, repo.updateCalls) // 刷新失败不应更新
require.Equal(t, 0, invalidator.calls) // 不应触发缓存失效
}
// TestPathA_DBUpdateFailed 统一 API 路径 DB 更新失败 → 返回 error不执行 postRefreshActions
func TestPathA_DBUpdateFailed(t *testing.T) {
account := &Account{
ID: 105,
Platform: PlatformGemini,
Type: AccountTypeOAuth,
}
repo := &tokenRefreshAccountRepo{updateErr: errors.New("db connection lost")}
repo.accountsByID = map[int64]*Account{account.ID: account}
invalidator := &tokenCacheInvalidatorStub{}
cache := &mockTokenCacheForRefreshAPI{lockResult: true}
service, refresher := buildPathAService(repo, cache, invalidator)
err := service.refreshWithRetry(context.Background(), account, refresher, refresher, time.Hour)
require.Error(t, err)
require.Contains(t, err.Error(), "DB update failed")
require.Equal(t, 1, repo.updateCalls) // DB 更新被尝试
require.Equal(t, 0, invalidator.calls) // DB 失败时不应触发缓存失效
}