fix(认证): 补齐余额与删除场景缓存失效

为 Usage/Promo/Redeem 注入认证缓存失效逻辑
删除用户与分组前先失效认证缓存降低窗口
补充回归测试验证失效调用

测试: make test
This commit is contained in:
yangjianbo
2026-01-10 22:52:13 +08:00
committed by shaw
parent 44a93c1922
commit cb3e08dda4
8 changed files with 113 additions and 42 deletions

View File

@@ -55,24 +55,24 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
billingCache := repository.NewBillingCache(redisClient) billingCache := repository.NewBillingCache(redisClient)
userSubscriptionRepository := repository.NewUserSubscriptionRepository(client) userSubscriptionRepository := repository.NewUserSubscriptionRepository(client)
billingCacheService := service.NewBillingCacheService(billingCache, userRepository, userSubscriptionRepository, configConfig) billingCacheService := service.NewBillingCacheService(billingCache, userRepository, userSubscriptionRepository, configConfig)
promoService := service.NewPromoService(promoCodeRepository, userRepository, billingCacheService, client)
authService := service.NewAuthService(userRepository, configConfig, settingService, emailService, turnstileService, emailQueueService, promoService)
apiKeyRepository := repository.NewAPIKeyRepository(client) apiKeyRepository := repository.NewAPIKeyRepository(client)
groupRepository := repository.NewGroupRepository(client, db) groupRepository := repository.NewGroupRepository(client, db)
apiKeyCache := repository.NewAPIKeyCache(redisClient) apiKeyCache := repository.NewAPIKeyCache(redisClient)
apiKeyService := service.NewAPIKeyService(apiKeyRepository, userRepository, groupRepository, userSubscriptionRepository, apiKeyCache, configConfig) apiKeyService := service.NewAPIKeyService(apiKeyRepository, userRepository, groupRepository, userSubscriptionRepository, apiKeyCache, configConfig)
apiKeyAuthCacheInvalidator := service.ProvideAPIKeyAuthCacheInvalidator(apiKeyService) apiKeyAuthCacheInvalidator := service.ProvideAPIKeyAuthCacheInvalidator(apiKeyService)
promoService := service.NewPromoService(promoCodeRepository, userRepository, billingCacheService, client, apiKeyAuthCacheInvalidator)
authService := service.NewAuthService(userRepository, configConfig, settingService, emailService, turnstileService, emailQueueService, promoService)
userService := service.NewUserService(userRepository, apiKeyAuthCacheInvalidator) userService := service.NewUserService(userRepository, apiKeyAuthCacheInvalidator)
authHandler := handler.NewAuthHandler(configConfig, authService, userService, settingService, promoService) authHandler := handler.NewAuthHandler(configConfig, authService, userService, settingService, promoService)
userHandler := handler.NewUserHandler(userService) userHandler := handler.NewUserHandler(userService)
apiKeyHandler := handler.NewAPIKeyHandler(apiKeyService) apiKeyHandler := handler.NewAPIKeyHandler(apiKeyService)
usageLogRepository := repository.NewUsageLogRepository(client, db) usageLogRepository := repository.NewUsageLogRepository(client, db)
usageService := service.NewUsageService(usageLogRepository, userRepository, client) usageService := service.NewUsageService(usageLogRepository, userRepository, client, apiKeyAuthCacheInvalidator)
usageHandler := handler.NewUsageHandler(usageService, apiKeyService) usageHandler := handler.NewUsageHandler(usageService, apiKeyService)
redeemCodeRepository := repository.NewRedeemCodeRepository(client) redeemCodeRepository := repository.NewRedeemCodeRepository(client)
subscriptionService := service.NewSubscriptionService(groupRepository, userSubscriptionRepository, billingCacheService) subscriptionService := service.NewSubscriptionService(groupRepository, userSubscriptionRepository, billingCacheService)
redeemCache := repository.NewRedeemCache(redisClient) redeemCache := repository.NewRedeemCache(redisClient)
redeemService := service.NewRedeemService(redeemCodeRepository, userRepository, subscriptionService, redeemCache, billingCacheService, client) redeemService := service.NewRedeemService(redeemCodeRepository, userRepository, subscriptionService, redeemCache, billingCacheService, client, apiKeyAuthCacheInvalidator)
redeemHandler := handler.NewRedeemHandler(redeemService) redeemHandler := handler.NewRedeemHandler(redeemService)
subscriptionHandler := handler.NewSubscriptionHandler(subscriptionService) subscriptionHandler := handler.NewSubscriptionHandler(subscriptionService)
dashboardService := service.NewDashboardService(usageLogRepository) dashboardService := service.NewDashboardService(usageLogRepository)

View File

@@ -394,7 +394,7 @@ func newContractDeps(t *testing.T) *contractDeps {
apiKeyService := service.NewAPIKeyService(apiKeyRepo, userRepo, groupRepo, userSubRepo, apiKeyCache, cfg) apiKeyService := service.NewAPIKeyService(apiKeyRepo, userRepo, groupRepo, userSubRepo, apiKeyCache, cfg)
usageRepo := newStubUsageLogRepo() usageRepo := newStubUsageLogRepo()
usageService := service.NewUsageService(usageRepo, userRepo, nil) usageService := service.NewUsageService(usageRepo, userRepo, nil, nil)
settingRepo := newStubSettingRepo() settingRepo := newStubSettingRepo()
settingService := service.NewSettingService(settingRepo, cfg) settingService := service.NewSettingService(settingRepo, cfg)

View File

@@ -0,0 +1,31 @@
//go:build unit
package service
import (
"context"
"testing"
"github.com/stretchr/testify/require"
)
func TestUsageService_InvalidateUsageCaches(t *testing.T) {
invalidator := &authCacheInvalidatorStub{}
svc := &UsageService{authCacheInvalidator: invalidator}
svc.invalidateUsageCaches(context.Background(), 7, false)
require.Empty(t, invalidator.userIDs)
svc.invalidateUsageCaches(context.Background(), 7, true)
require.Equal(t, []int64{7}, invalidator.userIDs)
}
func TestRedeemService_InvalidateRedeemCaches_AuthCache(t *testing.T) {
invalidator := &authCacheInvalidatorStub{}
svc := &RedeemService{authCacheInvalidator: invalidator}
svc.invalidateRedeemCaches(context.Background(), 11, &RedeemCode{Type: RedeemTypeBalance})
svc.invalidateRedeemCaches(context.Background(), 11, &RedeemCode{Type: RedeemTypeConcurrency})
require.Equal(t, []int64{11, 11}, invalidator.userIDs)
}

View File

@@ -172,12 +172,12 @@ func (s *GroupService) Delete(ctx context.Context, id int64) error {
return fmt.Errorf("get group: %w", err) return fmt.Errorf("get group: %w", err)
} }
if err := s.groupRepo.Delete(ctx, id); err != nil {
return fmt.Errorf("delete group: %w", err)
}
if s.authCacheInvalidator != nil { if s.authCacheInvalidator != nil {
s.authCacheInvalidator.InvalidateAuthCacheByGroupID(ctx, id) s.authCacheInvalidator.InvalidateAuthCacheByGroupID(ctx, id)
} }
if err := s.groupRepo.Delete(ctx, id); err != nil {
return fmt.Errorf("delete group: %w", err)
}
return nil return nil
} }

View File

@@ -24,10 +24,11 @@ var (
// PromoService 优惠码服务 // PromoService 优惠码服务
type PromoService struct { type PromoService struct {
promoRepo PromoCodeRepository promoRepo PromoCodeRepository
userRepo UserRepository userRepo UserRepository
billingCacheService *BillingCacheService billingCacheService *BillingCacheService
entClient *dbent.Client entClient *dbent.Client
authCacheInvalidator APIKeyAuthCacheInvalidator
} }
// NewPromoService 创建优惠码服务实例 // NewPromoService 创建优惠码服务实例
@@ -36,12 +37,14 @@ func NewPromoService(
userRepo UserRepository, userRepo UserRepository,
billingCacheService *BillingCacheService, billingCacheService *BillingCacheService,
entClient *dbent.Client, entClient *dbent.Client,
authCacheInvalidator APIKeyAuthCacheInvalidator,
) *PromoService { ) *PromoService {
return &PromoService{ return &PromoService{
promoRepo: promoRepo, promoRepo: promoRepo,
userRepo: userRepo, userRepo: userRepo,
billingCacheService: billingCacheService, billingCacheService: billingCacheService,
entClient: entClient, entClient: entClient,
authCacheInvalidator: authCacheInvalidator,
} }
} }
@@ -145,6 +148,8 @@ func (s *PromoService) ApplyPromoCode(ctx context.Context, userID int64, code st
return fmt.Errorf("commit transaction: %w", err) return fmt.Errorf("commit transaction: %w", err)
} }
s.invalidatePromoCaches(ctx, userID, promoCode.BonusAmount)
// 失效余额缓存 // 失效余额缓存
if s.billingCacheService != nil { if s.billingCacheService != nil {
go func() { go func() {
@@ -157,6 +162,13 @@ func (s *PromoService) ApplyPromoCode(ctx context.Context, userID int64, code st
return nil return nil
} }
func (s *PromoService) invalidatePromoCaches(ctx context.Context, userID int64, bonusAmount float64) {
if bonusAmount == 0 || s.authCacheInvalidator == nil {
return
}
s.authCacheInvalidator.InvalidateAuthCacheByUserID(ctx, userID)
}
// GenerateRandomCode 生成随机优惠码 // GenerateRandomCode 生成随机优惠码
func (s *PromoService) GenerateRandomCode() (string, error) { func (s *PromoService) GenerateRandomCode() (string, error) {
bytes := make([]byte, 8) bytes := make([]byte, 8)

View File

@@ -68,12 +68,13 @@ type RedeemCodeResponse struct {
// RedeemService 兑换码服务 // RedeemService 兑换码服务
type RedeemService struct { type RedeemService struct {
redeemRepo RedeemCodeRepository redeemRepo RedeemCodeRepository
userRepo UserRepository userRepo UserRepository
subscriptionService *SubscriptionService subscriptionService *SubscriptionService
cache RedeemCache cache RedeemCache
billingCacheService *BillingCacheService billingCacheService *BillingCacheService
entClient *dbent.Client entClient *dbent.Client
authCacheInvalidator APIKeyAuthCacheInvalidator
} }
// NewRedeemService 创建兑换码服务实例 // NewRedeemService 创建兑换码服务实例
@@ -84,14 +85,16 @@ func NewRedeemService(
cache RedeemCache, cache RedeemCache,
billingCacheService *BillingCacheService, billingCacheService *BillingCacheService,
entClient *dbent.Client, entClient *dbent.Client,
authCacheInvalidator APIKeyAuthCacheInvalidator,
) *RedeemService { ) *RedeemService {
return &RedeemService{ return &RedeemService{
redeemRepo: redeemRepo, redeemRepo: redeemRepo,
userRepo: userRepo, userRepo: userRepo,
subscriptionService: subscriptionService, subscriptionService: subscriptionService,
cache: cache, cache: cache,
billingCacheService: billingCacheService, billingCacheService: billingCacheService,
entClient: entClient, entClient: entClient,
authCacheInvalidator: authCacheInvalidator,
} }
} }
@@ -324,18 +327,30 @@ func (s *RedeemService) Redeem(ctx context.Context, userID int64, code string) (
// invalidateRedeemCaches 失效兑换相关的缓存 // invalidateRedeemCaches 失效兑换相关的缓存
func (s *RedeemService) invalidateRedeemCaches(ctx context.Context, userID int64, redeemCode *RedeemCode) { func (s *RedeemService) invalidateRedeemCaches(ctx context.Context, userID int64, redeemCode *RedeemCode) {
if s.billingCacheService == nil {
return
}
switch redeemCode.Type { switch redeemCode.Type {
case RedeemTypeBalance: case RedeemTypeBalance:
if s.authCacheInvalidator != nil {
s.authCacheInvalidator.InvalidateAuthCacheByUserID(ctx, userID)
}
if s.billingCacheService == nil {
return
}
go func() { go func() {
cacheCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second) cacheCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel() defer cancel()
_ = s.billingCacheService.InvalidateUserBalance(cacheCtx, userID) _ = s.billingCacheService.InvalidateUserBalance(cacheCtx, userID)
}() }()
case RedeemTypeConcurrency:
if s.authCacheInvalidator != nil {
s.authCacheInvalidator.InvalidateAuthCacheByUserID(ctx, userID)
}
if s.billingCacheService == nil {
return
}
case RedeemTypeSubscription: case RedeemTypeSubscription:
if s.billingCacheService == nil {
return
}
if redeemCode.GroupID != nil { if redeemCode.GroupID != nil {
groupID := *redeemCode.GroupID groupID := *redeemCode.GroupID
go func() { go func() {

View File

@@ -54,17 +54,19 @@ type UsageStats struct {
// UsageService 使用统计服务 // UsageService 使用统计服务
type UsageService struct { type UsageService struct {
usageRepo UsageLogRepository usageRepo UsageLogRepository
userRepo UserRepository userRepo UserRepository
entClient *dbent.Client entClient *dbent.Client
authCacheInvalidator APIKeyAuthCacheInvalidator
} }
// NewUsageService 创建使用统计服务实例 // NewUsageService 创建使用统计服务实例
func NewUsageService(usageRepo UsageLogRepository, userRepo UserRepository, entClient *dbent.Client) *UsageService { func NewUsageService(usageRepo UsageLogRepository, userRepo UserRepository, entClient *dbent.Client, authCacheInvalidator APIKeyAuthCacheInvalidator) *UsageService {
return &UsageService{ return &UsageService{
usageRepo: usageRepo, usageRepo: usageRepo,
userRepo: userRepo, userRepo: userRepo,
entClient: entClient, entClient: entClient,
authCacheInvalidator: authCacheInvalidator,
} }
} }
@@ -118,10 +120,12 @@ func (s *UsageService) Create(ctx context.Context, req CreateUsageLogRequest) (*
} }
// 扣除用户余额 // 扣除用户余额
balanceUpdated := false
if inserted && req.ActualCost > 0 { if inserted && req.ActualCost > 0 {
if err := s.userRepo.UpdateBalance(txCtx, req.UserID, -req.ActualCost); err != nil { if err := s.userRepo.UpdateBalance(txCtx, req.UserID, -req.ActualCost); err != nil {
return nil, fmt.Errorf("update user balance: %w", err) return nil, fmt.Errorf("update user balance: %w", err)
} }
balanceUpdated = true
} }
if tx != nil { if tx != nil {
@@ -130,9 +134,18 @@ func (s *UsageService) Create(ctx context.Context, req CreateUsageLogRequest) (*
} }
} }
s.invalidateUsageCaches(ctx, req.UserID, balanceUpdated)
return usageLog, nil return usageLog, nil
} }
func (s *UsageService) invalidateUsageCaches(ctx context.Context, userID int64, balanceUpdated bool) {
if !balanceUpdated || s.authCacheInvalidator == nil {
return
}
s.authCacheInvalidator.InvalidateAuthCacheByUserID(ctx, userID)
}
// GetByID 根据ID获取使用日志 // GetByID 根据ID获取使用日志
func (s *UsageService) GetByID(ctx context.Context, id int64) (*UsageLog, error) { func (s *UsageService) GetByID(ctx context.Context, id int64) (*UsageLog, error) {
log, err := s.usageRepo.GetByID(ctx, id) log, err := s.usageRepo.GetByID(ctx, id)

View File

@@ -213,11 +213,11 @@ func (s *UserService) UpdateStatus(ctx context.Context, userID int64, status str
// Delete 删除用户(管理员功能) // Delete 删除用户(管理员功能)
func (s *UserService) Delete(ctx context.Context, userID int64) error { func (s *UserService) Delete(ctx context.Context, userID int64) error {
if err := s.userRepo.Delete(ctx, userID); err != nil {
return fmt.Errorf("delete user: %w", err)
}
if s.authCacheInvalidator != nil { if s.authCacheInvalidator != nil {
s.authCacheInvalidator.InvalidateAuthCacheByUserID(ctx, userID) s.authCacheInvalidator.InvalidateAuthCacheByUserID(ctx, userID)
} }
if err := s.userRepo.Delete(ctx, userID); err != nil {
return fmt.Errorf("delete user: %w", err)
}
return nil return nil
} }