feat: implement SSRF protection settings and update related references
This commit is contained in:
69
service/download.go
Normal file
69
service/download.go
Normal file
@@ -0,0 +1,69 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"one-api/setting/system_setting"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// WorkerRequest Worker请求的数据结构
|
||||
type WorkerRequest struct {
|
||||
URL string `json:"url"`
|
||||
Key string `json:"key"`
|
||||
Method string `json:"method,omitempty"`
|
||||
Headers map[string]string `json:"headers,omitempty"`
|
||||
Body json.RawMessage `json:"body,omitempty"`
|
||||
}
|
||||
|
||||
// DoWorkerRequest 通过Worker发送请求
|
||||
func DoWorkerRequest(req *WorkerRequest) (*http.Response, error) {
|
||||
if !system_setting.EnableWorker() {
|
||||
return nil, fmt.Errorf("worker not enabled")
|
||||
}
|
||||
if !system_setting.WorkerAllowHttpImageRequestEnabled && !strings.HasPrefix(req.URL, "https") {
|
||||
return nil, fmt.Errorf("only support https url")
|
||||
}
|
||||
|
||||
// SSRF防护:验证请求URL
|
||||
fetchSetting := system_setting.GetFetchSetting()
|
||||
if err := common.ValidateURLWithFetchSetting(req.URL, fetchSetting.EnableSSRFProtection, fetchSetting.AllowPrivateIp, fetchSetting.WhitelistDomains, fetchSetting.WhitelistIps, fetchSetting.AllowedPorts); err != nil {
|
||||
return nil, fmt.Errorf("request reject: %v", err)
|
||||
}
|
||||
|
||||
workerUrl := system_setting.WorkerUrl
|
||||
if !strings.HasSuffix(workerUrl, "/") {
|
||||
workerUrl += "/"
|
||||
}
|
||||
|
||||
// 序列化worker请求数据
|
||||
workerPayload, err := json.Marshal(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to marshal worker payload: %v", err)
|
||||
}
|
||||
|
||||
return http.Post(workerUrl, "application/json", bytes.NewBuffer(workerPayload))
|
||||
}
|
||||
|
||||
func DoDownloadRequest(originUrl string, reason ...string) (resp *http.Response, err error) {
|
||||
if system_setting.EnableWorker() {
|
||||
common.SysLog(fmt.Sprintf("downloading file from worker: %s, reason: %s", originUrl, strings.Join(reason, ", ")))
|
||||
req := &WorkerRequest{
|
||||
URL: originUrl,
|
||||
Key: system_setting.WorkerValidKey,
|
||||
}
|
||||
return DoWorkerRequest(req)
|
||||
} else {
|
||||
// SSRF防护:验证请求URL(非Worker模式)
|
||||
fetchSetting := system_setting.GetFetchSetting()
|
||||
if err := common.ValidateURLWithFetchSetting(originUrl, fetchSetting.EnableSSRFProtection, fetchSetting.AllowPrivateIp, fetchSetting.WhitelistDomains, fetchSetting.WhitelistIps, fetchSetting.AllowedPorts); err != nil {
|
||||
return nil, fmt.Errorf("request reject: %v", err)
|
||||
}
|
||||
|
||||
common.SysLog(fmt.Sprintf("downloading from origin: %s, reason: %s", common.MaskSensitiveInfo(originUrl), strings.Join(reason, ", ")))
|
||||
return http.Get(originUrl)
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user