diff --git a/controller/relay.go b/controller/relay.go index 25af7e20..d7e0f00a 100644 --- a/controller/relay.go +++ b/controller/relay.go @@ -34,7 +34,7 @@ func relayHandler(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusCode case relayconstant.RelayModeRerank: err = relay.RerankHelper(c, relayMode) case relayconstant.RelayModeEmbeddings: - err = relay.EmbeddingHelper(c,relayMode) + err = relay.EmbeddingHelper(c) default: err = relay.TextHelper(c) } @@ -57,11 +57,6 @@ func Relay(c *gin.Context) { originalModel := c.GetString("original_model") var openaiErr *dto.OpenAIErrorWithStatusCode - //获取request body 并输出到日志 - requestBody, _ := common.GetRequestBody(c) - common.LogInfo(c, fmt.Sprintf("relayMode: %d ,request body: %s",relayMode, string(requestBody))) - - for i := 0; i <= common.RetryTimes; i++ { channel, err := getChannel(c, group, originalModel, i) if err != nil { @@ -161,7 +156,6 @@ func WssRelay(c *gin.Context) { } func relayRequest(c *gin.Context, relayMode int, channel *model.Channel) *dto.OpenAIErrorWithStatusCode { - common.LogInfo(c, fmt.Sprintf("relayMode: %d ,channel Id : %s",relayMode, string(channel.Id))) addUsedChannel(c, channel.Id) requestBody, _ := common.GetRequestBody(c) c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody)) diff --git a/dto/embedding.go b/dto/embedding.go index 828faaab..9d722292 100644 --- a/dto/embedding.go +++ b/dto/embedding.go @@ -12,8 +12,35 @@ type EmbeddingOptions struct { } type EmbeddingRequest struct { - Model string `json:"model"` - Input []string `json:"input"` + Model string `json:"model"` + Input any `json:"input"` + EncodingFormat string `json:"encoding_format,omitempty"` + Dimensions int `json:"dimensions,omitempty"` + User string `json:"user,omitempty"` + Seed float64 `json:"seed,omitempty"` + Temperature *float64 `json:"temperature,omitempty"` + TopP float64 `json:"top_p,omitempty"` + FrequencyPenalty float64 `json:"frequency_penalty,omitempty"` + PresencePenalty float64 `json:"presence_penalty,omitempty"` +} + +func (r EmbeddingRequest) ParseInput() []string { + if r.Input == nil { + return nil + } + var input []string + switch r.Input.(type) { + case string: + input = []string{r.Input.(string)} + case []any: + input = make([]string, 0, len(r.Input.([]any))) + for _, item := range r.Input.([]any) { + if str, ok := item.(string); ok { + input = append(input, str) + } + } + } + return input } type EmbeddingResponseItem struct { @@ -23,8 +50,8 @@ type EmbeddingResponseItem struct { } type EmbeddingResponse struct { - Object string `json:"object"` + Object string `json:"object"` Data []EmbeddingResponseItem `json:"data"` - Model string `json:"model"` + Model string `json:"model"` Usage `json:"usage"` -} \ No newline at end of file +} diff --git a/relay/channel/ali/adaptor.go b/relay/channel/ali/adaptor.go index c4974a62..32be399b 100644 --- a/relay/channel/ali/adaptor.go +++ b/relay/channel/ali/adaptor.go @@ -49,9 +49,6 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, re return nil, errors.New("request is nil") } switch info.RelayMode { - case constant.RelayModeEmbeddings: - baiduEmbeddingRequest := embeddingRequestOpenAI2Ali(*request) - return baiduEmbeddingRequest, nil default: aliReq := requestOpenAI2Ali(*request) return aliReq, nil @@ -68,8 +65,7 @@ func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dt } func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.EmbeddingRequest) (any, error) { - //TODO implement me - return nil, errors.New("not implemented") + return embeddingRequestOpenAI2Ali(request), nil } func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) { diff --git a/relay/channel/ali/text.go b/relay/channel/ali/text.go index aec857fa..db4df0a9 100644 --- a/relay/channel/ali/text.go +++ b/relay/channel/ali/text.go @@ -25,9 +25,12 @@ func requestOpenAI2Ali(request dto.GeneralOpenAIRequest) *dto.GeneralOpenAIReque return &request } -func embeddingRequestOpenAI2Ali(request dto.GeneralOpenAIRequest) *AliEmbeddingRequest { +func embeddingRequestOpenAI2Ali(request dto.EmbeddingRequest) *AliEmbeddingRequest { + if request.Model == "" { + request.Model = "text-embedding-v1" + } return &AliEmbeddingRequest{ - Model: "text-embedding-v1", + Model: request.Model, Input: struct { Texts []string `json:"texts"` }{ diff --git a/relay/channel/baidu/adaptor.go b/relay/channel/baidu/adaptor.go index 35271b41..46a1f964 100644 --- a/relay/channel/baidu/adaptor.go +++ b/relay/channel/baidu/adaptor.go @@ -109,9 +109,6 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, re return nil, errors.New("request is nil") } switch info.RelayMode { - case constant.RelayModeEmbeddings: - baiduEmbeddingRequest := embeddingRequestOpenAI2Baidu(*request) - return baiduEmbeddingRequest, nil default: baiduRequest := requestOpenAI2Baidu(*request) return baiduRequest, nil @@ -123,8 +120,8 @@ func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dt } func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.EmbeddingRequest) (any, error) { - //TODO implement me - return nil, errors.New("not implemented") + baiduEmbeddingRequest := embeddingRequestOpenAI2Baidu(request) + return baiduEmbeddingRequest, nil } func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) { diff --git a/relay/channel/baidu/relay-baidu.go b/relay/channel/baidu/relay-baidu.go index 09a99e4d..d88f5212 100644 --- a/relay/channel/baidu/relay-baidu.go +++ b/relay/channel/baidu/relay-baidu.go @@ -87,7 +87,7 @@ func streamResponseBaidu2OpenAI(baiduResponse *BaiduChatStreamResponse) *dto.Cha return &response } -func embeddingRequestOpenAI2Baidu(request dto.GeneralOpenAIRequest) *BaiduEmbeddingRequest { +func embeddingRequestOpenAI2Baidu(request dto.EmbeddingRequest) *BaiduEmbeddingRequest { return &BaiduEmbeddingRequest{ Input: request.ParseInput(), } diff --git a/relay/channel/cloudflare/adaptor.go b/relay/channel/cloudflare/adaptor.go index cf41d9d7..75400098 100644 --- a/relay/channel/cloudflare/adaptor.go +++ b/relay/channel/cloudflare/adaptor.go @@ -57,11 +57,9 @@ func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dt } func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.EmbeddingRequest) (any, error) { - //TODO implement me - return nil, errors.New("not implemented") + return request, nil } - func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) { // 添加文件字段 file, _, err := c.Request.FormFile("file") diff --git a/relay/channel/jina/adaptor.go b/relay/channel/jina/adaptor.go index 98fb073d..3706e3b8 100644 --- a/relay/channel/jina/adaptor.go +++ b/relay/channel/jina/adaptor.go @@ -56,11 +56,9 @@ func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dt } func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.EmbeddingRequest) (any, error) { - //TODO implement me - return nil, errors.New("not implemented") + return request, nil } - func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) { if info.RelayMode == constant.RelayModeRerank { err, usage = jinaRerankHandler(c, resp) diff --git a/relay/channel/ollama/adaptor.go b/relay/channel/ollama/adaptor.go index d5185084..36889cb8 100644 --- a/relay/channel/ollama/adaptor.go +++ b/relay/channel/ollama/adaptor.go @@ -46,12 +46,7 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, re if request == nil { return nil, errors.New("request is nil") } - switch info.RelayMode { - case relayconstant.RelayModeEmbeddings: - return requestOpenAI2Embeddings(*request), nil - default: - return requestOpenAI2Ollama(*request), nil - } + return requestOpenAI2Ollama(*request), nil } func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) { @@ -59,11 +54,9 @@ func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dt } func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.EmbeddingRequest) (any, error) { - //TODO implement me - return nil, errors.New("not implemented") + return requestOpenAI2Embeddings(request), nil } - func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) { return channel.DoApiRequest(a, c, info, requestBody) } diff --git a/relay/channel/ollama/relay-ollama.go b/relay/channel/ollama/relay-ollama.go index 2ef716b3..4ecdd19b 100644 --- a/relay/channel/ollama/relay-ollama.go +++ b/relay/channel/ollama/relay-ollama.go @@ -42,7 +42,7 @@ func requestOpenAI2Ollama(request dto.GeneralOpenAIRequest) *OllamaRequest { } } -func requestOpenAI2Embeddings(request dto.GeneralOpenAIRequest) *OllamaEmbeddingRequest { +func requestOpenAI2Embeddings(request dto.EmbeddingRequest) *OllamaEmbeddingRequest { return &OllamaEmbeddingRequest{ Model: request.Model, Input: request.ParseInput(), @@ -123,9 +123,9 @@ func ollamaEmbeddingHandler(c *gin.Context, resp *http.Response, promptTokens in } func flattenEmbeddings(embeddings [][]float64) []float64 { -flattened := []float64{} -for _, row := range embeddings { - flattened = append(flattened, row...) + flattened := []float64{} + for _, row := range embeddings { + flattened = append(flattened, row...) + } + return flattened } -return flattened -} \ No newline at end of file diff --git a/relay/channel/openai/adaptor.go b/relay/channel/openai/adaptor.go index d86b33e0..e94399ea 100644 --- a/relay/channel/openai/adaptor.go +++ b/relay/channel/openai/adaptor.go @@ -150,8 +150,7 @@ func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dt } func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.EmbeddingRequest) (any, error) { - //TODO implement me - return nil, errors.New("not implemented") + return request, nil } func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) { diff --git a/relay/channel/siliconflow/adaptor.go b/relay/channel/siliconflow/adaptor.go index f9ddedeb..c02d18a3 100644 --- a/relay/channel/siliconflow/adaptor.go +++ b/relay/channel/siliconflow/adaptor.go @@ -59,11 +59,9 @@ func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dt } func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.EmbeddingRequest) (any, error) { - //TODO implement me - return nil, errors.New("not implemented") + return request, nil } - func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) { switch info.RelayMode { case constant.RelayModeRerank: diff --git a/relay/relay_embedding.go b/relay/relay_embedding.go index a3304d8e..0a41c11d 100644 --- a/relay/relay_embedding.go +++ b/relay/relay_embedding.go @@ -19,7 +19,20 @@ func getEmbeddingPromptToken(embeddingRequest dto.EmbeddingRequest) int { return token } -func EmbeddingHelper(c *gin.Context, relayMode int) (openaiErr *dto.OpenAIErrorWithStatusCode) { +func validateEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, embeddingRequest dto.EmbeddingRequest) error { + if embeddingRequest.Input == nil { + return fmt.Errorf("input is empty") + } + if info.RelayMode == relayconstant.RelayModeModerations && embeddingRequest.Model == "" { + embeddingRequest.Model = "omni-moderation-latest" + } + if info.RelayMode == relayconstant.RelayModeEmbeddings && embeddingRequest.Model == "" { + embeddingRequest.Model = c.Param("model") + } + return nil +} + +func EmbeddingHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) { relayInfo := relaycommon.GenRelayInfo(c) var embeddingRequest *dto.EmbeddingRequest @@ -28,15 +41,12 @@ func EmbeddingHelper(c *gin.Context, relayMode int) (openaiErr *dto.OpenAIErrorW common.LogError(c, fmt.Sprintf("getAndValidateTextRequest failed: %s", err.Error())) return service.OpenAIErrorWrapperLocal(err, "invalid_text_request", http.StatusBadRequest) } - if relayMode == relayconstant.RelayModeModerations && embeddingRequest.Model == "" { - embeddingRequest.Model = "m3e-base" - } - if relayMode == relayconstant.RelayModeEmbeddings && embeddingRequest.Model == "" { - embeddingRequest.Model = c.Param("model") - } - if embeddingRequest.Input == nil || len(embeddingRequest.Input) == 0 { - return service.OpenAIErrorWrapperLocal(fmt.Errorf("input is empty"), "invalid_input", http.StatusBadRequest) + + err = validateEmbeddingRequest(c, relayInfo, *embeddingRequest) + if err != nil { + return service.OpenAIErrorWrapperLocal(err, "invalid_embedding_request", http.StatusBadRequest) } + // map model name modelMapping := c.GetString("model_mapping") //isModelMapped := false @@ -89,8 +99,8 @@ func EmbeddingHelper(c *gin.Context, relayMode int) (openaiErr *dto.OpenAIErrorW } adaptor.Init(relayInfo) - convertedRequest, err := adaptor.ConvertEmbeddingRequest(c,relayInfo,*embeddingRequest) - + convertedRequest, err := adaptor.ConvertEmbeddingRequest(c, relayInfo, *embeddingRequest) + if err != nil { return service.OpenAIErrorWrapperLocal(err, "convert_request_failed", http.StatusInternalServerError) } @@ -100,7 +110,7 @@ func EmbeddingHelper(c *gin.Context, relayMode int) (openaiErr *dto.OpenAIErrorW } requestBody := bytes.NewBuffer(jsonData) statusCodeMappingStr := c.GetString("status_code_mapping") - resp, err := adaptor.DoRequest(c,relayInfo, requestBody) + resp, err := adaptor.DoRequest(c, relayInfo, requestBody) if err != nil { return service.OpenAIErrorWrapper(err, "do_request_failed", http.StatusInternalServerError) } diff --git a/service/token_counter.go b/service/token_counter.go index 93feab2d..319c9b11 100644 --- a/service/token_counter.go +++ b/service/token_counter.go @@ -4,7 +4,6 @@ import ( "encoding/json" "errors" "fmt" - "github.com/pkoukk/tiktoken-go" "image" "log" "math" @@ -14,6 +13,8 @@ import ( relaycommon "one-api/relay/common" "strings" "unicode/utf8" + + "github.com/pkoukk/tiktoken-go" ) // tokenEncoderMap won't grow after initialization @@ -323,6 +324,12 @@ func CountTokenInput(input any, model string) (int, error) { text += s } return CountTextToken(text, model) + case []interface{}: + text := "" + for _, item := range v { + text += fmt.Sprintf("%v", item) + } + return CountTextToken(text, model) } return CountTokenInput(fmt.Sprintf("%v", input), model) }