fix: 修复代码审核发现的10个问题(P0安全+P1数据一致性+P2性能优化)

P0: OpenAI SSE 错误消息 JSON 注入 — 使用 json.Marshal 替代 fmt.Sprintf
P1: subscription 续期包裹 Ent 事务确保原子性
P1: CSP nonce 生成处理 crypto/rand 错误,失败降级为 unsafe-inline
P1: singleflight 透传数据库真实错误,不再吞没为 not found
P1: GetUserSubscriptionsWithProgress 提取 calculateProgress 消除 N+1
P2: billing_cache/gateway_helper 迁移到 math/rand/v2 消除全局锁争用
P2: generateRandomID 降级分支增加原子计数器防碰撞
P2: CORS 非白名单 origin 不再设置 Allow-Headers/Methods/Max-Age
P2: Turnstile 验证移除 VerifyCode 空值跳过条件防绕过
P2: Redis Cluster Lua 脚本空 KEYS 添加兼容性警告注释

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
yangjianbo
2026-02-07 22:13:45 +08:00
parent e1ac0db05c
commit 9634494ba9
10 changed files with 100 additions and 50 deletions

View File

@@ -66,7 +66,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
promoService := service.NewPromoService(promoCodeRepository, userRepository, billingCacheService, client, apiKeyAuthCacheInvalidator)
authService := service.NewAuthService(userRepository, redeemCodeRepository, refreshTokenCache, configConfig, settingService, emailService, turnstileService, emailQueueService, promoService)
userService := service.NewUserService(userRepository, apiKeyAuthCacheInvalidator, billingCache)
subscriptionService := service.NewSubscriptionService(groupRepository, userSubscriptionRepository, billingCacheService, configConfig)
subscriptionService := service.NewSubscriptionService(groupRepository, userSubscriptionRepository, billingCacheService, client, configConfig)
redeemCache := repository.NewRedeemCache(redisClient)
redeemService := service.NewRedeemService(redeemCodeRepository, userRepository, subscriptionService, redeemCache, billingCacheService, client, apiKeyAuthCacheInvalidator)
secretEncryptor, err := repository.NewAESEncryptor(configConfig)

View File

@@ -113,12 +113,11 @@ func (h *AuthHandler) Register(c *gin.Context) {
return
}
// Turnstile 验证(当提供了邮箱验证码时跳过,因为发送验证码时已验证过)
if req.VerifyCode == "" {
if err := h.authService.VerifyTurnstile(c.Request.Context(), req.TurnstileToken, ip.GetClientIP(c)); err != nil {
response.ErrorFrom(c, err)
return
}
// Turnstile 验证 — 始终执行,防止绕过
// TODO: 确认前端在提交邮箱验证码注册时也传递了 turnstile_token
if err := h.authService.VerifyTurnstile(c.Request.Context(), req.TurnstileToken, ip.GetClientIP(c)); err != nil {
response.ErrorFrom(c, err)
return
}
_, user, err := h.authService.RegisterWithVerification(c.Request.Context(), req.Email, req.Password, req.VerifyCode, req.PromoCode, req.InvitationCode)

View File

@@ -4,7 +4,7 @@ import (
"context"
"encoding/json"
"fmt"
"math/rand"
"math/rand/v2"
"net/http"
"sync"
"time"
@@ -242,7 +242,6 @@ func (h *ConcurrencyHelper) waitForSlotWithPingTimeout(c *gin.Context, slotType
backoff := initialBackoff
timer := time.NewTimer(backoff)
defer timer.Stop()
rng := rand.New(rand.NewSource(time.Now().UnixNano()))
for {
select {
@@ -284,7 +283,7 @@ func (h *ConcurrencyHelper) waitForSlotWithPingTimeout(c *gin.Context, slotType
if result.Acquired {
return result.ReleaseFunc, nil
}
backoff = nextBackoff(backoff, rng)
backoff = nextBackoff(backoff)
timer.Reset(backoff)
}
}
@@ -298,20 +297,16 @@ func (h *ConcurrencyHelper) AcquireAccountSlotWithWaitTimeout(c *gin.Context, ac
// nextBackoff 计算下一次退避时间
// 性能优化:使用指数退避 + 随机抖动,避免惊群效应
// current: 当前退避时间
// rng: 随机数生成器(可为 nil此时不添加抖动
// 返回值下一次退避时间100ms ~ 2s 之间)
func nextBackoff(current time.Duration, rng *rand.Rand) time.Duration {
func nextBackoff(current time.Duration) time.Duration {
// 指数退避:当前时间 * 1.5
next := time.Duration(float64(current) * backoffMultiplier)
if next > maxBackoff {
next = maxBackoff
}
if rng == nil {
return next
}
// 添加 ±20% 的随机抖动jitter 范围 0.8 ~ 1.2
// 抖动可以分散多个请求的重试时间点,避免同时冲击 Redis
jitter := 0.8 + rng.Float64()*0.4
jitter := 0.8 + rand.Float64()*0.4
jittered := time.Duration(float64(next) * jitter)
if jittered < initialBackoff {
return initialBackoff

View File

@@ -396,8 +396,19 @@ func (h *OpenAIGatewayHandler) handleStreamingAwareError(c *gin.Context, status
// Stream already started, send error as SSE event then close
flusher, ok := c.Writer.(http.Flusher)
if ok {
// Send error event in OpenAI SSE format
errorEvent := fmt.Sprintf(`event: error`+"\n"+`data: {"error": {"type": "%s", "message": "%s"}}`+"\n\n", errType, message)
// Send error event in OpenAI SSE format with proper JSON marshaling
errorData := map[string]any{
"error": map[string]string{
"type": errType,
"message": message,
},
}
jsonBytes, err := json.Marshal(errorData)
if err != nil {
_ = c.Error(err)
return
}
errorEvent := fmt.Sprintf("event: error\ndata: %s\n\n", string(jsonBytes))
if _, err := fmt.Fprint(c.Writer, errorEvent); err != nil {
_ = c.Error(err)
}

View File

@@ -6,6 +6,7 @@ import (
"fmt"
"log"
"strings"
"sync/atomic"
"time"
)
@@ -343,6 +344,9 @@ func buildGroundingText(grounding *GeminiGroundingMetadata) string {
return builder.String()
}
// fallbackCounter 降级伪随机 ID 的全局计数器,混入 seed 避免高并发下 UnixNano 相同导致碰撞。
var fallbackCounter uint64
// generateRandomID 生成密码学安全的随机 ID
func generateRandomID() string {
const chars = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789"
@@ -351,10 +355,9 @@ func generateRandomID() string {
if _, err := rand.Read(randBytes); err != nil {
// 避免在请求路径里 panic极端情况下熵源不可用时降级为伪随机。
// 这里主要用于生成响应/工具调用的临时 ID安全要求不高但需尽量避免碰撞。
seed := uint64(time.Now().UnixNano())
if err != nil {
seed ^= uint64(len(err.Error())) << 32
}
cnt := atomic.AddUint64(&fallbackCounter, 1)
seed := uint64(time.Now().UnixNano()) ^ cnt
seed ^= uint64(len(err.Error())) << 32
for i := range id {
seed ^= seed << 13
seed ^= seed >> 7

View File

@@ -5,7 +5,7 @@ import (
"errors"
"fmt"
"log"
"math/rand"
"math/rand/v2"
"strconv"
"time"
@@ -26,7 +26,7 @@ func jitteredTTL() time.Duration {
if billingCacheJitter <= 0 {
return billingCacheTTL
}
jitter := time.Duration(rand.Int63n(int64(billingCacheJitter)))
jitter := time.Duration(rand.IntN(int(billingCacheJitter)))
return billingCacheTTL - jitter
}

View File

@@ -147,6 +147,10 @@ var (
return 1
`)
// WARNING: Redis Cluster 不兼容 — 脚本内部拼接 keyCluster 模式下可能路由到错误节点。
// 调用时传递空 KEYS 数组,所有 key 在 Lua 内通过 ARGV 动态拼接,
// 无法被 Redis Cluster 正确路由到对应 slot仅适用于单节点或 Sentinel 模式。
//
// getAccountsLoadBatchScript - batch load query with expired slot cleanup
// ARGV[1] = slot TTL (seconds)
// ARGV[2..n] = accountID1, maxConcurrency1, accountID2, maxConcurrency2, ...
@@ -194,6 +198,10 @@ var (
return result
`)
// WARNING: Redis Cluster 不兼容 — 脚本内部拼接 keyCluster 模式下可能路由到错误节点。
// 调用时传递空 KEYS 数组,所有 key 在 Lua 内通过 ARGV 动态拼接,
// 无法被 Redis Cluster 正确路由到对应 slot仅适用于单节点或 Sentinel 模式。
//
// getUsersLoadBatchScript - batch load query for users with expired slot cleanup
// ARGV[1] = slot TTL (seconds)
// ARGV[2..n] = userID1, maxConcurrency1, userID2, maxConcurrency2, ...

View File

@@ -68,12 +68,11 @@ func CORS(cfg config.CORSConfig) gin.HandlerFunc {
if allowCredentials {
c.Writer.Header().Set("Access-Control-Allow-Credentials", "true")
}
c.Writer.Header().Set("Access-Control-Allow-Headers", "Content-Type, Content-Length, Accept-Encoding, X-CSRF-Token, Authorization, accept, origin, Cache-Control, X-Requested-With, X-API-Key")
c.Writer.Header().Set("Access-Control-Allow-Methods", "POST, OPTIONS, GET, PUT, DELETE, PATCH")
c.Writer.Header().Set("Access-Control-Max-Age", "86400")
}
c.Writer.Header().Set("Access-Control-Allow-Headers", "Content-Type, Content-Length, Accept-Encoding, X-CSRF-Token, Authorization, accept, origin, Cache-Control, X-Requested-With, X-API-Key")
c.Writer.Header().Set("Access-Control-Allow-Methods", "POST, OPTIONS, GET, PUT, DELETE, PATCH")
c.Writer.Header().Set("Access-Control-Max-Age", "86400")
// 处理预检请求
if c.Request.Method == http.MethodOptions {
if originAllowed {

View File

@@ -3,6 +3,8 @@ package middleware
import (
"crypto/rand"
"encoding/base64"
"fmt"
"log"
"strings"
"github.com/Wei-Shaw/sub2api/internal/config"
@@ -18,11 +20,14 @@ const (
CloudflareInsightsDomain = "https://static.cloudflareinsights.com"
)
// GenerateNonce generates a cryptographically secure random nonce
func GenerateNonce() string {
// GenerateNonce generates a cryptographically secure random nonce.
// 返回 error 以确保调用方在 crypto/rand 失败时能正确降级。
func GenerateNonce() (string, error) {
b := make([]byte, 16)
_, _ = rand.Read(b)
return base64.StdEncoding.EncodeToString(b)
if _, err := rand.Read(b); err != nil {
return "", fmt.Errorf("generate CSP nonce: %w", err)
}
return base64.StdEncoding.EncodeToString(b), nil
}
// GetNonceFromContext retrieves the CSP nonce from gin context
@@ -52,12 +57,17 @@ func SecurityHeaders(cfg config.CSPConfig) gin.HandlerFunc {
if cfg.Enabled {
// Generate nonce for this request
nonce := GenerateNonce()
c.Set(CSPNonceKey, nonce)
// Replace nonce placeholder in policy
finalPolicy := strings.ReplaceAll(policy, NonceTemplate, "'nonce-"+nonce+"'")
c.Header("Content-Security-Policy", finalPolicy)
nonce, err := GenerateNonce()
if err != nil {
// crypto/rand 失败时降级为无 nonce 的 CSP 策略
log.Printf("[SecurityHeaders] %v — 降级为无 nonce 的 CSP", err)
finalPolicy := strings.ReplaceAll(policy, NonceTemplate, "'unsafe-inline'")
c.Header("Content-Security-Policy", finalPolicy)
} else {
c.Set(CSPNonceKey, nonce)
finalPolicy := strings.ReplaceAll(policy, NonceTemplate, "'nonce-"+nonce+"'")
c.Header("Content-Security-Policy", finalPolicy)
}
}
c.Next()
}

View File

@@ -8,6 +8,7 @@ import (
"strconv"
"time"
dbent "github.com/Wei-Shaw/sub2api/ent"
"github.com/Wei-Shaw/sub2api/internal/config"
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
@@ -40,6 +41,7 @@ type SubscriptionService struct {
groupRepo GroupRepository
userSubRepo UserSubscriptionRepository
billingCacheService *BillingCacheService
entClient *dbent.Client
// L1 缓存:加速中间件热路径的订阅查询
subCacheL1 *ristretto.Cache
@@ -49,11 +51,12 @@ type SubscriptionService struct {
}
// NewSubscriptionService 创建订阅服务
func NewSubscriptionService(groupRepo GroupRepository, userSubRepo UserSubscriptionRepository, billingCacheService *BillingCacheService, cfg *config.Config) *SubscriptionService {
func NewSubscriptionService(groupRepo GroupRepository, userSubRepo UserSubscriptionRepository, billingCacheService *BillingCacheService, entClient *dbent.Client, cfg *config.Config) *SubscriptionService {
svc := &SubscriptionService{
groupRepo: groupRepo,
userSubRepo: userSubRepo,
billingCacheService: billingCacheService,
entClient: entClient,
}
svc.initSubCache(cfg)
return svc
@@ -191,7 +194,7 @@ func (s *SubscriptionService) AssignOrExtendSubscription(ctx context.Context, in
validityDays = MaxValidityDays
}
// 已有订阅,执行续期
// 已有订阅,执行续期(在事务中完成所有更新)
if existingSub != nil {
now := time.Now()
var newExpiresAt time.Time
@@ -209,14 +212,23 @@ func (s *SubscriptionService) AssignOrExtendSubscription(ctx context.Context, in
newExpiresAt = MaxExpiresAt
}
// 开启事务ExtendExpiry + UpdateStatus + UpdateNotes 在同一事务中完成
tx, err := s.entClient.Tx(ctx)
if err != nil {
return nil, false, fmt.Errorf("begin transaction: %w", err)
}
txCtx := dbent.NewTxContext(ctx, tx)
// 更新过期时间
if err := s.userSubRepo.ExtendExpiry(ctx, existingSub.ID, newExpiresAt); err != nil {
if err := s.userSubRepo.ExtendExpiry(txCtx, existingSub.ID, newExpiresAt); err != nil {
_ = tx.Rollback()
return nil, false, fmt.Errorf("extend subscription: %w", err)
}
// 如果订阅已过期或被暂停恢复为active状态
if existingSub.Status != SubscriptionStatusActive {
if err := s.userSubRepo.UpdateStatus(ctx, existingSub.ID, SubscriptionStatusActive); err != nil {
if err := s.userSubRepo.UpdateStatus(txCtx, existingSub.ID, SubscriptionStatusActive); err != nil {
_ = tx.Rollback()
return nil, false, fmt.Errorf("update subscription status: %w", err)
}
}
@@ -228,11 +240,17 @@ func (s *SubscriptionService) AssignOrExtendSubscription(ctx context.Context, in
newNotes += "\n"
}
newNotes += input.Notes
if err := s.userSubRepo.UpdateNotes(ctx, existingSub.ID, newNotes); err != nil {
log.Printf("update subscription notes failed: sub_id=%d err=%v", existingSub.ID, err)
if err := s.userSubRepo.UpdateNotes(txCtx, existingSub.ID, newNotes); err != nil {
_ = tx.Rollback()
return nil, false, fmt.Errorf("update subscription notes: %w", err)
}
}
// 提交事务
if err := tx.Commit(); err != nil {
return nil, false, fmt.Errorf("commit transaction: %w", err)
}
// 失效订阅缓存
s.InvalidateSubCache(input.UserID, input.GroupID)
if s.billingCacheService != nil {
@@ -471,7 +489,7 @@ func (s *SubscriptionService) GetActiveSubscription(ctx context.Context, userID,
value, err, _ := s.subCacheGroup.Do(key, func() (any, error) {
sub, err := s.userSubRepo.GetActiveByUserIDAndGroupID(ctx, userID, groupID)
if err != nil {
return nil, ErrSubscriptionNotFound
return nil, err // 直接透传 repo 已翻译的错误NotFound → ErrSubscriptionNotFound,其他错误原样返回)
}
// 写入 L1 缓存
if s.subCacheL1 != nil {
@@ -763,6 +781,11 @@ func (s *SubscriptionService) GetSubscriptionProgress(ctx context.Context, subsc
}
}
return s.calculateProgress(sub, group), nil
}
// calculateProgress 根据已加载的订阅和分组数据计算使用进度(纯内存计算,无 DB 查询)
func (s *SubscriptionService) calculateProgress(sub *UserSubscription, group *Group) *SubscriptionProgress {
progress := &SubscriptionProgress{
ID: sub.ID,
GroupName: group.Name,
@@ -842,23 +865,25 @@ func (s *SubscriptionService) GetSubscriptionProgress(ctx context.Context, subsc
}
}
return progress, nil
return progress
}
// GetUserSubscriptionsWithProgress 获取用户所有订阅及进度
func (s *SubscriptionService) GetUserSubscriptionsWithProgress(ctx context.Context, userID int64) ([]SubscriptionProgress, error) {
// ListActiveByUserID 已使用 .WithGroup() eager-load Group 关联1 次查询获取所有数据
subs, err := s.userSubRepo.ListActiveByUserID(ctx, userID)
if err != nil {
return nil, err
}
progresses := make([]SubscriptionProgress, 0, len(subs))
for _, sub := range subs {
progress, err := s.GetSubscriptionProgress(ctx, sub.ID)
if err != nil {
for i := range subs {
sub := &subs[i]
group := sub.Group
if group == nil {
continue
}
progresses = append(progresses, *progress)
progresses = append(progresses, *s.calculateProgress(sub, group))
}
return progresses, nil