@@ -1,4 +1,4 @@
|
|||||||
FROM oven/bun:latest as builder
|
FROM oven/bun:latest AS builder
|
||||||
|
|
||||||
WORKDIR /build
|
WORKDIR /build
|
||||||
COPY web/package.json .
|
COPY web/package.json .
|
||||||
|
|||||||
@@ -231,7 +231,7 @@ const (
|
|||||||
ChannelTypeVertexAi = 41
|
ChannelTypeVertexAi = 41
|
||||||
ChannelTypeMistral = 42
|
ChannelTypeMistral = 42
|
||||||
ChannelTypeDeepSeek = 43
|
ChannelTypeDeepSeek = 43
|
||||||
|
ChannelTypeMokaAI = 47
|
||||||
ChannelTypeDummy // this one is only for count, do not add any channel after this
|
ChannelTypeDummy // this one is only for count, do not add any channel after this
|
||||||
|
|
||||||
)
|
)
|
||||||
@@ -281,4 +281,5 @@ var ChannelBaseURLs = []string{
|
|||||||
"", //41
|
"", //41
|
||||||
"https://api.mistral.ai", //42
|
"https://api.mistral.ai", //42
|
||||||
"https://api.deepseek.com", //43
|
"https://api.deepseek.com", //43
|
||||||
|
"https://api.moka.ai", //43
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -41,14 +41,27 @@ func testChannel(channel *model.Channel, testModel string) (err error, openAIErr
|
|||||||
}
|
}
|
||||||
w := httptest.NewRecorder()
|
w := httptest.NewRecorder()
|
||||||
c, _ := gin.CreateTestContext(w)
|
c, _ := gin.CreateTestContext(w)
|
||||||
|
|
||||||
|
requestPath := "/v1/chat/completions"
|
||||||
|
|
||||||
|
// 先判断是否为 Embedding 模型
|
||||||
|
if strings.Contains(strings.ToLower(testModel), "embedding") ||
|
||||||
|
strings.HasPrefix(testModel, "m3e") || // m3e 系列模型
|
||||||
|
strings.Contains(testModel, "bge-") || // bge 系列模型
|
||||||
|
testModel == "text-embedding-v1" ||
|
||||||
|
channel.Type == common.ChannelTypeMokaAI{ // 其他 embedding 模型
|
||||||
|
requestPath = "/v1/embeddings" // 修改请求路径
|
||||||
|
}
|
||||||
|
|
||||||
c.Request = &http.Request{
|
c.Request = &http.Request{
|
||||||
Method: "POST",
|
Method: "POST",
|
||||||
URL: &url.URL{Path: "/v1/chat/completions"},
|
URL: &url.URL{Path: requestPath}, // 使用动态路径
|
||||||
Body: nil,
|
Body: nil,
|
||||||
Header: make(http.Header),
|
Header: make(http.Header),
|
||||||
}
|
}
|
||||||
|
|
||||||
if testModel == "" {
|
if testModel == "" {
|
||||||
|
common.SysLog(fmt.Sprintf("testModel 为空, channel 的 TestModel 是 %s", string(*channel.TestModel)))
|
||||||
if channel.TestModel != nil && *channel.TestModel != "" {
|
if channel.TestModel != nil && *channel.TestModel != "" {
|
||||||
testModel = *channel.TestModel
|
testModel = *channel.TestModel
|
||||||
} else {
|
} else {
|
||||||
@@ -57,6 +70,7 @@ func testChannel(channel *model.Channel, testModel string) (err error, openAIErr
|
|||||||
} else {
|
} else {
|
||||||
testModel = "gpt-3.5-turbo"
|
testModel = "gpt-3.5-turbo"
|
||||||
}
|
}
|
||||||
|
common.SysLog(fmt.Sprintf("testModel 为空, channel 的 TestModel 为空:", string(testModel)))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -88,7 +102,7 @@ func testChannel(channel *model.Channel, testModel string) (err error, openAIErr
|
|||||||
|
|
||||||
request := buildTestRequest(testModel)
|
request := buildTestRequest(testModel)
|
||||||
meta.UpstreamModelName = testModel
|
meta.UpstreamModelName = testModel
|
||||||
common.SysLog(fmt.Sprintf("testing channel %d with model %s", channel.Id, testModel))
|
common.SysLog(fmt.Sprintf("testing channel %d with model %s , meta %s ", channel.Id, testModel, meta))
|
||||||
|
|
||||||
adaptor.Init(meta)
|
adaptor.Init(meta)
|
||||||
|
|
||||||
@@ -156,6 +170,17 @@ func buildTestRequest(model string) *dto.GeneralOpenAIRequest {
|
|||||||
Model: "", // this will be set later
|
Model: "", // this will be set later
|
||||||
Stream: false,
|
Stream: false,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// 先判断是否为 Embedding 模型
|
||||||
|
if strings.Contains(strings.ToLower(model), "embedding") ||
|
||||||
|
strings.HasPrefix(model, "m3e") || // m3e 系列模型
|
||||||
|
strings.Contains(model, "bge-") || // bge 系列模型
|
||||||
|
model == "text-embedding-v1" { // 其他 embedding 模型
|
||||||
|
// Embedding 请求
|
||||||
|
testRequest.Input = []string{"hello world"}
|
||||||
|
return testRequest
|
||||||
|
}
|
||||||
|
// 并非Embedding 模型
|
||||||
if strings.HasPrefix(model, "o1") || strings.HasPrefix(model, "o3") {
|
if strings.HasPrefix(model, "o1") || strings.HasPrefix(model, "o3") {
|
||||||
testRequest.MaxCompletionTokens = 10
|
testRequest.MaxCompletionTokens = 10
|
||||||
} else {
|
} else {
|
||||||
|
|||||||
@@ -33,6 +33,8 @@ func relayHandler(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusCode
|
|||||||
err = relay.AudioHelper(c)
|
err = relay.AudioHelper(c)
|
||||||
case relayconstant.RelayModeRerank:
|
case relayconstant.RelayModeRerank:
|
||||||
err = relay.RerankHelper(c, relayMode)
|
err = relay.RerankHelper(c, relayMode)
|
||||||
|
case relayconstant.RelayModeEmbeddings:
|
||||||
|
err = relay.EmbeddingHelper(c,relayMode)
|
||||||
default:
|
default:
|
||||||
err = relay.TextHelper(c)
|
err = relay.TextHelper(c)
|
||||||
}
|
}
|
||||||
@@ -55,6 +57,11 @@ func Relay(c *gin.Context) {
|
|||||||
originalModel := c.GetString("original_model")
|
originalModel := c.GetString("original_model")
|
||||||
var openaiErr *dto.OpenAIErrorWithStatusCode
|
var openaiErr *dto.OpenAIErrorWithStatusCode
|
||||||
|
|
||||||
|
//获取request body 并输出到日志
|
||||||
|
requestBody, _ := common.GetRequestBody(c)
|
||||||
|
common.LogInfo(c, fmt.Sprintf("relayMode: %d ,request body: %s",relayMode, string(requestBody)))
|
||||||
|
|
||||||
|
|
||||||
for i := 0; i <= common.RetryTimes; i++ {
|
for i := 0; i <= common.RetryTimes; i++ {
|
||||||
channel, err := getChannel(c, group, originalModel, i)
|
channel, err := getChannel(c, group, originalModel, i)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -154,6 +161,7 @@ func WssRelay(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func relayRequest(c *gin.Context, relayMode int, channel *model.Channel) *dto.OpenAIErrorWithStatusCode {
|
func relayRequest(c *gin.Context, relayMode int, channel *model.Channel) *dto.OpenAIErrorWithStatusCode {
|
||||||
|
common.LogInfo(c, fmt.Sprintf("relayMode: %d ,channel Id : %s",relayMode, string(channel.Id)))
|
||||||
addUsedChannel(c, channel.Id)
|
addUsedChannel(c, channel.Id)
|
||||||
requestBody, _ := common.GetRequestBody(c)
|
requestBody, _ := common.GetRequestBody(c)
|
||||||
c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody))
|
c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody))
|
||||||
|
|||||||
30
dto/embedding.go
Normal file
30
dto/embedding.go
Normal file
@@ -0,0 +1,30 @@
|
|||||||
|
package dto
|
||||||
|
|
||||||
|
type EmbeddingOptions struct {
|
||||||
|
Seed int `json:"seed,omitempty"`
|
||||||
|
Temperature *float64 `json:"temperature,omitempty"`
|
||||||
|
TopK int `json:"top_k,omitempty"`
|
||||||
|
TopP *float64 `json:"top_p,omitempty"`
|
||||||
|
FrequencyPenalty *float64 `json:"frequency_penalty,omitempty"`
|
||||||
|
PresencePenalty *float64 `json:"presence_penalty,omitempty"`
|
||||||
|
NumPredict int `json:"num_predict,omitempty"`
|
||||||
|
NumCtx int `json:"num_ctx,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type EmbeddingRequest struct {
|
||||||
|
Model string `json:"model"`
|
||||||
|
Input []string `json:"input"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type EmbeddingResponseItem struct {
|
||||||
|
Object string `json:"object"`
|
||||||
|
Index int `json:"index"`
|
||||||
|
Embedding []float64 `json:"embedding"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type EmbeddingResponse struct {
|
||||||
|
Object string `json:"object"`
|
||||||
|
Data []EmbeddingResponseItem `json:"data"`
|
||||||
|
Model string `json:"model"`
|
||||||
|
Usage `json:"usage"`
|
||||||
|
}
|
||||||
@@ -239,5 +239,7 @@ func SetupContextForSelectedChannel(c *gin.Context, channel *model.Channel, mode
|
|||||||
c.Set("plugin", channel.Other)
|
c.Set("plugin", channel.Other)
|
||||||
case common.ChannelCloudflare:
|
case common.ChannelCloudflare:
|
||||||
c.Set("api_version", channel.Other)
|
c.Set("api_version", channel.Other)
|
||||||
|
case common.ChannelTypeMokaAI:
|
||||||
|
c.Set("api_version", channel.Other)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -15,6 +15,7 @@ type Adaptor interface {
|
|||||||
SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error
|
SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error
|
||||||
ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error)
|
ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error)
|
||||||
ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error)
|
ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error)
|
||||||
|
ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.EmbeddingRequest) (any, error)
|
||||||
ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error)
|
ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error)
|
||||||
ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error)
|
ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error)
|
||||||
DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error)
|
DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error)
|
||||||
|
|||||||
@@ -67,6 +67,11 @@ func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dt
|
|||||||
return nil, errors.New("not implemented")
|
return nil, errors.New("not implemented")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.EmbeddingRequest) (any, error) {
|
||||||
|
//TODO implement me
|
||||||
|
return nil, errors.New("not implemented")
|
||||||
|
}
|
||||||
|
|
||||||
func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) {
|
func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) {
|
||||||
//TODO implement me
|
//TODO implement me
|
||||||
return nil, errors.New("not implemented")
|
return nil, errors.New("not implemented")
|
||||||
|
|||||||
@@ -59,6 +59,12 @@ func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dt
|
|||||||
return nil, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.EmbeddingRequest) (any, error) {
|
||||||
|
//TODO implement me
|
||||||
|
return nil, errors.New("not implemented")
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
|
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
|
||||||
return nil, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -122,6 +122,11 @@ func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dt
|
|||||||
return nil, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.EmbeddingRequest) (any, error) {
|
||||||
|
//TODO implement me
|
||||||
|
return nil, errors.New("not implemented")
|
||||||
|
}
|
||||||
|
|
||||||
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
|
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
|
||||||
return channel.DoApiRequest(a, c, info, requestBody)
|
return channel.DoApiRequest(a, c, info, requestBody)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -73,6 +73,11 @@ func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dt
|
|||||||
return nil, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.EmbeddingRequest) (any, error) {
|
||||||
|
//TODO implement me
|
||||||
|
return nil, errors.New("not implemented")
|
||||||
|
}
|
||||||
|
|
||||||
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
|
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
|
||||||
return channel.DoApiRequest(a, c, info, requestBody)
|
return channel.DoApiRequest(a, c, info, requestBody)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -56,6 +56,12 @@ func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dt
|
|||||||
return request, nil
|
return request, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.EmbeddingRequest) (any, error) {
|
||||||
|
//TODO implement me
|
||||||
|
return nil, errors.New("not implemented")
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) {
|
func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) {
|
||||||
// 添加文件字段
|
// 添加文件字段
|
||||||
file, _, err := c.Request.FormFile("file")
|
file, _, err := c.Request.FormFile("file")
|
||||||
|
|||||||
@@ -54,6 +54,12 @@ func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dt
|
|||||||
return requestConvertRerank2Cohere(request), nil
|
return requestConvertRerank2Cohere(request), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.EmbeddingRequest) (any, error) {
|
||||||
|
//TODO implement me
|
||||||
|
return nil, errors.New("not implemented")
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) {
|
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) {
|
||||||
if info.RelayMode == constant.RelayModeRerank {
|
if info.RelayMode == constant.RelayModeRerank {
|
||||||
err, usage = cohereRerankHandler(c, resp, info)
|
err, usage = cohereRerankHandler(c, resp, info)
|
||||||
|
|||||||
@@ -49,6 +49,12 @@ func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dt
|
|||||||
return nil, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.EmbeddingRequest) (any, error) {
|
||||||
|
//TODO implement me
|
||||||
|
return nil, errors.New("not implemented")
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
|
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
|
||||||
return channel.DoApiRequest(a, c, info, requestBody)
|
return channel.DoApiRequest(a, c, info, requestBody)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -48,6 +48,12 @@ func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dt
|
|||||||
return nil, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.EmbeddingRequest) (any, error) {
|
||||||
|
//TODO implement me
|
||||||
|
return nil, errors.New("not implemented")
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
|
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
|
||||||
return channel.DoApiRequest(a, c, info, requestBody)
|
return channel.DoApiRequest(a, c, info, requestBody)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -68,6 +68,12 @@ func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dt
|
|||||||
return nil, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.EmbeddingRequest) (any, error) {
|
||||||
|
//TODO implement me
|
||||||
|
return nil, errors.New("not implemented")
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
|
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
|
||||||
return channel.DoApiRequest(a, c, info, requestBody)
|
return channel.DoApiRequest(a, c, info, requestBody)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -55,6 +55,12 @@ func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dt
|
|||||||
return request, nil
|
return request, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.EmbeddingRequest) (any, error) {
|
||||||
|
//TODO implement me
|
||||||
|
return nil, errors.New("not implemented")
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) {
|
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) {
|
||||||
if info.RelayMode == constant.RelayModeRerank {
|
if info.RelayMode == constant.RelayModeRerank {
|
||||||
err, usage = jinaRerankHandler(c, resp)
|
err, usage = jinaRerankHandler(c, resp)
|
||||||
|
|||||||
@@ -50,6 +50,12 @@ func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dt
|
|||||||
return nil, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.EmbeddingRequest) (any, error) {
|
||||||
|
//TODO implement me
|
||||||
|
return nil, errors.New("not implemented")
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
|
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
|
||||||
return channel.DoApiRequest(a, c, info, requestBody)
|
return channel.DoApiRequest(a, c, info, requestBody)
|
||||||
}
|
}
|
||||||
|
|||||||
93
relay/channel/mokaai/adaptor.go
Normal file
93
relay/channel/mokaai/adaptor.go
Normal file
@@ -0,0 +1,93 @@
|
|||||||
|
package mokaai
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"one-api/dto"
|
||||||
|
"one-api/relay/channel"
|
||||||
|
relaycommon "one-api/relay/common"
|
||||||
|
"one-api/relay/constant"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Adaptor struct {
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) {
|
||||||
|
//TODO implement me
|
||||||
|
return nil, errors.New("not implemented")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) {
|
||||||
|
//TODO implement me
|
||||||
|
return nil, errors.New("not implemented")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.EmbeddingRequest) (any, error) {
|
||||||
|
//TODO implement me
|
||||||
|
return request, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
||||||
|
// https://cloud.baidu.com/doc/WENXINWORKSHOP/s/clntwmv7t
|
||||||
|
suffix := "chat/"
|
||||||
|
if strings.HasPrefix(info.UpstreamModelName, "m3e") {
|
||||||
|
suffix = "embeddings"
|
||||||
|
}
|
||||||
|
fullRequestURL := fmt.Sprintf("%s/%s", info.BaseUrl, suffix)
|
||||||
|
return fullRequestURL, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error {
|
||||||
|
channel.SetupApiRequestHeader(info, c, req)
|
||||||
|
req.Set("Authorization", fmt.Sprintf("Bearer %s", info.ApiKey))
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
|
||||||
|
if request == nil {
|
||||||
|
return nil, errors.New("request is nil")
|
||||||
|
}
|
||||||
|
switch info.RelayMode {
|
||||||
|
case constant.RelayModeEmbeddings:
|
||||||
|
baiduEmbeddingRequest := embeddingRequestOpenAI2Moka(*request)
|
||||||
|
return baiduEmbeddingRequest, nil
|
||||||
|
default:
|
||||||
|
return nil, errors.New("not implemented")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
|
||||||
|
return channel.DoApiRequest(a, c, info, requestBody)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) {
|
||||||
|
|
||||||
|
switch info.RelayMode {
|
||||||
|
case constant.RelayModeEmbeddings:
|
||||||
|
err, usage = mokaEmbeddingHandler(c, resp)
|
||||||
|
default:
|
||||||
|
// err, usage = mokaHandler(c, resp)
|
||||||
|
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *Adaptor) GetModelList() []string {
|
||||||
|
return ModelList
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *Adaptor) GetChannelName() string {
|
||||||
|
return ChannelName
|
||||||
|
}
|
||||||
9
relay/channel/mokaai/constants.go
Normal file
9
relay/channel/mokaai/constants.go
Normal file
@@ -0,0 +1,9 @@
|
|||||||
|
package mokaai
|
||||||
|
|
||||||
|
var ModelList = []string{
|
||||||
|
"m3e-large",
|
||||||
|
"m3e-base",
|
||||||
|
"m3e-small",
|
||||||
|
}
|
||||||
|
|
||||||
|
var ChannelName = "mokaai"
|
||||||
83
relay/channel/mokaai/relay-mokaai.go
Normal file
83
relay/channel/mokaai/relay-mokaai.go
Normal file
@@ -0,0 +1,83 @@
|
|||||||
|
package mokaai
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"one-api/dto"
|
||||||
|
"one-api/service"
|
||||||
|
)
|
||||||
|
|
||||||
|
func embeddingRequestOpenAI2Moka(request dto.GeneralOpenAIRequest) *dto.EmbeddingRequest {
|
||||||
|
var input []string // Change input to []string
|
||||||
|
|
||||||
|
switch v := request.Input.(type) {
|
||||||
|
case string:
|
||||||
|
input = []string{v} // Convert string to []string
|
||||||
|
case []string:
|
||||||
|
input = v // Already a []string, no conversion needed
|
||||||
|
case []interface{}:
|
||||||
|
for _, part := range v {
|
||||||
|
if str, ok := part.(string); ok {
|
||||||
|
input = append(input, str) // Append each string to the slice
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return &dto.EmbeddingRequest{
|
||||||
|
Input: input,
|
||||||
|
Model: request.Model,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func embeddingResponseMoka2OpenAI(response *dto.EmbeddingResponse) *dto.OpenAIEmbeddingResponse {
|
||||||
|
openAIEmbeddingResponse := dto.OpenAIEmbeddingResponse{
|
||||||
|
Object: "list",
|
||||||
|
Data: make([]dto.OpenAIEmbeddingResponseItem, 0, len(response.Data)),
|
||||||
|
Model: "baidu-embedding",
|
||||||
|
Usage: response.Usage,
|
||||||
|
}
|
||||||
|
for _, item := range response.Data {
|
||||||
|
openAIEmbeddingResponse.Data = append(openAIEmbeddingResponse.Data, dto.OpenAIEmbeddingResponseItem{
|
||||||
|
Object: item.Object,
|
||||||
|
Index: item.Index,
|
||||||
|
Embedding: item.Embedding,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
return &openAIEmbeddingResponse
|
||||||
|
}
|
||||||
|
|
||||||
|
func mokaEmbeddingHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
|
||||||
|
var baiduResponse dto.EmbeddingResponse
|
||||||
|
responseBody, err := io.ReadAll(resp.Body)
|
||||||
|
if err != nil {
|
||||||
|
return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
|
||||||
|
}
|
||||||
|
err = resp.Body.Close()
|
||||||
|
if err != nil {
|
||||||
|
return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
|
||||||
|
}
|
||||||
|
err = json.Unmarshal(responseBody, &baiduResponse)
|
||||||
|
if err != nil {
|
||||||
|
return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
|
||||||
|
}
|
||||||
|
// if baiduResponse.ErrorMsg != "" {
|
||||||
|
// return &dto.OpenAIErrorWithStatusCode{
|
||||||
|
// Error: dto.OpenAIError{
|
||||||
|
// Type: "baidu_error",
|
||||||
|
// Param: "",
|
||||||
|
// },
|
||||||
|
// StatusCode: resp.StatusCode,
|
||||||
|
// }, nil
|
||||||
|
// }
|
||||||
|
fullTextResponse := embeddingResponseMoka2OpenAI(&baiduResponse)
|
||||||
|
jsonResponse, err := json.Marshal(fullTextResponse)
|
||||||
|
if err != nil {
|
||||||
|
return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
|
||||||
|
}
|
||||||
|
c.Writer.Header().Set("Content-Type", "application/json")
|
||||||
|
c.Writer.WriteHeader(resp.StatusCode)
|
||||||
|
_, err = c.Writer.Write(jsonResponse)
|
||||||
|
return nil, &fullTextResponse.Usage
|
||||||
|
}
|
||||||
|
|
||||||
@@ -58,6 +58,12 @@ func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dt
|
|||||||
return nil, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.EmbeddingRequest) (any, error) {
|
||||||
|
//TODO implement me
|
||||||
|
return nil, errors.New("not implemented")
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
|
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
|
||||||
return channel.DoApiRequest(a, c, info, requestBody)
|
return channel.DoApiRequest(a, c, info, requestBody)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -149,6 +149,11 @@ func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dt
|
|||||||
return nil, errors.New("not implemented")
|
return nil, errors.New("not implemented")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.EmbeddingRequest) (any, error) {
|
||||||
|
//TODO implement me
|
||||||
|
return nil, errors.New("not implemented")
|
||||||
|
}
|
||||||
|
|
||||||
func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) {
|
func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) {
|
||||||
a.ResponseFormat = request.ResponseFormat
|
a.ResponseFormat = request.ResponseFormat
|
||||||
if info.RelayMode == constant.RelayModeAudioSpeech {
|
if info.RelayMode == constant.RelayModeAudioSpeech {
|
||||||
|
|||||||
@@ -49,6 +49,12 @@ func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dt
|
|||||||
return nil, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.EmbeddingRequest) (any, error) {
|
||||||
|
//TODO implement me
|
||||||
|
return nil, errors.New("not implemented")
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
|
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
|
||||||
return channel.DoApiRequest(a, c, info, requestBody)
|
return channel.DoApiRequest(a, c, info, requestBody)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -52,6 +52,12 @@ func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dt
|
|||||||
return nil, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.EmbeddingRequest) (any, error) {
|
||||||
|
//TODO implement me
|
||||||
|
return nil, errors.New("not implemented")
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
|
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
|
||||||
return channel.DoApiRequest(a, c, info, requestBody)
|
return channel.DoApiRequest(a, c, info, requestBody)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -58,6 +58,12 @@ func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dt
|
|||||||
return request, nil
|
return request, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.EmbeddingRequest) (any, error) {
|
||||||
|
//TODO implement me
|
||||||
|
return nil, errors.New("not implemented")
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) {
|
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) {
|
||||||
switch info.RelayMode {
|
switch info.RelayMode {
|
||||||
case constant.RelayModeRerank:
|
case constant.RelayModeRerank:
|
||||||
|
|||||||
@@ -73,6 +73,12 @@ func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dt
|
|||||||
return nil, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.EmbeddingRequest) (any, error) {
|
||||||
|
//TODO implement me
|
||||||
|
return nil, errors.New("not implemented")
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
|
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
|
||||||
return channel.DoApiRequest(a, c, info, requestBody)
|
return channel.DoApiRequest(a, c, info, requestBody)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -151,6 +151,12 @@ func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dt
|
|||||||
return nil, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.EmbeddingRequest) (any, error) {
|
||||||
|
//TODO implement me
|
||||||
|
return nil, errors.New("not implemented")
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
|
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
|
||||||
return channel.DoApiRequest(a, c, info, requestBody)
|
return channel.DoApiRequest(a, c, info, requestBody)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -50,6 +50,12 @@ func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dt
|
|||||||
return nil, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.EmbeddingRequest) (any, error) {
|
||||||
|
//TODO implement me
|
||||||
|
return nil, errors.New("not implemented")
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
|
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
|
||||||
// xunfei's request is not http request, so we don't need to do anything here
|
// xunfei's request is not http request, so we don't need to do anything here
|
||||||
dummyResp := &http.Response{}
|
dummyResp := &http.Response{}
|
||||||
|
|||||||
@@ -56,6 +56,12 @@ func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dt
|
|||||||
return nil, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.EmbeddingRequest) (any, error) {
|
||||||
|
//TODO implement me
|
||||||
|
return nil, errors.New("not implemented")
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
|
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
|
||||||
return channel.DoApiRequest(a, c, info, requestBody)
|
return channel.DoApiRequest(a, c, info, requestBody)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -53,6 +53,12 @@ func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dt
|
|||||||
return nil, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.EmbeddingRequest) (any, error) {
|
||||||
|
//TODO implement me
|
||||||
|
return nil, errors.New("not implemented")
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
|
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
|
||||||
return channel.DoApiRequest(a, c, info, requestBody)
|
return channel.DoApiRequest(a, c, info, requestBody)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -27,7 +27,7 @@ const (
|
|||||||
APITypeVertexAi
|
APITypeVertexAi
|
||||||
APITypeMistral
|
APITypeMistral
|
||||||
APITypeDeepSeek
|
APITypeDeepSeek
|
||||||
|
APITypeMokaAI
|
||||||
APITypeDummy // this one is only for count, do not add any channel after this
|
APITypeDummy // this one is only for count, do not add any channel after this
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -78,6 +78,8 @@ func ChannelType2APIType(channelType int) (int, bool) {
|
|||||||
apiType = APITypeMistral
|
apiType = APITypeMistral
|
||||||
case common.ChannelTypeDeepSeek:
|
case common.ChannelTypeDeepSeek:
|
||||||
apiType = APITypeDeepSeek
|
apiType = APITypeDeepSeek
|
||||||
|
case common.ChannelTypeMokaAI:
|
||||||
|
apiType = APITypeMokaAI
|
||||||
}
|
}
|
||||||
if apiType == -1 {
|
if apiType == -1 {
|
||||||
return APITypeOpenAI, false
|
return APITypeOpenAI, false
|
||||||
|
|||||||
@@ -14,6 +14,7 @@ import (
|
|||||||
"one-api/relay/channel/gemini"
|
"one-api/relay/channel/gemini"
|
||||||
"one-api/relay/channel/jina"
|
"one-api/relay/channel/jina"
|
||||||
"one-api/relay/channel/mistral"
|
"one-api/relay/channel/mistral"
|
||||||
|
"one-api/relay/channel/mokaai"
|
||||||
"one-api/relay/channel/ollama"
|
"one-api/relay/channel/ollama"
|
||||||
"one-api/relay/channel/openai"
|
"one-api/relay/channel/openai"
|
||||||
"one-api/relay/channel/palm"
|
"one-api/relay/channel/palm"
|
||||||
@@ -74,6 +75,8 @@ func GetAdaptor(apiType int) channel.Adaptor {
|
|||||||
return &mistral.Adaptor{}
|
return &mistral.Adaptor{}
|
||||||
case constant.APITypeDeepSeek:
|
case constant.APITypeDeepSeek:
|
||||||
return &deepseek.Adaptor{}
|
return &deepseek.Adaptor{}
|
||||||
|
case constant.APITypeMokaAI:
|
||||||
|
return &mokaai.Adaptor{}
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|||||||
127
relay/relay_embedding.go
Normal file
127
relay/relay_embedding.go
Normal file
@@ -0,0 +1,127 @@
|
|||||||
|
package relay
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
"net/http"
|
||||||
|
"one-api/common"
|
||||||
|
"one-api/dto"
|
||||||
|
relaycommon "one-api/relay/common"
|
||||||
|
relayconstant "one-api/relay/constant"
|
||||||
|
"one-api/service"
|
||||||
|
"one-api/setting"
|
||||||
|
)
|
||||||
|
|
||||||
|
func getEmbeddingPromptToken(embeddingRequest dto.EmbeddingRequest) int {
|
||||||
|
token, _ := service.CountTokenInput(embeddingRequest.Input, embeddingRequest.Model)
|
||||||
|
return token
|
||||||
|
}
|
||||||
|
|
||||||
|
func EmbeddingHelper(c *gin.Context, relayMode int) (openaiErr *dto.OpenAIErrorWithStatusCode) {
|
||||||
|
relayInfo := relaycommon.GenRelayInfo(c)
|
||||||
|
|
||||||
|
var embeddingRequest *dto.EmbeddingRequest
|
||||||
|
err := common.UnmarshalBodyReusable(c, &embeddingRequest)
|
||||||
|
if err != nil {
|
||||||
|
common.LogError(c, fmt.Sprintf("getAndValidateTextRequest failed: %s", err.Error()))
|
||||||
|
return service.OpenAIErrorWrapperLocal(err, "invalid_text_request", http.StatusBadRequest)
|
||||||
|
}
|
||||||
|
if relayMode == relayconstant.RelayModeModerations && embeddingRequest.Model == "" {
|
||||||
|
embeddingRequest.Model = "m3e-base"
|
||||||
|
}
|
||||||
|
if relayMode == relayconstant.RelayModeEmbeddings && embeddingRequest.Model == "" {
|
||||||
|
embeddingRequest.Model = c.Param("model")
|
||||||
|
}
|
||||||
|
if embeddingRequest.Input == nil || len(embeddingRequest.Input) == 0 {
|
||||||
|
return service.OpenAIErrorWrapperLocal(fmt.Errorf("input is empty"), "invalid_input", http.StatusBadRequest)
|
||||||
|
}
|
||||||
|
// map model name
|
||||||
|
modelMapping := c.GetString("model_mapping")
|
||||||
|
//isModelMapped := false
|
||||||
|
if modelMapping != "" && modelMapping != "{}" {
|
||||||
|
modelMap := make(map[string]string)
|
||||||
|
err := json.Unmarshal([]byte(modelMapping), &modelMap)
|
||||||
|
if err != nil {
|
||||||
|
return service.OpenAIErrorWrapperLocal(err, "unmarshal_model_mapping_failed", http.StatusInternalServerError)
|
||||||
|
}
|
||||||
|
if modelMap[embeddingRequest.Model] != "" {
|
||||||
|
embeddingRequest.Model = modelMap[embeddingRequest.Model]
|
||||||
|
// set upstream model name
|
||||||
|
//isModelMapped = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
relayInfo.UpstreamModelName = embeddingRequest.Model
|
||||||
|
modelPrice, success := common.GetModelPrice(embeddingRequest.Model, false)
|
||||||
|
groupRatio := setting.GetGroupRatio(relayInfo.Group)
|
||||||
|
|
||||||
|
var preConsumedQuota int
|
||||||
|
var ratio float64
|
||||||
|
var modelRatio float64
|
||||||
|
|
||||||
|
promptToken := getEmbeddingPromptToken(*embeddingRequest)
|
||||||
|
if !success {
|
||||||
|
preConsumedTokens := promptToken
|
||||||
|
modelRatio = common.GetModelRatio(embeddingRequest.Model)
|
||||||
|
ratio = modelRatio * groupRatio
|
||||||
|
preConsumedQuota = int(float64(preConsumedTokens) * ratio)
|
||||||
|
} else {
|
||||||
|
preConsumedQuota = int(modelPrice * common.QuotaPerUnit * groupRatio)
|
||||||
|
}
|
||||||
|
relayInfo.PromptTokens = promptToken
|
||||||
|
|
||||||
|
// pre-consume quota 预消耗配额
|
||||||
|
preConsumedQuota, userQuota, openaiErr := preConsumeQuota(c, preConsumedQuota, relayInfo)
|
||||||
|
if openaiErr != nil {
|
||||||
|
return openaiErr
|
||||||
|
}
|
||||||
|
defer func() {
|
||||||
|
if openaiErr != nil {
|
||||||
|
returnPreConsumedQuota(c, relayInfo, userQuota, preConsumedQuota)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
adaptor := GetAdaptor(relayInfo.ApiType)
|
||||||
|
if adaptor == nil {
|
||||||
|
return service.OpenAIErrorWrapperLocal(fmt.Errorf("invalid api type: %d", relayInfo.ApiType), "invalid_api_type", http.StatusBadRequest)
|
||||||
|
}
|
||||||
|
adaptor.Init(relayInfo)
|
||||||
|
|
||||||
|
convertedRequest, err := adaptor.ConvertEmbeddingRequest(c,relayInfo,*embeddingRequest)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return service.OpenAIErrorWrapperLocal(err, "convert_request_failed", http.StatusInternalServerError)
|
||||||
|
}
|
||||||
|
jsonData, err := json.Marshal(convertedRequest)
|
||||||
|
if err != nil {
|
||||||
|
return service.OpenAIErrorWrapperLocal(err, "json_marshal_failed", http.StatusInternalServerError)
|
||||||
|
}
|
||||||
|
requestBody := bytes.NewBuffer(jsonData)
|
||||||
|
statusCodeMappingStr := c.GetString("status_code_mapping")
|
||||||
|
resp, err := adaptor.DoRequest(c,relayInfo, requestBody)
|
||||||
|
if err != nil {
|
||||||
|
return service.OpenAIErrorWrapper(err, "do_request_failed", http.StatusInternalServerError)
|
||||||
|
}
|
||||||
|
|
||||||
|
var httpResp *http.Response
|
||||||
|
if resp != nil {
|
||||||
|
httpResp = resp.(*http.Response)
|
||||||
|
if httpResp.StatusCode != http.StatusOK {
|
||||||
|
openaiErr = service.RelayErrorHandler(httpResp)
|
||||||
|
// reset status code 重置状态码
|
||||||
|
service.ResetStatusCode(openaiErr, statusCodeMappingStr)
|
||||||
|
return openaiErr
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
usage, openaiErr := adaptor.DoResponse(c, httpResp, relayInfo)
|
||||||
|
if openaiErr != nil {
|
||||||
|
// reset status code 重置状态码
|
||||||
|
service.ResetStatusCode(openaiErr, statusCodeMappingStr)
|
||||||
|
return openaiErr
|
||||||
|
}
|
||||||
|
postConsumeQuota(c, relayInfo, embeddingRequest.Model, usage.(*dto.Usage), ratio, preConsumedQuota, userQuota, modelRatio, groupRatio, modelPrice, success, "")
|
||||||
|
return nil
|
||||||
|
}
|
||||||
@@ -44,7 +44,7 @@ function renderTimestamp(timestamp) {
|
|||||||
|
|
||||||
const ChannelsTable = () => {
|
const ChannelsTable = () => {
|
||||||
const { t } = useTranslation();
|
const { t } = useTranslation();
|
||||||
|
|
||||||
let type2label = undefined;
|
let type2label = undefined;
|
||||||
|
|
||||||
const renderType = (type) => {
|
const renderType = (type) => {
|
||||||
@@ -559,7 +559,7 @@ const ChannelsTable = () => {
|
|||||||
if (!enableTagMode) {
|
if (!enableTagMode) {
|
||||||
channelDates.push(channels[i]);
|
channelDates.push(channels[i]);
|
||||||
} else {
|
} else {
|
||||||
let tag = channels[i].tag?channels[i].tag:"";
|
let tag = channels[i].tag ? channels[i].tag : "";
|
||||||
// find from channelTags
|
// find from channelTags
|
||||||
let tagIndex = channelTags[tag];
|
let tagIndex = channelTags[tag];
|
||||||
let tagChannelDates = undefined;
|
let tagChannelDates = undefined;
|
||||||
@@ -805,6 +805,9 @@ const ChannelsTable = () => {
|
|||||||
record.response_time = time * 1000;
|
record.response_time = time * 1000;
|
||||||
record.test_time = Date.now() / 1000;
|
record.test_time = Date.now() / 1000;
|
||||||
showInfo(t('通道 ${name} 测试成功,耗时 ${time.toFixed(2)} 秒。').replace('${name}', record.name).replace('${time.toFixed(2)}', time.toFixed(2)));
|
showInfo(t('通道 ${name} 测试成功,耗时 ${time.toFixed(2)} 秒。').replace('${name}', record.name).replace('${time.toFixed(2)}', time.toFixed(2)));
|
||||||
|
|
||||||
|
// 刷新列表
|
||||||
|
await refresh();
|
||||||
} else {
|
} else {
|
||||||
showError(message);
|
showError(message);
|
||||||
}
|
}
|
||||||
@@ -838,6 +841,8 @@ const ChannelsTable = () => {
|
|||||||
record.balance = balance;
|
record.balance = balance;
|
||||||
record.balance_updated_time = Date.now() / 1000;
|
record.balance_updated_time = Date.now() / 1000;
|
||||||
showInfo(t('通道 ${name} 余额更新成功!').replace('${name}', record.name));
|
showInfo(t('通道 ${name} 余额更新成功!').replace('${name}', record.name));
|
||||||
|
// 刷新列表
|
||||||
|
await refresh();
|
||||||
} else {
|
} else {
|
||||||
showError(message);
|
showError(message);
|
||||||
}
|
}
|
||||||
@@ -1186,7 +1191,7 @@ const ChannelsTable = () => {
|
|||||||
</Space>
|
</Space>
|
||||||
</div>
|
</div>
|
||||||
<div style={{ marginTop: 20 }}>
|
<div style={{ marginTop: 20 }}>
|
||||||
<Space>
|
<Space>
|
||||||
<Typography.Text strong>{t('标签聚合模式')}</Typography.Text>
|
<Typography.Text strong>{t('标签聚合模式')}</Typography.Text>
|
||||||
<Switch
|
<Switch
|
||||||
checked={enableTagMode}
|
checked={enableTagMode}
|
||||||
@@ -1199,14 +1204,14 @@ const ChannelsTable = () => {
|
|||||||
}}
|
}}
|
||||||
/>
|
/>
|
||||||
<Button
|
<Button
|
||||||
disabled={!enableBatchDelete}
|
disabled={!enableBatchDelete}
|
||||||
theme="light"
|
theme="light"
|
||||||
type="primary"
|
type="primary"
|
||||||
style={{ marginRight: 8 }}
|
style={{ marginRight: 8 }}
|
||||||
onClick={() => setShowBatchSetTag(true)}
|
onClick={() => setShowBatchSetTag(true)}
|
||||||
>
|
>
|
||||||
{t('批量设置标签')}
|
{t('批量设置标签')}
|
||||||
</Button>
|
</Button>
|
||||||
</Space>
|
</Space>
|
||||||
|
|
||||||
</div>
|
</div>
|
||||||
|
|||||||
@@ -125,5 +125,12 @@ export const CHANNEL_OPTIONS = [
|
|||||||
value: 21,
|
value: 21,
|
||||||
color: 'purple',
|
color: 'purple',
|
||||||
label: '知识库:AI Proxy'
|
label: '知识库:AI Proxy'
|
||||||
|
},
|
||||||
|
{
|
||||||
|
key: 47,
|
||||||
|
text: '嵌入模型:MokaAI M3E',
|
||||||
|
value: 47,
|
||||||
|
color: 'purple',
|
||||||
|
label: '嵌入模型:MokaAI M3E'
|
||||||
}
|
}
|
||||||
];
|
];
|
||||||
|
|||||||
Reference in New Issue
Block a user