diff --git a/middleware/distributor.go b/middleware/distributor.go index 525270b1..a33ca5af 100644 --- a/middleware/distributor.go +++ b/middleware/distributor.go @@ -174,14 +174,22 @@ func getModelRequest(c *gin.Context) (*ModelRequest, bool, error) { relayMode := relayconstant.RelayModeUnknown if c.Request.Method == http.MethodPost { relayMode = relayconstant.RelayModeVideoSubmit - form, err := common.ParseMultipartFormReusable(c) - if err != nil { - return nil, false, errors.New("无效的video请求, " + err.Error()) - } - defer form.RemoveAll() - if form != nil { - if values, ok := form.Value["model"]; ok && len(values) > 0 { - modelRequest.Model = values[0] + contentType := c.Request.Header.Get("Content-Type") + if strings.HasPrefix(contentType, "multipart/form-data") { + form, err := common.ParseMultipartFormReusable(c) + if err != nil { + return nil, false, errors.New("无效的video请求, " + err.Error()) + } + defer form.RemoveAll() + if form != nil { + if values, ok := form.Value["model"]; ok && len(values) > 0 { + modelRequest.Model = values[0] + } + } + } else if strings.HasPrefix(contentType, "application/json") { + err = common.UnmarshalBodyReusable(c, &modelRequest) + if err != nil { + return nil, false, errors.New("无效的video请求, " + err.Error()) } } } else if c.Request.Method == http.MethodGet { diff --git a/relay/common/relay_utils.go b/relay/common/relay_utils.go index f18c4374..f7601149 100644 --- a/relay/common/relay_utils.go +++ b/relay/common/relay_utils.go @@ -106,25 +106,53 @@ func validateMultipartTaskRequest(c *gin.Context, info *RelayInfo, action string } func ValidateMultipartDirect(c *gin.Context, info *RelayInfo) *dto.TaskError { - form, err := common.ParseMultipartFormReusable(c) - if err != nil { - return createTaskError(err, "invalid_multipart_form", http.StatusBadRequest, true) - } - defer form.RemoveAll() + contentType := c.GetHeader("Content-Type") + var prompt string + var hasInputReference bool - prompts, ok := form.Value["prompt"] - if !ok || len(prompts) == 0 { - return createTaskError(fmt.Errorf("prompt field is required"), "missing_prompt", http.StatusBadRequest, true) + if strings.HasPrefix(contentType, "multipart/form-data") { + form, err := common.ParseMultipartFormReusable(c) + if err != nil { + return createTaskError(err, "invalid_multipart_form", http.StatusBadRequest, true) + } + defer form.RemoveAll() + + prompts, ok := form.Value["prompt"] + if !ok || len(prompts) == 0 { + return createTaskError(fmt.Errorf("prompt field is required"), "missing_prompt", http.StatusBadRequest, true) + } + prompt = prompts[0] + + if _, ok := form.Value["model"]; !ok { + return createTaskError(fmt.Errorf("model field is required"), "missing_model", http.StatusBadRequest, true) + } + + if _, ok := form.File["input_reference"]; ok { + hasInputReference = true + } + } else { + var req TaskSubmitReq + if err := common.UnmarshalBodyReusable(c, &req); err != nil { + return createTaskError(err, "invalid_json", http.StatusBadRequest, true) + } + + prompt = req.Prompt + + if strings.TrimSpace(req.Model) == "" { + return createTaskError(fmt.Errorf("model field is required"), "missing_model", http.StatusBadRequest, true) + } + + if req.HasImage() { + hasInputReference = true + } } - if taskErr := validatePrompt(prompts[0]); taskErr != nil { + + if taskErr := validatePrompt(prompt); taskErr != nil { return taskErr } - if _, ok := form.Value["model"]; !ok { - return createTaskError(fmt.Errorf("model field is required"), "missing_model", http.StatusBadRequest, true) - } action := constant.TaskActionTextGenerate - if _, ok := form.File["input_reference"]; ok { + if hasInputReference { action = constant.TaskActionGenerate } info.Action = action