fix: claude tool call format #795 #766

This commit is contained in:
1808837298@qq.com
2025-02-26 23:40:16 +08:00
parent 287caf8e38
commit 6d8d40e67b
4 changed files with 14 additions and 8 deletions

View File

@@ -107,7 +107,7 @@ type FunctionCall struct {
Name string `json:"name,omitempty"` Name string `json:"name,omitempty"`
// call function with arguments in JSON format // call function with arguments in JSON format
Parameters any `json:"parameters,omitempty"` // request Parameters any `json:"parameters,omitempty"` // request
Arguments string `json:"arguments,omitempty"` Arguments string `json:"arguments"`
} }
type ChatCompletionsStreamResponse struct { type ChatCompletionsStreamResponse struct {

View File

@@ -368,7 +368,7 @@ func responseGeminiChat2OpenAI(response *GeminiChatResponse) *dto.OpenAITextResp
Choices: make([]dto.OpenAITextResponseChoice, 0, len(response.Candidates)), Choices: make([]dto.OpenAITextResponseChoice, 0, len(response.Candidates)),
} }
content, _ := json.Marshal("") content, _ := json.Marshal("")
is_tool_call := false isToolCall := false
for _, candidate := range response.Candidates { for _, candidate := range response.Candidates {
choice := dto.OpenAITextResponseChoice{ choice := dto.OpenAITextResponseChoice{
Index: int(candidate.Index), Index: int(candidate.Index),
@@ -380,12 +380,12 @@ func responseGeminiChat2OpenAI(response *GeminiChatResponse) *dto.OpenAITextResp
} }
if len(candidate.Content.Parts) > 0 { if len(candidate.Content.Parts) > 0 {
var texts []string var texts []string
var tool_calls []dto.ToolCall var toolCalls []dto.ToolCall
for _, part := range candidate.Content.Parts { for _, part := range candidate.Content.Parts {
if part.FunctionCall != nil { if part.FunctionCall != nil {
choice.FinishReason = constant.FinishReasonToolCalls choice.FinishReason = constant.FinishReasonToolCalls
if call := getToolCall(&part); call != nil { if call := getToolCall(&part); call != nil {
tool_calls = append(tool_calls, *call) toolCalls = append(toolCalls, *call)
} }
} else { } else {
if part.ExecutableCode != nil { if part.ExecutableCode != nil {
@@ -400,9 +400,9 @@ func responseGeminiChat2OpenAI(response *GeminiChatResponse) *dto.OpenAITextResp
} }
} }
} }
if len(tool_calls) > 0 { if len(toolCalls) > 0 {
choice.Message.SetToolCalls(tool_calls) choice.Message.SetToolCalls(toolCalls)
is_tool_call = true isToolCall = true
} }
choice.Message.SetStringContent(strings.Join(texts, "\n")) choice.Message.SetStringContent(strings.Join(texts, "\n"))
@@ -418,7 +418,7 @@ func responseGeminiChat2OpenAI(response *GeminiChatResponse) *dto.OpenAITextResp
choice.FinishReason = constant.FinishReasonContentFilter choice.FinishReason = constant.FinishReasonContentFilter
} }
} }
if is_tool_call { if isToolCall {
choice.FinishReason = constant.FinishReasonToolCalls choice.FinishReason = constant.FinishReasonToolCalls
} }

View File

@@ -30,6 +30,7 @@ const (
APITypeMokaAI APITypeMokaAI
APITypeVolcEngine APITypeVolcEngine
APITypeBaiduV2 APITypeBaiduV2
APITypeOpenRouter
APITypeDummy // this one is only for count, do not add any channel after this APITypeDummy // this one is only for count, do not add any channel after this
) )
@@ -86,6 +87,8 @@ func ChannelType2APIType(channelType int) (int, bool) {
apiType = APITypeVolcEngine apiType = APITypeVolcEngine
case common.ChannelTypeBaiduV2: case common.ChannelTypeBaiduV2:
apiType = APITypeBaiduV2 apiType = APITypeBaiduV2
case common.ChannelTypeOpenRouter:
apiType = APITypeOpenRouter
} }
if apiType == -1 { if apiType == -1 {
return APITypeOpenAI, false return APITypeOpenAI, false

View File

@@ -18,6 +18,7 @@ import (
"one-api/relay/channel/mokaai" "one-api/relay/channel/mokaai"
"one-api/relay/channel/ollama" "one-api/relay/channel/ollama"
"one-api/relay/channel/openai" "one-api/relay/channel/openai"
"one-api/relay/channel/openrouter"
"one-api/relay/channel/palm" "one-api/relay/channel/palm"
"one-api/relay/channel/perplexity" "one-api/relay/channel/perplexity"
"one-api/relay/channel/siliconflow" "one-api/relay/channel/siliconflow"
@@ -83,6 +84,8 @@ func GetAdaptor(apiType int) channel.Adaptor {
return &volcengine.Adaptor{} return &volcengine.Adaptor{}
case constant.APITypeBaiduV2: case constant.APITypeBaiduV2:
return &baidu_v2.Adaptor{} return &baidu_v2.Adaptor{}
case constant.APITypeOpenRouter:
return &openrouter.Adaptor{}
} }
return nil return nil
} }