diff --git a/Dockerfile b/Dockerfile index 44a7837a..4e0d0511 100644 --- a/Dockerfile +++ b/Dockerfile @@ -1,4 +1,4 @@ -FROM oven/bun:latest as builder +FROM oven/bun:latest AS builder WORKDIR /build COPY web/package.json . diff --git a/common/constants.go b/common/constants.go index e2acf83b..3c8d262a 100644 --- a/common/constants.go +++ b/common/constants.go @@ -231,7 +231,7 @@ const ( ChannelTypeVertexAi = 41 ChannelTypeMistral = 42 ChannelTypeDeepSeek = 43 - + ChannelTypeMokaAI = 47 ChannelTypeDummy // this one is only for count, do not add any channel after this ) @@ -281,4 +281,5 @@ var ChannelBaseURLs = []string{ "", //41 "https://api.mistral.ai", //42 "https://api.deepseek.com", //43 + "https://api.moka.ai", //43 } diff --git a/controller/channel-test.go b/controller/channel-test.go index 5f9c990f..ea325b47 100644 --- a/controller/channel-test.go +++ b/controller/channel-test.go @@ -41,14 +41,27 @@ func testChannel(channel *model.Channel, testModel string) (err error, openAIErr } w := httptest.NewRecorder() c, _ := gin.CreateTestContext(w) + + requestPath := "/v1/chat/completions" + + // 先判断是否为 Embedding 模型 + if strings.Contains(strings.ToLower(testModel), "embedding") || + strings.HasPrefix(testModel, "m3e") || // m3e 系列模型 + strings.Contains(testModel, "bge-") || // bge 系列模型 + testModel == "text-embedding-v1" || + channel.Type == common.ChannelTypeMokaAI{ // 其他 embedding 模型 + requestPath = "/v1/embeddings" // 修改请求路径 + } + c.Request = &http.Request{ Method: "POST", - URL: &url.URL{Path: "/v1/chat/completions"}, + URL: &url.URL{Path: requestPath}, // 使用动态路径 Body: nil, Header: make(http.Header), } if testModel == "" { + common.SysLog(fmt.Sprintf("testModel 为空, channel 的 TestModel 是 %s", string(*channel.TestModel))) if channel.TestModel != nil && *channel.TestModel != "" { testModel = *channel.TestModel } else { @@ -57,6 +70,7 @@ func testChannel(channel *model.Channel, testModel string) (err error, openAIErr } else { testModel = "gpt-3.5-turbo" } + common.SysLog(fmt.Sprintf("testModel 为空, channel 的 TestModel 为空:", string(testModel))) } } else { modelMapping := *channel.ModelMapping @@ -88,7 +102,7 @@ func testChannel(channel *model.Channel, testModel string) (err error, openAIErr request := buildTestRequest(testModel) meta.UpstreamModelName = testModel - common.SysLog(fmt.Sprintf("testing channel %d with model %s", channel.Id, testModel)) + common.SysLog(fmt.Sprintf("testing channel %d with model %s , meta %s ", channel.Id, testModel, meta)) adaptor.Init(meta) @@ -156,6 +170,16 @@ func buildTestRequest(model string) *dto.GeneralOpenAIRequest { Model: "", // this will be set later Stream: false, } + // 先判断是否为 Embedding 模型 + if strings.Contains(strings.ToLower(model), "embedding") || + strings.HasPrefix(model, "m3e") || // m3e 系列模型 + strings.Contains(model, "bge-") || // bge 系列模型 + model == "text-embedding-v1" { // 其他 embedding 模型 + // Embedding 请求 + testRequest.Input = []string{"hello world"} + return testRequest + } + // 并非Embedding 模型 if strings.HasPrefix(model, "o1") { testRequest.MaxCompletionTokens = 10 } else if strings.HasPrefix(model, "gemini-2.0-flash-thinking") { diff --git a/middleware/distributor.go b/middleware/distributor.go index 49cca260..c90f3e5e 100644 --- a/middleware/distributor.go +++ b/middleware/distributor.go @@ -239,5 +239,7 @@ func SetupContextForSelectedChannel(c *gin.Context, channel *model.Channel, mode c.Set("plugin", channel.Other) case common.ChannelCloudflare: c.Set("api_version", channel.Other) + case common.ChannelTypeMokaAI: + c.Set("api_version", channel.Other) } } diff --git a/relay/channel/mokaai/adaptor.go b/relay/channel/mokaai/adaptor.go new file mode 100644 index 00000000..e9fe34cf --- /dev/null +++ b/relay/channel/mokaai/adaptor.go @@ -0,0 +1,104 @@ +package mokaai + +import ( + "errors" + "fmt" + "io" + "net/http" + + "github.com/gin-gonic/gin" + // "one-api/relay/adaptor" + // "one-api/relay/meta" + // "one-api/relay/model" + // "one-api/relay/constant" + "one-api/dto" + "one-api/relay/channel" + relaycommon "one-api/relay/common" + "one-api/relay/constant" +) + +type Adaptor struct { +} + +// ConvertImageRequest implements adaptor.Adaptor. +func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) { + //TODO implement me + return nil, errors.New("not implemented") +} + +func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) { + //TODO implement me + return nil, errors.New("not implemented") +} + +func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) { + //TODO implement me + return nil, errors.New("not implemented") +} +func (a *Adaptor) Init(info *relaycommon.RelayInfo) { +} + + +func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) { + + var urlPrefix = info.BaseUrl + + switch info.RelayMode { + case constant.RelayModeChatCompletions: + return fmt.Sprintf("%s/chat/completions", urlPrefix), nil + case constant.RelayModeEmbeddings: + return fmt.Sprintf("%s/embeddings", urlPrefix), nil + default: + return fmt.Sprintf("%s/run/%s", urlPrefix, info.UpstreamModelName), nil + } +} + +func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error { + channel.SetupApiRequestHeader(info, c, req) + req.Set("Authorization", fmt.Sprintf("Bearer %s", info.ApiKey)) + return nil +} + +func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) { + if request == nil { + return nil, errors.New("request is nil") + } + switch info.RelayMode { + case constant.RelayModeChatCompletions: + return nil, errors.New("not implemented") + case constant.RelayModeEmbeddings: + // return ConvertCompletionsRequest(*request), nil + return ConvertEmbeddingRequest(*request), nil + default: + return nil, errors.New("not implemented") + } +} + +func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) { + return channel.DoApiRequest(a, c, info, requestBody) +} + +func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) { + switch info.RelayMode { + + case constant.RelayModeAudioTranscription: + case constant.RelayModeAudioTranslation: + case constant.RelayModeChatCompletions: + fallthrough + case constant.RelayModeEmbeddings: + if info.IsStream { + err, usage = StreamHandler(c, resp, info) + } else { + err, usage = Handler(c, resp, info) + } + } + return +} + +func (a *Adaptor) GetModelList() []string { + return ModelList +} + +func (a *Adaptor) GetChannelName() string { + return ChannelName +} diff --git a/relay/channel/mokaai/constants.go b/relay/channel/mokaai/constants.go new file mode 100644 index 00000000..415d83b7 --- /dev/null +++ b/relay/channel/mokaai/constants.go @@ -0,0 +1,9 @@ +package mokaai + +var ModelList = []string{ + "m3e-large", + "m3e-base", + "m3e-small", +} + +var ChannelName = "mokaai" \ No newline at end of file diff --git a/relay/channel/mokaai/dto.go b/relay/channel/mokaai/dto.go new file mode 100644 index 00000000..18d2dccb --- /dev/null +++ b/relay/channel/mokaai/dto.go @@ -0,0 +1,30 @@ +package mokaai + +import "one-api/dto" + + +type Request struct { + Messages []dto.Message `json:"messages,omitempty"` + Lora string `json:"lora,omitempty"` + MaxTokens int `json:"max_tokens,omitempty"` + Prompt string `json:"prompt,omitempty"` + Raw bool `json:"raw,omitempty"` + Stream bool `json:"stream,omitempty"` + Temperature float64 `json:"temperature,omitempty"` +} + +type Options struct { + Seed int `json:"seed,omitempty"` + Temperature *float64 `json:"temperature,omitempty"` + TopK int `json:"top_k,omitempty"` + TopP *float64 `json:"top_p,omitempty"` + FrequencyPenalty *float64 `json:"frequency_penalty,omitempty"` + PresencePenalty *float64 `json:"presence_penalty,omitempty"` + NumPredict int `json:"num_predict,omitempty"` + NumCtx int `json:"num_ctx,omitempty"` +} + +type EmbeddingRequest struct { + Model string `json:"model"` + Input []string `json:"input"` +} \ No newline at end of file diff --git a/relay/channel/mokaai/relay-mokaai.go b/relay/channel/mokaai/relay-mokaai.go new file mode 100644 index 00000000..44d7f7c2 --- /dev/null +++ b/relay/channel/mokaai/relay-mokaai.go @@ -0,0 +1,154 @@ +package mokaai + +import ( + "bufio" + "encoding/json" + "io" + "net/http" + "strings" + + // "one-api/common/ctxkey" + // "one-api/common/render" + + // "github.com/gin-gonic/gin" + // "one-api/common" + // "one-api/common/helper" + // "one-api/common/logger" + // "one-api/relay/adaptor/openai" + // "one-api/relay/model" + + "github.com/gin-gonic/gin" + "one-api/common" + "one-api/dto" + relaycommon "one-api/relay/common" + "one-api/service" + "time" +) + +func ConvertCompletionsRequest(textRequest dto.GeneralOpenAIRequest) *Request { + p, _ := textRequest.Prompt.(string) + return &Request{ + Prompt: p, + MaxTokens: textRequest.GetMaxTokens(), + Stream: textRequest.Stream, + Temperature: textRequest.Temperature, + } +} + +func ConvertEmbeddingRequest(request dto.GeneralOpenAIRequest) *EmbeddingRequest { + var input []string // Change input to []string + + switch v := request.Input.(type) { + case string: + input = []string{v} // Convert string to []string + case []string: + input = v // Already a []string, no conversion needed + case []interface{}: + for _, part := range v { + if str, ok := part.(string); ok { + input = append(input, str) // Append each string to the slice + } + } + } + + return &EmbeddingRequest{ + Model: request.Model, + Input: input, // Assign []string to Input + } +} + +func StreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) { + scanner := bufio.NewScanner(resp.Body) + scanner.Split(bufio.ScanLines) + + service.SetEventStreamHeaders(c) + id := service.GetResponseID(c) + var responseText string + isFirst := true + + for scanner.Scan() { + data := scanner.Text() + if len(data) < len("data: ") { + continue + } + data = strings.TrimPrefix(data, "data: ") + data = strings.TrimSuffix(data, "\r") + + if data == "[DONE]" { + break + } + + var response dto.ChatCompletionsStreamResponse + err := json.Unmarshal([]byte(data), &response) + if err != nil { + common.LogError(c, "error_unmarshalling_stream_response: "+err.Error()) + continue + } + for _, choice := range response.Choices { + choice.Delta.Role = "assistant" + responseText += choice.Delta.GetContentString() + } + response.Id = id + response.Model = info.UpstreamModelName + err = service.ObjectData(c, response) + if isFirst { + isFirst = false + info.FirstResponseTime = time.Now() + } + if err != nil { + common.LogError(c, "error_rendering_stream_response: "+err.Error()) + } + } + + if err := scanner.Err(); err != nil { + common.LogError(c, "error_scanning_stream_response: "+err.Error()) + } + usage, _ := service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens) + if info.ShouldIncludeUsage { + response := service.GenerateFinalUsageResponse(id, info.StartTime.Unix(), info.UpstreamModelName, *usage) + err := service.ObjectData(c, response) + if err != nil { + common.LogError(c, "error_rendering_final_usage_response: "+err.Error()) + } + } + service.Done(c) + + err := resp.Body.Close() + if err != nil { + common.LogError(c, "close_response_body_failed: "+err.Error()) + } + + return nil, usage +} + +func Handler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) { + responseBody, err := io.ReadAll(resp.Body) + if err != nil { + return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil + } + err = resp.Body.Close() + if err != nil { + return service.OpenAIErrorWrapperLocal(err, "close_response_body_failed", http.StatusInternalServerError), nil + } + var response dto.TextResponse + err = json.Unmarshal(responseBody, &response) + if err != nil { + return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil + } + response.Model = info.UpstreamModelName + var responseText string + for _, choice := range response.Choices { + responseText += choice.Message.StringContent() + } + usage, _ := service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens) + response.Usage = *usage + response.Id = service.GetResponseID(c) + jsonResponse, err := json.Marshal(response) + if err != nil { + return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil + } + c.Writer.Header().Set("Content-Type", "application/json") + c.Writer.WriteHeader(resp.StatusCode) + _, _ = c.Writer.Write(jsonResponse) + return nil, usage +} diff --git a/relay/constant/api_type.go b/relay/constant/api_type.go index c5c1d089..1a40a6ee 100644 --- a/relay/constant/api_type.go +++ b/relay/constant/api_type.go @@ -27,7 +27,7 @@ const ( APITypeVertexAi APITypeMistral APITypeDeepSeek - + APITypeMokaAI APITypeDummy // this one is only for count, do not add any channel after this ) @@ -78,6 +78,8 @@ func ChannelType2APIType(channelType int) (int, bool) { apiType = APITypeMistral case common.ChannelTypeDeepSeek: apiType = APITypeDeepSeek + case common.ChannelTypeMokaAI: + apiType = APITypeMokaAI } if apiType == -1 { return APITypeOpenAI, false diff --git a/relay/relay_adaptor.go b/relay/relay_adaptor.go index 1c7d11e9..9304bd6d 100644 --- a/relay/relay_adaptor.go +++ b/relay/relay_adaptor.go @@ -14,6 +14,7 @@ import ( "one-api/relay/channel/gemini" "one-api/relay/channel/jina" "one-api/relay/channel/mistral" + "one-api/relay/channel/mokaai" "one-api/relay/channel/ollama" "one-api/relay/channel/openai" "one-api/relay/channel/palm" @@ -74,6 +75,8 @@ func GetAdaptor(apiType int) channel.Adaptor { return &mistral.Adaptor{} case constant.APITypeDeepSeek: return &deepseek.Adaptor{} + case constant.APITypeMokaAI: + return &mokaai.Adaptor{} } return nil } diff --git a/web/src/constants/channel.constants.js b/web/src/constants/channel.constants.js index 32bf8bce..c1be95b7 100644 --- a/web/src/constants/channel.constants.js +++ b/web/src/constants/channel.constants.js @@ -125,5 +125,12 @@ export const CHANNEL_OPTIONS = [ value: 21, color: 'purple', label: '知识库:AI Proxy' + }, + { + key: 47, + text: '嵌入模型:MokaAI M3E', + value: 47, + color: 'purple', + label: '嵌入模型:MokaAI M3E' } ];