feat(proxy): 集中代理 URL 验证并实现全局 fail-fast

提取 proxyurl.Parse() 公共包,将分散在 6 处的代理 URL 验证逻辑
统一收敛,确保无效代理配置在创建时立即失败,永不静默回退直连。

主要变更:
- 新增 proxyurl 包:统一 TrimSpace → url.Parse → Host 校验 → Scheme 白名单
- socks5:// 自动升级为 socks5h://,防止 DNS 泄漏(大小写不敏感)
- antigravity: http.ProxyURL → proxyutil.ConfigureTransportProxy 支持 SOCKS5
- openai_oauth: 删除 newOpenAIOAuthHTTPClient,收编至 httpclient.GetClient
- 移除未使用的 ProxyStrict 字段(fail-fast 已是全局默认行为)
- 补充 15 个 proxyurl 测试 + pricing/usage fail-fast 测试
This commit is contained in:
QTom
2026-03-02 15:53:26 +08:00
parent 445bfdf242
commit fdcbf7aacf
31 changed files with 633 additions and 157 deletions

View File

@@ -11,6 +11,7 @@ import (
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
"github.com/Wei-Shaw/sub2api/internal/pkg/oauth"
"github.com/Wei-Shaw/sub2api/internal/pkg/proxyurl"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/Wei-Shaw/sub2api/internal/util/logredact"
@@ -28,11 +29,14 @@ func NewClaudeOAuthClient() service.ClaudeOAuthClient {
type claudeOAuthService struct {
baseURL string
tokenURL string
clientFactory func(proxyURL string) *req.Client
clientFactory func(proxyURL string) (*req.Client, error)
}
func (s *claudeOAuthService) GetOrganizationUUID(ctx context.Context, sessionKey, proxyURL string) (string, error) {
client := s.clientFactory(proxyURL)
client, err := s.clientFactory(proxyURL)
if err != nil {
return "", fmt.Errorf("create HTTP client: %w", err)
}
var orgs []struct {
UUID string `json:"uuid"`
@@ -88,7 +92,10 @@ func (s *claudeOAuthService) GetOrganizationUUID(ctx context.Context, sessionKey
}
func (s *claudeOAuthService) GetAuthorizationCode(ctx context.Context, sessionKey, orgUUID, scope, codeChallenge, state, proxyURL string) (string, error) {
client := s.clientFactory(proxyURL)
client, err := s.clientFactory(proxyURL)
if err != nil {
return "", fmt.Errorf("create HTTP client: %w", err)
}
authURL := fmt.Sprintf("%s/v1/oauth/%s/authorize", s.baseURL, orgUUID)
@@ -165,7 +172,10 @@ func (s *claudeOAuthService) GetAuthorizationCode(ctx context.Context, sessionKe
}
func (s *claudeOAuthService) ExchangeCodeForToken(ctx context.Context, code, codeVerifier, state, proxyURL string, isSetupToken bool) (*oauth.TokenResponse, error) {
client := s.clientFactory(proxyURL)
client, err := s.clientFactory(proxyURL)
if err != nil {
return nil, fmt.Errorf("create HTTP client: %w", err)
}
// Parse code which may contain state in format "authCode#state"
authCode := code
@@ -223,7 +233,10 @@ func (s *claudeOAuthService) ExchangeCodeForToken(ctx context.Context, code, cod
}
func (s *claudeOAuthService) RefreshToken(ctx context.Context, refreshToken, proxyURL string) (*oauth.TokenResponse, error) {
client := s.clientFactory(proxyURL)
client, err := s.clientFactory(proxyURL)
if err != nil {
return nil, fmt.Errorf("create HTTP client: %w", err)
}
reqBody := map[string]any{
"grant_type": "refresh_token",
@@ -253,16 +266,20 @@ func (s *claudeOAuthService) RefreshToken(ctx context.Context, refreshToken, pro
return &tokenResp, nil
}
func createReqClient(proxyURL string) *req.Client {
func createReqClient(proxyURL string) (*req.Client, error) {
// 禁用 CookieJar确保每次授权都是干净的会话
client := req.C().
SetTimeout(60 * time.Second).
ImpersonateChrome().
SetCookieJar(nil) // 禁用 CookieJar
if strings.TrimSpace(proxyURL) != "" {
client.SetProxyURL(strings.TrimSpace(proxyURL))
trimmed, _, err := proxyurl.Parse(proxyURL)
if err != nil {
return nil, err
}
if trimmed != "" {
client.SetProxyURL(trimmed)
}
return client
return client, nil
}

View File

@@ -91,7 +91,7 @@ func (s *ClaudeOAuthServiceSuite) TestGetOrganizationUUID() {
require.True(s.T(), ok, "type assertion failed")
s.client = client
s.client.baseURL = "http://in-process"
s.client.clientFactory = func(string) *req.Client { return newTestReqClient(rt) }
s.client.clientFactory = func(string) (*req.Client, error) { return newTestReqClient(rt), nil }
got, err := s.client.GetOrganizationUUID(context.Background(), "sess", "")
@@ -169,7 +169,7 @@ func (s *ClaudeOAuthServiceSuite) TestGetAuthorizationCode() {
require.True(s.T(), ok, "type assertion failed")
s.client = client
s.client.baseURL = "http://in-process"
s.client.clientFactory = func(string) *req.Client { return newTestReqClient(rt) }
s.client.clientFactory = func(string) (*req.Client, error) { return newTestReqClient(rt), nil }
code, err := s.client.GetAuthorizationCode(context.Background(), "sess", "org-1", oauth.ScopeInference, "cc", "st", "")
@@ -276,7 +276,7 @@ func (s *ClaudeOAuthServiceSuite) TestExchangeCodeForToken() {
require.True(s.T(), ok, "type assertion failed")
s.client = client
s.client.tokenURL = "http://in-process/token"
s.client.clientFactory = func(string) *req.Client { return newTestReqClient(rt) }
s.client.clientFactory = func(string) (*req.Client, error) { return newTestReqClient(rt), nil }
resp, err := s.client.ExchangeCodeForToken(context.Background(), tt.code, "ver", "", "", tt.isSetupToken)
@@ -372,7 +372,7 @@ func (s *ClaudeOAuthServiceSuite) TestRefreshToken() {
require.True(s.T(), ok, "type assertion failed")
s.client = client
s.client.tokenURL = "http://in-process/token"
s.client.clientFactory = func(string) *req.Client { return newTestReqClient(rt) }
s.client.clientFactory = func(string) (*req.Client, error) { return newTestReqClient(rt), nil }
resp, err := s.client.RefreshToken(context.Background(), "rt", "")

View File

@@ -83,7 +83,7 @@ func (s *claudeUsageService) FetchUsageWithOptions(ctx context.Context, opts *se
AllowPrivateHosts: s.allowPrivateHosts,
})
if err != nil {
client = &http.Client{Timeout: 30 * time.Second}
return nil, fmt.Errorf("create http client failed: %w", err)
}
resp, err = client.Do(req)

View File

@@ -50,7 +50,7 @@ func (s *ClaudeUsageServiceSuite) TestFetchUsage_Success() {
allowPrivateHosts: true,
}
resp, err := s.fetcher.FetchUsage(context.Background(), "at", "://bad-proxy-url")
resp, err := s.fetcher.FetchUsage(context.Background(), "at", "")
require.NoError(s.T(), err, "FetchUsage")
require.Equal(s.T(), 12.5, resp.FiveHour.Utilization, "FiveHour utilization mismatch")
require.Equal(s.T(), 34.0, resp.SevenDay.Utilization, "SevenDay utilization mismatch")
@@ -112,6 +112,17 @@ func (s *ClaudeUsageServiceSuite) TestFetchUsage_ContextCancel() {
require.Error(s.T(), err, "expected error for cancelled context")
}
func (s *ClaudeUsageServiceSuite) TestFetchUsage_InvalidProxyReturnsError() {
s.fetcher = &claudeUsageService{
usageURL: "http://example.com",
allowPrivateHosts: true,
}
_, err := s.fetcher.FetchUsage(context.Background(), "at", "://bad-proxy-url")
require.Error(s.T(), err)
require.ErrorContains(s.T(), err, "create http client failed")
}
func TestClaudeUsageServiceSuite(t *testing.T) {
suite.Run(t, new(ClaudeUsageServiceSuite))
}

View File

@@ -26,7 +26,10 @@ func NewGeminiOAuthClient(cfg *config.Config) service.GeminiOAuthClient {
}
func (c *geminiOAuthClient) ExchangeCode(ctx context.Context, oauthType, code, codeVerifier, redirectURI, proxyURL string) (*geminicli.TokenResponse, error) {
client := createGeminiReqClient(proxyURL)
client, err := createGeminiReqClient(proxyURL)
if err != nil {
return nil, fmt.Errorf("create HTTP client: %w", err)
}
// Use different OAuth clients based on oauthType:
// - code_assist: always use built-in Gemini CLI OAuth client (public)
@@ -72,7 +75,10 @@ func (c *geminiOAuthClient) ExchangeCode(ctx context.Context, oauthType, code, c
}
func (c *geminiOAuthClient) RefreshToken(ctx context.Context, oauthType, refreshToken, proxyURL string) (*geminicli.TokenResponse, error) {
client := createGeminiReqClient(proxyURL)
client, err := createGeminiReqClient(proxyURL)
if err != nil {
return nil, fmt.Errorf("create HTTP client: %w", err)
}
oauthCfgInput := geminicli.OAuthConfig{
ClientID: c.cfg.Gemini.OAuth.ClientID,
@@ -111,7 +117,7 @@ func (c *geminiOAuthClient) RefreshToken(ctx context.Context, oauthType, refresh
return &tokenResp, nil
}
func createGeminiReqClient(proxyURL string) *req.Client {
func createGeminiReqClient(proxyURL string) (*req.Client, error) {
return getSharedReqClient(reqClientOptions{
ProxyURL: proxyURL,
Timeout: 60 * time.Second,

View File

@@ -26,7 +26,11 @@ func (c *geminiCliCodeAssistClient) LoadCodeAssist(ctx context.Context, accessTo
}
var out geminicli.LoadCodeAssistResponse
resp, err := createGeminiCliReqClient(proxyURL).R().
client, err := createGeminiCliReqClient(proxyURL)
if err != nil {
return nil, fmt.Errorf("create HTTP client: %w", err)
}
resp, err := client.R().
SetContext(ctx).
SetHeader("Authorization", "Bearer "+accessToken).
SetHeader("Content-Type", "application/json").
@@ -66,7 +70,11 @@ func (c *geminiCliCodeAssistClient) OnboardUser(ctx context.Context, accessToken
fmt.Printf("[CodeAssist] OnboardUser request body: %+v\n", reqBody)
var out geminicli.OnboardUserResponse
resp, err := createGeminiCliReqClient(proxyURL).R().
client, err := createGeminiCliReqClient(proxyURL)
if err != nil {
return nil, fmt.Errorf("create HTTP client: %w", err)
}
resp, err := client.R().
SetContext(ctx).
SetHeader("Authorization", "Bearer "+accessToken).
SetHeader("Content-Type", "application/json").
@@ -98,7 +106,7 @@ func (c *geminiCliCodeAssistClient) OnboardUser(ctx context.Context, accessToken
return &out, nil
}
func createGeminiCliReqClient(proxyURL string) *req.Client {
func createGeminiCliReqClient(proxyURL string) (*req.Client, error) {
return getSharedReqClient(reqClientOptions{
ProxyURL: proxyURL,
Timeout: 30 * time.Second,

View File

@@ -5,8 +5,10 @@ import (
"encoding/json"
"fmt"
"io"
"log/slog"
"net/http"
"os"
"strings"
"time"
"github.com/Wei-Shaw/sub2api/internal/pkg/httpclient"
@@ -24,13 +26,19 @@ type githubReleaseClientError struct {
// NewGitHubReleaseClient 创建 GitHub Release 客户端
// proxyURL 为空时直连 GitHub支持 http/https/socks5/socks5h 协议
// 代理配置失败时行为由 allowDirectOnProxyError 控制:
// - false默认返回错误占位客户端禁止回退到直连
// - true回退到直连仅限管理员显式开启
func NewGitHubReleaseClient(proxyURL string, allowDirectOnProxyError bool) service.GitHubReleaseClient {
// 安全说明httpclient.GetClient 的错误链url.Parse / proxyutil不含明文代理凭据
// 但仍通过 slog 仅在服务端日志记录,不会暴露给 HTTP 响应。
sharedClient, err := httpclient.GetClient(httpclient.Options{
Timeout: 30 * time.Second,
ProxyURL: proxyURL,
})
if err != nil {
if proxyURL != "" && !allowDirectOnProxyError {
if strings.TrimSpace(proxyURL) != "" && !allowDirectOnProxyError {
slog.Warn("proxy client init failed, all requests will fail", "service", "github_release", "error", err)
return &githubReleaseClientError{err: fmt.Errorf("proxy client init failed and direct fallback is disabled; set security.proxy_fallback.allow_direct_on_error=true to allow fallback: %w", err)}
}
sharedClient = &http.Client{Timeout: 30 * time.Second}
@@ -42,7 +50,8 @@ func NewGitHubReleaseClient(proxyURL string, allowDirectOnProxyError bool) servi
ProxyURL: proxyURL,
})
if err != nil {
if proxyURL != "" && !allowDirectOnProxyError {
if strings.TrimSpace(proxyURL) != "" && !allowDirectOnProxyError {
slog.Warn("proxy download client init failed, all requests will fail", "service", "github_release", "error", err)
return &githubReleaseClientError{err: fmt.Errorf("proxy client init failed and direct fallback is disabled; set security.proxy_fallback.allow_direct_on_error=true to allow fallback: %w", err)}
}
downloadClient = &http.Client{Timeout: 10 * time.Minute}

View File

@@ -14,6 +14,7 @@ import (
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/pkg/proxyurl"
"github.com/Wei-Shaw/sub2api/internal/pkg/proxyutil"
"github.com/Wei-Shaw/sub2api/internal/pkg/tlsfingerprint"
"github.com/Wei-Shaw/sub2api/internal/service"
@@ -235,7 +236,10 @@ func (s *httpUpstreamService) acquireClientWithTLS(proxyURL string, accountID in
// TLS 指纹客户端使用独立的缓存键,与普通客户端隔离
func (s *httpUpstreamService) getClientEntryWithTLS(proxyURL string, accountID int64, accountConcurrency int, profile *tlsfingerprint.Profile, markInFlight bool, enforceLimit bool) (*upstreamClientEntry, error) {
isolation := s.getIsolationMode()
proxyKey, parsedProxy := normalizeProxyURL(proxyURL)
proxyKey, parsedProxy, err := normalizeProxyURL(proxyURL)
if err != nil {
return nil, err
}
// TLS 指纹客户端使用独立的缓存键,加 "tls:" 前缀
cacheKey := "tls:" + buildCacheKey(isolation, proxyKey, accountID)
poolKey := s.buildPoolKey(isolation, accountConcurrency) + ":tls"
@@ -373,9 +377,8 @@ func (s *httpUpstreamService) acquireClient(proxyURL string, accountID int64, ac
// - proxy: 按代理地址隔离,同一代理共享客户端
// - account: 按账户隔离,同一账户共享客户端(代理变更时重建)
// - account_proxy: 按账户+代理组合隔离,最细粒度
func (s *httpUpstreamService) getOrCreateClient(proxyURL string, accountID int64, accountConcurrency int) *upstreamClientEntry {
entry, _ := s.getClientEntry(proxyURL, accountID, accountConcurrency, false, false)
return entry
func (s *httpUpstreamService) getOrCreateClient(proxyURL string, accountID int64, accountConcurrency int) (*upstreamClientEntry, error) {
return s.getClientEntry(proxyURL, accountID, accountConcurrency, false, false)
}
// getClientEntry 获取或创建客户端条目
@@ -385,7 +388,10 @@ func (s *httpUpstreamService) getClientEntry(proxyURL string, accountID int64, a
// 获取隔离模式
isolation := s.getIsolationMode()
// 标准化代理 URL 并解析
proxyKey, parsedProxy := normalizeProxyURL(proxyURL)
proxyKey, parsedProxy, err := normalizeProxyURL(proxyURL)
if err != nil {
return nil, err
}
// 构建缓存键(根据隔离策略不同)
cacheKey := buildCacheKey(isolation, proxyKey, accountID)
// 构建连接池配置键(用于检测配置变更)
@@ -680,17 +686,18 @@ func buildCacheKey(isolation, proxyKey string, accountID int64) string {
// - raw: 原始代理 URL 字符串
//
// 返回:
// - string: 标准化的代理键(空或解析失败返回 "direct"
// - *url.URL: 解析后的 URL或解析失败返回 nil
func normalizeProxyURL(raw string) (string, *url.URL) {
proxyURL := strings.TrimSpace(raw)
if proxyURL == "" {
return directProxyKey, nil
}
parsed, err := url.Parse(proxyURL)
// - string: 标准化的代理键(空返回 "direct"
// - *url.URL: 解析后的 URL空返回 nil
// - error: 非空代理 URL 解析失败时返回错误(禁止回退到直连)
func normalizeProxyURL(raw string) (string, *url.URL, error) {
_, parsed, err := proxyurl.Parse(raw)
if err != nil {
return directProxyKey, nil
return "", nil, err
}
if parsed == nil {
return directProxyKey, nil, nil
}
// 规范化:小写 scheme/host去除路径和查询参数
parsed.Scheme = strings.ToLower(parsed.Scheme)
parsed.Host = strings.ToLower(parsed.Host)
parsed.Path = ""
@@ -710,7 +717,7 @@ func normalizeProxyURL(raw string) (string, *url.URL) {
parsed.Host = hostname
}
}
return parsed.String(), parsed
return parsed.String(), parsed, nil
}
// defaultPoolSettings 获取默认连接池配置

View File

@@ -59,7 +59,10 @@ func BenchmarkHTTPUpstreamProxyClient(b *testing.B) {
// 模拟优化后的行为,从缓存获取客户端
b.Run("复用", func(b *testing.B) {
// 预热:确保客户端已缓存
entry := svc.getOrCreateClient(proxyURL, 1, 1)
entry, err := svc.getOrCreateClient(proxyURL, 1, 1)
if err != nil {
b.Fatalf("getOrCreateClient: %v", err)
}
client := entry.client
b.ResetTimer() // 重置计时器,排除预热时间
for i := 0; i < b.N; i++ {

View File

@@ -44,7 +44,7 @@ func (s *HTTPUpstreamSuite) newService() *httpUpstreamService {
// 验证未配置时使用 300 秒默认值
func (s *HTTPUpstreamSuite) TestDefaultResponseHeaderTimeout() {
svc := s.newService()
entry := svc.getOrCreateClient("", 0, 0)
entry := mustGetOrCreateClient(s.T(), svc, "", 0, 0)
transport, ok := entry.client.Transport.(*http.Transport)
require.True(s.T(), ok, "expected *http.Transport")
require.Equal(s.T(), 300*time.Second, transport.ResponseHeaderTimeout, "ResponseHeaderTimeout mismatch")
@@ -55,25 +55,27 @@ func (s *HTTPUpstreamSuite) TestDefaultResponseHeaderTimeout() {
func (s *HTTPUpstreamSuite) TestCustomResponseHeaderTimeout() {
s.cfg.Gateway = config.GatewayConfig{ResponseHeaderTimeout: 7}
svc := s.newService()
entry := svc.getOrCreateClient("", 0, 0)
entry := mustGetOrCreateClient(s.T(), svc, "", 0, 0)
transport, ok := entry.client.Transport.(*http.Transport)
require.True(s.T(), ok, "expected *http.Transport")
require.Equal(s.T(), 7*time.Second, transport.ResponseHeaderTimeout, "ResponseHeaderTimeout mismatch")
}
// TestGetOrCreateClient_InvalidURLFallsBackToDirect 测试无效代理 URL 回退
// 验证解析失败时回退到直连模式
func (s *HTTPUpstreamSuite) TestGetOrCreateClient_InvalidURLFallsBackToDirect() {
// TestGetOrCreateClient_InvalidURLReturnsError 测试无效代理 URL 返回错误
// 验证解析失败时拒绝回退到直连模式
func (s *HTTPUpstreamSuite) TestGetOrCreateClient_InvalidURLReturnsError() {
svc := s.newService()
entry := svc.getOrCreateClient("://bad-proxy-url", 1, 1)
require.Equal(s.T(), directProxyKey, entry.proxyKey, "expected direct proxy fallback")
_, err := svc.getClientEntry("://bad-proxy-url", 1, 1, false, false)
require.Error(s.T(), err, "expected error for invalid proxy URL")
}
// TestNormalizeProxyURL_Canonicalizes 测试代理 URL 规范化
// 验证等价地址能够映射到同一缓存键
func (s *HTTPUpstreamSuite) TestNormalizeProxyURL_Canonicalizes() {
key1, _ := normalizeProxyURL("http://proxy.local:8080")
key2, _ := normalizeProxyURL("http://proxy.local:8080/")
key1, _, err1 := normalizeProxyURL("http://proxy.local:8080")
require.NoError(s.T(), err1)
key2, _, err2 := normalizeProxyURL("http://proxy.local:8080/")
require.NoError(s.T(), err2)
require.Equal(s.T(), key1, key2, "expected normalized proxy keys to match")
}
@@ -171,8 +173,8 @@ func (s *HTTPUpstreamSuite) TestAccountIsolation_DifferentAccounts() {
s.cfg.Gateway = config.GatewayConfig{ConnectionPoolIsolation: config.ConnectionPoolIsolationAccount}
svc := s.newService()
// 同一代理,不同账户
entry1 := svc.getOrCreateClient("http://proxy.local:8080", 1, 3)
entry2 := svc.getOrCreateClient("http://proxy.local:8080", 2, 3)
entry1 := mustGetOrCreateClient(s.T(), svc, "http://proxy.local:8080", 1, 3)
entry2 := mustGetOrCreateClient(s.T(), svc, "http://proxy.local:8080", 2, 3)
require.NotSame(s.T(), entry1, entry2, "不同账号不应共享连接池")
require.Equal(s.T(), 2, len(svc.clients), "账号隔离应缓存两个客户端")
}
@@ -183,8 +185,8 @@ func (s *HTTPUpstreamSuite) TestAccountProxyIsolation_DifferentProxy() {
s.cfg.Gateway = config.GatewayConfig{ConnectionPoolIsolation: config.ConnectionPoolIsolationAccountProxy}
svc := s.newService()
// 同一账户,不同代理
entry1 := svc.getOrCreateClient("http://proxy-a:8080", 1, 3)
entry2 := svc.getOrCreateClient("http://proxy-b:8080", 1, 3)
entry1 := mustGetOrCreateClient(s.T(), svc, "http://proxy-a:8080", 1, 3)
entry2 := mustGetOrCreateClient(s.T(), svc, "http://proxy-b:8080", 1, 3)
require.NotSame(s.T(), entry1, entry2, "账号+代理隔离应区分不同代理")
require.Equal(s.T(), 2, len(svc.clients), "账号+代理隔离应缓存两个客户端")
}
@@ -195,8 +197,8 @@ func (s *HTTPUpstreamSuite) TestAccountModeProxyChangeClearsPool() {
s.cfg.Gateway = config.GatewayConfig{ConnectionPoolIsolation: config.ConnectionPoolIsolationAccount}
svc := s.newService()
// 同一账户,先后使用不同代理
entry1 := svc.getOrCreateClient("http://proxy-a:8080", 1, 3)
entry2 := svc.getOrCreateClient("http://proxy-b:8080", 1, 3)
entry1 := mustGetOrCreateClient(s.T(), svc, "http://proxy-a:8080", 1, 3)
entry2 := mustGetOrCreateClient(s.T(), svc, "http://proxy-b:8080", 1, 3)
require.NotSame(s.T(), entry1, entry2, "账号切换代理应创建新连接池")
require.Equal(s.T(), 1, len(svc.clients), "账号模式下应仅保留一个连接池")
require.False(s.T(), hasEntry(svc, entry1), "旧连接池应被清理")
@@ -208,7 +210,7 @@ func (s *HTTPUpstreamSuite) TestAccountConcurrencyOverridesPoolSettings() {
s.cfg.Gateway = config.GatewayConfig{ConnectionPoolIsolation: config.ConnectionPoolIsolationAccount}
svc := s.newService()
// 账户并发数为 12
entry := svc.getOrCreateClient("", 1, 12)
entry := mustGetOrCreateClient(s.T(), svc, "", 1, 12)
transport, ok := entry.client.Transport.(*http.Transport)
require.True(s.T(), ok, "expected *http.Transport")
// 连接池参数应与并发数一致
@@ -228,7 +230,7 @@ func (s *HTTPUpstreamSuite) TestAccountConcurrencyFallbackToDefault() {
}
svc := s.newService()
// 账户并发数为 0应使用全局配置
entry := svc.getOrCreateClient("", 1, 0)
entry := mustGetOrCreateClient(s.T(), svc, "", 1, 0)
transport, ok := entry.client.Transport.(*http.Transport)
require.True(s.T(), ok, "expected *http.Transport")
require.Equal(s.T(), 66, transport.MaxConnsPerHost, "MaxConnsPerHost fallback mismatch")
@@ -245,12 +247,12 @@ func (s *HTTPUpstreamSuite) TestEvictOverLimitRemovesOldestIdle() {
}
svc := s.newService()
// 创建两个客户端,设置不同的最后使用时间
entry1 := svc.getOrCreateClient("http://proxy-a:8080", 1, 1)
entry2 := svc.getOrCreateClient("http://proxy-b:8080", 2, 1)
entry1 := mustGetOrCreateClient(s.T(), svc, "http://proxy-a:8080", 1, 1)
entry2 := mustGetOrCreateClient(s.T(), svc, "http://proxy-b:8080", 2, 1)
atomic.StoreInt64(&entry1.lastUsed, time.Now().Add(-2*time.Hour).UnixNano()) // 最久
atomic.StoreInt64(&entry2.lastUsed, time.Now().Add(-time.Hour).UnixNano())
// 创建第三个客户端,触发淘汰
_ = svc.getOrCreateClient("http://proxy-c:8080", 3, 1)
_ = mustGetOrCreateClient(s.T(), svc, "http://proxy-c:8080", 3, 1)
require.LessOrEqual(s.T(), len(svc.clients), 2, "应保持在缓存上限内")
require.False(s.T(), hasEntry(svc, entry1), "最久未使用的连接池应被清理")
@@ -264,12 +266,12 @@ func (s *HTTPUpstreamSuite) TestIdleTTLDoesNotEvictActive() {
ClientIdleTTLSeconds: 1, // 1 秒空闲超时
}
svc := s.newService()
entry1 := svc.getOrCreateClient("", 1, 1)
entry1 := mustGetOrCreateClient(s.T(), svc, "", 1, 1)
// 设置为很久之前使用,但有活跃请求
atomic.StoreInt64(&entry1.lastUsed, time.Now().Add(-2*time.Minute).UnixNano())
atomic.StoreInt64(&entry1.inFlight, 1) // 模拟有活跃请求
// 创建新客户端,触发淘汰检查
_ = svc.getOrCreateClient("", 2, 1)
_, _ = svc.getOrCreateClient("", 2, 1)
require.True(s.T(), hasEntry(svc, entry1), "有活跃请求时不应回收")
}
@@ -279,6 +281,14 @@ func TestHTTPUpstreamSuite(t *testing.T) {
suite.Run(t, new(HTTPUpstreamSuite))
}
// mustGetOrCreateClient 测试辅助函数,调用 getOrCreateClient 并断言无错误
func mustGetOrCreateClient(t *testing.T, svc *httpUpstreamService, proxyURL string, accountID int64, concurrency int) *upstreamClientEntry {
t.Helper()
entry, err := svc.getOrCreateClient(proxyURL, accountID, concurrency)
require.NoError(t, err, "getOrCreateClient(%q, %d, %d)", proxyURL, accountID, concurrency)
return entry
}
// hasEntry 检查客户端是否存在于缓存中
// 辅助函数,用于验证淘汰逻辑
func hasEntry(svc *httpUpstreamService, target *upstreamClientEntry) bool {

View File

@@ -23,7 +23,10 @@ type openaiOAuthService struct {
}
func (s *openaiOAuthService) ExchangeCode(ctx context.Context, code, codeVerifier, redirectURI, proxyURL, clientID string) (*openai.TokenResponse, error) {
client := createOpenAIReqClient(proxyURL)
client, err := createOpenAIReqClient(proxyURL)
if err != nil {
return nil, infraerrors.Newf(http.StatusBadGateway, "OPENAI_OAUTH_CLIENT_INIT_FAILED", "create HTTP client: %v", err)
}
if redirectURI == "" {
redirectURI = openai.DefaultRedirectURI
@@ -74,7 +77,10 @@ func (s *openaiOAuthService) RefreshTokenWithClientID(ctx context.Context, refre
}
func (s *openaiOAuthService) refreshTokenWithClientID(ctx context.Context, refreshToken, proxyURL, clientID string) (*openai.TokenResponse, error) {
client := createOpenAIReqClient(proxyURL)
client, err := createOpenAIReqClient(proxyURL)
if err != nil {
return nil, infraerrors.Newf(http.StatusBadGateway, "OPENAI_OAUTH_CLIENT_INIT_FAILED", "create HTTP client: %v", err)
}
formData := url.Values{}
formData.Set("grant_type", "refresh_token")
@@ -102,7 +108,7 @@ func (s *openaiOAuthService) refreshTokenWithClientID(ctx context.Context, refre
return &tokenResp, nil
}
func createOpenAIReqClient(proxyURL string) *req.Client {
func createOpenAIReqClient(proxyURL string) (*req.Client, error) {
return getSharedReqClient(reqClientOptions{
ProxyURL: proxyURL,
Timeout: 120 * time.Second,

View File

@@ -4,6 +4,7 @@ import (
"context"
"fmt"
"io"
"log/slog"
"net/http"
"strings"
"time"
@@ -16,14 +17,37 @@ type pricingRemoteClient struct {
httpClient *http.Client
}
// pricingRemoteClientError 代理初始化失败时的错误占位客户端
// 所有请求直接返回初始化错误,禁止回退到直连
type pricingRemoteClientError struct {
err error
}
func (c *pricingRemoteClientError) FetchPricingJSON(_ context.Context, _ string) ([]byte, error) {
return nil, c.err
}
func (c *pricingRemoteClientError) FetchHashText(_ context.Context, _ string) (string, error) {
return "", c.err
}
// NewPricingRemoteClient 创建定价数据远程客户端
// proxyURL 为空时直连,支持 http/https/socks5/socks5h 协议
func NewPricingRemoteClient(proxyURL string) service.PricingRemoteClient {
// 代理配置失败时行为由 allowDirectOnProxyError 控制:
// - false默认返回错误占位客户端禁止回退到直连
// - true回退到直连仅限管理员显式开启
func NewPricingRemoteClient(proxyURL string, allowDirectOnProxyError bool) service.PricingRemoteClient {
// 安全说明httpclient.GetClient 的错误链url.Parse / proxyutil不含明文代理凭据
// 但仍通过 slog 仅在服务端日志记录,不会暴露给 HTTP 响应。
sharedClient, err := httpclient.GetClient(httpclient.Options{
Timeout: 30 * time.Second,
ProxyURL: proxyURL,
})
if err != nil {
if strings.TrimSpace(proxyURL) != "" && !allowDirectOnProxyError {
slog.Warn("proxy client init failed, all requests will fail", "service", "pricing", "error", err)
return &pricingRemoteClientError{err: fmt.Errorf("proxy client init failed and direct fallback is disabled; set security.proxy_fallback.allow_direct_on_error=true to allow fallback: %w", err)}
}
sharedClient = &http.Client{Timeout: 30 * time.Second}
}
return &pricingRemoteClient{

View File

@@ -19,7 +19,7 @@ type PricingServiceSuite struct {
func (s *PricingServiceSuite) SetupTest() {
s.ctx = context.Background()
client, ok := NewPricingRemoteClient("").(*pricingRemoteClient)
client, ok := NewPricingRemoteClient("", false).(*pricingRemoteClient)
require.True(s.T(), ok, "type assertion failed")
s.client = client
}
@@ -140,6 +140,22 @@ func (s *PricingServiceSuite) TestFetchPricingJSON_ContextCancel() {
require.Error(s.T(), err)
}
func TestNewPricingRemoteClient_InvalidProxy_NoFallback(t *testing.T) {
client := NewPricingRemoteClient("://bad", false)
_, ok := client.(*pricingRemoteClientError)
require.True(t, ok, "should return error client when proxy is invalid and fallback disabled")
_, err := client.FetchPricingJSON(context.Background(), "http://example.com")
require.Error(t, err)
require.Contains(t, err.Error(), "proxy client init failed")
}
func TestNewPricingRemoteClient_InvalidProxy_WithFallback(t *testing.T) {
client := NewPricingRemoteClient("://bad", true)
_, ok := client.(*pricingRemoteClient)
require.True(t, ok, "should fallback to direct client when allowed")
}
func TestPricingServiceSuite(t *testing.T) {
suite.Run(t, new(PricingServiceSuite))
}

View File

@@ -66,7 +66,6 @@ func (s *proxyProbeService) ProbeProxy(ctx context.Context, proxyURL string) (*s
ProxyURL: proxyURL,
Timeout: defaultProxyProbeTimeout,
InsecureSkipVerify: s.insecureSkipVerify,
ProxyStrict: true,
ValidateResolvedIP: s.validateResolvedIP,
AllowPrivateHosts: s.allowPrivateHosts,
})

View File

@@ -6,6 +6,8 @@ import (
"sync"
"time"
"github.com/Wei-Shaw/sub2api/internal/pkg/proxyurl"
"github.com/imroc/req/v3"
)
@@ -33,11 +35,11 @@ var sharedReqClients sync.Map
// getSharedReqClient 获取共享的 req 客户端实例
// 性能优化:相同配置复用同一客户端,避免重复创建
func getSharedReqClient(opts reqClientOptions) *req.Client {
func getSharedReqClient(opts reqClientOptions) (*req.Client, error) {
key := buildReqClientKey(opts)
if cached, ok := sharedReqClients.Load(key); ok {
if c, ok := cached.(*req.Client); ok {
return c
return c, nil
}
}
@@ -48,15 +50,19 @@ func getSharedReqClient(opts reqClientOptions) *req.Client {
if opts.Impersonate {
client = client.ImpersonateChrome()
}
if strings.TrimSpace(opts.ProxyURL) != "" {
client.SetProxyURL(strings.TrimSpace(opts.ProxyURL))
trimmed, _, err := proxyurl.Parse(opts.ProxyURL)
if err != nil {
return nil, err
}
if trimmed != "" {
client.SetProxyURL(trimmed)
}
actual, _ := sharedReqClients.LoadOrStore(key, client)
if c, ok := actual.(*req.Client); ok {
return c
return c, nil
}
return client
return client, nil
}
func buildReqClientKey(opts reqClientOptions) string {

View File

@@ -26,11 +26,13 @@ func TestGetSharedReqClient_ForceHTTP2SeparatesCache(t *testing.T) {
ProxyURL: "http://proxy.local:8080",
Timeout: time.Second,
}
clientDefault := getSharedReqClient(base)
clientDefault, err := getSharedReqClient(base)
require.NoError(t, err)
force := base
force.ForceHTTP2 = true
clientForce := getSharedReqClient(force)
clientForce, err := getSharedReqClient(force)
require.NoError(t, err)
require.NotSame(t, clientDefault, clientForce)
require.NotEqual(t, buildReqClientKey(base), buildReqClientKey(force))
@@ -42,8 +44,10 @@ func TestGetSharedReqClient_ReuseCachedClient(t *testing.T) {
ProxyURL: "http://proxy.local:8080",
Timeout: 2 * time.Second,
}
first := getSharedReqClient(opts)
second := getSharedReqClient(opts)
first, err := getSharedReqClient(opts)
require.NoError(t, err)
second, err := getSharedReqClient(opts)
require.NoError(t, err)
require.Same(t, first, second)
}
@@ -56,7 +60,8 @@ func TestGetSharedReqClient_IgnoresNonClientCache(t *testing.T) {
key := buildReqClientKey(opts)
sharedReqClients.Store(key, "invalid")
client := getSharedReqClient(opts)
client, err := getSharedReqClient(opts)
require.NoError(t, err)
require.NotNil(t, client)
loaded, ok := sharedReqClients.Load(key)
@@ -71,20 +76,45 @@ func TestGetSharedReqClient_ImpersonateAndProxy(t *testing.T) {
Timeout: 4 * time.Second,
Impersonate: true,
}
client := getSharedReqClient(opts)
client, err := getSharedReqClient(opts)
require.NoError(t, err)
require.NotNil(t, client)
require.Equal(t, "http://proxy.local:8080|4s|true|false", buildReqClientKey(opts))
}
func TestGetSharedReqClient_InvalidProxyURL(t *testing.T) {
sharedReqClients = sync.Map{}
opts := reqClientOptions{
ProxyURL: "://missing-scheme",
Timeout: time.Second,
}
_, err := getSharedReqClient(opts)
require.Error(t, err)
require.Contains(t, err.Error(), "invalid proxy URL")
}
func TestGetSharedReqClient_ProxyURLMissingHost(t *testing.T) {
sharedReqClients = sync.Map{}
opts := reqClientOptions{
ProxyURL: "http://",
Timeout: time.Second,
}
_, err := getSharedReqClient(opts)
require.Error(t, err)
require.Contains(t, err.Error(), "proxy URL missing host")
}
func TestCreateOpenAIReqClient_Timeout120Seconds(t *testing.T) {
sharedReqClients = sync.Map{}
client := createOpenAIReqClient("http://proxy.local:8080")
client, err := createOpenAIReqClient("http://proxy.local:8080")
require.NoError(t, err)
require.Equal(t, 120*time.Second, client.GetClient().Timeout)
}
func TestCreateGeminiReqClient_ForceHTTP2Disabled(t *testing.T) {
sharedReqClients = sync.Map{}
client := createGeminiReqClient("http://proxy.local:8080")
client, err := createGeminiReqClient("http://proxy.local:8080")
require.NoError(t, err)
require.Equal(t, "", forceHTTPVersion(t, client))
}

View File

@@ -34,7 +34,7 @@ func ProvideGitHubReleaseClient(cfg *config.Config) service.GitHubReleaseClient
// ProvidePricingRemoteClient 创建定价数据远程客户端
// 从配置中读取代理设置,支持国内服务器通过代理访问 GitHub 上的定价数据
func ProvidePricingRemoteClient(cfg *config.Config) service.PricingRemoteClient {
return NewPricingRemoteClient(cfg.Update.ProxyURL)
return NewPricingRemoteClient(cfg.Update.ProxyURL, cfg.Security.ProxyFallback.AllowDirectOnError)
}
// ProvideSessionLimitCache 创建会话限制缓存