diff --git a/dto/openai_response.go b/dto/openai_response.go index 4097db55..53883bb4 100644 --- a/dto/openai_response.go +++ b/dto/openai_response.go @@ -1,16 +1,5 @@ package dto -type TextResponseWithError struct { - Id string `json:"id"` - Object string `json:"object"` - Created int64 `json:"created"` - Choices []OpenAITextResponseChoice `json:"choices"` - Data []OpenAIEmbeddingResponseItem `json:"data"` - Model string `json:"model"` - Usage `json:"usage"` - Error OpenAIError `json:"error"` -} - type SimpleResponse struct { Usage `json:"usage"` Error OpenAIError `json:"error"` diff --git a/relay/channel/ali/adaptor.go b/relay/channel/ali/adaptor.go index e28278e1..0cbcef44 100644 --- a/relay/channel/ali/adaptor.go +++ b/relay/channel/ali/adaptor.go @@ -93,7 +93,7 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycom if info.IsStream { err, usage = openai.OaiStreamHandler(c, resp, info) } else { - err, usage = openai.OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName) + err, usage = openai.OpenaiHandler(c, resp, info) } } return diff --git a/relay/channel/baidu_v2/adaptor.go b/relay/channel/baidu_v2/adaptor.go index 9645bbf5..ec7936dc 100644 --- a/relay/channel/baidu_v2/adaptor.go +++ b/relay/channel/baidu_v2/adaptor.go @@ -68,7 +68,7 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycom if info.IsStream { err, usage = openai.OaiStreamHandler(c, resp, info) } else { - err, usage = openai.OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName) + err, usage = openai.OpenaiHandler(c, resp, info) } return } diff --git a/relay/channel/deepseek/adaptor.go b/relay/channel/deepseek/adaptor.go index 64d92a48..57accc8f 100644 --- a/relay/channel/deepseek/adaptor.go +++ b/relay/channel/deepseek/adaptor.go @@ -74,7 +74,7 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycom if info.IsStream { err, usage = openai.OaiStreamHandler(c, resp, info) } else { - err, usage = openai.OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName) + err, usage = openai.OpenaiHandler(c, resp, info) } return } diff --git a/relay/channel/jina/adaptor.go b/relay/channel/jina/adaptor.go index a65e820e..ceffb79a 100644 --- a/relay/channel/jina/adaptor.go +++ b/relay/channel/jina/adaptor.go @@ -8,7 +8,9 @@ import ( "net/http" "one-api/dto" "one-api/relay/channel" + "one-api/relay/channel/openai" relaycommon "one-api/relay/common" + "one-api/relay/common_handler" "one-api/relay/constant" ) @@ -67,9 +69,9 @@ func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.Rela func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) { if info.RelayMode == constant.RelayModeRerank { - err, usage = JinaRerankHandler(c, resp) + err, usage = common_handler.RerankHandler(c, resp) } else if info.RelayMode == constant.RelayModeEmbeddings { - err, usage = jinaEmbeddingHandler(c, resp) + err, usage = openai.OpenaiHandler(c, resp, info) } return } diff --git a/relay/channel/jina/relay-jina.go b/relay/channel/jina/relay-jina.go index aee7b131..d83b5854 100644 --- a/relay/channel/jina/relay-jina.go +++ b/relay/channel/jina/relay-jina.go @@ -1,60 +1 @@ package jina - -import ( - "encoding/json" - "github.com/gin-gonic/gin" - "io" - "net/http" - "one-api/dto" - "one-api/service" -) - -func JinaRerankHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) { - responseBody, err := io.ReadAll(resp.Body) - if err != nil { - return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil - } - err = resp.Body.Close() - if err != nil { - return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil - } - var jinaResp dto.RerankResponse - err = json.Unmarshal(responseBody, &jinaResp) - if err != nil { - return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil - } - - jsonResponse, err := json.Marshal(jinaResp) - if err != nil { - return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil - } - c.Writer.Header().Set("Content-Type", "application/json") - c.Writer.WriteHeader(resp.StatusCode) - _, err = c.Writer.Write(jsonResponse) - return nil, &jinaResp.Usage -} - -func jinaEmbeddingHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) { - responseBody, err := io.ReadAll(resp.Body) - if err != nil { - return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil - } - err = resp.Body.Close() - if err != nil { - return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil - } - var jinaResp dto.OpenAIEmbeddingResponse - err = json.Unmarshal(responseBody, &jinaResp) - if err != nil { - return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil - } - - jsonResponse, err := json.Marshal(jinaResp) - if err != nil { - return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil - } - c.Writer.Header().Set("Content-Type", "application/json") - c.Writer.WriteHeader(resp.StatusCode) - _, err = c.Writer.Write(jsonResponse) - return nil, &jinaResp.Usage -} diff --git a/relay/channel/mistral/adaptor.go b/relay/channel/mistral/adaptor.go index 4857209f..82c82496 100644 --- a/relay/channel/mistral/adaptor.go +++ b/relay/channel/mistral/adaptor.go @@ -67,7 +67,7 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycom if info.IsStream { err, usage = openai.OaiStreamHandler(c, resp, info) } else { - err, usage = openai.OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName) + err, usage = openai.OpenaiHandler(c, resp, info) } return } diff --git a/relay/channel/ollama/adaptor.go b/relay/channel/ollama/adaptor.go index 2101bf70..39e408ab 100644 --- a/relay/channel/ollama/adaptor.go +++ b/relay/channel/ollama/adaptor.go @@ -75,7 +75,7 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycom if info.RelayMode == relayconstant.RelayModeEmbeddings { err, usage = ollamaEmbeddingHandler(c, resp, info.PromptTokens, info.UpstreamModelName, info.RelayMode) } else { - err, usage = openai.OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName) + err, usage = openai.OpenaiHandler(c, resp, info) } } return diff --git a/relay/channel/openai/adaptor.go b/relay/channel/openai/adaptor.go index c7eb4142..91bc5066 100644 --- a/relay/channel/openai/adaptor.go +++ b/relay/channel/openai/adaptor.go @@ -13,12 +13,13 @@ import ( "one-api/dto" "one-api/relay/channel" "one-api/relay/channel/ai360" - "one-api/relay/channel/jina" "one-api/relay/channel/lingyiwanwu" "one-api/relay/channel/minimax" "one-api/relay/channel/moonshot" + "one-api/relay/channel/openrouter" "one-api/relay/channel/xinference" relaycommon "one-api/relay/common" + "one-api/relay/common_handler" "one-api/relay/constant" "one-api/service" "strings" @@ -32,7 +33,7 @@ type Adaptor struct { } func (a *Adaptor) ConvertClaudeRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.ClaudeRequest) (any, error) { - if !strings.HasPrefix(request.Model, "claude") { + if !strings.Contains(request.Model, "claude") { return nil, fmt.Errorf("you are using openai channel type with path /v1/messages, only claude model supported convert, but got %s", request.Model) } aiRequest, err := service.ClaudeToOpenAIRequest(*request) @@ -132,10 +133,10 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, header *http.Header, info * } else { header.Set("Authorization", "Bearer "+info.ApiKey) } - //if info.ChannelType == common.ChannelTypeOpenRouter { - // req.Header.Set("HTTP-Referer", "https://github.com/songquanpeng/one-api") - // req.Header.Set("X-Title", "One API") - //} + if info.ChannelType == common.ChannelTypeOpenRouter { + header.Set("HTTP-Referer", "https://github.com/Calcium-Ion/new-api") + header.Set("X-Title", "New API") + } return nil } @@ -261,12 +262,12 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycom case constant.RelayModeImagesGenerations: err, usage = OpenaiTTSHandler(c, resp, info) case constant.RelayModeRerank: - err, usage = jina.JinaRerankHandler(c, resp) + err, usage = common_handler.RerankHandler(c, resp) default: if info.IsStream { err, usage = OaiStreamHandler(c, resp, info) } else { - err, usage = OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName) + err, usage = OpenaiHandler(c, resp, info) } } return @@ -284,6 +285,8 @@ func (a *Adaptor) GetModelList() []string { return minimax.ModelList case common.ChannelTypeXinference: return xinference.ModelList + case common.ChannelTypeOpenRouter: + return openrouter.ModelList default: return ModelList } @@ -301,6 +304,8 @@ func (a *Adaptor) GetChannelName() string { return minimax.ChannelName case common.ChannelTypeXinference: return xinference.ChannelName + case common.ChannelTypeOpenRouter: + return openrouter.ChannelName default: return ChannelName } diff --git a/relay/channel/openai/relay-openai.go b/relay/channel/openai/relay-openai.go index faeadead..30f927a7 100644 --- a/relay/channel/openai/relay-openai.go +++ b/relay/channel/openai/relay-openai.go @@ -195,7 +195,7 @@ func OaiStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel return nil, usage } -func OpenaiHandler(c *gin.Context, resp *http.Response, promptTokens int, model string) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) { +func OpenaiHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) { var simpleResponse dto.SimpleResponse responseBody, err := io.ReadAll(resp.Body) if err != nil { @@ -233,13 +233,13 @@ func OpenaiHandler(c *gin.Context, resp *http.Response, promptTokens int, model if simpleResponse.Usage.TotalTokens == 0 || (simpleResponse.Usage.PromptTokens == 0 && simpleResponse.Usage.CompletionTokens == 0) { completionTokens := 0 for _, choice := range simpleResponse.Choices { - ctkm, _ := service.CountTextToken(choice.Message.StringContent()+choice.Message.ReasoningContent+choice.Message.Reasoning, model) + ctkm, _ := service.CountTextToken(choice.Message.StringContent()+choice.Message.ReasoningContent+choice.Message.Reasoning, info.UpstreamModelName) completionTokens += ctkm } simpleResponse.Usage = dto.Usage{ - PromptTokens: promptTokens, + PromptTokens: info.PromptTokens, CompletionTokens: completionTokens, - TotalTokens: promptTokens + completionTokens, + TotalTokens: info.PromptTokens + completionTokens, } } return nil, &simpleResponse.Usage diff --git a/relay/channel/openrouter/adaptor.go b/relay/channel/openrouter/adaptor.go deleted file mode 100644 index f2909b6b..00000000 --- a/relay/channel/openrouter/adaptor.go +++ /dev/null @@ -1,80 +0,0 @@ -package openrouter - -import ( - "errors" - "fmt" - "github.com/gin-gonic/gin" - "io" - "net/http" - "one-api/dto" - "one-api/relay/channel" - "one-api/relay/channel/openai" - relaycommon "one-api/relay/common" -) - -type Adaptor struct { -} - -func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) { - //TODO implement me - panic("implement me") - return nil, nil -} - -func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) { - //TODO implement me - return nil, errors.New("not implemented") -} - -func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) { - //TODO implement me - return nil, errors.New("not implemented") -} - -func (a *Adaptor) Init(info *relaycommon.RelayInfo) { -} - -func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) { - return fmt.Sprintf("%s/v1/chat/completions", info.BaseUrl), nil -} - -func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error { - channel.SetupApiRequestHeader(info, c, req) - req.Set("Authorization", fmt.Sprintf("Bearer %s", info.ApiKey)) - req.Set("HTTP-Referer", "https://github.com/Calcium-Ion/new-api") - req.Set("X-Title", "New API") - return nil -} - -func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) { - return request, nil -} - -func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) { - return channel.DoApiRequest(a, c, info, requestBody) -} - -func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) { - return nil, errors.New("not implemented") -} - -func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.EmbeddingRequest) (any, error) { - return nil, errors.New("not implemented") -} - -func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) { - if info.IsStream { - err, usage = openai.OaiStreamHandler(c, resp, info) - } else { - err, usage = openai.OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName) - } - return -} - -func (a *Adaptor) GetModelList() []string { - return ModelList -} - -func (a *Adaptor) GetChannelName() string { - return ChannelName -} diff --git a/relay/channel/perplexity/adaptor.go b/relay/channel/perplexity/adaptor.go index 32f00047..5727cac7 100644 --- a/relay/channel/perplexity/adaptor.go +++ b/relay/channel/perplexity/adaptor.go @@ -71,7 +71,7 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycom if info.IsStream { err, usage = openai.OaiStreamHandler(c, resp, info) } else { - err, usage = openai.OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName) + err, usage = openai.OpenaiHandler(c, resp, info) } return } diff --git a/relay/channel/siliconflow/adaptor.go b/relay/channel/siliconflow/adaptor.go index 1b319e2a..89602418 100644 --- a/relay/channel/siliconflow/adaptor.go +++ b/relay/channel/siliconflow/adaptor.go @@ -78,16 +78,16 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycom if info.IsStream { err, usage = openai.OaiStreamHandler(c, resp, info) } else { - err, usage = openai.OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName) + err, usage = openai.OpenaiHandler(c, resp, info) } case constant.RelayModeCompletions: if info.IsStream { err, usage = openai.OaiStreamHandler(c, resp, info) } else { - err, usage = openai.OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName) + err, usage = openai.OpenaiHandler(c, resp, info) } case constant.RelayModeEmbeddings: - err, usage = openai.OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName) + err, usage = openai.OpenaiHandler(c, resp, info) } return } diff --git a/relay/channel/vertex/adaptor.go b/relay/channel/vertex/adaptor.go index e09845eb..a49db1ee 100644 --- a/relay/channel/vertex/adaptor.go +++ b/relay/channel/vertex/adaptor.go @@ -178,7 +178,7 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycom case RequestModeGemini: err, usage = gemini.GeminiChatHandler(c, resp, info) case RequestModeLlama: - err, usage = openai.OpenaiHandler(c, resp, info.PromptTokens, info.OriginModelName) + err, usage = openai.OpenaiHandler(c, resp, info) } } return diff --git a/relay/channel/volcengine/adaptor.go b/relay/channel/volcengine/adaptor.go index 5e5e276b..277285b7 100644 --- a/relay/channel/volcengine/adaptor.go +++ b/relay/channel/volcengine/adaptor.go @@ -81,10 +81,10 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycom if info.IsStream { err, usage = openai.OaiStreamHandler(c, resp, info) } else { - err, usage = openai.OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName) + err, usage = openai.OpenaiHandler(c, resp, info) } case constant.RelayModeEmbeddings: - err, usage = openai.OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName) + err, usage = openai.OpenaiHandler(c, resp, info) } return } diff --git a/relay/channel/zhipu_4v/adaptor.go b/relay/channel/zhipu_4v/adaptor.go index ba24814c..8f6aab39 100644 --- a/relay/channel/zhipu_4v/adaptor.go +++ b/relay/channel/zhipu_4v/adaptor.go @@ -72,7 +72,7 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycom if info.IsStream { err, usage = openai.OaiStreamHandler(c, resp, info) } else { - err, usage = openai.OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName) + err, usage = openai.OpenaiHandler(c, resp, info) } return } diff --git a/relay/common_handler/rerank.go b/relay/common_handler/rerank.go new file mode 100644 index 00000000..f33da85c --- /dev/null +++ b/relay/common_handler/rerank.go @@ -0,0 +1,35 @@ +package common_handler + +import ( + "encoding/json" + "github.com/gin-gonic/gin" + "io" + "net/http" + "one-api/dto" + "one-api/service" +) + +func RerankHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) { + responseBody, err := io.ReadAll(resp.Body) + if err != nil { + return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil + } + err = resp.Body.Close() + if err != nil { + return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil + } + var jinaResp dto.RerankResponse + err = json.Unmarshal(responseBody, &jinaResp) + if err != nil { + return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil + } + + jsonResponse, err := json.Marshal(jinaResp) + if err != nil { + return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil + } + c.Writer.Header().Set("Content-Type", "application/json") + c.Writer.WriteHeader(resp.StatusCode) + _, err = c.Writer.Write(jsonResponse) + return nil, &jinaResp.Usage +} diff --git a/relay/relay_adaptor.go b/relay/relay_adaptor.go index f6d141fa..be7d07e6 100644 --- a/relay/relay_adaptor.go +++ b/relay/relay_adaptor.go @@ -18,7 +18,6 @@ import ( "one-api/relay/channel/mokaai" "one-api/relay/channel/ollama" "one-api/relay/channel/openai" - "one-api/relay/channel/openrouter" "one-api/relay/channel/palm" "one-api/relay/channel/perplexity" "one-api/relay/channel/siliconflow" @@ -83,7 +82,7 @@ func GetAdaptor(apiType int) channel.Adaptor { case constant.APITypeBaiduV2: return &baidu_v2.Adaptor{} case constant.APITypeOpenRouter: - return &openrouter.Adaptor{} + return &openai.Adaptor{} case constant.APITypeXinference: return &openai.Adaptor{} }