From 971aea09ee91358d60582c249ebd159dca038e8c Mon Sep 17 00:00:00 2001 From: "1808837298@qq.com" <1808837298@qq.com> Date: Wed, 19 Feb 2025 20:45:42 +0800 Subject: [PATCH] feat: Improve image handling for Ollama channels --- dto/openai_request.go | 19 +++++++++++++-- relay/channel/ollama/adaptor.go | 2 +- relay/channel/ollama/relay-ollama.go | 30 ++++++++++++++++++++---- relay/channel/zhipu_4v/relay-zhipu_v4.go | 3 +-- 4 files changed, 45 insertions(+), 9 deletions(-) diff --git a/dto/openai_request.go b/dto/openai_request.go index a142b437..642fa71c 100644 --- a/dto/openai_request.go +++ b/dto/openai_request.go @@ -88,8 +88,10 @@ func (r GeneralOpenAIRequest) ParseInput() []string { } type Message struct { - Role string `json:"role"` - Content json.RawMessage `json:"content"` + Role string `json:"role"` + Content json.RawMessage `json:"content"` + // parsedContent not json field + parsedContent []MediaContent Name *string `json:"name,omitempty"` Prefix *bool `json:"prefix,omitempty"` ReasoningContent string `json:"reasoning_content,omitempty"` @@ -160,6 +162,11 @@ func (m *Message) SetStringContent(content string) { m.Content = jsonContent } +func (m *Message) SetMediaContent(content []MediaContent) { + jsonContent, _ := json.Marshal(content) + m.Content = jsonContent +} + func (m *Message) IsStringContent() bool { var stringContent string if err := json.Unmarshal(m.Content, &stringContent); err == nil { @@ -169,7 +176,15 @@ func (m *Message) IsStringContent() bool { } func (m *Message) ParseContent() []MediaContent { + if m.parsedContent != nil { + return m.parsedContent + } var contentList []MediaContent + defer func() { + if len(contentList) > 0 { + m.parsedContent = contentList + } + }() var stringContent string if err := json.Unmarshal(m.Content, &stringContent); err == nil { contentList = append(contentList, MediaContent{ diff --git a/relay/channel/ollama/adaptor.go b/relay/channel/ollama/adaptor.go index 36889cb8..e15e2410 100644 --- a/relay/channel/ollama/adaptor.go +++ b/relay/channel/ollama/adaptor.go @@ -46,7 +46,7 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, re if request == nil { return nil, errors.New("request is nil") } - return requestOpenAI2Ollama(*request), nil + return requestOpenAI2Ollama(*request) } func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) { diff --git a/relay/channel/ollama/relay-ollama.go b/relay/channel/ollama/relay-ollama.go index 5a1d50c8..8b53fbfb 100644 --- a/relay/channel/ollama/relay-ollama.go +++ b/relay/channel/ollama/relay-ollama.go @@ -9,14 +9,36 @@ import ( "net/http" "one-api/dto" "one-api/service" + "strings" ) -func requestOpenAI2Ollama(request dto.GeneralOpenAIRequest) *OllamaRequest { +func requestOpenAI2Ollama(request dto.GeneralOpenAIRequest) (*OllamaRequest, error) { messages := make([]dto.Message, 0, len(request.Messages)) for _, message := range request.Messages { + if !message.IsStringContent() { + mediaMessages := message.ParseContent() + for j, mediaMessage := range mediaMessages { + if mediaMessage.Type == dto.ContentTypeImageURL { + imageUrl := mediaMessage.ImageUrl.(dto.MessageImageUrl) + // check if not base64 + if strings.HasPrefix(imageUrl.Url, "http") { + fileData, err := service.GetFileBase64FromUrl(imageUrl.Url) + if err != nil { + return nil, err + } + imageUrl.Url = fmt.Sprintf("data:%s;base64,%s", fileData.MimeType, fileData.Base64Data) + } + mediaMessage.ImageUrl = imageUrl + mediaMessages[j] = mediaMessage + } + } + message.SetMediaContent(mediaMessages) + } messages = append(messages, dto.Message{ - Role: message.Role, - Content: message.Content, + Role: message.Role, + Content: message.Content, + ToolCalls: message.ToolCalls, + ToolCallId: message.ToolCallId, }) } str, ok := request.Stop.(string) @@ -42,7 +64,7 @@ func requestOpenAI2Ollama(request dto.GeneralOpenAIRequest) *OllamaRequest { Prompt: request.Prompt, StreamOptions: request.StreamOptions, Suffix: request.Suffix, - } + }, nil } func requestOpenAI2Embeddings(request dto.EmbeddingRequest) *OllamaEmbeddingRequest { diff --git a/relay/channel/zhipu_4v/relay-zhipu_v4.go b/relay/channel/zhipu_4v/relay-zhipu_v4.go index 06f306f6..97d82c71 100644 --- a/relay/channel/zhipu_4v/relay-zhipu_v4.go +++ b/relay/channel/zhipu_4v/relay-zhipu_v4.go @@ -90,8 +90,7 @@ func requestOpenAI2Zhipu(request dto.GeneralOpenAIRequest) *dto.GeneralOpenAIReq mediaMessages[j] = mediaMessage } } - messageRaw, _ := json.Marshal(mediaMessages) - message.Content = messageRaw + message.SetMediaContent(mediaMessages) } messages = append(messages, dto.Message{ Role: message.Role,