feat: convert gemini format to openai chat completions

This commit is contained in:
creamlike1024
2025-08-01 22:23:35 +08:00
parent 9758a9e60d
commit d2183af23f
36 changed files with 648 additions and 66 deletions

View File

@@ -448,3 +448,353 @@ func toJSONString(v interface{}) string {
}
return string(b)
}
func GeminiToOpenAIRequest(geminiRequest *dto.GeminiChatRequest, info *relaycommon.RelayInfo) (*dto.GeneralOpenAIRequest, error) {
openaiRequest := &dto.GeneralOpenAIRequest{
Model: info.UpstreamModelName,
Stream: info.IsStream,
}
// 转换 messages
var messages []dto.Message
for _, content := range geminiRequest.Contents {
message := dto.Message{
Role: convertGeminiRoleToOpenAI(content.Role),
}
// 处理 parts
var mediaContents []dto.MediaContent
var toolCalls []dto.ToolCallRequest
for _, part := range content.Parts {
if part.Text != "" {
mediaContent := dto.MediaContent{
Type: "text",
Text: part.Text,
}
mediaContents = append(mediaContents, mediaContent)
} else if part.InlineData != nil {
mediaContent := dto.MediaContent{
Type: "image_url",
ImageUrl: &dto.MessageImageUrl{
Url: fmt.Sprintf("data:%s;base64,%s", part.InlineData.MimeType, part.InlineData.Data),
Detail: "auto",
MimeType: part.InlineData.MimeType,
},
}
mediaContents = append(mediaContents, mediaContent)
} else if part.FileData != nil {
mediaContent := dto.MediaContent{
Type: "image_url",
ImageUrl: &dto.MessageImageUrl{
Url: part.FileData.FileUri,
Detail: "auto",
MimeType: part.FileData.MimeType,
},
}
mediaContents = append(mediaContents, mediaContent)
} else if part.FunctionCall != nil {
// 处理 Gemini 的工具调用
toolCall := dto.ToolCallRequest{
ID: fmt.Sprintf("call_%d", len(toolCalls)+1), // 生成唯一ID
Type: "function",
Function: dto.FunctionRequest{
Name: part.FunctionCall.FunctionName,
Arguments: toJSONString(part.FunctionCall.Arguments),
},
}
toolCalls = append(toolCalls, toolCall)
} else if part.FunctionResponse != nil {
// 处理 Gemini 的工具响应,创建单独的 tool 消息
toolMessage := dto.Message{
Role: "tool",
ToolCallId: fmt.Sprintf("call_%d", len(toolCalls)), // 使用对应的调用ID
}
toolMessage.SetStringContent(toJSONString(part.FunctionResponse.Response))
messages = append(messages, toolMessage)
}
}
// 设置消息内容
if len(toolCalls) > 0 {
// 如果有工具调用,设置工具调用
message.SetToolCalls(toolCalls)
} else if len(mediaContents) == 1 && mediaContents[0].Type == "text" {
// 如果只有一个文本内容,直接设置字符串
message.Content = mediaContents[0].Text
} else if len(mediaContents) > 0 {
// 如果有多个内容或包含媒体,设置为数组
message.SetMediaContent(mediaContents)
}
// 只有当消息有内容或工具调用时才添加
if len(message.ParseContent()) > 0 || len(message.ToolCalls) > 0 {
messages = append(messages, message)
}
}
openaiRequest.Messages = messages
if geminiRequest.GenerationConfig.Temperature != nil {
openaiRequest.Temperature = geminiRequest.GenerationConfig.Temperature
}
if geminiRequest.GenerationConfig.TopP > 0 {
openaiRequest.TopP = geminiRequest.GenerationConfig.TopP
}
if geminiRequest.GenerationConfig.TopK > 0 {
openaiRequest.TopK = int(geminiRequest.GenerationConfig.TopK)
}
if geminiRequest.GenerationConfig.MaxOutputTokens > 0 {
openaiRequest.MaxTokens = geminiRequest.GenerationConfig.MaxOutputTokens
}
// gemini stop sequences 最多 5 个openai stop 最多 4 个
if len(geminiRequest.GenerationConfig.StopSequences) > 0 {
openaiRequest.Stop = geminiRequest.GenerationConfig.StopSequences[:4]
}
if geminiRequest.GenerationConfig.CandidateCount > 0 {
openaiRequest.N = geminiRequest.GenerationConfig.CandidateCount
}
// 转换工具调用
if len(geminiRequest.Tools) > 0 {
var tools []dto.ToolCallRequest
for _, tool := range geminiRequest.Tools {
if tool.FunctionDeclarations != nil {
// 将 Gemini 的 FunctionDeclarations 转换为 OpenAI 的 ToolCallRequest
functionDeclarations, ok := tool.FunctionDeclarations.([]dto.FunctionRequest)
if ok {
for _, function := range functionDeclarations {
openAITool := dto.ToolCallRequest{
Type: "function",
Function: dto.FunctionRequest{
Name: function.Name,
Description: function.Description,
Parameters: function.Parameters,
},
}
tools = append(tools, openAITool)
}
}
}
}
if len(tools) > 0 {
openaiRequest.Tools = tools
}
}
// gemini system instructions
if geminiRequest.SystemInstructions != nil {
// 将系统指令作为第一条消息插入
systemMessage := dto.Message{
Role: "system",
Content: extractTextFromGeminiParts(geminiRequest.SystemInstructions.Parts),
}
openaiRequest.Messages = append([]dto.Message{systemMessage}, openaiRequest.Messages...)
}
return openaiRequest, nil
}
func convertGeminiRoleToOpenAI(geminiRole string) string {
switch geminiRole {
case "user":
return "user"
case "model":
return "assistant"
case "function":
return "function"
default:
return "user"
}
}
func extractTextFromGeminiParts(parts []dto.GeminiPart) string {
var texts []string
for _, part := range parts {
if part.Text != "" {
texts = append(texts, part.Text)
}
}
return strings.Join(texts, "\n")
}
// ResponseOpenAI2Gemini 将 OpenAI 响应转换为 Gemini 格式
func ResponseOpenAI2Gemini(openAIResponse *dto.OpenAITextResponse, info *relaycommon.RelayInfo) *dto.GeminiChatResponse {
geminiResponse := &dto.GeminiChatResponse{
Candidates: make([]dto.GeminiChatCandidate, 0, len(openAIResponse.Choices)),
PromptFeedback: dto.GeminiChatPromptFeedback{
SafetyRatings: []dto.GeminiChatSafetyRating{},
},
UsageMetadata: dto.GeminiUsageMetadata{
PromptTokenCount: openAIResponse.PromptTokens,
CandidatesTokenCount: openAIResponse.CompletionTokens,
TotalTokenCount: openAIResponse.PromptTokens + openAIResponse.CompletionTokens,
},
}
for _, choice := range openAIResponse.Choices {
candidate := dto.GeminiChatCandidate{
Index: int64(choice.Index),
SafetyRatings: []dto.GeminiChatSafetyRating{},
}
// 设置结束原因
var finishReason string
switch choice.FinishReason {
case "stop":
finishReason = "STOP"
case "length":
finishReason = "MAX_TOKENS"
case "content_filter":
finishReason = "SAFETY"
case "tool_calls":
finishReason = "STOP"
default:
finishReason = "STOP"
}
candidate.FinishReason = &finishReason
// 转换消息内容
content := dto.GeminiChatContent{
Role: "model",
Parts: make([]dto.GeminiPart, 0),
}
// 处理工具调用
toolCalls := choice.Message.ParseToolCalls()
if len(toolCalls) > 0 {
for _, toolCall := range toolCalls {
// 解析参数
var args map[string]interface{}
if toolCall.Function.Arguments != "" {
if err := json.Unmarshal([]byte(toolCall.Function.Arguments), &args); err != nil {
args = map[string]interface{}{"arguments": toolCall.Function.Arguments}
}
} else {
args = make(map[string]interface{})
}
part := dto.GeminiPart{
FunctionCall: &dto.FunctionCall{
FunctionName: toolCall.Function.Name,
Arguments: args,
},
}
content.Parts = append(content.Parts, part)
}
} else {
// 处理文本内容
textContent := choice.Message.StringContent()
if textContent != "" {
part := dto.GeminiPart{
Text: textContent,
}
content.Parts = append(content.Parts, part)
}
}
candidate.Content = content
geminiResponse.Candidates = append(geminiResponse.Candidates, candidate)
}
return geminiResponse
}
// StreamResponseOpenAI2Gemini 将 OpenAI 流式响应转换为 Gemini 格式
func StreamResponseOpenAI2Gemini(openAIResponse *dto.ChatCompletionsStreamResponse, info *relaycommon.RelayInfo) *dto.GeminiChatResponse {
// 检查是否有实际内容或结束标志
hasContent := false
hasFinishReason := false
for _, choice := range openAIResponse.Choices {
if len(choice.Delta.GetContentString()) > 0 || (choice.Delta.ToolCalls != nil && len(choice.Delta.ToolCalls) > 0) {
hasContent = true
}
if choice.FinishReason != nil {
hasFinishReason = true
}
}
// 如果没有实际内容且没有结束标志,跳过。主要针对 openai 流响应开头的空数据
if !hasContent && !hasFinishReason {
return nil
}
geminiResponse := &dto.GeminiChatResponse{
Candidates: make([]dto.GeminiChatCandidate, 0, len(openAIResponse.Choices)),
PromptFeedback: dto.GeminiChatPromptFeedback{
SafetyRatings: []dto.GeminiChatSafetyRating{},
},
UsageMetadata: dto.GeminiUsageMetadata{
PromptTokenCount: info.PromptTokens,
CandidatesTokenCount: 0, // 流式响应中可能没有完整的 usage 信息
TotalTokenCount: info.PromptTokens,
},
}
for _, choice := range openAIResponse.Choices {
candidate := dto.GeminiChatCandidate{
Index: int64(choice.Index),
SafetyRatings: []dto.GeminiChatSafetyRating{},
}
// 设置结束原因
if choice.FinishReason != nil {
var finishReason string
switch *choice.FinishReason {
case "stop":
finishReason = "STOP"
case "length":
finishReason = "MAX_TOKENS"
case "content_filter":
finishReason = "SAFETY"
case "tool_calls":
finishReason = "STOP"
default:
finishReason = "STOP"
}
candidate.FinishReason = &finishReason
}
// 转换消息内容
content := dto.GeminiChatContent{
Role: "model",
Parts: make([]dto.GeminiPart, 0),
}
// 处理工具调用
if choice.Delta.ToolCalls != nil {
for _, toolCall := range choice.Delta.ToolCalls {
// 解析参数
var args map[string]interface{}
if toolCall.Function.Arguments != "" {
if err := json.Unmarshal([]byte(toolCall.Function.Arguments), &args); err != nil {
args = map[string]interface{}{"arguments": toolCall.Function.Arguments}
}
} else {
args = make(map[string]interface{})
}
part := dto.GeminiPart{
FunctionCall: &dto.FunctionCall{
FunctionName: toolCall.Function.Name,
Arguments: args,
},
}
content.Parts = append(content.Parts, part)
}
} else {
// 处理文本内容
textContent := choice.Delta.GetContentString()
if textContent != "" {
part := dto.GeminiPart{
Text: textContent,
}
content.Parts = append(content.Parts, part)
}
}
candidate.Content = content
geminiResponse.Candidates = append(geminiResponse.Candidates, candidate)
}
return geminiResponse
}