feat: support xinference rerank to jina format

This commit is contained in:
1808837298@qq.com
2025-03-16 21:06:29 +08:00
parent 53b3599827
commit d1c62a583d
9 changed files with 85 additions and 22 deletions

View File

@@ -10,13 +10,17 @@ type RerankRequest struct {
OverLapTokens int `json:"overlap_tokens,omitempty"`
}
type RerankResponseDocument struct {
type RerankResponseResult struct {
Document any `json:"document,omitempty"`
Index int `json:"index"`
RelevanceScore float64 `json:"relevance_score"`
}
type RerankResponse struct {
Results []RerankResponseDocument `json:"results"`
Usage Usage `json:"usage"`
type RerankDocument struct {
Text any `json:"text"`
}
type RerankResponse struct {
Results []RerankResponseResult `json:"results"`
Usage Usage `json:"usage"`
}

View File

@@ -40,8 +40,8 @@ type CohereRerankRequest struct {
}
type CohereRerankResponseResult struct {
Results []dto.RerankResponseDocument `json:"results"`
Meta CohereMeta `json:"meta"`
Results []dto.RerankResponseResult `json:"results"`
Meta CohereMeta `json:"meta"`
}
type CohereMeta struct {

View File

@@ -69,7 +69,7 @@ func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.Rela
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) {
if info.RelayMode == constant.RelayModeRerank {
err, usage = common_handler.RerankHandler(c, resp)
err, usage = common_handler.RerankHandler(c, info, resp)
} else if info.RelayMode == constant.RelayModeEmbeddings {
err, usage = openai.OpenaiHandler(c, resp, info)
}

View File

@@ -262,7 +262,7 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycom
case constant.RelayModeImagesGenerations:
err, usage = OpenaiTTSHandler(c, resp, info)
case constant.RelayModeRerank:
err, usage = common_handler.RerankHandler(c, resp)
err, usage = common_handler.RerankHandler(c, info, resp)
default:
if info.IsStream {
err, usage = OaiStreamHandler(c, resp, info)

View File

@@ -12,6 +12,6 @@ type SFMeta struct {
}
type SFRerankResponse struct {
Results []dto.RerankResponseDocument `json:"results"`
Meta SFMeta `json:"meta"`
Results []dto.RerankResponseResult `json:"results"`
Meta SFMeta `json:"meta"`
}

View File

@@ -0,0 +1,11 @@
package xinference
type XinRerankResponseDocument struct {
Document string `json:"document,omitempty"`
Index int `json:"index"`
RelevanceScore float64 `json:"relevance_score"`
}
type XinRerankResponse struct {
Results []XinRerankResponseDocument `json:"results"`
}

View File

@@ -33,6 +33,10 @@ const (
RelayFormatClaude = "claude"
)
type RerankerInfo struct {
Documents []any
}
type RelayInfo struct {
ChannelType int
ChannelId int
@@ -78,6 +82,7 @@ type RelayInfo struct {
SendResponseCount int
ThinkingContentInfo
ClaudeConvertInfo
*RerankerInfo
}
// 定义支持流式选项的通道类型
@@ -111,6 +116,15 @@ func GenRelayInfoClaude(c *gin.Context) *RelayInfo {
return info
}
func GenRelayInfoRerank(c *gin.Context, documents []any) *RelayInfo {
info := GenRelayInfo(c)
info.RelayMode = relayconstant.RelayModeRerank
info.RerankerInfo = &RerankerInfo{
Documents: documents,
}
return info
}
func GenRelayInfo(c *gin.Context) *RelayInfo {
channelType := c.GetInt("channel_type")
channelId := c.GetInt("channel_id")

View File

@@ -1,15 +1,17 @@
package common_handler
import (
"encoding/json"
"github.com/gin-gonic/gin"
"io"
"net/http"
"one-api/common"
"one-api/dto"
"one-api/relay/channel/xinference"
relaycommon "one-api/relay/common"
"one-api/service"
)
func RerankHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
func RerankHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
@@ -18,18 +20,48 @@ func RerankHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWithSta
if err != nil {
return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
}
if common.DebugEnabled {
println("reranker response body: ", string(responseBody))
}
var jinaResp dto.RerankResponse
err = json.Unmarshal(responseBody, &jinaResp)
if err != nil {
return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
if info.ChannelType == common.ChannelTypeXinference {
var xinRerankResponse xinference.XinRerankResponse
err = common.DecodeJson(responseBody, &xinRerankResponse)
if err != nil {
return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
}
jinaRespResults := make([]dto.RerankResponseResult, len(xinRerankResponse.Results))
for i, result := range xinRerankResponse.Results {
var document any
if result.Document == "" {
document = info.Documents[result.Index]
} else {
document = result.Document
}
jinaRespResults[i] = dto.RerankResponseResult{
Index: result.Index,
RelevanceScore: result.RelevanceScore,
Document: dto.RerankDocument{
Text: document,
},
}
}
jinaResp = dto.RerankResponse{
Results: jinaRespResults,
Usage: dto.Usage{
PromptTokens: info.PromptTokens,
TotalTokens: info.PromptTokens,
},
}
} else {
err = common.DecodeJson(responseBody, &jinaResp)
if err != nil {
return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
}
jinaResp.Usage.PromptTokens = jinaResp.Usage.TotalTokens
}
jsonResponse, err := json.Marshal(jinaResp)
if err != nil {
return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
}
c.Writer.Header().Set("Content-Type", "application/json")
c.Writer.WriteHeader(resp.StatusCode)
_, err = c.Writer.Write(jsonResponse)
c.JSON(http.StatusOK, jinaResp)
return nil, &jinaResp.Usage
}

View File

@@ -25,7 +25,6 @@ func getRerankPromptToken(rerankRequest dto.RerankRequest) int {
}
func RerankHelper(c *gin.Context, relayMode int) (openaiErr *dto.OpenAIErrorWithStatusCode) {
relayInfo := relaycommon.GenRelayInfo(c)
var rerankRequest *dto.RerankRequest
err := common.UnmarshalBodyReusable(c, &rerankRequest)
@@ -33,6 +32,9 @@ func RerankHelper(c *gin.Context, relayMode int) (openaiErr *dto.OpenAIErrorWith
common.LogError(c, fmt.Sprintf("getAndValidateTextRequest failed: %s", err.Error()))
return service.OpenAIErrorWrapperLocal(err, "invalid_text_request", http.StatusBadRequest)
}
relayInfo := relaycommon.GenRelayInfoRerank(c, rerankRequest.Documents)
if rerankRequest.Query == "" {
return service.OpenAIErrorWrapperLocal(fmt.Errorf("query is empty"), "invalid_query", http.StatusBadRequest)
}