451 lines
14 KiB
Go
451 lines
14 KiB
Go
package service
|
||
|
||
import (
|
||
"context"
|
||
"crypto/rand"
|
||
"encoding/hex"
|
||
"errors"
|
||
"fmt"
|
||
"strings"
|
||
"time"
|
||
|
||
dbent "github.com/Wei-Shaw/sub2api/ent"
|
||
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
|
||
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||
)
|
||
|
||
var (
|
||
ErrRedeemCodeNotFound = infraerrors.NotFound("REDEEM_CODE_NOT_FOUND", "redeem code not found")
|
||
ErrRedeemCodeUsed = infraerrors.Conflict("REDEEM_CODE_USED", "redeem code already used")
|
||
ErrInsufficientBalance = infraerrors.BadRequest("INSUFFICIENT_BALANCE", "insufficient balance")
|
||
ErrRedeemRateLimited = infraerrors.TooManyRequests("REDEEM_RATE_LIMITED", "too many failed attempts, please try again later")
|
||
ErrRedeemCodeLocked = infraerrors.Conflict("REDEEM_CODE_LOCKED", "redeem code is being processed, please try again")
|
||
)
|
||
|
||
const (
|
||
redeemMaxErrorsPerHour = 20
|
||
redeemRateLimitDuration = time.Hour
|
||
redeemLockDuration = 10 * time.Second // 锁超时时间,防止死锁
|
||
)
|
||
|
||
// RedeemCache defines cache operations for redeem service
|
||
type RedeemCache interface {
|
||
GetRedeemAttemptCount(ctx context.Context, userID int64) (int, error)
|
||
IncrementRedeemAttemptCount(ctx context.Context, userID int64) error
|
||
|
||
AcquireRedeemLock(ctx context.Context, code string, ttl time.Duration) (bool, error)
|
||
ReleaseRedeemLock(ctx context.Context, code string) error
|
||
}
|
||
|
||
type RedeemCodeRepository interface {
|
||
Create(ctx context.Context, code *RedeemCode) error
|
||
CreateBatch(ctx context.Context, codes []RedeemCode) error
|
||
GetByID(ctx context.Context, id int64) (*RedeemCode, error)
|
||
GetByCode(ctx context.Context, code string) (*RedeemCode, error)
|
||
Update(ctx context.Context, code *RedeemCode) error
|
||
Delete(ctx context.Context, id int64) error
|
||
Use(ctx context.Context, id, userID int64) error
|
||
|
||
List(ctx context.Context, params pagination.PaginationParams) ([]RedeemCode, *pagination.PaginationResult, error)
|
||
ListWithFilters(ctx context.Context, params pagination.PaginationParams, codeType, status, search string) ([]RedeemCode, *pagination.PaginationResult, error)
|
||
ListByUser(ctx context.Context, userID int64, limit int) ([]RedeemCode, error)
|
||
// ListByUserPaginated returns paginated balance/concurrency history for a specific user.
|
||
// codeType filter is optional - pass empty string to return all types.
|
||
ListByUserPaginated(ctx context.Context, userID int64, params pagination.PaginationParams, codeType string) ([]RedeemCode, *pagination.PaginationResult, error)
|
||
// SumPositiveBalanceByUser returns the total recharged amount (sum of positive balance values) for a user.
|
||
SumPositiveBalanceByUser(ctx context.Context, userID int64) (float64, error)
|
||
}
|
||
|
||
// GenerateCodesRequest 生成兑换码请求
|
||
type GenerateCodesRequest struct {
|
||
Count int `json:"count"`
|
||
Value float64 `json:"value"`
|
||
Type string `json:"type"`
|
||
}
|
||
|
||
// RedeemCodeResponse 兑换码响应
|
||
type RedeemCodeResponse struct {
|
||
Code string `json:"code"`
|
||
Value float64 `json:"value"`
|
||
Status string `json:"status"`
|
||
CreatedAt time.Time `json:"created_at"`
|
||
}
|
||
|
||
// RedeemService 兑换码服务
|
||
type RedeemService struct {
|
||
redeemRepo RedeemCodeRepository
|
||
userRepo UserRepository
|
||
subscriptionService *SubscriptionService
|
||
cache RedeemCache
|
||
billingCacheService *BillingCacheService
|
||
entClient *dbent.Client
|
||
authCacheInvalidator APIKeyAuthCacheInvalidator
|
||
}
|
||
|
||
// NewRedeemService 创建兑换码服务实例
|
||
func NewRedeemService(
|
||
redeemRepo RedeemCodeRepository,
|
||
userRepo UserRepository,
|
||
subscriptionService *SubscriptionService,
|
||
cache RedeemCache,
|
||
billingCacheService *BillingCacheService,
|
||
entClient *dbent.Client,
|
||
authCacheInvalidator APIKeyAuthCacheInvalidator,
|
||
) *RedeemService {
|
||
return &RedeemService{
|
||
redeemRepo: redeemRepo,
|
||
userRepo: userRepo,
|
||
subscriptionService: subscriptionService,
|
||
cache: cache,
|
||
billingCacheService: billingCacheService,
|
||
entClient: entClient,
|
||
authCacheInvalidator: authCacheInvalidator,
|
||
}
|
||
}
|
||
|
||
// GenerateRandomCode 生成随机兑换码
|
||
func (s *RedeemService) GenerateRandomCode() (string, error) {
|
||
// 生成16字节随机数据
|
||
bytes := make([]byte, 16)
|
||
if _, err := rand.Read(bytes); err != nil {
|
||
return "", fmt.Errorf("generate random bytes: %w", err)
|
||
}
|
||
|
||
// 转换为十六进制字符串
|
||
code := hex.EncodeToString(bytes)
|
||
|
||
// 格式化为 XXXX-XXXX-XXXX-XXXX 格式
|
||
parts := []string{
|
||
strings.ToUpper(code[0:8]),
|
||
strings.ToUpper(code[8:16]),
|
||
strings.ToUpper(code[16:24]),
|
||
strings.ToUpper(code[24:32]),
|
||
}
|
||
|
||
return strings.Join(parts, "-"), nil
|
||
}
|
||
|
||
// GenerateCodes 批量生成兑换码
|
||
func (s *RedeemService) GenerateCodes(ctx context.Context, req GenerateCodesRequest) ([]RedeemCode, error) {
|
||
if req.Count <= 0 {
|
||
return nil, errors.New("count must be greater than 0")
|
||
}
|
||
|
||
// 邀请码类型不需要数值,其他类型需要
|
||
if req.Type != RedeemTypeInvitation && req.Value <= 0 {
|
||
return nil, errors.New("value must be greater than 0")
|
||
}
|
||
|
||
if req.Count > 1000 {
|
||
return nil, errors.New("cannot generate more than 1000 codes at once")
|
||
}
|
||
|
||
codeType := req.Type
|
||
if codeType == "" {
|
||
codeType = RedeemTypeBalance
|
||
}
|
||
|
||
// 邀请码类型的 value 设为 0
|
||
value := req.Value
|
||
if codeType == RedeemTypeInvitation {
|
||
value = 0
|
||
}
|
||
|
||
codes := make([]RedeemCode, 0, req.Count)
|
||
for i := 0; i < req.Count; i++ {
|
||
code, err := s.GenerateRandomCode()
|
||
if err != nil {
|
||
return nil, fmt.Errorf("generate code: %w", err)
|
||
}
|
||
|
||
codes = append(codes, RedeemCode{
|
||
Code: code,
|
||
Type: codeType,
|
||
Value: value,
|
||
Status: StatusUnused,
|
||
})
|
||
}
|
||
|
||
// 批量插入
|
||
if err := s.redeemRepo.CreateBatch(ctx, codes); err != nil {
|
||
return nil, fmt.Errorf("create batch codes: %w", err)
|
||
}
|
||
|
||
return codes, nil
|
||
}
|
||
|
||
// checkRedeemRateLimit 检查用户兑换错误次数是否超限
|
||
func (s *RedeemService) checkRedeemRateLimit(ctx context.Context, userID int64) error {
|
||
if s.cache == nil {
|
||
return nil
|
||
}
|
||
|
||
count, err := s.cache.GetRedeemAttemptCount(ctx, userID)
|
||
if err != nil {
|
||
// Redis 出错时不阻止用户操作
|
||
return nil
|
||
}
|
||
|
||
if count >= redeemMaxErrorsPerHour {
|
||
return ErrRedeemRateLimited
|
||
}
|
||
|
||
return nil
|
||
}
|
||
|
||
// incrementRedeemErrorCount 增加用户兑换错误计数
|
||
func (s *RedeemService) incrementRedeemErrorCount(ctx context.Context, userID int64) {
|
||
if s.cache == nil {
|
||
return
|
||
}
|
||
|
||
_ = s.cache.IncrementRedeemAttemptCount(ctx, userID)
|
||
}
|
||
|
||
// acquireRedeemLock 尝试获取兑换码的分布式锁
|
||
// 返回 true 表示获取成功,false 表示锁已被占用
|
||
func (s *RedeemService) acquireRedeemLock(ctx context.Context, code string) bool {
|
||
if s.cache == nil {
|
||
return true // 无 Redis 时降级为不加锁
|
||
}
|
||
|
||
ok, err := s.cache.AcquireRedeemLock(ctx, code, redeemLockDuration)
|
||
if err != nil {
|
||
// Redis 出错时不阻止操作,依赖数据库层面的状态检查
|
||
return true
|
||
}
|
||
return ok
|
||
}
|
||
|
||
// releaseRedeemLock 释放兑换码的分布式锁
|
||
func (s *RedeemService) releaseRedeemLock(ctx context.Context, code string) {
|
||
if s.cache == nil {
|
||
return
|
||
}
|
||
|
||
_ = s.cache.ReleaseRedeemLock(ctx, code)
|
||
}
|
||
|
||
// Redeem 使用兑换码
|
||
func (s *RedeemService) Redeem(ctx context.Context, userID int64, code string) (*RedeemCode, error) {
|
||
// 检查限流
|
||
if err := s.checkRedeemRateLimit(ctx, userID); err != nil {
|
||
return nil, err
|
||
}
|
||
|
||
// 获取分布式锁,防止同一兑换码并发使用
|
||
if !s.acquireRedeemLock(ctx, code) {
|
||
return nil, ErrRedeemCodeLocked
|
||
}
|
||
defer s.releaseRedeemLock(ctx, code)
|
||
|
||
// 查找兑换码
|
||
redeemCode, err := s.redeemRepo.GetByCode(ctx, code)
|
||
if err != nil {
|
||
if errors.Is(err, ErrRedeemCodeNotFound) {
|
||
s.incrementRedeemErrorCount(ctx, userID)
|
||
return nil, ErrRedeemCodeNotFound
|
||
}
|
||
return nil, fmt.Errorf("get redeem code: %w", err)
|
||
}
|
||
|
||
// 检查兑换码状态
|
||
if !redeemCode.CanUse() {
|
||
s.incrementRedeemErrorCount(ctx, userID)
|
||
return nil, ErrRedeemCodeUsed
|
||
}
|
||
|
||
// 验证兑换码类型的前置条件
|
||
if redeemCode.Type == RedeemTypeSubscription && redeemCode.GroupID == nil {
|
||
return nil, infraerrors.BadRequest("REDEEM_CODE_INVALID", "invalid subscription redeem code: missing group_id")
|
||
}
|
||
|
||
// 获取用户信息
|
||
user, err := s.userRepo.GetByID(ctx, userID)
|
||
if err != nil {
|
||
return nil, fmt.Errorf("get user: %w", err)
|
||
}
|
||
_ = user // 使用变量避免未使用错误
|
||
|
||
// 使用数据库事务保证兑换码标记与权益发放的原子性
|
||
tx, err := s.entClient.Tx(ctx)
|
||
if err != nil {
|
||
return nil, fmt.Errorf("begin transaction: %w", err)
|
||
}
|
||
defer func() { _ = tx.Rollback() }()
|
||
|
||
// 将事务放入 context,使 repository 方法能够使用同一事务
|
||
txCtx := dbent.NewTxContext(ctx, tx)
|
||
|
||
// 【关键】先标记兑换码为已使用,确保并发安全
|
||
// 利用数据库乐观锁(WHERE status = 'unused')保证原子性
|
||
if err := s.redeemRepo.Use(txCtx, redeemCode.ID, userID); err != nil {
|
||
if errors.Is(err, ErrRedeemCodeNotFound) || errors.Is(err, ErrRedeemCodeUsed) {
|
||
return nil, ErrRedeemCodeUsed
|
||
}
|
||
return nil, fmt.Errorf("mark code as used: %w", err)
|
||
}
|
||
|
||
// 执行兑换逻辑(兑换码已被锁定,此时可安全操作)
|
||
switch redeemCode.Type {
|
||
case RedeemTypeBalance:
|
||
// 增加用户余额
|
||
if err := s.userRepo.UpdateBalance(txCtx, userID, redeemCode.Value); err != nil {
|
||
return nil, fmt.Errorf("update user balance: %w", err)
|
||
}
|
||
|
||
case RedeemTypeConcurrency:
|
||
// 增加用户并发数
|
||
if err := s.userRepo.UpdateConcurrency(txCtx, userID, int(redeemCode.Value)); err != nil {
|
||
return nil, fmt.Errorf("update user concurrency: %w", err)
|
||
}
|
||
|
||
case RedeemTypeSubscription:
|
||
validityDays := redeemCode.ValidityDays
|
||
if validityDays <= 0 {
|
||
validityDays = 30
|
||
}
|
||
_, _, err := s.subscriptionService.AssignOrExtendSubscription(txCtx, &AssignSubscriptionInput{
|
||
UserID: userID,
|
||
GroupID: *redeemCode.GroupID,
|
||
ValidityDays: validityDays,
|
||
AssignedBy: 0, // 系统分配
|
||
Notes: fmt.Sprintf("通过兑换码 %s 兑换", redeemCode.Code),
|
||
})
|
||
if err != nil {
|
||
return nil, fmt.Errorf("assign or extend subscription: %w", err)
|
||
}
|
||
|
||
default:
|
||
return nil, fmt.Errorf("unsupported redeem type: %s", redeemCode.Type)
|
||
}
|
||
|
||
// 提交事务
|
||
if err := tx.Commit(); err != nil {
|
||
return nil, fmt.Errorf("commit transaction: %w", err)
|
||
}
|
||
|
||
// 事务提交成功后失效缓存
|
||
s.invalidateRedeemCaches(ctx, userID, redeemCode)
|
||
|
||
// 重新获取更新后的兑换码
|
||
redeemCode, err = s.redeemRepo.GetByID(ctx, redeemCode.ID)
|
||
if err != nil {
|
||
return nil, fmt.Errorf("get updated redeem code: %w", err)
|
||
}
|
||
|
||
return redeemCode, nil
|
||
}
|
||
|
||
// invalidateRedeemCaches 失效兑换相关的缓存
|
||
func (s *RedeemService) invalidateRedeemCaches(ctx context.Context, userID int64, redeemCode *RedeemCode) {
|
||
switch redeemCode.Type {
|
||
case RedeemTypeBalance:
|
||
if s.authCacheInvalidator != nil {
|
||
s.authCacheInvalidator.InvalidateAuthCacheByUserID(ctx, userID)
|
||
}
|
||
if s.billingCacheService == nil {
|
||
return
|
||
}
|
||
go func() {
|
||
cacheCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||
defer cancel()
|
||
_ = s.billingCacheService.InvalidateUserBalance(cacheCtx, userID)
|
||
}()
|
||
case RedeemTypeConcurrency:
|
||
if s.authCacheInvalidator != nil {
|
||
s.authCacheInvalidator.InvalidateAuthCacheByUserID(ctx, userID)
|
||
}
|
||
if s.billingCacheService == nil {
|
||
return
|
||
}
|
||
case RedeemTypeSubscription:
|
||
if s.authCacheInvalidator != nil {
|
||
s.authCacheInvalidator.InvalidateAuthCacheByUserID(ctx, userID)
|
||
}
|
||
if s.billingCacheService == nil {
|
||
return
|
||
}
|
||
if redeemCode.GroupID != nil {
|
||
groupID := *redeemCode.GroupID
|
||
go func() {
|
||
cacheCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||
defer cancel()
|
||
_ = s.billingCacheService.InvalidateSubscription(cacheCtx, userID, groupID)
|
||
}()
|
||
}
|
||
}
|
||
}
|
||
|
||
// GetByID 根据ID获取兑换码
|
||
func (s *RedeemService) GetByID(ctx context.Context, id int64) (*RedeemCode, error) {
|
||
code, err := s.redeemRepo.GetByID(ctx, id)
|
||
if err != nil {
|
||
return nil, fmt.Errorf("get redeem code: %w", err)
|
||
}
|
||
return code, nil
|
||
}
|
||
|
||
// GetByCode 根据Code获取兑换码
|
||
func (s *RedeemService) GetByCode(ctx context.Context, code string) (*RedeemCode, error) {
|
||
redeemCode, err := s.redeemRepo.GetByCode(ctx, code)
|
||
if err != nil {
|
||
return nil, fmt.Errorf("get redeem code: %w", err)
|
||
}
|
||
return redeemCode, nil
|
||
}
|
||
|
||
// List 获取兑换码列表(管理员功能)
|
||
func (s *RedeemService) List(ctx context.Context, params pagination.PaginationParams) ([]RedeemCode, *pagination.PaginationResult, error) {
|
||
codes, pagination, err := s.redeemRepo.List(ctx, params)
|
||
if err != nil {
|
||
return nil, nil, fmt.Errorf("list redeem codes: %w", err)
|
||
}
|
||
return codes, pagination, nil
|
||
}
|
||
|
||
// Delete 删除兑换码(管理员功能)
|
||
func (s *RedeemService) Delete(ctx context.Context, id int64) error {
|
||
// 检查兑换码是否存在
|
||
code, err := s.redeemRepo.GetByID(ctx, id)
|
||
if err != nil {
|
||
return fmt.Errorf("get redeem code: %w", err)
|
||
}
|
||
|
||
// 不允许删除已使用的兑换码
|
||
if code.IsUsed() {
|
||
return infraerrors.Conflict("REDEEM_CODE_DELETE_USED", "cannot delete used redeem code")
|
||
}
|
||
|
||
if err := s.redeemRepo.Delete(ctx, id); err != nil {
|
||
return fmt.Errorf("delete redeem code: %w", err)
|
||
}
|
||
|
||
return nil
|
||
}
|
||
|
||
// GetStats 获取兑换码统计信息
|
||
func (s *RedeemService) GetStats(ctx context.Context) (map[string]any, error) {
|
||
// TODO: 实现统计逻辑
|
||
// 统计未使用、已使用的兑换码数量
|
||
// 统计总面值等
|
||
|
||
stats := map[string]any{
|
||
"total_codes": 0,
|
||
"unused_codes": 0,
|
||
"used_codes": 0,
|
||
"total_value": 0.0,
|
||
}
|
||
|
||
return stats, nil
|
||
}
|
||
|
||
// GetUserHistory 获取用户的兑换历史
|
||
func (s *RedeemService) GetUserHistory(ctx context.Context, userID int64, limit int) ([]RedeemCode, error) {
|
||
codes, err := s.redeemRepo.ListByUser(ctx, userID, limit)
|
||
if err != nil {
|
||
return nil, fmt.Errorf("get user redeem history: %w", err)
|
||
}
|
||
return codes, nil
|
||
}
|