From b70d2655ed0c03b98b50a12dfda66445621bc065 Mon Sep 17 00:00:00 2001 From: RedwindA Date: Sat, 9 Aug 2025 00:27:33 +0800 Subject: [PATCH] feat: support native Gemini Embedding --- controller/relay.go | 6 +- dto/gemini.go | 9 ++ relay/channel/gemini/adaptor.go | 9 +- relay/channel/gemini/relay-gemini-native.go | 39 +++++++- relay/common/relay_info.go | 15 +-- relay/gemini_handler.go | 102 ++++++++++++++++++++ 6 files changed, 170 insertions(+), 10 deletions(-) diff --git a/controller/relay.go b/controller/relay.go index 1a35c7d7..c97eca20 100644 --- a/controller/relay.go +++ b/controller/relay.go @@ -42,7 +42,11 @@ func relayHandler(c *gin.Context, relayMode int) *types.NewAPIError { case relayconstant.RelayModeResponses: err = relay.ResponsesHelper(c) case relayconstant.RelayModeGemini: - err = relay.GeminiHelper(c) + if strings.Contains(c.Request.URL.Path, "embed") { + err = relay.GeminiEmbeddingHandler(c) + } else { + err = relay.GeminiHelper(c) + } default: err = relay.TextHelper(c) } diff --git a/dto/gemini.go b/dto/gemini.go index f7acd355..60179c1a 100644 --- a/dto/gemini.go +++ b/dto/gemini.go @@ -210,16 +210,25 @@ type GeminiImagePrediction struct { // Embedding related structs type GeminiEmbeddingRequest struct { + Model string `json:"model,omitempty"` Content GeminiChatContent `json:"content"` TaskType string `json:"taskType,omitempty"` Title string `json:"title,omitempty"` OutputDimensionality int `json:"outputDimensionality,omitempty"` } +type GeminiBatchEmbeddingRequest struct { + Requests []GeminiEmbeddingRequest `json:"requests"` +} + type GeminiEmbeddingResponse struct { Embedding ContentEmbedding `json:"embedding"` } +type GeminiBatchEmbeddingResponse struct { + Embeddings []ContentEmbedding `json:"embeddings"` +} + type ContentEmbedding struct { Values []float64 `json:"values"` } diff --git a/relay/channel/gemini/adaptor.go b/relay/channel/gemini/adaptor.go index 01dfea2c..f10fcef0 100644 --- a/relay/channel/gemini/adaptor.go +++ b/relay/channel/gemini/adaptor.go @@ -114,7 +114,11 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) { if strings.HasPrefix(info.UpstreamModelName, "text-embedding") || strings.HasPrefix(info.UpstreamModelName, "embedding") || strings.HasPrefix(info.UpstreamModelName, "gemini-embedding") { - return fmt.Sprintf("%s/%s/models/%s:embedContent", info.BaseUrl, version, info.UpstreamModelName), nil + action := "embedContent" + if info.IsGeminiBatchEmbdding { + action = "batchEmbedContents" + } + return fmt.Sprintf("%s/%s/models/%s:%s", info.BaseUrl, version, info.UpstreamModelName, action), nil } action := "generateContent" @@ -195,6 +199,9 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *types.NewAPIError) { if info.RelayMode == constant.RelayModeGemini { + if strings.Contains(info.RequestURLPath, "embed") { + return NativeGeminiEmbeddingHandler(c, resp, info) + } if info.IsStream { return GeminiTextGenerationStreamHandler(c, info, resp) } else { diff --git a/relay/channel/gemini/relay-gemini-native.go b/relay/channel/gemini/relay-gemini-native.go index 1ba599b3..1a94f936 100644 --- a/relay/channel/gemini/relay-gemini-native.go +++ b/relay/channel/gemini/relay-gemini-native.go @@ -1,7 +1,6 @@ package gemini import ( - "github.com/pkg/errors" "io" "net/http" "one-api/common" @@ -12,6 +11,8 @@ import ( "one-api/types" "strings" + "github.com/pkg/errors" + "github.com/gin-gonic/gin" ) @@ -63,6 +64,42 @@ func GeminiTextGenerationHandler(c *gin.Context, info *relaycommon.RelayInfo, re return &usage, nil } +func NativeGeminiEmbeddingHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.Usage, *types.NewAPIError) { + defer common.CloseResponseBodyGracefully(resp) + + responseBody, err := io.ReadAll(resp.Body) + if err != nil { + return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError) + } + + if common.DebugEnabled { + println(string(responseBody)) + } + + usage := &dto.Usage{ + PromptTokens: info.PromptTokens, + TotalTokens: info.PromptTokens, + } + + if info.IsGeminiBatchEmbdding { + var geminiResponse dto.GeminiBatchEmbeddingResponse + err = common.Unmarshal(responseBody, &geminiResponse) + if err != nil { + return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError) + } + } else { + var geminiResponse dto.GeminiEmbeddingResponse + err = common.Unmarshal(responseBody, &geminiResponse) + if err != nil { + return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError) + } + } + + common.IOCopyBytesGracefully(c, resp, responseBody) + + return usage, nil +} + func GeminiTextGenerationStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) { var usage = &dto.Usage{} var imageCount int diff --git a/relay/common/relay_info.go b/relay/common/relay_info.go index 743070ca..c694a230 100644 --- a/relay/common/relay_info.go +++ b/relay/common/relay_info.go @@ -74,13 +74,14 @@ type RelayInfo struct { FirstResponseTime time.Time isFirstResponse bool //SendLastReasoningResponse bool - ApiType int - IsStream bool - IsPlayground bool - UsePrice bool - RelayMode int - UpstreamModelName string - OriginModelName string + ApiType int + IsStream bool + IsGeminiBatchEmbdding bool + IsPlayground bool + UsePrice bool + RelayMode int + UpstreamModelName string + OriginModelName string //RecodeModelName string RequestURLPath string ApiVersion string diff --git a/relay/gemini_handler.go b/relay/gemini_handler.go index 42b695b7..04be36ad 100644 --- a/relay/gemini_handler.go +++ b/relay/gemini_handler.go @@ -264,3 +264,105 @@ func GeminiHelper(c *gin.Context) (newAPIError *types.NewAPIError) { postConsumeQuota(c, relayInfo, usage.(*dto.Usage), preConsumedQuota, userQuota, priceData, "") return nil } + +func GeminiEmbeddingHandler(c *gin.Context) (newAPIError *types.NewAPIError) { + relayInfo := relaycommon.GenRelayInfoGemini(c) + + isBatch := strings.HasSuffix(c.Request.URL.Path, "batchEmbedContents") + relayInfo.IsGeminiBatchEmbdding = isBatch + + var promptTokens int + var req any + var err error + var inputTexts []string + + if isBatch { + batchRequest := &dto.GeminiBatchEmbeddingRequest{} + err = common.UnmarshalBodyReusable(c, batchRequest) + if err != nil { + return types.NewError(err, types.ErrorCodeInvalidRequest, types.ErrOptionWithSkipRetry()) + } + req = batchRequest + for _, r := range batchRequest.Requests { + for _, part := range r.Content.Parts { + if part.Text != "" { + inputTexts = append(inputTexts, part.Text) + } + } + } + } else { + singleRequest := &dto.GeminiEmbeddingRequest{} + err = common.UnmarshalBodyReusable(c, singleRequest) + if err != nil { + return types.NewError(err, types.ErrorCodeInvalidRequest, types.ErrOptionWithSkipRetry()) + } + req = singleRequest + for _, part := range singleRequest.Content.Parts { + if part.Text != "" { + inputTexts = append(inputTexts, part.Text) + } + } + } + promptTokens = service.CountTokenInput(strings.Join(inputTexts, "\n"), relayInfo.UpstreamModelName) + relayInfo.SetPromptTokens(promptTokens) + c.Set("prompt_tokens", promptTokens) + + err = helper.ModelMappedHelper(c, relayInfo, req) + if err != nil { + return types.NewError(err, types.ErrorCodeChannelModelMappedError, types.ErrOptionWithSkipRetry()) + } + + priceData, err := helper.ModelPriceHelper(c, relayInfo, relayInfo.PromptTokens, 0) + if err != nil { + return types.NewError(err, types.ErrorCodeModelPriceError, types.ErrOptionWithSkipRetry()) + } + + preConsumedQuota, userQuota, newAPIError := preConsumeQuota(c, priceData.ShouldPreConsumedQuota, relayInfo) + if newAPIError != nil { + return newAPIError + } + defer func() { + if newAPIError != nil { + returnPreConsumedQuota(c, relayInfo, userQuota, preConsumedQuota) + } + }() + + adaptor := GetAdaptor(relayInfo.ApiType) + if adaptor == nil { + return types.NewError(fmt.Errorf("invalid api type: %d", relayInfo.ApiType), types.ErrorCodeInvalidApiType, types.ErrOptionWithSkipRetry()) + } + adaptor.Init(relayInfo) + + var requestBody io.Reader + jsonData, err := common.Marshal(req) + if err != nil { + return types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry()) + } + requestBody = bytes.NewReader(jsonData) + + resp, err := adaptor.DoRequest(c, relayInfo, requestBody) + if err != nil { + common.LogError(c, "Do gemini request failed: "+err.Error()) + return types.NewOpenAIError(err, types.ErrorCodeDoRequestFailed, http.StatusInternalServerError) + } + + statusCodeMappingStr := c.GetString("status_code_mapping") + var httpResp *http.Response + if resp != nil { + httpResp = resp.(*http.Response) + if httpResp.StatusCode != http.StatusOK { + newAPIError = service.RelayErrorHandler(httpResp, false) + service.ResetStatusCode(newAPIError, statusCodeMappingStr) + return newAPIError + } + } + + usage, openaiErr := adaptor.DoResponse(c, resp.(*http.Response), relayInfo) + if openaiErr != nil { + service.ResetStatusCode(openaiErr, statusCodeMappingStr) + return openaiErr + } + + postConsumeQuota(c, relayInfo, usage.(*dto.Usage), preConsumedQuota, userQuota, priceData, "") + return nil +}