Support for MokaAI M3E
This commit is contained in:
@@ -1,4 +1,4 @@
|
||||
FROM oven/bun:latest as builder
|
||||
FROM oven/bun:latest AS builder
|
||||
|
||||
WORKDIR /build
|
||||
COPY web/package.json .
|
||||
|
||||
@@ -231,7 +231,7 @@ const (
|
||||
ChannelTypeVertexAi = 41
|
||||
ChannelTypeMistral = 42
|
||||
ChannelTypeDeepSeek = 43
|
||||
|
||||
ChannelTypeMokaAI = 47
|
||||
ChannelTypeDummy // this one is only for count, do not add any channel after this
|
||||
|
||||
)
|
||||
@@ -281,4 +281,5 @@ var ChannelBaseURLs = []string{
|
||||
"", //41
|
||||
"https://api.mistral.ai", //42
|
||||
"https://api.deepseek.com", //43
|
||||
"https://api.moka.ai", //43
|
||||
}
|
||||
|
||||
@@ -41,14 +41,27 @@ func testChannel(channel *model.Channel, testModel string) (err error, openAIErr
|
||||
}
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
|
||||
requestPath := "/v1/chat/completions"
|
||||
|
||||
// 先判断是否为 Embedding 模型
|
||||
if strings.Contains(strings.ToLower(testModel), "embedding") ||
|
||||
strings.HasPrefix(testModel, "m3e") || // m3e 系列模型
|
||||
strings.Contains(testModel, "bge-") || // bge 系列模型
|
||||
testModel == "text-embedding-v1" ||
|
||||
channel.Type == common.ChannelTypeMokaAI{ // 其他 embedding 模型
|
||||
requestPath = "/v1/embeddings" // 修改请求路径
|
||||
}
|
||||
|
||||
c.Request = &http.Request{
|
||||
Method: "POST",
|
||||
URL: &url.URL{Path: "/v1/chat/completions"},
|
||||
URL: &url.URL{Path: requestPath}, // 使用动态路径
|
||||
Body: nil,
|
||||
Header: make(http.Header),
|
||||
}
|
||||
|
||||
if testModel == "" {
|
||||
common.SysLog(fmt.Sprintf("testModel 为空, channel 的 TestModel 是 %s", string(*channel.TestModel)))
|
||||
if channel.TestModel != nil && *channel.TestModel != "" {
|
||||
testModel = *channel.TestModel
|
||||
} else {
|
||||
@@ -57,6 +70,7 @@ func testChannel(channel *model.Channel, testModel string) (err error, openAIErr
|
||||
} else {
|
||||
testModel = "gpt-3.5-turbo"
|
||||
}
|
||||
common.SysLog(fmt.Sprintf("testModel 为空, channel 的 TestModel 为空:", string(testModel)))
|
||||
}
|
||||
} else {
|
||||
modelMapping := *channel.ModelMapping
|
||||
@@ -88,7 +102,7 @@ func testChannel(channel *model.Channel, testModel string) (err error, openAIErr
|
||||
|
||||
request := buildTestRequest(testModel)
|
||||
meta.UpstreamModelName = testModel
|
||||
common.SysLog(fmt.Sprintf("testing channel %d with model %s", channel.Id, testModel))
|
||||
common.SysLog(fmt.Sprintf("testing channel %d with model %s , meta %s ", channel.Id, testModel, meta))
|
||||
|
||||
adaptor.Init(meta)
|
||||
|
||||
@@ -156,6 +170,16 @@ func buildTestRequest(model string) *dto.GeneralOpenAIRequest {
|
||||
Model: "", // this will be set later
|
||||
Stream: false,
|
||||
}
|
||||
// 先判断是否为 Embedding 模型
|
||||
if strings.Contains(strings.ToLower(model), "embedding") ||
|
||||
strings.HasPrefix(model, "m3e") || // m3e 系列模型
|
||||
strings.Contains(model, "bge-") || // bge 系列模型
|
||||
model == "text-embedding-v1" { // 其他 embedding 模型
|
||||
// Embedding 请求
|
||||
testRequest.Input = []string{"hello world"}
|
||||
return testRequest
|
||||
}
|
||||
// 并非Embedding 模型
|
||||
if strings.HasPrefix(model, "o1") {
|
||||
testRequest.MaxCompletionTokens = 10
|
||||
} else if strings.HasPrefix(model, "gemini-2.0-flash-thinking") {
|
||||
|
||||
@@ -239,5 +239,7 @@ func SetupContextForSelectedChannel(c *gin.Context, channel *model.Channel, mode
|
||||
c.Set("plugin", channel.Other)
|
||||
case common.ChannelCloudflare:
|
||||
c.Set("api_version", channel.Other)
|
||||
case common.ChannelTypeMokaAI:
|
||||
c.Set("api_version", channel.Other)
|
||||
}
|
||||
}
|
||||
|
||||
104
relay/channel/mokaai/adaptor.go
Normal file
104
relay/channel/mokaai/adaptor.go
Normal file
@@ -0,0 +1,104 @@
|
||||
package mokaai
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
// "one-api/relay/adaptor"
|
||||
// "one-api/relay/meta"
|
||||
// "one-api/relay/model"
|
||||
// "one-api/relay/constant"
|
||||
"one-api/dto"
|
||||
"one-api/relay/channel"
|
||||
relaycommon "one-api/relay/common"
|
||||
"one-api/relay/constant"
|
||||
)
|
||||
|
||||
type Adaptor struct {
|
||||
}
|
||||
|
||||
// ConvertImageRequest implements adaptor.Adaptor.
|
||||
func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) {
|
||||
//TODO implement me
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) {
|
||||
//TODO implement me
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) {
|
||||
//TODO implement me
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
|
||||
}
|
||||
|
||||
|
||||
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
||||
|
||||
var urlPrefix = info.BaseUrl
|
||||
|
||||
switch info.RelayMode {
|
||||
case constant.RelayModeChatCompletions:
|
||||
return fmt.Sprintf("%s/chat/completions", urlPrefix), nil
|
||||
case constant.RelayModeEmbeddings:
|
||||
return fmt.Sprintf("%s/embeddings", urlPrefix), nil
|
||||
default:
|
||||
return fmt.Sprintf("%s/run/%s", urlPrefix, info.UpstreamModelName), nil
|
||||
}
|
||||
}
|
||||
|
||||
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error {
|
||||
channel.SetupApiRequestHeader(info, c, req)
|
||||
req.Set("Authorization", fmt.Sprintf("Bearer %s", info.ApiKey))
|
||||
return nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
|
||||
if request == nil {
|
||||
return nil, errors.New("request is nil")
|
||||
}
|
||||
switch info.RelayMode {
|
||||
case constant.RelayModeChatCompletions:
|
||||
return nil, errors.New("not implemented")
|
||||
case constant.RelayModeEmbeddings:
|
||||
// return ConvertCompletionsRequest(*request), nil
|
||||
return ConvertEmbeddingRequest(*request), nil
|
||||
default:
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
}
|
||||
|
||||
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
|
||||
return channel.DoApiRequest(a, c, info, requestBody)
|
||||
}
|
||||
|
||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) {
|
||||
switch info.RelayMode {
|
||||
|
||||
case constant.RelayModeAudioTranscription:
|
||||
case constant.RelayModeAudioTranslation:
|
||||
case constant.RelayModeChatCompletions:
|
||||
fallthrough
|
||||
case constant.RelayModeEmbeddings:
|
||||
if info.IsStream {
|
||||
err, usage = StreamHandler(c, resp, info)
|
||||
} else {
|
||||
err, usage = Handler(c, resp, info)
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (a *Adaptor) GetModelList() []string {
|
||||
return ModelList
|
||||
}
|
||||
|
||||
func (a *Adaptor) GetChannelName() string {
|
||||
return ChannelName
|
||||
}
|
||||
9
relay/channel/mokaai/constants.go
Normal file
9
relay/channel/mokaai/constants.go
Normal file
@@ -0,0 +1,9 @@
|
||||
package mokaai
|
||||
|
||||
var ModelList = []string{
|
||||
"m3e-large",
|
||||
"m3e-base",
|
||||
"m3e-small",
|
||||
}
|
||||
|
||||
var ChannelName = "mokaai"
|
||||
30
relay/channel/mokaai/dto.go
Normal file
30
relay/channel/mokaai/dto.go
Normal file
@@ -0,0 +1,30 @@
|
||||
package mokaai
|
||||
|
||||
import "one-api/dto"
|
||||
|
||||
|
||||
type Request struct {
|
||||
Messages []dto.Message `json:"messages,omitempty"`
|
||||
Lora string `json:"lora,omitempty"`
|
||||
MaxTokens int `json:"max_tokens,omitempty"`
|
||||
Prompt string `json:"prompt,omitempty"`
|
||||
Raw bool `json:"raw,omitempty"`
|
||||
Stream bool `json:"stream,omitempty"`
|
||||
Temperature float64 `json:"temperature,omitempty"`
|
||||
}
|
||||
|
||||
type Options struct {
|
||||
Seed int `json:"seed,omitempty"`
|
||||
Temperature *float64 `json:"temperature,omitempty"`
|
||||
TopK int `json:"top_k,omitempty"`
|
||||
TopP *float64 `json:"top_p,omitempty"`
|
||||
FrequencyPenalty *float64 `json:"frequency_penalty,omitempty"`
|
||||
PresencePenalty *float64 `json:"presence_penalty,omitempty"`
|
||||
NumPredict int `json:"num_predict,omitempty"`
|
||||
NumCtx int `json:"num_ctx,omitempty"`
|
||||
}
|
||||
|
||||
type EmbeddingRequest struct {
|
||||
Model string `json:"model"`
|
||||
Input []string `json:"input"`
|
||||
}
|
||||
154
relay/channel/mokaai/relay-mokaai.go
Normal file
154
relay/channel/mokaai/relay-mokaai.go
Normal file
@@ -0,0 +1,154 @@
|
||||
package mokaai
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
// "one-api/common/ctxkey"
|
||||
// "one-api/common/render"
|
||||
|
||||
// "github.com/gin-gonic/gin"
|
||||
// "one-api/common"
|
||||
// "one-api/common/helper"
|
||||
// "one-api/common/logger"
|
||||
// "one-api/relay/adaptor/openai"
|
||||
// "one-api/relay/model"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"one-api/common"
|
||||
"one-api/dto"
|
||||
relaycommon "one-api/relay/common"
|
||||
"one-api/service"
|
||||
"time"
|
||||
)
|
||||
|
||||
func ConvertCompletionsRequest(textRequest dto.GeneralOpenAIRequest) *Request {
|
||||
p, _ := textRequest.Prompt.(string)
|
||||
return &Request{
|
||||
Prompt: p,
|
||||
MaxTokens: textRequest.GetMaxTokens(),
|
||||
Stream: textRequest.Stream,
|
||||
Temperature: textRequest.Temperature,
|
||||
}
|
||||
}
|
||||
|
||||
func ConvertEmbeddingRequest(request dto.GeneralOpenAIRequest) *EmbeddingRequest {
|
||||
var input []string // Change input to []string
|
||||
|
||||
switch v := request.Input.(type) {
|
||||
case string:
|
||||
input = []string{v} // Convert string to []string
|
||||
case []string:
|
||||
input = v // Already a []string, no conversion needed
|
||||
case []interface{}:
|
||||
for _, part := range v {
|
||||
if str, ok := part.(string); ok {
|
||||
input = append(input, str) // Append each string to the slice
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return &EmbeddingRequest{
|
||||
Model: request.Model,
|
||||
Input: input, // Assign []string to Input
|
||||
}
|
||||
}
|
||||
|
||||
func StreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
|
||||
scanner := bufio.NewScanner(resp.Body)
|
||||
scanner.Split(bufio.ScanLines)
|
||||
|
||||
service.SetEventStreamHeaders(c)
|
||||
id := service.GetResponseID(c)
|
||||
var responseText string
|
||||
isFirst := true
|
||||
|
||||
for scanner.Scan() {
|
||||
data := scanner.Text()
|
||||
if len(data) < len("data: ") {
|
||||
continue
|
||||
}
|
||||
data = strings.TrimPrefix(data, "data: ")
|
||||
data = strings.TrimSuffix(data, "\r")
|
||||
|
||||
if data == "[DONE]" {
|
||||
break
|
||||
}
|
||||
|
||||
var response dto.ChatCompletionsStreamResponse
|
||||
err := json.Unmarshal([]byte(data), &response)
|
||||
if err != nil {
|
||||
common.LogError(c, "error_unmarshalling_stream_response: "+err.Error())
|
||||
continue
|
||||
}
|
||||
for _, choice := range response.Choices {
|
||||
choice.Delta.Role = "assistant"
|
||||
responseText += choice.Delta.GetContentString()
|
||||
}
|
||||
response.Id = id
|
||||
response.Model = info.UpstreamModelName
|
||||
err = service.ObjectData(c, response)
|
||||
if isFirst {
|
||||
isFirst = false
|
||||
info.FirstResponseTime = time.Now()
|
||||
}
|
||||
if err != nil {
|
||||
common.LogError(c, "error_rendering_stream_response: "+err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
if err := scanner.Err(); err != nil {
|
||||
common.LogError(c, "error_scanning_stream_response: "+err.Error())
|
||||
}
|
||||
usage, _ := service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
|
||||
if info.ShouldIncludeUsage {
|
||||
response := service.GenerateFinalUsageResponse(id, info.StartTime.Unix(), info.UpstreamModelName, *usage)
|
||||
err := service.ObjectData(c, response)
|
||||
if err != nil {
|
||||
common.LogError(c, "error_rendering_final_usage_response: "+err.Error())
|
||||
}
|
||||
}
|
||||
service.Done(c)
|
||||
|
||||
err := resp.Body.Close()
|
||||
if err != nil {
|
||||
common.LogError(c, "close_response_body_failed: "+err.Error())
|
||||
}
|
||||
|
||||
return nil, usage
|
||||
}
|
||||
|
||||
func Handler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
|
||||
responseBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
err = resp.Body.Close()
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapperLocal(err, "close_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
var response dto.TextResponse
|
||||
err = json.Unmarshal(responseBody, &response)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
response.Model = info.UpstreamModelName
|
||||
var responseText string
|
||||
for _, choice := range response.Choices {
|
||||
responseText += choice.Message.StringContent()
|
||||
}
|
||||
usage, _ := service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
|
||||
response.Usage = *usage
|
||||
response.Id = service.GetResponseID(c)
|
||||
jsonResponse, err := json.Marshal(response)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
c.Writer.Header().Set("Content-Type", "application/json")
|
||||
c.Writer.WriteHeader(resp.StatusCode)
|
||||
_, _ = c.Writer.Write(jsonResponse)
|
||||
return nil, usage
|
||||
}
|
||||
@@ -27,7 +27,7 @@ const (
|
||||
APITypeVertexAi
|
||||
APITypeMistral
|
||||
APITypeDeepSeek
|
||||
|
||||
APITypeMokaAI
|
||||
APITypeDummy // this one is only for count, do not add any channel after this
|
||||
)
|
||||
|
||||
@@ -78,6 +78,8 @@ func ChannelType2APIType(channelType int) (int, bool) {
|
||||
apiType = APITypeMistral
|
||||
case common.ChannelTypeDeepSeek:
|
||||
apiType = APITypeDeepSeek
|
||||
case common.ChannelTypeMokaAI:
|
||||
apiType = APITypeMokaAI
|
||||
}
|
||||
if apiType == -1 {
|
||||
return APITypeOpenAI, false
|
||||
|
||||
@@ -14,6 +14,7 @@ import (
|
||||
"one-api/relay/channel/gemini"
|
||||
"one-api/relay/channel/jina"
|
||||
"one-api/relay/channel/mistral"
|
||||
"one-api/relay/channel/mokaai"
|
||||
"one-api/relay/channel/ollama"
|
||||
"one-api/relay/channel/openai"
|
||||
"one-api/relay/channel/palm"
|
||||
@@ -74,6 +75,8 @@ func GetAdaptor(apiType int) channel.Adaptor {
|
||||
return &mistral.Adaptor{}
|
||||
case constant.APITypeDeepSeek:
|
||||
return &deepseek.Adaptor{}
|
||||
case constant.APITypeMokaAI:
|
||||
return &mokaai.Adaptor{}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -125,5 +125,12 @@ export const CHANNEL_OPTIONS = [
|
||||
value: 21,
|
||||
color: 'purple',
|
||||
label: '知识库:AI Proxy'
|
||||
},
|
||||
{
|
||||
key: 47,
|
||||
text: '嵌入模型:MokaAI M3E',
|
||||
value: 47,
|
||||
color: 'purple',
|
||||
label: '嵌入模型:MokaAI M3E'
|
||||
}
|
||||
];
|
||||
|
||||
Reference in New Issue
Block a user