diff --git a/dto/dalle.go b/dto/dalle.go deleted file mode 100644 index 6ad5b6b6..00000000 --- a/dto/dalle.go +++ /dev/null @@ -1,117 +0,0 @@ -package dto - -import ( - "encoding/json" - "reflect" -) - -type ImageRequest struct { - Model string `json:"model"` - Prompt string `json:"prompt" binding:"required"` - N int `json:"n,omitempty"` - Size string `json:"size,omitempty"` - Quality string `json:"quality,omitempty"` - ResponseFormat string `json:"response_format,omitempty"` - Style string `json:"style,omitempty"` - User string `json:"user,omitempty"` - ExtraFields json.RawMessage `json:"extra_fields,omitempty"` - Background string `json:"background,omitempty"` - Moderation string `json:"moderation,omitempty"` - OutputFormat string `json:"output_format,omitempty"` - // 用匿名字段接住额外的字段 - Extra map[string]json.RawMessage `json:"-"` -} - -func (r *ImageRequest) UnmarshalJSON(data []byte) error { - // 先解析成 map[string]interface{} - var rawMap map[string]json.RawMessage - if err := json.Unmarshal(data, &rawMap); err != nil { - return err - } - - // 用 struct tag 获取所有已定义字段名 - knownFields := GetJSONFieldNames(reflect.TypeOf(*r)) - - // 再正常解析已定义字段 - type Alias ImageRequest - var known Alias - if err := json.Unmarshal(data, &known); err != nil { - return err - } - *r = ImageRequest(known) - - // 提取多余字段 - r.Extra = make(map[string]json.RawMessage) - for k, v := range rawMap { - if _, ok := knownFields[k]; !ok { - r.Extra[k] = v - } - } - return nil -} - -func (r ImageRequest) MarshalJSON() ([]byte, error) { - // 将已定义字段转为 map - type Alias ImageRequest - alias := Alias(r) - base, err := json.Marshal(alias) - if err != nil { - return nil, err - } - - var baseMap map[string]json.RawMessage - if err := json.Unmarshal(base, &baseMap); err != nil { - return nil, err - } - - // 合并 ExtraFields - for k, v := range r.Extra { - baseMap[k] = v - } - - return json.Marshal(baseMap) -} - -type ImageResponse struct { - Data []ImageData `json:"data"` - Created int64 `json:"created"` -} -type ImageData struct { - Url string `json:"url"` - B64Json string `json:"b64_json"` - RevisedPrompt string `json:"revised_prompt"` -} - -func GetJSONFieldNames(t reflect.Type) map[string]struct{} { - fields := make(map[string]struct{}) - for i := 0; i < t.NumField(); i++ { - field := t.Field(i) - - // 跳过匿名字段(例如 ExtraFields) - if field.Anonymous { - continue - } - - tag := field.Tag.Get("json") - if tag == "-" || tag == "" { - continue - } - - // 取逗号前字段名(排除 omitempty 等) - name := tag - if commaIdx := indexComma(tag); commaIdx != -1 { - name = tag[:commaIdx] - } - fields[name] = struct{}{} - } - return fields -} - -func indexComma(s string) int { - for i := 0; i < len(s); i++ { - if s[i] == ',' { - return i - } - } - return -1 -} diff --git a/dto/openai_image.go b/dto/openai_image.go index 8833e774..9e838688 100644 --- a/dto/openai_image.go +++ b/dto/openai_image.go @@ -2,7 +2,9 @@ package dto import ( "encoding/json" + "one-api/common" "one-api/types" + "reflect" "strings" "github.com/gin-gonic/gin" @@ -29,6 +31,68 @@ type ImageRequest struct { Extra map[string]json.RawMessage `json:"-"` } +func (i *ImageRequest) UnmarshalJSON(data []byte) error { + // 先解析成 map[string]interface{} + var rawMap map[string]json.RawMessage + if err := common.Unmarshal(data, &rawMap); err != nil { + return err + } + + // 用 struct tag 获取所有已定义字段名 + knownFields := GetJSONFieldNames(reflect.TypeOf(*i)) + + // 再正常解析已定义字段 + type Alias ImageRequest + var known Alias + if err := common.Unmarshal(data, &known); err != nil { + return err + } + *i = ImageRequest(known) + + // 提取多余字段 + i.Extra = make(map[string]json.RawMessage) + for k, v := range rawMap { + if _, ok := knownFields[k]; !ok { + i.Extra[k] = v + } + } + return nil +} + +func GetJSONFieldNames(t reflect.Type) map[string]struct{} { + fields := make(map[string]struct{}) + for i := 0; i < t.NumField(); i++ { + field := t.Field(i) + + // 跳过匿名字段(例如 ExtraFields) + if field.Anonymous { + continue + } + + tag := field.Tag.Get("json") + if tag == "-" || tag == "" { + continue + } + + // 取逗号前字段名(排除 omitempty 等) + name := tag + if commaIdx := indexComma(tag); commaIdx != -1 { + name = tag[:commaIdx] + } + fields[name] = struct{}{} + } + return fields +} + +func indexComma(s string) int { + for i := 0; i < len(s); i++ { + if s[i] == ',' { + return i + } + } + return -1 +} + func (i *ImageRequest) GetTokenCountMeta() *types.TokenCountMeta { var sizeRatio = 1.0 var qualityRatio = 1.0 diff --git a/dto/openai_request.go b/dto/openai_request.go index 02f969a7..881ec224 100644 --- a/dto/openai_request.go +++ b/dto/openai_request.go @@ -67,8 +67,8 @@ type GeneralOpenAIRequest struct { Reasoning json.RawMessage `json:"reasoning,omitempty"` // Ali Qwen Params VlHighResolutionImages json.RawMessage `json:"vl_high_resolution_images,omitempty"` - // 用匿名参数接收额外参数,例如ollama的think参数在此接收 - Extra map[string]json.RawMessage `json:"-"` + // ollama Params + Think json.RawMessage `json:"think,omitempty"` } func (r *GeneralOpenAIRequest) GetTokenCountMeta() *types.TokenCountMeta { diff --git a/middleware/distributor.go b/middleware/distributor.go index a80ed3c6..1e6df872 100644 --- a/middleware/distributor.go +++ b/middleware/distributor.go @@ -185,7 +185,7 @@ func getModelRequest(c *gin.Context) (*ModelRequest, bool, error) { modelRequest.Model = modelName } c.Set("relay_mode", relayMode) - } else if !strings.HasPrefix(c.Request.URL.Path, "/v1/audio/transcriptions") && !strings.HasPrefix(c.Request.URL.Path, "/v1/images/edits") { + } else if !strings.HasPrefix(c.Request.URL.Path, "/v1/audio/transcriptions") && !strings.Contains(c.Request.Header.Get("Content-Type"), "multipart/form-data") { err = common.UnmarshalBodyReusable(c, &modelRequest) } if err != nil { @@ -208,7 +208,10 @@ func getModelRequest(c *gin.Context) (*ModelRequest, bool, error) { if strings.HasPrefix(c.Request.URL.Path, "/v1/images/generations") { modelRequest.Model = common.GetStringIfEmpty(modelRequest.Model, "dall-e") } else if strings.HasPrefix(c.Request.URL.Path, "/v1/images/edits") { - modelRequest.Model = common.GetStringIfEmpty(c.PostForm("model"), "gpt-image-1") + //modelRequest.Model = common.GetStringIfEmpty(c.PostForm("model"), "gpt-image-1") + if strings.Contains(c.Request.Header.Get("Content-Type"), "multipart/form-data") { + modelRequest.Model = c.PostForm("model") + } } if strings.HasPrefix(c.Request.URL.Path, "/v1/audio") { relayMode := relayconstant.RelayModeAudioSpeech diff --git a/relay/channel/ali/adaptor.go b/relay/channel/ali/adaptor.go index c676badc..3ce9e22d 100644 --- a/relay/channel/ali/adaptor.go +++ b/relay/channel/ali/adaptor.go @@ -3,7 +3,6 @@ package ali import ( "errors" "fmt" - "github.com/gin-gonic/gin" "io" "net/http" "one-api/dto" @@ -14,6 +13,8 @@ import ( "one-api/relay/constant" "one-api/types" "strings" + + "github.com/gin-gonic/gin" ) type Adaptor struct { @@ -44,6 +45,8 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) { fullRequestURL = fmt.Sprintf("%s/api/v1/services/rerank/text-rerank/text-rerank", info.ChannelBaseUrl) case constant.RelayModeImagesGenerations: fullRequestURL = fmt.Sprintf("%s/api/v1/services/aigc/text2image/image-synthesis", info.ChannelBaseUrl) + case constant.RelayModeImagesEdits: + fullRequestURL = fmt.Sprintf("%s/api/v1/services/aigc/multimodal-generation/generation", info.ChannelBaseUrl) case constant.RelayModeCompletions: fullRequestURL = fmt.Sprintf("%s/compatible-mode/v1/completions", info.ChannelBaseUrl) default: @@ -66,6 +69,9 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *rel if info.RelayMode == constant.RelayModeImagesGenerations { req.Set("X-DashScope-Async", "enable") } + if info.RelayMode == constant.RelayModeImagesEdits { + req.Set("Content-Type", "application/json") + } return nil } @@ -93,11 +99,30 @@ func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayIn } func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) { - aliRequest, err := oaiImage2Ali(request) - if err != nil { - return nil, fmt.Errorf("convert image request failed: %w", err) + if info.RelayMode == constant.RelayModeImagesGenerations { + aliRequest, err := oaiImage2Ali(request) + if err != nil { + return nil, fmt.Errorf("convert image request failed: %w", err) + } + return aliRequest, nil + } else if info.RelayMode == constant.RelayModeImagesEdits { + // ali image edit https://bailian.console.aliyun.com/?tab=api#/api/?type=model&url=2976416 + // 如果用户使用表单,则需要解析表单数据 + if strings.Contains(c.Request.Header.Get("Content-Type"), "multipart/form-data") { + aliRequest, err := oaiFormEdit2AliImageEdit(c, info, request) + if err != nil { + return nil, fmt.Errorf("convert image edit form request failed: %w", err) + } + return aliRequest, nil + } else { + aliRequest, err := oaiImage2Ali(request) + if err != nil { + return nil, fmt.Errorf("convert image request failed: %w", err) + } + return aliRequest, nil + } } - return aliRequest, nil + return nil, fmt.Errorf("unsupported image relay mode: %d", info.RelayMode) } func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) { @@ -134,6 +159,8 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycom switch info.RelayMode { case constant.RelayModeImagesGenerations: err, usage = aliImageHandler(c, resp, info) + case constant.RelayModeImagesEdits: + err, usage = aliImageEditHandler(c, resp, info) case constant.RelayModeRerank: err, usage = RerankHandler(c, resp, info) default: diff --git a/relay/channel/ali/dto.go b/relay/channel/ali/dto.go index d40e077d..0873c99f 100644 --- a/relay/channel/ali/dto.go +++ b/relay/channel/ali/dto.go @@ -3,10 +3,15 @@ package ali import "one-api/dto" type AliMessage struct { - Content string `json:"content"` + Content any `json:"content"` Role string `json:"role"` } +type AliMediaContent struct { + Image string `json:"image,omitempty"` + Text string `json:"text,omitempty"` +} + type AliInput struct { Prompt string `json:"prompt,omitempty"` //History []AliMessage `json:"history,omitempty"` @@ -70,13 +75,14 @@ type TaskResult struct { } type AliOutput struct { - TaskId string `json:"task_id,omitempty"` - TaskStatus string `json:"task_status,omitempty"` - Text string `json:"text"` - FinishReason string `json:"finish_reason"` - Message string `json:"message,omitempty"` - Code string `json:"code,omitempty"` - Results []TaskResult `json:"results,omitempty"` + TaskId string `json:"task_id,omitempty"` + TaskStatus string `json:"task_status,omitempty"` + Text string `json:"text"` + FinishReason string `json:"finish_reason"` + Message string `json:"message,omitempty"` + Code string `json:"code,omitempty"` + Results []TaskResult `json:"results,omitempty"` + Choices []map[string]any `json:"choices,omitempty"` } type AliResponse struct { @@ -101,8 +107,9 @@ type AliImageParameters struct { } type AliImageInput struct { - Prompt string `json:"prompt"` - NegativePrompt string `json:"negative_prompt,omitempty"` + Prompt string `json:"prompt,omitempty"` + NegativePrompt string `json:"negative_prompt,omitempty"` + Messages []AliMessage `json:"messages,omitempty"` } type AliRerankParameters struct { diff --git a/relay/channel/ali/image.go b/relay/channel/ali/image.go index 78b0c334..490c9d0a 100644 --- a/relay/channel/ali/image.go +++ b/relay/channel/ali/image.go @@ -1,9 +1,12 @@ package ali import ( + "context" + "encoding/base64" "errors" "fmt" "io" + "mime/multipart" "net/http" "one-api/common" "one-api/dto" @@ -21,7 +24,7 @@ func oaiImage2Ali(request dto.ImageRequest) (*AliImageRequest, error) { var imageRequest AliImageRequest imageRequest.Model = request.Model imageRequest.ResponseFormat = request.ResponseFormat - + logger.LogJson(context.Background(), "oaiImage2Ali request extra", request.Extra) if request.Extra != nil { if val, ok := request.Extra["parameters"]; ok { err := common.Unmarshal(val, &imageRequest.Parameters) @@ -54,6 +57,100 @@ func oaiImage2Ali(request dto.ImageRequest) (*AliImageRequest, error) { return &imageRequest, nil } +func oaiFormEdit2AliImageEdit(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (*AliImageRequest, error) { + var imageRequest AliImageRequest + imageRequest.Model = request.Model + imageRequest.ResponseFormat = request.ResponseFormat + + mf := c.Request.MultipartForm + if mf == nil { + if _, err := c.MultipartForm(); err != nil { + return nil, fmt.Errorf("failed to parse image edit form request: %w", err) + } + mf = c.Request.MultipartForm + } + + var imageFiles []*multipart.FileHeader + var exists bool + + // First check for standard "image" field + if imageFiles, exists = mf.File["image"]; !exists || len(imageFiles) == 0 { + // If not found, check for "image[]" field + if imageFiles, exists = mf.File["image[]"]; !exists || len(imageFiles) == 0 { + // If still not found, iterate through all fields to find any that start with "image[" + foundArrayImages := false + for fieldName, files := range mf.File { + if strings.HasPrefix(fieldName, "image[") && len(files) > 0 { + foundArrayImages = true + imageFiles = append(imageFiles, files...) + } + } + + // If no image fields found at all + if !foundArrayImages && (len(imageFiles) == 0) { + return nil, errors.New("image is required") + } + } + } + + if len(imageFiles) == 0 { + return nil, errors.New("image is required") + } + + if len(imageFiles) > 1 { + return nil, errors.New("only one image is supported for qwen edit") + } + + // 获取base64编码的图片 + var imageBase64s []string + for _, file := range imageFiles { + image, err := file.Open() + if err != nil { + return nil, errors.New("failed to open image file") + } + + // 读取文件内容 + imageData, err := io.ReadAll(image) + if err != nil { + return nil, errors.New("failed to read image file") + } + + // 获取MIME类型 + mimeType := http.DetectContentType(imageData) + + // 编码为base64 + base64Data := base64.StdEncoding.EncodeToString(imageData) + + // 构造data URL格式 + dataURL := fmt.Sprintf("data:%s;base64,%s", mimeType, base64Data) + imageBase64s = append(imageBase64s, dataURL) + image.Close() + } + + //dto.MediaContent{} + mediaContents := make([]AliMediaContent, len(imageBase64s)) + for i, b64 := range imageBase64s { + mediaContents[i] = AliMediaContent{ + Image: b64, + } + } + mediaContents = append(mediaContents, AliMediaContent{ + Text: request.Prompt, + }) + imageRequest.Input = AliImageInput{ + Messages: []AliMessage{ + { + Role: "user", + Content: mediaContents, + }, + }, + } + imageRequest.Parameters = AliImageParameters{ + Watermark: request.Watermark, + } + return &imageRequest, nil +} + func updateTask(info *relaycommon.RelayInfo, taskID string) (*AliResponse, error, []byte) { url := fmt.Sprintf("%s/api/v1/tasks/%s", info.ChannelBaseUrl, taskID) @@ -196,8 +293,47 @@ func aliImageHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rela if err != nil { return types.NewError(err, types.ErrorCodeBadResponseBody), nil } - c.Writer.Header().Set("Content-Type", "application/json") - c.Writer.WriteHeader(resp.StatusCode) - c.Writer.Write(jsonResponse) + service.IOCopyBytesGracefully(c, resp, jsonResponse) + return nil, &dto.Usage{} +} + +func aliImageEditHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*types.NewAPIError, *dto.Usage) { + var aliResponse AliResponse + responseBody, err := io.ReadAll(resp.Body) + if err != nil { + return types.NewOpenAIError(err, types.ErrorCodeReadResponseBodyFailed, http.StatusInternalServerError), nil + } + + service.CloseResponseBodyGracefully(resp) + err = common.Unmarshal(responseBody, &aliResponse) + if err != nil { + return types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError), nil + } + + if aliResponse.Message != "" { + logger.LogError(c, "ali_task_failed: "+aliResponse.Message) + return types.NewError(errors.New(aliResponse.Message), types.ErrorCodeBadResponse), nil + } + var fullTextResponse dto.ImageResponse + if len(aliResponse.Output.Choices) > 0 { + fullTextResponse = dto.ImageResponse{ + Created: info.StartTime.Unix(), + Data: []dto.ImageData{ + { + Url: aliResponse.Output.Choices[0]["message"].(map[string]any)["content"].([]any)[0].(map[string]any)["image"].(string), + B64Json: "", + }, + }, + } + } + + var mapResponse map[string]any + _ = common.Unmarshal(responseBody, &mapResponse) + fullTextResponse.Extra = mapResponse + jsonResponse, err := common.Marshal(fullTextResponse) + if err != nil { + return types.NewError(err, types.ErrorCodeBadResponseBody), nil + } + service.IOCopyBytesGracefully(c, resp, jsonResponse) return nil, &dto.Usage{} } diff --git a/relay/channel/ollama/relay-ollama.go b/relay/channel/ollama/relay-ollama.go index be2029f5..27c67b4e 100644 --- a/relay/channel/ollama/relay-ollama.go +++ b/relay/channel/ollama/relay-ollama.go @@ -68,9 +68,7 @@ func requestOpenAI2Ollama(c *gin.Context, request *dto.GeneralOpenAIRequest) (*O StreamOptions: request.StreamOptions, Suffix: request.Suffix, } - if think, ok := request.Extra["think"]; ok { - ollamaRequest.Think = think - } + ollamaRequest.Think = request.Think return ollamaRequest, nil } diff --git a/relay/helper/valid_request.go b/relay/helper/valid_request.go index 285f26aa..4d1c1f9b 100644 --- a/relay/helper/valid_request.go +++ b/relay/helper/valid_request.go @@ -132,30 +132,34 @@ func GetAndValidOpenAIImageRequest(c *gin.Context, relayMode int) (*dto.ImageReq switch relayMode { case relayconstant.RelayModeImagesEdits: - _, err := c.MultipartForm() - if err != nil { - return nil, fmt.Errorf("failed to parse image edit form request: %w", err) - } - formData := c.Request.PostForm - imageRequest.Prompt = formData.Get("prompt") - imageRequest.Model = formData.Get("model") - imageRequest.N = uint(common.String2Int(formData.Get("n"))) - imageRequest.Quality = formData.Get("quality") - imageRequest.Size = formData.Get("size") - - if imageRequest.Model == "gpt-image-1" { - if imageRequest.Quality == "" { - imageRequest.Quality = "standard" + if strings.Contains(c.Request.Header.Get("Content-Type"), "multipart/form-data") { + _, err := c.MultipartForm() + if err != nil { + return nil, fmt.Errorf("failed to parse image edit form request: %w", err) } - } - if imageRequest.N == 0 { - imageRequest.N = 1 - } + formData := c.Request.PostForm + imageRequest.Prompt = formData.Get("prompt") + imageRequest.Model = formData.Get("model") + imageRequest.N = uint(common.String2Int(formData.Get("n"))) + imageRequest.Quality = formData.Get("quality") + imageRequest.Size = formData.Get("size") - watermark := formData.Has("watermark") - if watermark { - imageRequest.Watermark = &watermark + if imageRequest.Model == "gpt-image-1" { + if imageRequest.Quality == "" { + imageRequest.Quality = "standard" + } + } + if imageRequest.N == 0 { + imageRequest.N = 1 + } + + watermark := formData.Has("watermark") + if watermark { + imageRequest.Watermark = &watermark + } + break } + fallthrough default: err := common.UnmarshalBodyReusable(c, imageRequest) if err != nil { @@ -163,7 +167,8 @@ func GetAndValidOpenAIImageRequest(c *gin.Context, relayMode int) (*dto.ImageReq } if imageRequest.Model == "" { - imageRequest.Model = "dall-e-3" + //imageRequest.Model = "dall-e-3" + return nil, errors.New("model is required") } if strings.Contains(imageRequest.Size, "×") { @@ -194,9 +199,9 @@ func GetAndValidOpenAIImageRequest(c *gin.Context, relayMode int) (*dto.ImageReq } } - if imageRequest.Prompt == "" { - return nil, errors.New("prompt is required") - } + //if imageRequest.Prompt == "" { + // return nil, errors.New("prompt is required") + //} if imageRequest.N == 0 { imageRequest.N = 1 diff --git a/relay/image_handler.go b/relay/image_handler.go index c700424f..14a7103c 100644 --- a/relay/image_handler.go +++ b/relay/image_handler.go @@ -2,14 +2,13 @@ package relay import ( "bytes" - "encoding/json" "fmt" "io" "net/http" "one-api/common" "one-api/dto" + "one-api/logger" relaycommon "one-api/relay/common" - relayconstant "one-api/relay/constant" "one-api/relay/helper" "one-api/service" "one-api/setting/model_setting" @@ -56,10 +55,12 @@ func ImageHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *type if err != nil { return types.NewError(err, types.ErrorCodeConvertRequestFailed) } - if info.RelayMode == relayconstant.RelayModeImagesEdits { + + switch convertedRequest.(type) { + case *bytes.Buffer: requestBody = convertedRequest.(io.Reader) - } else { - jsonData, err := json.Marshal(convertedRequest) + default: + jsonData, err := common.Marshal(convertedRequest) if err != nil { return types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry()) } @@ -73,7 +74,7 @@ func ImageHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *type } if common.DebugEnabled { - println(fmt.Sprintf("image request body: %s", string(jsonData))) + logger.LogDebug(c, fmt.Sprintf("image request body: %s", string(jsonData))) } requestBody = bytes.NewBuffer(jsonData) }