@@ -78,7 +78,10 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
|
||||
dashboardAggregationRepository := repository.NewDashboardAggregationRepository(db)
|
||||
dashboardStatsCache := repository.NewDashboardCache(redisClient, configConfig)
|
||||
dashboardService := service.NewDashboardService(usageLogRepository, dashboardAggregationRepository, dashboardStatsCache, configConfig)
|
||||
timingWheelService := service.ProvideTimingWheelService()
|
||||
timingWheelService, err := service.ProvideTimingWheelService()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
dashboardAggregationService := service.ProvideDashboardAggregationService(dashboardAggregationRepository, timingWheelService, configConfig)
|
||||
dashboardHandler := admin.NewDashboardHandler(dashboardService, dashboardAggregationService)
|
||||
accountRepository := repository.NewAccountRepository(client, db)
|
||||
|
||||
@@ -19,7 +19,9 @@ const (
|
||||
RunModeSimple = "simple"
|
||||
)
|
||||
|
||||
const DefaultCSPPolicy = "default-src 'self'; script-src 'self' https://challenges.cloudflare.com; style-src 'self' 'unsafe-inline' https://fonts.googleapis.com; img-src 'self' data: https:; font-src 'self' data: https://fonts.gstatic.com; connect-src 'self' https:; frame-src https://challenges.cloudflare.com; frame-ancestors 'none'; base-uri 'self'; form-action 'self'"
|
||||
// DefaultCSPPolicy is the default Content-Security-Policy with nonce support
|
||||
// __CSP_NONCE__ will be replaced with actual nonce at request time by the SecurityHeaders middleware
|
||||
const DefaultCSPPolicy = "default-src 'self'; script-src 'self' __CSP_NONCE__ https://challenges.cloudflare.com https://static.cloudflareinsights.com; style-src 'self' 'unsafe-inline' https://fonts.googleapis.com; img-src 'self' data: https:; font-src 'self' data: https://fonts.gstatic.com; connect-src 'self' https:; frame-src https://challenges.cloudflare.com; frame-ancestors 'none'; base-uri 'self'; form-action 'self'"
|
||||
|
||||
// 连接池隔离策略常量
|
||||
// 用于控制上游 HTTP 连接池的隔离粒度,影响连接复用和资源消耗
|
||||
|
||||
@@ -1,12 +1,40 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"encoding/base64"
|
||||
"strings"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
const (
|
||||
// CSPNonceKey is the context key for storing the CSP nonce
|
||||
CSPNonceKey = "csp_nonce"
|
||||
// NonceTemplate is the placeholder in CSP policy for nonce
|
||||
NonceTemplate = "__CSP_NONCE__"
|
||||
// CloudflareInsightsDomain is the domain for Cloudflare Web Analytics
|
||||
CloudflareInsightsDomain = "https://static.cloudflareinsights.com"
|
||||
)
|
||||
|
||||
// GenerateNonce generates a cryptographically secure random nonce
|
||||
func GenerateNonce() string {
|
||||
b := make([]byte, 16)
|
||||
_, _ = rand.Read(b)
|
||||
return base64.StdEncoding.EncodeToString(b)
|
||||
}
|
||||
|
||||
// GetNonceFromContext retrieves the CSP nonce from gin context
|
||||
func GetNonceFromContext(c *gin.Context) string {
|
||||
if nonce, exists := c.Get(CSPNonceKey); exists {
|
||||
if s, ok := nonce.(string); ok {
|
||||
return s
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// SecurityHeaders sets baseline security headers for all responses.
|
||||
func SecurityHeaders(cfg config.CSPConfig) gin.HandlerFunc {
|
||||
policy := strings.TrimSpace(cfg.Policy)
|
||||
@@ -14,13 +42,75 @@ func SecurityHeaders(cfg config.CSPConfig) gin.HandlerFunc {
|
||||
policy = config.DefaultCSPPolicy
|
||||
}
|
||||
|
||||
// Enhance policy with required directives (nonce placeholder and Cloudflare Insights)
|
||||
policy = enhanceCSPPolicy(policy)
|
||||
|
||||
return func(c *gin.Context) {
|
||||
c.Header("X-Content-Type-Options", "nosniff")
|
||||
c.Header("X-Frame-Options", "DENY")
|
||||
c.Header("Referrer-Policy", "strict-origin-when-cross-origin")
|
||||
|
||||
if cfg.Enabled {
|
||||
c.Header("Content-Security-Policy", policy)
|
||||
// Generate nonce for this request
|
||||
nonce := GenerateNonce()
|
||||
c.Set(CSPNonceKey, nonce)
|
||||
|
||||
// Replace nonce placeholder in policy
|
||||
finalPolicy := strings.ReplaceAll(policy, NonceTemplate, "'nonce-"+nonce+"'")
|
||||
c.Header("Content-Security-Policy", finalPolicy)
|
||||
}
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
// enhanceCSPPolicy ensures the CSP policy includes nonce support and Cloudflare Insights domain.
|
||||
// This allows the application to work correctly even if the config file has an older CSP policy.
|
||||
func enhanceCSPPolicy(policy string) string {
|
||||
// Add nonce placeholder to script-src if not present
|
||||
if !strings.Contains(policy, NonceTemplate) && !strings.Contains(policy, "'nonce-") {
|
||||
policy = addToDirective(policy, "script-src", NonceTemplate)
|
||||
}
|
||||
|
||||
// Add Cloudflare Insights domain to script-src if not present
|
||||
if !strings.Contains(policy, CloudflareInsightsDomain) {
|
||||
policy = addToDirective(policy, "script-src", CloudflareInsightsDomain)
|
||||
}
|
||||
|
||||
return policy
|
||||
}
|
||||
|
||||
// addToDirective adds a value to a specific CSP directive.
|
||||
// If the directive doesn't exist, it will be added after default-src.
|
||||
func addToDirective(policy, directive, value string) string {
|
||||
// Find the directive in the policy
|
||||
directivePrefix := directive + " "
|
||||
idx := strings.Index(policy, directivePrefix)
|
||||
|
||||
if idx == -1 {
|
||||
// Directive not found, add it after default-src or at the beginning
|
||||
defaultSrcIdx := strings.Index(policy, "default-src ")
|
||||
if defaultSrcIdx != -1 {
|
||||
// Find the end of default-src directive (next semicolon)
|
||||
endIdx := strings.Index(policy[defaultSrcIdx:], ";")
|
||||
if endIdx != -1 {
|
||||
insertPos := defaultSrcIdx + endIdx + 1
|
||||
// Insert new directive after default-src
|
||||
return policy[:insertPos] + " " + directive + " 'self' " + value + ";" + policy[insertPos:]
|
||||
}
|
||||
}
|
||||
// Fallback: prepend the directive
|
||||
return directive + " 'self' " + value + "; " + policy
|
||||
}
|
||||
|
||||
// Find the end of this directive (next semicolon or end of string)
|
||||
endIdx := strings.Index(policy[idx:], ";")
|
||||
|
||||
if endIdx == -1 {
|
||||
// No semicolon found, directive goes to end of string
|
||||
return policy + " " + value
|
||||
}
|
||||
|
||||
// Insert value before the semicolon
|
||||
insertPos := idx + endIdx
|
||||
return policy[:insertPos] + " " + value + policy[insertPos:]
|
||||
}
|
||||
|
||||
365
backend/internal/server/middleware/security_headers_test.go
Normal file
365
backend/internal/server/middleware/security_headers_test.go
Normal file
@@ -0,0 +1,365 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func init() {
|
||||
gin.SetMode(gin.TestMode)
|
||||
}
|
||||
|
||||
func TestGenerateNonce(t *testing.T) {
|
||||
t.Run("generates_valid_base64_string", func(t *testing.T) {
|
||||
nonce := GenerateNonce()
|
||||
|
||||
// Should be valid base64
|
||||
decoded, err := base64.StdEncoding.DecodeString(nonce)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Should decode to 16 bytes
|
||||
assert.Len(t, decoded, 16)
|
||||
})
|
||||
|
||||
t.Run("generates_unique_nonces", func(t *testing.T) {
|
||||
nonces := make(map[string]bool)
|
||||
for i := 0; i < 100; i++ {
|
||||
nonce := GenerateNonce()
|
||||
assert.False(t, nonces[nonce], "nonce should be unique")
|
||||
nonces[nonce] = true
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("nonce_has_expected_length", func(t *testing.T) {
|
||||
nonce := GenerateNonce()
|
||||
// 16 bytes -> 24 chars in base64 (with padding)
|
||||
assert.Len(t, nonce, 24)
|
||||
})
|
||||
}
|
||||
|
||||
func TestGetNonceFromContext(t *testing.T) {
|
||||
t.Run("returns_nonce_when_present", func(t *testing.T) {
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
|
||||
expectedNonce := "test-nonce-123"
|
||||
c.Set(CSPNonceKey, expectedNonce)
|
||||
|
||||
nonce := GetNonceFromContext(c)
|
||||
assert.Equal(t, expectedNonce, nonce)
|
||||
})
|
||||
|
||||
t.Run("returns_empty_string_when_not_present", func(t *testing.T) {
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
|
||||
nonce := GetNonceFromContext(c)
|
||||
assert.Empty(t, nonce)
|
||||
})
|
||||
|
||||
t.Run("returns_empty_for_wrong_type", func(t *testing.T) {
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
|
||||
// Set a non-string value
|
||||
c.Set(CSPNonceKey, 12345)
|
||||
|
||||
// Should return empty string for wrong type (safe type assertion)
|
||||
nonce := GetNonceFromContext(c)
|
||||
assert.Empty(t, nonce)
|
||||
})
|
||||
}
|
||||
|
||||
func TestSecurityHeaders(t *testing.T) {
|
||||
t.Run("sets_basic_security_headers", func(t *testing.T) {
|
||||
cfg := config.CSPConfig{Enabled: false}
|
||||
middleware := SecurityHeaders(cfg)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Request = httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
|
||||
middleware(c)
|
||||
|
||||
assert.Equal(t, "nosniff", w.Header().Get("X-Content-Type-Options"))
|
||||
assert.Equal(t, "DENY", w.Header().Get("X-Frame-Options"))
|
||||
assert.Equal(t, "strict-origin-when-cross-origin", w.Header().Get("Referrer-Policy"))
|
||||
})
|
||||
|
||||
t.Run("csp_disabled_no_csp_header", func(t *testing.T) {
|
||||
cfg := config.CSPConfig{Enabled: false}
|
||||
middleware := SecurityHeaders(cfg)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Request = httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
|
||||
middleware(c)
|
||||
|
||||
assert.Empty(t, w.Header().Get("Content-Security-Policy"))
|
||||
})
|
||||
|
||||
t.Run("csp_enabled_sets_csp_header", func(t *testing.T) {
|
||||
cfg := config.CSPConfig{
|
||||
Enabled: true,
|
||||
Policy: "default-src 'self'",
|
||||
}
|
||||
middleware := SecurityHeaders(cfg)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Request = httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
|
||||
middleware(c)
|
||||
|
||||
csp := w.Header().Get("Content-Security-Policy")
|
||||
assert.NotEmpty(t, csp)
|
||||
// Policy is auto-enhanced with nonce and Cloudflare Insights domain
|
||||
assert.Contains(t, csp, "default-src 'self'")
|
||||
assert.Contains(t, csp, "'nonce-")
|
||||
assert.Contains(t, csp, CloudflareInsightsDomain)
|
||||
})
|
||||
|
||||
t.Run("csp_enabled_with_nonce_placeholder", func(t *testing.T) {
|
||||
cfg := config.CSPConfig{
|
||||
Enabled: true,
|
||||
Policy: "script-src 'self' __CSP_NONCE__",
|
||||
}
|
||||
middleware := SecurityHeaders(cfg)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Request = httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
|
||||
middleware(c)
|
||||
|
||||
csp := w.Header().Get("Content-Security-Policy")
|
||||
assert.NotEmpty(t, csp)
|
||||
assert.NotContains(t, csp, "__CSP_NONCE__", "placeholder should be replaced")
|
||||
assert.Contains(t, csp, "'nonce-", "should contain nonce directive")
|
||||
|
||||
// Verify nonce is stored in context
|
||||
nonce := GetNonceFromContext(c)
|
||||
assert.NotEmpty(t, nonce)
|
||||
assert.Contains(t, csp, "'nonce-"+nonce+"'")
|
||||
})
|
||||
|
||||
t.Run("uses_default_policy_when_empty", func(t *testing.T) {
|
||||
cfg := config.CSPConfig{
|
||||
Enabled: true,
|
||||
Policy: "",
|
||||
}
|
||||
middleware := SecurityHeaders(cfg)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Request = httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
|
||||
middleware(c)
|
||||
|
||||
csp := w.Header().Get("Content-Security-Policy")
|
||||
assert.NotEmpty(t, csp)
|
||||
// Default policy should contain these elements
|
||||
assert.Contains(t, csp, "default-src 'self'")
|
||||
})
|
||||
|
||||
t.Run("uses_default_policy_when_whitespace_only", func(t *testing.T) {
|
||||
cfg := config.CSPConfig{
|
||||
Enabled: true,
|
||||
Policy: " \t\n ",
|
||||
}
|
||||
middleware := SecurityHeaders(cfg)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Request = httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
|
||||
middleware(c)
|
||||
|
||||
csp := w.Header().Get("Content-Security-Policy")
|
||||
assert.NotEmpty(t, csp)
|
||||
assert.Contains(t, csp, "default-src 'self'")
|
||||
})
|
||||
|
||||
t.Run("multiple_nonce_placeholders_replaced", func(t *testing.T) {
|
||||
cfg := config.CSPConfig{
|
||||
Enabled: true,
|
||||
Policy: "script-src __CSP_NONCE__; style-src __CSP_NONCE__",
|
||||
}
|
||||
middleware := SecurityHeaders(cfg)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Request = httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
|
||||
middleware(c)
|
||||
|
||||
csp := w.Header().Get("Content-Security-Policy")
|
||||
nonce := GetNonceFromContext(c)
|
||||
|
||||
// Count occurrences of the nonce
|
||||
count := strings.Count(csp, "'nonce-"+nonce+"'")
|
||||
assert.Equal(t, 2, count, "both placeholders should be replaced with same nonce")
|
||||
})
|
||||
|
||||
t.Run("calls_next_handler", func(t *testing.T) {
|
||||
cfg := config.CSPConfig{Enabled: true, Policy: "default-src 'self'"}
|
||||
middleware := SecurityHeaders(cfg)
|
||||
|
||||
nextCalled := false
|
||||
router := gin.New()
|
||||
router.Use(middleware)
|
||||
router.GET("/test", func(c *gin.Context) {
|
||||
nextCalled = true
|
||||
c.Status(http.StatusOK)
|
||||
})
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodGet, "/test", nil)
|
||||
router.ServeHTTP(w, req)
|
||||
|
||||
assert.True(t, nextCalled, "next handler should be called")
|
||||
assert.Equal(t, http.StatusOK, w.Code)
|
||||
})
|
||||
|
||||
t.Run("nonce_unique_per_request", func(t *testing.T) {
|
||||
cfg := config.CSPConfig{
|
||||
Enabled: true,
|
||||
Policy: "script-src __CSP_NONCE__",
|
||||
}
|
||||
middleware := SecurityHeaders(cfg)
|
||||
|
||||
nonces := make(map[string]bool)
|
||||
for i := 0; i < 10; i++ {
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Request = httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
|
||||
middleware(c)
|
||||
|
||||
nonce := GetNonceFromContext(c)
|
||||
assert.False(t, nonces[nonce], "nonce should be unique per request")
|
||||
nonces[nonce] = true
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestCSPNonceKey(t *testing.T) {
|
||||
t.Run("constant_value", func(t *testing.T) {
|
||||
assert.Equal(t, "csp_nonce", CSPNonceKey)
|
||||
})
|
||||
}
|
||||
|
||||
func TestNonceTemplate(t *testing.T) {
|
||||
t.Run("constant_value", func(t *testing.T) {
|
||||
assert.Equal(t, "__CSP_NONCE__", NonceTemplate)
|
||||
})
|
||||
}
|
||||
|
||||
func TestEnhanceCSPPolicy(t *testing.T) {
|
||||
t.Run("adds_nonce_placeholder_if_missing", func(t *testing.T) {
|
||||
policy := "default-src 'self'; script-src 'self'"
|
||||
enhanced := enhanceCSPPolicy(policy)
|
||||
|
||||
assert.Contains(t, enhanced, NonceTemplate)
|
||||
assert.Contains(t, enhanced, CloudflareInsightsDomain)
|
||||
})
|
||||
|
||||
t.Run("does_not_duplicate_nonce_placeholder", func(t *testing.T) {
|
||||
policy := "default-src 'self'; script-src 'self' __CSP_NONCE__"
|
||||
enhanced := enhanceCSPPolicy(policy)
|
||||
|
||||
// Should not duplicate
|
||||
count := strings.Count(enhanced, NonceTemplate)
|
||||
assert.Equal(t, 1, count)
|
||||
})
|
||||
|
||||
t.Run("does_not_duplicate_cloudflare_domain", func(t *testing.T) {
|
||||
policy := "default-src 'self'; script-src 'self' https://static.cloudflareinsights.com"
|
||||
enhanced := enhanceCSPPolicy(policy)
|
||||
|
||||
count := strings.Count(enhanced, CloudflareInsightsDomain)
|
||||
assert.Equal(t, 1, count)
|
||||
})
|
||||
|
||||
t.Run("handles_policy_without_script_src", func(t *testing.T) {
|
||||
policy := "default-src 'self'"
|
||||
enhanced := enhanceCSPPolicy(policy)
|
||||
|
||||
assert.Contains(t, enhanced, "script-src")
|
||||
assert.Contains(t, enhanced, NonceTemplate)
|
||||
assert.Contains(t, enhanced, CloudflareInsightsDomain)
|
||||
})
|
||||
|
||||
t.Run("preserves_existing_nonce", func(t *testing.T) {
|
||||
policy := "script-src 'self' 'nonce-existing'"
|
||||
enhanced := enhanceCSPPolicy(policy)
|
||||
|
||||
// Should not add placeholder if nonce already exists
|
||||
assert.NotContains(t, enhanced, NonceTemplate)
|
||||
assert.Contains(t, enhanced, "'nonce-existing'")
|
||||
})
|
||||
}
|
||||
|
||||
func TestAddToDirective(t *testing.T) {
|
||||
t.Run("adds_to_existing_directive", func(t *testing.T) {
|
||||
policy := "script-src 'self'; style-src 'self'"
|
||||
result := addToDirective(policy, "script-src", "https://example.com")
|
||||
|
||||
assert.Contains(t, result, "script-src 'self' https://example.com")
|
||||
})
|
||||
|
||||
t.Run("creates_directive_if_not_exists", func(t *testing.T) {
|
||||
policy := "default-src 'self'"
|
||||
result := addToDirective(policy, "script-src", "https://example.com")
|
||||
|
||||
assert.Contains(t, result, "script-src")
|
||||
assert.Contains(t, result, "https://example.com")
|
||||
})
|
||||
|
||||
t.Run("handles_directive_at_end_without_semicolon", func(t *testing.T) {
|
||||
policy := "default-src 'self'; script-src 'self'"
|
||||
result := addToDirective(policy, "script-src", "https://example.com")
|
||||
|
||||
assert.Contains(t, result, "https://example.com")
|
||||
})
|
||||
|
||||
t.Run("handles_empty_policy", func(t *testing.T) {
|
||||
policy := ""
|
||||
result := addToDirective(policy, "script-src", "https://example.com")
|
||||
|
||||
assert.Contains(t, result, "script-src")
|
||||
assert.Contains(t, result, "https://example.com")
|
||||
})
|
||||
}
|
||||
|
||||
// Benchmark tests
|
||||
func BenchmarkGenerateNonce(b *testing.B) {
|
||||
for i := 0; i < b.N; i++ {
|
||||
GenerateNonce()
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkSecurityHeadersMiddleware(b *testing.B) {
|
||||
cfg := config.CSPConfig{
|
||||
Enabled: true,
|
||||
Policy: "script-src 'self' __CSP_NONCE__",
|
||||
}
|
||||
middleware := SecurityHeaders(cfg)
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Request = httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
middleware(c)
|
||||
}
|
||||
}
|
||||
@@ -1,6 +1,7 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"log"
|
||||
"sync"
|
||||
"time"
|
||||
@@ -8,6 +9,8 @@ import (
|
||||
"github.com/zeromicro/go-zero/core/collection"
|
||||
)
|
||||
|
||||
var newTimingWheel = collection.NewTimingWheel
|
||||
|
||||
// TimingWheelService wraps go-zero's TimingWheel for task scheduling
|
||||
type TimingWheelService struct {
|
||||
tw *collection.TimingWheel
|
||||
@@ -15,18 +18,18 @@ type TimingWheelService struct {
|
||||
}
|
||||
|
||||
// NewTimingWheelService creates a new TimingWheelService instance
|
||||
func NewTimingWheelService() *TimingWheelService {
|
||||
func NewTimingWheelService() (*TimingWheelService, error) {
|
||||
// 1 second tick, 3600 slots = supports up to 1 hour delay
|
||||
// execute function: runs func() type tasks
|
||||
tw, err := collection.NewTimingWheel(1*time.Second, 3600, func(key, value any) {
|
||||
tw, err := newTimingWheel(1*time.Second, 3600, func(key, value any) {
|
||||
if fn, ok := value.(func()); ok {
|
||||
fn()
|
||||
}
|
||||
})
|
||||
if err != nil {
|
||||
panic(err)
|
||||
return nil, fmt.Errorf("创建 timing wheel 失败: %w", err)
|
||||
}
|
||||
return &TimingWheelService{tw: tw}
|
||||
return &TimingWheelService{tw: tw}, nil
|
||||
}
|
||||
|
||||
// Start starts the timing wheel
|
||||
|
||||
146
backend/internal/service/timing_wheel_service_test.go
Normal file
146
backend/internal/service/timing_wheel_service_test.go
Normal file
@@ -0,0 +1,146 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/zeromicro/go-zero/core/collection"
|
||||
)
|
||||
|
||||
func TestNewTimingWheelService_InitFail_NoPanicAndReturnError(t *testing.T) {
|
||||
original := newTimingWheel
|
||||
t.Cleanup(func() { newTimingWheel = original })
|
||||
|
||||
newTimingWheel = func(_ time.Duration, _ int, _ collection.Execute) (*collection.TimingWheel, error) {
|
||||
return nil, errors.New("boom")
|
||||
}
|
||||
|
||||
svc, err := NewTimingWheelService()
|
||||
if err == nil {
|
||||
t.Fatalf("期望返回 error,但得到 nil")
|
||||
}
|
||||
if svc != nil {
|
||||
t.Fatalf("期望返回 nil svc,但得到非空")
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewTimingWheelService_Success(t *testing.T) {
|
||||
svc, err := NewTimingWheelService()
|
||||
if err != nil {
|
||||
t.Fatalf("期望 err 为 nil,但得到: %v", err)
|
||||
}
|
||||
if svc == nil {
|
||||
t.Fatalf("期望 svc 非空,但得到 nil")
|
||||
}
|
||||
svc.Stop()
|
||||
}
|
||||
|
||||
func TestNewTimingWheelService_ExecuteCallbackRunsFunc(t *testing.T) {
|
||||
original := newTimingWheel
|
||||
t.Cleanup(func() { newTimingWheel = original })
|
||||
|
||||
var captured collection.Execute
|
||||
newTimingWheel = func(interval time.Duration, numSlots int, execute collection.Execute) (*collection.TimingWheel, error) {
|
||||
captured = execute
|
||||
return original(interval, numSlots, execute)
|
||||
}
|
||||
|
||||
svc, err := NewTimingWheelService()
|
||||
if err != nil {
|
||||
t.Fatalf("期望 err 为 nil,但得到: %v", err)
|
||||
}
|
||||
if captured == nil {
|
||||
t.Fatalf("期望 captured 非空,但得到 nil")
|
||||
}
|
||||
|
||||
called := false
|
||||
captured("k", func() { called = true })
|
||||
if !called {
|
||||
t.Fatalf("期望 execute 回调触发传入函数执行")
|
||||
}
|
||||
|
||||
svc.Stop()
|
||||
}
|
||||
|
||||
func TestTimingWheelService_Schedule_ExecutesOnce(t *testing.T) {
|
||||
original := newTimingWheel
|
||||
t.Cleanup(func() { newTimingWheel = original })
|
||||
|
||||
newTimingWheel = func(_ time.Duration, _ int, execute collection.Execute) (*collection.TimingWheel, error) {
|
||||
return original(10*time.Millisecond, 128, execute)
|
||||
}
|
||||
|
||||
svc, err := NewTimingWheelService()
|
||||
if err != nil {
|
||||
t.Fatalf("期望 err 为 nil,但得到: %v", err)
|
||||
}
|
||||
defer svc.Stop()
|
||||
|
||||
ch := make(chan struct{}, 1)
|
||||
svc.Schedule("once", 30*time.Millisecond, func() { ch <- struct{}{} })
|
||||
|
||||
select {
|
||||
case <-ch:
|
||||
case <-time.After(500 * time.Millisecond):
|
||||
t.Fatalf("等待任务执行超时")
|
||||
}
|
||||
|
||||
select {
|
||||
case <-ch:
|
||||
t.Fatalf("任务不应重复执行")
|
||||
case <-time.After(80 * time.Millisecond):
|
||||
}
|
||||
}
|
||||
|
||||
func TestTimingWheelService_Cancel_PreventsExecution(t *testing.T) {
|
||||
original := newTimingWheel
|
||||
t.Cleanup(func() { newTimingWheel = original })
|
||||
|
||||
newTimingWheel = func(_ time.Duration, _ int, execute collection.Execute) (*collection.TimingWheel, error) {
|
||||
return original(10*time.Millisecond, 128, execute)
|
||||
}
|
||||
|
||||
svc, err := NewTimingWheelService()
|
||||
if err != nil {
|
||||
t.Fatalf("期望 err 为 nil,但得到: %v", err)
|
||||
}
|
||||
defer svc.Stop()
|
||||
|
||||
ch := make(chan struct{}, 1)
|
||||
svc.Schedule("cancel", 80*time.Millisecond, func() { ch <- struct{}{} })
|
||||
svc.Cancel("cancel")
|
||||
|
||||
select {
|
||||
case <-ch:
|
||||
t.Fatalf("任务已取消,不应执行")
|
||||
case <-time.After(200 * time.Millisecond):
|
||||
}
|
||||
}
|
||||
|
||||
func TestTimingWheelService_ScheduleRecurring_ExecutesMultipleTimes(t *testing.T) {
|
||||
original := newTimingWheel
|
||||
t.Cleanup(func() { newTimingWheel = original })
|
||||
|
||||
newTimingWheel = func(_ time.Duration, _ int, execute collection.Execute) (*collection.TimingWheel, error) {
|
||||
return original(10*time.Millisecond, 128, execute)
|
||||
}
|
||||
|
||||
svc, err := NewTimingWheelService()
|
||||
if err != nil {
|
||||
t.Fatalf("期望 err 为 nil,但得到: %v", err)
|
||||
}
|
||||
defer svc.Stop()
|
||||
|
||||
var count int32
|
||||
svc.ScheduleRecurring("rec", 30*time.Millisecond, func() { atomic.AddInt32(&count, 1) })
|
||||
|
||||
deadline := time.Now().Add(500 * time.Millisecond)
|
||||
for atomic.LoadInt32(&count) < 2 && time.Now().Before(deadline) {
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
}
|
||||
if atomic.LoadInt32(&count) < 2 {
|
||||
t.Fatalf("期望周期任务至少执行 2 次,但只执行了 %d 次", atomic.LoadInt32(&count))
|
||||
}
|
||||
}
|
||||
@@ -65,10 +65,13 @@ func ProvideAccountExpiryService(accountRepo AccountRepository) *AccountExpirySe
|
||||
}
|
||||
|
||||
// ProvideTimingWheelService creates and starts TimingWheelService
|
||||
func ProvideTimingWheelService() *TimingWheelService {
|
||||
svc := NewTimingWheelService()
|
||||
func ProvideTimingWheelService() (*TimingWheelService, error) {
|
||||
svc, err := NewTimingWheelService()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
svc.Start()
|
||||
return svc
|
||||
return svc, nil
|
||||
}
|
||||
|
||||
// ProvideDeferredService creates and starts DeferredService
|
||||
|
||||
37
backend/internal/service/wire_test.go
Normal file
37
backend/internal/service/wire_test.go
Normal file
@@ -0,0 +1,37 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/zeromicro/go-zero/core/collection"
|
||||
)
|
||||
|
||||
func TestProvideTimingWheelService_ReturnsError(t *testing.T) {
|
||||
original := newTimingWheel
|
||||
t.Cleanup(func() { newTimingWheel = original })
|
||||
|
||||
newTimingWheel = func(_ time.Duration, _ int, _ collection.Execute) (*collection.TimingWheel, error) {
|
||||
return nil, errors.New("boom")
|
||||
}
|
||||
|
||||
svc, err := ProvideTimingWheelService()
|
||||
if err == nil {
|
||||
t.Fatalf("期望返回 error,但得到 nil")
|
||||
}
|
||||
if svc != nil {
|
||||
t.Fatalf("期望返回 nil svc,但得到非空")
|
||||
}
|
||||
}
|
||||
|
||||
func TestProvideTimingWheelService_Success(t *testing.T) {
|
||||
svc, err := ProvideTimingWheelService()
|
||||
if err != nil {
|
||||
t.Fatalf("期望 err 为 nil,但得到: %v", err)
|
||||
}
|
||||
if svc == nil {
|
||||
t.Fatalf("期望 svc 非空,但得到 nil")
|
||||
}
|
||||
svc.Stop()
|
||||
}
|
||||
@@ -13,9 +13,15 @@ import (
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/server/middleware"
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
const (
|
||||
// NonceHTMLPlaceholder is the placeholder for nonce in HTML script tags
|
||||
NonceHTMLPlaceholder = "__CSP_NONCE_VALUE__"
|
||||
)
|
||||
|
||||
//go:embed all:dist
|
||||
var frontendFS embed.FS
|
||||
|
||||
@@ -115,6 +121,9 @@ func (s *FrontendServer) fileExists(path string) bool {
|
||||
}
|
||||
|
||||
func (s *FrontendServer) serveIndexHTML(c *gin.Context) {
|
||||
// Get nonce from context (generated by SecurityHeaders middleware)
|
||||
nonce := middleware.GetNonceFromContext(c)
|
||||
|
||||
// Check cache first
|
||||
cached := s.cache.Get()
|
||||
if cached != nil {
|
||||
@@ -125,9 +134,12 @@ func (s *FrontendServer) serveIndexHTML(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
// Replace nonce placeholder with actual nonce before serving
|
||||
content := replaceNoncePlaceholder(cached.Content, nonce)
|
||||
|
||||
c.Header("ETag", cached.ETag)
|
||||
c.Header("Cache-Control", "no-cache") // Must revalidate
|
||||
c.Data(http.StatusOK, "text/html; charset=utf-8", cached.Content)
|
||||
c.Data(http.StatusOK, "text/html; charset=utf-8", content)
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
@@ -155,24 +167,33 @@ func (s *FrontendServer) serveIndexHTML(c *gin.Context) {
|
||||
rendered := s.injectSettings(settingsJSON)
|
||||
s.cache.Set(rendered, settingsJSON)
|
||||
|
||||
// Replace nonce placeholder with actual nonce before serving
|
||||
content := replaceNoncePlaceholder(rendered, nonce)
|
||||
|
||||
cached = s.cache.Get()
|
||||
if cached != nil {
|
||||
c.Header("ETag", cached.ETag)
|
||||
}
|
||||
c.Header("Cache-Control", "no-cache")
|
||||
c.Data(http.StatusOK, "text/html; charset=utf-8", rendered)
|
||||
c.Data(http.StatusOK, "text/html; charset=utf-8", content)
|
||||
c.Abort()
|
||||
}
|
||||
|
||||
func (s *FrontendServer) injectSettings(settingsJSON []byte) []byte {
|
||||
// Create the script tag to inject
|
||||
script := []byte(`<script>window.__APP_CONFIG__=` + string(settingsJSON) + `;</script>`)
|
||||
// Create the script tag to inject with nonce placeholder
|
||||
// The placeholder will be replaced with actual nonce at request time
|
||||
script := []byte(`<script nonce="` + NonceHTMLPlaceholder + `">window.__APP_CONFIG__=` + string(settingsJSON) + `;</script>`)
|
||||
|
||||
// Inject before </head>
|
||||
headClose := []byte("</head>")
|
||||
return bytes.Replace(s.baseHTML, headClose, append(script, headClose...), 1)
|
||||
}
|
||||
|
||||
// replaceNoncePlaceholder replaces the nonce placeholder with actual nonce value
|
||||
func replaceNoncePlaceholder(html []byte, nonce string) []byte {
|
||||
return bytes.ReplaceAll(html, []byte(NonceHTMLPlaceholder), []byte(nonce))
|
||||
}
|
||||
|
||||
// ServeEmbeddedFrontend returns a middleware for serving embedded frontend
|
||||
// This is the legacy function for backward compatibility when no settings provider is available
|
||||
func ServeEmbeddedFrontend() gin.HandlerFunc {
|
||||
|
||||
660
backend/internal/web/embed_test.go
Normal file
660
backend/internal/web/embed_test.go
Normal file
@@ -0,0 +1,660 @@
|
||||
//go:build embed
|
||||
|
||||
package web
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/server/middleware"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func init() {
|
||||
gin.SetMode(gin.TestMode)
|
||||
}
|
||||
|
||||
func TestReplaceNoncePlaceholder(t *testing.T) {
|
||||
t.Run("replaces_single_placeholder", func(t *testing.T) {
|
||||
html := []byte(`<script nonce="__CSP_NONCE_VALUE__">console.log('test');</script>`)
|
||||
nonce := "abc123xyz"
|
||||
|
||||
result := replaceNoncePlaceholder(html, nonce)
|
||||
|
||||
expected := `<script nonce="abc123xyz">console.log('test');</script>`
|
||||
assert.Equal(t, expected, string(result))
|
||||
})
|
||||
|
||||
t.Run("replaces_multiple_placeholders", func(t *testing.T) {
|
||||
html := []byte(`<script nonce="__CSP_NONCE_VALUE__">a</script><script nonce="__CSP_NONCE_VALUE__">b</script>`)
|
||||
nonce := "nonce123"
|
||||
|
||||
result := replaceNoncePlaceholder(html, nonce)
|
||||
|
||||
assert.Equal(t, 2, strings.Count(string(result), `nonce="nonce123"`))
|
||||
assert.NotContains(t, string(result), NonceHTMLPlaceholder)
|
||||
})
|
||||
|
||||
t.Run("handles_empty_nonce", func(t *testing.T) {
|
||||
html := []byte(`<script nonce="__CSP_NONCE_VALUE__">test</script>`)
|
||||
nonce := ""
|
||||
|
||||
result := replaceNoncePlaceholder(html, nonce)
|
||||
|
||||
assert.Equal(t, `<script nonce="">test</script>`, string(result))
|
||||
})
|
||||
|
||||
t.Run("no_placeholder_returns_unchanged", func(t *testing.T) {
|
||||
html := []byte(`<script>console.log('test');</script>`)
|
||||
nonce := "abc123"
|
||||
|
||||
result := replaceNoncePlaceholder(html, nonce)
|
||||
|
||||
assert.Equal(t, string(html), string(result))
|
||||
})
|
||||
|
||||
t.Run("handles_empty_html", func(t *testing.T) {
|
||||
html := []byte(``)
|
||||
nonce := "abc123"
|
||||
|
||||
result := replaceNoncePlaceholder(html, nonce)
|
||||
|
||||
assert.Empty(t, result)
|
||||
})
|
||||
}
|
||||
|
||||
func TestNonceHTMLPlaceholder(t *testing.T) {
|
||||
t.Run("constant_value", func(t *testing.T) {
|
||||
assert.Equal(t, "__CSP_NONCE_VALUE__", NonceHTMLPlaceholder)
|
||||
})
|
||||
}
|
||||
|
||||
// mockSettingsProvider implements PublicSettingsProvider for testing
|
||||
type mockSettingsProvider struct {
|
||||
settings any
|
||||
err error
|
||||
called int
|
||||
}
|
||||
|
||||
func (m *mockSettingsProvider) GetPublicSettingsForInjection(ctx context.Context) (any, error) {
|
||||
m.called++
|
||||
return m.settings, m.err
|
||||
}
|
||||
|
||||
func TestFrontendServer_InjectSettings(t *testing.T) {
|
||||
t.Run("injects_settings_with_nonce_placeholder", func(t *testing.T) {
|
||||
provider := &mockSettingsProvider{
|
||||
settings: map[string]string{"key": "value"},
|
||||
}
|
||||
|
||||
server, err := NewFrontendServer(provider)
|
||||
require.NoError(t, err)
|
||||
|
||||
settingsJSON := []byte(`{"test":"data"}`)
|
||||
result := server.injectSettings(settingsJSON)
|
||||
|
||||
// Should contain the script with nonce placeholder
|
||||
assert.Contains(t, string(result), `<script nonce="__CSP_NONCE_VALUE__">`)
|
||||
assert.Contains(t, string(result), `window.__APP_CONFIG__={"test":"data"};`)
|
||||
assert.Contains(t, string(result), `</script></head>`)
|
||||
})
|
||||
|
||||
t.Run("injects_before_head_close", func(t *testing.T) {
|
||||
provider := &mockSettingsProvider{
|
||||
settings: map[string]string{"key": "value"},
|
||||
}
|
||||
|
||||
server, err := NewFrontendServer(provider)
|
||||
require.NoError(t, err)
|
||||
|
||||
settingsJSON := []byte(`{}`)
|
||||
result := server.injectSettings(settingsJSON)
|
||||
|
||||
// Script should be injected before </head>
|
||||
headCloseIndex := bytes.Index(result, []byte("</head>"))
|
||||
scriptIndex := bytes.Index(result, []byte(`<script nonce="`))
|
||||
|
||||
assert.True(t, scriptIndex < headCloseIndex, "script should be before </head>")
|
||||
})
|
||||
|
||||
t.Run("handles_complex_settings", func(t *testing.T) {
|
||||
provider := &mockSettingsProvider{
|
||||
settings: map[string]any{
|
||||
"nested": map[string]any{
|
||||
"array": []int{1, 2, 3},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
server, err := NewFrontendServer(provider)
|
||||
require.NoError(t, err)
|
||||
|
||||
settingsJSON := []byte(`{"nested":{"array":[1,2,3]},"special":"<>&"}`)
|
||||
result := server.injectSettings(settingsJSON)
|
||||
|
||||
assert.Contains(t, string(result), `window.__APP_CONFIG__={"nested":{"array":[1,2,3]},"special":"<>&"};`)
|
||||
})
|
||||
}
|
||||
|
||||
func TestFrontendServer_ServeIndexHTML(t *testing.T) {
|
||||
t.Run("serves_html_with_nonce", func(t *testing.T) {
|
||||
provider := &mockSettingsProvider{
|
||||
settings: map[string]string{"test": "value"},
|
||||
}
|
||||
|
||||
server, err := NewFrontendServer(provider)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Create a gin context with nonce
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Request = httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
|
||||
// Set nonce in context (simulating SecurityHeaders middleware)
|
||||
testNonce := "test-nonce-12345"
|
||||
c.Set(middleware.CSPNonceKey, testNonce)
|
||||
|
||||
server.serveIndexHTML(c)
|
||||
|
||||
assert.Equal(t, http.StatusOK, w.Code)
|
||||
assert.Contains(t, w.Header().Get("Content-Type"), "text/html")
|
||||
|
||||
body := w.Body.String()
|
||||
// Nonce placeholder should be replaced
|
||||
assert.NotContains(t, body, NonceHTMLPlaceholder)
|
||||
assert.Contains(t, body, `nonce="`+testNonce+`"`)
|
||||
})
|
||||
|
||||
t.Run("caches_html_content", func(t *testing.T) {
|
||||
provider := &mockSettingsProvider{
|
||||
settings: map[string]string{"test": "value"},
|
||||
}
|
||||
|
||||
server, err := NewFrontendServer(provider)
|
||||
require.NoError(t, err)
|
||||
|
||||
// First request
|
||||
w1 := httptest.NewRecorder()
|
||||
c1, _ := gin.CreateTestContext(w1)
|
||||
c1.Request = httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
c1.Set(middleware.CSPNonceKey, "nonce1")
|
||||
|
||||
server.serveIndexHTML(c1)
|
||||
assert.Equal(t, 1, provider.called)
|
||||
|
||||
// Second request - should use cache
|
||||
w2 := httptest.NewRecorder()
|
||||
c2, _ := gin.CreateTestContext(w2)
|
||||
c2.Request = httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
c2.Set(middleware.CSPNonceKey, "nonce2")
|
||||
|
||||
server.serveIndexHTML(c2)
|
||||
// Settings provider should not be called again
|
||||
assert.Equal(t, 1, provider.called)
|
||||
|
||||
// But nonce should be different
|
||||
assert.Contains(t, w2.Body.String(), `nonce="nonce2"`)
|
||||
})
|
||||
|
||||
t.Run("sets_etag_header", func(t *testing.T) {
|
||||
provider := &mockSettingsProvider{
|
||||
settings: map[string]string{"test": "value"},
|
||||
}
|
||||
|
||||
server, err := NewFrontendServer(provider)
|
||||
require.NoError(t, err)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Request = httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
c.Set(middleware.CSPNonceKey, "nonce123")
|
||||
|
||||
server.serveIndexHTML(c)
|
||||
|
||||
etag := w.Header().Get("ETag")
|
||||
assert.NotEmpty(t, etag)
|
||||
assert.True(t, strings.HasPrefix(etag, `"`))
|
||||
assert.True(t, strings.HasSuffix(etag, `"`))
|
||||
})
|
||||
|
||||
t.Run("returns_304_for_matching_etag", func(t *testing.T) {
|
||||
provider := &mockSettingsProvider{
|
||||
settings: map[string]string{"test": "value"},
|
||||
}
|
||||
|
||||
server, err := NewFrontendServer(provider)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Use a real router for proper 304 handling
|
||||
router := gin.New()
|
||||
router.Use(func(c *gin.Context) {
|
||||
c.Set(middleware.CSPNonceKey, "test-nonce")
|
||||
c.Next()
|
||||
})
|
||||
router.Use(server.Middleware())
|
||||
|
||||
// First request to populate cache and get ETag
|
||||
w1 := httptest.NewRecorder()
|
||||
req1 := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
router.ServeHTTP(w1, req1)
|
||||
etag := w1.Header().Get("ETag")
|
||||
require.NotEmpty(t, etag)
|
||||
|
||||
// Second request with If-None-Match
|
||||
w2 := httptest.NewRecorder()
|
||||
req2 := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
req2.Header.Set("If-None-Match", etag)
|
||||
router.ServeHTTP(w2, req2)
|
||||
|
||||
assert.Equal(t, http.StatusNotModified, w2.Code)
|
||||
assert.Empty(t, w2.Body.String())
|
||||
})
|
||||
|
||||
t.Run("sets_cache_control_header", func(t *testing.T) {
|
||||
provider := &mockSettingsProvider{
|
||||
settings: map[string]string{"test": "value"},
|
||||
}
|
||||
|
||||
server, err := NewFrontendServer(provider)
|
||||
require.NoError(t, err)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Request = httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
c.Set(middleware.CSPNonceKey, "nonce123")
|
||||
|
||||
server.serveIndexHTML(c)
|
||||
|
||||
assert.Equal(t, "no-cache", w.Header().Get("Cache-Control"))
|
||||
})
|
||||
|
||||
t.Run("fallback_on_settings_error", func(t *testing.T) {
|
||||
provider := &mockSettingsProvider{
|
||||
err: context.DeadlineExceeded,
|
||||
}
|
||||
|
||||
server, err := NewFrontendServer(provider)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Invalidate cache to force settings fetch
|
||||
server.InvalidateCache()
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Request = httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
c.Set(middleware.CSPNonceKey, "nonce123")
|
||||
|
||||
server.serveIndexHTML(c)
|
||||
|
||||
// Should still return 200 with base HTML
|
||||
assert.Equal(t, http.StatusOK, w.Code)
|
||||
assert.Contains(t, w.Header().Get("Content-Type"), "text/html")
|
||||
})
|
||||
}
|
||||
|
||||
func TestFrontendServer_InvalidateCache(t *testing.T) {
|
||||
t.Run("invalidates_cache", func(t *testing.T) {
|
||||
provider := &mockSettingsProvider{
|
||||
settings: map[string]string{"test": "value"},
|
||||
}
|
||||
|
||||
server, err := NewFrontendServer(provider)
|
||||
require.NoError(t, err)
|
||||
|
||||
// First request to populate cache
|
||||
w1 := httptest.NewRecorder()
|
||||
c1, _ := gin.CreateTestContext(w1)
|
||||
c1.Request = httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
c1.Set(middleware.CSPNonceKey, "nonce1")
|
||||
|
||||
server.serveIndexHTML(c1)
|
||||
assert.Equal(t, 1, provider.called)
|
||||
|
||||
// Invalidate cache
|
||||
server.InvalidateCache()
|
||||
|
||||
// Update settings
|
||||
provider.settings = map[string]string{"test": "new_value"}
|
||||
|
||||
// Second request should fetch new settings
|
||||
w2 := httptest.NewRecorder()
|
||||
c2, _ := gin.CreateTestContext(w2)
|
||||
c2.Request = httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
c2.Set(middleware.CSPNonceKey, "nonce2")
|
||||
|
||||
server.serveIndexHTML(c2)
|
||||
assert.Equal(t, 2, provider.called)
|
||||
})
|
||||
|
||||
t.Run("handles_nil_server", func(t *testing.T) {
|
||||
var server *FrontendServer
|
||||
// Should not panic
|
||||
assert.NotPanics(t, func() {
|
||||
server.InvalidateCache()
|
||||
})
|
||||
})
|
||||
|
||||
t.Run("handles_nil_cache", func(t *testing.T) {
|
||||
server := &FrontendServer{}
|
||||
// Should not panic
|
||||
assert.NotPanics(t, func() {
|
||||
server.InvalidateCache()
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
func TestFrontendServer_Middleware(t *testing.T) {
|
||||
t.Run("skips_api_routes", func(t *testing.T) {
|
||||
provider := &mockSettingsProvider{
|
||||
settings: map[string]string{"test": "value"},
|
||||
}
|
||||
|
||||
server, err := NewFrontendServer(provider)
|
||||
require.NoError(t, err)
|
||||
|
||||
apiPaths := []string{
|
||||
"/api/v1/users",
|
||||
"/v1/models",
|
||||
"/v1beta/chat",
|
||||
"/antigravity/test",
|
||||
"/setup/init",
|
||||
"/health",
|
||||
"/responses",
|
||||
}
|
||||
|
||||
for _, path := range apiPaths {
|
||||
t.Run(path, func(t *testing.T) {
|
||||
router := gin.New()
|
||||
router.Use(server.Middleware())
|
||||
nextCalled := false
|
||||
router.GET(path, func(c *gin.Context) {
|
||||
nextCalled = true
|
||||
c.String(http.StatusOK, "ok")
|
||||
})
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodGet, path, nil)
|
||||
router.ServeHTTP(w, req)
|
||||
|
||||
assert.True(t, nextCalled, "next handler should be called for API route")
|
||||
})
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("serves_index_for_spa_routes", func(t *testing.T) {
|
||||
provider := &mockSettingsProvider{
|
||||
settings: map[string]string{"test": "value"},
|
||||
}
|
||||
|
||||
server, err := NewFrontendServer(provider)
|
||||
require.NoError(t, err)
|
||||
|
||||
router := gin.New()
|
||||
router.Use(func(c *gin.Context) {
|
||||
c.Set(middleware.CSPNonceKey, "test-nonce")
|
||||
c.Next()
|
||||
})
|
||||
router.Use(server.Middleware())
|
||||
|
||||
spaPaths := []string{
|
||||
"/",
|
||||
"/dashboard",
|
||||
"/users/123",
|
||||
"/settings/profile",
|
||||
}
|
||||
|
||||
for _, path := range spaPaths {
|
||||
t.Run(path, func(t *testing.T) {
|
||||
w := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodGet, path, nil)
|
||||
router.ServeHTTP(w, req)
|
||||
|
||||
assert.Equal(t, http.StatusOK, w.Code)
|
||||
assert.Contains(t, w.Header().Get("Content-Type"), "text/html")
|
||||
})
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("serves_static_files", func(t *testing.T) {
|
||||
provider := &mockSettingsProvider{
|
||||
settings: map[string]string{"test": "value"},
|
||||
}
|
||||
|
||||
server, err := NewFrontendServer(provider)
|
||||
require.NoError(t, err)
|
||||
|
||||
router := gin.New()
|
||||
router.Use(server.Middleware())
|
||||
|
||||
// Request for existing static file
|
||||
w := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodGet, "/logo.png", nil)
|
||||
router.ServeHTTP(w, req)
|
||||
|
||||
assert.Equal(t, http.StatusOK, w.Code)
|
||||
assert.Contains(t, w.Header().Get("Content-Type"), "image/png")
|
||||
})
|
||||
}
|
||||
|
||||
func TestNewFrontendServer(t *testing.T) {
|
||||
t.Run("creates_server_successfully", func(t *testing.T) {
|
||||
provider := &mockSettingsProvider{
|
||||
settings: map[string]string{"test": "value"},
|
||||
}
|
||||
|
||||
server, err := NewFrontendServer(provider)
|
||||
|
||||
require.NoError(t, err)
|
||||
assert.NotNil(t, server)
|
||||
assert.NotNil(t, server.distFS)
|
||||
assert.NotNil(t, server.fileServer)
|
||||
assert.NotNil(t, server.baseHTML)
|
||||
assert.NotNil(t, server.cache)
|
||||
assert.Equal(t, provider, server.settings)
|
||||
})
|
||||
|
||||
t.Run("reads_base_html", func(t *testing.T) {
|
||||
provider := &mockSettingsProvider{
|
||||
settings: map[string]string{"test": "value"},
|
||||
}
|
||||
|
||||
server, err := NewFrontendServer(provider)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.NotEmpty(t, server.baseHTML)
|
||||
assert.Contains(t, string(server.baseHTML), "<!doctype html>")
|
||||
})
|
||||
}
|
||||
|
||||
func TestHasEmbeddedFrontend(t *testing.T) {
|
||||
t.Run("returns_true_when_frontend_embedded", func(t *testing.T) {
|
||||
result := HasEmbeddedFrontend()
|
||||
assert.True(t, result)
|
||||
})
|
||||
}
|
||||
|
||||
// Tests for legacy ServeEmbeddedFrontend function
|
||||
func TestServeEmbeddedFrontend(t *testing.T) {
|
||||
t.Run("serves_static_files", func(t *testing.T) {
|
||||
middleware := ServeEmbeddedFrontend()
|
||||
|
||||
router := gin.New()
|
||||
router.Use(middleware)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodGet, "/logo.png", nil)
|
||||
router.ServeHTTP(w, req)
|
||||
|
||||
assert.Equal(t, http.StatusOK, w.Code)
|
||||
assert.Contains(t, w.Header().Get("Content-Type"), "image/png")
|
||||
})
|
||||
|
||||
t.Run("serves_index_html_for_root", func(t *testing.T) {
|
||||
middleware := ServeEmbeddedFrontend()
|
||||
|
||||
router := gin.New()
|
||||
router.Use(middleware)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
router.ServeHTTP(w, req)
|
||||
|
||||
assert.Equal(t, http.StatusOK, w.Code)
|
||||
assert.Contains(t, w.Header().Get("Content-Type"), "text/html")
|
||||
assert.Contains(t, w.Body.String(), "<!doctype html>")
|
||||
})
|
||||
|
||||
t.Run("serves_index_html_for_spa_routes", func(t *testing.T) {
|
||||
middleware := ServeEmbeddedFrontend()
|
||||
|
||||
router := gin.New()
|
||||
router.Use(middleware)
|
||||
|
||||
spaPaths := []string{"/dashboard", "/users/123", "/settings"}
|
||||
|
||||
for _, path := range spaPaths {
|
||||
t.Run(path, func(t *testing.T) {
|
||||
w := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodGet, path, nil)
|
||||
router.ServeHTTP(w, req)
|
||||
|
||||
assert.Equal(t, http.StatusOK, w.Code)
|
||||
assert.Contains(t, w.Header().Get("Content-Type"), "text/html")
|
||||
})
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("skips_api_routes", func(t *testing.T) {
|
||||
middleware := ServeEmbeddedFrontend()
|
||||
|
||||
apiPaths := []string{
|
||||
"/api/users",
|
||||
"/v1/models",
|
||||
"/v1beta/chat",
|
||||
"/antigravity/test",
|
||||
"/setup/init",
|
||||
"/health",
|
||||
"/responses",
|
||||
}
|
||||
|
||||
for _, path := range apiPaths {
|
||||
t.Run(path, func(t *testing.T) {
|
||||
nextCalled := false
|
||||
router := gin.New()
|
||||
router.Use(middleware)
|
||||
router.GET(path, func(c *gin.Context) {
|
||||
nextCalled = true
|
||||
c.String(http.StatusOK, "ok")
|
||||
})
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodGet, path, nil)
|
||||
router.ServeHTTP(w, req)
|
||||
|
||||
assert.True(t, nextCalled, "next handler should be called for API route")
|
||||
})
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// Tests for HTMLCache
|
||||
func TestHTMLCache(t *testing.T) {
|
||||
t.Run("new_cache_returns_nil", func(t *testing.T) {
|
||||
cache := NewHTMLCache()
|
||||
assert.Nil(t, cache.Get())
|
||||
})
|
||||
|
||||
t.Run("set_and_get", func(t *testing.T) {
|
||||
cache := NewHTMLCache()
|
||||
cache.SetBaseHTML([]byte("<html></html>"))
|
||||
|
||||
html := []byte("<html><body>test</body></html>")
|
||||
settings := []byte(`{"key":"value"}`)
|
||||
cache.Set(html, settings)
|
||||
|
||||
result := cache.Get()
|
||||
require.NotNil(t, result)
|
||||
assert.Equal(t, html, result.Content)
|
||||
assert.NotEmpty(t, result.ETag)
|
||||
})
|
||||
|
||||
t.Run("invalidate_clears_cache", func(t *testing.T) {
|
||||
cache := NewHTMLCache()
|
||||
cache.SetBaseHTML([]byte("<html></html>"))
|
||||
|
||||
html := []byte("<html><body>test</body></html>")
|
||||
settings := []byte(`{"key":"value"}`)
|
||||
cache.Set(html, settings)
|
||||
|
||||
require.NotNil(t, cache.Get())
|
||||
|
||||
cache.Invalidate()
|
||||
|
||||
assert.Nil(t, cache.Get())
|
||||
})
|
||||
|
||||
t.Run("etag_changes_with_settings", func(t *testing.T) {
|
||||
cache := NewHTMLCache()
|
||||
cache.SetBaseHTML([]byte("<html></html>"))
|
||||
|
||||
html := []byte("<html><body>test</body></html>")
|
||||
|
||||
cache.Set(html, []byte(`{"v":1}`))
|
||||
etag1 := cache.Get().ETag
|
||||
|
||||
cache.Invalidate()
|
||||
cache.Set(html, []byte(`{"v":2}`))
|
||||
etag2 := cache.Get().ETag
|
||||
|
||||
assert.NotEqual(t, etag1, etag2)
|
||||
})
|
||||
|
||||
t.Run("etag_format", func(t *testing.T) {
|
||||
cache := NewHTMLCache()
|
||||
cache.SetBaseHTML([]byte("<html></html>"))
|
||||
|
||||
cache.Set([]byte("<html></html>"), []byte(`{}`))
|
||||
result := cache.Get()
|
||||
|
||||
// ETag should be quoted
|
||||
assert.True(t, strings.HasPrefix(result.ETag, `"`))
|
||||
assert.True(t, strings.HasSuffix(result.ETag, `"`))
|
||||
// Should contain dash separator
|
||||
assert.Contains(t, result.ETag[1:len(result.ETag)-1], "-")
|
||||
})
|
||||
}
|
||||
|
||||
// Benchmark tests
|
||||
func BenchmarkReplaceNoncePlaceholder(b *testing.B) {
|
||||
html := []byte(`<!DOCTYPE html><html><head><script nonce="__CSP_NONCE_VALUE__">window.__APP_CONFIG__={"test":"data"};</script></head><body></body></html>`)
|
||||
nonce := "abcdefghijklmnop123456=="
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
replaceNoncePlaceholder(html, nonce)
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkFrontendServerServeIndexHTML(b *testing.B) {
|
||||
provider := &mockSettingsProvider{
|
||||
settings: map[string]string{"test": "value"},
|
||||
}
|
||||
|
||||
server, _ := NewFrontendServer(provider)
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Request = httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
c.Set(middleware.CSPNonceKey, "test-nonce")
|
||||
|
||||
server.serveIndexHTML(c)
|
||||
}
|
||||
}
|
||||
@@ -97,7 +97,9 @@ security:
|
||||
enabled: true
|
||||
# Default CSP policy (override if you host assets on other domains)
|
||||
# 默认 CSP 策略(如果静态资源托管在其他域名,请自行覆盖)
|
||||
policy: "default-src 'self'; script-src 'self' https://challenges.cloudflare.com; style-src 'self' 'unsafe-inline' https://fonts.googleapis.com; img-src 'self' data: https:; font-src 'self' data: https://fonts.gstatic.com; connect-src 'self' https:; frame-src https://challenges.cloudflare.com; frame-ancestors 'none'; base-uri 'self'; form-action 'self'"
|
||||
# Note: __CSP_NONCE__ will be replaced with 'nonce-xxx' at request time for inline script security
|
||||
# 注意:__CSP_NONCE__ 会在请求时被替换为 'nonce-xxx',用于内联脚本安全
|
||||
policy: "default-src 'self'; script-src 'self' __CSP_NONCE__ https://challenges.cloudflare.com https://static.cloudflareinsights.com; style-src 'self' 'unsafe-inline' https://fonts.googleapis.com; img-src 'self' data: https:; font-src 'self' data: https://fonts.gstatic.com; connect-src 'self' https:; frame-src https://challenges.cloudflare.com; frame-ancestors 'none'; base-uri 'self'; form-action 'self'"
|
||||
proxy_probe:
|
||||
# Allow skipping TLS verification for proxy probe (debug only)
|
||||
# 允许代理探测时跳过 TLS 证书验证(仅用于调试)
|
||||
|
||||
@@ -400,9 +400,33 @@ router.beforeEach((to, _from, next) => {
|
||||
|
||||
/**
|
||||
* Navigation guard: Error handling
|
||||
* Handles dynamic import failures caused by deployment updates
|
||||
*/
|
||||
router.onError((error) => {
|
||||
console.error('Router error:', error)
|
||||
|
||||
// Check if this is a dynamic import failure (chunk loading error)
|
||||
const isChunkLoadError =
|
||||
error.message?.includes('Failed to fetch dynamically imported module') ||
|
||||
error.message?.includes('Loading chunk') ||
|
||||
error.message?.includes('Loading CSS chunk') ||
|
||||
error.name === 'ChunkLoadError'
|
||||
|
||||
if (isChunkLoadError) {
|
||||
// Avoid infinite reload loop by checking sessionStorage
|
||||
const reloadKey = 'chunk_reload_attempted'
|
||||
const lastReload = sessionStorage.getItem(reloadKey)
|
||||
const now = Date.now()
|
||||
|
||||
// Allow reload if never attempted or more than 10 seconds ago
|
||||
if (!lastReload || now - parseInt(lastReload) > 10000) {
|
||||
sessionStorage.setItem(reloadKey, now.toString())
|
||||
console.warn('Chunk load error detected, reloading page to fetch latest version...')
|
||||
window.location.reload()
|
||||
} else {
|
||||
console.error('Chunk load error persists after reload. Please clear browser cache.')
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
export default router
|
||||
|
||||
Reference in New Issue
Block a user