From e77effaf8b5b5054872c8c579a4dd4d5240aeee5 Mon Sep 17 00:00:00 2001 From: CaIon Date: Tue, 12 Aug 2025 19:57:56 +0800 Subject: [PATCH] fix(adaptor): optimize multipart form handling and resource management --- relay/channel/openai/adaptor.go | 53 ++++++++++++++++++--------------- 1 file changed, 29 insertions(+), 24 deletions(-) diff --git a/relay/channel/openai/adaptor.go b/relay/channel/openai/adaptor.go index 4c7ba60e..fc1749a0 100644 --- a/relay/channel/openai/adaptor.go +++ b/relay/channel/openai/adaptor.go @@ -359,40 +359,42 @@ func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInf writer := multipart.NewWriter(&requestBody) writer.WriteField("model", request.Model) - // 获取所有表单字段 - formData := c.Request.PostForm - // 遍历表单字段并打印输出 - for key, values := range formData { - if key == "model" { - continue + // 使用已解析的 multipart 表单,避免重复解析 + mf := c.Request.MultipartForm + if mf == nil { + if _, err := c.MultipartForm(); err != nil { + return nil, errors.New("failed to parse multipart form") } - for _, value := range values { - writer.WriteField(key, value) + mf = c.Request.MultipartForm + } + + // 写入所有非文件字段 + if mf != nil { + for key, values := range mf.Value { + if key == "model" { + continue + } + for _, value := range values { + writer.WriteField(key, value) + } } } - // Parse the multipart form to handle both single image and multiple images - if err := c.Request.ParseMultipartForm(32 << 20); err != nil { // 32MB max memory - return nil, errors.New("failed to parse multipart form") - } - - if c.Request.MultipartForm != nil && c.Request.MultipartForm.File != nil { + if mf != nil && mf.File != nil { // Check if "image" field exists in any form, including array notation var imageFiles []*multipart.FileHeader var exists bool // First check for standard "image" field - if imageFiles, exists = c.Request.MultipartForm.File["image"]; !exists || len(imageFiles) == 0 { + if imageFiles, exists = mf.File["image"]; !exists || len(imageFiles) == 0 { // If not found, check for "image[]" field - if imageFiles, exists = c.Request.MultipartForm.File["image[]"]; !exists || len(imageFiles) == 0 { + if imageFiles, exists = mf.File["image[]"]; !exists || len(imageFiles) == 0 { // If still not found, iterate through all fields to find any that start with "image[" foundArrayImages := false - for fieldName, files := range c.Request.MultipartForm.File { + for fieldName, files := range mf.File { if strings.HasPrefix(fieldName, "image[") && len(files) > 0 { foundArrayImages = true - for _, file := range files { - imageFiles = append(imageFiles, file) - } + imageFiles = append(imageFiles, files...) } } @@ -409,7 +411,6 @@ func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInf if err != nil { return nil, fmt.Errorf("failed to open image file %d: %w", i, err) } - defer file.Close() // If multiple images, use image[] as the field name fieldName := "image" @@ -433,15 +434,18 @@ func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInf if _, err := io.Copy(part, file); err != nil { return nil, fmt.Errorf("copy file failed for image %d: %w", i, err) } + + // 复制完立即关闭,避免在循环内使用 defer 占用资源 + _ = file.Close() } // Handle mask file if present - if maskFiles, exists := c.Request.MultipartForm.File["mask"]; exists && len(maskFiles) > 0 { + if maskFiles, exists := mf.File["mask"]; exists && len(maskFiles) > 0 { maskFile, err := maskFiles[0].Open() if err != nil { return nil, errors.New("failed to open mask file") } - defer maskFile.Close() + // 复制完立即关闭,避免在循环内使用 defer 占用资源 // Determine MIME type for mask file mimeType := detectImageMimeType(maskFiles[0].Filename) @@ -459,6 +463,7 @@ func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInf if _, err := io.Copy(maskPart, maskFile); err != nil { return nil, errors.New("copy mask file failed") } + _ = maskFile.Close() } } else { return nil, errors.New("no multipart form data found") @@ -467,7 +472,7 @@ func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInf // 关闭 multipart 编写器以设置分界线 writer.Close() c.Request.Header.Set("Content-Type", writer.FormDataContentType()) - return bytes.NewReader(requestBody.Bytes()), nil + return &requestBody, nil default: return request, nil