refactor: Simplify OpenAI handler function signature and remove unused TextResponseWithError struct; introduce common_handler for rerank functionality
This commit is contained in:
@@ -1,16 +1,5 @@
|
|||||||
package dto
|
package dto
|
||||||
|
|
||||||
type TextResponseWithError struct {
|
|
||||||
Id string `json:"id"`
|
|
||||||
Object string `json:"object"`
|
|
||||||
Created int64 `json:"created"`
|
|
||||||
Choices []OpenAITextResponseChoice `json:"choices"`
|
|
||||||
Data []OpenAIEmbeddingResponseItem `json:"data"`
|
|
||||||
Model string `json:"model"`
|
|
||||||
Usage `json:"usage"`
|
|
||||||
Error OpenAIError `json:"error"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type SimpleResponse struct {
|
type SimpleResponse struct {
|
||||||
Usage `json:"usage"`
|
Usage `json:"usage"`
|
||||||
Error OpenAIError `json:"error"`
|
Error OpenAIError `json:"error"`
|
||||||
|
|||||||
@@ -93,7 +93,7 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycom
|
|||||||
if info.IsStream {
|
if info.IsStream {
|
||||||
err, usage = openai.OaiStreamHandler(c, resp, info)
|
err, usage = openai.OaiStreamHandler(c, resp, info)
|
||||||
} else {
|
} else {
|
||||||
err, usage = openai.OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName)
|
err, usage = openai.OpenaiHandler(c, resp, info)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
|
|||||||
@@ -68,7 +68,7 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycom
|
|||||||
if info.IsStream {
|
if info.IsStream {
|
||||||
err, usage = openai.OaiStreamHandler(c, resp, info)
|
err, usage = openai.OaiStreamHandler(c, resp, info)
|
||||||
} else {
|
} else {
|
||||||
err, usage = openai.OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName)
|
err, usage = openai.OpenaiHandler(c, resp, info)
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -74,7 +74,7 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycom
|
|||||||
if info.IsStream {
|
if info.IsStream {
|
||||||
err, usage = openai.OaiStreamHandler(c, resp, info)
|
err, usage = openai.OaiStreamHandler(c, resp, info)
|
||||||
} else {
|
} else {
|
||||||
err, usage = openai.OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName)
|
err, usage = openai.OpenaiHandler(c, resp, info)
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -8,7 +8,9 @@ import (
|
|||||||
"net/http"
|
"net/http"
|
||||||
"one-api/dto"
|
"one-api/dto"
|
||||||
"one-api/relay/channel"
|
"one-api/relay/channel"
|
||||||
|
"one-api/relay/channel/openai"
|
||||||
relaycommon "one-api/relay/common"
|
relaycommon "one-api/relay/common"
|
||||||
|
"one-api/relay/common_handler"
|
||||||
"one-api/relay/constant"
|
"one-api/relay/constant"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -67,9 +69,9 @@ func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.Rela
|
|||||||
|
|
||||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) {
|
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) {
|
||||||
if info.RelayMode == constant.RelayModeRerank {
|
if info.RelayMode == constant.RelayModeRerank {
|
||||||
err, usage = JinaRerankHandler(c, resp)
|
err, usage = common_handler.RerankHandler(c, resp)
|
||||||
} else if info.RelayMode == constant.RelayModeEmbeddings {
|
} else if info.RelayMode == constant.RelayModeEmbeddings {
|
||||||
err, usage = jinaEmbeddingHandler(c, resp)
|
err, usage = openai.OpenaiHandler(c, resp, info)
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,60 +1 @@
|
|||||||
package jina
|
package jina
|
||||||
|
|
||||||
import (
|
|
||||||
"encoding/json"
|
|
||||||
"github.com/gin-gonic/gin"
|
|
||||||
"io"
|
|
||||||
"net/http"
|
|
||||||
"one-api/dto"
|
|
||||||
"one-api/service"
|
|
||||||
)
|
|
||||||
|
|
||||||
func JinaRerankHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
|
|
||||||
responseBody, err := io.ReadAll(resp.Body)
|
|
||||||
if err != nil {
|
|
||||||
return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
|
|
||||||
}
|
|
||||||
err = resp.Body.Close()
|
|
||||||
if err != nil {
|
|
||||||
return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
|
|
||||||
}
|
|
||||||
var jinaResp dto.RerankResponse
|
|
||||||
err = json.Unmarshal(responseBody, &jinaResp)
|
|
||||||
if err != nil {
|
|
||||||
return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
jsonResponse, err := json.Marshal(jinaResp)
|
|
||||||
if err != nil {
|
|
||||||
return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
|
|
||||||
}
|
|
||||||
c.Writer.Header().Set("Content-Type", "application/json")
|
|
||||||
c.Writer.WriteHeader(resp.StatusCode)
|
|
||||||
_, err = c.Writer.Write(jsonResponse)
|
|
||||||
return nil, &jinaResp.Usage
|
|
||||||
}
|
|
||||||
|
|
||||||
func jinaEmbeddingHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
|
|
||||||
responseBody, err := io.ReadAll(resp.Body)
|
|
||||||
if err != nil {
|
|
||||||
return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
|
|
||||||
}
|
|
||||||
err = resp.Body.Close()
|
|
||||||
if err != nil {
|
|
||||||
return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
|
|
||||||
}
|
|
||||||
var jinaResp dto.OpenAIEmbeddingResponse
|
|
||||||
err = json.Unmarshal(responseBody, &jinaResp)
|
|
||||||
if err != nil {
|
|
||||||
return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
jsonResponse, err := json.Marshal(jinaResp)
|
|
||||||
if err != nil {
|
|
||||||
return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
|
|
||||||
}
|
|
||||||
c.Writer.Header().Set("Content-Type", "application/json")
|
|
||||||
c.Writer.WriteHeader(resp.StatusCode)
|
|
||||||
_, err = c.Writer.Write(jsonResponse)
|
|
||||||
return nil, &jinaResp.Usage
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -67,7 +67,7 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycom
|
|||||||
if info.IsStream {
|
if info.IsStream {
|
||||||
err, usage = openai.OaiStreamHandler(c, resp, info)
|
err, usage = openai.OaiStreamHandler(c, resp, info)
|
||||||
} else {
|
} else {
|
||||||
err, usage = openai.OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName)
|
err, usage = openai.OpenaiHandler(c, resp, info)
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -75,7 +75,7 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycom
|
|||||||
if info.RelayMode == relayconstant.RelayModeEmbeddings {
|
if info.RelayMode == relayconstant.RelayModeEmbeddings {
|
||||||
err, usage = ollamaEmbeddingHandler(c, resp, info.PromptTokens, info.UpstreamModelName, info.RelayMode)
|
err, usage = ollamaEmbeddingHandler(c, resp, info.PromptTokens, info.UpstreamModelName, info.RelayMode)
|
||||||
} else {
|
} else {
|
||||||
err, usage = openai.OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName)
|
err, usage = openai.OpenaiHandler(c, resp, info)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
|
|||||||
@@ -13,12 +13,13 @@ import (
|
|||||||
"one-api/dto"
|
"one-api/dto"
|
||||||
"one-api/relay/channel"
|
"one-api/relay/channel"
|
||||||
"one-api/relay/channel/ai360"
|
"one-api/relay/channel/ai360"
|
||||||
"one-api/relay/channel/jina"
|
|
||||||
"one-api/relay/channel/lingyiwanwu"
|
"one-api/relay/channel/lingyiwanwu"
|
||||||
"one-api/relay/channel/minimax"
|
"one-api/relay/channel/minimax"
|
||||||
"one-api/relay/channel/moonshot"
|
"one-api/relay/channel/moonshot"
|
||||||
|
"one-api/relay/channel/openrouter"
|
||||||
"one-api/relay/channel/xinference"
|
"one-api/relay/channel/xinference"
|
||||||
relaycommon "one-api/relay/common"
|
relaycommon "one-api/relay/common"
|
||||||
|
"one-api/relay/common_handler"
|
||||||
"one-api/relay/constant"
|
"one-api/relay/constant"
|
||||||
"one-api/service"
|
"one-api/service"
|
||||||
"strings"
|
"strings"
|
||||||
@@ -32,7 +33,7 @@ type Adaptor struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (a *Adaptor) ConvertClaudeRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.ClaudeRequest) (any, error) {
|
func (a *Adaptor) ConvertClaudeRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.ClaudeRequest) (any, error) {
|
||||||
if !strings.HasPrefix(request.Model, "claude") {
|
if !strings.Contains(request.Model, "claude") {
|
||||||
return nil, fmt.Errorf("you are using openai channel type with path /v1/messages, only claude model supported convert, but got %s", request.Model)
|
return nil, fmt.Errorf("you are using openai channel type with path /v1/messages, only claude model supported convert, but got %s", request.Model)
|
||||||
}
|
}
|
||||||
aiRequest, err := service.ClaudeToOpenAIRequest(*request)
|
aiRequest, err := service.ClaudeToOpenAIRequest(*request)
|
||||||
@@ -132,10 +133,10 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, header *http.Header, info *
|
|||||||
} else {
|
} else {
|
||||||
header.Set("Authorization", "Bearer "+info.ApiKey)
|
header.Set("Authorization", "Bearer "+info.ApiKey)
|
||||||
}
|
}
|
||||||
//if info.ChannelType == common.ChannelTypeOpenRouter {
|
if info.ChannelType == common.ChannelTypeOpenRouter {
|
||||||
// req.Header.Set("HTTP-Referer", "https://github.com/songquanpeng/one-api")
|
header.Set("HTTP-Referer", "https://github.com/Calcium-Ion/new-api")
|
||||||
// req.Header.Set("X-Title", "One API")
|
header.Set("X-Title", "New API")
|
||||||
//}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -261,12 +262,12 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycom
|
|||||||
case constant.RelayModeImagesGenerations:
|
case constant.RelayModeImagesGenerations:
|
||||||
err, usage = OpenaiTTSHandler(c, resp, info)
|
err, usage = OpenaiTTSHandler(c, resp, info)
|
||||||
case constant.RelayModeRerank:
|
case constant.RelayModeRerank:
|
||||||
err, usage = jina.JinaRerankHandler(c, resp)
|
err, usage = common_handler.RerankHandler(c, resp)
|
||||||
default:
|
default:
|
||||||
if info.IsStream {
|
if info.IsStream {
|
||||||
err, usage = OaiStreamHandler(c, resp, info)
|
err, usage = OaiStreamHandler(c, resp, info)
|
||||||
} else {
|
} else {
|
||||||
err, usage = OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName)
|
err, usage = OpenaiHandler(c, resp, info)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
@@ -284,6 +285,8 @@ func (a *Adaptor) GetModelList() []string {
|
|||||||
return minimax.ModelList
|
return minimax.ModelList
|
||||||
case common.ChannelTypeXinference:
|
case common.ChannelTypeXinference:
|
||||||
return xinference.ModelList
|
return xinference.ModelList
|
||||||
|
case common.ChannelTypeOpenRouter:
|
||||||
|
return openrouter.ModelList
|
||||||
default:
|
default:
|
||||||
return ModelList
|
return ModelList
|
||||||
}
|
}
|
||||||
@@ -301,6 +304,8 @@ func (a *Adaptor) GetChannelName() string {
|
|||||||
return minimax.ChannelName
|
return minimax.ChannelName
|
||||||
case common.ChannelTypeXinference:
|
case common.ChannelTypeXinference:
|
||||||
return xinference.ChannelName
|
return xinference.ChannelName
|
||||||
|
case common.ChannelTypeOpenRouter:
|
||||||
|
return openrouter.ChannelName
|
||||||
default:
|
default:
|
||||||
return ChannelName
|
return ChannelName
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -195,7 +195,7 @@ func OaiStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel
|
|||||||
return nil, usage
|
return nil, usage
|
||||||
}
|
}
|
||||||
|
|
||||||
func OpenaiHandler(c *gin.Context, resp *http.Response, promptTokens int, model string) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
|
func OpenaiHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
|
||||||
var simpleResponse dto.SimpleResponse
|
var simpleResponse dto.SimpleResponse
|
||||||
responseBody, err := io.ReadAll(resp.Body)
|
responseBody, err := io.ReadAll(resp.Body)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -233,13 +233,13 @@ func OpenaiHandler(c *gin.Context, resp *http.Response, promptTokens int, model
|
|||||||
if simpleResponse.Usage.TotalTokens == 0 || (simpleResponse.Usage.PromptTokens == 0 && simpleResponse.Usage.CompletionTokens == 0) {
|
if simpleResponse.Usage.TotalTokens == 0 || (simpleResponse.Usage.PromptTokens == 0 && simpleResponse.Usage.CompletionTokens == 0) {
|
||||||
completionTokens := 0
|
completionTokens := 0
|
||||||
for _, choice := range simpleResponse.Choices {
|
for _, choice := range simpleResponse.Choices {
|
||||||
ctkm, _ := service.CountTextToken(choice.Message.StringContent()+choice.Message.ReasoningContent+choice.Message.Reasoning, model)
|
ctkm, _ := service.CountTextToken(choice.Message.StringContent()+choice.Message.ReasoningContent+choice.Message.Reasoning, info.UpstreamModelName)
|
||||||
completionTokens += ctkm
|
completionTokens += ctkm
|
||||||
}
|
}
|
||||||
simpleResponse.Usage = dto.Usage{
|
simpleResponse.Usage = dto.Usage{
|
||||||
PromptTokens: promptTokens,
|
PromptTokens: info.PromptTokens,
|
||||||
CompletionTokens: completionTokens,
|
CompletionTokens: completionTokens,
|
||||||
TotalTokens: promptTokens + completionTokens,
|
TotalTokens: info.PromptTokens + completionTokens,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return nil, &simpleResponse.Usage
|
return nil, &simpleResponse.Usage
|
||||||
|
|||||||
@@ -1,80 +0,0 @@
|
|||||||
package openrouter
|
|
||||||
|
|
||||||
import (
|
|
||||||
"errors"
|
|
||||||
"fmt"
|
|
||||||
"github.com/gin-gonic/gin"
|
|
||||||
"io"
|
|
||||||
"net/http"
|
|
||||||
"one-api/dto"
|
|
||||||
"one-api/relay/channel"
|
|
||||||
"one-api/relay/channel/openai"
|
|
||||||
relaycommon "one-api/relay/common"
|
|
||||||
)
|
|
||||||
|
|
||||||
type Adaptor struct {
|
|
||||||
}
|
|
||||||
|
|
||||||
func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) {
|
|
||||||
//TODO implement me
|
|
||||||
panic("implement me")
|
|
||||||
return nil, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) {
|
|
||||||
//TODO implement me
|
|
||||||
return nil, errors.New("not implemented")
|
|
||||||
}
|
|
||||||
|
|
||||||
func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) {
|
|
||||||
//TODO implement me
|
|
||||||
return nil, errors.New("not implemented")
|
|
||||||
}
|
|
||||||
|
|
||||||
func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
|
|
||||||
}
|
|
||||||
|
|
||||||
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
|
||||||
return fmt.Sprintf("%s/v1/chat/completions", info.BaseUrl), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error {
|
|
||||||
channel.SetupApiRequestHeader(info, c, req)
|
|
||||||
req.Set("Authorization", fmt.Sprintf("Bearer %s", info.ApiKey))
|
|
||||||
req.Set("HTTP-Referer", "https://github.com/Calcium-Ion/new-api")
|
|
||||||
req.Set("X-Title", "New API")
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
|
|
||||||
return request, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
|
|
||||||
return channel.DoApiRequest(a, c, info, requestBody)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) {
|
|
||||||
return nil, errors.New("not implemented")
|
|
||||||
}
|
|
||||||
|
|
||||||
func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.EmbeddingRequest) (any, error) {
|
|
||||||
return nil, errors.New("not implemented")
|
|
||||||
}
|
|
||||||
|
|
||||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) {
|
|
||||||
if info.IsStream {
|
|
||||||
err, usage = openai.OaiStreamHandler(c, resp, info)
|
|
||||||
} else {
|
|
||||||
err, usage = openai.OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName)
|
|
||||||
}
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
func (a *Adaptor) GetModelList() []string {
|
|
||||||
return ModelList
|
|
||||||
}
|
|
||||||
|
|
||||||
func (a *Adaptor) GetChannelName() string {
|
|
||||||
return ChannelName
|
|
||||||
}
|
|
||||||
@@ -71,7 +71,7 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycom
|
|||||||
if info.IsStream {
|
if info.IsStream {
|
||||||
err, usage = openai.OaiStreamHandler(c, resp, info)
|
err, usage = openai.OaiStreamHandler(c, resp, info)
|
||||||
} else {
|
} else {
|
||||||
err, usage = openai.OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName)
|
err, usage = openai.OpenaiHandler(c, resp, info)
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -78,16 +78,16 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycom
|
|||||||
if info.IsStream {
|
if info.IsStream {
|
||||||
err, usage = openai.OaiStreamHandler(c, resp, info)
|
err, usage = openai.OaiStreamHandler(c, resp, info)
|
||||||
} else {
|
} else {
|
||||||
err, usage = openai.OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName)
|
err, usage = openai.OpenaiHandler(c, resp, info)
|
||||||
}
|
}
|
||||||
case constant.RelayModeCompletions:
|
case constant.RelayModeCompletions:
|
||||||
if info.IsStream {
|
if info.IsStream {
|
||||||
err, usage = openai.OaiStreamHandler(c, resp, info)
|
err, usage = openai.OaiStreamHandler(c, resp, info)
|
||||||
} else {
|
} else {
|
||||||
err, usage = openai.OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName)
|
err, usage = openai.OpenaiHandler(c, resp, info)
|
||||||
}
|
}
|
||||||
case constant.RelayModeEmbeddings:
|
case constant.RelayModeEmbeddings:
|
||||||
err, usage = openai.OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName)
|
err, usage = openai.OpenaiHandler(c, resp, info)
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -178,7 +178,7 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycom
|
|||||||
case RequestModeGemini:
|
case RequestModeGemini:
|
||||||
err, usage = gemini.GeminiChatHandler(c, resp, info)
|
err, usage = gemini.GeminiChatHandler(c, resp, info)
|
||||||
case RequestModeLlama:
|
case RequestModeLlama:
|
||||||
err, usage = openai.OpenaiHandler(c, resp, info.PromptTokens, info.OriginModelName)
|
err, usage = openai.OpenaiHandler(c, resp, info)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
|
|||||||
@@ -81,10 +81,10 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycom
|
|||||||
if info.IsStream {
|
if info.IsStream {
|
||||||
err, usage = openai.OaiStreamHandler(c, resp, info)
|
err, usage = openai.OaiStreamHandler(c, resp, info)
|
||||||
} else {
|
} else {
|
||||||
err, usage = openai.OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName)
|
err, usage = openai.OpenaiHandler(c, resp, info)
|
||||||
}
|
}
|
||||||
case constant.RelayModeEmbeddings:
|
case constant.RelayModeEmbeddings:
|
||||||
err, usage = openai.OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName)
|
err, usage = openai.OpenaiHandler(c, resp, info)
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -72,7 +72,7 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycom
|
|||||||
if info.IsStream {
|
if info.IsStream {
|
||||||
err, usage = openai.OaiStreamHandler(c, resp, info)
|
err, usage = openai.OaiStreamHandler(c, resp, info)
|
||||||
} else {
|
} else {
|
||||||
err, usage = openai.OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName)
|
err, usage = openai.OpenaiHandler(c, resp, info)
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|||||||
35
relay/common_handler/rerank.go
Normal file
35
relay/common_handler/rerank.go
Normal file
@@ -0,0 +1,35 @@
|
|||||||
|
package common_handler
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"one-api/dto"
|
||||||
|
"one-api/service"
|
||||||
|
)
|
||||||
|
|
||||||
|
func RerankHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
|
||||||
|
responseBody, err := io.ReadAll(resp.Body)
|
||||||
|
if err != nil {
|
||||||
|
return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
|
||||||
|
}
|
||||||
|
err = resp.Body.Close()
|
||||||
|
if err != nil {
|
||||||
|
return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
|
||||||
|
}
|
||||||
|
var jinaResp dto.RerankResponse
|
||||||
|
err = json.Unmarshal(responseBody, &jinaResp)
|
||||||
|
if err != nil {
|
||||||
|
return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
jsonResponse, err := json.Marshal(jinaResp)
|
||||||
|
if err != nil {
|
||||||
|
return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
|
||||||
|
}
|
||||||
|
c.Writer.Header().Set("Content-Type", "application/json")
|
||||||
|
c.Writer.WriteHeader(resp.StatusCode)
|
||||||
|
_, err = c.Writer.Write(jsonResponse)
|
||||||
|
return nil, &jinaResp.Usage
|
||||||
|
}
|
||||||
@@ -18,7 +18,6 @@ import (
|
|||||||
"one-api/relay/channel/mokaai"
|
"one-api/relay/channel/mokaai"
|
||||||
"one-api/relay/channel/ollama"
|
"one-api/relay/channel/ollama"
|
||||||
"one-api/relay/channel/openai"
|
"one-api/relay/channel/openai"
|
||||||
"one-api/relay/channel/openrouter"
|
|
||||||
"one-api/relay/channel/palm"
|
"one-api/relay/channel/palm"
|
||||||
"one-api/relay/channel/perplexity"
|
"one-api/relay/channel/perplexity"
|
||||||
"one-api/relay/channel/siliconflow"
|
"one-api/relay/channel/siliconflow"
|
||||||
@@ -83,7 +82,7 @@ func GetAdaptor(apiType int) channel.Adaptor {
|
|||||||
case constant.APITypeBaiduV2:
|
case constant.APITypeBaiduV2:
|
||||||
return &baidu_v2.Adaptor{}
|
return &baidu_v2.Adaptor{}
|
||||||
case constant.APITypeOpenRouter:
|
case constant.APITypeOpenRouter:
|
||||||
return &openrouter.Adaptor{}
|
return &openai.Adaptor{}
|
||||||
case constant.APITypeXinference:
|
case constant.APITypeXinference:
|
||||||
return &openai.Adaptor{}
|
return &openai.Adaptor{}
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user