Fix M3E not working
This commit is contained in:
@@ -33,6 +33,8 @@ func relayHandler(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusCode
|
|||||||
err = relay.AudioHelper(c)
|
err = relay.AudioHelper(c)
|
||||||
case relayconstant.RelayModeRerank:
|
case relayconstant.RelayModeRerank:
|
||||||
err = relay.RerankHelper(c, relayMode)
|
err = relay.RerankHelper(c, relayMode)
|
||||||
|
case relayconstant.RelayModeEmbeddings:
|
||||||
|
err = relay.EmbeddingHelper(c,relayMode)
|
||||||
default:
|
default:
|
||||||
err = relay.TextHelper(c)
|
err = relay.TextHelper(c)
|
||||||
}
|
}
|
||||||
@@ -55,6 +57,11 @@ func Relay(c *gin.Context) {
|
|||||||
originalModel := c.GetString("original_model")
|
originalModel := c.GetString("original_model")
|
||||||
var openaiErr *dto.OpenAIErrorWithStatusCode
|
var openaiErr *dto.OpenAIErrorWithStatusCode
|
||||||
|
|
||||||
|
//获取request body 并输出到日志
|
||||||
|
requestBody, _ := common.GetRequestBody(c)
|
||||||
|
common.LogInfo(c, fmt.Sprintf("relayMode: %d ,request body: %s",relayMode, string(requestBody)))
|
||||||
|
|
||||||
|
|
||||||
for i := 0; i <= common.RetryTimes; i++ {
|
for i := 0; i <= common.RetryTimes; i++ {
|
||||||
channel, err := getChannel(c, group, originalModel, i)
|
channel, err := getChannel(c, group, originalModel, i)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -154,6 +161,7 @@ func WssRelay(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func relayRequest(c *gin.Context, relayMode int, channel *model.Channel) *dto.OpenAIErrorWithStatusCode {
|
func relayRequest(c *gin.Context, relayMode int, channel *model.Channel) *dto.OpenAIErrorWithStatusCode {
|
||||||
|
common.LogInfo(c, fmt.Sprintf("relayMode: %d ,channel Id : %s",relayMode, string(channel.Id)))
|
||||||
addUsedChannel(c, channel.Id)
|
addUsedChannel(c, channel.Id)
|
||||||
requestBody, _ := common.GetRequestBody(c)
|
requestBody, _ := common.GetRequestBody(c)
|
||||||
c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody))
|
c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody))
|
||||||
|
|||||||
@@ -1,19 +1,6 @@
|
|||||||
package mokaai
|
package dto
|
||||||
|
|
||||||
import "one-api/dto"
|
type EmbeddingOptions struct {
|
||||||
|
|
||||||
|
|
||||||
type Request struct {
|
|
||||||
Messages []dto.Message `json:"messages,omitempty"`
|
|
||||||
Lora string `json:"lora,omitempty"`
|
|
||||||
MaxTokens int `json:"max_tokens,omitempty"`
|
|
||||||
Prompt string `json:"prompt,omitempty"`
|
|
||||||
Raw bool `json:"raw,omitempty"`
|
|
||||||
Stream bool `json:"stream,omitempty"`
|
|
||||||
Temperature float64 `json:"temperature,omitempty"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type Options struct {
|
|
||||||
Seed int `json:"seed,omitempty"`
|
Seed int `json:"seed,omitempty"`
|
||||||
Temperature *float64 `json:"temperature,omitempty"`
|
Temperature *float64 `json:"temperature,omitempty"`
|
||||||
TopK int `json:"top_k,omitempty"`
|
TopK int `json:"top_k,omitempty"`
|
||||||
@@ -27,4 +14,17 @@ type Options struct {
|
|||||||
type EmbeddingRequest struct {
|
type EmbeddingRequest struct {
|
||||||
Model string `json:"model"`
|
Model string `json:"model"`
|
||||||
Input []string `json:"input"`
|
Input []string `json:"input"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type EmbeddingResponseItem struct {
|
||||||
|
Object string `json:"object"`
|
||||||
|
Index int `json:"index"`
|
||||||
|
Embedding []float64 `json:"embedding"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type EmbeddingResponse struct {
|
||||||
|
Object string `json:"object"`
|
||||||
|
Data []EmbeddingResponseItem `json:"data"`
|
||||||
|
Model string `json:"model"`
|
||||||
|
Usage `json:"usage"`
|
||||||
}
|
}
|
||||||
@@ -15,6 +15,7 @@ type Adaptor interface {
|
|||||||
SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error
|
SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error
|
||||||
ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error)
|
ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error)
|
||||||
ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error)
|
ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error)
|
||||||
|
ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.EmbeddingRequest) (any, error)
|
||||||
ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error)
|
ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error)
|
||||||
ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error)
|
ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error)
|
||||||
DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error)
|
DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error)
|
||||||
|
|||||||
@@ -67,6 +67,11 @@ func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dt
|
|||||||
return nil, errors.New("not implemented")
|
return nil, errors.New("not implemented")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.EmbeddingRequest) (any, error) {
|
||||||
|
//TODO implement me
|
||||||
|
return nil, errors.New("not implemented")
|
||||||
|
}
|
||||||
|
|
||||||
func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) {
|
func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) {
|
||||||
//TODO implement me
|
//TODO implement me
|
||||||
return nil, errors.New("not implemented")
|
return nil, errors.New("not implemented")
|
||||||
|
|||||||
@@ -59,6 +59,12 @@ func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dt
|
|||||||
return nil, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.EmbeddingRequest) (any, error) {
|
||||||
|
//TODO implement me
|
||||||
|
return nil, errors.New("not implemented")
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
|
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
|
||||||
return nil, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -122,6 +122,11 @@ func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dt
|
|||||||
return nil, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.EmbeddingRequest) (any, error) {
|
||||||
|
//TODO implement me
|
||||||
|
return nil, errors.New("not implemented")
|
||||||
|
}
|
||||||
|
|
||||||
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
|
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
|
||||||
return channel.DoApiRequest(a, c, info, requestBody)
|
return channel.DoApiRequest(a, c, info, requestBody)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -73,6 +73,11 @@ func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dt
|
|||||||
return nil, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.EmbeddingRequest) (any, error) {
|
||||||
|
//TODO implement me
|
||||||
|
return nil, errors.New("not implemented")
|
||||||
|
}
|
||||||
|
|
||||||
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
|
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
|
||||||
return channel.DoApiRequest(a, c, info, requestBody)
|
return channel.DoApiRequest(a, c, info, requestBody)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -56,6 +56,12 @@ func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dt
|
|||||||
return request, nil
|
return request, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.EmbeddingRequest) (any, error) {
|
||||||
|
//TODO implement me
|
||||||
|
return nil, errors.New("not implemented")
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) {
|
func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) {
|
||||||
// 添加文件字段
|
// 添加文件字段
|
||||||
file, _, err := c.Request.FormFile("file")
|
file, _, err := c.Request.FormFile("file")
|
||||||
|
|||||||
@@ -54,6 +54,12 @@ func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dt
|
|||||||
return requestConvertRerank2Cohere(request), nil
|
return requestConvertRerank2Cohere(request), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.EmbeddingRequest) (any, error) {
|
||||||
|
//TODO implement me
|
||||||
|
return nil, errors.New("not implemented")
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) {
|
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) {
|
||||||
if info.RelayMode == constant.RelayModeRerank {
|
if info.RelayMode == constant.RelayModeRerank {
|
||||||
err, usage = cohereRerankHandler(c, resp, info)
|
err, usage = cohereRerankHandler(c, resp, info)
|
||||||
|
|||||||
@@ -49,6 +49,12 @@ func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dt
|
|||||||
return nil, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.EmbeddingRequest) (any, error) {
|
||||||
|
//TODO implement me
|
||||||
|
return nil, errors.New("not implemented")
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
|
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
|
||||||
return channel.DoApiRequest(a, c, info, requestBody)
|
return channel.DoApiRequest(a, c, info, requestBody)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -48,6 +48,12 @@ func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dt
|
|||||||
return nil, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.EmbeddingRequest) (any, error) {
|
||||||
|
//TODO implement me
|
||||||
|
return nil, errors.New("not implemented")
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
|
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
|
||||||
return channel.DoApiRequest(a, c, info, requestBody)
|
return channel.DoApiRequest(a, c, info, requestBody)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -68,6 +68,12 @@ func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dt
|
|||||||
return nil, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.EmbeddingRequest) (any, error) {
|
||||||
|
//TODO implement me
|
||||||
|
return nil, errors.New("not implemented")
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
|
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
|
||||||
return channel.DoApiRequest(a, c, info, requestBody)
|
return channel.DoApiRequest(a, c, info, requestBody)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -55,6 +55,12 @@ func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dt
|
|||||||
return request, nil
|
return request, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.EmbeddingRequest) (any, error) {
|
||||||
|
//TODO implement me
|
||||||
|
return nil, errors.New("not implemented")
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) {
|
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) {
|
||||||
if info.RelayMode == constant.RelayModeRerank {
|
if info.RelayMode == constant.RelayModeRerank {
|
||||||
err, usage = jinaRerankHandler(c, resp)
|
err, usage = jinaRerankHandler(c, resp)
|
||||||
|
|||||||
@@ -50,6 +50,12 @@ func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dt
|
|||||||
return nil, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.EmbeddingRequest) (any, error) {
|
||||||
|
//TODO implement me
|
||||||
|
return nil, errors.New("not implemented")
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
|
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
|
||||||
return channel.DoApiRequest(a, c, info, requestBody)
|
return channel.DoApiRequest(a, c, info, requestBody)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -3,54 +3,46 @@ package mokaai
|
|||||||
import (
|
import (
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
"io"
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
|
||||||
// "one-api/relay/adaptor"
|
|
||||||
// "one-api/relay/meta"
|
|
||||||
// "one-api/relay/model"
|
|
||||||
// "one-api/relay/constant"
|
|
||||||
"one-api/dto"
|
"one-api/dto"
|
||||||
"one-api/relay/channel"
|
"one-api/relay/channel"
|
||||||
relaycommon "one-api/relay/common"
|
relaycommon "one-api/relay/common"
|
||||||
"one-api/relay/constant"
|
"one-api/relay/constant"
|
||||||
|
"strings"
|
||||||
)
|
)
|
||||||
|
|
||||||
type Adaptor struct {
|
type Adaptor struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
// ConvertImageRequest implements adaptor.Adaptor.
|
|
||||||
func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) {
|
|
||||||
//TODO implement me
|
|
||||||
return nil, errors.New("not implemented")
|
|
||||||
}
|
|
||||||
|
|
||||||
func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) {
|
func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) {
|
||||||
//TODO implement me
|
//TODO implement me
|
||||||
return nil, errors.New("not implemented")
|
return nil, errors.New("not implemented")
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) {
|
func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) {
|
||||||
//TODO implement me
|
//TODO implement me
|
||||||
return nil, errors.New("not implemented")
|
return nil, errors.New("not implemented")
|
||||||
}
|
}
|
||||||
func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
|
|
||||||
|
func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.EmbeddingRequest) (any, error) {
|
||||||
|
//TODO implement me
|
||||||
|
return request, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
|
||||||
|
|
||||||
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
}
|
||||||
|
|
||||||
var urlPrefix = info.BaseUrl
|
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
||||||
|
// https://cloud.baidu.com/doc/WENXINWORKSHOP/s/clntwmv7t
|
||||||
switch info.RelayMode {
|
suffix := "chat/"
|
||||||
case constant.RelayModeChatCompletions:
|
if strings.HasPrefix(info.UpstreamModelName, "m3e") {
|
||||||
return fmt.Sprintf("%s/chat/completions", urlPrefix), nil
|
suffix = "embeddings"
|
||||||
case constant.RelayModeEmbeddings:
|
|
||||||
return fmt.Sprintf("%s/embeddings", urlPrefix), nil
|
|
||||||
default:
|
|
||||||
return fmt.Sprintf("%s/run/%s", urlPrefix, info.UpstreamModelName), nil
|
|
||||||
}
|
}
|
||||||
|
fullRequestURL := fmt.Sprintf("%s/%s", info.BaseUrl, suffix)
|
||||||
|
return fullRequestURL, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error {
|
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error {
|
||||||
@@ -64,33 +56,30 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, re
|
|||||||
return nil, errors.New("request is nil")
|
return nil, errors.New("request is nil")
|
||||||
}
|
}
|
||||||
switch info.RelayMode {
|
switch info.RelayMode {
|
||||||
case constant.RelayModeChatCompletions:
|
case constant.RelayModeEmbeddings:
|
||||||
return nil, errors.New("not implemented")
|
baiduEmbeddingRequest := embeddingRequestOpenAI2Moka(*request)
|
||||||
case constant.RelayModeEmbeddings:
|
return baiduEmbeddingRequest, nil
|
||||||
// return ConvertCompletionsRequest(*request), nil
|
|
||||||
return ConvertEmbeddingRequest(*request), nil
|
|
||||||
default:
|
default:
|
||||||
return nil, errors.New("not implemented")
|
return nil, errors.New("not implemented")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
|
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
|
||||||
return channel.DoApiRequest(a, c, info, requestBody)
|
return channel.DoApiRequest(a, c, info, requestBody)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) {
|
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) {
|
||||||
switch info.RelayMode {
|
|
||||||
|
|
||||||
case constant.RelayModeAudioTranscription:
|
switch info.RelayMode {
|
||||||
case constant.RelayModeAudioTranslation:
|
|
||||||
case constant.RelayModeChatCompletions:
|
|
||||||
fallthrough
|
|
||||||
case constant.RelayModeEmbeddings:
|
case constant.RelayModeEmbeddings:
|
||||||
if info.IsStream {
|
err, usage = mokaEmbeddingHandler(c, resp)
|
||||||
err, usage = StreamHandler(c, resp, info)
|
default:
|
||||||
} else {
|
// err, usage = mokaHandler(c, resp)
|
||||||
err, usage = Handler(c, resp, info)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,41 +1,15 @@
|
|||||||
package mokaai
|
package mokaai
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bufio"
|
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
"io"
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
"strings"
|
|
||||||
|
|
||||||
// "one-api/common/ctxkey"
|
|
||||||
// "one-api/common/render"
|
|
||||||
|
|
||||||
// "github.com/gin-gonic/gin"
|
|
||||||
// "one-api/common"
|
|
||||||
// "one-api/common/helper"
|
|
||||||
// "one-api/common/logger"
|
|
||||||
// "one-api/relay/adaptor/openai"
|
|
||||||
// "one-api/relay/model"
|
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
|
||||||
"one-api/common"
|
|
||||||
"one-api/dto"
|
"one-api/dto"
|
||||||
relaycommon "one-api/relay/common"
|
|
||||||
"one-api/service"
|
"one-api/service"
|
||||||
"time"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
func ConvertCompletionsRequest(textRequest dto.GeneralOpenAIRequest) *Request {
|
func embeddingRequestOpenAI2Moka(request dto.GeneralOpenAIRequest) *dto.EmbeddingRequest {
|
||||||
p, _ := textRequest.Prompt.(string)
|
|
||||||
return &Request{
|
|
||||||
Prompt: p,
|
|
||||||
MaxTokens: textRequest.GetMaxTokens(),
|
|
||||||
Stream: textRequest.Stream,
|
|
||||||
Temperature: textRequest.Temperature,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func ConvertEmbeddingRequest(request dto.GeneralOpenAIRequest) *EmbeddingRequest {
|
|
||||||
var input []string // Change input to []string
|
var input []string // Change input to []string
|
||||||
|
|
||||||
switch v := request.Input.(type) {
|
switch v := request.Input.(type) {
|
||||||
@@ -50,105 +24,60 @@ func ConvertEmbeddingRequest(request dto.GeneralOpenAIRequest) *EmbeddingRequest
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
return &dto.EmbeddingRequest{
|
||||||
return &EmbeddingRequest{
|
Input: input,
|
||||||
Model: request.Model,
|
Model: request.Model,
|
||||||
Input: input, // Assign []string to Input
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func StreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
|
func embeddingResponseMoka2OpenAI(response *dto.EmbeddingResponse) *dto.OpenAIEmbeddingResponse {
|
||||||
scanner := bufio.NewScanner(resp.Body)
|
openAIEmbeddingResponse := dto.OpenAIEmbeddingResponse{
|
||||||
scanner.Split(bufio.ScanLines)
|
Object: "list",
|
||||||
|
Data: make([]dto.OpenAIEmbeddingResponseItem, 0, len(response.Data)),
|
||||||
service.SetEventStreamHeaders(c)
|
Model: "baidu-embedding",
|
||||||
id := service.GetResponseID(c)
|
Usage: response.Usage,
|
||||||
var responseText string
|
|
||||||
isFirst := true
|
|
||||||
|
|
||||||
for scanner.Scan() {
|
|
||||||
data := scanner.Text()
|
|
||||||
if len(data) < len("data: ") {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
data = strings.TrimPrefix(data, "data: ")
|
|
||||||
data = strings.TrimSuffix(data, "\r")
|
|
||||||
|
|
||||||
if data == "[DONE]" {
|
|
||||||
break
|
|
||||||
}
|
|
||||||
|
|
||||||
var response dto.ChatCompletionsStreamResponse
|
|
||||||
err := json.Unmarshal([]byte(data), &response)
|
|
||||||
if err != nil {
|
|
||||||
common.LogError(c, "error_unmarshalling_stream_response: "+err.Error())
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
for _, choice := range response.Choices {
|
|
||||||
choice.Delta.Role = "assistant"
|
|
||||||
responseText += choice.Delta.GetContentString()
|
|
||||||
}
|
|
||||||
response.Id = id
|
|
||||||
response.Model = info.UpstreamModelName
|
|
||||||
err = service.ObjectData(c, response)
|
|
||||||
if isFirst {
|
|
||||||
isFirst = false
|
|
||||||
info.FirstResponseTime = time.Now()
|
|
||||||
}
|
|
||||||
if err != nil {
|
|
||||||
common.LogError(c, "error_rendering_stream_response: "+err.Error())
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
for _, item := range response.Data {
|
||||||
if err := scanner.Err(); err != nil {
|
openAIEmbeddingResponse.Data = append(openAIEmbeddingResponse.Data, dto.OpenAIEmbeddingResponseItem{
|
||||||
common.LogError(c, "error_scanning_stream_response: "+err.Error())
|
Object: item.Object,
|
||||||
|
Index: item.Index,
|
||||||
|
Embedding: item.Embedding,
|
||||||
|
})
|
||||||
}
|
}
|
||||||
usage, _ := service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
|
return &openAIEmbeddingResponse
|
||||||
if info.ShouldIncludeUsage {
|
|
||||||
response := service.GenerateFinalUsageResponse(id, info.StartTime.Unix(), info.UpstreamModelName, *usage)
|
|
||||||
err := service.ObjectData(c, response)
|
|
||||||
if err != nil {
|
|
||||||
common.LogError(c, "error_rendering_final_usage_response: "+err.Error())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
service.Done(c)
|
|
||||||
|
|
||||||
err := resp.Body.Close()
|
|
||||||
if err != nil {
|
|
||||||
common.LogError(c, "close_response_body_failed: "+err.Error())
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil, usage
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func Handler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
|
func mokaEmbeddingHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
|
||||||
|
var baiduResponse dto.EmbeddingResponse
|
||||||
responseBody, err := io.ReadAll(resp.Body)
|
responseBody, err := io.ReadAll(resp.Body)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
|
return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
|
||||||
}
|
}
|
||||||
err = resp.Body.Close()
|
err = resp.Body.Close()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return service.OpenAIErrorWrapperLocal(err, "close_response_body_failed", http.StatusInternalServerError), nil
|
return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
|
||||||
}
|
}
|
||||||
var response dto.TextResponse
|
err = json.Unmarshal(responseBody, &baiduResponse)
|
||||||
err = json.Unmarshal(responseBody, &response)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
|
return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
|
||||||
}
|
}
|
||||||
response.Model = info.UpstreamModelName
|
// if baiduResponse.ErrorMsg != "" {
|
||||||
var responseText string
|
// return &dto.OpenAIErrorWithStatusCode{
|
||||||
for _, choice := range response.Choices {
|
// Error: dto.OpenAIError{
|
||||||
responseText += choice.Message.StringContent()
|
// Type: "baidu_error",
|
||||||
}
|
// Param: "",
|
||||||
usage, _ := service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
|
// },
|
||||||
response.Usage = *usage
|
// StatusCode: resp.StatusCode,
|
||||||
response.Id = service.GetResponseID(c)
|
// }, nil
|
||||||
jsonResponse, err := json.Marshal(response)
|
// }
|
||||||
|
fullTextResponse := embeddingResponseMoka2OpenAI(&baiduResponse)
|
||||||
|
jsonResponse, err := json.Marshal(fullTextResponse)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
|
return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
|
||||||
}
|
}
|
||||||
c.Writer.Header().Set("Content-Type", "application/json")
|
c.Writer.Header().Set("Content-Type", "application/json")
|
||||||
c.Writer.WriteHeader(resp.StatusCode)
|
c.Writer.WriteHeader(resp.StatusCode)
|
||||||
_, _ = c.Writer.Write(jsonResponse)
|
_, err = c.Writer.Write(jsonResponse)
|
||||||
return nil, usage
|
return nil, &fullTextResponse.Usage
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -58,6 +58,12 @@ func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dt
|
|||||||
return nil, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.EmbeddingRequest) (any, error) {
|
||||||
|
//TODO implement me
|
||||||
|
return nil, errors.New("not implemented")
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
|
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
|
||||||
return channel.DoApiRequest(a, c, info, requestBody)
|
return channel.DoApiRequest(a, c, info, requestBody)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -129,6 +129,11 @@ func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dt
|
|||||||
return nil, errors.New("not implemented")
|
return nil, errors.New("not implemented")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.EmbeddingRequest) (any, error) {
|
||||||
|
//TODO implement me
|
||||||
|
return nil, errors.New("not implemented")
|
||||||
|
}
|
||||||
|
|
||||||
func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) {
|
func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) {
|
||||||
a.ResponseFormat = request.ResponseFormat
|
a.ResponseFormat = request.ResponseFormat
|
||||||
if info.RelayMode == constant.RelayModeAudioSpeech {
|
if info.RelayMode == constant.RelayModeAudioSpeech {
|
||||||
|
|||||||
@@ -49,6 +49,12 @@ func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dt
|
|||||||
return nil, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.EmbeddingRequest) (any, error) {
|
||||||
|
//TODO implement me
|
||||||
|
return nil, errors.New("not implemented")
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
|
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
|
||||||
return channel.DoApiRequest(a, c, info, requestBody)
|
return channel.DoApiRequest(a, c, info, requestBody)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -52,6 +52,12 @@ func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dt
|
|||||||
return nil, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.EmbeddingRequest) (any, error) {
|
||||||
|
//TODO implement me
|
||||||
|
return nil, errors.New("not implemented")
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
|
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
|
||||||
return channel.DoApiRequest(a, c, info, requestBody)
|
return channel.DoApiRequest(a, c, info, requestBody)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -58,6 +58,12 @@ func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dt
|
|||||||
return request, nil
|
return request, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.EmbeddingRequest) (any, error) {
|
||||||
|
//TODO implement me
|
||||||
|
return nil, errors.New("not implemented")
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) {
|
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) {
|
||||||
switch info.RelayMode {
|
switch info.RelayMode {
|
||||||
case constant.RelayModeRerank:
|
case constant.RelayModeRerank:
|
||||||
|
|||||||
@@ -73,6 +73,12 @@ func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dt
|
|||||||
return nil, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.EmbeddingRequest) (any, error) {
|
||||||
|
//TODO implement me
|
||||||
|
return nil, errors.New("not implemented")
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
|
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
|
||||||
return channel.DoApiRequest(a, c, info, requestBody)
|
return channel.DoApiRequest(a, c, info, requestBody)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -151,6 +151,12 @@ func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dt
|
|||||||
return nil, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.EmbeddingRequest) (any, error) {
|
||||||
|
//TODO implement me
|
||||||
|
return nil, errors.New("not implemented")
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
|
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
|
||||||
return channel.DoApiRequest(a, c, info, requestBody)
|
return channel.DoApiRequest(a, c, info, requestBody)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -50,6 +50,12 @@ func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dt
|
|||||||
return nil, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.EmbeddingRequest) (any, error) {
|
||||||
|
//TODO implement me
|
||||||
|
return nil, errors.New("not implemented")
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
|
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
|
||||||
// xunfei's request is not http request, so we don't need to do anything here
|
// xunfei's request is not http request, so we don't need to do anything here
|
||||||
dummyResp := &http.Response{}
|
dummyResp := &http.Response{}
|
||||||
|
|||||||
@@ -56,6 +56,12 @@ func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dt
|
|||||||
return nil, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.EmbeddingRequest) (any, error) {
|
||||||
|
//TODO implement me
|
||||||
|
return nil, errors.New("not implemented")
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
|
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
|
||||||
return channel.DoApiRequest(a, c, info, requestBody)
|
return channel.DoApiRequest(a, c, info, requestBody)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -53,6 +53,12 @@ func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dt
|
|||||||
return nil, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.EmbeddingRequest) (any, error) {
|
||||||
|
//TODO implement me
|
||||||
|
return nil, errors.New("not implemented")
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
|
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
|
||||||
return channel.DoApiRequest(a, c, info, requestBody)
|
return channel.DoApiRequest(a, c, info, requestBody)
|
||||||
}
|
}
|
||||||
|
|||||||
127
relay/relay_embedding.go
Normal file
127
relay/relay_embedding.go
Normal file
@@ -0,0 +1,127 @@
|
|||||||
|
package relay
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
"net/http"
|
||||||
|
"one-api/common"
|
||||||
|
"one-api/dto"
|
||||||
|
relaycommon "one-api/relay/common"
|
||||||
|
relayconstant "one-api/relay/constant"
|
||||||
|
"one-api/service"
|
||||||
|
"one-api/setting"
|
||||||
|
)
|
||||||
|
|
||||||
|
func getEmbeddingPromptToken(embeddingRequest dto.EmbeddingRequest) int {
|
||||||
|
token, _ := service.CountTokenInput(embeddingRequest.Input, embeddingRequest.Model)
|
||||||
|
return token
|
||||||
|
}
|
||||||
|
|
||||||
|
func EmbeddingHelper(c *gin.Context, relayMode int) (openaiErr *dto.OpenAIErrorWithStatusCode) {
|
||||||
|
relayInfo := relaycommon.GenRelayInfo(c)
|
||||||
|
|
||||||
|
var embeddingRequest *dto.EmbeddingRequest
|
||||||
|
err := common.UnmarshalBodyReusable(c, &embeddingRequest)
|
||||||
|
if err != nil {
|
||||||
|
common.LogError(c, fmt.Sprintf("getAndValidateTextRequest failed: %s", err.Error()))
|
||||||
|
return service.OpenAIErrorWrapperLocal(err, "invalid_text_request", http.StatusBadRequest)
|
||||||
|
}
|
||||||
|
if relayMode == relayconstant.RelayModeModerations && embeddingRequest.Model == "" {
|
||||||
|
embeddingRequest.Model = "m3e-base"
|
||||||
|
}
|
||||||
|
if relayMode == relayconstant.RelayModeEmbeddings && embeddingRequest.Model == "" {
|
||||||
|
embeddingRequest.Model = c.Param("model")
|
||||||
|
}
|
||||||
|
if embeddingRequest.Input == nil || len(embeddingRequest.Input) == 0 {
|
||||||
|
return service.OpenAIErrorWrapperLocal(fmt.Errorf("input is empty"), "invalid_input", http.StatusBadRequest)
|
||||||
|
}
|
||||||
|
// map model name
|
||||||
|
modelMapping := c.GetString("model_mapping")
|
||||||
|
//isModelMapped := false
|
||||||
|
if modelMapping != "" && modelMapping != "{}" {
|
||||||
|
modelMap := make(map[string]string)
|
||||||
|
err := json.Unmarshal([]byte(modelMapping), &modelMap)
|
||||||
|
if err != nil {
|
||||||
|
return service.OpenAIErrorWrapperLocal(err, "unmarshal_model_mapping_failed", http.StatusInternalServerError)
|
||||||
|
}
|
||||||
|
if modelMap[embeddingRequest.Model] != "" {
|
||||||
|
embeddingRequest.Model = modelMap[embeddingRequest.Model]
|
||||||
|
// set upstream model name
|
||||||
|
//isModelMapped = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
relayInfo.UpstreamModelName = embeddingRequest.Model
|
||||||
|
modelPrice, success := common.GetModelPrice(embeddingRequest.Model, false)
|
||||||
|
groupRatio := setting.GetGroupRatio(relayInfo.Group)
|
||||||
|
|
||||||
|
var preConsumedQuota int
|
||||||
|
var ratio float64
|
||||||
|
var modelRatio float64
|
||||||
|
|
||||||
|
promptToken := getEmbeddingPromptToken(*embeddingRequest)
|
||||||
|
if !success {
|
||||||
|
preConsumedTokens := promptToken
|
||||||
|
modelRatio = common.GetModelRatio(embeddingRequest.Model)
|
||||||
|
ratio = modelRatio * groupRatio
|
||||||
|
preConsumedQuota = int(float64(preConsumedTokens) * ratio)
|
||||||
|
} else {
|
||||||
|
preConsumedQuota = int(modelPrice * common.QuotaPerUnit * groupRatio)
|
||||||
|
}
|
||||||
|
relayInfo.PromptTokens = promptToken
|
||||||
|
|
||||||
|
// pre-consume quota 预消耗配额
|
||||||
|
preConsumedQuota, userQuota, openaiErr := preConsumeQuota(c, preConsumedQuota, relayInfo)
|
||||||
|
if openaiErr != nil {
|
||||||
|
return openaiErr
|
||||||
|
}
|
||||||
|
defer func() {
|
||||||
|
if openaiErr != nil {
|
||||||
|
returnPreConsumedQuota(c, relayInfo, userQuota, preConsumedQuota)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
adaptor := GetAdaptor(relayInfo.ApiType)
|
||||||
|
if adaptor == nil {
|
||||||
|
return service.OpenAIErrorWrapperLocal(fmt.Errorf("invalid api type: %d", relayInfo.ApiType), "invalid_api_type", http.StatusBadRequest)
|
||||||
|
}
|
||||||
|
adaptor.Init(relayInfo)
|
||||||
|
|
||||||
|
convertedRequest, err := adaptor.ConvertEmbeddingRequest(c,relayInfo,*embeddingRequest)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return service.OpenAIErrorWrapperLocal(err, "convert_request_failed", http.StatusInternalServerError)
|
||||||
|
}
|
||||||
|
jsonData, err := json.Marshal(convertedRequest)
|
||||||
|
if err != nil {
|
||||||
|
return service.OpenAIErrorWrapperLocal(err, "json_marshal_failed", http.StatusInternalServerError)
|
||||||
|
}
|
||||||
|
requestBody := bytes.NewBuffer(jsonData)
|
||||||
|
statusCodeMappingStr := c.GetString("status_code_mapping")
|
||||||
|
resp, err := adaptor.DoRequest(c,relayInfo, requestBody)
|
||||||
|
if err != nil {
|
||||||
|
return service.OpenAIErrorWrapper(err, "do_request_failed", http.StatusInternalServerError)
|
||||||
|
}
|
||||||
|
|
||||||
|
var httpResp *http.Response
|
||||||
|
if resp != nil {
|
||||||
|
httpResp = resp.(*http.Response)
|
||||||
|
if httpResp.StatusCode != http.StatusOK {
|
||||||
|
openaiErr = service.RelayErrorHandler(httpResp)
|
||||||
|
// reset status code 重置状态码
|
||||||
|
service.ResetStatusCode(openaiErr, statusCodeMappingStr)
|
||||||
|
return openaiErr
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
usage, openaiErr := adaptor.DoResponse(c, httpResp, relayInfo)
|
||||||
|
if openaiErr != nil {
|
||||||
|
// reset status code 重置状态码
|
||||||
|
service.ResetStatusCode(openaiErr, statusCodeMappingStr)
|
||||||
|
return openaiErr
|
||||||
|
}
|
||||||
|
postConsumeQuota(c, relayInfo, embeddingRequest.Model, usage.(*dto.Usage), ratio, preConsumedQuota, userQuota, modelRatio, groupRatio, modelPrice, success, "")
|
||||||
|
return nil
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user