feat: support native Gemini Embedding

This commit is contained in:
RedwindA
2025-08-09 00:27:33 +08:00
parent 962c40c1a7
commit b70d2655ed
6 changed files with 170 additions and 10 deletions

View File

@@ -42,7 +42,11 @@ func relayHandler(c *gin.Context, relayMode int) *types.NewAPIError {
case relayconstant.RelayModeResponses:
err = relay.ResponsesHelper(c)
case relayconstant.RelayModeGemini:
err = relay.GeminiHelper(c)
if strings.Contains(c.Request.URL.Path, "embed") {
err = relay.GeminiEmbeddingHandler(c)
} else {
err = relay.GeminiHelper(c)
}
default:
err = relay.TextHelper(c)
}

View File

@@ -210,16 +210,25 @@ type GeminiImagePrediction struct {
// Embedding related structs
type GeminiEmbeddingRequest struct {
Model string `json:"model,omitempty"`
Content GeminiChatContent `json:"content"`
TaskType string `json:"taskType,omitempty"`
Title string `json:"title,omitempty"`
OutputDimensionality int `json:"outputDimensionality,omitempty"`
}
type GeminiBatchEmbeddingRequest struct {
Requests []GeminiEmbeddingRequest `json:"requests"`
}
type GeminiEmbeddingResponse struct {
Embedding ContentEmbedding `json:"embedding"`
}
type GeminiBatchEmbeddingResponse struct {
Embeddings []ContentEmbedding `json:"embeddings"`
}
type ContentEmbedding struct {
Values []float64 `json:"values"`
}

View File

@@ -114,7 +114,11 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
if strings.HasPrefix(info.UpstreamModelName, "text-embedding") ||
strings.HasPrefix(info.UpstreamModelName, "embedding") ||
strings.HasPrefix(info.UpstreamModelName, "gemini-embedding") {
return fmt.Sprintf("%s/%s/models/%s:embedContent", info.BaseUrl, version, info.UpstreamModelName), nil
action := "embedContent"
if info.IsGeminiBatchEmbdding {
action = "batchEmbedContents"
}
return fmt.Sprintf("%s/%s/models/%s:%s", info.BaseUrl, version, info.UpstreamModelName, action), nil
}
action := "generateContent"
@@ -195,6 +199,9 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *types.NewAPIError) {
if info.RelayMode == constant.RelayModeGemini {
if strings.Contains(info.RequestURLPath, "embed") {
return NativeGeminiEmbeddingHandler(c, resp, info)
}
if info.IsStream {
return GeminiTextGenerationStreamHandler(c, info, resp)
} else {

View File

@@ -1,7 +1,6 @@
package gemini
import (
"github.com/pkg/errors"
"io"
"net/http"
"one-api/common"
@@ -12,6 +11,8 @@ import (
"one-api/types"
"strings"
"github.com/pkg/errors"
"github.com/gin-gonic/gin"
)
@@ -63,6 +64,42 @@ func GeminiTextGenerationHandler(c *gin.Context, info *relaycommon.RelayInfo, re
return &usage, nil
}
func NativeGeminiEmbeddingHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.Usage, *types.NewAPIError) {
defer common.CloseResponseBodyGracefully(resp)
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
}
if common.DebugEnabled {
println(string(responseBody))
}
usage := &dto.Usage{
PromptTokens: info.PromptTokens,
TotalTokens: info.PromptTokens,
}
if info.IsGeminiBatchEmbdding {
var geminiResponse dto.GeminiBatchEmbeddingResponse
err = common.Unmarshal(responseBody, &geminiResponse)
if err != nil {
return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
}
} else {
var geminiResponse dto.GeminiEmbeddingResponse
err = common.Unmarshal(responseBody, &geminiResponse)
if err != nil {
return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
}
}
common.IOCopyBytesGracefully(c, resp, responseBody)
return usage, nil
}
func GeminiTextGenerationStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
var usage = &dto.Usage{}
var imageCount int

View File

@@ -74,13 +74,14 @@ type RelayInfo struct {
FirstResponseTime time.Time
isFirstResponse bool
//SendLastReasoningResponse bool
ApiType int
IsStream bool
IsPlayground bool
UsePrice bool
RelayMode int
UpstreamModelName string
OriginModelName string
ApiType int
IsStream bool
IsGeminiBatchEmbdding bool
IsPlayground bool
UsePrice bool
RelayMode int
UpstreamModelName string
OriginModelName string
//RecodeModelName string
RequestURLPath string
ApiVersion string

View File

@@ -264,3 +264,105 @@ func GeminiHelper(c *gin.Context) (newAPIError *types.NewAPIError) {
postConsumeQuota(c, relayInfo, usage.(*dto.Usage), preConsumedQuota, userQuota, priceData, "")
return nil
}
func GeminiEmbeddingHandler(c *gin.Context) (newAPIError *types.NewAPIError) {
relayInfo := relaycommon.GenRelayInfoGemini(c)
isBatch := strings.HasSuffix(c.Request.URL.Path, "batchEmbedContents")
relayInfo.IsGeminiBatchEmbdding = isBatch
var promptTokens int
var req any
var err error
var inputTexts []string
if isBatch {
batchRequest := &dto.GeminiBatchEmbeddingRequest{}
err = common.UnmarshalBodyReusable(c, batchRequest)
if err != nil {
return types.NewError(err, types.ErrorCodeInvalidRequest, types.ErrOptionWithSkipRetry())
}
req = batchRequest
for _, r := range batchRequest.Requests {
for _, part := range r.Content.Parts {
if part.Text != "" {
inputTexts = append(inputTexts, part.Text)
}
}
}
} else {
singleRequest := &dto.GeminiEmbeddingRequest{}
err = common.UnmarshalBodyReusable(c, singleRequest)
if err != nil {
return types.NewError(err, types.ErrorCodeInvalidRequest, types.ErrOptionWithSkipRetry())
}
req = singleRequest
for _, part := range singleRequest.Content.Parts {
if part.Text != "" {
inputTexts = append(inputTexts, part.Text)
}
}
}
promptTokens = service.CountTokenInput(strings.Join(inputTexts, "\n"), relayInfo.UpstreamModelName)
relayInfo.SetPromptTokens(promptTokens)
c.Set("prompt_tokens", promptTokens)
err = helper.ModelMappedHelper(c, relayInfo, req)
if err != nil {
return types.NewError(err, types.ErrorCodeChannelModelMappedError, types.ErrOptionWithSkipRetry())
}
priceData, err := helper.ModelPriceHelper(c, relayInfo, relayInfo.PromptTokens, 0)
if err != nil {
return types.NewError(err, types.ErrorCodeModelPriceError, types.ErrOptionWithSkipRetry())
}
preConsumedQuota, userQuota, newAPIError := preConsumeQuota(c, priceData.ShouldPreConsumedQuota, relayInfo)
if newAPIError != nil {
return newAPIError
}
defer func() {
if newAPIError != nil {
returnPreConsumedQuota(c, relayInfo, userQuota, preConsumedQuota)
}
}()
adaptor := GetAdaptor(relayInfo.ApiType)
if adaptor == nil {
return types.NewError(fmt.Errorf("invalid api type: %d", relayInfo.ApiType), types.ErrorCodeInvalidApiType, types.ErrOptionWithSkipRetry())
}
adaptor.Init(relayInfo)
var requestBody io.Reader
jsonData, err := common.Marshal(req)
if err != nil {
return types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry())
}
requestBody = bytes.NewReader(jsonData)
resp, err := adaptor.DoRequest(c, relayInfo, requestBody)
if err != nil {
common.LogError(c, "Do gemini request failed: "+err.Error())
return types.NewOpenAIError(err, types.ErrorCodeDoRequestFailed, http.StatusInternalServerError)
}
statusCodeMappingStr := c.GetString("status_code_mapping")
var httpResp *http.Response
if resp != nil {
httpResp = resp.(*http.Response)
if httpResp.StatusCode != http.StatusOK {
newAPIError = service.RelayErrorHandler(httpResp, false)
service.ResetStatusCode(newAPIError, statusCodeMappingStr)
return newAPIError
}
}
usage, openaiErr := adaptor.DoResponse(c, resp.(*http.Response), relayInfo)
if openaiErr != nil {
service.ResetStatusCode(openaiErr, statusCodeMappingStr)
return openaiErr
}
postConsumeQuota(c, relayInfo, usage.(*dto.Usage), preConsumedQuota, userQuota, priceData, "")
return nil
}