feat(rpm): RPM 限流模块优化

P0:
- rpm_override 嵌入 Auth Cache Snapshot,消除每请求 DB 查询 (snapshot v6→v7)
- 429 RPM 响应返回 Retry-After 头(当前分钟剩余秒数)

P1:
- ClearAll 按钮直连 DELETE API,带 loading 防重复
- 新增 GET /admin/users/:id/rpm-status 管理员 RPM 用量查询端点

优化:
- checkRPM 从级联互斥改为并行取最严,user.rpm_limit 作为全局硬上限始终生效
- Override/Group 变更后自动失效 auth cache
- fail-open 语义不变,Redis 故障不阻塞业务
This commit is contained in:
james-6-23
2026-04-23 03:33:52 +08:00
parent ef967d8f8a
commit dc5d42addc
79 changed files with 2831 additions and 140 deletions

View File

@@ -152,6 +152,7 @@ func (r *apiKeyRepository) GetByKeyForAuth(ctx context.Context, key string) (*se
user.FieldSignupSource,
user.FieldLastLoginAt,
user.FieldLastActiveAt,
user.FieldRpmLimit,
)
}).
WithGroup(func(q *dbent.GroupQuery) {
@@ -178,6 +179,7 @@ func (r *apiKeyRepository) GetByKeyForAuth(ctx context.Context, key string) (*se
group.FieldAllowMessagesDispatch,
group.FieldDefaultMappedModel,
group.FieldMessagesDispatchModelConfig,
group.FieldRpmLimit,
)
}).
Only(ctx)
@@ -669,6 +671,7 @@ func userEntityToService(u *dbent.User) *service.User {
BalanceNotifyThresholdType: u.BalanceNotifyThresholdType,
BalanceNotifyThreshold: u.BalanceNotifyThreshold,
TotalRecharged: u.TotalRecharged,
RPMLimit: u.RpmLimit,
CreatedAt: u.CreatedAt,
UpdatedAt: u.UpdatedAt,
}
@@ -713,6 +716,7 @@ func groupEntityToService(g *dbent.Group) *service.Group {
RequirePrivacySet: g.RequirePrivacySet,
DefaultMappedModel: g.DefaultMappedModel,
MessagesDispatchModelConfig: g.MessagesDispatchModelConfig,
RPMLimit: g.RpmLimit,
CreatedAt: g.CreatedAt,
UpdatedAt: g.UpdatedAt,
}

View File

@@ -63,7 +63,8 @@ func (r *groupRepository) Create(ctx context.Context, groupIn *service.Group) er
SetRequireOauthOnly(groupIn.RequireOAuthOnly).
SetRequirePrivacySet(groupIn.RequirePrivacySet).
SetDefaultMappedModel(groupIn.DefaultMappedModel).
SetMessagesDispatchModelConfig(groupIn.MessagesDispatchModelConfig)
SetMessagesDispatchModelConfig(groupIn.MessagesDispatchModelConfig).
SetRpmLimit(groupIn.RPMLimit)
// 设置模型路由配置
if groupIn.ModelRouting != nil {
@@ -130,7 +131,8 @@ func (r *groupRepository) Update(ctx context.Context, groupIn *service.Group) er
SetRequireOauthOnly(groupIn.RequireOAuthOnly).
SetRequirePrivacySet(groupIn.RequirePrivacySet).
SetDefaultMappedModel(groupIn.DefaultMappedModel).
SetMessagesDispatchModelConfig(groupIn.MessagesDispatchModelConfig)
SetMessagesDispatchModelConfig(groupIn.MessagesDispatchModelConfig).
SetRpmLimit(groupIn.RPMLimit)
// 显式处理可空字段nil 需要 clear非 nil 需要 set。
if groupIn.DailyLimitUSD != nil {

View File

@@ -13,14 +13,14 @@ type userGroupRateRepository struct {
sql sqlExecutor
}
// NewUserGroupRateRepository 创建用户专属分组倍率仓储
// NewUserGroupRateRepository 创建用户专属分组倍率/RPM 仓储
func NewUserGroupRateRepository(sqlDB *sql.DB) service.UserGroupRateRepository {
return &userGroupRateRepository{sql: sqlDB}
}
// GetByUserID 获取用户所有专属分组倍率
// GetByUserID 获取用户所有专属分组 rate_multiplier仅返回非 NULL 的条目)
func (r *userGroupRateRepository) GetByUserID(ctx context.Context, userID int64) (map[int64]float64, error) {
query := `SELECT group_id, rate_multiplier FROM user_group_rate_multipliers WHERE user_id = $1`
query := `SELECT group_id, rate_multiplier FROM user_group_rate_multipliers WHERE user_id = $1 AND rate_multiplier IS NOT NULL`
rows, err := r.sql.QueryContext(ctx, query, userID)
if err != nil {
return nil, err
@@ -42,8 +42,7 @@ func (r *userGroupRateRepository) GetByUserID(ctx context.Context, userID int64)
return result, nil
}
// GetByUserIDs 批量获取多个用户的专属分组倍率。
// 返回结构map[userID]map[groupID]rate
// GetByUserIDs 批量获取多个用户的专属分组 rate_multiplier仅返回非 NULL 的条目)
func (r *userGroupRateRepository) GetByUserIDs(ctx context.Context, userIDs []int64) (map[int64]map[int64]float64, error) {
result := make(map[int64]map[int64]float64, len(userIDs))
if len(userIDs) == 0 {
@@ -70,7 +69,7 @@ func (r *userGroupRateRepository) GetByUserIDs(ctx context.Context, userIDs []in
rows, err := r.sql.QueryContext(ctx, `
SELECT user_id, group_id, rate_multiplier
FROM user_group_rate_multipliers
WHERE user_id = ANY($1)
WHERE user_id = ANY($1) AND rate_multiplier IS NOT NULL
`, pq.Array(uniqueIDs))
if err != nil {
return nil, err
@@ -95,10 +94,10 @@ func (r *userGroupRateRepository) GetByUserIDs(ctx context.Context, userIDs []in
return result, nil
}
// GetByGroupID 获取指定分组下所有用户的专属倍率
// GetByGroupID 获取指定分组下所有用户的专属配置rate 与 rpm_override 任一非 NULL 即返回)
func (r *userGroupRateRepository) GetByGroupID(ctx context.Context, groupID int64) ([]service.UserGroupRateEntry, error) {
query := `
SELECT ugr.user_id, u.username, u.email, COALESCE(u.notes, ''), u.status, ugr.rate_multiplier
SELECT ugr.user_id, u.username, u.email, COALESCE(u.notes, ''), u.status, ugr.rate_multiplier, ugr.rpm_override
FROM user_group_rate_multipliers ugr
JOIN users u ON u.id = ugr.user_id AND u.deleted_at IS NULL
WHERE ugr.group_id = $1
@@ -113,9 +112,19 @@ func (r *userGroupRateRepository) GetByGroupID(ctx context.Context, groupID int6
var result []service.UserGroupRateEntry
for rows.Next() {
var entry service.UserGroupRateEntry
if err := rows.Scan(&entry.UserID, &entry.UserName, &entry.UserEmail, &entry.UserNotes, &entry.UserStatus, &entry.RateMultiplier); err != nil {
var rate sql.NullFloat64
var rpm sql.NullInt32
if err := rows.Scan(&entry.UserID, &entry.UserName, &entry.UserEmail, &entry.UserNotes, &entry.UserStatus, &rate, &rpm); err != nil {
return nil, err
}
if rate.Valid {
v := rate.Float64
entry.RateMultiplier = &v
}
if rpm.Valid {
v := int(rpm.Int32)
entry.RPMOverride = &v
}
result = append(result, entry)
}
if err := rows.Err(); err != nil {
@@ -124,10 +133,10 @@ func (r *userGroupRateRepository) GetByGroupID(ctx context.Context, groupID int6
return result, nil
}
// GetByUserAndGroup 获取用户在特定分组的专属倍率
// GetByUserAndGroup 获取用户在特定分组的专属 rate_multiplierNULL 返回 nil
func (r *userGroupRateRepository) GetByUserAndGroup(ctx context.Context, userID, groupID int64) (*float64, error) {
query := `SELECT rate_multiplier FROM user_group_rate_multipliers WHERE user_id = $1 AND group_id = $2`
var rate float64
var rate sql.NullFloat64
err := scanSingleRow(ctx, r.sql, query, []any{userID, groupID}, &rate)
if err == sql.ErrNoRows {
return nil, nil
@@ -135,42 +144,79 @@ func (r *userGroupRateRepository) GetByUserAndGroup(ctx context.Context, userID,
if err != nil {
return nil, err
}
return &rate, nil
if !rate.Valid {
return nil, nil
}
v := rate.Float64
return &v, nil
}
// SyncUserGroupRates 同步用户的分组专属倍率
// GetRPMOverrideByUserAndGroup 获取用户在特定分组的 rpm_overrideNULL 返回 nil
func (r *userGroupRateRepository) GetRPMOverrideByUserAndGroup(ctx context.Context, userID, groupID int64) (*int, error) {
query := `SELECT rpm_override FROM user_group_rate_multipliers WHERE user_id = $1 AND group_id = $2`
var rpm sql.NullInt32
err := scanSingleRow(ctx, r.sql, query, []any{userID, groupID}, &rpm)
if err == sql.ErrNoRows {
return nil, nil
}
if err != nil {
return nil, err
}
if !rpm.Valid {
return nil, nil
}
v := int(rpm.Int32)
return &v, nil
}
// SyncUserGroupRates 同步用户的分组专属 rate_multiplier。
// - 传入空 map清空该用户所有行的 rate_multiplier若 rpm_override 也为 NULL 则整行删除。
// - 值为 nil清空对应行的 rate_multiplier保留 rpm_override
// - 值非 nilupsert rate_multiplier保留已有 rpm_override
func (r *userGroupRateRepository) SyncUserGroupRates(ctx context.Context, userID int64, rates map[int64]*float64) error {
if len(rates) == 0 {
// 如果传入空 map删除该用户的所有专属倍率
_, err := r.sql.ExecContext(ctx, `DELETE FROM user_group_rate_multipliers WHERE user_id = $1`, userID)
if _, err := r.sql.ExecContext(ctx, `
UPDATE user_group_rate_multipliers
SET rate_multiplier = NULL, updated_at = NOW()
WHERE user_id = $1
`, userID); err != nil {
return err
}
_, err := r.sql.ExecContext(ctx,
`DELETE FROM user_group_rate_multipliers WHERE user_id = $1 AND rate_multiplier IS NULL AND rpm_override IS NULL`,
userID)
return err
}
// 分离需要删除和需要 upsert 的记录
var toDelete []int64
var clearGroupIDs []int64
upsertGroupIDs := make([]int64, 0, len(rates))
upsertRates := make([]float64, 0, len(rates))
for groupID, rate := range rates {
if rate == nil {
toDelete = append(toDelete, groupID)
clearGroupIDs = append(clearGroupIDs, groupID)
} else {
upsertGroupIDs = append(upsertGroupIDs, groupID)
upsertRates = append(upsertRates, *rate)
}
}
// 删除指定的记录
if len(toDelete) > 0 {
if len(clearGroupIDs) > 0 {
if _, err := r.sql.ExecContext(ctx, `
UPDATE user_group_rate_multipliers
SET rate_multiplier = NULL, updated_at = NOW()
WHERE user_id = $1 AND group_id = ANY($2)
`, userID, pq.Array(clearGroupIDs)); err != nil {
return err
}
if _, err := r.sql.ExecContext(ctx,
`DELETE FROM user_group_rate_multipliers WHERE user_id = $1 AND group_id = ANY($2)`,
userID, pq.Array(toDelete)); err != nil {
`DELETE FROM user_group_rate_multipliers WHERE user_id = $1 AND group_id = ANY($2) AND rate_multiplier IS NULL AND rpm_override IS NULL`,
userID, pq.Array(clearGroupIDs)); err != nil {
return err
}
}
// Upsert 记录
now := time.Now()
if len(upsertGroupIDs) > 0 {
now := time.Now()
_, err := r.sql.ExecContext(ctx, `
INSERT INTO user_group_rate_multipliers (user_id, group_id, rate_multiplier, created_at, updated_at)
SELECT
@@ -193,14 +239,47 @@ func (r *userGroupRateRepository) SyncUserGroupRates(ctx context.Context, userID
return nil
}
// SyncGroupRateMultipliers 批量同步分组的用户专属倍率(先删后插)
// SyncGroupRateMultipliers 同步分组的 rate_multiplier 部分(不触动 rpm_override
// 语义:
// - 未出现在 entries 中的用户行rate_multiplier 归 NULL若 rpm_override 也为 NULL 则整行删除。
// - 出现的用户行upsert rate_multiplier。
func (r *userGroupRateRepository) SyncGroupRateMultipliers(ctx context.Context, groupID int64, entries []service.GroupRateMultiplierInput) error {
if _, err := r.sql.ExecContext(ctx, `DELETE FROM user_group_rate_multipliers WHERE group_id = $1`, groupID); err != nil {
keepUserIDs := make([]int64, 0, len(entries))
for _, e := range entries {
keepUserIDs = append(keepUserIDs, e.UserID)
}
// 未在 entries 列表中的行:清空 rate_multiplier。
if len(keepUserIDs) == 0 {
if _, err := r.sql.ExecContext(ctx, `
UPDATE user_group_rate_multipliers
SET rate_multiplier = NULL, updated_at = NOW()
WHERE group_id = $1
`, groupID); err != nil {
return err
}
} else {
if _, err := r.sql.ExecContext(ctx, `
UPDATE user_group_rate_multipliers
SET rate_multiplier = NULL, updated_at = NOW()
WHERE group_id = $1 AND user_id <> ALL($2)
`, groupID, pq.Array(keepUserIDs)); err != nil {
return err
}
}
// 清空后若整行 NULL 则删除。
if _, err := r.sql.ExecContext(ctx, `
DELETE FROM user_group_rate_multipliers
WHERE group_id = $1 AND rate_multiplier IS NULL AND rpm_override IS NULL
`, groupID); err != nil {
return err
}
if len(entries) == 0 {
return nil
}
userIDs := make([]int64, len(entries))
rates := make([]float64, len(entries))
for i, e := range entries {
@@ -218,13 +297,103 @@ func (r *userGroupRateRepository) SyncGroupRateMultipliers(ctx context.Context,
return err
}
// DeleteByGroupID 删除指定分组的所有用户专属倍率
// SyncGroupRPMOverrides 同步分组的 rpm_override 部分(不触动 rate_multiplier
// 语义:
// - 未出现的用户行rpm_override 归 NULL若 rate_multiplier 也为 NULL 则整行删除。
// - 出现的用户行:若 RPMOverride 为 nil 则清空;非 nil 则 upsert。
func (r *userGroupRateRepository) SyncGroupRPMOverrides(ctx context.Context, groupID int64, entries []service.GroupRPMOverrideInput) error {
keepUserIDs := make([]int64, 0, len(entries))
var clearUserIDs []int64
upsertUserIDs := make([]int64, 0, len(entries))
upsertValues := make([]int32, 0, len(entries))
for _, e := range entries {
keepUserIDs = append(keepUserIDs, e.UserID)
if e.RPMOverride == nil {
clearUserIDs = append(clearUserIDs, e.UserID)
} else {
upsertUserIDs = append(upsertUserIDs, e.UserID)
upsertValues = append(upsertValues, int32(*e.RPMOverride))
}
}
// 未在 entries 列表中的行:清空 rpm_override。
if len(keepUserIDs) == 0 {
if _, err := r.sql.ExecContext(ctx, `
UPDATE user_group_rate_multipliers
SET rpm_override = NULL, updated_at = NOW()
WHERE group_id = $1
`, groupID); err != nil {
return err
}
} else {
if _, err := r.sql.ExecContext(ctx, `
UPDATE user_group_rate_multipliers
SET rpm_override = NULL, updated_at = NOW()
WHERE group_id = $1 AND user_id <> ALL($2)
`, groupID, pq.Array(keepUserIDs)); err != nil {
return err
}
}
// 显式 clear 的行。
if len(clearUserIDs) > 0 {
if _, err := r.sql.ExecContext(ctx, `
UPDATE user_group_rate_multipliers
SET rpm_override = NULL, updated_at = NOW()
WHERE group_id = $1 AND user_id = ANY($2)
`, groupID, pq.Array(clearUserIDs)); err != nil {
return err
}
}
// 清空后若整行 NULL 则删除。
if _, err := r.sql.ExecContext(ctx, `
DELETE FROM user_group_rate_multipliers
WHERE group_id = $1 AND rate_multiplier IS NULL AND rpm_override IS NULL
`, groupID); err != nil {
return err
}
if len(upsertUserIDs) > 0 {
now := time.Now()
_, err := r.sql.ExecContext(ctx, `
INSERT INTO user_group_rate_multipliers (user_id, group_id, rpm_override, created_at, updated_at)
SELECT data.user_id, $1::bigint, data.rpm_override, $2::timestamptz, $2::timestamptz
FROM unnest($3::bigint[], $4::integer[]) AS data(user_id, rpm_override)
ON CONFLICT (user_id, group_id)
DO UPDATE SET rpm_override = EXCLUDED.rpm_override, updated_at = EXCLUDED.updated_at
`, groupID, now, pq.Array(upsertUserIDs), pq.Array(upsertValues))
if err != nil {
return err
}
}
return nil
}
// ClearGroupRPMOverrides 清空指定分组所有行的 rpm_override。
func (r *userGroupRateRepository) ClearGroupRPMOverrides(ctx context.Context, groupID int64) error {
if _, err := r.sql.ExecContext(ctx, `
UPDATE user_group_rate_multipliers
SET rpm_override = NULL, updated_at = NOW()
WHERE group_id = $1
`, groupID); err != nil {
return err
}
_, err := r.sql.ExecContext(ctx, `
DELETE FROM user_group_rate_multipliers
WHERE group_id = $1 AND rate_multiplier IS NULL AND rpm_override IS NULL
`, groupID)
return err
}
// DeleteByGroupID 删除指定分组的所有用户专属条目
func (r *userGroupRateRepository) DeleteByGroupID(ctx context.Context, groupID int64) error {
_, err := r.sql.ExecContext(ctx, `DELETE FROM user_group_rate_multipliers WHERE group_id = $1`, groupID)
return err
}
// DeleteByUserID 删除指定用户的所有专属倍率
// DeleteByUserID 删除指定用户的所有专属条目
func (r *userGroupRateRepository) DeleteByUserID(ctx context.Context, userID int64) error {
_, err := r.sql.ExecContext(ctx, `DELETE FROM user_group_rate_multipliers WHERE user_id = $1`, userID)
return err

View File

@@ -93,6 +93,7 @@ func (r *userRepository) Create(ctx context.Context, userIn *service.User) error
SetSignupSource(userSignupSourceOrDefault(userIn.SignupSource)).
SetNillableLastLoginAt(userIn.LastLoginAt).
SetNillableLastActiveAt(userIn.LastActiveAt).
SetRpmLimit(userIn.RPMLimit).
Save(txCtx)
if err != nil {
return translatePersistenceError(err, nil, service.ErrEmailExists)
@@ -219,7 +220,8 @@ func (r *userRepository) Update(ctx context.Context, userIn *service.User) error
SetBalanceNotifyThresholdType(userIn.BalanceNotifyThresholdType).
SetNillableBalanceNotifyThreshold(userIn.BalanceNotifyThreshold).
SetBalanceNotifyExtraEmails(marshalExtraEmails(userIn.BalanceNotifyExtraEmails)).
SetTotalRecharged(userIn.TotalRecharged)
SetTotalRecharged(userIn.TotalRecharged).
SetRpmLimit(userIn.RPMLimit)
if userIn.SignupSource != "" {
updateOp = updateOp.SetSignupSource(userIn.SignupSource)
}

View File

@@ -0,0 +1,108 @@
package repository
import (
"context"
"fmt"
"time"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/redis/go-redis/v9"
)
// 用户/分组级 RPM 计数器 Redis 实现。
//
// 设计说明:
// - key 形式rpm:ug:{uid}:{gid}:{minute}、rpm:u:{uid}:{minute}
// - 时间来源rdb.Time()Redis 服务端时间),避免多实例时钟漂移。
// - 原子操作TxPipeline (MULTI/EXEC) 执行 INCR+EXPIRE兼容 Redis Cluster。
// - TTL120s覆盖当前分钟窗口 + 少量冗余。
// - 返回值语义超限判断由调用方billing_cache_service.checkRPM与 RPMLimit 比较完成。
const (
userGroupRPMKeyPrefix = "rpm:ug:"
userRPMKeyPrefix = "rpm:u:"
userRPMKeyTTL = 120 * time.Second
)
type userRPMCacheImpl struct {
rdb *redis.Client
}
// NewUserRPMCache 创建用户/分组级 RPM 计数器。
func NewUserRPMCache(rdb *redis.Client) service.UserRPMCache {
return &userRPMCacheImpl{rdb: rdb}
}
// minuteTS 获取当前 Redis 服务端分钟时间戳。
func (c *userRPMCacheImpl) minuteTS(ctx context.Context) (int64, error) {
t, err := c.rdb.Time(ctx).Result()
if err != nil {
return 0, fmt.Errorf("redis TIME: %w", err)
}
return t.Unix() / 60, nil
}
// atomicIncr 原子 INCR+EXPIRE。
func (c *userRPMCacheImpl) atomicIncr(ctx context.Context, key string) (int, error) {
pipe := c.rdb.TxPipeline()
incr := pipe.Incr(ctx, key)
pipe.Expire(ctx, key, userRPMKeyTTL)
if _, err := pipe.Exec(ctx); err != nil {
return 0, fmt.Errorf("user rpm increment: %w", err)
}
return int(incr.Val()), nil
}
// IncrementUserGroupRPM 递增 (user, group) 分钟计数。
func (c *userRPMCacheImpl) IncrementUserGroupRPM(ctx context.Context, userID, groupID int64) (int, error) {
minute, err := c.minuteTS(ctx)
if err != nil {
return 0, err
}
key := fmt.Sprintf("%s%d:%d:%d", userGroupRPMKeyPrefix, userID, groupID, minute)
return c.atomicIncr(ctx, key)
}
// IncrementUserRPM 递增用户分钟计数。
func (c *userRPMCacheImpl) IncrementUserRPM(ctx context.Context, userID int64) (int, error) {
minute, err := c.minuteTS(ctx)
if err != nil {
return 0, err
}
key := fmt.Sprintf("%s%d:%d", userRPMKeyPrefix, userID, minute)
return c.atomicIncr(ctx, key)
}
// GetUserGroupRPM 获取 (user, group) 当前分钟已用 RPM只读
func (c *userRPMCacheImpl) GetUserGroupRPM(ctx context.Context, userID, groupID int64) (int, error) {
minute, err := c.minuteTS(ctx)
if err != nil {
return 0, err
}
key := fmt.Sprintf("%s%d:%d:%d", userGroupRPMKeyPrefix, userID, groupID, minute)
val, err := c.rdb.Get(ctx, key).Int()
if err == redis.Nil {
return 0, nil
}
if err != nil {
return 0, fmt.Errorf("user group rpm get: %w", err)
}
return val, nil
}
// GetUserRPM 获取用户当前分钟已用 RPM只读
func (c *userRPMCacheImpl) GetUserRPM(ctx context.Context, userID int64) (int, error) {
minute, err := c.minuteTS(ctx)
if err != nil {
return 0, err
}
key := fmt.Sprintf("%s%d:%d", userRPMKeyPrefix, userID, minute)
val, err := c.rdb.Get(ctx, key).Int()
if err == redis.Nil {
return 0, nil
}
if err != nil {
return 0, fmt.Errorf("user rpm get: %w", err)
}
return val, nil
}

View File

@@ -101,6 +101,7 @@ var ProviderSet = wire.NewSet(
ProvideConcurrencyCache,
ProvideSessionLimitCache,
NewRPMCache,
NewUserRPMCache,
NewUserMsgQueueCache,
NewDashboardCache,
NewEmailCache,