diff --git a/relay/channel/mistral/text.go b/relay/channel/mistral/text.go index 75272e34..a25c1492 100644 --- a/relay/channel/mistral/text.go +++ b/relay/channel/mistral/text.go @@ -1,13 +1,55 @@ package mistral import ( + "one-api/common" "one-api/dto" + "regexp" ) +var mistralToolCallIdRegexp = regexp.MustCompile("^[a-zA-Z0-9]{9}$") + func requestOpenAI2Mistral(request *dto.GeneralOpenAIRequest) *dto.GeneralOpenAIRequest { messages := make([]dto.Message, 0, len(request.Messages)) + idMap := make(map[string]string) for _, message := range request.Messages { + // 1. tool_calls.id + toolCalls := message.ParseToolCalls() + if toolCalls != nil { + for i := range toolCalls { + if !mistralToolCallIdRegexp.MatchString(toolCalls[i].ID) { + if newId, ok := idMap[toolCalls[i].ID]; ok { + toolCalls[i].ID = newId + } else { + newId, err := common.GenerateRandomCharsKey(9) + if err == nil { + idMap[toolCalls[i].ID] = newId + toolCalls[i].ID = newId + } + } + } + } + message.SetToolCalls(toolCalls) + } + + // 2. tool_call_id + if message.ToolCallId != "" { + if newId, ok := idMap[message.ToolCallId]; ok { + message.ToolCallId = newId + } else { + if !mistralToolCallIdRegexp.MatchString(message.ToolCallId) { + newId, err := common.GenerateRandomCharsKey(9) + if err == nil { + idMap[message.ToolCallId] = newId + message.ToolCallId = newId + } + } + } + } + mediaMessages := message.ParseContent() + if message.Role == "assistant" && message.ToolCalls != nil && string(message.Content) == "null" { + mediaMessages = []dto.MediaContent{} + } for j, mediaMessage := range mediaMessages { if mediaMessage.Type == dto.ContentTypeImageURL { imageUrl := mediaMessage.GetImageMedia()