Files
sub2api/backend/internal/service/token_refresh_service.go
shaw 2665230a09 fix(token-cache): 修复异步刷新与请求线程的缓存竞态条件
- 新增 _token_version 版本号机制,防止过期 token 污染缓存
- TokenRefreshService 刷新成功后写入版本号并清除缓存
- TokenProvider 写入缓存前检查版本,过时则跳过
- ClearError 时同步清除 token 缓存
2026-01-22 21:09:28 +08:00

260 lines
8.0 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package service
import (
"context"
"fmt"
"log"
"strings"
"sync"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
)
// TokenRefreshService OAuth token自动刷新服务
// 定期检查并刷新即将过期的token
type TokenRefreshService struct {
accountRepo AccountRepository
refreshers []TokenRefresher
cfg *config.TokenRefreshConfig
cacheInvalidator TokenCacheInvalidator
stopCh chan struct{}
wg sync.WaitGroup
}
// NewTokenRefreshService 创建token刷新服务
func NewTokenRefreshService(
accountRepo AccountRepository,
oauthService *OAuthService,
openaiOAuthService *OpenAIOAuthService,
geminiOAuthService *GeminiOAuthService,
antigravityOAuthService *AntigravityOAuthService,
cacheInvalidator TokenCacheInvalidator,
cfg *config.Config,
) *TokenRefreshService {
s := &TokenRefreshService{
accountRepo: accountRepo,
cfg: &cfg.TokenRefresh,
cacheInvalidator: cacheInvalidator,
stopCh: make(chan struct{}),
}
// 注册平台特定的刷新器
s.refreshers = []TokenRefresher{
NewClaudeTokenRefresher(oauthService),
NewOpenAITokenRefresher(openaiOAuthService),
NewGeminiTokenRefresher(geminiOAuthService),
NewAntigravityTokenRefresher(antigravityOAuthService),
}
return s
}
// Start 启动后台刷新服务
func (s *TokenRefreshService) Start() {
if !s.cfg.Enabled {
log.Println("[TokenRefresh] Service disabled by configuration")
return
}
s.wg.Add(1)
go s.refreshLoop()
log.Printf("[TokenRefresh] Service started (check every %d minutes, refresh %v hours before expiry)",
s.cfg.CheckIntervalMinutes, s.cfg.RefreshBeforeExpiryHours)
}
// Stop 停止刷新服务
func (s *TokenRefreshService) Stop() {
close(s.stopCh)
s.wg.Wait()
log.Println("[TokenRefresh] Service stopped")
}
// refreshLoop 刷新循环
func (s *TokenRefreshService) refreshLoop() {
defer s.wg.Done()
// 计算检查间隔
checkInterval := time.Duration(s.cfg.CheckIntervalMinutes) * time.Minute
if checkInterval < time.Minute {
checkInterval = 5 * time.Minute
}
ticker := time.NewTicker(checkInterval)
defer ticker.Stop()
// 启动时立即执行一次检查
s.processRefresh()
for {
select {
case <-ticker.C:
s.processRefresh()
case <-s.stopCh:
return
}
}
}
// processRefresh 执行一次刷新检查
func (s *TokenRefreshService) processRefresh() {
ctx := context.Background()
// 计算刷新窗口
refreshWindow := time.Duration(s.cfg.RefreshBeforeExpiryHours * float64(time.Hour))
// 获取所有active状态的账号
accounts, err := s.listActiveAccounts(ctx)
if err != nil {
log.Printf("[TokenRefresh] Failed to list accounts: %v", err)
return
}
totalAccounts := len(accounts)
oauthAccounts := 0 // 可刷新的OAuth账号数
needsRefresh := 0 // 需要刷新的账号数
refreshed, failed := 0, 0
for i := range accounts {
account := &accounts[i]
// 遍历所有刷新器,找到能处理此账号的
for _, refresher := range s.refreshers {
if !refresher.CanRefresh(account) {
continue
}
oauthAccounts++
// 检查是否需要刷新
if !refresher.NeedsRefresh(account, refreshWindow) {
break // 不需要刷新,跳过
}
needsRefresh++
// 执行刷新
if err := s.refreshWithRetry(ctx, account, refresher); err != nil {
log.Printf("[TokenRefresh] Account %d (%s) failed: %v", account.ID, account.Name, err)
failed++
} else {
log.Printf("[TokenRefresh] Account %d (%s) refreshed successfully", account.ID, account.Name)
refreshed++
}
// 每个账号只由一个refresher处理
break
}
}
// 始终打印周期日志,便于跟踪服务运行状态
log.Printf("[TokenRefresh] Cycle complete: total=%d, oauth=%d, needs_refresh=%d, refreshed=%d, failed=%d",
totalAccounts, oauthAccounts, needsRefresh, refreshed, failed)
}
// listActiveAccounts 获取所有active状态的账号
// 使用ListActive确保刷新所有活跃账号的token包括临时禁用的
func (s *TokenRefreshService) listActiveAccounts(ctx context.Context) ([]Account, error) {
return s.accountRepo.ListActive(ctx)
}
// refreshWithRetry 带重试的刷新
func (s *TokenRefreshService) refreshWithRetry(ctx context.Context, account *Account, refresher TokenRefresher) error {
var lastErr error
for attempt := 1; attempt <= s.cfg.MaxRetries; attempt++ {
newCredentials, err := refresher.Refresh(ctx, account)
// 如果有新凭证,先更新(即使有错误也要保存 token
if newCredentials != nil {
// 记录刷新版本时间戳,用于解决缓存一致性问题
// TokenProvider 写入缓存前会检查此版本,如果版本已更新则跳过写入
newCredentials["_token_version"] = time.Now().UnixMilli()
account.Credentials = newCredentials
if saveErr := s.accountRepo.Update(ctx, account); saveErr != nil {
return fmt.Errorf("failed to save credentials: %w", saveErr)
}
}
if err == nil {
// Antigravity 账户:如果之前是因为缺少 project_id 而标记为 error现在成功获取到了清除错误状态
if account.Platform == PlatformAntigravity &&
account.Status == StatusError &&
strings.Contains(account.ErrorMessage, "missing_project_id:") {
if clearErr := s.accountRepo.ClearError(ctx, account.ID); clearErr != nil {
log.Printf("[TokenRefresh] Failed to clear error status for account %d: %v", account.ID, clearErr)
} else {
log.Printf("[TokenRefresh] Account %d: cleared missing_project_id error", account.ID)
}
}
// 对所有 OAuth 账号调用缓存失效InvalidateToken 内部根据平台判断是否需要处理)
if s.cacheInvalidator != nil && account.Type == AccountTypeOAuth {
if err := s.cacheInvalidator.InvalidateToken(ctx, account); err != nil {
log.Printf("[TokenRefresh] Failed to invalidate token cache for account %d: %v", account.ID, err)
} else {
log.Printf("[TokenRefresh] Token cache invalidated for account %d", account.ID)
}
}
return nil
}
// Antigravity 账户:不可重试错误直接标记 error 状态并返回
if account.Platform == PlatformAntigravity && isNonRetryableRefreshError(err) {
errorMsg := fmt.Sprintf("Token refresh failed (non-retryable): %v", err)
if setErr := s.accountRepo.SetError(ctx, account.ID, errorMsg); setErr != nil {
log.Printf("[TokenRefresh] Failed to set error status for account %d: %v", account.ID, setErr)
}
return err
}
lastErr = err
log.Printf("[TokenRefresh] Account %d attempt %d/%d failed: %v",
account.ID, attempt, s.cfg.MaxRetries, err)
// 如果还有重试机会,等待后重试
if attempt < s.cfg.MaxRetries {
// 指数退避2^(attempt-1) * baseSeconds
backoff := time.Duration(s.cfg.RetryBackoffSeconds) * time.Second * time.Duration(1<<(attempt-1))
time.Sleep(backoff)
}
}
// Antigravity 账户:其他错误仅记录日志,不标记 error可能是临时网络问题
// 其他平台账户:重试失败后标记 error
if account.Platform == PlatformAntigravity {
log.Printf("[TokenRefresh] Account %d: refresh failed after %d retries: %v", account.ID, s.cfg.MaxRetries, lastErr)
} else {
errorMsg := fmt.Sprintf("Token refresh failed after %d retries: %v", s.cfg.MaxRetries, lastErr)
if err := s.accountRepo.SetError(ctx, account.ID, errorMsg); err != nil {
log.Printf("[TokenRefresh] Failed to set error status for account %d: %v", account.ID, err)
}
}
return lastErr
}
// isNonRetryableRefreshError 判断是否为不可重试的刷新错误
// 这些错误通常表示凭证已失效,需要用户重新授权
func isNonRetryableRefreshError(err error) bool {
if err == nil {
return false
}
msg := strings.ToLower(err.Error())
nonRetryable := []string{
"invalid_grant", // refresh_token 已失效
"invalid_client", // 客户端配置错误
"unauthorized_client", // 客户端未授权
"access_denied", // 访问被拒绝
"missing_project_id", // 缺少 project_id
}
for _, needle := range nonRetryable {
if strings.Contains(msg, needle) {
return true
}
}
return false
}