diff --git a/Dockerfile b/Dockerfile index 9a1d9b5f..b341b22d 100644 --- a/Dockerfile +++ b/Dockerfile @@ -1,4 +1,4 @@ -FROM oven/bun:latest as builder +FROM oven/bun:latest AS builder WORKDIR /build COPY web/package.json . diff --git a/common/constants.go b/common/constants.go index e2acf83b..3c8d262a 100644 --- a/common/constants.go +++ b/common/constants.go @@ -231,7 +231,7 @@ const ( ChannelTypeVertexAi = 41 ChannelTypeMistral = 42 ChannelTypeDeepSeek = 43 - + ChannelTypeMokaAI = 47 ChannelTypeDummy // this one is only for count, do not add any channel after this ) @@ -281,4 +281,5 @@ var ChannelBaseURLs = []string{ "", //41 "https://api.mistral.ai", //42 "https://api.deepseek.com", //43 + "https://api.moka.ai", //43 } diff --git a/controller/channel-test.go b/controller/channel-test.go index d9083618..93f92f4c 100644 --- a/controller/channel-test.go +++ b/controller/channel-test.go @@ -41,14 +41,27 @@ func testChannel(channel *model.Channel, testModel string) (err error, openAIErr } w := httptest.NewRecorder() c, _ := gin.CreateTestContext(w) + + requestPath := "/v1/chat/completions" + + // 先判断是否为 Embedding 模型 + if strings.Contains(strings.ToLower(testModel), "embedding") || + strings.HasPrefix(testModel, "m3e") || // m3e 系列模型 + strings.Contains(testModel, "bge-") || // bge 系列模型 + testModel == "text-embedding-v1" || + channel.Type == common.ChannelTypeMokaAI{ // 其他 embedding 模型 + requestPath = "/v1/embeddings" // 修改请求路径 + } + c.Request = &http.Request{ Method: "POST", - URL: &url.URL{Path: "/v1/chat/completions"}, + URL: &url.URL{Path: requestPath}, // 使用动态路径 Body: nil, Header: make(http.Header), } if testModel == "" { + common.SysLog(fmt.Sprintf("testModel 为空, channel 的 TestModel 是 %s", string(*channel.TestModel))) if channel.TestModel != nil && *channel.TestModel != "" { testModel = *channel.TestModel } else { @@ -57,6 +70,7 @@ func testChannel(channel *model.Channel, testModel string) (err error, openAIErr } else { testModel = "gpt-3.5-turbo" } + common.SysLog(fmt.Sprintf("testModel 为空, channel 的 TestModel 为空:", string(testModel))) } } @@ -88,7 +102,7 @@ func testChannel(channel *model.Channel, testModel string) (err error, openAIErr request := buildTestRequest(testModel) meta.UpstreamModelName = testModel - common.SysLog(fmt.Sprintf("testing channel %d with model %s", channel.Id, testModel)) + common.SysLog(fmt.Sprintf("testing channel %d with model %s , meta %s ", channel.Id, testModel, meta)) adaptor.Init(meta) @@ -156,6 +170,17 @@ func buildTestRequest(model string) *dto.GeneralOpenAIRequest { Model: "", // this will be set later Stream: false, } + + // 先判断是否为 Embedding 模型 + if strings.Contains(strings.ToLower(model), "embedding") || + strings.HasPrefix(model, "m3e") || // m3e 系列模型 + strings.Contains(model, "bge-") || // bge 系列模型 + model == "text-embedding-v1" { // 其他 embedding 模型 + // Embedding 请求 + testRequest.Input = []string{"hello world"} + return testRequest + } + // 并非Embedding 模型 if strings.HasPrefix(model, "o1") || strings.HasPrefix(model, "o3") { testRequest.MaxCompletionTokens = 10 } else { diff --git a/controller/relay.go b/controller/relay.go index 72d421e3..25af7e20 100644 --- a/controller/relay.go +++ b/controller/relay.go @@ -33,6 +33,8 @@ func relayHandler(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusCode err = relay.AudioHelper(c) case relayconstant.RelayModeRerank: err = relay.RerankHelper(c, relayMode) + case relayconstant.RelayModeEmbeddings: + err = relay.EmbeddingHelper(c,relayMode) default: err = relay.TextHelper(c) } @@ -55,6 +57,11 @@ func Relay(c *gin.Context) { originalModel := c.GetString("original_model") var openaiErr *dto.OpenAIErrorWithStatusCode + //获取request body 并输出到日志 + requestBody, _ := common.GetRequestBody(c) + common.LogInfo(c, fmt.Sprintf("relayMode: %d ,request body: %s",relayMode, string(requestBody))) + + for i := 0; i <= common.RetryTimes; i++ { channel, err := getChannel(c, group, originalModel, i) if err != nil { @@ -154,6 +161,7 @@ func WssRelay(c *gin.Context) { } func relayRequest(c *gin.Context, relayMode int, channel *model.Channel) *dto.OpenAIErrorWithStatusCode { + common.LogInfo(c, fmt.Sprintf("relayMode: %d ,channel Id : %s",relayMode, string(channel.Id))) addUsedChannel(c, channel.Id) requestBody, _ := common.GetRequestBody(c) c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody)) diff --git a/dto/embedding.go b/dto/embedding.go new file mode 100644 index 00000000..828faaab --- /dev/null +++ b/dto/embedding.go @@ -0,0 +1,30 @@ +package dto + +type EmbeddingOptions struct { + Seed int `json:"seed,omitempty"` + Temperature *float64 `json:"temperature,omitempty"` + TopK int `json:"top_k,omitempty"` + TopP *float64 `json:"top_p,omitempty"` + FrequencyPenalty *float64 `json:"frequency_penalty,omitempty"` + PresencePenalty *float64 `json:"presence_penalty,omitempty"` + NumPredict int `json:"num_predict,omitempty"` + NumCtx int `json:"num_ctx,omitempty"` +} + +type EmbeddingRequest struct { + Model string `json:"model"` + Input []string `json:"input"` +} + +type EmbeddingResponseItem struct { + Object string `json:"object"` + Index int `json:"index"` + Embedding []float64 `json:"embedding"` +} + +type EmbeddingResponse struct { + Object string `json:"object"` + Data []EmbeddingResponseItem `json:"data"` + Model string `json:"model"` + Usage `json:"usage"` +} \ No newline at end of file diff --git a/middleware/distributor.go b/middleware/distributor.go index 49cca260..c90f3e5e 100644 --- a/middleware/distributor.go +++ b/middleware/distributor.go @@ -239,5 +239,7 @@ func SetupContextForSelectedChannel(c *gin.Context, channel *model.Channel, mode c.Set("plugin", channel.Other) case common.ChannelCloudflare: c.Set("api_version", channel.Other) + case common.ChannelTypeMokaAI: + c.Set("api_version", channel.Other) } } diff --git a/relay/channel/adapter.go b/relay/channel/adapter.go index d72db6e4..c970fd48 100644 --- a/relay/channel/adapter.go +++ b/relay/channel/adapter.go @@ -15,6 +15,7 @@ type Adaptor interface { SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) + ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.EmbeddingRequest) (any, error) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) diff --git a/relay/channel/ali/adaptor.go b/relay/channel/ali/adaptor.go index aa01ca66..c4974a62 100644 --- a/relay/channel/ali/adaptor.go +++ b/relay/channel/ali/adaptor.go @@ -67,6 +67,11 @@ func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dt return nil, errors.New("not implemented") } +func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.EmbeddingRequest) (any, error) { + //TODO implement me + return nil, errors.New("not implemented") +} + func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) { //TODO implement me return nil, errors.New("not implemented") diff --git a/relay/channel/aws/adaptor.go b/relay/channel/aws/adaptor.go index be72c04c..5a3d09b9 100644 --- a/relay/channel/aws/adaptor.go +++ b/relay/channel/aws/adaptor.go @@ -59,6 +59,12 @@ func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dt return nil, nil } +func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.EmbeddingRequest) (any, error) { + //TODO implement me + return nil, errors.New("not implemented") +} + + func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) { return nil, nil } diff --git a/relay/channel/baidu/adaptor.go b/relay/channel/baidu/adaptor.go index 3991a5e9..35271b41 100644 --- a/relay/channel/baidu/adaptor.go +++ b/relay/channel/baidu/adaptor.go @@ -122,6 +122,11 @@ func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dt return nil, nil } +func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.EmbeddingRequest) (any, error) { + //TODO implement me + return nil, errors.New("not implemented") +} + func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) { return channel.DoApiRequest(a, c, info, requestBody) } diff --git a/relay/channel/claude/adaptor.go b/relay/channel/claude/adaptor.go index 488d87dc..83168382 100644 --- a/relay/channel/claude/adaptor.go +++ b/relay/channel/claude/adaptor.go @@ -73,6 +73,11 @@ func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dt return nil, nil } +func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.EmbeddingRequest) (any, error) { + //TODO implement me + return nil, errors.New("not implemented") +} + func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) { return channel.DoApiRequest(a, c, info, requestBody) } diff --git a/relay/channel/cloudflare/adaptor.go b/relay/channel/cloudflare/adaptor.go index fc0ec271..cf41d9d7 100644 --- a/relay/channel/cloudflare/adaptor.go +++ b/relay/channel/cloudflare/adaptor.go @@ -56,6 +56,12 @@ func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dt return request, nil } +func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.EmbeddingRequest) (any, error) { + //TODO implement me + return nil, errors.New("not implemented") +} + + func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) { // 添加文件字段 file, _, err := c.Request.FormFile("file") diff --git a/relay/channel/cohere/adaptor.go b/relay/channel/cohere/adaptor.go index f8b190ec..d552a53b 100644 --- a/relay/channel/cohere/adaptor.go +++ b/relay/channel/cohere/adaptor.go @@ -54,6 +54,12 @@ func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dt return requestConvertRerank2Cohere(request), nil } +func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.EmbeddingRequest) (any, error) { + //TODO implement me + return nil, errors.New("not implemented") +} + + func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) { if info.RelayMode == constant.RelayModeRerank { err, usage = cohereRerankHandler(c, resp, info) diff --git a/relay/channel/deepseek/adaptor.go b/relay/channel/deepseek/adaptor.go index cc94a58f..1682dc3f 100644 --- a/relay/channel/deepseek/adaptor.go +++ b/relay/channel/deepseek/adaptor.go @@ -49,6 +49,12 @@ func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dt return nil, nil } +func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.EmbeddingRequest) (any, error) { + //TODO implement me + return nil, errors.New("not implemented") +} + + func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) { return channel.DoApiRequest(a, c, info, requestBody) } diff --git a/relay/channel/dify/adaptor.go b/relay/channel/dify/adaptor.go index 53ba26e6..ce73c78c 100644 --- a/relay/channel/dify/adaptor.go +++ b/relay/channel/dify/adaptor.go @@ -48,6 +48,12 @@ func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dt return nil, nil } +func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.EmbeddingRequest) (any, error) { + //TODO implement me + return nil, errors.New("not implemented") +} + + func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) { return channel.DoApiRequest(a, c, info, requestBody) } diff --git a/relay/channel/gemini/adaptor.go b/relay/channel/gemini/adaptor.go index 9a5bc251..681e9988 100644 --- a/relay/channel/gemini/adaptor.go +++ b/relay/channel/gemini/adaptor.go @@ -68,6 +68,12 @@ func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dt return nil, nil } +func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.EmbeddingRequest) (any, error) { + //TODO implement me + return nil, errors.New("not implemented") +} + + func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) { return channel.DoApiRequest(a, c, info, requestBody) } diff --git a/relay/channel/jina/adaptor.go b/relay/channel/jina/adaptor.go index ad488f28..98fb073d 100644 --- a/relay/channel/jina/adaptor.go +++ b/relay/channel/jina/adaptor.go @@ -55,6 +55,12 @@ func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dt return request, nil } +func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.EmbeddingRequest) (any, error) { + //TODO implement me + return nil, errors.New("not implemented") +} + + func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) { if info.RelayMode == constant.RelayModeRerank { err, usage = jinaRerankHandler(c, resp) diff --git a/relay/channel/mistral/adaptor.go b/relay/channel/mistral/adaptor.go index 4ab1a35a..c99e5396 100644 --- a/relay/channel/mistral/adaptor.go +++ b/relay/channel/mistral/adaptor.go @@ -50,6 +50,12 @@ func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dt return nil, nil } +func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.EmbeddingRequest) (any, error) { + //TODO implement me + return nil, errors.New("not implemented") +} + + func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) { return channel.DoApiRequest(a, c, info, requestBody) } diff --git a/relay/channel/mokaai/adaptor.go b/relay/channel/mokaai/adaptor.go new file mode 100644 index 00000000..9670ec94 --- /dev/null +++ b/relay/channel/mokaai/adaptor.go @@ -0,0 +1,93 @@ +package mokaai + +import ( + "errors" + "fmt" + "github.com/gin-gonic/gin" + "io" + "net/http" + "one-api/dto" + "one-api/relay/channel" + relaycommon "one-api/relay/common" + "one-api/relay/constant" + "strings" +) + +type Adaptor struct { +} + +func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) { + //TODO implement me + return nil, errors.New("not implemented") +} + +func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) { + //TODO implement me + return nil, errors.New("not implemented") +} + +func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.EmbeddingRequest) (any, error) { + //TODO implement me + return request, nil +} + +func (a *Adaptor) Init(info *relaycommon.RelayInfo) { + +} + +func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) { + // https://cloud.baidu.com/doc/WENXINWORKSHOP/s/clntwmv7t + suffix := "chat/" + if strings.HasPrefix(info.UpstreamModelName, "m3e") { + suffix = "embeddings" + } + fullRequestURL := fmt.Sprintf("%s/%s", info.BaseUrl, suffix) + return fullRequestURL, nil +} + +func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error { + channel.SetupApiRequestHeader(info, c, req) + req.Set("Authorization", fmt.Sprintf("Bearer %s", info.ApiKey)) + return nil +} + +func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) { + if request == nil { + return nil, errors.New("request is nil") + } + switch info.RelayMode { + case constant.RelayModeEmbeddings: + baiduEmbeddingRequest := embeddingRequestOpenAI2Moka(*request) + return baiduEmbeddingRequest, nil + default: + return nil, errors.New("not implemented") + } +} + +func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) { + return nil, nil +} + +func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) { + return channel.DoApiRequest(a, c, info, requestBody) +} + +func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) { + + switch info.RelayMode { + case constant.RelayModeEmbeddings: + err, usage = mokaEmbeddingHandler(c, resp) + default: + // err, usage = mokaHandler(c, resp) + + } + return +} + +func (a *Adaptor) GetModelList() []string { + return ModelList +} + +func (a *Adaptor) GetChannelName() string { + return ChannelName +} diff --git a/relay/channel/mokaai/constants.go b/relay/channel/mokaai/constants.go new file mode 100644 index 00000000..415d83b7 --- /dev/null +++ b/relay/channel/mokaai/constants.go @@ -0,0 +1,9 @@ +package mokaai + +var ModelList = []string{ + "m3e-large", + "m3e-base", + "m3e-small", +} + +var ChannelName = "mokaai" \ No newline at end of file diff --git a/relay/channel/mokaai/relay-mokaai.go b/relay/channel/mokaai/relay-mokaai.go new file mode 100644 index 00000000..d7580d7a --- /dev/null +++ b/relay/channel/mokaai/relay-mokaai.go @@ -0,0 +1,83 @@ +package mokaai + +import ( + "encoding/json" + "github.com/gin-gonic/gin" + "io" + "net/http" + "one-api/dto" + "one-api/service" +) + +func embeddingRequestOpenAI2Moka(request dto.GeneralOpenAIRequest) *dto.EmbeddingRequest { + var input []string // Change input to []string + + switch v := request.Input.(type) { + case string: + input = []string{v} // Convert string to []string + case []string: + input = v // Already a []string, no conversion needed + case []interface{}: + for _, part := range v { + if str, ok := part.(string); ok { + input = append(input, str) // Append each string to the slice + } + } + } + return &dto.EmbeddingRequest{ + Input: input, + Model: request.Model, + } +} + +func embeddingResponseMoka2OpenAI(response *dto.EmbeddingResponse) *dto.OpenAIEmbeddingResponse { + openAIEmbeddingResponse := dto.OpenAIEmbeddingResponse{ + Object: "list", + Data: make([]dto.OpenAIEmbeddingResponseItem, 0, len(response.Data)), + Model: "baidu-embedding", + Usage: response.Usage, + } + for _, item := range response.Data { + openAIEmbeddingResponse.Data = append(openAIEmbeddingResponse.Data, dto.OpenAIEmbeddingResponseItem{ + Object: item.Object, + Index: item.Index, + Embedding: item.Embedding, + }) + } + return &openAIEmbeddingResponse +} + +func mokaEmbeddingHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) { + var baiduResponse dto.EmbeddingResponse + responseBody, err := io.ReadAll(resp.Body) + if err != nil { + return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil + } + err = resp.Body.Close() + if err != nil { + return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil + } + err = json.Unmarshal(responseBody, &baiduResponse) + if err != nil { + return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil + } + // if baiduResponse.ErrorMsg != "" { + // return &dto.OpenAIErrorWithStatusCode{ + // Error: dto.OpenAIError{ + // Type: "baidu_error", + // Param: "", + // }, + // StatusCode: resp.StatusCode, + // }, nil + // } + fullTextResponse := embeddingResponseMoka2OpenAI(&baiduResponse) + jsonResponse, err := json.Marshal(fullTextResponse) + if err != nil { + return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil + } + c.Writer.Header().Set("Content-Type", "application/json") + c.Writer.WriteHeader(resp.StatusCode) + _, err = c.Writer.Write(jsonResponse) + return nil, &fullTextResponse.Usage +} + diff --git a/relay/channel/ollama/adaptor.go b/relay/channel/ollama/adaptor.go index 30798402..d5185084 100644 --- a/relay/channel/ollama/adaptor.go +++ b/relay/channel/ollama/adaptor.go @@ -58,6 +58,12 @@ func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dt return nil, nil } +func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.EmbeddingRequest) (any, error) { + //TODO implement me + return nil, errors.New("not implemented") +} + + func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) { return channel.DoApiRequest(a, c, info, requestBody) } diff --git a/relay/channel/openai/adaptor.go b/relay/channel/openai/adaptor.go index 68c36528..d86b33e0 100644 --- a/relay/channel/openai/adaptor.go +++ b/relay/channel/openai/adaptor.go @@ -149,6 +149,11 @@ func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dt return nil, errors.New("not implemented") } +func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.EmbeddingRequest) (any, error) { + //TODO implement me + return nil, errors.New("not implemented") +} + func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) { a.ResponseFormat = request.ResponseFormat if info.RelayMode == constant.RelayModeAudioSpeech { diff --git a/relay/channel/palm/adaptor.go b/relay/channel/palm/adaptor.go index 91272337..f38fa95b 100644 --- a/relay/channel/palm/adaptor.go +++ b/relay/channel/palm/adaptor.go @@ -49,6 +49,12 @@ func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dt return nil, nil } +func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.EmbeddingRequest) (any, error) { + //TODO implement me + return nil, errors.New("not implemented") +} + + func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) { return channel.DoApiRequest(a, c, info, requestBody) } diff --git a/relay/channel/perplexity/adaptor.go b/relay/channel/perplexity/adaptor.go index 18b66a9a..2b27bdb1 100644 --- a/relay/channel/perplexity/adaptor.go +++ b/relay/channel/perplexity/adaptor.go @@ -52,6 +52,12 @@ func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dt return nil, nil } +func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.EmbeddingRequest) (any, error) { + //TODO implement me + return nil, errors.New("not implemented") +} + + func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) { return channel.DoApiRequest(a, c, info, requestBody) } diff --git a/relay/channel/siliconflow/adaptor.go b/relay/channel/siliconflow/adaptor.go index ac722b22..f9ddedeb 100644 --- a/relay/channel/siliconflow/adaptor.go +++ b/relay/channel/siliconflow/adaptor.go @@ -58,6 +58,12 @@ func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dt return request, nil } +func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.EmbeddingRequest) (any, error) { + //TODO implement me + return nil, errors.New("not implemented") +} + + func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) { switch info.RelayMode { case constant.RelayModeRerank: diff --git a/relay/channel/tencent/adaptor.go b/relay/channel/tencent/adaptor.go index d831cc83..768ef646 100644 --- a/relay/channel/tencent/adaptor.go +++ b/relay/channel/tencent/adaptor.go @@ -73,6 +73,12 @@ func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dt return nil, nil } +func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.EmbeddingRequest) (any, error) { + //TODO implement me + return nil, errors.New("not implemented") +} + + func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) { return channel.DoApiRequest(a, c, info, requestBody) } diff --git a/relay/channel/vertex/adaptor.go b/relay/channel/vertex/adaptor.go index 764e5c4b..07659c20 100644 --- a/relay/channel/vertex/adaptor.go +++ b/relay/channel/vertex/adaptor.go @@ -151,6 +151,12 @@ func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dt return nil, nil } +func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.EmbeddingRequest) (any, error) { + //TODO implement me + return nil, errors.New("not implemented") +} + + func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) { return channel.DoApiRequest(a, c, info, requestBody) } diff --git a/relay/channel/xunfei/adaptor.go b/relay/channel/xunfei/adaptor.go index 31d426a6..71fd1367 100644 --- a/relay/channel/xunfei/adaptor.go +++ b/relay/channel/xunfei/adaptor.go @@ -50,6 +50,12 @@ func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dt return nil, nil } +func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.EmbeddingRequest) (any, error) { + //TODO implement me + return nil, errors.New("not implemented") +} + + func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) { // xunfei's request is not http request, so we don't need to do anything here dummyResp := &http.Response{} diff --git a/relay/channel/zhipu/adaptor.go b/relay/channel/zhipu/adaptor.go index f0538edc..87ff20d5 100644 --- a/relay/channel/zhipu/adaptor.go +++ b/relay/channel/zhipu/adaptor.go @@ -56,6 +56,12 @@ func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dt return nil, nil } +func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.EmbeddingRequest) (any, error) { + //TODO implement me + return nil, errors.New("not implemented") +} + + func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) { return channel.DoApiRequest(a, c, info, requestBody) } diff --git a/relay/channel/zhipu_4v/adaptor.go b/relay/channel/zhipu_4v/adaptor.go index 3d46b799..5983c1d9 100644 --- a/relay/channel/zhipu_4v/adaptor.go +++ b/relay/channel/zhipu_4v/adaptor.go @@ -53,6 +53,12 @@ func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dt return nil, nil } +func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.EmbeddingRequest) (any, error) { + //TODO implement me + return nil, errors.New("not implemented") +} + + func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) { return channel.DoApiRequest(a, c, info, requestBody) } diff --git a/relay/constant/api_type.go b/relay/constant/api_type.go index c5c1d089..1a40a6ee 100644 --- a/relay/constant/api_type.go +++ b/relay/constant/api_type.go @@ -27,7 +27,7 @@ const ( APITypeVertexAi APITypeMistral APITypeDeepSeek - + APITypeMokaAI APITypeDummy // this one is only for count, do not add any channel after this ) @@ -78,6 +78,8 @@ func ChannelType2APIType(channelType int) (int, bool) { apiType = APITypeMistral case common.ChannelTypeDeepSeek: apiType = APITypeDeepSeek + case common.ChannelTypeMokaAI: + apiType = APITypeMokaAI } if apiType == -1 { return APITypeOpenAI, false diff --git a/relay/relay_adaptor.go b/relay/relay_adaptor.go index 1c7d11e9..9304bd6d 100644 --- a/relay/relay_adaptor.go +++ b/relay/relay_adaptor.go @@ -14,6 +14,7 @@ import ( "one-api/relay/channel/gemini" "one-api/relay/channel/jina" "one-api/relay/channel/mistral" + "one-api/relay/channel/mokaai" "one-api/relay/channel/ollama" "one-api/relay/channel/openai" "one-api/relay/channel/palm" @@ -74,6 +75,8 @@ func GetAdaptor(apiType int) channel.Adaptor { return &mistral.Adaptor{} case constant.APITypeDeepSeek: return &deepseek.Adaptor{} + case constant.APITypeMokaAI: + return &mokaai.Adaptor{} } return nil } diff --git a/relay/relay_embedding.go b/relay/relay_embedding.go new file mode 100644 index 00000000..a3304d8e --- /dev/null +++ b/relay/relay_embedding.go @@ -0,0 +1,127 @@ +package relay + +import ( + "bytes" + "encoding/json" + "fmt" + "github.com/gin-gonic/gin" + "net/http" + "one-api/common" + "one-api/dto" + relaycommon "one-api/relay/common" + relayconstant "one-api/relay/constant" + "one-api/service" + "one-api/setting" +) + +func getEmbeddingPromptToken(embeddingRequest dto.EmbeddingRequest) int { + token, _ := service.CountTokenInput(embeddingRequest.Input, embeddingRequest.Model) + return token +} + +func EmbeddingHelper(c *gin.Context, relayMode int) (openaiErr *dto.OpenAIErrorWithStatusCode) { + relayInfo := relaycommon.GenRelayInfo(c) + + var embeddingRequest *dto.EmbeddingRequest + err := common.UnmarshalBodyReusable(c, &embeddingRequest) + if err != nil { + common.LogError(c, fmt.Sprintf("getAndValidateTextRequest failed: %s", err.Error())) + return service.OpenAIErrorWrapperLocal(err, "invalid_text_request", http.StatusBadRequest) + } + if relayMode == relayconstant.RelayModeModerations && embeddingRequest.Model == "" { + embeddingRequest.Model = "m3e-base" + } + if relayMode == relayconstant.RelayModeEmbeddings && embeddingRequest.Model == "" { + embeddingRequest.Model = c.Param("model") + } + if embeddingRequest.Input == nil || len(embeddingRequest.Input) == 0 { + return service.OpenAIErrorWrapperLocal(fmt.Errorf("input is empty"), "invalid_input", http.StatusBadRequest) + } + // map model name + modelMapping := c.GetString("model_mapping") + //isModelMapped := false + if modelMapping != "" && modelMapping != "{}" { + modelMap := make(map[string]string) + err := json.Unmarshal([]byte(modelMapping), &modelMap) + if err != nil { + return service.OpenAIErrorWrapperLocal(err, "unmarshal_model_mapping_failed", http.StatusInternalServerError) + } + if modelMap[embeddingRequest.Model] != "" { + embeddingRequest.Model = modelMap[embeddingRequest.Model] + // set upstream model name + //isModelMapped = true + } + } + + relayInfo.UpstreamModelName = embeddingRequest.Model + modelPrice, success := common.GetModelPrice(embeddingRequest.Model, false) + groupRatio := setting.GetGroupRatio(relayInfo.Group) + + var preConsumedQuota int + var ratio float64 + var modelRatio float64 + + promptToken := getEmbeddingPromptToken(*embeddingRequest) + if !success { + preConsumedTokens := promptToken + modelRatio = common.GetModelRatio(embeddingRequest.Model) + ratio = modelRatio * groupRatio + preConsumedQuota = int(float64(preConsumedTokens) * ratio) + } else { + preConsumedQuota = int(modelPrice * common.QuotaPerUnit * groupRatio) + } + relayInfo.PromptTokens = promptToken + + // pre-consume quota 预消耗配额 + preConsumedQuota, userQuota, openaiErr := preConsumeQuota(c, preConsumedQuota, relayInfo) + if openaiErr != nil { + return openaiErr + } + defer func() { + if openaiErr != nil { + returnPreConsumedQuota(c, relayInfo, userQuota, preConsumedQuota) + } + }() + + adaptor := GetAdaptor(relayInfo.ApiType) + if adaptor == nil { + return service.OpenAIErrorWrapperLocal(fmt.Errorf("invalid api type: %d", relayInfo.ApiType), "invalid_api_type", http.StatusBadRequest) + } + adaptor.Init(relayInfo) + + convertedRequest, err := adaptor.ConvertEmbeddingRequest(c,relayInfo,*embeddingRequest) + + if err != nil { + return service.OpenAIErrorWrapperLocal(err, "convert_request_failed", http.StatusInternalServerError) + } + jsonData, err := json.Marshal(convertedRequest) + if err != nil { + return service.OpenAIErrorWrapperLocal(err, "json_marshal_failed", http.StatusInternalServerError) + } + requestBody := bytes.NewBuffer(jsonData) + statusCodeMappingStr := c.GetString("status_code_mapping") + resp, err := adaptor.DoRequest(c,relayInfo, requestBody) + if err != nil { + return service.OpenAIErrorWrapper(err, "do_request_failed", http.StatusInternalServerError) + } + + var httpResp *http.Response + if resp != nil { + httpResp = resp.(*http.Response) + if httpResp.StatusCode != http.StatusOK { + openaiErr = service.RelayErrorHandler(httpResp) + // reset status code 重置状态码 + service.ResetStatusCode(openaiErr, statusCodeMappingStr) + return openaiErr + } + } + + usage, openaiErr := adaptor.DoResponse(c, httpResp, relayInfo) + if openaiErr != nil { + // reset status code 重置状态码 + service.ResetStatusCode(openaiErr, statusCodeMappingStr) + return openaiErr + } + postConsumeQuota(c, relayInfo, embeddingRequest.Model, usage.(*dto.Usage), ratio, preConsumedQuota, userQuota, modelRatio, groupRatio, modelPrice, success, "") + return nil +} diff --git a/web/src/components/ChannelsTable.js b/web/src/components/ChannelsTable.js index 890e32ba..d62c2f13 100644 --- a/web/src/components/ChannelsTable.js +++ b/web/src/components/ChannelsTable.js @@ -44,7 +44,7 @@ function renderTimestamp(timestamp) { const ChannelsTable = () => { const { t } = useTranslation(); - + let type2label = undefined; const renderType = (type) => { @@ -559,7 +559,7 @@ const ChannelsTable = () => { if (!enableTagMode) { channelDates.push(channels[i]); } else { - let tag = channels[i].tag?channels[i].tag:""; + let tag = channels[i].tag ? channels[i].tag : ""; // find from channelTags let tagIndex = channelTags[tag]; let tagChannelDates = undefined; @@ -805,6 +805,9 @@ const ChannelsTable = () => { record.response_time = time * 1000; record.test_time = Date.now() / 1000; showInfo(t('通道 ${name} 测试成功,耗时 ${time.toFixed(2)} 秒。').replace('${name}', record.name).replace('${time.toFixed(2)}', time.toFixed(2))); + + // 刷新列表 + await refresh(); } else { showError(message); } @@ -838,6 +841,8 @@ const ChannelsTable = () => { record.balance = balance; record.balance_updated_time = Date.now() / 1000; showInfo(t('通道 ${name} 余额更新成功!').replace('${name}', record.name)); + // 刷新列表 + await refresh(); } else { showError(message); } @@ -1186,7 +1191,7 @@ const ChannelsTable = () => {
- + {t('标签聚合模式')} { }} /> + disabled={!enableBatchDelete} + theme="light" + type="primary" + style={{ marginRight: 8 }} + onClick={() => setShowBatchSetTag(true)} + > + {t('批量设置标签')} +
diff --git a/web/src/constants/channel.constants.js b/web/src/constants/channel.constants.js index 32bf8bce..c1be95b7 100644 --- a/web/src/constants/channel.constants.js +++ b/web/src/constants/channel.constants.js @@ -125,5 +125,12 @@ export const CHANNEL_OPTIONS = [ value: 21, color: 'purple', label: '知识库:AI Proxy' + }, + { + key: 47, + text: '嵌入模型:MokaAI M3E', + value: 47, + color: 'purple', + label: '嵌入模型:MokaAI M3E' } ];