package urlvalidator import ( "context" "errors" "fmt" "net" "net/url" "strconv" "strings" "time" ) type ValidationOptions struct { AllowedHosts []string RequireAllowlist bool AllowPrivate bool } func ValidateURLFormat(raw string, allowInsecureHTTP bool) (string, error) { // 最小格式校验:仅保证 URL 可解析且 scheme 合规,不做白名单/私网/SSRF 校验 trimmed := strings.TrimSpace(raw) if trimmed == "" { return "", errors.New("url is required") } parsed, err := url.Parse(trimmed) if err != nil || parsed.Scheme == "" || parsed.Host == "" { return "", fmt.Errorf("invalid url: %s", trimmed) } scheme := strings.ToLower(parsed.Scheme) if scheme != "https" && (!allowInsecureHTTP || scheme != "http") { return "", fmt.Errorf("invalid url scheme: %s", parsed.Scheme) } host := strings.TrimSpace(parsed.Hostname()) if host == "" { return "", errors.New("invalid host") } if port := parsed.Port(); port != "" { num, err := strconv.Atoi(port) if err != nil || num <= 0 || num > 65535 { return "", fmt.Errorf("invalid port: %s", port) } } return trimmed, nil } func ValidateHTTPSURL(raw string, opts ValidationOptions) (string, error) { trimmed := strings.TrimSpace(raw) if trimmed == "" { return "", errors.New("url is required") } parsed, err := url.Parse(trimmed) if err != nil || parsed.Scheme == "" || parsed.Host == "" { return "", fmt.Errorf("invalid url: %s", trimmed) } if !strings.EqualFold(parsed.Scheme, "https") { return "", fmt.Errorf("invalid url scheme: %s", parsed.Scheme) } host := strings.ToLower(strings.TrimSpace(parsed.Hostname())) if host == "" { return "", errors.New("invalid host") } if !opts.AllowPrivate && isBlockedHost(host) { return "", fmt.Errorf("host is not allowed: %s", host) } allowlist := normalizeAllowlist(opts.AllowedHosts) if opts.RequireAllowlist && len(allowlist) == 0 { return "", errors.New("allowlist is not configured") } if len(allowlist) > 0 && !isAllowedHost(host, allowlist) { return "", fmt.Errorf("host is not allowed: %s", host) } parsed.Path = strings.TrimRight(parsed.Path, "/") parsed.RawPath = "" return strings.TrimRight(parsed.String(), "/"), nil } // ValidateResolvedIP 验证 DNS 解析后的 IP 地址是否安全 // 用于防止 DNS Rebinding 攻击:在实际 HTTP 请求时调用此函数验证解析后的 IP func ValidateResolvedIP(host string) error { ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() ips, err := net.DefaultResolver.LookupIP(ctx, "ip", host) if err != nil { return fmt.Errorf("dns resolution failed: %w", err) } for _, ip := range ips { if ip.IsLoopback() || ip.IsPrivate() || ip.IsLinkLocalUnicast() || ip.IsLinkLocalMulticast() || ip.IsUnspecified() { return fmt.Errorf("resolved ip %s is not allowed", ip.String()) } } return nil } func normalizeAllowlist(values []string) []string { if len(values) == 0 { return nil } normalized := make([]string, 0, len(values)) for _, v := range values { entry := strings.ToLower(strings.TrimSpace(v)) if entry == "" { continue } if host, _, err := net.SplitHostPort(entry); err == nil { entry = host } normalized = append(normalized, entry) } return normalized } func isAllowedHost(host string, allowlist []string) bool { for _, entry := range allowlist { if entry == "" { continue } if strings.HasPrefix(entry, "*.") { suffix := strings.TrimPrefix(entry, "*.") if host == suffix || strings.HasSuffix(host, "."+suffix) { return true } continue } if host == entry { return true } } return false } func isBlockedHost(host string) bool { if host == "localhost" || strings.HasSuffix(host, ".localhost") { return true } if ip := net.ParseIP(host); ip != nil { if ip.IsLoopback() || ip.IsPrivate() || ip.IsLinkLocalUnicast() || ip.IsLinkLocalMulticast() || ip.IsUnspecified() { return true } } return false }