This update standardizes the closure of HTTP response bodies across multiple stream handlers, enhancing error management and resource cleanup. The new method ensures that any errors during closure are handled gracefully, preventing potential request termination issues.
249 lines
6.8 KiB
Go
249 lines
6.8 KiB
Go
package zhipu
|
|
|
|
import (
|
|
"bufio"
|
|
"encoding/json"
|
|
"github.com/gin-gonic/gin"
|
|
"github.com/golang-jwt/jwt"
|
|
"io"
|
|
"net/http"
|
|
"one-api/common"
|
|
"one-api/constant"
|
|
"one-api/dto"
|
|
"one-api/relay/helper"
|
|
"one-api/service"
|
|
"strings"
|
|
"sync"
|
|
"time"
|
|
)
|
|
|
|
// https://open.bigmodel.cn/doc/api#chatglm_std
|
|
// chatglm_std, chatglm_lite
|
|
// https://open.bigmodel.cn/api/paas/v3/model-api/chatglm_std/invoke
|
|
// https://open.bigmodel.cn/api/paas/v3/model-api/chatglm_std/sse-invoke
|
|
|
|
var zhipuTokens sync.Map
|
|
var expSeconds int64 = 24 * 3600
|
|
|
|
func getZhipuToken(apikey string) string {
|
|
data, ok := zhipuTokens.Load(apikey)
|
|
if ok {
|
|
tokenData := data.(zhipuTokenData)
|
|
if time.Now().Before(tokenData.ExpiryTime) {
|
|
return tokenData.Token
|
|
}
|
|
}
|
|
|
|
split := strings.Split(apikey, ".")
|
|
if len(split) != 2 {
|
|
common.SysError("invalid zhipu key: " + apikey)
|
|
return ""
|
|
}
|
|
|
|
id := split[0]
|
|
secret := split[1]
|
|
|
|
expMillis := time.Now().Add(time.Duration(expSeconds)*time.Second).UnixNano() / 1e6
|
|
expiryTime := time.Now().Add(time.Duration(expSeconds) * time.Second)
|
|
|
|
timestamp := time.Now().UnixNano() / 1e6
|
|
|
|
payload := jwt.MapClaims{
|
|
"api_key": id,
|
|
"exp": expMillis,
|
|
"timestamp": timestamp,
|
|
}
|
|
|
|
token := jwt.NewWithClaims(jwt.SigningMethodHS256, payload)
|
|
|
|
token.Header["alg"] = "HS256"
|
|
token.Header["sign_type"] = "SIGN"
|
|
|
|
tokenString, err := token.SignedString([]byte(secret))
|
|
if err != nil {
|
|
return ""
|
|
}
|
|
|
|
zhipuTokens.Store(apikey, zhipuTokenData{
|
|
Token: tokenString,
|
|
ExpiryTime: expiryTime,
|
|
})
|
|
|
|
return tokenString
|
|
}
|
|
|
|
func requestOpenAI2Zhipu(request dto.GeneralOpenAIRequest) *ZhipuRequest {
|
|
messages := make([]ZhipuMessage, 0, len(request.Messages))
|
|
for _, message := range request.Messages {
|
|
if message.Role == "system" {
|
|
messages = append(messages, ZhipuMessage{
|
|
Role: "system",
|
|
Content: message.StringContent(),
|
|
})
|
|
messages = append(messages, ZhipuMessage{
|
|
Role: "user",
|
|
Content: "Okay",
|
|
})
|
|
} else {
|
|
messages = append(messages, ZhipuMessage{
|
|
Role: message.Role,
|
|
Content: message.StringContent(),
|
|
})
|
|
}
|
|
}
|
|
return &ZhipuRequest{
|
|
Prompt: messages,
|
|
Temperature: request.Temperature,
|
|
TopP: request.TopP,
|
|
Incremental: false,
|
|
}
|
|
}
|
|
|
|
func responseZhipu2OpenAI(response *ZhipuResponse) *dto.OpenAITextResponse {
|
|
fullTextResponse := dto.OpenAITextResponse{
|
|
Id: response.Data.TaskId,
|
|
Object: "chat.completion",
|
|
Created: common.GetTimestamp(),
|
|
Choices: make([]dto.OpenAITextResponseChoice, 0, len(response.Data.Choices)),
|
|
Usage: response.Data.Usage,
|
|
}
|
|
for i, choice := range response.Data.Choices {
|
|
openaiChoice := dto.OpenAITextResponseChoice{
|
|
Index: i,
|
|
Message: dto.Message{
|
|
Role: choice.Role,
|
|
Content: strings.Trim(choice.Content, "\""),
|
|
},
|
|
FinishReason: "",
|
|
}
|
|
if i == len(response.Data.Choices)-1 {
|
|
openaiChoice.FinishReason = "stop"
|
|
}
|
|
fullTextResponse.Choices = append(fullTextResponse.Choices, openaiChoice)
|
|
}
|
|
return &fullTextResponse
|
|
}
|
|
|
|
func streamResponseZhipu2OpenAI(zhipuResponse string) *dto.ChatCompletionsStreamResponse {
|
|
var choice dto.ChatCompletionsStreamResponseChoice
|
|
choice.Delta.SetContentString(zhipuResponse)
|
|
response := dto.ChatCompletionsStreamResponse{
|
|
Object: "chat.completion.chunk",
|
|
Created: common.GetTimestamp(),
|
|
Model: "chatglm",
|
|
Choices: []dto.ChatCompletionsStreamResponseChoice{choice},
|
|
}
|
|
return &response
|
|
}
|
|
|
|
func streamMetaResponseZhipu2OpenAI(zhipuResponse *ZhipuStreamMetaResponse) (*dto.ChatCompletionsStreamResponse, *dto.Usage) {
|
|
var choice dto.ChatCompletionsStreamResponseChoice
|
|
choice.Delta.SetContentString("")
|
|
choice.FinishReason = &constant.FinishReasonStop
|
|
response := dto.ChatCompletionsStreamResponse{
|
|
Id: zhipuResponse.RequestId,
|
|
Object: "chat.completion.chunk",
|
|
Created: common.GetTimestamp(),
|
|
Model: "chatglm",
|
|
Choices: []dto.ChatCompletionsStreamResponseChoice{choice},
|
|
}
|
|
return &response, &zhipuResponse.Usage
|
|
}
|
|
|
|
func zhipuStreamHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
|
|
var usage *dto.Usage
|
|
scanner := bufio.NewScanner(resp.Body)
|
|
scanner.Split(bufio.ScanLines)
|
|
dataChan := make(chan string)
|
|
metaChan := make(chan string)
|
|
stopChan := make(chan bool)
|
|
go func() {
|
|
for scanner.Scan() {
|
|
data := scanner.Text()
|
|
lines := strings.Split(data, "\n")
|
|
for i, line := range lines {
|
|
if len(line) < 5 {
|
|
continue
|
|
}
|
|
if line[:5] == "data:" {
|
|
dataChan <- line[5:]
|
|
if i != len(lines)-1 {
|
|
dataChan <- "\n"
|
|
}
|
|
} else if line[:5] == "meta:" {
|
|
metaChan <- line[5:]
|
|
}
|
|
}
|
|
}
|
|
stopChan <- true
|
|
}()
|
|
helper.SetEventStreamHeaders(c)
|
|
c.Stream(func(w io.Writer) bool {
|
|
select {
|
|
case data := <-dataChan:
|
|
response := streamResponseZhipu2OpenAI(data)
|
|
jsonResponse, err := json.Marshal(response)
|
|
if err != nil {
|
|
common.SysError("error marshalling stream response: " + err.Error())
|
|
return true
|
|
}
|
|
c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)})
|
|
return true
|
|
case data := <-metaChan:
|
|
var zhipuResponse ZhipuStreamMetaResponse
|
|
err := json.Unmarshal([]byte(data), &zhipuResponse)
|
|
if err != nil {
|
|
common.SysError("error unmarshalling stream response: " + err.Error())
|
|
return true
|
|
}
|
|
response, zhipuUsage := streamMetaResponseZhipu2OpenAI(&zhipuResponse)
|
|
jsonResponse, err := json.Marshal(response)
|
|
if err != nil {
|
|
common.SysError("error marshalling stream response: " + err.Error())
|
|
return true
|
|
}
|
|
usage = zhipuUsage
|
|
c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)})
|
|
return true
|
|
case <-stopChan:
|
|
c.Render(-1, common.CustomEvent{Data: "data: [DONE]"})
|
|
return false
|
|
}
|
|
})
|
|
common.CloseResponseBodyGracefully(resp)
|
|
return nil, usage
|
|
}
|
|
|
|
func zhipuHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
|
|
var zhipuResponse ZhipuResponse
|
|
responseBody, err := io.ReadAll(resp.Body)
|
|
if err != nil {
|
|
return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
|
|
}
|
|
common.CloseResponseBodyGracefully(resp)
|
|
err = json.Unmarshal(responseBody, &zhipuResponse)
|
|
if err != nil {
|
|
return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
|
|
}
|
|
if !zhipuResponse.Success {
|
|
return &dto.OpenAIErrorWithStatusCode{
|
|
Error: dto.OpenAIError{
|
|
Message: zhipuResponse.Msg,
|
|
Type: "zhipu_error",
|
|
Param: "",
|
|
Code: zhipuResponse.Code,
|
|
},
|
|
StatusCode: resp.StatusCode,
|
|
}, nil
|
|
}
|
|
fullTextResponse := responseZhipu2OpenAI(&zhipuResponse)
|
|
jsonResponse, err := json.Marshal(fullTextResponse)
|
|
if err != nil {
|
|
return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
|
|
}
|
|
c.Writer.Header().Set("Content-Type", "application/json")
|
|
c.Writer.WriteHeader(resp.StatusCode)
|
|
_, err = c.Writer.Write(jsonResponse)
|
|
return nil, &fullTextResponse.Usage
|
|
}
|