Merge pull request #723 from kuwork/main

Support for MokaAI M3E
This commit is contained in:
Calcium-Ion
2025-02-11 22:16:18 +07:00
committed by GitHub
36 changed files with 534 additions and 16 deletions

View File

@@ -41,14 +41,27 @@ func testChannel(channel *model.Channel, testModel string) (err error, openAIErr
}
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
requestPath := "/v1/chat/completions"
// 先判断是否为 Embedding 模型
if strings.Contains(strings.ToLower(testModel), "embedding") ||
strings.HasPrefix(testModel, "m3e") || // m3e 系列模型
strings.Contains(testModel, "bge-") || // bge 系列模型
testModel == "text-embedding-v1" ||
channel.Type == common.ChannelTypeMokaAI{ // 其他 embedding 模型
requestPath = "/v1/embeddings" // 修改请求路径
}
c.Request = &http.Request{
Method: "POST",
URL: &url.URL{Path: "/v1/chat/completions"},
URL: &url.URL{Path: requestPath}, // 使用动态路径
Body: nil,
Header: make(http.Header),
}
if testModel == "" {
common.SysLog(fmt.Sprintf("testModel 为空, channel 的 TestModel 是 %s", string(*channel.TestModel)))
if channel.TestModel != nil && *channel.TestModel != "" {
testModel = *channel.TestModel
} else {
@@ -57,6 +70,7 @@ func testChannel(channel *model.Channel, testModel string) (err error, openAIErr
} else {
testModel = "gpt-3.5-turbo"
}
common.SysLog(fmt.Sprintf("testModel 为空, channel 的 TestModel 为空:", string(testModel)))
}
}
@@ -88,7 +102,7 @@ func testChannel(channel *model.Channel, testModel string) (err error, openAIErr
request := buildTestRequest(testModel)
meta.UpstreamModelName = testModel
common.SysLog(fmt.Sprintf("testing channel %d with model %s", channel.Id, testModel))
common.SysLog(fmt.Sprintf("testing channel %d with model %s , meta %s ", channel.Id, testModel, meta))
adaptor.Init(meta)
@@ -156,6 +170,17 @@ func buildTestRequest(model string) *dto.GeneralOpenAIRequest {
Model: "", // this will be set later
Stream: false,
}
// 先判断是否为 Embedding 模型
if strings.Contains(strings.ToLower(model), "embedding") ||
strings.HasPrefix(model, "m3e") || // m3e 系列模型
strings.Contains(model, "bge-") || // bge 系列模型
model == "text-embedding-v1" { // 其他 embedding 模型
// Embedding 请求
testRequest.Input = []string{"hello world"}
return testRequest
}
// 并非Embedding 模型
if strings.HasPrefix(model, "o1") || strings.HasPrefix(model, "o3") {
testRequest.MaxCompletionTokens = 10
} else {

View File

@@ -33,6 +33,8 @@ func relayHandler(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusCode
err = relay.AudioHelper(c)
case relayconstant.RelayModeRerank:
err = relay.RerankHelper(c, relayMode)
case relayconstant.RelayModeEmbeddings:
err = relay.EmbeddingHelper(c,relayMode)
default:
err = relay.TextHelper(c)
}
@@ -55,6 +57,11 @@ func Relay(c *gin.Context) {
originalModel := c.GetString("original_model")
var openaiErr *dto.OpenAIErrorWithStatusCode
//获取request body 并输出到日志
requestBody, _ := common.GetRequestBody(c)
common.LogInfo(c, fmt.Sprintf("relayMode: %d ,request body: %s",relayMode, string(requestBody)))
for i := 0; i <= common.RetryTimes; i++ {
channel, err := getChannel(c, group, originalModel, i)
if err != nil {
@@ -154,6 +161,7 @@ func WssRelay(c *gin.Context) {
}
func relayRequest(c *gin.Context, relayMode int, channel *model.Channel) *dto.OpenAIErrorWithStatusCode {
common.LogInfo(c, fmt.Sprintf("relayMode: %d ,channel Id : %s",relayMode, string(channel.Id)))
addUsedChannel(c, channel.Id)
requestBody, _ := common.GetRequestBody(c)
c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody))