diff --git a/backend/cmd/server/wire_gen.go b/backend/cmd/server/wire_gen.go
index 289a14bd..27404b02 100644
--- a/backend/cmd/server/wire_gen.go
+++ b/backend/cmd/server/wire_gen.go
@@ -78,7 +78,10 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
dashboardAggregationRepository := repository.NewDashboardAggregationRepository(db)
dashboardStatsCache := repository.NewDashboardCache(redisClient, configConfig)
dashboardService := service.NewDashboardService(usageLogRepository, dashboardAggregationRepository, dashboardStatsCache, configConfig)
- timingWheelService := service.ProvideTimingWheelService()
+ timingWheelService, err := service.ProvideTimingWheelService()
+ if err != nil {
+ return nil, err
+ }
dashboardAggregationService := service.ProvideDashboardAggregationService(dashboardAggregationRepository, timingWheelService, configConfig)
dashboardHandler := admin.NewDashboardHandler(dashboardService, dashboardAggregationService)
accountRepository := repository.NewAccountRepository(client, db)
diff --git a/backend/internal/config/config.go b/backend/internal/config/config.go
index 944e0f84..5dc6ad19 100644
--- a/backend/internal/config/config.go
+++ b/backend/internal/config/config.go
@@ -19,7 +19,9 @@ const (
RunModeSimple = "simple"
)
-const DefaultCSPPolicy = "default-src 'self'; script-src 'self' https://challenges.cloudflare.com; style-src 'self' 'unsafe-inline' https://fonts.googleapis.com; img-src 'self' data: https:; font-src 'self' data: https://fonts.gstatic.com; connect-src 'self' https:; frame-src https://challenges.cloudflare.com; frame-ancestors 'none'; base-uri 'self'; form-action 'self'"
+// DefaultCSPPolicy is the default Content-Security-Policy with nonce support
+// __CSP_NONCE__ will be replaced with actual nonce at request time by the SecurityHeaders middleware
+const DefaultCSPPolicy = "default-src 'self'; script-src 'self' __CSP_NONCE__ https://challenges.cloudflare.com https://static.cloudflareinsights.com; style-src 'self' 'unsafe-inline' https://fonts.googleapis.com; img-src 'self' data: https:; font-src 'self' data: https://fonts.gstatic.com; connect-src 'self' https:; frame-src https://challenges.cloudflare.com; frame-ancestors 'none'; base-uri 'self'; form-action 'self'"
// 连接池隔离策略常量
// 用于控制上游 HTTP 连接池的隔离粒度,影响连接复用和资源消耗
diff --git a/backend/internal/server/middleware/security_headers.go b/backend/internal/server/middleware/security_headers.go
index 9fca0cd3..9ce7f449 100644
--- a/backend/internal/server/middleware/security_headers.go
+++ b/backend/internal/server/middleware/security_headers.go
@@ -1,12 +1,40 @@
package middleware
import (
+ "crypto/rand"
+ "encoding/base64"
"strings"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/gin-gonic/gin"
)
+const (
+ // CSPNonceKey is the context key for storing the CSP nonce
+ CSPNonceKey = "csp_nonce"
+ // NonceTemplate is the placeholder in CSP policy for nonce
+ NonceTemplate = "__CSP_NONCE__"
+ // CloudflareInsightsDomain is the domain for Cloudflare Web Analytics
+ CloudflareInsightsDomain = "https://static.cloudflareinsights.com"
+)
+
+// GenerateNonce generates a cryptographically secure random nonce
+func GenerateNonce() string {
+ b := make([]byte, 16)
+ _, _ = rand.Read(b)
+ return base64.StdEncoding.EncodeToString(b)
+}
+
+// GetNonceFromContext retrieves the CSP nonce from gin context
+func GetNonceFromContext(c *gin.Context) string {
+ if nonce, exists := c.Get(CSPNonceKey); exists {
+ if s, ok := nonce.(string); ok {
+ return s
+ }
+ }
+ return ""
+}
+
// SecurityHeaders sets baseline security headers for all responses.
func SecurityHeaders(cfg config.CSPConfig) gin.HandlerFunc {
policy := strings.TrimSpace(cfg.Policy)
@@ -14,13 +42,75 @@ func SecurityHeaders(cfg config.CSPConfig) gin.HandlerFunc {
policy = config.DefaultCSPPolicy
}
+ // Enhance policy with required directives (nonce placeholder and Cloudflare Insights)
+ policy = enhanceCSPPolicy(policy)
+
return func(c *gin.Context) {
c.Header("X-Content-Type-Options", "nosniff")
c.Header("X-Frame-Options", "DENY")
c.Header("Referrer-Policy", "strict-origin-when-cross-origin")
+
if cfg.Enabled {
- c.Header("Content-Security-Policy", policy)
+ // Generate nonce for this request
+ nonce := GenerateNonce()
+ c.Set(CSPNonceKey, nonce)
+
+ // Replace nonce placeholder in policy
+ finalPolicy := strings.ReplaceAll(policy, NonceTemplate, "'nonce-"+nonce+"'")
+ c.Header("Content-Security-Policy", finalPolicy)
}
c.Next()
}
}
+
+// enhanceCSPPolicy ensures the CSP policy includes nonce support and Cloudflare Insights domain.
+// This allows the application to work correctly even if the config file has an older CSP policy.
+func enhanceCSPPolicy(policy string) string {
+ // Add nonce placeholder to script-src if not present
+ if !strings.Contains(policy, NonceTemplate) && !strings.Contains(policy, "'nonce-") {
+ policy = addToDirective(policy, "script-src", NonceTemplate)
+ }
+
+ // Add Cloudflare Insights domain to script-src if not present
+ if !strings.Contains(policy, CloudflareInsightsDomain) {
+ policy = addToDirective(policy, "script-src", CloudflareInsightsDomain)
+ }
+
+ return policy
+}
+
+// addToDirective adds a value to a specific CSP directive.
+// If the directive doesn't exist, it will be added after default-src.
+func addToDirective(policy, directive, value string) string {
+ // Find the directive in the policy
+ directivePrefix := directive + " "
+ idx := strings.Index(policy, directivePrefix)
+
+ if idx == -1 {
+ // Directive not found, add it after default-src or at the beginning
+ defaultSrcIdx := strings.Index(policy, "default-src ")
+ if defaultSrcIdx != -1 {
+ // Find the end of default-src directive (next semicolon)
+ endIdx := strings.Index(policy[defaultSrcIdx:], ";")
+ if endIdx != -1 {
+ insertPos := defaultSrcIdx + endIdx + 1
+ // Insert new directive after default-src
+ return policy[:insertPos] + " " + directive + " 'self' " + value + ";" + policy[insertPos:]
+ }
+ }
+ // Fallback: prepend the directive
+ return directive + " 'self' " + value + "; " + policy
+ }
+
+ // Find the end of this directive (next semicolon or end of string)
+ endIdx := strings.Index(policy[idx:], ";")
+
+ if endIdx == -1 {
+ // No semicolon found, directive goes to end of string
+ return policy + " " + value
+ }
+
+ // Insert value before the semicolon
+ insertPos := idx + endIdx
+ return policy[:insertPos] + " " + value + policy[insertPos:]
+}
diff --git a/backend/internal/server/middleware/security_headers_test.go b/backend/internal/server/middleware/security_headers_test.go
new file mode 100644
index 00000000..dc7a87d8
--- /dev/null
+++ b/backend/internal/server/middleware/security_headers_test.go
@@ -0,0 +1,365 @@
+package middleware
+
+import (
+ "encoding/base64"
+ "net/http"
+ "net/http/httptest"
+ "strings"
+ "testing"
+
+ "github.com/Wei-Shaw/sub2api/internal/config"
+ "github.com/gin-gonic/gin"
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
+)
+
+func init() {
+ gin.SetMode(gin.TestMode)
+}
+
+func TestGenerateNonce(t *testing.T) {
+ t.Run("generates_valid_base64_string", func(t *testing.T) {
+ nonce := GenerateNonce()
+
+ // Should be valid base64
+ decoded, err := base64.StdEncoding.DecodeString(nonce)
+ require.NoError(t, err)
+
+ // Should decode to 16 bytes
+ assert.Len(t, decoded, 16)
+ })
+
+ t.Run("generates_unique_nonces", func(t *testing.T) {
+ nonces := make(map[string]bool)
+ for i := 0; i < 100; i++ {
+ nonce := GenerateNonce()
+ assert.False(t, nonces[nonce], "nonce should be unique")
+ nonces[nonce] = true
+ }
+ })
+
+ t.Run("nonce_has_expected_length", func(t *testing.T) {
+ nonce := GenerateNonce()
+ // 16 bytes -> 24 chars in base64 (with padding)
+ assert.Len(t, nonce, 24)
+ })
+}
+
+func TestGetNonceFromContext(t *testing.T) {
+ t.Run("returns_nonce_when_present", func(t *testing.T) {
+ w := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(w)
+
+ expectedNonce := "test-nonce-123"
+ c.Set(CSPNonceKey, expectedNonce)
+
+ nonce := GetNonceFromContext(c)
+ assert.Equal(t, expectedNonce, nonce)
+ })
+
+ t.Run("returns_empty_string_when_not_present", func(t *testing.T) {
+ w := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(w)
+
+ nonce := GetNonceFromContext(c)
+ assert.Empty(t, nonce)
+ })
+
+ t.Run("returns_empty_for_wrong_type", func(t *testing.T) {
+ w := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(w)
+
+ // Set a non-string value
+ c.Set(CSPNonceKey, 12345)
+
+ // Should return empty string for wrong type (safe type assertion)
+ nonce := GetNonceFromContext(c)
+ assert.Empty(t, nonce)
+ })
+}
+
+func TestSecurityHeaders(t *testing.T) {
+ t.Run("sets_basic_security_headers", func(t *testing.T) {
+ cfg := config.CSPConfig{Enabled: false}
+ middleware := SecurityHeaders(cfg)
+
+ w := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(w)
+ c.Request = httptest.NewRequest(http.MethodGet, "/", nil)
+
+ middleware(c)
+
+ assert.Equal(t, "nosniff", w.Header().Get("X-Content-Type-Options"))
+ assert.Equal(t, "DENY", w.Header().Get("X-Frame-Options"))
+ assert.Equal(t, "strict-origin-when-cross-origin", w.Header().Get("Referrer-Policy"))
+ })
+
+ t.Run("csp_disabled_no_csp_header", func(t *testing.T) {
+ cfg := config.CSPConfig{Enabled: false}
+ middleware := SecurityHeaders(cfg)
+
+ w := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(w)
+ c.Request = httptest.NewRequest(http.MethodGet, "/", nil)
+
+ middleware(c)
+
+ assert.Empty(t, w.Header().Get("Content-Security-Policy"))
+ })
+
+ t.Run("csp_enabled_sets_csp_header", func(t *testing.T) {
+ cfg := config.CSPConfig{
+ Enabled: true,
+ Policy: "default-src 'self'",
+ }
+ middleware := SecurityHeaders(cfg)
+
+ w := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(w)
+ c.Request = httptest.NewRequest(http.MethodGet, "/", nil)
+
+ middleware(c)
+
+ csp := w.Header().Get("Content-Security-Policy")
+ assert.NotEmpty(t, csp)
+ // Policy is auto-enhanced with nonce and Cloudflare Insights domain
+ assert.Contains(t, csp, "default-src 'self'")
+ assert.Contains(t, csp, "'nonce-")
+ assert.Contains(t, csp, CloudflareInsightsDomain)
+ })
+
+ t.Run("csp_enabled_with_nonce_placeholder", func(t *testing.T) {
+ cfg := config.CSPConfig{
+ Enabled: true,
+ Policy: "script-src 'self' __CSP_NONCE__",
+ }
+ middleware := SecurityHeaders(cfg)
+
+ w := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(w)
+ c.Request = httptest.NewRequest(http.MethodGet, "/", nil)
+
+ middleware(c)
+
+ csp := w.Header().Get("Content-Security-Policy")
+ assert.NotEmpty(t, csp)
+ assert.NotContains(t, csp, "__CSP_NONCE__", "placeholder should be replaced")
+ assert.Contains(t, csp, "'nonce-", "should contain nonce directive")
+
+ // Verify nonce is stored in context
+ nonce := GetNonceFromContext(c)
+ assert.NotEmpty(t, nonce)
+ assert.Contains(t, csp, "'nonce-"+nonce+"'")
+ })
+
+ t.Run("uses_default_policy_when_empty", func(t *testing.T) {
+ cfg := config.CSPConfig{
+ Enabled: true,
+ Policy: "",
+ }
+ middleware := SecurityHeaders(cfg)
+
+ w := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(w)
+ c.Request = httptest.NewRequest(http.MethodGet, "/", nil)
+
+ middleware(c)
+
+ csp := w.Header().Get("Content-Security-Policy")
+ assert.NotEmpty(t, csp)
+ // Default policy should contain these elements
+ assert.Contains(t, csp, "default-src 'self'")
+ })
+
+ t.Run("uses_default_policy_when_whitespace_only", func(t *testing.T) {
+ cfg := config.CSPConfig{
+ Enabled: true,
+ Policy: " \t\n ",
+ }
+ middleware := SecurityHeaders(cfg)
+
+ w := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(w)
+ c.Request = httptest.NewRequest(http.MethodGet, "/", nil)
+
+ middleware(c)
+
+ csp := w.Header().Get("Content-Security-Policy")
+ assert.NotEmpty(t, csp)
+ assert.Contains(t, csp, "default-src 'self'")
+ })
+
+ t.Run("multiple_nonce_placeholders_replaced", func(t *testing.T) {
+ cfg := config.CSPConfig{
+ Enabled: true,
+ Policy: "script-src __CSP_NONCE__; style-src __CSP_NONCE__",
+ }
+ middleware := SecurityHeaders(cfg)
+
+ w := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(w)
+ c.Request = httptest.NewRequest(http.MethodGet, "/", nil)
+
+ middleware(c)
+
+ csp := w.Header().Get("Content-Security-Policy")
+ nonce := GetNonceFromContext(c)
+
+ // Count occurrences of the nonce
+ count := strings.Count(csp, "'nonce-"+nonce+"'")
+ assert.Equal(t, 2, count, "both placeholders should be replaced with same nonce")
+ })
+
+ t.Run("calls_next_handler", func(t *testing.T) {
+ cfg := config.CSPConfig{Enabled: true, Policy: "default-src 'self'"}
+ middleware := SecurityHeaders(cfg)
+
+ nextCalled := false
+ router := gin.New()
+ router.Use(middleware)
+ router.GET("/test", func(c *gin.Context) {
+ nextCalled = true
+ c.Status(http.StatusOK)
+ })
+
+ w := httptest.NewRecorder()
+ req := httptest.NewRequest(http.MethodGet, "/test", nil)
+ router.ServeHTTP(w, req)
+
+ assert.True(t, nextCalled, "next handler should be called")
+ assert.Equal(t, http.StatusOK, w.Code)
+ })
+
+ t.Run("nonce_unique_per_request", func(t *testing.T) {
+ cfg := config.CSPConfig{
+ Enabled: true,
+ Policy: "script-src __CSP_NONCE__",
+ }
+ middleware := SecurityHeaders(cfg)
+
+ nonces := make(map[string]bool)
+ for i := 0; i < 10; i++ {
+ w := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(w)
+ c.Request = httptest.NewRequest(http.MethodGet, "/", nil)
+
+ middleware(c)
+
+ nonce := GetNonceFromContext(c)
+ assert.False(t, nonces[nonce], "nonce should be unique per request")
+ nonces[nonce] = true
+ }
+ })
+}
+
+func TestCSPNonceKey(t *testing.T) {
+ t.Run("constant_value", func(t *testing.T) {
+ assert.Equal(t, "csp_nonce", CSPNonceKey)
+ })
+}
+
+func TestNonceTemplate(t *testing.T) {
+ t.Run("constant_value", func(t *testing.T) {
+ assert.Equal(t, "__CSP_NONCE__", NonceTemplate)
+ })
+}
+
+func TestEnhanceCSPPolicy(t *testing.T) {
+ t.Run("adds_nonce_placeholder_if_missing", func(t *testing.T) {
+ policy := "default-src 'self'; script-src 'self'"
+ enhanced := enhanceCSPPolicy(policy)
+
+ assert.Contains(t, enhanced, NonceTemplate)
+ assert.Contains(t, enhanced, CloudflareInsightsDomain)
+ })
+
+ t.Run("does_not_duplicate_nonce_placeholder", func(t *testing.T) {
+ policy := "default-src 'self'; script-src 'self' __CSP_NONCE__"
+ enhanced := enhanceCSPPolicy(policy)
+
+ // Should not duplicate
+ count := strings.Count(enhanced, NonceTemplate)
+ assert.Equal(t, 1, count)
+ })
+
+ t.Run("does_not_duplicate_cloudflare_domain", func(t *testing.T) {
+ policy := "default-src 'self'; script-src 'self' https://static.cloudflareinsights.com"
+ enhanced := enhanceCSPPolicy(policy)
+
+ count := strings.Count(enhanced, CloudflareInsightsDomain)
+ assert.Equal(t, 1, count)
+ })
+
+ t.Run("handles_policy_without_script_src", func(t *testing.T) {
+ policy := "default-src 'self'"
+ enhanced := enhanceCSPPolicy(policy)
+
+ assert.Contains(t, enhanced, "script-src")
+ assert.Contains(t, enhanced, NonceTemplate)
+ assert.Contains(t, enhanced, CloudflareInsightsDomain)
+ })
+
+ t.Run("preserves_existing_nonce", func(t *testing.T) {
+ policy := "script-src 'self' 'nonce-existing'"
+ enhanced := enhanceCSPPolicy(policy)
+
+ // Should not add placeholder if nonce already exists
+ assert.NotContains(t, enhanced, NonceTemplate)
+ assert.Contains(t, enhanced, "'nonce-existing'")
+ })
+}
+
+func TestAddToDirective(t *testing.T) {
+ t.Run("adds_to_existing_directive", func(t *testing.T) {
+ policy := "script-src 'self'; style-src 'self'"
+ result := addToDirective(policy, "script-src", "https://example.com")
+
+ assert.Contains(t, result, "script-src 'self' https://example.com")
+ })
+
+ t.Run("creates_directive_if_not_exists", func(t *testing.T) {
+ policy := "default-src 'self'"
+ result := addToDirective(policy, "script-src", "https://example.com")
+
+ assert.Contains(t, result, "script-src")
+ assert.Contains(t, result, "https://example.com")
+ })
+
+ t.Run("handles_directive_at_end_without_semicolon", func(t *testing.T) {
+ policy := "default-src 'self'; script-src 'self'"
+ result := addToDirective(policy, "script-src", "https://example.com")
+
+ assert.Contains(t, result, "https://example.com")
+ })
+
+ t.Run("handles_empty_policy", func(t *testing.T) {
+ policy := ""
+ result := addToDirective(policy, "script-src", "https://example.com")
+
+ assert.Contains(t, result, "script-src")
+ assert.Contains(t, result, "https://example.com")
+ })
+}
+
+// Benchmark tests
+func BenchmarkGenerateNonce(b *testing.B) {
+ for i := 0; i < b.N; i++ {
+ GenerateNonce()
+ }
+}
+
+func BenchmarkSecurityHeadersMiddleware(b *testing.B) {
+ cfg := config.CSPConfig{
+ Enabled: true,
+ Policy: "script-src 'self' __CSP_NONCE__",
+ }
+ middleware := SecurityHeaders(cfg)
+
+ b.ResetTimer()
+ for i := 0; i < b.N; i++ {
+ w := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(w)
+ c.Request = httptest.NewRequest(http.MethodGet, "/", nil)
+ middleware(c)
+ }
+}
diff --git a/backend/internal/service/timing_wheel_service.go b/backend/internal/service/timing_wheel_service.go
index c4e64e33..5a2dea75 100644
--- a/backend/internal/service/timing_wheel_service.go
+++ b/backend/internal/service/timing_wheel_service.go
@@ -1,6 +1,7 @@
package service
import (
+ "fmt"
"log"
"sync"
"time"
@@ -8,6 +9,8 @@ import (
"github.com/zeromicro/go-zero/core/collection"
)
+var newTimingWheel = collection.NewTimingWheel
+
// TimingWheelService wraps go-zero's TimingWheel for task scheduling
type TimingWheelService struct {
tw *collection.TimingWheel
@@ -15,18 +18,18 @@ type TimingWheelService struct {
}
// NewTimingWheelService creates a new TimingWheelService instance
-func NewTimingWheelService() *TimingWheelService {
+func NewTimingWheelService() (*TimingWheelService, error) {
// 1 second tick, 3600 slots = supports up to 1 hour delay
// execute function: runs func() type tasks
- tw, err := collection.NewTimingWheel(1*time.Second, 3600, func(key, value any) {
+ tw, err := newTimingWheel(1*time.Second, 3600, func(key, value any) {
if fn, ok := value.(func()); ok {
fn()
}
})
if err != nil {
- panic(err)
+ return nil, fmt.Errorf("创建 timing wheel 失败: %w", err)
}
- return &TimingWheelService{tw: tw}
+ return &TimingWheelService{tw: tw}, nil
}
// Start starts the timing wheel
diff --git a/backend/internal/service/timing_wheel_service_test.go b/backend/internal/service/timing_wheel_service_test.go
new file mode 100644
index 00000000..cd0bffb7
--- /dev/null
+++ b/backend/internal/service/timing_wheel_service_test.go
@@ -0,0 +1,146 @@
+package service
+
+import (
+ "errors"
+ "sync/atomic"
+ "testing"
+ "time"
+
+ "github.com/zeromicro/go-zero/core/collection"
+)
+
+func TestNewTimingWheelService_InitFail_NoPanicAndReturnError(t *testing.T) {
+ original := newTimingWheel
+ t.Cleanup(func() { newTimingWheel = original })
+
+ newTimingWheel = func(_ time.Duration, _ int, _ collection.Execute) (*collection.TimingWheel, error) {
+ return nil, errors.New("boom")
+ }
+
+ svc, err := NewTimingWheelService()
+ if err == nil {
+ t.Fatalf("期望返回 error,但得到 nil")
+ }
+ if svc != nil {
+ t.Fatalf("期望返回 nil svc,但得到非空")
+ }
+}
+
+func TestNewTimingWheelService_Success(t *testing.T) {
+ svc, err := NewTimingWheelService()
+ if err != nil {
+ t.Fatalf("期望 err 为 nil,但得到: %v", err)
+ }
+ if svc == nil {
+ t.Fatalf("期望 svc 非空,但得到 nil")
+ }
+ svc.Stop()
+}
+
+func TestNewTimingWheelService_ExecuteCallbackRunsFunc(t *testing.T) {
+ original := newTimingWheel
+ t.Cleanup(func() { newTimingWheel = original })
+
+ var captured collection.Execute
+ newTimingWheel = func(interval time.Duration, numSlots int, execute collection.Execute) (*collection.TimingWheel, error) {
+ captured = execute
+ return original(interval, numSlots, execute)
+ }
+
+ svc, err := NewTimingWheelService()
+ if err != nil {
+ t.Fatalf("期望 err 为 nil,但得到: %v", err)
+ }
+ if captured == nil {
+ t.Fatalf("期望 captured 非空,但得到 nil")
+ }
+
+ called := false
+ captured("k", func() { called = true })
+ if !called {
+ t.Fatalf("期望 execute 回调触发传入函数执行")
+ }
+
+ svc.Stop()
+}
+
+func TestTimingWheelService_Schedule_ExecutesOnce(t *testing.T) {
+ original := newTimingWheel
+ t.Cleanup(func() { newTimingWheel = original })
+
+ newTimingWheel = func(_ time.Duration, _ int, execute collection.Execute) (*collection.TimingWheel, error) {
+ return original(10*time.Millisecond, 128, execute)
+ }
+
+ svc, err := NewTimingWheelService()
+ if err != nil {
+ t.Fatalf("期望 err 为 nil,但得到: %v", err)
+ }
+ defer svc.Stop()
+
+ ch := make(chan struct{}, 1)
+ svc.Schedule("once", 30*time.Millisecond, func() { ch <- struct{}{} })
+
+ select {
+ case <-ch:
+ case <-time.After(500 * time.Millisecond):
+ t.Fatalf("等待任务执行超时")
+ }
+
+ select {
+ case <-ch:
+ t.Fatalf("任务不应重复执行")
+ case <-time.After(80 * time.Millisecond):
+ }
+}
+
+func TestTimingWheelService_Cancel_PreventsExecution(t *testing.T) {
+ original := newTimingWheel
+ t.Cleanup(func() { newTimingWheel = original })
+
+ newTimingWheel = func(_ time.Duration, _ int, execute collection.Execute) (*collection.TimingWheel, error) {
+ return original(10*time.Millisecond, 128, execute)
+ }
+
+ svc, err := NewTimingWheelService()
+ if err != nil {
+ t.Fatalf("期望 err 为 nil,但得到: %v", err)
+ }
+ defer svc.Stop()
+
+ ch := make(chan struct{}, 1)
+ svc.Schedule("cancel", 80*time.Millisecond, func() { ch <- struct{}{} })
+ svc.Cancel("cancel")
+
+ select {
+ case <-ch:
+ t.Fatalf("任务已取消,不应执行")
+ case <-time.After(200 * time.Millisecond):
+ }
+}
+
+func TestTimingWheelService_ScheduleRecurring_ExecutesMultipleTimes(t *testing.T) {
+ original := newTimingWheel
+ t.Cleanup(func() { newTimingWheel = original })
+
+ newTimingWheel = func(_ time.Duration, _ int, execute collection.Execute) (*collection.TimingWheel, error) {
+ return original(10*time.Millisecond, 128, execute)
+ }
+
+ svc, err := NewTimingWheelService()
+ if err != nil {
+ t.Fatalf("期望 err 为 nil,但得到: %v", err)
+ }
+ defer svc.Stop()
+
+ var count int32
+ svc.ScheduleRecurring("rec", 30*time.Millisecond, func() { atomic.AddInt32(&count, 1) })
+
+ deadline := time.Now().Add(500 * time.Millisecond)
+ for atomic.LoadInt32(&count) < 2 && time.Now().Before(deadline) {
+ time.Sleep(10 * time.Millisecond)
+ }
+ if atomic.LoadInt32(&count) < 2 {
+ t.Fatalf("期望周期任务至少执行 2 次,但只执行了 %d 次", atomic.LoadInt32(&count))
+ }
+}
diff --git a/backend/internal/service/wire.go b/backend/internal/service/wire.go
index 5ba093a4..acc0a5fb 100644
--- a/backend/internal/service/wire.go
+++ b/backend/internal/service/wire.go
@@ -65,10 +65,13 @@ func ProvideAccountExpiryService(accountRepo AccountRepository) *AccountExpirySe
}
// ProvideTimingWheelService creates and starts TimingWheelService
-func ProvideTimingWheelService() *TimingWheelService {
- svc := NewTimingWheelService()
+func ProvideTimingWheelService() (*TimingWheelService, error) {
+ svc, err := NewTimingWheelService()
+ if err != nil {
+ return nil, err
+ }
svc.Start()
- return svc
+ return svc, nil
}
// ProvideDeferredService creates and starts DeferredService
diff --git a/backend/internal/service/wire_test.go b/backend/internal/service/wire_test.go
new file mode 100644
index 00000000..5f7866f6
--- /dev/null
+++ b/backend/internal/service/wire_test.go
@@ -0,0 +1,37 @@
+package service
+
+import (
+ "errors"
+ "testing"
+ "time"
+
+ "github.com/zeromicro/go-zero/core/collection"
+)
+
+func TestProvideTimingWheelService_ReturnsError(t *testing.T) {
+ original := newTimingWheel
+ t.Cleanup(func() { newTimingWheel = original })
+
+ newTimingWheel = func(_ time.Duration, _ int, _ collection.Execute) (*collection.TimingWheel, error) {
+ return nil, errors.New("boom")
+ }
+
+ svc, err := ProvideTimingWheelService()
+ if err == nil {
+ t.Fatalf("期望返回 error,但得到 nil")
+ }
+ if svc != nil {
+ t.Fatalf("期望返回 nil svc,但得到非空")
+ }
+}
+
+func TestProvideTimingWheelService_Success(t *testing.T) {
+ svc, err := ProvideTimingWheelService()
+ if err != nil {
+ t.Fatalf("期望 err 为 nil,但得到: %v", err)
+ }
+ if svc == nil {
+ t.Fatalf("期望 svc 非空,但得到 nil")
+ }
+ svc.Stop()
+}
diff --git a/backend/internal/web/embed_on.go b/backend/internal/web/embed_on.go
index 35697fbb..7f37d59c 100644
--- a/backend/internal/web/embed_on.go
+++ b/backend/internal/web/embed_on.go
@@ -13,9 +13,15 @@ import (
"strings"
"time"
+ "github.com/Wei-Shaw/sub2api/internal/server/middleware"
"github.com/gin-gonic/gin"
)
+const (
+ // NonceHTMLPlaceholder is the placeholder for nonce in HTML script tags
+ NonceHTMLPlaceholder = "__CSP_NONCE_VALUE__"
+)
+
//go:embed all:dist
var frontendFS embed.FS
@@ -115,6 +121,9 @@ func (s *FrontendServer) fileExists(path string) bool {
}
func (s *FrontendServer) serveIndexHTML(c *gin.Context) {
+ // Get nonce from context (generated by SecurityHeaders middleware)
+ nonce := middleware.GetNonceFromContext(c)
+
// Check cache first
cached := s.cache.Get()
if cached != nil {
@@ -125,9 +134,12 @@ func (s *FrontendServer) serveIndexHTML(c *gin.Context) {
return
}
+ // Replace nonce placeholder with actual nonce before serving
+ content := replaceNoncePlaceholder(cached.Content, nonce)
+
c.Header("ETag", cached.ETag)
c.Header("Cache-Control", "no-cache") // Must revalidate
- c.Data(http.StatusOK, "text/html; charset=utf-8", cached.Content)
+ c.Data(http.StatusOK, "text/html; charset=utf-8", content)
c.Abort()
return
}
@@ -155,24 +167,33 @@ func (s *FrontendServer) serveIndexHTML(c *gin.Context) {
rendered := s.injectSettings(settingsJSON)
s.cache.Set(rendered, settingsJSON)
+ // Replace nonce placeholder with actual nonce before serving
+ content := replaceNoncePlaceholder(rendered, nonce)
+
cached = s.cache.Get()
if cached != nil {
c.Header("ETag", cached.ETag)
}
c.Header("Cache-Control", "no-cache")
- c.Data(http.StatusOK, "text/html; charset=utf-8", rendered)
+ c.Data(http.StatusOK, "text/html; charset=utf-8", content)
c.Abort()
}
func (s *FrontendServer) injectSettings(settingsJSON []byte) []byte {
- // Create the script tag to inject
- script := []byte(``)
+ // Create the script tag to inject with nonce placeholder
+ // The placeholder will be replaced with actual nonce at request time
+ script := []byte(``)
// Inject before
headClose := []byte("")
return bytes.Replace(s.baseHTML, headClose, append(script, headClose...), 1)
}
+// replaceNoncePlaceholder replaces the nonce placeholder with actual nonce value
+func replaceNoncePlaceholder(html []byte, nonce string) []byte {
+ return bytes.ReplaceAll(html, []byte(NonceHTMLPlaceholder), []byte(nonce))
+}
+
// ServeEmbeddedFrontend returns a middleware for serving embedded frontend
// This is the legacy function for backward compatibility when no settings provider is available
func ServeEmbeddedFrontend() gin.HandlerFunc {
diff --git a/backend/internal/web/embed_test.go b/backend/internal/web/embed_test.go
new file mode 100644
index 00000000..50f5a323
--- /dev/null
+++ b/backend/internal/web/embed_test.go
@@ -0,0 +1,660 @@
+//go:build embed
+
+package web
+
+import (
+ "bytes"
+ "context"
+ "net/http"
+ "net/http/httptest"
+ "strings"
+ "testing"
+
+ "github.com/Wei-Shaw/sub2api/internal/server/middleware"
+ "github.com/gin-gonic/gin"
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
+)
+
+func init() {
+ gin.SetMode(gin.TestMode)
+}
+
+func TestReplaceNoncePlaceholder(t *testing.T) {
+ t.Run("replaces_single_placeholder", func(t *testing.T) {
+ html := []byte(``)
+ nonce := "abc123xyz"
+
+ result := replaceNoncePlaceholder(html, nonce)
+
+ expected := ``
+ assert.Equal(t, expected, string(result))
+ })
+
+ t.Run("replaces_multiple_placeholders", func(t *testing.T) {
+ html := []byte(``)
+ nonce := "nonce123"
+
+ result := replaceNoncePlaceholder(html, nonce)
+
+ assert.Equal(t, 2, strings.Count(string(result), `nonce="nonce123"`))
+ assert.NotContains(t, string(result), NonceHTMLPlaceholder)
+ })
+
+ t.Run("handles_empty_nonce", func(t *testing.T) {
+ html := []byte(``)
+ nonce := ""
+
+ result := replaceNoncePlaceholder(html, nonce)
+
+ assert.Equal(t, ``, string(result))
+ })
+
+ t.Run("no_placeholder_returns_unchanged", func(t *testing.T) {
+ html := []byte(``)
+ nonce := "abc123"
+
+ result := replaceNoncePlaceholder(html, nonce)
+
+ assert.Equal(t, string(html), string(result))
+ })
+
+ t.Run("handles_empty_html", func(t *testing.T) {
+ html := []byte(``)
+ nonce := "abc123"
+
+ result := replaceNoncePlaceholder(html, nonce)
+
+ assert.Empty(t, result)
+ })
+}
+
+func TestNonceHTMLPlaceholder(t *testing.T) {
+ t.Run("constant_value", func(t *testing.T) {
+ assert.Equal(t, "__CSP_NONCE_VALUE__", NonceHTMLPlaceholder)
+ })
+}
+
+// mockSettingsProvider implements PublicSettingsProvider for testing
+type mockSettingsProvider struct {
+ settings any
+ err error
+ called int
+}
+
+func (m *mockSettingsProvider) GetPublicSettingsForInjection(ctx context.Context) (any, error) {
+ m.called++
+ return m.settings, m.err
+}
+
+func TestFrontendServer_InjectSettings(t *testing.T) {
+ t.Run("injects_settings_with_nonce_placeholder", func(t *testing.T) {
+ provider := &mockSettingsProvider{
+ settings: map[string]string{"key": "value"},
+ }
+
+ server, err := NewFrontendServer(provider)
+ require.NoError(t, err)
+
+ settingsJSON := []byte(`{"test":"data"}`)
+ result := server.injectSettings(settingsJSON)
+
+ // Should contain the script with nonce placeholder
+ assert.Contains(t, string(result), ``)
+ })
+
+ t.Run("injects_before_head_close", func(t *testing.T) {
+ provider := &mockSettingsProvider{
+ settings: map[string]string{"key": "value"},
+ }
+
+ server, err := NewFrontendServer(provider)
+ require.NoError(t, err)
+
+ settingsJSON := []byte(`{}`)
+ result := server.injectSettings(settingsJSON)
+
+ // Script should be injected before
+ headCloseIndex := bytes.Index(result, []byte(""))
+ scriptIndex := bytes.Index(result, []byte(`