diff --git a/backend/internal/config/config.go b/backend/internal/config/config.go index 4f6fea37..763ed829 100644 --- a/backend/internal/config/config.go +++ b/backend/internal/config/config.go @@ -265,8 +265,13 @@ type CSPConfig struct { } type ProxyFallbackConfig struct { - // AllowDirectOnError 当代理初始化失败时是否允许回退直连。 - // 默认 false:避免因代理配置错误导致 IP 泄露/关联。 + // AllowDirectOnError 当辅助服务的代理初始化失败时是否允许回退直连。 + // 仅影响以下非 AI 账号连接的辅助服务: + // - GitHub Release 更新检查 + // - 定价数据拉取 + // 不影响 AI 账号网关连接(Claude/OpenAI/Gemini/Antigravity), + // 这些关键路径的代理失败始终返回错误,不会回退直连。 + // 默认 false:避免因代理配置错误导致服务器真实 IP 泄露。 AllowDirectOnError bool `mapstructure:"allow_direct_on_error"` } @@ -1105,6 +1110,9 @@ func setDefaults() { viper.SetDefault("security.csp.policy", DefaultCSPPolicy) viper.SetDefault("security.proxy_probe.insecure_skip_verify", false) + // Security - disable direct fallback on proxy error + viper.SetDefault("security.proxy_fallback.allow_direct_on_error", false) + // Billing viper.SetDefault("billing.circuit_breaker.enabled", true) viper.SetDefault("billing.circuit_breaker.failure_threshold", 5) @@ -1415,9 +1423,6 @@ func setDefaults() { viper.SetDefault("gemini.oauth.scopes", "") viper.SetDefault("gemini.quota.policy", "") - // Security - proxy fallback - viper.SetDefault("security.proxy_fallback.allow_direct_on_error", false) - // Subscription Maintenance (bounded queue + worker pool) viper.SetDefault("subscription_maintenance.worker_count", 2) viper.SetDefault("subscription_maintenance.queue_size", 1024) diff --git a/backend/internal/pkg/antigravity/client.go b/backend/internal/pkg/antigravity/client.go index 1998221a..d46bbc45 100644 --- a/backend/internal/pkg/antigravity/client.go +++ b/backend/internal/pkg/antigravity/client.go @@ -14,6 +14,9 @@ import ( "net/url" "strings" "time" + + "github.com/Wei-Shaw/sub2api/internal/pkg/proxyurl" + "github.com/Wei-Shaw/sub2api/internal/pkg/proxyutil" ) // NewAPIRequestWithURL 使用指定的 base URL 创建 Antigravity API 请求(v1internal 端点) @@ -149,22 +152,26 @@ type Client struct { httpClient *http.Client } -func NewClient(proxyURL string) *Client { +func NewClient(proxyURL string) (*Client, error) { client := &http.Client{ Timeout: 30 * time.Second, } - if strings.TrimSpace(proxyURL) != "" { - if proxyURLParsed, err := url.Parse(proxyURL); err == nil { - client.Transport = &http.Transport{ - Proxy: http.ProxyURL(proxyURLParsed), - } + _, parsed, err := proxyurl.Parse(proxyURL) + if err != nil { + return nil, err + } + if parsed != nil { + transport := &http.Transport{} + if err := proxyutil.ConfigureTransportProxy(transport, parsed); err != nil { + return nil, fmt.Errorf("configure proxy: %w", err) } + client.Transport = transport } return &Client{ httpClient: client, - } + }, nil } // isConnectionError 判断是否为连接错误(网络超时、DNS 失败、连接拒绝) diff --git a/backend/internal/pkg/antigravity/client_test.go b/backend/internal/pkg/antigravity/client_test.go index 394b6128..20b57833 100644 --- a/backend/internal/pkg/antigravity/client_test.go +++ b/backend/internal/pkg/antigravity/client_test.go @@ -228,8 +228,20 @@ func TestGetTier_两者都为nil(t *testing.T) { // NewClient // --------------------------------------------------------------------------- +func mustNewClient(t *testing.T, proxyURL string) *Client { + t.Helper() + client, err := NewClient(proxyURL) + if err != nil { + t.Fatalf("NewClient(%q) failed: %v", proxyURL, err) + } + return client +} + func TestNewClient_无代理(t *testing.T) { - client := NewClient("") + client, err := NewClient("") + if err != nil { + t.Fatalf("NewClient 返回错误: %v", err) + } if client == nil { t.Fatal("NewClient 返回 nil") } @@ -246,7 +258,10 @@ func TestNewClient_无代理(t *testing.T) { } func TestNewClient_有代理(t *testing.T) { - client := NewClient("http://proxy.example.com:8080") + client, err := NewClient("http://proxy.example.com:8080") + if err != nil { + t.Fatalf("NewClient 返回错误: %v", err) + } if client == nil { t.Fatal("NewClient 返回 nil") } @@ -256,7 +271,10 @@ func TestNewClient_有代理(t *testing.T) { } func TestNewClient_空格代理(t *testing.T) { - client := NewClient(" ") + client, err := NewClient(" ") + if err != nil { + t.Fatalf("NewClient 返回错误: %v", err) + } if client == nil { t.Fatal("NewClient 返回 nil") } @@ -267,15 +285,13 @@ func TestNewClient_空格代理(t *testing.T) { } func TestNewClient_无效代理URL(t *testing.T) { - // 无效 URL 时 url.Parse 不一定返回错误(Go 的 url.Parse 很宽容), - // 但 ://invalid 会导致解析错误 - client := NewClient("://invalid") - if client == nil { - t.Fatal("NewClient 返回 nil") + // 无效 URL 应返回 error + _, err := NewClient("://invalid") + if err == nil { + t.Fatal("无效代理 URL 应返回错误") } - // 无效 URL 解析失败时,Transport 应保持 nil - if client.httpClient.Transport != nil { - t.Error("无效代理 URL 时 Transport 应为 nil") + if !strings.Contains(err.Error(), "invalid proxy URL") { + t.Errorf("错误信息应包含 'invalid proxy URL': got %s", err.Error()) } } @@ -499,7 +515,7 @@ func TestClient_ExchangeCode_无ClientSecret(t *testing.T) { defaultClientSecret = "" t.Cleanup(func() { defaultClientSecret = old }) - client := NewClient("") + client := mustNewClient(t, "") _, err := client.ExchangeCode(context.Background(), "code", "verifier") if err == nil { t.Fatal("缺少 client_secret 时应返回错误") @@ -602,7 +618,7 @@ func TestClient_RefreshToken_无ClientSecret(t *testing.T) { defaultClientSecret = "" t.Cleanup(func() { defaultClientSecret = old }) - client := NewClient("") + client := mustNewClient(t, "") _, err := client.RefreshToken(context.Background(), "refresh-tok") if err == nil { t.Fatal("缺少 client_secret 时应返回错误") @@ -1242,7 +1258,7 @@ func TestClient_LoadCodeAssist_Success_RealCall(t *testing.T) { withMockBaseURLs(t, []string{server.URL}) - client := NewClient("") + client := mustNewClient(t, "") resp, rawResp, err := client.LoadCodeAssist(context.Background(), "test-token") if err != nil { t.Fatalf("LoadCodeAssist 失败: %v", err) @@ -1277,7 +1293,7 @@ func TestClient_LoadCodeAssist_HTTPError_RealCall(t *testing.T) { withMockBaseURLs(t, []string{server.URL}) - client := NewClient("") + client := mustNewClient(t, "") _, _, err := client.LoadCodeAssist(context.Background(), "bad-token") if err == nil { t.Fatal("服务器返回 403 时应返回错误") @@ -1300,7 +1316,7 @@ func TestClient_LoadCodeAssist_InvalidJSON_RealCall(t *testing.T) { withMockBaseURLs(t, []string{server.URL}) - client := NewClient("") + client := mustNewClient(t, "") _, _, err := client.LoadCodeAssist(context.Background(), "token") if err == nil { t.Fatal("无效 JSON 响应应返回错误") @@ -1333,7 +1349,7 @@ func TestClient_LoadCodeAssist_URLFallback_RealCall(t *testing.T) { withMockBaseURLs(t, []string{server1.URL, server2.URL}) - client := NewClient("") + client := mustNewClient(t, "") resp, _, err := client.LoadCodeAssist(context.Background(), "token") if err != nil { t.Fatalf("LoadCodeAssist 应在 fallback 后成功: %v", err) @@ -1361,7 +1377,7 @@ func TestClient_LoadCodeAssist_AllURLsFail_RealCall(t *testing.T) { withMockBaseURLs(t, []string{server1.URL, server2.URL}) - client := NewClient("") + client := mustNewClient(t, "") _, _, err := client.LoadCodeAssist(context.Background(), "token") if err == nil { t.Fatal("所有 URL 都失败时应返回错误") @@ -1377,7 +1393,7 @@ func TestClient_LoadCodeAssist_ContextCanceled_RealCall(t *testing.T) { withMockBaseURLs(t, []string{server.URL}) - client := NewClient("") + client := mustNewClient(t, "") ctx, cancel := context.WithCancel(context.Background()) cancel() @@ -1441,7 +1457,7 @@ func TestClient_FetchAvailableModels_Success_RealCall(t *testing.T) { withMockBaseURLs(t, []string{server.URL}) - client := NewClient("") + client := mustNewClient(t, "") resp, rawResp, err := client.FetchAvailableModels(context.Background(), "test-token", "project-abc") if err != nil { t.Fatalf("FetchAvailableModels 失败: %v", err) @@ -1496,7 +1512,7 @@ func TestClient_FetchAvailableModels_HTTPError_RealCall(t *testing.T) { withMockBaseURLs(t, []string{server.URL}) - client := NewClient("") + client := mustNewClient(t, "") _, _, err := client.FetchAvailableModels(context.Background(), "bad-token", "proj") if err == nil { t.Fatal("服务器返回 403 时应返回错误") @@ -1516,7 +1532,7 @@ func TestClient_FetchAvailableModels_InvalidJSON_RealCall(t *testing.T) { withMockBaseURLs(t, []string{server.URL}) - client := NewClient("") + client := mustNewClient(t, "") _, _, err := client.FetchAvailableModels(context.Background(), "token", "proj") if err == nil { t.Fatal("无效 JSON 响应应返回错误") @@ -1546,7 +1562,7 @@ func TestClient_FetchAvailableModels_URLFallback_RealCall(t *testing.T) { withMockBaseURLs(t, []string{server1.URL, server2.URL}) - client := NewClient("") + client := mustNewClient(t, "") resp, _, err := client.FetchAvailableModels(context.Background(), "token", "proj") if err != nil { t.Fatalf("FetchAvailableModels 应在 fallback 后成功: %v", err) @@ -1574,7 +1590,7 @@ func TestClient_FetchAvailableModels_AllURLsFail_RealCall(t *testing.T) { withMockBaseURLs(t, []string{server1.URL, server2.URL}) - client := NewClient("") + client := mustNewClient(t, "") _, _, err := client.FetchAvailableModels(context.Background(), "token", "proj") if err == nil { t.Fatal("所有 URL 都失败时应返回错误") @@ -1590,7 +1606,7 @@ func TestClient_FetchAvailableModels_ContextCanceled_RealCall(t *testing.T) { withMockBaseURLs(t, []string{server.URL}) - client := NewClient("") + client := mustNewClient(t, "") ctx, cancel := context.WithCancel(context.Background()) cancel() @@ -1610,7 +1626,7 @@ func TestClient_FetchAvailableModels_EmptyModels_RealCall(t *testing.T) { withMockBaseURLs(t, []string{server.URL}) - client := NewClient("") + client := mustNewClient(t, "") resp, rawResp, err := client.FetchAvailableModels(context.Background(), "token", "proj") if err != nil { t.Fatalf("FetchAvailableModels 失败: %v", err) @@ -1646,7 +1662,7 @@ func TestClient_LoadCodeAssist_408Fallback_RealCall(t *testing.T) { withMockBaseURLs(t, []string{server1.URL, server2.URL}) - client := NewClient("") + client := mustNewClient(t, "") resp, _, err := client.LoadCodeAssist(context.Background(), "token") if err != nil { t.Fatalf("LoadCodeAssist 应在 408 fallback 后成功: %v", err) @@ -1672,7 +1688,7 @@ func TestClient_FetchAvailableModels_404Fallback_RealCall(t *testing.T) { withMockBaseURLs(t, []string{server1.URL, server2.URL}) - client := NewClient("") + client := mustNewClient(t, "") resp, _, err := client.FetchAvailableModels(context.Background(), "token", "proj") if err != nil { t.Fatalf("FetchAvailableModels 应在 404 fallback 后成功: %v", err) diff --git a/backend/internal/pkg/httpclient/pool.go b/backend/internal/pkg/httpclient/pool.go index 6ef3d714..32e4bc5b 100644 --- a/backend/internal/pkg/httpclient/pool.go +++ b/backend/internal/pkg/httpclient/pool.go @@ -18,11 +18,11 @@ package httpclient import ( "fmt" "net/http" - "net/url" "strings" "sync" "time" + "github.com/Wei-Shaw/sub2api/internal/pkg/proxyurl" "github.com/Wei-Shaw/sub2api/internal/pkg/proxyutil" "github.com/Wei-Shaw/sub2api/internal/util/urlvalidator" ) @@ -41,7 +41,6 @@ type Options struct { Timeout time.Duration // 请求总超时时间 ResponseHeaderTimeout time.Duration // 等待响应头超时时间 InsecureSkipVerify bool // 是否跳过 TLS 证书验证(已禁用,不允许设置为 true) - ProxyStrict bool // 严格代理模式:代理失败时返回错误而非回退 ValidateResolvedIP bool // 是否校验解析后的 IP(防止 DNS Rebinding) AllowPrivateHosts bool // 允许私有地址解析(与 ValidateResolvedIP 一起使用) @@ -120,15 +119,13 @@ func buildTransport(opts Options) (*http.Transport, error) { return nil, fmt.Errorf("insecure_skip_verify is not allowed; install a trusted certificate instead") } - proxyURL := strings.TrimSpace(opts.ProxyURL) - if proxyURL == "" { - return transport, nil - } - - parsed, err := url.Parse(proxyURL) + _, parsed, err := proxyurl.Parse(opts.ProxyURL) if err != nil { return nil, err } + if parsed == nil { + return transport, nil + } if err := proxyutil.ConfigureTransportProxy(transport, parsed); err != nil { return nil, err @@ -138,12 +135,11 @@ func buildTransport(opts Options) (*http.Transport, error) { } func buildClientKey(opts Options) string { - return fmt.Sprintf("%s|%s|%s|%t|%t|%t|%t|%d|%d|%d", + return fmt.Sprintf("%s|%s|%s|%t|%t|%t|%d|%d|%d", strings.TrimSpace(opts.ProxyURL), opts.Timeout.String(), opts.ResponseHeaderTimeout.String(), opts.InsecureSkipVerify, - opts.ProxyStrict, opts.ValidateResolvedIP, opts.AllowPrivateHosts, opts.MaxIdleConns, diff --git a/backend/internal/pkg/proxyurl/parse.go b/backend/internal/pkg/proxyurl/parse.go new file mode 100644 index 00000000..217556f2 --- /dev/null +++ b/backend/internal/pkg/proxyurl/parse.go @@ -0,0 +1,66 @@ +// Package proxyurl 提供代理 URL 的统一验证(fail-fast,无效代理不回退直连) +// +// 所有需要解析代理 URL 的地方必须通过此包的 Parse 函数。 +// 直接使用 url.Parse 处理代理 URL 是被禁止的。 +// 这确保了 fail-fast 行为:无效代理配置在创建时立即失败, +// 而不是在运行时静默回退到直连(产生 IP 关联风险)。 +package proxyurl + +import ( + "fmt" + "net/url" + "strings" +) + +// allowedSchemes 代理协议白名单 +var allowedSchemes = map[string]bool{ + "http": true, + "https": true, + "socks5": true, + "socks5h": true, +} + +// Parse 解析并验证代理 URL。 +// +// 语义: +// - 空字符串 → ("", nil, nil),表示直连 +// - 非空且有效 → (trimmed, *url.URL, nil) +// - 非空但无效 → ("", nil, error),fail-fast 不回退 +// +// 验证规则: +// - TrimSpace 后为空视为直连 +// - url.Parse 失败返回 error(不含原始 URL,防凭据泄露) +// - Host 为空返回 error(用 Redacted() 脱敏) +// - Scheme 必须为 http/https/socks5/socks5h +// - socks5:// 自动升级为 socks5h://(确保 DNS 由代理端解析,防止 DNS 泄漏) +func Parse(raw string) (trimmed string, parsed *url.URL, err error) { + trimmed = strings.TrimSpace(raw) + if trimmed == "" { + return "", nil, nil + } + + parsed, err = url.Parse(trimmed) + if err != nil { + // 不使用 %w 包装,避免 url.Parse 的底层错误消息泄漏原始 URL(可能含凭据) + return "", nil, fmt.Errorf("invalid proxy URL: %v", err) + } + + if parsed.Host == "" || parsed.Hostname() == "" { + return "", nil, fmt.Errorf("proxy URL missing host: %s", parsed.Redacted()) + } + + scheme := strings.ToLower(parsed.Scheme) + if !allowedSchemes[scheme] { + return "", nil, fmt.Errorf("unsupported proxy scheme %q (allowed: http, https, socks5, socks5h)", scheme) + } + + // 自动升级 socks5 → socks5h,确保 DNS 由代理端解析,防止 DNS 泄漏。 + // Go 的 golang.org/x/net/proxy 对 socks5:// 默认在客户端本地解析 DNS, + // 仅 socks5h:// 才将域名发送给代理端做远程 DNS 解析。 + if scheme == "socks5" { + parsed.Scheme = "socks5h" + trimmed = parsed.String() + } + + return trimmed, parsed, nil +} diff --git a/backend/internal/pkg/proxyurl/parse_test.go b/backend/internal/pkg/proxyurl/parse_test.go new file mode 100644 index 00000000..5fb57c16 --- /dev/null +++ b/backend/internal/pkg/proxyurl/parse_test.go @@ -0,0 +1,215 @@ +package proxyurl + +import ( + "strings" + "testing" +) + +func TestParse_空字符串直连(t *testing.T) { + trimmed, parsed, err := Parse("") + if err != nil { + t.Fatalf("空字符串应直连: %v", err) + } + if trimmed != "" { + t.Errorf("trimmed 应为空: got %q", trimmed) + } + if parsed != nil { + t.Errorf("parsed 应为 nil: got %v", parsed) + } +} + +func TestParse_空白字符串直连(t *testing.T) { + trimmed, parsed, err := Parse(" ") + if err != nil { + t.Fatalf("空白字符串应直连: %v", err) + } + if trimmed != "" { + t.Errorf("trimmed 应为空: got %q", trimmed) + } + if parsed != nil { + t.Errorf("parsed 应为 nil: got %v", parsed) + } +} + +func TestParse_有效HTTP代理(t *testing.T) { + trimmed, parsed, err := Parse("http://proxy.example.com:8080") + if err != nil { + t.Fatalf("有效 HTTP 代理应成功: %v", err) + } + if trimmed != "http://proxy.example.com:8080" { + t.Errorf("trimmed 不匹配: got %q", trimmed) + } + if parsed == nil { + t.Fatal("parsed 不应为 nil") + } + if parsed.Host != "proxy.example.com:8080" { + t.Errorf("Host 不匹配: got %q", parsed.Host) + } +} + +func TestParse_有效HTTPS代理(t *testing.T) { + _, parsed, err := Parse("https://proxy.example.com:443") + if err != nil { + t.Fatalf("有效 HTTPS 代理应成功: %v", err) + } + if parsed.Scheme != "https" { + t.Errorf("Scheme 不匹配: got %q", parsed.Scheme) + } +} + +func TestParse_有效SOCKS5代理_自动升级为SOCKS5H(t *testing.T) { + trimmed, parsed, err := Parse("socks5://127.0.0.1:1080") + if err != nil { + t.Fatalf("有效 SOCKS5 代理应成功: %v", err) + } + // socks5 自动升级为 socks5h,确保 DNS 由代理端解析 + if trimmed != "socks5h://127.0.0.1:1080" { + t.Errorf("trimmed 应升级为 socks5h: got %q", trimmed) + } + if parsed.Scheme != "socks5h" { + t.Errorf("Scheme 应升级为 socks5h: got %q", parsed.Scheme) + } +} + +func TestParse_无效URL(t *testing.T) { + _, _, err := Parse("://invalid") + if err == nil { + t.Fatal("无效 URL 应返回错误") + } + if !strings.Contains(err.Error(), "invalid proxy URL") { + t.Errorf("错误信息应包含 'invalid proxy URL': got %s", err.Error()) + } +} + +func TestParse_缺少Host(t *testing.T) { + _, _, err := Parse("http://") + if err == nil { + t.Fatal("缺少 host 应返回错误") + } + if !strings.Contains(err.Error(), "missing host") { + t.Errorf("错误信息应包含 'missing host': got %s", err.Error()) + } +} + +func TestParse_不支持的Scheme(t *testing.T) { + _, _, err := Parse("ftp://proxy.example.com:21") + if err == nil { + t.Fatal("不支持的 scheme 应返回错误") + } + if !strings.Contains(err.Error(), "unsupported proxy scheme") { + t.Errorf("错误信息应包含 'unsupported proxy scheme': got %s", err.Error()) + } +} + +func TestParse_含密码URL脱敏(t *testing.T) { + // 场景 1: 带密码的 socks5 URL 应成功解析并升级为 socks5h + trimmed, parsed, err := Parse("socks5://user:secret_password@proxy.local:1080") + if err != nil { + t.Fatalf("含密码的有效 URL 应成功: %v", err) + } + if trimmed == "" || parsed == nil { + t.Fatal("应返回非空结果") + } + if parsed.Scheme != "socks5h" { + t.Errorf("Scheme 应升级为 socks5h: got %q", parsed.Scheme) + } + if !strings.HasPrefix(trimmed, "socks5h://") { + t.Errorf("trimmed 应以 socks5h:// 开头: got %q", trimmed) + } + if parsed.User == nil { + t.Error("升级后应保留 UserInfo") + } + + // 场景 2: 带密码但缺少 host(触发 Redacted 脱敏路径) + _, _, err = Parse("http://user:secret_password@:0/") + if err == nil { + t.Fatal("缺少 host 应返回错误") + } + if strings.Contains(err.Error(), "secret_password") { + t.Error("错误信息不应包含明文密码") + } + if !strings.Contains(err.Error(), "missing host") { + t.Errorf("错误信息应包含 'missing host': got %s", err.Error()) + } +} + +func TestParse_带空白的有效URL(t *testing.T) { + trimmed, parsed, err := Parse(" http://proxy.example.com:8080 ") + if err != nil { + t.Fatalf("带空白的有效 URL 应成功: %v", err) + } + if trimmed != "http://proxy.example.com:8080" { + t.Errorf("trimmed 应去除空白: got %q", trimmed) + } + if parsed == nil { + t.Fatal("parsed 不应为 nil") + } +} + +func TestParse_Scheme大小写不敏感(t *testing.T) { + // 大写 SOCKS5 应被接受并升级为 socks5h + trimmed, parsed, err := Parse("SOCKS5://proxy.example.com:1080") + if err != nil { + t.Fatalf("大写 SOCKS5 应被接受: %v", err) + } + if parsed.Scheme != "socks5h" { + t.Errorf("大写 SOCKS5 Scheme 应升级为 socks5h: got %q", parsed.Scheme) + } + if !strings.HasPrefix(trimmed, "socks5h://") { + t.Errorf("大写 SOCKS5 trimmed 应升级为 socks5h://: got %q", trimmed) + } + + // 大写 HTTP 应被接受(不变) + _, _, err = Parse("HTTP://proxy.example.com:8080") + if err != nil { + t.Fatalf("大写 HTTP 应被接受: %v", err) + } +} + +func TestParse_带认证的有效代理(t *testing.T) { + trimmed, parsed, err := Parse("http://user:pass@proxy.example.com:8080") + if err != nil { + t.Fatalf("带认证的代理 URL 应成功: %v", err) + } + if parsed.User == nil { + t.Error("应保留 UserInfo") + } + if trimmed != "http://user:pass@proxy.example.com:8080" { + t.Errorf("trimmed 不匹配: got %q", trimmed) + } +} + +func TestParse_IPv6地址(t *testing.T) { + trimmed, parsed, err := Parse("http://[::1]:8080") + if err != nil { + t.Fatalf("IPv6 代理 URL 应成功: %v", err) + } + if parsed.Hostname() != "::1" { + t.Errorf("Hostname 不匹配: got %q", parsed.Hostname()) + } + if trimmed != "http://[::1]:8080" { + t.Errorf("trimmed 不匹配: got %q", trimmed) + } +} + +func TestParse_SOCKS5H保持不变(t *testing.T) { + trimmed, parsed, err := Parse("socks5h://proxy.local:1080") + if err != nil { + t.Fatalf("有效 SOCKS5H 代理应成功: %v", err) + } + // socks5h 不需要升级,应保持原样 + if trimmed != "socks5h://proxy.local:1080" { + t.Errorf("trimmed 不应变化: got %q", trimmed) + } + if parsed.Scheme != "socks5h" { + t.Errorf("Scheme 应保持 socks5h: got %q", parsed.Scheme) + } +} + +func TestParse_无Scheme裸地址(t *testing.T) { + // 无 scheme 的裸地址,Go url.Parse 将其视为 path,Host 为空 + _, _, err := Parse("proxy.example.com:8080") + if err == nil { + t.Fatal("无 scheme 的裸地址应返回错误") + } +} diff --git a/backend/internal/pkg/proxyutil/dialer.go b/backend/internal/pkg/proxyutil/dialer.go index 91b224a2..e437cae3 100644 --- a/backend/internal/pkg/proxyutil/dialer.go +++ b/backend/internal/pkg/proxyutil/dialer.go @@ -2,7 +2,11 @@ // // 支持的代理协议: // - HTTP/HTTPS: 通过 Transport.Proxy 设置 -// - SOCKS5/SOCKS5H: 通过 Transport.DialContext 设置(服务端解析 DNS) +// - SOCKS5: 通过 Transport.DialContext 设置(客户端本地解析 DNS) +// - SOCKS5H: 通过 Transport.DialContext 设置(代理端远程解析 DNS,推荐) +// +// 注意:proxyurl.Parse() 会自动将 socks5:// 升级为 socks5h://, +// 确保 DNS 也由代理端解析,防止 DNS 泄漏。 package proxyutil import ( @@ -20,7 +24,8 @@ import ( // // 支持的协议: // - http/https: 设置 transport.Proxy -// - socks5/socks5h: 设置 transport.DialContext(由代理服务端解析 DNS) +// - socks5: 设置 transport.DialContext(客户端本地解析 DNS) +// - socks5h: 设置 transport.DialContext(代理端远程解析 DNS,推荐) // // 参数: // - transport: 需要配置的 http.Transport diff --git a/backend/internal/repository/claude_oauth_service.go b/backend/internal/repository/claude_oauth_service.go index 77764881..b754bd55 100644 --- a/backend/internal/repository/claude_oauth_service.go +++ b/backend/internal/repository/claude_oauth_service.go @@ -11,6 +11,7 @@ import ( "github.com/Wei-Shaw/sub2api/internal/pkg/logger" "github.com/Wei-Shaw/sub2api/internal/pkg/oauth" + "github.com/Wei-Shaw/sub2api/internal/pkg/proxyurl" "github.com/Wei-Shaw/sub2api/internal/service" "github.com/Wei-Shaw/sub2api/internal/util/logredact" @@ -28,11 +29,14 @@ func NewClaudeOAuthClient() service.ClaudeOAuthClient { type claudeOAuthService struct { baseURL string tokenURL string - clientFactory func(proxyURL string) *req.Client + clientFactory func(proxyURL string) (*req.Client, error) } func (s *claudeOAuthService) GetOrganizationUUID(ctx context.Context, sessionKey, proxyURL string) (string, error) { - client := s.clientFactory(proxyURL) + client, err := s.clientFactory(proxyURL) + if err != nil { + return "", fmt.Errorf("create HTTP client: %w", err) + } var orgs []struct { UUID string `json:"uuid"` @@ -88,7 +92,10 @@ func (s *claudeOAuthService) GetOrganizationUUID(ctx context.Context, sessionKey } func (s *claudeOAuthService) GetAuthorizationCode(ctx context.Context, sessionKey, orgUUID, scope, codeChallenge, state, proxyURL string) (string, error) { - client := s.clientFactory(proxyURL) + client, err := s.clientFactory(proxyURL) + if err != nil { + return "", fmt.Errorf("create HTTP client: %w", err) + } authURL := fmt.Sprintf("%s/v1/oauth/%s/authorize", s.baseURL, orgUUID) @@ -165,7 +172,10 @@ func (s *claudeOAuthService) GetAuthorizationCode(ctx context.Context, sessionKe } func (s *claudeOAuthService) ExchangeCodeForToken(ctx context.Context, code, codeVerifier, state, proxyURL string, isSetupToken bool) (*oauth.TokenResponse, error) { - client := s.clientFactory(proxyURL) + client, err := s.clientFactory(proxyURL) + if err != nil { + return nil, fmt.Errorf("create HTTP client: %w", err) + } // Parse code which may contain state in format "authCode#state" authCode := code @@ -223,7 +233,10 @@ func (s *claudeOAuthService) ExchangeCodeForToken(ctx context.Context, code, cod } func (s *claudeOAuthService) RefreshToken(ctx context.Context, refreshToken, proxyURL string) (*oauth.TokenResponse, error) { - client := s.clientFactory(proxyURL) + client, err := s.clientFactory(proxyURL) + if err != nil { + return nil, fmt.Errorf("create HTTP client: %w", err) + } reqBody := map[string]any{ "grant_type": "refresh_token", @@ -253,16 +266,20 @@ func (s *claudeOAuthService) RefreshToken(ctx context.Context, refreshToken, pro return &tokenResp, nil } -func createReqClient(proxyURL string) *req.Client { +func createReqClient(proxyURL string) (*req.Client, error) { // 禁用 CookieJar,确保每次授权都是干净的会话 client := req.C(). SetTimeout(60 * time.Second). ImpersonateChrome(). SetCookieJar(nil) // 禁用 CookieJar - if strings.TrimSpace(proxyURL) != "" { - client.SetProxyURL(strings.TrimSpace(proxyURL)) + trimmed, _, err := proxyurl.Parse(proxyURL) + if err != nil { + return nil, err + } + if trimmed != "" { + client.SetProxyURL(trimmed) } - return client + return client, nil } diff --git a/backend/internal/repository/claude_oauth_service_test.go b/backend/internal/repository/claude_oauth_service_test.go index 7395c6d8..c6383033 100644 --- a/backend/internal/repository/claude_oauth_service_test.go +++ b/backend/internal/repository/claude_oauth_service_test.go @@ -91,7 +91,7 @@ func (s *ClaudeOAuthServiceSuite) TestGetOrganizationUUID() { require.True(s.T(), ok, "type assertion failed") s.client = client s.client.baseURL = "http://in-process" - s.client.clientFactory = func(string) *req.Client { return newTestReqClient(rt) } + s.client.clientFactory = func(string) (*req.Client, error) { return newTestReqClient(rt), nil } got, err := s.client.GetOrganizationUUID(context.Background(), "sess", "") @@ -169,7 +169,7 @@ func (s *ClaudeOAuthServiceSuite) TestGetAuthorizationCode() { require.True(s.T(), ok, "type assertion failed") s.client = client s.client.baseURL = "http://in-process" - s.client.clientFactory = func(string) *req.Client { return newTestReqClient(rt) } + s.client.clientFactory = func(string) (*req.Client, error) { return newTestReqClient(rt), nil } code, err := s.client.GetAuthorizationCode(context.Background(), "sess", "org-1", oauth.ScopeInference, "cc", "st", "") @@ -276,7 +276,7 @@ func (s *ClaudeOAuthServiceSuite) TestExchangeCodeForToken() { require.True(s.T(), ok, "type assertion failed") s.client = client s.client.tokenURL = "http://in-process/token" - s.client.clientFactory = func(string) *req.Client { return newTestReqClient(rt) } + s.client.clientFactory = func(string) (*req.Client, error) { return newTestReqClient(rt), nil } resp, err := s.client.ExchangeCodeForToken(context.Background(), tt.code, "ver", "", "", tt.isSetupToken) @@ -372,7 +372,7 @@ func (s *ClaudeOAuthServiceSuite) TestRefreshToken() { require.True(s.T(), ok, "type assertion failed") s.client = client s.client.tokenURL = "http://in-process/token" - s.client.clientFactory = func(string) *req.Client { return newTestReqClient(rt) } + s.client.clientFactory = func(string) (*req.Client, error) { return newTestReqClient(rt), nil } resp, err := s.client.RefreshToken(context.Background(), "rt", "") diff --git a/backend/internal/repository/claude_usage_service.go b/backend/internal/repository/claude_usage_service.go index 1198f472..f6054828 100644 --- a/backend/internal/repository/claude_usage_service.go +++ b/backend/internal/repository/claude_usage_service.go @@ -83,7 +83,7 @@ func (s *claudeUsageService) FetchUsageWithOptions(ctx context.Context, opts *se AllowPrivateHosts: s.allowPrivateHosts, }) if err != nil { - client = &http.Client{Timeout: 30 * time.Second} + return nil, fmt.Errorf("create http client failed: %w", err) } resp, err = client.Do(req) diff --git a/backend/internal/repository/claude_usage_service_test.go b/backend/internal/repository/claude_usage_service_test.go index 2e10f3e5..cbd0b6d3 100644 --- a/backend/internal/repository/claude_usage_service_test.go +++ b/backend/internal/repository/claude_usage_service_test.go @@ -50,7 +50,7 @@ func (s *ClaudeUsageServiceSuite) TestFetchUsage_Success() { allowPrivateHosts: true, } - resp, err := s.fetcher.FetchUsage(context.Background(), "at", "://bad-proxy-url") + resp, err := s.fetcher.FetchUsage(context.Background(), "at", "") require.NoError(s.T(), err, "FetchUsage") require.Equal(s.T(), 12.5, resp.FiveHour.Utilization, "FiveHour utilization mismatch") require.Equal(s.T(), 34.0, resp.SevenDay.Utilization, "SevenDay utilization mismatch") @@ -112,6 +112,17 @@ func (s *ClaudeUsageServiceSuite) TestFetchUsage_ContextCancel() { require.Error(s.T(), err, "expected error for cancelled context") } +func (s *ClaudeUsageServiceSuite) TestFetchUsage_InvalidProxyReturnsError() { + s.fetcher = &claudeUsageService{ + usageURL: "http://example.com", + allowPrivateHosts: true, + } + + _, err := s.fetcher.FetchUsage(context.Background(), "at", "://bad-proxy-url") + require.Error(s.T(), err) + require.ErrorContains(s.T(), err, "create http client failed") +} + func TestClaudeUsageServiceSuite(t *testing.T) { suite.Run(t, new(ClaudeUsageServiceSuite)) } diff --git a/backend/internal/repository/gemini_oauth_client.go b/backend/internal/repository/gemini_oauth_client.go index 8b7fe625..eb14f313 100644 --- a/backend/internal/repository/gemini_oauth_client.go +++ b/backend/internal/repository/gemini_oauth_client.go @@ -26,7 +26,10 @@ func NewGeminiOAuthClient(cfg *config.Config) service.GeminiOAuthClient { } func (c *geminiOAuthClient) ExchangeCode(ctx context.Context, oauthType, code, codeVerifier, redirectURI, proxyURL string) (*geminicli.TokenResponse, error) { - client := createGeminiReqClient(proxyURL) + client, err := createGeminiReqClient(proxyURL) + if err != nil { + return nil, fmt.Errorf("create HTTP client: %w", err) + } // Use different OAuth clients based on oauthType: // - code_assist: always use built-in Gemini CLI OAuth client (public) @@ -72,7 +75,10 @@ func (c *geminiOAuthClient) ExchangeCode(ctx context.Context, oauthType, code, c } func (c *geminiOAuthClient) RefreshToken(ctx context.Context, oauthType, refreshToken, proxyURL string) (*geminicli.TokenResponse, error) { - client := createGeminiReqClient(proxyURL) + client, err := createGeminiReqClient(proxyURL) + if err != nil { + return nil, fmt.Errorf("create HTTP client: %w", err) + } oauthCfgInput := geminicli.OAuthConfig{ ClientID: c.cfg.Gemini.OAuth.ClientID, @@ -111,7 +117,7 @@ func (c *geminiOAuthClient) RefreshToken(ctx context.Context, oauthType, refresh return &tokenResp, nil } -func createGeminiReqClient(proxyURL string) *req.Client { +func createGeminiReqClient(proxyURL string) (*req.Client, error) { return getSharedReqClient(reqClientOptions{ ProxyURL: proxyURL, Timeout: 60 * time.Second, diff --git a/backend/internal/repository/geminicli_codeassist_client.go b/backend/internal/repository/geminicli_codeassist_client.go index 4f63280d..b5bc6497 100644 --- a/backend/internal/repository/geminicli_codeassist_client.go +++ b/backend/internal/repository/geminicli_codeassist_client.go @@ -26,7 +26,11 @@ func (c *geminiCliCodeAssistClient) LoadCodeAssist(ctx context.Context, accessTo } var out geminicli.LoadCodeAssistResponse - resp, err := createGeminiCliReqClient(proxyURL).R(). + client, err := createGeminiCliReqClient(proxyURL) + if err != nil { + return nil, fmt.Errorf("create HTTP client: %w", err) + } + resp, err := client.R(). SetContext(ctx). SetHeader("Authorization", "Bearer "+accessToken). SetHeader("Content-Type", "application/json"). @@ -66,7 +70,11 @@ func (c *geminiCliCodeAssistClient) OnboardUser(ctx context.Context, accessToken fmt.Printf("[CodeAssist] OnboardUser request body: %+v\n", reqBody) var out geminicli.OnboardUserResponse - resp, err := createGeminiCliReqClient(proxyURL).R(). + client, err := createGeminiCliReqClient(proxyURL) + if err != nil { + return nil, fmt.Errorf("create HTTP client: %w", err) + } + resp, err := client.R(). SetContext(ctx). SetHeader("Authorization", "Bearer "+accessToken). SetHeader("Content-Type", "application/json"). @@ -98,7 +106,7 @@ func (c *geminiCliCodeAssistClient) OnboardUser(ctx context.Context, accessToken return &out, nil } -func createGeminiCliReqClient(proxyURL string) *req.Client { +func createGeminiCliReqClient(proxyURL string) (*req.Client, error) { return getSharedReqClient(reqClientOptions{ ProxyURL: proxyURL, Timeout: 30 * time.Second, diff --git a/backend/internal/repository/github_release_service.go b/backend/internal/repository/github_release_service.go index 28efe914..ad1f22e3 100644 --- a/backend/internal/repository/github_release_service.go +++ b/backend/internal/repository/github_release_service.go @@ -5,8 +5,10 @@ import ( "encoding/json" "fmt" "io" + "log/slog" "net/http" "os" + "strings" "time" "github.com/Wei-Shaw/sub2api/internal/pkg/httpclient" @@ -24,13 +26,19 @@ type githubReleaseClientError struct { // NewGitHubReleaseClient 创建 GitHub Release 客户端 // proxyURL 为空时直连 GitHub,支持 http/https/socks5/socks5h 协议 +// 代理配置失败时行为由 allowDirectOnProxyError 控制: +// - false(默认):返回错误占位客户端,禁止回退到直连 +// - true:回退到直连(仅限管理员显式开启) func NewGitHubReleaseClient(proxyURL string, allowDirectOnProxyError bool) service.GitHubReleaseClient { + // 安全说明:httpclient.GetClient 的错误链(url.Parse / proxyutil)不含明文代理凭据, + // 但仍通过 slog 仅在服务端日志记录,不会暴露给 HTTP 响应。 sharedClient, err := httpclient.GetClient(httpclient.Options{ Timeout: 30 * time.Second, ProxyURL: proxyURL, }) if err != nil { - if proxyURL != "" && !allowDirectOnProxyError { + if strings.TrimSpace(proxyURL) != "" && !allowDirectOnProxyError { + slog.Warn("proxy client init failed, all requests will fail", "service", "github_release", "error", err) return &githubReleaseClientError{err: fmt.Errorf("proxy client init failed and direct fallback is disabled; set security.proxy_fallback.allow_direct_on_error=true to allow fallback: %w", err)} } sharedClient = &http.Client{Timeout: 30 * time.Second} @@ -42,7 +50,8 @@ func NewGitHubReleaseClient(proxyURL string, allowDirectOnProxyError bool) servi ProxyURL: proxyURL, }) if err != nil { - if proxyURL != "" && !allowDirectOnProxyError { + if strings.TrimSpace(proxyURL) != "" && !allowDirectOnProxyError { + slog.Warn("proxy download client init failed, all requests will fail", "service", "github_release", "error", err) return &githubReleaseClientError{err: fmt.Errorf("proxy client init failed and direct fallback is disabled; set security.proxy_fallback.allow_direct_on_error=true to allow fallback: %w", err)} } downloadClient = &http.Client{Timeout: 10 * time.Minute} diff --git a/backend/internal/repository/http_upstream.go b/backend/internal/repository/http_upstream.go index b0f15f19..a4674c1a 100644 --- a/backend/internal/repository/http_upstream.go +++ b/backend/internal/repository/http_upstream.go @@ -14,6 +14,7 @@ import ( "time" "github.com/Wei-Shaw/sub2api/internal/config" + "github.com/Wei-Shaw/sub2api/internal/pkg/proxyurl" "github.com/Wei-Shaw/sub2api/internal/pkg/proxyutil" "github.com/Wei-Shaw/sub2api/internal/pkg/tlsfingerprint" "github.com/Wei-Shaw/sub2api/internal/service" @@ -235,7 +236,10 @@ func (s *httpUpstreamService) acquireClientWithTLS(proxyURL string, accountID in // TLS 指纹客户端使用独立的缓存键,与普通客户端隔离 func (s *httpUpstreamService) getClientEntryWithTLS(proxyURL string, accountID int64, accountConcurrency int, profile *tlsfingerprint.Profile, markInFlight bool, enforceLimit bool) (*upstreamClientEntry, error) { isolation := s.getIsolationMode() - proxyKey, parsedProxy := normalizeProxyURL(proxyURL) + proxyKey, parsedProxy, err := normalizeProxyURL(proxyURL) + if err != nil { + return nil, err + } // TLS 指纹客户端使用独立的缓存键,加 "tls:" 前缀 cacheKey := "tls:" + buildCacheKey(isolation, proxyKey, accountID) poolKey := s.buildPoolKey(isolation, accountConcurrency) + ":tls" @@ -373,9 +377,8 @@ func (s *httpUpstreamService) acquireClient(proxyURL string, accountID int64, ac // - proxy: 按代理地址隔离,同一代理共享客户端 // - account: 按账户隔离,同一账户共享客户端(代理变更时重建) // - account_proxy: 按账户+代理组合隔离,最细粒度 -func (s *httpUpstreamService) getOrCreateClient(proxyURL string, accountID int64, accountConcurrency int) *upstreamClientEntry { - entry, _ := s.getClientEntry(proxyURL, accountID, accountConcurrency, false, false) - return entry +func (s *httpUpstreamService) getOrCreateClient(proxyURL string, accountID int64, accountConcurrency int) (*upstreamClientEntry, error) { + return s.getClientEntry(proxyURL, accountID, accountConcurrency, false, false) } // getClientEntry 获取或创建客户端条目 @@ -385,7 +388,10 @@ func (s *httpUpstreamService) getClientEntry(proxyURL string, accountID int64, a // 获取隔离模式 isolation := s.getIsolationMode() // 标准化代理 URL 并解析 - proxyKey, parsedProxy := normalizeProxyURL(proxyURL) + proxyKey, parsedProxy, err := normalizeProxyURL(proxyURL) + if err != nil { + return nil, err + } // 构建缓存键(根据隔离策略不同) cacheKey := buildCacheKey(isolation, proxyKey, accountID) // 构建连接池配置键(用于检测配置变更) @@ -680,17 +686,18 @@ func buildCacheKey(isolation, proxyKey string, accountID int64) string { // - raw: 原始代理 URL 字符串 // // 返回: -// - string: 标准化的代理键(空或解析失败返回 "direct") -// - *url.URL: 解析后的 URL(空或解析失败返回 nil) -func normalizeProxyURL(raw string) (string, *url.URL) { - proxyURL := strings.TrimSpace(raw) - if proxyURL == "" { - return directProxyKey, nil - } - parsed, err := url.Parse(proxyURL) +// - string: 标准化的代理键(空返回 "direct") +// - *url.URL: 解析后的 URL(空返回 nil) +// - error: 非空代理 URL 解析失败时返回错误(禁止回退到直连) +func normalizeProxyURL(raw string) (string, *url.URL, error) { + _, parsed, err := proxyurl.Parse(raw) if err != nil { - return directProxyKey, nil + return "", nil, err } + if parsed == nil { + return directProxyKey, nil, nil + } + // 规范化:小写 scheme/host,去除路径和查询参数 parsed.Scheme = strings.ToLower(parsed.Scheme) parsed.Host = strings.ToLower(parsed.Host) parsed.Path = "" @@ -710,7 +717,7 @@ func normalizeProxyURL(raw string) (string, *url.URL) { parsed.Host = hostname } } - return parsed.String(), parsed + return parsed.String(), parsed, nil } // defaultPoolSettings 获取默认连接池配置 diff --git a/backend/internal/repository/http_upstream_benchmark_test.go b/backend/internal/repository/http_upstream_benchmark_test.go index 1e7430a3..89892b3b 100644 --- a/backend/internal/repository/http_upstream_benchmark_test.go +++ b/backend/internal/repository/http_upstream_benchmark_test.go @@ -59,7 +59,10 @@ func BenchmarkHTTPUpstreamProxyClient(b *testing.B) { // 模拟优化后的行为,从缓存获取客户端 b.Run("复用", func(b *testing.B) { // 预热:确保客户端已缓存 - entry := svc.getOrCreateClient(proxyURL, 1, 1) + entry, err := svc.getOrCreateClient(proxyURL, 1, 1) + if err != nil { + b.Fatalf("getOrCreateClient: %v", err) + } client := entry.client b.ResetTimer() // 重置计时器,排除预热时间 for i := 0; i < b.N; i++ { diff --git a/backend/internal/repository/http_upstream_test.go b/backend/internal/repository/http_upstream_test.go index fbe44c5e..b3268463 100644 --- a/backend/internal/repository/http_upstream_test.go +++ b/backend/internal/repository/http_upstream_test.go @@ -44,7 +44,7 @@ func (s *HTTPUpstreamSuite) newService() *httpUpstreamService { // 验证未配置时使用 300 秒默认值 func (s *HTTPUpstreamSuite) TestDefaultResponseHeaderTimeout() { svc := s.newService() - entry := svc.getOrCreateClient("", 0, 0) + entry := mustGetOrCreateClient(s.T(), svc, "", 0, 0) transport, ok := entry.client.Transport.(*http.Transport) require.True(s.T(), ok, "expected *http.Transport") require.Equal(s.T(), 300*time.Second, transport.ResponseHeaderTimeout, "ResponseHeaderTimeout mismatch") @@ -55,25 +55,27 @@ func (s *HTTPUpstreamSuite) TestDefaultResponseHeaderTimeout() { func (s *HTTPUpstreamSuite) TestCustomResponseHeaderTimeout() { s.cfg.Gateway = config.GatewayConfig{ResponseHeaderTimeout: 7} svc := s.newService() - entry := svc.getOrCreateClient("", 0, 0) + entry := mustGetOrCreateClient(s.T(), svc, "", 0, 0) transport, ok := entry.client.Transport.(*http.Transport) require.True(s.T(), ok, "expected *http.Transport") require.Equal(s.T(), 7*time.Second, transport.ResponseHeaderTimeout, "ResponseHeaderTimeout mismatch") } -// TestGetOrCreateClient_InvalidURLFallsBackToDirect 测试无效代理 URL 回退 -// 验证解析失败时回退到直连模式 -func (s *HTTPUpstreamSuite) TestGetOrCreateClient_InvalidURLFallsBackToDirect() { +// TestGetOrCreateClient_InvalidURLReturnsError 测试无效代理 URL 返回错误 +// 验证解析失败时拒绝回退到直连模式 +func (s *HTTPUpstreamSuite) TestGetOrCreateClient_InvalidURLReturnsError() { svc := s.newService() - entry := svc.getOrCreateClient("://bad-proxy-url", 1, 1) - require.Equal(s.T(), directProxyKey, entry.proxyKey, "expected direct proxy fallback") + _, err := svc.getClientEntry("://bad-proxy-url", 1, 1, false, false) + require.Error(s.T(), err, "expected error for invalid proxy URL") } // TestNormalizeProxyURL_Canonicalizes 测试代理 URL 规范化 // 验证等价地址能够映射到同一缓存键 func (s *HTTPUpstreamSuite) TestNormalizeProxyURL_Canonicalizes() { - key1, _ := normalizeProxyURL("http://proxy.local:8080") - key2, _ := normalizeProxyURL("http://proxy.local:8080/") + key1, _, err1 := normalizeProxyURL("http://proxy.local:8080") + require.NoError(s.T(), err1) + key2, _, err2 := normalizeProxyURL("http://proxy.local:8080/") + require.NoError(s.T(), err2) require.Equal(s.T(), key1, key2, "expected normalized proxy keys to match") } @@ -171,8 +173,8 @@ func (s *HTTPUpstreamSuite) TestAccountIsolation_DifferentAccounts() { s.cfg.Gateway = config.GatewayConfig{ConnectionPoolIsolation: config.ConnectionPoolIsolationAccount} svc := s.newService() // 同一代理,不同账户 - entry1 := svc.getOrCreateClient("http://proxy.local:8080", 1, 3) - entry2 := svc.getOrCreateClient("http://proxy.local:8080", 2, 3) + entry1 := mustGetOrCreateClient(s.T(), svc, "http://proxy.local:8080", 1, 3) + entry2 := mustGetOrCreateClient(s.T(), svc, "http://proxy.local:8080", 2, 3) require.NotSame(s.T(), entry1, entry2, "不同账号不应共享连接池") require.Equal(s.T(), 2, len(svc.clients), "账号隔离应缓存两个客户端") } @@ -183,8 +185,8 @@ func (s *HTTPUpstreamSuite) TestAccountProxyIsolation_DifferentProxy() { s.cfg.Gateway = config.GatewayConfig{ConnectionPoolIsolation: config.ConnectionPoolIsolationAccountProxy} svc := s.newService() // 同一账户,不同代理 - entry1 := svc.getOrCreateClient("http://proxy-a:8080", 1, 3) - entry2 := svc.getOrCreateClient("http://proxy-b:8080", 1, 3) + entry1 := mustGetOrCreateClient(s.T(), svc, "http://proxy-a:8080", 1, 3) + entry2 := mustGetOrCreateClient(s.T(), svc, "http://proxy-b:8080", 1, 3) require.NotSame(s.T(), entry1, entry2, "账号+代理隔离应区分不同代理") require.Equal(s.T(), 2, len(svc.clients), "账号+代理隔离应缓存两个客户端") } @@ -195,8 +197,8 @@ func (s *HTTPUpstreamSuite) TestAccountModeProxyChangeClearsPool() { s.cfg.Gateway = config.GatewayConfig{ConnectionPoolIsolation: config.ConnectionPoolIsolationAccount} svc := s.newService() // 同一账户,先后使用不同代理 - entry1 := svc.getOrCreateClient("http://proxy-a:8080", 1, 3) - entry2 := svc.getOrCreateClient("http://proxy-b:8080", 1, 3) + entry1 := mustGetOrCreateClient(s.T(), svc, "http://proxy-a:8080", 1, 3) + entry2 := mustGetOrCreateClient(s.T(), svc, "http://proxy-b:8080", 1, 3) require.NotSame(s.T(), entry1, entry2, "账号切换代理应创建新连接池") require.Equal(s.T(), 1, len(svc.clients), "账号模式下应仅保留一个连接池") require.False(s.T(), hasEntry(svc, entry1), "旧连接池应被清理") @@ -208,7 +210,7 @@ func (s *HTTPUpstreamSuite) TestAccountConcurrencyOverridesPoolSettings() { s.cfg.Gateway = config.GatewayConfig{ConnectionPoolIsolation: config.ConnectionPoolIsolationAccount} svc := s.newService() // 账户并发数为 12 - entry := svc.getOrCreateClient("", 1, 12) + entry := mustGetOrCreateClient(s.T(), svc, "", 1, 12) transport, ok := entry.client.Transport.(*http.Transport) require.True(s.T(), ok, "expected *http.Transport") // 连接池参数应与并发数一致 @@ -228,7 +230,7 @@ func (s *HTTPUpstreamSuite) TestAccountConcurrencyFallbackToDefault() { } svc := s.newService() // 账户并发数为 0,应使用全局配置 - entry := svc.getOrCreateClient("", 1, 0) + entry := mustGetOrCreateClient(s.T(), svc, "", 1, 0) transport, ok := entry.client.Transport.(*http.Transport) require.True(s.T(), ok, "expected *http.Transport") require.Equal(s.T(), 66, transport.MaxConnsPerHost, "MaxConnsPerHost fallback mismatch") @@ -245,12 +247,12 @@ func (s *HTTPUpstreamSuite) TestEvictOverLimitRemovesOldestIdle() { } svc := s.newService() // 创建两个客户端,设置不同的最后使用时间 - entry1 := svc.getOrCreateClient("http://proxy-a:8080", 1, 1) - entry2 := svc.getOrCreateClient("http://proxy-b:8080", 2, 1) + entry1 := mustGetOrCreateClient(s.T(), svc, "http://proxy-a:8080", 1, 1) + entry2 := mustGetOrCreateClient(s.T(), svc, "http://proxy-b:8080", 2, 1) atomic.StoreInt64(&entry1.lastUsed, time.Now().Add(-2*time.Hour).UnixNano()) // 最久 atomic.StoreInt64(&entry2.lastUsed, time.Now().Add(-time.Hour).UnixNano()) // 创建第三个客户端,触发淘汰 - _ = svc.getOrCreateClient("http://proxy-c:8080", 3, 1) + _ = mustGetOrCreateClient(s.T(), svc, "http://proxy-c:8080", 3, 1) require.LessOrEqual(s.T(), len(svc.clients), 2, "应保持在缓存上限内") require.False(s.T(), hasEntry(svc, entry1), "最久未使用的连接池应被清理") @@ -264,12 +266,12 @@ func (s *HTTPUpstreamSuite) TestIdleTTLDoesNotEvictActive() { ClientIdleTTLSeconds: 1, // 1 秒空闲超时 } svc := s.newService() - entry1 := svc.getOrCreateClient("", 1, 1) + entry1 := mustGetOrCreateClient(s.T(), svc, "", 1, 1) // 设置为很久之前使用,但有活跃请求 atomic.StoreInt64(&entry1.lastUsed, time.Now().Add(-2*time.Minute).UnixNano()) atomic.StoreInt64(&entry1.inFlight, 1) // 模拟有活跃请求 // 创建新客户端,触发淘汰检查 - _ = svc.getOrCreateClient("", 2, 1) + _, _ = svc.getOrCreateClient("", 2, 1) require.True(s.T(), hasEntry(svc, entry1), "有活跃请求时不应回收") } @@ -279,6 +281,14 @@ func TestHTTPUpstreamSuite(t *testing.T) { suite.Run(t, new(HTTPUpstreamSuite)) } +// mustGetOrCreateClient 测试辅助函数,调用 getOrCreateClient 并断言无错误 +func mustGetOrCreateClient(t *testing.T, svc *httpUpstreamService, proxyURL string, accountID int64, concurrency int) *upstreamClientEntry { + t.Helper() + entry, err := svc.getOrCreateClient(proxyURL, accountID, concurrency) + require.NoError(t, err, "getOrCreateClient(%q, %d, %d)", proxyURL, accountID, concurrency) + return entry +} + // hasEntry 检查客户端是否存在于缓存中 // 辅助函数,用于验证淘汰逻辑 func hasEntry(svc *httpUpstreamService, target *upstreamClientEntry) bool { diff --git a/backend/internal/repository/openai_oauth_service.go b/backend/internal/repository/openai_oauth_service.go index 3e155971..dca0b612 100644 --- a/backend/internal/repository/openai_oauth_service.go +++ b/backend/internal/repository/openai_oauth_service.go @@ -23,7 +23,10 @@ type openaiOAuthService struct { } func (s *openaiOAuthService) ExchangeCode(ctx context.Context, code, codeVerifier, redirectURI, proxyURL, clientID string) (*openai.TokenResponse, error) { - client := createOpenAIReqClient(proxyURL) + client, err := createOpenAIReqClient(proxyURL) + if err != nil { + return nil, infraerrors.Newf(http.StatusBadGateway, "OPENAI_OAUTH_CLIENT_INIT_FAILED", "create HTTP client: %v", err) + } if redirectURI == "" { redirectURI = openai.DefaultRedirectURI @@ -74,7 +77,10 @@ func (s *openaiOAuthService) RefreshTokenWithClientID(ctx context.Context, refre } func (s *openaiOAuthService) refreshTokenWithClientID(ctx context.Context, refreshToken, proxyURL, clientID string) (*openai.TokenResponse, error) { - client := createOpenAIReqClient(proxyURL) + client, err := createOpenAIReqClient(proxyURL) + if err != nil { + return nil, infraerrors.Newf(http.StatusBadGateway, "OPENAI_OAUTH_CLIENT_INIT_FAILED", "create HTTP client: %v", err) + } formData := url.Values{} formData.Set("grant_type", "refresh_token") @@ -102,7 +108,7 @@ func (s *openaiOAuthService) refreshTokenWithClientID(ctx context.Context, refre return &tokenResp, nil } -func createOpenAIReqClient(proxyURL string) *req.Client { +func createOpenAIReqClient(proxyURL string) (*req.Client, error) { return getSharedReqClient(reqClientOptions{ ProxyURL: proxyURL, Timeout: 120 * time.Second, diff --git a/backend/internal/repository/pricing_service.go b/backend/internal/repository/pricing_service.go index 07d796b8..ee8e1749 100644 --- a/backend/internal/repository/pricing_service.go +++ b/backend/internal/repository/pricing_service.go @@ -4,6 +4,7 @@ import ( "context" "fmt" "io" + "log/slog" "net/http" "strings" "time" @@ -16,14 +17,37 @@ type pricingRemoteClient struct { httpClient *http.Client } +// pricingRemoteClientError 代理初始化失败时的错误占位客户端 +// 所有请求直接返回初始化错误,禁止回退到直连 +type pricingRemoteClientError struct { + err error +} + +func (c *pricingRemoteClientError) FetchPricingJSON(_ context.Context, _ string) ([]byte, error) { + return nil, c.err +} + +func (c *pricingRemoteClientError) FetchHashText(_ context.Context, _ string) (string, error) { + return "", c.err +} + // NewPricingRemoteClient 创建定价数据远程客户端 // proxyURL 为空时直连,支持 http/https/socks5/socks5h 协议 -func NewPricingRemoteClient(proxyURL string) service.PricingRemoteClient { +// 代理配置失败时行为由 allowDirectOnProxyError 控制: +// - false(默认):返回错误占位客户端,禁止回退到直连 +// - true:回退到直连(仅限管理员显式开启) +func NewPricingRemoteClient(proxyURL string, allowDirectOnProxyError bool) service.PricingRemoteClient { + // 安全说明:httpclient.GetClient 的错误链(url.Parse / proxyutil)不含明文代理凭据, + // 但仍通过 slog 仅在服务端日志记录,不会暴露给 HTTP 响应。 sharedClient, err := httpclient.GetClient(httpclient.Options{ Timeout: 30 * time.Second, ProxyURL: proxyURL, }) if err != nil { + if strings.TrimSpace(proxyURL) != "" && !allowDirectOnProxyError { + slog.Warn("proxy client init failed, all requests will fail", "service", "pricing", "error", err) + return &pricingRemoteClientError{err: fmt.Errorf("proxy client init failed and direct fallback is disabled; set security.proxy_fallback.allow_direct_on_error=true to allow fallback: %w", err)} + } sharedClient = &http.Client{Timeout: 30 * time.Second} } return &pricingRemoteClient{ diff --git a/backend/internal/repository/pricing_service_test.go b/backend/internal/repository/pricing_service_test.go index 6ea11211..ef2f214b 100644 --- a/backend/internal/repository/pricing_service_test.go +++ b/backend/internal/repository/pricing_service_test.go @@ -19,7 +19,7 @@ type PricingServiceSuite struct { func (s *PricingServiceSuite) SetupTest() { s.ctx = context.Background() - client, ok := NewPricingRemoteClient("").(*pricingRemoteClient) + client, ok := NewPricingRemoteClient("", false).(*pricingRemoteClient) require.True(s.T(), ok, "type assertion failed") s.client = client } @@ -140,6 +140,22 @@ func (s *PricingServiceSuite) TestFetchPricingJSON_ContextCancel() { require.Error(s.T(), err) } +func TestNewPricingRemoteClient_InvalidProxy_NoFallback(t *testing.T) { + client := NewPricingRemoteClient("://bad", false) + _, ok := client.(*pricingRemoteClientError) + require.True(t, ok, "should return error client when proxy is invalid and fallback disabled") + + _, err := client.FetchPricingJSON(context.Background(), "http://example.com") + require.Error(t, err) + require.Contains(t, err.Error(), "proxy client init failed") +} + +func TestNewPricingRemoteClient_InvalidProxy_WithFallback(t *testing.T) { + client := NewPricingRemoteClient("://bad", true) + _, ok := client.(*pricingRemoteClient) + require.True(t, ok, "should fallback to direct client when allowed") +} + func TestPricingServiceSuite(t *testing.T) { suite.Run(t, new(PricingServiceSuite)) } diff --git a/backend/internal/repository/proxy_probe_service.go b/backend/internal/repository/proxy_probe_service.go index 54de2897..b4aeab71 100644 --- a/backend/internal/repository/proxy_probe_service.go +++ b/backend/internal/repository/proxy_probe_service.go @@ -66,7 +66,6 @@ func (s *proxyProbeService) ProbeProxy(ctx context.Context, proxyURL string) (*s ProxyURL: proxyURL, Timeout: defaultProxyProbeTimeout, InsecureSkipVerify: s.insecureSkipVerify, - ProxyStrict: true, ValidateResolvedIP: s.validateResolvedIP, AllowPrivateHosts: s.allowPrivateHosts, }) diff --git a/backend/internal/repository/req_client_pool.go b/backend/internal/repository/req_client_pool.go index af71a7ee..79b24396 100644 --- a/backend/internal/repository/req_client_pool.go +++ b/backend/internal/repository/req_client_pool.go @@ -6,6 +6,8 @@ import ( "sync" "time" + "github.com/Wei-Shaw/sub2api/internal/pkg/proxyurl" + "github.com/imroc/req/v3" ) @@ -33,11 +35,11 @@ var sharedReqClients sync.Map // getSharedReqClient 获取共享的 req 客户端实例 // 性能优化:相同配置复用同一客户端,避免重复创建 -func getSharedReqClient(opts reqClientOptions) *req.Client { +func getSharedReqClient(opts reqClientOptions) (*req.Client, error) { key := buildReqClientKey(opts) if cached, ok := sharedReqClients.Load(key); ok { if c, ok := cached.(*req.Client); ok { - return c + return c, nil } } @@ -48,15 +50,19 @@ func getSharedReqClient(opts reqClientOptions) *req.Client { if opts.Impersonate { client = client.ImpersonateChrome() } - if strings.TrimSpace(opts.ProxyURL) != "" { - client.SetProxyURL(strings.TrimSpace(opts.ProxyURL)) + trimmed, _, err := proxyurl.Parse(opts.ProxyURL) + if err != nil { + return nil, err + } + if trimmed != "" { + client.SetProxyURL(trimmed) } actual, _ := sharedReqClients.LoadOrStore(key, client) if c, ok := actual.(*req.Client); ok { - return c + return c, nil } - return client + return client, nil } func buildReqClientKey(opts reqClientOptions) string { diff --git a/backend/internal/repository/req_client_pool_test.go b/backend/internal/repository/req_client_pool_test.go index 904ed4d6..9067d012 100644 --- a/backend/internal/repository/req_client_pool_test.go +++ b/backend/internal/repository/req_client_pool_test.go @@ -26,11 +26,13 @@ func TestGetSharedReqClient_ForceHTTP2SeparatesCache(t *testing.T) { ProxyURL: "http://proxy.local:8080", Timeout: time.Second, } - clientDefault := getSharedReqClient(base) + clientDefault, err := getSharedReqClient(base) + require.NoError(t, err) force := base force.ForceHTTP2 = true - clientForce := getSharedReqClient(force) + clientForce, err := getSharedReqClient(force) + require.NoError(t, err) require.NotSame(t, clientDefault, clientForce) require.NotEqual(t, buildReqClientKey(base), buildReqClientKey(force)) @@ -42,8 +44,10 @@ func TestGetSharedReqClient_ReuseCachedClient(t *testing.T) { ProxyURL: "http://proxy.local:8080", Timeout: 2 * time.Second, } - first := getSharedReqClient(opts) - second := getSharedReqClient(opts) + first, err := getSharedReqClient(opts) + require.NoError(t, err) + second, err := getSharedReqClient(opts) + require.NoError(t, err) require.Same(t, first, second) } @@ -56,7 +60,8 @@ func TestGetSharedReqClient_IgnoresNonClientCache(t *testing.T) { key := buildReqClientKey(opts) sharedReqClients.Store(key, "invalid") - client := getSharedReqClient(opts) + client, err := getSharedReqClient(opts) + require.NoError(t, err) require.NotNil(t, client) loaded, ok := sharedReqClients.Load(key) @@ -71,20 +76,45 @@ func TestGetSharedReqClient_ImpersonateAndProxy(t *testing.T) { Timeout: 4 * time.Second, Impersonate: true, } - client := getSharedReqClient(opts) + client, err := getSharedReqClient(opts) + require.NoError(t, err) require.NotNil(t, client) require.Equal(t, "http://proxy.local:8080|4s|true|false", buildReqClientKey(opts)) } +func TestGetSharedReqClient_InvalidProxyURL(t *testing.T) { + sharedReqClients = sync.Map{} + opts := reqClientOptions{ + ProxyURL: "://missing-scheme", + Timeout: time.Second, + } + _, err := getSharedReqClient(opts) + require.Error(t, err) + require.Contains(t, err.Error(), "invalid proxy URL") +} + +func TestGetSharedReqClient_ProxyURLMissingHost(t *testing.T) { + sharedReqClients = sync.Map{} + opts := reqClientOptions{ + ProxyURL: "http://", + Timeout: time.Second, + } + _, err := getSharedReqClient(opts) + require.Error(t, err) + require.Contains(t, err.Error(), "proxy URL missing host") +} + func TestCreateOpenAIReqClient_Timeout120Seconds(t *testing.T) { sharedReqClients = sync.Map{} - client := createOpenAIReqClient("http://proxy.local:8080") + client, err := createOpenAIReqClient("http://proxy.local:8080") + require.NoError(t, err) require.Equal(t, 120*time.Second, client.GetClient().Timeout) } func TestCreateGeminiReqClient_ForceHTTP2Disabled(t *testing.T) { sharedReqClients = sync.Map{} - client := createGeminiReqClient("http://proxy.local:8080") + client, err := createGeminiReqClient("http://proxy.local:8080") + require.NoError(t, err) require.Equal(t, "", forceHTTPVersion(t, client)) } diff --git a/backend/internal/repository/wire.go b/backend/internal/repository/wire.go index 2344035c..ee796d98 100644 --- a/backend/internal/repository/wire.go +++ b/backend/internal/repository/wire.go @@ -34,7 +34,7 @@ func ProvideGitHubReleaseClient(cfg *config.Config) service.GitHubReleaseClient // ProvidePricingRemoteClient 创建定价数据远程客户端 // 从配置中读取代理设置,支持国内服务器通过代理访问 GitHub 上的定价数据 func ProvidePricingRemoteClient(cfg *config.Config) service.PricingRemoteClient { - return NewPricingRemoteClient(cfg.Update.ProxyURL) + return NewPricingRemoteClient(cfg.Update.ProxyURL, cfg.Security.ProxyFallback.AllowDirectOnError) } // ProvideSessionLimitCache 创建会话限制缓存 diff --git a/backend/internal/service/admin_service.go b/backend/internal/service/admin_service.go index bdd1aa4a..7e6982d3 100644 --- a/backend/internal/service/admin_service.go +++ b/backend/internal/service/admin_service.go @@ -2028,7 +2028,6 @@ func (s *adminServiceImpl) CheckProxyQuality(ctx context.Context, id int64) (*Pr ProxyURL: proxyURL, Timeout: proxyQualityRequestTimeout, ResponseHeaderTimeout: proxyQualityResponseHeaderTimeout, - ProxyStrict: true, }) if err != nil { result.Items = append(result.Items, ProxyQualityCheckItem{ diff --git a/backend/internal/service/antigravity_oauth_service.go b/backend/internal/service/antigravity_oauth_service.go index b67c7faf..5f6691be 100644 --- a/backend/internal/service/antigravity_oauth_service.go +++ b/backend/internal/service/antigravity_oauth_service.go @@ -112,7 +112,10 @@ func (s *AntigravityOAuthService) ExchangeCode(ctx context.Context, input *Antig } } - client := antigravity.NewClient(proxyURL) + client, err := antigravity.NewClient(proxyURL) + if err != nil { + return nil, fmt.Errorf("create antigravity client failed: %w", err) + } // 交换 token tokenResp, err := client.ExchangeCode(ctx, input.Code, session.CodeVerifier) @@ -167,7 +170,10 @@ func (s *AntigravityOAuthService) RefreshToken(ctx context.Context, refreshToken time.Sleep(backoff) } - client := antigravity.NewClient(proxyURL) + client, err := antigravity.NewClient(proxyURL) + if err != nil { + return nil, fmt.Errorf("create antigravity client failed: %w", err) + } tokenResp, err := client.RefreshToken(ctx, refreshToken) if err == nil { now := time.Now() @@ -209,7 +215,10 @@ func (s *AntigravityOAuthService) ValidateRefreshToken(ctx context.Context, refr } // 获取用户信息(email) - client := antigravity.NewClient(proxyURL) + client, err := antigravity.NewClient(proxyURL) + if err != nil { + return nil, fmt.Errorf("create antigravity client failed: %w", err) + } userInfo, err := client.GetUserInfo(ctx, tokenInfo.AccessToken) if err != nil { fmt.Printf("[AntigravityOAuth] 警告: 获取用户信息失败: %v\n", err) @@ -309,7 +318,10 @@ func (s *AntigravityOAuthService) loadProjectIDWithRetry(ctx context.Context, ac time.Sleep(backoff) } - client := antigravity.NewClient(proxyURL) + client, err := antigravity.NewClient(proxyURL) + if err != nil { + return "", fmt.Errorf("create antigravity client failed: %w", err) + } loadResp, loadRaw, err := client.LoadCodeAssist(ctx, accessToken) if err == nil && loadResp != nil && loadResp.CloudAICompanionProject != "" { diff --git a/backend/internal/service/antigravity_quota_fetcher.go b/backend/internal/service/antigravity_quota_fetcher.go index 07eb563d..e950ec1d 100644 --- a/backend/internal/service/antigravity_quota_fetcher.go +++ b/backend/internal/service/antigravity_quota_fetcher.go @@ -2,6 +2,7 @@ package service import ( "context" + "fmt" "time" "github.com/Wei-Shaw/sub2api/internal/pkg/antigravity" @@ -31,7 +32,10 @@ func (f *AntigravityQuotaFetcher) FetchQuota(ctx context.Context, account *Accou accessToken := account.GetCredential("access_token") projectID := account.GetCredential("project_id") - client := antigravity.NewClient(proxyURL) + client, err := antigravity.NewClient(proxyURL) + if err != nil { + return nil, fmt.Errorf("create antigravity client failed: %w", err) + } // 调用 API 获取配额 modelsResp, modelsRaw, err := client.FetchAvailableModels(ctx, accessToken, projectID) diff --git a/backend/internal/service/crs_sync_service.go b/backend/internal/service/crs_sync_service.go index 040b2357..6a916740 100644 --- a/backend/internal/service/crs_sync_service.go +++ b/backend/internal/service/crs_sync_service.go @@ -221,7 +221,7 @@ func (s *CRSSyncService) fetchCRSExport(ctx context.Context, baseURL, username, AllowPrivateHosts: s.cfg.Security.URLAllowlist.AllowPrivateHosts, }) if err != nil { - client = &http.Client{Timeout: 20 * time.Second} + return nil, fmt.Errorf("create http client failed: %w", err) } adminToken, err := crsLogin(ctx, client, normalizedURL, username, password) diff --git a/backend/internal/service/gemini_oauth_service.go b/backend/internal/service/gemini_oauth_service.go index e866bdc3..08a74a37 100644 --- a/backend/internal/service/gemini_oauth_service.go +++ b/backend/internal/service/gemini_oauth_service.go @@ -1045,7 +1045,7 @@ func fetchProjectIDFromResourceManager(ctx context.Context, accessToken, proxyUR ValidateResolvedIP: true, }) if err != nil { - client = &http.Client{Timeout: 30 * time.Second} + return "", fmt.Errorf("create http client failed: %w", err) } resp, err := client.Do(req) diff --git a/backend/internal/service/openai_oauth_service.go b/backend/internal/service/openai_oauth_service.go index 07cb5472..72f4bbb0 100644 --- a/backend/internal/service/openai_oauth_service.go +++ b/backend/internal/service/openai_oauth_service.go @@ -7,7 +7,6 @@ import ( "io" "log/slog" "net/http" - "net/url" "regexp" "sort" "strconv" @@ -15,6 +14,7 @@ import ( "time" infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors" + "github.com/Wei-Shaw/sub2api/internal/pkg/httpclient" "github.com/Wei-Shaw/sub2api/internal/pkg/openai" ) @@ -273,7 +273,13 @@ func (s *OpenAIOAuthService) ExchangeSoraSessionToken(ctx context.Context, sessi req.Header.Set("Referer", "https://sora.chatgpt.com/") req.Header.Set("User-Agent", "Sora/1.2026.007 (Android 15; 24122RKC7C; build 2600700)") - client := newOpenAIOAuthHTTPClient(proxyURL) + client, err := httpclient.GetClient(httpclient.Options{ + ProxyURL: proxyURL, + Timeout: 120 * time.Second, + }) + if err != nil { + return nil, infraerrors.Newf(http.StatusBadGateway, "SORA_SESSION_CLIENT_FAILED", "create http client failed: %v", err) + } resp, err := client.Do(req) if err != nil { return nil, infraerrors.Newf(http.StatusBadGateway, "SORA_SESSION_REQUEST_FAILED", "request failed: %v", err) @@ -530,19 +536,6 @@ func (s *OpenAIOAuthService) resolveProxyURL(ctx context.Context, proxyID *int64 return proxy.URL(), nil } -func newOpenAIOAuthHTTPClient(proxyURL string) *http.Client { - transport := &http.Transport{} - if strings.TrimSpace(proxyURL) != "" { - if parsed, err := url.Parse(proxyURL); err == nil && parsed.Host != "" { - transport.Proxy = http.ProxyURL(parsed) - } - } - return &http.Client{ - Timeout: 120 * time.Second, - Transport: transport, - } -} - func normalizeOpenAIOAuthPlatform(platform string) string { switch strings.ToLower(strings.TrimSpace(platform)) { case PlatformSora: diff --git a/deploy/config.example.yaml b/deploy/config.example.yaml index faa85854..e2eb3130 100644 --- a/deploy/config.example.yaml +++ b/deploy/config.example.yaml @@ -134,6 +134,12 @@ security: # Allow skipping TLS verification for proxy probe (debug only) # 允许代理探测时跳过 TLS 证书验证(仅用于调试) insecure_skip_verify: false + proxy_fallback: + # Allow auxiliary services (update check, pricing data) to fallback to direct + # connection when proxy initialization fails. Does NOT affect AI gateway connections. + # 辅助服务(更新检查、定价数据拉取)代理初始化失败时是否允许回退直连。 + # 不影响 AI 账号网关连接。默认 false:fail-fast 防止 IP 泄露。 + allow_direct_on_error: false # ============================================================================= # Gateway Configuration