Files
sub2api/backend/internal/service/redeem_service.go
2026-02-03 13:38:44 +08:00

446 lines
13 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package service
import (
"context"
"crypto/rand"
"encoding/hex"
"errors"
"fmt"
"strings"
"time"
dbent "github.com/Wei-Shaw/sub2api/ent"
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
)
var (
ErrRedeemCodeNotFound = infraerrors.NotFound("REDEEM_CODE_NOT_FOUND", "redeem code not found")
ErrRedeemCodeUsed = infraerrors.Conflict("REDEEM_CODE_USED", "redeem code already used")
ErrInsufficientBalance = infraerrors.BadRequest("INSUFFICIENT_BALANCE", "insufficient balance")
ErrRedeemRateLimited = infraerrors.TooManyRequests("REDEEM_RATE_LIMITED", "too many failed attempts, please try again later")
ErrRedeemCodeLocked = infraerrors.Conflict("REDEEM_CODE_LOCKED", "redeem code is being processed, please try again")
)
const (
redeemMaxErrorsPerHour = 20
redeemRateLimitDuration = time.Hour
redeemLockDuration = 10 * time.Second // 锁超时时间,防止死锁
)
// RedeemCache defines cache operations for redeem service
type RedeemCache interface {
GetRedeemAttemptCount(ctx context.Context, userID int64) (int, error)
IncrementRedeemAttemptCount(ctx context.Context, userID int64) error
AcquireRedeemLock(ctx context.Context, code string, ttl time.Duration) (bool, error)
ReleaseRedeemLock(ctx context.Context, code string) error
}
type RedeemCodeRepository interface {
Create(ctx context.Context, code *RedeemCode) error
CreateBatch(ctx context.Context, codes []RedeemCode) error
GetByID(ctx context.Context, id int64) (*RedeemCode, error)
GetByCode(ctx context.Context, code string) (*RedeemCode, error)
Update(ctx context.Context, code *RedeemCode) error
Delete(ctx context.Context, id int64) error
Use(ctx context.Context, id, userID int64) error
List(ctx context.Context, params pagination.PaginationParams) ([]RedeemCode, *pagination.PaginationResult, error)
ListWithFilters(ctx context.Context, params pagination.PaginationParams, codeType, status, search string) ([]RedeemCode, *pagination.PaginationResult, error)
ListByUser(ctx context.Context, userID int64, limit int) ([]RedeemCode, error)
}
// GenerateCodesRequest 生成兑换码请求
type GenerateCodesRequest struct {
Count int `json:"count"`
Value float64 `json:"value"`
Type string `json:"type"`
}
// RedeemCodeResponse 兑换码响应
type RedeemCodeResponse struct {
Code string `json:"code"`
Value float64 `json:"value"`
Status string `json:"status"`
CreatedAt time.Time `json:"created_at"`
}
// RedeemService 兑换码服务
type RedeemService struct {
redeemRepo RedeemCodeRepository
userRepo UserRepository
subscriptionService *SubscriptionService
cache RedeemCache
billingCacheService *BillingCacheService
entClient *dbent.Client
authCacheInvalidator APIKeyAuthCacheInvalidator
}
// NewRedeemService 创建兑换码服务实例
func NewRedeemService(
redeemRepo RedeemCodeRepository,
userRepo UserRepository,
subscriptionService *SubscriptionService,
cache RedeemCache,
billingCacheService *BillingCacheService,
entClient *dbent.Client,
authCacheInvalidator APIKeyAuthCacheInvalidator,
) *RedeemService {
return &RedeemService{
redeemRepo: redeemRepo,
userRepo: userRepo,
subscriptionService: subscriptionService,
cache: cache,
billingCacheService: billingCacheService,
entClient: entClient,
authCacheInvalidator: authCacheInvalidator,
}
}
// GenerateRandomCode 生成随机兑换码
func (s *RedeemService) GenerateRandomCode() (string, error) {
// 生成16字节随机数据
bytes := make([]byte, 16)
if _, err := rand.Read(bytes); err != nil {
return "", fmt.Errorf("generate random bytes: %w", err)
}
// 转换为十六进制字符串
code := hex.EncodeToString(bytes)
// 格式化为 XXXX-XXXX-XXXX-XXXX 格式
parts := []string{
strings.ToUpper(code[0:8]),
strings.ToUpper(code[8:16]),
strings.ToUpper(code[16:24]),
strings.ToUpper(code[24:32]),
}
return strings.Join(parts, "-"), nil
}
// GenerateCodes 批量生成兑换码
func (s *RedeemService) GenerateCodes(ctx context.Context, req GenerateCodesRequest) ([]RedeemCode, error) {
if req.Count <= 0 {
return nil, errors.New("count must be greater than 0")
}
// 邀请码类型不需要数值,其他类型需要
if req.Type != RedeemTypeInvitation && req.Value <= 0 {
return nil, errors.New("value must be greater than 0")
}
if req.Count > 1000 {
return nil, errors.New("cannot generate more than 1000 codes at once")
}
codeType := req.Type
if codeType == "" {
codeType = RedeemTypeBalance
}
// 邀请码类型的 value 设为 0
value := req.Value
if codeType == RedeemTypeInvitation {
value = 0
}
codes := make([]RedeemCode, 0, req.Count)
for i := 0; i < req.Count; i++ {
code, err := s.GenerateRandomCode()
if err != nil {
return nil, fmt.Errorf("generate code: %w", err)
}
codes = append(codes, RedeemCode{
Code: code,
Type: codeType,
Value: value,
Status: StatusUnused,
})
}
// 批量插入
if err := s.redeemRepo.CreateBatch(ctx, codes); err != nil {
return nil, fmt.Errorf("create batch codes: %w", err)
}
return codes, nil
}
// checkRedeemRateLimit 检查用户兑换错误次数是否超限
func (s *RedeemService) checkRedeemRateLimit(ctx context.Context, userID int64) error {
if s.cache == nil {
return nil
}
count, err := s.cache.GetRedeemAttemptCount(ctx, userID)
if err != nil {
// Redis 出错时不阻止用户操作
return nil
}
if count >= redeemMaxErrorsPerHour {
return ErrRedeemRateLimited
}
return nil
}
// incrementRedeemErrorCount 增加用户兑换错误计数
func (s *RedeemService) incrementRedeemErrorCount(ctx context.Context, userID int64) {
if s.cache == nil {
return
}
_ = s.cache.IncrementRedeemAttemptCount(ctx, userID)
}
// acquireRedeemLock 尝试获取兑换码的分布式锁
// 返回 true 表示获取成功false 表示锁已被占用
func (s *RedeemService) acquireRedeemLock(ctx context.Context, code string) bool {
if s.cache == nil {
return true // 无 Redis 时降级为不加锁
}
ok, err := s.cache.AcquireRedeemLock(ctx, code, redeemLockDuration)
if err != nil {
// Redis 出错时不阻止操作,依赖数据库层面的状态检查
return true
}
return ok
}
// releaseRedeemLock 释放兑换码的分布式锁
func (s *RedeemService) releaseRedeemLock(ctx context.Context, code string) {
if s.cache == nil {
return
}
_ = s.cache.ReleaseRedeemLock(ctx, code)
}
// Redeem 使用兑换码
func (s *RedeemService) Redeem(ctx context.Context, userID int64, code string) (*RedeemCode, error) {
// 检查限流
if err := s.checkRedeemRateLimit(ctx, userID); err != nil {
return nil, err
}
// 获取分布式锁,防止同一兑换码并发使用
if !s.acquireRedeemLock(ctx, code) {
return nil, ErrRedeemCodeLocked
}
defer s.releaseRedeemLock(ctx, code)
// 查找兑换码
redeemCode, err := s.redeemRepo.GetByCode(ctx, code)
if err != nil {
if errors.Is(err, ErrRedeemCodeNotFound) {
s.incrementRedeemErrorCount(ctx, userID)
return nil, ErrRedeemCodeNotFound
}
return nil, fmt.Errorf("get redeem code: %w", err)
}
// 检查兑换码状态
if !redeemCode.CanUse() {
s.incrementRedeemErrorCount(ctx, userID)
return nil, ErrRedeemCodeUsed
}
// 验证兑换码类型的前置条件
if redeemCode.Type == RedeemTypeSubscription && redeemCode.GroupID == nil {
return nil, infraerrors.BadRequest("REDEEM_CODE_INVALID", "invalid subscription redeem code: missing group_id")
}
// 获取用户信息
user, err := s.userRepo.GetByID(ctx, userID)
if err != nil {
return nil, fmt.Errorf("get user: %w", err)
}
_ = user // 使用变量避免未使用错误
// 使用数据库事务保证兑换码标记与权益发放的原子性
tx, err := s.entClient.Tx(ctx)
if err != nil {
return nil, fmt.Errorf("begin transaction: %w", err)
}
defer func() { _ = tx.Rollback() }()
// 将事务放入 context使 repository 方法能够使用同一事务
txCtx := dbent.NewTxContext(ctx, tx)
// 【关键】先标记兑换码为已使用,确保并发安全
// 利用数据库乐观锁WHERE status = 'unused')保证原子性
if err := s.redeemRepo.Use(txCtx, redeemCode.ID, userID); err != nil {
if errors.Is(err, ErrRedeemCodeNotFound) || errors.Is(err, ErrRedeemCodeUsed) {
return nil, ErrRedeemCodeUsed
}
return nil, fmt.Errorf("mark code as used: %w", err)
}
// 执行兑换逻辑(兑换码已被锁定,此时可安全操作)
switch redeemCode.Type {
case RedeemTypeBalance:
// 增加用户余额
if err := s.userRepo.UpdateBalance(txCtx, userID, redeemCode.Value); err != nil {
return nil, fmt.Errorf("update user balance: %w", err)
}
case RedeemTypeConcurrency:
// 增加用户并发数
if err := s.userRepo.UpdateConcurrency(txCtx, userID, int(redeemCode.Value)); err != nil {
return nil, fmt.Errorf("update user concurrency: %w", err)
}
case RedeemTypeSubscription:
validityDays := redeemCode.ValidityDays
if validityDays <= 0 {
validityDays = 30
}
_, _, err := s.subscriptionService.AssignOrExtendSubscription(txCtx, &AssignSubscriptionInput{
UserID: userID,
GroupID: *redeemCode.GroupID,
ValidityDays: validityDays,
AssignedBy: 0, // 系统分配
Notes: fmt.Sprintf("通过兑换码 %s 兑换", redeemCode.Code),
})
if err != nil {
return nil, fmt.Errorf("assign or extend subscription: %w", err)
}
default:
return nil, fmt.Errorf("unsupported redeem type: %s", redeemCode.Type)
}
// 提交事务
if err := tx.Commit(); err != nil {
return nil, fmt.Errorf("commit transaction: %w", err)
}
// 事务提交成功后失效缓存
s.invalidateRedeemCaches(ctx, userID, redeemCode)
// 重新获取更新后的兑换码
redeemCode, err = s.redeemRepo.GetByID(ctx, redeemCode.ID)
if err != nil {
return nil, fmt.Errorf("get updated redeem code: %w", err)
}
return redeemCode, nil
}
// invalidateRedeemCaches 失效兑换相关的缓存
func (s *RedeemService) invalidateRedeemCaches(ctx context.Context, userID int64, redeemCode *RedeemCode) {
switch redeemCode.Type {
case RedeemTypeBalance:
if s.authCacheInvalidator != nil {
s.authCacheInvalidator.InvalidateAuthCacheByUserID(ctx, userID)
}
if s.billingCacheService == nil {
return
}
go func() {
cacheCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
_ = s.billingCacheService.InvalidateUserBalance(cacheCtx, userID)
}()
case RedeemTypeConcurrency:
if s.authCacheInvalidator != nil {
s.authCacheInvalidator.InvalidateAuthCacheByUserID(ctx, userID)
}
if s.billingCacheService == nil {
return
}
case RedeemTypeSubscription:
if s.authCacheInvalidator != nil {
s.authCacheInvalidator.InvalidateAuthCacheByUserID(ctx, userID)
}
if s.billingCacheService == nil {
return
}
if redeemCode.GroupID != nil {
groupID := *redeemCode.GroupID
go func() {
cacheCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
_ = s.billingCacheService.InvalidateSubscription(cacheCtx, userID, groupID)
}()
}
}
}
// GetByID 根据ID获取兑换码
func (s *RedeemService) GetByID(ctx context.Context, id int64) (*RedeemCode, error) {
code, err := s.redeemRepo.GetByID(ctx, id)
if err != nil {
return nil, fmt.Errorf("get redeem code: %w", err)
}
return code, nil
}
// GetByCode 根据Code获取兑换码
func (s *RedeemService) GetByCode(ctx context.Context, code string) (*RedeemCode, error) {
redeemCode, err := s.redeemRepo.GetByCode(ctx, code)
if err != nil {
return nil, fmt.Errorf("get redeem code: %w", err)
}
return redeemCode, nil
}
// List 获取兑换码列表(管理员功能)
func (s *RedeemService) List(ctx context.Context, params pagination.PaginationParams) ([]RedeemCode, *pagination.PaginationResult, error) {
codes, pagination, err := s.redeemRepo.List(ctx, params)
if err != nil {
return nil, nil, fmt.Errorf("list redeem codes: %w", err)
}
return codes, pagination, nil
}
// Delete 删除兑换码(管理员功能)
func (s *RedeemService) Delete(ctx context.Context, id int64) error {
// 检查兑换码是否存在
code, err := s.redeemRepo.GetByID(ctx, id)
if err != nil {
return fmt.Errorf("get redeem code: %w", err)
}
// 不允许删除已使用的兑换码
if code.IsUsed() {
return infraerrors.Conflict("REDEEM_CODE_DELETE_USED", "cannot delete used redeem code")
}
if err := s.redeemRepo.Delete(ctx, id); err != nil {
return fmt.Errorf("delete redeem code: %w", err)
}
return nil
}
// GetStats 获取兑换码统计信息
func (s *RedeemService) GetStats(ctx context.Context) (map[string]any, error) {
// TODO: 实现统计逻辑
// 统计未使用、已使用的兑换码数量
// 统计总面值等
stats := map[string]any{
"total_codes": 0,
"unused_codes": 0,
"used_codes": 0,
"total_value": 0.0,
}
return stats, nil
}
// GetUserHistory 获取用户的兑换历史
func (s *RedeemService) GetUserHistory(ctx context.Context, userID int64, limit int) ([]RedeemCode, error) {
codes, err := s.redeemRepo.ListByUser(ctx, userID, limit)
if err != nil {
return nil, fmt.Errorf("get user redeem history: %w", err)
}
return codes, nil
}