feat: 实现注册优惠码功能
- 支持创建/编辑/删除优惠码,设置赠送金额和使用限制 - 注册页面实时验证优惠码并显示赠送金额 - 支持 URL 参数自动填充 (?promo=CODE) - 添加优惠码验证接口速率限制 - 使用数据库行锁防止并发超限 - 新增后台优惠码管理页面,支持复制注册链接
This commit is contained in:
@@ -52,6 +52,7 @@ type AuthService struct {
|
||||
emailService *EmailService
|
||||
turnstileService *TurnstileService
|
||||
emailQueueService *EmailQueueService
|
||||
promoService *PromoService
|
||||
}
|
||||
|
||||
// NewAuthService 创建认证服务实例
|
||||
@@ -62,6 +63,7 @@ func NewAuthService(
|
||||
emailService *EmailService,
|
||||
turnstileService *TurnstileService,
|
||||
emailQueueService *EmailQueueService,
|
||||
promoService *PromoService,
|
||||
) *AuthService {
|
||||
return &AuthService{
|
||||
userRepo: userRepo,
|
||||
@@ -70,16 +72,17 @@ func NewAuthService(
|
||||
emailService: emailService,
|
||||
turnstileService: turnstileService,
|
||||
emailQueueService: emailQueueService,
|
||||
promoService: promoService,
|
||||
}
|
||||
}
|
||||
|
||||
// Register 用户注册,返回token和用户
|
||||
func (s *AuthService) Register(ctx context.Context, email, password string) (string, *User, error) {
|
||||
return s.RegisterWithVerification(ctx, email, password, "")
|
||||
return s.RegisterWithVerification(ctx, email, password, "", "")
|
||||
}
|
||||
|
||||
// RegisterWithVerification 用户注册(支持邮件验证),返回token和用户
|
||||
func (s *AuthService) RegisterWithVerification(ctx context.Context, email, password, verifyCode string) (string, *User, error) {
|
||||
// RegisterWithVerification 用户注册(支持邮件验证和优惠码),返回token和用户
|
||||
func (s *AuthService) RegisterWithVerification(ctx context.Context, email, password, verifyCode, promoCode string) (string, *User, error) {
|
||||
// 检查是否开放注册(默认关闭:settingService 未配置时不允许注册)
|
||||
if s.settingService == nil || !s.settingService.IsRegistrationEnabled(ctx) {
|
||||
return "", nil, ErrRegDisabled
|
||||
@@ -150,6 +153,19 @@ func (s *AuthService) RegisterWithVerification(ctx context.Context, email, passw
|
||||
return "", nil, ErrServiceUnavailable
|
||||
}
|
||||
|
||||
// 应用优惠码(如果提供)
|
||||
if promoCode != "" && s.promoService != nil {
|
||||
if err := s.promoService.ApplyPromoCode(ctx, user.ID, promoCode); err != nil {
|
||||
// 优惠码应用失败不影响注册,只记录日志
|
||||
log.Printf("[Auth] Failed to apply promo code for user %d: %v", user.ID, err)
|
||||
} else {
|
||||
// 重新获取用户信息以获取更新后的余额
|
||||
if updatedUser, err := s.userRepo.GetByID(ctx, user.ID); err == nil {
|
||||
user = updatedUser
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 生成token
|
||||
token, err := s.GenerateToken(user)
|
||||
if err != nil {
|
||||
|
||||
@@ -38,6 +38,12 @@ const (
|
||||
RedeemTypeSubscription = "subscription"
|
||||
)
|
||||
|
||||
// PromoCode status constants
|
||||
const (
|
||||
PromoCodeStatusActive = "active"
|
||||
PromoCodeStatusDisabled = "disabled"
|
||||
)
|
||||
|
||||
// Admin adjustment type constants
|
||||
const (
|
||||
AdjustmentTypeAdminBalance = "admin_balance" // 管理员调整余额
|
||||
|
||||
73
backend/internal/service/promo_code.go
Normal file
73
backend/internal/service/promo_code.go
Normal file
@@ -0,0 +1,73 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"time"
|
||||
)
|
||||
|
||||
// PromoCode 注册优惠码
|
||||
type PromoCode struct {
|
||||
ID int64
|
||||
Code string
|
||||
BonusAmount float64
|
||||
MaxUses int
|
||||
UsedCount int
|
||||
Status string
|
||||
ExpiresAt *time.Time
|
||||
Notes string
|
||||
CreatedAt time.Time
|
||||
UpdatedAt time.Time
|
||||
|
||||
// 关联
|
||||
UsageRecords []PromoCodeUsage
|
||||
}
|
||||
|
||||
// PromoCodeUsage 优惠码使用记录
|
||||
type PromoCodeUsage struct {
|
||||
ID int64
|
||||
PromoCodeID int64
|
||||
UserID int64
|
||||
BonusAmount float64
|
||||
UsedAt time.Time
|
||||
|
||||
// 关联
|
||||
PromoCode *PromoCode
|
||||
User *User
|
||||
}
|
||||
|
||||
// CanUse 检查优惠码是否可用
|
||||
func (p *PromoCode) CanUse() bool {
|
||||
if p.Status != PromoCodeStatusActive {
|
||||
return false
|
||||
}
|
||||
if p.ExpiresAt != nil && time.Now().After(*p.ExpiresAt) {
|
||||
return false
|
||||
}
|
||||
if p.MaxUses > 0 && p.UsedCount >= p.MaxUses {
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// IsExpired 检查是否已过期
|
||||
func (p *PromoCode) IsExpired() bool {
|
||||
return p.ExpiresAt != nil && time.Now().After(*p.ExpiresAt)
|
||||
}
|
||||
|
||||
// CreatePromoCodeInput 创建优惠码输入
|
||||
type CreatePromoCodeInput struct {
|
||||
Code string
|
||||
BonusAmount float64
|
||||
MaxUses int
|
||||
ExpiresAt *time.Time
|
||||
Notes string
|
||||
}
|
||||
|
||||
// UpdatePromoCodeInput 更新优惠码输入
|
||||
type UpdatePromoCodeInput struct {
|
||||
Code *string
|
||||
BonusAmount *float64
|
||||
MaxUses *int
|
||||
Status *string
|
||||
ExpiresAt *time.Time
|
||||
Notes *string
|
||||
}
|
||||
30
backend/internal/service/promo_code_repository.go
Normal file
30
backend/internal/service/promo_code_repository.go
Normal file
@@ -0,0 +1,30 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||||
)
|
||||
|
||||
// PromoCodeRepository 优惠码仓储接口
|
||||
type PromoCodeRepository interface {
|
||||
// 基础 CRUD
|
||||
Create(ctx context.Context, code *PromoCode) error
|
||||
GetByID(ctx context.Context, id int64) (*PromoCode, error)
|
||||
GetByCode(ctx context.Context, code string) (*PromoCode, error)
|
||||
GetByCodeForUpdate(ctx context.Context, code string) (*PromoCode, error) // 带行锁的查询,用于并发控制
|
||||
Update(ctx context.Context, code *PromoCode) error
|
||||
Delete(ctx context.Context, id int64) error
|
||||
|
||||
// 列表查询
|
||||
List(ctx context.Context, params pagination.PaginationParams) ([]PromoCode, *pagination.PaginationResult, error)
|
||||
ListWithFilters(ctx context.Context, params pagination.PaginationParams, status, search string) ([]PromoCode, *pagination.PaginationResult, error)
|
||||
|
||||
// 使用记录
|
||||
CreateUsage(ctx context.Context, usage *PromoCodeUsage) error
|
||||
GetUsageByPromoCodeAndUser(ctx context.Context, promoCodeID, userID int64) (*PromoCodeUsage, error)
|
||||
ListUsagesByPromoCode(ctx context.Context, promoCodeID int64, params pagination.PaginationParams) ([]PromoCodeUsage, *pagination.PaginationResult, error)
|
||||
|
||||
// 计数操作
|
||||
IncrementUsedCount(ctx context.Context, id int64) error
|
||||
}
|
||||
256
backend/internal/service/promo_service.go
Normal file
256
backend/internal/service/promo_service.go
Normal file
@@ -0,0 +1,256 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
dbent "github.com/Wei-Shaw/sub2api/ent"
|
||||
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrPromoCodeNotFound = infraerrors.NotFound("PROMO_CODE_NOT_FOUND", "promo code not found")
|
||||
ErrPromoCodeExpired = infraerrors.BadRequest("PROMO_CODE_EXPIRED", "promo code has expired")
|
||||
ErrPromoCodeDisabled = infraerrors.BadRequest("PROMO_CODE_DISABLED", "promo code is disabled")
|
||||
ErrPromoCodeMaxUsed = infraerrors.BadRequest("PROMO_CODE_MAX_USED", "promo code has reached maximum uses")
|
||||
ErrPromoCodeAlreadyUsed = infraerrors.Conflict("PROMO_CODE_ALREADY_USED", "you have already used this promo code")
|
||||
ErrPromoCodeInvalid = infraerrors.BadRequest("PROMO_CODE_INVALID", "invalid promo code")
|
||||
)
|
||||
|
||||
// PromoService 优惠码服务
|
||||
type PromoService struct {
|
||||
promoRepo PromoCodeRepository
|
||||
userRepo UserRepository
|
||||
billingCacheService *BillingCacheService
|
||||
entClient *dbent.Client
|
||||
}
|
||||
|
||||
// NewPromoService 创建优惠码服务实例
|
||||
func NewPromoService(
|
||||
promoRepo PromoCodeRepository,
|
||||
userRepo UserRepository,
|
||||
billingCacheService *BillingCacheService,
|
||||
entClient *dbent.Client,
|
||||
) *PromoService {
|
||||
return &PromoService{
|
||||
promoRepo: promoRepo,
|
||||
userRepo: userRepo,
|
||||
billingCacheService: billingCacheService,
|
||||
entClient: entClient,
|
||||
}
|
||||
}
|
||||
|
||||
// ValidatePromoCode 验证优惠码(注册前调用)
|
||||
// 返回 nil, nil 表示空码(不报错)
|
||||
func (s *PromoService) ValidatePromoCode(ctx context.Context, code string) (*PromoCode, error) {
|
||||
code = strings.TrimSpace(code)
|
||||
if code == "" {
|
||||
return nil, nil // 空码不报错,直接返回
|
||||
}
|
||||
|
||||
promoCode, err := s.promoRepo.GetByCode(ctx, code)
|
||||
if err != nil {
|
||||
// 保留原始错误类型,不要统一映射为 NotFound
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := s.validatePromoCodeStatus(promoCode); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return promoCode, nil
|
||||
}
|
||||
|
||||
// validatePromoCodeStatus 验证优惠码状态
|
||||
func (s *PromoService) validatePromoCodeStatus(promoCode *PromoCode) error {
|
||||
if !promoCode.CanUse() {
|
||||
if promoCode.IsExpired() {
|
||||
return ErrPromoCodeExpired
|
||||
}
|
||||
if promoCode.Status == PromoCodeStatusDisabled {
|
||||
return ErrPromoCodeDisabled
|
||||
}
|
||||
if promoCode.MaxUses > 0 && promoCode.UsedCount >= promoCode.MaxUses {
|
||||
return ErrPromoCodeMaxUsed
|
||||
}
|
||||
return ErrPromoCodeInvalid
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// ApplyPromoCode 应用优惠码(注册成功后调用)
|
||||
// 使用事务和行锁确保并发安全
|
||||
func (s *PromoService) ApplyPromoCode(ctx context.Context, userID int64, code string) error {
|
||||
code = strings.TrimSpace(code)
|
||||
if code == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
// 开启事务
|
||||
tx, err := s.entClient.Tx(ctx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("begin transaction: %w", err)
|
||||
}
|
||||
defer func() { _ = tx.Rollback() }()
|
||||
|
||||
txCtx := dbent.NewTxContext(ctx, tx)
|
||||
|
||||
// 在事务中获取并锁定优惠码记录(FOR UPDATE)
|
||||
promoCode, err := s.promoRepo.GetByCodeForUpdate(txCtx, code)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 在事务中验证优惠码状态
|
||||
if err := s.validatePromoCodeStatus(promoCode); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 在事务中检查用户是否已使用过此优惠码
|
||||
existing, err := s.promoRepo.GetUsageByPromoCodeAndUser(txCtx, promoCode.ID, userID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("check existing usage: %w", err)
|
||||
}
|
||||
if existing != nil {
|
||||
return ErrPromoCodeAlreadyUsed
|
||||
}
|
||||
|
||||
// 增加用户余额
|
||||
if err := s.userRepo.UpdateBalance(txCtx, userID, promoCode.BonusAmount); err != nil {
|
||||
return fmt.Errorf("update user balance: %w", err)
|
||||
}
|
||||
|
||||
// 创建使用记录
|
||||
usage := &PromoCodeUsage{
|
||||
PromoCodeID: promoCode.ID,
|
||||
UserID: userID,
|
||||
BonusAmount: promoCode.BonusAmount,
|
||||
UsedAt: time.Now(),
|
||||
}
|
||||
if err := s.promoRepo.CreateUsage(txCtx, usage); err != nil {
|
||||
return fmt.Errorf("create usage record: %w", err)
|
||||
}
|
||||
|
||||
// 增加使用次数
|
||||
if err := s.promoRepo.IncrementUsedCount(txCtx, promoCode.ID); err != nil {
|
||||
return fmt.Errorf("increment used count: %w", err)
|
||||
}
|
||||
|
||||
if err := tx.Commit(); err != nil {
|
||||
return fmt.Errorf("commit transaction: %w", err)
|
||||
}
|
||||
|
||||
// 失效余额缓存
|
||||
if s.billingCacheService != nil {
|
||||
go func() {
|
||||
cacheCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
_ = s.billingCacheService.InvalidateUserBalance(cacheCtx, userID)
|
||||
}()
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GenerateRandomCode 生成随机优惠码
|
||||
func (s *PromoService) GenerateRandomCode() (string, error) {
|
||||
bytes := make([]byte, 8)
|
||||
if _, err := rand.Read(bytes); err != nil {
|
||||
return "", fmt.Errorf("generate random bytes: %w", err)
|
||||
}
|
||||
return strings.ToUpper(hex.EncodeToString(bytes)), nil
|
||||
}
|
||||
|
||||
// Create 创建优惠码
|
||||
func (s *PromoService) Create(ctx context.Context, input *CreatePromoCodeInput) (*PromoCode, error) {
|
||||
code := strings.TrimSpace(input.Code)
|
||||
if code == "" {
|
||||
// 自动生成
|
||||
var err error
|
||||
code, err = s.GenerateRandomCode()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
promoCode := &PromoCode{
|
||||
Code: strings.ToUpper(code),
|
||||
BonusAmount: input.BonusAmount,
|
||||
MaxUses: input.MaxUses,
|
||||
UsedCount: 0,
|
||||
Status: PromoCodeStatusActive,
|
||||
ExpiresAt: input.ExpiresAt,
|
||||
Notes: input.Notes,
|
||||
}
|
||||
|
||||
if err := s.promoRepo.Create(ctx, promoCode); err != nil {
|
||||
return nil, fmt.Errorf("create promo code: %w", err)
|
||||
}
|
||||
|
||||
return promoCode, nil
|
||||
}
|
||||
|
||||
// GetByID 根据ID获取优惠码
|
||||
func (s *PromoService) GetByID(ctx context.Context, id int64) (*PromoCode, error) {
|
||||
code, err := s.promoRepo.GetByID(ctx, id)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return code, nil
|
||||
}
|
||||
|
||||
// Update 更新优惠码
|
||||
func (s *PromoService) Update(ctx context.Context, id int64, input *UpdatePromoCodeInput) (*PromoCode, error) {
|
||||
promoCode, err := s.promoRepo.GetByID(ctx, id)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if input.Code != nil {
|
||||
promoCode.Code = strings.ToUpper(strings.TrimSpace(*input.Code))
|
||||
}
|
||||
if input.BonusAmount != nil {
|
||||
promoCode.BonusAmount = *input.BonusAmount
|
||||
}
|
||||
if input.MaxUses != nil {
|
||||
promoCode.MaxUses = *input.MaxUses
|
||||
}
|
||||
if input.Status != nil {
|
||||
promoCode.Status = *input.Status
|
||||
}
|
||||
if input.ExpiresAt != nil {
|
||||
promoCode.ExpiresAt = input.ExpiresAt
|
||||
}
|
||||
if input.Notes != nil {
|
||||
promoCode.Notes = *input.Notes
|
||||
}
|
||||
|
||||
if err := s.promoRepo.Update(ctx, promoCode); err != nil {
|
||||
return nil, fmt.Errorf("update promo code: %w", err)
|
||||
}
|
||||
|
||||
return promoCode, nil
|
||||
}
|
||||
|
||||
// Delete 删除优惠码
|
||||
func (s *PromoService) Delete(ctx context.Context, id int64) error {
|
||||
if err := s.promoRepo.Delete(ctx, id); err != nil {
|
||||
return fmt.Errorf("delete promo code: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// List 获取优惠码列表
|
||||
func (s *PromoService) List(ctx context.Context, params pagination.PaginationParams, status, search string) ([]PromoCode, *pagination.PaginationResult, error) {
|
||||
return s.promoRepo.ListWithFilters(ctx, params, status, search)
|
||||
}
|
||||
|
||||
// ListUsages 获取使用记录
|
||||
func (s *PromoService) ListUsages(ctx context.Context, promoCodeID int64, params pagination.PaginationParams) ([]PromoCodeUsage, *pagination.PaginationResult, error) {
|
||||
return s.promoRepo.ListUsagesByPromoCode(ctx, promoCodeID, params)
|
||||
}
|
||||
@@ -87,6 +87,7 @@ var ProviderSet = wire.NewSet(
|
||||
NewAccountService,
|
||||
NewProxyService,
|
||||
NewRedeemService,
|
||||
NewPromoService,
|
||||
NewUsageService,
|
||||
NewDashboardService,
|
||||
ProvidePricingService,
|
||||
|
||||
Reference in New Issue
Block a user