fix(认证): 修复 OAuth token 缓存失效与 401 处理

新增 token 缓存失效接口并在刷新后清理
401 限流支持自定义规则与可配置冷却时间
补齐缓存失效与 401 处理测试
测试: make test
This commit is contained in:
yangjianbo
2026-01-14 15:55:44 +08:00
parent 9c567fad92
commit daf10907e4
19 changed files with 1257 additions and 63 deletions

View File

@@ -98,12 +98,13 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
geminiQuotaService := service.NewGeminiQuotaService(configConfig, settingRepository) geminiQuotaService := service.NewGeminiQuotaService(configConfig, settingRepository)
tempUnschedCache := repository.NewTempUnschedCache(redisClient) tempUnschedCache := repository.NewTempUnschedCache(redisClient)
timeoutCounterCache := repository.NewTimeoutCounterCache(redisClient) timeoutCounterCache := repository.NewTimeoutCounterCache(redisClient)
rateLimitService := service.ProvideRateLimitService(accountRepository, usageLogRepository, configConfig, geminiQuotaService, tempUnschedCache, timeoutCounterCache, settingService) geminiTokenCache := repository.NewGeminiTokenCache(redisClient)
tokenCacheInvalidator := service.NewCompositeTokenCacheInvalidator(geminiTokenCache)
rateLimitService := service.ProvideRateLimitService(accountRepository, usageLogRepository, configConfig, geminiQuotaService, tempUnschedCache, timeoutCounterCache, settingService, tokenCacheInvalidator)
claudeUsageFetcher := repository.NewClaudeUsageFetcher() claudeUsageFetcher := repository.NewClaudeUsageFetcher()
antigravityQuotaFetcher := service.NewAntigravityQuotaFetcher(proxyRepository) antigravityQuotaFetcher := service.NewAntigravityQuotaFetcher(proxyRepository)
usageCache := service.NewUsageCache() usageCache := service.NewUsageCache()
accountUsageService := service.NewAccountUsageService(accountRepository, usageLogRepository, claudeUsageFetcher, geminiQuotaService, antigravityQuotaFetcher, usageCache) accountUsageService := service.NewAccountUsageService(accountRepository, usageLogRepository, claudeUsageFetcher, geminiQuotaService, antigravityQuotaFetcher, usageCache)
geminiTokenCache := repository.NewGeminiTokenCache(redisClient)
geminiTokenProvider := service.NewGeminiTokenProvider(accountRepository, geminiTokenCache, geminiOAuthService) geminiTokenProvider := service.NewGeminiTokenProvider(accountRepository, geminiTokenCache, geminiOAuthService)
gatewayCache := repository.NewGatewayCache(redisClient) gatewayCache := repository.NewGatewayCache(redisClient)
antigravityTokenProvider := service.NewAntigravityTokenProvider(accountRepository, geminiTokenCache, antigravityOAuthService) antigravityTokenProvider := service.NewAntigravityTokenProvider(accountRepository, geminiTokenCache, antigravityOAuthService)
@@ -166,7 +167,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
opsAlertEvaluatorService := service.ProvideOpsAlertEvaluatorService(opsService, opsRepository, emailService, redisClient, configConfig) opsAlertEvaluatorService := service.ProvideOpsAlertEvaluatorService(opsService, opsRepository, emailService, redisClient, configConfig)
opsCleanupService := service.ProvideOpsCleanupService(opsRepository, db, redisClient, configConfig) opsCleanupService := service.ProvideOpsCleanupService(opsRepository, db, redisClient, configConfig)
opsScheduledReportService := service.ProvideOpsScheduledReportService(opsService, userService, emailService, redisClient, configConfig) opsScheduledReportService := service.ProvideOpsScheduledReportService(opsService, userService, emailService, redisClient, configConfig)
tokenRefreshService := service.ProvideTokenRefreshService(accountRepository, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, configConfig) tokenRefreshService := service.ProvideTokenRefreshService(accountRepository, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, tokenCacheInvalidator, configConfig)
accountExpiryService := service.ProvideAccountExpiryService(accountRepository) accountExpiryService := service.ProvideAccountExpiryService(accountRepository)
v := provideCleanup(client, redisClient, opsMetricsCollector, opsAggregationService, opsAlertEvaluatorService, opsCleanupService, opsScheduledReportService, schedulerSnapshotService, tokenRefreshService, accountExpiryService, pricingService, emailQueueService, billingCacheService, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService) v := provideCleanup(client, redisClient, opsMetricsCollector, opsAggregationService, opsAlertEvaluatorService, opsCleanupService, opsScheduledReportService, schedulerSnapshotService, tokenRefreshService, accountExpiryService, pricingService, emailQueueService, billingCacheService, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService)
application := &Application{ application := &Application{

View File

@@ -435,7 +435,8 @@ type DefaultConfig struct {
} }
type RateLimitConfig struct { type RateLimitConfig struct {
OverloadCooldownMinutes int `mapstructure:"overload_cooldown_minutes"` // 529过载冷却时间(分钟) OverloadCooldownMinutes int `mapstructure:"overload_cooldown_minutes"` // 529过载冷却时间(分钟)
OAuth401CooldownMinutes int `mapstructure:"oauth_401_cooldown_minutes"` // OAuth 401 临时不可调度冷却时间(分钟)
} }
// APIKeyAuthCacheConfig API Key 认证缓存配置 // APIKeyAuthCacheConfig API Key 认证缓存配置
@@ -709,6 +710,7 @@ func setDefaults() {
// RateLimit // RateLimit
viper.SetDefault("rate_limit.overload_cooldown_minutes", 10) viper.SetDefault("rate_limit.overload_cooldown_minutes", 10)
viper.SetDefault("rate_limit.oauth_401_cooldown_minutes", 5)
// Pricing - 从 price-mirror 分支同步,该分支维护了 sha256 哈希文件用于增量更新检查 // Pricing - 从 price-mirror 分支同步,该分支维护了 sha256 哈希文件用于增量更新检查
viper.SetDefault("pricing.remote_url", "https://raw.githubusercontent.com/Wei-Shaw/claude-relay-service/price-mirror/model_prices_and_context_window.json") viper.SetDefault("pricing.remote_url", "https://raw.githubusercontent.com/Wei-Shaw/claude-relay-service/price-mirror/model_prices_and_context_window.json")

View File

@@ -33,6 +33,11 @@ func (c *geminiTokenCache) SetAccessToken(ctx context.Context, cacheKey string,
return c.rdb.Set(ctx, key, token, ttl).Err() return c.rdb.Set(ctx, key, token, ttl).Err()
} }
func (c *geminiTokenCache) DeleteAccessToken(ctx context.Context, cacheKey string) error {
key := fmt.Sprintf("%s%s", geminiTokenKeyPrefix, cacheKey)
return c.rdb.Del(ctx, key).Err()
}
func (c *geminiTokenCache) AcquireRefreshLock(ctx context.Context, cacheKey string, ttl time.Duration) (bool, error) { func (c *geminiTokenCache) AcquireRefreshLock(ctx context.Context, cacheKey string, ttl time.Duration) (bool, error) {
key := fmt.Sprintf("%s%s", geminiRefreshLockKeyPrefix, cacheKey) key := fmt.Sprintf("%s%s", geminiRefreshLockKeyPrefix, cacheKey)
return c.rdb.SetNX(ctx, key, 1, ttl).Result() return c.rdb.SetNX(ctx, key, 1, ttl).Result()

View File

@@ -0,0 +1,47 @@
//go:build integration
package repository
import (
"errors"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/redis/go-redis/v9"
"github.com/stretchr/testify/require"
"github.com/stretchr/testify/suite"
)
type GeminiTokenCacheSuite struct {
IntegrationRedisSuite
cache service.GeminiTokenCache
}
func (s *GeminiTokenCacheSuite) SetupTest() {
s.IntegrationRedisSuite.SetupTest()
s.cache = NewGeminiTokenCache(s.rdb)
}
func (s *GeminiTokenCacheSuite) TestDeleteAccessToken() {
cacheKey := "project-123"
token := "token-value"
require.NoError(s.T(), s.cache.SetAccessToken(s.ctx, cacheKey, token, time.Minute))
got, err := s.cache.GetAccessToken(s.ctx, cacheKey)
require.NoError(s.T(), err)
require.Equal(s.T(), token, got)
require.NoError(s.T(), s.cache.DeleteAccessToken(s.ctx, cacheKey))
_, err = s.cache.GetAccessToken(s.ctx, cacheKey)
require.True(s.T(), errors.Is(err, redis.Nil), "expected redis.Nil after delete")
}
func (s *GeminiTokenCacheSuite) TestDeleteAccessToken_MissingKey() {
require.NoError(s.T(), s.cache.DeleteAccessToken(s.ctx, "missing-key"))
}
func TestGeminiTokenCacheSuite(t *testing.T) {
suite.Run(t, new(GeminiTokenCacheSuite))
}

View File

@@ -0,0 +1,28 @@
//go:build unit
package repository
import (
"context"
"testing"
"time"
"github.com/redis/go-redis/v9"
"github.com/stretchr/testify/require"
)
func TestGeminiTokenCache_DeleteAccessToken_RedisError(t *testing.T) {
rdb := redis.NewClient(&redis.Options{
Addr: "127.0.0.1:1",
DialTimeout: 50 * time.Millisecond,
ReadTimeout: 50 * time.Millisecond,
WriteTimeout: 50 * time.Millisecond,
})
t.Cleanup(func() {
_ = rdb.Close()
})
cache := NewGeminiTokenCache(rdb)
err := cache.DeleteAccessToken(context.Background(), "broken")
require.Error(t, err)
}

View File

@@ -45,7 +45,7 @@ func (p *AntigravityTokenProvider) GetAccessToken(ctx context.Context, account *
return "", errors.New("not an antigravity oauth account") return "", errors.New("not an antigravity oauth account")
} }
cacheKey := antigravityTokenCacheKey(account) cacheKey := AntigravityTokenCacheKey(account)
// 1. 先尝试缓存 // 1. 先尝试缓存
if p.tokenCache != nil { if p.tokenCache != nil {
@@ -121,7 +121,7 @@ func (p *AntigravityTokenProvider) GetAccessToken(ctx context.Context, account *
return accessToken, nil return accessToken, nil
} }
func antigravityTokenCacheKey(account *Account) string { func AntigravityTokenCacheKey(account *Account) string {
projectID := strings.TrimSpace(account.GetCredential("project_id")) projectID := strings.TrimSpace(account.GetCredential("project_id"))
if projectID != "" { if projectID != "" {
return "ag:" + projectID return "ag:" + projectID

View File

@@ -10,6 +10,7 @@ type GeminiTokenCache interface {
// cacheKey should be stable for the token scope; for GeminiCli OAuth we primarily use project_id. // cacheKey should be stable for the token scope; for GeminiCli OAuth we primarily use project_id.
GetAccessToken(ctx context.Context, cacheKey string) (string, error) GetAccessToken(ctx context.Context, cacheKey string) (string, error)
SetAccessToken(ctx context.Context, cacheKey string, token string, ttl time.Duration) error SetAccessToken(ctx context.Context, cacheKey string, token string, ttl time.Duration) error
DeleteAccessToken(ctx context.Context, cacheKey string) error
AcquireRefreshLock(ctx context.Context, cacheKey string, ttl time.Duration) (bool, error) AcquireRefreshLock(ctx context.Context, cacheKey string, ttl time.Duration) (bool, error)
ReleaseRefreshLock(ctx context.Context, cacheKey string) error ReleaseRefreshLock(ctx context.Context, cacheKey string) error

View File

@@ -40,7 +40,7 @@ func (p *GeminiTokenProvider) GetAccessToken(ctx context.Context, account *Accou
return "", errors.New("not a gemini oauth account") return "", errors.New("not a gemini oauth account")
} }
cacheKey := geminiTokenCacheKey(account) cacheKey := GeminiTokenCacheKey(account)
// 1) Try cache first. // 1) Try cache first.
if p.tokenCache != nil { if p.tokenCache != nil {
@@ -151,7 +151,7 @@ func (p *GeminiTokenProvider) GetAccessToken(ctx context.Context, account *Accou
return accessToken, nil return accessToken, nil
} }
func geminiTokenCacheKey(account *Account) string { func GeminiTokenCacheKey(account *Account) string {
projectID := strings.TrimSpace(account.GetCredential("project_id")) projectID := strings.TrimSpace(account.GetCredential("project_id"))
if projectID != "" { if projectID != "" {
return projectID return projectID

View File

@@ -3,7 +3,7 @@ package service
import ( import (
"context" "context"
"encoding/json" "encoding/json"
"log" "log/slog"
"net/http" "net/http"
"strconv" "strconv"
"strings" "strings"
@@ -15,15 +15,16 @@ import (
// RateLimitService 处理限流和过载状态管理 // RateLimitService 处理限流和过载状态管理
type RateLimitService struct { type RateLimitService struct {
accountRepo AccountRepository accountRepo AccountRepository
usageRepo UsageLogRepository usageRepo UsageLogRepository
cfg *config.Config cfg *config.Config
geminiQuotaService *GeminiQuotaService geminiQuotaService *GeminiQuotaService
tempUnschedCache TempUnschedCache tempUnschedCache TempUnschedCache
timeoutCounterCache TimeoutCounterCache timeoutCounterCache TimeoutCounterCache
settingService *SettingService settingService *SettingService
usageCacheMu sync.RWMutex tokenCacheInvalidator TokenCacheInvalidator
usageCache map[int64]*geminiUsageCacheEntry usageCacheMu sync.RWMutex
usageCache map[int64]*geminiUsageCacheEntry
} }
type geminiUsageCacheEntry struct { type geminiUsageCacheEntry struct {
@@ -56,6 +57,11 @@ func (s *RateLimitService) SetSettingService(settingService *SettingService) {
s.settingService = settingService s.settingService = settingService
} }
// SetTokenCacheInvalidator 设置 token 缓存清理器(可选依赖)
func (s *RateLimitService) SetTokenCacheInvalidator(invalidator TokenCacheInvalidator) {
s.tokenCacheInvalidator = invalidator
}
// HandleUpstreamError 处理上游错误响应,标记账号状态 // HandleUpstreamError 处理上游错误响应,标记账号状态
// 返回是否应该停止该账号的调度 // 返回是否应该停止该账号的调度
func (s *RateLimitService) HandleUpstreamError(ctx context.Context, account *Account, statusCode int, headers http.Header, responseBody []byte) (shouldDisable bool) { func (s *RateLimitService) HandleUpstreamError(ctx context.Context, account *Account, statusCode int, headers http.Header, responseBody []byte) (shouldDisable bool) {
@@ -63,11 +69,16 @@ func (s *RateLimitService) HandleUpstreamError(ctx context.Context, account *Acc
// 如果启用且错误码不在列表中,则不处理(不停止调度、不标记限流/过载) // 如果启用且错误码不在列表中,则不处理(不停止调度、不标记限流/过载)
customErrorCodesEnabled := account.IsCustomErrorCodesEnabled() customErrorCodesEnabled := account.IsCustomErrorCodesEnabled()
if !account.ShouldHandleErrorCode(statusCode) { if !account.ShouldHandleErrorCode(statusCode) {
log.Printf("Account %d: error %d skipped (not in custom error codes)", account.ID, statusCode) slog.Info("account_error_code_skipped", "account_id", account.ID, "status_code", statusCode)
return false return false
} }
tempMatched := s.tryTempUnschedulable(ctx, account, statusCode, responseBody) isOAuth401 := statusCode == 401 && account.Type == AccountTypeOAuth &&
(account.Platform == PlatformAntigravity || account.Platform == PlatformGemini)
tempMatched := false
if !isOAuth401 || account.IsTempUnschedulableEnabled() {
tempMatched = s.tryTempUnschedulable(ctx, account, statusCode, responseBody)
}
upstreamMsg := strings.TrimSpace(extractUpstreamErrorMessage(responseBody)) upstreamMsg := strings.TrimSpace(extractUpstreamErrorMessage(responseBody))
upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg) upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg)
if upstreamMsg != "" { if upstreamMsg != "" {
@@ -76,7 +87,19 @@ func (s *RateLimitService) HandleUpstreamError(ctx context.Context, account *Acc
switch statusCode { switch statusCode {
case 401: case 401:
// 认证失败:停止调度,记录错误 if isOAuth401 {
if tempMatched {
if s.tokenCacheInvalidator != nil {
if err := s.tokenCacheInvalidator.InvalidateToken(ctx, account); err != nil {
slog.Warn("oauth_401_invalidate_cache_failed", "account_id", account.ID, "error", err)
}
}
shouldDisable = true
} else {
shouldDisable = s.handleOAuth401TempUnschedulable(ctx, account, upstreamMsg)
}
break
}
msg := "Authentication failed (401): invalid or expired credentials" msg := "Authentication failed (401): invalid or expired credentials"
if upstreamMsg != "" { if upstreamMsg != "" {
msg = "Authentication failed (401): " + upstreamMsg msg = "Authentication failed (401): " + upstreamMsg
@@ -116,7 +139,7 @@ func (s *RateLimitService) HandleUpstreamError(ctx context.Context, account *Acc
shouldDisable = true shouldDisable = true
} else if statusCode >= 500 { } else if statusCode >= 500 {
// 未启用自定义错误码时仅记录5xx错误 // 未启用自定义错误码时仅记录5xx错误
log.Printf("Account %d received upstream error %d", account.ID, statusCode) slog.Warn("account_upstream_error", "account_id", account.ID, "status_code", statusCode)
shouldDisable = false shouldDisable = false
} }
} }
@@ -127,6 +150,63 @@ func (s *RateLimitService) HandleUpstreamError(ctx context.Context, account *Acc
return shouldDisable return shouldDisable
} }
func (s *RateLimitService) handleOAuth401TempUnschedulable(ctx context.Context, account *Account, upstreamMsg string) bool {
if account == nil {
return false
}
if s.tokenCacheInvalidator != nil {
if err := s.tokenCacheInvalidator.InvalidateToken(ctx, account); err != nil {
slog.Warn("oauth_401_invalidate_cache_failed", "account_id", account.ID, "error", err)
}
}
now := time.Now()
until := now.Add(s.oauth401Cooldown())
msg := "Authentication failed (401): invalid or expired credentials"
if upstreamMsg != "" {
msg = "Authentication failed (401): " + upstreamMsg
}
state := &TempUnschedState{
UntilUnix: until.Unix(),
TriggeredAtUnix: now.Unix(),
StatusCode: 401,
MatchedKeyword: "oauth_401",
RuleIndex: -1, // -1 表示非规则触发,而是 OAuth 401 特殊处理
ErrorMessage: msg,
}
reason := ""
if raw, err := json.Marshal(state); err == nil {
reason = string(raw)
}
if reason == "" {
reason = msg
}
if err := s.accountRepo.SetTempUnschedulable(ctx, account.ID, until, reason); err != nil {
slog.Warn("oauth_401_set_temp_unschedulable_failed", "account_id", account.ID, "error", err)
return false
}
if s.tempUnschedCache != nil {
if err := s.tempUnschedCache.SetTempUnsched(ctx, account.ID, state); err != nil {
slog.Warn("oauth_401_set_temp_unsched_cache_failed", "account_id", account.ID, "error", err)
}
}
slog.Info("oauth_401_temp_unschedulable", "account_id", account.ID, "until", until)
return true
}
func (s *RateLimitService) oauth401Cooldown() time.Duration {
if s != nil && s.cfg != nil && s.cfg.RateLimit.OAuth401CooldownMinutes > 0 {
return time.Duration(s.cfg.RateLimit.OAuth401CooldownMinutes) * time.Minute
}
return 5 * time.Minute
}
// PreCheckUsage proactively checks local quota before dispatching a request. // PreCheckUsage proactively checks local quota before dispatching a request.
// Returns false when the account should be skipped. // Returns false when the account should be skipped.
func (s *RateLimitService) PreCheckUsage(ctx context.Context, account *Account, requestedModel string) (bool, error) { func (s *RateLimitService) PreCheckUsage(ctx context.Context, account *Account, requestedModel string) (bool, error) {
@@ -188,7 +268,7 @@ func (s *RateLimitService) PreCheckUsage(ctx context.Context, account *Account,
// NOTE: // NOTE:
// - This is a local precheck to reduce upstream 429s. // - This is a local precheck to reduce upstream 429s.
// - Do NOT mark the account as rate-limited here; rate_limit_reset_at should reflect real upstream 429s. // - Do NOT mark the account as rate-limited here; rate_limit_reset_at should reflect real upstream 429s.
log.Printf("[Gemini PreCheck] Account %d reached daily quota (%d/%d), skip until %v", account.ID, used, limit, resetAt) slog.Info("gemini_precheck_daily_quota_reached", "account_id", account.ID, "used", used, "limit", limit, "reset_at", resetAt)
return false, nil return false, nil
} }
} }
@@ -231,7 +311,7 @@ func (s *RateLimitService) PreCheckUsage(ctx context.Context, account *Account,
if used >= limit { if used >= limit {
resetAt := start.Add(time.Minute) resetAt := start.Add(time.Minute)
// Do not persist "rate limited" status from local precheck. See note above. // Do not persist "rate limited" status from local precheck. See note above.
log.Printf("[Gemini PreCheck] Account %d reached minute quota (%d/%d), skip until %v", account.ID, used, limit, resetAt) slog.Info("gemini_precheck_minute_quota_reached", "account_id", account.ID, "used", used, "limit", limit, "reset_at", resetAt)
return false, nil return false, nil
} }
} }
@@ -288,20 +368,20 @@ func (s *RateLimitService) GeminiCooldown(ctx context.Context, account *Account)
// handleAuthError 处理认证类错误(401/403),停止账号调度 // handleAuthError 处理认证类错误(401/403),停止账号调度
func (s *RateLimitService) handleAuthError(ctx context.Context, account *Account, errorMsg string) { func (s *RateLimitService) handleAuthError(ctx context.Context, account *Account, errorMsg string) {
if err := s.accountRepo.SetError(ctx, account.ID, errorMsg); err != nil { if err := s.accountRepo.SetError(ctx, account.ID, errorMsg); err != nil {
log.Printf("SetError failed for account %d: %v", account.ID, err) slog.Warn("account_set_error_failed", "account_id", account.ID, "error", err)
return return
} }
log.Printf("Account %d disabled due to auth error: %s", account.ID, errorMsg) slog.Warn("account_disabled_auth_error", "account_id", account.ID, "error", errorMsg)
} }
// handleCustomErrorCode 处理自定义错误码,停止账号调度 // handleCustomErrorCode 处理自定义错误码,停止账号调度
func (s *RateLimitService) handleCustomErrorCode(ctx context.Context, account *Account, statusCode int, errorMsg string) { func (s *RateLimitService) handleCustomErrorCode(ctx context.Context, account *Account, statusCode int, errorMsg string) {
msg := "Custom error code " + strconv.Itoa(statusCode) + ": " + errorMsg msg := "Custom error code " + strconv.Itoa(statusCode) + ": " + errorMsg
if err := s.accountRepo.SetError(ctx, account.ID, msg); err != nil { if err := s.accountRepo.SetError(ctx, account.ID, msg); err != nil {
log.Printf("SetError failed for account %d: %v", account.ID, err) slog.Warn("account_set_error_failed", "account_id", account.ID, "status_code", statusCode, "error", err)
return return
} }
log.Printf("Account %d disabled due to custom error code %d: %s", account.ID, statusCode, errorMsg) slog.Warn("account_disabled_custom_error", "account_id", account.ID, "status_code", statusCode, "error", errorMsg)
} }
// handle429 处理429限流错误 // handle429 处理429限流错误
@@ -313,7 +393,7 @@ func (s *RateLimitService) handle429(ctx context.Context, account *Account, head
// 没有重置时间使用默认5分钟 // 没有重置时间使用默认5分钟
resetAt := time.Now().Add(5 * time.Minute) resetAt := time.Now().Add(5 * time.Minute)
if err := s.accountRepo.SetRateLimited(ctx, account.ID, resetAt); err != nil { if err := s.accountRepo.SetRateLimited(ctx, account.ID, resetAt); err != nil {
log.Printf("SetRateLimited failed for account %d: %v", account.ID, err) slog.Warn("rate_limit_set_failed", "account_id", account.ID, "error", err)
} }
return return
} }
@@ -321,10 +401,10 @@ func (s *RateLimitService) handle429(ctx context.Context, account *Account, head
// 解析Unix时间戳 // 解析Unix时间戳
ts, err := strconv.ParseInt(resetTimestamp, 10, 64) ts, err := strconv.ParseInt(resetTimestamp, 10, 64)
if err != nil { if err != nil {
log.Printf("Parse reset timestamp failed: %v", err) slog.Warn("rate_limit_reset_parse_failed", "reset_timestamp", resetTimestamp, "error", err)
resetAt := time.Now().Add(5 * time.Minute) resetAt := time.Now().Add(5 * time.Minute)
if err := s.accountRepo.SetRateLimited(ctx, account.ID, resetAt); err != nil { if err := s.accountRepo.SetRateLimited(ctx, account.ID, resetAt); err != nil {
log.Printf("SetRateLimited failed for account %d: %v", account.ID, err) slog.Warn("rate_limit_set_failed", "account_id", account.ID, "error", err)
} }
return return
} }
@@ -333,7 +413,7 @@ func (s *RateLimitService) handle429(ctx context.Context, account *Account, head
// 标记限流状态 // 标记限流状态
if err := s.accountRepo.SetRateLimited(ctx, account.ID, resetAt); err != nil { if err := s.accountRepo.SetRateLimited(ctx, account.ID, resetAt); err != nil {
log.Printf("SetRateLimited failed for account %d: %v", account.ID, err) slog.Warn("rate_limit_set_failed", "account_id", account.ID, "error", err)
return return
} }
@@ -341,10 +421,10 @@ func (s *RateLimitService) handle429(ctx context.Context, account *Account, head
windowEnd := resetAt windowEnd := resetAt
windowStart := resetAt.Add(-5 * time.Hour) windowStart := resetAt.Add(-5 * time.Hour)
if err := s.accountRepo.UpdateSessionWindow(ctx, account.ID, &windowStart, &windowEnd, "rejected"); err != nil { if err := s.accountRepo.UpdateSessionWindow(ctx, account.ID, &windowStart, &windowEnd, "rejected"); err != nil {
log.Printf("UpdateSessionWindow failed for account %d: %v", account.ID, err) slog.Warn("rate_limit_update_session_window_failed", "account_id", account.ID, "error", err)
} }
log.Printf("Account %d rate limited until %v", account.ID, resetAt) slog.Info("account_rate_limited", "account_id", account.ID, "reset_at", resetAt)
} }
// handle529 处理529过载错误 // handle529 处理529过载错误
@@ -357,11 +437,11 @@ func (s *RateLimitService) handle529(ctx context.Context, account *Account) {
until := time.Now().Add(time.Duration(cooldownMinutes) * time.Minute) until := time.Now().Add(time.Duration(cooldownMinutes) * time.Minute)
if err := s.accountRepo.SetOverloaded(ctx, account.ID, until); err != nil { if err := s.accountRepo.SetOverloaded(ctx, account.ID, until); err != nil {
log.Printf("SetOverloaded failed for account %d: %v", account.ID, err) slog.Warn("overload_set_failed", "account_id", account.ID, "error", err)
return return
} }
log.Printf("Account %d overloaded until %v", account.ID, until) slog.Info("account_overloaded", "account_id", account.ID, "until", until)
} }
// UpdateSessionWindow 从成功响应更新5h窗口状态 // UpdateSessionWindow 从成功响应更新5h窗口状态
@@ -384,17 +464,17 @@ func (s *RateLimitService) UpdateSessionWindow(ctx context.Context, account *Acc
end := start.Add(5 * time.Hour) end := start.Add(5 * time.Hour)
windowStart = &start windowStart = &start
windowEnd = &end windowEnd = &end
log.Printf("Account %d: initializing 5h window from %v to %v (status: %s)", account.ID, start, end, status) slog.Info("account_session_window_initialized", "account_id", account.ID, "window_start", start, "window_end", end, "status", status)
} }
if err := s.accountRepo.UpdateSessionWindow(ctx, account.ID, windowStart, windowEnd, status); err != nil { if err := s.accountRepo.UpdateSessionWindow(ctx, account.ID, windowStart, windowEnd, status); err != nil {
log.Printf("UpdateSessionWindow failed for account %d: %v", account.ID, err) slog.Warn("session_window_update_failed", "account_id", account.ID, "error", err)
} }
// 如果状态为allowed且之前有限流说明窗口已重置清除限流状态 // 如果状态为allowed且之前有限流说明窗口已重置清除限流状态
if status == "allowed" && account.IsRateLimited() { if status == "allowed" && account.IsRateLimited() {
if err := s.ClearRateLimit(ctx, account.ID); err != nil { if err := s.ClearRateLimit(ctx, account.ID); err != nil {
log.Printf("ClearRateLimit failed for account %d: %v", account.ID, err) slog.Warn("rate_limit_clear_failed", "account_id", account.ID, "error", err)
} }
} }
} }
@@ -413,7 +493,7 @@ func (s *RateLimitService) ClearTempUnschedulable(ctx context.Context, accountID
} }
if s.tempUnschedCache != nil { if s.tempUnschedCache != nil {
if err := s.tempUnschedCache.DeleteTempUnsched(ctx, accountID); err != nil { if err := s.tempUnschedCache.DeleteTempUnsched(ctx, accountID); err != nil {
log.Printf("DeleteTempUnsched failed for account %d: %v", accountID, err) slog.Warn("temp_unsched_cache_delete_failed", "account_id", accountID, "error", err)
} }
} }
return nil return nil
@@ -460,7 +540,7 @@ func (s *RateLimitService) GetTempUnschedStatus(ctx context.Context, accountID i
if s.tempUnschedCache != nil { if s.tempUnschedCache != nil {
if err := s.tempUnschedCache.SetTempUnsched(ctx, accountID, state); err != nil { if err := s.tempUnschedCache.SetTempUnsched(ctx, accountID, state); err != nil {
log.Printf("SetTempUnsched failed for account %d: %v", accountID, err) slog.Warn("temp_unsched_cache_set_failed", "account_id", accountID, "error", err)
} }
} }
@@ -563,17 +643,17 @@ func (s *RateLimitService) triggerTempUnschedulable(ctx context.Context, account
} }
if err := s.accountRepo.SetTempUnschedulable(ctx, account.ID, until, reason); err != nil { if err := s.accountRepo.SetTempUnschedulable(ctx, account.ID, until, reason); err != nil {
log.Printf("SetTempUnschedulable failed for account %d: %v", account.ID, err) slog.Warn("temp_unsched_set_failed", "account_id", account.ID, "error", err)
return false return false
} }
if s.tempUnschedCache != nil { if s.tempUnschedCache != nil {
if err := s.tempUnschedCache.SetTempUnsched(ctx, account.ID, state); err != nil { if err := s.tempUnschedCache.SetTempUnsched(ctx, account.ID, state); err != nil {
log.Printf("SetTempUnsched cache failed for account %d: %v", account.ID, err) slog.Warn("temp_unsched_cache_set_failed", "account_id", account.ID, "error", err)
} }
} }
log.Printf("Account %d temp unschedulable until %v (rule %d, code %d)", account.ID, until, ruleIndex, statusCode) slog.Info("account_temp_unschedulable", "account_id", account.ID, "until", until, "rule_index", ruleIndex, "status_code", statusCode)
return true return true
} }
@@ -597,13 +677,13 @@ func (s *RateLimitService) HandleStreamTimeout(ctx context.Context, account *Acc
// 获取系统设置 // 获取系统设置
if s.settingService == nil { if s.settingService == nil {
log.Printf("[StreamTimeout] settingService not configured, skipping timeout handling for account %d", account.ID) slog.Warn("stream_timeout_setting_service_missing", "account_id", account.ID)
return false return false
} }
settings, err := s.settingService.GetStreamTimeoutSettings(ctx) settings, err := s.settingService.GetStreamTimeoutSettings(ctx)
if err != nil { if err != nil {
log.Printf("[StreamTimeout] Failed to get settings: %v", err) slog.Warn("stream_timeout_get_settings_failed", "account_id", account.ID, "error", err)
return false return false
} }
@@ -620,14 +700,13 @@ func (s *RateLimitService) HandleStreamTimeout(ctx context.Context, account *Acc
if s.timeoutCounterCache != nil { if s.timeoutCounterCache != nil {
count, err = s.timeoutCounterCache.IncrementTimeoutCount(ctx, account.ID, settings.ThresholdWindowMinutes) count, err = s.timeoutCounterCache.IncrementTimeoutCount(ctx, account.ID, settings.ThresholdWindowMinutes)
if err != nil { if err != nil {
log.Printf("[StreamTimeout] Failed to increment timeout count for account %d: %v", account.ID, err) slog.Warn("stream_timeout_increment_count_failed", "account_id", account.ID, "error", err)
// 继续处理,使用 count=1 // 继续处理,使用 count=1
count = 1 count = 1
} }
} }
log.Printf("[StreamTimeout] Account %d timeout count: %d/%d (window: %d min, model: %s)", slog.Info("stream_timeout_count", "account_id", account.ID, "count", count, "threshold", settings.ThresholdCount, "window_minutes", settings.ThresholdWindowMinutes, "model", model)
account.ID, count, settings.ThresholdCount, settings.ThresholdWindowMinutes, model)
// 检查是否达到阈值 // 检查是否达到阈值
if count < int64(settings.ThresholdCount) { if count < int64(settings.ThresholdCount) {
@@ -668,24 +747,24 @@ func (s *RateLimitService) triggerStreamTimeoutTempUnsched(ctx context.Context,
} }
if err := s.accountRepo.SetTempUnschedulable(ctx, account.ID, until, reason); err != nil { if err := s.accountRepo.SetTempUnschedulable(ctx, account.ID, until, reason); err != nil {
log.Printf("[StreamTimeout] SetTempUnschedulable failed for account %d: %v", account.ID, err) slog.Warn("stream_timeout_set_temp_unsched_failed", "account_id", account.ID, "error", err)
return false return false
} }
if s.tempUnschedCache != nil { if s.tempUnschedCache != nil {
if err := s.tempUnschedCache.SetTempUnsched(ctx, account.ID, state); err != nil { if err := s.tempUnschedCache.SetTempUnsched(ctx, account.ID, state); err != nil {
log.Printf("[StreamTimeout] SetTempUnsched cache failed for account %d: %v", account.ID, err) slog.Warn("stream_timeout_set_temp_unsched_cache_failed", "account_id", account.ID, "error", err)
} }
} }
// 重置超时计数 // 重置超时计数
if s.timeoutCounterCache != nil { if s.timeoutCounterCache != nil {
if err := s.timeoutCounterCache.ResetTimeoutCount(ctx, account.ID); err != nil { if err := s.timeoutCounterCache.ResetTimeoutCount(ctx, account.ID); err != nil {
log.Printf("[StreamTimeout] ResetTimeoutCount failed for account %d: %v", account.ID, err) slog.Warn("stream_timeout_reset_count_failed", "account_id", account.ID, "error", err)
} }
} }
log.Printf("[StreamTimeout] Account %d marked as temp unschedulable until %v (model: %s)", account.ID, until, model) slog.Info("stream_timeout_temp_unschedulable", "account_id", account.ID, "until", until, "model", model)
return true return true
} }
@@ -694,17 +773,17 @@ func (s *RateLimitService) triggerStreamTimeoutError(ctx context.Context, accoun
errorMsg := "Stream data interval timeout (repeated failures) for model: " + model errorMsg := "Stream data interval timeout (repeated failures) for model: " + model
if err := s.accountRepo.SetError(ctx, account.ID, errorMsg); err != nil { if err := s.accountRepo.SetError(ctx, account.ID, errorMsg); err != nil {
log.Printf("[StreamTimeout] SetError failed for account %d: %v", account.ID, err) slog.Warn("stream_timeout_set_error_failed", "account_id", account.ID, "error", err)
return false return false
} }
// 重置超时计数 // 重置超时计数
if s.timeoutCounterCache != nil { if s.timeoutCounterCache != nil {
if err := s.timeoutCounterCache.ResetTimeoutCount(ctx, account.ID); err != nil { if err := s.timeoutCounterCache.ResetTimeoutCount(ctx, account.ID); err != nil {
log.Printf("[StreamTimeout] ResetTimeoutCount failed for account %d: %v", account.ID, err) slog.Warn("stream_timeout_reset_count_failed", "account_id", account.ID, "error", err)
} }
} }
log.Printf("[StreamTimeout] Account %d marked as error (model: %s)", account.ID, model) slog.Warn("stream_timeout_account_error", "account_id", account.ID, "model", model)
return true return true
} }

View File

@@ -0,0 +1,353 @@
//go:build unit
package service
import (
"context"
"errors"
"net/http"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/stretchr/testify/require"
)
type rateLimitAccountRepoStub struct {
mockAccountRepoForGemini
tempCalls int
tempUntil time.Time
tempReason string
setErrorCalls int
}
func (r *rateLimitAccountRepoStub) SetTempUnschedulable(ctx context.Context, id int64, until time.Time, reason string) error {
r.tempCalls++
r.tempUntil = until
r.tempReason = reason
return nil
}
func (r *rateLimitAccountRepoStub) SetError(ctx context.Context, id int64, errorMsg string) error {
r.setErrorCalls++
return nil
}
type tokenCacheInvalidatorRecorder struct {
accounts []*Account
err error
}
func (r *tokenCacheInvalidatorRecorder) InvalidateToken(ctx context.Context, account *Account) error {
r.accounts = append(r.accounts, account)
return r.err
}
func TestRateLimitService_HandleUpstreamError_OAuth401TempUnschedulable(t *testing.T) {
tests := []struct {
name string
platform string
}{
{name: "gemini", platform: PlatformGemini},
{name: "antigravity", platform: PlatformAntigravity},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
repo := &rateLimitAccountRepoStub{}
invalidator := &tokenCacheInvalidatorRecorder{}
service := NewRateLimitService(repo, nil, &config.Config{}, nil, nil)
service.SetTokenCacheInvalidator(invalidator)
account := &Account{
ID: 100,
Platform: tt.platform,
Type: AccountTypeOAuth,
}
start := time.Now()
shouldDisable := service.HandleUpstreamError(context.Background(), account, 401, http.Header{}, []byte("unauthorized"))
require.True(t, shouldDisable)
require.Equal(t, 1, repo.tempCalls)
require.Equal(t, 0, repo.setErrorCalls)
require.Len(t, invalidator.accounts, 1)
require.WithinDuration(t, start.Add(5*time.Minute), repo.tempUntil, 10*time.Second)
require.NotEmpty(t, repo.tempReason)
})
}
}
func TestRateLimitService_HandleUpstreamError_OAuth401InvalidatorError(t *testing.T) {
repo := &rateLimitAccountRepoStub{}
invalidator := &tokenCacheInvalidatorRecorder{err: errors.New("boom")}
service := NewRateLimitService(repo, nil, &config.Config{}, nil, nil)
service.SetTokenCacheInvalidator(invalidator)
account := &Account{
ID: 101,
Platform: PlatformGemini,
Type: AccountTypeOAuth,
}
shouldDisable := service.HandleUpstreamError(context.Background(), account, 401, http.Header{}, []byte("unauthorized"))
require.True(t, shouldDisable)
require.Equal(t, 1, repo.tempCalls)
require.Equal(t, 0, repo.setErrorCalls)
require.Len(t, invalidator.accounts, 1)
}
func TestRateLimitService_HandleUpstreamError_OAuth401CustomRule(t *testing.T) {
repo := &rateLimitAccountRepoStub{}
invalidator := &tokenCacheInvalidatorRecorder{}
service := NewRateLimitService(repo, nil, &config.Config{}, nil, nil)
service.SetTokenCacheInvalidator(invalidator)
account := &Account{
ID: 103,
Platform: PlatformGemini,
Type: AccountTypeOAuth,
Credentials: map[string]any{
"temp_unschedulable_enabled": true,
"temp_unschedulable_rules": []any{
map[string]any{
"error_code": 401,
"keywords": []any{"unauthorized"},
"duration_minutes": 30,
"description": "custom rule",
},
},
},
}
start := time.Now()
shouldDisable := service.HandleUpstreamError(context.Background(), account, 401, http.Header{}, []byte("unauthorized"))
require.True(t, shouldDisable)
require.Equal(t, 1, repo.tempCalls)
require.Equal(t, 0, repo.setErrorCalls)
require.Len(t, invalidator.accounts, 1)
require.WithinDuration(t, start.Add(30*time.Minute), repo.tempUntil, 10*time.Second)
}
func TestRateLimitService_HandleUpstreamError_NonOAuth401(t *testing.T) {
repo := &rateLimitAccountRepoStub{}
invalidator := &tokenCacheInvalidatorRecorder{}
service := NewRateLimitService(repo, nil, &config.Config{}, nil, nil)
service.SetTokenCacheInvalidator(invalidator)
account := &Account{
ID: 102,
Platform: PlatformOpenAI,
Type: AccountTypeAPIKey,
}
shouldDisable := service.HandleUpstreamError(context.Background(), account, 401, http.Header{}, []byte("unauthorized"))
require.True(t, shouldDisable)
require.Equal(t, 0, repo.tempCalls)
require.Equal(t, 1, repo.setErrorCalls)
require.Empty(t, invalidator.accounts)
}
// TestRateLimitService_HandleOAuth401_NilAccount 测试 account 为 nil 的情况
func TestRateLimitService_HandleOAuth401_NilAccount(t *testing.T) {
repo := &rateLimitAccountRepoStub{}
service := NewRateLimitService(repo, nil, &config.Config{}, nil, nil)
result := service.handleOAuth401TempUnschedulable(context.Background(), nil, "error")
require.False(t, result)
require.Equal(t, 0, repo.tempCalls)
}
// TestRateLimitService_HandleOAuth401_NilInvalidator 测试 tokenCacheInvalidator 为 nil 的情况
func TestRateLimitService_HandleOAuth401_NilInvalidator(t *testing.T) {
repo := &rateLimitAccountRepoStub{}
service := NewRateLimitService(repo, nil, &config.Config{}, nil, nil)
// 不设置 tokenCacheInvalidator
account := &Account{
ID: 200,
Platform: PlatformGemini,
Type: AccountTypeOAuth,
}
result := service.handleOAuth401TempUnschedulable(context.Background(), account, "error")
require.True(t, result)
require.Equal(t, 1, repo.tempCalls)
}
// TestRateLimitService_HandleOAuth401_SetTempUnschedulableFailed 测试 SetTempUnschedulable 失败的情况
func TestRateLimitService_HandleOAuth401_SetTempUnschedulableFailed(t *testing.T) {
repo := &rateLimitAccountRepoStubWithError{
setTempErr: errors.New("db error"),
}
invalidator := &tokenCacheInvalidatorRecorder{}
service := NewRateLimitService(repo, nil, &config.Config{}, nil, nil)
service.SetTokenCacheInvalidator(invalidator)
account := &Account{
ID: 201,
Platform: PlatformGemini,
Type: AccountTypeOAuth,
}
result := service.handleOAuth401TempUnschedulable(context.Background(), account, "error")
require.False(t, result) // 失败应返回 false
require.Len(t, invalidator.accounts, 1) // 但 invalidator 仍然被调用
}
// rateLimitAccountRepoStubWithError 支持返回错误的 stub
type rateLimitAccountRepoStubWithError struct {
mockAccountRepoForGemini
setTempErr error
setErrorCalls int
}
func (r *rateLimitAccountRepoStubWithError) SetTempUnschedulable(ctx context.Context, id int64, until time.Time, reason string) error {
return r.setTempErr
}
func (r *rateLimitAccountRepoStubWithError) SetError(ctx context.Context, id int64, errorMsg string) error {
r.setErrorCalls++
return nil
}
// TestRateLimitService_HandleOAuth401_WithTempUnschedCache 测试 tempUnschedCache 存在的情况
func TestRateLimitService_HandleOAuth401_WithTempUnschedCache(t *testing.T) {
repo := &rateLimitAccountRepoStub{}
invalidator := &tokenCacheInvalidatorRecorder{}
tempCache := &tempUnschedCacheStub{}
service := NewRateLimitService(repo, nil, &config.Config{}, nil, tempCache)
service.SetTokenCacheInvalidator(invalidator)
account := &Account{
ID: 202,
Platform: PlatformGemini,
Type: AccountTypeOAuth,
}
result := service.handleOAuth401TempUnschedulable(context.Background(), account, "error")
require.True(t, result)
require.Equal(t, 1, repo.tempCalls)
require.Equal(t, 1, tempCache.setCalls)
}
// TestRateLimitService_HandleOAuth401_TempUnschedCacheError 测试 tempUnschedCache 设置失败的情况
func TestRateLimitService_HandleOAuth401_TempUnschedCacheError(t *testing.T) {
repo := &rateLimitAccountRepoStub{}
invalidator := &tokenCacheInvalidatorRecorder{}
tempCache := &tempUnschedCacheStub{setErr: errors.New("cache error")}
service := NewRateLimitService(repo, nil, &config.Config{}, nil, tempCache)
service.SetTokenCacheInvalidator(invalidator)
account := &Account{
ID: 203,
Platform: PlatformGemini,
Type: AccountTypeOAuth,
}
result := service.handleOAuth401TempUnschedulable(context.Background(), account, "error")
require.True(t, result) // 缓存错误不影响主流程
require.Equal(t, 1, repo.tempCalls)
}
// tempUnschedCacheStub 用于测试的 TempUnschedCache stub
type tempUnschedCacheStub struct {
setCalls int
setErr error
}
func (c *tempUnschedCacheStub) GetTempUnsched(ctx context.Context, accountID int64) (*TempUnschedState, error) {
return nil, nil
}
func (c *tempUnschedCacheStub) SetTempUnsched(ctx context.Context, accountID int64, state *TempUnschedState) error {
c.setCalls++
return c.setErr
}
func (c *tempUnschedCacheStub) DeleteTempUnsched(ctx context.Context, accountID int64) error {
return nil
}
// TestRateLimitService_OAuth401Cooldown 测试 oauth401Cooldown 函数
func TestRateLimitService_OAuth401Cooldown(t *testing.T) {
tests := []struct {
name string
cfg *config.Config
expected time.Duration
}{
{
name: "default_when_config_zero",
cfg: &config.Config{RateLimit: config.RateLimitConfig{OAuth401CooldownMinutes: 0}},
expected: 5 * time.Minute,
},
{
name: "custom_cooldown_10_minutes",
cfg: &config.Config{RateLimit: config.RateLimitConfig{OAuth401CooldownMinutes: 10}},
expected: 10 * time.Minute,
},
{
name: "custom_cooldown_1_minute",
cfg: &config.Config{RateLimit: config.RateLimitConfig{OAuth401CooldownMinutes: 1}},
expected: 1 * time.Minute,
},
{
name: "negative_value_uses_default",
cfg: &config.Config{RateLimit: config.RateLimitConfig{OAuth401CooldownMinutes: -5}},
expected: 5 * time.Minute,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
service := NewRateLimitService(nil, nil, tt.cfg, nil, nil)
result := service.oauth401Cooldown()
require.Equal(t, tt.expected, result)
})
}
}
// TestRateLimitService_OAuth401Cooldown_NilConfig 测试 cfg 为 nil 的情况
func TestRateLimitService_OAuth401Cooldown_NilConfig(t *testing.T) {
service := &RateLimitService{cfg: nil}
result := service.oauth401Cooldown()
require.Equal(t, 5*time.Minute, result)
}
// TestRateLimitService_HandleOAuth401_WithCustomCooldown 测试自定义 cooldown 配置
func TestRateLimitService_HandleOAuth401_WithCustomCooldown(t *testing.T) {
repo := &rateLimitAccountRepoStub{}
cfg := &config.Config{
RateLimit: config.RateLimitConfig{
OAuth401CooldownMinutes: 15,
},
}
service := NewRateLimitService(repo, nil, cfg, nil, nil)
account := &Account{
ID: 204,
Platform: PlatformAntigravity,
Type: AccountTypeOAuth,
}
start := time.Now()
result := service.handleOAuth401TempUnschedulable(context.Background(), account, "error")
require.True(t, result)
require.WithinDuration(t, start.Add(15*time.Minute), repo.tempUntil, 10*time.Second)
}
// TestRateLimitService_HandleOAuth401_EmptyUpstreamMsg 测试 upstreamMsg 为空的情况
func TestRateLimitService_HandleOAuth401_EmptyUpstreamMsg(t *testing.T) {
repo := &rateLimitAccountRepoStub{}
service := NewRateLimitService(repo, nil, &config.Config{}, nil, nil)
account := &Account{
ID: 205,
Platform: PlatformGemini,
Type: AccountTypeOAuth,
}
result := service.handleOAuth401TempUnschedulable(context.Background(), account, "")
require.True(t, result)
require.Contains(t, repo.tempReason, "Authentication failed (401)")
}

View File

@@ -0,0 +1,35 @@
package service
import "context"
type TokenCacheInvalidator interface {
InvalidateToken(ctx context.Context, account *Account) error
}
type CompositeTokenCacheInvalidator struct {
geminiCache GeminiTokenCache
}
func NewCompositeTokenCacheInvalidator(geminiCache GeminiTokenCache) *CompositeTokenCacheInvalidator {
return &CompositeTokenCacheInvalidator{
geminiCache: geminiCache,
}
}
func (c *CompositeTokenCacheInvalidator) InvalidateToken(ctx context.Context, account *Account) error {
if c == nil || c.geminiCache == nil || account == nil {
return nil
}
if account.Type != AccountTypeOAuth {
return nil
}
switch account.Platform {
case PlatformGemini:
return c.geminiCache.DeleteAccessToken(ctx, GeminiTokenCacheKey(account))
case PlatformAntigravity:
return c.geminiCache.DeleteAccessToken(ctx, AntigravityTokenCacheKey(account))
default:
return nil
}
}

View File

@@ -0,0 +1,97 @@
//go:build unit
package service
import (
"context"
"testing"
"time"
"github.com/stretchr/testify/require"
)
type geminiTokenCacheStub struct {
deletedKeys []string
deleteErr error
}
func (s *geminiTokenCacheStub) GetAccessToken(ctx context.Context, cacheKey string) (string, error) {
return "", nil
}
func (s *geminiTokenCacheStub) SetAccessToken(ctx context.Context, cacheKey string, token string, ttl time.Duration) error {
return nil
}
func (s *geminiTokenCacheStub) DeleteAccessToken(ctx context.Context, cacheKey string) error {
s.deletedKeys = append(s.deletedKeys, cacheKey)
return s.deleteErr
}
func (s *geminiTokenCacheStub) AcquireRefreshLock(ctx context.Context, cacheKey string, ttl time.Duration) (bool, error) {
return true, nil
}
func (s *geminiTokenCacheStub) ReleaseRefreshLock(ctx context.Context, cacheKey string) error {
return nil
}
func TestCompositeTokenCacheInvalidator_Gemini(t *testing.T) {
cache := &geminiTokenCacheStub{}
invalidator := NewCompositeTokenCacheInvalidator(cache)
account := &Account{
ID: 10,
Platform: PlatformGemini,
Type: AccountTypeOAuth,
Credentials: map[string]any{
"project_id": "project-x",
},
}
err := invalidator.InvalidateToken(context.Background(), account)
require.NoError(t, err)
require.Equal(t, []string{"project-x"}, cache.deletedKeys)
}
func TestCompositeTokenCacheInvalidator_Antigravity(t *testing.T) {
cache := &geminiTokenCacheStub{}
invalidator := NewCompositeTokenCacheInvalidator(cache)
account := &Account{
ID: 99,
Platform: PlatformAntigravity,
Type: AccountTypeOAuth,
Credentials: map[string]any{
"project_id": "ag-project",
},
}
err := invalidator.InvalidateToken(context.Background(), account)
require.NoError(t, err)
require.Equal(t, []string{"ag:ag-project"}, cache.deletedKeys)
}
func TestCompositeTokenCacheInvalidator_SkipNonOAuth(t *testing.T) {
cache := &geminiTokenCacheStub{}
invalidator := NewCompositeTokenCacheInvalidator(cache)
account := &Account{
ID: 1,
Platform: PlatformGemini,
Type: AccountTypeAPIKey,
}
err := invalidator.InvalidateToken(context.Background(), account)
require.NoError(t, err)
require.Empty(t, cache.deletedKeys)
}
func TestCompositeTokenCacheInvalidator_NilCache(t *testing.T) {
invalidator := NewCompositeTokenCacheInvalidator(nil)
account := &Account{
ID: 2,
Platform: PlatformGemini,
Type: AccountTypeOAuth,
}
err := invalidator.InvalidateToken(context.Background(), account)
require.NoError(t, err)
}

View File

@@ -0,0 +1,153 @@
//go:build unit
package service
import (
"testing"
"github.com/stretchr/testify/require"
)
func TestGeminiTokenCacheKey(t *testing.T) {
tests := []struct {
name string
account *Account
expected string
}{
{
name: "with_project_id",
account: &Account{
ID: 100,
Credentials: map[string]any{
"project_id": "my-project-123",
},
},
expected: "my-project-123",
},
{
name: "project_id_with_whitespace",
account: &Account{
ID: 101,
Credentials: map[string]any{
"project_id": " project-with-spaces ",
},
},
expected: "project-with-spaces",
},
{
name: "empty_project_id_fallback_to_account_id",
account: &Account{
ID: 102,
Credentials: map[string]any{
"project_id": "",
},
},
expected: "account:102",
},
{
name: "whitespace_only_project_id_fallback_to_account_id",
account: &Account{
ID: 103,
Credentials: map[string]any{
"project_id": " ",
},
},
expected: "account:103",
},
{
name: "no_project_id_key_fallback_to_account_id",
account: &Account{
ID: 104,
Credentials: map[string]any{},
},
expected: "account:104",
},
{
name: "nil_credentials_fallback_to_account_id",
account: &Account{
ID: 105,
Credentials: nil,
},
expected: "account:105",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := GeminiTokenCacheKey(tt.account)
require.Equal(t, tt.expected, result)
})
}
}
func TestAntigravityTokenCacheKey(t *testing.T) {
tests := []struct {
name string
account *Account
expected string
}{
{
name: "with_project_id",
account: &Account{
ID: 200,
Credentials: map[string]any{
"project_id": "ag-project-456",
},
},
expected: "ag:ag-project-456",
},
{
name: "project_id_with_whitespace",
account: &Account{
ID: 201,
Credentials: map[string]any{
"project_id": " ag-project-spaces ",
},
},
expected: "ag:ag-project-spaces",
},
{
name: "empty_project_id_fallback_to_account_id",
account: &Account{
ID: 202,
Credentials: map[string]any{
"project_id": "",
},
},
expected: "ag:account:202",
},
{
name: "whitespace_only_project_id_fallback_to_account_id",
account: &Account{
ID: 203,
Credentials: map[string]any{
"project_id": " ",
},
},
expected: "ag:account:203",
},
{
name: "no_project_id_key_fallback_to_account_id",
account: &Account{
ID: 204,
Credentials: map[string]any{},
},
expected: "ag:account:204",
},
{
name: "nil_credentials_fallback_to_account_id",
account: &Account{
ID: 205,
Credentials: nil,
},
expected: "ag:account:205",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := AntigravityTokenCacheKey(tt.account)
require.Equal(t, tt.expected, result)
})
}
}

View File

@@ -14,9 +14,10 @@ import (
// TokenRefreshService OAuth token自动刷新服务 // TokenRefreshService OAuth token自动刷新服务
// 定期检查并刷新即将过期的token // 定期检查并刷新即将过期的token
type TokenRefreshService struct { type TokenRefreshService struct {
accountRepo AccountRepository accountRepo AccountRepository
refreshers []TokenRefresher refreshers []TokenRefresher
cfg *config.TokenRefreshConfig cfg *config.TokenRefreshConfig
cacheInvalidator TokenCacheInvalidator
stopCh chan struct{} stopCh chan struct{}
wg sync.WaitGroup wg sync.WaitGroup
@@ -29,12 +30,14 @@ func NewTokenRefreshService(
openaiOAuthService *OpenAIOAuthService, openaiOAuthService *OpenAIOAuthService,
geminiOAuthService *GeminiOAuthService, geminiOAuthService *GeminiOAuthService,
antigravityOAuthService *AntigravityOAuthService, antigravityOAuthService *AntigravityOAuthService,
cacheInvalidator TokenCacheInvalidator,
cfg *config.Config, cfg *config.Config,
) *TokenRefreshService { ) *TokenRefreshService {
s := &TokenRefreshService{ s := &TokenRefreshService{
accountRepo: accountRepo, accountRepo: accountRepo,
cfg: &cfg.TokenRefresh, cfg: &cfg.TokenRefresh,
stopCh: make(chan struct{}), cacheInvalidator: cacheInvalidator,
stopCh: make(chan struct{}),
} }
// 注册平台特定的刷新器 // 注册平台特定的刷新器
@@ -169,6 +172,14 @@ func (s *TokenRefreshService) refreshWithRetry(ctx context.Context, account *Acc
if err := s.accountRepo.Update(ctx, account); err != nil { if err := s.accountRepo.Update(ctx, account); err != nil {
return fmt.Errorf("failed to save credentials: %w", err) return fmt.Errorf("failed to save credentials: %w", err)
} }
if s.cacheInvalidator != nil && account.Type == AccountTypeOAuth &&
(account.Platform == PlatformGemini || account.Platform == PlatformAntigravity) {
if err := s.cacheInvalidator.InvalidateToken(ctx, account); err != nil {
log.Printf("[TokenRefresh] Failed to invalidate token cache for account %d: %v", account.ID, err)
} else {
log.Printf("[TokenRefresh] Token cache invalidated for account %d", account.ID)
}
}
return nil return nil
} }

View File

@@ -0,0 +1,361 @@
//go:build unit
package service
import (
"context"
"errors"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/stretchr/testify/require"
)
type tokenRefreshAccountRepo struct {
mockAccountRepoForGemini
updateCalls int
setErrorCalls int
lastAccount *Account
updateErr error
}
func (r *tokenRefreshAccountRepo) Update(ctx context.Context, account *Account) error {
r.updateCalls++
r.lastAccount = account
return r.updateErr
}
func (r *tokenRefreshAccountRepo) SetError(ctx context.Context, id int64, errorMsg string) error {
r.setErrorCalls++
return nil
}
type tokenCacheInvalidatorStub struct {
calls int
err error
}
func (s *tokenCacheInvalidatorStub) InvalidateToken(ctx context.Context, account *Account) error {
s.calls++
return s.err
}
type tokenRefresherStub struct {
credentials map[string]any
err error
}
func (r *tokenRefresherStub) CanRefresh(account *Account) bool {
return true
}
func (r *tokenRefresherStub) NeedsRefresh(account *Account, refreshWindowDuration time.Duration) bool {
return true
}
func (r *tokenRefresherStub) Refresh(ctx context.Context, account *Account) (map[string]any, error) {
if r.err != nil {
return nil, r.err
}
return r.credentials, nil
}
func TestTokenRefreshService_RefreshWithRetry_InvalidatesCache(t *testing.T) {
repo := &tokenRefreshAccountRepo{}
invalidator := &tokenCacheInvalidatorStub{}
cfg := &config.Config{
TokenRefresh: config.TokenRefreshConfig{
MaxRetries: 1,
RetryBackoffSeconds: 0,
},
}
service := NewTokenRefreshService(repo, nil, nil, nil, nil, invalidator, cfg)
account := &Account{
ID: 5,
Platform: PlatformGemini,
Type: AccountTypeOAuth,
}
refresher := &tokenRefresherStub{
credentials: map[string]any{
"access_token": "new-token",
},
}
err := service.refreshWithRetry(context.Background(), account, refresher)
require.NoError(t, err)
require.Equal(t, 1, repo.updateCalls)
require.Equal(t, 1, invalidator.calls)
require.Equal(t, "new-token", account.GetCredential("access_token"))
}
func TestTokenRefreshService_RefreshWithRetry_InvalidatorErrorIgnored(t *testing.T) {
repo := &tokenRefreshAccountRepo{}
invalidator := &tokenCacheInvalidatorStub{err: errors.New("invalidate failed")}
cfg := &config.Config{
TokenRefresh: config.TokenRefreshConfig{
MaxRetries: 1,
RetryBackoffSeconds: 0,
},
}
service := NewTokenRefreshService(repo, nil, nil, nil, nil, invalidator, cfg)
account := &Account{
ID: 6,
Platform: PlatformGemini,
Type: AccountTypeOAuth,
}
refresher := &tokenRefresherStub{
credentials: map[string]any{
"access_token": "token",
},
}
err := service.refreshWithRetry(context.Background(), account, refresher)
require.NoError(t, err)
require.Equal(t, 1, repo.updateCalls)
require.Equal(t, 1, invalidator.calls)
}
func TestTokenRefreshService_RefreshWithRetry_NilInvalidator(t *testing.T) {
repo := &tokenRefreshAccountRepo{}
cfg := &config.Config{
TokenRefresh: config.TokenRefreshConfig{
MaxRetries: 1,
RetryBackoffSeconds: 0,
},
}
service := NewTokenRefreshService(repo, nil, nil, nil, nil, nil, cfg)
account := &Account{
ID: 7,
Platform: PlatformGemini,
Type: AccountTypeOAuth,
}
refresher := &tokenRefresherStub{
credentials: map[string]any{
"access_token": "token",
},
}
err := service.refreshWithRetry(context.Background(), account, refresher)
require.NoError(t, err)
require.Equal(t, 1, repo.updateCalls)
}
// TestTokenRefreshService_RefreshWithRetry_Antigravity 测试 Antigravity 平台的缓存失效
func TestTokenRefreshService_RefreshWithRetry_Antigravity(t *testing.T) {
repo := &tokenRefreshAccountRepo{}
invalidator := &tokenCacheInvalidatorStub{}
cfg := &config.Config{
TokenRefresh: config.TokenRefreshConfig{
MaxRetries: 1,
RetryBackoffSeconds: 0,
},
}
service := NewTokenRefreshService(repo, nil, nil, nil, nil, invalidator, cfg)
account := &Account{
ID: 8,
Platform: PlatformAntigravity,
Type: AccountTypeOAuth,
}
refresher := &tokenRefresherStub{
credentials: map[string]any{
"access_token": "ag-token",
},
}
err := service.refreshWithRetry(context.Background(), account, refresher)
require.NoError(t, err)
require.Equal(t, 1, repo.updateCalls)
require.Equal(t, 1, invalidator.calls) // Antigravity 也应触发缓存失效
}
// TestTokenRefreshService_RefreshWithRetry_NonOAuthAccount 测试非 OAuth 账号不触发缓存失效
func TestTokenRefreshService_RefreshWithRetry_NonOAuthAccount(t *testing.T) {
repo := &tokenRefreshAccountRepo{}
invalidator := &tokenCacheInvalidatorStub{}
cfg := &config.Config{
TokenRefresh: config.TokenRefreshConfig{
MaxRetries: 1,
RetryBackoffSeconds: 0,
},
}
service := NewTokenRefreshService(repo, nil, nil, nil, nil, invalidator, cfg)
account := &Account{
ID: 9,
Platform: PlatformGemini,
Type: AccountTypeAPIKey, // 非 OAuth
}
refresher := &tokenRefresherStub{
credentials: map[string]any{
"access_token": "token",
},
}
err := service.refreshWithRetry(context.Background(), account, refresher)
require.NoError(t, err)
require.Equal(t, 1, repo.updateCalls)
require.Equal(t, 0, invalidator.calls) // 非 OAuth 不触发缓存失效
}
// TestTokenRefreshService_RefreshWithRetry_OtherPlatformOAuth 测试其他平台的 OAuth 账号不触发缓存失效
func TestTokenRefreshService_RefreshWithRetry_OtherPlatformOAuth(t *testing.T) {
repo := &tokenRefreshAccountRepo{}
invalidator := &tokenCacheInvalidatorStub{}
cfg := &config.Config{
TokenRefresh: config.TokenRefreshConfig{
MaxRetries: 1,
RetryBackoffSeconds: 0,
},
}
service := NewTokenRefreshService(repo, nil, nil, nil, nil, invalidator, cfg)
account := &Account{
ID: 10,
Platform: PlatformOpenAI, // 其他平台
Type: AccountTypeOAuth,
}
refresher := &tokenRefresherStub{
credentials: map[string]any{
"access_token": "token",
},
}
err := service.refreshWithRetry(context.Background(), account, refresher)
require.NoError(t, err)
require.Equal(t, 1, repo.updateCalls)
require.Equal(t, 0, invalidator.calls) // 其他平台不触发缓存失效
}
// TestTokenRefreshService_RefreshWithRetry_UpdateFailed 测试更新失败的情况
func TestTokenRefreshService_RefreshWithRetry_UpdateFailed(t *testing.T) {
repo := &tokenRefreshAccountRepo{updateErr: errors.New("update failed")}
invalidator := &tokenCacheInvalidatorStub{}
cfg := &config.Config{
TokenRefresh: config.TokenRefreshConfig{
MaxRetries: 1,
RetryBackoffSeconds: 0,
},
}
service := NewTokenRefreshService(repo, nil, nil, nil, nil, invalidator, cfg)
account := &Account{
ID: 11,
Platform: PlatformGemini,
Type: AccountTypeOAuth,
}
refresher := &tokenRefresherStub{
credentials: map[string]any{
"access_token": "token",
},
}
err := service.refreshWithRetry(context.Background(), account, refresher)
require.Error(t, err)
require.Contains(t, err.Error(), "failed to save credentials")
require.Equal(t, 1, repo.updateCalls)
require.Equal(t, 0, invalidator.calls) // 更新失败时不应触发缓存失效
}
// TestTokenRefreshService_RefreshWithRetry_RefreshFailed 测试刷新失败的情况
func TestTokenRefreshService_RefreshWithRetry_RefreshFailed(t *testing.T) {
repo := &tokenRefreshAccountRepo{}
invalidator := &tokenCacheInvalidatorStub{}
cfg := &config.Config{
TokenRefresh: config.TokenRefreshConfig{
MaxRetries: 2,
RetryBackoffSeconds: 0,
},
}
service := NewTokenRefreshService(repo, nil, nil, nil, nil, invalidator, cfg)
account := &Account{
ID: 12,
Platform: PlatformGemini,
Type: AccountTypeOAuth,
}
refresher := &tokenRefresherStub{
err: errors.New("refresh failed"),
}
err := service.refreshWithRetry(context.Background(), account, refresher)
require.Error(t, err)
require.Equal(t, 0, repo.updateCalls) // 刷新失败不应更新
require.Equal(t, 0, invalidator.calls) // 刷新失败不应触发缓存失效
require.Equal(t, 1, repo.setErrorCalls) // 应设置错误状态
}
// TestTokenRefreshService_RefreshWithRetry_AntigravityRefreshFailed 测试 Antigravity 刷新失败不设置错误状态
func TestTokenRefreshService_RefreshWithRetry_AntigravityRefreshFailed(t *testing.T) {
repo := &tokenRefreshAccountRepo{}
invalidator := &tokenCacheInvalidatorStub{}
cfg := &config.Config{
TokenRefresh: config.TokenRefreshConfig{
MaxRetries: 1,
RetryBackoffSeconds: 0,
},
}
service := NewTokenRefreshService(repo, nil, nil, nil, nil, invalidator, cfg)
account := &Account{
ID: 13,
Platform: PlatformAntigravity,
Type: AccountTypeOAuth,
}
refresher := &tokenRefresherStub{
err: errors.New("network error"), // 可重试错误
}
err := service.refreshWithRetry(context.Background(), account, refresher)
require.Error(t, err)
require.Equal(t, 0, repo.updateCalls)
require.Equal(t, 0, invalidator.calls)
require.Equal(t, 0, repo.setErrorCalls) // Antigravity 可重试错误不设置错误状态
}
// TestTokenRefreshService_RefreshWithRetry_AntigravityNonRetryableError 测试 Antigravity 不可重试错误
func TestTokenRefreshService_RefreshWithRetry_AntigravityNonRetryableError(t *testing.T) {
repo := &tokenRefreshAccountRepo{}
invalidator := &tokenCacheInvalidatorStub{}
cfg := &config.Config{
TokenRefresh: config.TokenRefreshConfig{
MaxRetries: 3,
RetryBackoffSeconds: 0,
},
}
service := NewTokenRefreshService(repo, nil, nil, nil, nil, invalidator, cfg)
account := &Account{
ID: 14,
Platform: PlatformAntigravity,
Type: AccountTypeOAuth,
}
refresher := &tokenRefresherStub{
err: errors.New("invalid_grant: token revoked"), // 不可重试错误
}
err := service.refreshWithRetry(context.Background(), account, refresher)
require.Error(t, err)
require.Equal(t, 0, repo.updateCalls)
require.Equal(t, 0, invalidator.calls)
require.Equal(t, 1, repo.setErrorCalls) // 不可重试错误应设置错误状态
}
// TestIsNonRetryableRefreshError 测试不可重试错误判断
func TestIsNonRetryableRefreshError(t *testing.T) {
tests := []struct {
name string
err error
expected bool
}{
{name: "nil_error", err: nil, expected: false},
{name: "network_error", err: errors.New("network timeout"), expected: false},
{name: "invalid_grant", err: errors.New("invalid_grant"), expected: true},
{name: "invalid_client", err: errors.New("invalid_client"), expected: true},
{name: "unauthorized_client", err: errors.New("unauthorized_client"), expected: true},
{name: "access_denied", err: errors.New("access_denied"), expected: true},
{name: "invalid_grant_with_desc", err: errors.New("Error: invalid_grant - token revoked"), expected: true},
{name: "case_insensitive", err: errors.New("INVALID_GRANT"), expected: true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := isNonRetryableRefreshError(tt.err)
require.Equal(t, tt.expected, result)
})
}
}

View File

@@ -42,9 +42,10 @@ func ProvideTokenRefreshService(
openaiOAuthService *OpenAIOAuthService, openaiOAuthService *OpenAIOAuthService,
geminiOAuthService *GeminiOAuthService, geminiOAuthService *GeminiOAuthService,
antigravityOAuthService *AntigravityOAuthService, antigravityOAuthService *AntigravityOAuthService,
cacheInvalidator TokenCacheInvalidator,
cfg *config.Config, cfg *config.Config,
) *TokenRefreshService { ) *TokenRefreshService {
svc := NewTokenRefreshService(accountRepo, oauthService, openaiOAuthService, geminiOAuthService, antigravityOAuthService, cfg) svc := NewTokenRefreshService(accountRepo, oauthService, openaiOAuthService, geminiOAuthService, antigravityOAuthService, cacheInvalidator, cfg)
svc.Start() svc.Start()
return svc return svc
} }
@@ -108,10 +109,12 @@ func ProvideRateLimitService(
tempUnschedCache TempUnschedCache, tempUnschedCache TempUnschedCache,
timeoutCounterCache TimeoutCounterCache, timeoutCounterCache TimeoutCounterCache,
settingService *SettingService, settingService *SettingService,
tokenCacheInvalidator TokenCacheInvalidator,
) *RateLimitService { ) *RateLimitService {
svc := NewRateLimitService(accountRepo, usageRepo, cfg, geminiQuotaService, tempUnschedCache) svc := NewRateLimitService(accountRepo, usageRepo, cfg, geminiQuotaService, tempUnschedCache)
svc.SetTimeoutCounterCache(timeoutCounterCache) svc.SetTimeoutCounterCache(timeoutCounterCache)
svc.SetSettingService(settingService) svc.SetSettingService(settingService)
svc.SetTokenCacheInvalidator(tokenCacheInvalidator)
return svc return svc
} }
@@ -210,6 +213,7 @@ var ProviderSet = wire.NewSet(
NewOpenAIOAuthService, NewOpenAIOAuthService,
NewGeminiOAuthService, NewGeminiOAuthService,
NewGeminiQuotaService, NewGeminiQuotaService,
NewCompositeTokenCacheInvalidator,
NewAntigravityOAuthService, NewAntigravityOAuthService,
NewGeminiTokenProvider, NewGeminiTokenProvider,
NewGeminiMessagesCompatService, NewGeminiMessagesCompatService,

View File

@@ -387,6 +387,9 @@ rate_limit:
# Cooldown time (in minutes) when upstream returns 529 (overloaded) # Cooldown time (in minutes) when upstream returns 529 (overloaded)
# 上游返回 529过载时的冷却时间分钟 # 上游返回 529过载时的冷却时间分钟
overload_cooldown_minutes: 10 overload_cooldown_minutes: 10
# Cooldown time (in minutes) for OAuth 401 temporary unschedulable
# OAuth 401 临时不可调度冷却时间(分钟)
oauth_401_cooldown_minutes: 5
# ============================================================================= # =============================================================================
# Pricing Data Source (Optional) # Pricing Data Source (Optional)

View File

@@ -69,6 +69,17 @@ JWT_EXPIRE_HOUR=24
# Leave unset to use default ./config.yaml # Leave unset to use default ./config.yaml
#CONFIG_FILE=./config.yaml #CONFIG_FILE=./config.yaml
# -----------------------------------------------------------------------------
# Rate Limiting (Optional)
# 速率限制(可选)
# -----------------------------------------------------------------------------
# Cooldown time (in minutes) when upstream returns 529 (overloaded)
# 上游返回 529过载时的冷却时间分钟
RATE_LIMIT_OVERLOAD_COOLDOWN_MINUTES=10
# Cooldown time (in minutes) for OAuth 401 temporary unschedulable
# OAuth 401 临时不可调度冷却时间(分钟)
RATE_LIMIT_OAUTH_401_COOLDOWN_MINUTES=5
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
# Gateway Scheduling (Optional) # Gateway Scheduling (Optional)
# 调度缓存与受控回源配置(缓存就绪且命中时不读 DB # 调度缓存与受控回源配置(缓存就绪且命中时不读 DB

View File

@@ -429,6 +429,9 @@ rate_limit:
# Cooldown time (in minutes) when upstream returns 529 (overloaded) # Cooldown time (in minutes) when upstream returns 529 (overloaded)
# 上游返回 529过载时的冷却时间分钟 # 上游返回 529过载时的冷却时间分钟
overload_cooldown_minutes: 10 overload_cooldown_minutes: 10
# Cooldown time (in minutes) for OAuth 401 temporary unschedulable
# OAuth 401 临时不可调度冷却时间(分钟)
oauth_401_cooldown_minutes: 5
# ============================================================================= # =============================================================================
# Pricing Data Source (Optional) # Pricing Data Source (Optional)