feat: unified OAuth token refresh API with distributed locking
Introduce OAuthRefreshAPI as the single entry point for all OAuth token refresh operations, eliminating the race condition where background refresh and inline refresh could simultaneously use the same refresh_token (fixes #1035). Key changes: - Add OAuthRefreshExecutor interface extending TokenRefresher with CacheKey - Add OAuthRefreshAPI.RefreshIfNeeded with lock → DB re-read → double-check flow - Add ProviderRefreshPolicy / BackgroundRefreshPolicy strategy types - Simplify all 4 TokenProviders to delegate to OAuthRefreshAPI - Rewrite TokenRefreshService.refreshWithRetry to use unified API path - Add MergeCredentials and BuildClaudeAccountCredentials helpers - Add 40 unit tests covering all new and modified code paths
This commit is contained in:
@@ -2,6 +2,7 @@ package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"strings"
|
||||
@@ -16,10 +17,13 @@ import (
|
||||
type TokenRefreshService struct {
|
||||
accountRepo AccountRepository
|
||||
refreshers []TokenRefresher
|
||||
executors []OAuthRefreshExecutor // 与 refreshers 一一对应的 executor(带 CacheKey)
|
||||
refreshPolicy BackgroundRefreshPolicy
|
||||
cfg *config.TokenRefreshConfig
|
||||
cacheInvalidator TokenCacheInvalidator
|
||||
schedulerCache SchedulerCache // 用于同步更新调度器缓存,解决 token 刷新后缓存不一致问题
|
||||
tempUnschedCache TempUnschedCache // 用于清除 Redis 中的临时不可调度缓存
|
||||
refreshAPI *OAuthRefreshAPI // 统一刷新 API
|
||||
|
||||
// OpenAI privacy: 刷新成功后检查并设置 training opt-out
|
||||
privacyClientFactory PrivacyClientFactory
|
||||
@@ -43,6 +47,7 @@ func NewTokenRefreshService(
|
||||
) *TokenRefreshService {
|
||||
s := &TokenRefreshService{
|
||||
accountRepo: accountRepo,
|
||||
refreshPolicy: DefaultBackgroundRefreshPolicy(),
|
||||
cfg: &cfg.TokenRefresh,
|
||||
cacheInvalidator: cacheInvalidator,
|
||||
schedulerCache: schedulerCache,
|
||||
@@ -53,12 +58,24 @@ func NewTokenRefreshService(
|
||||
openAIRefresher := NewOpenAITokenRefresher(openaiOAuthService, accountRepo)
|
||||
openAIRefresher.SetSyncLinkedSoraAccounts(cfg.TokenRefresh.SyncLinkedSoraAccounts)
|
||||
|
||||
// 注册平台特定的刷新器
|
||||
claudeRefresher := NewClaudeTokenRefresher(oauthService)
|
||||
geminiRefresher := NewGeminiTokenRefresher(geminiOAuthService)
|
||||
agRefresher := NewAntigravityTokenRefresher(antigravityOAuthService)
|
||||
|
||||
// 注册平台特定的刷新器(TokenRefresher 接口)
|
||||
s.refreshers = []TokenRefresher{
|
||||
NewClaudeTokenRefresher(oauthService),
|
||||
claudeRefresher,
|
||||
openAIRefresher,
|
||||
NewGeminiTokenRefresher(geminiOAuthService),
|
||||
NewAntigravityTokenRefresher(antigravityOAuthService),
|
||||
geminiRefresher,
|
||||
agRefresher,
|
||||
}
|
||||
|
||||
// 注册对应的 OAuthRefreshExecutor(带 CacheKey 方法)
|
||||
s.executors = []OAuthRefreshExecutor{
|
||||
claudeRefresher,
|
||||
openAIRefresher,
|
||||
geminiRefresher,
|
||||
agRefresher,
|
||||
}
|
||||
|
||||
return s
|
||||
@@ -82,6 +99,16 @@ func (s *TokenRefreshService) SetPrivacyDeps(factory PrivacyClientFactory, proxy
|
||||
s.proxyRepo = proxyRepo
|
||||
}
|
||||
|
||||
// SetRefreshAPI 注入统一的 OAuth 刷新 API
|
||||
func (s *TokenRefreshService) SetRefreshAPI(api *OAuthRefreshAPI) {
|
||||
s.refreshAPI = api
|
||||
}
|
||||
|
||||
// SetRefreshPolicy 注入后台刷新调用侧策略(用于显式化平台/场景差异行为)。
|
||||
func (s *TokenRefreshService) SetRefreshPolicy(policy BackgroundRefreshPolicy) {
|
||||
s.refreshPolicy = policy
|
||||
}
|
||||
|
||||
// Start 启动后台刷新服务
|
||||
func (s *TokenRefreshService) Start() {
|
||||
if !s.cfg.Enabled {
|
||||
@@ -148,13 +175,13 @@ func (s *TokenRefreshService) processRefresh() {
|
||||
totalAccounts := len(accounts)
|
||||
oauthAccounts := 0 // 可刷新的OAuth账号数
|
||||
needsRefresh := 0 // 需要刷新的账号数
|
||||
refreshed, failed := 0, 0
|
||||
refreshed, failed, skipped := 0, 0, 0
|
||||
|
||||
for i := range accounts {
|
||||
account := &accounts[i]
|
||||
|
||||
// 遍历所有刷新器,找到能处理此账号的
|
||||
for _, refresher := range s.refreshers {
|
||||
for idx, refresher := range s.refreshers {
|
||||
if !refresher.CanRefresh(account) {
|
||||
continue
|
||||
}
|
||||
@@ -168,14 +195,24 @@ func (s *TokenRefreshService) processRefresh() {
|
||||
|
||||
needsRefresh++
|
||||
|
||||
// 获取对应的 executor
|
||||
var executor OAuthRefreshExecutor
|
||||
if idx < len(s.executors) {
|
||||
executor = s.executors[idx]
|
||||
}
|
||||
|
||||
// 执行刷新
|
||||
if err := s.refreshWithRetry(ctx, account, refresher); err != nil {
|
||||
slog.Warn("token_refresh.account_refresh_failed",
|
||||
"account_id", account.ID,
|
||||
"account_name", account.Name,
|
||||
"error", err,
|
||||
)
|
||||
failed++
|
||||
if err := s.refreshWithRetry(ctx, account, refresher, executor, refreshWindow); err != nil {
|
||||
if errors.Is(err, errRefreshSkipped) {
|
||||
skipped++
|
||||
} else {
|
||||
slog.Warn("token_refresh.account_refresh_failed",
|
||||
"account_id", account.ID,
|
||||
"account_name", account.Name,
|
||||
"error", err,
|
||||
)
|
||||
failed++
|
||||
}
|
||||
} else {
|
||||
slog.Info("token_refresh.account_refreshed",
|
||||
"account_id", account.ID,
|
||||
@@ -193,13 +230,14 @@ func (s *TokenRefreshService) processRefresh() {
|
||||
if needsRefresh == 0 && failed == 0 {
|
||||
slog.Debug("token_refresh.cycle_completed",
|
||||
"total", totalAccounts, "oauth", oauthAccounts,
|
||||
"needs_refresh", needsRefresh, "refreshed", refreshed, "failed", failed)
|
||||
"needs_refresh", needsRefresh, "refreshed", refreshed, "skipped", skipped, "failed", failed)
|
||||
} else {
|
||||
slog.Info("token_refresh.cycle_completed",
|
||||
"total", totalAccounts,
|
||||
"oauth", oauthAccounts,
|
||||
"needs_refresh", needsRefresh,
|
||||
"refreshed", refreshed,
|
||||
"skipped", skipped,
|
||||
"failed", failed,
|
||||
)
|
||||
}
|
||||
@@ -212,83 +250,43 @@ func (s *TokenRefreshService) listActiveAccounts(ctx context.Context) ([]Account
|
||||
}
|
||||
|
||||
// refreshWithRetry 带重试的刷新
|
||||
func (s *TokenRefreshService) refreshWithRetry(ctx context.Context, account *Account, refresher TokenRefresher) error {
|
||||
func (s *TokenRefreshService) refreshWithRetry(ctx context.Context, account *Account, refresher TokenRefresher, executor OAuthRefreshExecutor, refreshWindow time.Duration) error {
|
||||
var lastErr error
|
||||
|
||||
for attempt := 1; attempt <= s.cfg.MaxRetries; attempt++ {
|
||||
newCredentials, err := refresher.Refresh(ctx, account)
|
||||
var newCredentials map[string]any
|
||||
var err error
|
||||
|
||||
// 如果有新凭证,先更新(即使有错误也要保存 token)
|
||||
if newCredentials != nil {
|
||||
// 记录刷新版本时间戳,用于解决缓存一致性问题
|
||||
// TokenProvider 写入缓存前会检查此版本,如果版本已更新则跳过写入
|
||||
newCredentials["_token_version"] = time.Now().UnixMilli()
|
||||
|
||||
account.Credentials = newCredentials
|
||||
if saveErr := s.accountRepo.Update(ctx, account); saveErr != nil {
|
||||
return fmt.Errorf("failed to save credentials: %w", saveErr)
|
||||
// 优先使用统一 API(带分布式锁 + DB 重读保护)
|
||||
if s.refreshAPI != nil && executor != nil {
|
||||
result, refreshErr := s.refreshAPI.RefreshIfNeeded(ctx, account, executor, refreshWindow)
|
||||
if refreshErr != nil {
|
||||
err = refreshErr
|
||||
} else if result.LockHeld {
|
||||
// 锁被其他 worker 持有,由调用侧策略决定如何计数
|
||||
return s.refreshPolicy.handleLockHeld()
|
||||
} else if !result.Refreshed {
|
||||
// 已被其他路径刷新,由调用侧策略决定如何计数
|
||||
return s.refreshPolicy.handleAlreadyRefreshed()
|
||||
} else {
|
||||
account = result.Account
|
||||
newCredentials = result.NewCredentials
|
||||
// 统一 API 已设置 _token_version 并更新 DB,无需重复操作
|
||||
}
|
||||
} else {
|
||||
// 降级:直接调用 refresher(兼容旧路径)
|
||||
newCredentials, err = refresher.Refresh(ctx, account)
|
||||
if newCredentials != nil {
|
||||
newCredentials["_token_version"] = time.Now().UnixMilli()
|
||||
account.Credentials = newCredentials
|
||||
if saveErr := s.accountRepo.Update(ctx, account); saveErr != nil {
|
||||
return fmt.Errorf("failed to save credentials: %w", saveErr)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if err == nil {
|
||||
// Antigravity 账户:如果之前是因为缺少 project_id 而标记为 error,现在成功获取到了,清除错误状态
|
||||
if account.Platform == PlatformAntigravity &&
|
||||
account.Status == StatusError &&
|
||||
strings.Contains(account.ErrorMessage, "missing_project_id:") {
|
||||
if clearErr := s.accountRepo.ClearError(ctx, account.ID); clearErr != nil {
|
||||
slog.Warn("token_refresh.clear_account_error_failed",
|
||||
"account_id", account.ID,
|
||||
"error", clearErr,
|
||||
)
|
||||
} else {
|
||||
slog.Info("token_refresh.cleared_missing_project_id_error", "account_id", account.ID)
|
||||
}
|
||||
}
|
||||
// 刷新成功后清除临时不可调度状态(处理 OAuth 401 恢复场景)
|
||||
if account.TempUnschedulableUntil != nil && time.Now().Before(*account.TempUnschedulableUntil) {
|
||||
if clearErr := s.accountRepo.ClearTempUnschedulable(ctx, account.ID); clearErr != nil {
|
||||
slog.Warn("token_refresh.clear_temp_unschedulable_failed",
|
||||
"account_id", account.ID,
|
||||
"error", clearErr,
|
||||
)
|
||||
} else {
|
||||
slog.Info("token_refresh.cleared_temp_unschedulable", "account_id", account.ID)
|
||||
}
|
||||
// 同步清除 Redis 缓存,避免调度器读到过期的临时不可调度状态
|
||||
if s.tempUnschedCache != nil {
|
||||
if clearErr := s.tempUnschedCache.DeleteTempUnsched(ctx, account.ID); clearErr != nil {
|
||||
slog.Warn("token_refresh.clear_temp_unsched_cache_failed",
|
||||
"account_id", account.ID,
|
||||
"error", clearErr,
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
// 对所有 OAuth 账号调用缓存失效(InvalidateToken 内部根据平台判断是否需要处理)
|
||||
if s.cacheInvalidator != nil && account.Type == AccountTypeOAuth {
|
||||
if err := s.cacheInvalidator.InvalidateToken(ctx, account); err != nil {
|
||||
slog.Warn("token_refresh.invalidate_token_cache_failed",
|
||||
"account_id", account.ID,
|
||||
"error", err,
|
||||
)
|
||||
} else {
|
||||
slog.Debug("token_refresh.token_cache_invalidated", "account_id", account.ID)
|
||||
}
|
||||
}
|
||||
// 同步更新调度器缓存,确保调度获取的 Account 对象包含最新的 credentials
|
||||
// 这解决了 token 刷新后调度器缓存数据不一致的问题(#445)
|
||||
if s.schedulerCache != nil {
|
||||
if err := s.schedulerCache.SetAccount(ctx, account); err != nil {
|
||||
slog.Warn("token_refresh.sync_scheduler_cache_failed",
|
||||
"account_id", account.ID,
|
||||
"error", err,
|
||||
)
|
||||
} else {
|
||||
slog.Debug("token_refresh.scheduler_cache_synced", "account_id", account.ID)
|
||||
}
|
||||
}
|
||||
// OpenAI OAuth: 刷新成功后,检查是否已设置 privacy_mode,未设置则尝试关闭训练数据共享
|
||||
s.ensureOpenAIPrivacy(ctx, account)
|
||||
s.postRefreshActions(ctx, account)
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -331,6 +329,70 @@ func (s *TokenRefreshService) refreshWithRetry(ctx context.Context, account *Acc
|
||||
return lastErr
|
||||
}
|
||||
|
||||
// postRefreshActions 刷新成功后的后续动作(清除错误状态、缓存失效、调度器同步等)
|
||||
func (s *TokenRefreshService) postRefreshActions(ctx context.Context, account *Account) {
|
||||
// Antigravity 账户:如果之前是因为缺少 project_id 而标记为 error,现在成功获取到了,清除错误状态
|
||||
if account.Platform == PlatformAntigravity &&
|
||||
account.Status == StatusError &&
|
||||
strings.Contains(account.ErrorMessage, "missing_project_id:") {
|
||||
if clearErr := s.accountRepo.ClearError(ctx, account.ID); clearErr != nil {
|
||||
slog.Warn("token_refresh.clear_account_error_failed",
|
||||
"account_id", account.ID,
|
||||
"error", clearErr,
|
||||
)
|
||||
} else {
|
||||
slog.Info("token_refresh.cleared_missing_project_id_error", "account_id", account.ID)
|
||||
}
|
||||
}
|
||||
// 刷新成功后清除临时不可调度状态(处理 OAuth 401 恢复场景)
|
||||
if account.TempUnschedulableUntil != nil && time.Now().Before(*account.TempUnschedulableUntil) {
|
||||
if clearErr := s.accountRepo.ClearTempUnschedulable(ctx, account.ID); clearErr != nil {
|
||||
slog.Warn("token_refresh.clear_temp_unschedulable_failed",
|
||||
"account_id", account.ID,
|
||||
"error", clearErr,
|
||||
)
|
||||
} else {
|
||||
slog.Info("token_refresh.cleared_temp_unschedulable", "account_id", account.ID)
|
||||
}
|
||||
// 同步清除 Redis 缓存,避免调度器读到过期的临时不可调度状态
|
||||
if s.tempUnschedCache != nil {
|
||||
if clearErr := s.tempUnschedCache.DeleteTempUnsched(ctx, account.ID); clearErr != nil {
|
||||
slog.Warn("token_refresh.clear_temp_unsched_cache_failed",
|
||||
"account_id", account.ID,
|
||||
"error", clearErr,
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
// 对所有 OAuth 账号调用缓存失效(InvalidateToken 内部根据平台判断是否需要处理)
|
||||
if s.cacheInvalidator != nil && account.Type == AccountTypeOAuth {
|
||||
if err := s.cacheInvalidator.InvalidateToken(ctx, account); err != nil {
|
||||
slog.Warn("token_refresh.invalidate_token_cache_failed",
|
||||
"account_id", account.ID,
|
||||
"error", err,
|
||||
)
|
||||
} else {
|
||||
slog.Debug("token_refresh.token_cache_invalidated", "account_id", account.ID)
|
||||
}
|
||||
}
|
||||
// 同步更新调度器缓存,确保调度获取的 Account 对象包含最新的 credentials
|
||||
if s.schedulerCache != nil {
|
||||
if err := s.schedulerCache.SetAccount(ctx, account); err != nil {
|
||||
slog.Warn("token_refresh.sync_scheduler_cache_failed",
|
||||
"account_id", account.ID,
|
||||
"error", err,
|
||||
)
|
||||
} else {
|
||||
slog.Debug("token_refresh.scheduler_cache_synced", "account_id", account.ID)
|
||||
}
|
||||
}
|
||||
// OpenAI OAuth: 刷新成功后,检查是否已设置 privacy_mode,未设置则尝试关闭训练数据共享
|
||||
s.ensureOpenAIPrivacy(ctx, account)
|
||||
}
|
||||
|
||||
// errRefreshSkipped 表示刷新被跳过(锁竞争或已被其他路径刷新),不计入 failed 或 refreshed
|
||||
var errRefreshSkipped = fmt.Errorf("refresh skipped")
|
||||
|
||||
// isNonRetryableRefreshError 判断是否为不可重试的刷新错误
|
||||
// 这些错误通常表示凭证已失效或配置确实缺失,需要用户重新授权
|
||||
// 注意:missing_project_id 错误只在真正缺失(从未获取过)时返回,临时获取失败不会返回此错误
|
||||
|
||||
Reference in New Issue
Block a user