From bdc65bdba20e2d5123f2678265cb67676c09de9d Mon Sep 17 00:00:00 2001 From: CaIon <1808837298@qq.com> Date: Mon, 22 Apr 2024 16:35:56 +0800 Subject: [PATCH] =?UTF-8?q?feat:=20=E5=90=AF=E7=94=A8=E5=87=BD=E6=95=B0?= =?UTF-8?q?=E8=AE=A1=E8=B4=B9?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- common/model-ratio.go | 31 ++++++++++++++-------------- dto/text_response.go | 30 ++++++++++++++++++++------- relay/channel/openai/relay-openai.go | 12 +++++++++++ service/token_counter.go | 17 +++++++++++++++ 4 files changed, 68 insertions(+), 22 deletions(-) diff --git a/common/model-ratio.go b/common/model-ratio.go index e7ffe421..ba9c71cf 100644 --- a/common/model-ratio.go +++ b/common/model-ratio.go @@ -12,23 +12,24 @@ import ( // TODO: when a new api is enabled, check the pricing here // 1 === $0.002 / 1K tokens // 1 === ¥0.014 / 1k tokens + var DefaultModelRatio = map[string]float64{ //"midjourney": 50, - "gpt-4-gizmo-*": 15, - "gpt-4": 15, - "gpt-4-0314": 15, - "gpt-4-0613": 15, - "gpt-4-32k": 30, - "gpt-4-32k-0314": 30, - "gpt-4-32k-0613": 30, - "gpt-4-1106-preview": 5, // $0.01 / 1K tokens - "gpt-4-0125-preview": 5, // $0.01 / 1K tokens - "gpt-4-turbo-preview": 5, // $0.01 / 1K tokens - "gpt-4-vision-preview": 5, // $0.01 / 1K tokens - "gpt-4-1106-vision-preview": 5, // $0.01 / 1K tokens - "gpt-4-turbo": 5, // $0.01 / 1K tokens - "gpt-3.5-turbo": 0.25, // $0.0015 / 1K tokens - "gpt-3.5-turbo-0301": 0.75, + "gpt-4-gizmo-*": 15, + "gpt-4": 15, + //"gpt-4-0314": 15, //deprecated + "gpt-4-0613": 15, + "gpt-4-32k": 30, + //"gpt-4-32k-0314": 30, //deprecated + "gpt-4-32k-0613": 30, + "gpt-4-1106-preview": 5, // $0.01 / 1K tokens + "gpt-4-0125-preview": 5, // $0.01 / 1K tokens + "gpt-4-turbo-preview": 5, // $0.01 / 1K tokens + "gpt-4-vision-preview": 5, // $0.01 / 1K tokens + "gpt-4-1106-vision-preview": 5, // $0.01 / 1K tokens + "gpt-4-turbo": 5, // $0.01 / 1K tokens + "gpt-3.5-turbo": 0.25, // $0.0015 / 1K tokens + //"gpt-3.5-turbo-0301": 0.75, //deprecated "gpt-3.5-turbo-0613": 0.75, "gpt-3.5-turbo-16k": 1.5, // $0.003 / 1K tokens "gpt-3.5-turbo-16k-0613": 1.5, diff --git a/dto/text_response.go b/dto/text_response.go index 98275fe4..a589d75e 100644 --- a/dto/text_response.go +++ b/dto/text_response.go @@ -54,13 +54,29 @@ type OpenAIEmbeddingResponse struct { } type ChatCompletionsStreamResponseChoice struct { - Delta struct { - Content string `json:"content"` - Role string `json:"role,omitempty"` - ToolCalls any `json:"tool_calls,omitempty"` - } `json:"delta"` - FinishReason *string `json:"finish_reason,omitempty"` - Index int `json:"index,omitempty"` + Delta ChatCompletionsStreamResponseChoiceDelta `json:"delta"` + FinishReason *string `json:"finish_reason,omitempty"` + Index int `json:"index,omitempty"` +} + +type ChatCompletionsStreamResponseChoiceDelta struct { + Content string `json:"content"` + Role string `json:"role,omitempty"` + ToolCalls []ToolCall `json:"tool_calls,omitempty"` +} + +type ToolCall struct { + // Index is not nil only in chat completion chunk object + Index *int `json:"index,omitempty"` + ID string `json:"id"` + Type any `json:"type"` + Function FunctionCall `json:"function"` +} + +type FunctionCall struct { + Name string `json:"name,omitempty"` + // call function with arguments in JSON format + Arguments string `json:"arguments,omitempty"` } type ChatCompletionsStreamResponse struct { diff --git a/relay/channel/openai/relay-openai.go b/relay/channel/openai/relay-openai.go index fe5cd48f..dae9fd55 100644 --- a/relay/channel/openai/relay-openai.go +++ b/relay/channel/openai/relay-openai.go @@ -68,6 +68,12 @@ func OpenaiStreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*d if err == nil { for _, choice := range streamResponse.Choices { responseTextBuilder.WriteString(choice.Delta.Content) + if choice.Delta.ToolCalls != nil { + for _, tool := range choice.Delta.ToolCalls { + responseTextBuilder.WriteString(tool.Function.Name) + responseTextBuilder.WriteString(tool.Function.Arguments) + } + } } } } @@ -75,6 +81,12 @@ func OpenaiStreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*d for _, streamResponse := range streamResponses { for _, choice := range streamResponse.Choices { responseTextBuilder.WriteString(choice.Delta.Content) + if choice.Delta.ToolCalls != nil { + for _, tool := range choice.Delta.ToolCalls { + responseTextBuilder.WriteString(tool.Function.Name) + responseTextBuilder.WriteString(tool.Function.Arguments) + } + } } } } diff --git a/service/token_counter.go b/service/token_counter.go index 5255c80f..897f49c1 100644 --- a/service/token_counter.go +++ b/service/token_counter.go @@ -211,6 +211,23 @@ func CountTokenInput(input any, model string, check bool) (int, error, bool) { return CountTokenInput(fmt.Sprintf("%v", input), model, check) } +func CountTokenStreamChoices(messages []dto.ChatCompletionsStreamResponseChoice, model string) int { + tokens := 0 + for _, message := range messages { + tkm, _, _ := CountTokenInput(message.Delta.Content, model, false) + tokens += tkm + if message.Delta.ToolCalls != nil { + for _, tool := range message.Delta.ToolCalls { + tkm, _, _ := CountTokenInput(tool.Function.Name, model, false) + tokens += tkm + tkm, _, _ = CountTokenInput(tool.Function.Arguments, model, false) + tokens += tkm + } + } + } + return tokens +} + func CountAudioToken(text string, model string, check bool) (int, error, bool) { if strings.HasPrefix(model, "tts") { contains, words := SensitiveWordContains(text)