diff --git a/backend/cmd/server/wire_gen.go b/backend/cmd/server/wire_gen.go index c8304831..289a14bd 100644 --- a/backend/cmd/server/wire_gen.go +++ b/backend/cmd/server/wire_gen.go @@ -100,8 +100,8 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { tempUnschedCache := repository.NewTempUnschedCache(redisClient) timeoutCounterCache := repository.NewTimeoutCounterCache(redisClient) geminiTokenCache := repository.NewGeminiTokenCache(redisClient) - tokenCacheInvalidator := service.NewCompositeTokenCacheInvalidator(geminiTokenCache) - rateLimitService := service.ProvideRateLimitService(accountRepository, usageLogRepository, configConfig, geminiQuotaService, tempUnschedCache, timeoutCounterCache, settingService, tokenCacheInvalidator) + compositeTokenCacheInvalidator := service.NewCompositeTokenCacheInvalidator(geminiTokenCache) + rateLimitService := service.ProvideRateLimitService(accountRepository, usageLogRepository, configConfig, geminiQuotaService, tempUnschedCache, timeoutCounterCache, settingService, compositeTokenCacheInvalidator) claudeUsageFetcher := repository.NewClaudeUsageFetcher() antigravityQuotaFetcher := service.NewAntigravityQuotaFetcher(proxyRepository) usageCache := service.NewUsageCache() @@ -136,8 +136,10 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { identityCache := repository.NewIdentityCache(redisClient) identityService := service.NewIdentityService(identityCache) deferredService := service.ProvideDeferredService(accountRepository, timingWheelService) - gatewayService := service.NewGatewayService(accountRepository, groupRepository, usageLogRepository, userRepository, userSubscriptionRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, identityService, httpUpstream, deferredService) - openAIGatewayService := service.NewOpenAIGatewayService(accountRepository, usageLogRepository, userRepository, userSubscriptionRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, httpUpstream, deferredService) + claudeTokenProvider := service.NewClaudeTokenProvider(accountRepository, geminiTokenCache, oAuthService) + gatewayService := service.NewGatewayService(accountRepository, groupRepository, usageLogRepository, userRepository, userSubscriptionRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, identityService, httpUpstream, deferredService, claudeTokenProvider) + openAITokenProvider := service.NewOpenAITokenProvider(accountRepository, geminiTokenCache, openAIOAuthService) + openAIGatewayService := service.NewOpenAIGatewayService(accountRepository, usageLogRepository, userRepository, userSubscriptionRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, httpUpstream, deferredService, openAITokenProvider) geminiMessagesCompatService := service.NewGeminiMessagesCompatService(accountRepository, groupRepository, gatewayCache, schedulerSnapshotService, geminiTokenProvider, rateLimitService, httpUpstream, antigravityGatewayService, configConfig) opsService := service.NewOpsService(opsRepository, settingRepository, configConfig, accountRepository, concurrencyService, gatewayService, openAIGatewayService, geminiMessagesCompatService, antigravityGatewayService) settingHandler := admin.NewSettingHandler(settingService, emailService, turnstileService, opsService) @@ -168,7 +170,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { opsAlertEvaluatorService := service.ProvideOpsAlertEvaluatorService(opsService, opsRepository, emailService, redisClient, configConfig) opsCleanupService := service.ProvideOpsCleanupService(opsRepository, db, redisClient, configConfig) opsScheduledReportService := service.ProvideOpsScheduledReportService(opsService, userService, emailService, redisClient, configConfig) - tokenRefreshService := service.ProvideTokenRefreshService(accountRepository, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, tokenCacheInvalidator, configConfig) + tokenRefreshService := service.ProvideTokenRefreshService(accountRepository, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, compositeTokenCacheInvalidator, configConfig) accountExpiryService := service.ProvideAccountExpiryService(accountRepository) v := provideCleanup(client, redisClient, opsMetricsCollector, opsAggregationService, opsAlertEvaluatorService, opsCleanupService, opsScheduledReportService, schedulerSnapshotService, tokenRefreshService, accountExpiryService, pricingService, emailQueueService, billingCacheService, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService) application := &Application{ diff --git a/backend/internal/repository/gemini_token_cache.go b/backend/internal/repository/gemini_token_cache.go index 82c14def..d4f552bc 100644 --- a/backend/internal/repository/gemini_token_cache.go +++ b/backend/internal/repository/gemini_token_cache.go @@ -11,8 +11,8 @@ import ( ) const ( - geminiTokenKeyPrefix = "gemini:token:" - geminiRefreshLockKeyPrefix = "gemini:refresh_lock:" + oauthTokenKeyPrefix = "oauth:token:" + oauthRefreshLockKeyPrefix = "oauth:refresh_lock:" ) type geminiTokenCache struct { @@ -24,26 +24,26 @@ func NewGeminiTokenCache(rdb *redis.Client) service.GeminiTokenCache { } func (c *geminiTokenCache) GetAccessToken(ctx context.Context, cacheKey string) (string, error) { - key := fmt.Sprintf("%s%s", geminiTokenKeyPrefix, cacheKey) + key := fmt.Sprintf("%s%s", oauthTokenKeyPrefix, cacheKey) return c.rdb.Get(ctx, key).Result() } func (c *geminiTokenCache) SetAccessToken(ctx context.Context, cacheKey string, token string, ttl time.Duration) error { - key := fmt.Sprintf("%s%s", geminiTokenKeyPrefix, cacheKey) + key := fmt.Sprintf("%s%s", oauthTokenKeyPrefix, cacheKey) return c.rdb.Set(ctx, key, token, ttl).Err() } func (c *geminiTokenCache) DeleteAccessToken(ctx context.Context, cacheKey string) error { - key := fmt.Sprintf("%s%s", geminiTokenKeyPrefix, cacheKey) + key := fmt.Sprintf("%s%s", oauthTokenKeyPrefix, cacheKey) return c.rdb.Del(ctx, key).Err() } func (c *geminiTokenCache) AcquireRefreshLock(ctx context.Context, cacheKey string, ttl time.Duration) (bool, error) { - key := fmt.Sprintf("%s%s", geminiRefreshLockKeyPrefix, cacheKey) + key := fmt.Sprintf("%s%s", oauthRefreshLockKeyPrefix, cacheKey) return c.rdb.SetNX(ctx, key, 1, ttl).Result() } func (c *geminiTokenCache) ReleaseRefreshLock(ctx context.Context, cacheKey string) error { - key := fmt.Sprintf("%s%s", geminiRefreshLockKeyPrefix, cacheKey) + key := fmt.Sprintf("%s%s", oauthRefreshLockKeyPrefix, cacheKey) return c.rdb.Del(ctx, key).Err() } diff --git a/backend/internal/service/claude_token_provider.go b/backend/internal/service/claude_token_provider.go new file mode 100644 index 00000000..d2db162f --- /dev/null +++ b/backend/internal/service/claude_token_provider.go @@ -0,0 +1,156 @@ +package service + +import ( + "context" + "errors" + "log/slog" + "strconv" + "strings" + "time" +) + +const ( + claudeTokenRefreshSkew = 3 * time.Minute + claudeTokenCacheSkew = 5 * time.Minute + claudeLockWaitTime = 200 * time.Millisecond +) + +// ClaudeTokenCache Token 缓存接口(复用 GeminiTokenCache 接口定义) +type ClaudeTokenCache = GeminiTokenCache + +// ClaudeTokenProvider 管理 Claude (Anthropic) OAuth 账户的 access_token +type ClaudeTokenProvider struct { + accountRepo AccountRepository + tokenCache ClaudeTokenCache + oauthService *OAuthService +} + +func NewClaudeTokenProvider( + accountRepo AccountRepository, + tokenCache ClaudeTokenCache, + oauthService *OAuthService, +) *ClaudeTokenProvider { + return &ClaudeTokenProvider{ + accountRepo: accountRepo, + tokenCache: tokenCache, + oauthService: oauthService, + } +} + +// GetAccessToken 获取有效的 access_token +func (p *ClaudeTokenProvider) GetAccessToken(ctx context.Context, account *Account) (string, error) { + if account == nil { + return "", errors.New("account is nil") + } + if account.Platform != PlatformAnthropic || account.Type != AccountTypeOAuth { + return "", errors.New("not an anthropic oauth account") + } + + cacheKey := ClaudeTokenCacheKey(account) + + // 1. 先尝试缓存 + if p.tokenCache != nil { + if token, err := p.tokenCache.GetAccessToken(ctx, cacheKey); err == nil && strings.TrimSpace(token) != "" { + slog.Debug("claude_token_cache_hit", "account_id", account.ID) + return token, nil + } else if err != nil { + slog.Warn("claude_token_cache_get_failed", "account_id", account.ID, "error", err) + } + } + + slog.Debug("claude_token_cache_miss", "account_id", account.ID) + + // 2. 如果即将过期则刷新 + expiresAt := account.GetCredentialAsTime("expires_at") + needsRefresh := expiresAt == nil || time.Until(*expiresAt) <= claudeTokenRefreshSkew + refreshFailed := false + if needsRefresh && p.tokenCache != nil { + locked, err := p.tokenCache.AcquireRefreshLock(ctx, cacheKey, 30*time.Second) + if err == nil && locked { + defer func() { _ = p.tokenCache.ReleaseRefreshLock(ctx, cacheKey) }() + + // 拿到锁后再次检查缓存(另一个 worker 可能已刷新) + if token, err := p.tokenCache.GetAccessToken(ctx, cacheKey); err == nil && strings.TrimSpace(token) != "" { + return token, nil + } + + // 从数据库获取最新账户信息 + fresh, err := p.accountRepo.GetByID(ctx, account.ID) + if err == nil && fresh != nil { + account = fresh + } + expiresAt = account.GetCredentialAsTime("expires_at") + if expiresAt == nil || time.Until(*expiresAt) <= claudeTokenRefreshSkew { + if p.oauthService == nil { + slog.Warn("claude_oauth_service_not_configured", "account_id", account.ID) + refreshFailed = true // 无法刷新,标记失败 + } else { + tokenInfo, err := p.oauthService.RefreshAccountToken(ctx, account) + if err != nil { + // 刷新失败时记录警告,但不立即返回错误,尝试使用现有 token + slog.Warn("claude_token_refresh_failed", "account_id", account.ID, "error", err) + refreshFailed = true // 刷新失败,标记以使用短 TTL + } else { + // 构建新 credentials,保留原有字段 + newCredentials := make(map[string]any) + for k, v := range account.Credentials { + newCredentials[k] = v + } + newCredentials["access_token"] = tokenInfo.AccessToken + newCredentials["token_type"] = tokenInfo.TokenType + newCredentials["expires_in"] = strconv.FormatInt(tokenInfo.ExpiresIn, 10) + newCredentials["expires_at"] = strconv.FormatInt(tokenInfo.ExpiresAt, 10) + if tokenInfo.RefreshToken != "" { + newCredentials["refresh_token"] = tokenInfo.RefreshToken + } + if tokenInfo.Scope != "" { + newCredentials["scope"] = tokenInfo.Scope + } + account.Credentials = newCredentials + if updateErr := p.accountRepo.Update(ctx, account); updateErr != nil { + slog.Error("claude_token_provider_update_failed", "account_id", account.ID, "error", updateErr) + } + expiresAt = account.GetCredentialAsTime("expires_at") + } + } + } + } else { + // 锁获取失败,等待 200ms 后重试读取缓存(改进:减少并发时的缓存未命中) + time.Sleep(claudeLockWaitTime) + if token, err := p.tokenCache.GetAccessToken(ctx, cacheKey); err == nil && strings.TrimSpace(token) != "" { + slog.Debug("claude_token_cache_hit_after_wait", "account_id", account.ID) + return token, nil + } + } + } + + accessToken := account.GetCredential("access_token") + if strings.TrimSpace(accessToken) == "" { + return "", errors.New("access_token not found in credentials") + } + + // 3. 存入缓存 + if p.tokenCache != nil { + ttl := 30 * time.Minute + if refreshFailed { + // 刷新失败时使用短 TTL,避免失效 token 长时间缓存导致 401 抖动 + ttl = time.Minute + slog.Debug("claude_token_cache_short_ttl", "account_id", account.ID, "reason", "refresh_failed") + } else if expiresAt != nil { + until := time.Until(*expiresAt) + switch { + case until > claudeTokenCacheSkew: + ttl = until - claudeTokenCacheSkew + case until > 0: + ttl = until + default: + ttl = time.Minute + } + } + if err := p.tokenCache.SetAccessToken(ctx, cacheKey, accessToken, ttl); err != nil { + slog.Warn("claude_token_cache_set_failed", "account_id", account.ID, "error", err) + } + } + + return accessToken, nil +} diff --git a/backend/internal/service/claude_token_provider_test.go b/backend/internal/service/claude_token_provider_test.go new file mode 100644 index 00000000..37c58e3f --- /dev/null +++ b/backend/internal/service/claude_token_provider_test.go @@ -0,0 +1,939 @@ +//go:build unit + +package service + +import ( + "context" + "errors" + "sync" + "sync/atomic" + "testing" + "time" + + "github.com/stretchr/testify/require" +) + +// claudeTokenCacheStub implements ClaudeTokenCache for testing +type claudeTokenCacheStub struct { + mu sync.Mutex + tokens map[string]string + getErr error + setErr error + deleteErr error + lockAcquired bool + lockErr error + releaseLockErr error + getCalled int32 + setCalled int32 + lockCalled int32 + unlockCalled int32 + simulateLockRace bool +} + +func newClaudeTokenCacheStub() *claudeTokenCacheStub { + return &claudeTokenCacheStub{ + tokens: make(map[string]string), + lockAcquired: true, + } +} + +func (s *claudeTokenCacheStub) GetAccessToken(ctx context.Context, cacheKey string) (string, error) { + atomic.AddInt32(&s.getCalled, 1) + if s.getErr != nil { + return "", s.getErr + } + s.mu.Lock() + defer s.mu.Unlock() + return s.tokens[cacheKey], nil +} + +func (s *claudeTokenCacheStub) SetAccessToken(ctx context.Context, cacheKey string, token string, ttl time.Duration) error { + atomic.AddInt32(&s.setCalled, 1) + if s.setErr != nil { + return s.setErr + } + s.mu.Lock() + defer s.mu.Unlock() + s.tokens[cacheKey] = token + return nil +} + +func (s *claudeTokenCacheStub) DeleteAccessToken(ctx context.Context, cacheKey string) error { + if s.deleteErr != nil { + return s.deleteErr + } + s.mu.Lock() + defer s.mu.Unlock() + delete(s.tokens, cacheKey) + return nil +} + +func (s *claudeTokenCacheStub) AcquireRefreshLock(ctx context.Context, cacheKey string, ttl time.Duration) (bool, error) { + atomic.AddInt32(&s.lockCalled, 1) + if s.lockErr != nil { + return false, s.lockErr + } + if s.simulateLockRace { + return false, nil + } + return s.lockAcquired, nil +} + +func (s *claudeTokenCacheStub) ReleaseRefreshLock(ctx context.Context, cacheKey string) error { + atomic.AddInt32(&s.unlockCalled, 1) + return s.releaseLockErr +} + +// claudeAccountRepoStub is a minimal stub implementing only the methods used by ClaudeTokenProvider +type claudeAccountRepoStub struct { + account *Account + getErr error + updateErr error + getCalled int32 + updateCalled int32 +} + +func (r *claudeAccountRepoStub) GetByID(ctx context.Context, id int64) (*Account, error) { + atomic.AddInt32(&r.getCalled, 1) + if r.getErr != nil { + return nil, r.getErr + } + return r.account, nil +} + +func (r *claudeAccountRepoStub) Update(ctx context.Context, account *Account) error { + atomic.AddInt32(&r.updateCalled, 1) + if r.updateErr != nil { + return r.updateErr + } + r.account = account + return nil +} + +// claudeOAuthServiceStub implements OAuthService methods for testing +type claudeOAuthServiceStub struct { + tokenInfo *TokenInfo + refreshErr error + refreshCalled int32 +} + +func (s *claudeOAuthServiceStub) RefreshAccountToken(ctx context.Context, account *Account) (*TokenInfo, error) { + atomic.AddInt32(&s.refreshCalled, 1) + if s.refreshErr != nil { + return nil, s.refreshErr + } + return s.tokenInfo, nil +} + +// testClaudeTokenProvider is a test version that uses the stub OAuth service +type testClaudeTokenProvider struct { + accountRepo *claudeAccountRepoStub + tokenCache *claudeTokenCacheStub + oauthService *claudeOAuthServiceStub +} + +func (p *testClaudeTokenProvider) GetAccessToken(ctx context.Context, account *Account) (string, error) { + if account == nil { + return "", errors.New("account is nil") + } + if account.Platform != PlatformAnthropic || account.Type != AccountTypeOAuth { + return "", errors.New("not an anthropic oauth account") + } + + cacheKey := ClaudeTokenCacheKey(account) + + // 1. Check cache + if p.tokenCache != nil { + if token, err := p.tokenCache.GetAccessToken(ctx, cacheKey); err == nil && token != "" { + return token, nil + } + } + + // 2. Check if refresh needed + expiresAt := account.GetCredentialAsTime("expires_at") + needsRefresh := expiresAt == nil || time.Until(*expiresAt) <= claudeTokenRefreshSkew + refreshFailed := false + if needsRefresh && p.tokenCache != nil { + locked, err := p.tokenCache.AcquireRefreshLock(ctx, cacheKey, 30*time.Second) + if err == nil && locked { + defer func() { _ = p.tokenCache.ReleaseRefreshLock(ctx, cacheKey) }() + + // Check cache again after acquiring lock + if token, err := p.tokenCache.GetAccessToken(ctx, cacheKey); err == nil && token != "" { + return token, nil + } + + // Get fresh account from DB + fresh, err := p.accountRepo.GetByID(ctx, account.ID) + if err == nil && fresh != nil { + account = fresh + } + expiresAt = account.GetCredentialAsTime("expires_at") + if expiresAt == nil || time.Until(*expiresAt) <= claudeTokenRefreshSkew { + if p.oauthService == nil { + refreshFailed = true // 无法刷新,标记失败 + } else { + tokenInfo, err := p.oauthService.RefreshAccountToken(ctx, account) + if err != nil { + refreshFailed = true // 刷新失败,标记以使用短 TTL + } else { + // Build new credentials + newCredentials := make(map[string]any) + for k, v := range account.Credentials { + newCredentials[k] = v + } + newCredentials["access_token"] = tokenInfo.AccessToken + newCredentials["token_type"] = tokenInfo.TokenType + newCredentials["expires_at"] = time.Now().Add(time.Duration(tokenInfo.ExpiresIn) * time.Second).Format(time.RFC3339) + if tokenInfo.RefreshToken != "" { + newCredentials["refresh_token"] = tokenInfo.RefreshToken + } + account.Credentials = newCredentials + _ = p.accountRepo.Update(ctx, account) + expiresAt = account.GetCredentialAsTime("expires_at") + } + } + } + } else if p.tokenCache.simulateLockRace { + // Wait and retry cache + time.Sleep(10 * time.Millisecond) + if token, err := p.tokenCache.GetAccessToken(ctx, cacheKey); err == nil && token != "" { + return token, nil + } + } + } + + accessToken := account.GetCredential("access_token") + if accessToken == "" { + return "", errors.New("access_token not found in credentials") + } + + // 3. Store in cache + if p.tokenCache != nil { + ttl := 30 * time.Minute + if refreshFailed { + ttl = time.Minute // 刷新失败时使用短 TTL + } else if expiresAt != nil { + until := time.Until(*expiresAt) + if until > claudeTokenCacheSkew { + ttl = until - claudeTokenCacheSkew + } else if until > 0 { + ttl = until + } else { + ttl = time.Minute + } + } + _ = p.tokenCache.SetAccessToken(ctx, cacheKey, accessToken, ttl) + } + + return accessToken, nil +} + +func TestClaudeTokenProvider_CacheHit(t *testing.T) { + cache := newClaudeTokenCacheStub() + account := &Account{ + ID: 100, + Platform: PlatformAnthropic, + Type: AccountTypeOAuth, + Credentials: map[string]any{ + "access_token": "db-token", + }, + } + cacheKey := ClaudeTokenCacheKey(account) + cache.tokens[cacheKey] = "cached-token" + + provider := NewClaudeTokenProvider(nil, cache, nil) + + token, err := provider.GetAccessToken(context.Background(), account) + require.NoError(t, err) + require.Equal(t, "cached-token", token) + require.Equal(t, int32(1), atomic.LoadInt32(&cache.getCalled)) + require.Equal(t, int32(0), atomic.LoadInt32(&cache.setCalled)) +} + +func TestClaudeTokenProvider_CacheMiss_FromCredentials(t *testing.T) { + cache := newClaudeTokenCacheStub() + // Token expires in far future, no refresh needed + expiresAt := time.Now().Add(1 * time.Hour).Format(time.RFC3339) + account := &Account{ + ID: 101, + Platform: PlatformAnthropic, + Type: AccountTypeOAuth, + Credentials: map[string]any{ + "access_token": "credential-token", + "expires_at": expiresAt, + }, + } + + provider := NewClaudeTokenProvider(nil, cache, nil) + + token, err := provider.GetAccessToken(context.Background(), account) + require.NoError(t, err) + require.Equal(t, "credential-token", token) + + // Should have stored in cache + cacheKey := ClaudeTokenCacheKey(account) + require.Equal(t, "credential-token", cache.tokens[cacheKey]) +} + +func TestClaudeTokenProvider_TokenRefresh(t *testing.T) { + cache := newClaudeTokenCacheStub() + accountRepo := &claudeAccountRepoStub{} + oauthService := &claudeOAuthServiceStub{ + tokenInfo: &TokenInfo{ + AccessToken: "refreshed-token", + RefreshToken: "new-refresh-token", + TokenType: "Bearer", + ExpiresIn: 3600, + ExpiresAt: time.Now().Add(time.Hour).Unix(), + }, + } + + // Token expires soon (within refresh skew) + expiresAt := time.Now().Add(1 * time.Minute).Format(time.RFC3339) + account := &Account{ + ID: 102, + Platform: PlatformAnthropic, + Type: AccountTypeOAuth, + Credentials: map[string]any{ + "access_token": "old-token", + "refresh_token": "old-refresh-token", + "expires_at": expiresAt, + }, + } + accountRepo.account = account + + provider := &testClaudeTokenProvider{ + accountRepo: accountRepo, + tokenCache: cache, + oauthService: oauthService, + } + + token, err := provider.GetAccessToken(context.Background(), account) + require.NoError(t, err) + require.Equal(t, "refreshed-token", token) + require.Equal(t, int32(1), atomic.LoadInt32(&oauthService.refreshCalled)) +} + +func TestClaudeTokenProvider_LockRaceCondition(t *testing.T) { + cache := newClaudeTokenCacheStub() + cache.simulateLockRace = true + accountRepo := &claudeAccountRepoStub{} + + // Token expires soon + expiresAt := time.Now().Add(1 * time.Minute).Format(time.RFC3339) + account := &Account{ + ID: 103, + Platform: PlatformAnthropic, + Type: AccountTypeOAuth, + Credentials: map[string]any{ + "access_token": "race-token", + "expires_at": expiresAt, + }, + } + accountRepo.account = account + + // Simulate another worker already refreshed and cached + cacheKey := ClaudeTokenCacheKey(account) + go func() { + time.Sleep(5 * time.Millisecond) + cache.mu.Lock() + cache.tokens[cacheKey] = "winner-token" + cache.mu.Unlock() + }() + + provider := &testClaudeTokenProvider{ + accountRepo: accountRepo, + tokenCache: cache, + } + + token, err := provider.GetAccessToken(context.Background(), account) + require.NoError(t, err) + require.NotEmpty(t, token) +} + +func TestClaudeTokenProvider_NilAccount(t *testing.T) { + provider := NewClaudeTokenProvider(nil, nil, nil) + + token, err := provider.GetAccessToken(context.Background(), nil) + require.Error(t, err) + require.Contains(t, err.Error(), "account is nil") + require.Empty(t, token) +} + +func TestClaudeTokenProvider_WrongPlatform(t *testing.T) { + provider := NewClaudeTokenProvider(nil, nil, nil) + account := &Account{ + ID: 104, + Platform: PlatformOpenAI, + Type: AccountTypeOAuth, + } + + token, err := provider.GetAccessToken(context.Background(), account) + require.Error(t, err) + require.Contains(t, err.Error(), "not an anthropic oauth account") + require.Empty(t, token) +} + +func TestClaudeTokenProvider_WrongAccountType(t *testing.T) { + provider := NewClaudeTokenProvider(nil, nil, nil) + account := &Account{ + ID: 105, + Platform: PlatformAnthropic, + Type: AccountTypeAPIKey, + } + + token, err := provider.GetAccessToken(context.Background(), account) + require.Error(t, err) + require.Contains(t, err.Error(), "not an anthropic oauth account") + require.Empty(t, token) +} + +func TestClaudeTokenProvider_SetupTokenType(t *testing.T) { + provider := NewClaudeTokenProvider(nil, nil, nil) + account := &Account{ + ID: 106, + Platform: PlatformAnthropic, + Type: AccountTypeSetupToken, + } + + token, err := provider.GetAccessToken(context.Background(), account) + require.Error(t, err) + require.Contains(t, err.Error(), "not an anthropic oauth account") + require.Empty(t, token) +} + +func TestClaudeTokenProvider_NilCache(t *testing.T) { + // Token doesn't need refresh + expiresAt := time.Now().Add(1 * time.Hour).Format(time.RFC3339) + account := &Account{ + ID: 107, + Platform: PlatformAnthropic, + Type: AccountTypeOAuth, + Credentials: map[string]any{ + "access_token": "nocache-token", + "expires_at": expiresAt, + }, + } + + provider := NewClaudeTokenProvider(nil, nil, nil) + + token, err := provider.GetAccessToken(context.Background(), account) + require.NoError(t, err) + require.Equal(t, "nocache-token", token) +} + +func TestClaudeTokenProvider_CacheGetError(t *testing.T) { + cache := newClaudeTokenCacheStub() + cache.getErr = errors.New("redis connection failed") + + // Token doesn't need refresh + expiresAt := time.Now().Add(1 * time.Hour).Format(time.RFC3339) + account := &Account{ + ID: 108, + Platform: PlatformAnthropic, + Type: AccountTypeOAuth, + Credentials: map[string]any{ + "access_token": "fallback-token", + "expires_at": expiresAt, + }, + } + + provider := NewClaudeTokenProvider(nil, cache, nil) + + // Should gracefully degrade and return from credentials + token, err := provider.GetAccessToken(context.Background(), account) + require.NoError(t, err) + require.Equal(t, "fallback-token", token) +} + +func TestClaudeTokenProvider_CacheSetError(t *testing.T) { + cache := newClaudeTokenCacheStub() + cache.setErr = errors.New("redis write failed") + + expiresAt := time.Now().Add(1 * time.Hour).Format(time.RFC3339) + account := &Account{ + ID: 109, + Platform: PlatformAnthropic, + Type: AccountTypeOAuth, + Credentials: map[string]any{ + "access_token": "still-works-token", + "expires_at": expiresAt, + }, + } + + provider := NewClaudeTokenProvider(nil, cache, nil) + + // Should still work even if cache set fails + token, err := provider.GetAccessToken(context.Background(), account) + require.NoError(t, err) + require.Equal(t, "still-works-token", token) +} + +func TestClaudeTokenProvider_MissingAccessToken(t *testing.T) { + cache := newClaudeTokenCacheStub() + expiresAt := time.Now().Add(1 * time.Hour).Format(time.RFC3339) + account := &Account{ + ID: 110, + Platform: PlatformAnthropic, + Type: AccountTypeOAuth, + Credentials: map[string]any{ + "expires_at": expiresAt, + // missing access_token + }, + } + + provider := NewClaudeTokenProvider(nil, cache, nil) + + token, err := provider.GetAccessToken(context.Background(), account) + require.Error(t, err) + require.Contains(t, err.Error(), "access_token not found") + require.Empty(t, token) +} + +func TestClaudeTokenProvider_RefreshError(t *testing.T) { + cache := newClaudeTokenCacheStub() + accountRepo := &claudeAccountRepoStub{} + oauthService := &claudeOAuthServiceStub{ + refreshErr: errors.New("oauth refresh failed"), + } + + // Token expires soon + expiresAt := time.Now().Add(1 * time.Minute).Format(time.RFC3339) + account := &Account{ + ID: 111, + Platform: PlatformAnthropic, + Type: AccountTypeOAuth, + Credentials: map[string]any{ + "access_token": "old-token", + "refresh_token": "old-refresh-token", + "expires_at": expiresAt, + }, + } + accountRepo.account = account + + provider := &testClaudeTokenProvider{ + accountRepo: accountRepo, + tokenCache: cache, + oauthService: oauthService, + } + + // Now with fallback behavior, should return existing token even if refresh fails + token, err := provider.GetAccessToken(context.Background(), account) + require.NoError(t, err) + require.Equal(t, "old-token", token) // Fallback to existing token +} + +func TestClaudeTokenProvider_OAuthServiceNotConfigured(t *testing.T) { + cache := newClaudeTokenCacheStub() + accountRepo := &claudeAccountRepoStub{} + + // Token expires soon + expiresAt := time.Now().Add(1 * time.Minute).Format(time.RFC3339) + account := &Account{ + ID: 112, + Platform: PlatformAnthropic, + Type: AccountTypeOAuth, + Credentials: map[string]any{ + "access_token": "old-token", + "expires_at": expiresAt, + }, + } + accountRepo.account = account + + provider := &testClaudeTokenProvider{ + accountRepo: accountRepo, + tokenCache: cache, + oauthService: nil, // not configured + } + + // Now with fallback behavior, should return existing token even if oauth service not configured + token, err := provider.GetAccessToken(context.Background(), account) + require.NoError(t, err) + require.Equal(t, "old-token", token) // Fallback to existing token +} + +func TestClaudeTokenProvider_TTLCalculation(t *testing.T) { + tests := []struct { + name string + expiresIn time.Duration + }{ + { + name: "far_future_expiry", + expiresIn: 1 * time.Hour, + }, + { + name: "medium_expiry", + expiresIn: 10 * time.Minute, + }, + { + name: "near_expiry", + expiresIn: 6 * time.Minute, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + cache := newClaudeTokenCacheStub() + expiresAt := time.Now().Add(tt.expiresIn).Format(time.RFC3339) + account := &Account{ + ID: 200, + Platform: PlatformAnthropic, + Type: AccountTypeOAuth, + Credentials: map[string]any{ + "access_token": "test-token", + "expires_at": expiresAt, + }, + } + + provider := NewClaudeTokenProvider(nil, cache, nil) + + _, err := provider.GetAccessToken(context.Background(), account) + require.NoError(t, err) + + // Verify token was cached + cacheKey := ClaudeTokenCacheKey(account) + require.Equal(t, "test-token", cache.tokens[cacheKey]) + }) + } +} + +func TestClaudeTokenProvider_AccountRepoGetError(t *testing.T) { + cache := newClaudeTokenCacheStub() + accountRepo := &claudeAccountRepoStub{ + getErr: errors.New("db connection failed"), + } + oauthService := &claudeOAuthServiceStub{ + tokenInfo: &TokenInfo{ + AccessToken: "refreshed-token", + RefreshToken: "new-refresh", + TokenType: "Bearer", + ExpiresIn: 3600, + }, + } + + // Token expires soon + expiresAt := time.Now().Add(1 * time.Minute).Format(time.RFC3339) + account := &Account{ + ID: 113, + Platform: PlatformAnthropic, + Type: AccountTypeOAuth, + Credentials: map[string]any{ + "access_token": "old-token", + "refresh_token": "old-refresh", + "expires_at": expiresAt, + }, + } + + provider := &testClaudeTokenProvider{ + accountRepo: accountRepo, + tokenCache: cache, + oauthService: oauthService, + } + + // Should still work, just using the passed-in account + token, err := provider.GetAccessToken(context.Background(), account) + require.NoError(t, err) + require.Equal(t, "refreshed-token", token) +} + +func TestClaudeTokenProvider_AccountUpdateError(t *testing.T) { + cache := newClaudeTokenCacheStub() + accountRepo := &claudeAccountRepoStub{ + updateErr: errors.New("db write failed"), + } + oauthService := &claudeOAuthServiceStub{ + tokenInfo: &TokenInfo{ + AccessToken: "refreshed-token", + RefreshToken: "new-refresh", + TokenType: "Bearer", + ExpiresIn: 3600, + }, + } + + // Token expires soon + expiresAt := time.Now().Add(1 * time.Minute).Format(time.RFC3339) + account := &Account{ + ID: 114, + Platform: PlatformAnthropic, + Type: AccountTypeOAuth, + Credentials: map[string]any{ + "access_token": "old-token", + "refresh_token": "old-refresh", + "expires_at": expiresAt, + }, + } + accountRepo.account = account + + provider := &testClaudeTokenProvider{ + accountRepo: accountRepo, + tokenCache: cache, + oauthService: oauthService, + } + + // Should still return token even if update fails + token, err := provider.GetAccessToken(context.Background(), account) + require.NoError(t, err) + require.Equal(t, "refreshed-token", token) +} + +func TestClaudeTokenProvider_RefreshPreservesExistingCredentials(t *testing.T) { + cache := newClaudeTokenCacheStub() + accountRepo := &claudeAccountRepoStub{} + oauthService := &claudeOAuthServiceStub{ + tokenInfo: &TokenInfo{ + AccessToken: "new-access-token", + RefreshToken: "new-refresh-token", + TokenType: "Bearer", + ExpiresIn: 3600, + }, + } + + // Token expires soon + expiresAt := time.Now().Add(1 * time.Minute).Format(time.RFC3339) + account := &Account{ + ID: 115, + Platform: PlatformAnthropic, + Type: AccountTypeOAuth, + Credentials: map[string]any{ + "access_token": "old-access-token", + "refresh_token": "old-refresh-token", + "expires_at": expiresAt, + "custom_field": "should-be-preserved", + "organization": "test-org", + }, + } + accountRepo.account = account + + provider := &testClaudeTokenProvider{ + accountRepo: accountRepo, + tokenCache: cache, + oauthService: oauthService, + } + + token, err := provider.GetAccessToken(context.Background(), account) + require.NoError(t, err) + require.Equal(t, "new-access-token", token) + + // Verify existing fields are preserved + require.Equal(t, "should-be-preserved", accountRepo.account.Credentials["custom_field"]) + require.Equal(t, "test-org", accountRepo.account.Credentials["organization"]) + // Verify new fields are updated + require.Equal(t, "new-access-token", accountRepo.account.Credentials["access_token"]) + require.Equal(t, "new-refresh-token", accountRepo.account.Credentials["refresh_token"]) +} + +func TestClaudeTokenProvider_DoubleCheckCacheAfterLock(t *testing.T) { + cache := newClaudeTokenCacheStub() + accountRepo := &claudeAccountRepoStub{} + oauthService := &claudeOAuthServiceStub{ + tokenInfo: &TokenInfo{ + AccessToken: "refreshed-token", + RefreshToken: "new-refresh", + TokenType: "Bearer", + ExpiresIn: 3600, + }, + } + + // Token expires soon + expiresAt := time.Now().Add(1 * time.Minute).Format(time.RFC3339) + account := &Account{ + ID: 116, + Platform: PlatformAnthropic, + Type: AccountTypeOAuth, + Credentials: map[string]any{ + "access_token": "old-token", + "expires_at": expiresAt, + }, + } + accountRepo.account = account + cacheKey := ClaudeTokenCacheKey(account) + + // After lock is acquired, cache should have the token (simulating another worker) + go func() { + time.Sleep(5 * time.Millisecond) + cache.mu.Lock() + cache.tokens[cacheKey] = "cached-by-other-worker" + cache.mu.Unlock() + }() + + provider := &testClaudeTokenProvider{ + accountRepo: accountRepo, + tokenCache: cache, + oauthService: oauthService, + } + + token, err := provider.GetAccessToken(context.Background(), account) + require.NoError(t, err) + require.NotEmpty(t, token) +} + +// Tests for real provider - to increase coverage +func TestClaudeTokenProvider_Real_LockFailedWait(t *testing.T) { + cache := newClaudeTokenCacheStub() + cache.lockAcquired = false // Lock acquisition fails + + // Token expires soon (within refresh skew) to trigger lock attempt + expiresAt := time.Now().Add(1 * time.Minute).Format(time.RFC3339) + account := &Account{ + ID: 300, + Platform: PlatformAnthropic, + Type: AccountTypeOAuth, + Credentials: map[string]any{ + "access_token": "fallback-token", + "expires_at": expiresAt, + }, + } + + // Set token in cache after lock wait period (simulate other worker refreshing) + cacheKey := ClaudeTokenCacheKey(account) + go func() { + time.Sleep(100 * time.Millisecond) + cache.mu.Lock() + cache.tokens[cacheKey] = "refreshed-by-other" + cache.mu.Unlock() + }() + + provider := NewClaudeTokenProvider(nil, cache, nil) + token, err := provider.GetAccessToken(context.Background(), account) + require.NoError(t, err) + require.NotEmpty(t, token) +} + +func TestClaudeTokenProvider_Real_CacheHitAfterWait(t *testing.T) { + cache := newClaudeTokenCacheStub() + cache.lockAcquired = false // Lock acquisition fails + + // Token expires soon + expiresAt := time.Now().Add(1 * time.Minute).Format(time.RFC3339) + account := &Account{ + ID: 301, + Platform: PlatformAnthropic, + Type: AccountTypeOAuth, + Credentials: map[string]any{ + "access_token": "original-token", + "expires_at": expiresAt, + }, + } + + cacheKey := ClaudeTokenCacheKey(account) + // Set token in cache immediately after wait starts + go func() { + time.Sleep(50 * time.Millisecond) + cache.mu.Lock() + cache.tokens[cacheKey] = "winner-token" + cache.mu.Unlock() + }() + + provider := NewClaudeTokenProvider(nil, cache, nil) + token, err := provider.GetAccessToken(context.Background(), account) + require.NoError(t, err) + require.NotEmpty(t, token) +} + +func TestClaudeTokenProvider_Real_NoExpiresAt(t *testing.T) { + cache := newClaudeTokenCacheStub() + cache.lockAcquired = false // Prevent entering refresh logic + + // Token with nil expires_at (no expiry set) + account := &Account{ + ID: 302, + Platform: PlatformAnthropic, + Type: AccountTypeOAuth, + Credentials: map[string]any{ + "access_token": "no-expiry-token", + }, + } + + // After lock wait, return token from credentials + provider := NewClaudeTokenProvider(nil, cache, nil) + token, err := provider.GetAccessToken(context.Background(), account) + require.NoError(t, err) + require.Equal(t, "no-expiry-token", token) +} + +func TestClaudeTokenProvider_Real_WhitespaceToken(t *testing.T) { + cache := newClaudeTokenCacheStub() + cacheKey := "claude:account:303" + cache.tokens[cacheKey] = " " // Whitespace only - should be treated as empty + + expiresAt := time.Now().Add(1 * time.Hour).Format(time.RFC3339) + account := &Account{ + ID: 303, + Platform: PlatformAnthropic, + Type: AccountTypeOAuth, + Credentials: map[string]any{ + "access_token": "real-token", + "expires_at": expiresAt, + }, + } + + provider := NewClaudeTokenProvider(nil, cache, nil) + token, err := provider.GetAccessToken(context.Background(), account) + require.NoError(t, err) + require.Equal(t, "real-token", token) +} + +func TestClaudeTokenProvider_Real_EmptyCredentialToken(t *testing.T) { + cache := newClaudeTokenCacheStub() + + expiresAt := time.Now().Add(1 * time.Hour).Format(time.RFC3339) + account := &Account{ + ID: 304, + Platform: PlatformAnthropic, + Type: AccountTypeOAuth, + Credentials: map[string]any{ + "access_token": " ", // Whitespace only + "expires_at": expiresAt, + }, + } + + provider := NewClaudeTokenProvider(nil, cache, nil) + token, err := provider.GetAccessToken(context.Background(), account) + require.Error(t, err) + require.Contains(t, err.Error(), "access_token not found") + require.Empty(t, token) +} + +func TestClaudeTokenProvider_Real_LockError(t *testing.T) { + cache := newClaudeTokenCacheStub() + cache.lockErr = errors.New("redis lock failed") + + // Token expires soon (within refresh skew) + expiresAt := time.Now().Add(1 * time.Minute).Format(time.RFC3339) + account := &Account{ + ID: 305, + Platform: PlatformAnthropic, + Type: AccountTypeOAuth, + Credentials: map[string]any{ + "access_token": "fallback-on-lock-error", + "expires_at": expiresAt, + }, + } + + provider := NewClaudeTokenProvider(nil, cache, nil) + token, err := provider.GetAccessToken(context.Background(), account) + require.NoError(t, err) + require.Equal(t, "fallback-on-lock-error", token) +} + +func TestClaudeTokenProvider_Real_NilCredentials(t *testing.T) { + cache := newClaudeTokenCacheStub() + + expiresAt := time.Now().Add(1 * time.Hour).Format(time.RFC3339) + account := &Account{ + ID: 306, + Platform: PlatformAnthropic, + Type: AccountTypeOAuth, + Credentials: map[string]any{ + "expires_at": expiresAt, + // No access_token + }, + } + + provider := NewClaudeTokenProvider(nil, cache, nil) + token, err := provider.GetAccessToken(context.Background(), account) + require.Error(t, err) + require.Contains(t, err.Error(), "access_token not found") + require.Empty(t, token) +} diff --git a/backend/internal/service/gateway_service.go b/backend/internal/service/gateway_service.go index b552f030..042f9f49 100644 --- a/backend/internal/service/gateway_service.go +++ b/backend/internal/service/gateway_service.go @@ -144,21 +144,22 @@ func (e *UpstreamFailoverError) Error() string { // GatewayService handles API gateway operations type GatewayService struct { - accountRepo AccountRepository - groupRepo GroupRepository - usageLogRepo UsageLogRepository - userRepo UserRepository - userSubRepo UserSubscriptionRepository - cache GatewayCache - cfg *config.Config - schedulerSnapshot *SchedulerSnapshotService - billingService *BillingService - rateLimitService *RateLimitService - billingCacheService *BillingCacheService - identityService *IdentityService - httpUpstream HTTPUpstream - deferredService *DeferredService - concurrencyService *ConcurrencyService + accountRepo AccountRepository + groupRepo GroupRepository + usageLogRepo UsageLogRepository + userRepo UserRepository + userSubRepo UserSubscriptionRepository + cache GatewayCache + cfg *config.Config + schedulerSnapshot *SchedulerSnapshotService + billingService *BillingService + rateLimitService *RateLimitService + billingCacheService *BillingCacheService + identityService *IdentityService + httpUpstream HTTPUpstream + deferredService *DeferredService + concurrencyService *ConcurrencyService + claudeTokenProvider *ClaudeTokenProvider } // NewGatewayService creates a new GatewayService @@ -178,23 +179,25 @@ func NewGatewayService( identityService *IdentityService, httpUpstream HTTPUpstream, deferredService *DeferredService, + claudeTokenProvider *ClaudeTokenProvider, ) *GatewayService { return &GatewayService{ - accountRepo: accountRepo, - groupRepo: groupRepo, - usageLogRepo: usageLogRepo, - userRepo: userRepo, - userSubRepo: userSubRepo, - cache: cache, - cfg: cfg, - schedulerSnapshot: schedulerSnapshot, - concurrencyService: concurrencyService, - billingService: billingService, - rateLimitService: rateLimitService, - billingCacheService: billingCacheService, - identityService: identityService, - httpUpstream: httpUpstream, - deferredService: deferredService, + accountRepo: accountRepo, + groupRepo: groupRepo, + usageLogRepo: usageLogRepo, + userRepo: userRepo, + userSubRepo: userSubRepo, + cache: cache, + cfg: cfg, + schedulerSnapshot: schedulerSnapshot, + concurrencyService: concurrencyService, + billingService: billingService, + rateLimitService: rateLimitService, + billingCacheService: billingCacheService, + identityService: identityService, + httpUpstream: httpUpstream, + deferredService: deferredService, + claudeTokenProvider: claudeTokenProvider, } } @@ -1079,6 +1082,16 @@ func (s *GatewayService) GetAccessToken(ctx context.Context, account *Account) ( } func (s *GatewayService) getOAuthToken(ctx context.Context, account *Account) (string, string, error) { + // 对于 Anthropic OAuth 账号,使用 ClaudeTokenProvider 获取缓存的 token + if account.Platform == PlatformAnthropic && account.Type == AccountTypeOAuth && s.claudeTokenProvider != nil { + accessToken, err := s.claudeTokenProvider.GetAccessToken(ctx, account) + if err != nil { + return "", "", err + } + return accessToken, "oauth", nil + } + + // 其他情况(Gemini 有自己的 TokenProvider,setup-token 类型等)直接从账号读取 accessToken := account.GetCredential("access_token") if accessToken == "" { return "", "", errors.New("access_token not found in credentials") diff --git a/backend/internal/service/openai_gateway_service.go b/backend/internal/service/openai_gateway_service.go index cfba6460..0e8a35ed 100644 --- a/backend/internal/service/openai_gateway_service.go +++ b/backend/internal/service/openai_gateway_service.go @@ -80,19 +80,20 @@ type OpenAIForwardResult struct { // OpenAIGatewayService handles OpenAI API gateway operations type OpenAIGatewayService struct { - accountRepo AccountRepository - usageLogRepo UsageLogRepository - userRepo UserRepository - userSubRepo UserSubscriptionRepository - cache GatewayCache - cfg *config.Config - schedulerSnapshot *SchedulerSnapshotService - concurrencyService *ConcurrencyService - billingService *BillingService - rateLimitService *RateLimitService - billingCacheService *BillingCacheService - httpUpstream HTTPUpstream - deferredService *DeferredService + accountRepo AccountRepository + usageLogRepo UsageLogRepository + userRepo UserRepository + userSubRepo UserSubscriptionRepository + cache GatewayCache + cfg *config.Config + schedulerSnapshot *SchedulerSnapshotService + concurrencyService *ConcurrencyService + billingService *BillingService + rateLimitService *RateLimitService + billingCacheService *BillingCacheService + httpUpstream HTTPUpstream + deferredService *DeferredService + openAITokenProvider *OpenAITokenProvider } // NewOpenAIGatewayService creates a new OpenAIGatewayService @@ -110,21 +111,23 @@ func NewOpenAIGatewayService( billingCacheService *BillingCacheService, httpUpstream HTTPUpstream, deferredService *DeferredService, + openAITokenProvider *OpenAITokenProvider, ) *OpenAIGatewayService { return &OpenAIGatewayService{ - accountRepo: accountRepo, - usageLogRepo: usageLogRepo, - userRepo: userRepo, - userSubRepo: userSubRepo, - cache: cache, - cfg: cfg, - schedulerSnapshot: schedulerSnapshot, - concurrencyService: concurrencyService, - billingService: billingService, - rateLimitService: rateLimitService, - billingCacheService: billingCacheService, - httpUpstream: httpUpstream, - deferredService: deferredService, + accountRepo: accountRepo, + usageLogRepo: usageLogRepo, + userRepo: userRepo, + userSubRepo: userSubRepo, + cache: cache, + cfg: cfg, + schedulerSnapshot: schedulerSnapshot, + concurrencyService: concurrencyService, + billingService: billingService, + rateLimitService: rateLimitService, + billingCacheService: billingCacheService, + httpUpstream: httpUpstream, + deferredService: deferredService, + openAITokenProvider: openAITokenProvider, } } @@ -503,6 +506,15 @@ func (s *OpenAIGatewayService) schedulingConfig() config.GatewaySchedulingConfig func (s *OpenAIGatewayService) GetAccessToken(ctx context.Context, account *Account) (string, string, error) { switch account.Type { case AccountTypeOAuth: + // 使用 TokenProvider 获取缓存的 token + if s.openAITokenProvider != nil { + accessToken, err := s.openAITokenProvider.GetAccessToken(ctx, account) + if err != nil { + return "", "", err + } + return accessToken, "oauth", nil + } + // 降级:TokenProvider 未配置时直接从账号读取 accessToken := account.GetOpenAIAccessToken() if accessToken == "" { return "", "", errors.New("access_token not found in credentials") diff --git a/backend/internal/service/openai_token_provider.go b/backend/internal/service/openai_token_provider.go new file mode 100644 index 00000000..e96892cd --- /dev/null +++ b/backend/internal/service/openai_token_provider.go @@ -0,0 +1,146 @@ +package service + +import ( + "context" + "errors" + "log/slog" + "strings" + "time" +) + +const ( + openAITokenRefreshSkew = 3 * time.Minute + openAITokenCacheSkew = 5 * time.Minute + openAILockWaitTime = 200 * time.Millisecond +) + +// OpenAITokenCache Token 缓存接口(复用 GeminiTokenCache 接口定义) +type OpenAITokenCache = GeminiTokenCache + +// OpenAITokenProvider 管理 OpenAI OAuth 账户的 access_token +type OpenAITokenProvider struct { + accountRepo AccountRepository + tokenCache OpenAITokenCache + openAIOAuthService *OpenAIOAuthService +} + +func NewOpenAITokenProvider( + accountRepo AccountRepository, + tokenCache OpenAITokenCache, + openAIOAuthService *OpenAIOAuthService, +) *OpenAITokenProvider { + return &OpenAITokenProvider{ + accountRepo: accountRepo, + tokenCache: tokenCache, + openAIOAuthService: openAIOAuthService, + } +} + +// GetAccessToken 获取有效的 access_token +func (p *OpenAITokenProvider) GetAccessToken(ctx context.Context, account *Account) (string, error) { + if account == nil { + return "", errors.New("account is nil") + } + if account.Platform != PlatformOpenAI || account.Type != AccountTypeOAuth { + return "", errors.New("not an openai oauth account") + } + + cacheKey := OpenAITokenCacheKey(account) + + // 1. 先尝试缓存 + if p.tokenCache != nil { + if token, err := p.tokenCache.GetAccessToken(ctx, cacheKey); err == nil && strings.TrimSpace(token) != "" { + slog.Debug("openai_token_cache_hit", "account_id", account.ID) + return token, nil + } else if err != nil { + slog.Warn("openai_token_cache_get_failed", "account_id", account.ID, "error", err) + } + } + + slog.Debug("openai_token_cache_miss", "account_id", account.ID) + + // 2. 如果即将过期则刷新 + expiresAt := account.GetCredentialAsTime("expires_at") + needsRefresh := expiresAt == nil || time.Until(*expiresAt) <= openAITokenRefreshSkew + refreshFailed := false + if needsRefresh && p.tokenCache != nil { + locked, err := p.tokenCache.AcquireRefreshLock(ctx, cacheKey, 30*time.Second) + if err == nil && locked { + defer func() { _ = p.tokenCache.ReleaseRefreshLock(ctx, cacheKey) }() + + // 拿到锁后再次检查缓存(另一个 worker 可能已刷新) + if token, err := p.tokenCache.GetAccessToken(ctx, cacheKey); err == nil && strings.TrimSpace(token) != "" { + return token, nil + } + + // 从数据库获取最新账户信息 + fresh, err := p.accountRepo.GetByID(ctx, account.ID) + if err == nil && fresh != nil { + account = fresh + } + expiresAt = account.GetCredentialAsTime("expires_at") + if expiresAt == nil || time.Until(*expiresAt) <= openAITokenRefreshSkew { + if p.openAIOAuthService == nil { + slog.Warn("openai_oauth_service_not_configured", "account_id", account.ID) + refreshFailed = true // 无法刷新,标记失败 + } else { + tokenInfo, err := p.openAIOAuthService.RefreshAccountToken(ctx, account) + if err != nil { + // 刷新失败时记录警告,但不立即返回错误,尝试使用现有 token + slog.Warn("openai_token_refresh_failed", "account_id", account.ID, "error", err) + refreshFailed = true // 刷新失败,标记以使用短 TTL + } else { + newCredentials := p.openAIOAuthService.BuildAccountCredentials(tokenInfo) + for k, v := range account.Credentials { + if _, exists := newCredentials[k]; !exists { + newCredentials[k] = v + } + } + account.Credentials = newCredentials + if updateErr := p.accountRepo.Update(ctx, account); updateErr != nil { + slog.Error("openai_token_provider_update_failed", "account_id", account.ID, "error", updateErr) + } + expiresAt = account.GetCredentialAsTime("expires_at") + } + } + } + } else { + // 锁获取失败,等待 200ms 后重试读取缓存(改进:减少并发时的缓存未命中) + time.Sleep(openAILockWaitTime) + if token, err := p.tokenCache.GetAccessToken(ctx, cacheKey); err == nil && strings.TrimSpace(token) != "" { + slog.Debug("openai_token_cache_hit_after_wait", "account_id", account.ID) + return token, nil + } + } + } + + accessToken := account.GetOpenAIAccessToken() + if strings.TrimSpace(accessToken) == "" { + return "", errors.New("access_token not found in credentials") + } + + // 3. 存入缓存 + if p.tokenCache != nil { + ttl := 30 * time.Minute + if refreshFailed { + // 刷新失败时使用短 TTL,避免失效 token 长时间缓存导致 401 抖动 + ttl = time.Minute + slog.Debug("openai_token_cache_short_ttl", "account_id", account.ID, "reason", "refresh_failed") + } else if expiresAt != nil { + until := time.Until(*expiresAt) + switch { + case until > openAITokenCacheSkew: + ttl = until - openAITokenCacheSkew + case until > 0: + ttl = until + default: + ttl = time.Minute + } + } + if err := p.tokenCache.SetAccessToken(ctx, cacheKey, accessToken, ttl); err != nil { + slog.Warn("openai_token_cache_set_failed", "account_id", account.ID, "error", err) + } + } + + return accessToken, nil +} diff --git a/backend/internal/service/openai_token_provider_test.go b/backend/internal/service/openai_token_provider_test.go new file mode 100644 index 00000000..29f9f769 --- /dev/null +++ b/backend/internal/service/openai_token_provider_test.go @@ -0,0 +1,810 @@ +//go:build unit + +package service + +import ( + "context" + "errors" + "sync" + "sync/atomic" + "testing" + "time" + + "github.com/stretchr/testify/require" +) + +// openAITokenCacheStub implements OpenAITokenCache for testing +type openAITokenCacheStub struct { + mu sync.Mutex + tokens map[string]string + getErr error + setErr error + deleteErr error + lockAcquired bool + lockErr error + releaseLockErr error + getCalled int32 + setCalled int32 + lockCalled int32 + unlockCalled int32 + simulateLockRace bool +} + +func newOpenAITokenCacheStub() *openAITokenCacheStub { + return &openAITokenCacheStub{ + tokens: make(map[string]string), + lockAcquired: true, + } +} + +func (s *openAITokenCacheStub) GetAccessToken(ctx context.Context, cacheKey string) (string, error) { + atomic.AddInt32(&s.getCalled, 1) + if s.getErr != nil { + return "", s.getErr + } + s.mu.Lock() + defer s.mu.Unlock() + return s.tokens[cacheKey], nil +} + +func (s *openAITokenCacheStub) SetAccessToken(ctx context.Context, cacheKey string, token string, ttl time.Duration) error { + atomic.AddInt32(&s.setCalled, 1) + if s.setErr != nil { + return s.setErr + } + s.mu.Lock() + defer s.mu.Unlock() + s.tokens[cacheKey] = token + return nil +} + +func (s *openAITokenCacheStub) DeleteAccessToken(ctx context.Context, cacheKey string) error { + if s.deleteErr != nil { + return s.deleteErr + } + s.mu.Lock() + defer s.mu.Unlock() + delete(s.tokens, cacheKey) + return nil +} + +func (s *openAITokenCacheStub) AcquireRefreshLock(ctx context.Context, cacheKey string, ttl time.Duration) (bool, error) { + atomic.AddInt32(&s.lockCalled, 1) + if s.lockErr != nil { + return false, s.lockErr + } + if s.simulateLockRace { + return false, nil + } + return s.lockAcquired, nil +} + +func (s *openAITokenCacheStub) ReleaseRefreshLock(ctx context.Context, cacheKey string) error { + atomic.AddInt32(&s.unlockCalled, 1) + return s.releaseLockErr +} + +// openAIAccountRepoStub is a minimal stub implementing only the methods used by OpenAITokenProvider +type openAIAccountRepoStub struct { + account *Account + getErr error + updateErr error + getCalled int32 + updateCalled int32 +} + +func (r *openAIAccountRepoStub) GetByID(ctx context.Context, id int64) (*Account, error) { + atomic.AddInt32(&r.getCalled, 1) + if r.getErr != nil { + return nil, r.getErr + } + return r.account, nil +} + +func (r *openAIAccountRepoStub) Update(ctx context.Context, account *Account) error { + atomic.AddInt32(&r.updateCalled, 1) + if r.updateErr != nil { + return r.updateErr + } + r.account = account + return nil +} + +// openAIOAuthServiceStub implements OpenAIOAuthService methods for testing +type openAIOAuthServiceStub struct { + tokenInfo *OpenAITokenInfo + refreshErr error + refreshCalled int32 +} + +func (s *openAIOAuthServiceStub) RefreshAccountToken(ctx context.Context, account *Account) (*OpenAITokenInfo, error) { + atomic.AddInt32(&s.refreshCalled, 1) + if s.refreshErr != nil { + return nil, s.refreshErr + } + return s.tokenInfo, nil +} + +func (s *openAIOAuthServiceStub) BuildAccountCredentials(info *OpenAITokenInfo) map[string]any { + now := time.Now() + return map[string]any{ + "access_token": info.AccessToken, + "refresh_token": info.RefreshToken, + "expires_at": now.Add(time.Duration(info.ExpiresIn) * time.Second).Format(time.RFC3339), + } +} + +func TestOpenAITokenProvider_CacheHit(t *testing.T) { + cache := newOpenAITokenCacheStub() + account := &Account{ + ID: 100, + Platform: PlatformOpenAI, + Type: AccountTypeOAuth, + Credentials: map[string]any{ + "access_token": "db-token", + }, + } + cacheKey := OpenAITokenCacheKey(account) + cache.tokens[cacheKey] = "cached-token" + + provider := NewOpenAITokenProvider(nil, cache, nil) + + token, err := provider.GetAccessToken(context.Background(), account) + require.NoError(t, err) + require.Equal(t, "cached-token", token) + require.Equal(t, int32(1), atomic.LoadInt32(&cache.getCalled)) + require.Equal(t, int32(0), atomic.LoadInt32(&cache.setCalled)) +} + +func TestOpenAITokenProvider_CacheMiss_FromCredentials(t *testing.T) { + cache := newOpenAITokenCacheStub() + // Token expires in far future, no refresh needed + expiresAt := time.Now().Add(1 * time.Hour).Format(time.RFC3339) + account := &Account{ + ID: 101, + Platform: PlatformOpenAI, + Type: AccountTypeOAuth, + Credentials: map[string]any{ + "access_token": "credential-token", + "expires_at": expiresAt, + }, + } + + provider := NewOpenAITokenProvider(nil, cache, nil) + + token, err := provider.GetAccessToken(context.Background(), account) + require.NoError(t, err) + require.Equal(t, "credential-token", token) + + // Should have stored in cache + cacheKey := OpenAITokenCacheKey(account) + require.Equal(t, "credential-token", cache.tokens[cacheKey]) +} + +func TestOpenAITokenProvider_TokenRefresh(t *testing.T) { + cache := newOpenAITokenCacheStub() + accountRepo := &openAIAccountRepoStub{} + oauthService := &openAIOAuthServiceStub{ + tokenInfo: &OpenAITokenInfo{ + AccessToken: "refreshed-token", + RefreshToken: "new-refresh-token", + ExpiresIn: 3600, + }, + } + + // Token expires soon (within refresh skew) + expiresAt := time.Now().Add(1 * time.Minute).Format(time.RFC3339) + account := &Account{ + ID: 102, + Platform: PlatformOpenAI, + Type: AccountTypeOAuth, + Credentials: map[string]any{ + "access_token": "old-token", + "refresh_token": "old-refresh-token", + "expires_at": expiresAt, + }, + } + accountRepo.account = account + + // We need to directly test with the stub - create a custom provider + customProvider := &testOpenAITokenProvider{ + accountRepo: accountRepo, + tokenCache: cache, + oauthService: oauthService, + } + + token, err := customProvider.GetAccessToken(context.Background(), account) + require.NoError(t, err) + require.Equal(t, "refreshed-token", token) + require.Equal(t, int32(1), atomic.LoadInt32(&oauthService.refreshCalled)) +} + +// testOpenAITokenProvider is a test version that uses the stub OAuth service +type testOpenAITokenProvider struct { + accountRepo *openAIAccountRepoStub + tokenCache *openAITokenCacheStub + oauthService *openAIOAuthServiceStub +} + +func (p *testOpenAITokenProvider) GetAccessToken(ctx context.Context, account *Account) (string, error) { + if account == nil { + return "", errors.New("account is nil") + } + if account.Platform != PlatformOpenAI || account.Type != AccountTypeOAuth { + return "", errors.New("not an openai oauth account") + } + + cacheKey := OpenAITokenCacheKey(account) + + // 1. Check cache + if p.tokenCache != nil { + if token, err := p.tokenCache.GetAccessToken(ctx, cacheKey); err == nil && token != "" { + return token, nil + } + } + + // 2. Check if refresh needed + expiresAt := account.GetCredentialAsTime("expires_at") + needsRefresh := expiresAt == nil || time.Until(*expiresAt) <= openAITokenRefreshSkew + refreshFailed := false + if needsRefresh && p.tokenCache != nil { + locked, err := p.tokenCache.AcquireRefreshLock(ctx, cacheKey, 30*time.Second) + if err == nil && locked { + defer func() { _ = p.tokenCache.ReleaseRefreshLock(ctx, cacheKey) }() + + // Check cache again after acquiring lock + if token, err := p.tokenCache.GetAccessToken(ctx, cacheKey); err == nil && token != "" { + return token, nil + } + + // Get fresh account from DB + fresh, err := p.accountRepo.GetByID(ctx, account.ID) + if err == nil && fresh != nil { + account = fresh + } + expiresAt = account.GetCredentialAsTime("expires_at") + if expiresAt == nil || time.Until(*expiresAt) <= openAITokenRefreshSkew { + if p.oauthService == nil { + refreshFailed = true // 无法刷新,标记失败 + } else { + tokenInfo, err := p.oauthService.RefreshAccountToken(ctx, account) + if err != nil { + refreshFailed = true // 刷新失败,标记以使用短 TTL + } else { + newCredentials := p.oauthService.BuildAccountCredentials(tokenInfo) + for k, v := range account.Credentials { + if _, exists := newCredentials[k]; !exists { + newCredentials[k] = v + } + } + account.Credentials = newCredentials + _ = p.accountRepo.Update(ctx, account) + expiresAt = account.GetCredentialAsTime("expires_at") + } + } + } + } else if p.tokenCache.simulateLockRace { + // Wait and retry cache + time.Sleep(10 * time.Millisecond) // Short wait for test + if token, err := p.tokenCache.GetAccessToken(ctx, cacheKey); err == nil && token != "" { + return token, nil + } + } + } + + accessToken := account.GetOpenAIAccessToken() + if accessToken == "" { + return "", errors.New("access_token not found in credentials") + } + + // 3. Store in cache + if p.tokenCache != nil { + ttl := 30 * time.Minute + if refreshFailed { + ttl = time.Minute // 刷新失败时使用短 TTL + } else if expiresAt != nil { + until := time.Until(*expiresAt) + if until > openAITokenCacheSkew { + ttl = until - openAITokenCacheSkew + } else if until > 0 { + ttl = until + } else { + ttl = time.Minute + } + } + _ = p.tokenCache.SetAccessToken(ctx, cacheKey, accessToken, ttl) + } + + return accessToken, nil +} + +func TestOpenAITokenProvider_LockRaceCondition(t *testing.T) { + cache := newOpenAITokenCacheStub() + cache.simulateLockRace = true + accountRepo := &openAIAccountRepoStub{} + + // Token expires soon + expiresAt := time.Now().Add(1 * time.Minute).Format(time.RFC3339) + account := &Account{ + ID: 103, + Platform: PlatformOpenAI, + Type: AccountTypeOAuth, + Credentials: map[string]any{ + "access_token": "race-token", + "expires_at": expiresAt, + }, + } + accountRepo.account = account + + // Simulate another worker already refreshed and cached + cacheKey := OpenAITokenCacheKey(account) + go func() { + time.Sleep(5 * time.Millisecond) + cache.mu.Lock() + cache.tokens[cacheKey] = "winner-token" + cache.mu.Unlock() + }() + + provider := &testOpenAITokenProvider{ + accountRepo: accountRepo, + tokenCache: cache, + } + + token, err := provider.GetAccessToken(context.Background(), account) + require.NoError(t, err) + // Should get the token set by the "winner" or the original + require.NotEmpty(t, token) +} + +func TestOpenAITokenProvider_NilAccount(t *testing.T) { + provider := NewOpenAITokenProvider(nil, nil, nil) + + token, err := provider.GetAccessToken(context.Background(), nil) + require.Error(t, err) + require.Contains(t, err.Error(), "account is nil") + require.Empty(t, token) +} + +func TestOpenAITokenProvider_WrongPlatform(t *testing.T) { + provider := NewOpenAITokenProvider(nil, nil, nil) + account := &Account{ + ID: 104, + Platform: PlatformGemini, + Type: AccountTypeOAuth, + } + + token, err := provider.GetAccessToken(context.Background(), account) + require.Error(t, err) + require.Contains(t, err.Error(), "not an openai oauth account") + require.Empty(t, token) +} + +func TestOpenAITokenProvider_WrongAccountType(t *testing.T) { + provider := NewOpenAITokenProvider(nil, nil, nil) + account := &Account{ + ID: 105, + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + } + + token, err := provider.GetAccessToken(context.Background(), account) + require.Error(t, err) + require.Contains(t, err.Error(), "not an openai oauth account") + require.Empty(t, token) +} + +func TestOpenAITokenProvider_NilCache(t *testing.T) { + // Token doesn't need refresh + expiresAt := time.Now().Add(1 * time.Hour).Format(time.RFC3339) + account := &Account{ + ID: 106, + Platform: PlatformOpenAI, + Type: AccountTypeOAuth, + Credentials: map[string]any{ + "access_token": "nocache-token", + "expires_at": expiresAt, + }, + } + + provider := NewOpenAITokenProvider(nil, nil, nil) + + token, err := provider.GetAccessToken(context.Background(), account) + require.NoError(t, err) + require.Equal(t, "nocache-token", token) +} + +func TestOpenAITokenProvider_CacheGetError(t *testing.T) { + cache := newOpenAITokenCacheStub() + cache.getErr = errors.New("redis connection failed") + + // Token doesn't need refresh + expiresAt := time.Now().Add(1 * time.Hour).Format(time.RFC3339) + account := &Account{ + ID: 107, + Platform: PlatformOpenAI, + Type: AccountTypeOAuth, + Credentials: map[string]any{ + "access_token": "fallback-token", + "expires_at": expiresAt, + }, + } + + provider := NewOpenAITokenProvider(nil, cache, nil) + + // Should gracefully degrade and return from credentials + token, err := provider.GetAccessToken(context.Background(), account) + require.NoError(t, err) + require.Equal(t, "fallback-token", token) +} + +func TestOpenAITokenProvider_CacheSetError(t *testing.T) { + cache := newOpenAITokenCacheStub() + cache.setErr = errors.New("redis write failed") + + expiresAt := time.Now().Add(1 * time.Hour).Format(time.RFC3339) + account := &Account{ + ID: 108, + Platform: PlatformOpenAI, + Type: AccountTypeOAuth, + Credentials: map[string]any{ + "access_token": "still-works-token", + "expires_at": expiresAt, + }, + } + + provider := NewOpenAITokenProvider(nil, cache, nil) + + // Should still work even if cache set fails + token, err := provider.GetAccessToken(context.Background(), account) + require.NoError(t, err) + require.Equal(t, "still-works-token", token) +} + +func TestOpenAITokenProvider_MissingAccessToken(t *testing.T) { + cache := newOpenAITokenCacheStub() + expiresAt := time.Now().Add(1 * time.Hour).Format(time.RFC3339) + account := &Account{ + ID: 109, + Platform: PlatformOpenAI, + Type: AccountTypeOAuth, + Credentials: map[string]any{ + "expires_at": expiresAt, + // missing access_token + }, + } + + provider := NewOpenAITokenProvider(nil, cache, nil) + + token, err := provider.GetAccessToken(context.Background(), account) + require.Error(t, err) + require.Contains(t, err.Error(), "access_token not found") + require.Empty(t, token) +} + +func TestOpenAITokenProvider_RefreshError(t *testing.T) { + cache := newOpenAITokenCacheStub() + accountRepo := &openAIAccountRepoStub{} + oauthService := &openAIOAuthServiceStub{ + refreshErr: errors.New("oauth refresh failed"), + } + + // Token expires soon + expiresAt := time.Now().Add(1 * time.Minute).Format(time.RFC3339) + account := &Account{ + ID: 110, + Platform: PlatformOpenAI, + Type: AccountTypeOAuth, + Credentials: map[string]any{ + "access_token": "old-token", + "refresh_token": "old-refresh-token", + "expires_at": expiresAt, + }, + } + accountRepo.account = account + + provider := &testOpenAITokenProvider{ + accountRepo: accountRepo, + tokenCache: cache, + oauthService: oauthService, + } + + // Now with fallback behavior, should return existing token even if refresh fails + token, err := provider.GetAccessToken(context.Background(), account) + require.NoError(t, err) + require.Equal(t, "old-token", token) // Fallback to existing token +} + +func TestOpenAITokenProvider_OAuthServiceNotConfigured(t *testing.T) { + cache := newOpenAITokenCacheStub() + accountRepo := &openAIAccountRepoStub{} + + // Token expires soon + expiresAt := time.Now().Add(1 * time.Minute).Format(time.RFC3339) + account := &Account{ + ID: 111, + Platform: PlatformOpenAI, + Type: AccountTypeOAuth, + Credentials: map[string]any{ + "access_token": "old-token", + "expires_at": expiresAt, + }, + } + accountRepo.account = account + + provider := &testOpenAITokenProvider{ + accountRepo: accountRepo, + tokenCache: cache, + oauthService: nil, // not configured + } + + // Now with fallback behavior, should return existing token even if oauth service not configured + token, err := provider.GetAccessToken(context.Background(), account) + require.NoError(t, err) + require.Equal(t, "old-token", token) // Fallback to existing token +} + +func TestOpenAITokenProvider_TTLCalculation(t *testing.T) { + tests := []struct { + name string + expiresIn time.Duration + }{ + { + name: "far_future_expiry", + expiresIn: 1 * time.Hour, + }, + { + name: "medium_expiry", + expiresIn: 10 * time.Minute, + }, + { + name: "near_expiry", + expiresIn: 6 * time.Minute, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + cache := newOpenAITokenCacheStub() + expiresAt := time.Now().Add(tt.expiresIn).Format(time.RFC3339) + account := &Account{ + ID: 200, + Platform: PlatformOpenAI, + Type: AccountTypeOAuth, + Credentials: map[string]any{ + "access_token": "test-token", + "expires_at": expiresAt, + }, + } + + provider := NewOpenAITokenProvider(nil, cache, nil) + + _, err := provider.GetAccessToken(context.Background(), account) + require.NoError(t, err) + + // Verify token was cached + cacheKey := OpenAITokenCacheKey(account) + require.Equal(t, "test-token", cache.tokens[cacheKey]) + }) + } +} + +func TestOpenAITokenProvider_DoubleCheckAfterLock(t *testing.T) { + cache := newOpenAITokenCacheStub() + accountRepo := &openAIAccountRepoStub{} + oauthService := &openAIOAuthServiceStub{ + tokenInfo: &OpenAITokenInfo{ + AccessToken: "refreshed-token", + RefreshToken: "new-refresh", + ExpiresIn: 3600, + }, + } + + // Token expires soon + expiresAt := time.Now().Add(1 * time.Minute).Format(time.RFC3339) + account := &Account{ + ID: 112, + Platform: PlatformOpenAI, + Type: AccountTypeOAuth, + Credentials: map[string]any{ + "access_token": "old-token", + "expires_at": expiresAt, + }, + } + accountRepo.account = account + cacheKey := OpenAITokenCacheKey(account) + + // Simulate: first GetAccessToken returns empty, but after lock acquired, cache has token + originalGet := int32(0) + cache.tokens[cacheKey] = "" // Empty initially + + provider := &testOpenAITokenProvider{ + accountRepo: accountRepo, + tokenCache: cache, + oauthService: oauthService, + } + + // In a goroutine, set the cached token after a small delay (simulating race) + go func() { + time.Sleep(5 * time.Millisecond) + cache.mu.Lock() + cache.tokens[cacheKey] = "cached-by-other" + cache.mu.Unlock() + }() + + token, err := provider.GetAccessToken(context.Background(), account) + require.NoError(t, err) + // Should get either the refreshed token or the cached one + require.NotEmpty(t, token) + _ = originalGet // Suppress unused warning +} + +// Tests for real provider - to increase coverage +func TestOpenAITokenProvider_Real_LockFailedWait(t *testing.T) { + cache := newOpenAITokenCacheStub() + cache.lockAcquired = false // Lock acquisition fails + + // Token expires soon (within refresh skew) to trigger lock attempt + expiresAt := time.Now().Add(1 * time.Minute).Format(time.RFC3339) + account := &Account{ + ID: 200, + Platform: PlatformOpenAI, + Type: AccountTypeOAuth, + Credentials: map[string]any{ + "access_token": "fallback-token", + "expires_at": expiresAt, + }, + } + + // Set token in cache after lock wait period (simulate other worker refreshing) + cacheKey := OpenAITokenCacheKey(account) + go func() { + time.Sleep(100 * time.Millisecond) + cache.mu.Lock() + cache.tokens[cacheKey] = "refreshed-by-other" + cache.mu.Unlock() + }() + + provider := NewOpenAITokenProvider(nil, cache, nil) + token, err := provider.GetAccessToken(context.Background(), account) + require.NoError(t, err) + // Should get either the fallback token or the refreshed one + require.NotEmpty(t, token) +} + +func TestOpenAITokenProvider_Real_CacheHitAfterWait(t *testing.T) { + cache := newOpenAITokenCacheStub() + cache.lockAcquired = false // Lock acquisition fails + + // Token expires soon + expiresAt := time.Now().Add(1 * time.Minute).Format(time.RFC3339) + account := &Account{ + ID: 201, + Platform: PlatformOpenAI, + Type: AccountTypeOAuth, + Credentials: map[string]any{ + "access_token": "original-token", + "expires_at": expiresAt, + }, + } + + cacheKey := OpenAITokenCacheKey(account) + // Set token in cache immediately after wait starts + go func() { + time.Sleep(50 * time.Millisecond) + cache.mu.Lock() + cache.tokens[cacheKey] = "winner-token" + cache.mu.Unlock() + }() + + provider := NewOpenAITokenProvider(nil, cache, nil) + token, err := provider.GetAccessToken(context.Background(), account) + require.NoError(t, err) + require.NotEmpty(t, token) +} + +func TestOpenAITokenProvider_Real_ExpiredWithoutRefreshToken(t *testing.T) { + cache := newOpenAITokenCacheStub() + cache.lockAcquired = false // Prevent entering refresh logic + + // Token with nil expires_at (no expiry set) - should use credentials + account := &Account{ + ID: 202, + Platform: PlatformOpenAI, + Type: AccountTypeOAuth, + Credentials: map[string]any{ + "access_token": "no-expiry-token", + }, + } + + provider := NewOpenAITokenProvider(nil, cache, nil) + token, err := provider.GetAccessToken(context.Background(), account) + // Without OAuth service, refresh will fail but token should be returned from credentials + require.NoError(t, err) + require.Equal(t, "no-expiry-token", token) +} + +func TestOpenAITokenProvider_Real_WhitespaceToken(t *testing.T) { + cache := newOpenAITokenCacheStub() + cacheKey := "openai:account:203" + cache.tokens[cacheKey] = " " // Whitespace only - should be treated as empty + + expiresAt := time.Now().Add(1 * time.Hour).Format(time.RFC3339) + account := &Account{ + ID: 203, + Platform: PlatformOpenAI, + Type: AccountTypeOAuth, + Credentials: map[string]any{ + "access_token": "real-token", + "expires_at": expiresAt, + }, + } + + provider := NewOpenAITokenProvider(nil, cache, nil) + token, err := provider.GetAccessToken(context.Background(), account) + require.NoError(t, err) + require.Equal(t, "real-token", token) // Should fall back to credentials +} + +func TestOpenAITokenProvider_Real_LockError(t *testing.T) { + cache := newOpenAITokenCacheStub() + cache.lockErr = errors.New("redis lock failed") + + // Token expires soon (within refresh skew) + expiresAt := time.Now().Add(1 * time.Minute).Format(time.RFC3339) + account := &Account{ + ID: 204, + Platform: PlatformOpenAI, + Type: AccountTypeOAuth, + Credentials: map[string]any{ + "access_token": "fallback-on-lock-error", + "expires_at": expiresAt, + }, + } + + provider := NewOpenAITokenProvider(nil, cache, nil) + token, err := provider.GetAccessToken(context.Background(), account) + require.NoError(t, err) + require.Equal(t, "fallback-on-lock-error", token) +} + +func TestOpenAITokenProvider_Real_WhitespaceCredentialToken(t *testing.T) { + cache := newOpenAITokenCacheStub() + + expiresAt := time.Now().Add(1 * time.Hour).Format(time.RFC3339) + account := &Account{ + ID: 205, + Platform: PlatformOpenAI, + Type: AccountTypeOAuth, + Credentials: map[string]any{ + "access_token": " ", // Whitespace only + "expires_at": expiresAt, + }, + } + + provider := NewOpenAITokenProvider(nil, cache, nil) + token, err := provider.GetAccessToken(context.Background(), account) + require.Error(t, err) + require.Contains(t, err.Error(), "access_token not found") + require.Empty(t, token) +} + +func TestOpenAITokenProvider_Real_NilCredentials(t *testing.T) { + cache := newOpenAITokenCacheStub() + + expiresAt := time.Now().Add(1 * time.Hour).Format(time.RFC3339) + account := &Account{ + ID: 206, + Platform: PlatformOpenAI, + Type: AccountTypeOAuth, + Credentials: map[string]any{ + "expires_at": expiresAt, + // No access_token + }, + } + + provider := NewOpenAITokenProvider(nil, cache, nil) + token, err := provider.GetAccessToken(context.Background(), account) + require.Error(t, err) + require.Contains(t, err.Error(), "access_token not found") + require.Empty(t, token) +} diff --git a/backend/internal/service/ratelimit_service.go b/backend/internal/service/ratelimit_service.go index ca479486..effb7e9a 100644 --- a/backend/internal/service/ratelimit_service.go +++ b/backend/internal/service/ratelimit_service.go @@ -85,8 +85,8 @@ func (s *RateLimitService) HandleUpstreamError(ctx context.Context, account *Acc switch statusCode { case 401: - if account.Type == AccountTypeOAuth && - (account.Platform == PlatformAntigravity || account.Platform == PlatformGemini) { + // 对所有 OAuth 账号在 401 错误时调用缓存失效 + if account.Type == AccountTypeOAuth { if s.tokenCacheInvalidator != nil { if err := s.tokenCacheInvalidator.InvalidateToken(ctx, account); err != nil { slog.Warn("oauth_401_invalidate_cache_failed", "account_id", account.ID, "error", err) diff --git a/backend/internal/service/token_cache_invalidator.go b/backend/internal/service/token_cache_invalidator.go index aacdf266..1117d2f1 100644 --- a/backend/internal/service/token_cache_invalidator.go +++ b/backend/internal/service/token_cache_invalidator.go @@ -7,29 +7,35 @@ type TokenCacheInvalidator interface { } type CompositeTokenCacheInvalidator struct { - geminiCache GeminiTokenCache + cache GeminiTokenCache // 统一使用一个缓存接口,通过缓存键前缀区分平台 } -func NewCompositeTokenCacheInvalidator(geminiCache GeminiTokenCache) *CompositeTokenCacheInvalidator { +func NewCompositeTokenCacheInvalidator(cache GeminiTokenCache) *CompositeTokenCacheInvalidator { return &CompositeTokenCacheInvalidator{ - geminiCache: geminiCache, + cache: cache, } } func (c *CompositeTokenCacheInvalidator) InvalidateToken(ctx context.Context, account *Account) error { - if c == nil || c.geminiCache == nil || account == nil { + if c == nil || c.cache == nil || account == nil { return nil } if account.Type != AccountTypeOAuth { return nil } + var cacheKey string switch account.Platform { case PlatformGemini: - return c.geminiCache.DeleteAccessToken(ctx, GeminiTokenCacheKey(account)) + cacheKey = GeminiTokenCacheKey(account) case PlatformAntigravity: - return c.geminiCache.DeleteAccessToken(ctx, AntigravityTokenCacheKey(account)) + cacheKey = AntigravityTokenCacheKey(account) + case PlatformOpenAI: + cacheKey = OpenAITokenCacheKey(account) + case PlatformAnthropic: + cacheKey = ClaudeTokenCacheKey(account) default: return nil } + return c.cache.DeleteAccessToken(ctx, cacheKey) } diff --git a/backend/internal/service/token_cache_invalidator_test.go b/backend/internal/service/token_cache_invalidator_test.go index 0090ed24..a33da60d 100644 --- a/backend/internal/service/token_cache_invalidator_test.go +++ b/backend/internal/service/token_cache_invalidator_test.go @@ -4,6 +4,7 @@ package service import ( "context" + "errors" "testing" "time" @@ -70,13 +71,99 @@ func TestCompositeTokenCacheInvalidator_Antigravity(t *testing.T) { require.Equal(t, []string{"ag:ag-project"}, cache.deletedKeys) } -func TestCompositeTokenCacheInvalidator_SkipNonOAuth(t *testing.T) { +func TestCompositeTokenCacheInvalidator_OpenAI(t *testing.T) { cache := &geminiTokenCacheStub{} invalidator := NewCompositeTokenCacheInvalidator(cache) account := &Account{ - ID: 1, - Platform: PlatformGemini, - Type: AccountTypeAPIKey, + ID: 500, + Platform: PlatformOpenAI, + Type: AccountTypeOAuth, + Credentials: map[string]any{ + "access_token": "openai-token", + }, + } + + err := invalidator.InvalidateToken(context.Background(), account) + require.NoError(t, err) + require.Equal(t, []string{"openai:account:500"}, cache.deletedKeys) +} + +func TestCompositeTokenCacheInvalidator_Claude(t *testing.T) { + cache := &geminiTokenCacheStub{} + invalidator := NewCompositeTokenCacheInvalidator(cache) + account := &Account{ + ID: 600, + Platform: PlatformAnthropic, + Type: AccountTypeOAuth, + Credentials: map[string]any{ + "access_token": "claude-token", + }, + } + + err := invalidator.InvalidateToken(context.Background(), account) + require.NoError(t, err) + require.Equal(t, []string{"claude:account:600"}, cache.deletedKeys) +} + +func TestCompositeTokenCacheInvalidator_SkipNonOAuth(t *testing.T) { + cache := &geminiTokenCacheStub{} + invalidator := NewCompositeTokenCacheInvalidator(cache) + + tests := []struct { + name string + account *Account + }{ + { + name: "gemini_api_key", + account: &Account{ + ID: 1, + Platform: PlatformGemini, + Type: AccountTypeAPIKey, + }, + }, + { + name: "openai_api_key", + account: &Account{ + ID: 2, + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + }, + }, + { + name: "claude_api_key", + account: &Account{ + ID: 3, + Platform: PlatformAnthropic, + Type: AccountTypeAPIKey, + }, + }, + { + name: "claude_setup_token", + account: &Account{ + ID: 4, + Platform: PlatformAnthropic, + Type: AccountTypeSetupToken, + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + cache.deletedKeys = nil + err := invalidator.InvalidateToken(context.Background(), tt.account) + require.NoError(t, err) + require.Empty(t, cache.deletedKeys) + }) + } +} + +func TestCompositeTokenCacheInvalidator_SkipUnsupportedPlatform(t *testing.T) { + cache := &geminiTokenCacheStub{} + invalidator := NewCompositeTokenCacheInvalidator(cache) + account := &Account{ + ID: 100, + Platform: "unknown-platform", + Type: AccountTypeOAuth, } err := invalidator.InvalidateToken(context.Background(), account) @@ -95,3 +182,87 @@ func TestCompositeTokenCacheInvalidator_NilCache(t *testing.T) { err := invalidator.InvalidateToken(context.Background(), account) require.NoError(t, err) } + +func TestCompositeTokenCacheInvalidator_NilAccount(t *testing.T) { + cache := &geminiTokenCacheStub{} + invalidator := NewCompositeTokenCacheInvalidator(cache) + + err := invalidator.InvalidateToken(context.Background(), nil) + require.NoError(t, err) + require.Empty(t, cache.deletedKeys) +} + +func TestCompositeTokenCacheInvalidator_NilInvalidator(t *testing.T) { + var invalidator *CompositeTokenCacheInvalidator + account := &Account{ + ID: 5, + Platform: PlatformGemini, + Type: AccountTypeOAuth, + } + + err := invalidator.InvalidateToken(context.Background(), account) + require.NoError(t, err) +} + +func TestCompositeTokenCacheInvalidator_DeleteError(t *testing.T) { + expectedErr := errors.New("redis connection failed") + cache := &geminiTokenCacheStub{deleteErr: expectedErr} + invalidator := NewCompositeTokenCacheInvalidator(cache) + + tests := []struct { + name string + account *Account + }{ + { + name: "openai_delete_error", + account: &Account{ + ID: 700, + Platform: PlatformOpenAI, + Type: AccountTypeOAuth, + }, + }, + { + name: "claude_delete_error", + account: &Account{ + ID: 800, + Platform: PlatformAnthropic, + Type: AccountTypeOAuth, + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := invalidator.InvalidateToken(context.Background(), tt.account) + require.Error(t, err) + require.Equal(t, expectedErr, err) + }) + } +} + +func TestCompositeTokenCacheInvalidator_AllPlatformsIntegration(t *testing.T) { + // 测试所有平台的缓存键生成和删除 + cache := &geminiTokenCacheStub{} + invalidator := NewCompositeTokenCacheInvalidator(cache) + + accounts := []*Account{ + {ID: 1, Platform: PlatformGemini, Type: AccountTypeOAuth, Credentials: map[string]any{"project_id": "gemini-proj"}}, + {ID: 2, Platform: PlatformAntigravity, Type: AccountTypeOAuth, Credentials: map[string]any{"project_id": "ag-proj"}}, + {ID: 3, Platform: PlatformOpenAI, Type: AccountTypeOAuth}, + {ID: 4, Platform: PlatformAnthropic, Type: AccountTypeOAuth}, + } + + expectedKeys := []string{ + "gemini-proj", + "ag:ag-proj", + "openai:account:3", + "claude:account:4", + } + + for _, acc := range accounts { + err := invalidator.InvalidateToken(context.Background(), acc) + require.NoError(t, err) + } + + require.Equal(t, expectedKeys, cache.deletedKeys) +} diff --git a/backend/internal/service/token_cache_key.go b/backend/internal/service/token_cache_key.go new file mode 100644 index 00000000..df0c025e --- /dev/null +++ b/backend/internal/service/token_cache_key.go @@ -0,0 +1,15 @@ +package service + +import "strconv" + +// OpenAITokenCacheKey 生成 OpenAI OAuth 账号的缓存键 +// 格式: "openai:account:{account_id}" +func OpenAITokenCacheKey(account *Account) string { + return "openai:account:" + strconv.FormatInt(account.ID, 10) +} + +// ClaudeTokenCacheKey 生成 Claude (Anthropic) OAuth 账号的缓存键 +// 格式: "claude:account:{account_id}" +func ClaudeTokenCacheKey(account *Account) string { + return "claude:account:" + strconv.FormatInt(account.ID, 10) +} diff --git a/backend/internal/service/token_cache_key_test.go b/backend/internal/service/token_cache_key_test.go index 0dc751c6..e6b33747 100644 --- a/backend/internal/service/token_cache_key_test.go +++ b/backend/internal/service/token_cache_key_test.go @@ -151,3 +151,109 @@ func TestAntigravityTokenCacheKey(t *testing.T) { }) } } + +func TestOpenAITokenCacheKey(t *testing.T) { + tests := []struct { + name string + account *Account + expected string + }{ + { + name: "basic_account", + account: &Account{ + ID: 300, + }, + expected: "openai:account:300", + }, + { + name: "account_with_credentials", + account: &Account{ + ID: 301, + Credentials: map[string]any{ + "access_token": "test-token", + }, + }, + expected: "openai:account:301", + }, + { + name: "account_id_zero", + account: &Account{ + ID: 0, + }, + expected: "openai:account:0", + }, + { + name: "large_account_id", + account: &Account{ + ID: 9999999999, + }, + expected: "openai:account:9999999999", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := OpenAITokenCacheKey(tt.account) + require.Equal(t, tt.expected, result) + }) + } +} + +func TestClaudeTokenCacheKey(t *testing.T) { + tests := []struct { + name string + account *Account + expected string + }{ + { + name: "basic_account", + account: &Account{ + ID: 400, + }, + expected: "claude:account:400", + }, + { + name: "account_with_credentials", + account: &Account{ + ID: 401, + Credentials: map[string]any{ + "access_token": "claude-token", + }, + }, + expected: "claude:account:401", + }, + { + name: "account_id_zero", + account: &Account{ + ID: 0, + }, + expected: "claude:account:0", + }, + { + name: "large_account_id", + account: &Account{ + ID: 9999999999, + }, + expected: "claude:account:9999999999", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := ClaudeTokenCacheKey(tt.account) + require.Equal(t, tt.expected, result) + }) + } +} + +func TestCacheKeyUniqueness(t *testing.T) { + // 确保不同平台的缓存键不会冲突 + account := &Account{ID: 123} + + openaiKey := OpenAITokenCacheKey(account) + claudeKey := ClaudeTokenCacheKey(account) + + require.NotEqual(t, openaiKey, claudeKey, "OpenAI and Claude cache keys should be different") + require.Contains(t, openaiKey, "openai:") + require.Contains(t, claudeKey, "claude:") +} diff --git a/backend/internal/service/token_refresh_service.go b/backend/internal/service/token_refresh_service.go index 4d513d07..26cfd97d 100644 --- a/backend/internal/service/token_refresh_service.go +++ b/backend/internal/service/token_refresh_service.go @@ -172,8 +172,8 @@ func (s *TokenRefreshService) refreshWithRetry(ctx context.Context, account *Acc if err := s.accountRepo.Update(ctx, account); err != nil { return fmt.Errorf("failed to save credentials: %w", err) } - if s.cacheInvalidator != nil && account.Type == AccountTypeOAuth && - (account.Platform == PlatformGemini || account.Platform == PlatformAntigravity) { + // 对所有 OAuth 账号调用缓存失效(InvalidateToken 内部根据平台判断是否需要处理) + if s.cacheInvalidator != nil && account.Type == AccountTypeOAuth { if err := s.cacheInvalidator.InvalidateToken(ctx, account); err != nil { log.Printf("[TokenRefresh] Failed to invalidate token cache for account %d: %v", account.ID, err) } else { diff --git a/backend/internal/service/token_refresh_service_test.go b/backend/internal/service/token_refresh_service_test.go index b11a0adc..c6ab71af 100644 --- a/backend/internal/service/token_refresh_service_test.go +++ b/backend/internal/service/token_refresh_service_test.go @@ -197,7 +197,7 @@ func TestTokenRefreshService_RefreshWithRetry_NonOAuthAccount(t *testing.T) { require.Equal(t, 0, invalidator.calls) // 非 OAuth 不触发缓存失效 } -// TestTokenRefreshService_RefreshWithRetry_OtherPlatformOAuth 测试其他平台的 OAuth 账号不触发缓存失效 +// TestTokenRefreshService_RefreshWithRetry_OtherPlatformOAuth 测试所有 OAuth 平台都触发缓存失效 func TestTokenRefreshService_RefreshWithRetry_OtherPlatformOAuth(t *testing.T) { repo := &tokenRefreshAccountRepo{} invalidator := &tokenCacheInvalidatorStub{} @@ -210,7 +210,7 @@ func TestTokenRefreshService_RefreshWithRetry_OtherPlatformOAuth(t *testing.T) { service := NewTokenRefreshService(repo, nil, nil, nil, nil, invalidator, cfg) account := &Account{ ID: 10, - Platform: PlatformOpenAI, // 其他平台 + Platform: PlatformOpenAI, // OpenAI OAuth 账户 Type: AccountTypeOAuth, } refresher := &tokenRefresherStub{ @@ -222,7 +222,7 @@ func TestTokenRefreshService_RefreshWithRetry_OtherPlatformOAuth(t *testing.T) { err := service.refreshWithRetry(context.Background(), account, refresher) require.NoError(t, err) require.Equal(t, 1, repo.updateCalls) - require.Equal(t, 0, invalidator.calls) // 其他平台不触发缓存失效 + require.Equal(t, 1, invalidator.calls) // 所有 OAuth 账户刷新后触发缓存失效 } // TestTokenRefreshService_RefreshWithRetry_UpdateFailed 测试更新失败的情况 diff --git a/backend/internal/service/wire.go b/backend/internal/service/wire.go index 05dbb0b0..5ba093a4 100644 --- a/backend/internal/service/wire.go +++ b/backend/internal/service/wire.go @@ -214,10 +214,13 @@ var ProviderSet = wire.NewSet( NewGeminiOAuthService, NewGeminiQuotaService, NewCompositeTokenCacheInvalidator, + wire.Bind(new(TokenCacheInvalidator), new(*CompositeTokenCacheInvalidator)), NewAntigravityOAuthService, NewGeminiTokenProvider, NewGeminiMessagesCompatService, NewAntigravityTokenProvider, + NewOpenAITokenProvider, + NewClaudeTokenProvider, NewAntigravityGatewayService, ProvideRateLimitService, NewAccountUsageService,