@@ -11,16 +11,20 @@ import (
// SSRFProtection SSRF防护配置
type SSRFProtection struct {
AllowPrivateIp bool
WhitelistDomains [ ] string // domain format, e.g. example.com, *.example.com
WhitelistIps [ ] string // CIDR format
DomainFilterMode bool // true: 白名单, false: 黑名单
DomainList [ ] string // domain format, e.g. example.com, *.example.com
IpFilterMode bool // true: 白名单, false: 黑名单
IpList [ ] string // CIDR or single IP
AllowedPorts [ ] int // 允许的端口范围
}
// DefaultSSRFProtection 默认SSRF防护配置
var DefaultSSRFProtection = & SSRFProtection {
AllowPrivateIp : false ,
WhitelistDomains : [ ] string { } ,
WhitelistIps : [ ] string { } ,
DomainFilterMode : true ,
DomainList : [ ] string { } ,
IpFilterMode : true ,
IpList : [ ] string { } ,
AllowedPorts : [ ] int { } ,
}
@@ -138,44 +142,25 @@ func (p *SSRFProtection) isAllowedPort(port int) bool {
return false
}
// isAllowedPortFromRanges 从端口范围字符串检查端口是否被允许
func isAllowedPortFromRanges ( port int , portRanges [ ] string ) bool {
if len ( portRanges ) == 0 {
return true // 如果没有配置端口限制,则允许所有端口
}
allowedPorts , err := parsePortRanges ( portRanges )
if err != nil {
// 如果解析失败,为安全起见拒绝访问
return false
}
for _ , allowedPort := range allowedPorts {
if port == allowedPort {
return true
}
}
return false
}
// isDomainWhitelisted 检查域名是否在白名单中
func ( p * SSRFProtection ) isDomainWhitel isted ( domain string ) bool {
if len ( p . WhitelistDomains ) == 0 {
func isDomainL isted ( domain string , list [ ] string ) bool {
if len ( list ) == 0 {
return false
}
domain = strings . ToLower ( domain )
for _ , wh itelistDomain := range p . WhitelistDomains {
wh itelistDomain = strings . ToLower ( whitelistDomain )
for _ , item := range list {
item = strings . ToLower ( strings . TrimSpace ( item ) )
if item == "" {
continue
}
// 精确匹配
if domain == wh itelistDomain {
if domain == item {
return true
}
// 通配符匹配 (*.example.com)
if strings . HasPrefix ( wh itelistDomain , "*." ) {
suffix := strings . TrimPrefix ( wh itelistDomain , "*." )
if strings . HasPrefix ( item , "*." ) {
suffix := strings . TrimPrefix ( item , "*." )
if strings . HasSuffix ( domain , "." + suffix ) || domain == suffix {
return true
}
@@ -184,13 +169,23 @@ func (p *SSRFProtection) isDomainWhitelisted(domain string) bool {
return false
}
func ( p * SSRFProtection ) isDomainAllowed ( domain string ) bool {
listed := isDomainListed ( domain , p . DomainList )
if p . DomainFilterMode { // 白名单
return listed
}
// 黑名单
return ! listed
}
// isIPWhitelisted 检查IP是否在白名单中
func ( p * SSRFProtection ) isIPWhitelisted ( ip net . IP ) bool {
if len ( p . WhitelistIps ) == 0 {
func isIPListed ( i p net . IP , list [ ] string ) bool {
if len ( list ) == 0 {
return false
}
for _ , whitelistCIDR := range p . White listIps {
for _ , whitelistCIDR := range list {
_ , network , err := net . ParseCIDR ( whitelistCIDR )
if err != nil {
// 尝试作为单个IP处理
@@ -211,22 +206,17 @@ func (p *SSRFProtection) isIPWhitelisted(ip net.IP) bool {
// IsIPAccessAllowed 检查IP是否允许访问
func ( p * SSRFProtection ) IsIPAccessAllowed ( ip net . IP ) bool {
// 如果IP在白名单中, 直接允许访问( 绕过 私有IP检查)
if p . isIPWhitelisted ( ip ) {
return true
}
// 如果IP白名单为空, 允许所有IP( 但仍需通过私有IP检查)
if len ( p . WhitelistIps ) == 0 {
// 检查私有IP限制
// 私有IP限制
if isPrivateIP ( ip ) && ! p . AllowPrivateIp {
return false
}
return true
}
// 如果IP白名单不为空且IP不在白名单中, 拒绝访问
return false
listed := isIPListed ( ip , p . IpList )
if p . IpFilterMode { // 白名单
return listed
}
// 黑名单
return ! listed
}
// ValidateURL 验证URL是否安全
@@ -264,28 +254,44 @@ func (p *SSRFProtection) ValidateURL(urlStr string) error {
return fmt . Errorf ( "port %d is not allowed" , port )
}
// 检查域名白名单
if p . isDomainWhitelisted ( host ) {
return nil // 白名单域名直接通过
// 如果 host 是 IP, 则跳过域名 检查
if i p := net . ParseIP ( host ) ; ip != nil {
if ! p . IsIPAccessAllowed ( ip ) {
if isPrivateIP ( ip ) {
return fmt . Errorf ( "private IP address not allowed: %s" , ip . String ( ) )
}
if p . IpFilterMode {
return fmt . Errorf ( "ip not in whitelist: %s" , ip . String ( ) )
}
return fmt . Errorf ( "ip in blacklist: %s" , ip . String ( ) )
}
return nil
}
// DNS解析获取IP地址
// 先进行域名过滤
if ! p . isDomainAllowed ( host ) {
if p . DomainFilterMode {
return fmt . Errorf ( "domain not in whitelist: %s" , host )
}
return fmt . Errorf ( "domain in blacklist: %s" , host )
}
// 解析域名对应IP并检查
ips , err := net . LookupIP ( host )
if err != nil {
return fmt . Errorf ( "DNS resolution failed for %s: %v" , host , err )
}
// 检查所有解析的IP地址
for _ , ip := range ips {
if ! p . IsIPAccessAllowed ( ip ) {
if isPrivateIP ( ip ) {
if isPrivateIP ( ip ) && ! p . AllowPrivateIp {
return fmt . Errorf ( "private IP address not allowed: %s resolves to %s" , host , ip . String ( ) )
} else {
return fmt . Errorf ( "IP address not in whitelist: %s resolves to %s" , host , ip . String ( ) )
}
if p . IpFilterMode {
return fmt . Errorf ( "ip not in whitelist: %s resolves to %s" , host , ip . String ( ) )
}
return fmt . Errorf ( "ip in blacklist: %s resolves to %s" , host , ip . String ( ) )
}
}
}
return nil
}
@@ -295,7 +301,7 @@ func ValidateURLWithDefaults(urlStr string) error {
}
// ValidateURLWithFetchSetting 使用FetchSetting配置验证URL
func ValidateURLWithFetchSetting ( urlStr string , enableSSRFProtection , allowPrivateIp bool , whitelistD omains , whitelistIps , allowedPorts [ ] string ) error {
func ValidateURLWithFetchSetting ( urlStr string , enableSSRFProtection , allowPrivateIp bool , domainFilterMode bool , ipFilterMode bool , d omainList , ipList , allowedPorts [ ] string ) error {
// 如果SSRF防护被禁用, 直接返回成功
if ! enableSSRFProtection {
return nil
@@ -309,76 +315,11 @@ func ValidateURLWithFetchSetting(urlStr string, enableSSRFProtection, allowPriva
protection := & SSRFProtection {
AllowPrivateIp : allowPrivateIp ,
WhitelistDomains : whitelistDomains ,
WhitelistIps : whitelistIps ,
DomainFilterMode : domainFilterMode ,
DomainList : domainList ,
IpFilterMode : ipFilterMode ,
IpList : ipList ,
AllowedPorts : allowedPortInts ,
}
return protection . ValidateURL ( urlStr )
}
// ValidateURLWithPortRanges 直接使用端口范围字符串验证URL( 更高效的版本)
func ValidateURLWithPortRanges ( urlStr string , allowPrivateIp bool , whitelistDomains , whitelistIps , allowedPorts [ ] string ) error {
// 解析URL
u , err := url . Parse ( urlStr )
if err != nil {
return fmt . Errorf ( "invalid URL format: %v" , err )
}
// 只允许HTTP/HTTPS协议
if u . Scheme != "http" && u . Scheme != "https" {
return fmt . Errorf ( "unsupported protocol: %s (only http/https allowed)" , u . Scheme )
}
// 解析主机和端口
host , portStr , err := net . SplitHostPort ( u . Host )
if err != nil {
// 没有端口,使用默认端口
host = u . Host
if u . Scheme == "https" {
portStr = "443"
} else {
portStr = "80"
}
}
// 验证端口
port , err := strconv . Atoi ( portStr )
if err != nil {
return fmt . Errorf ( "invalid port: %s" , portStr )
}
if ! isAllowedPortFromRanges ( port , allowedPorts ) {
return fmt . Errorf ( "port %d is not allowed" , port )
}
// 创建临时的SSRFProtection来复用域名和IP检查逻辑
protection := & SSRFProtection {
AllowPrivateIp : allowPrivateIp ,
WhitelistDomains : whitelistDomains ,
WhitelistIps : whitelistIps ,
}
// 检查域名白名单
if protection . isDomainWhitelisted ( host ) {
return nil // 白名单域名直接通过
}
// DNS解析获取IP地址
ips , err := net . LookupIP ( host )
if err != nil {
return fmt . Errorf ( "DNS resolution failed for %s: %v" , host , err )
}
// 检查所有解析的IP地址
for _ , ip := range ips {
if ! protection . IsIPAccessAllowed ( ip ) {
if isPrivateIP ( ip ) {
return fmt . Errorf ( "private IP address not allowed: %s resolves to %s" , host , ip . String ( ) )
} else {
return fmt . Errorf ( "IP address not in whitelist: %s resolves to %s" , host , ip . String ( ) )
}
}
}
return nil
}