fix(Sora): 加固直连安全与下载限制

补充图片输入 SSRF 防护与重定向限制\n增加媒体下载超时/大小上限配置并更新示例\n完善 recent_tasks 轮询回退策略与相关测试\n\n测试: go test ./... -tags=unit
This commit is contained in:
yangjianbo
2026-02-01 22:10:15 +08:00
parent dcf5f60237
commit 99250ec527
8 changed files with 301 additions and 13 deletions

View File

@@ -10,6 +10,7 @@ import (
"fmt"
"io"
"mime"
"net"
"net/http"
"net/url"
"regexp"
@@ -26,6 +27,9 @@ var soraImageMarkdownRe = regexp.MustCompile(`!\[[^\]]*\]\(([^)]+)\)`)
var soraVideoHTMLRe = regexp.MustCompile(`(?i)<video[^>]+src=['"]([^'"]+)['"]`)
const soraRewriteBufferLimit = 2048
const soraImageInputMaxBytes = 20 << 20
const soraImageInputMaxRedirects = 3
const soraImageInputTimeout = 20 * time.Second
var soraImageSizeMap = map[string]string{
"gpt-image": "360",
@@ -33,6 +37,29 @@ var soraImageSizeMap = map[string]string{
"gpt-image-portrait": "540",
}
var soraBlockedHostnames = map[string]struct{}{
"localhost": {},
"localhost.localdomain": {},
"metadata.google.internal": {},
"metadata.google.internal.": {},
}
var soraBlockedCIDRs = mustParseCIDRs([]string{
"0.0.0.0/8",
"10.0.0.0/8",
"100.64.0.0/10",
"127.0.0.0/8",
"169.254.0.0/16",
"172.16.0.0/12",
"192.168.0.0/16",
"224.0.0.0/4",
"240.0.0.0/4",
"::/128",
"::1/128",
"fc00::/7",
"fe80::/10",
})
type soraStreamingResult struct {
mediaType string
mediaURLs []string
@@ -1233,11 +1260,24 @@ func decodeSoraImageInput(ctx context.Context, input string) ([]byte, string, er
}
func downloadSoraImageInput(ctx context.Context, rawURL string) ([]byte, string, error) {
req, err := http.NewRequestWithContext(ctx, http.MethodGet, rawURL, nil)
parsed, err := validateSoraImageURL(rawURL)
if err != nil {
return nil, "", err
}
resp, err := http.DefaultClient.Do(req)
req, err := http.NewRequestWithContext(ctx, http.MethodGet, parsed.String(), nil)
if err != nil {
return nil, "", err
}
client := &http.Client{
Timeout: soraImageInputTimeout,
CheckRedirect: func(req *http.Request, via []*http.Request) error {
if len(via) >= soraImageInputMaxRedirects {
return errors.New("too many redirects")
}
return validateSoraImageURLValue(req.URL)
},
}
resp, err := client.Do(req)
if err != nil {
return nil, "", err
}
@@ -1245,14 +1285,88 @@ func downloadSoraImageInput(ctx context.Context, rawURL string) ([]byte, string,
if resp.StatusCode != http.StatusOK {
return nil, "", fmt.Errorf("download image failed: %d", resp.StatusCode)
}
data, err := io.ReadAll(io.LimitReader(resp.Body, 20<<20))
data, err := io.ReadAll(io.LimitReader(resp.Body, soraImageInputMaxBytes))
if err != nil {
return nil, "", err
}
ext := fileExtFromURL(rawURL)
ext := fileExtFromURL(parsed.String())
if ext == "" {
ext = fileExtFromContentType(resp.Header.Get("Content-Type"))
}
filename := "image" + ext
return data, filename, nil
}
func validateSoraImageURL(raw string) (*url.URL, error) {
if strings.TrimSpace(raw) == "" {
return nil, errors.New("empty image url")
}
parsed, err := url.Parse(raw)
if err != nil {
return nil, fmt.Errorf("invalid image url: %w", err)
}
if err := validateSoraImageURLValue(parsed); err != nil {
return nil, err
}
return parsed, nil
}
func validateSoraImageURLValue(parsed *url.URL) error {
if parsed == nil {
return errors.New("invalid image url")
}
scheme := strings.ToLower(strings.TrimSpace(parsed.Scheme))
if scheme != "http" && scheme != "https" {
return errors.New("only http/https image url is allowed")
}
if parsed.User != nil {
return errors.New("image url cannot contain userinfo")
}
host := strings.ToLower(strings.TrimSpace(parsed.Hostname()))
if host == "" {
return errors.New("image url missing host")
}
if _, blocked := soraBlockedHostnames[host]; blocked {
return errors.New("image url is not allowed")
}
if ip := net.ParseIP(host); ip != nil {
if isSoraBlockedIP(ip) {
return errors.New("image url is not allowed")
}
return nil
}
ips, err := net.LookupIP(host)
if err != nil {
return fmt.Errorf("resolve image url failed: %w", err)
}
for _, ip := range ips {
if isSoraBlockedIP(ip) {
return errors.New("image url is not allowed")
}
}
return nil
}
func isSoraBlockedIP(ip net.IP) bool {
if ip == nil {
return true
}
for _, cidr := range soraBlockedCIDRs {
if cidr.Contains(ip) {
return true
}
}
return false
}
func mustParseCIDRs(values []string) []*net.IPNet {
out := make([]*net.IPNet, 0, len(values))
for _, val := range values {
_, cidr, err := net.ParseCIDR(val)
if err != nil {
continue
}
out = append(out, cidr)
}
return out
}