package service import ( "context" "encoding/json" "log/slog" "net/http" "strconv" "strings" "sync" "time" "github.com/Wei-Shaw/sub2api/internal/config" ) // RateLimitService 处理限流和过载状态管理 type RateLimitService struct { accountRepo AccountRepository usageRepo UsageLogRepository cfg *config.Config geminiQuotaService *GeminiQuotaService tempUnschedCache TempUnschedCache timeoutCounterCache TimeoutCounterCache settingService *SettingService tokenCacheInvalidator TokenCacheInvalidator usageCacheMu sync.RWMutex usageCache map[int64]*geminiUsageCacheEntry } type geminiUsageCacheEntry struct { windowStart time.Time cachedAt time.Time totals GeminiUsageTotals } const geminiPrecheckCacheTTL = time.Minute // NewRateLimitService 创建RateLimitService实例 func NewRateLimitService(accountRepo AccountRepository, usageRepo UsageLogRepository, cfg *config.Config, geminiQuotaService *GeminiQuotaService, tempUnschedCache TempUnschedCache) *RateLimitService { return &RateLimitService{ accountRepo: accountRepo, usageRepo: usageRepo, cfg: cfg, geminiQuotaService: geminiQuotaService, tempUnschedCache: tempUnschedCache, usageCache: make(map[int64]*geminiUsageCacheEntry), } } // SetTimeoutCounterCache 设置超时计数器缓存(可选依赖) func (s *RateLimitService) SetTimeoutCounterCache(cache TimeoutCounterCache) { s.timeoutCounterCache = cache } // SetSettingService 设置系统设置服务(可选依赖) func (s *RateLimitService) SetSettingService(settingService *SettingService) { s.settingService = settingService } // SetTokenCacheInvalidator 设置 token 缓存清理器(可选依赖) func (s *RateLimitService) SetTokenCacheInvalidator(invalidator TokenCacheInvalidator) { s.tokenCacheInvalidator = invalidator } // HandleUpstreamError 处理上游错误响应,标记账号状态 // 返回是否应该停止该账号的调度 func (s *RateLimitService) HandleUpstreamError(ctx context.Context, account *Account, statusCode int, headers http.Header, responseBody []byte) (shouldDisable bool) { // apikey 类型账号:检查自定义错误码配置 // 如果启用且错误码不在列表中,则不处理(不停止调度、不标记限流/过载) customErrorCodesEnabled := account.IsCustomErrorCodesEnabled() if !account.ShouldHandleErrorCode(statusCode) { slog.Info("account_error_code_skipped", "account_id", account.ID, "status_code", statusCode) return false } tempMatched := false if statusCode != 401 { tempMatched = s.tryTempUnschedulable(ctx, account, statusCode, responseBody) } upstreamMsg := strings.TrimSpace(extractUpstreamErrorMessage(responseBody)) upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg) if upstreamMsg != "" { upstreamMsg = truncateForLog([]byte(upstreamMsg), 512) } switch statusCode { case 401: if account.Type == AccountTypeOAuth && (account.Platform == PlatformAntigravity || account.Platform == PlatformGemini) { if s.tokenCacheInvalidator != nil { if err := s.tokenCacheInvalidator.InvalidateToken(ctx, account); err != nil { slog.Warn("oauth_401_invalidate_cache_failed", "account_id", account.ID, "error", err) } } } msg := "Authentication failed (401): invalid or expired credentials" if upstreamMsg != "" { msg = "Authentication failed (401): " + upstreamMsg } s.handleAuthError(ctx, account, msg) shouldDisable = true case 402: // 支付要求:余额不足或计费问题,停止调度 msg := "Payment required (402): insufficient balance or billing issue" if upstreamMsg != "" { msg = "Payment required (402): " + upstreamMsg } s.handleAuthError(ctx, account, msg) shouldDisable = true case 403: // 禁止访问:停止调度,记录错误 msg := "Access forbidden (403): account may be suspended or lack permissions" if upstreamMsg != "" { msg = "Access forbidden (403): " + upstreamMsg } s.handleAuthError(ctx, account, msg) shouldDisable = true case 429: s.handle429(ctx, account, headers) shouldDisable = false case 529: s.handle529(ctx, account) shouldDisable = false default: // 自定义错误码启用时:在列表中的错误码都应该停止调度 if customErrorCodesEnabled { msg := "Custom error code triggered" if upstreamMsg != "" { msg = upstreamMsg } s.handleCustomErrorCode(ctx, account, statusCode, msg) shouldDisable = true } else if statusCode >= 500 { // 未启用自定义错误码时:仅记录5xx错误 slog.Warn("account_upstream_error", "account_id", account.ID, "status_code", statusCode) shouldDisable = false } } if tempMatched { return true } return shouldDisable } // PreCheckUsage proactively checks local quota before dispatching a request. // Returns false when the account should be skipped. func (s *RateLimitService) PreCheckUsage(ctx context.Context, account *Account, requestedModel string) (bool, error) { if account == nil || account.Platform != PlatformGemini { return true, nil } if s.usageRepo == nil || s.geminiQuotaService == nil { return true, nil } quota, ok := s.geminiQuotaService.QuotaForAccount(ctx, account) if !ok { return true, nil } now := time.Now() modelClass := geminiModelClassFromName(requestedModel) // 1) Daily quota precheck (RPD; resets at PST midnight) { var limit int64 if quota.SharedRPD > 0 { limit = quota.SharedRPD } else { switch modelClass { case geminiModelFlash: limit = quota.FlashRPD default: limit = quota.ProRPD } } if limit > 0 { start := geminiDailyWindowStart(now) totals, ok := s.getGeminiUsageTotals(account.ID, start, now) if !ok { stats, err := s.usageRepo.GetModelStatsWithFilters(ctx, start, now, 0, 0, account.ID) if err != nil { return true, err } totals = geminiAggregateUsage(stats) s.setGeminiUsageTotals(account.ID, start, now, totals) } var used int64 if quota.SharedRPD > 0 { used = totals.ProRequests + totals.FlashRequests } else { switch modelClass { case geminiModelFlash: used = totals.FlashRequests default: used = totals.ProRequests } } if used >= limit { resetAt := geminiDailyResetTime(now) // NOTE: // - This is a local precheck to reduce upstream 429s. // - Do NOT mark the account as rate-limited here; rate_limit_reset_at should reflect real upstream 429s. slog.Info("gemini_precheck_daily_quota_reached", "account_id", account.ID, "used", used, "limit", limit, "reset_at", resetAt) return false, nil } } } // 2) Minute quota precheck (RPM; fixed window current minute) { var limit int64 if quota.SharedRPM > 0 { limit = quota.SharedRPM } else { switch modelClass { case geminiModelFlash: limit = quota.FlashRPM default: limit = quota.ProRPM } } if limit > 0 { start := now.Truncate(time.Minute) stats, err := s.usageRepo.GetModelStatsWithFilters(ctx, start, now, 0, 0, account.ID) if err != nil { return true, err } totals := geminiAggregateUsage(stats) var used int64 if quota.SharedRPM > 0 { used = totals.ProRequests + totals.FlashRequests } else { switch modelClass { case geminiModelFlash: used = totals.FlashRequests default: used = totals.ProRequests } } if used >= limit { resetAt := start.Add(time.Minute) // Do not persist "rate limited" status from local precheck. See note above. slog.Info("gemini_precheck_minute_quota_reached", "account_id", account.ID, "used", used, "limit", limit, "reset_at", resetAt) return false, nil } } } return true, nil } func (s *RateLimitService) getGeminiUsageTotals(accountID int64, windowStart, now time.Time) (GeminiUsageTotals, bool) { s.usageCacheMu.RLock() defer s.usageCacheMu.RUnlock() if s.usageCache == nil { return GeminiUsageTotals{}, false } entry, ok := s.usageCache[accountID] if !ok || entry == nil { return GeminiUsageTotals{}, false } if !entry.windowStart.Equal(windowStart) { return GeminiUsageTotals{}, false } if now.Sub(entry.cachedAt) >= geminiPrecheckCacheTTL { return GeminiUsageTotals{}, false } return entry.totals, true } func (s *RateLimitService) setGeminiUsageTotals(accountID int64, windowStart, now time.Time, totals GeminiUsageTotals) { s.usageCacheMu.Lock() defer s.usageCacheMu.Unlock() if s.usageCache == nil { s.usageCache = make(map[int64]*geminiUsageCacheEntry) } s.usageCache[accountID] = &geminiUsageCacheEntry{ windowStart: windowStart, cachedAt: now, totals: totals, } } // GeminiCooldown returns the fallback cooldown duration for Gemini 429s based on tier. func (s *RateLimitService) GeminiCooldown(ctx context.Context, account *Account) time.Duration { if account == nil { return 5 * time.Minute } if s.geminiQuotaService == nil { return 5 * time.Minute } return s.geminiQuotaService.CooldownForAccount(ctx, account) } // handleAuthError 处理认证类错误(401/403),停止账号调度 func (s *RateLimitService) handleAuthError(ctx context.Context, account *Account, errorMsg string) { if err := s.accountRepo.SetError(ctx, account.ID, errorMsg); err != nil { slog.Warn("account_set_error_failed", "account_id", account.ID, "error", err) return } slog.Warn("account_disabled_auth_error", "account_id", account.ID, "error", errorMsg) } // handleCustomErrorCode 处理自定义错误码,停止账号调度 func (s *RateLimitService) handleCustomErrorCode(ctx context.Context, account *Account, statusCode int, errorMsg string) { msg := "Custom error code " + strconv.Itoa(statusCode) + ": " + errorMsg if err := s.accountRepo.SetError(ctx, account.ID, msg); err != nil { slog.Warn("account_set_error_failed", "account_id", account.ID, "status_code", statusCode, "error", err) return } slog.Warn("account_disabled_custom_error", "account_id", account.ID, "status_code", statusCode, "error", errorMsg) } // handle429 处理429限流错误 // 解析响应头获取重置时间,标记账号为限流状态 func (s *RateLimitService) handle429(ctx context.Context, account *Account, headers http.Header) { // 解析重置时间戳 resetTimestamp := headers.Get("anthropic-ratelimit-unified-reset") if resetTimestamp == "" { // 没有重置时间,使用默认5分钟 resetAt := time.Now().Add(5 * time.Minute) if err := s.accountRepo.SetRateLimited(ctx, account.ID, resetAt); err != nil { slog.Warn("rate_limit_set_failed", "account_id", account.ID, "error", err) } return } // 解析Unix时间戳 ts, err := strconv.ParseInt(resetTimestamp, 10, 64) if err != nil { slog.Warn("rate_limit_reset_parse_failed", "reset_timestamp", resetTimestamp, "error", err) resetAt := time.Now().Add(5 * time.Minute) if err := s.accountRepo.SetRateLimited(ctx, account.ID, resetAt); err != nil { slog.Warn("rate_limit_set_failed", "account_id", account.ID, "error", err) } return } resetAt := time.Unix(ts, 0) // 标记限流状态 if err := s.accountRepo.SetRateLimited(ctx, account.ID, resetAt); err != nil { slog.Warn("rate_limit_set_failed", "account_id", account.ID, "error", err) return } // 根据重置时间反推5h窗口 windowEnd := resetAt windowStart := resetAt.Add(-5 * time.Hour) if err := s.accountRepo.UpdateSessionWindow(ctx, account.ID, &windowStart, &windowEnd, "rejected"); err != nil { slog.Warn("rate_limit_update_session_window_failed", "account_id", account.ID, "error", err) } slog.Info("account_rate_limited", "account_id", account.ID, "reset_at", resetAt) } // handle529 处理529过载错误 // 根据配置设置过载冷却时间 func (s *RateLimitService) handle529(ctx context.Context, account *Account) { cooldownMinutes := s.cfg.RateLimit.OverloadCooldownMinutes if cooldownMinutes <= 0 { cooldownMinutes = 10 // 默认10分钟 } until := time.Now().Add(time.Duration(cooldownMinutes) * time.Minute) if err := s.accountRepo.SetOverloaded(ctx, account.ID, until); err != nil { slog.Warn("overload_set_failed", "account_id", account.ID, "error", err) return } slog.Info("account_overloaded", "account_id", account.ID, "until", until) } // UpdateSessionWindow 从成功响应更新5h窗口状态 func (s *RateLimitService) UpdateSessionWindow(ctx context.Context, account *Account, headers http.Header) { status := headers.Get("anthropic-ratelimit-unified-5h-status") if status == "" { return } // 检查是否需要初始化时间窗口 // 对于 Setup Token 账号,首次成功请求时需要预测时间窗口 var windowStart, windowEnd *time.Time needInitWindow := account.SessionWindowEnd == nil || time.Now().After(*account.SessionWindowEnd) if needInitWindow && (status == "allowed" || status == "allowed_warning") { // 预测时间窗口:从当前时间的整点开始,+5小时为结束 // 例如:现在是 14:30,窗口为 14:00 ~ 19:00 now := time.Now() start := time.Date(now.Year(), now.Month(), now.Day(), now.Hour(), 0, 0, 0, now.Location()) end := start.Add(5 * time.Hour) windowStart = &start windowEnd = &end slog.Info("account_session_window_initialized", "account_id", account.ID, "window_start", start, "window_end", end, "status", status) } if err := s.accountRepo.UpdateSessionWindow(ctx, account.ID, windowStart, windowEnd, status); err != nil { slog.Warn("session_window_update_failed", "account_id", account.ID, "error", err) } // 如果状态为allowed且之前有限流,说明窗口已重置,清除限流状态 if status == "allowed" && account.IsRateLimited() { if err := s.ClearRateLimit(ctx, account.ID); err != nil { slog.Warn("rate_limit_clear_failed", "account_id", account.ID, "error", err) } } } // ClearRateLimit 清除账号的限流状态 func (s *RateLimitService) ClearRateLimit(ctx context.Context, accountID int64) error { if err := s.accountRepo.ClearRateLimit(ctx, accountID); err != nil { return err } return s.accountRepo.ClearAntigravityQuotaScopes(ctx, accountID) } func (s *RateLimitService) ClearTempUnschedulable(ctx context.Context, accountID int64) error { if err := s.accountRepo.ClearTempUnschedulable(ctx, accountID); err != nil { return err } if s.tempUnschedCache != nil { if err := s.tempUnschedCache.DeleteTempUnsched(ctx, accountID); err != nil { slog.Warn("temp_unsched_cache_delete_failed", "account_id", accountID, "error", err) } } return nil } func (s *RateLimitService) GetTempUnschedStatus(ctx context.Context, accountID int64) (*TempUnschedState, error) { now := time.Now().Unix() if s.tempUnschedCache != nil { state, err := s.tempUnschedCache.GetTempUnsched(ctx, accountID) if err != nil { return nil, err } if state != nil && state.UntilUnix > now { return state, nil } } account, err := s.accountRepo.GetByID(ctx, accountID) if err != nil { return nil, err } if account.TempUnschedulableUntil == nil { return nil, nil } if account.TempUnschedulableUntil.Unix() <= now { return nil, nil } state := &TempUnschedState{ UntilUnix: account.TempUnschedulableUntil.Unix(), } if account.TempUnschedulableReason != "" { var parsed TempUnschedState if err := json.Unmarshal([]byte(account.TempUnschedulableReason), &parsed); err == nil { if parsed.UntilUnix == 0 { parsed.UntilUnix = state.UntilUnix } state = &parsed } else { state.ErrorMessage = account.TempUnschedulableReason } } if s.tempUnschedCache != nil { if err := s.tempUnschedCache.SetTempUnsched(ctx, accountID, state); err != nil { slog.Warn("temp_unsched_cache_set_failed", "account_id", accountID, "error", err) } } return state, nil } func (s *RateLimitService) HandleTempUnschedulable(ctx context.Context, account *Account, statusCode int, responseBody []byte) bool { if account == nil { return false } if !account.ShouldHandleErrorCode(statusCode) { return false } return s.tryTempUnschedulable(ctx, account, statusCode, responseBody) } const tempUnschedBodyMaxBytes = 64 << 10 const tempUnschedMessageMaxBytes = 2048 func (s *RateLimitService) tryTempUnschedulable(ctx context.Context, account *Account, statusCode int, responseBody []byte) bool { if account == nil { return false } if !account.IsTempUnschedulableEnabled() { return false } rules := account.GetTempUnschedulableRules() if len(rules) == 0 { return false } if statusCode <= 0 || len(responseBody) == 0 { return false } body := responseBody if len(body) > tempUnschedBodyMaxBytes { body = body[:tempUnschedBodyMaxBytes] } bodyLower := strings.ToLower(string(body)) for idx, rule := range rules { if rule.ErrorCode != statusCode || len(rule.Keywords) == 0 { continue } matchedKeyword := matchTempUnschedKeyword(bodyLower, rule.Keywords) if matchedKeyword == "" { continue } if s.triggerTempUnschedulable(ctx, account, rule, idx, statusCode, matchedKeyword, responseBody) { return true } } return false } func matchTempUnschedKeyword(bodyLower string, keywords []string) string { if bodyLower == "" { return "" } for _, keyword := range keywords { k := strings.TrimSpace(keyword) if k == "" { continue } if strings.Contains(bodyLower, strings.ToLower(k)) { return k } } return "" } func (s *RateLimitService) triggerTempUnschedulable(ctx context.Context, account *Account, rule TempUnschedulableRule, ruleIndex int, statusCode int, matchedKeyword string, responseBody []byte) bool { if account == nil { return false } if rule.DurationMinutes <= 0 { return false } now := time.Now() until := now.Add(time.Duration(rule.DurationMinutes) * time.Minute) state := &TempUnschedState{ UntilUnix: until.Unix(), TriggeredAtUnix: now.Unix(), StatusCode: statusCode, MatchedKeyword: matchedKeyword, RuleIndex: ruleIndex, ErrorMessage: truncateTempUnschedMessage(responseBody, tempUnschedMessageMaxBytes), } reason := "" if raw, err := json.Marshal(state); err == nil { reason = string(raw) } if reason == "" { reason = strings.TrimSpace(state.ErrorMessage) } if err := s.accountRepo.SetTempUnschedulable(ctx, account.ID, until, reason); err != nil { slog.Warn("temp_unsched_set_failed", "account_id", account.ID, "error", err) return false } if s.tempUnschedCache != nil { if err := s.tempUnschedCache.SetTempUnsched(ctx, account.ID, state); err != nil { slog.Warn("temp_unsched_cache_set_failed", "account_id", account.ID, "error", err) } } slog.Info("account_temp_unschedulable", "account_id", account.ID, "until", until, "rule_index", ruleIndex, "status_code", statusCode) return true } func truncateTempUnschedMessage(body []byte, maxBytes int) string { if maxBytes <= 0 || len(body) == 0 { return "" } if len(body) > maxBytes { body = body[:maxBytes] } return strings.TrimSpace(string(body)) } // HandleStreamTimeout 处理流数据超时 // 根据系统设置决定是否标记账户为临时不可调度或错误状态 // 返回是否应该停止该账号的调度 func (s *RateLimitService) HandleStreamTimeout(ctx context.Context, account *Account, model string) bool { if account == nil { return false } // 获取系统设置 if s.settingService == nil { slog.Warn("stream_timeout_setting_service_missing", "account_id", account.ID) return false } settings, err := s.settingService.GetStreamTimeoutSettings(ctx) if err != nil { slog.Warn("stream_timeout_get_settings_failed", "account_id", account.ID, "error", err) return false } if !settings.Enabled { return false } if settings.Action == StreamTimeoutActionNone { return false } // 增加超时计数 var count int64 = 1 if s.timeoutCounterCache != nil { count, err = s.timeoutCounterCache.IncrementTimeoutCount(ctx, account.ID, settings.ThresholdWindowMinutes) if err != nil { slog.Warn("stream_timeout_increment_count_failed", "account_id", account.ID, "error", err) // 继续处理,使用 count=1 count = 1 } } slog.Info("stream_timeout_count", "account_id", account.ID, "count", count, "threshold", settings.ThresholdCount, "window_minutes", settings.ThresholdWindowMinutes, "model", model) // 检查是否达到阈值 if count < int64(settings.ThresholdCount) { return false } // 达到阈值,执行相应操作 switch settings.Action { case StreamTimeoutActionTempUnsched: return s.triggerStreamTimeoutTempUnsched(ctx, account, settings, model) case StreamTimeoutActionError: return s.triggerStreamTimeoutError(ctx, account, model) default: return false } } // triggerStreamTimeoutTempUnsched 触发流超时临时不可调度 func (s *RateLimitService) triggerStreamTimeoutTempUnsched(ctx context.Context, account *Account, settings *StreamTimeoutSettings, model string) bool { now := time.Now() until := now.Add(time.Duration(settings.TempUnschedMinutes) * time.Minute) state := &TempUnschedState{ UntilUnix: until.Unix(), TriggeredAtUnix: now.Unix(), StatusCode: 0, // 超时没有状态码 MatchedKeyword: "stream_timeout", RuleIndex: -1, // 表示系统级规则 ErrorMessage: "Stream data interval timeout for model: " + model, } reason := "" if raw, err := json.Marshal(state); err == nil { reason = string(raw) } if reason == "" { reason = state.ErrorMessage } if err := s.accountRepo.SetTempUnschedulable(ctx, account.ID, until, reason); err != nil { slog.Warn("stream_timeout_set_temp_unsched_failed", "account_id", account.ID, "error", err) return false } if s.tempUnschedCache != nil { if err := s.tempUnschedCache.SetTempUnsched(ctx, account.ID, state); err != nil { slog.Warn("stream_timeout_set_temp_unsched_cache_failed", "account_id", account.ID, "error", err) } } // 重置超时计数 if s.timeoutCounterCache != nil { if err := s.timeoutCounterCache.ResetTimeoutCount(ctx, account.ID); err != nil { slog.Warn("stream_timeout_reset_count_failed", "account_id", account.ID, "error", err) } } slog.Info("stream_timeout_temp_unschedulable", "account_id", account.ID, "until", until, "model", model) return true } // triggerStreamTimeoutError 触发流超时错误状态 func (s *RateLimitService) triggerStreamTimeoutError(ctx context.Context, account *Account, model string) bool { errorMsg := "Stream data interval timeout (repeated failures) for model: " + model if err := s.accountRepo.SetError(ctx, account.ID, errorMsg); err != nil { slog.Warn("stream_timeout_set_error_failed", "account_id", account.ID, "error", err) return false } // 重置超时计数 if s.timeoutCounterCache != nil { if err := s.timeoutCounterCache.ResetTimeoutCount(ctx, account.ID); err != nil { slog.Warn("stream_timeout_reset_count_failed", "account_id", account.ID, "error", err) } } slog.Warn("stream_timeout_account_error", "account_id", account.ID, "model", model) return true }