package service import ( "context" "crypto/rand" "encoding/hex" "errors" "fmt" "strings" "sub2api/internal/model" "sub2api/internal/repository" "time" "github.com/redis/go-redis/v9" "gorm.io/gorm" ) var ( ErrRedeemCodeNotFound = errors.New("redeem code not found") ErrRedeemCodeUsed = errors.New("redeem code already used") ErrRedeemCodeInvalid = errors.New("invalid redeem code") ErrInsufficientBalance = errors.New("insufficient balance") ErrRedeemRateLimited = errors.New("too many failed attempts, please try again later") ErrRedeemCodeLocked = errors.New("redeem code is being processed, please try again") ) const ( redeemRateLimitKeyPrefix = "redeem:rate_limit:" redeemLockKeyPrefix = "redeem:lock:" redeemMaxErrorsPerHour = 20 redeemRateLimitDuration = time.Hour redeemLockDuration = 10 * time.Second // 锁超时时间,防止死锁 ) // GenerateCodesRequest 生成兑换码请求 type GenerateCodesRequest struct { Count int `json:"count"` Value float64 `json:"value"` Type string `json:"type"` } // RedeemCodeResponse 兑换码响应 type RedeemCodeResponse struct { Code string `json:"code"` Value float64 `json:"value"` Status string `json:"status"` CreatedAt time.Time `json:"created_at"` } // RedeemService 兑换码服务 type RedeemService struct { redeemRepo *repository.RedeemCodeRepository userRepo *repository.UserRepository subscriptionService *SubscriptionService rdb *redis.Client billingCacheService *BillingCacheService } // NewRedeemService 创建兑换码服务实例 func NewRedeemService(redeemRepo *repository.RedeemCodeRepository, userRepo *repository.UserRepository, subscriptionService *SubscriptionService, rdb *redis.Client) *RedeemService { return &RedeemService{ redeemRepo: redeemRepo, userRepo: userRepo, subscriptionService: subscriptionService, rdb: rdb, } } // SetBillingCacheService 设置计费缓存服务(用于缓存失效) func (s *RedeemService) SetBillingCacheService(billingCacheService *BillingCacheService) { s.billingCacheService = billingCacheService } // GenerateRandomCode 生成随机兑换码 func (s *RedeemService) GenerateRandomCode() (string, error) { // 生成16字节随机数据 bytes := make([]byte, 16) if _, err := rand.Read(bytes); err != nil { return "", fmt.Errorf("generate random bytes: %w", err) } // 转换为十六进制字符串 code := hex.EncodeToString(bytes) // 格式化为 XXXX-XXXX-XXXX-XXXX 格式 parts := []string{ strings.ToUpper(code[0:8]), strings.ToUpper(code[8:16]), strings.ToUpper(code[16:24]), strings.ToUpper(code[24:32]), } return strings.Join(parts, "-"), nil } // GenerateCodes 批量生成兑换码 func (s *RedeemService) GenerateCodes(ctx context.Context, req GenerateCodesRequest) ([]model.RedeemCode, error) { if req.Count <= 0 { return nil, errors.New("count must be greater than 0") } if req.Value <= 0 { return nil, errors.New("value must be greater than 0") } if req.Count > 1000 { return nil, errors.New("cannot generate more than 1000 codes at once") } codeType := req.Type if codeType == "" { codeType = model.RedeemTypeBalance } codes := make([]model.RedeemCode, 0, req.Count) for i := 0; i < req.Count; i++ { code, err := s.GenerateRandomCode() if err != nil { return nil, fmt.Errorf("generate code: %w", err) } codes = append(codes, model.RedeemCode{ Code: code, Type: codeType, Value: req.Value, Status: model.StatusUnused, }) } // 批量插入 if err := s.redeemRepo.CreateBatch(ctx, codes); err != nil { return nil, fmt.Errorf("create batch codes: %w", err) } return codes, nil } // checkRedeemRateLimit 检查用户兑换错误次数是否超限 func (s *RedeemService) checkRedeemRateLimit(ctx context.Context, userID int64) error { if s.rdb == nil { return nil } key := fmt.Sprintf("%s%d", redeemRateLimitKeyPrefix, userID) count, err := s.rdb.Get(ctx, key).Int() if err != nil && !errors.Is(err, redis.Nil) { // Redis 出错时不阻止用户操作 return nil } if count >= redeemMaxErrorsPerHour { return ErrRedeemRateLimited } return nil } // incrementRedeemErrorCount 增加用户兑换错误计数 func (s *RedeemService) incrementRedeemErrorCount(ctx context.Context, userID int64) { if s.rdb == nil { return } key := fmt.Sprintf("%s%d", redeemRateLimitKeyPrefix, userID) pipe := s.rdb.Pipeline() pipe.Incr(ctx, key) pipe.Expire(ctx, key, redeemRateLimitDuration) _, _ = pipe.Exec(ctx) } // acquireRedeemLock 尝试获取兑换码的分布式锁 // 返回 true 表示获取成功,false 表示锁已被占用 func (s *RedeemService) acquireRedeemLock(ctx context.Context, code string) bool { if s.rdb == nil { return true // 无 Redis 时降级为不加锁 } key := redeemLockKeyPrefix + code ok, err := s.rdb.SetNX(ctx, key, "1", redeemLockDuration).Result() if err != nil { // Redis 出错时不阻止操作,依赖数据库层面的状态检查 return true } return ok } // releaseRedeemLock 释放兑换码的分布式锁 func (s *RedeemService) releaseRedeemLock(ctx context.Context, code string) { if s.rdb == nil { return } key := redeemLockKeyPrefix + code s.rdb.Del(ctx, key) } // Redeem 使用兑换码 func (s *RedeemService) Redeem(ctx context.Context, userID int64, code string) (*model.RedeemCode, error) { // 检查限流 if err := s.checkRedeemRateLimit(ctx, userID); err != nil { return nil, err } // 获取分布式锁,防止同一兑换码并发使用 if !s.acquireRedeemLock(ctx, code) { return nil, ErrRedeemCodeLocked } defer s.releaseRedeemLock(ctx, code) // 查找兑换码 redeemCode, err := s.redeemRepo.GetByCode(ctx, code) if err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { s.incrementRedeemErrorCount(ctx, userID) return nil, ErrRedeemCodeNotFound } return nil, fmt.Errorf("get redeem code: %w", err) } // 检查兑换码状态 if !redeemCode.CanUse() { s.incrementRedeemErrorCount(ctx, userID) return nil, ErrRedeemCodeUsed } // 验证兑换码类型的前置条件 if redeemCode.Type == model.RedeemTypeSubscription && redeemCode.GroupID == nil { return nil, errors.New("invalid subscription redeem code: missing group_id") } // 获取用户信息 user, err := s.userRepo.GetByID(ctx, userID) if err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { return nil, ErrUserNotFound } return nil, fmt.Errorf("get user: %w", err) } _ = user // 使用变量避免未使用错误 // 【关键】先标记兑换码为已使用,确保并发安全 // 利用数据库乐观锁(WHERE status = 'unused')保证原子性 if err := s.redeemRepo.Use(ctx, redeemCode.ID, userID); err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { // 兑换码已被其他请求使用 return nil, ErrRedeemCodeUsed } return nil, fmt.Errorf("mark code as used: %w", err) } // 执行兑换逻辑(兑换码已被锁定,此时可安全操作) switch redeemCode.Type { case model.RedeemTypeBalance: // 增加用户余额 if err := s.userRepo.UpdateBalance(ctx, userID, redeemCode.Value); err != nil { return nil, fmt.Errorf("update user balance: %w", err) } // 失效余额缓存 if s.billingCacheService != nil { go func() { cacheCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() s.billingCacheService.InvalidateUserBalance(cacheCtx, userID) }() } case model.RedeemTypeConcurrency: // 增加用户并发数 if err := s.userRepo.UpdateConcurrency(ctx, userID, int(redeemCode.Value)); err != nil { return nil, fmt.Errorf("update user concurrency: %w", err) } case model.RedeemTypeSubscription: validityDays := redeemCode.ValidityDays if validityDays <= 0 { validityDays = 30 } _, _, err := s.subscriptionService.AssignOrExtendSubscription(ctx, &AssignSubscriptionInput{ UserID: userID, GroupID: *redeemCode.GroupID, ValidityDays: validityDays, AssignedBy: 0, // 系统分配 Notes: fmt.Sprintf("通过兑换码 %s 兑换", redeemCode.Code), }) if err != nil { return nil, fmt.Errorf("assign or extend subscription: %w", err) } // 失效订阅缓存 if s.billingCacheService != nil { groupID := *redeemCode.GroupID go func() { cacheCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() s.billingCacheService.InvalidateSubscription(cacheCtx, userID, groupID) }() } default: return nil, fmt.Errorf("unsupported redeem type: %s", redeemCode.Type) } // 重新获取更新后的兑换码 redeemCode, err = s.redeemRepo.GetByID(ctx, redeemCode.ID) if err != nil { return nil, fmt.Errorf("get updated redeem code: %w", err) } return redeemCode, nil } // GetByID 根据ID获取兑换码 func (s *RedeemService) GetByID(ctx context.Context, id int64) (*model.RedeemCode, error) { code, err := s.redeemRepo.GetByID(ctx, id) if err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { return nil, ErrRedeemCodeNotFound } return nil, fmt.Errorf("get redeem code: %w", err) } return code, nil } // GetByCode 根据Code获取兑换码 func (s *RedeemService) GetByCode(ctx context.Context, code string) (*model.RedeemCode, error) { redeemCode, err := s.redeemRepo.GetByCode(ctx, code) if err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { return nil, ErrRedeemCodeNotFound } return nil, fmt.Errorf("get redeem code: %w", err) } return redeemCode, nil } // List 获取兑换码列表(管理员功能) func (s *RedeemService) List(ctx context.Context, params repository.PaginationParams) ([]model.RedeemCode, *repository.PaginationResult, error) { codes, pagination, err := s.redeemRepo.List(ctx, params) if err != nil { return nil, nil, fmt.Errorf("list redeem codes: %w", err) } return codes, pagination, nil } // Delete 删除兑换码(管理员功能) func (s *RedeemService) Delete(ctx context.Context, id int64) error { // 检查兑换码是否存在 code, err := s.redeemRepo.GetByID(ctx, id) if err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { return ErrRedeemCodeNotFound } return fmt.Errorf("get redeem code: %w", err) } // 不允许删除已使用的兑换码 if code.IsUsed() { return errors.New("cannot delete used redeem code") } if err := s.redeemRepo.Delete(ctx, id); err != nil { return fmt.Errorf("delete redeem code: %w", err) } return nil } // GetStats 获取兑换码统计信息 func (s *RedeemService) GetStats(ctx context.Context) (map[string]interface{}, error) { // TODO: 实现统计逻辑 // 统计未使用、已使用的兑换码数量 // 统计总面值等 stats := map[string]interface{}{ "total_codes": 0, "unused_codes": 0, "used_codes": 0, "total_value": 0.0, } return stats, nil } // GetUserHistory 获取用户的兑换历史 func (s *RedeemService) GetUserHistory(ctx context.Context, userID int64, limit int) ([]model.RedeemCode, error) { codes, err := s.redeemRepo.ListByUser(ctx, userID, limit) if err != nil { return nil, fmt.Errorf("get user redeem history: %w", err) } return codes, nil }