fix typo; add ParamOverride for Gemini Embedding

This commit is contained in:
RedwindA
2025-08-09 01:07:48 +08:00
parent b70d2655ed
commit 7a31e481a6
4 changed files with 26 additions and 12 deletions

View File

@@ -115,7 +115,7 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
strings.HasPrefix(info.UpstreamModelName, "embedding") || strings.HasPrefix(info.UpstreamModelName, "embedding") ||
strings.HasPrefix(info.UpstreamModelName, "gemini-embedding") { strings.HasPrefix(info.UpstreamModelName, "gemini-embedding") {
action := "embedContent" action := "embedContent"
if info.IsGeminiBatchEmbdding { if info.IsGeminiBatchEmbedding {
action = "batchEmbedContents" action = "batchEmbedContents"
} }
return fmt.Sprintf("%s/%s/models/%s:%s", info.BaseUrl, version, info.UpstreamModelName, action), nil return fmt.Sprintf("%s/%s/models/%s:%s", info.BaseUrl, version, info.UpstreamModelName, action), nil
@@ -199,7 +199,8 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *types.NewAPIError) { func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *types.NewAPIError) {
if info.RelayMode == constant.RelayModeGemini { if info.RelayMode == constant.RelayModeGemini {
if strings.Contains(info.RequestURLPath, "embed") { if strings.HasSuffix(info.RequestURLPath, ":embedContent") ||
strings.HasSuffix(info.RequestURLPath, ":batchEmbedContents") {
return NativeGeminiEmbeddingHandler(c, resp, info) return NativeGeminiEmbeddingHandler(c, resp, info)
} }
if info.IsStream { if info.IsStream {

View File

@@ -81,7 +81,7 @@ func NativeGeminiEmbeddingHandler(c *gin.Context, resp *http.Response, info *rel
TotalTokens: info.PromptTokens, TotalTokens: info.PromptTokens,
} }
if info.IsGeminiBatchEmbdding { if info.IsGeminiBatchEmbedding {
var geminiResponse dto.GeminiBatchEmbeddingResponse var geminiResponse dto.GeminiBatchEmbeddingResponse
err = common.Unmarshal(responseBody, &geminiResponse) err = common.Unmarshal(responseBody, &geminiResponse)
if err != nil { if err != nil {

View File

@@ -74,14 +74,14 @@ type RelayInfo struct {
FirstResponseTime time.Time FirstResponseTime time.Time
isFirstResponse bool isFirstResponse bool
//SendLastReasoningResponse bool //SendLastReasoningResponse bool
ApiType int ApiType int
IsStream bool IsStream bool
IsGeminiBatchEmbdding bool IsGeminiBatchEmbedding bool
IsPlayground bool IsPlayground bool
UsePrice bool UsePrice bool
RelayMode int RelayMode int
UpstreamModelName string UpstreamModelName string
OriginModelName string OriginModelName string
//RecodeModelName string //RecodeModelName string
RequestURLPath string RequestURLPath string
ApiVersion string ApiVersion string

View File

@@ -269,7 +269,7 @@ func GeminiEmbeddingHandler(c *gin.Context) (newAPIError *types.NewAPIError) {
relayInfo := relaycommon.GenRelayInfoGemini(c) relayInfo := relaycommon.GenRelayInfoGemini(c)
isBatch := strings.HasSuffix(c.Request.URL.Path, "batchEmbedContents") isBatch := strings.HasSuffix(c.Request.URL.Path, "batchEmbedContents")
relayInfo.IsGeminiBatchEmbdding = isBatch relayInfo.IsGeminiBatchEmbedding = isBatch
var promptTokens int var promptTokens int
var req any var req any
@@ -338,6 +338,19 @@ func GeminiEmbeddingHandler(c *gin.Context) (newAPIError *types.NewAPIError) {
if err != nil { if err != nil {
return types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry()) return types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry())
} }
// apply param override
if len(relayInfo.ParamOverride) > 0 {
reqMap := make(map[string]interface{})
_ = common.Unmarshal(jsonData, &reqMap)
for key, value := range relayInfo.ParamOverride {
reqMap[key] = value
}
jsonData, err = common.Marshal(reqMap)
if err != nil {
return types.NewError(err, types.ErrorCodeChannelParamOverrideInvalid, types.ErrOptionWithSkipRetry())
}
}
requestBody = bytes.NewReader(jsonData) requestBody = bytes.NewReader(jsonData)
resp, err := adaptor.DoRequest(c, relayInfo, requestBody) resp, err := adaptor.DoRequest(c, relayInfo, requestBody)