test: 为代码审核修复添加详细单元测试(7个测试文件,50+测试用例)

新增测试文件:
- cors_test.go: CORS 条件化头部测试(12个测试,覆盖白名单/黑名单/通配符/凭证/多源/Vary)
- gateway_helper_backoff_test.go: nextBackoff 退避测试(6个测试+基准,验证指数增长/边界/抖动/收敛)
- billing_cache_jitter_test.go: jitteredTTL 抖动测试(5个测试+基准,验证范围/上界/方差/均值)
- subscription_calculate_progress_test.go: calculateProgress 纯函数测试(9个测试,覆盖日/周/月限额/超限截断/过期)
- openai_gateway_handler_test.go: SSE JSON 转义测试(7个子用例,验证双引号/反斜杠/换行符安全)

更新测试文件:
- response_transformer_test.go: 增强 generateRandomID 测试(7个测试,含并发/字符集/降级计数器)
- security_headers_test.go: 适配 GenerateNonce 新签名
- api_key_auth_test.go: 适配 NewSubscriptionService 新参数

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
yangjianbo
2026-02-07 22:14:07 +08:00
parent 9634494ba9
commit 00caf0bcd8
8 changed files with 913 additions and 6 deletions

View File

@@ -0,0 +1,106 @@
package handler
import (
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
// --- Task 6.2 验证: math/rand/v2 迁移后 nextBackoff 行为正确 ---
func TestNextBackoff_ExponentialGrowth(t *testing.T) {
// 验证退避时间指数增长(乘数 1.5
// 由于有随机抖动±20%),需要验证范围
current := initialBackoff // 100ms
for i := 0; i < 10; i++ {
next := nextBackoff(current)
// 退避结果应在 [initialBackoff, maxBackoff] 范围内
assert.GreaterOrEqual(t, int64(next), int64(initialBackoff),
"第 %d 次退避不应低于初始值 %v", i, initialBackoff)
assert.LessOrEqual(t, int64(next), int64(maxBackoff),
"第 %d 次退避不应超过最大值 %v", i, maxBackoff)
// 为下一轮提供当前退避值
current = next
}
}
func TestNextBackoff_BoundedByMaxBackoff(t *testing.T) {
// 即使输入非常大,输出也不超过 maxBackoff
for i := 0; i < 100; i++ {
result := nextBackoff(10 * time.Second)
assert.LessOrEqual(t, int64(result), int64(maxBackoff),
"退避值不应超过 maxBackoff")
}
}
func TestNextBackoff_BoundedByInitialBackoff(t *testing.T) {
// 即使输入非常小,输出也不低于 initialBackoff
for i := 0; i < 100; i++ {
result := nextBackoff(1 * time.Millisecond)
assert.GreaterOrEqual(t, int64(result), int64(initialBackoff),
"退避值不应低于 initialBackoff")
}
}
func TestNextBackoff_HasJitter(t *testing.T) {
// 验证多次调用会产生不同的值(随机抖动生效)
// 使用相同的输入调用 50 次,收集结果
results := make(map[time.Duration]bool)
current := 500 * time.Millisecond
for i := 0; i < 50; i++ {
result := nextBackoff(current)
results[result] = true
}
// 50 次调用应该至少有 2 个不同的值(抖动存在)
require.Greater(t, len(results), 1,
"nextBackoff 应产生随机抖动,但所有 50 次调用结果相同")
}
func TestNextBackoff_InitialValueGrows(t *testing.T) {
// 验证从初始值开始,退避趋势是增长的
current := initialBackoff
var sum time.Duration
runs := 100
for i := 0; i < runs; i++ {
next := nextBackoff(current)
sum += next
current = next
}
avg := sum / time.Duration(runs)
// 平均退避时间应大于初始值(因为指数增长 + 上限)
assert.Greater(t, int64(avg), int64(initialBackoff),
"平均退避时间应大于初始退避值")
}
func TestNextBackoff_ConvergesToMaxBackoff(t *testing.T) {
// 从初始值开始,经过多次退避后应收敛到 maxBackoff 附近
current := initialBackoff
for i := 0; i < 20; i++ {
current = nextBackoff(current)
}
// 经过 20 次迭代后,应该已经到达 maxBackoff 区间
// 由于抖动,允许 ±20% 的范围
lowerBound := time.Duration(float64(maxBackoff) * 0.8)
assert.GreaterOrEqual(t, int64(current), int64(lowerBound),
"经过多次退避后应收敛到 maxBackoff 附近")
}
func BenchmarkNextBackoff(b *testing.B) {
current := initialBackoff
for i := 0; i < b.N; i++ {
current = nextBackoff(current)
if current > maxBackoff {
current = initialBackoff
}
}
}

View File

@@ -0,0 +1,104 @@
package handler
import (
"encoding/json"
"net/http"
"net/http/httptest"
"strings"
"testing"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestOpenAIHandleStreamingAwareError_JSONEscaping(t *testing.T) {
tests := []struct {
name string
errType string
message string
}{
{
name: "包含双引号的消息",
errType: "server_error",
message: `upstream returned "invalid" response`,
},
{
name: "包含反斜杠的消息",
errType: "server_error",
message: `path C:\Users\test\file.txt not found`,
},
{
name: "包含双引号和反斜杠的消息",
errType: "upstream_error",
message: `error parsing "key\value": unexpected token`,
},
{
name: "包含换行符的消息",
errType: "server_error",
message: "line1\nline2\ttab",
},
{
name: "普通消息",
errType: "upstream_error",
message: "Upstream service temporarily unavailable",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
gin.SetMode(gin.TestMode)
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Request = httptest.NewRequest(http.MethodGet, "/", nil)
h := &OpenAIGatewayHandler{}
h.handleStreamingAwareError(c, http.StatusBadGateway, tt.errType, tt.message, true)
body := w.Body.String()
// 验证 SSE 格式event: error\ndata: {JSON}\n\n
assert.True(t, strings.HasPrefix(body, "event: error\n"), "应以 'event: error\\n' 开头")
assert.True(t, strings.HasSuffix(body, "\n\n"), "应以 '\\n\\n' 结尾")
// 提取 data 部分
lines := strings.Split(strings.TrimSuffix(body, "\n\n"), "\n")
require.Len(t, lines, 2, "应有 event 行和 data 行")
dataLine := lines[1]
require.True(t, strings.HasPrefix(dataLine, "data: "), "第二行应以 'data: ' 开头")
jsonStr := strings.TrimPrefix(dataLine, "data: ")
// 验证 JSON 合法性
var parsed map[string]any
err := json.Unmarshal([]byte(jsonStr), &parsed)
require.NoError(t, err, "JSON 应能被成功解析,原始 JSON: %s", jsonStr)
// 验证结构
errorObj, ok := parsed["error"].(map[string]any)
require.True(t, ok, "应包含 error 对象")
assert.Equal(t, tt.errType, errorObj["type"])
assert.Equal(t, tt.message, errorObj["message"])
})
}
}
func TestOpenAIHandleStreamingAwareError_NonStreaming(t *testing.T) {
gin.SetMode(gin.TestMode)
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Request = httptest.NewRequest(http.MethodGet, "/", nil)
h := &OpenAIGatewayHandler{}
h.handleStreamingAwareError(c, http.StatusBadGateway, "upstream_error", "test error", false)
// 非流式应返回 JSON 响应
assert.Equal(t, http.StatusBadGateway, w.Code)
var parsed map[string]any
err := json.Unmarshal(w.Body.Bytes(), &parsed)
require.NoError(t, err)
errorObj, ok := parsed["error"].(map[string]any)
require.True(t, ok)
assert.Equal(t, "upstream_error", errorObj["type"])
assert.Equal(t, "test error", errorObj["message"])
}

View File

@@ -3,11 +3,16 @@
package antigravity package antigravity
import ( import (
"sync"
"sync/atomic"
"testing" "testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
) )
// --- Task 7: 验证 generateRandomID 和降级碰撞防护 ---
func TestGenerateRandomID_Uniqueness(t *testing.T) { func TestGenerateRandomID_Uniqueness(t *testing.T) {
seen := make(map[string]struct{}, 100) seen := make(map[string]struct{}, 100)
for i := 0; i < 100; i++ { for i := 0; i < 100; i++ {
@@ -19,6 +24,39 @@ func TestGenerateRandomID_Uniqueness(t *testing.T) {
} }
} }
func TestFallbackCounter_Increments(t *testing.T) {
// 验证 fallbackCounter 的原子递增行为确保降级分支不会生成相同 seed
before := atomic.LoadUint64(&fallbackCounter)
cnt1 := atomic.AddUint64(&fallbackCounter, 1)
cnt2 := atomic.AddUint64(&fallbackCounter, 1)
require.Equal(t, before+1, cnt1, "第一次递增应为 before+1")
require.Equal(t, before+2, cnt2, "第二次递增应为 before+2")
require.NotEqual(t, cnt1, cnt2, "连续两次递增的计数器值应不同")
}
func TestFallbackCounter_ConcurrentIncrements(t *testing.T) {
// 验证并发递增的原子性 — 每次递增都应产生唯一值
const goroutines = 50
results := make([]uint64, goroutines)
var wg sync.WaitGroup
wg.Add(goroutines)
for i := 0; i < goroutines; i++ {
go func(idx int) {
defer wg.Done()
results[idx] = atomic.AddUint64(&fallbackCounter, 1)
}(i)
}
wg.Wait()
// 所有结果应唯一
seen := make(map[uint64]bool, goroutines)
for _, v := range results {
assert.False(t, seen[v], "并发递增产生了重复值: %d", v)
seen[v] = true
}
}
func TestGenerateRandomID_Charset(t *testing.T) { func TestGenerateRandomID_Charset(t *testing.T) {
const validChars = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789" const validChars = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789"
validSet := make(map[byte]struct{}, len(validChars)) validSet := make(map[byte]struct{}, len(validChars))
@@ -34,3 +72,38 @@ func TestGenerateRandomID_Charset(t *testing.T) {
} }
} }
} }
func TestGenerateRandomID_Length(t *testing.T) {
for i := 0; i < 100; i++ {
id := generateRandomID()
assert.Len(t, id, 12, "每次生成的 ID 长度应为 12")
}
}
func TestGenerateRandomID_ConcurrentUniqueness(t *testing.T) {
// 验证并发调用不会产生重复 ID
const goroutines = 100
results := make([]string, goroutines)
var wg sync.WaitGroup
wg.Add(goroutines)
for i := 0; i < goroutines; i++ {
go func(idx int) {
defer wg.Done()
results[idx] = generateRandomID()
}(i)
}
wg.Wait()
seen := make(map[string]bool, goroutines)
for _, id := range results {
assert.False(t, seen[id], "并发调用产生了重复 ID: %s", id)
seen[id] = true
}
}
func BenchmarkGenerateRandomID(b *testing.B) {
for i := 0; i < b.N; i++ {
_ = generateRandomID()
}
}

View File

@@ -0,0 +1,82 @@
package repository
import (
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
// --- Task 6.1 验证: math/rand/v2 迁移后 jitteredTTL 行为正确 ---
func TestJitteredTTL_WithinExpectedRange(t *testing.T) {
// jitteredTTL 使用减法抖动: billingCacheTTL - [0, billingCacheJitter)
// 所以结果应在 [billingCacheTTL - billingCacheJitter, billingCacheTTL] 范围内
lowerBound := billingCacheTTL - billingCacheJitter // 5min - 30s = 4min30s
upperBound := billingCacheTTL // 5min
for i := 0; i < 200; i++ {
ttl := jitteredTTL()
assert.GreaterOrEqual(t, int64(ttl), int64(lowerBound),
"TTL 不应低于 %v实际得到 %v", lowerBound, ttl)
assert.LessOrEqual(t, int64(ttl), int64(upperBound),
"TTL 不应超过 %v上界不变保证实际得到 %v", upperBound, ttl)
}
}
func TestJitteredTTL_NeverExceedsBase(t *testing.T) {
// 关键安全性测试jitteredTTL 使用减法抖动,确保永远不超过 billingCacheTTL
for i := 0; i < 500; i++ {
ttl := jitteredTTL()
assert.LessOrEqual(t, int64(ttl), int64(billingCacheTTL),
"jitteredTTL 不应超过基础 TTL上界预期不被打破")
}
}
func TestJitteredTTL_HasVariance(t *testing.T) {
// 验证抖动确实产生了不同的值
results := make(map[time.Duration]bool)
for i := 0; i < 100; i++ {
ttl := jitteredTTL()
results[ttl] = true
}
require.Greater(t, len(results), 1,
"jitteredTTL 应产生不同的值(抖动生效),但 100 次调用结果全部相同")
}
func TestJitteredTTL_AverageNearCenter(t *testing.T) {
// 验证平均值大约在抖动范围中间
var sum time.Duration
runs := 1000
for i := 0; i < runs; i++ {
sum += jitteredTTL()
}
avg := sum / time.Duration(runs)
expectedCenter := billingCacheTTL - billingCacheJitter/2 // 4min45s
// 允许 ±5s 的误差
tolerance := 5 * time.Second
assert.InDelta(t, float64(expectedCenter), float64(avg), float64(tolerance),
"平均 TTL 应接近抖动范围中心 %v", expectedCenter)
}
func TestBillingKeyGeneration(t *testing.T) {
t.Run("balance_key", func(t *testing.T) {
key := billingBalanceKey(12345)
assert.Equal(t, "billing:balance:12345", key)
})
t.Run("sub_key", func(t *testing.T) {
key := billingSubKey(100, 200)
assert.Equal(t, "billing:sub:100:200", key)
})
}
func BenchmarkJitteredTTL(b *testing.B) {
for i := 0; i < b.N; i++ {
_ = jitteredTTL()
}
}

View File

@@ -60,7 +60,7 @@ func TestSimpleModeBypassesQuotaCheck(t *testing.T) {
t.Run("simple_mode_bypasses_quota_check", func(t *testing.T) { t.Run("simple_mode_bypasses_quota_check", func(t *testing.T) {
cfg := &config.Config{RunMode: config.RunModeSimple} cfg := &config.Config{RunMode: config.RunModeSimple}
apiKeyService := service.NewAPIKeyService(apiKeyRepo, nil, nil, nil, nil, nil, cfg) apiKeyService := service.NewAPIKeyService(apiKeyRepo, nil, nil, nil, nil, nil, cfg)
subscriptionService := service.NewSubscriptionService(nil, &stubUserSubscriptionRepo{}, nil, cfg) subscriptionService := service.NewSubscriptionService(nil, &stubUserSubscriptionRepo{}, nil, nil, cfg)
router := newAuthTestRouter(apiKeyService, subscriptionService, cfg) router := newAuthTestRouter(apiKeyService, subscriptionService, cfg)
w := httptest.NewRecorder() w := httptest.NewRecorder()
@@ -99,7 +99,7 @@ func TestSimpleModeBypassesQuotaCheck(t *testing.T) {
resetWeekly: func(ctx context.Context, id int64, start time.Time) error { return nil }, resetWeekly: func(ctx context.Context, id int64, start time.Time) error { return nil },
resetMonthly: func(ctx context.Context, id int64, start time.Time) error { return nil }, resetMonthly: func(ctx context.Context, id int64, start time.Time) error { return nil },
} }
subscriptionService := service.NewSubscriptionService(nil, subscriptionRepo, nil, cfg) subscriptionService := service.NewSubscriptionService(nil, subscriptionRepo, nil, nil, cfg)
router := newAuthTestRouter(apiKeyService, subscriptionService, cfg) router := newAuthTestRouter(apiKeyService, subscriptionService, cfg)
w := httptest.NewRecorder() w := httptest.NewRecorder()

View File

@@ -0,0 +1,308 @@
package middleware
import (
"net/http"
"net/http/httptest"
"testing"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/assert"
)
func init() {
// cors_test 与 security_headers_test 在同一个包,但 init 是幂等的
gin.SetMode(gin.TestMode)
}
// --- Task 8.2: 验证 CORS 条件化头部 ---
func TestCORS_DisallowedOrigin_NoAllowHeaders(t *testing.T) {
cfg := config.CORSConfig{
AllowedOrigins: []string{"https://allowed.example.com"},
AllowCredentials: false,
}
middleware := CORS(cfg)
tests := []struct {
name string
method string
origin string
}{
{
name: "preflight_disallowed_origin",
method: http.MethodOptions,
origin: "https://evil.example.com",
},
{
name: "get_disallowed_origin",
method: http.MethodGet,
origin: "https://evil.example.com",
},
{
name: "post_disallowed_origin",
method: http.MethodPost,
origin: "https://attacker.example.com",
},
{
name: "preflight_no_origin",
method: http.MethodOptions,
origin: "",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Request = httptest.NewRequest(tt.method, "/", nil)
if tt.origin != "" {
c.Request.Header.Set("Origin", tt.origin)
}
middleware(c)
// 不应设置 Allow-Headers、Allow-Methods 和 Max-Age
assert.Empty(t, w.Header().Get("Access-Control-Allow-Headers"),
"不允许的 origin 不应收到 Allow-Headers")
assert.Empty(t, w.Header().Get("Access-Control-Allow-Methods"),
"不允许的 origin 不应收到 Allow-Methods")
assert.Empty(t, w.Header().Get("Access-Control-Max-Age"),
"不允许的 origin 不应收到 Max-Age")
assert.Empty(t, w.Header().Get("Access-Control-Allow-Origin"),
"不允许的 origin 不应收到 Allow-Origin")
})
}
}
func TestCORS_AllowedOrigin_HasAllowHeaders(t *testing.T) {
cfg := config.CORSConfig{
AllowedOrigins: []string{"https://allowed.example.com"},
AllowCredentials: false,
}
middleware := CORS(cfg)
tests := []struct {
name string
method string
}{
{name: "preflight_OPTIONS", method: http.MethodOptions},
{name: "normal_GET", method: http.MethodGet},
{name: "normal_POST", method: http.MethodPost},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Request = httptest.NewRequest(tt.method, "/", nil)
c.Request.Header.Set("Origin", "https://allowed.example.com")
middleware(c)
// 应设置 Allow-Headers、Allow-Methods 和 Max-Age
assert.NotEmpty(t, w.Header().Get("Access-Control-Allow-Headers"),
"允许的 origin 应收到 Allow-Headers")
assert.NotEmpty(t, w.Header().Get("Access-Control-Allow-Methods"),
"允许的 origin 应收到 Allow-Methods")
assert.Equal(t, "86400", w.Header().Get("Access-Control-Max-Age"),
"允许的 origin 应收到 Max-Age=86400")
assert.Equal(t, "https://allowed.example.com", w.Header().Get("Access-Control-Allow-Origin"),
"允许的 origin 应收到 Allow-Origin")
})
}
}
func TestCORS_PreflightDisallowedOrigin_ReturnsForbidden(t *testing.T) {
cfg := config.CORSConfig{
AllowedOrigins: []string{"https://allowed.example.com"},
AllowCredentials: false,
}
middleware := CORS(cfg)
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Request = httptest.NewRequest(http.MethodOptions, "/", nil)
c.Request.Header.Set("Origin", "https://evil.example.com")
middleware(c)
assert.Equal(t, http.StatusForbidden, w.Code,
"不允许的 origin 的 preflight 请求应返回 403")
}
func TestCORS_PreflightAllowedOrigin_ReturnsNoContent(t *testing.T) {
cfg := config.CORSConfig{
AllowedOrigins: []string{"https://allowed.example.com"},
AllowCredentials: false,
}
middleware := CORS(cfg)
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Request = httptest.NewRequest(http.MethodOptions, "/", nil)
c.Request.Header.Set("Origin", "https://allowed.example.com")
middleware(c)
assert.Equal(t, http.StatusNoContent, w.Code,
"允许的 origin 的 preflight 请求应返回 204")
}
func TestCORS_WildcardOrigin_AllowsAny(t *testing.T) {
cfg := config.CORSConfig{
AllowedOrigins: []string{"*"},
AllowCredentials: false,
}
middleware := CORS(cfg)
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Request = httptest.NewRequest(http.MethodGet, "/", nil)
c.Request.Header.Set("Origin", "https://any-origin.example.com")
middleware(c)
assert.Equal(t, "*", w.Header().Get("Access-Control-Allow-Origin"),
"通配符配置应返回 *")
assert.NotEmpty(t, w.Header().Get("Access-Control-Allow-Headers"),
"通配符 origin 应设置 Allow-Headers")
assert.NotEmpty(t, w.Header().Get("Access-Control-Allow-Methods"),
"通配符 origin 应设置 Allow-Methods")
}
func TestCORS_AllowCredentials_SetCorrectly(t *testing.T) {
cfg := config.CORSConfig{
AllowedOrigins: []string{"https://allowed.example.com"},
AllowCredentials: true,
}
middleware := CORS(cfg)
t.Run("allowed_origin_gets_credentials", func(t *testing.T) {
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Request = httptest.NewRequest(http.MethodGet, "/", nil)
c.Request.Header.Set("Origin", "https://allowed.example.com")
middleware(c)
assert.Equal(t, "true", w.Header().Get("Access-Control-Allow-Credentials"),
"允许的 origin 且开启 credentials 应设置 Allow-Credentials")
})
t.Run("disallowed_origin_no_credentials", func(t *testing.T) {
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Request = httptest.NewRequest(http.MethodGet, "/", nil)
c.Request.Header.Set("Origin", "https://evil.example.com")
middleware(c)
assert.Empty(t, w.Header().Get("Access-Control-Allow-Credentials"),
"不允许的 origin 不应收到 Allow-Credentials")
})
}
func TestCORS_WildcardWithCredentials_DisablesCredentials(t *testing.T) {
cfg := config.CORSConfig{
AllowedOrigins: []string{"*"},
AllowCredentials: true,
}
middleware := CORS(cfg)
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Request = httptest.NewRequest(http.MethodGet, "/", nil)
c.Request.Header.Set("Origin", "https://any.example.com")
middleware(c)
// 通配符 + credentials 不兼容credentials 应被禁用
assert.Empty(t, w.Header().Get("Access-Control-Allow-Credentials"),
"通配符 origin 应禁用 Allow-Credentials")
}
func TestCORS_MultipleAllowedOrigins(t *testing.T) {
cfg := config.CORSConfig{
AllowedOrigins: []string{
"https://app1.example.com",
"https://app2.example.com",
},
AllowCredentials: false,
}
middleware := CORS(cfg)
t.Run("first_origin_allowed", func(t *testing.T) {
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Request = httptest.NewRequest(http.MethodGet, "/", nil)
c.Request.Header.Set("Origin", "https://app1.example.com")
middleware(c)
assert.Equal(t, "https://app1.example.com", w.Header().Get("Access-Control-Allow-Origin"))
assert.NotEmpty(t, w.Header().Get("Access-Control-Allow-Headers"))
})
t.Run("second_origin_allowed", func(t *testing.T) {
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Request = httptest.NewRequest(http.MethodGet, "/", nil)
c.Request.Header.Set("Origin", "https://app2.example.com")
middleware(c)
assert.Equal(t, "https://app2.example.com", w.Header().Get("Access-Control-Allow-Origin"))
assert.NotEmpty(t, w.Header().Get("Access-Control-Allow-Headers"))
})
t.Run("unlisted_origin_rejected", func(t *testing.T) {
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Request = httptest.NewRequest(http.MethodGet, "/", nil)
c.Request.Header.Set("Origin", "https://app3.example.com")
middleware(c)
assert.Empty(t, w.Header().Get("Access-Control-Allow-Origin"))
assert.Empty(t, w.Header().Get("Access-Control-Allow-Headers"))
})
}
func TestCORS_VaryHeader_SetForSpecificOrigin(t *testing.T) {
cfg := config.CORSConfig{
AllowedOrigins: []string{"https://allowed.example.com"},
AllowCredentials: false,
}
middleware := CORS(cfg)
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Request = httptest.NewRequest(http.MethodGet, "/", nil)
c.Request.Header.Set("Origin", "https://allowed.example.com")
middleware(c)
assert.Contains(t, w.Header().Values("Vary"), "Origin",
"非通配符允许的 origin 应设置 Vary: Origin")
}
func TestNormalizeOrigins(t *testing.T) {
tests := []struct {
name string
input []string
expect []string
}{
{name: "nil_input", input: nil, expect: nil},
{name: "empty_input", input: []string{}, expect: nil},
{name: "trims_whitespace", input: []string{" https://a.com ", " https://b.com"}, expect: []string{"https://a.com", "https://b.com"}},
{name: "removes_empty_strings", input: []string{"", " ", "https://a.com"}, expect: []string{"https://a.com"}},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := normalizeOrigins(tt.input)
assert.Equal(t, tt.expect, result)
})
}
}

View File

@@ -19,7 +19,8 @@ func init() {
func TestGenerateNonce(t *testing.T) { func TestGenerateNonce(t *testing.T) {
t.Run("generates_valid_base64_string", func(t *testing.T) { t.Run("generates_valid_base64_string", func(t *testing.T) {
nonce := GenerateNonce() nonce, err := GenerateNonce()
require.NoError(t, err)
// Should be valid base64 // Should be valid base64
decoded, err := base64.StdEncoding.DecodeString(nonce) decoded, err := base64.StdEncoding.DecodeString(nonce)
@@ -32,14 +33,16 @@ func TestGenerateNonce(t *testing.T) {
t.Run("generates_unique_nonces", func(t *testing.T) { t.Run("generates_unique_nonces", func(t *testing.T) {
nonces := make(map[string]bool) nonces := make(map[string]bool)
for i := 0; i < 100; i++ { for i := 0; i < 100; i++ {
nonce := GenerateNonce() nonce, err := GenerateNonce()
require.NoError(t, err)
assert.False(t, nonces[nonce], "nonce should be unique") assert.False(t, nonces[nonce], "nonce should be unique")
nonces[nonce] = true nonces[nonce] = true
} }
}) })
t.Run("nonce_has_expected_length", func(t *testing.T) { t.Run("nonce_has_expected_length", func(t *testing.T) {
nonce := GenerateNonce() nonce, err := GenerateNonce()
require.NoError(t, err)
// 16 bytes -> 24 chars in base64 (with padding) // 16 bytes -> 24 chars in base64 (with padding)
assert.Len(t, nonce, 24) assert.Len(t, nonce, 24)
}) })
@@ -344,7 +347,7 @@ func TestAddToDirective(t *testing.T) {
// Benchmark tests // Benchmark tests
func BenchmarkGenerateNonce(b *testing.B) { func BenchmarkGenerateNonce(b *testing.B) {
for i := 0; i < b.N; i++ { for i := 0; i < b.N; i++ {
GenerateNonce() _, _ = GenerateNonce()
} }
} }

View File

@@ -0,0 +1,231 @@
package service
import (
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
// --- Task 5: 验证 calculateProgress 纯函数行为正确 ---
func newTestSubscriptionService() *SubscriptionService {
return &SubscriptionService{}
}
func ptrFloat64(v float64) *float64 { return &v }
func ptrTime(t time.Time) *time.Time { return &t }
func TestCalculateProgress_BasicFields(t *testing.T) {
svc := newTestSubscriptionService()
now := time.Now()
sub := &UserSubscription{
ID: 100,
ExpiresAt: now.Add(30 * 24 * time.Hour),
}
group := &Group{
Name: "Premium",
}
progress := svc.calculateProgress(sub, group)
assert.Equal(t, int64(100), progress.ID)
assert.Equal(t, "Premium", progress.GroupName)
assert.Equal(t, sub.ExpiresAt, progress.ExpiresAt)
assert.Equal(t, 29, progress.ExpiresInDays) // 约 30 天
assert.Nil(t, progress.Daily, "无日限额时 Daily 应为 nil")
assert.Nil(t, progress.Weekly, "无周限额时 Weekly 应为 nil")
assert.Nil(t, progress.Monthly, "无月限额时 Monthly 应为 nil")
}
func TestCalculateProgress_DailyUsage(t *testing.T) {
svc := newTestSubscriptionService()
now := time.Now()
dailyStart := now.Add(-12 * time.Hour)
sub := &UserSubscription{
ID: 1,
ExpiresAt: now.Add(10 * 24 * time.Hour),
DailyUsageUSD: 3.0,
DailyWindowStart: ptrTime(dailyStart),
}
group := &Group{
Name: "Pro",
DailyLimitUSD: ptrFloat64(10.0),
}
progress := svc.calculateProgress(sub, group)
require.NotNil(t, progress.Daily, "有日限额和窗口时 Daily 不应为 nil")
assert.Equal(t, 10.0, progress.Daily.LimitUSD)
assert.Equal(t, 3.0, progress.Daily.UsedUSD)
assert.Equal(t, 7.0, progress.Daily.RemainingUSD)
assert.Equal(t, 30.0, progress.Daily.Percentage)
assert.Equal(t, dailyStart, progress.Daily.WindowStart)
}
func TestCalculateProgress_WeeklyUsage(t *testing.T) {
svc := newTestSubscriptionService()
now := time.Now()
weeklyStart := now.Add(-3 * 24 * time.Hour)
sub := &UserSubscription{
ID: 1,
ExpiresAt: now.Add(10 * 24 * time.Hour),
WeeklyUsageUSD: 25.0,
WeeklyWindowStart: ptrTime(weeklyStart),
}
group := &Group{
Name: "Pro",
WeeklyLimitUSD: ptrFloat64(50.0),
}
progress := svc.calculateProgress(sub, group)
require.NotNil(t, progress.Weekly, "有周限额和窗口时 Weekly 不应为 nil")
assert.Equal(t, 50.0, progress.Weekly.LimitUSD)
assert.Equal(t, 25.0, progress.Weekly.UsedUSD)
assert.Equal(t, 25.0, progress.Weekly.RemainingUSD)
assert.Equal(t, 50.0, progress.Weekly.Percentage)
}
func TestCalculateProgress_MonthlyUsage(t *testing.T) {
svc := newTestSubscriptionService()
now := time.Now()
monthlyStart := now.Add(-15 * 24 * time.Hour)
sub := &UserSubscription{
ID: 1,
ExpiresAt: now.Add(10 * 24 * time.Hour),
MonthlyUsageUSD: 80.0,
MonthlyWindowStart: ptrTime(monthlyStart),
}
group := &Group{
Name: "Enterprise",
MonthlyLimitUSD: ptrFloat64(100.0),
}
progress := svc.calculateProgress(sub, group)
require.NotNil(t, progress.Monthly, "有月限额和窗口时 Monthly 不应为 nil")
assert.Equal(t, 100.0, progress.Monthly.LimitUSD)
assert.Equal(t, 80.0, progress.Monthly.UsedUSD)
assert.Equal(t, 20.0, progress.Monthly.RemainingUSD)
assert.Equal(t, 80.0, progress.Monthly.Percentage)
}
func TestCalculateProgress_OverLimit_ClampedTo100Percent(t *testing.T) {
svc := newTestSubscriptionService()
now := time.Now()
sub := &UserSubscription{
ID: 1,
ExpiresAt: now.Add(10 * 24 * time.Hour),
DailyUsageUSD: 15.0, // 超过限额
DailyWindowStart: ptrTime(now.Add(-1 * time.Hour)),
}
group := &Group{
Name: "Pro",
DailyLimitUSD: ptrFloat64(10.0),
}
progress := svc.calculateProgress(sub, group)
require.NotNil(t, progress.Daily)
assert.Equal(t, 100.0, progress.Daily.Percentage, "超额使用应被截断为 100%")
assert.Equal(t, 0.0, progress.Daily.RemainingUSD, "超额使用时剩余应为 0")
}
func TestCalculateProgress_NoWindowStart_NoProgress(t *testing.T) {
svc := newTestSubscriptionService()
now := time.Now()
// 有限额但无窗口起始时间(订阅未激活)
sub := &UserSubscription{
ID: 1,
ExpiresAt: now.Add(10 * 24 * time.Hour),
DailyUsageUSD: 0,
WeeklyUsageUSD: 0,
}
group := &Group{
Name: "Pro",
DailyLimitUSD: ptrFloat64(10.0),
WeeklyLimitUSD: ptrFloat64(50.0),
}
progress := svc.calculateProgress(sub, group)
assert.Nil(t, progress.Daily, "无 DailyWindowStart 时 Daily 应为 nil")
assert.Nil(t, progress.Weekly, "无 WeeklyWindowStart 时 Weekly 应为 nil")
}
func TestCalculateProgress_AllLimits(t *testing.T) {
svc := newTestSubscriptionService()
now := time.Now()
sub := &UserSubscription{
ID: 1,
ExpiresAt: now.Add(10 * 24 * time.Hour),
DailyUsageUSD: 5.0,
WeeklyUsageUSD: 20.0,
MonthlyUsageUSD: 60.0,
DailyWindowStart: ptrTime(now.Add(-6 * time.Hour)),
WeeklyWindowStart: ptrTime(now.Add(-3 * 24 * time.Hour)),
MonthlyWindowStart: ptrTime(now.Add(-15 * 24 * time.Hour)),
}
group := &Group{
Name: "Full",
DailyLimitUSD: ptrFloat64(10.0),
WeeklyLimitUSD: ptrFloat64(50.0),
MonthlyLimitUSD: ptrFloat64(100.0),
}
progress := svc.calculateProgress(sub, group)
require.NotNil(t, progress.Daily)
require.NotNil(t, progress.Weekly)
require.NotNil(t, progress.Monthly)
assert.Equal(t, 50.0, progress.Daily.Percentage)
assert.Equal(t, 40.0, progress.Weekly.Percentage)
assert.Equal(t, 60.0, progress.Monthly.Percentage)
}
func TestCalculateProgress_ExpiredSubscription(t *testing.T) {
svc := newTestSubscriptionService()
sub := &UserSubscription{
ID: 1,
ExpiresAt: time.Now().Add(-24 * time.Hour), // 已过期
}
group := &Group{Name: "Expired"}
progress := svc.calculateProgress(sub, group)
assert.Equal(t, 0, progress.ExpiresInDays, "过期订阅的剩余天数应为 0")
}
func TestCalculateProgress_ResetsInSeconds_NotNegative(t *testing.T) {
svc := newTestSubscriptionService()
// 使用过去的窗口起始时间,使得重置时间已过
pastStart := time.Now().Add(-48 * time.Hour)
sub := &UserSubscription{
ID: 1,
ExpiresAt: time.Now().Add(10 * 24 * time.Hour),
DailyUsageUSD: 1.0,
DailyWindowStart: ptrTime(pastStart),
}
group := &Group{
Name: "Test",
DailyLimitUSD: ptrFloat64(10.0),
}
progress := svc.calculateProgress(sub, group)
require.NotNil(t, progress.Daily)
assert.GreaterOrEqual(t, progress.Daily.ResetsInSeconds, int64(0),
"ResetsInSeconds 不应为负数")
}