feat: Improve embedding request handling and support across channels
- Update EmbeddingRequest DTO to support more flexible input types - Add input parsing method to handle various input formats - Implement ConvertEmbeddingRequest for multiple channel adaptors - Remove relayMode parameter from EmbeddingHelper - Add input validation for embedding requests - Simplify embedding request conversion for different channels
This commit is contained in:
@@ -34,7 +34,7 @@ func relayHandler(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusCode
|
|||||||
case relayconstant.RelayModeRerank:
|
case relayconstant.RelayModeRerank:
|
||||||
err = relay.RerankHelper(c, relayMode)
|
err = relay.RerankHelper(c, relayMode)
|
||||||
case relayconstant.RelayModeEmbeddings:
|
case relayconstant.RelayModeEmbeddings:
|
||||||
err = relay.EmbeddingHelper(c,relayMode)
|
err = relay.EmbeddingHelper(c)
|
||||||
default:
|
default:
|
||||||
err = relay.TextHelper(c)
|
err = relay.TextHelper(c)
|
||||||
}
|
}
|
||||||
@@ -57,11 +57,6 @@ func Relay(c *gin.Context) {
|
|||||||
originalModel := c.GetString("original_model")
|
originalModel := c.GetString("original_model")
|
||||||
var openaiErr *dto.OpenAIErrorWithStatusCode
|
var openaiErr *dto.OpenAIErrorWithStatusCode
|
||||||
|
|
||||||
//获取request body 并输出到日志
|
|
||||||
requestBody, _ := common.GetRequestBody(c)
|
|
||||||
common.LogInfo(c, fmt.Sprintf("relayMode: %d ,request body: %s",relayMode, string(requestBody)))
|
|
||||||
|
|
||||||
|
|
||||||
for i := 0; i <= common.RetryTimes; i++ {
|
for i := 0; i <= common.RetryTimes; i++ {
|
||||||
channel, err := getChannel(c, group, originalModel, i)
|
channel, err := getChannel(c, group, originalModel, i)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -161,7 +156,6 @@ func WssRelay(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func relayRequest(c *gin.Context, relayMode int, channel *model.Channel) *dto.OpenAIErrorWithStatusCode {
|
func relayRequest(c *gin.Context, relayMode int, channel *model.Channel) *dto.OpenAIErrorWithStatusCode {
|
||||||
common.LogInfo(c, fmt.Sprintf("relayMode: %d ,channel Id : %s",relayMode, string(channel.Id)))
|
|
||||||
addUsedChannel(c, channel.Id)
|
addUsedChannel(c, channel.Id)
|
||||||
requestBody, _ := common.GetRequestBody(c)
|
requestBody, _ := common.GetRequestBody(c)
|
||||||
c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody))
|
c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody))
|
||||||
|
|||||||
@@ -12,8 +12,35 @@ type EmbeddingOptions struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type EmbeddingRequest struct {
|
type EmbeddingRequest struct {
|
||||||
Model string `json:"model"`
|
Model string `json:"model"`
|
||||||
Input []string `json:"input"`
|
Input any `json:"input"`
|
||||||
|
EncodingFormat string `json:"encoding_format,omitempty"`
|
||||||
|
Dimensions int `json:"dimensions,omitempty"`
|
||||||
|
User string `json:"user,omitempty"`
|
||||||
|
Seed float64 `json:"seed,omitempty"`
|
||||||
|
Temperature *float64 `json:"temperature,omitempty"`
|
||||||
|
TopP float64 `json:"top_p,omitempty"`
|
||||||
|
FrequencyPenalty float64 `json:"frequency_penalty,omitempty"`
|
||||||
|
PresencePenalty float64 `json:"presence_penalty,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r EmbeddingRequest) ParseInput() []string {
|
||||||
|
if r.Input == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
var input []string
|
||||||
|
switch r.Input.(type) {
|
||||||
|
case string:
|
||||||
|
input = []string{r.Input.(string)}
|
||||||
|
case []any:
|
||||||
|
input = make([]string, 0, len(r.Input.([]any)))
|
||||||
|
for _, item := range r.Input.([]any) {
|
||||||
|
if str, ok := item.(string); ok {
|
||||||
|
input = append(input, str)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return input
|
||||||
}
|
}
|
||||||
|
|
||||||
type EmbeddingResponseItem struct {
|
type EmbeddingResponseItem struct {
|
||||||
@@ -23,8 +50,8 @@ type EmbeddingResponseItem struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type EmbeddingResponse struct {
|
type EmbeddingResponse struct {
|
||||||
Object string `json:"object"`
|
Object string `json:"object"`
|
||||||
Data []EmbeddingResponseItem `json:"data"`
|
Data []EmbeddingResponseItem `json:"data"`
|
||||||
Model string `json:"model"`
|
Model string `json:"model"`
|
||||||
Usage `json:"usage"`
|
Usage `json:"usage"`
|
||||||
}
|
}
|
||||||
@@ -49,9 +49,6 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, re
|
|||||||
return nil, errors.New("request is nil")
|
return nil, errors.New("request is nil")
|
||||||
}
|
}
|
||||||
switch info.RelayMode {
|
switch info.RelayMode {
|
||||||
case constant.RelayModeEmbeddings:
|
|
||||||
baiduEmbeddingRequest := embeddingRequestOpenAI2Ali(*request)
|
|
||||||
return baiduEmbeddingRequest, nil
|
|
||||||
default:
|
default:
|
||||||
aliReq := requestOpenAI2Ali(*request)
|
aliReq := requestOpenAI2Ali(*request)
|
||||||
return aliReq, nil
|
return aliReq, nil
|
||||||
@@ -68,8 +65,7 @@ func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dt
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.EmbeddingRequest) (any, error) {
|
func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.EmbeddingRequest) (any, error) {
|
||||||
//TODO implement me
|
return embeddingRequestOpenAI2Ali(request), nil
|
||||||
return nil, errors.New("not implemented")
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) {
|
func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) {
|
||||||
|
|||||||
@@ -25,9 +25,12 @@ func requestOpenAI2Ali(request dto.GeneralOpenAIRequest) *dto.GeneralOpenAIReque
|
|||||||
return &request
|
return &request
|
||||||
}
|
}
|
||||||
|
|
||||||
func embeddingRequestOpenAI2Ali(request dto.GeneralOpenAIRequest) *AliEmbeddingRequest {
|
func embeddingRequestOpenAI2Ali(request dto.EmbeddingRequest) *AliEmbeddingRequest {
|
||||||
|
if request.Model == "" {
|
||||||
|
request.Model = "text-embedding-v1"
|
||||||
|
}
|
||||||
return &AliEmbeddingRequest{
|
return &AliEmbeddingRequest{
|
||||||
Model: "text-embedding-v1",
|
Model: request.Model,
|
||||||
Input: struct {
|
Input: struct {
|
||||||
Texts []string `json:"texts"`
|
Texts []string `json:"texts"`
|
||||||
}{
|
}{
|
||||||
|
|||||||
@@ -109,9 +109,6 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, re
|
|||||||
return nil, errors.New("request is nil")
|
return nil, errors.New("request is nil")
|
||||||
}
|
}
|
||||||
switch info.RelayMode {
|
switch info.RelayMode {
|
||||||
case constant.RelayModeEmbeddings:
|
|
||||||
baiduEmbeddingRequest := embeddingRequestOpenAI2Baidu(*request)
|
|
||||||
return baiduEmbeddingRequest, nil
|
|
||||||
default:
|
default:
|
||||||
baiduRequest := requestOpenAI2Baidu(*request)
|
baiduRequest := requestOpenAI2Baidu(*request)
|
||||||
return baiduRequest, nil
|
return baiduRequest, nil
|
||||||
@@ -123,8 +120,8 @@ func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dt
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.EmbeddingRequest) (any, error) {
|
func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.EmbeddingRequest) (any, error) {
|
||||||
//TODO implement me
|
baiduEmbeddingRequest := embeddingRequestOpenAI2Baidu(request)
|
||||||
return nil, errors.New("not implemented")
|
return baiduEmbeddingRequest, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
|
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
|
||||||
|
|||||||
@@ -87,7 +87,7 @@ func streamResponseBaidu2OpenAI(baiduResponse *BaiduChatStreamResponse) *dto.Cha
|
|||||||
return &response
|
return &response
|
||||||
}
|
}
|
||||||
|
|
||||||
func embeddingRequestOpenAI2Baidu(request dto.GeneralOpenAIRequest) *BaiduEmbeddingRequest {
|
func embeddingRequestOpenAI2Baidu(request dto.EmbeddingRequest) *BaiduEmbeddingRequest {
|
||||||
return &BaiduEmbeddingRequest{
|
return &BaiduEmbeddingRequest{
|
||||||
Input: request.ParseInput(),
|
Input: request.ParseInput(),
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -57,11 +57,9 @@ func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dt
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.EmbeddingRequest) (any, error) {
|
func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.EmbeddingRequest) (any, error) {
|
||||||
//TODO implement me
|
return request, nil
|
||||||
return nil, errors.New("not implemented")
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) {
|
func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) {
|
||||||
// 添加文件字段
|
// 添加文件字段
|
||||||
file, _, err := c.Request.FormFile("file")
|
file, _, err := c.Request.FormFile("file")
|
||||||
|
|||||||
@@ -56,11 +56,9 @@ func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dt
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.EmbeddingRequest) (any, error) {
|
func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.EmbeddingRequest) (any, error) {
|
||||||
//TODO implement me
|
return request, nil
|
||||||
return nil, errors.New("not implemented")
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) {
|
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) {
|
||||||
if info.RelayMode == constant.RelayModeRerank {
|
if info.RelayMode == constant.RelayModeRerank {
|
||||||
err, usage = jinaRerankHandler(c, resp)
|
err, usage = jinaRerankHandler(c, resp)
|
||||||
|
|||||||
@@ -46,12 +46,7 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, re
|
|||||||
if request == nil {
|
if request == nil {
|
||||||
return nil, errors.New("request is nil")
|
return nil, errors.New("request is nil")
|
||||||
}
|
}
|
||||||
switch info.RelayMode {
|
return requestOpenAI2Ollama(*request), nil
|
||||||
case relayconstant.RelayModeEmbeddings:
|
|
||||||
return requestOpenAI2Embeddings(*request), nil
|
|
||||||
default:
|
|
||||||
return requestOpenAI2Ollama(*request), nil
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) {
|
func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) {
|
||||||
@@ -59,11 +54,9 @@ func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dt
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.EmbeddingRequest) (any, error) {
|
func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.EmbeddingRequest) (any, error) {
|
||||||
//TODO implement me
|
return requestOpenAI2Embeddings(request), nil
|
||||||
return nil, errors.New("not implemented")
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
|
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
|
||||||
return channel.DoApiRequest(a, c, info, requestBody)
|
return channel.DoApiRequest(a, c, info, requestBody)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -42,7 +42,7 @@ func requestOpenAI2Ollama(request dto.GeneralOpenAIRequest) *OllamaRequest {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func requestOpenAI2Embeddings(request dto.GeneralOpenAIRequest) *OllamaEmbeddingRequest {
|
func requestOpenAI2Embeddings(request dto.EmbeddingRequest) *OllamaEmbeddingRequest {
|
||||||
return &OllamaEmbeddingRequest{
|
return &OllamaEmbeddingRequest{
|
||||||
Model: request.Model,
|
Model: request.Model,
|
||||||
Input: request.ParseInput(),
|
Input: request.ParseInput(),
|
||||||
@@ -123,9 +123,9 @@ func ollamaEmbeddingHandler(c *gin.Context, resp *http.Response, promptTokens in
|
|||||||
}
|
}
|
||||||
|
|
||||||
func flattenEmbeddings(embeddings [][]float64) []float64 {
|
func flattenEmbeddings(embeddings [][]float64) []float64 {
|
||||||
flattened := []float64{}
|
flattened := []float64{}
|
||||||
for _, row := range embeddings {
|
for _, row := range embeddings {
|
||||||
flattened = append(flattened, row...)
|
flattened = append(flattened, row...)
|
||||||
}
|
}
|
||||||
return flattened
|
return flattened
|
||||||
}
|
}
|
||||||
@@ -150,8 +150,7 @@ func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dt
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.EmbeddingRequest) (any, error) {
|
func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.EmbeddingRequest) (any, error) {
|
||||||
//TODO implement me
|
return request, nil
|
||||||
return nil, errors.New("not implemented")
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) {
|
func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) {
|
||||||
|
|||||||
@@ -59,11 +59,9 @@ func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dt
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.EmbeddingRequest) (any, error) {
|
func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.EmbeddingRequest) (any, error) {
|
||||||
//TODO implement me
|
return request, nil
|
||||||
return nil, errors.New("not implemented")
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) {
|
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) {
|
||||||
switch info.RelayMode {
|
switch info.RelayMode {
|
||||||
case constant.RelayModeRerank:
|
case constant.RelayModeRerank:
|
||||||
|
|||||||
@@ -19,7 +19,20 @@ func getEmbeddingPromptToken(embeddingRequest dto.EmbeddingRequest) int {
|
|||||||
return token
|
return token
|
||||||
}
|
}
|
||||||
|
|
||||||
func EmbeddingHelper(c *gin.Context, relayMode int) (openaiErr *dto.OpenAIErrorWithStatusCode) {
|
func validateEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, embeddingRequest dto.EmbeddingRequest) error {
|
||||||
|
if embeddingRequest.Input == nil {
|
||||||
|
return fmt.Errorf("input is empty")
|
||||||
|
}
|
||||||
|
if info.RelayMode == relayconstant.RelayModeModerations && embeddingRequest.Model == "" {
|
||||||
|
embeddingRequest.Model = "omni-moderation-latest"
|
||||||
|
}
|
||||||
|
if info.RelayMode == relayconstant.RelayModeEmbeddings && embeddingRequest.Model == "" {
|
||||||
|
embeddingRequest.Model = c.Param("model")
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func EmbeddingHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) {
|
||||||
relayInfo := relaycommon.GenRelayInfo(c)
|
relayInfo := relaycommon.GenRelayInfo(c)
|
||||||
|
|
||||||
var embeddingRequest *dto.EmbeddingRequest
|
var embeddingRequest *dto.EmbeddingRequest
|
||||||
@@ -28,15 +41,12 @@ func EmbeddingHelper(c *gin.Context, relayMode int) (openaiErr *dto.OpenAIErrorW
|
|||||||
common.LogError(c, fmt.Sprintf("getAndValidateTextRequest failed: %s", err.Error()))
|
common.LogError(c, fmt.Sprintf("getAndValidateTextRequest failed: %s", err.Error()))
|
||||||
return service.OpenAIErrorWrapperLocal(err, "invalid_text_request", http.StatusBadRequest)
|
return service.OpenAIErrorWrapperLocal(err, "invalid_text_request", http.StatusBadRequest)
|
||||||
}
|
}
|
||||||
if relayMode == relayconstant.RelayModeModerations && embeddingRequest.Model == "" {
|
|
||||||
embeddingRequest.Model = "m3e-base"
|
err = validateEmbeddingRequest(c, relayInfo, *embeddingRequest)
|
||||||
}
|
if err != nil {
|
||||||
if relayMode == relayconstant.RelayModeEmbeddings && embeddingRequest.Model == "" {
|
return service.OpenAIErrorWrapperLocal(err, "invalid_embedding_request", http.StatusBadRequest)
|
||||||
embeddingRequest.Model = c.Param("model")
|
|
||||||
}
|
|
||||||
if embeddingRequest.Input == nil || len(embeddingRequest.Input) == 0 {
|
|
||||||
return service.OpenAIErrorWrapperLocal(fmt.Errorf("input is empty"), "invalid_input", http.StatusBadRequest)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// map model name
|
// map model name
|
||||||
modelMapping := c.GetString("model_mapping")
|
modelMapping := c.GetString("model_mapping")
|
||||||
//isModelMapped := false
|
//isModelMapped := false
|
||||||
@@ -89,7 +99,7 @@ func EmbeddingHelper(c *gin.Context, relayMode int) (openaiErr *dto.OpenAIErrorW
|
|||||||
}
|
}
|
||||||
adaptor.Init(relayInfo)
|
adaptor.Init(relayInfo)
|
||||||
|
|
||||||
convertedRequest, err := adaptor.ConvertEmbeddingRequest(c,relayInfo,*embeddingRequest)
|
convertedRequest, err := adaptor.ConvertEmbeddingRequest(c, relayInfo, *embeddingRequest)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return service.OpenAIErrorWrapperLocal(err, "convert_request_failed", http.StatusInternalServerError)
|
return service.OpenAIErrorWrapperLocal(err, "convert_request_failed", http.StatusInternalServerError)
|
||||||
@@ -100,7 +110,7 @@ func EmbeddingHelper(c *gin.Context, relayMode int) (openaiErr *dto.OpenAIErrorW
|
|||||||
}
|
}
|
||||||
requestBody := bytes.NewBuffer(jsonData)
|
requestBody := bytes.NewBuffer(jsonData)
|
||||||
statusCodeMappingStr := c.GetString("status_code_mapping")
|
statusCodeMappingStr := c.GetString("status_code_mapping")
|
||||||
resp, err := adaptor.DoRequest(c,relayInfo, requestBody)
|
resp, err := adaptor.DoRequest(c, relayInfo, requestBody)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return service.OpenAIErrorWrapper(err, "do_request_failed", http.StatusInternalServerError)
|
return service.OpenAIErrorWrapper(err, "do_request_failed", http.StatusInternalServerError)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -4,7 +4,6 @@ import (
|
|||||||
"encoding/json"
|
"encoding/json"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"github.com/pkoukk/tiktoken-go"
|
|
||||||
"image"
|
"image"
|
||||||
"log"
|
"log"
|
||||||
"math"
|
"math"
|
||||||
@@ -14,6 +13,8 @@ import (
|
|||||||
relaycommon "one-api/relay/common"
|
relaycommon "one-api/relay/common"
|
||||||
"strings"
|
"strings"
|
||||||
"unicode/utf8"
|
"unicode/utf8"
|
||||||
|
|
||||||
|
"github.com/pkoukk/tiktoken-go"
|
||||||
)
|
)
|
||||||
|
|
||||||
// tokenEncoderMap won't grow after initialization
|
// tokenEncoderMap won't grow after initialization
|
||||||
@@ -323,6 +324,12 @@ func CountTokenInput(input any, model string) (int, error) {
|
|||||||
text += s
|
text += s
|
||||||
}
|
}
|
||||||
return CountTextToken(text, model)
|
return CountTextToken(text, model)
|
||||||
|
case []interface{}:
|
||||||
|
text := ""
|
||||||
|
for _, item := range v {
|
||||||
|
text += fmt.Sprintf("%v", item)
|
||||||
|
}
|
||||||
|
return CountTextToken(text, model)
|
||||||
}
|
}
|
||||||
return CountTokenInput(fmt.Sprintf("%v", input), model)
|
return CountTokenInput(fmt.Sprintf("%v", input), model)
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user