fix(安全): 修复上游校验与 URL 清理问题
增加请求阶段 DNS 解析校验,阻断重绑定到私网 补充默认透传 WWW-Authenticate 头,保留认证挑战 前端相对 URL 过滤拒绝 // 协议相对路径 测试: go test ./internal/repository -run TestGitHubReleaseServiceSuite 测试: go test ./internal/repository -run TestTurnstileServiceSuite 测试: go test ./internal/repository -run TestProxyProbeServiceSuite 测试: go test ./internal/repository -run TestClaudeUsageServiceSuite
This commit is contained in:
@@ -118,7 +118,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
|
|||||||
adminSubscriptionHandler := admin.NewSubscriptionHandler(subscriptionService)
|
adminSubscriptionHandler := admin.NewSubscriptionHandler(subscriptionService)
|
||||||
adminUsageHandler := admin.NewUsageHandler(usageService, apiKeyService, adminService)
|
adminUsageHandler := admin.NewUsageHandler(usageService, apiKeyService, adminService)
|
||||||
adminHandlers := handler.ProvideAdminHandlers(dashboardHandler, adminUserHandler, groupHandler, accountHandler, oAuthHandler, openAIOAuthHandler, geminiOAuthHandler, antigravityOAuthHandler, proxyHandler, adminRedeemHandler, settingHandler, systemHandler, adminSubscriptionHandler, adminUsageHandler)
|
adminHandlers := handler.ProvideAdminHandlers(dashboardHandler, adminUserHandler, groupHandler, accountHandler, oAuthHandler, openAIOAuthHandler, geminiOAuthHandler, antigravityOAuthHandler, proxyHandler, adminRedeemHandler, settingHandler, systemHandler, adminSubscriptionHandler, adminUsageHandler)
|
||||||
pricingRemoteClient := repository.NewPricingRemoteClient()
|
pricingRemoteClient := repository.NewPricingRemoteClient(configConfig)
|
||||||
pricingService, err := service.ProvidePricingService(configConfig, pricingRemoteClient)
|
pricingService, err := service.ProvidePricingService(configConfig, pricingRemoteClient)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
|
|||||||
@@ -26,6 +26,7 @@ import (
|
|||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/util/urlvalidator"
|
||||||
"golang.org/x/net/proxy"
|
"golang.org/x/net/proxy"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -43,6 +44,8 @@ type Options struct {
|
|||||||
ResponseHeaderTimeout time.Duration // 等待响应头超时时间
|
ResponseHeaderTimeout time.Duration // 等待响应头超时时间
|
||||||
InsecureSkipVerify bool // 是否跳过 TLS 证书验证
|
InsecureSkipVerify bool // 是否跳过 TLS 证书验证
|
||||||
ProxyStrict bool // 严格代理模式:代理失败时返回错误而非回退
|
ProxyStrict bool // 严格代理模式:代理失败时返回错误而非回退
|
||||||
|
ValidateResolvedIP bool // 是否校验解析后的 IP(防止 DNS Rebinding)
|
||||||
|
AllowPrivateHosts bool // 允许私有地址解析(与 ValidateResolvedIP 一起使用)
|
||||||
|
|
||||||
// 可选的连接池参数(不设置则使用默认值)
|
// 可选的连接池参数(不设置则使用默认值)
|
||||||
MaxIdleConns int // 最大空闲连接总数(默认 100)
|
MaxIdleConns int // 最大空闲连接总数(默认 100)
|
||||||
@@ -86,8 +89,12 @@ func buildClient(opts Options) (*http.Client, error) {
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var rt http.RoundTripper = transport
|
||||||
|
if opts.ValidateResolvedIP && !opts.AllowPrivateHosts {
|
||||||
|
rt = &validatedTransport{base: transport}
|
||||||
|
}
|
||||||
return &http.Client{
|
return &http.Client{
|
||||||
Transport: transport,
|
Transport: rt,
|
||||||
Timeout: opts.Timeout,
|
Timeout: opts.Timeout,
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
@@ -144,14 +151,32 @@ func buildTransport(opts Options) (*http.Transport, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func buildClientKey(opts Options) string {
|
func buildClientKey(opts Options) string {
|
||||||
return fmt.Sprintf("%s|%s|%s|%t|%t|%d|%d|%d",
|
return fmt.Sprintf("%s|%s|%s|%t|%t|%t|%t|%d|%d|%d",
|
||||||
strings.TrimSpace(opts.ProxyURL),
|
strings.TrimSpace(opts.ProxyURL),
|
||||||
opts.Timeout.String(),
|
opts.Timeout.String(),
|
||||||
opts.ResponseHeaderTimeout.String(),
|
opts.ResponseHeaderTimeout.String(),
|
||||||
opts.InsecureSkipVerify,
|
opts.InsecureSkipVerify,
|
||||||
opts.ProxyStrict,
|
opts.ProxyStrict,
|
||||||
|
opts.ValidateResolvedIP,
|
||||||
|
opts.AllowPrivateHosts,
|
||||||
opts.MaxIdleConns,
|
opts.MaxIdleConns,
|
||||||
opts.MaxIdleConnsPerHost,
|
opts.MaxIdleConnsPerHost,
|
||||||
opts.MaxConnsPerHost,
|
opts.MaxConnsPerHost,
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type validatedTransport struct {
|
||||||
|
base http.RoundTripper
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *validatedTransport) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||||
|
if req != nil && req.URL != nil {
|
||||||
|
host := strings.TrimSpace(req.URL.Hostname())
|
||||||
|
if host != "" {
|
||||||
|
if err := urlvalidator.ValidateResolvedIP(host); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return t.base.RoundTrip(req)
|
||||||
|
}
|
||||||
|
|||||||
@@ -15,7 +15,8 @@ import (
|
|||||||
const defaultClaudeUsageURL = "https://api.anthropic.com/api/oauth/usage"
|
const defaultClaudeUsageURL = "https://api.anthropic.com/api/oauth/usage"
|
||||||
|
|
||||||
type claudeUsageService struct {
|
type claudeUsageService struct {
|
||||||
usageURL string
|
usageURL string
|
||||||
|
allowPrivateHosts bool
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewClaudeUsageFetcher() service.ClaudeUsageFetcher {
|
func NewClaudeUsageFetcher() service.ClaudeUsageFetcher {
|
||||||
@@ -24,8 +25,10 @@ func NewClaudeUsageFetcher() service.ClaudeUsageFetcher {
|
|||||||
|
|
||||||
func (s *claudeUsageService) FetchUsage(ctx context.Context, accessToken, proxyURL string) (*service.ClaudeUsageResponse, error) {
|
func (s *claudeUsageService) FetchUsage(ctx context.Context, accessToken, proxyURL string) (*service.ClaudeUsageResponse, error) {
|
||||||
client, err := httpclient.GetClient(httpclient.Options{
|
client, err := httpclient.GetClient(httpclient.Options{
|
||||||
ProxyURL: proxyURL,
|
ProxyURL: proxyURL,
|
||||||
Timeout: 30 * time.Second,
|
Timeout: 30 * time.Second,
|
||||||
|
ValidateResolvedIP: true,
|
||||||
|
AllowPrivateHosts: s.allowPrivateHosts,
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
client = &http.Client{Timeout: 30 * time.Second}
|
client = &http.Client{Timeout: 30 * time.Second}
|
||||||
|
|||||||
@@ -45,7 +45,10 @@ func (s *ClaudeUsageServiceSuite) TestFetchUsage_Success() {
|
|||||||
}`)
|
}`)
|
||||||
}))
|
}))
|
||||||
|
|
||||||
s.fetcher = &claudeUsageService{usageURL: s.srv.URL}
|
s.fetcher = &claudeUsageService{
|
||||||
|
usageURL: s.srv.URL,
|
||||||
|
allowPrivateHosts: true,
|
||||||
|
}
|
||||||
|
|
||||||
resp, err := s.fetcher.FetchUsage(context.Background(), "at", "://bad-proxy-url")
|
resp, err := s.fetcher.FetchUsage(context.Background(), "at", "://bad-proxy-url")
|
||||||
require.NoError(s.T(), err, "FetchUsage")
|
require.NoError(s.T(), err, "FetchUsage")
|
||||||
@@ -64,7 +67,10 @@ func (s *ClaudeUsageServiceSuite) TestFetchUsage_NonOK() {
|
|||||||
_, _ = io.WriteString(w, "nope")
|
_, _ = io.WriteString(w, "nope")
|
||||||
}))
|
}))
|
||||||
|
|
||||||
s.fetcher = &claudeUsageService{usageURL: s.srv.URL}
|
s.fetcher = &claudeUsageService{
|
||||||
|
usageURL: s.srv.URL,
|
||||||
|
allowPrivateHosts: true,
|
||||||
|
}
|
||||||
|
|
||||||
_, err := s.fetcher.FetchUsage(context.Background(), "at", "")
|
_, err := s.fetcher.FetchUsage(context.Background(), "at", "")
|
||||||
require.Error(s.T(), err)
|
require.Error(s.T(), err)
|
||||||
@@ -78,7 +84,10 @@ func (s *ClaudeUsageServiceSuite) TestFetchUsage_BadJSON() {
|
|||||||
_, _ = io.WriteString(w, "not-json")
|
_, _ = io.WriteString(w, "not-json")
|
||||||
}))
|
}))
|
||||||
|
|
||||||
s.fetcher = &claudeUsageService{usageURL: s.srv.URL}
|
s.fetcher = &claudeUsageService{
|
||||||
|
usageURL: s.srv.URL,
|
||||||
|
allowPrivateHosts: true,
|
||||||
|
}
|
||||||
|
|
||||||
_, err := s.fetcher.FetchUsage(context.Background(), "at", "")
|
_, err := s.fetcher.FetchUsage(context.Background(), "at", "")
|
||||||
require.Error(s.T(), err)
|
require.Error(s.T(), err)
|
||||||
@@ -91,7 +100,10 @@ func (s *ClaudeUsageServiceSuite) TestFetchUsage_ContextCancel() {
|
|||||||
<-r.Context().Done()
|
<-r.Context().Done()
|
||||||
}))
|
}))
|
||||||
|
|
||||||
s.fetcher = &claudeUsageService{usageURL: s.srv.URL}
|
s.fetcher = &claudeUsageService{
|
||||||
|
usageURL: s.srv.URL,
|
||||||
|
allowPrivateHosts: true,
|
||||||
|
}
|
||||||
|
|
||||||
ctx, cancel := context.WithCancel(context.Background())
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
cancel() // Cancel immediately
|
cancel() // Cancel immediately
|
||||||
|
|||||||
@@ -14,18 +14,23 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
type githubReleaseClient struct {
|
type githubReleaseClient struct {
|
||||||
httpClient *http.Client
|
httpClient *http.Client
|
||||||
|
allowPrivateHosts bool
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewGitHubReleaseClient() service.GitHubReleaseClient {
|
func NewGitHubReleaseClient() service.GitHubReleaseClient {
|
||||||
|
allowPrivate := false
|
||||||
sharedClient, err := httpclient.GetClient(httpclient.Options{
|
sharedClient, err := httpclient.GetClient(httpclient.Options{
|
||||||
Timeout: 30 * time.Second,
|
Timeout: 30 * time.Second,
|
||||||
|
ValidateResolvedIP: true,
|
||||||
|
AllowPrivateHosts: allowPrivate,
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
sharedClient = &http.Client{Timeout: 30 * time.Second}
|
sharedClient = &http.Client{Timeout: 30 * time.Second}
|
||||||
}
|
}
|
||||||
return &githubReleaseClient{
|
return &githubReleaseClient{
|
||||||
httpClient: sharedClient,
|
httpClient: sharedClient,
|
||||||
|
allowPrivateHosts: allowPrivate,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -64,7 +69,9 @@ func (c *githubReleaseClient) DownloadFile(ctx context.Context, url, dest string
|
|||||||
}
|
}
|
||||||
|
|
||||||
downloadClient, err := httpclient.GetClient(httpclient.Options{
|
downloadClient, err := httpclient.GetClient(httpclient.Options{
|
||||||
Timeout: 10 * time.Minute,
|
Timeout: 10 * time.Minute,
|
||||||
|
ValidateResolvedIP: true,
|
||||||
|
AllowPrivateHosts: c.allowPrivateHosts,
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
downloadClient = &http.Client{Timeout: 10 * time.Minute}
|
downloadClient = &http.Client{Timeout: 10 * time.Minute}
|
||||||
|
|||||||
@@ -37,6 +37,13 @@ func (t *testTransport) RoundTrip(req *http.Request) (*http.Response, error) {
|
|||||||
return http.DefaultTransport.RoundTrip(newReq)
|
return http.DefaultTransport.RoundTrip(newReq)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func newTestGitHubReleaseClient() *githubReleaseClient {
|
||||||
|
return &githubReleaseClient{
|
||||||
|
httpClient: &http.Client{},
|
||||||
|
allowPrivateHosts: true,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func (s *GitHubReleaseServiceSuite) SetupTest() {
|
func (s *GitHubReleaseServiceSuite) SetupTest() {
|
||||||
s.tempDir = s.T().TempDir()
|
s.tempDir = s.T().TempDir()
|
||||||
}
|
}
|
||||||
@@ -55,9 +62,7 @@ func (s *GitHubReleaseServiceSuite) TestDownloadFile_EnforcesMaxSize_ContentLeng
|
|||||||
_, _ = w.Write(bytes.Repeat([]byte("a"), 100))
|
_, _ = w.Write(bytes.Repeat([]byte("a"), 100))
|
||||||
}))
|
}))
|
||||||
|
|
||||||
client, ok := NewGitHubReleaseClient().(*githubReleaseClient)
|
s.client = newTestGitHubReleaseClient()
|
||||||
require.True(s.T(), ok, "type assertion failed")
|
|
||||||
s.client = client
|
|
||||||
|
|
||||||
dest := filepath.Join(s.tempDir, "file1.bin")
|
dest := filepath.Join(s.tempDir, "file1.bin")
|
||||||
err := s.client.DownloadFile(context.Background(), s.srv.URL, dest, 10)
|
err := s.client.DownloadFile(context.Background(), s.srv.URL, dest, 10)
|
||||||
@@ -82,9 +87,7 @@ func (s *GitHubReleaseServiceSuite) TestDownloadFile_EnforcesMaxSize_Chunked() {
|
|||||||
}
|
}
|
||||||
}))
|
}))
|
||||||
|
|
||||||
client, ok := NewGitHubReleaseClient().(*githubReleaseClient)
|
s.client = newTestGitHubReleaseClient()
|
||||||
require.True(s.T(), ok, "type assertion failed")
|
|
||||||
s.client = client
|
|
||||||
|
|
||||||
dest := filepath.Join(s.tempDir, "file2.bin")
|
dest := filepath.Join(s.tempDir, "file2.bin")
|
||||||
err := s.client.DownloadFile(context.Background(), s.srv.URL, dest, 10)
|
err := s.client.DownloadFile(context.Background(), s.srv.URL, dest, 10)
|
||||||
@@ -108,9 +111,7 @@ func (s *GitHubReleaseServiceSuite) TestDownloadFile_Success() {
|
|||||||
}
|
}
|
||||||
}))
|
}))
|
||||||
|
|
||||||
client, ok := NewGitHubReleaseClient().(*githubReleaseClient)
|
s.client = newTestGitHubReleaseClient()
|
||||||
require.True(s.T(), ok, "type assertion failed")
|
|
||||||
s.client = client
|
|
||||||
|
|
||||||
dest := filepath.Join(s.tempDir, "file3.bin")
|
dest := filepath.Join(s.tempDir, "file3.bin")
|
||||||
err := s.client.DownloadFile(context.Background(), s.srv.URL, dest, 200)
|
err := s.client.DownloadFile(context.Background(), s.srv.URL, dest, 200)
|
||||||
@@ -127,9 +128,7 @@ func (s *GitHubReleaseServiceSuite) TestDownloadFile_404() {
|
|||||||
w.WriteHeader(http.StatusNotFound)
|
w.WriteHeader(http.StatusNotFound)
|
||||||
}))
|
}))
|
||||||
|
|
||||||
client, ok := NewGitHubReleaseClient().(*githubReleaseClient)
|
s.client = newTestGitHubReleaseClient()
|
||||||
require.True(s.T(), ok, "type assertion failed")
|
|
||||||
s.client = client
|
|
||||||
|
|
||||||
dest := filepath.Join(s.tempDir, "notfound.bin")
|
dest := filepath.Join(s.tempDir, "notfound.bin")
|
||||||
err := s.client.DownloadFile(context.Background(), s.srv.URL, dest, 100)
|
err := s.client.DownloadFile(context.Background(), s.srv.URL, dest, 100)
|
||||||
@@ -145,9 +144,7 @@ func (s *GitHubReleaseServiceSuite) TestFetchChecksumFile_Success() {
|
|||||||
_, _ = w.Write([]byte("sum"))
|
_, _ = w.Write([]byte("sum"))
|
||||||
}))
|
}))
|
||||||
|
|
||||||
client, ok := NewGitHubReleaseClient().(*githubReleaseClient)
|
s.client = newTestGitHubReleaseClient()
|
||||||
require.True(s.T(), ok, "type assertion failed")
|
|
||||||
s.client = client
|
|
||||||
|
|
||||||
body, err := s.client.FetchChecksumFile(context.Background(), s.srv.URL)
|
body, err := s.client.FetchChecksumFile(context.Background(), s.srv.URL)
|
||||||
require.NoError(s.T(), err, "FetchChecksumFile")
|
require.NoError(s.T(), err, "FetchChecksumFile")
|
||||||
@@ -159,9 +156,7 @@ func (s *GitHubReleaseServiceSuite) TestFetchChecksumFile_Non200() {
|
|||||||
w.WriteHeader(http.StatusInternalServerError)
|
w.WriteHeader(http.StatusInternalServerError)
|
||||||
}))
|
}))
|
||||||
|
|
||||||
client, ok := NewGitHubReleaseClient().(*githubReleaseClient)
|
s.client = newTestGitHubReleaseClient()
|
||||||
require.True(s.T(), ok, "type assertion failed")
|
|
||||||
s.client = client
|
|
||||||
|
|
||||||
_, err := s.client.FetchChecksumFile(context.Background(), s.srv.URL)
|
_, err := s.client.FetchChecksumFile(context.Background(), s.srv.URL)
|
||||||
require.Error(s.T(), err, "expected error for non-200")
|
require.Error(s.T(), err, "expected error for non-200")
|
||||||
@@ -172,9 +167,7 @@ func (s *GitHubReleaseServiceSuite) TestDownloadFile_ContextCancel() {
|
|||||||
<-r.Context().Done()
|
<-r.Context().Done()
|
||||||
}))
|
}))
|
||||||
|
|
||||||
client, ok := NewGitHubReleaseClient().(*githubReleaseClient)
|
s.client = newTestGitHubReleaseClient()
|
||||||
require.True(s.T(), ok, "type assertion failed")
|
|
||||||
s.client = client
|
|
||||||
|
|
||||||
ctx, cancel := context.WithCancel(context.Background())
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
cancel()
|
cancel()
|
||||||
@@ -185,9 +178,7 @@ func (s *GitHubReleaseServiceSuite) TestDownloadFile_ContextCancel() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (s *GitHubReleaseServiceSuite) TestDownloadFile_InvalidURL() {
|
func (s *GitHubReleaseServiceSuite) TestDownloadFile_InvalidURL() {
|
||||||
client, ok := NewGitHubReleaseClient().(*githubReleaseClient)
|
s.client = newTestGitHubReleaseClient()
|
||||||
require.True(s.T(), ok, "type assertion failed")
|
|
||||||
s.client = client
|
|
||||||
|
|
||||||
dest := filepath.Join(s.tempDir, "invalid.bin")
|
dest := filepath.Join(s.tempDir, "invalid.bin")
|
||||||
err := s.client.DownloadFile(context.Background(), "://invalid-url", dest, 100)
|
err := s.client.DownloadFile(context.Background(), "://invalid-url", dest, 100)
|
||||||
@@ -200,9 +191,7 @@ func (s *GitHubReleaseServiceSuite) TestDownloadFile_InvalidDestPath() {
|
|||||||
_, _ = w.Write([]byte("content"))
|
_, _ = w.Write([]byte("content"))
|
||||||
}))
|
}))
|
||||||
|
|
||||||
client, ok := NewGitHubReleaseClient().(*githubReleaseClient)
|
s.client = newTestGitHubReleaseClient()
|
||||||
require.True(s.T(), ok, "type assertion failed")
|
|
||||||
s.client = client
|
|
||||||
|
|
||||||
// Use a path that cannot be created (directory doesn't exist)
|
// Use a path that cannot be created (directory doesn't exist)
|
||||||
dest := filepath.Join(s.tempDir, "nonexistent", "subdir", "file.bin")
|
dest := filepath.Join(s.tempDir, "nonexistent", "subdir", "file.bin")
|
||||||
@@ -211,9 +200,7 @@ func (s *GitHubReleaseServiceSuite) TestDownloadFile_InvalidDestPath() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (s *GitHubReleaseServiceSuite) TestFetchChecksumFile_InvalidURL() {
|
func (s *GitHubReleaseServiceSuite) TestFetchChecksumFile_InvalidURL() {
|
||||||
client, ok := NewGitHubReleaseClient().(*githubReleaseClient)
|
s.client = newTestGitHubReleaseClient()
|
||||||
require.True(s.T(), ok, "type assertion failed")
|
|
||||||
s.client = client
|
|
||||||
|
|
||||||
_, err := s.client.FetchChecksumFile(context.Background(), "://invalid-url")
|
_, err := s.client.FetchChecksumFile(context.Background(), "://invalid-url")
|
||||||
require.Error(s.T(), err, "expected error for invalid URL")
|
require.Error(s.T(), err, "expected error for invalid URL")
|
||||||
@@ -247,6 +234,7 @@ func (s *GitHubReleaseServiceSuite) TestFetchLatestRelease_Success() {
|
|||||||
httpClient: &http.Client{
|
httpClient: &http.Client{
|
||||||
Transport: &testTransport{testServerURL: s.srv.URL},
|
Transport: &testTransport{testServerURL: s.srv.URL},
|
||||||
},
|
},
|
||||||
|
allowPrivateHosts: true,
|
||||||
}
|
}
|
||||||
|
|
||||||
release, err := s.client.FetchLatestRelease(context.Background(), "test/repo")
|
release, err := s.client.FetchLatestRelease(context.Background(), "test/repo")
|
||||||
@@ -266,6 +254,7 @@ func (s *GitHubReleaseServiceSuite) TestFetchLatestRelease_Non200() {
|
|||||||
httpClient: &http.Client{
|
httpClient: &http.Client{
|
||||||
Transport: &testTransport{testServerURL: s.srv.URL},
|
Transport: &testTransport{testServerURL: s.srv.URL},
|
||||||
},
|
},
|
||||||
|
allowPrivateHosts: true,
|
||||||
}
|
}
|
||||||
|
|
||||||
_, err := s.client.FetchLatestRelease(context.Background(), "test/repo")
|
_, err := s.client.FetchLatestRelease(context.Background(), "test/repo")
|
||||||
@@ -283,6 +272,7 @@ func (s *GitHubReleaseServiceSuite) TestFetchLatestRelease_InvalidJSON() {
|
|||||||
httpClient: &http.Client{
|
httpClient: &http.Client{
|
||||||
Transport: &testTransport{testServerURL: s.srv.URL},
|
Transport: &testTransport{testServerURL: s.srv.URL},
|
||||||
},
|
},
|
||||||
|
allowPrivateHosts: true,
|
||||||
}
|
}
|
||||||
|
|
||||||
_, err := s.client.FetchLatestRelease(context.Background(), "test/repo")
|
_, err := s.client.FetchLatestRelease(context.Background(), "test/repo")
|
||||||
@@ -298,6 +288,7 @@ func (s *GitHubReleaseServiceSuite) TestFetchLatestRelease_ContextCancel() {
|
|||||||
httpClient: &http.Client{
|
httpClient: &http.Client{
|
||||||
Transport: &testTransport{testServerURL: s.srv.URL},
|
Transport: &testTransport{testServerURL: s.srv.URL},
|
||||||
},
|
},
|
||||||
|
allowPrivateHosts: true,
|
||||||
}
|
}
|
||||||
|
|
||||||
ctx, cancel := context.WithCancel(context.Background())
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
@@ -312,9 +303,7 @@ func (s *GitHubReleaseServiceSuite) TestFetchChecksumFile_ContextCancel() {
|
|||||||
<-r.Context().Done()
|
<-r.Context().Done()
|
||||||
}))
|
}))
|
||||||
|
|
||||||
client, ok := NewGitHubReleaseClient().(*githubReleaseClient)
|
s.client = newTestGitHubReleaseClient()
|
||||||
require.True(s.T(), ok, "type assertion failed")
|
|
||||||
s.client = client
|
|
||||||
|
|
||||||
ctx, cancel := context.WithCancel(context.Background())
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
cancel()
|
cancel()
|
||||||
|
|||||||
@@ -14,6 +14,7 @@ import (
|
|||||||
|
|
||||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/util/urlvalidator"
|
||||||
)
|
)
|
||||||
|
|
||||||
// 默认配置常量
|
// 默认配置常量
|
||||||
@@ -119,6 +120,10 @@ func NewHTTPUpstream(cfg *config.Config) service.HTTPUpstream {
|
|||||||
// - 调用方必须关闭 resp.Body,否则会导致 inFlight 计数泄漏
|
// - 调用方必须关闭 resp.Body,否则会导致 inFlight 计数泄漏
|
||||||
// - inFlight > 0 的客户端不会被淘汰,确保活跃请求不被中断
|
// - inFlight > 0 的客户端不会被淘汰,确保活跃请求不被中断
|
||||||
func (s *httpUpstreamService) Do(req *http.Request, proxyURL string, accountID int64, accountConcurrency int) (*http.Response, error) {
|
func (s *httpUpstreamService) Do(req *http.Request, proxyURL string, accountID int64, accountConcurrency int) (*http.Response, error) {
|
||||||
|
if err := s.validateRequestHost(req); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
// 获取或创建对应的客户端,并标记请求占用
|
// 获取或创建对应的客户端,并标记请求占用
|
||||||
entry, err := s.acquireClient(proxyURL, accountID, accountConcurrency)
|
entry, err := s.acquireClient(proxyURL, accountID, accountConcurrency)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -144,6 +149,37 @@ func (s *httpUpstreamService) Do(req *http.Request, proxyURL string, accountID i
|
|||||||
return resp, nil
|
return resp, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *httpUpstreamService) shouldValidateResolvedIP() bool {
|
||||||
|
if s.cfg == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return !s.cfg.Security.URLAllowlist.AllowPrivateHosts
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *httpUpstreamService) validateRequestHost(req *http.Request) error {
|
||||||
|
if !s.shouldValidateResolvedIP() {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
if req == nil || req.URL == nil {
|
||||||
|
return errors.New("request url is nil")
|
||||||
|
}
|
||||||
|
host := strings.TrimSpace(req.URL.Hostname())
|
||||||
|
if host == "" {
|
||||||
|
return errors.New("request host is empty")
|
||||||
|
}
|
||||||
|
if err := urlvalidator.ValidateResolvedIP(host); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *httpUpstreamService) redirectChecker(req *http.Request, via []*http.Request) error {
|
||||||
|
if len(via) >= 10 {
|
||||||
|
return errors.New("stopped after 10 redirects")
|
||||||
|
}
|
||||||
|
return s.validateRequestHost(req)
|
||||||
|
}
|
||||||
|
|
||||||
// acquireClient 获取或创建客户端,并标记为进行中请求
|
// acquireClient 获取或创建客户端,并标记为进行中请求
|
||||||
// 用于请求路径,避免在获取后被淘汰
|
// 用于请求路径,避免在获取后被淘汰
|
||||||
func (s *httpUpstreamService) acquireClient(proxyURL string, accountID int64, accountConcurrency int) (*upstreamClientEntry, error) {
|
func (s *httpUpstreamService) acquireClient(proxyURL string, accountID int64, accountConcurrency int) (*upstreamClientEntry, error) {
|
||||||
@@ -226,6 +262,9 @@ func (s *httpUpstreamService) getClientEntry(proxyURL string, accountID int64, a
|
|||||||
// 缓存未命中或需要重建,创建新客户端
|
// 缓存未命中或需要重建,创建新客户端
|
||||||
settings := s.resolvePoolSettings(isolation, accountConcurrency)
|
settings := s.resolvePoolSettings(isolation, accountConcurrency)
|
||||||
client := &http.Client{Transport: buildUpstreamTransport(settings, parsedProxy)}
|
client := &http.Client{Transport: buildUpstreamTransport(settings, parsedProxy)}
|
||||||
|
if s.shouldValidateResolvedIP() {
|
||||||
|
client.CheckRedirect = s.redirectChecker
|
||||||
|
}
|
||||||
entry := &upstreamClientEntry{
|
entry := &upstreamClientEntry{
|
||||||
client: client,
|
client: client,
|
||||||
proxyKey: proxyKey,
|
proxyKey: proxyKey,
|
||||||
|
|||||||
@@ -23,7 +23,13 @@ type HTTPUpstreamSuite struct {
|
|||||||
// SetupTest 每个测试用例执行前的初始化
|
// SetupTest 每个测试用例执行前的初始化
|
||||||
// 创建空配置,各测试用例可按需覆盖
|
// 创建空配置,各测试用例可按需覆盖
|
||||||
func (s *HTTPUpstreamSuite) SetupTest() {
|
func (s *HTTPUpstreamSuite) SetupTest() {
|
||||||
s.cfg = &config.Config{}
|
s.cfg = &config.Config{
|
||||||
|
Security: config.SecurityConfig{
|
||||||
|
URLAllowlist: config.URLAllowlistConfig{
|
||||||
|
AllowPrivateHosts: true,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// newService 创建测试用的 httpUpstreamService 实例
|
// newService 创建测试用的 httpUpstreamService 实例
|
||||||
|
|||||||
@@ -8,6 +8,7 @@ import (
|
|||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||||
"github.com/Wei-Shaw/sub2api/internal/pkg/httpclient"
|
"github.com/Wei-Shaw/sub2api/internal/pkg/httpclient"
|
||||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||||
)
|
)
|
||||||
@@ -16,9 +17,15 @@ type pricingRemoteClient struct {
|
|||||||
httpClient *http.Client
|
httpClient *http.Client
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewPricingRemoteClient() service.PricingRemoteClient {
|
func NewPricingRemoteClient(cfg *config.Config) service.PricingRemoteClient {
|
||||||
|
allowPrivate := false
|
||||||
|
if cfg != nil {
|
||||||
|
allowPrivate = cfg.Security.URLAllowlist.AllowPrivateHosts
|
||||||
|
}
|
||||||
sharedClient, err := httpclient.GetClient(httpclient.Options{
|
sharedClient, err := httpclient.GetClient(httpclient.Options{
|
||||||
Timeout: 30 * time.Second,
|
Timeout: 30 * time.Second,
|
||||||
|
ValidateResolvedIP: true,
|
||||||
|
AllowPrivateHosts: allowPrivate,
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
sharedClient = &http.Client{Timeout: 30 * time.Second}
|
sharedClient = &http.Client{Timeout: 30 * time.Second}
|
||||||
|
|||||||
@@ -6,6 +6,7 @@ import (
|
|||||||
"net/http/httptest"
|
"net/http/httptest"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
"github.com/stretchr/testify/suite"
|
"github.com/stretchr/testify/suite"
|
||||||
)
|
)
|
||||||
@@ -19,7 +20,13 @@ type PricingServiceSuite struct {
|
|||||||
|
|
||||||
func (s *PricingServiceSuite) SetupTest() {
|
func (s *PricingServiceSuite) SetupTest() {
|
||||||
s.ctx = context.Background()
|
s.ctx = context.Background()
|
||||||
client, ok := NewPricingRemoteClient().(*pricingRemoteClient)
|
client, ok := NewPricingRemoteClient(&config.Config{
|
||||||
|
Security: config.SecurityConfig{
|
||||||
|
URLAllowlist: config.URLAllowlistConfig{
|
||||||
|
AllowPrivateHosts: true,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}).(*pricingRemoteClient)
|
||||||
require.True(s.T(), ok, "type assertion failed")
|
require.True(s.T(), ok, "type assertion failed")
|
||||||
s.client = client
|
s.client = client
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -5,19 +5,21 @@ import (
|
|||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
|
"log"
|
||||||
"net/http"
|
"net/http"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||||
"github.com/Wei-Shaw/sub2api/internal/pkg/httpclient"
|
"github.com/Wei-Shaw/sub2api/internal/pkg/httpclient"
|
||||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||||
"log"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
func NewProxyExitInfoProber(cfg *config.Config) service.ProxyExitInfoProber {
|
func NewProxyExitInfoProber(cfg *config.Config) service.ProxyExitInfoProber {
|
||||||
insecure := false
|
insecure := false
|
||||||
|
allowPrivate := false
|
||||||
if cfg != nil {
|
if cfg != nil {
|
||||||
insecure = cfg.Security.ProxyProbe.InsecureSkipVerify
|
insecure = cfg.Security.ProxyProbe.InsecureSkipVerify
|
||||||
|
allowPrivate = cfg.Security.URLAllowlist.AllowPrivateHosts
|
||||||
}
|
}
|
||||||
if insecure {
|
if insecure {
|
||||||
log.Printf("[ProxyProbe] Warning: TLS verification is disabled for proxy probing.")
|
log.Printf("[ProxyProbe] Warning: TLS verification is disabled for proxy probing.")
|
||||||
@@ -25,6 +27,7 @@ func NewProxyExitInfoProber(cfg *config.Config) service.ProxyExitInfoProber {
|
|||||||
return &proxyProbeService{
|
return &proxyProbeService{
|
||||||
ipInfoURL: defaultIPInfoURL,
|
ipInfoURL: defaultIPInfoURL,
|
||||||
insecureSkipVerify: insecure,
|
insecureSkipVerify: insecure,
|
||||||
|
allowPrivateHosts: allowPrivate,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -33,6 +36,7 @@ const defaultIPInfoURL = "https://ipinfo.io/json"
|
|||||||
type proxyProbeService struct {
|
type proxyProbeService struct {
|
||||||
ipInfoURL string
|
ipInfoURL string
|
||||||
insecureSkipVerify bool
|
insecureSkipVerify bool
|
||||||
|
allowPrivateHosts bool
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *proxyProbeService) ProbeProxy(ctx context.Context, proxyURL string) (*service.ProxyExitInfo, int64, error) {
|
func (s *proxyProbeService) ProbeProxy(ctx context.Context, proxyURL string) (*service.ProxyExitInfo, int64, error) {
|
||||||
@@ -41,6 +45,8 @@ func (s *proxyProbeService) ProbeProxy(ctx context.Context, proxyURL string) (*s
|
|||||||
Timeout: 15 * time.Second,
|
Timeout: 15 * time.Second,
|
||||||
InsecureSkipVerify: s.insecureSkipVerify,
|
InsecureSkipVerify: s.insecureSkipVerify,
|
||||||
ProxyStrict: true,
|
ProxyStrict: true,
|
||||||
|
ValidateResolvedIP: true,
|
||||||
|
AllowPrivateHosts: s.allowPrivateHosts,
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, 0, fmt.Errorf("failed to create proxy client: %w", err)
|
return nil, 0, fmt.Errorf("failed to create proxy client: %w", err)
|
||||||
|
|||||||
@@ -20,7 +20,10 @@ type ProxyProbeServiceSuite struct {
|
|||||||
|
|
||||||
func (s *ProxyProbeServiceSuite) SetupTest() {
|
func (s *ProxyProbeServiceSuite) SetupTest() {
|
||||||
s.ctx = context.Background()
|
s.ctx = context.Background()
|
||||||
s.prober = &proxyProbeService{ipInfoURL: "http://ipinfo.test/json"}
|
s.prober = &proxyProbeService{
|
||||||
|
ipInfoURL: "http://ipinfo.test/json",
|
||||||
|
allowPrivateHosts: true,
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *ProxyProbeServiceSuite) TearDownTest() {
|
func (s *ProxyProbeServiceSuite) TearDownTest() {
|
||||||
|
|||||||
@@ -22,7 +22,8 @@ type turnstileVerifier struct {
|
|||||||
|
|
||||||
func NewTurnstileVerifier() service.TurnstileVerifier {
|
func NewTurnstileVerifier() service.TurnstileVerifier {
|
||||||
sharedClient, err := httpclient.GetClient(httpclient.Options{
|
sharedClient, err := httpclient.GetClient(httpclient.Options{
|
||||||
Timeout: 10 * time.Second,
|
Timeout: 10 * time.Second,
|
||||||
|
ValidateResolvedIP: true,
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
sharedClient = &http.Client{Timeout: 10 * time.Second}
|
sharedClient = &http.Client{Timeout: 10 * time.Second}
|
||||||
|
|||||||
@@ -41,6 +41,7 @@ func (s *TurnstileServiceSuite) TearDownTest() {
|
|||||||
func (s *TurnstileServiceSuite) setupServer(handler http.HandlerFunc) {
|
func (s *TurnstileServiceSuite) setupServer(handler http.HandlerFunc) {
|
||||||
s.srv = httptest.NewServer(handler)
|
s.srv = httptest.NewServer(handler)
|
||||||
s.verifier.verifyURL = s.srv.URL
|
s.verifier.verifyURL = s.srv.URL
|
||||||
|
s.verifier.httpClient = s.srv.Client()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *TurnstileServiceSuite) TestVerifyToken_SendsFormAndDecodesJSON() {
|
func (s *TurnstileServiceSuite) TestVerifyToken_SendsFormAndDecodesJSON() {
|
||||||
|
|||||||
@@ -203,7 +203,9 @@ func (s *CRSSyncService) SyncFromCRS(ctx context.Context, input SyncFromCRSInput
|
|||||||
}
|
}
|
||||||
|
|
||||||
client, err := httpclient.GetClient(httpclient.Options{
|
client, err := httpclient.GetClient(httpclient.Options{
|
||||||
Timeout: 20 * time.Second,
|
Timeout: 20 * time.Second,
|
||||||
|
ValidateResolvedIP: true,
|
||||||
|
AllowPrivateHosts: s.cfg.Security.URLAllowlist.AllowPrivateHosts,
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
client = &http.Client{Timeout: 20 * time.Second}
|
client = &http.Client{Timeout: 20 * time.Second}
|
||||||
|
|||||||
@@ -498,8 +498,9 @@ func fetchProjectIDFromResourceManager(ctx context.Context, accessToken, proxyUR
|
|||||||
req.Header.Set("User-Agent", geminicli.GeminiCLIUserAgent)
|
req.Header.Set("User-Agent", geminicli.GeminiCLIUserAgent)
|
||||||
|
|
||||||
client, err := httpclient.GetClient(httpclient.Options{
|
client, err := httpclient.GetClient(httpclient.Options{
|
||||||
ProxyURL: strings.TrimSpace(proxyURL),
|
ProxyURL: strings.TrimSpace(proxyURL),
|
||||||
Timeout: 30 * time.Second,
|
Timeout: 30 * time.Second,
|
||||||
|
ValidateResolvedIP: true,
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
client = &http.Client{Timeout: 30 * time.Second}
|
client = &http.Client{Timeout: 30 * time.Second}
|
||||||
|
|||||||
@@ -31,6 +31,7 @@ var defaultAllowed = map[string]struct{}{
|
|||||||
"x-ratelimit-reset-tokens": {},
|
"x-ratelimit-reset-tokens": {},
|
||||||
"retry-after": {},
|
"retry-after": {},
|
||||||
"location": {},
|
"location": {},
|
||||||
|
"www-authenticate": {},
|
||||||
}
|
}
|
||||||
|
|
||||||
// hopByHopHeaders 是跳过的 hop-by-hop 头部,这些头部由 HTTP 库自动处理
|
// hopByHopHeaders 是跳过的 hop-by-hop 头部,这些头部由 HTTP 库自动处理
|
||||||
|
|||||||
@@ -14,7 +14,7 @@ export function sanitizeUrl(value: string, options: SanitizeOptions = {}): strin
|
|||||||
return ''
|
return ''
|
||||||
}
|
}
|
||||||
|
|
||||||
if (options.allowRelative && trimmed.startsWith('/')) {
|
if (options.allowRelative && trimmed.startsWith('/') && !trimmed.startsWith('//')) {
|
||||||
return trimmed
|
return trimmed
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user