From 195ab1fdd5f51b86cecba3d22b0b55a5a0538da1 Mon Sep 17 00:00:00 2001 From: CalciumIon <1808837298@qq.com> Date: Fri, 6 Dec 2024 14:31:27 +0800 Subject: [PATCH] feat: add gemini tool_calls finish reason --- constant/finish_reason.go | 6 ++++++ relay/channel/baidu/relay-baidu.go | 4 ++-- relay/channel/gemini/relay-gemini.go | 6 ++++-- relay/channel/palm/relay-palm.go | 4 ++-- relay/channel/tencent/relay-tencent.go | 4 ++-- relay/channel/xunfei/relay-xunfei.go | 6 +++--- relay/channel/zhipu/relay-zhipu.go | 4 ++-- relay/common/relay_utils.go | 2 -- 8 files changed, 21 insertions(+), 15 deletions(-) create mode 100644 constant/finish_reason.go diff --git a/constant/finish_reason.go b/constant/finish_reason.go new file mode 100644 index 00000000..8d6289a6 --- /dev/null +++ b/constant/finish_reason.go @@ -0,0 +1,6 @@ +package constant + +var ( + FinishReasonStop = "stop" + FinishReasonToolCalls = "tool_calls" +) diff --git a/relay/channel/baidu/relay-baidu.go b/relay/channel/baidu/relay-baidu.go index 918108c5..09a99e4d 100644 --- a/relay/channel/baidu/relay-baidu.go +++ b/relay/channel/baidu/relay-baidu.go @@ -9,8 +9,8 @@ import ( "io" "net/http" "one-api/common" + "one-api/constant" "one-api/dto" - relaycommon "one-api/relay/common" "one-api/service" "strings" "sync" @@ -75,7 +75,7 @@ func streamResponseBaidu2OpenAI(baiduResponse *BaiduChatStreamResponse) *dto.Cha var choice dto.ChatCompletionsStreamResponseChoice choice.Delta.SetContentString(baiduResponse.Result) if baiduResponse.IsEnd { - choice.FinishReason = &relaycommon.StopFinishReason + choice.FinishReason = &constant.FinishReasonStop } response := dto.ChatCompletionsStreamResponse{ Id: baiduResponse.Id, diff --git a/relay/channel/gemini/relay-gemini.go b/relay/channel/gemini/relay-gemini.go index f6dba5e4..df262976 100644 --- a/relay/channel/gemini/relay-gemini.go +++ b/relay/channel/gemini/relay-gemini.go @@ -8,6 +8,7 @@ import ( "io" "net/http" "one-api/common" + "one-api/constant" "one-api/dto" relaycommon "one-api/relay/common" "one-api/service" @@ -186,10 +187,11 @@ func responseGeminiChat2OpenAI(response *GeminiChatResponse) *dto.OpenAITextResp Role: "assistant", Content: content, }, - FinishReason: relaycommon.StopFinishReason, + FinishReason: constant.FinishReasonStop, } if len(candidate.Content.Parts) > 0 { if candidate.Content.Parts[0].FunctionCall != nil { + choice.FinishReason = constant.FinishReasonToolCalls choice.Message.ToolCalls = getToolCalls(&candidate) } else { choice.Message.SetStringContent(candidate.Content.Parts[0].Text) @@ -262,7 +264,7 @@ func GeminiChatStreamHandler(c *gin.Context, resp *http.Response, info *relaycom } } - response := service.GenerateStopResponse(id, createAt, info.UpstreamModelName, relaycommon.StopFinishReason) + response := service.GenerateStopResponse(id, createAt, info.UpstreamModelName, constant.FinishReasonStop) service.ObjectData(c, response) usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens diff --git a/relay/channel/palm/relay-palm.go b/relay/channel/palm/relay-palm.go index dfde59f5..02a3e382 100644 --- a/relay/channel/palm/relay-palm.go +++ b/relay/channel/palm/relay-palm.go @@ -7,8 +7,8 @@ import ( "io" "net/http" "one-api/common" + "one-api/constant" "one-api/dto" - relaycommon "one-api/relay/common" "one-api/service" ) @@ -63,7 +63,7 @@ func streamResponsePaLM2OpenAI(palmResponse *PaLMChatResponse) *dto.ChatCompleti if len(palmResponse.Candidates) > 0 { choice.Delta.SetContentString(palmResponse.Candidates[0].Content) } - choice.FinishReason = &relaycommon.StopFinishReason + choice.FinishReason = &constant.FinishReasonStop var response dto.ChatCompletionsStreamResponse response.Object = "chat.completion.chunk" response.Model = "palm2" diff --git a/relay/channel/tencent/relay-tencent.go b/relay/channel/tencent/relay-tencent.go index 3ea23767..6d0f3471 100644 --- a/relay/channel/tencent/relay-tencent.go +++ b/relay/channel/tencent/relay-tencent.go @@ -12,8 +12,8 @@ import ( "io" "net/http" "one-api/common" + "one-api/constant" "one-api/dto" - relaycommon "one-api/relay/common" "one-api/service" "strconv" "strings" @@ -81,7 +81,7 @@ func streamResponseTencent2OpenAI(TencentResponse *TencentChatResponse) *dto.Cha var choice dto.ChatCompletionsStreamResponseChoice choice.Delta.SetContentString(TencentResponse.Choices[0].Delta.Content) if TencentResponse.Choices[0].FinishReason == "stop" { - choice.FinishReason = &relaycommon.StopFinishReason + choice.FinishReason = &constant.FinishReasonStop } response.Choices = append(response.Choices, choice) } diff --git a/relay/channel/xunfei/relay-xunfei.go b/relay/channel/xunfei/relay-xunfei.go index fdcda5d3..f9236973 100644 --- a/relay/channel/xunfei/relay-xunfei.go +++ b/relay/channel/xunfei/relay-xunfei.go @@ -12,8 +12,8 @@ import ( "net/http" "net/url" "one-api/common" + "one-api/constant" "one-api/dto" - relaycommon "one-api/relay/common" "one-api/service" "strings" "time" @@ -67,7 +67,7 @@ func responseXunfei2OpenAI(response *XunfeiChatResponse) *dto.OpenAITextResponse Role: "assistant", Content: content, }, - FinishReason: relaycommon.StopFinishReason, + FinishReason: constant.FinishReasonStop, } fullTextResponse := dto.OpenAITextResponse{ Object: "chat.completion", @@ -89,7 +89,7 @@ func streamResponseXunfei2OpenAI(xunfeiResponse *XunfeiChatResponse) *dto.ChatCo var choice dto.ChatCompletionsStreamResponseChoice choice.Delta.SetContentString(xunfeiResponse.Payload.Choices.Text[0].Content) if xunfeiResponse.Payload.Choices.Status == 2 { - choice.FinishReason = &relaycommon.StopFinishReason + choice.FinishReason = &constant.FinishReasonStop } response := dto.ChatCompletionsStreamResponse{ Object: "chat.completion.chunk", diff --git a/relay/channel/zhipu/relay-zhipu.go b/relay/channel/zhipu/relay-zhipu.go index aaf3c5dd..6bdd1c2a 100644 --- a/relay/channel/zhipu/relay-zhipu.go +++ b/relay/channel/zhipu/relay-zhipu.go @@ -8,8 +8,8 @@ import ( "io" "net/http" "one-api/common" + "one-api/constant" "one-api/dto" - relaycommon "one-api/relay/common" "one-api/service" "strings" "sync" @@ -139,7 +139,7 @@ func streamResponseZhipu2OpenAI(zhipuResponse string) *dto.ChatCompletionsStream func streamMetaResponseZhipu2OpenAI(zhipuResponse *ZhipuStreamMetaResponse) (*dto.ChatCompletionsStreamResponse, *dto.Usage) { var choice dto.ChatCompletionsStreamResponseChoice choice.Delta.SetContentString("") - choice.FinishReason = &relaycommon.StopFinishReason + choice.FinishReason = &constant.FinishReasonStop response := dto.ChatCompletionsStreamResponse{ Id: zhipuResponse.RequestId, Object: "chat.completion.chunk", diff --git a/relay/common/relay_utils.go b/relay/common/relay_utils.go index 6daf003a..7a4f44bb 100644 --- a/relay/common/relay_utils.go +++ b/relay/common/relay_utils.go @@ -10,8 +10,6 @@ import ( "strings" ) -var StopFinishReason = "stop" - func GetFullRequestURL(baseURL string, requestURL string, channelType int) string { fullRequestURL := fmt.Sprintf("%s%s", baseURL, requestURL)