P0: - rpm_override 嵌入 Auth Cache Snapshot,消除每请求 DB 查询 (snapshot v6→v7) - 429 RPM 响应返回 Retry-After 头(当前分钟剩余秒数) P1: - ClearAll 按钮直连 DELETE API,带 loading 防重复 - 新增 GET /admin/users/:id/rpm-status 管理员 RPM 用量查询端点 优化: - checkRPM 从级联互斥改为并行取最严,user.rpm_limit 作为全局硬上限始终生效 - Override/Group 变更后自动失效 auth cache - fail-open 语义不变,Redis 故障不阻塞业务
254 lines
9.8 KiB
Go
254 lines
9.8 KiB
Go
//go:build unit
|
||
|
||
package service
|
||
|
||
import (
|
||
"context"
|
||
"errors"
|
||
"sync/atomic"
|
||
"testing"
|
||
|
||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||
"github.com/stretchr/testify/require"
|
||
)
|
||
|
||
// userRPMCacheStub 记录每种计数器被调用的次数,并可注入返回值与错误。
|
||
type userRPMCacheStub struct {
|
||
userGroupCalls int32
|
||
userCalls int32
|
||
|
||
userGroupCounts []int // 依次返回的计数值
|
||
userGroupErr error
|
||
userCounts []int
|
||
userErr error
|
||
}
|
||
|
||
func (s *userRPMCacheStub) IncrementUserGroupRPM(_ context.Context, _, _ int64) (int, error) {
|
||
idx := int(atomic.AddInt32(&s.userGroupCalls, 1)) - 1
|
||
if s.userGroupErr != nil {
|
||
return 0, s.userGroupErr
|
||
}
|
||
if idx < len(s.userGroupCounts) {
|
||
return s.userGroupCounts[idx], nil
|
||
}
|
||
return 1, nil
|
||
}
|
||
|
||
func (s *userRPMCacheStub) IncrementUserRPM(_ context.Context, _ int64) (int, error) {
|
||
idx := int(atomic.AddInt32(&s.userCalls, 1)) - 1
|
||
if s.userErr != nil {
|
||
return 0, s.userErr
|
||
}
|
||
if idx < len(s.userCounts) {
|
||
return s.userCounts[idx], nil
|
||
}
|
||
return 1, nil
|
||
}
|
||
|
||
func (s *userRPMCacheStub) GetUserGroupRPM(_ context.Context, _, _ int64) (int, error) {
|
||
return 0, nil
|
||
}
|
||
|
||
func (s *userRPMCacheStub) GetUserRPM(_ context.Context, _ int64) (int, error) {
|
||
return 0, nil
|
||
}
|
||
|
||
// rpmOverrideRepoStub 专用于 checkRPM 分支测试,只实现必要方法。
|
||
type rpmOverrideRepoStub struct {
|
||
UserGroupRateRepository
|
||
|
||
override *int
|
||
err error
|
||
calls int32
|
||
}
|
||
|
||
func (s *rpmOverrideRepoStub) GetRPMOverrideByUserAndGroup(_ context.Context, _, _ int64) (*int, error) {
|
||
atomic.AddInt32(&s.calls, 1)
|
||
if s.err != nil {
|
||
return nil, s.err
|
||
}
|
||
return s.override, nil
|
||
}
|
||
|
||
func newBillingServiceForRPM(t *testing.T, cache UserRPMCache, rateRepo UserGroupRateRepository) *BillingCacheService {
|
||
t.Helper()
|
||
// 用 nil BillingCache 走 "无缓存" 分支,避免 CheckBillingEligibility 副作用。
|
||
// 我们只直接测 checkRPM。
|
||
svc := NewBillingCacheService(nil, nil, nil, nil, cache, rateRepo, &config.Config{})
|
||
t.Cleanup(svc.Stop)
|
||
return svc
|
||
}
|
||
|
||
func TestBillingCacheService_CheckRPM_OverrideTakesPrecedenceOverGroup(t *testing.T) {
|
||
override := 2
|
||
// user-group 计数: 1, 2, 3;user 计数: 默认返回 1(远小于 RPMLimit=100,不干扰)
|
||
cache := &userRPMCacheStub{userGroupCounts: []int{1, 2, 3}}
|
||
repo := &rpmOverrideRepoStub{override: &override}
|
||
svc := newBillingServiceForRPM(t, cache, repo)
|
||
|
||
user := &User{ID: 1, RPMLimit: 100} // 全局上限设高,不干扰 override 测试
|
||
group := &Group{ID: 10, RPMLimit: 100}
|
||
|
||
require.NoError(t, svc.checkRPM(context.Background(), user, group))
|
||
require.NoError(t, svc.checkRPM(context.Background(), user, group))
|
||
require.ErrorIs(t, svc.checkRPM(context.Background(), user, group), ErrGroupRPMExceeded)
|
||
|
||
require.EqualValues(t, 3, atomic.LoadInt32(&cache.userGroupCalls), "override 命中分支应走 user-group 计数")
|
||
// 并行设计:前 2 次 override 未超→继续检查 user;第 3 次 override 超了→直接 return,不检查 user
|
||
require.EqualValues(t, 2, atomic.LoadInt32(&cache.userCalls), "override 超限前 user 计数器应被调用")
|
||
require.EqualValues(t, 3, atomic.LoadInt32(&repo.calls))
|
||
}
|
||
|
||
func TestBillingCacheService_CheckRPM_UserLimitIsGlobalHardCap(t *testing.T) {
|
||
override := 100 // override 很高
|
||
// user-group 计数: 默认返回 1(远小于 override);user 计数: 1, 2, 3
|
||
cache := &userRPMCacheStub{userCounts: []int{1, 2, 3}}
|
||
repo := &rpmOverrideRepoStub{override: &override}
|
||
svc := newBillingServiceForRPM(t, cache, repo)
|
||
|
||
user := &User{ID: 1, RPMLimit: 2} // 全局硬上限=2,应覆盖 override=100
|
||
group := &Group{ID: 10, RPMLimit: 100}
|
||
|
||
require.NoError(t, svc.checkRPM(context.Background(), user, group))
|
||
require.NoError(t, svc.checkRPM(context.Background(), user, group))
|
||
require.ErrorIs(t, svc.checkRPM(context.Background(), user, group), ErrUserRPMExceeded, "user 全局硬上限应优先于 override")
|
||
}
|
||
|
||
func TestBillingCacheService_CheckRPM_OverrideZeroSkipsGroupButUserStillApplies(t *testing.T) {
|
||
zero := 0
|
||
// user 计数: 依次返回 1..6
|
||
cache := &userRPMCacheStub{userCounts: []int{1, 2, 3, 4, 5, 6}}
|
||
repo := &rpmOverrideRepoStub{override: &zero}
|
||
svc := newBillingServiceForRPM(t, cache, repo)
|
||
|
||
user := &User{ID: 1, RPMLimit: 5}
|
||
group := &Group{ID: 10, RPMLimit: 100}
|
||
|
||
// override=0 跳过分组计数,但 user.RPMLimit=5 仍生效
|
||
for i := 0; i < 5; i++ {
|
||
require.NoError(t, svc.checkRPM(context.Background(), user, group), "request %d should pass", i+1)
|
||
}
|
||
require.ErrorIs(t, svc.checkRPM(context.Background(), user, group), ErrUserRPMExceeded,
|
||
"override=0 跳过分组但 user 全局上限仍应生效")
|
||
require.EqualValues(t, 0, atomic.LoadInt32(&cache.userGroupCalls), "override=0 不应触发分组计数器")
|
||
require.EqualValues(t, 6, atomic.LoadInt32(&cache.userCalls), "user 计数器应被调用")
|
||
}
|
||
|
||
func TestBillingCacheService_CheckRPM_OverrideZeroAndUserZeroIsFullyUnlimited(t *testing.T) {
|
||
zero := 0
|
||
cache := &userRPMCacheStub{}
|
||
repo := &rpmOverrideRepoStub{override: &zero}
|
||
svc := newBillingServiceForRPM(t, cache, repo)
|
||
|
||
user := &User{ID: 1, RPMLimit: 0} // user 也不限
|
||
group := &Group{ID: 10, RPMLimit: 100}
|
||
|
||
for i := 0; i < 50; i++ {
|
||
require.NoError(t, svc.checkRPM(context.Background(), user, group))
|
||
}
|
||
require.EqualValues(t, 0, atomic.LoadInt32(&cache.userGroupCalls), "override=0 不触发分组计数")
|
||
require.EqualValues(t, 0, atomic.LoadInt32(&cache.userCalls), "user.RPMLimit=0 也不触发用户计数")
|
||
}
|
||
|
||
func TestBillingCacheService_CheckRPM_NilOverrideFallsThroughToGroup(t *testing.T) {
|
||
// user-group 计数: 5, 6;user 计数: 默认 1(不干扰)
|
||
cache := &userRPMCacheStub{userGroupCounts: []int{5, 6}}
|
||
repo := &rpmOverrideRepoStub{override: nil}
|
||
svc := newBillingServiceForRPM(t, cache, repo)
|
||
|
||
user := &User{ID: 1, RPMLimit: 999} // 全局上限很高,group 先超
|
||
group := &Group{ID: 10, RPMLimit: 5}
|
||
|
||
require.NoError(t, svc.checkRPM(context.Background(), user, group)) // ug=5, user=1, 都没超
|
||
require.ErrorIs(t, svc.checkRPM(context.Background(), user, group), ErrGroupRPMExceeded) // ug=6 > 5
|
||
|
||
require.EqualValues(t, 2, atomic.LoadInt32(&cache.userGroupCalls))
|
||
// 并行模式:第 1 次 group 没超 → 继续检查 user;第 2 次 group 超了 → 直接 return,不检查 user
|
||
require.EqualValues(t, 1, atomic.LoadInt32(&cache.userCalls), "group 未超时 user 也应检查;group 超时直接返回")
|
||
}
|
||
|
||
func TestBillingCacheService_CheckRPM_OverrideLookupErrorFallsThroughToGroup(t *testing.T) {
|
||
cache := &userRPMCacheStub{userGroupCounts: []int{3}}
|
||
repo := &rpmOverrideRepoStub{err: errors.New("db down")}
|
||
svc := newBillingServiceForRPM(t, cache, repo)
|
||
|
||
user := &User{ID: 1, RPMLimit: 0}
|
||
group := &Group{ID: 10, RPMLimit: 10}
|
||
|
||
// override 查询失败后应继续尝试 group 分支(不直接拒绝)
|
||
require.NoError(t, svc.checkRPM(context.Background(), user, group))
|
||
require.EqualValues(t, 1, atomic.LoadInt32(&cache.userGroupCalls))
|
||
require.EqualValues(t, 1, atomic.LoadInt32(&repo.calls))
|
||
}
|
||
|
||
func TestBillingCacheService_CheckRPM_UserLevelFallbackWhenGroupUnlimited(t *testing.T) {
|
||
cache := &userRPMCacheStub{userCounts: []int{1, 2, 3}}
|
||
repo := &rpmOverrideRepoStub{override: nil}
|
||
svc := newBillingServiceForRPM(t, cache, repo)
|
||
|
||
user := &User{ID: 1, RPMLimit: 2}
|
||
group := &Group{ID: 10, RPMLimit: 0} // 分组未设限
|
||
|
||
require.NoError(t, svc.checkRPM(context.Background(), user, group))
|
||
require.NoError(t, svc.checkRPM(context.Background(), user, group))
|
||
require.ErrorIs(t, svc.checkRPM(context.Background(), user, group), ErrUserRPMExceeded)
|
||
|
||
require.EqualValues(t, 0, atomic.LoadInt32(&cache.userGroupCalls), "group 未设限时不应 INCR user-group 键")
|
||
require.EqualValues(t, 3, atomic.LoadInt32(&cache.userCalls))
|
||
}
|
||
|
||
func TestBillingCacheService_CheckRPM_NoLimitsConfiguredIsNoop(t *testing.T) {
|
||
cache := &userRPMCacheStub{}
|
||
repo := &rpmOverrideRepoStub{override: nil}
|
||
svc := newBillingServiceForRPM(t, cache, repo)
|
||
|
||
user := &User{ID: 1, RPMLimit: 0}
|
||
group := &Group{ID: 10, RPMLimit: 0}
|
||
|
||
for i := 0; i < 10; i++ {
|
||
require.NoError(t, svc.checkRPM(context.Background(), user, group))
|
||
}
|
||
require.EqualValues(t, 0, atomic.LoadInt32(&cache.userGroupCalls))
|
||
require.EqualValues(t, 0, atomic.LoadInt32(&cache.userCalls))
|
||
}
|
||
|
||
func TestBillingCacheService_CheckRPM_RedisErrorFailOpen(t *testing.T) {
|
||
cache := &userRPMCacheStub{userGroupErr: errors.New("redis unavailable")}
|
||
repo := &rpmOverrideRepoStub{override: nil}
|
||
svc := newBillingServiceForRPM(t, cache, repo)
|
||
|
||
user := &User{ID: 1, RPMLimit: 0}
|
||
group := &Group{ID: 10, RPMLimit: 5}
|
||
|
||
// Redis 故障时应 fail-open,不拒绝请求
|
||
require.NoError(t, svc.checkRPM(context.Background(), user, group))
|
||
require.EqualValues(t, 1, atomic.LoadInt32(&cache.userGroupCalls))
|
||
}
|
||
|
||
func TestBillingCacheService_CheckRPM_NoGroupUsesUserOnly(t *testing.T) {
|
||
cache := &userRPMCacheStub{userCounts: []int{1, 2, 3}}
|
||
repo := &rpmOverrideRepoStub{}
|
||
svc := newBillingServiceForRPM(t, cache, repo)
|
||
|
||
user := &User{ID: 1, RPMLimit: 2}
|
||
|
||
// 无 group(纯用户级限流场景),不应查询 rpm_override。
|
||
require.NoError(t, svc.checkRPM(context.Background(), user, nil))
|
||
require.NoError(t, svc.checkRPM(context.Background(), user, nil))
|
||
require.ErrorIs(t, svc.checkRPM(context.Background(), user, nil), ErrUserRPMExceeded)
|
||
|
||
require.EqualValues(t, 0, atomic.LoadInt32(&repo.calls), "无 group 时不应查询 rpm_override")
|
||
require.EqualValues(t, 3, atomic.LoadInt32(&cache.userCalls))
|
||
}
|
||
|
||
func TestBillingCacheService_CheckRPM_NilUserIsNoop(t *testing.T) {
|
||
cache := &userRPMCacheStub{}
|
||
repo := &rpmOverrideRepoStub{}
|
||
svc := newBillingServiceForRPM(t, cache, repo)
|
||
|
||
require.NoError(t, svc.checkRPM(context.Background(), nil, &Group{ID: 1, RPMLimit: 10}))
|
||
require.EqualValues(t, 0, atomic.LoadInt32(&cache.userGroupCalls))
|
||
require.EqualValues(t, 0, atomic.LoadInt32(&cache.userCalls))
|
||
require.EqualValues(t, 0, atomic.LoadInt32(&repo.calls))
|
||
}
|