Support for MokaAI M3E

This commit is contained in:
Jerry
2025-01-22 04:21:08 +08:00
parent 53a941a6c0
commit 126f04e08f
11 changed files with 341 additions and 5 deletions

View File

@@ -1,4 +1,4 @@
FROM oven/bun:latest as builder
FROM oven/bun:latest AS builder
WORKDIR /build
COPY web/package.json .

View File

@@ -231,7 +231,7 @@ const (
ChannelTypeVertexAi = 41
ChannelTypeMistral = 42
ChannelTypeDeepSeek = 43
ChannelTypeMokaAI = 47
ChannelTypeDummy // this one is only for count, do not add any channel after this
)
@@ -281,4 +281,5 @@ var ChannelBaseURLs = []string{
"", //41
"https://api.mistral.ai", //42
"https://api.deepseek.com", //43
"https://api.moka.ai", //43
}

View File

@@ -41,14 +41,27 @@ func testChannel(channel *model.Channel, testModel string) (err error, openAIErr
}
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
requestPath := "/v1/chat/completions"
// 先判断是否为 Embedding 模型
if strings.Contains(strings.ToLower(testModel), "embedding") ||
strings.HasPrefix(testModel, "m3e") || // m3e 系列模型
strings.Contains(testModel, "bge-") || // bge 系列模型
testModel == "text-embedding-v1" ||
channel.Type == common.ChannelTypeMokaAI{ // 其他 embedding 模型
requestPath = "/v1/embeddings" // 修改请求路径
}
c.Request = &http.Request{
Method: "POST",
URL: &url.URL{Path: "/v1/chat/completions"},
URL: &url.URL{Path: requestPath}, // 使用动态路径
Body: nil,
Header: make(http.Header),
}
if testModel == "" {
common.SysLog(fmt.Sprintf("testModel 为空, channel 的 TestModel 是 %s", string(*channel.TestModel)))
if channel.TestModel != nil && *channel.TestModel != "" {
testModel = *channel.TestModel
} else {
@@ -57,6 +70,7 @@ func testChannel(channel *model.Channel, testModel string) (err error, openAIErr
} else {
testModel = "gpt-3.5-turbo"
}
common.SysLog(fmt.Sprintf("testModel 为空, channel 的 TestModel 为空:", string(testModel)))
}
} else {
modelMapping := *channel.ModelMapping
@@ -88,7 +102,7 @@ func testChannel(channel *model.Channel, testModel string) (err error, openAIErr
request := buildTestRequest(testModel)
meta.UpstreamModelName = testModel
common.SysLog(fmt.Sprintf("testing channel %d with model %s", channel.Id, testModel))
common.SysLog(fmt.Sprintf("testing channel %d with model %s , meta %s ", channel.Id, testModel, meta))
adaptor.Init(meta)
@@ -156,6 +170,16 @@ func buildTestRequest(model string) *dto.GeneralOpenAIRequest {
Model: "", // this will be set later
Stream: false,
}
// 先判断是否为 Embedding 模型
if strings.Contains(strings.ToLower(model), "embedding") ||
strings.HasPrefix(model, "m3e") || // m3e 系列模型
strings.Contains(model, "bge-") || // bge 系列模型
model == "text-embedding-v1" { // 其他 embedding 模型
// Embedding 请求
testRequest.Input = []string{"hello world"}
return testRequest
}
// 并非Embedding 模型
if strings.HasPrefix(model, "o1") {
testRequest.MaxCompletionTokens = 10
} else if strings.HasPrefix(model, "gemini-2.0-flash-thinking") {

View File

@@ -239,5 +239,7 @@ func SetupContextForSelectedChannel(c *gin.Context, channel *model.Channel, mode
c.Set("plugin", channel.Other)
case common.ChannelCloudflare:
c.Set("api_version", channel.Other)
case common.ChannelTypeMokaAI:
c.Set("api_version", channel.Other)
}
}

View File

@@ -0,0 +1,104 @@
package mokaai
import (
"errors"
"fmt"
"io"
"net/http"
"github.com/gin-gonic/gin"
// "one-api/relay/adaptor"
// "one-api/relay/meta"
// "one-api/relay/model"
// "one-api/relay/constant"
"one-api/dto"
"one-api/relay/channel"
relaycommon "one-api/relay/common"
"one-api/relay/constant"
)
type Adaptor struct {
}
// ConvertImageRequest implements adaptor.Adaptor.
func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) {
//TODO implement me
return nil, errors.New("not implemented")
}
func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) {
//TODO implement me
return nil, errors.New("not implemented")
}
func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) {
//TODO implement me
return nil, errors.New("not implemented")
}
func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
}
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
var urlPrefix = info.BaseUrl
switch info.RelayMode {
case constant.RelayModeChatCompletions:
return fmt.Sprintf("%s/chat/completions", urlPrefix), nil
case constant.RelayModeEmbeddings:
return fmt.Sprintf("%s/embeddings", urlPrefix), nil
default:
return fmt.Sprintf("%s/run/%s", urlPrefix, info.UpstreamModelName), nil
}
}
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error {
channel.SetupApiRequestHeader(info, c, req)
req.Set("Authorization", fmt.Sprintf("Bearer %s", info.ApiKey))
return nil
}
func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
if request == nil {
return nil, errors.New("request is nil")
}
switch info.RelayMode {
case constant.RelayModeChatCompletions:
return nil, errors.New("not implemented")
case constant.RelayModeEmbeddings:
// return ConvertCompletionsRequest(*request), nil
return ConvertEmbeddingRequest(*request), nil
default:
return nil, errors.New("not implemented")
}
}
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
return channel.DoApiRequest(a, c, info, requestBody)
}
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) {
switch info.RelayMode {
case constant.RelayModeAudioTranscription:
case constant.RelayModeAudioTranslation:
case constant.RelayModeChatCompletions:
fallthrough
case constant.RelayModeEmbeddings:
if info.IsStream {
err, usage = StreamHandler(c, resp, info)
} else {
err, usage = Handler(c, resp, info)
}
}
return
}
func (a *Adaptor) GetModelList() []string {
return ModelList
}
func (a *Adaptor) GetChannelName() string {
return ChannelName
}

View File

@@ -0,0 +1,9 @@
package mokaai
var ModelList = []string{
"m3e-large",
"m3e-base",
"m3e-small",
}
var ChannelName = "mokaai"

View File

@@ -0,0 +1,30 @@
package mokaai
import "one-api/dto"
type Request struct {
Messages []dto.Message `json:"messages,omitempty"`
Lora string `json:"lora,omitempty"`
MaxTokens int `json:"max_tokens,omitempty"`
Prompt string `json:"prompt,omitempty"`
Raw bool `json:"raw,omitempty"`
Stream bool `json:"stream,omitempty"`
Temperature float64 `json:"temperature,omitempty"`
}
type Options struct {
Seed int `json:"seed,omitempty"`
Temperature *float64 `json:"temperature,omitempty"`
TopK int `json:"top_k,omitempty"`
TopP *float64 `json:"top_p,omitempty"`
FrequencyPenalty *float64 `json:"frequency_penalty,omitempty"`
PresencePenalty *float64 `json:"presence_penalty,omitempty"`
NumPredict int `json:"num_predict,omitempty"`
NumCtx int `json:"num_ctx,omitempty"`
}
type EmbeddingRequest struct {
Model string `json:"model"`
Input []string `json:"input"`
}

View File

@@ -0,0 +1,154 @@
package mokaai
import (
"bufio"
"encoding/json"
"io"
"net/http"
"strings"
// "one-api/common/ctxkey"
// "one-api/common/render"
// "github.com/gin-gonic/gin"
// "one-api/common"
// "one-api/common/helper"
// "one-api/common/logger"
// "one-api/relay/adaptor/openai"
// "one-api/relay/model"
"github.com/gin-gonic/gin"
"one-api/common"
"one-api/dto"
relaycommon "one-api/relay/common"
"one-api/service"
"time"
)
func ConvertCompletionsRequest(textRequest dto.GeneralOpenAIRequest) *Request {
p, _ := textRequest.Prompt.(string)
return &Request{
Prompt: p,
MaxTokens: textRequest.GetMaxTokens(),
Stream: textRequest.Stream,
Temperature: textRequest.Temperature,
}
}
func ConvertEmbeddingRequest(request dto.GeneralOpenAIRequest) *EmbeddingRequest {
var input []string // Change input to []string
switch v := request.Input.(type) {
case string:
input = []string{v} // Convert string to []string
case []string:
input = v // Already a []string, no conversion needed
case []interface{}:
for _, part := range v {
if str, ok := part.(string); ok {
input = append(input, str) // Append each string to the slice
}
}
}
return &EmbeddingRequest{
Model: request.Model,
Input: input, // Assign []string to Input
}
}
func StreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
scanner := bufio.NewScanner(resp.Body)
scanner.Split(bufio.ScanLines)
service.SetEventStreamHeaders(c)
id := service.GetResponseID(c)
var responseText string
isFirst := true
for scanner.Scan() {
data := scanner.Text()
if len(data) < len("data: ") {
continue
}
data = strings.TrimPrefix(data, "data: ")
data = strings.TrimSuffix(data, "\r")
if data == "[DONE]" {
break
}
var response dto.ChatCompletionsStreamResponse
err := json.Unmarshal([]byte(data), &response)
if err != nil {
common.LogError(c, "error_unmarshalling_stream_response: "+err.Error())
continue
}
for _, choice := range response.Choices {
choice.Delta.Role = "assistant"
responseText += choice.Delta.GetContentString()
}
response.Id = id
response.Model = info.UpstreamModelName
err = service.ObjectData(c, response)
if isFirst {
isFirst = false
info.FirstResponseTime = time.Now()
}
if err != nil {
common.LogError(c, "error_rendering_stream_response: "+err.Error())
}
}
if err := scanner.Err(); err != nil {
common.LogError(c, "error_scanning_stream_response: "+err.Error())
}
usage, _ := service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
if info.ShouldIncludeUsage {
response := service.GenerateFinalUsageResponse(id, info.StartTime.Unix(), info.UpstreamModelName, *usage)
err := service.ObjectData(c, response)
if err != nil {
common.LogError(c, "error_rendering_final_usage_response: "+err.Error())
}
}
service.Done(c)
err := resp.Body.Close()
if err != nil {
common.LogError(c, "close_response_body_failed: "+err.Error())
}
return nil, usage
}
func Handler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
}
err = resp.Body.Close()
if err != nil {
return service.OpenAIErrorWrapperLocal(err, "close_response_body_failed", http.StatusInternalServerError), nil
}
var response dto.TextResponse
err = json.Unmarshal(responseBody, &response)
if err != nil {
return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
}
response.Model = info.UpstreamModelName
var responseText string
for _, choice := range response.Choices {
responseText += choice.Message.StringContent()
}
usage, _ := service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
response.Usage = *usage
response.Id = service.GetResponseID(c)
jsonResponse, err := json.Marshal(response)
if err != nil {
return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
}
c.Writer.Header().Set("Content-Type", "application/json")
c.Writer.WriteHeader(resp.StatusCode)
_, _ = c.Writer.Write(jsonResponse)
return nil, usage
}

View File

@@ -27,7 +27,7 @@ const (
APITypeVertexAi
APITypeMistral
APITypeDeepSeek
APITypeMokaAI
APITypeDummy // this one is only for count, do not add any channel after this
)
@@ -78,6 +78,8 @@ func ChannelType2APIType(channelType int) (int, bool) {
apiType = APITypeMistral
case common.ChannelTypeDeepSeek:
apiType = APITypeDeepSeek
case common.ChannelTypeMokaAI:
apiType = APITypeMokaAI
}
if apiType == -1 {
return APITypeOpenAI, false

View File

@@ -14,6 +14,7 @@ import (
"one-api/relay/channel/gemini"
"one-api/relay/channel/jina"
"one-api/relay/channel/mistral"
"one-api/relay/channel/mokaai"
"one-api/relay/channel/ollama"
"one-api/relay/channel/openai"
"one-api/relay/channel/palm"
@@ -74,6 +75,8 @@ func GetAdaptor(apiType int) channel.Adaptor {
return &mistral.Adaptor{}
case constant.APITypeDeepSeek:
return &deepseek.Adaptor{}
case constant.APITypeMokaAI:
return &mokaai.Adaptor{}
}
return nil
}

View File

@@ -125,5 +125,12 @@ export const CHANNEL_OPTIONS = [
value: 21,
color: 'purple',
label: '知识库AI Proxy'
},
{
key: 47,
text: '嵌入模型MokaAI M3E',
value: 47,
color: 'purple',
label: '嵌入模型MokaAI M3E'
}
];