From 5b37e9aea432df2dd6ff36bc4a3ba8c35178c14f Mon Sep 17 00:00:00 2001 From: yangjianbo Date: Thu, 15 Jan 2026 19:08:07 +0800 Subject: [PATCH] =?UTF-8?q?fix(OAuth=E7=BC=93=E5=AD=98):=20=E4=BF=AE?= =?UTF-8?q?=E5=A4=8D=E7=BC=93=E5=AD=98=E9=94=AE=E5=86=B2=E7=AA=81=E3=80=81?= =?UTF-8?q?401=E5=BC=BA=E5=88=B6=E5=88=B7=E6=96=B0=E5=8F=8ARedis=E9=99=8D?= =?UTF-8?q?=E7=BA=A7=E5=A4=84=E7=90=86?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Gemini 缓存键统一增加 gemini: 前缀,避免与其他平台命名空间冲突 - OAuth 账号 401 错误时设置 expires_at=now 并持久化,强制下次请求刷新 token - Redis 锁获取失败时降级为无锁刷新,仅在 token 接近过期时执行,并检查 ctx 取消状态 Co-Authored-By: Claude Opus 4.5 --- .../internal/service/claude_token_provider.go | 58 ++++++++++++++++++- .../internal/service/gemini_token_provider.go | 4 +- .../internal/service/openai_token_provider.go | 49 +++++++++++++++- backend/internal/service/ratelimit_service.go | 13 ++++- .../service/token_cache_invalidator_test.go | 4 +- .../internal/service/token_cache_key_test.go | 12 ++-- 6 files changed, 123 insertions(+), 17 deletions(-) diff --git a/backend/internal/service/claude_token_provider.go b/backend/internal/service/claude_token_provider.go index d2db162f..c7c6e42d 100644 --- a/backend/internal/service/claude_token_provider.go +++ b/backend/internal/service/claude_token_provider.go @@ -65,8 +65,8 @@ func (p *ClaudeTokenProvider) GetAccessToken(ctx context.Context, account *Accou needsRefresh := expiresAt == nil || time.Until(*expiresAt) <= claudeTokenRefreshSkew refreshFailed := false if needsRefresh && p.tokenCache != nil { - locked, err := p.tokenCache.AcquireRefreshLock(ctx, cacheKey, 30*time.Second) - if err == nil && locked { + locked, lockErr := p.tokenCache.AcquireRefreshLock(ctx, cacheKey, 30*time.Second) + if lockErr == nil && locked { defer func() { _ = p.tokenCache.ReleaseRefreshLock(ctx, cacheKey) }() // 拿到锁后再次检查缓存(另一个 worker 可能已刷新) @@ -114,8 +114,60 @@ func (p *ClaudeTokenProvider) GetAccessToken(ctx context.Context, account *Accou } } } + } else if lockErr != nil { + // Redis 错误导致无法获取锁,降级为无锁刷新(仅在 token 接近过期时) + slog.Warn("claude_token_lock_failed_degraded_refresh", "account_id", account.ID, "error", lockErr) + + // 检查 ctx 是否已取消 + if ctx.Err() != nil { + return "", ctx.Err() + } + + // 从数据库获取最新账户信息 + if p.accountRepo != nil { + fresh, err := p.accountRepo.GetByID(ctx, account.ID) + if err == nil && fresh != nil { + account = fresh + } + } + expiresAt = account.GetCredentialAsTime("expires_at") + + // 仅在 expires_at 已过期/接近过期时才执行无锁刷新 + if expiresAt == nil || time.Until(*expiresAt) <= claudeTokenRefreshSkew { + if p.oauthService == nil { + slog.Warn("claude_oauth_service_not_configured", "account_id", account.ID) + refreshFailed = true + } else { + tokenInfo, err := p.oauthService.RefreshAccountToken(ctx, account) + if err != nil { + slog.Warn("claude_token_refresh_failed_degraded", "account_id", account.ID, "error", err) + refreshFailed = true + } else { + // 构建新 credentials,保留原有字段 + newCredentials := make(map[string]any) + for k, v := range account.Credentials { + newCredentials[k] = v + } + newCredentials["access_token"] = tokenInfo.AccessToken + newCredentials["token_type"] = tokenInfo.TokenType + newCredentials["expires_in"] = strconv.FormatInt(tokenInfo.ExpiresIn, 10) + newCredentials["expires_at"] = strconv.FormatInt(tokenInfo.ExpiresAt, 10) + if tokenInfo.RefreshToken != "" { + newCredentials["refresh_token"] = tokenInfo.RefreshToken + } + if tokenInfo.Scope != "" { + newCredentials["scope"] = tokenInfo.Scope + } + account.Credentials = newCredentials + if updateErr := p.accountRepo.Update(ctx, account); updateErr != nil { + slog.Error("claude_token_provider_update_failed", "account_id", account.ID, "error", updateErr) + } + expiresAt = account.GetCredentialAsTime("expires_at") + } + } + } } else { - // 锁获取失败,等待 200ms 后重试读取缓存(改进:减少并发时的缓存未命中) + // 锁获取失败(被其他 worker 持有),等待 200ms 后重试读取缓存 time.Sleep(claudeLockWaitTime) if token, err := p.tokenCache.GetAccessToken(ctx, cacheKey); err == nil && strings.TrimSpace(token) != "" { slog.Debug("claude_token_cache_hit_after_wait", "account_id", account.ID) diff --git a/backend/internal/service/gemini_token_provider.go b/backend/internal/service/gemini_token_provider.go index a5cacc9a..f13ae169 100644 --- a/backend/internal/service/gemini_token_provider.go +++ b/backend/internal/service/gemini_token_provider.go @@ -154,7 +154,7 @@ func (p *GeminiTokenProvider) GetAccessToken(ctx context.Context, account *Accou func GeminiTokenCacheKey(account *Account) string { projectID := strings.TrimSpace(account.GetCredential("project_id")) if projectID != "" { - return projectID + return "gemini:" + projectID } - return "account:" + strconv.FormatInt(account.ID, 10) + return "gemini:account:" + strconv.FormatInt(account.ID, 10) } diff --git a/backend/internal/service/openai_token_provider.go b/backend/internal/service/openai_token_provider.go index e96892cd..82a0866f 100644 --- a/backend/internal/service/openai_token_provider.go +++ b/backend/internal/service/openai_token_provider.go @@ -64,8 +64,8 @@ func (p *OpenAITokenProvider) GetAccessToken(ctx context.Context, account *Accou needsRefresh := expiresAt == nil || time.Until(*expiresAt) <= openAITokenRefreshSkew refreshFailed := false if needsRefresh && p.tokenCache != nil { - locked, err := p.tokenCache.AcquireRefreshLock(ctx, cacheKey, 30*time.Second) - if err == nil && locked { + locked, lockErr := p.tokenCache.AcquireRefreshLock(ctx, cacheKey, 30*time.Second) + if lockErr == nil && locked { defer func() { _ = p.tokenCache.ReleaseRefreshLock(ctx, cacheKey) }() // 拿到锁后再次检查缓存(另一个 worker 可能已刷新) @@ -104,8 +104,51 @@ func (p *OpenAITokenProvider) GetAccessToken(ctx context.Context, account *Accou } } } + } else if lockErr != nil { + // Redis 错误导致无法获取锁,降级为无锁刷新(仅在 token 接近过期时) + slog.Warn("openai_token_lock_failed_degraded_refresh", "account_id", account.ID, "error", lockErr) + + // 检查 ctx 是否已取消 + if ctx.Err() != nil { + return "", ctx.Err() + } + + // 从数据库获取最新账户信息 + if p.accountRepo != nil { + fresh, err := p.accountRepo.GetByID(ctx, account.ID) + if err == nil && fresh != nil { + account = fresh + } + } + expiresAt = account.GetCredentialAsTime("expires_at") + + // 仅在 expires_at 已过期/接近过期时才执行无锁刷新 + if expiresAt == nil || time.Until(*expiresAt) <= openAITokenRefreshSkew { + if p.openAIOAuthService == nil { + slog.Warn("openai_oauth_service_not_configured", "account_id", account.ID) + refreshFailed = true + } else { + tokenInfo, err := p.openAIOAuthService.RefreshAccountToken(ctx, account) + if err != nil { + slog.Warn("openai_token_refresh_failed_degraded", "account_id", account.ID, "error", err) + refreshFailed = true + } else { + newCredentials := p.openAIOAuthService.BuildAccountCredentials(tokenInfo) + for k, v := range account.Credentials { + if _, exists := newCredentials[k]; !exists { + newCredentials[k] = v + } + } + account.Credentials = newCredentials + if updateErr := p.accountRepo.Update(ctx, account); updateErr != nil { + slog.Error("openai_token_provider_update_failed", "account_id", account.ID, "error", updateErr) + } + expiresAt = account.GetCredentialAsTime("expires_at") + } + } + } } else { - // 锁获取失败,等待 200ms 后重试读取缓存(改进:减少并发时的缓存未命中) + // 锁获取失败(被其他 worker 持有),等待 200ms 后重试读取缓存 time.Sleep(openAILockWaitTime) if token, err := p.tokenCache.GetAccessToken(ctx, cacheKey); err == nil && strings.TrimSpace(token) != "" { slog.Debug("openai_token_cache_hit_after_wait", "account_id", account.ID) diff --git a/backend/internal/service/ratelimit_service.go b/backend/internal/service/ratelimit_service.go index effb7e9a..20e08c52 100644 --- a/backend/internal/service/ratelimit_service.go +++ b/backend/internal/service/ratelimit_service.go @@ -85,13 +85,24 @@ func (s *RateLimitService) HandleUpstreamError(ctx context.Context, account *Acc switch statusCode { case 401: - // 对所有 OAuth 账号在 401 错误时调用缓存失效 + // 对所有 OAuth 账号在 401 错误时调用缓存失效并强制下次刷新 if account.Type == AccountTypeOAuth { + // 1. 失效缓存 if s.tokenCacheInvalidator != nil { if err := s.tokenCacheInvalidator.InvalidateToken(ctx, account); err != nil { slog.Warn("oauth_401_invalidate_cache_failed", "account_id", account.ID, "error", err) } } + // 2. 设置 expires_at 为当前时间,强制下次请求刷新 token + if account.Credentials == nil { + account.Credentials = make(map[string]any) + } + account.Credentials["expires_at"] = time.Now().Format(time.RFC3339) + if err := s.accountRepo.Update(ctx, account); err != nil { + slog.Warn("oauth_401_force_refresh_update_failed", "account_id", account.ID, "error", err) + } else { + slog.Info("oauth_401_force_refresh_set", "account_id", account.ID, "platform", account.Platform) + } } msg := "Authentication failed (401): invalid or expired credentials" if upstreamMsg != "" { diff --git a/backend/internal/service/token_cache_invalidator_test.go b/backend/internal/service/token_cache_invalidator_test.go index a33da60d..ad13bde7 100644 --- a/backend/internal/service/token_cache_invalidator_test.go +++ b/backend/internal/service/token_cache_invalidator_test.go @@ -51,7 +51,7 @@ func TestCompositeTokenCacheInvalidator_Gemini(t *testing.T) { err := invalidator.InvalidateToken(context.Background(), account) require.NoError(t, err) - require.Equal(t, []string{"project-x"}, cache.deletedKeys) + require.Equal(t, []string{"gemini:project-x"}, cache.deletedKeys) } func TestCompositeTokenCacheInvalidator_Antigravity(t *testing.T) { @@ -253,7 +253,7 @@ func TestCompositeTokenCacheInvalidator_AllPlatformsIntegration(t *testing.T) { } expectedKeys := []string{ - "gemini-proj", + "gemini:gemini-proj", "ag:ag-proj", "openai:account:3", "claude:account:4", diff --git a/backend/internal/service/token_cache_key_test.go b/backend/internal/service/token_cache_key_test.go index e6b33747..6215eeaf 100644 --- a/backend/internal/service/token_cache_key_test.go +++ b/backend/internal/service/token_cache_key_test.go @@ -22,7 +22,7 @@ func TestGeminiTokenCacheKey(t *testing.T) { "project_id": "my-project-123", }, }, - expected: "my-project-123", + expected: "gemini:my-project-123", }, { name: "project_id_with_whitespace", @@ -32,7 +32,7 @@ func TestGeminiTokenCacheKey(t *testing.T) { "project_id": " project-with-spaces ", }, }, - expected: "project-with-spaces", + expected: "gemini:project-with-spaces", }, { name: "empty_project_id_fallback_to_account_id", @@ -42,7 +42,7 @@ func TestGeminiTokenCacheKey(t *testing.T) { "project_id": "", }, }, - expected: "account:102", + expected: "gemini:account:102", }, { name: "whitespace_only_project_id_fallback_to_account_id", @@ -52,7 +52,7 @@ func TestGeminiTokenCacheKey(t *testing.T) { "project_id": " ", }, }, - expected: "account:103", + expected: "gemini:account:103", }, { name: "no_project_id_key_fallback_to_account_id", @@ -60,7 +60,7 @@ func TestGeminiTokenCacheKey(t *testing.T) { ID: 104, Credentials: map[string]any{}, }, - expected: "account:104", + expected: "gemini:account:104", }, { name: "nil_credentials_fallback_to_account_id", @@ -68,7 +68,7 @@ func TestGeminiTokenCacheKey(t *testing.T) { ID: 105, Credentials: nil, }, - expected: "account:105", + expected: "gemini:account:105", }, }