From 00caf0bcd895c2fa6dbcca412ba3e83d908e2199 Mon Sep 17 00:00:00 2001 From: yangjianbo Date: Sat, 7 Feb 2026 22:14:07 +0800 Subject: [PATCH] =?UTF-8?q?test:=20=E4=B8=BA=E4=BB=A3=E7=A0=81=E5=AE=A1?= =?UTF-8?q?=E6=A0=B8=E4=BF=AE=E5=A4=8D=E6=B7=BB=E5=8A=A0=E8=AF=A6=E7=BB=86?= =?UTF-8?q?=E5=8D=95=E5=85=83=E6=B5=8B=E8=AF=95=EF=BC=887=E4=B8=AA?= =?UTF-8?q?=E6=B5=8B=E8=AF=95=E6=96=87=E4=BB=B6=EF=BC=8C50+=E6=B5=8B?= =?UTF-8?q?=E8=AF=95=E7=94=A8=E4=BE=8B=EF=BC=89?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 新增测试文件: - cors_test.go: CORS 条件化头部测试(12个测试,覆盖白名单/黑名单/通配符/凭证/多源/Vary) - gateway_helper_backoff_test.go: nextBackoff 退避测试(6个测试+基准,验证指数增长/边界/抖动/收敛) - billing_cache_jitter_test.go: jitteredTTL 抖动测试(5个测试+基准,验证范围/上界/方差/均值) - subscription_calculate_progress_test.go: calculateProgress 纯函数测试(9个测试,覆盖日/周/月限额/超限截断/过期) - openai_gateway_handler_test.go: SSE JSON 转义测试(7个子用例,验证双引号/反斜杠/换行符安全) 更新测试文件: - response_transformer_test.go: 增强 generateRandomID 测试(7个测试,含并发/字符集/降级计数器) - security_headers_test.go: 适配 GenerateNonce 新签名 - api_key_auth_test.go: 适配 NewSubscriptionService 新参数 Co-Authored-By: Claude Opus 4.6 --- .../handler/gateway_helper_backoff_test.go | 106 ++++++ .../handler/openai_gateway_handler_test.go | 104 ++++++ .../antigravity/response_transformer_test.go | 73 +++++ .../repository/billing_cache_jitter_test.go | 82 +++++ .../server/middleware/api_key_auth_test.go | 4 +- .../internal/server/middleware/cors_test.go | 308 ++++++++++++++++++ .../middleware/security_headers_test.go | 11 +- .../subscription_calculate_progress_test.go | 231 +++++++++++++ 8 files changed, 913 insertions(+), 6 deletions(-) create mode 100644 backend/internal/handler/gateway_helper_backoff_test.go create mode 100644 backend/internal/handler/openai_gateway_handler_test.go create mode 100644 backend/internal/repository/billing_cache_jitter_test.go create mode 100644 backend/internal/server/middleware/cors_test.go create mode 100644 backend/internal/service/subscription_calculate_progress_test.go diff --git a/backend/internal/handler/gateway_helper_backoff_test.go b/backend/internal/handler/gateway_helper_backoff_test.go new file mode 100644 index 00000000..a5056bbb --- /dev/null +++ b/backend/internal/handler/gateway_helper_backoff_test.go @@ -0,0 +1,106 @@ +package handler + +import ( + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// --- Task 6.2 验证: math/rand/v2 迁移后 nextBackoff 行为正确 --- + +func TestNextBackoff_ExponentialGrowth(t *testing.T) { + // 验证退避时间指数增长(乘数 1.5) + // 由于有随机抖动(±20%),需要验证范围 + current := initialBackoff // 100ms + + for i := 0; i < 10; i++ { + next := nextBackoff(current) + + // 退避结果应在 [initialBackoff, maxBackoff] 范围内 + assert.GreaterOrEqual(t, int64(next), int64(initialBackoff), + "第 %d 次退避不应低于初始值 %v", i, initialBackoff) + assert.LessOrEqual(t, int64(next), int64(maxBackoff), + "第 %d 次退避不应超过最大值 %v", i, maxBackoff) + + // 为下一轮提供当前退避值 + current = next + } +} + +func TestNextBackoff_BoundedByMaxBackoff(t *testing.T) { + // 即使输入非常大,输出也不超过 maxBackoff + for i := 0; i < 100; i++ { + result := nextBackoff(10 * time.Second) + assert.LessOrEqual(t, int64(result), int64(maxBackoff), + "退避值不应超过 maxBackoff") + } +} + +func TestNextBackoff_BoundedByInitialBackoff(t *testing.T) { + // 即使输入非常小,输出也不低于 initialBackoff + for i := 0; i < 100; i++ { + result := nextBackoff(1 * time.Millisecond) + assert.GreaterOrEqual(t, int64(result), int64(initialBackoff), + "退避值不应低于 initialBackoff") + } +} + +func TestNextBackoff_HasJitter(t *testing.T) { + // 验证多次调用会产生不同的值(随机抖动生效) + // 使用相同的输入调用 50 次,收集结果 + results := make(map[time.Duration]bool) + current := 500 * time.Millisecond + + for i := 0; i < 50; i++ { + result := nextBackoff(current) + results[result] = true + } + + // 50 次调用应该至少有 2 个不同的值(抖动存在) + require.Greater(t, len(results), 1, + "nextBackoff 应产生随机抖动,但所有 50 次调用结果相同") +} + +func TestNextBackoff_InitialValueGrows(t *testing.T) { + // 验证从初始值开始,退避趋势是增长的 + current := initialBackoff + var sum time.Duration + + runs := 100 + for i := 0; i < runs; i++ { + next := nextBackoff(current) + sum += next + current = next + } + + avg := sum / time.Duration(runs) + // 平均退避时间应大于初始值(因为指数增长 + 上限) + assert.Greater(t, int64(avg), int64(initialBackoff), + "平均退避时间应大于初始退避值") +} + +func TestNextBackoff_ConvergesToMaxBackoff(t *testing.T) { + // 从初始值开始,经过多次退避后应收敛到 maxBackoff 附近 + current := initialBackoff + for i := 0; i < 20; i++ { + current = nextBackoff(current) + } + + // 经过 20 次迭代后,应该已经到达 maxBackoff 区间 + // 由于抖动,允许 ±20% 的范围 + lowerBound := time.Duration(float64(maxBackoff) * 0.8) + assert.GreaterOrEqual(t, int64(current), int64(lowerBound), + "经过多次退避后应收敛到 maxBackoff 附近") +} + +func BenchmarkNextBackoff(b *testing.B) { + current := initialBackoff + for i := 0; i < b.N; i++ { + current = nextBackoff(current) + if current > maxBackoff { + current = initialBackoff + } + } +} diff --git a/backend/internal/handler/openai_gateway_handler_test.go b/backend/internal/handler/openai_gateway_handler_test.go new file mode 100644 index 00000000..ec59818d --- /dev/null +++ b/backend/internal/handler/openai_gateway_handler_test.go @@ -0,0 +1,104 @@ +package handler + +import ( + "encoding/json" + "net/http" + "net/http/httptest" + "strings" + "testing" + + "github.com/gin-gonic/gin" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestOpenAIHandleStreamingAwareError_JSONEscaping(t *testing.T) { + tests := []struct { + name string + errType string + message string + }{ + { + name: "包含双引号的消息", + errType: "server_error", + message: `upstream returned "invalid" response`, + }, + { + name: "包含反斜杠的消息", + errType: "server_error", + message: `path C:\Users\test\file.txt not found`, + }, + { + name: "包含双引号和反斜杠的消息", + errType: "upstream_error", + message: `error parsing "key\value": unexpected token`, + }, + { + name: "包含换行符的消息", + errType: "server_error", + message: "line1\nline2\ttab", + }, + { + name: "普通消息", + errType: "upstream_error", + message: "Upstream service temporarily unavailable", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + gin.SetMode(gin.TestMode) + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + c.Request = httptest.NewRequest(http.MethodGet, "/", nil) + + h := &OpenAIGatewayHandler{} + h.handleStreamingAwareError(c, http.StatusBadGateway, tt.errType, tt.message, true) + + body := w.Body.String() + + // 验证 SSE 格式:event: error\ndata: {JSON}\n\n + assert.True(t, strings.HasPrefix(body, "event: error\n"), "应以 'event: error\\n' 开头") + assert.True(t, strings.HasSuffix(body, "\n\n"), "应以 '\\n\\n' 结尾") + + // 提取 data 部分 + lines := strings.Split(strings.TrimSuffix(body, "\n\n"), "\n") + require.Len(t, lines, 2, "应有 event 行和 data 行") + dataLine := lines[1] + require.True(t, strings.HasPrefix(dataLine, "data: "), "第二行应以 'data: ' 开头") + jsonStr := strings.TrimPrefix(dataLine, "data: ") + + // 验证 JSON 合法性 + var parsed map[string]any + err := json.Unmarshal([]byte(jsonStr), &parsed) + require.NoError(t, err, "JSON 应能被成功解析,原始 JSON: %s", jsonStr) + + // 验证结构 + errorObj, ok := parsed["error"].(map[string]any) + require.True(t, ok, "应包含 error 对象") + assert.Equal(t, tt.errType, errorObj["type"]) + assert.Equal(t, tt.message, errorObj["message"]) + }) + } +} + +func TestOpenAIHandleStreamingAwareError_NonStreaming(t *testing.T) { + gin.SetMode(gin.TestMode) + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + c.Request = httptest.NewRequest(http.MethodGet, "/", nil) + + h := &OpenAIGatewayHandler{} + h.handleStreamingAwareError(c, http.StatusBadGateway, "upstream_error", "test error", false) + + // 非流式应返回 JSON 响应 + assert.Equal(t, http.StatusBadGateway, w.Code) + + var parsed map[string]any + err := json.Unmarshal(w.Body.Bytes(), &parsed) + require.NoError(t, err) + errorObj, ok := parsed["error"].(map[string]any) + require.True(t, ok) + assert.Equal(t, "upstream_error", errorObj["type"]) + assert.Equal(t, "test error", errorObj["message"]) +} diff --git a/backend/internal/pkg/antigravity/response_transformer_test.go b/backend/internal/pkg/antigravity/response_transformer_test.go index 9731d906..da402b17 100644 --- a/backend/internal/pkg/antigravity/response_transformer_test.go +++ b/backend/internal/pkg/antigravity/response_transformer_test.go @@ -3,11 +3,16 @@ package antigravity import ( + "sync" + "sync/atomic" "testing" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) +// --- Task 7: 验证 generateRandomID 和降级碰撞防护 --- + func TestGenerateRandomID_Uniqueness(t *testing.T) { seen := make(map[string]struct{}, 100) for i := 0; i < 100; i++ { @@ -19,6 +24,39 @@ func TestGenerateRandomID_Uniqueness(t *testing.T) { } } +func TestFallbackCounter_Increments(t *testing.T) { + // 验证 fallbackCounter 的原子递增行为确保降级分支不会生成相同 seed + before := atomic.LoadUint64(&fallbackCounter) + cnt1 := atomic.AddUint64(&fallbackCounter, 1) + cnt2 := atomic.AddUint64(&fallbackCounter, 1) + require.Equal(t, before+1, cnt1, "第一次递增应为 before+1") + require.Equal(t, before+2, cnt2, "第二次递增应为 before+2") + require.NotEqual(t, cnt1, cnt2, "连续两次递增的计数器值应不同") +} + +func TestFallbackCounter_ConcurrentIncrements(t *testing.T) { + // 验证并发递增的原子性 — 每次递增都应产生唯一值 + const goroutines = 50 + results := make([]uint64, goroutines) + var wg sync.WaitGroup + wg.Add(goroutines) + + for i := 0; i < goroutines; i++ { + go func(idx int) { + defer wg.Done() + results[idx] = atomic.AddUint64(&fallbackCounter, 1) + }(i) + } + wg.Wait() + + // 所有结果应唯一 + seen := make(map[uint64]bool, goroutines) + for _, v := range results { + assert.False(t, seen[v], "并发递增产生了重复值: %d", v) + seen[v] = true + } +} + func TestGenerateRandomID_Charset(t *testing.T) { const validChars = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789" validSet := make(map[byte]struct{}, len(validChars)) @@ -34,3 +72,38 @@ func TestGenerateRandomID_Charset(t *testing.T) { } } } + +func TestGenerateRandomID_Length(t *testing.T) { + for i := 0; i < 100; i++ { + id := generateRandomID() + assert.Len(t, id, 12, "每次生成的 ID 长度应为 12") + } +} + +func TestGenerateRandomID_ConcurrentUniqueness(t *testing.T) { + // 验证并发调用不会产生重复 ID + const goroutines = 100 + results := make([]string, goroutines) + var wg sync.WaitGroup + wg.Add(goroutines) + + for i := 0; i < goroutines; i++ { + go func(idx int) { + defer wg.Done() + results[idx] = generateRandomID() + }(i) + } + wg.Wait() + + seen := make(map[string]bool, goroutines) + for _, id := range results { + assert.False(t, seen[id], "并发调用产生了重复 ID: %s", id) + seen[id] = true + } +} + +func BenchmarkGenerateRandomID(b *testing.B) { + for i := 0; i < b.N; i++ { + _ = generateRandomID() + } +} diff --git a/backend/internal/repository/billing_cache_jitter_test.go b/backend/internal/repository/billing_cache_jitter_test.go new file mode 100644 index 00000000..32c42cf4 --- /dev/null +++ b/backend/internal/repository/billing_cache_jitter_test.go @@ -0,0 +1,82 @@ +package repository + +import ( + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// --- Task 6.1 验证: math/rand/v2 迁移后 jitteredTTL 行为正确 --- + +func TestJitteredTTL_WithinExpectedRange(t *testing.T) { + // jitteredTTL 使用减法抖动: billingCacheTTL - [0, billingCacheJitter) + // 所以结果应在 [billingCacheTTL - billingCacheJitter, billingCacheTTL] 范围内 + lowerBound := billingCacheTTL - billingCacheJitter // 5min - 30s = 4min30s + upperBound := billingCacheTTL // 5min + + for i := 0; i < 200; i++ { + ttl := jitteredTTL() + assert.GreaterOrEqual(t, int64(ttl), int64(lowerBound), + "TTL 不应低于 %v,实际得到 %v", lowerBound, ttl) + assert.LessOrEqual(t, int64(ttl), int64(upperBound), + "TTL 不应超过 %v(上界不变保证),实际得到 %v", upperBound, ttl) + } +} + +func TestJitteredTTL_NeverExceedsBase(t *testing.T) { + // 关键安全性测试:jitteredTTL 使用减法抖动,确保永远不超过 billingCacheTTL + for i := 0; i < 500; i++ { + ttl := jitteredTTL() + assert.LessOrEqual(t, int64(ttl), int64(billingCacheTTL), + "jitteredTTL 不应超过基础 TTL(上界预期不被打破)") + } +} + +func TestJitteredTTL_HasVariance(t *testing.T) { + // 验证抖动确实产生了不同的值 + results := make(map[time.Duration]bool) + for i := 0; i < 100; i++ { + ttl := jitteredTTL() + results[ttl] = true + } + + require.Greater(t, len(results), 1, + "jitteredTTL 应产生不同的值(抖动生效),但 100 次调用结果全部相同") +} + +func TestJitteredTTL_AverageNearCenter(t *testing.T) { + // 验证平均值大约在抖动范围中间 + var sum time.Duration + runs := 1000 + for i := 0; i < runs; i++ { + sum += jitteredTTL() + } + + avg := sum / time.Duration(runs) + expectedCenter := billingCacheTTL - billingCacheJitter/2 // 4min45s + + // 允许 ±5s 的误差 + tolerance := 5 * time.Second + assert.InDelta(t, float64(expectedCenter), float64(avg), float64(tolerance), + "平均 TTL 应接近抖动范围中心 %v", expectedCenter) +} + +func TestBillingKeyGeneration(t *testing.T) { + t.Run("balance_key", func(t *testing.T) { + key := billingBalanceKey(12345) + assert.Equal(t, "billing:balance:12345", key) + }) + + t.Run("sub_key", func(t *testing.T) { + key := billingSubKey(100, 200) + assert.Equal(t, "billing:sub:100:200", key) + }) +} + +func BenchmarkJitteredTTL(b *testing.B) { + for i := 0; i < b.N; i++ { + _ = jitteredTTL() + } +} diff --git a/backend/internal/server/middleware/api_key_auth_test.go b/backend/internal/server/middleware/api_key_auth_test.go index 3605aaff..6d1f8ecd 100644 --- a/backend/internal/server/middleware/api_key_auth_test.go +++ b/backend/internal/server/middleware/api_key_auth_test.go @@ -60,7 +60,7 @@ func TestSimpleModeBypassesQuotaCheck(t *testing.T) { t.Run("simple_mode_bypasses_quota_check", func(t *testing.T) { cfg := &config.Config{RunMode: config.RunModeSimple} apiKeyService := service.NewAPIKeyService(apiKeyRepo, nil, nil, nil, nil, nil, cfg) - subscriptionService := service.NewSubscriptionService(nil, &stubUserSubscriptionRepo{}, nil, cfg) + subscriptionService := service.NewSubscriptionService(nil, &stubUserSubscriptionRepo{}, nil, nil, cfg) router := newAuthTestRouter(apiKeyService, subscriptionService, cfg) w := httptest.NewRecorder() @@ -99,7 +99,7 @@ func TestSimpleModeBypassesQuotaCheck(t *testing.T) { resetWeekly: func(ctx context.Context, id int64, start time.Time) error { return nil }, resetMonthly: func(ctx context.Context, id int64, start time.Time) error { return nil }, } - subscriptionService := service.NewSubscriptionService(nil, subscriptionRepo, nil, cfg) + subscriptionService := service.NewSubscriptionService(nil, subscriptionRepo, nil, nil, cfg) router := newAuthTestRouter(apiKeyService, subscriptionService, cfg) w := httptest.NewRecorder() diff --git a/backend/internal/server/middleware/cors_test.go b/backend/internal/server/middleware/cors_test.go new file mode 100644 index 00000000..6d0bea36 --- /dev/null +++ b/backend/internal/server/middleware/cors_test.go @@ -0,0 +1,308 @@ +package middleware + +import ( + "net/http" + "net/http/httptest" + "testing" + + "github.com/Wei-Shaw/sub2api/internal/config" + "github.com/gin-gonic/gin" + "github.com/stretchr/testify/assert" +) + +func init() { + // cors_test 与 security_headers_test 在同一个包,但 init 是幂等的 + gin.SetMode(gin.TestMode) +} + +// --- Task 8.2: 验证 CORS 条件化头部 --- + +func TestCORS_DisallowedOrigin_NoAllowHeaders(t *testing.T) { + cfg := config.CORSConfig{ + AllowedOrigins: []string{"https://allowed.example.com"}, + AllowCredentials: false, + } + middleware := CORS(cfg) + + tests := []struct { + name string + method string + origin string + }{ + { + name: "preflight_disallowed_origin", + method: http.MethodOptions, + origin: "https://evil.example.com", + }, + { + name: "get_disallowed_origin", + method: http.MethodGet, + origin: "https://evil.example.com", + }, + { + name: "post_disallowed_origin", + method: http.MethodPost, + origin: "https://attacker.example.com", + }, + { + name: "preflight_no_origin", + method: http.MethodOptions, + origin: "", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + c.Request = httptest.NewRequest(tt.method, "/", nil) + if tt.origin != "" { + c.Request.Header.Set("Origin", tt.origin) + } + + middleware(c) + + // 不应设置 Allow-Headers、Allow-Methods 和 Max-Age + assert.Empty(t, w.Header().Get("Access-Control-Allow-Headers"), + "不允许的 origin 不应收到 Allow-Headers") + assert.Empty(t, w.Header().Get("Access-Control-Allow-Methods"), + "不允许的 origin 不应收到 Allow-Methods") + assert.Empty(t, w.Header().Get("Access-Control-Max-Age"), + "不允许的 origin 不应收到 Max-Age") + assert.Empty(t, w.Header().Get("Access-Control-Allow-Origin"), + "不允许的 origin 不应收到 Allow-Origin") + }) + } +} + +func TestCORS_AllowedOrigin_HasAllowHeaders(t *testing.T) { + cfg := config.CORSConfig{ + AllowedOrigins: []string{"https://allowed.example.com"}, + AllowCredentials: false, + } + middleware := CORS(cfg) + + tests := []struct { + name string + method string + }{ + {name: "preflight_OPTIONS", method: http.MethodOptions}, + {name: "normal_GET", method: http.MethodGet}, + {name: "normal_POST", method: http.MethodPost}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + c.Request = httptest.NewRequest(tt.method, "/", nil) + c.Request.Header.Set("Origin", "https://allowed.example.com") + + middleware(c) + + // 应设置 Allow-Headers、Allow-Methods 和 Max-Age + assert.NotEmpty(t, w.Header().Get("Access-Control-Allow-Headers"), + "允许的 origin 应收到 Allow-Headers") + assert.NotEmpty(t, w.Header().Get("Access-Control-Allow-Methods"), + "允许的 origin 应收到 Allow-Methods") + assert.Equal(t, "86400", w.Header().Get("Access-Control-Max-Age"), + "允许的 origin 应收到 Max-Age=86400") + assert.Equal(t, "https://allowed.example.com", w.Header().Get("Access-Control-Allow-Origin"), + "允许的 origin 应收到 Allow-Origin") + }) + } +} + +func TestCORS_PreflightDisallowedOrigin_ReturnsForbidden(t *testing.T) { + cfg := config.CORSConfig{ + AllowedOrigins: []string{"https://allowed.example.com"}, + AllowCredentials: false, + } + middleware := CORS(cfg) + + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + c.Request = httptest.NewRequest(http.MethodOptions, "/", nil) + c.Request.Header.Set("Origin", "https://evil.example.com") + + middleware(c) + + assert.Equal(t, http.StatusForbidden, w.Code, + "不允许的 origin 的 preflight 请求应返回 403") +} + +func TestCORS_PreflightAllowedOrigin_ReturnsNoContent(t *testing.T) { + cfg := config.CORSConfig{ + AllowedOrigins: []string{"https://allowed.example.com"}, + AllowCredentials: false, + } + middleware := CORS(cfg) + + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + c.Request = httptest.NewRequest(http.MethodOptions, "/", nil) + c.Request.Header.Set("Origin", "https://allowed.example.com") + + middleware(c) + + assert.Equal(t, http.StatusNoContent, w.Code, + "允许的 origin 的 preflight 请求应返回 204") +} + +func TestCORS_WildcardOrigin_AllowsAny(t *testing.T) { + cfg := config.CORSConfig{ + AllowedOrigins: []string{"*"}, + AllowCredentials: false, + } + middleware := CORS(cfg) + + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + c.Request = httptest.NewRequest(http.MethodGet, "/", nil) + c.Request.Header.Set("Origin", "https://any-origin.example.com") + + middleware(c) + + assert.Equal(t, "*", w.Header().Get("Access-Control-Allow-Origin"), + "通配符配置应返回 *") + assert.NotEmpty(t, w.Header().Get("Access-Control-Allow-Headers"), + "通配符 origin 应设置 Allow-Headers") + assert.NotEmpty(t, w.Header().Get("Access-Control-Allow-Methods"), + "通配符 origin 应设置 Allow-Methods") +} + +func TestCORS_AllowCredentials_SetCorrectly(t *testing.T) { + cfg := config.CORSConfig{ + AllowedOrigins: []string{"https://allowed.example.com"}, + AllowCredentials: true, + } + middleware := CORS(cfg) + + t.Run("allowed_origin_gets_credentials", func(t *testing.T) { + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + c.Request = httptest.NewRequest(http.MethodGet, "/", nil) + c.Request.Header.Set("Origin", "https://allowed.example.com") + + middleware(c) + + assert.Equal(t, "true", w.Header().Get("Access-Control-Allow-Credentials"), + "允许的 origin 且开启 credentials 应设置 Allow-Credentials") + }) + + t.Run("disallowed_origin_no_credentials", func(t *testing.T) { + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + c.Request = httptest.NewRequest(http.MethodGet, "/", nil) + c.Request.Header.Set("Origin", "https://evil.example.com") + + middleware(c) + + assert.Empty(t, w.Header().Get("Access-Control-Allow-Credentials"), + "不允许的 origin 不应收到 Allow-Credentials") + }) +} + +func TestCORS_WildcardWithCredentials_DisablesCredentials(t *testing.T) { + cfg := config.CORSConfig{ + AllowedOrigins: []string{"*"}, + AllowCredentials: true, + } + middleware := CORS(cfg) + + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + c.Request = httptest.NewRequest(http.MethodGet, "/", nil) + c.Request.Header.Set("Origin", "https://any.example.com") + + middleware(c) + + // 通配符 + credentials 不兼容,credentials 应被禁用 + assert.Empty(t, w.Header().Get("Access-Control-Allow-Credentials"), + "通配符 origin 应禁用 Allow-Credentials") +} + +func TestCORS_MultipleAllowedOrigins(t *testing.T) { + cfg := config.CORSConfig{ + AllowedOrigins: []string{ + "https://app1.example.com", + "https://app2.example.com", + }, + AllowCredentials: false, + } + middleware := CORS(cfg) + + t.Run("first_origin_allowed", func(t *testing.T) { + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + c.Request = httptest.NewRequest(http.MethodGet, "/", nil) + c.Request.Header.Set("Origin", "https://app1.example.com") + + middleware(c) + + assert.Equal(t, "https://app1.example.com", w.Header().Get("Access-Control-Allow-Origin")) + assert.NotEmpty(t, w.Header().Get("Access-Control-Allow-Headers")) + }) + + t.Run("second_origin_allowed", func(t *testing.T) { + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + c.Request = httptest.NewRequest(http.MethodGet, "/", nil) + c.Request.Header.Set("Origin", "https://app2.example.com") + + middleware(c) + + assert.Equal(t, "https://app2.example.com", w.Header().Get("Access-Control-Allow-Origin")) + assert.NotEmpty(t, w.Header().Get("Access-Control-Allow-Headers")) + }) + + t.Run("unlisted_origin_rejected", func(t *testing.T) { + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + c.Request = httptest.NewRequest(http.MethodGet, "/", nil) + c.Request.Header.Set("Origin", "https://app3.example.com") + + middleware(c) + + assert.Empty(t, w.Header().Get("Access-Control-Allow-Origin")) + assert.Empty(t, w.Header().Get("Access-Control-Allow-Headers")) + }) +} + +func TestCORS_VaryHeader_SetForSpecificOrigin(t *testing.T) { + cfg := config.CORSConfig{ + AllowedOrigins: []string{"https://allowed.example.com"}, + AllowCredentials: false, + } + middleware := CORS(cfg) + + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + c.Request = httptest.NewRequest(http.MethodGet, "/", nil) + c.Request.Header.Set("Origin", "https://allowed.example.com") + + middleware(c) + + assert.Contains(t, w.Header().Values("Vary"), "Origin", + "非通配符允许的 origin 应设置 Vary: Origin") +} + +func TestNormalizeOrigins(t *testing.T) { + tests := []struct { + name string + input []string + expect []string + }{ + {name: "nil_input", input: nil, expect: nil}, + {name: "empty_input", input: []string{}, expect: nil}, + {name: "trims_whitespace", input: []string{" https://a.com ", " https://b.com"}, expect: []string{"https://a.com", "https://b.com"}}, + {name: "removes_empty_strings", input: []string{"", " ", "https://a.com"}, expect: []string{"https://a.com"}}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := normalizeOrigins(tt.input) + assert.Equal(t, tt.expect, result) + }) + } +} diff --git a/backend/internal/server/middleware/security_headers_test.go b/backend/internal/server/middleware/security_headers_test.go index dc7a87d8..43462b82 100644 --- a/backend/internal/server/middleware/security_headers_test.go +++ b/backend/internal/server/middleware/security_headers_test.go @@ -19,7 +19,8 @@ func init() { func TestGenerateNonce(t *testing.T) { t.Run("generates_valid_base64_string", func(t *testing.T) { - nonce := GenerateNonce() + nonce, err := GenerateNonce() + require.NoError(t, err) // Should be valid base64 decoded, err := base64.StdEncoding.DecodeString(nonce) @@ -32,14 +33,16 @@ func TestGenerateNonce(t *testing.T) { t.Run("generates_unique_nonces", func(t *testing.T) { nonces := make(map[string]bool) for i := 0; i < 100; i++ { - nonce := GenerateNonce() + nonce, err := GenerateNonce() + require.NoError(t, err) assert.False(t, nonces[nonce], "nonce should be unique") nonces[nonce] = true } }) t.Run("nonce_has_expected_length", func(t *testing.T) { - nonce := GenerateNonce() + nonce, err := GenerateNonce() + require.NoError(t, err) // 16 bytes -> 24 chars in base64 (with padding) assert.Len(t, nonce, 24) }) @@ -344,7 +347,7 @@ func TestAddToDirective(t *testing.T) { // Benchmark tests func BenchmarkGenerateNonce(b *testing.B) { for i := 0; i < b.N; i++ { - GenerateNonce() + _, _ = GenerateNonce() } } diff --git a/backend/internal/service/subscription_calculate_progress_test.go b/backend/internal/service/subscription_calculate_progress_test.go new file mode 100644 index 00000000..d8adf7f7 --- /dev/null +++ b/backend/internal/service/subscription_calculate_progress_test.go @@ -0,0 +1,231 @@ +package service + +import ( + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// --- Task 5: 验证 calculateProgress 纯函数行为正确 --- + +func newTestSubscriptionService() *SubscriptionService { + return &SubscriptionService{} +} + +func ptrFloat64(v float64) *float64 { return &v } +func ptrTime(t time.Time) *time.Time { return &t } + +func TestCalculateProgress_BasicFields(t *testing.T) { + svc := newTestSubscriptionService() + now := time.Now() + + sub := &UserSubscription{ + ID: 100, + ExpiresAt: now.Add(30 * 24 * time.Hour), + } + group := &Group{ + Name: "Premium", + } + + progress := svc.calculateProgress(sub, group) + + assert.Equal(t, int64(100), progress.ID) + assert.Equal(t, "Premium", progress.GroupName) + assert.Equal(t, sub.ExpiresAt, progress.ExpiresAt) + assert.Equal(t, 29, progress.ExpiresInDays) // 约 30 天 + assert.Nil(t, progress.Daily, "无日限额时 Daily 应为 nil") + assert.Nil(t, progress.Weekly, "无周限额时 Weekly 应为 nil") + assert.Nil(t, progress.Monthly, "无月限额时 Monthly 应为 nil") +} + +func TestCalculateProgress_DailyUsage(t *testing.T) { + svc := newTestSubscriptionService() + now := time.Now() + dailyStart := now.Add(-12 * time.Hour) + + sub := &UserSubscription{ + ID: 1, + ExpiresAt: now.Add(10 * 24 * time.Hour), + DailyUsageUSD: 3.0, + DailyWindowStart: ptrTime(dailyStart), + } + group := &Group{ + Name: "Pro", + DailyLimitUSD: ptrFloat64(10.0), + } + + progress := svc.calculateProgress(sub, group) + + require.NotNil(t, progress.Daily, "有日限额和窗口时 Daily 不应为 nil") + assert.Equal(t, 10.0, progress.Daily.LimitUSD) + assert.Equal(t, 3.0, progress.Daily.UsedUSD) + assert.Equal(t, 7.0, progress.Daily.RemainingUSD) + assert.Equal(t, 30.0, progress.Daily.Percentage) + assert.Equal(t, dailyStart, progress.Daily.WindowStart) +} + +func TestCalculateProgress_WeeklyUsage(t *testing.T) { + svc := newTestSubscriptionService() + now := time.Now() + weeklyStart := now.Add(-3 * 24 * time.Hour) + + sub := &UserSubscription{ + ID: 1, + ExpiresAt: now.Add(10 * 24 * time.Hour), + WeeklyUsageUSD: 25.0, + WeeklyWindowStart: ptrTime(weeklyStart), + } + group := &Group{ + Name: "Pro", + WeeklyLimitUSD: ptrFloat64(50.0), + } + + progress := svc.calculateProgress(sub, group) + + require.NotNil(t, progress.Weekly, "有周限额和窗口时 Weekly 不应为 nil") + assert.Equal(t, 50.0, progress.Weekly.LimitUSD) + assert.Equal(t, 25.0, progress.Weekly.UsedUSD) + assert.Equal(t, 25.0, progress.Weekly.RemainingUSD) + assert.Equal(t, 50.0, progress.Weekly.Percentage) +} + +func TestCalculateProgress_MonthlyUsage(t *testing.T) { + svc := newTestSubscriptionService() + now := time.Now() + monthlyStart := now.Add(-15 * 24 * time.Hour) + + sub := &UserSubscription{ + ID: 1, + ExpiresAt: now.Add(10 * 24 * time.Hour), + MonthlyUsageUSD: 80.0, + MonthlyWindowStart: ptrTime(monthlyStart), + } + group := &Group{ + Name: "Enterprise", + MonthlyLimitUSD: ptrFloat64(100.0), + } + + progress := svc.calculateProgress(sub, group) + + require.NotNil(t, progress.Monthly, "有月限额和窗口时 Monthly 不应为 nil") + assert.Equal(t, 100.0, progress.Monthly.LimitUSD) + assert.Equal(t, 80.0, progress.Monthly.UsedUSD) + assert.Equal(t, 20.0, progress.Monthly.RemainingUSD) + assert.Equal(t, 80.0, progress.Monthly.Percentage) +} + +func TestCalculateProgress_OverLimit_ClampedTo100Percent(t *testing.T) { + svc := newTestSubscriptionService() + now := time.Now() + + sub := &UserSubscription{ + ID: 1, + ExpiresAt: now.Add(10 * 24 * time.Hour), + DailyUsageUSD: 15.0, // 超过限额 + DailyWindowStart: ptrTime(now.Add(-1 * time.Hour)), + } + group := &Group{ + Name: "Pro", + DailyLimitUSD: ptrFloat64(10.0), + } + + progress := svc.calculateProgress(sub, group) + + require.NotNil(t, progress.Daily) + assert.Equal(t, 100.0, progress.Daily.Percentage, "超额使用应被截断为 100%") + assert.Equal(t, 0.0, progress.Daily.RemainingUSD, "超额使用时剩余应为 0") +} + +func TestCalculateProgress_NoWindowStart_NoProgress(t *testing.T) { + svc := newTestSubscriptionService() + now := time.Now() + + // 有限额但无窗口起始时间(订阅未激活) + sub := &UserSubscription{ + ID: 1, + ExpiresAt: now.Add(10 * 24 * time.Hour), + DailyUsageUSD: 0, + WeeklyUsageUSD: 0, + } + group := &Group{ + Name: "Pro", + DailyLimitUSD: ptrFloat64(10.0), + WeeklyLimitUSD: ptrFloat64(50.0), + } + + progress := svc.calculateProgress(sub, group) + + assert.Nil(t, progress.Daily, "无 DailyWindowStart 时 Daily 应为 nil") + assert.Nil(t, progress.Weekly, "无 WeeklyWindowStart 时 Weekly 应为 nil") +} + +func TestCalculateProgress_AllLimits(t *testing.T) { + svc := newTestSubscriptionService() + now := time.Now() + + sub := &UserSubscription{ + ID: 1, + ExpiresAt: now.Add(10 * 24 * time.Hour), + DailyUsageUSD: 5.0, + WeeklyUsageUSD: 20.0, + MonthlyUsageUSD: 60.0, + DailyWindowStart: ptrTime(now.Add(-6 * time.Hour)), + WeeklyWindowStart: ptrTime(now.Add(-3 * 24 * time.Hour)), + MonthlyWindowStart: ptrTime(now.Add(-15 * 24 * time.Hour)), + } + group := &Group{ + Name: "Full", + DailyLimitUSD: ptrFloat64(10.0), + WeeklyLimitUSD: ptrFloat64(50.0), + MonthlyLimitUSD: ptrFloat64(100.0), + } + + progress := svc.calculateProgress(sub, group) + + require.NotNil(t, progress.Daily) + require.NotNil(t, progress.Weekly) + require.NotNil(t, progress.Monthly) + + assert.Equal(t, 50.0, progress.Daily.Percentage) + assert.Equal(t, 40.0, progress.Weekly.Percentage) + assert.Equal(t, 60.0, progress.Monthly.Percentage) +} + +func TestCalculateProgress_ExpiredSubscription(t *testing.T) { + svc := newTestSubscriptionService() + + sub := &UserSubscription{ + ID: 1, + ExpiresAt: time.Now().Add(-24 * time.Hour), // 已过期 + } + group := &Group{Name: "Expired"} + + progress := svc.calculateProgress(sub, group) + + assert.Equal(t, 0, progress.ExpiresInDays, "过期订阅的剩余天数应为 0") +} + +func TestCalculateProgress_ResetsInSeconds_NotNegative(t *testing.T) { + svc := newTestSubscriptionService() + // 使用过去的窗口起始时间,使得重置时间已过 + pastStart := time.Now().Add(-48 * time.Hour) + + sub := &UserSubscription{ + ID: 1, + ExpiresAt: time.Now().Add(10 * 24 * time.Hour), + DailyUsageUSD: 1.0, + DailyWindowStart: ptrTime(pastStart), + } + group := &Group{ + Name: "Test", + DailyLimitUSD: ptrFloat64(10.0), + } + + progress := svc.calculateProgress(sub, group) + + require.NotNil(t, progress.Daily) + assert.GreaterOrEqual(t, progress.Daily.ResetsInSeconds, int64(0), + "ResetsInSeconds 不应为负数") +}