From d9c1fb52444f446f45aa1289179b5fe700bc92c1 Mon Sep 17 00:00:00 2001 From: CaIon Date: Thu, 7 Aug 2025 16:15:59 +0800 Subject: [PATCH] feat: update MaxTokens handling --- controller/channel-test.go | 2 +- dto/openai_request.go | 7 ++++-- relay/channel/baidu/relay-baidu.go | 6 ++--- relay/channel/claude/relay-claude.go | 2 +- relay/channel/cloudflare/dto.go | 2 +- relay/channel/cohere/dto.go | 2 +- relay/channel/gemini/relay-gemini.go | 2 +- relay/channel/mistral/text.go | 2 +- relay/channel/ollama/relay-ollama.go | 2 +- relay/channel/palm/relay-palm.go | 24 -------------------- relay/channel/perplexity/relay-perplexity.go | 2 +- relay/channel/xunfei/relay-xunfei.go | 2 +- relay/channel/zhipu_4v/relay-zhipu_v4.go | 2 +- 13 files changed, 18 insertions(+), 39 deletions(-) diff --git a/controller/channel-test.go b/controller/channel-test.go index 3a7c582b..a83d7d2a 100644 --- a/controller/channel-test.go +++ b/controller/channel-test.go @@ -161,7 +161,7 @@ func testChannel(channel *model.Channel, testModel string) testResult { logInfo.ApiKey = "" common.SysLog(fmt.Sprintf("testing channel %d with model %s , info %+v ", channel.Id, testModel, logInfo)) - priceData, err := helper.ModelPriceHelper(c, info, 0, int(request.MaxTokens)) + priceData, err := helper.ModelPriceHelper(c, info, 0, int(request.GetMaxTokens())) if err != nil { return testResult{ context: c, diff --git a/dto/openai_request.go b/dto/openai_request.go index 29076ef6..fcd47d07 100644 --- a/dto/openai_request.go +++ b/dto/openai_request.go @@ -99,8 +99,11 @@ type StreamOptions struct { IncludeUsage bool `json:"include_usage,omitempty"` } -func (r *GeneralOpenAIRequest) GetMaxTokens() int { - return int(r.MaxTokens) +func (r *GeneralOpenAIRequest) GetMaxTokens() uint { + if r.MaxCompletionTokens != 0 { + return r.MaxCompletionTokens + } + return r.MaxTokens } func (r *GeneralOpenAIRequest) ParseInput() []string { diff --git a/relay/channel/baidu/relay-baidu.go b/relay/channel/baidu/relay-baidu.go index 06b48c20..a7cd5996 100644 --- a/relay/channel/baidu/relay-baidu.go +++ b/relay/channel/baidu/relay-baidu.go @@ -34,9 +34,9 @@ func requestOpenAI2Baidu(request dto.GeneralOpenAIRequest) *BaiduChatRequest { EnableCitation: false, UserId: request.User, } - if request.MaxTokens != 0 { - maxTokens := int(request.MaxTokens) - if request.MaxTokens == 1 { + if request.GetMaxTokens() != 0 { + maxTokens := int(request.GetMaxTokens()) + if request.GetMaxTokens() == 1 { maxTokens = 2 } baiduRequest.MaxOutputTokens = &maxTokens diff --git a/relay/channel/claude/relay-claude.go b/relay/channel/claude/relay-claude.go index 64739aa9..2cbac7b7 100644 --- a/relay/channel/claude/relay-claude.go +++ b/relay/channel/claude/relay-claude.go @@ -149,7 +149,7 @@ func RequestOpenAI2ClaudeMessage(textRequest dto.GeneralOpenAIRequest) (*dto.Cla claudeRequest := dto.ClaudeRequest{ Model: textRequest.Model, - MaxTokens: textRequest.MaxTokens, + MaxTokens: textRequest.GetMaxTokens(), StopSequences: nil, Temperature: textRequest.Temperature, TopP: textRequest.TopP, diff --git a/relay/channel/cloudflare/dto.go b/relay/channel/cloudflare/dto.go index 62a45c40..72b40615 100644 --- a/relay/channel/cloudflare/dto.go +++ b/relay/channel/cloudflare/dto.go @@ -5,7 +5,7 @@ import "one-api/dto" type CfRequest struct { Messages []dto.Message `json:"messages,omitempty"` Lora string `json:"lora,omitempty"` - MaxTokens int `json:"max_tokens,omitempty"` + MaxTokens uint `json:"max_tokens,omitempty"` Prompt string `json:"prompt,omitempty"` Raw bool `json:"raw,omitempty"` Stream bool `json:"stream,omitempty"` diff --git a/relay/channel/cohere/dto.go b/relay/channel/cohere/dto.go index 410540c0..d5127963 100644 --- a/relay/channel/cohere/dto.go +++ b/relay/channel/cohere/dto.go @@ -7,7 +7,7 @@ type CohereRequest struct { ChatHistory []ChatHistory `json:"chat_history"` Message string `json:"message"` Stream bool `json:"stream"` - MaxTokens int `json:"max_tokens"` + MaxTokens uint `json:"max_tokens"` SafetyMode string `json:"safety_mode,omitempty"` } diff --git a/relay/channel/gemini/relay-gemini.go b/relay/channel/gemini/relay-gemini.go index 18524afb..698a972c 100644 --- a/relay/channel/gemini/relay-gemini.go +++ b/relay/channel/gemini/relay-gemini.go @@ -184,7 +184,7 @@ func CovertGemini2OpenAI(textRequest dto.GeneralOpenAIRequest, info *relaycommon GenerationConfig: dto.GeminiChatGenerationConfig{ Temperature: textRequest.Temperature, TopP: textRequest.TopP, - MaxOutputTokens: textRequest.MaxTokens, + MaxOutputTokens: textRequest.GetMaxTokens(), Seed: int64(textRequest.Seed), }, } diff --git a/relay/channel/mistral/text.go b/relay/channel/mistral/text.go index e26c6101..aa925781 100644 --- a/relay/channel/mistral/text.go +++ b/relay/channel/mistral/text.go @@ -71,7 +71,7 @@ func requestOpenAI2Mistral(request *dto.GeneralOpenAIRequest) *dto.GeneralOpenAI Messages: messages, Temperature: request.Temperature, TopP: request.TopP, - MaxTokens: request.MaxTokens, + MaxTokens: request.GetMaxTokens(), Tools: request.Tools, ToolChoice: request.ToolChoice, } diff --git a/relay/channel/ollama/relay-ollama.go b/relay/channel/ollama/relay-ollama.go index f98dfc73..d4686ce3 100644 --- a/relay/channel/ollama/relay-ollama.go +++ b/relay/channel/ollama/relay-ollama.go @@ -60,7 +60,7 @@ func requestOpenAI2Ollama(request *dto.GeneralOpenAIRequest) (*OllamaRequest, er TopK: request.TopK, Stop: Stop, Tools: request.Tools, - MaxTokens: request.MaxTokens, + MaxTokens: request.GetMaxTokens(), ResponseFormat: request.ResponseFormat, FrequencyPenalty: request.FrequencyPenalty, PresencePenalty: request.PresencePenalty, diff --git a/relay/channel/palm/relay-palm.go b/relay/channel/palm/relay-palm.go index cbd60f5e..9b8bce7d 100644 --- a/relay/channel/palm/relay-palm.go +++ b/relay/channel/palm/relay-palm.go @@ -18,30 +18,6 @@ import ( // https://developers.generativeai.google/api/rest/generativelanguage/models/generateMessage#request-body // https://developers.generativeai.google/api/rest/generativelanguage/models/generateMessage#response-body -func requestOpenAI2PaLM(textRequest dto.GeneralOpenAIRequest) *PaLMChatRequest { - palmRequest := PaLMChatRequest{ - Prompt: PaLMPrompt{ - Messages: make([]PaLMChatMessage, 0, len(textRequest.Messages)), - }, - Temperature: textRequest.Temperature, - CandidateCount: textRequest.N, - TopP: textRequest.TopP, - TopK: textRequest.MaxTokens, - } - for _, message := range textRequest.Messages { - palmMessage := PaLMChatMessage{ - Content: message.StringContent(), - } - if message.Role == "user" { - palmMessage.Author = "0" - } else { - palmMessage.Author = "1" - } - palmRequest.Prompt.Messages = append(palmRequest.Prompt.Messages, palmMessage) - } - return &palmRequest -} - func responsePaLM2OpenAI(response *PaLMChatResponse) *dto.OpenAITextResponse { fullTextResponse := dto.OpenAITextResponse{ Choices: make([]dto.OpenAITextResponseChoice, 0, len(response.Candidates)), diff --git a/relay/channel/perplexity/relay-perplexity.go b/relay/channel/perplexity/relay-perplexity.go index 9772aead..7ebadd0f 100644 --- a/relay/channel/perplexity/relay-perplexity.go +++ b/relay/channel/perplexity/relay-perplexity.go @@ -16,6 +16,6 @@ func requestOpenAI2Perplexity(request dto.GeneralOpenAIRequest) *dto.GeneralOpen Messages: messages, Temperature: request.Temperature, TopP: request.TopP, - MaxTokens: request.MaxTokens, + MaxTokens: request.GetMaxTokens(), } } diff --git a/relay/channel/xunfei/relay-xunfei.go b/relay/channel/xunfei/relay-xunfei.go index 373ad605..1a426d50 100644 --- a/relay/channel/xunfei/relay-xunfei.go +++ b/relay/channel/xunfei/relay-xunfei.go @@ -48,7 +48,7 @@ func requestOpenAI2Xunfei(request dto.GeneralOpenAIRequest, xunfeiAppId string, xunfeiRequest.Parameter.Chat.Domain = domain xunfeiRequest.Parameter.Chat.Temperature = request.Temperature xunfeiRequest.Parameter.Chat.TopK = request.N - xunfeiRequest.Parameter.Chat.MaxTokens = request.MaxTokens + xunfeiRequest.Parameter.Chat.MaxTokens = request.GetMaxTokens() xunfeiRequest.Payload.Message.Text = messages return &xunfeiRequest } diff --git a/relay/channel/zhipu_4v/relay-zhipu_v4.go b/relay/channel/zhipu_4v/relay-zhipu_v4.go index 271dda8f..98a852f5 100644 --- a/relay/channel/zhipu_4v/relay-zhipu_v4.go +++ b/relay/channel/zhipu_4v/relay-zhipu_v4.go @@ -105,7 +105,7 @@ func requestOpenAI2Zhipu(request dto.GeneralOpenAIRequest) *dto.GeneralOpenAIReq Messages: messages, Temperature: request.Temperature, TopP: request.TopP, - MaxTokens: request.MaxTokens, + MaxTokens: request.GetMaxTokens(), Stop: Stop, Tools: request.Tools, ToolChoice: request.ToolChoice,