From 126f04e08fe3272bcad37ff04cf0756c5d4b522a Mon Sep 17 00:00:00 2001 From: Jerry Date: Wed, 22 Jan 2025 04:21:08 +0800 Subject: [PATCH 1/3] Support for MokaAI M3E --- Dockerfile | 2 +- common/constants.go | 3 +- controller/channel-test.go | 28 ++++- middleware/distributor.go | 2 + relay/channel/mokaai/adaptor.go | 104 +++++++++++++++++ relay/channel/mokaai/constants.go | 9 ++ relay/channel/mokaai/dto.go | 30 +++++ relay/channel/mokaai/relay-mokaai.go | 154 +++++++++++++++++++++++++ relay/constant/api_type.go | 4 +- relay/relay_adaptor.go | 3 + web/src/constants/channel.constants.js | 7 ++ 11 files changed, 341 insertions(+), 5 deletions(-) create mode 100644 relay/channel/mokaai/adaptor.go create mode 100644 relay/channel/mokaai/constants.go create mode 100644 relay/channel/mokaai/dto.go create mode 100644 relay/channel/mokaai/relay-mokaai.go diff --git a/Dockerfile b/Dockerfile index 44a7837a..4e0d0511 100644 --- a/Dockerfile +++ b/Dockerfile @@ -1,4 +1,4 @@ -FROM oven/bun:latest as builder +FROM oven/bun:latest AS builder WORKDIR /build COPY web/package.json . diff --git a/common/constants.go b/common/constants.go index e2acf83b..3c8d262a 100644 --- a/common/constants.go +++ b/common/constants.go @@ -231,7 +231,7 @@ const ( ChannelTypeVertexAi = 41 ChannelTypeMistral = 42 ChannelTypeDeepSeek = 43 - + ChannelTypeMokaAI = 47 ChannelTypeDummy // this one is only for count, do not add any channel after this ) @@ -281,4 +281,5 @@ var ChannelBaseURLs = []string{ "", //41 "https://api.mistral.ai", //42 "https://api.deepseek.com", //43 + "https://api.moka.ai", //43 } diff --git a/controller/channel-test.go b/controller/channel-test.go index 5f9c990f..ea325b47 100644 --- a/controller/channel-test.go +++ b/controller/channel-test.go @@ -41,14 +41,27 @@ func testChannel(channel *model.Channel, testModel string) (err error, openAIErr } w := httptest.NewRecorder() c, _ := gin.CreateTestContext(w) + + requestPath := "/v1/chat/completions" + + // 先判断是否为 Embedding 模型 + if strings.Contains(strings.ToLower(testModel), "embedding") || + strings.HasPrefix(testModel, "m3e") || // m3e 系列模型 + strings.Contains(testModel, "bge-") || // bge 系列模型 + testModel == "text-embedding-v1" || + channel.Type == common.ChannelTypeMokaAI{ // 其他 embedding 模型 + requestPath = "/v1/embeddings" // 修改请求路径 + } + c.Request = &http.Request{ Method: "POST", - URL: &url.URL{Path: "/v1/chat/completions"}, + URL: &url.URL{Path: requestPath}, // 使用动态路径 Body: nil, Header: make(http.Header), } if testModel == "" { + common.SysLog(fmt.Sprintf("testModel 为空, channel 的 TestModel 是 %s", string(*channel.TestModel))) if channel.TestModel != nil && *channel.TestModel != "" { testModel = *channel.TestModel } else { @@ -57,6 +70,7 @@ func testChannel(channel *model.Channel, testModel string) (err error, openAIErr } else { testModel = "gpt-3.5-turbo" } + common.SysLog(fmt.Sprintf("testModel 为空, channel 的 TestModel 为空:", string(testModel))) } } else { modelMapping := *channel.ModelMapping @@ -88,7 +102,7 @@ func testChannel(channel *model.Channel, testModel string) (err error, openAIErr request := buildTestRequest(testModel) meta.UpstreamModelName = testModel - common.SysLog(fmt.Sprintf("testing channel %d with model %s", channel.Id, testModel)) + common.SysLog(fmt.Sprintf("testing channel %d with model %s , meta %s ", channel.Id, testModel, meta)) adaptor.Init(meta) @@ -156,6 +170,16 @@ func buildTestRequest(model string) *dto.GeneralOpenAIRequest { Model: "", // this will be set later Stream: false, } + // 先判断是否为 Embedding 模型 + if strings.Contains(strings.ToLower(model), "embedding") || + strings.HasPrefix(model, "m3e") || // m3e 系列模型 + strings.Contains(model, "bge-") || // bge 系列模型 + model == "text-embedding-v1" { // 其他 embedding 模型 + // Embedding 请求 + testRequest.Input = []string{"hello world"} + return testRequest + } + // 并非Embedding 模型 if strings.HasPrefix(model, "o1") { testRequest.MaxCompletionTokens = 10 } else if strings.HasPrefix(model, "gemini-2.0-flash-thinking") { diff --git a/middleware/distributor.go b/middleware/distributor.go index 49cca260..c90f3e5e 100644 --- a/middleware/distributor.go +++ b/middleware/distributor.go @@ -239,5 +239,7 @@ func SetupContextForSelectedChannel(c *gin.Context, channel *model.Channel, mode c.Set("plugin", channel.Other) case common.ChannelCloudflare: c.Set("api_version", channel.Other) + case common.ChannelTypeMokaAI: + c.Set("api_version", channel.Other) } } diff --git a/relay/channel/mokaai/adaptor.go b/relay/channel/mokaai/adaptor.go new file mode 100644 index 00000000..e9fe34cf --- /dev/null +++ b/relay/channel/mokaai/adaptor.go @@ -0,0 +1,104 @@ +package mokaai + +import ( + "errors" + "fmt" + "io" + "net/http" + + "github.com/gin-gonic/gin" + // "one-api/relay/adaptor" + // "one-api/relay/meta" + // "one-api/relay/model" + // "one-api/relay/constant" + "one-api/dto" + "one-api/relay/channel" + relaycommon "one-api/relay/common" + "one-api/relay/constant" +) + +type Adaptor struct { +} + +// ConvertImageRequest implements adaptor.Adaptor. +func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) { + //TODO implement me + return nil, errors.New("not implemented") +} + +func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) { + //TODO implement me + return nil, errors.New("not implemented") +} + +func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) { + //TODO implement me + return nil, errors.New("not implemented") +} +func (a *Adaptor) Init(info *relaycommon.RelayInfo) { +} + + +func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) { + + var urlPrefix = info.BaseUrl + + switch info.RelayMode { + case constant.RelayModeChatCompletions: + return fmt.Sprintf("%s/chat/completions", urlPrefix), nil + case constant.RelayModeEmbeddings: + return fmt.Sprintf("%s/embeddings", urlPrefix), nil + default: + return fmt.Sprintf("%s/run/%s", urlPrefix, info.UpstreamModelName), nil + } +} + +func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error { + channel.SetupApiRequestHeader(info, c, req) + req.Set("Authorization", fmt.Sprintf("Bearer %s", info.ApiKey)) + return nil +} + +func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) { + if request == nil { + return nil, errors.New("request is nil") + } + switch info.RelayMode { + case constant.RelayModeChatCompletions: + return nil, errors.New("not implemented") + case constant.RelayModeEmbeddings: + // return ConvertCompletionsRequest(*request), nil + return ConvertEmbeddingRequest(*request), nil + default: + return nil, errors.New("not implemented") + } +} + +func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) { + return channel.DoApiRequest(a, c, info, requestBody) +} + +func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) { + switch info.RelayMode { + + case constant.RelayModeAudioTranscription: + case constant.RelayModeAudioTranslation: + case constant.RelayModeChatCompletions: + fallthrough + case constant.RelayModeEmbeddings: + if info.IsStream { + err, usage = StreamHandler(c, resp, info) + } else { + err, usage = Handler(c, resp, info) + } + } + return +} + +func (a *Adaptor) GetModelList() []string { + return ModelList +} + +func (a *Adaptor) GetChannelName() string { + return ChannelName +} diff --git a/relay/channel/mokaai/constants.go b/relay/channel/mokaai/constants.go new file mode 100644 index 00000000..415d83b7 --- /dev/null +++ b/relay/channel/mokaai/constants.go @@ -0,0 +1,9 @@ +package mokaai + +var ModelList = []string{ + "m3e-large", + "m3e-base", + "m3e-small", +} + +var ChannelName = "mokaai" \ No newline at end of file diff --git a/relay/channel/mokaai/dto.go b/relay/channel/mokaai/dto.go new file mode 100644 index 00000000..18d2dccb --- /dev/null +++ b/relay/channel/mokaai/dto.go @@ -0,0 +1,30 @@ +package mokaai + +import "one-api/dto" + + +type Request struct { + Messages []dto.Message `json:"messages,omitempty"` + Lora string `json:"lora,omitempty"` + MaxTokens int `json:"max_tokens,omitempty"` + Prompt string `json:"prompt,omitempty"` + Raw bool `json:"raw,omitempty"` + Stream bool `json:"stream,omitempty"` + Temperature float64 `json:"temperature,omitempty"` +} + +type Options struct { + Seed int `json:"seed,omitempty"` + Temperature *float64 `json:"temperature,omitempty"` + TopK int `json:"top_k,omitempty"` + TopP *float64 `json:"top_p,omitempty"` + FrequencyPenalty *float64 `json:"frequency_penalty,omitempty"` + PresencePenalty *float64 `json:"presence_penalty,omitempty"` + NumPredict int `json:"num_predict,omitempty"` + NumCtx int `json:"num_ctx,omitempty"` +} + +type EmbeddingRequest struct { + Model string `json:"model"` + Input []string `json:"input"` +} \ No newline at end of file diff --git a/relay/channel/mokaai/relay-mokaai.go b/relay/channel/mokaai/relay-mokaai.go new file mode 100644 index 00000000..44d7f7c2 --- /dev/null +++ b/relay/channel/mokaai/relay-mokaai.go @@ -0,0 +1,154 @@ +package mokaai + +import ( + "bufio" + "encoding/json" + "io" + "net/http" + "strings" + + // "one-api/common/ctxkey" + // "one-api/common/render" + + // "github.com/gin-gonic/gin" + // "one-api/common" + // "one-api/common/helper" + // "one-api/common/logger" + // "one-api/relay/adaptor/openai" + // "one-api/relay/model" + + "github.com/gin-gonic/gin" + "one-api/common" + "one-api/dto" + relaycommon "one-api/relay/common" + "one-api/service" + "time" +) + +func ConvertCompletionsRequest(textRequest dto.GeneralOpenAIRequest) *Request { + p, _ := textRequest.Prompt.(string) + return &Request{ + Prompt: p, + MaxTokens: textRequest.GetMaxTokens(), + Stream: textRequest.Stream, + Temperature: textRequest.Temperature, + } +} + +func ConvertEmbeddingRequest(request dto.GeneralOpenAIRequest) *EmbeddingRequest { + var input []string // Change input to []string + + switch v := request.Input.(type) { + case string: + input = []string{v} // Convert string to []string + case []string: + input = v // Already a []string, no conversion needed + case []interface{}: + for _, part := range v { + if str, ok := part.(string); ok { + input = append(input, str) // Append each string to the slice + } + } + } + + return &EmbeddingRequest{ + Model: request.Model, + Input: input, // Assign []string to Input + } +} + +func StreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) { + scanner := bufio.NewScanner(resp.Body) + scanner.Split(bufio.ScanLines) + + service.SetEventStreamHeaders(c) + id := service.GetResponseID(c) + var responseText string + isFirst := true + + for scanner.Scan() { + data := scanner.Text() + if len(data) < len("data: ") { + continue + } + data = strings.TrimPrefix(data, "data: ") + data = strings.TrimSuffix(data, "\r") + + if data == "[DONE]" { + break + } + + var response dto.ChatCompletionsStreamResponse + err := json.Unmarshal([]byte(data), &response) + if err != nil { + common.LogError(c, "error_unmarshalling_stream_response: "+err.Error()) + continue + } + for _, choice := range response.Choices { + choice.Delta.Role = "assistant" + responseText += choice.Delta.GetContentString() + } + response.Id = id + response.Model = info.UpstreamModelName + err = service.ObjectData(c, response) + if isFirst { + isFirst = false + info.FirstResponseTime = time.Now() + } + if err != nil { + common.LogError(c, "error_rendering_stream_response: "+err.Error()) + } + } + + if err := scanner.Err(); err != nil { + common.LogError(c, "error_scanning_stream_response: "+err.Error()) + } + usage, _ := service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens) + if info.ShouldIncludeUsage { + response := service.GenerateFinalUsageResponse(id, info.StartTime.Unix(), info.UpstreamModelName, *usage) + err := service.ObjectData(c, response) + if err != nil { + common.LogError(c, "error_rendering_final_usage_response: "+err.Error()) + } + } + service.Done(c) + + err := resp.Body.Close() + if err != nil { + common.LogError(c, "close_response_body_failed: "+err.Error()) + } + + return nil, usage +} + +func Handler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) { + responseBody, err := io.ReadAll(resp.Body) + if err != nil { + return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil + } + err = resp.Body.Close() + if err != nil { + return service.OpenAIErrorWrapperLocal(err, "close_response_body_failed", http.StatusInternalServerError), nil + } + var response dto.TextResponse + err = json.Unmarshal(responseBody, &response) + if err != nil { + return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil + } + response.Model = info.UpstreamModelName + var responseText string + for _, choice := range response.Choices { + responseText += choice.Message.StringContent() + } + usage, _ := service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens) + response.Usage = *usage + response.Id = service.GetResponseID(c) + jsonResponse, err := json.Marshal(response) + if err != nil { + return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil + } + c.Writer.Header().Set("Content-Type", "application/json") + c.Writer.WriteHeader(resp.StatusCode) + _, _ = c.Writer.Write(jsonResponse) + return nil, usage +} diff --git a/relay/constant/api_type.go b/relay/constant/api_type.go index c5c1d089..1a40a6ee 100644 --- a/relay/constant/api_type.go +++ b/relay/constant/api_type.go @@ -27,7 +27,7 @@ const ( APITypeVertexAi APITypeMistral APITypeDeepSeek - + APITypeMokaAI APITypeDummy // this one is only for count, do not add any channel after this ) @@ -78,6 +78,8 @@ func ChannelType2APIType(channelType int) (int, bool) { apiType = APITypeMistral case common.ChannelTypeDeepSeek: apiType = APITypeDeepSeek + case common.ChannelTypeMokaAI: + apiType = APITypeMokaAI } if apiType == -1 { return APITypeOpenAI, false diff --git a/relay/relay_adaptor.go b/relay/relay_adaptor.go index 1c7d11e9..9304bd6d 100644 --- a/relay/relay_adaptor.go +++ b/relay/relay_adaptor.go @@ -14,6 +14,7 @@ import ( "one-api/relay/channel/gemini" "one-api/relay/channel/jina" "one-api/relay/channel/mistral" + "one-api/relay/channel/mokaai" "one-api/relay/channel/ollama" "one-api/relay/channel/openai" "one-api/relay/channel/palm" @@ -74,6 +75,8 @@ func GetAdaptor(apiType int) channel.Adaptor { return &mistral.Adaptor{} case constant.APITypeDeepSeek: return &deepseek.Adaptor{} + case constant.APITypeMokaAI: + return &mokaai.Adaptor{} } return nil } diff --git a/web/src/constants/channel.constants.js b/web/src/constants/channel.constants.js index 32bf8bce..c1be95b7 100644 --- a/web/src/constants/channel.constants.js +++ b/web/src/constants/channel.constants.js @@ -125,5 +125,12 @@ export const CHANNEL_OPTIONS = [ value: 21, color: 'purple', label: '知识库:AI Proxy' + }, + { + key: 47, + text: '嵌入模型:MokaAI M3E', + value: 47, + color: 'purple', + label: '嵌入模型:MokaAI M3E' } ]; From 8a2d220cf4762e87f08eb574d6e2b953395e0070 Mon Sep 17 00:00:00 2001 From: Jerry Date: Wed, 22 Jan 2025 13:16:06 +0800 Subject: [PATCH 2/3] fix : chanel test did not refresh --- web/src/components/ChannelsTable.js | 27 ++++++++++++++++----------- 1 file changed, 16 insertions(+), 11 deletions(-) diff --git a/web/src/components/ChannelsTable.js b/web/src/components/ChannelsTable.js index 890e32ba..d62c2f13 100644 --- a/web/src/components/ChannelsTable.js +++ b/web/src/components/ChannelsTable.js @@ -44,7 +44,7 @@ function renderTimestamp(timestamp) { const ChannelsTable = () => { const { t } = useTranslation(); - + let type2label = undefined; const renderType = (type) => { @@ -559,7 +559,7 @@ const ChannelsTable = () => { if (!enableTagMode) { channelDates.push(channels[i]); } else { - let tag = channels[i].tag?channels[i].tag:""; + let tag = channels[i].tag ? channels[i].tag : ""; // find from channelTags let tagIndex = channelTags[tag]; let tagChannelDates = undefined; @@ -805,6 +805,9 @@ const ChannelsTable = () => { record.response_time = time * 1000; record.test_time = Date.now() / 1000; showInfo(t('通道 ${name} 测试成功,耗时 ${time.toFixed(2)} 秒。').replace('${name}', record.name).replace('${time.toFixed(2)}', time.toFixed(2))); + + // 刷新列表 + await refresh(); } else { showError(message); } @@ -838,6 +841,8 @@ const ChannelsTable = () => { record.balance = balance; record.balance_updated_time = Date.now() / 1000; showInfo(t('通道 ${name} 余额更新成功!').replace('${name}', record.name)); + // 刷新列表 + await refresh(); } else { showError(message); } @@ -1186,7 +1191,7 @@ const ChannelsTable = () => {
- + {t('标签聚合模式')} { }} /> + disabled={!enableBatchDelete} + theme="light" + type="primary" + style={{ marginRight: 8 }} + onClick={() => setShowBatchSetTag(true)} + > + {t('批量设置标签')} +
From 7588c42b420b3a21901396d6fc46e044ebe4b590 Mon Sep 17 00:00:00 2001 From: Jerry Date: Thu, 23 Jan 2025 05:54:39 +0800 Subject: [PATCH 3/3] Fix M3E not working --- controller/relay.go | 8 + .../channel/mokaai/dto.go => dto/embedding.go | 30 ++-- relay/channel/adapter.go | 1 + relay/channel/ali/adaptor.go | 5 + relay/channel/aws/adaptor.go | 6 + relay/channel/baidu/adaptor.go | 5 + relay/channel/claude/adaptor.go | 5 + relay/channel/cloudflare/adaptor.go | 6 + relay/channel/cohere/adaptor.go | 6 + relay/channel/deepseek/adaptor.go | 6 + relay/channel/dify/adaptor.go | 6 + relay/channel/gemini/adaptor.go | 6 + relay/channel/jina/adaptor.go | 6 + relay/channel/mistral/adaptor.go | 6 + relay/channel/mokaai/adaptor.go | 69 ++++----- relay/channel/mokaai/relay-mokaai.go | 143 +++++------------- relay/channel/ollama/adaptor.go | 6 + relay/channel/openai/adaptor.go | 5 + relay/channel/palm/adaptor.go | 6 + relay/channel/perplexity/adaptor.go | 6 + relay/channel/siliconflow/adaptor.go | 6 + relay/channel/tencent/adaptor.go | 6 + relay/channel/vertex/adaptor.go | 6 + relay/channel/xunfei/adaptor.go | 6 + relay/channel/zhipu/adaptor.go | 6 + relay/channel/zhipu_4v/adaptor.go | 6 + relay/relay_embedding.go | 127 ++++++++++++++++ 27 files changed, 338 insertions(+), 162 deletions(-) rename relay/channel/mokaai/dto.go => dto/embedding.go (54%) create mode 100644 relay/relay_embedding.go diff --git a/controller/relay.go b/controller/relay.go index 72d421e3..25af7e20 100644 --- a/controller/relay.go +++ b/controller/relay.go @@ -33,6 +33,8 @@ func relayHandler(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusCode err = relay.AudioHelper(c) case relayconstant.RelayModeRerank: err = relay.RerankHelper(c, relayMode) + case relayconstant.RelayModeEmbeddings: + err = relay.EmbeddingHelper(c,relayMode) default: err = relay.TextHelper(c) } @@ -55,6 +57,11 @@ func Relay(c *gin.Context) { originalModel := c.GetString("original_model") var openaiErr *dto.OpenAIErrorWithStatusCode + //获取request body 并输出到日志 + requestBody, _ := common.GetRequestBody(c) + common.LogInfo(c, fmt.Sprintf("relayMode: %d ,request body: %s",relayMode, string(requestBody))) + + for i := 0; i <= common.RetryTimes; i++ { channel, err := getChannel(c, group, originalModel, i) if err != nil { @@ -154,6 +161,7 @@ func WssRelay(c *gin.Context) { } func relayRequest(c *gin.Context, relayMode int, channel *model.Channel) *dto.OpenAIErrorWithStatusCode { + common.LogInfo(c, fmt.Sprintf("relayMode: %d ,channel Id : %s",relayMode, string(channel.Id))) addUsedChannel(c, channel.Id) requestBody, _ := common.GetRequestBody(c) c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody)) diff --git a/relay/channel/mokaai/dto.go b/dto/embedding.go similarity index 54% rename from relay/channel/mokaai/dto.go rename to dto/embedding.go index 18d2dccb..828faaab 100644 --- a/relay/channel/mokaai/dto.go +++ b/dto/embedding.go @@ -1,19 +1,6 @@ -package mokaai +package dto -import "one-api/dto" - - -type Request struct { - Messages []dto.Message `json:"messages,omitempty"` - Lora string `json:"lora,omitempty"` - MaxTokens int `json:"max_tokens,omitempty"` - Prompt string `json:"prompt,omitempty"` - Raw bool `json:"raw,omitempty"` - Stream bool `json:"stream,omitempty"` - Temperature float64 `json:"temperature,omitempty"` -} - -type Options struct { +type EmbeddingOptions struct { Seed int `json:"seed,omitempty"` Temperature *float64 `json:"temperature,omitempty"` TopK int `json:"top_k,omitempty"` @@ -27,4 +14,17 @@ type Options struct { type EmbeddingRequest struct { Model string `json:"model"` Input []string `json:"input"` +} + +type EmbeddingResponseItem struct { + Object string `json:"object"` + Index int `json:"index"` + Embedding []float64 `json:"embedding"` +} + +type EmbeddingResponse struct { + Object string `json:"object"` + Data []EmbeddingResponseItem `json:"data"` + Model string `json:"model"` + Usage `json:"usage"` } \ No newline at end of file diff --git a/relay/channel/adapter.go b/relay/channel/adapter.go index d72db6e4..c970fd48 100644 --- a/relay/channel/adapter.go +++ b/relay/channel/adapter.go @@ -15,6 +15,7 @@ type Adaptor interface { SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) + ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.EmbeddingRequest) (any, error) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) diff --git a/relay/channel/ali/adaptor.go b/relay/channel/ali/adaptor.go index aa01ca66..c4974a62 100644 --- a/relay/channel/ali/adaptor.go +++ b/relay/channel/ali/adaptor.go @@ -67,6 +67,11 @@ func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dt return nil, errors.New("not implemented") } +func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.EmbeddingRequest) (any, error) { + //TODO implement me + return nil, errors.New("not implemented") +} + func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) { //TODO implement me return nil, errors.New("not implemented") diff --git a/relay/channel/aws/adaptor.go b/relay/channel/aws/adaptor.go index be72c04c..5a3d09b9 100644 --- a/relay/channel/aws/adaptor.go +++ b/relay/channel/aws/adaptor.go @@ -59,6 +59,12 @@ func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dt return nil, nil } +func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.EmbeddingRequest) (any, error) { + //TODO implement me + return nil, errors.New("not implemented") +} + + func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) { return nil, nil } diff --git a/relay/channel/baidu/adaptor.go b/relay/channel/baidu/adaptor.go index 3991a5e9..35271b41 100644 --- a/relay/channel/baidu/adaptor.go +++ b/relay/channel/baidu/adaptor.go @@ -122,6 +122,11 @@ func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dt return nil, nil } +func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.EmbeddingRequest) (any, error) { + //TODO implement me + return nil, errors.New("not implemented") +} + func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) { return channel.DoApiRequest(a, c, info, requestBody) } diff --git a/relay/channel/claude/adaptor.go b/relay/channel/claude/adaptor.go index 488d87dc..83168382 100644 --- a/relay/channel/claude/adaptor.go +++ b/relay/channel/claude/adaptor.go @@ -73,6 +73,11 @@ func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dt return nil, nil } +func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.EmbeddingRequest) (any, error) { + //TODO implement me + return nil, errors.New("not implemented") +} + func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) { return channel.DoApiRequest(a, c, info, requestBody) } diff --git a/relay/channel/cloudflare/adaptor.go b/relay/channel/cloudflare/adaptor.go index fc0ec271..cf41d9d7 100644 --- a/relay/channel/cloudflare/adaptor.go +++ b/relay/channel/cloudflare/adaptor.go @@ -56,6 +56,12 @@ func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dt return request, nil } +func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.EmbeddingRequest) (any, error) { + //TODO implement me + return nil, errors.New("not implemented") +} + + func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) { // 添加文件字段 file, _, err := c.Request.FormFile("file") diff --git a/relay/channel/cohere/adaptor.go b/relay/channel/cohere/adaptor.go index f8b190ec..d552a53b 100644 --- a/relay/channel/cohere/adaptor.go +++ b/relay/channel/cohere/adaptor.go @@ -54,6 +54,12 @@ func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dt return requestConvertRerank2Cohere(request), nil } +func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.EmbeddingRequest) (any, error) { + //TODO implement me + return nil, errors.New("not implemented") +} + + func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) { if info.RelayMode == constant.RelayModeRerank { err, usage = cohereRerankHandler(c, resp, info) diff --git a/relay/channel/deepseek/adaptor.go b/relay/channel/deepseek/adaptor.go index cc94a58f..1682dc3f 100644 --- a/relay/channel/deepseek/adaptor.go +++ b/relay/channel/deepseek/adaptor.go @@ -49,6 +49,12 @@ func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dt return nil, nil } +func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.EmbeddingRequest) (any, error) { + //TODO implement me + return nil, errors.New("not implemented") +} + + func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) { return channel.DoApiRequest(a, c, info, requestBody) } diff --git a/relay/channel/dify/adaptor.go b/relay/channel/dify/adaptor.go index 53ba26e6..ce73c78c 100644 --- a/relay/channel/dify/adaptor.go +++ b/relay/channel/dify/adaptor.go @@ -48,6 +48,12 @@ func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dt return nil, nil } +func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.EmbeddingRequest) (any, error) { + //TODO implement me + return nil, errors.New("not implemented") +} + + func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) { return channel.DoApiRequest(a, c, info, requestBody) } diff --git a/relay/channel/gemini/adaptor.go b/relay/channel/gemini/adaptor.go index 9a5bc251..681e9988 100644 --- a/relay/channel/gemini/adaptor.go +++ b/relay/channel/gemini/adaptor.go @@ -68,6 +68,12 @@ func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dt return nil, nil } +func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.EmbeddingRequest) (any, error) { + //TODO implement me + return nil, errors.New("not implemented") +} + + func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) { return channel.DoApiRequest(a, c, info, requestBody) } diff --git a/relay/channel/jina/adaptor.go b/relay/channel/jina/adaptor.go index ad488f28..98fb073d 100644 --- a/relay/channel/jina/adaptor.go +++ b/relay/channel/jina/adaptor.go @@ -55,6 +55,12 @@ func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dt return request, nil } +func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.EmbeddingRequest) (any, error) { + //TODO implement me + return nil, errors.New("not implemented") +} + + func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) { if info.RelayMode == constant.RelayModeRerank { err, usage = jinaRerankHandler(c, resp) diff --git a/relay/channel/mistral/adaptor.go b/relay/channel/mistral/adaptor.go index 4ab1a35a..c99e5396 100644 --- a/relay/channel/mistral/adaptor.go +++ b/relay/channel/mistral/adaptor.go @@ -50,6 +50,12 @@ func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dt return nil, nil } +func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.EmbeddingRequest) (any, error) { + //TODO implement me + return nil, errors.New("not implemented") +} + + func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) { return channel.DoApiRequest(a, c, info, requestBody) } diff --git a/relay/channel/mokaai/adaptor.go b/relay/channel/mokaai/adaptor.go index e9fe34cf..9670ec94 100644 --- a/relay/channel/mokaai/adaptor.go +++ b/relay/channel/mokaai/adaptor.go @@ -3,54 +3,46 @@ package mokaai import ( "errors" "fmt" + "github.com/gin-gonic/gin" "io" "net/http" - - "github.com/gin-gonic/gin" - // "one-api/relay/adaptor" - // "one-api/relay/meta" - // "one-api/relay/model" - // "one-api/relay/constant" "one-api/dto" "one-api/relay/channel" relaycommon "one-api/relay/common" "one-api/relay/constant" + "strings" ) type Adaptor struct { } -// ConvertImageRequest implements adaptor.Adaptor. -func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) { - //TODO implement me - return nil, errors.New("not implemented") -} - func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) { //TODO implement me return nil, errors.New("not implemented") } -func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) { +func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) { //TODO implement me return nil, errors.New("not implemented") } -func (a *Adaptor) Init(info *relaycommon.RelayInfo) { + +func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.EmbeddingRequest) (any, error) { + //TODO implement me + return request, nil } +func (a *Adaptor) Init(info *relaycommon.RelayInfo) { -func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) { - - var urlPrefix = info.BaseUrl - - switch info.RelayMode { - case constant.RelayModeChatCompletions: - return fmt.Sprintf("%s/chat/completions", urlPrefix), nil - case constant.RelayModeEmbeddings: - return fmt.Sprintf("%s/embeddings", urlPrefix), nil - default: - return fmt.Sprintf("%s/run/%s", urlPrefix, info.UpstreamModelName), nil +} + +func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) { + // https://cloud.baidu.com/doc/WENXINWORKSHOP/s/clntwmv7t + suffix := "chat/" + if strings.HasPrefix(info.UpstreamModelName, "m3e") { + suffix = "embeddings" } + fullRequestURL := fmt.Sprintf("%s/%s", info.BaseUrl, suffix) + return fullRequestURL, nil } func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error { @@ -64,33 +56,30 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, re return nil, errors.New("request is nil") } switch info.RelayMode { - case constant.RelayModeChatCompletions: - return nil, errors.New("not implemented") - case constant.RelayModeEmbeddings: - // return ConvertCompletionsRequest(*request), nil - return ConvertEmbeddingRequest(*request), nil + case constant.RelayModeEmbeddings: + baiduEmbeddingRequest := embeddingRequestOpenAI2Moka(*request) + return baiduEmbeddingRequest, nil default: return nil, errors.New("not implemented") } } +func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) { + return nil, nil +} + func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) { return channel.DoApiRequest(a, c, info, requestBody) } func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) { - switch info.RelayMode { - case constant.RelayModeAudioTranscription: - case constant.RelayModeAudioTranslation: - case constant.RelayModeChatCompletions: - fallthrough + switch info.RelayMode { case constant.RelayModeEmbeddings: - if info.IsStream { - err, usage = StreamHandler(c, resp, info) - } else { - err, usage = Handler(c, resp, info) - } + err, usage = mokaEmbeddingHandler(c, resp) + default: + // err, usage = mokaHandler(c, resp) + } return } diff --git a/relay/channel/mokaai/relay-mokaai.go b/relay/channel/mokaai/relay-mokaai.go index 44d7f7c2..d7580d7a 100644 --- a/relay/channel/mokaai/relay-mokaai.go +++ b/relay/channel/mokaai/relay-mokaai.go @@ -1,41 +1,15 @@ package mokaai import ( - "bufio" "encoding/json" + "github.com/gin-gonic/gin" "io" "net/http" - "strings" - - // "one-api/common/ctxkey" - // "one-api/common/render" - - // "github.com/gin-gonic/gin" - // "one-api/common" - // "one-api/common/helper" - // "one-api/common/logger" - // "one-api/relay/adaptor/openai" - // "one-api/relay/model" - - "github.com/gin-gonic/gin" - "one-api/common" "one-api/dto" - relaycommon "one-api/relay/common" "one-api/service" - "time" ) -func ConvertCompletionsRequest(textRequest dto.GeneralOpenAIRequest) *Request { - p, _ := textRequest.Prompt.(string) - return &Request{ - Prompt: p, - MaxTokens: textRequest.GetMaxTokens(), - Stream: textRequest.Stream, - Temperature: textRequest.Temperature, - } -} - -func ConvertEmbeddingRequest(request dto.GeneralOpenAIRequest) *EmbeddingRequest { +func embeddingRequestOpenAI2Moka(request dto.GeneralOpenAIRequest) *dto.EmbeddingRequest { var input []string // Change input to []string switch v := request.Input.(type) { @@ -50,105 +24,60 @@ func ConvertEmbeddingRequest(request dto.GeneralOpenAIRequest) *EmbeddingRequest } } } - - return &EmbeddingRequest{ - Model: request.Model, - Input: input, // Assign []string to Input + return &dto.EmbeddingRequest{ + Input: input, + Model: request.Model, } } -func StreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) { - scanner := bufio.NewScanner(resp.Body) - scanner.Split(bufio.ScanLines) - - service.SetEventStreamHeaders(c) - id := service.GetResponseID(c) - var responseText string - isFirst := true - - for scanner.Scan() { - data := scanner.Text() - if len(data) < len("data: ") { - continue - } - data = strings.TrimPrefix(data, "data: ") - data = strings.TrimSuffix(data, "\r") - - if data == "[DONE]" { - break - } - - var response dto.ChatCompletionsStreamResponse - err := json.Unmarshal([]byte(data), &response) - if err != nil { - common.LogError(c, "error_unmarshalling_stream_response: "+err.Error()) - continue - } - for _, choice := range response.Choices { - choice.Delta.Role = "assistant" - responseText += choice.Delta.GetContentString() - } - response.Id = id - response.Model = info.UpstreamModelName - err = service.ObjectData(c, response) - if isFirst { - isFirst = false - info.FirstResponseTime = time.Now() - } - if err != nil { - common.LogError(c, "error_rendering_stream_response: "+err.Error()) - } +func embeddingResponseMoka2OpenAI(response *dto.EmbeddingResponse) *dto.OpenAIEmbeddingResponse { + openAIEmbeddingResponse := dto.OpenAIEmbeddingResponse{ + Object: "list", + Data: make([]dto.OpenAIEmbeddingResponseItem, 0, len(response.Data)), + Model: "baidu-embedding", + Usage: response.Usage, } - - if err := scanner.Err(); err != nil { - common.LogError(c, "error_scanning_stream_response: "+err.Error()) + for _, item := range response.Data { + openAIEmbeddingResponse.Data = append(openAIEmbeddingResponse.Data, dto.OpenAIEmbeddingResponseItem{ + Object: item.Object, + Index: item.Index, + Embedding: item.Embedding, + }) } - usage, _ := service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens) - if info.ShouldIncludeUsage { - response := service.GenerateFinalUsageResponse(id, info.StartTime.Unix(), info.UpstreamModelName, *usage) - err := service.ObjectData(c, response) - if err != nil { - common.LogError(c, "error_rendering_final_usage_response: "+err.Error()) - } - } - service.Done(c) - - err := resp.Body.Close() - if err != nil { - common.LogError(c, "close_response_body_failed: "+err.Error()) - } - - return nil, usage + return &openAIEmbeddingResponse } -func Handler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) { +func mokaEmbeddingHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) { + var baiduResponse dto.EmbeddingResponse responseBody, err := io.ReadAll(resp.Body) if err != nil { return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil } err = resp.Body.Close() if err != nil { - return service.OpenAIErrorWrapperLocal(err, "close_response_body_failed", http.StatusInternalServerError), nil + return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil } - var response dto.TextResponse - err = json.Unmarshal(responseBody, &response) + err = json.Unmarshal(responseBody, &baiduResponse) if err != nil { return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil } - response.Model = info.UpstreamModelName - var responseText string - for _, choice := range response.Choices { - responseText += choice.Message.StringContent() - } - usage, _ := service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens) - response.Usage = *usage - response.Id = service.GetResponseID(c) - jsonResponse, err := json.Marshal(response) + // if baiduResponse.ErrorMsg != "" { + // return &dto.OpenAIErrorWithStatusCode{ + // Error: dto.OpenAIError{ + // Type: "baidu_error", + // Param: "", + // }, + // StatusCode: resp.StatusCode, + // }, nil + // } + fullTextResponse := embeddingResponseMoka2OpenAI(&baiduResponse) + jsonResponse, err := json.Marshal(fullTextResponse) if err != nil { return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil } c.Writer.Header().Set("Content-Type", "application/json") c.Writer.WriteHeader(resp.StatusCode) - _, _ = c.Writer.Write(jsonResponse) - return nil, usage + _, err = c.Writer.Write(jsonResponse) + return nil, &fullTextResponse.Usage } + diff --git a/relay/channel/ollama/adaptor.go b/relay/channel/ollama/adaptor.go index 30798402..d5185084 100644 --- a/relay/channel/ollama/adaptor.go +++ b/relay/channel/ollama/adaptor.go @@ -58,6 +58,12 @@ func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dt return nil, nil } +func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.EmbeddingRequest) (any, error) { + //TODO implement me + return nil, errors.New("not implemented") +} + + func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) { return channel.DoApiRequest(a, c, info, requestBody) } diff --git a/relay/channel/openai/adaptor.go b/relay/channel/openai/adaptor.go index 230ab55c..718b26a1 100644 --- a/relay/channel/openai/adaptor.go +++ b/relay/channel/openai/adaptor.go @@ -129,6 +129,11 @@ func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dt return nil, errors.New("not implemented") } +func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.EmbeddingRequest) (any, error) { + //TODO implement me + return nil, errors.New("not implemented") +} + func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) { a.ResponseFormat = request.ResponseFormat if info.RelayMode == constant.RelayModeAudioSpeech { diff --git a/relay/channel/palm/adaptor.go b/relay/channel/palm/adaptor.go index 91272337..f38fa95b 100644 --- a/relay/channel/palm/adaptor.go +++ b/relay/channel/palm/adaptor.go @@ -49,6 +49,12 @@ func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dt return nil, nil } +func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.EmbeddingRequest) (any, error) { + //TODO implement me + return nil, errors.New("not implemented") +} + + func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) { return channel.DoApiRequest(a, c, info, requestBody) } diff --git a/relay/channel/perplexity/adaptor.go b/relay/channel/perplexity/adaptor.go index 18b66a9a..2b27bdb1 100644 --- a/relay/channel/perplexity/adaptor.go +++ b/relay/channel/perplexity/adaptor.go @@ -52,6 +52,12 @@ func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dt return nil, nil } +func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.EmbeddingRequest) (any, error) { + //TODO implement me + return nil, errors.New("not implemented") +} + + func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) { return channel.DoApiRequest(a, c, info, requestBody) } diff --git a/relay/channel/siliconflow/adaptor.go b/relay/channel/siliconflow/adaptor.go index ac722b22..f9ddedeb 100644 --- a/relay/channel/siliconflow/adaptor.go +++ b/relay/channel/siliconflow/adaptor.go @@ -58,6 +58,12 @@ func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dt return request, nil } +func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.EmbeddingRequest) (any, error) { + //TODO implement me + return nil, errors.New("not implemented") +} + + func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) { switch info.RelayMode { case constant.RelayModeRerank: diff --git a/relay/channel/tencent/adaptor.go b/relay/channel/tencent/adaptor.go index d831cc83..768ef646 100644 --- a/relay/channel/tencent/adaptor.go +++ b/relay/channel/tencent/adaptor.go @@ -73,6 +73,12 @@ func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dt return nil, nil } +func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.EmbeddingRequest) (any, error) { + //TODO implement me + return nil, errors.New("not implemented") +} + + func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) { return channel.DoApiRequest(a, c, info, requestBody) } diff --git a/relay/channel/vertex/adaptor.go b/relay/channel/vertex/adaptor.go index 764e5c4b..07659c20 100644 --- a/relay/channel/vertex/adaptor.go +++ b/relay/channel/vertex/adaptor.go @@ -151,6 +151,12 @@ func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dt return nil, nil } +func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.EmbeddingRequest) (any, error) { + //TODO implement me + return nil, errors.New("not implemented") +} + + func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) { return channel.DoApiRequest(a, c, info, requestBody) } diff --git a/relay/channel/xunfei/adaptor.go b/relay/channel/xunfei/adaptor.go index 31d426a6..71fd1367 100644 --- a/relay/channel/xunfei/adaptor.go +++ b/relay/channel/xunfei/adaptor.go @@ -50,6 +50,12 @@ func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dt return nil, nil } +func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.EmbeddingRequest) (any, error) { + //TODO implement me + return nil, errors.New("not implemented") +} + + func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) { // xunfei's request is not http request, so we don't need to do anything here dummyResp := &http.Response{} diff --git a/relay/channel/zhipu/adaptor.go b/relay/channel/zhipu/adaptor.go index f0538edc..87ff20d5 100644 --- a/relay/channel/zhipu/adaptor.go +++ b/relay/channel/zhipu/adaptor.go @@ -56,6 +56,12 @@ func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dt return nil, nil } +func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.EmbeddingRequest) (any, error) { + //TODO implement me + return nil, errors.New("not implemented") +} + + func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) { return channel.DoApiRequest(a, c, info, requestBody) } diff --git a/relay/channel/zhipu_4v/adaptor.go b/relay/channel/zhipu_4v/adaptor.go index 3d46b799..5983c1d9 100644 --- a/relay/channel/zhipu_4v/adaptor.go +++ b/relay/channel/zhipu_4v/adaptor.go @@ -53,6 +53,12 @@ func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dt return nil, nil } +func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.EmbeddingRequest) (any, error) { + //TODO implement me + return nil, errors.New("not implemented") +} + + func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) { return channel.DoApiRequest(a, c, info, requestBody) } diff --git a/relay/relay_embedding.go b/relay/relay_embedding.go new file mode 100644 index 00000000..a3304d8e --- /dev/null +++ b/relay/relay_embedding.go @@ -0,0 +1,127 @@ +package relay + +import ( + "bytes" + "encoding/json" + "fmt" + "github.com/gin-gonic/gin" + "net/http" + "one-api/common" + "one-api/dto" + relaycommon "one-api/relay/common" + relayconstant "one-api/relay/constant" + "one-api/service" + "one-api/setting" +) + +func getEmbeddingPromptToken(embeddingRequest dto.EmbeddingRequest) int { + token, _ := service.CountTokenInput(embeddingRequest.Input, embeddingRequest.Model) + return token +} + +func EmbeddingHelper(c *gin.Context, relayMode int) (openaiErr *dto.OpenAIErrorWithStatusCode) { + relayInfo := relaycommon.GenRelayInfo(c) + + var embeddingRequest *dto.EmbeddingRequest + err := common.UnmarshalBodyReusable(c, &embeddingRequest) + if err != nil { + common.LogError(c, fmt.Sprintf("getAndValidateTextRequest failed: %s", err.Error())) + return service.OpenAIErrorWrapperLocal(err, "invalid_text_request", http.StatusBadRequest) + } + if relayMode == relayconstant.RelayModeModerations && embeddingRequest.Model == "" { + embeddingRequest.Model = "m3e-base" + } + if relayMode == relayconstant.RelayModeEmbeddings && embeddingRequest.Model == "" { + embeddingRequest.Model = c.Param("model") + } + if embeddingRequest.Input == nil || len(embeddingRequest.Input) == 0 { + return service.OpenAIErrorWrapperLocal(fmt.Errorf("input is empty"), "invalid_input", http.StatusBadRequest) + } + // map model name + modelMapping := c.GetString("model_mapping") + //isModelMapped := false + if modelMapping != "" && modelMapping != "{}" { + modelMap := make(map[string]string) + err := json.Unmarshal([]byte(modelMapping), &modelMap) + if err != nil { + return service.OpenAIErrorWrapperLocal(err, "unmarshal_model_mapping_failed", http.StatusInternalServerError) + } + if modelMap[embeddingRequest.Model] != "" { + embeddingRequest.Model = modelMap[embeddingRequest.Model] + // set upstream model name + //isModelMapped = true + } + } + + relayInfo.UpstreamModelName = embeddingRequest.Model + modelPrice, success := common.GetModelPrice(embeddingRequest.Model, false) + groupRatio := setting.GetGroupRatio(relayInfo.Group) + + var preConsumedQuota int + var ratio float64 + var modelRatio float64 + + promptToken := getEmbeddingPromptToken(*embeddingRequest) + if !success { + preConsumedTokens := promptToken + modelRatio = common.GetModelRatio(embeddingRequest.Model) + ratio = modelRatio * groupRatio + preConsumedQuota = int(float64(preConsumedTokens) * ratio) + } else { + preConsumedQuota = int(modelPrice * common.QuotaPerUnit * groupRatio) + } + relayInfo.PromptTokens = promptToken + + // pre-consume quota 预消耗配额 + preConsumedQuota, userQuota, openaiErr := preConsumeQuota(c, preConsumedQuota, relayInfo) + if openaiErr != nil { + return openaiErr + } + defer func() { + if openaiErr != nil { + returnPreConsumedQuota(c, relayInfo, userQuota, preConsumedQuota) + } + }() + + adaptor := GetAdaptor(relayInfo.ApiType) + if adaptor == nil { + return service.OpenAIErrorWrapperLocal(fmt.Errorf("invalid api type: %d", relayInfo.ApiType), "invalid_api_type", http.StatusBadRequest) + } + adaptor.Init(relayInfo) + + convertedRequest, err := adaptor.ConvertEmbeddingRequest(c,relayInfo,*embeddingRequest) + + if err != nil { + return service.OpenAIErrorWrapperLocal(err, "convert_request_failed", http.StatusInternalServerError) + } + jsonData, err := json.Marshal(convertedRequest) + if err != nil { + return service.OpenAIErrorWrapperLocal(err, "json_marshal_failed", http.StatusInternalServerError) + } + requestBody := bytes.NewBuffer(jsonData) + statusCodeMappingStr := c.GetString("status_code_mapping") + resp, err := adaptor.DoRequest(c,relayInfo, requestBody) + if err != nil { + return service.OpenAIErrorWrapper(err, "do_request_failed", http.StatusInternalServerError) + } + + var httpResp *http.Response + if resp != nil { + httpResp = resp.(*http.Response) + if httpResp.StatusCode != http.StatusOK { + openaiErr = service.RelayErrorHandler(httpResp) + // reset status code 重置状态码 + service.ResetStatusCode(openaiErr, statusCodeMappingStr) + return openaiErr + } + } + + usage, openaiErr := adaptor.DoResponse(c, httpResp, relayInfo) + if openaiErr != nil { + // reset status code 重置状态码 + service.ResetStatusCode(openaiErr, statusCodeMappingStr) + return openaiErr + } + postConsumeQuota(c, relayInfo, embeddingRequest.Model, usage.(*dto.Usage), ratio, preConsumedQuota, userQuota, modelRatio, groupRatio, modelPrice, success, "") + return nil +}