@@ -9,6 +9,7 @@ import (
"strings"
"time"
dbent "github.com/Wei-Shaw/sub2api/ent"
infraerrors "github.com/Wei-Shaw/sub2api/internal/infrastructure/errors"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
)
@@ -72,6 +73,7 @@ type RedeemService struct {
subscriptionService * SubscriptionService
cache RedeemCache
billingCacheService * BillingCacheService
entClient * dbent . Client
}
// NewRedeemService 创建兑换码服务实例
@@ -81,6 +83,7 @@ func NewRedeemService(
subscriptionService * SubscriptionService ,
cache RedeemCache ,
billingCacheService * BillingCacheService ,
entClient * dbent . Client ,
) * RedeemService {
return & RedeemService {
redeemRepo : redeemRepo ,
@@ -88,6 +91,7 @@ func NewRedeemService(
subscriptionService : subscriptionService ,
cache : cache ,
billingCacheService : billingCacheService ,
entClient : entClient ,
}
}
@@ -248,9 +252,19 @@ func (s *RedeemService) Redeem(ctx context.Context, userID int64, code string) (
}
_ = user // 使用变量避免未使用错误
// 使用数据库事务保证兑换码标记与权益发放的原子性
tx , err := s . entClient . Tx ( ctx )
if err != nil {
return nil , fmt . Errorf ( "begin transaction: %w" , err )
}
defer func ( ) { _ = tx . Rollback ( ) } ( )
// 将事务放入 context, 使 repository 方法能够使用同一事务
txCtx := dbent . NewTxContext ( ctx , tx )
// 【关键】先标记兑换码为已使用,确保并发安全
// 利用数据库乐观锁( WHERE status = 'unused')保证原子性
if err := s . redeemRepo . Use ( c tx, redeemCode . ID , userID ) ; err != nil {
if err := s . redeemRepo . Use ( txC tx, redeemCode . ID , userID ) ; err != nil {
if errors . Is ( err , ErrRedeemCodeNotFound ) || errors . Is ( err , ErrRedeemCodeUsed ) {
return nil , ErrRedeemCodeUsed
}
@@ -261,21 +275,13 @@ func (s *RedeemService) Redeem(ctx context.Context, userID int64, code string) (
switch redeemCode . Type {
case RedeemTypeBalance :
// 增加用户余额
if err := s . userRepo . UpdateBalance ( c tx, userID , redeemCode . Value ) ; err != nil {
if err := s . userRepo . UpdateBalance ( txC tx, userID , redeemCode . Value ) ; err != nil {
return nil , fmt . Errorf ( "update user balance: %w" , err )
}
// 失效余额缓存
if s . billingCacheService != nil {
go func ( ) {
cacheCtx , cancel := context . WithTimeout ( context . Background ( ) , 5 * time . Second )
defer cancel ( )
_ = s . billingCacheService . InvalidateUserBalance ( cacheCtx , userID )
} ( )
}
case RedeemTypeConcurrency :
// 增加用户并发数
if err := s . userRepo . UpdateConcurrency ( c tx, userID , int ( redeemCode . Value ) ) ; err != nil {
if err := s . userRepo . UpdateConcurrency ( txC tx, userID , int ( redeemCode . Value ) ) ; err != nil {
return nil , fmt . Errorf ( "update user concurrency: %w" , err )
}
@@ -284,7 +290,7 @@ func (s *RedeemService) Redeem(ctx context.Context, userID int64, code string) (
if validityDays <= 0 {
validityDays = 30
}
_ , _ , err := s . subscriptionService . AssignOrExtendSubscription ( c tx, & AssignSubscriptionInput {
_ , _ , err := s . subscriptionService . AssignOrExtendSubscription ( txC tx, & AssignSubscriptionInput {
UserID : userID ,
GroupID : * redeemCode . GroupID ,
ValidityDays : validityDays ,
@@ -294,20 +300,19 @@ func (s *RedeemService) Redeem(ctx context.Context, userID int64, code string) (
if err != nil {
return nil , fmt . Errorf ( "assign or extend subscription: %w" , err )
}
// 失效订阅缓存
if s . billingCacheService != nil {
groupID := * redeemCode . GroupID
go func ( ) {
cacheCtx , cancel := context . WithTimeout ( context . Background ( ) , 5 * time . Second )
defer cancel ( )
_ = s . billingCacheService . InvalidateSubscription ( cacheCtx , userID , groupID )
} ( )
}
default :
return nil , fmt . Errorf ( "unsupported redeem type: %s" , redeemCode . Type )
}
// 提交事务
if err := tx . Commit ( ) ; err != nil {
return nil , fmt . Errorf ( "commit transaction: %w" , err )
}
// 事务提交成功后失效缓存
s . invalidateRedeemCaches ( ctx , userID , redeemCode )
// 重新获取更新后的兑换码
redeemCode , err = s . redeemRepo . GetByID ( ctx , redeemCode . ID )
if err != nil {
@@ -317,6 +322,31 @@ func (s *RedeemService) Redeem(ctx context.Context, userID int64, code string) (
return redeemCode , nil
}
// invalidateRedeemCaches 失效兑换相关的缓存
func ( s * RedeemService ) invalidateRedeemCaches ( ctx context . Context , userID int64 , redeemCode * RedeemCode ) {
if s . billingCacheService == nil {
return
}
switch redeemCode . Type {
case RedeemTypeBalance :
go func ( ) {
cacheCtx , cancel := context . WithTimeout ( context . Background ( ) , 5 * time . Second )
defer cancel ( )
_ = s . billingCacheService . InvalidateUserBalance ( cacheCtx , userID )
} ( )
case RedeemTypeSubscription :
if redeemCode . GroupID != nil {
groupID := * redeemCode . GroupID
go func ( ) {
cacheCtx , cancel := context . WithTimeout ( context . Background ( ) , 5 * time . Second )
defer cancel ( )
_ = s . billingCacheService . InvalidateSubscription ( cacheCtx , userID , groupID )
} ( )
}
}
}
// GetByID 根据ID获取兑换码
func ( s * RedeemService ) GetByID ( ctx context . Context , id int64 ) ( * RedeemCode , error ) {
code , err := s . redeemRepo . GetByID ( ctx , id )