test: 为代码审核修复添加详细单元测试(7个测试文件,50+测试用例)
新增测试文件: - cors_test.go: CORS 条件化头部测试(12个测试,覆盖白名单/黑名单/通配符/凭证/多源/Vary) - gateway_helper_backoff_test.go: nextBackoff 退避测试(6个测试+基准,验证指数增长/边界/抖动/收敛) - billing_cache_jitter_test.go: jitteredTTL 抖动测试(5个测试+基准,验证范围/上界/方差/均值) - subscription_calculate_progress_test.go: calculateProgress 纯函数测试(9个测试,覆盖日/周/月限额/超限截断/过期) - openai_gateway_handler_test.go: SSE JSON 转义测试(7个子用例,验证双引号/反斜杠/换行符安全) 更新测试文件: - response_transformer_test.go: 增强 generateRandomID 测试(7个测试,含并发/字符集/降级计数器) - security_headers_test.go: 适配 GenerateNonce 新签名 - api_key_auth_test.go: 适配 NewSubscriptionService 新参数 Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
106
backend/internal/handler/gateway_helper_backoff_test.go
Normal file
106
backend/internal/handler/gateway_helper_backoff_test.go
Normal file
@@ -0,0 +1,106 @@
|
|||||||
|
package handler
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
// --- Task 6.2 验证: math/rand/v2 迁移后 nextBackoff 行为正确 ---
|
||||||
|
|
||||||
|
func TestNextBackoff_ExponentialGrowth(t *testing.T) {
|
||||||
|
// 验证退避时间指数增长(乘数 1.5)
|
||||||
|
// 由于有随机抖动(±20%),需要验证范围
|
||||||
|
current := initialBackoff // 100ms
|
||||||
|
|
||||||
|
for i := 0; i < 10; i++ {
|
||||||
|
next := nextBackoff(current)
|
||||||
|
|
||||||
|
// 退避结果应在 [initialBackoff, maxBackoff] 范围内
|
||||||
|
assert.GreaterOrEqual(t, int64(next), int64(initialBackoff),
|
||||||
|
"第 %d 次退避不应低于初始值 %v", i, initialBackoff)
|
||||||
|
assert.LessOrEqual(t, int64(next), int64(maxBackoff),
|
||||||
|
"第 %d 次退避不应超过最大值 %v", i, maxBackoff)
|
||||||
|
|
||||||
|
// 为下一轮提供当前退避值
|
||||||
|
current = next
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNextBackoff_BoundedByMaxBackoff(t *testing.T) {
|
||||||
|
// 即使输入非常大,输出也不超过 maxBackoff
|
||||||
|
for i := 0; i < 100; i++ {
|
||||||
|
result := nextBackoff(10 * time.Second)
|
||||||
|
assert.LessOrEqual(t, int64(result), int64(maxBackoff),
|
||||||
|
"退避值不应超过 maxBackoff")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNextBackoff_BoundedByInitialBackoff(t *testing.T) {
|
||||||
|
// 即使输入非常小,输出也不低于 initialBackoff
|
||||||
|
for i := 0; i < 100; i++ {
|
||||||
|
result := nextBackoff(1 * time.Millisecond)
|
||||||
|
assert.GreaterOrEqual(t, int64(result), int64(initialBackoff),
|
||||||
|
"退避值不应低于 initialBackoff")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNextBackoff_HasJitter(t *testing.T) {
|
||||||
|
// 验证多次调用会产生不同的值(随机抖动生效)
|
||||||
|
// 使用相同的输入调用 50 次,收集结果
|
||||||
|
results := make(map[time.Duration]bool)
|
||||||
|
current := 500 * time.Millisecond
|
||||||
|
|
||||||
|
for i := 0; i < 50; i++ {
|
||||||
|
result := nextBackoff(current)
|
||||||
|
results[result] = true
|
||||||
|
}
|
||||||
|
|
||||||
|
// 50 次调用应该至少有 2 个不同的值(抖动存在)
|
||||||
|
require.Greater(t, len(results), 1,
|
||||||
|
"nextBackoff 应产生随机抖动,但所有 50 次调用结果相同")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNextBackoff_InitialValueGrows(t *testing.T) {
|
||||||
|
// 验证从初始值开始,退避趋势是增长的
|
||||||
|
current := initialBackoff
|
||||||
|
var sum time.Duration
|
||||||
|
|
||||||
|
runs := 100
|
||||||
|
for i := 0; i < runs; i++ {
|
||||||
|
next := nextBackoff(current)
|
||||||
|
sum += next
|
||||||
|
current = next
|
||||||
|
}
|
||||||
|
|
||||||
|
avg := sum / time.Duration(runs)
|
||||||
|
// 平均退避时间应大于初始值(因为指数增长 + 上限)
|
||||||
|
assert.Greater(t, int64(avg), int64(initialBackoff),
|
||||||
|
"平均退避时间应大于初始退避值")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNextBackoff_ConvergesToMaxBackoff(t *testing.T) {
|
||||||
|
// 从初始值开始,经过多次退避后应收敛到 maxBackoff 附近
|
||||||
|
current := initialBackoff
|
||||||
|
for i := 0; i < 20; i++ {
|
||||||
|
current = nextBackoff(current)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 经过 20 次迭代后,应该已经到达 maxBackoff 区间
|
||||||
|
// 由于抖动,允许 ±20% 的范围
|
||||||
|
lowerBound := time.Duration(float64(maxBackoff) * 0.8)
|
||||||
|
assert.GreaterOrEqual(t, int64(current), int64(lowerBound),
|
||||||
|
"经过多次退避后应收敛到 maxBackoff 附近")
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkNextBackoff(b *testing.B) {
|
||||||
|
current := initialBackoff
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
current = nextBackoff(current)
|
||||||
|
if current > maxBackoff {
|
||||||
|
current = initialBackoff
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
104
backend/internal/handler/openai_gateway_handler_test.go
Normal file
104
backend/internal/handler/openai_gateway_handler_test.go
Normal file
@@ -0,0 +1,104 @@
|
|||||||
|
package handler
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestOpenAIHandleStreamingAwareError_JSONEscaping(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
errType string
|
||||||
|
message string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "包含双引号的消息",
|
||||||
|
errType: "server_error",
|
||||||
|
message: `upstream returned "invalid" response`,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "包含反斜杠的消息",
|
||||||
|
errType: "server_error",
|
||||||
|
message: `path C:\Users\test\file.txt not found`,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "包含双引号和反斜杠的消息",
|
||||||
|
errType: "upstream_error",
|
||||||
|
message: `error parsing "key\value": unexpected token`,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "包含换行符的消息",
|
||||||
|
errType: "server_error",
|
||||||
|
message: "line1\nline2\ttab",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "普通消息",
|
||||||
|
errType: "upstream_error",
|
||||||
|
message: "Upstream service temporarily unavailable",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
c, _ := gin.CreateTestContext(w)
|
||||||
|
c.Request = httptest.NewRequest(http.MethodGet, "/", nil)
|
||||||
|
|
||||||
|
h := &OpenAIGatewayHandler{}
|
||||||
|
h.handleStreamingAwareError(c, http.StatusBadGateway, tt.errType, tt.message, true)
|
||||||
|
|
||||||
|
body := w.Body.String()
|
||||||
|
|
||||||
|
// 验证 SSE 格式:event: error\ndata: {JSON}\n\n
|
||||||
|
assert.True(t, strings.HasPrefix(body, "event: error\n"), "应以 'event: error\\n' 开头")
|
||||||
|
assert.True(t, strings.HasSuffix(body, "\n\n"), "应以 '\\n\\n' 结尾")
|
||||||
|
|
||||||
|
// 提取 data 部分
|
||||||
|
lines := strings.Split(strings.TrimSuffix(body, "\n\n"), "\n")
|
||||||
|
require.Len(t, lines, 2, "应有 event 行和 data 行")
|
||||||
|
dataLine := lines[1]
|
||||||
|
require.True(t, strings.HasPrefix(dataLine, "data: "), "第二行应以 'data: ' 开头")
|
||||||
|
jsonStr := strings.TrimPrefix(dataLine, "data: ")
|
||||||
|
|
||||||
|
// 验证 JSON 合法性
|
||||||
|
var parsed map[string]any
|
||||||
|
err := json.Unmarshal([]byte(jsonStr), &parsed)
|
||||||
|
require.NoError(t, err, "JSON 应能被成功解析,原始 JSON: %s", jsonStr)
|
||||||
|
|
||||||
|
// 验证结构
|
||||||
|
errorObj, ok := parsed["error"].(map[string]any)
|
||||||
|
require.True(t, ok, "应包含 error 对象")
|
||||||
|
assert.Equal(t, tt.errType, errorObj["type"])
|
||||||
|
assert.Equal(t, tt.message, errorObj["message"])
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestOpenAIHandleStreamingAwareError_NonStreaming(t *testing.T) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
c, _ := gin.CreateTestContext(w)
|
||||||
|
c.Request = httptest.NewRequest(http.MethodGet, "/", nil)
|
||||||
|
|
||||||
|
h := &OpenAIGatewayHandler{}
|
||||||
|
h.handleStreamingAwareError(c, http.StatusBadGateway, "upstream_error", "test error", false)
|
||||||
|
|
||||||
|
// 非流式应返回 JSON 响应
|
||||||
|
assert.Equal(t, http.StatusBadGateway, w.Code)
|
||||||
|
|
||||||
|
var parsed map[string]any
|
||||||
|
err := json.Unmarshal(w.Body.Bytes(), &parsed)
|
||||||
|
require.NoError(t, err)
|
||||||
|
errorObj, ok := parsed["error"].(map[string]any)
|
||||||
|
require.True(t, ok)
|
||||||
|
assert.Equal(t, "upstream_error", errorObj["type"])
|
||||||
|
assert.Equal(t, "test error", errorObj["message"])
|
||||||
|
}
|
||||||
@@ -3,11 +3,16 @@
|
|||||||
package antigravity
|
package antigravity
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"sync"
|
||||||
|
"sync/atomic"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// --- Task 7: 验证 generateRandomID 和降级碰撞防护 ---
|
||||||
|
|
||||||
func TestGenerateRandomID_Uniqueness(t *testing.T) {
|
func TestGenerateRandomID_Uniqueness(t *testing.T) {
|
||||||
seen := make(map[string]struct{}, 100)
|
seen := make(map[string]struct{}, 100)
|
||||||
for i := 0; i < 100; i++ {
|
for i := 0; i < 100; i++ {
|
||||||
@@ -19,6 +24,39 @@ func TestGenerateRandomID_Uniqueness(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestFallbackCounter_Increments(t *testing.T) {
|
||||||
|
// 验证 fallbackCounter 的原子递增行为确保降级分支不会生成相同 seed
|
||||||
|
before := atomic.LoadUint64(&fallbackCounter)
|
||||||
|
cnt1 := atomic.AddUint64(&fallbackCounter, 1)
|
||||||
|
cnt2 := atomic.AddUint64(&fallbackCounter, 1)
|
||||||
|
require.Equal(t, before+1, cnt1, "第一次递增应为 before+1")
|
||||||
|
require.Equal(t, before+2, cnt2, "第二次递增应为 before+2")
|
||||||
|
require.NotEqual(t, cnt1, cnt2, "连续两次递增的计数器值应不同")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestFallbackCounter_ConcurrentIncrements(t *testing.T) {
|
||||||
|
// 验证并发递增的原子性 — 每次递增都应产生唯一值
|
||||||
|
const goroutines = 50
|
||||||
|
results := make([]uint64, goroutines)
|
||||||
|
var wg sync.WaitGroup
|
||||||
|
wg.Add(goroutines)
|
||||||
|
|
||||||
|
for i := 0; i < goroutines; i++ {
|
||||||
|
go func(idx int) {
|
||||||
|
defer wg.Done()
|
||||||
|
results[idx] = atomic.AddUint64(&fallbackCounter, 1)
|
||||||
|
}(i)
|
||||||
|
}
|
||||||
|
wg.Wait()
|
||||||
|
|
||||||
|
// 所有结果应唯一
|
||||||
|
seen := make(map[uint64]bool, goroutines)
|
||||||
|
for _, v := range results {
|
||||||
|
assert.False(t, seen[v], "并发递增产生了重复值: %d", v)
|
||||||
|
seen[v] = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestGenerateRandomID_Charset(t *testing.T) {
|
func TestGenerateRandomID_Charset(t *testing.T) {
|
||||||
const validChars = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789"
|
const validChars = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789"
|
||||||
validSet := make(map[byte]struct{}, len(validChars))
|
validSet := make(map[byte]struct{}, len(validChars))
|
||||||
@@ -34,3 +72,38 @@ func TestGenerateRandomID_Charset(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestGenerateRandomID_Length(t *testing.T) {
|
||||||
|
for i := 0; i < 100; i++ {
|
||||||
|
id := generateRandomID()
|
||||||
|
assert.Len(t, id, 12, "每次生成的 ID 长度应为 12")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGenerateRandomID_ConcurrentUniqueness(t *testing.T) {
|
||||||
|
// 验证并发调用不会产生重复 ID
|
||||||
|
const goroutines = 100
|
||||||
|
results := make([]string, goroutines)
|
||||||
|
var wg sync.WaitGroup
|
||||||
|
wg.Add(goroutines)
|
||||||
|
|
||||||
|
for i := 0; i < goroutines; i++ {
|
||||||
|
go func(idx int) {
|
||||||
|
defer wg.Done()
|
||||||
|
results[idx] = generateRandomID()
|
||||||
|
}(i)
|
||||||
|
}
|
||||||
|
wg.Wait()
|
||||||
|
|
||||||
|
seen := make(map[string]bool, goroutines)
|
||||||
|
for _, id := range results {
|
||||||
|
assert.False(t, seen[id], "并发调用产生了重复 ID: %s", id)
|
||||||
|
seen[id] = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkGenerateRandomID(b *testing.B) {
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
_ = generateRandomID()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
82
backend/internal/repository/billing_cache_jitter_test.go
Normal file
82
backend/internal/repository/billing_cache_jitter_test.go
Normal file
@@ -0,0 +1,82 @@
|
|||||||
|
package repository
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
// --- Task 6.1 验证: math/rand/v2 迁移后 jitteredTTL 行为正确 ---
|
||||||
|
|
||||||
|
func TestJitteredTTL_WithinExpectedRange(t *testing.T) {
|
||||||
|
// jitteredTTL 使用减法抖动: billingCacheTTL - [0, billingCacheJitter)
|
||||||
|
// 所以结果应在 [billingCacheTTL - billingCacheJitter, billingCacheTTL] 范围内
|
||||||
|
lowerBound := billingCacheTTL - billingCacheJitter // 5min - 30s = 4min30s
|
||||||
|
upperBound := billingCacheTTL // 5min
|
||||||
|
|
||||||
|
for i := 0; i < 200; i++ {
|
||||||
|
ttl := jitteredTTL()
|
||||||
|
assert.GreaterOrEqual(t, int64(ttl), int64(lowerBound),
|
||||||
|
"TTL 不应低于 %v,实际得到 %v", lowerBound, ttl)
|
||||||
|
assert.LessOrEqual(t, int64(ttl), int64(upperBound),
|
||||||
|
"TTL 不应超过 %v(上界不变保证),实际得到 %v", upperBound, ttl)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestJitteredTTL_NeverExceedsBase(t *testing.T) {
|
||||||
|
// 关键安全性测试:jitteredTTL 使用减法抖动,确保永远不超过 billingCacheTTL
|
||||||
|
for i := 0; i < 500; i++ {
|
||||||
|
ttl := jitteredTTL()
|
||||||
|
assert.LessOrEqual(t, int64(ttl), int64(billingCacheTTL),
|
||||||
|
"jitteredTTL 不应超过基础 TTL(上界预期不被打破)")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestJitteredTTL_HasVariance(t *testing.T) {
|
||||||
|
// 验证抖动确实产生了不同的值
|
||||||
|
results := make(map[time.Duration]bool)
|
||||||
|
for i := 0; i < 100; i++ {
|
||||||
|
ttl := jitteredTTL()
|
||||||
|
results[ttl] = true
|
||||||
|
}
|
||||||
|
|
||||||
|
require.Greater(t, len(results), 1,
|
||||||
|
"jitteredTTL 应产生不同的值(抖动生效),但 100 次调用结果全部相同")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestJitteredTTL_AverageNearCenter(t *testing.T) {
|
||||||
|
// 验证平均值大约在抖动范围中间
|
||||||
|
var sum time.Duration
|
||||||
|
runs := 1000
|
||||||
|
for i := 0; i < runs; i++ {
|
||||||
|
sum += jitteredTTL()
|
||||||
|
}
|
||||||
|
|
||||||
|
avg := sum / time.Duration(runs)
|
||||||
|
expectedCenter := billingCacheTTL - billingCacheJitter/2 // 4min45s
|
||||||
|
|
||||||
|
// 允许 ±5s 的误差
|
||||||
|
tolerance := 5 * time.Second
|
||||||
|
assert.InDelta(t, float64(expectedCenter), float64(avg), float64(tolerance),
|
||||||
|
"平均 TTL 应接近抖动范围中心 %v", expectedCenter)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBillingKeyGeneration(t *testing.T) {
|
||||||
|
t.Run("balance_key", func(t *testing.T) {
|
||||||
|
key := billingBalanceKey(12345)
|
||||||
|
assert.Equal(t, "billing:balance:12345", key)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("sub_key", func(t *testing.T) {
|
||||||
|
key := billingSubKey(100, 200)
|
||||||
|
assert.Equal(t, "billing:sub:100:200", key)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkJitteredTTL(b *testing.B) {
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
_ = jitteredTTL()
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -60,7 +60,7 @@ func TestSimpleModeBypassesQuotaCheck(t *testing.T) {
|
|||||||
t.Run("simple_mode_bypasses_quota_check", func(t *testing.T) {
|
t.Run("simple_mode_bypasses_quota_check", func(t *testing.T) {
|
||||||
cfg := &config.Config{RunMode: config.RunModeSimple}
|
cfg := &config.Config{RunMode: config.RunModeSimple}
|
||||||
apiKeyService := service.NewAPIKeyService(apiKeyRepo, nil, nil, nil, nil, nil, cfg)
|
apiKeyService := service.NewAPIKeyService(apiKeyRepo, nil, nil, nil, nil, nil, cfg)
|
||||||
subscriptionService := service.NewSubscriptionService(nil, &stubUserSubscriptionRepo{}, nil, cfg)
|
subscriptionService := service.NewSubscriptionService(nil, &stubUserSubscriptionRepo{}, nil, nil, cfg)
|
||||||
router := newAuthTestRouter(apiKeyService, subscriptionService, cfg)
|
router := newAuthTestRouter(apiKeyService, subscriptionService, cfg)
|
||||||
|
|
||||||
w := httptest.NewRecorder()
|
w := httptest.NewRecorder()
|
||||||
@@ -99,7 +99,7 @@ func TestSimpleModeBypassesQuotaCheck(t *testing.T) {
|
|||||||
resetWeekly: func(ctx context.Context, id int64, start time.Time) error { return nil },
|
resetWeekly: func(ctx context.Context, id int64, start time.Time) error { return nil },
|
||||||
resetMonthly: func(ctx context.Context, id int64, start time.Time) error { return nil },
|
resetMonthly: func(ctx context.Context, id int64, start time.Time) error { return nil },
|
||||||
}
|
}
|
||||||
subscriptionService := service.NewSubscriptionService(nil, subscriptionRepo, nil, cfg)
|
subscriptionService := service.NewSubscriptionService(nil, subscriptionRepo, nil, nil, cfg)
|
||||||
router := newAuthTestRouter(apiKeyService, subscriptionService, cfg)
|
router := newAuthTestRouter(apiKeyService, subscriptionService, cfg)
|
||||||
|
|
||||||
w := httptest.NewRecorder()
|
w := httptest.NewRecorder()
|
||||||
|
|||||||
308
backend/internal/server/middleware/cors_test.go
Normal file
308
backend/internal/server/middleware/cors_test.go
Normal file
@@ -0,0 +1,308 @@
|
|||||||
|
package middleware
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
)
|
||||||
|
|
||||||
|
func init() {
|
||||||
|
// cors_test 与 security_headers_test 在同一个包,但 init 是幂等的
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
}
|
||||||
|
|
||||||
|
// --- Task 8.2: 验证 CORS 条件化头部 ---
|
||||||
|
|
||||||
|
func TestCORS_DisallowedOrigin_NoAllowHeaders(t *testing.T) {
|
||||||
|
cfg := config.CORSConfig{
|
||||||
|
AllowedOrigins: []string{"https://allowed.example.com"},
|
||||||
|
AllowCredentials: false,
|
||||||
|
}
|
||||||
|
middleware := CORS(cfg)
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
method string
|
||||||
|
origin string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "preflight_disallowed_origin",
|
||||||
|
method: http.MethodOptions,
|
||||||
|
origin: "https://evil.example.com",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "get_disallowed_origin",
|
||||||
|
method: http.MethodGet,
|
||||||
|
origin: "https://evil.example.com",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "post_disallowed_origin",
|
||||||
|
method: http.MethodPost,
|
||||||
|
origin: "https://attacker.example.com",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "preflight_no_origin",
|
||||||
|
method: http.MethodOptions,
|
||||||
|
origin: "",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
c, _ := gin.CreateTestContext(w)
|
||||||
|
c.Request = httptest.NewRequest(tt.method, "/", nil)
|
||||||
|
if tt.origin != "" {
|
||||||
|
c.Request.Header.Set("Origin", tt.origin)
|
||||||
|
}
|
||||||
|
|
||||||
|
middleware(c)
|
||||||
|
|
||||||
|
// 不应设置 Allow-Headers、Allow-Methods 和 Max-Age
|
||||||
|
assert.Empty(t, w.Header().Get("Access-Control-Allow-Headers"),
|
||||||
|
"不允许的 origin 不应收到 Allow-Headers")
|
||||||
|
assert.Empty(t, w.Header().Get("Access-Control-Allow-Methods"),
|
||||||
|
"不允许的 origin 不应收到 Allow-Methods")
|
||||||
|
assert.Empty(t, w.Header().Get("Access-Control-Max-Age"),
|
||||||
|
"不允许的 origin 不应收到 Max-Age")
|
||||||
|
assert.Empty(t, w.Header().Get("Access-Control-Allow-Origin"),
|
||||||
|
"不允许的 origin 不应收到 Allow-Origin")
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCORS_AllowedOrigin_HasAllowHeaders(t *testing.T) {
|
||||||
|
cfg := config.CORSConfig{
|
||||||
|
AllowedOrigins: []string{"https://allowed.example.com"},
|
||||||
|
AllowCredentials: false,
|
||||||
|
}
|
||||||
|
middleware := CORS(cfg)
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
method string
|
||||||
|
}{
|
||||||
|
{name: "preflight_OPTIONS", method: http.MethodOptions},
|
||||||
|
{name: "normal_GET", method: http.MethodGet},
|
||||||
|
{name: "normal_POST", method: http.MethodPost},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
c, _ := gin.CreateTestContext(w)
|
||||||
|
c.Request = httptest.NewRequest(tt.method, "/", nil)
|
||||||
|
c.Request.Header.Set("Origin", "https://allowed.example.com")
|
||||||
|
|
||||||
|
middleware(c)
|
||||||
|
|
||||||
|
// 应设置 Allow-Headers、Allow-Methods 和 Max-Age
|
||||||
|
assert.NotEmpty(t, w.Header().Get("Access-Control-Allow-Headers"),
|
||||||
|
"允许的 origin 应收到 Allow-Headers")
|
||||||
|
assert.NotEmpty(t, w.Header().Get("Access-Control-Allow-Methods"),
|
||||||
|
"允许的 origin 应收到 Allow-Methods")
|
||||||
|
assert.Equal(t, "86400", w.Header().Get("Access-Control-Max-Age"),
|
||||||
|
"允许的 origin 应收到 Max-Age=86400")
|
||||||
|
assert.Equal(t, "https://allowed.example.com", w.Header().Get("Access-Control-Allow-Origin"),
|
||||||
|
"允许的 origin 应收到 Allow-Origin")
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCORS_PreflightDisallowedOrigin_ReturnsForbidden(t *testing.T) {
|
||||||
|
cfg := config.CORSConfig{
|
||||||
|
AllowedOrigins: []string{"https://allowed.example.com"},
|
||||||
|
AllowCredentials: false,
|
||||||
|
}
|
||||||
|
middleware := CORS(cfg)
|
||||||
|
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
c, _ := gin.CreateTestContext(w)
|
||||||
|
c.Request = httptest.NewRequest(http.MethodOptions, "/", nil)
|
||||||
|
c.Request.Header.Set("Origin", "https://evil.example.com")
|
||||||
|
|
||||||
|
middleware(c)
|
||||||
|
|
||||||
|
assert.Equal(t, http.StatusForbidden, w.Code,
|
||||||
|
"不允许的 origin 的 preflight 请求应返回 403")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCORS_PreflightAllowedOrigin_ReturnsNoContent(t *testing.T) {
|
||||||
|
cfg := config.CORSConfig{
|
||||||
|
AllowedOrigins: []string{"https://allowed.example.com"},
|
||||||
|
AllowCredentials: false,
|
||||||
|
}
|
||||||
|
middleware := CORS(cfg)
|
||||||
|
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
c, _ := gin.CreateTestContext(w)
|
||||||
|
c.Request = httptest.NewRequest(http.MethodOptions, "/", nil)
|
||||||
|
c.Request.Header.Set("Origin", "https://allowed.example.com")
|
||||||
|
|
||||||
|
middleware(c)
|
||||||
|
|
||||||
|
assert.Equal(t, http.StatusNoContent, w.Code,
|
||||||
|
"允许的 origin 的 preflight 请求应返回 204")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCORS_WildcardOrigin_AllowsAny(t *testing.T) {
|
||||||
|
cfg := config.CORSConfig{
|
||||||
|
AllowedOrigins: []string{"*"},
|
||||||
|
AllowCredentials: false,
|
||||||
|
}
|
||||||
|
middleware := CORS(cfg)
|
||||||
|
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
c, _ := gin.CreateTestContext(w)
|
||||||
|
c.Request = httptest.NewRequest(http.MethodGet, "/", nil)
|
||||||
|
c.Request.Header.Set("Origin", "https://any-origin.example.com")
|
||||||
|
|
||||||
|
middleware(c)
|
||||||
|
|
||||||
|
assert.Equal(t, "*", w.Header().Get("Access-Control-Allow-Origin"),
|
||||||
|
"通配符配置应返回 *")
|
||||||
|
assert.NotEmpty(t, w.Header().Get("Access-Control-Allow-Headers"),
|
||||||
|
"通配符 origin 应设置 Allow-Headers")
|
||||||
|
assert.NotEmpty(t, w.Header().Get("Access-Control-Allow-Methods"),
|
||||||
|
"通配符 origin 应设置 Allow-Methods")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCORS_AllowCredentials_SetCorrectly(t *testing.T) {
|
||||||
|
cfg := config.CORSConfig{
|
||||||
|
AllowedOrigins: []string{"https://allowed.example.com"},
|
||||||
|
AllowCredentials: true,
|
||||||
|
}
|
||||||
|
middleware := CORS(cfg)
|
||||||
|
|
||||||
|
t.Run("allowed_origin_gets_credentials", func(t *testing.T) {
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
c, _ := gin.CreateTestContext(w)
|
||||||
|
c.Request = httptest.NewRequest(http.MethodGet, "/", nil)
|
||||||
|
c.Request.Header.Set("Origin", "https://allowed.example.com")
|
||||||
|
|
||||||
|
middleware(c)
|
||||||
|
|
||||||
|
assert.Equal(t, "true", w.Header().Get("Access-Control-Allow-Credentials"),
|
||||||
|
"允许的 origin 且开启 credentials 应设置 Allow-Credentials")
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("disallowed_origin_no_credentials", func(t *testing.T) {
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
c, _ := gin.CreateTestContext(w)
|
||||||
|
c.Request = httptest.NewRequest(http.MethodGet, "/", nil)
|
||||||
|
c.Request.Header.Set("Origin", "https://evil.example.com")
|
||||||
|
|
||||||
|
middleware(c)
|
||||||
|
|
||||||
|
assert.Empty(t, w.Header().Get("Access-Control-Allow-Credentials"),
|
||||||
|
"不允许的 origin 不应收到 Allow-Credentials")
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCORS_WildcardWithCredentials_DisablesCredentials(t *testing.T) {
|
||||||
|
cfg := config.CORSConfig{
|
||||||
|
AllowedOrigins: []string{"*"},
|
||||||
|
AllowCredentials: true,
|
||||||
|
}
|
||||||
|
middleware := CORS(cfg)
|
||||||
|
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
c, _ := gin.CreateTestContext(w)
|
||||||
|
c.Request = httptest.NewRequest(http.MethodGet, "/", nil)
|
||||||
|
c.Request.Header.Set("Origin", "https://any.example.com")
|
||||||
|
|
||||||
|
middleware(c)
|
||||||
|
|
||||||
|
// 通配符 + credentials 不兼容,credentials 应被禁用
|
||||||
|
assert.Empty(t, w.Header().Get("Access-Control-Allow-Credentials"),
|
||||||
|
"通配符 origin 应禁用 Allow-Credentials")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCORS_MultipleAllowedOrigins(t *testing.T) {
|
||||||
|
cfg := config.CORSConfig{
|
||||||
|
AllowedOrigins: []string{
|
||||||
|
"https://app1.example.com",
|
||||||
|
"https://app2.example.com",
|
||||||
|
},
|
||||||
|
AllowCredentials: false,
|
||||||
|
}
|
||||||
|
middleware := CORS(cfg)
|
||||||
|
|
||||||
|
t.Run("first_origin_allowed", func(t *testing.T) {
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
c, _ := gin.CreateTestContext(w)
|
||||||
|
c.Request = httptest.NewRequest(http.MethodGet, "/", nil)
|
||||||
|
c.Request.Header.Set("Origin", "https://app1.example.com")
|
||||||
|
|
||||||
|
middleware(c)
|
||||||
|
|
||||||
|
assert.Equal(t, "https://app1.example.com", w.Header().Get("Access-Control-Allow-Origin"))
|
||||||
|
assert.NotEmpty(t, w.Header().Get("Access-Control-Allow-Headers"))
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("second_origin_allowed", func(t *testing.T) {
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
c, _ := gin.CreateTestContext(w)
|
||||||
|
c.Request = httptest.NewRequest(http.MethodGet, "/", nil)
|
||||||
|
c.Request.Header.Set("Origin", "https://app2.example.com")
|
||||||
|
|
||||||
|
middleware(c)
|
||||||
|
|
||||||
|
assert.Equal(t, "https://app2.example.com", w.Header().Get("Access-Control-Allow-Origin"))
|
||||||
|
assert.NotEmpty(t, w.Header().Get("Access-Control-Allow-Headers"))
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("unlisted_origin_rejected", func(t *testing.T) {
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
c, _ := gin.CreateTestContext(w)
|
||||||
|
c.Request = httptest.NewRequest(http.MethodGet, "/", nil)
|
||||||
|
c.Request.Header.Set("Origin", "https://app3.example.com")
|
||||||
|
|
||||||
|
middleware(c)
|
||||||
|
|
||||||
|
assert.Empty(t, w.Header().Get("Access-Control-Allow-Origin"))
|
||||||
|
assert.Empty(t, w.Header().Get("Access-Control-Allow-Headers"))
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCORS_VaryHeader_SetForSpecificOrigin(t *testing.T) {
|
||||||
|
cfg := config.CORSConfig{
|
||||||
|
AllowedOrigins: []string{"https://allowed.example.com"},
|
||||||
|
AllowCredentials: false,
|
||||||
|
}
|
||||||
|
middleware := CORS(cfg)
|
||||||
|
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
c, _ := gin.CreateTestContext(w)
|
||||||
|
c.Request = httptest.NewRequest(http.MethodGet, "/", nil)
|
||||||
|
c.Request.Header.Set("Origin", "https://allowed.example.com")
|
||||||
|
|
||||||
|
middleware(c)
|
||||||
|
|
||||||
|
assert.Contains(t, w.Header().Values("Vary"), "Origin",
|
||||||
|
"非通配符允许的 origin 应设置 Vary: Origin")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNormalizeOrigins(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
input []string
|
||||||
|
expect []string
|
||||||
|
}{
|
||||||
|
{name: "nil_input", input: nil, expect: nil},
|
||||||
|
{name: "empty_input", input: []string{}, expect: nil},
|
||||||
|
{name: "trims_whitespace", input: []string{" https://a.com ", " https://b.com"}, expect: []string{"https://a.com", "https://b.com"}},
|
||||||
|
{name: "removes_empty_strings", input: []string{"", " ", "https://a.com"}, expect: []string{"https://a.com"}},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
result := normalizeOrigins(tt.input)
|
||||||
|
assert.Equal(t, tt.expect, result)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -19,7 +19,8 @@ func init() {
|
|||||||
|
|
||||||
func TestGenerateNonce(t *testing.T) {
|
func TestGenerateNonce(t *testing.T) {
|
||||||
t.Run("generates_valid_base64_string", func(t *testing.T) {
|
t.Run("generates_valid_base64_string", func(t *testing.T) {
|
||||||
nonce := GenerateNonce()
|
nonce, err := GenerateNonce()
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
// Should be valid base64
|
// Should be valid base64
|
||||||
decoded, err := base64.StdEncoding.DecodeString(nonce)
|
decoded, err := base64.StdEncoding.DecodeString(nonce)
|
||||||
@@ -32,14 +33,16 @@ func TestGenerateNonce(t *testing.T) {
|
|||||||
t.Run("generates_unique_nonces", func(t *testing.T) {
|
t.Run("generates_unique_nonces", func(t *testing.T) {
|
||||||
nonces := make(map[string]bool)
|
nonces := make(map[string]bool)
|
||||||
for i := 0; i < 100; i++ {
|
for i := 0; i < 100; i++ {
|
||||||
nonce := GenerateNonce()
|
nonce, err := GenerateNonce()
|
||||||
|
require.NoError(t, err)
|
||||||
assert.False(t, nonces[nonce], "nonce should be unique")
|
assert.False(t, nonces[nonce], "nonce should be unique")
|
||||||
nonces[nonce] = true
|
nonces[nonce] = true
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("nonce_has_expected_length", func(t *testing.T) {
|
t.Run("nonce_has_expected_length", func(t *testing.T) {
|
||||||
nonce := GenerateNonce()
|
nonce, err := GenerateNonce()
|
||||||
|
require.NoError(t, err)
|
||||||
// 16 bytes -> 24 chars in base64 (with padding)
|
// 16 bytes -> 24 chars in base64 (with padding)
|
||||||
assert.Len(t, nonce, 24)
|
assert.Len(t, nonce, 24)
|
||||||
})
|
})
|
||||||
@@ -344,7 +347,7 @@ func TestAddToDirective(t *testing.T) {
|
|||||||
// Benchmark tests
|
// Benchmark tests
|
||||||
func BenchmarkGenerateNonce(b *testing.B) {
|
func BenchmarkGenerateNonce(b *testing.B) {
|
||||||
for i := 0; i < b.N; i++ {
|
for i := 0; i < b.N; i++ {
|
||||||
GenerateNonce()
|
_, _ = GenerateNonce()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
231
backend/internal/service/subscription_calculate_progress_test.go
Normal file
231
backend/internal/service/subscription_calculate_progress_test.go
Normal file
@@ -0,0 +1,231 @@
|
|||||||
|
package service
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
// --- Task 5: 验证 calculateProgress 纯函数行为正确 ---
|
||||||
|
|
||||||
|
func newTestSubscriptionService() *SubscriptionService {
|
||||||
|
return &SubscriptionService{}
|
||||||
|
}
|
||||||
|
|
||||||
|
func ptrFloat64(v float64) *float64 { return &v }
|
||||||
|
func ptrTime(t time.Time) *time.Time { return &t }
|
||||||
|
|
||||||
|
func TestCalculateProgress_BasicFields(t *testing.T) {
|
||||||
|
svc := newTestSubscriptionService()
|
||||||
|
now := time.Now()
|
||||||
|
|
||||||
|
sub := &UserSubscription{
|
||||||
|
ID: 100,
|
||||||
|
ExpiresAt: now.Add(30 * 24 * time.Hour),
|
||||||
|
}
|
||||||
|
group := &Group{
|
||||||
|
Name: "Premium",
|
||||||
|
}
|
||||||
|
|
||||||
|
progress := svc.calculateProgress(sub, group)
|
||||||
|
|
||||||
|
assert.Equal(t, int64(100), progress.ID)
|
||||||
|
assert.Equal(t, "Premium", progress.GroupName)
|
||||||
|
assert.Equal(t, sub.ExpiresAt, progress.ExpiresAt)
|
||||||
|
assert.Equal(t, 29, progress.ExpiresInDays) // 约 30 天
|
||||||
|
assert.Nil(t, progress.Daily, "无日限额时 Daily 应为 nil")
|
||||||
|
assert.Nil(t, progress.Weekly, "无周限额时 Weekly 应为 nil")
|
||||||
|
assert.Nil(t, progress.Monthly, "无月限额时 Monthly 应为 nil")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCalculateProgress_DailyUsage(t *testing.T) {
|
||||||
|
svc := newTestSubscriptionService()
|
||||||
|
now := time.Now()
|
||||||
|
dailyStart := now.Add(-12 * time.Hour)
|
||||||
|
|
||||||
|
sub := &UserSubscription{
|
||||||
|
ID: 1,
|
||||||
|
ExpiresAt: now.Add(10 * 24 * time.Hour),
|
||||||
|
DailyUsageUSD: 3.0,
|
||||||
|
DailyWindowStart: ptrTime(dailyStart),
|
||||||
|
}
|
||||||
|
group := &Group{
|
||||||
|
Name: "Pro",
|
||||||
|
DailyLimitUSD: ptrFloat64(10.0),
|
||||||
|
}
|
||||||
|
|
||||||
|
progress := svc.calculateProgress(sub, group)
|
||||||
|
|
||||||
|
require.NotNil(t, progress.Daily, "有日限额和窗口时 Daily 不应为 nil")
|
||||||
|
assert.Equal(t, 10.0, progress.Daily.LimitUSD)
|
||||||
|
assert.Equal(t, 3.0, progress.Daily.UsedUSD)
|
||||||
|
assert.Equal(t, 7.0, progress.Daily.RemainingUSD)
|
||||||
|
assert.Equal(t, 30.0, progress.Daily.Percentage)
|
||||||
|
assert.Equal(t, dailyStart, progress.Daily.WindowStart)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCalculateProgress_WeeklyUsage(t *testing.T) {
|
||||||
|
svc := newTestSubscriptionService()
|
||||||
|
now := time.Now()
|
||||||
|
weeklyStart := now.Add(-3 * 24 * time.Hour)
|
||||||
|
|
||||||
|
sub := &UserSubscription{
|
||||||
|
ID: 1,
|
||||||
|
ExpiresAt: now.Add(10 * 24 * time.Hour),
|
||||||
|
WeeklyUsageUSD: 25.0,
|
||||||
|
WeeklyWindowStart: ptrTime(weeklyStart),
|
||||||
|
}
|
||||||
|
group := &Group{
|
||||||
|
Name: "Pro",
|
||||||
|
WeeklyLimitUSD: ptrFloat64(50.0),
|
||||||
|
}
|
||||||
|
|
||||||
|
progress := svc.calculateProgress(sub, group)
|
||||||
|
|
||||||
|
require.NotNil(t, progress.Weekly, "有周限额和窗口时 Weekly 不应为 nil")
|
||||||
|
assert.Equal(t, 50.0, progress.Weekly.LimitUSD)
|
||||||
|
assert.Equal(t, 25.0, progress.Weekly.UsedUSD)
|
||||||
|
assert.Equal(t, 25.0, progress.Weekly.RemainingUSD)
|
||||||
|
assert.Equal(t, 50.0, progress.Weekly.Percentage)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCalculateProgress_MonthlyUsage(t *testing.T) {
|
||||||
|
svc := newTestSubscriptionService()
|
||||||
|
now := time.Now()
|
||||||
|
monthlyStart := now.Add(-15 * 24 * time.Hour)
|
||||||
|
|
||||||
|
sub := &UserSubscription{
|
||||||
|
ID: 1,
|
||||||
|
ExpiresAt: now.Add(10 * 24 * time.Hour),
|
||||||
|
MonthlyUsageUSD: 80.0,
|
||||||
|
MonthlyWindowStart: ptrTime(monthlyStart),
|
||||||
|
}
|
||||||
|
group := &Group{
|
||||||
|
Name: "Enterprise",
|
||||||
|
MonthlyLimitUSD: ptrFloat64(100.0),
|
||||||
|
}
|
||||||
|
|
||||||
|
progress := svc.calculateProgress(sub, group)
|
||||||
|
|
||||||
|
require.NotNil(t, progress.Monthly, "有月限额和窗口时 Monthly 不应为 nil")
|
||||||
|
assert.Equal(t, 100.0, progress.Monthly.LimitUSD)
|
||||||
|
assert.Equal(t, 80.0, progress.Monthly.UsedUSD)
|
||||||
|
assert.Equal(t, 20.0, progress.Monthly.RemainingUSD)
|
||||||
|
assert.Equal(t, 80.0, progress.Monthly.Percentage)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCalculateProgress_OverLimit_ClampedTo100Percent(t *testing.T) {
|
||||||
|
svc := newTestSubscriptionService()
|
||||||
|
now := time.Now()
|
||||||
|
|
||||||
|
sub := &UserSubscription{
|
||||||
|
ID: 1,
|
||||||
|
ExpiresAt: now.Add(10 * 24 * time.Hour),
|
||||||
|
DailyUsageUSD: 15.0, // 超过限额
|
||||||
|
DailyWindowStart: ptrTime(now.Add(-1 * time.Hour)),
|
||||||
|
}
|
||||||
|
group := &Group{
|
||||||
|
Name: "Pro",
|
||||||
|
DailyLimitUSD: ptrFloat64(10.0),
|
||||||
|
}
|
||||||
|
|
||||||
|
progress := svc.calculateProgress(sub, group)
|
||||||
|
|
||||||
|
require.NotNil(t, progress.Daily)
|
||||||
|
assert.Equal(t, 100.0, progress.Daily.Percentage, "超额使用应被截断为 100%")
|
||||||
|
assert.Equal(t, 0.0, progress.Daily.RemainingUSD, "超额使用时剩余应为 0")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCalculateProgress_NoWindowStart_NoProgress(t *testing.T) {
|
||||||
|
svc := newTestSubscriptionService()
|
||||||
|
now := time.Now()
|
||||||
|
|
||||||
|
// 有限额但无窗口起始时间(订阅未激活)
|
||||||
|
sub := &UserSubscription{
|
||||||
|
ID: 1,
|
||||||
|
ExpiresAt: now.Add(10 * 24 * time.Hour),
|
||||||
|
DailyUsageUSD: 0,
|
||||||
|
WeeklyUsageUSD: 0,
|
||||||
|
}
|
||||||
|
group := &Group{
|
||||||
|
Name: "Pro",
|
||||||
|
DailyLimitUSD: ptrFloat64(10.0),
|
||||||
|
WeeklyLimitUSD: ptrFloat64(50.0),
|
||||||
|
}
|
||||||
|
|
||||||
|
progress := svc.calculateProgress(sub, group)
|
||||||
|
|
||||||
|
assert.Nil(t, progress.Daily, "无 DailyWindowStart 时 Daily 应为 nil")
|
||||||
|
assert.Nil(t, progress.Weekly, "无 WeeklyWindowStart 时 Weekly 应为 nil")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCalculateProgress_AllLimits(t *testing.T) {
|
||||||
|
svc := newTestSubscriptionService()
|
||||||
|
now := time.Now()
|
||||||
|
|
||||||
|
sub := &UserSubscription{
|
||||||
|
ID: 1,
|
||||||
|
ExpiresAt: now.Add(10 * 24 * time.Hour),
|
||||||
|
DailyUsageUSD: 5.0,
|
||||||
|
WeeklyUsageUSD: 20.0,
|
||||||
|
MonthlyUsageUSD: 60.0,
|
||||||
|
DailyWindowStart: ptrTime(now.Add(-6 * time.Hour)),
|
||||||
|
WeeklyWindowStart: ptrTime(now.Add(-3 * 24 * time.Hour)),
|
||||||
|
MonthlyWindowStart: ptrTime(now.Add(-15 * 24 * time.Hour)),
|
||||||
|
}
|
||||||
|
group := &Group{
|
||||||
|
Name: "Full",
|
||||||
|
DailyLimitUSD: ptrFloat64(10.0),
|
||||||
|
WeeklyLimitUSD: ptrFloat64(50.0),
|
||||||
|
MonthlyLimitUSD: ptrFloat64(100.0),
|
||||||
|
}
|
||||||
|
|
||||||
|
progress := svc.calculateProgress(sub, group)
|
||||||
|
|
||||||
|
require.NotNil(t, progress.Daily)
|
||||||
|
require.NotNil(t, progress.Weekly)
|
||||||
|
require.NotNil(t, progress.Monthly)
|
||||||
|
|
||||||
|
assert.Equal(t, 50.0, progress.Daily.Percentage)
|
||||||
|
assert.Equal(t, 40.0, progress.Weekly.Percentage)
|
||||||
|
assert.Equal(t, 60.0, progress.Monthly.Percentage)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCalculateProgress_ExpiredSubscription(t *testing.T) {
|
||||||
|
svc := newTestSubscriptionService()
|
||||||
|
|
||||||
|
sub := &UserSubscription{
|
||||||
|
ID: 1,
|
||||||
|
ExpiresAt: time.Now().Add(-24 * time.Hour), // 已过期
|
||||||
|
}
|
||||||
|
group := &Group{Name: "Expired"}
|
||||||
|
|
||||||
|
progress := svc.calculateProgress(sub, group)
|
||||||
|
|
||||||
|
assert.Equal(t, 0, progress.ExpiresInDays, "过期订阅的剩余天数应为 0")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCalculateProgress_ResetsInSeconds_NotNegative(t *testing.T) {
|
||||||
|
svc := newTestSubscriptionService()
|
||||||
|
// 使用过去的窗口起始时间,使得重置时间已过
|
||||||
|
pastStart := time.Now().Add(-48 * time.Hour)
|
||||||
|
|
||||||
|
sub := &UserSubscription{
|
||||||
|
ID: 1,
|
||||||
|
ExpiresAt: time.Now().Add(10 * 24 * time.Hour),
|
||||||
|
DailyUsageUSD: 1.0,
|
||||||
|
DailyWindowStart: ptrTime(pastStart),
|
||||||
|
}
|
||||||
|
group := &Group{
|
||||||
|
Name: "Test",
|
||||||
|
DailyLimitUSD: ptrFloat64(10.0),
|
||||||
|
}
|
||||||
|
|
||||||
|
progress := svc.calculateProgress(sub, group)
|
||||||
|
|
||||||
|
require.NotNil(t, progress.Daily)
|
||||||
|
assert.GreaterOrEqual(t, progress.Daily.ResetsInSeconds, int64(0),
|
||||||
|
"ResetsInSeconds 不应为负数")
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user