Merge pull request #1537 from RedwindA/feat/support-native-gemini-embedding
feat: 支持原生Gemini Embedding格式
This commit is contained in:
@@ -42,7 +42,11 @@ func relayHandler(c *gin.Context, relayMode int) *types.NewAPIError {
|
|||||||
case relayconstant.RelayModeResponses:
|
case relayconstant.RelayModeResponses:
|
||||||
err = relay.ResponsesHelper(c)
|
err = relay.ResponsesHelper(c)
|
||||||
case relayconstant.RelayModeGemini:
|
case relayconstant.RelayModeGemini:
|
||||||
err = relay.GeminiHelper(c)
|
if strings.Contains(c.Request.URL.Path, "embed") {
|
||||||
|
err = relay.GeminiEmbeddingHandler(c)
|
||||||
|
} else {
|
||||||
|
err = relay.GeminiHelper(c)
|
||||||
|
}
|
||||||
default:
|
default:
|
||||||
err = relay.TextHelper(c)
|
err = relay.TextHelper(c)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -210,6 +210,7 @@ type GeminiImagePrediction struct {
|
|||||||
|
|
||||||
// Embedding related structs
|
// Embedding related structs
|
||||||
type GeminiEmbeddingRequest struct {
|
type GeminiEmbeddingRequest struct {
|
||||||
|
Model string `json:"model,omitempty"`
|
||||||
Content GeminiChatContent `json:"content"`
|
Content GeminiChatContent `json:"content"`
|
||||||
TaskType string `json:"taskType,omitempty"`
|
TaskType string `json:"taskType,omitempty"`
|
||||||
Title string `json:"title,omitempty"`
|
Title string `json:"title,omitempty"`
|
||||||
@@ -220,10 +221,14 @@ type GeminiBatchEmbeddingRequest struct {
|
|||||||
Requests []*GeminiEmbeddingRequest `json:"requests"`
|
Requests []*GeminiEmbeddingRequest `json:"requests"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type GeminiEmbedding struct {
|
type GeminiEmbeddingResponse struct {
|
||||||
Values []float64 `json:"values"`
|
Embedding ContentEmbedding `json:"embedding"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type GeminiBatchEmbeddingResponse struct {
|
type GeminiBatchEmbeddingResponse struct {
|
||||||
Embeddings []*GeminiEmbedding `json:"embeddings"`
|
Embeddings []*ContentEmbedding `json:"embeddings"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type ContentEmbedding struct {
|
||||||
|
Values []float64 `json:"values"`
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -114,7 +114,11 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
|||||||
if strings.HasPrefix(info.UpstreamModelName, "text-embedding") ||
|
if strings.HasPrefix(info.UpstreamModelName, "text-embedding") ||
|
||||||
strings.HasPrefix(info.UpstreamModelName, "embedding") ||
|
strings.HasPrefix(info.UpstreamModelName, "embedding") ||
|
||||||
strings.HasPrefix(info.UpstreamModelName, "gemini-embedding") {
|
strings.HasPrefix(info.UpstreamModelName, "gemini-embedding") {
|
||||||
return fmt.Sprintf("%s/%s/models/%s:batchEmbedContents", info.BaseUrl, version, info.UpstreamModelName), nil
|
action := "embedContent"
|
||||||
|
if info.IsGeminiBatchEmbedding {
|
||||||
|
action = "batchEmbedContents"
|
||||||
|
}
|
||||||
|
return fmt.Sprintf("%s/%s/models/%s:%s", info.BaseUrl, version, info.UpstreamModelName, action), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
action := "generateContent"
|
action := "generateContent"
|
||||||
@@ -159,6 +163,9 @@ func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.Rela
|
|||||||
if len(inputs) == 0 {
|
if len(inputs) == 0 {
|
||||||
return nil, errors.New("input is empty")
|
return nil, errors.New("input is empty")
|
||||||
}
|
}
|
||||||
|
// We always build a batch-style payload with `requests`, so ensure we call the
|
||||||
|
// batch endpoint upstream to avoid payload/endpoint mismatches.
|
||||||
|
info.IsGeminiBatchEmbedding = true
|
||||||
// process all inputs
|
// process all inputs
|
||||||
geminiRequests := make([]map[string]interface{}, 0, len(inputs))
|
geminiRequests := make([]map[string]interface{}, 0, len(inputs))
|
||||||
for _, input := range inputs {
|
for _, input := range inputs {
|
||||||
@@ -176,7 +183,7 @@ func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.Rela
|
|||||||
// set specific parameters for different models
|
// set specific parameters for different models
|
||||||
// https://ai.google.dev/api/embeddings?hl=zh-cn#method:-models.embedcontent
|
// https://ai.google.dev/api/embeddings?hl=zh-cn#method:-models.embedcontent
|
||||||
switch info.UpstreamModelName {
|
switch info.UpstreamModelName {
|
||||||
case "text-embedding-004","gemini-embedding-exp-03-07","gemini-embedding-001":
|
case "text-embedding-004", "gemini-embedding-exp-03-07", "gemini-embedding-001":
|
||||||
// Only newer models introduced after 2024 support OutputDimensionality
|
// Only newer models introduced after 2024 support OutputDimensionality
|
||||||
if request.Dimensions > 0 {
|
if request.Dimensions > 0 {
|
||||||
geminiRequest["outputDimensionality"] = request.Dimensions
|
geminiRequest["outputDimensionality"] = request.Dimensions
|
||||||
@@ -201,6 +208,10 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request
|
|||||||
|
|
||||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *types.NewAPIError) {
|
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *types.NewAPIError) {
|
||||||
if info.RelayMode == constant.RelayModeGemini {
|
if info.RelayMode == constant.RelayModeGemini {
|
||||||
|
if strings.HasSuffix(info.RequestURLPath, ":embedContent") ||
|
||||||
|
strings.HasSuffix(info.RequestURLPath, ":batchEmbedContents") {
|
||||||
|
return NativeGeminiEmbeddingHandler(c, resp, info)
|
||||||
|
}
|
||||||
if info.IsStream {
|
if info.IsStream {
|
||||||
return GeminiTextGenerationStreamHandler(c, info, resp)
|
return GeminiTextGenerationStreamHandler(c, info, resp)
|
||||||
} else {
|
} else {
|
||||||
@@ -225,18 +236,6 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycom
|
|||||||
return GeminiChatHandler(c, info, resp)
|
return GeminiChatHandler(c, info, resp)
|
||||||
}
|
}
|
||||||
|
|
||||||
//if usage.(*dto.Usage).CompletionTokenDetails.ReasoningTokens > 100 {
|
|
||||||
// // 没有请求-thinking的情况下,产生思考token,则按照思考模型计费
|
|
||||||
// if !strings.HasSuffix(info.OriginModelName, "-thinking") &&
|
|
||||||
// !strings.HasSuffix(info.OriginModelName, "-nothinking") {
|
|
||||||
// thinkingModelName := info.OriginModelName + "-thinking"
|
|
||||||
// if operation_setting.SelfUseModeEnabled || helper.ContainPriceOrRatio(thinkingModelName) {
|
|
||||||
// info.OriginModelName = thinkingModelName
|
|
||||||
// }
|
|
||||||
// }
|
|
||||||
//}
|
|
||||||
|
|
||||||
return nil, types.NewError(errors.New("not implemented"), types.ErrorCodeBadResponseBody)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *Adaptor) GetModelList() []string {
|
func (a *Adaptor) GetModelList() []string {
|
||||||
|
|||||||
@@ -1,7 +1,6 @@
|
|||||||
package gemini
|
package gemini
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"github.com/pkg/errors"
|
|
||||||
"io"
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
"one-api/common"
|
"one-api/common"
|
||||||
@@ -12,6 +11,8 @@ import (
|
|||||||
"one-api/types"
|
"one-api/types"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
|
"github.com/pkg/errors"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -63,6 +64,42 @@ func GeminiTextGenerationHandler(c *gin.Context, info *relaycommon.RelayInfo, re
|
|||||||
return &usage, nil
|
return &usage, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func NativeGeminiEmbeddingHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.Usage, *types.NewAPIError) {
|
||||||
|
defer common.CloseResponseBodyGracefully(resp)
|
||||||
|
|
||||||
|
responseBody, err := io.ReadAll(resp.Body)
|
||||||
|
if err != nil {
|
||||||
|
return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
|
||||||
|
}
|
||||||
|
|
||||||
|
if common.DebugEnabled {
|
||||||
|
println(string(responseBody))
|
||||||
|
}
|
||||||
|
|
||||||
|
usage := &dto.Usage{
|
||||||
|
PromptTokens: info.PromptTokens,
|
||||||
|
TotalTokens: info.PromptTokens,
|
||||||
|
}
|
||||||
|
|
||||||
|
if info.IsGeminiBatchEmbedding {
|
||||||
|
var geminiResponse dto.GeminiBatchEmbeddingResponse
|
||||||
|
err = common.Unmarshal(responseBody, &geminiResponse)
|
||||||
|
if err != nil {
|
||||||
|
return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
var geminiResponse dto.GeminiEmbeddingResponse
|
||||||
|
err = common.Unmarshal(responseBody, &geminiResponse)
|
||||||
|
if err != nil {
|
||||||
|
return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
common.IOCopyBytesGracefully(c, resp, responseBody)
|
||||||
|
|
||||||
|
return usage, nil
|
||||||
|
}
|
||||||
|
|
||||||
func GeminiTextGenerationStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
|
func GeminiTextGenerationStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
|
||||||
var usage = &dto.Usage{}
|
var usage = &dto.Usage{}
|
||||||
var imageCount int
|
var imageCount int
|
||||||
|
|||||||
@@ -74,13 +74,14 @@ type RelayInfo struct {
|
|||||||
FirstResponseTime time.Time
|
FirstResponseTime time.Time
|
||||||
isFirstResponse bool
|
isFirstResponse bool
|
||||||
//SendLastReasoningResponse bool
|
//SendLastReasoningResponse bool
|
||||||
ApiType int
|
ApiType int
|
||||||
IsStream bool
|
IsStream bool
|
||||||
IsPlayground bool
|
IsGeminiBatchEmbedding bool
|
||||||
UsePrice bool
|
IsPlayground bool
|
||||||
RelayMode int
|
UsePrice bool
|
||||||
UpstreamModelName string
|
RelayMode int
|
||||||
OriginModelName string
|
UpstreamModelName string
|
||||||
|
OriginModelName string
|
||||||
//RecodeModelName string
|
//RecodeModelName string
|
||||||
RequestURLPath string
|
RequestURLPath string
|
||||||
ApiVersion string
|
ApiVersion string
|
||||||
|
|||||||
@@ -264,3 +264,118 @@ func GeminiHelper(c *gin.Context) (newAPIError *types.NewAPIError) {
|
|||||||
postConsumeQuota(c, relayInfo, usage.(*dto.Usage), preConsumedQuota, userQuota, priceData, "")
|
postConsumeQuota(c, relayInfo, usage.(*dto.Usage), preConsumedQuota, userQuota, priceData, "")
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func GeminiEmbeddingHandler(c *gin.Context) (newAPIError *types.NewAPIError) {
|
||||||
|
relayInfo := relaycommon.GenRelayInfoGemini(c)
|
||||||
|
|
||||||
|
isBatch := strings.HasSuffix(c.Request.URL.Path, "batchEmbedContents")
|
||||||
|
relayInfo.IsGeminiBatchEmbedding = isBatch
|
||||||
|
|
||||||
|
var promptTokens int
|
||||||
|
var req any
|
||||||
|
var err error
|
||||||
|
var inputTexts []string
|
||||||
|
|
||||||
|
if isBatch {
|
||||||
|
batchRequest := &dto.GeminiBatchEmbeddingRequest{}
|
||||||
|
err = common.UnmarshalBodyReusable(c, batchRequest)
|
||||||
|
if err != nil {
|
||||||
|
return types.NewError(err, types.ErrorCodeInvalidRequest, types.ErrOptionWithSkipRetry())
|
||||||
|
}
|
||||||
|
req = batchRequest
|
||||||
|
for _, r := range batchRequest.Requests {
|
||||||
|
for _, part := range r.Content.Parts {
|
||||||
|
if part.Text != "" {
|
||||||
|
inputTexts = append(inputTexts, part.Text)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
singleRequest := &dto.GeminiEmbeddingRequest{}
|
||||||
|
err = common.UnmarshalBodyReusable(c, singleRequest)
|
||||||
|
if err != nil {
|
||||||
|
return types.NewError(err, types.ErrorCodeInvalidRequest, types.ErrOptionWithSkipRetry())
|
||||||
|
}
|
||||||
|
req = singleRequest
|
||||||
|
for _, part := range singleRequest.Content.Parts {
|
||||||
|
if part.Text != "" {
|
||||||
|
inputTexts = append(inputTexts, part.Text)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
promptTokens = service.CountTokenInput(strings.Join(inputTexts, "\n"), relayInfo.UpstreamModelName)
|
||||||
|
relayInfo.SetPromptTokens(promptTokens)
|
||||||
|
c.Set("prompt_tokens", promptTokens)
|
||||||
|
|
||||||
|
err = helper.ModelMappedHelper(c, relayInfo, req)
|
||||||
|
if err != nil {
|
||||||
|
return types.NewError(err, types.ErrorCodeChannelModelMappedError, types.ErrOptionWithSkipRetry())
|
||||||
|
}
|
||||||
|
|
||||||
|
priceData, err := helper.ModelPriceHelper(c, relayInfo, relayInfo.PromptTokens, 0)
|
||||||
|
if err != nil {
|
||||||
|
return types.NewError(err, types.ErrorCodeModelPriceError, types.ErrOptionWithSkipRetry())
|
||||||
|
}
|
||||||
|
|
||||||
|
preConsumedQuota, userQuota, newAPIError := preConsumeQuota(c, priceData.ShouldPreConsumedQuota, relayInfo)
|
||||||
|
if newAPIError != nil {
|
||||||
|
return newAPIError
|
||||||
|
}
|
||||||
|
defer func() {
|
||||||
|
if newAPIError != nil {
|
||||||
|
returnPreConsumedQuota(c, relayInfo, userQuota, preConsumedQuota)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
adaptor := GetAdaptor(relayInfo.ApiType)
|
||||||
|
if adaptor == nil {
|
||||||
|
return types.NewError(fmt.Errorf("invalid api type: %d", relayInfo.ApiType), types.ErrorCodeInvalidApiType, types.ErrOptionWithSkipRetry())
|
||||||
|
}
|
||||||
|
adaptor.Init(relayInfo)
|
||||||
|
|
||||||
|
var requestBody io.Reader
|
||||||
|
jsonData, err := common.Marshal(req)
|
||||||
|
if err != nil {
|
||||||
|
return types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry())
|
||||||
|
}
|
||||||
|
|
||||||
|
// apply param override
|
||||||
|
if len(relayInfo.ParamOverride) > 0 {
|
||||||
|
reqMap := make(map[string]interface{})
|
||||||
|
_ = common.Unmarshal(jsonData, &reqMap)
|
||||||
|
for key, value := range relayInfo.ParamOverride {
|
||||||
|
reqMap[key] = value
|
||||||
|
}
|
||||||
|
jsonData, err = common.Marshal(reqMap)
|
||||||
|
if err != nil {
|
||||||
|
return types.NewError(err, types.ErrorCodeChannelParamOverrideInvalid, types.ErrOptionWithSkipRetry())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
requestBody = bytes.NewReader(jsonData)
|
||||||
|
|
||||||
|
resp, err := adaptor.DoRequest(c, relayInfo, requestBody)
|
||||||
|
if err != nil {
|
||||||
|
common.LogError(c, "Do gemini request failed: "+err.Error())
|
||||||
|
return types.NewOpenAIError(err, types.ErrorCodeDoRequestFailed, http.StatusInternalServerError)
|
||||||
|
}
|
||||||
|
|
||||||
|
statusCodeMappingStr := c.GetString("status_code_mapping")
|
||||||
|
var httpResp *http.Response
|
||||||
|
if resp != nil {
|
||||||
|
httpResp = resp.(*http.Response)
|
||||||
|
if httpResp.StatusCode != http.StatusOK {
|
||||||
|
newAPIError = service.RelayErrorHandler(httpResp, false)
|
||||||
|
service.ResetStatusCode(newAPIError, statusCodeMappingStr)
|
||||||
|
return newAPIError
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
usage, openaiErr := adaptor.DoResponse(c, resp.(*http.Response), relayInfo)
|
||||||
|
if openaiErr != nil {
|
||||||
|
service.ResetStatusCode(openaiErr, statusCodeMappingStr)
|
||||||
|
return openaiErr
|
||||||
|
}
|
||||||
|
|
||||||
|
postConsumeQuota(c, relayInfo, usage.(*dto.Usage), preConsumedQuota, userQuota, priceData, "")
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user