feat: support xinference rerank to jina format

This commit is contained in:
1808837298@qq.com
2025-03-16 21:06:29 +08:00
parent 53b3599827
commit d1c62a583d
9 changed files with 85 additions and 22 deletions

View File

@@ -10,13 +10,17 @@ type RerankRequest struct {
OverLapTokens int `json:"overlap_tokens,omitempty"` OverLapTokens int `json:"overlap_tokens,omitempty"`
} }
type RerankResponseDocument struct { type RerankResponseResult struct {
Document any `json:"document,omitempty"` Document any `json:"document,omitempty"`
Index int `json:"index"` Index int `json:"index"`
RelevanceScore float64 `json:"relevance_score"` RelevanceScore float64 `json:"relevance_score"`
} }
type RerankResponse struct { type RerankDocument struct {
Results []RerankResponseDocument `json:"results"` Text any `json:"text"`
Usage Usage `json:"usage"` }
type RerankResponse struct {
Results []RerankResponseResult `json:"results"`
Usage Usage `json:"usage"`
} }

View File

@@ -40,8 +40,8 @@ type CohereRerankRequest struct {
} }
type CohereRerankResponseResult struct { type CohereRerankResponseResult struct {
Results []dto.RerankResponseDocument `json:"results"` Results []dto.RerankResponseResult `json:"results"`
Meta CohereMeta `json:"meta"` Meta CohereMeta `json:"meta"`
} }
type CohereMeta struct { type CohereMeta struct {

View File

@@ -69,7 +69,7 @@ func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.Rela
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) { func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) {
if info.RelayMode == constant.RelayModeRerank { if info.RelayMode == constant.RelayModeRerank {
err, usage = common_handler.RerankHandler(c, resp) err, usage = common_handler.RerankHandler(c, info, resp)
} else if info.RelayMode == constant.RelayModeEmbeddings { } else if info.RelayMode == constant.RelayModeEmbeddings {
err, usage = openai.OpenaiHandler(c, resp, info) err, usage = openai.OpenaiHandler(c, resp, info)
} }

View File

@@ -262,7 +262,7 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycom
case constant.RelayModeImagesGenerations: case constant.RelayModeImagesGenerations:
err, usage = OpenaiTTSHandler(c, resp, info) err, usage = OpenaiTTSHandler(c, resp, info)
case constant.RelayModeRerank: case constant.RelayModeRerank:
err, usage = common_handler.RerankHandler(c, resp) err, usage = common_handler.RerankHandler(c, info, resp)
default: default:
if info.IsStream { if info.IsStream {
err, usage = OaiStreamHandler(c, resp, info) err, usage = OaiStreamHandler(c, resp, info)

View File

@@ -12,6 +12,6 @@ type SFMeta struct {
} }
type SFRerankResponse struct { type SFRerankResponse struct {
Results []dto.RerankResponseDocument `json:"results"` Results []dto.RerankResponseResult `json:"results"`
Meta SFMeta `json:"meta"` Meta SFMeta `json:"meta"`
} }

View File

@@ -0,0 +1,11 @@
package xinference
type XinRerankResponseDocument struct {
Document string `json:"document,omitempty"`
Index int `json:"index"`
RelevanceScore float64 `json:"relevance_score"`
}
type XinRerankResponse struct {
Results []XinRerankResponseDocument `json:"results"`
}

View File

@@ -33,6 +33,10 @@ const (
RelayFormatClaude = "claude" RelayFormatClaude = "claude"
) )
type RerankerInfo struct {
Documents []any
}
type RelayInfo struct { type RelayInfo struct {
ChannelType int ChannelType int
ChannelId int ChannelId int
@@ -78,6 +82,7 @@ type RelayInfo struct {
SendResponseCount int SendResponseCount int
ThinkingContentInfo ThinkingContentInfo
ClaudeConvertInfo ClaudeConvertInfo
*RerankerInfo
} }
// 定义支持流式选项的通道类型 // 定义支持流式选项的通道类型
@@ -111,6 +116,15 @@ func GenRelayInfoClaude(c *gin.Context) *RelayInfo {
return info return info
} }
func GenRelayInfoRerank(c *gin.Context, documents []any) *RelayInfo {
info := GenRelayInfo(c)
info.RelayMode = relayconstant.RelayModeRerank
info.RerankerInfo = &RerankerInfo{
Documents: documents,
}
return info
}
func GenRelayInfo(c *gin.Context) *RelayInfo { func GenRelayInfo(c *gin.Context) *RelayInfo {
channelType := c.GetInt("channel_type") channelType := c.GetInt("channel_type")
channelId := c.GetInt("channel_id") channelId := c.GetInt("channel_id")

View File

@@ -1,15 +1,17 @@
package common_handler package common_handler
import ( import (
"encoding/json"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"io" "io"
"net/http" "net/http"
"one-api/common"
"one-api/dto" "one-api/dto"
"one-api/relay/channel/xinference"
relaycommon "one-api/relay/common"
"one-api/service" "one-api/service"
) )
func RerankHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) { func RerankHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
responseBody, err := io.ReadAll(resp.Body) responseBody, err := io.ReadAll(resp.Body)
if err != nil { if err != nil {
return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
@@ -18,18 +20,48 @@ func RerankHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWithSta
if err != nil { if err != nil {
return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
} }
if common.DebugEnabled {
println("reranker response body: ", string(responseBody))
}
var jinaResp dto.RerankResponse var jinaResp dto.RerankResponse
err = json.Unmarshal(responseBody, &jinaResp) if info.ChannelType == common.ChannelTypeXinference {
if err != nil { var xinRerankResponse xinference.XinRerankResponse
return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil err = common.DecodeJson(responseBody, &xinRerankResponse)
if err != nil {
return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
}
jinaRespResults := make([]dto.RerankResponseResult, len(xinRerankResponse.Results))
for i, result := range xinRerankResponse.Results {
var document any
if result.Document == "" {
document = info.Documents[result.Index]
} else {
document = result.Document
}
jinaRespResults[i] = dto.RerankResponseResult{
Index: result.Index,
RelevanceScore: result.RelevanceScore,
Document: dto.RerankDocument{
Text: document,
},
}
}
jinaResp = dto.RerankResponse{
Results: jinaRespResults,
Usage: dto.Usage{
PromptTokens: info.PromptTokens,
TotalTokens: info.PromptTokens,
},
}
} else {
err = common.DecodeJson(responseBody, &jinaResp)
if err != nil {
return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
}
jinaResp.Usage.PromptTokens = jinaResp.Usage.TotalTokens
} }
jsonResponse, err := json.Marshal(jinaResp)
if err != nil {
return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
}
c.Writer.Header().Set("Content-Type", "application/json") c.Writer.Header().Set("Content-Type", "application/json")
c.Writer.WriteHeader(resp.StatusCode) c.JSON(http.StatusOK, jinaResp)
_, err = c.Writer.Write(jsonResponse)
return nil, &jinaResp.Usage return nil, &jinaResp.Usage
} }

View File

@@ -25,7 +25,6 @@ func getRerankPromptToken(rerankRequest dto.RerankRequest) int {
} }
func RerankHelper(c *gin.Context, relayMode int) (openaiErr *dto.OpenAIErrorWithStatusCode) { func RerankHelper(c *gin.Context, relayMode int) (openaiErr *dto.OpenAIErrorWithStatusCode) {
relayInfo := relaycommon.GenRelayInfo(c)
var rerankRequest *dto.RerankRequest var rerankRequest *dto.RerankRequest
err := common.UnmarshalBodyReusable(c, &rerankRequest) err := common.UnmarshalBodyReusable(c, &rerankRequest)
@@ -33,6 +32,9 @@ func RerankHelper(c *gin.Context, relayMode int) (openaiErr *dto.OpenAIErrorWith
common.LogError(c, fmt.Sprintf("getAndValidateTextRequest failed: %s", err.Error())) common.LogError(c, fmt.Sprintf("getAndValidateTextRequest failed: %s", err.Error()))
return service.OpenAIErrorWrapperLocal(err, "invalid_text_request", http.StatusBadRequest) return service.OpenAIErrorWrapperLocal(err, "invalid_text_request", http.StatusBadRequest)
} }
relayInfo := relaycommon.GenRelayInfoRerank(c, rerankRequest.Documents)
if rerankRequest.Query == "" { if rerankRequest.Query == "" {
return service.OpenAIErrorWrapperLocal(fmt.Errorf("query is empty"), "invalid_query", http.StatusBadRequest) return service.OpenAIErrorWrapperLocal(fmt.Errorf("query is empty"), "invalid_query", http.StatusBadRequest)
} }