fix: 修复代码审核发现的10个问题(P0安全+P1数据一致性+P2性能优化)
P0: OpenAI SSE 错误消息 JSON 注入 — 使用 json.Marshal 替代 fmt.Sprintf P1: subscription 续期包裹 Ent 事务确保原子性 P1: CSP nonce 生成处理 crypto/rand 错误,失败降级为 unsafe-inline P1: singleflight 透传数据库真实错误,不再吞没为 not found P1: GetUserSubscriptionsWithProgress 提取 calculateProgress 消除 N+1 P2: billing_cache/gateway_helper 迁移到 math/rand/v2 消除全局锁争用 P2: generateRandomID 降级分支增加原子计数器防碰撞 P2: CORS 非白名单 origin 不再设置 Allow-Headers/Methods/Max-Age P2: Turnstile 验证移除 VerifyCode 空值跳过条件防绕过 P2: Redis Cluster Lua 脚本空 KEYS 添加兼容性警告注释 Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -66,7 +66,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
|
||||
promoService := service.NewPromoService(promoCodeRepository, userRepository, billingCacheService, client, apiKeyAuthCacheInvalidator)
|
||||
authService := service.NewAuthService(userRepository, redeemCodeRepository, refreshTokenCache, configConfig, settingService, emailService, turnstileService, emailQueueService, promoService)
|
||||
userService := service.NewUserService(userRepository, apiKeyAuthCacheInvalidator, billingCache)
|
||||
subscriptionService := service.NewSubscriptionService(groupRepository, userSubscriptionRepository, billingCacheService, configConfig)
|
||||
subscriptionService := service.NewSubscriptionService(groupRepository, userSubscriptionRepository, billingCacheService, client, configConfig)
|
||||
redeemCache := repository.NewRedeemCache(redisClient)
|
||||
redeemService := service.NewRedeemService(redeemCodeRepository, userRepository, subscriptionService, redeemCache, billingCacheService, client, apiKeyAuthCacheInvalidator)
|
||||
secretEncryptor, err := repository.NewAESEncryptor(configConfig)
|
||||
|
||||
@@ -113,12 +113,11 @@ func (h *AuthHandler) Register(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
// Turnstile 验证(当提供了邮箱验证码时跳过,因为发送验证码时已验证过)
|
||||
if req.VerifyCode == "" {
|
||||
if err := h.authService.VerifyTurnstile(c.Request.Context(), req.TurnstileToken, ip.GetClientIP(c)); err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
// Turnstile 验证 — 始终执行,防止绕过
|
||||
// TODO: 确认前端在提交邮箱验证码注册时也传递了 turnstile_token
|
||||
if err := h.authService.VerifyTurnstile(c.Request.Context(), req.TurnstileToken, ip.GetClientIP(c)); err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
_, user, err := h.authService.RegisterWithVerification(c.Request.Context(), req.Email, req.Password, req.VerifyCode, req.PromoCode, req.InvitationCode)
|
||||
|
||||
@@ -4,7 +4,7 @@ import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"math/rand"
|
||||
"math/rand/v2"
|
||||
"net/http"
|
||||
"sync"
|
||||
"time"
|
||||
@@ -242,7 +242,6 @@ func (h *ConcurrencyHelper) waitForSlotWithPingTimeout(c *gin.Context, slotType
|
||||
backoff := initialBackoff
|
||||
timer := time.NewTimer(backoff)
|
||||
defer timer.Stop()
|
||||
rng := rand.New(rand.NewSource(time.Now().UnixNano()))
|
||||
|
||||
for {
|
||||
select {
|
||||
@@ -284,7 +283,7 @@ func (h *ConcurrencyHelper) waitForSlotWithPingTimeout(c *gin.Context, slotType
|
||||
if result.Acquired {
|
||||
return result.ReleaseFunc, nil
|
||||
}
|
||||
backoff = nextBackoff(backoff, rng)
|
||||
backoff = nextBackoff(backoff)
|
||||
timer.Reset(backoff)
|
||||
}
|
||||
}
|
||||
@@ -298,20 +297,16 @@ func (h *ConcurrencyHelper) AcquireAccountSlotWithWaitTimeout(c *gin.Context, ac
|
||||
// nextBackoff 计算下一次退避时间
|
||||
// 性能优化:使用指数退避 + 随机抖动,避免惊群效应
|
||||
// current: 当前退避时间
|
||||
// rng: 随机数生成器(可为 nil,此时不添加抖动)
|
||||
// 返回值:下一次退避时间(100ms ~ 2s 之间)
|
||||
func nextBackoff(current time.Duration, rng *rand.Rand) time.Duration {
|
||||
func nextBackoff(current time.Duration) time.Duration {
|
||||
// 指数退避:当前时间 * 1.5
|
||||
next := time.Duration(float64(current) * backoffMultiplier)
|
||||
if next > maxBackoff {
|
||||
next = maxBackoff
|
||||
}
|
||||
if rng == nil {
|
||||
return next
|
||||
}
|
||||
// 添加 ±20% 的随机抖动(jitter 范围 0.8 ~ 1.2)
|
||||
// 抖动可以分散多个请求的重试时间点,避免同时冲击 Redis
|
||||
jitter := 0.8 + rng.Float64()*0.4
|
||||
jitter := 0.8 + rand.Float64()*0.4
|
||||
jittered := time.Duration(float64(next) * jitter)
|
||||
if jittered < initialBackoff {
|
||||
return initialBackoff
|
||||
|
||||
@@ -396,8 +396,19 @@ func (h *OpenAIGatewayHandler) handleStreamingAwareError(c *gin.Context, status
|
||||
// Stream already started, send error as SSE event then close
|
||||
flusher, ok := c.Writer.(http.Flusher)
|
||||
if ok {
|
||||
// Send error event in OpenAI SSE format
|
||||
errorEvent := fmt.Sprintf(`event: error`+"\n"+`data: {"error": {"type": "%s", "message": "%s"}}`+"\n\n", errType, message)
|
||||
// Send error event in OpenAI SSE format with proper JSON marshaling
|
||||
errorData := map[string]any{
|
||||
"error": map[string]string{
|
||||
"type": errType,
|
||||
"message": message,
|
||||
},
|
||||
}
|
||||
jsonBytes, err := json.Marshal(errorData)
|
||||
if err != nil {
|
||||
_ = c.Error(err)
|
||||
return
|
||||
}
|
||||
errorEvent := fmt.Sprintf("event: error\ndata: %s\n\n", string(jsonBytes))
|
||||
if _, err := fmt.Fprint(c.Writer, errorEvent); err != nil {
|
||||
_ = c.Error(err)
|
||||
}
|
||||
|
||||
@@ -6,6 +6,7 @@ import (
|
||||
"fmt"
|
||||
"log"
|
||||
"strings"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
)
|
||||
|
||||
@@ -343,6 +344,9 @@ func buildGroundingText(grounding *GeminiGroundingMetadata) string {
|
||||
return builder.String()
|
||||
}
|
||||
|
||||
// fallbackCounter 降级伪随机 ID 的全局计数器,混入 seed 避免高并发下 UnixNano 相同导致碰撞。
|
||||
var fallbackCounter uint64
|
||||
|
||||
// generateRandomID 生成密码学安全的随机 ID
|
||||
func generateRandomID() string {
|
||||
const chars = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789"
|
||||
@@ -351,10 +355,9 @@ func generateRandomID() string {
|
||||
if _, err := rand.Read(randBytes); err != nil {
|
||||
// 避免在请求路径里 panic:极端情况下熵源不可用时降级为伪随机。
|
||||
// 这里主要用于生成响应/工具调用的临时 ID,安全要求不高但需尽量避免碰撞。
|
||||
seed := uint64(time.Now().UnixNano())
|
||||
if err != nil {
|
||||
seed ^= uint64(len(err.Error())) << 32
|
||||
}
|
||||
cnt := atomic.AddUint64(&fallbackCounter, 1)
|
||||
seed := uint64(time.Now().UnixNano()) ^ cnt
|
||||
seed ^= uint64(len(err.Error())) << 32
|
||||
for i := range id {
|
||||
seed ^= seed << 13
|
||||
seed ^= seed >> 7
|
||||
|
||||
@@ -5,7 +5,7 @@ import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"log"
|
||||
"math/rand"
|
||||
"math/rand/v2"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
@@ -26,7 +26,7 @@ func jitteredTTL() time.Duration {
|
||||
if billingCacheJitter <= 0 {
|
||||
return billingCacheTTL
|
||||
}
|
||||
jitter := time.Duration(rand.Int63n(int64(billingCacheJitter)))
|
||||
jitter := time.Duration(rand.IntN(int(billingCacheJitter)))
|
||||
return billingCacheTTL - jitter
|
||||
}
|
||||
|
||||
|
||||
@@ -147,6 +147,10 @@ var (
|
||||
return 1
|
||||
`)
|
||||
|
||||
// WARNING: Redis Cluster 不兼容 — 脚本内部拼接 key,Cluster 模式下可能路由到错误节点。
|
||||
// 调用时传递空 KEYS 数组,所有 key 在 Lua 内通过 ARGV 动态拼接,
|
||||
// 无法被 Redis Cluster 正确路由到对应 slot,仅适用于单节点或 Sentinel 模式。
|
||||
//
|
||||
// getAccountsLoadBatchScript - batch load query with expired slot cleanup
|
||||
// ARGV[1] = slot TTL (seconds)
|
||||
// ARGV[2..n] = accountID1, maxConcurrency1, accountID2, maxConcurrency2, ...
|
||||
@@ -194,6 +198,10 @@ var (
|
||||
return result
|
||||
`)
|
||||
|
||||
// WARNING: Redis Cluster 不兼容 — 脚本内部拼接 key,Cluster 模式下可能路由到错误节点。
|
||||
// 调用时传递空 KEYS 数组,所有 key 在 Lua 内通过 ARGV 动态拼接,
|
||||
// 无法被 Redis Cluster 正确路由到对应 slot,仅适用于单节点或 Sentinel 模式。
|
||||
//
|
||||
// getUsersLoadBatchScript - batch load query for users with expired slot cleanup
|
||||
// ARGV[1] = slot TTL (seconds)
|
||||
// ARGV[2..n] = userID1, maxConcurrency1, userID2, maxConcurrency2, ...
|
||||
|
||||
@@ -68,12 +68,11 @@ func CORS(cfg config.CORSConfig) gin.HandlerFunc {
|
||||
if allowCredentials {
|
||||
c.Writer.Header().Set("Access-Control-Allow-Credentials", "true")
|
||||
}
|
||||
c.Writer.Header().Set("Access-Control-Allow-Headers", "Content-Type, Content-Length, Accept-Encoding, X-CSRF-Token, Authorization, accept, origin, Cache-Control, X-Requested-With, X-API-Key")
|
||||
c.Writer.Header().Set("Access-Control-Allow-Methods", "POST, OPTIONS, GET, PUT, DELETE, PATCH")
|
||||
c.Writer.Header().Set("Access-Control-Max-Age", "86400")
|
||||
}
|
||||
|
||||
c.Writer.Header().Set("Access-Control-Allow-Headers", "Content-Type, Content-Length, Accept-Encoding, X-CSRF-Token, Authorization, accept, origin, Cache-Control, X-Requested-With, X-API-Key")
|
||||
c.Writer.Header().Set("Access-Control-Allow-Methods", "POST, OPTIONS, GET, PUT, DELETE, PATCH")
|
||||
c.Writer.Header().Set("Access-Control-Max-Age", "86400")
|
||||
|
||||
// 处理预检请求
|
||||
if c.Request.Method == http.MethodOptions {
|
||||
if originAllowed {
|
||||
|
||||
@@ -3,6 +3,8 @@ package middleware
|
||||
import (
|
||||
"crypto/rand"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"log"
|
||||
"strings"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
@@ -18,11 +20,14 @@ const (
|
||||
CloudflareInsightsDomain = "https://static.cloudflareinsights.com"
|
||||
)
|
||||
|
||||
// GenerateNonce generates a cryptographically secure random nonce
|
||||
func GenerateNonce() string {
|
||||
// GenerateNonce generates a cryptographically secure random nonce.
|
||||
// 返回 error 以确保调用方在 crypto/rand 失败时能正确降级。
|
||||
func GenerateNonce() (string, error) {
|
||||
b := make([]byte, 16)
|
||||
_, _ = rand.Read(b)
|
||||
return base64.StdEncoding.EncodeToString(b)
|
||||
if _, err := rand.Read(b); err != nil {
|
||||
return "", fmt.Errorf("generate CSP nonce: %w", err)
|
||||
}
|
||||
return base64.StdEncoding.EncodeToString(b), nil
|
||||
}
|
||||
|
||||
// GetNonceFromContext retrieves the CSP nonce from gin context
|
||||
@@ -52,12 +57,17 @@ func SecurityHeaders(cfg config.CSPConfig) gin.HandlerFunc {
|
||||
|
||||
if cfg.Enabled {
|
||||
// Generate nonce for this request
|
||||
nonce := GenerateNonce()
|
||||
c.Set(CSPNonceKey, nonce)
|
||||
|
||||
// Replace nonce placeholder in policy
|
||||
finalPolicy := strings.ReplaceAll(policy, NonceTemplate, "'nonce-"+nonce+"'")
|
||||
c.Header("Content-Security-Policy", finalPolicy)
|
||||
nonce, err := GenerateNonce()
|
||||
if err != nil {
|
||||
// crypto/rand 失败时降级为无 nonce 的 CSP 策略
|
||||
log.Printf("[SecurityHeaders] %v — 降级为无 nonce 的 CSP", err)
|
||||
finalPolicy := strings.ReplaceAll(policy, NonceTemplate, "'unsafe-inline'")
|
||||
c.Header("Content-Security-Policy", finalPolicy)
|
||||
} else {
|
||||
c.Set(CSPNonceKey, nonce)
|
||||
finalPolicy := strings.ReplaceAll(policy, NonceTemplate, "'nonce-"+nonce+"'")
|
||||
c.Header("Content-Security-Policy", finalPolicy)
|
||||
}
|
||||
}
|
||||
c.Next()
|
||||
}
|
||||
|
||||
@@ -8,6 +8,7 @@ import (
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
dbent "github.com/Wei-Shaw/sub2api/ent"
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||||
@@ -40,6 +41,7 @@ type SubscriptionService struct {
|
||||
groupRepo GroupRepository
|
||||
userSubRepo UserSubscriptionRepository
|
||||
billingCacheService *BillingCacheService
|
||||
entClient *dbent.Client
|
||||
|
||||
// L1 缓存:加速中间件热路径的订阅查询
|
||||
subCacheL1 *ristretto.Cache
|
||||
@@ -49,11 +51,12 @@ type SubscriptionService struct {
|
||||
}
|
||||
|
||||
// NewSubscriptionService 创建订阅服务
|
||||
func NewSubscriptionService(groupRepo GroupRepository, userSubRepo UserSubscriptionRepository, billingCacheService *BillingCacheService, cfg *config.Config) *SubscriptionService {
|
||||
func NewSubscriptionService(groupRepo GroupRepository, userSubRepo UserSubscriptionRepository, billingCacheService *BillingCacheService, entClient *dbent.Client, cfg *config.Config) *SubscriptionService {
|
||||
svc := &SubscriptionService{
|
||||
groupRepo: groupRepo,
|
||||
userSubRepo: userSubRepo,
|
||||
billingCacheService: billingCacheService,
|
||||
entClient: entClient,
|
||||
}
|
||||
svc.initSubCache(cfg)
|
||||
return svc
|
||||
@@ -191,7 +194,7 @@ func (s *SubscriptionService) AssignOrExtendSubscription(ctx context.Context, in
|
||||
validityDays = MaxValidityDays
|
||||
}
|
||||
|
||||
// 已有订阅,执行续期
|
||||
// 已有订阅,执行续期(在事务中完成所有更新)
|
||||
if existingSub != nil {
|
||||
now := time.Now()
|
||||
var newExpiresAt time.Time
|
||||
@@ -209,14 +212,23 @@ func (s *SubscriptionService) AssignOrExtendSubscription(ctx context.Context, in
|
||||
newExpiresAt = MaxExpiresAt
|
||||
}
|
||||
|
||||
// 开启事务:ExtendExpiry + UpdateStatus + UpdateNotes 在同一事务中完成
|
||||
tx, err := s.entClient.Tx(ctx)
|
||||
if err != nil {
|
||||
return nil, false, fmt.Errorf("begin transaction: %w", err)
|
||||
}
|
||||
txCtx := dbent.NewTxContext(ctx, tx)
|
||||
|
||||
// 更新过期时间
|
||||
if err := s.userSubRepo.ExtendExpiry(ctx, existingSub.ID, newExpiresAt); err != nil {
|
||||
if err := s.userSubRepo.ExtendExpiry(txCtx, existingSub.ID, newExpiresAt); err != nil {
|
||||
_ = tx.Rollback()
|
||||
return nil, false, fmt.Errorf("extend subscription: %w", err)
|
||||
}
|
||||
|
||||
// 如果订阅已过期或被暂停,恢复为active状态
|
||||
if existingSub.Status != SubscriptionStatusActive {
|
||||
if err := s.userSubRepo.UpdateStatus(ctx, existingSub.ID, SubscriptionStatusActive); err != nil {
|
||||
if err := s.userSubRepo.UpdateStatus(txCtx, existingSub.ID, SubscriptionStatusActive); err != nil {
|
||||
_ = tx.Rollback()
|
||||
return nil, false, fmt.Errorf("update subscription status: %w", err)
|
||||
}
|
||||
}
|
||||
@@ -228,11 +240,17 @@ func (s *SubscriptionService) AssignOrExtendSubscription(ctx context.Context, in
|
||||
newNotes += "\n"
|
||||
}
|
||||
newNotes += input.Notes
|
||||
if err := s.userSubRepo.UpdateNotes(ctx, existingSub.ID, newNotes); err != nil {
|
||||
log.Printf("update subscription notes failed: sub_id=%d err=%v", existingSub.ID, err)
|
||||
if err := s.userSubRepo.UpdateNotes(txCtx, existingSub.ID, newNotes); err != nil {
|
||||
_ = tx.Rollback()
|
||||
return nil, false, fmt.Errorf("update subscription notes: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
// 提交事务
|
||||
if err := tx.Commit(); err != nil {
|
||||
return nil, false, fmt.Errorf("commit transaction: %w", err)
|
||||
}
|
||||
|
||||
// 失效订阅缓存
|
||||
s.InvalidateSubCache(input.UserID, input.GroupID)
|
||||
if s.billingCacheService != nil {
|
||||
@@ -471,7 +489,7 @@ func (s *SubscriptionService) GetActiveSubscription(ctx context.Context, userID,
|
||||
value, err, _ := s.subCacheGroup.Do(key, func() (any, error) {
|
||||
sub, err := s.userSubRepo.GetActiveByUserIDAndGroupID(ctx, userID, groupID)
|
||||
if err != nil {
|
||||
return nil, ErrSubscriptionNotFound
|
||||
return nil, err // 直接透传 repo 已翻译的错误(NotFound → ErrSubscriptionNotFound,其他错误原样返回)
|
||||
}
|
||||
// 写入 L1 缓存
|
||||
if s.subCacheL1 != nil {
|
||||
@@ -763,6 +781,11 @@ func (s *SubscriptionService) GetSubscriptionProgress(ctx context.Context, subsc
|
||||
}
|
||||
}
|
||||
|
||||
return s.calculateProgress(sub, group), nil
|
||||
}
|
||||
|
||||
// calculateProgress 根据已加载的订阅和分组数据计算使用进度(纯内存计算,无 DB 查询)
|
||||
func (s *SubscriptionService) calculateProgress(sub *UserSubscription, group *Group) *SubscriptionProgress {
|
||||
progress := &SubscriptionProgress{
|
||||
ID: sub.ID,
|
||||
GroupName: group.Name,
|
||||
@@ -842,23 +865,25 @@ func (s *SubscriptionService) GetSubscriptionProgress(ctx context.Context, subsc
|
||||
}
|
||||
}
|
||||
|
||||
return progress, nil
|
||||
return progress
|
||||
}
|
||||
|
||||
// GetUserSubscriptionsWithProgress 获取用户所有订阅及进度
|
||||
func (s *SubscriptionService) GetUserSubscriptionsWithProgress(ctx context.Context, userID int64) ([]SubscriptionProgress, error) {
|
||||
// ListActiveByUserID 已使用 .WithGroup() eager-load Group 关联,1 次查询获取所有数据
|
||||
subs, err := s.userSubRepo.ListActiveByUserID(ctx, userID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
progresses := make([]SubscriptionProgress, 0, len(subs))
|
||||
for _, sub := range subs {
|
||||
progress, err := s.GetSubscriptionProgress(ctx, sub.ID)
|
||||
if err != nil {
|
||||
for i := range subs {
|
||||
sub := &subs[i]
|
||||
group := sub.Group
|
||||
if group == nil {
|
||||
continue
|
||||
}
|
||||
progresses = append(progresses, *progress)
|
||||
progresses = append(progresses, *s.calculateProgress(sub, group))
|
||||
}
|
||||
|
||||
return progresses, nil
|
||||
|
||||
Reference in New Issue
Block a user