fix: address code review issues for RPM limiting feature

- Use TxPipeline (MULTI/EXEC) instead of Pipeline for atomic INCR+EXPIRE
- Filter negative values in GetBaseRPM(), update test expectation
- Add RPM batch query (GetRPMBatch) to account List API
- Add warn logs for RPM increment failures in gateway handler
- Reset enableRpmLimit on BulkEditAccountModal close
- Use union type 'tiered' | 'sticky_exempt' for rpmStrategy refs
- Add design decision comments for rdb.Time() RTT trade-off
This commit is contained in:
QTom
2026-02-28 10:16:34 +08:00
parent 28ca7df297
commit 607237571f
13 changed files with 509 additions and 80 deletions

View File

@@ -241,9 +241,10 @@ func (h *AccountHandler) List(c *gin.Context) {
concurrencyCounts = make(map[int64]int)
}
// 识别需要查询窗口费用会话数的账号Anthropic OAuth/SetupToken 且启用了相应功能)
// 识别需要查询窗口费用会话数和 RPM 的账号Anthropic OAuth/SetupToken 且启用了相应功能)
windowCostAccountIDs := make([]int64, 0)
sessionLimitAccountIDs := make([]int64, 0)
rpmAccountIDs := make([]int64, 0)
sessionIdleTimeouts := make(map[int64]time.Duration) // 各账号的会话空闲超时配置
for i := range accounts {
acc := &accounts[i]
@@ -255,12 +256,24 @@ func (h *AccountHandler) List(c *gin.Context) {
sessionLimitAccountIDs = append(sessionLimitAccountIDs, acc.ID)
sessionIdleTimeouts[acc.ID] = time.Duration(acc.GetSessionIdleTimeoutMinutes()) * time.Minute
}
if acc.GetBaseRPM() > 0 {
rpmAccountIDs = append(rpmAccountIDs, acc.ID)
}
}
}
// 并行获取窗口费用活跃会话数
// 并行获取窗口费用活跃会话数和 RPM 计数
var windowCosts map[int64]float64
var activeSessions map[int64]int
var rpmCounts map[int64]int
// 获取 RPM 计数(批量查询)
if len(rpmAccountIDs) > 0 && h.rpmCache != nil {
rpmCounts, _ = h.rpmCache.GetRPMBatch(c.Request.Context(), rpmAccountIDs)
if rpmCounts == nil {
rpmCounts = make(map[int64]int)
}
}
// 获取活跃会话数(批量查询,传入各账号的 idleTimeout 配置)
if len(sessionLimitAccountIDs) > 0 && h.sessionLimitCache != nil {
@@ -321,6 +334,13 @@ func (h *AccountHandler) List(c *gin.Context) {
}
}
// 添加 RPM 计数(仅当启用时)
if rpmCounts != nil {
if rpm, ok := rpmCounts[acc.ID]; ok {
item.CurrentRPM = &rpm
}
}
result[i] = item
}

View File

@@ -368,8 +368,8 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
// RPM 计数递增调度成功后、Forward 前)
if account.IsAnthropicOAuthOrSetupToken() && account.GetBaseRPM() > 0 {
if h.gatewayService.IncrementAccountRPM(c.Request.Context(), account.ID) != nil {
// 失败开放:不阻塞请求
if err := h.gatewayService.IncrementAccountRPM(c.Request.Context(), account.ID); err != nil {
reqLog.Warn("gateway.rpm_increment_failed", zap.Int64("account_id", account.ID), zap.Error(err))
}
}
@@ -558,8 +558,8 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
// RPM 计数递增调度成功后、Forward 前)
if account.IsAnthropicOAuthOrSetupToken() && account.GetBaseRPM() > 0 {
if h.gatewayService.IncrementAccountRPM(c.Request.Context(), account.ID) != nil {
// 失败开放:不阻塞请求
if err := h.gatewayService.IncrementAccountRPM(c.Request.Context(), account.ID); err != nil {
reqLog.Warn("gateway.rpm_increment_failed", zap.Int64("account_id", account.ID), zap.Error(err))
}
}

View File

@@ -2,78 +2,130 @@ package repository
import (
"context"
"errors"
"fmt"
"strconv"
"time"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/redis/go-redis/v9"
)
const rpmKeyPrefix = "rpm:"
// RPM 计数器缓存常量定义
//
// 设计说明:
// 使用 Redis 简单计数器跟踪每个账号每分钟的请求数:
// - Key: rpm:{accountID}:{minuteTimestamp}
// - Value: 当前分钟内的请求计数
// - TTL: 120 秒(覆盖当前分钟 + 一定冗余)
//
// 使用 TxPipelineMULTI/EXEC执行 INCR + EXPIRE保证原子性且兼容 Redis Cluster。
// 通过 rdb.Time() 获取服务端时间,避免多实例时钟不同步。
//
// 设计决策:
// - TxPipeline vs PipelinePipeline 仅合并发送但不保证原子TxPipeline 使用 MULTI/EXEC 事务保证原子执行。
// - rdb.Time() 单独调用Pipeline/TxPipeline 中无法引用前一命令的结果,因此 TIME 必须单独调用2 RTT
// Lua 脚本可以做到 1 RTT但在 Redis Cluster 中动态拼接 key 存在 CROSSSLOT 风险,选择安全性优先。
const (
// RPM 计数器键前缀
// 格式: rpm:{accountID}:{minuteTimestamp}
rpmKeyPrefix = "rpm:"
// Lua scripts use Redis TIME for server-side minute key calculation
var rpmIncrScript = redis.NewScript(`
local timeResult = redis.call('TIME')
local minuteKey = math.floor(tonumber(timeResult[1]) / 60)
local key = ARGV[1] .. ':' .. minuteKey
local count = redis.call('INCR', key)
if count == 1 then
redis.call('EXPIRE', key, 120)
end
return count
`)
var rpmGetScript = redis.NewScript(`
local timeResult = redis.call('TIME')
local minuteKey = math.floor(tonumber(timeResult[1]) / 60)
local key = ARGV[1] .. ':' .. minuteKey
local count = redis.call('GET', key)
if count == false then
return 0
end
return tonumber(count)
`)
// RPM 计数器 TTL120 秒,覆盖当前分钟窗口 + 冗余)
rpmKeyTTL = 120 * time.Second
)
// RPMCacheImpl RPM 计数器缓存 Redis 实现
type RPMCacheImpl struct {
rdb *redis.Client
}
// NewRPMCache 创建 RPM 计数器缓存
func NewRPMCache(rdb *redis.Client) service.RPMCache {
return &RPMCacheImpl{rdb: rdb}
}
func rpmKeyBase(accountID int64) string {
return fmt.Sprintf("%s%d", rpmKeyPrefix, accountID)
// currentMinuteKey 获取当前分钟的完整 Redis key
// 使用 rdb.Time() 获取 Redis 服务端时间,避免多实例时钟偏差
func (c *RPMCacheImpl) currentMinuteKey(ctx context.Context, accountID int64) (string, error) {
serverTime, err := c.rdb.Time(ctx).Result()
if err != nil {
return "", fmt.Errorf("redis TIME: %w", err)
}
minuteTS := serverTime.Unix() / 60
return fmt.Sprintf("%s%d:%d", rpmKeyPrefix, accountID, minuteTS), nil
}
// currentMinuteSuffix 获取当前分钟时间戳后缀(供批量操作使用)
// 使用 rdb.Time() 获取 Redis 服务端时间
func (c *RPMCacheImpl) currentMinuteSuffix(ctx context.Context) (string, error) {
serverTime, err := c.rdb.Time(ctx).Result()
if err != nil {
return "", fmt.Errorf("redis TIME: %w", err)
}
minuteTS := serverTime.Unix() / 60
return strconv.FormatInt(minuteTS, 10), nil
}
// IncrementRPM 原子递增并返回当前分钟的计数
// 使用 TxPipeline (MULTI/EXEC) 执行 INCR + EXPIRE保证原子性且兼容 Redis Cluster
func (c *RPMCacheImpl) IncrementRPM(ctx context.Context, accountID int64) (int, error) {
result, err := rpmIncrScript.Run(ctx, c.rdb, nil, rpmKeyBase(accountID)).Int()
key, err := c.currentMinuteKey(ctx, accountID)
if err != nil {
return 0, fmt.Errorf("rpm increment: %w", err)
}
return result, nil
// 使用 TxPipeline (MULTI/EXEC) 保证 INCR + EXPIRE 原子执行
// EXPIRE 幂等,每次都设置不影响正确性
pipe := c.rdb.TxPipeline()
incrCmd := pipe.Incr(ctx, key)
pipe.Expire(ctx, key, rpmKeyTTL)
if _, err := pipe.Exec(ctx); err != nil {
return 0, fmt.Errorf("rpm increment: %w", err)
}
return int(incrCmd.Val()), nil
}
// GetRPM 获取当前分钟的 RPM 计数
func (c *RPMCacheImpl) GetRPM(ctx context.Context, accountID int64) (int, error) {
result, err := rpmGetScript.Run(ctx, c.rdb, nil, rpmKeyBase(accountID)).Int()
key, err := c.currentMinuteKey(ctx, accountID)
if err != nil {
return 0, fmt.Errorf("rpm get: %w", err)
}
return result, nil
val, err := c.rdb.Get(ctx, key).Int()
if errors.Is(err, redis.Nil) {
return 0, nil // 当前分钟无记录
}
if err != nil {
return 0, fmt.Errorf("rpm get: %w", err)
}
return val, nil
}
// GetRPMBatch 批量获取多个账号的 RPM 计数(使用 Pipeline
func (c *RPMCacheImpl) GetRPMBatch(ctx context.Context, accountIDs []int64) (map[int64]int, error) {
if len(accountIDs) == 0 {
return map[int64]int{}, nil
}
pipe := c.rdb.Pipeline()
cmds := make(map[int64]*redis.Cmd, len(accountIDs))
for _, id := range accountIDs {
cmds[id] = rpmGetScript.Run(ctx, pipe, nil, rpmKeyBase(id))
// 获取当前分钟后缀
minuteSuffix, err := c.currentMinuteSuffix(ctx)
if err != nil {
return nil, fmt.Errorf("rpm batch get: %w", err)
}
_, err := pipe.Exec(ctx)
if err != nil && err != redis.Nil {
// 使用 Pipeline 批量 GET
pipe := c.rdb.Pipeline()
cmds := make(map[int64]*redis.StringCmd, len(accountIDs))
for _, id := range accountIDs {
key := fmt.Sprintf("%s%d:%s", rpmKeyPrefix, id, minuteSuffix)
cmds[id] = pipe.Get(ctx, key)
}
if _, err := pipe.Exec(ctx); err != nil && !errors.Is(err, redis.Nil) {
return nil, fmt.Errorf("rpm batch get: %w", err)
}

View File

@@ -1138,13 +1138,16 @@ func (a *Account) GetSessionIdleTimeoutMinutes() int {
}
// GetBaseRPM 获取基础 RPM 限制
// 返回 0 表示未启用
// 返回 0 表示未启用(负数视为无效配置,按 0 处理)
func (a *Account) GetBaseRPM() int {
if a.Extra == nil {
return 0
}
if v, ok := a.Extra["base_rpm"]; ok {
return parseExtraInt(v)
val := parseExtraInt(v)
if val > 0 {
return val
}
}
return 0
}

View File

@@ -13,6 +13,9 @@ func TestGetBaseRPM(t *testing.T) {
{"zero", map[string]any{"base_rpm": 0}, 0},
{"int value", map[string]any{"base_rpm": 15}, 15},
{"float value", map[string]any{"base_rpm": 15.0}, 15},
{"string value", map[string]any{"base_rpm": "15"}, 15},
{"negative value", map[string]any{"base_rpm": -5}, 0},
{"int64 value", map[string]any{"base_rpm": int64(20)}, 20},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
@@ -35,6 +38,8 @@ func TestGetRPMStrategy(t *testing.T) {
{"tiered", map[string]any{"rpm_strategy": "tiered"}, "tiered"},
{"sticky_exempt", map[string]any{"rpm_strategy": "sticky_exempt"}, "sticky_exempt"},
{"invalid", map[string]any{"rpm_strategy": "foobar"}, "tiered"},
{"empty string fallback", map[string]any{"rpm_strategy": ""}, "tiered"},
{"numeric value fallback", map[string]any{"rpm_strategy": 123}, "tiered"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
@@ -61,6 +66,13 @@ func TestCheckRPMSchedulability(t *testing.T) {
{"sticky_exempt over limit", map[string]any{"base_rpm": 15, "rpm_strategy": "sticky_exempt"}, 100, WindowCostStickyOnly},
{"custom buffer", map[string]any{"base_rpm": 10, "rpm_sticky_buffer": 5}, 14, WindowCostStickyOnly},
{"custom buffer red", map[string]any{"base_rpm": 10, "rpm_sticky_buffer": 5}, 15, WindowCostNotSchedulable},
{"base_rpm=1 green", map[string]any{"base_rpm": 1}, 0, WindowCostSchedulable},
{"base_rpm=1 yellow (at limit)", map[string]any{"base_rpm": 1}, 1, WindowCostStickyOnly},
{"base_rpm=1 red (at limit+buffer)", map[string]any{"base_rpm": 1}, 2, WindowCostNotSchedulable},
{"negative currentRPM", map[string]any{"base_rpm": 15}, -1, WindowCostSchedulable},
{"base_rpm negative disabled", map[string]any{"base_rpm": -5}, 10, WindowCostSchedulable},
{"very high currentRPM", map[string]any{"base_rpm": 10}, 9999, WindowCostNotSchedulable},
{"sticky_exempt very high currentRPM", map[string]any{"base_rpm": 10, "rpm_strategy": "sticky_exempt"}, 9999, WindowCostStickyOnly},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
@@ -71,3 +83,33 @@ func TestCheckRPMSchedulability(t *testing.T) {
})
}
}
func TestGetRPMStickyBuffer(t *testing.T) {
tests := []struct {
name string
extra map[string]any
expected int
}{
{"nil extra", nil, 0},
{"no keys", map[string]any{}, 0},
{"base_rpm=0", map[string]any{"base_rpm": 0}, 0},
{"base_rpm=1 min buffer 1", map[string]any{"base_rpm": 1}, 1},
{"base_rpm=4 min buffer 1", map[string]any{"base_rpm": 4}, 1},
{"base_rpm=5 buffer 1", map[string]any{"base_rpm": 5}, 1},
{"base_rpm=10 buffer 2", map[string]any{"base_rpm": 10}, 2},
{"base_rpm=15 buffer 3", map[string]any{"base_rpm": 15}, 3},
{"base_rpm=100 buffer 20", map[string]any{"base_rpm": 100}, 20},
{"custom buffer=5", map[string]any{"base_rpm": 10, "rpm_sticky_buffer": 5}, 5},
{"custom buffer=0 fallback to default", map[string]any{"base_rpm": 10, "rpm_sticky_buffer": 0}, 2},
{"custom buffer negative fallback", map[string]any{"base_rpm": 10, "rpm_sticky_buffer": -1}, 2},
{"custom buffer with float", map[string]any{"base_rpm": 10, "rpm_sticky_buffer": float64(7)}, 7},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
a := &Account{Extra: tt.extra}
if got := a.GetRPMStickyBuffer(); got != tt.expected {
t.Errorf("GetRPMStickyBuffer() = %d, want %d", got, tt.expected)
}
})
}
}

View File

@@ -2708,6 +2708,9 @@ func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context,
}
}
// 批量预取 RPM 计数避免逐个账号查询N+1
ctx = s.withRPMPrefetch(ctx, accounts)
// 3. 按优先级+最久未用选择(考虑模型支持)
var selected *Account
for i := range accounts {
@@ -2922,6 +2925,9 @@ func (s *GatewayService) selectAccountWithMixedScheduling(ctx context.Context, g
}
}
// 批量预取 RPM 计数避免逐个账号查询N+1
ctx = s.withRPMPrefetch(ctx, accounts)
// 3. 按优先级+最久未用选择(考虑模型支持和混合调度)
var selected *Account
for i := range accounts {