feat(rpm): RPM 限流模块优化
P0: - rpm_override 嵌入 Auth Cache Snapshot,消除每请求 DB 查询 (snapshot v6→v7) - 429 RPM 响应返回 Retry-After 头(当前分钟剩余秒数) P1: - ClearAll 按钮直连 DELETE API,带 loading 防重复 - 新增 GET /admin/users/:id/rpm-status 管理员 RPM 用量查询端点 优化: - checkRPM 从级联互斥改为并行取最严,user.rpm_limit 作为全局硬上限始终生效 - Override/Group 变更后自动失效 auth cache - fail-open 语义不变,Redis 故障不阻塞业务
This commit is contained in:
@@ -152,6 +152,7 @@ func (r *apiKeyRepository) GetByKeyForAuth(ctx context.Context, key string) (*se
|
||||
user.FieldSignupSource,
|
||||
user.FieldLastLoginAt,
|
||||
user.FieldLastActiveAt,
|
||||
user.FieldRpmLimit,
|
||||
)
|
||||
}).
|
||||
WithGroup(func(q *dbent.GroupQuery) {
|
||||
@@ -178,6 +179,7 @@ func (r *apiKeyRepository) GetByKeyForAuth(ctx context.Context, key string) (*se
|
||||
group.FieldAllowMessagesDispatch,
|
||||
group.FieldDefaultMappedModel,
|
||||
group.FieldMessagesDispatchModelConfig,
|
||||
group.FieldRpmLimit,
|
||||
)
|
||||
}).
|
||||
Only(ctx)
|
||||
@@ -669,6 +671,7 @@ func userEntityToService(u *dbent.User) *service.User {
|
||||
BalanceNotifyThresholdType: u.BalanceNotifyThresholdType,
|
||||
BalanceNotifyThreshold: u.BalanceNotifyThreshold,
|
||||
TotalRecharged: u.TotalRecharged,
|
||||
RPMLimit: u.RpmLimit,
|
||||
CreatedAt: u.CreatedAt,
|
||||
UpdatedAt: u.UpdatedAt,
|
||||
}
|
||||
@@ -713,6 +716,7 @@ func groupEntityToService(g *dbent.Group) *service.Group {
|
||||
RequirePrivacySet: g.RequirePrivacySet,
|
||||
DefaultMappedModel: g.DefaultMappedModel,
|
||||
MessagesDispatchModelConfig: g.MessagesDispatchModelConfig,
|
||||
RPMLimit: g.RpmLimit,
|
||||
CreatedAt: g.CreatedAt,
|
||||
UpdatedAt: g.UpdatedAt,
|
||||
}
|
||||
|
||||
@@ -63,7 +63,8 @@ func (r *groupRepository) Create(ctx context.Context, groupIn *service.Group) er
|
||||
SetRequireOauthOnly(groupIn.RequireOAuthOnly).
|
||||
SetRequirePrivacySet(groupIn.RequirePrivacySet).
|
||||
SetDefaultMappedModel(groupIn.DefaultMappedModel).
|
||||
SetMessagesDispatchModelConfig(groupIn.MessagesDispatchModelConfig)
|
||||
SetMessagesDispatchModelConfig(groupIn.MessagesDispatchModelConfig).
|
||||
SetRpmLimit(groupIn.RPMLimit)
|
||||
|
||||
// 设置模型路由配置
|
||||
if groupIn.ModelRouting != nil {
|
||||
@@ -130,7 +131,8 @@ func (r *groupRepository) Update(ctx context.Context, groupIn *service.Group) er
|
||||
SetRequireOauthOnly(groupIn.RequireOAuthOnly).
|
||||
SetRequirePrivacySet(groupIn.RequirePrivacySet).
|
||||
SetDefaultMappedModel(groupIn.DefaultMappedModel).
|
||||
SetMessagesDispatchModelConfig(groupIn.MessagesDispatchModelConfig)
|
||||
SetMessagesDispatchModelConfig(groupIn.MessagesDispatchModelConfig).
|
||||
SetRpmLimit(groupIn.RPMLimit)
|
||||
|
||||
// 显式处理可空字段:nil 需要 clear,非 nil 需要 set。
|
||||
if groupIn.DailyLimitUSD != nil {
|
||||
|
||||
@@ -13,14 +13,14 @@ type userGroupRateRepository struct {
|
||||
sql sqlExecutor
|
||||
}
|
||||
|
||||
// NewUserGroupRateRepository 创建用户专属分组倍率仓储
|
||||
// NewUserGroupRateRepository 创建用户专属分组倍率/RPM 仓储
|
||||
func NewUserGroupRateRepository(sqlDB *sql.DB) service.UserGroupRateRepository {
|
||||
return &userGroupRateRepository{sql: sqlDB}
|
||||
}
|
||||
|
||||
// GetByUserID 获取用户的所有专属分组倍率
|
||||
// GetByUserID 获取用户所有专属分组 rate_multiplier(仅返回非 NULL 的条目)
|
||||
func (r *userGroupRateRepository) GetByUserID(ctx context.Context, userID int64) (map[int64]float64, error) {
|
||||
query := `SELECT group_id, rate_multiplier FROM user_group_rate_multipliers WHERE user_id = $1`
|
||||
query := `SELECT group_id, rate_multiplier FROM user_group_rate_multipliers WHERE user_id = $1 AND rate_multiplier IS NOT NULL`
|
||||
rows, err := r.sql.QueryContext(ctx, query, userID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -42,8 +42,7 @@ func (r *userGroupRateRepository) GetByUserID(ctx context.Context, userID int64)
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// GetByUserIDs 批量获取多个用户的专属分组倍率。
|
||||
// 返回结构:map[userID]map[groupID]rate
|
||||
// GetByUserIDs 批量获取多个用户的专属分组 rate_multiplier(仅返回非 NULL 的条目)
|
||||
func (r *userGroupRateRepository) GetByUserIDs(ctx context.Context, userIDs []int64) (map[int64]map[int64]float64, error) {
|
||||
result := make(map[int64]map[int64]float64, len(userIDs))
|
||||
if len(userIDs) == 0 {
|
||||
@@ -70,7 +69,7 @@ func (r *userGroupRateRepository) GetByUserIDs(ctx context.Context, userIDs []in
|
||||
rows, err := r.sql.QueryContext(ctx, `
|
||||
SELECT user_id, group_id, rate_multiplier
|
||||
FROM user_group_rate_multipliers
|
||||
WHERE user_id = ANY($1)
|
||||
WHERE user_id = ANY($1) AND rate_multiplier IS NOT NULL
|
||||
`, pq.Array(uniqueIDs))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -95,10 +94,10 @@ func (r *userGroupRateRepository) GetByUserIDs(ctx context.Context, userIDs []in
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// GetByGroupID 获取指定分组下所有用户的专属倍率
|
||||
// GetByGroupID 获取指定分组下所有用户的专属配置(rate 与 rpm_override 任一非 NULL 即返回)
|
||||
func (r *userGroupRateRepository) GetByGroupID(ctx context.Context, groupID int64) ([]service.UserGroupRateEntry, error) {
|
||||
query := `
|
||||
SELECT ugr.user_id, u.username, u.email, COALESCE(u.notes, ''), u.status, ugr.rate_multiplier
|
||||
SELECT ugr.user_id, u.username, u.email, COALESCE(u.notes, ''), u.status, ugr.rate_multiplier, ugr.rpm_override
|
||||
FROM user_group_rate_multipliers ugr
|
||||
JOIN users u ON u.id = ugr.user_id AND u.deleted_at IS NULL
|
||||
WHERE ugr.group_id = $1
|
||||
@@ -113,9 +112,19 @@ func (r *userGroupRateRepository) GetByGroupID(ctx context.Context, groupID int6
|
||||
var result []service.UserGroupRateEntry
|
||||
for rows.Next() {
|
||||
var entry service.UserGroupRateEntry
|
||||
if err := rows.Scan(&entry.UserID, &entry.UserName, &entry.UserEmail, &entry.UserNotes, &entry.UserStatus, &entry.RateMultiplier); err != nil {
|
||||
var rate sql.NullFloat64
|
||||
var rpm sql.NullInt32
|
||||
if err := rows.Scan(&entry.UserID, &entry.UserName, &entry.UserEmail, &entry.UserNotes, &entry.UserStatus, &rate, &rpm); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if rate.Valid {
|
||||
v := rate.Float64
|
||||
entry.RateMultiplier = &v
|
||||
}
|
||||
if rpm.Valid {
|
||||
v := int(rpm.Int32)
|
||||
entry.RPMOverride = &v
|
||||
}
|
||||
result = append(result, entry)
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
@@ -124,10 +133,10 @@ func (r *userGroupRateRepository) GetByGroupID(ctx context.Context, groupID int6
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// GetByUserAndGroup 获取用户在特定分组的专属倍率
|
||||
// GetByUserAndGroup 获取用户在特定分组的专属 rate_multiplier(NULL 返回 nil)
|
||||
func (r *userGroupRateRepository) GetByUserAndGroup(ctx context.Context, userID, groupID int64) (*float64, error) {
|
||||
query := `SELECT rate_multiplier FROM user_group_rate_multipliers WHERE user_id = $1 AND group_id = $2`
|
||||
var rate float64
|
||||
var rate sql.NullFloat64
|
||||
err := scanSingleRow(ctx, r.sql, query, []any{userID, groupID}, &rate)
|
||||
if err == sql.ErrNoRows {
|
||||
return nil, nil
|
||||
@@ -135,42 +144,79 @@ func (r *userGroupRateRepository) GetByUserAndGroup(ctx context.Context, userID,
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &rate, nil
|
||||
if !rate.Valid {
|
||||
return nil, nil
|
||||
}
|
||||
v := rate.Float64
|
||||
return &v, nil
|
||||
}
|
||||
|
||||
// SyncUserGroupRates 同步用户的分组专属倍率
|
||||
// GetRPMOverrideByUserAndGroup 获取用户在特定分组的 rpm_override(NULL 返回 nil)
|
||||
func (r *userGroupRateRepository) GetRPMOverrideByUserAndGroup(ctx context.Context, userID, groupID int64) (*int, error) {
|
||||
query := `SELECT rpm_override FROM user_group_rate_multipliers WHERE user_id = $1 AND group_id = $2`
|
||||
var rpm sql.NullInt32
|
||||
err := scanSingleRow(ctx, r.sql, query, []any{userID, groupID}, &rpm)
|
||||
if err == sql.ErrNoRows {
|
||||
return nil, nil
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if !rpm.Valid {
|
||||
return nil, nil
|
||||
}
|
||||
v := int(rpm.Int32)
|
||||
return &v, nil
|
||||
}
|
||||
|
||||
// SyncUserGroupRates 同步用户的分组专属 rate_multiplier。
|
||||
// - 传入空 map:清空该用户所有行的 rate_multiplier;若 rpm_override 也为 NULL 则整行删除。
|
||||
// - 值为 nil:清空对应行的 rate_multiplier(保留 rpm_override)。
|
||||
// - 值非 nil:upsert rate_multiplier(保留已有 rpm_override)。
|
||||
func (r *userGroupRateRepository) SyncUserGroupRates(ctx context.Context, userID int64, rates map[int64]*float64) error {
|
||||
if len(rates) == 0 {
|
||||
// 如果传入空 map,删除该用户的所有专属倍率
|
||||
_, err := r.sql.ExecContext(ctx, `DELETE FROM user_group_rate_multipliers WHERE user_id = $1`, userID)
|
||||
if _, err := r.sql.ExecContext(ctx, `
|
||||
UPDATE user_group_rate_multipliers
|
||||
SET rate_multiplier = NULL, updated_at = NOW()
|
||||
WHERE user_id = $1
|
||||
`, userID); err != nil {
|
||||
return err
|
||||
}
|
||||
_, err := r.sql.ExecContext(ctx,
|
||||
`DELETE FROM user_group_rate_multipliers WHERE user_id = $1 AND rate_multiplier IS NULL AND rpm_override IS NULL`,
|
||||
userID)
|
||||
return err
|
||||
}
|
||||
|
||||
// 分离需要删除和需要 upsert 的记录
|
||||
var toDelete []int64
|
||||
var clearGroupIDs []int64
|
||||
upsertGroupIDs := make([]int64, 0, len(rates))
|
||||
upsertRates := make([]float64, 0, len(rates))
|
||||
for groupID, rate := range rates {
|
||||
if rate == nil {
|
||||
toDelete = append(toDelete, groupID)
|
||||
clearGroupIDs = append(clearGroupIDs, groupID)
|
||||
} else {
|
||||
upsertGroupIDs = append(upsertGroupIDs, groupID)
|
||||
upsertRates = append(upsertRates, *rate)
|
||||
}
|
||||
}
|
||||
|
||||
// 删除指定的记录
|
||||
if len(toDelete) > 0 {
|
||||
if len(clearGroupIDs) > 0 {
|
||||
if _, err := r.sql.ExecContext(ctx, `
|
||||
UPDATE user_group_rate_multipliers
|
||||
SET rate_multiplier = NULL, updated_at = NOW()
|
||||
WHERE user_id = $1 AND group_id = ANY($2)
|
||||
`, userID, pq.Array(clearGroupIDs)); err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := r.sql.ExecContext(ctx,
|
||||
`DELETE FROM user_group_rate_multipliers WHERE user_id = $1 AND group_id = ANY($2)`,
|
||||
userID, pq.Array(toDelete)); err != nil {
|
||||
`DELETE FROM user_group_rate_multipliers WHERE user_id = $1 AND group_id = ANY($2) AND rate_multiplier IS NULL AND rpm_override IS NULL`,
|
||||
userID, pq.Array(clearGroupIDs)); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
// Upsert 记录
|
||||
now := time.Now()
|
||||
if len(upsertGroupIDs) > 0 {
|
||||
now := time.Now()
|
||||
_, err := r.sql.ExecContext(ctx, `
|
||||
INSERT INTO user_group_rate_multipliers (user_id, group_id, rate_multiplier, created_at, updated_at)
|
||||
SELECT
|
||||
@@ -193,14 +239,47 @@ func (r *userGroupRateRepository) SyncUserGroupRates(ctx context.Context, userID
|
||||
return nil
|
||||
}
|
||||
|
||||
// SyncGroupRateMultipliers 批量同步分组的用户专属倍率(先删后插)
|
||||
// SyncGroupRateMultipliers 同步分组的 rate_multiplier 部分(不触动 rpm_override)。
|
||||
// 语义:
|
||||
// - 未出现在 entries 中的用户行:rate_multiplier 归 NULL;若 rpm_override 也为 NULL 则整行删除。
|
||||
// - 出现的用户行:upsert rate_multiplier。
|
||||
func (r *userGroupRateRepository) SyncGroupRateMultipliers(ctx context.Context, groupID int64, entries []service.GroupRateMultiplierInput) error {
|
||||
if _, err := r.sql.ExecContext(ctx, `DELETE FROM user_group_rate_multipliers WHERE group_id = $1`, groupID); err != nil {
|
||||
keepUserIDs := make([]int64, 0, len(entries))
|
||||
for _, e := range entries {
|
||||
keepUserIDs = append(keepUserIDs, e.UserID)
|
||||
}
|
||||
|
||||
// 未在 entries 列表中的行:清空 rate_multiplier。
|
||||
if len(keepUserIDs) == 0 {
|
||||
if _, err := r.sql.ExecContext(ctx, `
|
||||
UPDATE user_group_rate_multipliers
|
||||
SET rate_multiplier = NULL, updated_at = NOW()
|
||||
WHERE group_id = $1
|
||||
`, groupID); err != nil {
|
||||
return err
|
||||
}
|
||||
} else {
|
||||
if _, err := r.sql.ExecContext(ctx, `
|
||||
UPDATE user_group_rate_multipliers
|
||||
SET rate_multiplier = NULL, updated_at = NOW()
|
||||
WHERE group_id = $1 AND user_id <> ALL($2)
|
||||
`, groupID, pq.Array(keepUserIDs)); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
// 清空后若整行 NULL 则删除。
|
||||
if _, err := r.sql.ExecContext(ctx, `
|
||||
DELETE FROM user_group_rate_multipliers
|
||||
WHERE group_id = $1 AND rate_multiplier IS NULL AND rpm_override IS NULL
|
||||
`, groupID); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if len(entries) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
userIDs := make([]int64, len(entries))
|
||||
rates := make([]float64, len(entries))
|
||||
for i, e := range entries {
|
||||
@@ -218,13 +297,103 @@ func (r *userGroupRateRepository) SyncGroupRateMultipliers(ctx context.Context,
|
||||
return err
|
||||
}
|
||||
|
||||
// DeleteByGroupID 删除指定分组的所有用户专属倍率
|
||||
// SyncGroupRPMOverrides 同步分组的 rpm_override 部分(不触动 rate_multiplier)。
|
||||
// 语义:
|
||||
// - 未出现的用户行:rpm_override 归 NULL;若 rate_multiplier 也为 NULL 则整行删除。
|
||||
// - 出现的用户行:若 RPMOverride 为 nil 则清空;非 nil 则 upsert。
|
||||
func (r *userGroupRateRepository) SyncGroupRPMOverrides(ctx context.Context, groupID int64, entries []service.GroupRPMOverrideInput) error {
|
||||
keepUserIDs := make([]int64, 0, len(entries))
|
||||
var clearUserIDs []int64
|
||||
upsertUserIDs := make([]int64, 0, len(entries))
|
||||
upsertValues := make([]int32, 0, len(entries))
|
||||
for _, e := range entries {
|
||||
keepUserIDs = append(keepUserIDs, e.UserID)
|
||||
if e.RPMOverride == nil {
|
||||
clearUserIDs = append(clearUserIDs, e.UserID)
|
||||
} else {
|
||||
upsertUserIDs = append(upsertUserIDs, e.UserID)
|
||||
upsertValues = append(upsertValues, int32(*e.RPMOverride))
|
||||
}
|
||||
}
|
||||
|
||||
// 未在 entries 列表中的行:清空 rpm_override。
|
||||
if len(keepUserIDs) == 0 {
|
||||
if _, err := r.sql.ExecContext(ctx, `
|
||||
UPDATE user_group_rate_multipliers
|
||||
SET rpm_override = NULL, updated_at = NOW()
|
||||
WHERE group_id = $1
|
||||
`, groupID); err != nil {
|
||||
return err
|
||||
}
|
||||
} else {
|
||||
if _, err := r.sql.ExecContext(ctx, `
|
||||
UPDATE user_group_rate_multipliers
|
||||
SET rpm_override = NULL, updated_at = NOW()
|
||||
WHERE group_id = $1 AND user_id <> ALL($2)
|
||||
`, groupID, pq.Array(keepUserIDs)); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
// 显式 clear 的行。
|
||||
if len(clearUserIDs) > 0 {
|
||||
if _, err := r.sql.ExecContext(ctx, `
|
||||
UPDATE user_group_rate_multipliers
|
||||
SET rpm_override = NULL, updated_at = NOW()
|
||||
WHERE group_id = $1 AND user_id = ANY($2)
|
||||
`, groupID, pq.Array(clearUserIDs)); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
// 清空后若整行 NULL 则删除。
|
||||
if _, err := r.sql.ExecContext(ctx, `
|
||||
DELETE FROM user_group_rate_multipliers
|
||||
WHERE group_id = $1 AND rate_multiplier IS NULL AND rpm_override IS NULL
|
||||
`, groupID); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if len(upsertUserIDs) > 0 {
|
||||
now := time.Now()
|
||||
_, err := r.sql.ExecContext(ctx, `
|
||||
INSERT INTO user_group_rate_multipliers (user_id, group_id, rpm_override, created_at, updated_at)
|
||||
SELECT data.user_id, $1::bigint, data.rpm_override, $2::timestamptz, $2::timestamptz
|
||||
FROM unnest($3::bigint[], $4::integer[]) AS data(user_id, rpm_override)
|
||||
ON CONFLICT (user_id, group_id)
|
||||
DO UPDATE SET rpm_override = EXCLUDED.rpm_override, updated_at = EXCLUDED.updated_at
|
||||
`, groupID, now, pq.Array(upsertUserIDs), pq.Array(upsertValues))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ClearGroupRPMOverrides 清空指定分组所有行的 rpm_override。
|
||||
func (r *userGroupRateRepository) ClearGroupRPMOverrides(ctx context.Context, groupID int64) error {
|
||||
if _, err := r.sql.ExecContext(ctx, `
|
||||
UPDATE user_group_rate_multipliers
|
||||
SET rpm_override = NULL, updated_at = NOW()
|
||||
WHERE group_id = $1
|
||||
`, groupID); err != nil {
|
||||
return err
|
||||
}
|
||||
_, err := r.sql.ExecContext(ctx, `
|
||||
DELETE FROM user_group_rate_multipliers
|
||||
WHERE group_id = $1 AND rate_multiplier IS NULL AND rpm_override IS NULL
|
||||
`, groupID)
|
||||
return err
|
||||
}
|
||||
|
||||
// DeleteByGroupID 删除指定分组的所有用户专属条目
|
||||
func (r *userGroupRateRepository) DeleteByGroupID(ctx context.Context, groupID int64) error {
|
||||
_, err := r.sql.ExecContext(ctx, `DELETE FROM user_group_rate_multipliers WHERE group_id = $1`, groupID)
|
||||
return err
|
||||
}
|
||||
|
||||
// DeleteByUserID 删除指定用户的所有专属倍率
|
||||
// DeleteByUserID 删除指定用户的所有专属条目
|
||||
func (r *userGroupRateRepository) DeleteByUserID(ctx context.Context, userID int64) error {
|
||||
_, err := r.sql.ExecContext(ctx, `DELETE FROM user_group_rate_multipliers WHERE user_id = $1`, userID)
|
||||
return err
|
||||
|
||||
@@ -93,6 +93,7 @@ func (r *userRepository) Create(ctx context.Context, userIn *service.User) error
|
||||
SetSignupSource(userSignupSourceOrDefault(userIn.SignupSource)).
|
||||
SetNillableLastLoginAt(userIn.LastLoginAt).
|
||||
SetNillableLastActiveAt(userIn.LastActiveAt).
|
||||
SetRpmLimit(userIn.RPMLimit).
|
||||
Save(txCtx)
|
||||
if err != nil {
|
||||
return translatePersistenceError(err, nil, service.ErrEmailExists)
|
||||
@@ -219,7 +220,8 @@ func (r *userRepository) Update(ctx context.Context, userIn *service.User) error
|
||||
SetBalanceNotifyThresholdType(userIn.BalanceNotifyThresholdType).
|
||||
SetNillableBalanceNotifyThreshold(userIn.BalanceNotifyThreshold).
|
||||
SetBalanceNotifyExtraEmails(marshalExtraEmails(userIn.BalanceNotifyExtraEmails)).
|
||||
SetTotalRecharged(userIn.TotalRecharged)
|
||||
SetTotalRecharged(userIn.TotalRecharged).
|
||||
SetRpmLimit(userIn.RPMLimit)
|
||||
if userIn.SignupSource != "" {
|
||||
updateOp = updateOp.SetSignupSource(userIn.SignupSource)
|
||||
}
|
||||
|
||||
108
backend/internal/repository/user_rpm_cache.go
Normal file
108
backend/internal/repository/user_rpm_cache.go
Normal file
@@ -0,0 +1,108 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/redis/go-redis/v9"
|
||||
)
|
||||
|
||||
// 用户/分组级 RPM 计数器 Redis 实现。
|
||||
//
|
||||
// 设计说明:
|
||||
// - key 形式:rpm:ug:{uid}:{gid}:{minute}、rpm:u:{uid}:{minute}
|
||||
// - 时间来源:rdb.Time()(Redis 服务端时间),避免多实例时钟漂移。
|
||||
// - 原子操作:TxPipeline (MULTI/EXEC) 执行 INCR+EXPIRE,兼容 Redis Cluster。
|
||||
// - TTL:120s,覆盖当前分钟窗口 + 少量冗余。
|
||||
// - 返回值语义:超限判断由调用方(billing_cache_service.checkRPM)与 RPMLimit 比较完成。
|
||||
const (
|
||||
userGroupRPMKeyPrefix = "rpm:ug:"
|
||||
userRPMKeyPrefix = "rpm:u:"
|
||||
|
||||
userRPMKeyTTL = 120 * time.Second
|
||||
)
|
||||
|
||||
type userRPMCacheImpl struct {
|
||||
rdb *redis.Client
|
||||
}
|
||||
|
||||
// NewUserRPMCache 创建用户/分组级 RPM 计数器。
|
||||
func NewUserRPMCache(rdb *redis.Client) service.UserRPMCache {
|
||||
return &userRPMCacheImpl{rdb: rdb}
|
||||
}
|
||||
|
||||
// minuteTS 获取当前 Redis 服务端分钟时间戳。
|
||||
func (c *userRPMCacheImpl) minuteTS(ctx context.Context) (int64, error) {
|
||||
t, err := c.rdb.Time(ctx).Result()
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("redis TIME: %w", err)
|
||||
}
|
||||
return t.Unix() / 60, nil
|
||||
}
|
||||
|
||||
// atomicIncr 原子 INCR+EXPIRE。
|
||||
func (c *userRPMCacheImpl) atomicIncr(ctx context.Context, key string) (int, error) {
|
||||
pipe := c.rdb.TxPipeline()
|
||||
incr := pipe.Incr(ctx, key)
|
||||
pipe.Expire(ctx, key, userRPMKeyTTL)
|
||||
if _, err := pipe.Exec(ctx); err != nil {
|
||||
return 0, fmt.Errorf("user rpm increment: %w", err)
|
||||
}
|
||||
return int(incr.Val()), nil
|
||||
}
|
||||
|
||||
// IncrementUserGroupRPM 递增 (user, group) 分钟计数。
|
||||
func (c *userRPMCacheImpl) IncrementUserGroupRPM(ctx context.Context, userID, groupID int64) (int, error) {
|
||||
minute, err := c.minuteTS(ctx)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
key := fmt.Sprintf("%s%d:%d:%d", userGroupRPMKeyPrefix, userID, groupID, minute)
|
||||
return c.atomicIncr(ctx, key)
|
||||
}
|
||||
|
||||
// IncrementUserRPM 递增用户分钟计数。
|
||||
func (c *userRPMCacheImpl) IncrementUserRPM(ctx context.Context, userID int64) (int, error) {
|
||||
minute, err := c.minuteTS(ctx)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
key := fmt.Sprintf("%s%d:%d", userRPMKeyPrefix, userID, minute)
|
||||
return c.atomicIncr(ctx, key)
|
||||
}
|
||||
|
||||
// GetUserGroupRPM 获取 (user, group) 当前分钟已用 RPM(只读)。
|
||||
func (c *userRPMCacheImpl) GetUserGroupRPM(ctx context.Context, userID, groupID int64) (int, error) {
|
||||
minute, err := c.minuteTS(ctx)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
key := fmt.Sprintf("%s%d:%d:%d", userGroupRPMKeyPrefix, userID, groupID, minute)
|
||||
val, err := c.rdb.Get(ctx, key).Int()
|
||||
if err == redis.Nil {
|
||||
return 0, nil
|
||||
}
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("user group rpm get: %w", err)
|
||||
}
|
||||
return val, nil
|
||||
}
|
||||
|
||||
// GetUserRPM 获取用户当前分钟已用 RPM(只读)。
|
||||
func (c *userRPMCacheImpl) GetUserRPM(ctx context.Context, userID int64) (int, error) {
|
||||
minute, err := c.minuteTS(ctx)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
key := fmt.Sprintf("%s%d:%d", userRPMKeyPrefix, userID, minute)
|
||||
val, err := c.rdb.Get(ctx, key).Int()
|
||||
if err == redis.Nil {
|
||||
return 0, nil
|
||||
}
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("user rpm get: %w", err)
|
||||
}
|
||||
return val, nil
|
||||
}
|
||||
@@ -101,6 +101,7 @@ var ProviderSet = wire.NewSet(
|
||||
ProvideConcurrencyCache,
|
||||
ProvideSessionLimitCache,
|
||||
NewRPMCache,
|
||||
NewUserRPMCache,
|
||||
NewUserMsgQueueCache,
|
||||
NewDashboardCache,
|
||||
NewEmailCache,
|
||||
|
||||
Reference in New Issue
Block a user