feat: Improve embedding request handling and support across channels
- Update EmbeddingRequest DTO to support more flexible input types - Add input parsing method to handle various input formats - Implement ConvertEmbeddingRequest for multiple channel adaptors - Remove relayMode parameter from EmbeddingHelper - Add input validation for embedding requests - Simplify embedding request conversion for different channels
This commit is contained in:
@@ -34,7 +34,7 @@ func relayHandler(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusCode
|
||||
case relayconstant.RelayModeRerank:
|
||||
err = relay.RerankHelper(c, relayMode)
|
||||
case relayconstant.RelayModeEmbeddings:
|
||||
err = relay.EmbeddingHelper(c,relayMode)
|
||||
err = relay.EmbeddingHelper(c)
|
||||
default:
|
||||
err = relay.TextHelper(c)
|
||||
}
|
||||
@@ -57,11 +57,6 @@ func Relay(c *gin.Context) {
|
||||
originalModel := c.GetString("original_model")
|
||||
var openaiErr *dto.OpenAIErrorWithStatusCode
|
||||
|
||||
//获取request body 并输出到日志
|
||||
requestBody, _ := common.GetRequestBody(c)
|
||||
common.LogInfo(c, fmt.Sprintf("relayMode: %d ,request body: %s",relayMode, string(requestBody)))
|
||||
|
||||
|
||||
for i := 0; i <= common.RetryTimes; i++ {
|
||||
channel, err := getChannel(c, group, originalModel, i)
|
||||
if err != nil {
|
||||
@@ -161,7 +156,6 @@ func WssRelay(c *gin.Context) {
|
||||
}
|
||||
|
||||
func relayRequest(c *gin.Context, relayMode int, channel *model.Channel) *dto.OpenAIErrorWithStatusCode {
|
||||
common.LogInfo(c, fmt.Sprintf("relayMode: %d ,channel Id : %s",relayMode, string(channel.Id)))
|
||||
addUsedChannel(c, channel.Id)
|
||||
requestBody, _ := common.GetRequestBody(c)
|
||||
c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody))
|
||||
|
||||
@@ -12,8 +12,35 @@ type EmbeddingOptions struct {
|
||||
}
|
||||
|
||||
type EmbeddingRequest struct {
|
||||
Model string `json:"model"`
|
||||
Input []string `json:"input"`
|
||||
Model string `json:"model"`
|
||||
Input any `json:"input"`
|
||||
EncodingFormat string `json:"encoding_format,omitempty"`
|
||||
Dimensions int `json:"dimensions,omitempty"`
|
||||
User string `json:"user,omitempty"`
|
||||
Seed float64 `json:"seed,omitempty"`
|
||||
Temperature *float64 `json:"temperature,omitempty"`
|
||||
TopP float64 `json:"top_p,omitempty"`
|
||||
FrequencyPenalty float64 `json:"frequency_penalty,omitempty"`
|
||||
PresencePenalty float64 `json:"presence_penalty,omitempty"`
|
||||
}
|
||||
|
||||
func (r EmbeddingRequest) ParseInput() []string {
|
||||
if r.Input == nil {
|
||||
return nil
|
||||
}
|
||||
var input []string
|
||||
switch r.Input.(type) {
|
||||
case string:
|
||||
input = []string{r.Input.(string)}
|
||||
case []any:
|
||||
input = make([]string, 0, len(r.Input.([]any)))
|
||||
for _, item := range r.Input.([]any) {
|
||||
if str, ok := item.(string); ok {
|
||||
input = append(input, str)
|
||||
}
|
||||
}
|
||||
}
|
||||
return input
|
||||
}
|
||||
|
||||
type EmbeddingResponseItem struct {
|
||||
@@ -23,8 +50,8 @@ type EmbeddingResponseItem struct {
|
||||
}
|
||||
|
||||
type EmbeddingResponse struct {
|
||||
Object string `json:"object"`
|
||||
Object string `json:"object"`
|
||||
Data []EmbeddingResponseItem `json:"data"`
|
||||
Model string `json:"model"`
|
||||
Model string `json:"model"`
|
||||
Usage `json:"usage"`
|
||||
}
|
||||
}
|
||||
|
||||
@@ -49,9 +49,6 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, re
|
||||
return nil, errors.New("request is nil")
|
||||
}
|
||||
switch info.RelayMode {
|
||||
case constant.RelayModeEmbeddings:
|
||||
baiduEmbeddingRequest := embeddingRequestOpenAI2Ali(*request)
|
||||
return baiduEmbeddingRequest, nil
|
||||
default:
|
||||
aliReq := requestOpenAI2Ali(*request)
|
||||
return aliReq, nil
|
||||
@@ -68,8 +65,7 @@ func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dt
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.EmbeddingRequest) (any, error) {
|
||||
//TODO implement me
|
||||
return nil, errors.New("not implemented")
|
||||
return embeddingRequestOpenAI2Ali(request), nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) {
|
||||
|
||||
@@ -25,9 +25,12 @@ func requestOpenAI2Ali(request dto.GeneralOpenAIRequest) *dto.GeneralOpenAIReque
|
||||
return &request
|
||||
}
|
||||
|
||||
func embeddingRequestOpenAI2Ali(request dto.GeneralOpenAIRequest) *AliEmbeddingRequest {
|
||||
func embeddingRequestOpenAI2Ali(request dto.EmbeddingRequest) *AliEmbeddingRequest {
|
||||
if request.Model == "" {
|
||||
request.Model = "text-embedding-v1"
|
||||
}
|
||||
return &AliEmbeddingRequest{
|
||||
Model: "text-embedding-v1",
|
||||
Model: request.Model,
|
||||
Input: struct {
|
||||
Texts []string `json:"texts"`
|
||||
}{
|
||||
|
||||
@@ -109,9 +109,6 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, re
|
||||
return nil, errors.New("request is nil")
|
||||
}
|
||||
switch info.RelayMode {
|
||||
case constant.RelayModeEmbeddings:
|
||||
baiduEmbeddingRequest := embeddingRequestOpenAI2Baidu(*request)
|
||||
return baiduEmbeddingRequest, nil
|
||||
default:
|
||||
baiduRequest := requestOpenAI2Baidu(*request)
|
||||
return baiduRequest, nil
|
||||
@@ -123,8 +120,8 @@ func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dt
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.EmbeddingRequest) (any, error) {
|
||||
//TODO implement me
|
||||
return nil, errors.New("not implemented")
|
||||
baiduEmbeddingRequest := embeddingRequestOpenAI2Baidu(request)
|
||||
return baiduEmbeddingRequest, nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
|
||||
|
||||
@@ -87,7 +87,7 @@ func streamResponseBaidu2OpenAI(baiduResponse *BaiduChatStreamResponse) *dto.Cha
|
||||
return &response
|
||||
}
|
||||
|
||||
func embeddingRequestOpenAI2Baidu(request dto.GeneralOpenAIRequest) *BaiduEmbeddingRequest {
|
||||
func embeddingRequestOpenAI2Baidu(request dto.EmbeddingRequest) *BaiduEmbeddingRequest {
|
||||
return &BaiduEmbeddingRequest{
|
||||
Input: request.ParseInput(),
|
||||
}
|
||||
|
||||
@@ -57,11 +57,9 @@ func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dt
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.EmbeddingRequest) (any, error) {
|
||||
//TODO implement me
|
||||
return nil, errors.New("not implemented")
|
||||
return request, nil
|
||||
}
|
||||
|
||||
|
||||
func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) {
|
||||
// 添加文件字段
|
||||
file, _, err := c.Request.FormFile("file")
|
||||
|
||||
@@ -56,11 +56,9 @@ func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dt
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.EmbeddingRequest) (any, error) {
|
||||
//TODO implement me
|
||||
return nil, errors.New("not implemented")
|
||||
return request, nil
|
||||
}
|
||||
|
||||
|
||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) {
|
||||
if info.RelayMode == constant.RelayModeRerank {
|
||||
err, usage = jinaRerankHandler(c, resp)
|
||||
|
||||
@@ -46,12 +46,7 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, re
|
||||
if request == nil {
|
||||
return nil, errors.New("request is nil")
|
||||
}
|
||||
switch info.RelayMode {
|
||||
case relayconstant.RelayModeEmbeddings:
|
||||
return requestOpenAI2Embeddings(*request), nil
|
||||
default:
|
||||
return requestOpenAI2Ollama(*request), nil
|
||||
}
|
||||
return requestOpenAI2Ollama(*request), nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) {
|
||||
@@ -59,11 +54,9 @@ func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dt
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.EmbeddingRequest) (any, error) {
|
||||
//TODO implement me
|
||||
return nil, errors.New("not implemented")
|
||||
return requestOpenAI2Embeddings(request), nil
|
||||
}
|
||||
|
||||
|
||||
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
|
||||
return channel.DoApiRequest(a, c, info, requestBody)
|
||||
}
|
||||
|
||||
@@ -42,7 +42,7 @@ func requestOpenAI2Ollama(request dto.GeneralOpenAIRequest) *OllamaRequest {
|
||||
}
|
||||
}
|
||||
|
||||
func requestOpenAI2Embeddings(request dto.GeneralOpenAIRequest) *OllamaEmbeddingRequest {
|
||||
func requestOpenAI2Embeddings(request dto.EmbeddingRequest) *OllamaEmbeddingRequest {
|
||||
return &OllamaEmbeddingRequest{
|
||||
Model: request.Model,
|
||||
Input: request.ParseInput(),
|
||||
@@ -123,9 +123,9 @@ func ollamaEmbeddingHandler(c *gin.Context, resp *http.Response, promptTokens in
|
||||
}
|
||||
|
||||
func flattenEmbeddings(embeddings [][]float64) []float64 {
|
||||
flattened := []float64{}
|
||||
for _, row := range embeddings {
|
||||
flattened = append(flattened, row...)
|
||||
flattened := []float64{}
|
||||
for _, row := range embeddings {
|
||||
flattened = append(flattened, row...)
|
||||
}
|
||||
return flattened
|
||||
}
|
||||
return flattened
|
||||
}
|
||||
@@ -150,8 +150,7 @@ func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dt
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.EmbeddingRequest) (any, error) {
|
||||
//TODO implement me
|
||||
return nil, errors.New("not implemented")
|
||||
return request, nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) {
|
||||
|
||||
@@ -59,11 +59,9 @@ func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dt
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.EmbeddingRequest) (any, error) {
|
||||
//TODO implement me
|
||||
return nil, errors.New("not implemented")
|
||||
return request, nil
|
||||
}
|
||||
|
||||
|
||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) {
|
||||
switch info.RelayMode {
|
||||
case constant.RelayModeRerank:
|
||||
|
||||
@@ -19,7 +19,20 @@ func getEmbeddingPromptToken(embeddingRequest dto.EmbeddingRequest) int {
|
||||
return token
|
||||
}
|
||||
|
||||
func EmbeddingHelper(c *gin.Context, relayMode int) (openaiErr *dto.OpenAIErrorWithStatusCode) {
|
||||
func validateEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, embeddingRequest dto.EmbeddingRequest) error {
|
||||
if embeddingRequest.Input == nil {
|
||||
return fmt.Errorf("input is empty")
|
||||
}
|
||||
if info.RelayMode == relayconstant.RelayModeModerations && embeddingRequest.Model == "" {
|
||||
embeddingRequest.Model = "omni-moderation-latest"
|
||||
}
|
||||
if info.RelayMode == relayconstant.RelayModeEmbeddings && embeddingRequest.Model == "" {
|
||||
embeddingRequest.Model = c.Param("model")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func EmbeddingHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) {
|
||||
relayInfo := relaycommon.GenRelayInfo(c)
|
||||
|
||||
var embeddingRequest *dto.EmbeddingRequest
|
||||
@@ -28,15 +41,12 @@ func EmbeddingHelper(c *gin.Context, relayMode int) (openaiErr *dto.OpenAIErrorW
|
||||
common.LogError(c, fmt.Sprintf("getAndValidateTextRequest failed: %s", err.Error()))
|
||||
return service.OpenAIErrorWrapperLocal(err, "invalid_text_request", http.StatusBadRequest)
|
||||
}
|
||||
if relayMode == relayconstant.RelayModeModerations && embeddingRequest.Model == "" {
|
||||
embeddingRequest.Model = "m3e-base"
|
||||
}
|
||||
if relayMode == relayconstant.RelayModeEmbeddings && embeddingRequest.Model == "" {
|
||||
embeddingRequest.Model = c.Param("model")
|
||||
}
|
||||
if embeddingRequest.Input == nil || len(embeddingRequest.Input) == 0 {
|
||||
return service.OpenAIErrorWrapperLocal(fmt.Errorf("input is empty"), "invalid_input", http.StatusBadRequest)
|
||||
|
||||
err = validateEmbeddingRequest(c, relayInfo, *embeddingRequest)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapperLocal(err, "invalid_embedding_request", http.StatusBadRequest)
|
||||
}
|
||||
|
||||
// map model name
|
||||
modelMapping := c.GetString("model_mapping")
|
||||
//isModelMapped := false
|
||||
@@ -89,8 +99,8 @@ func EmbeddingHelper(c *gin.Context, relayMode int) (openaiErr *dto.OpenAIErrorW
|
||||
}
|
||||
adaptor.Init(relayInfo)
|
||||
|
||||
convertedRequest, err := adaptor.ConvertEmbeddingRequest(c,relayInfo,*embeddingRequest)
|
||||
|
||||
convertedRequest, err := adaptor.ConvertEmbeddingRequest(c, relayInfo, *embeddingRequest)
|
||||
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapperLocal(err, "convert_request_failed", http.StatusInternalServerError)
|
||||
}
|
||||
@@ -100,7 +110,7 @@ func EmbeddingHelper(c *gin.Context, relayMode int) (openaiErr *dto.OpenAIErrorW
|
||||
}
|
||||
requestBody := bytes.NewBuffer(jsonData)
|
||||
statusCodeMappingStr := c.GetString("status_code_mapping")
|
||||
resp, err := adaptor.DoRequest(c,relayInfo, requestBody)
|
||||
resp, err := adaptor.DoRequest(c, relayInfo, requestBody)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "do_request_failed", http.StatusInternalServerError)
|
||||
}
|
||||
|
||||
@@ -4,7 +4,6 @@ import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/pkoukk/tiktoken-go"
|
||||
"image"
|
||||
"log"
|
||||
"math"
|
||||
@@ -14,6 +13,8 @@ import (
|
||||
relaycommon "one-api/relay/common"
|
||||
"strings"
|
||||
"unicode/utf8"
|
||||
|
||||
"github.com/pkoukk/tiktoken-go"
|
||||
)
|
||||
|
||||
// tokenEncoderMap won't grow after initialization
|
||||
@@ -323,6 +324,12 @@ func CountTokenInput(input any, model string) (int, error) {
|
||||
text += s
|
||||
}
|
||||
return CountTextToken(text, model)
|
||||
case []interface{}:
|
||||
text := ""
|
||||
for _, item := range v {
|
||||
text += fmt.Sprintf("%v", item)
|
||||
}
|
||||
return CountTextToken(text, model)
|
||||
}
|
||||
return CountTokenInput(fmt.Sprintf("%v", input), model)
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user