diff --git a/backend/cmd/server/wire_gen.go b/backend/cmd/server/wire_gen.go index ba232984..37ad5d9f 100644 --- a/backend/cmd/server/wire_gen.go +++ b/backend/cmd/server/wire_gen.go @@ -138,7 +138,8 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { accountTestService := service.NewAccountTestService(accountRepository, geminiTokenProvider, antigravityGatewayService, httpUpstream, configConfig) crsSyncService := service.NewCRSSyncService(accountRepository, proxyRepository, oAuthService, openAIOAuthService, geminiOAuthService, configConfig) sessionLimitCache := repository.ProvideSessionLimitCache(redisClient, configConfig) - accountHandler := admin.NewAccountHandler(adminService, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, rateLimitService, accountUsageService, accountTestService, concurrencyService, crsSyncService, sessionLimitCache, compositeTokenCacheInvalidator) + rpmCache := repository.NewRPMCache(redisClient) + accountHandler := admin.NewAccountHandler(adminService, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, rateLimitService, accountUsageService, accountTestService, concurrencyService, crsSyncService, sessionLimitCache, rpmCache, compositeTokenCacheInvalidator) adminAnnouncementHandler := admin.NewAnnouncementHandler(announcementService) dataManagementService := service.NewDataManagementService() dataManagementHandler := admin.NewDataManagementHandler(dataManagementService) @@ -160,7 +161,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { deferredService := service.ProvideDeferredService(accountRepository, timingWheelService) claudeTokenProvider := service.NewClaudeTokenProvider(accountRepository, geminiTokenCache, oAuthService) digestSessionStore := service.NewDigestSessionStore() - gatewayService := service.NewGatewayService(accountRepository, groupRepository, usageLogRepository, userRepository, userSubscriptionRepository, userGroupRateRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, identityService, httpUpstream, deferredService, claudeTokenProvider, sessionLimitCache, digestSessionStore) + gatewayService := service.NewGatewayService(accountRepository, groupRepository, usageLogRepository, userRepository, userSubscriptionRepository, userGroupRateRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, identityService, httpUpstream, deferredService, claudeTokenProvider, sessionLimitCache, rpmCache, digestSessionStore) openAITokenProvider := service.NewOpenAITokenProvider(accountRepository, geminiTokenCache, openAIOAuthService) openAIGatewayService := service.NewOpenAIGatewayService(accountRepository, usageLogRepository, userRepository, userSubscriptionRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, httpUpstream, deferredService, openAITokenProvider) geminiMessagesCompatService := service.NewGeminiMessagesCompatService(accountRepository, groupRepository, gatewayCache, schedulerSnapshotService, geminiTokenProvider, rateLimitService, httpUpstream, antigravityGatewayService, configConfig) diff --git a/backend/internal/handler/admin/account_data_handler_test.go b/backend/internal/handler/admin/account_data_handler_test.go index c8b04c2a..285033a1 100644 --- a/backend/internal/handler/admin/account_data_handler_test.go +++ b/backend/internal/handler/admin/account_data_handler_test.go @@ -64,6 +64,7 @@ func setupAccountDataRouter() (*gin.Engine, *stubAdminService) { nil, nil, nil, + nil, ) router.GET("/api/v1/admin/accounts/data", h.ExportData) diff --git a/backend/internal/handler/admin/account_handler.go b/backend/internal/handler/admin/account_handler.go index 9b732f9c..f6082e09 100644 --- a/backend/internal/handler/admin/account_handler.go +++ b/backend/internal/handler/admin/account_handler.go @@ -53,6 +53,7 @@ type AccountHandler struct { concurrencyService *service.ConcurrencyService crsSyncService *service.CRSSyncService sessionLimitCache service.SessionLimitCache + rpmCache service.RPMCache tokenCacheInvalidator service.TokenCacheInvalidator } @@ -69,6 +70,7 @@ func NewAccountHandler( concurrencyService *service.ConcurrencyService, crsSyncService *service.CRSSyncService, sessionLimitCache service.SessionLimitCache, + rpmCache service.RPMCache, tokenCacheInvalidator service.TokenCacheInvalidator, ) *AccountHandler { return &AccountHandler{ @@ -83,6 +85,7 @@ func NewAccountHandler( concurrencyService: concurrencyService, crsSyncService: crsSyncService, sessionLimitCache: sessionLimitCache, + rpmCache: rpmCache, tokenCacheInvalidator: tokenCacheInvalidator, } } @@ -154,6 +157,7 @@ type AccountWithConcurrency struct { // 以下字段仅对 Anthropic OAuth/SetupToken 账号有效,且仅在启用相应功能时返回 CurrentWindowCost *float64 `json:"current_window_cost,omitempty"` // 当前窗口费用 ActiveSessions *int `json:"active_sessions,omitempty"` // 当前活跃会话数 + CurrentRPM *int `json:"current_rpm,omitempty"` // 当前分钟 RPM 计数 } func (h *AccountHandler) buildAccountResponseWithRuntime(ctx context.Context, account *service.Account) AccountWithConcurrency { @@ -189,6 +193,12 @@ func (h *AccountHandler) buildAccountResponseWithRuntime(ctx context.Context, ac } } } + + if h.rpmCache != nil && account.GetBaseRPM() > 0 { + if rpm, err := h.rpmCache.GetRPM(ctx, account.ID); err == nil { + item.CurrentRPM = &rpm + } + } } return item @@ -231,9 +241,10 @@ func (h *AccountHandler) List(c *gin.Context) { concurrencyCounts = make(map[int64]int) } - // 识别需要查询窗口费用和会话数的账号(Anthropic OAuth/SetupToken 且启用了相应功能) + // 识别需要查询窗口费用、会话数和 RPM 的账号(Anthropic OAuth/SetupToken 且启用了相应功能) windowCostAccountIDs := make([]int64, 0) sessionLimitAccountIDs := make([]int64, 0) + rpmAccountIDs := make([]int64, 0) sessionIdleTimeouts := make(map[int64]time.Duration) // 各账号的会话空闲超时配置 for i := range accounts { acc := &accounts[i] @@ -245,12 +256,24 @@ func (h *AccountHandler) List(c *gin.Context) { sessionLimitAccountIDs = append(sessionLimitAccountIDs, acc.ID) sessionIdleTimeouts[acc.ID] = time.Duration(acc.GetSessionIdleTimeoutMinutes()) * time.Minute } + if acc.GetBaseRPM() > 0 { + rpmAccountIDs = append(rpmAccountIDs, acc.ID) + } } } - // 并行获取窗口费用和活跃会话数 + // 并行获取窗口费用、活跃会话数和 RPM 计数 var windowCosts map[int64]float64 var activeSessions map[int64]int + var rpmCounts map[int64]int + + // 获取 RPM 计数(批量查询) + if len(rpmAccountIDs) > 0 && h.rpmCache != nil { + rpmCounts, _ = h.rpmCache.GetRPMBatch(c.Request.Context(), rpmAccountIDs) + if rpmCounts == nil { + rpmCounts = make(map[int64]int) + } + } // 获取活跃会话数(批量查询,传入各账号的 idleTimeout 配置) if len(sessionLimitAccountIDs) > 0 && h.sessionLimitCache != nil { @@ -311,6 +334,13 @@ func (h *AccountHandler) List(c *gin.Context) { } } + // 添加 RPM 计数(仅当启用时) + if rpmCounts != nil { + if rpm, ok := rpmCounts[acc.ID]; ok { + item.CurrentRPM = &rpm + } + } + result[i] = item } @@ -453,6 +483,8 @@ func (h *AccountHandler) Create(c *gin.Context) { response.BadRequest(c, "rate_multiplier must be >= 0") return } + // base_rpm 输入校验:负值归零,超过 10000 截断 + sanitizeExtraBaseRPM(req.Extra) // 确定是否跳过混合渠道检查 skipCheck := req.ConfirmMixedChannelRisk != nil && *req.ConfirmMixedChannelRisk @@ -522,6 +554,8 @@ func (h *AccountHandler) Update(c *gin.Context) { response.BadRequest(c, "rate_multiplier must be >= 0") return } + // base_rpm 输入校验:负值归零,超过 10000 截断 + sanitizeExtraBaseRPM(req.Extra) // 确定是否跳过混合渠道检查 skipCheck := req.ConfirmMixedChannelRisk != nil && *req.ConfirmMixedChannelRisk @@ -904,6 +938,9 @@ func (h *AccountHandler) BatchCreate(c *gin.Context) { continue } + // base_rpm 输入校验:负值归零,超过 10000 截断 + sanitizeExtraBaseRPM(item.Extra) + skipCheck := item.ConfirmMixedChannelRisk != nil && *item.ConfirmMixedChannelRisk account, err := h.adminService.CreateAccount(ctx, &service.CreateAccountInput{ @@ -1048,6 +1085,8 @@ func (h *AccountHandler) BulkUpdate(c *gin.Context) { response.BadRequest(c, "rate_multiplier must be >= 0") return } + // base_rpm 输入校验:负值归零,超过 10000 截断 + sanitizeExtraBaseRPM(req.Extra) // 确定是否跳过混合渠道检查 skipCheck := req.ConfirmMixedChannelRisk != nil && *req.ConfirmMixedChannelRisk @@ -1706,3 +1745,22 @@ func (h *AccountHandler) BatchRefreshTier(c *gin.Context) { func (h *AccountHandler) GetAntigravityDefaultModelMapping(c *gin.Context) { response.Success(c, domain.DefaultAntigravityModelMapping) } + +// sanitizeExtraBaseRPM 对 extra map 中的 base_rpm 值进行范围校验和归一化。 +// 负值归零,超过 10000 截断为 10000。extra 为 nil 或不含 base_rpm 时无操作。 +func sanitizeExtraBaseRPM(extra map[string]any) { + if extra == nil { + return + } + raw, ok := extra["base_rpm"] + if !ok { + return + } + v := service.ParseExtraInt(raw) + if v < 0 { + v = 0 + } else if v > 10000 { + v = 10000 + } + extra["base_rpm"] = v +} diff --git a/backend/internal/handler/admin/account_handler_mixed_channel_test.go b/backend/internal/handler/admin/account_handler_mixed_channel_test.go index ad004844..61b99e03 100644 --- a/backend/internal/handler/admin/account_handler_mixed_channel_test.go +++ b/backend/internal/handler/admin/account_handler_mixed_channel_test.go @@ -15,7 +15,7 @@ import ( func setupAccountMixedChannelRouter(adminSvc *stubAdminService) *gin.Engine { gin.SetMode(gin.TestMode) router := gin.New() - accountHandler := NewAccountHandler(adminSvc, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil) + accountHandler := NewAccountHandler(adminSvc, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil) router.POST("/api/v1/admin/accounts/check-mixed-channel", accountHandler.CheckMixedChannel) router.POST("/api/v1/admin/accounts", accountHandler.Create) router.PUT("/api/v1/admin/accounts/:id", accountHandler.Update) diff --git a/backend/internal/handler/admin/account_handler_passthrough_test.go b/backend/internal/handler/admin/account_handler_passthrough_test.go index d09cccd6..d86501c0 100644 --- a/backend/internal/handler/admin/account_handler_passthrough_test.go +++ b/backend/internal/handler/admin/account_handler_passthrough_test.go @@ -28,6 +28,7 @@ func TestAccountHandler_Create_AnthropicAPIKeyPassthroughExtraForwarded(t *testi nil, nil, nil, + nil, ) router := gin.New() diff --git a/backend/internal/handler/admin/batch_update_credentials_test.go b/backend/internal/handler/admin/batch_update_credentials_test.go index c8185735..0b1b6691 100644 --- a/backend/internal/handler/admin/batch_update_credentials_test.go +++ b/backend/internal/handler/admin/batch_update_credentials_test.go @@ -36,7 +36,7 @@ func (f *failingAdminService) UpdateAccount(ctx context.Context, id int64, input func setupAccountHandlerWithService(adminSvc service.AdminService) (*gin.Engine, *AccountHandler) { gin.SetMode(gin.TestMode) router := gin.New() - handler := NewAccountHandler(adminSvc, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil) + handler := NewAccountHandler(adminSvc, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil) router.POST("/api/v1/admin/accounts/batch-update-credentials", handler.BatchUpdateCredentials) return router, handler } diff --git a/backend/internal/handler/dto/mappers.go b/backend/internal/handler/dto/mappers.go index 49c74522..d811c7be 100644 --- a/backend/internal/handler/dto/mappers.go +++ b/backend/internal/handler/dto/mappers.go @@ -209,6 +209,13 @@ func AccountFromServiceShallow(a *service.Account) *Account { if idleTimeout := a.GetSessionIdleTimeoutMinutes(); idleTimeout > 0 { out.SessionIdleTimeoutMin = &idleTimeout } + if rpm := a.GetBaseRPM(); rpm > 0 { + out.BaseRPM = &rpm + strategy := a.GetRPMStrategy() + out.RPMStrategy = &strategy + buffer := a.GetRPMStickyBuffer() + out.RPMStickyBuffer = &buffer + } // TLS指纹伪装开关 if a.IsTLSFingerprintEnabled() { enabled := true diff --git a/backend/internal/handler/dto/types.go b/backend/internal/handler/dto/types.go index 73243397..c575c232 100644 --- a/backend/internal/handler/dto/types.go +++ b/backend/internal/handler/dto/types.go @@ -153,6 +153,12 @@ type Account struct { MaxSessions *int `json:"max_sessions,omitempty"` SessionIdleTimeoutMin *int `json:"session_idle_timeout_minutes,omitempty"` + // RPM 限制(仅 Anthropic OAuth/SetupToken 账号有效) + // 从 extra 字段提取,方便前端显示和编辑 + BaseRPM *int `json:"base_rpm,omitempty"` + RPMStrategy *string `json:"rpm_strategy,omitempty"` + RPMStickyBuffer *int `json:"rpm_sticky_buffer,omitempty"` + // TLS指纹伪装(仅 Anthropic OAuth/SetupToken 账号有效) // 从 extra 字段提取,方便前端显示和编辑 EnableTLSFingerprint *bool `json:"enable_tls_fingerprint,omitempty"` diff --git a/backend/internal/handler/gateway_handler.go b/backend/internal/handler/gateway_handler.go index 9262df7e..3cc52839 100644 --- a/backend/internal/handler/gateway_handler.go +++ b/backend/internal/handler/gateway_handler.go @@ -403,6 +403,15 @@ func (h *GatewayHandler) Messages(c *gin.Context) { return } + // RPM 计数递增(Forward 成功后) + // 注意:TOCTOU 竞态是已知且可接受的设计权衡,与 WindowCost 一致的 soft-limit 模式。 + // 在高并发下可能短暂超出 RPM 限制,但不会导致请求失败。 + if account.IsAnthropicOAuthOrSetupToken() && account.GetBaseRPM() > 0 { + if err := h.gatewayService.IncrementAccountRPM(c.Request.Context(), account.ID); err != nil { + reqLog.Warn("gateway.rpm_increment_failed", zap.Int64("account_id", account.ID), zap.Error(err)) + } + } + // 捕获请求信息(用于异步记录,避免在 goroutine 中访问 gin.Context) userAgent := c.GetHeader("User-Agent") clientIP := ip.GetClientIP(c) @@ -595,7 +604,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) { h.handleStreamingAwareError(c, status, code, message, streamStarted) return } - // 兜底重试按“直接请求兜底分组”处理:清除强制平台,允许按分组平台调度 + // 兜底重试按"直接请求兜底分组"处理:清除强制平台,允许按分组平台调度 ctx := context.WithValue(c.Request.Context(), ctxkey.ForcePlatform, "") c.Request = c.Request.WithContext(ctx) currentAPIKey = fallbackAPIKey @@ -629,6 +638,15 @@ func (h *GatewayHandler) Messages(c *gin.Context) { return } + // RPM 计数递增(Forward 成功后) + // 注意:TOCTOU 竞态是已知且可接受的设计权衡,与 WindowCost 一致的 soft-limit 模式。 + // 在高并发下可能短暂超出 RPM 限制,但不会导致请求失败。 + if account.IsAnthropicOAuthOrSetupToken() && account.GetBaseRPM() > 0 { + if err := h.gatewayService.IncrementAccountRPM(c.Request.Context(), account.ID); err != nil { + reqLog.Warn("gateway.rpm_increment_failed", zap.Int64("account_id", account.ID), zap.Error(err)) + } + } + // 捕获请求信息(用于异步记录,避免在 goroutine 中访问 gin.Context) userAgent := c.GetHeader("User-Agent") clientIP := ip.GetClientIP(c) diff --git a/backend/internal/handler/gateway_handler_warmup_intercept_unit_test.go b/backend/internal/handler/gateway_handler_warmup_intercept_unit_test.go index 76141521..2afa6440 100644 --- a/backend/internal/handler/gateway_handler_warmup_intercept_unit_test.go +++ b/backend/internal/handler/gateway_handler_warmup_intercept_unit_test.go @@ -153,6 +153,7 @@ func newTestGatewayHandler(t *testing.T, group *service.Group, accounts []*servi nil, // deferredService nil, // claudeTokenProvider nil, // sessionLimitCache + nil, // rpmCache nil, // digestStore ) diff --git a/backend/internal/handler/sora_client_handler_test.go b/backend/internal/handler/sora_client_handler_test.go index 085658ad..5df7fa0a 100644 --- a/backend/internal/handler/sora_client_handler_test.go +++ b/backend/internal/handler/sora_client_handler_test.go @@ -2184,7 +2184,7 @@ func (s *stubSoraClientForHandler) GetVideoTask(_ context.Context, _ *service.Ac func newMinimalGatewayService(accountRepo service.AccountRepository) *service.GatewayService { return service.NewGatewayService( accountRepo, nil, nil, nil, nil, nil, nil, nil, - nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, + nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, ) } diff --git a/backend/internal/handler/sora_gateway_handler_test.go b/backend/internal/handler/sora_gateway_handler_test.go index 01c684ca..68a04084 100644 --- a/backend/internal/handler/sora_gateway_handler_test.go +++ b/backend/internal/handler/sora_gateway_handler_test.go @@ -426,7 +426,8 @@ func TestSoraGatewayHandler_ChatCompletions(t *testing.T) { deferredService, nil, testutil.StubSessionLimitCache{}, - nil, + nil, // rpmCache + nil, // digestStore ) soraClient := &stubSoraClient{imageURLs: []string{"https://example.com/a.png"}} diff --git a/backend/internal/repository/rpm_cache.go b/backend/internal/repository/rpm_cache.go new file mode 100644 index 00000000..4d73ec4b --- /dev/null +++ b/backend/internal/repository/rpm_cache.go @@ -0,0 +1,141 @@ +package repository + +import ( + "context" + "errors" + "fmt" + "strconv" + "time" + + "github.com/Wei-Shaw/sub2api/internal/service" + "github.com/redis/go-redis/v9" +) + +// RPM 计数器缓存常量定义 +// +// 设计说明: +// 使用 Redis 简单计数器跟踪每个账号每分钟的请求数: +// - Key: rpm:{accountID}:{minuteTimestamp} +// - Value: 当前分钟内的请求计数 +// - TTL: 120 秒(覆盖当前分钟 + 一定冗余) +// +// 使用 TxPipeline(MULTI/EXEC)执行 INCR + EXPIRE,保证原子性且兼容 Redis Cluster。 +// 通过 rdb.Time() 获取服务端时间,避免多实例时钟不同步。 +// +// 设计决策: +// - TxPipeline vs Pipeline:Pipeline 仅合并发送但不保证原子,TxPipeline 使用 MULTI/EXEC 事务保证原子执行。 +// - rdb.Time() 单独调用:Pipeline/TxPipeline 中无法引用前一命令的结果,因此 TIME 必须单独调用(2 RTT)。 +// Lua 脚本可以做到 1 RTT,但在 Redis Cluster 中动态拼接 key 存在 CROSSSLOT 风险,选择安全性优先。 +const ( + // RPM 计数器键前缀 + // 格式: rpm:{accountID}:{minuteTimestamp} + rpmKeyPrefix = "rpm:" + + // RPM 计数器 TTL(120 秒,覆盖当前分钟窗口 + 冗余) + rpmKeyTTL = 120 * time.Second +) + +// RPMCacheImpl RPM 计数器缓存 Redis 实现 +type RPMCacheImpl struct { + rdb *redis.Client +} + +// NewRPMCache 创建 RPM 计数器缓存 +func NewRPMCache(rdb *redis.Client) service.RPMCache { + return &RPMCacheImpl{rdb: rdb} +} + +// currentMinuteKey 获取当前分钟的完整 Redis key +// 使用 rdb.Time() 获取 Redis 服务端时间,避免多实例时钟偏差 +func (c *RPMCacheImpl) currentMinuteKey(ctx context.Context, accountID int64) (string, error) { + serverTime, err := c.rdb.Time(ctx).Result() + if err != nil { + return "", fmt.Errorf("redis TIME: %w", err) + } + minuteTS := serverTime.Unix() / 60 + return fmt.Sprintf("%s%d:%d", rpmKeyPrefix, accountID, minuteTS), nil +} + +// currentMinuteSuffix 获取当前分钟时间戳后缀(供批量操作使用) +// 使用 rdb.Time() 获取 Redis 服务端时间 +func (c *RPMCacheImpl) currentMinuteSuffix(ctx context.Context) (string, error) { + serverTime, err := c.rdb.Time(ctx).Result() + if err != nil { + return "", fmt.Errorf("redis TIME: %w", err) + } + minuteTS := serverTime.Unix() / 60 + return strconv.FormatInt(minuteTS, 10), nil +} + +// IncrementRPM 原子递增并返回当前分钟的计数 +// 使用 TxPipeline (MULTI/EXEC) 执行 INCR + EXPIRE,保证原子性且兼容 Redis Cluster +func (c *RPMCacheImpl) IncrementRPM(ctx context.Context, accountID int64) (int, error) { + key, err := c.currentMinuteKey(ctx, accountID) + if err != nil { + return 0, fmt.Errorf("rpm increment: %w", err) + } + + // 使用 TxPipeline (MULTI/EXEC) 保证 INCR + EXPIRE 原子执行 + // EXPIRE 幂等,每次都设置不影响正确性 + pipe := c.rdb.TxPipeline() + incrCmd := pipe.Incr(ctx, key) + pipe.Expire(ctx, key, rpmKeyTTL) + + if _, err := pipe.Exec(ctx); err != nil { + return 0, fmt.Errorf("rpm increment: %w", err) + } + + return int(incrCmd.Val()), nil +} + +// GetRPM 获取当前分钟的 RPM 计数 +func (c *RPMCacheImpl) GetRPM(ctx context.Context, accountID int64) (int, error) { + key, err := c.currentMinuteKey(ctx, accountID) + if err != nil { + return 0, fmt.Errorf("rpm get: %w", err) + } + + val, err := c.rdb.Get(ctx, key).Int() + if errors.Is(err, redis.Nil) { + return 0, nil // 当前分钟无记录 + } + if err != nil { + return 0, fmt.Errorf("rpm get: %w", err) + } + return val, nil +} + +// GetRPMBatch 批量获取多个账号的 RPM 计数(使用 Pipeline) +func (c *RPMCacheImpl) GetRPMBatch(ctx context.Context, accountIDs []int64) (map[int64]int, error) { + if len(accountIDs) == 0 { + return map[int64]int{}, nil + } + + // 获取当前分钟后缀 + minuteSuffix, err := c.currentMinuteSuffix(ctx) + if err != nil { + return nil, fmt.Errorf("rpm batch get: %w", err) + } + + // 使用 Pipeline 批量 GET + pipe := c.rdb.Pipeline() + cmds := make(map[int64]*redis.StringCmd, len(accountIDs)) + for _, id := range accountIDs { + key := fmt.Sprintf("%s%d:%s", rpmKeyPrefix, id, minuteSuffix) + cmds[id] = pipe.Get(ctx, key) + } + + if _, err := pipe.Exec(ctx); err != nil && !errors.Is(err, redis.Nil) { + return nil, fmt.Errorf("rpm batch get: %w", err) + } + + result := make(map[int64]int, len(accountIDs)) + for id, cmd := range cmds { + if val, err := cmd.Int(); err == nil { + result[id] = val + } else { + result[id] = 0 + } + } + return result, nil +} diff --git a/backend/internal/repository/wire.go b/backend/internal/repository/wire.go index eb8ce3fb..2344035c 100644 --- a/backend/internal/repository/wire.go +++ b/backend/internal/repository/wire.go @@ -79,6 +79,7 @@ var ProviderSet = wire.NewSet( NewTimeoutCounterCache, ProvideConcurrencyCache, ProvideSessionLimitCache, + NewRPMCache, NewDashboardCache, NewEmailCache, NewIdentityCache, diff --git a/backend/internal/server/api_contract_test.go b/backend/internal/server/api_contract_test.go index a9a9bbdd..f8a3a9dd 100644 --- a/backend/internal/server/api_contract_test.go +++ b/backend/internal/server/api_contract_test.go @@ -624,7 +624,7 @@ func newContractDeps(t *testing.T) *contractDeps { apiKeyHandler := handler.NewAPIKeyHandler(apiKeyService) usageHandler := handler.NewUsageHandler(usageService, apiKeyService) adminSettingHandler := adminhandler.NewSettingHandler(settingService, nil, nil, nil, nil) - adminAccountHandler := adminhandler.NewAccountHandler(adminService, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil) + adminAccountHandler := adminhandler.NewAccountHandler(adminService, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil) jwtAuth := func(c *gin.Context) { c.Set(string(middleware.ContextKeyUser), middleware.AuthSubject{ diff --git a/backend/internal/service/account.go b/backend/internal/service/account.go index 1864eb54..c76c817e 100644 --- a/backend/internal/service/account.go +++ b/backend/internal/service/account.go @@ -1137,6 +1137,80 @@ func (a *Account) GetSessionIdleTimeoutMinutes() int { return 5 } +// GetBaseRPM 获取基础 RPM 限制 +// 返回 0 表示未启用(负数视为无效配置,按 0 处理) +func (a *Account) GetBaseRPM() int { + if a.Extra == nil { + return 0 + } + if v, ok := a.Extra["base_rpm"]; ok { + val := parseExtraInt(v) + if val > 0 { + return val + } + } + return 0 +} + +// GetRPMStrategy 获取 RPM 策略 +// "tiered" = 三区模型(默认), "sticky_exempt" = 粘性豁免 +func (a *Account) GetRPMStrategy() string { + if a.Extra == nil { + return "tiered" + } + if v, ok := a.Extra["rpm_strategy"]; ok { + if s, ok := v.(string); ok && s == "sticky_exempt" { + return "sticky_exempt" + } + } + return "tiered" +} + +// GetRPMStickyBuffer 获取 RPM 粘性缓冲数量 +// tiered 模式下的黄区大小,默认为 base_rpm 的 20%(至少 1) +func (a *Account) GetRPMStickyBuffer() int { + if a.Extra == nil { + return 0 + } + if v, ok := a.Extra["rpm_sticky_buffer"]; ok { + val := parseExtraInt(v) + if val > 0 { + return val + } + } + base := a.GetBaseRPM() + buffer := base / 5 + if buffer < 1 && base > 0 { + buffer = 1 + } + return buffer +} + +// CheckRPMSchedulability 根据当前 RPM 计数检查调度状态 +// 复用 WindowCostSchedulability 三态:Schedulable / StickyOnly / NotSchedulable +func (a *Account) CheckRPMSchedulability(currentRPM int) WindowCostSchedulability { + baseRPM := a.GetBaseRPM() + if baseRPM <= 0 { + return WindowCostSchedulable + } + + if currentRPM < baseRPM { + return WindowCostSchedulable + } + + strategy := a.GetRPMStrategy() + if strategy == "sticky_exempt" { + return WindowCostStickyOnly // 粘性豁免无红区 + } + + // tiered: 黄区 + 红区 + buffer := a.GetRPMStickyBuffer() + if currentRPM < baseRPM+buffer { + return WindowCostStickyOnly + } + return WindowCostNotSchedulable +} + // CheckWindowCostSchedulability 根据当前窗口费用检查调度状态 // - 费用 < 阈值: WindowCostSchedulable(可正常调度) // - 费用 >= 阈值 且 < 阈值+预留: WindowCostStickyOnly(仅粘性会话) @@ -1200,6 +1274,12 @@ func parseExtraFloat64(value any) float64 { } // parseExtraInt 从 extra 字段解析 int 值 +// ParseExtraInt 从 extra 字段的 any 值解析为 int。 +// 支持 int, int64, float64, json.Number, string 类型,无法解析时返回 0。 +func ParseExtraInt(value any) int { + return parseExtraInt(value) +} + func parseExtraInt(value any) int { switch v := value.(type) { case int: diff --git a/backend/internal/service/account_rpm_test.go b/backend/internal/service/account_rpm_test.go new file mode 100644 index 00000000..9d91f3e0 --- /dev/null +++ b/backend/internal/service/account_rpm_test.go @@ -0,0 +1,120 @@ +package service + +import ( + "encoding/json" + "testing" +) + +func TestGetBaseRPM(t *testing.T) { + tests := []struct { + name string + extra map[string]any + expected int + }{ + {"nil extra", nil, 0}, + {"no key", map[string]any{}, 0}, + {"zero", map[string]any{"base_rpm": 0}, 0}, + {"int value", map[string]any{"base_rpm": 15}, 15}, + {"float value", map[string]any{"base_rpm": 15.0}, 15}, + {"string value", map[string]any{"base_rpm": "15"}, 15}, + {"negative value", map[string]any{"base_rpm": -5}, 0}, + {"int64 value", map[string]any{"base_rpm": int64(20)}, 20}, + {"json.Number value", map[string]any{"base_rpm": json.Number("25")}, 25}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + a := &Account{Extra: tt.extra} + if got := a.GetBaseRPM(); got != tt.expected { + t.Errorf("GetBaseRPM() = %d, want %d", got, tt.expected) + } + }) + } +} + +func TestGetRPMStrategy(t *testing.T) { + tests := []struct { + name string + extra map[string]any + expected string + }{ + {"nil extra", nil, "tiered"}, + {"no key", map[string]any{}, "tiered"}, + {"tiered", map[string]any{"rpm_strategy": "tiered"}, "tiered"}, + {"sticky_exempt", map[string]any{"rpm_strategy": "sticky_exempt"}, "sticky_exempt"}, + {"invalid", map[string]any{"rpm_strategy": "foobar"}, "tiered"}, + {"empty string fallback", map[string]any{"rpm_strategy": ""}, "tiered"}, + {"numeric value fallback", map[string]any{"rpm_strategy": 123}, "tiered"}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + a := &Account{Extra: tt.extra} + if got := a.GetRPMStrategy(); got != tt.expected { + t.Errorf("GetRPMStrategy() = %q, want %q", got, tt.expected) + } + }) + } +} + +func TestCheckRPMSchedulability(t *testing.T) { + tests := []struct { + name string + extra map[string]any + currentRPM int + expected WindowCostSchedulability + }{ + {"disabled", map[string]any{}, 100, WindowCostSchedulable}, + {"green zone", map[string]any{"base_rpm": 15}, 10, WindowCostSchedulable}, + {"yellow zone tiered", map[string]any{"base_rpm": 15}, 15, WindowCostStickyOnly}, + {"red zone tiered", map[string]any{"base_rpm": 15}, 18, WindowCostNotSchedulable}, + {"sticky_exempt at limit", map[string]any{"base_rpm": 15, "rpm_strategy": "sticky_exempt"}, 15, WindowCostStickyOnly}, + {"sticky_exempt over limit", map[string]any{"base_rpm": 15, "rpm_strategy": "sticky_exempt"}, 100, WindowCostStickyOnly}, + {"custom buffer", map[string]any{"base_rpm": 10, "rpm_sticky_buffer": 5}, 14, WindowCostStickyOnly}, + {"custom buffer red", map[string]any{"base_rpm": 10, "rpm_sticky_buffer": 5}, 15, WindowCostNotSchedulable}, + {"base_rpm=1 green", map[string]any{"base_rpm": 1}, 0, WindowCostSchedulable}, + {"base_rpm=1 yellow (at limit)", map[string]any{"base_rpm": 1}, 1, WindowCostStickyOnly}, + {"base_rpm=1 red (at limit+buffer)", map[string]any{"base_rpm": 1}, 2, WindowCostNotSchedulable}, + {"negative currentRPM", map[string]any{"base_rpm": 15}, -1, WindowCostSchedulable}, + {"base_rpm negative disabled", map[string]any{"base_rpm": -5}, 10, WindowCostSchedulable}, + {"very high currentRPM", map[string]any{"base_rpm": 10}, 9999, WindowCostNotSchedulable}, + {"sticky_exempt very high currentRPM", map[string]any{"base_rpm": 10, "rpm_strategy": "sticky_exempt"}, 9999, WindowCostStickyOnly}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + a := &Account{Extra: tt.extra} + if got := a.CheckRPMSchedulability(tt.currentRPM); got != tt.expected { + t.Errorf("CheckRPMSchedulability(%d) = %d, want %d", tt.currentRPM, got, tt.expected) + } + }) + } +} + +func TestGetRPMStickyBuffer(t *testing.T) { + tests := []struct { + name string + extra map[string]any + expected int + }{ + {"nil extra", nil, 0}, + {"no keys", map[string]any{}, 0}, + {"base_rpm=0", map[string]any{"base_rpm": 0}, 0}, + {"base_rpm=1 min buffer 1", map[string]any{"base_rpm": 1}, 1}, + {"base_rpm=4 min buffer 1", map[string]any{"base_rpm": 4}, 1}, + {"base_rpm=5 buffer 1", map[string]any{"base_rpm": 5}, 1}, + {"base_rpm=10 buffer 2", map[string]any{"base_rpm": 10}, 2}, + {"base_rpm=15 buffer 3", map[string]any{"base_rpm": 15}, 3}, + {"base_rpm=100 buffer 20", map[string]any{"base_rpm": 100}, 20}, + {"custom buffer=5", map[string]any{"base_rpm": 10, "rpm_sticky_buffer": 5}, 5}, + {"custom buffer=0 fallback to default", map[string]any{"base_rpm": 10, "rpm_sticky_buffer": 0}, 2}, + {"custom buffer negative fallback", map[string]any{"base_rpm": 10, "rpm_sticky_buffer": -1}, 2}, + {"custom buffer with float", map[string]any{"base_rpm": 10, "rpm_sticky_buffer": float64(7)}, 7}, + {"json.Number base_rpm", map[string]any{"base_rpm": json.Number("10")}, 2}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + a := &Account{Extra: tt.extra} + if got := a.GetRPMStickyBuffer(); got != tt.expected { + t.Errorf("GetRPMStickyBuffer() = %d, want %d", got, tt.expected) + } + }) + } +} diff --git a/backend/internal/service/gateway_service.go b/backend/internal/service/gateway_service.go index 0ba9e093..3323f868 100644 --- a/backend/internal/service/gateway_service.go +++ b/backend/internal/service/gateway_service.go @@ -520,6 +520,7 @@ type GatewayService struct { concurrencyService *ConcurrencyService claudeTokenProvider *ClaudeTokenProvider sessionLimitCache SessionLimitCache // 会话数量限制缓存(仅 Anthropic OAuth/SetupToken) + rpmCache RPMCache // RPM 计数缓存(仅 Anthropic OAuth/SetupToken) userGroupRateCache *gocache.Cache userGroupRateSF singleflight.Group modelsListCache *gocache.Cache @@ -549,6 +550,7 @@ func NewGatewayService( deferredService *DeferredService, claudeTokenProvider *ClaudeTokenProvider, sessionLimitCache SessionLimitCache, + rpmCache RPMCache, digestStore *DigestSessionStore, ) *GatewayService { userGroupRateTTL := resolveUserGroupRateCacheTTL(cfg) @@ -574,6 +576,7 @@ func NewGatewayService( deferredService: deferredService, claudeTokenProvider: claudeTokenProvider, sessionLimitCache: sessionLimitCache, + rpmCache: rpmCache, userGroupRateCache: gocache.New(userGroupRateTTL, time.Minute), modelsListCache: gocache.New(modelsListTTL, time.Minute), modelsListCacheTTL: modelsListTTL, @@ -1154,6 +1157,7 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro return nil, errors.New("no available accounts") } ctx = s.withWindowCostPrefetch(ctx, accounts) + ctx = s.withRPMPrefetch(ctx, accounts) isExcluded := func(accountID int64) bool { if excludedIDs == nil { @@ -1229,6 +1233,10 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro filteredWindowCost++ continue } + // RPM 检查(非粘性会话路径) + if !s.isAccountSchedulableForRPM(ctx, account, false) { + continue + } routingCandidates = append(routingCandidates, account) } @@ -1252,7 +1260,9 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro s.isAccountAllowedForPlatform(stickyAccount, platform, useMixed) && (requestedModel == "" || s.isModelSupportedByAccountWithContext(ctx, stickyAccount, requestedModel)) && s.isAccountSchedulableForModelSelection(ctx, stickyAccount, requestedModel) && - s.isAccountSchedulableForWindowCost(ctx, stickyAccount, true) { // 粘性会话窗口费用检查 + s.isAccountSchedulableForWindowCost(ctx, stickyAccount, true) && + + s.isAccountSchedulableForRPM(ctx, stickyAccount, true) { // 粘性会话窗口费用+RPM 检查 result, err := s.tryAcquireAccountSlot(ctx, stickyAccountID, stickyAccount.Concurrency) if err == nil && result.Acquired { // 会话数量限制检查 @@ -1406,7 +1416,9 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro s.isAccountAllowedForPlatform(account, platform, useMixed) && (requestedModel == "" || s.isModelSupportedByAccountWithContext(ctx, account, requestedModel)) && s.isAccountSchedulableForModelSelection(ctx, account, requestedModel) && - s.isAccountSchedulableForWindowCost(ctx, account, true) { // 粘性会话窗口费用检查 + s.isAccountSchedulableForWindowCost(ctx, account, true) && + + s.isAccountSchedulableForRPM(ctx, account, true) { // 粘性会话窗口费用+RPM 检查 result, err := s.tryAcquireAccountSlot(ctx, accountID, account.Concurrency) if err == nil && result.Acquired { // 会话数量限制检查 @@ -1472,6 +1484,10 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro if !s.isAccountSchedulableForWindowCost(ctx, acc, false) { continue } + // RPM 检查(非粘性会话路径) + if !s.isAccountSchedulableForRPM(ctx, acc, false) { + continue + } candidates = append(candidates, acc) } @@ -2155,6 +2171,88 @@ checkSchedulability: return true } +// rpmPrefetchContextKey is the context key for prefetched RPM counts. +type rpmPrefetchContextKeyType struct{} + +var rpmPrefetchContextKey = rpmPrefetchContextKeyType{} + +func rpmFromPrefetchContext(ctx context.Context, accountID int64) (int, bool) { + if v, ok := ctx.Value(rpmPrefetchContextKey).(map[int64]int); ok { + count, found := v[accountID] + return count, found + } + return 0, false +} + +// withRPMPrefetch 批量预取所有候选账号的 RPM 计数 +func (s *GatewayService) withRPMPrefetch(ctx context.Context, accounts []Account) context.Context { + if s.rpmCache == nil { + return ctx + } + + var ids []int64 + for i := range accounts { + if accounts[i].IsAnthropicOAuthOrSetupToken() && accounts[i].GetBaseRPM() > 0 { + ids = append(ids, accounts[i].ID) + } + } + if len(ids) == 0 { + return ctx + } + + counts, err := s.rpmCache.GetRPMBatch(ctx, ids) + if err != nil { + return ctx // 失败开放 + } + return context.WithValue(ctx, rpmPrefetchContextKey, counts) +} + +// isAccountSchedulableForRPM 检查账号是否可根据 RPM 进行调度 +// 仅适用于 Anthropic OAuth/SetupToken 账号 +func (s *GatewayService) isAccountSchedulableForRPM(ctx context.Context, account *Account, isSticky bool) bool { + if !account.IsAnthropicOAuthOrSetupToken() { + return true + } + baseRPM := account.GetBaseRPM() + if baseRPM <= 0 { + return true + } + + // 尝试从预取缓存获取 + var currentRPM int + if count, ok := rpmFromPrefetchContext(ctx, account.ID); ok { + currentRPM = count + } else if s.rpmCache != nil { + if count, err := s.rpmCache.GetRPM(ctx, account.ID); err == nil { + currentRPM = count + } + // 失败开放:GetRPM 错误时允许调度 + } + + schedulability := account.CheckRPMSchedulability(currentRPM) + switch schedulability { + case WindowCostSchedulable: + return true + case WindowCostStickyOnly: + return isSticky + case WindowCostNotSchedulable: + return false + } + return true +} + +// IncrementAccountRPM increments the RPM counter for the given account. +// 已知 TOCTOU 竞态:调度时读取 RPM 计数与此处递增之间存在时间窗口, +// 高并发下可能短暂超出 RPM 限制。这是与 WindowCost 一致的 soft-limit +// 设计权衡——可接受的少量超额优于加锁带来的延迟和复杂度。 +func (s *GatewayService) IncrementAccountRPM(ctx context.Context, accountID int64) error { + if s.rpmCache == nil { + return nil + } + _, err := s.rpmCache.IncrementRPM(ctx, accountID) + return err +} + // checkAndRegisterSession 检查并注册会话,用于会话数量限制 // 仅适用于 Anthropic OAuth/SetupToken 账号 // sessionID: 会话标识符(使用粘性会话的 hash) @@ -2349,7 +2447,7 @@ func sameAccountWithLoadGroup(a, b accountWithLoad) bool { // shuffleWithinPriorityAndLastUsed 对排序后的 []*Account 切片,按 (Priority, LastUsedAt) 分组后组内随机打乱。 // // 注意:当 preferOAuth=true 时,需要保证 OAuth 账号在同组内仍然优先,否则会把排序时的偏好打散掉。 -// 因此这里采用“组内分区 + 分区内 shuffle”的方式: +// 因此这里采用"组内分区 + 分区内 shuffle"的方式: // - 先把同组账号按 (OAuth / 非 OAuth) 拆成两段,保持 OAuth 段在前; // - 再分别在各段内随机打散,避免热点。 func shuffleWithinPriorityAndLastUsed(accounts []*Account, preferOAuth bool) { @@ -2489,7 +2587,7 @@ func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context, if clearSticky { _ = s.cache.DeleteSessionAccountID(ctx, derefGroupID(groupID), sessionHash) } - if !clearSticky && s.isAccountInGroup(account, groupID) && account.Platform == platform && (requestedModel == "" || s.isModelSupportedByAccountWithContext(ctx, account, requestedModel)) && s.isAccountSchedulableForModelSelection(ctx, account, requestedModel) { + if !clearSticky && s.isAccountInGroup(account, groupID) && account.Platform == platform && (requestedModel == "" || s.isModelSupportedByAccountWithContext(ctx, account, requestedModel)) && s.isAccountSchedulableForModelSelection(ctx, account, requestedModel) && s.isAccountSchedulableForWindowCost(ctx, account, true) && s.isAccountSchedulableForRPM(ctx, account, true) { if s.debugModelRoutingEnabled() { logger.LegacyPrintf("service.gateway", "[ModelRoutingDebug] legacy routed sticky hit: group_id=%v model=%s session=%s account=%d", derefGroupID(groupID), requestedModel, shortSessionHash(sessionHash), accountID) } @@ -2512,6 +2610,10 @@ func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context, } accountsLoaded = true + // 提前预取窗口费用+RPM 计数,确保 routing 段内的调度检查调用能命中缓存 + ctx = s.withWindowCostPrefetch(ctx, accounts) + ctx = s.withRPMPrefetch(ctx, accounts) + routingSet := make(map[int64]struct{}, len(routingAccountIDs)) for _, id := range routingAccountIDs { if id > 0 { @@ -2539,6 +2641,12 @@ func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context, if !s.isAccountSchedulableForModelSelection(ctx, acc, requestedModel) { continue } + if !s.isAccountSchedulableForWindowCost(ctx, acc, false) { + continue + } + if !s.isAccountSchedulableForRPM(ctx, acc, false) { + continue + } if selected == nil { selected = acc continue @@ -2589,7 +2697,7 @@ func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context, if clearSticky { _ = s.cache.DeleteSessionAccountID(ctx, derefGroupID(groupID), sessionHash) } - if !clearSticky && s.isAccountInGroup(account, groupID) && account.Platform == platform && (requestedModel == "" || s.isModelSupportedByAccountWithContext(ctx, account, requestedModel)) && s.isAccountSchedulableForModelSelection(ctx, account, requestedModel) { + if !clearSticky && s.isAccountInGroup(account, groupID) && account.Platform == platform && (requestedModel == "" || s.isModelSupportedByAccountWithContext(ctx, account, requestedModel)) && s.isAccountSchedulableForModelSelection(ctx, account, requestedModel) && s.isAccountSchedulableForWindowCost(ctx, account, true) && s.isAccountSchedulableForRPM(ctx, account, true) { return account, nil } } @@ -2610,6 +2718,10 @@ func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context, } } + // 批量预取窗口费用+RPM 计数,避免逐个账号查询(N+1) + ctx = s.withWindowCostPrefetch(ctx, accounts) + ctx = s.withRPMPrefetch(ctx, accounts) + // 3. 按优先级+最久未用选择(考虑模型支持) var selected *Account for i := range accounts { @@ -2628,6 +2740,12 @@ func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context, if !s.isAccountSchedulableForModelSelection(ctx, acc, requestedModel) { continue } + if !s.isAccountSchedulableForWindowCost(ctx, acc, false) { + continue + } + if !s.isAccountSchedulableForRPM(ctx, acc, false) { + continue + } if selected == nil { selected = acc continue @@ -2697,7 +2815,7 @@ func (s *GatewayService) selectAccountWithMixedScheduling(ctx context.Context, g if clearSticky { _ = s.cache.DeleteSessionAccountID(ctx, derefGroupID(groupID), sessionHash) } - if !clearSticky && s.isAccountInGroup(account, groupID) && (requestedModel == "" || s.isModelSupportedByAccountWithContext(ctx, account, requestedModel)) && s.isAccountSchedulableForModelSelection(ctx, account, requestedModel) { + if !clearSticky && s.isAccountInGroup(account, groupID) && (requestedModel == "" || s.isModelSupportedByAccountWithContext(ctx, account, requestedModel)) && s.isAccountSchedulableForModelSelection(ctx, account, requestedModel) && s.isAccountSchedulableForWindowCost(ctx, account, true) && s.isAccountSchedulableForRPM(ctx, account, true) { if account.Platform == nativePlatform || (account.Platform == PlatformAntigravity && account.IsMixedSchedulingEnabled()) { if s.debugModelRoutingEnabled() { logger.LegacyPrintf("service.gateway", "[ModelRoutingDebug] legacy mixed routed sticky hit: group_id=%v model=%s session=%s account=%d", derefGroupID(groupID), requestedModel, shortSessionHash(sessionHash), accountID) @@ -2718,6 +2836,10 @@ func (s *GatewayService) selectAccountWithMixedScheduling(ctx context.Context, g } accountsLoaded = true + // 提前预取窗口费用+RPM 计数,确保 routing 段内的调度检查调用能命中缓存 + ctx = s.withWindowCostPrefetch(ctx, accounts) + ctx = s.withRPMPrefetch(ctx, accounts) + routingSet := make(map[int64]struct{}, len(routingAccountIDs)) for _, id := range routingAccountIDs { if id > 0 { @@ -2749,6 +2871,12 @@ func (s *GatewayService) selectAccountWithMixedScheduling(ctx context.Context, g if !s.isAccountSchedulableForModelSelection(ctx, acc, requestedModel) { continue } + if !s.isAccountSchedulableForWindowCost(ctx, acc, false) { + continue + } + if !s.isAccountSchedulableForRPM(ctx, acc, false) { + continue + } if selected == nil { selected = acc continue @@ -2799,7 +2927,7 @@ func (s *GatewayService) selectAccountWithMixedScheduling(ctx context.Context, g if clearSticky { _ = s.cache.DeleteSessionAccountID(ctx, derefGroupID(groupID), sessionHash) } - if !clearSticky && s.isAccountInGroup(account, groupID) && (requestedModel == "" || s.isModelSupportedByAccountWithContext(ctx, account, requestedModel)) && s.isAccountSchedulableForModelSelection(ctx, account, requestedModel) { + if !clearSticky && s.isAccountInGroup(account, groupID) && (requestedModel == "" || s.isModelSupportedByAccountWithContext(ctx, account, requestedModel)) && s.isAccountSchedulableForModelSelection(ctx, account, requestedModel) && s.isAccountSchedulableForWindowCost(ctx, account, true) && s.isAccountSchedulableForRPM(ctx, account, true) { if account.Platform == nativePlatform || (account.Platform == PlatformAntigravity && account.IsMixedSchedulingEnabled()) { return account, nil } @@ -2818,6 +2946,10 @@ func (s *GatewayService) selectAccountWithMixedScheduling(ctx context.Context, g } } + // 批量预取窗口费用+RPM 计数,避免逐个账号查询(N+1) + ctx = s.withWindowCostPrefetch(ctx, accounts) + ctx = s.withRPMPrefetch(ctx, accounts) + // 3. 按优先级+最久未用选择(考虑模型支持和混合调度) var selected *Account for i := range accounts { @@ -2840,6 +2972,12 @@ func (s *GatewayService) selectAccountWithMixedScheduling(ctx context.Context, g if !s.isAccountSchedulableForModelSelection(ctx, acc, requestedModel) { continue } + if !s.isAccountSchedulableForWindowCost(ctx, acc, false) { + continue + } + if !s.isAccountSchedulableForRPM(ctx, acc, false) { + continue + } if selected == nil { selected = acc continue @@ -5185,7 +5323,7 @@ func (s *GatewayService) isThinkingBlockSignatureError(respBody []byte) bool { } func (s *GatewayService) shouldFailoverOn400(respBody []byte) bool { - // 只对“可能是兼容性差异导致”的 400 允许切换,避免无意义重试。 + // 只对"可能是兼容性差异导致"的 400 允许切换,避免无意义重试。 // 默认保守:无法识别则不切换。 msg := strings.ToLower(strings.TrimSpace(extractUpstreamErrorMessage(respBody))) if msg == "" { diff --git a/backend/internal/service/rpm_cache.go b/backend/internal/service/rpm_cache.go new file mode 100644 index 00000000..07036219 --- /dev/null +++ b/backend/internal/service/rpm_cache.go @@ -0,0 +1,17 @@ +package service + +import "context" + +// RPMCache RPM 计数器缓存接口 +// 用于 Anthropic OAuth/SetupToken 账号的每分钟请求数限制 +type RPMCache interface { + // IncrementRPM 原子递增并返回当前分钟的计数 + // 使用 Redis 服务器时间确定 minute key,避免多实例时钟偏差 + IncrementRPM(ctx context.Context, accountID int64) (count int, err error) + + // GetRPM 获取当前分钟的 RPM 计数 + GetRPM(ctx context.Context, accountID int64) (count int, err error) + + // GetRPMBatch 批量获取多个账号的 RPM 计数(使用 Pipeline) + GetRPMBatch(ctx context.Context, accountIDs []int64) (map[int64]int, error) +} diff --git a/frontend/src/components/account/AccountCapacityCell.vue b/frontend/src/components/account/AccountCapacityCell.vue index ae338aca..2a4babf2 100644 --- a/frontend/src/components/account/AccountCapacityCell.vue +++ b/frontend/src/components/account/AccountCapacityCell.vue @@ -52,6 +52,25 @@ {{ account.max_sessions }} + + +
{{ t('admin.accounts.quotaControl.rpmLimit.baseRpmHint') }}
+{{ t('admin.accounts.quotaControl.rpmLimit.stickyBufferHint') }}
++ {{ t('admin.accounts.quotaControl.rpmLimit.hint') }} +
+{{ t('admin.accounts.quotaControl.rpmLimit.baseRpmHint') }}
+{{ t('admin.accounts.quotaControl.rpmLimit.stickyBufferHint') }}
++ {{ t('admin.accounts.quotaControl.rpmLimit.hint') }} +
+{{ t('admin.accounts.quotaControl.rpmLimit.baseRpmHint') }}
+{{ t('admin.accounts.quotaControl.rpmLimit.stickyBufferHint') }}
+