feat: unified OAuth token refresh API with distributed locking
Introduce OAuthRefreshAPI as the single entry point for all OAuth token refresh operations, eliminating the race condition where background refresh and inline refresh could simultaneously use the same refresh_token (fixes #1035). Key changes: - Add OAuthRefreshExecutor interface extending TokenRefresher with CacheKey - Add OAuthRefreshAPI.RefreshIfNeeded with lock → DB re-read → double-check flow - Add ProviderRefreshPolicy / BackgroundRefreshPolicy strategy types - Simplify all 4 TokenProviders to delegate to OAuthRefreshAPI - Rewrite TokenRefreshService.refreshWithRetry to use unified API path - Add MergeCredentials and BuildClaudeAccountCredentials helpers - Add 40 unit tests covering all new and modified code paths
This commit is contained in:
159
backend/internal/service/oauth_refresh_api.go
Normal file
159
backend/internal/service/oauth_refresh_api.go
Normal file
@@ -0,0 +1,159 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"strconv"
|
||||
"time"
|
||||
)
|
||||
|
||||
// OAuthRefreshExecutor 各平台实现的 OAuth 刷新执行器
|
||||
// TokenRefresher 接口的超集:增加了 CacheKey 方法用于分布式锁
|
||||
type OAuthRefreshExecutor interface {
|
||||
TokenRefresher
|
||||
|
||||
// CacheKey 返回用于分布式锁的缓存键(与 TokenProvider 使用的一致)
|
||||
CacheKey(account *Account) string
|
||||
}
|
||||
|
||||
const refreshLockTTL = 30 * time.Second
|
||||
|
||||
// OAuthRefreshResult 统一刷新结果
|
||||
type OAuthRefreshResult struct {
|
||||
Refreshed bool // 实际执行了刷新
|
||||
NewCredentials map[string]any // 刷新后的 credentials(nil 表示未刷新)
|
||||
Account *Account // 从 DB 重新读取的最新 account
|
||||
LockHeld bool // 锁被其他 worker 持有(未执行刷新)
|
||||
}
|
||||
|
||||
// OAuthRefreshAPI 统一的 OAuth Token 刷新入口
|
||||
// 封装分布式锁、DB 重读、已刷新检查等通用逻辑
|
||||
type OAuthRefreshAPI struct {
|
||||
accountRepo AccountRepository
|
||||
tokenCache GeminiTokenCache // 可选,nil = 无锁
|
||||
}
|
||||
|
||||
// NewOAuthRefreshAPI 创建统一刷新 API
|
||||
func NewOAuthRefreshAPI(accountRepo AccountRepository, tokenCache GeminiTokenCache) *OAuthRefreshAPI {
|
||||
return &OAuthRefreshAPI{
|
||||
accountRepo: accountRepo,
|
||||
tokenCache: tokenCache,
|
||||
}
|
||||
}
|
||||
|
||||
// RefreshIfNeeded 在分布式锁保护下按需刷新 OAuth token
|
||||
//
|
||||
// 流程:
|
||||
// 1. 获取分布式锁
|
||||
// 2. 从 DB 重读最新 account(防止使用过时的 refresh_token)
|
||||
// 3. 二次检查是否仍需刷新
|
||||
// 4. 调用 executor.Refresh() 执行平台特定刷新逻辑
|
||||
// 5. 设置 _token_version + 更新 DB
|
||||
// 6. 释放锁
|
||||
func (api *OAuthRefreshAPI) RefreshIfNeeded(
|
||||
ctx context.Context,
|
||||
account *Account,
|
||||
executor OAuthRefreshExecutor,
|
||||
refreshWindow time.Duration,
|
||||
) (*OAuthRefreshResult, error) {
|
||||
cacheKey := executor.CacheKey(account)
|
||||
|
||||
// 1. 获取分布式锁
|
||||
lockAcquired := false
|
||||
if api.tokenCache != nil {
|
||||
acquired, lockErr := api.tokenCache.AcquireRefreshLock(ctx, cacheKey, refreshLockTTL)
|
||||
if lockErr != nil {
|
||||
// Redis 错误,降级为无锁刷新
|
||||
slog.Warn("oauth_refresh_lock_failed_degraded",
|
||||
"account_id", account.ID,
|
||||
"cache_key", cacheKey,
|
||||
"error", lockErr,
|
||||
)
|
||||
} else if !acquired {
|
||||
// 锁被其他 worker 持有
|
||||
return &OAuthRefreshResult{LockHeld: true}, nil
|
||||
} else {
|
||||
lockAcquired = true
|
||||
defer func() { _ = api.tokenCache.ReleaseRefreshLock(ctx, cacheKey) }()
|
||||
}
|
||||
}
|
||||
|
||||
// 2. 从 DB 重读最新 account(锁保护下,确保使用最新的 refresh_token)
|
||||
freshAccount, err := api.accountRepo.GetByID(ctx, account.ID)
|
||||
if err != nil {
|
||||
slog.Warn("oauth_refresh_db_reread_failed",
|
||||
"account_id", account.ID,
|
||||
"error", err,
|
||||
)
|
||||
// 降级使用传入的 account
|
||||
freshAccount = account
|
||||
} else if freshAccount == nil {
|
||||
freshAccount = account
|
||||
}
|
||||
|
||||
// 3. 二次检查是否仍需刷新(另一条路径可能已刷新)
|
||||
if !executor.NeedsRefresh(freshAccount, refreshWindow) {
|
||||
return &OAuthRefreshResult{
|
||||
Account: freshAccount,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// 4. 执行平台特定刷新逻辑
|
||||
newCredentials, refreshErr := executor.Refresh(ctx, freshAccount)
|
||||
if refreshErr != nil {
|
||||
return nil, refreshErr
|
||||
}
|
||||
|
||||
// 5. 设置版本号 + 更新 DB
|
||||
if newCredentials != nil {
|
||||
newCredentials["_token_version"] = time.Now().UnixMilli()
|
||||
freshAccount.Credentials = newCredentials
|
||||
if updateErr := api.accountRepo.Update(ctx, freshAccount); updateErr != nil {
|
||||
slog.Error("oauth_refresh_update_failed",
|
||||
"account_id", freshAccount.ID,
|
||||
"error", updateErr,
|
||||
)
|
||||
return nil, fmt.Errorf("oauth refresh succeeded but DB update failed: %w", updateErr)
|
||||
}
|
||||
}
|
||||
|
||||
_ = lockAcquired // suppress unused warning when tokenCache is nil
|
||||
|
||||
return &OAuthRefreshResult{
|
||||
Refreshed: true,
|
||||
NewCredentials: newCredentials,
|
||||
Account: freshAccount,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// MergeCredentials 将旧 credentials 中不存在于新 map 的字段保留到新 map 中
|
||||
func MergeCredentials(oldCreds, newCreds map[string]any) map[string]any {
|
||||
if newCreds == nil {
|
||||
newCreds = make(map[string]any)
|
||||
}
|
||||
for k, v := range oldCreds {
|
||||
if _, exists := newCreds[k]; !exists {
|
||||
newCreds[k] = v
|
||||
}
|
||||
}
|
||||
return newCreds
|
||||
}
|
||||
|
||||
// BuildClaudeAccountCredentials 为 Claude 平台构建 OAuth credentials map
|
||||
// 消除 Claude 平台没有 BuildAccountCredentials 方法的问题
|
||||
func BuildClaudeAccountCredentials(tokenInfo *TokenInfo) map[string]any {
|
||||
creds := map[string]any{
|
||||
"access_token": tokenInfo.AccessToken,
|
||||
"token_type": tokenInfo.TokenType,
|
||||
"expires_in": strconv.FormatInt(tokenInfo.ExpiresIn, 10),
|
||||
"expires_at": strconv.FormatInt(tokenInfo.ExpiresAt, 10),
|
||||
}
|
||||
if tokenInfo.RefreshToken != "" {
|
||||
creds["refresh_token"] = tokenInfo.RefreshToken
|
||||
}
|
||||
if tokenInfo.Scope != "" {
|
||||
creds["scope"] = tokenInfo.Scope
|
||||
}
|
||||
return creds
|
||||
}
|
||||
Reference in New Issue
Block a user