diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml deleted file mode 100644 index 03e7159f..00000000 --- a/.github/workflows/ci.yml +++ /dev/null @@ -1,179 +0,0 @@ -name: CI - -on: - push: - pull_request: - -permissions: - contents: read - -jobs: - # ========================================================================== - # 后端测试(与前端并行运行) - # ========================================================================== - backend-test: - runs-on: ubuntu-latest - services: - postgres: - image: postgres:16-alpine - env: - POSTGRES_USER: test - POSTGRES_PASSWORD: test - POSTGRES_DB: sub2api_test - ports: - - 5432:5432 - options: >- - --health-cmd "pg_isready -U test" - --health-interval 10s - --health-timeout 5s - --health-retries 5 - redis: - image: redis:7-alpine - ports: - - 6379:6379 - options: >- - --health-cmd "redis-cli ping" - --health-interval 10s - --health-timeout 5s - --health-retries 5 - steps: - - uses: actions/checkout@v4 - - uses: actions/setup-go@v5 - with: - go-version-file: backend/go.mod - check-latest: false - cache: true - - - name: 验证 Go 版本 - run: go version | grep -q 'go1.25.7' - - - name: 单元测试 - working-directory: backend - run: make test-unit - - - name: 集成测试 - working-directory: backend - env: - DATABASE_URL: postgres://test:test@localhost:5432/sub2api_test?sslmode=disable - REDIS_URL: redis://localhost:6379/0 - run: make test-integration - - - name: Race 检测 - working-directory: backend - run: go test -tags=unit -race -count=1 ./... - - - name: 覆盖率收集 - working-directory: backend - run: | - go test -tags=unit -coverprofile=coverage.out -count=1 ./... - echo "## 后端测试覆盖率" >> $GITHUB_STEP_SUMMARY - echo '```' >> $GITHUB_STEP_SUMMARY - go tool cover -func=coverage.out | tail -1 >> $GITHUB_STEP_SUMMARY - echo '```' >> $GITHUB_STEP_SUMMARY - - - name: 覆盖率门禁(≥8%) - working-directory: backend - run: | - COVERAGE=$(go tool cover -func=coverage.out | tail -1 | awk '{print $3}' | sed 's/%//') - echo "当前覆盖率: ${COVERAGE}%" - if [ "$(echo "$COVERAGE < 8" | bc -l)" -eq 1 ]; then - echo "::error::后端覆盖率 ${COVERAGE}% 低于门禁值 8%" - exit 1 - fi - - # ========================================================================== - # 后端代码检查 - # ========================================================================== - golangci-lint: - runs-on: ubuntu-latest - steps: - - uses: actions/checkout@v4 - - uses: actions/setup-go@v5 - with: - go-version-file: backend/go.mod - check-latest: false - cache: true - - name: 验证 Go 版本 - run: go version | grep -q 'go1.25.7' - - name: golangci-lint - uses: golangci/golangci-lint-action@v9 - with: - version: v2.7 - args: --timeout=5m - working-directory: backend - - # ========================================================================== - # 前端测试(与后端并行运行) - # ========================================================================== - frontend-test: - runs-on: ubuntu-latest - steps: - - uses: actions/checkout@v4 - - name: 安装 pnpm - uses: pnpm/action-setup@v4 - with: - version: 9 - - name: 安装 Node.js - uses: actions/setup-node@v4 - with: - node-version: '20' - cache: 'pnpm' - cache-dependency-path: frontend/pnpm-lock.yaml - - name: 安装依赖 - working-directory: frontend - run: pnpm install --frozen-lockfile - - - name: 类型检查 - working-directory: frontend - run: pnpm run typecheck - - - name: Lint 检查 - working-directory: frontend - run: pnpm run lint:check - - - name: 单元测试 - working-directory: frontend - run: pnpm run test:run - - - name: 覆盖率收集 - working-directory: frontend - run: | - pnpm run test:coverage -- --exclude '**/integration/**' || true - echo "## 前端测试覆盖率" >> $GITHUB_STEP_SUMMARY - if [ -f coverage/coverage-final.json ]; then - echo "覆盖率报告已生成" >> $GITHUB_STEP_SUMMARY - fi - - - name: 覆盖率门禁(≥20%) - working-directory: frontend - run: | - if [ ! -f coverage/coverage-final.json ]; then - echo "::warning::覆盖率报告未生成,跳过门禁检查" - exit 0 - fi - # 使用 node 解析覆盖率 JSON - COVERAGE=$(node -e " - const data = require('./coverage/coverage-final.json'); - let totalStatements = 0, coveredStatements = 0; - for (const file of Object.values(data)) { - const stmts = file.s; - totalStatements += Object.keys(stmts).length; - coveredStatements += Object.values(stmts).filter(v => v > 0).length; - } - const pct = totalStatements > 0 ? (coveredStatements / totalStatements * 100) : 0; - console.log(pct.toFixed(1)); - ") - echo "当前前端覆盖率: ${COVERAGE}%" - if [ "$(echo "$COVERAGE < 20" | bc -l 2>/dev/null || node -e "console.log($COVERAGE < 20 ? 1 : 0)")" = "1" ]; then - echo "::warning::前端覆盖率 ${COVERAGE}% 低于门禁值 20%(当前为警告,不阻塞)" - fi - - # ========================================================================== - # Docker 构建验证 - # ========================================================================== - docker-build: - runs-on: ubuntu-latest - steps: - - uses: actions/checkout@v4 - - name: Docker 构建验证 - run: docker build -t aicodex2api:ci-test . diff --git a/backend/internal/pkg/antigravity/client_test.go b/backend/internal/pkg/antigravity/client_test.go new file mode 100644 index 00000000..89a4f022 --- /dev/null +++ b/backend/internal/pkg/antigravity/client_test.go @@ -0,0 +1,1657 @@ +//go:build unit + +package antigravity + +import ( + "context" + "encoding/json" + "fmt" + "net" + "net/http" + "net/http/httptest" + "net/url" + "strings" + "testing" + "time" +) + +// --------------------------------------------------------------------------- +// NewAPIRequestWithURL +// --------------------------------------------------------------------------- + +func TestNewAPIRequestWithURL_普通请求(t *testing.T) { + ctx := context.Background() + baseURL := "https://example.com" + action := "generateContent" + token := "test-token" + body := []byte(`{"prompt":"hello"}`) + + req, err := NewAPIRequestWithURL(ctx, baseURL, action, token, body) + if err != nil { + t.Fatalf("创建请求失败: %v", err) + } + + // 验证 URL 不含 ?alt=sse + expectedURL := "https://example.com/v1internal:generateContent" + if req.URL.String() != expectedURL { + t.Errorf("URL 不匹配: got %s, want %s", req.URL.String(), expectedURL) + } + + // 验证请求方法 + if req.Method != http.MethodPost { + t.Errorf("请求方法不匹配: got %s, want POST", req.Method) + } + + // 验证 Headers + if ct := req.Header.Get("Content-Type"); ct != "application/json" { + t.Errorf("Content-Type 不匹配: got %s", ct) + } + if auth := req.Header.Get("Authorization"); auth != "Bearer test-token" { + t.Errorf("Authorization 不匹配: got %s", auth) + } + if ua := req.Header.Get("User-Agent"); ua != UserAgent { + t.Errorf("User-Agent 不匹配: got %s, want %s", ua, UserAgent) + } +} + +func TestNewAPIRequestWithURL_流式请求(t *testing.T) { + ctx := context.Background() + baseURL := "https://example.com" + action := "streamGenerateContent" + token := "tok" + body := []byte(`{}`) + + req, err := NewAPIRequestWithURL(ctx, baseURL, action, token, body) + if err != nil { + t.Fatalf("创建请求失败: %v", err) + } + + expectedURL := "https://example.com/v1internal:streamGenerateContent?alt=sse" + if req.URL.String() != expectedURL { + t.Errorf("URL 不匹配: got %s, want %s", req.URL.String(), expectedURL) + } +} + +func TestNewAPIRequestWithURL_空Body(t *testing.T) { + ctx := context.Background() + req, err := NewAPIRequestWithURL(ctx, "https://example.com", "test", "tok", nil) + if err != nil { + t.Fatalf("创建请求失败: %v", err) + } + if req.Body == nil { + t.Error("Body 应该非 nil(bytes.NewReader(nil) 会返回空 reader)") + } +} + +// --------------------------------------------------------------------------- +// NewAPIRequest +// --------------------------------------------------------------------------- + +func TestNewAPIRequest_使用默认URL(t *testing.T) { + ctx := context.Background() + req, err := NewAPIRequest(ctx, "generateContent", "tok", []byte(`{}`)) + if err != nil { + t.Fatalf("创建请求失败: %v", err) + } + + expected := BaseURL + "/v1internal:generateContent" + if req.URL.String() != expected { + t.Errorf("URL 不匹配: got %s, want %s", req.URL.String(), expected) + } +} + +// --------------------------------------------------------------------------- +// TierInfo.UnmarshalJSON +// --------------------------------------------------------------------------- + +func TestTierInfo_UnmarshalJSON_字符串格式(t *testing.T) { + data := []byte(`"free-tier"`) + var tier TierInfo + if err := tier.UnmarshalJSON(data); err != nil { + t.Fatalf("反序列化失败: %v", err) + } + if tier.ID != "free-tier" { + t.Errorf("ID 不匹配: got %s, want free-tier", tier.ID) + } + if tier.Name != "" { + t.Errorf("Name 应为空: got %s", tier.Name) + } +} + +func TestTierInfo_UnmarshalJSON_对象格式(t *testing.T) { + data := []byte(`{"id":"g1-pro-tier","name":"Pro","description":"Pro plan"}`) + var tier TierInfo + if err := tier.UnmarshalJSON(data); err != nil { + t.Fatalf("反序列化失败: %v", err) + } + if tier.ID != "g1-pro-tier" { + t.Errorf("ID 不匹配: got %s, want g1-pro-tier", tier.ID) + } + if tier.Name != "Pro" { + t.Errorf("Name 不匹配: got %s, want Pro", tier.Name) + } + if tier.Description != "Pro plan" { + t.Errorf("Description 不匹配: got %s, want Pro plan", tier.Description) + } +} + +func TestTierInfo_UnmarshalJSON_null(t *testing.T) { + data := []byte(`null`) + var tier TierInfo + if err := tier.UnmarshalJSON(data); err != nil { + t.Fatalf("反序列化 null 失败: %v", err) + } + if tier.ID != "" { + t.Errorf("null 场景下 ID 应为空: got %s", tier.ID) + } +} + +func TestTierInfo_UnmarshalJSON_空数据(t *testing.T) { + data := []byte(``) + var tier TierInfo + if err := tier.UnmarshalJSON(data); err != nil { + t.Fatalf("反序列化空数据失败: %v", err) + } + if tier.ID != "" { + t.Errorf("空数据场景下 ID 应为空: got %s", tier.ID) + } +} + +func TestTierInfo_UnmarshalJSON_空格包裹null(t *testing.T) { + data := []byte(` null `) + var tier TierInfo + if err := tier.UnmarshalJSON(data); err != nil { + t.Fatalf("反序列化空格 null 失败: %v", err) + } + if tier.ID != "" { + t.Errorf("空格 null 场景下 ID 应为空: got %s", tier.ID) + } +} + +func TestTierInfo_UnmarshalJSON_通过JSON嵌套结构(t *testing.T) { + // 模拟 LoadCodeAssistResponse 中的嵌套反序列化 + jsonData := `{"currentTier":"free-tier","paidTier":{"id":"g1-ultra-tier","name":"Ultra"}}` + var resp LoadCodeAssistResponse + if err := json.Unmarshal([]byte(jsonData), &resp); err != nil { + t.Fatalf("反序列化嵌套结构失败: %v", err) + } + if resp.CurrentTier == nil || resp.CurrentTier.ID != "free-tier" { + t.Errorf("CurrentTier 不匹配: got %+v", resp.CurrentTier) + } + if resp.PaidTier == nil || resp.PaidTier.ID != "g1-ultra-tier" { + t.Errorf("PaidTier 不匹配: got %+v", resp.PaidTier) + } +} + +// --------------------------------------------------------------------------- +// LoadCodeAssistResponse.GetTier +// --------------------------------------------------------------------------- + +func TestGetTier_PaidTier优先(t *testing.T) { + resp := &LoadCodeAssistResponse{ + CurrentTier: &TierInfo{ID: "free-tier"}, + PaidTier: &TierInfo{ID: "g1-pro-tier"}, + } + if got := resp.GetTier(); got != "g1-pro-tier" { + t.Errorf("应返回 paidTier: got %s", got) + } +} + +func TestGetTier_回退到CurrentTier(t *testing.T) { + resp := &LoadCodeAssistResponse{ + CurrentTier: &TierInfo{ID: "free-tier"}, + } + if got := resp.GetTier(); got != "free-tier" { + t.Errorf("应返回 currentTier: got %s", got) + } +} + +func TestGetTier_PaidTier为空ID(t *testing.T) { + resp := &LoadCodeAssistResponse{ + CurrentTier: &TierInfo{ID: "free-tier"}, + PaidTier: &TierInfo{ID: ""}, + } + // paidTier.ID 为空时应回退到 currentTier + if got := resp.GetTier(); got != "free-tier" { + t.Errorf("paidTier.ID 为空时应回退到 currentTier: got %s", got) + } +} + +func TestGetTier_两者都为nil(t *testing.T) { + resp := &LoadCodeAssistResponse{} + if got := resp.GetTier(); got != "" { + t.Errorf("两者都为 nil 时应返回空字符串: got %s", got) + } +} + +// --------------------------------------------------------------------------- +// NewClient +// --------------------------------------------------------------------------- + +func TestNewClient_无代理(t *testing.T) { + client := NewClient("") + if client == nil { + t.Fatal("NewClient 返回 nil") + } + if client.httpClient == nil { + t.Fatal("httpClient 为 nil") + } + if client.httpClient.Timeout != 30*time.Second { + t.Errorf("Timeout 不匹配: got %v, want 30s", client.httpClient.Timeout) + } + // 无代理时 Transport 应为 nil(使用默认) + if client.httpClient.Transport != nil { + t.Error("无代理时 Transport 应为 nil") + } +} + +func TestNewClient_有代理(t *testing.T) { + client := NewClient("http://proxy.example.com:8080") + if client == nil { + t.Fatal("NewClient 返回 nil") + } + if client.httpClient.Transport == nil { + t.Fatal("有代理时 Transport 不应为 nil") + } +} + +func TestNewClient_空格代理(t *testing.T) { + client := NewClient(" ") + if client == nil { + t.Fatal("NewClient 返回 nil") + } + // 空格代理应等同于无代理 + if client.httpClient.Transport != nil { + t.Error("空格代理 Transport 应为 nil") + } +} + +func TestNewClient_无效代理URL(t *testing.T) { + // 无效 URL 时 url.Parse 不一定返回错误(Go 的 url.Parse 很宽容), + // 但 ://invalid 会导致解析错误 + client := NewClient("://invalid") + if client == nil { + t.Fatal("NewClient 返回 nil") + } + // 无效 URL 解析失败时,Transport 应保持 nil + if client.httpClient.Transport != nil { + t.Error("无效代理 URL 时 Transport 应为 nil") + } +} + +// --------------------------------------------------------------------------- +// isConnectionError +// --------------------------------------------------------------------------- + +func TestIsConnectionError_nil(t *testing.T) { + if isConnectionError(nil) { + t.Error("nil 错误不应判定为连接错误") + } +} + +func TestIsConnectionError_超时错误(t *testing.T) { + // 使用 net.OpError 包装超时 + err := &net.OpError{ + Op: "dial", + Net: "tcp", + Err: &timeoutError{}, + } + if !isConnectionError(err) { + t.Error("超时错误应判定为连接错误") + } +} + +// timeoutError 实现 net.Error 接口用于测试 +type timeoutError struct{} + +func (e *timeoutError) Error() string { return "timeout" } +func (e *timeoutError) Timeout() bool { return true } +func (e *timeoutError) Temporary() bool { return true } + +func TestIsConnectionError_netOpError(t *testing.T) { + err := &net.OpError{ + Op: "dial", + Net: "tcp", + Err: fmt.Errorf("connection refused"), + } + if !isConnectionError(err) { + t.Error("net.OpError 应判定为连接错误") + } +} + +func TestIsConnectionError_urlError(t *testing.T) { + err := &url.Error{ + Op: "Get", + URL: "https://example.com", + Err: fmt.Errorf("some error"), + } + if !isConnectionError(err) { + t.Error("url.Error 应判定为连接错误") + } +} + +func TestIsConnectionError_普通错误(t *testing.T) { + err := fmt.Errorf("some random error") + if isConnectionError(err) { + t.Error("普通错误不应判定为连接错误") + } +} + +func TestIsConnectionError_包装的netOpError(t *testing.T) { + inner := &net.OpError{ + Op: "dial", + Net: "tcp", + Err: fmt.Errorf("connection refused"), + } + err := fmt.Errorf("wrapping: %w", inner) + if !isConnectionError(err) { + t.Error("被包装的 net.OpError 应判定为连接错误") + } +} + +// --------------------------------------------------------------------------- +// shouldFallbackToNextURL +// --------------------------------------------------------------------------- + +func TestShouldFallbackToNextURL_连接错误(t *testing.T) { + err := &net.OpError{Op: "dial", Net: "tcp", Err: fmt.Errorf("refused")} + if !shouldFallbackToNextURL(err, 0) { + t.Error("连接错误应触发 URL 降级") + } +} + +func TestShouldFallbackToNextURL_状态码(t *testing.T) { + tests := []struct { + name string + statusCode int + want bool + }{ + {"429 Too Many Requests", http.StatusTooManyRequests, true}, + {"408 Request Timeout", http.StatusRequestTimeout, true}, + {"404 Not Found", http.StatusNotFound, true}, + {"500 Internal Server Error", http.StatusInternalServerError, true}, + {"502 Bad Gateway", http.StatusBadGateway, true}, + {"503 Service Unavailable", http.StatusServiceUnavailable, true}, + {"200 OK", http.StatusOK, false}, + {"201 Created", http.StatusCreated, false}, + {"400 Bad Request", http.StatusBadRequest, false}, + {"401 Unauthorized", http.StatusUnauthorized, false}, + {"403 Forbidden", http.StatusForbidden, false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := shouldFallbackToNextURL(nil, tt.statusCode) + if got != tt.want { + t.Errorf("shouldFallbackToNextURL(nil, %d) = %v, want %v", tt.statusCode, got, tt.want) + } + }) + } +} + +func TestShouldFallbackToNextURL_无错误且200(t *testing.T) { + if shouldFallbackToNextURL(nil, http.StatusOK) { + t.Error("无错误且 200 不应触发 URL 降级") + } +} + +// --------------------------------------------------------------------------- +// Client.ExchangeCode (使用 httptest) +// --------------------------------------------------------------------------- + +func TestClient_ExchangeCode_成功(t *testing.T) { + t.Setenv(AntigravityOAuthClientSecretEnv, "test-secret") + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // 验证请求方法 + if r.Method != http.MethodPost { + t.Errorf("请求方法不匹配: got %s", r.Method) + } + // 验证 Content-Type + if ct := r.Header.Get("Content-Type"); ct != "application/x-www-form-urlencoded" { + t.Errorf("Content-Type 不匹配: got %s", ct) + } + // 验证请求体参数 + if err := r.ParseForm(); err != nil { + t.Fatalf("解析表单失败: %v", err) + } + if r.FormValue("client_id") != ClientID { + t.Errorf("client_id 不匹配: got %s", r.FormValue("client_id")) + } + if r.FormValue("client_secret") != "test-secret" { + t.Errorf("client_secret 不匹配: got %s", r.FormValue("client_secret")) + } + if r.FormValue("code") != "auth-code" { + t.Errorf("code 不匹配: got %s", r.FormValue("code")) + } + if r.FormValue("code_verifier") != "verifier123" { + t.Errorf("code_verifier 不匹配: got %s", r.FormValue("code_verifier")) + } + if r.FormValue("grant_type") != "authorization_code" { + t.Errorf("grant_type 不匹配: got %s", r.FormValue("grant_type")) + } + + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + _ = json.NewEncoder(w).Encode(TokenResponse{ + AccessToken: "access-tok", + ExpiresIn: 3600, + TokenType: "Bearer", + RefreshToken: "refresh-tok", + }) + })) + defer server.Close() + + // 临时替换 TokenURL(该函数直接使用常量,需要我们通过构建自定义 client 来绕过) + // 由于 ExchangeCode 硬编码了 TokenURL,我们需要直接测试 HTTP client 的行为 + // 这里通过构造一个直接调用 mock server 的测试 + client := &Client{httpClient: server.Client()} + + // 由于 ExchangeCode 使用硬编码的 TokenURL,我们无法直接注入 mock server URL + // 需要使用 httptest 的 Transport 重定向 + originalTokenURL := TokenURL + // 我们改为直接构造请求来测试逻辑 + _ = originalTokenURL + _ = client + + // 改用直接构造请求测试 mock server 响应 + ctx := context.Background() + params := url.Values{} + params.Set("client_id", ClientID) + params.Set("client_secret", "test-secret") + params.Set("code", "auth-code") + params.Set("redirect_uri", RedirectURI) + params.Set("grant_type", "authorization_code") + params.Set("code_verifier", "verifier123") + + req, err := http.NewRequestWithContext(ctx, http.MethodPost, server.URL, strings.NewReader(params.Encode())) + if err != nil { + t.Fatalf("创建请求失败: %v", err) + } + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + + resp, err := server.Client().Do(req) + if err != nil { + t.Fatalf("请求失败: %v", err) + } + defer func() { _ = resp.Body.Close() }() + + if resp.StatusCode != http.StatusOK { + t.Fatalf("状态码不匹配: got %d", resp.StatusCode) + } + + var tokenResp TokenResponse + if err := json.NewDecoder(resp.Body).Decode(&tokenResp); err != nil { + t.Fatalf("解码失败: %v", err) + } + if tokenResp.AccessToken != "access-tok" { + t.Errorf("AccessToken 不匹配: got %s", tokenResp.AccessToken) + } + if tokenResp.RefreshToken != "refresh-tok" { + t.Errorf("RefreshToken 不匹配: got %s", tokenResp.RefreshToken) + } +} + +func TestClient_ExchangeCode_无ClientSecret(t *testing.T) { + t.Setenv(AntigravityOAuthClientSecretEnv, "") + + client := NewClient("") + _, err := client.ExchangeCode(context.Background(), "code", "verifier") + if err == nil { + t.Fatal("缺少 client_secret 时应返回错误") + } + if !strings.Contains(err.Error(), AntigravityOAuthClientSecretEnv) { + t.Errorf("错误信息应包含环境变量名: got %s", err.Error()) + } +} + +func TestClient_ExchangeCode_服务器返回错误(t *testing.T) { + t.Setenv(AntigravityOAuthClientSecretEnv, "test-secret") + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusBadRequest) + _, _ = w.Write([]byte(`{"error":"invalid_grant"}`)) + })) + defer server.Close() + + // 直接测试 mock server 的错误响应 + resp, err := server.Client().Get(server.URL) + if err != nil { + t.Fatalf("请求失败: %v", err) + } + defer func() { _ = resp.Body.Close() }() + + if resp.StatusCode != http.StatusBadRequest { + t.Errorf("状态码不匹配: got %d, want 400", resp.StatusCode) + } +} + +// --------------------------------------------------------------------------- +// Client.RefreshToken (使用 httptest) +// --------------------------------------------------------------------------- + +func TestClient_RefreshToken_MockServer(t *testing.T) { + t.Setenv(AntigravityOAuthClientSecretEnv, "test-secret") + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPost { + t.Errorf("请求方法不匹配: got %s", r.Method) + } + if err := r.ParseForm(); err != nil { + t.Fatalf("解析表单失败: %v", err) + } + if r.FormValue("grant_type") != "refresh_token" { + t.Errorf("grant_type 不匹配: got %s", r.FormValue("grant_type")) + } + if r.FormValue("refresh_token") != "old-refresh-tok" { + t.Errorf("refresh_token 不匹配: got %s", r.FormValue("refresh_token")) + } + + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + _ = json.NewEncoder(w).Encode(TokenResponse{ + AccessToken: "new-access-tok", + ExpiresIn: 3600, + TokenType: "Bearer", + }) + })) + defer server.Close() + + ctx := context.Background() + params := url.Values{} + params.Set("client_id", ClientID) + params.Set("client_secret", "test-secret") + params.Set("refresh_token", "old-refresh-tok") + params.Set("grant_type", "refresh_token") + + req, err := http.NewRequestWithContext(ctx, http.MethodPost, server.URL, strings.NewReader(params.Encode())) + if err != nil { + t.Fatalf("创建请求失败: %v", err) + } + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + + resp, err := server.Client().Do(req) + if err != nil { + t.Fatalf("请求失败: %v", err) + } + defer func() { _ = resp.Body.Close() }() + + if resp.StatusCode != http.StatusOK { + t.Fatalf("状态码不匹配: got %d", resp.StatusCode) + } + + var tokenResp TokenResponse + if err := json.NewDecoder(resp.Body).Decode(&tokenResp); err != nil { + t.Fatalf("解码失败: %v", err) + } + if tokenResp.AccessToken != "new-access-tok" { + t.Errorf("AccessToken 不匹配: got %s", tokenResp.AccessToken) + } +} + +func TestClient_RefreshToken_无ClientSecret(t *testing.T) { + t.Setenv(AntigravityOAuthClientSecretEnv, "") + + client := NewClient("") + _, err := client.RefreshToken(context.Background(), "refresh-tok") + if err == nil { + t.Fatal("缺少 client_secret 时应返回错误") + } +} + +// --------------------------------------------------------------------------- +// Client.GetUserInfo (使用 httptest) +// --------------------------------------------------------------------------- + +func TestClient_GetUserInfo_成功(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodGet { + t.Errorf("请求方法不匹配: got %s", r.Method) + } + auth := r.Header.Get("Authorization") + if auth != "Bearer test-access-token" { + t.Errorf("Authorization 不匹配: got %s", auth) + } + + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + _ = json.NewEncoder(w).Encode(UserInfo{ + Email: "user@example.com", + Name: "Test User", + GivenName: "Test", + FamilyName: "User", + Picture: "https://example.com/photo.jpg", + }) + })) + defer server.Close() + + // 直接通过 mock server 测试 GetUserInfo 的行为逻辑 + ctx := context.Background() + req, err := http.NewRequestWithContext(ctx, http.MethodGet, server.URL, nil) + if err != nil { + t.Fatalf("创建请求失败: %v", err) + } + req.Header.Set("Authorization", "Bearer test-access-token") + + resp, err := server.Client().Do(req) + if err != nil { + t.Fatalf("请求失败: %v", err) + } + defer func() { _ = resp.Body.Close() }() + + if resp.StatusCode != http.StatusOK { + t.Fatalf("状态码不匹配: got %d", resp.StatusCode) + } + + var userInfo UserInfo + if err := json.NewDecoder(resp.Body).Decode(&userInfo); err != nil { + t.Fatalf("解码失败: %v", err) + } + if userInfo.Email != "user@example.com" { + t.Errorf("Email 不匹配: got %s", userInfo.Email) + } + if userInfo.Name != "Test User" { + t.Errorf("Name 不匹配: got %s", userInfo.Name) + } +} + +func TestClient_GetUserInfo_服务器返回错误(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusUnauthorized) + _, _ = w.Write([]byte(`{"error":"invalid_token"}`)) + })) + defer server.Close() + + resp, err := server.Client().Get(server.URL) + if err != nil { + t.Fatalf("请求失败: %v", err) + } + defer func() { _ = resp.Body.Close() }() + + if resp.StatusCode != http.StatusUnauthorized { + t.Errorf("状态码不匹配: got %d, want 401", resp.StatusCode) + } +} + +// --------------------------------------------------------------------------- +// TokenResponse / UserInfo JSON 序列化 +// --------------------------------------------------------------------------- + +func TestTokenResponse_JSON序列化(t *testing.T) { + jsonData := `{"access_token":"at","expires_in":3600,"token_type":"Bearer","scope":"openid","refresh_token":"rt"}` + var resp TokenResponse + if err := json.Unmarshal([]byte(jsonData), &resp); err != nil { + t.Fatalf("反序列化失败: %v", err) + } + if resp.AccessToken != "at" { + t.Errorf("AccessToken 不匹配: got %s", resp.AccessToken) + } + if resp.ExpiresIn != 3600 { + t.Errorf("ExpiresIn 不匹配: got %d", resp.ExpiresIn) + } + if resp.RefreshToken != "rt" { + t.Errorf("RefreshToken 不匹配: got %s", resp.RefreshToken) + } +} + +func TestUserInfo_JSON序列化(t *testing.T) { + jsonData := `{"email":"a@b.com","name":"Alice"}` + var info UserInfo + if err := json.Unmarshal([]byte(jsonData), &info); err != nil { + t.Fatalf("反序列化失败: %v", err) + } + if info.Email != "a@b.com" { + t.Errorf("Email 不匹配: got %s", info.Email) + } + if info.Name != "Alice" { + t.Errorf("Name 不匹配: got %s", info.Name) + } +} + +// --------------------------------------------------------------------------- +// LoadCodeAssistResponse JSON 序列化 +// --------------------------------------------------------------------------- + +func TestLoadCodeAssistResponse_完整JSON(t *testing.T) { + jsonData := `{ + "cloudaicompanionProject": "proj-123", + "currentTier": "free-tier", + "paidTier": {"id": "g1-pro-tier", "name": "Pro"}, + "ineligibleTiers": [{"tier": {"id": "g1-ultra-tier"}, "reasonCode": "INELIGIBLE_ACCOUNT"}] + }` + var resp LoadCodeAssistResponse + if err := json.Unmarshal([]byte(jsonData), &resp); err != nil { + t.Fatalf("反序列化失败: %v", err) + } + if resp.CloudAICompanionProject != "proj-123" { + t.Errorf("CloudAICompanionProject 不匹配: got %s", resp.CloudAICompanionProject) + } + if resp.GetTier() != "g1-pro-tier" { + t.Errorf("GetTier 不匹配: got %s", resp.GetTier()) + } + if len(resp.IneligibleTiers) != 1 { + t.Fatalf("IneligibleTiers 数量不匹配: got %d", len(resp.IneligibleTiers)) + } + if resp.IneligibleTiers[0].ReasonCode != "INELIGIBLE_ACCOUNT" { + t.Errorf("ReasonCode 不匹配: got %s", resp.IneligibleTiers[0].ReasonCode) + } +} + +// =========================================================================== +// 以下为新增测试:真正调用 Client 方法,通过 RoundTripper 拦截 HTTP 请求 +// =========================================================================== + +// redirectRoundTripper 将请求中特定前缀的 URL 重定向到 httptest server +type redirectRoundTripper struct { + // 原始 URL 前缀 -> 替换目标 URL 的映射 + redirects map[string]string + transport http.RoundTripper +} + +func (rt *redirectRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) { + originalURL := req.URL.String() + for prefix, target := range rt.redirects { + if strings.HasPrefix(originalURL, prefix) { + newURL := target + strings.TrimPrefix(originalURL, prefix) + parsed, err := url.Parse(newURL) + if err != nil { + return nil, err + } + req.URL = parsed + break + } + } + if rt.transport == nil { + return http.DefaultTransport.RoundTrip(req) + } + return rt.transport.RoundTrip(req) +} + +// newTestClientWithRedirect 创建一个 Client,将指定 URL 前缀的请求重定向到 mock server +func newTestClientWithRedirect(redirects map[string]string) *Client { + return &Client{ + httpClient: &http.Client{ + Timeout: 10 * time.Second, + Transport: &redirectRoundTripper{ + redirects: redirects, + }, + }, + } +} + +// --------------------------------------------------------------------------- +// Client.ExchangeCode - 真正调用方法的测试 +// --------------------------------------------------------------------------- + +func TestClient_ExchangeCode_Success_RealCall(t *testing.T) { + t.Setenv(AntigravityOAuthClientSecretEnv, "test-secret") + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPost { + t.Errorf("请求方法不匹配: got %s, want POST", r.Method) + } + if ct := r.Header.Get("Content-Type"); ct != "application/x-www-form-urlencoded" { + t.Errorf("Content-Type 不匹配: got %s", ct) + } + if err := r.ParseForm(); err != nil { + t.Fatalf("解析表单失败: %v", err) + } + if r.FormValue("client_id") != ClientID { + t.Errorf("client_id 不匹配: got %s", r.FormValue("client_id")) + } + if r.FormValue("client_secret") != "test-secret" { + t.Errorf("client_secret 不匹配: got %s", r.FormValue("client_secret")) + } + if r.FormValue("code") != "test-auth-code" { + t.Errorf("code 不匹配: got %s", r.FormValue("code")) + } + if r.FormValue("code_verifier") != "test-verifier" { + t.Errorf("code_verifier 不匹配: got %s", r.FormValue("code_verifier")) + } + if r.FormValue("grant_type") != "authorization_code" { + t.Errorf("grant_type 不匹配: got %s", r.FormValue("grant_type")) + } + if r.FormValue("redirect_uri") != RedirectURI { + t.Errorf("redirect_uri 不匹配: got %s", r.FormValue("redirect_uri")) + } + + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + _ = json.NewEncoder(w).Encode(TokenResponse{ + AccessToken: "new-access-token", + ExpiresIn: 3600, + TokenType: "Bearer", + Scope: "openid email", + RefreshToken: "new-refresh-token", + }) + })) + defer server.Close() + + client := newTestClientWithRedirect(map[string]string{ + TokenURL: server.URL, + }) + + tokenResp, err := client.ExchangeCode(context.Background(), "test-auth-code", "test-verifier") + if err != nil { + t.Fatalf("ExchangeCode 失败: %v", err) + } + if tokenResp.AccessToken != "new-access-token" { + t.Errorf("AccessToken 不匹配: got %s, want new-access-token", tokenResp.AccessToken) + } + if tokenResp.RefreshToken != "new-refresh-token" { + t.Errorf("RefreshToken 不匹配: got %s, want new-refresh-token", tokenResp.RefreshToken) + } + if tokenResp.ExpiresIn != 3600 { + t.Errorf("ExpiresIn 不匹配: got %d, want 3600", tokenResp.ExpiresIn) + } + if tokenResp.TokenType != "Bearer" { + t.Errorf("TokenType 不匹配: got %s, want Bearer", tokenResp.TokenType) + } + if tokenResp.Scope != "openid email" { + t.Errorf("Scope 不匹配: got %s, want openid email", tokenResp.Scope) + } +} + +func TestClient_ExchangeCode_ServerError_RealCall(t *testing.T) { + t.Setenv(AntigravityOAuthClientSecretEnv, "test-secret") + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusBadRequest) + _, _ = w.Write([]byte(`{"error":"invalid_grant","error_description":"code expired"}`)) + })) + defer server.Close() + + client := newTestClientWithRedirect(map[string]string{ + TokenURL: server.URL, + }) + + _, err := client.ExchangeCode(context.Background(), "expired-code", "verifier") + if err == nil { + t.Fatal("服务器返回 400 时应返回错误") + } + if !strings.Contains(err.Error(), "token 交换失败") { + t.Errorf("错误信息应包含 'token 交换失败': got %s", err.Error()) + } + if !strings.Contains(err.Error(), "400") { + t.Errorf("错误信息应包含状态码 400: got %s", err.Error()) + } +} + +func TestClient_ExchangeCode_InvalidJSON_RealCall(t *testing.T) { + t.Setenv(AntigravityOAuthClientSecretEnv, "test-secret") + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte(`{invalid json`)) + })) + defer server.Close() + + client := newTestClientWithRedirect(map[string]string{ + TokenURL: server.URL, + }) + + _, err := client.ExchangeCode(context.Background(), "code", "verifier") + if err == nil { + t.Fatal("无效 JSON 响应应返回错误") + } + if !strings.Contains(err.Error(), "token 解析失败") { + t.Errorf("错误信息应包含 'token 解析失败': got %s", err.Error()) + } +} + +func TestClient_ExchangeCode_ContextCanceled_RealCall(t *testing.T) { + t.Setenv(AntigravityOAuthClientSecretEnv, "test-secret") + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + time.Sleep(5 * time.Second) // 模拟慢响应 + w.WriteHeader(http.StatusOK) + })) + defer server.Close() + + client := newTestClientWithRedirect(map[string]string{ + TokenURL: server.URL, + }) + + ctx, cancel := context.WithCancel(context.Background()) + cancel() // 立即取消 + + _, err := client.ExchangeCode(ctx, "code", "verifier") + if err == nil { + t.Fatal("context 取消时应返回错误") + } +} + +// --------------------------------------------------------------------------- +// Client.RefreshToken - 真正调用方法的测试 +// --------------------------------------------------------------------------- + +func TestClient_RefreshToken_Success_RealCall(t *testing.T) { + t.Setenv(AntigravityOAuthClientSecretEnv, "test-secret") + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPost { + t.Errorf("请求方法不匹配: got %s, want POST", r.Method) + } + if err := r.ParseForm(); err != nil { + t.Fatalf("解析表单失败: %v", err) + } + if r.FormValue("grant_type") != "refresh_token" { + t.Errorf("grant_type 不匹配: got %s", r.FormValue("grant_type")) + } + if r.FormValue("refresh_token") != "my-refresh-token" { + t.Errorf("refresh_token 不匹配: got %s", r.FormValue("refresh_token")) + } + if r.FormValue("client_id") != ClientID { + t.Errorf("client_id 不匹配: got %s", r.FormValue("client_id")) + } + if r.FormValue("client_secret") != "test-secret" { + t.Errorf("client_secret 不匹配: got %s", r.FormValue("client_secret")) + } + + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + _ = json.NewEncoder(w).Encode(TokenResponse{ + AccessToken: "refreshed-access-token", + ExpiresIn: 3600, + TokenType: "Bearer", + }) + })) + defer server.Close() + + client := newTestClientWithRedirect(map[string]string{ + TokenURL: server.URL, + }) + + tokenResp, err := client.RefreshToken(context.Background(), "my-refresh-token") + if err != nil { + t.Fatalf("RefreshToken 失败: %v", err) + } + if tokenResp.AccessToken != "refreshed-access-token" { + t.Errorf("AccessToken 不匹配: got %s, want refreshed-access-token", tokenResp.AccessToken) + } + if tokenResp.ExpiresIn != 3600 { + t.Errorf("ExpiresIn 不匹配: got %d, want 3600", tokenResp.ExpiresIn) + } +} + +func TestClient_RefreshToken_ServerError_RealCall(t *testing.T) { + t.Setenv(AntigravityOAuthClientSecretEnv, "test-secret") + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusUnauthorized) + _, _ = w.Write([]byte(`{"error":"invalid_grant","error_description":"token revoked"}`)) + })) + defer server.Close() + + client := newTestClientWithRedirect(map[string]string{ + TokenURL: server.URL, + }) + + _, err := client.RefreshToken(context.Background(), "revoked-token") + if err == nil { + t.Fatal("服务器返回 401 时应返回错误") + } + if !strings.Contains(err.Error(), "token 刷新失败") { + t.Errorf("错误信息应包含 'token 刷新失败': got %s", err.Error()) + } +} + +func TestClient_RefreshToken_InvalidJSON_RealCall(t *testing.T) { + t.Setenv(AntigravityOAuthClientSecretEnv, "test-secret") + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte(`not-json`)) + })) + defer server.Close() + + client := newTestClientWithRedirect(map[string]string{ + TokenURL: server.URL, + }) + + _, err := client.RefreshToken(context.Background(), "refresh-tok") + if err == nil { + t.Fatal("无效 JSON 响应应返回错误") + } + if !strings.Contains(err.Error(), "token 解析失败") { + t.Errorf("错误信息应包含 'token 解析失败': got %s", err.Error()) + } +} + +func TestClient_RefreshToken_ContextCanceled_RealCall(t *testing.T) { + t.Setenv(AntigravityOAuthClientSecretEnv, "test-secret") + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + time.Sleep(5 * time.Second) + w.WriteHeader(http.StatusOK) + })) + defer server.Close() + + client := newTestClientWithRedirect(map[string]string{ + TokenURL: server.URL, + }) + + ctx, cancel := context.WithCancel(context.Background()) + cancel() + + _, err := client.RefreshToken(ctx, "refresh-tok") + if err == nil { + t.Fatal("context 取消时应返回错误") + } +} + +// --------------------------------------------------------------------------- +// Client.GetUserInfo - 真正调用方法的测试 +// --------------------------------------------------------------------------- + +func TestClient_GetUserInfo_Success_RealCall(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodGet { + t.Errorf("请求方法不匹配: got %s, want GET", r.Method) + } + auth := r.Header.Get("Authorization") + if auth != "Bearer user-access-token" { + t.Errorf("Authorization 不匹配: got %s", auth) + } + + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + _ = json.NewEncoder(w).Encode(UserInfo{ + Email: "test@example.com", + Name: "Test User", + GivenName: "Test", + FamilyName: "User", + Picture: "https://example.com/avatar.jpg", + }) + })) + defer server.Close() + + client := newTestClientWithRedirect(map[string]string{ + UserInfoURL: server.URL, + }) + + userInfo, err := client.GetUserInfo(context.Background(), "user-access-token") + if err != nil { + t.Fatalf("GetUserInfo 失败: %v", err) + } + if userInfo.Email != "test@example.com" { + t.Errorf("Email 不匹配: got %s, want test@example.com", userInfo.Email) + } + if userInfo.Name != "Test User" { + t.Errorf("Name 不匹配: got %s, want Test User", userInfo.Name) + } + if userInfo.GivenName != "Test" { + t.Errorf("GivenName 不匹配: got %s, want Test", userInfo.GivenName) + } + if userInfo.FamilyName != "User" { + t.Errorf("FamilyName 不匹配: got %s, want User", userInfo.FamilyName) + } + if userInfo.Picture != "https://example.com/avatar.jpg" { + t.Errorf("Picture 不匹配: got %s", userInfo.Picture) + } +} + +func TestClient_GetUserInfo_Unauthorized_RealCall(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusUnauthorized) + _, _ = w.Write([]byte(`{"error":"invalid_token"}`)) + })) + defer server.Close() + + client := newTestClientWithRedirect(map[string]string{ + UserInfoURL: server.URL, + }) + + _, err := client.GetUserInfo(context.Background(), "bad-token") + if err == nil { + t.Fatal("服务器返回 401 时应返回错误") + } + if !strings.Contains(err.Error(), "获取用户信息失败") { + t.Errorf("错误信息应包含 '获取用户信息失败': got %s", err.Error()) + } + if !strings.Contains(err.Error(), "401") { + t.Errorf("错误信息应包含状态码 401: got %s", err.Error()) + } +} + +func TestClient_GetUserInfo_InvalidJSON_RealCall(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte(`{broken`)) + })) + defer server.Close() + + client := newTestClientWithRedirect(map[string]string{ + UserInfoURL: server.URL, + }) + + _, err := client.GetUserInfo(context.Background(), "token") + if err == nil { + t.Fatal("无效 JSON 响应应返回错误") + } + if !strings.Contains(err.Error(), "用户信息解析失败") { + t.Errorf("错误信息应包含 '用户信息解析失败': got %s", err.Error()) + } +} + +func TestClient_GetUserInfo_ContextCanceled_RealCall(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + time.Sleep(5 * time.Second) + w.WriteHeader(http.StatusOK) + })) + defer server.Close() + + client := newTestClientWithRedirect(map[string]string{ + UserInfoURL: server.URL, + }) + + ctx, cancel := context.WithCancel(context.Background()) + cancel() + + _, err := client.GetUserInfo(ctx, "token") + if err == nil { + t.Fatal("context 取消时应返回错误") + } +} + +// --------------------------------------------------------------------------- +// Client.LoadCodeAssist - 真正调用方法的测试 +// --------------------------------------------------------------------------- + +// withMockBaseURLs 临时替换 BaseURLs,测试结束后恢复 +func withMockBaseURLs(t *testing.T, urls []string) { + t.Helper() + origBaseURLs := BaseURLs + origBaseURL := BaseURL + BaseURLs = urls + if len(urls) > 0 { + BaseURL = urls[0] + } + t.Cleanup(func() { + BaseURLs = origBaseURLs + BaseURL = origBaseURL + }) +} + +func TestClient_LoadCodeAssist_Success_RealCall(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPost { + t.Errorf("请求方法不匹配: got %s, want POST", r.Method) + } + if !strings.HasSuffix(r.URL.Path, "/v1internal:loadCodeAssist") { + t.Errorf("URL 路径不匹配: got %s", r.URL.Path) + } + auth := r.Header.Get("Authorization") + if auth != "Bearer test-token" { + t.Errorf("Authorization 不匹配: got %s", auth) + } + if ct := r.Header.Get("Content-Type"); ct != "application/json" { + t.Errorf("Content-Type 不匹配: got %s", ct) + } + if ua := r.Header.Get("User-Agent"); ua != UserAgent { + t.Errorf("User-Agent 不匹配: got %s", ua) + } + + // 验证请求体 + var reqBody LoadCodeAssistRequest + if err := json.NewDecoder(r.Body).Decode(&reqBody); err != nil { + t.Fatalf("解析请求体失败: %v", err) + } + if reqBody.Metadata.IDEType != "ANTIGRAVITY" { + t.Errorf("IDEType 不匹配: got %s, want ANTIGRAVITY", reqBody.Metadata.IDEType) + } + + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte(`{ + "cloudaicompanionProject": "test-project-123", + "currentTier": {"id": "free-tier", "name": "Free"}, + "paidTier": {"id": "g1-pro-tier", "name": "Pro", "description": "Pro plan"} + }`)) + })) + defer server.Close() + + withMockBaseURLs(t, []string{server.URL}) + + client := NewClient("") + resp, rawResp, err := client.LoadCodeAssist(context.Background(), "test-token") + if err != nil { + t.Fatalf("LoadCodeAssist 失败: %v", err) + } + if resp.CloudAICompanionProject != "test-project-123" { + t.Errorf("CloudAICompanionProject 不匹配: got %s", resp.CloudAICompanionProject) + } + if resp.GetTier() != "g1-pro-tier" { + t.Errorf("GetTier 不匹配: got %s, want g1-pro-tier", resp.GetTier()) + } + if resp.CurrentTier == nil || resp.CurrentTier.ID != "free-tier" { + t.Errorf("CurrentTier 不匹配: got %+v", resp.CurrentTier) + } + if resp.PaidTier == nil || resp.PaidTier.ID != "g1-pro-tier" { + t.Errorf("PaidTier 不匹配: got %+v", resp.PaidTier) + } + // 验证原始 JSON map + if rawResp == nil { + t.Fatal("rawResp 不应为 nil") + } + if rawResp["cloudaicompanionProject"] != "test-project-123" { + t.Errorf("rawResp cloudaicompanionProject 不匹配: got %v", rawResp["cloudaicompanionProject"]) + } +} + +func TestClient_LoadCodeAssist_HTTPError_RealCall(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusForbidden) + _, _ = w.Write([]byte(`{"error":"forbidden"}`)) + })) + defer server.Close() + + withMockBaseURLs(t, []string{server.URL}) + + client := NewClient("") + _, _, err := client.LoadCodeAssist(context.Background(), "bad-token") + if err == nil { + t.Fatal("服务器返回 403 时应返回错误") + } + if !strings.Contains(err.Error(), "loadCodeAssist 失败") { + t.Errorf("错误信息应包含 'loadCodeAssist 失败': got %s", err.Error()) + } + if !strings.Contains(err.Error(), "403") { + t.Errorf("错误信息应包含状态码 403: got %s", err.Error()) + } +} + +func TestClient_LoadCodeAssist_InvalidJSON_RealCall(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte(`{not valid json!!!`)) + })) + defer server.Close() + + withMockBaseURLs(t, []string{server.URL}) + + client := NewClient("") + _, _, err := client.LoadCodeAssist(context.Background(), "token") + if err == nil { + t.Fatal("无效 JSON 响应应返回错误") + } + if !strings.Contains(err.Error(), "响应解析失败") { + t.Errorf("错误信息应包含 '响应解析失败': got %s", err.Error()) + } +} + +func TestClient_LoadCodeAssist_URLFallback_RealCall(t *testing.T) { + // 第一个 server 返回 500,第二个 server 返回成功 + callCount := 0 + server1 := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + callCount++ + w.WriteHeader(http.StatusInternalServerError) + _, _ = w.Write([]byte(`{"error":"internal"}`)) + })) + defer server1.Close() + + server2 := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + callCount++ + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte(`{ + "cloudaicompanionProject": "fallback-project", + "currentTier": {"id": "free-tier", "name": "Free"} + }`)) + })) + defer server2.Close() + + withMockBaseURLs(t, []string{server1.URL, server2.URL}) + + client := NewClient("") + resp, _, err := client.LoadCodeAssist(context.Background(), "token") + if err != nil { + t.Fatalf("LoadCodeAssist 应在 fallback 后成功: %v", err) + } + if resp.CloudAICompanionProject != "fallback-project" { + t.Errorf("CloudAICompanionProject 不匹配: got %s", resp.CloudAICompanionProject) + } + if callCount != 2 { + t.Errorf("应该调用了 2 个 server,实际调用 %d 次", callCount) + } +} + +func TestClient_LoadCodeAssist_AllURLsFail_RealCall(t *testing.T) { + server1 := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusServiceUnavailable) + _, _ = w.Write([]byte(`{"error":"unavailable"}`)) + })) + defer server1.Close() + + server2 := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusBadGateway) + _, _ = w.Write([]byte(`{"error":"bad_gateway"}`)) + })) + defer server2.Close() + + withMockBaseURLs(t, []string{server1.URL, server2.URL}) + + client := NewClient("") + _, _, err := client.LoadCodeAssist(context.Background(), "token") + if err == nil { + t.Fatal("所有 URL 都失败时应返回错误") + } +} + +func TestClient_LoadCodeAssist_ContextCanceled_RealCall(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + time.Sleep(5 * time.Second) + w.WriteHeader(http.StatusOK) + })) + defer server.Close() + + withMockBaseURLs(t, []string{server.URL}) + + client := NewClient("") + ctx, cancel := context.WithCancel(context.Background()) + cancel() + + _, _, err := client.LoadCodeAssist(ctx, "token") + if err == nil { + t.Fatal("context 取消时应返回错误") + } +} + +// --------------------------------------------------------------------------- +// Client.FetchAvailableModels - 真正调用方法的测试 +// --------------------------------------------------------------------------- + +func TestClient_FetchAvailableModels_Success_RealCall(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPost { + t.Errorf("请求方法不匹配: got %s, want POST", r.Method) + } + if !strings.HasSuffix(r.URL.Path, "/v1internal:fetchAvailableModels") { + t.Errorf("URL 路径不匹配: got %s", r.URL.Path) + } + auth := r.Header.Get("Authorization") + if auth != "Bearer test-token" { + t.Errorf("Authorization 不匹配: got %s", auth) + } + if ct := r.Header.Get("Content-Type"); ct != "application/json" { + t.Errorf("Content-Type 不匹配: got %s", ct) + } + if ua := r.Header.Get("User-Agent"); ua != UserAgent { + t.Errorf("User-Agent 不匹配: got %s", ua) + } + + // 验证请求体 + var reqBody FetchAvailableModelsRequest + if err := json.NewDecoder(r.Body).Decode(&reqBody); err != nil { + t.Fatalf("解析请求体失败: %v", err) + } + if reqBody.Project != "project-abc" { + t.Errorf("Project 不匹配: got %s, want project-abc", reqBody.Project) + } + + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte(`{ + "models": { + "gemini-2.0-flash": { + "quotaInfo": { + "remainingFraction": 0.85, + "resetTime": "2025-01-01T00:00:00Z" + } + }, + "gemini-2.5-pro": { + "quotaInfo": { + "remainingFraction": 0.5 + } + } + } + }`)) + })) + defer server.Close() + + withMockBaseURLs(t, []string{server.URL}) + + client := NewClient("") + resp, rawResp, err := client.FetchAvailableModels(context.Background(), "test-token", "project-abc") + if err != nil { + t.Fatalf("FetchAvailableModels 失败: %v", err) + } + if resp.Models == nil { + t.Fatal("Models 不应为 nil") + } + if len(resp.Models) != 2 { + t.Errorf("Models 数量不匹配: got %d, want 2", len(resp.Models)) + } + + flashModel, ok := resp.Models["gemini-2.0-flash"] + if !ok { + t.Fatal("缺少 gemini-2.0-flash 模型") + } + if flashModel.QuotaInfo == nil { + t.Fatal("gemini-2.0-flash QuotaInfo 不应为 nil") + } + if flashModel.QuotaInfo.RemainingFraction != 0.85 { + t.Errorf("RemainingFraction 不匹配: got %f, want 0.85", flashModel.QuotaInfo.RemainingFraction) + } + if flashModel.QuotaInfo.ResetTime != "2025-01-01T00:00:00Z" { + t.Errorf("ResetTime 不匹配: got %s", flashModel.QuotaInfo.ResetTime) + } + + proModel, ok := resp.Models["gemini-2.5-pro"] + if !ok { + t.Fatal("缺少 gemini-2.5-pro 模型") + } + if proModel.QuotaInfo == nil { + t.Fatal("gemini-2.5-pro QuotaInfo 不应为 nil") + } + if proModel.QuotaInfo.RemainingFraction != 0.5 { + t.Errorf("RemainingFraction 不匹配: got %f, want 0.5", proModel.QuotaInfo.RemainingFraction) + } + + // 验证原始 JSON map + if rawResp == nil { + t.Fatal("rawResp 不应为 nil") + } + if rawResp["models"] == nil { + t.Error("rawResp models 不应为 nil") + } +} + +func TestClient_FetchAvailableModels_HTTPError_RealCall(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusForbidden) + _, _ = w.Write([]byte(`{"error":"forbidden"}`)) + })) + defer server.Close() + + withMockBaseURLs(t, []string{server.URL}) + + client := NewClient("") + _, _, err := client.FetchAvailableModels(context.Background(), "bad-token", "proj") + if err == nil { + t.Fatal("服务器返回 403 时应返回错误") + } + if !strings.Contains(err.Error(), "fetchAvailableModels 失败") { + t.Errorf("错误信息应包含 'fetchAvailableModels 失败': got %s", err.Error()) + } +} + +func TestClient_FetchAvailableModels_InvalidJSON_RealCall(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte(`<<>>`)) + })) + defer server.Close() + + withMockBaseURLs(t, []string{server.URL}) + + client := NewClient("") + _, _, err := client.FetchAvailableModels(context.Background(), "token", "proj") + if err == nil { + t.Fatal("无效 JSON 响应应返回错误") + } + if !strings.Contains(err.Error(), "响应解析失败") { + t.Errorf("错误信息应包含 '响应解析失败': got %s", err.Error()) + } +} + +func TestClient_FetchAvailableModels_URLFallback_RealCall(t *testing.T) { + callCount := 0 + // 第一个 server 返回 429,第二个 server 返回成功 + server1 := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + callCount++ + w.WriteHeader(http.StatusTooManyRequests) + _, _ = w.Write([]byte(`{"error":"rate_limited"}`)) + })) + defer server1.Close() + + server2 := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + callCount++ + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte(`{"models": {"model-a": {}}}`)) + })) + defer server2.Close() + + withMockBaseURLs(t, []string{server1.URL, server2.URL}) + + client := NewClient("") + resp, _, err := client.FetchAvailableModels(context.Background(), "token", "proj") + if err != nil { + t.Fatalf("FetchAvailableModels 应在 fallback 后成功: %v", err) + } + if _, ok := resp.Models["model-a"]; !ok { + t.Error("应返回 fallback server 的模型") + } + if callCount != 2 { + t.Errorf("应该调用了 2 个 server,实际调用 %d 次", callCount) + } +} + +func TestClient_FetchAvailableModels_AllURLsFail_RealCall(t *testing.T) { + server1 := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusNotFound) + _, _ = w.Write([]byte(`not found`)) + })) + defer server1.Close() + + server2 := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusInternalServerError) + _, _ = w.Write([]byte(`internal error`)) + })) + defer server2.Close() + + withMockBaseURLs(t, []string{server1.URL, server2.URL}) + + client := NewClient("") + _, _, err := client.FetchAvailableModels(context.Background(), "token", "proj") + if err == nil { + t.Fatal("所有 URL 都失败时应返回错误") + } +} + +func TestClient_FetchAvailableModels_ContextCanceled_RealCall(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + time.Sleep(5 * time.Second) + w.WriteHeader(http.StatusOK) + })) + defer server.Close() + + withMockBaseURLs(t, []string{server.URL}) + + client := NewClient("") + ctx, cancel := context.WithCancel(context.Background()) + cancel() + + _, _, err := client.FetchAvailableModels(ctx, "token", "proj") + if err == nil { + t.Fatal("context 取消时应返回错误") + } +} + +func TestClient_FetchAvailableModels_EmptyModels_RealCall(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte(`{"models": {}}`)) + })) + defer server.Close() + + withMockBaseURLs(t, []string{server.URL}) + + client := NewClient("") + resp, rawResp, err := client.FetchAvailableModels(context.Background(), "token", "proj") + if err != nil { + t.Fatalf("FetchAvailableModels 失败: %v", err) + } + if resp.Models == nil { + t.Fatal("Models 不应为 nil") + } + if len(resp.Models) != 0 { + t.Errorf("Models 应为空: got %d", len(resp.Models)) + } + if rawResp == nil { + t.Fatal("rawResp 不应为 nil") + } +} + +// --------------------------------------------------------------------------- +// LoadCodeAssist 和 FetchAvailableModels 的 408 fallback 测试 +// --------------------------------------------------------------------------- + +func TestClient_LoadCodeAssist_408Fallback_RealCall(t *testing.T) { + server1 := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusRequestTimeout) + _, _ = w.Write([]byte(`timeout`)) + })) + defer server1.Close() + + server2 := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte(`{"cloudaicompanionProject":"p2","currentTier":"free-tier"}`)) + })) + defer server2.Close() + + withMockBaseURLs(t, []string{server1.URL, server2.URL}) + + client := NewClient("") + resp, _, err := client.LoadCodeAssist(context.Background(), "token") + if err != nil { + t.Fatalf("LoadCodeAssist 应在 408 fallback 后成功: %v", err) + } + if resp.CloudAICompanionProject != "p2" { + t.Errorf("CloudAICompanionProject 不匹配: got %s", resp.CloudAICompanionProject) + } +} + +func TestClient_FetchAvailableModels_404Fallback_RealCall(t *testing.T) { + server1 := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusNotFound) + _, _ = w.Write([]byte(`not found`)) + })) + defer server1.Close() + + server2 := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte(`{"models":{"m1":{"quotaInfo":{"remainingFraction":1.0}}}}`)) + })) + defer server2.Close() + + withMockBaseURLs(t, []string{server1.URL, server2.URL}) + + client := NewClient("") + resp, _, err := client.FetchAvailableModels(context.Background(), "token", "proj") + if err != nil { + t.Fatalf("FetchAvailableModels 应在 404 fallback 后成功: %v", err) + } + if _, ok := resp.Models["m1"]; !ok { + t.Error("应返回 fallback server 的模型 m1") + } +} diff --git a/backend/internal/pkg/antigravity/oauth_test.go b/backend/internal/pkg/antigravity/oauth_test.go new file mode 100644 index 00000000..67731c06 --- /dev/null +++ b/backend/internal/pkg/antigravity/oauth_test.go @@ -0,0 +1,704 @@ +//go:build unit + +package antigravity + +import ( + "crypto/sha256" + "encoding/base64" + "encoding/hex" + "net/url" + "strings" + "testing" + "time" +) + +// --------------------------------------------------------------------------- +// getClientSecret +// --------------------------------------------------------------------------- + +func TestGetClientSecret_环境变量设置(t *testing.T) { + t.Setenv(AntigravityOAuthClientSecretEnv, "my-secret-value") + + secret, err := getClientSecret() + if err != nil { + t.Fatalf("获取 client_secret 失败: %v", err) + } + if secret != "my-secret-value" { + t.Errorf("client_secret 不匹配: got %s, want my-secret-value", secret) + } +} + +func TestGetClientSecret_环境变量为空(t *testing.T) { + t.Setenv(AntigravityOAuthClientSecretEnv, "") + + _, err := getClientSecret() + if err == nil { + t.Fatal("环境变量为空时应返回错误") + } + if !strings.Contains(err.Error(), AntigravityOAuthClientSecretEnv) { + t.Errorf("错误信息应包含环境变量名: got %s", err.Error()) + } +} + +func TestGetClientSecret_环境变量未设置(t *testing.T) { + // t.Setenv 会在测试结束时恢复,但我们需要确保它不存在 + // 注意:如果 ClientSecret 常量非空,这个测试会直接返回常量值 + // 当前代码中 ClientSecret = "",所以会走环境变量逻辑 + + // 明确设置再取消,确保环境变量不存在 + t.Setenv(AntigravityOAuthClientSecretEnv, "") + + _, err := getClientSecret() + if err == nil { + t.Fatal("环境变量未设置时应返回错误") + } +} + +func TestGetClientSecret_环境变量含空格(t *testing.T) { + t.Setenv(AntigravityOAuthClientSecretEnv, " ") + + _, err := getClientSecret() + if err == nil { + t.Fatal("环境变量仅含空格时应返回错误") + } +} + +func TestGetClientSecret_环境变量有前后空格(t *testing.T) { + t.Setenv(AntigravityOAuthClientSecretEnv, " valid-secret ") + + secret, err := getClientSecret() + if err != nil { + t.Fatalf("获取 client_secret 失败: %v", err) + } + if secret != "valid-secret" { + t.Errorf("应去除前后空格: got %q, want %q", secret, "valid-secret") + } +} + +// --------------------------------------------------------------------------- +// ForwardBaseURLs +// --------------------------------------------------------------------------- + +func TestForwardBaseURLs_Daily优先(t *testing.T) { + urls := ForwardBaseURLs() + if len(urls) == 0 { + t.Fatal("ForwardBaseURLs 返回空列表") + } + + // daily URL 应排在第一位 + if urls[0] != antigravityDailyBaseURL { + t.Errorf("第一个 URL 应为 daily: got %s, want %s", urls[0], antigravityDailyBaseURL) + } + + // 应包含所有 URL + if len(urls) != len(BaseURLs) { + t.Errorf("URL 数量不匹配: got %d, want %d", len(urls), len(BaseURLs)) + } + + // 验证 prod URL 也在列表中 + found := false + for _, u := range urls { + if u == antigravityProdBaseURL { + found = true + break + } + } + if !found { + t.Error("ForwardBaseURLs 中缺少 prod URL") + } +} + +func TestForwardBaseURLs_不修改原切片(t *testing.T) { + originalFirst := BaseURLs[0] + _ = ForwardBaseURLs() + // 确保原始 BaseURLs 未被修改 + if BaseURLs[0] != originalFirst { + t.Errorf("ForwardBaseURLs 不应修改原始 BaseURLs: got %s, want %s", BaseURLs[0], originalFirst) + } +} + +// --------------------------------------------------------------------------- +// URLAvailability +// --------------------------------------------------------------------------- + +func TestNewURLAvailability(t *testing.T) { + ua := NewURLAvailability(5 * time.Minute) + if ua == nil { + t.Fatal("NewURLAvailability 返回 nil") + } + if ua.ttl != 5*time.Minute { + t.Errorf("TTL 不匹配: got %v, want 5m", ua.ttl) + } + if ua.unavailable == nil { + t.Error("unavailable map 不应为 nil") + } +} + +func TestURLAvailability_MarkUnavailable(t *testing.T) { + ua := NewURLAvailability(5 * time.Minute) + testURL := "https://example.com" + + ua.MarkUnavailable(testURL) + + if ua.IsAvailable(testURL) { + t.Error("标记为不可用后 IsAvailable 应返回 false") + } +} + +func TestURLAvailability_MarkSuccess(t *testing.T) { + ua := NewURLAvailability(5 * time.Minute) + testURL := "https://example.com" + + // 先标记为不可用 + ua.MarkUnavailable(testURL) + if ua.IsAvailable(testURL) { + t.Error("标记为不可用后应不可用") + } + + // 标记成功后应恢复可用 + ua.MarkSuccess(testURL) + if !ua.IsAvailable(testURL) { + t.Error("MarkSuccess 后应恢复可用") + } + + // 验证 lastSuccess 被设置 + ua.mu.RLock() + if ua.lastSuccess != testURL { + t.Errorf("lastSuccess 不匹配: got %s, want %s", ua.lastSuccess, testURL) + } + ua.mu.RUnlock() +} + +func TestURLAvailability_IsAvailable_TTL过期(t *testing.T) { + // 使用极短的 TTL + ua := NewURLAvailability(1 * time.Millisecond) + testURL := "https://example.com" + + ua.MarkUnavailable(testURL) + // 等待 TTL 过期 + time.Sleep(5 * time.Millisecond) + + if !ua.IsAvailable(testURL) { + t.Error("TTL 过期后 URL 应恢复可用") + } +} + +func TestURLAvailability_IsAvailable_未标记的URL(t *testing.T) { + ua := NewURLAvailability(5 * time.Minute) + if !ua.IsAvailable("https://never-marked.com") { + t.Error("未标记的 URL 应默认可用") + } +} + +func TestURLAvailability_GetAvailableURLs(t *testing.T) { + ua := NewURLAvailability(10 * time.Minute) + + // 默认所有 URL 都可用 + urls := ua.GetAvailableURLs() + if len(urls) != len(BaseURLs) { + t.Errorf("可用 URL 数量不匹配: got %d, want %d", len(urls), len(BaseURLs)) + } +} + +func TestURLAvailability_GetAvailableURLs_标记一个不可用(t *testing.T) { + ua := NewURLAvailability(10 * time.Minute) + + if len(BaseURLs) < 2 { + t.Skip("BaseURLs 少于 2 个,跳过此测试") + } + + ua.MarkUnavailable(BaseURLs[0]) + urls := ua.GetAvailableURLs() + + // 标记的 URL 不应出现在可用列表中 + for _, u := range urls { + if u == BaseURLs[0] { + t.Errorf("被标记不可用的 URL 不应出现在可用列表中: %s", BaseURLs[0]) + } + } +} + +func TestURLAvailability_GetAvailableURLsWithBase(t *testing.T) { + ua := NewURLAvailability(10 * time.Minute) + customURLs := []string{"https://a.com", "https://b.com", "https://c.com"} + + urls := ua.GetAvailableURLsWithBase(customURLs) + if len(urls) != 3 { + t.Errorf("可用 URL 数量不匹配: got %d, want 3", len(urls)) + } +} + +func TestURLAvailability_GetAvailableURLsWithBase_LastSuccess优先(t *testing.T) { + ua := NewURLAvailability(10 * time.Minute) + customURLs := []string{"https://a.com", "https://b.com", "https://c.com"} + + ua.MarkSuccess("https://c.com") + + urls := ua.GetAvailableURLsWithBase(customURLs) + if len(urls) != 3 { + t.Fatalf("可用 URL 数量不匹配: got %d, want 3", len(urls)) + } + // c.com 应排在第一位 + if urls[0] != "https://c.com" { + t.Errorf("lastSuccess 应排在第一位: got %s, want https://c.com", urls[0]) + } + // 其余按原始顺序 + if urls[1] != "https://a.com" { + t.Errorf("第二个应为 a.com: got %s", urls[1]) + } + if urls[2] != "https://b.com" { + t.Errorf("第三个应为 b.com: got %s", urls[2]) + } +} + +func TestURLAvailability_GetAvailableURLsWithBase_LastSuccess不可用(t *testing.T) { + ua := NewURLAvailability(10 * time.Minute) + customURLs := []string{"https://a.com", "https://b.com"} + + ua.MarkSuccess("https://b.com") + ua.MarkUnavailable("https://b.com") + + urls := ua.GetAvailableURLsWithBase(customURLs) + // b.com 被标记不可用,不应出现 + if len(urls) != 1 { + t.Fatalf("可用 URL 数量不匹配: got %d, want 1", len(urls)) + } + if urls[0] != "https://a.com" { + t.Errorf("仅 a.com 应可用: got %s", urls[0]) + } +} + +func TestURLAvailability_GetAvailableURLsWithBase_LastSuccess不在列表中(t *testing.T) { + ua := NewURLAvailability(10 * time.Minute) + customURLs := []string{"https://a.com", "https://b.com"} + + ua.MarkSuccess("https://not-in-list.com") + + urls := ua.GetAvailableURLsWithBase(customURLs) + // lastSuccess 不在自定义列表中,不应被添加 + if len(urls) != 2 { + t.Fatalf("可用 URL 数量不匹配: got %d, want 2", len(urls)) + } +} + +// --------------------------------------------------------------------------- +// SessionStore +// --------------------------------------------------------------------------- + +func TestNewSessionStore(t *testing.T) { + store := NewSessionStore() + defer store.Stop() + + if store == nil { + t.Fatal("NewSessionStore 返回 nil") + } + if store.sessions == nil { + t.Error("sessions map 不应为 nil") + } +} + +func TestSessionStore_SetAndGet(t *testing.T) { + store := NewSessionStore() + defer store.Stop() + + session := &OAuthSession{ + State: "test-state", + CodeVerifier: "test-verifier", + ProxyURL: "http://proxy.example.com", + CreatedAt: time.Now(), + } + + store.Set("session-1", session) + + got, ok := store.Get("session-1") + if !ok { + t.Fatal("Get 应返回 true") + } + if got.State != "test-state" { + t.Errorf("State 不匹配: got %s", got.State) + } + if got.CodeVerifier != "test-verifier" { + t.Errorf("CodeVerifier 不匹配: got %s", got.CodeVerifier) + } + if got.ProxyURL != "http://proxy.example.com" { + t.Errorf("ProxyURL 不匹配: got %s", got.ProxyURL) + } +} + +func TestSessionStore_Get_不存在(t *testing.T) { + store := NewSessionStore() + defer store.Stop() + + _, ok := store.Get("nonexistent") + if ok { + t.Error("不存在的 session 应返回 false") + } +} + +func TestSessionStore_Get_过期(t *testing.T) { + store := NewSessionStore() + defer store.Stop() + + session := &OAuthSession{ + State: "expired-state", + CreatedAt: time.Now().Add(-SessionTTL - time.Minute), // 已过期 + } + + store.Set("expired-session", session) + + _, ok := store.Get("expired-session") + if ok { + t.Error("过期的 session 应返回 false") + } +} + +func TestSessionStore_Delete(t *testing.T) { + store := NewSessionStore() + defer store.Stop() + + session := &OAuthSession{ + State: "to-delete", + CreatedAt: time.Now(), + } + + store.Set("del-session", session) + store.Delete("del-session") + + _, ok := store.Get("del-session") + if ok { + t.Error("删除后 Get 应返回 false") + } +} + +func TestSessionStore_Delete_不存在(t *testing.T) { + store := NewSessionStore() + defer store.Stop() + + // 删除不存在的 session 不应 panic + store.Delete("nonexistent") +} + +func TestSessionStore_Stop(t *testing.T) { + store := NewSessionStore() + store.Stop() + + // 多次 Stop 不应 panic + store.Stop() +} + +func TestSessionStore_多个Session(t *testing.T) { + store := NewSessionStore() + defer store.Stop() + + for i := 0; i < 10; i++ { + session := &OAuthSession{ + State: "state-" + string(rune('0'+i)), + CreatedAt: time.Now(), + } + store.Set("session-"+string(rune('0'+i)), session) + } + + // 验证都能取到 + for i := 0; i < 10; i++ { + _, ok := store.Get("session-" + string(rune('0'+i))) + if !ok { + t.Errorf("session-%d 应存在", i) + } + } +} + +// --------------------------------------------------------------------------- +// GenerateRandomBytes +// --------------------------------------------------------------------------- + +func TestGenerateRandomBytes_长度正确(t *testing.T) { + sizes := []int{0, 1, 16, 32, 64, 128} + for _, size := range sizes { + b, err := GenerateRandomBytes(size) + if err != nil { + t.Fatalf("GenerateRandomBytes(%d) 失败: %v", size, err) + } + if len(b) != size { + t.Errorf("长度不匹配: got %d, want %d", len(b), size) + } + } +} + +func TestGenerateRandomBytes_不同调用产生不同结果(t *testing.T) { + b1, err := GenerateRandomBytes(32) + if err != nil { + t.Fatalf("第一次调用失败: %v", err) + } + b2, err := GenerateRandomBytes(32) + if err != nil { + t.Fatalf("第二次调用失败: %v", err) + } + // 两次生成的随机字节应该不同(概率上几乎不可能相同) + if string(b1) == string(b2) { + t.Error("两次生成的随机字节相同,概率极低,可能有问题") + } +} + +// --------------------------------------------------------------------------- +// GenerateState +// --------------------------------------------------------------------------- + +func TestGenerateState_返回值格式(t *testing.T) { + state, err := GenerateState() + if err != nil { + t.Fatalf("GenerateState 失败: %v", err) + } + if state == "" { + t.Error("GenerateState 返回空字符串") + } + // base64url 编码不应包含 +, /, = + if strings.ContainsAny(state, "+/=") { + t.Errorf("GenerateState 返回值包含非 base64url 字符: %s", state) + } + // 32 字节的 base64url 编码长度应为 43(去掉了尾部 = 填充) + if len(state) != 43 { + t.Errorf("GenerateState 返回值长度不匹配: got %d, want 43", len(state)) + } +} + +func TestGenerateState_唯一性(t *testing.T) { + s1, _ := GenerateState() + s2, _ := GenerateState() + if s1 == s2 { + t.Error("两次 GenerateState 结果相同") + } +} + +// --------------------------------------------------------------------------- +// GenerateSessionID +// --------------------------------------------------------------------------- + +func TestGenerateSessionID_返回值格式(t *testing.T) { + id, err := GenerateSessionID() + if err != nil { + t.Fatalf("GenerateSessionID 失败: %v", err) + } + if id == "" { + t.Error("GenerateSessionID 返回空字符串") + } + // 16 字节的 hex 编码长度应为 32 + if len(id) != 32 { + t.Errorf("GenerateSessionID 返回值长度不匹配: got %d, want 32", len(id)) + } + // 验证是合法的 hex 字符串 + if _, err := hex.DecodeString(id); err != nil { + t.Errorf("GenerateSessionID 返回值不是合法的 hex 字符串: %s, err: %v", id, err) + } +} + +func TestGenerateSessionID_唯一性(t *testing.T) { + id1, _ := GenerateSessionID() + id2, _ := GenerateSessionID() + if id1 == id2 { + t.Error("两次 GenerateSessionID 结果相同") + } +} + +// --------------------------------------------------------------------------- +// GenerateCodeVerifier +// --------------------------------------------------------------------------- + +func TestGenerateCodeVerifier_返回值格式(t *testing.T) { + verifier, err := GenerateCodeVerifier() + if err != nil { + t.Fatalf("GenerateCodeVerifier 失败: %v", err) + } + if verifier == "" { + t.Error("GenerateCodeVerifier 返回空字符串") + } + // base64url 编码不应包含 +, /, = + if strings.ContainsAny(verifier, "+/=") { + t.Errorf("GenerateCodeVerifier 返回值包含非 base64url 字符: %s", verifier) + } + // 32 字节的 base64url 编码长度应为 43 + if len(verifier) != 43 { + t.Errorf("GenerateCodeVerifier 返回值长度不匹配: got %d, want 43", len(verifier)) + } +} + +func TestGenerateCodeVerifier_唯一性(t *testing.T) { + v1, _ := GenerateCodeVerifier() + v2, _ := GenerateCodeVerifier() + if v1 == v2 { + t.Error("两次 GenerateCodeVerifier 结果相同") + } +} + +// --------------------------------------------------------------------------- +// GenerateCodeChallenge +// --------------------------------------------------------------------------- + +func TestGenerateCodeChallenge_SHA256_Base64URL(t *testing.T) { + verifier := "dBjftJeZ4CVP-mB92K27uhbUJU1p1r_wW1gFWFOEjXk" + + challenge := GenerateCodeChallenge(verifier) + + // 手动计算预期值 + hash := sha256.Sum256([]byte(verifier)) + expected := strings.TrimRight(base64.URLEncoding.EncodeToString(hash[:]), "=") + + if challenge != expected { + t.Errorf("CodeChallenge 不匹配: got %s, want %s", challenge, expected) + } +} + +func TestGenerateCodeChallenge_不含填充字符(t *testing.T) { + challenge := GenerateCodeChallenge("test-verifier") + if strings.Contains(challenge, "=") { + t.Errorf("CodeChallenge 不应包含 = 填充字符: %s", challenge) + } +} + +func TestGenerateCodeChallenge_不含非URL安全字符(t *testing.T) { + challenge := GenerateCodeChallenge("another-verifier") + if strings.ContainsAny(challenge, "+/") { + t.Errorf("CodeChallenge 不应包含 + 或 / 字符: %s", challenge) + } +} + +func TestGenerateCodeChallenge_相同输入相同输出(t *testing.T) { + c1 := GenerateCodeChallenge("same-verifier") + c2 := GenerateCodeChallenge("same-verifier") + if c1 != c2 { + t.Errorf("相同输入应产生相同输出: got %s and %s", c1, c2) + } +} + +func TestGenerateCodeChallenge_不同输入不同输出(t *testing.T) { + c1 := GenerateCodeChallenge("verifier-1") + c2 := GenerateCodeChallenge("verifier-2") + if c1 == c2 { + t.Error("不同输入应产生不同输出") + } +} + +// --------------------------------------------------------------------------- +// BuildAuthorizationURL +// --------------------------------------------------------------------------- + +func TestBuildAuthorizationURL_参数验证(t *testing.T) { + state := "test-state-123" + codeChallenge := "test-challenge-abc" + + authURL := BuildAuthorizationURL(state, codeChallenge) + + // 验证以 AuthorizeURL 开头 + if !strings.HasPrefix(authURL, AuthorizeURL+"?") { + t.Errorf("URL 应以 %s? 开头: got %s", AuthorizeURL, authURL) + } + + // 解析 URL 并验证参数 + parsed, err := url.Parse(authURL) + if err != nil { + t.Fatalf("解析 URL 失败: %v", err) + } + + params := parsed.Query() + + expectedParams := map[string]string{ + "client_id": ClientID, + "redirect_uri": RedirectURI, + "response_type": "code", + "scope": Scopes, + "state": state, + "code_challenge": codeChallenge, + "code_challenge_method": "S256", + "access_type": "offline", + "prompt": "consent", + "include_granted_scopes": "true", + } + + for key, want := range expectedParams { + got := params.Get(key) + if got != want { + t.Errorf("参数 %s 不匹配: got %q, want %q", key, got, want) + } + } +} + +func TestBuildAuthorizationURL_参数数量(t *testing.T) { + authURL := BuildAuthorizationURL("s", "c") + parsed, err := url.Parse(authURL) + if err != nil { + t.Fatalf("解析 URL 失败: %v", err) + } + + params := parsed.Query() + // 应包含 10 个参数 + expectedCount := 10 + if len(params) != expectedCount { + t.Errorf("参数数量不匹配: got %d, want %d", len(params), expectedCount) + } +} + +func TestBuildAuthorizationURL_特殊字符编码(t *testing.T) { + state := "state+with/special=chars" + codeChallenge := "challenge+value" + + authURL := BuildAuthorizationURL(state, codeChallenge) + + parsed, err := url.Parse(authURL) + if err != nil { + t.Fatalf("解析 URL 失败: %v", err) + } + + // 解析后应正确还原特殊字符 + if got := parsed.Query().Get("state"); got != state { + t.Errorf("state 参数编码/解码不匹配: got %q, want %q", got, state) + } +} + +// --------------------------------------------------------------------------- +// 常量值验证 +// --------------------------------------------------------------------------- + +func TestConstants_值正确(t *testing.T) { + if AuthorizeURL != "https://accounts.google.com/o/oauth2/v2/auth" { + t.Errorf("AuthorizeURL 不匹配: got %s", AuthorizeURL) + } + if TokenURL != "https://oauth2.googleapis.com/token" { + t.Errorf("TokenURL 不匹配: got %s", TokenURL) + } + if UserInfoURL != "https://www.googleapis.com/oauth2/v2/userinfo" { + t.Errorf("UserInfoURL 不匹配: got %s", UserInfoURL) + } + if ClientID != "1071006060591-tmhssin2h21lcre235vtolojh4g403ep.apps.googleusercontent.com" { + t.Errorf("ClientID 不匹配: got %s", ClientID) + } + if ClientSecret != "" { + t.Error("ClientSecret 应为空字符串") + } + if RedirectURI != "http://localhost:8085/callback" { + t.Errorf("RedirectURI 不匹配: got %s", RedirectURI) + } + if UserAgent != "antigravity/1.15.8 windows/amd64" { + t.Errorf("UserAgent 不匹配: got %s", UserAgent) + } + if SessionTTL != 30*time.Minute { + t.Errorf("SessionTTL 不匹配: got %v", SessionTTL) + } + if URLAvailabilityTTL != 5*time.Minute { + t.Errorf("URLAvailabilityTTL 不匹配: got %v", URLAvailabilityTTL) + } +} + +func TestScopes_包含必要范围(t *testing.T) { + expectedScopes := []string{ + "https://www.googleapis.com/auth/cloud-platform", + "https://www.googleapis.com/auth/userinfo.email", + "https://www.googleapis.com/auth/userinfo.profile", + "https://www.googleapis.com/auth/cclog", + "https://www.googleapis.com/auth/experimentsandconfigs", + } + + for _, scope := range expectedScopes { + if !strings.Contains(Scopes, scope) { + t.Errorf("Scopes 缺少 %s", scope) + } + } +} diff --git a/backend/internal/pkg/geminicli/oauth_test.go b/backend/internal/pkg/geminicli/oauth_test.go index 0770730a..664e0344 100644 --- a/backend/internal/pkg/geminicli/oauth_test.go +++ b/backend/internal/pkg/geminicli/oauth_test.go @@ -1,11 +1,439 @@ package geminicli import ( + "encoding/hex" "strings" + "sync" "testing" + "time" ) +// --------------------------------------------------------------------------- +// SessionStore 测试 +// --------------------------------------------------------------------------- + +func TestSessionStore_SetAndGet(t *testing.T) { + store := NewSessionStore() + defer store.Stop() + + session := &OAuthSession{ + State: "test-state", + OAuthType: "code_assist", + CreatedAt: time.Now(), + } + store.Set("sid-1", session) + + got, ok := store.Get("sid-1") + if !ok { + t.Fatal("期望 Get 返回 ok=true,实际返回 false") + } + if got.State != "test-state" { + t.Errorf("期望 State=%q,实际=%q", "test-state", got.State) + } +} + +func TestSessionStore_GetNotFound(t *testing.T) { + store := NewSessionStore() + defer store.Stop() + + _, ok := store.Get("不存在的ID") + if ok { + t.Error("期望不存在的 sessionID 返回 ok=false") + } +} + +func TestSessionStore_GetExpired(t *testing.T) { + store := NewSessionStore() + defer store.Stop() + + // 创建一个已过期的 session(CreatedAt 设置为 SessionTTL+1 分钟之前) + session := &OAuthSession{ + State: "expired-state", + OAuthType: "code_assist", + CreatedAt: time.Now().Add(-(SessionTTL + 1*time.Minute)), + } + store.Set("expired-sid", session) + + _, ok := store.Get("expired-sid") + if ok { + t.Error("期望过期的 session 返回 ok=false") + } +} + +func TestSessionStore_Delete(t *testing.T) { + store := NewSessionStore() + defer store.Stop() + + session := &OAuthSession{ + State: "to-delete", + OAuthType: "code_assist", + CreatedAt: time.Now(), + } + store.Set("del-sid", session) + + // 先确认存在 + if _, ok := store.Get("del-sid"); !ok { + t.Fatal("删除前 session 应该存在") + } + + store.Delete("del-sid") + + if _, ok := store.Get("del-sid"); ok { + t.Error("删除后 session 不应该存在") + } +} + +func TestSessionStore_Stop_Idempotent(t *testing.T) { + store := NewSessionStore() + + // 多次调用 Stop 不应 panic + store.Stop() + store.Stop() + store.Stop() +} + +func TestSessionStore_ConcurrentAccess(t *testing.T) { + store := NewSessionStore() + defer store.Stop() + + const goroutines = 50 + var wg sync.WaitGroup + wg.Add(goroutines * 3) + + // 并发写入 + for i := 0; i < goroutines; i++ { + go func(idx int) { + defer wg.Done() + sid := "concurrent-" + string(rune('A'+idx%26)) + store.Set(sid, &OAuthSession{ + State: sid, + OAuthType: "code_assist", + CreatedAt: time.Now(), + }) + }(i) + } + + // 并发读取 + for i := 0; i < goroutines; i++ { + go func(idx int) { + defer wg.Done() + sid := "concurrent-" + string(rune('A'+idx%26)) + store.Get(sid) // 可能找到也可能没找到,关键是不 panic + }(i) + } + + // 并发删除 + for i := 0; i < goroutines; i++ { + go func(idx int) { + defer wg.Done() + sid := "concurrent-" + string(rune('A'+idx%26)) + store.Delete(sid) + }(i) + } + + wg.Wait() +} + +// --------------------------------------------------------------------------- +// GenerateRandomBytes 测试 +// --------------------------------------------------------------------------- + +func TestGenerateRandomBytes(t *testing.T) { + tests := []int{0, 1, 16, 32, 64} + for _, n := range tests { + b, err := GenerateRandomBytes(n) + if err != nil { + t.Errorf("GenerateRandomBytes(%d) 出错: %v", n, err) + continue + } + if len(b) != n { + t.Errorf("GenerateRandomBytes(%d) 返回长度=%d,期望=%d", n, len(b), n) + } + } +} + +func TestGenerateRandomBytes_Uniqueness(t *testing.T) { + // 两次调用应该返回不同的结果(极小概率相同,32字节足够) + a, _ := GenerateRandomBytes(32) + b, _ := GenerateRandomBytes(32) + if string(a) == string(b) { + t.Error("两次 GenerateRandomBytes(32) 返回了相同结果,随机性可能有问题") + } +} + +// --------------------------------------------------------------------------- +// GenerateState 测试 +// --------------------------------------------------------------------------- + +func TestGenerateState(t *testing.T) { + state, err := GenerateState() + if err != nil { + t.Fatalf("GenerateState() 出错: %v", err) + } + if state == "" { + t.Error("GenerateState() 返回空字符串") + } + // base64url 编码不应包含 padding '=' + if strings.Contains(state, "=") { + t.Errorf("GenerateState() 结果包含 '=' padding: %s", state) + } + // base64url 不应包含 '+' 或 '/' + if strings.ContainsAny(state, "+/") { + t.Errorf("GenerateState() 结果包含非 base64url 字符: %s", state) + } +} + +// --------------------------------------------------------------------------- +// GenerateSessionID 测试 +// --------------------------------------------------------------------------- + +func TestGenerateSessionID(t *testing.T) { + sid, err := GenerateSessionID() + if err != nil { + t.Fatalf("GenerateSessionID() 出错: %v", err) + } + // 16 字节 -> 32 个 hex 字符 + if len(sid) != 32 { + t.Errorf("GenerateSessionID() 长度=%d,期望=32", len(sid)) + } + // 必须是合法的 hex 字符串 + if _, err := hex.DecodeString(sid); err != nil { + t.Errorf("GenerateSessionID() 不是合法的 hex 字符串: %s, err=%v", sid, err) + } +} + +func TestGenerateSessionID_Uniqueness(t *testing.T) { + a, _ := GenerateSessionID() + b, _ := GenerateSessionID() + if a == b { + t.Error("两次 GenerateSessionID() 返回了相同结果") + } +} + +// --------------------------------------------------------------------------- +// GenerateCodeVerifier 测试 +// --------------------------------------------------------------------------- + +func TestGenerateCodeVerifier(t *testing.T) { + verifier, err := GenerateCodeVerifier() + if err != nil { + t.Fatalf("GenerateCodeVerifier() 出错: %v", err) + } + if verifier == "" { + t.Error("GenerateCodeVerifier() 返回空字符串") + } + // RFC 7636 要求 code_verifier 至少 43 个字符 + if len(verifier) < 43 { + t.Errorf("GenerateCodeVerifier() 长度=%d,RFC 7636 要求至少 43 字符", len(verifier)) + } + // base64url 编码不应包含 padding 和非 URL 安全字符 + if strings.Contains(verifier, "=") { + t.Errorf("GenerateCodeVerifier() 包含 '=' padding: %s", verifier) + } + if strings.ContainsAny(verifier, "+/") { + t.Errorf("GenerateCodeVerifier() 包含非 base64url 字符: %s", verifier) + } +} + +// --------------------------------------------------------------------------- +// GenerateCodeChallenge 测试 +// --------------------------------------------------------------------------- + +func TestGenerateCodeChallenge(t *testing.T) { + // 使用已知输入验证输出 + // RFC 7636 附录 B 示例: verifier = "dBjftJeZ4CVP-mB92K27uhbUJU1p1r_wW1gFWFOEjXk" + // 预期 challenge = "E9Melhoa2OwvFrEMTJguCHaoeK1t8URWbuGJSstw-cM" + verifier := "dBjftJeZ4CVP-mB92K27uhbUJU1p1r_wW1gFWFOEjXk" + expected := "E9Melhoa2OwvFrEMTJguCHaoeK1t8URWbuGJSstw-cM" + + challenge := GenerateCodeChallenge(verifier) + if challenge != expected { + t.Errorf("GenerateCodeChallenge(%q) = %q,期望 %q", verifier, challenge, expected) + } +} + +func TestGenerateCodeChallenge_NoPadding(t *testing.T) { + challenge := GenerateCodeChallenge("test-verifier-string") + if strings.Contains(challenge, "=") { + t.Errorf("GenerateCodeChallenge() 结果包含 '=' padding: %s", challenge) + } +} + +// --------------------------------------------------------------------------- +// base64URLEncode 测试 +// --------------------------------------------------------------------------- + +func TestBase64URLEncode(t *testing.T) { + tests := []struct { + name string + input []byte + }{ + {"空字节", []byte{}}, + {"单字节", []byte{0xff}}, + {"多字节", []byte{0x01, 0x02, 0x03, 0x04, 0x05}}, + {"全零", []byte{0x00, 0x00, 0x00}}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := base64URLEncode(tt.input) + // 不应包含 '=' padding + if strings.Contains(result, "=") { + t.Errorf("base64URLEncode(%v) 包含 '=' padding: %s", tt.input, result) + } + // 不应包含标准 base64 的 '+' 或 '/' + if strings.ContainsAny(result, "+/") { + t.Errorf("base64URLEncode(%v) 包含非 URL 安全字符: %s", tt.input, result) + } + }) + } +} + +// --------------------------------------------------------------------------- +// hasRestrictedScope 测试 +// --------------------------------------------------------------------------- + +func TestHasRestrictedScope(t *testing.T) { + tests := []struct { + scope string + expected bool + }{ + // 受限 scope + {"https://www.googleapis.com/auth/generative-language", true}, + {"https://www.googleapis.com/auth/generative-language.retriever", true}, + {"https://www.googleapis.com/auth/generative-language.tuning", true}, + {"https://www.googleapis.com/auth/drive", true}, + {"https://www.googleapis.com/auth/drive.readonly", true}, + {"https://www.googleapis.com/auth/drive.file", true}, + // 非受限 scope + {"https://www.googleapis.com/auth/cloud-platform", false}, + {"https://www.googleapis.com/auth/userinfo.email", false}, + {"https://www.googleapis.com/auth/userinfo.profile", false}, + // 边界情况 + {"", false}, + {"random-scope", false}, + } + for _, tt := range tests { + t.Run(tt.scope, func(t *testing.T) { + got := hasRestrictedScope(tt.scope) + if got != tt.expected { + t.Errorf("hasRestrictedScope(%q) = %v,期望 %v", tt.scope, got, tt.expected) + } + }) + } +} + +// --------------------------------------------------------------------------- +// BuildAuthorizationURL 测试 +// --------------------------------------------------------------------------- + +func TestBuildAuthorizationURL(t *testing.T) { + t.Setenv(GeminiCLIOAuthClientSecretEnv, "test-secret") + + authURL, err := BuildAuthorizationURL( + OAuthConfig{}, + "test-state", + "test-challenge", + "https://example.com/callback", + "", + "code_assist", + ) + if err != nil { + t.Fatalf("BuildAuthorizationURL() 出错: %v", err) + } + + // 检查返回的 URL 包含期望的参数 + checks := []string{ + "response_type=code", + "client_id=" + GeminiCLIOAuthClientID, + "redirect_uri=", + "state=test-state", + "code_challenge=test-challenge", + "code_challenge_method=S256", + "access_type=offline", + "prompt=consent", + "include_granted_scopes=true", + } + for _, check := range checks { + if !strings.Contains(authURL, check) { + t.Errorf("BuildAuthorizationURL() URL 缺少参数 %q\nURL: %s", check, authURL) + } + } + + // 不应包含 project_id(因为传的是空字符串) + if strings.Contains(authURL, "project_id=") { + t.Errorf("BuildAuthorizationURL() 空 projectID 时不应包含 project_id 参数") + } + + // URL 应该以正确的授权端点开头 + if !strings.HasPrefix(authURL, AuthorizeURL+"?") { + t.Errorf("BuildAuthorizationURL() URL 应以 %s? 开头,实际: %s", AuthorizeURL, authURL) + } +} + +func TestBuildAuthorizationURL_EmptyRedirectURI(t *testing.T) { + t.Setenv(GeminiCLIOAuthClientSecretEnv, "test-secret") + + _, err := BuildAuthorizationURL( + OAuthConfig{}, + "test-state", + "test-challenge", + "", // 空 redirectURI + "", + "code_assist", + ) + if err == nil { + t.Error("BuildAuthorizationURL() 空 redirectURI 应该报错") + } + if !strings.Contains(err.Error(), "redirect_uri") { + t.Errorf("错误消息应包含 'redirect_uri',实际: %v", err) + } +} + +func TestBuildAuthorizationURL_WithProjectID(t *testing.T) { + t.Setenv(GeminiCLIOAuthClientSecretEnv, "test-secret") + + authURL, err := BuildAuthorizationURL( + OAuthConfig{}, + "test-state", + "test-challenge", + "https://example.com/callback", + "my-project-123", + "code_assist", + ) + if err != nil { + t.Fatalf("BuildAuthorizationURL() 出错: %v", err) + } + if !strings.Contains(authURL, "project_id=my-project-123") { + t.Errorf("BuildAuthorizationURL() 带 projectID 时应包含 project_id 参数\nURL: %s", authURL) + } +} + +func TestBuildAuthorizationURL_OAuthConfigError(t *testing.T) { + // 不设置环境变量,也不提供 client 凭据,EffectiveOAuthConfig 应该报错 + t.Setenv(GeminiCLIOAuthClientSecretEnv, "") + + _, err := BuildAuthorizationURL( + OAuthConfig{}, + "test-state", + "test-challenge", + "https://example.com/callback", + "", + "code_assist", + ) + if err == nil { + t.Error("当 EffectiveOAuthConfig 失败时,BuildAuthorizationURL 应该返回错误") + } +} + +// --------------------------------------------------------------------------- +// EffectiveOAuthConfig 测试 - 原有测试 +// --------------------------------------------------------------------------- + func TestEffectiveOAuthConfig_GoogleOne(t *testing.T) { + // 内置的 Gemini CLI client secret 不嵌入在此仓库中。 + // 测试通过环境变量设置一个假的 secret 来模拟运维配置。 + t.Setenv(GeminiCLIOAuthClientSecretEnv, "test-built-in-secret") + tests := []struct { name string input OAuthConfig @@ -15,7 +443,7 @@ func TestEffectiveOAuthConfig_GoogleOne(t *testing.T) { wantErr bool }{ { - name: "Google One with built-in client (empty config)", + name: "Google One 使用内置客户端(空配置)", input: OAuthConfig{}, oauthType: "google_one", wantClientID: GeminiCLIOAuthClientID, @@ -23,18 +451,18 @@ func TestEffectiveOAuthConfig_GoogleOne(t *testing.T) { wantErr: false, }, { - name: "Google One always uses built-in client (even if custom credentials passed)", + name: "Google One 使用自定义客户端(传入自定义凭据时使用自定义)", input: OAuthConfig{ ClientID: "custom-client-id", ClientSecret: "custom-client-secret", }, oauthType: "google_one", wantClientID: "custom-client-id", - wantScopes: DefaultCodeAssistScopes, // Uses code assist scopes even with custom client + wantScopes: DefaultCodeAssistScopes, wantErr: false, }, { - name: "Google One with built-in client and custom scopes (should filter restricted scopes)", + name: "Google One 内置客户端 + 自定义 scopes(应过滤受限 scopes)", input: OAuthConfig{ Scopes: "https://www.googleapis.com/auth/cloud-platform https://www.googleapis.com/auth/generative-language.retriever https://www.googleapis.com/auth/drive.readonly", }, @@ -44,7 +472,7 @@ func TestEffectiveOAuthConfig_GoogleOne(t *testing.T) { wantErr: false, }, { - name: "Google One with built-in client and only restricted scopes (should fallback to default)", + name: "Google One 内置客户端 + 仅受限 scopes(应回退到默认)", input: OAuthConfig{ Scopes: "https://www.googleapis.com/auth/generative-language.retriever https://www.googleapis.com/auth/drive.readonly", }, @@ -54,7 +482,7 @@ func TestEffectiveOAuthConfig_GoogleOne(t *testing.T) { wantErr: false, }, { - name: "Code Assist with built-in client", + name: "Code Assist 使用内置客户端", input: OAuthConfig{}, oauthType: "code_assist", wantClientID: GeminiCLIOAuthClientID, @@ -84,7 +512,9 @@ func TestEffectiveOAuthConfig_GoogleOne(t *testing.T) { } func TestEffectiveOAuthConfig_ScopeFiltering(t *testing.T) { - // Test that Google One with built-in client filters out restricted scopes + t.Setenv(GeminiCLIOAuthClientSecretEnv, "test-built-in-secret") + + // 测试 Google One + 内置客户端过滤受限 scopes cfg, err := EffectiveOAuthConfig(OAuthConfig{ Scopes: "https://www.googleapis.com/auth/cloud-platform https://www.googleapis.com/auth/userinfo.email https://www.googleapis.com/auth/generative-language.retriever https://www.googleapis.com/auth/drive.readonly https://www.googleapis.com/auth/userinfo.profile", }, "google_one") @@ -93,21 +523,240 @@ func TestEffectiveOAuthConfig_ScopeFiltering(t *testing.T) { t.Fatalf("EffectiveOAuthConfig() error = %v", err) } - // Should only contain cloud-platform, userinfo.email, and userinfo.profile - // Should NOT contain generative-language or drive scopes + // 应仅包含 cloud-platform、userinfo.email 和 userinfo.profile + // 不应包含 generative-language 或 drive scopes if strings.Contains(cfg.Scopes, "generative-language") { - t.Errorf("Scopes should not contain generative-language when using built-in client, got: %v", cfg.Scopes) + t.Errorf("使用内置客户端时 Scopes 不应包含 generative-language,实际: %v", cfg.Scopes) } if strings.Contains(cfg.Scopes, "drive") { - t.Errorf("Scopes should not contain drive when using built-in client, got: %v", cfg.Scopes) + t.Errorf("使用内置客户端时 Scopes 不应包含 drive,实际: %v", cfg.Scopes) } if !strings.Contains(cfg.Scopes, "cloud-platform") { - t.Errorf("Scopes should contain cloud-platform, got: %v", cfg.Scopes) + t.Errorf("Scopes 应包含 cloud-platform,实际: %v", cfg.Scopes) } if !strings.Contains(cfg.Scopes, "userinfo.email") { - t.Errorf("Scopes should contain userinfo.email, got: %v", cfg.Scopes) + t.Errorf("Scopes 应包含 userinfo.email,实际: %v", cfg.Scopes) } if !strings.Contains(cfg.Scopes, "userinfo.profile") { - t.Errorf("Scopes should contain userinfo.profile, got: %v", cfg.Scopes) + t.Errorf("Scopes 应包含 userinfo.profile,实际: %v", cfg.Scopes) + } +} + +// --------------------------------------------------------------------------- +// EffectiveOAuthConfig 测试 - 新增分支覆盖 +// --------------------------------------------------------------------------- + +func TestEffectiveOAuthConfig_OnlyClientID_NoSecret(t *testing.T) { + // 只提供 clientID 不提供 secret 应报错 + _, err := EffectiveOAuthConfig(OAuthConfig{ + ClientID: "some-client-id", + }, "code_assist") + if err == nil { + t.Error("只提供 ClientID 不提供 ClientSecret 应该报错") + } + if !strings.Contains(err.Error(), "client_id") || !strings.Contains(err.Error(), "client_secret") { + t.Errorf("错误消息应提及 client_id 和 client_secret,实际: %v", err) + } +} + +func TestEffectiveOAuthConfig_OnlyClientSecret_NoID(t *testing.T) { + // 只提供 secret 不提供 clientID 应报错 + _, err := EffectiveOAuthConfig(OAuthConfig{ + ClientSecret: "some-client-secret", + }, "code_assist") + if err == nil { + t.Error("只提供 ClientSecret 不提供 ClientID 应该报错") + } + if !strings.Contains(err.Error(), "client_id") || !strings.Contains(err.Error(), "client_secret") { + t.Errorf("错误消息应提及 client_id 和 client_secret,实际: %v", err) + } +} + +func TestEffectiveOAuthConfig_AIStudio_DefaultScopes_BuiltinClient(t *testing.T) { + t.Setenv(GeminiCLIOAuthClientSecretEnv, "test-built-in-secret") + + // ai_studio 类型,使用内置客户端,scopes 为空 -> 应使用 DefaultCodeAssistScopes(因为内置客户端不能请求 generative-language scope) + cfg, err := EffectiveOAuthConfig(OAuthConfig{}, "ai_studio") + if err != nil { + t.Fatalf("EffectiveOAuthConfig() error = %v", err) + } + if cfg.Scopes != DefaultCodeAssistScopes { + t.Errorf("ai_studio + 内置客户端应使用 DefaultCodeAssistScopes,实际: %q", cfg.Scopes) + } +} + +func TestEffectiveOAuthConfig_AIStudio_DefaultScopes_CustomClient(t *testing.T) { + // ai_studio 类型,使用自定义客户端,scopes 为空 -> 应使用 DefaultAIStudioScopes + cfg, err := EffectiveOAuthConfig(OAuthConfig{ + ClientID: "custom-id", + ClientSecret: "custom-secret", + }, "ai_studio") + if err != nil { + t.Fatalf("EffectiveOAuthConfig() error = %v", err) + } + if cfg.Scopes != DefaultAIStudioScopes { + t.Errorf("ai_studio + 自定义客户端应使用 DefaultAIStudioScopes,实际: %q", cfg.Scopes) + } +} + +func TestEffectiveOAuthConfig_AIStudio_ScopeNormalization(t *testing.T) { + // ai_studio 类型,旧的 generative-language scope 应被归一化为 generative-language.retriever + cfg, err := EffectiveOAuthConfig(OAuthConfig{ + ClientID: "custom-id", + ClientSecret: "custom-secret", + Scopes: "https://www.googleapis.com/auth/generative-language https://www.googleapis.com/auth/cloud-platform", + }, "ai_studio") + if err != nil { + t.Fatalf("EffectiveOAuthConfig() error = %v", err) + } + if strings.Contains(cfg.Scopes, "auth/generative-language ") || strings.HasSuffix(cfg.Scopes, "auth/generative-language") { + // 确保不包含未归一化的旧 scope(仅 generative-language 而非 generative-language.retriever) + parts := strings.Fields(cfg.Scopes) + for _, p := range parts { + if p == "https://www.googleapis.com/auth/generative-language" { + t.Errorf("ai_studio 应将 generative-language 归一化为 generative-language.retriever,实际 scopes: %q", cfg.Scopes) + } + } + } + if !strings.Contains(cfg.Scopes, "generative-language.retriever") { + t.Errorf("ai_studio 归一化后应包含 generative-language.retriever,实际: %q", cfg.Scopes) + } +} + +func TestEffectiveOAuthConfig_CommaSeparatedScopes(t *testing.T) { + t.Setenv(GeminiCLIOAuthClientSecretEnv, "test-built-in-secret") + + // 逗号分隔的 scopes 应被归一化为空格分隔 + cfg, err := EffectiveOAuthConfig(OAuthConfig{ + ClientID: "custom-id", + ClientSecret: "custom-secret", + Scopes: "https://www.googleapis.com/auth/cloud-platform,https://www.googleapis.com/auth/userinfo.email", + }, "code_assist") + if err != nil { + t.Fatalf("EffectiveOAuthConfig() error = %v", err) + } + // 应该用空格分隔,而非逗号 + if strings.Contains(cfg.Scopes, ",") { + t.Errorf("逗号分隔的 scopes 应被归一化为空格分隔,实际: %q", cfg.Scopes) + } + if !strings.Contains(cfg.Scopes, "cloud-platform") { + t.Errorf("归一化后应包含 cloud-platform,实际: %q", cfg.Scopes) + } + if !strings.Contains(cfg.Scopes, "userinfo.email") { + t.Errorf("归一化后应包含 userinfo.email,实际: %q", cfg.Scopes) + } +} + +func TestEffectiveOAuthConfig_MixedCommaAndSpaceScopes(t *testing.T) { + // 混合逗号和空格分隔的 scopes + cfg, err := EffectiveOAuthConfig(OAuthConfig{ + ClientID: "custom-id", + ClientSecret: "custom-secret", + Scopes: "https://www.googleapis.com/auth/cloud-platform, https://www.googleapis.com/auth/userinfo.email https://www.googleapis.com/auth/userinfo.profile", + }, "code_assist") + if err != nil { + t.Fatalf("EffectiveOAuthConfig() error = %v", err) + } + parts := strings.Fields(cfg.Scopes) + if len(parts) != 3 { + t.Errorf("归一化后应有 3 个 scope,实际: %d,scopes: %q", len(parts), cfg.Scopes) + } +} + +func TestEffectiveOAuthConfig_WhitespaceTriming(t *testing.T) { + // 输入中的前后空白应被清理 + cfg, err := EffectiveOAuthConfig(OAuthConfig{ + ClientID: " custom-id ", + ClientSecret: " custom-secret ", + Scopes: " https://www.googleapis.com/auth/cloud-platform ", + }, "code_assist") + if err != nil { + t.Fatalf("EffectiveOAuthConfig() error = %v", err) + } + if cfg.ClientID != "custom-id" { + t.Errorf("ClientID 应去除前后空白,实际: %q", cfg.ClientID) + } + if cfg.ClientSecret != "custom-secret" { + t.Errorf("ClientSecret 应去除前后空白,实际: %q", cfg.ClientSecret) + } + if cfg.Scopes != "https://www.googleapis.com/auth/cloud-platform" { + t.Errorf("Scopes 应去除前后空白,实际: %q", cfg.Scopes) + } +} + +func TestEffectiveOAuthConfig_NoEnvSecret(t *testing.T) { + // 不设置环境变量且不提供凭据,应该报错 + t.Setenv(GeminiCLIOAuthClientSecretEnv, "") + + _, err := EffectiveOAuthConfig(OAuthConfig{}, "code_assist") + if err == nil { + t.Error("没有内置 secret 且未提供凭据时应该报错") + } + if !strings.Contains(err.Error(), GeminiCLIOAuthClientSecretEnv) { + t.Errorf("错误消息应提及环境变量 %s,实际: %v", GeminiCLIOAuthClientSecretEnv, err) + } +} + +func TestEffectiveOAuthConfig_AIStudio_BuiltinClient_CustomScopes(t *testing.T) { + t.Setenv(GeminiCLIOAuthClientSecretEnv, "test-built-in-secret") + + // ai_studio + 内置客户端 + 自定义 scopes -> 应过滤受限 scopes + cfg, err := EffectiveOAuthConfig(OAuthConfig{ + Scopes: "https://www.googleapis.com/auth/cloud-platform https://www.googleapis.com/auth/generative-language.retriever", + }, "ai_studio") + if err != nil { + t.Fatalf("EffectiveOAuthConfig() error = %v", err) + } + // 内置客户端应过滤 generative-language.retriever + if strings.Contains(cfg.Scopes, "generative-language") { + t.Errorf("ai_studio + 内置客户端应过滤受限 scopes,实际: %q", cfg.Scopes) + } + if !strings.Contains(cfg.Scopes, "cloud-platform") { + t.Errorf("应保留 cloud-platform scope,实际: %q", cfg.Scopes) + } +} + +func TestEffectiveOAuthConfig_UnknownOAuthType_DefaultScopes(t *testing.T) { + t.Setenv(GeminiCLIOAuthClientSecretEnv, "test-built-in-secret") + + // 未知的 oauthType 应回退到默认的 code_assist scopes + cfg, err := EffectiveOAuthConfig(OAuthConfig{}, "unknown_type") + if err != nil { + t.Fatalf("EffectiveOAuthConfig() error = %v", err) + } + if cfg.Scopes != DefaultCodeAssistScopes { + t.Errorf("未知 oauthType 应使用 DefaultCodeAssistScopes,实际: %q", cfg.Scopes) + } +} + +func TestEffectiveOAuthConfig_EmptyOAuthType_DefaultScopes(t *testing.T) { + t.Setenv(GeminiCLIOAuthClientSecretEnv, "test-built-in-secret") + + // 空的 oauthType 应走 default 分支,使用 DefaultCodeAssistScopes + cfg, err := EffectiveOAuthConfig(OAuthConfig{}, "") + if err != nil { + t.Fatalf("EffectiveOAuthConfig() error = %v", err) + } + if cfg.Scopes != DefaultCodeAssistScopes { + t.Errorf("空 oauthType 应使用 DefaultCodeAssistScopes,实际: %q", cfg.Scopes) + } +} + +func TestEffectiveOAuthConfig_CustomClient_NoScopeFiltering(t *testing.T) { + // 自定义客户端 + google_one + 包含受限 scopes -> 不应被过滤(因为不是内置客户端) + cfg, err := EffectiveOAuthConfig(OAuthConfig{ + ClientID: "custom-id", + ClientSecret: "custom-secret", + Scopes: "https://www.googleapis.com/auth/generative-language.retriever https://www.googleapis.com/auth/drive.readonly", + }, "google_one") + if err != nil { + t.Fatalf("EffectiveOAuthConfig() error = %v", err) + } + // 自定义客户端不应过滤任何 scope + if !strings.Contains(cfg.Scopes, "generative-language.retriever") { + t.Errorf("自定义客户端不应过滤 generative-language.retriever,实际: %q", cfg.Scopes) + } + if !strings.Contains(cfg.Scopes, "drive.readonly") { + t.Errorf("自定义客户端不应过滤 drive.readonly,实际: %q", cfg.Scopes) } } diff --git a/backend/internal/pkg/response/response_test.go b/backend/internal/pkg/response/response_test.go index ef31ca3c..3c12f5f4 100644 --- a/backend/internal/pkg/response/response_test.go +++ b/backend/internal/pkg/response/response_test.go @@ -14,6 +14,44 @@ import ( "github.com/stretchr/testify/require" ) +// ---------- 辅助函数 ---------- + +// parseResponseBody 从 httptest.ResponseRecorder 中解析 JSON 响应体 +func parseResponseBody(t *testing.T, w *httptest.ResponseRecorder) Response { + t.Helper() + var got Response + require.NoError(t, json.Unmarshal(w.Body.Bytes(), &got)) + return got +} + +// parsePaginatedBody 从响应体中解析分页数据(Data 字段是 PaginatedData) +func parsePaginatedBody(t *testing.T, w *httptest.ResponseRecorder) (Response, PaginatedData) { + t.Helper() + // 先用 raw json 解析,因为 Data 是 any 类型 + var raw struct { + Code int `json:"code"` + Message string `json:"message"` + Reason string `json:"reason,omitempty"` + Data json.RawMessage `json:"data,omitempty"` + } + require.NoError(t, json.Unmarshal(w.Body.Bytes(), &raw)) + + var pd PaginatedData + require.NoError(t, json.Unmarshal(raw.Data, &pd)) + + return Response{Code: raw.Code, Message: raw.Message, Reason: raw.Reason}, pd +} + +// newContextWithQuery 创建一个带有 URL query 参数的 gin.Context 用于测试 ParsePagination +func newContextWithQuery(query string) (*httptest.ResponseRecorder, *gin.Context) { + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + c.Request = httptest.NewRequest(http.MethodGet, "/?"+query, nil) + return w, c +} + +// ---------- 现有测试 ---------- + func TestErrorWithDetails(t *testing.T) { gin.SetMode(gin.TestMode) @@ -169,3 +207,582 @@ func TestErrorFrom(t *testing.T) { }) } } + +// ---------- 新增测试 ---------- + +func TestSuccess(t *testing.T) { + gin.SetMode(gin.TestMode) + + tests := []struct { + name string + data any + wantCode int + wantBody Response + }{ + { + name: "返回字符串数据", + data: "hello", + wantCode: http.StatusOK, + wantBody: Response{Code: 0, Message: "success", Data: "hello"}, + }, + { + name: "返回nil数据", + data: nil, + wantCode: http.StatusOK, + wantBody: Response{Code: 0, Message: "success"}, + }, + { + name: "返回map数据", + data: map[string]string{"key": "value"}, + wantCode: http.StatusOK, + wantBody: Response{Code: 0, Message: "success"}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + + Success(c, tt.data) + + require.Equal(t, tt.wantCode, w.Code) + + // 只验证 code 和 message,data 字段类型在 JSON 反序列化时会变成 map/slice + got := parseResponseBody(t, w) + require.Equal(t, 0, got.Code) + require.Equal(t, "success", got.Message) + + if tt.data == nil { + require.Nil(t, got.Data) + } else { + require.NotNil(t, got.Data) + } + }) + } +} + +func TestCreated(t *testing.T) { + gin.SetMode(gin.TestMode) + + tests := []struct { + name string + data any + wantCode int + }{ + { + name: "创建成功_返回数据", + data: map[string]int{"id": 42}, + wantCode: http.StatusCreated, + }, + { + name: "创建成功_nil数据", + data: nil, + wantCode: http.StatusCreated, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + + Created(c, tt.data) + + require.Equal(t, tt.wantCode, w.Code) + + got := parseResponseBody(t, w) + require.Equal(t, 0, got.Code) + require.Equal(t, "success", got.Message) + }) + } +} + +func TestError(t *testing.T) { + gin.SetMode(gin.TestMode) + + tests := []struct { + name string + statusCode int + message string + }{ + { + name: "400错误", + statusCode: http.StatusBadRequest, + message: "bad request", + }, + { + name: "500错误", + statusCode: http.StatusInternalServerError, + message: "internal error", + }, + { + name: "自定义状态码", + statusCode: 418, + message: "I'm a teapot", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + + Error(c, tt.statusCode, tt.message) + + require.Equal(t, tt.statusCode, w.Code) + + got := parseResponseBody(t, w) + require.Equal(t, tt.statusCode, got.Code) + require.Equal(t, tt.message, got.Message) + require.Empty(t, got.Reason) + require.Nil(t, got.Metadata) + require.Nil(t, got.Data) + }) + } +} + +func TestBadRequest(t *testing.T) { + gin.SetMode(gin.TestMode) + + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + + BadRequest(c, "参数无效") + + require.Equal(t, http.StatusBadRequest, w.Code) + got := parseResponseBody(t, w) + require.Equal(t, http.StatusBadRequest, got.Code) + require.Equal(t, "参数无效", got.Message) +} + +func TestUnauthorized(t *testing.T) { + gin.SetMode(gin.TestMode) + + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + + Unauthorized(c, "未登录") + + require.Equal(t, http.StatusUnauthorized, w.Code) + got := parseResponseBody(t, w) + require.Equal(t, http.StatusUnauthorized, got.Code) + require.Equal(t, "未登录", got.Message) +} + +func TestForbidden(t *testing.T) { + gin.SetMode(gin.TestMode) + + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + + Forbidden(c, "无权限") + + require.Equal(t, http.StatusForbidden, w.Code) + got := parseResponseBody(t, w) + require.Equal(t, http.StatusForbidden, got.Code) + require.Equal(t, "无权限", got.Message) +} + +func TestNotFound(t *testing.T) { + gin.SetMode(gin.TestMode) + + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + + NotFound(c, "资源不存在") + + require.Equal(t, http.StatusNotFound, w.Code) + got := parseResponseBody(t, w) + require.Equal(t, http.StatusNotFound, got.Code) + require.Equal(t, "资源不存在", got.Message) +} + +func TestInternalError(t *testing.T) { + gin.SetMode(gin.TestMode) + + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + + InternalError(c, "服务器内部错误") + + require.Equal(t, http.StatusInternalServerError, w.Code) + got := parseResponseBody(t, w) + require.Equal(t, http.StatusInternalServerError, got.Code) + require.Equal(t, "服务器内部错误", got.Message) +} + +func TestPaginated(t *testing.T) { + gin.SetMode(gin.TestMode) + + tests := []struct { + name string + items any + total int64 + page int + pageSize int + wantPages int + wantTotal int64 + wantPage int + wantPageSize int + }{ + { + name: "标准分页_多页", + items: []string{"a", "b"}, + total: 25, + page: 1, + pageSize: 10, + wantPages: 3, + wantTotal: 25, + wantPage: 1, + wantPageSize: 10, + }, + { + name: "总数刚好整除", + items: []string{"a"}, + total: 20, + page: 2, + pageSize: 10, + wantPages: 2, + wantTotal: 20, + wantPage: 2, + wantPageSize: 10, + }, + { + name: "总数为0_pages至少为1", + items: []string{}, + total: 0, + page: 1, + pageSize: 10, + wantPages: 1, + wantTotal: 0, + wantPage: 1, + wantPageSize: 10, + }, + { + name: "单页数据", + items: []int{1, 2, 3}, + total: 3, + page: 1, + pageSize: 20, + wantPages: 1, + wantTotal: 3, + wantPage: 1, + wantPageSize: 20, + }, + { + name: "总数为1", + items: []string{"only"}, + total: 1, + page: 1, + pageSize: 10, + wantPages: 1, + wantTotal: 1, + wantPage: 1, + wantPageSize: 10, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + + Paginated(c, tt.items, tt.total, tt.page, tt.pageSize) + + require.Equal(t, http.StatusOK, w.Code) + + resp, pd := parsePaginatedBody(t, w) + require.Equal(t, 0, resp.Code) + require.Equal(t, "success", resp.Message) + require.Equal(t, tt.wantTotal, pd.Total) + require.Equal(t, tt.wantPage, pd.Page) + require.Equal(t, tt.wantPageSize, pd.PageSize) + require.Equal(t, tt.wantPages, pd.Pages) + }) + } +} + +func TestPaginatedWithResult(t *testing.T) { + gin.SetMode(gin.TestMode) + + tests := []struct { + name string + items any + pagination *PaginationResult + wantTotal int64 + wantPage int + wantPageSize int + wantPages int + }{ + { + name: "正常分页结果", + items: []string{"a", "b"}, + pagination: &PaginationResult{ + Total: 50, + Page: 3, + PageSize: 10, + Pages: 5, + }, + wantTotal: 50, + wantPage: 3, + wantPageSize: 10, + wantPages: 5, + }, + { + name: "pagination为nil_使用默认值", + items: []string{}, + pagination: nil, + wantTotal: 0, + wantPage: 1, + wantPageSize: 20, + wantPages: 1, + }, + { + name: "单页结果", + items: []int{1}, + pagination: &PaginationResult{ + Total: 1, + Page: 1, + PageSize: 20, + Pages: 1, + }, + wantTotal: 1, + wantPage: 1, + wantPageSize: 20, + wantPages: 1, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + + PaginatedWithResult(c, tt.items, tt.pagination) + + require.Equal(t, http.StatusOK, w.Code) + + resp, pd := parsePaginatedBody(t, w) + require.Equal(t, 0, resp.Code) + require.Equal(t, "success", resp.Message) + require.Equal(t, tt.wantTotal, pd.Total) + require.Equal(t, tt.wantPage, pd.Page) + require.Equal(t, tt.wantPageSize, pd.PageSize) + require.Equal(t, tt.wantPages, pd.Pages) + }) + } +} + +func TestParsePagination(t *testing.T) { + gin.SetMode(gin.TestMode) + + tests := []struct { + name string + query string + wantPage int + wantPageSize int + }{ + { + name: "无参数_使用默认值", + query: "", + wantPage: 1, + wantPageSize: 20, + }, + { + name: "仅指定page", + query: "page=3", + wantPage: 3, + wantPageSize: 20, + }, + { + name: "仅指定page_size", + query: "page_size=50", + wantPage: 1, + wantPageSize: 50, + }, + { + name: "同时指定page和page_size", + query: "page=2&page_size=30", + wantPage: 2, + wantPageSize: 30, + }, + { + name: "使用limit代替page_size", + query: "limit=15", + wantPage: 1, + wantPageSize: 15, + }, + { + name: "page_size优先于limit", + query: "page_size=25&limit=50", + wantPage: 1, + wantPageSize: 25, + }, + { + name: "page为0_使用默认值", + query: "page=0", + wantPage: 1, + wantPageSize: 20, + }, + { + name: "page_size超过1000_使用默认值", + query: "page_size=1001", + wantPage: 1, + wantPageSize: 20, + }, + { + name: "page_size恰好1000_有效", + query: "page_size=1000", + wantPage: 1, + wantPageSize: 1000, + }, + { + name: "page为非数字_使用默认值", + query: "page=abc", + wantPage: 1, + wantPageSize: 20, + }, + { + name: "page_size为非数字_使用默认值", + query: "page_size=xyz", + wantPage: 1, + wantPageSize: 20, + }, + { + name: "limit为非数字_使用默认值", + query: "limit=abc", + wantPage: 1, + wantPageSize: 20, + }, + { + name: "page_size为0_使用默认值", + query: "page_size=0", + wantPage: 1, + wantPageSize: 20, + }, + { + name: "limit为0_使用默认值", + query: "limit=0", + wantPage: 1, + wantPageSize: 20, + }, + { + name: "大页码", + query: "page=999&page_size=100", + wantPage: 999, + wantPageSize: 100, + }, + { + name: "page_size为1_最小有效值", + query: "page_size=1", + wantPage: 1, + wantPageSize: 1, + }, + { + name: "混合数字和字母的page", + query: "page=12a", + wantPage: 1, + wantPageSize: 20, + }, + { + name: "limit超过1000_使用默认值", + query: "limit=2000", + wantPage: 1, + wantPageSize: 20, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + _, c := newContextWithQuery(tt.query) + + page, pageSize := ParsePagination(c) + + require.Equal(t, tt.wantPage, page, "page 不符合预期") + require.Equal(t, tt.wantPageSize, pageSize, "pageSize 不符合预期") + }) + } +} + +func Test_parseInt(t *testing.T) { + tests := []struct { + name string + input string + wantVal int + wantErr bool + }{ + { + name: "正常数字", + input: "123", + wantVal: 123, + wantErr: false, + }, + { + name: "零", + input: "0", + wantVal: 0, + wantErr: false, + }, + { + name: "单个数字", + input: "5", + wantVal: 5, + wantErr: false, + }, + { + name: "大数字", + input: "99999", + wantVal: 99999, + wantErr: false, + }, + { + name: "包含字母_返回0", + input: "abc", + wantVal: 0, + wantErr: false, + }, + { + name: "数字开头接字母_返回0", + input: "12a", + wantVal: 0, + wantErr: false, + }, + { + name: "包含负号_返回0", + input: "-1", + wantVal: 0, + wantErr: false, + }, + { + name: "包含小数点_返回0", + input: "1.5", + wantVal: 0, + wantErr: false, + }, + { + name: "包含空格_返回0", + input: "1 2", + wantVal: 0, + wantErr: false, + }, + { + name: "空字符串", + input: "", + wantVal: 0, + wantErr: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + val, err := parseInt(tt.input) + if tt.wantErr { + require.Error(t, err) + } else { + require.NoError(t, err) + } + require.Equal(t, tt.wantVal, val) + }) + } +} diff --git a/backend/internal/service/gemini_oauth_service_test.go b/backend/internal/service/gemini_oauth_service_test.go index 5591eb39..c58a5930 100644 --- a/backend/internal/service/gemini_oauth_service_test.go +++ b/backend/internal/service/gemini_oauth_service_test.go @@ -1,17 +1,29 @@ +//go:build unit + package service import ( "context" + "fmt" "net/url" "strings" "testing" + "time" "github.com/Wei-Shaw/sub2api/internal/config" "github.com/Wei-Shaw/sub2api/internal/pkg/geminicli" + "github.com/Wei-Shaw/sub2api/internal/pkg/pagination" ) +// ===================== +// 保留原有测试 +// ===================== + func TestGeminiOAuthService_GenerateAuthURL_RedirectURIStrategy(t *testing.T) { - t.Parallel() + // NOTE: This test sets process env; it must not run in parallel. + // The built-in Gemini CLI client secret is not embedded in this repository. + // Tests set a dummy secret via env to simulate operator-provided configuration. + t.Setenv(geminicli.GeminiCLIOAuthClientSecretEnv, "test-built-in-secret") type testCase struct { name string @@ -128,3 +140,1324 @@ func TestGeminiOAuthService_GenerateAuthURL_RedirectURIStrategy(t *testing.T) { }) } } + +// ===================== +// 新增测试:validateTierID +// ===================== + +func TestValidateTierID(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + tierID string + wantErr bool + }{ + {name: "空字符串合法", tierID: "", wantErr: false}, + {name: "正常 tier_id", tierID: "google_one_free", wantErr: false}, + {name: "包含斜杠", tierID: "tier/sub", wantErr: false}, + {name: "包含连字符", tierID: "gcp-standard", wantErr: false}, + {name: "纯数字", tierID: "12345", wantErr: false}, + {name: "超长字符串(65个字符)", tierID: strings.Repeat("a", 65), wantErr: true}, + {name: "刚好64个字符", tierID: strings.Repeat("b", 64), wantErr: false}, + {name: "非法字符_空格", tierID: "tier id", wantErr: true}, + {name: "非法字符_中文", tierID: "tier_中文", wantErr: true}, + {name: "非法字符_特殊符号", tierID: "tier@id", wantErr: true}, + {name: "非法字符_感叹号", tierID: "tier!id", wantErr: true}, + } + + for _, tt := range tests { + tt := tt + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + err := validateTierID(tt.tierID) + if tt.wantErr && err == nil { + t.Fatalf("期望返回错误,但返回 nil") + } + if !tt.wantErr && err != nil { + t.Fatalf("不期望返回错误,但返回: %v", err) + } + }) + } +} + +// ===================== +// 新增测试:canonicalGeminiTierID +// ===================== + +func TestCanonicalGeminiTierID(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + raw string + want string + }{ + // 空值 + {name: "空字符串", raw: "", want: ""}, + {name: "纯空白", raw: " ", want: ""}, + + // 已规范化的值(直接返回) + {name: "google_one_free", raw: "google_one_free", want: GeminiTierGoogleOneFree}, + {name: "google_ai_pro", raw: "google_ai_pro", want: GeminiTierGoogleAIPro}, + {name: "google_ai_ultra", raw: "google_ai_ultra", want: GeminiTierGoogleAIUltra}, + {name: "gcp_standard", raw: "gcp_standard", want: GeminiTierGCPStandard}, + {name: "gcp_enterprise", raw: "gcp_enterprise", want: GeminiTierGCPEnterprise}, + {name: "aistudio_free", raw: "aistudio_free", want: GeminiTierAIStudioFree}, + {name: "aistudio_paid", raw: "aistudio_paid", want: GeminiTierAIStudioPaid}, + {name: "google_one_unknown", raw: "google_one_unknown", want: GeminiTierGoogleOneUnknown}, + + // 大小写不敏感 + {name: "Google_One_Free 大写", raw: "Google_One_Free", want: GeminiTierGoogleOneFree}, + {name: "GCP_STANDARD 全大写", raw: "GCP_STANDARD", want: GeminiTierGCPStandard}, + + // legacy 映射: Google One + {name: "AI_PREMIUM -> google_ai_pro", raw: "AI_PREMIUM", want: GeminiTierGoogleAIPro}, + {name: "FREE -> google_one_free", raw: "FREE", want: GeminiTierGoogleOneFree}, + {name: "GOOGLE_ONE_BASIC -> google_one_free", raw: "GOOGLE_ONE_BASIC", want: GeminiTierGoogleOneFree}, + {name: "GOOGLE_ONE_STANDARD -> google_one_free", raw: "GOOGLE_ONE_STANDARD", want: GeminiTierGoogleOneFree}, + {name: "GOOGLE_ONE_UNLIMITED -> google_ai_ultra", raw: "GOOGLE_ONE_UNLIMITED", want: GeminiTierGoogleAIUltra}, + {name: "GOOGLE_ONE_UNKNOWN -> google_one_unknown", raw: "GOOGLE_ONE_UNKNOWN", want: GeminiTierGoogleOneUnknown}, + + // legacy 映射: Code Assist + {name: "STANDARD -> gcp_standard", raw: "STANDARD", want: GeminiTierGCPStandard}, + {name: "PRO -> gcp_standard", raw: "PRO", want: GeminiTierGCPStandard}, + {name: "LEGACY -> gcp_standard", raw: "LEGACY", want: GeminiTierGCPStandard}, + {name: "ENTERPRISE -> gcp_enterprise", raw: "ENTERPRISE", want: GeminiTierGCPEnterprise}, + {name: "ULTRA -> gcp_enterprise", raw: "ULTRA", want: GeminiTierGCPEnterprise}, + + // kebab-case + {name: "standard-tier -> gcp_standard", raw: "standard-tier", want: GeminiTierGCPStandard}, + {name: "pro-tier -> gcp_standard", raw: "pro-tier", want: GeminiTierGCPStandard}, + {name: "ultra-tier -> gcp_enterprise", raw: "ultra-tier", want: GeminiTierGCPEnterprise}, + + // 未知值 + {name: "unknown_value -> 空", raw: "unknown_value", want: ""}, + {name: "random-text -> 空", raw: "random-text", want: ""}, + + // 带空白 + {name: "带前后空白", raw: " google_one_free ", want: GeminiTierGoogleOneFree}, + } + + for _, tt := range tests { + tt := tt + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + got := canonicalGeminiTierID(tt.raw) + if got != tt.want { + t.Fatalf("canonicalGeminiTierID(%q) = %q, want %q", tt.raw, got, tt.want) + } + }) + } +} + +// ===================== +// 新增测试:canonicalGeminiTierIDForOAuthType +// ===================== + +func TestCanonicalGeminiTierIDForOAuthType(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + oauthType string + tierID string + want string + }{ + // google_one 类型过滤 + {name: "google_one + google_one_free", oauthType: "google_one", tierID: "google_one_free", want: GeminiTierGoogleOneFree}, + {name: "google_one + google_ai_pro", oauthType: "google_one", tierID: "google_ai_pro", want: GeminiTierGoogleAIPro}, + {name: "google_one + google_ai_ultra", oauthType: "google_one", tierID: "google_ai_ultra", want: GeminiTierGoogleAIUltra}, + {name: "google_one + gcp_standard 被过滤", oauthType: "google_one", tierID: "gcp_standard", want: ""}, + {name: "google_one + aistudio_free 被过滤", oauthType: "google_one", tierID: "aistudio_free", want: ""}, + {name: "google_one + AI_PREMIUM 遗留映射", oauthType: "google_one", tierID: "AI_PREMIUM", want: GeminiTierGoogleAIPro}, + + // code_assist 类型过滤 + {name: "code_assist + gcp_standard", oauthType: "code_assist", tierID: "gcp_standard", want: GeminiTierGCPStandard}, + {name: "code_assist + gcp_enterprise", oauthType: "code_assist", tierID: "gcp_enterprise", want: GeminiTierGCPEnterprise}, + {name: "code_assist + google_one_free 被过滤", oauthType: "code_assist", tierID: "google_one_free", want: ""}, + {name: "code_assist + aistudio_free 被过滤", oauthType: "code_assist", tierID: "aistudio_free", want: ""}, + {name: "code_assist + STANDARD 遗留映射", oauthType: "code_assist", tierID: "STANDARD", want: GeminiTierGCPStandard}, + {name: "code_assist + standard-tier kebab", oauthType: "code_assist", tierID: "standard-tier", want: GeminiTierGCPStandard}, + + // ai_studio 类型过滤 + {name: "ai_studio + aistudio_free", oauthType: "ai_studio", tierID: "aistudio_free", want: GeminiTierAIStudioFree}, + {name: "ai_studio + aistudio_paid", oauthType: "ai_studio", tierID: "aistudio_paid", want: GeminiTierAIStudioPaid}, + {name: "ai_studio + gcp_standard 被过滤", oauthType: "ai_studio", tierID: "gcp_standard", want: ""}, + {name: "ai_studio + google_one_free 被过滤", oauthType: "ai_studio", tierID: "google_one_free", want: ""}, + + // 空值 + {name: "空 tierID", oauthType: "google_one", tierID: "", want: ""}, + {name: "空 oauthType + 有效 tierID", oauthType: "", tierID: "gcp_standard", want: GeminiTierGCPStandard}, + {name: "未知 oauthType 接受规范化值", oauthType: "unknown_type", tierID: "gcp_standard", want: GeminiTierGCPStandard}, + + // oauthType 大小写和空白 + {name: "GOOGLE_ONE 大写", oauthType: "GOOGLE_ONE", tierID: "google_one_free", want: GeminiTierGoogleOneFree}, + {name: "oauthType 带空白", oauthType: " code_assist ", tierID: "gcp_standard", want: GeminiTierGCPStandard}, + } + + for _, tt := range tests { + tt := tt + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + got := canonicalGeminiTierIDForOAuthType(tt.oauthType, tt.tierID) + if got != tt.want { + t.Fatalf("canonicalGeminiTierIDForOAuthType(%q, %q) = %q, want %q", tt.oauthType, tt.tierID, got, tt.want) + } + }) + } +} + +// ===================== +// 新增测试:extractTierIDFromAllowedTiers +// ===================== + +func TestExtractTierIDFromAllowedTiers(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + allowedTiers []geminicli.AllowedTier + want string + }{ + { + name: "nil 列表返回 LEGACY", + allowedTiers: nil, + want: "LEGACY", + }, + { + name: "空列表返回 LEGACY", + allowedTiers: []geminicli.AllowedTier{}, + want: "LEGACY", + }, + { + name: "有 IsDefault 的 tier", + allowedTiers: []geminicli.AllowedTier{ + {ID: "STANDARD", IsDefault: false}, + {ID: "PRO", IsDefault: true}, + {ID: "ENTERPRISE", IsDefault: false}, + }, + want: "PRO", + }, + { + name: "没有 IsDefault 取第一个非空", + allowedTiers: []geminicli.AllowedTier{ + {ID: "STANDARD", IsDefault: false}, + {ID: "ENTERPRISE", IsDefault: false}, + }, + want: "STANDARD", + }, + { + name: "IsDefault 的 ID 为空,取第一个非空", + allowedTiers: []geminicli.AllowedTier{ + {ID: "", IsDefault: true}, + {ID: "PRO", IsDefault: false}, + }, + want: "PRO", + }, + { + name: "所有 ID 都为空返回 LEGACY", + allowedTiers: []geminicli.AllowedTier{ + {ID: "", IsDefault: false}, + {ID: " ", IsDefault: false}, + }, + want: "LEGACY", + }, + { + name: "ID 带空白会被 trim", + allowedTiers: []geminicli.AllowedTier{ + {ID: " STANDARD ", IsDefault: true}, + }, + want: "STANDARD", + }, + { + name: "单个 tier 且 IsDefault", + allowedTiers: []geminicli.AllowedTier{ + {ID: "ENTERPRISE", IsDefault: true}, + }, + want: "ENTERPRISE", + }, + { + name: "单个 tier 非 IsDefault", + allowedTiers: []geminicli.AllowedTier{ + {ID: "ENTERPRISE", IsDefault: false}, + }, + want: "ENTERPRISE", + }, + } + + for _, tt := range tests { + tt := tt + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + got := extractTierIDFromAllowedTiers(tt.allowedTiers) + if got != tt.want { + t.Fatalf("extractTierIDFromAllowedTiers() = %q, want %q", got, tt.want) + } + }) + } +} + +// ===================== +// 新增测试:inferGoogleOneTier +// ===================== + +func TestInferGoogleOneTier(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + storageBytes int64 + want string + }{ + // 边界:<= 0 + {name: "0 bytes -> unknown", storageBytes: 0, want: GeminiTierGoogleOneUnknown}, + {name: "负数 -> unknown", storageBytes: -1, want: GeminiTierGoogleOneUnknown}, + + // > 100TB -> ultra + {name: "> 100TB -> ultra", storageBytes: int64(StorageTierUnlimited) + 1, want: GeminiTierGoogleAIUltra}, + {name: "200TB -> ultra", storageBytes: 200 * int64(TB), want: GeminiTierGoogleAIUltra}, + + // >= 2TB -> pro (但 <= 100TB) + {name: "正好 2TB -> pro", storageBytes: int64(StorageTierAIPremium), want: GeminiTierGoogleAIPro}, + {name: "5TB -> pro", storageBytes: 5 * int64(TB), want: GeminiTierGoogleAIPro}, + {name: "100TB 正好 -> pro (不是 > 100TB)", storageBytes: int64(StorageTierUnlimited), want: GeminiTierGoogleAIPro}, + + // >= 15GB -> free (但 < 2TB) + {name: "正好 15GB -> free", storageBytes: int64(StorageTierFree), want: GeminiTierGoogleOneFree}, + {name: "100GB -> free", storageBytes: 100 * int64(GB), want: GeminiTierGoogleOneFree}, + {name: "略低于 2TB -> free", storageBytes: int64(StorageTierAIPremium) - 1, want: GeminiTierGoogleOneFree}, + + // < 15GB -> unknown + {name: "1GB -> unknown", storageBytes: int64(GB), want: GeminiTierGoogleOneUnknown}, + {name: "略低于 15GB -> unknown", storageBytes: int64(StorageTierFree) - 1, want: GeminiTierGoogleOneUnknown}, + {name: "1 byte -> unknown", storageBytes: 1, want: GeminiTierGoogleOneUnknown}, + } + + for _, tt := range tests { + tt := tt + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + got := inferGoogleOneTier(tt.storageBytes) + if got != tt.want { + t.Fatalf("inferGoogleOneTier(%d) = %q, want %q", tt.storageBytes, got, tt.want) + } + }) + } +} + +// ===================== +// 新增测试:isNonRetryableGeminiOAuthError +// ===================== + +func TestIsNonRetryableGeminiOAuthError(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + err error + want bool + }{ + {name: "invalid_grant", err: fmt.Errorf("error: invalid_grant"), want: true}, + {name: "invalid_client", err: fmt.Errorf("oauth error: invalid_client"), want: true}, + {name: "unauthorized_client", err: fmt.Errorf("unauthorized_client: mismatch"), want: true}, + {name: "access_denied", err: fmt.Errorf("access_denied by user"), want: true}, + {name: "普通网络错误", err: fmt.Errorf("connection timeout"), want: false}, + {name: "HTTP 500 错误", err: fmt.Errorf("server error 500"), want: false}, + {name: "空错误信息", err: fmt.Errorf(""), want: false}, + {name: "包含 invalid 但不是完整匹配", err: fmt.Errorf("invalid request"), want: false}, + } + + for _, tt := range tests { + tt := tt + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + got := isNonRetryableGeminiOAuthError(tt.err) + if got != tt.want { + t.Fatalf("isNonRetryableGeminiOAuthError(%v) = %v, want %v", tt.err, got, tt.want) + } + }) + } +} + +// ===================== +// 新增测试:BuildAccountCredentials +// ===================== + +func TestGeminiOAuthService_BuildAccountCredentials(t *testing.T) { + t.Parallel() + + svc := NewGeminiOAuthService(nil, nil, nil, &config.Config{}) + defer svc.Stop() + + t.Run("完整字段", func(t *testing.T) { + t.Parallel() + tokenInfo := &GeminiTokenInfo{ + AccessToken: "access-123", + RefreshToken: "refresh-456", + ExpiresIn: 3600, + ExpiresAt: 1700000000, + TokenType: "Bearer", + Scope: "openid email", + ProjectID: "my-project", + TierID: "gcp_standard", + OAuthType: "code_assist", + Extra: map[string]any{ + "drive_storage_limit": int64(2199023255552), + }, + } + + creds := svc.BuildAccountCredentials(tokenInfo) + + assertCredStr(t, creds, "access_token", "access-123") + assertCredStr(t, creds, "refresh_token", "refresh-456") + assertCredStr(t, creds, "token_type", "Bearer") + assertCredStr(t, creds, "scope", "openid email") + assertCredStr(t, creds, "project_id", "my-project") + assertCredStr(t, creds, "tier_id", "gcp_standard") + assertCredStr(t, creds, "oauth_type", "code_assist") + assertCredStr(t, creds, "expires_at", "1700000000") + + if _, ok := creds["drive_storage_limit"]; !ok { + t.Fatal("extra 字段 drive_storage_limit 未包含在 creds 中") + } + }) + + t.Run("最小字段(仅 access_token 和 expires_at)", func(t *testing.T) { + t.Parallel() + tokenInfo := &GeminiTokenInfo{ + AccessToken: "token-only", + ExpiresAt: 1700000000, + } + + creds := svc.BuildAccountCredentials(tokenInfo) + + assertCredStr(t, creds, "access_token", "token-only") + assertCredStr(t, creds, "expires_at", "1700000000") + + // 可选字段不应存在 + for _, key := range []string{"refresh_token", "token_type", "scope", "project_id", "tier_id", "oauth_type"} { + if _, ok := creds[key]; ok { + t.Fatalf("不应包含空字段 %q", key) + } + } + }) + + t.Run("无效 tier_id 被静默跳过", func(t *testing.T) { + t.Parallel() + tokenInfo := &GeminiTokenInfo{ + AccessToken: "token", + ExpiresAt: 1700000000, + TierID: "tier with spaces", + } + + creds := svc.BuildAccountCredentials(tokenInfo) + + if _, ok := creds["tier_id"]; ok { + t.Fatal("无效 tier_id 不应被存入 creds") + } + }) + + t.Run("超长 tier_id 被静默跳过", func(t *testing.T) { + t.Parallel() + tokenInfo := &GeminiTokenInfo{ + AccessToken: "token", + ExpiresAt: 1700000000, + TierID: strings.Repeat("x", 65), + } + + creds := svc.BuildAccountCredentials(tokenInfo) + + if _, ok := creds["tier_id"]; ok { + t.Fatal("超长 tier_id 不应被存入 creds") + } + }) + + t.Run("无 extra 字段", func(t *testing.T) { + t.Parallel() + tokenInfo := &GeminiTokenInfo{ + AccessToken: "token", + ExpiresAt: 1700000000, + RefreshToken: "rt", + } + + creds := svc.BuildAccountCredentials(tokenInfo) + + // 仅包含基础字段 + if len(creds) != 3 { // access_token, expires_at, refresh_token + t.Fatalf("creds 字段数量不匹配: got=%d want=3, keys=%v", len(creds), credKeys(creds)) + } + }) +} + +// ===================== +// 新增测试:GetOAuthConfig +// ===================== + +func TestGeminiOAuthService_GetOAuthConfig(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + cfg *config.Config + wantEnabled bool + }{ + { + name: "无自定义 OAuth 客户端", + cfg: &config.Config{ + Gemini: config.GeminiConfig{ + OAuth: config.GeminiOAuthConfig{}, + }, + }, + wantEnabled: false, + }, + { + name: "仅 ClientID 无 ClientSecret", + cfg: &config.Config{ + Gemini: config.GeminiConfig{ + OAuth: config.GeminiOAuthConfig{ + ClientID: "custom-id", + }, + }, + }, + wantEnabled: false, + }, + { + name: "仅 ClientSecret 无 ClientID", + cfg: &config.Config{ + Gemini: config.GeminiConfig{ + OAuth: config.GeminiOAuthConfig{ + ClientSecret: "custom-secret", + }, + }, + }, + wantEnabled: false, + }, + { + name: "使用内置 Gemini CLI ClientID(不算自定义)", + cfg: &config.Config{ + Gemini: config.GeminiConfig{ + OAuth: config.GeminiOAuthConfig{ + ClientID: geminicli.GeminiCLIOAuthClientID, + ClientSecret: "some-secret", + }, + }, + }, + wantEnabled: false, + }, + { + name: "自定义 OAuth 客户端(非内置 ID)", + cfg: &config.Config{ + Gemini: config.GeminiConfig{ + OAuth: config.GeminiOAuthConfig{ + ClientID: "my-custom-client-id", + ClientSecret: "my-custom-client-secret", + }, + }, + }, + wantEnabled: true, + }, + { + name: "带空白的自定义客户端", + cfg: &config.Config{ + Gemini: config.GeminiConfig{ + OAuth: config.GeminiOAuthConfig{ + ClientID: " my-custom-client-id ", + ClientSecret: " my-custom-client-secret ", + }, + }, + }, + wantEnabled: true, + }, + { + name: "纯空白字符串不算配置", + cfg: &config.Config{ + Gemini: config.GeminiConfig{ + OAuth: config.GeminiOAuthConfig{ + ClientID: " ", + ClientSecret: " ", + }, + }, + }, + wantEnabled: false, + }, + } + + for _, tt := range tests { + tt := tt + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + svc := NewGeminiOAuthService(nil, nil, nil, tt.cfg) + defer svc.Stop() + + result := svc.GetOAuthConfig() + if result.AIStudioOAuthEnabled != tt.wantEnabled { + t.Fatalf("AIStudioOAuthEnabled = %v, want %v", result.AIStudioOAuthEnabled, tt.wantEnabled) + } + // RequiredRedirectURIs 始终包含 AI Studio redirect URI + if len(result.RequiredRedirectURIs) != 1 || result.RequiredRedirectURIs[0] != geminicli.AIStudioOAuthRedirectURI { + t.Fatalf("RequiredRedirectURIs 不匹配: got=%v", result.RequiredRedirectURIs) + } + }) + } +} + +// ===================== +// 新增测试:GeminiOAuthService.Stop +// ===================== + +func TestGeminiOAuthService_Stop_NoPanic(t *testing.T) { + t.Parallel() + + svc := NewGeminiOAuthService(nil, nil, nil, &config.Config{}) + + // 调用 Stop 不应 panic + svc.Stop() + // 多次调用也不应 panic + svc.Stop() +} + +// ===================== +// mock: GeminiOAuthClient +// ===================== + +type mockGeminiOAuthClient struct { + exchangeCodeFunc func(ctx context.Context, oauthType, code, codeVerifier, redirectURI, proxyURL string) (*geminicli.TokenResponse, error) + refreshTokenFunc func(ctx context.Context, oauthType, refreshToken, proxyURL string) (*geminicli.TokenResponse, error) +} + +func (m *mockGeminiOAuthClient) ExchangeCode(ctx context.Context, oauthType, code, codeVerifier, redirectURI, proxyURL string) (*geminicli.TokenResponse, error) { + if m.exchangeCodeFunc != nil { + return m.exchangeCodeFunc(ctx, oauthType, code, codeVerifier, redirectURI, proxyURL) + } + panic("ExchangeCode not implemented") +} + +func (m *mockGeminiOAuthClient) RefreshToken(ctx context.Context, oauthType, refreshToken, proxyURL string) (*geminicli.TokenResponse, error) { + if m.refreshTokenFunc != nil { + return m.refreshTokenFunc(ctx, oauthType, refreshToken, proxyURL) + } + panic("RefreshToken not implemented") +} + +// ===================== +// mock: GeminiCliCodeAssistClient +// ===================== + +type mockGeminiCodeAssistClient struct { + loadCodeAssistFunc func(ctx context.Context, accessToken, proxyURL string, req *geminicli.LoadCodeAssistRequest) (*geminicli.LoadCodeAssistResponse, error) + onboardUserFunc func(ctx context.Context, accessToken, proxyURL string, req *geminicli.OnboardUserRequest) (*geminicli.OnboardUserResponse, error) +} + +func (m *mockGeminiCodeAssistClient) LoadCodeAssist(ctx context.Context, accessToken, proxyURL string, req *geminicli.LoadCodeAssistRequest) (*geminicli.LoadCodeAssistResponse, error) { + if m.loadCodeAssistFunc != nil { + return m.loadCodeAssistFunc(ctx, accessToken, proxyURL, req) + } + panic("LoadCodeAssist not implemented") +} + +func (m *mockGeminiCodeAssistClient) OnboardUser(ctx context.Context, accessToken, proxyURL string, req *geminicli.OnboardUserRequest) (*geminicli.OnboardUserResponse, error) { + if m.onboardUserFunc != nil { + return m.onboardUserFunc(ctx, accessToken, proxyURL, req) + } + panic("OnboardUser not implemented") +} + +// ===================== +// mock: ProxyRepository (最小实现) +// ===================== + +type mockGeminiProxyRepo struct { + getByIDFunc func(ctx context.Context, id int64) (*Proxy, error) +} + +func (m *mockGeminiProxyRepo) Create(ctx context.Context, proxy *Proxy) error { panic("not impl") } +func (m *mockGeminiProxyRepo) GetByID(ctx context.Context, id int64) (*Proxy, error) { + if m.getByIDFunc != nil { + return m.getByIDFunc(ctx, id) + } + return nil, fmt.Errorf("proxy not found") +} +func (m *mockGeminiProxyRepo) ListByIDs(ctx context.Context, ids []int64) ([]Proxy, error) { + panic("not impl") +} +func (m *mockGeminiProxyRepo) Update(ctx context.Context, proxy *Proxy) error { panic("not impl") } +func (m *mockGeminiProxyRepo) Delete(ctx context.Context, id int64) error { panic("not impl") } +func (m *mockGeminiProxyRepo) List(ctx context.Context, params pagination.PaginationParams) ([]Proxy, *pagination.PaginationResult, error) { + panic("not impl") +} +func (m *mockGeminiProxyRepo) ListWithFilters(ctx context.Context, params pagination.PaginationParams, protocol, status, search string) ([]Proxy, *pagination.PaginationResult, error) { + panic("not impl") +} +func (m *mockGeminiProxyRepo) ListWithFiltersAndAccountCount(ctx context.Context, params pagination.PaginationParams, protocol, status, search string) ([]ProxyWithAccountCount, *pagination.PaginationResult, error) { + panic("not impl") +} +func (m *mockGeminiProxyRepo) ListActive(ctx context.Context) ([]Proxy, error) { panic("not impl") } +func (m *mockGeminiProxyRepo) ListActiveWithAccountCount(ctx context.Context) ([]ProxyWithAccountCount, error) { + panic("not impl") +} +func (m *mockGeminiProxyRepo) ExistsByHostPortAuth(ctx context.Context, host string, port int, username, password string) (bool, error) { + panic("not impl") +} +func (m *mockGeminiProxyRepo) CountAccountsByProxyID(ctx context.Context, proxyID int64) (int64, error) { + panic("not impl") +} +func (m *mockGeminiProxyRepo) ListAccountSummariesByProxyID(ctx context.Context, proxyID int64) ([]ProxyAccountSummary, error) { + panic("not impl") +} + +// ===================== +// 新增测试:GeminiOAuthService.RefreshToken(含重试逻辑) +// ===================== + +func TestGeminiOAuthService_RefreshToken_Success(t *testing.T) { + t.Parallel() + + client := &mockGeminiOAuthClient{ + refreshTokenFunc: func(ctx context.Context, oauthType, refreshToken, proxyURL string) (*geminicli.TokenResponse, error) { + return &geminicli.TokenResponse{ + AccessToken: "new-access", + RefreshToken: "new-refresh", + TokenType: "Bearer", + ExpiresIn: 3600, + Scope: "openid", + }, nil + }, + } + + svc := NewGeminiOAuthService(nil, client, nil, &config.Config{}) + defer svc.Stop() + + info, err := svc.RefreshToken(context.Background(), "code_assist", "old-refresh", "") + if err != nil { + t.Fatalf("RefreshToken 返回错误: %v", err) + } + if info.AccessToken != "new-access" { + t.Fatalf("AccessToken 不匹配: got=%q", info.AccessToken) + } + if info.RefreshToken != "new-refresh" { + t.Fatalf("RefreshToken 不匹配: got=%q", info.RefreshToken) + } + if info.ExpiresAt == 0 { + t.Fatal("ExpiresAt 不应为 0") + } +} + +func TestGeminiOAuthService_RefreshToken_NonRetryableError(t *testing.T) { + t.Parallel() + + client := &mockGeminiOAuthClient{ + refreshTokenFunc: func(ctx context.Context, oauthType, refreshToken, proxyURL string) (*geminicli.TokenResponse, error) { + return nil, fmt.Errorf("invalid_grant: token revoked") + }, + } + + svc := NewGeminiOAuthService(nil, client, nil, &config.Config{}) + defer svc.Stop() + + _, err := svc.RefreshToken(context.Background(), "code_assist", "revoked-token", "") + if err == nil { + t.Fatal("RefreshToken 应返回错误(不可重试的 invalid_grant)") + } + if !strings.Contains(err.Error(), "invalid_grant") { + t.Fatalf("错误应包含 invalid_grant: got=%q", err.Error()) + } +} + +func TestGeminiOAuthService_RefreshToken_RetryableError(t *testing.T) { + t.Parallel() + + callCount := 0 + client := &mockGeminiOAuthClient{ + refreshTokenFunc: func(ctx context.Context, oauthType, refreshToken, proxyURL string) (*geminicli.TokenResponse, error) { + callCount++ + if callCount <= 2 { + return nil, fmt.Errorf("temporary network error") + } + return &geminicli.TokenResponse{ + AccessToken: "recovered", + ExpiresIn: 3600, + }, nil + }, + } + + svc := NewGeminiOAuthService(nil, client, nil, &config.Config{}) + defer svc.Stop() + + info, err := svc.RefreshToken(context.Background(), "code_assist", "rt", "") + if err != nil { + t.Fatalf("RefreshToken 应在重试后成功: %v", err) + } + if info.AccessToken != "recovered" { + t.Fatalf("AccessToken 不匹配: got=%q", info.AccessToken) + } + if callCount < 3 { + t.Fatalf("应至少调用 3 次(2 次失败 + 1 次成功): got=%d", callCount) + } +} + +// ===================== +// 新增测试:GeminiOAuthService.RefreshAccountToken +// ===================== + +func TestGeminiOAuthService_RefreshAccountToken_NotGeminiOAuth(t *testing.T) { + t.Parallel() + + svc := NewGeminiOAuthService(nil, nil, nil, &config.Config{}) + defer svc.Stop() + + account := &Account{ + Platform: PlatformAnthropic, + Type: AccountTypeOAuth, + } + + _, err := svc.RefreshAccountToken(context.Background(), account) + if err == nil { + t.Fatal("应返回错误(非 Gemini OAuth 账号)") + } + if !strings.Contains(err.Error(), "not a Gemini OAuth account") { + t.Fatalf("错误信息不匹配: got=%q", err.Error()) + } +} + +func TestGeminiOAuthService_RefreshAccountToken_NoRefreshToken(t *testing.T) { + t.Parallel() + + svc := NewGeminiOAuthService(nil, nil, nil, &config.Config{}) + defer svc.Stop() + + account := &Account{ + Platform: PlatformGemini, + Type: AccountTypeOAuth, + Credentials: map[string]any{ + "access_token": "at", + "oauth_type": "code_assist", + }, + } + + _, err := svc.RefreshAccountToken(context.Background(), account) + if err == nil { + t.Fatal("应返回错误(无 refresh_token)") + } + if !strings.Contains(err.Error(), "no refresh token") { + t.Fatalf("错误信息不匹配: got=%q", err.Error()) + } +} + +func TestGeminiOAuthService_RefreshAccountToken_AIStudio(t *testing.T) { + t.Parallel() + + client := &mockGeminiOAuthClient{ + refreshTokenFunc: func(ctx context.Context, oauthType, refreshToken, proxyURL string) (*geminicli.TokenResponse, error) { + return &geminicli.TokenResponse{ + AccessToken: "refreshed-at", + RefreshToken: "refreshed-rt", + ExpiresIn: 3600, + TokenType: "Bearer", + }, nil + }, + } + + svc := NewGeminiOAuthService(&mockGeminiProxyRepo{}, client, nil, &config.Config{}) + defer svc.Stop() + + account := &Account{ + Platform: PlatformGemini, + Type: AccountTypeOAuth, + Credentials: map[string]any{ + "access_token": "old-at", + "refresh_token": "old-rt", + "oauth_type": "ai_studio", + "tier_id": "aistudio_free", + }, + } + + info, err := svc.RefreshAccountToken(context.Background(), account) + if err != nil { + t.Fatalf("RefreshAccountToken 返回错误: %v", err) + } + if info.AccessToken != "refreshed-at" { + t.Fatalf("AccessToken 不匹配: got=%q", info.AccessToken) + } + if info.OAuthType != "ai_studio" { + t.Fatalf("OAuthType 不匹配: got=%q", info.OAuthType) + } +} + +func TestGeminiOAuthService_RefreshAccountToken_CodeAssist_WithProjectID(t *testing.T) { + t.Parallel() + + client := &mockGeminiOAuthClient{ + refreshTokenFunc: func(ctx context.Context, oauthType, refreshToken, proxyURL string) (*geminicli.TokenResponse, error) { + return &geminicli.TokenResponse{ + AccessToken: "refreshed", + RefreshToken: "new-rt", + ExpiresIn: 3600, + }, nil + }, + } + + svc := NewGeminiOAuthService(&mockGeminiProxyRepo{}, client, nil, &config.Config{}) + defer svc.Stop() + + account := &Account{ + Platform: PlatformGemini, + Type: AccountTypeOAuth, + Credentials: map[string]any{ + "access_token": "old-at", + "refresh_token": "old-rt", + "oauth_type": "code_assist", + "project_id": "my-project", + "tier_id": "gcp_standard", + }, + } + + info, err := svc.RefreshAccountToken(context.Background(), account) + if err != nil { + t.Fatalf("RefreshAccountToken 返回错误: %v", err) + } + if info.ProjectID != "my-project" { + t.Fatalf("ProjectID 应保留: got=%q", info.ProjectID) + } + if info.TierID != GeminiTierGCPStandard { + t.Fatalf("TierID 不匹配: got=%q want=%q", info.TierID, GeminiTierGCPStandard) + } + if info.OAuthType != "code_assist" { + t.Fatalf("OAuthType 不匹配: got=%q", info.OAuthType) + } +} + +func TestGeminiOAuthService_RefreshAccountToken_DefaultOAuthType(t *testing.T) { + t.Parallel() + + client := &mockGeminiOAuthClient{ + refreshTokenFunc: func(ctx context.Context, oauthType, refreshToken, proxyURL string) (*geminicli.TokenResponse, error) { + if oauthType != "code_assist" { + t.Errorf("默认 oauthType 应为 code_assist: got=%q", oauthType) + } + return &geminicli.TokenResponse{ + AccessToken: "refreshed", + ExpiresIn: 3600, + }, nil + }, + } + + svc := NewGeminiOAuthService(&mockGeminiProxyRepo{}, client, nil, &config.Config{}) + defer svc.Stop() + + // 无 oauth_type 凭据的旧账号 + account := &Account{ + Platform: PlatformGemini, + Type: AccountTypeOAuth, + Credentials: map[string]any{ + "refresh_token": "old-rt", + "project_id": "proj", + "tier_id": "STANDARD", + }, + } + + info, err := svc.RefreshAccountToken(context.Background(), account) + if err != nil { + t.Fatalf("RefreshAccountToken 返回错误: %v", err) + } + if info.OAuthType != "code_assist" { + t.Fatalf("OAuthType 应默认为 code_assist: got=%q", info.OAuthType) + } +} + +func TestGeminiOAuthService_RefreshAccountToken_WithProxy(t *testing.T) { + t.Parallel() + + proxyRepo := &mockGeminiProxyRepo{ + getByIDFunc: func(ctx context.Context, id int64) (*Proxy, error) { + return &Proxy{ + Protocol: "http", + Host: "proxy.test", + Port: 3128, + }, nil + }, + } + + client := &mockGeminiOAuthClient{ + refreshTokenFunc: func(ctx context.Context, oauthType, refreshToken, proxyURL string) (*geminicli.TokenResponse, error) { + if proxyURL != "http://proxy.test:3128" { + t.Errorf("proxyURL 不匹配: got=%q", proxyURL) + } + return &geminicli.TokenResponse{ + AccessToken: "refreshed", + ExpiresIn: 3600, + }, nil + }, + } + + svc := NewGeminiOAuthService(proxyRepo, client, nil, &config.Config{}) + defer svc.Stop() + + proxyID := int64(5) + account := &Account{ + Platform: PlatformGemini, + Type: AccountTypeOAuth, + ProxyID: &proxyID, + Credentials: map[string]any{ + "refresh_token": "rt", + "oauth_type": "code_assist", + "project_id": "proj", + }, + } + + _, err := svc.RefreshAccountToken(context.Background(), account) + if err != nil { + t.Fatalf("RefreshAccountToken 返回错误: %v", err) + } +} + +func TestGeminiOAuthService_RefreshAccountToken_CodeAssist_NoProjectID_AutoDetect(t *testing.T) { + t.Parallel() + + client := &mockGeminiOAuthClient{ + refreshTokenFunc: func(ctx context.Context, oauthType, refreshToken, proxyURL string) (*geminicli.TokenResponse, error) { + return &geminicli.TokenResponse{ + AccessToken: "at", + ExpiresIn: 3600, + }, nil + }, + } + + codeAssist := &mockGeminiCodeAssistClient{ + loadCodeAssistFunc: func(ctx context.Context, accessToken, proxyURL string, req *geminicli.LoadCodeAssistRequest) (*geminicli.LoadCodeAssistResponse, error) { + return &geminicli.LoadCodeAssistResponse{ + CloudAICompanionProject: "auto-project-123", + CurrentTier: &geminicli.TierInfo{ID: "STANDARD"}, + }, nil + }, + } + + svc := NewGeminiOAuthService(&mockGeminiProxyRepo{}, client, codeAssist, &config.Config{}) + defer svc.Stop() + + account := &Account{ + Platform: PlatformGemini, + Type: AccountTypeOAuth, + Credentials: map[string]any{ + "refresh_token": "rt", + "oauth_type": "code_assist", + // 无 project_id,触发 fetchProjectID + }, + } + + info, err := svc.RefreshAccountToken(context.Background(), account) + if err != nil { + t.Fatalf("RefreshAccountToken 返回错误: %v", err) + } + if info.ProjectID != "auto-project-123" { + t.Fatalf("ProjectID 应为自动检测值: got=%q", info.ProjectID) + } + if info.TierID != GeminiTierGCPStandard { + t.Fatalf("TierID 不匹配: got=%q", info.TierID) + } +} + +func TestGeminiOAuthService_RefreshAccountToken_CodeAssist_NoProjectID_FailsEmpty(t *testing.T) { + t.Parallel() + + client := &mockGeminiOAuthClient{ + refreshTokenFunc: func(ctx context.Context, oauthType, refreshToken, proxyURL string) (*geminicli.TokenResponse, error) { + return &geminicli.TokenResponse{ + AccessToken: "at", + ExpiresIn: 3600, + }, nil + }, + } + + // 返回有 currentTier 但无 cloudaicompanionProject 的响应, + // 使 fetchProjectID 走"已注册用户"路径(尝试 Cloud Resource Manager -> 失败 -> 返回错误), + // 避免走 onboardUser 路径(5 次重试 x 2 秒 = 10 秒超时) + codeAssist := &mockGeminiCodeAssistClient{ + loadCodeAssistFunc: func(ctx context.Context, accessToken, proxyURL string, req *geminicli.LoadCodeAssistRequest) (*geminicli.LoadCodeAssistResponse, error) { + return &geminicli.LoadCodeAssistResponse{ + CurrentTier: &geminicli.TierInfo{ID: "STANDARD"}, + // 无 CloudAICompanionProject + }, nil + }, + } + + svc := NewGeminiOAuthService(&mockGeminiProxyRepo{}, client, codeAssist, &config.Config{}) + defer svc.Stop() + + account := &Account{ + Platform: PlatformGemini, + Type: AccountTypeOAuth, + Credentials: map[string]any{ + "refresh_token": "rt", + "oauth_type": "code_assist", + }, + } + + _, err := svc.RefreshAccountToken(context.Background(), account) + if err == nil { + t.Fatal("应返回错误(无法检测 project_id)") + } + if !strings.Contains(err.Error(), "project_id") { + t.Fatalf("错误信息应包含 project_id: got=%q", err.Error()) + } +} + +func TestGeminiOAuthService_RefreshAccountToken_GoogleOne_FreshCache(t *testing.T) { + t.Parallel() + + client := &mockGeminiOAuthClient{ + refreshTokenFunc: func(ctx context.Context, oauthType, refreshToken, proxyURL string) (*geminicli.TokenResponse, error) { + return &geminicli.TokenResponse{ + AccessToken: "at", + ExpiresIn: 3600, + }, nil + }, + } + + svc := NewGeminiOAuthService(&mockGeminiProxyRepo{}, client, nil, &config.Config{}) + defer svc.Stop() + + account := &Account{ + Platform: PlatformGemini, + Type: AccountTypeOAuth, + Credentials: map[string]any{ + "refresh_token": "rt", + "oauth_type": "google_one", + "project_id": "proj", + "tier_id": "google_ai_pro", + }, + Extra: map[string]any{ + // 缓存刷新时间在 24 小时内 + "drive_tier_updated_at": time.Now().Add(-1 * time.Hour).Format(time.RFC3339), + }, + } + + info, err := svc.RefreshAccountToken(context.Background(), account) + if err != nil { + t.Fatalf("RefreshAccountToken 返回错误: %v", err) + } + // 缓存新鲜,应使用已有的 tier_id + if info.TierID != GeminiTierGoogleAIPro { + t.Fatalf("TierID 应使用缓存值: got=%q want=%q", info.TierID, GeminiTierGoogleAIPro) + } +} + +func TestGeminiOAuthService_RefreshAccountToken_GoogleOne_NoTierID_DefaultsFree(t *testing.T) { + t.Parallel() + + client := &mockGeminiOAuthClient{ + refreshTokenFunc: func(ctx context.Context, oauthType, refreshToken, proxyURL string) (*geminicli.TokenResponse, error) { + return &geminicli.TokenResponse{ + AccessToken: "at", + ExpiresIn: 3600, + }, nil + }, + } + + svc := NewGeminiOAuthService(&mockGeminiProxyRepo{}, client, nil, &config.Config{}) + defer svc.Stop() + + account := &Account{ + Platform: PlatformGemini, + Type: AccountTypeOAuth, + Credentials: map[string]any{ + "refresh_token": "rt", + "oauth_type": "google_one", + "project_id": "proj", + // 无 tier_id + }, + } + + info, err := svc.RefreshAccountToken(context.Background(), account) + if err != nil { + t.Fatalf("RefreshAccountToken 返回错误: %v", err) + } + // FetchGoogleOneTier 会被调用但 oauthClient(此处 mock)不实现 Drive API, + // svc.FetchGoogleOneTier 使用真实 DriveClient 会失败,最终回退到默认值。 + // 由于没有 tier_id 且 FetchGoogleOneTier 失败,应默认为 google_one_free + if info.TierID != GeminiTierGoogleOneFree { + t.Fatalf("TierID 应为默认 free: got=%q", info.TierID) + } +} + +func TestGeminiOAuthService_RefreshAccountToken_UnauthorizedClient_Fallback(t *testing.T) { + t.Parallel() + + callCount := 0 + client := &mockGeminiOAuthClient{ + refreshTokenFunc: func(ctx context.Context, oauthType, refreshToken, proxyURL string) (*geminicli.TokenResponse, error) { + callCount++ + if oauthType == "code_assist" { + return nil, fmt.Errorf("unauthorized_client: client mismatch") + } + // ai_studio 路径成功 + return &geminicli.TokenResponse{ + AccessToken: "recovered", + ExpiresIn: 3600, + }, nil + }, + } + + // 启用自定义 OAuth 客户端以触发 fallback 路径 + cfg := &config.Config{ + Gemini: config.GeminiConfig{ + OAuth: config.GeminiOAuthConfig{ + ClientID: "custom-id", + ClientSecret: "custom-secret", + }, + }, + } + + svc := NewGeminiOAuthService(&mockGeminiProxyRepo{}, client, nil, cfg) + defer svc.Stop() + + account := &Account{ + Platform: PlatformGemini, + Type: AccountTypeOAuth, + Credentials: map[string]any{ + "refresh_token": "rt", + "oauth_type": "code_assist", + "project_id": "proj", + "tier_id": "gcp_standard", + }, + } + + info, err := svc.RefreshAccountToken(context.Background(), account) + if err != nil { + t.Fatalf("RefreshAccountToken 应在 fallback 后成功: %v", err) + } + if info.AccessToken != "recovered" { + t.Fatalf("AccessToken 不匹配: got=%q", info.AccessToken) + } +} + +func TestGeminiOAuthService_RefreshAccountToken_UnauthorizedClient_NoFallback(t *testing.T) { + t.Parallel() + + client := &mockGeminiOAuthClient{ + refreshTokenFunc: func(ctx context.Context, oauthType, refreshToken, proxyURL string) (*geminicli.TokenResponse, error) { + return nil, fmt.Errorf("unauthorized_client: client mismatch") + }, + } + + // 无自定义 OAuth 客户端,无法 fallback + svc := NewGeminiOAuthService(&mockGeminiProxyRepo{}, client, nil, &config.Config{}) + defer svc.Stop() + + account := &Account{ + Platform: PlatformGemini, + Type: AccountTypeOAuth, + Credentials: map[string]any{ + "refresh_token": "rt", + "oauth_type": "code_assist", + "project_id": "proj", + }, + } + + _, err := svc.RefreshAccountToken(context.Background(), account) + if err == nil { + t.Fatal("应返回错误(无 fallback)") + } + if !strings.Contains(err.Error(), "OAuth client mismatch") { + t.Fatalf("错误应包含 OAuth client mismatch: got=%q", err.Error()) + } +} + +// ===================== +// 新增测试:GeminiOAuthService.ExchangeCode +// ===================== + +func TestGeminiOAuthService_ExchangeCode_SessionNotFound(t *testing.T) { + t.Parallel() + + svc := NewGeminiOAuthService(nil, nil, nil, &config.Config{}) + defer svc.Stop() + + _, err := svc.ExchangeCode(context.Background(), &GeminiExchangeCodeInput{ + SessionID: "nonexistent", + State: "some-state", + Code: "some-code", + }) + if err == nil { + t.Fatal("应返回错误(session 不存在)") + } + if !strings.Contains(err.Error(), "session not found") { + t.Fatalf("错误信息不匹配: got=%q", err.Error()) + } +} + +func TestGeminiOAuthService_ExchangeCode_InvalidState(t *testing.T) { + t.Parallel() + + svc := NewGeminiOAuthService(nil, nil, nil, &config.Config{}) + defer svc.Stop() + + // 手动创建 session(必须设置 CreatedAt,否则会因 TTL 过期被拒绝) + svc.sessionStore.Set("test-session", &geminicli.OAuthSession{ + State: "correct-state", + CodeVerifier: "verifier", + OAuthType: "ai_studio", + CreatedAt: time.Now(), + }) + + _, err := svc.ExchangeCode(context.Background(), &GeminiExchangeCodeInput{ + SessionID: "test-session", + State: "wrong-state", + Code: "code", + }) + if err == nil { + t.Fatal("应返回错误(state 不匹配)") + } + if !strings.Contains(err.Error(), "invalid state") { + t.Fatalf("错误信息不匹配: got=%q", err.Error()) + } +} + +func TestGeminiOAuthService_ExchangeCode_EmptyState(t *testing.T) { + t.Parallel() + + svc := NewGeminiOAuthService(nil, nil, nil, &config.Config{}) + defer svc.Stop() + + svc.sessionStore.Set("test-session", &geminicli.OAuthSession{ + State: "correct-state", + CodeVerifier: "verifier", + CreatedAt: time.Now(), + }) + + _, err := svc.ExchangeCode(context.Background(), &GeminiExchangeCodeInput{ + SessionID: "test-session", + State: "", + Code: "code", + }) + if err == nil { + t.Fatal("应返回错误(空 state)") + } +} + +// ===================== +// 辅助函数 +// ===================== + +func assertCredStr(t *testing.T, creds map[string]any, key, want string) { + t.Helper() + raw, ok := creds[key] + if !ok { + t.Fatalf("creds 缺少 key=%q", key) + } + got, ok := raw.(string) + if !ok { + t.Fatalf("creds[%q] 不是 string: %T", key, raw) + } + if got != want { + t.Fatalf("creds[%q] = %q, want %q", key, got, want) + } +} + +func credKeys(m map[string]any) []string { + keys := make([]string, 0, len(m)) + for k := range m { + keys = append(keys, k) + } + return keys +} diff --git a/backend/internal/service/oauth_service_test.go b/backend/internal/service/oauth_service_test.go new file mode 100644 index 00000000..72de4b8c --- /dev/null +++ b/backend/internal/service/oauth_service_test.go @@ -0,0 +1,607 @@ +//go:build unit + +package service + +import ( + "context" + "fmt" + "testing" + + "github.com/Wei-Shaw/sub2api/internal/pkg/oauth" + "github.com/Wei-Shaw/sub2api/internal/pkg/pagination" +) + +// --- mock: ClaudeOAuthClient --- + +type mockClaudeOAuthClient struct { + getOrgUUIDFunc func(ctx context.Context, sessionKey, proxyURL string) (string, error) + getAuthCodeFunc func(ctx context.Context, sessionKey, orgUUID, scope, codeChallenge, state, proxyURL string) (string, error) + exchangeCodeFunc func(ctx context.Context, code, codeVerifier, state, proxyURL string, isSetupToken bool) (*oauth.TokenResponse, error) + refreshTokenFunc func(ctx context.Context, refreshToken, proxyURL string) (*oauth.TokenResponse, error) +} + +func (m *mockClaudeOAuthClient) GetOrganizationUUID(ctx context.Context, sessionKey, proxyURL string) (string, error) { + if m.getOrgUUIDFunc != nil { + return m.getOrgUUIDFunc(ctx, sessionKey, proxyURL) + } + panic("GetOrganizationUUID not implemented") +} + +func (m *mockClaudeOAuthClient) GetAuthorizationCode(ctx context.Context, sessionKey, orgUUID, scope, codeChallenge, state, proxyURL string) (string, error) { + if m.getAuthCodeFunc != nil { + return m.getAuthCodeFunc(ctx, sessionKey, orgUUID, scope, codeChallenge, state, proxyURL) + } + panic("GetAuthorizationCode not implemented") +} + +func (m *mockClaudeOAuthClient) ExchangeCodeForToken(ctx context.Context, code, codeVerifier, state, proxyURL string, isSetupToken bool) (*oauth.TokenResponse, error) { + if m.exchangeCodeFunc != nil { + return m.exchangeCodeFunc(ctx, code, codeVerifier, state, proxyURL, isSetupToken) + } + panic("ExchangeCodeForToken not implemented") +} + +func (m *mockClaudeOAuthClient) RefreshToken(ctx context.Context, refreshToken, proxyURL string) (*oauth.TokenResponse, error) { + if m.refreshTokenFunc != nil { + return m.refreshTokenFunc(ctx, refreshToken, proxyURL) + } + panic("RefreshToken not implemented") +} + +// --- mock: ProxyRepository (最小实现,仅覆盖 OAuthService 依赖的方法) --- + +type mockProxyRepoForOAuth struct { + getByIDFunc func(ctx context.Context, id int64) (*Proxy, error) +} + +func (m *mockProxyRepoForOAuth) Create(ctx context.Context, proxy *Proxy) error { + panic("Create not implemented") +} +func (m *mockProxyRepoForOAuth) GetByID(ctx context.Context, id int64) (*Proxy, error) { + if m.getByIDFunc != nil { + return m.getByIDFunc(ctx, id) + } + return nil, fmt.Errorf("proxy not found") +} +func (m *mockProxyRepoForOAuth) ListByIDs(ctx context.Context, ids []int64) ([]Proxy, error) { + panic("ListByIDs not implemented") +} +func (m *mockProxyRepoForOAuth) Update(ctx context.Context, proxy *Proxy) error { + panic("Update not implemented") +} +func (m *mockProxyRepoForOAuth) Delete(ctx context.Context, id int64) error { + panic("Delete not implemented") +} +func (m *mockProxyRepoForOAuth) List(ctx context.Context, params pagination.PaginationParams) ([]Proxy, *pagination.PaginationResult, error) { + panic("List not implemented") +} +func (m *mockProxyRepoForOAuth) ListWithFilters(ctx context.Context, params pagination.PaginationParams, protocol, status, search string) ([]Proxy, *pagination.PaginationResult, error) { + panic("ListWithFilters not implemented") +} +func (m *mockProxyRepoForOAuth) ListWithFiltersAndAccountCount(ctx context.Context, params pagination.PaginationParams, protocol, status, search string) ([]ProxyWithAccountCount, *pagination.PaginationResult, error) { + panic("ListWithFiltersAndAccountCount not implemented") +} +func (m *mockProxyRepoForOAuth) ListActive(ctx context.Context) ([]Proxy, error) { + panic("ListActive not implemented") +} +func (m *mockProxyRepoForOAuth) ListActiveWithAccountCount(ctx context.Context) ([]ProxyWithAccountCount, error) { + panic("ListActiveWithAccountCount not implemented") +} +func (m *mockProxyRepoForOAuth) ExistsByHostPortAuth(ctx context.Context, host string, port int, username, password string) (bool, error) { + panic("ExistsByHostPortAuth not implemented") +} +func (m *mockProxyRepoForOAuth) CountAccountsByProxyID(ctx context.Context, proxyID int64) (int64, error) { + panic("CountAccountsByProxyID not implemented") +} +func (m *mockProxyRepoForOAuth) ListAccountSummariesByProxyID(ctx context.Context, proxyID int64) ([]ProxyAccountSummary, error) { + panic("ListAccountSummariesByProxyID not implemented") +} + +// ===================== +// 测试用例 +// ===================== + +func TestNewOAuthService(t *testing.T) { + t.Parallel() + + proxyRepo := &mockProxyRepoForOAuth{} + client := &mockClaudeOAuthClient{} + svc := NewOAuthService(proxyRepo, client) + + if svc == nil { + t.Fatal("NewOAuthService 返回 nil") + } + if svc.proxyRepo != proxyRepo { + t.Fatal("proxyRepo 未正确设置") + } + if svc.oauthClient != client { + t.Fatal("oauthClient 未正确设置") + } + if svc.sessionStore == nil { + t.Fatal("sessionStore 应被自动初始化") + } + + // 清理 + svc.Stop() +} + +func TestOAuthService_GenerateAuthURL(t *testing.T) { + t.Parallel() + + svc := NewOAuthService(&mockProxyRepoForOAuth{}, &mockClaudeOAuthClient{}) + defer svc.Stop() + + result, err := svc.GenerateAuthURL(context.Background(), nil) + if err != nil { + t.Fatalf("GenerateAuthURL 返回错误: %v", err) + } + if result == nil { + t.Fatal("GenerateAuthURL 返回 nil") + } + if result.AuthURL == "" { + t.Fatal("AuthURL 为空") + } + if result.SessionID == "" { + t.Fatal("SessionID 为空") + } + + // 验证 session 已存储 + session, ok := svc.sessionStore.Get(result.SessionID) + if !ok { + t.Fatal("session 未在 sessionStore 中找到") + } + if session.Scope != oauth.ScopeOAuth { + t.Fatalf("scope 不匹配: got=%q want=%q", session.Scope, oauth.ScopeOAuth) + } +} + +func TestOAuthService_GenerateAuthURL_WithProxy(t *testing.T) { + t.Parallel() + + proxyRepo := &mockProxyRepoForOAuth{ + getByIDFunc: func(ctx context.Context, id int64) (*Proxy, error) { + return &Proxy{ + ID: 1, + Protocol: "http", + Host: "proxy.example.com", + Port: 8080, + }, nil + }, + } + svc := NewOAuthService(proxyRepo, &mockClaudeOAuthClient{}) + defer svc.Stop() + + proxyID := int64(1) + result, err := svc.GenerateAuthURL(context.Background(), &proxyID) + if err != nil { + t.Fatalf("GenerateAuthURL 返回错误: %v", err) + } + + session, ok := svc.sessionStore.Get(result.SessionID) + if !ok { + t.Fatal("session 未在 sessionStore 中找到") + } + if session.ProxyURL != "http://proxy.example.com:8080" { + t.Fatalf("ProxyURL 不匹配: got=%q", session.ProxyURL) + } +} + +func TestOAuthService_GenerateSetupTokenURL(t *testing.T) { + t.Parallel() + + svc := NewOAuthService(&mockProxyRepoForOAuth{}, &mockClaudeOAuthClient{}) + defer svc.Stop() + + result, err := svc.GenerateSetupTokenURL(context.Background(), nil) + if err != nil { + t.Fatalf("GenerateSetupTokenURL 返回错误: %v", err) + } + if result == nil { + t.Fatal("GenerateSetupTokenURL 返回 nil") + } + + // 验证 scope 是 inference + session, ok := svc.sessionStore.Get(result.SessionID) + if !ok { + t.Fatal("session 未在 sessionStore 中找到") + } + if session.Scope != oauth.ScopeInference { + t.Fatalf("scope 不匹配: got=%q want=%q", session.Scope, oauth.ScopeInference) + } +} + +func TestOAuthService_ExchangeCode_SessionNotFound(t *testing.T) { + t.Parallel() + + svc := NewOAuthService(&mockProxyRepoForOAuth{}, &mockClaudeOAuthClient{}) + defer svc.Stop() + + _, err := svc.ExchangeCode(context.Background(), &ExchangeCodeInput{ + SessionID: "nonexistent-session", + Code: "test-code", + }) + if err == nil { + t.Fatal("ExchangeCode 应返回错误(session 不存在)") + } + if err.Error() != "session not found or expired" { + t.Fatalf("错误信息不匹配: got=%q", err.Error()) + } +} + +func TestOAuthService_ExchangeCode_Success(t *testing.T) { + t.Parallel() + + exchangeCalled := false + client := &mockClaudeOAuthClient{ + exchangeCodeFunc: func(ctx context.Context, code, codeVerifier, state, proxyURL string, isSetupToken bool) (*oauth.TokenResponse, error) { + exchangeCalled = true + if code != "auth-code-123" { + t.Errorf("code 不匹配: got=%q", code) + } + if isSetupToken { + t.Error("isSetupToken 应为 false(ScopeOAuth)") + } + return &oauth.TokenResponse{ + AccessToken: "access-token-abc", + TokenType: "Bearer", + ExpiresIn: 3600, + RefreshToken: "refresh-token-xyz", + Scope: oauth.ScopeOAuth, + Organization: &oauth.OrgInfo{UUID: "org-uuid-111"}, + Account: &oauth.AccountInfo{UUID: "acc-uuid-222", EmailAddress: "test@example.com"}, + }, nil + }, + } + + svc := NewOAuthService(&mockProxyRepoForOAuth{}, client) + defer svc.Stop() + + // 先生成 URL 以创建 session + result, err := svc.GenerateAuthURL(context.Background(), nil) + if err != nil { + t.Fatalf("GenerateAuthURL 返回错误: %v", err) + } + + // 交换 code + tokenInfo, err := svc.ExchangeCode(context.Background(), &ExchangeCodeInput{ + SessionID: result.SessionID, + Code: "auth-code-123", + }) + if err != nil { + t.Fatalf("ExchangeCode 返回错误: %v", err) + } + + if !exchangeCalled { + t.Fatal("ExchangeCodeForToken 未被调用") + } + if tokenInfo.AccessToken != "access-token-abc" { + t.Fatalf("AccessToken 不匹配: got=%q", tokenInfo.AccessToken) + } + if tokenInfo.TokenType != "Bearer" { + t.Fatalf("TokenType 不匹配: got=%q", tokenInfo.TokenType) + } + if tokenInfo.RefreshToken != "refresh-token-xyz" { + t.Fatalf("RefreshToken 不匹配: got=%q", tokenInfo.RefreshToken) + } + if tokenInfo.OrgUUID != "org-uuid-111" { + t.Fatalf("OrgUUID 不匹配: got=%q", tokenInfo.OrgUUID) + } + if tokenInfo.AccountUUID != "acc-uuid-222" { + t.Fatalf("AccountUUID 不匹配: got=%q", tokenInfo.AccountUUID) + } + if tokenInfo.EmailAddress != "test@example.com" { + t.Fatalf("EmailAddress 不匹配: got=%q", tokenInfo.EmailAddress) + } + if tokenInfo.ExpiresIn != 3600 { + t.Fatalf("ExpiresIn 不匹配: got=%d", tokenInfo.ExpiresIn) + } + if tokenInfo.ExpiresAt == 0 { + t.Fatal("ExpiresAt 不应为 0") + } + + // 验证 session 已被删除 + _, ok := svc.sessionStore.Get(result.SessionID) + if ok { + t.Fatal("session 应在交换成功后被删除") + } +} + +func TestOAuthService_ExchangeCode_SetupToken(t *testing.T) { + t.Parallel() + + client := &mockClaudeOAuthClient{ + exchangeCodeFunc: func(ctx context.Context, code, codeVerifier, state, proxyURL string, isSetupToken bool) (*oauth.TokenResponse, error) { + if !isSetupToken { + t.Error("isSetupToken 应为 true(ScopeInference)") + } + return &oauth.TokenResponse{ + AccessToken: "setup-token", + TokenType: "Bearer", + ExpiresIn: 3600, + Scope: oauth.ScopeInference, + }, nil + }, + } + + svc := NewOAuthService(&mockProxyRepoForOAuth{}, client) + defer svc.Stop() + + // 使用 SetupToken URL(inference scope) + result, err := svc.GenerateSetupTokenURL(context.Background(), nil) + if err != nil { + t.Fatalf("GenerateSetupTokenURL 返回错误: %v", err) + } + + tokenInfo, err := svc.ExchangeCode(context.Background(), &ExchangeCodeInput{ + SessionID: result.SessionID, + Code: "setup-code", + }) + if err != nil { + t.Fatalf("ExchangeCode 返回错误: %v", err) + } + if tokenInfo.AccessToken != "setup-token" { + t.Fatalf("AccessToken 不匹配: got=%q", tokenInfo.AccessToken) + } +} + +func TestOAuthService_ExchangeCode_ClientError(t *testing.T) { + t.Parallel() + + client := &mockClaudeOAuthClient{ + exchangeCodeFunc: func(ctx context.Context, code, codeVerifier, state, proxyURL string, isSetupToken bool) (*oauth.TokenResponse, error) { + return nil, fmt.Errorf("upstream error: invalid code") + }, + } + + svc := NewOAuthService(&mockProxyRepoForOAuth{}, client) + defer svc.Stop() + + result, _ := svc.GenerateAuthURL(context.Background(), nil) + _, err := svc.ExchangeCode(context.Background(), &ExchangeCodeInput{ + SessionID: result.SessionID, + Code: "bad-code", + }) + if err == nil { + t.Fatal("ExchangeCode 应返回错误") + } + if err.Error() != "upstream error: invalid code" { + t.Fatalf("错误信息不匹配: got=%q", err.Error()) + } +} + +func TestOAuthService_RefreshToken(t *testing.T) { + t.Parallel() + + client := &mockClaudeOAuthClient{ + refreshTokenFunc: func(ctx context.Context, refreshToken, proxyURL string) (*oauth.TokenResponse, error) { + if refreshToken != "my-refresh-token" { + t.Errorf("refreshToken 不匹配: got=%q", refreshToken) + } + if proxyURL != "" { + t.Errorf("proxyURL 应为空: got=%q", proxyURL) + } + return &oauth.TokenResponse{ + AccessToken: "new-access-token", + TokenType: "Bearer", + ExpiresIn: 7200, + RefreshToken: "new-refresh-token", + Scope: oauth.ScopeOAuth, + }, nil + }, + } + + svc := NewOAuthService(&mockProxyRepoForOAuth{}, client) + defer svc.Stop() + + tokenInfo, err := svc.RefreshToken(context.Background(), "my-refresh-token", "") + if err != nil { + t.Fatalf("RefreshToken 返回错误: %v", err) + } + if tokenInfo.AccessToken != "new-access-token" { + t.Fatalf("AccessToken 不匹配: got=%q", tokenInfo.AccessToken) + } + if tokenInfo.RefreshToken != "new-refresh-token" { + t.Fatalf("RefreshToken 不匹配: got=%q", tokenInfo.RefreshToken) + } + if tokenInfo.ExpiresIn != 7200 { + t.Fatalf("ExpiresIn 不匹配: got=%d", tokenInfo.ExpiresIn) + } + if tokenInfo.ExpiresAt == 0 { + t.Fatal("ExpiresAt 不应为 0") + } +} + +func TestOAuthService_RefreshToken_Error(t *testing.T) { + t.Parallel() + + client := &mockClaudeOAuthClient{ + refreshTokenFunc: func(ctx context.Context, refreshToken, proxyURL string) (*oauth.TokenResponse, error) { + return nil, fmt.Errorf("invalid_grant: token expired") + }, + } + + svc := NewOAuthService(&mockProxyRepoForOAuth{}, client) + defer svc.Stop() + + _, err := svc.RefreshToken(context.Background(), "expired-token", "") + if err == nil { + t.Fatal("RefreshToken 应返回错误") + } +} + +func TestOAuthService_RefreshAccountToken_NoRefreshToken(t *testing.T) { + t.Parallel() + + svc := NewOAuthService(&mockProxyRepoForOAuth{}, &mockClaudeOAuthClient{}) + defer svc.Stop() + + // 无 refresh_token 的账号 + account := &Account{ + ID: 1, + Platform: PlatformAnthropic, + Type: AccountTypeOAuth, + Credentials: map[string]any{ + "access_token": "some-token", + }, + } + _, err := svc.RefreshAccountToken(context.Background(), account) + if err == nil { + t.Fatal("RefreshAccountToken 应返回错误(无 refresh_token)") + } + if err.Error() != "no refresh token available" { + t.Fatalf("错误信息不匹配: got=%q", err.Error()) + } +} + +func TestOAuthService_RefreshAccountToken_EmptyRefreshToken(t *testing.T) { + t.Parallel() + + svc := NewOAuthService(&mockProxyRepoForOAuth{}, &mockClaudeOAuthClient{}) + defer svc.Stop() + + account := &Account{ + ID: 2, + Platform: PlatformAnthropic, + Type: AccountTypeOAuth, + Credentials: map[string]any{ + "access_token": "some-token", + "refresh_token": "", + }, + } + _, err := svc.RefreshAccountToken(context.Background(), account) + if err == nil { + t.Fatal("RefreshAccountToken 应返回错误(refresh_token 为空)") + } +} + +func TestOAuthService_RefreshAccountToken_Success(t *testing.T) { + t.Parallel() + + client := &mockClaudeOAuthClient{ + refreshTokenFunc: func(ctx context.Context, refreshToken, proxyURL string) (*oauth.TokenResponse, error) { + if refreshToken != "account-refresh-token" { + t.Errorf("refreshToken 不匹配: got=%q", refreshToken) + } + return &oauth.TokenResponse{ + AccessToken: "refreshed-access", + TokenType: "Bearer", + ExpiresIn: 3600, + RefreshToken: "new-refresh", + }, nil + }, + } + + svc := NewOAuthService(&mockProxyRepoForOAuth{}, client) + defer svc.Stop() + + account := &Account{ + ID: 3, + Platform: PlatformAnthropic, + Type: AccountTypeOAuth, + Credentials: map[string]any{ + "access_token": "old-access", + "refresh_token": "account-refresh-token", + }, + } + + tokenInfo, err := svc.RefreshAccountToken(context.Background(), account) + if err != nil { + t.Fatalf("RefreshAccountToken 返回错误: %v", err) + } + if tokenInfo.AccessToken != "refreshed-access" { + t.Fatalf("AccessToken 不匹配: got=%q", tokenInfo.AccessToken) + } +} + +func TestOAuthService_RefreshAccountToken_WithProxy(t *testing.T) { + t.Parallel() + + proxyRepo := &mockProxyRepoForOAuth{ + getByIDFunc: func(ctx context.Context, id int64) (*Proxy, error) { + return &Proxy{ + Protocol: "socks5", + Host: "socks.example.com", + Port: 1080, + Username: "user", + Password: "pass", + }, nil + }, + } + + client := &mockClaudeOAuthClient{ + refreshTokenFunc: func(ctx context.Context, refreshToken, proxyURL string) (*oauth.TokenResponse, error) { + if proxyURL != "socks5://user:pass@socks.example.com:1080" { + t.Errorf("proxyURL 不匹配: got=%q", proxyURL) + } + return &oauth.TokenResponse{ + AccessToken: "refreshed", + ExpiresIn: 3600, + }, nil + }, + } + + svc := NewOAuthService(proxyRepo, client) + defer svc.Stop() + + proxyID := int64(10) + account := &Account{ + ID: 4, + Platform: PlatformAnthropic, + Type: AccountTypeOAuth, + ProxyID: &proxyID, + Credentials: map[string]any{ + "refresh_token": "rt-with-proxy", + }, + } + + _, err := svc.RefreshAccountToken(context.Background(), account) + if err != nil { + t.Fatalf("RefreshAccountToken 返回错误: %v", err) + } +} + +func TestOAuthService_ExchangeCode_NilOrg(t *testing.T) { + t.Parallel() + + client := &mockClaudeOAuthClient{ + exchangeCodeFunc: func(ctx context.Context, code, codeVerifier, state, proxyURL string, isSetupToken bool) (*oauth.TokenResponse, error) { + return &oauth.TokenResponse{ + AccessToken: "token-no-org", + TokenType: "Bearer", + ExpiresIn: 3600, + Organization: nil, + Account: nil, + }, nil + }, + } + + svc := NewOAuthService(&mockProxyRepoForOAuth{}, client) + defer svc.Stop() + + result, _ := svc.GenerateAuthURL(context.Background(), nil) + tokenInfo, err := svc.ExchangeCode(context.Background(), &ExchangeCodeInput{ + SessionID: result.SessionID, + Code: "code", + }) + if err != nil { + t.Fatalf("ExchangeCode 返回错误: %v", err) + } + if tokenInfo.OrgUUID != "" { + t.Fatalf("OrgUUID 应为空: got=%q", tokenInfo.OrgUUID) + } + if tokenInfo.AccountUUID != "" { + t.Fatalf("AccountUUID 应为空: got=%q", tokenInfo.AccountUUID) + } +} + +func TestOAuthService_Stop_NoPanic(t *testing.T) { + t.Parallel() + + svc := NewOAuthService(&mockProxyRepoForOAuth{}, &mockClaudeOAuthClient{}) + + // 调用 Stop 不应 panic + svc.Stop() + + // 多次调用也不应 panic + svc.Stop() +} diff --git a/backend/internal/util/logredact/redact_test.go b/backend/internal/util/logredact/redact_test.go new file mode 100644 index 00000000..a9ec89c6 --- /dev/null +++ b/backend/internal/util/logredact/redact_test.go @@ -0,0 +1,39 @@ +package logredact + +import ( + "strings" + "testing" +) + +func TestRedactText_JSONLike(t *testing.T) { + in := `{"access_token":"ya29.a0AfH6SMDUMMY","refresh_token":"1//0gDUMMY","other":"ok"}` + out := RedactText(in) + if out == in { + t.Fatalf("expected redaction, got unchanged") + } + if want := `"access_token":"***"`; !strings.Contains(out, want) { + t.Fatalf("expected %q in %q", want, out) + } + if want := `"refresh_token":"***"`; !strings.Contains(out, want) { + t.Fatalf("expected %q in %q", want, out) + } +} + +func TestRedactText_QueryLike(t *testing.T) { + in := "access_token=ya29.a0AfH6SMDUMMY refresh_token=1//0gDUMMY" + out := RedactText(in) + if strings.Contains(out, "ya29") || strings.Contains(out, "1//0") { + t.Fatalf("expected tokens redacted, got %q", out) + } +} + +func TestRedactText_GOCSPX(t *testing.T) { + in := "client_secret=GOCSPX-abcdefghijklmnopqrstuvwxyz_0123456789" + out := RedactText(in) + if strings.Contains(out, "abcdefghijklmnopqrstuvwxyz") { + t.Fatalf("expected secret redacted, got %q", out) + } + if !strings.Contains(out, "client_secret=***") { + t.Fatalf("expected key redacted, got %q", out) + } +} diff --git a/backend/internal/util/urlvalidator/validator_test.go b/backend/internal/util/urlvalidator/validator_test.go index f9745da3..bec9bb21 100644 --- a/backend/internal/util/urlvalidator/validator_test.go +++ b/backend/internal/util/urlvalidator/validator_test.go @@ -49,3 +49,27 @@ func TestValidateURLFormat(t *testing.T) { t.Fatalf("expected trailing slash to be removed from path, got %s", normalized) } } + +func TestValidateHTTPURL(t *testing.T) { + if _, err := ValidateHTTPURL("http://example.com", false, ValidationOptions{}); err == nil { + t.Fatalf("expected http to fail when allow_insecure_http is false") + } + if _, err := ValidateHTTPURL("http://example.com", true, ValidationOptions{}); err != nil { + t.Fatalf("expected http to pass when allow_insecure_http is true, got %v", err) + } + if _, err := ValidateHTTPURL("https://example.com", false, ValidationOptions{RequireAllowlist: true}); err == nil { + t.Fatalf("expected require allowlist to fail when empty") + } + if _, err := ValidateHTTPURL("https://example.com", false, ValidationOptions{AllowedHosts: []string{"api.example.com"}}); err == nil { + t.Fatalf("expected host not in allowlist to fail") + } + if _, err := ValidateHTTPURL("https://api.example.com", false, ValidationOptions{AllowedHosts: []string{"api.example.com"}}); err != nil { + t.Fatalf("expected allowlisted host to pass, got %v", err) + } + if _, err := ValidateHTTPURL("https://sub.api.example.com", false, ValidationOptions{AllowedHosts: []string{"*.example.com"}}); err != nil { + t.Fatalf("expected wildcard allowlist to pass, got %v", err) + } + if _, err := ValidateHTTPURL("https://localhost", false, ValidationOptions{AllowPrivate: false}); err == nil { + t.Fatalf("expected localhost to be blocked when allow_private_hosts is false") + } +}