test: 为代码审核修复添加详细单元测试(7个测试文件,50+测试用例)

新增测试文件:
- cors_test.go: CORS 条件化头部测试(12个测试,覆盖白名单/黑名单/通配符/凭证/多源/Vary)
- gateway_helper_backoff_test.go: nextBackoff 退避测试(6个测试+基准,验证指数增长/边界/抖动/收敛)
- billing_cache_jitter_test.go: jitteredTTL 抖动测试(5个测试+基准,验证范围/上界/方差/均值)
- subscription_calculate_progress_test.go: calculateProgress 纯函数测试(9个测试,覆盖日/周/月限额/超限截断/过期)
- openai_gateway_handler_test.go: SSE JSON 转义测试(7个子用例,验证双引号/反斜杠/换行符安全)

更新测试文件:
- response_transformer_test.go: 增强 generateRandomID 测试(7个测试,含并发/字符集/降级计数器)
- security_headers_test.go: 适配 GenerateNonce 新签名
- api_key_auth_test.go: 适配 NewSubscriptionService 新参数

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
yangjianbo
2026-02-07 22:14:07 +08:00
parent 9634494ba9
commit 00caf0bcd8
8 changed files with 913 additions and 6 deletions

View File

@@ -60,7 +60,7 @@ func TestSimpleModeBypassesQuotaCheck(t *testing.T) {
t.Run("simple_mode_bypasses_quota_check", func(t *testing.T) {
cfg := &config.Config{RunMode: config.RunModeSimple}
apiKeyService := service.NewAPIKeyService(apiKeyRepo, nil, nil, nil, nil, nil, cfg)
subscriptionService := service.NewSubscriptionService(nil, &stubUserSubscriptionRepo{}, nil, cfg)
subscriptionService := service.NewSubscriptionService(nil, &stubUserSubscriptionRepo{}, nil, nil, cfg)
router := newAuthTestRouter(apiKeyService, subscriptionService, cfg)
w := httptest.NewRecorder()
@@ -99,7 +99,7 @@ func TestSimpleModeBypassesQuotaCheck(t *testing.T) {
resetWeekly: func(ctx context.Context, id int64, start time.Time) error { return nil },
resetMonthly: func(ctx context.Context, id int64, start time.Time) error { return nil },
}
subscriptionService := service.NewSubscriptionService(nil, subscriptionRepo, nil, cfg)
subscriptionService := service.NewSubscriptionService(nil, subscriptionRepo, nil, nil, cfg)
router := newAuthTestRouter(apiKeyService, subscriptionService, cfg)
w := httptest.NewRecorder()

View File

@@ -0,0 +1,308 @@
package middleware
import (
"net/http"
"net/http/httptest"
"testing"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/assert"
)
func init() {
// cors_test 与 security_headers_test 在同一个包,但 init 是幂等的
gin.SetMode(gin.TestMode)
}
// --- Task 8.2: 验证 CORS 条件化头部 ---
func TestCORS_DisallowedOrigin_NoAllowHeaders(t *testing.T) {
cfg := config.CORSConfig{
AllowedOrigins: []string{"https://allowed.example.com"},
AllowCredentials: false,
}
middleware := CORS(cfg)
tests := []struct {
name string
method string
origin string
}{
{
name: "preflight_disallowed_origin",
method: http.MethodOptions,
origin: "https://evil.example.com",
},
{
name: "get_disallowed_origin",
method: http.MethodGet,
origin: "https://evil.example.com",
},
{
name: "post_disallowed_origin",
method: http.MethodPost,
origin: "https://attacker.example.com",
},
{
name: "preflight_no_origin",
method: http.MethodOptions,
origin: "",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Request = httptest.NewRequest(tt.method, "/", nil)
if tt.origin != "" {
c.Request.Header.Set("Origin", tt.origin)
}
middleware(c)
// 不应设置 Allow-Headers、Allow-Methods 和 Max-Age
assert.Empty(t, w.Header().Get("Access-Control-Allow-Headers"),
"不允许的 origin 不应收到 Allow-Headers")
assert.Empty(t, w.Header().Get("Access-Control-Allow-Methods"),
"不允许的 origin 不应收到 Allow-Methods")
assert.Empty(t, w.Header().Get("Access-Control-Max-Age"),
"不允许的 origin 不应收到 Max-Age")
assert.Empty(t, w.Header().Get("Access-Control-Allow-Origin"),
"不允许的 origin 不应收到 Allow-Origin")
})
}
}
func TestCORS_AllowedOrigin_HasAllowHeaders(t *testing.T) {
cfg := config.CORSConfig{
AllowedOrigins: []string{"https://allowed.example.com"},
AllowCredentials: false,
}
middleware := CORS(cfg)
tests := []struct {
name string
method string
}{
{name: "preflight_OPTIONS", method: http.MethodOptions},
{name: "normal_GET", method: http.MethodGet},
{name: "normal_POST", method: http.MethodPost},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Request = httptest.NewRequest(tt.method, "/", nil)
c.Request.Header.Set("Origin", "https://allowed.example.com")
middleware(c)
// 应设置 Allow-Headers、Allow-Methods 和 Max-Age
assert.NotEmpty(t, w.Header().Get("Access-Control-Allow-Headers"),
"允许的 origin 应收到 Allow-Headers")
assert.NotEmpty(t, w.Header().Get("Access-Control-Allow-Methods"),
"允许的 origin 应收到 Allow-Methods")
assert.Equal(t, "86400", w.Header().Get("Access-Control-Max-Age"),
"允许的 origin 应收到 Max-Age=86400")
assert.Equal(t, "https://allowed.example.com", w.Header().Get("Access-Control-Allow-Origin"),
"允许的 origin 应收到 Allow-Origin")
})
}
}
func TestCORS_PreflightDisallowedOrigin_ReturnsForbidden(t *testing.T) {
cfg := config.CORSConfig{
AllowedOrigins: []string{"https://allowed.example.com"},
AllowCredentials: false,
}
middleware := CORS(cfg)
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Request = httptest.NewRequest(http.MethodOptions, "/", nil)
c.Request.Header.Set("Origin", "https://evil.example.com")
middleware(c)
assert.Equal(t, http.StatusForbidden, w.Code,
"不允许的 origin 的 preflight 请求应返回 403")
}
func TestCORS_PreflightAllowedOrigin_ReturnsNoContent(t *testing.T) {
cfg := config.CORSConfig{
AllowedOrigins: []string{"https://allowed.example.com"},
AllowCredentials: false,
}
middleware := CORS(cfg)
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Request = httptest.NewRequest(http.MethodOptions, "/", nil)
c.Request.Header.Set("Origin", "https://allowed.example.com")
middleware(c)
assert.Equal(t, http.StatusNoContent, w.Code,
"允许的 origin 的 preflight 请求应返回 204")
}
func TestCORS_WildcardOrigin_AllowsAny(t *testing.T) {
cfg := config.CORSConfig{
AllowedOrigins: []string{"*"},
AllowCredentials: false,
}
middleware := CORS(cfg)
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Request = httptest.NewRequest(http.MethodGet, "/", nil)
c.Request.Header.Set("Origin", "https://any-origin.example.com")
middleware(c)
assert.Equal(t, "*", w.Header().Get("Access-Control-Allow-Origin"),
"通配符配置应返回 *")
assert.NotEmpty(t, w.Header().Get("Access-Control-Allow-Headers"),
"通配符 origin 应设置 Allow-Headers")
assert.NotEmpty(t, w.Header().Get("Access-Control-Allow-Methods"),
"通配符 origin 应设置 Allow-Methods")
}
func TestCORS_AllowCredentials_SetCorrectly(t *testing.T) {
cfg := config.CORSConfig{
AllowedOrigins: []string{"https://allowed.example.com"},
AllowCredentials: true,
}
middleware := CORS(cfg)
t.Run("allowed_origin_gets_credentials", func(t *testing.T) {
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Request = httptest.NewRequest(http.MethodGet, "/", nil)
c.Request.Header.Set("Origin", "https://allowed.example.com")
middleware(c)
assert.Equal(t, "true", w.Header().Get("Access-Control-Allow-Credentials"),
"允许的 origin 且开启 credentials 应设置 Allow-Credentials")
})
t.Run("disallowed_origin_no_credentials", func(t *testing.T) {
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Request = httptest.NewRequest(http.MethodGet, "/", nil)
c.Request.Header.Set("Origin", "https://evil.example.com")
middleware(c)
assert.Empty(t, w.Header().Get("Access-Control-Allow-Credentials"),
"不允许的 origin 不应收到 Allow-Credentials")
})
}
func TestCORS_WildcardWithCredentials_DisablesCredentials(t *testing.T) {
cfg := config.CORSConfig{
AllowedOrigins: []string{"*"},
AllowCredentials: true,
}
middleware := CORS(cfg)
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Request = httptest.NewRequest(http.MethodGet, "/", nil)
c.Request.Header.Set("Origin", "https://any.example.com")
middleware(c)
// 通配符 + credentials 不兼容credentials 应被禁用
assert.Empty(t, w.Header().Get("Access-Control-Allow-Credentials"),
"通配符 origin 应禁用 Allow-Credentials")
}
func TestCORS_MultipleAllowedOrigins(t *testing.T) {
cfg := config.CORSConfig{
AllowedOrigins: []string{
"https://app1.example.com",
"https://app2.example.com",
},
AllowCredentials: false,
}
middleware := CORS(cfg)
t.Run("first_origin_allowed", func(t *testing.T) {
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Request = httptest.NewRequest(http.MethodGet, "/", nil)
c.Request.Header.Set("Origin", "https://app1.example.com")
middleware(c)
assert.Equal(t, "https://app1.example.com", w.Header().Get("Access-Control-Allow-Origin"))
assert.NotEmpty(t, w.Header().Get("Access-Control-Allow-Headers"))
})
t.Run("second_origin_allowed", func(t *testing.T) {
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Request = httptest.NewRequest(http.MethodGet, "/", nil)
c.Request.Header.Set("Origin", "https://app2.example.com")
middleware(c)
assert.Equal(t, "https://app2.example.com", w.Header().Get("Access-Control-Allow-Origin"))
assert.NotEmpty(t, w.Header().Get("Access-Control-Allow-Headers"))
})
t.Run("unlisted_origin_rejected", func(t *testing.T) {
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Request = httptest.NewRequest(http.MethodGet, "/", nil)
c.Request.Header.Set("Origin", "https://app3.example.com")
middleware(c)
assert.Empty(t, w.Header().Get("Access-Control-Allow-Origin"))
assert.Empty(t, w.Header().Get("Access-Control-Allow-Headers"))
})
}
func TestCORS_VaryHeader_SetForSpecificOrigin(t *testing.T) {
cfg := config.CORSConfig{
AllowedOrigins: []string{"https://allowed.example.com"},
AllowCredentials: false,
}
middleware := CORS(cfg)
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Request = httptest.NewRequest(http.MethodGet, "/", nil)
c.Request.Header.Set("Origin", "https://allowed.example.com")
middleware(c)
assert.Contains(t, w.Header().Values("Vary"), "Origin",
"非通配符允许的 origin 应设置 Vary: Origin")
}
func TestNormalizeOrigins(t *testing.T) {
tests := []struct {
name string
input []string
expect []string
}{
{name: "nil_input", input: nil, expect: nil},
{name: "empty_input", input: []string{}, expect: nil},
{name: "trims_whitespace", input: []string{" https://a.com ", " https://b.com"}, expect: []string{"https://a.com", "https://b.com"}},
{name: "removes_empty_strings", input: []string{"", " ", "https://a.com"}, expect: []string{"https://a.com"}},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := normalizeOrigins(tt.input)
assert.Equal(t, tt.expect, result)
})
}
}

View File

@@ -19,7 +19,8 @@ func init() {
func TestGenerateNonce(t *testing.T) {
t.Run("generates_valid_base64_string", func(t *testing.T) {
nonce := GenerateNonce()
nonce, err := GenerateNonce()
require.NoError(t, err)
// Should be valid base64
decoded, err := base64.StdEncoding.DecodeString(nonce)
@@ -32,14 +33,16 @@ func TestGenerateNonce(t *testing.T) {
t.Run("generates_unique_nonces", func(t *testing.T) {
nonces := make(map[string]bool)
for i := 0; i < 100; i++ {
nonce := GenerateNonce()
nonce, err := GenerateNonce()
require.NoError(t, err)
assert.False(t, nonces[nonce], "nonce should be unique")
nonces[nonce] = true
}
})
t.Run("nonce_has_expected_length", func(t *testing.T) {
nonce := GenerateNonce()
nonce, err := GenerateNonce()
require.NoError(t, err)
// 16 bytes -> 24 chars in base64 (with padding)
assert.Len(t, nonce, 24)
})
@@ -344,7 +347,7 @@ func TestAddToDirective(t *testing.T) {
// Benchmark tests
func BenchmarkGenerateNonce(b *testing.B) {
for i := 0; i < b.N; i++ {
GenerateNonce()
_, _ = GenerateNonce()
}
}