Files
sub2api/backend/internal/service/promo_service.go
yangjianbo cb3e08dda4 fix(认证): 补齐余额与删除场景缓存失效
为 Usage/Promo/Redeem 注入认证缓存失效逻辑
删除用户与分组前先失效认证缓存降低窗口
补充回归测试验证失效调用

测试: make test
2026-01-11 10:55:25 +08:00

269 lines
7.5 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package service
import (
"context"
"crypto/rand"
"encoding/hex"
"fmt"
"strings"
"time"
dbent "github.com/Wei-Shaw/sub2api/ent"
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
)
var (
ErrPromoCodeNotFound = infraerrors.NotFound("PROMO_CODE_NOT_FOUND", "promo code not found")
ErrPromoCodeExpired = infraerrors.BadRequest("PROMO_CODE_EXPIRED", "promo code has expired")
ErrPromoCodeDisabled = infraerrors.BadRequest("PROMO_CODE_DISABLED", "promo code is disabled")
ErrPromoCodeMaxUsed = infraerrors.BadRequest("PROMO_CODE_MAX_USED", "promo code has reached maximum uses")
ErrPromoCodeAlreadyUsed = infraerrors.Conflict("PROMO_CODE_ALREADY_USED", "you have already used this promo code")
ErrPromoCodeInvalid = infraerrors.BadRequest("PROMO_CODE_INVALID", "invalid promo code")
)
// PromoService 优惠码服务
type PromoService struct {
promoRepo PromoCodeRepository
userRepo UserRepository
billingCacheService *BillingCacheService
entClient *dbent.Client
authCacheInvalidator APIKeyAuthCacheInvalidator
}
// NewPromoService 创建优惠码服务实例
func NewPromoService(
promoRepo PromoCodeRepository,
userRepo UserRepository,
billingCacheService *BillingCacheService,
entClient *dbent.Client,
authCacheInvalidator APIKeyAuthCacheInvalidator,
) *PromoService {
return &PromoService{
promoRepo: promoRepo,
userRepo: userRepo,
billingCacheService: billingCacheService,
entClient: entClient,
authCacheInvalidator: authCacheInvalidator,
}
}
// ValidatePromoCode 验证优惠码(注册前调用)
// 返回 nil, nil 表示空码(不报错)
func (s *PromoService) ValidatePromoCode(ctx context.Context, code string) (*PromoCode, error) {
code = strings.TrimSpace(code)
if code == "" {
return nil, nil // 空码不报错,直接返回
}
promoCode, err := s.promoRepo.GetByCode(ctx, code)
if err != nil {
// 保留原始错误类型,不要统一映射为 NotFound
return nil, err
}
if err := s.validatePromoCodeStatus(promoCode); err != nil {
return nil, err
}
return promoCode, nil
}
// validatePromoCodeStatus 验证优惠码状态
func (s *PromoService) validatePromoCodeStatus(promoCode *PromoCode) error {
if !promoCode.CanUse() {
if promoCode.IsExpired() {
return ErrPromoCodeExpired
}
if promoCode.Status == PromoCodeStatusDisabled {
return ErrPromoCodeDisabled
}
if promoCode.MaxUses > 0 && promoCode.UsedCount >= promoCode.MaxUses {
return ErrPromoCodeMaxUsed
}
return ErrPromoCodeInvalid
}
return nil
}
// ApplyPromoCode 应用优惠码(注册成功后调用)
// 使用事务和行锁确保并发安全
func (s *PromoService) ApplyPromoCode(ctx context.Context, userID int64, code string) error {
code = strings.TrimSpace(code)
if code == "" {
return nil
}
// 开启事务
tx, err := s.entClient.Tx(ctx)
if err != nil {
return fmt.Errorf("begin transaction: %w", err)
}
defer func() { _ = tx.Rollback() }()
txCtx := dbent.NewTxContext(ctx, tx)
// 在事务中获取并锁定优惠码记录FOR UPDATE
promoCode, err := s.promoRepo.GetByCodeForUpdate(txCtx, code)
if err != nil {
return err
}
// 在事务中验证优惠码状态
if err := s.validatePromoCodeStatus(promoCode); err != nil {
return err
}
// 在事务中检查用户是否已使用过此优惠码
existing, err := s.promoRepo.GetUsageByPromoCodeAndUser(txCtx, promoCode.ID, userID)
if err != nil {
return fmt.Errorf("check existing usage: %w", err)
}
if existing != nil {
return ErrPromoCodeAlreadyUsed
}
// 增加用户余额
if err := s.userRepo.UpdateBalance(txCtx, userID, promoCode.BonusAmount); err != nil {
return fmt.Errorf("update user balance: %w", err)
}
// 创建使用记录
usage := &PromoCodeUsage{
PromoCodeID: promoCode.ID,
UserID: userID,
BonusAmount: promoCode.BonusAmount,
UsedAt: time.Now(),
}
if err := s.promoRepo.CreateUsage(txCtx, usage); err != nil {
return fmt.Errorf("create usage record: %w", err)
}
// 增加使用次数
if err := s.promoRepo.IncrementUsedCount(txCtx, promoCode.ID); err != nil {
return fmt.Errorf("increment used count: %w", err)
}
if err := tx.Commit(); err != nil {
return fmt.Errorf("commit transaction: %w", err)
}
s.invalidatePromoCaches(ctx, userID, promoCode.BonusAmount)
// 失效余额缓存
if s.billingCacheService != nil {
go func() {
cacheCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
_ = s.billingCacheService.InvalidateUserBalance(cacheCtx, userID)
}()
}
return nil
}
func (s *PromoService) invalidatePromoCaches(ctx context.Context, userID int64, bonusAmount float64) {
if bonusAmount == 0 || s.authCacheInvalidator == nil {
return
}
s.authCacheInvalidator.InvalidateAuthCacheByUserID(ctx, userID)
}
// GenerateRandomCode 生成随机优惠码
func (s *PromoService) GenerateRandomCode() (string, error) {
bytes := make([]byte, 8)
if _, err := rand.Read(bytes); err != nil {
return "", fmt.Errorf("generate random bytes: %w", err)
}
return strings.ToUpper(hex.EncodeToString(bytes)), nil
}
// Create 创建优惠码
func (s *PromoService) Create(ctx context.Context, input *CreatePromoCodeInput) (*PromoCode, error) {
code := strings.TrimSpace(input.Code)
if code == "" {
// 自动生成
var err error
code, err = s.GenerateRandomCode()
if err != nil {
return nil, err
}
}
promoCode := &PromoCode{
Code: strings.ToUpper(code),
BonusAmount: input.BonusAmount,
MaxUses: input.MaxUses,
UsedCount: 0,
Status: PromoCodeStatusActive,
ExpiresAt: input.ExpiresAt,
Notes: input.Notes,
}
if err := s.promoRepo.Create(ctx, promoCode); err != nil {
return nil, fmt.Errorf("create promo code: %w", err)
}
return promoCode, nil
}
// GetByID 根据ID获取优惠码
func (s *PromoService) GetByID(ctx context.Context, id int64) (*PromoCode, error) {
code, err := s.promoRepo.GetByID(ctx, id)
if err != nil {
return nil, err
}
return code, nil
}
// Update 更新优惠码
func (s *PromoService) Update(ctx context.Context, id int64, input *UpdatePromoCodeInput) (*PromoCode, error) {
promoCode, err := s.promoRepo.GetByID(ctx, id)
if err != nil {
return nil, err
}
if input.Code != nil {
promoCode.Code = strings.ToUpper(strings.TrimSpace(*input.Code))
}
if input.BonusAmount != nil {
promoCode.BonusAmount = *input.BonusAmount
}
if input.MaxUses != nil {
promoCode.MaxUses = *input.MaxUses
}
if input.Status != nil {
promoCode.Status = *input.Status
}
if input.ExpiresAt != nil {
promoCode.ExpiresAt = input.ExpiresAt
}
if input.Notes != nil {
promoCode.Notes = *input.Notes
}
if err := s.promoRepo.Update(ctx, promoCode); err != nil {
return nil, fmt.Errorf("update promo code: %w", err)
}
return promoCode, nil
}
// Delete 删除优惠码
func (s *PromoService) Delete(ctx context.Context, id int64) error {
if err := s.promoRepo.Delete(ctx, id); err != nil {
return fmt.Errorf("delete promo code: %w", err)
}
return nil
}
// List 获取优惠码列表
func (s *PromoService) List(ctx context.Context, params pagination.PaginationParams, status, search string) ([]PromoCode, *pagination.PaginationResult, error) {
return s.promoRepo.ListWithFilters(ctx, params, status, search)
}
// ListUsages 获取使用记录
func (s *PromoService) ListUsages(ctx context.Context, promoCodeID int64, params pagination.PaginationParams) ([]PromoCodeUsage, *pagination.PaginationResult, error) {
return s.promoRepo.ListUsagesByPromoCode(ctx, promoCodeID, params)
}