perf(middleware): 优化订阅模式认证中间件,5次串行调用降至2步同步+1步异步

- 为 GetActiveSubscription 添加 ristretto L1 缓存 + singleflight 防击穿
- 合并 ValidateSubscription + CheckUsageLimits 为纯内存 ValidateAndCheckLimits
- 窗口维护操作(激活/重置)异步化,不再阻塞首字节
- 缓存返回浅拷贝,避免并发 data race 和缓存污染
- 所有管理操作(分配/续期/撤销/扩展/窗口重置)同步失效 L1 缓存
- 新增 SubscriptionCacheConfig 可配置 L1 缓存大小/TTL/抖动

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
yangjianbo
2026-02-07 14:43:12 +08:00
parent 782a54a8a1
commit 0e514ed80b
4 changed files with 241 additions and 56 deletions

View File

@@ -38,31 +38,32 @@ const (
)
type Config struct {
Server ServerConfig `mapstructure:"server"`
CORS CORSConfig `mapstructure:"cors"`
Security SecurityConfig `mapstructure:"security"`
Billing BillingConfig `mapstructure:"billing"`
Turnstile TurnstileConfig `mapstructure:"turnstile"`
Database DatabaseConfig `mapstructure:"database"`
Redis RedisConfig `mapstructure:"redis"`
Ops OpsConfig `mapstructure:"ops"`
JWT JWTConfig `mapstructure:"jwt"`
Totp TotpConfig `mapstructure:"totp"`
LinuxDo LinuxDoConnectConfig `mapstructure:"linuxdo_connect"`
Default DefaultConfig `mapstructure:"default"`
RateLimit RateLimitConfig `mapstructure:"rate_limit"`
Pricing PricingConfig `mapstructure:"pricing"`
Gateway GatewayConfig `mapstructure:"gateway"`
APIKeyAuth APIKeyAuthCacheConfig `mapstructure:"api_key_auth_cache"`
Dashboard DashboardCacheConfig `mapstructure:"dashboard_cache"`
DashboardAgg DashboardAggregationConfig `mapstructure:"dashboard_aggregation"`
UsageCleanup UsageCleanupConfig `mapstructure:"usage_cleanup"`
Concurrency ConcurrencyConfig `mapstructure:"concurrency"`
TokenRefresh TokenRefreshConfig `mapstructure:"token_refresh"`
RunMode string `mapstructure:"run_mode" yaml:"run_mode"`
Timezone string `mapstructure:"timezone"` // e.g. "Asia/Shanghai", "UTC"
Gemini GeminiConfig `mapstructure:"gemini"`
Update UpdateConfig `mapstructure:"update"`
Server ServerConfig `mapstructure:"server"`
CORS CORSConfig `mapstructure:"cors"`
Security SecurityConfig `mapstructure:"security"`
Billing BillingConfig `mapstructure:"billing"`
Turnstile TurnstileConfig `mapstructure:"turnstile"`
Database DatabaseConfig `mapstructure:"database"`
Redis RedisConfig `mapstructure:"redis"`
Ops OpsConfig `mapstructure:"ops"`
JWT JWTConfig `mapstructure:"jwt"`
Totp TotpConfig `mapstructure:"totp"`
LinuxDo LinuxDoConnectConfig `mapstructure:"linuxdo_connect"`
Default DefaultConfig `mapstructure:"default"`
RateLimit RateLimitConfig `mapstructure:"rate_limit"`
Pricing PricingConfig `mapstructure:"pricing"`
Gateway GatewayConfig `mapstructure:"gateway"`
APIKeyAuth APIKeyAuthCacheConfig `mapstructure:"api_key_auth_cache"`
SubscriptionCache SubscriptionCacheConfig `mapstructure:"subscription_cache"`
Dashboard DashboardCacheConfig `mapstructure:"dashboard_cache"`
DashboardAgg DashboardAggregationConfig `mapstructure:"dashboard_aggregation"`
UsageCleanup UsageCleanupConfig `mapstructure:"usage_cleanup"`
Concurrency ConcurrencyConfig `mapstructure:"concurrency"`
TokenRefresh TokenRefreshConfig `mapstructure:"token_refresh"`
RunMode string `mapstructure:"run_mode" yaml:"run_mode"`
Timezone string `mapstructure:"timezone"` // e.g. "Asia/Shanghai", "UTC"
Gemini GeminiConfig `mapstructure:"gemini"`
Update UpdateConfig `mapstructure:"update"`
}
type GeminiConfig struct {
@@ -528,6 +529,13 @@ type APIKeyAuthCacheConfig struct {
Singleflight bool `mapstructure:"singleflight"`
}
// SubscriptionCacheConfig 订阅认证 L1 缓存配置
type SubscriptionCacheConfig struct {
L1Size int `mapstructure:"l1_size"`
L1TTLSeconds int `mapstructure:"l1_ttl_seconds"`
JitterPercent int `mapstructure:"jitter_percent"`
}
// DashboardCacheConfig 仪表盘统计缓存配置
type DashboardCacheConfig struct {
// Enabled: 是否启用仪表盘缓存
@@ -852,6 +860,11 @@ func setDefaults() {
viper.SetDefault("api_key_auth_cache.jitter_percent", 10)
viper.SetDefault("api_key_auth_cache.singleflight", true)
// Subscription auth L1 cache
viper.SetDefault("subscription_cache.l1_size", 16384)
viper.SetDefault("subscription_cache.l1_ttl_seconds", 10)
viper.SetDefault("subscription_cache.jitter_percent", 10)
// Dashboard cache
viper.SetDefault("dashboard_cache.enabled", true)
viper.SetDefault("dashboard_cache.key_prefix", "sub2api:")

View File

@@ -3,7 +3,6 @@ package middleware
import (
"context"
"errors"
"log"
"strings"
"github.com/Wei-Shaw/sub2api/internal/config"
@@ -134,7 +133,7 @@ func apiKeyAuthWithSubscription(apiKeyService *service.APIKeyService, subscripti
isSubscriptionType := apiKey.Group != nil && apiKey.Group.IsSubscriptionType()
if isSubscriptionType && subscriptionService != nil {
// 订阅模式:验证订阅
// 订阅模式:获取订阅L1 缓存 + singleflight
subscription, err := subscriptionService.GetActiveSubscription(
c.Request.Context(),
apiKey.User.ID,
@@ -145,30 +144,30 @@ func apiKeyAuthWithSubscription(apiKeyService *service.APIKeyService, subscripti
return
}
// 验证订阅状态(是否过期、暂停等
if err := subscriptionService.ValidateSubscription(c.Request.Context(), subscription); err != nil {
AbortWithError(c, 403, "SUBSCRIPTION_INVALID", err.Error())
return
}
// 激活滑动窗口(首次使用时)
if err := subscriptionService.CheckAndActivateWindow(c.Request.Context(), subscription); err != nil {
log.Printf("Failed to activate subscription windows: %v", err)
}
// 检查并重置过期窗口
if err := subscriptionService.CheckAndResetWindows(c.Request.Context(), subscription); err != nil {
log.Printf("Failed to reset subscription windows: %v", err)
}
// 预检查用量限制使用0作为额外费用进行预检查
if err := subscriptionService.CheckUsageLimits(c.Request.Context(), subscription, apiKey.Group, 0); err != nil {
AbortWithError(c, 429, "USAGE_LIMIT_EXCEEDED", err.Error())
// 合并验证 + 限额检查(纯内存操作
needsMaintenance, err := subscriptionService.ValidateAndCheckLimits(subscription, apiKey.Group)
if err != nil {
code := "SUBSCRIPTION_INVALID"
status := 403
if errors.Is(err, service.ErrDailyLimitExceeded) ||
errors.Is(err, service.ErrWeeklyLimitExceeded) ||
errors.Is(err, service.ErrMonthlyLimitExceeded) {
code = "USAGE_LIMIT_EXCEEDED"
status = 429
}
AbortWithError(c, status, code, err.Error())
return
}
// 将订阅信息存入上下文
c.Set(string(ContextKeySubscription), subscription)
// 窗口维护异步化(不阻塞请求)
// 传递独立拷贝,避免与 handler 读取 context 中的 subscription 产生 data race
if needsMaintenance {
maintenanceCopy := *subscription
go subscriptionService.DoWindowMaintenance(&maintenanceCopy)
}
} else {
// 余额模式:检查用户余额
if apiKey.User.Balance <= 0 {

View File

@@ -4,10 +4,15 @@ import (
"context"
"fmt"
"log"
"math/rand/v2"
"strconv"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/dgraph-io/ristretto"
"golang.org/x/sync/singleflight"
)
// MaxExpiresAt is the maximum allowed expiration date (year 2099)
@@ -35,15 +40,76 @@ type SubscriptionService struct {
groupRepo GroupRepository
userSubRepo UserSubscriptionRepository
billingCacheService *BillingCacheService
// L1 缓存:加速中间件热路径的订阅查询
subCacheL1 *ristretto.Cache
subCacheGroup singleflight.Group
subCacheTTL time.Duration
subCacheJitter int // 抖动百分比
}
// NewSubscriptionService 创建订阅服务
func NewSubscriptionService(groupRepo GroupRepository, userSubRepo UserSubscriptionRepository, billingCacheService *BillingCacheService) *SubscriptionService {
return &SubscriptionService{
func NewSubscriptionService(groupRepo GroupRepository, userSubRepo UserSubscriptionRepository, billingCacheService *BillingCacheService, cfg *config.Config) *SubscriptionService {
svc := &SubscriptionService{
groupRepo: groupRepo,
userSubRepo: userSubRepo,
billingCacheService: billingCacheService,
}
svc.initSubCache(cfg)
return svc
}
// initSubCache 初始化订阅 L1 缓存
func (s *SubscriptionService) initSubCache(cfg *config.Config) {
if cfg == nil {
return
}
sc := cfg.SubscriptionCache
if sc.L1Size <= 0 || sc.L1TTLSeconds <= 0 {
return
}
cache, err := ristretto.NewCache(&ristretto.Config{
NumCounters: int64(sc.L1Size) * 10,
MaxCost: int64(sc.L1Size),
BufferItems: 64,
})
if err != nil {
log.Printf("Warning: failed to init subscription L1 cache: %v", err)
return
}
s.subCacheL1 = cache
s.subCacheTTL = time.Duration(sc.L1TTLSeconds) * time.Second
s.subCacheJitter = sc.JitterPercent
}
// subCacheKey 生成订阅缓存 key热路径避免 fmt.Sprintf 开销)
func subCacheKey(userID, groupID int64) string {
return "sub:" + strconv.FormatInt(userID, 10) + ":" + strconv.FormatInt(groupID, 10)
}
// jitteredTTL 为 TTL 添加抖动,避免集中过期
func (s *SubscriptionService) jitteredTTL(ttl time.Duration) time.Duration {
if ttl <= 0 || s.subCacheJitter <= 0 {
return ttl
}
pct := s.subCacheJitter
if pct > 100 {
pct = 100
}
delta := float64(pct) / 100
factor := 1 - delta + rand.Float64()*(2*delta)
if factor <= 0 {
return ttl
}
return time.Duration(float64(ttl) * factor)
}
// InvalidateSubCache 失效指定用户+分组的订阅 L1 缓存
func (s *SubscriptionService) InvalidateSubCache(userID, groupID int64) {
if s.subCacheL1 == nil {
return
}
s.subCacheL1.Del(subCacheKey(userID, groupID))
}
// AssignSubscriptionInput 分配订阅输入
@@ -81,6 +147,7 @@ func (s *SubscriptionService) AssignSubscription(ctx context.Context, input *Ass
}
// 失效订阅缓存
s.InvalidateSubCache(input.UserID, input.GroupID)
if s.billingCacheService != nil {
userID, groupID := input.UserID, input.GroupID
go func() {
@@ -167,6 +234,7 @@ func (s *SubscriptionService) AssignOrExtendSubscription(ctx context.Context, in
}
// 失效订阅缓存
s.InvalidateSubCache(input.UserID, input.GroupID)
if s.billingCacheService != nil {
userID, groupID := input.UserID, input.GroupID
go func() {
@@ -188,6 +256,7 @@ func (s *SubscriptionService) AssignOrExtendSubscription(ctx context.Context, in
}
// 失效订阅缓存
s.InvalidateSubCache(input.UserID, input.GroupID)
if s.billingCacheService != nil {
userID, groupID := input.UserID, input.GroupID
go func() {
@@ -297,6 +366,7 @@ func (s *SubscriptionService) RevokeSubscription(ctx context.Context, subscripti
}
// 失效订阅缓存
s.InvalidateSubCache(sub.UserID, sub.GroupID)
if s.billingCacheService != nil {
userID, groupID := sub.UserID, sub.GroupID
go func() {
@@ -363,6 +433,7 @@ func (s *SubscriptionService) ExtendSubscription(ctx context.Context, subscripti
}
// 失效订阅缓存
s.InvalidateSubCache(sub.UserID, sub.GroupID)
if s.billingCacheService != nil {
userID, groupID := sub.UserID, sub.GroupID
go func() {
@@ -381,12 +452,39 @@ func (s *SubscriptionService) GetByID(ctx context.Context, id int64) (*UserSubsc
}
// GetActiveSubscription 获取用户对特定分组的有效订阅
// 使用 L1 缓存 + singleflight 加速中间件热路径。
// 返回缓存对象的浅拷贝,调用方可安全修改字段而不会污染缓存或触发 data race。
func (s *SubscriptionService) GetActiveSubscription(ctx context.Context, userID, groupID int64) (*UserSubscription, error) {
sub, err := s.userSubRepo.GetActiveByUserIDAndGroupID(ctx, userID, groupID)
if err != nil {
return nil, ErrSubscriptionNotFound
key := subCacheKey(userID, groupID)
// L1 缓存命中:返回浅拷贝
if s.subCacheL1 != nil {
if v, ok := s.subCacheL1.Get(key); ok {
if sub, ok := v.(*UserSubscription); ok {
cp := *sub
return &cp, nil
}
}
}
return sub, nil
// singleflight 防止并发击穿
value, err, _ := s.subCacheGroup.Do(key, func() (any, error) {
sub, err := s.userSubRepo.GetActiveByUserIDAndGroupID(ctx, userID, groupID)
if err != nil {
return nil, ErrSubscriptionNotFound
}
// 写入 L1 缓存
if s.subCacheL1 != nil {
_ = s.subCacheL1.SetWithTTL(key, sub, 1, s.jitteredTTL(s.subCacheTTL))
}
return sub, nil
})
if err != nil {
return nil, err
}
// singleflight 返回的也是缓存指针,需要浅拷贝
cp := *value.(*UserSubscription)
return &cp, nil
}
// ListUserSubscriptions 获取用户的所有订阅
@@ -521,9 +619,12 @@ func (s *SubscriptionService) CheckAndResetWindows(ctx context.Context, sub *Use
needsInvalidateCache = true
}
// 如果有窗口被重置,失效 Redis 缓存以保持一致性
if needsInvalidateCache && s.billingCacheService != nil {
_ = s.billingCacheService.InvalidateSubscription(ctx, sub.UserID, sub.GroupID)
// 如果有窗口被重置,失效缓存以保持一致性
if needsInvalidateCache {
s.InvalidateSubCache(sub.UserID, sub.GroupID)
if s.billingCacheService != nil {
_ = s.billingCacheService.InvalidateSubscription(ctx, sub.UserID, sub.GroupID)
}
}
return nil
@@ -544,6 +645,78 @@ func (s *SubscriptionService) CheckUsageLimits(ctx context.Context, sub *UserSub
return nil
}
// ValidateAndCheckLimits 合并验证+限额检查(中间件热路径专用)
// 仅做内存检查,不触发 DB 写入。窗口重置的 DB 写入由 DoWindowMaintenance 异步完成。
// 返回 needsMaintenance 表示是否需要异步执行窗口维护。
func (s *SubscriptionService) ValidateAndCheckLimits(sub *UserSubscription, group *Group) (needsMaintenance bool, err error) {
// 1. 验证订阅状态
if sub.Status == SubscriptionStatusExpired {
return false, ErrSubscriptionExpired
}
if sub.Status == SubscriptionStatusSuspended {
return false, ErrSubscriptionSuspended
}
if sub.IsExpired() {
return false, ErrSubscriptionExpired
}
// 2. 内存中修正过期窗口的用量,确保 CheckUsageLimits 不会误拒绝用户
// 实际的 DB 窗口重置由 DoWindowMaintenance 异步完成
if sub.NeedsDailyReset() {
sub.DailyUsageUSD = 0
needsMaintenance = true
}
if sub.NeedsWeeklyReset() {
sub.WeeklyUsageUSD = 0
needsMaintenance = true
}
if sub.NeedsMonthlyReset() {
sub.MonthlyUsageUSD = 0
needsMaintenance = true
}
if !sub.IsWindowActivated() {
needsMaintenance = true
}
// 3. 检查用量限额
if !sub.CheckDailyLimit(group, 0) {
return needsMaintenance, ErrDailyLimitExceeded
}
if !sub.CheckWeeklyLimit(group, 0) {
return needsMaintenance, ErrWeeklyLimitExceeded
}
if !sub.CheckMonthlyLimit(group, 0) {
return needsMaintenance, ErrMonthlyLimitExceeded
}
return needsMaintenance, nil
}
// DoWindowMaintenance 异步执行窗口维护(激活+重置)
// 使用独立 context不受请求取消影响。
// 注意:此方法仅在 ValidateAndCheckLimits 返回 needsMaintenance=true 时调用,
// 而 IsExpired()=true 的订阅在 ValidateAndCheckLimits 中已被拦截返回错误,
// 因此进入此方法的订阅一定未过期,无需处理过期状态同步。
func (s *SubscriptionService) DoWindowMaintenance(sub *UserSubscription) {
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
// 激活窗口(首次使用时)
if !sub.IsWindowActivated() {
if err := s.CheckAndActivateWindow(ctx, sub); err != nil {
log.Printf("Failed to activate subscription windows: %v", err)
}
}
// 重置过期窗口
if err := s.CheckAndResetWindows(ctx, sub); err != nil {
log.Printf("Failed to reset subscription windows: %v", err)
}
// 失效 L1 缓存,确保后续请求拿到更新后的数据
s.InvalidateSubCache(sub.UserID, sub.GroupID)
}
// RecordUsage 记录使用量到订阅
func (s *SubscriptionService) RecordUsage(ctx context.Context, subscriptionID int64, costUSD float64) error {
return s.userSubRepo.IncrementUsage(ctx, subscriptionID, costUSD)