feat: convert gemini format to openai chat completions
This commit is contained in:
@@ -448,3 +448,353 @@ func toJSONString(v interface{}) string {
|
||||
}
|
||||
return string(b)
|
||||
}
|
||||
|
||||
func GeminiToOpenAIRequest(geminiRequest *dto.GeminiChatRequest, info *relaycommon.RelayInfo) (*dto.GeneralOpenAIRequest, error) {
|
||||
openaiRequest := &dto.GeneralOpenAIRequest{
|
||||
Model: info.UpstreamModelName,
|
||||
Stream: info.IsStream,
|
||||
}
|
||||
|
||||
// 转换 messages
|
||||
var messages []dto.Message
|
||||
for _, content := range geminiRequest.Contents {
|
||||
message := dto.Message{
|
||||
Role: convertGeminiRoleToOpenAI(content.Role),
|
||||
}
|
||||
|
||||
// 处理 parts
|
||||
var mediaContents []dto.MediaContent
|
||||
var toolCalls []dto.ToolCallRequest
|
||||
for _, part := range content.Parts {
|
||||
if part.Text != "" {
|
||||
mediaContent := dto.MediaContent{
|
||||
Type: "text",
|
||||
Text: part.Text,
|
||||
}
|
||||
mediaContents = append(mediaContents, mediaContent)
|
||||
} else if part.InlineData != nil {
|
||||
mediaContent := dto.MediaContent{
|
||||
Type: "image_url",
|
||||
ImageUrl: &dto.MessageImageUrl{
|
||||
Url: fmt.Sprintf("data:%s;base64,%s", part.InlineData.MimeType, part.InlineData.Data),
|
||||
Detail: "auto",
|
||||
MimeType: part.InlineData.MimeType,
|
||||
},
|
||||
}
|
||||
mediaContents = append(mediaContents, mediaContent)
|
||||
} else if part.FileData != nil {
|
||||
mediaContent := dto.MediaContent{
|
||||
Type: "image_url",
|
||||
ImageUrl: &dto.MessageImageUrl{
|
||||
Url: part.FileData.FileUri,
|
||||
Detail: "auto",
|
||||
MimeType: part.FileData.MimeType,
|
||||
},
|
||||
}
|
||||
mediaContents = append(mediaContents, mediaContent)
|
||||
} else if part.FunctionCall != nil {
|
||||
// 处理 Gemini 的工具调用
|
||||
toolCall := dto.ToolCallRequest{
|
||||
ID: fmt.Sprintf("call_%d", len(toolCalls)+1), // 生成唯一ID
|
||||
Type: "function",
|
||||
Function: dto.FunctionRequest{
|
||||
Name: part.FunctionCall.FunctionName,
|
||||
Arguments: toJSONString(part.FunctionCall.Arguments),
|
||||
},
|
||||
}
|
||||
toolCalls = append(toolCalls, toolCall)
|
||||
} else if part.FunctionResponse != nil {
|
||||
// 处理 Gemini 的工具响应,创建单独的 tool 消息
|
||||
toolMessage := dto.Message{
|
||||
Role: "tool",
|
||||
ToolCallId: fmt.Sprintf("call_%d", len(toolCalls)), // 使用对应的调用ID
|
||||
}
|
||||
toolMessage.SetStringContent(toJSONString(part.FunctionResponse.Response))
|
||||
messages = append(messages, toolMessage)
|
||||
}
|
||||
}
|
||||
|
||||
// 设置消息内容
|
||||
if len(toolCalls) > 0 {
|
||||
// 如果有工具调用,设置工具调用
|
||||
message.SetToolCalls(toolCalls)
|
||||
} else if len(mediaContents) == 1 && mediaContents[0].Type == "text" {
|
||||
// 如果只有一个文本内容,直接设置字符串
|
||||
message.Content = mediaContents[0].Text
|
||||
} else if len(mediaContents) > 0 {
|
||||
// 如果有多个内容或包含媒体,设置为数组
|
||||
message.SetMediaContent(mediaContents)
|
||||
}
|
||||
|
||||
// 只有当消息有内容或工具调用时才添加
|
||||
if len(message.ParseContent()) > 0 || len(message.ToolCalls) > 0 {
|
||||
messages = append(messages, message)
|
||||
}
|
||||
}
|
||||
|
||||
openaiRequest.Messages = messages
|
||||
|
||||
if geminiRequest.GenerationConfig.Temperature != nil {
|
||||
openaiRequest.Temperature = geminiRequest.GenerationConfig.Temperature
|
||||
}
|
||||
if geminiRequest.GenerationConfig.TopP > 0 {
|
||||
openaiRequest.TopP = geminiRequest.GenerationConfig.TopP
|
||||
}
|
||||
if geminiRequest.GenerationConfig.TopK > 0 {
|
||||
openaiRequest.TopK = int(geminiRequest.GenerationConfig.TopK)
|
||||
}
|
||||
if geminiRequest.GenerationConfig.MaxOutputTokens > 0 {
|
||||
openaiRequest.MaxTokens = geminiRequest.GenerationConfig.MaxOutputTokens
|
||||
}
|
||||
// gemini stop sequences 最多 5 个,openai stop 最多 4 个
|
||||
if len(geminiRequest.GenerationConfig.StopSequences) > 0 {
|
||||
openaiRequest.Stop = geminiRequest.GenerationConfig.StopSequences[:4]
|
||||
}
|
||||
if geminiRequest.GenerationConfig.CandidateCount > 0 {
|
||||
openaiRequest.N = geminiRequest.GenerationConfig.CandidateCount
|
||||
}
|
||||
|
||||
// 转换工具调用
|
||||
if len(geminiRequest.Tools) > 0 {
|
||||
var tools []dto.ToolCallRequest
|
||||
for _, tool := range geminiRequest.Tools {
|
||||
if tool.FunctionDeclarations != nil {
|
||||
// 将 Gemini 的 FunctionDeclarations 转换为 OpenAI 的 ToolCallRequest
|
||||
functionDeclarations, ok := tool.FunctionDeclarations.([]dto.FunctionRequest)
|
||||
if ok {
|
||||
for _, function := range functionDeclarations {
|
||||
openAITool := dto.ToolCallRequest{
|
||||
Type: "function",
|
||||
Function: dto.FunctionRequest{
|
||||
Name: function.Name,
|
||||
Description: function.Description,
|
||||
Parameters: function.Parameters,
|
||||
},
|
||||
}
|
||||
tools = append(tools, openAITool)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
if len(tools) > 0 {
|
||||
openaiRequest.Tools = tools
|
||||
}
|
||||
}
|
||||
|
||||
// gemini system instructions
|
||||
if geminiRequest.SystemInstructions != nil {
|
||||
// 将系统指令作为第一条消息插入
|
||||
systemMessage := dto.Message{
|
||||
Role: "system",
|
||||
Content: extractTextFromGeminiParts(geminiRequest.SystemInstructions.Parts),
|
||||
}
|
||||
openaiRequest.Messages = append([]dto.Message{systemMessage}, openaiRequest.Messages...)
|
||||
}
|
||||
|
||||
return openaiRequest, nil
|
||||
}
|
||||
|
||||
func convertGeminiRoleToOpenAI(geminiRole string) string {
|
||||
switch geminiRole {
|
||||
case "user":
|
||||
return "user"
|
||||
case "model":
|
||||
return "assistant"
|
||||
case "function":
|
||||
return "function"
|
||||
default:
|
||||
return "user"
|
||||
}
|
||||
}
|
||||
|
||||
func extractTextFromGeminiParts(parts []dto.GeminiPart) string {
|
||||
var texts []string
|
||||
for _, part := range parts {
|
||||
if part.Text != "" {
|
||||
texts = append(texts, part.Text)
|
||||
}
|
||||
}
|
||||
return strings.Join(texts, "\n")
|
||||
}
|
||||
|
||||
// ResponseOpenAI2Gemini 将 OpenAI 响应转换为 Gemini 格式
|
||||
func ResponseOpenAI2Gemini(openAIResponse *dto.OpenAITextResponse, info *relaycommon.RelayInfo) *dto.GeminiChatResponse {
|
||||
geminiResponse := &dto.GeminiChatResponse{
|
||||
Candidates: make([]dto.GeminiChatCandidate, 0, len(openAIResponse.Choices)),
|
||||
PromptFeedback: dto.GeminiChatPromptFeedback{
|
||||
SafetyRatings: []dto.GeminiChatSafetyRating{},
|
||||
},
|
||||
UsageMetadata: dto.GeminiUsageMetadata{
|
||||
PromptTokenCount: openAIResponse.PromptTokens,
|
||||
CandidatesTokenCount: openAIResponse.CompletionTokens,
|
||||
TotalTokenCount: openAIResponse.PromptTokens + openAIResponse.CompletionTokens,
|
||||
},
|
||||
}
|
||||
|
||||
for _, choice := range openAIResponse.Choices {
|
||||
candidate := dto.GeminiChatCandidate{
|
||||
Index: int64(choice.Index),
|
||||
SafetyRatings: []dto.GeminiChatSafetyRating{},
|
||||
}
|
||||
|
||||
// 设置结束原因
|
||||
var finishReason string
|
||||
switch choice.FinishReason {
|
||||
case "stop":
|
||||
finishReason = "STOP"
|
||||
case "length":
|
||||
finishReason = "MAX_TOKENS"
|
||||
case "content_filter":
|
||||
finishReason = "SAFETY"
|
||||
case "tool_calls":
|
||||
finishReason = "STOP"
|
||||
default:
|
||||
finishReason = "STOP"
|
||||
}
|
||||
candidate.FinishReason = &finishReason
|
||||
|
||||
// 转换消息内容
|
||||
content := dto.GeminiChatContent{
|
||||
Role: "model",
|
||||
Parts: make([]dto.GeminiPart, 0),
|
||||
}
|
||||
|
||||
// 处理工具调用
|
||||
toolCalls := choice.Message.ParseToolCalls()
|
||||
if len(toolCalls) > 0 {
|
||||
for _, toolCall := range toolCalls {
|
||||
// 解析参数
|
||||
var args map[string]interface{}
|
||||
if toolCall.Function.Arguments != "" {
|
||||
if err := json.Unmarshal([]byte(toolCall.Function.Arguments), &args); err != nil {
|
||||
args = map[string]interface{}{"arguments": toolCall.Function.Arguments}
|
||||
}
|
||||
} else {
|
||||
args = make(map[string]interface{})
|
||||
}
|
||||
|
||||
part := dto.GeminiPart{
|
||||
FunctionCall: &dto.FunctionCall{
|
||||
FunctionName: toolCall.Function.Name,
|
||||
Arguments: args,
|
||||
},
|
||||
}
|
||||
content.Parts = append(content.Parts, part)
|
||||
}
|
||||
} else {
|
||||
// 处理文本内容
|
||||
textContent := choice.Message.StringContent()
|
||||
if textContent != "" {
|
||||
part := dto.GeminiPart{
|
||||
Text: textContent,
|
||||
}
|
||||
content.Parts = append(content.Parts, part)
|
||||
}
|
||||
}
|
||||
|
||||
candidate.Content = content
|
||||
geminiResponse.Candidates = append(geminiResponse.Candidates, candidate)
|
||||
}
|
||||
|
||||
return geminiResponse
|
||||
}
|
||||
|
||||
// StreamResponseOpenAI2Gemini 将 OpenAI 流式响应转换为 Gemini 格式
|
||||
func StreamResponseOpenAI2Gemini(openAIResponse *dto.ChatCompletionsStreamResponse, info *relaycommon.RelayInfo) *dto.GeminiChatResponse {
|
||||
// 检查是否有实际内容或结束标志
|
||||
hasContent := false
|
||||
hasFinishReason := false
|
||||
for _, choice := range openAIResponse.Choices {
|
||||
if len(choice.Delta.GetContentString()) > 0 || (choice.Delta.ToolCalls != nil && len(choice.Delta.ToolCalls) > 0) {
|
||||
hasContent = true
|
||||
}
|
||||
if choice.FinishReason != nil {
|
||||
hasFinishReason = true
|
||||
}
|
||||
}
|
||||
|
||||
// 如果没有实际内容且没有结束标志,跳过。主要针对 openai 流响应开头的空数据
|
||||
if !hasContent && !hasFinishReason {
|
||||
return nil
|
||||
}
|
||||
|
||||
geminiResponse := &dto.GeminiChatResponse{
|
||||
Candidates: make([]dto.GeminiChatCandidate, 0, len(openAIResponse.Choices)),
|
||||
PromptFeedback: dto.GeminiChatPromptFeedback{
|
||||
SafetyRatings: []dto.GeminiChatSafetyRating{},
|
||||
},
|
||||
UsageMetadata: dto.GeminiUsageMetadata{
|
||||
PromptTokenCount: info.PromptTokens,
|
||||
CandidatesTokenCount: 0, // 流式响应中可能没有完整的 usage 信息
|
||||
TotalTokenCount: info.PromptTokens,
|
||||
},
|
||||
}
|
||||
|
||||
for _, choice := range openAIResponse.Choices {
|
||||
candidate := dto.GeminiChatCandidate{
|
||||
Index: int64(choice.Index),
|
||||
SafetyRatings: []dto.GeminiChatSafetyRating{},
|
||||
}
|
||||
|
||||
// 设置结束原因
|
||||
if choice.FinishReason != nil {
|
||||
var finishReason string
|
||||
switch *choice.FinishReason {
|
||||
case "stop":
|
||||
finishReason = "STOP"
|
||||
case "length":
|
||||
finishReason = "MAX_TOKENS"
|
||||
case "content_filter":
|
||||
finishReason = "SAFETY"
|
||||
case "tool_calls":
|
||||
finishReason = "STOP"
|
||||
default:
|
||||
finishReason = "STOP"
|
||||
}
|
||||
candidate.FinishReason = &finishReason
|
||||
}
|
||||
|
||||
// 转换消息内容
|
||||
content := dto.GeminiChatContent{
|
||||
Role: "model",
|
||||
Parts: make([]dto.GeminiPart, 0),
|
||||
}
|
||||
|
||||
// 处理工具调用
|
||||
if choice.Delta.ToolCalls != nil {
|
||||
for _, toolCall := range choice.Delta.ToolCalls {
|
||||
// 解析参数
|
||||
var args map[string]interface{}
|
||||
if toolCall.Function.Arguments != "" {
|
||||
if err := json.Unmarshal([]byte(toolCall.Function.Arguments), &args); err != nil {
|
||||
args = map[string]interface{}{"arguments": toolCall.Function.Arguments}
|
||||
}
|
||||
} else {
|
||||
args = make(map[string]interface{})
|
||||
}
|
||||
|
||||
part := dto.GeminiPart{
|
||||
FunctionCall: &dto.FunctionCall{
|
||||
FunctionName: toolCall.Function.Name,
|
||||
Arguments: args,
|
||||
},
|
||||
}
|
||||
content.Parts = append(content.Parts, part)
|
||||
}
|
||||
} else {
|
||||
// 处理文本内容
|
||||
textContent := choice.Delta.GetContentString()
|
||||
if textContent != "" {
|
||||
part := dto.GeminiPart{
|
||||
Text: textContent,
|
||||
}
|
||||
content.Parts = append(content.Parts, part)
|
||||
}
|
||||
}
|
||||
|
||||
candidate.Content = content
|
||||
geminiResponse.Candidates = append(geminiResponse.Candidates, candidate)
|
||||
}
|
||||
|
||||
return geminiResponse
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user