fix(Sora): 加固直连安全与下载限制
补充图片输入 SSRF 防护与重定向限制\n增加媒体下载超时/大小上限配置并更新示例\n完善 recent_tasks 轮询回退策略与相关测试\n\n测试: go test ./... -tags=unit
This commit is contained in:
@@ -10,6 +10,7 @@ import (
|
||||
"fmt"
|
||||
"io"
|
||||
"mime"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"regexp"
|
||||
@@ -26,6 +27,9 @@ var soraImageMarkdownRe = regexp.MustCompile(`!\[[^\]]*\]\(([^)]+)\)`)
|
||||
var soraVideoHTMLRe = regexp.MustCompile(`(?i)<video[^>]+src=['"]([^'"]+)['"]`)
|
||||
|
||||
const soraRewriteBufferLimit = 2048
|
||||
const soraImageInputMaxBytes = 20 << 20
|
||||
const soraImageInputMaxRedirects = 3
|
||||
const soraImageInputTimeout = 20 * time.Second
|
||||
|
||||
var soraImageSizeMap = map[string]string{
|
||||
"gpt-image": "360",
|
||||
@@ -33,6 +37,29 @@ var soraImageSizeMap = map[string]string{
|
||||
"gpt-image-portrait": "540",
|
||||
}
|
||||
|
||||
var soraBlockedHostnames = map[string]struct{}{
|
||||
"localhost": {},
|
||||
"localhost.localdomain": {},
|
||||
"metadata.google.internal": {},
|
||||
"metadata.google.internal.": {},
|
||||
}
|
||||
|
||||
var soraBlockedCIDRs = mustParseCIDRs([]string{
|
||||
"0.0.0.0/8",
|
||||
"10.0.0.0/8",
|
||||
"100.64.0.0/10",
|
||||
"127.0.0.0/8",
|
||||
"169.254.0.0/16",
|
||||
"172.16.0.0/12",
|
||||
"192.168.0.0/16",
|
||||
"224.0.0.0/4",
|
||||
"240.0.0.0/4",
|
||||
"::/128",
|
||||
"::1/128",
|
||||
"fc00::/7",
|
||||
"fe80::/10",
|
||||
})
|
||||
|
||||
type soraStreamingResult struct {
|
||||
mediaType string
|
||||
mediaURLs []string
|
||||
@@ -1233,11 +1260,24 @@ func decodeSoraImageInput(ctx context.Context, input string) ([]byte, string, er
|
||||
}
|
||||
|
||||
func downloadSoraImageInput(ctx context.Context, rawURL string) ([]byte, string, error) {
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, rawURL, nil)
|
||||
parsed, err := validateSoraImageURL(rawURL)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
resp, err := http.DefaultClient.Do(req)
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, parsed.String(), nil)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
client := &http.Client{
|
||||
Timeout: soraImageInputTimeout,
|
||||
CheckRedirect: func(req *http.Request, via []*http.Request) error {
|
||||
if len(via) >= soraImageInputMaxRedirects {
|
||||
return errors.New("too many redirects")
|
||||
}
|
||||
return validateSoraImageURLValue(req.URL)
|
||||
},
|
||||
}
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
@@ -1245,14 +1285,88 @@ func downloadSoraImageInput(ctx context.Context, rawURL string) ([]byte, string,
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, "", fmt.Errorf("download image failed: %d", resp.StatusCode)
|
||||
}
|
||||
data, err := io.ReadAll(io.LimitReader(resp.Body, 20<<20))
|
||||
data, err := io.ReadAll(io.LimitReader(resp.Body, soraImageInputMaxBytes))
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
ext := fileExtFromURL(rawURL)
|
||||
ext := fileExtFromURL(parsed.String())
|
||||
if ext == "" {
|
||||
ext = fileExtFromContentType(resp.Header.Get("Content-Type"))
|
||||
}
|
||||
filename := "image" + ext
|
||||
return data, filename, nil
|
||||
}
|
||||
|
||||
func validateSoraImageURL(raw string) (*url.URL, error) {
|
||||
if strings.TrimSpace(raw) == "" {
|
||||
return nil, errors.New("empty image url")
|
||||
}
|
||||
parsed, err := url.Parse(raw)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid image url: %w", err)
|
||||
}
|
||||
if err := validateSoraImageURLValue(parsed); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return parsed, nil
|
||||
}
|
||||
|
||||
func validateSoraImageURLValue(parsed *url.URL) error {
|
||||
if parsed == nil {
|
||||
return errors.New("invalid image url")
|
||||
}
|
||||
scheme := strings.ToLower(strings.TrimSpace(parsed.Scheme))
|
||||
if scheme != "http" && scheme != "https" {
|
||||
return errors.New("only http/https image url is allowed")
|
||||
}
|
||||
if parsed.User != nil {
|
||||
return errors.New("image url cannot contain userinfo")
|
||||
}
|
||||
host := strings.ToLower(strings.TrimSpace(parsed.Hostname()))
|
||||
if host == "" {
|
||||
return errors.New("image url missing host")
|
||||
}
|
||||
if _, blocked := soraBlockedHostnames[host]; blocked {
|
||||
return errors.New("image url is not allowed")
|
||||
}
|
||||
if ip := net.ParseIP(host); ip != nil {
|
||||
if isSoraBlockedIP(ip) {
|
||||
return errors.New("image url is not allowed")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
ips, err := net.LookupIP(host)
|
||||
if err != nil {
|
||||
return fmt.Errorf("resolve image url failed: %w", err)
|
||||
}
|
||||
for _, ip := range ips {
|
||||
if isSoraBlockedIP(ip) {
|
||||
return errors.New("image url is not allowed")
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func isSoraBlockedIP(ip net.IP) bool {
|
||||
if ip == nil {
|
||||
return true
|
||||
}
|
||||
for _, cidr := range soraBlockedCIDRs {
|
||||
if cidr.Contains(ip) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func mustParseCIDRs(values []string) []*net.IPNet {
|
||||
out := make([]*net.IPNet, 0, len(values))
|
||||
for _, val := range values {
|
||||
_, cidr, err := net.ParseCIDR(val)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
out = append(out, cidr)
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user