diff --git a/relay/channel/gemini/dto.go b/dto/gemini.go similarity index 98% rename from relay/channel/gemini/dto.go rename to dto/gemini.go index a5e41c83..f7acd355 100644 --- a/relay/channel/gemini/dto.go +++ b/dto/gemini.go @@ -1,4 +1,4 @@ -package gemini +package dto import ( "encoding/json" @@ -56,7 +56,7 @@ type FunctionCall struct { Arguments any `json:"args"` } -type FunctionResponse struct { +type GeminiFunctionResponse struct { Name string `json:"name"` Response map[string]interface{} `json:"response"` } @@ -81,7 +81,7 @@ type GeminiPart struct { Thought bool `json:"thought,omitempty"` InlineData *GeminiInlineData `json:"inlineData,omitempty"` FunctionCall *FunctionCall `json:"functionCall,omitempty"` - FunctionResponse *FunctionResponse `json:"functionResponse,omitempty"` + FunctionResponse *GeminiFunctionResponse `json:"functionResponse,omitempty"` FileData *GeminiFileData `json:"fileData,omitempty"` ExecutableCode *GeminiPartExecutableCode `json:"executableCode,omitempty"` CodeExecutionResult *GeminiPartCodeExecutionResult `json:"codeExecutionResult,omitempty"` diff --git a/relay/channel/adapter.go b/relay/channel/adapter.go index ab8836ba..ec749133 100644 --- a/relay/channel/adapter.go +++ b/relay/channel/adapter.go @@ -26,6 +26,7 @@ type Adaptor interface { GetModelList() []string GetChannelName() string ConvertClaudeRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.ClaudeRequest) (any, error) + ConvertGeminiRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeminiChatRequest) (any, error) } type TaskAdaptor interface { diff --git a/relay/channel/ali/adaptor.go b/relay/channel/ali/adaptor.go index d941a1bc..067fac37 100644 --- a/relay/channel/ali/adaptor.go +++ b/relay/channel/ali/adaptor.go @@ -18,6 +18,11 @@ import ( type Adaptor struct { } +func (a *Adaptor) ConvertGeminiRequest(*gin.Context, *relaycommon.RelayInfo, *dto.GeminiChatRequest) (any, error) { + //TODO implement me + return nil, errors.New("not implemented") +} + func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) { //TODO implement me panic("implement me") diff --git a/relay/channel/aws/adaptor.go b/relay/channel/aws/adaptor.go index d3354f00..d7910725 100644 --- a/relay/channel/aws/adaptor.go +++ b/relay/channel/aws/adaptor.go @@ -22,6 +22,11 @@ type Adaptor struct { RequestMode int } +func (a *Adaptor) ConvertGeminiRequest(*gin.Context, *relaycommon.RelayInfo, *dto.GeminiChatRequest) (any, error) { + //TODO implement me + return nil, errors.New("not implemented") +} + func (a *Adaptor) ConvertClaudeRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.ClaudeRequest) (any, error) { c.Set("request_model", request.Model) c.Set("converted_request", request) diff --git a/relay/channel/baidu/adaptor.go b/relay/channel/baidu/adaptor.go index 22443354..8396a844 100644 --- a/relay/channel/baidu/adaptor.go +++ b/relay/channel/baidu/adaptor.go @@ -18,6 +18,11 @@ import ( type Adaptor struct { } +func (a *Adaptor) ConvertGeminiRequest(*gin.Context, *relaycommon.RelayInfo, *dto.GeminiChatRequest) (any, error) { + //TODO implement me + return nil, errors.New("not implemented") +} + func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) { //TODO implement me panic("implement me") diff --git a/relay/channel/baidu_v2/adaptor.go b/relay/channel/baidu_v2/adaptor.go index 375fd531..b8a4ac2f 100644 --- a/relay/channel/baidu_v2/adaptor.go +++ b/relay/channel/baidu_v2/adaptor.go @@ -18,6 +18,11 @@ import ( type Adaptor struct { } +func (a *Adaptor) ConvertGeminiRequest(*gin.Context, *relaycommon.RelayInfo, *dto.GeminiChatRequest) (any, error) { + //TODO implement me + return nil, errors.New("not implemented") +} + func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) { //TODO implement me panic("implement me") @@ -43,15 +48,15 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) { func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error { channel.SetupApiRequestHeader(info, c, req) - keyParts := strings.Split(info.ApiKey, "|") + keyParts := strings.Split(info.ApiKey, "|") if len(keyParts) == 0 || keyParts[0] == "" { - return errors.New("invalid API key: authorization token is required") - } - if len(keyParts) > 1 { - if keyParts[1] != "" { - req.Set("appid", keyParts[1]) - } - } + return errors.New("invalid API key: authorization token is required") + } + if len(keyParts) > 1 { + if keyParts[1] != "" { + req.Set("appid", keyParts[1]) + } + } req.Set("Authorization", "Bearer "+keyParts[0]) return nil } diff --git a/relay/channel/claude/adaptor.go b/relay/channel/claude/adaptor.go index 540742d6..0f7a9414 100644 --- a/relay/channel/claude/adaptor.go +++ b/relay/channel/claude/adaptor.go @@ -24,6 +24,11 @@ type Adaptor struct { RequestMode int } +func (a *Adaptor) ConvertGeminiRequest(*gin.Context, *relaycommon.RelayInfo, *dto.GeminiChatRequest) (any, error) { + //TODO implement me + return nil, errors.New("not implemented") +} + func (a *Adaptor) ConvertClaudeRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.ClaudeRequest) (any, error) { return request, nil } diff --git a/relay/channel/claude_code/adaptor.go b/relay/channel/claude_code/adaptor.go index 7a0be927..a5926f9d 100644 --- a/relay/channel/claude_code/adaptor.go +++ b/relay/channel/claude_code/adaptor.go @@ -25,6 +25,11 @@ type Adaptor struct { RequestMode int } +func (a *Adaptor) ConvertGeminiRequest(*gin.Context, *relaycommon.RelayInfo, *dto.GeminiChatRequest) (any, error) { + //TODO implement me + return nil, errors.New("not implemented") +} + func (a *Adaptor) ConvertClaudeRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.ClaudeRequest) (any, error) { // Use configured system prompt if available, otherwise use default if info.ChannelSetting.SystemPrompt != "" { diff --git a/relay/channel/cloudflare/adaptor.go b/relay/channel/cloudflare/adaptor.go index 6e59ad71..74a65ba4 100644 --- a/relay/channel/cloudflare/adaptor.go +++ b/relay/channel/cloudflare/adaptor.go @@ -18,6 +18,11 @@ import ( type Adaptor struct { } +func (a *Adaptor) ConvertGeminiRequest(*gin.Context, *relaycommon.RelayInfo, *dto.GeminiChatRequest) (any, error) { + //TODO implement me + return nil, errors.New("not implemented") +} + func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) { //TODO implement me panic("implement me") diff --git a/relay/channel/cohere/adaptor.go b/relay/channel/cohere/adaptor.go index 4f3a96c3..887f9efd 100644 --- a/relay/channel/cohere/adaptor.go +++ b/relay/channel/cohere/adaptor.go @@ -17,6 +17,11 @@ import ( type Adaptor struct { } +func (a *Adaptor) ConvertGeminiRequest(*gin.Context, *relaycommon.RelayInfo, *dto.GeminiChatRequest) (any, error) { + //TODO implement me + return nil, errors.New("not implemented") +} + func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) { //TODO implement me panic("implement me") diff --git a/relay/channel/coze/adaptor.go b/relay/channel/coze/adaptor.go index fe5f5f00..658c6193 100644 --- a/relay/channel/coze/adaptor.go +++ b/relay/channel/coze/adaptor.go @@ -18,6 +18,11 @@ import ( type Adaptor struct { } +func (a *Adaptor) ConvertGeminiRequest(*gin.Context, *common.RelayInfo, *dto.GeminiChatRequest) (any, error) { + //TODO implement me + return nil, errors.New("not implemented") +} + // ConvertAudioRequest implements channel.Adaptor. func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *common.RelayInfo, request dto.AudioRequest) (io.Reader, error) { return nil, errors.New("not implemented") diff --git a/relay/channel/deepseek/adaptor.go b/relay/channel/deepseek/adaptor.go index edfc7fd3..ac8ea18f 100644 --- a/relay/channel/deepseek/adaptor.go +++ b/relay/channel/deepseek/adaptor.go @@ -19,6 +19,11 @@ import ( type Adaptor struct { } +func (a *Adaptor) ConvertGeminiRequest(*gin.Context, *relaycommon.RelayInfo, *dto.GeminiChatRequest) (any, error) { + //TODO implement me + return nil, errors.New("not implemented") +} + func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) { //TODO implement me panic("implement me") diff --git a/relay/channel/dify/adaptor.go b/relay/channel/dify/adaptor.go index 4ad16766..8c7898c9 100644 --- a/relay/channel/dify/adaptor.go +++ b/relay/channel/dify/adaptor.go @@ -24,6 +24,11 @@ type Adaptor struct { BotType int } +func (a *Adaptor) ConvertGeminiRequest(*gin.Context, *relaycommon.RelayInfo, *dto.GeminiChatRequest) (any, error) { + //TODO implement me + return nil, errors.New("not implemented") +} + func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) { //TODO implement me panic("implement me") diff --git a/relay/channel/gemini/adaptor.go b/relay/channel/gemini/adaptor.go index 2b7b7e39..20d43020 100644 --- a/relay/channel/gemini/adaptor.go +++ b/relay/channel/gemini/adaptor.go @@ -20,6 +20,10 @@ import ( type Adaptor struct { } +func (a *Adaptor) ConvertGeminiRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeminiChatRequest) (any, error) { + return request, nil +} + func (a *Adaptor) ConvertClaudeRequest(c *gin.Context, info *relaycommon.RelayInfo, req *dto.ClaudeRequest) (any, error) { adaptor := openai.Adaptor{} oaiReq, err := adaptor.ConvertClaudeRequest(c, info, req) @@ -51,13 +55,13 @@ func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInf } // build gemini imagen request - geminiRequest := GeminiImageRequest{ - Instances: []GeminiImageInstance{ + geminiRequest := dto.GeminiImageRequest{ + Instances: []dto.GeminiImageInstance{ { Prompt: request.Prompt, }, }, - Parameters: GeminiImageParameters{ + Parameters: dto.GeminiImageParameters{ SampleCount: request.N, AspectRatio: aspectRatio, PersonGeneration: "allow_adult", // default allow adult @@ -138,9 +142,9 @@ func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.Rela } // only process the first input - geminiRequest := GeminiEmbeddingRequest{ - Content: GeminiChatContent{ - Parts: []GeminiPart{ + geminiRequest := dto.GeminiEmbeddingRequest{ + Content: dto.GeminiChatContent{ + Parts: []dto.GeminiPart{ { Text: inputs[0], }, diff --git a/relay/channel/gemini/relay-gemini-native.go b/relay/channel/gemini/relay-gemini-native.go index 7d459cc2..2060fd8c 100644 --- a/relay/channel/gemini/relay-gemini-native.go +++ b/relay/channel/gemini/relay-gemini-native.go @@ -28,7 +28,7 @@ func GeminiTextGenerationHandler(c *gin.Context, info *relaycommon.RelayInfo, re } // 解析为 Gemini 原生响应格式 - var geminiResponse GeminiChatResponse + var geminiResponse dto.GeminiChatResponse err = common.Unmarshal(responseBody, &geminiResponse) if err != nil { return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError) @@ -71,7 +71,7 @@ func GeminiTextGenerationStreamHandler(c *gin.Context, info *relaycommon.RelayIn responseText := strings.Builder{} helper.StreamScannerHandler(c, resp, info, func(data string) bool { - var geminiResponse GeminiChatResponse + var geminiResponse dto.GeminiChatResponse err := common.UnmarshalJsonStr(data, &geminiResponse) if err != nil { common.LogError(c, "error unmarshalling stream response: "+err.Error()) diff --git a/relay/channel/gemini/relay-gemini.go b/relay/channel/gemini/relay-gemini.go index 5dac0ce5..4065259f 100644 --- a/relay/channel/gemini/relay-gemini.go +++ b/relay/channel/gemini/relay-gemini.go @@ -81,7 +81,7 @@ func clampThinkingBudget(modelName string, budget int) int { return budget } -func ThinkingAdaptor(geminiRequest *GeminiChatRequest, info *relaycommon.RelayInfo) { +func ThinkingAdaptor(geminiRequest *dto.GeminiChatRequest, info *relaycommon.RelayInfo) { if model_setting.GetGeminiSettings().ThinkingAdapterEnabled { modelName := info.UpstreamModelName isNew25Pro := strings.HasPrefix(modelName, "gemini-2.5-pro") && @@ -93,7 +93,7 @@ func ThinkingAdaptor(geminiRequest *GeminiChatRequest, info *relaycommon.RelayIn if len(parts) == 2 && parts[1] != "" { if budgetTokens, err := strconv.Atoi(parts[1]); err == nil { clampedBudget := clampThinkingBudget(modelName, budgetTokens) - geminiRequest.GenerationConfig.ThinkingConfig = &GeminiThinkingConfig{ + geminiRequest.GenerationConfig.ThinkingConfig = &dto.GeminiThinkingConfig{ ThinkingBudget: common.GetPointer(clampedBudget), IncludeThoughts: true, } @@ -113,11 +113,11 @@ func ThinkingAdaptor(geminiRequest *GeminiChatRequest, info *relaycommon.RelayIn } if isUnsupported { - geminiRequest.GenerationConfig.ThinkingConfig = &GeminiThinkingConfig{ + geminiRequest.GenerationConfig.ThinkingConfig = &dto.GeminiThinkingConfig{ IncludeThoughts: true, } } else { - geminiRequest.GenerationConfig.ThinkingConfig = &GeminiThinkingConfig{ + geminiRequest.GenerationConfig.ThinkingConfig = &dto.GeminiThinkingConfig{ IncludeThoughts: true, } if geminiRequest.GenerationConfig.MaxOutputTokens > 0 { @@ -128,7 +128,7 @@ func ThinkingAdaptor(geminiRequest *GeminiChatRequest, info *relaycommon.RelayIn } } else if strings.HasSuffix(modelName, "-nothinking") { if !isNew25Pro { - geminiRequest.GenerationConfig.ThinkingConfig = &GeminiThinkingConfig{ + geminiRequest.GenerationConfig.ThinkingConfig = &dto.GeminiThinkingConfig{ ThinkingBudget: common.GetPointer(0), } } @@ -137,11 +137,11 @@ func ThinkingAdaptor(geminiRequest *GeminiChatRequest, info *relaycommon.RelayIn } // Setting safety to the lowest possible values since Gemini is already powerless enough -func CovertGemini2OpenAI(textRequest dto.GeneralOpenAIRequest, info *relaycommon.RelayInfo) (*GeminiChatRequest, error) { +func CovertGemini2OpenAI(textRequest dto.GeneralOpenAIRequest, info *relaycommon.RelayInfo) (*dto.GeminiChatRequest, error) { - geminiRequest := GeminiChatRequest{ - Contents: make([]GeminiChatContent, 0, len(textRequest.Messages)), - GenerationConfig: GeminiChatGenerationConfig{ + geminiRequest := dto.GeminiChatRequest{ + Contents: make([]dto.GeminiChatContent, 0, len(textRequest.Messages)), + GenerationConfig: dto.GeminiChatGenerationConfig{ Temperature: textRequest.Temperature, TopP: textRequest.TopP, MaxOutputTokens: textRequest.MaxTokens, @@ -158,9 +158,9 @@ func CovertGemini2OpenAI(textRequest dto.GeneralOpenAIRequest, info *relaycommon ThinkingAdaptor(&geminiRequest, info) - safetySettings := make([]GeminiChatSafetySettings, 0, len(SafetySettingList)) + safetySettings := make([]dto.GeminiChatSafetySettings, 0, len(SafetySettingList)) for _, category := range SafetySettingList { - safetySettings = append(safetySettings, GeminiChatSafetySettings{ + safetySettings = append(safetySettings, dto.GeminiChatSafetySettings{ Category: category, Threshold: model_setting.GetGeminiSafetySetting(category), }) @@ -198,17 +198,17 @@ func CovertGemini2OpenAI(textRequest dto.GeneralOpenAIRequest, info *relaycommon functions = append(functions, tool.Function) } if codeExecution { - geminiRequest.Tools = append(geminiRequest.Tools, GeminiChatTool{ + geminiRequest.Tools = append(geminiRequest.Tools, dto.GeminiChatTool{ CodeExecution: make(map[string]string), }) } if googleSearch { - geminiRequest.Tools = append(geminiRequest.Tools, GeminiChatTool{ + geminiRequest.Tools = append(geminiRequest.Tools, dto.GeminiChatTool{ GoogleSearch: make(map[string]string), }) } if len(functions) > 0 { - geminiRequest.Tools = append(geminiRequest.Tools, GeminiChatTool{ + geminiRequest.Tools = append(geminiRequest.Tools, dto.GeminiChatTool{ FunctionDeclarations: functions, }) } @@ -238,7 +238,7 @@ func CovertGemini2OpenAI(textRequest dto.GeneralOpenAIRequest, info *relaycommon continue } else if message.Role == "tool" || message.Role == "function" { if len(geminiRequest.Contents) == 0 || geminiRequest.Contents[len(geminiRequest.Contents)-1].Role == "model" { - geminiRequest.Contents = append(geminiRequest.Contents, GeminiChatContent{ + geminiRequest.Contents = append(geminiRequest.Contents, dto.GeminiChatContent{ Role: "user", }) } @@ -265,18 +265,18 @@ func CovertGemini2OpenAI(textRequest dto.GeneralOpenAIRequest, info *relaycommon } } - functionResp := &FunctionResponse{ + functionResp := &dto.GeminiFunctionResponse{ Name: name, Response: contentMap, } - *parts = append(*parts, GeminiPart{ + *parts = append(*parts, dto.GeminiPart{ FunctionResponse: functionResp, }) continue } - var parts []GeminiPart - content := GeminiChatContent{ + var parts []dto.GeminiPart + content := dto.GeminiChatContent{ Role: message.Role, } // isToolCall := false @@ -290,8 +290,8 @@ func CovertGemini2OpenAI(textRequest dto.GeneralOpenAIRequest, info *relaycommon return nil, fmt.Errorf("invalid arguments for function %s, args: %s", call.Function.Name, call.Function.Arguments) } } - toolCall := GeminiPart{ - FunctionCall: &FunctionCall{ + toolCall := dto.GeminiPart{ + FunctionCall: &dto.FunctionCall{ FunctionName: call.Function.Name, Arguments: args, }, @@ -308,7 +308,7 @@ func CovertGemini2OpenAI(textRequest dto.GeneralOpenAIRequest, info *relaycommon if part.Text == "" { continue } - parts = append(parts, GeminiPart{ + parts = append(parts, dto.GeminiPart{ Text: part.Text, }) } else if part.Type == dto.ContentTypeImageURL { @@ -331,8 +331,8 @@ func CovertGemini2OpenAI(textRequest dto.GeneralOpenAIRequest, info *relaycommon return nil, fmt.Errorf("mime type is not supported by Gemini: '%s', url: '%s', supported types are: %v", fileData.MimeType, url, getSupportedMimeTypesList()) } - parts = append(parts, GeminiPart{ - InlineData: &GeminiInlineData{ + parts = append(parts, dto.GeminiPart{ + InlineData: &dto.GeminiInlineData{ MimeType: fileData.MimeType, // 使用原始的 MimeType,因为大小写可能对API有意义 Data: fileData.Base64Data, }, @@ -342,8 +342,8 @@ func CovertGemini2OpenAI(textRequest dto.GeneralOpenAIRequest, info *relaycommon if err != nil { return nil, fmt.Errorf("decode base64 image data failed: %s", err.Error()) } - parts = append(parts, GeminiPart{ - InlineData: &GeminiInlineData{ + parts = append(parts, dto.GeminiPart{ + InlineData: &dto.GeminiInlineData{ MimeType: format, Data: base64String, }, @@ -357,8 +357,8 @@ func CovertGemini2OpenAI(textRequest dto.GeneralOpenAIRequest, info *relaycommon if err != nil { return nil, fmt.Errorf("decode base64 file data failed: %s", err.Error()) } - parts = append(parts, GeminiPart{ - InlineData: &GeminiInlineData{ + parts = append(parts, dto.GeminiPart{ + InlineData: &dto.GeminiInlineData{ MimeType: format, Data: base64String, }, @@ -371,8 +371,8 @@ func CovertGemini2OpenAI(textRequest dto.GeneralOpenAIRequest, info *relaycommon if err != nil { return nil, fmt.Errorf("decode base64 audio data failed: %s", err.Error()) } - parts = append(parts, GeminiPart{ - InlineData: &GeminiInlineData{ + parts = append(parts, dto.GeminiPart{ + InlineData: &dto.GeminiInlineData{ MimeType: "audio/" + part.GetInputAudio().Format, Data: base64String, }, @@ -392,8 +392,8 @@ func CovertGemini2OpenAI(textRequest dto.GeneralOpenAIRequest, info *relaycommon } if len(system_content) > 0 { - geminiRequest.SystemInstructions = &GeminiChatContent{ - Parts: []GeminiPart{ + geminiRequest.SystemInstructions = &dto.GeminiChatContent{ + Parts: []dto.GeminiPart{ { Text: strings.Join(system_content, "\n"), }, @@ -636,7 +636,7 @@ func unescapeMapOrSlice(data interface{}) interface{} { return data } -func getResponseToolCall(item *GeminiPart) *dto.ToolCallResponse { +func getResponseToolCall(item *dto.GeminiPart) *dto.ToolCallResponse { var argsBytes []byte var err error if result, ok := item.FunctionCall.Arguments.(map[string]interface{}); ok { @@ -658,7 +658,7 @@ func getResponseToolCall(item *GeminiPart) *dto.ToolCallResponse { } } -func responseGeminiChat2OpenAI(c *gin.Context, response *GeminiChatResponse) *dto.OpenAITextResponse { +func responseGeminiChat2OpenAI(c *gin.Context, response *dto.GeminiChatResponse) *dto.OpenAITextResponse { fullTextResponse := dto.OpenAITextResponse{ Id: helper.GetResponseID(c), Object: "chat.completion", @@ -725,7 +725,7 @@ func responseGeminiChat2OpenAI(c *gin.Context, response *GeminiChatResponse) *dt return &fullTextResponse } -func streamResponseGeminiChat2OpenAI(geminiResponse *GeminiChatResponse) (*dto.ChatCompletionsStreamResponse, bool, bool) { +func streamResponseGeminiChat2OpenAI(geminiResponse *dto.GeminiChatResponse) (*dto.ChatCompletionsStreamResponse, bool, bool) { choices := make([]dto.ChatCompletionsStreamResponseChoice, 0, len(geminiResponse.Candidates)) isStop := false hasImage := false @@ -830,7 +830,7 @@ func GeminiChatStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp * respCount := 0 helper.StreamScannerHandler(c, resp, info, func(data string) bool { - var geminiResponse GeminiChatResponse + var geminiResponse dto.GeminiChatResponse err := common.UnmarshalJsonStr(data, &geminiResponse) if err != nil { common.LogError(c, "error unmarshalling stream response: "+err.Error()) @@ -913,7 +913,7 @@ func GeminiChatHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.R if common.DebugEnabled { println(string(responseBody)) } - var geminiResponse GeminiChatResponse + var geminiResponse dto.GeminiChatResponse err = common.Unmarshal(responseBody, &geminiResponse) if err != nil { return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError) @@ -959,7 +959,7 @@ func GeminiEmbeddingHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *h return nil, types.NewOpenAIError(readErr, types.ErrorCodeBadResponseBody, http.StatusInternalServerError) } - var geminiResponse GeminiEmbeddingResponse + var geminiResponse dto.GeminiEmbeddingResponse if jsonErr := common.Unmarshal(responseBody, &geminiResponse); jsonErr != nil { return nil, types.NewOpenAIError(jsonErr, types.ErrorCodeBadResponseBody, http.StatusInternalServerError) } @@ -1005,7 +1005,7 @@ func GeminiImageHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http. } _ = resp.Body.Close() - var geminiResponse GeminiImageResponse + var geminiResponse dto.GeminiImageResponse if jsonErr := common.Unmarshal(responseBody, &geminiResponse); jsonErr != nil { return nil, types.NewOpenAIError(jsonErr, types.ErrorCodeBadResponseBody, http.StatusInternalServerError) } diff --git a/relay/channel/jimeng/adaptor.go b/relay/channel/jimeng/adaptor.go index 0b743879..ff9ac678 100644 --- a/relay/channel/jimeng/adaptor.go +++ b/relay/channel/jimeng/adaptor.go @@ -4,7 +4,6 @@ import ( "encoding/json" "errors" "fmt" - "github.com/gin-gonic/gin" "io" "net/http" "one-api/dto" @@ -13,11 +12,18 @@ import ( relaycommon "one-api/relay/common" relayconstant "one-api/relay/constant" "one-api/types" + + "github.com/gin-gonic/gin" ) type Adaptor struct { } +func (a *Adaptor) ConvertGeminiRequest(*gin.Context, *relaycommon.RelayInfo, *dto.GeminiChatRequest) (any, error) { + //TODO implement me + return nil, errors.New("not implemented") +} + func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) { return nil, errors.New("not implemented") } diff --git a/relay/channel/jina/adaptor.go b/relay/channel/jina/adaptor.go index 408a5c6e..bf318aa7 100644 --- a/relay/channel/jina/adaptor.go +++ b/relay/channel/jina/adaptor.go @@ -19,6 +19,11 @@ import ( type Adaptor struct { } +func (a *Adaptor) ConvertGeminiRequest(*gin.Context, *relaycommon.RelayInfo, *dto.GeminiChatRequest) (any, error) { + //TODO implement me + return nil, errors.New("not implemented") +} + func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) { //TODO implement me panic("implement me") diff --git a/relay/channel/mistral/adaptor.go b/relay/channel/mistral/adaptor.go index 434a1031..45cb3290 100644 --- a/relay/channel/mistral/adaptor.go +++ b/relay/channel/mistral/adaptor.go @@ -16,6 +16,11 @@ import ( type Adaptor struct { } +func (a *Adaptor) ConvertGeminiRequest(*gin.Context, *relaycommon.RelayInfo, *dto.GeminiChatRequest) (any, error) { + //TODO implement me + return nil, errors.New("not implemented") +} + func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) { //TODO implement me panic("implement me") diff --git a/relay/channel/mokaai/adaptor.go b/relay/channel/mokaai/adaptor.go index b0b54b0c..37db2aec 100644 --- a/relay/channel/mokaai/adaptor.go +++ b/relay/channel/mokaai/adaptor.go @@ -18,6 +18,11 @@ import ( type Adaptor struct { } +func (a *Adaptor) ConvertGeminiRequest(*gin.Context, *relaycommon.RelayInfo, *dto.GeminiChatRequest) (any, error) { + //TODO implement me + return nil, errors.New("not implemented") +} + func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) { //TODO implement me panic("implement me") diff --git a/relay/channel/ollama/adaptor.go b/relay/channel/ollama/adaptor.go index ff88de8b..1f3fda8d 100644 --- a/relay/channel/ollama/adaptor.go +++ b/relay/channel/ollama/adaptor.go @@ -17,6 +17,11 @@ import ( type Adaptor struct { } +func (a *Adaptor) ConvertGeminiRequest(*gin.Context, *relaycommon.RelayInfo, *dto.GeminiChatRequest) (any, error) { + //TODO implement me + return nil, errors.New("not implemented") +} + func (a *Adaptor) ConvertClaudeRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.ClaudeRequest) (any, error) { openaiAdaptor := openai.Adaptor{} openaiRequest, err := openaiAdaptor.ConvertClaudeRequest(c, info, request) diff --git a/relay/channel/openai/adaptor.go b/relay/channel/openai/adaptor.go index efd22878..df858ea2 100644 --- a/relay/channel/openai/adaptor.go +++ b/relay/channel/openai/adaptor.go @@ -34,6 +34,15 @@ type Adaptor struct { ResponseFormat string } +func (a *Adaptor) ConvertGeminiRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeminiChatRequest) (any, error) { + // 使用 service.GeminiToOpenAIRequest 转换请求格式 + openaiRequest, err := service.GeminiToOpenAIRequest(request, info) + if err != nil { + return nil, err + } + return a.ConvertOpenAIRequest(c, info, openaiRequest) +} + func (a *Adaptor) ConvertClaudeRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.ClaudeRequest) (any, error) { //if !strings.Contains(request.Model, "claude") { // return nil, fmt.Errorf("you are using openai channel type with path /v1/messages, only claude model supported convert, but got %s", request.Model) @@ -64,7 +73,7 @@ func (a *Adaptor) Init(info *relaycommon.RelayInfo) { } func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) { - if info.RelayFormat == relaycommon.RelayFormatClaude { + if info.RelayFormat == relaycommon.RelayFormatClaude || info.RelayFormat == relaycommon.RelayFormatGemini { return fmt.Sprintf("%s/v1/chat/completions", info.BaseUrl), nil } if info.RelayMode == relayconstant.RelayModeRealtime { diff --git a/relay/channel/openai/helper.go b/relay/channel/openai/helper.go index 1681c9ff..528f1276 100644 --- a/relay/channel/openai/helper.go +++ b/relay/channel/openai/helper.go @@ -2,6 +2,8 @@ package openai import ( "encoding/json" + "errors" + "net/http" "one-api/common" "one-api/dto" relaycommon "one-api/relay/common" @@ -16,11 +18,14 @@ import ( // 辅助函数 func HandleStreamFormat(c *gin.Context, info *relaycommon.RelayInfo, data string, forceFormat bool, thinkToContent bool) error { info.SendResponseCount++ + switch info.RelayFormat { case relaycommon.RelayFormatOpenAI: return sendStreamData(c, info, data, forceFormat, thinkToContent) case relaycommon.RelayFormatClaude: return handleClaudeFormat(c, data, info) + case relaycommon.RelayFormatGemini: + return handleGeminiFormat(c, data, info) } return nil } @@ -41,6 +46,46 @@ func handleClaudeFormat(c *gin.Context, data string, info *relaycommon.RelayInfo return nil } +func handleGeminiFormat(c *gin.Context, data string, info *relaycommon.RelayInfo) error { + // 截取前50个字符用于调试 + debugData := data + if len(data) > 50 { + debugData = data[:50] + "..." + } + common.LogInfo(c, "handleGeminiFormat called with data: "+debugData) + + var streamResponse dto.ChatCompletionsStreamResponse + if err := common.Unmarshal(common.StringToByteSlice(data), &streamResponse); err != nil { + common.LogError(c, "failed to unmarshal stream response: "+err.Error()) + return err + } + + common.LogInfo(c, "successfully unmarshaled stream response") + geminiResponse := service.StreamResponseOpenAI2Gemini(&streamResponse, info) + + // 如果返回 nil,表示没有实际内容,跳过发送 + if geminiResponse == nil { + common.LogInfo(c, "handleGeminiFormat: no content to send, skipping") + return nil + } + + geminiResponseStr, err := common.Marshal(geminiResponse) + if err != nil { + common.LogError(c, "failed to marshal gemini response: "+err.Error()) + return err + } + + common.LogInfo(c, "sending gemini format response") + // send gemini format response + c.Render(-1, common.CustomEvent{Data: "data: " + string(geminiResponseStr)}) + if flusher, ok := c.Writer.(http.Flusher); ok { + flusher.Flush() + } else { + return errors.New("streaming error: flusher not found") + } + return nil +} + func ProcessStreamResponse(streamResponse dto.ChatCompletionsStreamResponse, responseTextBuilder *strings.Builder, toolCount *int) error { for _, choice := range streamResponse.Choices { responseTextBuilder.WriteString(choice.Delta.GetContentString()) @@ -185,6 +230,37 @@ func HandleFinalResponse(c *gin.Context, info *relaycommon.RelayInfo, lastStream for _, resp := range claudeResponses { _ = helper.ClaudeData(c, *resp) } + + case relaycommon.RelayFormatGemini: + var streamResponse dto.ChatCompletionsStreamResponse + if err := common.Unmarshal(common.StringToByteSlice(lastStreamData), &streamResponse); err != nil { + common.SysError("error unmarshalling stream response: " + err.Error()) + return + } + + // 这里处理的是 openai 最后一个流响应,其 delta 为空,有 finish_reason 字段 + // 因此相比较于 google 官方的流响应,由 openai 转换而来会多一个 parts 为空,finishReason 为 STOP 的响应 + // 而包含最后一段文本输出的响应(倒数第二个)的 finishReason 为 null + // 暂不知是否有程序会不兼容。 + + geminiResponse := service.StreamResponseOpenAI2Gemini(&streamResponse, info) + + // openai 流响应开头的空数据 + if geminiResponse == nil { + return + } + + geminiResponseStr, err := common.Marshal(geminiResponse) + if err != nil { + common.SysError("error marshalling gemini response: " + err.Error()) + return + } + + // 发送最终的 Gemini 响应 + c.Render(-1, common.CustomEvent{Data: "data: " + string(geminiResponseStr)}) + if flusher, ok := c.Writer.(http.Flusher); ok { + flusher.Flush() + } } } diff --git a/relay/channel/openai/relay-openai.go b/relay/channel/openai/relay-openai.go index f6a04f3a..9ae0a200 100644 --- a/relay/channel/openai/relay-openai.go +++ b/relay/channel/openai/relay-openai.go @@ -223,6 +223,13 @@ func OpenaiHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Respo return nil, types.NewError(err, types.ErrorCodeBadResponseBody) } responseBody = claudeRespStr + case relaycommon.RelayFormatGemini: + geminiResp := service.ResponseOpenAI2Gemini(&simpleResponse, info) + geminiRespStr, err := common.Marshal(geminiResp) + if err != nil { + return nil, types.NewError(err, types.ErrorCodeBadResponseBody) + } + responseBody = geminiRespStr } common.IOCopyBytesGracefully(c, resp, responseBody) diff --git a/relay/channel/palm/adaptor.go b/relay/channel/palm/adaptor.go index a60dc4b2..4d1ab783 100644 --- a/relay/channel/palm/adaptor.go +++ b/relay/channel/palm/adaptor.go @@ -17,6 +17,11 @@ import ( type Adaptor struct { } +func (a *Adaptor) ConvertGeminiRequest(*gin.Context, *relaycommon.RelayInfo, *dto.GeminiChatRequest) (any, error) { + //TODO implement me + return nil, errors.New("not implemented") +} + func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) { //TODO implement me panic("implement me") diff --git a/relay/channel/perplexity/adaptor.go b/relay/channel/perplexity/adaptor.go index 19830aca..92cb08a2 100644 --- a/relay/channel/perplexity/adaptor.go +++ b/relay/channel/perplexity/adaptor.go @@ -17,6 +17,11 @@ import ( type Adaptor struct { } +func (a *Adaptor) ConvertGeminiRequest(*gin.Context, *relaycommon.RelayInfo, *dto.GeminiChatRequest) (any, error) { + //TODO implement me + return nil, errors.New("not implemented") +} + func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) { //TODO implement me panic("implement me") diff --git a/relay/channel/siliconflow/adaptor.go b/relay/channel/siliconflow/adaptor.go index c80e9ea1..05e6d453 100644 --- a/relay/channel/siliconflow/adaptor.go +++ b/relay/channel/siliconflow/adaptor.go @@ -18,6 +18,11 @@ import ( type Adaptor struct { } +func (a *Adaptor) ConvertGeminiRequest(*gin.Context, *relaycommon.RelayInfo, *dto.GeminiChatRequest) (any, error) { + //TODO implement me + return nil, errors.New("not implemented") +} + func (a *Adaptor) ConvertClaudeRequest(c *gin.Context, info *relaycommon.RelayInfo, req *dto.ClaudeRequest) (any, error) { adaptor := openai.Adaptor{} return adaptor.ConvertClaudeRequest(c, info, req) diff --git a/relay/channel/tencent/adaptor.go b/relay/channel/tencent/adaptor.go index 520276a7..b86d8a16 100644 --- a/relay/channel/tencent/adaptor.go +++ b/relay/channel/tencent/adaptor.go @@ -25,6 +25,11 @@ type Adaptor struct { Timestamp int64 } +func (a *Adaptor) ConvertGeminiRequest(*gin.Context, *relaycommon.RelayInfo, *dto.GeminiChatRequest) (any, error) { + //TODO implement me + return nil, errors.New("not implemented") +} + func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) { //TODO implement me panic("implement me") diff --git a/relay/channel/vertex/adaptor.go b/relay/channel/vertex/adaptor.go index c88b4359..39be998e 100644 --- a/relay/channel/vertex/adaptor.go +++ b/relay/channel/vertex/adaptor.go @@ -44,6 +44,10 @@ type Adaptor struct { AccountCredentials Credentials } +func (a *Adaptor) ConvertGeminiRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeminiChatRequest) (any, error) { + return request, nil +} + func (a *Adaptor) ConvertClaudeRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.ClaudeRequest) (any, error) { if v, ok := claudeModelMap[info.UpstreamModelName]; ok { c.Set("request_model", v) diff --git a/relay/channel/volcengine/adaptor.go b/relay/channel/volcengine/adaptor.go index af15d636..225b3895 100644 --- a/relay/channel/volcengine/adaptor.go +++ b/relay/channel/volcengine/adaptor.go @@ -23,6 +23,11 @@ import ( type Adaptor struct { } +func (a *Adaptor) ConvertGeminiRequest(*gin.Context, *relaycommon.RelayInfo, *dto.GeminiChatRequest) (any, error) { + //TODO implement me + return nil, errors.New("not implemented") +} + func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) { //TODO implement me panic("implement me") diff --git a/relay/channel/xai/adaptor.go b/relay/channel/xai/adaptor.go index 8d880137..6a3a5370 100644 --- a/relay/channel/xai/adaptor.go +++ b/relay/channel/xai/adaptor.go @@ -19,6 +19,11 @@ import ( type Adaptor struct { } +func (a *Adaptor) ConvertGeminiRequest(*gin.Context, *relaycommon.RelayInfo, *dto.GeminiChatRequest) (any, error) { + //TODO implement me + return nil, errors.New("not implemented") +} + func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) { //TODO implement me //panic("implement me") diff --git a/relay/channel/xunfei/adaptor.go b/relay/channel/xunfei/adaptor.go index 0d218ada..7ee76f1a 100644 --- a/relay/channel/xunfei/adaptor.go +++ b/relay/channel/xunfei/adaptor.go @@ -17,6 +17,11 @@ type Adaptor struct { request *dto.GeneralOpenAIRequest } +func (a *Adaptor) ConvertGeminiRequest(*gin.Context, *relaycommon.RelayInfo, *dto.GeminiChatRequest) (any, error) { + //TODO implement me + return nil, errors.New("not implemented") +} + func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) { //TODO implement me panic("implement me") diff --git a/relay/channel/zhipu/adaptor.go b/relay/channel/zhipu/adaptor.go index 43344428..e3be0e8e 100644 --- a/relay/channel/zhipu/adaptor.go +++ b/relay/channel/zhipu/adaptor.go @@ -16,6 +16,11 @@ import ( type Adaptor struct { } +func (a *Adaptor) ConvertGeminiRequest(*gin.Context, *relaycommon.RelayInfo, *dto.GeminiChatRequest) (any, error) { + //TODO implement me + return nil, errors.New("not implemented") +} + func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) { //TODO implement me panic("implement me") diff --git a/relay/channel/zhipu_4v/adaptor.go b/relay/channel/zhipu_4v/adaptor.go index edd7a534..83070fe5 100644 --- a/relay/channel/zhipu_4v/adaptor.go +++ b/relay/channel/zhipu_4v/adaptor.go @@ -18,6 +18,11 @@ import ( type Adaptor struct { } +func (a *Adaptor) ConvertGeminiRequest(*gin.Context, *relaycommon.RelayInfo, *dto.GeminiChatRequest) (any, error) { + //TODO implement me + return nil, errors.New("not implemented") +} + func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) { //TODO implement me panic("implement me") diff --git a/relay/gemini_handler.go b/relay/gemini_handler.go index 43c7ca58..862630ea 100644 --- a/relay/gemini_handler.go +++ b/relay/gemini_handler.go @@ -20,8 +20,8 @@ import ( "github.com/gin-gonic/gin" ) -func getAndValidateGeminiRequest(c *gin.Context) (*gemini.GeminiChatRequest, error) { - request := &gemini.GeminiChatRequest{} +func getAndValidateGeminiRequest(c *gin.Context) (*dto.GeminiChatRequest, error) { + request := &dto.GeminiChatRequest{} err := common.UnmarshalBodyReusable(c, request) if err != nil { return nil, err @@ -44,7 +44,7 @@ func checkGeminiStreamMode(c *gin.Context, relayInfo *relaycommon.RelayInfo) { // } } -func checkGeminiInputSensitive(textRequest *gemini.GeminiChatRequest) ([]string, error) { +func checkGeminiInputSensitive(textRequest *dto.GeminiChatRequest) ([]string, error) { var inputTexts []string for _, content := range textRequest.Contents { for _, part := range content.Parts { @@ -61,7 +61,7 @@ func checkGeminiInputSensitive(textRequest *gemini.GeminiChatRequest) ([]string, return sensitiveWords, err } -func getGeminiInputTokens(req *gemini.GeminiChatRequest, info *relaycommon.RelayInfo) int { +func getGeminiInputTokens(req *dto.GeminiChatRequest, info *relaycommon.RelayInfo) int { // 计算输入 token 数量 var inputTexts []string for _, content := range req.Contents { @@ -78,7 +78,7 @@ func getGeminiInputTokens(req *gemini.GeminiChatRequest, info *relaycommon.Relay return inputTokens } -func isNoThinkingRequest(req *gemini.GeminiChatRequest) bool { +func isNoThinkingRequest(req *dto.GeminiChatRequest) bool { if req.GenerationConfig.ThinkingConfig != nil && req.GenerationConfig.ThinkingConfig.ThinkingBudget != nil { return *req.GenerationConfig.ThinkingConfig.ThinkingBudget == 0 } @@ -202,7 +202,12 @@ func GeminiHelper(c *gin.Context) (newAPIError *types.NewAPIError) { } requestBody = bytes.NewReader(body) } else { - jsonData, err := common.Marshal(req) + // 使用 ConvertGeminiRequest 转换请求格式 + convertedRequest, err := adaptor.ConvertGeminiRequest(c, relayInfo, req) + if err != nil { + return types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry()) + } + jsonData, err := common.Marshal(convertedRequest) if err != nil { return types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry()) } diff --git a/service/convert.go b/service/convert.go index 787cc79d..ee8ecee5 100644 --- a/service/convert.go +++ b/service/convert.go @@ -448,3 +448,353 @@ func toJSONString(v interface{}) string { } return string(b) } + +func GeminiToOpenAIRequest(geminiRequest *dto.GeminiChatRequest, info *relaycommon.RelayInfo) (*dto.GeneralOpenAIRequest, error) { + openaiRequest := &dto.GeneralOpenAIRequest{ + Model: info.UpstreamModelName, + Stream: info.IsStream, + } + + // 转换 messages + var messages []dto.Message + for _, content := range geminiRequest.Contents { + message := dto.Message{ + Role: convertGeminiRoleToOpenAI(content.Role), + } + + // 处理 parts + var mediaContents []dto.MediaContent + var toolCalls []dto.ToolCallRequest + for _, part := range content.Parts { + if part.Text != "" { + mediaContent := dto.MediaContent{ + Type: "text", + Text: part.Text, + } + mediaContents = append(mediaContents, mediaContent) + } else if part.InlineData != nil { + mediaContent := dto.MediaContent{ + Type: "image_url", + ImageUrl: &dto.MessageImageUrl{ + Url: fmt.Sprintf("data:%s;base64,%s", part.InlineData.MimeType, part.InlineData.Data), + Detail: "auto", + MimeType: part.InlineData.MimeType, + }, + } + mediaContents = append(mediaContents, mediaContent) + } else if part.FileData != nil { + mediaContent := dto.MediaContent{ + Type: "image_url", + ImageUrl: &dto.MessageImageUrl{ + Url: part.FileData.FileUri, + Detail: "auto", + MimeType: part.FileData.MimeType, + }, + } + mediaContents = append(mediaContents, mediaContent) + } else if part.FunctionCall != nil { + // 处理 Gemini 的工具调用 + toolCall := dto.ToolCallRequest{ + ID: fmt.Sprintf("call_%d", len(toolCalls)+1), // 生成唯一ID + Type: "function", + Function: dto.FunctionRequest{ + Name: part.FunctionCall.FunctionName, + Arguments: toJSONString(part.FunctionCall.Arguments), + }, + } + toolCalls = append(toolCalls, toolCall) + } else if part.FunctionResponse != nil { + // 处理 Gemini 的工具响应,创建单独的 tool 消息 + toolMessage := dto.Message{ + Role: "tool", + ToolCallId: fmt.Sprintf("call_%d", len(toolCalls)), // 使用对应的调用ID + } + toolMessage.SetStringContent(toJSONString(part.FunctionResponse.Response)) + messages = append(messages, toolMessage) + } + } + + // 设置消息内容 + if len(toolCalls) > 0 { + // 如果有工具调用,设置工具调用 + message.SetToolCalls(toolCalls) + } else if len(mediaContents) == 1 && mediaContents[0].Type == "text" { + // 如果只有一个文本内容,直接设置字符串 + message.Content = mediaContents[0].Text + } else if len(mediaContents) > 0 { + // 如果有多个内容或包含媒体,设置为数组 + message.SetMediaContent(mediaContents) + } + + // 只有当消息有内容或工具调用时才添加 + if len(message.ParseContent()) > 0 || len(message.ToolCalls) > 0 { + messages = append(messages, message) + } + } + + openaiRequest.Messages = messages + + if geminiRequest.GenerationConfig.Temperature != nil { + openaiRequest.Temperature = geminiRequest.GenerationConfig.Temperature + } + if geminiRequest.GenerationConfig.TopP > 0 { + openaiRequest.TopP = geminiRequest.GenerationConfig.TopP + } + if geminiRequest.GenerationConfig.TopK > 0 { + openaiRequest.TopK = int(geminiRequest.GenerationConfig.TopK) + } + if geminiRequest.GenerationConfig.MaxOutputTokens > 0 { + openaiRequest.MaxTokens = geminiRequest.GenerationConfig.MaxOutputTokens + } + // gemini stop sequences 最多 5 个,openai stop 最多 4 个 + if len(geminiRequest.GenerationConfig.StopSequences) > 0 { + openaiRequest.Stop = geminiRequest.GenerationConfig.StopSequences[:4] + } + if geminiRequest.GenerationConfig.CandidateCount > 0 { + openaiRequest.N = geminiRequest.GenerationConfig.CandidateCount + } + + // 转换工具调用 + if len(geminiRequest.Tools) > 0 { + var tools []dto.ToolCallRequest + for _, tool := range geminiRequest.Tools { + if tool.FunctionDeclarations != nil { + // 将 Gemini 的 FunctionDeclarations 转换为 OpenAI 的 ToolCallRequest + functionDeclarations, ok := tool.FunctionDeclarations.([]dto.FunctionRequest) + if ok { + for _, function := range functionDeclarations { + openAITool := dto.ToolCallRequest{ + Type: "function", + Function: dto.FunctionRequest{ + Name: function.Name, + Description: function.Description, + Parameters: function.Parameters, + }, + } + tools = append(tools, openAITool) + } + } + } + } + if len(tools) > 0 { + openaiRequest.Tools = tools + } + } + + // gemini system instructions + if geminiRequest.SystemInstructions != nil { + // 将系统指令作为第一条消息插入 + systemMessage := dto.Message{ + Role: "system", + Content: extractTextFromGeminiParts(geminiRequest.SystemInstructions.Parts), + } + openaiRequest.Messages = append([]dto.Message{systemMessage}, openaiRequest.Messages...) + } + + return openaiRequest, nil +} + +func convertGeminiRoleToOpenAI(geminiRole string) string { + switch geminiRole { + case "user": + return "user" + case "model": + return "assistant" + case "function": + return "function" + default: + return "user" + } +} + +func extractTextFromGeminiParts(parts []dto.GeminiPart) string { + var texts []string + for _, part := range parts { + if part.Text != "" { + texts = append(texts, part.Text) + } + } + return strings.Join(texts, "\n") +} + +// ResponseOpenAI2Gemini 将 OpenAI 响应转换为 Gemini 格式 +func ResponseOpenAI2Gemini(openAIResponse *dto.OpenAITextResponse, info *relaycommon.RelayInfo) *dto.GeminiChatResponse { + geminiResponse := &dto.GeminiChatResponse{ + Candidates: make([]dto.GeminiChatCandidate, 0, len(openAIResponse.Choices)), + PromptFeedback: dto.GeminiChatPromptFeedback{ + SafetyRatings: []dto.GeminiChatSafetyRating{}, + }, + UsageMetadata: dto.GeminiUsageMetadata{ + PromptTokenCount: openAIResponse.PromptTokens, + CandidatesTokenCount: openAIResponse.CompletionTokens, + TotalTokenCount: openAIResponse.PromptTokens + openAIResponse.CompletionTokens, + }, + } + + for _, choice := range openAIResponse.Choices { + candidate := dto.GeminiChatCandidate{ + Index: int64(choice.Index), + SafetyRatings: []dto.GeminiChatSafetyRating{}, + } + + // 设置结束原因 + var finishReason string + switch choice.FinishReason { + case "stop": + finishReason = "STOP" + case "length": + finishReason = "MAX_TOKENS" + case "content_filter": + finishReason = "SAFETY" + case "tool_calls": + finishReason = "STOP" + default: + finishReason = "STOP" + } + candidate.FinishReason = &finishReason + + // 转换消息内容 + content := dto.GeminiChatContent{ + Role: "model", + Parts: make([]dto.GeminiPart, 0), + } + + // 处理工具调用 + toolCalls := choice.Message.ParseToolCalls() + if len(toolCalls) > 0 { + for _, toolCall := range toolCalls { + // 解析参数 + var args map[string]interface{} + if toolCall.Function.Arguments != "" { + if err := json.Unmarshal([]byte(toolCall.Function.Arguments), &args); err != nil { + args = map[string]interface{}{"arguments": toolCall.Function.Arguments} + } + } else { + args = make(map[string]interface{}) + } + + part := dto.GeminiPart{ + FunctionCall: &dto.FunctionCall{ + FunctionName: toolCall.Function.Name, + Arguments: args, + }, + } + content.Parts = append(content.Parts, part) + } + } else { + // 处理文本内容 + textContent := choice.Message.StringContent() + if textContent != "" { + part := dto.GeminiPart{ + Text: textContent, + } + content.Parts = append(content.Parts, part) + } + } + + candidate.Content = content + geminiResponse.Candidates = append(geminiResponse.Candidates, candidate) + } + + return geminiResponse +} + +// StreamResponseOpenAI2Gemini 将 OpenAI 流式响应转换为 Gemini 格式 +func StreamResponseOpenAI2Gemini(openAIResponse *dto.ChatCompletionsStreamResponse, info *relaycommon.RelayInfo) *dto.GeminiChatResponse { + // 检查是否有实际内容或结束标志 + hasContent := false + hasFinishReason := false + for _, choice := range openAIResponse.Choices { + if len(choice.Delta.GetContentString()) > 0 || (choice.Delta.ToolCalls != nil && len(choice.Delta.ToolCalls) > 0) { + hasContent = true + } + if choice.FinishReason != nil { + hasFinishReason = true + } + } + + // 如果没有实际内容且没有结束标志,跳过。主要针对 openai 流响应开头的空数据 + if !hasContent && !hasFinishReason { + return nil + } + + geminiResponse := &dto.GeminiChatResponse{ + Candidates: make([]dto.GeminiChatCandidate, 0, len(openAIResponse.Choices)), + PromptFeedback: dto.GeminiChatPromptFeedback{ + SafetyRatings: []dto.GeminiChatSafetyRating{}, + }, + UsageMetadata: dto.GeminiUsageMetadata{ + PromptTokenCount: info.PromptTokens, + CandidatesTokenCount: 0, // 流式响应中可能没有完整的 usage 信息 + TotalTokenCount: info.PromptTokens, + }, + } + + for _, choice := range openAIResponse.Choices { + candidate := dto.GeminiChatCandidate{ + Index: int64(choice.Index), + SafetyRatings: []dto.GeminiChatSafetyRating{}, + } + + // 设置结束原因 + if choice.FinishReason != nil { + var finishReason string + switch *choice.FinishReason { + case "stop": + finishReason = "STOP" + case "length": + finishReason = "MAX_TOKENS" + case "content_filter": + finishReason = "SAFETY" + case "tool_calls": + finishReason = "STOP" + default: + finishReason = "STOP" + } + candidate.FinishReason = &finishReason + } + + // 转换消息内容 + content := dto.GeminiChatContent{ + Role: "model", + Parts: make([]dto.GeminiPart, 0), + } + + // 处理工具调用 + if choice.Delta.ToolCalls != nil { + for _, toolCall := range choice.Delta.ToolCalls { + // 解析参数 + var args map[string]interface{} + if toolCall.Function.Arguments != "" { + if err := json.Unmarshal([]byte(toolCall.Function.Arguments), &args); err != nil { + args = map[string]interface{}{"arguments": toolCall.Function.Arguments} + } + } else { + args = make(map[string]interface{}) + } + + part := dto.GeminiPart{ + FunctionCall: &dto.FunctionCall{ + FunctionName: toolCall.Function.Name, + Arguments: args, + }, + } + content.Parts = append(content.Parts, part) + } + } else { + // 处理文本内容 + textContent := choice.Delta.GetContentString() + if textContent != "" { + part := dto.GeminiPart{ + Text: textContent, + } + content.Parts = append(content.Parts, part) + } + } + + candidate.Content = content + geminiResponse.Candidates = append(geminiResponse.Candidates, candidate) + } + + return geminiResponse +}