feat: update dto for embeddings

This commit is contained in:
RedwindA
2025-08-09 18:31:56 +08:00
parent 03b670971b
commit f242220132
2 changed files with 5 additions and 15 deletions

View File

@@ -218,7 +218,7 @@ type GeminiEmbeddingRequest struct {
} }
type GeminiBatchEmbeddingRequest struct { type GeminiBatchEmbeddingRequest struct {
Requests []GeminiEmbeddingRequest `json:"requests"` Requests []*GeminiEmbeddingRequest `json:"requests"`
} }
type GeminiEmbeddingResponse struct { type GeminiEmbeddingResponse struct {
@@ -226,7 +226,7 @@ type GeminiEmbeddingResponse struct {
} }
type GeminiBatchEmbeddingResponse struct { type GeminiBatchEmbeddingResponse struct {
Embeddings []ContentEmbedding `json:"embeddings"` Embeddings []*ContentEmbedding `json:"embeddings"`
} }
type ContentEmbedding struct { type ContentEmbedding struct {

View File

@@ -119,7 +119,6 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
action = "batchEmbedContents" action = "batchEmbedContents"
} }
return fmt.Sprintf("%s/%s/models/%s:%s", info.BaseUrl, version, info.UpstreamModelName, action), nil return fmt.Sprintf("%s/%s/models/%s:%s", info.BaseUrl, version, info.UpstreamModelName, action), nil
return fmt.Sprintf("%s/%s/models/%s:batchEmbedContents", info.BaseUrl, version, info.UpstreamModelName), nil
} }
action := "generateContent" action := "generateContent"
@@ -164,6 +163,9 @@ func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.Rela
if len(inputs) == 0 { if len(inputs) == 0 {
return nil, errors.New("input is empty") return nil, errors.New("input is empty")
} }
// We always build a batch-style payload with `requests`, so ensure we call the
// batch endpoint upstream to avoid payload/endpoint mismatches.
info.IsGeminiBatchEmbedding = true
// process all inputs // process all inputs
geminiRequests := make([]map[string]interface{}, 0, len(inputs)) geminiRequests := make([]map[string]interface{}, 0, len(inputs))
for _, input := range inputs { for _, input := range inputs {
@@ -234,18 +236,6 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycom
return GeminiChatHandler(c, info, resp) return GeminiChatHandler(c, info, resp)
} }
//if usage.(*dto.Usage).CompletionTokenDetails.ReasoningTokens > 100 {
// // 没有请求-thinking的情况下产生思考token则按照思考模型计费
// if !strings.HasSuffix(info.OriginModelName, "-thinking") &&
// !strings.HasSuffix(info.OriginModelName, "-nothinking") {
// thinkingModelName := info.OriginModelName + "-thinking"
// if operation_setting.SelfUseModeEnabled || helper.ContainPriceOrRatio(thinkingModelName) {
// info.OriginModelName = thinkingModelName
// }
// }
//}
return nil, types.NewError(errors.New("not implemented"), types.ErrorCodeBadResponseBody)
} }
func (a *Adaptor) GetModelList() []string { func (a *Adaptor) GetModelList() []string {