diff --git a/backend/cmd/server/wire_gen.go b/backend/cmd/server/wire_gen.go index ab3ce4e0..341da381 100644 --- a/backend/cmd/server/wire_gen.go +++ b/backend/cmd/server/wire_gen.go @@ -66,7 +66,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { promoService := service.NewPromoService(promoCodeRepository, userRepository, billingCacheService, client, apiKeyAuthCacheInvalidator) authService := service.NewAuthService(userRepository, redeemCodeRepository, refreshTokenCache, configConfig, settingService, emailService, turnstileService, emailQueueService, promoService) userService := service.NewUserService(userRepository, apiKeyAuthCacheInvalidator, billingCache) - subscriptionService := service.NewSubscriptionService(groupRepository, userSubscriptionRepository, billingCacheService, configConfig) + subscriptionService := service.NewSubscriptionService(groupRepository, userSubscriptionRepository, billingCacheService, client, configConfig) redeemCache := repository.NewRedeemCache(redisClient) redeemService := service.NewRedeemService(redeemCodeRepository, userRepository, subscriptionService, redeemCache, billingCacheService, client, apiKeyAuthCacheInvalidator) secretEncryptor, err := repository.NewAESEncryptor(configConfig) diff --git a/backend/internal/handler/auth_handler.go b/backend/internal/handler/auth_handler.go index 204af666..e0078e14 100644 --- a/backend/internal/handler/auth_handler.go +++ b/backend/internal/handler/auth_handler.go @@ -113,12 +113,11 @@ func (h *AuthHandler) Register(c *gin.Context) { return } - // Turnstile 验证(当提供了邮箱验证码时跳过,因为发送验证码时已验证过) - if req.VerifyCode == "" { - if err := h.authService.VerifyTurnstile(c.Request.Context(), req.TurnstileToken, ip.GetClientIP(c)); err != nil { - response.ErrorFrom(c, err) - return - } + // Turnstile 验证 — 始终执行,防止绕过 + // TODO: 确认前端在提交邮箱验证码注册时也传递了 turnstile_token + if err := h.authService.VerifyTurnstile(c.Request.Context(), req.TurnstileToken, ip.GetClientIP(c)); err != nil { + response.ErrorFrom(c, err) + return } _, user, err := h.authService.RegisterWithVerification(c.Request.Context(), req.Email, req.Password, req.VerifyCode, req.PromoCode, req.InvitationCode) diff --git a/backend/internal/handler/gateway_helper.go b/backend/internal/handler/gateway_helper.go index 0393f954..94698691 100644 --- a/backend/internal/handler/gateway_helper.go +++ b/backend/internal/handler/gateway_helper.go @@ -4,7 +4,7 @@ import ( "context" "encoding/json" "fmt" - "math/rand" + "math/rand/v2" "net/http" "sync" "time" @@ -242,7 +242,6 @@ func (h *ConcurrencyHelper) waitForSlotWithPingTimeout(c *gin.Context, slotType backoff := initialBackoff timer := time.NewTimer(backoff) defer timer.Stop() - rng := rand.New(rand.NewSource(time.Now().UnixNano())) for { select { @@ -284,7 +283,7 @@ func (h *ConcurrencyHelper) waitForSlotWithPingTimeout(c *gin.Context, slotType if result.Acquired { return result.ReleaseFunc, nil } - backoff = nextBackoff(backoff, rng) + backoff = nextBackoff(backoff) timer.Reset(backoff) } } @@ -298,20 +297,16 @@ func (h *ConcurrencyHelper) AcquireAccountSlotWithWaitTimeout(c *gin.Context, ac // nextBackoff 计算下一次退避时间 // 性能优化:使用指数退避 + 随机抖动,避免惊群效应 // current: 当前退避时间 -// rng: 随机数生成器(可为 nil,此时不添加抖动) // 返回值:下一次退避时间(100ms ~ 2s 之间) -func nextBackoff(current time.Duration, rng *rand.Rand) time.Duration { +func nextBackoff(current time.Duration) time.Duration { // 指数退避:当前时间 * 1.5 next := time.Duration(float64(current) * backoffMultiplier) if next > maxBackoff { next = maxBackoff } - if rng == nil { - return next - } // 添加 ±20% 的随机抖动(jitter 范围 0.8 ~ 1.2) // 抖动可以分散多个请求的重试时间点,避免同时冲击 Redis - jitter := 0.8 + rng.Float64()*0.4 + jitter := 0.8 + rand.Float64()*0.4 jittered := time.Duration(float64(next) * jitter) if jittered < initialBackoff { return initialBackoff diff --git a/backend/internal/handler/openai_gateway_handler.go b/backend/internal/handler/openai_gateway_handler.go index 56e21690..1f8ccba9 100644 --- a/backend/internal/handler/openai_gateway_handler.go +++ b/backend/internal/handler/openai_gateway_handler.go @@ -396,8 +396,19 @@ func (h *OpenAIGatewayHandler) handleStreamingAwareError(c *gin.Context, status // Stream already started, send error as SSE event then close flusher, ok := c.Writer.(http.Flusher) if ok { - // Send error event in OpenAI SSE format - errorEvent := fmt.Sprintf(`event: error`+"\n"+`data: {"error": {"type": "%s", "message": "%s"}}`+"\n\n", errType, message) + // Send error event in OpenAI SSE format with proper JSON marshaling + errorData := map[string]any{ + "error": map[string]string{ + "type": errType, + "message": message, + }, + } + jsonBytes, err := json.Marshal(errorData) + if err != nil { + _ = c.Error(err) + return + } + errorEvent := fmt.Sprintf("event: error\ndata: %s\n\n", string(jsonBytes)) if _, err := fmt.Fprint(c.Writer, errorEvent); err != nil { _ = c.Error(err) } diff --git a/backend/internal/pkg/antigravity/response_transformer.go b/backend/internal/pkg/antigravity/response_transformer.go index 69829ab6..84687f08 100644 --- a/backend/internal/pkg/antigravity/response_transformer.go +++ b/backend/internal/pkg/antigravity/response_transformer.go @@ -6,6 +6,7 @@ import ( "fmt" "log" "strings" + "sync/atomic" "time" ) @@ -343,6 +344,9 @@ func buildGroundingText(grounding *GeminiGroundingMetadata) string { return builder.String() } +// fallbackCounter 降级伪随机 ID 的全局计数器,混入 seed 避免高并发下 UnixNano 相同导致碰撞。 +var fallbackCounter uint64 + // generateRandomID 生成密码学安全的随机 ID func generateRandomID() string { const chars = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789" @@ -351,10 +355,9 @@ func generateRandomID() string { if _, err := rand.Read(randBytes); err != nil { // 避免在请求路径里 panic:极端情况下熵源不可用时降级为伪随机。 // 这里主要用于生成响应/工具调用的临时 ID,安全要求不高但需尽量避免碰撞。 - seed := uint64(time.Now().UnixNano()) - if err != nil { - seed ^= uint64(len(err.Error())) << 32 - } + cnt := atomic.AddUint64(&fallbackCounter, 1) + seed := uint64(time.Now().UnixNano()) ^ cnt + seed ^= uint64(len(err.Error())) << 32 for i := range id { seed ^= seed << 13 seed ^= seed >> 7 diff --git a/backend/internal/repository/billing_cache.go b/backend/internal/repository/billing_cache.go index 370d0672..e753e1b8 100644 --- a/backend/internal/repository/billing_cache.go +++ b/backend/internal/repository/billing_cache.go @@ -5,7 +5,7 @@ import ( "errors" "fmt" "log" - "math/rand" + "math/rand/v2" "strconv" "time" @@ -26,7 +26,7 @@ func jitteredTTL() time.Duration { if billingCacheJitter <= 0 { return billingCacheTTL } - jitter := time.Duration(rand.Int63n(int64(billingCacheJitter))) + jitter := time.Duration(rand.IntN(int(billingCacheJitter))) return billingCacheTTL - jitter } diff --git a/backend/internal/repository/concurrency_cache.go b/backend/internal/repository/concurrency_cache.go index cc0c6db5..28932cc5 100644 --- a/backend/internal/repository/concurrency_cache.go +++ b/backend/internal/repository/concurrency_cache.go @@ -147,6 +147,10 @@ var ( return 1 `) + // WARNING: Redis Cluster 不兼容 — 脚本内部拼接 key,Cluster 模式下可能路由到错误节点。 + // 调用时传递空 KEYS 数组,所有 key 在 Lua 内通过 ARGV 动态拼接, + // 无法被 Redis Cluster 正确路由到对应 slot,仅适用于单节点或 Sentinel 模式。 + // // getAccountsLoadBatchScript - batch load query with expired slot cleanup // ARGV[1] = slot TTL (seconds) // ARGV[2..n] = accountID1, maxConcurrency1, accountID2, maxConcurrency2, ... @@ -194,6 +198,10 @@ var ( return result `) + // WARNING: Redis Cluster 不兼容 — 脚本内部拼接 key,Cluster 模式下可能路由到错误节点。 + // 调用时传递空 KEYS 数组,所有 key 在 Lua 内通过 ARGV 动态拼接, + // 无法被 Redis Cluster 正确路由到对应 slot,仅适用于单节点或 Sentinel 模式。 + // // getUsersLoadBatchScript - batch load query for users with expired slot cleanup // ARGV[1] = slot TTL (seconds) // ARGV[2..n] = userID1, maxConcurrency1, userID2, maxConcurrency2, ... diff --git a/backend/internal/server/middleware/cors.go b/backend/internal/server/middleware/cors.go index b54a0b0e..14a09cc2 100644 --- a/backend/internal/server/middleware/cors.go +++ b/backend/internal/server/middleware/cors.go @@ -68,12 +68,11 @@ func CORS(cfg config.CORSConfig) gin.HandlerFunc { if allowCredentials { c.Writer.Header().Set("Access-Control-Allow-Credentials", "true") } + c.Writer.Header().Set("Access-Control-Allow-Headers", "Content-Type, Content-Length, Accept-Encoding, X-CSRF-Token, Authorization, accept, origin, Cache-Control, X-Requested-With, X-API-Key") + c.Writer.Header().Set("Access-Control-Allow-Methods", "POST, OPTIONS, GET, PUT, DELETE, PATCH") + c.Writer.Header().Set("Access-Control-Max-Age", "86400") } - c.Writer.Header().Set("Access-Control-Allow-Headers", "Content-Type, Content-Length, Accept-Encoding, X-CSRF-Token, Authorization, accept, origin, Cache-Control, X-Requested-With, X-API-Key") - c.Writer.Header().Set("Access-Control-Allow-Methods", "POST, OPTIONS, GET, PUT, DELETE, PATCH") - c.Writer.Header().Set("Access-Control-Max-Age", "86400") - // 处理预检请求 if c.Request.Method == http.MethodOptions { if originAllowed { diff --git a/backend/internal/server/middleware/security_headers.go b/backend/internal/server/middleware/security_headers.go index 9ce7f449..67b19c09 100644 --- a/backend/internal/server/middleware/security_headers.go +++ b/backend/internal/server/middleware/security_headers.go @@ -3,6 +3,8 @@ package middleware import ( "crypto/rand" "encoding/base64" + "fmt" + "log" "strings" "github.com/Wei-Shaw/sub2api/internal/config" @@ -18,11 +20,14 @@ const ( CloudflareInsightsDomain = "https://static.cloudflareinsights.com" ) -// GenerateNonce generates a cryptographically secure random nonce -func GenerateNonce() string { +// GenerateNonce generates a cryptographically secure random nonce. +// 返回 error 以确保调用方在 crypto/rand 失败时能正确降级。 +func GenerateNonce() (string, error) { b := make([]byte, 16) - _, _ = rand.Read(b) - return base64.StdEncoding.EncodeToString(b) + if _, err := rand.Read(b); err != nil { + return "", fmt.Errorf("generate CSP nonce: %w", err) + } + return base64.StdEncoding.EncodeToString(b), nil } // GetNonceFromContext retrieves the CSP nonce from gin context @@ -52,12 +57,17 @@ func SecurityHeaders(cfg config.CSPConfig) gin.HandlerFunc { if cfg.Enabled { // Generate nonce for this request - nonce := GenerateNonce() - c.Set(CSPNonceKey, nonce) - - // Replace nonce placeholder in policy - finalPolicy := strings.ReplaceAll(policy, NonceTemplate, "'nonce-"+nonce+"'") - c.Header("Content-Security-Policy", finalPolicy) + nonce, err := GenerateNonce() + if err != nil { + // crypto/rand 失败时降级为无 nonce 的 CSP 策略 + log.Printf("[SecurityHeaders] %v — 降级为无 nonce 的 CSP", err) + finalPolicy := strings.ReplaceAll(policy, NonceTemplate, "'unsafe-inline'") + c.Header("Content-Security-Policy", finalPolicy) + } else { + c.Set(CSPNonceKey, nonce) + finalPolicy := strings.ReplaceAll(policy, NonceTemplate, "'nonce-"+nonce+"'") + c.Header("Content-Security-Policy", finalPolicy) + } } c.Next() } diff --git a/backend/internal/service/subscription_service.go b/backend/internal/service/subscription_service.go index 92300b11..4360b261 100644 --- a/backend/internal/service/subscription_service.go +++ b/backend/internal/service/subscription_service.go @@ -8,6 +8,7 @@ import ( "strconv" "time" + dbent "github.com/Wei-Shaw/sub2api/ent" "github.com/Wei-Shaw/sub2api/internal/config" infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors" "github.com/Wei-Shaw/sub2api/internal/pkg/pagination" @@ -40,6 +41,7 @@ type SubscriptionService struct { groupRepo GroupRepository userSubRepo UserSubscriptionRepository billingCacheService *BillingCacheService + entClient *dbent.Client // L1 缓存:加速中间件热路径的订阅查询 subCacheL1 *ristretto.Cache @@ -49,11 +51,12 @@ type SubscriptionService struct { } // NewSubscriptionService 创建订阅服务 -func NewSubscriptionService(groupRepo GroupRepository, userSubRepo UserSubscriptionRepository, billingCacheService *BillingCacheService, cfg *config.Config) *SubscriptionService { +func NewSubscriptionService(groupRepo GroupRepository, userSubRepo UserSubscriptionRepository, billingCacheService *BillingCacheService, entClient *dbent.Client, cfg *config.Config) *SubscriptionService { svc := &SubscriptionService{ groupRepo: groupRepo, userSubRepo: userSubRepo, billingCacheService: billingCacheService, + entClient: entClient, } svc.initSubCache(cfg) return svc @@ -191,7 +194,7 @@ func (s *SubscriptionService) AssignOrExtendSubscription(ctx context.Context, in validityDays = MaxValidityDays } - // 已有订阅,执行续期 + // 已有订阅,执行续期(在事务中完成所有更新) if existingSub != nil { now := time.Now() var newExpiresAt time.Time @@ -209,14 +212,23 @@ func (s *SubscriptionService) AssignOrExtendSubscription(ctx context.Context, in newExpiresAt = MaxExpiresAt } + // 开启事务:ExtendExpiry + UpdateStatus + UpdateNotes 在同一事务中完成 + tx, err := s.entClient.Tx(ctx) + if err != nil { + return nil, false, fmt.Errorf("begin transaction: %w", err) + } + txCtx := dbent.NewTxContext(ctx, tx) + // 更新过期时间 - if err := s.userSubRepo.ExtendExpiry(ctx, existingSub.ID, newExpiresAt); err != nil { + if err := s.userSubRepo.ExtendExpiry(txCtx, existingSub.ID, newExpiresAt); err != nil { + _ = tx.Rollback() return nil, false, fmt.Errorf("extend subscription: %w", err) } // 如果订阅已过期或被暂停,恢复为active状态 if existingSub.Status != SubscriptionStatusActive { - if err := s.userSubRepo.UpdateStatus(ctx, existingSub.ID, SubscriptionStatusActive); err != nil { + if err := s.userSubRepo.UpdateStatus(txCtx, existingSub.ID, SubscriptionStatusActive); err != nil { + _ = tx.Rollback() return nil, false, fmt.Errorf("update subscription status: %w", err) } } @@ -228,11 +240,17 @@ func (s *SubscriptionService) AssignOrExtendSubscription(ctx context.Context, in newNotes += "\n" } newNotes += input.Notes - if err := s.userSubRepo.UpdateNotes(ctx, existingSub.ID, newNotes); err != nil { - log.Printf("update subscription notes failed: sub_id=%d err=%v", existingSub.ID, err) + if err := s.userSubRepo.UpdateNotes(txCtx, existingSub.ID, newNotes); err != nil { + _ = tx.Rollback() + return nil, false, fmt.Errorf("update subscription notes: %w", err) } } + // 提交事务 + if err := tx.Commit(); err != nil { + return nil, false, fmt.Errorf("commit transaction: %w", err) + } + // 失效订阅缓存 s.InvalidateSubCache(input.UserID, input.GroupID) if s.billingCacheService != nil { @@ -471,7 +489,7 @@ func (s *SubscriptionService) GetActiveSubscription(ctx context.Context, userID, value, err, _ := s.subCacheGroup.Do(key, func() (any, error) { sub, err := s.userSubRepo.GetActiveByUserIDAndGroupID(ctx, userID, groupID) if err != nil { - return nil, ErrSubscriptionNotFound + return nil, err // 直接透传 repo 已翻译的错误(NotFound → ErrSubscriptionNotFound,其他错误原样返回) } // 写入 L1 缓存 if s.subCacheL1 != nil { @@ -763,6 +781,11 @@ func (s *SubscriptionService) GetSubscriptionProgress(ctx context.Context, subsc } } + return s.calculateProgress(sub, group), nil +} + +// calculateProgress 根据已加载的订阅和分组数据计算使用进度(纯内存计算,无 DB 查询) +func (s *SubscriptionService) calculateProgress(sub *UserSubscription, group *Group) *SubscriptionProgress { progress := &SubscriptionProgress{ ID: sub.ID, GroupName: group.Name, @@ -842,23 +865,25 @@ func (s *SubscriptionService) GetSubscriptionProgress(ctx context.Context, subsc } } - return progress, nil + return progress } // GetUserSubscriptionsWithProgress 获取用户所有订阅及进度 func (s *SubscriptionService) GetUserSubscriptionsWithProgress(ctx context.Context, userID int64) ([]SubscriptionProgress, error) { + // ListActiveByUserID 已使用 .WithGroup() eager-load Group 关联,1 次查询获取所有数据 subs, err := s.userSubRepo.ListActiveByUserID(ctx, userID) if err != nil { return nil, err } progresses := make([]SubscriptionProgress, 0, len(subs)) - for _, sub := range subs { - progress, err := s.GetSubscriptionProgress(ctx, sub.ID) - if err != nil { + for i := range subs { + sub := &subs[i] + group := sub.Group + if group == nil { continue } - progresses = append(progresses, *progress) + progresses = append(progresses, *s.calculateProgress(sub, group)) } return progresses, nil