From 99e2391b2ac7c4927ed69a82025f30ca48ac0a92 Mon Sep 17 00:00:00 2001 From: yangjianbo Date: Sat, 10 Jan 2026 22:52:13 +0800 Subject: [PATCH] =?UTF-8?q?fix(=E8=AE=A4=E8=AF=81):=20=E8=A1=A5=E9=BD=90?= =?UTF-8?q?=E4=BD=99=E9=A2=9D=E4=B8=8E=E5=88=A0=E9=99=A4=E5=9C=BA=E6=99=AF?= =?UTF-8?q?=E7=BC=93=E5=AD=98=E5=A4=B1=E6=95=88?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 为 Usage/Promo/Redeem 注入认证缓存失效逻辑 删除用户与分组前先失效认证缓存降低窗口 补充回归测试验证失效调用 测试: make test --- backend/cmd/server/wire_gen.go | 8 ++-- backend/internal/server/api_contract_test.go | 2 +- .../service/auth_cache_invalidation_test.go | 31 ++++++++++++ backend/internal/service/group_service.go | 6 +-- backend/internal/service/promo_service.go | 28 +++++++---- backend/internal/service/redeem_service.go | 47 ++++++++++++------- backend/internal/service/usage_service.go | 27 ++++++++--- backend/internal/service/user_service.go | 6 +-- 8 files changed, 113 insertions(+), 42 deletions(-) create mode 100644 backend/internal/service/auth_cache_invalidation_test.go diff --git a/backend/cmd/server/wire_gen.go b/backend/cmd/server/wire_gen.go index a372f673..95a7b30b 100644 --- a/backend/cmd/server/wire_gen.go +++ b/backend/cmd/server/wire_gen.go @@ -55,24 +55,24 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { billingCache := repository.NewBillingCache(redisClient) userSubscriptionRepository := repository.NewUserSubscriptionRepository(client) billingCacheService := service.NewBillingCacheService(billingCache, userRepository, userSubscriptionRepository, configConfig) - promoService := service.NewPromoService(promoCodeRepository, userRepository, billingCacheService, client) - authService := service.NewAuthService(userRepository, configConfig, settingService, emailService, turnstileService, emailQueueService, promoService) apiKeyRepository := repository.NewAPIKeyRepository(client) groupRepository := repository.NewGroupRepository(client, db) apiKeyCache := repository.NewAPIKeyCache(redisClient) apiKeyService := service.NewAPIKeyService(apiKeyRepository, userRepository, groupRepository, userSubscriptionRepository, apiKeyCache, configConfig) apiKeyAuthCacheInvalidator := service.ProvideAPIKeyAuthCacheInvalidator(apiKeyService) + promoService := service.NewPromoService(promoCodeRepository, userRepository, billingCacheService, client, apiKeyAuthCacheInvalidator) + authService := service.NewAuthService(userRepository, configConfig, settingService, emailService, turnstileService, emailQueueService, promoService) userService := service.NewUserService(userRepository, apiKeyAuthCacheInvalidator) authHandler := handler.NewAuthHandler(configConfig, authService, userService, settingService, promoService) userHandler := handler.NewUserHandler(userService) apiKeyHandler := handler.NewAPIKeyHandler(apiKeyService) usageLogRepository := repository.NewUsageLogRepository(client, db) - usageService := service.NewUsageService(usageLogRepository, userRepository, client) + usageService := service.NewUsageService(usageLogRepository, userRepository, client, apiKeyAuthCacheInvalidator) usageHandler := handler.NewUsageHandler(usageService, apiKeyService) redeemCodeRepository := repository.NewRedeemCodeRepository(client) subscriptionService := service.NewSubscriptionService(groupRepository, userSubscriptionRepository, billingCacheService) redeemCache := repository.NewRedeemCache(redisClient) - redeemService := service.NewRedeemService(redeemCodeRepository, userRepository, subscriptionService, redeemCache, billingCacheService, client) + redeemService := service.NewRedeemService(redeemCodeRepository, userRepository, subscriptionService, redeemCache, billingCacheService, client, apiKeyAuthCacheInvalidator) redeemHandler := handler.NewRedeemHandler(redeemService) subscriptionHandler := handler.NewSubscriptionHandler(subscriptionService) dashboardService := service.NewDashboardService(usageLogRepository) diff --git a/backend/internal/server/api_contract_test.go b/backend/internal/server/api_contract_test.go index bd02f47d..4949f14b 100644 --- a/backend/internal/server/api_contract_test.go +++ b/backend/internal/server/api_contract_test.go @@ -393,7 +393,7 @@ func newContractDeps(t *testing.T) *contractDeps { apiKeyService := service.NewAPIKeyService(apiKeyRepo, userRepo, groupRepo, userSubRepo, apiKeyCache, cfg) usageRepo := newStubUsageLogRepo() - usageService := service.NewUsageService(usageRepo, userRepo, nil) + usageService := service.NewUsageService(usageRepo, userRepo, nil, nil) settingRepo := newStubSettingRepo() settingService := service.NewSettingService(settingRepo, cfg) diff --git a/backend/internal/service/auth_cache_invalidation_test.go b/backend/internal/service/auth_cache_invalidation_test.go new file mode 100644 index 00000000..3b4217c6 --- /dev/null +++ b/backend/internal/service/auth_cache_invalidation_test.go @@ -0,0 +1,31 @@ +//go:build unit + +package service + +import ( + "context" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestUsageService_InvalidateUsageCaches(t *testing.T) { + invalidator := &authCacheInvalidatorStub{} + svc := &UsageService{authCacheInvalidator: invalidator} + + svc.invalidateUsageCaches(context.Background(), 7, false) + require.Empty(t, invalidator.userIDs) + + svc.invalidateUsageCaches(context.Background(), 7, true) + require.Equal(t, []int64{7}, invalidator.userIDs) +} + +func TestRedeemService_InvalidateRedeemCaches_AuthCache(t *testing.T) { + invalidator := &authCacheInvalidatorStub{} + svc := &RedeemService{authCacheInvalidator: invalidator} + + svc.invalidateRedeemCaches(context.Background(), 11, &RedeemCode{Type: RedeemTypeBalance}) + svc.invalidateRedeemCaches(context.Background(), 11, &RedeemCode{Type: RedeemTypeConcurrency}) + + require.Equal(t, []int64{11, 11}, invalidator.userIDs) +} diff --git a/backend/internal/service/group_service.go b/backend/internal/service/group_service.go index a9214c82..324f347b 100644 --- a/backend/internal/service/group_service.go +++ b/backend/internal/service/group_service.go @@ -172,12 +172,12 @@ func (s *GroupService) Delete(ctx context.Context, id int64) error { return fmt.Errorf("get group: %w", err) } - if err := s.groupRepo.Delete(ctx, id); err != nil { - return fmt.Errorf("delete group: %w", err) - } if s.authCacheInvalidator != nil { s.authCacheInvalidator.InvalidateAuthCacheByGroupID(ctx, id) } + if err := s.groupRepo.Delete(ctx, id); err != nil { + return fmt.Errorf("delete group: %w", err) + } return nil } diff --git a/backend/internal/service/promo_service.go b/backend/internal/service/promo_service.go index 9acd5868..5ff63bdc 100644 --- a/backend/internal/service/promo_service.go +++ b/backend/internal/service/promo_service.go @@ -24,10 +24,11 @@ var ( // PromoService 优惠码服务 type PromoService struct { - promoRepo PromoCodeRepository - userRepo UserRepository - billingCacheService *BillingCacheService - entClient *dbent.Client + promoRepo PromoCodeRepository + userRepo UserRepository + billingCacheService *BillingCacheService + entClient *dbent.Client + authCacheInvalidator APIKeyAuthCacheInvalidator } // NewPromoService 创建优惠码服务实例 @@ -36,12 +37,14 @@ func NewPromoService( userRepo UserRepository, billingCacheService *BillingCacheService, entClient *dbent.Client, + authCacheInvalidator APIKeyAuthCacheInvalidator, ) *PromoService { return &PromoService{ - promoRepo: promoRepo, - userRepo: userRepo, - billingCacheService: billingCacheService, - entClient: entClient, + promoRepo: promoRepo, + userRepo: userRepo, + billingCacheService: billingCacheService, + entClient: entClient, + authCacheInvalidator: authCacheInvalidator, } } @@ -145,6 +148,8 @@ func (s *PromoService) ApplyPromoCode(ctx context.Context, userID int64, code st return fmt.Errorf("commit transaction: %w", err) } + s.invalidatePromoCaches(ctx, userID, promoCode.BonusAmount) + // 失效余额缓存 if s.billingCacheService != nil { go func() { @@ -157,6 +162,13 @@ func (s *PromoService) ApplyPromoCode(ctx context.Context, userID int64, code st return nil } +func (s *PromoService) invalidatePromoCaches(ctx context.Context, userID int64, bonusAmount float64) { + if bonusAmount == 0 || s.authCacheInvalidator == nil { + return + } + s.authCacheInvalidator.InvalidateAuthCacheByUserID(ctx, userID) +} + // GenerateRandomCode 生成随机优惠码 func (s *PromoService) GenerateRandomCode() (string, error) { bytes := make([]byte, 8) diff --git a/backend/internal/service/redeem_service.go b/backend/internal/service/redeem_service.go index b6324235..81767aa9 100644 --- a/backend/internal/service/redeem_service.go +++ b/backend/internal/service/redeem_service.go @@ -68,12 +68,13 @@ type RedeemCodeResponse struct { // RedeemService 兑换码服务 type RedeemService struct { - redeemRepo RedeemCodeRepository - userRepo UserRepository - subscriptionService *SubscriptionService - cache RedeemCache - billingCacheService *BillingCacheService - entClient *dbent.Client + redeemRepo RedeemCodeRepository + userRepo UserRepository + subscriptionService *SubscriptionService + cache RedeemCache + billingCacheService *BillingCacheService + entClient *dbent.Client + authCacheInvalidator APIKeyAuthCacheInvalidator } // NewRedeemService 创建兑换码服务实例 @@ -84,14 +85,16 @@ func NewRedeemService( cache RedeemCache, billingCacheService *BillingCacheService, entClient *dbent.Client, + authCacheInvalidator APIKeyAuthCacheInvalidator, ) *RedeemService { return &RedeemService{ - redeemRepo: redeemRepo, - userRepo: userRepo, - subscriptionService: subscriptionService, - cache: cache, - billingCacheService: billingCacheService, - entClient: entClient, + redeemRepo: redeemRepo, + userRepo: userRepo, + subscriptionService: subscriptionService, + cache: cache, + billingCacheService: billingCacheService, + entClient: entClient, + authCacheInvalidator: authCacheInvalidator, } } @@ -324,18 +327,30 @@ func (s *RedeemService) Redeem(ctx context.Context, userID int64, code string) ( // invalidateRedeemCaches 失效兑换相关的缓存 func (s *RedeemService) invalidateRedeemCaches(ctx context.Context, userID int64, redeemCode *RedeemCode) { - if s.billingCacheService == nil { - return - } - switch redeemCode.Type { case RedeemTypeBalance: + if s.authCacheInvalidator != nil { + s.authCacheInvalidator.InvalidateAuthCacheByUserID(ctx, userID) + } + if s.billingCacheService == nil { + return + } go func() { cacheCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() _ = s.billingCacheService.InvalidateUserBalance(cacheCtx, userID) }() + case RedeemTypeConcurrency: + if s.authCacheInvalidator != nil { + s.authCacheInvalidator.InvalidateAuthCacheByUserID(ctx, userID) + } + if s.billingCacheService == nil { + return + } case RedeemTypeSubscription: + if s.billingCacheService == nil { + return + } if redeemCode.GroupID != nil { groupID := *redeemCode.GroupID go func() { diff --git a/backend/internal/service/usage_service.go b/backend/internal/service/usage_service.go index 10a294ae..aa0a5b87 100644 --- a/backend/internal/service/usage_service.go +++ b/backend/internal/service/usage_service.go @@ -54,17 +54,19 @@ type UsageStats struct { // UsageService 使用统计服务 type UsageService struct { - usageRepo UsageLogRepository - userRepo UserRepository - entClient *dbent.Client + usageRepo UsageLogRepository + userRepo UserRepository + entClient *dbent.Client + authCacheInvalidator APIKeyAuthCacheInvalidator } // NewUsageService 创建使用统计服务实例 -func NewUsageService(usageRepo UsageLogRepository, userRepo UserRepository, entClient *dbent.Client) *UsageService { +func NewUsageService(usageRepo UsageLogRepository, userRepo UserRepository, entClient *dbent.Client, authCacheInvalidator APIKeyAuthCacheInvalidator) *UsageService { return &UsageService{ - usageRepo: usageRepo, - userRepo: userRepo, - entClient: entClient, + usageRepo: usageRepo, + userRepo: userRepo, + entClient: entClient, + authCacheInvalidator: authCacheInvalidator, } } @@ -118,10 +120,12 @@ func (s *UsageService) Create(ctx context.Context, req CreateUsageLogRequest) (* } // 扣除用户余额 + balanceUpdated := false if inserted && req.ActualCost > 0 { if err := s.userRepo.UpdateBalance(txCtx, req.UserID, -req.ActualCost); err != nil { return nil, fmt.Errorf("update user balance: %w", err) } + balanceUpdated = true } if tx != nil { @@ -130,9 +134,18 @@ func (s *UsageService) Create(ctx context.Context, req CreateUsageLogRequest) (* } } + s.invalidateUsageCaches(ctx, req.UserID, balanceUpdated) + return usageLog, nil } +func (s *UsageService) invalidateUsageCaches(ctx context.Context, userID int64, balanceUpdated bool) { + if !balanceUpdated || s.authCacheInvalidator == nil { + return + } + s.authCacheInvalidator.InvalidateAuthCacheByUserID(ctx, userID) +} + // GetByID 根据ID获取使用日志 func (s *UsageService) GetByID(ctx context.Context, id int64) (*UsageLog, error) { log, err := s.usageRepo.GetByID(ctx, id) diff --git a/backend/internal/service/user_service.go b/backend/internal/service/user_service.go index a7a36760..1734914a 100644 --- a/backend/internal/service/user_service.go +++ b/backend/internal/service/user_service.go @@ -213,11 +213,11 @@ func (s *UserService) UpdateStatus(ctx context.Context, userID int64, status str // Delete 删除用户(管理员功能) func (s *UserService) Delete(ctx context.Context, userID int64) error { - if err := s.userRepo.Delete(ctx, userID); err != nil { - return fmt.Errorf("delete user: %w", err) - } if s.authCacheInvalidator != nil { s.authCacheInvalidator.InvalidateAuthCacheByUserID(ctx, userID) } + if err := s.userRepo.Delete(ctx, userID); err != nil { + return fmt.Errorf("delete user: %w", err) + } return nil }