From 777be05348b1ca695adf453e8996f7ea6001d13a Mon Sep 17 00:00:00 2001 From: QTom Date: Sat, 28 Feb 2026 01:14:55 +0800 Subject: [PATCH] feat: add RPMCache interface and Redis implementation with Lua scripts --- backend/internal/repository/rpm_cache.go | 88 ++++++++++++++++++++++++ backend/internal/service/rpm_cache.go | 17 +++++ 2 files changed, 105 insertions(+) create mode 100644 backend/internal/repository/rpm_cache.go create mode 100644 backend/internal/service/rpm_cache.go diff --git a/backend/internal/repository/rpm_cache.go b/backend/internal/repository/rpm_cache.go new file mode 100644 index 00000000..6ec7f739 --- /dev/null +++ b/backend/internal/repository/rpm_cache.go @@ -0,0 +1,88 @@ +package repository + +import ( + "context" + "fmt" + + "github.com/redis/go-redis/v9" +) + +const rpmKeyPrefix = "rpm:" + +// Lua scripts use Redis TIME for server-side minute key calculation +var rpmIncrScript = redis.NewScript(` +local timeResult = redis.call('TIME') +local minuteKey = math.floor(tonumber(timeResult[1]) / 60) +local key = ARGV[1] .. ':' .. minuteKey +local count = redis.call('INCR', key) +if count == 1 then + redis.call('EXPIRE', key, 120) +end +return count +`) + +var rpmGetScript = redis.NewScript(` +local timeResult = redis.call('TIME') +local minuteKey = math.floor(tonumber(timeResult[1]) / 60) +local key = ARGV[1] .. ':' .. minuteKey +local count = redis.call('GET', key) +if count == false then + return 0 +end +return tonumber(count) +`) + +type RPMCacheImpl struct { + rdb *redis.Client +} + +func NewRPMCache(rdb *redis.Client) *RPMCacheImpl { + return &RPMCacheImpl{rdb: rdb} +} + +func rpmKeyBase(accountID int64) string { + return fmt.Sprintf("%s%d", rpmKeyPrefix, accountID) +} + +func (c *RPMCacheImpl) IncrementRPM(ctx context.Context, accountID int64) (int, error) { + result, err := rpmIncrScript.Run(ctx, c.rdb, nil, rpmKeyBase(accountID)).Int() + if err != nil { + return 0, fmt.Errorf("rpm increment: %w", err) + } + return result, nil +} + +func (c *RPMCacheImpl) GetRPM(ctx context.Context, accountID int64) (int, error) { + result, err := rpmGetScript.Run(ctx, c.rdb, nil, rpmKeyBase(accountID)).Int() + if err != nil { + return 0, fmt.Errorf("rpm get: %w", err) + } + return result, nil +} + +func (c *RPMCacheImpl) GetRPMBatch(ctx context.Context, accountIDs []int64) (map[int64]int, error) { + if len(accountIDs) == 0 { + return map[int64]int{}, nil + } + + pipe := c.rdb.Pipeline() + cmds := make(map[int64]*redis.Cmd, len(accountIDs)) + for _, id := range accountIDs { + cmds[id] = rpmGetScript.Run(ctx, pipe, nil, rpmKeyBase(id)) + } + + _, err := pipe.Exec(ctx) + if err != nil && err != redis.Nil { + return nil, fmt.Errorf("rpm batch get: %w", err) + } + + result := make(map[int64]int, len(accountIDs)) + for id, cmd := range cmds { + if val, err := cmd.Int(); err == nil { + result[id] = val + } else { + result[id] = 0 + } + } + return result, nil +} diff --git a/backend/internal/service/rpm_cache.go b/backend/internal/service/rpm_cache.go new file mode 100644 index 00000000..07036219 --- /dev/null +++ b/backend/internal/service/rpm_cache.go @@ -0,0 +1,17 @@ +package service + +import "context" + +// RPMCache RPM 计数器缓存接口 +// 用于 Anthropic OAuth/SetupToken 账号的每分钟请求数限制 +type RPMCache interface { + // IncrementRPM 原子递增并返回当前分钟的计数 + // 使用 Redis 服务器时间确定 minute key,避免多实例时钟偏差 + IncrementRPM(ctx context.Context, accountID int64) (count int, err error) + + // GetRPM 获取当前分钟的 RPM 计数 + GetRPM(ctx context.Context, accountID int64) (count int, err error) + + // GetRPMBatch 批量获取多个账号的 RPM 计数(使用 Pipeline) + GetRPMBatch(ctx context.Context, accountIDs []int64) (map[int64]int, error) +}