801 lines
25 KiB
Go
801 lines
25 KiB
Go
package service
|
||
|
||
import (
|
||
"encoding/json"
|
||
"fmt"
|
||
"one-api/common"
|
||
"one-api/constant"
|
||
"one-api/dto"
|
||
"one-api/relay/channel/openrouter"
|
||
relaycommon "one-api/relay/common"
|
||
"strings"
|
||
)
|
||
|
||
func ClaudeToOpenAIRequest(claudeRequest dto.ClaudeRequest, info *relaycommon.RelayInfo) (*dto.GeneralOpenAIRequest, error) {
|
||
openAIRequest := dto.GeneralOpenAIRequest{
|
||
Model: claudeRequest.Model,
|
||
MaxTokens: claudeRequest.MaxTokens,
|
||
Temperature: claudeRequest.Temperature,
|
||
TopP: claudeRequest.TopP,
|
||
Stream: claudeRequest.Stream,
|
||
}
|
||
|
||
isOpenRouter := info.ChannelType == constant.ChannelTypeOpenRouter
|
||
|
||
if claudeRequest.Thinking != nil && claudeRequest.Thinking.Type == "enabled" {
|
||
if isOpenRouter {
|
||
reasoning := openrouter.RequestReasoning{
|
||
MaxTokens: claudeRequest.Thinking.GetBudgetTokens(),
|
||
}
|
||
reasoningJSON, err := json.Marshal(reasoning)
|
||
if err != nil {
|
||
return nil, fmt.Errorf("failed to marshal reasoning: %w", err)
|
||
}
|
||
openAIRequest.Reasoning = reasoningJSON
|
||
} else {
|
||
thinkingSuffix := "-thinking"
|
||
if strings.HasSuffix(info.OriginModelName, thinkingSuffix) &&
|
||
!strings.HasSuffix(openAIRequest.Model, thinkingSuffix) {
|
||
openAIRequest.Model = openAIRequest.Model + thinkingSuffix
|
||
}
|
||
}
|
||
}
|
||
|
||
// Convert stop sequences
|
||
if len(claudeRequest.StopSequences) == 1 {
|
||
openAIRequest.Stop = claudeRequest.StopSequences[0]
|
||
} else if len(claudeRequest.StopSequences) > 1 {
|
||
openAIRequest.Stop = claudeRequest.StopSequences
|
||
}
|
||
|
||
// Convert tools
|
||
tools, _ := common.Any2Type[[]dto.Tool](claudeRequest.Tools)
|
||
openAITools := make([]dto.ToolCallRequest, 0)
|
||
for _, claudeTool := range tools {
|
||
openAITool := dto.ToolCallRequest{
|
||
Type: "function",
|
||
Function: dto.FunctionRequest{
|
||
Name: claudeTool.Name,
|
||
Description: claudeTool.Description,
|
||
Parameters: claudeTool.InputSchema,
|
||
},
|
||
}
|
||
openAITools = append(openAITools, openAITool)
|
||
}
|
||
openAIRequest.Tools = openAITools
|
||
|
||
// Convert messages
|
||
openAIMessages := make([]dto.Message, 0)
|
||
|
||
// Add system message if present
|
||
if claudeRequest.System != nil {
|
||
if claudeRequest.IsStringSystem() && claudeRequest.GetStringSystem() != "" {
|
||
openAIMessage := dto.Message{
|
||
Role: "system",
|
||
}
|
||
openAIMessage.SetStringContent(claudeRequest.GetStringSystem())
|
||
openAIMessages = append(openAIMessages, openAIMessage)
|
||
} else {
|
||
systems := claudeRequest.ParseSystem()
|
||
if len(systems) > 0 {
|
||
openAIMessage := dto.Message{
|
||
Role: "system",
|
||
}
|
||
isOpenRouterClaude := isOpenRouter && strings.HasPrefix(info.UpstreamModelName, "anthropic/claude")
|
||
if isOpenRouterClaude {
|
||
systemMediaMessages := make([]dto.MediaContent, 0, len(systems))
|
||
for _, system := range systems {
|
||
message := dto.MediaContent{
|
||
Type: "text",
|
||
Text: system.GetText(),
|
||
CacheControl: system.CacheControl,
|
||
}
|
||
systemMediaMessages = append(systemMediaMessages, message)
|
||
}
|
||
openAIMessage.SetMediaContent(systemMediaMessages)
|
||
} else {
|
||
systemStr := ""
|
||
for _, system := range systems {
|
||
if system.Text != nil {
|
||
systemStr += *system.Text
|
||
}
|
||
}
|
||
openAIMessage.SetStringContent(systemStr)
|
||
}
|
||
openAIMessages = append(openAIMessages, openAIMessage)
|
||
}
|
||
}
|
||
}
|
||
for _, claudeMessage := range claudeRequest.Messages {
|
||
openAIMessage := dto.Message{
|
||
Role: claudeMessage.Role,
|
||
}
|
||
|
||
//log.Printf("claudeMessage.Content: %v", claudeMessage.Content)
|
||
if claudeMessage.IsStringContent() {
|
||
openAIMessage.SetStringContent(claudeMessage.GetStringContent())
|
||
} else {
|
||
content, err := claudeMessage.ParseContent()
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
contents := content
|
||
var toolCalls []dto.ToolCallRequest
|
||
mediaMessages := make([]dto.MediaContent, 0, len(contents))
|
||
|
||
for _, mediaMsg := range contents {
|
||
switch mediaMsg.Type {
|
||
case "text":
|
||
message := dto.MediaContent{
|
||
Type: "text",
|
||
Text: mediaMsg.GetText(),
|
||
CacheControl: mediaMsg.CacheControl,
|
||
}
|
||
mediaMessages = append(mediaMessages, message)
|
||
case "image":
|
||
// Handle image conversion (base64 to URL or keep as is)
|
||
imageData := fmt.Sprintf("data:%s;base64,%s", mediaMsg.Source.MediaType, mediaMsg.Source.Data)
|
||
//textContent += fmt.Sprintf("[Image: %s]", imageData)
|
||
mediaMessage := dto.MediaContent{
|
||
Type: "image_url",
|
||
ImageUrl: &dto.MessageImageUrl{Url: imageData},
|
||
}
|
||
mediaMessages = append(mediaMessages, mediaMessage)
|
||
case "tool_use":
|
||
toolCall := dto.ToolCallRequest{
|
||
ID: mediaMsg.Id,
|
||
Type: "function",
|
||
Function: dto.FunctionRequest{
|
||
Name: mediaMsg.Name,
|
||
Arguments: toJSONString(mediaMsg.Input),
|
||
},
|
||
}
|
||
toolCalls = append(toolCalls, toolCall)
|
||
case "tool_result":
|
||
// Add tool result as a separate message
|
||
oaiToolMessage := dto.Message{
|
||
Role: "tool",
|
||
Name: &mediaMsg.Name,
|
||
ToolCallId: mediaMsg.ToolUseId,
|
||
}
|
||
//oaiToolMessage.SetStringContent(*mediaMsg.GetMediaContent().Text)
|
||
if mediaMsg.IsStringContent() {
|
||
oaiToolMessage.SetStringContent(mediaMsg.GetStringContent())
|
||
} else {
|
||
mediaContents := mediaMsg.ParseMediaContent()
|
||
encodeJson, _ := common.Marshal(mediaContents)
|
||
oaiToolMessage.SetStringContent(string(encodeJson))
|
||
}
|
||
openAIMessages = append(openAIMessages, oaiToolMessage)
|
||
}
|
||
}
|
||
|
||
if len(toolCalls) > 0 {
|
||
openAIMessage.SetToolCalls(toolCalls)
|
||
}
|
||
|
||
if len(mediaMessages) > 0 && len(toolCalls) == 0 {
|
||
openAIMessage.SetMediaContent(mediaMessages)
|
||
}
|
||
}
|
||
if len(openAIMessage.ParseContent()) > 0 || len(openAIMessage.ToolCalls) > 0 {
|
||
openAIMessages = append(openAIMessages, openAIMessage)
|
||
}
|
||
}
|
||
|
||
openAIRequest.Messages = openAIMessages
|
||
|
||
return &openAIRequest, nil
|
||
}
|
||
|
||
func generateStopBlock(index int) *dto.ClaudeResponse {
|
||
return &dto.ClaudeResponse{
|
||
Type: "content_block_stop",
|
||
Index: common.GetPointer[int](index),
|
||
}
|
||
}
|
||
|
||
func StreamResponseOpenAI2Claude(openAIResponse *dto.ChatCompletionsStreamResponse, info *relaycommon.RelayInfo) []*dto.ClaudeResponse {
|
||
var claudeResponses []*dto.ClaudeResponse
|
||
if info.SendResponseCount == 1 {
|
||
msg := &dto.ClaudeMediaMessage{
|
||
Id: openAIResponse.Id,
|
||
Model: openAIResponse.Model,
|
||
Type: "message",
|
||
Role: "assistant",
|
||
Usage: &dto.ClaudeUsage{
|
||
InputTokens: info.PromptTokens,
|
||
OutputTokens: 0,
|
||
},
|
||
}
|
||
msg.SetContent(make([]any, 0))
|
||
claudeResponses = append(claudeResponses, &dto.ClaudeResponse{
|
||
Type: "message_start",
|
||
Message: msg,
|
||
})
|
||
claudeResponses = append(claudeResponses)
|
||
//claudeResponses = append(claudeResponses, &dto.ClaudeResponse{
|
||
// Type: "ping",
|
||
//})
|
||
if openAIResponse.IsToolCall() {
|
||
resp := &dto.ClaudeResponse{
|
||
Type: "content_block_start",
|
||
ContentBlock: &dto.ClaudeMediaMessage{
|
||
Id: openAIResponse.GetFirstToolCall().ID,
|
||
Type: "tool_use",
|
||
Name: openAIResponse.GetFirstToolCall().Function.Name,
|
||
},
|
||
}
|
||
resp.SetIndex(0)
|
||
claudeResponses = append(claudeResponses, resp)
|
||
} else {
|
||
|
||
}
|
||
// 判断首个响应是否存在内容(非标准的 OpenAI 响应)
|
||
if len(openAIResponse.Choices) > 0 && len(openAIResponse.Choices[0].Delta.GetContentString()) > 0 {
|
||
claudeResponses = append(claudeResponses, &dto.ClaudeResponse{
|
||
Index: &info.ClaudeConvertInfo.Index,
|
||
Type: "content_block_start",
|
||
ContentBlock: &dto.ClaudeMediaMessage{
|
||
Type: "text",
|
||
Text: common.GetPointer[string](""),
|
||
},
|
||
})
|
||
claudeResponses = append(claudeResponses, &dto.ClaudeResponse{
|
||
Type: "content_block_delta",
|
||
Delta: &dto.ClaudeMediaMessage{
|
||
Type: "text",
|
||
Text: common.GetPointer[string](openAIResponse.Choices[0].Delta.GetContentString()),
|
||
},
|
||
})
|
||
info.ClaudeConvertInfo.LastMessagesType = relaycommon.LastMessageTypeText
|
||
}
|
||
return claudeResponses
|
||
}
|
||
|
||
if len(openAIResponse.Choices) == 0 {
|
||
// no choices
|
||
// 可能为非标准的 OpenAI 响应,判断是否已经完成
|
||
if info.Done {
|
||
claudeResponses = append(claudeResponses, generateStopBlock(info.ClaudeConvertInfo.Index))
|
||
oaiUsage := info.ClaudeConvertInfo.Usage
|
||
if oaiUsage != nil {
|
||
claudeResponses = append(claudeResponses, &dto.ClaudeResponse{
|
||
Type: "message_delta",
|
||
Usage: &dto.ClaudeUsage{
|
||
InputTokens: oaiUsage.PromptTokens,
|
||
OutputTokens: oaiUsage.CompletionTokens,
|
||
CacheCreationInputTokens: oaiUsage.PromptTokensDetails.CachedCreationTokens,
|
||
CacheReadInputTokens: oaiUsage.PromptTokensDetails.CachedTokens,
|
||
},
|
||
Delta: &dto.ClaudeMediaMessage{
|
||
StopReason: common.GetPointer[string](stopReasonOpenAI2Claude(info.FinishReason)),
|
||
},
|
||
})
|
||
}
|
||
claudeResponses = append(claudeResponses, &dto.ClaudeResponse{
|
||
Type: "message_stop",
|
||
})
|
||
}
|
||
return claudeResponses
|
||
} else {
|
||
chosenChoice := openAIResponse.Choices[0]
|
||
if chosenChoice.FinishReason != nil && *chosenChoice.FinishReason != "" {
|
||
// should be done
|
||
info.FinishReason = *chosenChoice.FinishReason
|
||
return claudeResponses
|
||
}
|
||
if info.Done {
|
||
claudeResponses = append(claudeResponses, generateStopBlock(info.ClaudeConvertInfo.Index))
|
||
oaiUsage := info.ClaudeConvertInfo.Usage
|
||
if oaiUsage != nil {
|
||
claudeResponses = append(claudeResponses, &dto.ClaudeResponse{
|
||
Type: "message_delta",
|
||
Usage: &dto.ClaudeUsage{
|
||
InputTokens: oaiUsage.PromptTokens,
|
||
OutputTokens: oaiUsage.CompletionTokens,
|
||
CacheCreationInputTokens: oaiUsage.PromptTokensDetails.CachedCreationTokens,
|
||
CacheReadInputTokens: oaiUsage.PromptTokensDetails.CachedTokens,
|
||
},
|
||
Delta: &dto.ClaudeMediaMessage{
|
||
StopReason: common.GetPointer[string](stopReasonOpenAI2Claude(info.FinishReason)),
|
||
},
|
||
})
|
||
}
|
||
claudeResponses = append(claudeResponses, &dto.ClaudeResponse{
|
||
Type: "message_stop",
|
||
})
|
||
} else {
|
||
var claudeResponse dto.ClaudeResponse
|
||
var isEmpty bool
|
||
claudeResponse.Type = "content_block_delta"
|
||
if len(chosenChoice.Delta.ToolCalls) > 0 {
|
||
if info.ClaudeConvertInfo.LastMessagesType != relaycommon.LastMessageTypeTools {
|
||
claudeResponses = append(claudeResponses, generateStopBlock(info.ClaudeConvertInfo.Index))
|
||
info.ClaudeConvertInfo.Index++
|
||
claudeResponses = append(claudeResponses, &dto.ClaudeResponse{
|
||
Index: &info.ClaudeConvertInfo.Index,
|
||
Type: "content_block_start",
|
||
ContentBlock: &dto.ClaudeMediaMessage{
|
||
Id: openAIResponse.GetFirstToolCall().ID,
|
||
Type: "tool_use",
|
||
Name: openAIResponse.GetFirstToolCall().Function.Name,
|
||
Input: map[string]interface{}{},
|
||
},
|
||
})
|
||
}
|
||
info.ClaudeConvertInfo.LastMessagesType = relaycommon.LastMessageTypeTools
|
||
// tools delta
|
||
claudeResponse.Delta = &dto.ClaudeMediaMessage{
|
||
Type: "input_json_delta",
|
||
PartialJson: &chosenChoice.Delta.ToolCalls[0].Function.Arguments,
|
||
}
|
||
} else {
|
||
reasoning := chosenChoice.Delta.GetReasoningContent()
|
||
textContent := chosenChoice.Delta.GetContentString()
|
||
if reasoning != "" || textContent != "" {
|
||
if reasoning != "" {
|
||
if info.ClaudeConvertInfo.LastMessagesType != relaycommon.LastMessageTypeThinking {
|
||
//info.ClaudeConvertInfo.Index++
|
||
claudeResponses = append(claudeResponses, &dto.ClaudeResponse{
|
||
Index: &info.ClaudeConvertInfo.Index,
|
||
Type: "content_block_start",
|
||
ContentBlock: &dto.ClaudeMediaMessage{
|
||
Type: "thinking",
|
||
Thinking: "",
|
||
},
|
||
})
|
||
}
|
||
info.ClaudeConvertInfo.LastMessagesType = relaycommon.LastMessageTypeThinking
|
||
// text delta
|
||
claudeResponse.Delta = &dto.ClaudeMediaMessage{
|
||
Type: "thinking_delta",
|
||
Thinking: reasoning,
|
||
}
|
||
} else {
|
||
if info.ClaudeConvertInfo.LastMessagesType != relaycommon.LastMessageTypeText {
|
||
if info.LastMessagesType == relaycommon.LastMessageTypeThinking || info.LastMessagesType == relaycommon.LastMessageTypeTools {
|
||
claudeResponses = append(claudeResponses, generateStopBlock(info.ClaudeConvertInfo.Index))
|
||
info.ClaudeConvertInfo.Index++
|
||
}
|
||
claudeResponses = append(claudeResponses, &dto.ClaudeResponse{
|
||
Index: &info.ClaudeConvertInfo.Index,
|
||
Type: "content_block_start",
|
||
ContentBlock: &dto.ClaudeMediaMessage{
|
||
Type: "text",
|
||
Text: common.GetPointer[string](""),
|
||
},
|
||
})
|
||
}
|
||
info.ClaudeConvertInfo.LastMessagesType = relaycommon.LastMessageTypeText
|
||
// text delta
|
||
claudeResponse.Delta = &dto.ClaudeMediaMessage{
|
||
Type: "text_delta",
|
||
Text: common.GetPointer[string](textContent),
|
||
}
|
||
}
|
||
} else {
|
||
isEmpty = true
|
||
}
|
||
}
|
||
claudeResponse.Index = &info.ClaudeConvertInfo.Index
|
||
if !isEmpty {
|
||
claudeResponses = append(claudeResponses, &claudeResponse)
|
||
}
|
||
}
|
||
}
|
||
|
||
return claudeResponses
|
||
}
|
||
|
||
func ResponseOpenAI2Claude(openAIResponse *dto.OpenAITextResponse, info *relaycommon.RelayInfo) *dto.ClaudeResponse {
|
||
var stopReason string
|
||
contents := make([]dto.ClaudeMediaMessage, 0)
|
||
claudeResponse := &dto.ClaudeResponse{
|
||
Id: openAIResponse.Id,
|
||
Type: "message",
|
||
Role: "assistant",
|
||
Model: openAIResponse.Model,
|
||
}
|
||
for _, choice := range openAIResponse.Choices {
|
||
stopReason = stopReasonOpenAI2Claude(choice.FinishReason)
|
||
claudeContent := dto.ClaudeMediaMessage{}
|
||
if choice.FinishReason == "tool_calls" {
|
||
claudeContent.Type = "tool_use"
|
||
claudeContent.Id = choice.Message.ToolCallId
|
||
claudeContent.Name = choice.Message.ParseToolCalls()[0].Function.Name
|
||
var mapParams map[string]interface{}
|
||
if err := json.Unmarshal([]byte(choice.Message.ParseToolCalls()[0].Function.Arguments), &mapParams); err == nil {
|
||
claudeContent.Input = mapParams
|
||
} else {
|
||
claudeContent.Input = choice.Message.ParseToolCalls()[0].Function.Arguments
|
||
}
|
||
} else {
|
||
claudeContent.Type = "text"
|
||
claudeContent.SetText(choice.Message.StringContent())
|
||
}
|
||
contents = append(contents, claudeContent)
|
||
}
|
||
claudeResponse.Content = contents
|
||
claudeResponse.StopReason = stopReason
|
||
claudeResponse.Usage = &dto.ClaudeUsage{
|
||
InputTokens: openAIResponse.PromptTokens,
|
||
OutputTokens: openAIResponse.CompletionTokens,
|
||
}
|
||
|
||
return claudeResponse
|
||
}
|
||
|
||
func stopReasonOpenAI2Claude(reason string) string {
|
||
switch reason {
|
||
case "stop":
|
||
return "end_turn"
|
||
case "stop_sequence":
|
||
return "stop_sequence"
|
||
case "max_tokens":
|
||
return "max_tokens"
|
||
case "tool_calls":
|
||
return "tool_use"
|
||
default:
|
||
return reason
|
||
}
|
||
}
|
||
|
||
func toJSONString(v interface{}) string {
|
||
b, err := json.Marshal(v)
|
||
if err != nil {
|
||
return "{}"
|
||
}
|
||
return string(b)
|
||
}
|
||
|
||
func GeminiToOpenAIRequest(geminiRequest *dto.GeminiChatRequest, info *relaycommon.RelayInfo) (*dto.GeneralOpenAIRequest, error) {
|
||
openaiRequest := &dto.GeneralOpenAIRequest{
|
||
Model: info.UpstreamModelName,
|
||
Stream: info.IsStream,
|
||
}
|
||
|
||
// 转换 messages
|
||
var messages []dto.Message
|
||
for _, content := range geminiRequest.Contents {
|
||
message := dto.Message{
|
||
Role: convertGeminiRoleToOpenAI(content.Role),
|
||
}
|
||
|
||
// 处理 parts
|
||
var mediaContents []dto.MediaContent
|
||
var toolCalls []dto.ToolCallRequest
|
||
for _, part := range content.Parts {
|
||
if part.Text != "" {
|
||
mediaContent := dto.MediaContent{
|
||
Type: "text",
|
||
Text: part.Text,
|
||
}
|
||
mediaContents = append(mediaContents, mediaContent)
|
||
} else if part.InlineData != nil {
|
||
mediaContent := dto.MediaContent{
|
||
Type: "image_url",
|
||
ImageUrl: &dto.MessageImageUrl{
|
||
Url: fmt.Sprintf("data:%s;base64,%s", part.InlineData.MimeType, part.InlineData.Data),
|
||
Detail: "auto",
|
||
MimeType: part.InlineData.MimeType,
|
||
},
|
||
}
|
||
mediaContents = append(mediaContents, mediaContent)
|
||
} else if part.FileData != nil {
|
||
mediaContent := dto.MediaContent{
|
||
Type: "image_url",
|
||
ImageUrl: &dto.MessageImageUrl{
|
||
Url: part.FileData.FileUri,
|
||
Detail: "auto",
|
||
MimeType: part.FileData.MimeType,
|
||
},
|
||
}
|
||
mediaContents = append(mediaContents, mediaContent)
|
||
} else if part.FunctionCall != nil {
|
||
// 处理 Gemini 的工具调用
|
||
toolCall := dto.ToolCallRequest{
|
||
ID: fmt.Sprintf("call_%d", len(toolCalls)+1), // 生成唯一ID
|
||
Type: "function",
|
||
Function: dto.FunctionRequest{
|
||
Name: part.FunctionCall.FunctionName,
|
||
Arguments: toJSONString(part.FunctionCall.Arguments),
|
||
},
|
||
}
|
||
toolCalls = append(toolCalls, toolCall)
|
||
} else if part.FunctionResponse != nil {
|
||
// 处理 Gemini 的工具响应,创建单独的 tool 消息
|
||
toolMessage := dto.Message{
|
||
Role: "tool",
|
||
ToolCallId: fmt.Sprintf("call_%d", len(toolCalls)), // 使用对应的调用ID
|
||
}
|
||
toolMessage.SetStringContent(toJSONString(part.FunctionResponse.Response))
|
||
messages = append(messages, toolMessage)
|
||
}
|
||
}
|
||
|
||
// 设置消息内容
|
||
if len(toolCalls) > 0 {
|
||
// 如果有工具调用,设置工具调用
|
||
message.SetToolCalls(toolCalls)
|
||
} else if len(mediaContents) == 1 && mediaContents[0].Type == "text" {
|
||
// 如果只有一个文本内容,直接设置字符串
|
||
message.Content = mediaContents[0].Text
|
||
} else if len(mediaContents) > 0 {
|
||
// 如果有多个内容或包含媒体,设置为数组
|
||
message.SetMediaContent(mediaContents)
|
||
}
|
||
|
||
// 只有当消息有内容或工具调用时才添加
|
||
if len(message.ParseContent()) > 0 || len(message.ToolCalls) > 0 {
|
||
messages = append(messages, message)
|
||
}
|
||
}
|
||
|
||
openaiRequest.Messages = messages
|
||
|
||
if geminiRequest.GenerationConfig.Temperature != nil {
|
||
openaiRequest.Temperature = geminiRequest.GenerationConfig.Temperature
|
||
}
|
||
if geminiRequest.GenerationConfig.TopP > 0 {
|
||
openaiRequest.TopP = geminiRequest.GenerationConfig.TopP
|
||
}
|
||
if geminiRequest.GenerationConfig.TopK > 0 {
|
||
openaiRequest.TopK = int(geminiRequest.GenerationConfig.TopK)
|
||
}
|
||
if geminiRequest.GenerationConfig.MaxOutputTokens > 0 {
|
||
openaiRequest.MaxTokens = geminiRequest.GenerationConfig.MaxOutputTokens
|
||
}
|
||
// gemini stop sequences 最多 5 个,openai stop 最多 4 个
|
||
if len(geminiRequest.GenerationConfig.StopSequences) > 0 {
|
||
openaiRequest.Stop = geminiRequest.GenerationConfig.StopSequences[:4]
|
||
}
|
||
if geminiRequest.GenerationConfig.CandidateCount > 0 {
|
||
openaiRequest.N = geminiRequest.GenerationConfig.CandidateCount
|
||
}
|
||
|
||
// 转换工具调用
|
||
if len(geminiRequest.Tools) > 0 {
|
||
var tools []dto.ToolCallRequest
|
||
for _, tool := range geminiRequest.Tools {
|
||
if tool.FunctionDeclarations != nil {
|
||
// 将 Gemini 的 FunctionDeclarations 转换为 OpenAI 的 ToolCallRequest
|
||
functionDeclarations, ok := tool.FunctionDeclarations.([]dto.FunctionRequest)
|
||
if ok {
|
||
for _, function := range functionDeclarations {
|
||
openAITool := dto.ToolCallRequest{
|
||
Type: "function",
|
||
Function: dto.FunctionRequest{
|
||
Name: function.Name,
|
||
Description: function.Description,
|
||
Parameters: function.Parameters,
|
||
},
|
||
}
|
||
tools = append(tools, openAITool)
|
||
}
|
||
}
|
||
}
|
||
}
|
||
if len(tools) > 0 {
|
||
openaiRequest.Tools = tools
|
||
}
|
||
}
|
||
|
||
// gemini system instructions
|
||
if geminiRequest.SystemInstructions != nil {
|
||
// 将系统指令作为第一条消息插入
|
||
systemMessage := dto.Message{
|
||
Role: "system",
|
||
Content: extractTextFromGeminiParts(geminiRequest.SystemInstructions.Parts),
|
||
}
|
||
openaiRequest.Messages = append([]dto.Message{systemMessage}, openaiRequest.Messages...)
|
||
}
|
||
|
||
return openaiRequest, nil
|
||
}
|
||
|
||
func convertGeminiRoleToOpenAI(geminiRole string) string {
|
||
switch geminiRole {
|
||
case "user":
|
||
return "user"
|
||
case "model":
|
||
return "assistant"
|
||
case "function":
|
||
return "function"
|
||
default:
|
||
return "user"
|
||
}
|
||
}
|
||
|
||
func extractTextFromGeminiParts(parts []dto.GeminiPart) string {
|
||
var texts []string
|
||
for _, part := range parts {
|
||
if part.Text != "" {
|
||
texts = append(texts, part.Text)
|
||
}
|
||
}
|
||
return strings.Join(texts, "\n")
|
||
}
|
||
|
||
// ResponseOpenAI2Gemini 将 OpenAI 响应转换为 Gemini 格式
|
||
func ResponseOpenAI2Gemini(openAIResponse *dto.OpenAITextResponse, info *relaycommon.RelayInfo) *dto.GeminiChatResponse {
|
||
geminiResponse := &dto.GeminiChatResponse{
|
||
Candidates: make([]dto.GeminiChatCandidate, 0, len(openAIResponse.Choices)),
|
||
PromptFeedback: dto.GeminiChatPromptFeedback{
|
||
SafetyRatings: []dto.GeminiChatSafetyRating{},
|
||
},
|
||
UsageMetadata: dto.GeminiUsageMetadata{
|
||
PromptTokenCount: openAIResponse.PromptTokens,
|
||
CandidatesTokenCount: openAIResponse.CompletionTokens,
|
||
TotalTokenCount: openAIResponse.PromptTokens + openAIResponse.CompletionTokens,
|
||
},
|
||
}
|
||
|
||
for _, choice := range openAIResponse.Choices {
|
||
candidate := dto.GeminiChatCandidate{
|
||
Index: int64(choice.Index),
|
||
SafetyRatings: []dto.GeminiChatSafetyRating{},
|
||
}
|
||
|
||
// 设置结束原因
|
||
var finishReason string
|
||
switch choice.FinishReason {
|
||
case "stop":
|
||
finishReason = "STOP"
|
||
case "length":
|
||
finishReason = "MAX_TOKENS"
|
||
case "content_filter":
|
||
finishReason = "SAFETY"
|
||
case "tool_calls":
|
||
finishReason = "STOP"
|
||
default:
|
||
finishReason = "STOP"
|
||
}
|
||
candidate.FinishReason = &finishReason
|
||
|
||
// 转换消息内容
|
||
content := dto.GeminiChatContent{
|
||
Role: "model",
|
||
Parts: make([]dto.GeminiPart, 0),
|
||
}
|
||
|
||
// 处理工具调用
|
||
toolCalls := choice.Message.ParseToolCalls()
|
||
if len(toolCalls) > 0 {
|
||
for _, toolCall := range toolCalls {
|
||
// 解析参数
|
||
var args map[string]interface{}
|
||
if toolCall.Function.Arguments != "" {
|
||
if err := json.Unmarshal([]byte(toolCall.Function.Arguments), &args); err != nil {
|
||
args = map[string]interface{}{"arguments": toolCall.Function.Arguments}
|
||
}
|
||
} else {
|
||
args = make(map[string]interface{})
|
||
}
|
||
|
||
part := dto.GeminiPart{
|
||
FunctionCall: &dto.FunctionCall{
|
||
FunctionName: toolCall.Function.Name,
|
||
Arguments: args,
|
||
},
|
||
}
|
||
content.Parts = append(content.Parts, part)
|
||
}
|
||
} else {
|
||
// 处理文本内容
|
||
textContent := choice.Message.StringContent()
|
||
if textContent != "" {
|
||
part := dto.GeminiPart{
|
||
Text: textContent,
|
||
}
|
||
content.Parts = append(content.Parts, part)
|
||
}
|
||
}
|
||
|
||
candidate.Content = content
|
||
geminiResponse.Candidates = append(geminiResponse.Candidates, candidate)
|
||
}
|
||
|
||
return geminiResponse
|
||
}
|
||
|
||
// StreamResponseOpenAI2Gemini 将 OpenAI 流式响应转换为 Gemini 格式
|
||
func StreamResponseOpenAI2Gemini(openAIResponse *dto.ChatCompletionsStreamResponse, info *relaycommon.RelayInfo) *dto.GeminiChatResponse {
|
||
// 检查是否有实际内容或结束标志
|
||
hasContent := false
|
||
hasFinishReason := false
|
||
for _, choice := range openAIResponse.Choices {
|
||
if len(choice.Delta.GetContentString()) > 0 || (choice.Delta.ToolCalls != nil && len(choice.Delta.ToolCalls) > 0) {
|
||
hasContent = true
|
||
}
|
||
if choice.FinishReason != nil {
|
||
hasFinishReason = true
|
||
}
|
||
}
|
||
|
||
// 如果没有实际内容且没有结束标志,跳过。主要针对 openai 流响应开头的空数据
|
||
if !hasContent && !hasFinishReason {
|
||
return nil
|
||
}
|
||
|
||
geminiResponse := &dto.GeminiChatResponse{
|
||
Candidates: make([]dto.GeminiChatCandidate, 0, len(openAIResponse.Choices)),
|
||
PromptFeedback: dto.GeminiChatPromptFeedback{
|
||
SafetyRatings: []dto.GeminiChatSafetyRating{},
|
||
},
|
||
UsageMetadata: dto.GeminiUsageMetadata{
|
||
PromptTokenCount: info.PromptTokens,
|
||
CandidatesTokenCount: 0, // 流式响应中可能没有完整的 usage 信息
|
||
TotalTokenCount: info.PromptTokens,
|
||
},
|
||
}
|
||
|
||
for _, choice := range openAIResponse.Choices {
|
||
candidate := dto.GeminiChatCandidate{
|
||
Index: int64(choice.Index),
|
||
SafetyRatings: []dto.GeminiChatSafetyRating{},
|
||
}
|
||
|
||
// 设置结束原因
|
||
if choice.FinishReason != nil {
|
||
var finishReason string
|
||
switch *choice.FinishReason {
|
||
case "stop":
|
||
finishReason = "STOP"
|
||
case "length":
|
||
finishReason = "MAX_TOKENS"
|
||
case "content_filter":
|
||
finishReason = "SAFETY"
|
||
case "tool_calls":
|
||
finishReason = "STOP"
|
||
default:
|
||
finishReason = "STOP"
|
||
}
|
||
candidate.FinishReason = &finishReason
|
||
}
|
||
|
||
// 转换消息内容
|
||
content := dto.GeminiChatContent{
|
||
Role: "model",
|
||
Parts: make([]dto.GeminiPart, 0),
|
||
}
|
||
|
||
// 处理工具调用
|
||
if choice.Delta.ToolCalls != nil {
|
||
for _, toolCall := range choice.Delta.ToolCalls {
|
||
// 解析参数
|
||
var args map[string]interface{}
|
||
if toolCall.Function.Arguments != "" {
|
||
if err := json.Unmarshal([]byte(toolCall.Function.Arguments), &args); err != nil {
|
||
args = map[string]interface{}{"arguments": toolCall.Function.Arguments}
|
||
}
|
||
} else {
|
||
args = make(map[string]interface{})
|
||
}
|
||
|
||
part := dto.GeminiPart{
|
||
FunctionCall: &dto.FunctionCall{
|
||
FunctionName: toolCall.Function.Name,
|
||
Arguments: args,
|
||
},
|
||
}
|
||
content.Parts = append(content.Parts, part)
|
||
}
|
||
} else {
|
||
// 处理文本内容
|
||
textContent := choice.Delta.GetContentString()
|
||
if textContent != "" {
|
||
part := dto.GeminiPart{
|
||
Text: textContent,
|
||
}
|
||
content.Parts = append(content.Parts, part)
|
||
}
|
||
}
|
||
|
||
candidate.Content = content
|
||
geminiResponse.Candidates = append(geminiResponse.Candidates, candidate)
|
||
}
|
||
|
||
return geminiResponse
|
||
}
|