feat: support native Gemini Embedding

This commit is contained in:
RedwindA
2025-08-09 00:27:33 +08:00
parent 962c40c1a7
commit b70d2655ed
6 changed files with 170 additions and 10 deletions

View File

@@ -42,7 +42,11 @@ func relayHandler(c *gin.Context, relayMode int) *types.NewAPIError {
case relayconstant.RelayModeResponses: case relayconstant.RelayModeResponses:
err = relay.ResponsesHelper(c) err = relay.ResponsesHelper(c)
case relayconstant.RelayModeGemini: case relayconstant.RelayModeGemini:
err = relay.GeminiHelper(c) if strings.Contains(c.Request.URL.Path, "embed") {
err = relay.GeminiEmbeddingHandler(c)
} else {
err = relay.GeminiHelper(c)
}
default: default:
err = relay.TextHelper(c) err = relay.TextHelper(c)
} }

View File

@@ -210,16 +210,25 @@ type GeminiImagePrediction struct {
// Embedding related structs // Embedding related structs
type GeminiEmbeddingRequest struct { type GeminiEmbeddingRequest struct {
Model string `json:"model,omitempty"`
Content GeminiChatContent `json:"content"` Content GeminiChatContent `json:"content"`
TaskType string `json:"taskType,omitempty"` TaskType string `json:"taskType,omitempty"`
Title string `json:"title,omitempty"` Title string `json:"title,omitempty"`
OutputDimensionality int `json:"outputDimensionality,omitempty"` OutputDimensionality int `json:"outputDimensionality,omitempty"`
} }
type GeminiBatchEmbeddingRequest struct {
Requests []GeminiEmbeddingRequest `json:"requests"`
}
type GeminiEmbeddingResponse struct { type GeminiEmbeddingResponse struct {
Embedding ContentEmbedding `json:"embedding"` Embedding ContentEmbedding `json:"embedding"`
} }
type GeminiBatchEmbeddingResponse struct {
Embeddings []ContentEmbedding `json:"embeddings"`
}
type ContentEmbedding struct { type ContentEmbedding struct {
Values []float64 `json:"values"` Values []float64 `json:"values"`
} }

View File

@@ -114,7 +114,11 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
if strings.HasPrefix(info.UpstreamModelName, "text-embedding") || if strings.HasPrefix(info.UpstreamModelName, "text-embedding") ||
strings.HasPrefix(info.UpstreamModelName, "embedding") || strings.HasPrefix(info.UpstreamModelName, "embedding") ||
strings.HasPrefix(info.UpstreamModelName, "gemini-embedding") { strings.HasPrefix(info.UpstreamModelName, "gemini-embedding") {
return fmt.Sprintf("%s/%s/models/%s:embedContent", info.BaseUrl, version, info.UpstreamModelName), nil action := "embedContent"
if info.IsGeminiBatchEmbdding {
action = "batchEmbedContents"
}
return fmt.Sprintf("%s/%s/models/%s:%s", info.BaseUrl, version, info.UpstreamModelName, action), nil
} }
action := "generateContent" action := "generateContent"
@@ -195,6 +199,9 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *types.NewAPIError) { func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *types.NewAPIError) {
if info.RelayMode == constant.RelayModeGemini { if info.RelayMode == constant.RelayModeGemini {
if strings.Contains(info.RequestURLPath, "embed") {
return NativeGeminiEmbeddingHandler(c, resp, info)
}
if info.IsStream { if info.IsStream {
return GeminiTextGenerationStreamHandler(c, info, resp) return GeminiTextGenerationStreamHandler(c, info, resp)
} else { } else {

View File

@@ -1,7 +1,6 @@
package gemini package gemini
import ( import (
"github.com/pkg/errors"
"io" "io"
"net/http" "net/http"
"one-api/common" "one-api/common"
@@ -12,6 +11,8 @@ import (
"one-api/types" "one-api/types"
"strings" "strings"
"github.com/pkg/errors"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
) )
@@ -63,6 +64,42 @@ func GeminiTextGenerationHandler(c *gin.Context, info *relaycommon.RelayInfo, re
return &usage, nil return &usage, nil
} }
func NativeGeminiEmbeddingHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.Usage, *types.NewAPIError) {
defer common.CloseResponseBodyGracefully(resp)
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
}
if common.DebugEnabled {
println(string(responseBody))
}
usage := &dto.Usage{
PromptTokens: info.PromptTokens,
TotalTokens: info.PromptTokens,
}
if info.IsGeminiBatchEmbdding {
var geminiResponse dto.GeminiBatchEmbeddingResponse
err = common.Unmarshal(responseBody, &geminiResponse)
if err != nil {
return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
}
} else {
var geminiResponse dto.GeminiEmbeddingResponse
err = common.Unmarshal(responseBody, &geminiResponse)
if err != nil {
return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
}
}
common.IOCopyBytesGracefully(c, resp, responseBody)
return usage, nil
}
func GeminiTextGenerationStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) { func GeminiTextGenerationStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
var usage = &dto.Usage{} var usage = &dto.Usage{}
var imageCount int var imageCount int

View File

@@ -74,13 +74,14 @@ type RelayInfo struct {
FirstResponseTime time.Time FirstResponseTime time.Time
isFirstResponse bool isFirstResponse bool
//SendLastReasoningResponse bool //SendLastReasoningResponse bool
ApiType int ApiType int
IsStream bool IsStream bool
IsPlayground bool IsGeminiBatchEmbdding bool
UsePrice bool IsPlayground bool
RelayMode int UsePrice bool
UpstreamModelName string RelayMode int
OriginModelName string UpstreamModelName string
OriginModelName string
//RecodeModelName string //RecodeModelName string
RequestURLPath string RequestURLPath string
ApiVersion string ApiVersion string

View File

@@ -264,3 +264,105 @@ func GeminiHelper(c *gin.Context) (newAPIError *types.NewAPIError) {
postConsumeQuota(c, relayInfo, usage.(*dto.Usage), preConsumedQuota, userQuota, priceData, "") postConsumeQuota(c, relayInfo, usage.(*dto.Usage), preConsumedQuota, userQuota, priceData, "")
return nil return nil
} }
func GeminiEmbeddingHandler(c *gin.Context) (newAPIError *types.NewAPIError) {
relayInfo := relaycommon.GenRelayInfoGemini(c)
isBatch := strings.HasSuffix(c.Request.URL.Path, "batchEmbedContents")
relayInfo.IsGeminiBatchEmbdding = isBatch
var promptTokens int
var req any
var err error
var inputTexts []string
if isBatch {
batchRequest := &dto.GeminiBatchEmbeddingRequest{}
err = common.UnmarshalBodyReusable(c, batchRequest)
if err != nil {
return types.NewError(err, types.ErrorCodeInvalidRequest, types.ErrOptionWithSkipRetry())
}
req = batchRequest
for _, r := range batchRequest.Requests {
for _, part := range r.Content.Parts {
if part.Text != "" {
inputTexts = append(inputTexts, part.Text)
}
}
}
} else {
singleRequest := &dto.GeminiEmbeddingRequest{}
err = common.UnmarshalBodyReusable(c, singleRequest)
if err != nil {
return types.NewError(err, types.ErrorCodeInvalidRequest, types.ErrOptionWithSkipRetry())
}
req = singleRequest
for _, part := range singleRequest.Content.Parts {
if part.Text != "" {
inputTexts = append(inputTexts, part.Text)
}
}
}
promptTokens = service.CountTokenInput(strings.Join(inputTexts, "\n"), relayInfo.UpstreamModelName)
relayInfo.SetPromptTokens(promptTokens)
c.Set("prompt_tokens", promptTokens)
err = helper.ModelMappedHelper(c, relayInfo, req)
if err != nil {
return types.NewError(err, types.ErrorCodeChannelModelMappedError, types.ErrOptionWithSkipRetry())
}
priceData, err := helper.ModelPriceHelper(c, relayInfo, relayInfo.PromptTokens, 0)
if err != nil {
return types.NewError(err, types.ErrorCodeModelPriceError, types.ErrOptionWithSkipRetry())
}
preConsumedQuota, userQuota, newAPIError := preConsumeQuota(c, priceData.ShouldPreConsumedQuota, relayInfo)
if newAPIError != nil {
return newAPIError
}
defer func() {
if newAPIError != nil {
returnPreConsumedQuota(c, relayInfo, userQuota, preConsumedQuota)
}
}()
adaptor := GetAdaptor(relayInfo.ApiType)
if adaptor == nil {
return types.NewError(fmt.Errorf("invalid api type: %d", relayInfo.ApiType), types.ErrorCodeInvalidApiType, types.ErrOptionWithSkipRetry())
}
adaptor.Init(relayInfo)
var requestBody io.Reader
jsonData, err := common.Marshal(req)
if err != nil {
return types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry())
}
requestBody = bytes.NewReader(jsonData)
resp, err := adaptor.DoRequest(c, relayInfo, requestBody)
if err != nil {
common.LogError(c, "Do gemini request failed: "+err.Error())
return types.NewOpenAIError(err, types.ErrorCodeDoRequestFailed, http.StatusInternalServerError)
}
statusCodeMappingStr := c.GetString("status_code_mapping")
var httpResp *http.Response
if resp != nil {
httpResp = resp.(*http.Response)
if httpResp.StatusCode != http.StatusOK {
newAPIError = service.RelayErrorHandler(httpResp, false)
service.ResetStatusCode(newAPIError, statusCodeMappingStr)
return newAPIError
}
}
usage, openaiErr := adaptor.DoResponse(c, resp.(*http.Response), relayInfo)
if openaiErr != nil {
service.ResetStatusCode(openaiErr, statusCodeMappingStr)
return openaiErr
}
postConsumeQuota(c, relayInfo, usage.(*dto.Usage), preConsumedQuota, userQuota, priceData, "")
return nil
}