fix: 修复代码审核发现的10个问题(P0安全+P1数据一致性+P2性能优化)
P0: OpenAI SSE 错误消息 JSON 注入 — 使用 json.Marshal 替代 fmt.Sprintf P1: subscription 续期包裹 Ent 事务确保原子性 P1: CSP nonce 生成处理 crypto/rand 错误,失败降级为 unsafe-inline P1: singleflight 透传数据库真实错误,不再吞没为 not found P1: GetUserSubscriptionsWithProgress 提取 calculateProgress 消除 N+1 P2: billing_cache/gateway_helper 迁移到 math/rand/v2 消除全局锁争用 P2: generateRandomID 降级分支增加原子计数器防碰撞 P2: CORS 非白名单 origin 不再设置 Allow-Headers/Methods/Max-Age P2: Turnstile 验证移除 VerifyCode 空值跳过条件防绕过 P2: Redis Cluster Lua 脚本空 KEYS 添加兼容性警告注释 Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -8,6 +8,7 @@ import (
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
dbent "github.com/Wei-Shaw/sub2api/ent"
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||||
@@ -40,6 +41,7 @@ type SubscriptionService struct {
|
||||
groupRepo GroupRepository
|
||||
userSubRepo UserSubscriptionRepository
|
||||
billingCacheService *BillingCacheService
|
||||
entClient *dbent.Client
|
||||
|
||||
// L1 缓存:加速中间件热路径的订阅查询
|
||||
subCacheL1 *ristretto.Cache
|
||||
@@ -49,11 +51,12 @@ type SubscriptionService struct {
|
||||
}
|
||||
|
||||
// NewSubscriptionService 创建订阅服务
|
||||
func NewSubscriptionService(groupRepo GroupRepository, userSubRepo UserSubscriptionRepository, billingCacheService *BillingCacheService, cfg *config.Config) *SubscriptionService {
|
||||
func NewSubscriptionService(groupRepo GroupRepository, userSubRepo UserSubscriptionRepository, billingCacheService *BillingCacheService, entClient *dbent.Client, cfg *config.Config) *SubscriptionService {
|
||||
svc := &SubscriptionService{
|
||||
groupRepo: groupRepo,
|
||||
userSubRepo: userSubRepo,
|
||||
billingCacheService: billingCacheService,
|
||||
entClient: entClient,
|
||||
}
|
||||
svc.initSubCache(cfg)
|
||||
return svc
|
||||
@@ -191,7 +194,7 @@ func (s *SubscriptionService) AssignOrExtendSubscription(ctx context.Context, in
|
||||
validityDays = MaxValidityDays
|
||||
}
|
||||
|
||||
// 已有订阅,执行续期
|
||||
// 已有订阅,执行续期(在事务中完成所有更新)
|
||||
if existingSub != nil {
|
||||
now := time.Now()
|
||||
var newExpiresAt time.Time
|
||||
@@ -209,14 +212,23 @@ func (s *SubscriptionService) AssignOrExtendSubscription(ctx context.Context, in
|
||||
newExpiresAt = MaxExpiresAt
|
||||
}
|
||||
|
||||
// 开启事务:ExtendExpiry + UpdateStatus + UpdateNotes 在同一事务中完成
|
||||
tx, err := s.entClient.Tx(ctx)
|
||||
if err != nil {
|
||||
return nil, false, fmt.Errorf("begin transaction: %w", err)
|
||||
}
|
||||
txCtx := dbent.NewTxContext(ctx, tx)
|
||||
|
||||
// 更新过期时间
|
||||
if err := s.userSubRepo.ExtendExpiry(ctx, existingSub.ID, newExpiresAt); err != nil {
|
||||
if err := s.userSubRepo.ExtendExpiry(txCtx, existingSub.ID, newExpiresAt); err != nil {
|
||||
_ = tx.Rollback()
|
||||
return nil, false, fmt.Errorf("extend subscription: %w", err)
|
||||
}
|
||||
|
||||
// 如果订阅已过期或被暂停,恢复为active状态
|
||||
if existingSub.Status != SubscriptionStatusActive {
|
||||
if err := s.userSubRepo.UpdateStatus(ctx, existingSub.ID, SubscriptionStatusActive); err != nil {
|
||||
if err := s.userSubRepo.UpdateStatus(txCtx, existingSub.ID, SubscriptionStatusActive); err != nil {
|
||||
_ = tx.Rollback()
|
||||
return nil, false, fmt.Errorf("update subscription status: %w", err)
|
||||
}
|
||||
}
|
||||
@@ -228,11 +240,17 @@ func (s *SubscriptionService) AssignOrExtendSubscription(ctx context.Context, in
|
||||
newNotes += "\n"
|
||||
}
|
||||
newNotes += input.Notes
|
||||
if err := s.userSubRepo.UpdateNotes(ctx, existingSub.ID, newNotes); err != nil {
|
||||
log.Printf("update subscription notes failed: sub_id=%d err=%v", existingSub.ID, err)
|
||||
if err := s.userSubRepo.UpdateNotes(txCtx, existingSub.ID, newNotes); err != nil {
|
||||
_ = tx.Rollback()
|
||||
return nil, false, fmt.Errorf("update subscription notes: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
// 提交事务
|
||||
if err := tx.Commit(); err != nil {
|
||||
return nil, false, fmt.Errorf("commit transaction: %w", err)
|
||||
}
|
||||
|
||||
// 失效订阅缓存
|
||||
s.InvalidateSubCache(input.UserID, input.GroupID)
|
||||
if s.billingCacheService != nil {
|
||||
@@ -471,7 +489,7 @@ func (s *SubscriptionService) GetActiveSubscription(ctx context.Context, userID,
|
||||
value, err, _ := s.subCacheGroup.Do(key, func() (any, error) {
|
||||
sub, err := s.userSubRepo.GetActiveByUserIDAndGroupID(ctx, userID, groupID)
|
||||
if err != nil {
|
||||
return nil, ErrSubscriptionNotFound
|
||||
return nil, err // 直接透传 repo 已翻译的错误(NotFound → ErrSubscriptionNotFound,其他错误原样返回)
|
||||
}
|
||||
// 写入 L1 缓存
|
||||
if s.subCacheL1 != nil {
|
||||
@@ -763,6 +781,11 @@ func (s *SubscriptionService) GetSubscriptionProgress(ctx context.Context, subsc
|
||||
}
|
||||
}
|
||||
|
||||
return s.calculateProgress(sub, group), nil
|
||||
}
|
||||
|
||||
// calculateProgress 根据已加载的订阅和分组数据计算使用进度(纯内存计算,无 DB 查询)
|
||||
func (s *SubscriptionService) calculateProgress(sub *UserSubscription, group *Group) *SubscriptionProgress {
|
||||
progress := &SubscriptionProgress{
|
||||
ID: sub.ID,
|
||||
GroupName: group.Name,
|
||||
@@ -842,23 +865,25 @@ func (s *SubscriptionService) GetSubscriptionProgress(ctx context.Context, subsc
|
||||
}
|
||||
}
|
||||
|
||||
return progress, nil
|
||||
return progress
|
||||
}
|
||||
|
||||
// GetUserSubscriptionsWithProgress 获取用户所有订阅及进度
|
||||
func (s *SubscriptionService) GetUserSubscriptionsWithProgress(ctx context.Context, userID int64) ([]SubscriptionProgress, error) {
|
||||
// ListActiveByUserID 已使用 .WithGroup() eager-load Group 关联,1 次查询获取所有数据
|
||||
subs, err := s.userSubRepo.ListActiveByUserID(ctx, userID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
progresses := make([]SubscriptionProgress, 0, len(subs))
|
||||
for _, sub := range subs {
|
||||
progress, err := s.GetSubscriptionProgress(ctx, sub.ID)
|
||||
if err != nil {
|
||||
for i := range subs {
|
||||
sub := &subs[i]
|
||||
group := sub.Group
|
||||
if group == nil {
|
||||
continue
|
||||
}
|
||||
progresses = append(progresses, *progress)
|
||||
progresses = append(progresses, *s.calculateProgress(sub, group))
|
||||
}
|
||||
|
||||
return progresses, nil
|
||||
|
||||
Reference in New Issue
Block a user