fix(安全): 修复上游校验与 URL 清理问题

增加请求阶段 DNS 解析校验,阻断重绑定到私网
补充默认透传 WWW-Authenticate 头,保留认证挑战
前端相对 URL 过滤拒绝 // 协议相对路径

测试: go test ./internal/repository -run TestGitHubReleaseServiceSuite
测试: go test ./internal/repository -run TestTurnstileServiceSuite
测试: go test ./internal/repository -run TestProxyProbeServiceSuite
测试: go test ./internal/repository -run TestClaudeUsageServiceSuite
This commit is contained in:
yangjianbo
2026-01-03 10:52:24 +08:00
parent bd4bf00856
commit 25e1632628
18 changed files with 168 additions and 58 deletions

View File

@@ -118,7 +118,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
adminSubscriptionHandler := admin.NewSubscriptionHandler(subscriptionService)
adminUsageHandler := admin.NewUsageHandler(usageService, apiKeyService, adminService)
adminHandlers := handler.ProvideAdminHandlers(dashboardHandler, adminUserHandler, groupHandler, accountHandler, oAuthHandler, openAIOAuthHandler, geminiOAuthHandler, antigravityOAuthHandler, proxyHandler, adminRedeemHandler, settingHandler, systemHandler, adminSubscriptionHandler, adminUsageHandler)
pricingRemoteClient := repository.NewPricingRemoteClient()
pricingRemoteClient := repository.NewPricingRemoteClient(configConfig)
pricingService, err := service.ProvidePricingService(configConfig, pricingRemoteClient)
if err != nil {
return nil, err

View File

@@ -26,6 +26,7 @@ import (
"sync"
"time"
"github.com/Wei-Shaw/sub2api/internal/util/urlvalidator"
"golang.org/x/net/proxy"
)
@@ -43,6 +44,8 @@ type Options struct {
ResponseHeaderTimeout time.Duration // 等待响应头超时时间
InsecureSkipVerify bool // 是否跳过 TLS 证书验证
ProxyStrict bool // 严格代理模式:代理失败时返回错误而非回退
ValidateResolvedIP bool // 是否校验解析后的 IP防止 DNS Rebinding
AllowPrivateHosts bool // 允许私有地址解析(与 ValidateResolvedIP 一起使用)
// 可选的连接池参数(不设置则使用默认值)
MaxIdleConns int // 最大空闲连接总数(默认 100
@@ -86,8 +89,12 @@ func buildClient(opts Options) (*http.Client, error) {
return nil, err
}
var rt http.RoundTripper = transport
if opts.ValidateResolvedIP && !opts.AllowPrivateHosts {
rt = &validatedTransport{base: transport}
}
return &http.Client{
Transport: transport,
Transport: rt,
Timeout: opts.Timeout,
}, nil
}
@@ -144,14 +151,32 @@ func buildTransport(opts Options) (*http.Transport, error) {
}
func buildClientKey(opts Options) string {
return fmt.Sprintf("%s|%s|%s|%t|%t|%d|%d|%d",
return fmt.Sprintf("%s|%s|%s|%t|%t|%t|%t|%d|%d|%d",
strings.TrimSpace(opts.ProxyURL),
opts.Timeout.String(),
opts.ResponseHeaderTimeout.String(),
opts.InsecureSkipVerify,
opts.ProxyStrict,
opts.ValidateResolvedIP,
opts.AllowPrivateHosts,
opts.MaxIdleConns,
opts.MaxIdleConnsPerHost,
opts.MaxConnsPerHost,
)
}
type validatedTransport struct {
base http.RoundTripper
}
func (t *validatedTransport) RoundTrip(req *http.Request) (*http.Response, error) {
if req != nil && req.URL != nil {
host := strings.TrimSpace(req.URL.Hostname())
if host != "" {
if err := urlvalidator.ValidateResolvedIP(host); err != nil {
return nil, err
}
}
}
return t.base.RoundTrip(req)
}

View File

@@ -15,7 +15,8 @@ import (
const defaultClaudeUsageURL = "https://api.anthropic.com/api/oauth/usage"
type claudeUsageService struct {
usageURL string
usageURL string
allowPrivateHosts bool
}
func NewClaudeUsageFetcher() service.ClaudeUsageFetcher {
@@ -24,8 +25,10 @@ func NewClaudeUsageFetcher() service.ClaudeUsageFetcher {
func (s *claudeUsageService) FetchUsage(ctx context.Context, accessToken, proxyURL string) (*service.ClaudeUsageResponse, error) {
client, err := httpclient.GetClient(httpclient.Options{
ProxyURL: proxyURL,
Timeout: 30 * time.Second,
ProxyURL: proxyURL,
Timeout: 30 * time.Second,
ValidateResolvedIP: true,
AllowPrivateHosts: s.allowPrivateHosts,
})
if err != nil {
client = &http.Client{Timeout: 30 * time.Second}

View File

@@ -45,7 +45,10 @@ func (s *ClaudeUsageServiceSuite) TestFetchUsage_Success() {
}`)
}))
s.fetcher = &claudeUsageService{usageURL: s.srv.URL}
s.fetcher = &claudeUsageService{
usageURL: s.srv.URL,
allowPrivateHosts: true,
}
resp, err := s.fetcher.FetchUsage(context.Background(), "at", "://bad-proxy-url")
require.NoError(s.T(), err, "FetchUsage")
@@ -64,7 +67,10 @@ func (s *ClaudeUsageServiceSuite) TestFetchUsage_NonOK() {
_, _ = io.WriteString(w, "nope")
}))
s.fetcher = &claudeUsageService{usageURL: s.srv.URL}
s.fetcher = &claudeUsageService{
usageURL: s.srv.URL,
allowPrivateHosts: true,
}
_, err := s.fetcher.FetchUsage(context.Background(), "at", "")
require.Error(s.T(), err)
@@ -78,7 +84,10 @@ func (s *ClaudeUsageServiceSuite) TestFetchUsage_BadJSON() {
_, _ = io.WriteString(w, "not-json")
}))
s.fetcher = &claudeUsageService{usageURL: s.srv.URL}
s.fetcher = &claudeUsageService{
usageURL: s.srv.URL,
allowPrivateHosts: true,
}
_, err := s.fetcher.FetchUsage(context.Background(), "at", "")
require.Error(s.T(), err)
@@ -91,7 +100,10 @@ func (s *ClaudeUsageServiceSuite) TestFetchUsage_ContextCancel() {
<-r.Context().Done()
}))
s.fetcher = &claudeUsageService{usageURL: s.srv.URL}
s.fetcher = &claudeUsageService{
usageURL: s.srv.URL,
allowPrivateHosts: true,
}
ctx, cancel := context.WithCancel(context.Background())
cancel() // Cancel immediately

View File

@@ -14,18 +14,23 @@ import (
)
type githubReleaseClient struct {
httpClient *http.Client
httpClient *http.Client
allowPrivateHosts bool
}
func NewGitHubReleaseClient() service.GitHubReleaseClient {
allowPrivate := false
sharedClient, err := httpclient.GetClient(httpclient.Options{
Timeout: 30 * time.Second,
Timeout: 30 * time.Second,
ValidateResolvedIP: true,
AllowPrivateHosts: allowPrivate,
})
if err != nil {
sharedClient = &http.Client{Timeout: 30 * time.Second}
}
return &githubReleaseClient{
httpClient: sharedClient,
httpClient: sharedClient,
allowPrivateHosts: allowPrivate,
}
}
@@ -64,7 +69,9 @@ func (c *githubReleaseClient) DownloadFile(ctx context.Context, url, dest string
}
downloadClient, err := httpclient.GetClient(httpclient.Options{
Timeout: 10 * time.Minute,
Timeout: 10 * time.Minute,
ValidateResolvedIP: true,
AllowPrivateHosts: c.allowPrivateHosts,
})
if err != nil {
downloadClient = &http.Client{Timeout: 10 * time.Minute}

View File

@@ -37,6 +37,13 @@ func (t *testTransport) RoundTrip(req *http.Request) (*http.Response, error) {
return http.DefaultTransport.RoundTrip(newReq)
}
func newTestGitHubReleaseClient() *githubReleaseClient {
return &githubReleaseClient{
httpClient: &http.Client{},
allowPrivateHosts: true,
}
}
func (s *GitHubReleaseServiceSuite) SetupTest() {
s.tempDir = s.T().TempDir()
}
@@ -55,9 +62,7 @@ func (s *GitHubReleaseServiceSuite) TestDownloadFile_EnforcesMaxSize_ContentLeng
_, _ = w.Write(bytes.Repeat([]byte("a"), 100))
}))
client, ok := NewGitHubReleaseClient().(*githubReleaseClient)
require.True(s.T(), ok, "type assertion failed")
s.client = client
s.client = newTestGitHubReleaseClient()
dest := filepath.Join(s.tempDir, "file1.bin")
err := s.client.DownloadFile(context.Background(), s.srv.URL, dest, 10)
@@ -82,9 +87,7 @@ func (s *GitHubReleaseServiceSuite) TestDownloadFile_EnforcesMaxSize_Chunked() {
}
}))
client, ok := NewGitHubReleaseClient().(*githubReleaseClient)
require.True(s.T(), ok, "type assertion failed")
s.client = client
s.client = newTestGitHubReleaseClient()
dest := filepath.Join(s.tempDir, "file2.bin")
err := s.client.DownloadFile(context.Background(), s.srv.URL, dest, 10)
@@ -108,9 +111,7 @@ func (s *GitHubReleaseServiceSuite) TestDownloadFile_Success() {
}
}))
client, ok := NewGitHubReleaseClient().(*githubReleaseClient)
require.True(s.T(), ok, "type assertion failed")
s.client = client
s.client = newTestGitHubReleaseClient()
dest := filepath.Join(s.tempDir, "file3.bin")
err := s.client.DownloadFile(context.Background(), s.srv.URL, dest, 200)
@@ -127,9 +128,7 @@ func (s *GitHubReleaseServiceSuite) TestDownloadFile_404() {
w.WriteHeader(http.StatusNotFound)
}))
client, ok := NewGitHubReleaseClient().(*githubReleaseClient)
require.True(s.T(), ok, "type assertion failed")
s.client = client
s.client = newTestGitHubReleaseClient()
dest := filepath.Join(s.tempDir, "notfound.bin")
err := s.client.DownloadFile(context.Background(), s.srv.URL, dest, 100)
@@ -145,9 +144,7 @@ func (s *GitHubReleaseServiceSuite) TestFetchChecksumFile_Success() {
_, _ = w.Write([]byte("sum"))
}))
client, ok := NewGitHubReleaseClient().(*githubReleaseClient)
require.True(s.T(), ok, "type assertion failed")
s.client = client
s.client = newTestGitHubReleaseClient()
body, err := s.client.FetchChecksumFile(context.Background(), s.srv.URL)
require.NoError(s.T(), err, "FetchChecksumFile")
@@ -159,9 +156,7 @@ func (s *GitHubReleaseServiceSuite) TestFetchChecksumFile_Non200() {
w.WriteHeader(http.StatusInternalServerError)
}))
client, ok := NewGitHubReleaseClient().(*githubReleaseClient)
require.True(s.T(), ok, "type assertion failed")
s.client = client
s.client = newTestGitHubReleaseClient()
_, err := s.client.FetchChecksumFile(context.Background(), s.srv.URL)
require.Error(s.T(), err, "expected error for non-200")
@@ -172,9 +167,7 @@ func (s *GitHubReleaseServiceSuite) TestDownloadFile_ContextCancel() {
<-r.Context().Done()
}))
client, ok := NewGitHubReleaseClient().(*githubReleaseClient)
require.True(s.T(), ok, "type assertion failed")
s.client = client
s.client = newTestGitHubReleaseClient()
ctx, cancel := context.WithCancel(context.Background())
cancel()
@@ -185,9 +178,7 @@ func (s *GitHubReleaseServiceSuite) TestDownloadFile_ContextCancel() {
}
func (s *GitHubReleaseServiceSuite) TestDownloadFile_InvalidURL() {
client, ok := NewGitHubReleaseClient().(*githubReleaseClient)
require.True(s.T(), ok, "type assertion failed")
s.client = client
s.client = newTestGitHubReleaseClient()
dest := filepath.Join(s.tempDir, "invalid.bin")
err := s.client.DownloadFile(context.Background(), "://invalid-url", dest, 100)
@@ -200,9 +191,7 @@ func (s *GitHubReleaseServiceSuite) TestDownloadFile_InvalidDestPath() {
_, _ = w.Write([]byte("content"))
}))
client, ok := NewGitHubReleaseClient().(*githubReleaseClient)
require.True(s.T(), ok, "type assertion failed")
s.client = client
s.client = newTestGitHubReleaseClient()
// Use a path that cannot be created (directory doesn't exist)
dest := filepath.Join(s.tempDir, "nonexistent", "subdir", "file.bin")
@@ -211,9 +200,7 @@ func (s *GitHubReleaseServiceSuite) TestDownloadFile_InvalidDestPath() {
}
func (s *GitHubReleaseServiceSuite) TestFetchChecksumFile_InvalidURL() {
client, ok := NewGitHubReleaseClient().(*githubReleaseClient)
require.True(s.T(), ok, "type assertion failed")
s.client = client
s.client = newTestGitHubReleaseClient()
_, err := s.client.FetchChecksumFile(context.Background(), "://invalid-url")
require.Error(s.T(), err, "expected error for invalid URL")
@@ -247,6 +234,7 @@ func (s *GitHubReleaseServiceSuite) TestFetchLatestRelease_Success() {
httpClient: &http.Client{
Transport: &testTransport{testServerURL: s.srv.URL},
},
allowPrivateHosts: true,
}
release, err := s.client.FetchLatestRelease(context.Background(), "test/repo")
@@ -266,6 +254,7 @@ func (s *GitHubReleaseServiceSuite) TestFetchLatestRelease_Non200() {
httpClient: &http.Client{
Transport: &testTransport{testServerURL: s.srv.URL},
},
allowPrivateHosts: true,
}
_, err := s.client.FetchLatestRelease(context.Background(), "test/repo")
@@ -283,6 +272,7 @@ func (s *GitHubReleaseServiceSuite) TestFetchLatestRelease_InvalidJSON() {
httpClient: &http.Client{
Transport: &testTransport{testServerURL: s.srv.URL},
},
allowPrivateHosts: true,
}
_, err := s.client.FetchLatestRelease(context.Background(), "test/repo")
@@ -298,6 +288,7 @@ func (s *GitHubReleaseServiceSuite) TestFetchLatestRelease_ContextCancel() {
httpClient: &http.Client{
Transport: &testTransport{testServerURL: s.srv.URL},
},
allowPrivateHosts: true,
}
ctx, cancel := context.WithCancel(context.Background())
@@ -312,9 +303,7 @@ func (s *GitHubReleaseServiceSuite) TestFetchChecksumFile_ContextCancel() {
<-r.Context().Done()
}))
client, ok := NewGitHubReleaseClient().(*githubReleaseClient)
require.True(s.T(), ok, "type assertion failed")
s.client = client
s.client = newTestGitHubReleaseClient()
ctx, cancel := context.WithCancel(context.Background())
cancel()

View File

@@ -14,6 +14,7 @@ import (
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/Wei-Shaw/sub2api/internal/util/urlvalidator"
)
// 默认配置常量
@@ -119,6 +120,10 @@ func NewHTTPUpstream(cfg *config.Config) service.HTTPUpstream {
// - 调用方必须关闭 resp.Body否则会导致 inFlight 计数泄漏
// - inFlight > 0 的客户端不会被淘汰,确保活跃请求不被中断
func (s *httpUpstreamService) Do(req *http.Request, proxyURL string, accountID int64, accountConcurrency int) (*http.Response, error) {
if err := s.validateRequestHost(req); err != nil {
return nil, err
}
// 获取或创建对应的客户端,并标记请求占用
entry, err := s.acquireClient(proxyURL, accountID, accountConcurrency)
if err != nil {
@@ -144,6 +149,37 @@ func (s *httpUpstreamService) Do(req *http.Request, proxyURL string, accountID i
return resp, nil
}
func (s *httpUpstreamService) shouldValidateResolvedIP() bool {
if s.cfg == nil {
return false
}
return !s.cfg.Security.URLAllowlist.AllowPrivateHosts
}
func (s *httpUpstreamService) validateRequestHost(req *http.Request) error {
if !s.shouldValidateResolvedIP() {
return nil
}
if req == nil || req.URL == nil {
return errors.New("request url is nil")
}
host := strings.TrimSpace(req.URL.Hostname())
if host == "" {
return errors.New("request host is empty")
}
if err := urlvalidator.ValidateResolvedIP(host); err != nil {
return err
}
return nil
}
func (s *httpUpstreamService) redirectChecker(req *http.Request, via []*http.Request) error {
if len(via) >= 10 {
return errors.New("stopped after 10 redirects")
}
return s.validateRequestHost(req)
}
// acquireClient 获取或创建客户端,并标记为进行中请求
// 用于请求路径,避免在获取后被淘汰
func (s *httpUpstreamService) acquireClient(proxyURL string, accountID int64, accountConcurrency int) (*upstreamClientEntry, error) {
@@ -226,6 +262,9 @@ func (s *httpUpstreamService) getClientEntry(proxyURL string, accountID int64, a
// 缓存未命中或需要重建,创建新客户端
settings := s.resolvePoolSettings(isolation, accountConcurrency)
client := &http.Client{Transport: buildUpstreamTransport(settings, parsedProxy)}
if s.shouldValidateResolvedIP() {
client.CheckRedirect = s.redirectChecker
}
entry := &upstreamClientEntry{
client: client,
proxyKey: proxyKey,

View File

@@ -23,7 +23,13 @@ type HTTPUpstreamSuite struct {
// SetupTest 每个测试用例执行前的初始化
// 创建空配置,各测试用例可按需覆盖
func (s *HTTPUpstreamSuite) SetupTest() {
s.cfg = &config.Config{}
s.cfg = &config.Config{
Security: config.SecurityConfig{
URLAllowlist: config.URLAllowlistConfig{
AllowPrivateHosts: true,
},
},
}
}
// newService 创建测试用的 httpUpstreamService 实例

View File

@@ -8,6 +8,7 @@ import (
"strings"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/pkg/httpclient"
"github.com/Wei-Shaw/sub2api/internal/service"
)
@@ -16,9 +17,15 @@ type pricingRemoteClient struct {
httpClient *http.Client
}
func NewPricingRemoteClient() service.PricingRemoteClient {
func NewPricingRemoteClient(cfg *config.Config) service.PricingRemoteClient {
allowPrivate := false
if cfg != nil {
allowPrivate = cfg.Security.URLAllowlist.AllowPrivateHosts
}
sharedClient, err := httpclient.GetClient(httpclient.Options{
Timeout: 30 * time.Second,
Timeout: 30 * time.Second,
ValidateResolvedIP: true,
AllowPrivateHosts: allowPrivate,
})
if err != nil {
sharedClient = &http.Client{Timeout: 30 * time.Second}

View File

@@ -6,6 +6,7 @@ import (
"net/http/httptest"
"testing"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/stretchr/testify/require"
"github.com/stretchr/testify/suite"
)
@@ -19,7 +20,13 @@ type PricingServiceSuite struct {
func (s *PricingServiceSuite) SetupTest() {
s.ctx = context.Background()
client, ok := NewPricingRemoteClient().(*pricingRemoteClient)
client, ok := NewPricingRemoteClient(&config.Config{
Security: config.SecurityConfig{
URLAllowlist: config.URLAllowlistConfig{
AllowPrivateHosts: true,
},
},
}).(*pricingRemoteClient)
require.True(s.T(), ok, "type assertion failed")
s.client = client
}

View File

@@ -5,19 +5,21 @@ import (
"encoding/json"
"fmt"
"io"
"log"
"net/http"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/pkg/httpclient"
"github.com/Wei-Shaw/sub2api/internal/service"
"log"
)
func NewProxyExitInfoProber(cfg *config.Config) service.ProxyExitInfoProber {
insecure := false
allowPrivate := false
if cfg != nil {
insecure = cfg.Security.ProxyProbe.InsecureSkipVerify
allowPrivate = cfg.Security.URLAllowlist.AllowPrivateHosts
}
if insecure {
log.Printf("[ProxyProbe] Warning: TLS verification is disabled for proxy probing.")
@@ -25,6 +27,7 @@ func NewProxyExitInfoProber(cfg *config.Config) service.ProxyExitInfoProber {
return &proxyProbeService{
ipInfoURL: defaultIPInfoURL,
insecureSkipVerify: insecure,
allowPrivateHosts: allowPrivate,
}
}
@@ -33,6 +36,7 @@ const defaultIPInfoURL = "https://ipinfo.io/json"
type proxyProbeService struct {
ipInfoURL string
insecureSkipVerify bool
allowPrivateHosts bool
}
func (s *proxyProbeService) ProbeProxy(ctx context.Context, proxyURL string) (*service.ProxyExitInfo, int64, error) {
@@ -41,6 +45,8 @@ func (s *proxyProbeService) ProbeProxy(ctx context.Context, proxyURL string) (*s
Timeout: 15 * time.Second,
InsecureSkipVerify: s.insecureSkipVerify,
ProxyStrict: true,
ValidateResolvedIP: true,
AllowPrivateHosts: s.allowPrivateHosts,
})
if err != nil {
return nil, 0, fmt.Errorf("failed to create proxy client: %w", err)

View File

@@ -20,7 +20,10 @@ type ProxyProbeServiceSuite struct {
func (s *ProxyProbeServiceSuite) SetupTest() {
s.ctx = context.Background()
s.prober = &proxyProbeService{ipInfoURL: "http://ipinfo.test/json"}
s.prober = &proxyProbeService{
ipInfoURL: "http://ipinfo.test/json",
allowPrivateHosts: true,
}
}
func (s *ProxyProbeServiceSuite) TearDownTest() {

View File

@@ -22,7 +22,8 @@ type turnstileVerifier struct {
func NewTurnstileVerifier() service.TurnstileVerifier {
sharedClient, err := httpclient.GetClient(httpclient.Options{
Timeout: 10 * time.Second,
Timeout: 10 * time.Second,
ValidateResolvedIP: true,
})
if err != nil {
sharedClient = &http.Client{Timeout: 10 * time.Second}

View File

@@ -41,6 +41,7 @@ func (s *TurnstileServiceSuite) TearDownTest() {
func (s *TurnstileServiceSuite) setupServer(handler http.HandlerFunc) {
s.srv = httptest.NewServer(handler)
s.verifier.verifyURL = s.srv.URL
s.verifier.httpClient = s.srv.Client()
}
func (s *TurnstileServiceSuite) TestVerifyToken_SendsFormAndDecodesJSON() {

View File

@@ -203,7 +203,9 @@ func (s *CRSSyncService) SyncFromCRS(ctx context.Context, input SyncFromCRSInput
}
client, err := httpclient.GetClient(httpclient.Options{
Timeout: 20 * time.Second,
Timeout: 20 * time.Second,
ValidateResolvedIP: true,
AllowPrivateHosts: s.cfg.Security.URLAllowlist.AllowPrivateHosts,
})
if err != nil {
client = &http.Client{Timeout: 20 * time.Second}

View File

@@ -498,8 +498,9 @@ func fetchProjectIDFromResourceManager(ctx context.Context, accessToken, proxyUR
req.Header.Set("User-Agent", geminicli.GeminiCLIUserAgent)
client, err := httpclient.GetClient(httpclient.Options{
ProxyURL: strings.TrimSpace(proxyURL),
Timeout: 30 * time.Second,
ProxyURL: strings.TrimSpace(proxyURL),
Timeout: 30 * time.Second,
ValidateResolvedIP: true,
})
if err != nil {
client = &http.Client{Timeout: 30 * time.Second}

View File

@@ -31,6 +31,7 @@ var defaultAllowed = map[string]struct{}{
"x-ratelimit-reset-tokens": {},
"retry-after": {},
"location": {},
"www-authenticate": {},
}
// hopByHopHeaders 是跳过的 hop-by-hop 头部,这些头部由 HTTP 库自动处理

View File

@@ -14,7 +14,7 @@ export function sanitizeUrl(value: string, options: SanitizeOptions = {}): strin
return ''
}
if (options.allowRelative && trimmed.startsWith('/')) {
if (options.allowRelative && trimmed.startsWith('/') && !trimmed.startsWith('//')) {
return trimmed
}