From 38bff1a0e041c02a6ef6692d8b62c2b50676b8a3 Mon Sep 17 00:00:00 2001 From: RedwindA Date: Thu, 7 Aug 2025 00:54:48 +0800 Subject: [PATCH 01/12] =?UTF-8?q?refactor:=20=E7=A7=BB=E9=99=A4=20GoogleOp?= =?UTF-8?q?enAI=20=E5=85=BC=E5=AE=B9=E6=A8=A1=E5=9E=8B=E7=9B=B8=E5=85=B3?= =?UTF-8?q?=E7=BB=93=E6=9E=84=E4=BD=93=EF=BC=8C=E7=AE=80=E5=8C=96=20FetchU?= =?UTF-8?q?pstreamModels=20=E5=87=BD=E6=95=B0=E9=80=BB=E8=BE=91?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- controller/channel.go | 57 ++++++------------------------------------- 1 file changed, 8 insertions(+), 49 deletions(-) diff --git a/controller/channel.go b/controller/channel.go index 9f46ca35..3361cbf5 100644 --- a/controller/channel.go +++ b/controller/channel.go @@ -36,30 +36,11 @@ type OpenAIModel struct { Parent string `json:"parent"` } -type GoogleOpenAICompatibleModels []struct { - Name string `json:"name"` - Version string `json:"version"` - DisplayName string `json:"displayName"` - Description string `json:"description,omitempty"` - InputTokenLimit int `json:"inputTokenLimit"` - OutputTokenLimit int `json:"outputTokenLimit"` - SupportedGenerationMethods []string `json:"supportedGenerationMethods"` - Temperature float64 `json:"temperature,omitempty"` - TopP float64 `json:"topP,omitempty"` - TopK int `json:"topK,omitempty"` - MaxTemperature int `json:"maxTemperature,omitempty"` -} - type OpenAIModelsResponse struct { Data []OpenAIModel `json:"data"` Success bool `json:"success"` } -type GoogleOpenAICompatibleResponse struct { - Models []GoogleOpenAICompatibleModels `json:"models"` - NextPageToken string `json:"nextPageToken"` -} - func parseStatusFilter(statusParam string) int { switch strings.ToLower(statusParam) { case "enabled", "1": @@ -203,7 +184,7 @@ func FetchUpstreamModels(c *gin.Context) { switch channel.Type { case constant.ChannelTypeGemini: // curl https://example.com/v1beta/models?key=$GEMINI_API_KEY - url = fmt.Sprintf("%s/v1beta/openai/models?key=%s", baseURL, channel.Key) + url = fmt.Sprintf("%s/v1beta/openai/models", baseURL) // Remember key in url since we need to use AuthHeader case constant.ChannelTypeAli: url = fmt.Sprintf("%s/compatible-mode/v1/models", baseURL) default: @@ -213,7 +194,7 @@ func FetchUpstreamModels(c *gin.Context) { // 获取响应体 - 根据渠道类型决定是否添加 AuthHeader var body []byte if channel.Type == constant.ChannelTypeGemini { - body, err = GetResponseBody("GET", url, channel, nil) // I don't know why, but Gemini requires no AuthHeader + body, err = GetResponseBody("GET", url, channel, GetAuthHeader(channel.Key)) // Use AuthHeader since Gemini now forces it } else { body, err = GetResponseBody("GET", url, channel, GetAuthHeader(channel.Key)) } @@ -223,34 +204,12 @@ func FetchUpstreamModels(c *gin.Context) { } var result OpenAIModelsResponse - var parseSuccess bool - - // 适配特殊格式 - switch channel.Type { - case constant.ChannelTypeGemini: - var googleResult GoogleOpenAICompatibleResponse - if err = json.Unmarshal(body, &googleResult); err == nil { - // 转换Google格式到OpenAI格式 - for _, model := range googleResult.Models { - for _, gModel := range model { - result.Data = append(result.Data, OpenAIModel{ - ID: gModel.Name, - }) - } - } - parseSuccess = true - } - } - - // 如果解析失败,尝试OpenAI格式 - if !parseSuccess { - if err = json.Unmarshal(body, &result); err != nil { - c.JSON(http.StatusOK, gin.H{ - "success": false, - "message": fmt.Sprintf("解析响应失败: %s", err.Error()), - }) - return - } + if err = json.Unmarshal(body, &result); err != nil { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": fmt.Sprintf("解析响应失败: %s", err.Error()), + }) + return } var ids []string From 76d71a032acf65e1433a54e7e4b10ea28dc7b162 Mon Sep 17 00:00:00 2001 From: RedwindA Date: Thu, 7 Aug 2025 01:01:45 +0800 Subject: [PATCH 02/12] =?UTF-8?q?fix:=20=E4=BF=AE=E5=A4=8D=20FetchUpstream?= =?UTF-8?q?Models=20=E5=87=BD=E6=95=B0=E4=B8=AD=20AuthHeader=20=E7=9A=84?= =?UTF-8?q?=E4=BD=BF=E7=94=A8=EF=BC=8C=E7=A1=AE=E4=BF=9D=E6=AD=A3=E7=A1=AE?= =?UTF-8?q?=E5=A4=84=E7=90=86=20=E5=A4=9Akey=E8=81=9A=E5=90=88=E7=9A=84?= =?UTF-8?q?=E6=83=85=E5=86=B5?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- controller/channel.go | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/controller/channel.go b/controller/channel.go index 3361cbf5..284597c3 100644 --- a/controller/channel.go +++ b/controller/channel.go @@ -193,10 +193,11 @@ func FetchUpstreamModels(c *gin.Context) { // 获取响应体 - 根据渠道类型决定是否添加 AuthHeader var body []byte + key := strings.Split(channel.Key, "\n")[0] if channel.Type == constant.ChannelTypeGemini { - body, err = GetResponseBody("GET", url, channel, GetAuthHeader(channel.Key)) // Use AuthHeader since Gemini now forces it + body, err = GetResponseBody("GET", url, channel, GetAuthHeader(key)) // Use AuthHeader since Gemini now forces it } else { - body, err = GetResponseBody("GET", url, channel, GetAuthHeader(channel.Key)) + body, err = GetResponseBody("GET", url, channel, GetAuthHeader(key)) } if err != nil { common.ApiError(c, err) From ed95a9f2b27124de1f20adc6b876ffc6502fbbce Mon Sep 17 00:00:00 2001 From: RedwindA Date: Thu, 7 Aug 2025 01:06:50 +0800 Subject: [PATCH 03/12] fix a typo in comment --- controller/channel.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/controller/channel.go b/controller/channel.go index 284597c3..020a3327 100644 --- a/controller/channel.go +++ b/controller/channel.go @@ -184,7 +184,7 @@ func FetchUpstreamModels(c *gin.Context) { switch channel.Type { case constant.ChannelTypeGemini: // curl https://example.com/v1beta/models?key=$GEMINI_API_KEY - url = fmt.Sprintf("%s/v1beta/openai/models", baseURL) // Remember key in url since we need to use AuthHeader + url = fmt.Sprintf("%s/v1beta/openai/models", baseURL) // Remove key in url since we need to use AuthHeader case constant.ChannelTypeAli: url = fmt.Sprintf("%s/compatible-mode/v1/models", baseURL) default: From 1cea7a0314da30f362e0cbfff5fd411e0ffaaa6f Mon Sep 17 00:00:00 2001 From: RedwindA Date: Thu, 7 Aug 2025 06:18:22 +0800 Subject: [PATCH 04/12] =?UTF-8?q?fix:=20=E8=B0=83=E6=95=B4Disable=20Ping?= =?UTF-8?q?=E6=A0=87=E5=BF=97=E7=9A=84=E8=AE=BE=E7=BD=AE=E4=BD=8D=E7=BD=AE?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- relay/channel/gemini/adaptor.go | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/relay/channel/gemini/adaptor.go b/relay/channel/gemini/adaptor.go index 14fd278d..01dfea2c 100644 --- a/relay/channel/gemini/adaptor.go +++ b/relay/channel/gemini/adaptor.go @@ -120,6 +120,9 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) { action := "generateContent" if info.IsStream { action = "streamGenerateContent?alt=sse" + if info.RelayMode == constant.RelayModeGemini { + info.DisablePing = true + } } return fmt.Sprintf("%s/%s/models/%s:%s", info.BaseUrl, version, info.UpstreamModelName, action), nil } @@ -193,7 +196,6 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *types.NewAPIError) { if info.RelayMode == constant.RelayModeGemini { if info.IsStream { - info.DisablePing = true return GeminiTextGenerationStreamHandler(c, info, resp) } else { return GeminiTextGenerationHandler(c, info, resp) From 38067f1ddc3239f63b0787734f6cfa0fe01805a4 Mon Sep 17 00:00:00 2001 From: feitianbubu Date: Wed, 6 Aug 2025 22:58:36 +0800 Subject: [PATCH 05/12] feat: enable thinking mode on ali thinking model --- controller/channel-test.go | 2 +- relay/channel/ali/adaptor.go | 12 +++++++++--- relay/common/relay_info.go | 3 +++ 3 files changed, 13 insertions(+), 4 deletions(-) diff --git a/controller/channel-test.go b/controller/channel-test.go index 3a7c582b..1be36808 100644 --- a/controller/channel-test.go +++ b/controller/channel-test.go @@ -275,7 +275,7 @@ func testChannel(channel *model.Channel, testModel string) testResult { Quota: quota, Content: "模型测试", UseTimeSeconds: int(consumedTime), - IsStream: false, + IsStream: info.IsStream, Group: info.UsingGroup, Other: other, }) diff --git a/relay/channel/ali/adaptor.go b/relay/channel/ali/adaptor.go index 067fac37..35fe73c2 100644 --- a/relay/channel/ali/adaptor.go +++ b/relay/channel/ali/adaptor.go @@ -3,6 +3,7 @@ package ali import ( "errors" "fmt" + "github.com/gin-gonic/gin" "io" "net/http" "one-api/dto" @@ -11,8 +12,7 @@ import ( relaycommon "one-api/relay/common" "one-api/relay/constant" "one-api/types" - - "github.com/gin-gonic/gin" + "strings" ) type Adaptor struct { @@ -65,7 +65,13 @@ func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayIn if request == nil { return nil, errors.New("request is nil") } - + // docs: https://bailian.console.aliyun.com/?tab=api#/api/?type=model&url=2712216 + // fix: InternalError.Algo.InvalidParameter: The value of the enable_thinking parameter is restricted to True. + if strings.Contains(request.Model, "thinking") { + request.EnableThinking = true + request.Stream = true + info.IsStream = true + } // fix: ali parameter.enable_thinking must be set to false for non-streaming calls if !info.IsStream { request.EnableThinking = false diff --git a/relay/common/relay_info.go b/relay/common/relay_info.go index 266486c4..743070ca 100644 --- a/relay/common/relay_info.go +++ b/relay/common/relay_info.go @@ -225,6 +225,9 @@ func GenRelayInfo(c *gin.Context) *RelayInfo { userId := common.GetContextKeyInt(c, constant.ContextKeyUserId) tokenUnlimited := common.GetContextKeyBool(c, constant.ContextKeyTokenUnlimited) startTime := common.GetContextKeyTime(c, constant.ContextKeyRequestStartTime) + if startTime.IsZero() { + startTime = time.Now() + } // firstResponseTime = time.Now() - 1 second apiType, _ := common.ChannelType2APIType(channelType) From 71c39c98936417a6b2fd38f5a635d1d2bad11c24 Mon Sep 17 00:00:00 2001 From: CaIon Date: Thu, 7 Aug 2025 15:40:12 +0800 Subject: [PATCH 06/12] feat: update Usage struct to support dynamic token handling with ceil function #1503 --- dto/openai_response.go | 119 +++++++++++++++++++++++- relay/channel/openai/relay-openai.go | 8 +- relay/channel/openai/relay_responses.go | 8 +- 3 files changed, 124 insertions(+), 11 deletions(-) diff --git a/dto/openai_response.go b/dto/openai_response.go index b050cd03..7e6ee584 100644 --- a/dto/openai_response.go +++ b/dto/openai_response.go @@ -3,6 +3,8 @@ package dto import ( "encoding/json" "fmt" + "math" + "one-api/common" "one-api/types" ) @@ -202,13 +204,124 @@ type Usage struct { PromptTokensDetails InputTokenDetails `json:"prompt_tokens_details"` CompletionTokenDetails OutputTokenDetails `json:"completion_tokens_details"` - InputTokens int `json:"input_tokens"` - OutputTokens int `json:"output_tokens"` - InputTokensDetails *InputTokenDetails `json:"input_tokens_details"` + InputTokens any `json:"input_tokens"` + OutputTokens any `json:"output_tokens"` + //CacheReadInputTokens any `json:"cache_read_input_tokens,omitempty"` + InputTokensDetails *InputTokenDetails `json:"input_tokens_details"` // OpenRouter Params Cost any `json:"cost,omitempty"` } +func (u *Usage) UnmarshalJSON(data []byte) error { + // first normal unmarshal + if err := common.Unmarshal(data, u); err != nil { + return fmt.Errorf("unmarshal Usage failed: %w", err) + } + + // then ceil the input and output tokens + ceil := func(val any) int { + switch v := val.(type) { + case float64: + return int(math.Ceil(v)) + case int: + return v + case string: + var intVal int + _, err := fmt.Sscanf(v, "%d", &intVal) + if err != nil { + return 0 // or handle error appropriately + } + return intVal + default: + return 0 // or handle error appropriately + } + } + + // input_tokens must be int + if u.InputTokens != nil { + u.InputTokens = ceil(u.InputTokens) + } + if u.OutputTokens != nil { + u.OutputTokens = ceil(u.OutputTokens) + } + return nil +} + +func (u *Usage) GetInputTokens() int { + if u.InputTokens == nil { + return 0 + } + + switch v := u.InputTokens.(type) { + case int: + return v + case float64: + return int(math.Ceil(v)) + case string: + var intVal int + _, err := fmt.Sscanf(v, "%d", &intVal) + if err != nil { + return 0 // or handle error appropriately + } + return intVal + default: + return 0 // or handle error appropriately + } +} + +func (u *Usage) GetOutputTokens() int { + if u.OutputTokens == nil { + return 0 + } + + switch v := u.OutputTokens.(type) { + case int: + return v + case float64: + return int(math.Ceil(v)) + case string: + var intVal int + _, err := fmt.Sscanf(v, "%d", &intVal) + if err != nil { + return 0 // or handle error appropriately + } + return intVal + default: + return 0 // or handle error appropriately + } +} + +//func (u *Usage) MarshalJSON() ([]byte, error) { +// ceil := func(val any) int { +// switch v := val.(type) { +// case float64: +// return int(math.Ceil(v)) +// case int: +// return v +// case string: +// var intVal int +// _, err := fmt.Sscanf(v, "%d", &intVal) +// if err != nil { +// return 0 // or handle error appropriately +// } +// return intVal +// default: +// return 0 // or handle error appropriately +// } +// } +// +// // input_tokens must be int +// if u.InputTokens != nil { +// u.InputTokens = ceil(u.InputTokens) +// } +// if u.OutputTokens != nil { +// u.OutputTokens = ceil(u.OutputTokens) +// } +// +// // done +// return common.Marshal(u) +//} + type InputTokenDetails struct { CachedTokens int `json:"cached_tokens"` CachedCreationTokens int `json:"-"` diff --git a/relay/channel/openai/relay-openai.go b/relay/channel/openai/relay-openai.go index 9ae0a200..f5e29209 100644 --- a/relay/channel/openai/relay-openai.go +++ b/relay/channel/openai/relay-openai.go @@ -570,11 +570,11 @@ func OpenaiHandlerWithUsage(c *gin.Context, info *relaycommon.RelayInfo, resp *h // because the upstream has already consumed resources and returned content // We should still perform billing even if parsing fails // format - if usageResp.InputTokens > 0 { - usageResp.PromptTokens += usageResp.InputTokens + if usageResp.GetInputTokens() > 0 { + usageResp.PromptTokens += usageResp.GetInputTokens() } - if usageResp.OutputTokens > 0 { - usageResp.CompletionTokens += usageResp.OutputTokens + if usageResp.GetOutputTokens() > 0 { + usageResp.CompletionTokens += usageResp.GetOutputTokens() } if usageResp.InputTokensDetails != nil { usageResp.PromptTokensDetails.ImageTokens += usageResp.InputTokensDetails.ImageTokens diff --git a/relay/channel/openai/relay_responses.go b/relay/channel/openai/relay_responses.go index bae6fcb6..2c996f91 100644 --- a/relay/channel/openai/relay_responses.go +++ b/relay/channel/openai/relay_responses.go @@ -38,8 +38,8 @@ func OaiResponsesHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http // compute usage usage := dto.Usage{} if responsesResponse.Usage != nil { - usage.PromptTokens = responsesResponse.Usage.InputTokens - usage.CompletionTokens = responsesResponse.Usage.OutputTokens + usage.PromptTokens = responsesResponse.Usage.GetInputTokens() + usage.CompletionTokens = responsesResponse.Usage.GetOutputTokens() usage.TotalTokens = responsesResponse.Usage.TotalTokens if responsesResponse.Usage.InputTokensDetails != nil { usage.PromptTokensDetails.CachedTokens = responsesResponse.Usage.InputTokensDetails.CachedTokens @@ -70,8 +70,8 @@ func OaiResponsesStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp switch streamResponse.Type { case "response.completed": if streamResponse.Response.Usage != nil { - usage.PromptTokens = streamResponse.Response.Usage.InputTokens - usage.CompletionTokens = streamResponse.Response.Usage.OutputTokens + usage.PromptTokens = streamResponse.Response.Usage.GetInputTokens() + usage.CompletionTokens = streamResponse.Response.Usage.GetOutputTokens() usage.TotalTokens = streamResponse.Response.Usage.TotalTokens if streamResponse.Response.Usage.InputTokensDetails != nil { usage.PromptTokensDetails.CachedTokens = streamResponse.Response.Usage.InputTokensDetails.CachedTokens From d9c1fb52444f446f45aa1289179b5fe700bc92c1 Mon Sep 17 00:00:00 2001 From: CaIon Date: Thu, 7 Aug 2025 16:15:59 +0800 Subject: [PATCH 07/12] feat: update MaxTokens handling --- controller/channel-test.go | 2 +- dto/openai_request.go | 7 ++++-- relay/channel/baidu/relay-baidu.go | 6 ++--- relay/channel/claude/relay-claude.go | 2 +- relay/channel/cloudflare/dto.go | 2 +- relay/channel/cohere/dto.go | 2 +- relay/channel/gemini/relay-gemini.go | 2 +- relay/channel/mistral/text.go | 2 +- relay/channel/ollama/relay-ollama.go | 2 +- relay/channel/palm/relay-palm.go | 24 -------------------- relay/channel/perplexity/relay-perplexity.go | 2 +- relay/channel/xunfei/relay-xunfei.go | 2 +- relay/channel/zhipu_4v/relay-zhipu_v4.go | 2 +- 13 files changed, 18 insertions(+), 39 deletions(-) diff --git a/controller/channel-test.go b/controller/channel-test.go index 3a7c582b..a83d7d2a 100644 --- a/controller/channel-test.go +++ b/controller/channel-test.go @@ -161,7 +161,7 @@ func testChannel(channel *model.Channel, testModel string) testResult { logInfo.ApiKey = "" common.SysLog(fmt.Sprintf("testing channel %d with model %s , info %+v ", channel.Id, testModel, logInfo)) - priceData, err := helper.ModelPriceHelper(c, info, 0, int(request.MaxTokens)) + priceData, err := helper.ModelPriceHelper(c, info, 0, int(request.GetMaxTokens())) if err != nil { return testResult{ context: c, diff --git a/dto/openai_request.go b/dto/openai_request.go index 29076ef6..fcd47d07 100644 --- a/dto/openai_request.go +++ b/dto/openai_request.go @@ -99,8 +99,11 @@ type StreamOptions struct { IncludeUsage bool `json:"include_usage,omitempty"` } -func (r *GeneralOpenAIRequest) GetMaxTokens() int { - return int(r.MaxTokens) +func (r *GeneralOpenAIRequest) GetMaxTokens() uint { + if r.MaxCompletionTokens != 0 { + return r.MaxCompletionTokens + } + return r.MaxTokens } func (r *GeneralOpenAIRequest) ParseInput() []string { diff --git a/relay/channel/baidu/relay-baidu.go b/relay/channel/baidu/relay-baidu.go index 06b48c20..a7cd5996 100644 --- a/relay/channel/baidu/relay-baidu.go +++ b/relay/channel/baidu/relay-baidu.go @@ -34,9 +34,9 @@ func requestOpenAI2Baidu(request dto.GeneralOpenAIRequest) *BaiduChatRequest { EnableCitation: false, UserId: request.User, } - if request.MaxTokens != 0 { - maxTokens := int(request.MaxTokens) - if request.MaxTokens == 1 { + if request.GetMaxTokens() != 0 { + maxTokens := int(request.GetMaxTokens()) + if request.GetMaxTokens() == 1 { maxTokens = 2 } baiduRequest.MaxOutputTokens = &maxTokens diff --git a/relay/channel/claude/relay-claude.go b/relay/channel/claude/relay-claude.go index 64739aa9..2cbac7b7 100644 --- a/relay/channel/claude/relay-claude.go +++ b/relay/channel/claude/relay-claude.go @@ -149,7 +149,7 @@ func RequestOpenAI2ClaudeMessage(textRequest dto.GeneralOpenAIRequest) (*dto.Cla claudeRequest := dto.ClaudeRequest{ Model: textRequest.Model, - MaxTokens: textRequest.MaxTokens, + MaxTokens: textRequest.GetMaxTokens(), StopSequences: nil, Temperature: textRequest.Temperature, TopP: textRequest.TopP, diff --git a/relay/channel/cloudflare/dto.go b/relay/channel/cloudflare/dto.go index 62a45c40..72b40615 100644 --- a/relay/channel/cloudflare/dto.go +++ b/relay/channel/cloudflare/dto.go @@ -5,7 +5,7 @@ import "one-api/dto" type CfRequest struct { Messages []dto.Message `json:"messages,omitempty"` Lora string `json:"lora,omitempty"` - MaxTokens int `json:"max_tokens,omitempty"` + MaxTokens uint `json:"max_tokens,omitempty"` Prompt string `json:"prompt,omitempty"` Raw bool `json:"raw,omitempty"` Stream bool `json:"stream,omitempty"` diff --git a/relay/channel/cohere/dto.go b/relay/channel/cohere/dto.go index 410540c0..d5127963 100644 --- a/relay/channel/cohere/dto.go +++ b/relay/channel/cohere/dto.go @@ -7,7 +7,7 @@ type CohereRequest struct { ChatHistory []ChatHistory `json:"chat_history"` Message string `json:"message"` Stream bool `json:"stream"` - MaxTokens int `json:"max_tokens"` + MaxTokens uint `json:"max_tokens"` SafetyMode string `json:"safety_mode,omitempty"` } diff --git a/relay/channel/gemini/relay-gemini.go b/relay/channel/gemini/relay-gemini.go index 18524afb..698a972c 100644 --- a/relay/channel/gemini/relay-gemini.go +++ b/relay/channel/gemini/relay-gemini.go @@ -184,7 +184,7 @@ func CovertGemini2OpenAI(textRequest dto.GeneralOpenAIRequest, info *relaycommon GenerationConfig: dto.GeminiChatGenerationConfig{ Temperature: textRequest.Temperature, TopP: textRequest.TopP, - MaxOutputTokens: textRequest.MaxTokens, + MaxOutputTokens: textRequest.GetMaxTokens(), Seed: int64(textRequest.Seed), }, } diff --git a/relay/channel/mistral/text.go b/relay/channel/mistral/text.go index e26c6101..aa925781 100644 --- a/relay/channel/mistral/text.go +++ b/relay/channel/mistral/text.go @@ -71,7 +71,7 @@ func requestOpenAI2Mistral(request *dto.GeneralOpenAIRequest) *dto.GeneralOpenAI Messages: messages, Temperature: request.Temperature, TopP: request.TopP, - MaxTokens: request.MaxTokens, + MaxTokens: request.GetMaxTokens(), Tools: request.Tools, ToolChoice: request.ToolChoice, } diff --git a/relay/channel/ollama/relay-ollama.go b/relay/channel/ollama/relay-ollama.go index f98dfc73..d4686ce3 100644 --- a/relay/channel/ollama/relay-ollama.go +++ b/relay/channel/ollama/relay-ollama.go @@ -60,7 +60,7 @@ func requestOpenAI2Ollama(request *dto.GeneralOpenAIRequest) (*OllamaRequest, er TopK: request.TopK, Stop: Stop, Tools: request.Tools, - MaxTokens: request.MaxTokens, + MaxTokens: request.GetMaxTokens(), ResponseFormat: request.ResponseFormat, FrequencyPenalty: request.FrequencyPenalty, PresencePenalty: request.PresencePenalty, diff --git a/relay/channel/palm/relay-palm.go b/relay/channel/palm/relay-palm.go index cbd60f5e..9b8bce7d 100644 --- a/relay/channel/palm/relay-palm.go +++ b/relay/channel/palm/relay-palm.go @@ -18,30 +18,6 @@ import ( // https://developers.generativeai.google/api/rest/generativelanguage/models/generateMessage#request-body // https://developers.generativeai.google/api/rest/generativelanguage/models/generateMessage#response-body -func requestOpenAI2PaLM(textRequest dto.GeneralOpenAIRequest) *PaLMChatRequest { - palmRequest := PaLMChatRequest{ - Prompt: PaLMPrompt{ - Messages: make([]PaLMChatMessage, 0, len(textRequest.Messages)), - }, - Temperature: textRequest.Temperature, - CandidateCount: textRequest.N, - TopP: textRequest.TopP, - TopK: textRequest.MaxTokens, - } - for _, message := range textRequest.Messages { - palmMessage := PaLMChatMessage{ - Content: message.StringContent(), - } - if message.Role == "user" { - palmMessage.Author = "0" - } else { - palmMessage.Author = "1" - } - palmRequest.Prompt.Messages = append(palmRequest.Prompt.Messages, palmMessage) - } - return &palmRequest -} - func responsePaLM2OpenAI(response *PaLMChatResponse) *dto.OpenAITextResponse { fullTextResponse := dto.OpenAITextResponse{ Choices: make([]dto.OpenAITextResponseChoice, 0, len(response.Candidates)), diff --git a/relay/channel/perplexity/relay-perplexity.go b/relay/channel/perplexity/relay-perplexity.go index 9772aead..7ebadd0f 100644 --- a/relay/channel/perplexity/relay-perplexity.go +++ b/relay/channel/perplexity/relay-perplexity.go @@ -16,6 +16,6 @@ func requestOpenAI2Perplexity(request dto.GeneralOpenAIRequest) *dto.GeneralOpen Messages: messages, Temperature: request.Temperature, TopP: request.TopP, - MaxTokens: request.MaxTokens, + MaxTokens: request.GetMaxTokens(), } } diff --git a/relay/channel/xunfei/relay-xunfei.go b/relay/channel/xunfei/relay-xunfei.go index 373ad605..1a426d50 100644 --- a/relay/channel/xunfei/relay-xunfei.go +++ b/relay/channel/xunfei/relay-xunfei.go @@ -48,7 +48,7 @@ func requestOpenAI2Xunfei(request dto.GeneralOpenAIRequest, xunfeiAppId string, xunfeiRequest.Parameter.Chat.Domain = domain xunfeiRequest.Parameter.Chat.Temperature = request.Temperature xunfeiRequest.Parameter.Chat.TopK = request.N - xunfeiRequest.Parameter.Chat.MaxTokens = request.MaxTokens + xunfeiRequest.Parameter.Chat.MaxTokens = request.GetMaxTokens() xunfeiRequest.Payload.Message.Text = messages return &xunfeiRequest } diff --git a/relay/channel/zhipu_4v/relay-zhipu_v4.go b/relay/channel/zhipu_4v/relay-zhipu_v4.go index 271dda8f..98a852f5 100644 --- a/relay/channel/zhipu_4v/relay-zhipu_v4.go +++ b/relay/channel/zhipu_4v/relay-zhipu_v4.go @@ -105,7 +105,7 @@ func requestOpenAI2Zhipu(request dto.GeneralOpenAIRequest) *dto.GeneralOpenAIReq Messages: messages, Temperature: request.Temperature, TopP: request.TopP, - MaxTokens: request.MaxTokens, + MaxTokens: request.GetMaxTokens(), Stop: Stop, Tools: request.Tools, ToolChoice: request.ToolChoice, From 865bb7aad85e045e22cc437e5ca4b2d438f1a2b1 Mon Sep 17 00:00:00 2001 From: CaIon Date: Thu, 7 Aug 2025 16:22:40 +0800 Subject: [PATCH 08/12] Revert "feat: update Usage struct to support dynamic token handling with ceil function #1503" This reverts commit 71c39c98936417a6b2fd38f5a635d1d2bad11c24. --- dto/openai_response.go | 119 +----------------------- relay/channel/openai/relay-openai.go | 8 +- relay/channel/openai/relay_responses.go | 8 +- 3 files changed, 11 insertions(+), 124 deletions(-) diff --git a/dto/openai_response.go b/dto/openai_response.go index 7e6ee584..b050cd03 100644 --- a/dto/openai_response.go +++ b/dto/openai_response.go @@ -3,8 +3,6 @@ package dto import ( "encoding/json" "fmt" - "math" - "one-api/common" "one-api/types" ) @@ -204,124 +202,13 @@ type Usage struct { PromptTokensDetails InputTokenDetails `json:"prompt_tokens_details"` CompletionTokenDetails OutputTokenDetails `json:"completion_tokens_details"` - InputTokens any `json:"input_tokens"` - OutputTokens any `json:"output_tokens"` - //CacheReadInputTokens any `json:"cache_read_input_tokens,omitempty"` - InputTokensDetails *InputTokenDetails `json:"input_tokens_details"` + InputTokens int `json:"input_tokens"` + OutputTokens int `json:"output_tokens"` + InputTokensDetails *InputTokenDetails `json:"input_tokens_details"` // OpenRouter Params Cost any `json:"cost,omitempty"` } -func (u *Usage) UnmarshalJSON(data []byte) error { - // first normal unmarshal - if err := common.Unmarshal(data, u); err != nil { - return fmt.Errorf("unmarshal Usage failed: %w", err) - } - - // then ceil the input and output tokens - ceil := func(val any) int { - switch v := val.(type) { - case float64: - return int(math.Ceil(v)) - case int: - return v - case string: - var intVal int - _, err := fmt.Sscanf(v, "%d", &intVal) - if err != nil { - return 0 // or handle error appropriately - } - return intVal - default: - return 0 // or handle error appropriately - } - } - - // input_tokens must be int - if u.InputTokens != nil { - u.InputTokens = ceil(u.InputTokens) - } - if u.OutputTokens != nil { - u.OutputTokens = ceil(u.OutputTokens) - } - return nil -} - -func (u *Usage) GetInputTokens() int { - if u.InputTokens == nil { - return 0 - } - - switch v := u.InputTokens.(type) { - case int: - return v - case float64: - return int(math.Ceil(v)) - case string: - var intVal int - _, err := fmt.Sscanf(v, "%d", &intVal) - if err != nil { - return 0 // or handle error appropriately - } - return intVal - default: - return 0 // or handle error appropriately - } -} - -func (u *Usage) GetOutputTokens() int { - if u.OutputTokens == nil { - return 0 - } - - switch v := u.OutputTokens.(type) { - case int: - return v - case float64: - return int(math.Ceil(v)) - case string: - var intVal int - _, err := fmt.Sscanf(v, "%d", &intVal) - if err != nil { - return 0 // or handle error appropriately - } - return intVal - default: - return 0 // or handle error appropriately - } -} - -//func (u *Usage) MarshalJSON() ([]byte, error) { -// ceil := func(val any) int { -// switch v := val.(type) { -// case float64: -// return int(math.Ceil(v)) -// case int: -// return v -// case string: -// var intVal int -// _, err := fmt.Sscanf(v, "%d", &intVal) -// if err != nil { -// return 0 // or handle error appropriately -// } -// return intVal -// default: -// return 0 // or handle error appropriately -// } -// } -// -// // input_tokens must be int -// if u.InputTokens != nil { -// u.InputTokens = ceil(u.InputTokens) -// } -// if u.OutputTokens != nil { -// u.OutputTokens = ceil(u.OutputTokens) -// } -// -// // done -// return common.Marshal(u) -//} - type InputTokenDetails struct { CachedTokens int `json:"cached_tokens"` CachedCreationTokens int `json:"-"` diff --git a/relay/channel/openai/relay-openai.go b/relay/channel/openai/relay-openai.go index f5e29209..9ae0a200 100644 --- a/relay/channel/openai/relay-openai.go +++ b/relay/channel/openai/relay-openai.go @@ -570,11 +570,11 @@ func OpenaiHandlerWithUsage(c *gin.Context, info *relaycommon.RelayInfo, resp *h // because the upstream has already consumed resources and returned content // We should still perform billing even if parsing fails // format - if usageResp.GetInputTokens() > 0 { - usageResp.PromptTokens += usageResp.GetInputTokens() + if usageResp.InputTokens > 0 { + usageResp.PromptTokens += usageResp.InputTokens } - if usageResp.GetOutputTokens() > 0 { - usageResp.CompletionTokens += usageResp.GetOutputTokens() + if usageResp.OutputTokens > 0 { + usageResp.CompletionTokens += usageResp.OutputTokens } if usageResp.InputTokensDetails != nil { usageResp.PromptTokensDetails.ImageTokens += usageResp.InputTokensDetails.ImageTokens diff --git a/relay/channel/openai/relay_responses.go b/relay/channel/openai/relay_responses.go index 2c996f91..bae6fcb6 100644 --- a/relay/channel/openai/relay_responses.go +++ b/relay/channel/openai/relay_responses.go @@ -38,8 +38,8 @@ func OaiResponsesHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http // compute usage usage := dto.Usage{} if responsesResponse.Usage != nil { - usage.PromptTokens = responsesResponse.Usage.GetInputTokens() - usage.CompletionTokens = responsesResponse.Usage.GetOutputTokens() + usage.PromptTokens = responsesResponse.Usage.InputTokens + usage.CompletionTokens = responsesResponse.Usage.OutputTokens usage.TotalTokens = responsesResponse.Usage.TotalTokens if responsesResponse.Usage.InputTokensDetails != nil { usage.PromptTokensDetails.CachedTokens = responsesResponse.Usage.InputTokensDetails.CachedTokens @@ -70,8 +70,8 @@ func OaiResponsesStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp switch streamResponse.Type { case "response.completed": if streamResponse.Response.Usage != nil { - usage.PromptTokens = streamResponse.Response.Usage.GetInputTokens() - usage.CompletionTokens = streamResponse.Response.Usage.GetOutputTokens() + usage.PromptTokens = streamResponse.Response.Usage.InputTokens + usage.CompletionTokens = streamResponse.Response.Usage.OutputTokens usage.TotalTokens = streamResponse.Response.Usage.TotalTokens if streamResponse.Response.Usage.InputTokensDetails != nil { usage.PromptTokensDetails.CachedTokens = streamResponse.Response.Usage.InputTokensDetails.CachedTokens From 0ea0a432bfbe2de436362043de3646ef56834fb3 Mon Sep 17 00:00:00 2001 From: CaIon Date: Thu, 7 Aug 2025 18:32:31 +0800 Subject: [PATCH 09/12] feat: support qwen claude format --- relay/channel/ali/adaptor.go | 62 +++++++++++++++++----------- relay/channel/claude/adaptor.go | 2 +- relay/channel/claude/relay-claude.go | 2 +- relay/channel/vertex/adaptor.go | 2 +- service/error.go | 3 ++ 5 files changed, 44 insertions(+), 27 deletions(-) diff --git a/relay/channel/ali/adaptor.go b/relay/channel/ali/adaptor.go index 35fe73c2..f3c5cee6 100644 --- a/relay/channel/ali/adaptor.go +++ b/relay/channel/ali/adaptor.go @@ -8,6 +8,7 @@ import ( "net/http" "one-api/dto" "one-api/relay/channel" + "one-api/relay/channel/claude" "one-api/relay/channel/openai" relaycommon "one-api/relay/common" "one-api/relay/constant" @@ -23,10 +24,8 @@ func (a *Adaptor) ConvertGeminiRequest(*gin.Context, *relaycommon.RelayInfo, *dt return nil, errors.New("not implemented") } -func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) { - //TODO implement me - panic("implement me") - return nil, nil +func (a *Adaptor) ConvertClaudeRequest(c *gin.Context, info *relaycommon.RelayInfo, req *dto.ClaudeRequest) (any, error) { + return req, nil } func (a *Adaptor) Init(info *relaycommon.RelayInfo) { @@ -34,18 +33,24 @@ func (a *Adaptor) Init(info *relaycommon.RelayInfo) { func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) { var fullRequestURL string - switch info.RelayMode { - case constant.RelayModeEmbeddings: - fullRequestURL = fmt.Sprintf("%s/compatible-mode/v1/embeddings", info.BaseUrl) - case constant.RelayModeRerank: - fullRequestURL = fmt.Sprintf("%s/api/v1/services/rerank/text-rerank/text-rerank", info.BaseUrl) - case constant.RelayModeImagesGenerations: - fullRequestURL = fmt.Sprintf("%s/api/v1/services/aigc/text2image/image-synthesis", info.BaseUrl) - case constant.RelayModeCompletions: - fullRequestURL = fmt.Sprintf("%s/compatible-mode/v1/completions", info.BaseUrl) + switch info.RelayFormat { + case relaycommon.RelayFormatClaude: + fullRequestURL = fmt.Sprintf("%s/api/v2/apps/claude-code-proxy/v1/messages", info.BaseUrl) default: - fullRequestURL = fmt.Sprintf("%s/compatible-mode/v1/chat/completions", info.BaseUrl) + switch info.RelayMode { + case constant.RelayModeEmbeddings: + fullRequestURL = fmt.Sprintf("%s/compatible-mode/v1/embeddings", info.BaseUrl) + case constant.RelayModeRerank: + fullRequestURL = fmt.Sprintf("%s/api/v1/services/rerank/text-rerank/text-rerank", info.BaseUrl) + case constant.RelayModeImagesGenerations: + fullRequestURL = fmt.Sprintf("%s/api/v1/services/aigc/text2image/image-synthesis", info.BaseUrl) + case constant.RelayModeCompletions: + fullRequestURL = fmt.Sprintf("%s/compatible-mode/v1/completions", info.BaseUrl) + default: + fullRequestURL = fmt.Sprintf("%s/compatible-mode/v1/chat/completions", info.BaseUrl) + } } + return fullRequestURL, nil } @@ -112,18 +117,27 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request } func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *types.NewAPIError) { - switch info.RelayMode { - case constant.RelayModeImagesGenerations: - err, usage = aliImageHandler(c, resp, info) - case constant.RelayModeEmbeddings: - err, usage = aliEmbeddingHandler(c, resp) - case constant.RelayModeRerank: - err, usage = RerankHandler(c, resp, info) - default: + switch info.RelayFormat { + case relaycommon.RelayFormatClaude: if info.IsStream { - usage, err = openai.OaiStreamHandler(c, info, resp) + err, usage = claude.ClaudeStreamHandler(c, resp, info, claude.RequestModeMessage) } else { - usage, err = openai.OpenaiHandler(c, info, resp) + err, usage = claude.ClaudeHandler(c, resp, info, claude.RequestModeMessage) + } + default: + switch info.RelayMode { + case constant.RelayModeImagesGenerations: + err, usage = aliImageHandler(c, resp, info) + case constant.RelayModeEmbeddings: + err, usage = aliEmbeddingHandler(c, resp) + case constant.RelayModeRerank: + err, usage = RerankHandler(c, resp, info) + default: + if info.IsStream { + usage, err = openai.OaiStreamHandler(c, info, resp) + } else { + usage, err = openai.OpenaiHandler(c, info, resp) + } } } return diff --git a/relay/channel/claude/adaptor.go b/relay/channel/claude/adaptor.go index 0f7a9414..39b8ce2f 100644 --- a/relay/channel/claude/adaptor.go +++ b/relay/channel/claude/adaptor.go @@ -104,7 +104,7 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycom if info.IsStream { err, usage = ClaudeStreamHandler(c, resp, info, a.RequestMode) } else { - err, usage = ClaudeHandler(c, resp, a.RequestMode, info) + err, usage = ClaudeHandler(c, resp, info, a.RequestMode) } return } diff --git a/relay/channel/claude/relay-claude.go b/relay/channel/claude/relay-claude.go index 2cbac7b7..e4d3975e 100644 --- a/relay/channel/claude/relay-claude.go +++ b/relay/channel/claude/relay-claude.go @@ -740,7 +740,7 @@ func HandleClaudeResponseData(c *gin.Context, info *relaycommon.RelayInfo, claud return nil } -func ClaudeHandler(c *gin.Context, resp *http.Response, requestMode int, info *relaycommon.RelayInfo) (*types.NewAPIError, *dto.Usage) { +func ClaudeHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo, requestMode int) (*types.NewAPIError, *dto.Usage) { defer common.CloseResponseBodyGracefully(resp) claudeInfo := &ClaudeResponseInfo{ diff --git a/relay/channel/vertex/adaptor.go b/relay/channel/vertex/adaptor.go index 4648a384..35e4490b 100644 --- a/relay/channel/vertex/adaptor.go +++ b/relay/channel/vertex/adaptor.go @@ -238,7 +238,7 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycom } else { switch a.RequestMode { case RequestModeClaude: - err, usage = claude.ClaudeHandler(c, resp, claude.RequestModeMessage, info) + err, usage = claude.ClaudeHandler(c, resp, info, claude.RequestModeMessage) case RequestModeGemini: if info.RelayMode == constant.RelayModeGemini { usage, err = gemini.GeminiTextGenerationHandler(c, info, resp) diff --git a/service/error.go b/service/error.go index ad28c90f..9672402d 100644 --- a/service/error.go +++ b/service/error.go @@ -93,6 +93,9 @@ func RelayErrorHandler(resp *http.Response, showBodyWhenFail bool) (newApiErr *t if showBodyWhenFail { newApiErr.Err = fmt.Errorf("bad response status code %d, body: %s", resp.StatusCode, string(responseBody)) } else { + if common.DebugEnabled { + println(fmt.Sprintf("bad response status code %d, body: %s", resp.StatusCode, string(responseBody))) + } newApiErr.Err = fmt.Errorf("bad response status code %d", resp.StatusCode) } return From 18c630e5e416e727e500b5f025e2c01edabb0498 Mon Sep 17 00:00:00 2001 From: CaIon Date: Thu, 7 Aug 2025 19:01:49 +0800 Subject: [PATCH 10/12] feat: support deepseek claude format (convert) --- dto/claude.go | 2 +- relay/channel/deepseek/adaptor.go | 7 +++---- service/convert.go | 6 +++++- 3 files changed, 9 insertions(+), 6 deletions(-) diff --git a/dto/claude.go b/dto/claude.go index ea099df4..7b5f348e 100644 --- a/dto/claude.go +++ b/dto/claude.go @@ -361,7 +361,7 @@ type ClaudeUsage struct { CacheCreationInputTokens int `json:"cache_creation_input_tokens"` CacheReadInputTokens int `json:"cache_read_input_tokens"` OutputTokens int `json:"output_tokens"` - ServerToolUse *ClaudeServerToolUse `json:"server_tool_use"` + ServerToolUse *ClaudeServerToolUse `json:"server_tool_use,omitempty"` } type ClaudeServerToolUse struct { diff --git a/relay/channel/deepseek/adaptor.go b/relay/channel/deepseek/adaptor.go index ac8ea18f..be8de0c8 100644 --- a/relay/channel/deepseek/adaptor.go +++ b/relay/channel/deepseek/adaptor.go @@ -24,10 +24,9 @@ func (a *Adaptor) ConvertGeminiRequest(*gin.Context, *relaycommon.RelayInfo, *dt return nil, errors.New("not implemented") } -func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) { - //TODO implement me - panic("implement me") - return nil, nil +func (a *Adaptor) ConvertClaudeRequest(c *gin.Context, info *relaycommon.RelayInfo, req *dto.ClaudeRequest) (any, error) { + adaptor := openai.Adaptor{} + return adaptor.ConvertClaudeRequest(c, info, req) } func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) { diff --git a/service/convert.go b/service/convert.go index ee8ecee5..967e4682 100644 --- a/service/convert.go +++ b/service/convert.go @@ -283,7 +283,9 @@ func StreamResponseOpenAI2Claude(openAIResponse *dto.ChatCompletionsStreamRespon if chosenChoice.FinishReason != nil && *chosenChoice.FinishReason != "" { // should be done info.FinishReason = *chosenChoice.FinishReason - return claudeResponses + if !info.Done { + return claudeResponses + } } if info.Done { claudeResponses = append(claudeResponses, generateStopBlock(info.ClaudeConvertInfo.Index)) @@ -432,6 +434,8 @@ func stopReasonOpenAI2Claude(reason string) string { return "end_turn" case "stop_sequence": return "stop_sequence" + case "length": + fallthrough case "max_tokens": return "max_tokens" case "tool_calls": From 7ddd3140151718d13b2d64c97136fd96ddeb0f35 Mon Sep 17 00:00:00 2001 From: CaIon Date: Thu, 7 Aug 2025 19:19:59 +0800 Subject: [PATCH 11/12] feat: implement ConvertClaudeRequest method in baidu_v2 Adaptor --- relay/channel/baidu_v2/adaptor.go | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/relay/channel/baidu_v2/adaptor.go b/relay/channel/baidu_v2/adaptor.go index b8a4ac2f..c0ea0e60 100644 --- a/relay/channel/baidu_v2/adaptor.go +++ b/relay/channel/baidu_v2/adaptor.go @@ -23,10 +23,9 @@ func (a *Adaptor) ConvertGeminiRequest(*gin.Context, *relaycommon.RelayInfo, *dt return nil, errors.New("not implemented") } -func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) { - //TODO implement me - panic("implement me") - return nil, nil +func (a *Adaptor) ConvertClaudeRequest(c *gin.Context, info *relaycommon.RelayInfo, req *dto.ClaudeRequest) (any, error) { + adaptor := openai.Adaptor{} + return adaptor.ConvertClaudeRequest(c, info, req) } func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) { From c4dcc6df9c4bc0bc242abfbd80b4a86bf6b5bbbb Mon Sep 17 00:00:00 2001 From: CaIon Date: Thu, 7 Aug 2025 19:30:42 +0800 Subject: [PATCH 12/12] feat: enhance Adaptor to support multiple relay modes in request handling --- relay/channel/baidu_v2/adaptor.go | 23 +++++++++++++++++------ relay/channel/volcengine/adaptor.go | 25 +++++++++---------------- 2 files changed, 26 insertions(+), 22 deletions(-) diff --git a/relay/channel/baidu_v2/adaptor.go b/relay/channel/baidu_v2/adaptor.go index c0ea0e60..ba59e307 100644 --- a/relay/channel/baidu_v2/adaptor.go +++ b/relay/channel/baidu_v2/adaptor.go @@ -9,6 +9,7 @@ import ( "one-api/relay/channel" "one-api/relay/channel/openai" relaycommon "one-api/relay/common" + "one-api/relay/constant" "one-api/types" "strings" @@ -42,7 +43,20 @@ func (a *Adaptor) Init(info *relaycommon.RelayInfo) { } func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) { - return fmt.Sprintf("%s/v2/chat/completions", info.BaseUrl), nil + switch info.RelayMode { + case constant.RelayModeChatCompletions: + return fmt.Sprintf("%s/v2/chat/completions", info.BaseUrl), nil + case constant.RelayModeEmbeddings: + return fmt.Sprintf("%s/v2/embeddings", info.BaseUrl), nil + case constant.RelayModeImagesGenerations: + return fmt.Sprintf("%s/v2/images/generations", info.BaseUrl), nil + case constant.RelayModeImagesEdits: + return fmt.Sprintf("%s/v2/images/edits", info.BaseUrl), nil + case constant.RelayModeRerank: + return fmt.Sprintf("%s/v2/rerank", info.BaseUrl), nil + default: + } + return "", fmt.Errorf("unsupported relay mode: %d", info.RelayMode) } func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error { @@ -98,11 +112,8 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request } func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *types.NewAPIError) { - if info.IsStream { - usage, err = openai.OaiStreamHandler(c, info, resp) - } else { - usage, err = openai.OpenaiHandler(c, info, resp) - } + adaptor := openai.Adaptor{} + usage, err = adaptor.DoResponse(c, resp, info) return } diff --git a/relay/channel/volcengine/adaptor.go b/relay/channel/volcengine/adaptor.go index 225b3895..2cc4f663 100644 --- a/relay/channel/volcengine/adaptor.go +++ b/relay/channel/volcengine/adaptor.go @@ -28,10 +28,9 @@ func (a *Adaptor) ConvertGeminiRequest(*gin.Context, *relaycommon.RelayInfo, *dt return nil, errors.New("not implemented") } -func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) { - //TODO implement me - panic("implement me") - return nil, nil +func (a *Adaptor) ConvertClaudeRequest(c *gin.Context, info *relaycommon.RelayInfo, req *dto.ClaudeRequest) (any, error) { + adaptor := openai.Adaptor{} + return adaptor.ConvertClaudeRequest(c, info, req) } func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) { @@ -196,6 +195,10 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) { return fmt.Sprintf("%s/api/v3/embeddings", info.BaseUrl), nil case constant.RelayModeImagesGenerations: return fmt.Sprintf("%s/api/v3/images/generations", info.BaseUrl), nil + case constant.RelayModeImagesEdits: + return fmt.Sprintf("%s/api/v3/images/edits", info.BaseUrl), nil + case constant.RelayModeRerank: + return fmt.Sprintf("%s/api/v3/rerank", info.BaseUrl), nil default: } return "", fmt.Errorf("unsupported relay mode: %d", info.RelayMode) @@ -232,18 +235,8 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request } func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *types.NewAPIError) { - switch info.RelayMode { - case constant.RelayModeChatCompletions: - if info.IsStream { - usage, err = openai.OaiStreamHandler(c, info, resp) - } else { - usage, err = openai.OpenaiHandler(c, info, resp) - } - case constant.RelayModeEmbeddings: - usage, err = openai.OpenaiHandler(c, info, resp) - case constant.RelayModeImagesGenerations, constant.RelayModeImagesEdits: - usage, err = openai.OpenaiHandlerWithUsage(c, info, resp) - } + adaptor := openai.Adaptor{} + usage, err = adaptor.DoResponse(c, resp, info) return }