feat: Improve embedding request handling and support across channels

- Update EmbeddingRequest DTO to support more flexible input types
- Add input parsing method to handle various input formats
- Implement ConvertEmbeddingRequest for multiple channel adaptors
- Remove relayMode parameter from EmbeddingHelper
- Add input validation for embedding requests
- Simplify embedding request conversion for different channels
This commit is contained in:
1808837298@qq.com
2025-02-12 14:39:36 +08:00
parent eceb6afcdd
commit f5e3063f33
14 changed files with 84 additions and 64 deletions

View File

@@ -34,7 +34,7 @@ func relayHandler(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusCode
case relayconstant.RelayModeRerank: case relayconstant.RelayModeRerank:
err = relay.RerankHelper(c, relayMode) err = relay.RerankHelper(c, relayMode)
case relayconstant.RelayModeEmbeddings: case relayconstant.RelayModeEmbeddings:
err = relay.EmbeddingHelper(c,relayMode) err = relay.EmbeddingHelper(c)
default: default:
err = relay.TextHelper(c) err = relay.TextHelper(c)
} }
@@ -57,11 +57,6 @@ func Relay(c *gin.Context) {
originalModel := c.GetString("original_model") originalModel := c.GetString("original_model")
var openaiErr *dto.OpenAIErrorWithStatusCode var openaiErr *dto.OpenAIErrorWithStatusCode
//获取request body 并输出到日志
requestBody, _ := common.GetRequestBody(c)
common.LogInfo(c, fmt.Sprintf("relayMode: %d ,request body: %s",relayMode, string(requestBody)))
for i := 0; i <= common.RetryTimes; i++ { for i := 0; i <= common.RetryTimes; i++ {
channel, err := getChannel(c, group, originalModel, i) channel, err := getChannel(c, group, originalModel, i)
if err != nil { if err != nil {
@@ -161,7 +156,6 @@ func WssRelay(c *gin.Context) {
} }
func relayRequest(c *gin.Context, relayMode int, channel *model.Channel) *dto.OpenAIErrorWithStatusCode { func relayRequest(c *gin.Context, relayMode int, channel *model.Channel) *dto.OpenAIErrorWithStatusCode {
common.LogInfo(c, fmt.Sprintf("relayMode: %d ,channel Id : %s",relayMode, string(channel.Id)))
addUsedChannel(c, channel.Id) addUsedChannel(c, channel.Id)
requestBody, _ := common.GetRequestBody(c) requestBody, _ := common.GetRequestBody(c)
c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody)) c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody))

View File

@@ -12,8 +12,35 @@ type EmbeddingOptions struct {
} }
type EmbeddingRequest struct { type EmbeddingRequest struct {
Model string `json:"model"` Model string `json:"model"`
Input []string `json:"input"` Input any `json:"input"`
EncodingFormat string `json:"encoding_format,omitempty"`
Dimensions int `json:"dimensions,omitempty"`
User string `json:"user,omitempty"`
Seed float64 `json:"seed,omitempty"`
Temperature *float64 `json:"temperature,omitempty"`
TopP float64 `json:"top_p,omitempty"`
FrequencyPenalty float64 `json:"frequency_penalty,omitempty"`
PresencePenalty float64 `json:"presence_penalty,omitempty"`
}
func (r EmbeddingRequest) ParseInput() []string {
if r.Input == nil {
return nil
}
var input []string
switch r.Input.(type) {
case string:
input = []string{r.Input.(string)}
case []any:
input = make([]string, 0, len(r.Input.([]any)))
for _, item := range r.Input.([]any) {
if str, ok := item.(string); ok {
input = append(input, str)
}
}
}
return input
} }
type EmbeddingResponseItem struct { type EmbeddingResponseItem struct {
@@ -23,8 +50,8 @@ type EmbeddingResponseItem struct {
} }
type EmbeddingResponse struct { type EmbeddingResponse struct {
Object string `json:"object"` Object string `json:"object"`
Data []EmbeddingResponseItem `json:"data"` Data []EmbeddingResponseItem `json:"data"`
Model string `json:"model"` Model string `json:"model"`
Usage `json:"usage"` Usage `json:"usage"`
} }

View File

@@ -49,9 +49,6 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, re
return nil, errors.New("request is nil") return nil, errors.New("request is nil")
} }
switch info.RelayMode { switch info.RelayMode {
case constant.RelayModeEmbeddings:
baiduEmbeddingRequest := embeddingRequestOpenAI2Ali(*request)
return baiduEmbeddingRequest, nil
default: default:
aliReq := requestOpenAI2Ali(*request) aliReq := requestOpenAI2Ali(*request)
return aliReq, nil return aliReq, nil
@@ -68,8 +65,7 @@ func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dt
} }
func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.EmbeddingRequest) (any, error) { func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.EmbeddingRequest) (any, error) {
//TODO implement me return embeddingRequestOpenAI2Ali(request), nil
return nil, errors.New("not implemented")
} }
func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) { func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) {

View File

@@ -25,9 +25,12 @@ func requestOpenAI2Ali(request dto.GeneralOpenAIRequest) *dto.GeneralOpenAIReque
return &request return &request
} }
func embeddingRequestOpenAI2Ali(request dto.GeneralOpenAIRequest) *AliEmbeddingRequest { func embeddingRequestOpenAI2Ali(request dto.EmbeddingRequest) *AliEmbeddingRequest {
if request.Model == "" {
request.Model = "text-embedding-v1"
}
return &AliEmbeddingRequest{ return &AliEmbeddingRequest{
Model: "text-embedding-v1", Model: request.Model,
Input: struct { Input: struct {
Texts []string `json:"texts"` Texts []string `json:"texts"`
}{ }{

View File

@@ -109,9 +109,6 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, re
return nil, errors.New("request is nil") return nil, errors.New("request is nil")
} }
switch info.RelayMode { switch info.RelayMode {
case constant.RelayModeEmbeddings:
baiduEmbeddingRequest := embeddingRequestOpenAI2Baidu(*request)
return baiduEmbeddingRequest, nil
default: default:
baiduRequest := requestOpenAI2Baidu(*request) baiduRequest := requestOpenAI2Baidu(*request)
return baiduRequest, nil return baiduRequest, nil
@@ -123,8 +120,8 @@ func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dt
} }
func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.EmbeddingRequest) (any, error) { func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.EmbeddingRequest) (any, error) {
//TODO implement me baiduEmbeddingRequest := embeddingRequestOpenAI2Baidu(request)
return nil, errors.New("not implemented") return baiduEmbeddingRequest, nil
} }
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) { func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {

View File

@@ -87,7 +87,7 @@ func streamResponseBaidu2OpenAI(baiduResponse *BaiduChatStreamResponse) *dto.Cha
return &response return &response
} }
func embeddingRequestOpenAI2Baidu(request dto.GeneralOpenAIRequest) *BaiduEmbeddingRequest { func embeddingRequestOpenAI2Baidu(request dto.EmbeddingRequest) *BaiduEmbeddingRequest {
return &BaiduEmbeddingRequest{ return &BaiduEmbeddingRequest{
Input: request.ParseInput(), Input: request.ParseInput(),
} }

View File

@@ -57,11 +57,9 @@ func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dt
} }
func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.EmbeddingRequest) (any, error) { func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.EmbeddingRequest) (any, error) {
//TODO implement me return request, nil
return nil, errors.New("not implemented")
} }
func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) { func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) {
// 添加文件字段 // 添加文件字段
file, _, err := c.Request.FormFile("file") file, _, err := c.Request.FormFile("file")

View File

@@ -56,11 +56,9 @@ func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dt
} }
func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.EmbeddingRequest) (any, error) { func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.EmbeddingRequest) (any, error) {
//TODO implement me return request, nil
return nil, errors.New("not implemented")
} }
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) { func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) {
if info.RelayMode == constant.RelayModeRerank { if info.RelayMode == constant.RelayModeRerank {
err, usage = jinaRerankHandler(c, resp) err, usage = jinaRerankHandler(c, resp)

View File

@@ -46,12 +46,7 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, re
if request == nil { if request == nil {
return nil, errors.New("request is nil") return nil, errors.New("request is nil")
} }
switch info.RelayMode { return requestOpenAI2Ollama(*request), nil
case relayconstant.RelayModeEmbeddings:
return requestOpenAI2Embeddings(*request), nil
default:
return requestOpenAI2Ollama(*request), nil
}
} }
func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) { func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) {
@@ -59,11 +54,9 @@ func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dt
} }
func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.EmbeddingRequest) (any, error) { func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.EmbeddingRequest) (any, error) {
//TODO implement me return requestOpenAI2Embeddings(request), nil
return nil, errors.New("not implemented")
} }
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) { func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
return channel.DoApiRequest(a, c, info, requestBody) return channel.DoApiRequest(a, c, info, requestBody)
} }

View File

@@ -42,7 +42,7 @@ func requestOpenAI2Ollama(request dto.GeneralOpenAIRequest) *OllamaRequest {
} }
} }
func requestOpenAI2Embeddings(request dto.GeneralOpenAIRequest) *OllamaEmbeddingRequest { func requestOpenAI2Embeddings(request dto.EmbeddingRequest) *OllamaEmbeddingRequest {
return &OllamaEmbeddingRequest{ return &OllamaEmbeddingRequest{
Model: request.Model, Model: request.Model,
Input: request.ParseInput(), Input: request.ParseInput(),
@@ -123,9 +123,9 @@ func ollamaEmbeddingHandler(c *gin.Context, resp *http.Response, promptTokens in
} }
func flattenEmbeddings(embeddings [][]float64) []float64 { func flattenEmbeddings(embeddings [][]float64) []float64 {
flattened := []float64{} flattened := []float64{}
for _, row := range embeddings { for _, row := range embeddings {
flattened = append(flattened, row...) flattened = append(flattened, row...)
}
return flattened
} }
return flattened
}

View File

@@ -150,8 +150,7 @@ func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dt
} }
func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.EmbeddingRequest) (any, error) { func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.EmbeddingRequest) (any, error) {
//TODO implement me return request, nil
return nil, errors.New("not implemented")
} }
func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) { func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) {

View File

@@ -59,11 +59,9 @@ func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dt
} }
func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.EmbeddingRequest) (any, error) { func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.EmbeddingRequest) (any, error) {
//TODO implement me return request, nil
return nil, errors.New("not implemented")
} }
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) { func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) {
switch info.RelayMode { switch info.RelayMode {
case constant.RelayModeRerank: case constant.RelayModeRerank:

View File

@@ -19,7 +19,20 @@ func getEmbeddingPromptToken(embeddingRequest dto.EmbeddingRequest) int {
return token return token
} }
func EmbeddingHelper(c *gin.Context, relayMode int) (openaiErr *dto.OpenAIErrorWithStatusCode) { func validateEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, embeddingRequest dto.EmbeddingRequest) error {
if embeddingRequest.Input == nil {
return fmt.Errorf("input is empty")
}
if info.RelayMode == relayconstant.RelayModeModerations && embeddingRequest.Model == "" {
embeddingRequest.Model = "omni-moderation-latest"
}
if info.RelayMode == relayconstant.RelayModeEmbeddings && embeddingRequest.Model == "" {
embeddingRequest.Model = c.Param("model")
}
return nil
}
func EmbeddingHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) {
relayInfo := relaycommon.GenRelayInfo(c) relayInfo := relaycommon.GenRelayInfo(c)
var embeddingRequest *dto.EmbeddingRequest var embeddingRequest *dto.EmbeddingRequest
@@ -28,15 +41,12 @@ func EmbeddingHelper(c *gin.Context, relayMode int) (openaiErr *dto.OpenAIErrorW
common.LogError(c, fmt.Sprintf("getAndValidateTextRequest failed: %s", err.Error())) common.LogError(c, fmt.Sprintf("getAndValidateTextRequest failed: %s", err.Error()))
return service.OpenAIErrorWrapperLocal(err, "invalid_text_request", http.StatusBadRequest) return service.OpenAIErrorWrapperLocal(err, "invalid_text_request", http.StatusBadRequest)
} }
if relayMode == relayconstant.RelayModeModerations && embeddingRequest.Model == "" {
embeddingRequest.Model = "m3e-base" err = validateEmbeddingRequest(c, relayInfo, *embeddingRequest)
} if err != nil {
if relayMode == relayconstant.RelayModeEmbeddings && embeddingRequest.Model == "" { return service.OpenAIErrorWrapperLocal(err, "invalid_embedding_request", http.StatusBadRequest)
embeddingRequest.Model = c.Param("model")
}
if embeddingRequest.Input == nil || len(embeddingRequest.Input) == 0 {
return service.OpenAIErrorWrapperLocal(fmt.Errorf("input is empty"), "invalid_input", http.StatusBadRequest)
} }
// map model name // map model name
modelMapping := c.GetString("model_mapping") modelMapping := c.GetString("model_mapping")
//isModelMapped := false //isModelMapped := false
@@ -89,8 +99,8 @@ func EmbeddingHelper(c *gin.Context, relayMode int) (openaiErr *dto.OpenAIErrorW
} }
adaptor.Init(relayInfo) adaptor.Init(relayInfo)
convertedRequest, err := adaptor.ConvertEmbeddingRequest(c,relayInfo,*embeddingRequest) convertedRequest, err := adaptor.ConvertEmbeddingRequest(c, relayInfo, *embeddingRequest)
if err != nil { if err != nil {
return service.OpenAIErrorWrapperLocal(err, "convert_request_failed", http.StatusInternalServerError) return service.OpenAIErrorWrapperLocal(err, "convert_request_failed", http.StatusInternalServerError)
} }
@@ -100,7 +110,7 @@ func EmbeddingHelper(c *gin.Context, relayMode int) (openaiErr *dto.OpenAIErrorW
} }
requestBody := bytes.NewBuffer(jsonData) requestBody := bytes.NewBuffer(jsonData)
statusCodeMappingStr := c.GetString("status_code_mapping") statusCodeMappingStr := c.GetString("status_code_mapping")
resp, err := adaptor.DoRequest(c,relayInfo, requestBody) resp, err := adaptor.DoRequest(c, relayInfo, requestBody)
if err != nil { if err != nil {
return service.OpenAIErrorWrapper(err, "do_request_failed", http.StatusInternalServerError) return service.OpenAIErrorWrapper(err, "do_request_failed", http.StatusInternalServerError)
} }

View File

@@ -4,7 +4,6 @@ import (
"encoding/json" "encoding/json"
"errors" "errors"
"fmt" "fmt"
"github.com/pkoukk/tiktoken-go"
"image" "image"
"log" "log"
"math" "math"
@@ -14,6 +13,8 @@ import (
relaycommon "one-api/relay/common" relaycommon "one-api/relay/common"
"strings" "strings"
"unicode/utf8" "unicode/utf8"
"github.com/pkoukk/tiktoken-go"
) )
// tokenEncoderMap won't grow after initialization // tokenEncoderMap won't grow after initialization
@@ -323,6 +324,12 @@ func CountTokenInput(input any, model string) (int, error) {
text += s text += s
} }
return CountTextToken(text, model) return CountTextToken(text, model)
case []interface{}:
text := ""
for _, item := range v {
text += fmt.Sprintf("%v", item)
}
return CountTextToken(text, model)
} }
return CountTokenInput(fmt.Sprintf("%v", input), model) return CountTokenInput(fmt.Sprintf("%v", input), model)
} }