feat: ali rerank
This commit is contained in:
@@ -31,6 +31,8 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
|||||||
switch info.RelayMode {
|
switch info.RelayMode {
|
||||||
case constant.RelayModeEmbeddings:
|
case constant.RelayModeEmbeddings:
|
||||||
fullRequestURL = fmt.Sprintf("%s/api/v1/services/embeddings/text-embedding/text-embedding", info.BaseUrl)
|
fullRequestURL = fmt.Sprintf("%s/api/v1/services/embeddings/text-embedding/text-embedding", info.BaseUrl)
|
||||||
|
case constant.RelayModeRerank:
|
||||||
|
fullRequestURL = fmt.Sprintf("%s/api/v1/services/rerank/text-rerank/text-rerank", info.BaseUrl)
|
||||||
case constant.RelayModeImagesGenerations:
|
case constant.RelayModeImagesGenerations:
|
||||||
fullRequestURL = fmt.Sprintf("%s/api/v1/services/aigc/text2image/image-synthesis", info.BaseUrl)
|
fullRequestURL = fmt.Sprintf("%s/api/v1/services/aigc/text2image/image-synthesis", info.BaseUrl)
|
||||||
case constant.RelayModeCompletions:
|
case constant.RelayModeCompletions:
|
||||||
@@ -76,7 +78,7 @@ func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInf
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) {
|
func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) {
|
||||||
return nil, errors.New("not implemented")
|
return ConvertRerankRequest(request), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.EmbeddingRequest) (any, error) {
|
func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.EmbeddingRequest) (any, error) {
|
||||||
@@ -103,6 +105,8 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycom
|
|||||||
err, usage = aliImageHandler(c, resp, info)
|
err, usage = aliImageHandler(c, resp, info)
|
||||||
case constant.RelayModeEmbeddings:
|
case constant.RelayModeEmbeddings:
|
||||||
err, usage = aliEmbeddingHandler(c, resp)
|
err, usage = aliEmbeddingHandler(c, resp)
|
||||||
|
case constant.RelayModeRerank:
|
||||||
|
err, usage = RerankHandler(c, resp, info)
|
||||||
default:
|
default:
|
||||||
if info.IsStream {
|
if info.IsStream {
|
||||||
err, usage = openai.OaiStreamHandler(c, resp, info)
|
err, usage = openai.OaiStreamHandler(c, resp, info)
|
||||||
|
|||||||
@@ -8,6 +8,7 @@ var ModelList = []string{
|
|||||||
"qwq-32b",
|
"qwq-32b",
|
||||||
"qwen3-235b-a22b",
|
"qwen3-235b-a22b",
|
||||||
"text-embedding-v1",
|
"text-embedding-v1",
|
||||||
|
"gte-rerank-v2",
|
||||||
}
|
}
|
||||||
|
|
||||||
var ChannelName = "ali"
|
var ChannelName = "ali"
|
||||||
|
|||||||
@@ -1,5 +1,7 @@
|
|||||||
package ali
|
package ali
|
||||||
|
|
||||||
|
import "one-api/dto"
|
||||||
|
|
||||||
type AliMessage struct {
|
type AliMessage struct {
|
||||||
Content string `json:"content"`
|
Content string `json:"content"`
|
||||||
Role string `json:"role"`
|
Role string `json:"role"`
|
||||||
@@ -97,3 +99,28 @@ type AliImageRequest struct {
|
|||||||
} `json:"parameters,omitempty"`
|
} `json:"parameters,omitempty"`
|
||||||
ResponseFormat string `json:"response_format,omitempty"`
|
ResponseFormat string `json:"response_format,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type AliRerankParameters struct {
|
||||||
|
TopN *int `json:"top_n,omitempty"`
|
||||||
|
ReturnDocuments *bool `json:"return_documents,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type AliRerankInput struct {
|
||||||
|
Query string `json:"query"`
|
||||||
|
Documents []any `json:"documents"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type AliRerankRequest struct {
|
||||||
|
Model string `json:"model"`
|
||||||
|
Input AliRerankInput `json:"input"`
|
||||||
|
Parameters AliRerankParameters `json:"parameters,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type AliRerankResponse struct {
|
||||||
|
Output struct {
|
||||||
|
Results []dto.RerankResponseResult `json:"results"`
|
||||||
|
} `json:"output"`
|
||||||
|
Usage AliUsage `json:"usage"`
|
||||||
|
RequestId string `json:"request_id"`
|
||||||
|
AliError
|
||||||
|
}
|
||||||
|
|||||||
83
relay/channel/ali/rerank.go
Normal file
83
relay/channel/ali/rerank.go
Normal file
@@ -0,0 +1,83 @@
|
|||||||
|
package ali
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"one-api/dto"
|
||||||
|
relaycommon "one-api/relay/common"
|
||||||
|
"one-api/service"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
)
|
||||||
|
|
||||||
|
func ConvertRerankRequest(request dto.RerankRequest) *AliRerankRequest {
|
||||||
|
returnDocuments := request.ReturnDocuments
|
||||||
|
if returnDocuments == nil {
|
||||||
|
t := true
|
||||||
|
returnDocuments = &t
|
||||||
|
}
|
||||||
|
return &AliRerankRequest{
|
||||||
|
Model: request.Model,
|
||||||
|
Input: AliRerankInput{
|
||||||
|
Query: request.Query,
|
||||||
|
Documents: request.Documents,
|
||||||
|
},
|
||||||
|
Parameters: AliRerankParameters{
|
||||||
|
TopN: &request.TopN,
|
||||||
|
ReturnDocuments: returnDocuments,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func RerankHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
|
||||||
|
responseBody, err := io.ReadAll(resp.Body)
|
||||||
|
if err != nil {
|
||||||
|
return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
|
||||||
|
}
|
||||||
|
err = resp.Body.Close()
|
||||||
|
if err != nil {
|
||||||
|
return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
var aliResponse AliRerankResponse
|
||||||
|
err = json.Unmarshal(responseBody, &aliResponse)
|
||||||
|
if err != nil {
|
||||||
|
return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if aliResponse.Code != "" {
|
||||||
|
return &dto.OpenAIErrorWithStatusCode{
|
||||||
|
Error: dto.OpenAIError{
|
||||||
|
Message: aliResponse.Message,
|
||||||
|
Type: aliResponse.Code,
|
||||||
|
Param: aliResponse.RequestId,
|
||||||
|
Code: aliResponse.Code,
|
||||||
|
},
|
||||||
|
StatusCode: resp.StatusCode,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
usage := dto.Usage{
|
||||||
|
PromptTokens: aliResponse.Usage.TotalTokens,
|
||||||
|
CompletionTokens: 0,
|
||||||
|
TotalTokens: aliResponse.Usage.TotalTokens,
|
||||||
|
}
|
||||||
|
rerankResponse := dto.RerankResponse{
|
||||||
|
Results: aliResponse.Output.Results,
|
||||||
|
Usage: usage,
|
||||||
|
}
|
||||||
|
|
||||||
|
jsonResponse, err := json.Marshal(rerankResponse)
|
||||||
|
if err != nil {
|
||||||
|
return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
|
||||||
|
}
|
||||||
|
c.Writer.Header().Set("Content-Type", "application/json")
|
||||||
|
c.Writer.WriteHeader(resp.StatusCode)
|
||||||
|
_, err = c.Writer.Write(jsonResponse)
|
||||||
|
if err != nil {
|
||||||
|
return service.OpenAIErrorWrapper(err, "write_response_body_failed", http.StatusInternalServerError), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil, &usage
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user