diff --git a/backend/cmd/server/wire_gen.go b/backend/cmd/server/wire_gen.go index 1e14b1c4..c8304831 100644 --- a/backend/cmd/server/wire_gen.go +++ b/backend/cmd/server/wire_gen.go @@ -99,12 +99,13 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { geminiQuotaService := service.NewGeminiQuotaService(configConfig, settingRepository) tempUnschedCache := repository.NewTempUnschedCache(redisClient) timeoutCounterCache := repository.NewTimeoutCounterCache(redisClient) - rateLimitService := service.ProvideRateLimitService(accountRepository, usageLogRepository, configConfig, geminiQuotaService, tempUnschedCache, timeoutCounterCache, settingService) + geminiTokenCache := repository.NewGeminiTokenCache(redisClient) + tokenCacheInvalidator := service.NewCompositeTokenCacheInvalidator(geminiTokenCache) + rateLimitService := service.ProvideRateLimitService(accountRepository, usageLogRepository, configConfig, geminiQuotaService, tempUnschedCache, timeoutCounterCache, settingService, tokenCacheInvalidator) claudeUsageFetcher := repository.NewClaudeUsageFetcher() antigravityQuotaFetcher := service.NewAntigravityQuotaFetcher(proxyRepository) usageCache := service.NewUsageCache() accountUsageService := service.NewAccountUsageService(accountRepository, usageLogRepository, claudeUsageFetcher, geminiQuotaService, antigravityQuotaFetcher, usageCache) - geminiTokenCache := repository.NewGeminiTokenCache(redisClient) geminiTokenProvider := service.NewGeminiTokenProvider(accountRepository, geminiTokenCache, geminiOAuthService) gatewayCache := repository.NewGatewayCache(redisClient) antigravityTokenProvider := service.NewAntigravityTokenProvider(accountRepository, geminiTokenCache, antigravityOAuthService) @@ -167,7 +168,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { opsAlertEvaluatorService := service.ProvideOpsAlertEvaluatorService(opsService, opsRepository, emailService, redisClient, configConfig) opsCleanupService := service.ProvideOpsCleanupService(opsRepository, db, redisClient, configConfig) opsScheduledReportService := service.ProvideOpsScheduledReportService(opsService, userService, emailService, redisClient, configConfig) - tokenRefreshService := service.ProvideTokenRefreshService(accountRepository, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, configConfig) + tokenRefreshService := service.ProvideTokenRefreshService(accountRepository, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, tokenCacheInvalidator, configConfig) accountExpiryService := service.ProvideAccountExpiryService(accountRepository) v := provideCleanup(client, redisClient, opsMetricsCollector, opsAggregationService, opsAlertEvaluatorService, opsCleanupService, opsScheduledReportService, schedulerSnapshotService, tokenRefreshService, accountExpiryService, pricingService, emailQueueService, billingCacheService, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService) application := &Application{ diff --git a/backend/internal/repository/gemini_token_cache.go b/backend/internal/repository/gemini_token_cache.go index a7270556..82c14def 100644 --- a/backend/internal/repository/gemini_token_cache.go +++ b/backend/internal/repository/gemini_token_cache.go @@ -33,6 +33,11 @@ func (c *geminiTokenCache) SetAccessToken(ctx context.Context, cacheKey string, return c.rdb.Set(ctx, key, token, ttl).Err() } +func (c *geminiTokenCache) DeleteAccessToken(ctx context.Context, cacheKey string) error { + key := fmt.Sprintf("%s%s", geminiTokenKeyPrefix, cacheKey) + return c.rdb.Del(ctx, key).Err() +} + func (c *geminiTokenCache) AcquireRefreshLock(ctx context.Context, cacheKey string, ttl time.Duration) (bool, error) { key := fmt.Sprintf("%s%s", geminiRefreshLockKeyPrefix, cacheKey) return c.rdb.SetNX(ctx, key, 1, ttl).Result() diff --git a/backend/internal/repository/gemini_token_cache_integration_test.go b/backend/internal/repository/gemini_token_cache_integration_test.go new file mode 100644 index 00000000..4fe89865 --- /dev/null +++ b/backend/internal/repository/gemini_token_cache_integration_test.go @@ -0,0 +1,47 @@ +//go:build integration + +package repository + +import ( + "errors" + "testing" + "time" + + "github.com/Wei-Shaw/sub2api/internal/service" + "github.com/redis/go-redis/v9" + "github.com/stretchr/testify/require" + "github.com/stretchr/testify/suite" +) + +type GeminiTokenCacheSuite struct { + IntegrationRedisSuite + cache service.GeminiTokenCache +} + +func (s *GeminiTokenCacheSuite) SetupTest() { + s.IntegrationRedisSuite.SetupTest() + s.cache = NewGeminiTokenCache(s.rdb) +} + +func (s *GeminiTokenCacheSuite) TestDeleteAccessToken() { + cacheKey := "project-123" + token := "token-value" + require.NoError(s.T(), s.cache.SetAccessToken(s.ctx, cacheKey, token, time.Minute)) + + got, err := s.cache.GetAccessToken(s.ctx, cacheKey) + require.NoError(s.T(), err) + require.Equal(s.T(), token, got) + + require.NoError(s.T(), s.cache.DeleteAccessToken(s.ctx, cacheKey)) + + _, err = s.cache.GetAccessToken(s.ctx, cacheKey) + require.True(s.T(), errors.Is(err, redis.Nil), "expected redis.Nil after delete") +} + +func (s *GeminiTokenCacheSuite) TestDeleteAccessToken_MissingKey() { + require.NoError(s.T(), s.cache.DeleteAccessToken(s.ctx, "missing-key")) +} + +func TestGeminiTokenCacheSuite(t *testing.T) { + suite.Run(t, new(GeminiTokenCacheSuite)) +} diff --git a/backend/internal/repository/gemini_token_cache_test.go b/backend/internal/repository/gemini_token_cache_test.go new file mode 100644 index 00000000..4fcebfdd --- /dev/null +++ b/backend/internal/repository/gemini_token_cache_test.go @@ -0,0 +1,28 @@ +//go:build unit + +package repository + +import ( + "context" + "testing" + "time" + + "github.com/redis/go-redis/v9" + "github.com/stretchr/testify/require" +) + +func TestGeminiTokenCache_DeleteAccessToken_RedisError(t *testing.T) { + rdb := redis.NewClient(&redis.Options{ + Addr: "127.0.0.1:1", + DialTimeout: 50 * time.Millisecond, + ReadTimeout: 50 * time.Millisecond, + WriteTimeout: 50 * time.Millisecond, + }) + t.Cleanup(func() { + _ = rdb.Close() + }) + + cache := NewGeminiTokenCache(rdb) + err := cache.DeleteAccessToken(context.Background(), "broken") + require.Error(t, err) +} diff --git a/backend/internal/service/antigravity_token_provider.go b/backend/internal/service/antigravity_token_provider.go index cbd1bef4..c5dc55db 100644 --- a/backend/internal/service/antigravity_token_provider.go +++ b/backend/internal/service/antigravity_token_provider.go @@ -45,7 +45,7 @@ func (p *AntigravityTokenProvider) GetAccessToken(ctx context.Context, account * return "", errors.New("not an antigravity oauth account") } - cacheKey := antigravityTokenCacheKey(account) + cacheKey := AntigravityTokenCacheKey(account) // 1. 先尝试缓存 if p.tokenCache != nil { @@ -121,7 +121,7 @@ func (p *AntigravityTokenProvider) GetAccessToken(ctx context.Context, account * return accessToken, nil } -func antigravityTokenCacheKey(account *Account) string { +func AntigravityTokenCacheKey(account *Account) string { projectID := strings.TrimSpace(account.GetCredential("project_id")) if projectID != "" { return "ag:" + projectID diff --git a/backend/internal/service/gemini_token_cache.go b/backend/internal/service/gemini_token_cache.go index d5e64f9a..70f246da 100644 --- a/backend/internal/service/gemini_token_cache.go +++ b/backend/internal/service/gemini_token_cache.go @@ -10,6 +10,7 @@ type GeminiTokenCache interface { // cacheKey should be stable for the token scope; for GeminiCli OAuth we primarily use project_id. GetAccessToken(ctx context.Context, cacheKey string) (string, error) SetAccessToken(ctx context.Context, cacheKey string, token string, ttl time.Duration) error + DeleteAccessToken(ctx context.Context, cacheKey string) error AcquireRefreshLock(ctx context.Context, cacheKey string, ttl time.Duration) (bool, error) ReleaseRefreshLock(ctx context.Context, cacheKey string) error diff --git a/backend/internal/service/gemini_token_provider.go b/backend/internal/service/gemini_token_provider.go index 0257d19f..a5cacc9a 100644 --- a/backend/internal/service/gemini_token_provider.go +++ b/backend/internal/service/gemini_token_provider.go @@ -40,7 +40,7 @@ func (p *GeminiTokenProvider) GetAccessToken(ctx context.Context, account *Accou return "", errors.New("not a gemini oauth account") } - cacheKey := geminiTokenCacheKey(account) + cacheKey := GeminiTokenCacheKey(account) // 1) Try cache first. if p.tokenCache != nil { @@ -151,7 +151,7 @@ func (p *GeminiTokenProvider) GetAccessToken(ctx context.Context, account *Accou return accessToken, nil } -func geminiTokenCacheKey(account *Account) string { +func GeminiTokenCacheKey(account *Account) string { projectID := strings.TrimSpace(account.GetCredential("project_id")) if projectID != "" { return projectID diff --git a/backend/internal/service/ratelimit_service.go b/backend/internal/service/ratelimit_service.go index a2281f12..ca479486 100644 --- a/backend/internal/service/ratelimit_service.go +++ b/backend/internal/service/ratelimit_service.go @@ -3,7 +3,7 @@ package service import ( "context" "encoding/json" - "log" + "log/slog" "net/http" "strconv" "strings" @@ -15,15 +15,16 @@ import ( // RateLimitService 处理限流和过载状态管理 type RateLimitService struct { - accountRepo AccountRepository - usageRepo UsageLogRepository - cfg *config.Config - geminiQuotaService *GeminiQuotaService - tempUnschedCache TempUnschedCache - timeoutCounterCache TimeoutCounterCache - settingService *SettingService - usageCacheMu sync.RWMutex - usageCache map[int64]*geminiUsageCacheEntry + accountRepo AccountRepository + usageRepo UsageLogRepository + cfg *config.Config + geminiQuotaService *GeminiQuotaService + tempUnschedCache TempUnschedCache + timeoutCounterCache TimeoutCounterCache + settingService *SettingService + tokenCacheInvalidator TokenCacheInvalidator + usageCacheMu sync.RWMutex + usageCache map[int64]*geminiUsageCacheEntry } type geminiUsageCacheEntry struct { @@ -56,6 +57,11 @@ func (s *RateLimitService) SetSettingService(settingService *SettingService) { s.settingService = settingService } +// SetTokenCacheInvalidator 设置 token 缓存清理器(可选依赖) +func (s *RateLimitService) SetTokenCacheInvalidator(invalidator TokenCacheInvalidator) { + s.tokenCacheInvalidator = invalidator +} + // HandleUpstreamError 处理上游错误响应,标记账号状态 // 返回是否应该停止该账号的调度 func (s *RateLimitService) HandleUpstreamError(ctx context.Context, account *Account, statusCode int, headers http.Header, responseBody []byte) (shouldDisable bool) { @@ -63,11 +69,14 @@ func (s *RateLimitService) HandleUpstreamError(ctx context.Context, account *Acc // 如果启用且错误码不在列表中,则不处理(不停止调度、不标记限流/过载) customErrorCodesEnabled := account.IsCustomErrorCodesEnabled() if !account.ShouldHandleErrorCode(statusCode) { - log.Printf("Account %d: error %d skipped (not in custom error codes)", account.ID, statusCode) + slog.Info("account_error_code_skipped", "account_id", account.ID, "status_code", statusCode) return false } - tempMatched := s.tryTempUnschedulable(ctx, account, statusCode, responseBody) + tempMatched := false + if statusCode != 401 { + tempMatched = s.tryTempUnschedulable(ctx, account, statusCode, responseBody) + } upstreamMsg := strings.TrimSpace(extractUpstreamErrorMessage(responseBody)) upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg) if upstreamMsg != "" { @@ -76,7 +85,14 @@ func (s *RateLimitService) HandleUpstreamError(ctx context.Context, account *Acc switch statusCode { case 401: - // 认证失败:停止调度,记录错误 + if account.Type == AccountTypeOAuth && + (account.Platform == PlatformAntigravity || account.Platform == PlatformGemini) { + if s.tokenCacheInvalidator != nil { + if err := s.tokenCacheInvalidator.InvalidateToken(ctx, account); err != nil { + slog.Warn("oauth_401_invalidate_cache_failed", "account_id", account.ID, "error", err) + } + } + } msg := "Authentication failed (401): invalid or expired credentials" if upstreamMsg != "" { msg = "Authentication failed (401): " + upstreamMsg @@ -116,7 +132,7 @@ func (s *RateLimitService) HandleUpstreamError(ctx context.Context, account *Acc shouldDisable = true } else if statusCode >= 500 { // 未启用自定义错误码时:仅记录5xx错误 - log.Printf("Account %d received upstream error %d", account.ID, statusCode) + slog.Warn("account_upstream_error", "account_id", account.ID, "status_code", statusCode) shouldDisable = false } } @@ -188,7 +204,7 @@ func (s *RateLimitService) PreCheckUsage(ctx context.Context, account *Account, // NOTE: // - This is a local precheck to reduce upstream 429s. // - Do NOT mark the account as rate-limited here; rate_limit_reset_at should reflect real upstream 429s. - log.Printf("[Gemini PreCheck] Account %d reached daily quota (%d/%d), skip until %v", account.ID, used, limit, resetAt) + slog.Info("gemini_precheck_daily_quota_reached", "account_id", account.ID, "used", used, "limit", limit, "reset_at", resetAt) return false, nil } } @@ -231,7 +247,7 @@ func (s *RateLimitService) PreCheckUsage(ctx context.Context, account *Account, if used >= limit { resetAt := start.Add(time.Minute) // Do not persist "rate limited" status from local precheck. See note above. - log.Printf("[Gemini PreCheck] Account %d reached minute quota (%d/%d), skip until %v", account.ID, used, limit, resetAt) + slog.Info("gemini_precheck_minute_quota_reached", "account_id", account.ID, "used", used, "limit", limit, "reset_at", resetAt) return false, nil } } @@ -288,20 +304,20 @@ func (s *RateLimitService) GeminiCooldown(ctx context.Context, account *Account) // handleAuthError 处理认证类错误(401/403),停止账号调度 func (s *RateLimitService) handleAuthError(ctx context.Context, account *Account, errorMsg string) { if err := s.accountRepo.SetError(ctx, account.ID, errorMsg); err != nil { - log.Printf("SetError failed for account %d: %v", account.ID, err) + slog.Warn("account_set_error_failed", "account_id", account.ID, "error", err) return } - log.Printf("Account %d disabled due to auth error: %s", account.ID, errorMsg) + slog.Warn("account_disabled_auth_error", "account_id", account.ID, "error", errorMsg) } // handleCustomErrorCode 处理自定义错误码,停止账号调度 func (s *RateLimitService) handleCustomErrorCode(ctx context.Context, account *Account, statusCode int, errorMsg string) { msg := "Custom error code " + strconv.Itoa(statusCode) + ": " + errorMsg if err := s.accountRepo.SetError(ctx, account.ID, msg); err != nil { - log.Printf("SetError failed for account %d: %v", account.ID, err) + slog.Warn("account_set_error_failed", "account_id", account.ID, "status_code", statusCode, "error", err) return } - log.Printf("Account %d disabled due to custom error code %d: %s", account.ID, statusCode, errorMsg) + slog.Warn("account_disabled_custom_error", "account_id", account.ID, "status_code", statusCode, "error", errorMsg) } // handle429 处理429限流错误 @@ -313,7 +329,7 @@ func (s *RateLimitService) handle429(ctx context.Context, account *Account, head // 没有重置时间,使用默认5分钟 resetAt := time.Now().Add(5 * time.Minute) if err := s.accountRepo.SetRateLimited(ctx, account.ID, resetAt); err != nil { - log.Printf("SetRateLimited failed for account %d: %v", account.ID, err) + slog.Warn("rate_limit_set_failed", "account_id", account.ID, "error", err) } return } @@ -321,10 +337,10 @@ func (s *RateLimitService) handle429(ctx context.Context, account *Account, head // 解析Unix时间戳 ts, err := strconv.ParseInt(resetTimestamp, 10, 64) if err != nil { - log.Printf("Parse reset timestamp failed: %v", err) + slog.Warn("rate_limit_reset_parse_failed", "reset_timestamp", resetTimestamp, "error", err) resetAt := time.Now().Add(5 * time.Minute) if err := s.accountRepo.SetRateLimited(ctx, account.ID, resetAt); err != nil { - log.Printf("SetRateLimited failed for account %d: %v", account.ID, err) + slog.Warn("rate_limit_set_failed", "account_id", account.ID, "error", err) } return } @@ -333,7 +349,7 @@ func (s *RateLimitService) handle429(ctx context.Context, account *Account, head // 标记限流状态 if err := s.accountRepo.SetRateLimited(ctx, account.ID, resetAt); err != nil { - log.Printf("SetRateLimited failed for account %d: %v", account.ID, err) + slog.Warn("rate_limit_set_failed", "account_id", account.ID, "error", err) return } @@ -341,10 +357,10 @@ func (s *RateLimitService) handle429(ctx context.Context, account *Account, head windowEnd := resetAt windowStart := resetAt.Add(-5 * time.Hour) if err := s.accountRepo.UpdateSessionWindow(ctx, account.ID, &windowStart, &windowEnd, "rejected"); err != nil { - log.Printf("UpdateSessionWindow failed for account %d: %v", account.ID, err) + slog.Warn("rate_limit_update_session_window_failed", "account_id", account.ID, "error", err) } - log.Printf("Account %d rate limited until %v", account.ID, resetAt) + slog.Info("account_rate_limited", "account_id", account.ID, "reset_at", resetAt) } // handle529 处理529过载错误 @@ -357,11 +373,11 @@ func (s *RateLimitService) handle529(ctx context.Context, account *Account) { until := time.Now().Add(time.Duration(cooldownMinutes) * time.Minute) if err := s.accountRepo.SetOverloaded(ctx, account.ID, until); err != nil { - log.Printf("SetOverloaded failed for account %d: %v", account.ID, err) + slog.Warn("overload_set_failed", "account_id", account.ID, "error", err) return } - log.Printf("Account %d overloaded until %v", account.ID, until) + slog.Info("account_overloaded", "account_id", account.ID, "until", until) } // UpdateSessionWindow 从成功响应更新5h窗口状态 @@ -384,17 +400,17 @@ func (s *RateLimitService) UpdateSessionWindow(ctx context.Context, account *Acc end := start.Add(5 * time.Hour) windowStart = &start windowEnd = &end - log.Printf("Account %d: initializing 5h window from %v to %v (status: %s)", account.ID, start, end, status) + slog.Info("account_session_window_initialized", "account_id", account.ID, "window_start", start, "window_end", end, "status", status) } if err := s.accountRepo.UpdateSessionWindow(ctx, account.ID, windowStart, windowEnd, status); err != nil { - log.Printf("UpdateSessionWindow failed for account %d: %v", account.ID, err) + slog.Warn("session_window_update_failed", "account_id", account.ID, "error", err) } // 如果状态为allowed且之前有限流,说明窗口已重置,清除限流状态 if status == "allowed" && account.IsRateLimited() { if err := s.ClearRateLimit(ctx, account.ID); err != nil { - log.Printf("ClearRateLimit failed for account %d: %v", account.ID, err) + slog.Warn("rate_limit_clear_failed", "account_id", account.ID, "error", err) } } } @@ -413,7 +429,7 @@ func (s *RateLimitService) ClearTempUnschedulable(ctx context.Context, accountID } if s.tempUnschedCache != nil { if err := s.tempUnschedCache.DeleteTempUnsched(ctx, accountID); err != nil { - log.Printf("DeleteTempUnsched failed for account %d: %v", accountID, err) + slog.Warn("temp_unsched_cache_delete_failed", "account_id", accountID, "error", err) } } return nil @@ -460,7 +476,7 @@ func (s *RateLimitService) GetTempUnschedStatus(ctx context.Context, accountID i if s.tempUnschedCache != nil { if err := s.tempUnschedCache.SetTempUnsched(ctx, accountID, state); err != nil { - log.Printf("SetTempUnsched failed for account %d: %v", accountID, err) + slog.Warn("temp_unsched_cache_set_failed", "account_id", accountID, "error", err) } } @@ -563,17 +579,17 @@ func (s *RateLimitService) triggerTempUnschedulable(ctx context.Context, account } if err := s.accountRepo.SetTempUnschedulable(ctx, account.ID, until, reason); err != nil { - log.Printf("SetTempUnschedulable failed for account %d: %v", account.ID, err) + slog.Warn("temp_unsched_set_failed", "account_id", account.ID, "error", err) return false } if s.tempUnschedCache != nil { if err := s.tempUnschedCache.SetTempUnsched(ctx, account.ID, state); err != nil { - log.Printf("SetTempUnsched cache failed for account %d: %v", account.ID, err) + slog.Warn("temp_unsched_cache_set_failed", "account_id", account.ID, "error", err) } } - log.Printf("Account %d temp unschedulable until %v (rule %d, code %d)", account.ID, until, ruleIndex, statusCode) + slog.Info("account_temp_unschedulable", "account_id", account.ID, "until", until, "rule_index", ruleIndex, "status_code", statusCode) return true } @@ -597,13 +613,13 @@ func (s *RateLimitService) HandleStreamTimeout(ctx context.Context, account *Acc // 获取系统设置 if s.settingService == nil { - log.Printf("[StreamTimeout] settingService not configured, skipping timeout handling for account %d", account.ID) + slog.Warn("stream_timeout_setting_service_missing", "account_id", account.ID) return false } settings, err := s.settingService.GetStreamTimeoutSettings(ctx) if err != nil { - log.Printf("[StreamTimeout] Failed to get settings: %v", err) + slog.Warn("stream_timeout_get_settings_failed", "account_id", account.ID, "error", err) return false } @@ -620,14 +636,13 @@ func (s *RateLimitService) HandleStreamTimeout(ctx context.Context, account *Acc if s.timeoutCounterCache != nil { count, err = s.timeoutCounterCache.IncrementTimeoutCount(ctx, account.ID, settings.ThresholdWindowMinutes) if err != nil { - log.Printf("[StreamTimeout] Failed to increment timeout count for account %d: %v", account.ID, err) + slog.Warn("stream_timeout_increment_count_failed", "account_id", account.ID, "error", err) // 继续处理,使用 count=1 count = 1 } } - log.Printf("[StreamTimeout] Account %d timeout count: %d/%d (window: %d min, model: %s)", - account.ID, count, settings.ThresholdCount, settings.ThresholdWindowMinutes, model) + slog.Info("stream_timeout_count", "account_id", account.ID, "count", count, "threshold", settings.ThresholdCount, "window_minutes", settings.ThresholdWindowMinutes, "model", model) // 检查是否达到阈值 if count < int64(settings.ThresholdCount) { @@ -668,24 +683,24 @@ func (s *RateLimitService) triggerStreamTimeoutTempUnsched(ctx context.Context, } if err := s.accountRepo.SetTempUnschedulable(ctx, account.ID, until, reason); err != nil { - log.Printf("[StreamTimeout] SetTempUnschedulable failed for account %d: %v", account.ID, err) + slog.Warn("stream_timeout_set_temp_unsched_failed", "account_id", account.ID, "error", err) return false } if s.tempUnschedCache != nil { if err := s.tempUnschedCache.SetTempUnsched(ctx, account.ID, state); err != nil { - log.Printf("[StreamTimeout] SetTempUnsched cache failed for account %d: %v", account.ID, err) + slog.Warn("stream_timeout_set_temp_unsched_cache_failed", "account_id", account.ID, "error", err) } } // 重置超时计数 if s.timeoutCounterCache != nil { if err := s.timeoutCounterCache.ResetTimeoutCount(ctx, account.ID); err != nil { - log.Printf("[StreamTimeout] ResetTimeoutCount failed for account %d: %v", account.ID, err) + slog.Warn("stream_timeout_reset_count_failed", "account_id", account.ID, "error", err) } } - log.Printf("[StreamTimeout] Account %d marked as temp unschedulable until %v (model: %s)", account.ID, until, model) + slog.Info("stream_timeout_temp_unschedulable", "account_id", account.ID, "until", until, "model", model) return true } @@ -694,17 +709,17 @@ func (s *RateLimitService) triggerStreamTimeoutError(ctx context.Context, accoun errorMsg := "Stream data interval timeout (repeated failures) for model: " + model if err := s.accountRepo.SetError(ctx, account.ID, errorMsg); err != nil { - log.Printf("[StreamTimeout] SetError failed for account %d: %v", account.ID, err) + slog.Warn("stream_timeout_set_error_failed", "account_id", account.ID, "error", err) return false } // 重置超时计数 if s.timeoutCounterCache != nil { if err := s.timeoutCounterCache.ResetTimeoutCount(ctx, account.ID); err != nil { - log.Printf("[StreamTimeout] ResetTimeoutCount failed for account %d: %v", account.ID, err) + slog.Warn("stream_timeout_reset_count_failed", "account_id", account.ID, "error", err) } } - log.Printf("[StreamTimeout] Account %d marked as error (model: %s)", account.ID, model) + slog.Warn("stream_timeout_account_error", "account_id", account.ID, "model", model) return true } diff --git a/backend/internal/service/ratelimit_service_401_test.go b/backend/internal/service/ratelimit_service_401_test.go new file mode 100644 index 00000000..36357a4b --- /dev/null +++ b/backend/internal/service/ratelimit_service_401_test.go @@ -0,0 +1,121 @@ +//go:build unit + +package service + +import ( + "context" + "errors" + "net/http" + "testing" + "time" + + "github.com/Wei-Shaw/sub2api/internal/config" + "github.com/stretchr/testify/require" +) + +type rateLimitAccountRepoStub struct { + mockAccountRepoForGemini + setErrorCalls int + tempCalls int + lastErrorMsg string +} + +func (r *rateLimitAccountRepoStub) SetError(ctx context.Context, id int64, errorMsg string) error { + r.setErrorCalls++ + r.lastErrorMsg = errorMsg + return nil +} + +func (r *rateLimitAccountRepoStub) SetTempUnschedulable(ctx context.Context, id int64, until time.Time, reason string) error { + r.tempCalls++ + return nil +} + +type tokenCacheInvalidatorRecorder struct { + accounts []*Account + err error +} + +func (r *tokenCacheInvalidatorRecorder) InvalidateToken(ctx context.Context, account *Account) error { + r.accounts = append(r.accounts, account) + return r.err +} + +func TestRateLimitService_HandleUpstreamError_OAuth401MarksError(t *testing.T) { + tests := []struct { + name string + platform string + }{ + {name: "gemini", platform: PlatformGemini}, + {name: "antigravity", platform: PlatformAntigravity}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + repo := &rateLimitAccountRepoStub{} + invalidator := &tokenCacheInvalidatorRecorder{} + service := NewRateLimitService(repo, nil, &config.Config{}, nil, nil) + service.SetTokenCacheInvalidator(invalidator) + account := &Account{ + ID: 100, + Platform: tt.platform, + Type: AccountTypeOAuth, + Credentials: map[string]any{ + "temp_unschedulable_enabled": true, + "temp_unschedulable_rules": []any{ + map[string]any{ + "error_code": 401, + "keywords": []any{"unauthorized"}, + "duration_minutes": 30, + "description": "custom rule", + }, + }, + }, + } + + shouldDisable := service.HandleUpstreamError(context.Background(), account, 401, http.Header{}, []byte("unauthorized")) + + require.True(t, shouldDisable) + require.Equal(t, 1, repo.setErrorCalls) + require.Equal(t, 0, repo.tempCalls) + require.Contains(t, repo.lastErrorMsg, "Authentication failed (401)") + require.Len(t, invalidator.accounts, 1) + }) + } +} + +func TestRateLimitService_HandleUpstreamError_OAuth401InvalidatorError(t *testing.T) { + repo := &rateLimitAccountRepoStub{} + invalidator := &tokenCacheInvalidatorRecorder{err: errors.New("boom")} + service := NewRateLimitService(repo, nil, &config.Config{}, nil, nil) + service.SetTokenCacheInvalidator(invalidator) + account := &Account{ + ID: 101, + Platform: PlatformGemini, + Type: AccountTypeOAuth, + } + + shouldDisable := service.HandleUpstreamError(context.Background(), account, 401, http.Header{}, []byte("unauthorized")) + + require.True(t, shouldDisable) + require.Equal(t, 1, repo.setErrorCalls) + require.Len(t, invalidator.accounts, 1) +} + +func TestRateLimitService_HandleUpstreamError_NonOAuth401(t *testing.T) { + repo := &rateLimitAccountRepoStub{} + invalidator := &tokenCacheInvalidatorRecorder{} + service := NewRateLimitService(repo, nil, &config.Config{}, nil, nil) + service.SetTokenCacheInvalidator(invalidator) + account := &Account{ + ID: 102, + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + } + + shouldDisable := service.HandleUpstreamError(context.Background(), account, 401, http.Header{}, []byte("unauthorized")) + + require.True(t, shouldDisable) + require.Equal(t, 1, repo.setErrorCalls) + require.Empty(t, invalidator.accounts) +} diff --git a/backend/internal/service/token_cache_invalidator.go b/backend/internal/service/token_cache_invalidator.go new file mode 100644 index 00000000..aacdf266 --- /dev/null +++ b/backend/internal/service/token_cache_invalidator.go @@ -0,0 +1,35 @@ +package service + +import "context" + +type TokenCacheInvalidator interface { + InvalidateToken(ctx context.Context, account *Account) error +} + +type CompositeTokenCacheInvalidator struct { + geminiCache GeminiTokenCache +} + +func NewCompositeTokenCacheInvalidator(geminiCache GeminiTokenCache) *CompositeTokenCacheInvalidator { + return &CompositeTokenCacheInvalidator{ + geminiCache: geminiCache, + } +} + +func (c *CompositeTokenCacheInvalidator) InvalidateToken(ctx context.Context, account *Account) error { + if c == nil || c.geminiCache == nil || account == nil { + return nil + } + if account.Type != AccountTypeOAuth { + return nil + } + + switch account.Platform { + case PlatformGemini: + return c.geminiCache.DeleteAccessToken(ctx, GeminiTokenCacheKey(account)) + case PlatformAntigravity: + return c.geminiCache.DeleteAccessToken(ctx, AntigravityTokenCacheKey(account)) + default: + return nil + } +} diff --git a/backend/internal/service/token_cache_invalidator_test.go b/backend/internal/service/token_cache_invalidator_test.go new file mode 100644 index 00000000..0090ed24 --- /dev/null +++ b/backend/internal/service/token_cache_invalidator_test.go @@ -0,0 +1,97 @@ +//go:build unit + +package service + +import ( + "context" + "testing" + "time" + + "github.com/stretchr/testify/require" +) + +type geminiTokenCacheStub struct { + deletedKeys []string + deleteErr error +} + +func (s *geminiTokenCacheStub) GetAccessToken(ctx context.Context, cacheKey string) (string, error) { + return "", nil +} + +func (s *geminiTokenCacheStub) SetAccessToken(ctx context.Context, cacheKey string, token string, ttl time.Duration) error { + return nil +} + +func (s *geminiTokenCacheStub) DeleteAccessToken(ctx context.Context, cacheKey string) error { + s.deletedKeys = append(s.deletedKeys, cacheKey) + return s.deleteErr +} + +func (s *geminiTokenCacheStub) AcquireRefreshLock(ctx context.Context, cacheKey string, ttl time.Duration) (bool, error) { + return true, nil +} + +func (s *geminiTokenCacheStub) ReleaseRefreshLock(ctx context.Context, cacheKey string) error { + return nil +} + +func TestCompositeTokenCacheInvalidator_Gemini(t *testing.T) { + cache := &geminiTokenCacheStub{} + invalidator := NewCompositeTokenCacheInvalidator(cache) + account := &Account{ + ID: 10, + Platform: PlatformGemini, + Type: AccountTypeOAuth, + Credentials: map[string]any{ + "project_id": "project-x", + }, + } + + err := invalidator.InvalidateToken(context.Background(), account) + require.NoError(t, err) + require.Equal(t, []string{"project-x"}, cache.deletedKeys) +} + +func TestCompositeTokenCacheInvalidator_Antigravity(t *testing.T) { + cache := &geminiTokenCacheStub{} + invalidator := NewCompositeTokenCacheInvalidator(cache) + account := &Account{ + ID: 99, + Platform: PlatformAntigravity, + Type: AccountTypeOAuth, + Credentials: map[string]any{ + "project_id": "ag-project", + }, + } + + err := invalidator.InvalidateToken(context.Background(), account) + require.NoError(t, err) + require.Equal(t, []string{"ag:ag-project"}, cache.deletedKeys) +} + +func TestCompositeTokenCacheInvalidator_SkipNonOAuth(t *testing.T) { + cache := &geminiTokenCacheStub{} + invalidator := NewCompositeTokenCacheInvalidator(cache) + account := &Account{ + ID: 1, + Platform: PlatformGemini, + Type: AccountTypeAPIKey, + } + + err := invalidator.InvalidateToken(context.Background(), account) + require.NoError(t, err) + require.Empty(t, cache.deletedKeys) +} + +func TestCompositeTokenCacheInvalidator_NilCache(t *testing.T) { + invalidator := NewCompositeTokenCacheInvalidator(nil) + account := &Account{ + ID: 2, + Platform: PlatformGemini, + Type: AccountTypeOAuth, + } + + err := invalidator.InvalidateToken(context.Background(), account) + require.NoError(t, err) +} diff --git a/backend/internal/service/token_cache_key_test.go b/backend/internal/service/token_cache_key_test.go new file mode 100644 index 00000000..0dc751c6 --- /dev/null +++ b/backend/internal/service/token_cache_key_test.go @@ -0,0 +1,153 @@ +//go:build unit + +package service + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +func TestGeminiTokenCacheKey(t *testing.T) { + tests := []struct { + name string + account *Account + expected string + }{ + { + name: "with_project_id", + account: &Account{ + ID: 100, + Credentials: map[string]any{ + "project_id": "my-project-123", + }, + }, + expected: "my-project-123", + }, + { + name: "project_id_with_whitespace", + account: &Account{ + ID: 101, + Credentials: map[string]any{ + "project_id": " project-with-spaces ", + }, + }, + expected: "project-with-spaces", + }, + { + name: "empty_project_id_fallback_to_account_id", + account: &Account{ + ID: 102, + Credentials: map[string]any{ + "project_id": "", + }, + }, + expected: "account:102", + }, + { + name: "whitespace_only_project_id_fallback_to_account_id", + account: &Account{ + ID: 103, + Credentials: map[string]any{ + "project_id": " ", + }, + }, + expected: "account:103", + }, + { + name: "no_project_id_key_fallback_to_account_id", + account: &Account{ + ID: 104, + Credentials: map[string]any{}, + }, + expected: "account:104", + }, + { + name: "nil_credentials_fallback_to_account_id", + account: &Account{ + ID: 105, + Credentials: nil, + }, + expected: "account:105", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := GeminiTokenCacheKey(tt.account) + require.Equal(t, tt.expected, result) + }) + } +} + +func TestAntigravityTokenCacheKey(t *testing.T) { + tests := []struct { + name string + account *Account + expected string + }{ + { + name: "with_project_id", + account: &Account{ + ID: 200, + Credentials: map[string]any{ + "project_id": "ag-project-456", + }, + }, + expected: "ag:ag-project-456", + }, + { + name: "project_id_with_whitespace", + account: &Account{ + ID: 201, + Credentials: map[string]any{ + "project_id": " ag-project-spaces ", + }, + }, + expected: "ag:ag-project-spaces", + }, + { + name: "empty_project_id_fallback_to_account_id", + account: &Account{ + ID: 202, + Credentials: map[string]any{ + "project_id": "", + }, + }, + expected: "ag:account:202", + }, + { + name: "whitespace_only_project_id_fallback_to_account_id", + account: &Account{ + ID: 203, + Credentials: map[string]any{ + "project_id": " ", + }, + }, + expected: "ag:account:203", + }, + { + name: "no_project_id_key_fallback_to_account_id", + account: &Account{ + ID: 204, + Credentials: map[string]any{}, + }, + expected: "ag:account:204", + }, + { + name: "nil_credentials_fallback_to_account_id", + account: &Account{ + ID: 205, + Credentials: nil, + }, + expected: "ag:account:205", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := AntigravityTokenCacheKey(tt.account) + require.Equal(t, tt.expected, result) + }) + } +} diff --git a/backend/internal/service/token_refresh_service.go b/backend/internal/service/token_refresh_service.go index 3ed35f04..4d513d07 100644 --- a/backend/internal/service/token_refresh_service.go +++ b/backend/internal/service/token_refresh_service.go @@ -14,9 +14,10 @@ import ( // TokenRefreshService OAuth token自动刷新服务 // 定期检查并刷新即将过期的token type TokenRefreshService struct { - accountRepo AccountRepository - refreshers []TokenRefresher - cfg *config.TokenRefreshConfig + accountRepo AccountRepository + refreshers []TokenRefresher + cfg *config.TokenRefreshConfig + cacheInvalidator TokenCacheInvalidator stopCh chan struct{} wg sync.WaitGroup @@ -29,12 +30,14 @@ func NewTokenRefreshService( openaiOAuthService *OpenAIOAuthService, geminiOAuthService *GeminiOAuthService, antigravityOAuthService *AntigravityOAuthService, + cacheInvalidator TokenCacheInvalidator, cfg *config.Config, ) *TokenRefreshService { s := &TokenRefreshService{ - accountRepo: accountRepo, - cfg: &cfg.TokenRefresh, - stopCh: make(chan struct{}), + accountRepo: accountRepo, + cfg: &cfg.TokenRefresh, + cacheInvalidator: cacheInvalidator, + stopCh: make(chan struct{}), } // 注册平台特定的刷新器 @@ -169,6 +172,14 @@ func (s *TokenRefreshService) refreshWithRetry(ctx context.Context, account *Acc if err := s.accountRepo.Update(ctx, account); err != nil { return fmt.Errorf("failed to save credentials: %w", err) } + if s.cacheInvalidator != nil && account.Type == AccountTypeOAuth && + (account.Platform == PlatformGemini || account.Platform == PlatformAntigravity) { + if err := s.cacheInvalidator.InvalidateToken(ctx, account); err != nil { + log.Printf("[TokenRefresh] Failed to invalidate token cache for account %d: %v", account.ID, err) + } else { + log.Printf("[TokenRefresh] Token cache invalidated for account %d", account.ID) + } + } return nil } diff --git a/backend/internal/service/token_refresh_service_test.go b/backend/internal/service/token_refresh_service_test.go new file mode 100644 index 00000000..b11a0adc --- /dev/null +++ b/backend/internal/service/token_refresh_service_test.go @@ -0,0 +1,361 @@ +//go:build unit + +package service + +import ( + "context" + "errors" + "testing" + "time" + + "github.com/Wei-Shaw/sub2api/internal/config" + "github.com/stretchr/testify/require" +) + +type tokenRefreshAccountRepo struct { + mockAccountRepoForGemini + updateCalls int + setErrorCalls int + lastAccount *Account + updateErr error +} + +func (r *tokenRefreshAccountRepo) Update(ctx context.Context, account *Account) error { + r.updateCalls++ + r.lastAccount = account + return r.updateErr +} + +func (r *tokenRefreshAccountRepo) SetError(ctx context.Context, id int64, errorMsg string) error { + r.setErrorCalls++ + return nil +} + +type tokenCacheInvalidatorStub struct { + calls int + err error +} + +func (s *tokenCacheInvalidatorStub) InvalidateToken(ctx context.Context, account *Account) error { + s.calls++ + return s.err +} + +type tokenRefresherStub struct { + credentials map[string]any + err error +} + +func (r *tokenRefresherStub) CanRefresh(account *Account) bool { + return true +} + +func (r *tokenRefresherStub) NeedsRefresh(account *Account, refreshWindowDuration time.Duration) bool { + return true +} + +func (r *tokenRefresherStub) Refresh(ctx context.Context, account *Account) (map[string]any, error) { + if r.err != nil { + return nil, r.err + } + return r.credentials, nil +} + +func TestTokenRefreshService_RefreshWithRetry_InvalidatesCache(t *testing.T) { + repo := &tokenRefreshAccountRepo{} + invalidator := &tokenCacheInvalidatorStub{} + cfg := &config.Config{ + TokenRefresh: config.TokenRefreshConfig{ + MaxRetries: 1, + RetryBackoffSeconds: 0, + }, + } + service := NewTokenRefreshService(repo, nil, nil, nil, nil, invalidator, cfg) + account := &Account{ + ID: 5, + Platform: PlatformGemini, + Type: AccountTypeOAuth, + } + refresher := &tokenRefresherStub{ + credentials: map[string]any{ + "access_token": "new-token", + }, + } + + err := service.refreshWithRetry(context.Background(), account, refresher) + require.NoError(t, err) + require.Equal(t, 1, repo.updateCalls) + require.Equal(t, 1, invalidator.calls) + require.Equal(t, "new-token", account.GetCredential("access_token")) +} + +func TestTokenRefreshService_RefreshWithRetry_InvalidatorErrorIgnored(t *testing.T) { + repo := &tokenRefreshAccountRepo{} + invalidator := &tokenCacheInvalidatorStub{err: errors.New("invalidate failed")} + cfg := &config.Config{ + TokenRefresh: config.TokenRefreshConfig{ + MaxRetries: 1, + RetryBackoffSeconds: 0, + }, + } + service := NewTokenRefreshService(repo, nil, nil, nil, nil, invalidator, cfg) + account := &Account{ + ID: 6, + Platform: PlatformGemini, + Type: AccountTypeOAuth, + } + refresher := &tokenRefresherStub{ + credentials: map[string]any{ + "access_token": "token", + }, + } + + err := service.refreshWithRetry(context.Background(), account, refresher) + require.NoError(t, err) + require.Equal(t, 1, repo.updateCalls) + require.Equal(t, 1, invalidator.calls) +} + +func TestTokenRefreshService_RefreshWithRetry_NilInvalidator(t *testing.T) { + repo := &tokenRefreshAccountRepo{} + cfg := &config.Config{ + TokenRefresh: config.TokenRefreshConfig{ + MaxRetries: 1, + RetryBackoffSeconds: 0, + }, + } + service := NewTokenRefreshService(repo, nil, nil, nil, nil, nil, cfg) + account := &Account{ + ID: 7, + Platform: PlatformGemini, + Type: AccountTypeOAuth, + } + refresher := &tokenRefresherStub{ + credentials: map[string]any{ + "access_token": "token", + }, + } + + err := service.refreshWithRetry(context.Background(), account, refresher) + require.NoError(t, err) + require.Equal(t, 1, repo.updateCalls) +} + +// TestTokenRefreshService_RefreshWithRetry_Antigravity 测试 Antigravity 平台的缓存失效 +func TestTokenRefreshService_RefreshWithRetry_Antigravity(t *testing.T) { + repo := &tokenRefreshAccountRepo{} + invalidator := &tokenCacheInvalidatorStub{} + cfg := &config.Config{ + TokenRefresh: config.TokenRefreshConfig{ + MaxRetries: 1, + RetryBackoffSeconds: 0, + }, + } + service := NewTokenRefreshService(repo, nil, nil, nil, nil, invalidator, cfg) + account := &Account{ + ID: 8, + Platform: PlatformAntigravity, + Type: AccountTypeOAuth, + } + refresher := &tokenRefresherStub{ + credentials: map[string]any{ + "access_token": "ag-token", + }, + } + + err := service.refreshWithRetry(context.Background(), account, refresher) + require.NoError(t, err) + require.Equal(t, 1, repo.updateCalls) + require.Equal(t, 1, invalidator.calls) // Antigravity 也应触发缓存失效 +} + +// TestTokenRefreshService_RefreshWithRetry_NonOAuthAccount 测试非 OAuth 账号不触发缓存失效 +func TestTokenRefreshService_RefreshWithRetry_NonOAuthAccount(t *testing.T) { + repo := &tokenRefreshAccountRepo{} + invalidator := &tokenCacheInvalidatorStub{} + cfg := &config.Config{ + TokenRefresh: config.TokenRefreshConfig{ + MaxRetries: 1, + RetryBackoffSeconds: 0, + }, + } + service := NewTokenRefreshService(repo, nil, nil, nil, nil, invalidator, cfg) + account := &Account{ + ID: 9, + Platform: PlatformGemini, + Type: AccountTypeAPIKey, // 非 OAuth + } + refresher := &tokenRefresherStub{ + credentials: map[string]any{ + "access_token": "token", + }, + } + + err := service.refreshWithRetry(context.Background(), account, refresher) + require.NoError(t, err) + require.Equal(t, 1, repo.updateCalls) + require.Equal(t, 0, invalidator.calls) // 非 OAuth 不触发缓存失效 +} + +// TestTokenRefreshService_RefreshWithRetry_OtherPlatformOAuth 测试其他平台的 OAuth 账号不触发缓存失效 +func TestTokenRefreshService_RefreshWithRetry_OtherPlatformOAuth(t *testing.T) { + repo := &tokenRefreshAccountRepo{} + invalidator := &tokenCacheInvalidatorStub{} + cfg := &config.Config{ + TokenRefresh: config.TokenRefreshConfig{ + MaxRetries: 1, + RetryBackoffSeconds: 0, + }, + } + service := NewTokenRefreshService(repo, nil, nil, nil, nil, invalidator, cfg) + account := &Account{ + ID: 10, + Platform: PlatformOpenAI, // 其他平台 + Type: AccountTypeOAuth, + } + refresher := &tokenRefresherStub{ + credentials: map[string]any{ + "access_token": "token", + }, + } + + err := service.refreshWithRetry(context.Background(), account, refresher) + require.NoError(t, err) + require.Equal(t, 1, repo.updateCalls) + require.Equal(t, 0, invalidator.calls) // 其他平台不触发缓存失效 +} + +// TestTokenRefreshService_RefreshWithRetry_UpdateFailed 测试更新失败的情况 +func TestTokenRefreshService_RefreshWithRetry_UpdateFailed(t *testing.T) { + repo := &tokenRefreshAccountRepo{updateErr: errors.New("update failed")} + invalidator := &tokenCacheInvalidatorStub{} + cfg := &config.Config{ + TokenRefresh: config.TokenRefreshConfig{ + MaxRetries: 1, + RetryBackoffSeconds: 0, + }, + } + service := NewTokenRefreshService(repo, nil, nil, nil, nil, invalidator, cfg) + account := &Account{ + ID: 11, + Platform: PlatformGemini, + Type: AccountTypeOAuth, + } + refresher := &tokenRefresherStub{ + credentials: map[string]any{ + "access_token": "token", + }, + } + + err := service.refreshWithRetry(context.Background(), account, refresher) + require.Error(t, err) + require.Contains(t, err.Error(), "failed to save credentials") + require.Equal(t, 1, repo.updateCalls) + require.Equal(t, 0, invalidator.calls) // 更新失败时不应触发缓存失效 +} + +// TestTokenRefreshService_RefreshWithRetry_RefreshFailed 测试刷新失败的情况 +func TestTokenRefreshService_RefreshWithRetry_RefreshFailed(t *testing.T) { + repo := &tokenRefreshAccountRepo{} + invalidator := &tokenCacheInvalidatorStub{} + cfg := &config.Config{ + TokenRefresh: config.TokenRefreshConfig{ + MaxRetries: 2, + RetryBackoffSeconds: 0, + }, + } + service := NewTokenRefreshService(repo, nil, nil, nil, nil, invalidator, cfg) + account := &Account{ + ID: 12, + Platform: PlatformGemini, + Type: AccountTypeOAuth, + } + refresher := &tokenRefresherStub{ + err: errors.New("refresh failed"), + } + + err := service.refreshWithRetry(context.Background(), account, refresher) + require.Error(t, err) + require.Equal(t, 0, repo.updateCalls) // 刷新失败不应更新 + require.Equal(t, 0, invalidator.calls) // 刷新失败不应触发缓存失效 + require.Equal(t, 1, repo.setErrorCalls) // 应设置错误状态 +} + +// TestTokenRefreshService_RefreshWithRetry_AntigravityRefreshFailed 测试 Antigravity 刷新失败不设置错误状态 +func TestTokenRefreshService_RefreshWithRetry_AntigravityRefreshFailed(t *testing.T) { + repo := &tokenRefreshAccountRepo{} + invalidator := &tokenCacheInvalidatorStub{} + cfg := &config.Config{ + TokenRefresh: config.TokenRefreshConfig{ + MaxRetries: 1, + RetryBackoffSeconds: 0, + }, + } + service := NewTokenRefreshService(repo, nil, nil, nil, nil, invalidator, cfg) + account := &Account{ + ID: 13, + Platform: PlatformAntigravity, + Type: AccountTypeOAuth, + } + refresher := &tokenRefresherStub{ + err: errors.New("network error"), // 可重试错误 + } + + err := service.refreshWithRetry(context.Background(), account, refresher) + require.Error(t, err) + require.Equal(t, 0, repo.updateCalls) + require.Equal(t, 0, invalidator.calls) + require.Equal(t, 0, repo.setErrorCalls) // Antigravity 可重试错误不设置错误状态 +} + +// TestTokenRefreshService_RefreshWithRetry_AntigravityNonRetryableError 测试 Antigravity 不可重试错误 +func TestTokenRefreshService_RefreshWithRetry_AntigravityNonRetryableError(t *testing.T) { + repo := &tokenRefreshAccountRepo{} + invalidator := &tokenCacheInvalidatorStub{} + cfg := &config.Config{ + TokenRefresh: config.TokenRefreshConfig{ + MaxRetries: 3, + RetryBackoffSeconds: 0, + }, + } + service := NewTokenRefreshService(repo, nil, nil, nil, nil, invalidator, cfg) + account := &Account{ + ID: 14, + Platform: PlatformAntigravity, + Type: AccountTypeOAuth, + } + refresher := &tokenRefresherStub{ + err: errors.New("invalid_grant: token revoked"), // 不可重试错误 + } + + err := service.refreshWithRetry(context.Background(), account, refresher) + require.Error(t, err) + require.Equal(t, 0, repo.updateCalls) + require.Equal(t, 0, invalidator.calls) + require.Equal(t, 1, repo.setErrorCalls) // 不可重试错误应设置错误状态 +} + +// TestIsNonRetryableRefreshError 测试不可重试错误判断 +func TestIsNonRetryableRefreshError(t *testing.T) { + tests := []struct { + name string + err error + expected bool + }{ + {name: "nil_error", err: nil, expected: false}, + {name: "network_error", err: errors.New("network timeout"), expected: false}, + {name: "invalid_grant", err: errors.New("invalid_grant"), expected: true}, + {name: "invalid_client", err: errors.New("invalid_client"), expected: true}, + {name: "unauthorized_client", err: errors.New("unauthorized_client"), expected: true}, + {name: "access_denied", err: errors.New("access_denied"), expected: true}, + {name: "invalid_grant_with_desc", err: errors.New("Error: invalid_grant - token revoked"), expected: true}, + {name: "case_insensitive", err: errors.New("INVALID_GRANT"), expected: true}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := isNonRetryableRefreshError(tt.err) + require.Equal(t, tt.expected, result) + }) + } +} diff --git a/backend/internal/service/wire.go b/backend/internal/service/wire.go index 5326bace..05dbb0b0 100644 --- a/backend/internal/service/wire.go +++ b/backend/internal/service/wire.go @@ -42,9 +42,10 @@ func ProvideTokenRefreshService( openaiOAuthService *OpenAIOAuthService, geminiOAuthService *GeminiOAuthService, antigravityOAuthService *AntigravityOAuthService, + cacheInvalidator TokenCacheInvalidator, cfg *config.Config, ) *TokenRefreshService { - svc := NewTokenRefreshService(accountRepo, oauthService, openaiOAuthService, geminiOAuthService, antigravityOAuthService, cfg) + svc := NewTokenRefreshService(accountRepo, oauthService, openaiOAuthService, geminiOAuthService, antigravityOAuthService, cacheInvalidator, cfg) svc.Start() return svc } @@ -108,10 +109,12 @@ func ProvideRateLimitService( tempUnschedCache TempUnschedCache, timeoutCounterCache TimeoutCounterCache, settingService *SettingService, + tokenCacheInvalidator TokenCacheInvalidator, ) *RateLimitService { svc := NewRateLimitService(accountRepo, usageRepo, cfg, geminiQuotaService, tempUnschedCache) svc.SetTimeoutCounterCache(timeoutCounterCache) svc.SetSettingService(settingService) + svc.SetTokenCacheInvalidator(tokenCacheInvalidator) return svc } @@ -210,6 +213,7 @@ var ProviderSet = wire.NewSet( NewOpenAIOAuthService, NewGeminiOAuthService, NewGeminiQuotaService, + NewCompositeTokenCacheInvalidator, NewAntigravityOAuthService, NewGeminiTokenProvider, NewGeminiMessagesCompatService, diff --git a/deploy/.env.example b/deploy/.env.example index e5cf8b32..f21a3c62 100644 --- a/deploy/.env.example +++ b/deploy/.env.example @@ -69,6 +69,14 @@ JWT_EXPIRE_HOUR=24 # Leave unset to use default ./config.yaml #CONFIG_FILE=./config.yaml +# ----------------------------------------------------------------------------- +# Rate Limiting (Optional) +# 速率限制(可选) +# ----------------------------------------------------------------------------- +# Cooldown time (in minutes) when upstream returns 529 (overloaded) +# 上游返回 529(过载)时的冷却时间(分钟) +RATE_LIMIT_OVERLOAD_COOLDOWN_MINUTES=10 + # ----------------------------------------------------------------------------- # Gateway Scheduling (Optional) # 调度缓存与受控回源配置(缓存就绪且命中时不读 DB) diff --git a/frontend/src/views/admin/AccountsView.vue b/frontend/src/views/admin/AccountsView.vue index f465c001..cf484303 100644 --- a/frontend/src/views/admin/AccountsView.vue +++ b/frontend/src/views/admin/AccountsView.vue @@ -357,9 +357,6 @@ const handleBulkToggleSchedulable = async (schedulable: boolean) => { } else { selIds.value = hasIds ? [] : accountIds } - load().catch((error) => { - console.error('Failed to refresh accounts:', error) - }) } catch (error) { console.error('Failed to bulk toggle schedulable:', error) appStore.showError(t('common.error')) @@ -383,9 +380,6 @@ const handleToggleSchedulable = async (a: Account) => { try { const updated = await adminAPI.accounts.setSchedulable(a.id, nextSchedulable) updateSchedulableInList([a.id], updated?.schedulable ?? nextSchedulable) - load().catch((error) => { - console.error('Failed to refresh accounts:', error) - }) } catch (error) { console.error('Failed to toggle schedulable:', error) appStore.showError(t('admin.accounts.failedToToggleSchedulable'))