diff --git a/backend/internal/service/antigravity_token_provider.go b/backend/internal/service/antigravity_token_provider.go index ad8acbd3..94eca94d 100644 --- a/backend/internal/service/antigravity_token_provider.go +++ b/backend/internal/service/antigravity_token_provider.go @@ -4,6 +4,7 @@ import ( "context" "errors" "log" + "log/slog" "strconv" "strings" "time" @@ -102,20 +103,31 @@ func (p *AntigravityTokenProvider) GetAccessToken(ctx context.Context, account * } // 3. 存入缓存(验证版本后再写入,避免异步刷新任务与请求线程的竞态条件) - if p.tokenCache != nil && !IsTokenVersionStale(ctx, account, p.accountRepo) { - ttl := 30 * time.Minute - if expiresAt != nil { - until := time.Until(*expiresAt) - switch { - case until > antigravityTokenCacheSkew: - ttl = until - antigravityTokenCacheSkew - case until > 0: - ttl = until - default: - ttl = time.Minute + if p.tokenCache != nil { + latestAccount, isStale := CheckTokenVersion(ctx, account, p.accountRepo) + if isStale && latestAccount != nil { + // 版本过时,使用 DB 中的最新 token + slog.Debug("antigravity_token_version_stale_use_latest", "account_id", account.ID) + accessToken = latestAccount.GetCredential("access_token") + if strings.TrimSpace(accessToken) == "" { + return "", errors.New("access_token not found after version check") } + // 不写入缓存,让下次请求重新处理 + } else { + ttl := 30 * time.Minute + if expiresAt != nil { + until := time.Until(*expiresAt) + switch { + case until > antigravityTokenCacheSkew: + ttl = until - antigravityTokenCacheSkew + case until > 0: + ttl = until + default: + ttl = time.Minute + } + } + _ = p.tokenCache.SetAccessToken(ctx, cacheKey, accessToken, ttl) } - _ = p.tokenCache.SetAccessToken(ctx, cacheKey, accessToken, ttl) } return accessToken, nil diff --git a/backend/internal/service/claude_token_provider.go b/backend/internal/service/claude_token_provider.go index caf69056..f6cab204 100644 --- a/backend/internal/service/claude_token_provider.go +++ b/backend/internal/service/claude_token_provider.go @@ -182,25 +182,36 @@ func (p *ClaudeTokenProvider) GetAccessToken(ctx context.Context, account *Accou } // 3. 存入缓存(验证版本后再写入,避免异步刷新任务与请求线程的竞态条件) - if p.tokenCache != nil && !IsTokenVersionStale(ctx, account, p.accountRepo) { - ttl := 30 * time.Minute - if refreshFailed { - // 刷新失败时使用短 TTL,避免失效 token 长时间缓存导致 401 抖动 - ttl = time.Minute - slog.Debug("claude_token_cache_short_ttl", "account_id", account.ID, "reason", "refresh_failed") - } else if expiresAt != nil { - until := time.Until(*expiresAt) - switch { - case until > claudeTokenCacheSkew: - ttl = until - claudeTokenCacheSkew - case until > 0: - ttl = until - default: - ttl = time.Minute + if p.tokenCache != nil { + latestAccount, isStale := CheckTokenVersion(ctx, account, p.accountRepo) + if isStale && latestAccount != nil { + // 版本过时,使用 DB 中的最新 token + slog.Debug("claude_token_version_stale_use_latest", "account_id", account.ID) + accessToken = latestAccount.GetCredential("access_token") + if strings.TrimSpace(accessToken) == "" { + return "", errors.New("access_token not found after version check") + } + // 不写入缓存,让下次请求重新处理 + } else { + ttl := 30 * time.Minute + if refreshFailed { + // 刷新失败时使用短 TTL,避免失效 token 长时间缓存导致 401 抖动 + ttl = time.Minute + slog.Debug("claude_token_cache_short_ttl", "account_id", account.ID, "reason", "refresh_failed") + } else if expiresAt != nil { + until := time.Until(*expiresAt) + switch { + case until > claudeTokenCacheSkew: + ttl = until - claudeTokenCacheSkew + case until > 0: + ttl = until + default: + ttl = time.Minute + } + } + if err := p.tokenCache.SetAccessToken(ctx, cacheKey, accessToken, ttl); err != nil { + slog.Warn("claude_token_cache_set_failed", "account_id", account.ID, "error", err) } - } - if err := p.tokenCache.SetAccessToken(ctx, cacheKey, accessToken, ttl); err != nil { - slog.Warn("claude_token_cache_set_failed", "account_id", account.ID, "error", err) } } diff --git a/backend/internal/service/gemini_token_provider.go b/backend/internal/service/gemini_token_provider.go index 85dc64dc..313b048f 100644 --- a/backend/internal/service/gemini_token_provider.go +++ b/backend/internal/service/gemini_token_provider.go @@ -4,6 +4,7 @@ import ( "context" "errors" "log" + "log/slog" "strconv" "strings" "time" @@ -132,20 +133,31 @@ func (p *GeminiTokenProvider) GetAccessToken(ctx context.Context, account *Accou } // 3) Populate cache with TTL(验证版本后再写入,避免异步刷新任务与请求线程的竞态条件) - if p.tokenCache != nil && !IsTokenVersionStale(ctx, account, p.accountRepo) { - ttl := 30 * time.Minute - if expiresAt != nil { - until := time.Until(*expiresAt) - switch { - case until > geminiTokenCacheSkew: - ttl = until - geminiTokenCacheSkew - case until > 0: - ttl = until - default: - ttl = time.Minute + if p.tokenCache != nil { + latestAccount, isStale := CheckTokenVersion(ctx, account, p.accountRepo) + if isStale && latestAccount != nil { + // 版本过时,使用 DB 中的最新 token + slog.Debug("gemini_token_version_stale_use_latest", "account_id", account.ID) + accessToken = latestAccount.GetCredential("access_token") + if strings.TrimSpace(accessToken) == "" { + return "", errors.New("access_token not found after version check") } + // 不写入缓存,让下次请求重新处理 + } else { + ttl := 30 * time.Minute + if expiresAt != nil { + until := time.Until(*expiresAt) + switch { + case until > geminiTokenCacheSkew: + ttl = until - geminiTokenCacheSkew + case until > 0: + ttl = until + default: + ttl = time.Minute + } + } + _ = p.tokenCache.SetAccessToken(ctx, cacheKey, accessToken, ttl) } - _ = p.tokenCache.SetAccessToken(ctx, cacheKey, accessToken, ttl) } return accessToken, nil diff --git a/backend/internal/service/openai_token_provider.go b/backend/internal/service/openai_token_provider.go index de3b690d..87a7713b 100644 --- a/backend/internal/service/openai_token_provider.go +++ b/backend/internal/service/openai_token_provider.go @@ -163,25 +163,36 @@ func (p *OpenAITokenProvider) GetAccessToken(ctx context.Context, account *Accou } // 3. 存入缓存(验证版本后再写入,避免异步刷新任务与请求线程的竞态条件) - if p.tokenCache != nil && !IsTokenVersionStale(ctx, account, p.accountRepo) { - ttl := 30 * time.Minute - if refreshFailed { - // 刷新失败时使用短 TTL,避免失效 token 长时间缓存导致 401 抖动 - ttl = time.Minute - slog.Debug("openai_token_cache_short_ttl", "account_id", account.ID, "reason", "refresh_failed") - } else if expiresAt != nil { - until := time.Until(*expiresAt) - switch { - case until > openAITokenCacheSkew: - ttl = until - openAITokenCacheSkew - case until > 0: - ttl = until - default: - ttl = time.Minute + if p.tokenCache != nil { + latestAccount, isStale := CheckTokenVersion(ctx, account, p.accountRepo) + if isStale && latestAccount != nil { + // 版本过时,使用 DB 中的最新 token + slog.Debug("openai_token_version_stale_use_latest", "account_id", account.ID) + accessToken = latestAccount.GetOpenAIAccessToken() + if strings.TrimSpace(accessToken) == "" { + return "", errors.New("access_token not found after version check") + } + // 不写入缓存,让下次请求重新处理 + } else { + ttl := 30 * time.Minute + if refreshFailed { + // 刷新失败时使用短 TTL,避免失效 token 长时间缓存导致 401 抖动 + ttl = time.Minute + slog.Debug("openai_token_cache_short_ttl", "account_id", account.ID, "reason", "refresh_failed") + } else if expiresAt != nil { + until := time.Until(*expiresAt) + switch { + case until > openAITokenCacheSkew: + ttl = until - openAITokenCacheSkew + case until > 0: + ttl = until + default: + ttl = time.Minute + } + } + if err := p.tokenCache.SetAccessToken(ctx, cacheKey, accessToken, ttl); err != nil { + slog.Warn("openai_token_cache_set_failed", "account_id", account.ID, "error", err) } - } - if err := p.tokenCache.SetAccessToken(ctx, cacheKey, accessToken, ttl); err != nil { - slog.Warn("openai_token_cache_set_failed", "account_id", account.ID, "error", err) } } diff --git a/backend/internal/service/token_cache_invalidator.go b/backend/internal/service/token_cache_invalidator.go index 4407df3a..74c9edc3 100644 --- a/backend/internal/service/token_cache_invalidator.go +++ b/backend/internal/service/token_cache_invalidator.go @@ -65,22 +65,24 @@ func (c *CompositeTokenCacheInvalidator) InvalidateToken(ctx context.Context, ac return nil } -// IsTokenVersionStale 检查 account 的 token 版本是否已过时 +// CheckTokenVersion 检查 account 的 token 版本是否已过时,并返回最新的 account // 用于解决异步刷新任务与请求线程的竞态条件: // 如果刷新任务已更新 token 并删除缓存,此时请求线程的旧 account 对象不应写入缓存 // -// 返回 true 表示 token 已过时(不应缓存),false 表示可以缓存 -func IsTokenVersionStale(ctx context.Context, account *Account, repo AccountRepository) bool { +// 返回值: +// - latestAccount: 从 DB 获取的最新 account(如果查询失败则返回 nil) +// - isStale: true 表示 token 已过时(应使用 latestAccount),false 表示可以使用当前 account +func CheckTokenVersion(ctx context.Context, account *Account, repo AccountRepository) (latestAccount *Account, isStale bool) { if account == nil || repo == nil { - return false + return nil, false } currentVersion := account.GetCredentialAsInt64("_token_version") latestAccount, err := repo.GetByID(ctx, account.ID) if err != nil || latestAccount == nil { - // 查询失败,默认允许缓存 - return false + // 查询失败,默认允许缓存,不返回 latestAccount + return nil, false } latestVersion := latestAccount.GetCredentialAsInt64("_token_version") @@ -91,12 +93,12 @@ func IsTokenVersionStale(ctx context.Context, account *Account, repo AccountRepo slog.Debug("token_version_stale_no_current_version", "account_id", account.ID, "latest_version", latestVersion) - return true + return latestAccount, true } // 情况2: 两边都没有版本号,说明从未被异步刷新过,允许缓存 if currentVersion == 0 && latestVersion == 0 { - return false + return latestAccount, false } // 情况3: 比较版本号,如果 DB 中的版本更新,当前 account 已过时 @@ -105,8 +107,8 @@ func IsTokenVersionStale(ctx context.Context, account *Account, repo AccountRepo "account_id", account.ID, "current_version", currentVersion, "latest_version", latestVersion) - return true + return latestAccount, true } - return false + return latestAccount, false }