fix typo; add ParamOverride for Gemini Embedding
This commit is contained in:
@@ -115,7 +115,7 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
|||||||
strings.HasPrefix(info.UpstreamModelName, "embedding") ||
|
strings.HasPrefix(info.UpstreamModelName, "embedding") ||
|
||||||
strings.HasPrefix(info.UpstreamModelName, "gemini-embedding") {
|
strings.HasPrefix(info.UpstreamModelName, "gemini-embedding") {
|
||||||
action := "embedContent"
|
action := "embedContent"
|
||||||
if info.IsGeminiBatchEmbdding {
|
if info.IsGeminiBatchEmbedding {
|
||||||
action = "batchEmbedContents"
|
action = "batchEmbedContents"
|
||||||
}
|
}
|
||||||
return fmt.Sprintf("%s/%s/models/%s:%s", info.BaseUrl, version, info.UpstreamModelName, action), nil
|
return fmt.Sprintf("%s/%s/models/%s:%s", info.BaseUrl, version, info.UpstreamModelName, action), nil
|
||||||
@@ -199,7 +199,8 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request
|
|||||||
|
|
||||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *types.NewAPIError) {
|
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *types.NewAPIError) {
|
||||||
if info.RelayMode == constant.RelayModeGemini {
|
if info.RelayMode == constant.RelayModeGemini {
|
||||||
if strings.Contains(info.RequestURLPath, "embed") {
|
if strings.HasSuffix(info.RequestURLPath, ":embedContent") ||
|
||||||
|
strings.HasSuffix(info.RequestURLPath, ":batchEmbedContents") {
|
||||||
return NativeGeminiEmbeddingHandler(c, resp, info)
|
return NativeGeminiEmbeddingHandler(c, resp, info)
|
||||||
}
|
}
|
||||||
if info.IsStream {
|
if info.IsStream {
|
||||||
|
|||||||
@@ -81,7 +81,7 @@ func NativeGeminiEmbeddingHandler(c *gin.Context, resp *http.Response, info *rel
|
|||||||
TotalTokens: info.PromptTokens,
|
TotalTokens: info.PromptTokens,
|
||||||
}
|
}
|
||||||
|
|
||||||
if info.IsGeminiBatchEmbdding {
|
if info.IsGeminiBatchEmbedding {
|
||||||
var geminiResponse dto.GeminiBatchEmbeddingResponse
|
var geminiResponse dto.GeminiBatchEmbeddingResponse
|
||||||
err = common.Unmarshal(responseBody, &geminiResponse)
|
err = common.Unmarshal(responseBody, &geminiResponse)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|||||||
@@ -74,14 +74,14 @@ type RelayInfo struct {
|
|||||||
FirstResponseTime time.Time
|
FirstResponseTime time.Time
|
||||||
isFirstResponse bool
|
isFirstResponse bool
|
||||||
//SendLastReasoningResponse bool
|
//SendLastReasoningResponse bool
|
||||||
ApiType int
|
ApiType int
|
||||||
IsStream bool
|
IsStream bool
|
||||||
IsGeminiBatchEmbdding bool
|
IsGeminiBatchEmbedding bool
|
||||||
IsPlayground bool
|
IsPlayground bool
|
||||||
UsePrice bool
|
UsePrice bool
|
||||||
RelayMode int
|
RelayMode int
|
||||||
UpstreamModelName string
|
UpstreamModelName string
|
||||||
OriginModelName string
|
OriginModelName string
|
||||||
//RecodeModelName string
|
//RecodeModelName string
|
||||||
RequestURLPath string
|
RequestURLPath string
|
||||||
ApiVersion string
|
ApiVersion string
|
||||||
|
|||||||
@@ -269,7 +269,7 @@ func GeminiEmbeddingHandler(c *gin.Context) (newAPIError *types.NewAPIError) {
|
|||||||
relayInfo := relaycommon.GenRelayInfoGemini(c)
|
relayInfo := relaycommon.GenRelayInfoGemini(c)
|
||||||
|
|
||||||
isBatch := strings.HasSuffix(c.Request.URL.Path, "batchEmbedContents")
|
isBatch := strings.HasSuffix(c.Request.URL.Path, "batchEmbedContents")
|
||||||
relayInfo.IsGeminiBatchEmbdding = isBatch
|
relayInfo.IsGeminiBatchEmbedding = isBatch
|
||||||
|
|
||||||
var promptTokens int
|
var promptTokens int
|
||||||
var req any
|
var req any
|
||||||
@@ -338,6 +338,19 @@ func GeminiEmbeddingHandler(c *gin.Context) (newAPIError *types.NewAPIError) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry())
|
return types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// apply param override
|
||||||
|
if len(relayInfo.ParamOverride) > 0 {
|
||||||
|
reqMap := make(map[string]interface{})
|
||||||
|
_ = common.Unmarshal(jsonData, &reqMap)
|
||||||
|
for key, value := range relayInfo.ParamOverride {
|
||||||
|
reqMap[key] = value
|
||||||
|
}
|
||||||
|
jsonData, err = common.Marshal(reqMap)
|
||||||
|
if err != nil {
|
||||||
|
return types.NewError(err, types.ErrorCodeChannelParamOverrideInvalid, types.ErrOptionWithSkipRetry())
|
||||||
|
}
|
||||||
|
}
|
||||||
requestBody = bytes.NewReader(jsonData)
|
requestBody = bytes.NewReader(jsonData)
|
||||||
|
|
||||||
resp, err := adaptor.DoRequest(c, relayInfo, requestBody)
|
resp, err := adaptor.DoRequest(c, relayInfo, requestBody)
|
||||||
|
|||||||
Reference in New Issue
Block a user