Merge PR #296: OAuth token cache invalidation on 401 and refresh

This commit is contained in:
shaw
2026-01-15 16:31:37 +08:00
17 changed files with 949 additions and 68 deletions

View File

@@ -99,12 +99,13 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
geminiQuotaService := service.NewGeminiQuotaService(configConfig, settingRepository)
tempUnschedCache := repository.NewTempUnschedCache(redisClient)
timeoutCounterCache := repository.NewTimeoutCounterCache(redisClient)
rateLimitService := service.ProvideRateLimitService(accountRepository, usageLogRepository, configConfig, geminiQuotaService, tempUnschedCache, timeoutCounterCache, settingService)
geminiTokenCache := repository.NewGeminiTokenCache(redisClient)
tokenCacheInvalidator := service.NewCompositeTokenCacheInvalidator(geminiTokenCache)
rateLimitService := service.ProvideRateLimitService(accountRepository, usageLogRepository, configConfig, geminiQuotaService, tempUnschedCache, timeoutCounterCache, settingService, tokenCacheInvalidator)
claudeUsageFetcher := repository.NewClaudeUsageFetcher()
antigravityQuotaFetcher := service.NewAntigravityQuotaFetcher(proxyRepository)
usageCache := service.NewUsageCache()
accountUsageService := service.NewAccountUsageService(accountRepository, usageLogRepository, claudeUsageFetcher, geminiQuotaService, antigravityQuotaFetcher, usageCache)
geminiTokenCache := repository.NewGeminiTokenCache(redisClient)
geminiTokenProvider := service.NewGeminiTokenProvider(accountRepository, geminiTokenCache, geminiOAuthService)
gatewayCache := repository.NewGatewayCache(redisClient)
antigravityTokenProvider := service.NewAntigravityTokenProvider(accountRepository, geminiTokenCache, antigravityOAuthService)
@@ -167,7 +168,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
opsAlertEvaluatorService := service.ProvideOpsAlertEvaluatorService(opsService, opsRepository, emailService, redisClient, configConfig)
opsCleanupService := service.ProvideOpsCleanupService(opsRepository, db, redisClient, configConfig)
opsScheduledReportService := service.ProvideOpsScheduledReportService(opsService, userService, emailService, redisClient, configConfig)
tokenRefreshService := service.ProvideTokenRefreshService(accountRepository, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, configConfig)
tokenRefreshService := service.ProvideTokenRefreshService(accountRepository, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, tokenCacheInvalidator, configConfig)
accountExpiryService := service.ProvideAccountExpiryService(accountRepository)
v := provideCleanup(client, redisClient, opsMetricsCollector, opsAggregationService, opsAlertEvaluatorService, opsCleanupService, opsScheduledReportService, schedulerSnapshotService, tokenRefreshService, accountExpiryService, pricingService, emailQueueService, billingCacheService, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService)
application := &Application{

View File

@@ -33,6 +33,11 @@ func (c *geminiTokenCache) SetAccessToken(ctx context.Context, cacheKey string,
return c.rdb.Set(ctx, key, token, ttl).Err()
}
func (c *geminiTokenCache) DeleteAccessToken(ctx context.Context, cacheKey string) error {
key := fmt.Sprintf("%s%s", geminiTokenKeyPrefix, cacheKey)
return c.rdb.Del(ctx, key).Err()
}
func (c *geminiTokenCache) AcquireRefreshLock(ctx context.Context, cacheKey string, ttl time.Duration) (bool, error) {
key := fmt.Sprintf("%s%s", geminiRefreshLockKeyPrefix, cacheKey)
return c.rdb.SetNX(ctx, key, 1, ttl).Result()

View File

@@ -0,0 +1,47 @@
//go:build integration
package repository
import (
"errors"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/redis/go-redis/v9"
"github.com/stretchr/testify/require"
"github.com/stretchr/testify/suite"
)
type GeminiTokenCacheSuite struct {
IntegrationRedisSuite
cache service.GeminiTokenCache
}
func (s *GeminiTokenCacheSuite) SetupTest() {
s.IntegrationRedisSuite.SetupTest()
s.cache = NewGeminiTokenCache(s.rdb)
}
func (s *GeminiTokenCacheSuite) TestDeleteAccessToken() {
cacheKey := "project-123"
token := "token-value"
require.NoError(s.T(), s.cache.SetAccessToken(s.ctx, cacheKey, token, time.Minute))
got, err := s.cache.GetAccessToken(s.ctx, cacheKey)
require.NoError(s.T(), err)
require.Equal(s.T(), token, got)
require.NoError(s.T(), s.cache.DeleteAccessToken(s.ctx, cacheKey))
_, err = s.cache.GetAccessToken(s.ctx, cacheKey)
require.True(s.T(), errors.Is(err, redis.Nil), "expected redis.Nil after delete")
}
func (s *GeminiTokenCacheSuite) TestDeleteAccessToken_MissingKey() {
require.NoError(s.T(), s.cache.DeleteAccessToken(s.ctx, "missing-key"))
}
func TestGeminiTokenCacheSuite(t *testing.T) {
suite.Run(t, new(GeminiTokenCacheSuite))
}

View File

@@ -0,0 +1,28 @@
//go:build unit
package repository
import (
"context"
"testing"
"time"
"github.com/redis/go-redis/v9"
"github.com/stretchr/testify/require"
)
func TestGeminiTokenCache_DeleteAccessToken_RedisError(t *testing.T) {
rdb := redis.NewClient(&redis.Options{
Addr: "127.0.0.1:1",
DialTimeout: 50 * time.Millisecond,
ReadTimeout: 50 * time.Millisecond,
WriteTimeout: 50 * time.Millisecond,
})
t.Cleanup(func() {
_ = rdb.Close()
})
cache := NewGeminiTokenCache(rdb)
err := cache.DeleteAccessToken(context.Background(), "broken")
require.Error(t, err)
}

View File

@@ -45,7 +45,7 @@ func (p *AntigravityTokenProvider) GetAccessToken(ctx context.Context, account *
return "", errors.New("not an antigravity oauth account")
}
cacheKey := antigravityTokenCacheKey(account)
cacheKey := AntigravityTokenCacheKey(account)
// 1. 先尝试缓存
if p.tokenCache != nil {
@@ -121,7 +121,7 @@ func (p *AntigravityTokenProvider) GetAccessToken(ctx context.Context, account *
return accessToken, nil
}
func antigravityTokenCacheKey(account *Account) string {
func AntigravityTokenCacheKey(account *Account) string {
projectID := strings.TrimSpace(account.GetCredential("project_id"))
if projectID != "" {
return "ag:" + projectID

View File

@@ -10,6 +10,7 @@ type GeminiTokenCache interface {
// cacheKey should be stable for the token scope; for GeminiCli OAuth we primarily use project_id.
GetAccessToken(ctx context.Context, cacheKey string) (string, error)
SetAccessToken(ctx context.Context, cacheKey string, token string, ttl time.Duration) error
DeleteAccessToken(ctx context.Context, cacheKey string) error
AcquireRefreshLock(ctx context.Context, cacheKey string, ttl time.Duration) (bool, error)
ReleaseRefreshLock(ctx context.Context, cacheKey string) error

View File

@@ -40,7 +40,7 @@ func (p *GeminiTokenProvider) GetAccessToken(ctx context.Context, account *Accou
return "", errors.New("not a gemini oauth account")
}
cacheKey := geminiTokenCacheKey(account)
cacheKey := GeminiTokenCacheKey(account)
// 1) Try cache first.
if p.tokenCache != nil {
@@ -151,7 +151,7 @@ func (p *GeminiTokenProvider) GetAccessToken(ctx context.Context, account *Accou
return accessToken, nil
}
func geminiTokenCacheKey(account *Account) string {
func GeminiTokenCacheKey(account *Account) string {
projectID := strings.TrimSpace(account.GetCredential("project_id"))
if projectID != "" {
return projectID

View File

@@ -3,7 +3,7 @@ package service
import (
"context"
"encoding/json"
"log"
"log/slog"
"net/http"
"strconv"
"strings"
@@ -15,15 +15,16 @@ import (
// RateLimitService 处理限流和过载状态管理
type RateLimitService struct {
accountRepo AccountRepository
usageRepo UsageLogRepository
cfg *config.Config
geminiQuotaService *GeminiQuotaService
tempUnschedCache TempUnschedCache
timeoutCounterCache TimeoutCounterCache
settingService *SettingService
usageCacheMu sync.RWMutex
usageCache map[int64]*geminiUsageCacheEntry
accountRepo AccountRepository
usageRepo UsageLogRepository
cfg *config.Config
geminiQuotaService *GeminiQuotaService
tempUnschedCache TempUnschedCache
timeoutCounterCache TimeoutCounterCache
settingService *SettingService
tokenCacheInvalidator TokenCacheInvalidator
usageCacheMu sync.RWMutex
usageCache map[int64]*geminiUsageCacheEntry
}
type geminiUsageCacheEntry struct {
@@ -56,6 +57,11 @@ func (s *RateLimitService) SetSettingService(settingService *SettingService) {
s.settingService = settingService
}
// SetTokenCacheInvalidator 设置 token 缓存清理器(可选依赖)
func (s *RateLimitService) SetTokenCacheInvalidator(invalidator TokenCacheInvalidator) {
s.tokenCacheInvalidator = invalidator
}
// HandleUpstreamError 处理上游错误响应,标记账号状态
// 返回是否应该停止该账号的调度
func (s *RateLimitService) HandleUpstreamError(ctx context.Context, account *Account, statusCode int, headers http.Header, responseBody []byte) (shouldDisable bool) {
@@ -63,11 +69,14 @@ func (s *RateLimitService) HandleUpstreamError(ctx context.Context, account *Acc
// 如果启用且错误码不在列表中,则不处理(不停止调度、不标记限流/过载)
customErrorCodesEnabled := account.IsCustomErrorCodesEnabled()
if !account.ShouldHandleErrorCode(statusCode) {
log.Printf("Account %d: error %d skipped (not in custom error codes)", account.ID, statusCode)
slog.Info("account_error_code_skipped", "account_id", account.ID, "status_code", statusCode)
return false
}
tempMatched := s.tryTempUnschedulable(ctx, account, statusCode, responseBody)
tempMatched := false
if statusCode != 401 {
tempMatched = s.tryTempUnschedulable(ctx, account, statusCode, responseBody)
}
upstreamMsg := strings.TrimSpace(extractUpstreamErrorMessage(responseBody))
upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg)
if upstreamMsg != "" {
@@ -76,7 +85,14 @@ func (s *RateLimitService) HandleUpstreamError(ctx context.Context, account *Acc
switch statusCode {
case 401:
// 认证失败:停止调度,记录错误
if account.Type == AccountTypeOAuth &&
(account.Platform == PlatformAntigravity || account.Platform == PlatformGemini) {
if s.tokenCacheInvalidator != nil {
if err := s.tokenCacheInvalidator.InvalidateToken(ctx, account); err != nil {
slog.Warn("oauth_401_invalidate_cache_failed", "account_id", account.ID, "error", err)
}
}
}
msg := "Authentication failed (401): invalid or expired credentials"
if upstreamMsg != "" {
msg = "Authentication failed (401): " + upstreamMsg
@@ -116,7 +132,7 @@ func (s *RateLimitService) HandleUpstreamError(ctx context.Context, account *Acc
shouldDisable = true
} else if statusCode >= 500 {
// 未启用自定义错误码时仅记录5xx错误
log.Printf("Account %d received upstream error %d", account.ID, statusCode)
slog.Warn("account_upstream_error", "account_id", account.ID, "status_code", statusCode)
shouldDisable = false
}
}
@@ -188,7 +204,7 @@ func (s *RateLimitService) PreCheckUsage(ctx context.Context, account *Account,
// NOTE:
// - This is a local precheck to reduce upstream 429s.
// - Do NOT mark the account as rate-limited here; rate_limit_reset_at should reflect real upstream 429s.
log.Printf("[Gemini PreCheck] Account %d reached daily quota (%d/%d), skip until %v", account.ID, used, limit, resetAt)
slog.Info("gemini_precheck_daily_quota_reached", "account_id", account.ID, "used", used, "limit", limit, "reset_at", resetAt)
return false, nil
}
}
@@ -231,7 +247,7 @@ func (s *RateLimitService) PreCheckUsage(ctx context.Context, account *Account,
if used >= limit {
resetAt := start.Add(time.Minute)
// Do not persist "rate limited" status from local precheck. See note above.
log.Printf("[Gemini PreCheck] Account %d reached minute quota (%d/%d), skip until %v", account.ID, used, limit, resetAt)
slog.Info("gemini_precheck_minute_quota_reached", "account_id", account.ID, "used", used, "limit", limit, "reset_at", resetAt)
return false, nil
}
}
@@ -288,20 +304,20 @@ func (s *RateLimitService) GeminiCooldown(ctx context.Context, account *Account)
// handleAuthError 处理认证类错误(401/403),停止账号调度
func (s *RateLimitService) handleAuthError(ctx context.Context, account *Account, errorMsg string) {
if err := s.accountRepo.SetError(ctx, account.ID, errorMsg); err != nil {
log.Printf("SetError failed for account %d: %v", account.ID, err)
slog.Warn("account_set_error_failed", "account_id", account.ID, "error", err)
return
}
log.Printf("Account %d disabled due to auth error: %s", account.ID, errorMsg)
slog.Warn("account_disabled_auth_error", "account_id", account.ID, "error", errorMsg)
}
// handleCustomErrorCode 处理自定义错误码,停止账号调度
func (s *RateLimitService) handleCustomErrorCode(ctx context.Context, account *Account, statusCode int, errorMsg string) {
msg := "Custom error code " + strconv.Itoa(statusCode) + ": " + errorMsg
if err := s.accountRepo.SetError(ctx, account.ID, msg); err != nil {
log.Printf("SetError failed for account %d: %v", account.ID, err)
slog.Warn("account_set_error_failed", "account_id", account.ID, "status_code", statusCode, "error", err)
return
}
log.Printf("Account %d disabled due to custom error code %d: %s", account.ID, statusCode, errorMsg)
slog.Warn("account_disabled_custom_error", "account_id", account.ID, "status_code", statusCode, "error", errorMsg)
}
// handle429 处理429限流错误
@@ -313,7 +329,7 @@ func (s *RateLimitService) handle429(ctx context.Context, account *Account, head
// 没有重置时间使用默认5分钟
resetAt := time.Now().Add(5 * time.Minute)
if err := s.accountRepo.SetRateLimited(ctx, account.ID, resetAt); err != nil {
log.Printf("SetRateLimited failed for account %d: %v", account.ID, err)
slog.Warn("rate_limit_set_failed", "account_id", account.ID, "error", err)
}
return
}
@@ -321,10 +337,10 @@ func (s *RateLimitService) handle429(ctx context.Context, account *Account, head
// 解析Unix时间戳
ts, err := strconv.ParseInt(resetTimestamp, 10, 64)
if err != nil {
log.Printf("Parse reset timestamp failed: %v", err)
slog.Warn("rate_limit_reset_parse_failed", "reset_timestamp", resetTimestamp, "error", err)
resetAt := time.Now().Add(5 * time.Minute)
if err := s.accountRepo.SetRateLimited(ctx, account.ID, resetAt); err != nil {
log.Printf("SetRateLimited failed for account %d: %v", account.ID, err)
slog.Warn("rate_limit_set_failed", "account_id", account.ID, "error", err)
}
return
}
@@ -333,7 +349,7 @@ func (s *RateLimitService) handle429(ctx context.Context, account *Account, head
// 标记限流状态
if err := s.accountRepo.SetRateLimited(ctx, account.ID, resetAt); err != nil {
log.Printf("SetRateLimited failed for account %d: %v", account.ID, err)
slog.Warn("rate_limit_set_failed", "account_id", account.ID, "error", err)
return
}
@@ -341,10 +357,10 @@ func (s *RateLimitService) handle429(ctx context.Context, account *Account, head
windowEnd := resetAt
windowStart := resetAt.Add(-5 * time.Hour)
if err := s.accountRepo.UpdateSessionWindow(ctx, account.ID, &windowStart, &windowEnd, "rejected"); err != nil {
log.Printf("UpdateSessionWindow failed for account %d: %v", account.ID, err)
slog.Warn("rate_limit_update_session_window_failed", "account_id", account.ID, "error", err)
}
log.Printf("Account %d rate limited until %v", account.ID, resetAt)
slog.Info("account_rate_limited", "account_id", account.ID, "reset_at", resetAt)
}
// handle529 处理529过载错误
@@ -357,11 +373,11 @@ func (s *RateLimitService) handle529(ctx context.Context, account *Account) {
until := time.Now().Add(time.Duration(cooldownMinutes) * time.Minute)
if err := s.accountRepo.SetOverloaded(ctx, account.ID, until); err != nil {
log.Printf("SetOverloaded failed for account %d: %v", account.ID, err)
slog.Warn("overload_set_failed", "account_id", account.ID, "error", err)
return
}
log.Printf("Account %d overloaded until %v", account.ID, until)
slog.Info("account_overloaded", "account_id", account.ID, "until", until)
}
// UpdateSessionWindow 从成功响应更新5h窗口状态
@@ -384,17 +400,17 @@ func (s *RateLimitService) UpdateSessionWindow(ctx context.Context, account *Acc
end := start.Add(5 * time.Hour)
windowStart = &start
windowEnd = &end
log.Printf("Account %d: initializing 5h window from %v to %v (status: %s)", account.ID, start, end, status)
slog.Info("account_session_window_initialized", "account_id", account.ID, "window_start", start, "window_end", end, "status", status)
}
if err := s.accountRepo.UpdateSessionWindow(ctx, account.ID, windowStart, windowEnd, status); err != nil {
log.Printf("UpdateSessionWindow failed for account %d: %v", account.ID, err)
slog.Warn("session_window_update_failed", "account_id", account.ID, "error", err)
}
// 如果状态为allowed且之前有限流说明窗口已重置清除限流状态
if status == "allowed" && account.IsRateLimited() {
if err := s.ClearRateLimit(ctx, account.ID); err != nil {
log.Printf("ClearRateLimit failed for account %d: %v", account.ID, err)
slog.Warn("rate_limit_clear_failed", "account_id", account.ID, "error", err)
}
}
}
@@ -413,7 +429,7 @@ func (s *RateLimitService) ClearTempUnschedulable(ctx context.Context, accountID
}
if s.tempUnschedCache != nil {
if err := s.tempUnschedCache.DeleteTempUnsched(ctx, accountID); err != nil {
log.Printf("DeleteTempUnsched failed for account %d: %v", accountID, err)
slog.Warn("temp_unsched_cache_delete_failed", "account_id", accountID, "error", err)
}
}
return nil
@@ -460,7 +476,7 @@ func (s *RateLimitService) GetTempUnschedStatus(ctx context.Context, accountID i
if s.tempUnschedCache != nil {
if err := s.tempUnschedCache.SetTempUnsched(ctx, accountID, state); err != nil {
log.Printf("SetTempUnsched failed for account %d: %v", accountID, err)
slog.Warn("temp_unsched_cache_set_failed", "account_id", accountID, "error", err)
}
}
@@ -563,17 +579,17 @@ func (s *RateLimitService) triggerTempUnschedulable(ctx context.Context, account
}
if err := s.accountRepo.SetTempUnschedulable(ctx, account.ID, until, reason); err != nil {
log.Printf("SetTempUnschedulable failed for account %d: %v", account.ID, err)
slog.Warn("temp_unsched_set_failed", "account_id", account.ID, "error", err)
return false
}
if s.tempUnschedCache != nil {
if err := s.tempUnschedCache.SetTempUnsched(ctx, account.ID, state); err != nil {
log.Printf("SetTempUnsched cache failed for account %d: %v", account.ID, err)
slog.Warn("temp_unsched_cache_set_failed", "account_id", account.ID, "error", err)
}
}
log.Printf("Account %d temp unschedulable until %v (rule %d, code %d)", account.ID, until, ruleIndex, statusCode)
slog.Info("account_temp_unschedulable", "account_id", account.ID, "until", until, "rule_index", ruleIndex, "status_code", statusCode)
return true
}
@@ -597,13 +613,13 @@ func (s *RateLimitService) HandleStreamTimeout(ctx context.Context, account *Acc
// 获取系统设置
if s.settingService == nil {
log.Printf("[StreamTimeout] settingService not configured, skipping timeout handling for account %d", account.ID)
slog.Warn("stream_timeout_setting_service_missing", "account_id", account.ID)
return false
}
settings, err := s.settingService.GetStreamTimeoutSettings(ctx)
if err != nil {
log.Printf("[StreamTimeout] Failed to get settings: %v", err)
slog.Warn("stream_timeout_get_settings_failed", "account_id", account.ID, "error", err)
return false
}
@@ -620,14 +636,13 @@ func (s *RateLimitService) HandleStreamTimeout(ctx context.Context, account *Acc
if s.timeoutCounterCache != nil {
count, err = s.timeoutCounterCache.IncrementTimeoutCount(ctx, account.ID, settings.ThresholdWindowMinutes)
if err != nil {
log.Printf("[StreamTimeout] Failed to increment timeout count for account %d: %v", account.ID, err)
slog.Warn("stream_timeout_increment_count_failed", "account_id", account.ID, "error", err)
// 继续处理,使用 count=1
count = 1
}
}
log.Printf("[StreamTimeout] Account %d timeout count: %d/%d (window: %d min, model: %s)",
account.ID, count, settings.ThresholdCount, settings.ThresholdWindowMinutes, model)
slog.Info("stream_timeout_count", "account_id", account.ID, "count", count, "threshold", settings.ThresholdCount, "window_minutes", settings.ThresholdWindowMinutes, "model", model)
// 检查是否达到阈值
if count < int64(settings.ThresholdCount) {
@@ -668,24 +683,24 @@ func (s *RateLimitService) triggerStreamTimeoutTempUnsched(ctx context.Context,
}
if err := s.accountRepo.SetTempUnschedulable(ctx, account.ID, until, reason); err != nil {
log.Printf("[StreamTimeout] SetTempUnschedulable failed for account %d: %v", account.ID, err)
slog.Warn("stream_timeout_set_temp_unsched_failed", "account_id", account.ID, "error", err)
return false
}
if s.tempUnschedCache != nil {
if err := s.tempUnschedCache.SetTempUnsched(ctx, account.ID, state); err != nil {
log.Printf("[StreamTimeout] SetTempUnsched cache failed for account %d: %v", account.ID, err)
slog.Warn("stream_timeout_set_temp_unsched_cache_failed", "account_id", account.ID, "error", err)
}
}
// 重置超时计数
if s.timeoutCounterCache != nil {
if err := s.timeoutCounterCache.ResetTimeoutCount(ctx, account.ID); err != nil {
log.Printf("[StreamTimeout] ResetTimeoutCount failed for account %d: %v", account.ID, err)
slog.Warn("stream_timeout_reset_count_failed", "account_id", account.ID, "error", err)
}
}
log.Printf("[StreamTimeout] Account %d marked as temp unschedulable until %v (model: %s)", account.ID, until, model)
slog.Info("stream_timeout_temp_unschedulable", "account_id", account.ID, "until", until, "model", model)
return true
}
@@ -694,17 +709,17 @@ func (s *RateLimitService) triggerStreamTimeoutError(ctx context.Context, accoun
errorMsg := "Stream data interval timeout (repeated failures) for model: " + model
if err := s.accountRepo.SetError(ctx, account.ID, errorMsg); err != nil {
log.Printf("[StreamTimeout] SetError failed for account %d: %v", account.ID, err)
slog.Warn("stream_timeout_set_error_failed", "account_id", account.ID, "error", err)
return false
}
// 重置超时计数
if s.timeoutCounterCache != nil {
if err := s.timeoutCounterCache.ResetTimeoutCount(ctx, account.ID); err != nil {
log.Printf("[StreamTimeout] ResetTimeoutCount failed for account %d: %v", account.ID, err)
slog.Warn("stream_timeout_reset_count_failed", "account_id", account.ID, "error", err)
}
}
log.Printf("[StreamTimeout] Account %d marked as error (model: %s)", account.ID, model)
slog.Warn("stream_timeout_account_error", "account_id", account.ID, "model", model)
return true
}

View File

@@ -0,0 +1,121 @@
//go:build unit
package service
import (
"context"
"errors"
"net/http"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/stretchr/testify/require"
)
type rateLimitAccountRepoStub struct {
mockAccountRepoForGemini
setErrorCalls int
tempCalls int
lastErrorMsg string
}
func (r *rateLimitAccountRepoStub) SetError(ctx context.Context, id int64, errorMsg string) error {
r.setErrorCalls++
r.lastErrorMsg = errorMsg
return nil
}
func (r *rateLimitAccountRepoStub) SetTempUnschedulable(ctx context.Context, id int64, until time.Time, reason string) error {
r.tempCalls++
return nil
}
type tokenCacheInvalidatorRecorder struct {
accounts []*Account
err error
}
func (r *tokenCacheInvalidatorRecorder) InvalidateToken(ctx context.Context, account *Account) error {
r.accounts = append(r.accounts, account)
return r.err
}
func TestRateLimitService_HandleUpstreamError_OAuth401MarksError(t *testing.T) {
tests := []struct {
name string
platform string
}{
{name: "gemini", platform: PlatformGemini},
{name: "antigravity", platform: PlatformAntigravity},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
repo := &rateLimitAccountRepoStub{}
invalidator := &tokenCacheInvalidatorRecorder{}
service := NewRateLimitService(repo, nil, &config.Config{}, nil, nil)
service.SetTokenCacheInvalidator(invalidator)
account := &Account{
ID: 100,
Platform: tt.platform,
Type: AccountTypeOAuth,
Credentials: map[string]any{
"temp_unschedulable_enabled": true,
"temp_unschedulable_rules": []any{
map[string]any{
"error_code": 401,
"keywords": []any{"unauthorized"},
"duration_minutes": 30,
"description": "custom rule",
},
},
},
}
shouldDisable := service.HandleUpstreamError(context.Background(), account, 401, http.Header{}, []byte("unauthorized"))
require.True(t, shouldDisable)
require.Equal(t, 1, repo.setErrorCalls)
require.Equal(t, 0, repo.tempCalls)
require.Contains(t, repo.lastErrorMsg, "Authentication failed (401)")
require.Len(t, invalidator.accounts, 1)
})
}
}
func TestRateLimitService_HandleUpstreamError_OAuth401InvalidatorError(t *testing.T) {
repo := &rateLimitAccountRepoStub{}
invalidator := &tokenCacheInvalidatorRecorder{err: errors.New("boom")}
service := NewRateLimitService(repo, nil, &config.Config{}, nil, nil)
service.SetTokenCacheInvalidator(invalidator)
account := &Account{
ID: 101,
Platform: PlatformGemini,
Type: AccountTypeOAuth,
}
shouldDisable := service.HandleUpstreamError(context.Background(), account, 401, http.Header{}, []byte("unauthorized"))
require.True(t, shouldDisable)
require.Equal(t, 1, repo.setErrorCalls)
require.Len(t, invalidator.accounts, 1)
}
func TestRateLimitService_HandleUpstreamError_NonOAuth401(t *testing.T) {
repo := &rateLimitAccountRepoStub{}
invalidator := &tokenCacheInvalidatorRecorder{}
service := NewRateLimitService(repo, nil, &config.Config{}, nil, nil)
service.SetTokenCacheInvalidator(invalidator)
account := &Account{
ID: 102,
Platform: PlatformOpenAI,
Type: AccountTypeAPIKey,
}
shouldDisable := service.HandleUpstreamError(context.Background(), account, 401, http.Header{}, []byte("unauthorized"))
require.True(t, shouldDisable)
require.Equal(t, 1, repo.setErrorCalls)
require.Empty(t, invalidator.accounts)
}

View File

@@ -0,0 +1,35 @@
package service
import "context"
type TokenCacheInvalidator interface {
InvalidateToken(ctx context.Context, account *Account) error
}
type CompositeTokenCacheInvalidator struct {
geminiCache GeminiTokenCache
}
func NewCompositeTokenCacheInvalidator(geminiCache GeminiTokenCache) *CompositeTokenCacheInvalidator {
return &CompositeTokenCacheInvalidator{
geminiCache: geminiCache,
}
}
func (c *CompositeTokenCacheInvalidator) InvalidateToken(ctx context.Context, account *Account) error {
if c == nil || c.geminiCache == nil || account == nil {
return nil
}
if account.Type != AccountTypeOAuth {
return nil
}
switch account.Platform {
case PlatformGemini:
return c.geminiCache.DeleteAccessToken(ctx, GeminiTokenCacheKey(account))
case PlatformAntigravity:
return c.geminiCache.DeleteAccessToken(ctx, AntigravityTokenCacheKey(account))
default:
return nil
}
}

View File

@@ -0,0 +1,97 @@
//go:build unit
package service
import (
"context"
"testing"
"time"
"github.com/stretchr/testify/require"
)
type geminiTokenCacheStub struct {
deletedKeys []string
deleteErr error
}
func (s *geminiTokenCacheStub) GetAccessToken(ctx context.Context, cacheKey string) (string, error) {
return "", nil
}
func (s *geminiTokenCacheStub) SetAccessToken(ctx context.Context, cacheKey string, token string, ttl time.Duration) error {
return nil
}
func (s *geminiTokenCacheStub) DeleteAccessToken(ctx context.Context, cacheKey string) error {
s.deletedKeys = append(s.deletedKeys, cacheKey)
return s.deleteErr
}
func (s *geminiTokenCacheStub) AcquireRefreshLock(ctx context.Context, cacheKey string, ttl time.Duration) (bool, error) {
return true, nil
}
func (s *geminiTokenCacheStub) ReleaseRefreshLock(ctx context.Context, cacheKey string) error {
return nil
}
func TestCompositeTokenCacheInvalidator_Gemini(t *testing.T) {
cache := &geminiTokenCacheStub{}
invalidator := NewCompositeTokenCacheInvalidator(cache)
account := &Account{
ID: 10,
Platform: PlatformGemini,
Type: AccountTypeOAuth,
Credentials: map[string]any{
"project_id": "project-x",
},
}
err := invalidator.InvalidateToken(context.Background(), account)
require.NoError(t, err)
require.Equal(t, []string{"project-x"}, cache.deletedKeys)
}
func TestCompositeTokenCacheInvalidator_Antigravity(t *testing.T) {
cache := &geminiTokenCacheStub{}
invalidator := NewCompositeTokenCacheInvalidator(cache)
account := &Account{
ID: 99,
Platform: PlatformAntigravity,
Type: AccountTypeOAuth,
Credentials: map[string]any{
"project_id": "ag-project",
},
}
err := invalidator.InvalidateToken(context.Background(), account)
require.NoError(t, err)
require.Equal(t, []string{"ag:ag-project"}, cache.deletedKeys)
}
func TestCompositeTokenCacheInvalidator_SkipNonOAuth(t *testing.T) {
cache := &geminiTokenCacheStub{}
invalidator := NewCompositeTokenCacheInvalidator(cache)
account := &Account{
ID: 1,
Platform: PlatformGemini,
Type: AccountTypeAPIKey,
}
err := invalidator.InvalidateToken(context.Background(), account)
require.NoError(t, err)
require.Empty(t, cache.deletedKeys)
}
func TestCompositeTokenCacheInvalidator_NilCache(t *testing.T) {
invalidator := NewCompositeTokenCacheInvalidator(nil)
account := &Account{
ID: 2,
Platform: PlatformGemini,
Type: AccountTypeOAuth,
}
err := invalidator.InvalidateToken(context.Background(), account)
require.NoError(t, err)
}

View File

@@ -0,0 +1,153 @@
//go:build unit
package service
import (
"testing"
"github.com/stretchr/testify/require"
)
func TestGeminiTokenCacheKey(t *testing.T) {
tests := []struct {
name string
account *Account
expected string
}{
{
name: "with_project_id",
account: &Account{
ID: 100,
Credentials: map[string]any{
"project_id": "my-project-123",
},
},
expected: "my-project-123",
},
{
name: "project_id_with_whitespace",
account: &Account{
ID: 101,
Credentials: map[string]any{
"project_id": " project-with-spaces ",
},
},
expected: "project-with-spaces",
},
{
name: "empty_project_id_fallback_to_account_id",
account: &Account{
ID: 102,
Credentials: map[string]any{
"project_id": "",
},
},
expected: "account:102",
},
{
name: "whitespace_only_project_id_fallback_to_account_id",
account: &Account{
ID: 103,
Credentials: map[string]any{
"project_id": " ",
},
},
expected: "account:103",
},
{
name: "no_project_id_key_fallback_to_account_id",
account: &Account{
ID: 104,
Credentials: map[string]any{},
},
expected: "account:104",
},
{
name: "nil_credentials_fallback_to_account_id",
account: &Account{
ID: 105,
Credentials: nil,
},
expected: "account:105",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := GeminiTokenCacheKey(tt.account)
require.Equal(t, tt.expected, result)
})
}
}
func TestAntigravityTokenCacheKey(t *testing.T) {
tests := []struct {
name string
account *Account
expected string
}{
{
name: "with_project_id",
account: &Account{
ID: 200,
Credentials: map[string]any{
"project_id": "ag-project-456",
},
},
expected: "ag:ag-project-456",
},
{
name: "project_id_with_whitespace",
account: &Account{
ID: 201,
Credentials: map[string]any{
"project_id": " ag-project-spaces ",
},
},
expected: "ag:ag-project-spaces",
},
{
name: "empty_project_id_fallback_to_account_id",
account: &Account{
ID: 202,
Credentials: map[string]any{
"project_id": "",
},
},
expected: "ag:account:202",
},
{
name: "whitespace_only_project_id_fallback_to_account_id",
account: &Account{
ID: 203,
Credentials: map[string]any{
"project_id": " ",
},
},
expected: "ag:account:203",
},
{
name: "no_project_id_key_fallback_to_account_id",
account: &Account{
ID: 204,
Credentials: map[string]any{},
},
expected: "ag:account:204",
},
{
name: "nil_credentials_fallback_to_account_id",
account: &Account{
ID: 205,
Credentials: nil,
},
expected: "ag:account:205",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := AntigravityTokenCacheKey(tt.account)
require.Equal(t, tt.expected, result)
})
}
}

View File

@@ -14,9 +14,10 @@ import (
// TokenRefreshService OAuth token自动刷新服务
// 定期检查并刷新即将过期的token
type TokenRefreshService struct {
accountRepo AccountRepository
refreshers []TokenRefresher
cfg *config.TokenRefreshConfig
accountRepo AccountRepository
refreshers []TokenRefresher
cfg *config.TokenRefreshConfig
cacheInvalidator TokenCacheInvalidator
stopCh chan struct{}
wg sync.WaitGroup
@@ -29,12 +30,14 @@ func NewTokenRefreshService(
openaiOAuthService *OpenAIOAuthService,
geminiOAuthService *GeminiOAuthService,
antigravityOAuthService *AntigravityOAuthService,
cacheInvalidator TokenCacheInvalidator,
cfg *config.Config,
) *TokenRefreshService {
s := &TokenRefreshService{
accountRepo: accountRepo,
cfg: &cfg.TokenRefresh,
stopCh: make(chan struct{}),
accountRepo: accountRepo,
cfg: &cfg.TokenRefresh,
cacheInvalidator: cacheInvalidator,
stopCh: make(chan struct{}),
}
// 注册平台特定的刷新器
@@ -169,6 +172,14 @@ func (s *TokenRefreshService) refreshWithRetry(ctx context.Context, account *Acc
if err := s.accountRepo.Update(ctx, account); err != nil {
return fmt.Errorf("failed to save credentials: %w", err)
}
if s.cacheInvalidator != nil && account.Type == AccountTypeOAuth &&
(account.Platform == PlatformGemini || account.Platform == PlatformAntigravity) {
if err := s.cacheInvalidator.InvalidateToken(ctx, account); err != nil {
log.Printf("[TokenRefresh] Failed to invalidate token cache for account %d: %v", account.ID, err)
} else {
log.Printf("[TokenRefresh] Token cache invalidated for account %d", account.ID)
}
}
return nil
}

View File

@@ -0,0 +1,361 @@
//go:build unit
package service
import (
"context"
"errors"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/stretchr/testify/require"
)
type tokenRefreshAccountRepo struct {
mockAccountRepoForGemini
updateCalls int
setErrorCalls int
lastAccount *Account
updateErr error
}
func (r *tokenRefreshAccountRepo) Update(ctx context.Context, account *Account) error {
r.updateCalls++
r.lastAccount = account
return r.updateErr
}
func (r *tokenRefreshAccountRepo) SetError(ctx context.Context, id int64, errorMsg string) error {
r.setErrorCalls++
return nil
}
type tokenCacheInvalidatorStub struct {
calls int
err error
}
func (s *tokenCacheInvalidatorStub) InvalidateToken(ctx context.Context, account *Account) error {
s.calls++
return s.err
}
type tokenRefresherStub struct {
credentials map[string]any
err error
}
func (r *tokenRefresherStub) CanRefresh(account *Account) bool {
return true
}
func (r *tokenRefresherStub) NeedsRefresh(account *Account, refreshWindowDuration time.Duration) bool {
return true
}
func (r *tokenRefresherStub) Refresh(ctx context.Context, account *Account) (map[string]any, error) {
if r.err != nil {
return nil, r.err
}
return r.credentials, nil
}
func TestTokenRefreshService_RefreshWithRetry_InvalidatesCache(t *testing.T) {
repo := &tokenRefreshAccountRepo{}
invalidator := &tokenCacheInvalidatorStub{}
cfg := &config.Config{
TokenRefresh: config.TokenRefreshConfig{
MaxRetries: 1,
RetryBackoffSeconds: 0,
},
}
service := NewTokenRefreshService(repo, nil, nil, nil, nil, invalidator, cfg)
account := &Account{
ID: 5,
Platform: PlatformGemini,
Type: AccountTypeOAuth,
}
refresher := &tokenRefresherStub{
credentials: map[string]any{
"access_token": "new-token",
},
}
err := service.refreshWithRetry(context.Background(), account, refresher)
require.NoError(t, err)
require.Equal(t, 1, repo.updateCalls)
require.Equal(t, 1, invalidator.calls)
require.Equal(t, "new-token", account.GetCredential("access_token"))
}
func TestTokenRefreshService_RefreshWithRetry_InvalidatorErrorIgnored(t *testing.T) {
repo := &tokenRefreshAccountRepo{}
invalidator := &tokenCacheInvalidatorStub{err: errors.New("invalidate failed")}
cfg := &config.Config{
TokenRefresh: config.TokenRefreshConfig{
MaxRetries: 1,
RetryBackoffSeconds: 0,
},
}
service := NewTokenRefreshService(repo, nil, nil, nil, nil, invalidator, cfg)
account := &Account{
ID: 6,
Platform: PlatformGemini,
Type: AccountTypeOAuth,
}
refresher := &tokenRefresherStub{
credentials: map[string]any{
"access_token": "token",
},
}
err := service.refreshWithRetry(context.Background(), account, refresher)
require.NoError(t, err)
require.Equal(t, 1, repo.updateCalls)
require.Equal(t, 1, invalidator.calls)
}
func TestTokenRefreshService_RefreshWithRetry_NilInvalidator(t *testing.T) {
repo := &tokenRefreshAccountRepo{}
cfg := &config.Config{
TokenRefresh: config.TokenRefreshConfig{
MaxRetries: 1,
RetryBackoffSeconds: 0,
},
}
service := NewTokenRefreshService(repo, nil, nil, nil, nil, nil, cfg)
account := &Account{
ID: 7,
Platform: PlatformGemini,
Type: AccountTypeOAuth,
}
refresher := &tokenRefresherStub{
credentials: map[string]any{
"access_token": "token",
},
}
err := service.refreshWithRetry(context.Background(), account, refresher)
require.NoError(t, err)
require.Equal(t, 1, repo.updateCalls)
}
// TestTokenRefreshService_RefreshWithRetry_Antigravity 测试 Antigravity 平台的缓存失效
func TestTokenRefreshService_RefreshWithRetry_Antigravity(t *testing.T) {
repo := &tokenRefreshAccountRepo{}
invalidator := &tokenCacheInvalidatorStub{}
cfg := &config.Config{
TokenRefresh: config.TokenRefreshConfig{
MaxRetries: 1,
RetryBackoffSeconds: 0,
},
}
service := NewTokenRefreshService(repo, nil, nil, nil, nil, invalidator, cfg)
account := &Account{
ID: 8,
Platform: PlatformAntigravity,
Type: AccountTypeOAuth,
}
refresher := &tokenRefresherStub{
credentials: map[string]any{
"access_token": "ag-token",
},
}
err := service.refreshWithRetry(context.Background(), account, refresher)
require.NoError(t, err)
require.Equal(t, 1, repo.updateCalls)
require.Equal(t, 1, invalidator.calls) // Antigravity 也应触发缓存失效
}
// TestTokenRefreshService_RefreshWithRetry_NonOAuthAccount 测试非 OAuth 账号不触发缓存失效
func TestTokenRefreshService_RefreshWithRetry_NonOAuthAccount(t *testing.T) {
repo := &tokenRefreshAccountRepo{}
invalidator := &tokenCacheInvalidatorStub{}
cfg := &config.Config{
TokenRefresh: config.TokenRefreshConfig{
MaxRetries: 1,
RetryBackoffSeconds: 0,
},
}
service := NewTokenRefreshService(repo, nil, nil, nil, nil, invalidator, cfg)
account := &Account{
ID: 9,
Platform: PlatformGemini,
Type: AccountTypeAPIKey, // 非 OAuth
}
refresher := &tokenRefresherStub{
credentials: map[string]any{
"access_token": "token",
},
}
err := service.refreshWithRetry(context.Background(), account, refresher)
require.NoError(t, err)
require.Equal(t, 1, repo.updateCalls)
require.Equal(t, 0, invalidator.calls) // 非 OAuth 不触发缓存失效
}
// TestTokenRefreshService_RefreshWithRetry_OtherPlatformOAuth 测试其他平台的 OAuth 账号不触发缓存失效
func TestTokenRefreshService_RefreshWithRetry_OtherPlatformOAuth(t *testing.T) {
repo := &tokenRefreshAccountRepo{}
invalidator := &tokenCacheInvalidatorStub{}
cfg := &config.Config{
TokenRefresh: config.TokenRefreshConfig{
MaxRetries: 1,
RetryBackoffSeconds: 0,
},
}
service := NewTokenRefreshService(repo, nil, nil, nil, nil, invalidator, cfg)
account := &Account{
ID: 10,
Platform: PlatformOpenAI, // 其他平台
Type: AccountTypeOAuth,
}
refresher := &tokenRefresherStub{
credentials: map[string]any{
"access_token": "token",
},
}
err := service.refreshWithRetry(context.Background(), account, refresher)
require.NoError(t, err)
require.Equal(t, 1, repo.updateCalls)
require.Equal(t, 0, invalidator.calls) // 其他平台不触发缓存失效
}
// TestTokenRefreshService_RefreshWithRetry_UpdateFailed 测试更新失败的情况
func TestTokenRefreshService_RefreshWithRetry_UpdateFailed(t *testing.T) {
repo := &tokenRefreshAccountRepo{updateErr: errors.New("update failed")}
invalidator := &tokenCacheInvalidatorStub{}
cfg := &config.Config{
TokenRefresh: config.TokenRefreshConfig{
MaxRetries: 1,
RetryBackoffSeconds: 0,
},
}
service := NewTokenRefreshService(repo, nil, nil, nil, nil, invalidator, cfg)
account := &Account{
ID: 11,
Platform: PlatformGemini,
Type: AccountTypeOAuth,
}
refresher := &tokenRefresherStub{
credentials: map[string]any{
"access_token": "token",
},
}
err := service.refreshWithRetry(context.Background(), account, refresher)
require.Error(t, err)
require.Contains(t, err.Error(), "failed to save credentials")
require.Equal(t, 1, repo.updateCalls)
require.Equal(t, 0, invalidator.calls) // 更新失败时不应触发缓存失效
}
// TestTokenRefreshService_RefreshWithRetry_RefreshFailed 测试刷新失败的情况
func TestTokenRefreshService_RefreshWithRetry_RefreshFailed(t *testing.T) {
repo := &tokenRefreshAccountRepo{}
invalidator := &tokenCacheInvalidatorStub{}
cfg := &config.Config{
TokenRefresh: config.TokenRefreshConfig{
MaxRetries: 2,
RetryBackoffSeconds: 0,
},
}
service := NewTokenRefreshService(repo, nil, nil, nil, nil, invalidator, cfg)
account := &Account{
ID: 12,
Platform: PlatformGemini,
Type: AccountTypeOAuth,
}
refresher := &tokenRefresherStub{
err: errors.New("refresh failed"),
}
err := service.refreshWithRetry(context.Background(), account, refresher)
require.Error(t, err)
require.Equal(t, 0, repo.updateCalls) // 刷新失败不应更新
require.Equal(t, 0, invalidator.calls) // 刷新失败不应触发缓存失效
require.Equal(t, 1, repo.setErrorCalls) // 应设置错误状态
}
// TestTokenRefreshService_RefreshWithRetry_AntigravityRefreshFailed 测试 Antigravity 刷新失败不设置错误状态
func TestTokenRefreshService_RefreshWithRetry_AntigravityRefreshFailed(t *testing.T) {
repo := &tokenRefreshAccountRepo{}
invalidator := &tokenCacheInvalidatorStub{}
cfg := &config.Config{
TokenRefresh: config.TokenRefreshConfig{
MaxRetries: 1,
RetryBackoffSeconds: 0,
},
}
service := NewTokenRefreshService(repo, nil, nil, nil, nil, invalidator, cfg)
account := &Account{
ID: 13,
Platform: PlatformAntigravity,
Type: AccountTypeOAuth,
}
refresher := &tokenRefresherStub{
err: errors.New("network error"), // 可重试错误
}
err := service.refreshWithRetry(context.Background(), account, refresher)
require.Error(t, err)
require.Equal(t, 0, repo.updateCalls)
require.Equal(t, 0, invalidator.calls)
require.Equal(t, 0, repo.setErrorCalls) // Antigravity 可重试错误不设置错误状态
}
// TestTokenRefreshService_RefreshWithRetry_AntigravityNonRetryableError 测试 Antigravity 不可重试错误
func TestTokenRefreshService_RefreshWithRetry_AntigravityNonRetryableError(t *testing.T) {
repo := &tokenRefreshAccountRepo{}
invalidator := &tokenCacheInvalidatorStub{}
cfg := &config.Config{
TokenRefresh: config.TokenRefreshConfig{
MaxRetries: 3,
RetryBackoffSeconds: 0,
},
}
service := NewTokenRefreshService(repo, nil, nil, nil, nil, invalidator, cfg)
account := &Account{
ID: 14,
Platform: PlatformAntigravity,
Type: AccountTypeOAuth,
}
refresher := &tokenRefresherStub{
err: errors.New("invalid_grant: token revoked"), // 不可重试错误
}
err := service.refreshWithRetry(context.Background(), account, refresher)
require.Error(t, err)
require.Equal(t, 0, repo.updateCalls)
require.Equal(t, 0, invalidator.calls)
require.Equal(t, 1, repo.setErrorCalls) // 不可重试错误应设置错误状态
}
// TestIsNonRetryableRefreshError 测试不可重试错误判断
func TestIsNonRetryableRefreshError(t *testing.T) {
tests := []struct {
name string
err error
expected bool
}{
{name: "nil_error", err: nil, expected: false},
{name: "network_error", err: errors.New("network timeout"), expected: false},
{name: "invalid_grant", err: errors.New("invalid_grant"), expected: true},
{name: "invalid_client", err: errors.New("invalid_client"), expected: true},
{name: "unauthorized_client", err: errors.New("unauthorized_client"), expected: true},
{name: "access_denied", err: errors.New("access_denied"), expected: true},
{name: "invalid_grant_with_desc", err: errors.New("Error: invalid_grant - token revoked"), expected: true},
{name: "case_insensitive", err: errors.New("INVALID_GRANT"), expected: true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := isNonRetryableRefreshError(tt.err)
require.Equal(t, tt.expected, result)
})
}
}

View File

@@ -42,9 +42,10 @@ func ProvideTokenRefreshService(
openaiOAuthService *OpenAIOAuthService,
geminiOAuthService *GeminiOAuthService,
antigravityOAuthService *AntigravityOAuthService,
cacheInvalidator TokenCacheInvalidator,
cfg *config.Config,
) *TokenRefreshService {
svc := NewTokenRefreshService(accountRepo, oauthService, openaiOAuthService, geminiOAuthService, antigravityOAuthService, cfg)
svc := NewTokenRefreshService(accountRepo, oauthService, openaiOAuthService, geminiOAuthService, antigravityOAuthService, cacheInvalidator, cfg)
svc.Start()
return svc
}
@@ -108,10 +109,12 @@ func ProvideRateLimitService(
tempUnschedCache TempUnschedCache,
timeoutCounterCache TimeoutCounterCache,
settingService *SettingService,
tokenCacheInvalidator TokenCacheInvalidator,
) *RateLimitService {
svc := NewRateLimitService(accountRepo, usageRepo, cfg, geminiQuotaService, tempUnschedCache)
svc.SetTimeoutCounterCache(timeoutCounterCache)
svc.SetSettingService(settingService)
svc.SetTokenCacheInvalidator(tokenCacheInvalidator)
return svc
}
@@ -210,6 +213,7 @@ var ProviderSet = wire.NewSet(
NewOpenAIOAuthService,
NewGeminiOAuthService,
NewGeminiQuotaService,
NewCompositeTokenCacheInvalidator,
NewAntigravityOAuthService,
NewGeminiTokenProvider,
NewGeminiMessagesCompatService,

View File

@@ -69,6 +69,14 @@ JWT_EXPIRE_HOUR=24
# Leave unset to use default ./config.yaml
#CONFIG_FILE=./config.yaml
# -----------------------------------------------------------------------------
# Rate Limiting (Optional)
# 速率限制(可选)
# -----------------------------------------------------------------------------
# Cooldown time (in minutes) when upstream returns 529 (overloaded)
# 上游返回 529过载时的冷却时间分钟
RATE_LIMIT_OVERLOAD_COOLDOWN_MINUTES=10
# -----------------------------------------------------------------------------
# Gateway Scheduling (Optional)
# 调度缓存与受控回源配置(缓存就绪且命中时不读 DB

View File

@@ -357,9 +357,6 @@ const handleBulkToggleSchedulable = async (schedulable: boolean) => {
} else {
selIds.value = hasIds ? [] : accountIds
}
load().catch((error) => {
console.error('Failed to refresh accounts:', error)
})
} catch (error) {
console.error('Failed to bulk toggle schedulable:', error)
appStore.showError(t('common.error'))
@@ -383,9 +380,6 @@ const handleToggleSchedulable = async (a: Account) => {
try {
const updated = await adminAPI.accounts.setSchedulable(a.id, nextSchedulable)
updateSchedulableInList([a.id], updated?.schedulable ?? nextSchedulable)
load().catch((error) => {
console.error('Failed to refresh accounts:', error)
})
} catch (error) {
console.error('Failed to toggle schedulable:', error)
appStore.showError(t('admin.accounts.failedToToggleSchedulable'))