diff --git a/controller/relay.go b/controller/relay.go index 1a35c7d7..c97eca20 100644 --- a/controller/relay.go +++ b/controller/relay.go @@ -42,7 +42,11 @@ func relayHandler(c *gin.Context, relayMode int) *types.NewAPIError { case relayconstant.RelayModeResponses: err = relay.ResponsesHelper(c) case relayconstant.RelayModeGemini: - err = relay.GeminiHelper(c) + if strings.Contains(c.Request.URL.Path, "embed") { + err = relay.GeminiEmbeddingHandler(c) + } else { + err = relay.GeminiHelper(c) + } default: err = relay.TextHelper(c) } diff --git a/dto/gemini.go b/dto/gemini.go index 1bd1fe4c..915b7b81 100644 --- a/dto/gemini.go +++ b/dto/gemini.go @@ -210,6 +210,7 @@ type GeminiImagePrediction struct { // Embedding related structs type GeminiEmbeddingRequest struct { + Model string `json:"model,omitempty"` Content GeminiChatContent `json:"content"` TaskType string `json:"taskType,omitempty"` Title string `json:"title,omitempty"` @@ -220,10 +221,14 @@ type GeminiBatchEmbeddingRequest struct { Requests []*GeminiEmbeddingRequest `json:"requests"` } -type GeminiEmbedding struct { - Values []float64 `json:"values"` +type GeminiEmbeddingResponse struct { + Embedding ContentEmbedding `json:"embedding"` } type GeminiBatchEmbeddingResponse struct { - Embeddings []*GeminiEmbedding `json:"embeddings"` + Embeddings []*ContentEmbedding `json:"embeddings"` +} + +type ContentEmbedding struct { + Values []float64 `json:"values"` } diff --git a/relay/channel/gemini/adaptor.go b/relay/channel/gemini/adaptor.go index e5b4146a..4141caf7 100644 --- a/relay/channel/gemini/adaptor.go +++ b/relay/channel/gemini/adaptor.go @@ -114,7 +114,11 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) { if strings.HasPrefix(info.UpstreamModelName, "text-embedding") || strings.HasPrefix(info.UpstreamModelName, "embedding") || strings.HasPrefix(info.UpstreamModelName, "gemini-embedding") { - return fmt.Sprintf("%s/%s/models/%s:batchEmbedContents", info.BaseUrl, version, info.UpstreamModelName), nil + action := "embedContent" + if info.IsGeminiBatchEmbedding { + action = "batchEmbedContents" + } + return fmt.Sprintf("%s/%s/models/%s:%s", info.BaseUrl, version, info.UpstreamModelName, action), nil } action := "generateContent" @@ -159,6 +163,9 @@ func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.Rela if len(inputs) == 0 { return nil, errors.New("input is empty") } + // We always build a batch-style payload with `requests`, so ensure we call the + // batch endpoint upstream to avoid payload/endpoint mismatches. + info.IsGeminiBatchEmbedding = true // process all inputs geminiRequests := make([]map[string]interface{}, 0, len(inputs)) for _, input := range inputs { @@ -176,7 +183,7 @@ func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.Rela // set specific parameters for different models // https://ai.google.dev/api/embeddings?hl=zh-cn#method:-models.embedcontent switch info.UpstreamModelName { - case "text-embedding-004","gemini-embedding-exp-03-07","gemini-embedding-001": + case "text-embedding-004", "gemini-embedding-exp-03-07", "gemini-embedding-001": // Only newer models introduced after 2024 support OutputDimensionality if request.Dimensions > 0 { geminiRequest["outputDimensionality"] = request.Dimensions @@ -201,6 +208,10 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *types.NewAPIError) { if info.RelayMode == constant.RelayModeGemini { + if strings.HasSuffix(info.RequestURLPath, ":embedContent") || + strings.HasSuffix(info.RequestURLPath, ":batchEmbedContents") { + return NativeGeminiEmbeddingHandler(c, resp, info) + } if info.IsStream { return GeminiTextGenerationStreamHandler(c, info, resp) } else { @@ -225,18 +236,6 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycom return GeminiChatHandler(c, info, resp) } - //if usage.(*dto.Usage).CompletionTokenDetails.ReasoningTokens > 100 { - // // 没有请求-thinking的情况下,产生思考token,则按照思考模型计费 - // if !strings.HasSuffix(info.OriginModelName, "-thinking") && - // !strings.HasSuffix(info.OriginModelName, "-nothinking") { - // thinkingModelName := info.OriginModelName + "-thinking" - // if operation_setting.SelfUseModeEnabled || helper.ContainPriceOrRatio(thinkingModelName) { - // info.OriginModelName = thinkingModelName - // } - // } - //} - - return nil, types.NewError(errors.New("not implemented"), types.ErrorCodeBadResponseBody) } func (a *Adaptor) GetModelList() []string { diff --git a/relay/channel/gemini/relay-gemini-native.go b/relay/channel/gemini/relay-gemini-native.go index 1ba599b3..247b41fd 100644 --- a/relay/channel/gemini/relay-gemini-native.go +++ b/relay/channel/gemini/relay-gemini-native.go @@ -1,7 +1,6 @@ package gemini import ( - "github.com/pkg/errors" "io" "net/http" "one-api/common" @@ -12,6 +11,8 @@ import ( "one-api/types" "strings" + "github.com/pkg/errors" + "github.com/gin-gonic/gin" ) @@ -63,6 +64,42 @@ func GeminiTextGenerationHandler(c *gin.Context, info *relaycommon.RelayInfo, re return &usage, nil } +func NativeGeminiEmbeddingHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.Usage, *types.NewAPIError) { + defer common.CloseResponseBodyGracefully(resp) + + responseBody, err := io.ReadAll(resp.Body) + if err != nil { + return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError) + } + + if common.DebugEnabled { + println(string(responseBody)) + } + + usage := &dto.Usage{ + PromptTokens: info.PromptTokens, + TotalTokens: info.PromptTokens, + } + + if info.IsGeminiBatchEmbedding { + var geminiResponse dto.GeminiBatchEmbeddingResponse + err = common.Unmarshal(responseBody, &geminiResponse) + if err != nil { + return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError) + } + } else { + var geminiResponse dto.GeminiEmbeddingResponse + err = common.Unmarshal(responseBody, &geminiResponse) + if err != nil { + return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError) + } + } + + common.IOCopyBytesGracefully(c, resp, responseBody) + + return usage, nil +} + func GeminiTextGenerationStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) { var usage = &dto.Usage{} var imageCount int diff --git a/relay/common/relay_info.go b/relay/common/relay_info.go index 743070ca..9b7e3db5 100644 --- a/relay/common/relay_info.go +++ b/relay/common/relay_info.go @@ -74,13 +74,14 @@ type RelayInfo struct { FirstResponseTime time.Time isFirstResponse bool //SendLastReasoningResponse bool - ApiType int - IsStream bool - IsPlayground bool - UsePrice bool - RelayMode int - UpstreamModelName string - OriginModelName string + ApiType int + IsStream bool + IsGeminiBatchEmbedding bool + IsPlayground bool + UsePrice bool + RelayMode int + UpstreamModelName string + OriginModelName string //RecodeModelName string RequestURLPath string ApiVersion string diff --git a/relay/gemini_handler.go b/relay/gemini_handler.go index 42b695b7..e0581156 100644 --- a/relay/gemini_handler.go +++ b/relay/gemini_handler.go @@ -264,3 +264,118 @@ func GeminiHelper(c *gin.Context) (newAPIError *types.NewAPIError) { postConsumeQuota(c, relayInfo, usage.(*dto.Usage), preConsumedQuota, userQuota, priceData, "") return nil } + +func GeminiEmbeddingHandler(c *gin.Context) (newAPIError *types.NewAPIError) { + relayInfo := relaycommon.GenRelayInfoGemini(c) + + isBatch := strings.HasSuffix(c.Request.URL.Path, "batchEmbedContents") + relayInfo.IsGeminiBatchEmbedding = isBatch + + var promptTokens int + var req any + var err error + var inputTexts []string + + if isBatch { + batchRequest := &dto.GeminiBatchEmbeddingRequest{} + err = common.UnmarshalBodyReusable(c, batchRequest) + if err != nil { + return types.NewError(err, types.ErrorCodeInvalidRequest, types.ErrOptionWithSkipRetry()) + } + req = batchRequest + for _, r := range batchRequest.Requests { + for _, part := range r.Content.Parts { + if part.Text != "" { + inputTexts = append(inputTexts, part.Text) + } + } + } + } else { + singleRequest := &dto.GeminiEmbeddingRequest{} + err = common.UnmarshalBodyReusable(c, singleRequest) + if err != nil { + return types.NewError(err, types.ErrorCodeInvalidRequest, types.ErrOptionWithSkipRetry()) + } + req = singleRequest + for _, part := range singleRequest.Content.Parts { + if part.Text != "" { + inputTexts = append(inputTexts, part.Text) + } + } + } + promptTokens = service.CountTokenInput(strings.Join(inputTexts, "\n"), relayInfo.UpstreamModelName) + relayInfo.SetPromptTokens(promptTokens) + c.Set("prompt_tokens", promptTokens) + + err = helper.ModelMappedHelper(c, relayInfo, req) + if err != nil { + return types.NewError(err, types.ErrorCodeChannelModelMappedError, types.ErrOptionWithSkipRetry()) + } + + priceData, err := helper.ModelPriceHelper(c, relayInfo, relayInfo.PromptTokens, 0) + if err != nil { + return types.NewError(err, types.ErrorCodeModelPriceError, types.ErrOptionWithSkipRetry()) + } + + preConsumedQuota, userQuota, newAPIError := preConsumeQuota(c, priceData.ShouldPreConsumedQuota, relayInfo) + if newAPIError != nil { + return newAPIError + } + defer func() { + if newAPIError != nil { + returnPreConsumedQuota(c, relayInfo, userQuota, preConsumedQuota) + } + }() + + adaptor := GetAdaptor(relayInfo.ApiType) + if adaptor == nil { + return types.NewError(fmt.Errorf("invalid api type: %d", relayInfo.ApiType), types.ErrorCodeInvalidApiType, types.ErrOptionWithSkipRetry()) + } + adaptor.Init(relayInfo) + + var requestBody io.Reader + jsonData, err := common.Marshal(req) + if err != nil { + return types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry()) + } + + // apply param override + if len(relayInfo.ParamOverride) > 0 { + reqMap := make(map[string]interface{}) + _ = common.Unmarshal(jsonData, &reqMap) + for key, value := range relayInfo.ParamOverride { + reqMap[key] = value + } + jsonData, err = common.Marshal(reqMap) + if err != nil { + return types.NewError(err, types.ErrorCodeChannelParamOverrideInvalid, types.ErrOptionWithSkipRetry()) + } + } + requestBody = bytes.NewReader(jsonData) + + resp, err := adaptor.DoRequest(c, relayInfo, requestBody) + if err != nil { + common.LogError(c, "Do gemini request failed: "+err.Error()) + return types.NewOpenAIError(err, types.ErrorCodeDoRequestFailed, http.StatusInternalServerError) + } + + statusCodeMappingStr := c.GetString("status_code_mapping") + var httpResp *http.Response + if resp != nil { + httpResp = resp.(*http.Response) + if httpResp.StatusCode != http.StatusOK { + newAPIError = service.RelayErrorHandler(httpResp, false) + service.ResetStatusCode(newAPIError, statusCodeMappingStr) + return newAPIError + } + } + + usage, openaiErr := adaptor.DoResponse(c, resp.(*http.Response), relayInfo) + if openaiErr != nil { + service.ResetStatusCode(openaiErr, statusCodeMappingStr) + return openaiErr + } + + postConsumeQuota(c, relayInfo, usage.(*dto.Usage), preConsumedQuota, userQuota, priceData, "") + return nil +}