From a981e107120d96eb1775605ecd6c77f1908963c7 Mon Sep 17 00:00:00 2001 From: "1808837298@qq.com" <1808837298@qq.com> Date: Wed, 12 Mar 2025 17:53:46 +0800 Subject: [PATCH 01/18] feat(relay): Add Xinference channel support --- common/constants.go | 2 ++ relay/channel/openai/adaptor.go | 5 +++++ relay/channel/xinference/constant.go | 7 +++++++ relay/constant/api_type.go | 3 +++ relay/relay_adaptor.go | 4 ++-- web/src/constants/channel.constants.js | 12 +++++++++--- 6 files changed, 28 insertions(+), 5 deletions(-) create mode 100644 relay/channel/xinference/constant.go diff --git a/common/constants.go b/common/constants.go index 9611ed0e..36b6277e 100644 --- a/common/constants.go +++ b/common/constants.go @@ -235,6 +235,7 @@ const ( ChannelTypeMokaAI = 44 ChannelTypeVolcEngine = 45 ChannelTypeBaiduV2 = 46 + ChannelTypeXinference = 47 ChannelTypeDummy // this one is only for count, do not add any channel after this ) @@ -287,4 +288,5 @@ var ChannelBaseURLs = []string{ "https://api.moka.ai", //44 "https://ark.cn-beijing.volces.com", //45 "https://qianfan.baidubce.com", //46 + "", //47 } diff --git a/relay/channel/openai/adaptor.go b/relay/channel/openai/adaptor.go index 6dbbb17e..d8a44335 100644 --- a/relay/channel/openai/adaptor.go +++ b/relay/channel/openai/adaptor.go @@ -18,6 +18,7 @@ import ( "one-api/relay/channel/lingyiwanwu" "one-api/relay/channel/minimax" "one-api/relay/channel/moonshot" + "one-api/relay/channel/xinference" relaycommon "one-api/relay/common" "one-api/relay/constant" "strings" @@ -251,6 +252,8 @@ func (a *Adaptor) GetModelList() []string { return lingyiwanwu.ModelList case common.ChannelTypeMiniMax: return minimax.ModelList + case common.ChannelTypeXinference: + return xinference.ModelList default: return ModelList } @@ -266,6 +269,8 @@ func (a *Adaptor) GetChannelName() string { return lingyiwanwu.ChannelName case common.ChannelTypeMiniMax: return minimax.ChannelName + case common.ChannelTypeXinference: + return xinference.ChannelName default: return ChannelName } diff --git a/relay/channel/xinference/constant.go b/relay/channel/xinference/constant.go new file mode 100644 index 00000000..98ec9b04 --- /dev/null +++ b/relay/channel/xinference/constant.go @@ -0,0 +1,7 @@ +package xinference + +var ModelList = []string{ + "bge-reranker-v2-m3", +} + +var ChannelName = "xinference" diff --git a/relay/constant/api_type.go b/relay/constant/api_type.go index 8ccfee03..2cd0e399 100644 --- a/relay/constant/api_type.go +++ b/relay/constant/api_type.go @@ -31,6 +31,7 @@ const ( APITypeVolcEngine APITypeBaiduV2 APITypeOpenRouter + APITypeXinference APITypeDummy // this one is only for count, do not add any channel after this ) @@ -89,6 +90,8 @@ func ChannelType2APIType(channelType int) (int, bool) { apiType = APITypeBaiduV2 case common.ChannelTypeOpenRouter: apiType = APITypeOpenRouter + case common.ChannelTypeXinference: + apiType = APITypeXinference } if apiType == -1 { return APITypeOpenAI, false diff --git a/relay/relay_adaptor.go b/relay/relay_adaptor.go index 00cff316..f6d141fa 100644 --- a/relay/relay_adaptor.go +++ b/relay/relay_adaptor.go @@ -34,8 +34,6 @@ import ( func GetAdaptor(apiType int) channel.Adaptor { switch apiType { - //case constant.APITypeAIProxyLibrary: - // return &aiproxy.Adaptor{} case constant.APITypeAli: return &ali.Adaptor{} case constant.APITypeAnthropic: @@ -86,6 +84,8 @@ func GetAdaptor(apiType int) channel.Adaptor { return &baidu_v2.Adaptor{} case constant.APITypeOpenRouter: return &openrouter.Adaptor{} + case constant.APITypeXinference: + return &openai.Adaptor{} } return nil } diff --git a/web/src/constants/channel.constants.js b/web/src/constants/channel.constants.js index 5738d656..f1d0c88d 100644 --- a/web/src/constants/channel.constants.js +++ b/web/src/constants/channel.constants.js @@ -80,11 +80,12 @@ export const CHANNEL_OPTIONS = [ label: 'Google PaLM2' }, { - value: 45, + value: 47, color: 'blue', - label: '字节火山方舟、豆包、DeepSeek通用' + label: 'Xinference' }, { value: 25, color: 'green', label: 'Moonshot' }, + { value: 20, color: 'green', label: 'OpenRouter' }, { value: 19, color: 'blue', label: '360 智脑' }, { value: 23, color: 'teal', label: '腾讯混元' }, { value: 31, color: 'green', label: '零一万物' }, @@ -108,5 +109,10 @@ export const CHANNEL_OPTIONS = [ value: 44, color: 'purple', label: '嵌入模型:MokaAI M3E' - } + }, + { + value: 45, + color: 'blue', + label: '字节火山方舟、豆包、DeepSeek通用' + }, ]; From 39d95172e892be4eead4b0e10d74fee21512c119 Mon Sep 17 00:00:00 2001 From: "1808837298@qq.com" <1808837298@qq.com> Date: Wed, 12 Mar 2025 18:53:38 +0800 Subject: [PATCH 02/18] fix: claude to openai tools use --- relay/channel/claude/relay-claude.go | 1 - 1 file changed, 1 deletion(-) diff --git a/relay/channel/claude/relay-claude.go b/relay/channel/claude/relay-claude.go index 09154bcb..40659020 100644 --- a/relay/channel/claude/relay-claude.go +++ b/relay/channel/claude/relay-claude.go @@ -471,7 +471,6 @@ func ClaudeStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon. usage.CompletionTokens = claudeUsage.OutputTokens usage.TotalTokens = claudeUsage.InputTokens + claudeUsage.OutputTokens } else if claudeResponse.Type == "content_block_start" { - return true } else { return true } From 229738cda9eb93dd23ee6b2be1286ab4b542fab5 Mon Sep 17 00:00:00 2001 From: "1808837298@qq.com" <1808837298@qq.com> Date: Wed, 12 Mar 2025 19:29:15 +0800 Subject: [PATCH 03/18] fix: claude to openai tools use --- relay/channel/aws/relay-aws.go | 45 +++++------ relay/channel/claude/relay-claude.go | 111 +++++++++++++++------------ 2 files changed, 83 insertions(+), 73 deletions(-) diff --git a/relay/channel/aws/relay-aws.go b/relay/channel/aws/relay-aws.go index 976f97ce..e1270606 100644 --- a/relay/channel/aws/relay-aws.go +++ b/relay/channel/aws/relay-aws.go @@ -144,11 +144,14 @@ func awsStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel defer stream.Close() c.Writer.Header().Set("Content-Type", "text/event-stream") - var usage relaymodel.Usage - var id string - var model string + claudeInfo := &claude.ClaudeResponseInfo{ + ResponseId: fmt.Sprintf("chatcmpl-%s", common.GetUUID()), + Created: common.GetTimestamp(), + Model: info.UpstreamModelName, + ResponseText: strings.Builder{}, + Usage: &relaymodel.Usage{}, + } isFirst := true - createdTime := common.GetTimestamp() c.Stream(func(w io.Writer) bool { event, ok := <-stream.Events() if !ok { @@ -161,33 +164,19 @@ func awsStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel isFirst = false info.FirstResponseTime = time.Now() } - claudeResp := new(claude.ClaudeResponse) - err := json.NewDecoder(bytes.NewReader(v.Value.Bytes)).Decode(claudeResp) + claudeResponse := new(claude.ClaudeResponse) + err := json.NewDecoder(bytes.NewReader(v.Value.Bytes)).Decode(claudeResponse) if err != nil { common.SysError("error unmarshalling stream response: " + err.Error()) return false } - response, claudeUsage := claude.StreamResponseClaude2OpenAI(requestMode, claudeResp) - if claudeUsage != nil { - usage.PromptTokens += claudeUsage.InputTokens - usage.CompletionTokens += claudeUsage.OutputTokens - } + response := claude.StreamResponseClaude2OpenAI(requestMode, claudeResponse) - if response == nil { + if !claude.FormatClaudeResponseInfo(RequestModeMessage, claudeResponse, response, claudeInfo) { return true } - if response.Id != "" { - id = response.Id - } - if response.Model != "" { - model = response.Model - } - response.Created = createdTime - response.Id = id - response.Model = model - jsonStr, err := json.Marshal(response) if err != nil { common.SysError("error marshalling stream response: " + err.Error()) @@ -203,8 +192,16 @@ func awsStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel return false } }) + + if claudeInfo.Usage.PromptTokens == 0 { + //上游出错 + } + if claudeInfo.Usage.CompletionTokens == 0 { + claudeInfo.Usage, _ = service.ResponseText2Usage(claudeInfo.ResponseText.String(), info.UpstreamModelName, claudeInfo.Usage.PromptTokens) + } + if info.ShouldIncludeUsage { - response := helper.GenerateFinalUsageResponse(id, createdTime, info.UpstreamModelName, usage) + response := helper.GenerateFinalUsageResponse(claudeInfo.ResponseId, claudeInfo.Created, info.UpstreamModelName, *claudeInfo.Usage) err := helper.ObjectData(c, response) if err != nil { common.SysError("send final response failed: " + err.Error()) @@ -217,5 +214,5 @@ func awsStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel return service.OpenAIErrorWrapperLocal(err, "close_response_body_failed", http.StatusInternalServerError), nil } } - return nil, &usage + return nil, claudeInfo.Usage } diff --git a/relay/channel/claude/relay-claude.go b/relay/channel/claude/relay-claude.go index 40659020..fb4f5b7e 100644 --- a/relay/channel/claude/relay-claude.go +++ b/relay/channel/claude/relay-claude.go @@ -1,6 +1,7 @@ package claude import ( + "bytes" "encoding/json" "fmt" "io" @@ -290,9 +291,8 @@ func RequestOpenAI2ClaudeMessage(textRequest dto.GeneralOpenAIRequest) (*ClaudeR return &claudeRequest, nil } -func StreamResponseClaude2OpenAI(reqMode int, claudeResponse *ClaudeResponse) (*dto.ChatCompletionsStreamResponse, *ClaudeUsage) { +func StreamResponseClaude2OpenAI(reqMode int, claudeResponse *ClaudeResponse) *dto.ChatCompletionsStreamResponse { var response dto.ChatCompletionsStreamResponse - var claudeUsage *ClaudeUsage response.Object = "chat.completion.chunk" response.Model = claudeResponse.Model response.Choices = make([]dto.ChatCompletionsStreamResponseChoice, 0) @@ -308,7 +308,7 @@ func StreamResponseClaude2OpenAI(reqMode int, claudeResponse *ClaudeResponse) (* if claudeResponse.Type == "message_start" { response.Id = claudeResponse.Message.Id response.Model = claudeResponse.Message.Model - claudeUsage = &claudeResponse.Message.Usage + //claudeUsage = &claudeResponse.Message.Usage choice.Delta.SetContentString("") choice.Delta.Role = "assistant" } else if claudeResponse.Type == "content_block_start" { @@ -325,7 +325,7 @@ func StreamResponseClaude2OpenAI(reqMode int, claudeResponse *ClaudeResponse) (* }) } } else { - return nil, nil + return nil } } else if claudeResponse.Type == "content_block_delta" { if claudeResponse.Delta != nil { @@ -352,23 +352,20 @@ func StreamResponseClaude2OpenAI(reqMode int, claudeResponse *ClaudeResponse) (* if finishReason != "null" { choice.FinishReason = &finishReason } - claudeUsage = &claudeResponse.Usage + //claudeUsage = &claudeResponse.Usage } else if claudeResponse.Type == "message_stop" { - return nil, nil + return nil } else { - return nil, nil + return nil } } - if claudeUsage == nil { - claudeUsage = &ClaudeUsage{} - } if len(tools) > 0 { choice.Delta.Content = nil // compatible with other OpenAI derivative applications, like LobeOpenAICompatibleFactory ... choice.Delta.ToolCalls = tools } response.Choices = append(response.Choices, choice) - return &response, claudeUsage + return &response } func ResponseClaude2OpenAI(reqMode int, claudeResponse *ClaudeResponse) *dto.OpenAITextResponse { @@ -437,48 +434,65 @@ func ResponseClaude2OpenAI(reqMode int, claudeResponse *ClaudeResponse) *dto.Ope return &fullTextResponse } +type ClaudeResponseInfo struct { + ResponseId string + Created int64 + Model string + ResponseText strings.Builder + Usage *dto.Usage +} + +func FormatClaudeResponseInfo(requestMode int, claudeResponse *ClaudeResponse, oaiResponse *dto.ChatCompletionsStreamResponse, claudeInfo *ClaudeResponseInfo) bool { + if oaiResponse == nil { + return false + } + if requestMode == RequestModeCompletion { + claudeInfo.ResponseText.WriteString(claudeResponse.Completion) + } else { + if claudeResponse.Type == "message_start" { + // message_start, 获取usage + claudeInfo.ResponseId = claudeResponse.Message.Id + claudeInfo.Model = claudeResponse.Message.Model + claudeInfo.Usage.PromptTokens = claudeResponse.Message.Usage.InputTokens + } else if claudeResponse.Type == "content_block_delta" { + claudeInfo.ResponseText.WriteString(claudeResponse.Delta.Text) + } else if claudeResponse.Type == "message_delta" { + claudeInfo.Usage.CompletionTokens = claudeResponse.Usage.OutputTokens + claudeInfo.Usage.TotalTokens = claudeResponse.Usage.InputTokens + claudeResponse.Usage.OutputTokens + } else if claudeResponse.Type == "content_block_start" { + } else { + return false + } + } + oaiResponse.Id = claudeInfo.ResponseId + oaiResponse.Created = claudeInfo.Created + oaiResponse.Model = claudeInfo.Model + return true +} + func ClaudeStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo, requestMode int) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) { responseId := fmt.Sprintf("chatcmpl-%s", common.GetUUID()) - var usage *dto.Usage - usage = &dto.Usage{} - responseText := "" - createdTime := common.GetTimestamp() + claudeInfo := &ClaudeResponseInfo{ + ResponseId: responseId, + Created: common.GetTimestamp(), + Model: info.UpstreamModelName, + ResponseText: strings.Builder{}, + Usage: &dto.Usage{}, + } helper.StreamScannerHandler(c, resp, info, func(data string) bool { var claudeResponse ClaudeResponse - err := json.Unmarshal([]byte(data), &claudeResponse) + err := json.NewDecoder(bytes.NewReader(common.StringToByteSlice(data))).Decode(&claudeResponse) if err != nil { common.SysError("error unmarshalling stream response: " + err.Error()) return true } - response, claudeUsage := StreamResponseClaude2OpenAI(requestMode, &claudeResponse) - if response == nil { + response := StreamResponseClaude2OpenAI(requestMode, &claudeResponse) + + if !FormatClaudeResponseInfo(requestMode, &claudeResponse, response, claudeInfo) { return true } - if requestMode == RequestModeCompletion { - responseText += claudeResponse.Completion - responseId = response.Id - } else { - if claudeResponse.Type == "message_start" { - // message_start, 获取usage - responseId = claudeResponse.Message.Id - info.UpstreamModelName = claudeResponse.Message.Model - usage.PromptTokens = claudeUsage.InputTokens - } else if claudeResponse.Type == "content_block_delta" { - responseText += claudeResponse.Delta.Text - } else if claudeResponse.Type == "message_delta" { - usage.CompletionTokens = claudeUsage.OutputTokens - usage.TotalTokens = claudeUsage.InputTokens + claudeUsage.OutputTokens - } else if claudeResponse.Type == "content_block_start" { - } else { - return true - } - } - //response.Id = responseId - response.Id = responseId - response.Created = createdTime - response.Model = info.UpstreamModelName err = helper.ObjectData(c, response) if err != nil { @@ -488,25 +502,24 @@ func ClaudeStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon. }) if requestMode == RequestModeCompletion { - usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens) + claudeInfo.Usage, _ = service.ResponseText2Usage(claudeInfo.ResponseText.String(), info.UpstreamModelName, info.PromptTokens) } else { - if usage.PromptTokens == 0 { - usage.PromptTokens = info.PromptTokens + if claudeInfo.Usage.PromptTokens == 0 { + //上游出错 } - if usage.CompletionTokens == 0 { - usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, usage.PromptTokens) + if claudeInfo.Usage.CompletionTokens == 0 { + claudeInfo.Usage, _ = service.ResponseText2Usage(claudeInfo.ResponseText.String(), info.UpstreamModelName, claudeInfo.Usage.PromptTokens) } } if info.ShouldIncludeUsage { - response := helper.GenerateFinalUsageResponse(responseId, createdTime, info.UpstreamModelName, *usage) + response := helper.GenerateFinalUsageResponse(responseId, claudeInfo.Created, info.UpstreamModelName, *claudeInfo.Usage) err := helper.ObjectData(c, response) if err != nil { common.SysError("send final response failed: " + err.Error()) } } helper.Done(c) - //resp.Body.Close() - return nil, usage + return nil, claudeInfo.Usage } func ClaudeHandler(c *gin.Context, resp *http.Response, requestMode int, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) { From c0b93507853e9a11a4e6b043ada9ed712a6dac67 Mon Sep 17 00:00:00 2001 From: "1808837298@qq.com" <1808837298@qq.com> Date: Wed, 12 Mar 2025 19:46:08 +0800 Subject: [PATCH 04/18] fix: claude to openai tools use --- relay/channel/claude/relay-claude.go | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/relay/channel/claude/relay-claude.go b/relay/channel/claude/relay-claude.go index fb4f5b7e..011694df 100644 --- a/relay/channel/claude/relay-claude.go +++ b/relay/channel/claude/relay-claude.go @@ -443,9 +443,6 @@ type ClaudeResponseInfo struct { } func FormatClaudeResponseInfo(requestMode int, claudeResponse *ClaudeResponse, oaiResponse *dto.ChatCompletionsStreamResponse, claudeInfo *ClaudeResponseInfo) bool { - if oaiResponse == nil { - return false - } if requestMode == RequestModeCompletion { claudeInfo.ResponseText.WriteString(claudeResponse.Completion) } else { @@ -464,9 +461,11 @@ func FormatClaudeResponseInfo(requestMode int, claudeResponse *ClaudeResponse, o return false } } - oaiResponse.Id = claudeInfo.ResponseId - oaiResponse.Created = claudeInfo.Created - oaiResponse.Model = claudeInfo.Model + if oaiResponse != nil { + oaiResponse.Id = claudeInfo.ResponseId + oaiResponse.Created = claudeInfo.Created + oaiResponse.Model = claudeInfo.Model + } return true } From c47d8a10f0afa2eff8bc22e5dc52b2b309337b8d Mon Sep 17 00:00:00 2001 From: Seefs Date: Wed, 12 Mar 2025 21:08:47 +0800 Subject: [PATCH 05/18] feat: Support postgresql:// dsn format --- model/main.go | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/model/main.go b/model/main.go index c0bf927c..649dac54 100644 --- a/model/main.go +++ b/model/main.go @@ -1,16 +1,17 @@ package model import ( - "github.com/glebarez/sqlite" - "gorm.io/driver/mysql" - "gorm.io/driver/postgres" - "gorm.io/gorm" "log" "one-api/common" "os" "strings" "sync" "time" + + "github.com/glebarez/sqlite" + "gorm.io/driver/mysql" + "gorm.io/driver/postgres" + "gorm.io/gorm" ) var groupCol string @@ -60,7 +61,7 @@ func chooseDB(envName string) (*gorm.DB, error) { }() dsn := os.Getenv(envName) if dsn != "" { - if strings.HasPrefix(dsn, "postgres://") { + if strings.HasPrefix(dsn, "postgres://") || strings.HasPrefix(dsn, "postgresql://") { // Use PostgreSQL common.SysLog("using PostgreSQL as database") common.UsingPostgreSQL = true From bd48f434105a4031bab239d1a89aabeaa68e4d6e Mon Sep 17 00:00:00 2001 From: "1808837298@qq.com" <1808837298@qq.com> Date: Wed, 12 Mar 2025 21:31:46 +0800 Subject: [PATCH 06/18] feat: claude relay --- common/custom-event.go | 2 +- controller/relay.go | 51 ++++ dto/claude.go | 186 ++++++++++++++ dto/openai_response.go | 14 + dto/realtime.go | 9 +- go.mod | 2 +- middleware/auth.go | 8 + relay/channel/adapter.go | 1 + relay/channel/ali/adaptor.go | 6 + relay/channel/aws/adaptor.go | 6 +- relay/channel/aws/dto.go | 26 +- relay/channel/aws/relay-aws.go | 24 +- relay/channel/baidu/adaptor.go | 6 + relay/channel/baidu_v2/adaptor.go | 6 + relay/channel/claude/adaptor.go | 4 + relay/channel/claude/dto.go | 185 +++++++------- relay/channel/claude/relay-claude.go | 149 +++++++---- relay/channel/cloudflare/adaptor.go | 6 + relay/channel/cohere/adaptor.go | 7 +- relay/channel/deepseek/adaptor.go | 6 + relay/channel/dify/adaptor.go | 6 + relay/channel/gemini/adaptor.go | 6 + relay/channel/jina/adaptor.go | 6 + relay/channel/mistral/adaptor.go | 6 + relay/channel/mokaai/adaptor.go | 10 +- relay/channel/ollama/adaptor.go | 6 + relay/channel/openai/adaptor.go | 6 + relay/channel/openrouter/adaptor.go | 6 + relay/channel/palm/adaptor.go | 7 +- relay/channel/perplexity/adaptor.go | 7 +- relay/channel/siliconflow/adaptor.go | 6 + relay/channel/tencent/adaptor.go | 7 +- relay/channel/vertex/adaptor.go | 3 + relay/channel/vertex/dto.go | 28 +- relay/channel/volcengine/adaptor.go | 6 + relay/channel/xunfei/adaptor.go | 7 +- relay/channel/zhipu/adaptor.go | 7 +- relay/channel/zhipu_4v/adaptor.go | 7 +- relay/claude_handler.go | 162 ++++++++++++ relay/common/relay_info.go | 16 ++ relay/helper/common.go | 8 + relay/helper/price.go | 4 + router/relay-router.go | 1 + service/convert.go | 310 +++++++++++++++++++++++ service/error.go | 24 ++ service/log_info_generate.go | 9 + service/quota.go | 69 +++++ service/token_counter.go | 105 ++++++++ setting/operation_setting/cache_ratio.go | 71 +++--- web/src/components/LogsTable.js | 79 +++++- web/src/helpers/render.js | 197 +++++++++++++- 51 files changed, 1660 insertions(+), 236 deletions(-) create mode 100644 dto/claude.go create mode 100644 relay/claude_handler.go create mode 100644 service/convert.go diff --git a/common/custom-event.go b/common/custom-event.go index 69da4bc4..d8f9ec9f 100644 --- a/common/custom-event.go +++ b/common/custom-event.go @@ -44,7 +44,7 @@ var fieldReplacer = strings.NewReplacer( "\r", "\\r") var dataReplacer = strings.NewReplacer( - "\n", "\ndata:", + "\n", "\n", "\r", "\\r") type CustomEvent struct { diff --git a/controller/relay.go b/controller/relay.go index 460599b5..fb4c524f 100644 --- a/controller/relay.go +++ b/controller/relay.go @@ -148,6 +148,50 @@ func WssRelay(c *gin.Context) { } } +func RelayClaude(c *gin.Context) { + //relayMode := constant.Path2RelayMode(c.Request.URL.Path) + requestId := c.GetString(common.RequestIdKey) + group := c.GetString("group") + originalModel := c.GetString("original_model") + var claudeErr *dto.ClaudeErrorWithStatusCode + + for i := 0; i <= common.RetryTimes; i++ { + channel, err := getChannel(c, group, originalModel, i) + if err != nil { + common.LogError(c, err.Error()) + claudeErr = service.ClaudeErrorWrapperLocal(err, "get_channel_failed", http.StatusInternalServerError) + break + } + + claudeErr = claudeRequest(c, channel) + + if claudeErr == nil { + return // 成功处理请求,直接返回 + } + + openaiErr := service.ClaudeErrorToOpenAIError(claudeErr) + + go processChannelError(c, channel.Id, channel.Type, channel.Name, channel.GetAutoBan(), openaiErr) + + if !shouldRetry(c, openaiErr, common.RetryTimes-i) { + break + } + } + useChannel := c.GetStringSlice("use_channel") + if len(useChannel) > 1 { + retryLogStr := fmt.Sprintf("重试:%s", strings.Trim(strings.Join(strings.Fields(fmt.Sprint(useChannel)), "->"), "[]")) + common.LogInfo(c, retryLogStr) + } + + if claudeErr != nil { + claudeErr.Error.Message = common.MessageWithRequestId(claudeErr.Error.Message, requestId) + c.JSON(claudeErr.StatusCode, gin.H{ + "type": "error", + "error": claudeErr.Error, + }) + } +} + func relayRequest(c *gin.Context, relayMode int, channel *model.Channel) *dto.OpenAIErrorWithStatusCode { addUsedChannel(c, channel.Id) requestBody, _ := common.GetRequestBody(c) @@ -162,6 +206,13 @@ func wssRequest(c *gin.Context, ws *websocket.Conn, relayMode int, channel *mode return relay.WssHelper(c, ws) } +func claudeRequest(c *gin.Context, channel *model.Channel) *dto.ClaudeErrorWithStatusCode { + addUsedChannel(c, channel.Id) + requestBody, _ := common.GetRequestBody(c) + c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody)) + return relay.ClaudeHelper(c) +} + func addUsedChannel(c *gin.Context, channelId int) { useChannel := c.GetStringSlice("use_channel") useChannel = append(useChannel, fmt.Sprintf("%d", channelId)) diff --git a/dto/claude.go b/dto/claude.go new file mode 100644 index 00000000..60f638f6 --- /dev/null +++ b/dto/claude.go @@ -0,0 +1,186 @@ +package dto + +import "encoding/json" + +type ClaudeMetadata struct { + UserId string `json:"user_id"` +} + +type ClaudeMediaMessage struct { + Type string `json:"type"` + Text *string `json:"text,omitempty"` + Model string `json:"model,omitempty"` + Source *ClaudeMessageSource `json:"source,omitempty"` + Usage *ClaudeUsage `json:"usage,omitempty"` + StopReason *string `json:"stop_reason,omitempty"` + PartialJson string `json:"partial_json,omitempty"` + Role string `json:"role,omitempty"` + Thinking string `json:"thinking,omitempty"` + Signature string `json:"signature,omitempty"` + Delta string `json:"delta,omitempty"` + // tool_calls + Id string `json:"id,omitempty"` + Name string `json:"name,omitempty"` + Input any `json:"input,omitempty"` + Content json.RawMessage `json:"content,omitempty"` + ToolUseId string `json:"tool_use_id,omitempty"` +} + +func (c *ClaudeMediaMessage) SetText(s string) { + c.Text = &s +} + +func (c *ClaudeMediaMessage) GetText() string { + if c.Text == nil { + return "" + } + return *c.Text +} + +type ClaudeMessageSource struct { + Type string `json:"type"` + MediaType string `json:"media_type"` + Data any `json:"data"` +} + +type ClaudeMessage struct { + Role string `json:"role"` + Content any `json:"content"` +} + +func (c *ClaudeMessage) IsStringContent() bool { + _, ok := c.Content.(string) + return ok +} + +func (c *ClaudeMessage) GetStringContent() string { + if c.IsStringContent() { + return c.Content.(string) + } + return "" +} + +func (c *ClaudeMessage) SetStringContent(content string) { + c.Content = content +} + +func (c *ClaudeMessage) ParseContent() ([]ClaudeMediaMessage, error) { + // map content to []ClaudeMediaMessage + // parse to json + jsonContent, _ := json.Marshal(c.Content) + var contentList []ClaudeMediaMessage + err := json.Unmarshal(jsonContent, &contentList) + if err != nil { + return make([]ClaudeMediaMessage, 0), err + } + return contentList, nil +} + +type Tool struct { + Name string `json:"name"` + Description string `json:"description,omitempty"` + InputSchema map[string]interface{} `json:"input_schema"` +} + +type InputSchema struct { + Type string `json:"type"` + Properties any `json:"properties,omitempty"` + Required any `json:"required,omitempty"` +} + +type ClaudeRequest struct { + Model string `json:"model"` + Prompt string `json:"prompt,omitempty"` + System any `json:"system,omitempty"` + Messages []ClaudeMessage `json:"messages,omitempty"` + MaxTokens uint `json:"max_tokens,omitempty"` + MaxTokensToSample uint `json:"max_tokens_to_sample,omitempty"` + StopSequences []string `json:"stop_sequences,omitempty"` + Temperature *float64 `json:"temperature,omitempty"` + TopP float64 `json:"top_p,omitempty"` + TopK int `json:"top_k,omitempty"` + //ClaudeMetadata `json:"metadata,omitempty"` + Stream bool `json:"stream,omitempty"` + Tools any `json:"tools,omitempty"` + ToolChoice any `json:"tool_choice,omitempty"` + Thinking *Thinking `json:"thinking,omitempty"` +} + +type Thinking struct { + Type string `json:"type"` + BudgetTokens int `json:"budget_tokens"` +} + +func (c *ClaudeRequest) IsStringSystem() bool { + _, ok := c.System.(string) + return ok +} + +func (c *ClaudeRequest) GetStringSystem() string { + if c.IsStringSystem() { + return c.System.(string) + } + return "" +} + +func (c *ClaudeRequest) SetStringSystem(system string) { + c.System = system +} + +func (c *ClaudeRequest) ParseSystem() []ClaudeMediaMessage { + // map content to []ClaudeMediaMessage + // parse to json + jsonContent, _ := json.Marshal(c.System) + var contentList []ClaudeMediaMessage + if err := json.Unmarshal(jsonContent, &contentList); err == nil { + return contentList + } + return make([]ClaudeMediaMessage, 0) +} + +type ClaudeError struct { + Type string `json:"type"` + Message string `json:"message"` +} + +type ClaudeErrorWithStatusCode struct { + Error ClaudeError `json:"error"` + StatusCode int `json:"status_code"` + LocalError bool +} + +type ClaudeResponse struct { + Id string `json:"id,omitempty"` + Type string `json:"type"` + Role string `json:"role,omitempty"` + Content []ClaudeMediaMessage `json:"content,omitempty"` + Completion string `json:"completion,omitempty"` + StopReason string `json:"stop_reason,omitempty"` + Model string `json:"model,omitempty"` + Error *ClaudeError `json:"error,omitempty"` + Usage *ClaudeUsage `json:"usage,omitempty"` + Index *int `json:"index,omitempty"` + ContentBlock *ClaudeMediaMessage `json:"content_block,omitempty"` + Delta *ClaudeMediaMessage `json:"delta,omitempty"` + Message *ClaudeMediaMessage `json:"message,omitempty"` +} + +// set index +func (c *ClaudeResponse) SetIndex(i int) { + c.Index = &i +} + +// get index +func (c *ClaudeResponse) GetIndex() int { + if c.Index == nil { + return 0 + } + return *c.Index +} + +type ClaudeUsage struct { + InputTokens int `json:"input_tokens"` + CacheCreationInputTokens int `json:"cache_creation_input_tokens"` + CacheReadInputTokens int `json:"cache_read_input_tokens"` + OutputTokens int `json:"output_tokens"` +} diff --git a/dto/openai_response.go b/dto/openai_response.go index 9188fad7..4097db55 100644 --- a/dto/openai_response.go +++ b/dto/openai_response.go @@ -125,6 +125,20 @@ type ChatCompletionsStreamResponse struct { Usage *Usage `json:"usage"` } +func (c *ChatCompletionsStreamResponse) IsToolCall() bool { + if len(c.Choices) == 0 { + return false + } + return len(c.Choices[0].Delta.ToolCalls) > 0 +} + +func (c *ChatCompletionsStreamResponse) GetFirstToolCall() *ToolCallResponse { + if c.IsToolCall() { + return &c.Choices[0].Delta.ToolCalls[0] + } + return nil +} + func (c *ChatCompletionsStreamResponse) Copy() *ChatCompletionsStreamResponse { choices := make([]ChatCompletionsStreamResponseChoice, len(c.Choices)) copy(choices, c.Choices) diff --git a/dto/realtime.go b/dto/realtime.go index e28d813e..8c6e8932 100644 --- a/dto/realtime.go +++ b/dto/realtime.go @@ -44,10 +44,11 @@ type RealtimeUsage struct { } type InputTokenDetails struct { - CachedTokens int `json:"cached_tokens"` - TextTokens int `json:"text_tokens"` - AudioTokens int `json:"audio_tokens"` - ImageTokens int `json:"image_tokens"` + CachedTokens int `json:"cached_tokens"` + CachedCreationTokens int + TextTokens int `json:"text_tokens"` + AudioTokens int `json:"audio_tokens"` + ImageTokens int `json:"image_tokens"` } type OutputTokenDetails struct { diff --git a/go.mod b/go.mod index ca526466..ce768bf3 100644 --- a/go.mod +++ b/go.mod @@ -11,6 +11,7 @@ require ( github.com/aws/aws-sdk-go-v2/credentials v1.17.11 github.com/aws/aws-sdk-go-v2/service/bedrockruntime v1.7.4 github.com/bytedance/gopkg v0.0.0-20220118071334-3db87571198b + github.com/bytedance/sonic v1.11.6 github.com/gin-contrib/cors v1.7.2 github.com/gin-contrib/gzip v0.0.6 github.com/gin-contrib/sessions v0.0.5 @@ -42,7 +43,6 @@ require ( github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.5 // indirect github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.5 // indirect github.com/aws/smithy-go v1.20.2 // indirect - github.com/bytedance/sonic v1.11.6 // indirect github.com/bytedance/sonic/loader v0.1.1 // indirect github.com/cespare/xxhash/v2 v2.3.0 // indirect github.com/cloudwego/base64x v0.1.4 // indirect diff --git a/middleware/auth.go b/middleware/auth.go index a589f52c..fece4553 100644 --- a/middleware/auth.go +++ b/middleware/auth.go @@ -174,6 +174,14 @@ func TokenAuth() func(c *gin.Context) { } c.Request.Header.Set("Authorization", "Bearer "+key) } + // 检查path包含/v1/messages + if strings.Contains(c.Request.URL.Path, "/v1/messages") { + // 从x-api-key中获取key + key := c.Request.Header.Get("x-api-key") + if key != "" { + c.Request.Header.Set("Authorization", "Bearer "+key) + } + } key := c.Request.Header.Get("Authorization") parts := make([]string, 0) key = strings.TrimPrefix(key, "Bearer ") diff --git a/relay/channel/adapter.go b/relay/channel/adapter.go index c970fd48..9f449b54 100644 --- a/relay/channel/adapter.go +++ b/relay/channel/adapter.go @@ -22,6 +22,7 @@ type Adaptor interface { DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) GetModelList() []string GetChannelName() string + ConvertClaudeRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.ClaudeRequest) (any, error) } type TaskAdaptor interface { diff --git a/relay/channel/ali/adaptor.go b/relay/channel/ali/adaptor.go index 32be399b..9d3ee99f 100644 --- a/relay/channel/ali/adaptor.go +++ b/relay/channel/ali/adaptor.go @@ -16,6 +16,12 @@ import ( type Adaptor struct { } +func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) { + //TODO implement me + panic("implement me") + return nil, nil +} + func (a *Adaptor) Init(info *relaycommon.RelayInfo) { } diff --git a/relay/channel/aws/adaptor.go b/relay/channel/aws/adaptor.go index 7f2a2841..e735ee2b 100644 --- a/relay/channel/aws/adaptor.go +++ b/relay/channel/aws/adaptor.go @@ -20,6 +20,10 @@ type Adaptor struct { RequestMode int } +func (a *Adaptor) ConvertClaudeRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.ClaudeRequest) (any, error) { + return request, nil +} + func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) { //TODO implement me return nil, errors.New("not implemented") @@ -48,7 +52,7 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, re return nil, errors.New("request is nil") } - var claudeReq *claude.ClaudeRequest + var claudeReq *dto.ClaudeRequest var err error claudeReq, err = claude.RequestOpenAI2ClaudeMessage(*request) if err != nil { diff --git a/relay/channel/aws/dto.go b/relay/channel/aws/dto.go index 3b615134..0188c30a 100644 --- a/relay/channel/aws/dto.go +++ b/relay/channel/aws/dto.go @@ -1,25 +1,25 @@ package aws import ( - "one-api/relay/channel/claude" + "one-api/dto" ) type AwsClaudeRequest struct { // AnthropicVersion should be "bedrock-2023-05-31" - AnthropicVersion string `json:"anthropic_version"` - System string `json:"system,omitempty"` - Messages []claude.ClaudeMessage `json:"messages"` - MaxTokens uint `json:"max_tokens,omitempty"` - Temperature *float64 `json:"temperature,omitempty"` - TopP float64 `json:"top_p,omitempty"` - TopK int `json:"top_k,omitempty"` - StopSequences []string `json:"stop_sequences,omitempty"` - Tools any `json:"tools,omitempty"` - ToolChoice any `json:"tool_choice,omitempty"` - Thinking *claude.Thinking `json:"thinking,omitempty"` + AnthropicVersion string `json:"anthropic_version"` + System any `json:"system,omitempty"` + Messages []dto.ClaudeMessage `json:"messages"` + MaxTokens uint `json:"max_tokens,omitempty"` + Temperature *float64 `json:"temperature,omitempty"` + TopP float64 `json:"top_p,omitempty"` + TopK int `json:"top_k,omitempty"` + StopSequences []string `json:"stop_sequences,omitempty"` + Tools any `json:"tools,omitempty"` + ToolChoice any `json:"tool_choice,omitempty"` + Thinking *dto.Thinking `json:"thinking,omitempty"` } -func copyRequest(req *claude.ClaudeRequest) *AwsClaudeRequest { +func copyRequest(req *dto.ClaudeRequest) *AwsClaudeRequest { return &AwsClaudeRequest{ AnthropicVersion: "bedrock-2023-05-31", System: req.System, diff --git a/relay/channel/aws/relay-aws.go b/relay/channel/aws/relay-aws.go index e1270606..0d517256 100644 --- a/relay/channel/aws/relay-aws.go +++ b/relay/channel/aws/relay-aws.go @@ -9,7 +9,7 @@ import ( "io" "net/http" "one-api/common" - relaymodel "one-api/dto" + "one-api/dto" "one-api/relay/channel/claude" relaycommon "one-api/relay/common" "one-api/relay/helper" @@ -39,10 +39,10 @@ func newAwsClient(c *gin.Context, info *relaycommon.RelayInfo) (*bedrockruntime. return client, nil } -func wrapErr(err error) *relaymodel.OpenAIErrorWithStatusCode { - return &relaymodel.OpenAIErrorWithStatusCode{ +func wrapErr(err error) *dto.OpenAIErrorWithStatusCode { + return &dto.OpenAIErrorWithStatusCode{ StatusCode: http.StatusInternalServerError, - Error: relaymodel.OpenAIError{ + Error: dto.OpenAIError{ Message: fmt.Sprintf("%s", err.Error()), }, } @@ -56,7 +56,7 @@ func awsModelID(requestModel string) (string, error) { return requestModel, nil } -func awsHandler(c *gin.Context, info *relaycommon.RelayInfo, requestMode int) (*relaymodel.OpenAIErrorWithStatusCode, *relaymodel.Usage) { +func awsHandler(c *gin.Context, info *relaycommon.RelayInfo, requestMode int) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) { awsCli, err := newAwsClient(c, info) if err != nil { return wrapErr(errors.Wrap(err, "newAwsClient")), nil @@ -77,7 +77,7 @@ func awsHandler(c *gin.Context, info *relaycommon.RelayInfo, requestMode int) (* if !ok { return wrapErr(errors.New("request not found")), nil } - claudeReq := claudeReq_.(*claude.ClaudeRequest) + claudeReq := claudeReq_.(*dto.ClaudeRequest) awsClaudeReq := copyRequest(claudeReq) awsReq.Body, err = json.Marshal(awsClaudeReq) if err != nil { @@ -89,14 +89,14 @@ func awsHandler(c *gin.Context, info *relaycommon.RelayInfo, requestMode int) (* return wrapErr(errors.Wrap(err, "InvokeModel")), nil } - claudeResponse := new(claude.ClaudeResponse) + claudeResponse := new(dto.ClaudeResponse) err = json.Unmarshal(awsResp.Body, claudeResponse) if err != nil { return wrapErr(errors.Wrap(err, "unmarshal response")), nil } openaiResp := claude.ResponseClaude2OpenAI(requestMode, claudeResponse) - usage := relaymodel.Usage{ + usage := dto.Usage{ PromptTokens: claudeResponse.Usage.InputTokens, CompletionTokens: claudeResponse.Usage.OutputTokens, TotalTokens: claudeResponse.Usage.InputTokens + claudeResponse.Usage.OutputTokens, @@ -107,7 +107,7 @@ func awsHandler(c *gin.Context, info *relaycommon.RelayInfo, requestMode int) (* return nil, &usage } -func awsStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo, requestMode int) (*relaymodel.OpenAIErrorWithStatusCode, *relaymodel.Usage) { +func awsStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo, requestMode int) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) { awsCli, err := newAwsClient(c, info) if err != nil { return wrapErr(errors.Wrap(err, "newAwsClient")), nil @@ -128,7 +128,7 @@ func awsStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel if !ok { return wrapErr(errors.New("request not found")), nil } - claudeReq := claudeReq_.(*claude.ClaudeRequest) + claudeReq := claudeReq_.(*dto.ClaudeRequest) awsClaudeReq := copyRequest(claudeReq) awsReq.Body, err = json.Marshal(awsClaudeReq) @@ -149,7 +149,7 @@ func awsStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel Created: common.GetTimestamp(), Model: info.UpstreamModelName, ResponseText: strings.Builder{}, - Usage: &relaymodel.Usage{}, + Usage: &dto.Usage{}, } isFirst := true c.Stream(func(w io.Writer) bool { @@ -164,7 +164,7 @@ func awsStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel isFirst = false info.FirstResponseTime = time.Now() } - claudeResponse := new(claude.ClaudeResponse) + claudeResponse := new(dto.ClaudeResponse) err := json.NewDecoder(bytes.NewReader(v.Value.Bytes)).Decode(claudeResponse) if err != nil { common.SysError("error unmarshalling stream response: " + err.Error()) diff --git a/relay/channel/baidu/adaptor.go b/relay/channel/baidu/adaptor.go index 46a1f964..105f2a9b 100644 --- a/relay/channel/baidu/adaptor.go +++ b/relay/channel/baidu/adaptor.go @@ -16,6 +16,12 @@ import ( type Adaptor struct { } +func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) { + //TODO implement me + panic("implement me") + return nil, nil +} + func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) { //TODO implement me return nil, errors.New("not implemented") diff --git a/relay/channel/baidu_v2/adaptor.go b/relay/channel/baidu_v2/adaptor.go index fd25ecc1..855ed717 100644 --- a/relay/channel/baidu_v2/adaptor.go +++ b/relay/channel/baidu_v2/adaptor.go @@ -15,6 +15,12 @@ import ( type Adaptor struct { } +func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) { + //TODO implement me + panic("implement me") + return nil, nil +} + func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) { //TODO implement me return nil, errors.New("not implemented") diff --git a/relay/channel/claude/adaptor.go b/relay/channel/claude/adaptor.go index bf03e5f5..a5c475fa 100644 --- a/relay/channel/claude/adaptor.go +++ b/relay/channel/claude/adaptor.go @@ -22,6 +22,10 @@ type Adaptor struct { RequestMode int } +func (a *Adaptor) ConvertClaudeRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.ClaudeRequest) (any, error) { + return request, nil +} + func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) { //TODO implement me return nil, errors.New("not implemented") diff --git a/relay/channel/claude/dto.go b/relay/channel/claude/dto.go index 9532ca74..89415868 100644 --- a/relay/channel/claude/dto.go +++ b/relay/channel/claude/dto.go @@ -1,94 +1,95 @@ package claude -type ClaudeMetadata struct { - UserId string `json:"user_id"` -} - -type ClaudeMediaMessage struct { - Type string `json:"type"` - Text string `json:"text,omitempty"` - Source *ClaudeMessageSource `json:"source,omitempty"` - Usage *ClaudeUsage `json:"usage,omitempty"` - StopReason *string `json:"stop_reason,omitempty"` - PartialJson string `json:"partial_json,omitempty"` - Thinking string `json:"thinking,omitempty"` - Signature string `json:"signature,omitempty"` - Delta string `json:"delta,omitempty"` - // tool_calls - Id string `json:"id,omitempty"` - Name string `json:"name,omitempty"` - Input any `json:"input,omitempty"` - Content string `json:"content,omitempty"` - ToolUseId string `json:"tool_use_id,omitempty"` -} - -type ClaudeMessageSource struct { - Type string `json:"type"` - MediaType string `json:"media_type"` - Data string `json:"data"` -} - -type ClaudeMessage struct { - Role string `json:"role"` - Content any `json:"content"` -} - -type Tool struct { - Name string `json:"name"` - Description string `json:"description,omitempty"` - InputSchema map[string]interface{} `json:"input_schema"` -} - -type InputSchema struct { - Type string `json:"type"` - Properties any `json:"properties,omitempty"` - Required any `json:"required,omitempty"` -} - -type ClaudeRequest struct { - Model string `json:"model"` - Prompt string `json:"prompt,omitempty"` - System string `json:"system,omitempty"` - Messages []ClaudeMessage `json:"messages,omitempty"` - MaxTokens uint `json:"max_tokens,omitempty"` - MaxTokensToSample uint `json:"max_tokens_to_sample,omitempty"` - StopSequences []string `json:"stop_sequences,omitempty"` - Temperature *float64 `json:"temperature,omitempty"` - TopP float64 `json:"top_p,omitempty"` - TopK int `json:"top_k,omitempty"` - //ClaudeMetadata `json:"metadata,omitempty"` - Stream bool `json:"stream,omitempty"` - Tools any `json:"tools,omitempty"` - ToolChoice any `json:"tool_choice,omitempty"` - Thinking *Thinking `json:"thinking,omitempty"` -} - -type Thinking struct { - Type string `json:"type"` - BudgetTokens int `json:"budget_tokens"` -} - -type ClaudeError struct { - Type string `json:"type"` - Message string `json:"message"` -} - -type ClaudeResponse struct { - Id string `json:"id"` - Type string `json:"type"` - Content []ClaudeMediaMessage `json:"content"` - Completion string `json:"completion"` - StopReason string `json:"stop_reason"` - Model string `json:"model"` - Error ClaudeError `json:"error"` - Usage ClaudeUsage `json:"usage"` - Index int `json:"index"` // stream only - ContentBlock *ClaudeMediaMessage `json:"content_block"` - Delta *ClaudeMediaMessage `json:"delta"` // stream only - Message *ClaudeResponse `json:"message"` // stream only: message_start -} - -type ClaudeUsage struct { - InputTokens int `json:"input_tokens"` - OutputTokens int `json:"output_tokens"` -} +// +//type ClaudeMetadata struct { +// UserId string `json:"user_id"` +//} +// +//type ClaudeMediaMessage struct { +// Type string `json:"type"` +// Text string `json:"text,omitempty"` +// Source *ClaudeMessageSource `json:"source,omitempty"` +// Usage *ClaudeUsage `json:"usage,omitempty"` +// StopReason *string `json:"stop_reason,omitempty"` +// PartialJson string `json:"partial_json,omitempty"` +// Thinking string `json:"thinking,omitempty"` +// Signature string `json:"signature,omitempty"` +// Delta string `json:"delta,omitempty"` +// // tool_calls +// Id string `json:"id,omitempty"` +// Name string `json:"name,omitempty"` +// Input any `json:"input,omitempty"` +// Content string `json:"content,omitempty"` +// ToolUseId string `json:"tool_use_id,omitempty"` +//} +// +//type ClaudeMessageSource struct { +// Type string `json:"type"` +// MediaType string `json:"media_type"` +// Data string `json:"data"` +//} +// +//type ClaudeMessage struct { +// Role string `json:"role"` +// Content any `json:"content"` +//} +// +//type Tool struct { +// Name string `json:"name"` +// Description string `json:"description,omitempty"` +// InputSchema map[string]interface{} `json:"input_schema"` +//} +// +//type InputSchema struct { +// Type string `json:"type"` +// Properties any `json:"properties,omitempty"` +// Required any `json:"required,omitempty"` +//} +// +//type ClaudeRequest struct { +// Model string `json:"model"` +// Prompt string `json:"prompt,omitempty"` +// System string `json:"system,omitempty"` +// Messages []ClaudeMessage `json:"messages,omitempty"` +// MaxTokens uint `json:"max_tokens,omitempty"` +// MaxTokensToSample uint `json:"max_tokens_to_sample,omitempty"` +// StopSequences []string `json:"stop_sequences,omitempty"` +// Temperature *float64 `json:"temperature,omitempty"` +// TopP float64 `json:"top_p,omitempty"` +// TopK int `json:"top_k,omitempty"` +// //ClaudeMetadata `json:"metadata,omitempty"` +// Stream bool `json:"stream,omitempty"` +// Tools any `json:"tools,omitempty"` +// ToolChoice any `json:"tool_choice,omitempty"` +// Thinking *Thinking `json:"thinking,omitempty"` +//} +// +//type Thinking struct { +// Type string `json:"type"` +// BudgetTokens int `json:"budget_tokens"` +//} +// +//type ClaudeError struct { +// Type string `json:"type"` +// Message string `json:"message"` +//} +// +//type ClaudeResponse struct { +// Id string `json:"id"` +// Type string `json:"type"` +// Content []ClaudeMediaMessage `json:"content"` +// Completion string `json:"completion"` +// StopReason string `json:"stop_reason"` +// Model string `json:"model"` +// Error ClaudeError `json:"error"` +// Usage ClaudeUsage `json:"usage"` +// Index int `json:"index"` // stream only +// ContentBlock *ClaudeMediaMessage `json:"content_block"` +// Delta *ClaudeMediaMessage `json:"delta"` // stream only +// Message *ClaudeResponse `json:"message"` // stream only: message_start +//} +// +//type ClaudeUsage struct { +// InputTokens int `json:"input_tokens"` +// OutputTokens int `json:"output_tokens"` +//} diff --git a/relay/channel/claude/relay-claude.go b/relay/channel/claude/relay-claude.go index 011694df..205e0b61 100644 --- a/relay/channel/claude/relay-claude.go +++ b/relay/channel/claude/relay-claude.go @@ -30,9 +30,9 @@ func stopReasonClaude2OpenAI(reason string) string { } } -func RequestOpenAI2ClaudeComplete(textRequest dto.GeneralOpenAIRequest) *ClaudeRequest { +func RequestOpenAI2ClaudeComplete(textRequest dto.GeneralOpenAIRequest) *dto.ClaudeRequest { - claudeRequest := ClaudeRequest{ + claudeRequest := dto.ClaudeRequest{ Model: textRequest.Model, Prompt: "", StopSequences: nil, @@ -61,12 +61,12 @@ func RequestOpenAI2ClaudeComplete(textRequest dto.GeneralOpenAIRequest) *ClaudeR return &claudeRequest } -func RequestOpenAI2ClaudeMessage(textRequest dto.GeneralOpenAIRequest) (*ClaudeRequest, error) { - claudeTools := make([]Tool, 0, len(textRequest.Tools)) +func RequestOpenAI2ClaudeMessage(textRequest dto.GeneralOpenAIRequest) (*dto.ClaudeRequest, error) { + claudeTools := make([]dto.Tool, 0, len(textRequest.Tools)) for _, tool := range textRequest.Tools { if params, ok := tool.Function.Parameters.(map[string]any); ok { - claudeTool := Tool{ + claudeTool := dto.Tool{ Name: tool.Function.Name, Description: tool.Function.Description, } @@ -84,7 +84,7 @@ func RequestOpenAI2ClaudeMessage(textRequest dto.GeneralOpenAIRequest) (*ClaudeR } } - claudeRequest := ClaudeRequest{ + claudeRequest := dto.ClaudeRequest{ Model: textRequest.Model, MaxTokens: textRequest.MaxTokens, StopSequences: nil, @@ -108,7 +108,7 @@ func RequestOpenAI2ClaudeMessage(textRequest dto.GeneralOpenAIRequest) (*ClaudeR } // BudgetTokens 为 max_tokens 的 80% - claudeRequest.Thinking = &Thinking{ + claudeRequest.Thinking = &dto.Thinking{ Type: "enabled", BudgetTokens: int(float64(claudeRequest.MaxTokens) * model_setting.GetClaudeSettings().ThinkingAdapterBudgetTokensPercentage), } @@ -166,7 +166,7 @@ func RequestOpenAI2ClaudeMessage(textRequest dto.GeneralOpenAIRequest) (*ClaudeR lastMessage = fmtMessage } - claudeMessages := make([]ClaudeMessage, 0) + claudeMessages := make([]dto.ClaudeMessage, 0) isFirstMessage := true for _, message := range formatMessages { if message.Role == "system" { @@ -187,63 +187,63 @@ func RequestOpenAI2ClaudeMessage(textRequest dto.GeneralOpenAIRequest) (*ClaudeR isFirstMessage = false if message.Role != "user" { // fix: first message is assistant, add user message - claudeMessage := ClaudeMessage{ + claudeMessage := dto.ClaudeMessage{ Role: "user", - Content: []ClaudeMediaMessage{ + Content: []dto.ClaudeMediaMessage{ { Type: "text", - Text: "...", + Text: common.GetPointer[string]("..."), }, }, } claudeMessages = append(claudeMessages, claudeMessage) } } - claudeMessage := ClaudeMessage{ + claudeMessage := dto.ClaudeMessage{ Role: message.Role, } if message.Role == "tool" { if len(claudeMessages) > 0 && claudeMessages[len(claudeMessages)-1].Role == "user" { lastMessage := claudeMessages[len(claudeMessages)-1] if content, ok := lastMessage.Content.(string); ok { - lastMessage.Content = []ClaudeMediaMessage{ + lastMessage.Content = []dto.ClaudeMediaMessage{ { Type: "text", - Text: content, + Text: common.GetPointer[string](content), }, } } - lastMessage.Content = append(lastMessage.Content.([]ClaudeMediaMessage), ClaudeMediaMessage{ + lastMessage.Content = append(lastMessage.Content.([]dto.ClaudeMediaMessage), dto.ClaudeMediaMessage{ Type: "tool_result", ToolUseId: message.ToolCallId, - Content: message.StringContent(), + Content: message.Content, }) claudeMessages[len(claudeMessages)-1] = lastMessage continue } else { claudeMessage.Role = "user" - claudeMessage.Content = []ClaudeMediaMessage{ + claudeMessage.Content = []dto.ClaudeMediaMessage{ { Type: "tool_result", ToolUseId: message.ToolCallId, - Content: message.StringContent(), + Content: message.Content, }, } } } else if message.IsStringContent() && message.ToolCalls == nil { claudeMessage.Content = message.StringContent() } else { - claudeMediaMessages := make([]ClaudeMediaMessage, 0) + claudeMediaMessages := make([]dto.ClaudeMediaMessage, 0) for _, mediaMessage := range message.ParseContent() { - claudeMediaMessage := ClaudeMediaMessage{ + claudeMediaMessage := dto.ClaudeMediaMessage{ Type: mediaMessage.Type, } if mediaMessage.Type == "text" { - claudeMediaMessage.Text = mediaMessage.Text + claudeMediaMessage.Text = common.GetPointer[string](mediaMessage.Text) } else { imageUrl := mediaMessage.ImageUrl.(dto.MessageImageUrl) claudeMediaMessage.Type = "image" - claudeMediaMessage.Source = &ClaudeMessageSource{ + claudeMediaMessage.Source = &dto.ClaudeMessageSource{ Type: "base64", } // 判断是否是url @@ -273,7 +273,7 @@ func RequestOpenAI2ClaudeMessage(textRequest dto.GeneralOpenAIRequest) (*ClaudeR common.SysError("tool call function arguments is not a map[string]any: " + fmt.Sprintf("%v", toolCall.Function.Arguments)) continue } - claudeMediaMessages = append(claudeMediaMessages, ClaudeMediaMessage{ + claudeMediaMessages = append(claudeMediaMessages, dto.ClaudeMediaMessage{ Type: "tool_use", Id: toolCall.ID, Name: toolCall.Function.Name, @@ -291,7 +291,7 @@ func RequestOpenAI2ClaudeMessage(textRequest dto.GeneralOpenAIRequest) (*ClaudeR return &claudeRequest, nil } -func StreamResponseClaude2OpenAI(reqMode int, claudeResponse *ClaudeResponse) *dto.ChatCompletionsStreamResponse { +func StreamResponseClaude2OpenAI(reqMode int, claudeResponse *dto.ClaudeResponse) *dto.ChatCompletionsStreamResponse { var response dto.ChatCompletionsStreamResponse response.Object = "chat.completion.chunk" response.Model = claudeResponse.Model @@ -329,8 +329,8 @@ func StreamResponseClaude2OpenAI(reqMode int, claudeResponse *ClaudeResponse) *d } } else if claudeResponse.Type == "content_block_delta" { if claudeResponse.Delta != nil { - choice.Index = claudeResponse.Index - choice.Delta.SetContentString(claudeResponse.Delta.Text) + choice.Index = *claudeResponse.Index + choice.Delta.SetContentString(*claudeResponse.Delta.Text) switch claudeResponse.Delta.Type { case "input_json_delta": tools = append(tools, dto.ToolCallResponse{ @@ -368,7 +368,7 @@ func StreamResponseClaude2OpenAI(reqMode int, claudeResponse *ClaudeResponse) *d return &response } -func ResponseClaude2OpenAI(reqMode int, claudeResponse *ClaudeResponse) *dto.OpenAITextResponse { +func ResponseClaude2OpenAI(reqMode int, claudeResponse *dto.ClaudeResponse) *dto.OpenAITextResponse { choices := make([]dto.OpenAITextResponseChoice, 0) fullTextResponse := dto.OpenAITextResponse{ Id: fmt.Sprintf("chatcmpl-%s", common.GetUUID()), @@ -377,7 +377,7 @@ func ResponseClaude2OpenAI(reqMode int, claudeResponse *ClaudeResponse) *dto.Ope } var responseText string if len(claudeResponse.Content) > 0 { - responseText = claudeResponse.Content[0].Text + responseText = *claudeResponse.Content[0].Text } tools := make([]dto.ToolCallResponse, 0) thinkingContent := "" @@ -412,7 +412,7 @@ func ResponseClaude2OpenAI(reqMode int, claudeResponse *ClaudeResponse) *dto.Ope // 加密的不管, 只输出明文的推理过程 thinkingContent = message.Thinking case "text": - responseText = message.Text + responseText = *message.Text } } } @@ -442,7 +442,7 @@ type ClaudeResponseInfo struct { Usage *dto.Usage } -func FormatClaudeResponseInfo(requestMode int, claudeResponse *ClaudeResponse, oaiResponse *dto.ChatCompletionsStreamResponse, claudeInfo *ClaudeResponseInfo) bool { +func FormatClaudeResponseInfo(requestMode int, claudeResponse *dto.ClaudeResponse, oaiResponse *dto.ChatCompletionsStreamResponse, claudeInfo *ClaudeResponseInfo) bool { if requestMode == RequestModeCompletion { claudeInfo.ResponseText.WriteString(claudeResponse.Completion) } else { @@ -452,7 +452,7 @@ func FormatClaudeResponseInfo(requestMode int, claudeResponse *ClaudeResponse, o claudeInfo.Model = claudeResponse.Message.Model claudeInfo.Usage.PromptTokens = claudeResponse.Message.Usage.InputTokens } else if claudeResponse.Type == "content_block_delta" { - claudeInfo.ResponseText.WriteString(claudeResponse.Delta.Text) + claudeInfo.ResponseText.WriteString(*claudeResponse.Delta.Text) } else if claudeResponse.Type == "message_delta" { claudeInfo.Usage.CompletionTokens = claudeResponse.Usage.OutputTokens claudeInfo.Usage.TotalTokens = claudeResponse.Usage.InputTokens + claudeResponse.Usage.OutputTokens @@ -470,6 +470,61 @@ func FormatClaudeResponseInfo(requestMode int, claudeResponse *ClaudeResponse, o } func ClaudeStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo, requestMode int) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) { + + if info.RelayFormat == relaycommon.RelayFormatOpenAI { + return toOpenAIStreamHandler(c, resp, info, requestMode) + } + + usage := &dto.Usage{} + responseText := strings.Builder{} + + helper.StreamScannerHandler(c, resp, info, func(data string) bool { + var claudeResponse dto.ClaudeResponse + err := json.NewDecoder(bytes.NewReader(common.StringToByteSlice(data))).Decode(&claudeResponse) + if err != nil { + common.SysError("error unmarshalling stream response: " + err.Error()) + return true + } + if requestMode == RequestModeCompletion { + responseText.WriteString(claudeResponse.Completion) + } else { + if claudeResponse.Type == "message_start" { + // message_start, 获取usage + info.UpstreamModelName = claudeResponse.Message.Model + usage.PromptTokens = claudeResponse.Message.Usage.InputTokens + usage.PromptTokensDetails.CachedTokens = claudeResponse.Message.Usage.CacheReadInputTokens + usage.PromptTokensDetails.CachedCreationTokens = claudeResponse.Message.Usage.CacheCreationInputTokens + usage.CompletionTokens = claudeResponse.Message.Usage.OutputTokens + } else if claudeResponse.Type == "content_block_delta" { + responseText.WriteString(claudeResponse.Delta.GetText()) + } else if claudeResponse.Type == "message_delta" { + if claudeResponse.Usage.InputTokens > 0 { + // 不叠加,只取最新的 + usage.PromptTokens = claudeResponse.Usage.InputTokens + } + usage.CompletionTokens = claudeResponse.Usage.OutputTokens + usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens + } + } + helper.ClaudeChunkData(c, claudeResponse, data) + return true + }) + + if requestMode == RequestModeCompletion { + usage, _ = service.ResponseText2Usage(responseText.String(), info.UpstreamModelName, info.PromptTokens) + } else { + // 说明流模式建立失败,可能为官方出错 + if usage.PromptTokens == 0 { + //usage.PromptTokens = info.PromptTokens + } + if usage.CompletionTokens == 0 { + usage, _ = service.ResponseText2Usage(responseText.String(), info.UpstreamModelName, usage.PromptTokens) + } + } + return nil, usage +} + +func toOpenAIStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo, requestMode int) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) { responseId := fmt.Sprintf("chatcmpl-%s", common.GetUUID()) claudeInfo := &ClaudeResponseInfo{ ResponseId: responseId, @@ -480,7 +535,7 @@ func ClaudeStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon. } helper.StreamScannerHandler(c, resp, info, func(data string) bool { - var claudeResponse ClaudeResponse + var claudeResponse dto.ClaudeResponse err := json.NewDecoder(bytes.NewReader(common.StringToByteSlice(data))).Decode(&claudeResponse) if err != nil { common.SysError("error unmarshalling stream response: " + err.Error()) @@ -530,7 +585,7 @@ func ClaudeHandler(c *gin.Context, resp *http.Response, requestMode int, info *r if err != nil { return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil } - var claudeResponse ClaudeResponse + var claudeResponse dto.ClaudeResponse err = json.Unmarshal(responseBody, &claudeResponse) if err != nil { return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil @@ -546,13 +601,12 @@ func ClaudeHandler(c *gin.Context, resp *http.Response, requestMode int, info *r StatusCode: resp.StatusCode, }, nil } - fullTextResponse := ResponseClaude2OpenAI(requestMode, &claudeResponse) - completionTokens, err := service.CountTextToken(claudeResponse.Completion, info.OriginModelName) - if err != nil { - return service.OpenAIErrorWrapper(err, "count_token_text_failed", http.StatusInternalServerError), nil - } usage := dto.Usage{} if requestMode == RequestModeCompletion { + completionTokens, err := service.CountTextToken(claudeResponse.Completion, info.OriginModelName) + if err != nil { + return service.OpenAIErrorWrapper(err, "count_token_text_failed", http.StatusInternalServerError), nil + } usage.PromptTokens = info.PromptTokens usage.CompletionTokens = completionTokens usage.TotalTokens = info.PromptTokens + completionTokens @@ -560,14 +614,23 @@ func ClaudeHandler(c *gin.Context, resp *http.Response, requestMode int, info *r usage.PromptTokens = claudeResponse.Usage.InputTokens usage.CompletionTokens = claudeResponse.Usage.OutputTokens usage.TotalTokens = claudeResponse.Usage.InputTokens + claudeResponse.Usage.OutputTokens + usage.PromptTokensDetails.CachedTokens = claudeResponse.Usage.CacheReadInputTokens + usage.PromptTokensDetails.CachedCreationTokens = claudeResponse.Usage.CacheCreationInputTokens } - fullTextResponse.Usage = usage - jsonResponse, err := json.Marshal(fullTextResponse) - if err != nil { - return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil + var responseData []byte + switch info.RelayFormat { + case relaycommon.RelayFormatOpenAI: + openaiResponse := ResponseClaude2OpenAI(requestMode, &claudeResponse) + openaiResponse.Usage = usage + responseData, err = json.Marshal(openaiResponse) + if err != nil { + return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil + } + case relaycommon.RelayFormatClaude: + responseData = responseBody } c.Writer.Header().Set("Content-Type", "application/json") c.Writer.WriteHeader(resp.StatusCode) - _, err = c.Writer.Write(jsonResponse) + _, err = c.Writer.Write(responseData) return nil, &usage } diff --git a/relay/channel/cloudflare/adaptor.go b/relay/channel/cloudflare/adaptor.go index 5c2eadc2..b21e25f3 100644 --- a/relay/channel/cloudflare/adaptor.go +++ b/relay/channel/cloudflare/adaptor.go @@ -17,6 +17,12 @@ import ( type Adaptor struct { } +func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) { + //TODO implement me + panic("implement me") + return nil, nil +} + func (a *Adaptor) Init(info *relaycommon.RelayInfo) { } diff --git a/relay/channel/cohere/adaptor.go b/relay/channel/cohere/adaptor.go index d552a53b..7675d546 100644 --- a/relay/channel/cohere/adaptor.go +++ b/relay/channel/cohere/adaptor.go @@ -15,6 +15,12 @@ import ( type Adaptor struct { } +func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) { + //TODO implement me + panic("implement me") + return nil, nil +} + func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) { //TODO implement me return nil, errors.New("not implemented") @@ -59,7 +65,6 @@ func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.Rela return nil, errors.New("not implemented") } - func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) { if info.RelayMode == constant.RelayModeRerank { err, usage = cohereRerankHandler(c, resp, info) diff --git a/relay/channel/deepseek/adaptor.go b/relay/channel/deepseek/adaptor.go index d779ee65..ad01b8f4 100644 --- a/relay/channel/deepseek/adaptor.go +++ b/relay/channel/deepseek/adaptor.go @@ -16,6 +16,12 @@ import ( type Adaptor struct { } +func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) { + //TODO implement me + panic("implement me") + return nil, nil +} + func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) { //TODO implement me return nil, errors.New("not implemented") diff --git a/relay/channel/dify/adaptor.go b/relay/channel/dify/adaptor.go index 2626dd7d..96aff447 100644 --- a/relay/channel/dify/adaptor.go +++ b/relay/channel/dify/adaptor.go @@ -23,6 +23,12 @@ type Adaptor struct { BotType int } +func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) { + //TODO implement me + panic("implement me") + return nil, nil +} + func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) { //TODO implement me return nil, errors.New("not implemented") diff --git a/relay/channel/gemini/adaptor.go b/relay/channel/gemini/adaptor.go index 1b7131dc..a629968b 100644 --- a/relay/channel/gemini/adaptor.go +++ b/relay/channel/gemini/adaptor.go @@ -21,6 +21,12 @@ import ( type Adaptor struct { } +func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) { + //TODO implement me + panic("implement me") + return nil, nil +} + func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) { //TODO implement me return nil, errors.New("not implemented") diff --git a/relay/channel/jina/adaptor.go b/relay/channel/jina/adaptor.go index 77076bd4..bcfc8dea 100644 --- a/relay/channel/jina/adaptor.go +++ b/relay/channel/jina/adaptor.go @@ -15,6 +15,12 @@ import ( type Adaptor struct { } +func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) { + //TODO implement me + panic("implement me") + return nil, nil +} + func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) { //TODO implement me return nil, errors.New("not implemented") diff --git a/relay/channel/mistral/adaptor.go b/relay/channel/mistral/adaptor.go index fcea169a..80547346 100644 --- a/relay/channel/mistral/adaptor.go +++ b/relay/channel/mistral/adaptor.go @@ -14,6 +14,12 @@ import ( type Adaptor struct { } +func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) { + //TODO implement me + panic("implement me") + return nil, nil +} + func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) { //TODO implement me return nil, errors.New("not implemented") diff --git a/relay/channel/mokaai/adaptor.go b/relay/channel/mokaai/adaptor.go index 9670ec94..151072cb 100644 --- a/relay/channel/mokaai/adaptor.go +++ b/relay/channel/mokaai/adaptor.go @@ -16,6 +16,12 @@ import ( type Adaptor struct { } +func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) { + //TODO implement me + panic("implement me") + return nil, nil +} + func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) { //TODO implement me return nil, errors.New("not implemented") @@ -73,13 +79,13 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request } func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) { - + switch info.RelayMode { case constant.RelayModeEmbeddings: err, usage = mokaEmbeddingHandler(c, resp) default: // err, usage = mokaHandler(c, resp) - + } return } diff --git a/relay/channel/ollama/adaptor.go b/relay/channel/ollama/adaptor.go index 7e1c6237..4190dd3f 100644 --- a/relay/channel/ollama/adaptor.go +++ b/relay/channel/ollama/adaptor.go @@ -15,6 +15,12 @@ import ( type Adaptor struct { } +func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) { + //TODO implement me + panic("implement me") + return nil, nil +} + func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) { //TODO implement me return nil, errors.New("not implemented") diff --git a/relay/channel/openai/adaptor.go b/relay/channel/openai/adaptor.go index d8a44335..196343e8 100644 --- a/relay/channel/openai/adaptor.go +++ b/relay/channel/openai/adaptor.go @@ -29,6 +29,12 @@ type Adaptor struct { ResponseFormat string } +func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) { + //TODO implement me + panic("implement me") + return nil, nil +} + func (a *Adaptor) Init(info *relaycommon.RelayInfo) { a.ChannelType = info.ChannelType } diff --git a/relay/channel/openrouter/adaptor.go b/relay/channel/openrouter/adaptor.go index 83afb6af..aef5afeb 100644 --- a/relay/channel/openrouter/adaptor.go +++ b/relay/channel/openrouter/adaptor.go @@ -15,6 +15,12 @@ import ( type Adaptor struct { } +func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) { + //TODO implement me + panic("implement me") + return nil, nil +} + func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) { //TODO implement me return nil, errors.New("not implemented") diff --git a/relay/channel/palm/adaptor.go b/relay/channel/palm/adaptor.go index f38fa95b..69ef5001 100644 --- a/relay/channel/palm/adaptor.go +++ b/relay/channel/palm/adaptor.go @@ -15,6 +15,12 @@ import ( type Adaptor struct { } +func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) { + //TODO implement me + panic("implement me") + return nil, nil +} + func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) { //TODO implement me return nil, errors.New("not implemented") @@ -54,7 +60,6 @@ func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.Rela return nil, errors.New("not implemented") } - func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) { return channel.DoApiRequest(a, c, info, requestBody) } diff --git a/relay/channel/perplexity/adaptor.go b/relay/channel/perplexity/adaptor.go index 2b27bdb1..de84406c 100644 --- a/relay/channel/perplexity/adaptor.go +++ b/relay/channel/perplexity/adaptor.go @@ -15,6 +15,12 @@ import ( type Adaptor struct { } +func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) { + //TODO implement me + panic("implement me") + return nil, nil +} + func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) { //TODO implement me return nil, errors.New("not implemented") @@ -57,7 +63,6 @@ func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.Rela return nil, errors.New("not implemented") } - func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) { return channel.DoApiRequest(a, c, info, requestBody) } diff --git a/relay/channel/siliconflow/adaptor.go b/relay/channel/siliconflow/adaptor.go index 797f0244..754a1f00 100644 --- a/relay/channel/siliconflow/adaptor.go +++ b/relay/channel/siliconflow/adaptor.go @@ -16,6 +16,12 @@ import ( type Adaptor struct { } +func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) { + //TODO implement me + panic("implement me") + return nil, nil +} + func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) { //TODO implement me return nil, errors.New("not implemented") diff --git a/relay/channel/tencent/adaptor.go b/relay/channel/tencent/adaptor.go index 768ef646..28a02aae 100644 --- a/relay/channel/tencent/adaptor.go +++ b/relay/channel/tencent/adaptor.go @@ -23,6 +23,12 @@ type Adaptor struct { Timestamp int64 } +func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) { + //TODO implement me + panic("implement me") + return nil, nil +} + func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) { //TODO implement me return nil, errors.New("not implemented") @@ -78,7 +84,6 @@ func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.Rela return nil, errors.New("not implemented") } - func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) { return channel.DoApiRequest(a, c, info, requestBody) } diff --git a/relay/channel/vertex/adaptor.go b/relay/channel/vertex/adaptor.go index 7ccd3f30..2f348e46 100644 --- a/relay/channel/vertex/adaptor.go +++ b/relay/channel/vertex/adaptor.go @@ -38,6 +38,9 @@ type Adaptor struct { AccountCredentials Credentials } +func (a *Adaptor) ConvertClaudeRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.ClaudeRequest) (any, error) { + return request, nil +} func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) { //TODO implement me return nil, errors.New("not implemented") diff --git a/relay/channel/vertex/dto.go b/relay/channel/vertex/dto.go index 4ba570de..4a571612 100644 --- a/relay/channel/vertex/dto.go +++ b/relay/channel/vertex/dto.go @@ -1,25 +1,25 @@ package vertex import ( - "one-api/relay/channel/claude" + "one-api/dto" ) type VertexAIClaudeRequest struct { - AnthropicVersion string `json:"anthropic_version"` - Messages []claude.ClaudeMessage `json:"messages"` - System any `json:"system,omitempty"` - MaxTokens uint `json:"max_tokens,omitempty"` - StopSequences []string `json:"stop_sequences,omitempty"` - Stream bool `json:"stream,omitempty"` - Temperature *float64 `json:"temperature,omitempty"` - TopP float64 `json:"top_p,omitempty"` - TopK int `json:"top_k,omitempty"` - Tools any `json:"tools,omitempty"` - ToolChoice any `json:"tool_choice,omitempty"` - Thinking *claude.Thinking `json:"thinking,omitempty"` + AnthropicVersion string `json:"anthropic_version"` + Messages []dto.ClaudeMessage `json:"messages"` + System any `json:"system,omitempty"` + MaxTokens uint `json:"max_tokens,omitempty"` + StopSequences []string `json:"stop_sequences,omitempty"` + Stream bool `json:"stream,omitempty"` + Temperature *float64 `json:"temperature,omitempty"` + TopP float64 `json:"top_p,omitempty"` + TopK int `json:"top_k,omitempty"` + Tools any `json:"tools,omitempty"` + ToolChoice any `json:"tool_choice,omitempty"` + Thinking *dto.Thinking `json:"thinking,omitempty"` } -func copyRequest(req *claude.ClaudeRequest, version string) *VertexAIClaudeRequest { +func copyRequest(req *dto.ClaudeRequest, version string) *VertexAIClaudeRequest { return &VertexAIClaudeRequest{ AnthropicVersion: version, System: req.System, diff --git a/relay/channel/volcengine/adaptor.go b/relay/channel/volcengine/adaptor.go index 3b57c67c..f423d587 100644 --- a/relay/channel/volcengine/adaptor.go +++ b/relay/channel/volcengine/adaptor.go @@ -17,6 +17,12 @@ import ( type Adaptor struct { } +func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) { + //TODO implement me + panic("implement me") + return nil, nil +} + func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) { //TODO implement me return nil, errors.New("not implemented") diff --git a/relay/channel/xunfei/adaptor.go b/relay/channel/xunfei/adaptor.go index 71fd1367..d66f3732 100644 --- a/relay/channel/xunfei/adaptor.go +++ b/relay/channel/xunfei/adaptor.go @@ -16,6 +16,12 @@ type Adaptor struct { request *dto.GeneralOpenAIRequest } +func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) { + //TODO implement me + panic("implement me") + return nil, nil +} + func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) { //TODO implement me return nil, errors.New("not implemented") @@ -55,7 +61,6 @@ func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.Rela return nil, errors.New("not implemented") } - func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) { // xunfei's request is not http request, so we don't need to do anything here dummyResp := &http.Response{} diff --git a/relay/channel/zhipu/adaptor.go b/relay/channel/zhipu/adaptor.go index 87ff20d5..aa612f0c 100644 --- a/relay/channel/zhipu/adaptor.go +++ b/relay/channel/zhipu/adaptor.go @@ -14,6 +14,12 @@ import ( type Adaptor struct { } +func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) { + //TODO implement me + panic("implement me") + return nil, nil +} + func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) { //TODO implement me return nil, errors.New("not implemented") @@ -61,7 +67,6 @@ func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.Rela return nil, errors.New("not implemented") } - func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) { return channel.DoApiRequest(a, c, info, requestBody) } diff --git a/relay/channel/zhipu_4v/adaptor.go b/relay/channel/zhipu_4v/adaptor.go index 5983c1d9..7a23e212 100644 --- a/relay/channel/zhipu_4v/adaptor.go +++ b/relay/channel/zhipu_4v/adaptor.go @@ -15,6 +15,12 @@ import ( type Adaptor struct { } +func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) { + //TODO implement me + panic("implement me") + return nil, nil +} + func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) { //TODO implement me return nil, errors.New("not implemented") @@ -58,7 +64,6 @@ func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.Rela return nil, errors.New("not implemented") } - func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) { return channel.DoApiRequest(a, c, info, requestBody) } diff --git a/relay/claude_handler.go b/relay/claude_handler.go new file mode 100644 index 00000000..97de772b --- /dev/null +++ b/relay/claude_handler.go @@ -0,0 +1,162 @@ +package relay + +import ( + "bytes" + "encoding/json" + "errors" + "fmt" + "github.com/gin-gonic/gin" + "io" + "net/http" + "one-api/common" + "one-api/dto" + relaycommon "one-api/relay/common" + "one-api/relay/helper" + "one-api/service" + "one-api/setting/model_setting" + "strings" +) + +func getAndValidateClaudeRequest(c *gin.Context) (textRequest *dto.ClaudeRequest, err error) { + textRequest = &dto.ClaudeRequest{} + err = c.ShouldBindJSON(textRequest) + if err != nil { + return nil, err + } + if textRequest.Messages == nil || len(textRequest.Messages) == 0 { + return nil, errors.New("field messages is required") + } + if textRequest.Model == "" { + return nil, errors.New("field model is required") + } + return textRequest, nil +} + +func ClaudeHelper(c *gin.Context) (claudeError *dto.ClaudeErrorWithStatusCode) { + + relayInfo := relaycommon.GenRelayInfoClaude(c) + + // get & validate textRequest 获取并验证文本请求 + textRequest, err := getAndValidateClaudeRequest(c) + if err != nil { + return service.ClaudeErrorWrapperLocal(err, "invalid_claude_request", http.StatusBadRequest) + } + + if textRequest.Stream { + relayInfo.IsStream = true + } + + err = helper.ModelMappedHelper(c, relayInfo) + if err != nil { + return service.ClaudeErrorWrapperLocal(err, "model_mapped_error", http.StatusInternalServerError) + } + + textRequest.Model = relayInfo.UpstreamModelName + + promptTokens, err := getClaudePromptTokens(textRequest, relayInfo) + // count messages token error 计算promptTokens错误 + if err != nil { + return service.ClaudeErrorWrapperLocal(err, "count_token_messages_failed", http.StatusInternalServerError) + } + + priceData, err := helper.ModelPriceHelper(c, relayInfo, promptTokens, int(textRequest.MaxTokens)) + if err != nil { + return service.ClaudeErrorWrapperLocal(err, "model_price_error", http.StatusInternalServerError) + } + + // pre-consume quota 预消耗配额 + preConsumedQuota, userQuota, openaiErr := preConsumeQuota(c, priceData.ShouldPreConsumedQuota, relayInfo) + + if openaiErr != nil { + return service.OpenAIErrorToClaudeError(openaiErr) + } + defer func() { + if openaiErr != nil { + returnPreConsumedQuota(c, relayInfo, userQuota, preConsumedQuota) + } + }() + + adaptor := GetAdaptor(relayInfo.ApiType) + if adaptor == nil { + return service.ClaudeErrorWrapperLocal(fmt.Errorf("invalid api type: %d", relayInfo.ApiType), "invalid_api_type", http.StatusBadRequest) + } + adaptor.Init(relayInfo) + var requestBody io.Reader + + if textRequest.MaxTokens == 0 { + textRequest.MaxTokens = uint(model_setting.GetClaudeSettings().GetDefaultMaxTokens(textRequest.Model)) + } + + if model_setting.GetClaudeSettings().ThinkingAdapterEnabled && + strings.HasSuffix(textRequest.Model, "-thinking") { + if textRequest.Thinking == nil { + // 因为BudgetTokens 必须大于1024 + if textRequest.MaxTokens < 1280 { + textRequest.MaxTokens = 1280 + } + + // BudgetTokens 为 max_tokens 的 80% + textRequest.Thinking = &dto.Thinking{ + Type: "enabled", + BudgetTokens: int(float64(textRequest.MaxTokens) * model_setting.GetClaudeSettings().ThinkingAdapterBudgetTokensPercentage), + } + // TODO: 临时处理 + // https://docs.anthropic.com/en/docs/build-with-claude/extended-thinking#important-considerations-when-using-extended-thinking + textRequest.TopP = 0 + textRequest.Temperature = common.GetPointer[float64](1.0) + } + textRequest.Model = strings.TrimSuffix(textRequest.Model, "-thinking") + relayInfo.UpstreamModelName = textRequest.Model + } + + convertedRequest, err := adaptor.ConvertClaudeRequest(c, relayInfo, textRequest) + if err != nil { + return service.ClaudeErrorWrapperLocal(err, "convert_request_failed", http.StatusInternalServerError) + } + jsonData, err := json.Marshal(convertedRequest) + if err != nil { + return service.ClaudeErrorWrapperLocal(err, "json_marshal_failed", http.StatusInternalServerError) + } + requestBody = bytes.NewBuffer(jsonData) + + //log.Printf("requestBody: %s", requestBody) + + statusCodeMappingStr := c.GetString("status_code_mapping") + var httpResp *http.Response + resp, err := adaptor.DoRequest(c, relayInfo, requestBody) + if err != nil { + return service.ClaudeErrorWrapperLocal(err, "do_request_failed", http.StatusInternalServerError) + } + + if resp != nil { + httpResp = resp.(*http.Response) + relayInfo.IsStream = relayInfo.IsStream || strings.HasPrefix(httpResp.Header.Get("Content-Type"), "text/event-stream") + if httpResp.StatusCode != http.StatusOK { + openaiErr = service.RelayErrorHandler(httpResp, false) + // reset status code 重置状态码 + service.ResetStatusCode(openaiErr, statusCodeMappingStr) + return service.OpenAIErrorToClaudeError(openaiErr) + } + } + + usage, openaiErr := adaptor.DoResponse(c, httpResp, relayInfo) + //log.Printf("usage: %v", usage) + if openaiErr != nil { + // reset status code 重置状态码 + service.ResetStatusCode(openaiErr, statusCodeMappingStr) + return service.OpenAIErrorToClaudeError(openaiErr) + } + service.PostClaudeConsumeQuota(c, relayInfo, usage.(*dto.Usage), preConsumedQuota, userQuota, priceData, "") + return nil +} + +func getClaudePromptTokens(textRequest *dto.ClaudeRequest, info *relaycommon.RelayInfo) (int, error) { + var promptTokens int + var err error + switch info.RelayMode { + default: + promptTokens, err = service.CountTokenClaudeRequest(*textRequest, info.UpstreamModelName) + } + info.PromptTokens = promptTokens + return promptTokens, err +} diff --git a/relay/common/relay_info.go b/relay/common/relay_info.go index c1d3f4a4..3b5ef795 100644 --- a/relay/common/relay_info.go +++ b/relay/common/relay_info.go @@ -17,6 +17,11 @@ type ThinkingContentInfo struct { SendLastThinkingContent bool } +const ( + RelayFormatOpenAI = "openai" + RelayFormatClaude = "claude" +) + type RelayInfo struct { ChannelType int ChannelId int @@ -58,6 +63,8 @@ type RelayInfo struct { UserSetting map[string]interface{} UserEmail string UserQuota int + RelayFormat string + ResponseTimes int64 ThinkingContentInfo } @@ -82,6 +89,13 @@ func GenRelayInfoWs(c *gin.Context, ws *websocket.Conn) *RelayInfo { return info } +func GenRelayInfoClaude(c *gin.Context) *RelayInfo { + info := GenRelayInfo(c) + info.RelayFormat = RelayFormatClaude + info.ShouldIncludeUsage = false + return info +} + func GenRelayInfo(c *gin.Context) *RelayInfo { channelType := c.GetInt("channel_type") channelId := c.GetInt("channel_id") @@ -123,6 +137,7 @@ func GenRelayInfo(c *gin.Context) *RelayInfo { ApiKey: strings.TrimPrefix(c.Request.Header.Get("Authorization"), "Bearer "), Organization: c.GetString("channel_organization"), ChannelSetting: channelSetting, + RelayFormat: RelayFormatOpenAI, ThinkingContentInfo: ThinkingContentInfo{ IsFirstThinkingContent: true, SendLastThinkingContent: false, @@ -157,6 +172,7 @@ func (info *RelayInfo) SetIsStream(isStream bool) { } func (info *RelayInfo) SetFirstResponseTime() { + info.ResponseTimes++ if info.isFirstResponse { info.FirstResponseTime = time.Now() info.isFirstResponse = false diff --git a/relay/helper/common.go b/relay/helper/common.go index 2a72d30a..6af55a86 100644 --- a/relay/helper/common.go +++ b/relay/helper/common.go @@ -19,6 +19,14 @@ func SetEventStreamHeaders(c *gin.Context) { c.Writer.Header().Set("X-Accel-Buffering", "no") } +func ClaudeChunkData(c *gin.Context, resp dto.ClaudeResponse, data string) { + c.Render(-1, common.CustomEvent{Data: fmt.Sprintf("event: %s\n", resp.Type)}) + c.Render(-1, common.CustomEvent{Data: fmt.Sprintf("data: %s\n", data)}) + if flusher, ok := c.Writer.(http.Flusher); ok { + flusher.Flush() + } +} + func StringData(c *gin.Context, str string) error { //str = strings.TrimPrefix(str, "data: ") //str = strings.TrimSuffix(str, "\r") diff --git a/relay/helper/price.go b/relay/helper/price.go index b169df98..1ae3d2fc 100644 --- a/relay/helper/price.go +++ b/relay/helper/price.go @@ -16,6 +16,7 @@ type PriceData struct { CacheRatio float64 GroupRatio float64 UsePrice bool + CacheCreationRatio float64 ShouldPreConsumedQuota int } @@ -26,6 +27,7 @@ func ModelPriceHelper(c *gin.Context, info *relaycommon.RelayInfo, promptTokens var modelRatio float64 var completionRatio float64 var cacheRatio float64 + var cacheCreationRatio float64 if !usePrice { preConsumedTokens := common.PreConsumedQuota if maxTokens != 0 { @@ -42,6 +44,7 @@ func ModelPriceHelper(c *gin.Context, info *relaycommon.RelayInfo, promptTokens } completionRatio = operation_setting.GetCompletionRatio(info.OriginModelName) cacheRatio, _ = operation_setting.GetCacheRatio(info.OriginModelName) + cacheCreationRatio, _ = operation_setting.GetCreateCacheRatio(info.OriginModelName) ratio := modelRatio * groupRatio preConsumedQuota = int(float64(preConsumedTokens) * ratio) } else { @@ -54,6 +57,7 @@ func ModelPriceHelper(c *gin.Context, info *relaycommon.RelayInfo, promptTokens GroupRatio: groupRatio, UsePrice: usePrice, CacheRatio: cacheRatio, + CacheCreationRatio: cacheCreationRatio, ShouldPreConsumedQuota: preConsumedQuota, }, nil } diff --git a/router/relay-router.go b/router/relay-router.go index 32e0c682..3a9122d4 100644 --- a/router/relay-router.go +++ b/router/relay-router.go @@ -35,6 +35,7 @@ func SetRelayRouter(router *gin.Engine) { //http router httpRouter := relayV1Router.Group("") httpRouter.Use(middleware.Distribute()) + httpRouter.POST("/messages", controller.RelayClaude) httpRouter.POST("/completions", controller.Relay) httpRouter.POST("/chat/completions", controller.Relay) httpRouter.POST("/edits", controller.Relay) diff --git a/service/convert.go b/service/convert.go new file mode 100644 index 00000000..c4916df2 --- /dev/null +++ b/service/convert.go @@ -0,0 +1,310 @@ +package service + +import ( + "encoding/json" + "fmt" + "one-api/common" + "one-api/dto" + relaycommon "one-api/relay/common" +) + +func ClaudeToOpenAIRequest(claudeRequest dto.ClaudeRequest) (*dto.GeneralOpenAIRequest, error) { + openAIRequest := dto.GeneralOpenAIRequest{ + Model: claudeRequest.Model, + MaxTokens: claudeRequest.MaxTokens, + Temperature: claudeRequest.Temperature, + TopP: claudeRequest.TopP, + Stream: claudeRequest.Stream, + } + + // Convert stop sequences + if len(claudeRequest.StopSequences) == 1 { + openAIRequest.Stop = claudeRequest.StopSequences[0] + } else if len(claudeRequest.StopSequences) > 1 { + openAIRequest.Stop = claudeRequest.StopSequences + } + + // Convert tools + tools, _ := common.Any2Type[[]dto.Tool](claudeRequest.Tools) + openAITools := make([]dto.ToolCallRequest, 0) + for _, claudeTool := range tools { + openAITool := dto.ToolCallRequest{ + Type: "function", + Function: dto.FunctionRequest{ + Name: claudeTool.Name, + Description: claudeTool.Description, + Parameters: claudeTool.InputSchema, + }, + } + openAITools = append(openAITools, openAITool) + } + openAIRequest.Tools = openAITools + + // Convert messages + openAIMessages := make([]dto.Message, 0) + + // Add system message if present + if claudeRequest.IsStringSystem() { + openAIMessage := dto.Message{ + Role: "system", + } + openAIMessage.SetStringContent(claudeRequest.GetStringSystem()) + openAIMessages = append(openAIMessages, openAIMessage) + } else { + systems := claudeRequest.ParseSystem() + if len(systems) > 0 { + systemStr := "" + openAIMessage := dto.Message{ + Role: "system", + } + for _, system := range systems { + systemStr += system.Type + } + openAIMessage.SetStringContent(systemStr) + openAIMessages = append(openAIMessages, openAIMessage) + } + } + for _, claudeMessage := range claudeRequest.Messages { + openAIMessage := dto.Message{ + Role: claudeMessage.Role, + } + + //log.Printf("claudeMessage.Content: %v", claudeMessage.Content) + if claudeMessage.IsStringContent() { + openAIMessage.SetStringContent(claudeMessage.GetStringContent()) + } else { + content, err := claudeMessage.ParseContent() + if err != nil { + return nil, err + } + contents := content + var toolCalls []dto.ToolCallRequest + mediaMessages := make([]dto.MediaContent, 0, len(contents)) + + for _, mediaMsg := range contents { + switch mediaMsg.Type { + case "text": + message := dto.MediaContent{ + Type: "text", + Text: mediaMsg.GetText(), + } + mediaMessages = append(mediaMessages, message) + case "image": + // Handle image conversion (base64 to URL or keep as is) + imageData := fmt.Sprintf("data:%s;base64,%s", mediaMsg.Source.MediaType, mediaMsg.Source.Data) + //textContent += fmt.Sprintf("[Image: %s]", imageData) + mediaMessage := dto.MediaContent{ + Type: "image_url", + ImageUrl: &dto.MessageImageUrl{Url: imageData}, + } + mediaMessages = append(mediaMessages, mediaMessage) + case "tool_use": + toolCall := dto.ToolCallRequest{ + ID: mediaMsg.Id, + Function: dto.FunctionRequest{ + Name: mediaMsg.Name, + Arguments: toJSONString(mediaMsg.Input), + }, + } + toolCalls = append(toolCalls, toolCall) + case "tool_result": + // Add tool result as a separate message + oaiToolMessage := dto.Message{ + Role: "tool", + ToolCallId: mediaMsg.ToolUseId, + } + oaiToolMessage.Content = mediaMsg.Content + } + } + + openAIMessage.SetMediaContent(mediaMessages) + + if len(toolCalls) > 0 { + openAIMessage.SetToolCalls(toolCalls) + } + } + + openAIMessages = append(openAIMessages, openAIMessage) + } + + openAIRequest.Messages = openAIMessages + + return &openAIRequest, nil +} + +func OpenAIErrorToClaudeError(openAIError *dto.OpenAIErrorWithStatusCode) *dto.ClaudeErrorWithStatusCode { + claudeError := dto.ClaudeError{ + Type: "new_api_error", + Message: openAIError.Error.Message, + } + return &dto.ClaudeErrorWithStatusCode{ + Error: claudeError, + StatusCode: openAIError.StatusCode, + } +} + +func ClaudeErrorToOpenAIError(claudeError *dto.ClaudeErrorWithStatusCode) *dto.OpenAIErrorWithStatusCode { + openAIError := dto.OpenAIError{ + Message: claudeError.Error.Message, + Type: "new_api_error", + } + return &dto.OpenAIErrorWithStatusCode{ + Error: openAIError, + StatusCode: claudeError.StatusCode, + } +} + +func StreamResponseOpenAI2Claude(openAIResponse *dto.ChatCompletionsStreamResponse, info *relaycommon.RelayInfo) []*dto.ClaudeResponse { + var claudeResponses []*dto.ClaudeResponse + if info.ResponseTimes == 1 { + claudeResponses = append(claudeResponses, &dto.ClaudeResponse{ + Type: "message_start", + Message: &dto.ClaudeMediaMessage{ + Id: openAIResponse.Id, + Model: openAIResponse.Model, + Type: "message", + Role: "assistant", + Usage: &dto.ClaudeUsage{ + InputTokens: info.PromptTokens, + OutputTokens: 0, + }, + }, + }) + if openAIResponse.IsToolCall() { + resp := &dto.ClaudeResponse{ + Type: "content_block_start", + ContentBlock: &dto.ClaudeMediaMessage{ + Id: openAIResponse.GetFirstToolCall().ID, + Type: "tool_use", + Name: openAIResponse.GetFirstToolCall().Function.Name, + }, + } + resp.SetIndex(0) + claudeResponses = append(claudeResponses, resp) + } else { + resp := &dto.ClaudeResponse{ + Type: "content_block_start", + ContentBlock: &dto.ClaudeMediaMessage{ + Type: "text", + Text: common.GetPointer[string](""), + }, + } + resp.SetIndex(0) + claudeResponses = append(claudeResponses, resp) + } + claudeResponses = append(claudeResponses, &dto.ClaudeResponse{ + Type: "ping", + }) + return claudeResponses + } + + if len(openAIResponse.Choices) == 0 { + // no choices + // TODO: handle this case + } else { + chosenChoice := openAIResponse.Choices[0] + if chosenChoice.FinishReason != nil && *chosenChoice.FinishReason != "" { + // should be done + claudeResponses = append(claudeResponses, &dto.ClaudeResponse{ + Type: "content_block_stop", + Index: common.GetPointer[int](0), + }) + if openAIResponse.Usage != nil { + claudeResponses = append(claudeResponses, &dto.ClaudeResponse{ + Type: "message_delta", + Usage: &dto.ClaudeUsage{ + InputTokens: openAIResponse.Usage.PromptTokens, + OutputTokens: openAIResponse.Usage.CompletionTokens, + }, + Delta: &dto.ClaudeMediaMessage{ + StopReason: common.GetPointer[string](stopReasonOpenAI2Claude(*chosenChoice.FinishReason)), + }, + }) + } + claudeResponses = append(claudeResponses, &dto.ClaudeResponse{ + Type: "message_stop", + }) + } else { + var claudeResponse dto.ClaudeResponse + claudeResponse.SetIndex(0) + claudeResponse.Type = "content_block_delta" + if len(chosenChoice.Delta.ToolCalls) > 0 { + // tools delta + claudeResponse.Delta = &dto.ClaudeMediaMessage{ + Type: "input_json_delta", + PartialJson: chosenChoice.Delta.ToolCalls[0].Function.Arguments, + } + } else { + // text delta + claudeResponse.Delta = &dto.ClaudeMediaMessage{ + Type: "text_delta", + Text: common.GetPointer[string](chosenChoice.Delta.GetContentString()), + } + } + claudeResponses = append(claudeResponses, &claudeResponse) + } + } + + return claudeResponses +} + +func ResponseOpenAI2Claude(openAIResponse *dto.OpenAITextResponse, info *relaycommon.RelayInfo) *dto.ClaudeResponse { + var stopReason string + contents := make([]dto.ClaudeMediaMessage, 0) + claudeResponse := &dto.ClaudeResponse{ + Id: openAIResponse.Id, + Type: "message", + Role: "assistant", + Model: openAIResponse.Model, + } + for _, choice := range openAIResponse.Choices { + stopReason = stopReasonOpenAI2Claude(choice.FinishReason) + claudeContent := dto.ClaudeMediaMessage{} + if choice.FinishReason == "tool_calls" { + claudeContent.Type = "tool_use" + claudeContent.Id = choice.Message.ToolCallId + claudeContent.Name = choice.Message.ParseToolCalls()[0].Function.Name + var mapParams map[string]interface{} + if err := json.Unmarshal([]byte(choice.Message.ParseToolCalls()[0].Function.Arguments), &mapParams); err == nil { + claudeContent.Input = mapParams + } else { + claudeContent.Input = choice.Message.ParseToolCalls()[0].Function.Arguments + } + } else { + claudeContent.Type = "text" + claudeContent.SetText(choice.Message.StringContent()) + } + contents = append(contents, claudeContent) + } + claudeResponse.Content = contents + claudeResponse.StopReason = stopReason + claudeResponse.Usage = &dto.ClaudeUsage{ + InputTokens: openAIResponse.PromptTokens, + OutputTokens: openAIResponse.CompletionTokens, + } + + return claudeResponse +} + +func stopReasonOpenAI2Claude(reason string) string { + switch reason { + case "stop": + return "end_turn" + case "stop_sequence": + return "stop_sequence" + case "max_tokens": + return "max_tokens" + case "tool_calls": + return "tool_use" + default: + return reason + } +} + +func toJSONString(v interface{}) string { + b, err := json.Marshal(v) + if err != nil { + return "{}" + } + return string(b) +} diff --git a/service/error.go b/service/error.go index 82fbda18..9824a853 100644 --- a/service/error.go +++ b/service/error.go @@ -50,6 +50,30 @@ func OpenAIErrorWrapperLocal(err error, code string, statusCode int) *dto.OpenAI return openaiErr } +func ClaudeErrorWrapper(err error, code string, statusCode int) *dto.ClaudeErrorWithStatusCode { + text := err.Error() + lowerText := strings.ToLower(text) + if strings.Contains(lowerText, "post") || strings.Contains(lowerText, "dial") || strings.Contains(lowerText, "http") { + common.SysLog(fmt.Sprintf("error: %s", text)) + text = "请求上游地址失败" + } + claudeError := dto.ClaudeError{ + Message: text, + Type: "new_api_error", + //Code: code, + } + return &dto.ClaudeErrorWithStatusCode{ + Error: claudeError, + StatusCode: statusCode, + } +} + +func ClaudeErrorWrapperLocal(err error, code string, statusCode int) *dto.ClaudeErrorWithStatusCode { + claudeErr := ClaudeErrorWrapper(err, code, statusCode) + claudeErr.LocalError = true + return claudeErr +} + func RelayErrorHandler(resp *http.Response, showBodyWhenFail bool) (errWithStatusCode *dto.OpenAIErrorWithStatusCode) { errWithStatusCode = &dto.OpenAIErrorWithStatusCode{ StatusCode: resp.StatusCode, diff --git a/service/log_info_generate.go b/service/log_info_generate.go index 6406cbe1..75457b97 100644 --- a/service/log_info_generate.go +++ b/service/log_info_generate.go @@ -53,3 +53,12 @@ func GenerateAudioOtherInfo(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, info["audio_completion_ratio"] = audioCompletionRatio return info } + +func GenerateClaudeOtherInfo(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, modelRatio, groupRatio, completionRatio float64, + cacheTokens int, cacheRatio float64, cacheCreationTokens int, cacheCreationRatio float64, modelPrice float64) map[string]interface{} { + info := GenerateTextOtherInfo(ctx, relayInfo, modelRatio, groupRatio, completionRatio, cacheTokens, cacheRatio, modelPrice) + info["claude"] = true + info["cache_creation_tokens"] = cacheCreationTokens + info["cache_creation_ratio"] = cacheCreationRatio + return info +} diff --git a/service/quota.go b/service/quota.go index e19f1b82..ec5af57a 100644 --- a/service/quota.go +++ b/service/quota.go @@ -194,6 +194,75 @@ func PostWssConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, mod tokenName, quota, logContent, relayInfo.TokenId, userQuota, int(useTimeSeconds), relayInfo.IsStream, relayInfo.Group, other) } +func PostClaudeConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, + usage *dto.Usage, preConsumedQuota int, userQuota int, priceData helper.PriceData, extraContent string) { + + useTimeSeconds := time.Now().Unix() - relayInfo.StartTime.Unix() + promptTokens := usage.PromptTokens + completionTokens := usage.CompletionTokens + modelName := relayInfo.OriginModelName + + tokenName := ctx.GetString("token_name") + completionRatio := priceData.CompletionRatio + modelRatio := priceData.ModelRatio + groupRatio := priceData.GroupRatio + modelPrice := priceData.ModelPrice + + cacheRatio := priceData.CacheRatio + cacheTokens := usage.PromptTokensDetails.CachedTokens + + cacheCreationRatio := priceData.CacheCreationRatio + cacheCreationTokens := usage.PromptTokensDetails.CachedCreationTokens + + calculateQuota := 0.0 + if !priceData.UsePrice { + calculateQuota = float64(promptTokens) + calculateQuota += float64(cacheTokens) * cacheRatio + calculateQuota += float64(cacheCreationTokens) * cacheCreationRatio + calculateQuota += float64(completionTokens) * completionRatio + calculateQuota = calculateQuota * groupRatio * modelRatio + } else { + calculateQuota = modelPrice * common.QuotaPerUnit * groupRatio + } + + if modelRatio != 0 && calculateQuota <= 0 { + calculateQuota = 1 + } + + quota := int(calculateQuota) + + totalTokens := promptTokens + completionTokens + + var logContent string + // record all the consume log even if quota is 0 + if totalTokens == 0 { + // in this case, must be some error happened + // we cannot just return, because we may have to return the pre-consumed quota + quota = 0 + logContent += fmt.Sprintf("(可能是上游出错)") + common.LogError(ctx, fmt.Sprintf("total tokens is 0, cannot consume quota, userId %d, channelId %d, "+ + "tokenId %d, model %s, pre-consumed quota %d", relayInfo.UserId, relayInfo.ChannelId, relayInfo.TokenId, modelName, preConsumedQuota)) + } else { + //if sensitiveResp != nil { + // logContent += fmt.Sprintf(",敏感词:%s", strings.Join(sensitiveResp.SensitiveWords, ", ")) + //} + quotaDelta := quota - preConsumedQuota + if quotaDelta != 0 { + err := PostConsumeQuota(relayInfo, quotaDelta, preConsumedQuota, true) + if err != nil { + common.LogError(ctx, "error consuming token remain quota: "+err.Error()) + } + } + model.UpdateUserUsedQuotaAndRequestCount(relayInfo.UserId, quota) + model.UpdateChannelUsedQuota(relayInfo.ChannelId, quota) + } + + other := GenerateClaudeOtherInfo(ctx, relayInfo, modelRatio, groupRatio, completionRatio, + cacheTokens, cacheRatio, cacheCreationTokens, cacheCreationRatio, modelPrice) + model.RecordConsumeLog(ctx, relayInfo.UserId, relayInfo.ChannelId, promptTokens, completionTokens, modelName, + tokenName, quota, logContent, relayInfo.TokenId, userQuota, int(useTimeSeconds), relayInfo.IsStream, relayInfo.Group, other) +} + func PostAudioConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, usage *dto.Usage, preConsumedQuota int, userQuota int, priceData helper.PriceData, extraContent string) { diff --git a/service/token_counter.go b/service/token_counter.go index a6b8e86a..98386f85 100644 --- a/service/token_counter.go +++ b/service/token_counter.go @@ -1,6 +1,7 @@ package service import ( + "encoding/json" "errors" "fmt" "image" @@ -192,6 +193,110 @@ func CountTokenChatRequest(info *relaycommon.RelayInfo, request dto.GeneralOpenA return tkm, nil } +func CountTokenClaudeRequest(request dto.ClaudeRequest, model string) (int, error) { + tkm := 0 + + // Count tokens in messages + msgTokens, err := CountTokenClaudeMessages(request.Messages, model, request.Stream) + if err != nil { + return 0, err + } + tkm += msgTokens + + // Count tokens in system message + if request.System != "" { + systemTokens, err := CountTokenInput(request.System, model) + if err != nil { + return 0, err + } + tkm += systemTokens + } + + if request.Tools != nil { + // check is array + if tools, ok := request.Tools.([]any); ok { + if len(tools) > 0 { + parsedTools, err1 := common.Any2Type[[]dto.Tool](request.Tools) + if err1 != nil { + return 0, fmt.Errorf("tools: Input should be a valid list: %v", err) + } + toolTokens, err2 := CountTokenClaudeTools(parsedTools, model) + if err2 != nil { + return 0, fmt.Errorf("tools: %v", err) + } + tkm += toolTokens + } + } else { + return 0, errors.New("tools: Input should be a valid list") + } + } + + return tkm, nil +} + +func CountTokenClaudeMessages(messages []dto.ClaudeMessage, model string, stream bool) (int, error) { + tokenEncoder := getTokenEncoder(model) + tokenNum := 0 + + for _, message := range messages { + // Count tokens for role + tokenNum += getTokenNum(tokenEncoder, message.Role) + if message.IsStringContent() { + tokenNum += getTokenNum(tokenEncoder, message.GetStringContent()) + } else { + content, err := message.ParseContent() + if err != nil { + return 0, err + } + for _, mediaMessage := range content { + switch mediaMessage.Type { + case "text": + tokenNum += getTokenNum(tokenEncoder, mediaMessage.GetText()) + case "image": + //imageTokenNum, err := getClaudeImageToken(mediaMsg.Source, model, stream) + //if err != nil { + // return 0, err + //} + tokenNum += 1000 + case "tool_use": + tokenNum += getTokenNum(tokenEncoder, mediaMessage.Name) + inputJSON, _ := json.Marshal(mediaMessage.Input) + tokenNum += getTokenNum(tokenEncoder, string(inputJSON)) + case "tool_result": + contentJSON, _ := json.Marshal(mediaMessage.Content) + tokenNum += getTokenNum(tokenEncoder, string(contentJSON)) + } + } + } + } + + // Add a constant for message formatting (this may need adjustment based on Claude's exact formatting) + tokenNum += len(messages) * 2 // Assuming 2 tokens per message for formatting + + return tokenNum, nil +} + +func CountTokenClaudeTools(tools []dto.Tool, model string) (int, error) { + tokenEncoder := getTokenEncoder(model) + tokenNum := 0 + + for _, tool := range tools { + tokenNum += getTokenNum(tokenEncoder, tool.Name) + tokenNum += getTokenNum(tokenEncoder, tool.Description) + + schemaJSON, err := json.Marshal(tool.InputSchema) + if err != nil { + return 0, errors.New(fmt.Sprintf("marshal_tool_schema_fail: %s", err.Error())) + } + tokenNum += getTokenNum(tokenEncoder, string(schemaJSON)) + } + + // Add a constant for tool formatting (this may need adjustment based on Claude's exact formatting) + tokenNum += len(tools) * 3 // Assuming 3 tokens per tool for formatting + + return tokenNum, nil +} + func CountTokenRealtime(info *relaycommon.RelayInfo, request dto.RealtimeEvent, model string) (int, int, error) { audioToken := 0 textToken := 0 diff --git a/setting/operation_setting/cache_ratio.go b/setting/operation_setting/cache_ratio.go index 98f022ed..01d79c10 100644 --- a/setting/operation_setting/cache_ratio.go +++ b/setting/operation_setting/cache_ratio.go @@ -7,26 +7,45 @@ import ( ) var defaultCacheRatio = map[string]float64{ - "gpt-4": 0.5, - "o1": 0.5, - "o1-2024-12-17": 0.5, - "o1-preview-2024-09-12": 0.5, - "o1-preview": 0.5, - "o1-mini-2024-09-12": 0.5, - "o1-mini": 0.5, - "gpt-4o-2024-11-20": 0.5, - "gpt-4o-2024-08-06": 0.5, - "gpt-4o": 0.5, - "gpt-4o-mini-2024-07-18": 0.5, - "gpt-4o-mini": 0.5, - "gpt-4o-realtime-preview": 0.5, - "gpt-4o-mini-realtime-preview": 0.5, - "deepseek-chat": 0.1, - "deepseek-reasoner": 0.1, - "deepseek-coder": 0.1, + "gpt-4": 0.5, + "o1": 0.5, + "o1-2024-12-17": 0.5, + "o1-preview-2024-09-12": 0.5, + "o1-preview": 0.5, + "o1-mini-2024-09-12": 0.5, + "o1-mini": 0.5, + "gpt-4o-2024-11-20": 0.5, + "gpt-4o-2024-08-06": 0.5, + "gpt-4o": 0.5, + "gpt-4o-mini-2024-07-18": 0.5, + "gpt-4o-mini": 0.5, + "gpt-4o-realtime-preview": 0.5, + "gpt-4o-mini-realtime-preview": 0.5, + "deepseek-chat": 0.25, + "deepseek-reasoner": 0.25, + "deepseek-coder": 0.25, + "claude-3-sonnet-20240229": 0.1, + "claude-3-opus-20240229": 0.1, + "claude-3-haiku-20240307": 0.1, + "claude-3-5-haiku-20241022": 0.1, + "claude-3-5-sonnet-20240620": 0.1, + "claude-3-5-sonnet-20241022": 0.1, + "claude-3-7-sonnet-20250219": 0.1, + "claude-3-7-sonnet-20250219-thinking": 0.1, } -var defaultCreateCacheRatio = map[string]float64{} +var defaultCreateCacheRatio = map[string]float64{ + "claude-3-sonnet-20240229": 1.25, + "claude-3-opus-20240229": 1.25, + "claude-3-haiku-20240307": 1.25, + "claude-3-5-haiku-20241022": 1.25, + "claude-3-5-sonnet-20240620": 1.25, + "claude-3-5-sonnet-20241022": 1.25, + "claude-3-7-sonnet-20250219": 1.25, + "claude-3-7-sonnet-20250219-thinking": 1.25, +} + +//var defaultCreateCacheRatio = map[string]float64{} var cacheRatioMap map[string]float64 var cacheRatioMapMutex sync.RWMutex @@ -69,16 +88,10 @@ func GetCacheRatio(name string) (float64, bool) { return ratio, true } -// DefaultCacheRatio2JSONString converts the default cache ratio map to a JSON string -func DefaultCacheRatio2JSONString() string { - jsonBytes, err := json.Marshal(defaultCacheRatio) - if err != nil { - common.SysError("error marshalling default cache ratio: " + err.Error()) +func GetCreateCacheRatio(name string) (float64, bool) { + ratio, ok := defaultCreateCacheRatio[name] + if !ok { + return 1.25, false // Default to 1.25 if not found } - return string(jsonBytes) -} - -// GetDefaultCacheRatioMap returns the default cache ratio map -func GetDefaultCacheRatioMap() map[string]float64 { - return defaultCacheRatio + return ratio, true } diff --git a/web/src/components/LogsTable.js b/web/src/components/LogsTable.js index fb87d54f..21d0a979 100644 --- a/web/src/components/LogsTable.js +++ b/web/src/components/LogsTable.js @@ -26,8 +26,14 @@ import { } from '@douyinfe/semi-ui'; import { ITEMS_PER_PAGE } from '../constants'; import { - renderAudioModelPrice, renderGroup, - renderModelPrice, renderModelPriceSimple, + renderAudioModelPrice, + renderClaudeLogContent, + renderClaudeModelPrice, + renderClaudeModelPriceSimple, + renderGroup, + renderLogContent, + renderModelPrice, + renderModelPriceSimple, renderNumber, renderQuota, stringToColor @@ -564,13 +570,23 @@ const LogsTable = () => { ); } - let content = renderModelPriceSimple( - other.model_ratio, - other.model_price, - other.group_ratio, - other.cache_tokens || 0, - other.cache_ratio || 1.0, - ); + let content = other?.claude + ? renderClaudeModelPriceSimple( + other.model_ratio, + other.model_price, + other.group_ratio, + other.cache_tokens || 0, + other.cache_ratio || 1.0, + other.cache_creation_tokens || 0, + other.cache_creation_ratio || 1.0, + ) + : renderModelPriceSimple( + other.model_ratio, + other.model_price, + other.group_ratio, + other.cache_tokens || 0, + other.cache_ratio || 1.0, + ); return ( { value: other.cache_tokens, }); } - expandDataLocal.push({ - key: t('日志详情'), - value: logs[i].content, - }); + if (other?.cache_creation_tokens > 0) { + expandDataLocal.push({ + key: t('缓存创建 Tokens'), + value: other.cache_creation_tokens, + }); + } + if (logs[i].type === 2) { + expandDataLocal.push({ + key: t('日志详情'), + value: other?.claude + ? renderClaudeLogContent( + other?.model_ratio, + other.completion_ratio, + other.model_price, + other.group_ratio, + other.user_group_ratio, + other.cache_ratio || 1.0, + other.cache_creation_ratio || 1.0 + ) + : renderLogContent( + other?.model_ratio, + other.completion_ratio, + other.model_price, + other.group_ratio, + other.user_group_ratio + ), + }); + } if (logs[i].type === 2) { let modelMapped = other?.is_model_mapped && other?.upstream_model_name && other?.upstream_model_name !== ''; if (modelMapped) { @@ -850,6 +890,19 @@ const LogsTable = () => { other?.cache_tokens || 0, other?.cache_ratio || 1.0, ); + } else if (other?.claude) { + content = renderClaudeModelPrice( + logs[i].prompt_tokens, + logs[i].completion_tokens, + other.model_ratio, + other.model_price, + other.completion_ratio, + other.group_ratio, + other.cache_tokens || 0, + other.cache_ratio || 1.0, + other.cache_creation_tokens || 0, + other.cache_creation_ratio || 1.0, + ); } else { content = renderModelPrice( logs[i].prompt_tokens, diff --git a/web/src/helpers/render.js b/web/src/helpers/render.js index 3ac81420..d1396191 100644 --- a/web/src/helpers/render.js +++ b/web/src/helpers/render.js @@ -325,9 +325,8 @@ export function renderModelPrice( return ( <>
-

{i18next.t('提示价格:${{price}} = ${{total}} / 1M tokens', { +

{i18next.t('提示价格:${{price}} / 1M tokens', { price: inputRatioPrice, - total: inputRatioPrice })}

{i18next.t('补全价格:${{price}} * {{completionRatio}} = ${{total}} / 1M tokens (补全倍率: {{completionRatio}})', { price: inputRatioPrice, @@ -445,9 +444,8 @@ export function renderAudioModelPrice( return ( <>

-

{i18next.t('提示价格:${{price}} = ${{total}} / 1M tokens', { +

{i18next.t('提示价格:${{price}} / 1M tokens', { price: inputRatioPrice, - total: inputRatioPrice })}

{i18next.t('补全价格:${{price}} * {{completionRatio}} = ${{total}} / 1M tokens (补全倍率: {{completionRatio}})', { price: inputRatioPrice, @@ -654,3 +652,194 @@ export function stringToColor(str) { let i = sum % colors.length; return colors[i]; } + +export function renderClaudeModelPrice( + inputTokens, + completionTokens, + modelRatio, + modelPrice = -1, + completionRatio, + groupRatio, + cacheTokens = 0, + cacheRatio = 1.0, + cacheCreationTokens = 0, + cacheCreationRatio = 1.0, +) { + const ratioLabel = false ? i18next.t('专属倍率') : i18next.t('分组倍率'); + + if (modelPrice !== -1) { + return i18next.t('模型价格:${{price}} * {{ratioType}}:{{ratio}} = ${{total}}', { + price: modelPrice, + ratioType: ratioLabel, + ratio: groupRatio, + total: modelPrice * groupRatio + }); + } else { + if (completionRatio === undefined) { + completionRatio = 0; + } + + const completionRatioValue = completionRatio || 0; + const inputRatioPrice = modelRatio * 2.0; + const completionRatioPrice = modelRatio * 2.0 * completionRatioValue; + let cacheRatioPrice = (modelRatio * 2.0 * cacheRatio).toFixed(2); + let cacheCreationRatioPrice = modelRatio * 2.0 * cacheCreationRatio; + + // Calculate effective input tokens (non-cached + cached with ratio applied + cache creation with ratio applied) + const nonCachedTokens = inputTokens; + const effectiveInputTokens = nonCachedTokens + + (cacheTokens * cacheRatio) + + (cacheCreationTokens * cacheCreationRatio); + + let price = + (effectiveInputTokens / 1000000) * inputRatioPrice * groupRatio + + (completionTokens / 1000000) * completionRatioPrice * groupRatio; + + return ( + <> +

+

{i18next.t('提示价格:${{price}} / 1M tokens', { + price: inputRatioPrice, + })}

+

{i18next.t('补全价格:${{price}} * {{ratio}} = ${{total}} / 1M tokens', { + price: inputRatioPrice, + ratio: completionRatio, + total: completionRatioPrice + })}

+ {cacheTokens > 0 && ( +

{i18next.t('缓存价格:${{price}} * {{ratio}} = ${{total}} / 1M tokens (缓存倍率: {{cacheRatio}})', { + price: inputRatioPrice, + ratio: cacheRatio, + total: cacheRatioPrice, + cacheRatio: cacheRatio + })}

+ )} + {cacheCreationTokens > 0 && ( +

{i18next.t('缓存创建价格:${{price}} * {{ratio}} = ${{total}} / 1M tokens (缓存创建倍率: {{cacheCreationRatio}})', { + price: inputRatioPrice, + ratio: cacheCreationRatio, + total: cacheCreationRatioPrice, + cacheCreationRatio: cacheCreationRatio + })}

+ )} +

+

+ {(cacheTokens > 0 || cacheCreationTokens > 0) ? + i18next.t('提示 {{nonCacheInput}} tokens / 1M tokens * ${{price}} + 缓存 {{cacheInput}} tokens / 1M tokens * ${{cachePrice}} + 缓存创建 {{cacheCreationInput}} tokens / 1M tokens * ${{cacheCreationPrice}} + 补全 {{completion}} tokens / 1M tokens * ${{compPrice}} * 分组 {{ratio}} = ${{total}}', { + nonCacheInput: nonCachedTokens, + cacheInput: cacheTokens, + cacheRatio: cacheRatio, + cacheCreationInput: cacheCreationTokens, + cacheCreationRatio: cacheCreationRatio, + cachePrice: cacheRatioPrice, + cacheCreationPrice: cacheCreationRatioPrice, + price: inputRatioPrice, + completion: completionTokens, + compPrice: completionRatioPrice, + ratio: groupRatio, + total: price.toFixed(6) + }) : + i18next.t('提示 {{input}} tokens / 1M tokens * ${{price}} + 补全 {{completion}} tokens / 1M tokens * ${{compPrice}} * 分组 {{ratio}} = ${{total}}', { + input: inputTokens, + price: inputRatioPrice, + completion: completionTokens, + compPrice: completionRatioPrice, + ratio: groupRatio, + total: price.toFixed(6) + }) + } +

+

{i18next.t('仅供参考,以实际扣费为准')}

+
+ + ); + } +} + +export function renderClaudeLogContent( + modelRatio, + completionRatio, + modelPrice = -1, + groupRatio, + cacheRatio = 1.0, + cacheCreationRatio = 1.0, +) { + const ratioLabel = false ? i18next.t('专属倍率') : i18next.t('分组倍率'); + + if (modelPrice !== -1) { + return i18next.t('模型价格 ${{price}},{{ratioType}} {{ratio}}', { + price: modelPrice, + ratioType: ratioLabel, + ratio: groupRatio + }); + } else { + return i18next.t('模型倍率 {{modelRatio}},补全倍率 {{completionRatio}},缓存倍率 {{cacheRatio}},缓存创建倍率 {{cacheCreationRatio}},{{ratioType}} {{ratio}}', { + modelRatio: modelRatio, + completionRatio: completionRatio, + cacheRatio: cacheRatio, + cacheCreationRatio: cacheCreationRatio, + ratioType: ratioLabel, + ratio: groupRatio + }); + } +} + +export function renderClaudeModelPriceSimple( + modelRatio, + modelPrice = -1, + groupRatio, + cacheTokens = 0, + cacheRatio = 1.0, + cacheCreationTokens = 0, + cacheCreationRatio = 1.0, +) { + const ratioLabel = false ? i18next.t('专属倍率') : i18next.t('分组'); + + if (modelPrice !== -1) { + return i18next.t('价格:${{price}} * {{ratioType}}:{{ratio}}', { + price: modelPrice, + ratioType: ratioLabel, + ratio: groupRatio + }); + } else { + if (cacheTokens !== 0 || cacheCreationTokens !== 0) { + return i18next.t('模型: {{ratio}} * {{ratioType}}: {{groupRatio}} * 缓存: {{cacheRatio}}', { + ratio: modelRatio, + ratioType: ratioLabel, + groupRatio: groupRatio, + cacheRatio: cacheRatio, + cacheCreationRatio: cacheCreationRatio + }); + } else { + return i18next.t('模型: {{ratio}} * {{ratioType}}: {{groupRatio}}', { + ratio: modelRatio, + ratioType: ratioLabel, + groupRatio: groupRatio + }); + } + } +} + +export function renderLogContent( + modelRatio, + completionRatio, + modelPrice = -1, + groupRatio +) { + const ratioLabel = false ? i18next.t('专属倍率') : i18next.t('分组倍率'); + + if (modelPrice !== -1) { + return i18next.t('模型价格 ${{price}},{{ratioType}} {{ratio}}', { + price: modelPrice, + ratioType: ratioLabel, + ratio: groupRatio + }); + } else { + return i18next.t('模型倍率 {{modelRatio}},补全倍率 {{completionRatio}},{{ratioType}} {{ratio}}', { + modelRatio: modelRatio, + completionRatio: completionRatio, + ratioType: ratioLabel, + ratio: groupRatio + }); + } +} From 2048b451bf47a3ce2f390656e3c2655db3f30916 Mon Sep 17 00:00:00 2001 From: "1808837298@qq.com" <1808837298@qq.com> Date: Wed, 12 Mar 2025 21:35:57 +0800 Subject: [PATCH 07/18] fix panic --- relay/channel/claude/relay-claude.go | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/relay/channel/claude/relay-claude.go b/relay/channel/claude/relay-claude.go index 205e0b61..74b73454 100644 --- a/relay/channel/claude/relay-claude.go +++ b/relay/channel/claude/relay-claude.go @@ -330,7 +330,7 @@ func StreamResponseClaude2OpenAI(reqMode int, claudeResponse *dto.ClaudeResponse } else if claudeResponse.Type == "content_block_delta" { if claudeResponse.Delta != nil { choice.Index = *claudeResponse.Index - choice.Delta.SetContentString(*claudeResponse.Delta.Text) + choice.Delta.Content = claudeResponse.Delta.Text switch claudeResponse.Delta.Type { case "input_json_delta": tools = append(tools, dto.ToolCallResponse{ @@ -452,7 +452,9 @@ func FormatClaudeResponseInfo(requestMode int, claudeResponse *dto.ClaudeRespons claudeInfo.Model = claudeResponse.Message.Model claudeInfo.Usage.PromptTokens = claudeResponse.Message.Usage.InputTokens } else if claudeResponse.Type == "content_block_delta" { - claudeInfo.ResponseText.WriteString(*claudeResponse.Delta.Text) + if claudeResponse.Delta.Text != nil { + claudeInfo.ResponseText.WriteString(*claudeResponse.Delta.Text) + } } else if claudeResponse.Type == "message_delta" { claudeInfo.Usage.CompletionTokens = claudeResponse.Usage.OutputTokens claudeInfo.Usage.TotalTokens = claudeResponse.Usage.InputTokens + claudeResponse.Usage.OutputTokens From e68edf81f76edab15cb23c4448298eca4ce8bb46 Mon Sep 17 00:00:00 2001 From: Calcium-Ion <61247483+Calcium-Ion@users.noreply.github.com> Date: Wed, 12 Mar 2025 22:12:09 +0800 Subject: [PATCH 08/18] Update README.md --- README.md | 216 +++++++++++++++++++++++------------------------------- 1 file changed, 90 insertions(+), 126 deletions(-) diff --git a/README.md b/README.md index 0a0ff71b..6c1054f9 100644 --- a/README.md +++ b/README.md @@ -7,7 +7,6 @@ # New API - 🍥新一代大模型网关与AI资产管理系统 Calcium-Ion%2Fnew-api | Trendshift @@ -41,39 +40,40 @@ > - 本项目仅供个人学习使用,不保证稳定性,且不提供任何技术支持。 > - 根据[《生成式人工智能服务管理暂行办法》](http://www.cac.gov.cn/2023-07/13/c_1690898327029107.htm)的要求,请勿对中国地区公众提供一切未经备案的生成式人工智能服务。 +## 📚 文档 + +详细文档请访问我们的官方Wiki:[https://docs.newapi.pro/](https://docs.newapi.pro/) + ## ✨ 主要特性 -1. 🎨 全新的UI界面(部分界面还待更新) -2. 🌍 多语言支持(待完善) -3. 🎨 添加[Midjourney-Proxy(Plus)](https://github.com/novicezk/midjourney-proxy)接口支持,[对接文档](Midjourney.md) -4. 💰 支持在线充值功能,可在系统设置中设置: - - [x] 易支付 -5. 🔍 支持用key查询使用额度: - - 配合项目[neko-api-key-tool](https://github.com/Calcium-Ion/neko-api-key-tool)可实现用key查询使用 +New API提供了丰富的功能,详细特性请参考[维基百科-特性说明](https://docs.newapi.pro/wiki/features-introduction): + +1. 🎨 全新的UI界面 +2. 🌍 多语言支持 +3. 🎨 支持[Midjourney-Proxy(Plus)](https://github.com/novicezk/midjourney-proxy)接口,[对接文档](https://docs.newapi.pro/api/relay/image/midjourney) +4. 💰 支持在线充值功能(易支付) +5. 🔍 支持用key查询使用额度(配合[neko-api-key-tool](https://github.com/Calcium-Ion/neko-api-key-tool)) 6. 📑 分页支持选择每页显示数量 -7. 🔄 兼容原版One API的数据库,可直接使用原版数据库(one-api.db) -8. 💵 支持模型按次数收费,可在 系统设置-运营设置 中设置 -9. ⚖️ 支持渠道**加权随机** +7. 🔄 兼容原版One API的数据库 +8. 💵 支持模型按次数收费 +9. ⚖️ 支持渠道加权随机 10. 📈 数据看板(控制台) 11. 🔒 可设置令牌能调用的模型 -12. 🤖 支持Telegram授权登录: - 1. 系统设置-配置登录注册-允许通过Telegram登录 - 2. 对[@Botfather](https://t.me/botfather)输入指令/setdomain - 3. 选择你的bot,然后输入http(s)://你的网站地址/login - 4. Telegram Bot 名称是bot username 去掉@后的字符串 -13. 🎵 添加 [Suno API](https://github.com/Suno-API/Suno-API)接口支持,[对接文档](Suno.md) -14. 🔄 支持Rerank模型,目前兼容Cohere和Jina,可接入Dify,[对接文档](Rerank.md) -15. ⚡ **[OpenAI Realtime API](https://platform.openai.com/docs/guides/realtime/integration)** - 支持OpenAI的Realtime API,支持Azure渠道 -16. 支持使用路由/chat2link 进入聊天界面 -17. 🧠 支持通过模型名称后缀设置 reasoning effort: +12. 🤖 支持Telegram授权登录 +13. 🎵 支持[Suno API](https://github.com/Suno-API/Suno-API)接口,[接口文档](https://docs.newapi.pro/api/suno-music) +14. 🔄 支持Rerank模型(Cohere和Jina),[接口文档](https://docs.newapi.pro/api/jinaai-rerank) +15. ⚡ 支持OpenAI Realtime API(包括Azure渠道),[接口文档](https://docs.newapi.pro/api/openai-realtime) +16. ⚡ 支持Claude Messages 格式,[接口文档](https://docs.newapi.pro/api/anthropic-chat) +17. 支持使用路由/chat2link进入聊天界面 +18. 🧠 支持通过模型名称后缀设置 reasoning effort: 1. OpenAI o系列模型 - 添加后缀 `-high` 设置为 high reasoning effort (例如: `o3-mini-high`) - 添加后缀 `-medium` 设置为 medium reasoning effort (例如: `o3-mini-medium`) - 添加后缀 `-low` 设置为 low reasoning effort (例如: `o3-mini-low`) 2. Claude 思考模型 - 添加后缀 `-thinking` 启用思考模式 (例如: `claude-3-7-sonnet-20250219-thinking`) -18. 🔄 思考转内容,支持在 `渠道-编辑-渠道额外设置` 中设置 `thinking_to_content` 选项,默认`false`,开启后会将思考内容`reasoning_content`转换为``标签拼接到内容中返回。 -19. 🔄 模型限流,支持在 `系统设置-速率限制设置` 中设置模型限流,支持设置总请求数限制和成功请求数限制 +19. 🔄 思考转内容功能 +20. 🔄 模型限流功能 20. 💰 缓存计费支持,开启后可以在缓存命中时按照设定的比例计费: 1. 在 `系统设置-运营设置` 中设置 `提示缓存倍率` 选项 2. 在渠道中设置 `提示缓存倍率`,范围 0-1,例如设置为 0.5 表示缓存命中时按照 50% 计费 @@ -81,155 +81,119 @@ - [x] OpenAI - [x] Azure - [x] DeepSeek - - [ ] Claude + - [x] Claude ## 模型支持 -此版本额外支持以下模型: + +此版本支持多种模型,详情请参考[接口文档-中继接口](https://docs.newapi.pro/api): + 1. 第三方模型 **gpts** (gpt-4-gizmo-*) -2. [Midjourney-Proxy(Plus)](https://github.com/novicezk/midjourney-proxy)接口,[对接文档](Midjourney.md) +2. [Midjourney-Proxy(Plus)](https://github.com/novicezk/midjourney-proxy)接口,[接口文档](https://docs.newapi.pro/api/midjourney-proxy-image) 3. 自定义渠道,支持填入完整调用地址 -4. [Suno API](https://github.com/Suno-API/Suno-API) 接口,[对接文档](Suno.md) -5. Rerank模型,目前支持[Cohere](https://cohere.ai/)和[Jina](https://jina.ai/),[对接文档](Rerank.md) -6. Dify +4. [Suno API](https://github.com/Suno-API/Suno-API)接口,[接口文档](https://docs.newapi.pro/api/suno-music) +5. Rerank模型([Cohere](https://cohere.ai/)和[Jina](https://jina.ai/)),[接口文档](https://docs.newapi.pro/api/jinaai-rerank) +6. Claude Messages 格式,[接口文档](https://docs.newapi.pro/api/anthropic-chat) +7. Dify -您可以在渠道中添加自定义模型gpt-4-gizmo-*,此模型并非OpenAI官方模型,而是第三方模型,使用官方key无法调用。 +## 环境变量配置 -## 比原版One API多出的配置 -- `GENERATE_DEFAULT_TOKEN`:是否为新注册用户生成初始令牌,默认为 `false`。 -- `STREAMING_TIMEOUT`:设置流式一次回复的超时时间,默认为 60 秒。 -- `DIFY_DEBUG`:设置 Dify 渠道是否输出工作流和节点信息到客户端,默认为 `true`。 -- `FORCE_STREAM_OPTION`:是否覆盖客户端stream_options参数,请求上游返回流模式usage,默认为 `true`,建议开启,不影响客户端传入stream_options参数返回结果。 -- `GET_MEDIA_TOKEN`:是否统计图片token,默认为 `true`,关闭后将不再在本地计算图片token,可能会导致和上游计费不同,此项覆盖 `GET_MEDIA_TOKEN_NOT_STREAM` 选项作用。 -- `GET_MEDIA_TOKEN_NOT_STREAM`:是否在非流(`stream=false`)情况下统计图片token,默认为 `true`。 -- `UPDATE_TASK`:是否更新异步任务(Midjourney、Suno),默认为 `true`,关闭后将不会更新任务进度。 -- `COHERE_SAFETY_SETTING`:Cohere模型[安全设置](https://docs.cohere.com/docs/safety-modes#overview),可选值为 `NONE`, `CONTEXTUAL`, `STRICT`,默认为 `NONE`。 -- `GEMINI_VISION_MAX_IMAGE_NUM`:Gemini模型最大图片数量,默认为 `16`,设置为 `-1` 则不限制。 -- `MAX_FILE_DOWNLOAD_MB`: 最大文件下载大小,单位 MB,默认为 `20`。 -- `CRYPTO_SECRET`:加密密钥,用于加密数据库内容。 -- `AZURE_DEFAULT_API_VERSION`:Azure渠道默认API版本,如果渠道设置中未指定API版本,则使用此版本,默认为 `2024-12-01-preview` -- `NOTIFICATION_LIMIT_DURATION_MINUTE`:通知限制的持续时间(分钟),默认为 `10`。 -- `NOTIFY_LIMIT_COUNT`:用户通知在指定持续时间内的最大数量,默认为 `2`。 +详细配置说明请参考[安装指南-环境变量配置](https://docs.newapi.pro/installation/environment-variables): -## 已废弃的环境变量 -- ~~`GEMINI_MODEL_MAP`(已废弃)~~:改为到`设置-模型相关设置`中设置 -- ~~`GEMINI_SAFETY_SETTING`(已废弃)~~:改为到`设置-模型相关设置`中设置 +- `GENERATE_DEFAULT_TOKEN`:是否为新注册用户生成初始令牌,默认为 `false` +- `STREAMING_TIMEOUT`:流式回复超时时间,默认60秒 +- `DIFY_DEBUG`:Dify渠道是否输出工作流和节点信息,默认 `true` +- `FORCE_STREAM_OPTION`:是否覆盖客户端stream_options参数,默认 `true` +- `GET_MEDIA_TOKEN`:是否统计图片token,默认 `true` +- `GET_MEDIA_TOKEN_NOT_STREAM`:非流情况下是否统计图片token,默认 `true` +- `UPDATE_TASK`:是否更新异步任务(Midjourney、Suno),默认 `true` +- `COHERE_SAFETY_SETTING`:Cohere模型安全设置,可选值为 `NONE`, `CONTEXTUAL`, `STRICT`,默认 `NONE` +- `GEMINI_VISION_MAX_IMAGE_NUM`:Gemini模型最大图片数量,默认 `16` +- `MAX_FILE_DOWNLOAD_MB`: 最大文件下载大小,单位MB,默认 `20` +- `CRYPTO_SECRET`:加密密钥,用于加密数据库内容 +- `AZURE_DEFAULT_API_VERSION`:Azure渠道默认API版本,默认 `2024-12-01-preview` +- `NOTIFICATION_LIMIT_DURATION_MINUTE`:通知限制持续时间,默认 `10`分钟 +- `NOTIFY_LIMIT_COUNT`:用户通知在指定持续时间内的最大数量,默认 `2` ## 部署 +详细部署指南请参考[安装指南-部署方式](https://docs.newapi.pro/installation): + > [!TIP] > 最新版Docker镜像:`calciumion/new-api:latest` > 默认账号root 密码123456 -### 多机部署 -- 必须设置环境变量 `SESSION_SECRET`,否则会导致多机部署时登录状态不一致。 -- 如果公用Redis,必须设置 `CRYPTO_SECRET`,否则会导致多机部署时Redis内容无法获取。 +### 多机部署注意事项 +- 必须设置环境变量 `SESSION_SECRET`,否则会导致多机部署时登录状态不一致 +- 如果公用Redis,必须设置 `CRYPTO_SECRET`,否则会导致多机部署时Redis内容无法获取 ### 部署要求 -- 本地数据库(默认):SQLite(Docker 部署默认使用 SQLite,必须挂载 `/data` 目录到宿主机) -- 远程数据库:MySQL 版本 >= 5.7.8,PgSQL 版本 >= 9.6 +- 本地数据库(默认):SQLite(Docker部署必须挂载`/data`目录) +- 远程数据库:MySQL版本 >= 5.7.8,PgSQL版本 >= 9.6 -### 使用宝塔面板Docker功能部署 -安装宝塔面板 (**9.2.0版本**及以上),前往 [宝塔面板](https://www.bt.cn/new/download.html) 官网,选择正式版的脚本下载安装 -安装后登录宝塔面板,在菜单栏中点击 Docker ,首次进入会提示安装 Docker 服务,点击立即安装,按提示完成安装 -安装完成后在应用商店中找到 **New-API** ,点击安装,配置基本选项 即可完成安装 +### 部署方式 + +#### 使用宝塔面板Docker功能部署 +安装宝塔面板(**9.2.0版本**及以上),在应用商店中找到**New-API**安装即可。 [图文教程](BT.md) -### 基于 Docker 进行部署 - -> [!TIP] -> 默认管理员账号root 密码123456 - -### 使用 Docker Compose 部署(推荐) +#### 使用Docker Compose部署(推荐) ```shell # 下载项目 git clone https://github.com/Calcium-Ion/new-api.git cd new-api -# 按需编辑 docker-compose.yml -# nano docker-compose.yml -# vim docker-compose.yml +# 按需编辑docker-compose.yml # 启动 docker-compose up -d ``` -#### 更新版本 +#### 直接使用Docker镜像 ```shell -docker-compose pull -docker-compose up -d -``` - -### 直接使用 Docker 镜像 -```shell -# 使用 SQLite 的部署命令: +# 使用SQLite docker run --name new-api -d --restart always -p 3000:3000 -e TZ=Asia/Shanghai -v /home/ubuntu/data/new-api:/data calciumion/new-api:latest -# 使用 MySQL 的部署命令,在上面的基础上添加 `-e SQL_DSN="root:123456@tcp(localhost:3306)/oneapi"`,请自行修改数据库连接参数。 -# 例如: +# 使用MySQL docker run --name new-api -d --restart always -p 3000:3000 -e SQL_DSN="root:123456@tcp(localhost:3306)/oneapi" -e TZ=Asia/Shanghai -v /home/ubuntu/data/new-api:/data calciumion/new-api:latest ``` -#### 更新版本 -```shell -# 拉取最新镜像 -docker pull calciumion/new-api:latest -# 停止并删除旧容器 -docker stop new-api -docker rm new-api -# 使用相同参数运行新容器 -docker run --name new-api -d --restart always -p 3000:3000 -e TZ=Asia/Shanghai -v /home/ubuntu/data/new-api:/data calciumion/new-api:latest -``` +## 渠道重试与缓存 -或者使用 Watchtower 自动更新(不推荐,可能会导致数据库不兼容): -```shell -docker run --rm -v /var/run/docker.sock:/var/run/docker.sock containrrr/watchtower -cR -``` +详细说明请参考[配置与维护-缓存设置](https://docs.newapi.pro/installation/configuration/cache): + +渠道重试功能已经实现,可以在`设置->运营设置->通用设置`设置重试次数,**建议开启缓存**功能。 -## 渠道重试 -渠道重试功能已经实现,可以在`设置->运营设置->通用设置`设置重试次数,**建议开启缓存**功能。 -如果开启了重试功能,重试使用下一个优先级,以此类推。 ### 缓存设置方法 -1. `REDIS_CONN_STRING`:设置之后将使用 Redis 作为缓存使用。 - + 例子:`REDIS_CONN_STRING=redis://default:redispw@localhost:49153` -2. `MEMORY_CACHE_ENABLED`:启用内存缓存(如果设置了`REDIS_CONN_STRING`,则无需手动设置),会导致用户额度的更新存在一定的延迟,可选值为 `true` 和 `false`,未设置则默认为 `false`。 - + 例子:`MEMORY_CACHE_ENABLED=true` -### 为什么有的时候没有重试 -这些错误码不会重试:400,504,524 -### 我想让400也重试 -在`渠道->编辑`中,将`状态码复写`改为 -```json -{ - "400": "500" -} -``` -可以实现400错误转为500错误,从而重试 +1. `REDIS_CONN_STRING`:设置Redis作为缓存 +2. `MEMORY_CACHE_ENABLED`:启用内存缓存(设置了Redis则无需手动设置) -## Midjourney接口设置文档 -[对接文档](Midjourney.md) +## 接口文档 -## Suno接口设置文档 -[对接文档](Suno.md) +详细接口文档请参考[接口文档](https://docs.newapi.pro/api): -## 界面截图 -![image](https://github.com/user-attachments/assets/a0dcd349-5df8-4dc8-9acf-ca272b239919) - - -![image](https://github.com/user-attachments/assets/c7d0f7e1-729c-43e2-ac7c-2cb73b0afc8e) - -![image](https://github.com/user-attachments/assets/29f81de5-33fc-4fc5-a5ff-f9b54b653c7c) - -![image](https://github.com/user-attachments/assets/4fa53e18-d2c5-477a-9b26-b86e44c71e35) - -## 交流群 - +- [聊天接口(Chat)](https://docs.newapi.pro/api/openai-chat) +- [图像接口(Image)](https://docs.newapi.pro/api/openai-image) +- [Midjourney接口](https://docs.newapi.pro/api/midjourney-proxy-image) +- [音乐接口(Music)](https://docs.newapi.pro/api/relay/music) +- [Suno接口](https://docs.newapi.pro/api/suno-music) +- [重排序接口(Rerank)](https://docs.newapi.pro/api/jinaai-rerank) +- [实时对话接口(Realtime)](https://docs.newapi.pro/api/openai-realtime) +- [Claude聊天接口(messages)](https://docs.newapi.pro/api/anthropic-chat) ## 相关项目 - [One API](https://github.com/songquanpeng/one-api):原版项目 - [Midjourney-Proxy](https://github.com/novicezk/midjourney-proxy):Midjourney接口支持 -- [chatnio](https://github.com/Deeptrain-Community/chatnio):下一代 AI 一站式 B/C 端解决方案 +- [chatnio](https://github.com/Deeptrain-Community/chatnio):下一代AI一站式B/C端解决方案 - [neko-api-key-tool](https://github.com/Calcium-Ion/neko-api-key-tool):用key查询使用额度 其他基于New API的项目: -- [new-api-horizon](https://github.com/Calcium-Ion/new-api-horizon):New API高性能优化版,专注于高并发优化,并支持Claude格式 -- [VoAPI](https://github.com/VoAPI/VoAPI):基于New API的前端美化版本,闭源免费 +- [new-api-horizon](https://github.com/Calcium-Ion/new-api-horizon):New API高性能优化版 +- [VoAPI](https://github.com/VoAPI/VoAPI):基于New API的前端美化版本 + +## 帮助支持 + +如有问题,请参考[帮助支持](https://docs.newapi.pro/support): +- [社区交流](https://docs.newapi.pro/support/community) +- [反馈问题](https://docs.newapi.pro/support/feedback) +- [常见问题](https://docs.newapi.pro/support/faq) ## 🌟 Star History From b291fbff6b94136e19135726b73e68188479be6e Mon Sep 17 00:00:00 2001 From: Calcium-Ion <61247483+Calcium-Ion@users.noreply.github.com> Date: Wed, 12 Mar 2025 22:13:35 +0800 Subject: [PATCH 09/18] Update README.md --- README.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index 6c1054f9..bc417946 100644 --- a/README.md +++ b/README.md @@ -191,8 +191,8 @@ docker run --name new-api -d --restart always -p 3000:3000 -e SQL_DSN="root:1234 ## 帮助支持 如有问题,请参考[帮助支持](https://docs.newapi.pro/support): -- [社区交流](https://docs.newapi.pro/support/community) -- [反馈问题](https://docs.newapi.pro/support/feedback) +- [社区交流](https://docs.newapi.pro/support/community-interaction) +- [反馈问题](https://docs.newapi.pro/support/feedback-issues) - [常见问题](https://docs.newapi.pro/support/faq) ## 🌟 Star History From c25d4d8d239222368314215abc4f1911733f0691 Mon Sep 17 00:00:00 2001 From: Calcium-Ion <61247483+Calcium-Ion@users.noreply.github.com> Date: Wed, 12 Mar 2025 22:22:21 +0800 Subject: [PATCH 10/18] Update README.md --- README.md | 3 --- 1 file changed, 3 deletions(-) diff --git a/README.md b/README.md index bc417946..d9209f6a 100644 --- a/README.md +++ b/README.md @@ -156,9 +156,6 @@ docker run --name new-api -d --restart always -p 3000:3000 -e SQL_DSN="root:1234 ``` ## 渠道重试与缓存 - -详细说明请参考[配置与维护-缓存设置](https://docs.newapi.pro/installation/configuration/cache): - 渠道重试功能已经实现,可以在`设置->运营设置->通用设置`设置重试次数,**建议开启缓存**功能。 ### 缓存设置方法 From 23596d22c93d96b91c494725aa2dbdfecf61a236 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=9C=8D=E9=9B=A8=E4=BD=B3?= Date: Thu, 13 Mar 2025 08:54:45 +0800 Subject: [PATCH 11/18] Refactor: Optimize the ImageHandler under the Alibaba large model to retrieve the key from the header. Reason: The info parameter already includes the key, so there is no need to retrieve it again from the header. Solution: Delete the code for obtaining the key and directly use info.ApiKey. --- relay/channel/ali/image.go | 12 +++++------- 1 file changed, 5 insertions(+), 7 deletions(-) diff --git a/relay/channel/ali/image.go b/relay/channel/ali/image.go index 5cbf16b5..44203583 100644 --- a/relay/channel/ali/image.go +++ b/relay/channel/ali/image.go @@ -26,7 +26,7 @@ func oaiImage2Ali(request dto.ImageRequest) *AliImageRequest { return &imageRequest } -func updateTask(info *relaycommon.RelayInfo, taskID string, key string) (*AliResponse, error, []byte) { +func updateTask(info *relaycommon.RelayInfo, taskID string) (*AliResponse, error, []byte) { url := fmt.Sprintf("%s/api/v1/tasks/%s", info.BaseUrl, taskID) var aliResponse AliResponse @@ -36,7 +36,7 @@ func updateTask(info *relaycommon.RelayInfo, taskID string, key string) (*AliRes return &aliResponse, err, nil } - req.Header.Set("Authorization", "Bearer "+key) + req.Header.Set("Authorization", "Bearer "+info.ApiKey) client := &http.Client{} resp, err := client.Do(req) @@ -58,7 +58,7 @@ func updateTask(info *relaycommon.RelayInfo, taskID string, key string) (*AliRes return &response, nil, responseBody } -func asyncTaskWait(info *relaycommon.RelayInfo, taskID string, key string) (*AliResponse, []byte, error) { +func asyncTaskWait(info *relaycommon.RelayInfo, taskID string) (*AliResponse, []byte, error) { waitSeconds := 3 step := 0 maxStep := 20 @@ -68,7 +68,7 @@ func asyncTaskWait(info *relaycommon.RelayInfo, taskID string, key string) (*Ali for { step++ - rsp, err, body := updateTask(info, taskID, key) + rsp, err, body := updateTask(info, taskID) responseBody = body if err != nil { return &taskResponse, responseBody, err @@ -125,8 +125,6 @@ func responseAli2OpenAIImage(c *gin.Context, response *AliResponse, info *relayc } func aliImageHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) { - apiKey := c.Request.Header.Get("Authorization") - apiKey = strings.TrimPrefix(apiKey, "Bearer ") responseFormat := c.GetString("response_format") var aliTaskResponse AliResponse @@ -148,7 +146,7 @@ func aliImageHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rela return service.OpenAIErrorWrapper(errors.New(aliTaskResponse.Message), "ali_async_task_failed", http.StatusInternalServerError), nil } - aliResponse, _, err := asyncTaskWait(info, aliTaskResponse.Output.TaskId, apiKey) + aliResponse, _, err := asyncTaskWait(info, aliTaskResponse.Output.TaskId) if err != nil { return service.OpenAIErrorWrapper(err, "ali_async_task_wait_failed", http.StatusInternalServerError), nil } From 7e46d4217d45e10df5c07430aafcf11e17bf1221 Mon Sep 17 00:00:00 2001 From: "1808837298@qq.com" <1808837298@qq.com> Date: Thu, 13 Mar 2025 19:32:08 +0800 Subject: [PATCH 12/18] =?UTF-8?q?feat:=20=E5=88=9D=E6=AD=A5=E5=85=BC?= =?UTF-8?q?=E5=AE=B9=E6=B5=81=E6=A8=A1=E5=BC=8F=E4=B8=8Bopenai=E6=B8=A0?= =?UTF-8?q?=E9=81=93=E7=B1=BB=E5=9E=8B=E8=BD=AC=E4=B8=BAclaude=E6=A0=BC?= =?UTF-8?q?=E5=BC=8F=E8=AE=BF=E9=97=AE=20#862?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- controller/channel-test.go | 2 +- dto/claude.go | 28 +++- relay/channel/adapter.go | 2 +- relay/channel/ali/adaptor.go | 2 +- relay/channel/api_request.go | 4 + relay/channel/aws/adaptor.go | 2 +- relay/channel/baidu/adaptor.go | 2 +- relay/channel/baidu_v2/adaptor.go | 2 +- relay/channel/claude/adaptor.go | 2 +- relay/channel/claude/relay-claude.go | 2 +- relay/channel/cloudflare/adaptor.go | 2 +- relay/channel/cohere/adaptor.go | 2 +- relay/channel/deepseek/adaptor.go | 2 +- relay/channel/dify/adaptor.go | 2 +- relay/channel/gemini/adaptor.go | 2 +- relay/channel/jina/adaptor.go | 2 +- relay/channel/mistral/adaptor.go | 2 +- relay/channel/mokaai/adaptor.go | 2 +- relay/channel/ollama/adaptor.go | 2 +- relay/channel/openai/adaptor.go | 24 +++- relay/channel/openai/helper.go | 188 +++++++++++++++++++++++++++ relay/channel/openai/relay-openai.go | 95 +------------- relay/channel/openrouter/adaptor.go | 2 +- relay/channel/palm/adaptor.go | 2 +- relay/channel/perplexity/adaptor.go | 2 +- relay/channel/siliconflow/adaptor.go | 2 +- relay/channel/tencent/adaptor.go | 2 +- relay/channel/vertex/adaptor.go | 2 +- relay/channel/volcengine/adaptor.go | 2 +- relay/channel/xunfei/adaptor.go | 2 +- relay/channel/zhipu/adaptor.go | 2 +- relay/channel/zhipu_4v/adaptor.go | 2 +- relay/claude_handler.go | 5 +- relay/common/relay_info.go | 17 ++- relay/helper/common.go | 16 +++ relay/relay-text.go | 5 +- service/convert.go | 119 +++++++++++------ 37 files changed, 390 insertions(+), 165 deletions(-) create mode 100644 relay/channel/openai/helper.go diff --git a/controller/channel-test.go b/controller/channel-test.go index 39af95e1..8ecbde3f 100644 --- a/controller/channel-test.go +++ b/controller/channel-test.go @@ -107,7 +107,7 @@ func testChannel(channel *model.Channel, testModel string) (err error, openAIErr adaptor.Init(info) - convertedRequest, err := adaptor.ConvertRequest(c, info, request) + convertedRequest, err := adaptor.ConvertOpenAIRequest(c, info, request) if err != nil { return err, nil } diff --git a/dto/claude.go b/dto/claude.go index 60f638f6..f7354230 100644 --- a/dto/claude.go +++ b/dto/claude.go @@ -13,7 +13,7 @@ type ClaudeMediaMessage struct { Source *ClaudeMessageSource `json:"source,omitempty"` Usage *ClaudeUsage `json:"usage,omitempty"` StopReason *string `json:"stop_reason,omitempty"` - PartialJson string `json:"partial_json,omitempty"` + PartialJson *string `json:"partial_json,omitempty"` Role string `json:"role,omitempty"` Thinking string `json:"thinking,omitempty"` Signature string `json:"signature,omitempty"` @@ -37,6 +37,32 @@ func (c *ClaudeMediaMessage) GetText() string { return *c.Text } +func (c *ClaudeMediaMessage) IsStringContent() bool { + var content string + return json.Unmarshal(c.Content, &content) == nil +} + +func (c *ClaudeMediaMessage) GetStringContent() string { + var content string + if err := json.Unmarshal(c.Content, &content); err == nil { + return content + } + return "" +} + +func (c *ClaudeMediaMessage) SetContent(content any) { + jsonContent, _ := json.Marshal(content) + c.Content = jsonContent +} + +func (c *ClaudeMediaMessage) ParseMediaContent() []ClaudeMediaMessage { + var mediaContent []ClaudeMediaMessage + if err := json.Unmarshal(c.Content, &mediaContent); err == nil { + return mediaContent + } + return make([]ClaudeMediaMessage, 0) +} + type ClaudeMessageSource struct { Type string `json:"type"` MediaType string `json:"media_type"` diff --git a/relay/channel/adapter.go b/relay/channel/adapter.go index 9f449b54..e097dbe6 100644 --- a/relay/channel/adapter.go +++ b/relay/channel/adapter.go @@ -13,7 +13,7 @@ type Adaptor interface { Init(info *relaycommon.RelayInfo) GetRequestURL(info *relaycommon.RelayInfo) (string, error) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error - ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) + ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.EmbeddingRequest) (any, error) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) diff --git a/relay/channel/ali/adaptor.go b/relay/channel/ali/adaptor.go index 9d3ee99f..e28278e1 100644 --- a/relay/channel/ali/adaptor.go +++ b/relay/channel/ali/adaptor.go @@ -50,7 +50,7 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *rel return nil } -func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) { +func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) { if request == nil { return nil, errors.New("request is nil") } diff --git a/relay/channel/api_request.go b/relay/channel/api_request.go index a60bc6f1..8b2ca889 100644 --- a/relay/channel/api_request.go +++ b/relay/channel/api_request.go @@ -7,6 +7,7 @@ import ( "github.com/gorilla/websocket" "io" "net/http" + common2 "one-api/common" "one-api/relay/common" "one-api/relay/constant" "one-api/service" @@ -31,6 +32,9 @@ func DoApiRequest(a Adaptor, c *gin.Context, info *common.RelayInfo, requestBody if err != nil { return nil, fmt.Errorf("get request url failed: %w", err) } + if common2.DebugEnabled { + println("fullRequestURL:", fullRequestURL) + } req, err := http.NewRequest(c.Request.Method, fullRequestURL, requestBody) if err != nil { return nil, fmt.Errorf("new request failed: %w", err) diff --git a/relay/channel/aws/adaptor.go b/relay/channel/aws/adaptor.go index e735ee2b..94edda33 100644 --- a/relay/channel/aws/adaptor.go +++ b/relay/channel/aws/adaptor.go @@ -47,7 +47,7 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *rel return nil } -func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) { +func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) { if request == nil { return nil, errors.New("request is nil") } diff --git a/relay/channel/baidu/adaptor.go b/relay/channel/baidu/adaptor.go index 105f2a9b..eecb0bac 100644 --- a/relay/channel/baidu/adaptor.go +++ b/relay/channel/baidu/adaptor.go @@ -110,7 +110,7 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *rel return nil } -func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) { +func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) { if request == nil { return nil, errors.New("request is nil") } diff --git a/relay/channel/baidu_v2/adaptor.go b/relay/channel/baidu_v2/adaptor.go index 855ed717..9645bbf5 100644 --- a/relay/channel/baidu_v2/adaptor.go +++ b/relay/channel/baidu_v2/adaptor.go @@ -44,7 +44,7 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *rel return nil } -func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) { +func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) { if request == nil { return nil, errors.New("request is nil") } diff --git a/relay/channel/claude/adaptor.go b/relay/channel/claude/adaptor.go index a5c475fa..6d65d6d4 100644 --- a/relay/channel/claude/adaptor.go +++ b/relay/channel/claude/adaptor.go @@ -64,7 +64,7 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *rel return nil } -func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) { +func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) { if request == nil { return nil, errors.New("request is nil") } diff --git a/relay/channel/claude/relay-claude.go b/relay/channel/claude/relay-claude.go index 74b73454..8607f77d 100644 --- a/relay/channel/claude/relay-claude.go +++ b/relay/channel/claude/relay-claude.go @@ -335,7 +335,7 @@ func StreamResponseClaude2OpenAI(reqMode int, claudeResponse *dto.ClaudeResponse case "input_json_delta": tools = append(tools, dto.ToolCallResponse{ Function: dto.FunctionResponse{ - Arguments: claudeResponse.Delta.PartialJson, + Arguments: *claudeResponse.Delta.PartialJson, }, }) case "signature_delta": diff --git a/relay/channel/cloudflare/adaptor.go b/relay/channel/cloudflare/adaptor.go index b21e25f3..3d5a5a8a 100644 --- a/relay/channel/cloudflare/adaptor.go +++ b/relay/channel/cloudflare/adaptor.go @@ -43,7 +43,7 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *rel return nil } -func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) { +func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) { if request == nil { return nil, errors.New("request is nil") } diff --git a/relay/channel/cohere/adaptor.go b/relay/channel/cohere/adaptor.go index 7675d546..53a357ad 100644 --- a/relay/channel/cohere/adaptor.go +++ b/relay/channel/cohere/adaptor.go @@ -48,7 +48,7 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *rel return nil } -func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) { +func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) { return requestOpenAI2Cohere(*request), nil } diff --git a/relay/channel/deepseek/adaptor.go b/relay/channel/deepseek/adaptor.go index ad01b8f4..64d92a48 100644 --- a/relay/channel/deepseek/adaptor.go +++ b/relay/channel/deepseek/adaptor.go @@ -50,7 +50,7 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *rel return nil } -func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) { +func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) { if request == nil { return nil, errors.New("request is nil") } diff --git a/relay/channel/dify/adaptor.go b/relay/channel/dify/adaptor.go index 96aff447..003b5f83 100644 --- a/relay/channel/dify/adaptor.go +++ b/relay/channel/dify/adaptor.go @@ -70,7 +70,7 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *rel return nil } -func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) { +func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) { if request == nil { return nil, errors.New("request is nil") } diff --git a/relay/channel/gemini/adaptor.go b/relay/channel/gemini/adaptor.go index a629968b..c5a547ba 100644 --- a/relay/channel/gemini/adaptor.go +++ b/relay/channel/gemini/adaptor.go @@ -95,7 +95,7 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *rel return nil } -func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) { +func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) { if request == nil { return nil, errors.New("request is nil") } diff --git a/relay/channel/jina/adaptor.go b/relay/channel/jina/adaptor.go index bcfc8dea..a65e820e 100644 --- a/relay/channel/jina/adaptor.go +++ b/relay/channel/jina/adaptor.go @@ -49,7 +49,7 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *rel return nil } -func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) { +func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) { return request, nil } diff --git a/relay/channel/mistral/adaptor.go b/relay/channel/mistral/adaptor.go index 80547346..4857209f 100644 --- a/relay/channel/mistral/adaptor.go +++ b/relay/channel/mistral/adaptor.go @@ -43,7 +43,7 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *rel return nil } -func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) { +func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) { if request == nil { return nil, errors.New("request is nil") } diff --git a/relay/channel/mokaai/adaptor.go b/relay/channel/mokaai/adaptor.go index 151072cb..304351fd 100644 --- a/relay/channel/mokaai/adaptor.go +++ b/relay/channel/mokaai/adaptor.go @@ -57,7 +57,7 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *rel return nil } -func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) { +func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) { if request == nil { return nil, errors.New("request is nil") } diff --git a/relay/channel/ollama/adaptor.go b/relay/channel/ollama/adaptor.go index 4190dd3f..2101bf70 100644 --- a/relay/channel/ollama/adaptor.go +++ b/relay/channel/ollama/adaptor.go @@ -49,7 +49,7 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *rel return nil } -func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) { +func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) { if request == nil { return nil, errors.New("request is nil") } diff --git a/relay/channel/openai/adaptor.go b/relay/channel/openai/adaptor.go index 196343e8..d8bc808e 100644 --- a/relay/channel/openai/adaptor.go +++ b/relay/channel/openai/adaptor.go @@ -21,6 +21,7 @@ import ( "one-api/relay/channel/xinference" relaycommon "one-api/relay/common" "one-api/relay/constant" + "one-api/service" "strings" ) @@ -29,10 +30,20 @@ type Adaptor struct { ResponseFormat string } -func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) { - //TODO implement me - panic("implement me") - return nil, nil +func (a *Adaptor) ConvertClaudeRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.ClaudeRequest) (any, error) { + if !strings.HasPrefix(request.Model, "claude") { + return nil, fmt.Errorf("you are using openai channel type with path /v1/messages, only claude model supported convert, but got %s", request.Model) + } + aiRequest, err := service.ClaudeToOpenAIRequest(*request) + if err != nil { + return nil, err + } + if info.SupportStreamOptions { + aiRequest.StreamOptions = &dto.StreamOptions{ + IncludeUsage: true, + } + } + return a.ConvertOpenAIRequest(c, info, aiRequest) } func (a *Adaptor) Init(info *relaycommon.RelayInfo) { @@ -40,6 +51,9 @@ func (a *Adaptor) Init(info *relaycommon.RelayInfo) { } func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) { + if info.RelayFormat == relaycommon.RelayFormatClaude { + return fmt.Sprintf("%s/v1/chat/completions", info.BaseUrl), nil + } if info.RelayMode == constant.RelayModeRealtime { if strings.HasPrefix(info.BaseUrl, "https://") { baseUrl := strings.TrimPrefix(info.BaseUrl, "https://") @@ -115,7 +129,7 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, header *http.Header, info * return nil } -func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) { +func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) { if request == nil { return nil, errors.New("request is nil") } diff --git a/relay/channel/openai/helper.go b/relay/channel/openai/helper.go new file mode 100644 index 00000000..a6d0aed8 --- /dev/null +++ b/relay/channel/openai/helper.go @@ -0,0 +1,188 @@ +package openai + +import ( + "encoding/json" + "one-api/common" + "one-api/dto" + relaycommon "one-api/relay/common" + relayconstant "one-api/relay/constant" + "one-api/relay/helper" + "one-api/service" + "strings" + + "github.com/gin-gonic/gin" +) + +// 辅助函数 +func handleStreamFormat(c *gin.Context, info *relaycommon.RelayInfo, data string, forceFormat bool, thinkToContent bool) error { + info.SendResponseCount++ + switch info.RelayFormat { + case relaycommon.RelayFormatOpenAI: + return sendStreamData(c, info, data, forceFormat, thinkToContent) + case relaycommon.RelayFormatClaude: + return handleClaudeFormat(c, data, info) + } + return nil +} + +func handleClaudeFormat(c *gin.Context, data string, info *relaycommon.RelayInfo) error { + var streamResponse dto.ChatCompletionsStreamResponse + if err := json.Unmarshal(common.StringToByteSlice(data), &streamResponse); err != nil { + return err + } + + claudeResponses := service.StreamResponseOpenAI2Claude(&streamResponse, info) + for _, resp := range claudeResponses { + helper.ClaudeData(c, *resp) + } + return nil +} + +func processStreamResponse(item string, responseTextBuilder *strings.Builder, toolCount *int) error { + var streamResponse dto.ChatCompletionsStreamResponse + if err := json.Unmarshal(common.StringToByteSlice(item), &streamResponse); err != nil { + return err + } + + for _, choice := range streamResponse.Choices { + responseTextBuilder.WriteString(choice.Delta.GetContentString()) + responseTextBuilder.WriteString(choice.Delta.GetReasoningContent()) + if choice.Delta.ToolCalls != nil { + if len(choice.Delta.ToolCalls) > *toolCount { + *toolCount = len(choice.Delta.ToolCalls) + } + for _, tool := range choice.Delta.ToolCalls { + responseTextBuilder.WriteString(tool.Function.Name) + responseTextBuilder.WriteString(tool.Function.Arguments) + } + } + } + return nil +} + +func processTokens(relayMode int, streamItems []string, responseTextBuilder *strings.Builder, toolCount *int) error { + streamResp := "[" + strings.Join(streamItems, ",") + "]" + + switch relayMode { + case relayconstant.RelayModeChatCompletions: + return processChatCompletions(streamResp, streamItems, responseTextBuilder, toolCount) + case relayconstant.RelayModeCompletions: + return processCompletions(streamResp, streamItems, responseTextBuilder) + } + return nil +} + +func processChatCompletions(streamResp string, streamItems []string, responseTextBuilder *strings.Builder, toolCount *int) error { + var streamResponses []dto.ChatCompletionsStreamResponse + if err := json.Unmarshal(common.StringToByteSlice(streamResp), &streamResponses); err != nil { + // 一次性解析失败,逐个解析 + common.SysError("error unmarshalling stream response: " + err.Error()) + for _, item := range streamItems { + if err := processStreamResponse(item, responseTextBuilder, toolCount); err != nil { + common.SysError("error processing stream response: " + err.Error()) + } + } + return nil + } + + // 批量处理所有响应 + for _, streamResponse := range streamResponses { + for _, choice := range streamResponse.Choices { + responseTextBuilder.WriteString(choice.Delta.GetContentString()) + responseTextBuilder.WriteString(choice.Delta.GetReasoningContent()) + if choice.Delta.ToolCalls != nil { + if len(choice.Delta.ToolCalls) > *toolCount { + *toolCount = len(choice.Delta.ToolCalls) + } + for _, tool := range choice.Delta.ToolCalls { + responseTextBuilder.WriteString(tool.Function.Name) + responseTextBuilder.WriteString(tool.Function.Arguments) + } + } + } + } + return nil +} + +func processCompletions(streamResp string, streamItems []string, responseTextBuilder *strings.Builder) error { + var streamResponses []dto.CompletionsStreamResponse + if err := json.Unmarshal(common.StringToByteSlice(streamResp), &streamResponses); err != nil { + // 一次性解析失败,逐个解析 + common.SysError("error unmarshalling stream response: " + err.Error()) + for _, item := range streamItems { + var streamResponse dto.CompletionsStreamResponse + if err := json.Unmarshal(common.StringToByteSlice(item), &streamResponse); err != nil { + continue + } + for _, choice := range streamResponse.Choices { + responseTextBuilder.WriteString(choice.Text) + } + } + return nil + } + + // 批量处理所有响应 + for _, streamResponse := range streamResponses { + for _, choice := range streamResponse.Choices { + responseTextBuilder.WriteString(choice.Text) + } + } + return nil +} + +func handleLastResponse(lastStreamData string, responseId *string, createAt *int64, + systemFingerprint *string, model *string, usage **dto.Usage, + containStreamUsage *bool, info *relaycommon.RelayInfo, + shouldSendLastResp *bool) error { + + var lastStreamResponse dto.ChatCompletionsStreamResponse + if err := json.Unmarshal(common.StringToByteSlice(lastStreamData), &lastStreamResponse); err != nil { + return err + } + + *responseId = lastStreamResponse.Id + *createAt = lastStreamResponse.Created + *systemFingerprint = lastStreamResponse.GetSystemFingerprint() + *model = lastStreamResponse.Model + + if service.ValidUsage(lastStreamResponse.Usage) { + *containStreamUsage = true + *usage = lastStreamResponse.Usage + if !info.ShouldIncludeUsage { + *shouldSendLastResp = false + } + } + + return nil +} + +func handleFinalResponse(c *gin.Context, info *relaycommon.RelayInfo, lastStreamData string, + responseId string, createAt int64, model string, systemFingerprint string, + usage *dto.Usage, containStreamUsage bool) { + + switch info.RelayFormat { + case relaycommon.RelayFormatOpenAI: + if info.ShouldIncludeUsage && !containStreamUsage { + response := helper.GenerateFinalUsageResponse(responseId, createAt, model, *usage) + response.SetSystemFingerprint(systemFingerprint) + helper.ObjectData(c, response) + } + helper.Done(c) + + case relaycommon.RelayFormatClaude: + var streamResponse dto.ChatCompletionsStreamResponse + if err := json.Unmarshal(common.StringToByteSlice(lastStreamData), &streamResponse); err != nil { + common.SysError("error unmarshalling stream response: " + err.Error()) + return + } + + if !containStreamUsage { + streamResponse.Usage = usage + } + + claudeResponses := service.StreamResponseOpenAI2Claude(&streamResponse, info) + for _, resp := range claudeResponses { + helper.ClaudeData(c, *resp) + } + } +} diff --git a/relay/channel/openai/relay-openai.go b/relay/channel/openai/relay-openai.go index ffd36d3c..2d1ad53e 100644 --- a/relay/channel/openai/relay-openai.go +++ b/relay/channel/openai/relay-openai.go @@ -12,7 +12,6 @@ import ( "one-api/constant" "one-api/dto" relaycommon "one-api/relay/common" - relayconstant "one-api/relay/constant" "one-api/relay/helper" "one-api/service" "os" @@ -137,10 +136,11 @@ func OaiStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel helper.StreamScannerHandler(c, resp, info, func(data string) bool { if lastStreamData != "" { - err := sendStreamData(c, info, lastStreamData, forceFormat, thinkToContent) + err := handleStreamFormat(c, info, lastStreamData, forceFormat, thinkToContent) if err != nil { - common.LogError(c, "streaming error: "+err.Error()) + common.SysError("error handling stream format: " + err.Error()) } + info.SetFirstResponseTime() } lastStreamData = data streamItems = append(streamItems, data) @@ -172,83 +172,9 @@ func OaiStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel sendStreamData(c, info, lastStreamData, forceFormat, thinkToContent) } - // 计算token - streamResp := "[" + strings.Join(streamItems, ",") + "]" - switch info.RelayMode { - case relayconstant.RelayModeChatCompletions: - var streamResponses []dto.ChatCompletionsStreamResponse - err := json.Unmarshal(common.StringToByteSlice(streamResp), &streamResponses) - if err != nil { - // 一次性解析失败,逐个解析 - common.SysError("error unmarshalling stream response: " + err.Error()) - for _, item := range streamItems { - var streamResponse dto.ChatCompletionsStreamResponse - err := json.Unmarshal(common.StringToByteSlice(item), &streamResponse) - if err == nil { - //if service.ValidUsage(streamResponse.Usage) { - // usage = streamResponse.Usage - //} - for _, choice := range streamResponse.Choices { - responseTextBuilder.WriteString(choice.Delta.GetContentString()) - - // handle both reasoning_content and reasoning - responseTextBuilder.WriteString(choice.Delta.GetReasoningContent()) - - if choice.Delta.ToolCalls != nil { - if len(choice.Delta.ToolCalls) > toolCount { - toolCount = len(choice.Delta.ToolCalls) - } - for _, tool := range choice.Delta.ToolCalls { - responseTextBuilder.WriteString(tool.Function.Name) - responseTextBuilder.WriteString(tool.Function.Arguments) - } - } - } - } - } - } else { - for _, streamResponse := range streamResponses { - //if service.ValidUsage(streamResponse.Usage) { - // usage = streamResponse.Usage - // containStreamUsage = true - //} - for _, choice := range streamResponse.Choices { - responseTextBuilder.WriteString(choice.Delta.GetContentString()) - responseTextBuilder.WriteString(choice.Delta.GetReasoningContent()) // This will handle both reasoning_content and reasoning - if choice.Delta.ToolCalls != nil { - if len(choice.Delta.ToolCalls) > toolCount { - toolCount = len(choice.Delta.ToolCalls) - } - for _, tool := range choice.Delta.ToolCalls { - responseTextBuilder.WriteString(tool.Function.Name) - responseTextBuilder.WriteString(tool.Function.Arguments) - } - } - } - } - } - case relayconstant.RelayModeCompletions: - var streamResponses []dto.CompletionsStreamResponse - err := json.Unmarshal(common.StringToByteSlice(streamResp), &streamResponses) - if err != nil { - // 一次性解析失败,逐个解析 - common.SysError("error unmarshalling stream response: " + err.Error()) - for _, item := range streamItems { - var streamResponse dto.CompletionsStreamResponse - err := json.Unmarshal(common.StringToByteSlice(item), &streamResponse) - if err == nil { - for _, choice := range streamResponse.Choices { - responseTextBuilder.WriteString(choice.Text) - } - } - } - } else { - for _, streamResponse := range streamResponses { - for _, choice := range streamResponse.Choices { - responseTextBuilder.WriteString(choice.Text) - } - } - } + // 处理token计算 + if err := processTokens(info.RelayMode, streamItems, &responseTextBuilder, &toolCount); err != nil { + common.SysError("error processing tokens: " + err.Error()) } if !containStreamUsage { @@ -262,15 +188,8 @@ func OaiStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel } } - if info.ShouldIncludeUsage && !containStreamUsage { - response := helper.GenerateFinalUsageResponse(responseId, createAt, model, *usage) - response.SetSystemFingerprint(systemFingerprint) - helper.ObjectData(c, response) - } + handleFinalResponse(c, info, lastStreamData, responseId, createAt, model, systemFingerprint, usage, containStreamUsage) - helper.Done(c) - - //resp.Body.Close() return nil, usage } diff --git a/relay/channel/openrouter/adaptor.go b/relay/channel/openrouter/adaptor.go index aef5afeb..f2909b6b 100644 --- a/relay/channel/openrouter/adaptor.go +++ b/relay/channel/openrouter/adaptor.go @@ -46,7 +46,7 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *rel return nil } -func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) { +func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) { return request, nil } diff --git a/relay/channel/palm/adaptor.go b/relay/channel/palm/adaptor.go index 69ef5001..f0220f4f 100644 --- a/relay/channel/palm/adaptor.go +++ b/relay/channel/palm/adaptor.go @@ -44,7 +44,7 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *rel return nil } -func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) { +func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) { if request == nil { return nil, errors.New("request is nil") } diff --git a/relay/channel/perplexity/adaptor.go b/relay/channel/perplexity/adaptor.go index de84406c..32f00047 100644 --- a/relay/channel/perplexity/adaptor.go +++ b/relay/channel/perplexity/adaptor.go @@ -44,7 +44,7 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *rel return nil } -func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) { +func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) { if request == nil { return nil, errors.New("request is nil") } diff --git a/relay/channel/siliconflow/adaptor.go b/relay/channel/siliconflow/adaptor.go index 754a1f00..1b319e2a 100644 --- a/relay/channel/siliconflow/adaptor.go +++ b/relay/channel/siliconflow/adaptor.go @@ -54,7 +54,7 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *rel return nil } -func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) { +func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) { return request, nil } diff --git a/relay/channel/tencent/adaptor.go b/relay/channel/tencent/adaptor.go index 28a02aae..f2b51ee9 100644 --- a/relay/channel/tencent/adaptor.go +++ b/relay/channel/tencent/adaptor.go @@ -58,7 +58,7 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *rel return nil } -func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) { +func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) { if request == nil { return nil, errors.New("request is nil") } diff --git a/relay/channel/vertex/adaptor.go b/relay/channel/vertex/adaptor.go index 2f348e46..e09845eb 100644 --- a/relay/channel/vertex/adaptor.go +++ b/relay/channel/vertex/adaptor.go @@ -122,7 +122,7 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *rel return nil } -func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) { +func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) { if request == nil { return nil, errors.New("request is nil") } diff --git a/relay/channel/volcengine/adaptor.go b/relay/channel/volcengine/adaptor.go index f423d587..5e5e276b 100644 --- a/relay/channel/volcengine/adaptor.go +++ b/relay/channel/volcengine/adaptor.go @@ -56,7 +56,7 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *rel return nil } -func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) { +func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) { if request == nil { return nil, errors.New("request is nil") } diff --git a/relay/channel/xunfei/adaptor.go b/relay/channel/xunfei/adaptor.go index d66f3732..9521bb47 100644 --- a/relay/channel/xunfei/adaptor.go +++ b/relay/channel/xunfei/adaptor.go @@ -44,7 +44,7 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *rel return nil } -func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) { +func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) { if request == nil { return nil, errors.New("request is nil") } diff --git a/relay/channel/zhipu/adaptor.go b/relay/channel/zhipu/adaptor.go index aa612f0c..04369001 100644 --- a/relay/channel/zhipu/adaptor.go +++ b/relay/channel/zhipu/adaptor.go @@ -48,7 +48,7 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *rel return nil } -func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) { +func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) { if request == nil { return nil, errors.New("request is nil") } diff --git a/relay/channel/zhipu_4v/adaptor.go b/relay/channel/zhipu_4v/adaptor.go index 7a23e212..ba24814c 100644 --- a/relay/channel/zhipu_4v/adaptor.go +++ b/relay/channel/zhipu_4v/adaptor.go @@ -45,7 +45,7 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *rel return nil } -func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) { +func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) { if request == nil { return nil, errors.New("request is nil") } diff --git a/relay/claude_handler.go b/relay/claude_handler.go index 97de772b..fb68a88a 100644 --- a/relay/claude_handler.go +++ b/relay/claude_handler.go @@ -114,13 +114,14 @@ func ClaudeHelper(c *gin.Context) (claudeError *dto.ClaudeErrorWithStatusCode) { return service.ClaudeErrorWrapperLocal(err, "convert_request_failed", http.StatusInternalServerError) } jsonData, err := json.Marshal(convertedRequest) + if common.DebugEnabled { + println("requestBody: ", string(jsonData)) + } if err != nil { return service.ClaudeErrorWrapperLocal(err, "json_marshal_failed", http.StatusInternalServerError) } requestBody = bytes.NewBuffer(jsonData) - //log.Printf("requestBody: %s", requestBody) - statusCodeMappingStr := c.GetString("status_code_mapping") var httpResp *http.Response resp, err := adaptor.DoRequest(c, relayInfo, requestBody) diff --git a/relay/common/relay_info.go b/relay/common/relay_info.go index 3b5ef795..5075d07d 100644 --- a/relay/common/relay_info.go +++ b/relay/common/relay_info.go @@ -17,6 +17,16 @@ type ThinkingContentInfo struct { SendLastThinkingContent bool } +const ( + LastMessageTypeText = "text" + LastMessageTypeTools = "tools" +) + +type ClaudeConvertInfo struct { + LastMessagesType string + Index int +} + const ( RelayFormatOpenAI = "openai" RelayFormatClaude = "claude" @@ -64,8 +74,9 @@ type RelayInfo struct { UserEmail string UserQuota int RelayFormat string - ResponseTimes int64 + SendResponseCount int ThinkingContentInfo + ClaudeConvertInfo } // 定义支持流式选项的通道类型 @@ -93,6 +104,9 @@ func GenRelayInfoClaude(c *gin.Context) *RelayInfo { info := GenRelayInfo(c) info.RelayFormat = RelayFormatClaude info.ShouldIncludeUsage = false + info.ClaudeConvertInfo = ClaudeConvertInfo{ + LastMessagesType: LastMessageTypeText, + } return info } @@ -172,7 +186,6 @@ func (info *RelayInfo) SetIsStream(isStream bool) { } func (info *RelayInfo) SetFirstResponseTime() { - info.ResponseTimes++ if info.isFirstResponse { info.FirstResponseTime = time.Now() info.isFirstResponse = false diff --git a/relay/helper/common.go b/relay/helper/common.go index 6af55a86..13fc85ab 100644 --- a/relay/helper/common.go +++ b/relay/helper/common.go @@ -19,6 +19,22 @@ func SetEventStreamHeaders(c *gin.Context) { c.Writer.Header().Set("X-Accel-Buffering", "no") } +func ClaudeData(c *gin.Context, resp dto.ClaudeResponse) error { + jsonData, err := json.Marshal(resp) + if err != nil { + common.SysError("error marshalling stream response: " + err.Error()) + } else { + c.Render(-1, common.CustomEvent{Data: fmt.Sprintf("event: %s\n", resp.Type)}) + c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonData)}) + } + if flusher, ok := c.Writer.(http.Flusher); ok { + flusher.Flush() + } else { + return errors.New("streaming error: flusher not found") + } + return nil +} + func ClaudeChunkData(c *gin.Context, resp dto.ClaudeResponse, data string) { c.Render(-1, common.CustomEvent{Data: fmt.Sprintf("event: %s\n", resp.Type)}) c.Render(-1, common.CustomEvent{Data: fmt.Sprintf("data: %s\n", data)}) diff --git a/relay/relay-text.go b/relay/relay-text.go index a0a97617..a61718fc 100644 --- a/relay/relay-text.go +++ b/relay/relay-text.go @@ -160,7 +160,7 @@ func TextHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) { } requestBody = bytes.NewBuffer(body) } else { - convertedRequest, err := adaptor.ConvertRequest(c, relayInfo, textRequest) + convertedRequest, err := adaptor.ConvertOpenAIRequest(c, relayInfo, textRequest) if err != nil { return service.OpenAIErrorWrapperLocal(err, "convert_request_failed", http.StatusInternalServerError) } @@ -168,6 +168,9 @@ func TextHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) { if err != nil { return service.OpenAIErrorWrapperLocal(err, "json_marshal_failed", http.StatusInternalServerError) } + if common.DebugEnabled { + println("requestBody: ", string(jsonData)) + } requestBody = bytes.NewBuffer(jsonData) } diff --git a/service/convert.go b/service/convert.go index c4916df2..dbaae654 100644 --- a/service/convert.go +++ b/service/convert.go @@ -44,24 +44,26 @@ func ClaudeToOpenAIRequest(claudeRequest dto.ClaudeRequest) (*dto.GeneralOpenAIR openAIMessages := make([]dto.Message, 0) // Add system message if present - if claudeRequest.IsStringSystem() { - openAIMessage := dto.Message{ - Role: "system", - } - openAIMessage.SetStringContent(claudeRequest.GetStringSystem()) - openAIMessages = append(openAIMessages, openAIMessage) - } else { - systems := claudeRequest.ParseSystem() - if len(systems) > 0 { - systemStr := "" + if claudeRequest.System != nil { + if claudeRequest.IsStringSystem() { openAIMessage := dto.Message{ Role: "system", } - for _, system := range systems { - systemStr += system.Type - } - openAIMessage.SetStringContent(systemStr) + openAIMessage.SetStringContent(claudeRequest.GetStringSystem()) openAIMessages = append(openAIMessages, openAIMessage) + } else { + systems := claudeRequest.ParseSystem() + if len(systems) > 0 { + systemStr := "" + openAIMessage := dto.Message{ + Role: "system", + } + for _, system := range systems { + systemStr += system.Type + } + openAIMessage.SetStringContent(systemStr) + openAIMessages = append(openAIMessages, openAIMessage) + } } } for _, claudeMessage := range claudeRequest.Messages { @@ -100,7 +102,8 @@ func ClaudeToOpenAIRequest(claudeRequest dto.ClaudeRequest) (*dto.GeneralOpenAIR mediaMessages = append(mediaMessages, mediaMessage) case "tool_use": toolCall := dto.ToolCallRequest{ - ID: mediaMsg.Id, + ID: mediaMsg.Id, + Type: "function", Function: dto.FunctionRequest{ Name: mediaMsg.Name, Arguments: toJSONString(mediaMsg.Input), @@ -111,20 +114,33 @@ func ClaudeToOpenAIRequest(claudeRequest dto.ClaudeRequest) (*dto.GeneralOpenAIR // Add tool result as a separate message oaiToolMessage := dto.Message{ Role: "tool", + Name: &mediaMsg.Name, ToolCallId: mediaMsg.ToolUseId, } - oaiToolMessage.Content = mediaMsg.Content + //oaiToolMessage.SetStringContent(*mediaMsg.GetMediaContent().Text) + if mediaMsg.IsStringContent() { + oaiToolMessage.SetStringContent(mediaMsg.GetStringContent()) + } else { + mediaContents := mediaMsg.ParseMediaContent() + if len(mediaContents) > 0 && mediaContents[0].Text != nil { + oaiToolMessage.SetStringContent(*mediaContents[0].Text) + } + } + openAIMessages = append(openAIMessages, oaiToolMessage) } } - openAIMessage.SetMediaContent(mediaMessages) + if len(mediaMessages) > 0 { + openAIMessage.SetMediaContent(mediaMessages) + } if len(toolCalls) > 0 { openAIMessage.SetToolCalls(toolCalls) } } - - openAIMessages = append(openAIMessages, openAIMessage) + if len(openAIMessage.ParseContent()) > 0 { + openAIMessages = append(openAIMessages, openAIMessage) + } } openAIRequest.Messages = openAIMessages @@ -154,22 +170,35 @@ func ClaudeErrorToOpenAIError(claudeError *dto.ClaudeErrorWithStatusCode) *dto.O } } +func generateStopBlock(index int) *dto.ClaudeResponse { + return &dto.ClaudeResponse{ + Type: "content_block_stop", + Index: common.GetPointer[int](index), + } +} + func StreamResponseOpenAI2Claude(openAIResponse *dto.ChatCompletionsStreamResponse, info *relaycommon.RelayInfo) []*dto.ClaudeResponse { var claudeResponses []*dto.ClaudeResponse - if info.ResponseTimes == 1 { - claudeResponses = append(claudeResponses, &dto.ClaudeResponse{ - Type: "message_start", - Message: &dto.ClaudeMediaMessage{ - Id: openAIResponse.Id, - Model: openAIResponse.Model, - Type: "message", - Role: "assistant", - Usage: &dto.ClaudeUsage{ - InputTokens: info.PromptTokens, - OutputTokens: 0, - }, + if info.SendResponseCount == 1 { + msg := &dto.ClaudeMediaMessage{ + Id: openAIResponse.Id, + Model: openAIResponse.Model, + Type: "message", + Role: "assistant", + Usage: &dto.ClaudeUsage{ + InputTokens: info.PromptTokens, + OutputTokens: 0, }, + } + msg.SetContent(make([]any, 0)) + claudeResponses = append(claudeResponses, &dto.ClaudeResponse{ + Type: "message_start", + Message: msg, }) + claudeResponses = append(claudeResponses) + //claudeResponses = append(claudeResponses, &dto.ClaudeResponse{ + // Type: "ping", + //}) if openAIResponse.IsToolCall() { resp := &dto.ClaudeResponse{ Type: "content_block_start", @@ -192,23 +221,18 @@ func StreamResponseOpenAI2Claude(openAIResponse *dto.ChatCompletionsStreamRespon resp.SetIndex(0) claudeResponses = append(claudeResponses, resp) } - claudeResponses = append(claudeResponses, &dto.ClaudeResponse{ - Type: "ping", - }) return claudeResponses } if len(openAIResponse.Choices) == 0 { // no choices // TODO: handle this case + return claudeResponses } else { chosenChoice := openAIResponse.Choices[0] if chosenChoice.FinishReason != nil && *chosenChoice.FinishReason != "" { // should be done - claudeResponses = append(claudeResponses, &dto.ClaudeResponse{ - Type: "content_block_stop", - Index: common.GetPointer[int](0), - }) + claudeResponses = append(claudeResponses, generateStopBlock(info.ClaudeConvertInfo.Index)) if openAIResponse.Usage != nil { claudeResponses = append(claudeResponses, &dto.ClaudeResponse{ Type: "message_delta", @@ -229,18 +253,35 @@ func StreamResponseOpenAI2Claude(openAIResponse *dto.ChatCompletionsStreamRespon claudeResponse.SetIndex(0) claudeResponse.Type = "content_block_delta" if len(chosenChoice.Delta.ToolCalls) > 0 { + if info.ClaudeConvertInfo.LastMessagesType == relaycommon.LastMessageTypeText { + claudeResponses = append(claudeResponses, generateStopBlock(info.ClaudeConvertInfo.Index)) + info.ClaudeConvertInfo.Index++ + claudeResponses = append(claudeResponses, &dto.ClaudeResponse{ + Index: &info.ClaudeConvertInfo.Index, + Type: "content_block_start", + ContentBlock: &dto.ClaudeMediaMessage{ + Id: openAIResponse.GetFirstToolCall().ID, + Type: "tool_use", + Name: openAIResponse.GetFirstToolCall().Function.Name, + Input: map[string]interface{}{}, + }, + }) + } + info.ClaudeConvertInfo.LastMessagesType = relaycommon.LastMessageTypeTools // tools delta claudeResponse.Delta = &dto.ClaudeMediaMessage{ Type: "input_json_delta", - PartialJson: chosenChoice.Delta.ToolCalls[0].Function.Arguments, + PartialJson: &chosenChoice.Delta.ToolCalls[0].Function.Arguments, } } else { + info.ClaudeConvertInfo.LastMessagesType = relaycommon.LastMessageTypeText // text delta claudeResponse.Delta = &dto.ClaudeMediaMessage{ Type: "text_delta", Text: common.GetPointer[string](chosenChoice.Delta.GetContentString()), } } + claudeResponse.Index = &info.ClaudeConvertInfo.Index claudeResponses = append(claudeResponses, &claudeResponse) } } From 6187656aa985b455c75a4323ee3a5368c93e0726 Mon Sep 17 00:00:00 2001 From: "1808837298@qq.com" <1808837298@qq.com> Date: Thu, 13 Mar 2025 21:10:39 +0800 Subject: [PATCH 13/18] chore: Update GitHub Actions workflows and refactor adaptor logic for Docker image builds --- .github/workflows/docker-image-amd64.yml | 10 +++++----- .github/workflows/docker-image-arm64.yml | 15 +++++++-------- relay/channel/dify/adaptor.go | 20 ++++++++++---------- 3 files changed, 22 insertions(+), 23 deletions(-) diff --git a/.github/workflows/docker-image-amd64.yml b/.github/workflows/docker-image-amd64.yml index 36236df2..a823151c 100644 --- a/.github/workflows/docker-image-amd64.yml +++ b/.github/workflows/docker-image-amd64.yml @@ -18,20 +18,20 @@ jobs: contents: read steps: - name: Check out the repo - uses: actions/checkout@v3 + uses: actions/checkout@v4 - name: Save version info run: | git describe --tags > VERSION - name: Log in to Docker Hub - uses: docker/login-action@v2 + uses: docker/login-action@v3 with: username: ${{ secrets.DOCKERHUB_USERNAME }} password: ${{ secrets.DOCKERHUB_TOKEN }} - name: Log in to the Container registry - uses: docker/login-action@v2 + uses: docker/login-action@v3 with: registry: ghcr.io username: ${{ github.actor }} @@ -39,14 +39,14 @@ jobs: - name: Extract metadata (tags, labels) for Docker id: meta - uses: docker/metadata-action@v4 + uses: docker/metadata-action@v5 with: images: | calciumion/new-api ghcr.io/${{ github.repository }} - name: Build and push Docker images - uses: docker/build-push-action@v3 + uses: docker/build-push-action@v5 with: context: . push: true diff --git a/.github/workflows/docker-image-arm64.yml b/.github/workflows/docker-image-arm64.yml index 44aec807..d7468c8e 100644 --- a/.github/workflows/docker-image-arm64.yml +++ b/.github/workflows/docker-image-arm64.yml @@ -4,7 +4,6 @@ on: push: tags: - '*' - - '!*-alpha*' workflow_dispatch: inputs: name: @@ -19,26 +18,26 @@ jobs: contents: read steps: - name: Check out the repo - uses: actions/checkout@v3 + uses: actions/checkout@v4 - name: Save version info run: | git describe --tags > VERSION - name: Set up QEMU - uses: docker/setup-qemu-action@v2 + uses: docker/setup-qemu-action@v3 - name: Set up Docker Buildx - uses: docker/setup-buildx-action@v2 + uses: docker/setup-buildx-action@v3 - name: Log in to Docker Hub - uses: docker/login-action@v2 + uses: docker/login-action@v3 with: username: ${{ secrets.DOCKERHUB_USERNAME }} password: ${{ secrets.DOCKERHUB_TOKEN }} - name: Log in to the Container registry - uses: docker/login-action@v2 + uses: docker/login-action@v3 with: registry: ghcr.io username: ${{ github.actor }} @@ -46,14 +45,14 @@ jobs: - name: Extract metadata (tags, labels) for Docker id: meta - uses: docker/metadata-action@v4 + uses: docker/metadata-action@v5 with: images: | calciumion/new-api ghcr.io/${{ github.repository }} - name: Build and push Docker images - uses: docker/build-push-action@v3 + uses: docker/build-push-action@v5 with: context: . platforms: linux/amd64,linux/arm64 diff --git a/relay/channel/dify/adaptor.go b/relay/channel/dify/adaptor.go index 003b5f83..54d10c97 100644 --- a/relay/channel/dify/adaptor.go +++ b/relay/channel/dify/adaptor.go @@ -9,7 +9,6 @@ import ( "one-api/dto" "one-api/relay/channel" relaycommon "one-api/relay/common" - "strings" ) const ( @@ -40,15 +39,16 @@ func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInf } func (a *Adaptor) Init(info *relaycommon.RelayInfo) { - if strings.HasPrefix(info.UpstreamModelName, "agent") { - a.BotType = BotTypeAgent - } else if strings.HasPrefix(info.UpstreamModelName, "workflow") { - a.BotType = BotTypeWorkFlow - } else if strings.HasPrefix(info.UpstreamModelName, "chat") { - a.BotType = BotTypeCompletion - } else { - a.BotType = BotTypeChatFlow - } + //if strings.HasPrefix(info.UpstreamModelName, "agent") { + // a.BotType = BotTypeAgent + //} else if strings.HasPrefix(info.UpstreamModelName, "workflow") { + // a.BotType = BotTypeWorkFlow + //} else if strings.HasPrefix(info.UpstreamModelName, "chat") { + // a.BotType = BotTypeCompletion + //} else { + //} + a.BotType = BotTypeChatFlow + } func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) { From cc1400e939b83df18ff1b46eed85448577ee8a63 Mon Sep 17 00:00:00 2001 From: Sh1n3zZ Date: Fri, 14 Mar 2025 03:13:52 +0800 Subject: [PATCH 14/18] fix: wrong thinking labels appear in non-thinking models (#861) --- relay/channel/openai/adaptor.go | 12 +++++++++++- relay/channel/openai/relay-openai.go | 6 ++++-- 2 files changed, 15 insertions(+), 3 deletions(-) diff --git a/relay/channel/openai/adaptor.go b/relay/channel/openai/adaptor.go index d8bc808e..c7eb4142 100644 --- a/relay/channel/openai/adaptor.go +++ b/relay/channel/openai/adaptor.go @@ -5,7 +5,6 @@ import ( "encoding/json" "errors" "fmt" - "github.com/gin-gonic/gin" "io" "mime/multipart" "net/http" @@ -23,6 +22,8 @@ import ( "one-api/relay/constant" "one-api/service" "strings" + + "github.com/gin-gonic/gin" ) type Adaptor struct { @@ -48,6 +49,15 @@ func (a *Adaptor) ConvertClaudeRequest(c *gin.Context, info *relaycommon.RelayIn func (a *Adaptor) Init(info *relaycommon.RelayInfo) { a.ChannelType = info.ChannelType + + // initialize ThinkingContentInfo when thinking_to_content is enabled + if think2Content, ok := info.ChannelSetting[constant2.ChannelSettingThinkingToContent].(bool); ok && think2Content { + info.ThinkingContentInfo = relaycommon.ThinkingContentInfo{ + IsFirstThinkingContent: true, + SendLastThinkingContent: false, + HasSentThinkingContent: false, + } + } } func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) { diff --git a/relay/channel/openai/relay-openai.go b/relay/channel/openai/relay-openai.go index 2d1ad53e..faeadead 100644 --- a/relay/channel/openai/relay-openai.go +++ b/relay/channel/openai/relay-openai.go @@ -65,6 +65,7 @@ func sendStreamData(c *gin.Context, info *relaycommon.RelayInfo, data string, fo response.Choices[i].Delta.Reasoning = nil } info.ThinkingContentInfo.IsFirstThinkingContent = false + info.ThinkingContentInfo.HasSentThinkingContent = true return helper.ObjectData(c, response) } } @@ -76,7 +77,8 @@ func sendStreamData(c *gin.Context, info *relaycommon.RelayInfo, data string, fo // Process each choice for i, choice := range lastStreamResponse.Choices { // Handle transition from thinking to content - if hasContent && !info.ThinkingContentInfo.SendLastThinkingContent { + // only send `` tag when previous thinking content has been sent + if hasContent && !info.ThinkingContentInfo.SendLastThinkingContent && info.ThinkingContentInfo.HasSentThinkingContent { response := lastStreamResponse.Copy() for j := range response.Choices { response.Choices[j].Delta.SetContentString("\n\n") @@ -87,7 +89,7 @@ func sendStreamData(c *gin.Context, info *relaycommon.RelayInfo, data string, fo helper.ObjectData(c, response) } - // Convert reasoning content to regular content + // Convert reasoning content to regular content if any if len(choice.Delta.GetReasoningContent()) > 0 { lastStreamResponse.Choices[i].Delta.SetContentString(choice.Delta.GetReasoningContent()) lastStreamResponse.Choices[i].Delta.ReasoningContent = nil From 1644dbc8640c0ebeff0ac9008528561946a82830 Mon Sep 17 00:00:00 2001 From: CalciumIon <1808837298@qq.com> Date: Fri, 14 Mar 2025 17:00:39 +0800 Subject: [PATCH 15/18] refactor: Update token usage calculation in FormatClaudeResponseInfo #865 --- relay/channel/claude/relay-claude.go | 5 +- web/src/components/SafetySetting.js | 790 --------------------------- 2 files changed, 4 insertions(+), 791 deletions(-) delete mode 100644 web/src/components/SafetySetting.js diff --git a/relay/channel/claude/relay-claude.go b/relay/channel/claude/relay-claude.go index 8607f77d..7e84f70d 100644 --- a/relay/channel/claude/relay-claude.go +++ b/relay/channel/claude/relay-claude.go @@ -457,7 +457,10 @@ func FormatClaudeResponseInfo(requestMode int, claudeResponse *dto.ClaudeRespons } } else if claudeResponse.Type == "message_delta" { claudeInfo.Usage.CompletionTokens = claudeResponse.Usage.OutputTokens - claudeInfo.Usage.TotalTokens = claudeResponse.Usage.InputTokens + claudeResponse.Usage.OutputTokens + if claudeResponse.Usage.InputTokens > 0 { + claudeInfo.Usage.PromptTokens = claudeResponse.Usage.InputTokens + } + claudeInfo.Usage.TotalTokens = claudeInfo.Usage.PromptTokens + claudeResponse.Usage.OutputTokens } else if claudeResponse.Type == "content_block_start" { } else { return false diff --git a/web/src/components/SafetySetting.js b/web/src/components/SafetySetting.js deleted file mode 100644 index 7d08838e..00000000 --- a/web/src/components/SafetySetting.js +++ /dev/null @@ -1,790 +0,0 @@ -import React, { useEffect, useState } from 'react'; -import { - Button, - Divider, - Form, - Grid, - Header, - Message, - Modal, -} from 'semantic-ui-react'; -import { API, removeTrailingSlash, showError, verifyJSON } from '../helpers'; - -import { useTheme } from '../context/Theme'; - -const SafetySetting = () => { - let [inputs, setInputs] = useState({ - PasswordLoginEnabled: '', - PasswordRegisterEnabled: '', - EmailVerificationEnabled: '', - GitHubOAuthEnabled: '', - GitHubClientId: '', - GitHubClientSecret: '', - Notice: '', - SMTPServer: '', - SMTPPort: '', - SMTPAccount: '', - SMTPFrom: '', - SMTPToken: '', - ServerAddress: '', - WorkerUrl: '', - WorkerValidKey: '', - EpayId: '', - EpayKey: '', - Price: 7.3, - MinTopUp: 1, - TopupGroupRatio: '', - PayAddress: '', - CustomCallbackAddress: '', - Footer: '', - WeChatAuthEnabled: '', - WeChatServerAddress: '', - WeChatServerToken: '', - WeChatAccountQRCodeImageURL: '', - TurnstileCheckEnabled: '', - TurnstileSiteKey: '', - TurnstileSecretKey: '', - RegisterEnabled: '', - EmailDomainRestrictionEnabled: '', - EmailAliasRestrictionEnabled: '', - SMTPSSLEnabled: '', - EmailDomainWhitelist: [], - // telegram login - TelegramOAuthEnabled: '', - TelegramBotToken: '', - TelegramBotName: '', - }); - const [originInputs, setOriginInputs] = useState({}); - let [loading, setLoading] = useState(false); - const [EmailDomainWhitelist, setEmailDomainWhitelist] = useState([]); - const [restrictedDomainInput, setRestrictedDomainInput] = useState(''); - const [showPasswordWarningModal, setShowPasswordWarningModal] = - useState(false); - - const theme = useTheme(); - const isDark = theme === 'dark'; - - const getOptions = async () => { - const res = await API.get('/api/option/'); - const { success, message, data } = res.data; - if (success) { - let newInputs = {}; - data.forEach((item) => { - if (item.key === 'TopupGroupRatio') { - item.value = JSON.stringify(JSON.parse(item.value), null, 2); - } - newInputs[item.key] = item.value; - }); - setInputs({ - ...newInputs, - EmailDomainWhitelist: newInputs.EmailDomainWhitelist.split(','), - }); - setOriginInputs(newInputs); - - setEmailDomainWhitelist( - newInputs.EmailDomainWhitelist.split(',').map((item) => { - return { key: item, text: item, value: item }; - }), - ); - } else { - showError(message); - } - }; - - useEffect(() => { - getOptions().then(); - }, []); - useEffect(() => {}, [inputs.EmailDomainWhitelist]); - - const updateOption = async (key, value) => { - setLoading(true); - switch (key) { - case 'PasswordLoginEnabled': - case 'PasswordRegisterEnabled': - case 'EmailVerificationEnabled': - case 'GitHubOAuthEnabled': - case 'WeChatAuthEnabled': - case 'TelegramOAuthEnabled': - case 'TurnstileCheckEnabled': - case 'EmailDomainRestrictionEnabled': - case 'EmailAliasRestrictionEnabled': - case 'SMTPSSLEnabled': - case 'RegisterEnabled': - value = inputs[key] === 'true' ? 'false' : 'true'; - break; - default: - break; - } - const res = await API.put('/api/option/', { - key, - value, - }); - const { success, message } = res.data; - if (success) { - if (key === 'EmailDomainWhitelist') { - value = value.split(','); - } - if (key === 'Price') { - value = parseFloat(value); - } - setInputs((inputs) => ({ - ...inputs, - [key]: value, - })); - } else { - showError(message); - } - setLoading(false); - }; - - const handleInputChange = async (e, { name, value }) => { - if (name === 'PasswordLoginEnabled' && inputs[name] === 'true') { - // block disabling password login - setShowPasswordWarningModal(true); - return; - } - if ( - name === 'Notice' || - (name.startsWith('SMTP') && name !== 'SMTPSSLEnabled') || - name === 'ServerAddress' || - name === 'WorkerUrl' || - name === 'WorkerValidKey' || - name === 'EpayId' || - name === 'EpayKey' || - name === 'Price' || - name === 'PayAddress' || - name === 'GitHubClientId' || - name === 'GitHubClientSecret' || - name === 'WeChatServerAddress' || - name === 'WeChatServerToken' || - name === 'WeChatAccountQRCodeImageURL' || - name === 'TurnstileSiteKey' || - name === 'TurnstileSecretKey' || - name === 'EmailDomainWhitelist' || - name === 'TopupGroupRatio' || - name === 'TelegramBotToken' || - name === 'TelegramBotName' - ) { - setInputs((inputs) => ({ ...inputs, [name]: value })); - } else { - await updateOption(name, value); - } - }; - - const submitServerAddress = async () => { - let ServerAddress = removeTrailingSlash(inputs.ServerAddress); - await updateOption('ServerAddress', ServerAddress); - }; - - const submitWorker = async () => { - let WorkerUrl = removeTrailingSlash(inputs.WorkerUrl); - await updateOption('WorkerUrl', WorkerUrl); - if (inputs.WorkerValidKey !== '') { - await updateOption('WorkerValidKey', inputs.WorkerValidKey); - } - } - - const submitPayAddress = async () => { - if (inputs.ServerAddress === '') { - showError('请先填写服务器地址'); - return; - } - if (originInputs['TopupGroupRatio'] !== inputs.TopupGroupRatio) { - if (!verifyJSON(inputs.TopupGroupRatio)) { - showError('充值分组倍率不是合法的 JSON 字符串'); - return; - } - await updateOption('TopupGroupRatio', inputs.TopupGroupRatio); - } - let PayAddress = removeTrailingSlash(inputs.PayAddress); - await updateOption('PayAddress', PayAddress); - if (inputs.EpayId !== '') { - await updateOption('EpayId', inputs.EpayId); - } - if (inputs.EpayKey !== undefined && inputs.EpayKey !== '') { - await updateOption('EpayKey', inputs.EpayKey); - } - await updateOption('Price', '' + inputs.Price); - }; - - const submitSMTP = async () => { - if (originInputs['SMTPServer'] !== inputs.SMTPServer) { - await updateOption('SMTPServer', inputs.SMTPServer); - } - if (originInputs['SMTPAccount'] !== inputs.SMTPAccount) { - await updateOption('SMTPAccount', inputs.SMTPAccount); - } - if (originInputs['SMTPFrom'] !== inputs.SMTPFrom) { - await updateOption('SMTPFrom', inputs.SMTPFrom); - } - if ( - originInputs['SMTPPort'] !== inputs.SMTPPort && - inputs.SMTPPort !== '' - ) { - await updateOption('SMTPPort', inputs.SMTPPort); - } - if ( - originInputs['SMTPToken'] !== inputs.SMTPToken && - inputs.SMTPToken !== '' - ) { - await updateOption('SMTPToken', inputs.SMTPToken); - } - }; - - const submitEmailDomainWhitelist = async () => { - if ( - originInputs['EmailDomainWhitelist'] !== - inputs.EmailDomainWhitelist.join(',') && - inputs.SMTPToken !== '' - ) { - await updateOption( - 'EmailDomainWhitelist', - inputs.EmailDomainWhitelist.join(','), - ); - } - }; - - const submitWeChat = async () => { - if (originInputs['WeChatServerAddress'] !== inputs.WeChatServerAddress) { - await updateOption( - 'WeChatServerAddress', - removeTrailingSlash(inputs.WeChatServerAddress), - ); - } - if ( - originInputs['WeChatAccountQRCodeImageURL'] !== - inputs.WeChatAccountQRCodeImageURL - ) { - await updateOption( - 'WeChatAccountQRCodeImageURL', - inputs.WeChatAccountQRCodeImageURL, - ); - } - if ( - originInputs['WeChatServerToken'] !== inputs.WeChatServerToken && - inputs.WeChatServerToken !== '' - ) { - await updateOption('WeChatServerToken', inputs.WeChatServerToken); - } - }; - - const submitGitHubOAuth = async () => { - if (originInputs['GitHubClientId'] !== inputs.GitHubClientId) { - await updateOption('GitHubClientId', inputs.GitHubClientId); - } - if ( - originInputs['GitHubClientSecret'] !== inputs.GitHubClientSecret && - inputs.GitHubClientSecret !== '' - ) { - await updateOption('GitHubClientSecret', inputs.GitHubClientSecret); - } - }; - - const submitTelegramSettings = async () => { - // await updateOption('TelegramOAuthEnabled', inputs.TelegramOAuthEnabled); - await updateOption('TelegramBotToken', inputs.TelegramBotToken); - await updateOption('TelegramBotName', inputs.TelegramBotName); - }; - - const submitTurnstile = async () => { - if (originInputs['TurnstileSiteKey'] !== inputs.TurnstileSiteKey) { - await updateOption('TurnstileSiteKey', inputs.TurnstileSiteKey); - } - if ( - originInputs['TurnstileSecretKey'] !== inputs.TurnstileSecretKey && - inputs.TurnstileSecretKey !== '' - ) { - await updateOption('TurnstileSecretKey', inputs.TurnstileSecretKey); - } - }; - - const submitNewRestrictedDomain = () => { - const localDomainList = inputs.EmailDomainWhitelist; - if ( - restrictedDomainInput !== '' && - !localDomainList.includes(restrictedDomainInput) - ) { - setRestrictedDomainInput(''); - setInputs({ - ...inputs, - EmailDomainWhitelist: [...localDomainList, restrictedDomainInput], - }); - setEmailDomainWhitelist([ - ...EmailDomainWhitelist, - { - key: restrictedDomainInput, - text: restrictedDomainInput, - value: restrictedDomainInput, - }, - ]); - } - }; - - return ( - - -
-
- 通用设置 -
- - - - - 更新服务器地址 - -
- 代理设置(支持 new-api-worker) -
- - - - - - 更新Worker设置 - - -
- 支付设置(当前仅支持易支付接口,默认使用上方服务器地址作为回调地址!) -
- - - - - - - - - - - - - - 更新支付设置 - -
- 配置登录注册 -
- - - {showPasswordWarningModal && ( - setShowPasswordWarningModal(false)} - size={'tiny'} - style={{ maxWidth: '450px' }} - > - 警告 - -

- 取消密码登录将导致所有未绑定其他登录方式的用户(包括管理员)无法通过密码登录,确认取消? -

-
- - - - -
- )} - - - - - -
- - - - - -
- 配置邮箱域名白名单 - - 用以防止恶意用户利用临时邮箱批量注册 - -
- - - - - - - - - { - submitNewRestrictedDomain(); - }} - > - 填入 - - } - onKeyDown={(e) => { - if (e.key === 'Enter') { - submitNewRestrictedDomain(); - } - }} - autoComplete='new-password' - placeholder='输入新的允许的邮箱域名' - value={restrictedDomainInput} - onChange={(e, { value }) => { - setRestrictedDomainInput(value); - }} - /> - - - 保存邮箱域名白名单设置 - - -
- 配置 SMTP - 用以支持系统的邮件发送 -
- - - - - - - - - - - - - 保存 SMTP 设置 - -
- 配置 GitHub OAuth App - - 用以支持通过 GitHub 进行登录注册, - - 点击此处 - - 管理你的 GitHub OAuth App - -
- - Homepage URL 填 {inputs.ServerAddress} - ,Authorization callback URL 填{' '} - {`${inputs.ServerAddress}/oauth/github`} - - - - - - - 保存 GitHub OAuth 设置 - - -
- 配置 WeChat Server - - 用以支持通过微信进行登录注册, - - 点击此处 - - 了解 WeChat Server - -
- - - - - - - 保存 WeChat Server 设置 - - -
- 配置 Telegram 登录 -
- - - - - - 保存 Telegram 登录设置 - - -
- 配置 Turnstile - - 用以支持用户校验, - - 点击此处 - - 管理你的 Turnstile Sites,推荐选择 Invisible Widget Type - -
- - - - - - 保存 Turnstile 设置 - - -
-
- ); -}; - -export default SystemSetting; From 9a78db84840050ac08d254895f2ee4128b93102b Mon Sep 17 00:00:00 2001 From: CalciumIon <1808837298@qq.com> Date: Fri, 14 Mar 2025 17:09:40 +0800 Subject: [PATCH 16/18] feat: Add HasSentThinkingContent field to ThinkingContentInfo struct --- relay/common/relay_info.go | 1 + 1 file changed, 1 insertion(+) diff --git a/relay/common/relay_info.go b/relay/common/relay_info.go index 5075d07d..baabd3e7 100644 --- a/relay/common/relay_info.go +++ b/relay/common/relay_info.go @@ -15,6 +15,7 @@ import ( type ThinkingContentInfo struct { IsFirstThinkingContent bool SendLastThinkingContent bool + HasSentThinkingContent bool } const ( From 69e44a03b1b466fcca5e34c31322ce3dd3e6697b Mon Sep 17 00:00:00 2001 From: CalciumIon <1808837298@qq.com> Date: Fri, 14 Mar 2025 17:31:05 +0800 Subject: [PATCH 17/18] refactor: Simplify OpenAI handler function signature and remove unused TextResponseWithError struct; introduce common_handler for rerank functionality --- dto/openai_response.go | 11 ---- relay/channel/ali/adaptor.go | 2 +- relay/channel/baidu_v2/adaptor.go | 2 +- relay/channel/deepseek/adaptor.go | 2 +- relay/channel/jina/adaptor.go | 6 ++- relay/channel/jina/relay-jina.go | 59 -------------------- relay/channel/mistral/adaptor.go | 2 +- relay/channel/ollama/adaptor.go | 2 +- relay/channel/openai/adaptor.go | 21 +++++--- relay/channel/openai/relay-openai.go | 8 +-- relay/channel/openrouter/adaptor.go | 80 ---------------------------- relay/channel/perplexity/adaptor.go | 2 +- relay/channel/siliconflow/adaptor.go | 6 +-- relay/channel/vertex/adaptor.go | 2 +- relay/channel/volcengine/adaptor.go | 4 +- relay/channel/zhipu_4v/adaptor.go | 2 +- relay/common_handler/rerank.go | 35 ++++++++++++ relay/relay_adaptor.go | 3 +- 18 files changed, 70 insertions(+), 179 deletions(-) delete mode 100644 relay/channel/openrouter/adaptor.go create mode 100644 relay/common_handler/rerank.go diff --git a/dto/openai_response.go b/dto/openai_response.go index 4097db55..53883bb4 100644 --- a/dto/openai_response.go +++ b/dto/openai_response.go @@ -1,16 +1,5 @@ package dto -type TextResponseWithError struct { - Id string `json:"id"` - Object string `json:"object"` - Created int64 `json:"created"` - Choices []OpenAITextResponseChoice `json:"choices"` - Data []OpenAIEmbeddingResponseItem `json:"data"` - Model string `json:"model"` - Usage `json:"usage"` - Error OpenAIError `json:"error"` -} - type SimpleResponse struct { Usage `json:"usage"` Error OpenAIError `json:"error"` diff --git a/relay/channel/ali/adaptor.go b/relay/channel/ali/adaptor.go index e28278e1..0cbcef44 100644 --- a/relay/channel/ali/adaptor.go +++ b/relay/channel/ali/adaptor.go @@ -93,7 +93,7 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycom if info.IsStream { err, usage = openai.OaiStreamHandler(c, resp, info) } else { - err, usage = openai.OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName) + err, usage = openai.OpenaiHandler(c, resp, info) } } return diff --git a/relay/channel/baidu_v2/adaptor.go b/relay/channel/baidu_v2/adaptor.go index 9645bbf5..ec7936dc 100644 --- a/relay/channel/baidu_v2/adaptor.go +++ b/relay/channel/baidu_v2/adaptor.go @@ -68,7 +68,7 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycom if info.IsStream { err, usage = openai.OaiStreamHandler(c, resp, info) } else { - err, usage = openai.OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName) + err, usage = openai.OpenaiHandler(c, resp, info) } return } diff --git a/relay/channel/deepseek/adaptor.go b/relay/channel/deepseek/adaptor.go index 64d92a48..57accc8f 100644 --- a/relay/channel/deepseek/adaptor.go +++ b/relay/channel/deepseek/adaptor.go @@ -74,7 +74,7 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycom if info.IsStream { err, usage = openai.OaiStreamHandler(c, resp, info) } else { - err, usage = openai.OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName) + err, usage = openai.OpenaiHandler(c, resp, info) } return } diff --git a/relay/channel/jina/adaptor.go b/relay/channel/jina/adaptor.go index a65e820e..ceffb79a 100644 --- a/relay/channel/jina/adaptor.go +++ b/relay/channel/jina/adaptor.go @@ -8,7 +8,9 @@ import ( "net/http" "one-api/dto" "one-api/relay/channel" + "one-api/relay/channel/openai" relaycommon "one-api/relay/common" + "one-api/relay/common_handler" "one-api/relay/constant" ) @@ -67,9 +69,9 @@ func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.Rela func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) { if info.RelayMode == constant.RelayModeRerank { - err, usage = JinaRerankHandler(c, resp) + err, usage = common_handler.RerankHandler(c, resp) } else if info.RelayMode == constant.RelayModeEmbeddings { - err, usage = jinaEmbeddingHandler(c, resp) + err, usage = openai.OpenaiHandler(c, resp, info) } return } diff --git a/relay/channel/jina/relay-jina.go b/relay/channel/jina/relay-jina.go index aee7b131..d83b5854 100644 --- a/relay/channel/jina/relay-jina.go +++ b/relay/channel/jina/relay-jina.go @@ -1,60 +1 @@ package jina - -import ( - "encoding/json" - "github.com/gin-gonic/gin" - "io" - "net/http" - "one-api/dto" - "one-api/service" -) - -func JinaRerankHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) { - responseBody, err := io.ReadAll(resp.Body) - if err != nil { - return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil - } - err = resp.Body.Close() - if err != nil { - return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil - } - var jinaResp dto.RerankResponse - err = json.Unmarshal(responseBody, &jinaResp) - if err != nil { - return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil - } - - jsonResponse, err := json.Marshal(jinaResp) - if err != nil { - return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil - } - c.Writer.Header().Set("Content-Type", "application/json") - c.Writer.WriteHeader(resp.StatusCode) - _, err = c.Writer.Write(jsonResponse) - return nil, &jinaResp.Usage -} - -func jinaEmbeddingHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) { - responseBody, err := io.ReadAll(resp.Body) - if err != nil { - return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil - } - err = resp.Body.Close() - if err != nil { - return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil - } - var jinaResp dto.OpenAIEmbeddingResponse - err = json.Unmarshal(responseBody, &jinaResp) - if err != nil { - return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil - } - - jsonResponse, err := json.Marshal(jinaResp) - if err != nil { - return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil - } - c.Writer.Header().Set("Content-Type", "application/json") - c.Writer.WriteHeader(resp.StatusCode) - _, err = c.Writer.Write(jsonResponse) - return nil, &jinaResp.Usage -} diff --git a/relay/channel/mistral/adaptor.go b/relay/channel/mistral/adaptor.go index 4857209f..82c82496 100644 --- a/relay/channel/mistral/adaptor.go +++ b/relay/channel/mistral/adaptor.go @@ -67,7 +67,7 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycom if info.IsStream { err, usage = openai.OaiStreamHandler(c, resp, info) } else { - err, usage = openai.OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName) + err, usage = openai.OpenaiHandler(c, resp, info) } return } diff --git a/relay/channel/ollama/adaptor.go b/relay/channel/ollama/adaptor.go index 2101bf70..39e408ab 100644 --- a/relay/channel/ollama/adaptor.go +++ b/relay/channel/ollama/adaptor.go @@ -75,7 +75,7 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycom if info.RelayMode == relayconstant.RelayModeEmbeddings { err, usage = ollamaEmbeddingHandler(c, resp, info.PromptTokens, info.UpstreamModelName, info.RelayMode) } else { - err, usage = openai.OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName) + err, usage = openai.OpenaiHandler(c, resp, info) } } return diff --git a/relay/channel/openai/adaptor.go b/relay/channel/openai/adaptor.go index c7eb4142..91bc5066 100644 --- a/relay/channel/openai/adaptor.go +++ b/relay/channel/openai/adaptor.go @@ -13,12 +13,13 @@ import ( "one-api/dto" "one-api/relay/channel" "one-api/relay/channel/ai360" - "one-api/relay/channel/jina" "one-api/relay/channel/lingyiwanwu" "one-api/relay/channel/minimax" "one-api/relay/channel/moonshot" + "one-api/relay/channel/openrouter" "one-api/relay/channel/xinference" relaycommon "one-api/relay/common" + "one-api/relay/common_handler" "one-api/relay/constant" "one-api/service" "strings" @@ -32,7 +33,7 @@ type Adaptor struct { } func (a *Adaptor) ConvertClaudeRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.ClaudeRequest) (any, error) { - if !strings.HasPrefix(request.Model, "claude") { + if !strings.Contains(request.Model, "claude") { return nil, fmt.Errorf("you are using openai channel type with path /v1/messages, only claude model supported convert, but got %s", request.Model) } aiRequest, err := service.ClaudeToOpenAIRequest(*request) @@ -132,10 +133,10 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, header *http.Header, info * } else { header.Set("Authorization", "Bearer "+info.ApiKey) } - //if info.ChannelType == common.ChannelTypeOpenRouter { - // req.Header.Set("HTTP-Referer", "https://github.com/songquanpeng/one-api") - // req.Header.Set("X-Title", "One API") - //} + if info.ChannelType == common.ChannelTypeOpenRouter { + header.Set("HTTP-Referer", "https://github.com/Calcium-Ion/new-api") + header.Set("X-Title", "New API") + } return nil } @@ -261,12 +262,12 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycom case constant.RelayModeImagesGenerations: err, usage = OpenaiTTSHandler(c, resp, info) case constant.RelayModeRerank: - err, usage = jina.JinaRerankHandler(c, resp) + err, usage = common_handler.RerankHandler(c, resp) default: if info.IsStream { err, usage = OaiStreamHandler(c, resp, info) } else { - err, usage = OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName) + err, usage = OpenaiHandler(c, resp, info) } } return @@ -284,6 +285,8 @@ func (a *Adaptor) GetModelList() []string { return minimax.ModelList case common.ChannelTypeXinference: return xinference.ModelList + case common.ChannelTypeOpenRouter: + return openrouter.ModelList default: return ModelList } @@ -301,6 +304,8 @@ func (a *Adaptor) GetChannelName() string { return minimax.ChannelName case common.ChannelTypeXinference: return xinference.ChannelName + case common.ChannelTypeOpenRouter: + return openrouter.ChannelName default: return ChannelName } diff --git a/relay/channel/openai/relay-openai.go b/relay/channel/openai/relay-openai.go index faeadead..30f927a7 100644 --- a/relay/channel/openai/relay-openai.go +++ b/relay/channel/openai/relay-openai.go @@ -195,7 +195,7 @@ func OaiStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel return nil, usage } -func OpenaiHandler(c *gin.Context, resp *http.Response, promptTokens int, model string) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) { +func OpenaiHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) { var simpleResponse dto.SimpleResponse responseBody, err := io.ReadAll(resp.Body) if err != nil { @@ -233,13 +233,13 @@ func OpenaiHandler(c *gin.Context, resp *http.Response, promptTokens int, model if simpleResponse.Usage.TotalTokens == 0 || (simpleResponse.Usage.PromptTokens == 0 && simpleResponse.Usage.CompletionTokens == 0) { completionTokens := 0 for _, choice := range simpleResponse.Choices { - ctkm, _ := service.CountTextToken(choice.Message.StringContent()+choice.Message.ReasoningContent+choice.Message.Reasoning, model) + ctkm, _ := service.CountTextToken(choice.Message.StringContent()+choice.Message.ReasoningContent+choice.Message.Reasoning, info.UpstreamModelName) completionTokens += ctkm } simpleResponse.Usage = dto.Usage{ - PromptTokens: promptTokens, + PromptTokens: info.PromptTokens, CompletionTokens: completionTokens, - TotalTokens: promptTokens + completionTokens, + TotalTokens: info.PromptTokens + completionTokens, } } return nil, &simpleResponse.Usage diff --git a/relay/channel/openrouter/adaptor.go b/relay/channel/openrouter/adaptor.go deleted file mode 100644 index f2909b6b..00000000 --- a/relay/channel/openrouter/adaptor.go +++ /dev/null @@ -1,80 +0,0 @@ -package openrouter - -import ( - "errors" - "fmt" - "github.com/gin-gonic/gin" - "io" - "net/http" - "one-api/dto" - "one-api/relay/channel" - "one-api/relay/channel/openai" - relaycommon "one-api/relay/common" -) - -type Adaptor struct { -} - -func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) { - //TODO implement me - panic("implement me") - return nil, nil -} - -func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) { - //TODO implement me - return nil, errors.New("not implemented") -} - -func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) { - //TODO implement me - return nil, errors.New("not implemented") -} - -func (a *Adaptor) Init(info *relaycommon.RelayInfo) { -} - -func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) { - return fmt.Sprintf("%s/v1/chat/completions", info.BaseUrl), nil -} - -func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error { - channel.SetupApiRequestHeader(info, c, req) - req.Set("Authorization", fmt.Sprintf("Bearer %s", info.ApiKey)) - req.Set("HTTP-Referer", "https://github.com/Calcium-Ion/new-api") - req.Set("X-Title", "New API") - return nil -} - -func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) { - return request, nil -} - -func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) { - return channel.DoApiRequest(a, c, info, requestBody) -} - -func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) { - return nil, errors.New("not implemented") -} - -func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.EmbeddingRequest) (any, error) { - return nil, errors.New("not implemented") -} - -func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) { - if info.IsStream { - err, usage = openai.OaiStreamHandler(c, resp, info) - } else { - err, usage = openai.OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName) - } - return -} - -func (a *Adaptor) GetModelList() []string { - return ModelList -} - -func (a *Adaptor) GetChannelName() string { - return ChannelName -} diff --git a/relay/channel/perplexity/adaptor.go b/relay/channel/perplexity/adaptor.go index 32f00047..5727cac7 100644 --- a/relay/channel/perplexity/adaptor.go +++ b/relay/channel/perplexity/adaptor.go @@ -71,7 +71,7 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycom if info.IsStream { err, usage = openai.OaiStreamHandler(c, resp, info) } else { - err, usage = openai.OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName) + err, usage = openai.OpenaiHandler(c, resp, info) } return } diff --git a/relay/channel/siliconflow/adaptor.go b/relay/channel/siliconflow/adaptor.go index 1b319e2a..89602418 100644 --- a/relay/channel/siliconflow/adaptor.go +++ b/relay/channel/siliconflow/adaptor.go @@ -78,16 +78,16 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycom if info.IsStream { err, usage = openai.OaiStreamHandler(c, resp, info) } else { - err, usage = openai.OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName) + err, usage = openai.OpenaiHandler(c, resp, info) } case constant.RelayModeCompletions: if info.IsStream { err, usage = openai.OaiStreamHandler(c, resp, info) } else { - err, usage = openai.OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName) + err, usage = openai.OpenaiHandler(c, resp, info) } case constant.RelayModeEmbeddings: - err, usage = openai.OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName) + err, usage = openai.OpenaiHandler(c, resp, info) } return } diff --git a/relay/channel/vertex/adaptor.go b/relay/channel/vertex/adaptor.go index e09845eb..a49db1ee 100644 --- a/relay/channel/vertex/adaptor.go +++ b/relay/channel/vertex/adaptor.go @@ -178,7 +178,7 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycom case RequestModeGemini: err, usage = gemini.GeminiChatHandler(c, resp, info) case RequestModeLlama: - err, usage = openai.OpenaiHandler(c, resp, info.PromptTokens, info.OriginModelName) + err, usage = openai.OpenaiHandler(c, resp, info) } } return diff --git a/relay/channel/volcengine/adaptor.go b/relay/channel/volcengine/adaptor.go index 5e5e276b..277285b7 100644 --- a/relay/channel/volcengine/adaptor.go +++ b/relay/channel/volcengine/adaptor.go @@ -81,10 +81,10 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycom if info.IsStream { err, usage = openai.OaiStreamHandler(c, resp, info) } else { - err, usage = openai.OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName) + err, usage = openai.OpenaiHandler(c, resp, info) } case constant.RelayModeEmbeddings: - err, usage = openai.OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName) + err, usage = openai.OpenaiHandler(c, resp, info) } return } diff --git a/relay/channel/zhipu_4v/adaptor.go b/relay/channel/zhipu_4v/adaptor.go index ba24814c..8f6aab39 100644 --- a/relay/channel/zhipu_4v/adaptor.go +++ b/relay/channel/zhipu_4v/adaptor.go @@ -72,7 +72,7 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycom if info.IsStream { err, usage = openai.OaiStreamHandler(c, resp, info) } else { - err, usage = openai.OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName) + err, usage = openai.OpenaiHandler(c, resp, info) } return } diff --git a/relay/common_handler/rerank.go b/relay/common_handler/rerank.go new file mode 100644 index 00000000..f33da85c --- /dev/null +++ b/relay/common_handler/rerank.go @@ -0,0 +1,35 @@ +package common_handler + +import ( + "encoding/json" + "github.com/gin-gonic/gin" + "io" + "net/http" + "one-api/dto" + "one-api/service" +) + +func RerankHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) { + responseBody, err := io.ReadAll(resp.Body) + if err != nil { + return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil + } + err = resp.Body.Close() + if err != nil { + return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil + } + var jinaResp dto.RerankResponse + err = json.Unmarshal(responseBody, &jinaResp) + if err != nil { + return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil + } + + jsonResponse, err := json.Marshal(jinaResp) + if err != nil { + return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil + } + c.Writer.Header().Set("Content-Type", "application/json") + c.Writer.WriteHeader(resp.StatusCode) + _, err = c.Writer.Write(jsonResponse) + return nil, &jinaResp.Usage +} diff --git a/relay/relay_adaptor.go b/relay/relay_adaptor.go index f6d141fa..be7d07e6 100644 --- a/relay/relay_adaptor.go +++ b/relay/relay_adaptor.go @@ -18,7 +18,6 @@ import ( "one-api/relay/channel/mokaai" "one-api/relay/channel/ollama" "one-api/relay/channel/openai" - "one-api/relay/channel/openrouter" "one-api/relay/channel/palm" "one-api/relay/channel/perplexity" "one-api/relay/channel/siliconflow" @@ -83,7 +82,7 @@ func GetAdaptor(apiType int) channel.Adaptor { case constant.APITypeBaiduV2: return &baidu_v2.Adaptor{} case constant.APITypeOpenRouter: - return &openrouter.Adaptor{} + return &openai.Adaptor{} case constant.APITypeXinference: return &openai.Adaptor{} } From 19bfa158cce25d5c52f5bc390e34fdf99689a7be Mon Sep 17 00:00:00 2001 From: "1808837298@qq.com" <1808837298@qq.com> Date: Fri, 14 Mar 2025 17:48:26 +0800 Subject: [PATCH 18/18] refactor: Change ClaudeError field type to non-pointer and enhance response handling with reasoning content --- dto/claude.go | 2 +- relay/channel/claude/relay-claude.go | 10 +++++++++- 2 files changed, 10 insertions(+), 2 deletions(-) diff --git a/dto/claude.go b/dto/claude.go index f7354230..f9a6024a 100644 --- a/dto/claude.go +++ b/dto/claude.go @@ -183,7 +183,7 @@ type ClaudeResponse struct { Completion string `json:"completion,omitempty"` StopReason string `json:"stop_reason,omitempty"` Model string `json:"model,omitempty"` - Error *ClaudeError `json:"error,omitempty"` + Error ClaudeError `json:"error,omitempty"` Usage *ClaudeUsage `json:"usage,omitempty"` Index *int `json:"index,omitempty"` ContentBlock *ClaudeMediaMessage `json:"content_block,omitempty"` diff --git a/relay/channel/claude/relay-claude.go b/relay/channel/claude/relay-claude.go index 7e84f70d..3dbca4a9 100644 --- a/relay/channel/claude/relay-claude.go +++ b/relay/channel/claude/relay-claude.go @@ -376,8 +376,10 @@ func ResponseClaude2OpenAI(reqMode int, claudeResponse *dto.ClaudeResponse) *dto Created: common.GetTimestamp(), } var responseText string + var responseThinking string if len(claudeResponse.Content) > 0 { - responseText = *claudeResponse.Content[0].Text + responseText = claudeResponse.Content[0].GetText() + responseThinking = claudeResponse.Content[0].Thinking } tools := make([]dto.ToolCallResponse, 0) thinkingContent := "" @@ -424,6 +426,9 @@ func ResponseClaude2OpenAI(reqMode int, claudeResponse *dto.ClaudeResponse) *dto FinishReason: stopReasonClaude2OpenAI(claudeResponse.StopReason), } choice.SetStringContent(responseText) + if len(responseThinking) > 0 { + choice.ReasoningContent = responseThinking + } if len(tools) > 0 { choice.Message.SetToolCalls(tools) } @@ -590,6 +595,9 @@ func ClaudeHandler(c *gin.Context, resp *http.Response, requestMode int, info *r if err != nil { return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil } + if common.DebugEnabled { + println("responseBody: ", string(responseBody)) + } var claudeResponse dto.ClaudeResponse err = json.Unmarshal(responseBody, &claudeResponse) if err != nil {