//go:build unit package antigravity import ( "context" "encoding/json" "fmt" "net" "net/http" "net/http/httptest" "net/url" "strings" "testing" "time" ) // --------------------------------------------------------------------------- // NewAPIRequestWithURL // --------------------------------------------------------------------------- func TestNewAPIRequestWithURL_普通请求(t *testing.T) { ctx := context.Background() baseURL := "https://example.com" action := "generateContent" token := "test-token" body := []byte(`{"prompt":"hello"}`) req, err := NewAPIRequestWithURL(ctx, baseURL, action, token, body) if err != nil { t.Fatalf("创建请求失败: %v", err) } // 验证 URL 不含 ?alt=sse expectedURL := "https://example.com/v1internal:generateContent" if req.URL.String() != expectedURL { t.Errorf("URL 不匹配: got %s, want %s", req.URL.String(), expectedURL) } // 验证请求方法 if req.Method != http.MethodPost { t.Errorf("请求方法不匹配: got %s, want POST", req.Method) } // 验证 Headers if ct := req.Header.Get("Content-Type"); ct != "application/json" { t.Errorf("Content-Type 不匹配: got %s", ct) } if auth := req.Header.Get("Authorization"); auth != "Bearer test-token" { t.Errorf("Authorization 不匹配: got %s", auth) } if ua := req.Header.Get("User-Agent"); ua != UserAgent { t.Errorf("User-Agent 不匹配: got %s, want %s", ua, UserAgent) } } func TestNewAPIRequestWithURL_流式请求(t *testing.T) { ctx := context.Background() baseURL := "https://example.com" action := "streamGenerateContent" token := "tok" body := []byte(`{}`) req, err := NewAPIRequestWithURL(ctx, baseURL, action, token, body) if err != nil { t.Fatalf("创建请求失败: %v", err) } expectedURL := "https://example.com/v1internal:streamGenerateContent?alt=sse" if req.URL.String() != expectedURL { t.Errorf("URL 不匹配: got %s, want %s", req.URL.String(), expectedURL) } } func TestNewAPIRequestWithURL_空Body(t *testing.T) { ctx := context.Background() req, err := NewAPIRequestWithURL(ctx, "https://example.com", "test", "tok", nil) if err != nil { t.Fatalf("创建请求失败: %v", err) } if req.Body == nil { t.Error("Body 应该非 nil(bytes.NewReader(nil) 会返回空 reader)") } } // --------------------------------------------------------------------------- // NewAPIRequest // --------------------------------------------------------------------------- func TestNewAPIRequest_使用默认URL(t *testing.T) { ctx := context.Background() req, err := NewAPIRequest(ctx, "generateContent", "tok", []byte(`{}`)) if err != nil { t.Fatalf("创建请求失败: %v", err) } expected := BaseURL + "/v1internal:generateContent" if req.URL.String() != expected { t.Errorf("URL 不匹配: got %s, want %s", req.URL.String(), expected) } } // --------------------------------------------------------------------------- // TierInfo.UnmarshalJSON // --------------------------------------------------------------------------- func TestTierInfo_UnmarshalJSON_字符串格式(t *testing.T) { data := []byte(`"free-tier"`) var tier TierInfo if err := tier.UnmarshalJSON(data); err != nil { t.Fatalf("反序列化失败: %v", err) } if tier.ID != "free-tier" { t.Errorf("ID 不匹配: got %s, want free-tier", tier.ID) } if tier.Name != "" { t.Errorf("Name 应为空: got %s", tier.Name) } } func TestTierInfo_UnmarshalJSON_对象格式(t *testing.T) { data := []byte(`{"id":"g1-pro-tier","name":"Pro","description":"Pro plan"}`) var tier TierInfo if err := tier.UnmarshalJSON(data); err != nil { t.Fatalf("反序列化失败: %v", err) } if tier.ID != "g1-pro-tier" { t.Errorf("ID 不匹配: got %s, want g1-pro-tier", tier.ID) } if tier.Name != "Pro" { t.Errorf("Name 不匹配: got %s, want Pro", tier.Name) } if tier.Description != "Pro plan" { t.Errorf("Description 不匹配: got %s, want Pro plan", tier.Description) } } func TestTierInfo_UnmarshalJSON_null(t *testing.T) { data := []byte(`null`) var tier TierInfo if err := tier.UnmarshalJSON(data); err != nil { t.Fatalf("反序列化 null 失败: %v", err) } if tier.ID != "" { t.Errorf("null 场景下 ID 应为空: got %s", tier.ID) } } func TestTierInfo_UnmarshalJSON_空数据(t *testing.T) { data := []byte(``) var tier TierInfo if err := tier.UnmarshalJSON(data); err != nil { t.Fatalf("反序列化空数据失败: %v", err) } if tier.ID != "" { t.Errorf("空数据场景下 ID 应为空: got %s", tier.ID) } } func TestTierInfo_UnmarshalJSON_空格包裹null(t *testing.T) { data := []byte(` null `) var tier TierInfo if err := tier.UnmarshalJSON(data); err != nil { t.Fatalf("反序列化空格 null 失败: %v", err) } if tier.ID != "" { t.Errorf("空格 null 场景下 ID 应为空: got %s", tier.ID) } } func TestTierInfo_UnmarshalJSON_通过JSON嵌套结构(t *testing.T) { // 模拟 LoadCodeAssistResponse 中的嵌套反序列化 jsonData := `{"currentTier":"free-tier","paidTier":{"id":"g1-ultra-tier","name":"Ultra"}}` var resp LoadCodeAssistResponse if err := json.Unmarshal([]byte(jsonData), &resp); err != nil { t.Fatalf("反序列化嵌套结构失败: %v", err) } if resp.CurrentTier == nil || resp.CurrentTier.ID != "free-tier" { t.Errorf("CurrentTier 不匹配: got %+v", resp.CurrentTier) } if resp.PaidTier == nil || resp.PaidTier.ID != "g1-ultra-tier" { t.Errorf("PaidTier 不匹配: got %+v", resp.PaidTier) } } // --------------------------------------------------------------------------- // LoadCodeAssistResponse.GetTier // --------------------------------------------------------------------------- func TestGetTier_PaidTier优先(t *testing.T) { resp := &LoadCodeAssistResponse{ CurrentTier: &TierInfo{ID: "free-tier"}, PaidTier: &TierInfo{ID: "g1-pro-tier"}, } if got := resp.GetTier(); got != "g1-pro-tier" { t.Errorf("应返回 paidTier: got %s", got) } } func TestGetTier_回退到CurrentTier(t *testing.T) { resp := &LoadCodeAssistResponse{ CurrentTier: &TierInfo{ID: "free-tier"}, } if got := resp.GetTier(); got != "free-tier" { t.Errorf("应返回 currentTier: got %s", got) } } func TestGetTier_PaidTier为空ID(t *testing.T) { resp := &LoadCodeAssistResponse{ CurrentTier: &TierInfo{ID: "free-tier"}, PaidTier: &TierInfo{ID: ""}, } // paidTier.ID 为空时应回退到 currentTier if got := resp.GetTier(); got != "free-tier" { t.Errorf("paidTier.ID 为空时应回退到 currentTier: got %s", got) } } func TestGetTier_两者都为nil(t *testing.T) { resp := &LoadCodeAssistResponse{} if got := resp.GetTier(); got != "" { t.Errorf("两者都为 nil 时应返回空字符串: got %s", got) } } // --------------------------------------------------------------------------- // NewClient // --------------------------------------------------------------------------- func TestNewClient_无代理(t *testing.T) { client := NewClient("") if client == nil { t.Fatal("NewClient 返回 nil") } if client.httpClient == nil { t.Fatal("httpClient 为 nil") } if client.httpClient.Timeout != 30*time.Second { t.Errorf("Timeout 不匹配: got %v, want 30s", client.httpClient.Timeout) } // 无代理时 Transport 应为 nil(使用默认) if client.httpClient.Transport != nil { t.Error("无代理时 Transport 应为 nil") } } func TestNewClient_有代理(t *testing.T) { client := NewClient("http://proxy.example.com:8080") if client == nil { t.Fatal("NewClient 返回 nil") } if client.httpClient.Transport == nil { t.Fatal("有代理时 Transport 不应为 nil") } } func TestNewClient_空格代理(t *testing.T) { client := NewClient(" ") if client == nil { t.Fatal("NewClient 返回 nil") } // 空格代理应等同于无代理 if client.httpClient.Transport != nil { t.Error("空格代理 Transport 应为 nil") } } func TestNewClient_无效代理URL(t *testing.T) { // 无效 URL 时 url.Parse 不一定返回错误(Go 的 url.Parse 很宽容), // 但 ://invalid 会导致解析错误 client := NewClient("://invalid") if client == nil { t.Fatal("NewClient 返回 nil") } // 无效 URL 解析失败时,Transport 应保持 nil if client.httpClient.Transport != nil { t.Error("无效代理 URL 时 Transport 应为 nil") } } // --------------------------------------------------------------------------- // isConnectionError // --------------------------------------------------------------------------- func TestIsConnectionError_nil(t *testing.T) { if isConnectionError(nil) { t.Error("nil 错误不应判定为连接错误") } } func TestIsConnectionError_超时错误(t *testing.T) { // 使用 net.OpError 包装超时 err := &net.OpError{ Op: "dial", Net: "tcp", Err: &timeoutError{}, } if !isConnectionError(err) { t.Error("超时错误应判定为连接错误") } } // timeoutError 实现 net.Error 接口用于测试 type timeoutError struct{} func (e *timeoutError) Error() string { return "timeout" } func (e *timeoutError) Timeout() bool { return true } func (e *timeoutError) Temporary() bool { return true } func TestIsConnectionError_netOpError(t *testing.T) { err := &net.OpError{ Op: "dial", Net: "tcp", Err: fmt.Errorf("connection refused"), } if !isConnectionError(err) { t.Error("net.OpError 应判定为连接错误") } } func TestIsConnectionError_urlError(t *testing.T) { err := &url.Error{ Op: "Get", URL: "https://example.com", Err: fmt.Errorf("some error"), } if !isConnectionError(err) { t.Error("url.Error 应判定为连接错误") } } func TestIsConnectionError_普通错误(t *testing.T) { err := fmt.Errorf("some random error") if isConnectionError(err) { t.Error("普通错误不应判定为连接错误") } } func TestIsConnectionError_包装的netOpError(t *testing.T) { inner := &net.OpError{ Op: "dial", Net: "tcp", Err: fmt.Errorf("connection refused"), } err := fmt.Errorf("wrapping: %w", inner) if !isConnectionError(err) { t.Error("被包装的 net.OpError 应判定为连接错误") } } // --------------------------------------------------------------------------- // shouldFallbackToNextURL // --------------------------------------------------------------------------- func TestShouldFallbackToNextURL_连接错误(t *testing.T) { err := &net.OpError{Op: "dial", Net: "tcp", Err: fmt.Errorf("refused")} if !shouldFallbackToNextURL(err, 0) { t.Error("连接错误应触发 URL 降级") } } func TestShouldFallbackToNextURL_状态码(t *testing.T) { tests := []struct { name string statusCode int want bool }{ {"429 Too Many Requests", http.StatusTooManyRequests, true}, {"408 Request Timeout", http.StatusRequestTimeout, true}, {"404 Not Found", http.StatusNotFound, true}, {"500 Internal Server Error", http.StatusInternalServerError, true}, {"502 Bad Gateway", http.StatusBadGateway, true}, {"503 Service Unavailable", http.StatusServiceUnavailable, true}, {"200 OK", http.StatusOK, false}, {"201 Created", http.StatusCreated, false}, {"400 Bad Request", http.StatusBadRequest, false}, {"401 Unauthorized", http.StatusUnauthorized, false}, {"403 Forbidden", http.StatusForbidden, false}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { got := shouldFallbackToNextURL(nil, tt.statusCode) if got != tt.want { t.Errorf("shouldFallbackToNextURL(nil, %d) = %v, want %v", tt.statusCode, got, tt.want) } }) } } func TestShouldFallbackToNextURL_无错误且200(t *testing.T) { if shouldFallbackToNextURL(nil, http.StatusOK) { t.Error("无错误且 200 不应触发 URL 降级") } } // --------------------------------------------------------------------------- // Client.ExchangeCode (使用 httptest) // --------------------------------------------------------------------------- func TestClient_ExchangeCode_成功(t *testing.T) { t.Setenv(AntigravityOAuthClientSecretEnv, "test-secret") server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { // 验证请求方法 if r.Method != http.MethodPost { t.Errorf("请求方法不匹配: got %s", r.Method) } // 验证 Content-Type if ct := r.Header.Get("Content-Type"); ct != "application/x-www-form-urlencoded" { t.Errorf("Content-Type 不匹配: got %s", ct) } // 验证请求体参数 if err := r.ParseForm(); err != nil { t.Fatalf("解析表单失败: %v", err) } if r.FormValue("client_id") != ClientID { t.Errorf("client_id 不匹配: got %s", r.FormValue("client_id")) } if r.FormValue("client_secret") != "test-secret" { t.Errorf("client_secret 不匹配: got %s", r.FormValue("client_secret")) } if r.FormValue("code") != "auth-code" { t.Errorf("code 不匹配: got %s", r.FormValue("code")) } if r.FormValue("code_verifier") != "verifier123" { t.Errorf("code_verifier 不匹配: got %s", r.FormValue("code_verifier")) } if r.FormValue("grant_type") != "authorization_code" { t.Errorf("grant_type 不匹配: got %s", r.FormValue("grant_type")) } w.Header().Set("Content-Type", "application/json") w.WriteHeader(http.StatusOK) _ = json.NewEncoder(w).Encode(TokenResponse{ AccessToken: "access-tok", ExpiresIn: 3600, TokenType: "Bearer", RefreshToken: "refresh-tok", }) })) defer server.Close() // 临时替换 TokenURL(该函数直接使用常量,需要我们通过构建自定义 client 来绕过) // 由于 ExchangeCode 硬编码了 TokenURL,我们需要直接测试 HTTP client 的行为 // 这里通过构造一个直接调用 mock server 的测试 client := &Client{httpClient: server.Client()} // 由于 ExchangeCode 使用硬编码的 TokenURL,我们无法直接注入 mock server URL // 需要使用 httptest 的 Transport 重定向 originalTokenURL := TokenURL // 我们改为直接构造请求来测试逻辑 _ = originalTokenURL _ = client // 改用直接构造请求测试 mock server 响应 ctx := context.Background() params := url.Values{} params.Set("client_id", ClientID) params.Set("client_secret", "test-secret") params.Set("code", "auth-code") params.Set("redirect_uri", RedirectURI) params.Set("grant_type", "authorization_code") params.Set("code_verifier", "verifier123") req, err := http.NewRequestWithContext(ctx, http.MethodPost, server.URL, strings.NewReader(params.Encode())) if err != nil { t.Fatalf("创建请求失败: %v", err) } req.Header.Set("Content-Type", "application/x-www-form-urlencoded") resp, err := server.Client().Do(req) if err != nil { t.Fatalf("请求失败: %v", err) } defer func() { _ = resp.Body.Close() }() if resp.StatusCode != http.StatusOK { t.Fatalf("状态码不匹配: got %d", resp.StatusCode) } var tokenResp TokenResponse if err := json.NewDecoder(resp.Body).Decode(&tokenResp); err != nil { t.Fatalf("解码失败: %v", err) } if tokenResp.AccessToken != "access-tok" { t.Errorf("AccessToken 不匹配: got %s", tokenResp.AccessToken) } if tokenResp.RefreshToken != "refresh-tok" { t.Errorf("RefreshToken 不匹配: got %s", tokenResp.RefreshToken) } } func TestClient_ExchangeCode_无ClientSecret(t *testing.T) { t.Setenv(AntigravityOAuthClientSecretEnv, "") client := NewClient("") _, err := client.ExchangeCode(context.Background(), "code", "verifier") if err == nil { t.Fatal("缺少 client_secret 时应返回错误") } if !strings.Contains(err.Error(), AntigravityOAuthClientSecretEnv) { t.Errorf("错误信息应包含环境变量名: got %s", err.Error()) } } func TestClient_ExchangeCode_服务器返回错误(t *testing.T) { t.Setenv(AntigravityOAuthClientSecretEnv, "test-secret") server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusBadRequest) _, _ = w.Write([]byte(`{"error":"invalid_grant"}`)) })) defer server.Close() // 直接测试 mock server 的错误响应 resp, err := server.Client().Get(server.URL) if err != nil { t.Fatalf("请求失败: %v", err) } defer func() { _ = resp.Body.Close() }() if resp.StatusCode != http.StatusBadRequest { t.Errorf("状态码不匹配: got %d, want 400", resp.StatusCode) } } // --------------------------------------------------------------------------- // Client.RefreshToken (使用 httptest) // --------------------------------------------------------------------------- func TestClient_RefreshToken_MockServer(t *testing.T) { t.Setenv(AntigravityOAuthClientSecretEnv, "test-secret") server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if r.Method != http.MethodPost { t.Errorf("请求方法不匹配: got %s", r.Method) } if err := r.ParseForm(); err != nil { t.Fatalf("解析表单失败: %v", err) } if r.FormValue("grant_type") != "refresh_token" { t.Errorf("grant_type 不匹配: got %s", r.FormValue("grant_type")) } if r.FormValue("refresh_token") != "old-refresh-tok" { t.Errorf("refresh_token 不匹配: got %s", r.FormValue("refresh_token")) } w.Header().Set("Content-Type", "application/json") w.WriteHeader(http.StatusOK) _ = json.NewEncoder(w).Encode(TokenResponse{ AccessToken: "new-access-tok", ExpiresIn: 3600, TokenType: "Bearer", }) })) defer server.Close() ctx := context.Background() params := url.Values{} params.Set("client_id", ClientID) params.Set("client_secret", "test-secret") params.Set("refresh_token", "old-refresh-tok") params.Set("grant_type", "refresh_token") req, err := http.NewRequestWithContext(ctx, http.MethodPost, server.URL, strings.NewReader(params.Encode())) if err != nil { t.Fatalf("创建请求失败: %v", err) } req.Header.Set("Content-Type", "application/x-www-form-urlencoded") resp, err := server.Client().Do(req) if err != nil { t.Fatalf("请求失败: %v", err) } defer func() { _ = resp.Body.Close() }() if resp.StatusCode != http.StatusOK { t.Fatalf("状态码不匹配: got %d", resp.StatusCode) } var tokenResp TokenResponse if err := json.NewDecoder(resp.Body).Decode(&tokenResp); err != nil { t.Fatalf("解码失败: %v", err) } if tokenResp.AccessToken != "new-access-tok" { t.Errorf("AccessToken 不匹配: got %s", tokenResp.AccessToken) } } func TestClient_RefreshToken_无ClientSecret(t *testing.T) { t.Setenv(AntigravityOAuthClientSecretEnv, "") client := NewClient("") _, err := client.RefreshToken(context.Background(), "refresh-tok") if err == nil { t.Fatal("缺少 client_secret 时应返回错误") } } // --------------------------------------------------------------------------- // Client.GetUserInfo (使用 httptest) // --------------------------------------------------------------------------- func TestClient_GetUserInfo_成功(t *testing.T) { server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if r.Method != http.MethodGet { t.Errorf("请求方法不匹配: got %s", r.Method) } auth := r.Header.Get("Authorization") if auth != "Bearer test-access-token" { t.Errorf("Authorization 不匹配: got %s", auth) } w.Header().Set("Content-Type", "application/json") w.WriteHeader(http.StatusOK) _ = json.NewEncoder(w).Encode(UserInfo{ Email: "user@example.com", Name: "Test User", GivenName: "Test", FamilyName: "User", Picture: "https://example.com/photo.jpg", }) })) defer server.Close() // 直接通过 mock server 测试 GetUserInfo 的行为逻辑 ctx := context.Background() req, err := http.NewRequestWithContext(ctx, http.MethodGet, server.URL, nil) if err != nil { t.Fatalf("创建请求失败: %v", err) } req.Header.Set("Authorization", "Bearer test-access-token") resp, err := server.Client().Do(req) if err != nil { t.Fatalf("请求失败: %v", err) } defer func() { _ = resp.Body.Close() }() if resp.StatusCode != http.StatusOK { t.Fatalf("状态码不匹配: got %d", resp.StatusCode) } var userInfo UserInfo if err := json.NewDecoder(resp.Body).Decode(&userInfo); err != nil { t.Fatalf("解码失败: %v", err) } if userInfo.Email != "user@example.com" { t.Errorf("Email 不匹配: got %s", userInfo.Email) } if userInfo.Name != "Test User" { t.Errorf("Name 不匹配: got %s", userInfo.Name) } } func TestClient_GetUserInfo_服务器返回错误(t *testing.T) { server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusUnauthorized) _, _ = w.Write([]byte(`{"error":"invalid_token"}`)) })) defer server.Close() resp, err := server.Client().Get(server.URL) if err != nil { t.Fatalf("请求失败: %v", err) } defer func() { _ = resp.Body.Close() }() if resp.StatusCode != http.StatusUnauthorized { t.Errorf("状态码不匹配: got %d, want 401", resp.StatusCode) } } // --------------------------------------------------------------------------- // TokenResponse / UserInfo JSON 序列化 // --------------------------------------------------------------------------- func TestTokenResponse_JSON序列化(t *testing.T) { jsonData := `{"access_token":"at","expires_in":3600,"token_type":"Bearer","scope":"openid","refresh_token":"rt"}` var resp TokenResponse if err := json.Unmarshal([]byte(jsonData), &resp); err != nil { t.Fatalf("反序列化失败: %v", err) } if resp.AccessToken != "at" { t.Errorf("AccessToken 不匹配: got %s", resp.AccessToken) } if resp.ExpiresIn != 3600 { t.Errorf("ExpiresIn 不匹配: got %d", resp.ExpiresIn) } if resp.RefreshToken != "rt" { t.Errorf("RefreshToken 不匹配: got %s", resp.RefreshToken) } } func TestUserInfo_JSON序列化(t *testing.T) { jsonData := `{"email":"a@b.com","name":"Alice"}` var info UserInfo if err := json.Unmarshal([]byte(jsonData), &info); err != nil { t.Fatalf("反序列化失败: %v", err) } if info.Email != "a@b.com" { t.Errorf("Email 不匹配: got %s", info.Email) } if info.Name != "Alice" { t.Errorf("Name 不匹配: got %s", info.Name) } } // --------------------------------------------------------------------------- // LoadCodeAssistResponse JSON 序列化 // --------------------------------------------------------------------------- func TestLoadCodeAssistResponse_完整JSON(t *testing.T) { jsonData := `{ "cloudaicompanionProject": "proj-123", "currentTier": "free-tier", "paidTier": {"id": "g1-pro-tier", "name": "Pro"}, "ineligibleTiers": [{"tier": {"id": "g1-ultra-tier"}, "reasonCode": "INELIGIBLE_ACCOUNT"}] }` var resp LoadCodeAssistResponse if err := json.Unmarshal([]byte(jsonData), &resp); err != nil { t.Fatalf("反序列化失败: %v", err) } if resp.CloudAICompanionProject != "proj-123" { t.Errorf("CloudAICompanionProject 不匹配: got %s", resp.CloudAICompanionProject) } if resp.GetTier() != "g1-pro-tier" { t.Errorf("GetTier 不匹配: got %s", resp.GetTier()) } if len(resp.IneligibleTiers) != 1 { t.Fatalf("IneligibleTiers 数量不匹配: got %d", len(resp.IneligibleTiers)) } if resp.IneligibleTiers[0].ReasonCode != "INELIGIBLE_ACCOUNT" { t.Errorf("ReasonCode 不匹配: got %s", resp.IneligibleTiers[0].ReasonCode) } } // =========================================================================== // 以下为新增测试:真正调用 Client 方法,通过 RoundTripper 拦截 HTTP 请求 // =========================================================================== // redirectRoundTripper 将请求中特定前缀的 URL 重定向到 httptest server type redirectRoundTripper struct { // 原始 URL 前缀 -> 替换目标 URL 的映射 redirects map[string]string transport http.RoundTripper } func (rt *redirectRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) { originalURL := req.URL.String() for prefix, target := range rt.redirects { if strings.HasPrefix(originalURL, prefix) { newURL := target + strings.TrimPrefix(originalURL, prefix) parsed, err := url.Parse(newURL) if err != nil { return nil, err } req.URL = parsed break } } if rt.transport == nil { return http.DefaultTransport.RoundTrip(req) } return rt.transport.RoundTrip(req) } // newTestClientWithRedirect 创建一个 Client,将指定 URL 前缀的请求重定向到 mock server func newTestClientWithRedirect(redirects map[string]string) *Client { return &Client{ httpClient: &http.Client{ Timeout: 10 * time.Second, Transport: &redirectRoundTripper{ redirects: redirects, }, }, } } // --------------------------------------------------------------------------- // Client.ExchangeCode - 真正调用方法的测试 // --------------------------------------------------------------------------- func TestClient_ExchangeCode_Success_RealCall(t *testing.T) { t.Setenv(AntigravityOAuthClientSecretEnv, "test-secret") server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if r.Method != http.MethodPost { t.Errorf("请求方法不匹配: got %s, want POST", r.Method) } if ct := r.Header.Get("Content-Type"); ct != "application/x-www-form-urlencoded" { t.Errorf("Content-Type 不匹配: got %s", ct) } if err := r.ParseForm(); err != nil { t.Fatalf("解析表单失败: %v", err) } if r.FormValue("client_id") != ClientID { t.Errorf("client_id 不匹配: got %s", r.FormValue("client_id")) } if r.FormValue("client_secret") != "test-secret" { t.Errorf("client_secret 不匹配: got %s", r.FormValue("client_secret")) } if r.FormValue("code") != "test-auth-code" { t.Errorf("code 不匹配: got %s", r.FormValue("code")) } if r.FormValue("code_verifier") != "test-verifier" { t.Errorf("code_verifier 不匹配: got %s", r.FormValue("code_verifier")) } if r.FormValue("grant_type") != "authorization_code" { t.Errorf("grant_type 不匹配: got %s", r.FormValue("grant_type")) } if r.FormValue("redirect_uri") != RedirectURI { t.Errorf("redirect_uri 不匹配: got %s", r.FormValue("redirect_uri")) } w.Header().Set("Content-Type", "application/json") w.WriteHeader(http.StatusOK) _ = json.NewEncoder(w).Encode(TokenResponse{ AccessToken: "new-access-token", ExpiresIn: 3600, TokenType: "Bearer", Scope: "openid email", RefreshToken: "new-refresh-token", }) })) defer server.Close() client := newTestClientWithRedirect(map[string]string{ TokenURL: server.URL, }) tokenResp, err := client.ExchangeCode(context.Background(), "test-auth-code", "test-verifier") if err != nil { t.Fatalf("ExchangeCode 失败: %v", err) } if tokenResp.AccessToken != "new-access-token" { t.Errorf("AccessToken 不匹配: got %s, want new-access-token", tokenResp.AccessToken) } if tokenResp.RefreshToken != "new-refresh-token" { t.Errorf("RefreshToken 不匹配: got %s, want new-refresh-token", tokenResp.RefreshToken) } if tokenResp.ExpiresIn != 3600 { t.Errorf("ExpiresIn 不匹配: got %d, want 3600", tokenResp.ExpiresIn) } if tokenResp.TokenType != "Bearer" { t.Errorf("TokenType 不匹配: got %s, want Bearer", tokenResp.TokenType) } if tokenResp.Scope != "openid email" { t.Errorf("Scope 不匹配: got %s, want openid email", tokenResp.Scope) } } func TestClient_ExchangeCode_ServerError_RealCall(t *testing.T) { t.Setenv(AntigravityOAuthClientSecretEnv, "test-secret") server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusBadRequest) _, _ = w.Write([]byte(`{"error":"invalid_grant","error_description":"code expired"}`)) })) defer server.Close() client := newTestClientWithRedirect(map[string]string{ TokenURL: server.URL, }) _, err := client.ExchangeCode(context.Background(), "expired-code", "verifier") if err == nil { t.Fatal("服务器返回 400 时应返回错误") } if !strings.Contains(err.Error(), "token 交换失败") { t.Errorf("错误信息应包含 'token 交换失败': got %s", err.Error()) } if !strings.Contains(err.Error(), "400") { t.Errorf("错误信息应包含状态码 400: got %s", err.Error()) } } func TestClient_ExchangeCode_InvalidJSON_RealCall(t *testing.T) { t.Setenv(AntigravityOAuthClientSecretEnv, "test-secret") server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "application/json") w.WriteHeader(http.StatusOK) _, _ = w.Write([]byte(`{invalid json`)) })) defer server.Close() client := newTestClientWithRedirect(map[string]string{ TokenURL: server.URL, }) _, err := client.ExchangeCode(context.Background(), "code", "verifier") if err == nil { t.Fatal("无效 JSON 响应应返回错误") } if !strings.Contains(err.Error(), "token 解析失败") { t.Errorf("错误信息应包含 'token 解析失败': got %s", err.Error()) } } func TestClient_ExchangeCode_ContextCanceled_RealCall(t *testing.T) { t.Setenv(AntigravityOAuthClientSecretEnv, "test-secret") server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { time.Sleep(5 * time.Second) // 模拟慢响应 w.WriteHeader(http.StatusOK) })) defer server.Close() client := newTestClientWithRedirect(map[string]string{ TokenURL: server.URL, }) ctx, cancel := context.WithCancel(context.Background()) cancel() // 立即取消 _, err := client.ExchangeCode(ctx, "code", "verifier") if err == nil { t.Fatal("context 取消时应返回错误") } } // --------------------------------------------------------------------------- // Client.RefreshToken - 真正调用方法的测试 // --------------------------------------------------------------------------- func TestClient_RefreshToken_Success_RealCall(t *testing.T) { t.Setenv(AntigravityOAuthClientSecretEnv, "test-secret") server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if r.Method != http.MethodPost { t.Errorf("请求方法不匹配: got %s, want POST", r.Method) } if err := r.ParseForm(); err != nil { t.Fatalf("解析表单失败: %v", err) } if r.FormValue("grant_type") != "refresh_token" { t.Errorf("grant_type 不匹配: got %s", r.FormValue("grant_type")) } if r.FormValue("refresh_token") != "my-refresh-token" { t.Errorf("refresh_token 不匹配: got %s", r.FormValue("refresh_token")) } if r.FormValue("client_id") != ClientID { t.Errorf("client_id 不匹配: got %s", r.FormValue("client_id")) } if r.FormValue("client_secret") != "test-secret" { t.Errorf("client_secret 不匹配: got %s", r.FormValue("client_secret")) } w.Header().Set("Content-Type", "application/json") w.WriteHeader(http.StatusOK) _ = json.NewEncoder(w).Encode(TokenResponse{ AccessToken: "refreshed-access-token", ExpiresIn: 3600, TokenType: "Bearer", }) })) defer server.Close() client := newTestClientWithRedirect(map[string]string{ TokenURL: server.URL, }) tokenResp, err := client.RefreshToken(context.Background(), "my-refresh-token") if err != nil { t.Fatalf("RefreshToken 失败: %v", err) } if tokenResp.AccessToken != "refreshed-access-token" { t.Errorf("AccessToken 不匹配: got %s, want refreshed-access-token", tokenResp.AccessToken) } if tokenResp.ExpiresIn != 3600 { t.Errorf("ExpiresIn 不匹配: got %d, want 3600", tokenResp.ExpiresIn) } } func TestClient_RefreshToken_ServerError_RealCall(t *testing.T) { t.Setenv(AntigravityOAuthClientSecretEnv, "test-secret") server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusUnauthorized) _, _ = w.Write([]byte(`{"error":"invalid_grant","error_description":"token revoked"}`)) })) defer server.Close() client := newTestClientWithRedirect(map[string]string{ TokenURL: server.URL, }) _, err := client.RefreshToken(context.Background(), "revoked-token") if err == nil { t.Fatal("服务器返回 401 时应返回错误") } if !strings.Contains(err.Error(), "token 刷新失败") { t.Errorf("错误信息应包含 'token 刷新失败': got %s", err.Error()) } } func TestClient_RefreshToken_InvalidJSON_RealCall(t *testing.T) { t.Setenv(AntigravityOAuthClientSecretEnv, "test-secret") server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "application/json") w.WriteHeader(http.StatusOK) _, _ = w.Write([]byte(`not-json`)) })) defer server.Close() client := newTestClientWithRedirect(map[string]string{ TokenURL: server.URL, }) _, err := client.RefreshToken(context.Background(), "refresh-tok") if err == nil { t.Fatal("无效 JSON 响应应返回错误") } if !strings.Contains(err.Error(), "token 解析失败") { t.Errorf("错误信息应包含 'token 解析失败': got %s", err.Error()) } } func TestClient_RefreshToken_ContextCanceled_RealCall(t *testing.T) { t.Setenv(AntigravityOAuthClientSecretEnv, "test-secret") server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { time.Sleep(5 * time.Second) w.WriteHeader(http.StatusOK) })) defer server.Close() client := newTestClientWithRedirect(map[string]string{ TokenURL: server.URL, }) ctx, cancel := context.WithCancel(context.Background()) cancel() _, err := client.RefreshToken(ctx, "refresh-tok") if err == nil { t.Fatal("context 取消时应返回错误") } } // --------------------------------------------------------------------------- // Client.GetUserInfo - 真正调用方法的测试 // --------------------------------------------------------------------------- func TestClient_GetUserInfo_Success_RealCall(t *testing.T) { server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if r.Method != http.MethodGet { t.Errorf("请求方法不匹配: got %s, want GET", r.Method) } auth := r.Header.Get("Authorization") if auth != "Bearer user-access-token" { t.Errorf("Authorization 不匹配: got %s", auth) } w.Header().Set("Content-Type", "application/json") w.WriteHeader(http.StatusOK) _ = json.NewEncoder(w).Encode(UserInfo{ Email: "test@example.com", Name: "Test User", GivenName: "Test", FamilyName: "User", Picture: "https://example.com/avatar.jpg", }) })) defer server.Close() client := newTestClientWithRedirect(map[string]string{ UserInfoURL: server.URL, }) userInfo, err := client.GetUserInfo(context.Background(), "user-access-token") if err != nil { t.Fatalf("GetUserInfo 失败: %v", err) } if userInfo.Email != "test@example.com" { t.Errorf("Email 不匹配: got %s, want test@example.com", userInfo.Email) } if userInfo.Name != "Test User" { t.Errorf("Name 不匹配: got %s, want Test User", userInfo.Name) } if userInfo.GivenName != "Test" { t.Errorf("GivenName 不匹配: got %s, want Test", userInfo.GivenName) } if userInfo.FamilyName != "User" { t.Errorf("FamilyName 不匹配: got %s, want User", userInfo.FamilyName) } if userInfo.Picture != "https://example.com/avatar.jpg" { t.Errorf("Picture 不匹配: got %s", userInfo.Picture) } } func TestClient_GetUserInfo_Unauthorized_RealCall(t *testing.T) { server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusUnauthorized) _, _ = w.Write([]byte(`{"error":"invalid_token"}`)) })) defer server.Close() client := newTestClientWithRedirect(map[string]string{ UserInfoURL: server.URL, }) _, err := client.GetUserInfo(context.Background(), "bad-token") if err == nil { t.Fatal("服务器返回 401 时应返回错误") } if !strings.Contains(err.Error(), "获取用户信息失败") { t.Errorf("错误信息应包含 '获取用户信息失败': got %s", err.Error()) } if !strings.Contains(err.Error(), "401") { t.Errorf("错误信息应包含状态码 401: got %s", err.Error()) } } func TestClient_GetUserInfo_InvalidJSON_RealCall(t *testing.T) { server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "application/json") w.WriteHeader(http.StatusOK) _, _ = w.Write([]byte(`{broken`)) })) defer server.Close() client := newTestClientWithRedirect(map[string]string{ UserInfoURL: server.URL, }) _, err := client.GetUserInfo(context.Background(), "token") if err == nil { t.Fatal("无效 JSON 响应应返回错误") } if !strings.Contains(err.Error(), "用户信息解析失败") { t.Errorf("错误信息应包含 '用户信息解析失败': got %s", err.Error()) } } func TestClient_GetUserInfo_ContextCanceled_RealCall(t *testing.T) { server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { time.Sleep(5 * time.Second) w.WriteHeader(http.StatusOK) })) defer server.Close() client := newTestClientWithRedirect(map[string]string{ UserInfoURL: server.URL, }) ctx, cancel := context.WithCancel(context.Background()) cancel() _, err := client.GetUserInfo(ctx, "token") if err == nil { t.Fatal("context 取消时应返回错误") } } // --------------------------------------------------------------------------- // Client.LoadCodeAssist - 真正调用方法的测试 // --------------------------------------------------------------------------- // withMockBaseURLs 临时替换 BaseURLs,测试结束后恢复 func withMockBaseURLs(t *testing.T, urls []string) { t.Helper() origBaseURLs := BaseURLs origBaseURL := BaseURL BaseURLs = urls if len(urls) > 0 { BaseURL = urls[0] } t.Cleanup(func() { BaseURLs = origBaseURLs BaseURL = origBaseURL }) } func TestClient_LoadCodeAssist_Success_RealCall(t *testing.T) { server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if r.Method != http.MethodPost { t.Errorf("请求方法不匹配: got %s, want POST", r.Method) } if !strings.HasSuffix(r.URL.Path, "/v1internal:loadCodeAssist") { t.Errorf("URL 路径不匹配: got %s", r.URL.Path) } auth := r.Header.Get("Authorization") if auth != "Bearer test-token" { t.Errorf("Authorization 不匹配: got %s", auth) } if ct := r.Header.Get("Content-Type"); ct != "application/json" { t.Errorf("Content-Type 不匹配: got %s", ct) } if ua := r.Header.Get("User-Agent"); ua != UserAgent { t.Errorf("User-Agent 不匹配: got %s", ua) } // 验证请求体 var reqBody LoadCodeAssistRequest if err := json.NewDecoder(r.Body).Decode(&reqBody); err != nil { t.Fatalf("解析请求体失败: %v", err) } if reqBody.Metadata.IDEType != "ANTIGRAVITY" { t.Errorf("IDEType 不匹配: got %s, want ANTIGRAVITY", reqBody.Metadata.IDEType) } w.Header().Set("Content-Type", "application/json") w.WriteHeader(http.StatusOK) _, _ = w.Write([]byte(`{ "cloudaicompanionProject": "test-project-123", "currentTier": {"id": "free-tier", "name": "Free"}, "paidTier": {"id": "g1-pro-tier", "name": "Pro", "description": "Pro plan"} }`)) })) defer server.Close() withMockBaseURLs(t, []string{server.URL}) client := NewClient("") resp, rawResp, err := client.LoadCodeAssist(context.Background(), "test-token") if err != nil { t.Fatalf("LoadCodeAssist 失败: %v", err) } if resp.CloudAICompanionProject != "test-project-123" { t.Errorf("CloudAICompanionProject 不匹配: got %s", resp.CloudAICompanionProject) } if resp.GetTier() != "g1-pro-tier" { t.Errorf("GetTier 不匹配: got %s, want g1-pro-tier", resp.GetTier()) } if resp.CurrentTier == nil || resp.CurrentTier.ID != "free-tier" { t.Errorf("CurrentTier 不匹配: got %+v", resp.CurrentTier) } if resp.PaidTier == nil || resp.PaidTier.ID != "g1-pro-tier" { t.Errorf("PaidTier 不匹配: got %+v", resp.PaidTier) } // 验证原始 JSON map if rawResp == nil { t.Fatal("rawResp 不应为 nil") } if rawResp["cloudaicompanionProject"] != "test-project-123" { t.Errorf("rawResp cloudaicompanionProject 不匹配: got %v", rawResp["cloudaicompanionProject"]) } } func TestClient_LoadCodeAssist_HTTPError_RealCall(t *testing.T) { server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusForbidden) _, _ = w.Write([]byte(`{"error":"forbidden"}`)) })) defer server.Close() withMockBaseURLs(t, []string{server.URL}) client := NewClient("") _, _, err := client.LoadCodeAssist(context.Background(), "bad-token") if err == nil { t.Fatal("服务器返回 403 时应返回错误") } if !strings.Contains(err.Error(), "loadCodeAssist 失败") { t.Errorf("错误信息应包含 'loadCodeAssist 失败': got %s", err.Error()) } if !strings.Contains(err.Error(), "403") { t.Errorf("错误信息应包含状态码 403: got %s", err.Error()) } } func TestClient_LoadCodeAssist_InvalidJSON_RealCall(t *testing.T) { server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "application/json") w.WriteHeader(http.StatusOK) _, _ = w.Write([]byte(`{not valid json!!!`)) })) defer server.Close() withMockBaseURLs(t, []string{server.URL}) client := NewClient("") _, _, err := client.LoadCodeAssist(context.Background(), "token") if err == nil { t.Fatal("无效 JSON 响应应返回错误") } if !strings.Contains(err.Error(), "响应解析失败") { t.Errorf("错误信息应包含 '响应解析失败': got %s", err.Error()) } } func TestClient_LoadCodeAssist_URLFallback_RealCall(t *testing.T) { // 第一个 server 返回 500,第二个 server 返回成功 callCount := 0 server1 := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { callCount++ w.WriteHeader(http.StatusInternalServerError) _, _ = w.Write([]byte(`{"error":"internal"}`)) })) defer server1.Close() server2 := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { callCount++ w.Header().Set("Content-Type", "application/json") w.WriteHeader(http.StatusOK) _, _ = w.Write([]byte(`{ "cloudaicompanionProject": "fallback-project", "currentTier": {"id": "free-tier", "name": "Free"} }`)) })) defer server2.Close() withMockBaseURLs(t, []string{server1.URL, server2.URL}) client := NewClient("") resp, _, err := client.LoadCodeAssist(context.Background(), "token") if err != nil { t.Fatalf("LoadCodeAssist 应在 fallback 后成功: %v", err) } if resp.CloudAICompanionProject != "fallback-project" { t.Errorf("CloudAICompanionProject 不匹配: got %s", resp.CloudAICompanionProject) } if callCount != 2 { t.Errorf("应该调用了 2 个 server,实际调用 %d 次", callCount) } } func TestClient_LoadCodeAssist_AllURLsFail_RealCall(t *testing.T) { server1 := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusServiceUnavailable) _, _ = w.Write([]byte(`{"error":"unavailable"}`)) })) defer server1.Close() server2 := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusBadGateway) _, _ = w.Write([]byte(`{"error":"bad_gateway"}`)) })) defer server2.Close() withMockBaseURLs(t, []string{server1.URL, server2.URL}) client := NewClient("") _, _, err := client.LoadCodeAssist(context.Background(), "token") if err == nil { t.Fatal("所有 URL 都失败时应返回错误") } } func TestClient_LoadCodeAssist_ContextCanceled_RealCall(t *testing.T) { server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { time.Sleep(5 * time.Second) w.WriteHeader(http.StatusOK) })) defer server.Close() withMockBaseURLs(t, []string{server.URL}) client := NewClient("") ctx, cancel := context.WithCancel(context.Background()) cancel() _, _, err := client.LoadCodeAssist(ctx, "token") if err == nil { t.Fatal("context 取消时应返回错误") } } // --------------------------------------------------------------------------- // Client.FetchAvailableModels - 真正调用方法的测试 // --------------------------------------------------------------------------- func TestClient_FetchAvailableModels_Success_RealCall(t *testing.T) { server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if r.Method != http.MethodPost { t.Errorf("请求方法不匹配: got %s, want POST", r.Method) } if !strings.HasSuffix(r.URL.Path, "/v1internal:fetchAvailableModels") { t.Errorf("URL 路径不匹配: got %s", r.URL.Path) } auth := r.Header.Get("Authorization") if auth != "Bearer test-token" { t.Errorf("Authorization 不匹配: got %s", auth) } if ct := r.Header.Get("Content-Type"); ct != "application/json" { t.Errorf("Content-Type 不匹配: got %s", ct) } if ua := r.Header.Get("User-Agent"); ua != UserAgent { t.Errorf("User-Agent 不匹配: got %s", ua) } // 验证请求体 var reqBody FetchAvailableModelsRequest if err := json.NewDecoder(r.Body).Decode(&reqBody); err != nil { t.Fatalf("解析请求体失败: %v", err) } if reqBody.Project != "project-abc" { t.Errorf("Project 不匹配: got %s, want project-abc", reqBody.Project) } w.Header().Set("Content-Type", "application/json") w.WriteHeader(http.StatusOK) _, _ = w.Write([]byte(`{ "models": { "gemini-2.0-flash": { "quotaInfo": { "remainingFraction": 0.85, "resetTime": "2025-01-01T00:00:00Z" } }, "gemini-2.5-pro": { "quotaInfo": { "remainingFraction": 0.5 } } } }`)) })) defer server.Close() withMockBaseURLs(t, []string{server.URL}) client := NewClient("") resp, rawResp, err := client.FetchAvailableModels(context.Background(), "test-token", "project-abc") if err != nil { t.Fatalf("FetchAvailableModels 失败: %v", err) } if resp.Models == nil { t.Fatal("Models 不应为 nil") } if len(resp.Models) != 2 { t.Errorf("Models 数量不匹配: got %d, want 2", len(resp.Models)) } flashModel, ok := resp.Models["gemini-2.0-flash"] if !ok { t.Fatal("缺少 gemini-2.0-flash 模型") } if flashModel.QuotaInfo == nil { t.Fatal("gemini-2.0-flash QuotaInfo 不应为 nil") } if flashModel.QuotaInfo.RemainingFraction != 0.85 { t.Errorf("RemainingFraction 不匹配: got %f, want 0.85", flashModel.QuotaInfo.RemainingFraction) } if flashModel.QuotaInfo.ResetTime != "2025-01-01T00:00:00Z" { t.Errorf("ResetTime 不匹配: got %s", flashModel.QuotaInfo.ResetTime) } proModel, ok := resp.Models["gemini-2.5-pro"] if !ok { t.Fatal("缺少 gemini-2.5-pro 模型") } if proModel.QuotaInfo == nil { t.Fatal("gemini-2.5-pro QuotaInfo 不应为 nil") } if proModel.QuotaInfo.RemainingFraction != 0.5 { t.Errorf("RemainingFraction 不匹配: got %f, want 0.5", proModel.QuotaInfo.RemainingFraction) } // 验证原始 JSON map if rawResp == nil { t.Fatal("rawResp 不应为 nil") } if rawResp["models"] == nil { t.Error("rawResp models 不应为 nil") } } func TestClient_FetchAvailableModels_HTTPError_RealCall(t *testing.T) { server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusForbidden) _, _ = w.Write([]byte(`{"error":"forbidden"}`)) })) defer server.Close() withMockBaseURLs(t, []string{server.URL}) client := NewClient("") _, _, err := client.FetchAvailableModels(context.Background(), "bad-token", "proj") if err == nil { t.Fatal("服务器返回 403 时应返回错误") } if !strings.Contains(err.Error(), "fetchAvailableModels 失败") { t.Errorf("错误信息应包含 'fetchAvailableModels 失败': got %s", err.Error()) } } func TestClient_FetchAvailableModels_InvalidJSON_RealCall(t *testing.T) { server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "application/json") w.WriteHeader(http.StatusOK) _, _ = w.Write([]byte(`<<>>`)) })) defer server.Close() withMockBaseURLs(t, []string{server.URL}) client := NewClient("") _, _, err := client.FetchAvailableModels(context.Background(), "token", "proj") if err == nil { t.Fatal("无效 JSON 响应应返回错误") } if !strings.Contains(err.Error(), "响应解析失败") { t.Errorf("错误信息应包含 '响应解析失败': got %s", err.Error()) } } func TestClient_FetchAvailableModels_URLFallback_RealCall(t *testing.T) { callCount := 0 // 第一个 server 返回 429,第二个 server 返回成功 server1 := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { callCount++ w.WriteHeader(http.StatusTooManyRequests) _, _ = w.Write([]byte(`{"error":"rate_limited"}`)) })) defer server1.Close() server2 := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { callCount++ w.Header().Set("Content-Type", "application/json") w.WriteHeader(http.StatusOK) _, _ = w.Write([]byte(`{"models": {"model-a": {}}}`)) })) defer server2.Close() withMockBaseURLs(t, []string{server1.URL, server2.URL}) client := NewClient("") resp, _, err := client.FetchAvailableModels(context.Background(), "token", "proj") if err != nil { t.Fatalf("FetchAvailableModels 应在 fallback 后成功: %v", err) } if _, ok := resp.Models["model-a"]; !ok { t.Error("应返回 fallback server 的模型") } if callCount != 2 { t.Errorf("应该调用了 2 个 server,实际调用 %d 次", callCount) } } func TestClient_FetchAvailableModels_AllURLsFail_RealCall(t *testing.T) { server1 := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusNotFound) _, _ = w.Write([]byte(`not found`)) })) defer server1.Close() server2 := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusInternalServerError) _, _ = w.Write([]byte(`internal error`)) })) defer server2.Close() withMockBaseURLs(t, []string{server1.URL, server2.URL}) client := NewClient("") _, _, err := client.FetchAvailableModels(context.Background(), "token", "proj") if err == nil { t.Fatal("所有 URL 都失败时应返回错误") } } func TestClient_FetchAvailableModels_ContextCanceled_RealCall(t *testing.T) { server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { time.Sleep(5 * time.Second) w.WriteHeader(http.StatusOK) })) defer server.Close() withMockBaseURLs(t, []string{server.URL}) client := NewClient("") ctx, cancel := context.WithCancel(context.Background()) cancel() _, _, err := client.FetchAvailableModels(ctx, "token", "proj") if err == nil { t.Fatal("context 取消时应返回错误") } } func TestClient_FetchAvailableModels_EmptyModels_RealCall(t *testing.T) { server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "application/json") w.WriteHeader(http.StatusOK) _, _ = w.Write([]byte(`{"models": {}}`)) })) defer server.Close() withMockBaseURLs(t, []string{server.URL}) client := NewClient("") resp, rawResp, err := client.FetchAvailableModels(context.Background(), "token", "proj") if err != nil { t.Fatalf("FetchAvailableModels 失败: %v", err) } if resp.Models == nil { t.Fatal("Models 不应为 nil") } if len(resp.Models) != 0 { t.Errorf("Models 应为空: got %d", len(resp.Models)) } if rawResp == nil { t.Fatal("rawResp 不应为 nil") } } // --------------------------------------------------------------------------- // LoadCodeAssist 和 FetchAvailableModels 的 408 fallback 测试 // --------------------------------------------------------------------------- func TestClient_LoadCodeAssist_408Fallback_RealCall(t *testing.T) { server1 := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusRequestTimeout) _, _ = w.Write([]byte(`timeout`)) })) defer server1.Close() server2 := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "application/json") w.WriteHeader(http.StatusOK) _, _ = w.Write([]byte(`{"cloudaicompanionProject":"p2","currentTier":"free-tier"}`)) })) defer server2.Close() withMockBaseURLs(t, []string{server1.URL, server2.URL}) client := NewClient("") resp, _, err := client.LoadCodeAssist(context.Background(), "token") if err != nil { t.Fatalf("LoadCodeAssist 应在 408 fallback 后成功: %v", err) } if resp.CloudAICompanionProject != "p2" { t.Errorf("CloudAICompanionProject 不匹配: got %s", resp.CloudAICompanionProject) } } func TestClient_FetchAvailableModels_404Fallback_RealCall(t *testing.T) { server1 := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusNotFound) _, _ = w.Write([]byte(`not found`)) })) defer server1.Close() server2 := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "application/json") w.WriteHeader(http.StatusOK) _, _ = w.Write([]byte(`{"models":{"m1":{"quotaInfo":{"remainingFraction":1.0}}}}`)) })) defer server2.Close() withMockBaseURLs(t, []string{server1.URL, server2.URL}) client := NewClient("") resp, _, err := client.FetchAvailableModels(context.Background(), "token", "proj") if err != nil { t.Fatalf("FetchAvailableModels 应在 404 fallback 后成功: %v", err) } if _, ok := resp.Models["m1"]; !ok { t.Error("应返回 fallback server 的模型 m1") } }