Merge branch 'alpha' into 'feat/support-native-gemini-embedding'

This commit is contained in:
RedwindA
2025-08-09 18:05:11 +08:00
37 changed files with 1095 additions and 697 deletions

View File

@@ -119,6 +119,7 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
action = "batchEmbedContents"
}
return fmt.Sprintf("%s/%s/models/%s:%s", info.BaseUrl, version, info.UpstreamModelName, action), nil
return fmt.Sprintf("%s/%s/models/%s:batchEmbedContents", info.BaseUrl, version, info.UpstreamModelName), nil
}
action := "generateContent"
@@ -163,29 +164,35 @@ func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.Rela
if len(inputs) == 0 {
return nil, errors.New("input is empty")
}
// only process the first input
geminiRequest := dto.GeminiEmbeddingRequest{
Content: dto.GeminiChatContent{
Parts: []dto.GeminiPart{
{
Text: inputs[0],
// process all inputs
geminiRequests := make([]map[string]interface{}, 0, len(inputs))
for _, input := range inputs {
geminiRequest := map[string]interface{}{
"model": fmt.Sprintf("models/%s", info.UpstreamModelName),
"content": dto.GeminiChatContent{
Parts: []dto.GeminiPart{
{
Text: input,
},
},
},
},
}
// set specific parameters for different models
// https://ai.google.dev/api/embeddings?hl=zh-cn#method:-models.embedcontent
switch info.UpstreamModelName {
case "text-embedding-004":
// except embedding-001 supports setting `OutputDimensionality`
if request.Dimensions > 0 {
geminiRequest.OutputDimensionality = request.Dimensions
}
// set specific parameters for different models
// https://ai.google.dev/api/embeddings?hl=zh-cn#method:-models.embedcontent
switch info.UpstreamModelName {
case "text-embedding-004", "gemini-embedding-exp-03-07", "gemini-embedding-001":
// Only newer models introduced after 2024 support OutputDimensionality
if request.Dimensions > 0 {
geminiRequest["outputDimensionality"] = request.Dimensions
}
}
geminiRequests = append(geminiRequests, geminiRequest)
}
return geminiRequest, nil
return map[string]interface{}{
"requests": geminiRequests,
}, nil
}
func (a *Adaptor) ConvertOpenAIResponsesRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.OpenAIResponsesRequest) (any, error) {

View File

@@ -1071,7 +1071,7 @@ func GeminiEmbeddingHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *h
return nil, types.NewOpenAIError(readErr, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
}
var geminiResponse dto.GeminiEmbeddingResponse
var geminiResponse dto.GeminiBatchEmbeddingResponse
if jsonErr := common.Unmarshal(responseBody, &geminiResponse); jsonErr != nil {
return nil, types.NewOpenAIError(jsonErr, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
}
@@ -1079,14 +1079,16 @@ func GeminiEmbeddingHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *h
// convert to openai format response
openAIResponse := dto.OpenAIEmbeddingResponse{
Object: "list",
Data: []dto.OpenAIEmbeddingResponseItem{
{
Object: "embedding",
Embedding: geminiResponse.Embedding.Values,
Index: 0,
},
},
Model: info.UpstreamModelName,
Data: make([]dto.OpenAIEmbeddingResponseItem, 0, len(geminiResponse.Embeddings)),
Model: info.UpstreamModelName,
}
for i, embedding := range geminiResponse.Embeddings {
openAIResponse.Data = append(openAIResponse.Data, dto.OpenAIEmbeddingResponseItem{
Object: "embedding",
Embedding: embedding.Values,
Index: i,
})
}
// calculate usage

View File

@@ -54,8 +54,7 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error {
channel.SetupApiRequestHeader(info, c, req)
token := getZhipuToken(info.ApiKey)
req.Set("Authorization", token)
req.Set("Authorization", "Bearer "+info.ApiKey)
return nil
}

View File

@@ -1,69 +1,10 @@
package zhipu_4v
import (
"github.com/golang-jwt/jwt"
"one-api/common"
"one-api/dto"
"strings"
"sync"
"time"
)
// https://open.bigmodel.cn/doc/api#chatglm_std
// chatglm_std, chatglm_lite
// https://open.bigmodel.cn/api/paas/v3/model-api/chatglm_std/invoke
// https://open.bigmodel.cn/api/paas/v3/model-api/chatglm_std/sse-invoke
var zhipuTokens sync.Map
var expSeconds int64 = 24 * 3600
func getZhipuToken(apikey string) string {
data, ok := zhipuTokens.Load(apikey)
if ok {
tokenData := data.(tokenData)
if time.Now().Before(tokenData.ExpiryTime) {
return tokenData.Token
}
}
split := strings.Split(apikey, ".")
if len(split) != 2 {
common.SysError("invalid zhipu key: " + apikey)
return ""
}
id := split[0]
secret := split[1]
expMillis := time.Now().Add(time.Duration(expSeconds)*time.Second).UnixNano() / 1e6
expiryTime := time.Now().Add(time.Duration(expSeconds) * time.Second)
timestamp := time.Now().UnixNano() / 1e6
payload := jwt.MapClaims{
"api_key": id,
"exp": expMillis,
"timestamp": timestamp,
}
token := jwt.NewWithClaims(jwt.SigningMethodHS256, payload)
token.Header["alg"] = "HS256"
token.Header["sign_type"] = "SIGN"
tokenString, err := token.SignedString([]byte(secret))
if err != nil {
return ""
}
zhipuTokens.Store(apikey, tokenData{
Token: tokenString,
ExpiryTime: expiryTime,
})
return tokenString
}
func requestOpenAI2Zhipu(request dto.GeneralOpenAIRequest) *dto.GeneralOpenAIRequest {
messages := make([]dto.Message, 0, len(request.Messages))
for _, message := range request.Messages {

View File

@@ -140,10 +140,10 @@ func TextHelper(c *gin.Context) (newAPIError *types.NewAPIError) {
returnPreConsumedQuota(c, relayInfo, userQuota, preConsumedQuota)
}
}()
includeUsage := false
includeUsage := true
// 判断用户是否需要返回使用情况
if textRequest.StreamOptions != nil && textRequest.StreamOptions.IncludeUsage {
includeUsage = true
if textRequest.StreamOptions != nil {
includeUsage = textRequest.StreamOptions.IncludeUsage
}
// 如果不支持StreamOptions将StreamOptions设置为nil
@@ -158,9 +158,7 @@ func TextHelper(c *gin.Context) (newAPIError *types.NewAPIError) {
}
}
if includeUsage {
relayInfo.ShouldIncludeUsage = true
}
relayInfo.ShouldIncludeUsage = includeUsage
adaptor := GetAdaptor(relayInfo.ApiType)
if adaptor == nil {
@@ -201,6 +199,26 @@ func TextHelper(c *gin.Context) (newAPIError *types.NewAPIError) {
Content: relayInfo.ChannelSetting.SystemPrompt,
}
request.Messages = append([]dto.Message{systemMessage}, request.Messages...)
} else if relayInfo.ChannelSetting.SystemPromptOverride {
common.SetContextKey(c, constant.ContextKeySystemPromptOverride, true)
// 如果有系统提示,且允许覆盖,则拼接到前面
for i, message := range request.Messages {
if message.Role == request.GetSystemRoleName() {
if message.IsStringContent() {
request.Messages[i].SetStringContent(relayInfo.ChannelSetting.SystemPrompt + "\n" + message.StringContent())
} else {
contents := message.ParseContent()
contents = append([]dto.MediaContent{
{
Type: dto.ContentTypeText,
Text: relayInfo.ChannelSetting.SystemPrompt,
},
}, contents...)
request.Messages[i].Content = contents
}
break
}
}
}
}