diff --git a/controller/relay.go b/controller/relay.go index 72d421e3..25af7e20 100644 --- a/controller/relay.go +++ b/controller/relay.go @@ -33,6 +33,8 @@ func relayHandler(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusCode err = relay.AudioHelper(c) case relayconstant.RelayModeRerank: err = relay.RerankHelper(c, relayMode) + case relayconstant.RelayModeEmbeddings: + err = relay.EmbeddingHelper(c,relayMode) default: err = relay.TextHelper(c) } @@ -55,6 +57,11 @@ func Relay(c *gin.Context) { originalModel := c.GetString("original_model") var openaiErr *dto.OpenAIErrorWithStatusCode + //获取request body 并输出到日志 + requestBody, _ := common.GetRequestBody(c) + common.LogInfo(c, fmt.Sprintf("relayMode: %d ,request body: %s",relayMode, string(requestBody))) + + for i := 0; i <= common.RetryTimes; i++ { channel, err := getChannel(c, group, originalModel, i) if err != nil { @@ -154,6 +161,7 @@ func WssRelay(c *gin.Context) { } func relayRequest(c *gin.Context, relayMode int, channel *model.Channel) *dto.OpenAIErrorWithStatusCode { + common.LogInfo(c, fmt.Sprintf("relayMode: %d ,channel Id : %s",relayMode, string(channel.Id))) addUsedChannel(c, channel.Id) requestBody, _ := common.GetRequestBody(c) c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody)) diff --git a/relay/channel/mokaai/dto.go b/dto/embedding.go similarity index 54% rename from relay/channel/mokaai/dto.go rename to dto/embedding.go index 18d2dccb..828faaab 100644 --- a/relay/channel/mokaai/dto.go +++ b/dto/embedding.go @@ -1,19 +1,6 @@ -package mokaai +package dto -import "one-api/dto" - - -type Request struct { - Messages []dto.Message `json:"messages,omitempty"` - Lora string `json:"lora,omitempty"` - MaxTokens int `json:"max_tokens,omitempty"` - Prompt string `json:"prompt,omitempty"` - Raw bool `json:"raw,omitempty"` - Stream bool `json:"stream,omitempty"` - Temperature float64 `json:"temperature,omitempty"` -} - -type Options struct { +type EmbeddingOptions struct { Seed int `json:"seed,omitempty"` Temperature *float64 `json:"temperature,omitempty"` TopK int `json:"top_k,omitempty"` @@ -27,4 +14,17 @@ type Options struct { type EmbeddingRequest struct { Model string `json:"model"` Input []string `json:"input"` +} + +type EmbeddingResponseItem struct { + Object string `json:"object"` + Index int `json:"index"` + Embedding []float64 `json:"embedding"` +} + +type EmbeddingResponse struct { + Object string `json:"object"` + Data []EmbeddingResponseItem `json:"data"` + Model string `json:"model"` + Usage `json:"usage"` } \ No newline at end of file diff --git a/relay/channel/adapter.go b/relay/channel/adapter.go index d72db6e4..c970fd48 100644 --- a/relay/channel/adapter.go +++ b/relay/channel/adapter.go @@ -15,6 +15,7 @@ type Adaptor interface { SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) + ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.EmbeddingRequest) (any, error) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) diff --git a/relay/channel/ali/adaptor.go b/relay/channel/ali/adaptor.go index aa01ca66..c4974a62 100644 --- a/relay/channel/ali/adaptor.go +++ b/relay/channel/ali/adaptor.go @@ -67,6 +67,11 @@ func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dt return nil, errors.New("not implemented") } +func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.EmbeddingRequest) (any, error) { + //TODO implement me + return nil, errors.New("not implemented") +} + func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) { //TODO implement me return nil, errors.New("not implemented") diff --git a/relay/channel/aws/adaptor.go b/relay/channel/aws/adaptor.go index be72c04c..5a3d09b9 100644 --- a/relay/channel/aws/adaptor.go +++ b/relay/channel/aws/adaptor.go @@ -59,6 +59,12 @@ func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dt return nil, nil } +func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.EmbeddingRequest) (any, error) { + //TODO implement me + return nil, errors.New("not implemented") +} + + func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) { return nil, nil } diff --git a/relay/channel/baidu/adaptor.go b/relay/channel/baidu/adaptor.go index 3991a5e9..35271b41 100644 --- a/relay/channel/baidu/adaptor.go +++ b/relay/channel/baidu/adaptor.go @@ -122,6 +122,11 @@ func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dt return nil, nil } +func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.EmbeddingRequest) (any, error) { + //TODO implement me + return nil, errors.New("not implemented") +} + func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) { return channel.DoApiRequest(a, c, info, requestBody) } diff --git a/relay/channel/claude/adaptor.go b/relay/channel/claude/adaptor.go index 488d87dc..83168382 100644 --- a/relay/channel/claude/adaptor.go +++ b/relay/channel/claude/adaptor.go @@ -73,6 +73,11 @@ func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dt return nil, nil } +func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.EmbeddingRequest) (any, error) { + //TODO implement me + return nil, errors.New("not implemented") +} + func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) { return channel.DoApiRequest(a, c, info, requestBody) } diff --git a/relay/channel/cloudflare/adaptor.go b/relay/channel/cloudflare/adaptor.go index fc0ec271..cf41d9d7 100644 --- a/relay/channel/cloudflare/adaptor.go +++ b/relay/channel/cloudflare/adaptor.go @@ -56,6 +56,12 @@ func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dt return request, nil } +func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.EmbeddingRequest) (any, error) { + //TODO implement me + return nil, errors.New("not implemented") +} + + func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) { // 添加文件字段 file, _, err := c.Request.FormFile("file") diff --git a/relay/channel/cohere/adaptor.go b/relay/channel/cohere/adaptor.go index f8b190ec..d552a53b 100644 --- a/relay/channel/cohere/adaptor.go +++ b/relay/channel/cohere/adaptor.go @@ -54,6 +54,12 @@ func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dt return requestConvertRerank2Cohere(request), nil } +func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.EmbeddingRequest) (any, error) { + //TODO implement me + return nil, errors.New("not implemented") +} + + func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) { if info.RelayMode == constant.RelayModeRerank { err, usage = cohereRerankHandler(c, resp, info) diff --git a/relay/channel/deepseek/adaptor.go b/relay/channel/deepseek/adaptor.go index cc94a58f..1682dc3f 100644 --- a/relay/channel/deepseek/adaptor.go +++ b/relay/channel/deepseek/adaptor.go @@ -49,6 +49,12 @@ func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dt return nil, nil } +func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.EmbeddingRequest) (any, error) { + //TODO implement me + return nil, errors.New("not implemented") +} + + func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) { return channel.DoApiRequest(a, c, info, requestBody) } diff --git a/relay/channel/dify/adaptor.go b/relay/channel/dify/adaptor.go index 53ba26e6..ce73c78c 100644 --- a/relay/channel/dify/adaptor.go +++ b/relay/channel/dify/adaptor.go @@ -48,6 +48,12 @@ func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dt return nil, nil } +func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.EmbeddingRequest) (any, error) { + //TODO implement me + return nil, errors.New("not implemented") +} + + func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) { return channel.DoApiRequest(a, c, info, requestBody) } diff --git a/relay/channel/gemini/adaptor.go b/relay/channel/gemini/adaptor.go index 9a5bc251..681e9988 100644 --- a/relay/channel/gemini/adaptor.go +++ b/relay/channel/gemini/adaptor.go @@ -68,6 +68,12 @@ func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dt return nil, nil } +func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.EmbeddingRequest) (any, error) { + //TODO implement me + return nil, errors.New("not implemented") +} + + func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) { return channel.DoApiRequest(a, c, info, requestBody) } diff --git a/relay/channel/jina/adaptor.go b/relay/channel/jina/adaptor.go index ad488f28..98fb073d 100644 --- a/relay/channel/jina/adaptor.go +++ b/relay/channel/jina/adaptor.go @@ -55,6 +55,12 @@ func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dt return request, nil } +func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.EmbeddingRequest) (any, error) { + //TODO implement me + return nil, errors.New("not implemented") +} + + func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) { if info.RelayMode == constant.RelayModeRerank { err, usage = jinaRerankHandler(c, resp) diff --git a/relay/channel/mistral/adaptor.go b/relay/channel/mistral/adaptor.go index 4ab1a35a..c99e5396 100644 --- a/relay/channel/mistral/adaptor.go +++ b/relay/channel/mistral/adaptor.go @@ -50,6 +50,12 @@ func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dt return nil, nil } +func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.EmbeddingRequest) (any, error) { + //TODO implement me + return nil, errors.New("not implemented") +} + + func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) { return channel.DoApiRequest(a, c, info, requestBody) } diff --git a/relay/channel/mokaai/adaptor.go b/relay/channel/mokaai/adaptor.go index e9fe34cf..9670ec94 100644 --- a/relay/channel/mokaai/adaptor.go +++ b/relay/channel/mokaai/adaptor.go @@ -3,54 +3,46 @@ package mokaai import ( "errors" "fmt" + "github.com/gin-gonic/gin" "io" "net/http" - - "github.com/gin-gonic/gin" - // "one-api/relay/adaptor" - // "one-api/relay/meta" - // "one-api/relay/model" - // "one-api/relay/constant" "one-api/dto" "one-api/relay/channel" relaycommon "one-api/relay/common" "one-api/relay/constant" + "strings" ) type Adaptor struct { } -// ConvertImageRequest implements adaptor.Adaptor. -func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) { - //TODO implement me - return nil, errors.New("not implemented") -} - func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) { //TODO implement me return nil, errors.New("not implemented") } -func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) { +func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) { //TODO implement me return nil, errors.New("not implemented") } -func (a *Adaptor) Init(info *relaycommon.RelayInfo) { + +func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.EmbeddingRequest) (any, error) { + //TODO implement me + return request, nil } +func (a *Adaptor) Init(info *relaycommon.RelayInfo) { -func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) { - - var urlPrefix = info.BaseUrl - - switch info.RelayMode { - case constant.RelayModeChatCompletions: - return fmt.Sprintf("%s/chat/completions", urlPrefix), nil - case constant.RelayModeEmbeddings: - return fmt.Sprintf("%s/embeddings", urlPrefix), nil - default: - return fmt.Sprintf("%s/run/%s", urlPrefix, info.UpstreamModelName), nil +} + +func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) { + // https://cloud.baidu.com/doc/WENXINWORKSHOP/s/clntwmv7t + suffix := "chat/" + if strings.HasPrefix(info.UpstreamModelName, "m3e") { + suffix = "embeddings" } + fullRequestURL := fmt.Sprintf("%s/%s", info.BaseUrl, suffix) + return fullRequestURL, nil } func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error { @@ -64,33 +56,30 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, re return nil, errors.New("request is nil") } switch info.RelayMode { - case constant.RelayModeChatCompletions: - return nil, errors.New("not implemented") - case constant.RelayModeEmbeddings: - // return ConvertCompletionsRequest(*request), nil - return ConvertEmbeddingRequest(*request), nil + case constant.RelayModeEmbeddings: + baiduEmbeddingRequest := embeddingRequestOpenAI2Moka(*request) + return baiduEmbeddingRequest, nil default: return nil, errors.New("not implemented") } } +func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) { + return nil, nil +} + func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) { return channel.DoApiRequest(a, c, info, requestBody) } func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) { - switch info.RelayMode { - case constant.RelayModeAudioTranscription: - case constant.RelayModeAudioTranslation: - case constant.RelayModeChatCompletions: - fallthrough + switch info.RelayMode { case constant.RelayModeEmbeddings: - if info.IsStream { - err, usage = StreamHandler(c, resp, info) - } else { - err, usage = Handler(c, resp, info) - } + err, usage = mokaEmbeddingHandler(c, resp) + default: + // err, usage = mokaHandler(c, resp) + } return } diff --git a/relay/channel/mokaai/relay-mokaai.go b/relay/channel/mokaai/relay-mokaai.go index 44d7f7c2..d7580d7a 100644 --- a/relay/channel/mokaai/relay-mokaai.go +++ b/relay/channel/mokaai/relay-mokaai.go @@ -1,41 +1,15 @@ package mokaai import ( - "bufio" "encoding/json" + "github.com/gin-gonic/gin" "io" "net/http" - "strings" - - // "one-api/common/ctxkey" - // "one-api/common/render" - - // "github.com/gin-gonic/gin" - // "one-api/common" - // "one-api/common/helper" - // "one-api/common/logger" - // "one-api/relay/adaptor/openai" - // "one-api/relay/model" - - "github.com/gin-gonic/gin" - "one-api/common" "one-api/dto" - relaycommon "one-api/relay/common" "one-api/service" - "time" ) -func ConvertCompletionsRequest(textRequest dto.GeneralOpenAIRequest) *Request { - p, _ := textRequest.Prompt.(string) - return &Request{ - Prompt: p, - MaxTokens: textRequest.GetMaxTokens(), - Stream: textRequest.Stream, - Temperature: textRequest.Temperature, - } -} - -func ConvertEmbeddingRequest(request dto.GeneralOpenAIRequest) *EmbeddingRequest { +func embeddingRequestOpenAI2Moka(request dto.GeneralOpenAIRequest) *dto.EmbeddingRequest { var input []string // Change input to []string switch v := request.Input.(type) { @@ -50,105 +24,60 @@ func ConvertEmbeddingRequest(request dto.GeneralOpenAIRequest) *EmbeddingRequest } } } - - return &EmbeddingRequest{ - Model: request.Model, - Input: input, // Assign []string to Input + return &dto.EmbeddingRequest{ + Input: input, + Model: request.Model, } } -func StreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) { - scanner := bufio.NewScanner(resp.Body) - scanner.Split(bufio.ScanLines) - - service.SetEventStreamHeaders(c) - id := service.GetResponseID(c) - var responseText string - isFirst := true - - for scanner.Scan() { - data := scanner.Text() - if len(data) < len("data: ") { - continue - } - data = strings.TrimPrefix(data, "data: ") - data = strings.TrimSuffix(data, "\r") - - if data == "[DONE]" { - break - } - - var response dto.ChatCompletionsStreamResponse - err := json.Unmarshal([]byte(data), &response) - if err != nil { - common.LogError(c, "error_unmarshalling_stream_response: "+err.Error()) - continue - } - for _, choice := range response.Choices { - choice.Delta.Role = "assistant" - responseText += choice.Delta.GetContentString() - } - response.Id = id - response.Model = info.UpstreamModelName - err = service.ObjectData(c, response) - if isFirst { - isFirst = false - info.FirstResponseTime = time.Now() - } - if err != nil { - common.LogError(c, "error_rendering_stream_response: "+err.Error()) - } +func embeddingResponseMoka2OpenAI(response *dto.EmbeddingResponse) *dto.OpenAIEmbeddingResponse { + openAIEmbeddingResponse := dto.OpenAIEmbeddingResponse{ + Object: "list", + Data: make([]dto.OpenAIEmbeddingResponseItem, 0, len(response.Data)), + Model: "baidu-embedding", + Usage: response.Usage, } - - if err := scanner.Err(); err != nil { - common.LogError(c, "error_scanning_stream_response: "+err.Error()) + for _, item := range response.Data { + openAIEmbeddingResponse.Data = append(openAIEmbeddingResponse.Data, dto.OpenAIEmbeddingResponseItem{ + Object: item.Object, + Index: item.Index, + Embedding: item.Embedding, + }) } - usage, _ := service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens) - if info.ShouldIncludeUsage { - response := service.GenerateFinalUsageResponse(id, info.StartTime.Unix(), info.UpstreamModelName, *usage) - err := service.ObjectData(c, response) - if err != nil { - common.LogError(c, "error_rendering_final_usage_response: "+err.Error()) - } - } - service.Done(c) - - err := resp.Body.Close() - if err != nil { - common.LogError(c, "close_response_body_failed: "+err.Error()) - } - - return nil, usage + return &openAIEmbeddingResponse } -func Handler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) { +func mokaEmbeddingHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) { + var baiduResponse dto.EmbeddingResponse responseBody, err := io.ReadAll(resp.Body) if err != nil { return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil } err = resp.Body.Close() if err != nil { - return service.OpenAIErrorWrapperLocal(err, "close_response_body_failed", http.StatusInternalServerError), nil + return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil } - var response dto.TextResponse - err = json.Unmarshal(responseBody, &response) + err = json.Unmarshal(responseBody, &baiduResponse) if err != nil { return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil } - response.Model = info.UpstreamModelName - var responseText string - for _, choice := range response.Choices { - responseText += choice.Message.StringContent() - } - usage, _ := service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens) - response.Usage = *usage - response.Id = service.GetResponseID(c) - jsonResponse, err := json.Marshal(response) + // if baiduResponse.ErrorMsg != "" { + // return &dto.OpenAIErrorWithStatusCode{ + // Error: dto.OpenAIError{ + // Type: "baidu_error", + // Param: "", + // }, + // StatusCode: resp.StatusCode, + // }, nil + // } + fullTextResponse := embeddingResponseMoka2OpenAI(&baiduResponse) + jsonResponse, err := json.Marshal(fullTextResponse) if err != nil { return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil } c.Writer.Header().Set("Content-Type", "application/json") c.Writer.WriteHeader(resp.StatusCode) - _, _ = c.Writer.Write(jsonResponse) - return nil, usage + _, err = c.Writer.Write(jsonResponse) + return nil, &fullTextResponse.Usage } + diff --git a/relay/channel/ollama/adaptor.go b/relay/channel/ollama/adaptor.go index 30798402..d5185084 100644 --- a/relay/channel/ollama/adaptor.go +++ b/relay/channel/ollama/adaptor.go @@ -58,6 +58,12 @@ func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dt return nil, nil } +func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.EmbeddingRequest) (any, error) { + //TODO implement me + return nil, errors.New("not implemented") +} + + func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) { return channel.DoApiRequest(a, c, info, requestBody) } diff --git a/relay/channel/openai/adaptor.go b/relay/channel/openai/adaptor.go index 230ab55c..718b26a1 100644 --- a/relay/channel/openai/adaptor.go +++ b/relay/channel/openai/adaptor.go @@ -129,6 +129,11 @@ func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dt return nil, errors.New("not implemented") } +func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.EmbeddingRequest) (any, error) { + //TODO implement me + return nil, errors.New("not implemented") +} + func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) { a.ResponseFormat = request.ResponseFormat if info.RelayMode == constant.RelayModeAudioSpeech { diff --git a/relay/channel/palm/adaptor.go b/relay/channel/palm/adaptor.go index 91272337..f38fa95b 100644 --- a/relay/channel/palm/adaptor.go +++ b/relay/channel/palm/adaptor.go @@ -49,6 +49,12 @@ func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dt return nil, nil } +func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.EmbeddingRequest) (any, error) { + //TODO implement me + return nil, errors.New("not implemented") +} + + func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) { return channel.DoApiRequest(a, c, info, requestBody) } diff --git a/relay/channel/perplexity/adaptor.go b/relay/channel/perplexity/adaptor.go index 18b66a9a..2b27bdb1 100644 --- a/relay/channel/perplexity/adaptor.go +++ b/relay/channel/perplexity/adaptor.go @@ -52,6 +52,12 @@ func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dt return nil, nil } +func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.EmbeddingRequest) (any, error) { + //TODO implement me + return nil, errors.New("not implemented") +} + + func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) { return channel.DoApiRequest(a, c, info, requestBody) } diff --git a/relay/channel/siliconflow/adaptor.go b/relay/channel/siliconflow/adaptor.go index ac722b22..f9ddedeb 100644 --- a/relay/channel/siliconflow/adaptor.go +++ b/relay/channel/siliconflow/adaptor.go @@ -58,6 +58,12 @@ func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dt return request, nil } +func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.EmbeddingRequest) (any, error) { + //TODO implement me + return nil, errors.New("not implemented") +} + + func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) { switch info.RelayMode { case constant.RelayModeRerank: diff --git a/relay/channel/tencent/adaptor.go b/relay/channel/tencent/adaptor.go index d831cc83..768ef646 100644 --- a/relay/channel/tencent/adaptor.go +++ b/relay/channel/tencent/adaptor.go @@ -73,6 +73,12 @@ func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dt return nil, nil } +func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.EmbeddingRequest) (any, error) { + //TODO implement me + return nil, errors.New("not implemented") +} + + func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) { return channel.DoApiRequest(a, c, info, requestBody) } diff --git a/relay/channel/vertex/adaptor.go b/relay/channel/vertex/adaptor.go index 764e5c4b..07659c20 100644 --- a/relay/channel/vertex/adaptor.go +++ b/relay/channel/vertex/adaptor.go @@ -151,6 +151,12 @@ func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dt return nil, nil } +func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.EmbeddingRequest) (any, error) { + //TODO implement me + return nil, errors.New("not implemented") +} + + func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) { return channel.DoApiRequest(a, c, info, requestBody) } diff --git a/relay/channel/xunfei/adaptor.go b/relay/channel/xunfei/adaptor.go index 31d426a6..71fd1367 100644 --- a/relay/channel/xunfei/adaptor.go +++ b/relay/channel/xunfei/adaptor.go @@ -50,6 +50,12 @@ func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dt return nil, nil } +func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.EmbeddingRequest) (any, error) { + //TODO implement me + return nil, errors.New("not implemented") +} + + func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) { // xunfei's request is not http request, so we don't need to do anything here dummyResp := &http.Response{} diff --git a/relay/channel/zhipu/adaptor.go b/relay/channel/zhipu/adaptor.go index f0538edc..87ff20d5 100644 --- a/relay/channel/zhipu/adaptor.go +++ b/relay/channel/zhipu/adaptor.go @@ -56,6 +56,12 @@ func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dt return nil, nil } +func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.EmbeddingRequest) (any, error) { + //TODO implement me + return nil, errors.New("not implemented") +} + + func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) { return channel.DoApiRequest(a, c, info, requestBody) } diff --git a/relay/channel/zhipu_4v/adaptor.go b/relay/channel/zhipu_4v/adaptor.go index 3d46b799..5983c1d9 100644 --- a/relay/channel/zhipu_4v/adaptor.go +++ b/relay/channel/zhipu_4v/adaptor.go @@ -53,6 +53,12 @@ func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dt return nil, nil } +func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.EmbeddingRequest) (any, error) { + //TODO implement me + return nil, errors.New("not implemented") +} + + func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) { return channel.DoApiRequest(a, c, info, requestBody) } diff --git a/relay/relay_embedding.go b/relay/relay_embedding.go new file mode 100644 index 00000000..a3304d8e --- /dev/null +++ b/relay/relay_embedding.go @@ -0,0 +1,127 @@ +package relay + +import ( + "bytes" + "encoding/json" + "fmt" + "github.com/gin-gonic/gin" + "net/http" + "one-api/common" + "one-api/dto" + relaycommon "one-api/relay/common" + relayconstant "one-api/relay/constant" + "one-api/service" + "one-api/setting" +) + +func getEmbeddingPromptToken(embeddingRequest dto.EmbeddingRequest) int { + token, _ := service.CountTokenInput(embeddingRequest.Input, embeddingRequest.Model) + return token +} + +func EmbeddingHelper(c *gin.Context, relayMode int) (openaiErr *dto.OpenAIErrorWithStatusCode) { + relayInfo := relaycommon.GenRelayInfo(c) + + var embeddingRequest *dto.EmbeddingRequest + err := common.UnmarshalBodyReusable(c, &embeddingRequest) + if err != nil { + common.LogError(c, fmt.Sprintf("getAndValidateTextRequest failed: %s", err.Error())) + return service.OpenAIErrorWrapperLocal(err, "invalid_text_request", http.StatusBadRequest) + } + if relayMode == relayconstant.RelayModeModerations && embeddingRequest.Model == "" { + embeddingRequest.Model = "m3e-base" + } + if relayMode == relayconstant.RelayModeEmbeddings && embeddingRequest.Model == "" { + embeddingRequest.Model = c.Param("model") + } + if embeddingRequest.Input == nil || len(embeddingRequest.Input) == 0 { + return service.OpenAIErrorWrapperLocal(fmt.Errorf("input is empty"), "invalid_input", http.StatusBadRequest) + } + // map model name + modelMapping := c.GetString("model_mapping") + //isModelMapped := false + if modelMapping != "" && modelMapping != "{}" { + modelMap := make(map[string]string) + err := json.Unmarshal([]byte(modelMapping), &modelMap) + if err != nil { + return service.OpenAIErrorWrapperLocal(err, "unmarshal_model_mapping_failed", http.StatusInternalServerError) + } + if modelMap[embeddingRequest.Model] != "" { + embeddingRequest.Model = modelMap[embeddingRequest.Model] + // set upstream model name + //isModelMapped = true + } + } + + relayInfo.UpstreamModelName = embeddingRequest.Model + modelPrice, success := common.GetModelPrice(embeddingRequest.Model, false) + groupRatio := setting.GetGroupRatio(relayInfo.Group) + + var preConsumedQuota int + var ratio float64 + var modelRatio float64 + + promptToken := getEmbeddingPromptToken(*embeddingRequest) + if !success { + preConsumedTokens := promptToken + modelRatio = common.GetModelRatio(embeddingRequest.Model) + ratio = modelRatio * groupRatio + preConsumedQuota = int(float64(preConsumedTokens) * ratio) + } else { + preConsumedQuota = int(modelPrice * common.QuotaPerUnit * groupRatio) + } + relayInfo.PromptTokens = promptToken + + // pre-consume quota 预消耗配额 + preConsumedQuota, userQuota, openaiErr := preConsumeQuota(c, preConsumedQuota, relayInfo) + if openaiErr != nil { + return openaiErr + } + defer func() { + if openaiErr != nil { + returnPreConsumedQuota(c, relayInfo, userQuota, preConsumedQuota) + } + }() + + adaptor := GetAdaptor(relayInfo.ApiType) + if adaptor == nil { + return service.OpenAIErrorWrapperLocal(fmt.Errorf("invalid api type: %d", relayInfo.ApiType), "invalid_api_type", http.StatusBadRequest) + } + adaptor.Init(relayInfo) + + convertedRequest, err := adaptor.ConvertEmbeddingRequest(c,relayInfo,*embeddingRequest) + + if err != nil { + return service.OpenAIErrorWrapperLocal(err, "convert_request_failed", http.StatusInternalServerError) + } + jsonData, err := json.Marshal(convertedRequest) + if err != nil { + return service.OpenAIErrorWrapperLocal(err, "json_marshal_failed", http.StatusInternalServerError) + } + requestBody := bytes.NewBuffer(jsonData) + statusCodeMappingStr := c.GetString("status_code_mapping") + resp, err := adaptor.DoRequest(c,relayInfo, requestBody) + if err != nil { + return service.OpenAIErrorWrapper(err, "do_request_failed", http.StatusInternalServerError) + } + + var httpResp *http.Response + if resp != nil { + httpResp = resp.(*http.Response) + if httpResp.StatusCode != http.StatusOK { + openaiErr = service.RelayErrorHandler(httpResp) + // reset status code 重置状态码 + service.ResetStatusCode(openaiErr, statusCodeMappingStr) + return openaiErr + } + } + + usage, openaiErr := adaptor.DoResponse(c, httpResp, relayInfo) + if openaiErr != nil { + // reset status code 重置状态码 + service.ResetStatusCode(openaiErr, statusCodeMappingStr) + return openaiErr + } + postConsumeQuota(c, relayInfo, embeddingRequest.Model, usage.(*dto.Usage), ratio, preConsumedQuota, userQuota, modelRatio, groupRatio, modelPrice, success, "") + return nil +}