fix(认证): 修复 OAuth token 缓存失效与 401 处理
新增 token 缓存失效接口并在刷新后清理 401 限流支持自定义规则与可配置冷却时间 补齐缓存失效与 401 处理测试 测试: make test
This commit is contained in:
@@ -98,12 +98,13 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
|
|||||||
geminiQuotaService := service.NewGeminiQuotaService(configConfig, settingRepository)
|
geminiQuotaService := service.NewGeminiQuotaService(configConfig, settingRepository)
|
||||||
tempUnschedCache := repository.NewTempUnschedCache(redisClient)
|
tempUnschedCache := repository.NewTempUnschedCache(redisClient)
|
||||||
timeoutCounterCache := repository.NewTimeoutCounterCache(redisClient)
|
timeoutCounterCache := repository.NewTimeoutCounterCache(redisClient)
|
||||||
rateLimitService := service.ProvideRateLimitService(accountRepository, usageLogRepository, configConfig, geminiQuotaService, tempUnschedCache, timeoutCounterCache, settingService)
|
geminiTokenCache := repository.NewGeminiTokenCache(redisClient)
|
||||||
|
tokenCacheInvalidator := service.NewCompositeTokenCacheInvalidator(geminiTokenCache)
|
||||||
|
rateLimitService := service.ProvideRateLimitService(accountRepository, usageLogRepository, configConfig, geminiQuotaService, tempUnschedCache, timeoutCounterCache, settingService, tokenCacheInvalidator)
|
||||||
claudeUsageFetcher := repository.NewClaudeUsageFetcher()
|
claudeUsageFetcher := repository.NewClaudeUsageFetcher()
|
||||||
antigravityQuotaFetcher := service.NewAntigravityQuotaFetcher(proxyRepository)
|
antigravityQuotaFetcher := service.NewAntigravityQuotaFetcher(proxyRepository)
|
||||||
usageCache := service.NewUsageCache()
|
usageCache := service.NewUsageCache()
|
||||||
accountUsageService := service.NewAccountUsageService(accountRepository, usageLogRepository, claudeUsageFetcher, geminiQuotaService, antigravityQuotaFetcher, usageCache)
|
accountUsageService := service.NewAccountUsageService(accountRepository, usageLogRepository, claudeUsageFetcher, geminiQuotaService, antigravityQuotaFetcher, usageCache)
|
||||||
geminiTokenCache := repository.NewGeminiTokenCache(redisClient)
|
|
||||||
geminiTokenProvider := service.NewGeminiTokenProvider(accountRepository, geminiTokenCache, geminiOAuthService)
|
geminiTokenProvider := service.NewGeminiTokenProvider(accountRepository, geminiTokenCache, geminiOAuthService)
|
||||||
gatewayCache := repository.NewGatewayCache(redisClient)
|
gatewayCache := repository.NewGatewayCache(redisClient)
|
||||||
antigravityTokenProvider := service.NewAntigravityTokenProvider(accountRepository, geminiTokenCache, antigravityOAuthService)
|
antigravityTokenProvider := service.NewAntigravityTokenProvider(accountRepository, geminiTokenCache, antigravityOAuthService)
|
||||||
@@ -166,7 +167,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
|
|||||||
opsAlertEvaluatorService := service.ProvideOpsAlertEvaluatorService(opsService, opsRepository, emailService, redisClient, configConfig)
|
opsAlertEvaluatorService := service.ProvideOpsAlertEvaluatorService(opsService, opsRepository, emailService, redisClient, configConfig)
|
||||||
opsCleanupService := service.ProvideOpsCleanupService(opsRepository, db, redisClient, configConfig)
|
opsCleanupService := service.ProvideOpsCleanupService(opsRepository, db, redisClient, configConfig)
|
||||||
opsScheduledReportService := service.ProvideOpsScheduledReportService(opsService, userService, emailService, redisClient, configConfig)
|
opsScheduledReportService := service.ProvideOpsScheduledReportService(opsService, userService, emailService, redisClient, configConfig)
|
||||||
tokenRefreshService := service.ProvideTokenRefreshService(accountRepository, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, configConfig)
|
tokenRefreshService := service.ProvideTokenRefreshService(accountRepository, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, tokenCacheInvalidator, configConfig)
|
||||||
accountExpiryService := service.ProvideAccountExpiryService(accountRepository)
|
accountExpiryService := service.ProvideAccountExpiryService(accountRepository)
|
||||||
v := provideCleanup(client, redisClient, opsMetricsCollector, opsAggregationService, opsAlertEvaluatorService, opsCleanupService, opsScheduledReportService, schedulerSnapshotService, tokenRefreshService, accountExpiryService, pricingService, emailQueueService, billingCacheService, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService)
|
v := provideCleanup(client, redisClient, opsMetricsCollector, opsAggregationService, opsAlertEvaluatorService, opsCleanupService, opsScheduledReportService, schedulerSnapshotService, tokenRefreshService, accountExpiryService, pricingService, emailQueueService, billingCacheService, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService)
|
||||||
application := &Application{
|
application := &Application{
|
||||||
|
|||||||
@@ -435,7 +435,8 @@ type DefaultConfig struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type RateLimitConfig struct {
|
type RateLimitConfig struct {
|
||||||
OverloadCooldownMinutes int `mapstructure:"overload_cooldown_minutes"` // 529过载冷却时间(分钟)
|
OverloadCooldownMinutes int `mapstructure:"overload_cooldown_minutes"` // 529过载冷却时间(分钟)
|
||||||
|
OAuth401CooldownMinutes int `mapstructure:"oauth_401_cooldown_minutes"` // OAuth 401 临时不可调度冷却时间(分钟)
|
||||||
}
|
}
|
||||||
|
|
||||||
// APIKeyAuthCacheConfig API Key 认证缓存配置
|
// APIKeyAuthCacheConfig API Key 认证缓存配置
|
||||||
@@ -709,6 +710,7 @@ func setDefaults() {
|
|||||||
|
|
||||||
// RateLimit
|
// RateLimit
|
||||||
viper.SetDefault("rate_limit.overload_cooldown_minutes", 10)
|
viper.SetDefault("rate_limit.overload_cooldown_minutes", 10)
|
||||||
|
viper.SetDefault("rate_limit.oauth_401_cooldown_minutes", 5)
|
||||||
|
|
||||||
// Pricing - 从 price-mirror 分支同步,该分支维护了 sha256 哈希文件用于增量更新检查
|
// Pricing - 从 price-mirror 分支同步,该分支维护了 sha256 哈希文件用于增量更新检查
|
||||||
viper.SetDefault("pricing.remote_url", "https://raw.githubusercontent.com/Wei-Shaw/claude-relay-service/price-mirror/model_prices_and_context_window.json")
|
viper.SetDefault("pricing.remote_url", "https://raw.githubusercontent.com/Wei-Shaw/claude-relay-service/price-mirror/model_prices_and_context_window.json")
|
||||||
|
|||||||
@@ -33,6 +33,11 @@ func (c *geminiTokenCache) SetAccessToken(ctx context.Context, cacheKey string,
|
|||||||
return c.rdb.Set(ctx, key, token, ttl).Err()
|
return c.rdb.Set(ctx, key, token, ttl).Err()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (c *geminiTokenCache) DeleteAccessToken(ctx context.Context, cacheKey string) error {
|
||||||
|
key := fmt.Sprintf("%s%s", geminiTokenKeyPrefix, cacheKey)
|
||||||
|
return c.rdb.Del(ctx, key).Err()
|
||||||
|
}
|
||||||
|
|
||||||
func (c *geminiTokenCache) AcquireRefreshLock(ctx context.Context, cacheKey string, ttl time.Duration) (bool, error) {
|
func (c *geminiTokenCache) AcquireRefreshLock(ctx context.Context, cacheKey string, ttl time.Duration) (bool, error) {
|
||||||
key := fmt.Sprintf("%s%s", geminiRefreshLockKeyPrefix, cacheKey)
|
key := fmt.Sprintf("%s%s", geminiRefreshLockKeyPrefix, cacheKey)
|
||||||
return c.rdb.SetNX(ctx, key, 1, ttl).Result()
|
return c.rdb.SetNX(ctx, key, 1, ttl).Result()
|
||||||
|
|||||||
@@ -0,0 +1,47 @@
|
|||||||
|
//go:build integration
|
||||||
|
|
||||||
|
package repository
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||||
|
"github.com/redis/go-redis/v9"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
"github.com/stretchr/testify/suite"
|
||||||
|
)
|
||||||
|
|
||||||
|
type GeminiTokenCacheSuite struct {
|
||||||
|
IntegrationRedisSuite
|
||||||
|
cache service.GeminiTokenCache
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *GeminiTokenCacheSuite) SetupTest() {
|
||||||
|
s.IntegrationRedisSuite.SetupTest()
|
||||||
|
s.cache = NewGeminiTokenCache(s.rdb)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *GeminiTokenCacheSuite) TestDeleteAccessToken() {
|
||||||
|
cacheKey := "project-123"
|
||||||
|
token := "token-value"
|
||||||
|
require.NoError(s.T(), s.cache.SetAccessToken(s.ctx, cacheKey, token, time.Minute))
|
||||||
|
|
||||||
|
got, err := s.cache.GetAccessToken(s.ctx, cacheKey)
|
||||||
|
require.NoError(s.T(), err)
|
||||||
|
require.Equal(s.T(), token, got)
|
||||||
|
|
||||||
|
require.NoError(s.T(), s.cache.DeleteAccessToken(s.ctx, cacheKey))
|
||||||
|
|
||||||
|
_, err = s.cache.GetAccessToken(s.ctx, cacheKey)
|
||||||
|
require.True(s.T(), errors.Is(err, redis.Nil), "expected redis.Nil after delete")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *GeminiTokenCacheSuite) TestDeleteAccessToken_MissingKey() {
|
||||||
|
require.NoError(s.T(), s.cache.DeleteAccessToken(s.ctx, "missing-key"))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGeminiTokenCacheSuite(t *testing.T) {
|
||||||
|
suite.Run(t, new(GeminiTokenCacheSuite))
|
||||||
|
}
|
||||||
28
backend/internal/repository/gemini_token_cache_test.go
Normal file
28
backend/internal/repository/gemini_token_cache_test.go
Normal file
@@ -0,0 +1,28 @@
|
|||||||
|
//go:build unit
|
||||||
|
|
||||||
|
package repository
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/redis/go-redis/v9"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestGeminiTokenCache_DeleteAccessToken_RedisError(t *testing.T) {
|
||||||
|
rdb := redis.NewClient(&redis.Options{
|
||||||
|
Addr: "127.0.0.1:1",
|
||||||
|
DialTimeout: 50 * time.Millisecond,
|
||||||
|
ReadTimeout: 50 * time.Millisecond,
|
||||||
|
WriteTimeout: 50 * time.Millisecond,
|
||||||
|
})
|
||||||
|
t.Cleanup(func() {
|
||||||
|
_ = rdb.Close()
|
||||||
|
})
|
||||||
|
|
||||||
|
cache := NewGeminiTokenCache(rdb)
|
||||||
|
err := cache.DeleteAccessToken(context.Background(), "broken")
|
||||||
|
require.Error(t, err)
|
||||||
|
}
|
||||||
@@ -45,7 +45,7 @@ func (p *AntigravityTokenProvider) GetAccessToken(ctx context.Context, account *
|
|||||||
return "", errors.New("not an antigravity oauth account")
|
return "", errors.New("not an antigravity oauth account")
|
||||||
}
|
}
|
||||||
|
|
||||||
cacheKey := antigravityTokenCacheKey(account)
|
cacheKey := AntigravityTokenCacheKey(account)
|
||||||
|
|
||||||
// 1. 先尝试缓存
|
// 1. 先尝试缓存
|
||||||
if p.tokenCache != nil {
|
if p.tokenCache != nil {
|
||||||
@@ -121,7 +121,7 @@ func (p *AntigravityTokenProvider) GetAccessToken(ctx context.Context, account *
|
|||||||
return accessToken, nil
|
return accessToken, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func antigravityTokenCacheKey(account *Account) string {
|
func AntigravityTokenCacheKey(account *Account) string {
|
||||||
projectID := strings.TrimSpace(account.GetCredential("project_id"))
|
projectID := strings.TrimSpace(account.GetCredential("project_id"))
|
||||||
if projectID != "" {
|
if projectID != "" {
|
||||||
return "ag:" + projectID
|
return "ag:" + projectID
|
||||||
|
|||||||
@@ -10,6 +10,7 @@ type GeminiTokenCache interface {
|
|||||||
// cacheKey should be stable for the token scope; for GeminiCli OAuth we primarily use project_id.
|
// cacheKey should be stable for the token scope; for GeminiCli OAuth we primarily use project_id.
|
||||||
GetAccessToken(ctx context.Context, cacheKey string) (string, error)
|
GetAccessToken(ctx context.Context, cacheKey string) (string, error)
|
||||||
SetAccessToken(ctx context.Context, cacheKey string, token string, ttl time.Duration) error
|
SetAccessToken(ctx context.Context, cacheKey string, token string, ttl time.Duration) error
|
||||||
|
DeleteAccessToken(ctx context.Context, cacheKey string) error
|
||||||
|
|
||||||
AcquireRefreshLock(ctx context.Context, cacheKey string, ttl time.Duration) (bool, error)
|
AcquireRefreshLock(ctx context.Context, cacheKey string, ttl time.Duration) (bool, error)
|
||||||
ReleaseRefreshLock(ctx context.Context, cacheKey string) error
|
ReleaseRefreshLock(ctx context.Context, cacheKey string) error
|
||||||
|
|||||||
@@ -40,7 +40,7 @@ func (p *GeminiTokenProvider) GetAccessToken(ctx context.Context, account *Accou
|
|||||||
return "", errors.New("not a gemini oauth account")
|
return "", errors.New("not a gemini oauth account")
|
||||||
}
|
}
|
||||||
|
|
||||||
cacheKey := geminiTokenCacheKey(account)
|
cacheKey := GeminiTokenCacheKey(account)
|
||||||
|
|
||||||
// 1) Try cache first.
|
// 1) Try cache first.
|
||||||
if p.tokenCache != nil {
|
if p.tokenCache != nil {
|
||||||
@@ -151,7 +151,7 @@ func (p *GeminiTokenProvider) GetAccessToken(ctx context.Context, account *Accou
|
|||||||
return accessToken, nil
|
return accessToken, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func geminiTokenCacheKey(account *Account) string {
|
func GeminiTokenCacheKey(account *Account) string {
|
||||||
projectID := strings.TrimSpace(account.GetCredential("project_id"))
|
projectID := strings.TrimSpace(account.GetCredential("project_id"))
|
||||||
if projectID != "" {
|
if projectID != "" {
|
||||||
return projectID
|
return projectID
|
||||||
|
|||||||
@@ -3,7 +3,7 @@ package service
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"log"
|
"log/slog"
|
||||||
"net/http"
|
"net/http"
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
@@ -15,15 +15,16 @@ import (
|
|||||||
|
|
||||||
// RateLimitService 处理限流和过载状态管理
|
// RateLimitService 处理限流和过载状态管理
|
||||||
type RateLimitService struct {
|
type RateLimitService struct {
|
||||||
accountRepo AccountRepository
|
accountRepo AccountRepository
|
||||||
usageRepo UsageLogRepository
|
usageRepo UsageLogRepository
|
||||||
cfg *config.Config
|
cfg *config.Config
|
||||||
geminiQuotaService *GeminiQuotaService
|
geminiQuotaService *GeminiQuotaService
|
||||||
tempUnschedCache TempUnschedCache
|
tempUnschedCache TempUnschedCache
|
||||||
timeoutCounterCache TimeoutCounterCache
|
timeoutCounterCache TimeoutCounterCache
|
||||||
settingService *SettingService
|
settingService *SettingService
|
||||||
usageCacheMu sync.RWMutex
|
tokenCacheInvalidator TokenCacheInvalidator
|
||||||
usageCache map[int64]*geminiUsageCacheEntry
|
usageCacheMu sync.RWMutex
|
||||||
|
usageCache map[int64]*geminiUsageCacheEntry
|
||||||
}
|
}
|
||||||
|
|
||||||
type geminiUsageCacheEntry struct {
|
type geminiUsageCacheEntry struct {
|
||||||
@@ -56,6 +57,11 @@ func (s *RateLimitService) SetSettingService(settingService *SettingService) {
|
|||||||
s.settingService = settingService
|
s.settingService = settingService
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SetTokenCacheInvalidator 设置 token 缓存清理器(可选依赖)
|
||||||
|
func (s *RateLimitService) SetTokenCacheInvalidator(invalidator TokenCacheInvalidator) {
|
||||||
|
s.tokenCacheInvalidator = invalidator
|
||||||
|
}
|
||||||
|
|
||||||
// HandleUpstreamError 处理上游错误响应,标记账号状态
|
// HandleUpstreamError 处理上游错误响应,标记账号状态
|
||||||
// 返回是否应该停止该账号的调度
|
// 返回是否应该停止该账号的调度
|
||||||
func (s *RateLimitService) HandleUpstreamError(ctx context.Context, account *Account, statusCode int, headers http.Header, responseBody []byte) (shouldDisable bool) {
|
func (s *RateLimitService) HandleUpstreamError(ctx context.Context, account *Account, statusCode int, headers http.Header, responseBody []byte) (shouldDisable bool) {
|
||||||
@@ -63,11 +69,16 @@ func (s *RateLimitService) HandleUpstreamError(ctx context.Context, account *Acc
|
|||||||
// 如果启用且错误码不在列表中,则不处理(不停止调度、不标记限流/过载)
|
// 如果启用且错误码不在列表中,则不处理(不停止调度、不标记限流/过载)
|
||||||
customErrorCodesEnabled := account.IsCustomErrorCodesEnabled()
|
customErrorCodesEnabled := account.IsCustomErrorCodesEnabled()
|
||||||
if !account.ShouldHandleErrorCode(statusCode) {
|
if !account.ShouldHandleErrorCode(statusCode) {
|
||||||
log.Printf("Account %d: error %d skipped (not in custom error codes)", account.ID, statusCode)
|
slog.Info("account_error_code_skipped", "account_id", account.ID, "status_code", statusCode)
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
tempMatched := s.tryTempUnschedulable(ctx, account, statusCode, responseBody)
|
isOAuth401 := statusCode == 401 && account.Type == AccountTypeOAuth &&
|
||||||
|
(account.Platform == PlatformAntigravity || account.Platform == PlatformGemini)
|
||||||
|
tempMatched := false
|
||||||
|
if !isOAuth401 || account.IsTempUnschedulableEnabled() {
|
||||||
|
tempMatched = s.tryTempUnschedulable(ctx, account, statusCode, responseBody)
|
||||||
|
}
|
||||||
upstreamMsg := strings.TrimSpace(extractUpstreamErrorMessage(responseBody))
|
upstreamMsg := strings.TrimSpace(extractUpstreamErrorMessage(responseBody))
|
||||||
upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg)
|
upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg)
|
||||||
if upstreamMsg != "" {
|
if upstreamMsg != "" {
|
||||||
@@ -76,7 +87,19 @@ func (s *RateLimitService) HandleUpstreamError(ctx context.Context, account *Acc
|
|||||||
|
|
||||||
switch statusCode {
|
switch statusCode {
|
||||||
case 401:
|
case 401:
|
||||||
// 认证失败:停止调度,记录错误
|
if isOAuth401 {
|
||||||
|
if tempMatched {
|
||||||
|
if s.tokenCacheInvalidator != nil {
|
||||||
|
if err := s.tokenCacheInvalidator.InvalidateToken(ctx, account); err != nil {
|
||||||
|
slog.Warn("oauth_401_invalidate_cache_failed", "account_id", account.ID, "error", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
shouldDisable = true
|
||||||
|
} else {
|
||||||
|
shouldDisable = s.handleOAuth401TempUnschedulable(ctx, account, upstreamMsg)
|
||||||
|
}
|
||||||
|
break
|
||||||
|
}
|
||||||
msg := "Authentication failed (401): invalid or expired credentials"
|
msg := "Authentication failed (401): invalid or expired credentials"
|
||||||
if upstreamMsg != "" {
|
if upstreamMsg != "" {
|
||||||
msg = "Authentication failed (401): " + upstreamMsg
|
msg = "Authentication failed (401): " + upstreamMsg
|
||||||
@@ -116,7 +139,7 @@ func (s *RateLimitService) HandleUpstreamError(ctx context.Context, account *Acc
|
|||||||
shouldDisable = true
|
shouldDisable = true
|
||||||
} else if statusCode >= 500 {
|
} else if statusCode >= 500 {
|
||||||
// 未启用自定义错误码时:仅记录5xx错误
|
// 未启用自定义错误码时:仅记录5xx错误
|
||||||
log.Printf("Account %d received upstream error %d", account.ID, statusCode)
|
slog.Warn("account_upstream_error", "account_id", account.ID, "status_code", statusCode)
|
||||||
shouldDisable = false
|
shouldDisable = false
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -127,6 +150,63 @@ func (s *RateLimitService) HandleUpstreamError(ctx context.Context, account *Acc
|
|||||||
return shouldDisable
|
return shouldDisable
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *RateLimitService) handleOAuth401TempUnschedulable(ctx context.Context, account *Account, upstreamMsg string) bool {
|
||||||
|
if account == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
if s.tokenCacheInvalidator != nil {
|
||||||
|
if err := s.tokenCacheInvalidator.InvalidateToken(ctx, account); err != nil {
|
||||||
|
slog.Warn("oauth_401_invalidate_cache_failed", "account_id", account.ID, "error", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
now := time.Now()
|
||||||
|
until := now.Add(s.oauth401Cooldown())
|
||||||
|
msg := "Authentication failed (401): invalid or expired credentials"
|
||||||
|
if upstreamMsg != "" {
|
||||||
|
msg = "Authentication failed (401): " + upstreamMsg
|
||||||
|
}
|
||||||
|
|
||||||
|
state := &TempUnschedState{
|
||||||
|
UntilUnix: until.Unix(),
|
||||||
|
TriggeredAtUnix: now.Unix(),
|
||||||
|
StatusCode: 401,
|
||||||
|
MatchedKeyword: "oauth_401",
|
||||||
|
RuleIndex: -1, // -1 表示非规则触发,而是 OAuth 401 特殊处理
|
||||||
|
ErrorMessage: msg,
|
||||||
|
}
|
||||||
|
|
||||||
|
reason := ""
|
||||||
|
if raw, err := json.Marshal(state); err == nil {
|
||||||
|
reason = string(raw)
|
||||||
|
}
|
||||||
|
if reason == "" {
|
||||||
|
reason = msg
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := s.accountRepo.SetTempUnschedulable(ctx, account.ID, until, reason); err != nil {
|
||||||
|
slog.Warn("oauth_401_set_temp_unschedulable_failed", "account_id", account.ID, "error", err)
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
if s.tempUnschedCache != nil {
|
||||||
|
if err := s.tempUnschedCache.SetTempUnsched(ctx, account.ID, state); err != nil {
|
||||||
|
slog.Warn("oauth_401_set_temp_unsched_cache_failed", "account_id", account.ID, "error", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
slog.Info("oauth_401_temp_unschedulable", "account_id", account.ID, "until", until)
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *RateLimitService) oauth401Cooldown() time.Duration {
|
||||||
|
if s != nil && s.cfg != nil && s.cfg.RateLimit.OAuth401CooldownMinutes > 0 {
|
||||||
|
return time.Duration(s.cfg.RateLimit.OAuth401CooldownMinutes) * time.Minute
|
||||||
|
}
|
||||||
|
return 5 * time.Minute
|
||||||
|
}
|
||||||
|
|
||||||
// PreCheckUsage proactively checks local quota before dispatching a request.
|
// PreCheckUsage proactively checks local quota before dispatching a request.
|
||||||
// Returns false when the account should be skipped.
|
// Returns false when the account should be skipped.
|
||||||
func (s *RateLimitService) PreCheckUsage(ctx context.Context, account *Account, requestedModel string) (bool, error) {
|
func (s *RateLimitService) PreCheckUsage(ctx context.Context, account *Account, requestedModel string) (bool, error) {
|
||||||
@@ -188,7 +268,7 @@ func (s *RateLimitService) PreCheckUsage(ctx context.Context, account *Account,
|
|||||||
// NOTE:
|
// NOTE:
|
||||||
// - This is a local precheck to reduce upstream 429s.
|
// - This is a local precheck to reduce upstream 429s.
|
||||||
// - Do NOT mark the account as rate-limited here; rate_limit_reset_at should reflect real upstream 429s.
|
// - Do NOT mark the account as rate-limited here; rate_limit_reset_at should reflect real upstream 429s.
|
||||||
log.Printf("[Gemini PreCheck] Account %d reached daily quota (%d/%d), skip until %v", account.ID, used, limit, resetAt)
|
slog.Info("gemini_precheck_daily_quota_reached", "account_id", account.ID, "used", used, "limit", limit, "reset_at", resetAt)
|
||||||
return false, nil
|
return false, nil
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -231,7 +311,7 @@ func (s *RateLimitService) PreCheckUsage(ctx context.Context, account *Account,
|
|||||||
if used >= limit {
|
if used >= limit {
|
||||||
resetAt := start.Add(time.Minute)
|
resetAt := start.Add(time.Minute)
|
||||||
// Do not persist "rate limited" status from local precheck. See note above.
|
// Do not persist "rate limited" status from local precheck. See note above.
|
||||||
log.Printf("[Gemini PreCheck] Account %d reached minute quota (%d/%d), skip until %v", account.ID, used, limit, resetAt)
|
slog.Info("gemini_precheck_minute_quota_reached", "account_id", account.ID, "used", used, "limit", limit, "reset_at", resetAt)
|
||||||
return false, nil
|
return false, nil
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -288,20 +368,20 @@ func (s *RateLimitService) GeminiCooldown(ctx context.Context, account *Account)
|
|||||||
// handleAuthError 处理认证类错误(401/403),停止账号调度
|
// handleAuthError 处理认证类错误(401/403),停止账号调度
|
||||||
func (s *RateLimitService) handleAuthError(ctx context.Context, account *Account, errorMsg string) {
|
func (s *RateLimitService) handleAuthError(ctx context.Context, account *Account, errorMsg string) {
|
||||||
if err := s.accountRepo.SetError(ctx, account.ID, errorMsg); err != nil {
|
if err := s.accountRepo.SetError(ctx, account.ID, errorMsg); err != nil {
|
||||||
log.Printf("SetError failed for account %d: %v", account.ID, err)
|
slog.Warn("account_set_error_failed", "account_id", account.ID, "error", err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
log.Printf("Account %d disabled due to auth error: %s", account.ID, errorMsg)
|
slog.Warn("account_disabled_auth_error", "account_id", account.ID, "error", errorMsg)
|
||||||
}
|
}
|
||||||
|
|
||||||
// handleCustomErrorCode 处理自定义错误码,停止账号调度
|
// handleCustomErrorCode 处理自定义错误码,停止账号调度
|
||||||
func (s *RateLimitService) handleCustomErrorCode(ctx context.Context, account *Account, statusCode int, errorMsg string) {
|
func (s *RateLimitService) handleCustomErrorCode(ctx context.Context, account *Account, statusCode int, errorMsg string) {
|
||||||
msg := "Custom error code " + strconv.Itoa(statusCode) + ": " + errorMsg
|
msg := "Custom error code " + strconv.Itoa(statusCode) + ": " + errorMsg
|
||||||
if err := s.accountRepo.SetError(ctx, account.ID, msg); err != nil {
|
if err := s.accountRepo.SetError(ctx, account.ID, msg); err != nil {
|
||||||
log.Printf("SetError failed for account %d: %v", account.ID, err)
|
slog.Warn("account_set_error_failed", "account_id", account.ID, "status_code", statusCode, "error", err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
log.Printf("Account %d disabled due to custom error code %d: %s", account.ID, statusCode, errorMsg)
|
slog.Warn("account_disabled_custom_error", "account_id", account.ID, "status_code", statusCode, "error", errorMsg)
|
||||||
}
|
}
|
||||||
|
|
||||||
// handle429 处理429限流错误
|
// handle429 处理429限流错误
|
||||||
@@ -313,7 +393,7 @@ func (s *RateLimitService) handle429(ctx context.Context, account *Account, head
|
|||||||
// 没有重置时间,使用默认5分钟
|
// 没有重置时间,使用默认5分钟
|
||||||
resetAt := time.Now().Add(5 * time.Minute)
|
resetAt := time.Now().Add(5 * time.Minute)
|
||||||
if err := s.accountRepo.SetRateLimited(ctx, account.ID, resetAt); err != nil {
|
if err := s.accountRepo.SetRateLimited(ctx, account.ID, resetAt); err != nil {
|
||||||
log.Printf("SetRateLimited failed for account %d: %v", account.ID, err)
|
slog.Warn("rate_limit_set_failed", "account_id", account.ID, "error", err)
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -321,10 +401,10 @@ func (s *RateLimitService) handle429(ctx context.Context, account *Account, head
|
|||||||
// 解析Unix时间戳
|
// 解析Unix时间戳
|
||||||
ts, err := strconv.ParseInt(resetTimestamp, 10, 64)
|
ts, err := strconv.ParseInt(resetTimestamp, 10, 64)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Printf("Parse reset timestamp failed: %v", err)
|
slog.Warn("rate_limit_reset_parse_failed", "reset_timestamp", resetTimestamp, "error", err)
|
||||||
resetAt := time.Now().Add(5 * time.Minute)
|
resetAt := time.Now().Add(5 * time.Minute)
|
||||||
if err := s.accountRepo.SetRateLimited(ctx, account.ID, resetAt); err != nil {
|
if err := s.accountRepo.SetRateLimited(ctx, account.ID, resetAt); err != nil {
|
||||||
log.Printf("SetRateLimited failed for account %d: %v", account.ID, err)
|
slog.Warn("rate_limit_set_failed", "account_id", account.ID, "error", err)
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -333,7 +413,7 @@ func (s *RateLimitService) handle429(ctx context.Context, account *Account, head
|
|||||||
|
|
||||||
// 标记限流状态
|
// 标记限流状态
|
||||||
if err := s.accountRepo.SetRateLimited(ctx, account.ID, resetAt); err != nil {
|
if err := s.accountRepo.SetRateLimited(ctx, account.ID, resetAt); err != nil {
|
||||||
log.Printf("SetRateLimited failed for account %d: %v", account.ID, err)
|
slog.Warn("rate_limit_set_failed", "account_id", account.ID, "error", err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -341,10 +421,10 @@ func (s *RateLimitService) handle429(ctx context.Context, account *Account, head
|
|||||||
windowEnd := resetAt
|
windowEnd := resetAt
|
||||||
windowStart := resetAt.Add(-5 * time.Hour)
|
windowStart := resetAt.Add(-5 * time.Hour)
|
||||||
if err := s.accountRepo.UpdateSessionWindow(ctx, account.ID, &windowStart, &windowEnd, "rejected"); err != nil {
|
if err := s.accountRepo.UpdateSessionWindow(ctx, account.ID, &windowStart, &windowEnd, "rejected"); err != nil {
|
||||||
log.Printf("UpdateSessionWindow failed for account %d: %v", account.ID, err)
|
slog.Warn("rate_limit_update_session_window_failed", "account_id", account.ID, "error", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Printf("Account %d rate limited until %v", account.ID, resetAt)
|
slog.Info("account_rate_limited", "account_id", account.ID, "reset_at", resetAt)
|
||||||
}
|
}
|
||||||
|
|
||||||
// handle529 处理529过载错误
|
// handle529 处理529过载错误
|
||||||
@@ -357,11 +437,11 @@ func (s *RateLimitService) handle529(ctx context.Context, account *Account) {
|
|||||||
|
|
||||||
until := time.Now().Add(time.Duration(cooldownMinutes) * time.Minute)
|
until := time.Now().Add(time.Duration(cooldownMinutes) * time.Minute)
|
||||||
if err := s.accountRepo.SetOverloaded(ctx, account.ID, until); err != nil {
|
if err := s.accountRepo.SetOverloaded(ctx, account.ID, until); err != nil {
|
||||||
log.Printf("SetOverloaded failed for account %d: %v", account.ID, err)
|
slog.Warn("overload_set_failed", "account_id", account.ID, "error", err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Printf("Account %d overloaded until %v", account.ID, until)
|
slog.Info("account_overloaded", "account_id", account.ID, "until", until)
|
||||||
}
|
}
|
||||||
|
|
||||||
// UpdateSessionWindow 从成功响应更新5h窗口状态
|
// UpdateSessionWindow 从成功响应更新5h窗口状态
|
||||||
@@ -384,17 +464,17 @@ func (s *RateLimitService) UpdateSessionWindow(ctx context.Context, account *Acc
|
|||||||
end := start.Add(5 * time.Hour)
|
end := start.Add(5 * time.Hour)
|
||||||
windowStart = &start
|
windowStart = &start
|
||||||
windowEnd = &end
|
windowEnd = &end
|
||||||
log.Printf("Account %d: initializing 5h window from %v to %v (status: %s)", account.ID, start, end, status)
|
slog.Info("account_session_window_initialized", "account_id", account.ID, "window_start", start, "window_end", end, "status", status)
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := s.accountRepo.UpdateSessionWindow(ctx, account.ID, windowStart, windowEnd, status); err != nil {
|
if err := s.accountRepo.UpdateSessionWindow(ctx, account.ID, windowStart, windowEnd, status); err != nil {
|
||||||
log.Printf("UpdateSessionWindow failed for account %d: %v", account.ID, err)
|
slog.Warn("session_window_update_failed", "account_id", account.ID, "error", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// 如果状态为allowed且之前有限流,说明窗口已重置,清除限流状态
|
// 如果状态为allowed且之前有限流,说明窗口已重置,清除限流状态
|
||||||
if status == "allowed" && account.IsRateLimited() {
|
if status == "allowed" && account.IsRateLimited() {
|
||||||
if err := s.ClearRateLimit(ctx, account.ID); err != nil {
|
if err := s.ClearRateLimit(ctx, account.ID); err != nil {
|
||||||
log.Printf("ClearRateLimit failed for account %d: %v", account.ID, err)
|
slog.Warn("rate_limit_clear_failed", "account_id", account.ID, "error", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -413,7 +493,7 @@ func (s *RateLimitService) ClearTempUnschedulable(ctx context.Context, accountID
|
|||||||
}
|
}
|
||||||
if s.tempUnschedCache != nil {
|
if s.tempUnschedCache != nil {
|
||||||
if err := s.tempUnschedCache.DeleteTempUnsched(ctx, accountID); err != nil {
|
if err := s.tempUnschedCache.DeleteTempUnsched(ctx, accountID); err != nil {
|
||||||
log.Printf("DeleteTempUnsched failed for account %d: %v", accountID, err)
|
slog.Warn("temp_unsched_cache_delete_failed", "account_id", accountID, "error", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
@@ -460,7 +540,7 @@ func (s *RateLimitService) GetTempUnschedStatus(ctx context.Context, accountID i
|
|||||||
|
|
||||||
if s.tempUnschedCache != nil {
|
if s.tempUnschedCache != nil {
|
||||||
if err := s.tempUnschedCache.SetTempUnsched(ctx, accountID, state); err != nil {
|
if err := s.tempUnschedCache.SetTempUnsched(ctx, accountID, state); err != nil {
|
||||||
log.Printf("SetTempUnsched failed for account %d: %v", accountID, err)
|
slog.Warn("temp_unsched_cache_set_failed", "account_id", accountID, "error", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -563,17 +643,17 @@ func (s *RateLimitService) triggerTempUnschedulable(ctx context.Context, account
|
|||||||
}
|
}
|
||||||
|
|
||||||
if err := s.accountRepo.SetTempUnschedulable(ctx, account.ID, until, reason); err != nil {
|
if err := s.accountRepo.SetTempUnschedulable(ctx, account.ID, until, reason); err != nil {
|
||||||
log.Printf("SetTempUnschedulable failed for account %d: %v", account.ID, err)
|
slog.Warn("temp_unsched_set_failed", "account_id", account.ID, "error", err)
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
if s.tempUnschedCache != nil {
|
if s.tempUnschedCache != nil {
|
||||||
if err := s.tempUnschedCache.SetTempUnsched(ctx, account.ID, state); err != nil {
|
if err := s.tempUnschedCache.SetTempUnsched(ctx, account.ID, state); err != nil {
|
||||||
log.Printf("SetTempUnsched cache failed for account %d: %v", account.ID, err)
|
slog.Warn("temp_unsched_cache_set_failed", "account_id", account.ID, "error", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Printf("Account %d temp unschedulable until %v (rule %d, code %d)", account.ID, until, ruleIndex, statusCode)
|
slog.Info("account_temp_unschedulable", "account_id", account.ID, "until", until, "rule_index", ruleIndex, "status_code", statusCode)
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -597,13 +677,13 @@ func (s *RateLimitService) HandleStreamTimeout(ctx context.Context, account *Acc
|
|||||||
|
|
||||||
// 获取系统设置
|
// 获取系统设置
|
||||||
if s.settingService == nil {
|
if s.settingService == nil {
|
||||||
log.Printf("[StreamTimeout] settingService not configured, skipping timeout handling for account %d", account.ID)
|
slog.Warn("stream_timeout_setting_service_missing", "account_id", account.ID)
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
settings, err := s.settingService.GetStreamTimeoutSettings(ctx)
|
settings, err := s.settingService.GetStreamTimeoutSettings(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Printf("[StreamTimeout] Failed to get settings: %v", err)
|
slog.Warn("stream_timeout_get_settings_failed", "account_id", account.ID, "error", err)
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -620,14 +700,13 @@ func (s *RateLimitService) HandleStreamTimeout(ctx context.Context, account *Acc
|
|||||||
if s.timeoutCounterCache != nil {
|
if s.timeoutCounterCache != nil {
|
||||||
count, err = s.timeoutCounterCache.IncrementTimeoutCount(ctx, account.ID, settings.ThresholdWindowMinutes)
|
count, err = s.timeoutCounterCache.IncrementTimeoutCount(ctx, account.ID, settings.ThresholdWindowMinutes)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Printf("[StreamTimeout] Failed to increment timeout count for account %d: %v", account.ID, err)
|
slog.Warn("stream_timeout_increment_count_failed", "account_id", account.ID, "error", err)
|
||||||
// 继续处理,使用 count=1
|
// 继续处理,使用 count=1
|
||||||
count = 1
|
count = 1
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Printf("[StreamTimeout] Account %d timeout count: %d/%d (window: %d min, model: %s)",
|
slog.Info("stream_timeout_count", "account_id", account.ID, "count", count, "threshold", settings.ThresholdCount, "window_minutes", settings.ThresholdWindowMinutes, "model", model)
|
||||||
account.ID, count, settings.ThresholdCount, settings.ThresholdWindowMinutes, model)
|
|
||||||
|
|
||||||
// 检查是否达到阈值
|
// 检查是否达到阈值
|
||||||
if count < int64(settings.ThresholdCount) {
|
if count < int64(settings.ThresholdCount) {
|
||||||
@@ -668,24 +747,24 @@ func (s *RateLimitService) triggerStreamTimeoutTempUnsched(ctx context.Context,
|
|||||||
}
|
}
|
||||||
|
|
||||||
if err := s.accountRepo.SetTempUnschedulable(ctx, account.ID, until, reason); err != nil {
|
if err := s.accountRepo.SetTempUnschedulable(ctx, account.ID, until, reason); err != nil {
|
||||||
log.Printf("[StreamTimeout] SetTempUnschedulable failed for account %d: %v", account.ID, err)
|
slog.Warn("stream_timeout_set_temp_unsched_failed", "account_id", account.ID, "error", err)
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
if s.tempUnschedCache != nil {
|
if s.tempUnschedCache != nil {
|
||||||
if err := s.tempUnschedCache.SetTempUnsched(ctx, account.ID, state); err != nil {
|
if err := s.tempUnschedCache.SetTempUnsched(ctx, account.ID, state); err != nil {
|
||||||
log.Printf("[StreamTimeout] SetTempUnsched cache failed for account %d: %v", account.ID, err)
|
slog.Warn("stream_timeout_set_temp_unsched_cache_failed", "account_id", account.ID, "error", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// 重置超时计数
|
// 重置超时计数
|
||||||
if s.timeoutCounterCache != nil {
|
if s.timeoutCounterCache != nil {
|
||||||
if err := s.timeoutCounterCache.ResetTimeoutCount(ctx, account.ID); err != nil {
|
if err := s.timeoutCounterCache.ResetTimeoutCount(ctx, account.ID); err != nil {
|
||||||
log.Printf("[StreamTimeout] ResetTimeoutCount failed for account %d: %v", account.ID, err)
|
slog.Warn("stream_timeout_reset_count_failed", "account_id", account.ID, "error", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Printf("[StreamTimeout] Account %d marked as temp unschedulable until %v (model: %s)", account.ID, until, model)
|
slog.Info("stream_timeout_temp_unschedulable", "account_id", account.ID, "until", until, "model", model)
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -694,17 +773,17 @@ func (s *RateLimitService) triggerStreamTimeoutError(ctx context.Context, accoun
|
|||||||
errorMsg := "Stream data interval timeout (repeated failures) for model: " + model
|
errorMsg := "Stream data interval timeout (repeated failures) for model: " + model
|
||||||
|
|
||||||
if err := s.accountRepo.SetError(ctx, account.ID, errorMsg); err != nil {
|
if err := s.accountRepo.SetError(ctx, account.ID, errorMsg); err != nil {
|
||||||
log.Printf("[StreamTimeout] SetError failed for account %d: %v", account.ID, err)
|
slog.Warn("stream_timeout_set_error_failed", "account_id", account.ID, "error", err)
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
// 重置超时计数
|
// 重置超时计数
|
||||||
if s.timeoutCounterCache != nil {
|
if s.timeoutCounterCache != nil {
|
||||||
if err := s.timeoutCounterCache.ResetTimeoutCount(ctx, account.ID); err != nil {
|
if err := s.timeoutCounterCache.ResetTimeoutCount(ctx, account.ID); err != nil {
|
||||||
log.Printf("[StreamTimeout] ResetTimeoutCount failed for account %d: %v", account.ID, err)
|
slog.Warn("stream_timeout_reset_count_failed", "account_id", account.ID, "error", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Printf("[StreamTimeout] Account %d marked as error (model: %s)", account.ID, model)
|
slog.Warn("stream_timeout_account_error", "account_id", account.ID, "model", model)
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|||||||
353
backend/internal/service/ratelimit_service_401_test.go
Normal file
353
backend/internal/service/ratelimit_service_401_test.go
Normal file
@@ -0,0 +1,353 @@
|
|||||||
|
//go:build unit
|
||||||
|
|
||||||
|
package service
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"net/http"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
type rateLimitAccountRepoStub struct {
|
||||||
|
mockAccountRepoForGemini
|
||||||
|
tempCalls int
|
||||||
|
tempUntil time.Time
|
||||||
|
tempReason string
|
||||||
|
setErrorCalls int
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *rateLimitAccountRepoStub) SetTempUnschedulable(ctx context.Context, id int64, until time.Time, reason string) error {
|
||||||
|
r.tempCalls++
|
||||||
|
r.tempUntil = until
|
||||||
|
r.tempReason = reason
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *rateLimitAccountRepoStub) SetError(ctx context.Context, id int64, errorMsg string) error {
|
||||||
|
r.setErrorCalls++
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
type tokenCacheInvalidatorRecorder struct {
|
||||||
|
accounts []*Account
|
||||||
|
err error
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *tokenCacheInvalidatorRecorder) InvalidateToken(ctx context.Context, account *Account) error {
|
||||||
|
r.accounts = append(r.accounts, account)
|
||||||
|
return r.err
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRateLimitService_HandleUpstreamError_OAuth401TempUnschedulable(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
platform string
|
||||||
|
}{
|
||||||
|
{name: "gemini", platform: PlatformGemini},
|
||||||
|
{name: "antigravity", platform: PlatformAntigravity},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
repo := &rateLimitAccountRepoStub{}
|
||||||
|
invalidator := &tokenCacheInvalidatorRecorder{}
|
||||||
|
service := NewRateLimitService(repo, nil, &config.Config{}, nil, nil)
|
||||||
|
service.SetTokenCacheInvalidator(invalidator)
|
||||||
|
account := &Account{
|
||||||
|
ID: 100,
|
||||||
|
Platform: tt.platform,
|
||||||
|
Type: AccountTypeOAuth,
|
||||||
|
}
|
||||||
|
|
||||||
|
start := time.Now()
|
||||||
|
shouldDisable := service.HandleUpstreamError(context.Background(), account, 401, http.Header{}, []byte("unauthorized"))
|
||||||
|
|
||||||
|
require.True(t, shouldDisable)
|
||||||
|
require.Equal(t, 1, repo.tempCalls)
|
||||||
|
require.Equal(t, 0, repo.setErrorCalls)
|
||||||
|
require.Len(t, invalidator.accounts, 1)
|
||||||
|
require.WithinDuration(t, start.Add(5*time.Minute), repo.tempUntil, 10*time.Second)
|
||||||
|
require.NotEmpty(t, repo.tempReason)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRateLimitService_HandleUpstreamError_OAuth401InvalidatorError(t *testing.T) {
|
||||||
|
repo := &rateLimitAccountRepoStub{}
|
||||||
|
invalidator := &tokenCacheInvalidatorRecorder{err: errors.New("boom")}
|
||||||
|
service := NewRateLimitService(repo, nil, &config.Config{}, nil, nil)
|
||||||
|
service.SetTokenCacheInvalidator(invalidator)
|
||||||
|
account := &Account{
|
||||||
|
ID: 101,
|
||||||
|
Platform: PlatformGemini,
|
||||||
|
Type: AccountTypeOAuth,
|
||||||
|
}
|
||||||
|
|
||||||
|
shouldDisable := service.HandleUpstreamError(context.Background(), account, 401, http.Header{}, []byte("unauthorized"))
|
||||||
|
|
||||||
|
require.True(t, shouldDisable)
|
||||||
|
require.Equal(t, 1, repo.tempCalls)
|
||||||
|
require.Equal(t, 0, repo.setErrorCalls)
|
||||||
|
require.Len(t, invalidator.accounts, 1)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRateLimitService_HandleUpstreamError_OAuth401CustomRule(t *testing.T) {
|
||||||
|
repo := &rateLimitAccountRepoStub{}
|
||||||
|
invalidator := &tokenCacheInvalidatorRecorder{}
|
||||||
|
service := NewRateLimitService(repo, nil, &config.Config{}, nil, nil)
|
||||||
|
service.SetTokenCacheInvalidator(invalidator)
|
||||||
|
account := &Account{
|
||||||
|
ID: 103,
|
||||||
|
Platform: PlatformGemini,
|
||||||
|
Type: AccountTypeOAuth,
|
||||||
|
Credentials: map[string]any{
|
||||||
|
"temp_unschedulable_enabled": true,
|
||||||
|
"temp_unschedulable_rules": []any{
|
||||||
|
map[string]any{
|
||||||
|
"error_code": 401,
|
||||||
|
"keywords": []any{"unauthorized"},
|
||||||
|
"duration_minutes": 30,
|
||||||
|
"description": "custom rule",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
start := time.Now()
|
||||||
|
shouldDisable := service.HandleUpstreamError(context.Background(), account, 401, http.Header{}, []byte("unauthorized"))
|
||||||
|
|
||||||
|
require.True(t, shouldDisable)
|
||||||
|
require.Equal(t, 1, repo.tempCalls)
|
||||||
|
require.Equal(t, 0, repo.setErrorCalls)
|
||||||
|
require.Len(t, invalidator.accounts, 1)
|
||||||
|
require.WithinDuration(t, start.Add(30*time.Minute), repo.tempUntil, 10*time.Second)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRateLimitService_HandleUpstreamError_NonOAuth401(t *testing.T) {
|
||||||
|
repo := &rateLimitAccountRepoStub{}
|
||||||
|
invalidator := &tokenCacheInvalidatorRecorder{}
|
||||||
|
service := NewRateLimitService(repo, nil, &config.Config{}, nil, nil)
|
||||||
|
service.SetTokenCacheInvalidator(invalidator)
|
||||||
|
account := &Account{
|
||||||
|
ID: 102,
|
||||||
|
Platform: PlatformOpenAI,
|
||||||
|
Type: AccountTypeAPIKey,
|
||||||
|
}
|
||||||
|
|
||||||
|
shouldDisable := service.HandleUpstreamError(context.Background(), account, 401, http.Header{}, []byte("unauthorized"))
|
||||||
|
|
||||||
|
require.True(t, shouldDisable)
|
||||||
|
require.Equal(t, 0, repo.tempCalls)
|
||||||
|
require.Equal(t, 1, repo.setErrorCalls)
|
||||||
|
require.Empty(t, invalidator.accounts)
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestRateLimitService_HandleOAuth401_NilAccount 测试 account 为 nil 的情况
|
||||||
|
func TestRateLimitService_HandleOAuth401_NilAccount(t *testing.T) {
|
||||||
|
repo := &rateLimitAccountRepoStub{}
|
||||||
|
service := NewRateLimitService(repo, nil, &config.Config{}, nil, nil)
|
||||||
|
|
||||||
|
result := service.handleOAuth401TempUnschedulable(context.Background(), nil, "error")
|
||||||
|
|
||||||
|
require.False(t, result)
|
||||||
|
require.Equal(t, 0, repo.tempCalls)
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestRateLimitService_HandleOAuth401_NilInvalidator 测试 tokenCacheInvalidator 为 nil 的情况
|
||||||
|
func TestRateLimitService_HandleOAuth401_NilInvalidator(t *testing.T) {
|
||||||
|
repo := &rateLimitAccountRepoStub{}
|
||||||
|
service := NewRateLimitService(repo, nil, &config.Config{}, nil, nil)
|
||||||
|
// 不设置 tokenCacheInvalidator
|
||||||
|
account := &Account{
|
||||||
|
ID: 200,
|
||||||
|
Platform: PlatformGemini,
|
||||||
|
Type: AccountTypeOAuth,
|
||||||
|
}
|
||||||
|
|
||||||
|
result := service.handleOAuth401TempUnschedulable(context.Background(), account, "error")
|
||||||
|
|
||||||
|
require.True(t, result)
|
||||||
|
require.Equal(t, 1, repo.tempCalls)
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestRateLimitService_HandleOAuth401_SetTempUnschedulableFailed 测试 SetTempUnschedulable 失败的情况
|
||||||
|
func TestRateLimitService_HandleOAuth401_SetTempUnschedulableFailed(t *testing.T) {
|
||||||
|
repo := &rateLimitAccountRepoStubWithError{
|
||||||
|
setTempErr: errors.New("db error"),
|
||||||
|
}
|
||||||
|
invalidator := &tokenCacheInvalidatorRecorder{}
|
||||||
|
service := NewRateLimitService(repo, nil, &config.Config{}, nil, nil)
|
||||||
|
service.SetTokenCacheInvalidator(invalidator)
|
||||||
|
account := &Account{
|
||||||
|
ID: 201,
|
||||||
|
Platform: PlatformGemini,
|
||||||
|
Type: AccountTypeOAuth,
|
||||||
|
}
|
||||||
|
|
||||||
|
result := service.handleOAuth401TempUnschedulable(context.Background(), account, "error")
|
||||||
|
|
||||||
|
require.False(t, result) // 失败应返回 false
|
||||||
|
require.Len(t, invalidator.accounts, 1) // 但 invalidator 仍然被调用
|
||||||
|
}
|
||||||
|
|
||||||
|
// rateLimitAccountRepoStubWithError 支持返回错误的 stub
|
||||||
|
type rateLimitAccountRepoStubWithError struct {
|
||||||
|
mockAccountRepoForGemini
|
||||||
|
setTempErr error
|
||||||
|
setErrorCalls int
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *rateLimitAccountRepoStubWithError) SetTempUnschedulable(ctx context.Context, id int64, until time.Time, reason string) error {
|
||||||
|
return r.setTempErr
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *rateLimitAccountRepoStubWithError) SetError(ctx context.Context, id int64, errorMsg string) error {
|
||||||
|
r.setErrorCalls++
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestRateLimitService_HandleOAuth401_WithTempUnschedCache 测试 tempUnschedCache 存在的情况
|
||||||
|
func TestRateLimitService_HandleOAuth401_WithTempUnschedCache(t *testing.T) {
|
||||||
|
repo := &rateLimitAccountRepoStub{}
|
||||||
|
invalidator := &tokenCacheInvalidatorRecorder{}
|
||||||
|
tempCache := &tempUnschedCacheStub{}
|
||||||
|
service := NewRateLimitService(repo, nil, &config.Config{}, nil, tempCache)
|
||||||
|
service.SetTokenCacheInvalidator(invalidator)
|
||||||
|
account := &Account{
|
||||||
|
ID: 202,
|
||||||
|
Platform: PlatformGemini,
|
||||||
|
Type: AccountTypeOAuth,
|
||||||
|
}
|
||||||
|
|
||||||
|
result := service.handleOAuth401TempUnschedulable(context.Background(), account, "error")
|
||||||
|
|
||||||
|
require.True(t, result)
|
||||||
|
require.Equal(t, 1, repo.tempCalls)
|
||||||
|
require.Equal(t, 1, tempCache.setCalls)
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestRateLimitService_HandleOAuth401_TempUnschedCacheError 测试 tempUnschedCache 设置失败的情况
|
||||||
|
func TestRateLimitService_HandleOAuth401_TempUnschedCacheError(t *testing.T) {
|
||||||
|
repo := &rateLimitAccountRepoStub{}
|
||||||
|
invalidator := &tokenCacheInvalidatorRecorder{}
|
||||||
|
tempCache := &tempUnschedCacheStub{setErr: errors.New("cache error")}
|
||||||
|
service := NewRateLimitService(repo, nil, &config.Config{}, nil, tempCache)
|
||||||
|
service.SetTokenCacheInvalidator(invalidator)
|
||||||
|
account := &Account{
|
||||||
|
ID: 203,
|
||||||
|
Platform: PlatformGemini,
|
||||||
|
Type: AccountTypeOAuth,
|
||||||
|
}
|
||||||
|
|
||||||
|
result := service.handleOAuth401TempUnschedulable(context.Background(), account, "error")
|
||||||
|
|
||||||
|
require.True(t, result) // 缓存错误不影响主流程
|
||||||
|
require.Equal(t, 1, repo.tempCalls)
|
||||||
|
}
|
||||||
|
|
||||||
|
// tempUnschedCacheStub 用于测试的 TempUnschedCache stub
|
||||||
|
type tempUnschedCacheStub struct {
|
||||||
|
setCalls int
|
||||||
|
setErr error
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *tempUnschedCacheStub) GetTempUnsched(ctx context.Context, accountID int64) (*TempUnschedState, error) {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *tempUnschedCacheStub) SetTempUnsched(ctx context.Context, accountID int64, state *TempUnschedState) error {
|
||||||
|
c.setCalls++
|
||||||
|
return c.setErr
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *tempUnschedCacheStub) DeleteTempUnsched(ctx context.Context, accountID int64) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestRateLimitService_OAuth401Cooldown 测试 oauth401Cooldown 函数
|
||||||
|
func TestRateLimitService_OAuth401Cooldown(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
cfg *config.Config
|
||||||
|
expected time.Duration
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "default_when_config_zero",
|
||||||
|
cfg: &config.Config{RateLimit: config.RateLimitConfig{OAuth401CooldownMinutes: 0}},
|
||||||
|
expected: 5 * time.Minute,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "custom_cooldown_10_minutes",
|
||||||
|
cfg: &config.Config{RateLimit: config.RateLimitConfig{OAuth401CooldownMinutes: 10}},
|
||||||
|
expected: 10 * time.Minute,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "custom_cooldown_1_minute",
|
||||||
|
cfg: &config.Config{RateLimit: config.RateLimitConfig{OAuth401CooldownMinutes: 1}},
|
||||||
|
expected: 1 * time.Minute,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "negative_value_uses_default",
|
||||||
|
cfg: &config.Config{RateLimit: config.RateLimitConfig{OAuth401CooldownMinutes: -5}},
|
||||||
|
expected: 5 * time.Minute,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
service := NewRateLimitService(nil, nil, tt.cfg, nil, nil)
|
||||||
|
result := service.oauth401Cooldown()
|
||||||
|
require.Equal(t, tt.expected, result)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestRateLimitService_OAuth401Cooldown_NilConfig 测试 cfg 为 nil 的情况
|
||||||
|
func TestRateLimitService_OAuth401Cooldown_NilConfig(t *testing.T) {
|
||||||
|
service := &RateLimitService{cfg: nil}
|
||||||
|
result := service.oauth401Cooldown()
|
||||||
|
require.Equal(t, 5*time.Minute, result)
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestRateLimitService_HandleOAuth401_WithCustomCooldown 测试自定义 cooldown 配置
|
||||||
|
func TestRateLimitService_HandleOAuth401_WithCustomCooldown(t *testing.T) {
|
||||||
|
repo := &rateLimitAccountRepoStub{}
|
||||||
|
cfg := &config.Config{
|
||||||
|
RateLimit: config.RateLimitConfig{
|
||||||
|
OAuth401CooldownMinutes: 15,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
service := NewRateLimitService(repo, nil, cfg, nil, nil)
|
||||||
|
account := &Account{
|
||||||
|
ID: 204,
|
||||||
|
Platform: PlatformAntigravity,
|
||||||
|
Type: AccountTypeOAuth,
|
||||||
|
}
|
||||||
|
|
||||||
|
start := time.Now()
|
||||||
|
result := service.handleOAuth401TempUnschedulable(context.Background(), account, "error")
|
||||||
|
|
||||||
|
require.True(t, result)
|
||||||
|
require.WithinDuration(t, start.Add(15*time.Minute), repo.tempUntil, 10*time.Second)
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestRateLimitService_HandleOAuth401_EmptyUpstreamMsg 测试 upstreamMsg 为空的情况
|
||||||
|
func TestRateLimitService_HandleOAuth401_EmptyUpstreamMsg(t *testing.T) {
|
||||||
|
repo := &rateLimitAccountRepoStub{}
|
||||||
|
service := NewRateLimitService(repo, nil, &config.Config{}, nil, nil)
|
||||||
|
account := &Account{
|
||||||
|
ID: 205,
|
||||||
|
Platform: PlatformGemini,
|
||||||
|
Type: AccountTypeOAuth,
|
||||||
|
}
|
||||||
|
|
||||||
|
result := service.handleOAuth401TempUnschedulable(context.Background(), account, "")
|
||||||
|
|
||||||
|
require.True(t, result)
|
||||||
|
require.Contains(t, repo.tempReason, "Authentication failed (401)")
|
||||||
|
}
|
||||||
35
backend/internal/service/token_cache_invalidator.go
Normal file
35
backend/internal/service/token_cache_invalidator.go
Normal file
@@ -0,0 +1,35 @@
|
|||||||
|
package service
|
||||||
|
|
||||||
|
import "context"
|
||||||
|
|
||||||
|
type TokenCacheInvalidator interface {
|
||||||
|
InvalidateToken(ctx context.Context, account *Account) error
|
||||||
|
}
|
||||||
|
|
||||||
|
type CompositeTokenCacheInvalidator struct {
|
||||||
|
geminiCache GeminiTokenCache
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewCompositeTokenCacheInvalidator(geminiCache GeminiTokenCache) *CompositeTokenCacheInvalidator {
|
||||||
|
return &CompositeTokenCacheInvalidator{
|
||||||
|
geminiCache: geminiCache,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *CompositeTokenCacheInvalidator) InvalidateToken(ctx context.Context, account *Account) error {
|
||||||
|
if c == nil || c.geminiCache == nil || account == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
if account.Type != AccountTypeOAuth {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
switch account.Platform {
|
||||||
|
case PlatformGemini:
|
||||||
|
return c.geminiCache.DeleteAccessToken(ctx, GeminiTokenCacheKey(account))
|
||||||
|
case PlatformAntigravity:
|
||||||
|
return c.geminiCache.DeleteAccessToken(ctx, AntigravityTokenCacheKey(account))
|
||||||
|
default:
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
97
backend/internal/service/token_cache_invalidator_test.go
Normal file
97
backend/internal/service/token_cache_invalidator_test.go
Normal file
@@ -0,0 +1,97 @@
|
|||||||
|
//go:build unit
|
||||||
|
|
||||||
|
package service
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
type geminiTokenCacheStub struct {
|
||||||
|
deletedKeys []string
|
||||||
|
deleteErr error
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *geminiTokenCacheStub) GetAccessToken(ctx context.Context, cacheKey string) (string, error) {
|
||||||
|
return "", nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *geminiTokenCacheStub) SetAccessToken(ctx context.Context, cacheKey string, token string, ttl time.Duration) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *geminiTokenCacheStub) DeleteAccessToken(ctx context.Context, cacheKey string) error {
|
||||||
|
s.deletedKeys = append(s.deletedKeys, cacheKey)
|
||||||
|
return s.deleteErr
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *geminiTokenCacheStub) AcquireRefreshLock(ctx context.Context, cacheKey string, ttl time.Duration) (bool, error) {
|
||||||
|
return true, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *geminiTokenCacheStub) ReleaseRefreshLock(ctx context.Context, cacheKey string) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCompositeTokenCacheInvalidator_Gemini(t *testing.T) {
|
||||||
|
cache := &geminiTokenCacheStub{}
|
||||||
|
invalidator := NewCompositeTokenCacheInvalidator(cache)
|
||||||
|
account := &Account{
|
||||||
|
ID: 10,
|
||||||
|
Platform: PlatformGemini,
|
||||||
|
Type: AccountTypeOAuth,
|
||||||
|
Credentials: map[string]any{
|
||||||
|
"project_id": "project-x",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
err := invalidator.InvalidateToken(context.Background(), account)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, []string{"project-x"}, cache.deletedKeys)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCompositeTokenCacheInvalidator_Antigravity(t *testing.T) {
|
||||||
|
cache := &geminiTokenCacheStub{}
|
||||||
|
invalidator := NewCompositeTokenCacheInvalidator(cache)
|
||||||
|
account := &Account{
|
||||||
|
ID: 99,
|
||||||
|
Platform: PlatformAntigravity,
|
||||||
|
Type: AccountTypeOAuth,
|
||||||
|
Credentials: map[string]any{
|
||||||
|
"project_id": "ag-project",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
err := invalidator.InvalidateToken(context.Background(), account)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, []string{"ag:ag-project"}, cache.deletedKeys)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCompositeTokenCacheInvalidator_SkipNonOAuth(t *testing.T) {
|
||||||
|
cache := &geminiTokenCacheStub{}
|
||||||
|
invalidator := NewCompositeTokenCacheInvalidator(cache)
|
||||||
|
account := &Account{
|
||||||
|
ID: 1,
|
||||||
|
Platform: PlatformGemini,
|
||||||
|
Type: AccountTypeAPIKey,
|
||||||
|
}
|
||||||
|
|
||||||
|
err := invalidator.InvalidateToken(context.Background(), account)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Empty(t, cache.deletedKeys)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCompositeTokenCacheInvalidator_NilCache(t *testing.T) {
|
||||||
|
invalidator := NewCompositeTokenCacheInvalidator(nil)
|
||||||
|
account := &Account{
|
||||||
|
ID: 2,
|
||||||
|
Platform: PlatformGemini,
|
||||||
|
Type: AccountTypeOAuth,
|
||||||
|
}
|
||||||
|
|
||||||
|
err := invalidator.InvalidateToken(context.Background(), account)
|
||||||
|
require.NoError(t, err)
|
||||||
|
}
|
||||||
153
backend/internal/service/token_cache_key_test.go
Normal file
153
backend/internal/service/token_cache_key_test.go
Normal file
@@ -0,0 +1,153 @@
|
|||||||
|
//go:build unit
|
||||||
|
|
||||||
|
package service
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestGeminiTokenCacheKey(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
account *Account
|
||||||
|
expected string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "with_project_id",
|
||||||
|
account: &Account{
|
||||||
|
ID: 100,
|
||||||
|
Credentials: map[string]any{
|
||||||
|
"project_id": "my-project-123",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
expected: "my-project-123",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "project_id_with_whitespace",
|
||||||
|
account: &Account{
|
||||||
|
ID: 101,
|
||||||
|
Credentials: map[string]any{
|
||||||
|
"project_id": " project-with-spaces ",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
expected: "project-with-spaces",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "empty_project_id_fallback_to_account_id",
|
||||||
|
account: &Account{
|
||||||
|
ID: 102,
|
||||||
|
Credentials: map[string]any{
|
||||||
|
"project_id": "",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
expected: "account:102",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "whitespace_only_project_id_fallback_to_account_id",
|
||||||
|
account: &Account{
|
||||||
|
ID: 103,
|
||||||
|
Credentials: map[string]any{
|
||||||
|
"project_id": " ",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
expected: "account:103",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "no_project_id_key_fallback_to_account_id",
|
||||||
|
account: &Account{
|
||||||
|
ID: 104,
|
||||||
|
Credentials: map[string]any{},
|
||||||
|
},
|
||||||
|
expected: "account:104",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "nil_credentials_fallback_to_account_id",
|
||||||
|
account: &Account{
|
||||||
|
ID: 105,
|
||||||
|
Credentials: nil,
|
||||||
|
},
|
||||||
|
expected: "account:105",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
result := GeminiTokenCacheKey(tt.account)
|
||||||
|
require.Equal(t, tt.expected, result)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAntigravityTokenCacheKey(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
account *Account
|
||||||
|
expected string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "with_project_id",
|
||||||
|
account: &Account{
|
||||||
|
ID: 200,
|
||||||
|
Credentials: map[string]any{
|
||||||
|
"project_id": "ag-project-456",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
expected: "ag:ag-project-456",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "project_id_with_whitespace",
|
||||||
|
account: &Account{
|
||||||
|
ID: 201,
|
||||||
|
Credentials: map[string]any{
|
||||||
|
"project_id": " ag-project-spaces ",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
expected: "ag:ag-project-spaces",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "empty_project_id_fallback_to_account_id",
|
||||||
|
account: &Account{
|
||||||
|
ID: 202,
|
||||||
|
Credentials: map[string]any{
|
||||||
|
"project_id": "",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
expected: "ag:account:202",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "whitespace_only_project_id_fallback_to_account_id",
|
||||||
|
account: &Account{
|
||||||
|
ID: 203,
|
||||||
|
Credentials: map[string]any{
|
||||||
|
"project_id": " ",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
expected: "ag:account:203",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "no_project_id_key_fallback_to_account_id",
|
||||||
|
account: &Account{
|
||||||
|
ID: 204,
|
||||||
|
Credentials: map[string]any{},
|
||||||
|
},
|
||||||
|
expected: "ag:account:204",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "nil_credentials_fallback_to_account_id",
|
||||||
|
account: &Account{
|
||||||
|
ID: 205,
|
||||||
|
Credentials: nil,
|
||||||
|
},
|
||||||
|
expected: "ag:account:205",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
result := AntigravityTokenCacheKey(tt.account)
|
||||||
|
require.Equal(t, tt.expected, result)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -14,9 +14,10 @@ import (
|
|||||||
// TokenRefreshService OAuth token自动刷新服务
|
// TokenRefreshService OAuth token自动刷新服务
|
||||||
// 定期检查并刷新即将过期的token
|
// 定期检查并刷新即将过期的token
|
||||||
type TokenRefreshService struct {
|
type TokenRefreshService struct {
|
||||||
accountRepo AccountRepository
|
accountRepo AccountRepository
|
||||||
refreshers []TokenRefresher
|
refreshers []TokenRefresher
|
||||||
cfg *config.TokenRefreshConfig
|
cfg *config.TokenRefreshConfig
|
||||||
|
cacheInvalidator TokenCacheInvalidator
|
||||||
|
|
||||||
stopCh chan struct{}
|
stopCh chan struct{}
|
||||||
wg sync.WaitGroup
|
wg sync.WaitGroup
|
||||||
@@ -29,12 +30,14 @@ func NewTokenRefreshService(
|
|||||||
openaiOAuthService *OpenAIOAuthService,
|
openaiOAuthService *OpenAIOAuthService,
|
||||||
geminiOAuthService *GeminiOAuthService,
|
geminiOAuthService *GeminiOAuthService,
|
||||||
antigravityOAuthService *AntigravityOAuthService,
|
antigravityOAuthService *AntigravityOAuthService,
|
||||||
|
cacheInvalidator TokenCacheInvalidator,
|
||||||
cfg *config.Config,
|
cfg *config.Config,
|
||||||
) *TokenRefreshService {
|
) *TokenRefreshService {
|
||||||
s := &TokenRefreshService{
|
s := &TokenRefreshService{
|
||||||
accountRepo: accountRepo,
|
accountRepo: accountRepo,
|
||||||
cfg: &cfg.TokenRefresh,
|
cfg: &cfg.TokenRefresh,
|
||||||
stopCh: make(chan struct{}),
|
cacheInvalidator: cacheInvalidator,
|
||||||
|
stopCh: make(chan struct{}),
|
||||||
}
|
}
|
||||||
|
|
||||||
// 注册平台特定的刷新器
|
// 注册平台特定的刷新器
|
||||||
@@ -169,6 +172,14 @@ func (s *TokenRefreshService) refreshWithRetry(ctx context.Context, account *Acc
|
|||||||
if err := s.accountRepo.Update(ctx, account); err != nil {
|
if err := s.accountRepo.Update(ctx, account); err != nil {
|
||||||
return fmt.Errorf("failed to save credentials: %w", err)
|
return fmt.Errorf("failed to save credentials: %w", err)
|
||||||
}
|
}
|
||||||
|
if s.cacheInvalidator != nil && account.Type == AccountTypeOAuth &&
|
||||||
|
(account.Platform == PlatformGemini || account.Platform == PlatformAntigravity) {
|
||||||
|
if err := s.cacheInvalidator.InvalidateToken(ctx, account); err != nil {
|
||||||
|
log.Printf("[TokenRefresh] Failed to invalidate token cache for account %d: %v", account.ID, err)
|
||||||
|
} else {
|
||||||
|
log.Printf("[TokenRefresh] Token cache invalidated for account %d", account.ID)
|
||||||
|
}
|
||||||
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
361
backend/internal/service/token_refresh_service_test.go
Normal file
361
backend/internal/service/token_refresh_service_test.go
Normal file
@@ -0,0 +1,361 @@
|
|||||||
|
//go:build unit
|
||||||
|
|
||||||
|
package service
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
type tokenRefreshAccountRepo struct {
|
||||||
|
mockAccountRepoForGemini
|
||||||
|
updateCalls int
|
||||||
|
setErrorCalls int
|
||||||
|
lastAccount *Account
|
||||||
|
updateErr error
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *tokenRefreshAccountRepo) Update(ctx context.Context, account *Account) error {
|
||||||
|
r.updateCalls++
|
||||||
|
r.lastAccount = account
|
||||||
|
return r.updateErr
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *tokenRefreshAccountRepo) SetError(ctx context.Context, id int64, errorMsg string) error {
|
||||||
|
r.setErrorCalls++
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
type tokenCacheInvalidatorStub struct {
|
||||||
|
calls int
|
||||||
|
err error
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *tokenCacheInvalidatorStub) InvalidateToken(ctx context.Context, account *Account) error {
|
||||||
|
s.calls++
|
||||||
|
return s.err
|
||||||
|
}
|
||||||
|
|
||||||
|
type tokenRefresherStub struct {
|
||||||
|
credentials map[string]any
|
||||||
|
err error
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *tokenRefresherStub) CanRefresh(account *Account) bool {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *tokenRefresherStub) NeedsRefresh(account *Account, refreshWindowDuration time.Duration) bool {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *tokenRefresherStub) Refresh(ctx context.Context, account *Account) (map[string]any, error) {
|
||||||
|
if r.err != nil {
|
||||||
|
return nil, r.err
|
||||||
|
}
|
||||||
|
return r.credentials, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTokenRefreshService_RefreshWithRetry_InvalidatesCache(t *testing.T) {
|
||||||
|
repo := &tokenRefreshAccountRepo{}
|
||||||
|
invalidator := &tokenCacheInvalidatorStub{}
|
||||||
|
cfg := &config.Config{
|
||||||
|
TokenRefresh: config.TokenRefreshConfig{
|
||||||
|
MaxRetries: 1,
|
||||||
|
RetryBackoffSeconds: 0,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
service := NewTokenRefreshService(repo, nil, nil, nil, nil, invalidator, cfg)
|
||||||
|
account := &Account{
|
||||||
|
ID: 5,
|
||||||
|
Platform: PlatformGemini,
|
||||||
|
Type: AccountTypeOAuth,
|
||||||
|
}
|
||||||
|
refresher := &tokenRefresherStub{
|
||||||
|
credentials: map[string]any{
|
||||||
|
"access_token": "new-token",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
err := service.refreshWithRetry(context.Background(), account, refresher)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, 1, repo.updateCalls)
|
||||||
|
require.Equal(t, 1, invalidator.calls)
|
||||||
|
require.Equal(t, "new-token", account.GetCredential("access_token"))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTokenRefreshService_RefreshWithRetry_InvalidatorErrorIgnored(t *testing.T) {
|
||||||
|
repo := &tokenRefreshAccountRepo{}
|
||||||
|
invalidator := &tokenCacheInvalidatorStub{err: errors.New("invalidate failed")}
|
||||||
|
cfg := &config.Config{
|
||||||
|
TokenRefresh: config.TokenRefreshConfig{
|
||||||
|
MaxRetries: 1,
|
||||||
|
RetryBackoffSeconds: 0,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
service := NewTokenRefreshService(repo, nil, nil, nil, nil, invalidator, cfg)
|
||||||
|
account := &Account{
|
||||||
|
ID: 6,
|
||||||
|
Platform: PlatformGemini,
|
||||||
|
Type: AccountTypeOAuth,
|
||||||
|
}
|
||||||
|
refresher := &tokenRefresherStub{
|
||||||
|
credentials: map[string]any{
|
||||||
|
"access_token": "token",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
err := service.refreshWithRetry(context.Background(), account, refresher)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, 1, repo.updateCalls)
|
||||||
|
require.Equal(t, 1, invalidator.calls)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTokenRefreshService_RefreshWithRetry_NilInvalidator(t *testing.T) {
|
||||||
|
repo := &tokenRefreshAccountRepo{}
|
||||||
|
cfg := &config.Config{
|
||||||
|
TokenRefresh: config.TokenRefreshConfig{
|
||||||
|
MaxRetries: 1,
|
||||||
|
RetryBackoffSeconds: 0,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
service := NewTokenRefreshService(repo, nil, nil, nil, nil, nil, cfg)
|
||||||
|
account := &Account{
|
||||||
|
ID: 7,
|
||||||
|
Platform: PlatformGemini,
|
||||||
|
Type: AccountTypeOAuth,
|
||||||
|
}
|
||||||
|
refresher := &tokenRefresherStub{
|
||||||
|
credentials: map[string]any{
|
||||||
|
"access_token": "token",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
err := service.refreshWithRetry(context.Background(), account, refresher)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, 1, repo.updateCalls)
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestTokenRefreshService_RefreshWithRetry_Antigravity 测试 Antigravity 平台的缓存失效
|
||||||
|
func TestTokenRefreshService_RefreshWithRetry_Antigravity(t *testing.T) {
|
||||||
|
repo := &tokenRefreshAccountRepo{}
|
||||||
|
invalidator := &tokenCacheInvalidatorStub{}
|
||||||
|
cfg := &config.Config{
|
||||||
|
TokenRefresh: config.TokenRefreshConfig{
|
||||||
|
MaxRetries: 1,
|
||||||
|
RetryBackoffSeconds: 0,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
service := NewTokenRefreshService(repo, nil, nil, nil, nil, invalidator, cfg)
|
||||||
|
account := &Account{
|
||||||
|
ID: 8,
|
||||||
|
Platform: PlatformAntigravity,
|
||||||
|
Type: AccountTypeOAuth,
|
||||||
|
}
|
||||||
|
refresher := &tokenRefresherStub{
|
||||||
|
credentials: map[string]any{
|
||||||
|
"access_token": "ag-token",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
err := service.refreshWithRetry(context.Background(), account, refresher)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, 1, repo.updateCalls)
|
||||||
|
require.Equal(t, 1, invalidator.calls) // Antigravity 也应触发缓存失效
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestTokenRefreshService_RefreshWithRetry_NonOAuthAccount 测试非 OAuth 账号不触发缓存失效
|
||||||
|
func TestTokenRefreshService_RefreshWithRetry_NonOAuthAccount(t *testing.T) {
|
||||||
|
repo := &tokenRefreshAccountRepo{}
|
||||||
|
invalidator := &tokenCacheInvalidatorStub{}
|
||||||
|
cfg := &config.Config{
|
||||||
|
TokenRefresh: config.TokenRefreshConfig{
|
||||||
|
MaxRetries: 1,
|
||||||
|
RetryBackoffSeconds: 0,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
service := NewTokenRefreshService(repo, nil, nil, nil, nil, invalidator, cfg)
|
||||||
|
account := &Account{
|
||||||
|
ID: 9,
|
||||||
|
Platform: PlatformGemini,
|
||||||
|
Type: AccountTypeAPIKey, // 非 OAuth
|
||||||
|
}
|
||||||
|
refresher := &tokenRefresherStub{
|
||||||
|
credentials: map[string]any{
|
||||||
|
"access_token": "token",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
err := service.refreshWithRetry(context.Background(), account, refresher)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, 1, repo.updateCalls)
|
||||||
|
require.Equal(t, 0, invalidator.calls) // 非 OAuth 不触发缓存失效
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestTokenRefreshService_RefreshWithRetry_OtherPlatformOAuth 测试其他平台的 OAuth 账号不触发缓存失效
|
||||||
|
func TestTokenRefreshService_RefreshWithRetry_OtherPlatformOAuth(t *testing.T) {
|
||||||
|
repo := &tokenRefreshAccountRepo{}
|
||||||
|
invalidator := &tokenCacheInvalidatorStub{}
|
||||||
|
cfg := &config.Config{
|
||||||
|
TokenRefresh: config.TokenRefreshConfig{
|
||||||
|
MaxRetries: 1,
|
||||||
|
RetryBackoffSeconds: 0,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
service := NewTokenRefreshService(repo, nil, nil, nil, nil, invalidator, cfg)
|
||||||
|
account := &Account{
|
||||||
|
ID: 10,
|
||||||
|
Platform: PlatformOpenAI, // 其他平台
|
||||||
|
Type: AccountTypeOAuth,
|
||||||
|
}
|
||||||
|
refresher := &tokenRefresherStub{
|
||||||
|
credentials: map[string]any{
|
||||||
|
"access_token": "token",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
err := service.refreshWithRetry(context.Background(), account, refresher)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, 1, repo.updateCalls)
|
||||||
|
require.Equal(t, 0, invalidator.calls) // 其他平台不触发缓存失效
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestTokenRefreshService_RefreshWithRetry_UpdateFailed 测试更新失败的情况
|
||||||
|
func TestTokenRefreshService_RefreshWithRetry_UpdateFailed(t *testing.T) {
|
||||||
|
repo := &tokenRefreshAccountRepo{updateErr: errors.New("update failed")}
|
||||||
|
invalidator := &tokenCacheInvalidatorStub{}
|
||||||
|
cfg := &config.Config{
|
||||||
|
TokenRefresh: config.TokenRefreshConfig{
|
||||||
|
MaxRetries: 1,
|
||||||
|
RetryBackoffSeconds: 0,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
service := NewTokenRefreshService(repo, nil, nil, nil, nil, invalidator, cfg)
|
||||||
|
account := &Account{
|
||||||
|
ID: 11,
|
||||||
|
Platform: PlatformGemini,
|
||||||
|
Type: AccountTypeOAuth,
|
||||||
|
}
|
||||||
|
refresher := &tokenRefresherStub{
|
||||||
|
credentials: map[string]any{
|
||||||
|
"access_token": "token",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
err := service.refreshWithRetry(context.Background(), account, refresher)
|
||||||
|
require.Error(t, err)
|
||||||
|
require.Contains(t, err.Error(), "failed to save credentials")
|
||||||
|
require.Equal(t, 1, repo.updateCalls)
|
||||||
|
require.Equal(t, 0, invalidator.calls) // 更新失败时不应触发缓存失效
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestTokenRefreshService_RefreshWithRetry_RefreshFailed 测试刷新失败的情况
|
||||||
|
func TestTokenRefreshService_RefreshWithRetry_RefreshFailed(t *testing.T) {
|
||||||
|
repo := &tokenRefreshAccountRepo{}
|
||||||
|
invalidator := &tokenCacheInvalidatorStub{}
|
||||||
|
cfg := &config.Config{
|
||||||
|
TokenRefresh: config.TokenRefreshConfig{
|
||||||
|
MaxRetries: 2,
|
||||||
|
RetryBackoffSeconds: 0,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
service := NewTokenRefreshService(repo, nil, nil, nil, nil, invalidator, cfg)
|
||||||
|
account := &Account{
|
||||||
|
ID: 12,
|
||||||
|
Platform: PlatformGemini,
|
||||||
|
Type: AccountTypeOAuth,
|
||||||
|
}
|
||||||
|
refresher := &tokenRefresherStub{
|
||||||
|
err: errors.New("refresh failed"),
|
||||||
|
}
|
||||||
|
|
||||||
|
err := service.refreshWithRetry(context.Background(), account, refresher)
|
||||||
|
require.Error(t, err)
|
||||||
|
require.Equal(t, 0, repo.updateCalls) // 刷新失败不应更新
|
||||||
|
require.Equal(t, 0, invalidator.calls) // 刷新失败不应触发缓存失效
|
||||||
|
require.Equal(t, 1, repo.setErrorCalls) // 应设置错误状态
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestTokenRefreshService_RefreshWithRetry_AntigravityRefreshFailed 测试 Antigravity 刷新失败不设置错误状态
|
||||||
|
func TestTokenRefreshService_RefreshWithRetry_AntigravityRefreshFailed(t *testing.T) {
|
||||||
|
repo := &tokenRefreshAccountRepo{}
|
||||||
|
invalidator := &tokenCacheInvalidatorStub{}
|
||||||
|
cfg := &config.Config{
|
||||||
|
TokenRefresh: config.TokenRefreshConfig{
|
||||||
|
MaxRetries: 1,
|
||||||
|
RetryBackoffSeconds: 0,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
service := NewTokenRefreshService(repo, nil, nil, nil, nil, invalidator, cfg)
|
||||||
|
account := &Account{
|
||||||
|
ID: 13,
|
||||||
|
Platform: PlatformAntigravity,
|
||||||
|
Type: AccountTypeOAuth,
|
||||||
|
}
|
||||||
|
refresher := &tokenRefresherStub{
|
||||||
|
err: errors.New("network error"), // 可重试错误
|
||||||
|
}
|
||||||
|
|
||||||
|
err := service.refreshWithRetry(context.Background(), account, refresher)
|
||||||
|
require.Error(t, err)
|
||||||
|
require.Equal(t, 0, repo.updateCalls)
|
||||||
|
require.Equal(t, 0, invalidator.calls)
|
||||||
|
require.Equal(t, 0, repo.setErrorCalls) // Antigravity 可重试错误不设置错误状态
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestTokenRefreshService_RefreshWithRetry_AntigravityNonRetryableError 测试 Antigravity 不可重试错误
|
||||||
|
func TestTokenRefreshService_RefreshWithRetry_AntigravityNonRetryableError(t *testing.T) {
|
||||||
|
repo := &tokenRefreshAccountRepo{}
|
||||||
|
invalidator := &tokenCacheInvalidatorStub{}
|
||||||
|
cfg := &config.Config{
|
||||||
|
TokenRefresh: config.TokenRefreshConfig{
|
||||||
|
MaxRetries: 3,
|
||||||
|
RetryBackoffSeconds: 0,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
service := NewTokenRefreshService(repo, nil, nil, nil, nil, invalidator, cfg)
|
||||||
|
account := &Account{
|
||||||
|
ID: 14,
|
||||||
|
Platform: PlatformAntigravity,
|
||||||
|
Type: AccountTypeOAuth,
|
||||||
|
}
|
||||||
|
refresher := &tokenRefresherStub{
|
||||||
|
err: errors.New("invalid_grant: token revoked"), // 不可重试错误
|
||||||
|
}
|
||||||
|
|
||||||
|
err := service.refreshWithRetry(context.Background(), account, refresher)
|
||||||
|
require.Error(t, err)
|
||||||
|
require.Equal(t, 0, repo.updateCalls)
|
||||||
|
require.Equal(t, 0, invalidator.calls)
|
||||||
|
require.Equal(t, 1, repo.setErrorCalls) // 不可重试错误应设置错误状态
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestIsNonRetryableRefreshError 测试不可重试错误判断
|
||||||
|
func TestIsNonRetryableRefreshError(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
err error
|
||||||
|
expected bool
|
||||||
|
}{
|
||||||
|
{name: "nil_error", err: nil, expected: false},
|
||||||
|
{name: "network_error", err: errors.New("network timeout"), expected: false},
|
||||||
|
{name: "invalid_grant", err: errors.New("invalid_grant"), expected: true},
|
||||||
|
{name: "invalid_client", err: errors.New("invalid_client"), expected: true},
|
||||||
|
{name: "unauthorized_client", err: errors.New("unauthorized_client"), expected: true},
|
||||||
|
{name: "access_denied", err: errors.New("access_denied"), expected: true},
|
||||||
|
{name: "invalid_grant_with_desc", err: errors.New("Error: invalid_grant - token revoked"), expected: true},
|
||||||
|
{name: "case_insensitive", err: errors.New("INVALID_GRANT"), expected: true},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
result := isNonRetryableRefreshError(tt.err)
|
||||||
|
require.Equal(t, tt.expected, result)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -42,9 +42,10 @@ func ProvideTokenRefreshService(
|
|||||||
openaiOAuthService *OpenAIOAuthService,
|
openaiOAuthService *OpenAIOAuthService,
|
||||||
geminiOAuthService *GeminiOAuthService,
|
geminiOAuthService *GeminiOAuthService,
|
||||||
antigravityOAuthService *AntigravityOAuthService,
|
antigravityOAuthService *AntigravityOAuthService,
|
||||||
|
cacheInvalidator TokenCacheInvalidator,
|
||||||
cfg *config.Config,
|
cfg *config.Config,
|
||||||
) *TokenRefreshService {
|
) *TokenRefreshService {
|
||||||
svc := NewTokenRefreshService(accountRepo, oauthService, openaiOAuthService, geminiOAuthService, antigravityOAuthService, cfg)
|
svc := NewTokenRefreshService(accountRepo, oauthService, openaiOAuthService, geminiOAuthService, antigravityOAuthService, cacheInvalidator, cfg)
|
||||||
svc.Start()
|
svc.Start()
|
||||||
return svc
|
return svc
|
||||||
}
|
}
|
||||||
@@ -108,10 +109,12 @@ func ProvideRateLimitService(
|
|||||||
tempUnschedCache TempUnschedCache,
|
tempUnschedCache TempUnschedCache,
|
||||||
timeoutCounterCache TimeoutCounterCache,
|
timeoutCounterCache TimeoutCounterCache,
|
||||||
settingService *SettingService,
|
settingService *SettingService,
|
||||||
|
tokenCacheInvalidator TokenCacheInvalidator,
|
||||||
) *RateLimitService {
|
) *RateLimitService {
|
||||||
svc := NewRateLimitService(accountRepo, usageRepo, cfg, geminiQuotaService, tempUnschedCache)
|
svc := NewRateLimitService(accountRepo, usageRepo, cfg, geminiQuotaService, tempUnschedCache)
|
||||||
svc.SetTimeoutCounterCache(timeoutCounterCache)
|
svc.SetTimeoutCounterCache(timeoutCounterCache)
|
||||||
svc.SetSettingService(settingService)
|
svc.SetSettingService(settingService)
|
||||||
|
svc.SetTokenCacheInvalidator(tokenCacheInvalidator)
|
||||||
return svc
|
return svc
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -210,6 +213,7 @@ var ProviderSet = wire.NewSet(
|
|||||||
NewOpenAIOAuthService,
|
NewOpenAIOAuthService,
|
||||||
NewGeminiOAuthService,
|
NewGeminiOAuthService,
|
||||||
NewGeminiQuotaService,
|
NewGeminiQuotaService,
|
||||||
|
NewCompositeTokenCacheInvalidator,
|
||||||
NewAntigravityOAuthService,
|
NewAntigravityOAuthService,
|
||||||
NewGeminiTokenProvider,
|
NewGeminiTokenProvider,
|
||||||
NewGeminiMessagesCompatService,
|
NewGeminiMessagesCompatService,
|
||||||
|
|||||||
@@ -387,6 +387,9 @@ rate_limit:
|
|||||||
# Cooldown time (in minutes) when upstream returns 529 (overloaded)
|
# Cooldown time (in minutes) when upstream returns 529 (overloaded)
|
||||||
# 上游返回 529(过载)时的冷却时间(分钟)
|
# 上游返回 529(过载)时的冷却时间(分钟)
|
||||||
overload_cooldown_minutes: 10
|
overload_cooldown_minutes: 10
|
||||||
|
# Cooldown time (in minutes) for OAuth 401 temporary unschedulable
|
||||||
|
# OAuth 401 临时不可调度冷却时间(分钟)
|
||||||
|
oauth_401_cooldown_minutes: 5
|
||||||
|
|
||||||
# =============================================================================
|
# =============================================================================
|
||||||
# Pricing Data Source (Optional)
|
# Pricing Data Source (Optional)
|
||||||
|
|||||||
@@ -69,6 +69,17 @@ JWT_EXPIRE_HOUR=24
|
|||||||
# Leave unset to use default ./config.yaml
|
# Leave unset to use default ./config.yaml
|
||||||
#CONFIG_FILE=./config.yaml
|
#CONFIG_FILE=./config.yaml
|
||||||
|
|
||||||
|
# -----------------------------------------------------------------------------
|
||||||
|
# Rate Limiting (Optional)
|
||||||
|
# 速率限制(可选)
|
||||||
|
# -----------------------------------------------------------------------------
|
||||||
|
# Cooldown time (in minutes) when upstream returns 529 (overloaded)
|
||||||
|
# 上游返回 529(过载)时的冷却时间(分钟)
|
||||||
|
RATE_LIMIT_OVERLOAD_COOLDOWN_MINUTES=10
|
||||||
|
# Cooldown time (in minutes) for OAuth 401 temporary unschedulable
|
||||||
|
# OAuth 401 临时不可调度冷却时间(分钟)
|
||||||
|
RATE_LIMIT_OAUTH_401_COOLDOWN_MINUTES=5
|
||||||
|
|
||||||
# -----------------------------------------------------------------------------
|
# -----------------------------------------------------------------------------
|
||||||
# Gateway Scheduling (Optional)
|
# Gateway Scheduling (Optional)
|
||||||
# 调度缓存与受控回源配置(缓存就绪且命中时不读 DB)
|
# 调度缓存与受控回源配置(缓存就绪且命中时不读 DB)
|
||||||
|
|||||||
@@ -429,6 +429,9 @@ rate_limit:
|
|||||||
# Cooldown time (in minutes) when upstream returns 529 (overloaded)
|
# Cooldown time (in minutes) when upstream returns 529 (overloaded)
|
||||||
# 上游返回 529(过载)时的冷却时间(分钟)
|
# 上游返回 529(过载)时的冷却时间(分钟)
|
||||||
overload_cooldown_minutes: 10
|
overload_cooldown_minutes: 10
|
||||||
|
# Cooldown time (in minutes) for OAuth 401 temporary unschedulable
|
||||||
|
# OAuth 401 临时不可调度冷却时间(分钟)
|
||||||
|
oauth_401_cooldown_minutes: 5
|
||||||
|
|
||||||
# =============================================================================
|
# =============================================================================
|
||||||
# Pricing Data Source (Optional)
|
# Pricing Data Source (Optional)
|
||||||
|
|||||||
Reference in New Issue
Block a user