diff --git a/backend/cmd/server/wire_gen.go b/backend/cmd/server/wire_gen.go index 8184bc1c..6b0c6370 100644 --- a/backend/cmd/server/wire_gen.go +++ b/backend/cmd/server/wire_gen.go @@ -66,7 +66,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { promoService := service.NewPromoService(promoCodeRepository, userRepository, billingCacheService, client, apiKeyAuthCacheInvalidator) authService := service.NewAuthService(userRepository, redeemCodeRepository, refreshTokenCache, configConfig, settingService, emailService, turnstileService, emailQueueService, promoService) userService := service.NewUserService(userRepository, apiKeyAuthCacheInvalidator) - subscriptionService := service.NewSubscriptionService(groupRepository, userSubscriptionRepository, billingCacheService) + subscriptionService := service.NewSubscriptionService(groupRepository, userSubscriptionRepository, billingCacheService, configConfig) redeemCache := repository.NewRedeemCache(redisClient) redeemService := service.NewRedeemService(redeemCodeRepository, userRepository, subscriptionService, redeemCache, billingCacheService, client, apiKeyAuthCacheInvalidator) secretEncryptor, err := repository.NewAESEncryptor(configConfig) diff --git a/backend/internal/config/config.go b/backend/internal/config/config.go index 6c932ae2..64132a2f 100644 --- a/backend/internal/config/config.go +++ b/backend/internal/config/config.go @@ -38,31 +38,32 @@ const ( ) type Config struct { - Server ServerConfig `mapstructure:"server"` - CORS CORSConfig `mapstructure:"cors"` - Security SecurityConfig `mapstructure:"security"` - Billing BillingConfig `mapstructure:"billing"` - Turnstile TurnstileConfig `mapstructure:"turnstile"` - Database DatabaseConfig `mapstructure:"database"` - Redis RedisConfig `mapstructure:"redis"` - Ops OpsConfig `mapstructure:"ops"` - JWT JWTConfig `mapstructure:"jwt"` - Totp TotpConfig `mapstructure:"totp"` - LinuxDo LinuxDoConnectConfig `mapstructure:"linuxdo_connect"` - Default DefaultConfig `mapstructure:"default"` - RateLimit RateLimitConfig `mapstructure:"rate_limit"` - Pricing PricingConfig `mapstructure:"pricing"` - Gateway GatewayConfig `mapstructure:"gateway"` - APIKeyAuth APIKeyAuthCacheConfig `mapstructure:"api_key_auth_cache"` - Dashboard DashboardCacheConfig `mapstructure:"dashboard_cache"` - DashboardAgg DashboardAggregationConfig `mapstructure:"dashboard_aggregation"` - UsageCleanup UsageCleanupConfig `mapstructure:"usage_cleanup"` - Concurrency ConcurrencyConfig `mapstructure:"concurrency"` - TokenRefresh TokenRefreshConfig `mapstructure:"token_refresh"` - RunMode string `mapstructure:"run_mode" yaml:"run_mode"` - Timezone string `mapstructure:"timezone"` // e.g. "Asia/Shanghai", "UTC" - Gemini GeminiConfig `mapstructure:"gemini"` - Update UpdateConfig `mapstructure:"update"` + Server ServerConfig `mapstructure:"server"` + CORS CORSConfig `mapstructure:"cors"` + Security SecurityConfig `mapstructure:"security"` + Billing BillingConfig `mapstructure:"billing"` + Turnstile TurnstileConfig `mapstructure:"turnstile"` + Database DatabaseConfig `mapstructure:"database"` + Redis RedisConfig `mapstructure:"redis"` + Ops OpsConfig `mapstructure:"ops"` + JWT JWTConfig `mapstructure:"jwt"` + Totp TotpConfig `mapstructure:"totp"` + LinuxDo LinuxDoConnectConfig `mapstructure:"linuxdo_connect"` + Default DefaultConfig `mapstructure:"default"` + RateLimit RateLimitConfig `mapstructure:"rate_limit"` + Pricing PricingConfig `mapstructure:"pricing"` + Gateway GatewayConfig `mapstructure:"gateway"` + APIKeyAuth APIKeyAuthCacheConfig `mapstructure:"api_key_auth_cache"` + SubscriptionCache SubscriptionCacheConfig `mapstructure:"subscription_cache"` + Dashboard DashboardCacheConfig `mapstructure:"dashboard_cache"` + DashboardAgg DashboardAggregationConfig `mapstructure:"dashboard_aggregation"` + UsageCleanup UsageCleanupConfig `mapstructure:"usage_cleanup"` + Concurrency ConcurrencyConfig `mapstructure:"concurrency"` + TokenRefresh TokenRefreshConfig `mapstructure:"token_refresh"` + RunMode string `mapstructure:"run_mode" yaml:"run_mode"` + Timezone string `mapstructure:"timezone"` // e.g. "Asia/Shanghai", "UTC" + Gemini GeminiConfig `mapstructure:"gemini"` + Update UpdateConfig `mapstructure:"update"` } type GeminiConfig struct { @@ -528,6 +529,13 @@ type APIKeyAuthCacheConfig struct { Singleflight bool `mapstructure:"singleflight"` } +// SubscriptionCacheConfig 订阅认证 L1 缓存配置 +type SubscriptionCacheConfig struct { + L1Size int `mapstructure:"l1_size"` + L1TTLSeconds int `mapstructure:"l1_ttl_seconds"` + JitterPercent int `mapstructure:"jitter_percent"` +} + // DashboardCacheConfig 仪表盘统计缓存配置 type DashboardCacheConfig struct { // Enabled: 是否启用仪表盘缓存 @@ -852,6 +860,11 @@ func setDefaults() { viper.SetDefault("api_key_auth_cache.jitter_percent", 10) viper.SetDefault("api_key_auth_cache.singleflight", true) + // Subscription auth L1 cache + viper.SetDefault("subscription_cache.l1_size", 16384) + viper.SetDefault("subscription_cache.l1_ttl_seconds", 10) + viper.SetDefault("subscription_cache.jitter_percent", 10) + // Dashboard cache viper.SetDefault("dashboard_cache.enabled", true) viper.SetDefault("dashboard_cache.key_prefix", "sub2api:") diff --git a/backend/internal/server/middleware/api_key_auth.go b/backend/internal/server/middleware/api_key_auth.go index 2f739357..4525aee7 100644 --- a/backend/internal/server/middleware/api_key_auth.go +++ b/backend/internal/server/middleware/api_key_auth.go @@ -3,7 +3,6 @@ package middleware import ( "context" "errors" - "log" "strings" "github.com/Wei-Shaw/sub2api/internal/config" @@ -134,7 +133,7 @@ func apiKeyAuthWithSubscription(apiKeyService *service.APIKeyService, subscripti isSubscriptionType := apiKey.Group != nil && apiKey.Group.IsSubscriptionType() if isSubscriptionType && subscriptionService != nil { - // 订阅模式:验证订阅 + // 订阅模式:获取订阅(L1 缓存 + singleflight) subscription, err := subscriptionService.GetActiveSubscription( c.Request.Context(), apiKey.User.ID, @@ -145,30 +144,30 @@ func apiKeyAuthWithSubscription(apiKeyService *service.APIKeyService, subscripti return } - // 验证订阅状态(是否过期、暂停等) - if err := subscriptionService.ValidateSubscription(c.Request.Context(), subscription); err != nil { - AbortWithError(c, 403, "SUBSCRIPTION_INVALID", err.Error()) - return - } - - // 激活滑动窗口(首次使用时) - if err := subscriptionService.CheckAndActivateWindow(c.Request.Context(), subscription); err != nil { - log.Printf("Failed to activate subscription windows: %v", err) - } - - // 检查并重置过期窗口 - if err := subscriptionService.CheckAndResetWindows(c.Request.Context(), subscription); err != nil { - log.Printf("Failed to reset subscription windows: %v", err) - } - - // 预检查用量限制(使用0作为额外费用进行预检查) - if err := subscriptionService.CheckUsageLimits(c.Request.Context(), subscription, apiKey.Group, 0); err != nil { - AbortWithError(c, 429, "USAGE_LIMIT_EXCEEDED", err.Error()) + // 合并验证 + 限额检查(纯内存操作) + needsMaintenance, err := subscriptionService.ValidateAndCheckLimits(subscription, apiKey.Group) + if err != nil { + code := "SUBSCRIPTION_INVALID" + status := 403 + if errors.Is(err, service.ErrDailyLimitExceeded) || + errors.Is(err, service.ErrWeeklyLimitExceeded) || + errors.Is(err, service.ErrMonthlyLimitExceeded) { + code = "USAGE_LIMIT_EXCEEDED" + status = 429 + } + AbortWithError(c, status, code, err.Error()) return } // 将订阅信息存入上下文 c.Set(string(ContextKeySubscription), subscription) + + // 窗口维护异步化(不阻塞请求) + // 传递独立拷贝,避免与 handler 读取 context 中的 subscription 产生 data race + if needsMaintenance { + maintenanceCopy := *subscription + go subscriptionService.DoWindowMaintenance(&maintenanceCopy) + } } else { // 余额模式:检查用户余额 if apiKey.User.Balance <= 0 { diff --git a/backend/internal/service/subscription_service.go b/backend/internal/service/subscription_service.go index 3c42852e..21694d41 100644 --- a/backend/internal/service/subscription_service.go +++ b/backend/internal/service/subscription_service.go @@ -4,10 +4,15 @@ import ( "context" "fmt" "log" + "math/rand/v2" + "strconv" "time" + "github.com/Wei-Shaw/sub2api/internal/config" infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors" "github.com/Wei-Shaw/sub2api/internal/pkg/pagination" + "github.com/dgraph-io/ristretto" + "golang.org/x/sync/singleflight" ) // MaxExpiresAt is the maximum allowed expiration date (year 2099) @@ -35,15 +40,76 @@ type SubscriptionService struct { groupRepo GroupRepository userSubRepo UserSubscriptionRepository billingCacheService *BillingCacheService + + // L1 缓存:加速中间件热路径的订阅查询 + subCacheL1 *ristretto.Cache + subCacheGroup singleflight.Group + subCacheTTL time.Duration + subCacheJitter int // 抖动百分比 } // NewSubscriptionService 创建订阅服务 -func NewSubscriptionService(groupRepo GroupRepository, userSubRepo UserSubscriptionRepository, billingCacheService *BillingCacheService) *SubscriptionService { - return &SubscriptionService{ +func NewSubscriptionService(groupRepo GroupRepository, userSubRepo UserSubscriptionRepository, billingCacheService *BillingCacheService, cfg *config.Config) *SubscriptionService { + svc := &SubscriptionService{ groupRepo: groupRepo, userSubRepo: userSubRepo, billingCacheService: billingCacheService, } + svc.initSubCache(cfg) + return svc +} + +// initSubCache 初始化订阅 L1 缓存 +func (s *SubscriptionService) initSubCache(cfg *config.Config) { + if cfg == nil { + return + } + sc := cfg.SubscriptionCache + if sc.L1Size <= 0 || sc.L1TTLSeconds <= 0 { + return + } + cache, err := ristretto.NewCache(&ristretto.Config{ + NumCounters: int64(sc.L1Size) * 10, + MaxCost: int64(sc.L1Size), + BufferItems: 64, + }) + if err != nil { + log.Printf("Warning: failed to init subscription L1 cache: %v", err) + return + } + s.subCacheL1 = cache + s.subCacheTTL = time.Duration(sc.L1TTLSeconds) * time.Second + s.subCacheJitter = sc.JitterPercent +} + +// subCacheKey 生成订阅缓存 key(热路径,避免 fmt.Sprintf 开销) +func subCacheKey(userID, groupID int64) string { + return "sub:" + strconv.FormatInt(userID, 10) + ":" + strconv.FormatInt(groupID, 10) +} + +// jitteredTTL 为 TTL 添加抖动,避免集中过期 +func (s *SubscriptionService) jitteredTTL(ttl time.Duration) time.Duration { + if ttl <= 0 || s.subCacheJitter <= 0 { + return ttl + } + pct := s.subCacheJitter + if pct > 100 { + pct = 100 + } + delta := float64(pct) / 100 + factor := 1 - delta + rand.Float64()*(2*delta) + if factor <= 0 { + return ttl + } + return time.Duration(float64(ttl) * factor) +} + +// InvalidateSubCache 失效指定用户+分组的订阅 L1 缓存 +func (s *SubscriptionService) InvalidateSubCache(userID, groupID int64) { + if s.subCacheL1 == nil { + return + } + s.subCacheL1.Del(subCacheKey(userID, groupID)) } // AssignSubscriptionInput 分配订阅输入 @@ -81,6 +147,7 @@ func (s *SubscriptionService) AssignSubscription(ctx context.Context, input *Ass } // 失效订阅缓存 + s.InvalidateSubCache(input.UserID, input.GroupID) if s.billingCacheService != nil { userID, groupID := input.UserID, input.GroupID go func() { @@ -167,6 +234,7 @@ func (s *SubscriptionService) AssignOrExtendSubscription(ctx context.Context, in } // 失效订阅缓存 + s.InvalidateSubCache(input.UserID, input.GroupID) if s.billingCacheService != nil { userID, groupID := input.UserID, input.GroupID go func() { @@ -188,6 +256,7 @@ func (s *SubscriptionService) AssignOrExtendSubscription(ctx context.Context, in } // 失效订阅缓存 + s.InvalidateSubCache(input.UserID, input.GroupID) if s.billingCacheService != nil { userID, groupID := input.UserID, input.GroupID go func() { @@ -297,6 +366,7 @@ func (s *SubscriptionService) RevokeSubscription(ctx context.Context, subscripti } // 失效订阅缓存 + s.InvalidateSubCache(sub.UserID, sub.GroupID) if s.billingCacheService != nil { userID, groupID := sub.UserID, sub.GroupID go func() { @@ -363,6 +433,7 @@ func (s *SubscriptionService) ExtendSubscription(ctx context.Context, subscripti } // 失效订阅缓存 + s.InvalidateSubCache(sub.UserID, sub.GroupID) if s.billingCacheService != nil { userID, groupID := sub.UserID, sub.GroupID go func() { @@ -381,12 +452,39 @@ func (s *SubscriptionService) GetByID(ctx context.Context, id int64) (*UserSubsc } // GetActiveSubscription 获取用户对特定分组的有效订阅 +// 使用 L1 缓存 + singleflight 加速中间件热路径。 +// 返回缓存对象的浅拷贝,调用方可安全修改字段而不会污染缓存或触发 data race。 func (s *SubscriptionService) GetActiveSubscription(ctx context.Context, userID, groupID int64) (*UserSubscription, error) { - sub, err := s.userSubRepo.GetActiveByUserIDAndGroupID(ctx, userID, groupID) - if err != nil { - return nil, ErrSubscriptionNotFound + key := subCacheKey(userID, groupID) + + // L1 缓存命中:返回浅拷贝 + if s.subCacheL1 != nil { + if v, ok := s.subCacheL1.Get(key); ok { + if sub, ok := v.(*UserSubscription); ok { + cp := *sub + return &cp, nil + } + } } - return sub, nil + + // singleflight 防止并发击穿 + value, err, _ := s.subCacheGroup.Do(key, func() (any, error) { + sub, err := s.userSubRepo.GetActiveByUserIDAndGroupID(ctx, userID, groupID) + if err != nil { + return nil, ErrSubscriptionNotFound + } + // 写入 L1 缓存 + if s.subCacheL1 != nil { + _ = s.subCacheL1.SetWithTTL(key, sub, 1, s.jitteredTTL(s.subCacheTTL)) + } + return sub, nil + }) + if err != nil { + return nil, err + } + // singleflight 返回的也是缓存指针,需要浅拷贝 + cp := *value.(*UserSubscription) + return &cp, nil } // ListUserSubscriptions 获取用户的所有订阅 @@ -521,9 +619,12 @@ func (s *SubscriptionService) CheckAndResetWindows(ctx context.Context, sub *Use needsInvalidateCache = true } - // 如果有窗口被重置,失效 Redis 缓存以保持一致性 - if needsInvalidateCache && s.billingCacheService != nil { - _ = s.billingCacheService.InvalidateSubscription(ctx, sub.UserID, sub.GroupID) + // 如果有窗口被重置,失效缓存以保持一致性 + if needsInvalidateCache { + s.InvalidateSubCache(sub.UserID, sub.GroupID) + if s.billingCacheService != nil { + _ = s.billingCacheService.InvalidateSubscription(ctx, sub.UserID, sub.GroupID) + } } return nil @@ -544,6 +645,78 @@ func (s *SubscriptionService) CheckUsageLimits(ctx context.Context, sub *UserSub return nil } +// ValidateAndCheckLimits 合并验证+限额检查(中间件热路径专用) +// 仅做内存检查,不触发 DB 写入。窗口重置的 DB 写入由 DoWindowMaintenance 异步完成。 +// 返回 needsMaintenance 表示是否需要异步执行窗口维护。 +func (s *SubscriptionService) ValidateAndCheckLimits(sub *UserSubscription, group *Group) (needsMaintenance bool, err error) { + // 1. 验证订阅状态 + if sub.Status == SubscriptionStatusExpired { + return false, ErrSubscriptionExpired + } + if sub.Status == SubscriptionStatusSuspended { + return false, ErrSubscriptionSuspended + } + if sub.IsExpired() { + return false, ErrSubscriptionExpired + } + + // 2. 内存中修正过期窗口的用量,确保 CheckUsageLimits 不会误拒绝用户 + // 实际的 DB 窗口重置由 DoWindowMaintenance 异步完成 + if sub.NeedsDailyReset() { + sub.DailyUsageUSD = 0 + needsMaintenance = true + } + if sub.NeedsWeeklyReset() { + sub.WeeklyUsageUSD = 0 + needsMaintenance = true + } + if sub.NeedsMonthlyReset() { + sub.MonthlyUsageUSD = 0 + needsMaintenance = true + } + if !sub.IsWindowActivated() { + needsMaintenance = true + } + + // 3. 检查用量限额 + if !sub.CheckDailyLimit(group, 0) { + return needsMaintenance, ErrDailyLimitExceeded + } + if !sub.CheckWeeklyLimit(group, 0) { + return needsMaintenance, ErrWeeklyLimitExceeded + } + if !sub.CheckMonthlyLimit(group, 0) { + return needsMaintenance, ErrMonthlyLimitExceeded + } + + return needsMaintenance, nil +} + +// DoWindowMaintenance 异步执行窗口维护(激活+重置) +// 使用独立 context,不受请求取消影响。 +// 注意:此方法仅在 ValidateAndCheckLimits 返回 needsMaintenance=true 时调用, +// 而 IsExpired()=true 的订阅在 ValidateAndCheckLimits 中已被拦截返回错误, +// 因此进入此方法的订阅一定未过期,无需处理过期状态同步。 +func (s *SubscriptionService) DoWindowMaintenance(sub *UserSubscription) { + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + // 激活窗口(首次使用时) + if !sub.IsWindowActivated() { + if err := s.CheckAndActivateWindow(ctx, sub); err != nil { + log.Printf("Failed to activate subscription windows: %v", err) + } + } + + // 重置过期窗口 + if err := s.CheckAndResetWindows(ctx, sub); err != nil { + log.Printf("Failed to reset subscription windows: %v", err) + } + + // 失效 L1 缓存,确保后续请求拿到更新后的数据 + s.InvalidateSubCache(sub.UserID, sub.GroupID) +} + // RecordUsage 记录使用量到订阅 func (s *SubscriptionService) RecordUsage(ctx context.Context, subscriptionID int64, costUSD float64) error { return s.userSubRepo.IncrementUsage(ctx, subscriptionID, costUSD)