Merge pull request #679 from DaydreamCoding/feat/account-rpm-limit
feat: 添加账号级别 RPM(每分钟请求数)限流功能
This commit is contained in:
141
backend/internal/repository/rpm_cache.go
Normal file
141
backend/internal/repository/rpm_cache.go
Normal file
@@ -0,0 +1,141 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/redis/go-redis/v9"
|
||||
)
|
||||
|
||||
// RPM 计数器缓存常量定义
|
||||
//
|
||||
// 设计说明:
|
||||
// 使用 Redis 简单计数器跟踪每个账号每分钟的请求数:
|
||||
// - Key: rpm:{accountID}:{minuteTimestamp}
|
||||
// - Value: 当前分钟内的请求计数
|
||||
// - TTL: 120 秒(覆盖当前分钟 + 一定冗余)
|
||||
//
|
||||
// 使用 TxPipeline(MULTI/EXEC)执行 INCR + EXPIRE,保证原子性且兼容 Redis Cluster。
|
||||
// 通过 rdb.Time() 获取服务端时间,避免多实例时钟不同步。
|
||||
//
|
||||
// 设计决策:
|
||||
// - TxPipeline vs Pipeline:Pipeline 仅合并发送但不保证原子,TxPipeline 使用 MULTI/EXEC 事务保证原子执行。
|
||||
// - rdb.Time() 单独调用:Pipeline/TxPipeline 中无法引用前一命令的结果,因此 TIME 必须单独调用(2 RTT)。
|
||||
// Lua 脚本可以做到 1 RTT,但在 Redis Cluster 中动态拼接 key 存在 CROSSSLOT 风险,选择安全性优先。
|
||||
const (
|
||||
// RPM 计数器键前缀
|
||||
// 格式: rpm:{accountID}:{minuteTimestamp}
|
||||
rpmKeyPrefix = "rpm:"
|
||||
|
||||
// RPM 计数器 TTL(120 秒,覆盖当前分钟窗口 + 冗余)
|
||||
rpmKeyTTL = 120 * time.Second
|
||||
)
|
||||
|
||||
// RPMCacheImpl RPM 计数器缓存 Redis 实现
|
||||
type RPMCacheImpl struct {
|
||||
rdb *redis.Client
|
||||
}
|
||||
|
||||
// NewRPMCache 创建 RPM 计数器缓存
|
||||
func NewRPMCache(rdb *redis.Client) service.RPMCache {
|
||||
return &RPMCacheImpl{rdb: rdb}
|
||||
}
|
||||
|
||||
// currentMinuteKey 获取当前分钟的完整 Redis key
|
||||
// 使用 rdb.Time() 获取 Redis 服务端时间,避免多实例时钟偏差
|
||||
func (c *RPMCacheImpl) currentMinuteKey(ctx context.Context, accountID int64) (string, error) {
|
||||
serverTime, err := c.rdb.Time(ctx).Result()
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("redis TIME: %w", err)
|
||||
}
|
||||
minuteTS := serverTime.Unix() / 60
|
||||
return fmt.Sprintf("%s%d:%d", rpmKeyPrefix, accountID, minuteTS), nil
|
||||
}
|
||||
|
||||
// currentMinuteSuffix 获取当前分钟时间戳后缀(供批量操作使用)
|
||||
// 使用 rdb.Time() 获取 Redis 服务端时间
|
||||
func (c *RPMCacheImpl) currentMinuteSuffix(ctx context.Context) (string, error) {
|
||||
serverTime, err := c.rdb.Time(ctx).Result()
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("redis TIME: %w", err)
|
||||
}
|
||||
minuteTS := serverTime.Unix() / 60
|
||||
return strconv.FormatInt(minuteTS, 10), nil
|
||||
}
|
||||
|
||||
// IncrementRPM 原子递增并返回当前分钟的计数
|
||||
// 使用 TxPipeline (MULTI/EXEC) 执行 INCR + EXPIRE,保证原子性且兼容 Redis Cluster
|
||||
func (c *RPMCacheImpl) IncrementRPM(ctx context.Context, accountID int64) (int, error) {
|
||||
key, err := c.currentMinuteKey(ctx, accountID)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("rpm increment: %w", err)
|
||||
}
|
||||
|
||||
// 使用 TxPipeline (MULTI/EXEC) 保证 INCR + EXPIRE 原子执行
|
||||
// EXPIRE 幂等,每次都设置不影响正确性
|
||||
pipe := c.rdb.TxPipeline()
|
||||
incrCmd := pipe.Incr(ctx, key)
|
||||
pipe.Expire(ctx, key, rpmKeyTTL)
|
||||
|
||||
if _, err := pipe.Exec(ctx); err != nil {
|
||||
return 0, fmt.Errorf("rpm increment: %w", err)
|
||||
}
|
||||
|
||||
return int(incrCmd.Val()), nil
|
||||
}
|
||||
|
||||
// GetRPM 获取当前分钟的 RPM 计数
|
||||
func (c *RPMCacheImpl) GetRPM(ctx context.Context, accountID int64) (int, error) {
|
||||
key, err := c.currentMinuteKey(ctx, accountID)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("rpm get: %w", err)
|
||||
}
|
||||
|
||||
val, err := c.rdb.Get(ctx, key).Int()
|
||||
if errors.Is(err, redis.Nil) {
|
||||
return 0, nil // 当前分钟无记录
|
||||
}
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("rpm get: %w", err)
|
||||
}
|
||||
return val, nil
|
||||
}
|
||||
|
||||
// GetRPMBatch 批量获取多个账号的 RPM 计数(使用 Pipeline)
|
||||
func (c *RPMCacheImpl) GetRPMBatch(ctx context.Context, accountIDs []int64) (map[int64]int, error) {
|
||||
if len(accountIDs) == 0 {
|
||||
return map[int64]int{}, nil
|
||||
}
|
||||
|
||||
// 获取当前分钟后缀
|
||||
minuteSuffix, err := c.currentMinuteSuffix(ctx)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("rpm batch get: %w", err)
|
||||
}
|
||||
|
||||
// 使用 Pipeline 批量 GET
|
||||
pipe := c.rdb.Pipeline()
|
||||
cmds := make(map[int64]*redis.StringCmd, len(accountIDs))
|
||||
for _, id := range accountIDs {
|
||||
key := fmt.Sprintf("%s%d:%s", rpmKeyPrefix, id, minuteSuffix)
|
||||
cmds[id] = pipe.Get(ctx, key)
|
||||
}
|
||||
|
||||
if _, err := pipe.Exec(ctx); err != nil && !errors.Is(err, redis.Nil) {
|
||||
return nil, fmt.Errorf("rpm batch get: %w", err)
|
||||
}
|
||||
|
||||
result := make(map[int64]int, len(accountIDs))
|
||||
for id, cmd := range cmds {
|
||||
if val, err := cmd.Int(); err == nil {
|
||||
result[id] = val
|
||||
} else {
|
||||
result[id] = 0
|
||||
}
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
@@ -79,6 +79,7 @@ var ProviderSet = wire.NewSet(
|
||||
NewTimeoutCounterCache,
|
||||
ProvideConcurrencyCache,
|
||||
ProvideSessionLimitCache,
|
||||
NewRPMCache,
|
||||
NewDashboardCache,
|
||||
NewEmailCache,
|
||||
NewIdentityCache,
|
||||
|
||||
Reference in New Issue
Block a user