Merge pull request #845 from Sh1n3zZ/gemini-embedding
feat: gemini Embeddings support
This commit is contained in:
@@ -70,6 +70,12 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
|||||||
return fmt.Sprintf("%s/%s/models/%s:predict", info.BaseUrl, version, info.UpstreamModelName), nil
|
return fmt.Sprintf("%s/%s/models/%s:predict", info.BaseUrl, version, info.UpstreamModelName), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if strings.HasPrefix(info.UpstreamModelName, "text-embedding") ||
|
||||||
|
strings.HasPrefix(info.UpstreamModelName, "embedding") ||
|
||||||
|
strings.HasPrefix(info.UpstreamModelName, "gemini-embedding") {
|
||||||
|
return fmt.Sprintf("%s/%s/models/%s:embedContent", info.BaseUrl, version, info.UpstreamModelName), nil
|
||||||
|
}
|
||||||
|
|
||||||
action := "generateContent"
|
action := "generateContent"
|
||||||
if info.IsStream {
|
if info.IsStream {
|
||||||
action = "streamGenerateContent?alt=sse"
|
action = "streamGenerateContent?alt=sse"
|
||||||
@@ -99,8 +105,37 @@ func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dt
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.EmbeddingRequest) (any, error) {
|
func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.EmbeddingRequest) (any, error) {
|
||||||
//TODO implement me
|
if request.Input == nil {
|
||||||
return nil, errors.New("not implemented")
|
return nil, errors.New("input is required")
|
||||||
|
}
|
||||||
|
|
||||||
|
inputs := request.ParseInput()
|
||||||
|
if len(inputs) == 0 {
|
||||||
|
return nil, errors.New("input is empty")
|
||||||
|
}
|
||||||
|
|
||||||
|
// only process the first input
|
||||||
|
geminiRequest := GeminiEmbeddingRequest{
|
||||||
|
Content: GeminiChatContent{
|
||||||
|
Parts: []GeminiPart{
|
||||||
|
{
|
||||||
|
Text: inputs[0],
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
// set specific parameters for different models
|
||||||
|
// https://ai.google.dev/api/embeddings?hl=zh-cn#method:-models.embedcontent
|
||||||
|
switch info.UpstreamModelName {
|
||||||
|
case "text-embedding-004":
|
||||||
|
// except embedding-001 supports setting `OutputDimensionality`
|
||||||
|
if request.Dimensions > 0 {
|
||||||
|
geminiRequest.OutputDimensionality = request.Dimensions
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return geminiRequest, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
|
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
|
||||||
@@ -112,6 +147,13 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycom
|
|||||||
return GeminiImageHandler(c, resp, info)
|
return GeminiImageHandler(c, resp, info)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// check if the model is an embedding model
|
||||||
|
if strings.HasPrefix(info.UpstreamModelName, "text-embedding") ||
|
||||||
|
strings.HasPrefix(info.UpstreamModelName, "embedding") ||
|
||||||
|
strings.HasPrefix(info.UpstreamModelName, "gemini-embedding") {
|
||||||
|
return GeminiEmbeddingHandler(c, resp, info)
|
||||||
|
}
|
||||||
|
|
||||||
if info.IsStream {
|
if info.IsStream {
|
||||||
err, usage = GeminiChatStreamHandler(c, resp, info)
|
err, usage = GeminiChatStreamHandler(c, resp, info)
|
||||||
} else {
|
} else {
|
||||||
|
|||||||
@@ -18,6 +18,10 @@ var ModelList = []string{
|
|||||||
"gemini-2.0-flash-thinking-exp",
|
"gemini-2.0-flash-thinking-exp",
|
||||||
// imagen models
|
// imagen models
|
||||||
"imagen-3.0-generate-002",
|
"imagen-3.0-generate-002",
|
||||||
|
// embedding models
|
||||||
|
"gemini-embedding-exp-03-07",
|
||||||
|
"text-embedding-004",
|
||||||
|
"embedding-001",
|
||||||
}
|
}
|
||||||
|
|
||||||
var SafetySettingList = []string{
|
var SafetySettingList = []string{
|
||||||
|
|||||||
@@ -136,3 +136,19 @@ type GeminiImagePrediction struct {
|
|||||||
RaiFilteredReason string `json:"raiFilteredReason,omitempty"`
|
RaiFilteredReason string `json:"raiFilteredReason,omitempty"`
|
||||||
SafetyAttributes any `json:"safetyAttributes,omitempty"`
|
SafetyAttributes any `json:"safetyAttributes,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Embedding related structs
|
||||||
|
type GeminiEmbeddingRequest struct {
|
||||||
|
Content GeminiChatContent `json:"content"`
|
||||||
|
TaskType string `json:"taskType,omitempty"`
|
||||||
|
Title string `json:"title,omitempty"`
|
||||||
|
OutputDimensionality int `json:"outputDimensionality,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type GeminiEmbeddingResponse struct {
|
||||||
|
Embedding ContentEmbedding `json:"embedding"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type ContentEmbedding struct {
|
||||||
|
Values []float64 `json:"values"`
|
||||||
|
}
|
||||||
|
|||||||
@@ -580,3 +580,52 @@ func GeminiChatHandler(c *gin.Context, resp *http.Response, info *relaycommon.Re
|
|||||||
_, err = c.Writer.Write(jsonResponse)
|
_, err = c.Writer.Write(jsonResponse)
|
||||||
return nil, &usage
|
return nil, &usage
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func GeminiEmbeddingHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) {
|
||||||
|
responseBody, readErr := io.ReadAll(resp.Body)
|
||||||
|
if readErr != nil {
|
||||||
|
return nil, service.OpenAIErrorWrapper(readErr, "read_response_body_failed", http.StatusInternalServerError)
|
||||||
|
}
|
||||||
|
_ = resp.Body.Close()
|
||||||
|
|
||||||
|
var geminiResponse GeminiEmbeddingResponse
|
||||||
|
if jsonErr := json.Unmarshal(responseBody, &geminiResponse); jsonErr != nil {
|
||||||
|
return nil, service.OpenAIErrorWrapper(jsonErr, "unmarshal_response_body_failed", http.StatusInternalServerError)
|
||||||
|
}
|
||||||
|
|
||||||
|
// convert to openai format response
|
||||||
|
openAIResponse := dto.OpenAIEmbeddingResponse{
|
||||||
|
Object: "list",
|
||||||
|
Data: []dto.OpenAIEmbeddingResponseItem{
|
||||||
|
{
|
||||||
|
Object: "embedding",
|
||||||
|
Embedding: geminiResponse.Embedding.Values,
|
||||||
|
Index: 0,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
Model: info.UpstreamModelName,
|
||||||
|
}
|
||||||
|
|
||||||
|
// calculate usage
|
||||||
|
// https://ai.google.dev/gemini-api/docs/pricing?hl=zh-cn#text-embedding-004
|
||||||
|
// Google has not yet clarified how embedding models will be billed
|
||||||
|
// refer to openai billing method to use input tokens billing
|
||||||
|
// https://platform.openai.com/docs/guides/embeddings#what-are-embeddings
|
||||||
|
usage = &dto.Usage{
|
||||||
|
PromptTokens: info.PromptTokens,
|
||||||
|
CompletionTokens: 0,
|
||||||
|
TotalTokens: info.PromptTokens,
|
||||||
|
}
|
||||||
|
openAIResponse.Usage = *usage.(*dto.Usage)
|
||||||
|
|
||||||
|
jsonResponse, jsonErr := json.Marshal(openAIResponse)
|
||||||
|
if jsonErr != nil {
|
||||||
|
return nil, service.OpenAIErrorWrapper(jsonErr, "marshal_response_failed", http.StatusInternalServerError)
|
||||||
|
}
|
||||||
|
|
||||||
|
c.Writer.Header().Set("Content-Type", "application/json")
|
||||||
|
c.Writer.WriteHeader(resp.StatusCode)
|
||||||
|
_, _ = c.Writer.Write(jsonResponse)
|
||||||
|
|
||||||
|
return usage, nil
|
||||||
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user