From e1b9f164f990a04874289d522e40fbbbce0628a9 Mon Sep 17 00:00:00 2001 From: Sh1n3zZ Date: Mon, 10 Mar 2025 23:32:06 +0800 Subject: [PATCH] feat: gemini Embeddings support --- relay/channel/gemini/adaptor.go | 46 ++++++++++++++++++++++++-- relay/channel/gemini/constant.go | 4 +++ relay/channel/gemini/dto.go | 16 +++++++++ relay/channel/gemini/relay-gemini.go | 49 ++++++++++++++++++++++++++++ 4 files changed, 113 insertions(+), 2 deletions(-) diff --git a/relay/channel/gemini/adaptor.go b/relay/channel/gemini/adaptor.go index 37c6c9df..1b7131dc 100644 --- a/relay/channel/gemini/adaptor.go +++ b/relay/channel/gemini/adaptor.go @@ -70,6 +70,12 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) { return fmt.Sprintf("%s/%s/models/%s:predict", info.BaseUrl, version, info.UpstreamModelName), nil } + if strings.HasPrefix(info.UpstreamModelName, "text-embedding") || + strings.HasPrefix(info.UpstreamModelName, "embedding") || + strings.HasPrefix(info.UpstreamModelName, "gemini-embedding") { + return fmt.Sprintf("%s/%s/models/%s:embedContent", info.BaseUrl, version, info.UpstreamModelName), nil + } + action := "generateContent" if info.IsStream { action = "streamGenerateContent?alt=sse" @@ -99,8 +105,37 @@ func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dt } func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.EmbeddingRequest) (any, error) { - //TODO implement me - return nil, errors.New("not implemented") + if request.Input == nil { + return nil, errors.New("input is required") + } + + inputs := request.ParseInput() + if len(inputs) == 0 { + return nil, errors.New("input is empty") + } + + // only process the first input + geminiRequest := GeminiEmbeddingRequest{ + Content: GeminiChatContent{ + Parts: []GeminiPart{ + { + Text: inputs[0], + }, + }, + }, + } + + // set specific parameters for different models + // https://ai.google.dev/api/embeddings?hl=zh-cn#method:-models.embedcontent + switch info.UpstreamModelName { + case "text-embedding-004": + // except embedding-001 supports setting `OutputDimensionality` + if request.Dimensions > 0 { + geminiRequest.OutputDimensionality = request.Dimensions + } + } + + return geminiRequest, nil } func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) { @@ -112,6 +147,13 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycom return GeminiImageHandler(c, resp, info) } + // check if the model is an embedding model + if strings.HasPrefix(info.UpstreamModelName, "text-embedding") || + strings.HasPrefix(info.UpstreamModelName, "embedding") || + strings.HasPrefix(info.UpstreamModelName, "gemini-embedding") { + return GeminiEmbeddingHandler(c, resp, info) + } + if info.IsStream { err, usage = GeminiChatStreamHandler(c, resp, info) } else { diff --git a/relay/channel/gemini/constant.go b/relay/channel/gemini/constant.go index 1f402cbc..c40baaaa 100644 --- a/relay/channel/gemini/constant.go +++ b/relay/channel/gemini/constant.go @@ -18,6 +18,10 @@ var ModelList = []string{ "gemini-2.0-flash-thinking-exp", // imagen models "imagen-3.0-generate-002", + // embedding models + "gemini-embedding-exp-03-07", + "text-embedding-004", + "embedding-001", } var SafetySettingList = []string{ diff --git a/relay/channel/gemini/dto.go b/relay/channel/gemini/dto.go index bbcb1248..cbf55576 100644 --- a/relay/channel/gemini/dto.go +++ b/relay/channel/gemini/dto.go @@ -136,3 +136,19 @@ type GeminiImagePrediction struct { RaiFilteredReason string `json:"raiFilteredReason,omitempty"` SafetyAttributes any `json:"safetyAttributes,omitempty"` } + +// Embedding related structs +type GeminiEmbeddingRequest struct { + Content GeminiChatContent `json:"content"` + TaskType string `json:"taskType,omitempty"` + Title string `json:"title,omitempty"` + OutputDimensionality int `json:"outputDimensionality,omitempty"` +} + +type GeminiEmbeddingResponse struct { + Embedding ContentEmbedding `json:"embedding"` +} + +type ContentEmbedding struct { + Values []float64 `json:"values"` +} diff --git a/relay/channel/gemini/relay-gemini.go b/relay/channel/gemini/relay-gemini.go index c1ce8219..2fcb8c4d 100644 --- a/relay/channel/gemini/relay-gemini.go +++ b/relay/channel/gemini/relay-gemini.go @@ -580,3 +580,52 @@ func GeminiChatHandler(c *gin.Context, resp *http.Response, info *relaycommon.Re _, err = c.Writer.Write(jsonResponse) return nil, &usage } + +func GeminiEmbeddingHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) { + responseBody, readErr := io.ReadAll(resp.Body) + if readErr != nil { + return nil, service.OpenAIErrorWrapper(readErr, "read_response_body_failed", http.StatusInternalServerError) + } + _ = resp.Body.Close() + + var geminiResponse GeminiEmbeddingResponse + if jsonErr := json.Unmarshal(responseBody, &geminiResponse); jsonErr != nil { + return nil, service.OpenAIErrorWrapper(jsonErr, "unmarshal_response_body_failed", http.StatusInternalServerError) + } + + // convert to openai format response + openAIResponse := dto.OpenAIEmbeddingResponse{ + Object: "list", + Data: []dto.OpenAIEmbeddingResponseItem{ + { + Object: "embedding", + Embedding: geminiResponse.Embedding.Values, + Index: 0, + }, + }, + Model: info.UpstreamModelName, + } + + // calculate usage + // https://ai.google.dev/gemini-api/docs/pricing?hl=zh-cn#text-embedding-004 + // Google has not yet clarified how embedding models will be billed + // refer to openai billing method to use input tokens billing + // https://platform.openai.com/docs/guides/embeddings#what-are-embeddings + usage = &dto.Usage{ + PromptTokens: info.PromptTokens, + CompletionTokens: 0, + TotalTokens: info.PromptTokens, + } + openAIResponse.Usage = *usage.(*dto.Usage) + + jsonResponse, jsonErr := json.Marshal(openAIResponse) + if jsonErr != nil { + return nil, service.OpenAIErrorWrapper(jsonErr, "marshal_response_failed", http.StatusInternalServerError) + } + + c.Writer.Header().Set("Content-Type", "application/json") + c.Writer.WriteHeader(resp.StatusCode) + _, _ = c.Writer.Write(jsonResponse) + + return usage, nil +}