上次修复(2665230)只阻止了写入缓存,但仍返回旧token导致403。
现在版本过时时直接使用DB中的最新token返回。
- 重构 IsTokenVersionStale 为 CheckTokenVersion,返回最新account
- 消除重复DB查询,复用版本检查时已获取的account
173 lines
5.3 KiB
Go
173 lines
5.3 KiB
Go
package service
|
||
|
||
import (
|
||
"context"
|
||
"errors"
|
||
"log"
|
||
"log/slog"
|
||
"strconv"
|
||
"strings"
|
||
"time"
|
||
)
|
||
|
||
const (
|
||
geminiTokenRefreshSkew = 3 * time.Minute
|
||
geminiTokenCacheSkew = 5 * time.Minute
|
||
)
|
||
|
||
type GeminiTokenProvider struct {
|
||
accountRepo AccountRepository
|
||
tokenCache GeminiTokenCache
|
||
geminiOAuthService *GeminiOAuthService
|
||
}
|
||
|
||
func NewGeminiTokenProvider(
|
||
accountRepo AccountRepository,
|
||
tokenCache GeminiTokenCache,
|
||
geminiOAuthService *GeminiOAuthService,
|
||
) *GeminiTokenProvider {
|
||
return &GeminiTokenProvider{
|
||
accountRepo: accountRepo,
|
||
tokenCache: tokenCache,
|
||
geminiOAuthService: geminiOAuthService,
|
||
}
|
||
}
|
||
|
||
func (p *GeminiTokenProvider) GetAccessToken(ctx context.Context, account *Account) (string, error) {
|
||
if account == nil {
|
||
return "", errors.New("account is nil")
|
||
}
|
||
if account.Platform != PlatformGemini || account.Type != AccountTypeOAuth {
|
||
return "", errors.New("not a gemini oauth account")
|
||
}
|
||
|
||
cacheKey := GeminiTokenCacheKey(account)
|
||
|
||
// 1) Try cache first.
|
||
if p.tokenCache != nil {
|
||
if token, err := p.tokenCache.GetAccessToken(ctx, cacheKey); err == nil && strings.TrimSpace(token) != "" {
|
||
return token, nil
|
||
}
|
||
}
|
||
|
||
// 2) Refresh if needed (pre-expiry skew).
|
||
expiresAt := account.GetCredentialAsTime("expires_at")
|
||
needsRefresh := expiresAt == nil || time.Until(*expiresAt) <= geminiTokenRefreshSkew
|
||
if needsRefresh && p.tokenCache != nil {
|
||
locked, err := p.tokenCache.AcquireRefreshLock(ctx, cacheKey, 30*time.Second)
|
||
if err == nil && locked {
|
||
defer func() { _ = p.tokenCache.ReleaseRefreshLock(ctx, cacheKey) }()
|
||
|
||
// Re-check after lock (another worker may have refreshed).
|
||
if token, err := p.tokenCache.GetAccessToken(ctx, cacheKey); err == nil && strings.TrimSpace(token) != "" {
|
||
return token, nil
|
||
}
|
||
|
||
fresh, err := p.accountRepo.GetByID(ctx, account.ID)
|
||
if err == nil && fresh != nil {
|
||
account = fresh
|
||
}
|
||
expiresAt = account.GetCredentialAsTime("expires_at")
|
||
if expiresAt == nil || time.Until(*expiresAt) <= geminiTokenRefreshSkew {
|
||
if p.geminiOAuthService == nil {
|
||
return "", errors.New("gemini oauth service not configured")
|
||
}
|
||
tokenInfo, err := p.geminiOAuthService.RefreshAccountToken(ctx, account)
|
||
if err != nil {
|
||
return "", err
|
||
}
|
||
newCredentials := p.geminiOAuthService.BuildAccountCredentials(tokenInfo)
|
||
for k, v := range account.Credentials {
|
||
if _, exists := newCredentials[k]; !exists {
|
||
newCredentials[k] = v
|
||
}
|
||
}
|
||
account.Credentials = newCredentials
|
||
_ = p.accountRepo.Update(ctx, account)
|
||
expiresAt = account.GetCredentialAsTime("expires_at")
|
||
}
|
||
}
|
||
}
|
||
|
||
accessToken := account.GetCredential("access_token")
|
||
if strings.TrimSpace(accessToken) == "" {
|
||
return "", errors.New("access_token not found in credentials")
|
||
}
|
||
|
||
// project_id is optional now:
|
||
// - If present: will use Code Assist API (requires project_id)
|
||
// - If absent: will use AI Studio API with OAuth token (like regular API key mode)
|
||
// Auto-detect project_id only if explicitly enabled via a credential flag
|
||
projectID := strings.TrimSpace(account.GetCredential("project_id"))
|
||
autoDetectProjectID := account.GetCredential("auto_detect_project_id") == "true"
|
||
|
||
if projectID == "" && autoDetectProjectID {
|
||
if p.geminiOAuthService == nil {
|
||
return accessToken, nil // Fallback to AI Studio API mode
|
||
}
|
||
|
||
var proxyURL string
|
||
if account.ProxyID != nil && p.geminiOAuthService.proxyRepo != nil {
|
||
if proxy, err := p.geminiOAuthService.proxyRepo.GetByID(ctx, *account.ProxyID); err == nil && proxy != nil {
|
||
proxyURL = proxy.URL()
|
||
}
|
||
}
|
||
|
||
detected, tierID, err := p.geminiOAuthService.fetchProjectID(ctx, accessToken, proxyURL)
|
||
if err != nil {
|
||
log.Printf("[GeminiTokenProvider] Auto-detect project_id failed: %v, fallback to AI Studio API mode", err)
|
||
return accessToken, nil
|
||
}
|
||
detected = strings.TrimSpace(detected)
|
||
tierID = strings.TrimSpace(tierID)
|
||
if detected != "" {
|
||
if account.Credentials == nil {
|
||
account.Credentials = make(map[string]any)
|
||
}
|
||
account.Credentials["project_id"] = detected
|
||
if tierID != "" {
|
||
account.Credentials["tier_id"] = tierID
|
||
}
|
||
_ = p.accountRepo.Update(ctx, account)
|
||
}
|
||
}
|
||
|
||
// 3) Populate cache with TTL(验证版本后再写入,避免异步刷新任务与请求线程的竞态条件)
|
||
if p.tokenCache != nil {
|
||
latestAccount, isStale := CheckTokenVersion(ctx, account, p.accountRepo)
|
||
if isStale && latestAccount != nil {
|
||
// 版本过时,使用 DB 中的最新 token
|
||
slog.Debug("gemini_token_version_stale_use_latest", "account_id", account.ID)
|
||
accessToken = latestAccount.GetCredential("access_token")
|
||
if strings.TrimSpace(accessToken) == "" {
|
||
return "", errors.New("access_token not found after version check")
|
||
}
|
||
// 不写入缓存,让下次请求重新处理
|
||
} else {
|
||
ttl := 30 * time.Minute
|
||
if expiresAt != nil {
|
||
until := time.Until(*expiresAt)
|
||
switch {
|
||
case until > geminiTokenCacheSkew:
|
||
ttl = until - geminiTokenCacheSkew
|
||
case until > 0:
|
||
ttl = until
|
||
default:
|
||
ttl = time.Minute
|
||
}
|
||
}
|
||
_ = p.tokenCache.SetAccessToken(ctx, cacheKey, accessToken, ttl)
|
||
}
|
||
}
|
||
|
||
return accessToken, nil
|
||
}
|
||
|
||
func GeminiTokenCacheKey(account *Account) string {
|
||
projectID := strings.TrimSpace(account.GetCredential("project_id"))
|
||
if projectID != "" {
|
||
return "gemini:" + projectID
|
||
}
|
||
return "gemini:account:" + strconv.FormatInt(account.ID, 10)
|
||
}
|