feat(网关): 引入 OpenAI/Claude OAuth token 缓存

新增 OpenAI/Claude TokenProvider 与缓存键生成
扩展 OAuth 缓存失效覆盖更多平台
统一 OAuth 缓存前缀与依赖注入
This commit is contained in:
yangjianbo
2026-01-15 18:27:06 +08:00
parent 35e3a89385
commit 1820389a05
16 changed files with 2464 additions and 85 deletions

View File

@@ -100,8 +100,8 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
tempUnschedCache := repository.NewTempUnschedCache(redisClient)
timeoutCounterCache := repository.NewTimeoutCounterCache(redisClient)
geminiTokenCache := repository.NewGeminiTokenCache(redisClient)
tokenCacheInvalidator := service.NewCompositeTokenCacheInvalidator(geminiTokenCache)
rateLimitService := service.ProvideRateLimitService(accountRepository, usageLogRepository, configConfig, geminiQuotaService, tempUnschedCache, timeoutCounterCache, settingService, tokenCacheInvalidator)
compositeTokenCacheInvalidator := service.NewCompositeTokenCacheInvalidator(geminiTokenCache)
rateLimitService := service.ProvideRateLimitService(accountRepository, usageLogRepository, configConfig, geminiQuotaService, tempUnschedCache, timeoutCounterCache, settingService, compositeTokenCacheInvalidator)
claudeUsageFetcher := repository.NewClaudeUsageFetcher()
antigravityQuotaFetcher := service.NewAntigravityQuotaFetcher(proxyRepository)
usageCache := service.NewUsageCache()
@@ -136,8 +136,10 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
identityCache := repository.NewIdentityCache(redisClient)
identityService := service.NewIdentityService(identityCache)
deferredService := service.ProvideDeferredService(accountRepository, timingWheelService)
gatewayService := service.NewGatewayService(accountRepository, groupRepository, usageLogRepository, userRepository, userSubscriptionRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, identityService, httpUpstream, deferredService)
openAIGatewayService := service.NewOpenAIGatewayService(accountRepository, usageLogRepository, userRepository, userSubscriptionRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, httpUpstream, deferredService)
claudeTokenProvider := service.NewClaudeTokenProvider(accountRepository, geminiTokenCache, oAuthService)
gatewayService := service.NewGatewayService(accountRepository, groupRepository, usageLogRepository, userRepository, userSubscriptionRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, identityService, httpUpstream, deferredService, claudeTokenProvider)
openAITokenProvider := service.NewOpenAITokenProvider(accountRepository, geminiTokenCache, openAIOAuthService)
openAIGatewayService := service.NewOpenAIGatewayService(accountRepository, usageLogRepository, userRepository, userSubscriptionRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, httpUpstream, deferredService, openAITokenProvider)
geminiMessagesCompatService := service.NewGeminiMessagesCompatService(accountRepository, groupRepository, gatewayCache, schedulerSnapshotService, geminiTokenProvider, rateLimitService, httpUpstream, antigravityGatewayService, configConfig)
opsService := service.NewOpsService(opsRepository, settingRepository, configConfig, accountRepository, concurrencyService, gatewayService, openAIGatewayService, geminiMessagesCompatService, antigravityGatewayService)
settingHandler := admin.NewSettingHandler(settingService, emailService, turnstileService, opsService)
@@ -168,7 +170,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
opsAlertEvaluatorService := service.ProvideOpsAlertEvaluatorService(opsService, opsRepository, emailService, redisClient, configConfig)
opsCleanupService := service.ProvideOpsCleanupService(opsRepository, db, redisClient, configConfig)
opsScheduledReportService := service.ProvideOpsScheduledReportService(opsService, userService, emailService, redisClient, configConfig)
tokenRefreshService := service.ProvideTokenRefreshService(accountRepository, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, tokenCacheInvalidator, configConfig)
tokenRefreshService := service.ProvideTokenRefreshService(accountRepository, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, compositeTokenCacheInvalidator, configConfig)
accountExpiryService := service.ProvideAccountExpiryService(accountRepository)
v := provideCleanup(client, redisClient, opsMetricsCollector, opsAggregationService, opsAlertEvaluatorService, opsCleanupService, opsScheduledReportService, schedulerSnapshotService, tokenRefreshService, accountExpiryService, pricingService, emailQueueService, billingCacheService, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService)
application := &Application{

View File

@@ -11,8 +11,8 @@ import (
)
const (
geminiTokenKeyPrefix = "gemini:token:"
geminiRefreshLockKeyPrefix = "gemini:refresh_lock:"
oauthTokenKeyPrefix = "oauth:token:"
oauthRefreshLockKeyPrefix = "oauth:refresh_lock:"
)
type geminiTokenCache struct {
@@ -24,26 +24,26 @@ func NewGeminiTokenCache(rdb *redis.Client) service.GeminiTokenCache {
}
func (c *geminiTokenCache) GetAccessToken(ctx context.Context, cacheKey string) (string, error) {
key := fmt.Sprintf("%s%s", geminiTokenKeyPrefix, cacheKey)
key := fmt.Sprintf("%s%s", oauthTokenKeyPrefix, cacheKey)
return c.rdb.Get(ctx, key).Result()
}
func (c *geminiTokenCache) SetAccessToken(ctx context.Context, cacheKey string, token string, ttl time.Duration) error {
key := fmt.Sprintf("%s%s", geminiTokenKeyPrefix, cacheKey)
key := fmt.Sprintf("%s%s", oauthTokenKeyPrefix, cacheKey)
return c.rdb.Set(ctx, key, token, ttl).Err()
}
func (c *geminiTokenCache) DeleteAccessToken(ctx context.Context, cacheKey string) error {
key := fmt.Sprintf("%s%s", geminiTokenKeyPrefix, cacheKey)
key := fmt.Sprintf("%s%s", oauthTokenKeyPrefix, cacheKey)
return c.rdb.Del(ctx, key).Err()
}
func (c *geminiTokenCache) AcquireRefreshLock(ctx context.Context, cacheKey string, ttl time.Duration) (bool, error) {
key := fmt.Sprintf("%s%s", geminiRefreshLockKeyPrefix, cacheKey)
key := fmt.Sprintf("%s%s", oauthRefreshLockKeyPrefix, cacheKey)
return c.rdb.SetNX(ctx, key, 1, ttl).Result()
}
func (c *geminiTokenCache) ReleaseRefreshLock(ctx context.Context, cacheKey string) error {
key := fmt.Sprintf("%s%s", geminiRefreshLockKeyPrefix, cacheKey)
key := fmt.Sprintf("%s%s", oauthRefreshLockKeyPrefix, cacheKey)
return c.rdb.Del(ctx, key).Err()
}

View File

@@ -0,0 +1,156 @@
package service
import (
"context"
"errors"
"log/slog"
"strconv"
"strings"
"time"
)
const (
claudeTokenRefreshSkew = 3 * time.Minute
claudeTokenCacheSkew = 5 * time.Minute
claudeLockWaitTime = 200 * time.Millisecond
)
// ClaudeTokenCache Token 缓存接口(复用 GeminiTokenCache 接口定义)
type ClaudeTokenCache = GeminiTokenCache
// ClaudeTokenProvider 管理 Claude (Anthropic) OAuth 账户的 access_token
type ClaudeTokenProvider struct {
accountRepo AccountRepository
tokenCache ClaudeTokenCache
oauthService *OAuthService
}
func NewClaudeTokenProvider(
accountRepo AccountRepository,
tokenCache ClaudeTokenCache,
oauthService *OAuthService,
) *ClaudeTokenProvider {
return &ClaudeTokenProvider{
accountRepo: accountRepo,
tokenCache: tokenCache,
oauthService: oauthService,
}
}
// GetAccessToken 获取有效的 access_token
func (p *ClaudeTokenProvider) GetAccessToken(ctx context.Context, account *Account) (string, error) {
if account == nil {
return "", errors.New("account is nil")
}
if account.Platform != PlatformAnthropic || account.Type != AccountTypeOAuth {
return "", errors.New("not an anthropic oauth account")
}
cacheKey := ClaudeTokenCacheKey(account)
// 1. 先尝试缓存
if p.tokenCache != nil {
if token, err := p.tokenCache.GetAccessToken(ctx, cacheKey); err == nil && strings.TrimSpace(token) != "" {
slog.Debug("claude_token_cache_hit", "account_id", account.ID)
return token, nil
} else if err != nil {
slog.Warn("claude_token_cache_get_failed", "account_id", account.ID, "error", err)
}
}
slog.Debug("claude_token_cache_miss", "account_id", account.ID)
// 2. 如果即将过期则刷新
expiresAt := account.GetCredentialAsTime("expires_at")
needsRefresh := expiresAt == nil || time.Until(*expiresAt) <= claudeTokenRefreshSkew
refreshFailed := false
if needsRefresh && p.tokenCache != nil {
locked, err := p.tokenCache.AcquireRefreshLock(ctx, cacheKey, 30*time.Second)
if err == nil && locked {
defer func() { _ = p.tokenCache.ReleaseRefreshLock(ctx, cacheKey) }()
// 拿到锁后再次检查缓存(另一个 worker 可能已刷新)
if token, err := p.tokenCache.GetAccessToken(ctx, cacheKey); err == nil && strings.TrimSpace(token) != "" {
return token, nil
}
// 从数据库获取最新账户信息
fresh, err := p.accountRepo.GetByID(ctx, account.ID)
if err == nil && fresh != nil {
account = fresh
}
expiresAt = account.GetCredentialAsTime("expires_at")
if expiresAt == nil || time.Until(*expiresAt) <= claudeTokenRefreshSkew {
if p.oauthService == nil {
slog.Warn("claude_oauth_service_not_configured", "account_id", account.ID)
refreshFailed = true // 无法刷新,标记失败
} else {
tokenInfo, err := p.oauthService.RefreshAccountToken(ctx, account)
if err != nil {
// 刷新失败时记录警告,但不立即返回错误,尝试使用现有 token
slog.Warn("claude_token_refresh_failed", "account_id", account.ID, "error", err)
refreshFailed = true // 刷新失败,标记以使用短 TTL
} else {
// 构建新 credentials保留原有字段
newCredentials := make(map[string]any)
for k, v := range account.Credentials {
newCredentials[k] = v
}
newCredentials["access_token"] = tokenInfo.AccessToken
newCredentials["token_type"] = tokenInfo.TokenType
newCredentials["expires_in"] = strconv.FormatInt(tokenInfo.ExpiresIn, 10)
newCredentials["expires_at"] = strconv.FormatInt(tokenInfo.ExpiresAt, 10)
if tokenInfo.RefreshToken != "" {
newCredentials["refresh_token"] = tokenInfo.RefreshToken
}
if tokenInfo.Scope != "" {
newCredentials["scope"] = tokenInfo.Scope
}
account.Credentials = newCredentials
if updateErr := p.accountRepo.Update(ctx, account); updateErr != nil {
slog.Error("claude_token_provider_update_failed", "account_id", account.ID, "error", updateErr)
}
expiresAt = account.GetCredentialAsTime("expires_at")
}
}
}
} else {
// 锁获取失败,等待 200ms 后重试读取缓存(改进:减少并发时的缓存未命中)
time.Sleep(claudeLockWaitTime)
if token, err := p.tokenCache.GetAccessToken(ctx, cacheKey); err == nil && strings.TrimSpace(token) != "" {
slog.Debug("claude_token_cache_hit_after_wait", "account_id", account.ID)
return token, nil
}
}
}
accessToken := account.GetCredential("access_token")
if strings.TrimSpace(accessToken) == "" {
return "", errors.New("access_token not found in credentials")
}
// 3. 存入缓存
if p.tokenCache != nil {
ttl := 30 * time.Minute
if refreshFailed {
// 刷新失败时使用短 TTL避免失效 token 长时间缓存导致 401 抖动
ttl = time.Minute
slog.Debug("claude_token_cache_short_ttl", "account_id", account.ID, "reason", "refresh_failed")
} else if expiresAt != nil {
until := time.Until(*expiresAt)
switch {
case until > claudeTokenCacheSkew:
ttl = until - claudeTokenCacheSkew
case until > 0:
ttl = until
default:
ttl = time.Minute
}
}
if err := p.tokenCache.SetAccessToken(ctx, cacheKey, accessToken, ttl); err != nil {
slog.Warn("claude_token_cache_set_failed", "account_id", account.ID, "error", err)
}
}
return accessToken, nil
}

View File

@@ -0,0 +1,939 @@
//go:build unit
package service
import (
"context"
"errors"
"sync"
"sync/atomic"
"testing"
"time"
"github.com/stretchr/testify/require"
)
// claudeTokenCacheStub implements ClaudeTokenCache for testing
type claudeTokenCacheStub struct {
mu sync.Mutex
tokens map[string]string
getErr error
setErr error
deleteErr error
lockAcquired bool
lockErr error
releaseLockErr error
getCalled int32
setCalled int32
lockCalled int32
unlockCalled int32
simulateLockRace bool
}
func newClaudeTokenCacheStub() *claudeTokenCacheStub {
return &claudeTokenCacheStub{
tokens: make(map[string]string),
lockAcquired: true,
}
}
func (s *claudeTokenCacheStub) GetAccessToken(ctx context.Context, cacheKey string) (string, error) {
atomic.AddInt32(&s.getCalled, 1)
if s.getErr != nil {
return "", s.getErr
}
s.mu.Lock()
defer s.mu.Unlock()
return s.tokens[cacheKey], nil
}
func (s *claudeTokenCacheStub) SetAccessToken(ctx context.Context, cacheKey string, token string, ttl time.Duration) error {
atomic.AddInt32(&s.setCalled, 1)
if s.setErr != nil {
return s.setErr
}
s.mu.Lock()
defer s.mu.Unlock()
s.tokens[cacheKey] = token
return nil
}
func (s *claudeTokenCacheStub) DeleteAccessToken(ctx context.Context, cacheKey string) error {
if s.deleteErr != nil {
return s.deleteErr
}
s.mu.Lock()
defer s.mu.Unlock()
delete(s.tokens, cacheKey)
return nil
}
func (s *claudeTokenCacheStub) AcquireRefreshLock(ctx context.Context, cacheKey string, ttl time.Duration) (bool, error) {
atomic.AddInt32(&s.lockCalled, 1)
if s.lockErr != nil {
return false, s.lockErr
}
if s.simulateLockRace {
return false, nil
}
return s.lockAcquired, nil
}
func (s *claudeTokenCacheStub) ReleaseRefreshLock(ctx context.Context, cacheKey string) error {
atomic.AddInt32(&s.unlockCalled, 1)
return s.releaseLockErr
}
// claudeAccountRepoStub is a minimal stub implementing only the methods used by ClaudeTokenProvider
type claudeAccountRepoStub struct {
account *Account
getErr error
updateErr error
getCalled int32
updateCalled int32
}
func (r *claudeAccountRepoStub) GetByID(ctx context.Context, id int64) (*Account, error) {
atomic.AddInt32(&r.getCalled, 1)
if r.getErr != nil {
return nil, r.getErr
}
return r.account, nil
}
func (r *claudeAccountRepoStub) Update(ctx context.Context, account *Account) error {
atomic.AddInt32(&r.updateCalled, 1)
if r.updateErr != nil {
return r.updateErr
}
r.account = account
return nil
}
// claudeOAuthServiceStub implements OAuthService methods for testing
type claudeOAuthServiceStub struct {
tokenInfo *TokenInfo
refreshErr error
refreshCalled int32
}
func (s *claudeOAuthServiceStub) RefreshAccountToken(ctx context.Context, account *Account) (*TokenInfo, error) {
atomic.AddInt32(&s.refreshCalled, 1)
if s.refreshErr != nil {
return nil, s.refreshErr
}
return s.tokenInfo, nil
}
// testClaudeTokenProvider is a test version that uses the stub OAuth service
type testClaudeTokenProvider struct {
accountRepo *claudeAccountRepoStub
tokenCache *claudeTokenCacheStub
oauthService *claudeOAuthServiceStub
}
func (p *testClaudeTokenProvider) GetAccessToken(ctx context.Context, account *Account) (string, error) {
if account == nil {
return "", errors.New("account is nil")
}
if account.Platform != PlatformAnthropic || account.Type != AccountTypeOAuth {
return "", errors.New("not an anthropic oauth account")
}
cacheKey := ClaudeTokenCacheKey(account)
// 1. Check cache
if p.tokenCache != nil {
if token, err := p.tokenCache.GetAccessToken(ctx, cacheKey); err == nil && token != "" {
return token, nil
}
}
// 2. Check if refresh needed
expiresAt := account.GetCredentialAsTime("expires_at")
needsRefresh := expiresAt == nil || time.Until(*expiresAt) <= claudeTokenRefreshSkew
refreshFailed := false
if needsRefresh && p.tokenCache != nil {
locked, err := p.tokenCache.AcquireRefreshLock(ctx, cacheKey, 30*time.Second)
if err == nil && locked {
defer func() { _ = p.tokenCache.ReleaseRefreshLock(ctx, cacheKey) }()
// Check cache again after acquiring lock
if token, err := p.tokenCache.GetAccessToken(ctx, cacheKey); err == nil && token != "" {
return token, nil
}
// Get fresh account from DB
fresh, err := p.accountRepo.GetByID(ctx, account.ID)
if err == nil && fresh != nil {
account = fresh
}
expiresAt = account.GetCredentialAsTime("expires_at")
if expiresAt == nil || time.Until(*expiresAt) <= claudeTokenRefreshSkew {
if p.oauthService == nil {
refreshFailed = true // 无法刷新,标记失败
} else {
tokenInfo, err := p.oauthService.RefreshAccountToken(ctx, account)
if err != nil {
refreshFailed = true // 刷新失败,标记以使用短 TTL
} else {
// Build new credentials
newCredentials := make(map[string]any)
for k, v := range account.Credentials {
newCredentials[k] = v
}
newCredentials["access_token"] = tokenInfo.AccessToken
newCredentials["token_type"] = tokenInfo.TokenType
newCredentials["expires_at"] = time.Now().Add(time.Duration(tokenInfo.ExpiresIn) * time.Second).Format(time.RFC3339)
if tokenInfo.RefreshToken != "" {
newCredentials["refresh_token"] = tokenInfo.RefreshToken
}
account.Credentials = newCredentials
_ = p.accountRepo.Update(ctx, account)
expiresAt = account.GetCredentialAsTime("expires_at")
}
}
}
} else if p.tokenCache.simulateLockRace {
// Wait and retry cache
time.Sleep(10 * time.Millisecond)
if token, err := p.tokenCache.GetAccessToken(ctx, cacheKey); err == nil && token != "" {
return token, nil
}
}
}
accessToken := account.GetCredential("access_token")
if accessToken == "" {
return "", errors.New("access_token not found in credentials")
}
// 3. Store in cache
if p.tokenCache != nil {
ttl := 30 * time.Minute
if refreshFailed {
ttl = time.Minute // 刷新失败时使用短 TTL
} else if expiresAt != nil {
until := time.Until(*expiresAt)
if until > claudeTokenCacheSkew {
ttl = until - claudeTokenCacheSkew
} else if until > 0 {
ttl = until
} else {
ttl = time.Minute
}
}
_ = p.tokenCache.SetAccessToken(ctx, cacheKey, accessToken, ttl)
}
return accessToken, nil
}
func TestClaudeTokenProvider_CacheHit(t *testing.T) {
cache := newClaudeTokenCacheStub()
account := &Account{
ID: 100,
Platform: PlatformAnthropic,
Type: AccountTypeOAuth,
Credentials: map[string]any{
"access_token": "db-token",
},
}
cacheKey := ClaudeTokenCacheKey(account)
cache.tokens[cacheKey] = "cached-token"
provider := NewClaudeTokenProvider(nil, cache, nil)
token, err := provider.GetAccessToken(context.Background(), account)
require.NoError(t, err)
require.Equal(t, "cached-token", token)
require.Equal(t, int32(1), atomic.LoadInt32(&cache.getCalled))
require.Equal(t, int32(0), atomic.LoadInt32(&cache.setCalled))
}
func TestClaudeTokenProvider_CacheMiss_FromCredentials(t *testing.T) {
cache := newClaudeTokenCacheStub()
// Token expires in far future, no refresh needed
expiresAt := time.Now().Add(1 * time.Hour).Format(time.RFC3339)
account := &Account{
ID: 101,
Platform: PlatformAnthropic,
Type: AccountTypeOAuth,
Credentials: map[string]any{
"access_token": "credential-token",
"expires_at": expiresAt,
},
}
provider := NewClaudeTokenProvider(nil, cache, nil)
token, err := provider.GetAccessToken(context.Background(), account)
require.NoError(t, err)
require.Equal(t, "credential-token", token)
// Should have stored in cache
cacheKey := ClaudeTokenCacheKey(account)
require.Equal(t, "credential-token", cache.tokens[cacheKey])
}
func TestClaudeTokenProvider_TokenRefresh(t *testing.T) {
cache := newClaudeTokenCacheStub()
accountRepo := &claudeAccountRepoStub{}
oauthService := &claudeOAuthServiceStub{
tokenInfo: &TokenInfo{
AccessToken: "refreshed-token",
RefreshToken: "new-refresh-token",
TokenType: "Bearer",
ExpiresIn: 3600,
ExpiresAt: time.Now().Add(time.Hour).Unix(),
},
}
// Token expires soon (within refresh skew)
expiresAt := time.Now().Add(1 * time.Minute).Format(time.RFC3339)
account := &Account{
ID: 102,
Platform: PlatformAnthropic,
Type: AccountTypeOAuth,
Credentials: map[string]any{
"access_token": "old-token",
"refresh_token": "old-refresh-token",
"expires_at": expiresAt,
},
}
accountRepo.account = account
provider := &testClaudeTokenProvider{
accountRepo: accountRepo,
tokenCache: cache,
oauthService: oauthService,
}
token, err := provider.GetAccessToken(context.Background(), account)
require.NoError(t, err)
require.Equal(t, "refreshed-token", token)
require.Equal(t, int32(1), atomic.LoadInt32(&oauthService.refreshCalled))
}
func TestClaudeTokenProvider_LockRaceCondition(t *testing.T) {
cache := newClaudeTokenCacheStub()
cache.simulateLockRace = true
accountRepo := &claudeAccountRepoStub{}
// Token expires soon
expiresAt := time.Now().Add(1 * time.Minute).Format(time.RFC3339)
account := &Account{
ID: 103,
Platform: PlatformAnthropic,
Type: AccountTypeOAuth,
Credentials: map[string]any{
"access_token": "race-token",
"expires_at": expiresAt,
},
}
accountRepo.account = account
// Simulate another worker already refreshed and cached
cacheKey := ClaudeTokenCacheKey(account)
go func() {
time.Sleep(5 * time.Millisecond)
cache.mu.Lock()
cache.tokens[cacheKey] = "winner-token"
cache.mu.Unlock()
}()
provider := &testClaudeTokenProvider{
accountRepo: accountRepo,
tokenCache: cache,
}
token, err := provider.GetAccessToken(context.Background(), account)
require.NoError(t, err)
require.NotEmpty(t, token)
}
func TestClaudeTokenProvider_NilAccount(t *testing.T) {
provider := NewClaudeTokenProvider(nil, nil, nil)
token, err := provider.GetAccessToken(context.Background(), nil)
require.Error(t, err)
require.Contains(t, err.Error(), "account is nil")
require.Empty(t, token)
}
func TestClaudeTokenProvider_WrongPlatform(t *testing.T) {
provider := NewClaudeTokenProvider(nil, nil, nil)
account := &Account{
ID: 104,
Platform: PlatformOpenAI,
Type: AccountTypeOAuth,
}
token, err := provider.GetAccessToken(context.Background(), account)
require.Error(t, err)
require.Contains(t, err.Error(), "not an anthropic oauth account")
require.Empty(t, token)
}
func TestClaudeTokenProvider_WrongAccountType(t *testing.T) {
provider := NewClaudeTokenProvider(nil, nil, nil)
account := &Account{
ID: 105,
Platform: PlatformAnthropic,
Type: AccountTypeAPIKey,
}
token, err := provider.GetAccessToken(context.Background(), account)
require.Error(t, err)
require.Contains(t, err.Error(), "not an anthropic oauth account")
require.Empty(t, token)
}
func TestClaudeTokenProvider_SetupTokenType(t *testing.T) {
provider := NewClaudeTokenProvider(nil, nil, nil)
account := &Account{
ID: 106,
Platform: PlatformAnthropic,
Type: AccountTypeSetupToken,
}
token, err := provider.GetAccessToken(context.Background(), account)
require.Error(t, err)
require.Contains(t, err.Error(), "not an anthropic oauth account")
require.Empty(t, token)
}
func TestClaudeTokenProvider_NilCache(t *testing.T) {
// Token doesn't need refresh
expiresAt := time.Now().Add(1 * time.Hour).Format(time.RFC3339)
account := &Account{
ID: 107,
Platform: PlatformAnthropic,
Type: AccountTypeOAuth,
Credentials: map[string]any{
"access_token": "nocache-token",
"expires_at": expiresAt,
},
}
provider := NewClaudeTokenProvider(nil, nil, nil)
token, err := provider.GetAccessToken(context.Background(), account)
require.NoError(t, err)
require.Equal(t, "nocache-token", token)
}
func TestClaudeTokenProvider_CacheGetError(t *testing.T) {
cache := newClaudeTokenCacheStub()
cache.getErr = errors.New("redis connection failed")
// Token doesn't need refresh
expiresAt := time.Now().Add(1 * time.Hour).Format(time.RFC3339)
account := &Account{
ID: 108,
Platform: PlatformAnthropic,
Type: AccountTypeOAuth,
Credentials: map[string]any{
"access_token": "fallback-token",
"expires_at": expiresAt,
},
}
provider := NewClaudeTokenProvider(nil, cache, nil)
// Should gracefully degrade and return from credentials
token, err := provider.GetAccessToken(context.Background(), account)
require.NoError(t, err)
require.Equal(t, "fallback-token", token)
}
func TestClaudeTokenProvider_CacheSetError(t *testing.T) {
cache := newClaudeTokenCacheStub()
cache.setErr = errors.New("redis write failed")
expiresAt := time.Now().Add(1 * time.Hour).Format(time.RFC3339)
account := &Account{
ID: 109,
Platform: PlatformAnthropic,
Type: AccountTypeOAuth,
Credentials: map[string]any{
"access_token": "still-works-token",
"expires_at": expiresAt,
},
}
provider := NewClaudeTokenProvider(nil, cache, nil)
// Should still work even if cache set fails
token, err := provider.GetAccessToken(context.Background(), account)
require.NoError(t, err)
require.Equal(t, "still-works-token", token)
}
func TestClaudeTokenProvider_MissingAccessToken(t *testing.T) {
cache := newClaudeTokenCacheStub()
expiresAt := time.Now().Add(1 * time.Hour).Format(time.RFC3339)
account := &Account{
ID: 110,
Platform: PlatformAnthropic,
Type: AccountTypeOAuth,
Credentials: map[string]any{
"expires_at": expiresAt,
// missing access_token
},
}
provider := NewClaudeTokenProvider(nil, cache, nil)
token, err := provider.GetAccessToken(context.Background(), account)
require.Error(t, err)
require.Contains(t, err.Error(), "access_token not found")
require.Empty(t, token)
}
func TestClaudeTokenProvider_RefreshError(t *testing.T) {
cache := newClaudeTokenCacheStub()
accountRepo := &claudeAccountRepoStub{}
oauthService := &claudeOAuthServiceStub{
refreshErr: errors.New("oauth refresh failed"),
}
// Token expires soon
expiresAt := time.Now().Add(1 * time.Minute).Format(time.RFC3339)
account := &Account{
ID: 111,
Platform: PlatformAnthropic,
Type: AccountTypeOAuth,
Credentials: map[string]any{
"access_token": "old-token",
"refresh_token": "old-refresh-token",
"expires_at": expiresAt,
},
}
accountRepo.account = account
provider := &testClaudeTokenProvider{
accountRepo: accountRepo,
tokenCache: cache,
oauthService: oauthService,
}
// Now with fallback behavior, should return existing token even if refresh fails
token, err := provider.GetAccessToken(context.Background(), account)
require.NoError(t, err)
require.Equal(t, "old-token", token) // Fallback to existing token
}
func TestClaudeTokenProvider_OAuthServiceNotConfigured(t *testing.T) {
cache := newClaudeTokenCacheStub()
accountRepo := &claudeAccountRepoStub{}
// Token expires soon
expiresAt := time.Now().Add(1 * time.Minute).Format(time.RFC3339)
account := &Account{
ID: 112,
Platform: PlatformAnthropic,
Type: AccountTypeOAuth,
Credentials: map[string]any{
"access_token": "old-token",
"expires_at": expiresAt,
},
}
accountRepo.account = account
provider := &testClaudeTokenProvider{
accountRepo: accountRepo,
tokenCache: cache,
oauthService: nil, // not configured
}
// Now with fallback behavior, should return existing token even if oauth service not configured
token, err := provider.GetAccessToken(context.Background(), account)
require.NoError(t, err)
require.Equal(t, "old-token", token) // Fallback to existing token
}
func TestClaudeTokenProvider_TTLCalculation(t *testing.T) {
tests := []struct {
name string
expiresIn time.Duration
}{
{
name: "far_future_expiry",
expiresIn: 1 * time.Hour,
},
{
name: "medium_expiry",
expiresIn: 10 * time.Minute,
},
{
name: "near_expiry",
expiresIn: 6 * time.Minute,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
cache := newClaudeTokenCacheStub()
expiresAt := time.Now().Add(tt.expiresIn).Format(time.RFC3339)
account := &Account{
ID: 200,
Platform: PlatformAnthropic,
Type: AccountTypeOAuth,
Credentials: map[string]any{
"access_token": "test-token",
"expires_at": expiresAt,
},
}
provider := NewClaudeTokenProvider(nil, cache, nil)
_, err := provider.GetAccessToken(context.Background(), account)
require.NoError(t, err)
// Verify token was cached
cacheKey := ClaudeTokenCacheKey(account)
require.Equal(t, "test-token", cache.tokens[cacheKey])
})
}
}
func TestClaudeTokenProvider_AccountRepoGetError(t *testing.T) {
cache := newClaudeTokenCacheStub()
accountRepo := &claudeAccountRepoStub{
getErr: errors.New("db connection failed"),
}
oauthService := &claudeOAuthServiceStub{
tokenInfo: &TokenInfo{
AccessToken: "refreshed-token",
RefreshToken: "new-refresh",
TokenType: "Bearer",
ExpiresIn: 3600,
},
}
// Token expires soon
expiresAt := time.Now().Add(1 * time.Minute).Format(time.RFC3339)
account := &Account{
ID: 113,
Platform: PlatformAnthropic,
Type: AccountTypeOAuth,
Credentials: map[string]any{
"access_token": "old-token",
"refresh_token": "old-refresh",
"expires_at": expiresAt,
},
}
provider := &testClaudeTokenProvider{
accountRepo: accountRepo,
tokenCache: cache,
oauthService: oauthService,
}
// Should still work, just using the passed-in account
token, err := provider.GetAccessToken(context.Background(), account)
require.NoError(t, err)
require.Equal(t, "refreshed-token", token)
}
func TestClaudeTokenProvider_AccountUpdateError(t *testing.T) {
cache := newClaudeTokenCacheStub()
accountRepo := &claudeAccountRepoStub{
updateErr: errors.New("db write failed"),
}
oauthService := &claudeOAuthServiceStub{
tokenInfo: &TokenInfo{
AccessToken: "refreshed-token",
RefreshToken: "new-refresh",
TokenType: "Bearer",
ExpiresIn: 3600,
},
}
// Token expires soon
expiresAt := time.Now().Add(1 * time.Minute).Format(time.RFC3339)
account := &Account{
ID: 114,
Platform: PlatformAnthropic,
Type: AccountTypeOAuth,
Credentials: map[string]any{
"access_token": "old-token",
"refresh_token": "old-refresh",
"expires_at": expiresAt,
},
}
accountRepo.account = account
provider := &testClaudeTokenProvider{
accountRepo: accountRepo,
tokenCache: cache,
oauthService: oauthService,
}
// Should still return token even if update fails
token, err := provider.GetAccessToken(context.Background(), account)
require.NoError(t, err)
require.Equal(t, "refreshed-token", token)
}
func TestClaudeTokenProvider_RefreshPreservesExistingCredentials(t *testing.T) {
cache := newClaudeTokenCacheStub()
accountRepo := &claudeAccountRepoStub{}
oauthService := &claudeOAuthServiceStub{
tokenInfo: &TokenInfo{
AccessToken: "new-access-token",
RefreshToken: "new-refresh-token",
TokenType: "Bearer",
ExpiresIn: 3600,
},
}
// Token expires soon
expiresAt := time.Now().Add(1 * time.Minute).Format(time.RFC3339)
account := &Account{
ID: 115,
Platform: PlatformAnthropic,
Type: AccountTypeOAuth,
Credentials: map[string]any{
"access_token": "old-access-token",
"refresh_token": "old-refresh-token",
"expires_at": expiresAt,
"custom_field": "should-be-preserved",
"organization": "test-org",
},
}
accountRepo.account = account
provider := &testClaudeTokenProvider{
accountRepo: accountRepo,
tokenCache: cache,
oauthService: oauthService,
}
token, err := provider.GetAccessToken(context.Background(), account)
require.NoError(t, err)
require.Equal(t, "new-access-token", token)
// Verify existing fields are preserved
require.Equal(t, "should-be-preserved", accountRepo.account.Credentials["custom_field"])
require.Equal(t, "test-org", accountRepo.account.Credentials["organization"])
// Verify new fields are updated
require.Equal(t, "new-access-token", accountRepo.account.Credentials["access_token"])
require.Equal(t, "new-refresh-token", accountRepo.account.Credentials["refresh_token"])
}
func TestClaudeTokenProvider_DoubleCheckCacheAfterLock(t *testing.T) {
cache := newClaudeTokenCacheStub()
accountRepo := &claudeAccountRepoStub{}
oauthService := &claudeOAuthServiceStub{
tokenInfo: &TokenInfo{
AccessToken: "refreshed-token",
RefreshToken: "new-refresh",
TokenType: "Bearer",
ExpiresIn: 3600,
},
}
// Token expires soon
expiresAt := time.Now().Add(1 * time.Minute).Format(time.RFC3339)
account := &Account{
ID: 116,
Platform: PlatformAnthropic,
Type: AccountTypeOAuth,
Credentials: map[string]any{
"access_token": "old-token",
"expires_at": expiresAt,
},
}
accountRepo.account = account
cacheKey := ClaudeTokenCacheKey(account)
// After lock is acquired, cache should have the token (simulating another worker)
go func() {
time.Sleep(5 * time.Millisecond)
cache.mu.Lock()
cache.tokens[cacheKey] = "cached-by-other-worker"
cache.mu.Unlock()
}()
provider := &testClaudeTokenProvider{
accountRepo: accountRepo,
tokenCache: cache,
oauthService: oauthService,
}
token, err := provider.GetAccessToken(context.Background(), account)
require.NoError(t, err)
require.NotEmpty(t, token)
}
// Tests for real provider - to increase coverage
func TestClaudeTokenProvider_Real_LockFailedWait(t *testing.T) {
cache := newClaudeTokenCacheStub()
cache.lockAcquired = false // Lock acquisition fails
// Token expires soon (within refresh skew) to trigger lock attempt
expiresAt := time.Now().Add(1 * time.Minute).Format(time.RFC3339)
account := &Account{
ID: 300,
Platform: PlatformAnthropic,
Type: AccountTypeOAuth,
Credentials: map[string]any{
"access_token": "fallback-token",
"expires_at": expiresAt,
},
}
// Set token in cache after lock wait period (simulate other worker refreshing)
cacheKey := ClaudeTokenCacheKey(account)
go func() {
time.Sleep(100 * time.Millisecond)
cache.mu.Lock()
cache.tokens[cacheKey] = "refreshed-by-other"
cache.mu.Unlock()
}()
provider := NewClaudeTokenProvider(nil, cache, nil)
token, err := provider.GetAccessToken(context.Background(), account)
require.NoError(t, err)
require.NotEmpty(t, token)
}
func TestClaudeTokenProvider_Real_CacheHitAfterWait(t *testing.T) {
cache := newClaudeTokenCacheStub()
cache.lockAcquired = false // Lock acquisition fails
// Token expires soon
expiresAt := time.Now().Add(1 * time.Minute).Format(time.RFC3339)
account := &Account{
ID: 301,
Platform: PlatformAnthropic,
Type: AccountTypeOAuth,
Credentials: map[string]any{
"access_token": "original-token",
"expires_at": expiresAt,
},
}
cacheKey := ClaudeTokenCacheKey(account)
// Set token in cache immediately after wait starts
go func() {
time.Sleep(50 * time.Millisecond)
cache.mu.Lock()
cache.tokens[cacheKey] = "winner-token"
cache.mu.Unlock()
}()
provider := NewClaudeTokenProvider(nil, cache, nil)
token, err := provider.GetAccessToken(context.Background(), account)
require.NoError(t, err)
require.NotEmpty(t, token)
}
func TestClaudeTokenProvider_Real_NoExpiresAt(t *testing.T) {
cache := newClaudeTokenCacheStub()
cache.lockAcquired = false // Prevent entering refresh logic
// Token with nil expires_at (no expiry set)
account := &Account{
ID: 302,
Platform: PlatformAnthropic,
Type: AccountTypeOAuth,
Credentials: map[string]any{
"access_token": "no-expiry-token",
},
}
// After lock wait, return token from credentials
provider := NewClaudeTokenProvider(nil, cache, nil)
token, err := provider.GetAccessToken(context.Background(), account)
require.NoError(t, err)
require.Equal(t, "no-expiry-token", token)
}
func TestClaudeTokenProvider_Real_WhitespaceToken(t *testing.T) {
cache := newClaudeTokenCacheStub()
cacheKey := "claude:account:303"
cache.tokens[cacheKey] = " " // Whitespace only - should be treated as empty
expiresAt := time.Now().Add(1 * time.Hour).Format(time.RFC3339)
account := &Account{
ID: 303,
Platform: PlatformAnthropic,
Type: AccountTypeOAuth,
Credentials: map[string]any{
"access_token": "real-token",
"expires_at": expiresAt,
},
}
provider := NewClaudeTokenProvider(nil, cache, nil)
token, err := provider.GetAccessToken(context.Background(), account)
require.NoError(t, err)
require.Equal(t, "real-token", token)
}
func TestClaudeTokenProvider_Real_EmptyCredentialToken(t *testing.T) {
cache := newClaudeTokenCacheStub()
expiresAt := time.Now().Add(1 * time.Hour).Format(time.RFC3339)
account := &Account{
ID: 304,
Platform: PlatformAnthropic,
Type: AccountTypeOAuth,
Credentials: map[string]any{
"access_token": " ", // Whitespace only
"expires_at": expiresAt,
},
}
provider := NewClaudeTokenProvider(nil, cache, nil)
token, err := provider.GetAccessToken(context.Background(), account)
require.Error(t, err)
require.Contains(t, err.Error(), "access_token not found")
require.Empty(t, token)
}
func TestClaudeTokenProvider_Real_LockError(t *testing.T) {
cache := newClaudeTokenCacheStub()
cache.lockErr = errors.New("redis lock failed")
// Token expires soon (within refresh skew)
expiresAt := time.Now().Add(1 * time.Minute).Format(time.RFC3339)
account := &Account{
ID: 305,
Platform: PlatformAnthropic,
Type: AccountTypeOAuth,
Credentials: map[string]any{
"access_token": "fallback-on-lock-error",
"expires_at": expiresAt,
},
}
provider := NewClaudeTokenProvider(nil, cache, nil)
token, err := provider.GetAccessToken(context.Background(), account)
require.NoError(t, err)
require.Equal(t, "fallback-on-lock-error", token)
}
func TestClaudeTokenProvider_Real_NilCredentials(t *testing.T) {
cache := newClaudeTokenCacheStub()
expiresAt := time.Now().Add(1 * time.Hour).Format(time.RFC3339)
account := &Account{
ID: 306,
Platform: PlatformAnthropic,
Type: AccountTypeOAuth,
Credentials: map[string]any{
"expires_at": expiresAt,
// No access_token
},
}
provider := NewClaudeTokenProvider(nil, cache, nil)
token, err := provider.GetAccessToken(context.Background(), account)
require.Error(t, err)
require.Contains(t, err.Error(), "access_token not found")
require.Empty(t, token)
}

View File

@@ -144,21 +144,22 @@ func (e *UpstreamFailoverError) Error() string {
// GatewayService handles API gateway operations
type GatewayService struct {
accountRepo AccountRepository
groupRepo GroupRepository
usageLogRepo UsageLogRepository
userRepo UserRepository
userSubRepo UserSubscriptionRepository
cache GatewayCache
cfg *config.Config
schedulerSnapshot *SchedulerSnapshotService
billingService *BillingService
rateLimitService *RateLimitService
billingCacheService *BillingCacheService
identityService *IdentityService
httpUpstream HTTPUpstream
deferredService *DeferredService
concurrencyService *ConcurrencyService
accountRepo AccountRepository
groupRepo GroupRepository
usageLogRepo UsageLogRepository
userRepo UserRepository
userSubRepo UserSubscriptionRepository
cache GatewayCache
cfg *config.Config
schedulerSnapshot *SchedulerSnapshotService
billingService *BillingService
rateLimitService *RateLimitService
billingCacheService *BillingCacheService
identityService *IdentityService
httpUpstream HTTPUpstream
deferredService *DeferredService
concurrencyService *ConcurrencyService
claudeTokenProvider *ClaudeTokenProvider
}
// NewGatewayService creates a new GatewayService
@@ -178,23 +179,25 @@ func NewGatewayService(
identityService *IdentityService,
httpUpstream HTTPUpstream,
deferredService *DeferredService,
claudeTokenProvider *ClaudeTokenProvider,
) *GatewayService {
return &GatewayService{
accountRepo: accountRepo,
groupRepo: groupRepo,
usageLogRepo: usageLogRepo,
userRepo: userRepo,
userSubRepo: userSubRepo,
cache: cache,
cfg: cfg,
schedulerSnapshot: schedulerSnapshot,
concurrencyService: concurrencyService,
billingService: billingService,
rateLimitService: rateLimitService,
billingCacheService: billingCacheService,
identityService: identityService,
httpUpstream: httpUpstream,
deferredService: deferredService,
accountRepo: accountRepo,
groupRepo: groupRepo,
usageLogRepo: usageLogRepo,
userRepo: userRepo,
userSubRepo: userSubRepo,
cache: cache,
cfg: cfg,
schedulerSnapshot: schedulerSnapshot,
concurrencyService: concurrencyService,
billingService: billingService,
rateLimitService: rateLimitService,
billingCacheService: billingCacheService,
identityService: identityService,
httpUpstream: httpUpstream,
deferredService: deferredService,
claudeTokenProvider: claudeTokenProvider,
}
}
@@ -1079,6 +1082,16 @@ func (s *GatewayService) GetAccessToken(ctx context.Context, account *Account) (
}
func (s *GatewayService) getOAuthToken(ctx context.Context, account *Account) (string, string, error) {
// 对于 Anthropic OAuth 账号,使用 ClaudeTokenProvider 获取缓存的 token
if account.Platform == PlatformAnthropic && account.Type == AccountTypeOAuth && s.claudeTokenProvider != nil {
accessToken, err := s.claudeTokenProvider.GetAccessToken(ctx, account)
if err != nil {
return "", "", err
}
return accessToken, "oauth", nil
}
// 其他情况Gemini 有自己的 TokenProvidersetup-token 类型等)直接从账号读取
accessToken := account.GetCredential("access_token")
if accessToken == "" {
return "", "", errors.New("access_token not found in credentials")

View File

@@ -80,19 +80,20 @@ type OpenAIForwardResult struct {
// OpenAIGatewayService handles OpenAI API gateway operations
type OpenAIGatewayService struct {
accountRepo AccountRepository
usageLogRepo UsageLogRepository
userRepo UserRepository
userSubRepo UserSubscriptionRepository
cache GatewayCache
cfg *config.Config
schedulerSnapshot *SchedulerSnapshotService
concurrencyService *ConcurrencyService
billingService *BillingService
rateLimitService *RateLimitService
billingCacheService *BillingCacheService
httpUpstream HTTPUpstream
deferredService *DeferredService
accountRepo AccountRepository
usageLogRepo UsageLogRepository
userRepo UserRepository
userSubRepo UserSubscriptionRepository
cache GatewayCache
cfg *config.Config
schedulerSnapshot *SchedulerSnapshotService
concurrencyService *ConcurrencyService
billingService *BillingService
rateLimitService *RateLimitService
billingCacheService *BillingCacheService
httpUpstream HTTPUpstream
deferredService *DeferredService
openAITokenProvider *OpenAITokenProvider
}
// NewOpenAIGatewayService creates a new OpenAIGatewayService
@@ -110,21 +111,23 @@ func NewOpenAIGatewayService(
billingCacheService *BillingCacheService,
httpUpstream HTTPUpstream,
deferredService *DeferredService,
openAITokenProvider *OpenAITokenProvider,
) *OpenAIGatewayService {
return &OpenAIGatewayService{
accountRepo: accountRepo,
usageLogRepo: usageLogRepo,
userRepo: userRepo,
userSubRepo: userSubRepo,
cache: cache,
cfg: cfg,
schedulerSnapshot: schedulerSnapshot,
concurrencyService: concurrencyService,
billingService: billingService,
rateLimitService: rateLimitService,
billingCacheService: billingCacheService,
httpUpstream: httpUpstream,
deferredService: deferredService,
accountRepo: accountRepo,
usageLogRepo: usageLogRepo,
userRepo: userRepo,
userSubRepo: userSubRepo,
cache: cache,
cfg: cfg,
schedulerSnapshot: schedulerSnapshot,
concurrencyService: concurrencyService,
billingService: billingService,
rateLimitService: rateLimitService,
billingCacheService: billingCacheService,
httpUpstream: httpUpstream,
deferredService: deferredService,
openAITokenProvider: openAITokenProvider,
}
}
@@ -503,6 +506,15 @@ func (s *OpenAIGatewayService) schedulingConfig() config.GatewaySchedulingConfig
func (s *OpenAIGatewayService) GetAccessToken(ctx context.Context, account *Account) (string, string, error) {
switch account.Type {
case AccountTypeOAuth:
// 使用 TokenProvider 获取缓存的 token
if s.openAITokenProvider != nil {
accessToken, err := s.openAITokenProvider.GetAccessToken(ctx, account)
if err != nil {
return "", "", err
}
return accessToken, "oauth", nil
}
// 降级TokenProvider 未配置时直接从账号读取
accessToken := account.GetOpenAIAccessToken()
if accessToken == "" {
return "", "", errors.New("access_token not found in credentials")

View File

@@ -0,0 +1,146 @@
package service
import (
"context"
"errors"
"log/slog"
"strings"
"time"
)
const (
openAITokenRefreshSkew = 3 * time.Minute
openAITokenCacheSkew = 5 * time.Minute
openAILockWaitTime = 200 * time.Millisecond
)
// OpenAITokenCache Token 缓存接口(复用 GeminiTokenCache 接口定义)
type OpenAITokenCache = GeminiTokenCache
// OpenAITokenProvider 管理 OpenAI OAuth 账户的 access_token
type OpenAITokenProvider struct {
accountRepo AccountRepository
tokenCache OpenAITokenCache
openAIOAuthService *OpenAIOAuthService
}
func NewOpenAITokenProvider(
accountRepo AccountRepository,
tokenCache OpenAITokenCache,
openAIOAuthService *OpenAIOAuthService,
) *OpenAITokenProvider {
return &OpenAITokenProvider{
accountRepo: accountRepo,
tokenCache: tokenCache,
openAIOAuthService: openAIOAuthService,
}
}
// GetAccessToken 获取有效的 access_token
func (p *OpenAITokenProvider) GetAccessToken(ctx context.Context, account *Account) (string, error) {
if account == nil {
return "", errors.New("account is nil")
}
if account.Platform != PlatformOpenAI || account.Type != AccountTypeOAuth {
return "", errors.New("not an openai oauth account")
}
cacheKey := OpenAITokenCacheKey(account)
// 1. 先尝试缓存
if p.tokenCache != nil {
if token, err := p.tokenCache.GetAccessToken(ctx, cacheKey); err == nil && strings.TrimSpace(token) != "" {
slog.Debug("openai_token_cache_hit", "account_id", account.ID)
return token, nil
} else if err != nil {
slog.Warn("openai_token_cache_get_failed", "account_id", account.ID, "error", err)
}
}
slog.Debug("openai_token_cache_miss", "account_id", account.ID)
// 2. 如果即将过期则刷新
expiresAt := account.GetCredentialAsTime("expires_at")
needsRefresh := expiresAt == nil || time.Until(*expiresAt) <= openAITokenRefreshSkew
refreshFailed := false
if needsRefresh && p.tokenCache != nil {
locked, err := p.tokenCache.AcquireRefreshLock(ctx, cacheKey, 30*time.Second)
if err == nil && locked {
defer func() { _ = p.tokenCache.ReleaseRefreshLock(ctx, cacheKey) }()
// 拿到锁后再次检查缓存(另一个 worker 可能已刷新)
if token, err := p.tokenCache.GetAccessToken(ctx, cacheKey); err == nil && strings.TrimSpace(token) != "" {
return token, nil
}
// 从数据库获取最新账户信息
fresh, err := p.accountRepo.GetByID(ctx, account.ID)
if err == nil && fresh != nil {
account = fresh
}
expiresAt = account.GetCredentialAsTime("expires_at")
if expiresAt == nil || time.Until(*expiresAt) <= openAITokenRefreshSkew {
if p.openAIOAuthService == nil {
slog.Warn("openai_oauth_service_not_configured", "account_id", account.ID)
refreshFailed = true // 无法刷新,标记失败
} else {
tokenInfo, err := p.openAIOAuthService.RefreshAccountToken(ctx, account)
if err != nil {
// 刷新失败时记录警告,但不立即返回错误,尝试使用现有 token
slog.Warn("openai_token_refresh_failed", "account_id", account.ID, "error", err)
refreshFailed = true // 刷新失败,标记以使用短 TTL
} else {
newCredentials := p.openAIOAuthService.BuildAccountCredentials(tokenInfo)
for k, v := range account.Credentials {
if _, exists := newCredentials[k]; !exists {
newCredentials[k] = v
}
}
account.Credentials = newCredentials
if updateErr := p.accountRepo.Update(ctx, account); updateErr != nil {
slog.Error("openai_token_provider_update_failed", "account_id", account.ID, "error", updateErr)
}
expiresAt = account.GetCredentialAsTime("expires_at")
}
}
}
} else {
// 锁获取失败,等待 200ms 后重试读取缓存(改进:减少并发时的缓存未命中)
time.Sleep(openAILockWaitTime)
if token, err := p.tokenCache.GetAccessToken(ctx, cacheKey); err == nil && strings.TrimSpace(token) != "" {
slog.Debug("openai_token_cache_hit_after_wait", "account_id", account.ID)
return token, nil
}
}
}
accessToken := account.GetOpenAIAccessToken()
if strings.TrimSpace(accessToken) == "" {
return "", errors.New("access_token not found in credentials")
}
// 3. 存入缓存
if p.tokenCache != nil {
ttl := 30 * time.Minute
if refreshFailed {
// 刷新失败时使用短 TTL避免失效 token 长时间缓存导致 401 抖动
ttl = time.Minute
slog.Debug("openai_token_cache_short_ttl", "account_id", account.ID, "reason", "refresh_failed")
} else if expiresAt != nil {
until := time.Until(*expiresAt)
switch {
case until > openAITokenCacheSkew:
ttl = until - openAITokenCacheSkew
case until > 0:
ttl = until
default:
ttl = time.Minute
}
}
if err := p.tokenCache.SetAccessToken(ctx, cacheKey, accessToken, ttl); err != nil {
slog.Warn("openai_token_cache_set_failed", "account_id", account.ID, "error", err)
}
}
return accessToken, nil
}

View File

@@ -0,0 +1,810 @@
//go:build unit
package service
import (
"context"
"errors"
"sync"
"sync/atomic"
"testing"
"time"
"github.com/stretchr/testify/require"
)
// openAITokenCacheStub implements OpenAITokenCache for testing
type openAITokenCacheStub struct {
mu sync.Mutex
tokens map[string]string
getErr error
setErr error
deleteErr error
lockAcquired bool
lockErr error
releaseLockErr error
getCalled int32
setCalled int32
lockCalled int32
unlockCalled int32
simulateLockRace bool
}
func newOpenAITokenCacheStub() *openAITokenCacheStub {
return &openAITokenCacheStub{
tokens: make(map[string]string),
lockAcquired: true,
}
}
func (s *openAITokenCacheStub) GetAccessToken(ctx context.Context, cacheKey string) (string, error) {
atomic.AddInt32(&s.getCalled, 1)
if s.getErr != nil {
return "", s.getErr
}
s.mu.Lock()
defer s.mu.Unlock()
return s.tokens[cacheKey], nil
}
func (s *openAITokenCacheStub) SetAccessToken(ctx context.Context, cacheKey string, token string, ttl time.Duration) error {
atomic.AddInt32(&s.setCalled, 1)
if s.setErr != nil {
return s.setErr
}
s.mu.Lock()
defer s.mu.Unlock()
s.tokens[cacheKey] = token
return nil
}
func (s *openAITokenCacheStub) DeleteAccessToken(ctx context.Context, cacheKey string) error {
if s.deleteErr != nil {
return s.deleteErr
}
s.mu.Lock()
defer s.mu.Unlock()
delete(s.tokens, cacheKey)
return nil
}
func (s *openAITokenCacheStub) AcquireRefreshLock(ctx context.Context, cacheKey string, ttl time.Duration) (bool, error) {
atomic.AddInt32(&s.lockCalled, 1)
if s.lockErr != nil {
return false, s.lockErr
}
if s.simulateLockRace {
return false, nil
}
return s.lockAcquired, nil
}
func (s *openAITokenCacheStub) ReleaseRefreshLock(ctx context.Context, cacheKey string) error {
atomic.AddInt32(&s.unlockCalled, 1)
return s.releaseLockErr
}
// openAIAccountRepoStub is a minimal stub implementing only the methods used by OpenAITokenProvider
type openAIAccountRepoStub struct {
account *Account
getErr error
updateErr error
getCalled int32
updateCalled int32
}
func (r *openAIAccountRepoStub) GetByID(ctx context.Context, id int64) (*Account, error) {
atomic.AddInt32(&r.getCalled, 1)
if r.getErr != nil {
return nil, r.getErr
}
return r.account, nil
}
func (r *openAIAccountRepoStub) Update(ctx context.Context, account *Account) error {
atomic.AddInt32(&r.updateCalled, 1)
if r.updateErr != nil {
return r.updateErr
}
r.account = account
return nil
}
// openAIOAuthServiceStub implements OpenAIOAuthService methods for testing
type openAIOAuthServiceStub struct {
tokenInfo *OpenAITokenInfo
refreshErr error
refreshCalled int32
}
func (s *openAIOAuthServiceStub) RefreshAccountToken(ctx context.Context, account *Account) (*OpenAITokenInfo, error) {
atomic.AddInt32(&s.refreshCalled, 1)
if s.refreshErr != nil {
return nil, s.refreshErr
}
return s.tokenInfo, nil
}
func (s *openAIOAuthServiceStub) BuildAccountCredentials(info *OpenAITokenInfo) map[string]any {
now := time.Now()
return map[string]any{
"access_token": info.AccessToken,
"refresh_token": info.RefreshToken,
"expires_at": now.Add(time.Duration(info.ExpiresIn) * time.Second).Format(time.RFC3339),
}
}
func TestOpenAITokenProvider_CacheHit(t *testing.T) {
cache := newOpenAITokenCacheStub()
account := &Account{
ID: 100,
Platform: PlatformOpenAI,
Type: AccountTypeOAuth,
Credentials: map[string]any{
"access_token": "db-token",
},
}
cacheKey := OpenAITokenCacheKey(account)
cache.tokens[cacheKey] = "cached-token"
provider := NewOpenAITokenProvider(nil, cache, nil)
token, err := provider.GetAccessToken(context.Background(), account)
require.NoError(t, err)
require.Equal(t, "cached-token", token)
require.Equal(t, int32(1), atomic.LoadInt32(&cache.getCalled))
require.Equal(t, int32(0), atomic.LoadInt32(&cache.setCalled))
}
func TestOpenAITokenProvider_CacheMiss_FromCredentials(t *testing.T) {
cache := newOpenAITokenCacheStub()
// Token expires in far future, no refresh needed
expiresAt := time.Now().Add(1 * time.Hour).Format(time.RFC3339)
account := &Account{
ID: 101,
Platform: PlatformOpenAI,
Type: AccountTypeOAuth,
Credentials: map[string]any{
"access_token": "credential-token",
"expires_at": expiresAt,
},
}
provider := NewOpenAITokenProvider(nil, cache, nil)
token, err := provider.GetAccessToken(context.Background(), account)
require.NoError(t, err)
require.Equal(t, "credential-token", token)
// Should have stored in cache
cacheKey := OpenAITokenCacheKey(account)
require.Equal(t, "credential-token", cache.tokens[cacheKey])
}
func TestOpenAITokenProvider_TokenRefresh(t *testing.T) {
cache := newOpenAITokenCacheStub()
accountRepo := &openAIAccountRepoStub{}
oauthService := &openAIOAuthServiceStub{
tokenInfo: &OpenAITokenInfo{
AccessToken: "refreshed-token",
RefreshToken: "new-refresh-token",
ExpiresIn: 3600,
},
}
// Token expires soon (within refresh skew)
expiresAt := time.Now().Add(1 * time.Minute).Format(time.RFC3339)
account := &Account{
ID: 102,
Platform: PlatformOpenAI,
Type: AccountTypeOAuth,
Credentials: map[string]any{
"access_token": "old-token",
"refresh_token": "old-refresh-token",
"expires_at": expiresAt,
},
}
accountRepo.account = account
// We need to directly test with the stub - create a custom provider
customProvider := &testOpenAITokenProvider{
accountRepo: accountRepo,
tokenCache: cache,
oauthService: oauthService,
}
token, err := customProvider.GetAccessToken(context.Background(), account)
require.NoError(t, err)
require.Equal(t, "refreshed-token", token)
require.Equal(t, int32(1), atomic.LoadInt32(&oauthService.refreshCalled))
}
// testOpenAITokenProvider is a test version that uses the stub OAuth service
type testOpenAITokenProvider struct {
accountRepo *openAIAccountRepoStub
tokenCache *openAITokenCacheStub
oauthService *openAIOAuthServiceStub
}
func (p *testOpenAITokenProvider) GetAccessToken(ctx context.Context, account *Account) (string, error) {
if account == nil {
return "", errors.New("account is nil")
}
if account.Platform != PlatformOpenAI || account.Type != AccountTypeOAuth {
return "", errors.New("not an openai oauth account")
}
cacheKey := OpenAITokenCacheKey(account)
// 1. Check cache
if p.tokenCache != nil {
if token, err := p.tokenCache.GetAccessToken(ctx, cacheKey); err == nil && token != "" {
return token, nil
}
}
// 2. Check if refresh needed
expiresAt := account.GetCredentialAsTime("expires_at")
needsRefresh := expiresAt == nil || time.Until(*expiresAt) <= openAITokenRefreshSkew
refreshFailed := false
if needsRefresh && p.tokenCache != nil {
locked, err := p.tokenCache.AcquireRefreshLock(ctx, cacheKey, 30*time.Second)
if err == nil && locked {
defer func() { _ = p.tokenCache.ReleaseRefreshLock(ctx, cacheKey) }()
// Check cache again after acquiring lock
if token, err := p.tokenCache.GetAccessToken(ctx, cacheKey); err == nil && token != "" {
return token, nil
}
// Get fresh account from DB
fresh, err := p.accountRepo.GetByID(ctx, account.ID)
if err == nil && fresh != nil {
account = fresh
}
expiresAt = account.GetCredentialAsTime("expires_at")
if expiresAt == nil || time.Until(*expiresAt) <= openAITokenRefreshSkew {
if p.oauthService == nil {
refreshFailed = true // 无法刷新,标记失败
} else {
tokenInfo, err := p.oauthService.RefreshAccountToken(ctx, account)
if err != nil {
refreshFailed = true // 刷新失败,标记以使用短 TTL
} else {
newCredentials := p.oauthService.BuildAccountCredentials(tokenInfo)
for k, v := range account.Credentials {
if _, exists := newCredentials[k]; !exists {
newCredentials[k] = v
}
}
account.Credentials = newCredentials
_ = p.accountRepo.Update(ctx, account)
expiresAt = account.GetCredentialAsTime("expires_at")
}
}
}
} else if p.tokenCache.simulateLockRace {
// Wait and retry cache
time.Sleep(10 * time.Millisecond) // Short wait for test
if token, err := p.tokenCache.GetAccessToken(ctx, cacheKey); err == nil && token != "" {
return token, nil
}
}
}
accessToken := account.GetOpenAIAccessToken()
if accessToken == "" {
return "", errors.New("access_token not found in credentials")
}
// 3. Store in cache
if p.tokenCache != nil {
ttl := 30 * time.Minute
if refreshFailed {
ttl = time.Minute // 刷新失败时使用短 TTL
} else if expiresAt != nil {
until := time.Until(*expiresAt)
if until > openAITokenCacheSkew {
ttl = until - openAITokenCacheSkew
} else if until > 0 {
ttl = until
} else {
ttl = time.Minute
}
}
_ = p.tokenCache.SetAccessToken(ctx, cacheKey, accessToken, ttl)
}
return accessToken, nil
}
func TestOpenAITokenProvider_LockRaceCondition(t *testing.T) {
cache := newOpenAITokenCacheStub()
cache.simulateLockRace = true
accountRepo := &openAIAccountRepoStub{}
// Token expires soon
expiresAt := time.Now().Add(1 * time.Minute).Format(time.RFC3339)
account := &Account{
ID: 103,
Platform: PlatformOpenAI,
Type: AccountTypeOAuth,
Credentials: map[string]any{
"access_token": "race-token",
"expires_at": expiresAt,
},
}
accountRepo.account = account
// Simulate another worker already refreshed and cached
cacheKey := OpenAITokenCacheKey(account)
go func() {
time.Sleep(5 * time.Millisecond)
cache.mu.Lock()
cache.tokens[cacheKey] = "winner-token"
cache.mu.Unlock()
}()
provider := &testOpenAITokenProvider{
accountRepo: accountRepo,
tokenCache: cache,
}
token, err := provider.GetAccessToken(context.Background(), account)
require.NoError(t, err)
// Should get the token set by the "winner" or the original
require.NotEmpty(t, token)
}
func TestOpenAITokenProvider_NilAccount(t *testing.T) {
provider := NewOpenAITokenProvider(nil, nil, nil)
token, err := provider.GetAccessToken(context.Background(), nil)
require.Error(t, err)
require.Contains(t, err.Error(), "account is nil")
require.Empty(t, token)
}
func TestOpenAITokenProvider_WrongPlatform(t *testing.T) {
provider := NewOpenAITokenProvider(nil, nil, nil)
account := &Account{
ID: 104,
Platform: PlatformGemini,
Type: AccountTypeOAuth,
}
token, err := provider.GetAccessToken(context.Background(), account)
require.Error(t, err)
require.Contains(t, err.Error(), "not an openai oauth account")
require.Empty(t, token)
}
func TestOpenAITokenProvider_WrongAccountType(t *testing.T) {
provider := NewOpenAITokenProvider(nil, nil, nil)
account := &Account{
ID: 105,
Platform: PlatformOpenAI,
Type: AccountTypeAPIKey,
}
token, err := provider.GetAccessToken(context.Background(), account)
require.Error(t, err)
require.Contains(t, err.Error(), "not an openai oauth account")
require.Empty(t, token)
}
func TestOpenAITokenProvider_NilCache(t *testing.T) {
// Token doesn't need refresh
expiresAt := time.Now().Add(1 * time.Hour).Format(time.RFC3339)
account := &Account{
ID: 106,
Platform: PlatformOpenAI,
Type: AccountTypeOAuth,
Credentials: map[string]any{
"access_token": "nocache-token",
"expires_at": expiresAt,
},
}
provider := NewOpenAITokenProvider(nil, nil, nil)
token, err := provider.GetAccessToken(context.Background(), account)
require.NoError(t, err)
require.Equal(t, "nocache-token", token)
}
func TestOpenAITokenProvider_CacheGetError(t *testing.T) {
cache := newOpenAITokenCacheStub()
cache.getErr = errors.New("redis connection failed")
// Token doesn't need refresh
expiresAt := time.Now().Add(1 * time.Hour).Format(time.RFC3339)
account := &Account{
ID: 107,
Platform: PlatformOpenAI,
Type: AccountTypeOAuth,
Credentials: map[string]any{
"access_token": "fallback-token",
"expires_at": expiresAt,
},
}
provider := NewOpenAITokenProvider(nil, cache, nil)
// Should gracefully degrade and return from credentials
token, err := provider.GetAccessToken(context.Background(), account)
require.NoError(t, err)
require.Equal(t, "fallback-token", token)
}
func TestOpenAITokenProvider_CacheSetError(t *testing.T) {
cache := newOpenAITokenCacheStub()
cache.setErr = errors.New("redis write failed")
expiresAt := time.Now().Add(1 * time.Hour).Format(time.RFC3339)
account := &Account{
ID: 108,
Platform: PlatformOpenAI,
Type: AccountTypeOAuth,
Credentials: map[string]any{
"access_token": "still-works-token",
"expires_at": expiresAt,
},
}
provider := NewOpenAITokenProvider(nil, cache, nil)
// Should still work even if cache set fails
token, err := provider.GetAccessToken(context.Background(), account)
require.NoError(t, err)
require.Equal(t, "still-works-token", token)
}
func TestOpenAITokenProvider_MissingAccessToken(t *testing.T) {
cache := newOpenAITokenCacheStub()
expiresAt := time.Now().Add(1 * time.Hour).Format(time.RFC3339)
account := &Account{
ID: 109,
Platform: PlatformOpenAI,
Type: AccountTypeOAuth,
Credentials: map[string]any{
"expires_at": expiresAt,
// missing access_token
},
}
provider := NewOpenAITokenProvider(nil, cache, nil)
token, err := provider.GetAccessToken(context.Background(), account)
require.Error(t, err)
require.Contains(t, err.Error(), "access_token not found")
require.Empty(t, token)
}
func TestOpenAITokenProvider_RefreshError(t *testing.T) {
cache := newOpenAITokenCacheStub()
accountRepo := &openAIAccountRepoStub{}
oauthService := &openAIOAuthServiceStub{
refreshErr: errors.New("oauth refresh failed"),
}
// Token expires soon
expiresAt := time.Now().Add(1 * time.Minute).Format(time.RFC3339)
account := &Account{
ID: 110,
Platform: PlatformOpenAI,
Type: AccountTypeOAuth,
Credentials: map[string]any{
"access_token": "old-token",
"refresh_token": "old-refresh-token",
"expires_at": expiresAt,
},
}
accountRepo.account = account
provider := &testOpenAITokenProvider{
accountRepo: accountRepo,
tokenCache: cache,
oauthService: oauthService,
}
// Now with fallback behavior, should return existing token even if refresh fails
token, err := provider.GetAccessToken(context.Background(), account)
require.NoError(t, err)
require.Equal(t, "old-token", token) // Fallback to existing token
}
func TestOpenAITokenProvider_OAuthServiceNotConfigured(t *testing.T) {
cache := newOpenAITokenCacheStub()
accountRepo := &openAIAccountRepoStub{}
// Token expires soon
expiresAt := time.Now().Add(1 * time.Minute).Format(time.RFC3339)
account := &Account{
ID: 111,
Platform: PlatformOpenAI,
Type: AccountTypeOAuth,
Credentials: map[string]any{
"access_token": "old-token",
"expires_at": expiresAt,
},
}
accountRepo.account = account
provider := &testOpenAITokenProvider{
accountRepo: accountRepo,
tokenCache: cache,
oauthService: nil, // not configured
}
// Now with fallback behavior, should return existing token even if oauth service not configured
token, err := provider.GetAccessToken(context.Background(), account)
require.NoError(t, err)
require.Equal(t, "old-token", token) // Fallback to existing token
}
func TestOpenAITokenProvider_TTLCalculation(t *testing.T) {
tests := []struct {
name string
expiresIn time.Duration
}{
{
name: "far_future_expiry",
expiresIn: 1 * time.Hour,
},
{
name: "medium_expiry",
expiresIn: 10 * time.Minute,
},
{
name: "near_expiry",
expiresIn: 6 * time.Minute,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
cache := newOpenAITokenCacheStub()
expiresAt := time.Now().Add(tt.expiresIn).Format(time.RFC3339)
account := &Account{
ID: 200,
Platform: PlatformOpenAI,
Type: AccountTypeOAuth,
Credentials: map[string]any{
"access_token": "test-token",
"expires_at": expiresAt,
},
}
provider := NewOpenAITokenProvider(nil, cache, nil)
_, err := provider.GetAccessToken(context.Background(), account)
require.NoError(t, err)
// Verify token was cached
cacheKey := OpenAITokenCacheKey(account)
require.Equal(t, "test-token", cache.tokens[cacheKey])
})
}
}
func TestOpenAITokenProvider_DoubleCheckAfterLock(t *testing.T) {
cache := newOpenAITokenCacheStub()
accountRepo := &openAIAccountRepoStub{}
oauthService := &openAIOAuthServiceStub{
tokenInfo: &OpenAITokenInfo{
AccessToken: "refreshed-token",
RefreshToken: "new-refresh",
ExpiresIn: 3600,
},
}
// Token expires soon
expiresAt := time.Now().Add(1 * time.Minute).Format(time.RFC3339)
account := &Account{
ID: 112,
Platform: PlatformOpenAI,
Type: AccountTypeOAuth,
Credentials: map[string]any{
"access_token": "old-token",
"expires_at": expiresAt,
},
}
accountRepo.account = account
cacheKey := OpenAITokenCacheKey(account)
// Simulate: first GetAccessToken returns empty, but after lock acquired, cache has token
originalGet := int32(0)
cache.tokens[cacheKey] = "" // Empty initially
provider := &testOpenAITokenProvider{
accountRepo: accountRepo,
tokenCache: cache,
oauthService: oauthService,
}
// In a goroutine, set the cached token after a small delay (simulating race)
go func() {
time.Sleep(5 * time.Millisecond)
cache.mu.Lock()
cache.tokens[cacheKey] = "cached-by-other"
cache.mu.Unlock()
}()
token, err := provider.GetAccessToken(context.Background(), account)
require.NoError(t, err)
// Should get either the refreshed token or the cached one
require.NotEmpty(t, token)
_ = originalGet // Suppress unused warning
}
// Tests for real provider - to increase coverage
func TestOpenAITokenProvider_Real_LockFailedWait(t *testing.T) {
cache := newOpenAITokenCacheStub()
cache.lockAcquired = false // Lock acquisition fails
// Token expires soon (within refresh skew) to trigger lock attempt
expiresAt := time.Now().Add(1 * time.Minute).Format(time.RFC3339)
account := &Account{
ID: 200,
Platform: PlatformOpenAI,
Type: AccountTypeOAuth,
Credentials: map[string]any{
"access_token": "fallback-token",
"expires_at": expiresAt,
},
}
// Set token in cache after lock wait period (simulate other worker refreshing)
cacheKey := OpenAITokenCacheKey(account)
go func() {
time.Sleep(100 * time.Millisecond)
cache.mu.Lock()
cache.tokens[cacheKey] = "refreshed-by-other"
cache.mu.Unlock()
}()
provider := NewOpenAITokenProvider(nil, cache, nil)
token, err := provider.GetAccessToken(context.Background(), account)
require.NoError(t, err)
// Should get either the fallback token or the refreshed one
require.NotEmpty(t, token)
}
func TestOpenAITokenProvider_Real_CacheHitAfterWait(t *testing.T) {
cache := newOpenAITokenCacheStub()
cache.lockAcquired = false // Lock acquisition fails
// Token expires soon
expiresAt := time.Now().Add(1 * time.Minute).Format(time.RFC3339)
account := &Account{
ID: 201,
Platform: PlatformOpenAI,
Type: AccountTypeOAuth,
Credentials: map[string]any{
"access_token": "original-token",
"expires_at": expiresAt,
},
}
cacheKey := OpenAITokenCacheKey(account)
// Set token in cache immediately after wait starts
go func() {
time.Sleep(50 * time.Millisecond)
cache.mu.Lock()
cache.tokens[cacheKey] = "winner-token"
cache.mu.Unlock()
}()
provider := NewOpenAITokenProvider(nil, cache, nil)
token, err := provider.GetAccessToken(context.Background(), account)
require.NoError(t, err)
require.NotEmpty(t, token)
}
func TestOpenAITokenProvider_Real_ExpiredWithoutRefreshToken(t *testing.T) {
cache := newOpenAITokenCacheStub()
cache.lockAcquired = false // Prevent entering refresh logic
// Token with nil expires_at (no expiry set) - should use credentials
account := &Account{
ID: 202,
Platform: PlatformOpenAI,
Type: AccountTypeOAuth,
Credentials: map[string]any{
"access_token": "no-expiry-token",
},
}
provider := NewOpenAITokenProvider(nil, cache, nil)
token, err := provider.GetAccessToken(context.Background(), account)
// Without OAuth service, refresh will fail but token should be returned from credentials
require.NoError(t, err)
require.Equal(t, "no-expiry-token", token)
}
func TestOpenAITokenProvider_Real_WhitespaceToken(t *testing.T) {
cache := newOpenAITokenCacheStub()
cacheKey := "openai:account:203"
cache.tokens[cacheKey] = " " // Whitespace only - should be treated as empty
expiresAt := time.Now().Add(1 * time.Hour).Format(time.RFC3339)
account := &Account{
ID: 203,
Platform: PlatformOpenAI,
Type: AccountTypeOAuth,
Credentials: map[string]any{
"access_token": "real-token",
"expires_at": expiresAt,
},
}
provider := NewOpenAITokenProvider(nil, cache, nil)
token, err := provider.GetAccessToken(context.Background(), account)
require.NoError(t, err)
require.Equal(t, "real-token", token) // Should fall back to credentials
}
func TestOpenAITokenProvider_Real_LockError(t *testing.T) {
cache := newOpenAITokenCacheStub()
cache.lockErr = errors.New("redis lock failed")
// Token expires soon (within refresh skew)
expiresAt := time.Now().Add(1 * time.Minute).Format(time.RFC3339)
account := &Account{
ID: 204,
Platform: PlatformOpenAI,
Type: AccountTypeOAuth,
Credentials: map[string]any{
"access_token": "fallback-on-lock-error",
"expires_at": expiresAt,
},
}
provider := NewOpenAITokenProvider(nil, cache, nil)
token, err := provider.GetAccessToken(context.Background(), account)
require.NoError(t, err)
require.Equal(t, "fallback-on-lock-error", token)
}
func TestOpenAITokenProvider_Real_WhitespaceCredentialToken(t *testing.T) {
cache := newOpenAITokenCacheStub()
expiresAt := time.Now().Add(1 * time.Hour).Format(time.RFC3339)
account := &Account{
ID: 205,
Platform: PlatformOpenAI,
Type: AccountTypeOAuth,
Credentials: map[string]any{
"access_token": " ", // Whitespace only
"expires_at": expiresAt,
},
}
provider := NewOpenAITokenProvider(nil, cache, nil)
token, err := provider.GetAccessToken(context.Background(), account)
require.Error(t, err)
require.Contains(t, err.Error(), "access_token not found")
require.Empty(t, token)
}
func TestOpenAITokenProvider_Real_NilCredentials(t *testing.T) {
cache := newOpenAITokenCacheStub()
expiresAt := time.Now().Add(1 * time.Hour).Format(time.RFC3339)
account := &Account{
ID: 206,
Platform: PlatformOpenAI,
Type: AccountTypeOAuth,
Credentials: map[string]any{
"expires_at": expiresAt,
// No access_token
},
}
provider := NewOpenAITokenProvider(nil, cache, nil)
token, err := provider.GetAccessToken(context.Background(), account)
require.Error(t, err)
require.Contains(t, err.Error(), "access_token not found")
require.Empty(t, token)
}

View File

@@ -85,8 +85,8 @@ func (s *RateLimitService) HandleUpstreamError(ctx context.Context, account *Acc
switch statusCode {
case 401:
if account.Type == AccountTypeOAuth &&
(account.Platform == PlatformAntigravity || account.Platform == PlatformGemini) {
// 对所有 OAuth 账号在 401 错误时调用缓存失效
if account.Type == AccountTypeOAuth {
if s.tokenCacheInvalidator != nil {
if err := s.tokenCacheInvalidator.InvalidateToken(ctx, account); err != nil {
slog.Warn("oauth_401_invalidate_cache_failed", "account_id", account.ID, "error", err)

View File

@@ -7,29 +7,35 @@ type TokenCacheInvalidator interface {
}
type CompositeTokenCacheInvalidator struct {
geminiCache GeminiTokenCache
cache GeminiTokenCache // 统一使用一个缓存接口,通过缓存键前缀区分平台
}
func NewCompositeTokenCacheInvalidator(geminiCache GeminiTokenCache) *CompositeTokenCacheInvalidator {
func NewCompositeTokenCacheInvalidator(cache GeminiTokenCache) *CompositeTokenCacheInvalidator {
return &CompositeTokenCacheInvalidator{
geminiCache: geminiCache,
cache: cache,
}
}
func (c *CompositeTokenCacheInvalidator) InvalidateToken(ctx context.Context, account *Account) error {
if c == nil || c.geminiCache == nil || account == nil {
if c == nil || c.cache == nil || account == nil {
return nil
}
if account.Type != AccountTypeOAuth {
return nil
}
var cacheKey string
switch account.Platform {
case PlatformGemini:
return c.geminiCache.DeleteAccessToken(ctx, GeminiTokenCacheKey(account))
cacheKey = GeminiTokenCacheKey(account)
case PlatformAntigravity:
return c.geminiCache.DeleteAccessToken(ctx, AntigravityTokenCacheKey(account))
cacheKey = AntigravityTokenCacheKey(account)
case PlatformOpenAI:
cacheKey = OpenAITokenCacheKey(account)
case PlatformAnthropic:
cacheKey = ClaudeTokenCacheKey(account)
default:
return nil
}
return c.cache.DeleteAccessToken(ctx, cacheKey)
}

View File

@@ -4,6 +4,7 @@ package service
import (
"context"
"errors"
"testing"
"time"
@@ -70,13 +71,99 @@ func TestCompositeTokenCacheInvalidator_Antigravity(t *testing.T) {
require.Equal(t, []string{"ag:ag-project"}, cache.deletedKeys)
}
func TestCompositeTokenCacheInvalidator_SkipNonOAuth(t *testing.T) {
func TestCompositeTokenCacheInvalidator_OpenAI(t *testing.T) {
cache := &geminiTokenCacheStub{}
invalidator := NewCompositeTokenCacheInvalidator(cache)
account := &Account{
ID: 1,
Platform: PlatformGemini,
Type: AccountTypeAPIKey,
ID: 500,
Platform: PlatformOpenAI,
Type: AccountTypeOAuth,
Credentials: map[string]any{
"access_token": "openai-token",
},
}
err := invalidator.InvalidateToken(context.Background(), account)
require.NoError(t, err)
require.Equal(t, []string{"openai:account:500"}, cache.deletedKeys)
}
func TestCompositeTokenCacheInvalidator_Claude(t *testing.T) {
cache := &geminiTokenCacheStub{}
invalidator := NewCompositeTokenCacheInvalidator(cache)
account := &Account{
ID: 600,
Platform: PlatformAnthropic,
Type: AccountTypeOAuth,
Credentials: map[string]any{
"access_token": "claude-token",
},
}
err := invalidator.InvalidateToken(context.Background(), account)
require.NoError(t, err)
require.Equal(t, []string{"claude:account:600"}, cache.deletedKeys)
}
func TestCompositeTokenCacheInvalidator_SkipNonOAuth(t *testing.T) {
cache := &geminiTokenCacheStub{}
invalidator := NewCompositeTokenCacheInvalidator(cache)
tests := []struct {
name string
account *Account
}{
{
name: "gemini_api_key",
account: &Account{
ID: 1,
Platform: PlatformGemini,
Type: AccountTypeAPIKey,
},
},
{
name: "openai_api_key",
account: &Account{
ID: 2,
Platform: PlatformOpenAI,
Type: AccountTypeAPIKey,
},
},
{
name: "claude_api_key",
account: &Account{
ID: 3,
Platform: PlatformAnthropic,
Type: AccountTypeAPIKey,
},
},
{
name: "claude_setup_token",
account: &Account{
ID: 4,
Platform: PlatformAnthropic,
Type: AccountTypeSetupToken,
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
cache.deletedKeys = nil
err := invalidator.InvalidateToken(context.Background(), tt.account)
require.NoError(t, err)
require.Empty(t, cache.deletedKeys)
})
}
}
func TestCompositeTokenCacheInvalidator_SkipUnsupportedPlatform(t *testing.T) {
cache := &geminiTokenCacheStub{}
invalidator := NewCompositeTokenCacheInvalidator(cache)
account := &Account{
ID: 100,
Platform: "unknown-platform",
Type: AccountTypeOAuth,
}
err := invalidator.InvalidateToken(context.Background(), account)
@@ -95,3 +182,87 @@ func TestCompositeTokenCacheInvalidator_NilCache(t *testing.T) {
err := invalidator.InvalidateToken(context.Background(), account)
require.NoError(t, err)
}
func TestCompositeTokenCacheInvalidator_NilAccount(t *testing.T) {
cache := &geminiTokenCacheStub{}
invalidator := NewCompositeTokenCacheInvalidator(cache)
err := invalidator.InvalidateToken(context.Background(), nil)
require.NoError(t, err)
require.Empty(t, cache.deletedKeys)
}
func TestCompositeTokenCacheInvalidator_NilInvalidator(t *testing.T) {
var invalidator *CompositeTokenCacheInvalidator
account := &Account{
ID: 5,
Platform: PlatformGemini,
Type: AccountTypeOAuth,
}
err := invalidator.InvalidateToken(context.Background(), account)
require.NoError(t, err)
}
func TestCompositeTokenCacheInvalidator_DeleteError(t *testing.T) {
expectedErr := errors.New("redis connection failed")
cache := &geminiTokenCacheStub{deleteErr: expectedErr}
invalidator := NewCompositeTokenCacheInvalidator(cache)
tests := []struct {
name string
account *Account
}{
{
name: "openai_delete_error",
account: &Account{
ID: 700,
Platform: PlatformOpenAI,
Type: AccountTypeOAuth,
},
},
{
name: "claude_delete_error",
account: &Account{
ID: 800,
Platform: PlatformAnthropic,
Type: AccountTypeOAuth,
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := invalidator.InvalidateToken(context.Background(), tt.account)
require.Error(t, err)
require.Equal(t, expectedErr, err)
})
}
}
func TestCompositeTokenCacheInvalidator_AllPlatformsIntegration(t *testing.T) {
// 测试所有平台的缓存键生成和删除
cache := &geminiTokenCacheStub{}
invalidator := NewCompositeTokenCacheInvalidator(cache)
accounts := []*Account{
{ID: 1, Platform: PlatformGemini, Type: AccountTypeOAuth, Credentials: map[string]any{"project_id": "gemini-proj"}},
{ID: 2, Platform: PlatformAntigravity, Type: AccountTypeOAuth, Credentials: map[string]any{"project_id": "ag-proj"}},
{ID: 3, Platform: PlatformOpenAI, Type: AccountTypeOAuth},
{ID: 4, Platform: PlatformAnthropic, Type: AccountTypeOAuth},
}
expectedKeys := []string{
"gemini-proj",
"ag:ag-proj",
"openai:account:3",
"claude:account:4",
}
for _, acc := range accounts {
err := invalidator.InvalidateToken(context.Background(), acc)
require.NoError(t, err)
}
require.Equal(t, expectedKeys, cache.deletedKeys)
}

View File

@@ -0,0 +1,15 @@
package service
import "strconv"
// OpenAITokenCacheKey 生成 OpenAI OAuth 账号的缓存键
// 格式: "openai:account:{account_id}"
func OpenAITokenCacheKey(account *Account) string {
return "openai:account:" + strconv.FormatInt(account.ID, 10)
}
// ClaudeTokenCacheKey 生成 Claude (Anthropic) OAuth 账号的缓存键
// 格式: "claude:account:{account_id}"
func ClaudeTokenCacheKey(account *Account) string {
return "claude:account:" + strconv.FormatInt(account.ID, 10)
}

View File

@@ -151,3 +151,109 @@ func TestAntigravityTokenCacheKey(t *testing.T) {
})
}
}
func TestOpenAITokenCacheKey(t *testing.T) {
tests := []struct {
name string
account *Account
expected string
}{
{
name: "basic_account",
account: &Account{
ID: 300,
},
expected: "openai:account:300",
},
{
name: "account_with_credentials",
account: &Account{
ID: 301,
Credentials: map[string]any{
"access_token": "test-token",
},
},
expected: "openai:account:301",
},
{
name: "account_id_zero",
account: &Account{
ID: 0,
},
expected: "openai:account:0",
},
{
name: "large_account_id",
account: &Account{
ID: 9999999999,
},
expected: "openai:account:9999999999",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := OpenAITokenCacheKey(tt.account)
require.Equal(t, tt.expected, result)
})
}
}
func TestClaudeTokenCacheKey(t *testing.T) {
tests := []struct {
name string
account *Account
expected string
}{
{
name: "basic_account",
account: &Account{
ID: 400,
},
expected: "claude:account:400",
},
{
name: "account_with_credentials",
account: &Account{
ID: 401,
Credentials: map[string]any{
"access_token": "claude-token",
},
},
expected: "claude:account:401",
},
{
name: "account_id_zero",
account: &Account{
ID: 0,
},
expected: "claude:account:0",
},
{
name: "large_account_id",
account: &Account{
ID: 9999999999,
},
expected: "claude:account:9999999999",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := ClaudeTokenCacheKey(tt.account)
require.Equal(t, tt.expected, result)
})
}
}
func TestCacheKeyUniqueness(t *testing.T) {
// 确保不同平台的缓存键不会冲突
account := &Account{ID: 123}
openaiKey := OpenAITokenCacheKey(account)
claudeKey := ClaudeTokenCacheKey(account)
require.NotEqual(t, openaiKey, claudeKey, "OpenAI and Claude cache keys should be different")
require.Contains(t, openaiKey, "openai:")
require.Contains(t, claudeKey, "claude:")
}

View File

@@ -172,8 +172,8 @@ func (s *TokenRefreshService) refreshWithRetry(ctx context.Context, account *Acc
if err := s.accountRepo.Update(ctx, account); err != nil {
return fmt.Errorf("failed to save credentials: %w", err)
}
if s.cacheInvalidator != nil && account.Type == AccountTypeOAuth &&
(account.Platform == PlatformGemini || account.Platform == PlatformAntigravity) {
// 对所有 OAuth 账号调用缓存失效InvalidateToken 内部根据平台判断是否需要处理)
if s.cacheInvalidator != nil && account.Type == AccountTypeOAuth {
if err := s.cacheInvalidator.InvalidateToken(ctx, account); err != nil {
log.Printf("[TokenRefresh] Failed to invalidate token cache for account %d: %v", account.ID, err)
} else {

View File

@@ -197,7 +197,7 @@ func TestTokenRefreshService_RefreshWithRetry_NonOAuthAccount(t *testing.T) {
require.Equal(t, 0, invalidator.calls) // 非 OAuth 不触发缓存失效
}
// TestTokenRefreshService_RefreshWithRetry_OtherPlatformOAuth 测试其他平台的 OAuth 账号不触发缓存失效
// TestTokenRefreshService_RefreshWithRetry_OtherPlatformOAuth 测试所有 OAuth 平台都触发缓存失效
func TestTokenRefreshService_RefreshWithRetry_OtherPlatformOAuth(t *testing.T) {
repo := &tokenRefreshAccountRepo{}
invalidator := &tokenCacheInvalidatorStub{}
@@ -210,7 +210,7 @@ func TestTokenRefreshService_RefreshWithRetry_OtherPlatformOAuth(t *testing.T) {
service := NewTokenRefreshService(repo, nil, nil, nil, nil, invalidator, cfg)
account := &Account{
ID: 10,
Platform: PlatformOpenAI, // 其他平台
Platform: PlatformOpenAI, // OpenAI OAuth 账户
Type: AccountTypeOAuth,
}
refresher := &tokenRefresherStub{
@@ -222,7 +222,7 @@ func TestTokenRefreshService_RefreshWithRetry_OtherPlatformOAuth(t *testing.T) {
err := service.refreshWithRetry(context.Background(), account, refresher)
require.NoError(t, err)
require.Equal(t, 1, repo.updateCalls)
require.Equal(t, 0, invalidator.calls) // 其他平台不触发缓存失效
require.Equal(t, 1, invalidator.calls) // 所有 OAuth 账户刷新后触发缓存失效
}
// TestTokenRefreshService_RefreshWithRetry_UpdateFailed 测试更新失败的情况

View File

@@ -214,10 +214,13 @@ var ProviderSet = wire.NewSet(
NewGeminiOAuthService,
NewGeminiQuotaService,
NewCompositeTokenCacheInvalidator,
wire.Bind(new(TokenCacheInvalidator), new(*CompositeTokenCacheInvalidator)),
NewAntigravityOAuthService,
NewGeminiTokenProvider,
NewGeminiMessagesCompatService,
NewAntigravityTokenProvider,
NewOpenAITokenProvider,
NewClaudeTokenProvider,
NewAntigravityGatewayService,
ProvideRateLimitService,
NewAccountUsageService,