fix: 修复代码审核发现的10个问题(P0安全+P1数据一致性+P2性能优化)

P0: OpenAI SSE 错误消息 JSON 注入 — 使用 json.Marshal 替代 fmt.Sprintf
P1: subscription 续期包裹 Ent 事务确保原子性
P1: CSP nonce 生成处理 crypto/rand 错误,失败降级为 unsafe-inline
P1: singleflight 透传数据库真实错误,不再吞没为 not found
P1: GetUserSubscriptionsWithProgress 提取 calculateProgress 消除 N+1
P2: billing_cache/gateway_helper 迁移到 math/rand/v2 消除全局锁争用
P2: generateRandomID 降级分支增加原子计数器防碰撞
P2: CORS 非白名单 origin 不再设置 Allow-Headers/Methods/Max-Age
P2: Turnstile 验证移除 VerifyCode 空值跳过条件防绕过
P2: Redis Cluster Lua 脚本空 KEYS 添加兼容性警告注释

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
yangjianbo
2026-02-07 22:13:45 +08:00
parent e1ac0db05c
commit 9634494ba9
10 changed files with 100 additions and 50 deletions

View File

@@ -8,6 +8,7 @@ import (
"strconv"
"time"
dbent "github.com/Wei-Shaw/sub2api/ent"
"github.com/Wei-Shaw/sub2api/internal/config"
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
@@ -40,6 +41,7 @@ type SubscriptionService struct {
groupRepo GroupRepository
userSubRepo UserSubscriptionRepository
billingCacheService *BillingCacheService
entClient *dbent.Client
// L1 缓存:加速中间件热路径的订阅查询
subCacheL1 *ristretto.Cache
@@ -49,11 +51,12 @@ type SubscriptionService struct {
}
// NewSubscriptionService 创建订阅服务
func NewSubscriptionService(groupRepo GroupRepository, userSubRepo UserSubscriptionRepository, billingCacheService *BillingCacheService, cfg *config.Config) *SubscriptionService {
func NewSubscriptionService(groupRepo GroupRepository, userSubRepo UserSubscriptionRepository, billingCacheService *BillingCacheService, entClient *dbent.Client, cfg *config.Config) *SubscriptionService {
svc := &SubscriptionService{
groupRepo: groupRepo,
userSubRepo: userSubRepo,
billingCacheService: billingCacheService,
entClient: entClient,
}
svc.initSubCache(cfg)
return svc
@@ -191,7 +194,7 @@ func (s *SubscriptionService) AssignOrExtendSubscription(ctx context.Context, in
validityDays = MaxValidityDays
}
// 已有订阅,执行续期
// 已有订阅,执行续期(在事务中完成所有更新)
if existingSub != nil {
now := time.Now()
var newExpiresAt time.Time
@@ -209,14 +212,23 @@ func (s *SubscriptionService) AssignOrExtendSubscription(ctx context.Context, in
newExpiresAt = MaxExpiresAt
}
// 开启事务ExtendExpiry + UpdateStatus + UpdateNotes 在同一事务中完成
tx, err := s.entClient.Tx(ctx)
if err != nil {
return nil, false, fmt.Errorf("begin transaction: %w", err)
}
txCtx := dbent.NewTxContext(ctx, tx)
// 更新过期时间
if err := s.userSubRepo.ExtendExpiry(ctx, existingSub.ID, newExpiresAt); err != nil {
if err := s.userSubRepo.ExtendExpiry(txCtx, existingSub.ID, newExpiresAt); err != nil {
_ = tx.Rollback()
return nil, false, fmt.Errorf("extend subscription: %w", err)
}
// 如果订阅已过期或被暂停恢复为active状态
if existingSub.Status != SubscriptionStatusActive {
if err := s.userSubRepo.UpdateStatus(ctx, existingSub.ID, SubscriptionStatusActive); err != nil {
if err := s.userSubRepo.UpdateStatus(txCtx, existingSub.ID, SubscriptionStatusActive); err != nil {
_ = tx.Rollback()
return nil, false, fmt.Errorf("update subscription status: %w", err)
}
}
@@ -228,11 +240,17 @@ func (s *SubscriptionService) AssignOrExtendSubscription(ctx context.Context, in
newNotes += "\n"
}
newNotes += input.Notes
if err := s.userSubRepo.UpdateNotes(ctx, existingSub.ID, newNotes); err != nil {
log.Printf("update subscription notes failed: sub_id=%d err=%v", existingSub.ID, err)
if err := s.userSubRepo.UpdateNotes(txCtx, existingSub.ID, newNotes); err != nil {
_ = tx.Rollback()
return nil, false, fmt.Errorf("update subscription notes: %w", err)
}
}
// 提交事务
if err := tx.Commit(); err != nil {
return nil, false, fmt.Errorf("commit transaction: %w", err)
}
// 失效订阅缓存
s.InvalidateSubCache(input.UserID, input.GroupID)
if s.billingCacheService != nil {
@@ -471,7 +489,7 @@ func (s *SubscriptionService) GetActiveSubscription(ctx context.Context, userID,
value, err, _ := s.subCacheGroup.Do(key, func() (any, error) {
sub, err := s.userSubRepo.GetActiveByUserIDAndGroupID(ctx, userID, groupID)
if err != nil {
return nil, ErrSubscriptionNotFound
return nil, err // 直接透传 repo 已翻译的错误NotFound → ErrSubscriptionNotFound,其他错误原样返回)
}
// 写入 L1 缓存
if s.subCacheL1 != nil {
@@ -763,6 +781,11 @@ func (s *SubscriptionService) GetSubscriptionProgress(ctx context.Context, subsc
}
}
return s.calculateProgress(sub, group), nil
}
// calculateProgress 根据已加载的订阅和分组数据计算使用进度(纯内存计算,无 DB 查询)
func (s *SubscriptionService) calculateProgress(sub *UserSubscription, group *Group) *SubscriptionProgress {
progress := &SubscriptionProgress{
ID: sub.ID,
GroupName: group.Name,
@@ -842,23 +865,25 @@ func (s *SubscriptionService) GetSubscriptionProgress(ctx context.Context, subsc
}
}
return progress, nil
return progress
}
// GetUserSubscriptionsWithProgress 获取用户所有订阅及进度
func (s *SubscriptionService) GetUserSubscriptionsWithProgress(ctx context.Context, userID int64) ([]SubscriptionProgress, error) {
// ListActiveByUserID 已使用 .WithGroup() eager-load Group 关联1 次查询获取所有数据
subs, err := s.userSubRepo.ListActiveByUserID(ctx, userID)
if err != nil {
return nil, err
}
progresses := make([]SubscriptionProgress, 0, len(subs))
for _, sub := range subs {
progress, err := s.GetSubscriptionProgress(ctx, sub.ID)
if err != nil {
for i := range subs {
sub := &subs[i]
group := sub.Group
if group == nil {
continue
}
progresses = append(progresses, *progress)
progresses = append(progresses, *s.calculateProgress(sub, group))
}
return progresses, nil