diff --git a/dto/rerank.go b/dto/rerank.go index dfa79633..38aca907 100644 --- a/dto/rerank.go +++ b/dto/rerank.go @@ -10,13 +10,17 @@ type RerankRequest struct { OverLapTokens int `json:"overlap_tokens,omitempty"` } -type RerankResponseDocument struct { +type RerankResponseResult struct { Document any `json:"document,omitempty"` Index int `json:"index"` RelevanceScore float64 `json:"relevance_score"` } -type RerankResponse struct { - Results []RerankResponseDocument `json:"results"` - Usage Usage `json:"usage"` +type RerankDocument struct { + Text any `json:"text"` +} + +type RerankResponse struct { + Results []RerankResponseResult `json:"results"` + Usage Usage `json:"usage"` } diff --git a/relay/channel/cohere/dto.go b/relay/channel/cohere/dto.go index e7452fd4..410540c0 100644 --- a/relay/channel/cohere/dto.go +++ b/relay/channel/cohere/dto.go @@ -40,8 +40,8 @@ type CohereRerankRequest struct { } type CohereRerankResponseResult struct { - Results []dto.RerankResponseDocument `json:"results"` - Meta CohereMeta `json:"meta"` + Results []dto.RerankResponseResult `json:"results"` + Meta CohereMeta `json:"meta"` } type CohereMeta struct { diff --git a/relay/channel/jina/adaptor.go b/relay/channel/jina/adaptor.go index ceffb79a..3faac243 100644 --- a/relay/channel/jina/adaptor.go +++ b/relay/channel/jina/adaptor.go @@ -69,7 +69,7 @@ func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.Rela func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) { if info.RelayMode == constant.RelayModeRerank { - err, usage = common_handler.RerankHandler(c, resp) + err, usage = common_handler.RerankHandler(c, info, resp) } else if info.RelayMode == constant.RelayModeEmbeddings { err, usage = openai.OpenaiHandler(c, resp, info) } diff --git a/relay/channel/openai/adaptor.go b/relay/channel/openai/adaptor.go index 91bc5066..a9f5b591 100644 --- a/relay/channel/openai/adaptor.go +++ b/relay/channel/openai/adaptor.go @@ -262,7 +262,7 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycom case constant.RelayModeImagesGenerations: err, usage = OpenaiTTSHandler(c, resp, info) case constant.RelayModeRerank: - err, usage = common_handler.RerankHandler(c, resp) + err, usage = common_handler.RerankHandler(c, info, resp) default: if info.IsStream { err, usage = OaiStreamHandler(c, resp, info) diff --git a/relay/channel/siliconflow/dto.go b/relay/channel/siliconflow/dto.go index 58cf81cd..add0fd07 100644 --- a/relay/channel/siliconflow/dto.go +++ b/relay/channel/siliconflow/dto.go @@ -12,6 +12,6 @@ type SFMeta struct { } type SFRerankResponse struct { - Results []dto.RerankResponseDocument `json:"results"` - Meta SFMeta `json:"meta"` + Results []dto.RerankResponseResult `json:"results"` + Meta SFMeta `json:"meta"` } diff --git a/relay/channel/xinference/dto.go b/relay/channel/xinference/dto.go new file mode 100644 index 00000000..2f12ad10 --- /dev/null +++ b/relay/channel/xinference/dto.go @@ -0,0 +1,11 @@ +package xinference + +type XinRerankResponseDocument struct { + Document string `json:"document,omitempty"` + Index int `json:"index"` + RelevanceScore float64 `json:"relevance_score"` +} + +type XinRerankResponse struct { + Results []XinRerankResponseDocument `json:"results"` +} diff --git a/relay/common/relay_info.go b/relay/common/relay_info.go index baabd3e7..6b96419a 100644 --- a/relay/common/relay_info.go +++ b/relay/common/relay_info.go @@ -33,6 +33,10 @@ const ( RelayFormatClaude = "claude" ) +type RerankerInfo struct { + Documents []any +} + type RelayInfo struct { ChannelType int ChannelId int @@ -78,6 +82,7 @@ type RelayInfo struct { SendResponseCount int ThinkingContentInfo ClaudeConvertInfo + *RerankerInfo } // 定义支持流式选项的通道类型 @@ -111,6 +116,15 @@ func GenRelayInfoClaude(c *gin.Context) *RelayInfo { return info } +func GenRelayInfoRerank(c *gin.Context, documents []any) *RelayInfo { + info := GenRelayInfo(c) + info.RelayMode = relayconstant.RelayModeRerank + info.RerankerInfo = &RerankerInfo{ + Documents: documents, + } + return info +} + func GenRelayInfo(c *gin.Context) *RelayInfo { channelType := c.GetInt("channel_type") channelId := c.GetInt("channel_id") diff --git a/relay/common_handler/rerank.go b/relay/common_handler/rerank.go index f33da85c..05aaa8ae 100644 --- a/relay/common_handler/rerank.go +++ b/relay/common_handler/rerank.go @@ -1,15 +1,17 @@ package common_handler import ( - "encoding/json" "github.com/gin-gonic/gin" "io" "net/http" + "one-api/common" "one-api/dto" + "one-api/relay/channel/xinference" + relaycommon "one-api/relay/common" "one-api/service" ) -func RerankHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) { +func RerankHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) { responseBody, err := io.ReadAll(resp.Body) if err != nil { return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil @@ -18,18 +20,48 @@ func RerankHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWithSta if err != nil { return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil } + if common.DebugEnabled { + println("reranker response body: ", string(responseBody)) + } var jinaResp dto.RerankResponse - err = json.Unmarshal(responseBody, &jinaResp) - if err != nil { - return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil + if info.ChannelType == common.ChannelTypeXinference { + var xinRerankResponse xinference.XinRerankResponse + err = common.DecodeJson(responseBody, &xinRerankResponse) + if err != nil { + return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil + } + jinaRespResults := make([]dto.RerankResponseResult, len(xinRerankResponse.Results)) + for i, result := range xinRerankResponse.Results { + var document any + if result.Document == "" { + document = info.Documents[result.Index] + } else { + document = result.Document + } + jinaRespResults[i] = dto.RerankResponseResult{ + Index: result.Index, + RelevanceScore: result.RelevanceScore, + Document: dto.RerankDocument{ + Text: document, + }, + } + } + jinaResp = dto.RerankResponse{ + Results: jinaRespResults, + Usage: dto.Usage{ + PromptTokens: info.PromptTokens, + TotalTokens: info.PromptTokens, + }, + } + } else { + err = common.DecodeJson(responseBody, &jinaResp) + if err != nil { + return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil + } + jinaResp.Usage.PromptTokens = jinaResp.Usage.TotalTokens } - jsonResponse, err := json.Marshal(jinaResp) - if err != nil { - return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil - } c.Writer.Header().Set("Content-Type", "application/json") - c.Writer.WriteHeader(resp.StatusCode) - _, err = c.Writer.Write(jsonResponse) + c.JSON(http.StatusOK, jinaResp) return nil, &jinaResp.Usage } diff --git a/relay/relay_rerank.go b/relay/relay_rerank.go index 201166d6..69ab2247 100644 --- a/relay/relay_rerank.go +++ b/relay/relay_rerank.go @@ -25,7 +25,6 @@ func getRerankPromptToken(rerankRequest dto.RerankRequest) int { } func RerankHelper(c *gin.Context, relayMode int) (openaiErr *dto.OpenAIErrorWithStatusCode) { - relayInfo := relaycommon.GenRelayInfo(c) var rerankRequest *dto.RerankRequest err := common.UnmarshalBodyReusable(c, &rerankRequest) @@ -33,6 +32,9 @@ func RerankHelper(c *gin.Context, relayMode int) (openaiErr *dto.OpenAIErrorWith common.LogError(c, fmt.Sprintf("getAndValidateTextRequest failed: %s", err.Error())) return service.OpenAIErrorWrapperLocal(err, "invalid_text_request", http.StatusBadRequest) } + + relayInfo := relaycommon.GenRelayInfoRerank(c, rerankRequest.Documents) + if rerankRequest.Query == "" { return service.OpenAIErrorWrapperLocal(fmt.Errorf("query is empty"), "invalid_query", http.StatusBadRequest) }