fix(认证): 修复 OAuth token 缓存失效与 401 处理
新增 token 缓存失效接口并在刷新后清理 401 限流支持自定义规则与可配置冷却时间 补齐缓存失效与 401 处理测试 测试: make test
This commit is contained in:
@@ -3,7 +3,7 @@ package service
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"log"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"strings"
|
||||
@@ -15,15 +15,16 @@ import (
|
||||
|
||||
// RateLimitService 处理限流和过载状态管理
|
||||
type RateLimitService struct {
|
||||
accountRepo AccountRepository
|
||||
usageRepo UsageLogRepository
|
||||
cfg *config.Config
|
||||
geminiQuotaService *GeminiQuotaService
|
||||
tempUnschedCache TempUnschedCache
|
||||
timeoutCounterCache TimeoutCounterCache
|
||||
settingService *SettingService
|
||||
usageCacheMu sync.RWMutex
|
||||
usageCache map[int64]*geminiUsageCacheEntry
|
||||
accountRepo AccountRepository
|
||||
usageRepo UsageLogRepository
|
||||
cfg *config.Config
|
||||
geminiQuotaService *GeminiQuotaService
|
||||
tempUnschedCache TempUnschedCache
|
||||
timeoutCounterCache TimeoutCounterCache
|
||||
settingService *SettingService
|
||||
tokenCacheInvalidator TokenCacheInvalidator
|
||||
usageCacheMu sync.RWMutex
|
||||
usageCache map[int64]*geminiUsageCacheEntry
|
||||
}
|
||||
|
||||
type geminiUsageCacheEntry struct {
|
||||
@@ -56,6 +57,11 @@ func (s *RateLimitService) SetSettingService(settingService *SettingService) {
|
||||
s.settingService = settingService
|
||||
}
|
||||
|
||||
// SetTokenCacheInvalidator 设置 token 缓存清理器(可选依赖)
|
||||
func (s *RateLimitService) SetTokenCacheInvalidator(invalidator TokenCacheInvalidator) {
|
||||
s.tokenCacheInvalidator = invalidator
|
||||
}
|
||||
|
||||
// HandleUpstreamError 处理上游错误响应,标记账号状态
|
||||
// 返回是否应该停止该账号的调度
|
||||
func (s *RateLimitService) HandleUpstreamError(ctx context.Context, account *Account, statusCode int, headers http.Header, responseBody []byte) (shouldDisable bool) {
|
||||
@@ -63,11 +69,16 @@ func (s *RateLimitService) HandleUpstreamError(ctx context.Context, account *Acc
|
||||
// 如果启用且错误码不在列表中,则不处理(不停止调度、不标记限流/过载)
|
||||
customErrorCodesEnabled := account.IsCustomErrorCodesEnabled()
|
||||
if !account.ShouldHandleErrorCode(statusCode) {
|
||||
log.Printf("Account %d: error %d skipped (not in custom error codes)", account.ID, statusCode)
|
||||
slog.Info("account_error_code_skipped", "account_id", account.ID, "status_code", statusCode)
|
||||
return false
|
||||
}
|
||||
|
||||
tempMatched := s.tryTempUnschedulable(ctx, account, statusCode, responseBody)
|
||||
isOAuth401 := statusCode == 401 && account.Type == AccountTypeOAuth &&
|
||||
(account.Platform == PlatformAntigravity || account.Platform == PlatformGemini)
|
||||
tempMatched := false
|
||||
if !isOAuth401 || account.IsTempUnschedulableEnabled() {
|
||||
tempMatched = s.tryTempUnschedulable(ctx, account, statusCode, responseBody)
|
||||
}
|
||||
upstreamMsg := strings.TrimSpace(extractUpstreamErrorMessage(responseBody))
|
||||
upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg)
|
||||
if upstreamMsg != "" {
|
||||
@@ -76,7 +87,19 @@ func (s *RateLimitService) HandleUpstreamError(ctx context.Context, account *Acc
|
||||
|
||||
switch statusCode {
|
||||
case 401:
|
||||
// 认证失败:停止调度,记录错误
|
||||
if isOAuth401 {
|
||||
if tempMatched {
|
||||
if s.tokenCacheInvalidator != nil {
|
||||
if err := s.tokenCacheInvalidator.InvalidateToken(ctx, account); err != nil {
|
||||
slog.Warn("oauth_401_invalidate_cache_failed", "account_id", account.ID, "error", err)
|
||||
}
|
||||
}
|
||||
shouldDisable = true
|
||||
} else {
|
||||
shouldDisable = s.handleOAuth401TempUnschedulable(ctx, account, upstreamMsg)
|
||||
}
|
||||
break
|
||||
}
|
||||
msg := "Authentication failed (401): invalid or expired credentials"
|
||||
if upstreamMsg != "" {
|
||||
msg = "Authentication failed (401): " + upstreamMsg
|
||||
@@ -116,7 +139,7 @@ func (s *RateLimitService) HandleUpstreamError(ctx context.Context, account *Acc
|
||||
shouldDisable = true
|
||||
} else if statusCode >= 500 {
|
||||
// 未启用自定义错误码时:仅记录5xx错误
|
||||
log.Printf("Account %d received upstream error %d", account.ID, statusCode)
|
||||
slog.Warn("account_upstream_error", "account_id", account.ID, "status_code", statusCode)
|
||||
shouldDisable = false
|
||||
}
|
||||
}
|
||||
@@ -127,6 +150,63 @@ func (s *RateLimitService) HandleUpstreamError(ctx context.Context, account *Acc
|
||||
return shouldDisable
|
||||
}
|
||||
|
||||
func (s *RateLimitService) handleOAuth401TempUnschedulable(ctx context.Context, account *Account, upstreamMsg string) bool {
|
||||
if account == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
if s.tokenCacheInvalidator != nil {
|
||||
if err := s.tokenCacheInvalidator.InvalidateToken(ctx, account); err != nil {
|
||||
slog.Warn("oauth_401_invalidate_cache_failed", "account_id", account.ID, "error", err)
|
||||
}
|
||||
}
|
||||
|
||||
now := time.Now()
|
||||
until := now.Add(s.oauth401Cooldown())
|
||||
msg := "Authentication failed (401): invalid or expired credentials"
|
||||
if upstreamMsg != "" {
|
||||
msg = "Authentication failed (401): " + upstreamMsg
|
||||
}
|
||||
|
||||
state := &TempUnschedState{
|
||||
UntilUnix: until.Unix(),
|
||||
TriggeredAtUnix: now.Unix(),
|
||||
StatusCode: 401,
|
||||
MatchedKeyword: "oauth_401",
|
||||
RuleIndex: -1, // -1 表示非规则触发,而是 OAuth 401 特殊处理
|
||||
ErrorMessage: msg,
|
||||
}
|
||||
|
||||
reason := ""
|
||||
if raw, err := json.Marshal(state); err == nil {
|
||||
reason = string(raw)
|
||||
}
|
||||
if reason == "" {
|
||||
reason = msg
|
||||
}
|
||||
|
||||
if err := s.accountRepo.SetTempUnschedulable(ctx, account.ID, until, reason); err != nil {
|
||||
slog.Warn("oauth_401_set_temp_unschedulable_failed", "account_id", account.ID, "error", err)
|
||||
return false
|
||||
}
|
||||
|
||||
if s.tempUnschedCache != nil {
|
||||
if err := s.tempUnschedCache.SetTempUnsched(ctx, account.ID, state); err != nil {
|
||||
slog.Warn("oauth_401_set_temp_unsched_cache_failed", "account_id", account.ID, "error", err)
|
||||
}
|
||||
}
|
||||
|
||||
slog.Info("oauth_401_temp_unschedulable", "account_id", account.ID, "until", until)
|
||||
return true
|
||||
}
|
||||
|
||||
func (s *RateLimitService) oauth401Cooldown() time.Duration {
|
||||
if s != nil && s.cfg != nil && s.cfg.RateLimit.OAuth401CooldownMinutes > 0 {
|
||||
return time.Duration(s.cfg.RateLimit.OAuth401CooldownMinutes) * time.Minute
|
||||
}
|
||||
return 5 * time.Minute
|
||||
}
|
||||
|
||||
// PreCheckUsage proactively checks local quota before dispatching a request.
|
||||
// Returns false when the account should be skipped.
|
||||
func (s *RateLimitService) PreCheckUsage(ctx context.Context, account *Account, requestedModel string) (bool, error) {
|
||||
@@ -188,7 +268,7 @@ func (s *RateLimitService) PreCheckUsage(ctx context.Context, account *Account,
|
||||
// NOTE:
|
||||
// - This is a local precheck to reduce upstream 429s.
|
||||
// - Do NOT mark the account as rate-limited here; rate_limit_reset_at should reflect real upstream 429s.
|
||||
log.Printf("[Gemini PreCheck] Account %d reached daily quota (%d/%d), skip until %v", account.ID, used, limit, resetAt)
|
||||
slog.Info("gemini_precheck_daily_quota_reached", "account_id", account.ID, "used", used, "limit", limit, "reset_at", resetAt)
|
||||
return false, nil
|
||||
}
|
||||
}
|
||||
@@ -231,7 +311,7 @@ func (s *RateLimitService) PreCheckUsage(ctx context.Context, account *Account,
|
||||
if used >= limit {
|
||||
resetAt := start.Add(time.Minute)
|
||||
// Do not persist "rate limited" status from local precheck. See note above.
|
||||
log.Printf("[Gemini PreCheck] Account %d reached minute quota (%d/%d), skip until %v", account.ID, used, limit, resetAt)
|
||||
slog.Info("gemini_precheck_minute_quota_reached", "account_id", account.ID, "used", used, "limit", limit, "reset_at", resetAt)
|
||||
return false, nil
|
||||
}
|
||||
}
|
||||
@@ -288,20 +368,20 @@ func (s *RateLimitService) GeminiCooldown(ctx context.Context, account *Account)
|
||||
// handleAuthError 处理认证类错误(401/403),停止账号调度
|
||||
func (s *RateLimitService) handleAuthError(ctx context.Context, account *Account, errorMsg string) {
|
||||
if err := s.accountRepo.SetError(ctx, account.ID, errorMsg); err != nil {
|
||||
log.Printf("SetError failed for account %d: %v", account.ID, err)
|
||||
slog.Warn("account_set_error_failed", "account_id", account.ID, "error", err)
|
||||
return
|
||||
}
|
||||
log.Printf("Account %d disabled due to auth error: %s", account.ID, errorMsg)
|
||||
slog.Warn("account_disabled_auth_error", "account_id", account.ID, "error", errorMsg)
|
||||
}
|
||||
|
||||
// handleCustomErrorCode 处理自定义错误码,停止账号调度
|
||||
func (s *RateLimitService) handleCustomErrorCode(ctx context.Context, account *Account, statusCode int, errorMsg string) {
|
||||
msg := "Custom error code " + strconv.Itoa(statusCode) + ": " + errorMsg
|
||||
if err := s.accountRepo.SetError(ctx, account.ID, msg); err != nil {
|
||||
log.Printf("SetError failed for account %d: %v", account.ID, err)
|
||||
slog.Warn("account_set_error_failed", "account_id", account.ID, "status_code", statusCode, "error", err)
|
||||
return
|
||||
}
|
||||
log.Printf("Account %d disabled due to custom error code %d: %s", account.ID, statusCode, errorMsg)
|
||||
slog.Warn("account_disabled_custom_error", "account_id", account.ID, "status_code", statusCode, "error", errorMsg)
|
||||
}
|
||||
|
||||
// handle429 处理429限流错误
|
||||
@@ -313,7 +393,7 @@ func (s *RateLimitService) handle429(ctx context.Context, account *Account, head
|
||||
// 没有重置时间,使用默认5分钟
|
||||
resetAt := time.Now().Add(5 * time.Minute)
|
||||
if err := s.accountRepo.SetRateLimited(ctx, account.ID, resetAt); err != nil {
|
||||
log.Printf("SetRateLimited failed for account %d: %v", account.ID, err)
|
||||
slog.Warn("rate_limit_set_failed", "account_id", account.ID, "error", err)
|
||||
}
|
||||
return
|
||||
}
|
||||
@@ -321,10 +401,10 @@ func (s *RateLimitService) handle429(ctx context.Context, account *Account, head
|
||||
// 解析Unix时间戳
|
||||
ts, err := strconv.ParseInt(resetTimestamp, 10, 64)
|
||||
if err != nil {
|
||||
log.Printf("Parse reset timestamp failed: %v", err)
|
||||
slog.Warn("rate_limit_reset_parse_failed", "reset_timestamp", resetTimestamp, "error", err)
|
||||
resetAt := time.Now().Add(5 * time.Minute)
|
||||
if err := s.accountRepo.SetRateLimited(ctx, account.ID, resetAt); err != nil {
|
||||
log.Printf("SetRateLimited failed for account %d: %v", account.ID, err)
|
||||
slog.Warn("rate_limit_set_failed", "account_id", account.ID, "error", err)
|
||||
}
|
||||
return
|
||||
}
|
||||
@@ -333,7 +413,7 @@ func (s *RateLimitService) handle429(ctx context.Context, account *Account, head
|
||||
|
||||
// 标记限流状态
|
||||
if err := s.accountRepo.SetRateLimited(ctx, account.ID, resetAt); err != nil {
|
||||
log.Printf("SetRateLimited failed for account %d: %v", account.ID, err)
|
||||
slog.Warn("rate_limit_set_failed", "account_id", account.ID, "error", err)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -341,10 +421,10 @@ func (s *RateLimitService) handle429(ctx context.Context, account *Account, head
|
||||
windowEnd := resetAt
|
||||
windowStart := resetAt.Add(-5 * time.Hour)
|
||||
if err := s.accountRepo.UpdateSessionWindow(ctx, account.ID, &windowStart, &windowEnd, "rejected"); err != nil {
|
||||
log.Printf("UpdateSessionWindow failed for account %d: %v", account.ID, err)
|
||||
slog.Warn("rate_limit_update_session_window_failed", "account_id", account.ID, "error", err)
|
||||
}
|
||||
|
||||
log.Printf("Account %d rate limited until %v", account.ID, resetAt)
|
||||
slog.Info("account_rate_limited", "account_id", account.ID, "reset_at", resetAt)
|
||||
}
|
||||
|
||||
// handle529 处理529过载错误
|
||||
@@ -357,11 +437,11 @@ func (s *RateLimitService) handle529(ctx context.Context, account *Account) {
|
||||
|
||||
until := time.Now().Add(time.Duration(cooldownMinutes) * time.Minute)
|
||||
if err := s.accountRepo.SetOverloaded(ctx, account.ID, until); err != nil {
|
||||
log.Printf("SetOverloaded failed for account %d: %v", account.ID, err)
|
||||
slog.Warn("overload_set_failed", "account_id", account.ID, "error", err)
|
||||
return
|
||||
}
|
||||
|
||||
log.Printf("Account %d overloaded until %v", account.ID, until)
|
||||
slog.Info("account_overloaded", "account_id", account.ID, "until", until)
|
||||
}
|
||||
|
||||
// UpdateSessionWindow 从成功响应更新5h窗口状态
|
||||
@@ -384,17 +464,17 @@ func (s *RateLimitService) UpdateSessionWindow(ctx context.Context, account *Acc
|
||||
end := start.Add(5 * time.Hour)
|
||||
windowStart = &start
|
||||
windowEnd = &end
|
||||
log.Printf("Account %d: initializing 5h window from %v to %v (status: %s)", account.ID, start, end, status)
|
||||
slog.Info("account_session_window_initialized", "account_id", account.ID, "window_start", start, "window_end", end, "status", status)
|
||||
}
|
||||
|
||||
if err := s.accountRepo.UpdateSessionWindow(ctx, account.ID, windowStart, windowEnd, status); err != nil {
|
||||
log.Printf("UpdateSessionWindow failed for account %d: %v", account.ID, err)
|
||||
slog.Warn("session_window_update_failed", "account_id", account.ID, "error", err)
|
||||
}
|
||||
|
||||
// 如果状态为allowed且之前有限流,说明窗口已重置,清除限流状态
|
||||
if status == "allowed" && account.IsRateLimited() {
|
||||
if err := s.ClearRateLimit(ctx, account.ID); err != nil {
|
||||
log.Printf("ClearRateLimit failed for account %d: %v", account.ID, err)
|
||||
slog.Warn("rate_limit_clear_failed", "account_id", account.ID, "error", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -413,7 +493,7 @@ func (s *RateLimitService) ClearTempUnschedulable(ctx context.Context, accountID
|
||||
}
|
||||
if s.tempUnschedCache != nil {
|
||||
if err := s.tempUnschedCache.DeleteTempUnsched(ctx, accountID); err != nil {
|
||||
log.Printf("DeleteTempUnsched failed for account %d: %v", accountID, err)
|
||||
slog.Warn("temp_unsched_cache_delete_failed", "account_id", accountID, "error", err)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
@@ -460,7 +540,7 @@ func (s *RateLimitService) GetTempUnschedStatus(ctx context.Context, accountID i
|
||||
|
||||
if s.tempUnschedCache != nil {
|
||||
if err := s.tempUnschedCache.SetTempUnsched(ctx, accountID, state); err != nil {
|
||||
log.Printf("SetTempUnsched failed for account %d: %v", accountID, err)
|
||||
slog.Warn("temp_unsched_cache_set_failed", "account_id", accountID, "error", err)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -563,17 +643,17 @@ func (s *RateLimitService) triggerTempUnschedulable(ctx context.Context, account
|
||||
}
|
||||
|
||||
if err := s.accountRepo.SetTempUnschedulable(ctx, account.ID, until, reason); err != nil {
|
||||
log.Printf("SetTempUnschedulable failed for account %d: %v", account.ID, err)
|
||||
slog.Warn("temp_unsched_set_failed", "account_id", account.ID, "error", err)
|
||||
return false
|
||||
}
|
||||
|
||||
if s.tempUnschedCache != nil {
|
||||
if err := s.tempUnschedCache.SetTempUnsched(ctx, account.ID, state); err != nil {
|
||||
log.Printf("SetTempUnsched cache failed for account %d: %v", account.ID, err)
|
||||
slog.Warn("temp_unsched_cache_set_failed", "account_id", account.ID, "error", err)
|
||||
}
|
||||
}
|
||||
|
||||
log.Printf("Account %d temp unschedulable until %v (rule %d, code %d)", account.ID, until, ruleIndex, statusCode)
|
||||
slog.Info("account_temp_unschedulable", "account_id", account.ID, "until", until, "rule_index", ruleIndex, "status_code", statusCode)
|
||||
return true
|
||||
}
|
||||
|
||||
@@ -597,13 +677,13 @@ func (s *RateLimitService) HandleStreamTimeout(ctx context.Context, account *Acc
|
||||
|
||||
// 获取系统设置
|
||||
if s.settingService == nil {
|
||||
log.Printf("[StreamTimeout] settingService not configured, skipping timeout handling for account %d", account.ID)
|
||||
slog.Warn("stream_timeout_setting_service_missing", "account_id", account.ID)
|
||||
return false
|
||||
}
|
||||
|
||||
settings, err := s.settingService.GetStreamTimeoutSettings(ctx)
|
||||
if err != nil {
|
||||
log.Printf("[StreamTimeout] Failed to get settings: %v", err)
|
||||
slog.Warn("stream_timeout_get_settings_failed", "account_id", account.ID, "error", err)
|
||||
return false
|
||||
}
|
||||
|
||||
@@ -620,14 +700,13 @@ func (s *RateLimitService) HandleStreamTimeout(ctx context.Context, account *Acc
|
||||
if s.timeoutCounterCache != nil {
|
||||
count, err = s.timeoutCounterCache.IncrementTimeoutCount(ctx, account.ID, settings.ThresholdWindowMinutes)
|
||||
if err != nil {
|
||||
log.Printf("[StreamTimeout] Failed to increment timeout count for account %d: %v", account.ID, err)
|
||||
slog.Warn("stream_timeout_increment_count_failed", "account_id", account.ID, "error", err)
|
||||
// 继续处理,使用 count=1
|
||||
count = 1
|
||||
}
|
||||
}
|
||||
|
||||
log.Printf("[StreamTimeout] Account %d timeout count: %d/%d (window: %d min, model: %s)",
|
||||
account.ID, count, settings.ThresholdCount, settings.ThresholdWindowMinutes, model)
|
||||
slog.Info("stream_timeout_count", "account_id", account.ID, "count", count, "threshold", settings.ThresholdCount, "window_minutes", settings.ThresholdWindowMinutes, "model", model)
|
||||
|
||||
// 检查是否达到阈值
|
||||
if count < int64(settings.ThresholdCount) {
|
||||
@@ -668,24 +747,24 @@ func (s *RateLimitService) triggerStreamTimeoutTempUnsched(ctx context.Context,
|
||||
}
|
||||
|
||||
if err := s.accountRepo.SetTempUnschedulable(ctx, account.ID, until, reason); err != nil {
|
||||
log.Printf("[StreamTimeout] SetTempUnschedulable failed for account %d: %v", account.ID, err)
|
||||
slog.Warn("stream_timeout_set_temp_unsched_failed", "account_id", account.ID, "error", err)
|
||||
return false
|
||||
}
|
||||
|
||||
if s.tempUnschedCache != nil {
|
||||
if err := s.tempUnschedCache.SetTempUnsched(ctx, account.ID, state); err != nil {
|
||||
log.Printf("[StreamTimeout] SetTempUnsched cache failed for account %d: %v", account.ID, err)
|
||||
slog.Warn("stream_timeout_set_temp_unsched_cache_failed", "account_id", account.ID, "error", err)
|
||||
}
|
||||
}
|
||||
|
||||
// 重置超时计数
|
||||
if s.timeoutCounterCache != nil {
|
||||
if err := s.timeoutCounterCache.ResetTimeoutCount(ctx, account.ID); err != nil {
|
||||
log.Printf("[StreamTimeout] ResetTimeoutCount failed for account %d: %v", account.ID, err)
|
||||
slog.Warn("stream_timeout_reset_count_failed", "account_id", account.ID, "error", err)
|
||||
}
|
||||
}
|
||||
|
||||
log.Printf("[StreamTimeout] Account %d marked as temp unschedulable until %v (model: %s)", account.ID, until, model)
|
||||
slog.Info("stream_timeout_temp_unschedulable", "account_id", account.ID, "until", until, "model", model)
|
||||
return true
|
||||
}
|
||||
|
||||
@@ -694,17 +773,17 @@ func (s *RateLimitService) triggerStreamTimeoutError(ctx context.Context, accoun
|
||||
errorMsg := "Stream data interval timeout (repeated failures) for model: " + model
|
||||
|
||||
if err := s.accountRepo.SetError(ctx, account.ID, errorMsg); err != nil {
|
||||
log.Printf("[StreamTimeout] SetError failed for account %d: %v", account.ID, err)
|
||||
slog.Warn("stream_timeout_set_error_failed", "account_id", account.ID, "error", err)
|
||||
return false
|
||||
}
|
||||
|
||||
// 重置超时计数
|
||||
if s.timeoutCounterCache != nil {
|
||||
if err := s.timeoutCounterCache.ResetTimeoutCount(ctx, account.ID); err != nil {
|
||||
log.Printf("[StreamTimeout] ResetTimeoutCount failed for account %d: %v", account.ID, err)
|
||||
slog.Warn("stream_timeout_reset_count_failed", "account_id", account.ID, "error", err)
|
||||
}
|
||||
}
|
||||
|
||||
log.Printf("[StreamTimeout] Account %d marked as error (model: %s)", account.ID, model)
|
||||
slog.Warn("stream_timeout_account_error", "account_id", account.ID, "model", model)
|
||||
return true
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user