114 lines
2.8 KiB
Go
114 lines
2.8 KiB
Go
package zhipu_4v
|
||
|
||
import (
|
||
"github.com/golang-jwt/jwt"
|
||
"one-api/common"
|
||
"one-api/dto"
|
||
"strings"
|
||
"sync"
|
||
"time"
|
||
)
|
||
|
||
// https://open.bigmodel.cn/doc/api#chatglm_std
|
||
// chatglm_std, chatglm_lite
|
||
// https://open.bigmodel.cn/api/paas/v3/model-api/chatglm_std/invoke
|
||
// https://open.bigmodel.cn/api/paas/v3/model-api/chatglm_std/sse-invoke
|
||
|
||
var zhipuTokens sync.Map
|
||
var expSeconds int64 = 24 * 3600
|
||
|
||
func getZhipuToken(apikey string) string {
|
||
data, ok := zhipuTokens.Load(apikey)
|
||
if ok {
|
||
tokenData := data.(tokenData)
|
||
if time.Now().Before(tokenData.ExpiryTime) {
|
||
return tokenData.Token
|
||
}
|
||
}
|
||
|
||
split := strings.Split(apikey, ".")
|
||
if len(split) != 2 {
|
||
common.SysError("invalid zhipu key: " + apikey)
|
||
return ""
|
||
}
|
||
|
||
id := split[0]
|
||
secret := split[1]
|
||
|
||
expMillis := time.Now().Add(time.Duration(expSeconds)*time.Second).UnixNano() / 1e6
|
||
expiryTime := time.Now().Add(time.Duration(expSeconds) * time.Second)
|
||
|
||
timestamp := time.Now().UnixNano() / 1e6
|
||
|
||
payload := jwt.MapClaims{
|
||
"api_key": id,
|
||
"exp": expMillis,
|
||
"timestamp": timestamp,
|
||
}
|
||
|
||
token := jwt.NewWithClaims(jwt.SigningMethodHS256, payload)
|
||
|
||
token.Header["alg"] = "HS256"
|
||
token.Header["sign_type"] = "SIGN"
|
||
|
||
tokenString, err := token.SignedString([]byte(secret))
|
||
if err != nil {
|
||
return ""
|
||
}
|
||
|
||
zhipuTokens.Store(apikey, tokenData{
|
||
Token: tokenString,
|
||
ExpiryTime: expiryTime,
|
||
})
|
||
|
||
return tokenString
|
||
}
|
||
|
||
func requestOpenAI2Zhipu(request dto.GeneralOpenAIRequest) *dto.GeneralOpenAIRequest {
|
||
messages := make([]dto.Message, 0, len(request.Messages))
|
||
for _, message := range request.Messages {
|
||
if !message.IsStringContent() {
|
||
mediaMessages := message.ParseContent()
|
||
for j, mediaMessage := range mediaMessages {
|
||
if mediaMessage.Type == dto.ContentTypeImageURL {
|
||
imageUrl := mediaMessage.GetImageMedia()
|
||
// check if base64
|
||
if strings.HasPrefix(imageUrl.Url, "data:image/") {
|
||
// 去除base64数据的URL前缀(如果有)
|
||
if idx := strings.Index(imageUrl.Url, ","); idx != -1 {
|
||
imageUrl.Url = imageUrl.Url[idx+1:]
|
||
}
|
||
}
|
||
mediaMessage.ImageUrl = imageUrl
|
||
mediaMessages[j] = mediaMessage
|
||
}
|
||
}
|
||
message.SetMediaContent(mediaMessages)
|
||
}
|
||
messages = append(messages, dto.Message{
|
||
Role: message.Role,
|
||
Content: message.Content,
|
||
ToolCalls: message.ToolCalls,
|
||
ToolCallId: message.ToolCallId,
|
||
})
|
||
}
|
||
str, ok := request.Stop.(string)
|
||
var Stop []string
|
||
if ok {
|
||
Stop = []string{str}
|
||
} else {
|
||
Stop, _ = request.Stop.([]string)
|
||
}
|
||
return &dto.GeneralOpenAIRequest{
|
||
Model: request.Model,
|
||
Stream: request.Stream,
|
||
Messages: messages,
|
||
Temperature: request.Temperature,
|
||
TopP: request.TopP,
|
||
MaxTokens: request.GetMaxTokens(),
|
||
Stop: Stop,
|
||
Tools: request.Tools,
|
||
ToolChoice: request.ToolChoice,
|
||
}
|
||
}
|