package service import ( "context" "crypto/rand" "encoding/hex" "errors" "fmt" "strings" "time" dbent "github.com/Wei-Shaw/sub2api/ent" infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors" "github.com/Wei-Shaw/sub2api/internal/pkg/pagination" ) var ( ErrRedeemCodeNotFound = infraerrors.NotFound("REDEEM_CODE_NOT_FOUND", "redeem code not found") ErrRedeemCodeUsed = infraerrors.Conflict("REDEEM_CODE_USED", "redeem code already used") ErrInsufficientBalance = infraerrors.BadRequest("INSUFFICIENT_BALANCE", "insufficient balance") ErrRedeemRateLimited = infraerrors.TooManyRequests("REDEEM_RATE_LIMITED", "too many failed attempts, please try again later") ErrRedeemCodeLocked = infraerrors.Conflict("REDEEM_CODE_LOCKED", "redeem code is being processed, please try again") ) const ( redeemMaxErrorsPerHour = 20 redeemRateLimitDuration = time.Hour redeemLockDuration = 10 * time.Second // 锁超时时间,防止死锁 ) // RedeemCache defines cache operations for redeem service type RedeemCache interface { GetRedeemAttemptCount(ctx context.Context, userID int64) (int, error) IncrementRedeemAttemptCount(ctx context.Context, userID int64) error AcquireRedeemLock(ctx context.Context, code string, ttl time.Duration) (bool, error) ReleaseRedeemLock(ctx context.Context, code string) error } type RedeemCodeRepository interface { Create(ctx context.Context, code *RedeemCode) error CreateBatch(ctx context.Context, codes []RedeemCode) error GetByID(ctx context.Context, id int64) (*RedeemCode, error) GetByCode(ctx context.Context, code string) (*RedeemCode, error) Update(ctx context.Context, code *RedeemCode) error Delete(ctx context.Context, id int64) error Use(ctx context.Context, id, userID int64) error List(ctx context.Context, params pagination.PaginationParams) ([]RedeemCode, *pagination.PaginationResult, error) ListWithFilters(ctx context.Context, params pagination.PaginationParams, codeType, status, search string) ([]RedeemCode, *pagination.PaginationResult, error) ListByUser(ctx context.Context, userID int64, limit int) ([]RedeemCode, error) } // GenerateCodesRequest 生成兑换码请求 type GenerateCodesRequest struct { Count int `json:"count"` Value float64 `json:"value"` Type string `json:"type"` } // RedeemCodeResponse 兑换码响应 type RedeemCodeResponse struct { Code string `json:"code"` Value float64 `json:"value"` Status string `json:"status"` CreatedAt time.Time `json:"created_at"` } // RedeemService 兑换码服务 type RedeemService struct { redeemRepo RedeemCodeRepository userRepo UserRepository subscriptionService *SubscriptionService cache RedeemCache billingCacheService *BillingCacheService entClient *dbent.Client authCacheInvalidator APIKeyAuthCacheInvalidator } // NewRedeemService 创建兑换码服务实例 func NewRedeemService( redeemRepo RedeemCodeRepository, userRepo UserRepository, subscriptionService *SubscriptionService, cache RedeemCache, billingCacheService *BillingCacheService, entClient *dbent.Client, authCacheInvalidator APIKeyAuthCacheInvalidator, ) *RedeemService { return &RedeemService{ redeemRepo: redeemRepo, userRepo: userRepo, subscriptionService: subscriptionService, cache: cache, billingCacheService: billingCacheService, entClient: entClient, authCacheInvalidator: authCacheInvalidator, } } // GenerateRandomCode 生成随机兑换码 func (s *RedeemService) GenerateRandomCode() (string, error) { // 生成16字节随机数据 bytes := make([]byte, 16) if _, err := rand.Read(bytes); err != nil { return "", fmt.Errorf("generate random bytes: %w", err) } // 转换为十六进制字符串 code := hex.EncodeToString(bytes) // 格式化为 XXXX-XXXX-XXXX-XXXX 格式 parts := []string{ strings.ToUpper(code[0:8]), strings.ToUpper(code[8:16]), strings.ToUpper(code[16:24]), strings.ToUpper(code[24:32]), } return strings.Join(parts, "-"), nil } // GenerateCodes 批量生成兑换码 func (s *RedeemService) GenerateCodes(ctx context.Context, req GenerateCodesRequest) ([]RedeemCode, error) { if req.Count <= 0 { return nil, errors.New("count must be greater than 0") } // 邀请码类型不需要数值,其他类型需要 if req.Type != RedeemTypeInvitation && req.Value <= 0 { return nil, errors.New("value must be greater than 0") } if req.Count > 1000 { return nil, errors.New("cannot generate more than 1000 codes at once") } codeType := req.Type if codeType == "" { codeType = RedeemTypeBalance } // 邀请码类型的 value 设为 0 value := req.Value if codeType == RedeemTypeInvitation { value = 0 } codes := make([]RedeemCode, 0, req.Count) for i := 0; i < req.Count; i++ { code, err := s.GenerateRandomCode() if err != nil { return nil, fmt.Errorf("generate code: %w", err) } codes = append(codes, RedeemCode{ Code: code, Type: codeType, Value: value, Status: StatusUnused, }) } // 批量插入 if err := s.redeemRepo.CreateBatch(ctx, codes); err != nil { return nil, fmt.Errorf("create batch codes: %w", err) } return codes, nil } // checkRedeemRateLimit 检查用户兑换错误次数是否超限 func (s *RedeemService) checkRedeemRateLimit(ctx context.Context, userID int64) error { if s.cache == nil { return nil } count, err := s.cache.GetRedeemAttemptCount(ctx, userID) if err != nil { // Redis 出错时不阻止用户操作 return nil } if count >= redeemMaxErrorsPerHour { return ErrRedeemRateLimited } return nil } // incrementRedeemErrorCount 增加用户兑换错误计数 func (s *RedeemService) incrementRedeemErrorCount(ctx context.Context, userID int64) { if s.cache == nil { return } _ = s.cache.IncrementRedeemAttemptCount(ctx, userID) } // acquireRedeemLock 尝试获取兑换码的分布式锁 // 返回 true 表示获取成功,false 表示锁已被占用 func (s *RedeemService) acquireRedeemLock(ctx context.Context, code string) bool { if s.cache == nil { return true // 无 Redis 时降级为不加锁 } ok, err := s.cache.AcquireRedeemLock(ctx, code, redeemLockDuration) if err != nil { // Redis 出错时不阻止操作,依赖数据库层面的状态检查 return true } return ok } // releaseRedeemLock 释放兑换码的分布式锁 func (s *RedeemService) releaseRedeemLock(ctx context.Context, code string) { if s.cache == nil { return } _ = s.cache.ReleaseRedeemLock(ctx, code) } // Redeem 使用兑换码 func (s *RedeemService) Redeem(ctx context.Context, userID int64, code string) (*RedeemCode, error) { // 检查限流 if err := s.checkRedeemRateLimit(ctx, userID); err != nil { return nil, err } // 获取分布式锁,防止同一兑换码并发使用 if !s.acquireRedeemLock(ctx, code) { return nil, ErrRedeemCodeLocked } defer s.releaseRedeemLock(ctx, code) // 查找兑换码 redeemCode, err := s.redeemRepo.GetByCode(ctx, code) if err != nil { if errors.Is(err, ErrRedeemCodeNotFound) { s.incrementRedeemErrorCount(ctx, userID) return nil, ErrRedeemCodeNotFound } return nil, fmt.Errorf("get redeem code: %w", err) } // 检查兑换码状态 if !redeemCode.CanUse() { s.incrementRedeemErrorCount(ctx, userID) return nil, ErrRedeemCodeUsed } // 验证兑换码类型的前置条件 if redeemCode.Type == RedeemTypeSubscription && redeemCode.GroupID == nil { return nil, infraerrors.BadRequest("REDEEM_CODE_INVALID", "invalid subscription redeem code: missing group_id") } // 获取用户信息 user, err := s.userRepo.GetByID(ctx, userID) if err != nil { return nil, fmt.Errorf("get user: %w", err) } _ = user // 使用变量避免未使用错误 // 使用数据库事务保证兑换码标记与权益发放的原子性 tx, err := s.entClient.Tx(ctx) if err != nil { return nil, fmt.Errorf("begin transaction: %w", err) } defer func() { _ = tx.Rollback() }() // 将事务放入 context,使 repository 方法能够使用同一事务 txCtx := dbent.NewTxContext(ctx, tx) // 【关键】先标记兑换码为已使用,确保并发安全 // 利用数据库乐观锁(WHERE status = 'unused')保证原子性 if err := s.redeemRepo.Use(txCtx, redeemCode.ID, userID); err != nil { if errors.Is(err, ErrRedeemCodeNotFound) || errors.Is(err, ErrRedeemCodeUsed) { return nil, ErrRedeemCodeUsed } return nil, fmt.Errorf("mark code as used: %w", err) } // 执行兑换逻辑(兑换码已被锁定,此时可安全操作) switch redeemCode.Type { case RedeemTypeBalance: // 增加用户余额 if err := s.userRepo.UpdateBalance(txCtx, userID, redeemCode.Value); err != nil { return nil, fmt.Errorf("update user balance: %w", err) } case RedeemTypeConcurrency: // 增加用户并发数 if err := s.userRepo.UpdateConcurrency(txCtx, userID, int(redeemCode.Value)); err != nil { return nil, fmt.Errorf("update user concurrency: %w", err) } case RedeemTypeSubscription: validityDays := redeemCode.ValidityDays if validityDays <= 0 { validityDays = 30 } _, _, err := s.subscriptionService.AssignOrExtendSubscription(txCtx, &AssignSubscriptionInput{ UserID: userID, GroupID: *redeemCode.GroupID, ValidityDays: validityDays, AssignedBy: 0, // 系统分配 Notes: fmt.Sprintf("通过兑换码 %s 兑换", redeemCode.Code), }) if err != nil { return nil, fmt.Errorf("assign or extend subscription: %w", err) } default: return nil, fmt.Errorf("unsupported redeem type: %s", redeemCode.Type) } // 提交事务 if err := tx.Commit(); err != nil { return nil, fmt.Errorf("commit transaction: %w", err) } // 事务提交成功后失效缓存 s.invalidateRedeemCaches(ctx, userID, redeemCode) // 重新获取更新后的兑换码 redeemCode, err = s.redeemRepo.GetByID(ctx, redeemCode.ID) if err != nil { return nil, fmt.Errorf("get updated redeem code: %w", err) } return redeemCode, nil } // invalidateRedeemCaches 失效兑换相关的缓存 func (s *RedeemService) invalidateRedeemCaches(ctx context.Context, userID int64, redeemCode *RedeemCode) { switch redeemCode.Type { case RedeemTypeBalance: if s.authCacheInvalidator != nil { s.authCacheInvalidator.InvalidateAuthCacheByUserID(ctx, userID) } if s.billingCacheService == nil { return } go func() { cacheCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() _ = s.billingCacheService.InvalidateUserBalance(cacheCtx, userID) }() case RedeemTypeConcurrency: if s.authCacheInvalidator != nil { s.authCacheInvalidator.InvalidateAuthCacheByUserID(ctx, userID) } if s.billingCacheService == nil { return } case RedeemTypeSubscription: if s.authCacheInvalidator != nil { s.authCacheInvalidator.InvalidateAuthCacheByUserID(ctx, userID) } if s.billingCacheService == nil { return } if redeemCode.GroupID != nil { groupID := *redeemCode.GroupID go func() { cacheCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() _ = s.billingCacheService.InvalidateSubscription(cacheCtx, userID, groupID) }() } } } // GetByID 根据ID获取兑换码 func (s *RedeemService) GetByID(ctx context.Context, id int64) (*RedeemCode, error) { code, err := s.redeemRepo.GetByID(ctx, id) if err != nil { return nil, fmt.Errorf("get redeem code: %w", err) } return code, nil } // GetByCode 根据Code获取兑换码 func (s *RedeemService) GetByCode(ctx context.Context, code string) (*RedeemCode, error) { redeemCode, err := s.redeemRepo.GetByCode(ctx, code) if err != nil { return nil, fmt.Errorf("get redeem code: %w", err) } return redeemCode, nil } // List 获取兑换码列表(管理员功能) func (s *RedeemService) List(ctx context.Context, params pagination.PaginationParams) ([]RedeemCode, *pagination.PaginationResult, error) { codes, pagination, err := s.redeemRepo.List(ctx, params) if err != nil { return nil, nil, fmt.Errorf("list redeem codes: %w", err) } return codes, pagination, nil } // Delete 删除兑换码(管理员功能) func (s *RedeemService) Delete(ctx context.Context, id int64) error { // 检查兑换码是否存在 code, err := s.redeemRepo.GetByID(ctx, id) if err != nil { return fmt.Errorf("get redeem code: %w", err) } // 不允许删除已使用的兑换码 if code.IsUsed() { return infraerrors.Conflict("REDEEM_CODE_DELETE_USED", "cannot delete used redeem code") } if err := s.redeemRepo.Delete(ctx, id); err != nil { return fmt.Errorf("delete redeem code: %w", err) } return nil } // GetStats 获取兑换码统计信息 func (s *RedeemService) GetStats(ctx context.Context) (map[string]any, error) { // TODO: 实现统计逻辑 // 统计未使用、已使用的兑换码数量 // 统计总面值等 stats := map[string]any{ "total_codes": 0, "unused_codes": 0, "used_codes": 0, "total_value": 0.0, } return stats, nil } // GetUserHistory 获取用户的兑换历史 func (s *RedeemService) GetUserHistory(ctx context.Context, userID int64, limit int) ([]RedeemCode, error) { codes, err := s.redeemRepo.ListByUser(ctx, userID, limit) if err != nil { return nil, fmt.Errorf("get user redeem history: %w", err) } return codes, nil }