Merge branch 'dev'
This commit is contained in:
1
.gitignore
vendored
1
.gitignore
vendored
@@ -127,3 +127,4 @@ deploy/docker-compose.override.yml
|
||||
.gocache/
|
||||
vite.config.js
|
||||
docs/*
|
||||
.serena/
|
||||
@@ -100,8 +100,8 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
|
||||
tempUnschedCache := repository.NewTempUnschedCache(redisClient)
|
||||
timeoutCounterCache := repository.NewTimeoutCounterCache(redisClient)
|
||||
geminiTokenCache := repository.NewGeminiTokenCache(redisClient)
|
||||
tokenCacheInvalidator := service.NewCompositeTokenCacheInvalidator(geminiTokenCache)
|
||||
rateLimitService := service.ProvideRateLimitService(accountRepository, usageLogRepository, configConfig, geminiQuotaService, tempUnschedCache, timeoutCounterCache, settingService, tokenCacheInvalidator)
|
||||
compositeTokenCacheInvalidator := service.NewCompositeTokenCacheInvalidator(geminiTokenCache)
|
||||
rateLimitService := service.ProvideRateLimitService(accountRepository, usageLogRepository, configConfig, geminiQuotaService, tempUnschedCache, timeoutCounterCache, settingService, compositeTokenCacheInvalidator)
|
||||
claudeUsageFetcher := repository.NewClaudeUsageFetcher()
|
||||
antigravityQuotaFetcher := service.NewAntigravityQuotaFetcher(proxyRepository)
|
||||
usageCache := service.NewUsageCache()
|
||||
@@ -136,8 +136,10 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
|
||||
identityCache := repository.NewIdentityCache(redisClient)
|
||||
identityService := service.NewIdentityService(identityCache)
|
||||
deferredService := service.ProvideDeferredService(accountRepository, timingWheelService)
|
||||
gatewayService := service.NewGatewayService(accountRepository, groupRepository, usageLogRepository, userRepository, userSubscriptionRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, identityService, httpUpstream, deferredService)
|
||||
openAIGatewayService := service.NewOpenAIGatewayService(accountRepository, usageLogRepository, userRepository, userSubscriptionRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, httpUpstream, deferredService)
|
||||
claudeTokenProvider := service.NewClaudeTokenProvider(accountRepository, geminiTokenCache, oAuthService)
|
||||
gatewayService := service.NewGatewayService(accountRepository, groupRepository, usageLogRepository, userRepository, userSubscriptionRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, identityService, httpUpstream, deferredService, claudeTokenProvider)
|
||||
openAITokenProvider := service.NewOpenAITokenProvider(accountRepository, geminiTokenCache, openAIOAuthService)
|
||||
openAIGatewayService := service.NewOpenAIGatewayService(accountRepository, usageLogRepository, userRepository, userSubscriptionRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, httpUpstream, deferredService, openAITokenProvider)
|
||||
geminiMessagesCompatService := service.NewGeminiMessagesCompatService(accountRepository, groupRepository, gatewayCache, schedulerSnapshotService, geminiTokenProvider, rateLimitService, httpUpstream, antigravityGatewayService, configConfig)
|
||||
opsService := service.NewOpsService(opsRepository, settingRepository, configConfig, accountRepository, concurrencyService, gatewayService, openAIGatewayService, geminiMessagesCompatService, antigravityGatewayService)
|
||||
settingHandler := admin.NewSettingHandler(settingService, emailService, turnstileService, opsService)
|
||||
@@ -168,7 +170,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
|
||||
opsAlertEvaluatorService := service.ProvideOpsAlertEvaluatorService(opsService, opsRepository, emailService, redisClient, configConfig)
|
||||
opsCleanupService := service.ProvideOpsCleanupService(opsRepository, db, redisClient, configConfig)
|
||||
opsScheduledReportService := service.ProvideOpsScheduledReportService(opsService, userService, emailService, redisClient, configConfig)
|
||||
tokenRefreshService := service.ProvideTokenRefreshService(accountRepository, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, tokenCacheInvalidator, configConfig)
|
||||
tokenRefreshService := service.ProvideTokenRefreshService(accountRepository, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, compositeTokenCacheInvalidator, configConfig)
|
||||
accountExpiryService := service.ProvideAccountExpiryService(accountRepository)
|
||||
v := provideCleanup(client, redisClient, opsMetricsCollector, opsAggregationService, opsAlertEvaluatorService, opsCleanupService, opsScheduledReportService, schedulerSnapshotService, tokenRefreshService, accountExpiryService, pricingService, emailQueueService, billingCacheService, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService)
|
||||
application := &Application{
|
||||
|
||||
@@ -11,8 +11,8 @@ import (
|
||||
)
|
||||
|
||||
const (
|
||||
geminiTokenKeyPrefix = "gemini:token:"
|
||||
geminiRefreshLockKeyPrefix = "gemini:refresh_lock:"
|
||||
oauthTokenKeyPrefix = "oauth:token:"
|
||||
oauthRefreshLockKeyPrefix = "oauth:refresh_lock:"
|
||||
)
|
||||
|
||||
type geminiTokenCache struct {
|
||||
@@ -24,26 +24,26 @@ func NewGeminiTokenCache(rdb *redis.Client) service.GeminiTokenCache {
|
||||
}
|
||||
|
||||
func (c *geminiTokenCache) GetAccessToken(ctx context.Context, cacheKey string) (string, error) {
|
||||
key := fmt.Sprintf("%s%s", geminiTokenKeyPrefix, cacheKey)
|
||||
key := fmt.Sprintf("%s%s", oauthTokenKeyPrefix, cacheKey)
|
||||
return c.rdb.Get(ctx, key).Result()
|
||||
}
|
||||
|
||||
func (c *geminiTokenCache) SetAccessToken(ctx context.Context, cacheKey string, token string, ttl time.Duration) error {
|
||||
key := fmt.Sprintf("%s%s", geminiTokenKeyPrefix, cacheKey)
|
||||
key := fmt.Sprintf("%s%s", oauthTokenKeyPrefix, cacheKey)
|
||||
return c.rdb.Set(ctx, key, token, ttl).Err()
|
||||
}
|
||||
|
||||
func (c *geminiTokenCache) DeleteAccessToken(ctx context.Context, cacheKey string) error {
|
||||
key := fmt.Sprintf("%s%s", geminiTokenKeyPrefix, cacheKey)
|
||||
key := fmt.Sprintf("%s%s", oauthTokenKeyPrefix, cacheKey)
|
||||
return c.rdb.Del(ctx, key).Err()
|
||||
}
|
||||
|
||||
func (c *geminiTokenCache) AcquireRefreshLock(ctx context.Context, cacheKey string, ttl time.Duration) (bool, error) {
|
||||
key := fmt.Sprintf("%s%s", geminiRefreshLockKeyPrefix, cacheKey)
|
||||
key := fmt.Sprintf("%s%s", oauthRefreshLockKeyPrefix, cacheKey)
|
||||
return c.rdb.SetNX(ctx, key, 1, ttl).Result()
|
||||
}
|
||||
|
||||
func (c *geminiTokenCache) ReleaseRefreshLock(ctx context.Context, cacheKey string) error {
|
||||
key := fmt.Sprintf("%s%s", geminiRefreshLockKeyPrefix, cacheKey)
|
||||
key := fmt.Sprintf("%s%s", oauthRefreshLockKeyPrefix, cacheKey)
|
||||
return c.rdb.Del(ctx, key).Err()
|
||||
}
|
||||
|
||||
208
backend/internal/service/claude_token_provider.go
Normal file
208
backend/internal/service/claude_token_provider.go
Normal file
@@ -0,0 +1,208 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"log/slog"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
const (
|
||||
claudeTokenRefreshSkew = 3 * time.Minute
|
||||
claudeTokenCacheSkew = 5 * time.Minute
|
||||
claudeLockWaitTime = 200 * time.Millisecond
|
||||
)
|
||||
|
||||
// ClaudeTokenCache Token 缓存接口(复用 GeminiTokenCache 接口定义)
|
||||
type ClaudeTokenCache = GeminiTokenCache
|
||||
|
||||
// ClaudeTokenProvider 管理 Claude (Anthropic) OAuth 账户的 access_token
|
||||
type ClaudeTokenProvider struct {
|
||||
accountRepo AccountRepository
|
||||
tokenCache ClaudeTokenCache
|
||||
oauthService *OAuthService
|
||||
}
|
||||
|
||||
func NewClaudeTokenProvider(
|
||||
accountRepo AccountRepository,
|
||||
tokenCache ClaudeTokenCache,
|
||||
oauthService *OAuthService,
|
||||
) *ClaudeTokenProvider {
|
||||
return &ClaudeTokenProvider{
|
||||
accountRepo: accountRepo,
|
||||
tokenCache: tokenCache,
|
||||
oauthService: oauthService,
|
||||
}
|
||||
}
|
||||
|
||||
// GetAccessToken 获取有效的 access_token
|
||||
func (p *ClaudeTokenProvider) GetAccessToken(ctx context.Context, account *Account) (string, error) {
|
||||
if account == nil {
|
||||
return "", errors.New("account is nil")
|
||||
}
|
||||
if account.Platform != PlatformAnthropic || account.Type != AccountTypeOAuth {
|
||||
return "", errors.New("not an anthropic oauth account")
|
||||
}
|
||||
|
||||
cacheKey := ClaudeTokenCacheKey(account)
|
||||
|
||||
// 1. 先尝试缓存
|
||||
if p.tokenCache != nil {
|
||||
if token, err := p.tokenCache.GetAccessToken(ctx, cacheKey); err == nil && strings.TrimSpace(token) != "" {
|
||||
slog.Debug("claude_token_cache_hit", "account_id", account.ID)
|
||||
return token, nil
|
||||
} else if err != nil {
|
||||
slog.Warn("claude_token_cache_get_failed", "account_id", account.ID, "error", err)
|
||||
}
|
||||
}
|
||||
|
||||
slog.Debug("claude_token_cache_miss", "account_id", account.ID)
|
||||
|
||||
// 2. 如果即将过期则刷新
|
||||
expiresAt := account.GetCredentialAsTime("expires_at")
|
||||
needsRefresh := expiresAt == nil || time.Until(*expiresAt) <= claudeTokenRefreshSkew
|
||||
refreshFailed := false
|
||||
if needsRefresh && p.tokenCache != nil {
|
||||
locked, lockErr := p.tokenCache.AcquireRefreshLock(ctx, cacheKey, 30*time.Second)
|
||||
if lockErr == nil && locked {
|
||||
defer func() { _ = p.tokenCache.ReleaseRefreshLock(ctx, cacheKey) }()
|
||||
|
||||
// 拿到锁后再次检查缓存(另一个 worker 可能已刷新)
|
||||
if token, err := p.tokenCache.GetAccessToken(ctx, cacheKey); err == nil && strings.TrimSpace(token) != "" {
|
||||
return token, nil
|
||||
}
|
||||
|
||||
// 从数据库获取最新账户信息
|
||||
fresh, err := p.accountRepo.GetByID(ctx, account.ID)
|
||||
if err == nil && fresh != nil {
|
||||
account = fresh
|
||||
}
|
||||
expiresAt = account.GetCredentialAsTime("expires_at")
|
||||
if expiresAt == nil || time.Until(*expiresAt) <= claudeTokenRefreshSkew {
|
||||
if p.oauthService == nil {
|
||||
slog.Warn("claude_oauth_service_not_configured", "account_id", account.ID)
|
||||
refreshFailed = true // 无法刷新,标记失败
|
||||
} else {
|
||||
tokenInfo, err := p.oauthService.RefreshAccountToken(ctx, account)
|
||||
if err != nil {
|
||||
// 刷新失败时记录警告,但不立即返回错误,尝试使用现有 token
|
||||
slog.Warn("claude_token_refresh_failed", "account_id", account.ID, "error", err)
|
||||
refreshFailed = true // 刷新失败,标记以使用短 TTL
|
||||
} else {
|
||||
// 构建新 credentials,保留原有字段
|
||||
newCredentials := make(map[string]any)
|
||||
for k, v := range account.Credentials {
|
||||
newCredentials[k] = v
|
||||
}
|
||||
newCredentials["access_token"] = tokenInfo.AccessToken
|
||||
newCredentials["token_type"] = tokenInfo.TokenType
|
||||
newCredentials["expires_in"] = strconv.FormatInt(tokenInfo.ExpiresIn, 10)
|
||||
newCredentials["expires_at"] = strconv.FormatInt(tokenInfo.ExpiresAt, 10)
|
||||
if tokenInfo.RefreshToken != "" {
|
||||
newCredentials["refresh_token"] = tokenInfo.RefreshToken
|
||||
}
|
||||
if tokenInfo.Scope != "" {
|
||||
newCredentials["scope"] = tokenInfo.Scope
|
||||
}
|
||||
account.Credentials = newCredentials
|
||||
if updateErr := p.accountRepo.Update(ctx, account); updateErr != nil {
|
||||
slog.Error("claude_token_provider_update_failed", "account_id", account.ID, "error", updateErr)
|
||||
}
|
||||
expiresAt = account.GetCredentialAsTime("expires_at")
|
||||
}
|
||||
}
|
||||
}
|
||||
} else if lockErr != nil {
|
||||
// Redis 错误导致无法获取锁,降级为无锁刷新(仅在 token 接近过期时)
|
||||
slog.Warn("claude_token_lock_failed_degraded_refresh", "account_id", account.ID, "error", lockErr)
|
||||
|
||||
// 检查 ctx 是否已取消
|
||||
if ctx.Err() != nil {
|
||||
return "", ctx.Err()
|
||||
}
|
||||
|
||||
// 从数据库获取最新账户信息
|
||||
if p.accountRepo != nil {
|
||||
fresh, err := p.accountRepo.GetByID(ctx, account.ID)
|
||||
if err == nil && fresh != nil {
|
||||
account = fresh
|
||||
}
|
||||
}
|
||||
expiresAt = account.GetCredentialAsTime("expires_at")
|
||||
|
||||
// 仅在 expires_at 已过期/接近过期时才执行无锁刷新
|
||||
if expiresAt == nil || time.Until(*expiresAt) <= claudeTokenRefreshSkew {
|
||||
if p.oauthService == nil {
|
||||
slog.Warn("claude_oauth_service_not_configured", "account_id", account.ID)
|
||||
refreshFailed = true
|
||||
} else {
|
||||
tokenInfo, err := p.oauthService.RefreshAccountToken(ctx, account)
|
||||
if err != nil {
|
||||
slog.Warn("claude_token_refresh_failed_degraded", "account_id", account.ID, "error", err)
|
||||
refreshFailed = true
|
||||
} else {
|
||||
// 构建新 credentials,保留原有字段
|
||||
newCredentials := make(map[string]any)
|
||||
for k, v := range account.Credentials {
|
||||
newCredentials[k] = v
|
||||
}
|
||||
newCredentials["access_token"] = tokenInfo.AccessToken
|
||||
newCredentials["token_type"] = tokenInfo.TokenType
|
||||
newCredentials["expires_in"] = strconv.FormatInt(tokenInfo.ExpiresIn, 10)
|
||||
newCredentials["expires_at"] = strconv.FormatInt(tokenInfo.ExpiresAt, 10)
|
||||
if tokenInfo.RefreshToken != "" {
|
||||
newCredentials["refresh_token"] = tokenInfo.RefreshToken
|
||||
}
|
||||
if tokenInfo.Scope != "" {
|
||||
newCredentials["scope"] = tokenInfo.Scope
|
||||
}
|
||||
account.Credentials = newCredentials
|
||||
if updateErr := p.accountRepo.Update(ctx, account); updateErr != nil {
|
||||
slog.Error("claude_token_provider_update_failed", "account_id", account.ID, "error", updateErr)
|
||||
}
|
||||
expiresAt = account.GetCredentialAsTime("expires_at")
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// 锁获取失败(被其他 worker 持有),等待 200ms 后重试读取缓存
|
||||
time.Sleep(claudeLockWaitTime)
|
||||
if token, err := p.tokenCache.GetAccessToken(ctx, cacheKey); err == nil && strings.TrimSpace(token) != "" {
|
||||
slog.Debug("claude_token_cache_hit_after_wait", "account_id", account.ID)
|
||||
return token, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
accessToken := account.GetCredential("access_token")
|
||||
if strings.TrimSpace(accessToken) == "" {
|
||||
return "", errors.New("access_token not found in credentials")
|
||||
}
|
||||
|
||||
// 3. 存入缓存
|
||||
if p.tokenCache != nil {
|
||||
ttl := 30 * time.Minute
|
||||
if refreshFailed {
|
||||
// 刷新失败时使用短 TTL,避免失效 token 长时间缓存导致 401 抖动
|
||||
ttl = time.Minute
|
||||
slog.Debug("claude_token_cache_short_ttl", "account_id", account.ID, "reason", "refresh_failed")
|
||||
} else if expiresAt != nil {
|
||||
until := time.Until(*expiresAt)
|
||||
switch {
|
||||
case until > claudeTokenCacheSkew:
|
||||
ttl = until - claudeTokenCacheSkew
|
||||
case until > 0:
|
||||
ttl = until
|
||||
default:
|
||||
ttl = time.Minute
|
||||
}
|
||||
}
|
||||
if err := p.tokenCache.SetAccessToken(ctx, cacheKey, accessToken, ttl); err != nil {
|
||||
slog.Warn("claude_token_cache_set_failed", "account_id", account.ID, "error", err)
|
||||
}
|
||||
}
|
||||
|
||||
return accessToken, nil
|
||||
}
|
||||
939
backend/internal/service/claude_token_provider_test.go
Normal file
939
backend/internal/service/claude_token_provider_test.go
Normal file
@@ -0,0 +1,939 @@
|
||||
//go:build unit
|
||||
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// claudeTokenCacheStub implements ClaudeTokenCache for testing
|
||||
type claudeTokenCacheStub struct {
|
||||
mu sync.Mutex
|
||||
tokens map[string]string
|
||||
getErr error
|
||||
setErr error
|
||||
deleteErr error
|
||||
lockAcquired bool
|
||||
lockErr error
|
||||
releaseLockErr error
|
||||
getCalled int32
|
||||
setCalled int32
|
||||
lockCalled int32
|
||||
unlockCalled int32
|
||||
simulateLockRace bool
|
||||
}
|
||||
|
||||
func newClaudeTokenCacheStub() *claudeTokenCacheStub {
|
||||
return &claudeTokenCacheStub{
|
||||
tokens: make(map[string]string),
|
||||
lockAcquired: true,
|
||||
}
|
||||
}
|
||||
|
||||
func (s *claudeTokenCacheStub) GetAccessToken(ctx context.Context, cacheKey string) (string, error) {
|
||||
atomic.AddInt32(&s.getCalled, 1)
|
||||
if s.getErr != nil {
|
||||
return "", s.getErr
|
||||
}
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
return s.tokens[cacheKey], nil
|
||||
}
|
||||
|
||||
func (s *claudeTokenCacheStub) SetAccessToken(ctx context.Context, cacheKey string, token string, ttl time.Duration) error {
|
||||
atomic.AddInt32(&s.setCalled, 1)
|
||||
if s.setErr != nil {
|
||||
return s.setErr
|
||||
}
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
s.tokens[cacheKey] = token
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *claudeTokenCacheStub) DeleteAccessToken(ctx context.Context, cacheKey string) error {
|
||||
if s.deleteErr != nil {
|
||||
return s.deleteErr
|
||||
}
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
delete(s.tokens, cacheKey)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *claudeTokenCacheStub) AcquireRefreshLock(ctx context.Context, cacheKey string, ttl time.Duration) (bool, error) {
|
||||
atomic.AddInt32(&s.lockCalled, 1)
|
||||
if s.lockErr != nil {
|
||||
return false, s.lockErr
|
||||
}
|
||||
if s.simulateLockRace {
|
||||
return false, nil
|
||||
}
|
||||
return s.lockAcquired, nil
|
||||
}
|
||||
|
||||
func (s *claudeTokenCacheStub) ReleaseRefreshLock(ctx context.Context, cacheKey string) error {
|
||||
atomic.AddInt32(&s.unlockCalled, 1)
|
||||
return s.releaseLockErr
|
||||
}
|
||||
|
||||
// claudeAccountRepoStub is a minimal stub implementing only the methods used by ClaudeTokenProvider
|
||||
type claudeAccountRepoStub struct {
|
||||
account *Account
|
||||
getErr error
|
||||
updateErr error
|
||||
getCalled int32
|
||||
updateCalled int32
|
||||
}
|
||||
|
||||
func (r *claudeAccountRepoStub) GetByID(ctx context.Context, id int64) (*Account, error) {
|
||||
atomic.AddInt32(&r.getCalled, 1)
|
||||
if r.getErr != nil {
|
||||
return nil, r.getErr
|
||||
}
|
||||
return r.account, nil
|
||||
}
|
||||
|
||||
func (r *claudeAccountRepoStub) Update(ctx context.Context, account *Account) error {
|
||||
atomic.AddInt32(&r.updateCalled, 1)
|
||||
if r.updateErr != nil {
|
||||
return r.updateErr
|
||||
}
|
||||
r.account = account
|
||||
return nil
|
||||
}
|
||||
|
||||
// claudeOAuthServiceStub implements OAuthService methods for testing
|
||||
type claudeOAuthServiceStub struct {
|
||||
tokenInfo *TokenInfo
|
||||
refreshErr error
|
||||
refreshCalled int32
|
||||
}
|
||||
|
||||
func (s *claudeOAuthServiceStub) RefreshAccountToken(ctx context.Context, account *Account) (*TokenInfo, error) {
|
||||
atomic.AddInt32(&s.refreshCalled, 1)
|
||||
if s.refreshErr != nil {
|
||||
return nil, s.refreshErr
|
||||
}
|
||||
return s.tokenInfo, nil
|
||||
}
|
||||
|
||||
// testClaudeTokenProvider is a test version that uses the stub OAuth service
|
||||
type testClaudeTokenProvider struct {
|
||||
accountRepo *claudeAccountRepoStub
|
||||
tokenCache *claudeTokenCacheStub
|
||||
oauthService *claudeOAuthServiceStub
|
||||
}
|
||||
|
||||
func (p *testClaudeTokenProvider) GetAccessToken(ctx context.Context, account *Account) (string, error) {
|
||||
if account == nil {
|
||||
return "", errors.New("account is nil")
|
||||
}
|
||||
if account.Platform != PlatformAnthropic || account.Type != AccountTypeOAuth {
|
||||
return "", errors.New("not an anthropic oauth account")
|
||||
}
|
||||
|
||||
cacheKey := ClaudeTokenCacheKey(account)
|
||||
|
||||
// 1. Check cache
|
||||
if p.tokenCache != nil {
|
||||
if token, err := p.tokenCache.GetAccessToken(ctx, cacheKey); err == nil && token != "" {
|
||||
return token, nil
|
||||
}
|
||||
}
|
||||
|
||||
// 2. Check if refresh needed
|
||||
expiresAt := account.GetCredentialAsTime("expires_at")
|
||||
needsRefresh := expiresAt == nil || time.Until(*expiresAt) <= claudeTokenRefreshSkew
|
||||
refreshFailed := false
|
||||
if needsRefresh && p.tokenCache != nil {
|
||||
locked, err := p.tokenCache.AcquireRefreshLock(ctx, cacheKey, 30*time.Second)
|
||||
if err == nil && locked {
|
||||
defer func() { _ = p.tokenCache.ReleaseRefreshLock(ctx, cacheKey) }()
|
||||
|
||||
// Check cache again after acquiring lock
|
||||
if token, err := p.tokenCache.GetAccessToken(ctx, cacheKey); err == nil && token != "" {
|
||||
return token, nil
|
||||
}
|
||||
|
||||
// Get fresh account from DB
|
||||
fresh, err := p.accountRepo.GetByID(ctx, account.ID)
|
||||
if err == nil && fresh != nil {
|
||||
account = fresh
|
||||
}
|
||||
expiresAt = account.GetCredentialAsTime("expires_at")
|
||||
if expiresAt == nil || time.Until(*expiresAt) <= claudeTokenRefreshSkew {
|
||||
if p.oauthService == nil {
|
||||
refreshFailed = true // 无法刷新,标记失败
|
||||
} else {
|
||||
tokenInfo, err := p.oauthService.RefreshAccountToken(ctx, account)
|
||||
if err != nil {
|
||||
refreshFailed = true // 刷新失败,标记以使用短 TTL
|
||||
} else {
|
||||
// Build new credentials
|
||||
newCredentials := make(map[string]any)
|
||||
for k, v := range account.Credentials {
|
||||
newCredentials[k] = v
|
||||
}
|
||||
newCredentials["access_token"] = tokenInfo.AccessToken
|
||||
newCredentials["token_type"] = tokenInfo.TokenType
|
||||
newCredentials["expires_at"] = time.Now().Add(time.Duration(tokenInfo.ExpiresIn) * time.Second).Format(time.RFC3339)
|
||||
if tokenInfo.RefreshToken != "" {
|
||||
newCredentials["refresh_token"] = tokenInfo.RefreshToken
|
||||
}
|
||||
account.Credentials = newCredentials
|
||||
_ = p.accountRepo.Update(ctx, account)
|
||||
expiresAt = account.GetCredentialAsTime("expires_at")
|
||||
}
|
||||
}
|
||||
}
|
||||
} else if p.tokenCache.simulateLockRace {
|
||||
// Wait and retry cache
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
if token, err := p.tokenCache.GetAccessToken(ctx, cacheKey); err == nil && token != "" {
|
||||
return token, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
accessToken := account.GetCredential("access_token")
|
||||
if accessToken == "" {
|
||||
return "", errors.New("access_token not found in credentials")
|
||||
}
|
||||
|
||||
// 3. Store in cache
|
||||
if p.tokenCache != nil {
|
||||
ttl := 30 * time.Minute
|
||||
if refreshFailed {
|
||||
ttl = time.Minute // 刷新失败时使用短 TTL
|
||||
} else if expiresAt != nil {
|
||||
until := time.Until(*expiresAt)
|
||||
if until > claudeTokenCacheSkew {
|
||||
ttl = until - claudeTokenCacheSkew
|
||||
} else if until > 0 {
|
||||
ttl = until
|
||||
} else {
|
||||
ttl = time.Minute
|
||||
}
|
||||
}
|
||||
_ = p.tokenCache.SetAccessToken(ctx, cacheKey, accessToken, ttl)
|
||||
}
|
||||
|
||||
return accessToken, nil
|
||||
}
|
||||
|
||||
func TestClaudeTokenProvider_CacheHit(t *testing.T) {
|
||||
cache := newClaudeTokenCacheStub()
|
||||
account := &Account{
|
||||
ID: 100,
|
||||
Platform: PlatformAnthropic,
|
||||
Type: AccountTypeOAuth,
|
||||
Credentials: map[string]any{
|
||||
"access_token": "db-token",
|
||||
},
|
||||
}
|
||||
cacheKey := ClaudeTokenCacheKey(account)
|
||||
cache.tokens[cacheKey] = "cached-token"
|
||||
|
||||
provider := NewClaudeTokenProvider(nil, cache, nil)
|
||||
|
||||
token, err := provider.GetAccessToken(context.Background(), account)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "cached-token", token)
|
||||
require.Equal(t, int32(1), atomic.LoadInt32(&cache.getCalled))
|
||||
require.Equal(t, int32(0), atomic.LoadInt32(&cache.setCalled))
|
||||
}
|
||||
|
||||
func TestClaudeTokenProvider_CacheMiss_FromCredentials(t *testing.T) {
|
||||
cache := newClaudeTokenCacheStub()
|
||||
// Token expires in far future, no refresh needed
|
||||
expiresAt := time.Now().Add(1 * time.Hour).Format(time.RFC3339)
|
||||
account := &Account{
|
||||
ID: 101,
|
||||
Platform: PlatformAnthropic,
|
||||
Type: AccountTypeOAuth,
|
||||
Credentials: map[string]any{
|
||||
"access_token": "credential-token",
|
||||
"expires_at": expiresAt,
|
||||
},
|
||||
}
|
||||
|
||||
provider := NewClaudeTokenProvider(nil, cache, nil)
|
||||
|
||||
token, err := provider.GetAccessToken(context.Background(), account)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "credential-token", token)
|
||||
|
||||
// Should have stored in cache
|
||||
cacheKey := ClaudeTokenCacheKey(account)
|
||||
require.Equal(t, "credential-token", cache.tokens[cacheKey])
|
||||
}
|
||||
|
||||
func TestClaudeTokenProvider_TokenRefresh(t *testing.T) {
|
||||
cache := newClaudeTokenCacheStub()
|
||||
accountRepo := &claudeAccountRepoStub{}
|
||||
oauthService := &claudeOAuthServiceStub{
|
||||
tokenInfo: &TokenInfo{
|
||||
AccessToken: "refreshed-token",
|
||||
RefreshToken: "new-refresh-token",
|
||||
TokenType: "Bearer",
|
||||
ExpiresIn: 3600,
|
||||
ExpiresAt: time.Now().Add(time.Hour).Unix(),
|
||||
},
|
||||
}
|
||||
|
||||
// Token expires soon (within refresh skew)
|
||||
expiresAt := time.Now().Add(1 * time.Minute).Format(time.RFC3339)
|
||||
account := &Account{
|
||||
ID: 102,
|
||||
Platform: PlatformAnthropic,
|
||||
Type: AccountTypeOAuth,
|
||||
Credentials: map[string]any{
|
||||
"access_token": "old-token",
|
||||
"refresh_token": "old-refresh-token",
|
||||
"expires_at": expiresAt,
|
||||
},
|
||||
}
|
||||
accountRepo.account = account
|
||||
|
||||
provider := &testClaudeTokenProvider{
|
||||
accountRepo: accountRepo,
|
||||
tokenCache: cache,
|
||||
oauthService: oauthService,
|
||||
}
|
||||
|
||||
token, err := provider.GetAccessToken(context.Background(), account)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "refreshed-token", token)
|
||||
require.Equal(t, int32(1), atomic.LoadInt32(&oauthService.refreshCalled))
|
||||
}
|
||||
|
||||
func TestClaudeTokenProvider_LockRaceCondition(t *testing.T) {
|
||||
cache := newClaudeTokenCacheStub()
|
||||
cache.simulateLockRace = true
|
||||
accountRepo := &claudeAccountRepoStub{}
|
||||
|
||||
// Token expires soon
|
||||
expiresAt := time.Now().Add(1 * time.Minute).Format(time.RFC3339)
|
||||
account := &Account{
|
||||
ID: 103,
|
||||
Platform: PlatformAnthropic,
|
||||
Type: AccountTypeOAuth,
|
||||
Credentials: map[string]any{
|
||||
"access_token": "race-token",
|
||||
"expires_at": expiresAt,
|
||||
},
|
||||
}
|
||||
accountRepo.account = account
|
||||
|
||||
// Simulate another worker already refreshed and cached
|
||||
cacheKey := ClaudeTokenCacheKey(account)
|
||||
go func() {
|
||||
time.Sleep(5 * time.Millisecond)
|
||||
cache.mu.Lock()
|
||||
cache.tokens[cacheKey] = "winner-token"
|
||||
cache.mu.Unlock()
|
||||
}()
|
||||
|
||||
provider := &testClaudeTokenProvider{
|
||||
accountRepo: accountRepo,
|
||||
tokenCache: cache,
|
||||
}
|
||||
|
||||
token, err := provider.GetAccessToken(context.Background(), account)
|
||||
require.NoError(t, err)
|
||||
require.NotEmpty(t, token)
|
||||
}
|
||||
|
||||
func TestClaudeTokenProvider_NilAccount(t *testing.T) {
|
||||
provider := NewClaudeTokenProvider(nil, nil, nil)
|
||||
|
||||
token, err := provider.GetAccessToken(context.Background(), nil)
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "account is nil")
|
||||
require.Empty(t, token)
|
||||
}
|
||||
|
||||
func TestClaudeTokenProvider_WrongPlatform(t *testing.T) {
|
||||
provider := NewClaudeTokenProvider(nil, nil, nil)
|
||||
account := &Account{
|
||||
ID: 104,
|
||||
Platform: PlatformOpenAI,
|
||||
Type: AccountTypeOAuth,
|
||||
}
|
||||
|
||||
token, err := provider.GetAccessToken(context.Background(), account)
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "not an anthropic oauth account")
|
||||
require.Empty(t, token)
|
||||
}
|
||||
|
||||
func TestClaudeTokenProvider_WrongAccountType(t *testing.T) {
|
||||
provider := NewClaudeTokenProvider(nil, nil, nil)
|
||||
account := &Account{
|
||||
ID: 105,
|
||||
Platform: PlatformAnthropic,
|
||||
Type: AccountTypeAPIKey,
|
||||
}
|
||||
|
||||
token, err := provider.GetAccessToken(context.Background(), account)
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "not an anthropic oauth account")
|
||||
require.Empty(t, token)
|
||||
}
|
||||
|
||||
func TestClaudeTokenProvider_SetupTokenType(t *testing.T) {
|
||||
provider := NewClaudeTokenProvider(nil, nil, nil)
|
||||
account := &Account{
|
||||
ID: 106,
|
||||
Platform: PlatformAnthropic,
|
||||
Type: AccountTypeSetupToken,
|
||||
}
|
||||
|
||||
token, err := provider.GetAccessToken(context.Background(), account)
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "not an anthropic oauth account")
|
||||
require.Empty(t, token)
|
||||
}
|
||||
|
||||
func TestClaudeTokenProvider_NilCache(t *testing.T) {
|
||||
// Token doesn't need refresh
|
||||
expiresAt := time.Now().Add(1 * time.Hour).Format(time.RFC3339)
|
||||
account := &Account{
|
||||
ID: 107,
|
||||
Platform: PlatformAnthropic,
|
||||
Type: AccountTypeOAuth,
|
||||
Credentials: map[string]any{
|
||||
"access_token": "nocache-token",
|
||||
"expires_at": expiresAt,
|
||||
},
|
||||
}
|
||||
|
||||
provider := NewClaudeTokenProvider(nil, nil, nil)
|
||||
|
||||
token, err := provider.GetAccessToken(context.Background(), account)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "nocache-token", token)
|
||||
}
|
||||
|
||||
func TestClaudeTokenProvider_CacheGetError(t *testing.T) {
|
||||
cache := newClaudeTokenCacheStub()
|
||||
cache.getErr = errors.New("redis connection failed")
|
||||
|
||||
// Token doesn't need refresh
|
||||
expiresAt := time.Now().Add(1 * time.Hour).Format(time.RFC3339)
|
||||
account := &Account{
|
||||
ID: 108,
|
||||
Platform: PlatformAnthropic,
|
||||
Type: AccountTypeOAuth,
|
||||
Credentials: map[string]any{
|
||||
"access_token": "fallback-token",
|
||||
"expires_at": expiresAt,
|
||||
},
|
||||
}
|
||||
|
||||
provider := NewClaudeTokenProvider(nil, cache, nil)
|
||||
|
||||
// Should gracefully degrade and return from credentials
|
||||
token, err := provider.GetAccessToken(context.Background(), account)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "fallback-token", token)
|
||||
}
|
||||
|
||||
func TestClaudeTokenProvider_CacheSetError(t *testing.T) {
|
||||
cache := newClaudeTokenCacheStub()
|
||||
cache.setErr = errors.New("redis write failed")
|
||||
|
||||
expiresAt := time.Now().Add(1 * time.Hour).Format(time.RFC3339)
|
||||
account := &Account{
|
||||
ID: 109,
|
||||
Platform: PlatformAnthropic,
|
||||
Type: AccountTypeOAuth,
|
||||
Credentials: map[string]any{
|
||||
"access_token": "still-works-token",
|
||||
"expires_at": expiresAt,
|
||||
},
|
||||
}
|
||||
|
||||
provider := NewClaudeTokenProvider(nil, cache, nil)
|
||||
|
||||
// Should still work even if cache set fails
|
||||
token, err := provider.GetAccessToken(context.Background(), account)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "still-works-token", token)
|
||||
}
|
||||
|
||||
func TestClaudeTokenProvider_MissingAccessToken(t *testing.T) {
|
||||
cache := newClaudeTokenCacheStub()
|
||||
expiresAt := time.Now().Add(1 * time.Hour).Format(time.RFC3339)
|
||||
account := &Account{
|
||||
ID: 110,
|
||||
Platform: PlatformAnthropic,
|
||||
Type: AccountTypeOAuth,
|
||||
Credentials: map[string]any{
|
||||
"expires_at": expiresAt,
|
||||
// missing access_token
|
||||
},
|
||||
}
|
||||
|
||||
provider := NewClaudeTokenProvider(nil, cache, nil)
|
||||
|
||||
token, err := provider.GetAccessToken(context.Background(), account)
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "access_token not found")
|
||||
require.Empty(t, token)
|
||||
}
|
||||
|
||||
func TestClaudeTokenProvider_RefreshError(t *testing.T) {
|
||||
cache := newClaudeTokenCacheStub()
|
||||
accountRepo := &claudeAccountRepoStub{}
|
||||
oauthService := &claudeOAuthServiceStub{
|
||||
refreshErr: errors.New("oauth refresh failed"),
|
||||
}
|
||||
|
||||
// Token expires soon
|
||||
expiresAt := time.Now().Add(1 * time.Minute).Format(time.RFC3339)
|
||||
account := &Account{
|
||||
ID: 111,
|
||||
Platform: PlatformAnthropic,
|
||||
Type: AccountTypeOAuth,
|
||||
Credentials: map[string]any{
|
||||
"access_token": "old-token",
|
||||
"refresh_token": "old-refresh-token",
|
||||
"expires_at": expiresAt,
|
||||
},
|
||||
}
|
||||
accountRepo.account = account
|
||||
|
||||
provider := &testClaudeTokenProvider{
|
||||
accountRepo: accountRepo,
|
||||
tokenCache: cache,
|
||||
oauthService: oauthService,
|
||||
}
|
||||
|
||||
// Now with fallback behavior, should return existing token even if refresh fails
|
||||
token, err := provider.GetAccessToken(context.Background(), account)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "old-token", token) // Fallback to existing token
|
||||
}
|
||||
|
||||
func TestClaudeTokenProvider_OAuthServiceNotConfigured(t *testing.T) {
|
||||
cache := newClaudeTokenCacheStub()
|
||||
accountRepo := &claudeAccountRepoStub{}
|
||||
|
||||
// Token expires soon
|
||||
expiresAt := time.Now().Add(1 * time.Minute).Format(time.RFC3339)
|
||||
account := &Account{
|
||||
ID: 112,
|
||||
Platform: PlatformAnthropic,
|
||||
Type: AccountTypeOAuth,
|
||||
Credentials: map[string]any{
|
||||
"access_token": "old-token",
|
||||
"expires_at": expiresAt,
|
||||
},
|
||||
}
|
||||
accountRepo.account = account
|
||||
|
||||
provider := &testClaudeTokenProvider{
|
||||
accountRepo: accountRepo,
|
||||
tokenCache: cache,
|
||||
oauthService: nil, // not configured
|
||||
}
|
||||
|
||||
// Now with fallback behavior, should return existing token even if oauth service not configured
|
||||
token, err := provider.GetAccessToken(context.Background(), account)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "old-token", token) // Fallback to existing token
|
||||
}
|
||||
|
||||
func TestClaudeTokenProvider_TTLCalculation(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
expiresIn time.Duration
|
||||
}{
|
||||
{
|
||||
name: "far_future_expiry",
|
||||
expiresIn: 1 * time.Hour,
|
||||
},
|
||||
{
|
||||
name: "medium_expiry",
|
||||
expiresIn: 10 * time.Minute,
|
||||
},
|
||||
{
|
||||
name: "near_expiry",
|
||||
expiresIn: 6 * time.Minute,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
cache := newClaudeTokenCacheStub()
|
||||
expiresAt := time.Now().Add(tt.expiresIn).Format(time.RFC3339)
|
||||
account := &Account{
|
||||
ID: 200,
|
||||
Platform: PlatformAnthropic,
|
||||
Type: AccountTypeOAuth,
|
||||
Credentials: map[string]any{
|
||||
"access_token": "test-token",
|
||||
"expires_at": expiresAt,
|
||||
},
|
||||
}
|
||||
|
||||
provider := NewClaudeTokenProvider(nil, cache, nil)
|
||||
|
||||
_, err := provider.GetAccessToken(context.Background(), account)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify token was cached
|
||||
cacheKey := ClaudeTokenCacheKey(account)
|
||||
require.Equal(t, "test-token", cache.tokens[cacheKey])
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestClaudeTokenProvider_AccountRepoGetError(t *testing.T) {
|
||||
cache := newClaudeTokenCacheStub()
|
||||
accountRepo := &claudeAccountRepoStub{
|
||||
getErr: errors.New("db connection failed"),
|
||||
}
|
||||
oauthService := &claudeOAuthServiceStub{
|
||||
tokenInfo: &TokenInfo{
|
||||
AccessToken: "refreshed-token",
|
||||
RefreshToken: "new-refresh",
|
||||
TokenType: "Bearer",
|
||||
ExpiresIn: 3600,
|
||||
},
|
||||
}
|
||||
|
||||
// Token expires soon
|
||||
expiresAt := time.Now().Add(1 * time.Minute).Format(time.RFC3339)
|
||||
account := &Account{
|
||||
ID: 113,
|
||||
Platform: PlatformAnthropic,
|
||||
Type: AccountTypeOAuth,
|
||||
Credentials: map[string]any{
|
||||
"access_token": "old-token",
|
||||
"refresh_token": "old-refresh",
|
||||
"expires_at": expiresAt,
|
||||
},
|
||||
}
|
||||
|
||||
provider := &testClaudeTokenProvider{
|
||||
accountRepo: accountRepo,
|
||||
tokenCache: cache,
|
||||
oauthService: oauthService,
|
||||
}
|
||||
|
||||
// Should still work, just using the passed-in account
|
||||
token, err := provider.GetAccessToken(context.Background(), account)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "refreshed-token", token)
|
||||
}
|
||||
|
||||
func TestClaudeTokenProvider_AccountUpdateError(t *testing.T) {
|
||||
cache := newClaudeTokenCacheStub()
|
||||
accountRepo := &claudeAccountRepoStub{
|
||||
updateErr: errors.New("db write failed"),
|
||||
}
|
||||
oauthService := &claudeOAuthServiceStub{
|
||||
tokenInfo: &TokenInfo{
|
||||
AccessToken: "refreshed-token",
|
||||
RefreshToken: "new-refresh",
|
||||
TokenType: "Bearer",
|
||||
ExpiresIn: 3600,
|
||||
},
|
||||
}
|
||||
|
||||
// Token expires soon
|
||||
expiresAt := time.Now().Add(1 * time.Minute).Format(time.RFC3339)
|
||||
account := &Account{
|
||||
ID: 114,
|
||||
Platform: PlatformAnthropic,
|
||||
Type: AccountTypeOAuth,
|
||||
Credentials: map[string]any{
|
||||
"access_token": "old-token",
|
||||
"refresh_token": "old-refresh",
|
||||
"expires_at": expiresAt,
|
||||
},
|
||||
}
|
||||
accountRepo.account = account
|
||||
|
||||
provider := &testClaudeTokenProvider{
|
||||
accountRepo: accountRepo,
|
||||
tokenCache: cache,
|
||||
oauthService: oauthService,
|
||||
}
|
||||
|
||||
// Should still return token even if update fails
|
||||
token, err := provider.GetAccessToken(context.Background(), account)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "refreshed-token", token)
|
||||
}
|
||||
|
||||
func TestClaudeTokenProvider_RefreshPreservesExistingCredentials(t *testing.T) {
|
||||
cache := newClaudeTokenCacheStub()
|
||||
accountRepo := &claudeAccountRepoStub{}
|
||||
oauthService := &claudeOAuthServiceStub{
|
||||
tokenInfo: &TokenInfo{
|
||||
AccessToken: "new-access-token",
|
||||
RefreshToken: "new-refresh-token",
|
||||
TokenType: "Bearer",
|
||||
ExpiresIn: 3600,
|
||||
},
|
||||
}
|
||||
|
||||
// Token expires soon
|
||||
expiresAt := time.Now().Add(1 * time.Minute).Format(time.RFC3339)
|
||||
account := &Account{
|
||||
ID: 115,
|
||||
Platform: PlatformAnthropic,
|
||||
Type: AccountTypeOAuth,
|
||||
Credentials: map[string]any{
|
||||
"access_token": "old-access-token",
|
||||
"refresh_token": "old-refresh-token",
|
||||
"expires_at": expiresAt,
|
||||
"custom_field": "should-be-preserved",
|
||||
"organization": "test-org",
|
||||
},
|
||||
}
|
||||
accountRepo.account = account
|
||||
|
||||
provider := &testClaudeTokenProvider{
|
||||
accountRepo: accountRepo,
|
||||
tokenCache: cache,
|
||||
oauthService: oauthService,
|
||||
}
|
||||
|
||||
token, err := provider.GetAccessToken(context.Background(), account)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "new-access-token", token)
|
||||
|
||||
// Verify existing fields are preserved
|
||||
require.Equal(t, "should-be-preserved", accountRepo.account.Credentials["custom_field"])
|
||||
require.Equal(t, "test-org", accountRepo.account.Credentials["organization"])
|
||||
// Verify new fields are updated
|
||||
require.Equal(t, "new-access-token", accountRepo.account.Credentials["access_token"])
|
||||
require.Equal(t, "new-refresh-token", accountRepo.account.Credentials["refresh_token"])
|
||||
}
|
||||
|
||||
func TestClaudeTokenProvider_DoubleCheckCacheAfterLock(t *testing.T) {
|
||||
cache := newClaudeTokenCacheStub()
|
||||
accountRepo := &claudeAccountRepoStub{}
|
||||
oauthService := &claudeOAuthServiceStub{
|
||||
tokenInfo: &TokenInfo{
|
||||
AccessToken: "refreshed-token",
|
||||
RefreshToken: "new-refresh",
|
||||
TokenType: "Bearer",
|
||||
ExpiresIn: 3600,
|
||||
},
|
||||
}
|
||||
|
||||
// Token expires soon
|
||||
expiresAt := time.Now().Add(1 * time.Minute).Format(time.RFC3339)
|
||||
account := &Account{
|
||||
ID: 116,
|
||||
Platform: PlatformAnthropic,
|
||||
Type: AccountTypeOAuth,
|
||||
Credentials: map[string]any{
|
||||
"access_token": "old-token",
|
||||
"expires_at": expiresAt,
|
||||
},
|
||||
}
|
||||
accountRepo.account = account
|
||||
cacheKey := ClaudeTokenCacheKey(account)
|
||||
|
||||
// After lock is acquired, cache should have the token (simulating another worker)
|
||||
go func() {
|
||||
time.Sleep(5 * time.Millisecond)
|
||||
cache.mu.Lock()
|
||||
cache.tokens[cacheKey] = "cached-by-other-worker"
|
||||
cache.mu.Unlock()
|
||||
}()
|
||||
|
||||
provider := &testClaudeTokenProvider{
|
||||
accountRepo: accountRepo,
|
||||
tokenCache: cache,
|
||||
oauthService: oauthService,
|
||||
}
|
||||
|
||||
token, err := provider.GetAccessToken(context.Background(), account)
|
||||
require.NoError(t, err)
|
||||
require.NotEmpty(t, token)
|
||||
}
|
||||
|
||||
// Tests for real provider - to increase coverage
|
||||
func TestClaudeTokenProvider_Real_LockFailedWait(t *testing.T) {
|
||||
cache := newClaudeTokenCacheStub()
|
||||
cache.lockAcquired = false // Lock acquisition fails
|
||||
|
||||
// Token expires soon (within refresh skew) to trigger lock attempt
|
||||
expiresAt := time.Now().Add(1 * time.Minute).Format(time.RFC3339)
|
||||
account := &Account{
|
||||
ID: 300,
|
||||
Platform: PlatformAnthropic,
|
||||
Type: AccountTypeOAuth,
|
||||
Credentials: map[string]any{
|
||||
"access_token": "fallback-token",
|
||||
"expires_at": expiresAt,
|
||||
},
|
||||
}
|
||||
|
||||
// Set token in cache after lock wait period (simulate other worker refreshing)
|
||||
cacheKey := ClaudeTokenCacheKey(account)
|
||||
go func() {
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
cache.mu.Lock()
|
||||
cache.tokens[cacheKey] = "refreshed-by-other"
|
||||
cache.mu.Unlock()
|
||||
}()
|
||||
|
||||
provider := NewClaudeTokenProvider(nil, cache, nil)
|
||||
token, err := provider.GetAccessToken(context.Background(), account)
|
||||
require.NoError(t, err)
|
||||
require.NotEmpty(t, token)
|
||||
}
|
||||
|
||||
func TestClaudeTokenProvider_Real_CacheHitAfterWait(t *testing.T) {
|
||||
cache := newClaudeTokenCacheStub()
|
||||
cache.lockAcquired = false // Lock acquisition fails
|
||||
|
||||
// Token expires soon
|
||||
expiresAt := time.Now().Add(1 * time.Minute).Format(time.RFC3339)
|
||||
account := &Account{
|
||||
ID: 301,
|
||||
Platform: PlatformAnthropic,
|
||||
Type: AccountTypeOAuth,
|
||||
Credentials: map[string]any{
|
||||
"access_token": "original-token",
|
||||
"expires_at": expiresAt,
|
||||
},
|
||||
}
|
||||
|
||||
cacheKey := ClaudeTokenCacheKey(account)
|
||||
// Set token in cache immediately after wait starts
|
||||
go func() {
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
cache.mu.Lock()
|
||||
cache.tokens[cacheKey] = "winner-token"
|
||||
cache.mu.Unlock()
|
||||
}()
|
||||
|
||||
provider := NewClaudeTokenProvider(nil, cache, nil)
|
||||
token, err := provider.GetAccessToken(context.Background(), account)
|
||||
require.NoError(t, err)
|
||||
require.NotEmpty(t, token)
|
||||
}
|
||||
|
||||
func TestClaudeTokenProvider_Real_NoExpiresAt(t *testing.T) {
|
||||
cache := newClaudeTokenCacheStub()
|
||||
cache.lockAcquired = false // Prevent entering refresh logic
|
||||
|
||||
// Token with nil expires_at (no expiry set)
|
||||
account := &Account{
|
||||
ID: 302,
|
||||
Platform: PlatformAnthropic,
|
||||
Type: AccountTypeOAuth,
|
||||
Credentials: map[string]any{
|
||||
"access_token": "no-expiry-token",
|
||||
},
|
||||
}
|
||||
|
||||
// After lock wait, return token from credentials
|
||||
provider := NewClaudeTokenProvider(nil, cache, nil)
|
||||
token, err := provider.GetAccessToken(context.Background(), account)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "no-expiry-token", token)
|
||||
}
|
||||
|
||||
func TestClaudeTokenProvider_Real_WhitespaceToken(t *testing.T) {
|
||||
cache := newClaudeTokenCacheStub()
|
||||
cacheKey := "claude:account:303"
|
||||
cache.tokens[cacheKey] = " " // Whitespace only - should be treated as empty
|
||||
|
||||
expiresAt := time.Now().Add(1 * time.Hour).Format(time.RFC3339)
|
||||
account := &Account{
|
||||
ID: 303,
|
||||
Platform: PlatformAnthropic,
|
||||
Type: AccountTypeOAuth,
|
||||
Credentials: map[string]any{
|
||||
"access_token": "real-token",
|
||||
"expires_at": expiresAt,
|
||||
},
|
||||
}
|
||||
|
||||
provider := NewClaudeTokenProvider(nil, cache, nil)
|
||||
token, err := provider.GetAccessToken(context.Background(), account)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "real-token", token)
|
||||
}
|
||||
|
||||
func TestClaudeTokenProvider_Real_EmptyCredentialToken(t *testing.T) {
|
||||
cache := newClaudeTokenCacheStub()
|
||||
|
||||
expiresAt := time.Now().Add(1 * time.Hour).Format(time.RFC3339)
|
||||
account := &Account{
|
||||
ID: 304,
|
||||
Platform: PlatformAnthropic,
|
||||
Type: AccountTypeOAuth,
|
||||
Credentials: map[string]any{
|
||||
"access_token": " ", // Whitespace only
|
||||
"expires_at": expiresAt,
|
||||
},
|
||||
}
|
||||
|
||||
provider := NewClaudeTokenProvider(nil, cache, nil)
|
||||
token, err := provider.GetAccessToken(context.Background(), account)
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "access_token not found")
|
||||
require.Empty(t, token)
|
||||
}
|
||||
|
||||
func TestClaudeTokenProvider_Real_LockError(t *testing.T) {
|
||||
cache := newClaudeTokenCacheStub()
|
||||
cache.lockErr = errors.New("redis lock failed")
|
||||
|
||||
// Token expires soon (within refresh skew)
|
||||
expiresAt := time.Now().Add(1 * time.Minute).Format(time.RFC3339)
|
||||
account := &Account{
|
||||
ID: 305,
|
||||
Platform: PlatformAnthropic,
|
||||
Type: AccountTypeOAuth,
|
||||
Credentials: map[string]any{
|
||||
"access_token": "fallback-on-lock-error",
|
||||
"expires_at": expiresAt,
|
||||
},
|
||||
}
|
||||
|
||||
provider := NewClaudeTokenProvider(nil, cache, nil)
|
||||
token, err := provider.GetAccessToken(context.Background(), account)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "fallback-on-lock-error", token)
|
||||
}
|
||||
|
||||
func TestClaudeTokenProvider_Real_NilCredentials(t *testing.T) {
|
||||
cache := newClaudeTokenCacheStub()
|
||||
|
||||
expiresAt := time.Now().Add(1 * time.Hour).Format(time.RFC3339)
|
||||
account := &Account{
|
||||
ID: 306,
|
||||
Platform: PlatformAnthropic,
|
||||
Type: AccountTypeOAuth,
|
||||
Credentials: map[string]any{
|
||||
"expires_at": expiresAt,
|
||||
// No access_token
|
||||
},
|
||||
}
|
||||
|
||||
provider := NewClaudeTokenProvider(nil, cache, nil)
|
||||
token, err := provider.GetAccessToken(context.Background(), account)
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "access_token not found")
|
||||
require.Empty(t, token)
|
||||
}
|
||||
@@ -144,21 +144,22 @@ func (e *UpstreamFailoverError) Error() string {
|
||||
|
||||
// GatewayService handles API gateway operations
|
||||
type GatewayService struct {
|
||||
accountRepo AccountRepository
|
||||
groupRepo GroupRepository
|
||||
usageLogRepo UsageLogRepository
|
||||
userRepo UserRepository
|
||||
userSubRepo UserSubscriptionRepository
|
||||
cache GatewayCache
|
||||
cfg *config.Config
|
||||
schedulerSnapshot *SchedulerSnapshotService
|
||||
billingService *BillingService
|
||||
rateLimitService *RateLimitService
|
||||
billingCacheService *BillingCacheService
|
||||
identityService *IdentityService
|
||||
httpUpstream HTTPUpstream
|
||||
deferredService *DeferredService
|
||||
concurrencyService *ConcurrencyService
|
||||
accountRepo AccountRepository
|
||||
groupRepo GroupRepository
|
||||
usageLogRepo UsageLogRepository
|
||||
userRepo UserRepository
|
||||
userSubRepo UserSubscriptionRepository
|
||||
cache GatewayCache
|
||||
cfg *config.Config
|
||||
schedulerSnapshot *SchedulerSnapshotService
|
||||
billingService *BillingService
|
||||
rateLimitService *RateLimitService
|
||||
billingCacheService *BillingCacheService
|
||||
identityService *IdentityService
|
||||
httpUpstream HTTPUpstream
|
||||
deferredService *DeferredService
|
||||
concurrencyService *ConcurrencyService
|
||||
claudeTokenProvider *ClaudeTokenProvider
|
||||
}
|
||||
|
||||
// NewGatewayService creates a new GatewayService
|
||||
@@ -178,23 +179,25 @@ func NewGatewayService(
|
||||
identityService *IdentityService,
|
||||
httpUpstream HTTPUpstream,
|
||||
deferredService *DeferredService,
|
||||
claudeTokenProvider *ClaudeTokenProvider,
|
||||
) *GatewayService {
|
||||
return &GatewayService{
|
||||
accountRepo: accountRepo,
|
||||
groupRepo: groupRepo,
|
||||
usageLogRepo: usageLogRepo,
|
||||
userRepo: userRepo,
|
||||
userSubRepo: userSubRepo,
|
||||
cache: cache,
|
||||
cfg: cfg,
|
||||
schedulerSnapshot: schedulerSnapshot,
|
||||
concurrencyService: concurrencyService,
|
||||
billingService: billingService,
|
||||
rateLimitService: rateLimitService,
|
||||
billingCacheService: billingCacheService,
|
||||
identityService: identityService,
|
||||
httpUpstream: httpUpstream,
|
||||
deferredService: deferredService,
|
||||
accountRepo: accountRepo,
|
||||
groupRepo: groupRepo,
|
||||
usageLogRepo: usageLogRepo,
|
||||
userRepo: userRepo,
|
||||
userSubRepo: userSubRepo,
|
||||
cache: cache,
|
||||
cfg: cfg,
|
||||
schedulerSnapshot: schedulerSnapshot,
|
||||
concurrencyService: concurrencyService,
|
||||
billingService: billingService,
|
||||
rateLimitService: rateLimitService,
|
||||
billingCacheService: billingCacheService,
|
||||
identityService: identityService,
|
||||
httpUpstream: httpUpstream,
|
||||
deferredService: deferredService,
|
||||
claudeTokenProvider: claudeTokenProvider,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1079,6 +1082,16 @@ func (s *GatewayService) GetAccessToken(ctx context.Context, account *Account) (
|
||||
}
|
||||
|
||||
func (s *GatewayService) getOAuthToken(ctx context.Context, account *Account) (string, string, error) {
|
||||
// 对于 Anthropic OAuth 账号,使用 ClaudeTokenProvider 获取缓存的 token
|
||||
if account.Platform == PlatformAnthropic && account.Type == AccountTypeOAuth && s.claudeTokenProvider != nil {
|
||||
accessToken, err := s.claudeTokenProvider.GetAccessToken(ctx, account)
|
||||
if err != nil {
|
||||
return "", "", err
|
||||
}
|
||||
return accessToken, "oauth", nil
|
||||
}
|
||||
|
||||
// 其他情况(Gemini 有自己的 TokenProvider,setup-token 类型等)直接从账号读取
|
||||
accessToken := account.GetCredential("access_token")
|
||||
if accessToken == "" {
|
||||
return "", "", errors.New("access_token not found in credentials")
|
||||
|
||||
@@ -154,7 +154,7 @@ func (p *GeminiTokenProvider) GetAccessToken(ctx context.Context, account *Accou
|
||||
func GeminiTokenCacheKey(account *Account) string {
|
||||
projectID := strings.TrimSpace(account.GetCredential("project_id"))
|
||||
if projectID != "" {
|
||||
return projectID
|
||||
return "gemini:" + projectID
|
||||
}
|
||||
return "account:" + strconv.FormatInt(account.ID, 10)
|
||||
return "gemini:account:" + strconv.FormatInt(account.ID, 10)
|
||||
}
|
||||
|
||||
@@ -80,19 +80,20 @@ type OpenAIForwardResult struct {
|
||||
|
||||
// OpenAIGatewayService handles OpenAI API gateway operations
|
||||
type OpenAIGatewayService struct {
|
||||
accountRepo AccountRepository
|
||||
usageLogRepo UsageLogRepository
|
||||
userRepo UserRepository
|
||||
userSubRepo UserSubscriptionRepository
|
||||
cache GatewayCache
|
||||
cfg *config.Config
|
||||
schedulerSnapshot *SchedulerSnapshotService
|
||||
concurrencyService *ConcurrencyService
|
||||
billingService *BillingService
|
||||
rateLimitService *RateLimitService
|
||||
billingCacheService *BillingCacheService
|
||||
httpUpstream HTTPUpstream
|
||||
deferredService *DeferredService
|
||||
accountRepo AccountRepository
|
||||
usageLogRepo UsageLogRepository
|
||||
userRepo UserRepository
|
||||
userSubRepo UserSubscriptionRepository
|
||||
cache GatewayCache
|
||||
cfg *config.Config
|
||||
schedulerSnapshot *SchedulerSnapshotService
|
||||
concurrencyService *ConcurrencyService
|
||||
billingService *BillingService
|
||||
rateLimitService *RateLimitService
|
||||
billingCacheService *BillingCacheService
|
||||
httpUpstream HTTPUpstream
|
||||
deferredService *DeferredService
|
||||
openAITokenProvider *OpenAITokenProvider
|
||||
}
|
||||
|
||||
// NewOpenAIGatewayService creates a new OpenAIGatewayService
|
||||
@@ -110,21 +111,23 @@ func NewOpenAIGatewayService(
|
||||
billingCacheService *BillingCacheService,
|
||||
httpUpstream HTTPUpstream,
|
||||
deferredService *DeferredService,
|
||||
openAITokenProvider *OpenAITokenProvider,
|
||||
) *OpenAIGatewayService {
|
||||
return &OpenAIGatewayService{
|
||||
accountRepo: accountRepo,
|
||||
usageLogRepo: usageLogRepo,
|
||||
userRepo: userRepo,
|
||||
userSubRepo: userSubRepo,
|
||||
cache: cache,
|
||||
cfg: cfg,
|
||||
schedulerSnapshot: schedulerSnapshot,
|
||||
concurrencyService: concurrencyService,
|
||||
billingService: billingService,
|
||||
rateLimitService: rateLimitService,
|
||||
billingCacheService: billingCacheService,
|
||||
httpUpstream: httpUpstream,
|
||||
deferredService: deferredService,
|
||||
accountRepo: accountRepo,
|
||||
usageLogRepo: usageLogRepo,
|
||||
userRepo: userRepo,
|
||||
userSubRepo: userSubRepo,
|
||||
cache: cache,
|
||||
cfg: cfg,
|
||||
schedulerSnapshot: schedulerSnapshot,
|
||||
concurrencyService: concurrencyService,
|
||||
billingService: billingService,
|
||||
rateLimitService: rateLimitService,
|
||||
billingCacheService: billingCacheService,
|
||||
httpUpstream: httpUpstream,
|
||||
deferredService: deferredService,
|
||||
openAITokenProvider: openAITokenProvider,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -503,6 +506,15 @@ func (s *OpenAIGatewayService) schedulingConfig() config.GatewaySchedulingConfig
|
||||
func (s *OpenAIGatewayService) GetAccessToken(ctx context.Context, account *Account) (string, string, error) {
|
||||
switch account.Type {
|
||||
case AccountTypeOAuth:
|
||||
// 使用 TokenProvider 获取缓存的 token
|
||||
if s.openAITokenProvider != nil {
|
||||
accessToken, err := s.openAITokenProvider.GetAccessToken(ctx, account)
|
||||
if err != nil {
|
||||
return "", "", err
|
||||
}
|
||||
return accessToken, "oauth", nil
|
||||
}
|
||||
// 降级:TokenProvider 未配置时直接从账号读取
|
||||
accessToken := account.GetOpenAIAccessToken()
|
||||
if accessToken == "" {
|
||||
return "", "", errors.New("access_token not found in credentials")
|
||||
|
||||
189
backend/internal/service/openai_token_provider.go
Normal file
189
backend/internal/service/openai_token_provider.go
Normal file
@@ -0,0 +1,189 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"log/slog"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
const (
|
||||
openAITokenRefreshSkew = 3 * time.Minute
|
||||
openAITokenCacheSkew = 5 * time.Minute
|
||||
openAILockWaitTime = 200 * time.Millisecond
|
||||
)
|
||||
|
||||
// OpenAITokenCache Token 缓存接口(复用 GeminiTokenCache 接口定义)
|
||||
type OpenAITokenCache = GeminiTokenCache
|
||||
|
||||
// OpenAITokenProvider 管理 OpenAI OAuth 账户的 access_token
|
||||
type OpenAITokenProvider struct {
|
||||
accountRepo AccountRepository
|
||||
tokenCache OpenAITokenCache
|
||||
openAIOAuthService *OpenAIOAuthService
|
||||
}
|
||||
|
||||
func NewOpenAITokenProvider(
|
||||
accountRepo AccountRepository,
|
||||
tokenCache OpenAITokenCache,
|
||||
openAIOAuthService *OpenAIOAuthService,
|
||||
) *OpenAITokenProvider {
|
||||
return &OpenAITokenProvider{
|
||||
accountRepo: accountRepo,
|
||||
tokenCache: tokenCache,
|
||||
openAIOAuthService: openAIOAuthService,
|
||||
}
|
||||
}
|
||||
|
||||
// GetAccessToken 获取有效的 access_token
|
||||
func (p *OpenAITokenProvider) GetAccessToken(ctx context.Context, account *Account) (string, error) {
|
||||
if account == nil {
|
||||
return "", errors.New("account is nil")
|
||||
}
|
||||
if account.Platform != PlatformOpenAI || account.Type != AccountTypeOAuth {
|
||||
return "", errors.New("not an openai oauth account")
|
||||
}
|
||||
|
||||
cacheKey := OpenAITokenCacheKey(account)
|
||||
|
||||
// 1. 先尝试缓存
|
||||
if p.tokenCache != nil {
|
||||
if token, err := p.tokenCache.GetAccessToken(ctx, cacheKey); err == nil && strings.TrimSpace(token) != "" {
|
||||
slog.Debug("openai_token_cache_hit", "account_id", account.ID)
|
||||
return token, nil
|
||||
} else if err != nil {
|
||||
slog.Warn("openai_token_cache_get_failed", "account_id", account.ID, "error", err)
|
||||
}
|
||||
}
|
||||
|
||||
slog.Debug("openai_token_cache_miss", "account_id", account.ID)
|
||||
|
||||
// 2. 如果即将过期则刷新
|
||||
expiresAt := account.GetCredentialAsTime("expires_at")
|
||||
needsRefresh := expiresAt == nil || time.Until(*expiresAt) <= openAITokenRefreshSkew
|
||||
refreshFailed := false
|
||||
if needsRefresh && p.tokenCache != nil {
|
||||
locked, lockErr := p.tokenCache.AcquireRefreshLock(ctx, cacheKey, 30*time.Second)
|
||||
if lockErr == nil && locked {
|
||||
defer func() { _ = p.tokenCache.ReleaseRefreshLock(ctx, cacheKey) }()
|
||||
|
||||
// 拿到锁后再次检查缓存(另一个 worker 可能已刷新)
|
||||
if token, err := p.tokenCache.GetAccessToken(ctx, cacheKey); err == nil && strings.TrimSpace(token) != "" {
|
||||
return token, nil
|
||||
}
|
||||
|
||||
// 从数据库获取最新账户信息
|
||||
fresh, err := p.accountRepo.GetByID(ctx, account.ID)
|
||||
if err == nil && fresh != nil {
|
||||
account = fresh
|
||||
}
|
||||
expiresAt = account.GetCredentialAsTime("expires_at")
|
||||
if expiresAt == nil || time.Until(*expiresAt) <= openAITokenRefreshSkew {
|
||||
if p.openAIOAuthService == nil {
|
||||
slog.Warn("openai_oauth_service_not_configured", "account_id", account.ID)
|
||||
refreshFailed = true // 无法刷新,标记失败
|
||||
} else {
|
||||
tokenInfo, err := p.openAIOAuthService.RefreshAccountToken(ctx, account)
|
||||
if err != nil {
|
||||
// 刷新失败时记录警告,但不立即返回错误,尝试使用现有 token
|
||||
slog.Warn("openai_token_refresh_failed", "account_id", account.ID, "error", err)
|
||||
refreshFailed = true // 刷新失败,标记以使用短 TTL
|
||||
} else {
|
||||
newCredentials := p.openAIOAuthService.BuildAccountCredentials(tokenInfo)
|
||||
for k, v := range account.Credentials {
|
||||
if _, exists := newCredentials[k]; !exists {
|
||||
newCredentials[k] = v
|
||||
}
|
||||
}
|
||||
account.Credentials = newCredentials
|
||||
if updateErr := p.accountRepo.Update(ctx, account); updateErr != nil {
|
||||
slog.Error("openai_token_provider_update_failed", "account_id", account.ID, "error", updateErr)
|
||||
}
|
||||
expiresAt = account.GetCredentialAsTime("expires_at")
|
||||
}
|
||||
}
|
||||
}
|
||||
} else if lockErr != nil {
|
||||
// Redis 错误导致无法获取锁,降级为无锁刷新(仅在 token 接近过期时)
|
||||
slog.Warn("openai_token_lock_failed_degraded_refresh", "account_id", account.ID, "error", lockErr)
|
||||
|
||||
// 检查 ctx 是否已取消
|
||||
if ctx.Err() != nil {
|
||||
return "", ctx.Err()
|
||||
}
|
||||
|
||||
// 从数据库获取最新账户信息
|
||||
if p.accountRepo != nil {
|
||||
fresh, err := p.accountRepo.GetByID(ctx, account.ID)
|
||||
if err == nil && fresh != nil {
|
||||
account = fresh
|
||||
}
|
||||
}
|
||||
expiresAt = account.GetCredentialAsTime("expires_at")
|
||||
|
||||
// 仅在 expires_at 已过期/接近过期时才执行无锁刷新
|
||||
if expiresAt == nil || time.Until(*expiresAt) <= openAITokenRefreshSkew {
|
||||
if p.openAIOAuthService == nil {
|
||||
slog.Warn("openai_oauth_service_not_configured", "account_id", account.ID)
|
||||
refreshFailed = true
|
||||
} else {
|
||||
tokenInfo, err := p.openAIOAuthService.RefreshAccountToken(ctx, account)
|
||||
if err != nil {
|
||||
slog.Warn("openai_token_refresh_failed_degraded", "account_id", account.ID, "error", err)
|
||||
refreshFailed = true
|
||||
} else {
|
||||
newCredentials := p.openAIOAuthService.BuildAccountCredentials(tokenInfo)
|
||||
for k, v := range account.Credentials {
|
||||
if _, exists := newCredentials[k]; !exists {
|
||||
newCredentials[k] = v
|
||||
}
|
||||
}
|
||||
account.Credentials = newCredentials
|
||||
if updateErr := p.accountRepo.Update(ctx, account); updateErr != nil {
|
||||
slog.Error("openai_token_provider_update_failed", "account_id", account.ID, "error", updateErr)
|
||||
}
|
||||
expiresAt = account.GetCredentialAsTime("expires_at")
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// 锁获取失败(被其他 worker 持有),等待 200ms 后重试读取缓存
|
||||
time.Sleep(openAILockWaitTime)
|
||||
if token, err := p.tokenCache.GetAccessToken(ctx, cacheKey); err == nil && strings.TrimSpace(token) != "" {
|
||||
slog.Debug("openai_token_cache_hit_after_wait", "account_id", account.ID)
|
||||
return token, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
accessToken := account.GetOpenAIAccessToken()
|
||||
if strings.TrimSpace(accessToken) == "" {
|
||||
return "", errors.New("access_token not found in credentials")
|
||||
}
|
||||
|
||||
// 3. 存入缓存
|
||||
if p.tokenCache != nil {
|
||||
ttl := 30 * time.Minute
|
||||
if refreshFailed {
|
||||
// 刷新失败时使用短 TTL,避免失效 token 长时间缓存导致 401 抖动
|
||||
ttl = time.Minute
|
||||
slog.Debug("openai_token_cache_short_ttl", "account_id", account.ID, "reason", "refresh_failed")
|
||||
} else if expiresAt != nil {
|
||||
until := time.Until(*expiresAt)
|
||||
switch {
|
||||
case until > openAITokenCacheSkew:
|
||||
ttl = until - openAITokenCacheSkew
|
||||
case until > 0:
|
||||
ttl = until
|
||||
default:
|
||||
ttl = time.Minute
|
||||
}
|
||||
}
|
||||
if err := p.tokenCache.SetAccessToken(ctx, cacheKey, accessToken, ttl); err != nil {
|
||||
slog.Warn("openai_token_cache_set_failed", "account_id", account.ID, "error", err)
|
||||
}
|
||||
}
|
||||
|
||||
return accessToken, nil
|
||||
}
|
||||
810
backend/internal/service/openai_token_provider_test.go
Normal file
810
backend/internal/service/openai_token_provider_test.go
Normal file
@@ -0,0 +1,810 @@
|
||||
//go:build unit
|
||||
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// openAITokenCacheStub implements OpenAITokenCache for testing
|
||||
type openAITokenCacheStub struct {
|
||||
mu sync.Mutex
|
||||
tokens map[string]string
|
||||
getErr error
|
||||
setErr error
|
||||
deleteErr error
|
||||
lockAcquired bool
|
||||
lockErr error
|
||||
releaseLockErr error
|
||||
getCalled int32
|
||||
setCalled int32
|
||||
lockCalled int32
|
||||
unlockCalled int32
|
||||
simulateLockRace bool
|
||||
}
|
||||
|
||||
func newOpenAITokenCacheStub() *openAITokenCacheStub {
|
||||
return &openAITokenCacheStub{
|
||||
tokens: make(map[string]string),
|
||||
lockAcquired: true,
|
||||
}
|
||||
}
|
||||
|
||||
func (s *openAITokenCacheStub) GetAccessToken(ctx context.Context, cacheKey string) (string, error) {
|
||||
atomic.AddInt32(&s.getCalled, 1)
|
||||
if s.getErr != nil {
|
||||
return "", s.getErr
|
||||
}
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
return s.tokens[cacheKey], nil
|
||||
}
|
||||
|
||||
func (s *openAITokenCacheStub) SetAccessToken(ctx context.Context, cacheKey string, token string, ttl time.Duration) error {
|
||||
atomic.AddInt32(&s.setCalled, 1)
|
||||
if s.setErr != nil {
|
||||
return s.setErr
|
||||
}
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
s.tokens[cacheKey] = token
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *openAITokenCacheStub) DeleteAccessToken(ctx context.Context, cacheKey string) error {
|
||||
if s.deleteErr != nil {
|
||||
return s.deleteErr
|
||||
}
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
delete(s.tokens, cacheKey)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *openAITokenCacheStub) AcquireRefreshLock(ctx context.Context, cacheKey string, ttl time.Duration) (bool, error) {
|
||||
atomic.AddInt32(&s.lockCalled, 1)
|
||||
if s.lockErr != nil {
|
||||
return false, s.lockErr
|
||||
}
|
||||
if s.simulateLockRace {
|
||||
return false, nil
|
||||
}
|
||||
return s.lockAcquired, nil
|
||||
}
|
||||
|
||||
func (s *openAITokenCacheStub) ReleaseRefreshLock(ctx context.Context, cacheKey string) error {
|
||||
atomic.AddInt32(&s.unlockCalled, 1)
|
||||
return s.releaseLockErr
|
||||
}
|
||||
|
||||
// openAIAccountRepoStub is a minimal stub implementing only the methods used by OpenAITokenProvider
|
||||
type openAIAccountRepoStub struct {
|
||||
account *Account
|
||||
getErr error
|
||||
updateErr error
|
||||
getCalled int32
|
||||
updateCalled int32
|
||||
}
|
||||
|
||||
func (r *openAIAccountRepoStub) GetByID(ctx context.Context, id int64) (*Account, error) {
|
||||
atomic.AddInt32(&r.getCalled, 1)
|
||||
if r.getErr != nil {
|
||||
return nil, r.getErr
|
||||
}
|
||||
return r.account, nil
|
||||
}
|
||||
|
||||
func (r *openAIAccountRepoStub) Update(ctx context.Context, account *Account) error {
|
||||
atomic.AddInt32(&r.updateCalled, 1)
|
||||
if r.updateErr != nil {
|
||||
return r.updateErr
|
||||
}
|
||||
r.account = account
|
||||
return nil
|
||||
}
|
||||
|
||||
// openAIOAuthServiceStub implements OpenAIOAuthService methods for testing
|
||||
type openAIOAuthServiceStub struct {
|
||||
tokenInfo *OpenAITokenInfo
|
||||
refreshErr error
|
||||
refreshCalled int32
|
||||
}
|
||||
|
||||
func (s *openAIOAuthServiceStub) RefreshAccountToken(ctx context.Context, account *Account) (*OpenAITokenInfo, error) {
|
||||
atomic.AddInt32(&s.refreshCalled, 1)
|
||||
if s.refreshErr != nil {
|
||||
return nil, s.refreshErr
|
||||
}
|
||||
return s.tokenInfo, nil
|
||||
}
|
||||
|
||||
func (s *openAIOAuthServiceStub) BuildAccountCredentials(info *OpenAITokenInfo) map[string]any {
|
||||
now := time.Now()
|
||||
return map[string]any{
|
||||
"access_token": info.AccessToken,
|
||||
"refresh_token": info.RefreshToken,
|
||||
"expires_at": now.Add(time.Duration(info.ExpiresIn) * time.Second).Format(time.RFC3339),
|
||||
}
|
||||
}
|
||||
|
||||
func TestOpenAITokenProvider_CacheHit(t *testing.T) {
|
||||
cache := newOpenAITokenCacheStub()
|
||||
account := &Account{
|
||||
ID: 100,
|
||||
Platform: PlatformOpenAI,
|
||||
Type: AccountTypeOAuth,
|
||||
Credentials: map[string]any{
|
||||
"access_token": "db-token",
|
||||
},
|
||||
}
|
||||
cacheKey := OpenAITokenCacheKey(account)
|
||||
cache.tokens[cacheKey] = "cached-token"
|
||||
|
||||
provider := NewOpenAITokenProvider(nil, cache, nil)
|
||||
|
||||
token, err := provider.GetAccessToken(context.Background(), account)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "cached-token", token)
|
||||
require.Equal(t, int32(1), atomic.LoadInt32(&cache.getCalled))
|
||||
require.Equal(t, int32(0), atomic.LoadInt32(&cache.setCalled))
|
||||
}
|
||||
|
||||
func TestOpenAITokenProvider_CacheMiss_FromCredentials(t *testing.T) {
|
||||
cache := newOpenAITokenCacheStub()
|
||||
// Token expires in far future, no refresh needed
|
||||
expiresAt := time.Now().Add(1 * time.Hour).Format(time.RFC3339)
|
||||
account := &Account{
|
||||
ID: 101,
|
||||
Platform: PlatformOpenAI,
|
||||
Type: AccountTypeOAuth,
|
||||
Credentials: map[string]any{
|
||||
"access_token": "credential-token",
|
||||
"expires_at": expiresAt,
|
||||
},
|
||||
}
|
||||
|
||||
provider := NewOpenAITokenProvider(nil, cache, nil)
|
||||
|
||||
token, err := provider.GetAccessToken(context.Background(), account)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "credential-token", token)
|
||||
|
||||
// Should have stored in cache
|
||||
cacheKey := OpenAITokenCacheKey(account)
|
||||
require.Equal(t, "credential-token", cache.tokens[cacheKey])
|
||||
}
|
||||
|
||||
func TestOpenAITokenProvider_TokenRefresh(t *testing.T) {
|
||||
cache := newOpenAITokenCacheStub()
|
||||
accountRepo := &openAIAccountRepoStub{}
|
||||
oauthService := &openAIOAuthServiceStub{
|
||||
tokenInfo: &OpenAITokenInfo{
|
||||
AccessToken: "refreshed-token",
|
||||
RefreshToken: "new-refresh-token",
|
||||
ExpiresIn: 3600,
|
||||
},
|
||||
}
|
||||
|
||||
// Token expires soon (within refresh skew)
|
||||
expiresAt := time.Now().Add(1 * time.Minute).Format(time.RFC3339)
|
||||
account := &Account{
|
||||
ID: 102,
|
||||
Platform: PlatformOpenAI,
|
||||
Type: AccountTypeOAuth,
|
||||
Credentials: map[string]any{
|
||||
"access_token": "old-token",
|
||||
"refresh_token": "old-refresh-token",
|
||||
"expires_at": expiresAt,
|
||||
},
|
||||
}
|
||||
accountRepo.account = account
|
||||
|
||||
// We need to directly test with the stub - create a custom provider
|
||||
customProvider := &testOpenAITokenProvider{
|
||||
accountRepo: accountRepo,
|
||||
tokenCache: cache,
|
||||
oauthService: oauthService,
|
||||
}
|
||||
|
||||
token, err := customProvider.GetAccessToken(context.Background(), account)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "refreshed-token", token)
|
||||
require.Equal(t, int32(1), atomic.LoadInt32(&oauthService.refreshCalled))
|
||||
}
|
||||
|
||||
// testOpenAITokenProvider is a test version that uses the stub OAuth service
|
||||
type testOpenAITokenProvider struct {
|
||||
accountRepo *openAIAccountRepoStub
|
||||
tokenCache *openAITokenCacheStub
|
||||
oauthService *openAIOAuthServiceStub
|
||||
}
|
||||
|
||||
func (p *testOpenAITokenProvider) GetAccessToken(ctx context.Context, account *Account) (string, error) {
|
||||
if account == nil {
|
||||
return "", errors.New("account is nil")
|
||||
}
|
||||
if account.Platform != PlatformOpenAI || account.Type != AccountTypeOAuth {
|
||||
return "", errors.New("not an openai oauth account")
|
||||
}
|
||||
|
||||
cacheKey := OpenAITokenCacheKey(account)
|
||||
|
||||
// 1. Check cache
|
||||
if p.tokenCache != nil {
|
||||
if token, err := p.tokenCache.GetAccessToken(ctx, cacheKey); err == nil && token != "" {
|
||||
return token, nil
|
||||
}
|
||||
}
|
||||
|
||||
// 2. Check if refresh needed
|
||||
expiresAt := account.GetCredentialAsTime("expires_at")
|
||||
needsRefresh := expiresAt == nil || time.Until(*expiresAt) <= openAITokenRefreshSkew
|
||||
refreshFailed := false
|
||||
if needsRefresh && p.tokenCache != nil {
|
||||
locked, err := p.tokenCache.AcquireRefreshLock(ctx, cacheKey, 30*time.Second)
|
||||
if err == nil && locked {
|
||||
defer func() { _ = p.tokenCache.ReleaseRefreshLock(ctx, cacheKey) }()
|
||||
|
||||
// Check cache again after acquiring lock
|
||||
if token, err := p.tokenCache.GetAccessToken(ctx, cacheKey); err == nil && token != "" {
|
||||
return token, nil
|
||||
}
|
||||
|
||||
// Get fresh account from DB
|
||||
fresh, err := p.accountRepo.GetByID(ctx, account.ID)
|
||||
if err == nil && fresh != nil {
|
||||
account = fresh
|
||||
}
|
||||
expiresAt = account.GetCredentialAsTime("expires_at")
|
||||
if expiresAt == nil || time.Until(*expiresAt) <= openAITokenRefreshSkew {
|
||||
if p.oauthService == nil {
|
||||
refreshFailed = true // 无法刷新,标记失败
|
||||
} else {
|
||||
tokenInfo, err := p.oauthService.RefreshAccountToken(ctx, account)
|
||||
if err != nil {
|
||||
refreshFailed = true // 刷新失败,标记以使用短 TTL
|
||||
} else {
|
||||
newCredentials := p.oauthService.BuildAccountCredentials(tokenInfo)
|
||||
for k, v := range account.Credentials {
|
||||
if _, exists := newCredentials[k]; !exists {
|
||||
newCredentials[k] = v
|
||||
}
|
||||
}
|
||||
account.Credentials = newCredentials
|
||||
_ = p.accountRepo.Update(ctx, account)
|
||||
expiresAt = account.GetCredentialAsTime("expires_at")
|
||||
}
|
||||
}
|
||||
}
|
||||
} else if p.tokenCache.simulateLockRace {
|
||||
// Wait and retry cache
|
||||
time.Sleep(10 * time.Millisecond) // Short wait for test
|
||||
if token, err := p.tokenCache.GetAccessToken(ctx, cacheKey); err == nil && token != "" {
|
||||
return token, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
accessToken := account.GetOpenAIAccessToken()
|
||||
if accessToken == "" {
|
||||
return "", errors.New("access_token not found in credentials")
|
||||
}
|
||||
|
||||
// 3. Store in cache
|
||||
if p.tokenCache != nil {
|
||||
ttl := 30 * time.Minute
|
||||
if refreshFailed {
|
||||
ttl = time.Minute // 刷新失败时使用短 TTL
|
||||
} else if expiresAt != nil {
|
||||
until := time.Until(*expiresAt)
|
||||
if until > openAITokenCacheSkew {
|
||||
ttl = until - openAITokenCacheSkew
|
||||
} else if until > 0 {
|
||||
ttl = until
|
||||
} else {
|
||||
ttl = time.Minute
|
||||
}
|
||||
}
|
||||
_ = p.tokenCache.SetAccessToken(ctx, cacheKey, accessToken, ttl)
|
||||
}
|
||||
|
||||
return accessToken, nil
|
||||
}
|
||||
|
||||
func TestOpenAITokenProvider_LockRaceCondition(t *testing.T) {
|
||||
cache := newOpenAITokenCacheStub()
|
||||
cache.simulateLockRace = true
|
||||
accountRepo := &openAIAccountRepoStub{}
|
||||
|
||||
// Token expires soon
|
||||
expiresAt := time.Now().Add(1 * time.Minute).Format(time.RFC3339)
|
||||
account := &Account{
|
||||
ID: 103,
|
||||
Platform: PlatformOpenAI,
|
||||
Type: AccountTypeOAuth,
|
||||
Credentials: map[string]any{
|
||||
"access_token": "race-token",
|
||||
"expires_at": expiresAt,
|
||||
},
|
||||
}
|
||||
accountRepo.account = account
|
||||
|
||||
// Simulate another worker already refreshed and cached
|
||||
cacheKey := OpenAITokenCacheKey(account)
|
||||
go func() {
|
||||
time.Sleep(5 * time.Millisecond)
|
||||
cache.mu.Lock()
|
||||
cache.tokens[cacheKey] = "winner-token"
|
||||
cache.mu.Unlock()
|
||||
}()
|
||||
|
||||
provider := &testOpenAITokenProvider{
|
||||
accountRepo: accountRepo,
|
||||
tokenCache: cache,
|
||||
}
|
||||
|
||||
token, err := provider.GetAccessToken(context.Background(), account)
|
||||
require.NoError(t, err)
|
||||
// Should get the token set by the "winner" or the original
|
||||
require.NotEmpty(t, token)
|
||||
}
|
||||
|
||||
func TestOpenAITokenProvider_NilAccount(t *testing.T) {
|
||||
provider := NewOpenAITokenProvider(nil, nil, nil)
|
||||
|
||||
token, err := provider.GetAccessToken(context.Background(), nil)
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "account is nil")
|
||||
require.Empty(t, token)
|
||||
}
|
||||
|
||||
func TestOpenAITokenProvider_WrongPlatform(t *testing.T) {
|
||||
provider := NewOpenAITokenProvider(nil, nil, nil)
|
||||
account := &Account{
|
||||
ID: 104,
|
||||
Platform: PlatformGemini,
|
||||
Type: AccountTypeOAuth,
|
||||
}
|
||||
|
||||
token, err := provider.GetAccessToken(context.Background(), account)
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "not an openai oauth account")
|
||||
require.Empty(t, token)
|
||||
}
|
||||
|
||||
func TestOpenAITokenProvider_WrongAccountType(t *testing.T) {
|
||||
provider := NewOpenAITokenProvider(nil, nil, nil)
|
||||
account := &Account{
|
||||
ID: 105,
|
||||
Platform: PlatformOpenAI,
|
||||
Type: AccountTypeAPIKey,
|
||||
}
|
||||
|
||||
token, err := provider.GetAccessToken(context.Background(), account)
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "not an openai oauth account")
|
||||
require.Empty(t, token)
|
||||
}
|
||||
|
||||
func TestOpenAITokenProvider_NilCache(t *testing.T) {
|
||||
// Token doesn't need refresh
|
||||
expiresAt := time.Now().Add(1 * time.Hour).Format(time.RFC3339)
|
||||
account := &Account{
|
||||
ID: 106,
|
||||
Platform: PlatformOpenAI,
|
||||
Type: AccountTypeOAuth,
|
||||
Credentials: map[string]any{
|
||||
"access_token": "nocache-token",
|
||||
"expires_at": expiresAt,
|
||||
},
|
||||
}
|
||||
|
||||
provider := NewOpenAITokenProvider(nil, nil, nil)
|
||||
|
||||
token, err := provider.GetAccessToken(context.Background(), account)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "nocache-token", token)
|
||||
}
|
||||
|
||||
func TestOpenAITokenProvider_CacheGetError(t *testing.T) {
|
||||
cache := newOpenAITokenCacheStub()
|
||||
cache.getErr = errors.New("redis connection failed")
|
||||
|
||||
// Token doesn't need refresh
|
||||
expiresAt := time.Now().Add(1 * time.Hour).Format(time.RFC3339)
|
||||
account := &Account{
|
||||
ID: 107,
|
||||
Platform: PlatformOpenAI,
|
||||
Type: AccountTypeOAuth,
|
||||
Credentials: map[string]any{
|
||||
"access_token": "fallback-token",
|
||||
"expires_at": expiresAt,
|
||||
},
|
||||
}
|
||||
|
||||
provider := NewOpenAITokenProvider(nil, cache, nil)
|
||||
|
||||
// Should gracefully degrade and return from credentials
|
||||
token, err := provider.GetAccessToken(context.Background(), account)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "fallback-token", token)
|
||||
}
|
||||
|
||||
func TestOpenAITokenProvider_CacheSetError(t *testing.T) {
|
||||
cache := newOpenAITokenCacheStub()
|
||||
cache.setErr = errors.New("redis write failed")
|
||||
|
||||
expiresAt := time.Now().Add(1 * time.Hour).Format(time.RFC3339)
|
||||
account := &Account{
|
||||
ID: 108,
|
||||
Platform: PlatformOpenAI,
|
||||
Type: AccountTypeOAuth,
|
||||
Credentials: map[string]any{
|
||||
"access_token": "still-works-token",
|
||||
"expires_at": expiresAt,
|
||||
},
|
||||
}
|
||||
|
||||
provider := NewOpenAITokenProvider(nil, cache, nil)
|
||||
|
||||
// Should still work even if cache set fails
|
||||
token, err := provider.GetAccessToken(context.Background(), account)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "still-works-token", token)
|
||||
}
|
||||
|
||||
func TestOpenAITokenProvider_MissingAccessToken(t *testing.T) {
|
||||
cache := newOpenAITokenCacheStub()
|
||||
expiresAt := time.Now().Add(1 * time.Hour).Format(time.RFC3339)
|
||||
account := &Account{
|
||||
ID: 109,
|
||||
Platform: PlatformOpenAI,
|
||||
Type: AccountTypeOAuth,
|
||||
Credentials: map[string]any{
|
||||
"expires_at": expiresAt,
|
||||
// missing access_token
|
||||
},
|
||||
}
|
||||
|
||||
provider := NewOpenAITokenProvider(nil, cache, nil)
|
||||
|
||||
token, err := provider.GetAccessToken(context.Background(), account)
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "access_token not found")
|
||||
require.Empty(t, token)
|
||||
}
|
||||
|
||||
func TestOpenAITokenProvider_RefreshError(t *testing.T) {
|
||||
cache := newOpenAITokenCacheStub()
|
||||
accountRepo := &openAIAccountRepoStub{}
|
||||
oauthService := &openAIOAuthServiceStub{
|
||||
refreshErr: errors.New("oauth refresh failed"),
|
||||
}
|
||||
|
||||
// Token expires soon
|
||||
expiresAt := time.Now().Add(1 * time.Minute).Format(time.RFC3339)
|
||||
account := &Account{
|
||||
ID: 110,
|
||||
Platform: PlatformOpenAI,
|
||||
Type: AccountTypeOAuth,
|
||||
Credentials: map[string]any{
|
||||
"access_token": "old-token",
|
||||
"refresh_token": "old-refresh-token",
|
||||
"expires_at": expiresAt,
|
||||
},
|
||||
}
|
||||
accountRepo.account = account
|
||||
|
||||
provider := &testOpenAITokenProvider{
|
||||
accountRepo: accountRepo,
|
||||
tokenCache: cache,
|
||||
oauthService: oauthService,
|
||||
}
|
||||
|
||||
// Now with fallback behavior, should return existing token even if refresh fails
|
||||
token, err := provider.GetAccessToken(context.Background(), account)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "old-token", token) // Fallback to existing token
|
||||
}
|
||||
|
||||
func TestOpenAITokenProvider_OAuthServiceNotConfigured(t *testing.T) {
|
||||
cache := newOpenAITokenCacheStub()
|
||||
accountRepo := &openAIAccountRepoStub{}
|
||||
|
||||
// Token expires soon
|
||||
expiresAt := time.Now().Add(1 * time.Minute).Format(time.RFC3339)
|
||||
account := &Account{
|
||||
ID: 111,
|
||||
Platform: PlatformOpenAI,
|
||||
Type: AccountTypeOAuth,
|
||||
Credentials: map[string]any{
|
||||
"access_token": "old-token",
|
||||
"expires_at": expiresAt,
|
||||
},
|
||||
}
|
||||
accountRepo.account = account
|
||||
|
||||
provider := &testOpenAITokenProvider{
|
||||
accountRepo: accountRepo,
|
||||
tokenCache: cache,
|
||||
oauthService: nil, // not configured
|
||||
}
|
||||
|
||||
// Now with fallback behavior, should return existing token even if oauth service not configured
|
||||
token, err := provider.GetAccessToken(context.Background(), account)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "old-token", token) // Fallback to existing token
|
||||
}
|
||||
|
||||
func TestOpenAITokenProvider_TTLCalculation(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
expiresIn time.Duration
|
||||
}{
|
||||
{
|
||||
name: "far_future_expiry",
|
||||
expiresIn: 1 * time.Hour,
|
||||
},
|
||||
{
|
||||
name: "medium_expiry",
|
||||
expiresIn: 10 * time.Minute,
|
||||
},
|
||||
{
|
||||
name: "near_expiry",
|
||||
expiresIn: 6 * time.Minute,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
cache := newOpenAITokenCacheStub()
|
||||
expiresAt := time.Now().Add(tt.expiresIn).Format(time.RFC3339)
|
||||
account := &Account{
|
||||
ID: 200,
|
||||
Platform: PlatformOpenAI,
|
||||
Type: AccountTypeOAuth,
|
||||
Credentials: map[string]any{
|
||||
"access_token": "test-token",
|
||||
"expires_at": expiresAt,
|
||||
},
|
||||
}
|
||||
|
||||
provider := NewOpenAITokenProvider(nil, cache, nil)
|
||||
|
||||
_, err := provider.GetAccessToken(context.Background(), account)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify token was cached
|
||||
cacheKey := OpenAITokenCacheKey(account)
|
||||
require.Equal(t, "test-token", cache.tokens[cacheKey])
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestOpenAITokenProvider_DoubleCheckAfterLock(t *testing.T) {
|
||||
cache := newOpenAITokenCacheStub()
|
||||
accountRepo := &openAIAccountRepoStub{}
|
||||
oauthService := &openAIOAuthServiceStub{
|
||||
tokenInfo: &OpenAITokenInfo{
|
||||
AccessToken: "refreshed-token",
|
||||
RefreshToken: "new-refresh",
|
||||
ExpiresIn: 3600,
|
||||
},
|
||||
}
|
||||
|
||||
// Token expires soon
|
||||
expiresAt := time.Now().Add(1 * time.Minute).Format(time.RFC3339)
|
||||
account := &Account{
|
||||
ID: 112,
|
||||
Platform: PlatformOpenAI,
|
||||
Type: AccountTypeOAuth,
|
||||
Credentials: map[string]any{
|
||||
"access_token": "old-token",
|
||||
"expires_at": expiresAt,
|
||||
},
|
||||
}
|
||||
accountRepo.account = account
|
||||
cacheKey := OpenAITokenCacheKey(account)
|
||||
|
||||
// Simulate: first GetAccessToken returns empty, but after lock acquired, cache has token
|
||||
originalGet := int32(0)
|
||||
cache.tokens[cacheKey] = "" // Empty initially
|
||||
|
||||
provider := &testOpenAITokenProvider{
|
||||
accountRepo: accountRepo,
|
||||
tokenCache: cache,
|
||||
oauthService: oauthService,
|
||||
}
|
||||
|
||||
// In a goroutine, set the cached token after a small delay (simulating race)
|
||||
go func() {
|
||||
time.Sleep(5 * time.Millisecond)
|
||||
cache.mu.Lock()
|
||||
cache.tokens[cacheKey] = "cached-by-other"
|
||||
cache.mu.Unlock()
|
||||
}()
|
||||
|
||||
token, err := provider.GetAccessToken(context.Background(), account)
|
||||
require.NoError(t, err)
|
||||
// Should get either the refreshed token or the cached one
|
||||
require.NotEmpty(t, token)
|
||||
_ = originalGet // Suppress unused warning
|
||||
}
|
||||
|
||||
// Tests for real provider - to increase coverage
|
||||
func TestOpenAITokenProvider_Real_LockFailedWait(t *testing.T) {
|
||||
cache := newOpenAITokenCacheStub()
|
||||
cache.lockAcquired = false // Lock acquisition fails
|
||||
|
||||
// Token expires soon (within refresh skew) to trigger lock attempt
|
||||
expiresAt := time.Now().Add(1 * time.Minute).Format(time.RFC3339)
|
||||
account := &Account{
|
||||
ID: 200,
|
||||
Platform: PlatformOpenAI,
|
||||
Type: AccountTypeOAuth,
|
||||
Credentials: map[string]any{
|
||||
"access_token": "fallback-token",
|
||||
"expires_at": expiresAt,
|
||||
},
|
||||
}
|
||||
|
||||
// Set token in cache after lock wait period (simulate other worker refreshing)
|
||||
cacheKey := OpenAITokenCacheKey(account)
|
||||
go func() {
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
cache.mu.Lock()
|
||||
cache.tokens[cacheKey] = "refreshed-by-other"
|
||||
cache.mu.Unlock()
|
||||
}()
|
||||
|
||||
provider := NewOpenAITokenProvider(nil, cache, nil)
|
||||
token, err := provider.GetAccessToken(context.Background(), account)
|
||||
require.NoError(t, err)
|
||||
// Should get either the fallback token or the refreshed one
|
||||
require.NotEmpty(t, token)
|
||||
}
|
||||
|
||||
func TestOpenAITokenProvider_Real_CacheHitAfterWait(t *testing.T) {
|
||||
cache := newOpenAITokenCacheStub()
|
||||
cache.lockAcquired = false // Lock acquisition fails
|
||||
|
||||
// Token expires soon
|
||||
expiresAt := time.Now().Add(1 * time.Minute).Format(time.RFC3339)
|
||||
account := &Account{
|
||||
ID: 201,
|
||||
Platform: PlatformOpenAI,
|
||||
Type: AccountTypeOAuth,
|
||||
Credentials: map[string]any{
|
||||
"access_token": "original-token",
|
||||
"expires_at": expiresAt,
|
||||
},
|
||||
}
|
||||
|
||||
cacheKey := OpenAITokenCacheKey(account)
|
||||
// Set token in cache immediately after wait starts
|
||||
go func() {
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
cache.mu.Lock()
|
||||
cache.tokens[cacheKey] = "winner-token"
|
||||
cache.mu.Unlock()
|
||||
}()
|
||||
|
||||
provider := NewOpenAITokenProvider(nil, cache, nil)
|
||||
token, err := provider.GetAccessToken(context.Background(), account)
|
||||
require.NoError(t, err)
|
||||
require.NotEmpty(t, token)
|
||||
}
|
||||
|
||||
func TestOpenAITokenProvider_Real_ExpiredWithoutRefreshToken(t *testing.T) {
|
||||
cache := newOpenAITokenCacheStub()
|
||||
cache.lockAcquired = false // Prevent entering refresh logic
|
||||
|
||||
// Token with nil expires_at (no expiry set) - should use credentials
|
||||
account := &Account{
|
||||
ID: 202,
|
||||
Platform: PlatformOpenAI,
|
||||
Type: AccountTypeOAuth,
|
||||
Credentials: map[string]any{
|
||||
"access_token": "no-expiry-token",
|
||||
},
|
||||
}
|
||||
|
||||
provider := NewOpenAITokenProvider(nil, cache, nil)
|
||||
token, err := provider.GetAccessToken(context.Background(), account)
|
||||
// Without OAuth service, refresh will fail but token should be returned from credentials
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "no-expiry-token", token)
|
||||
}
|
||||
|
||||
func TestOpenAITokenProvider_Real_WhitespaceToken(t *testing.T) {
|
||||
cache := newOpenAITokenCacheStub()
|
||||
cacheKey := "openai:account:203"
|
||||
cache.tokens[cacheKey] = " " // Whitespace only - should be treated as empty
|
||||
|
||||
expiresAt := time.Now().Add(1 * time.Hour).Format(time.RFC3339)
|
||||
account := &Account{
|
||||
ID: 203,
|
||||
Platform: PlatformOpenAI,
|
||||
Type: AccountTypeOAuth,
|
||||
Credentials: map[string]any{
|
||||
"access_token": "real-token",
|
||||
"expires_at": expiresAt,
|
||||
},
|
||||
}
|
||||
|
||||
provider := NewOpenAITokenProvider(nil, cache, nil)
|
||||
token, err := provider.GetAccessToken(context.Background(), account)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "real-token", token) // Should fall back to credentials
|
||||
}
|
||||
|
||||
func TestOpenAITokenProvider_Real_LockError(t *testing.T) {
|
||||
cache := newOpenAITokenCacheStub()
|
||||
cache.lockErr = errors.New("redis lock failed")
|
||||
|
||||
// Token expires soon (within refresh skew)
|
||||
expiresAt := time.Now().Add(1 * time.Minute).Format(time.RFC3339)
|
||||
account := &Account{
|
||||
ID: 204,
|
||||
Platform: PlatformOpenAI,
|
||||
Type: AccountTypeOAuth,
|
||||
Credentials: map[string]any{
|
||||
"access_token": "fallback-on-lock-error",
|
||||
"expires_at": expiresAt,
|
||||
},
|
||||
}
|
||||
|
||||
provider := NewOpenAITokenProvider(nil, cache, nil)
|
||||
token, err := provider.GetAccessToken(context.Background(), account)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "fallback-on-lock-error", token)
|
||||
}
|
||||
|
||||
func TestOpenAITokenProvider_Real_WhitespaceCredentialToken(t *testing.T) {
|
||||
cache := newOpenAITokenCacheStub()
|
||||
|
||||
expiresAt := time.Now().Add(1 * time.Hour).Format(time.RFC3339)
|
||||
account := &Account{
|
||||
ID: 205,
|
||||
Platform: PlatformOpenAI,
|
||||
Type: AccountTypeOAuth,
|
||||
Credentials: map[string]any{
|
||||
"access_token": " ", // Whitespace only
|
||||
"expires_at": expiresAt,
|
||||
},
|
||||
}
|
||||
|
||||
provider := NewOpenAITokenProvider(nil, cache, nil)
|
||||
token, err := provider.GetAccessToken(context.Background(), account)
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "access_token not found")
|
||||
require.Empty(t, token)
|
||||
}
|
||||
|
||||
func TestOpenAITokenProvider_Real_NilCredentials(t *testing.T) {
|
||||
cache := newOpenAITokenCacheStub()
|
||||
|
||||
expiresAt := time.Now().Add(1 * time.Hour).Format(time.RFC3339)
|
||||
account := &Account{
|
||||
ID: 206,
|
||||
Platform: PlatformOpenAI,
|
||||
Type: AccountTypeOAuth,
|
||||
Credentials: map[string]any{
|
||||
"expires_at": expiresAt,
|
||||
// No access_token
|
||||
},
|
||||
}
|
||||
|
||||
provider := NewOpenAITokenProvider(nil, cache, nil)
|
||||
token, err := provider.GetAccessToken(context.Background(), account)
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "access_token not found")
|
||||
require.Empty(t, token)
|
||||
}
|
||||
@@ -85,13 +85,24 @@ func (s *RateLimitService) HandleUpstreamError(ctx context.Context, account *Acc
|
||||
|
||||
switch statusCode {
|
||||
case 401:
|
||||
if account.Type == AccountTypeOAuth &&
|
||||
(account.Platform == PlatformAntigravity || account.Platform == PlatformGemini) {
|
||||
// 对所有 OAuth 账号在 401 错误时调用缓存失效并强制下次刷新
|
||||
if account.Type == AccountTypeOAuth {
|
||||
// 1. 失效缓存
|
||||
if s.tokenCacheInvalidator != nil {
|
||||
if err := s.tokenCacheInvalidator.InvalidateToken(ctx, account); err != nil {
|
||||
slog.Warn("oauth_401_invalidate_cache_failed", "account_id", account.ID, "error", err)
|
||||
}
|
||||
}
|
||||
// 2. 设置 expires_at 为当前时间,强制下次请求刷新 token
|
||||
if account.Credentials == nil {
|
||||
account.Credentials = make(map[string]any)
|
||||
}
|
||||
account.Credentials["expires_at"] = time.Now().Format(time.RFC3339)
|
||||
if err := s.accountRepo.Update(ctx, account); err != nil {
|
||||
slog.Warn("oauth_401_force_refresh_update_failed", "account_id", account.ID, "error", err)
|
||||
} else {
|
||||
slog.Info("oauth_401_force_refresh_set", "account_id", account.ID, "platform", account.Platform)
|
||||
}
|
||||
}
|
||||
msg := "Authentication failed (401): invalid or expired credentials"
|
||||
if upstreamMsg != "" {
|
||||
|
||||
@@ -7,29 +7,35 @@ type TokenCacheInvalidator interface {
|
||||
}
|
||||
|
||||
type CompositeTokenCacheInvalidator struct {
|
||||
geminiCache GeminiTokenCache
|
||||
cache GeminiTokenCache // 统一使用一个缓存接口,通过缓存键前缀区分平台
|
||||
}
|
||||
|
||||
func NewCompositeTokenCacheInvalidator(geminiCache GeminiTokenCache) *CompositeTokenCacheInvalidator {
|
||||
func NewCompositeTokenCacheInvalidator(cache GeminiTokenCache) *CompositeTokenCacheInvalidator {
|
||||
return &CompositeTokenCacheInvalidator{
|
||||
geminiCache: geminiCache,
|
||||
cache: cache,
|
||||
}
|
||||
}
|
||||
|
||||
func (c *CompositeTokenCacheInvalidator) InvalidateToken(ctx context.Context, account *Account) error {
|
||||
if c == nil || c.geminiCache == nil || account == nil {
|
||||
if c == nil || c.cache == nil || account == nil {
|
||||
return nil
|
||||
}
|
||||
if account.Type != AccountTypeOAuth {
|
||||
return nil
|
||||
}
|
||||
|
||||
var cacheKey string
|
||||
switch account.Platform {
|
||||
case PlatformGemini:
|
||||
return c.geminiCache.DeleteAccessToken(ctx, GeminiTokenCacheKey(account))
|
||||
cacheKey = GeminiTokenCacheKey(account)
|
||||
case PlatformAntigravity:
|
||||
return c.geminiCache.DeleteAccessToken(ctx, AntigravityTokenCacheKey(account))
|
||||
cacheKey = AntigravityTokenCacheKey(account)
|
||||
case PlatformOpenAI:
|
||||
cacheKey = OpenAITokenCacheKey(account)
|
||||
case PlatformAnthropic:
|
||||
cacheKey = ClaudeTokenCacheKey(account)
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
return c.cache.DeleteAccessToken(ctx, cacheKey)
|
||||
}
|
||||
|
||||
@@ -4,6 +4,7 @@ package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
@@ -50,7 +51,7 @@ func TestCompositeTokenCacheInvalidator_Gemini(t *testing.T) {
|
||||
|
||||
err := invalidator.InvalidateToken(context.Background(), account)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, []string{"project-x"}, cache.deletedKeys)
|
||||
require.Equal(t, []string{"gemini:project-x"}, cache.deletedKeys)
|
||||
}
|
||||
|
||||
func TestCompositeTokenCacheInvalidator_Antigravity(t *testing.T) {
|
||||
@@ -70,13 +71,99 @@ func TestCompositeTokenCacheInvalidator_Antigravity(t *testing.T) {
|
||||
require.Equal(t, []string{"ag:ag-project"}, cache.deletedKeys)
|
||||
}
|
||||
|
||||
func TestCompositeTokenCacheInvalidator_SkipNonOAuth(t *testing.T) {
|
||||
func TestCompositeTokenCacheInvalidator_OpenAI(t *testing.T) {
|
||||
cache := &geminiTokenCacheStub{}
|
||||
invalidator := NewCompositeTokenCacheInvalidator(cache)
|
||||
account := &Account{
|
||||
ID: 1,
|
||||
Platform: PlatformGemini,
|
||||
Type: AccountTypeAPIKey,
|
||||
ID: 500,
|
||||
Platform: PlatformOpenAI,
|
||||
Type: AccountTypeOAuth,
|
||||
Credentials: map[string]any{
|
||||
"access_token": "openai-token",
|
||||
},
|
||||
}
|
||||
|
||||
err := invalidator.InvalidateToken(context.Background(), account)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, []string{"openai:account:500"}, cache.deletedKeys)
|
||||
}
|
||||
|
||||
func TestCompositeTokenCacheInvalidator_Claude(t *testing.T) {
|
||||
cache := &geminiTokenCacheStub{}
|
||||
invalidator := NewCompositeTokenCacheInvalidator(cache)
|
||||
account := &Account{
|
||||
ID: 600,
|
||||
Platform: PlatformAnthropic,
|
||||
Type: AccountTypeOAuth,
|
||||
Credentials: map[string]any{
|
||||
"access_token": "claude-token",
|
||||
},
|
||||
}
|
||||
|
||||
err := invalidator.InvalidateToken(context.Background(), account)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, []string{"claude:account:600"}, cache.deletedKeys)
|
||||
}
|
||||
|
||||
func TestCompositeTokenCacheInvalidator_SkipNonOAuth(t *testing.T) {
|
||||
cache := &geminiTokenCacheStub{}
|
||||
invalidator := NewCompositeTokenCacheInvalidator(cache)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
account *Account
|
||||
}{
|
||||
{
|
||||
name: "gemini_api_key",
|
||||
account: &Account{
|
||||
ID: 1,
|
||||
Platform: PlatformGemini,
|
||||
Type: AccountTypeAPIKey,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "openai_api_key",
|
||||
account: &Account{
|
||||
ID: 2,
|
||||
Platform: PlatformOpenAI,
|
||||
Type: AccountTypeAPIKey,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "claude_api_key",
|
||||
account: &Account{
|
||||
ID: 3,
|
||||
Platform: PlatformAnthropic,
|
||||
Type: AccountTypeAPIKey,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "claude_setup_token",
|
||||
account: &Account{
|
||||
ID: 4,
|
||||
Platform: PlatformAnthropic,
|
||||
Type: AccountTypeSetupToken,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
cache.deletedKeys = nil
|
||||
err := invalidator.InvalidateToken(context.Background(), tt.account)
|
||||
require.NoError(t, err)
|
||||
require.Empty(t, cache.deletedKeys)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestCompositeTokenCacheInvalidator_SkipUnsupportedPlatform(t *testing.T) {
|
||||
cache := &geminiTokenCacheStub{}
|
||||
invalidator := NewCompositeTokenCacheInvalidator(cache)
|
||||
account := &Account{
|
||||
ID: 100,
|
||||
Platform: "unknown-platform",
|
||||
Type: AccountTypeOAuth,
|
||||
}
|
||||
|
||||
err := invalidator.InvalidateToken(context.Background(), account)
|
||||
@@ -95,3 +182,87 @@ func TestCompositeTokenCacheInvalidator_NilCache(t *testing.T) {
|
||||
err := invalidator.InvalidateToken(context.Background(), account)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestCompositeTokenCacheInvalidator_NilAccount(t *testing.T) {
|
||||
cache := &geminiTokenCacheStub{}
|
||||
invalidator := NewCompositeTokenCacheInvalidator(cache)
|
||||
|
||||
err := invalidator.InvalidateToken(context.Background(), nil)
|
||||
require.NoError(t, err)
|
||||
require.Empty(t, cache.deletedKeys)
|
||||
}
|
||||
|
||||
func TestCompositeTokenCacheInvalidator_NilInvalidator(t *testing.T) {
|
||||
var invalidator *CompositeTokenCacheInvalidator
|
||||
account := &Account{
|
||||
ID: 5,
|
||||
Platform: PlatformGemini,
|
||||
Type: AccountTypeOAuth,
|
||||
}
|
||||
|
||||
err := invalidator.InvalidateToken(context.Background(), account)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestCompositeTokenCacheInvalidator_DeleteError(t *testing.T) {
|
||||
expectedErr := errors.New("redis connection failed")
|
||||
cache := &geminiTokenCacheStub{deleteErr: expectedErr}
|
||||
invalidator := NewCompositeTokenCacheInvalidator(cache)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
account *Account
|
||||
}{
|
||||
{
|
||||
name: "openai_delete_error",
|
||||
account: &Account{
|
||||
ID: 700,
|
||||
Platform: PlatformOpenAI,
|
||||
Type: AccountTypeOAuth,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "claude_delete_error",
|
||||
account: &Account{
|
||||
ID: 800,
|
||||
Platform: PlatformAnthropic,
|
||||
Type: AccountTypeOAuth,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
err := invalidator.InvalidateToken(context.Background(), tt.account)
|
||||
require.Error(t, err)
|
||||
require.Equal(t, expectedErr, err)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestCompositeTokenCacheInvalidator_AllPlatformsIntegration(t *testing.T) {
|
||||
// 测试所有平台的缓存键生成和删除
|
||||
cache := &geminiTokenCacheStub{}
|
||||
invalidator := NewCompositeTokenCacheInvalidator(cache)
|
||||
|
||||
accounts := []*Account{
|
||||
{ID: 1, Platform: PlatformGemini, Type: AccountTypeOAuth, Credentials: map[string]any{"project_id": "gemini-proj"}},
|
||||
{ID: 2, Platform: PlatformAntigravity, Type: AccountTypeOAuth, Credentials: map[string]any{"project_id": "ag-proj"}},
|
||||
{ID: 3, Platform: PlatformOpenAI, Type: AccountTypeOAuth},
|
||||
{ID: 4, Platform: PlatformAnthropic, Type: AccountTypeOAuth},
|
||||
}
|
||||
|
||||
expectedKeys := []string{
|
||||
"gemini:gemini-proj",
|
||||
"ag:ag-proj",
|
||||
"openai:account:3",
|
||||
"claude:account:4",
|
||||
}
|
||||
|
||||
for _, acc := range accounts {
|
||||
err := invalidator.InvalidateToken(context.Background(), acc)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
require.Equal(t, expectedKeys, cache.deletedKeys)
|
||||
}
|
||||
|
||||
15
backend/internal/service/token_cache_key.go
Normal file
15
backend/internal/service/token_cache_key.go
Normal file
@@ -0,0 +1,15 @@
|
||||
package service
|
||||
|
||||
import "strconv"
|
||||
|
||||
// OpenAITokenCacheKey 生成 OpenAI OAuth 账号的缓存键
|
||||
// 格式: "openai:account:{account_id}"
|
||||
func OpenAITokenCacheKey(account *Account) string {
|
||||
return "openai:account:" + strconv.FormatInt(account.ID, 10)
|
||||
}
|
||||
|
||||
// ClaudeTokenCacheKey 生成 Claude (Anthropic) OAuth 账号的缓存键
|
||||
// 格式: "claude:account:{account_id}"
|
||||
func ClaudeTokenCacheKey(account *Account) string {
|
||||
return "claude:account:" + strconv.FormatInt(account.ID, 10)
|
||||
}
|
||||
@@ -22,7 +22,7 @@ func TestGeminiTokenCacheKey(t *testing.T) {
|
||||
"project_id": "my-project-123",
|
||||
},
|
||||
},
|
||||
expected: "my-project-123",
|
||||
expected: "gemini:my-project-123",
|
||||
},
|
||||
{
|
||||
name: "project_id_with_whitespace",
|
||||
@@ -32,7 +32,7 @@ func TestGeminiTokenCacheKey(t *testing.T) {
|
||||
"project_id": " project-with-spaces ",
|
||||
},
|
||||
},
|
||||
expected: "project-with-spaces",
|
||||
expected: "gemini:project-with-spaces",
|
||||
},
|
||||
{
|
||||
name: "empty_project_id_fallback_to_account_id",
|
||||
@@ -42,7 +42,7 @@ func TestGeminiTokenCacheKey(t *testing.T) {
|
||||
"project_id": "",
|
||||
},
|
||||
},
|
||||
expected: "account:102",
|
||||
expected: "gemini:account:102",
|
||||
},
|
||||
{
|
||||
name: "whitespace_only_project_id_fallback_to_account_id",
|
||||
@@ -52,7 +52,7 @@ func TestGeminiTokenCacheKey(t *testing.T) {
|
||||
"project_id": " ",
|
||||
},
|
||||
},
|
||||
expected: "account:103",
|
||||
expected: "gemini:account:103",
|
||||
},
|
||||
{
|
||||
name: "no_project_id_key_fallback_to_account_id",
|
||||
@@ -60,7 +60,7 @@ func TestGeminiTokenCacheKey(t *testing.T) {
|
||||
ID: 104,
|
||||
Credentials: map[string]any{},
|
||||
},
|
||||
expected: "account:104",
|
||||
expected: "gemini:account:104",
|
||||
},
|
||||
{
|
||||
name: "nil_credentials_fallback_to_account_id",
|
||||
@@ -68,7 +68,7 @@ func TestGeminiTokenCacheKey(t *testing.T) {
|
||||
ID: 105,
|
||||
Credentials: nil,
|
||||
},
|
||||
expected: "account:105",
|
||||
expected: "gemini:account:105",
|
||||
},
|
||||
}
|
||||
|
||||
@@ -151,3 +151,109 @@ func TestAntigravityTokenCacheKey(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestOpenAITokenCacheKey(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
account *Account
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "basic_account",
|
||||
account: &Account{
|
||||
ID: 300,
|
||||
},
|
||||
expected: "openai:account:300",
|
||||
},
|
||||
{
|
||||
name: "account_with_credentials",
|
||||
account: &Account{
|
||||
ID: 301,
|
||||
Credentials: map[string]any{
|
||||
"access_token": "test-token",
|
||||
},
|
||||
},
|
||||
expected: "openai:account:301",
|
||||
},
|
||||
{
|
||||
name: "account_id_zero",
|
||||
account: &Account{
|
||||
ID: 0,
|
||||
},
|
||||
expected: "openai:account:0",
|
||||
},
|
||||
{
|
||||
name: "large_account_id",
|
||||
account: &Account{
|
||||
ID: 9999999999,
|
||||
},
|
||||
expected: "openai:account:9999999999",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := OpenAITokenCacheKey(tt.account)
|
||||
require.Equal(t, tt.expected, result)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestClaudeTokenCacheKey(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
account *Account
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "basic_account",
|
||||
account: &Account{
|
||||
ID: 400,
|
||||
},
|
||||
expected: "claude:account:400",
|
||||
},
|
||||
{
|
||||
name: "account_with_credentials",
|
||||
account: &Account{
|
||||
ID: 401,
|
||||
Credentials: map[string]any{
|
||||
"access_token": "claude-token",
|
||||
},
|
||||
},
|
||||
expected: "claude:account:401",
|
||||
},
|
||||
{
|
||||
name: "account_id_zero",
|
||||
account: &Account{
|
||||
ID: 0,
|
||||
},
|
||||
expected: "claude:account:0",
|
||||
},
|
||||
{
|
||||
name: "large_account_id",
|
||||
account: &Account{
|
||||
ID: 9999999999,
|
||||
},
|
||||
expected: "claude:account:9999999999",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := ClaudeTokenCacheKey(tt.account)
|
||||
require.Equal(t, tt.expected, result)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestCacheKeyUniqueness(t *testing.T) {
|
||||
// 确保不同平台的缓存键不会冲突
|
||||
account := &Account{ID: 123}
|
||||
|
||||
openaiKey := OpenAITokenCacheKey(account)
|
||||
claudeKey := ClaudeTokenCacheKey(account)
|
||||
|
||||
require.NotEqual(t, openaiKey, claudeKey, "OpenAI and Claude cache keys should be different")
|
||||
require.Contains(t, openaiKey, "openai:")
|
||||
require.Contains(t, claudeKey, "claude:")
|
||||
}
|
||||
|
||||
@@ -172,8 +172,8 @@ func (s *TokenRefreshService) refreshWithRetry(ctx context.Context, account *Acc
|
||||
if err := s.accountRepo.Update(ctx, account); err != nil {
|
||||
return fmt.Errorf("failed to save credentials: %w", err)
|
||||
}
|
||||
if s.cacheInvalidator != nil && account.Type == AccountTypeOAuth &&
|
||||
(account.Platform == PlatformGemini || account.Platform == PlatformAntigravity) {
|
||||
// 对所有 OAuth 账号调用缓存失效(InvalidateToken 内部根据平台判断是否需要处理)
|
||||
if s.cacheInvalidator != nil && account.Type == AccountTypeOAuth {
|
||||
if err := s.cacheInvalidator.InvalidateToken(ctx, account); err != nil {
|
||||
log.Printf("[TokenRefresh] Failed to invalidate token cache for account %d: %v", account.ID, err)
|
||||
} else {
|
||||
|
||||
@@ -197,7 +197,7 @@ func TestTokenRefreshService_RefreshWithRetry_NonOAuthAccount(t *testing.T) {
|
||||
require.Equal(t, 0, invalidator.calls) // 非 OAuth 不触发缓存失效
|
||||
}
|
||||
|
||||
// TestTokenRefreshService_RefreshWithRetry_OtherPlatformOAuth 测试其他平台的 OAuth 账号不触发缓存失效
|
||||
// TestTokenRefreshService_RefreshWithRetry_OtherPlatformOAuth 测试所有 OAuth 平台都触发缓存失效
|
||||
func TestTokenRefreshService_RefreshWithRetry_OtherPlatformOAuth(t *testing.T) {
|
||||
repo := &tokenRefreshAccountRepo{}
|
||||
invalidator := &tokenCacheInvalidatorStub{}
|
||||
@@ -210,7 +210,7 @@ func TestTokenRefreshService_RefreshWithRetry_OtherPlatformOAuth(t *testing.T) {
|
||||
service := NewTokenRefreshService(repo, nil, nil, nil, nil, invalidator, cfg)
|
||||
account := &Account{
|
||||
ID: 10,
|
||||
Platform: PlatformOpenAI, // 其他平台
|
||||
Platform: PlatformOpenAI, // OpenAI OAuth 账户
|
||||
Type: AccountTypeOAuth,
|
||||
}
|
||||
refresher := &tokenRefresherStub{
|
||||
@@ -222,7 +222,7 @@ func TestTokenRefreshService_RefreshWithRetry_OtherPlatformOAuth(t *testing.T) {
|
||||
err := service.refreshWithRetry(context.Background(), account, refresher)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, 1, repo.updateCalls)
|
||||
require.Equal(t, 0, invalidator.calls) // 其他平台不触发缓存失效
|
||||
require.Equal(t, 1, invalidator.calls) // 所有 OAuth 账户刷新后触发缓存失效
|
||||
}
|
||||
|
||||
// TestTokenRefreshService_RefreshWithRetry_UpdateFailed 测试更新失败的情况
|
||||
|
||||
@@ -214,10 +214,13 @@ var ProviderSet = wire.NewSet(
|
||||
NewGeminiOAuthService,
|
||||
NewGeminiQuotaService,
|
||||
NewCompositeTokenCacheInvalidator,
|
||||
wire.Bind(new(TokenCacheInvalidator), new(*CompositeTokenCacheInvalidator)),
|
||||
NewAntigravityOAuthService,
|
||||
NewGeminiTokenProvider,
|
||||
NewGeminiMessagesCompatService,
|
||||
NewAntigravityTokenProvider,
|
||||
NewOpenAITokenProvider,
|
||||
NewClaudeTokenProvider,
|
||||
NewAntigravityGatewayService,
|
||||
ProvideRateLimitService,
|
||||
NewAccountUsageService,
|
||||
|
||||
Reference in New Issue
Block a user