From 88a2fec19050327ee76603e4a2852de667e88644 Mon Sep 17 00:00:00 2001 From: "1808837298@qq.com" <1808837298@qq.com> Date: Sat, 22 Feb 2025 16:29:48 +0800 Subject: [PATCH] fix: mistral --- dto/openai_request.go | 2 +- relay/channel/mistral/adaptor.go | 2 +- relay/channel/mistral/text.go | 20 ++++++++------------ relay/channel/openai/relay-openai.go | 3 +++ 4 files changed, 13 insertions(+), 14 deletions(-) diff --git a/dto/openai_request.go b/dto/openai_request.go index 38303642..028e0286 100644 --- a/dto/openai_request.go +++ b/dto/openai_request.go @@ -101,7 +101,7 @@ type Message struct { type MediaContent struct { Type string `json:"type"` - Text string `json:"text"` + Text string `json:"text,omitempty"` ImageUrl any `json:"image_url,omitempty"` InputAudio any `json:"input_audio,omitempty"` } diff --git a/relay/channel/mistral/adaptor.go b/relay/channel/mistral/adaptor.go index f599df63..fcea169a 100644 --- a/relay/channel/mistral/adaptor.go +++ b/relay/channel/mistral/adaptor.go @@ -41,7 +41,7 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, re if request == nil { return nil, errors.New("request is nil") } - return request, nil + return requestOpenAI2Mistral(request), nil } func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) { diff --git a/relay/channel/mistral/text.go b/relay/channel/mistral/text.go index 04add067..8987b8f0 100644 --- a/relay/channel/mistral/text.go +++ b/relay/channel/mistral/text.go @@ -1,25 +1,21 @@ package mistral import ( - "encoding/json" "one-api/dto" ) -func requestOpenAI2Mistral(request dto.GeneralOpenAIRequest) *dto.GeneralOpenAIRequest { +func requestOpenAI2Mistral(request *dto.GeneralOpenAIRequest) *dto.GeneralOpenAIRequest { messages := make([]dto.Message, 0, len(request.Messages)) for _, message := range request.Messages { - if !message.IsStringContent() { - mediaMessages := message.ParseContent() - for j, mediaMessage := range mediaMessages { - if mediaMessage.Type == dto.ContentTypeImageURL { - imageUrl := mediaMessage.ImageUrl.(dto.MessageImageUrl) - mediaMessage.ImageUrl = imageUrl.Url - mediaMessages[j] = mediaMessage - } + mediaMessages := message.ParseContent() + for j, mediaMessage := range mediaMessages { + if mediaMessage.Type == dto.ContentTypeImageURL { + imageUrl := mediaMessage.ImageUrl.(dto.MessageImageUrl) + mediaMessage.ImageUrl = imageUrl.Url + mediaMessages[j] = mediaMessage } - messageRaw, _ := json.Marshal(mediaMessages) - message.Content = messageRaw } + message.SetMediaContent(mediaMessages) messages = append(messages, dto.Message{ Role: message.Role, Content: message.Content, diff --git a/relay/channel/openai/relay-openai.go b/relay/channel/openai/relay-openai.go index fb952711..33cdea48 100644 --- a/relay/channel/openai/relay-openai.go +++ b/relay/channel/openai/relay-openai.go @@ -87,6 +87,9 @@ func OaiStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel info.SetFirstResponseTime() ticker.Reset(time.Duration(constant.StreamingTimeout) * time.Second) data := scanner.Text() + if common.DebugEnabled { + println(data) + } if len(data) < 6 { // ignore blank line or wrong format continue }