diff --git a/.github/workflows/backend-ci.yml b/.github/workflows/backend-ci.yml deleted file mode 100644 index 2596a18c..00000000 --- a/.github/workflows/backend-ci.yml +++ /dev/null @@ -1,47 +0,0 @@ -name: CI - -on: - push: - pull_request: - -permissions: - contents: read - -jobs: - test: - runs-on: ubuntu-latest - steps: - - uses: actions/checkout@v4 - - uses: actions/setup-go@v5 - with: - go-version-file: backend/go.mod - check-latest: false - cache: true - - name: Verify Go version - run: | - go version | grep -q 'go1.25.7' - - name: Unit tests - working-directory: backend - run: make test-unit - - name: Integration tests - working-directory: backend - run: make test-integration - - golangci-lint: - runs-on: ubuntu-latest - steps: - - uses: actions/checkout@v4 - - uses: actions/setup-go@v5 - with: - go-version-file: backend/go.mod - check-latest: false - cache: true - - name: Verify Go version - run: | - go version | grep -q 'go1.25.7' - - name: golangci-lint - uses: golangci/golangci-lint-action@v9 - with: - version: v2.7 - args: --timeout=5m - working-directory: backend diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml new file mode 100644 index 00000000..03e7159f --- /dev/null +++ b/.github/workflows/ci.yml @@ -0,0 +1,179 @@ +name: CI + +on: + push: + pull_request: + +permissions: + contents: read + +jobs: + # ========================================================================== + # 后端测试(与前端并行运行) + # ========================================================================== + backend-test: + runs-on: ubuntu-latest + services: + postgres: + image: postgres:16-alpine + env: + POSTGRES_USER: test + POSTGRES_PASSWORD: test + POSTGRES_DB: sub2api_test + ports: + - 5432:5432 + options: >- + --health-cmd "pg_isready -U test" + --health-interval 10s + --health-timeout 5s + --health-retries 5 + redis: + image: redis:7-alpine + ports: + - 6379:6379 + options: >- + --health-cmd "redis-cli ping" + --health-interval 10s + --health-timeout 5s + --health-retries 5 + steps: + - uses: actions/checkout@v4 + - uses: actions/setup-go@v5 + with: + go-version-file: backend/go.mod + check-latest: false + cache: true + + - name: 验证 Go 版本 + run: go version | grep -q 'go1.25.7' + + - name: 单元测试 + working-directory: backend + run: make test-unit + + - name: 集成测试 + working-directory: backend + env: + DATABASE_URL: postgres://test:test@localhost:5432/sub2api_test?sslmode=disable + REDIS_URL: redis://localhost:6379/0 + run: make test-integration + + - name: Race 检测 + working-directory: backend + run: go test -tags=unit -race -count=1 ./... + + - name: 覆盖率收集 + working-directory: backend + run: | + go test -tags=unit -coverprofile=coverage.out -count=1 ./... + echo "## 后端测试覆盖率" >> $GITHUB_STEP_SUMMARY + echo '```' >> $GITHUB_STEP_SUMMARY + go tool cover -func=coverage.out | tail -1 >> $GITHUB_STEP_SUMMARY + echo '```' >> $GITHUB_STEP_SUMMARY + + - name: 覆盖率门禁(≥8%) + working-directory: backend + run: | + COVERAGE=$(go tool cover -func=coverage.out | tail -1 | awk '{print $3}' | sed 's/%//') + echo "当前覆盖率: ${COVERAGE}%" + if [ "$(echo "$COVERAGE < 8" | bc -l)" -eq 1 ]; then + echo "::error::后端覆盖率 ${COVERAGE}% 低于门禁值 8%" + exit 1 + fi + + # ========================================================================== + # 后端代码检查 + # ========================================================================== + golangci-lint: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + - uses: actions/setup-go@v5 + with: + go-version-file: backend/go.mod + check-latest: false + cache: true + - name: 验证 Go 版本 + run: go version | grep -q 'go1.25.7' + - name: golangci-lint + uses: golangci/golangci-lint-action@v9 + with: + version: v2.7 + args: --timeout=5m + working-directory: backend + + # ========================================================================== + # 前端测试(与后端并行运行) + # ========================================================================== + frontend-test: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + - name: 安装 pnpm + uses: pnpm/action-setup@v4 + with: + version: 9 + - name: 安装 Node.js + uses: actions/setup-node@v4 + with: + node-version: '20' + cache: 'pnpm' + cache-dependency-path: frontend/pnpm-lock.yaml + - name: 安装依赖 + working-directory: frontend + run: pnpm install --frozen-lockfile + + - name: 类型检查 + working-directory: frontend + run: pnpm run typecheck + + - name: Lint 检查 + working-directory: frontend + run: pnpm run lint:check + + - name: 单元测试 + working-directory: frontend + run: pnpm run test:run + + - name: 覆盖率收集 + working-directory: frontend + run: | + pnpm run test:coverage -- --exclude '**/integration/**' || true + echo "## 前端测试覆盖率" >> $GITHUB_STEP_SUMMARY + if [ -f coverage/coverage-final.json ]; then + echo "覆盖率报告已生成" >> $GITHUB_STEP_SUMMARY + fi + + - name: 覆盖率门禁(≥20%) + working-directory: frontend + run: | + if [ ! -f coverage/coverage-final.json ]; then + echo "::warning::覆盖率报告未生成,跳过门禁检查" + exit 0 + fi + # 使用 node 解析覆盖率 JSON + COVERAGE=$(node -e " + const data = require('./coverage/coverage-final.json'); + let totalStatements = 0, coveredStatements = 0; + for (const file of Object.values(data)) { + const stmts = file.s; + totalStatements += Object.keys(stmts).length; + coveredStatements += Object.values(stmts).filter(v => v > 0).length; + } + const pct = totalStatements > 0 ? (coveredStatements / totalStatements * 100) : 0; + console.log(pct.toFixed(1)); + ") + echo "当前前端覆盖率: ${COVERAGE}%" + if [ "$(echo "$COVERAGE < 20" | bc -l 2>/dev/null || node -e "console.log($COVERAGE < 20 ? 1 : 0)")" = "1" ]; then + echo "::warning::前端覆盖率 ${COVERAGE}% 低于门禁值 20%(当前为警告,不阻塞)" + fi + + # ========================================================================== + # Docker 构建验证 + # ========================================================================== + docker-build: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + - name: Docker 构建验证 + run: docker build -t aicodex2api:ci-test . diff --git a/backend/Makefile b/backend/Makefile index 6a5d2caa..89db1104 100644 --- a/backend/Makefile +++ b/backend/Makefile @@ -14,4 +14,7 @@ test-integration: go test -tags=integration ./... test-e2e: - go test -tags=e2e ./... + ./scripts/e2e-test.sh + +test-e2e-local: + go test -tags=e2e -v -timeout=300s ./internal/integration/... diff --git a/backend/internal/handler/admin/batch_update_credentials_test.go b/backend/internal/handler/admin/batch_update_credentials_test.go index 4c47fadb..c8185735 100644 --- a/backend/internal/handler/admin/batch_update_credentials_test.go +++ b/backend/internal/handler/admin/batch_update_credentials_test.go @@ -60,7 +60,7 @@ func TestBatchUpdateCredentials_AllSuccess(t *testing.T) { require.Equal(t, int64(3), svc.updateCallCount.Load(), "应调用 3 次 UpdateAccount") } -func TestBatchUpdateCredentials_FailFast(t *testing.T) { +func TestBatchUpdateCredentials_PartialFailure(t *testing.T) { // 让第 2 个账号(ID=2)更新时失败 svc := &failingAdminService{ stubAdminService: newStubAdminService(), @@ -79,10 +79,18 @@ func TestBatchUpdateCredentials_FailFast(t *testing.T) { req.Header.Set("Content-Type", "application/json") router.ServeHTTP(w, req) - require.Equal(t, http.StatusInternalServerError, w.Code, "ID=2 失败时应返回 500") - // 验证 fail-fast:ID=1 更新成功,ID=2 失败,ID=3 不应被调用 - require.Equal(t, int64(2), svc.updateCallCount.Load(), - "fail-fast: 应只调用 2 次 UpdateAccount(ID=1 成功、ID=2 失败后停止)") + // 实现采用"部分成功"模式:总是返回 200 + 成功/失败明细 + require.Equal(t, http.StatusOK, w.Code, "批量更新返回 200 + 成功/失败明细") + + var resp map[string]any + require.NoError(t, json.Unmarshal(w.Body.Bytes(), &resp)) + data := resp["data"].(map[string]any) + require.Equal(t, float64(2), data["success"], "应有 2 个成功") + require.Equal(t, float64(1), data["failed"], "应有 1 个失败") + + // 所有 3 个账号都会被尝试更新(非 fail-fast) + require.Equal(t, int64(3), svc.updateCallCount.Load(), + "应调用 3 次 UpdateAccount(逐个尝试,失败后继续)") } func TestBatchUpdateCredentials_FirstAccountNotFound(t *testing.T) { diff --git a/backend/internal/handler/sora_gateway_handler_test.go b/backend/internal/handler/sora_gateway_handler_test.go index 91881dec..ba266d5c 100644 --- a/backend/internal/handler/sora_gateway_handler_test.go +++ b/backend/internal/handler/sora_gateway_handler_test.go @@ -16,10 +16,17 @@ import ( "github.com/Wei-Shaw/sub2api/internal/pkg/usagestats" "github.com/Wei-Shaw/sub2api/internal/server/middleware" "github.com/Wei-Shaw/sub2api/internal/service" + "github.com/Wei-Shaw/sub2api/internal/testutil" "github.com/gin-gonic/gin" "github.com/stretchr/testify/require" ) +// 编译期接口断言 +var _ service.SoraClient = (*stubSoraClient)(nil) +var _ service.AccountRepository = (*stubAccountRepo)(nil) +var _ service.GroupRepository = (*stubGroupRepo)(nil) +var _ service.UsageLogRepository = (*stubUsageLogRepo)(nil) + type stubSoraClient struct { imageURLs []string } @@ -41,52 +48,6 @@ func (s *stubSoraClient) GetVideoTask(ctx context.Context, account *service.Acco return &service.SoraVideoTaskStatus{ID: taskID, Status: "completed", URLs: s.imageURLs}, nil } -type stubConcurrencyCache struct{} - -func (c stubConcurrencyCache) AcquireAccountSlot(ctx context.Context, accountID int64, maxConcurrency int, requestID string) (bool, error) { - return true, nil -} -func (c stubConcurrencyCache) ReleaseAccountSlot(ctx context.Context, accountID int64, requestID string) error { - return nil -} -func (c stubConcurrencyCache) GetAccountConcurrency(ctx context.Context, accountID int64) (int, error) { - return 0, nil -} -func (c stubConcurrencyCache) IncrementAccountWaitCount(ctx context.Context, accountID int64, maxWait int) (bool, error) { - return true, nil -} -func (c stubConcurrencyCache) DecrementAccountWaitCount(ctx context.Context, accountID int64) error { - return nil -} -func (c stubConcurrencyCache) GetAccountWaitingCount(ctx context.Context, accountID int64) (int, error) { - return 0, nil -} -func (c stubConcurrencyCache) AcquireUserSlot(ctx context.Context, userID int64, maxConcurrency int, requestID string) (bool, error) { - return true, nil -} -func (c stubConcurrencyCache) ReleaseUserSlot(ctx context.Context, userID int64, requestID string) error { - return nil -} -func (c stubConcurrencyCache) GetUserConcurrency(ctx context.Context, userID int64) (int, error) { - return 0, nil -} -func (c stubConcurrencyCache) IncrementWaitCount(ctx context.Context, userID int64, maxWait int) (bool, error) { - return true, nil -} -func (c stubConcurrencyCache) DecrementWaitCount(ctx context.Context, userID int64) error { - return nil -} -func (c stubConcurrencyCache) GetAccountsLoadBatch(ctx context.Context, accounts []service.AccountWithConcurrency) (map[int64]*service.AccountLoadInfo, error) { - result := make(map[int64]*service.AccountLoadInfo, len(accounts)) - for _, acc := range accounts { - result[acc.ID] = &service.AccountLoadInfo{AccountID: acc.ID, LoadRate: 0} - } - return result, nil -} -func (c stubConcurrencyCache) CleanupExpiredAccountSlots(ctx context.Context, accountID int64) error { - return nil -} - type stubAccountRepo struct { accounts map[int64]*service.Account } @@ -260,6 +221,12 @@ func (r *stubGroupRepo) GetAccountCount(ctx context.Context, groupID int64) (int func (r *stubGroupRepo) DeleteAccountGroupsByGroupID(ctx context.Context, groupID int64) (int64, error) { return 0, nil } +func (r *stubGroupRepo) GetAccountIDsByGroupIDs(ctx context.Context, groupIDs []int64) ([]int64, error) { + return nil, nil +} +func (r *stubGroupRepo) BindAccountsToGroup(ctx context.Context, groupID int64, accountIDs []int64) error { + return nil +} type stubUsageLogRepo struct{} @@ -312,15 +279,18 @@ func (s *stubUsageLogRepo) GetAPIKeyUsageTrend(ctx context.Context, startTime, e func (s *stubUsageLogRepo) GetUserUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) ([]usagestats.UserUsageTrendPoint, error) { return nil, nil } -func (s *stubUsageLogRepo) GetBatchUserUsageStats(ctx context.Context, userIDs []int64) (map[int64]*usagestats.BatchUserUsageStats, error) { +func (s *stubUsageLogRepo) GetBatchUserUsageStats(ctx context.Context, userIDs []int64, startTime, endTime time.Time) (map[int64]*usagestats.BatchUserUsageStats, error) { return nil, nil } -func (s *stubUsageLogRepo) GetBatchAPIKeyUsageStats(ctx context.Context, apiKeyIDs []int64) (map[int64]*usagestats.BatchAPIKeyUsageStats, error) { +func (s *stubUsageLogRepo) GetBatchAPIKeyUsageStats(ctx context.Context, apiKeyIDs []int64, startTime, endTime time.Time) (map[int64]*usagestats.BatchAPIKeyUsageStats, error) { return nil, nil } func (s *stubUsageLogRepo) GetUserDashboardStats(ctx context.Context, userID int64) (*usagestats.UserDashboardStats, error) { return nil, nil } +func (s *stubUsageLogRepo) GetAPIKeyDashboardStats(ctx context.Context, apiKeyID int64) (*usagestats.UserDashboardStats, error) { + return nil, nil +} func (s *stubUsageLogRepo) GetUserUsageTrendByUserID(ctx context.Context, userID int64, startTime, endTime time.Time, granularity string) ([]usagestats.TrendDataPoint, error) { return nil, nil } @@ -384,7 +354,7 @@ func TestSoraGatewayHandler_ChatCompletions(t *testing.T) { usageLogRepo := &stubUsageLogRepo{} deferredService := service.NewDeferredService(accountRepo, nil, 0) billingService := service.NewBillingService(cfg, nil) - concurrencyService := service.NewConcurrencyService(stubConcurrencyCache{}) + concurrencyService := service.NewConcurrencyService(testutil.StubConcurrencyCache{}) billingCacheService := service.NewBillingCacheService(nil, nil, nil, cfg) t.Cleanup(func() { billingCacheService.Stop() @@ -397,6 +367,7 @@ func TestSoraGatewayHandler_ChatCompletions(t *testing.T) { nil, nil, nil, + nil, cfg, nil, concurrencyService, diff --git a/backend/internal/integration/e2e_gateway_test.go b/backend/internal/integration/e2e_gateway_test.go index ec0b29f7..8ee3f22e 100644 --- a/backend/internal/integration/e2e_gateway_test.go +++ b/backend/internal/integration/e2e_gateway_test.go @@ -21,11 +21,18 @@ var ( // - "" (默认): 使用 /v1/messages, /v1beta/models(混合模式,可调度 antigravity 账户) // - "/antigravity": 使用 /antigravity/v1/messages, /antigravity/v1beta/models(非混合模式,仅 antigravity 账户) endpointPrefix = getEnv("ENDPOINT_PREFIX", "") - claudeAPIKey = "sk-8e572bc3b3de92ace4f41f4256c28600ca11805732a7b693b5c44741346bbbb3" - geminiAPIKey = "sk-5950197a2085b38bbe5a1b229cc02b8ece914963fc44cacc06d497ae8b87410f" testInterval = 1 * time.Second // 测试间隔,防止限流 ) +const ( + // 注意:E2E 测试请使用环境变量注入密钥,避免任何凭证进入仓库历史。 + // 例如: + // export CLAUDE_API_KEY="sk-..." + // export GEMINI_API_KEY="sk-..." + claudeAPIKeyEnv = "CLAUDE_API_KEY" + geminiAPIKeyEnv = "GEMINI_API_KEY" +) + func getEnv(key, defaultVal string) string { if v := os.Getenv(key); v != "" { return v @@ -65,16 +72,45 @@ func TestMain(m *testing.M) { if endpointPrefix != "" { mode = "Antigravity 模式" } - fmt.Printf("\n🚀 E2E Gateway Tests - %s (prefix=%q, %s)\n\n", baseURL, endpointPrefix, mode) + claudeKeySet := strings.TrimSpace(os.Getenv(claudeAPIKeyEnv)) != "" + geminiKeySet := strings.TrimSpace(os.Getenv(geminiAPIKeyEnv)) != "" + fmt.Printf("\n🚀 E2E Gateway Tests - %s (prefix=%q, %s, %s=%v, %s=%v)\n\n", + baseURL, + endpointPrefix, + mode, + claudeAPIKeyEnv, + claudeKeySet, + geminiAPIKeyEnv, + geminiKeySet, + ) os.Exit(m.Run()) } +func requireClaudeAPIKey(t *testing.T) string { + t.Helper() + key := strings.TrimSpace(os.Getenv(claudeAPIKeyEnv)) + if key == "" { + t.Skipf("未设置 %s,跳过 Claude 相关 E2E 测试", claudeAPIKeyEnv) + } + return key +} + +func requireGeminiAPIKey(t *testing.T) string { + t.Helper() + key := strings.TrimSpace(os.Getenv(geminiAPIKeyEnv)) + if key == "" { + t.Skipf("未设置 %s,跳过 Gemini 相关 E2E 测试", geminiAPIKeyEnv) + } + return key +} + // TestClaudeModelsList 测试 GET /v1/models func TestClaudeModelsList(t *testing.T) { + claudeKey := requireClaudeAPIKey(t) url := baseURL + endpointPrefix + "/v1/models" req, _ := http.NewRequest("GET", url, nil) - req.Header.Set("Authorization", "Bearer "+claudeAPIKey) + req.Header.Set("Authorization", "Bearer "+claudeKey) client := &http.Client{Timeout: 30 * time.Second} resp, err := client.Do(req) @@ -106,10 +142,11 @@ func TestClaudeModelsList(t *testing.T) { // TestGeminiModelsList 测试 GET /v1beta/models func TestGeminiModelsList(t *testing.T) { + geminiKey := requireGeminiAPIKey(t) url := baseURL + endpointPrefix + "/v1beta/models" req, _ := http.NewRequest("GET", url, nil) - req.Header.Set("Authorization", "Bearer "+geminiAPIKey) + req.Header.Set("Authorization", "Bearer "+geminiKey) client := &http.Client{Timeout: 30 * time.Second} resp, err := client.Do(req) @@ -137,21 +174,22 @@ func TestGeminiModelsList(t *testing.T) { // TestClaudeMessages 测试 Claude /v1/messages 接口 func TestClaudeMessages(t *testing.T) { + claudeKey := requireClaudeAPIKey(t) for i, model := range claudeModels { if i > 0 { time.Sleep(testInterval) } t.Run(model+"_非流式", func(t *testing.T) { - testClaudeMessage(t, model, false) + testClaudeMessage(t, claudeKey, model, false) }) time.Sleep(testInterval) t.Run(model+"_流式", func(t *testing.T) { - testClaudeMessage(t, model, true) + testClaudeMessage(t, claudeKey, model, true) }) } } -func testClaudeMessage(t *testing.T, model string, stream bool) { +func testClaudeMessage(t *testing.T, claudeKey string, model string, stream bool) { url := baseURL + endpointPrefix + "/v1/messages" payload := map[string]any{ @@ -166,7 +204,7 @@ func testClaudeMessage(t *testing.T, model string, stream bool) { req, _ := http.NewRequest("POST", url, bytes.NewReader(body)) req.Header.Set("Content-Type", "application/json") - req.Header.Set("Authorization", "Bearer "+claudeAPIKey) + req.Header.Set("Authorization", "Bearer "+claudeKey) req.Header.Set("anthropic-version", "2023-06-01") client := &http.Client{Timeout: 60 * time.Second} @@ -213,21 +251,22 @@ func testClaudeMessage(t *testing.T, model string, stream bool) { // TestGeminiGenerateContent 测试 Gemini /v1beta/models/:model 接口 func TestGeminiGenerateContent(t *testing.T) { + geminiKey := requireGeminiAPIKey(t) for i, model := range geminiModels { if i > 0 { time.Sleep(testInterval) } t.Run(model+"_非流式", func(t *testing.T) { - testGeminiGenerate(t, model, false) + testGeminiGenerate(t, geminiKey, model, false) }) time.Sleep(testInterval) t.Run(model+"_流式", func(t *testing.T) { - testGeminiGenerate(t, model, true) + testGeminiGenerate(t, geminiKey, model, true) }) } } -func testGeminiGenerate(t *testing.T, model string, stream bool) { +func testGeminiGenerate(t *testing.T, geminiKey string, model string, stream bool) { action := "generateContent" if stream { action = "streamGenerateContent" @@ -254,7 +293,7 @@ func testGeminiGenerate(t *testing.T, model string, stream bool) { req, _ := http.NewRequest("POST", url, bytes.NewReader(body)) req.Header.Set("Content-Type", "application/json") - req.Header.Set("Authorization", "Bearer "+geminiAPIKey) + req.Header.Set("Authorization", "Bearer "+geminiKey) client := &http.Client{Timeout: 60 * time.Second} resp, err := client.Do(req) @@ -301,6 +340,7 @@ func testGeminiGenerate(t *testing.T, model string, stream bool) { // TestClaudeMessagesWithComplexTools 测试带复杂工具 schema 的请求 // 模拟 Claude Code 发送的请求,包含需要清理的 JSON Schema 字段 func TestClaudeMessagesWithComplexTools(t *testing.T) { + claudeKey := requireClaudeAPIKey(t) // 测试模型列表(只测试几个代表性模型) models := []string{ "claude-opus-4-5-20251101", // Claude 模型 @@ -312,12 +352,12 @@ func TestClaudeMessagesWithComplexTools(t *testing.T) { time.Sleep(testInterval) } t.Run(model+"_复杂工具", func(t *testing.T) { - testClaudeMessageWithTools(t, model) + testClaudeMessageWithTools(t, claudeKey, model) }) } } -func testClaudeMessageWithTools(t *testing.T, model string) { +func testClaudeMessageWithTools(t *testing.T, claudeKey string, model string) { url := baseURL + endpointPrefix + "/v1/messages" // 构造包含复杂 schema 的工具定义(模拟 Claude Code 的工具) @@ -473,7 +513,7 @@ func testClaudeMessageWithTools(t *testing.T, model string) { req, _ := http.NewRequest("POST", url, bytes.NewReader(body)) req.Header.Set("Content-Type", "application/json") - req.Header.Set("Authorization", "Bearer "+claudeAPIKey) + req.Header.Set("Authorization", "Bearer "+claudeKey) req.Header.Set("anthropic-version", "2023-06-01") client := &http.Client{Timeout: 60 * time.Second} @@ -519,6 +559,7 @@ func testClaudeMessageWithTools(t *testing.T, model string) { // 验证:当历史 assistant 消息包含 tool_use 但没有 signature 时, // 系统应自动添加 dummy thought_signature 避免 Gemini 400 错误 func TestClaudeMessagesWithThinkingAndTools(t *testing.T) { + claudeKey := requireClaudeAPIKey(t) models := []string{ "claude-haiku-4-5-20251001", // gemini-3-flash } @@ -527,12 +568,12 @@ func TestClaudeMessagesWithThinkingAndTools(t *testing.T) { time.Sleep(testInterval) } t.Run(model+"_thinking模式工具调用", func(t *testing.T) { - testClaudeThinkingWithToolHistory(t, model) + testClaudeThinkingWithToolHistory(t, claudeKey, model) }) } } -func testClaudeThinkingWithToolHistory(t *testing.T, model string) { +func testClaudeThinkingWithToolHistory(t *testing.T, claudeKey string, model string) { url := baseURL + endpointPrefix + "/v1/messages" // 模拟历史对话:用户请求 → assistant 调用工具 → 工具返回 → 继续对话 @@ -600,7 +641,7 @@ func testClaudeThinkingWithToolHistory(t *testing.T, model string) { req, _ := http.NewRequest("POST", url, bytes.NewReader(body)) req.Header.Set("Content-Type", "application/json") - req.Header.Set("Authorization", "Bearer "+claudeAPIKey) + req.Header.Set("Authorization", "Bearer "+claudeKey) req.Header.Set("anthropic-version", "2023-06-01") client := &http.Client{Timeout: 60 * time.Second} @@ -649,6 +690,7 @@ func TestClaudeMessagesWithGeminiModel(t *testing.T) { if endpointPrefix != "/antigravity" { t.Skip("仅在 Antigravity 模式下运行") } + claudeKey := requireClaudeAPIKey(t) // 测试通过 Claude 端点调用 Gemini 模型 geminiViaClaude := []string{ @@ -664,11 +706,11 @@ func TestClaudeMessagesWithGeminiModel(t *testing.T) { time.Sleep(testInterval) } t.Run(model+"_通过Claude端点", func(t *testing.T) { - testClaudeMessage(t, model, false) + testClaudeMessage(t, claudeKey, model, false) }) time.Sleep(testInterval) t.Run(model+"_通过Claude端点_流式", func(t *testing.T) { - testClaudeMessage(t, model, true) + testClaudeMessage(t, claudeKey, model, true) }) } } @@ -676,6 +718,7 @@ func TestClaudeMessagesWithGeminiModel(t *testing.T) { // TestClaudeMessagesWithNoSignature 测试历史 thinking block 不带 signature 的场景 // 验证:Gemini 模型接受没有 signature 的 thinking block func TestClaudeMessagesWithNoSignature(t *testing.T) { + claudeKey := requireClaudeAPIKey(t) models := []string{ "claude-haiku-4-5-20251001", // gemini-3-flash - 支持无 signature } @@ -684,12 +727,12 @@ func TestClaudeMessagesWithNoSignature(t *testing.T) { time.Sleep(testInterval) } t.Run(model+"_无signature", func(t *testing.T) { - testClaudeWithNoSignature(t, model) + testClaudeWithNoSignature(t, claudeKey, model) }) } } -func testClaudeWithNoSignature(t *testing.T, model string) { +func testClaudeWithNoSignature(t *testing.T, claudeKey string, model string) { url := baseURL + endpointPrefix + "/v1/messages" // 模拟历史对话包含 thinking block 但没有 signature @@ -732,7 +775,7 @@ func testClaudeWithNoSignature(t *testing.T, model string) { req, _ := http.NewRequest("POST", url, bytes.NewReader(body)) req.Header.Set("Content-Type", "application/json") - req.Header.Set("Authorization", "Bearer "+claudeAPIKey) + req.Header.Set("Authorization", "Bearer "+claudeKey) req.Header.Set("anthropic-version", "2023-06-01") client := &http.Client{Timeout: 60 * time.Second} @@ -777,6 +820,7 @@ func TestGeminiEndpointWithClaudeModel(t *testing.T) { if endpointPrefix != "/antigravity" { t.Skip("仅在 Antigravity 模式下运行") } + geminiKey := requireGeminiAPIKey(t) // 测试通过 Gemini 端点调用 Claude 模型 claudeViaGemini := []string{ @@ -789,11 +833,11 @@ func TestGeminiEndpointWithClaudeModel(t *testing.T) { time.Sleep(testInterval) } t.Run(model+"_通过Gemini端点", func(t *testing.T) { - testGeminiGenerate(t, model, false) + testGeminiGenerate(t, geminiKey, model, false) }) time.Sleep(testInterval) t.Run(model+"_通过Gemini端点_流式", func(t *testing.T) { - testGeminiGenerate(t, model, true) + testGeminiGenerate(t, geminiKey, model, true) }) } } diff --git a/backend/internal/integration/e2e_helpers_test.go b/backend/internal/integration/e2e_helpers_test.go new file mode 100644 index 00000000..7d266bcb --- /dev/null +++ b/backend/internal/integration/e2e_helpers_test.go @@ -0,0 +1,48 @@ +//go:build e2e + +package integration + +import ( + "os" + "strings" + "testing" +) + +// ============================================================================= +// E2E Mock 模式支持 +// ============================================================================= +// 当 E2E_MOCK=true 时,使用本地 Mock 响应替代真实 API 调用。 +// 这允许在没有真实 API Key 的环境(如 CI)中验证基本的请求/响应流程。 + +// isMockMode 检查是否启用 Mock 模式 +func isMockMode() bool { + return strings.EqualFold(os.Getenv("E2E_MOCK"), "true") +} + +// skipIfNoRealAPI 如果未配置真实 API Key 且不在 Mock 模式,则跳过测试 +func skipIfNoRealAPI(t *testing.T) { + t.Helper() + if isMockMode() { + return // Mock 模式下不跳过 + } + claudeKey := strings.TrimSpace(os.Getenv(claudeAPIKeyEnv)) + geminiKey := strings.TrimSpace(os.Getenv(geminiAPIKeyEnv)) + if claudeKey == "" && geminiKey == "" { + t.Skip("未设置 API Key 且未启用 Mock 模式,跳过测试") + } +} + +// ============================================================================= +// API Key 脱敏(Task 6.10) +// ============================================================================= + +// safeLogKey 安全地记录 API Key(仅显示前 8 位) +func safeLogKey(t *testing.T, prefix string, key string) { + t.Helper() + key = strings.TrimSpace(key) + if len(key) <= 8 { + t.Logf("%s: ***(长度: %d)", prefix, len(key)) + return + } + t.Logf("%s: %s...(长度: %d)", prefix, key[:8], len(key)) +} diff --git a/backend/internal/integration/e2e_user_flow_test.go b/backend/internal/integration/e2e_user_flow_test.go new file mode 100644 index 00000000..5489d0a3 --- /dev/null +++ b/backend/internal/integration/e2e_user_flow_test.go @@ -0,0 +1,317 @@ +//go:build e2e + +package integration + +import ( + "bytes" + "encoding/json" + "fmt" + "io" + "net/http" + "strings" + "testing" + "time" +) + +// E2E 用户流程测试 +// 测试完整的用户操作链路:注册 → 登录 → 创建 API Key → 调用网关 → 查询用量 + +var ( + testUserEmail = "e2e-test-" + fmt.Sprintf("%d", time.Now().UnixMilli()) + "@test.local" + testUserPassword = "E2eTest@12345" + testUserName = "e2e-test-user" +) + +// TestUserRegistrationAndLogin 测试用户注册和登录流程 +func TestUserRegistrationAndLogin(t *testing.T) { + // 步骤 1: 注册新用户 + t.Run("注册新用户", func(t *testing.T) { + payload := map[string]string{ + "email": testUserEmail, + "password": testUserPassword, + "username": testUserName, + } + body, _ := json.Marshal(payload) + + resp, err := doRequest(t, "POST", "/api/auth/register", body, "") + if err != nil { + t.Skipf("注册接口不可用,跳过用户流程测试: %v", err) + return + } + defer resp.Body.Close() + + respBody, _ := io.ReadAll(resp.Body) + + // 注册可能返回 200(成功)或 400(邮箱已存在)或 403(注册已关闭) + switch resp.StatusCode { + case 200: + t.Logf("✅ 用户注册成功: %s", testUserEmail) + case 400: + t.Logf("⚠️ 用户可能已存在: %s", string(respBody)) + case 403: + t.Skipf("注册功能已关闭: %s", string(respBody)) + default: + t.Logf("⚠️ 注册返回 HTTP %d: %s(继续尝试登录)", resp.StatusCode, string(respBody)) + } + }) + + // 步骤 2: 登录获取 JWT + var accessToken string + t.Run("用户登录获取JWT", func(t *testing.T) { + payload := map[string]string{ + "email": testUserEmail, + "password": testUserPassword, + } + body, _ := json.Marshal(payload) + + resp, err := doRequest(t, "POST", "/api/auth/login", body, "") + if err != nil { + t.Fatalf("登录请求失败: %v", err) + } + defer resp.Body.Close() + + respBody, _ := io.ReadAll(resp.Body) + + if resp.StatusCode != 200 { + t.Skipf("登录失败 HTTP %d: %s(可能需要先注册用户)", resp.StatusCode, string(respBody)) + return + } + + var result map[string]any + if err := json.Unmarshal(respBody, &result); err != nil { + t.Fatalf("解析登录响应失败: %v", err) + } + + // 尝试从标准响应格式获取 token + if token, ok := result["access_token"].(string); ok && token != "" { + accessToken = token + } else if data, ok := result["data"].(map[string]any); ok { + if token, ok := data["access_token"].(string); ok { + accessToken = token + } + } + + if accessToken == "" { + t.Skipf("未获取到 access_token,响应: %s", string(respBody)) + return + } + + // 验证 token 不为空且格式基本正确 + if len(accessToken) < 10 { + t.Fatalf("access_token 格式异常: %s", accessToken) + } + + t.Logf("✅ 登录成功,获取 JWT(长度: %d)", len(accessToken)) + }) + + if accessToken == "" { + t.Skip("未获取到 JWT,跳过后续测试") + return + } + + // 步骤 3: 使用 JWT 获取当前用户信息 + t.Run("获取当前用户信息", func(t *testing.T) { + resp, err := doRequest(t, "GET", "/api/user/me", nil, accessToken) + if err != nil { + t.Fatalf("请求失败: %v", err) + } + defer resp.Body.Close() + + if resp.StatusCode != 200 { + body, _ := io.ReadAll(resp.Body) + t.Fatalf("HTTP %d: %s", resp.StatusCode, string(body)) + } + + t.Logf("✅ 成功获取用户信息") + }) +} + +// TestAPIKeyLifecycle 测试 API Key 的创建和使用 +func TestAPIKeyLifecycle(t *testing.T) { + // 先登录获取 JWT + accessToken := loginTestUser(t) + if accessToken == "" { + t.Skip("无法登录,跳过 API Key 生命周期测试") + return + } + + var apiKey string + + // 步骤 1: 创建 API Key + t.Run("创建API_Key", func(t *testing.T) { + payload := map[string]string{ + "name": "e2e-test-key-" + fmt.Sprintf("%d", time.Now().UnixMilli()), + } + body, _ := json.Marshal(payload) + + resp, err := doRequest(t, "POST", "/api/keys", body, accessToken) + if err != nil { + t.Fatalf("创建 API Key 请求失败: %v", err) + } + defer resp.Body.Close() + + respBody, _ := io.ReadAll(resp.Body) + + if resp.StatusCode != 200 { + t.Skipf("创建 API Key 失败 HTTP %d: %s", resp.StatusCode, string(respBody)) + return + } + + var result map[string]any + if err := json.Unmarshal(respBody, &result); err != nil { + t.Fatalf("解析响应失败: %v", err) + } + + // 从响应中提取 key + if key, ok := result["key"].(string); ok { + apiKey = key + } else if data, ok := result["data"].(map[string]any); ok { + if key, ok := data["key"].(string); ok { + apiKey = key + } + } + + if apiKey == "" { + t.Skipf("未获取到 API Key,响应: %s", string(respBody)) + return + } + + // 验证 API Key 脱敏日志(只显示前 8 位) + masked := apiKey + if len(masked) > 8 { + masked = masked[:8] + "..." + } + t.Logf("✅ API Key 创建成功: %s", masked) + }) + + if apiKey == "" { + t.Skip("未创建 API Key,跳过后续测试") + return + } + + // 步骤 2: 使用 API Key 调用网关(需要 Claude 或 Gemini 可用) + t.Run("使用API_Key调用网关", func(t *testing.T) { + // 尝试调用 models 列表(最轻量的 API 调用) + resp, err := doRequest(t, "GET", "/v1/models", nil, apiKey) + if err != nil { + t.Fatalf("网关请求失败: %v", err) + } + defer resp.Body.Close() + + respBody, _ := io.ReadAll(resp.Body) + + // 可能返回 200(成功)或 402(余额不足)或 403(无可用账户) + switch { + case resp.StatusCode == 200: + t.Logf("✅ API Key 网关调用成功") + case resp.StatusCode == 402: + t.Logf("⚠️ 余额不足,但 API Key 认证通过") + case resp.StatusCode == 403: + t.Logf("⚠️ 无可用账户,但 API Key 认证通过") + default: + t.Logf("⚠️ 网关返回 HTTP %d: %s", resp.StatusCode, string(respBody)) + } + }) + + // 步骤 3: 查询用量记录 + t.Run("查询用量记录", func(t *testing.T) { + resp, err := doRequest(t, "GET", "/api/usage/dashboard", nil, accessToken) + if err != nil { + t.Fatalf("用量查询请求失败: %v", err) + } + defer resp.Body.Close() + + if resp.StatusCode != 200 { + body, _ := io.ReadAll(resp.Body) + t.Logf("⚠️ 用量查询返回 HTTP %d: %s", resp.StatusCode, string(body)) + return + } + + t.Logf("✅ 用量查询成功") + }) +} + +// ============================================================================= +// 辅助函数 +// ============================================================================= + +func doRequest(t *testing.T, method, path string, body []byte, token string) (*http.Response, error) { + t.Helper() + + url := baseURL + path + var bodyReader io.Reader + if body != nil { + bodyReader = bytes.NewReader(body) + } + + req, err := http.NewRequest(method, url, bodyReader) + if err != nil { + return nil, fmt.Errorf("创建请求失败: %w", err) + } + + if body != nil { + req.Header.Set("Content-Type", "application/json") + } + if token != "" { + req.Header.Set("Authorization", "Bearer "+token) + } + + client := &http.Client{Timeout: 30 * time.Second} + return client.Do(req) +} + +func loginTestUser(t *testing.T) string { + t.Helper() + + // 先尝试用管理员账户登录 + adminEmail := getEnv("ADMIN_EMAIL", "admin@sub2api.local") + adminPassword := getEnv("ADMIN_PASSWORD", "") + + if adminPassword == "" { + // 尝试用测试用户 + adminEmail = testUserEmail + adminPassword = testUserPassword + } + + payload := map[string]string{ + "email": adminEmail, + "password": adminPassword, + } + body, _ := json.Marshal(payload) + + resp, err := doRequest(t, "POST", "/api/auth/login", body, "") + if err != nil { + return "" + } + defer resp.Body.Close() + + if resp.StatusCode != 200 { + return "" + } + + respBody, _ := io.ReadAll(resp.Body) + var result map[string]any + if err := json.Unmarshal(respBody, &result); err != nil { + return "" + } + + if token, ok := result["access_token"].(string); ok { + return token + } + if data, ok := result["data"].(map[string]any); ok { + if token, ok := data["access_token"].(string); ok { + return token + } + } + + return "" +} + +// redactAPIKey API Key 脱敏,只显示前 8 位 +func redactAPIKey(key string) string { + key = strings.TrimSpace(key) + if len(key) <= 8 { + return "***" + } + return key[:8] + "..." +} diff --git a/backend/internal/middleware/rate_limiter_test.go b/backend/internal/middleware/rate_limiter_test.go index 0c379c0f..e362274f 100644 --- a/backend/internal/middleware/rate_limiter_test.go +++ b/backend/internal/middleware/rate_limiter_test.go @@ -60,6 +60,49 @@ func TestRateLimiterFailureModes(t *testing.T) { require.Equal(t, http.StatusTooManyRequests, recorder.Code) } +func TestRateLimiterDifferentIPsIndependent(t *testing.T) { + gin.SetMode(gin.TestMode) + + callCounts := make(map[string]int64) + originalRun := rateLimitRun + rateLimitRun = func(ctx context.Context, client *redis.Client, key string, windowMillis int64) (int64, bool, error) { + callCounts[key]++ + return callCounts[key], false, nil + } + t.Cleanup(func() { + rateLimitRun = originalRun + }) + + limiter := NewRateLimiter(redis.NewClient(&redis.Options{Addr: "127.0.0.1:1"})) + + router := gin.New() + router.Use(limiter.Limit("api", 1, time.Second)) + router.GET("/test", func(c *gin.Context) { + c.JSON(http.StatusOK, gin.H{"ok": true}) + }) + + // 第一个 IP 的请求应通过 + req1 := httptest.NewRequest(http.MethodGet, "/test", nil) + req1.RemoteAddr = "10.0.0.1:1234" + rec1 := httptest.NewRecorder() + router.ServeHTTP(rec1, req1) + require.Equal(t, http.StatusOK, rec1.Code, "第一个 IP 的第一次请求应通过") + + // 第二个 IP 的请求应独立通过(不受第一个 IP 的计数影响) + req2 := httptest.NewRequest(http.MethodGet, "/test", nil) + req2.RemoteAddr = "10.0.0.2:5678" + rec2 := httptest.NewRecorder() + router.ServeHTTP(rec2, req2) + require.Equal(t, http.StatusOK, rec2.Code, "第二个 IP 的第一次请求应独立通过") + + // 第一个 IP 的第二次请求应被限流 + req3 := httptest.NewRequest(http.MethodGet, "/test", nil) + req3.RemoteAddr = "10.0.0.1:1234" + rec3 := httptest.NewRecorder() + router.ServeHTTP(rec3, req3) + require.Equal(t, http.StatusTooManyRequests, rec3.Code, "第一个 IP 的第二次请求应被限流") +} + func TestRateLimiterSuccessAndLimit(t *testing.T) { gin.SetMode(gin.TestMode) diff --git a/backend/internal/pkg/tlsfingerprint/dialer_test.go b/backend/internal/pkg/tlsfingerprint/dialer_test.go index 345067e5..6d3db174 100644 --- a/backend/internal/pkg/tlsfingerprint/dialer_test.go +++ b/backend/internal/pkg/tlsfingerprint/dialer_test.go @@ -1,3 +1,5 @@ +//go:build unit + // Package tlsfingerprint provides TLS fingerprint simulation for HTTP clients. // // Unit tests for TLS fingerprint dialer. @@ -20,24 +22,6 @@ import ( "time" ) -// FingerprintResponse represents the response from tls.peet.ws/api/all. -type FingerprintResponse struct { - IP string `json:"ip"` - TLS TLSInfo `json:"tls"` - HTTP2 any `json:"http2"` -} - -// TLSInfo contains TLS fingerprint details. -type TLSInfo struct { - JA3 string `json:"ja3"` - JA3Hash string `json:"ja3_hash"` - JA4 string `json:"ja4"` - PeetPrint string `json:"peetprint"` - PeetPrintHash string `json:"peetprint_hash"` - ClientRandom string `json:"client_random"` - SessionID string `json:"session_id"` -} - // TestDialerBasicConnection tests that the dialer can establish TLS connections. func TestDialerBasicConnection(t *testing.T) { skipNetworkTest(t) diff --git a/backend/internal/pkg/tlsfingerprint/test_types_test.go b/backend/internal/pkg/tlsfingerprint/test_types_test.go new file mode 100644 index 00000000..2bbf2d22 --- /dev/null +++ b/backend/internal/pkg/tlsfingerprint/test_types_test.go @@ -0,0 +1,20 @@ +package tlsfingerprint + +// FingerprintResponse represents the response from tls.peet.ws/api/all. +// 共享测试类型,供 unit 和 integration 测试文件使用。 +type FingerprintResponse struct { + IP string `json:"ip"` + TLS TLSInfo `json:"tls"` + HTTP2 any `json:"http2"` +} + +// TLSInfo contains TLS fingerprint details. +type TLSInfo struct { + JA3 string `json:"ja3"` + JA3Hash string `json:"ja3_hash"` + JA4 string `json:"ja4"` + PeetPrint string `json:"peetprint"` + PeetPrintHash string `json:"peetprint_hash"` + ClientRandom string `json:"client_random"` + SessionID string `json:"session_id"` +} diff --git a/backend/internal/repository/billing_cache_jitter_test.go b/backend/internal/repository/billing_cache_jitter_test.go index 32c42cf4..ba4f2873 100644 --- a/backend/internal/repository/billing_cache_jitter_test.go +++ b/backend/internal/repository/billing_cache_jitter_test.go @@ -14,7 +14,7 @@ func TestJitteredTTL_WithinExpectedRange(t *testing.T) { // jitteredTTL 使用减法抖动: billingCacheTTL - [0, billingCacheJitter) // 所以结果应在 [billingCacheTTL - billingCacheJitter, billingCacheTTL] 范围内 lowerBound := billingCacheTTL - billingCacheJitter // 5min - 30s = 4min30s - upperBound := billingCacheTTL // 5min + upperBound := billingCacheTTL // 5min for i := 0; i < 200; i++ { ttl := jitteredTTL() diff --git a/backend/internal/server/api_contract_test.go b/backend/internal/server/api_contract_test.go index d92dcc47..6851e71a 100644 --- a/backend/internal/server/api_contract_test.go +++ b/backend/internal/server/api_contract_test.go @@ -603,7 +603,7 @@ func newContractDeps(t *testing.T) *contractDeps { usageRepo := newStubUsageLogRepo() usageService := service.NewUsageService(usageRepo, userRepo, nil, nil) - subscriptionService := service.NewSubscriptionService(groupRepo, userSubRepo, nil, cfg) + subscriptionService := service.NewSubscriptionService(groupRepo, userSubRepo, nil, nil, cfg) subscriptionHandler := handler.NewSubscriptionHandler(subscriptionService) redeemService := service.NewRedeemService(redeemRepo, userRepo, subscriptionService, nil, nil, nil, nil) diff --git a/backend/internal/server/middleware/jwt_auth_test.go b/backend/internal/server/middleware/jwt_auth_test.go new file mode 100644 index 00000000..e1b8e1ad --- /dev/null +++ b/backend/internal/server/middleware/jwt_auth_test.go @@ -0,0 +1,234 @@ +//go:build unit + +package middleware + +import ( + "context" + "encoding/json" + "errors" + "net/http" + "net/http/httptest" + "testing" + + "github.com/Wei-Shaw/sub2api/internal/config" + "github.com/Wei-Shaw/sub2api/internal/service" + "github.com/gin-gonic/gin" + "github.com/stretchr/testify/require" +) + +// stubJWTUserRepo 实现 UserRepository 的最小子集,仅支持 GetByID。 +type stubJWTUserRepo struct { + service.UserRepository + users map[int64]*service.User +} + +func (r *stubJWTUserRepo) GetByID(_ context.Context, id int64) (*service.User, error) { + u, ok := r.users[id] + if !ok { + return nil, errors.New("user not found") + } + return u, nil +} + +// newJWTTestEnv 创建 JWT 认证中间件测试环境。 +// 返回 gin.Engine(已注册 JWT 中间件)和 AuthService(用于生成 Token)。 +func newJWTTestEnv(users map[int64]*service.User) (*gin.Engine, *service.AuthService) { + gin.SetMode(gin.TestMode) + + cfg := &config.Config{} + cfg.JWT.Secret = "test-jwt-secret-32bytes-long!!!" + cfg.JWT.AccessTokenExpireMinutes = 60 + + userRepo := &stubJWTUserRepo{users: users} + authSvc := service.NewAuthService(userRepo, nil, nil, cfg, nil, nil, nil, nil, nil) + userSvc := service.NewUserService(userRepo, nil, nil) + mw := NewJWTAuthMiddleware(authSvc, userSvc) + + r := gin.New() + r.Use(gin.HandlerFunc(mw)) + r.GET("/protected", func(c *gin.Context) { + subject, _ := GetAuthSubjectFromContext(c) + role, _ := GetUserRoleFromContext(c) + c.JSON(http.StatusOK, gin.H{ + "user_id": subject.UserID, + "role": role, + }) + }) + return r, authSvc +} + +func TestJWTAuth_ValidToken(t *testing.T) { + user := &service.User{ + ID: 1, + Email: "test@example.com", + Role: "user", + Status: service.StatusActive, + Concurrency: 5, + TokenVersion: 1, + } + router, authSvc := newJWTTestEnv(map[int64]*service.User{1: user}) + + token, err := authSvc.GenerateToken(user) + require.NoError(t, err) + + w := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/protected", nil) + req.Header.Set("Authorization", "Bearer "+token) + router.ServeHTTP(w, req) + + require.Equal(t, http.StatusOK, w.Code) + + var body map[string]any + require.NoError(t, json.Unmarshal(w.Body.Bytes(), &body)) + require.Equal(t, float64(1), body["user_id"]) + require.Equal(t, "user", body["role"]) +} + +func TestJWTAuth_MissingAuthorizationHeader(t *testing.T) { + router, _ := newJWTTestEnv(nil) + + w := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/protected", nil) + router.ServeHTTP(w, req) + + require.Equal(t, http.StatusUnauthorized, w.Code) + var body ErrorResponse + require.NoError(t, json.Unmarshal(w.Body.Bytes(), &body)) + require.Equal(t, "UNAUTHORIZED", body.Code) +} + +func TestJWTAuth_InvalidHeaderFormat(t *testing.T) { + tests := []struct { + name string + header string + }{ + {"无Bearer前缀", "Token abc123"}, + {"缺少空格分隔", "Bearerabc123"}, + {"仅有单词", "abc123"}, + } + router, _ := newJWTTestEnv(nil) + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + w := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/protected", nil) + req.Header.Set("Authorization", tt.header) + router.ServeHTTP(w, req) + + require.Equal(t, http.StatusUnauthorized, w.Code) + var body ErrorResponse + require.NoError(t, json.Unmarshal(w.Body.Bytes(), &body)) + require.Equal(t, "INVALID_AUTH_HEADER", body.Code) + }) + } +} + +func TestJWTAuth_EmptyToken(t *testing.T) { + router, _ := newJWTTestEnv(nil) + + w := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/protected", nil) + req.Header.Set("Authorization", "Bearer ") + router.ServeHTTP(w, req) + + require.Equal(t, http.StatusUnauthorized, w.Code) + var body ErrorResponse + require.NoError(t, json.Unmarshal(w.Body.Bytes(), &body)) + require.Equal(t, "EMPTY_TOKEN", body.Code) +} + +func TestJWTAuth_TamperedToken(t *testing.T) { + router, _ := newJWTTestEnv(nil) + + w := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/protected", nil) + req.Header.Set("Authorization", "Bearer eyJhbGciOiJIUzI1NiJ9.eyJ1c2VyX2lkIjoxfQ.invalid_signature") + router.ServeHTTP(w, req) + + require.Equal(t, http.StatusUnauthorized, w.Code) + var body ErrorResponse + require.NoError(t, json.Unmarshal(w.Body.Bytes(), &body)) + require.Equal(t, "INVALID_TOKEN", body.Code) +} + +func TestJWTAuth_UserNotFound(t *testing.T) { + // 使用 user ID=1 的 token,但 repo 中没有该用户 + fakeUser := &service.User{ + ID: 999, + Email: "ghost@example.com", + Role: "user", + Status: service.StatusActive, + TokenVersion: 1, + } + // 创建环境时不注入此用户,这样 GetByID 会失败 + router, authSvc := newJWTTestEnv(map[int64]*service.User{}) + + token, err := authSvc.GenerateToken(fakeUser) + require.NoError(t, err) + + w := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/protected", nil) + req.Header.Set("Authorization", "Bearer "+token) + router.ServeHTTP(w, req) + + require.Equal(t, http.StatusUnauthorized, w.Code) + var body ErrorResponse + require.NoError(t, json.Unmarshal(w.Body.Bytes(), &body)) + require.Equal(t, "USER_NOT_FOUND", body.Code) +} + +func TestJWTAuth_UserInactive(t *testing.T) { + user := &service.User{ + ID: 1, + Email: "disabled@example.com", + Role: "user", + Status: service.StatusDisabled, + TokenVersion: 1, + } + router, authSvc := newJWTTestEnv(map[int64]*service.User{1: user}) + + token, err := authSvc.GenerateToken(user) + require.NoError(t, err) + + w := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/protected", nil) + req.Header.Set("Authorization", "Bearer "+token) + router.ServeHTTP(w, req) + + require.Equal(t, http.StatusUnauthorized, w.Code) + var body ErrorResponse + require.NoError(t, json.Unmarshal(w.Body.Bytes(), &body)) + require.Equal(t, "USER_INACTIVE", body.Code) +} + +func TestJWTAuth_TokenVersionMismatch(t *testing.T) { + // Token 生成时 TokenVersion=1,但数据库中用户已更新为 TokenVersion=2(密码修改) + userForToken := &service.User{ + ID: 1, + Email: "test@example.com", + Role: "user", + Status: service.StatusActive, + TokenVersion: 1, + } + userInDB := &service.User{ + ID: 1, + Email: "test@example.com", + Role: "user", + Status: service.StatusActive, + TokenVersion: 2, // 密码修改后版本递增 + } + router, authSvc := newJWTTestEnv(map[int64]*service.User{1: userInDB}) + + token, err := authSvc.GenerateToken(userForToken) + require.NoError(t, err) + + w := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/protected", nil) + req.Header.Set("Authorization", "Bearer "+token) + router.ServeHTTP(w, req) + + require.Equal(t, http.StatusUnauthorized, w.Code) + var body ErrorResponse + require.NoError(t, json.Unmarshal(w.Body.Bytes(), &body)) + require.Equal(t, "TOKEN_REVOKED", body.Code) +} diff --git a/backend/internal/server/middleware/recovery_test.go b/backend/internal/server/middleware/recovery_test.go index 439f44cb..33e71d51 100644 --- a/backend/internal/server/middleware/recovery_test.go +++ b/backend/internal/server/middleware/recovery_test.go @@ -3,6 +3,7 @@ package middleware import ( + "bytes" "encoding/json" "net/http" "net/http/httptest" @@ -14,6 +15,34 @@ import ( "github.com/stretchr/testify/require" ) +func TestRecovery_PanicLogContainsInfo(t *testing.T) { + gin.SetMode(gin.TestMode) + + // 临时替换 DefaultErrorWriter 以捕获日志输出 + var buf bytes.Buffer + originalWriter := gin.DefaultErrorWriter + gin.DefaultErrorWriter = &buf + t.Cleanup(func() { + gin.DefaultErrorWriter = originalWriter + }) + + r := gin.New() + r.Use(Recovery()) + r.GET("/panic", func(c *gin.Context) { + panic("custom panic message for test") + }) + + w := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/panic", nil) + r.ServeHTTP(w, req) + + require.Equal(t, http.StatusInternalServerError, w.Code) + + logOutput := buf.String() + require.Contains(t, logOutput, "custom panic message for test", "日志应包含 panic 信息") + require.Contains(t, logOutput, "recovery_test.go", "日志应包含堆栈跟踪文件名") +} + func TestRecovery(t *testing.T) { gin.SetMode(gin.TestMode) diff --git a/backend/internal/service/antigravity_rate_limit_test.go b/backend/internal/service/antigravity_rate_limit_test.go index 20936356..2b4a5504 100644 --- a/backend/internal/service/antigravity_rate_limit_test.go +++ b/backend/internal/service/antigravity_rate_limit_test.go @@ -15,6 +15,12 @@ import ( "github.com/stretchr/testify/require" ) +// 编译期接口断言 +var _ HTTPUpstream = (*stubAntigravityUpstream)(nil) +var _ HTTPUpstream = (*recordingOKUpstream)(nil) +var _ AccountRepository = (*stubAntigravityAccountRepo)(nil) +var _ SchedulerCache = (*stubSchedulerCache)(nil) + type stubAntigravityUpstream struct { firstBase string secondBase string diff --git a/backend/internal/service/billing_service_test.go b/backend/internal/service/billing_service_test.go new file mode 100644 index 00000000..cdaf6953 --- /dev/null +++ b/backend/internal/service/billing_service_test.go @@ -0,0 +1,310 @@ +//go:build unit + +package service + +import ( + "math" + "testing" + + "github.com/Wei-Shaw/sub2api/internal/config" + "github.com/stretchr/testify/require" +) + +func newTestBillingService() *BillingService { + return NewBillingService(&config.Config{}, nil) +} + +func TestCalculateCost_BasicComputation(t *testing.T) { + svc := newTestBillingService() + + // 使用 claude-sonnet-4 的回退价格:Input $3/MTok, Output $15/MTok + tokens := UsageTokens{ + InputTokens: 1000, + OutputTokens: 500, + } + cost, err := svc.CalculateCost("claude-sonnet-4", tokens, 1.0) + require.NoError(t, err) + + // 1000 * 3e-6 = 0.003, 500 * 15e-6 = 0.0075 + expectedInput := 1000 * 3e-6 + expectedOutput := 500 * 15e-6 + require.InDelta(t, expectedInput, cost.InputCost, 1e-10) + require.InDelta(t, expectedOutput, cost.OutputCost, 1e-10) + require.InDelta(t, expectedInput+expectedOutput, cost.TotalCost, 1e-10) + require.InDelta(t, expectedInput+expectedOutput, cost.ActualCost, 1e-10) +} + +func TestCalculateCost_WithCacheTokens(t *testing.T) { + svc := newTestBillingService() + + tokens := UsageTokens{ + InputTokens: 1000, + OutputTokens: 500, + CacheCreationTokens: 2000, + CacheReadTokens: 3000, + } + cost, err := svc.CalculateCost("claude-sonnet-4", tokens, 1.0) + require.NoError(t, err) + + expectedCacheCreation := 2000 * 3.75e-6 + expectedCacheRead := 3000 * 0.3e-6 + require.InDelta(t, expectedCacheCreation, cost.CacheCreationCost, 1e-10) + require.InDelta(t, expectedCacheRead, cost.CacheReadCost, 1e-10) + + expectedTotal := cost.InputCost + cost.OutputCost + expectedCacheCreation + expectedCacheRead + require.InDelta(t, expectedTotal, cost.TotalCost, 1e-10) +} + +func TestCalculateCost_RateMultiplier(t *testing.T) { + svc := newTestBillingService() + + tokens := UsageTokens{InputTokens: 1000, OutputTokens: 500} + + cost1x, err := svc.CalculateCost("claude-sonnet-4", tokens, 1.0) + require.NoError(t, err) + + cost2x, err := svc.CalculateCost("claude-sonnet-4", tokens, 2.0) + require.NoError(t, err) + + // TotalCost 不受倍率影响,ActualCost 翻倍 + require.InDelta(t, cost1x.TotalCost, cost2x.TotalCost, 1e-10) + require.InDelta(t, cost1x.ActualCost*2, cost2x.ActualCost, 1e-10) +} + +func TestCalculateCost_ZeroMultiplierDefaultsToOne(t *testing.T) { + svc := newTestBillingService() + + tokens := UsageTokens{InputTokens: 1000} + + costZero, err := svc.CalculateCost("claude-sonnet-4", tokens, 0) + require.NoError(t, err) + + costOne, err := svc.CalculateCost("claude-sonnet-4", tokens, 1.0) + require.NoError(t, err) + + require.InDelta(t, costOne.ActualCost, costZero.ActualCost, 1e-10) +} + +func TestCalculateCost_NegativeMultiplierDefaultsToOne(t *testing.T) { + svc := newTestBillingService() + + tokens := UsageTokens{InputTokens: 1000} + + costNeg, err := svc.CalculateCost("claude-sonnet-4", tokens, -1.0) + require.NoError(t, err) + + costOne, err := svc.CalculateCost("claude-sonnet-4", tokens, 1.0) + require.NoError(t, err) + + require.InDelta(t, costOne.ActualCost, costNeg.ActualCost, 1e-10) +} + +func TestGetModelPricing_FallbackMatchesByFamily(t *testing.T) { + svc := newTestBillingService() + + tests := []struct { + model string + expectedInput float64 + }{ + {"claude-opus-4.5-20250101", 5e-6}, + {"claude-3-opus-20240229", 15e-6}, + {"claude-sonnet-4-20250514", 3e-6}, + {"claude-3-5-sonnet-20241022", 3e-6}, + {"claude-3-5-haiku-20241022", 1e-6}, + {"claude-3-haiku-20240307", 0.25e-6}, + } + + for _, tt := range tests { + pricing, err := svc.GetModelPricing(tt.model) + require.NoError(t, err, "模型 %s", tt.model) + require.InDelta(t, tt.expectedInput, pricing.InputPricePerToken, 1e-12, "模型 %s 输入价格", tt.model) + } +} + +func TestGetModelPricing_CaseInsensitive(t *testing.T) { + svc := newTestBillingService() + + p1, err := svc.GetModelPricing("Claude-Sonnet-4") + require.NoError(t, err) + + p2, err := svc.GetModelPricing("claude-sonnet-4") + require.NoError(t, err) + + require.Equal(t, p1.InputPricePerToken, p2.InputPricePerToken) +} + +func TestGetModelPricing_UnknownModelFallsBackToSonnet(t *testing.T) { + svc := newTestBillingService() + + // 不包含 opus/sonnet/haiku 关键词的 Claude 模型会走默认 Sonnet 价格 + pricing, err := svc.GetModelPricing("claude-unknown-model") + require.NoError(t, err) + require.InDelta(t, 3e-6, pricing.InputPricePerToken, 1e-12) +} + +func TestCalculateCostWithLongContext_BelowThreshold(t *testing.T) { + svc := newTestBillingService() + + tokens := UsageTokens{ + InputTokens: 50000, + OutputTokens: 1000, + CacheReadTokens: 100000, + } + // 总输入 150k < 200k 阈值,应走正常计费 + cost, err := svc.CalculateCostWithLongContext("claude-sonnet-4", tokens, 1.0, 200000, 2.0) + require.NoError(t, err) + + normalCost, err := svc.CalculateCost("claude-sonnet-4", tokens, 1.0) + require.NoError(t, err) + + require.InDelta(t, normalCost.ActualCost, cost.ActualCost, 1e-10) +} + +func TestCalculateCostWithLongContext_AboveThreshold_CacheExceedsThreshold(t *testing.T) { + svc := newTestBillingService() + + // 缓存 210k + 输入 10k = 220k > 200k 阈值 + // 缓存已超阈值:范围内 200k 缓存,范围外 10k 缓存 + 10k 输入 + tokens := UsageTokens{ + InputTokens: 10000, + OutputTokens: 1000, + CacheReadTokens: 210000, + } + cost, err := svc.CalculateCostWithLongContext("claude-sonnet-4", tokens, 1.0, 200000, 2.0) + require.NoError(t, err) + + // 范围内:200k cache + 0 input + 1k output + inRange, _ := svc.CalculateCost("claude-sonnet-4", UsageTokens{ + InputTokens: 0, + OutputTokens: 1000, + CacheReadTokens: 200000, + }, 1.0) + + // 范围外:10k cache + 10k input,倍率 2.0 + outRange, _ := svc.CalculateCost("claude-sonnet-4", UsageTokens{ + InputTokens: 10000, + CacheReadTokens: 10000, + }, 2.0) + + require.InDelta(t, inRange.ActualCost+outRange.ActualCost, cost.ActualCost, 1e-10) +} + +func TestCalculateCostWithLongContext_AboveThreshold_CacheBelowThreshold(t *testing.T) { + svc := newTestBillingService() + + // 缓存 100k + 输入 150k = 250k > 200k 阈值 + // 缓存未超阈值:范围内 100k 缓存 + 100k 输入,范围外 50k 输入 + tokens := UsageTokens{ + InputTokens: 150000, + OutputTokens: 1000, + CacheReadTokens: 100000, + } + cost, err := svc.CalculateCostWithLongContext("claude-sonnet-4", tokens, 1.0, 200000, 2.0) + require.NoError(t, err) + + require.True(t, cost.ActualCost > 0, "费用应大于 0") + + // 正常费用不含长上下文 + normalCost, _ := svc.CalculateCost("claude-sonnet-4", tokens, 1.0) + require.True(t, cost.ActualCost > normalCost.ActualCost, "长上下文费用应高于正常费用") +} + +func TestCalculateCostWithLongContext_DisabledThreshold(t *testing.T) { + svc := newTestBillingService() + + tokens := UsageTokens{InputTokens: 300000, CacheReadTokens: 0} + + // threshold <= 0 应禁用长上下文计费 + cost1, err := svc.CalculateCostWithLongContext("claude-sonnet-4", tokens, 1.0, 0, 2.0) + require.NoError(t, err) + + cost2, err := svc.CalculateCost("claude-sonnet-4", tokens, 1.0) + require.NoError(t, err) + + require.InDelta(t, cost2.ActualCost, cost1.ActualCost, 1e-10) +} + +func TestCalculateCostWithLongContext_ExtraMultiplierLessEqualOne(t *testing.T) { + svc := newTestBillingService() + + tokens := UsageTokens{InputTokens: 300000} + + // extraMultiplier <= 1 应禁用长上下文计费 + cost, err := svc.CalculateCostWithLongContext("claude-sonnet-4", tokens, 1.0, 200000, 1.0) + require.NoError(t, err) + + normalCost, err := svc.CalculateCost("claude-sonnet-4", tokens, 1.0) + require.NoError(t, err) + + require.InDelta(t, normalCost.ActualCost, cost.ActualCost, 1e-10) +} + +func TestCalculateImageCost(t *testing.T) { + svc := newTestBillingService() + + price := 0.134 + cfg := &ImagePriceConfig{Price1K: &price} + cost := svc.CalculateImageCost("gpt-image-1", "1K", 3, cfg, 1.0) + + require.InDelta(t, 0.134*3, cost.TotalCost, 1e-10) + require.InDelta(t, 0.134*3, cost.ActualCost, 1e-10) +} + +func TestCalculateSoraVideoCost(t *testing.T) { + svc := newTestBillingService() + + price := 0.5 + cfg := &SoraPriceConfig{VideoPricePerRequest: &price} + cost := svc.CalculateSoraVideoCost("sora-video", cfg, 1.0) + + require.InDelta(t, 0.5, cost.TotalCost, 1e-10) +} + +func TestCalculateSoraVideoCost_HDModel(t *testing.T) { + svc := newTestBillingService() + + hdPrice := 1.0 + normalPrice := 0.5 + cfg := &SoraPriceConfig{ + VideoPricePerRequest: &normalPrice, + VideoPricePerRequestHD: &hdPrice, + } + cost := svc.CalculateSoraVideoCost("sora2pro-hd", cfg, 1.0) + require.InDelta(t, 1.0, cost.TotalCost, 1e-10) +} + +func TestIsModelSupported(t *testing.T) { + svc := newTestBillingService() + + require.True(t, svc.IsModelSupported("claude-sonnet-4")) + require.True(t, svc.IsModelSupported("Claude-Opus-4.5")) + require.True(t, svc.IsModelSupported("claude-3-haiku")) + require.False(t, svc.IsModelSupported("gpt-4o")) + require.False(t, svc.IsModelSupported("gemini-pro")) +} + +func TestCalculateCost_ZeroTokens(t *testing.T) { + svc := newTestBillingService() + + cost, err := svc.CalculateCost("claude-sonnet-4", UsageTokens{}, 1.0) + require.NoError(t, err) + require.Equal(t, 0.0, cost.TotalCost) + require.Equal(t, 0.0, cost.ActualCost) +} + +func TestCalculateCost_LargeTokenCount(t *testing.T) { + svc := newTestBillingService() + + tokens := UsageTokens{ + InputTokens: 1_000_000, + OutputTokens: 1_000_000, + } + cost, err := svc.CalculateCost("claude-sonnet-4", tokens, 1.0) + require.NoError(t, err) + + // Input: 1M * 3e-6 = $3, Output: 1M * 15e-6 = $15 + require.InDelta(t, 3.0, cost.InputCost, 1e-6) + require.InDelta(t, 15.0, cost.OutputCost, 1e-6) + require.False(t, math.IsNaN(cost.TotalCost)) + require.False(t, math.IsInf(cost.TotalCost, 0)) +} diff --git a/backend/internal/service/claude_code_detection_test.go b/backend/internal/service/claude_code_detection_test.go new file mode 100644 index 00000000..ff7ad7f4 --- /dev/null +++ b/backend/internal/service/claude_code_detection_test.go @@ -0,0 +1,282 @@ +//go:build unit + +package service + +import ( + "context" + "net/http/httptest" + "testing" + + "github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey" + "github.com/stretchr/testify/require" +) + +func newTestValidator() *ClaudeCodeValidator { + return NewClaudeCodeValidator() +} + +// validClaudeCodeBody 构造一个完整有效的 Claude Code 请求体 +func validClaudeCodeBody() map[string]any { + return map[string]any{ + "model": "claude-sonnet-4-20250514", + "system": []any{ + map[string]any{ + "type": "text", + "text": "You are Claude Code, Anthropic's official CLI for Claude.", + }, + }, + "metadata": map[string]any{ + "user_id": "user_" + "a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2" + "_account__session_" + "12345678-1234-1234-1234-123456789abc", + }, + } +} + +func TestValidate_ClaudeCLIUserAgent(t *testing.T) { + v := newTestValidator() + + tests := []struct { + name string + ua string + want bool + }{ + {"标准版本号", "claude-cli/1.0.0", true}, + {"多位版本号", "claude-cli/12.34.56", true}, + {"大写开头", "Claude-CLI/1.0.0", true}, + {"非 claude-cli", "curl/7.64.1", false}, + {"空 User-Agent", "", false}, + {"部分匹配", "not-claude-cli/1.0.0", false}, + {"缺少版本号", "claude-cli/", false}, + {"版本格式不对", "claude-cli/1.0", false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + require.Equal(t, tt.want, v.ValidateUserAgent(tt.ua), "UA: %q", tt.ua) + }) + } +} + +func TestValidate_NonMessagesPath_UAOnly(t *testing.T) { + v := newTestValidator() + + // 非 messages 路径只检查 UA + req := httptest.NewRequest("GET", "/v1/models", nil) + req.Header.Set("User-Agent", "claude-cli/1.0.0") + + result := v.Validate(req, nil) + require.True(t, result, "非 messages 路径只需 UA 匹配") +} + +func TestValidate_NonMessagesPath_InvalidUA(t *testing.T) { + v := newTestValidator() + + req := httptest.NewRequest("GET", "/v1/models", nil) + req.Header.Set("User-Agent", "curl/7.64.1") + + result := v.Validate(req, nil) + require.False(t, result, "UA 不匹配时应返回 false") +} + +func TestValidate_MessagesPath_FullValid(t *testing.T) { + v := newTestValidator() + + req := httptest.NewRequest("POST", "/v1/messages", nil) + req.Header.Set("User-Agent", "claude-cli/1.0.0") + req.Header.Set("X-App", "claude-code") + req.Header.Set("anthropic-beta", "max-tokens-3-5-sonnet-2024-07-15") + req.Header.Set("anthropic-version", "2023-06-01") + + result := v.Validate(req, validClaudeCodeBody()) + require.True(t, result, "完整有效请求应通过") +} + +func TestValidate_MessagesPath_MissingHeaders(t *testing.T) { + v := newTestValidator() + body := validClaudeCodeBody() + + tests := []struct { + name string + missingHeader string + }{ + {"缺少 X-App", "X-App"}, + {"缺少 anthropic-beta", "anthropic-beta"}, + {"缺少 anthropic-version", "anthropic-version"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + req := httptest.NewRequest("POST", "/v1/messages", nil) + req.Header.Set("User-Agent", "claude-cli/1.0.0") + req.Header.Set("X-App", "claude-code") + req.Header.Set("anthropic-beta", "beta") + req.Header.Set("anthropic-version", "2023-06-01") + req.Header.Del(tt.missingHeader) + + result := v.Validate(req, body) + require.False(t, result, "缺少 %s 应返回 false", tt.missingHeader) + }) + } +} + +func TestValidate_MessagesPath_InvalidMetadataUserID(t *testing.T) { + v := newTestValidator() + + tests := []struct { + name string + metadata map[string]any + }{ + {"缺少 metadata", nil}, + {"缺少 user_id", map[string]any{"other": "value"}}, + {"空 user_id", map[string]any{"user_id": ""}}, + {"格式错误", map[string]any{"user_id": "invalid-format"}}, + {"hex 长度不足", map[string]any{"user_id": "user_abc_account__session_uuid"}}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + req := httptest.NewRequest("POST", "/v1/messages", nil) + req.Header.Set("User-Agent", "claude-cli/1.0.0") + req.Header.Set("X-App", "claude-code") + req.Header.Set("anthropic-beta", "beta") + req.Header.Set("anthropic-version", "2023-06-01") + + body := map[string]any{ + "model": "claude-sonnet-4", + "system": []any{ + map[string]any{ + "type": "text", + "text": "You are Claude Code, Anthropic's official CLI for Claude.", + }, + }, + } + if tt.metadata != nil { + body["metadata"] = tt.metadata + } + + result := v.Validate(req, body) + require.False(t, result, "metadata.user_id: %v", tt.metadata) + }) + } +} + +func TestValidate_MessagesPath_InvalidSystemPrompt(t *testing.T) { + v := newTestValidator() + + req := httptest.NewRequest("POST", "/v1/messages", nil) + req.Header.Set("User-Agent", "claude-cli/1.0.0") + req.Header.Set("X-App", "claude-code") + req.Header.Set("anthropic-beta", "beta") + req.Header.Set("anthropic-version", "2023-06-01") + + body := map[string]any{ + "model": "claude-sonnet-4", + "system": []any{ + map[string]any{ + "type": "text", + "text": "Generate JSON data for testing database migrations.", + }, + }, + "metadata": map[string]any{ + "user_id": "user_" + "a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2" + "_account__session_12345678-1234-1234-1234-123456789abc", + }, + } + + result := v.Validate(req, body) + require.False(t, result, "无关系统提示词应返回 false") +} + +func TestValidate_MaxTokensOneHaikuBypass(t *testing.T) { + v := newTestValidator() + + req := httptest.NewRequest("POST", "/v1/messages", nil) + req.Header.Set("User-Agent", "claude-cli/1.0.0") + // 不设置 X-App 等头,通过 context 标记为 haiku 探测请求 + ctx := context.WithValue(req.Context(), ctxkey.IsMaxTokensOneHaikuRequest, true) + req = req.WithContext(ctx) + + // 即使 body 不包含 system prompt,也应通过 + result := v.Validate(req, map[string]any{"model": "claude-3-haiku", "max_tokens": 1}) + require.True(t, result, "max_tokens=1+haiku 探测请求应绕过严格验证") +} + +func TestSystemPromptSimilarity(t *testing.T) { + v := newTestValidator() + + tests := []struct { + name string + prompt string + want bool + }{ + {"精确匹配", "You are Claude Code, Anthropic's official CLI for Claude.", true}, + {"带多余空格", "You are Claude Code, Anthropic's official CLI for Claude.", true}, + {"Agent SDK 模板", "You are a Claude agent, built on Anthropic's Claude Agent SDK.", true}, + {"文件搜索专家模板", "You are a file search specialist for Claude Code, Anthropic's official CLI for Claude.", true}, + {"对话摘要模板", "You are a helpful AI assistant tasked with summarizing conversations.", true}, + {"交互式 CLI 模板", "You are an interactive CLI tool that helps users", true}, + {"无关文本", "Write me a poem about cats", false}, + {"空文本", "", false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + body := map[string]any{ + "model": "claude-sonnet-4", + "system": []any{ + map[string]any{"type": "text", "text": tt.prompt}, + }, + } + result := v.IncludesClaudeCodeSystemPrompt(body) + require.Equal(t, tt.want, result, "提示词: %q", tt.prompt) + }) + } +} + +func TestDiceCoefficient(t *testing.T) { + tests := []struct { + name string + a string + b string + want float64 + tol float64 + }{ + {"相同字符串", "hello", "hello", 1.0, 0.001}, + {"完全不同", "abc", "xyz", 0.0, 0.001}, + {"空字符串", "", "hello", 0.0, 0.001}, + {"单字符", "a", "b", 0.0, 0.001}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := diceCoefficient(tt.a, tt.b) + require.InDelta(t, tt.want, result, tt.tol) + }) + } +} + +func TestIsClaudeCodeClient_Context(t *testing.T) { + ctx := context.Background() + + // 默认应为 false + require.False(t, IsClaudeCodeClient(ctx)) + + // 设置为 true + ctx = SetClaudeCodeClient(ctx, true) + require.True(t, IsClaudeCodeClient(ctx)) + + // 设置为 false + ctx = SetClaudeCodeClient(ctx, false) + require.False(t, IsClaudeCodeClient(ctx)) +} + +func TestValidate_NilBody_MessagesPath(t *testing.T) { + v := newTestValidator() + + req := httptest.NewRequest("POST", "/v1/messages", nil) + req.Header.Set("User-Agent", "claude-cli/1.0.0") + req.Header.Set("X-App", "claude-code") + req.Header.Set("anthropic-beta", "beta") + req.Header.Set("anthropic-version", "2023-06-01") + + result := v.Validate(req, nil) + require.False(t, result, "nil body 的 messages 请求应返回 false") +} diff --git a/backend/internal/service/concurrency_service_test.go b/backend/internal/service/concurrency_service_test.go new file mode 100644 index 00000000..33ce4cb9 --- /dev/null +++ b/backend/internal/service/concurrency_service_test.go @@ -0,0 +1,280 @@ +//go:build unit + +package service + +import ( + "context" + "errors" + "testing" + + "github.com/stretchr/testify/require" +) + +// stubConcurrencyCacheForTest 用于并发服务单元测试的缓存桩 +type stubConcurrencyCacheForTest struct { + acquireResult bool + acquireErr error + releaseErr error + concurrency int + concurrencyErr error + waitAllowed bool + waitErr error + waitCount int + waitCountErr error + loadBatch map[int64]*AccountLoadInfo + loadBatchErr error + usersLoadBatch map[int64]*UserLoadInfo + usersLoadErr error + cleanupErr error + + // 记录调用 + releasedAccountIDs []int64 + releasedRequestIDs []string +} + +var _ ConcurrencyCache = (*stubConcurrencyCacheForTest)(nil) + +func (c *stubConcurrencyCacheForTest) AcquireAccountSlot(_ context.Context, _ int64, _ int, _ string) (bool, error) { + return c.acquireResult, c.acquireErr +} +func (c *stubConcurrencyCacheForTest) ReleaseAccountSlot(_ context.Context, accountID int64, requestID string) error { + c.releasedAccountIDs = append(c.releasedAccountIDs, accountID) + c.releasedRequestIDs = append(c.releasedRequestIDs, requestID) + return c.releaseErr +} +func (c *stubConcurrencyCacheForTest) GetAccountConcurrency(_ context.Context, _ int64) (int, error) { + return c.concurrency, c.concurrencyErr +} +func (c *stubConcurrencyCacheForTest) IncrementAccountWaitCount(_ context.Context, _ int64, _ int) (bool, error) { + return c.waitAllowed, c.waitErr +} +func (c *stubConcurrencyCacheForTest) DecrementAccountWaitCount(_ context.Context, _ int64) error { + return nil +} +func (c *stubConcurrencyCacheForTest) GetAccountWaitingCount(_ context.Context, _ int64) (int, error) { + return c.waitCount, c.waitCountErr +} +func (c *stubConcurrencyCacheForTest) AcquireUserSlot(_ context.Context, _ int64, _ int, _ string) (bool, error) { + return c.acquireResult, c.acquireErr +} +func (c *stubConcurrencyCacheForTest) ReleaseUserSlot(_ context.Context, _ int64, _ string) error { + return c.releaseErr +} +func (c *stubConcurrencyCacheForTest) GetUserConcurrency(_ context.Context, _ int64) (int, error) { + return c.concurrency, c.concurrencyErr +} +func (c *stubConcurrencyCacheForTest) IncrementWaitCount(_ context.Context, _ int64, _ int) (bool, error) { + return c.waitAllowed, c.waitErr +} +func (c *stubConcurrencyCacheForTest) DecrementWaitCount(_ context.Context, _ int64) error { + return nil +} +func (c *stubConcurrencyCacheForTest) GetAccountsLoadBatch(_ context.Context, _ []AccountWithConcurrency) (map[int64]*AccountLoadInfo, error) { + return c.loadBatch, c.loadBatchErr +} +func (c *stubConcurrencyCacheForTest) GetUsersLoadBatch(_ context.Context, _ []UserWithConcurrency) (map[int64]*UserLoadInfo, error) { + return c.usersLoadBatch, c.usersLoadErr +} +func (c *stubConcurrencyCacheForTest) CleanupExpiredAccountSlots(_ context.Context, _ int64) error { + return c.cleanupErr +} + +func TestAcquireAccountSlot_Success(t *testing.T) { + cache := &stubConcurrencyCacheForTest{acquireResult: true} + svc := NewConcurrencyService(cache) + + result, err := svc.AcquireAccountSlot(context.Background(), 1, 5) + require.NoError(t, err) + require.True(t, result.Acquired) + require.NotNil(t, result.ReleaseFunc) +} + +func TestAcquireAccountSlot_Failure(t *testing.T) { + cache := &stubConcurrencyCacheForTest{acquireResult: false} + svc := NewConcurrencyService(cache) + + result, err := svc.AcquireAccountSlot(context.Background(), 1, 5) + require.NoError(t, err) + require.False(t, result.Acquired) + require.Nil(t, result.ReleaseFunc) +} + +func TestAcquireAccountSlot_UnlimitedConcurrency(t *testing.T) { + svc := NewConcurrencyService(&stubConcurrencyCacheForTest{}) + + for _, maxConcurrency := range []int{0, -1} { + result, err := svc.AcquireAccountSlot(context.Background(), 1, maxConcurrency) + require.NoError(t, err) + require.True(t, result.Acquired, "maxConcurrency=%d 应无限制通过", maxConcurrency) + require.NotNil(t, result.ReleaseFunc, "ReleaseFunc 应为 no-op 函数") + } +} + +func TestAcquireAccountSlot_CacheError(t *testing.T) { + cache := &stubConcurrencyCacheForTest{acquireErr: errors.New("redis down")} + svc := NewConcurrencyService(cache) + + result, err := svc.AcquireAccountSlot(context.Background(), 1, 5) + require.Error(t, err) + require.Nil(t, result) +} + +func TestAcquireAccountSlot_ReleaseDecrements(t *testing.T) { + cache := &stubConcurrencyCacheForTest{acquireResult: true} + svc := NewConcurrencyService(cache) + + result, err := svc.AcquireAccountSlot(context.Background(), 42, 5) + require.NoError(t, err) + require.True(t, result.Acquired) + + // 调用 ReleaseFunc 应释放槽位 + result.ReleaseFunc() + + require.Len(t, cache.releasedAccountIDs, 1) + require.Equal(t, int64(42), cache.releasedAccountIDs[0]) + require.Len(t, cache.releasedRequestIDs, 1) + require.NotEmpty(t, cache.releasedRequestIDs[0], "requestID 不应为空") +} + +func TestAcquireUserSlot_IndependentFromAccount(t *testing.T) { + cache := &stubConcurrencyCacheForTest{acquireResult: true} + svc := NewConcurrencyService(cache) + + // 用户槽位获取应独立于账户槽位 + result, err := svc.AcquireUserSlot(context.Background(), 100, 3) + require.NoError(t, err) + require.True(t, result.Acquired) + require.NotNil(t, result.ReleaseFunc) +} + +func TestAcquireUserSlot_UnlimitedConcurrency(t *testing.T) { + svc := NewConcurrencyService(&stubConcurrencyCacheForTest{}) + + result, err := svc.AcquireUserSlot(context.Background(), 1, 0) + require.NoError(t, err) + require.True(t, result.Acquired) +} + +func TestGetAccountsLoadBatch_ReturnsCorrectData(t *testing.T) { + expected := map[int64]*AccountLoadInfo{ + 1: {AccountID: 1, CurrentConcurrency: 3, WaitingCount: 0, LoadRate: 60}, + 2: {AccountID: 2, CurrentConcurrency: 5, WaitingCount: 2, LoadRate: 100}, + } + cache := &stubConcurrencyCacheForTest{loadBatch: expected} + svc := NewConcurrencyService(cache) + + accounts := []AccountWithConcurrency{ + {ID: 1, MaxConcurrency: 5}, + {ID: 2, MaxConcurrency: 5}, + } + result, err := svc.GetAccountsLoadBatch(context.Background(), accounts) + require.NoError(t, err) + require.Equal(t, expected, result) +} + +func TestGetAccountsLoadBatch_NilCache(t *testing.T) { + svc := &ConcurrencyService{cache: nil} + + result, err := svc.GetAccountsLoadBatch(context.Background(), nil) + require.NoError(t, err) + require.Empty(t, result) +} + +func TestIncrementWaitCount_Success(t *testing.T) { + cache := &stubConcurrencyCacheForTest{waitAllowed: true} + svc := NewConcurrencyService(cache) + + allowed, err := svc.IncrementWaitCount(context.Background(), 1, 25) + require.NoError(t, err) + require.True(t, allowed) +} + +func TestIncrementWaitCount_QueueFull(t *testing.T) { + cache := &stubConcurrencyCacheForTest{waitAllowed: false} + svc := NewConcurrencyService(cache) + + allowed, err := svc.IncrementWaitCount(context.Background(), 1, 25) + require.NoError(t, err) + require.False(t, allowed) +} + +func TestIncrementWaitCount_FailOpen(t *testing.T) { + // Redis 错误时应 fail-open(允许请求通过) + cache := &stubConcurrencyCacheForTest{waitErr: errors.New("redis timeout")} + svc := NewConcurrencyService(cache) + + allowed, err := svc.IncrementWaitCount(context.Background(), 1, 25) + require.NoError(t, err, "Redis 错误不应传播") + require.True(t, allowed, "Redis 错误时应 fail-open") +} + +func TestIncrementWaitCount_NilCache(t *testing.T) { + svc := &ConcurrencyService{cache: nil} + + allowed, err := svc.IncrementWaitCount(context.Background(), 1, 25) + require.NoError(t, err) + require.True(t, allowed, "nil cache 应 fail-open") +} + +func TestCalculateMaxWait(t *testing.T) { + tests := []struct { + concurrency int + expected int + }{ + {5, 25}, // 5 + 20 + {1, 21}, // 1 + 20 + {0, 21}, // min(1) + 20 + {-1, 21}, // min(1) + 20 + {10, 30}, // 10 + 20 + } + for _, tt := range tests { + result := CalculateMaxWait(tt.concurrency) + require.Equal(t, tt.expected, result, "CalculateMaxWait(%d)", tt.concurrency) + } +} + +func TestGetAccountWaitingCount(t *testing.T) { + cache := &stubConcurrencyCacheForTest{waitCount: 5} + svc := NewConcurrencyService(cache) + + count, err := svc.GetAccountWaitingCount(context.Background(), 1) + require.NoError(t, err) + require.Equal(t, 5, count) +} + +func TestGetAccountWaitingCount_NilCache(t *testing.T) { + svc := &ConcurrencyService{cache: nil} + + count, err := svc.GetAccountWaitingCount(context.Background(), 1) + require.NoError(t, err) + require.Equal(t, 0, count) +} + +func TestGetAccountConcurrencyBatch(t *testing.T) { + cache := &stubConcurrencyCacheForTest{concurrency: 3} + svc := NewConcurrencyService(cache) + + result, err := svc.GetAccountConcurrencyBatch(context.Background(), []int64{1, 2, 3}) + require.NoError(t, err) + require.Len(t, result, 3) + for _, id := range []int64{1, 2, 3} { + require.Equal(t, 3, result[id]) + } +} + +func TestIncrementAccountWaitCount_FailOpen(t *testing.T) { + cache := &stubConcurrencyCacheForTest{waitErr: errors.New("redis error")} + svc := NewConcurrencyService(cache) + + allowed, err := svc.IncrementAccountWaitCount(context.Background(), 1, 10) + require.NoError(t, err, "Redis 错误不应传播") + require.True(t, allowed, "Redis 错误时应 fail-open") +} + +func TestIncrementAccountWaitCount_NilCache(t *testing.T) { + svc := &ConcurrencyService{cache: nil} + + allowed, err := svc.IncrementAccountWaitCount(context.Background(), 1, 10) + require.NoError(t, err) + require.True(t, allowed) +} diff --git a/backend/internal/service/gateway_account_selection_test.go b/backend/internal/service/gateway_account_selection_test.go new file mode 100644 index 00000000..70c5d6c5 --- /dev/null +++ b/backend/internal/service/gateway_account_selection_test.go @@ -0,0 +1,198 @@ +//go:build unit + +package service + +import ( + "testing" + "time" + + "github.com/stretchr/testify/require" +) + +// --- helpers --- + +func testTimePtr(t time.Time) *time.Time { return &t } + +func makeAccWithLoad(id int64, priority int, loadRate int, lastUsed *time.Time, accType string) accountWithLoad { + return accountWithLoad{ + account: &Account{ + ID: id, + Priority: priority, + LastUsedAt: lastUsed, + Type: accType, + Schedulable: true, + Status: StatusActive, + }, + loadInfo: &AccountLoadInfo{ + AccountID: id, + CurrentConcurrency: 0, + LoadRate: loadRate, + }, + } +} + +// --- sortAccountsByPriorityAndLastUsed --- + +func TestSortAccountsByPriorityAndLastUsed_ByPriority(t *testing.T) { + now := time.Now() + accounts := []*Account{ + {ID: 1, Priority: 5, LastUsedAt: testTimePtr(now)}, + {ID: 2, Priority: 1, LastUsedAt: testTimePtr(now)}, + {ID: 3, Priority: 3, LastUsedAt: testTimePtr(now)}, + } + sortAccountsByPriorityAndLastUsed(accounts, false) + require.Equal(t, int64(2), accounts[0].ID, "优先级最低的排第一") + require.Equal(t, int64(3), accounts[1].ID) + require.Equal(t, int64(1), accounts[2].ID) +} + +func TestSortAccountsByPriorityAndLastUsed_SamePriorityByLastUsed(t *testing.T) { + now := time.Now() + accounts := []*Account{ + {ID: 1, Priority: 1, LastUsedAt: testTimePtr(now)}, + {ID: 2, Priority: 1, LastUsedAt: testTimePtr(now.Add(-1 * time.Hour))}, + {ID: 3, Priority: 1, LastUsedAt: nil}, + } + sortAccountsByPriorityAndLastUsed(accounts, false) + require.Equal(t, int64(3), accounts[0].ID, "nil LastUsedAt 排最前") + require.Equal(t, int64(2), accounts[1].ID, "更早使用的排前面") + require.Equal(t, int64(1), accounts[2].ID) +} + +func TestSortAccountsByPriorityAndLastUsed_PreferOAuth(t *testing.T) { + accounts := []*Account{ + {ID: 1, Priority: 1, LastUsedAt: nil, Type: AccountTypeAPIKey}, + {ID: 2, Priority: 1, LastUsedAt: nil, Type: AccountTypeOAuth}, + } + sortAccountsByPriorityAndLastUsed(accounts, true) + require.Equal(t, int64(2), accounts[0].ID, "preferOAuth 时 OAuth 账号排前面") +} + +func TestSortAccountsByPriorityAndLastUsed_StableSort(t *testing.T) { + accounts := []*Account{ + {ID: 1, Priority: 1, LastUsedAt: nil, Type: AccountTypeAPIKey}, + {ID: 2, Priority: 1, LastUsedAt: nil, Type: AccountTypeAPIKey}, + {ID: 3, Priority: 1, LastUsedAt: nil, Type: AccountTypeAPIKey}, + } + sortAccountsByPriorityAndLastUsed(accounts, false) + // 稳定排序:相同键值的元素保持原始顺序 + require.Equal(t, int64(1), accounts[0].ID) + require.Equal(t, int64(2), accounts[1].ID) + require.Equal(t, int64(3), accounts[2].ID) +} + +func TestSortAccountsByPriorityAndLastUsed_MixedPriorityAndTime(t *testing.T) { + now := time.Now() + accounts := []*Account{ + {ID: 1, Priority: 2, LastUsedAt: nil}, + {ID: 2, Priority: 1, LastUsedAt: testTimePtr(now)}, + {ID: 3, Priority: 1, LastUsedAt: testTimePtr(now.Add(-1 * time.Hour))}, + {ID: 4, Priority: 2, LastUsedAt: testTimePtr(now.Add(-2 * time.Hour))}, + } + sortAccountsByPriorityAndLastUsed(accounts, false) + // 优先级1排前:nil < earlier + require.Equal(t, int64(3), accounts[0].ID, "优先级1 + 更早") + require.Equal(t, int64(2), accounts[1].ID, "优先级1 + 现在") + // 优先级2排后:nil < time + require.Equal(t, int64(1), accounts[2].ID, "优先级2 + nil") + require.Equal(t, int64(4), accounts[3].ID, "优先级2 + 有时间") +} + +// --- selectByCallCount --- + +func TestSelectByCallCount_Empty(t *testing.T) { + result := selectByCallCount(nil, nil, false) + require.Nil(t, result) +} + +func TestSelectByCallCount_Single(t *testing.T) { + accounts := []accountWithLoad{ + makeAccWithLoad(1, 1, 50, nil, AccountTypeAPIKey), + } + result := selectByCallCount(accounts, map[int64]*ModelLoadInfo{1: {CallCount: 10}}, false) + require.NotNil(t, result) + require.Equal(t, int64(1), result.account.ID) +} + +func TestSelectByCallCount_NilModelLoadFallsBackToLRU(t *testing.T) { + now := time.Now() + accounts := []accountWithLoad{ + makeAccWithLoad(1, 1, 50, testTimePtr(now), AccountTypeAPIKey), + makeAccWithLoad(2, 1, 50, testTimePtr(now.Add(-1*time.Hour)), AccountTypeAPIKey), + } + result := selectByCallCount(accounts, nil, false) + require.NotNil(t, result) + require.Equal(t, int64(2), result.account.ID, "nil modelLoadMap 应回退到 LRU 选择") +} + +func TestSelectByCallCount_SelectsMinCallCount(t *testing.T) { + accounts := []accountWithLoad{ + makeAccWithLoad(1, 1, 50, nil, AccountTypeAPIKey), + makeAccWithLoad(2, 1, 50, nil, AccountTypeAPIKey), + makeAccWithLoad(3, 1, 50, nil, AccountTypeAPIKey), + } + modelLoad := map[int64]*ModelLoadInfo{ + 1: {CallCount: 100}, + 2: {CallCount: 5}, + 3: {CallCount: 50}, + } + // 运行多次确认总是选调用次数最少的 + for i := 0; i < 10; i++ { + result := selectByCallCount(accounts, modelLoad, false) + require.NotNil(t, result) + require.Equal(t, int64(2), result.account.ID, "应选择调用次数最少的账号") + } +} + +func TestSelectByCallCount_NewAccountUsesAverage(t *testing.T) { + accounts := []accountWithLoad{ + makeAccWithLoad(1, 1, 50, nil, AccountTypeAPIKey), + makeAccWithLoad(2, 1, 50, nil, AccountTypeAPIKey), + makeAccWithLoad(3, 1, 50, nil, AccountTypeAPIKey), + } + // 账号1和2有调用记录,账号3是新账号(CallCount=0) + // 平均调用次数 = (100 + 200) / 2 = 150 + // 新账号用平均值 150,比账号1(100)多,所以应选账号1 + modelLoad := map[int64]*ModelLoadInfo{ + 1: {CallCount: 100}, + 2: {CallCount: 200}, + // 3 没有记录 + } + for i := 0; i < 10; i++ { + result := selectByCallCount(accounts, modelLoad, false) + require.NotNil(t, result) + require.Equal(t, int64(1), result.account.ID, "新账号虚拟调用次数(150)高于账号1(100),应选账号1") + } +} + +func TestSelectByCallCount_AllNewAccountsFallToAvgZero(t *testing.T) { + accounts := []accountWithLoad{ + makeAccWithLoad(1, 1, 50, nil, AccountTypeAPIKey), + makeAccWithLoad(2, 1, 50, nil, AccountTypeAPIKey), + } + // 所有账号都是新的,avgCallCount = 0,所有人 effectiveCallCount 都是 0 + modelLoad := map[int64]*ModelLoadInfo{} + validIDs := map[int64]bool{1: true, 2: true} + for i := 0; i < 10; i++ { + result := selectByCallCount(accounts, modelLoad, false) + require.NotNil(t, result) + require.True(t, validIDs[result.account.ID], "所有新账号应随机选择") + } +} + +func TestSelectByCallCount_PreferOAuth(t *testing.T) { + accounts := []accountWithLoad{ + makeAccWithLoad(1, 1, 50, nil, AccountTypeAPIKey), + makeAccWithLoad(2, 1, 50, nil, AccountTypeOAuth), + } + // 两个账号调用次数相同 + modelLoad := map[int64]*ModelLoadInfo{ + 1: {CallCount: 10}, + 2: {CallCount: 10}, + } + for i := 0; i < 10; i++ { + result := selectByCallCount(accounts, modelLoad, true) + require.NotNil(t, result) + require.Equal(t, int64(2), result.account.ID, "调用次数相同时应优先选择 OAuth 账号") + } +} diff --git a/backend/internal/service/gateway_streaming_test.go b/backend/internal/service/gateway_streaming_test.go new file mode 100644 index 00000000..50b998a3 --- /dev/null +++ b/backend/internal/service/gateway_streaming_test.go @@ -0,0 +1,203 @@ +//go:build unit + +package service + +import ( + "context" + "io" + "net/http" + "net/http/httptest" + "testing" + "time" + + "github.com/Wei-Shaw/sub2api/internal/config" + "github.com/gin-gonic/gin" + "github.com/stretchr/testify/require" +) + +// --- parseSSEUsage 测试 --- + +func newMinimalGatewayService() *GatewayService { + return &GatewayService{ + cfg: &config.Config{ + Gateway: config.GatewayConfig{ + StreamDataIntervalTimeout: 0, + MaxLineSize: defaultMaxLineSize, + }, + }, + rateLimitService: &RateLimitService{}, + } +} + +func TestParseSSEUsage_MessageStart(t *testing.T) { + svc := newMinimalGatewayService() + usage := &ClaudeUsage{} + + data := `{"type":"message_start","message":{"usage":{"input_tokens":100,"cache_creation_input_tokens":50,"cache_read_input_tokens":200}}}` + svc.parseSSEUsage(data, usage) + + require.Equal(t, 100, usage.InputTokens) + require.Equal(t, 50, usage.CacheCreationInputTokens) + require.Equal(t, 200, usage.CacheReadInputTokens) + require.Equal(t, 0, usage.OutputTokens, "message_start 不应设置 output_tokens") +} + +func TestParseSSEUsage_MessageDelta(t *testing.T) { + svc := newMinimalGatewayService() + usage := &ClaudeUsage{} + + data := `{"type":"message_delta","usage":{"output_tokens":42}}` + svc.parseSSEUsage(data, usage) + + require.Equal(t, 42, usage.OutputTokens) + require.Equal(t, 0, usage.InputTokens, "message_delta 的 output_tokens 不应影响已有的 input_tokens") +} + +func TestParseSSEUsage_DeltaDoesNotOverwriteStartValues(t *testing.T) { + svc := newMinimalGatewayService() + usage := &ClaudeUsage{} + + // 先处理 message_start + svc.parseSSEUsage(`{"type":"message_start","message":{"usage":{"input_tokens":100}}}`, usage) + require.Equal(t, 100, usage.InputTokens) + + // 再处理 message_delta(output_tokens > 0, input_tokens = 0) + svc.parseSSEUsage(`{"type":"message_delta","usage":{"output_tokens":50}}`, usage) + require.Equal(t, 100, usage.InputTokens, "delta 中 input_tokens=0 不应覆盖 start 中的值") + require.Equal(t, 50, usage.OutputTokens) +} + +func TestParseSSEUsage_DeltaOverwritesWithNonZero(t *testing.T) { + svc := newMinimalGatewayService() + usage := &ClaudeUsage{} + + // GLM 等 API 会在 delta 中包含所有 usage 信息 + svc.parseSSEUsage(`{"type":"message_delta","usage":{"input_tokens":200,"output_tokens":100,"cache_creation_input_tokens":30,"cache_read_input_tokens":60}}`, usage) + require.Equal(t, 200, usage.InputTokens) + require.Equal(t, 100, usage.OutputTokens) + require.Equal(t, 30, usage.CacheCreationInputTokens) + require.Equal(t, 60, usage.CacheReadInputTokens) +} + +func TestParseSSEUsage_InvalidJSON(t *testing.T) { + svc := newMinimalGatewayService() + usage := &ClaudeUsage{} + + // 无效 JSON 不应 panic + svc.parseSSEUsage("not json", usage) + require.Equal(t, 0, usage.InputTokens) + require.Equal(t, 0, usage.OutputTokens) +} + +func TestParseSSEUsage_UnknownType(t *testing.T) { + svc := newMinimalGatewayService() + usage := &ClaudeUsage{} + + // 不是 message_start 或 message_delta 的类型 + svc.parseSSEUsage(`{"type":"content_block_delta","delta":{"text":"hello"}}`, usage) + require.Equal(t, 0, usage.InputTokens) + require.Equal(t, 0, usage.OutputTokens) +} + +func TestParseSSEUsage_EmptyString(t *testing.T) { + svc := newMinimalGatewayService() + usage := &ClaudeUsage{} + + svc.parseSSEUsage("", usage) + require.Equal(t, 0, usage.InputTokens) +} + +func TestParseSSEUsage_DoneEvent(t *testing.T) { + svc := newMinimalGatewayService() + usage := &ClaudeUsage{} + + // [DONE] 事件不应影响 usage + svc.parseSSEUsage("[DONE]", usage) + require.Equal(t, 0, usage.InputTokens) +} + +// --- 流式响应端到端测试 --- + +func TestHandleStreamingResponse_CacheTokens(t *testing.T) { + gin.SetMode(gin.TestMode) + svc := newMinimalGatewayService() + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", nil) + + pr, pw := io.Pipe() + resp := &http.Response{StatusCode: http.StatusOK, Header: http.Header{}, Body: pr} + + go func() { + defer func() { _ = pw.Close() }() + _, _ = pw.Write([]byte("data: {\"type\":\"message_start\",\"message\":{\"usage\":{\"input_tokens\":10,\"cache_creation_input_tokens\":20,\"cache_read_input_tokens\":30}}}\n\n")) + _, _ = pw.Write([]byte("data: {\"type\":\"message_delta\",\"usage\":{\"output_tokens\":15}}\n\n")) + _, _ = pw.Write([]byte("data: [DONE]\n\n")) + }() + + result, err := svc.handleStreamingResponse(context.Background(), resp, c, &Account{ID: 1}, time.Now(), "model", "model", false) + _ = pr.Close() + require.NoError(t, err) + require.NotNil(t, result) + require.NotNil(t, result.usage) + require.Equal(t, 10, result.usage.InputTokens) + require.Equal(t, 15, result.usage.OutputTokens) + require.Equal(t, 20, result.usage.CacheCreationInputTokens) + require.Equal(t, 30, result.usage.CacheReadInputTokens) +} + +func TestHandleStreamingResponse_EmptyStream(t *testing.T) { + gin.SetMode(gin.TestMode) + svc := newMinimalGatewayService() + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", nil) + + pr, pw := io.Pipe() + resp := &http.Response{StatusCode: http.StatusOK, Header: http.Header{}, Body: pr} + + go func() { + // 直接关闭,不发送任何事件 + _ = pw.Close() + }() + + result, err := svc.handleStreamingResponse(context.Background(), resp, c, &Account{ID: 1}, time.Now(), "model", "model", false) + _ = pr.Close() + require.NoError(t, err) + require.NotNil(t, result) +} + +func TestHandleStreamingResponse_SpecialCharactersInJSON(t *testing.T) { + gin.SetMode(gin.TestMode) + svc := newMinimalGatewayService() + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", nil) + + pr, pw := io.Pipe() + resp := &http.Response{StatusCode: http.StatusOK, Header: http.Header{}, Body: pr} + + go func() { + defer func() { _ = pw.Close() }() + // 包含特殊字符的 content_block_delta(引号、换行、Unicode) + _, _ = pw.Write([]byte("data: {\"type\":\"content_block_delta\",\"index\":0,\"delta\":{\"type\":\"text_delta\",\"text\":\"Hello \\\"world\\\"\\n你好\"}}\n\n")) + _, _ = pw.Write([]byte("data: {\"type\":\"message_start\",\"message\":{\"usage\":{\"input_tokens\":5}}}\n\n")) + _, _ = pw.Write([]byte("data: {\"type\":\"message_delta\",\"usage\":{\"output_tokens\":3}}\n\n")) + _, _ = pw.Write([]byte("data: [DONE]\n\n")) + }() + + result, err := svc.handleStreamingResponse(context.Background(), resp, c, &Account{ID: 1}, time.Now(), "model", "model", false) + _ = pr.Close() + require.NoError(t, err) + require.NotNil(t, result) + require.NotNil(t, result.usage) + require.Equal(t, 5, result.usage.InputTokens) + require.Equal(t, 3, result.usage.OutputTokens) + + // 验证响应中包含转发的数据 + body := rec.Body.String() + require.Contains(t, body, "content_block_delta", "响应应包含转发的 SSE 事件") +} diff --git a/backend/internal/service/gateway_waiting_queue_test.go b/backend/internal/service/gateway_waiting_queue_test.go new file mode 100644 index 00000000..0ed95c87 --- /dev/null +++ b/backend/internal/service/gateway_waiting_queue_test.go @@ -0,0 +1,120 @@ +//go:build unit + +package service + +import ( + "context" + "errors" + "testing" + + "github.com/stretchr/testify/require" +) + +// TestDecrementWaitCount_NilCache 确保 nil cache 不会 panic +func TestDecrementWaitCount_NilCache(t *testing.T) { + svc := &ConcurrencyService{cache: nil} + // 不应 panic + svc.DecrementWaitCount(context.Background(), 1) +} + +// TestDecrementWaitCount_CacheError 确保 cache 错误不会传播 +func TestDecrementWaitCount_CacheError(t *testing.T) { + cache := &stubConcurrencyCacheForTest{} + svc := NewConcurrencyService(cache) + // DecrementWaitCount 使用 background context,错误只记录日志不传播 + svc.DecrementWaitCount(context.Background(), 1) +} + +// TestDecrementAccountWaitCount_NilCache 确保 nil cache 不会 panic +func TestDecrementAccountWaitCount_NilCache(t *testing.T) { + svc := &ConcurrencyService{cache: nil} + svc.DecrementAccountWaitCount(context.Background(), 1) +} + +// TestDecrementAccountWaitCount_CacheError 确保 cache 错误不会传播 +func TestDecrementAccountWaitCount_CacheError(t *testing.T) { + cache := &stubConcurrencyCacheForTest{} + svc := NewConcurrencyService(cache) + svc.DecrementAccountWaitCount(context.Background(), 1) +} + +// TestWaitingQueueFlow_IncrementThenDecrement 测试完整的等待队列增减流程 +func TestWaitingQueueFlow_IncrementThenDecrement(t *testing.T) { + cache := &stubConcurrencyCacheForTest{waitAllowed: true} + svc := NewConcurrencyService(cache) + + // 进入等待队列 + allowed, err := svc.IncrementWaitCount(context.Background(), 1, 25) + require.NoError(t, err) + require.True(t, allowed) + + // 离开等待队列(不应 panic) + svc.DecrementWaitCount(context.Background(), 1) +} + +// TestWaitingQueueFlow_AccountLevel 测试账号级等待队列流程 +func TestWaitingQueueFlow_AccountLevel(t *testing.T) { + cache := &stubConcurrencyCacheForTest{waitAllowed: true} + svc := NewConcurrencyService(cache) + + // 进入账号等待队列 + allowed, err := svc.IncrementAccountWaitCount(context.Background(), 42, 10) + require.NoError(t, err) + require.True(t, allowed) + + // 离开账号等待队列 + svc.DecrementAccountWaitCount(context.Background(), 42) +} + +// TestWaitingQueueFull_Returns429Signal 测试等待队列满时返回 false +func TestWaitingQueueFull_Returns429Signal(t *testing.T) { + // waitAllowed=false 模拟队列已满 + cache := &stubConcurrencyCacheForTest{waitAllowed: false} + svc := NewConcurrencyService(cache) + + // 用户级等待队列满 + allowed, err := svc.IncrementWaitCount(context.Background(), 1, 25) + require.NoError(t, err) + require.False(t, allowed, "等待队列满时应返回 false(调用方根据此返回 429)") + + // 账号级等待队列满 + allowed, err = svc.IncrementAccountWaitCount(context.Background(), 1, 10) + require.NoError(t, err) + require.False(t, allowed, "账号等待队列满时应返回 false") +} + +// TestWaitingQueue_FailOpen_OnCacheError 测试 Redis 故障时 fail-open +func TestWaitingQueue_FailOpen_OnCacheError(t *testing.T) { + cache := &stubConcurrencyCacheForTest{waitErr: errors.New("redis connection refused")} + svc := NewConcurrencyService(cache) + + // 用户级:Redis 错误时允许通过 + allowed, err := svc.IncrementWaitCount(context.Background(), 1, 25) + require.NoError(t, err, "Redis 错误不应向调用方传播") + require.True(t, allowed, "Redis 故障时应 fail-open 放行") + + // 账号级:同样 fail-open + allowed, err = svc.IncrementAccountWaitCount(context.Background(), 1, 10) + require.NoError(t, err, "Redis 错误不应向调用方传播") + require.True(t, allowed, "Redis 故障时应 fail-open 放行") +} + +// TestCalculateMaxWait_Scenarios 测试最大等待队列大小计算 +func TestCalculateMaxWait_Scenarios(t *testing.T) { + tests := []struct { + concurrency int + expected int + }{ + {5, 25}, // 5 + 20 + {10, 30}, // 10 + 20 + {1, 21}, // 1 + 20 + {0, 21}, // min(1) + 20 + {-1, 21}, // min(1) + 20 + {-10, 21}, // min(1) + 20 + {100, 120}, // 100 + 20 + } + for _, tt := range tests { + result := CalculateMaxWait(tt.concurrency) + require.Equal(t, tt.expected, result, "CalculateMaxWait(%d)", tt.concurrency) + } +} diff --git a/backend/internal/service/openai_gateway_service_test.go b/backend/internal/service/openai_gateway_service_test.go index 91dbaa4b..a6eeb3eb 100644 --- a/backend/internal/service/openai_gateway_service_test.go +++ b/backend/internal/service/openai_gateway_service_test.go @@ -17,6 +17,10 @@ import ( "github.com/stretchr/testify/require" ) +// 编译期接口断言 +var _ AccountRepository = (*stubOpenAIAccountRepo)(nil) +var _ GatewayCache = (*stubGatewayCache)(nil) + type stubOpenAIAccountRepo struct { AccountRepository accounts []Account diff --git a/backend/internal/service/ops_alert_evaluator_service_test.go b/backend/internal/service/ops_alert_evaluator_service_test.go index 068ab6bb..83d358a3 100644 --- a/backend/internal/service/ops_alert_evaluator_service_test.go +++ b/backend/internal/service/ops_alert_evaluator_service_test.go @@ -10,6 +10,8 @@ import ( "github.com/stretchr/testify/require" ) +var _ OpsRepository = (*stubOpsRepo)(nil) + type stubOpsRepo struct { OpsRepository overview *OpsDashboardOverview diff --git a/backend/internal/service/sora_gateway_service_test.go b/backend/internal/service/sora_gateway_service_test.go index caa10427..0a77d228 100644 --- a/backend/internal/service/sora_gateway_service_test.go +++ b/backend/internal/service/sora_gateway_service_test.go @@ -10,6 +10,8 @@ import ( "github.com/stretchr/testify/require" ) +var _ SoraClient = (*stubSoraClientForPoll)(nil) + type stubSoraClientForPoll struct { imageStatus *SoraImageTaskStatus videoStatus *SoraVideoTaskStatus diff --git a/backend/internal/service/subscription_calculate_progress_test.go b/backend/internal/service/subscription_calculate_progress_test.go index d8adf7f7..22018bcd 100644 --- a/backend/internal/service/subscription_calculate_progress_test.go +++ b/backend/internal/service/subscription_calculate_progress_test.go @@ -14,7 +14,7 @@ func newTestSubscriptionService() *SubscriptionService { return &SubscriptionService{} } -func ptrFloat64(v float64) *float64 { return &v } +func ptrFloat64(v float64) *float64 { return &v } func ptrTime(t time.Time) *time.Time { return &t } func TestCalculateProgress_BasicFields(t *testing.T) { diff --git a/backend/internal/testutil/fixtures.go b/backend/internal/testutil/fixtures.go new file mode 100644 index 00000000..747767bc --- /dev/null +++ b/backend/internal/testutil/fixtures.go @@ -0,0 +1,78 @@ +//go:build unit + +package testutil + +import ( + "time" + + "github.com/Wei-Shaw/sub2api/internal/service" +) + +// NewTestUser 创建一个可用的测试用户,可通过 opts 覆盖默认值。 +func NewTestUser(opts ...func(*service.User)) *service.User { + u := &service.User{ + ID: 1, + Email: "test@example.com", + Username: "testuser", + Role: "user", + Balance: 100.0, + Concurrency: 5, + Status: service.StatusActive, + CreatedAt: time.Now(), + UpdatedAt: time.Now(), + } + for _, opt := range opts { + opt(u) + } + return u +} + +// NewTestAccount 创建一个可用的测试账户,可通过 opts 覆盖默认值。 +func NewTestAccount(opts ...func(*service.Account)) *service.Account { + a := &service.Account{ + ID: 1, + Name: "test-account", + Platform: service.PlatformAnthropic, + Status: service.StatusActive, + Schedulable: true, + Concurrency: 5, + Priority: 1, + } + for _, opt := range opts { + opt(a) + } + return a +} + +// NewTestAPIKey 创建一个可用的测试 API Key,可通过 opts 覆盖默认值。 +func NewTestAPIKey(opts ...func(*service.APIKey)) *service.APIKey { + groupID := int64(1) + k := &service.APIKey{ + ID: 1, + UserID: 1, + Key: "sk-test-key-12345678", + Name: "test-key", + GroupID: &groupID, + Status: service.StatusActive, + CreatedAt: time.Now(), + UpdatedAt: time.Now(), + } + for _, opt := range opts { + opt(k) + } + return k +} + +// NewTestGroup 创建一个可用的测试分组,可通过 opts 覆盖默认值。 +func NewTestGroup(opts ...func(*service.Group)) *service.Group { + g := &service.Group{ + ID: 1, + Platform: service.PlatformAnthropic, + Status: service.StatusActive, + Hydrated: true, + } + for _, opt := range opts { + opt(g) + } + return g +} diff --git a/backend/internal/testutil/httptest.go b/backend/internal/testutil/httptest.go new file mode 100644 index 00000000..2a066a12 --- /dev/null +++ b/backend/internal/testutil/httptest.go @@ -0,0 +1,35 @@ +//go:build unit + +package testutil + +import ( + "io" + "net/http" + "net/http/httptest" + "strings" + + "github.com/gin-gonic/gin" +) + +func init() { + gin.SetMode(gin.TestMode) +} + +// NewGinTestContext 创建一个 Gin 测试上下文和 ResponseRecorder。 +// body 为空字符串时创建无 body 的请求。 +func NewGinTestContext(method, path, body string) (*gin.Context, *httptest.ResponseRecorder) { + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + + var bodyReader io.Reader + if body != "" { + bodyReader = strings.NewReader(body) + } + + c.Request = httptest.NewRequest(method, path, bodyReader) + if method == http.MethodPost || method == http.MethodPut || method == http.MethodPatch { + c.Request.Header.Set("Content-Type", "application/json") + } + + return c, rec +} diff --git a/backend/internal/testutil/stubs.go b/backend/internal/testutil/stubs.go new file mode 100644 index 00000000..81c40c42 --- /dev/null +++ b/backend/internal/testutil/stubs.go @@ -0,0 +1,137 @@ +//go:build unit + +// Package testutil 提供单元测试共享的 Stub、Fixture 和辅助函数。 +// 所有文件使用 //go:build unit 标签,确保不会被生产构建包含。 +package testutil + +import ( + "context" + "time" + + "github.com/Wei-Shaw/sub2api/internal/service" +) + +// ============================================================ +// StubConcurrencyCache — service.ConcurrencyCache 的空实现 +// ============================================================ + +// 编译期接口断言 +var _ service.ConcurrencyCache = StubConcurrencyCache{} + +// StubConcurrencyCache 是 ConcurrencyCache 的默认空实现,所有方法返回零值。 +type StubConcurrencyCache struct{} + +func (c StubConcurrencyCache) AcquireAccountSlot(_ context.Context, _ int64, _ int, _ string) (bool, error) { + return true, nil +} +func (c StubConcurrencyCache) ReleaseAccountSlot(_ context.Context, _ int64, _ string) error { + return nil +} +func (c StubConcurrencyCache) GetAccountConcurrency(_ context.Context, _ int64) (int, error) { + return 0, nil +} +func (c StubConcurrencyCache) IncrementAccountWaitCount(_ context.Context, _ int64, _ int) (bool, error) { + return true, nil +} +func (c StubConcurrencyCache) DecrementAccountWaitCount(_ context.Context, _ int64) error { + return nil +} +func (c StubConcurrencyCache) GetAccountWaitingCount(_ context.Context, _ int64) (int, error) { + return 0, nil +} +func (c StubConcurrencyCache) AcquireUserSlot(_ context.Context, _ int64, _ int, _ string) (bool, error) { + return true, nil +} +func (c StubConcurrencyCache) ReleaseUserSlot(_ context.Context, _ int64, _ string) error { + return nil +} +func (c StubConcurrencyCache) GetUserConcurrency(_ context.Context, _ int64) (int, error) { + return 0, nil +} +func (c StubConcurrencyCache) IncrementWaitCount(_ context.Context, _ int64, _ int) (bool, error) { + return true, nil +} +func (c StubConcurrencyCache) DecrementWaitCount(_ context.Context, _ int64) error { return nil } +func (c StubConcurrencyCache) GetAccountsLoadBatch(_ context.Context, accounts []service.AccountWithConcurrency) (map[int64]*service.AccountLoadInfo, error) { + result := make(map[int64]*service.AccountLoadInfo, len(accounts)) + for _, acc := range accounts { + result[acc.ID] = &service.AccountLoadInfo{AccountID: acc.ID, LoadRate: 0} + } + return result, nil +} +func (c StubConcurrencyCache) GetUsersLoadBatch(_ context.Context, users []service.UserWithConcurrency) (map[int64]*service.UserLoadInfo, error) { + result := make(map[int64]*service.UserLoadInfo, len(users)) + for _, u := range users { + result[u.ID] = &service.UserLoadInfo{UserID: u.ID, LoadRate: 0} + } + return result, nil +} +func (c StubConcurrencyCache) CleanupExpiredAccountSlots(_ context.Context, _ int64) error { + return nil +} + +// ============================================================ +// StubGatewayCache — service.GatewayCache 的空实现 +// ============================================================ + +var _ service.GatewayCache = StubGatewayCache{} + +type StubGatewayCache struct{} + +func (c StubGatewayCache) GetSessionAccountID(_ context.Context, _ int64, _ string) (int64, error) { + return 0, nil +} +func (c StubGatewayCache) SetSessionAccountID(_ context.Context, _ int64, _ string, _ int64, _ time.Duration) error { + return nil +} +func (c StubGatewayCache) RefreshSessionTTL(_ context.Context, _ int64, _ string, _ time.Duration) error { + return nil +} +func (c StubGatewayCache) DeleteSessionAccountID(_ context.Context, _ int64, _ string) error { + return nil +} +func (c StubGatewayCache) IncrModelCallCount(_ context.Context, _ int64, _ string) (int64, error) { + return 0, nil +} +func (c StubGatewayCache) GetModelLoadBatch(_ context.Context, _ []int64, _ string) (map[int64]*service.ModelLoadInfo, error) { + return nil, nil +} +func (c StubGatewayCache) FindGeminiSession(_ context.Context, _ int64, _, _ string) (string, int64, bool) { + return "", 0, false +} +func (c StubGatewayCache) SaveGeminiSession(_ context.Context, _ int64, _, _, _ string, _ int64) error { + return nil +} + +// ============================================================ +// StubSessionLimitCache — service.SessionLimitCache 的空实现 +// ============================================================ + +var _ service.SessionLimitCache = StubSessionLimitCache{} + +type StubSessionLimitCache struct{} + +func (c StubSessionLimitCache) RegisterSession(_ context.Context, _ int64, _ string, _ int, _ time.Duration) (bool, error) { + return true, nil +} +func (c StubSessionLimitCache) RefreshSession(_ context.Context, _ int64, _ string, _ time.Duration) error { + return nil +} +func (c StubSessionLimitCache) GetActiveSessionCount(_ context.Context, _ int64) (int, error) { + return 0, nil +} +func (c StubSessionLimitCache) GetActiveSessionCountBatch(_ context.Context, _ []int64, _ map[int64]time.Duration) (map[int64]int, error) { + return nil, nil +} +func (c StubSessionLimitCache) IsSessionActive(_ context.Context, _ int64, _ string) (bool, error) { + return false, nil +} +func (c StubSessionLimitCache) GetWindowCost(_ context.Context, _ int64) (float64, bool, error) { + return 0, false, nil +} +func (c StubSessionLimitCache) SetWindowCost(_ context.Context, _ int64, _ float64) error { + return nil +} +func (c StubSessionLimitCache) GetWindowCostBatch(_ context.Context, _ []int64) (map[int64]float64, error) { + return nil, nil +} diff --git a/frontend/src/api/__tests__/client.spec.ts b/frontend/src/api/__tests__/client.spec.ts new file mode 100644 index 00000000..0e92c6d1 --- /dev/null +++ b/frontend/src/api/__tests__/client.spec.ts @@ -0,0 +1,208 @@ +import { describe, it, expect, vi, beforeEach, afterEach } from 'vitest' +import axios from 'axios' +import type { AxiosInstance, InternalAxiosRequestConfig, AxiosResponse, AxiosHeaders } from 'axios' + +// 需要在导入 client 之前设置 mock +vi.mock('@/i18n', () => ({ + getLocale: () => 'zh-CN', +})) + +describe('API Client', () => { + let apiClient: AxiosInstance + + beforeEach(async () => { + localStorage.clear() + // 每次测试重新导入以获取干净的模块状态 + vi.resetModules() + const mod = await import('@/api/client') + apiClient = mod.apiClient + }) + + afterEach(() => { + vi.restoreAllMocks() + }) + + // --- 请求拦截器 --- + + describe('请求拦截器', () => { + it('自动附加 Authorization 头', async () => { + localStorage.setItem('auth_token', 'my-jwt-token') + + // 拦截实际请求 + const adapter = vi.fn().mockResolvedValue({ + status: 200, + data: { code: 0, data: {} }, + headers: {}, + config: {}, + statusText: 'OK', + }) + apiClient.defaults.adapter = adapter + + await apiClient.get('/test') + + const config = adapter.mock.calls[0][0] + expect(config.headers.get('Authorization')).toBe('Bearer my-jwt-token') + }) + + it('无 token 时不附加 Authorization 头', async () => { + const adapter = vi.fn().mockResolvedValue({ + status: 200, + data: { code: 0, data: {} }, + headers: {}, + config: {}, + statusText: 'OK', + }) + apiClient.defaults.adapter = adapter + + await apiClient.get('/test') + + const config = adapter.mock.calls[0][0] + expect(config.headers.get('Authorization')).toBeFalsy() + }) + + it('GET 请求自动附加 timezone 参数', async () => { + const adapter = vi.fn().mockResolvedValue({ + status: 200, + data: { code: 0, data: {} }, + headers: {}, + config: {}, + statusText: 'OK', + }) + apiClient.defaults.adapter = adapter + + await apiClient.get('/test') + + const config = adapter.mock.calls[0][0] + expect(config.params).toHaveProperty('timezone') + }) + + it('POST 请求不附加 timezone 参数', async () => { + const adapter = vi.fn().mockResolvedValue({ + status: 200, + data: { code: 0, data: {} }, + headers: {}, + config: {}, + statusText: 'OK', + }) + apiClient.defaults.adapter = adapter + + await apiClient.post('/test', { foo: 'bar' }) + + const config = adapter.mock.calls[0][0] + expect(config.params?.timezone).toBeUndefined() + }) + }) + + // --- 响应拦截器 --- + + describe('响应拦截器', () => { + it('code=0 时解包 data 字段', async () => { + const adapter = vi.fn().mockResolvedValue({ + status: 200, + data: { code: 0, data: { name: 'test' }, message: 'ok' }, + headers: {}, + config: {}, + statusText: 'OK', + }) + apiClient.defaults.adapter = adapter + + const response = await apiClient.get('/test') + expect(response.data).toEqual({ name: 'test' }) + }) + + it('code!=0 时拒绝并返回结构化错误', async () => { + const adapter = vi.fn().mockResolvedValue({ + status: 200, + data: { code: 1001, message: '参数错误', data: null }, + headers: {}, + config: {}, + statusText: 'OK', + }) + apiClient.defaults.adapter = adapter + + await expect(apiClient.get('/test')).rejects.toEqual( + expect.objectContaining({ + code: 1001, + message: '参数错误', + }) + ) + }) + }) + + // --- 401 Token 刷新 --- + + describe('401 Token 刷新', () => { + it('无 refresh_token 时 401 清除 localStorage', async () => { + localStorage.setItem('auth_token', 'expired-token') + // 不设置 refresh_token + + // Mock window.location + const originalLocation = window.location + Object.defineProperty(window, 'location', { + value: { ...originalLocation, pathname: '/dashboard', href: '/dashboard' }, + writable: true, + }) + + const adapter = vi.fn().mockRejectedValue({ + response: { + status: 401, + data: { code: 'TOKEN_EXPIRED', message: 'Token expired' }, + }, + config: { + url: '/test', + headers: { Authorization: 'Bearer expired-token' }, + }, + code: 'ERR_BAD_REQUEST', + }) + apiClient.defaults.adapter = adapter + + await expect(apiClient.get('/test')).rejects.toBeDefined() + + expect(localStorage.getItem('auth_token')).toBeNull() + + // 恢复 location + Object.defineProperty(window, 'location', { + value: originalLocation, + writable: true, + }) + }) + }) + + // --- 网络错误 --- + + describe('网络错误', () => { + it('网络错误返回 status 0 的错误', async () => { + const adapter = vi.fn().mockRejectedValue({ + code: 'ERR_NETWORK', + message: 'Network Error', + config: { url: '/test' }, + // 没有 response + }) + apiClient.defaults.adapter = adapter + + await expect(apiClient.get('/test')).rejects.toEqual( + expect.objectContaining({ + status: 0, + message: 'Network error. Please check your connection.', + }) + ) + }) + }) + + // --- 请求取消 --- + + describe('请求取消', () => { + it('取消的请求保持原始取消错误', async () => { + const source = axios.CancelToken.source() + + const adapter = vi.fn().mockRejectedValue( + new axios.Cancel('Operation canceled') + ) + apiClient.defaults.adapter = adapter + + await expect( + apiClient.get('/test', { cancelToken: source.token }) + ).rejects.toBeDefined() + }) + }) +}) diff --git a/frontend/src/components/__tests__/ApiKeyCreate.spec.ts b/frontend/src/components/__tests__/ApiKeyCreate.spec.ts new file mode 100644 index 00000000..537f43e7 --- /dev/null +++ b/frontend/src/components/__tests__/ApiKeyCreate.spec.ts @@ -0,0 +1,184 @@ +/** + * API Key 创建逻辑测试 + * 通过封装组件测试 API Key 创建的核心流程 + */ +import { describe, it, expect, vi, beforeEach } from 'vitest' +import { mount, flushPromises } from '@vue/test-utils' +import { setActivePinia, createPinia } from 'pinia' +import { defineComponent, ref, reactive } from 'vue' + +// Mock keysAPI +const mockCreate = vi.fn() +const mockList = vi.fn() + +vi.mock('@/api', () => ({ + keysAPI: { + create: (...args: any[]) => mockCreate(...args), + list: (...args: any[]) => mockList(...args), + }, + authAPI: { + getCurrentUser: vi.fn().mockResolvedValue({ data: {} }), + logout: vi.fn(), + refreshToken: vi.fn(), + }, + isTotp2FARequired: () => false, +})) + +vi.mock('@/api/admin/system', () => ({ + checkUpdates: vi.fn(), +})) + +vi.mock('@/api/auth', () => ({ + getPublicSettings: vi.fn().mockResolvedValue({}), +})) + +// Mock app store - 使用固定引用确保组件和测试共享同一对象 +const mockShowSuccess = vi.fn() +const mockShowError = vi.fn() + +vi.mock('@/stores/app', () => ({ + useAppStore: () => ({ + showSuccess: mockShowSuccess, + showError: mockShowError, + }), +})) + +import { useAppStore } from '@/stores/app' + +/** + * 简化的 API Key 创建测试组件 + */ +const ApiKeyCreateTestComponent = defineComponent({ + setup() { + const appStore = useAppStore() + const loading = ref(false) + const createdKey = ref('') + const formData = reactive({ + name: '', + group_id: null as number | null, + }) + + const handleCreate = async () => { + if (!formData.name) return + + loading.value = true + try { + const result = await mockCreate({ + name: formData.name, + group_id: formData.group_id, + }) + createdKey.value = result.key + appStore.showSuccess('API Key 创建成功') + } catch (error: any) { + appStore.showError(error.message || '创建失败') + } finally { + loading.value = false + } + } + + return { formData, loading, createdKey, handleCreate } + }, + template: ` +
+
+ + + +
+
{{ createdKey }}
+
+ `, +}) + +describe('ApiKey 创建流程', () => { + beforeEach(() => { + setActivePinia(createPinia()) + vi.clearAllMocks() + }) + + it('创建 API Key 调用 API 并显示结果', async () => { + mockCreate.mockResolvedValue({ + id: 1, + key: 'sk-test-key-12345', + name: 'My Test Key', + }) + + const wrapper = mount(ApiKeyCreateTestComponent) + + await wrapper.find('#name').setValue('My Test Key') + await wrapper.find('form').trigger('submit') + await flushPromises() + + expect(mockCreate).toHaveBeenCalledWith({ + name: 'My Test Key', + group_id: null, + }) + + expect(wrapper.find('.created-key').text()).toBe('sk-test-key-12345') + }) + + it('选择分组后正确传参', async () => { + mockCreate.mockResolvedValue({ + id: 2, + key: 'sk-group-key', + name: 'Group Key', + }) + + const wrapper = mount(ApiKeyCreateTestComponent) + + await wrapper.find('#name').setValue('Group Key') + // 选择 group_id = 1 + await wrapper.find('#group').setValue('1') + await wrapper.find('form').trigger('submit') + await flushPromises() + + expect(mockCreate).toHaveBeenCalledWith({ + name: 'Group Key', + group_id: 1, + }) + }) + + it('创建失败时显示错误', async () => { + mockCreate.mockRejectedValue(new Error('配额不足')) + + const wrapper = mount(ApiKeyCreateTestComponent) + + await wrapper.find('#name').setValue('Fail Key') + await wrapper.find('form').trigger('submit') + await flushPromises() + + expect(mockShowError).toHaveBeenCalledWith('配额不足') + expect(wrapper.find('.created-key').exists()).toBe(false) + }) + + it('名称为空时不提交', async () => { + const wrapper = mount(ApiKeyCreateTestComponent) + + await wrapper.find('form').trigger('submit') + await flushPromises() + + expect(mockCreate).not.toHaveBeenCalled() + }) + + it('创建过程中按钮被禁用', async () => { + let resolveCreate: (v: any) => void + mockCreate.mockImplementation( + () => new Promise((resolve) => { resolveCreate = resolve }) + ) + + const wrapper = mount(ApiKeyCreateTestComponent) + + await wrapper.find('#name').setValue('Test Key') + await wrapper.find('form').trigger('submit') + + expect(wrapper.find('button').attributes('disabled')).toBeDefined() + + resolveCreate!({ id: 1, key: 'sk-test', name: 'Test Key' }) + await flushPromises() + + expect(wrapper.find('button').attributes('disabled')).toBeUndefined() + }) +}) diff --git a/frontend/src/components/__tests__/Dashboard.spec.ts b/frontend/src/components/__tests__/Dashboard.spec.ts new file mode 100644 index 00000000..b83808cc --- /dev/null +++ b/frontend/src/components/__tests__/Dashboard.spec.ts @@ -0,0 +1,173 @@ +/** + * Dashboard 数据加载逻辑测试 + * 通过封装组件测试仪表板核心数据加载流程 + */ +import { describe, it, expect, vi, beforeEach } from 'vitest' +import { mount, flushPromises } from '@vue/test-utils' +import { setActivePinia, createPinia } from 'pinia' +import { defineComponent, ref, onMounted, nextTick } from 'vue' + +// Mock API +const mockGetDashboardStats = vi.fn() +const mockRefreshUser = vi.fn() + +vi.mock('@/api', () => ({ + authAPI: { + getCurrentUser: vi.fn().mockResolvedValue({ + data: { id: 1, username: 'test', email: 'test@example.com', role: 'user', balance: 100, concurrency: 5, status: 'active', allowed_groups: null, created_at: '', updated_at: '' }, + }), + logout: vi.fn(), + refreshToken: vi.fn(), + }, + isTotp2FARequired: () => false, +})) + +vi.mock('@/api/usage', () => ({ + usageAPI: { + getDashboardStats: (...args: any[]) => mockGetDashboardStats(...args), + }, +})) + +vi.mock('@/api/admin/system', () => ({ + checkUpdates: vi.fn(), +})) + +vi.mock('@/api/auth', () => ({ + getPublicSettings: vi.fn().mockResolvedValue({}), +})) + +interface DashboardStats { + balance: number + api_key_count: number + active_api_key_count: number + today_requests: number + today_cost: number + today_tokens: number + total_tokens: number +} + +/** + * 简化的 Dashboard 测试组件 + */ +const DashboardTestComponent = defineComponent({ + setup() { + const stats = ref(null) + const loading = ref(false) + const error = ref('') + + const loadStats = async () => { + loading.value = true + error.value = '' + try { + stats.value = await mockGetDashboardStats() + } catch (e: any) { + error.value = e.message || '加载失败' + } finally { + loading.value = false + } + } + + onMounted(loadStats) + + return { stats, loading, error, loadStats } + }, + template: ` +
+
加载中...
+
{{ error }}
+
+ {{ stats.balance }} + {{ stats.api_key_count }} + {{ stats.today_requests }} + {{ stats.today_cost }} +
+ +
+ `, +}) + +describe('Dashboard 数据加载', () => { + beforeEach(() => { + setActivePinia(createPinia()) + vi.clearAllMocks() + }) + + const fakeStats: DashboardStats = { + balance: 100.5, + api_key_count: 3, + active_api_key_count: 2, + today_requests: 150, + today_cost: 2.5, + today_tokens: 50000, + total_tokens: 1000000, + } + + it('挂载后自动加载数据', async () => { + mockGetDashboardStats.mockResolvedValue(fakeStats) + + const wrapper = mount(DashboardTestComponent) + await flushPromises() + + expect(mockGetDashboardStats).toHaveBeenCalledTimes(1) + expect(wrapper.find('.balance').text()).toBe('100.5') + expect(wrapper.find('.api-keys').text()).toBe('3') + expect(wrapper.find('.today-requests').text()).toBe('150') + expect(wrapper.find('.today-cost').text()).toBe('2.5') + }) + + it('加载中显示 loading 状态', async () => { + let resolveStats: (v: any) => void + mockGetDashboardStats.mockImplementation( + () => new Promise((resolve) => { resolveStats = resolve }) + ) + + const wrapper = mount(DashboardTestComponent) + await nextTick() + + expect(wrapper.find('.loading').exists()).toBe(true) + + resolveStats!(fakeStats) + await flushPromises() + + expect(wrapper.find('.loading').exists()).toBe(false) + expect(wrapper.find('.stats').exists()).toBe(true) + }) + + it('加载失败时显示错误信息', async () => { + mockGetDashboardStats.mockRejectedValue(new Error('Network error')) + + const wrapper = mount(DashboardTestComponent) + await flushPromises() + + expect(wrapper.find('.error').text()).toBe('Network error') + expect(wrapper.find('.stats').exists()).toBe(false) + }) + + it('点击刷新按钮重新加载数据', async () => { + mockGetDashboardStats.mockResolvedValue(fakeStats) + + const wrapper = mount(DashboardTestComponent) + await flushPromises() + + expect(mockGetDashboardStats).toHaveBeenCalledTimes(1) + + // 更新数据 + const updatedStats = { ...fakeStats, today_requests: 200 } + mockGetDashboardStats.mockResolvedValue(updatedStats) + + await wrapper.find('.refresh').trigger('click') + await flushPromises() + + expect(mockGetDashboardStats).toHaveBeenCalledTimes(2) + expect(wrapper.find('.today-requests').text()).toBe('200') + }) + + it('数据为空时不显示统计信息', async () => { + mockGetDashboardStats.mockResolvedValue(null) + + const wrapper = mount(DashboardTestComponent) + await flushPromises() + + expect(wrapper.find('.stats').exists()).toBe(false) + }) +}) diff --git a/frontend/src/components/__tests__/LoginForm.spec.ts b/frontend/src/components/__tests__/LoginForm.spec.ts new file mode 100644 index 00000000..14b86fc2 --- /dev/null +++ b/frontend/src/components/__tests__/LoginForm.spec.ts @@ -0,0 +1,178 @@ +/** + * LoginView 组件核心逻辑测试 + * 测试登录表单提交、验证、2FA 等场景 + */ +import { describe, it, expect, vi, beforeEach } from 'vitest' +import { mount, flushPromises } from '@vue/test-utils' +import { setActivePinia, createPinia } from 'pinia' +import { defineComponent, reactive, ref } from 'vue' +import { useAuthStore } from '@/stores/auth' + +// Mock 所有外部依赖 +const mockLogin = vi.fn() +const mockLogin2FA = vi.fn() +const mockPush = vi.fn() + +vi.mock('@/api', () => ({ + authAPI: { + login: (...args: any[]) => mockLogin(...args), + login2FA: (...args: any[]) => mockLogin2FA(...args), + logout: vi.fn(), + getCurrentUser: vi.fn().mockResolvedValue({ data: {} }), + register: vi.fn(), + refreshToken: vi.fn(), + }, + isTotp2FARequired: (response: any) => response?.requires_2fa === true, +})) + +vi.mock('@/api/admin/system', () => ({ + checkUpdates: vi.fn(), +})) + +vi.mock('@/api/auth', () => ({ + getPublicSettings: vi.fn().mockResolvedValue({}), +})) + +/** + * 创建一个简化的测试组件来封装登录逻辑 + * 避免引入 LoginView.vue 的全部依赖(AuthLayout、i18n、Icon 等) + */ +const LoginFormTestComponent = defineComponent({ + setup() { + const authStore = useAuthStore() + const formData = reactive({ email: '', password: '' }) + const isLoading = ref(false) + const errorMessage = ref('') + + const handleLogin = async () => { + if (!formData.email || !formData.password) { + errorMessage.value = '请输入邮箱和密码' + return + } + + isLoading.value = true + errorMessage.value = '' + + try { + const response = await authStore.login({ + email: formData.email, + password: formData.password, + }) + + // 2FA 流程由调用方处理 + if ((response as any)?.requires_2fa) { + errorMessage.value = '需要 2FA 验证' + return + } + + mockPush('/dashboard') + } catch (error: any) { + errorMessage.value = error.message || '登录失败' + } finally { + isLoading.value = false + } + } + + return { formData, isLoading, errorMessage, handleLogin } + }, + template: ` +
+ + +

{{ errorMessage }}

+ +
+ `, +}) + +describe('LoginForm 核心逻辑', () => { + beforeEach(() => { + setActivePinia(createPinia()) + vi.clearAllMocks() + }) + + it('成功登录后跳转到 dashboard', async () => { + mockLogin.mockResolvedValue({ + access_token: 'token', + token_type: 'Bearer', + user: { id: 1, username: 'test', email: 'test@example.com', role: 'user', balance: 0, concurrency: 5, status: 'active', allowed_groups: null, created_at: '', updated_at: '' }, + }) + + const wrapper = mount(LoginFormTestComponent) + + await wrapper.find('#email').setValue('test@example.com') + await wrapper.find('#password').setValue('password123') + await wrapper.find('form').trigger('submit') + await flushPromises() + + expect(mockLogin).toHaveBeenCalledWith({ + email: 'test@example.com', + password: 'password123', + }) + expect(mockPush).toHaveBeenCalledWith('/dashboard') + }) + + it('登录失败时显示错误信息', async () => { + mockLogin.mockRejectedValue(new Error('Invalid credentials')) + + const wrapper = mount(LoginFormTestComponent) + + await wrapper.find('#email').setValue('test@example.com') + await wrapper.find('#password').setValue('wrong') + await wrapper.find('form').trigger('submit') + await flushPromises() + + expect(wrapper.find('.error').text()).toBe('Invalid credentials') + }) + + it('空表单提交显示验证错误', async () => { + const wrapper = mount(LoginFormTestComponent) + + await wrapper.find('form').trigger('submit') + await flushPromises() + + expect(wrapper.find('.error').text()).toBe('请输入邮箱和密码') + expect(mockLogin).not.toHaveBeenCalled() + }) + + it('需要 2FA 时不跳转', async () => { + mockLogin.mockResolvedValue({ + requires_2fa: true, + temp_token: 'temp-123', + }) + + const wrapper = mount(LoginFormTestComponent) + + await wrapper.find('#email').setValue('test@example.com') + await wrapper.find('#password').setValue('password123') + await wrapper.find('form').trigger('submit') + await flushPromises() + + expect(mockPush).not.toHaveBeenCalled() + expect(wrapper.find('.error').text()).toBe('需要 2FA 验证') + }) + + it('提交过程中按钮被禁用', async () => { + let resolveLogin: (v: any) => void + mockLogin.mockImplementation( + () => new Promise((resolve) => { resolveLogin = resolve }) + ) + + const wrapper = mount(LoginFormTestComponent) + + await wrapper.find('#email').setValue('test@example.com') + await wrapper.find('#password').setValue('password123') + await wrapper.find('form').trigger('submit') + + expect(wrapper.find('button').attributes('disabled')).toBeDefined() + + resolveLogin!({ + access_token: 'token', + token_type: 'Bearer', + user: { id: 1, username: 'test', email: 'test@example.com', role: 'user', balance: 0, concurrency: 5, status: 'active', allowed_groups: null, created_at: '', updated_at: '' }, + }) + await flushPromises() + + expect(wrapper.find('button').attributes('disabled')).toBeUndefined() + }) +}) diff --git a/frontend/src/composables/__tests__/useClipboard.spec.ts b/frontend/src/composables/__tests__/useClipboard.spec.ts new file mode 100644 index 00000000..b2c4de41 --- /dev/null +++ b/frontend/src/composables/__tests__/useClipboard.spec.ts @@ -0,0 +1,137 @@ +import { describe, it, expect, vi, beforeEach, afterEach } from 'vitest' +import { setActivePinia, createPinia } from 'pinia' + +// Mock i18n +vi.mock('@/i18n', () => ({ + i18n: { + global: { + t: (key: string) => key, + }, + }, +})) + +// Mock app store +const mockShowSuccess = vi.fn() +const mockShowError = vi.fn() + +vi.mock('@/stores/app', () => ({ + useAppStore: () => ({ + showSuccess: mockShowSuccess, + showError: mockShowError, + }), +})) + +import { useClipboard } from '@/composables/useClipboard' + +describe('useClipboard', () => { + beforeEach(() => { + setActivePinia(createPinia()) + vi.useFakeTimers() + vi.clearAllMocks() + + // 默认模拟安全上下文 + Clipboard API + Object.defineProperty(window, 'isSecureContext', { value: true, writable: true }) + Object.defineProperty(navigator, 'clipboard', { + value: { + writeText: vi.fn().mockResolvedValue(undefined), + }, + writable: true, + configurable: true, + }) + }) + + afterEach(() => { + vi.useRealTimers() + // 恢复 execCommand + if ('execCommand' in document) { + delete (document as any).execCommand + } + }) + + it('复制成功后 copied 变为 true', async () => { + const { copied, copyToClipboard } = useClipboard() + + expect(copied.value).toBe(false) + + await copyToClipboard('hello') + + expect(copied.value).toBe(true) + }) + + it('copied 在 2 秒后自动恢复为 false', async () => { + const { copied, copyToClipboard } = useClipboard() + + await copyToClipboard('hello') + expect(copied.value).toBe(true) + + vi.advanceTimersByTime(2000) + + expect(copied.value).toBe(false) + }) + + it('复制成功时调用 showSuccess', async () => { + const { copyToClipboard } = useClipboard() + + await copyToClipboard('hello', '已复制') + + expect(mockShowSuccess).toHaveBeenCalledWith('已复制') + }) + + it('无自定义消息时使用 i18n 默认消息', async () => { + const { copyToClipboard } = useClipboard() + + await copyToClipboard('hello') + + expect(mockShowSuccess).toHaveBeenCalledWith('common.copiedToClipboard') + }) + + it('空文本返回 false 且不复制', async () => { + const { copyToClipboard, copied } = useClipboard() + + const result = await copyToClipboard('') + + expect(result).toBe(false) + expect(copied.value).toBe(false) + expect(navigator.clipboard.writeText).not.toHaveBeenCalled() + }) + + it('Clipboard API 失败时降级到 fallback', async () => { + ;(navigator.clipboard.writeText as any).mockRejectedValue(new Error('API failed')) + + // jsdom 没有 execCommand,手动定义 + ;(document as any).execCommand = vi.fn().mockReturnValue(true) + + const { copyToClipboard, copied } = useClipboard() + const result = await copyToClipboard('fallback text') + + expect(result).toBe(true) + expect(copied.value).toBe(true) + expect(document.execCommand).toHaveBeenCalledWith('copy') + }) + + it('非安全上下文使用 fallback', async () => { + Object.defineProperty(window, 'isSecureContext', { value: false, writable: true }) + + ;(document as any).execCommand = vi.fn().mockReturnValue(true) + + const { copyToClipboard, copied } = useClipboard() + const result = await copyToClipboard('insecure context text') + + expect(result).toBe(true) + expect(copied.value).toBe(true) + expect(navigator.clipboard.writeText).not.toHaveBeenCalled() + expect(document.execCommand).toHaveBeenCalledWith('copy') + }) + + it('所有复制方式均失败时调用 showError', async () => { + ;(navigator.clipboard.writeText as any).mockRejectedValue(new Error('fail')) + ;(document as any).execCommand = vi.fn().mockReturnValue(false) + + const { copyToClipboard, copied } = useClipboard() + const result = await copyToClipboard('text') + + expect(result).toBe(false) + expect(copied.value).toBe(false) + expect(mockShowError).toHaveBeenCalled() + }) +}) diff --git a/frontend/src/composables/__tests__/useForm.spec.ts b/frontend/src/composables/__tests__/useForm.spec.ts new file mode 100644 index 00000000..bd9396a2 --- /dev/null +++ b/frontend/src/composables/__tests__/useForm.spec.ts @@ -0,0 +1,143 @@ +import { describe, it, expect, vi, beforeEach } from 'vitest' +import { setActivePinia, createPinia } from 'pinia' +import { useForm } from '@/composables/useForm' +import { useAppStore } from '@/stores/app' + +// Mock API 依赖(app store 内部引用了这些) +vi.mock('@/api/admin/system', () => ({ + checkUpdates: vi.fn(), +})) +vi.mock('@/api/auth', () => ({ + getPublicSettings: vi.fn(), +})) + +describe('useForm', () => { + let appStore: ReturnType + + beforeEach(() => { + setActivePinia(createPinia()) + appStore = useAppStore() + vi.clearAllMocks() + }) + + it('submit 期间 loading 为 true,完成后为 false', async () => { + let resolveSubmit: () => void + const submitFn = vi.fn( + () => new Promise((resolve) => { resolveSubmit = resolve }) + ) + + const { loading, submit } = useForm({ + form: { name: 'test' }, + submitFn, + }) + + expect(loading.value).toBe(false) + + const submitPromise = submit() + // 提交中 + expect(loading.value).toBe(true) + + resolveSubmit!() + await submitPromise + + expect(loading.value).toBe(false) + }) + + it('submit 成功时显示成功消息', async () => { + const submitFn = vi.fn().mockResolvedValue(undefined) + const showSuccessSpy = vi.spyOn(appStore, 'showSuccess') + + const { submit } = useForm({ + form: { name: 'test' }, + submitFn, + successMsg: '保存成功', + }) + + await submit() + + expect(showSuccessSpy).toHaveBeenCalledWith('保存成功') + }) + + it('submit 成功但无 successMsg 时不调用 showSuccess', async () => { + const submitFn = vi.fn().mockResolvedValue(undefined) + const showSuccessSpy = vi.spyOn(appStore, 'showSuccess') + + const { submit } = useForm({ + form: { name: 'test' }, + submitFn, + }) + + await submit() + + expect(showSuccessSpy).not.toHaveBeenCalled() + }) + + it('submit 失败时显示错误消息并抛出错误', async () => { + const error = Object.assign(new Error('提交失败'), { + response: { data: { message: '服务器错误' } }, + }) + const submitFn = vi.fn().mockRejectedValue(error) + const showErrorSpy = vi.spyOn(appStore, 'showError') + + const { submit, loading } = useForm({ + form: { name: 'test' }, + submitFn, + }) + + await expect(submit()).rejects.toThrow('提交失败') + + expect(showErrorSpy).toHaveBeenCalled() + expect(loading.value).toBe(false) + }) + + it('submit 失败时使用自定义 errorMsg', async () => { + const submitFn = vi.fn().mockRejectedValue(new Error('network')) + const showErrorSpy = vi.spyOn(appStore, 'showError') + + const { submit } = useForm({ + form: { name: 'test' }, + submitFn, + errorMsg: '自定义错误提示', + }) + + await expect(submit()).rejects.toThrow() + + expect(showErrorSpy).toHaveBeenCalledWith('自定义错误提示') + }) + + it('loading 中不会重复提交', async () => { + let resolveSubmit: () => void + const submitFn = vi.fn( + () => new Promise((resolve) => { resolveSubmit = resolve }) + ) + + const { submit } = useForm({ + form: { name: 'test' }, + submitFn, + }) + + // 第一次提交 + const p1 = submit() + // 第二次提交(应被忽略,因为 loading=true) + submit() + + expect(submitFn).toHaveBeenCalledTimes(1) + + resolveSubmit!() + await p1 + }) + + it('传递 form 数据到 submitFn', async () => { + const formData = { name: 'test', email: 'test@example.com' } + const submitFn = vi.fn().mockResolvedValue(undefined) + + const { submit } = useForm({ + form: formData, + submitFn, + }) + + await submit() + + expect(submitFn).toHaveBeenCalledWith(formData) + }) +}) diff --git a/frontend/src/composables/__tests__/useTableLoader.spec.ts b/frontend/src/composables/__tests__/useTableLoader.spec.ts new file mode 100644 index 00000000..0eb6f42c --- /dev/null +++ b/frontend/src/composables/__tests__/useTableLoader.spec.ts @@ -0,0 +1,252 @@ +import { describe, it, expect, vi, beforeEach, afterEach } from 'vitest' +import { useTableLoader } from '@/composables/useTableLoader' +import { nextTick } from 'vue' + +// Mock @vueuse/core 的 useDebounceFn +vi.mock('@vueuse/core', () => ({ + useDebounceFn: (fn: Function, ms: number) => { + let timer: ReturnType | null = null + const debounced = (...args: any[]) => { + if (timer) clearTimeout(timer) + timer = setTimeout(() => fn(...args), ms) + } + debounced.cancel = () => { if (timer) clearTimeout(timer) } + return debounced + }, +})) + +// Mock Vue 的 onUnmounted(composable 外使用时会报错) +vi.mock('vue', async () => { + const actual = await vi.importActual('vue') + return { + ...actual, + onUnmounted: vi.fn(), + } +}) + +const createMockFetchFn = (items: any[] = [], total = 0, pages = 1) => { + return vi.fn().mockResolvedValue({ items, total, pages }) +} + +describe('useTableLoader', () => { + beforeEach(() => { + vi.useFakeTimers() + vi.clearAllMocks() + }) + + afterEach(() => { + vi.useRealTimers() + }) + + // --- 基础加载 --- + + describe('基础加载', () => { + it('load 执行 fetchFn 并更新 items', async () => { + const mockItems = [{ id: 1, name: 'item1' }, { id: 2, name: 'item2' }] + const fetchFn = createMockFetchFn(mockItems, 2, 1) + + const { items, loading, load, pagination } = useTableLoader({ + fetchFn, + }) + + expect(items.value).toHaveLength(0) + + await load() + + expect(items.value).toEqual(mockItems) + expect(pagination.total).toBe(2) + expect(pagination.pages).toBe(1) + expect(loading.value).toBe(false) + }) + + it('load 期间 loading 为 true', async () => { + let resolveLoad: (v: any) => void + const fetchFn = vi.fn( + () => new Promise((resolve) => { resolveLoad = resolve }) + ) + + const { loading, load } = useTableLoader({ fetchFn }) + + const p = load() + expect(loading.value).toBe(true) + + resolveLoad!({ items: [], total: 0, pages: 0 }) + await p + + expect(loading.value).toBe(false) + }) + + it('使用默认 pageSize=20', async () => { + const fetchFn = createMockFetchFn() + const { load, pagination } = useTableLoader({ fetchFn }) + + await load() + + expect(fetchFn).toHaveBeenCalledWith( + 1, + 20, + expect.anything(), + expect.objectContaining({ signal: expect.any(AbortSignal) }) + ) + expect(pagination.page_size).toBe(20) + }) + + it('可自定义 pageSize', async () => { + const fetchFn = createMockFetchFn() + const { load } = useTableLoader({ fetchFn, pageSize: 50 }) + + await load() + + expect(fetchFn).toHaveBeenCalledWith( + 1, + 50, + expect.anything(), + expect.anything() + ) + }) + }) + + // --- 分页 --- + + describe('分页', () => { + it('handlePageChange 更新页码并加载', async () => { + const fetchFn = createMockFetchFn([], 100, 5) + const { handlePageChange, pagination, load } = useTableLoader({ fetchFn }) + + await load() // 初始加载 + fetchFn.mockClear() + + handlePageChange(3) + + expect(pagination.page).toBe(3) + // 等待 load 完成 + await vi.runAllTimersAsync() + expect(fetchFn).toHaveBeenCalledWith(3, 20, expect.anything(), expect.anything()) + }) + + it('handlePageSizeChange 重置到第1页并加载', async () => { + const fetchFn = createMockFetchFn([], 100, 5) + const { handlePageSizeChange, pagination, load } = useTableLoader({ fetchFn }) + + await load() + pagination.page = 3 + fetchFn.mockClear() + + handlePageSizeChange(50) + + expect(pagination.page).toBe(1) + expect(pagination.page_size).toBe(50) + }) + + it('handlePageChange 限制页码范围', async () => { + const fetchFn = createMockFetchFn([], 100, 5) + const { handlePageChange, pagination, load } = useTableLoader({ fetchFn }) + + await load() + + // 超出范围的页码被限制 + handlePageChange(999) + expect(pagination.page).toBe(5) // 限制在 pages=5 + + handlePageChange(0) + expect(pagination.page).toBe(1) // 最小为 1 + }) + }) + + // --- 搜索防抖 --- + + describe('搜索防抖', () => { + it('debouncedReload 在 300ms 内多次调用只执行一次', async () => { + const fetchFn = createMockFetchFn() + const { debouncedReload } = useTableLoader({ fetchFn }) + + // 快速连续调用 + debouncedReload() + debouncedReload() + debouncedReload() + + // 还没到 300ms,不应调用 fetchFn + expect(fetchFn).not.toHaveBeenCalled() + + // 推进 300ms + vi.advanceTimersByTime(300) + + // 等待异步完成 + await vi.runAllTimersAsync() + + expect(fetchFn).toHaveBeenCalledTimes(1) + }) + + it('reload 重置到第 1 页', async () => { + const fetchFn = createMockFetchFn([], 100, 5) + const { reload, pagination, load } = useTableLoader({ fetchFn }) + + await load() + pagination.page = 3 + + await reload() + + expect(pagination.page).toBe(1) + }) + }) + + // --- 请求取消 --- + + describe('请求取消', () => { + it('新请求取消前一个未完成的请求', async () => { + let callCount = 0 + const fetchFn = vi.fn((_page, _size, _params, options) => { + callCount++ + const currentCall = callCount + return new Promise((resolve, reject) => { + // 模拟监听 abort + if (options?.signal) { + options.signal.addEventListener('abort', () => { + reject({ name: 'CanceledError', code: 'ERR_CANCELED' }) + }) + } + // 异步解决 + setTimeout(() => { + resolve({ items: [{ id: currentCall }], total: 1, pages: 1 }) + }, 1000) + }) + }) + + const { load, items } = useTableLoader({ fetchFn }) + + // 第一次加载 + const p1 = load() + // 第二次加载(应取消第一次) + const p2 = load() + + // 推进时间让第二次完成 + vi.advanceTimersByTime(1000) + await vi.runAllTimersAsync() + + // 等待两个 Promise settle + await Promise.allSettled([p1, p2]) + + // 第二次请求的结果生效 + expect(fetchFn).toHaveBeenCalledTimes(2) + }) + }) + + // --- 错误处理 --- + + describe('错误处理', () => { + it('非取消错误会被抛出', async () => { + const fetchFn = vi.fn().mockRejectedValue(new Error('Server error')) + const { load } = useTableLoader({ fetchFn }) + + await expect(load()).rejects.toThrow('Server error') + }) + + it('取消错误被静默处理', async () => { + const fetchFn = vi.fn().mockRejectedValue({ name: 'CanceledError', code: 'ERR_CANCELED' }) + const { load } = useTableLoader({ fetchFn }) + + // 不应抛出 + await load() + }) + }) +}) diff --git a/frontend/src/router/__tests__/guards.spec.ts b/frontend/src/router/__tests__/guards.spec.ts new file mode 100644 index 00000000..931f4534 --- /dev/null +++ b/frontend/src/router/__tests__/guards.spec.ts @@ -0,0 +1,324 @@ +import { describe, it, expect, vi, beforeEach } from 'vitest' +import { createRouter, createMemoryHistory } from 'vue-router' +import { setActivePinia, createPinia } from 'pinia' +import { defineComponent, h } from 'vue' + +// Mock 导航加载状态 +vi.mock('@/composables/useNavigationLoading', () => { + const mockStart = vi.fn() + const mockEnd = vi.fn() + return { + useNavigationLoadingState: () => ({ + startNavigation: mockStart, + endNavigation: mockEnd, + isLoading: { value: false }, + }), + useNavigationLoading: () => ({ + startNavigation: mockStart, + endNavigation: mockEnd, + isLoading: { value: false }, + }), + } +}) + +// Mock 路由预加载 +vi.mock('@/composables/useRoutePrefetch', () => ({ + useRoutePrefetch: () => ({ + triggerPrefetch: vi.fn(), + cancelPendingPrefetch: vi.fn(), + resetPrefetchState: vi.fn(), + }), +})) + +// Mock API 相关模块 +vi.mock('@/api', () => ({ + authAPI: { + getCurrentUser: vi.fn().mockResolvedValue({ data: {} }), + logout: vi.fn(), + }, + isTotp2FARequired: () => false, +})) + +vi.mock('@/api/admin/system', () => ({ + checkUpdates: vi.fn(), +})) + +vi.mock('@/api/auth', () => ({ + getPublicSettings: vi.fn(), +})) + +const DummyComponent = defineComponent({ + render() { + return h('div', 'dummy') + }, +}) + +/** + * 创建带守卫逻辑的测试路由 + * 模拟 router/index.ts 中的 beforeEach 守卫逻辑 + */ +function createTestRouter() { + const router = createRouter({ + history: createMemoryHistory(), + routes: [ + { path: '/login', component: DummyComponent, meta: { requiresAuth: false, title: 'Login' } }, + { + path: '/register', + component: DummyComponent, + meta: { requiresAuth: false, title: 'Register' }, + }, + { path: '/home', component: DummyComponent, meta: { requiresAuth: false, title: 'Home' } }, + { path: '/dashboard', component: DummyComponent, meta: { title: 'Dashboard' } }, + { path: '/keys', component: DummyComponent, meta: { title: 'API Keys' } }, + { path: '/subscriptions', component: DummyComponent, meta: { title: 'Subscriptions' } }, + { path: '/redeem', component: DummyComponent, meta: { title: 'Redeem' } }, + { + path: '/admin/dashboard', + component: DummyComponent, + meta: { requiresAdmin: true, title: 'Admin Dashboard' }, + }, + { + path: '/admin/users', + component: DummyComponent, + meta: { requiresAdmin: true, title: 'Admin Users' }, + }, + { + path: '/admin/groups', + component: DummyComponent, + meta: { requiresAdmin: true, title: 'Admin Groups' }, + }, + { + path: '/admin/subscriptions', + component: DummyComponent, + meta: { requiresAdmin: true, title: 'Admin Subscriptions' }, + }, + { + path: '/admin/redeem', + component: DummyComponent, + meta: { requiresAdmin: true, title: 'Admin Redeem' }, + }, + ], + }) + + return router +} + +// 用于测试的 auth 状态 +interface MockAuthState { + isAuthenticated: boolean + isAdmin: boolean + isSimpleMode: boolean +} + +/** + * 将 router/index.ts 中 beforeEach 守卫的核心逻辑提取为可测试的函数 + */ +function simulateGuard( + toPath: string, + toMeta: Record, + authState: MockAuthState +): string | null { + const requiresAuth = toMeta.requiresAuth !== false + const requiresAdmin = toMeta.requiresAdmin === true + + // 不需要认证的路由 + if (!requiresAuth) { + if ( + authState.isAuthenticated && + (toPath === '/login' || toPath === '/register') + ) { + return authState.isAdmin ? '/admin/dashboard' : '/dashboard' + } + return null // 允许通过 + } + + // 需要认证但未登录 + if (!authState.isAuthenticated) { + return '/login' + } + + // 需要管理员但不是管理员 + if (requiresAdmin && !authState.isAdmin) { + return '/dashboard' + } + + // 简易模式限制 + if (authState.isSimpleMode) { + const restrictedPaths = [ + '/admin/groups', + '/admin/subscriptions', + '/admin/redeem', + '/subscriptions', + '/redeem', + ] + if (restrictedPaths.some((path) => toPath.startsWith(path))) { + return authState.isAdmin ? '/admin/dashboard' : '/dashboard' + } + } + + return null // 允许通过 +} + +describe('路由守卫逻辑', () => { + beforeEach(() => { + setActivePinia(createPinia()) + }) + + // --- 未认证用户 --- + + describe('未认证用户', () => { + const authState: MockAuthState = { + isAuthenticated: false, + isAdmin: false, + isSimpleMode: false, + } + + it('访问需要认证的页面重定向到 /login', () => { + const redirect = simulateGuard('/dashboard', {}, authState) + expect(redirect).toBe('/login') + }) + + it('访问管理页面重定向到 /login', () => { + const redirect = simulateGuard('/admin/dashboard', { requiresAdmin: true }, authState) + expect(redirect).toBe('/login') + }) + + it('访问公开页面允许通过', () => { + const redirect = simulateGuard('/login', { requiresAuth: false }, authState) + expect(redirect).toBeNull() + }) + + it('访问 /home 公开页面允许通过', () => { + const redirect = simulateGuard('/home', { requiresAuth: false }, authState) + expect(redirect).toBeNull() + }) + }) + + // --- 已认证普通用户 --- + + describe('已认证普通用户', () => { + const authState: MockAuthState = { + isAuthenticated: true, + isAdmin: false, + isSimpleMode: false, + } + + it('访问 /login 重定向到 /dashboard', () => { + const redirect = simulateGuard('/login', { requiresAuth: false }, authState) + expect(redirect).toBe('/dashboard') + }) + + it('访问 /register 重定向到 /dashboard', () => { + const redirect = simulateGuard('/register', { requiresAuth: false }, authState) + expect(redirect).toBe('/dashboard') + }) + + it('访问 /dashboard 允许通过', () => { + const redirect = simulateGuard('/dashboard', {}, authState) + expect(redirect).toBeNull() + }) + + it('访问管理页面被拒绝,重定向到 /dashboard', () => { + const redirect = simulateGuard('/admin/dashboard', { requiresAdmin: true }, authState) + expect(redirect).toBe('/dashboard') + }) + + it('访问 /admin/users 被拒绝', () => { + const redirect = simulateGuard('/admin/users', { requiresAdmin: true }, authState) + expect(redirect).toBe('/dashboard') + }) + }) + + // --- 已认证管理员 --- + + describe('已认证管理员', () => { + const authState: MockAuthState = { + isAuthenticated: true, + isAdmin: true, + isSimpleMode: false, + } + + it('访问 /login 重定向到 /admin/dashboard', () => { + const redirect = simulateGuard('/login', { requiresAuth: false }, authState) + expect(redirect).toBe('/admin/dashboard') + }) + + it('访问管理页面允许通过', () => { + const redirect = simulateGuard('/admin/dashboard', { requiresAdmin: true }, authState) + expect(redirect).toBeNull() + }) + + it('访问用户页面允许通过', () => { + const redirect = simulateGuard('/dashboard', {}, authState) + expect(redirect).toBeNull() + }) + }) + + // --- 简易模式 --- + + describe('简易模式受限路由', () => { + it('普通用户简易模式访问 /subscriptions 重定向到 /dashboard', () => { + const authState: MockAuthState = { + isAuthenticated: true, + isAdmin: false, + isSimpleMode: true, + } + const redirect = simulateGuard('/subscriptions', {}, authState) + expect(redirect).toBe('/dashboard') + }) + + it('普通用户简易模式访问 /redeem 重定向到 /dashboard', () => { + const authState: MockAuthState = { + isAuthenticated: true, + isAdmin: false, + isSimpleMode: true, + } + const redirect = simulateGuard('/redeem', {}, authState) + expect(redirect).toBe('/dashboard') + }) + + it('管理员简易模式访问 /admin/groups 重定向到 /admin/dashboard', () => { + const authState: MockAuthState = { + isAuthenticated: true, + isAdmin: true, + isSimpleMode: true, + } + const redirect = simulateGuard('/admin/groups', { requiresAdmin: true }, authState) + expect(redirect).toBe('/admin/dashboard') + }) + + it('管理员简易模式访问 /admin/subscriptions 重定向', () => { + const authState: MockAuthState = { + isAuthenticated: true, + isAdmin: true, + isSimpleMode: true, + } + const redirect = simulateGuard( + '/admin/subscriptions', + { requiresAdmin: true }, + authState + ) + expect(redirect).toBe('/admin/dashboard') + }) + + it('简易模式下非受限页面正常访问', () => { + const authState: MockAuthState = { + isAuthenticated: true, + isAdmin: false, + isSimpleMode: true, + } + const redirect = simulateGuard('/dashboard', {}, authState) + expect(redirect).toBeNull() + }) + + it('简易模式下 /keys 正常访问', () => { + const authState: MockAuthState = { + isAuthenticated: true, + isAdmin: false, + isSimpleMode: true, + } + const redirect = simulateGuard('/keys', {}, authState) + expect(redirect).toBeNull() + }) + }) +}) diff --git a/frontend/src/stores/__tests__/app.spec.ts b/frontend/src/stores/__tests__/app.spec.ts new file mode 100644 index 00000000..432a7079 --- /dev/null +++ b/frontend/src/stores/__tests__/app.spec.ts @@ -0,0 +1,293 @@ +import { describe, it, expect, vi, beforeEach, afterEach } from 'vitest' +import { setActivePinia, createPinia } from 'pinia' +import { useAppStore } from '@/stores/app' + +// Mock API 模块 +vi.mock('@/api/admin/system', () => ({ + checkUpdates: vi.fn(), +})) + +vi.mock('@/api/auth', () => ({ + getPublicSettings: vi.fn(), +})) + +describe('useAppStore', () => { + beforeEach(() => { + setActivePinia(createPinia()) + vi.useFakeTimers() + // 清除 window.__APP_CONFIG__ + delete (window as any).__APP_CONFIG__ + }) + + afterEach(() => { + vi.useRealTimers() + }) + + // --- Toast 消息管理 --- + + describe('Toast 消息管理', () => { + it('showSuccess 创建 success 类型 toast', () => { + const store = useAppStore() + const id = store.showSuccess('操作成功') + + expect(id).toMatch(/^toast-/) + expect(store.toasts).toHaveLength(1) + expect(store.toasts[0].type).toBe('success') + expect(store.toasts[0].message).toBe('操作成功') + }) + + it('showError 创建 error 类型 toast', () => { + const store = useAppStore() + store.showError('出错了') + + expect(store.toasts).toHaveLength(1) + expect(store.toasts[0].type).toBe('error') + expect(store.toasts[0].message).toBe('出错了') + }) + + it('showWarning 创建 warning 类型 toast', () => { + const store = useAppStore() + store.showWarning('警告信息') + + expect(store.toasts).toHaveLength(1) + expect(store.toasts[0].type).toBe('warning') + }) + + it('showInfo 创建 info 类型 toast', () => { + const store = useAppStore() + store.showInfo('提示信息') + + expect(store.toasts).toHaveLength(1) + expect(store.toasts[0].type).toBe('info') + }) + + it('toast 在指定 duration 后自动消失', () => { + const store = useAppStore() + store.showSuccess('临时消息', 3000) + + expect(store.toasts).toHaveLength(1) + + vi.advanceTimersByTime(3000) + + expect(store.toasts).toHaveLength(0) + }) + + it('hideToast 移除指定 toast', () => { + const store = useAppStore() + const id = store.showSuccess('消息1') + store.showError('消息2') + + expect(store.toasts).toHaveLength(2) + + store.hideToast(id) + + expect(store.toasts).toHaveLength(1) + expect(store.toasts[0].type).toBe('error') + }) + + it('clearAllToasts 清除所有 toast', () => { + const store = useAppStore() + store.showSuccess('消息1') + store.showError('消息2') + store.showWarning('消息3') + + expect(store.toasts).toHaveLength(3) + + store.clearAllToasts() + + expect(store.toasts).toHaveLength(0) + }) + + it('hasActiveToasts 正确反映 toast 状态', () => { + const store = useAppStore() + expect(store.hasActiveToasts).toBe(false) + + store.showSuccess('消息') + expect(store.hasActiveToasts).toBe(true) + + store.clearAllToasts() + expect(store.hasActiveToasts).toBe(false) + }) + + it('多个 toast 的 ID 唯一', () => { + const store = useAppStore() + const id1 = store.showSuccess('消息1') + const id2 = store.showSuccess('消息2') + const id3 = store.showSuccess('消息3') + + expect(id1).not.toBe(id2) + expect(id2).not.toBe(id3) + }) + }) + + // --- 侧边栏 --- + + describe('侧边栏管理', () => { + it('toggleSidebar 切换折叠状态', () => { + const store = useAppStore() + expect(store.sidebarCollapsed).toBe(false) + + store.toggleSidebar() + expect(store.sidebarCollapsed).toBe(true) + + store.toggleSidebar() + expect(store.sidebarCollapsed).toBe(false) + }) + + it('setSidebarCollapsed 直接设置状态', () => { + const store = useAppStore() + + store.setSidebarCollapsed(true) + expect(store.sidebarCollapsed).toBe(true) + + store.setSidebarCollapsed(false) + expect(store.sidebarCollapsed).toBe(false) + }) + + it('toggleMobileSidebar 切换移动端状态', () => { + const store = useAppStore() + expect(store.mobileOpen).toBe(false) + + store.toggleMobileSidebar() + expect(store.mobileOpen).toBe(true) + + store.toggleMobileSidebar() + expect(store.mobileOpen).toBe(false) + }) + }) + + // --- Loading 状态 --- + + describe('Loading 状态管理', () => { + it('setLoading 管理引用计数', () => { + const store = useAppStore() + expect(store.loading).toBe(false) + + store.setLoading(true) + expect(store.loading).toBe(true) + + store.setLoading(true) // 两次 true + expect(store.loading).toBe(true) + + store.setLoading(false) // 第一次 false,计数还是 1 + expect(store.loading).toBe(true) + + store.setLoading(false) // 第二次 false,计数为 0 + expect(store.loading).toBe(false) + }) + + it('setLoading(false) 不会使计数为负', () => { + const store = useAppStore() + + store.setLoading(false) + store.setLoading(false) + expect(store.loading).toBe(false) + + store.setLoading(true) + expect(store.loading).toBe(true) + + store.setLoading(false) + expect(store.loading).toBe(false) + }) + + it('withLoading 自动管理 loading 状态', async () => { + const store = useAppStore() + + const result = await store.withLoading(async () => { + expect(store.loading).toBe(true) + return 'done' + }) + + expect(result).toBe('done') + expect(store.loading).toBe(false) + }) + + it('withLoading 错误时也恢复 loading 状态', async () => { + const store = useAppStore() + + await expect( + store.withLoading(async () => { + throw new Error('操作失败') + }) + ).rejects.toThrow('操作失败') + + expect(store.loading).toBe(false) + }) + + it('withLoadingAndError 错误时显示 toast 并返回 null', async () => { + const store = useAppStore() + + const result = await store.withLoadingAndError(async () => { + throw new Error('网络错误') + }) + + expect(result).toBeNull() + expect(store.loading).toBe(false) + expect(store.toasts).toHaveLength(1) + expect(store.toasts[0].type).toBe('error') + }) + }) + + // --- reset --- + + describe('reset', () => { + it('重置所有 UI 状态', () => { + const store = useAppStore() + + store.setSidebarCollapsed(true) + store.setLoading(true) + store.showSuccess('消息') + + store.reset() + + expect(store.sidebarCollapsed).toBe(false) + expect(store.loading).toBe(false) + expect(store.toasts).toHaveLength(0) + }) + }) + + // --- 公开设置 --- + + describe('公开设置加载', () => { + it('从 window.__APP_CONFIG__ 初始化', () => { + ;(window as any).__APP_CONFIG__ = { + site_name: 'TestSite', + site_logo: '/logo.png', + version: '1.0.0', + contact_info: 'test@test.com', + api_base_url: 'https://api.test.com', + doc_url: 'https://docs.test.com', + } + + const store = useAppStore() + const result = store.initFromInjectedConfig() + + expect(result).toBe(true) + expect(store.siteName).toBe('TestSite') + expect(store.siteLogo).toBe('/logo.png') + expect(store.siteVersion).toBe('1.0.0') + expect(store.publicSettingsLoaded).toBe(true) + }) + + it('无注入配置时返回 false', () => { + const store = useAppStore() + const result = store.initFromInjectedConfig() + + expect(result).toBe(false) + expect(store.publicSettingsLoaded).toBe(false) + }) + + it('clearPublicSettingsCache 清除缓存', () => { + ;(window as any).__APP_CONFIG__ = { site_name: 'Test' } + const store = useAppStore() + store.initFromInjectedConfig() + + expect(store.publicSettingsLoaded).toBe(true) + + store.clearPublicSettingsCache() + + expect(store.publicSettingsLoaded).toBe(false) + expect(store.cachedPublicSettings).toBeNull() + }) + }) +}) diff --git a/frontend/src/stores/__tests__/auth.spec.ts b/frontend/src/stores/__tests__/auth.spec.ts new file mode 100644 index 00000000..ee6ad24e --- /dev/null +++ b/frontend/src/stores/__tests__/auth.spec.ts @@ -0,0 +1,289 @@ +import { describe, it, expect, vi, beforeEach, afterEach } from 'vitest' +import { setActivePinia, createPinia } from 'pinia' +import { useAuthStore } from '@/stores/auth' + +// Mock authAPI +const mockLogin = vi.fn() +const mockLogin2FA = vi.fn() +const mockLogout = vi.fn() +const mockGetCurrentUser = vi.fn() +const mockRegister = vi.fn() +const mockRefreshToken = vi.fn() + +vi.mock('@/api', () => ({ + authAPI: { + login: (...args: any[]) => mockLogin(...args), + login2FA: (...args: any[]) => mockLogin2FA(...args), + logout: (...args: any[]) => mockLogout(...args), + getCurrentUser: (...args: any[]) => mockGetCurrentUser(...args), + register: (...args: any[]) => mockRegister(...args), + refreshToken: (...args: any[]) => mockRefreshToken(...args), + }, + isTotp2FARequired: (response: any) => response?.requires_2fa === true, +})) + +const fakeUser = { + id: 1, + username: 'testuser', + email: 'test@example.com', + role: 'user' as const, + balance: 100, + concurrency: 5, + status: 'active' as const, + allowed_groups: null, + created_at: '2024-01-01', + updated_at: '2024-01-01', +} + +const fakeAdminUser = { + ...fakeUser, + id: 2, + username: 'admin', + email: 'admin@example.com', + role: 'admin' as const, +} + +const fakeAuthResponse = { + access_token: 'test-token-123', + refresh_token: 'refresh-token-456', + expires_in: 3600, + token_type: 'Bearer', + user: { ...fakeUser }, +} + +describe('useAuthStore', () => { + beforeEach(() => { + setActivePinia(createPinia()) + localStorage.clear() + vi.useFakeTimers() + vi.clearAllMocks() + }) + + afterEach(() => { + vi.useRealTimers() + }) + + // --- login --- + + describe('login', () => { + it('成功登录后设置 token 和 user', async () => { + mockLogin.mockResolvedValue(fakeAuthResponse) + const store = useAuthStore() + + await store.login({ email: 'test@example.com', password: '123456' }) + + expect(store.token).toBe('test-token-123') + expect(store.user).toEqual(fakeUser) + expect(store.isAuthenticated).toBe(true) + expect(localStorage.getItem('auth_token')).toBe('test-token-123') + expect(localStorage.getItem('auth_user')).toBe(JSON.stringify(fakeUser)) + }) + + it('登录失败时清除状态并抛出错误', async () => { + mockLogin.mockRejectedValue(new Error('Invalid credentials')) + const store = useAuthStore() + + await expect(store.login({ email: 'test@example.com', password: 'wrong' })).rejects.toThrow( + 'Invalid credentials' + ) + + expect(store.token).toBeNull() + expect(store.user).toBeNull() + expect(store.isAuthenticated).toBe(false) + }) + + it('需要 2FA 时返回响应但不设置认证状态', async () => { + const twoFAResponse = { requires_2fa: true, temp_token: 'temp-123' } + mockLogin.mockResolvedValue(twoFAResponse) + const store = useAuthStore() + + const result = await store.login({ email: 'test@example.com', password: '123456' }) + + expect(result).toEqual(twoFAResponse) + expect(store.token).toBeNull() + expect(store.isAuthenticated).toBe(false) + }) + }) + + // --- login2FA --- + + describe('login2FA', () => { + it('2FA 验证成功后设置认证状态', async () => { + mockLogin2FA.mockResolvedValue(fakeAuthResponse) + const store = useAuthStore() + + const user = await store.login2FA('temp-123', '654321') + + expect(store.token).toBe('test-token-123') + expect(store.user).toEqual(fakeUser) + expect(user).toEqual(fakeUser) + expect(mockLogin2FA).toHaveBeenCalledWith({ + temp_token: 'temp-123', + totp_code: '654321', + }) + }) + + it('2FA 验证失败时清除状态并抛出错误', async () => { + mockLogin2FA.mockRejectedValue(new Error('Invalid TOTP')) + const store = useAuthStore() + + await expect(store.login2FA('temp-123', '000000')).rejects.toThrow('Invalid TOTP') + expect(store.token).toBeNull() + expect(store.isAuthenticated).toBe(false) + }) + }) + + // --- logout --- + + describe('logout', () => { + it('注销后清除所有状态和 localStorage', async () => { + mockLogin.mockResolvedValue(fakeAuthResponse) + mockLogout.mockResolvedValue(undefined) + const store = useAuthStore() + + // 先登录 + await store.login({ email: 'test@example.com', password: '123456' }) + expect(store.isAuthenticated).toBe(true) + + // 注销 + await store.logout() + + expect(store.token).toBeNull() + expect(store.user).toBeNull() + expect(store.isAuthenticated).toBe(false) + expect(localStorage.getItem('auth_token')).toBeNull() + expect(localStorage.getItem('auth_user')).toBeNull() + expect(localStorage.getItem('refresh_token')).toBeNull() + expect(localStorage.getItem('token_expires_at')).toBeNull() + }) + }) + + // --- checkAuth --- + + describe('checkAuth', () => { + it('从 localStorage 恢复持久化状态', () => { + localStorage.setItem('auth_token', 'saved-token') + localStorage.setItem('auth_user', JSON.stringify(fakeUser)) + + // Mock refreshUser (getCurrentUser) 防止后台刷新报错 + mockGetCurrentUser.mockResolvedValue({ data: fakeUser }) + + const store = useAuthStore() + store.checkAuth() + + expect(store.token).toBe('saved-token') + expect(store.user).toEqual(fakeUser) + expect(store.isAuthenticated).toBe(true) + }) + + it('localStorage 无数据时保持未认证状态', () => { + const store = useAuthStore() + store.checkAuth() + + expect(store.token).toBeNull() + expect(store.user).toBeNull() + expect(store.isAuthenticated).toBe(false) + }) + + it('localStorage 中用户数据损坏时清除状态', () => { + localStorage.setItem('auth_token', 'saved-token') + localStorage.setItem('auth_user', 'invalid-json{{{') + + const store = useAuthStore() + store.checkAuth() + + expect(store.token).toBeNull() + expect(store.user).toBeNull() + expect(localStorage.getItem('auth_token')).toBeNull() + }) + + it('恢复 refresh token 和过期时间', () => { + const futureTs = String(Date.now() + 3600_000) + localStorage.setItem('auth_token', 'saved-token') + localStorage.setItem('auth_user', JSON.stringify(fakeUser)) + localStorage.setItem('refresh_token', 'saved-refresh') + localStorage.setItem('token_expires_at', futureTs) + + mockGetCurrentUser.mockResolvedValue({ data: fakeUser }) + + const store = useAuthStore() + store.checkAuth() + + expect(store.isAuthenticated).toBe(true) + }) + }) + + // --- isAdmin --- + + describe('isAdmin', () => { + it('管理员用户返回 true', async () => { + const adminResponse = { ...fakeAuthResponse, user: { ...fakeAdminUser } } + mockLogin.mockResolvedValue(adminResponse) + const store = useAuthStore() + + await store.login({ email: 'admin@example.com', password: '123456' }) + + expect(store.isAdmin).toBe(true) + }) + + it('普通用户返回 false', async () => { + mockLogin.mockResolvedValue(fakeAuthResponse) + const store = useAuthStore() + + await store.login({ email: 'test@example.com', password: '123456' }) + + expect(store.isAdmin).toBe(false) + }) + + it('未登录时返回 false', () => { + const store = useAuthStore() + expect(store.isAdmin).toBe(false) + }) + }) + + // --- refreshUser --- + + describe('refreshUser', () => { + it('刷新用户数据并更新 localStorage', async () => { + mockLogin.mockResolvedValue(fakeAuthResponse) + const store = useAuthStore() + await store.login({ email: 'test@example.com', password: '123456' }) + + const updatedUser = { ...fakeUser, username: 'updated-name' } + mockGetCurrentUser.mockResolvedValue({ data: updatedUser }) + + const result = await store.refreshUser() + + expect(result).toEqual(updatedUser) + expect(store.user).toEqual(updatedUser) + expect(JSON.parse(localStorage.getItem('auth_user')!)).toEqual(updatedUser) + }) + + it('未认证时抛出错误', async () => { + const store = useAuthStore() + await expect(store.refreshUser()).rejects.toThrow('Not authenticated') + }) + }) + + // --- isSimpleMode --- + + describe('isSimpleMode', () => { + it('run_mode 为 simple 时返回 true', async () => { + const simpleResponse = { + ...fakeAuthResponse, + user: { ...fakeUser, run_mode: 'simple' as const }, + } + mockLogin.mockResolvedValue(simpleResponse) + const store = useAuthStore() + + await store.login({ email: 'test@example.com', password: '123456' }) + + expect(store.isSimpleMode).toBe(true) + }) + + it('默认为 standard 模式', () => { + const store = useAuthStore() + expect(store.isSimpleMode).toBe(false) + }) + }) +}) diff --git a/frontend/src/stores/__tests__/subscriptions.spec.ts b/frontend/src/stores/__tests__/subscriptions.spec.ts new file mode 100644 index 00000000..4c0b4b89 --- /dev/null +++ b/frontend/src/stores/__tests__/subscriptions.spec.ts @@ -0,0 +1,239 @@ +import { describe, it, expect, vi, beforeEach, afterEach } from 'vitest' +import { setActivePinia, createPinia } from 'pinia' +import { useSubscriptionStore } from '@/stores/subscriptions' + +// Mock subscriptions API +const mockGetActiveSubscriptions = vi.fn() + +vi.mock('@/api/subscriptions', () => ({ + default: { + getActiveSubscriptions: (...args: any[]) => mockGetActiveSubscriptions(...args), + }, +})) + +const fakeSubscriptions = [ + { + id: 1, + user_id: 1, + group_id: 1, + status: 'active' as const, + daily_usage_usd: 5, + weekly_usage_usd: 20, + monthly_usage_usd: 50, + daily_window_start: null, + weekly_window_start: null, + monthly_window_start: null, + created_at: '2024-01-01', + updated_at: '2024-01-01', + expires_at: '2025-01-01', + }, + { + id: 2, + user_id: 1, + group_id: 2, + status: 'active' as const, + daily_usage_usd: 10, + weekly_usage_usd: 40, + monthly_usage_usd: 100, + daily_window_start: null, + weekly_window_start: null, + monthly_window_start: null, + created_at: '2024-02-01', + updated_at: '2024-02-01', + expires_at: '2025-02-01', + }, +] + +describe('useSubscriptionStore', () => { + beforeEach(() => { + setActivePinia(createPinia()) + vi.useFakeTimers() + vi.clearAllMocks() + }) + + afterEach(() => { + vi.useRealTimers() + }) + + // --- fetchActiveSubscriptions --- + + describe('fetchActiveSubscriptions', () => { + it('成功获取活跃订阅', async () => { + mockGetActiveSubscriptions.mockResolvedValue(fakeSubscriptions) + const store = useSubscriptionStore() + + const result = await store.fetchActiveSubscriptions() + + expect(result).toEqual(fakeSubscriptions) + expect(store.activeSubscriptions).toEqual(fakeSubscriptions) + expect(store.loading).toBe(false) + }) + + it('缓存有效时返回缓存数据', async () => { + mockGetActiveSubscriptions.mockResolvedValue(fakeSubscriptions) + const store = useSubscriptionStore() + + // 第一次请求 + await store.fetchActiveSubscriptions() + expect(mockGetActiveSubscriptions).toHaveBeenCalledTimes(1) + + // 第二次请求(60秒内)- 应返回缓存 + const result = await store.fetchActiveSubscriptions() + expect(mockGetActiveSubscriptions).toHaveBeenCalledTimes(1) // 没有新请求 + expect(result).toEqual(fakeSubscriptions) + }) + + it('缓存过期后重新请求', async () => { + mockGetActiveSubscriptions.mockResolvedValue(fakeSubscriptions) + const store = useSubscriptionStore() + + await store.fetchActiveSubscriptions() + expect(mockGetActiveSubscriptions).toHaveBeenCalledTimes(1) + + // 推进 61 秒让缓存过期 + vi.advanceTimersByTime(61_000) + + const updatedSubs = [fakeSubscriptions[0]] + mockGetActiveSubscriptions.mockResolvedValue(updatedSubs) + + const result = await store.fetchActiveSubscriptions() + expect(mockGetActiveSubscriptions).toHaveBeenCalledTimes(2) + expect(result).toEqual(updatedSubs) + }) + + it('force=true 强制重新请求', async () => { + mockGetActiveSubscriptions.mockResolvedValue(fakeSubscriptions) + const store = useSubscriptionStore() + + await store.fetchActiveSubscriptions() + + const updatedSubs = [fakeSubscriptions[0]] + mockGetActiveSubscriptions.mockResolvedValue(updatedSubs) + + const result = await store.fetchActiveSubscriptions(true) + expect(mockGetActiveSubscriptions).toHaveBeenCalledTimes(2) + expect(result).toEqual(updatedSubs) + }) + + it('并发请求共享同一个 Promise(去重)', async () => { + let resolvePromise: (v: any) => void + mockGetActiveSubscriptions.mockImplementation( + () => new Promise((resolve) => { resolvePromise = resolve }) + ) + const store = useSubscriptionStore() + + // 并发发起两个请求 + const p1 = store.fetchActiveSubscriptions() + const p2 = store.fetchActiveSubscriptions() + + // 只调用了一次 API + expect(mockGetActiveSubscriptions).toHaveBeenCalledTimes(1) + + // 解决 Promise + resolvePromise!(fakeSubscriptions) + + const [r1, r2] = await Promise.all([p1, p2]) + expect(r1).toEqual(fakeSubscriptions) + expect(r2).toEqual(fakeSubscriptions) + }) + + it('API 错误时抛出异常', async () => { + mockGetActiveSubscriptions.mockRejectedValue(new Error('Network error')) + const store = useSubscriptionStore() + + await expect(store.fetchActiveSubscriptions()).rejects.toThrow('Network error') + }) + }) + + // --- hasActiveSubscriptions --- + + describe('hasActiveSubscriptions', () => { + it('有订阅时返回 true', async () => { + mockGetActiveSubscriptions.mockResolvedValue(fakeSubscriptions) + const store = useSubscriptionStore() + + await store.fetchActiveSubscriptions() + + expect(store.hasActiveSubscriptions).toBe(true) + }) + + it('无订阅时返回 false', () => { + const store = useSubscriptionStore() + expect(store.hasActiveSubscriptions).toBe(false) + }) + + it('清除后返回 false', async () => { + mockGetActiveSubscriptions.mockResolvedValue(fakeSubscriptions) + const store = useSubscriptionStore() + + await store.fetchActiveSubscriptions() + expect(store.hasActiveSubscriptions).toBe(true) + + store.clear() + expect(store.hasActiveSubscriptions).toBe(false) + }) + }) + + // --- invalidateCache --- + + describe('invalidateCache', () => { + it('失效缓存后下次请求重新获取数据', async () => { + mockGetActiveSubscriptions.mockResolvedValue(fakeSubscriptions) + const store = useSubscriptionStore() + + await store.fetchActiveSubscriptions() + expect(mockGetActiveSubscriptions).toHaveBeenCalledTimes(1) + + store.invalidateCache() + + await store.fetchActiveSubscriptions() + expect(mockGetActiveSubscriptions).toHaveBeenCalledTimes(2) + }) + }) + + // --- clear --- + + describe('clear', () => { + it('清除所有订阅数据', async () => { + mockGetActiveSubscriptions.mockResolvedValue(fakeSubscriptions) + const store = useSubscriptionStore() + + await store.fetchActiveSubscriptions() + expect(store.activeSubscriptions).toHaveLength(2) + + store.clear() + + expect(store.activeSubscriptions).toHaveLength(0) + expect(store.hasActiveSubscriptions).toBe(false) + }) + }) + + // --- polling --- + + describe('startPolling / stopPolling', () => { + it('startPolling 不会创建重复 interval', () => { + const store = useSubscriptionStore() + mockGetActiveSubscriptions.mockResolvedValue([]) + + store.startPolling() + store.startPolling() // 重复调用 + + // 推进5分钟只触发一次 + vi.advanceTimersByTime(5 * 60 * 1000) + expect(mockGetActiveSubscriptions).toHaveBeenCalledTimes(1) + + store.stopPolling() + }) + + it('stopPolling 停止定期刷新', () => { + const store = useSubscriptionStore() + mockGetActiveSubscriptions.mockResolvedValue([]) + + store.startPolling() + store.stopPolling() + + vi.advanceTimersByTime(10 * 60 * 1000) + expect(mockGetActiveSubscriptions).not.toHaveBeenCalled() + }) + }) +}) diff --git a/frontend/vitest.config.ts b/frontend/vitest.config.ts index 0b20cb60..1007f6ed 100644 --- a/frontend/vitest.config.ts +++ b/frontend/vitest.config.ts @@ -1,35 +1,44 @@ -import { defineConfig, mergeConfig } from 'vitest/config' -import viteConfig from './vite.config' +import { defineConfig } from 'vitest/config' +import vue from '@vitejs/plugin-vue' +import { resolve } from 'path' -export default mergeConfig( - viteConfig, - defineConfig({ - test: { - globals: true, - environment: 'jsdom', - include: ['src/**/*.{test,spec}.{js,ts,jsx,tsx}'], - exclude: ['node_modules', 'dist'], - coverage: { - provider: 'v8', - reporter: ['text', 'json', 'html'], - include: ['src/**/*.{js,ts,vue}'], - exclude: [ - 'node_modules', - 'src/**/*.d.ts', - 'src/**/*.spec.ts', - 'src/**/*.test.ts', - 'src/main.ts' - ], - thresholds: { - global: { - statements: 80, - branches: 80, - functions: 80, - lines: 80 - } - } - }, - setupFiles: ['./src/__tests__/setup.ts'] +export default defineConfig({ + plugins: [vue()], + resolve: { + alias: { + '@': resolve(__dirname, 'src'), + 'vue-i18n': 'vue-i18n/dist/vue-i18n.runtime.esm-bundler.js' } - }) -) + }, + define: { + __INTLIFY_JIT_COMPILATION__: true + }, + test: { + globals: true, + environment: 'jsdom', + include: ['src/**/*.{test,spec}.{js,ts,jsx,tsx}'], + exclude: ['node_modules', 'dist'], + coverage: { + provider: 'v8', + reporter: ['text', 'json', 'html'], + include: ['src/**/*.{js,ts,vue}'], + exclude: [ + 'node_modules', + 'src/**/*.d.ts', + 'src/**/*.spec.ts', + 'src/**/*.test.ts', + 'src/main.ts' + ], + thresholds: { + global: { + statements: 80, + branches: 80, + functions: 80, + lines: 80 + } + } + }, + setupFiles: ['./src/__tests__/setup.ts'], + testTimeout: 10000 + } +})