feat: Improve embedding request handling and support across channels
- Update EmbeddingRequest DTO to support more flexible input types - Add input parsing method to handle various input formats - Implement ConvertEmbeddingRequest for multiple channel adaptors - Remove relayMode parameter from EmbeddingHelper - Add input validation for embedding requests - Simplify embedding request conversion for different channels
This commit is contained in:
@@ -19,7 +19,20 @@ func getEmbeddingPromptToken(embeddingRequest dto.EmbeddingRequest) int {
|
||||
return token
|
||||
}
|
||||
|
||||
func EmbeddingHelper(c *gin.Context, relayMode int) (openaiErr *dto.OpenAIErrorWithStatusCode) {
|
||||
func validateEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, embeddingRequest dto.EmbeddingRequest) error {
|
||||
if embeddingRequest.Input == nil {
|
||||
return fmt.Errorf("input is empty")
|
||||
}
|
||||
if info.RelayMode == relayconstant.RelayModeModerations && embeddingRequest.Model == "" {
|
||||
embeddingRequest.Model = "omni-moderation-latest"
|
||||
}
|
||||
if info.RelayMode == relayconstant.RelayModeEmbeddings && embeddingRequest.Model == "" {
|
||||
embeddingRequest.Model = c.Param("model")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func EmbeddingHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) {
|
||||
relayInfo := relaycommon.GenRelayInfo(c)
|
||||
|
||||
var embeddingRequest *dto.EmbeddingRequest
|
||||
@@ -28,15 +41,12 @@ func EmbeddingHelper(c *gin.Context, relayMode int) (openaiErr *dto.OpenAIErrorW
|
||||
common.LogError(c, fmt.Sprintf("getAndValidateTextRequest failed: %s", err.Error()))
|
||||
return service.OpenAIErrorWrapperLocal(err, "invalid_text_request", http.StatusBadRequest)
|
||||
}
|
||||
if relayMode == relayconstant.RelayModeModerations && embeddingRequest.Model == "" {
|
||||
embeddingRequest.Model = "m3e-base"
|
||||
}
|
||||
if relayMode == relayconstant.RelayModeEmbeddings && embeddingRequest.Model == "" {
|
||||
embeddingRequest.Model = c.Param("model")
|
||||
}
|
||||
if embeddingRequest.Input == nil || len(embeddingRequest.Input) == 0 {
|
||||
return service.OpenAIErrorWrapperLocal(fmt.Errorf("input is empty"), "invalid_input", http.StatusBadRequest)
|
||||
|
||||
err = validateEmbeddingRequest(c, relayInfo, *embeddingRequest)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapperLocal(err, "invalid_embedding_request", http.StatusBadRequest)
|
||||
}
|
||||
|
||||
// map model name
|
||||
modelMapping := c.GetString("model_mapping")
|
||||
//isModelMapped := false
|
||||
@@ -89,8 +99,8 @@ func EmbeddingHelper(c *gin.Context, relayMode int) (openaiErr *dto.OpenAIErrorW
|
||||
}
|
||||
adaptor.Init(relayInfo)
|
||||
|
||||
convertedRequest, err := adaptor.ConvertEmbeddingRequest(c,relayInfo,*embeddingRequest)
|
||||
|
||||
convertedRequest, err := adaptor.ConvertEmbeddingRequest(c, relayInfo, *embeddingRequest)
|
||||
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapperLocal(err, "convert_request_failed", http.StatusInternalServerError)
|
||||
}
|
||||
@@ -100,7 +110,7 @@ func EmbeddingHelper(c *gin.Context, relayMode int) (openaiErr *dto.OpenAIErrorW
|
||||
}
|
||||
requestBody := bytes.NewBuffer(jsonData)
|
||||
statusCodeMappingStr := c.GetString("status_code_mapping")
|
||||
resp, err := adaptor.DoRequest(c,relayInfo, requestBody)
|
||||
resp, err := adaptor.DoRequest(c, relayInfo, requestBody)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "do_request_failed", http.StatusInternalServerError)
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user