Merge pull request #1537 from RedwindA/feat/support-native-gemini-embedding

feat: 支持原生Gemini Embedding格式
This commit is contained in:
Calcium-Ion
2025-08-10 10:26:46 +08:00
committed by GitHub
6 changed files with 187 additions and 26 deletions

View File

@@ -42,7 +42,11 @@ func relayHandler(c *gin.Context, relayMode int) *types.NewAPIError {
case relayconstant.RelayModeResponses: case relayconstant.RelayModeResponses:
err = relay.ResponsesHelper(c) err = relay.ResponsesHelper(c)
case relayconstant.RelayModeGemini: case relayconstant.RelayModeGemini:
err = relay.GeminiHelper(c) if strings.Contains(c.Request.URL.Path, "embed") {
err = relay.GeminiEmbeddingHandler(c)
} else {
err = relay.GeminiHelper(c)
}
default: default:
err = relay.TextHelper(c) err = relay.TextHelper(c)
} }

View File

@@ -210,6 +210,7 @@ type GeminiImagePrediction struct {
// Embedding related structs // Embedding related structs
type GeminiEmbeddingRequest struct { type GeminiEmbeddingRequest struct {
Model string `json:"model,omitempty"`
Content GeminiChatContent `json:"content"` Content GeminiChatContent `json:"content"`
TaskType string `json:"taskType,omitempty"` TaskType string `json:"taskType,omitempty"`
Title string `json:"title,omitempty"` Title string `json:"title,omitempty"`
@@ -220,10 +221,14 @@ type GeminiBatchEmbeddingRequest struct {
Requests []*GeminiEmbeddingRequest `json:"requests"` Requests []*GeminiEmbeddingRequest `json:"requests"`
} }
type GeminiEmbedding struct { type GeminiEmbeddingResponse struct {
Values []float64 `json:"values"` Embedding ContentEmbedding `json:"embedding"`
} }
type GeminiBatchEmbeddingResponse struct { type GeminiBatchEmbeddingResponse struct {
Embeddings []*GeminiEmbedding `json:"embeddings"` Embeddings []*ContentEmbedding `json:"embeddings"`
}
type ContentEmbedding struct {
Values []float64 `json:"values"`
} }

View File

@@ -114,7 +114,11 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
if strings.HasPrefix(info.UpstreamModelName, "text-embedding") || if strings.HasPrefix(info.UpstreamModelName, "text-embedding") ||
strings.HasPrefix(info.UpstreamModelName, "embedding") || strings.HasPrefix(info.UpstreamModelName, "embedding") ||
strings.HasPrefix(info.UpstreamModelName, "gemini-embedding") { strings.HasPrefix(info.UpstreamModelName, "gemini-embedding") {
return fmt.Sprintf("%s/%s/models/%s:batchEmbedContents", info.BaseUrl, version, info.UpstreamModelName), nil action := "embedContent"
if info.IsGeminiBatchEmbedding {
action = "batchEmbedContents"
}
return fmt.Sprintf("%s/%s/models/%s:%s", info.BaseUrl, version, info.UpstreamModelName, action), nil
} }
action := "generateContent" action := "generateContent"
@@ -159,6 +163,9 @@ func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.Rela
if len(inputs) == 0 { if len(inputs) == 0 {
return nil, errors.New("input is empty") return nil, errors.New("input is empty")
} }
// We always build a batch-style payload with `requests`, so ensure we call the
// batch endpoint upstream to avoid payload/endpoint mismatches.
info.IsGeminiBatchEmbedding = true
// process all inputs // process all inputs
geminiRequests := make([]map[string]interface{}, 0, len(inputs)) geminiRequests := make([]map[string]interface{}, 0, len(inputs))
for _, input := range inputs { for _, input := range inputs {
@@ -176,7 +183,7 @@ func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.Rela
// set specific parameters for different models // set specific parameters for different models
// https://ai.google.dev/api/embeddings?hl=zh-cn#method:-models.embedcontent // https://ai.google.dev/api/embeddings?hl=zh-cn#method:-models.embedcontent
switch info.UpstreamModelName { switch info.UpstreamModelName {
case "text-embedding-004","gemini-embedding-exp-03-07","gemini-embedding-001": case "text-embedding-004", "gemini-embedding-exp-03-07", "gemini-embedding-001":
// Only newer models introduced after 2024 support OutputDimensionality // Only newer models introduced after 2024 support OutputDimensionality
if request.Dimensions > 0 { if request.Dimensions > 0 {
geminiRequest["outputDimensionality"] = request.Dimensions geminiRequest["outputDimensionality"] = request.Dimensions
@@ -201,6 +208,10 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *types.NewAPIError) { func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *types.NewAPIError) {
if info.RelayMode == constant.RelayModeGemini { if info.RelayMode == constant.RelayModeGemini {
if strings.HasSuffix(info.RequestURLPath, ":embedContent") ||
strings.HasSuffix(info.RequestURLPath, ":batchEmbedContents") {
return NativeGeminiEmbeddingHandler(c, resp, info)
}
if info.IsStream { if info.IsStream {
return GeminiTextGenerationStreamHandler(c, info, resp) return GeminiTextGenerationStreamHandler(c, info, resp)
} else { } else {
@@ -225,18 +236,6 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycom
return GeminiChatHandler(c, info, resp) return GeminiChatHandler(c, info, resp)
} }
//if usage.(*dto.Usage).CompletionTokenDetails.ReasoningTokens > 100 {
// // 没有请求-thinking的情况下产生思考token则按照思考模型计费
// if !strings.HasSuffix(info.OriginModelName, "-thinking") &&
// !strings.HasSuffix(info.OriginModelName, "-nothinking") {
// thinkingModelName := info.OriginModelName + "-thinking"
// if operation_setting.SelfUseModeEnabled || helper.ContainPriceOrRatio(thinkingModelName) {
// info.OriginModelName = thinkingModelName
// }
// }
//}
return nil, types.NewError(errors.New("not implemented"), types.ErrorCodeBadResponseBody)
} }
func (a *Adaptor) GetModelList() []string { func (a *Adaptor) GetModelList() []string {

View File

@@ -1,7 +1,6 @@
package gemini package gemini
import ( import (
"github.com/pkg/errors"
"io" "io"
"net/http" "net/http"
"one-api/common" "one-api/common"
@@ -12,6 +11,8 @@ import (
"one-api/types" "one-api/types"
"strings" "strings"
"github.com/pkg/errors"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
) )
@@ -63,6 +64,42 @@ func GeminiTextGenerationHandler(c *gin.Context, info *relaycommon.RelayInfo, re
return &usage, nil return &usage, nil
} }
func NativeGeminiEmbeddingHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.Usage, *types.NewAPIError) {
defer common.CloseResponseBodyGracefully(resp)
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
}
if common.DebugEnabled {
println(string(responseBody))
}
usage := &dto.Usage{
PromptTokens: info.PromptTokens,
TotalTokens: info.PromptTokens,
}
if info.IsGeminiBatchEmbedding {
var geminiResponse dto.GeminiBatchEmbeddingResponse
err = common.Unmarshal(responseBody, &geminiResponse)
if err != nil {
return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
}
} else {
var geminiResponse dto.GeminiEmbeddingResponse
err = common.Unmarshal(responseBody, &geminiResponse)
if err != nil {
return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
}
}
common.IOCopyBytesGracefully(c, resp, responseBody)
return usage, nil
}
func GeminiTextGenerationStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) { func GeminiTextGenerationStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
var usage = &dto.Usage{} var usage = &dto.Usage{}
var imageCount int var imageCount int

View File

@@ -74,13 +74,14 @@ type RelayInfo struct {
FirstResponseTime time.Time FirstResponseTime time.Time
isFirstResponse bool isFirstResponse bool
//SendLastReasoningResponse bool //SendLastReasoningResponse bool
ApiType int ApiType int
IsStream bool IsStream bool
IsPlayground bool IsGeminiBatchEmbedding bool
UsePrice bool IsPlayground bool
RelayMode int UsePrice bool
UpstreamModelName string RelayMode int
OriginModelName string UpstreamModelName string
OriginModelName string
//RecodeModelName string //RecodeModelName string
RequestURLPath string RequestURLPath string
ApiVersion string ApiVersion string

View File

@@ -264,3 +264,118 @@ func GeminiHelper(c *gin.Context) (newAPIError *types.NewAPIError) {
postConsumeQuota(c, relayInfo, usage.(*dto.Usage), preConsumedQuota, userQuota, priceData, "") postConsumeQuota(c, relayInfo, usage.(*dto.Usage), preConsumedQuota, userQuota, priceData, "")
return nil return nil
} }
func GeminiEmbeddingHandler(c *gin.Context) (newAPIError *types.NewAPIError) {
relayInfo := relaycommon.GenRelayInfoGemini(c)
isBatch := strings.HasSuffix(c.Request.URL.Path, "batchEmbedContents")
relayInfo.IsGeminiBatchEmbedding = isBatch
var promptTokens int
var req any
var err error
var inputTexts []string
if isBatch {
batchRequest := &dto.GeminiBatchEmbeddingRequest{}
err = common.UnmarshalBodyReusable(c, batchRequest)
if err != nil {
return types.NewError(err, types.ErrorCodeInvalidRequest, types.ErrOptionWithSkipRetry())
}
req = batchRequest
for _, r := range batchRequest.Requests {
for _, part := range r.Content.Parts {
if part.Text != "" {
inputTexts = append(inputTexts, part.Text)
}
}
}
} else {
singleRequest := &dto.GeminiEmbeddingRequest{}
err = common.UnmarshalBodyReusable(c, singleRequest)
if err != nil {
return types.NewError(err, types.ErrorCodeInvalidRequest, types.ErrOptionWithSkipRetry())
}
req = singleRequest
for _, part := range singleRequest.Content.Parts {
if part.Text != "" {
inputTexts = append(inputTexts, part.Text)
}
}
}
promptTokens = service.CountTokenInput(strings.Join(inputTexts, "\n"), relayInfo.UpstreamModelName)
relayInfo.SetPromptTokens(promptTokens)
c.Set("prompt_tokens", promptTokens)
err = helper.ModelMappedHelper(c, relayInfo, req)
if err != nil {
return types.NewError(err, types.ErrorCodeChannelModelMappedError, types.ErrOptionWithSkipRetry())
}
priceData, err := helper.ModelPriceHelper(c, relayInfo, relayInfo.PromptTokens, 0)
if err != nil {
return types.NewError(err, types.ErrorCodeModelPriceError, types.ErrOptionWithSkipRetry())
}
preConsumedQuota, userQuota, newAPIError := preConsumeQuota(c, priceData.ShouldPreConsumedQuota, relayInfo)
if newAPIError != nil {
return newAPIError
}
defer func() {
if newAPIError != nil {
returnPreConsumedQuota(c, relayInfo, userQuota, preConsumedQuota)
}
}()
adaptor := GetAdaptor(relayInfo.ApiType)
if adaptor == nil {
return types.NewError(fmt.Errorf("invalid api type: %d", relayInfo.ApiType), types.ErrorCodeInvalidApiType, types.ErrOptionWithSkipRetry())
}
adaptor.Init(relayInfo)
var requestBody io.Reader
jsonData, err := common.Marshal(req)
if err != nil {
return types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry())
}
// apply param override
if len(relayInfo.ParamOverride) > 0 {
reqMap := make(map[string]interface{})
_ = common.Unmarshal(jsonData, &reqMap)
for key, value := range relayInfo.ParamOverride {
reqMap[key] = value
}
jsonData, err = common.Marshal(reqMap)
if err != nil {
return types.NewError(err, types.ErrorCodeChannelParamOverrideInvalid, types.ErrOptionWithSkipRetry())
}
}
requestBody = bytes.NewReader(jsonData)
resp, err := adaptor.DoRequest(c, relayInfo, requestBody)
if err != nil {
common.LogError(c, "Do gemini request failed: "+err.Error())
return types.NewOpenAIError(err, types.ErrorCodeDoRequestFailed, http.StatusInternalServerError)
}
statusCodeMappingStr := c.GetString("status_code_mapping")
var httpResp *http.Response
if resp != nil {
httpResp = resp.(*http.Response)
if httpResp.StatusCode != http.StatusOK {
newAPIError = service.RelayErrorHandler(httpResp, false)
service.ResetStatusCode(newAPIError, statusCodeMappingStr)
return newAPIError
}
}
usage, openaiErr := adaptor.DoResponse(c, resp.(*http.Response), relayInfo)
if openaiErr != nil {
service.ResetStatusCode(openaiErr, statusCodeMappingStr)
return openaiErr
}
postConsumeQuota(c, relayInfo, usage.(*dto.Usage), preConsumedQuota, userQuota, priceData, "")
return nil
}