fix: gemini&claude tool call format #795 #766

This commit is contained in:
1808837298@qq.com
2025-02-26 23:56:10 +08:00
parent 6d8d40e67b
commit 13ab0f8e4f
6 changed files with 85 additions and 89 deletions

View File

@@ -18,50 +18,52 @@ type FormatJsonSchema struct {
} }
type GeneralOpenAIRequest struct { type GeneralOpenAIRequest struct {
Model string `json:"model,omitempty"` Model string `json:"model,omitempty"`
Messages []Message `json:"messages,omitempty"` Messages []Message `json:"messages,omitempty"`
Prompt any `json:"prompt,omitempty"` Prompt any `json:"prompt,omitempty"`
Prefix any `json:"prefix,omitempty"` Prefix any `json:"prefix,omitempty"`
Suffix any `json:"suffix,omitempty"` Suffix any `json:"suffix,omitempty"`
Stream bool `json:"stream,omitempty"` Stream bool `json:"stream,omitempty"`
StreamOptions *StreamOptions `json:"stream_options,omitempty"` StreamOptions *StreamOptions `json:"stream_options,omitempty"`
MaxTokens uint `json:"max_tokens,omitempty"` MaxTokens uint `json:"max_tokens,omitempty"`
MaxCompletionTokens uint `json:"max_completion_tokens,omitempty"` MaxCompletionTokens uint `json:"max_completion_tokens,omitempty"`
ReasoningEffort string `json:"reasoning_effort,omitempty"` ReasoningEffort string `json:"reasoning_effort,omitempty"`
Temperature *float64 `json:"temperature,omitempty"` Temperature *float64 `json:"temperature,omitempty"`
TopP float64 `json:"top_p,omitempty"` TopP float64 `json:"top_p,omitempty"`
TopK int `json:"top_k,omitempty"` TopK int `json:"top_k,omitempty"`
Stop any `json:"stop,omitempty"` Stop any `json:"stop,omitempty"`
N int `json:"n,omitempty"` N int `json:"n,omitempty"`
Input any `json:"input,omitempty"` Input any `json:"input,omitempty"`
Instruction string `json:"instruction,omitempty"` Instruction string `json:"instruction,omitempty"`
Size string `json:"size,omitempty"` Size string `json:"size,omitempty"`
Functions any `json:"functions,omitempty"` Functions any `json:"functions,omitempty"`
FrequencyPenalty float64 `json:"frequency_penalty,omitempty"` FrequencyPenalty float64 `json:"frequency_penalty,omitempty"`
PresencePenalty float64 `json:"presence_penalty,omitempty"` PresencePenalty float64 `json:"presence_penalty,omitempty"`
ResponseFormat *ResponseFormat `json:"response_format,omitempty"` ResponseFormat *ResponseFormat `json:"response_format,omitempty"`
EncodingFormat any `json:"encoding_format,omitempty"` EncodingFormat any `json:"encoding_format,omitempty"`
Seed float64 `json:"seed,omitempty"` Seed float64 `json:"seed,omitempty"`
Tools []ToolCall `json:"tools,omitempty"` Tools []ToolCallRequest `json:"tools,omitempty"`
ToolChoice any `json:"tool_choice,omitempty"` ToolChoice any `json:"tool_choice,omitempty"`
User string `json:"user,omitempty"` User string `json:"user,omitempty"`
LogProbs bool `json:"logprobs,omitempty"` LogProbs bool `json:"logprobs,omitempty"`
TopLogProbs int `json:"top_logprobs,omitempty"` TopLogProbs int `json:"top_logprobs,omitempty"`
Dimensions int `json:"dimensions,omitempty"` Dimensions int `json:"dimensions,omitempty"`
Modalities any `json:"modalities,omitempty"` Modalities any `json:"modalities,omitempty"`
Audio any `json:"audio,omitempty"` Audio any `json:"audio,omitempty"`
ExtraBody any `json:"extra_body,omitempty"` ExtraBody any `json:"extra_body,omitempty"`
} }
type OpenAITools struct { type ToolCallRequest struct {
Type string `json:"type"` ID string `json:"id,omitempty"`
Function OpenAIFunction `json:"function"` Type string `json:"type"`
Function FunctionRequest `json:"function"`
} }
type OpenAIFunction struct { type FunctionRequest struct {
Description string `json:"description,omitempty"` Description string `json:"description,omitempty"`
Name string `json:"name"` Name string `json:"name"`
Parameters any `json:"parameters,omitempty"` Parameters any `json:"parameters,omitempty"`
Arguments string `json:"arguments,omitempty"`
} }
type StreamOptions struct { type StreamOptions struct {
@@ -137,11 +139,11 @@ func (m *Message) SetPrefix(prefix bool) {
m.Prefix = &prefix m.Prefix = &prefix
} }
func (m *Message) ParseToolCalls() []ToolCall { func (m *Message) ParseToolCalls() []ToolCallRequest {
if m.ToolCalls == nil { if m.ToolCalls == nil {
return nil return nil
} }
var toolCalls []ToolCall var toolCalls []ToolCallRequest
if err := json.Unmarshal(m.ToolCalls, &toolCalls); err == nil { if err := json.Unmarshal(m.ToolCalls, &toolCalls); err == nil {
return toolCalls return toolCalls
} }

View File

@@ -62,10 +62,10 @@ type ChatCompletionsStreamResponseChoice struct {
} }
type ChatCompletionsStreamResponseChoiceDelta struct { type ChatCompletionsStreamResponseChoiceDelta struct {
Content *string `json:"content,omitempty"` Content *string `json:"content,omitempty"`
ReasoningContent *string `json:"reasoning_content,omitempty"` ReasoningContent *string `json:"reasoning_content,omitempty"`
Role string `json:"role,omitempty"` Role string `json:"role,omitempty"`
ToolCalls []ToolCall `json:"tool_calls,omitempty"` ToolCalls []ToolCallResponse `json:"tool_calls,omitempty"`
} }
func (c *ChatCompletionsStreamResponseChoiceDelta) SetContentString(s string) { func (c *ChatCompletionsStreamResponseChoiceDelta) SetContentString(s string) {
@@ -90,24 +90,24 @@ func (c *ChatCompletionsStreamResponseChoiceDelta) SetReasoningContent(s string)
c.ReasoningContent = &s c.ReasoningContent = &s
} }
type ToolCall struct { type ToolCallResponse struct {
// Index is not nil only in chat completion chunk object // Index is not nil only in chat completion chunk object
Index *int `json:"index,omitempty"` Index *int `json:"index,omitempty"`
ID string `json:"id,omitempty"` ID string `json:"id,omitempty"`
Type any `json:"type"` Type any `json:"type"`
Function FunctionCall `json:"function"` Function FunctionResponse `json:"function"`
} }
func (c *ToolCall) SetIndex(i int) { func (c *ToolCallResponse) SetIndex(i int) {
c.Index = &i c.Index = &i
} }
type FunctionCall struct { type FunctionResponse struct {
Description string `json:"description,omitempty"` Description string `json:"description,omitempty"`
Name string `json:"name,omitempty"` Name string `json:"name,omitempty"`
// call function with arguments in JSON format // call function with arguments in JSON format
Parameters any `json:"parameters,omitempty"` // request Parameters any `json:"parameters,omitempty"` // request
Arguments string `json:"arguments"` Arguments string `json:"arguments"` // response
} }
type ChatCompletionsStreamResponse struct { type ChatCompletionsStreamResponse struct {

View File

@@ -296,7 +296,7 @@ func StreamResponseClaude2OpenAI(reqMode int, claudeResponse *ClaudeResponse) (*
response.Object = "chat.completion.chunk" response.Object = "chat.completion.chunk"
response.Model = claudeResponse.Model response.Model = claudeResponse.Model
response.Choices = make([]dto.ChatCompletionsStreamResponseChoice, 0) response.Choices = make([]dto.ChatCompletionsStreamResponseChoice, 0)
tools := make([]dto.ToolCall, 0) tools := make([]dto.ToolCallResponse, 0)
var choice dto.ChatCompletionsStreamResponseChoice var choice dto.ChatCompletionsStreamResponseChoice
if reqMode == RequestModeCompletion { if reqMode == RequestModeCompletion {
choice.Delta.SetContentString(claudeResponse.Completion) choice.Delta.SetContentString(claudeResponse.Completion)
@@ -315,10 +315,10 @@ func StreamResponseClaude2OpenAI(reqMode int, claudeResponse *ClaudeResponse) (*
if claudeResponse.ContentBlock != nil { if claudeResponse.ContentBlock != nil {
//choice.Delta.SetContentString(claudeResponse.ContentBlock.Text) //choice.Delta.SetContentString(claudeResponse.ContentBlock.Text)
if claudeResponse.ContentBlock.Type == "tool_use" { if claudeResponse.ContentBlock.Type == "tool_use" {
tools = append(tools, dto.ToolCall{ tools = append(tools, dto.ToolCallResponse{
ID: claudeResponse.ContentBlock.Id, ID: claudeResponse.ContentBlock.Id,
Type: "function", Type: "function",
Function: dto.FunctionCall{ Function: dto.FunctionResponse{
Name: claudeResponse.ContentBlock.Name, Name: claudeResponse.ContentBlock.Name,
Arguments: "", Arguments: "",
}, },
@@ -333,8 +333,8 @@ func StreamResponseClaude2OpenAI(reqMode int, claudeResponse *ClaudeResponse) (*
choice.Delta.SetContentString(claudeResponse.Delta.Text) choice.Delta.SetContentString(claudeResponse.Delta.Text)
switch claudeResponse.Delta.Type { switch claudeResponse.Delta.Type {
case "input_json_delta": case "input_json_delta":
tools = append(tools, dto.ToolCall{ tools = append(tools, dto.ToolCallResponse{
Function: dto.FunctionCall{ Function: dto.FunctionResponse{
Arguments: claudeResponse.Delta.PartialJson, Arguments: claudeResponse.Delta.PartialJson,
}, },
}) })
@@ -382,7 +382,7 @@ func ResponseClaude2OpenAI(reqMode int, claudeResponse *ClaudeResponse) *dto.Ope
if len(claudeResponse.Content) > 0 { if len(claudeResponse.Content) > 0 {
responseText = claudeResponse.Content[0].Text responseText = claudeResponse.Content[0].Text
} }
tools := make([]dto.ToolCall, 0) tools := make([]dto.ToolCallResponse, 0)
thinkingContent := "" thinkingContent := ""
if reqMode == RequestModeCompletion { if reqMode == RequestModeCompletion {
@@ -403,10 +403,10 @@ func ResponseClaude2OpenAI(reqMode int, claudeResponse *ClaudeResponse) *dto.Ope
switch message.Type { switch message.Type {
case "tool_use": case "tool_use":
args, _ := json.Marshal(message.Input) args, _ := json.Marshal(message.Input)
tools = append(tools, dto.ToolCall{ tools = append(tools, dto.ToolCallResponse{
ID: message.Id, ID: message.Id,
Type: "function", // compatible with other OpenAI derivative applications Type: "function", // compatible with other OpenAI derivative applications
Function: dto.FunctionCall{ Function: dto.FunctionResponse{
Name: message.Name, Name: message.Name,
Arguments: string(args), Arguments: string(args),
}, },

View File

@@ -43,7 +43,7 @@ func CovertGemini2OpenAI(textRequest dto.GeneralOpenAIRequest) (*GeminiChatReque
// openaiContent.FuncToToolCalls() // openaiContent.FuncToToolCalls()
if textRequest.Tools != nil { if textRequest.Tools != nil {
functions := make([]dto.FunctionCall, 0, len(textRequest.Tools)) functions := make([]dto.FunctionRequest, 0, len(textRequest.Tools))
googleSearch := false googleSearch := false
codeExecution := false codeExecution := false
for _, tool := range textRequest.Tools { for _, tool := range textRequest.Tools {
@@ -338,7 +338,7 @@ func unescapeMapOrSlice(data interface{}) interface{} {
return data return data
} }
func getToolCall(item *GeminiPart) *dto.ToolCall { func getResponseToolCall(item *GeminiPart) *dto.ToolCallResponse {
var argsBytes []byte var argsBytes []byte
var err error var err error
if result, ok := item.FunctionCall.Arguments.(map[string]interface{}); ok { if result, ok := item.FunctionCall.Arguments.(map[string]interface{}); ok {
@@ -350,10 +350,10 @@ func getToolCall(item *GeminiPart) *dto.ToolCall {
if err != nil { if err != nil {
return nil return nil
} }
return &dto.ToolCall{ return &dto.ToolCallResponse{
ID: fmt.Sprintf("call_%s", common.GetUUID()), ID: fmt.Sprintf("call_%s", common.GetUUID()),
Type: "function", Type: "function",
Function: dto.FunctionCall{ Function: dto.FunctionResponse{
Arguments: string(argsBytes), Arguments: string(argsBytes),
Name: item.FunctionCall.FunctionName, Name: item.FunctionCall.FunctionName,
}, },
@@ -380,11 +380,11 @@ func responseGeminiChat2OpenAI(response *GeminiChatResponse) *dto.OpenAITextResp
} }
if len(candidate.Content.Parts) > 0 { if len(candidate.Content.Parts) > 0 {
var texts []string var texts []string
var toolCalls []dto.ToolCall var toolCalls []dto.ToolCallResponse
for _, part := range candidate.Content.Parts { for _, part := range candidate.Content.Parts {
if part.FunctionCall != nil { if part.FunctionCall != nil {
choice.FinishReason = constant.FinishReasonToolCalls choice.FinishReason = constant.FinishReasonToolCalls
if call := getToolCall(&part); call != nil { if call := getResponseToolCall(&part); call != nil {
toolCalls = append(toolCalls, *call) toolCalls = append(toolCalls, *call)
} }
} else { } else {
@@ -457,7 +457,7 @@ func streamResponseGeminiChat2OpenAI(geminiResponse *GeminiChatResponse) (*dto.C
for _, part := range candidate.Content.Parts { for _, part := range candidate.Content.Parts {
if part.FunctionCall != nil { if part.FunctionCall != nil {
isTools = true isTools = true
if call := getToolCall(&part); call != nil { if call := getResponseToolCall(&part); call != nil {
call.SetIndex(len(choice.Delta.ToolCalls)) call.SetIndex(len(choice.Delta.ToolCalls))
choice.Delta.ToolCalls = append(choice.Delta.ToolCalls, *call) choice.Delta.ToolCalls = append(choice.Delta.ToolCalls, *call)
} }

View File

@@ -3,22 +3,22 @@ package ollama
import "one-api/dto" import "one-api/dto"
type OllamaRequest struct { type OllamaRequest struct {
Model string `json:"model,omitempty"` Model string `json:"model,omitempty"`
Messages []dto.Message `json:"messages,omitempty"` Messages []dto.Message `json:"messages,omitempty"`
Stream bool `json:"stream,omitempty"` Stream bool `json:"stream,omitempty"`
Temperature *float64 `json:"temperature,omitempty"` Temperature *float64 `json:"temperature,omitempty"`
Seed float64 `json:"seed,omitempty"` Seed float64 `json:"seed,omitempty"`
Topp float64 `json:"top_p,omitempty"` Topp float64 `json:"top_p,omitempty"`
TopK int `json:"top_k,omitempty"` TopK int `json:"top_k,omitempty"`
Stop any `json:"stop,omitempty"` Stop any `json:"stop,omitempty"`
MaxTokens uint `json:"max_tokens,omitempty"` MaxTokens uint `json:"max_tokens,omitempty"`
Tools []dto.ToolCall `json:"tools,omitempty"` Tools []dto.ToolCallRequest `json:"tools,omitempty"`
ResponseFormat any `json:"response_format,omitempty"` ResponseFormat any `json:"response_format,omitempty"`
FrequencyPenalty float64 `json:"frequency_penalty,omitempty"` FrequencyPenalty float64 `json:"frequency_penalty,omitempty"`
PresencePenalty float64 `json:"presence_penalty,omitempty"` PresencePenalty float64 `json:"presence_penalty,omitempty"`
Suffix any `json:"suffix,omitempty"` Suffix any `json:"suffix,omitempty"`
StreamOptions *dto.StreamOptions `json:"stream_options,omitempty"` StreamOptions *dto.StreamOptions `json:"stream_options,omitempty"`
Prompt any `json:"prompt,omitempty"` Prompt any `json:"prompt,omitempty"`
} }
type Options struct { type Options struct {

View File

@@ -1,7 +1,6 @@
package service package service
import ( import (
"encoding/json"
"errors" "errors"
"fmt" "fmt"
"image" "image"
@@ -170,12 +169,7 @@ func CountTokenChatRequest(info *relaycommon.RelayInfo, request dto.GeneralOpenA
} }
tkm += msgTokens tkm += msgTokens
if request.Tools != nil { if request.Tools != nil {
toolsData, _ := json.Marshal(request.Tools) openaiTools := request.Tools
var openaiTools []dto.OpenAITools
err := json.Unmarshal(toolsData, &openaiTools)
if err != nil {
return 0, errors.New(fmt.Sprintf("count_tools_token_fail: %s", err.Error()))
}
countStr := "" countStr := ""
for _, tool := range openaiTools { for _, tool := range openaiTools {
countStr = tool.Function.Name countStr = tool.Function.Name