fix: gemini&claude tool call format #795 #766

This commit is contained in:
1808837298@qq.com
2025-02-26 23:56:10 +08:00
parent 6d8d40e67b
commit 13ab0f8e4f
6 changed files with 85 additions and 89 deletions

View File

@@ -42,7 +42,7 @@ type GeneralOpenAIRequest struct {
ResponseFormat *ResponseFormat `json:"response_format,omitempty"` ResponseFormat *ResponseFormat `json:"response_format,omitempty"`
EncodingFormat any `json:"encoding_format,omitempty"` EncodingFormat any `json:"encoding_format,omitempty"`
Seed float64 `json:"seed,omitempty"` Seed float64 `json:"seed,omitempty"`
Tools []ToolCall `json:"tools,omitempty"` Tools []ToolCallRequest `json:"tools,omitempty"`
ToolChoice any `json:"tool_choice,omitempty"` ToolChoice any `json:"tool_choice,omitempty"`
User string `json:"user,omitempty"` User string `json:"user,omitempty"`
LogProbs bool `json:"logprobs,omitempty"` LogProbs bool `json:"logprobs,omitempty"`
@@ -53,15 +53,17 @@ type GeneralOpenAIRequest struct {
ExtraBody any `json:"extra_body,omitempty"` ExtraBody any `json:"extra_body,omitempty"`
} }
type OpenAITools struct { type ToolCallRequest struct {
ID string `json:"id,omitempty"`
Type string `json:"type"` Type string `json:"type"`
Function OpenAIFunction `json:"function"` Function FunctionRequest `json:"function"`
} }
type OpenAIFunction struct { type FunctionRequest struct {
Description string `json:"description,omitempty"` Description string `json:"description,omitempty"`
Name string `json:"name"` Name string `json:"name"`
Parameters any `json:"parameters,omitempty"` Parameters any `json:"parameters,omitempty"`
Arguments string `json:"arguments,omitempty"`
} }
type StreamOptions struct { type StreamOptions struct {
@@ -137,11 +139,11 @@ func (m *Message) SetPrefix(prefix bool) {
m.Prefix = &prefix m.Prefix = &prefix
} }
func (m *Message) ParseToolCalls() []ToolCall { func (m *Message) ParseToolCalls() []ToolCallRequest {
if m.ToolCalls == nil { if m.ToolCalls == nil {
return nil return nil
} }
var toolCalls []ToolCall var toolCalls []ToolCallRequest
if err := json.Unmarshal(m.ToolCalls, &toolCalls); err == nil { if err := json.Unmarshal(m.ToolCalls, &toolCalls); err == nil {
return toolCalls return toolCalls
} }

View File

@@ -65,7 +65,7 @@ type ChatCompletionsStreamResponseChoiceDelta struct {
Content *string `json:"content,omitempty"` Content *string `json:"content,omitempty"`
ReasoningContent *string `json:"reasoning_content,omitempty"` ReasoningContent *string `json:"reasoning_content,omitempty"`
Role string `json:"role,omitempty"` Role string `json:"role,omitempty"`
ToolCalls []ToolCall `json:"tool_calls,omitempty"` ToolCalls []ToolCallResponse `json:"tool_calls,omitempty"`
} }
func (c *ChatCompletionsStreamResponseChoiceDelta) SetContentString(s string) { func (c *ChatCompletionsStreamResponseChoiceDelta) SetContentString(s string) {
@@ -90,24 +90,24 @@ func (c *ChatCompletionsStreamResponseChoiceDelta) SetReasoningContent(s string)
c.ReasoningContent = &s c.ReasoningContent = &s
} }
type ToolCall struct { type ToolCallResponse struct {
// Index is not nil only in chat completion chunk object // Index is not nil only in chat completion chunk object
Index *int `json:"index,omitempty"` Index *int `json:"index,omitempty"`
ID string `json:"id,omitempty"` ID string `json:"id,omitempty"`
Type any `json:"type"` Type any `json:"type"`
Function FunctionCall `json:"function"` Function FunctionResponse `json:"function"`
} }
func (c *ToolCall) SetIndex(i int) { func (c *ToolCallResponse) SetIndex(i int) {
c.Index = &i c.Index = &i
} }
type FunctionCall struct { type FunctionResponse struct {
Description string `json:"description,omitempty"` Description string `json:"description,omitempty"`
Name string `json:"name,omitempty"` Name string `json:"name,omitempty"`
// call function with arguments in JSON format // call function with arguments in JSON format
Parameters any `json:"parameters,omitempty"` // request Parameters any `json:"parameters,omitempty"` // request
Arguments string `json:"arguments"` Arguments string `json:"arguments"` // response
} }
type ChatCompletionsStreamResponse struct { type ChatCompletionsStreamResponse struct {

View File

@@ -296,7 +296,7 @@ func StreamResponseClaude2OpenAI(reqMode int, claudeResponse *ClaudeResponse) (*
response.Object = "chat.completion.chunk" response.Object = "chat.completion.chunk"
response.Model = claudeResponse.Model response.Model = claudeResponse.Model
response.Choices = make([]dto.ChatCompletionsStreamResponseChoice, 0) response.Choices = make([]dto.ChatCompletionsStreamResponseChoice, 0)
tools := make([]dto.ToolCall, 0) tools := make([]dto.ToolCallResponse, 0)
var choice dto.ChatCompletionsStreamResponseChoice var choice dto.ChatCompletionsStreamResponseChoice
if reqMode == RequestModeCompletion { if reqMode == RequestModeCompletion {
choice.Delta.SetContentString(claudeResponse.Completion) choice.Delta.SetContentString(claudeResponse.Completion)
@@ -315,10 +315,10 @@ func StreamResponseClaude2OpenAI(reqMode int, claudeResponse *ClaudeResponse) (*
if claudeResponse.ContentBlock != nil { if claudeResponse.ContentBlock != nil {
//choice.Delta.SetContentString(claudeResponse.ContentBlock.Text) //choice.Delta.SetContentString(claudeResponse.ContentBlock.Text)
if claudeResponse.ContentBlock.Type == "tool_use" { if claudeResponse.ContentBlock.Type == "tool_use" {
tools = append(tools, dto.ToolCall{ tools = append(tools, dto.ToolCallResponse{
ID: claudeResponse.ContentBlock.Id, ID: claudeResponse.ContentBlock.Id,
Type: "function", Type: "function",
Function: dto.FunctionCall{ Function: dto.FunctionResponse{
Name: claudeResponse.ContentBlock.Name, Name: claudeResponse.ContentBlock.Name,
Arguments: "", Arguments: "",
}, },
@@ -333,8 +333,8 @@ func StreamResponseClaude2OpenAI(reqMode int, claudeResponse *ClaudeResponse) (*
choice.Delta.SetContentString(claudeResponse.Delta.Text) choice.Delta.SetContentString(claudeResponse.Delta.Text)
switch claudeResponse.Delta.Type { switch claudeResponse.Delta.Type {
case "input_json_delta": case "input_json_delta":
tools = append(tools, dto.ToolCall{ tools = append(tools, dto.ToolCallResponse{
Function: dto.FunctionCall{ Function: dto.FunctionResponse{
Arguments: claudeResponse.Delta.PartialJson, Arguments: claudeResponse.Delta.PartialJson,
}, },
}) })
@@ -382,7 +382,7 @@ func ResponseClaude2OpenAI(reqMode int, claudeResponse *ClaudeResponse) *dto.Ope
if len(claudeResponse.Content) > 0 { if len(claudeResponse.Content) > 0 {
responseText = claudeResponse.Content[0].Text responseText = claudeResponse.Content[0].Text
} }
tools := make([]dto.ToolCall, 0) tools := make([]dto.ToolCallResponse, 0)
thinkingContent := "" thinkingContent := ""
if reqMode == RequestModeCompletion { if reqMode == RequestModeCompletion {
@@ -403,10 +403,10 @@ func ResponseClaude2OpenAI(reqMode int, claudeResponse *ClaudeResponse) *dto.Ope
switch message.Type { switch message.Type {
case "tool_use": case "tool_use":
args, _ := json.Marshal(message.Input) args, _ := json.Marshal(message.Input)
tools = append(tools, dto.ToolCall{ tools = append(tools, dto.ToolCallResponse{
ID: message.Id, ID: message.Id,
Type: "function", // compatible with other OpenAI derivative applications Type: "function", // compatible with other OpenAI derivative applications
Function: dto.FunctionCall{ Function: dto.FunctionResponse{
Name: message.Name, Name: message.Name,
Arguments: string(args), Arguments: string(args),
}, },

View File

@@ -43,7 +43,7 @@ func CovertGemini2OpenAI(textRequest dto.GeneralOpenAIRequest) (*GeminiChatReque
// openaiContent.FuncToToolCalls() // openaiContent.FuncToToolCalls()
if textRequest.Tools != nil { if textRequest.Tools != nil {
functions := make([]dto.FunctionCall, 0, len(textRequest.Tools)) functions := make([]dto.FunctionRequest, 0, len(textRequest.Tools))
googleSearch := false googleSearch := false
codeExecution := false codeExecution := false
for _, tool := range textRequest.Tools { for _, tool := range textRequest.Tools {
@@ -338,7 +338,7 @@ func unescapeMapOrSlice(data interface{}) interface{} {
return data return data
} }
func getToolCall(item *GeminiPart) *dto.ToolCall { func getResponseToolCall(item *GeminiPart) *dto.ToolCallResponse {
var argsBytes []byte var argsBytes []byte
var err error var err error
if result, ok := item.FunctionCall.Arguments.(map[string]interface{}); ok { if result, ok := item.FunctionCall.Arguments.(map[string]interface{}); ok {
@@ -350,10 +350,10 @@ func getToolCall(item *GeminiPart) *dto.ToolCall {
if err != nil { if err != nil {
return nil return nil
} }
return &dto.ToolCall{ return &dto.ToolCallResponse{
ID: fmt.Sprintf("call_%s", common.GetUUID()), ID: fmt.Sprintf("call_%s", common.GetUUID()),
Type: "function", Type: "function",
Function: dto.FunctionCall{ Function: dto.FunctionResponse{
Arguments: string(argsBytes), Arguments: string(argsBytes),
Name: item.FunctionCall.FunctionName, Name: item.FunctionCall.FunctionName,
}, },
@@ -380,11 +380,11 @@ func responseGeminiChat2OpenAI(response *GeminiChatResponse) *dto.OpenAITextResp
} }
if len(candidate.Content.Parts) > 0 { if len(candidate.Content.Parts) > 0 {
var texts []string var texts []string
var toolCalls []dto.ToolCall var toolCalls []dto.ToolCallResponse
for _, part := range candidate.Content.Parts { for _, part := range candidate.Content.Parts {
if part.FunctionCall != nil { if part.FunctionCall != nil {
choice.FinishReason = constant.FinishReasonToolCalls choice.FinishReason = constant.FinishReasonToolCalls
if call := getToolCall(&part); call != nil { if call := getResponseToolCall(&part); call != nil {
toolCalls = append(toolCalls, *call) toolCalls = append(toolCalls, *call)
} }
} else { } else {
@@ -457,7 +457,7 @@ func streamResponseGeminiChat2OpenAI(geminiResponse *GeminiChatResponse) (*dto.C
for _, part := range candidate.Content.Parts { for _, part := range candidate.Content.Parts {
if part.FunctionCall != nil { if part.FunctionCall != nil {
isTools = true isTools = true
if call := getToolCall(&part); call != nil { if call := getResponseToolCall(&part); call != nil {
call.SetIndex(len(choice.Delta.ToolCalls)) call.SetIndex(len(choice.Delta.ToolCalls))
choice.Delta.ToolCalls = append(choice.Delta.ToolCalls, *call) choice.Delta.ToolCalls = append(choice.Delta.ToolCalls, *call)
} }

View File

@@ -12,7 +12,7 @@ type OllamaRequest struct {
TopK int `json:"top_k,omitempty"` TopK int `json:"top_k,omitempty"`
Stop any `json:"stop,omitempty"` Stop any `json:"stop,omitempty"`
MaxTokens uint `json:"max_tokens,omitempty"` MaxTokens uint `json:"max_tokens,omitempty"`
Tools []dto.ToolCall `json:"tools,omitempty"` Tools []dto.ToolCallRequest `json:"tools,omitempty"`
ResponseFormat any `json:"response_format,omitempty"` ResponseFormat any `json:"response_format,omitempty"`
FrequencyPenalty float64 `json:"frequency_penalty,omitempty"` FrequencyPenalty float64 `json:"frequency_penalty,omitempty"`
PresencePenalty float64 `json:"presence_penalty,omitempty"` PresencePenalty float64 `json:"presence_penalty,omitempty"`

View File

@@ -1,7 +1,6 @@
package service package service
import ( import (
"encoding/json"
"errors" "errors"
"fmt" "fmt"
"image" "image"
@@ -170,12 +169,7 @@ func CountTokenChatRequest(info *relaycommon.RelayInfo, request dto.GeneralOpenA
} }
tkm += msgTokens tkm += msgTokens
if request.Tools != nil { if request.Tools != nil {
toolsData, _ := json.Marshal(request.Tools) openaiTools := request.Tools
var openaiTools []dto.OpenAITools
err := json.Unmarshal(toolsData, &openaiTools)
if err != nil {
return 0, errors.New(fmt.Sprintf("count_tools_token_fail: %s", err.Error()))
}
countStr := "" countStr := ""
for _, tool := range openaiTools { for _, tool := range openaiTools {
countStr = tool.Function.Name countStr = tool.Function.Name