test: 为代码审核修复添加详细单元测试(7个测试文件,50+测试用例)
新增测试文件: - cors_test.go: CORS 条件化头部测试(12个测试,覆盖白名单/黑名单/通配符/凭证/多源/Vary) - gateway_helper_backoff_test.go: nextBackoff 退避测试(6个测试+基准,验证指数增长/边界/抖动/收敛) - billing_cache_jitter_test.go: jitteredTTL 抖动测试(5个测试+基准,验证范围/上界/方差/均值) - subscription_calculate_progress_test.go: calculateProgress 纯函数测试(9个测试,覆盖日/周/月限额/超限截断/过期) - openai_gateway_handler_test.go: SSE JSON 转义测试(7个子用例,验证双引号/反斜杠/换行符安全) 更新测试文件: - response_transformer_test.go: 增强 generateRandomID 测试(7个测试,含并发/字符集/降级计数器) - security_headers_test.go: 适配 GenerateNonce 新签名 - api_key_auth_test.go: 适配 NewSubscriptionService 新参数 Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
308
backend/internal/server/middleware/cors_test.go
Normal file
308
backend/internal/server/middleware/cors_test.go
Normal file
@@ -0,0 +1,308 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func init() {
|
||||
// cors_test 与 security_headers_test 在同一个包,但 init 是幂等的
|
||||
gin.SetMode(gin.TestMode)
|
||||
}
|
||||
|
||||
// --- Task 8.2: 验证 CORS 条件化头部 ---
|
||||
|
||||
func TestCORS_DisallowedOrigin_NoAllowHeaders(t *testing.T) {
|
||||
cfg := config.CORSConfig{
|
||||
AllowedOrigins: []string{"https://allowed.example.com"},
|
||||
AllowCredentials: false,
|
||||
}
|
||||
middleware := CORS(cfg)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
method string
|
||||
origin string
|
||||
}{
|
||||
{
|
||||
name: "preflight_disallowed_origin",
|
||||
method: http.MethodOptions,
|
||||
origin: "https://evil.example.com",
|
||||
},
|
||||
{
|
||||
name: "get_disallowed_origin",
|
||||
method: http.MethodGet,
|
||||
origin: "https://evil.example.com",
|
||||
},
|
||||
{
|
||||
name: "post_disallowed_origin",
|
||||
method: http.MethodPost,
|
||||
origin: "https://attacker.example.com",
|
||||
},
|
||||
{
|
||||
name: "preflight_no_origin",
|
||||
method: http.MethodOptions,
|
||||
origin: "",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Request = httptest.NewRequest(tt.method, "/", nil)
|
||||
if tt.origin != "" {
|
||||
c.Request.Header.Set("Origin", tt.origin)
|
||||
}
|
||||
|
||||
middleware(c)
|
||||
|
||||
// 不应设置 Allow-Headers、Allow-Methods 和 Max-Age
|
||||
assert.Empty(t, w.Header().Get("Access-Control-Allow-Headers"),
|
||||
"不允许的 origin 不应收到 Allow-Headers")
|
||||
assert.Empty(t, w.Header().Get("Access-Control-Allow-Methods"),
|
||||
"不允许的 origin 不应收到 Allow-Methods")
|
||||
assert.Empty(t, w.Header().Get("Access-Control-Max-Age"),
|
||||
"不允许的 origin 不应收到 Max-Age")
|
||||
assert.Empty(t, w.Header().Get("Access-Control-Allow-Origin"),
|
||||
"不允许的 origin 不应收到 Allow-Origin")
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestCORS_AllowedOrigin_HasAllowHeaders(t *testing.T) {
|
||||
cfg := config.CORSConfig{
|
||||
AllowedOrigins: []string{"https://allowed.example.com"},
|
||||
AllowCredentials: false,
|
||||
}
|
||||
middleware := CORS(cfg)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
method string
|
||||
}{
|
||||
{name: "preflight_OPTIONS", method: http.MethodOptions},
|
||||
{name: "normal_GET", method: http.MethodGet},
|
||||
{name: "normal_POST", method: http.MethodPost},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Request = httptest.NewRequest(tt.method, "/", nil)
|
||||
c.Request.Header.Set("Origin", "https://allowed.example.com")
|
||||
|
||||
middleware(c)
|
||||
|
||||
// 应设置 Allow-Headers、Allow-Methods 和 Max-Age
|
||||
assert.NotEmpty(t, w.Header().Get("Access-Control-Allow-Headers"),
|
||||
"允许的 origin 应收到 Allow-Headers")
|
||||
assert.NotEmpty(t, w.Header().Get("Access-Control-Allow-Methods"),
|
||||
"允许的 origin 应收到 Allow-Methods")
|
||||
assert.Equal(t, "86400", w.Header().Get("Access-Control-Max-Age"),
|
||||
"允许的 origin 应收到 Max-Age=86400")
|
||||
assert.Equal(t, "https://allowed.example.com", w.Header().Get("Access-Control-Allow-Origin"),
|
||||
"允许的 origin 应收到 Allow-Origin")
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestCORS_PreflightDisallowedOrigin_ReturnsForbidden(t *testing.T) {
|
||||
cfg := config.CORSConfig{
|
||||
AllowedOrigins: []string{"https://allowed.example.com"},
|
||||
AllowCredentials: false,
|
||||
}
|
||||
middleware := CORS(cfg)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Request = httptest.NewRequest(http.MethodOptions, "/", nil)
|
||||
c.Request.Header.Set("Origin", "https://evil.example.com")
|
||||
|
||||
middleware(c)
|
||||
|
||||
assert.Equal(t, http.StatusForbidden, w.Code,
|
||||
"不允许的 origin 的 preflight 请求应返回 403")
|
||||
}
|
||||
|
||||
func TestCORS_PreflightAllowedOrigin_ReturnsNoContent(t *testing.T) {
|
||||
cfg := config.CORSConfig{
|
||||
AllowedOrigins: []string{"https://allowed.example.com"},
|
||||
AllowCredentials: false,
|
||||
}
|
||||
middleware := CORS(cfg)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Request = httptest.NewRequest(http.MethodOptions, "/", nil)
|
||||
c.Request.Header.Set("Origin", "https://allowed.example.com")
|
||||
|
||||
middleware(c)
|
||||
|
||||
assert.Equal(t, http.StatusNoContent, w.Code,
|
||||
"允许的 origin 的 preflight 请求应返回 204")
|
||||
}
|
||||
|
||||
func TestCORS_WildcardOrigin_AllowsAny(t *testing.T) {
|
||||
cfg := config.CORSConfig{
|
||||
AllowedOrigins: []string{"*"},
|
||||
AllowCredentials: false,
|
||||
}
|
||||
middleware := CORS(cfg)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Request = httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
c.Request.Header.Set("Origin", "https://any-origin.example.com")
|
||||
|
||||
middleware(c)
|
||||
|
||||
assert.Equal(t, "*", w.Header().Get("Access-Control-Allow-Origin"),
|
||||
"通配符配置应返回 *")
|
||||
assert.NotEmpty(t, w.Header().Get("Access-Control-Allow-Headers"),
|
||||
"通配符 origin 应设置 Allow-Headers")
|
||||
assert.NotEmpty(t, w.Header().Get("Access-Control-Allow-Methods"),
|
||||
"通配符 origin 应设置 Allow-Methods")
|
||||
}
|
||||
|
||||
func TestCORS_AllowCredentials_SetCorrectly(t *testing.T) {
|
||||
cfg := config.CORSConfig{
|
||||
AllowedOrigins: []string{"https://allowed.example.com"},
|
||||
AllowCredentials: true,
|
||||
}
|
||||
middleware := CORS(cfg)
|
||||
|
||||
t.Run("allowed_origin_gets_credentials", func(t *testing.T) {
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Request = httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
c.Request.Header.Set("Origin", "https://allowed.example.com")
|
||||
|
||||
middleware(c)
|
||||
|
||||
assert.Equal(t, "true", w.Header().Get("Access-Control-Allow-Credentials"),
|
||||
"允许的 origin 且开启 credentials 应设置 Allow-Credentials")
|
||||
})
|
||||
|
||||
t.Run("disallowed_origin_no_credentials", func(t *testing.T) {
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Request = httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
c.Request.Header.Set("Origin", "https://evil.example.com")
|
||||
|
||||
middleware(c)
|
||||
|
||||
assert.Empty(t, w.Header().Get("Access-Control-Allow-Credentials"),
|
||||
"不允许的 origin 不应收到 Allow-Credentials")
|
||||
})
|
||||
}
|
||||
|
||||
func TestCORS_WildcardWithCredentials_DisablesCredentials(t *testing.T) {
|
||||
cfg := config.CORSConfig{
|
||||
AllowedOrigins: []string{"*"},
|
||||
AllowCredentials: true,
|
||||
}
|
||||
middleware := CORS(cfg)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Request = httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
c.Request.Header.Set("Origin", "https://any.example.com")
|
||||
|
||||
middleware(c)
|
||||
|
||||
// 通配符 + credentials 不兼容,credentials 应被禁用
|
||||
assert.Empty(t, w.Header().Get("Access-Control-Allow-Credentials"),
|
||||
"通配符 origin 应禁用 Allow-Credentials")
|
||||
}
|
||||
|
||||
func TestCORS_MultipleAllowedOrigins(t *testing.T) {
|
||||
cfg := config.CORSConfig{
|
||||
AllowedOrigins: []string{
|
||||
"https://app1.example.com",
|
||||
"https://app2.example.com",
|
||||
},
|
||||
AllowCredentials: false,
|
||||
}
|
||||
middleware := CORS(cfg)
|
||||
|
||||
t.Run("first_origin_allowed", func(t *testing.T) {
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Request = httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
c.Request.Header.Set("Origin", "https://app1.example.com")
|
||||
|
||||
middleware(c)
|
||||
|
||||
assert.Equal(t, "https://app1.example.com", w.Header().Get("Access-Control-Allow-Origin"))
|
||||
assert.NotEmpty(t, w.Header().Get("Access-Control-Allow-Headers"))
|
||||
})
|
||||
|
||||
t.Run("second_origin_allowed", func(t *testing.T) {
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Request = httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
c.Request.Header.Set("Origin", "https://app2.example.com")
|
||||
|
||||
middleware(c)
|
||||
|
||||
assert.Equal(t, "https://app2.example.com", w.Header().Get("Access-Control-Allow-Origin"))
|
||||
assert.NotEmpty(t, w.Header().Get("Access-Control-Allow-Headers"))
|
||||
})
|
||||
|
||||
t.Run("unlisted_origin_rejected", func(t *testing.T) {
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Request = httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
c.Request.Header.Set("Origin", "https://app3.example.com")
|
||||
|
||||
middleware(c)
|
||||
|
||||
assert.Empty(t, w.Header().Get("Access-Control-Allow-Origin"))
|
||||
assert.Empty(t, w.Header().Get("Access-Control-Allow-Headers"))
|
||||
})
|
||||
}
|
||||
|
||||
func TestCORS_VaryHeader_SetForSpecificOrigin(t *testing.T) {
|
||||
cfg := config.CORSConfig{
|
||||
AllowedOrigins: []string{"https://allowed.example.com"},
|
||||
AllowCredentials: false,
|
||||
}
|
||||
middleware := CORS(cfg)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Request = httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
c.Request.Header.Set("Origin", "https://allowed.example.com")
|
||||
|
||||
middleware(c)
|
||||
|
||||
assert.Contains(t, w.Header().Values("Vary"), "Origin",
|
||||
"非通配符允许的 origin 应设置 Vary: Origin")
|
||||
}
|
||||
|
||||
func TestNormalizeOrigins(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input []string
|
||||
expect []string
|
||||
}{
|
||||
{name: "nil_input", input: nil, expect: nil},
|
||||
{name: "empty_input", input: []string{}, expect: nil},
|
||||
{name: "trims_whitespace", input: []string{" https://a.com ", " https://b.com"}, expect: []string{"https://a.com", "https://b.com"}},
|
||||
{name: "removes_empty_strings", input: []string{"", " ", "https://a.com"}, expect: []string{"https://a.com"}},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := normalizeOrigins(tt.input)
|
||||
assert.Equal(t, tt.expect, result)
|
||||
})
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user