feat: 实现注册优惠码功能
- 支持创建/编辑/删除优惠码,设置赠送金额和使用限制 - 注册页面实时验证优惠码并显示赠送金额 - 支持 URL 参数自动填充 (?promo=CODE) - 添加优惠码验证接口速率限制 - 使用数据库行锁防止并发超限 - 新增后台优惠码管理页面,支持复制注册链接
This commit is contained in:
209
backend/internal/handler/admin/promo_handler.go
Normal file
209
backend/internal/handler/admin/promo_handler.go
Normal file
@@ -0,0 +1,209 @@
|
||||
package admin
|
||||
|
||||
import (
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/handler/dto"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// PromoHandler handles admin promo code management
|
||||
type PromoHandler struct {
|
||||
promoService *service.PromoService
|
||||
}
|
||||
|
||||
// NewPromoHandler creates a new admin promo handler
|
||||
func NewPromoHandler(promoService *service.PromoService) *PromoHandler {
|
||||
return &PromoHandler{
|
||||
promoService: promoService,
|
||||
}
|
||||
}
|
||||
|
||||
// CreatePromoCodeRequest represents create promo code request
|
||||
type CreatePromoCodeRequest struct {
|
||||
Code string `json:"code"` // 可选,为空则自动生成
|
||||
BonusAmount float64 `json:"bonus_amount" binding:"required,min=0"` // 赠送余额
|
||||
MaxUses int `json:"max_uses" binding:"min=0"` // 最大使用次数,0=无限
|
||||
ExpiresAt *int64 `json:"expires_at"` // 过期时间戳(秒)
|
||||
Notes string `json:"notes"` // 备注
|
||||
}
|
||||
|
||||
// UpdatePromoCodeRequest represents update promo code request
|
||||
type UpdatePromoCodeRequest struct {
|
||||
Code *string `json:"code"`
|
||||
BonusAmount *float64 `json:"bonus_amount" binding:"omitempty,min=0"`
|
||||
MaxUses *int `json:"max_uses" binding:"omitempty,min=0"`
|
||||
Status *string `json:"status" binding:"omitempty,oneof=active disabled"`
|
||||
ExpiresAt *int64 `json:"expires_at"`
|
||||
Notes *string `json:"notes"`
|
||||
}
|
||||
|
||||
// List handles listing all promo codes with pagination
|
||||
// GET /api/v1/admin/promo-codes
|
||||
func (h *PromoHandler) List(c *gin.Context) {
|
||||
page, pageSize := response.ParsePagination(c)
|
||||
status := c.Query("status")
|
||||
search := strings.TrimSpace(c.Query("search"))
|
||||
if len(search) > 100 {
|
||||
search = search[:100]
|
||||
}
|
||||
|
||||
params := pagination.PaginationParams{
|
||||
Page: page,
|
||||
PageSize: pageSize,
|
||||
}
|
||||
|
||||
codes, paginationResult, err := h.promoService.List(c.Request.Context(), params, status, search)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
out := make([]dto.PromoCode, 0, len(codes))
|
||||
for i := range codes {
|
||||
out = append(out, *dto.PromoCodeFromService(&codes[i]))
|
||||
}
|
||||
response.Paginated(c, out, paginationResult.Total, page, pageSize)
|
||||
}
|
||||
|
||||
// GetByID handles getting a promo code by ID
|
||||
// GET /api/v1/admin/promo-codes/:id
|
||||
func (h *PromoHandler) GetByID(c *gin.Context) {
|
||||
codeID, err := strconv.ParseInt(c.Param("id"), 10, 64)
|
||||
if err != nil {
|
||||
response.BadRequest(c, "Invalid promo code ID")
|
||||
return
|
||||
}
|
||||
|
||||
code, err := h.promoService.GetByID(c.Request.Context(), codeID)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, dto.PromoCodeFromService(code))
|
||||
}
|
||||
|
||||
// Create handles creating a new promo code
|
||||
// POST /api/v1/admin/promo-codes
|
||||
func (h *PromoHandler) Create(c *gin.Context) {
|
||||
var req CreatePromoCodeRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.BadRequest(c, "Invalid request: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
input := &service.CreatePromoCodeInput{
|
||||
Code: req.Code,
|
||||
BonusAmount: req.BonusAmount,
|
||||
MaxUses: req.MaxUses,
|
||||
Notes: req.Notes,
|
||||
}
|
||||
|
||||
if req.ExpiresAt != nil {
|
||||
t := time.Unix(*req.ExpiresAt, 0)
|
||||
input.ExpiresAt = &t
|
||||
}
|
||||
|
||||
code, err := h.promoService.Create(c.Request.Context(), input)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, dto.PromoCodeFromService(code))
|
||||
}
|
||||
|
||||
// Update handles updating a promo code
|
||||
// PUT /api/v1/admin/promo-codes/:id
|
||||
func (h *PromoHandler) Update(c *gin.Context) {
|
||||
codeID, err := strconv.ParseInt(c.Param("id"), 10, 64)
|
||||
if err != nil {
|
||||
response.BadRequest(c, "Invalid promo code ID")
|
||||
return
|
||||
}
|
||||
|
||||
var req UpdatePromoCodeRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.BadRequest(c, "Invalid request: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
input := &service.UpdatePromoCodeInput{
|
||||
Code: req.Code,
|
||||
BonusAmount: req.BonusAmount,
|
||||
MaxUses: req.MaxUses,
|
||||
Status: req.Status,
|
||||
Notes: req.Notes,
|
||||
}
|
||||
|
||||
if req.ExpiresAt != nil {
|
||||
if *req.ExpiresAt == 0 {
|
||||
// 0 表示清除过期时间
|
||||
input.ExpiresAt = nil
|
||||
} else {
|
||||
t := time.Unix(*req.ExpiresAt, 0)
|
||||
input.ExpiresAt = &t
|
||||
}
|
||||
}
|
||||
|
||||
code, err := h.promoService.Update(c.Request.Context(), codeID, input)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, dto.PromoCodeFromService(code))
|
||||
}
|
||||
|
||||
// Delete handles deleting a promo code
|
||||
// DELETE /api/v1/admin/promo-codes/:id
|
||||
func (h *PromoHandler) Delete(c *gin.Context) {
|
||||
codeID, err := strconv.ParseInt(c.Param("id"), 10, 64)
|
||||
if err != nil {
|
||||
response.BadRequest(c, "Invalid promo code ID")
|
||||
return
|
||||
}
|
||||
|
||||
err = h.promoService.Delete(c.Request.Context(), codeID)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, gin.H{"message": "Promo code deleted successfully"})
|
||||
}
|
||||
|
||||
// GetUsages handles getting usage records for a promo code
|
||||
// GET /api/v1/admin/promo-codes/:id/usages
|
||||
func (h *PromoHandler) GetUsages(c *gin.Context) {
|
||||
codeID, err := strconv.ParseInt(c.Param("id"), 10, 64)
|
||||
if err != nil {
|
||||
response.BadRequest(c, "Invalid promo code ID")
|
||||
return
|
||||
}
|
||||
|
||||
page, pageSize := response.ParsePagination(c)
|
||||
params := pagination.PaginationParams{
|
||||
Page: page,
|
||||
PageSize: pageSize,
|
||||
}
|
||||
|
||||
usages, paginationResult, err := h.promoService.ListUsages(c.Request.Context(), codeID, params)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
out := make([]dto.PromoCodeUsage, 0, len(usages))
|
||||
for i := range usages {
|
||||
out = append(out, *dto.PromoCodeUsageFromService(&usages[i]))
|
||||
}
|
||||
response.Paginated(c, out, paginationResult.Total, page, pageSize)
|
||||
}
|
||||
@@ -12,19 +12,21 @@ import (
|
||||
|
||||
// AuthHandler handles authentication-related requests
|
||||
type AuthHandler struct {
|
||||
cfg *config.Config
|
||||
authService *service.AuthService
|
||||
userService *service.UserService
|
||||
settingSvc *service.SettingService
|
||||
cfg *config.Config
|
||||
authService *service.AuthService
|
||||
userService *service.UserService
|
||||
settingSvc *service.SettingService
|
||||
promoService *service.PromoService
|
||||
}
|
||||
|
||||
// NewAuthHandler creates a new AuthHandler
|
||||
func NewAuthHandler(cfg *config.Config, authService *service.AuthService, userService *service.UserService, settingService *service.SettingService) *AuthHandler {
|
||||
func NewAuthHandler(cfg *config.Config, authService *service.AuthService, userService *service.UserService, settingService *service.SettingService, promoService *service.PromoService) *AuthHandler {
|
||||
return &AuthHandler{
|
||||
cfg: cfg,
|
||||
authService: authService,
|
||||
userService: userService,
|
||||
settingSvc: settingService,
|
||||
cfg: cfg,
|
||||
authService: authService,
|
||||
userService: userService,
|
||||
settingSvc: settingService,
|
||||
promoService: promoService,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -34,6 +36,7 @@ type RegisterRequest struct {
|
||||
Password string `json:"password" binding:"required,min=6"`
|
||||
VerifyCode string `json:"verify_code"`
|
||||
TurnstileToken string `json:"turnstile_token"`
|
||||
PromoCode string `json:"promo_code"` // 注册优惠码
|
||||
}
|
||||
|
||||
// SendVerifyCodeRequest 发送验证码请求
|
||||
@@ -79,7 +82,7 @@ func (h *AuthHandler) Register(c *gin.Context) {
|
||||
}
|
||||
}
|
||||
|
||||
token, user, err := h.authService.RegisterWithVerification(c.Request.Context(), req.Email, req.Password, req.VerifyCode)
|
||||
token, user, err := h.authService.RegisterWithVerification(c.Request.Context(), req.Email, req.Password, req.VerifyCode, req.PromoCode)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
@@ -174,3 +177,63 @@ func (h *AuthHandler) GetCurrentUser(c *gin.Context) {
|
||||
|
||||
response.Success(c, UserResponse{User: dto.UserFromService(user), RunMode: runMode})
|
||||
}
|
||||
|
||||
// ValidatePromoCodeRequest 验证优惠码请求
|
||||
type ValidatePromoCodeRequest struct {
|
||||
Code string `json:"code" binding:"required"`
|
||||
}
|
||||
|
||||
// ValidatePromoCodeResponse 验证优惠码响应
|
||||
type ValidatePromoCodeResponse struct {
|
||||
Valid bool `json:"valid"`
|
||||
BonusAmount float64 `json:"bonus_amount,omitempty"`
|
||||
ErrorCode string `json:"error_code,omitempty"`
|
||||
Message string `json:"message,omitempty"`
|
||||
}
|
||||
|
||||
// ValidatePromoCode 验证优惠码(公开接口,注册前调用)
|
||||
// POST /api/v1/auth/validate-promo-code
|
||||
func (h *AuthHandler) ValidatePromoCode(c *gin.Context) {
|
||||
var req ValidatePromoCodeRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.BadRequest(c, "Invalid request: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
promoCode, err := h.promoService.ValidatePromoCode(c.Request.Context(), req.Code)
|
||||
if err != nil {
|
||||
// 根据错误类型返回对应的错误码
|
||||
errorCode := "PROMO_CODE_INVALID"
|
||||
switch err {
|
||||
case service.ErrPromoCodeNotFound:
|
||||
errorCode = "PROMO_CODE_NOT_FOUND"
|
||||
case service.ErrPromoCodeExpired:
|
||||
errorCode = "PROMO_CODE_EXPIRED"
|
||||
case service.ErrPromoCodeDisabled:
|
||||
errorCode = "PROMO_CODE_DISABLED"
|
||||
case service.ErrPromoCodeMaxUsed:
|
||||
errorCode = "PROMO_CODE_MAX_USED"
|
||||
case service.ErrPromoCodeAlreadyUsed:
|
||||
errorCode = "PROMO_CODE_ALREADY_USED"
|
||||
}
|
||||
|
||||
response.Success(c, ValidatePromoCodeResponse{
|
||||
Valid: false,
|
||||
ErrorCode: errorCode,
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
if promoCode == nil {
|
||||
response.Success(c, ValidatePromoCodeResponse{
|
||||
Valid: false,
|
||||
ErrorCode: "PROMO_CODE_INVALID",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, ValidatePromoCodeResponse{
|
||||
Valid: true,
|
||||
BonusAmount: promoCode.BonusAmount,
|
||||
})
|
||||
}
|
||||
|
||||
@@ -370,3 +370,35 @@ func BulkAssignResultFromService(r *service.BulkAssignResult) *BulkAssignResult
|
||||
Errors: r.Errors,
|
||||
}
|
||||
}
|
||||
|
||||
func PromoCodeFromService(pc *service.PromoCode) *PromoCode {
|
||||
if pc == nil {
|
||||
return nil
|
||||
}
|
||||
return &PromoCode{
|
||||
ID: pc.ID,
|
||||
Code: pc.Code,
|
||||
BonusAmount: pc.BonusAmount,
|
||||
MaxUses: pc.MaxUses,
|
||||
UsedCount: pc.UsedCount,
|
||||
Status: pc.Status,
|
||||
ExpiresAt: pc.ExpiresAt,
|
||||
Notes: pc.Notes,
|
||||
CreatedAt: pc.CreatedAt,
|
||||
UpdatedAt: pc.UpdatedAt,
|
||||
}
|
||||
}
|
||||
|
||||
func PromoCodeUsageFromService(u *service.PromoCodeUsage) *PromoCodeUsage {
|
||||
if u == nil {
|
||||
return nil
|
||||
}
|
||||
return &PromoCodeUsage{
|
||||
ID: u.ID,
|
||||
PromoCodeID: u.PromoCodeID,
|
||||
UserID: u.UserID,
|
||||
BonusAmount: u.BonusAmount,
|
||||
UsedAt: u.UsedAt,
|
||||
User: UserFromServiceShallow(u.User),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -250,3 +250,28 @@ type BulkAssignResult struct {
|
||||
Subscriptions []UserSubscription `json:"subscriptions"`
|
||||
Errors []string `json:"errors"`
|
||||
}
|
||||
|
||||
// PromoCode 注册优惠码
|
||||
type PromoCode struct {
|
||||
ID int64 `json:"id"`
|
||||
Code string `json:"code"`
|
||||
BonusAmount float64 `json:"bonus_amount"`
|
||||
MaxUses int `json:"max_uses"`
|
||||
UsedCount int `json:"used_count"`
|
||||
Status string `json:"status"`
|
||||
ExpiresAt *time.Time `json:"expires_at"`
|
||||
Notes string `json:"notes"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
UpdatedAt time.Time `json:"updated_at"`
|
||||
}
|
||||
|
||||
// PromoCodeUsage 优惠码使用记录
|
||||
type PromoCodeUsage struct {
|
||||
ID int64 `json:"id"`
|
||||
PromoCodeID int64 `json:"promo_code_id"`
|
||||
UserID int64 `json:"user_id"`
|
||||
BonusAmount float64 `json:"bonus_amount"`
|
||||
UsedAt time.Time `json:"used_at"`
|
||||
|
||||
User *User `json:"user,omitempty"`
|
||||
}
|
||||
|
||||
@@ -16,6 +16,7 @@ type AdminHandlers struct {
|
||||
AntigravityOAuth *admin.AntigravityOAuthHandler
|
||||
Proxy *admin.ProxyHandler
|
||||
Redeem *admin.RedeemHandler
|
||||
Promo *admin.PromoHandler
|
||||
Setting *admin.SettingHandler
|
||||
System *admin.SystemHandler
|
||||
Subscription *admin.SubscriptionHandler
|
||||
|
||||
@@ -19,6 +19,7 @@ func ProvideAdminHandlers(
|
||||
antigravityOAuthHandler *admin.AntigravityOAuthHandler,
|
||||
proxyHandler *admin.ProxyHandler,
|
||||
redeemHandler *admin.RedeemHandler,
|
||||
promoHandler *admin.PromoHandler,
|
||||
settingHandler *admin.SettingHandler,
|
||||
systemHandler *admin.SystemHandler,
|
||||
subscriptionHandler *admin.SubscriptionHandler,
|
||||
@@ -36,6 +37,7 @@ func ProvideAdminHandlers(
|
||||
AntigravityOAuth: antigravityOAuthHandler,
|
||||
Proxy: proxyHandler,
|
||||
Redeem: redeemHandler,
|
||||
Promo: promoHandler,
|
||||
Setting: settingHandler,
|
||||
System: systemHandler,
|
||||
Subscription: subscriptionHandler,
|
||||
@@ -105,6 +107,7 @@ var ProviderSet = wire.NewSet(
|
||||
admin.NewAntigravityOAuthHandler,
|
||||
admin.NewProxyHandler,
|
||||
admin.NewRedeemHandler,
|
||||
admin.NewPromoHandler,
|
||||
admin.NewSettingHandler,
|
||||
ProvideSystemHandler,
|
||||
admin.NewSubscriptionHandler,
|
||||
|
||||
60
backend/internal/middleware/rate_limiter.go
Normal file
60
backend/internal/middleware/rate_limiter.go
Normal file
@@ -0,0 +1,60 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/redis/go-redis/v9"
|
||||
)
|
||||
|
||||
// RateLimiter Redis 速率限制器
|
||||
type RateLimiter struct {
|
||||
redis *redis.Client
|
||||
prefix string
|
||||
}
|
||||
|
||||
// NewRateLimiter 创建速率限制器实例
|
||||
func NewRateLimiter(redisClient *redis.Client) *RateLimiter {
|
||||
return &RateLimiter{
|
||||
redis: redisClient,
|
||||
prefix: "rate_limit:",
|
||||
}
|
||||
}
|
||||
|
||||
// Limit 返回速率限制中间件
|
||||
// key: 限制类型标识
|
||||
// limit: 时间窗口内最大请求数
|
||||
// window: 时间窗口
|
||||
func (r *RateLimiter) Limit(key string, limit int, window time.Duration) gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
ip := c.ClientIP()
|
||||
redisKey := r.prefix + key + ":" + ip
|
||||
|
||||
ctx := c.Request.Context()
|
||||
|
||||
// 使用 INCR 原子操作增加计数
|
||||
count, err := r.redis.Incr(ctx, redisKey).Result()
|
||||
if err != nil {
|
||||
// Redis 错误时放行,避免影响正常服务
|
||||
c.Next()
|
||||
return
|
||||
}
|
||||
|
||||
// 首次访问时设置过期时间
|
||||
if count == 1 {
|
||||
r.redis.Expire(ctx, redisKey, window)
|
||||
}
|
||||
|
||||
// 超过限制
|
||||
if count > int64(limit) {
|
||||
c.AbortWithStatusJSON(http.StatusTooManyRequests, gin.H{
|
||||
"error": "rate limit exceeded",
|
||||
"message": "Too many requests, please try again later",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
273
backend/internal/repository/promo_code_repo.go
Normal file
273
backend/internal/repository/promo_code_repo.go
Normal file
@@ -0,0 +1,273 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
dbent "github.com/Wei-Shaw/sub2api/ent"
|
||||
"github.com/Wei-Shaw/sub2api/ent/promocode"
|
||||
"github.com/Wei-Shaw/sub2api/ent/promocodeusage"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
)
|
||||
|
||||
type promoCodeRepository struct {
|
||||
client *dbent.Client
|
||||
}
|
||||
|
||||
func NewPromoCodeRepository(client *dbent.Client) service.PromoCodeRepository {
|
||||
return &promoCodeRepository{client: client}
|
||||
}
|
||||
|
||||
func (r *promoCodeRepository) Create(ctx context.Context, code *service.PromoCode) error {
|
||||
client := clientFromContext(ctx, r.client)
|
||||
builder := client.PromoCode.Create().
|
||||
SetCode(code.Code).
|
||||
SetBonusAmount(code.BonusAmount).
|
||||
SetMaxUses(code.MaxUses).
|
||||
SetUsedCount(code.UsedCount).
|
||||
SetStatus(code.Status).
|
||||
SetNotes(code.Notes)
|
||||
|
||||
if code.ExpiresAt != nil {
|
||||
builder.SetExpiresAt(*code.ExpiresAt)
|
||||
}
|
||||
|
||||
created, err := builder.Save(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
code.ID = created.ID
|
||||
code.CreatedAt = created.CreatedAt
|
||||
code.UpdatedAt = created.UpdatedAt
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *promoCodeRepository) GetByID(ctx context.Context, id int64) (*service.PromoCode, error) {
|
||||
m, err := r.client.PromoCode.Query().
|
||||
Where(promocode.IDEQ(id)).
|
||||
Only(ctx)
|
||||
if err != nil {
|
||||
if dbent.IsNotFound(err) {
|
||||
return nil, service.ErrPromoCodeNotFound
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
return promoCodeEntityToService(m), nil
|
||||
}
|
||||
|
||||
func (r *promoCodeRepository) GetByCode(ctx context.Context, code string) (*service.PromoCode, error) {
|
||||
m, err := r.client.PromoCode.Query().
|
||||
Where(promocode.CodeEqualFold(code)).
|
||||
Only(ctx)
|
||||
if err != nil {
|
||||
if dbent.IsNotFound(err) {
|
||||
return nil, service.ErrPromoCodeNotFound
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
return promoCodeEntityToService(m), nil
|
||||
}
|
||||
|
||||
func (r *promoCodeRepository) GetByCodeForUpdate(ctx context.Context, code string) (*service.PromoCode, error) {
|
||||
client := clientFromContext(ctx, r.client)
|
||||
m, err := client.PromoCode.Query().
|
||||
Where(promocode.CodeEqualFold(code)).
|
||||
ForUpdate().
|
||||
Only(ctx)
|
||||
if err != nil {
|
||||
if dbent.IsNotFound(err) {
|
||||
return nil, service.ErrPromoCodeNotFound
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
return promoCodeEntityToService(m), nil
|
||||
}
|
||||
|
||||
func (r *promoCodeRepository) Update(ctx context.Context, code *service.PromoCode) error {
|
||||
client := clientFromContext(ctx, r.client)
|
||||
builder := client.PromoCode.UpdateOneID(code.ID).
|
||||
SetCode(code.Code).
|
||||
SetBonusAmount(code.BonusAmount).
|
||||
SetMaxUses(code.MaxUses).
|
||||
SetUsedCount(code.UsedCount).
|
||||
SetStatus(code.Status).
|
||||
SetNotes(code.Notes)
|
||||
|
||||
if code.ExpiresAt != nil {
|
||||
builder.SetExpiresAt(*code.ExpiresAt)
|
||||
} else {
|
||||
builder.ClearExpiresAt()
|
||||
}
|
||||
|
||||
updated, err := builder.Save(ctx)
|
||||
if err != nil {
|
||||
if dbent.IsNotFound(err) {
|
||||
return service.ErrPromoCodeNotFound
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
code.UpdatedAt = updated.UpdatedAt
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *promoCodeRepository) Delete(ctx context.Context, id int64) error {
|
||||
client := clientFromContext(ctx, r.client)
|
||||
_, err := client.PromoCode.Delete().Where(promocode.IDEQ(id)).Exec(ctx)
|
||||
return err
|
||||
}
|
||||
|
||||
func (r *promoCodeRepository) List(ctx context.Context, params pagination.PaginationParams) ([]service.PromoCode, *pagination.PaginationResult, error) {
|
||||
return r.ListWithFilters(ctx, params, "", "")
|
||||
}
|
||||
|
||||
func (r *promoCodeRepository) ListWithFilters(ctx context.Context, params pagination.PaginationParams, status, search string) ([]service.PromoCode, *pagination.PaginationResult, error) {
|
||||
q := r.client.PromoCode.Query()
|
||||
|
||||
if status != "" {
|
||||
q = q.Where(promocode.StatusEQ(status))
|
||||
}
|
||||
if search != "" {
|
||||
q = q.Where(promocode.CodeContainsFold(search))
|
||||
}
|
||||
|
||||
total, err := q.Count(ctx)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
codes, err := q.
|
||||
Offset(params.Offset()).
|
||||
Limit(params.Limit()).
|
||||
Order(dbent.Desc(promocode.FieldID)).
|
||||
All(ctx)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
outCodes := promoCodeEntitiesToService(codes)
|
||||
|
||||
return outCodes, paginationResultFromTotal(int64(total), params), nil
|
||||
}
|
||||
|
||||
func (r *promoCodeRepository) CreateUsage(ctx context.Context, usage *service.PromoCodeUsage) error {
|
||||
client := clientFromContext(ctx, r.client)
|
||||
created, err := client.PromoCodeUsage.Create().
|
||||
SetPromoCodeID(usage.PromoCodeID).
|
||||
SetUserID(usage.UserID).
|
||||
SetBonusAmount(usage.BonusAmount).
|
||||
SetUsedAt(usage.UsedAt).
|
||||
Save(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
usage.ID = created.ID
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *promoCodeRepository) GetUsageByPromoCodeAndUser(ctx context.Context, promoCodeID, userID int64) (*service.PromoCodeUsage, error) {
|
||||
m, err := r.client.PromoCodeUsage.Query().
|
||||
Where(
|
||||
promocodeusage.PromoCodeIDEQ(promoCodeID),
|
||||
promocodeusage.UserIDEQ(userID),
|
||||
).
|
||||
Only(ctx)
|
||||
if err != nil {
|
||||
if dbent.IsNotFound(err) {
|
||||
return nil, nil
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
return promoCodeUsageEntityToService(m), nil
|
||||
}
|
||||
|
||||
func (r *promoCodeRepository) ListUsagesByPromoCode(ctx context.Context, promoCodeID int64, params pagination.PaginationParams) ([]service.PromoCodeUsage, *pagination.PaginationResult, error) {
|
||||
q := r.client.PromoCodeUsage.Query().
|
||||
Where(promocodeusage.PromoCodeIDEQ(promoCodeID))
|
||||
|
||||
total, err := q.Count(ctx)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
usages, err := q.
|
||||
WithUser().
|
||||
Offset(params.Offset()).
|
||||
Limit(params.Limit()).
|
||||
Order(dbent.Desc(promocodeusage.FieldID)).
|
||||
All(ctx)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
outUsages := promoCodeUsageEntitiesToService(usages)
|
||||
|
||||
return outUsages, paginationResultFromTotal(int64(total), params), nil
|
||||
}
|
||||
|
||||
func (r *promoCodeRepository) IncrementUsedCount(ctx context.Context, id int64) error {
|
||||
client := clientFromContext(ctx, r.client)
|
||||
_, err := client.PromoCode.UpdateOneID(id).
|
||||
AddUsedCount(1).
|
||||
Save(ctx)
|
||||
return err
|
||||
}
|
||||
|
||||
// Entity to Service conversions
|
||||
|
||||
func promoCodeEntityToService(m *dbent.PromoCode) *service.PromoCode {
|
||||
if m == nil {
|
||||
return nil
|
||||
}
|
||||
return &service.PromoCode{
|
||||
ID: m.ID,
|
||||
Code: m.Code,
|
||||
BonusAmount: m.BonusAmount,
|
||||
MaxUses: m.MaxUses,
|
||||
UsedCount: m.UsedCount,
|
||||
Status: m.Status,
|
||||
ExpiresAt: m.ExpiresAt,
|
||||
Notes: derefString(m.Notes),
|
||||
CreatedAt: m.CreatedAt,
|
||||
UpdatedAt: m.UpdatedAt,
|
||||
}
|
||||
}
|
||||
|
||||
func promoCodeEntitiesToService(models []*dbent.PromoCode) []service.PromoCode {
|
||||
out := make([]service.PromoCode, 0, len(models))
|
||||
for i := range models {
|
||||
if s := promoCodeEntityToService(models[i]); s != nil {
|
||||
out = append(out, *s)
|
||||
}
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func promoCodeUsageEntityToService(m *dbent.PromoCodeUsage) *service.PromoCodeUsage {
|
||||
if m == nil {
|
||||
return nil
|
||||
}
|
||||
out := &service.PromoCodeUsage{
|
||||
ID: m.ID,
|
||||
PromoCodeID: m.PromoCodeID,
|
||||
UserID: m.UserID,
|
||||
BonusAmount: m.BonusAmount,
|
||||
UsedAt: m.UsedAt,
|
||||
}
|
||||
if m.Edges.User != nil {
|
||||
out.User = userEntityToService(m.Edges.User)
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func promoCodeUsageEntitiesToService(models []*dbent.PromoCodeUsage) []service.PromoCodeUsage {
|
||||
out := make([]service.PromoCodeUsage, 0, len(models))
|
||||
for i := range models {
|
||||
if s := promoCodeUsageEntityToService(models[i]); s != nil {
|
||||
out = append(out, *s)
|
||||
}
|
||||
}
|
||||
return out
|
||||
}
|
||||
@@ -45,6 +45,7 @@ var ProviderSet = wire.NewSet(
|
||||
NewAccountRepository,
|
||||
NewProxyRepository,
|
||||
NewRedeemCodeRepository,
|
||||
NewPromoCodeRepository,
|
||||
NewUsageLogRepository,
|
||||
NewSettingRepository,
|
||||
NewUserSubscriptionRepository,
|
||||
|
||||
@@ -13,6 +13,7 @@ import (
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/google/wire"
|
||||
"github.com/redis/go-redis/v9"
|
||||
)
|
||||
|
||||
// ProviderSet 提供服务器层的依赖
|
||||
@@ -30,6 +31,7 @@ func ProvideRouter(
|
||||
apiKeyAuth middleware2.APIKeyAuthMiddleware,
|
||||
apiKeyService *service.APIKeyService,
|
||||
subscriptionService *service.SubscriptionService,
|
||||
redisClient *redis.Client,
|
||||
) *gin.Engine {
|
||||
if cfg.Server.Mode == "release" {
|
||||
gin.SetMode(gin.ReleaseMode)
|
||||
@@ -47,7 +49,7 @@ func ProvideRouter(
|
||||
}
|
||||
}
|
||||
|
||||
return SetupRouter(r, handlers, jwtAuth, adminAuth, apiKeyAuth, apiKeyService, subscriptionService, cfg)
|
||||
return SetupRouter(r, handlers, jwtAuth, adminAuth, apiKeyAuth, apiKeyService, subscriptionService, cfg, redisClient)
|
||||
}
|
||||
|
||||
// ProvideHTTPServer 提供 HTTP 服务器
|
||||
|
||||
@@ -9,6 +9,7 @@ import (
|
||||
"github.com/Wei-Shaw/sub2api/internal/web"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/redis/go-redis/v9"
|
||||
)
|
||||
|
||||
// SetupRouter 配置路由器中间件和路由
|
||||
@@ -21,6 +22,7 @@ func SetupRouter(
|
||||
apiKeyService *service.APIKeyService,
|
||||
subscriptionService *service.SubscriptionService,
|
||||
cfg *config.Config,
|
||||
redisClient *redis.Client,
|
||||
) *gin.Engine {
|
||||
// 应用中间件
|
||||
r.Use(middleware2.Logger())
|
||||
@@ -33,7 +35,7 @@ func SetupRouter(
|
||||
}
|
||||
|
||||
// 注册路由
|
||||
registerRoutes(r, handlers, jwtAuth, adminAuth, apiKeyAuth, apiKeyService, subscriptionService, cfg)
|
||||
registerRoutes(r, handlers, jwtAuth, adminAuth, apiKeyAuth, apiKeyService, subscriptionService, cfg, redisClient)
|
||||
|
||||
return r
|
||||
}
|
||||
@@ -48,6 +50,7 @@ func registerRoutes(
|
||||
apiKeyService *service.APIKeyService,
|
||||
subscriptionService *service.SubscriptionService,
|
||||
cfg *config.Config,
|
||||
redisClient *redis.Client,
|
||||
) {
|
||||
// 通用路由(健康检查、状态等)
|
||||
routes.RegisterCommonRoutes(r)
|
||||
@@ -56,7 +59,7 @@ func registerRoutes(
|
||||
v1 := r.Group("/api/v1")
|
||||
|
||||
// 注册各模块路由
|
||||
routes.RegisterAuthRoutes(v1, h, jwtAuth)
|
||||
routes.RegisterAuthRoutes(v1, h, jwtAuth, redisClient)
|
||||
routes.RegisterUserRoutes(v1, h, jwtAuth)
|
||||
routes.RegisterAdminRoutes(v1, h, adminAuth)
|
||||
routes.RegisterGatewayRoutes(r, h, apiKeyAuth, apiKeyService, subscriptionService, cfg)
|
||||
|
||||
@@ -44,6 +44,9 @@ func RegisterAdminRoutes(
|
||||
// 卡密管理
|
||||
registerRedeemCodeRoutes(admin, h)
|
||||
|
||||
// 优惠码管理
|
||||
registerPromoCodeRoutes(admin, h)
|
||||
|
||||
// 系统设置
|
||||
registerSettingsRoutes(admin, h)
|
||||
|
||||
@@ -201,6 +204,18 @@ func registerRedeemCodeRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
|
||||
}
|
||||
}
|
||||
|
||||
func registerPromoCodeRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
|
||||
promoCodes := admin.Group("/promo-codes")
|
||||
{
|
||||
promoCodes.GET("", h.Admin.Promo.List)
|
||||
promoCodes.GET("/:id", h.Admin.Promo.GetByID)
|
||||
promoCodes.POST("", h.Admin.Promo.Create)
|
||||
promoCodes.PUT("/:id", h.Admin.Promo.Update)
|
||||
promoCodes.DELETE("/:id", h.Admin.Promo.Delete)
|
||||
promoCodes.GET("/:id/usages", h.Admin.Promo.GetUsages)
|
||||
}
|
||||
}
|
||||
|
||||
func registerSettingsRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
|
||||
adminSettings := admin.Group("/settings")
|
||||
{
|
||||
|
||||
@@ -1,24 +1,34 @@
|
||||
package routes
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/handler"
|
||||
"github.com/Wei-Shaw/sub2api/internal/server/middleware"
|
||||
"github.com/Wei-Shaw/sub2api/internal/middleware"
|
||||
servermiddleware "github.com/Wei-Shaw/sub2api/internal/server/middleware"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/redis/go-redis/v9"
|
||||
)
|
||||
|
||||
// RegisterAuthRoutes 注册认证相关路由
|
||||
func RegisterAuthRoutes(
|
||||
v1 *gin.RouterGroup,
|
||||
h *handler.Handlers,
|
||||
jwtAuth middleware.JWTAuthMiddleware,
|
||||
jwtAuth servermiddleware.JWTAuthMiddleware,
|
||||
redisClient *redis.Client,
|
||||
) {
|
||||
// 创建速率限制器
|
||||
rateLimiter := middleware.NewRateLimiter(redisClient)
|
||||
|
||||
// 公开接口
|
||||
auth := v1.Group("/auth")
|
||||
{
|
||||
auth.POST("/register", h.Auth.Register)
|
||||
auth.POST("/login", h.Auth.Login)
|
||||
auth.POST("/send-verify-code", h.Auth.SendVerifyCode)
|
||||
// 优惠码验证接口添加速率限制:每分钟最多 10 次
|
||||
auth.POST("/validate-promo-code", rateLimiter.Limit("validate-promo", 10, time.Minute), h.Auth.ValidatePromoCode)
|
||||
auth.GET("/oauth/linuxdo/start", h.Auth.LinuxDoOAuthStart)
|
||||
auth.GET("/oauth/linuxdo/callback", h.Auth.LinuxDoOAuthCallback)
|
||||
}
|
||||
|
||||
@@ -52,6 +52,7 @@ type AuthService struct {
|
||||
emailService *EmailService
|
||||
turnstileService *TurnstileService
|
||||
emailQueueService *EmailQueueService
|
||||
promoService *PromoService
|
||||
}
|
||||
|
||||
// NewAuthService 创建认证服务实例
|
||||
@@ -62,6 +63,7 @@ func NewAuthService(
|
||||
emailService *EmailService,
|
||||
turnstileService *TurnstileService,
|
||||
emailQueueService *EmailQueueService,
|
||||
promoService *PromoService,
|
||||
) *AuthService {
|
||||
return &AuthService{
|
||||
userRepo: userRepo,
|
||||
@@ -70,16 +72,17 @@ func NewAuthService(
|
||||
emailService: emailService,
|
||||
turnstileService: turnstileService,
|
||||
emailQueueService: emailQueueService,
|
||||
promoService: promoService,
|
||||
}
|
||||
}
|
||||
|
||||
// Register 用户注册,返回token和用户
|
||||
func (s *AuthService) Register(ctx context.Context, email, password string) (string, *User, error) {
|
||||
return s.RegisterWithVerification(ctx, email, password, "")
|
||||
return s.RegisterWithVerification(ctx, email, password, "", "")
|
||||
}
|
||||
|
||||
// RegisterWithVerification 用户注册(支持邮件验证),返回token和用户
|
||||
func (s *AuthService) RegisterWithVerification(ctx context.Context, email, password, verifyCode string) (string, *User, error) {
|
||||
// RegisterWithVerification 用户注册(支持邮件验证和优惠码),返回token和用户
|
||||
func (s *AuthService) RegisterWithVerification(ctx context.Context, email, password, verifyCode, promoCode string) (string, *User, error) {
|
||||
// 检查是否开放注册(默认关闭:settingService 未配置时不允许注册)
|
||||
if s.settingService == nil || !s.settingService.IsRegistrationEnabled(ctx) {
|
||||
return "", nil, ErrRegDisabled
|
||||
@@ -150,6 +153,19 @@ func (s *AuthService) RegisterWithVerification(ctx context.Context, email, passw
|
||||
return "", nil, ErrServiceUnavailable
|
||||
}
|
||||
|
||||
// 应用优惠码(如果提供)
|
||||
if promoCode != "" && s.promoService != nil {
|
||||
if err := s.promoService.ApplyPromoCode(ctx, user.ID, promoCode); err != nil {
|
||||
// 优惠码应用失败不影响注册,只记录日志
|
||||
log.Printf("[Auth] Failed to apply promo code for user %d: %v", user.ID, err)
|
||||
} else {
|
||||
// 重新获取用户信息以获取更新后的余额
|
||||
if updatedUser, err := s.userRepo.GetByID(ctx, user.ID); err == nil {
|
||||
user = updatedUser
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 生成token
|
||||
token, err := s.GenerateToken(user)
|
||||
if err != nil {
|
||||
|
||||
@@ -38,6 +38,12 @@ const (
|
||||
RedeemTypeSubscription = "subscription"
|
||||
)
|
||||
|
||||
// PromoCode status constants
|
||||
const (
|
||||
PromoCodeStatusActive = "active"
|
||||
PromoCodeStatusDisabled = "disabled"
|
||||
)
|
||||
|
||||
// Admin adjustment type constants
|
||||
const (
|
||||
AdjustmentTypeAdminBalance = "admin_balance" // 管理员调整余额
|
||||
|
||||
73
backend/internal/service/promo_code.go
Normal file
73
backend/internal/service/promo_code.go
Normal file
@@ -0,0 +1,73 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"time"
|
||||
)
|
||||
|
||||
// PromoCode 注册优惠码
|
||||
type PromoCode struct {
|
||||
ID int64
|
||||
Code string
|
||||
BonusAmount float64
|
||||
MaxUses int
|
||||
UsedCount int
|
||||
Status string
|
||||
ExpiresAt *time.Time
|
||||
Notes string
|
||||
CreatedAt time.Time
|
||||
UpdatedAt time.Time
|
||||
|
||||
// 关联
|
||||
UsageRecords []PromoCodeUsage
|
||||
}
|
||||
|
||||
// PromoCodeUsage 优惠码使用记录
|
||||
type PromoCodeUsage struct {
|
||||
ID int64
|
||||
PromoCodeID int64
|
||||
UserID int64
|
||||
BonusAmount float64
|
||||
UsedAt time.Time
|
||||
|
||||
// 关联
|
||||
PromoCode *PromoCode
|
||||
User *User
|
||||
}
|
||||
|
||||
// CanUse 检查优惠码是否可用
|
||||
func (p *PromoCode) CanUse() bool {
|
||||
if p.Status != PromoCodeStatusActive {
|
||||
return false
|
||||
}
|
||||
if p.ExpiresAt != nil && time.Now().After(*p.ExpiresAt) {
|
||||
return false
|
||||
}
|
||||
if p.MaxUses > 0 && p.UsedCount >= p.MaxUses {
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// IsExpired 检查是否已过期
|
||||
func (p *PromoCode) IsExpired() bool {
|
||||
return p.ExpiresAt != nil && time.Now().After(*p.ExpiresAt)
|
||||
}
|
||||
|
||||
// CreatePromoCodeInput 创建优惠码输入
|
||||
type CreatePromoCodeInput struct {
|
||||
Code string
|
||||
BonusAmount float64
|
||||
MaxUses int
|
||||
ExpiresAt *time.Time
|
||||
Notes string
|
||||
}
|
||||
|
||||
// UpdatePromoCodeInput 更新优惠码输入
|
||||
type UpdatePromoCodeInput struct {
|
||||
Code *string
|
||||
BonusAmount *float64
|
||||
MaxUses *int
|
||||
Status *string
|
||||
ExpiresAt *time.Time
|
||||
Notes *string
|
||||
}
|
||||
30
backend/internal/service/promo_code_repository.go
Normal file
30
backend/internal/service/promo_code_repository.go
Normal file
@@ -0,0 +1,30 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||||
)
|
||||
|
||||
// PromoCodeRepository 优惠码仓储接口
|
||||
type PromoCodeRepository interface {
|
||||
// 基础 CRUD
|
||||
Create(ctx context.Context, code *PromoCode) error
|
||||
GetByID(ctx context.Context, id int64) (*PromoCode, error)
|
||||
GetByCode(ctx context.Context, code string) (*PromoCode, error)
|
||||
GetByCodeForUpdate(ctx context.Context, code string) (*PromoCode, error) // 带行锁的查询,用于并发控制
|
||||
Update(ctx context.Context, code *PromoCode) error
|
||||
Delete(ctx context.Context, id int64) error
|
||||
|
||||
// 列表查询
|
||||
List(ctx context.Context, params pagination.PaginationParams) ([]PromoCode, *pagination.PaginationResult, error)
|
||||
ListWithFilters(ctx context.Context, params pagination.PaginationParams, status, search string) ([]PromoCode, *pagination.PaginationResult, error)
|
||||
|
||||
// 使用记录
|
||||
CreateUsage(ctx context.Context, usage *PromoCodeUsage) error
|
||||
GetUsageByPromoCodeAndUser(ctx context.Context, promoCodeID, userID int64) (*PromoCodeUsage, error)
|
||||
ListUsagesByPromoCode(ctx context.Context, promoCodeID int64, params pagination.PaginationParams) ([]PromoCodeUsage, *pagination.PaginationResult, error)
|
||||
|
||||
// 计数操作
|
||||
IncrementUsedCount(ctx context.Context, id int64) error
|
||||
}
|
||||
256
backend/internal/service/promo_service.go
Normal file
256
backend/internal/service/promo_service.go
Normal file
@@ -0,0 +1,256 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
dbent "github.com/Wei-Shaw/sub2api/ent"
|
||||
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrPromoCodeNotFound = infraerrors.NotFound("PROMO_CODE_NOT_FOUND", "promo code not found")
|
||||
ErrPromoCodeExpired = infraerrors.BadRequest("PROMO_CODE_EXPIRED", "promo code has expired")
|
||||
ErrPromoCodeDisabled = infraerrors.BadRequest("PROMO_CODE_DISABLED", "promo code is disabled")
|
||||
ErrPromoCodeMaxUsed = infraerrors.BadRequest("PROMO_CODE_MAX_USED", "promo code has reached maximum uses")
|
||||
ErrPromoCodeAlreadyUsed = infraerrors.Conflict("PROMO_CODE_ALREADY_USED", "you have already used this promo code")
|
||||
ErrPromoCodeInvalid = infraerrors.BadRequest("PROMO_CODE_INVALID", "invalid promo code")
|
||||
)
|
||||
|
||||
// PromoService 优惠码服务
|
||||
type PromoService struct {
|
||||
promoRepo PromoCodeRepository
|
||||
userRepo UserRepository
|
||||
billingCacheService *BillingCacheService
|
||||
entClient *dbent.Client
|
||||
}
|
||||
|
||||
// NewPromoService 创建优惠码服务实例
|
||||
func NewPromoService(
|
||||
promoRepo PromoCodeRepository,
|
||||
userRepo UserRepository,
|
||||
billingCacheService *BillingCacheService,
|
||||
entClient *dbent.Client,
|
||||
) *PromoService {
|
||||
return &PromoService{
|
||||
promoRepo: promoRepo,
|
||||
userRepo: userRepo,
|
||||
billingCacheService: billingCacheService,
|
||||
entClient: entClient,
|
||||
}
|
||||
}
|
||||
|
||||
// ValidatePromoCode 验证优惠码(注册前调用)
|
||||
// 返回 nil, nil 表示空码(不报错)
|
||||
func (s *PromoService) ValidatePromoCode(ctx context.Context, code string) (*PromoCode, error) {
|
||||
code = strings.TrimSpace(code)
|
||||
if code == "" {
|
||||
return nil, nil // 空码不报错,直接返回
|
||||
}
|
||||
|
||||
promoCode, err := s.promoRepo.GetByCode(ctx, code)
|
||||
if err != nil {
|
||||
// 保留原始错误类型,不要统一映射为 NotFound
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := s.validatePromoCodeStatus(promoCode); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return promoCode, nil
|
||||
}
|
||||
|
||||
// validatePromoCodeStatus 验证优惠码状态
|
||||
func (s *PromoService) validatePromoCodeStatus(promoCode *PromoCode) error {
|
||||
if !promoCode.CanUse() {
|
||||
if promoCode.IsExpired() {
|
||||
return ErrPromoCodeExpired
|
||||
}
|
||||
if promoCode.Status == PromoCodeStatusDisabled {
|
||||
return ErrPromoCodeDisabled
|
||||
}
|
||||
if promoCode.MaxUses > 0 && promoCode.UsedCount >= promoCode.MaxUses {
|
||||
return ErrPromoCodeMaxUsed
|
||||
}
|
||||
return ErrPromoCodeInvalid
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// ApplyPromoCode 应用优惠码(注册成功后调用)
|
||||
// 使用事务和行锁确保并发安全
|
||||
func (s *PromoService) ApplyPromoCode(ctx context.Context, userID int64, code string) error {
|
||||
code = strings.TrimSpace(code)
|
||||
if code == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
// 开启事务
|
||||
tx, err := s.entClient.Tx(ctx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("begin transaction: %w", err)
|
||||
}
|
||||
defer func() { _ = tx.Rollback() }()
|
||||
|
||||
txCtx := dbent.NewTxContext(ctx, tx)
|
||||
|
||||
// 在事务中获取并锁定优惠码记录(FOR UPDATE)
|
||||
promoCode, err := s.promoRepo.GetByCodeForUpdate(txCtx, code)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 在事务中验证优惠码状态
|
||||
if err := s.validatePromoCodeStatus(promoCode); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 在事务中检查用户是否已使用过此优惠码
|
||||
existing, err := s.promoRepo.GetUsageByPromoCodeAndUser(txCtx, promoCode.ID, userID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("check existing usage: %w", err)
|
||||
}
|
||||
if existing != nil {
|
||||
return ErrPromoCodeAlreadyUsed
|
||||
}
|
||||
|
||||
// 增加用户余额
|
||||
if err := s.userRepo.UpdateBalance(txCtx, userID, promoCode.BonusAmount); err != nil {
|
||||
return fmt.Errorf("update user balance: %w", err)
|
||||
}
|
||||
|
||||
// 创建使用记录
|
||||
usage := &PromoCodeUsage{
|
||||
PromoCodeID: promoCode.ID,
|
||||
UserID: userID,
|
||||
BonusAmount: promoCode.BonusAmount,
|
||||
UsedAt: time.Now(),
|
||||
}
|
||||
if err := s.promoRepo.CreateUsage(txCtx, usage); err != nil {
|
||||
return fmt.Errorf("create usage record: %w", err)
|
||||
}
|
||||
|
||||
// 增加使用次数
|
||||
if err := s.promoRepo.IncrementUsedCount(txCtx, promoCode.ID); err != nil {
|
||||
return fmt.Errorf("increment used count: %w", err)
|
||||
}
|
||||
|
||||
if err := tx.Commit(); err != nil {
|
||||
return fmt.Errorf("commit transaction: %w", err)
|
||||
}
|
||||
|
||||
// 失效余额缓存
|
||||
if s.billingCacheService != nil {
|
||||
go func() {
|
||||
cacheCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
_ = s.billingCacheService.InvalidateUserBalance(cacheCtx, userID)
|
||||
}()
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GenerateRandomCode 生成随机优惠码
|
||||
func (s *PromoService) GenerateRandomCode() (string, error) {
|
||||
bytes := make([]byte, 8)
|
||||
if _, err := rand.Read(bytes); err != nil {
|
||||
return "", fmt.Errorf("generate random bytes: %w", err)
|
||||
}
|
||||
return strings.ToUpper(hex.EncodeToString(bytes)), nil
|
||||
}
|
||||
|
||||
// Create 创建优惠码
|
||||
func (s *PromoService) Create(ctx context.Context, input *CreatePromoCodeInput) (*PromoCode, error) {
|
||||
code := strings.TrimSpace(input.Code)
|
||||
if code == "" {
|
||||
// 自动生成
|
||||
var err error
|
||||
code, err = s.GenerateRandomCode()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
promoCode := &PromoCode{
|
||||
Code: strings.ToUpper(code),
|
||||
BonusAmount: input.BonusAmount,
|
||||
MaxUses: input.MaxUses,
|
||||
UsedCount: 0,
|
||||
Status: PromoCodeStatusActive,
|
||||
ExpiresAt: input.ExpiresAt,
|
||||
Notes: input.Notes,
|
||||
}
|
||||
|
||||
if err := s.promoRepo.Create(ctx, promoCode); err != nil {
|
||||
return nil, fmt.Errorf("create promo code: %w", err)
|
||||
}
|
||||
|
||||
return promoCode, nil
|
||||
}
|
||||
|
||||
// GetByID 根据ID获取优惠码
|
||||
func (s *PromoService) GetByID(ctx context.Context, id int64) (*PromoCode, error) {
|
||||
code, err := s.promoRepo.GetByID(ctx, id)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return code, nil
|
||||
}
|
||||
|
||||
// Update 更新优惠码
|
||||
func (s *PromoService) Update(ctx context.Context, id int64, input *UpdatePromoCodeInput) (*PromoCode, error) {
|
||||
promoCode, err := s.promoRepo.GetByID(ctx, id)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if input.Code != nil {
|
||||
promoCode.Code = strings.ToUpper(strings.TrimSpace(*input.Code))
|
||||
}
|
||||
if input.BonusAmount != nil {
|
||||
promoCode.BonusAmount = *input.BonusAmount
|
||||
}
|
||||
if input.MaxUses != nil {
|
||||
promoCode.MaxUses = *input.MaxUses
|
||||
}
|
||||
if input.Status != nil {
|
||||
promoCode.Status = *input.Status
|
||||
}
|
||||
if input.ExpiresAt != nil {
|
||||
promoCode.ExpiresAt = input.ExpiresAt
|
||||
}
|
||||
if input.Notes != nil {
|
||||
promoCode.Notes = *input.Notes
|
||||
}
|
||||
|
||||
if err := s.promoRepo.Update(ctx, promoCode); err != nil {
|
||||
return nil, fmt.Errorf("update promo code: %w", err)
|
||||
}
|
||||
|
||||
return promoCode, nil
|
||||
}
|
||||
|
||||
// Delete 删除优惠码
|
||||
func (s *PromoService) Delete(ctx context.Context, id int64) error {
|
||||
if err := s.promoRepo.Delete(ctx, id); err != nil {
|
||||
return fmt.Errorf("delete promo code: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// List 获取优惠码列表
|
||||
func (s *PromoService) List(ctx context.Context, params pagination.PaginationParams, status, search string) ([]PromoCode, *pagination.PaginationResult, error) {
|
||||
return s.promoRepo.ListWithFilters(ctx, params, status, search)
|
||||
}
|
||||
|
||||
// ListUsages 获取使用记录
|
||||
func (s *PromoService) ListUsages(ctx context.Context, promoCodeID int64, params pagination.PaginationParams) ([]PromoCodeUsage, *pagination.PaginationResult, error) {
|
||||
return s.promoRepo.ListUsagesByPromoCode(ctx, promoCodeID, params)
|
||||
}
|
||||
@@ -87,6 +87,7 @@ var ProviderSet = wire.NewSet(
|
||||
NewAccountService,
|
||||
NewProxyService,
|
||||
NewRedeemService,
|
||||
NewPromoService,
|
||||
NewUsageService,
|
||||
NewDashboardService,
|
||||
ProvidePricingService,
|
||||
|
||||
Reference in New Issue
Block a user