fix(OAuth缓存): 修复缓存键冲突、401强制刷新及Redis降级处理
- Gemini 缓存键统一增加 gemini: 前缀,避免与其他平台命名空间冲突 - OAuth 账号 401 错误时设置 expires_at=now 并持久化,强制下次请求刷新 token - Redis 锁获取失败时降级为无锁刷新,仅在 token 接近过期时执行,并检查 ctx 取消状态 Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
@@ -65,8 +65,8 @@ func (p *ClaudeTokenProvider) GetAccessToken(ctx context.Context, account *Accou
|
|||||||
needsRefresh := expiresAt == nil || time.Until(*expiresAt) <= claudeTokenRefreshSkew
|
needsRefresh := expiresAt == nil || time.Until(*expiresAt) <= claudeTokenRefreshSkew
|
||||||
refreshFailed := false
|
refreshFailed := false
|
||||||
if needsRefresh && p.tokenCache != nil {
|
if needsRefresh && p.tokenCache != nil {
|
||||||
locked, err := p.tokenCache.AcquireRefreshLock(ctx, cacheKey, 30*time.Second)
|
locked, lockErr := p.tokenCache.AcquireRefreshLock(ctx, cacheKey, 30*time.Second)
|
||||||
if err == nil && locked {
|
if lockErr == nil && locked {
|
||||||
defer func() { _ = p.tokenCache.ReleaseRefreshLock(ctx, cacheKey) }()
|
defer func() { _ = p.tokenCache.ReleaseRefreshLock(ctx, cacheKey) }()
|
||||||
|
|
||||||
// 拿到锁后再次检查缓存(另一个 worker 可能已刷新)
|
// 拿到锁后再次检查缓存(另一个 worker 可能已刷新)
|
||||||
@@ -114,8 +114,60 @@ func (p *ClaudeTokenProvider) GetAccessToken(ctx context.Context, account *Accou
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
} else if lockErr != nil {
|
||||||
|
// Redis 错误导致无法获取锁,降级为无锁刷新(仅在 token 接近过期时)
|
||||||
|
slog.Warn("claude_token_lock_failed_degraded_refresh", "account_id", account.ID, "error", lockErr)
|
||||||
|
|
||||||
|
// 检查 ctx 是否已取消
|
||||||
|
if ctx.Err() != nil {
|
||||||
|
return "", ctx.Err()
|
||||||
|
}
|
||||||
|
|
||||||
|
// 从数据库获取最新账户信息
|
||||||
|
if p.accountRepo != nil {
|
||||||
|
fresh, err := p.accountRepo.GetByID(ctx, account.ID)
|
||||||
|
if err == nil && fresh != nil {
|
||||||
|
account = fresh
|
||||||
|
}
|
||||||
|
}
|
||||||
|
expiresAt = account.GetCredentialAsTime("expires_at")
|
||||||
|
|
||||||
|
// 仅在 expires_at 已过期/接近过期时才执行无锁刷新
|
||||||
|
if expiresAt == nil || time.Until(*expiresAt) <= claudeTokenRefreshSkew {
|
||||||
|
if p.oauthService == nil {
|
||||||
|
slog.Warn("claude_oauth_service_not_configured", "account_id", account.ID)
|
||||||
|
refreshFailed = true
|
||||||
|
} else {
|
||||||
|
tokenInfo, err := p.oauthService.RefreshAccountToken(ctx, account)
|
||||||
|
if err != nil {
|
||||||
|
slog.Warn("claude_token_refresh_failed_degraded", "account_id", account.ID, "error", err)
|
||||||
|
refreshFailed = true
|
||||||
|
} else {
|
||||||
|
// 构建新 credentials,保留原有字段
|
||||||
|
newCredentials := make(map[string]any)
|
||||||
|
for k, v := range account.Credentials {
|
||||||
|
newCredentials[k] = v
|
||||||
|
}
|
||||||
|
newCredentials["access_token"] = tokenInfo.AccessToken
|
||||||
|
newCredentials["token_type"] = tokenInfo.TokenType
|
||||||
|
newCredentials["expires_in"] = strconv.FormatInt(tokenInfo.ExpiresIn, 10)
|
||||||
|
newCredentials["expires_at"] = strconv.FormatInt(tokenInfo.ExpiresAt, 10)
|
||||||
|
if tokenInfo.RefreshToken != "" {
|
||||||
|
newCredentials["refresh_token"] = tokenInfo.RefreshToken
|
||||||
|
}
|
||||||
|
if tokenInfo.Scope != "" {
|
||||||
|
newCredentials["scope"] = tokenInfo.Scope
|
||||||
|
}
|
||||||
|
account.Credentials = newCredentials
|
||||||
|
if updateErr := p.accountRepo.Update(ctx, account); updateErr != nil {
|
||||||
|
slog.Error("claude_token_provider_update_failed", "account_id", account.ID, "error", updateErr)
|
||||||
|
}
|
||||||
|
expiresAt = account.GetCredentialAsTime("expires_at")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
} else {
|
} else {
|
||||||
// 锁获取失败,等待 200ms 后重试读取缓存(改进:减少并发时的缓存未命中)
|
// 锁获取失败(被其他 worker 持有),等待 200ms 后重试读取缓存
|
||||||
time.Sleep(claudeLockWaitTime)
|
time.Sleep(claudeLockWaitTime)
|
||||||
if token, err := p.tokenCache.GetAccessToken(ctx, cacheKey); err == nil && strings.TrimSpace(token) != "" {
|
if token, err := p.tokenCache.GetAccessToken(ctx, cacheKey); err == nil && strings.TrimSpace(token) != "" {
|
||||||
slog.Debug("claude_token_cache_hit_after_wait", "account_id", account.ID)
|
slog.Debug("claude_token_cache_hit_after_wait", "account_id", account.ID)
|
||||||
|
|||||||
@@ -154,7 +154,7 @@ func (p *GeminiTokenProvider) GetAccessToken(ctx context.Context, account *Accou
|
|||||||
func GeminiTokenCacheKey(account *Account) string {
|
func GeminiTokenCacheKey(account *Account) string {
|
||||||
projectID := strings.TrimSpace(account.GetCredential("project_id"))
|
projectID := strings.TrimSpace(account.GetCredential("project_id"))
|
||||||
if projectID != "" {
|
if projectID != "" {
|
||||||
return projectID
|
return "gemini:" + projectID
|
||||||
}
|
}
|
||||||
return "account:" + strconv.FormatInt(account.ID, 10)
|
return "gemini:account:" + strconv.FormatInt(account.ID, 10)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -64,8 +64,8 @@ func (p *OpenAITokenProvider) GetAccessToken(ctx context.Context, account *Accou
|
|||||||
needsRefresh := expiresAt == nil || time.Until(*expiresAt) <= openAITokenRefreshSkew
|
needsRefresh := expiresAt == nil || time.Until(*expiresAt) <= openAITokenRefreshSkew
|
||||||
refreshFailed := false
|
refreshFailed := false
|
||||||
if needsRefresh && p.tokenCache != nil {
|
if needsRefresh && p.tokenCache != nil {
|
||||||
locked, err := p.tokenCache.AcquireRefreshLock(ctx, cacheKey, 30*time.Second)
|
locked, lockErr := p.tokenCache.AcquireRefreshLock(ctx, cacheKey, 30*time.Second)
|
||||||
if err == nil && locked {
|
if lockErr == nil && locked {
|
||||||
defer func() { _ = p.tokenCache.ReleaseRefreshLock(ctx, cacheKey) }()
|
defer func() { _ = p.tokenCache.ReleaseRefreshLock(ctx, cacheKey) }()
|
||||||
|
|
||||||
// 拿到锁后再次检查缓存(另一个 worker 可能已刷新)
|
// 拿到锁后再次检查缓存(另一个 worker 可能已刷新)
|
||||||
@@ -104,8 +104,51 @@ func (p *OpenAITokenProvider) GetAccessToken(ctx context.Context, account *Accou
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
} else if lockErr != nil {
|
||||||
|
// Redis 错误导致无法获取锁,降级为无锁刷新(仅在 token 接近过期时)
|
||||||
|
slog.Warn("openai_token_lock_failed_degraded_refresh", "account_id", account.ID, "error", lockErr)
|
||||||
|
|
||||||
|
// 检查 ctx 是否已取消
|
||||||
|
if ctx.Err() != nil {
|
||||||
|
return "", ctx.Err()
|
||||||
|
}
|
||||||
|
|
||||||
|
// 从数据库获取最新账户信息
|
||||||
|
if p.accountRepo != nil {
|
||||||
|
fresh, err := p.accountRepo.GetByID(ctx, account.ID)
|
||||||
|
if err == nil && fresh != nil {
|
||||||
|
account = fresh
|
||||||
|
}
|
||||||
|
}
|
||||||
|
expiresAt = account.GetCredentialAsTime("expires_at")
|
||||||
|
|
||||||
|
// 仅在 expires_at 已过期/接近过期时才执行无锁刷新
|
||||||
|
if expiresAt == nil || time.Until(*expiresAt) <= openAITokenRefreshSkew {
|
||||||
|
if p.openAIOAuthService == nil {
|
||||||
|
slog.Warn("openai_oauth_service_not_configured", "account_id", account.ID)
|
||||||
|
refreshFailed = true
|
||||||
|
} else {
|
||||||
|
tokenInfo, err := p.openAIOAuthService.RefreshAccountToken(ctx, account)
|
||||||
|
if err != nil {
|
||||||
|
slog.Warn("openai_token_refresh_failed_degraded", "account_id", account.ID, "error", err)
|
||||||
|
refreshFailed = true
|
||||||
|
} else {
|
||||||
|
newCredentials := p.openAIOAuthService.BuildAccountCredentials(tokenInfo)
|
||||||
|
for k, v := range account.Credentials {
|
||||||
|
if _, exists := newCredentials[k]; !exists {
|
||||||
|
newCredentials[k] = v
|
||||||
|
}
|
||||||
|
}
|
||||||
|
account.Credentials = newCredentials
|
||||||
|
if updateErr := p.accountRepo.Update(ctx, account); updateErr != nil {
|
||||||
|
slog.Error("openai_token_provider_update_failed", "account_id", account.ID, "error", updateErr)
|
||||||
|
}
|
||||||
|
expiresAt = account.GetCredentialAsTime("expires_at")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
} else {
|
} else {
|
||||||
// 锁获取失败,等待 200ms 后重试读取缓存(改进:减少并发时的缓存未命中)
|
// 锁获取失败(被其他 worker 持有),等待 200ms 后重试读取缓存
|
||||||
time.Sleep(openAILockWaitTime)
|
time.Sleep(openAILockWaitTime)
|
||||||
if token, err := p.tokenCache.GetAccessToken(ctx, cacheKey); err == nil && strings.TrimSpace(token) != "" {
|
if token, err := p.tokenCache.GetAccessToken(ctx, cacheKey); err == nil && strings.TrimSpace(token) != "" {
|
||||||
slog.Debug("openai_token_cache_hit_after_wait", "account_id", account.ID)
|
slog.Debug("openai_token_cache_hit_after_wait", "account_id", account.ID)
|
||||||
|
|||||||
@@ -85,13 +85,24 @@ func (s *RateLimitService) HandleUpstreamError(ctx context.Context, account *Acc
|
|||||||
|
|
||||||
switch statusCode {
|
switch statusCode {
|
||||||
case 401:
|
case 401:
|
||||||
// 对所有 OAuth 账号在 401 错误时调用缓存失效
|
// 对所有 OAuth 账号在 401 错误时调用缓存失效并强制下次刷新
|
||||||
if account.Type == AccountTypeOAuth {
|
if account.Type == AccountTypeOAuth {
|
||||||
|
// 1. 失效缓存
|
||||||
if s.tokenCacheInvalidator != nil {
|
if s.tokenCacheInvalidator != nil {
|
||||||
if err := s.tokenCacheInvalidator.InvalidateToken(ctx, account); err != nil {
|
if err := s.tokenCacheInvalidator.InvalidateToken(ctx, account); err != nil {
|
||||||
slog.Warn("oauth_401_invalidate_cache_failed", "account_id", account.ID, "error", err)
|
slog.Warn("oauth_401_invalidate_cache_failed", "account_id", account.ID, "error", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
// 2. 设置 expires_at 为当前时间,强制下次请求刷新 token
|
||||||
|
if account.Credentials == nil {
|
||||||
|
account.Credentials = make(map[string]any)
|
||||||
|
}
|
||||||
|
account.Credentials["expires_at"] = time.Now().Format(time.RFC3339)
|
||||||
|
if err := s.accountRepo.Update(ctx, account); err != nil {
|
||||||
|
slog.Warn("oauth_401_force_refresh_update_failed", "account_id", account.ID, "error", err)
|
||||||
|
} else {
|
||||||
|
slog.Info("oauth_401_force_refresh_set", "account_id", account.ID, "platform", account.Platform)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
msg := "Authentication failed (401): invalid or expired credentials"
|
msg := "Authentication failed (401): invalid or expired credentials"
|
||||||
if upstreamMsg != "" {
|
if upstreamMsg != "" {
|
||||||
|
|||||||
@@ -51,7 +51,7 @@ func TestCompositeTokenCacheInvalidator_Gemini(t *testing.T) {
|
|||||||
|
|
||||||
err := invalidator.InvalidateToken(context.Background(), account)
|
err := invalidator.InvalidateToken(context.Background(), account)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
require.Equal(t, []string{"project-x"}, cache.deletedKeys)
|
require.Equal(t, []string{"gemini:project-x"}, cache.deletedKeys)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestCompositeTokenCacheInvalidator_Antigravity(t *testing.T) {
|
func TestCompositeTokenCacheInvalidator_Antigravity(t *testing.T) {
|
||||||
@@ -253,7 +253,7 @@ func TestCompositeTokenCacheInvalidator_AllPlatformsIntegration(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
expectedKeys := []string{
|
expectedKeys := []string{
|
||||||
"gemini-proj",
|
"gemini:gemini-proj",
|
||||||
"ag:ag-proj",
|
"ag:ag-proj",
|
||||||
"openai:account:3",
|
"openai:account:3",
|
||||||
"claude:account:4",
|
"claude:account:4",
|
||||||
|
|||||||
@@ -22,7 +22,7 @@ func TestGeminiTokenCacheKey(t *testing.T) {
|
|||||||
"project_id": "my-project-123",
|
"project_id": "my-project-123",
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
expected: "my-project-123",
|
expected: "gemini:my-project-123",
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "project_id_with_whitespace",
|
name: "project_id_with_whitespace",
|
||||||
@@ -32,7 +32,7 @@ func TestGeminiTokenCacheKey(t *testing.T) {
|
|||||||
"project_id": " project-with-spaces ",
|
"project_id": " project-with-spaces ",
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
expected: "project-with-spaces",
|
expected: "gemini:project-with-spaces",
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "empty_project_id_fallback_to_account_id",
|
name: "empty_project_id_fallback_to_account_id",
|
||||||
@@ -42,7 +42,7 @@ func TestGeminiTokenCacheKey(t *testing.T) {
|
|||||||
"project_id": "",
|
"project_id": "",
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
expected: "account:102",
|
expected: "gemini:account:102",
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "whitespace_only_project_id_fallback_to_account_id",
|
name: "whitespace_only_project_id_fallback_to_account_id",
|
||||||
@@ -52,7 +52,7 @@ func TestGeminiTokenCacheKey(t *testing.T) {
|
|||||||
"project_id": " ",
|
"project_id": " ",
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
expected: "account:103",
|
expected: "gemini:account:103",
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "no_project_id_key_fallback_to_account_id",
|
name: "no_project_id_key_fallback_to_account_id",
|
||||||
@@ -60,7 +60,7 @@ func TestGeminiTokenCacheKey(t *testing.T) {
|
|||||||
ID: 104,
|
ID: 104,
|
||||||
Credentials: map[string]any{},
|
Credentials: map[string]any{},
|
||||||
},
|
},
|
||||||
expected: "account:104",
|
expected: "gemini:account:104",
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "nil_credentials_fallback_to_account_id",
|
name: "nil_credentials_fallback_to_account_id",
|
||||||
@@ -68,7 +68,7 @@ func TestGeminiTokenCacheKey(t *testing.T) {
|
|||||||
ID: 105,
|
ID: 105,
|
||||||
Credentials: nil,
|
Credentials: nil,
|
||||||
},
|
},
|
||||||
expected: "account:105",
|
expected: "gemini:account:105",
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user