feat: support xinference rerank to jina format
This commit is contained in:
@@ -10,13 +10,17 @@ type RerankRequest struct {
|
|||||||
OverLapTokens int `json:"overlap_tokens,omitempty"`
|
OverLapTokens int `json:"overlap_tokens,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type RerankResponseDocument struct {
|
type RerankResponseResult struct {
|
||||||
Document any `json:"document,omitempty"`
|
Document any `json:"document,omitempty"`
|
||||||
Index int `json:"index"`
|
Index int `json:"index"`
|
||||||
RelevanceScore float64 `json:"relevance_score"`
|
RelevanceScore float64 `json:"relevance_score"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type RerankResponse struct {
|
type RerankDocument struct {
|
||||||
Results []RerankResponseDocument `json:"results"`
|
Text any `json:"text"`
|
||||||
Usage Usage `json:"usage"`
|
}
|
||||||
|
|
||||||
|
type RerankResponse struct {
|
||||||
|
Results []RerankResponseResult `json:"results"`
|
||||||
|
Usage Usage `json:"usage"`
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -40,8 +40,8 @@ type CohereRerankRequest struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type CohereRerankResponseResult struct {
|
type CohereRerankResponseResult struct {
|
||||||
Results []dto.RerankResponseDocument `json:"results"`
|
Results []dto.RerankResponseResult `json:"results"`
|
||||||
Meta CohereMeta `json:"meta"`
|
Meta CohereMeta `json:"meta"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type CohereMeta struct {
|
type CohereMeta struct {
|
||||||
|
|||||||
@@ -69,7 +69,7 @@ func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.Rela
|
|||||||
|
|
||||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) {
|
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) {
|
||||||
if info.RelayMode == constant.RelayModeRerank {
|
if info.RelayMode == constant.RelayModeRerank {
|
||||||
err, usage = common_handler.RerankHandler(c, resp)
|
err, usage = common_handler.RerankHandler(c, info, resp)
|
||||||
} else if info.RelayMode == constant.RelayModeEmbeddings {
|
} else if info.RelayMode == constant.RelayModeEmbeddings {
|
||||||
err, usage = openai.OpenaiHandler(c, resp, info)
|
err, usage = openai.OpenaiHandler(c, resp, info)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -262,7 +262,7 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycom
|
|||||||
case constant.RelayModeImagesGenerations:
|
case constant.RelayModeImagesGenerations:
|
||||||
err, usage = OpenaiTTSHandler(c, resp, info)
|
err, usage = OpenaiTTSHandler(c, resp, info)
|
||||||
case constant.RelayModeRerank:
|
case constant.RelayModeRerank:
|
||||||
err, usage = common_handler.RerankHandler(c, resp)
|
err, usage = common_handler.RerankHandler(c, info, resp)
|
||||||
default:
|
default:
|
||||||
if info.IsStream {
|
if info.IsStream {
|
||||||
err, usage = OaiStreamHandler(c, resp, info)
|
err, usage = OaiStreamHandler(c, resp, info)
|
||||||
|
|||||||
@@ -12,6 +12,6 @@ type SFMeta struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type SFRerankResponse struct {
|
type SFRerankResponse struct {
|
||||||
Results []dto.RerankResponseDocument `json:"results"`
|
Results []dto.RerankResponseResult `json:"results"`
|
||||||
Meta SFMeta `json:"meta"`
|
Meta SFMeta `json:"meta"`
|
||||||
}
|
}
|
||||||
|
|||||||
11
relay/channel/xinference/dto.go
Normal file
11
relay/channel/xinference/dto.go
Normal file
@@ -0,0 +1,11 @@
|
|||||||
|
package xinference
|
||||||
|
|
||||||
|
type XinRerankResponseDocument struct {
|
||||||
|
Document string `json:"document,omitempty"`
|
||||||
|
Index int `json:"index"`
|
||||||
|
RelevanceScore float64 `json:"relevance_score"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type XinRerankResponse struct {
|
||||||
|
Results []XinRerankResponseDocument `json:"results"`
|
||||||
|
}
|
||||||
@@ -33,6 +33,10 @@ const (
|
|||||||
RelayFormatClaude = "claude"
|
RelayFormatClaude = "claude"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
type RerankerInfo struct {
|
||||||
|
Documents []any
|
||||||
|
}
|
||||||
|
|
||||||
type RelayInfo struct {
|
type RelayInfo struct {
|
||||||
ChannelType int
|
ChannelType int
|
||||||
ChannelId int
|
ChannelId int
|
||||||
@@ -78,6 +82,7 @@ type RelayInfo struct {
|
|||||||
SendResponseCount int
|
SendResponseCount int
|
||||||
ThinkingContentInfo
|
ThinkingContentInfo
|
||||||
ClaudeConvertInfo
|
ClaudeConvertInfo
|
||||||
|
*RerankerInfo
|
||||||
}
|
}
|
||||||
|
|
||||||
// 定义支持流式选项的通道类型
|
// 定义支持流式选项的通道类型
|
||||||
@@ -111,6 +116,15 @@ func GenRelayInfoClaude(c *gin.Context) *RelayInfo {
|
|||||||
return info
|
return info
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func GenRelayInfoRerank(c *gin.Context, documents []any) *RelayInfo {
|
||||||
|
info := GenRelayInfo(c)
|
||||||
|
info.RelayMode = relayconstant.RelayModeRerank
|
||||||
|
info.RerankerInfo = &RerankerInfo{
|
||||||
|
Documents: documents,
|
||||||
|
}
|
||||||
|
return info
|
||||||
|
}
|
||||||
|
|
||||||
func GenRelayInfo(c *gin.Context) *RelayInfo {
|
func GenRelayInfo(c *gin.Context) *RelayInfo {
|
||||||
channelType := c.GetInt("channel_type")
|
channelType := c.GetInt("channel_type")
|
||||||
channelId := c.GetInt("channel_id")
|
channelId := c.GetInt("channel_id")
|
||||||
|
|||||||
@@ -1,15 +1,17 @@
|
|||||||
package common_handler
|
package common_handler
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"encoding/json"
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
"io"
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
"one-api/common"
|
||||||
"one-api/dto"
|
"one-api/dto"
|
||||||
|
"one-api/relay/channel/xinference"
|
||||||
|
relaycommon "one-api/relay/common"
|
||||||
"one-api/service"
|
"one-api/service"
|
||||||
)
|
)
|
||||||
|
|
||||||
func RerankHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
|
func RerankHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
|
||||||
responseBody, err := io.ReadAll(resp.Body)
|
responseBody, err := io.ReadAll(resp.Body)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
|
return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
|
||||||
@@ -18,18 +20,48 @@ func RerankHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWithSta
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
|
return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
|
||||||
}
|
}
|
||||||
|
if common.DebugEnabled {
|
||||||
|
println("reranker response body: ", string(responseBody))
|
||||||
|
}
|
||||||
var jinaResp dto.RerankResponse
|
var jinaResp dto.RerankResponse
|
||||||
err = json.Unmarshal(responseBody, &jinaResp)
|
if info.ChannelType == common.ChannelTypeXinference {
|
||||||
if err != nil {
|
var xinRerankResponse xinference.XinRerankResponse
|
||||||
return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
|
err = common.DecodeJson(responseBody, &xinRerankResponse)
|
||||||
|
if err != nil {
|
||||||
|
return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
|
||||||
|
}
|
||||||
|
jinaRespResults := make([]dto.RerankResponseResult, len(xinRerankResponse.Results))
|
||||||
|
for i, result := range xinRerankResponse.Results {
|
||||||
|
var document any
|
||||||
|
if result.Document == "" {
|
||||||
|
document = info.Documents[result.Index]
|
||||||
|
} else {
|
||||||
|
document = result.Document
|
||||||
|
}
|
||||||
|
jinaRespResults[i] = dto.RerankResponseResult{
|
||||||
|
Index: result.Index,
|
||||||
|
RelevanceScore: result.RelevanceScore,
|
||||||
|
Document: dto.RerankDocument{
|
||||||
|
Text: document,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
jinaResp = dto.RerankResponse{
|
||||||
|
Results: jinaRespResults,
|
||||||
|
Usage: dto.Usage{
|
||||||
|
PromptTokens: info.PromptTokens,
|
||||||
|
TotalTokens: info.PromptTokens,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
err = common.DecodeJson(responseBody, &jinaResp)
|
||||||
|
if err != nil {
|
||||||
|
return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
|
||||||
|
}
|
||||||
|
jinaResp.Usage.PromptTokens = jinaResp.Usage.TotalTokens
|
||||||
}
|
}
|
||||||
|
|
||||||
jsonResponse, err := json.Marshal(jinaResp)
|
|
||||||
if err != nil {
|
|
||||||
return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
|
|
||||||
}
|
|
||||||
c.Writer.Header().Set("Content-Type", "application/json")
|
c.Writer.Header().Set("Content-Type", "application/json")
|
||||||
c.Writer.WriteHeader(resp.StatusCode)
|
c.JSON(http.StatusOK, jinaResp)
|
||||||
_, err = c.Writer.Write(jsonResponse)
|
|
||||||
return nil, &jinaResp.Usage
|
return nil, &jinaResp.Usage
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -25,7 +25,6 @@ func getRerankPromptToken(rerankRequest dto.RerankRequest) int {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func RerankHelper(c *gin.Context, relayMode int) (openaiErr *dto.OpenAIErrorWithStatusCode) {
|
func RerankHelper(c *gin.Context, relayMode int) (openaiErr *dto.OpenAIErrorWithStatusCode) {
|
||||||
relayInfo := relaycommon.GenRelayInfo(c)
|
|
||||||
|
|
||||||
var rerankRequest *dto.RerankRequest
|
var rerankRequest *dto.RerankRequest
|
||||||
err := common.UnmarshalBodyReusable(c, &rerankRequest)
|
err := common.UnmarshalBodyReusable(c, &rerankRequest)
|
||||||
@@ -33,6 +32,9 @@ func RerankHelper(c *gin.Context, relayMode int) (openaiErr *dto.OpenAIErrorWith
|
|||||||
common.LogError(c, fmt.Sprintf("getAndValidateTextRequest failed: %s", err.Error()))
|
common.LogError(c, fmt.Sprintf("getAndValidateTextRequest failed: %s", err.Error()))
|
||||||
return service.OpenAIErrorWrapperLocal(err, "invalid_text_request", http.StatusBadRequest)
|
return service.OpenAIErrorWrapperLocal(err, "invalid_text_request", http.StatusBadRequest)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
relayInfo := relaycommon.GenRelayInfoRerank(c, rerankRequest.Documents)
|
||||||
|
|
||||||
if rerankRequest.Query == "" {
|
if rerankRequest.Query == "" {
|
||||||
return service.OpenAIErrorWrapperLocal(fmt.Errorf("query is empty"), "invalid_query", http.StatusBadRequest)
|
return service.OpenAIErrorWrapperLocal(fmt.Errorf("query is empty"), "invalid_query", http.StatusBadRequest)
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user