fix : Gemini embedding model only embeds the first text in a batch

This commit is contained in:
antecanis8
2025-08-04 13:02:57 +00:00
parent faaa5a2949
commit 43263a3bc8
3 changed files with 43 additions and 31 deletions

View File

@@ -216,10 +216,14 @@ type GeminiEmbeddingRequest struct {
OutputDimensionality int `json:"outputDimensionality,omitempty"` OutputDimensionality int `json:"outputDimensionality,omitempty"`
} }
type GeminiEmbeddingResponse struct { type GeminiBatchEmbeddingRequest struct {
Embedding ContentEmbedding `json:"embedding"` Requests []*GeminiEmbeddingRequest `json:"requests"`
} }
type ContentEmbedding struct { type GeminiEmbedding struct {
Values []float64 `json:"values"` Values []float64 `json:"values"`
} }
type GeminiBatchEmbeddingResponse struct {
Embeddings []*GeminiEmbedding `json:"embeddings"`
}

View File

@@ -114,7 +114,7 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
if strings.HasPrefix(info.UpstreamModelName, "text-embedding") || if strings.HasPrefix(info.UpstreamModelName, "text-embedding") ||
strings.HasPrefix(info.UpstreamModelName, "embedding") || strings.HasPrefix(info.UpstreamModelName, "embedding") ||
strings.HasPrefix(info.UpstreamModelName, "gemini-embedding") { strings.HasPrefix(info.UpstreamModelName, "gemini-embedding") {
return fmt.Sprintf("%s/%s/models/%s:embedContent", info.BaseUrl, version, info.UpstreamModelName), nil return fmt.Sprintf("%s/%s/models/%s:batchEmbedContents", info.BaseUrl, version, info.UpstreamModelName), nil
} }
action := "generateContent" action := "generateContent"
@@ -156,13 +156,15 @@ func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.Rela
if len(inputs) == 0 { if len(inputs) == 0 {
return nil, errors.New("input is empty") return nil, errors.New("input is empty")
} }
// process all inputs
// only process the first input geminiRequests := make([]map[string]interface{}, 0, len(inputs))
geminiRequest := dto.GeminiEmbeddingRequest{ for _, input := range inputs {
Content: dto.GeminiChatContent{ geminiRequest := map[string]interface{}{
"model": fmt.Sprintf("models/%s", info.UpstreamModelName),
"content": dto.GeminiChatContent{
Parts: []dto.GeminiPart{ Parts: []dto.GeminiPart{
{ {
Text: inputs[0], Text: input,
}, },
}, },
}, },
@@ -174,11 +176,15 @@ func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.Rela
case "text-embedding-004": case "text-embedding-004":
// except embedding-001 supports setting `OutputDimensionality` // except embedding-001 supports setting `OutputDimensionality`
if request.Dimensions > 0 { if request.Dimensions > 0 {
geminiRequest.OutputDimensionality = request.Dimensions geminiRequest["outputDimensionality"] = request.Dimensions
} }
} }
geminiRequests = append(geminiRequests, geminiRequest)
}
return geminiRequest, nil return map[string]interface{}{
"requests": geminiRequests,
}, nil
} }
func (a *Adaptor) ConvertOpenAIResponsesRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.OpenAIResponsesRequest) (any, error) { func (a *Adaptor) ConvertOpenAIResponsesRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.OpenAIResponsesRequest) (any, error) {

View File

@@ -974,7 +974,7 @@ func GeminiEmbeddingHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *h
return nil, types.NewOpenAIError(readErr, types.ErrorCodeBadResponseBody, http.StatusInternalServerError) return nil, types.NewOpenAIError(readErr, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
} }
var geminiResponse dto.GeminiEmbeddingResponse var geminiResponse dto.GeminiBatchEmbeddingResponse
if jsonErr := common.Unmarshal(responseBody, &geminiResponse); jsonErr != nil { if jsonErr := common.Unmarshal(responseBody, &geminiResponse); jsonErr != nil {
return nil, types.NewOpenAIError(jsonErr, types.ErrorCodeBadResponseBody, http.StatusInternalServerError) return nil, types.NewOpenAIError(jsonErr, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
} }
@@ -982,16 +982,18 @@ func GeminiEmbeddingHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *h
// convert to openai format response // convert to openai format response
openAIResponse := dto.OpenAIEmbeddingResponse{ openAIResponse := dto.OpenAIEmbeddingResponse{
Object: "list", Object: "list",
Data: []dto.OpenAIEmbeddingResponseItem{ Data: make([]dto.OpenAIEmbeddingResponseItem, 0, len(geminiResponse.Embeddings)),
{
Object: "embedding",
Embedding: geminiResponse.Embedding.Values,
Index: 0,
},
},
Model: info.UpstreamModelName, Model: info.UpstreamModelName,
} }
for i, embedding := range geminiResponse.Embeddings {
openAIResponse.Data = append(openAIResponse.Data, dto.OpenAIEmbeddingResponseItem{
Object: "embedding",
Embedding: embedding.Values,
Index: i,
})
}
// calculate usage // calculate usage
// https://ai.google.dev/gemini-api/docs/pricing?hl=zh-cn#text-embedding-004 // https://ai.google.dev/gemini-api/docs/pricing?hl=zh-cn#text-embedding-004
// Google has not yet clarified how embedding models will be billed // Google has not yet clarified how embedding models will be billed