fix(安全): 修复上游校验与 URL 清理问题
增加请求阶段 DNS 解析校验,阻断重绑定到私网 补充默认透传 WWW-Authenticate 头,保留认证挑战 前端相对 URL 过滤拒绝 // 协议相对路径 测试: go test ./internal/repository -run TestGitHubReleaseServiceSuite 测试: go test ./internal/repository -run TestTurnstileServiceSuite 测试: go test ./internal/repository -run TestProxyProbeServiceSuite 测试: go test ./internal/repository -run TestClaudeUsageServiceSuite
This commit is contained in:
@@ -118,7 +118,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
|
||||
adminSubscriptionHandler := admin.NewSubscriptionHandler(subscriptionService)
|
||||
adminUsageHandler := admin.NewUsageHandler(usageService, apiKeyService, adminService)
|
||||
adminHandlers := handler.ProvideAdminHandlers(dashboardHandler, adminUserHandler, groupHandler, accountHandler, oAuthHandler, openAIOAuthHandler, geminiOAuthHandler, antigravityOAuthHandler, proxyHandler, adminRedeemHandler, settingHandler, systemHandler, adminSubscriptionHandler, adminUsageHandler)
|
||||
pricingRemoteClient := repository.NewPricingRemoteClient()
|
||||
pricingRemoteClient := repository.NewPricingRemoteClient(configConfig)
|
||||
pricingService, err := service.ProvidePricingService(configConfig, pricingRemoteClient)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
||||
@@ -26,6 +26,7 @@ import (
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/util/urlvalidator"
|
||||
"golang.org/x/net/proxy"
|
||||
)
|
||||
|
||||
@@ -43,6 +44,8 @@ type Options struct {
|
||||
ResponseHeaderTimeout time.Duration // 等待响应头超时时间
|
||||
InsecureSkipVerify bool // 是否跳过 TLS 证书验证
|
||||
ProxyStrict bool // 严格代理模式:代理失败时返回错误而非回退
|
||||
ValidateResolvedIP bool // 是否校验解析后的 IP(防止 DNS Rebinding)
|
||||
AllowPrivateHosts bool // 允许私有地址解析(与 ValidateResolvedIP 一起使用)
|
||||
|
||||
// 可选的连接池参数(不设置则使用默认值)
|
||||
MaxIdleConns int // 最大空闲连接总数(默认 100)
|
||||
@@ -86,8 +89,12 @@ func buildClient(opts Options) (*http.Client, error) {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var rt http.RoundTripper = transport
|
||||
if opts.ValidateResolvedIP && !opts.AllowPrivateHosts {
|
||||
rt = &validatedTransport{base: transport}
|
||||
}
|
||||
return &http.Client{
|
||||
Transport: transport,
|
||||
Transport: rt,
|
||||
Timeout: opts.Timeout,
|
||||
}, nil
|
||||
}
|
||||
@@ -144,14 +151,32 @@ func buildTransport(opts Options) (*http.Transport, error) {
|
||||
}
|
||||
|
||||
func buildClientKey(opts Options) string {
|
||||
return fmt.Sprintf("%s|%s|%s|%t|%t|%d|%d|%d",
|
||||
return fmt.Sprintf("%s|%s|%s|%t|%t|%t|%t|%d|%d|%d",
|
||||
strings.TrimSpace(opts.ProxyURL),
|
||||
opts.Timeout.String(),
|
||||
opts.ResponseHeaderTimeout.String(),
|
||||
opts.InsecureSkipVerify,
|
||||
opts.ProxyStrict,
|
||||
opts.ValidateResolvedIP,
|
||||
opts.AllowPrivateHosts,
|
||||
opts.MaxIdleConns,
|
||||
opts.MaxIdleConnsPerHost,
|
||||
opts.MaxConnsPerHost,
|
||||
)
|
||||
}
|
||||
|
||||
type validatedTransport struct {
|
||||
base http.RoundTripper
|
||||
}
|
||||
|
||||
func (t *validatedTransport) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||
if req != nil && req.URL != nil {
|
||||
host := strings.TrimSpace(req.URL.Hostname())
|
||||
if host != "" {
|
||||
if err := urlvalidator.ValidateResolvedIP(host); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
}
|
||||
return t.base.RoundTrip(req)
|
||||
}
|
||||
|
||||
@@ -15,7 +15,8 @@ import (
|
||||
const defaultClaudeUsageURL = "https://api.anthropic.com/api/oauth/usage"
|
||||
|
||||
type claudeUsageService struct {
|
||||
usageURL string
|
||||
usageURL string
|
||||
allowPrivateHosts bool
|
||||
}
|
||||
|
||||
func NewClaudeUsageFetcher() service.ClaudeUsageFetcher {
|
||||
@@ -24,8 +25,10 @@ func NewClaudeUsageFetcher() service.ClaudeUsageFetcher {
|
||||
|
||||
func (s *claudeUsageService) FetchUsage(ctx context.Context, accessToken, proxyURL string) (*service.ClaudeUsageResponse, error) {
|
||||
client, err := httpclient.GetClient(httpclient.Options{
|
||||
ProxyURL: proxyURL,
|
||||
Timeout: 30 * time.Second,
|
||||
ProxyURL: proxyURL,
|
||||
Timeout: 30 * time.Second,
|
||||
ValidateResolvedIP: true,
|
||||
AllowPrivateHosts: s.allowPrivateHosts,
|
||||
})
|
||||
if err != nil {
|
||||
client = &http.Client{Timeout: 30 * time.Second}
|
||||
|
||||
@@ -45,7 +45,10 @@ func (s *ClaudeUsageServiceSuite) TestFetchUsage_Success() {
|
||||
}`)
|
||||
}))
|
||||
|
||||
s.fetcher = &claudeUsageService{usageURL: s.srv.URL}
|
||||
s.fetcher = &claudeUsageService{
|
||||
usageURL: s.srv.URL,
|
||||
allowPrivateHosts: true,
|
||||
}
|
||||
|
||||
resp, err := s.fetcher.FetchUsage(context.Background(), "at", "://bad-proxy-url")
|
||||
require.NoError(s.T(), err, "FetchUsage")
|
||||
@@ -64,7 +67,10 @@ func (s *ClaudeUsageServiceSuite) TestFetchUsage_NonOK() {
|
||||
_, _ = io.WriteString(w, "nope")
|
||||
}))
|
||||
|
||||
s.fetcher = &claudeUsageService{usageURL: s.srv.URL}
|
||||
s.fetcher = &claudeUsageService{
|
||||
usageURL: s.srv.URL,
|
||||
allowPrivateHosts: true,
|
||||
}
|
||||
|
||||
_, err := s.fetcher.FetchUsage(context.Background(), "at", "")
|
||||
require.Error(s.T(), err)
|
||||
@@ -78,7 +84,10 @@ func (s *ClaudeUsageServiceSuite) TestFetchUsage_BadJSON() {
|
||||
_, _ = io.WriteString(w, "not-json")
|
||||
}))
|
||||
|
||||
s.fetcher = &claudeUsageService{usageURL: s.srv.URL}
|
||||
s.fetcher = &claudeUsageService{
|
||||
usageURL: s.srv.URL,
|
||||
allowPrivateHosts: true,
|
||||
}
|
||||
|
||||
_, err := s.fetcher.FetchUsage(context.Background(), "at", "")
|
||||
require.Error(s.T(), err)
|
||||
@@ -91,7 +100,10 @@ func (s *ClaudeUsageServiceSuite) TestFetchUsage_ContextCancel() {
|
||||
<-r.Context().Done()
|
||||
}))
|
||||
|
||||
s.fetcher = &claudeUsageService{usageURL: s.srv.URL}
|
||||
s.fetcher = &claudeUsageService{
|
||||
usageURL: s.srv.URL,
|
||||
allowPrivateHosts: true,
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
cancel() // Cancel immediately
|
||||
|
||||
@@ -14,18 +14,23 @@ import (
|
||||
)
|
||||
|
||||
type githubReleaseClient struct {
|
||||
httpClient *http.Client
|
||||
httpClient *http.Client
|
||||
allowPrivateHosts bool
|
||||
}
|
||||
|
||||
func NewGitHubReleaseClient() service.GitHubReleaseClient {
|
||||
allowPrivate := false
|
||||
sharedClient, err := httpclient.GetClient(httpclient.Options{
|
||||
Timeout: 30 * time.Second,
|
||||
Timeout: 30 * time.Second,
|
||||
ValidateResolvedIP: true,
|
||||
AllowPrivateHosts: allowPrivate,
|
||||
})
|
||||
if err != nil {
|
||||
sharedClient = &http.Client{Timeout: 30 * time.Second}
|
||||
}
|
||||
return &githubReleaseClient{
|
||||
httpClient: sharedClient,
|
||||
httpClient: sharedClient,
|
||||
allowPrivateHosts: allowPrivate,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -64,7 +69,9 @@ func (c *githubReleaseClient) DownloadFile(ctx context.Context, url, dest string
|
||||
}
|
||||
|
||||
downloadClient, err := httpclient.GetClient(httpclient.Options{
|
||||
Timeout: 10 * time.Minute,
|
||||
Timeout: 10 * time.Minute,
|
||||
ValidateResolvedIP: true,
|
||||
AllowPrivateHosts: c.allowPrivateHosts,
|
||||
})
|
||||
if err != nil {
|
||||
downloadClient = &http.Client{Timeout: 10 * time.Minute}
|
||||
|
||||
@@ -37,6 +37,13 @@ func (t *testTransport) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||
return http.DefaultTransport.RoundTrip(newReq)
|
||||
}
|
||||
|
||||
func newTestGitHubReleaseClient() *githubReleaseClient {
|
||||
return &githubReleaseClient{
|
||||
httpClient: &http.Client{},
|
||||
allowPrivateHosts: true,
|
||||
}
|
||||
}
|
||||
|
||||
func (s *GitHubReleaseServiceSuite) SetupTest() {
|
||||
s.tempDir = s.T().TempDir()
|
||||
}
|
||||
@@ -55,9 +62,7 @@ func (s *GitHubReleaseServiceSuite) TestDownloadFile_EnforcesMaxSize_ContentLeng
|
||||
_, _ = w.Write(bytes.Repeat([]byte("a"), 100))
|
||||
}))
|
||||
|
||||
client, ok := NewGitHubReleaseClient().(*githubReleaseClient)
|
||||
require.True(s.T(), ok, "type assertion failed")
|
||||
s.client = client
|
||||
s.client = newTestGitHubReleaseClient()
|
||||
|
||||
dest := filepath.Join(s.tempDir, "file1.bin")
|
||||
err := s.client.DownloadFile(context.Background(), s.srv.URL, dest, 10)
|
||||
@@ -82,9 +87,7 @@ func (s *GitHubReleaseServiceSuite) TestDownloadFile_EnforcesMaxSize_Chunked() {
|
||||
}
|
||||
}))
|
||||
|
||||
client, ok := NewGitHubReleaseClient().(*githubReleaseClient)
|
||||
require.True(s.T(), ok, "type assertion failed")
|
||||
s.client = client
|
||||
s.client = newTestGitHubReleaseClient()
|
||||
|
||||
dest := filepath.Join(s.tempDir, "file2.bin")
|
||||
err := s.client.DownloadFile(context.Background(), s.srv.URL, dest, 10)
|
||||
@@ -108,9 +111,7 @@ func (s *GitHubReleaseServiceSuite) TestDownloadFile_Success() {
|
||||
}
|
||||
}))
|
||||
|
||||
client, ok := NewGitHubReleaseClient().(*githubReleaseClient)
|
||||
require.True(s.T(), ok, "type assertion failed")
|
||||
s.client = client
|
||||
s.client = newTestGitHubReleaseClient()
|
||||
|
||||
dest := filepath.Join(s.tempDir, "file3.bin")
|
||||
err := s.client.DownloadFile(context.Background(), s.srv.URL, dest, 200)
|
||||
@@ -127,9 +128,7 @@ func (s *GitHubReleaseServiceSuite) TestDownloadFile_404() {
|
||||
w.WriteHeader(http.StatusNotFound)
|
||||
}))
|
||||
|
||||
client, ok := NewGitHubReleaseClient().(*githubReleaseClient)
|
||||
require.True(s.T(), ok, "type assertion failed")
|
||||
s.client = client
|
||||
s.client = newTestGitHubReleaseClient()
|
||||
|
||||
dest := filepath.Join(s.tempDir, "notfound.bin")
|
||||
err := s.client.DownloadFile(context.Background(), s.srv.URL, dest, 100)
|
||||
@@ -145,9 +144,7 @@ func (s *GitHubReleaseServiceSuite) TestFetchChecksumFile_Success() {
|
||||
_, _ = w.Write([]byte("sum"))
|
||||
}))
|
||||
|
||||
client, ok := NewGitHubReleaseClient().(*githubReleaseClient)
|
||||
require.True(s.T(), ok, "type assertion failed")
|
||||
s.client = client
|
||||
s.client = newTestGitHubReleaseClient()
|
||||
|
||||
body, err := s.client.FetchChecksumFile(context.Background(), s.srv.URL)
|
||||
require.NoError(s.T(), err, "FetchChecksumFile")
|
||||
@@ -159,9 +156,7 @@ func (s *GitHubReleaseServiceSuite) TestFetchChecksumFile_Non200() {
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
}))
|
||||
|
||||
client, ok := NewGitHubReleaseClient().(*githubReleaseClient)
|
||||
require.True(s.T(), ok, "type assertion failed")
|
||||
s.client = client
|
||||
s.client = newTestGitHubReleaseClient()
|
||||
|
||||
_, err := s.client.FetchChecksumFile(context.Background(), s.srv.URL)
|
||||
require.Error(s.T(), err, "expected error for non-200")
|
||||
@@ -172,9 +167,7 @@ func (s *GitHubReleaseServiceSuite) TestDownloadFile_ContextCancel() {
|
||||
<-r.Context().Done()
|
||||
}))
|
||||
|
||||
client, ok := NewGitHubReleaseClient().(*githubReleaseClient)
|
||||
require.True(s.T(), ok, "type assertion failed")
|
||||
s.client = client
|
||||
s.client = newTestGitHubReleaseClient()
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
cancel()
|
||||
@@ -185,9 +178,7 @@ func (s *GitHubReleaseServiceSuite) TestDownloadFile_ContextCancel() {
|
||||
}
|
||||
|
||||
func (s *GitHubReleaseServiceSuite) TestDownloadFile_InvalidURL() {
|
||||
client, ok := NewGitHubReleaseClient().(*githubReleaseClient)
|
||||
require.True(s.T(), ok, "type assertion failed")
|
||||
s.client = client
|
||||
s.client = newTestGitHubReleaseClient()
|
||||
|
||||
dest := filepath.Join(s.tempDir, "invalid.bin")
|
||||
err := s.client.DownloadFile(context.Background(), "://invalid-url", dest, 100)
|
||||
@@ -200,9 +191,7 @@ func (s *GitHubReleaseServiceSuite) TestDownloadFile_InvalidDestPath() {
|
||||
_, _ = w.Write([]byte("content"))
|
||||
}))
|
||||
|
||||
client, ok := NewGitHubReleaseClient().(*githubReleaseClient)
|
||||
require.True(s.T(), ok, "type assertion failed")
|
||||
s.client = client
|
||||
s.client = newTestGitHubReleaseClient()
|
||||
|
||||
// Use a path that cannot be created (directory doesn't exist)
|
||||
dest := filepath.Join(s.tempDir, "nonexistent", "subdir", "file.bin")
|
||||
@@ -211,9 +200,7 @@ func (s *GitHubReleaseServiceSuite) TestDownloadFile_InvalidDestPath() {
|
||||
}
|
||||
|
||||
func (s *GitHubReleaseServiceSuite) TestFetchChecksumFile_InvalidURL() {
|
||||
client, ok := NewGitHubReleaseClient().(*githubReleaseClient)
|
||||
require.True(s.T(), ok, "type assertion failed")
|
||||
s.client = client
|
||||
s.client = newTestGitHubReleaseClient()
|
||||
|
||||
_, err := s.client.FetchChecksumFile(context.Background(), "://invalid-url")
|
||||
require.Error(s.T(), err, "expected error for invalid URL")
|
||||
@@ -247,6 +234,7 @@ func (s *GitHubReleaseServiceSuite) TestFetchLatestRelease_Success() {
|
||||
httpClient: &http.Client{
|
||||
Transport: &testTransport{testServerURL: s.srv.URL},
|
||||
},
|
||||
allowPrivateHosts: true,
|
||||
}
|
||||
|
||||
release, err := s.client.FetchLatestRelease(context.Background(), "test/repo")
|
||||
@@ -266,6 +254,7 @@ func (s *GitHubReleaseServiceSuite) TestFetchLatestRelease_Non200() {
|
||||
httpClient: &http.Client{
|
||||
Transport: &testTransport{testServerURL: s.srv.URL},
|
||||
},
|
||||
allowPrivateHosts: true,
|
||||
}
|
||||
|
||||
_, err := s.client.FetchLatestRelease(context.Background(), "test/repo")
|
||||
@@ -283,6 +272,7 @@ func (s *GitHubReleaseServiceSuite) TestFetchLatestRelease_InvalidJSON() {
|
||||
httpClient: &http.Client{
|
||||
Transport: &testTransport{testServerURL: s.srv.URL},
|
||||
},
|
||||
allowPrivateHosts: true,
|
||||
}
|
||||
|
||||
_, err := s.client.FetchLatestRelease(context.Background(), "test/repo")
|
||||
@@ -298,6 +288,7 @@ func (s *GitHubReleaseServiceSuite) TestFetchLatestRelease_ContextCancel() {
|
||||
httpClient: &http.Client{
|
||||
Transport: &testTransport{testServerURL: s.srv.URL},
|
||||
},
|
||||
allowPrivateHosts: true,
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
@@ -312,9 +303,7 @@ func (s *GitHubReleaseServiceSuite) TestFetchChecksumFile_ContextCancel() {
|
||||
<-r.Context().Done()
|
||||
}))
|
||||
|
||||
client, ok := NewGitHubReleaseClient().(*githubReleaseClient)
|
||||
require.True(s.T(), ok, "type assertion failed")
|
||||
s.client = client
|
||||
s.client = newTestGitHubReleaseClient()
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
cancel()
|
||||
|
||||
@@ -14,6 +14,7 @@ import (
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/Wei-Shaw/sub2api/internal/util/urlvalidator"
|
||||
)
|
||||
|
||||
// 默认配置常量
|
||||
@@ -119,6 +120,10 @@ func NewHTTPUpstream(cfg *config.Config) service.HTTPUpstream {
|
||||
// - 调用方必须关闭 resp.Body,否则会导致 inFlight 计数泄漏
|
||||
// - inFlight > 0 的客户端不会被淘汰,确保活跃请求不被中断
|
||||
func (s *httpUpstreamService) Do(req *http.Request, proxyURL string, accountID int64, accountConcurrency int) (*http.Response, error) {
|
||||
if err := s.validateRequestHost(req); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 获取或创建对应的客户端,并标记请求占用
|
||||
entry, err := s.acquireClient(proxyURL, accountID, accountConcurrency)
|
||||
if err != nil {
|
||||
@@ -144,6 +149,37 @@ func (s *httpUpstreamService) Do(req *http.Request, proxyURL string, accountID i
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
func (s *httpUpstreamService) shouldValidateResolvedIP() bool {
|
||||
if s.cfg == nil {
|
||||
return false
|
||||
}
|
||||
return !s.cfg.Security.URLAllowlist.AllowPrivateHosts
|
||||
}
|
||||
|
||||
func (s *httpUpstreamService) validateRequestHost(req *http.Request) error {
|
||||
if !s.shouldValidateResolvedIP() {
|
||||
return nil
|
||||
}
|
||||
if req == nil || req.URL == nil {
|
||||
return errors.New("request url is nil")
|
||||
}
|
||||
host := strings.TrimSpace(req.URL.Hostname())
|
||||
if host == "" {
|
||||
return errors.New("request host is empty")
|
||||
}
|
||||
if err := urlvalidator.ValidateResolvedIP(host); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *httpUpstreamService) redirectChecker(req *http.Request, via []*http.Request) error {
|
||||
if len(via) >= 10 {
|
||||
return errors.New("stopped after 10 redirects")
|
||||
}
|
||||
return s.validateRequestHost(req)
|
||||
}
|
||||
|
||||
// acquireClient 获取或创建客户端,并标记为进行中请求
|
||||
// 用于请求路径,避免在获取后被淘汰
|
||||
func (s *httpUpstreamService) acquireClient(proxyURL string, accountID int64, accountConcurrency int) (*upstreamClientEntry, error) {
|
||||
@@ -226,6 +262,9 @@ func (s *httpUpstreamService) getClientEntry(proxyURL string, accountID int64, a
|
||||
// 缓存未命中或需要重建,创建新客户端
|
||||
settings := s.resolvePoolSettings(isolation, accountConcurrency)
|
||||
client := &http.Client{Transport: buildUpstreamTransport(settings, parsedProxy)}
|
||||
if s.shouldValidateResolvedIP() {
|
||||
client.CheckRedirect = s.redirectChecker
|
||||
}
|
||||
entry := &upstreamClientEntry{
|
||||
client: client,
|
||||
proxyKey: proxyKey,
|
||||
|
||||
@@ -23,7 +23,13 @@ type HTTPUpstreamSuite struct {
|
||||
// SetupTest 每个测试用例执行前的初始化
|
||||
// 创建空配置,各测试用例可按需覆盖
|
||||
func (s *HTTPUpstreamSuite) SetupTest() {
|
||||
s.cfg = &config.Config{}
|
||||
s.cfg = &config.Config{
|
||||
Security: config.SecurityConfig{
|
||||
URLAllowlist: config.URLAllowlistConfig{
|
||||
AllowPrivateHosts: true,
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// newService 创建测试用的 httpUpstreamService 实例
|
||||
|
||||
@@ -8,6 +8,7 @@ import (
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/httpclient"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
)
|
||||
@@ -16,9 +17,15 @@ type pricingRemoteClient struct {
|
||||
httpClient *http.Client
|
||||
}
|
||||
|
||||
func NewPricingRemoteClient() service.PricingRemoteClient {
|
||||
func NewPricingRemoteClient(cfg *config.Config) service.PricingRemoteClient {
|
||||
allowPrivate := false
|
||||
if cfg != nil {
|
||||
allowPrivate = cfg.Security.URLAllowlist.AllowPrivateHosts
|
||||
}
|
||||
sharedClient, err := httpclient.GetClient(httpclient.Options{
|
||||
Timeout: 30 * time.Second,
|
||||
Timeout: 30 * time.Second,
|
||||
ValidateResolvedIP: true,
|
||||
AllowPrivateHosts: allowPrivate,
|
||||
})
|
||||
if err != nil {
|
||||
sharedClient = &http.Client{Timeout: 30 * time.Second}
|
||||
|
||||
@@ -6,6 +6,7 @@ import (
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/stretchr/testify/suite"
|
||||
)
|
||||
@@ -19,7 +20,13 @@ type PricingServiceSuite struct {
|
||||
|
||||
func (s *PricingServiceSuite) SetupTest() {
|
||||
s.ctx = context.Background()
|
||||
client, ok := NewPricingRemoteClient().(*pricingRemoteClient)
|
||||
client, ok := NewPricingRemoteClient(&config.Config{
|
||||
Security: config.SecurityConfig{
|
||||
URLAllowlist: config.URLAllowlistConfig{
|
||||
AllowPrivateHosts: true,
|
||||
},
|
||||
},
|
||||
}).(*pricingRemoteClient)
|
||||
require.True(s.T(), ok, "type assertion failed")
|
||||
s.client = client
|
||||
}
|
||||
|
||||
@@ -5,19 +5,21 @@ import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/httpclient"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"log"
|
||||
)
|
||||
|
||||
func NewProxyExitInfoProber(cfg *config.Config) service.ProxyExitInfoProber {
|
||||
insecure := false
|
||||
allowPrivate := false
|
||||
if cfg != nil {
|
||||
insecure = cfg.Security.ProxyProbe.InsecureSkipVerify
|
||||
allowPrivate = cfg.Security.URLAllowlist.AllowPrivateHosts
|
||||
}
|
||||
if insecure {
|
||||
log.Printf("[ProxyProbe] Warning: TLS verification is disabled for proxy probing.")
|
||||
@@ -25,6 +27,7 @@ func NewProxyExitInfoProber(cfg *config.Config) service.ProxyExitInfoProber {
|
||||
return &proxyProbeService{
|
||||
ipInfoURL: defaultIPInfoURL,
|
||||
insecureSkipVerify: insecure,
|
||||
allowPrivateHosts: allowPrivate,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -33,6 +36,7 @@ const defaultIPInfoURL = "https://ipinfo.io/json"
|
||||
type proxyProbeService struct {
|
||||
ipInfoURL string
|
||||
insecureSkipVerify bool
|
||||
allowPrivateHosts bool
|
||||
}
|
||||
|
||||
func (s *proxyProbeService) ProbeProxy(ctx context.Context, proxyURL string) (*service.ProxyExitInfo, int64, error) {
|
||||
@@ -41,6 +45,8 @@ func (s *proxyProbeService) ProbeProxy(ctx context.Context, proxyURL string) (*s
|
||||
Timeout: 15 * time.Second,
|
||||
InsecureSkipVerify: s.insecureSkipVerify,
|
||||
ProxyStrict: true,
|
||||
ValidateResolvedIP: true,
|
||||
AllowPrivateHosts: s.allowPrivateHosts,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, 0, fmt.Errorf("failed to create proxy client: %w", err)
|
||||
|
||||
@@ -20,7 +20,10 @@ type ProxyProbeServiceSuite struct {
|
||||
|
||||
func (s *ProxyProbeServiceSuite) SetupTest() {
|
||||
s.ctx = context.Background()
|
||||
s.prober = &proxyProbeService{ipInfoURL: "http://ipinfo.test/json"}
|
||||
s.prober = &proxyProbeService{
|
||||
ipInfoURL: "http://ipinfo.test/json",
|
||||
allowPrivateHosts: true,
|
||||
}
|
||||
}
|
||||
|
||||
func (s *ProxyProbeServiceSuite) TearDownTest() {
|
||||
|
||||
@@ -22,7 +22,8 @@ type turnstileVerifier struct {
|
||||
|
||||
func NewTurnstileVerifier() service.TurnstileVerifier {
|
||||
sharedClient, err := httpclient.GetClient(httpclient.Options{
|
||||
Timeout: 10 * time.Second,
|
||||
Timeout: 10 * time.Second,
|
||||
ValidateResolvedIP: true,
|
||||
})
|
||||
if err != nil {
|
||||
sharedClient = &http.Client{Timeout: 10 * time.Second}
|
||||
|
||||
@@ -41,6 +41,7 @@ func (s *TurnstileServiceSuite) TearDownTest() {
|
||||
func (s *TurnstileServiceSuite) setupServer(handler http.HandlerFunc) {
|
||||
s.srv = httptest.NewServer(handler)
|
||||
s.verifier.verifyURL = s.srv.URL
|
||||
s.verifier.httpClient = s.srv.Client()
|
||||
}
|
||||
|
||||
func (s *TurnstileServiceSuite) TestVerifyToken_SendsFormAndDecodesJSON() {
|
||||
|
||||
@@ -203,7 +203,9 @@ func (s *CRSSyncService) SyncFromCRS(ctx context.Context, input SyncFromCRSInput
|
||||
}
|
||||
|
||||
client, err := httpclient.GetClient(httpclient.Options{
|
||||
Timeout: 20 * time.Second,
|
||||
Timeout: 20 * time.Second,
|
||||
ValidateResolvedIP: true,
|
||||
AllowPrivateHosts: s.cfg.Security.URLAllowlist.AllowPrivateHosts,
|
||||
})
|
||||
if err != nil {
|
||||
client = &http.Client{Timeout: 20 * time.Second}
|
||||
|
||||
@@ -498,8 +498,9 @@ func fetchProjectIDFromResourceManager(ctx context.Context, accessToken, proxyUR
|
||||
req.Header.Set("User-Agent", geminicli.GeminiCLIUserAgent)
|
||||
|
||||
client, err := httpclient.GetClient(httpclient.Options{
|
||||
ProxyURL: strings.TrimSpace(proxyURL),
|
||||
Timeout: 30 * time.Second,
|
||||
ProxyURL: strings.TrimSpace(proxyURL),
|
||||
Timeout: 30 * time.Second,
|
||||
ValidateResolvedIP: true,
|
||||
})
|
||||
if err != nil {
|
||||
client = &http.Client{Timeout: 30 * time.Second}
|
||||
|
||||
@@ -31,6 +31,7 @@ var defaultAllowed = map[string]struct{}{
|
||||
"x-ratelimit-reset-tokens": {},
|
||||
"retry-after": {},
|
||||
"location": {},
|
||||
"www-authenticate": {},
|
||||
}
|
||||
|
||||
// hopByHopHeaders 是跳过的 hop-by-hop 头部,这些头部由 HTTP 库自动处理
|
||||
|
||||
@@ -14,7 +14,7 @@ export function sanitizeUrl(value: string, options: SanitizeOptions = {}): strin
|
||||
return ''
|
||||
}
|
||||
|
||||
if (options.allowRelative && trimmed.startsWith('/')) {
|
||||
if (options.allowRelative && trimmed.startsWith('/') && !trimmed.startsWith('//')) {
|
||||
return trimmed
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user