Merge pull request #679 from DaydreamCoding/feat/account-rpm-limit

feat: 添加账号级别 RPM(每分钟请求数)限流功能
This commit is contained in:
Wesley Liddick
2026-02-28 22:37:10 +08:00
committed by GitHub
27 changed files with 1174 additions and 31 deletions

View File

@@ -138,7 +138,8 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
accountTestService := service.NewAccountTestService(accountRepository, geminiTokenProvider, antigravityGatewayService, httpUpstream, configConfig)
crsSyncService := service.NewCRSSyncService(accountRepository, proxyRepository, oAuthService, openAIOAuthService, geminiOAuthService, configConfig)
sessionLimitCache := repository.ProvideSessionLimitCache(redisClient, configConfig)
accountHandler := admin.NewAccountHandler(adminService, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, rateLimitService, accountUsageService, accountTestService, concurrencyService, crsSyncService, sessionLimitCache, compositeTokenCacheInvalidator)
rpmCache := repository.NewRPMCache(redisClient)
accountHandler := admin.NewAccountHandler(adminService, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, rateLimitService, accountUsageService, accountTestService, concurrencyService, crsSyncService, sessionLimitCache, rpmCache, compositeTokenCacheInvalidator)
adminAnnouncementHandler := admin.NewAnnouncementHandler(announcementService)
dataManagementService := service.NewDataManagementService()
dataManagementHandler := admin.NewDataManagementHandler(dataManagementService)
@@ -160,7 +161,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
deferredService := service.ProvideDeferredService(accountRepository, timingWheelService)
claudeTokenProvider := service.NewClaudeTokenProvider(accountRepository, geminiTokenCache, oAuthService)
digestSessionStore := service.NewDigestSessionStore()
gatewayService := service.NewGatewayService(accountRepository, groupRepository, usageLogRepository, userRepository, userSubscriptionRepository, userGroupRateRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, identityService, httpUpstream, deferredService, claudeTokenProvider, sessionLimitCache, digestSessionStore)
gatewayService := service.NewGatewayService(accountRepository, groupRepository, usageLogRepository, userRepository, userSubscriptionRepository, userGroupRateRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, identityService, httpUpstream, deferredService, claudeTokenProvider, sessionLimitCache, rpmCache, digestSessionStore)
openAITokenProvider := service.NewOpenAITokenProvider(accountRepository, geminiTokenCache, openAIOAuthService)
openAIGatewayService := service.NewOpenAIGatewayService(accountRepository, usageLogRepository, userRepository, userSubscriptionRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, httpUpstream, deferredService, openAITokenProvider)
geminiMessagesCompatService := service.NewGeminiMessagesCompatService(accountRepository, groupRepository, gatewayCache, schedulerSnapshotService, geminiTokenProvider, rateLimitService, httpUpstream, antigravityGatewayService, configConfig)

View File

@@ -64,6 +64,7 @@ func setupAccountDataRouter() (*gin.Engine, *stubAdminService) {
nil,
nil,
nil,
nil,
)
router.GET("/api/v1/admin/accounts/data", h.ExportData)

View File

@@ -53,6 +53,7 @@ type AccountHandler struct {
concurrencyService *service.ConcurrencyService
crsSyncService *service.CRSSyncService
sessionLimitCache service.SessionLimitCache
rpmCache service.RPMCache
tokenCacheInvalidator service.TokenCacheInvalidator
}
@@ -69,6 +70,7 @@ func NewAccountHandler(
concurrencyService *service.ConcurrencyService,
crsSyncService *service.CRSSyncService,
sessionLimitCache service.SessionLimitCache,
rpmCache service.RPMCache,
tokenCacheInvalidator service.TokenCacheInvalidator,
) *AccountHandler {
return &AccountHandler{
@@ -83,6 +85,7 @@ func NewAccountHandler(
concurrencyService: concurrencyService,
crsSyncService: crsSyncService,
sessionLimitCache: sessionLimitCache,
rpmCache: rpmCache,
tokenCacheInvalidator: tokenCacheInvalidator,
}
}
@@ -154,6 +157,7 @@ type AccountWithConcurrency struct {
// 以下字段仅对 Anthropic OAuth/SetupToken 账号有效,且仅在启用相应功能时返回
CurrentWindowCost *float64 `json:"current_window_cost,omitempty"` // 当前窗口费用
ActiveSessions *int `json:"active_sessions,omitempty"` // 当前活跃会话数
CurrentRPM *int `json:"current_rpm,omitempty"` // 当前分钟 RPM 计数
}
func (h *AccountHandler) buildAccountResponseWithRuntime(ctx context.Context, account *service.Account) AccountWithConcurrency {
@@ -189,6 +193,12 @@ func (h *AccountHandler) buildAccountResponseWithRuntime(ctx context.Context, ac
}
}
}
if h.rpmCache != nil && account.GetBaseRPM() > 0 {
if rpm, err := h.rpmCache.GetRPM(ctx, account.ID); err == nil {
item.CurrentRPM = &rpm
}
}
}
return item
@@ -231,9 +241,10 @@ func (h *AccountHandler) List(c *gin.Context) {
concurrencyCounts = make(map[int64]int)
}
// 识别需要查询窗口费用会话数的账号Anthropic OAuth/SetupToken 且启用了相应功能)
// 识别需要查询窗口费用会话数和 RPM 的账号Anthropic OAuth/SetupToken 且启用了相应功能)
windowCostAccountIDs := make([]int64, 0)
sessionLimitAccountIDs := make([]int64, 0)
rpmAccountIDs := make([]int64, 0)
sessionIdleTimeouts := make(map[int64]time.Duration) // 各账号的会话空闲超时配置
for i := range accounts {
acc := &accounts[i]
@@ -245,12 +256,24 @@ func (h *AccountHandler) List(c *gin.Context) {
sessionLimitAccountIDs = append(sessionLimitAccountIDs, acc.ID)
sessionIdleTimeouts[acc.ID] = time.Duration(acc.GetSessionIdleTimeoutMinutes()) * time.Minute
}
if acc.GetBaseRPM() > 0 {
rpmAccountIDs = append(rpmAccountIDs, acc.ID)
}
}
}
// 并行获取窗口费用活跃会话数
// 并行获取窗口费用活跃会话数和 RPM 计数
var windowCosts map[int64]float64
var activeSessions map[int64]int
var rpmCounts map[int64]int
// 获取 RPM 计数(批量查询)
if len(rpmAccountIDs) > 0 && h.rpmCache != nil {
rpmCounts, _ = h.rpmCache.GetRPMBatch(c.Request.Context(), rpmAccountIDs)
if rpmCounts == nil {
rpmCounts = make(map[int64]int)
}
}
// 获取活跃会话数(批量查询,传入各账号的 idleTimeout 配置)
if len(sessionLimitAccountIDs) > 0 && h.sessionLimitCache != nil {
@@ -311,6 +334,13 @@ func (h *AccountHandler) List(c *gin.Context) {
}
}
// 添加 RPM 计数(仅当启用时)
if rpmCounts != nil {
if rpm, ok := rpmCounts[acc.ID]; ok {
item.CurrentRPM = &rpm
}
}
result[i] = item
}
@@ -453,6 +483,8 @@ func (h *AccountHandler) Create(c *gin.Context) {
response.BadRequest(c, "rate_multiplier must be >= 0")
return
}
// base_rpm 输入校验:负值归零,超过 10000 截断
sanitizeExtraBaseRPM(req.Extra)
// 确定是否跳过混合渠道检查
skipCheck := req.ConfirmMixedChannelRisk != nil && *req.ConfirmMixedChannelRisk
@@ -522,6 +554,8 @@ func (h *AccountHandler) Update(c *gin.Context) {
response.BadRequest(c, "rate_multiplier must be >= 0")
return
}
// base_rpm 输入校验:负值归零,超过 10000 截断
sanitizeExtraBaseRPM(req.Extra)
// 确定是否跳过混合渠道检查
skipCheck := req.ConfirmMixedChannelRisk != nil && *req.ConfirmMixedChannelRisk
@@ -904,6 +938,9 @@ func (h *AccountHandler) BatchCreate(c *gin.Context) {
continue
}
// base_rpm 输入校验:负值归零,超过 10000 截断
sanitizeExtraBaseRPM(item.Extra)
skipCheck := item.ConfirmMixedChannelRisk != nil && *item.ConfirmMixedChannelRisk
account, err := h.adminService.CreateAccount(ctx, &service.CreateAccountInput{
@@ -1048,6 +1085,8 @@ func (h *AccountHandler) BulkUpdate(c *gin.Context) {
response.BadRequest(c, "rate_multiplier must be >= 0")
return
}
// base_rpm 输入校验:负值归零,超过 10000 截断
sanitizeExtraBaseRPM(req.Extra)
// 确定是否跳过混合渠道检查
skipCheck := req.ConfirmMixedChannelRisk != nil && *req.ConfirmMixedChannelRisk
@@ -1706,3 +1745,22 @@ func (h *AccountHandler) BatchRefreshTier(c *gin.Context) {
func (h *AccountHandler) GetAntigravityDefaultModelMapping(c *gin.Context) {
response.Success(c, domain.DefaultAntigravityModelMapping)
}
// sanitizeExtraBaseRPM 对 extra map 中的 base_rpm 值进行范围校验和归一化。
// 负值归零,超过 10000 截断为 10000。extra 为 nil 或不含 base_rpm 时无操作。
func sanitizeExtraBaseRPM(extra map[string]any) {
if extra == nil {
return
}
raw, ok := extra["base_rpm"]
if !ok {
return
}
v := service.ParseExtraInt(raw)
if v < 0 {
v = 0
} else if v > 10000 {
v = 10000
}
extra["base_rpm"] = v
}

View File

@@ -15,7 +15,7 @@ import (
func setupAccountMixedChannelRouter(adminSvc *stubAdminService) *gin.Engine {
gin.SetMode(gin.TestMode)
router := gin.New()
accountHandler := NewAccountHandler(adminSvc, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil)
accountHandler := NewAccountHandler(adminSvc, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil)
router.POST("/api/v1/admin/accounts/check-mixed-channel", accountHandler.CheckMixedChannel)
router.POST("/api/v1/admin/accounts", accountHandler.Create)
router.PUT("/api/v1/admin/accounts/:id", accountHandler.Update)

View File

@@ -28,6 +28,7 @@ func TestAccountHandler_Create_AnthropicAPIKeyPassthroughExtraForwarded(t *testi
nil,
nil,
nil,
nil,
)
router := gin.New()

View File

@@ -36,7 +36,7 @@ func (f *failingAdminService) UpdateAccount(ctx context.Context, id int64, input
func setupAccountHandlerWithService(adminSvc service.AdminService) (*gin.Engine, *AccountHandler) {
gin.SetMode(gin.TestMode)
router := gin.New()
handler := NewAccountHandler(adminSvc, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil)
handler := NewAccountHandler(adminSvc, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil)
router.POST("/api/v1/admin/accounts/batch-update-credentials", handler.BatchUpdateCredentials)
return router, handler
}

View File

@@ -209,6 +209,13 @@ func AccountFromServiceShallow(a *service.Account) *Account {
if idleTimeout := a.GetSessionIdleTimeoutMinutes(); idleTimeout > 0 {
out.SessionIdleTimeoutMin = &idleTimeout
}
if rpm := a.GetBaseRPM(); rpm > 0 {
out.BaseRPM = &rpm
strategy := a.GetRPMStrategy()
out.RPMStrategy = &strategy
buffer := a.GetRPMStickyBuffer()
out.RPMStickyBuffer = &buffer
}
// TLS指纹伪装开关
if a.IsTLSFingerprintEnabled() {
enabled := true

View File

@@ -153,6 +153,12 @@ type Account struct {
MaxSessions *int `json:"max_sessions,omitempty"`
SessionIdleTimeoutMin *int `json:"session_idle_timeout_minutes,omitempty"`
// RPM 限制(仅 Anthropic OAuth/SetupToken 账号有效)
// 从 extra 字段提取,方便前端显示和编辑
BaseRPM *int `json:"base_rpm,omitempty"`
RPMStrategy *string `json:"rpm_strategy,omitempty"`
RPMStickyBuffer *int `json:"rpm_sticky_buffer,omitempty"`
// TLS指纹伪装仅 Anthropic OAuth/SetupToken 账号有效)
// 从 extra 字段提取,方便前端显示和编辑
EnableTLSFingerprint *bool `json:"enable_tls_fingerprint,omitempty"`

View File

@@ -403,6 +403,15 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
return
}
// RPM 计数递增Forward 成功后)
// 注意TOCTOU 竞态是已知且可接受的设计权衡,与 WindowCost 一致的 soft-limit 模式。
// 在高并发下可能短暂超出 RPM 限制,但不会导致请求失败。
if account.IsAnthropicOAuthOrSetupToken() && account.GetBaseRPM() > 0 {
if err := h.gatewayService.IncrementAccountRPM(c.Request.Context(), account.ID); err != nil {
reqLog.Warn("gateway.rpm_increment_failed", zap.Int64("account_id", account.ID), zap.Error(err))
}
}
// 捕获请求信息(用于异步记录,避免在 goroutine 中访问 gin.Context
userAgent := c.GetHeader("User-Agent")
clientIP := ip.GetClientIP(c)
@@ -595,7 +604,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
h.handleStreamingAwareError(c, status, code, message, streamStarted)
return
}
// 兜底重试按直接请求兜底分组处理:清除强制平台,允许按分组平台调度
// 兜底重试按"直接请求兜底分组"处理:清除强制平台,允许按分组平台调度
ctx := context.WithValue(c.Request.Context(), ctxkey.ForcePlatform, "")
c.Request = c.Request.WithContext(ctx)
currentAPIKey = fallbackAPIKey
@@ -629,6 +638,15 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
return
}
// RPM 计数递增Forward 成功后)
// 注意TOCTOU 竞态是已知且可接受的设计权衡,与 WindowCost 一致的 soft-limit 模式。
// 在高并发下可能短暂超出 RPM 限制,但不会导致请求失败。
if account.IsAnthropicOAuthOrSetupToken() && account.GetBaseRPM() > 0 {
if err := h.gatewayService.IncrementAccountRPM(c.Request.Context(), account.ID); err != nil {
reqLog.Warn("gateway.rpm_increment_failed", zap.Int64("account_id", account.ID), zap.Error(err))
}
}
// 捕获请求信息(用于异步记录,避免在 goroutine 中访问 gin.Context
userAgent := c.GetHeader("User-Agent")
clientIP := ip.GetClientIP(c)

View File

@@ -153,6 +153,7 @@ func newTestGatewayHandler(t *testing.T, group *service.Group, accounts []*servi
nil, // deferredService
nil, // claudeTokenProvider
nil, // sessionLimitCache
nil, // rpmCache
nil, // digestStore
)

View File

@@ -2184,7 +2184,7 @@ func (s *stubSoraClientForHandler) GetVideoTask(_ context.Context, _ *service.Ac
func newMinimalGatewayService(accountRepo service.AccountRepository) *service.GatewayService {
return service.NewGatewayService(
accountRepo, nil, nil, nil, nil, nil, nil, nil,
nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil,
nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil,
)
}

View File

@@ -426,7 +426,8 @@ func TestSoraGatewayHandler_ChatCompletions(t *testing.T) {
deferredService,
nil,
testutil.StubSessionLimitCache{},
nil,
nil, // rpmCache
nil, // digestStore
)
soraClient := &stubSoraClient{imageURLs: []string{"https://example.com/a.png"}}

View File

@@ -0,0 +1,141 @@
package repository
import (
"context"
"errors"
"fmt"
"strconv"
"time"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/redis/go-redis/v9"
)
// RPM 计数器缓存常量定义
//
// 设计说明:
// 使用 Redis 简单计数器跟踪每个账号每分钟的请求数:
// - Key: rpm:{accountID}:{minuteTimestamp}
// - Value: 当前分钟内的请求计数
// - TTL: 120 秒(覆盖当前分钟 + 一定冗余)
//
// 使用 TxPipelineMULTI/EXEC执行 INCR + EXPIRE保证原子性且兼容 Redis Cluster。
// 通过 rdb.Time() 获取服务端时间,避免多实例时钟不同步。
//
// 设计决策:
// - TxPipeline vs PipelinePipeline 仅合并发送但不保证原子TxPipeline 使用 MULTI/EXEC 事务保证原子执行。
// - rdb.Time() 单独调用Pipeline/TxPipeline 中无法引用前一命令的结果,因此 TIME 必须单独调用2 RTT
// Lua 脚本可以做到 1 RTT但在 Redis Cluster 中动态拼接 key 存在 CROSSSLOT 风险,选择安全性优先。
const (
// RPM 计数器键前缀
// 格式: rpm:{accountID}:{minuteTimestamp}
rpmKeyPrefix = "rpm:"
// RPM 计数器 TTL120 秒,覆盖当前分钟窗口 + 冗余)
rpmKeyTTL = 120 * time.Second
)
// RPMCacheImpl RPM 计数器缓存 Redis 实现
type RPMCacheImpl struct {
rdb *redis.Client
}
// NewRPMCache 创建 RPM 计数器缓存
func NewRPMCache(rdb *redis.Client) service.RPMCache {
return &RPMCacheImpl{rdb: rdb}
}
// currentMinuteKey 获取当前分钟的完整 Redis key
// 使用 rdb.Time() 获取 Redis 服务端时间,避免多实例时钟偏差
func (c *RPMCacheImpl) currentMinuteKey(ctx context.Context, accountID int64) (string, error) {
serverTime, err := c.rdb.Time(ctx).Result()
if err != nil {
return "", fmt.Errorf("redis TIME: %w", err)
}
minuteTS := serverTime.Unix() / 60
return fmt.Sprintf("%s%d:%d", rpmKeyPrefix, accountID, minuteTS), nil
}
// currentMinuteSuffix 获取当前分钟时间戳后缀(供批量操作使用)
// 使用 rdb.Time() 获取 Redis 服务端时间
func (c *RPMCacheImpl) currentMinuteSuffix(ctx context.Context) (string, error) {
serverTime, err := c.rdb.Time(ctx).Result()
if err != nil {
return "", fmt.Errorf("redis TIME: %w", err)
}
minuteTS := serverTime.Unix() / 60
return strconv.FormatInt(minuteTS, 10), nil
}
// IncrementRPM 原子递增并返回当前分钟的计数
// 使用 TxPipeline (MULTI/EXEC) 执行 INCR + EXPIRE保证原子性且兼容 Redis Cluster
func (c *RPMCacheImpl) IncrementRPM(ctx context.Context, accountID int64) (int, error) {
key, err := c.currentMinuteKey(ctx, accountID)
if err != nil {
return 0, fmt.Errorf("rpm increment: %w", err)
}
// 使用 TxPipeline (MULTI/EXEC) 保证 INCR + EXPIRE 原子执行
// EXPIRE 幂等,每次都设置不影响正确性
pipe := c.rdb.TxPipeline()
incrCmd := pipe.Incr(ctx, key)
pipe.Expire(ctx, key, rpmKeyTTL)
if _, err := pipe.Exec(ctx); err != nil {
return 0, fmt.Errorf("rpm increment: %w", err)
}
return int(incrCmd.Val()), nil
}
// GetRPM 获取当前分钟的 RPM 计数
func (c *RPMCacheImpl) GetRPM(ctx context.Context, accountID int64) (int, error) {
key, err := c.currentMinuteKey(ctx, accountID)
if err != nil {
return 0, fmt.Errorf("rpm get: %w", err)
}
val, err := c.rdb.Get(ctx, key).Int()
if errors.Is(err, redis.Nil) {
return 0, nil // 当前分钟无记录
}
if err != nil {
return 0, fmt.Errorf("rpm get: %w", err)
}
return val, nil
}
// GetRPMBatch 批量获取多个账号的 RPM 计数(使用 Pipeline
func (c *RPMCacheImpl) GetRPMBatch(ctx context.Context, accountIDs []int64) (map[int64]int, error) {
if len(accountIDs) == 0 {
return map[int64]int{}, nil
}
// 获取当前分钟后缀
minuteSuffix, err := c.currentMinuteSuffix(ctx)
if err != nil {
return nil, fmt.Errorf("rpm batch get: %w", err)
}
// 使用 Pipeline 批量 GET
pipe := c.rdb.Pipeline()
cmds := make(map[int64]*redis.StringCmd, len(accountIDs))
for _, id := range accountIDs {
key := fmt.Sprintf("%s%d:%s", rpmKeyPrefix, id, minuteSuffix)
cmds[id] = pipe.Get(ctx, key)
}
if _, err := pipe.Exec(ctx); err != nil && !errors.Is(err, redis.Nil) {
return nil, fmt.Errorf("rpm batch get: %w", err)
}
result := make(map[int64]int, len(accountIDs))
for id, cmd := range cmds {
if val, err := cmd.Int(); err == nil {
result[id] = val
} else {
result[id] = 0
}
}
return result, nil
}

View File

@@ -79,6 +79,7 @@ var ProviderSet = wire.NewSet(
NewTimeoutCounterCache,
ProvideConcurrencyCache,
ProvideSessionLimitCache,
NewRPMCache,
NewDashboardCache,
NewEmailCache,
NewIdentityCache,

View File

@@ -624,7 +624,7 @@ func newContractDeps(t *testing.T) *contractDeps {
apiKeyHandler := handler.NewAPIKeyHandler(apiKeyService)
usageHandler := handler.NewUsageHandler(usageService, apiKeyService)
adminSettingHandler := adminhandler.NewSettingHandler(settingService, nil, nil, nil, nil)
adminAccountHandler := adminhandler.NewAccountHandler(adminService, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil)
adminAccountHandler := adminhandler.NewAccountHandler(adminService, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil)
jwtAuth := func(c *gin.Context) {
c.Set(string(middleware.ContextKeyUser), middleware.AuthSubject{

View File

@@ -1137,6 +1137,80 @@ func (a *Account) GetSessionIdleTimeoutMinutes() int {
return 5
}
// GetBaseRPM 获取基础 RPM 限制
// 返回 0 表示未启用(负数视为无效配置,按 0 处理)
func (a *Account) GetBaseRPM() int {
if a.Extra == nil {
return 0
}
if v, ok := a.Extra["base_rpm"]; ok {
val := parseExtraInt(v)
if val > 0 {
return val
}
}
return 0
}
// GetRPMStrategy 获取 RPM 策略
// "tiered" = 三区模型(默认), "sticky_exempt" = 粘性豁免
func (a *Account) GetRPMStrategy() string {
if a.Extra == nil {
return "tiered"
}
if v, ok := a.Extra["rpm_strategy"]; ok {
if s, ok := v.(string); ok && s == "sticky_exempt" {
return "sticky_exempt"
}
}
return "tiered"
}
// GetRPMStickyBuffer 获取 RPM 粘性缓冲数量
// tiered 模式下的黄区大小,默认为 base_rpm 的 20%(至少 1
func (a *Account) GetRPMStickyBuffer() int {
if a.Extra == nil {
return 0
}
if v, ok := a.Extra["rpm_sticky_buffer"]; ok {
val := parseExtraInt(v)
if val > 0 {
return val
}
}
base := a.GetBaseRPM()
buffer := base / 5
if buffer < 1 && base > 0 {
buffer = 1
}
return buffer
}
// CheckRPMSchedulability 根据当前 RPM 计数检查调度状态
// 复用 WindowCostSchedulability 三态Schedulable / StickyOnly / NotSchedulable
func (a *Account) CheckRPMSchedulability(currentRPM int) WindowCostSchedulability {
baseRPM := a.GetBaseRPM()
if baseRPM <= 0 {
return WindowCostSchedulable
}
if currentRPM < baseRPM {
return WindowCostSchedulable
}
strategy := a.GetRPMStrategy()
if strategy == "sticky_exempt" {
return WindowCostStickyOnly // 粘性豁免无红区
}
// tiered: 黄区 + 红区
buffer := a.GetRPMStickyBuffer()
if currentRPM < baseRPM+buffer {
return WindowCostStickyOnly
}
return WindowCostNotSchedulable
}
// CheckWindowCostSchedulability 根据当前窗口费用检查调度状态
// - 费用 < 阈值: WindowCostSchedulable可正常调度
// - 费用 >= 阈值 且 < 阈值+预留: WindowCostStickyOnly仅粘性会话
@@ -1200,6 +1274,12 @@ func parseExtraFloat64(value any) float64 {
}
// parseExtraInt 从 extra 字段解析 int 值
// ParseExtraInt 从 extra 字段的 any 值解析为 int。
// 支持 int, int64, float64, json.Number, string 类型,无法解析时返回 0。
func ParseExtraInt(value any) int {
return parseExtraInt(value)
}
func parseExtraInt(value any) int {
switch v := value.(type) {
case int:

View File

@@ -0,0 +1,120 @@
package service
import (
"encoding/json"
"testing"
)
func TestGetBaseRPM(t *testing.T) {
tests := []struct {
name string
extra map[string]any
expected int
}{
{"nil extra", nil, 0},
{"no key", map[string]any{}, 0},
{"zero", map[string]any{"base_rpm": 0}, 0},
{"int value", map[string]any{"base_rpm": 15}, 15},
{"float value", map[string]any{"base_rpm": 15.0}, 15},
{"string value", map[string]any{"base_rpm": "15"}, 15},
{"negative value", map[string]any{"base_rpm": -5}, 0},
{"int64 value", map[string]any{"base_rpm": int64(20)}, 20},
{"json.Number value", map[string]any{"base_rpm": json.Number("25")}, 25},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
a := &Account{Extra: tt.extra}
if got := a.GetBaseRPM(); got != tt.expected {
t.Errorf("GetBaseRPM() = %d, want %d", got, tt.expected)
}
})
}
}
func TestGetRPMStrategy(t *testing.T) {
tests := []struct {
name string
extra map[string]any
expected string
}{
{"nil extra", nil, "tiered"},
{"no key", map[string]any{}, "tiered"},
{"tiered", map[string]any{"rpm_strategy": "tiered"}, "tiered"},
{"sticky_exempt", map[string]any{"rpm_strategy": "sticky_exempt"}, "sticky_exempt"},
{"invalid", map[string]any{"rpm_strategy": "foobar"}, "tiered"},
{"empty string fallback", map[string]any{"rpm_strategy": ""}, "tiered"},
{"numeric value fallback", map[string]any{"rpm_strategy": 123}, "tiered"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
a := &Account{Extra: tt.extra}
if got := a.GetRPMStrategy(); got != tt.expected {
t.Errorf("GetRPMStrategy() = %q, want %q", got, tt.expected)
}
})
}
}
func TestCheckRPMSchedulability(t *testing.T) {
tests := []struct {
name string
extra map[string]any
currentRPM int
expected WindowCostSchedulability
}{
{"disabled", map[string]any{}, 100, WindowCostSchedulable},
{"green zone", map[string]any{"base_rpm": 15}, 10, WindowCostSchedulable},
{"yellow zone tiered", map[string]any{"base_rpm": 15}, 15, WindowCostStickyOnly},
{"red zone tiered", map[string]any{"base_rpm": 15}, 18, WindowCostNotSchedulable},
{"sticky_exempt at limit", map[string]any{"base_rpm": 15, "rpm_strategy": "sticky_exempt"}, 15, WindowCostStickyOnly},
{"sticky_exempt over limit", map[string]any{"base_rpm": 15, "rpm_strategy": "sticky_exempt"}, 100, WindowCostStickyOnly},
{"custom buffer", map[string]any{"base_rpm": 10, "rpm_sticky_buffer": 5}, 14, WindowCostStickyOnly},
{"custom buffer red", map[string]any{"base_rpm": 10, "rpm_sticky_buffer": 5}, 15, WindowCostNotSchedulable},
{"base_rpm=1 green", map[string]any{"base_rpm": 1}, 0, WindowCostSchedulable},
{"base_rpm=1 yellow (at limit)", map[string]any{"base_rpm": 1}, 1, WindowCostStickyOnly},
{"base_rpm=1 red (at limit+buffer)", map[string]any{"base_rpm": 1}, 2, WindowCostNotSchedulable},
{"negative currentRPM", map[string]any{"base_rpm": 15}, -1, WindowCostSchedulable},
{"base_rpm negative disabled", map[string]any{"base_rpm": -5}, 10, WindowCostSchedulable},
{"very high currentRPM", map[string]any{"base_rpm": 10}, 9999, WindowCostNotSchedulable},
{"sticky_exempt very high currentRPM", map[string]any{"base_rpm": 10, "rpm_strategy": "sticky_exempt"}, 9999, WindowCostStickyOnly},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
a := &Account{Extra: tt.extra}
if got := a.CheckRPMSchedulability(tt.currentRPM); got != tt.expected {
t.Errorf("CheckRPMSchedulability(%d) = %d, want %d", tt.currentRPM, got, tt.expected)
}
})
}
}
func TestGetRPMStickyBuffer(t *testing.T) {
tests := []struct {
name string
extra map[string]any
expected int
}{
{"nil extra", nil, 0},
{"no keys", map[string]any{}, 0},
{"base_rpm=0", map[string]any{"base_rpm": 0}, 0},
{"base_rpm=1 min buffer 1", map[string]any{"base_rpm": 1}, 1},
{"base_rpm=4 min buffer 1", map[string]any{"base_rpm": 4}, 1},
{"base_rpm=5 buffer 1", map[string]any{"base_rpm": 5}, 1},
{"base_rpm=10 buffer 2", map[string]any{"base_rpm": 10}, 2},
{"base_rpm=15 buffer 3", map[string]any{"base_rpm": 15}, 3},
{"base_rpm=100 buffer 20", map[string]any{"base_rpm": 100}, 20},
{"custom buffer=5", map[string]any{"base_rpm": 10, "rpm_sticky_buffer": 5}, 5},
{"custom buffer=0 fallback to default", map[string]any{"base_rpm": 10, "rpm_sticky_buffer": 0}, 2},
{"custom buffer negative fallback", map[string]any{"base_rpm": 10, "rpm_sticky_buffer": -1}, 2},
{"custom buffer with float", map[string]any{"base_rpm": 10, "rpm_sticky_buffer": float64(7)}, 7},
{"json.Number base_rpm", map[string]any{"base_rpm": json.Number("10")}, 2},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
a := &Account{Extra: tt.extra}
if got := a.GetRPMStickyBuffer(); got != tt.expected {
t.Errorf("GetRPMStickyBuffer() = %d, want %d", got, tt.expected)
}
})
}
}

View File

@@ -520,6 +520,7 @@ type GatewayService struct {
concurrencyService *ConcurrencyService
claudeTokenProvider *ClaudeTokenProvider
sessionLimitCache SessionLimitCache // 会话数量限制缓存(仅 Anthropic OAuth/SetupToken
rpmCache RPMCache // RPM 计数缓存(仅 Anthropic OAuth/SetupToken
userGroupRateCache *gocache.Cache
userGroupRateSF singleflight.Group
modelsListCache *gocache.Cache
@@ -549,6 +550,7 @@ func NewGatewayService(
deferredService *DeferredService,
claudeTokenProvider *ClaudeTokenProvider,
sessionLimitCache SessionLimitCache,
rpmCache RPMCache,
digestStore *DigestSessionStore,
) *GatewayService {
userGroupRateTTL := resolveUserGroupRateCacheTTL(cfg)
@@ -574,6 +576,7 @@ func NewGatewayService(
deferredService: deferredService,
claudeTokenProvider: claudeTokenProvider,
sessionLimitCache: sessionLimitCache,
rpmCache: rpmCache,
userGroupRateCache: gocache.New(userGroupRateTTL, time.Minute),
modelsListCache: gocache.New(modelsListTTL, time.Minute),
modelsListCacheTTL: modelsListTTL,
@@ -1154,6 +1157,7 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
return nil, errors.New("no available accounts")
}
ctx = s.withWindowCostPrefetch(ctx, accounts)
ctx = s.withRPMPrefetch(ctx, accounts)
isExcluded := func(accountID int64) bool {
if excludedIDs == nil {
@@ -1229,6 +1233,10 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
filteredWindowCost++
continue
}
// RPM 检查(非粘性会话路径)
if !s.isAccountSchedulableForRPM(ctx, account, false) {
continue
}
routingCandidates = append(routingCandidates, account)
}
@@ -1252,7 +1260,9 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
s.isAccountAllowedForPlatform(stickyAccount, platform, useMixed) &&
(requestedModel == "" || s.isModelSupportedByAccountWithContext(ctx, stickyAccount, requestedModel)) &&
s.isAccountSchedulableForModelSelection(ctx, stickyAccount, requestedModel) &&
s.isAccountSchedulableForWindowCost(ctx, stickyAccount, true) { // 粘性会话窗口费用检查
s.isAccountSchedulableForWindowCost(ctx, stickyAccount, true) &&
s.isAccountSchedulableForRPM(ctx, stickyAccount, true) { // 粘性会话窗口费用+RPM 检查
result, err := s.tryAcquireAccountSlot(ctx, stickyAccountID, stickyAccount.Concurrency)
if err == nil && result.Acquired {
// 会话数量限制检查
@@ -1406,7 +1416,9 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
s.isAccountAllowedForPlatform(account, platform, useMixed) &&
(requestedModel == "" || s.isModelSupportedByAccountWithContext(ctx, account, requestedModel)) &&
s.isAccountSchedulableForModelSelection(ctx, account, requestedModel) &&
s.isAccountSchedulableForWindowCost(ctx, account, true) { // 粘性会话窗口费用检查
s.isAccountSchedulableForWindowCost(ctx, account, true) &&
s.isAccountSchedulableForRPM(ctx, account, true) { // 粘性会话窗口费用+RPM 检查
result, err := s.tryAcquireAccountSlot(ctx, accountID, account.Concurrency)
if err == nil && result.Acquired {
// 会话数量限制检查
@@ -1472,6 +1484,10 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
if !s.isAccountSchedulableForWindowCost(ctx, acc, false) {
continue
}
// RPM 检查(非粘性会话路径)
if !s.isAccountSchedulableForRPM(ctx, acc, false) {
continue
}
candidates = append(candidates, acc)
}
@@ -2155,6 +2171,88 @@ checkSchedulability:
return true
}
// rpmPrefetchContextKey is the context key for prefetched RPM counts.
type rpmPrefetchContextKeyType struct{}
var rpmPrefetchContextKey = rpmPrefetchContextKeyType{}
func rpmFromPrefetchContext(ctx context.Context, accountID int64) (int, bool) {
if v, ok := ctx.Value(rpmPrefetchContextKey).(map[int64]int); ok {
count, found := v[accountID]
return count, found
}
return 0, false
}
// withRPMPrefetch 批量预取所有候选账号的 RPM 计数
func (s *GatewayService) withRPMPrefetch(ctx context.Context, accounts []Account) context.Context {
if s.rpmCache == nil {
return ctx
}
var ids []int64
for i := range accounts {
if accounts[i].IsAnthropicOAuthOrSetupToken() && accounts[i].GetBaseRPM() > 0 {
ids = append(ids, accounts[i].ID)
}
}
if len(ids) == 0 {
return ctx
}
counts, err := s.rpmCache.GetRPMBatch(ctx, ids)
if err != nil {
return ctx // 失败开放
}
return context.WithValue(ctx, rpmPrefetchContextKey, counts)
}
// isAccountSchedulableForRPM 检查账号是否可根据 RPM 进行调度
// 仅适用于 Anthropic OAuth/SetupToken 账号
func (s *GatewayService) isAccountSchedulableForRPM(ctx context.Context, account *Account, isSticky bool) bool {
if !account.IsAnthropicOAuthOrSetupToken() {
return true
}
baseRPM := account.GetBaseRPM()
if baseRPM <= 0 {
return true
}
// 尝试从预取缓存获取
var currentRPM int
if count, ok := rpmFromPrefetchContext(ctx, account.ID); ok {
currentRPM = count
} else if s.rpmCache != nil {
if count, err := s.rpmCache.GetRPM(ctx, account.ID); err == nil {
currentRPM = count
}
// 失败开放GetRPM 错误时允许调度
}
schedulability := account.CheckRPMSchedulability(currentRPM)
switch schedulability {
case WindowCostSchedulable:
return true
case WindowCostStickyOnly:
return isSticky
case WindowCostNotSchedulable:
return false
}
return true
}
// IncrementAccountRPM increments the RPM counter for the given account.
// 已知 TOCTOU 竞态:调度时读取 RPM 计数与此处递增之间存在时间窗口,
// 高并发下可能短暂超出 RPM 限制。这是与 WindowCost 一致的 soft-limit
// 设计权衡——可接受的少量超额优于加锁带来的延迟和复杂度。
func (s *GatewayService) IncrementAccountRPM(ctx context.Context, accountID int64) error {
if s.rpmCache == nil {
return nil
}
_, err := s.rpmCache.IncrementRPM(ctx, accountID)
return err
}
// checkAndRegisterSession 检查并注册会话,用于会话数量限制
// 仅适用于 Anthropic OAuth/SetupToken 账号
// sessionID: 会话标识符(使用粘性会话的 hash
@@ -2349,7 +2447,7 @@ func sameAccountWithLoadGroup(a, b accountWithLoad) bool {
// shuffleWithinPriorityAndLastUsed 对排序后的 []*Account 切片,按 (Priority, LastUsedAt) 分组后组内随机打乱。
//
// 注意:当 preferOAuth=true 时,需要保证 OAuth 账号在同组内仍然优先,否则会把排序时的偏好打散掉。
// 因此这里采用组内分区 + 分区内 shuffle的方式:
// 因此这里采用"组内分区 + 分区内 shuffle"的方式:
// - 先把同组账号按 (OAuth / 非 OAuth) 拆成两段,保持 OAuth 段在前;
// - 再分别在各段内随机打散,避免热点。
func shuffleWithinPriorityAndLastUsed(accounts []*Account, preferOAuth bool) {
@@ -2489,7 +2587,7 @@ func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context,
if clearSticky {
_ = s.cache.DeleteSessionAccountID(ctx, derefGroupID(groupID), sessionHash)
}
if !clearSticky && s.isAccountInGroup(account, groupID) && account.Platform == platform && (requestedModel == "" || s.isModelSupportedByAccountWithContext(ctx, account, requestedModel)) && s.isAccountSchedulableForModelSelection(ctx, account, requestedModel) {
if !clearSticky && s.isAccountInGroup(account, groupID) && account.Platform == platform && (requestedModel == "" || s.isModelSupportedByAccountWithContext(ctx, account, requestedModel)) && s.isAccountSchedulableForModelSelection(ctx, account, requestedModel) && s.isAccountSchedulableForWindowCost(ctx, account, true) && s.isAccountSchedulableForRPM(ctx, account, true) {
if s.debugModelRoutingEnabled() {
logger.LegacyPrintf("service.gateway", "[ModelRoutingDebug] legacy routed sticky hit: group_id=%v model=%s session=%s account=%d", derefGroupID(groupID), requestedModel, shortSessionHash(sessionHash), accountID)
}
@@ -2512,6 +2610,10 @@ func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context,
}
accountsLoaded = true
// 提前预取窗口费用+RPM 计数,确保 routing 段内的调度检查调用能命中缓存
ctx = s.withWindowCostPrefetch(ctx, accounts)
ctx = s.withRPMPrefetch(ctx, accounts)
routingSet := make(map[int64]struct{}, len(routingAccountIDs))
for _, id := range routingAccountIDs {
if id > 0 {
@@ -2539,6 +2641,12 @@ func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context,
if !s.isAccountSchedulableForModelSelection(ctx, acc, requestedModel) {
continue
}
if !s.isAccountSchedulableForWindowCost(ctx, acc, false) {
continue
}
if !s.isAccountSchedulableForRPM(ctx, acc, false) {
continue
}
if selected == nil {
selected = acc
continue
@@ -2589,7 +2697,7 @@ func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context,
if clearSticky {
_ = s.cache.DeleteSessionAccountID(ctx, derefGroupID(groupID), sessionHash)
}
if !clearSticky && s.isAccountInGroup(account, groupID) && account.Platform == platform && (requestedModel == "" || s.isModelSupportedByAccountWithContext(ctx, account, requestedModel)) && s.isAccountSchedulableForModelSelection(ctx, account, requestedModel) {
if !clearSticky && s.isAccountInGroup(account, groupID) && account.Platform == platform && (requestedModel == "" || s.isModelSupportedByAccountWithContext(ctx, account, requestedModel)) && s.isAccountSchedulableForModelSelection(ctx, account, requestedModel) && s.isAccountSchedulableForWindowCost(ctx, account, true) && s.isAccountSchedulableForRPM(ctx, account, true) {
return account, nil
}
}
@@ -2610,6 +2718,10 @@ func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context,
}
}
// 批量预取窗口费用+RPM 计数避免逐个账号查询N+1
ctx = s.withWindowCostPrefetch(ctx, accounts)
ctx = s.withRPMPrefetch(ctx, accounts)
// 3. 按优先级+最久未用选择(考虑模型支持)
var selected *Account
for i := range accounts {
@@ -2628,6 +2740,12 @@ func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context,
if !s.isAccountSchedulableForModelSelection(ctx, acc, requestedModel) {
continue
}
if !s.isAccountSchedulableForWindowCost(ctx, acc, false) {
continue
}
if !s.isAccountSchedulableForRPM(ctx, acc, false) {
continue
}
if selected == nil {
selected = acc
continue
@@ -2697,7 +2815,7 @@ func (s *GatewayService) selectAccountWithMixedScheduling(ctx context.Context, g
if clearSticky {
_ = s.cache.DeleteSessionAccountID(ctx, derefGroupID(groupID), sessionHash)
}
if !clearSticky && s.isAccountInGroup(account, groupID) && (requestedModel == "" || s.isModelSupportedByAccountWithContext(ctx, account, requestedModel)) && s.isAccountSchedulableForModelSelection(ctx, account, requestedModel) {
if !clearSticky && s.isAccountInGroup(account, groupID) && (requestedModel == "" || s.isModelSupportedByAccountWithContext(ctx, account, requestedModel)) && s.isAccountSchedulableForModelSelection(ctx, account, requestedModel) && s.isAccountSchedulableForWindowCost(ctx, account, true) && s.isAccountSchedulableForRPM(ctx, account, true) {
if account.Platform == nativePlatform || (account.Platform == PlatformAntigravity && account.IsMixedSchedulingEnabled()) {
if s.debugModelRoutingEnabled() {
logger.LegacyPrintf("service.gateway", "[ModelRoutingDebug] legacy mixed routed sticky hit: group_id=%v model=%s session=%s account=%d", derefGroupID(groupID), requestedModel, shortSessionHash(sessionHash), accountID)
@@ -2718,6 +2836,10 @@ func (s *GatewayService) selectAccountWithMixedScheduling(ctx context.Context, g
}
accountsLoaded = true
// 提前预取窗口费用+RPM 计数,确保 routing 段内的调度检查调用能命中缓存
ctx = s.withWindowCostPrefetch(ctx, accounts)
ctx = s.withRPMPrefetch(ctx, accounts)
routingSet := make(map[int64]struct{}, len(routingAccountIDs))
for _, id := range routingAccountIDs {
if id > 0 {
@@ -2749,6 +2871,12 @@ func (s *GatewayService) selectAccountWithMixedScheduling(ctx context.Context, g
if !s.isAccountSchedulableForModelSelection(ctx, acc, requestedModel) {
continue
}
if !s.isAccountSchedulableForWindowCost(ctx, acc, false) {
continue
}
if !s.isAccountSchedulableForRPM(ctx, acc, false) {
continue
}
if selected == nil {
selected = acc
continue
@@ -2799,7 +2927,7 @@ func (s *GatewayService) selectAccountWithMixedScheduling(ctx context.Context, g
if clearSticky {
_ = s.cache.DeleteSessionAccountID(ctx, derefGroupID(groupID), sessionHash)
}
if !clearSticky && s.isAccountInGroup(account, groupID) && (requestedModel == "" || s.isModelSupportedByAccountWithContext(ctx, account, requestedModel)) && s.isAccountSchedulableForModelSelection(ctx, account, requestedModel) {
if !clearSticky && s.isAccountInGroup(account, groupID) && (requestedModel == "" || s.isModelSupportedByAccountWithContext(ctx, account, requestedModel)) && s.isAccountSchedulableForModelSelection(ctx, account, requestedModel) && s.isAccountSchedulableForWindowCost(ctx, account, true) && s.isAccountSchedulableForRPM(ctx, account, true) {
if account.Platform == nativePlatform || (account.Platform == PlatformAntigravity && account.IsMixedSchedulingEnabled()) {
return account, nil
}
@@ -2818,6 +2946,10 @@ func (s *GatewayService) selectAccountWithMixedScheduling(ctx context.Context, g
}
}
// 批量预取窗口费用+RPM 计数避免逐个账号查询N+1
ctx = s.withWindowCostPrefetch(ctx, accounts)
ctx = s.withRPMPrefetch(ctx, accounts)
// 3. 按优先级+最久未用选择(考虑模型支持和混合调度)
var selected *Account
for i := range accounts {
@@ -2840,6 +2972,12 @@ func (s *GatewayService) selectAccountWithMixedScheduling(ctx context.Context, g
if !s.isAccountSchedulableForModelSelection(ctx, acc, requestedModel) {
continue
}
if !s.isAccountSchedulableForWindowCost(ctx, acc, false) {
continue
}
if !s.isAccountSchedulableForRPM(ctx, acc, false) {
continue
}
if selected == nil {
selected = acc
continue
@@ -5185,7 +5323,7 @@ func (s *GatewayService) isThinkingBlockSignatureError(respBody []byte) bool {
}
func (s *GatewayService) shouldFailoverOn400(respBody []byte) bool {
// 只对可能是兼容性差异导致的 400 允许切换,避免无意义重试。
// 只对"可能是兼容性差异导致"的 400 允许切换,避免无意义重试。
// 默认保守:无法识别则不切换。
msg := strings.ToLower(strings.TrimSpace(extractUpstreamErrorMessage(respBody)))
if msg == "" {

View File

@@ -0,0 +1,17 @@
package service
import "context"
// RPMCache RPM 计数器缓存接口
// 用于 Anthropic OAuth/SetupToken 账号的每分钟请求数限制
type RPMCache interface {
// IncrementRPM 原子递增并返回当前分钟的计数
// 使用 Redis 服务器时间确定 minute key避免多实例时钟偏差
IncrementRPM(ctx context.Context, accountID int64) (count int, err error)
// GetRPM 获取当前分钟的 RPM 计数
GetRPM(ctx context.Context, accountID int64) (count int, err error)
// GetRPMBatch 批量获取多个账号的 RPM 计数(使用 Pipeline
GetRPMBatch(ctx context.Context, accountIDs []int64) (map[int64]int, error)
}