This update standardizes the closure of HTTP response bodies across multiple stream handlers, enhancing error management and resource cleanup. The new method ensures that any errors during closure are handled gracefully, preventing potential request termination issues.
295 lines
8.8 KiB
Go
295 lines
8.8 KiB
Go
package baidu
|
|
|
|
import (
|
|
"bufio"
|
|
"encoding/json"
|
|
"errors"
|
|
"fmt"
|
|
"github.com/gin-gonic/gin"
|
|
"io"
|
|
"net/http"
|
|
"one-api/common"
|
|
"one-api/constant"
|
|
"one-api/dto"
|
|
"one-api/relay/helper"
|
|
"one-api/service"
|
|
"strings"
|
|
"sync"
|
|
"time"
|
|
)
|
|
|
|
// https://cloud.baidu.com/doc/WENXINWORKSHOP/s/flfmc9do2
|
|
|
|
var baiduTokenStore sync.Map
|
|
|
|
func requestOpenAI2Baidu(request dto.GeneralOpenAIRequest) *BaiduChatRequest {
|
|
baiduRequest := BaiduChatRequest{
|
|
Temperature: request.Temperature,
|
|
TopP: request.TopP,
|
|
PenaltyScore: request.FrequencyPenalty,
|
|
Stream: request.Stream,
|
|
DisableSearch: false,
|
|
EnableCitation: false,
|
|
UserId: request.User,
|
|
}
|
|
if request.MaxTokens != 0 {
|
|
maxTokens := int(request.MaxTokens)
|
|
if request.MaxTokens == 1 {
|
|
maxTokens = 2
|
|
}
|
|
baiduRequest.MaxOutputTokens = &maxTokens
|
|
}
|
|
for _, message := range request.Messages {
|
|
if message.Role == "system" {
|
|
baiduRequest.System = message.StringContent()
|
|
} else {
|
|
baiduRequest.Messages = append(baiduRequest.Messages, BaiduMessage{
|
|
Role: message.Role,
|
|
Content: message.StringContent(),
|
|
})
|
|
}
|
|
}
|
|
return &baiduRequest
|
|
}
|
|
|
|
func responseBaidu2OpenAI(response *BaiduChatResponse) *dto.OpenAITextResponse {
|
|
choice := dto.OpenAITextResponseChoice{
|
|
Index: 0,
|
|
Message: dto.Message{
|
|
Role: "assistant",
|
|
Content: response.Result,
|
|
},
|
|
FinishReason: "stop",
|
|
}
|
|
fullTextResponse := dto.OpenAITextResponse{
|
|
Id: response.Id,
|
|
Object: "chat.completion",
|
|
Created: response.Created,
|
|
Choices: []dto.OpenAITextResponseChoice{choice},
|
|
Usage: response.Usage,
|
|
}
|
|
return &fullTextResponse
|
|
}
|
|
|
|
func streamResponseBaidu2OpenAI(baiduResponse *BaiduChatStreamResponse) *dto.ChatCompletionsStreamResponse {
|
|
var choice dto.ChatCompletionsStreamResponseChoice
|
|
choice.Delta.SetContentString(baiduResponse.Result)
|
|
if baiduResponse.IsEnd {
|
|
choice.FinishReason = &constant.FinishReasonStop
|
|
}
|
|
response := dto.ChatCompletionsStreamResponse{
|
|
Id: baiduResponse.Id,
|
|
Object: "chat.completion.chunk",
|
|
Created: baiduResponse.Created,
|
|
Model: "ernie-bot",
|
|
Choices: []dto.ChatCompletionsStreamResponseChoice{choice},
|
|
}
|
|
return &response
|
|
}
|
|
|
|
func embeddingRequestOpenAI2Baidu(request dto.EmbeddingRequest) *BaiduEmbeddingRequest {
|
|
return &BaiduEmbeddingRequest{
|
|
Input: request.ParseInput(),
|
|
}
|
|
}
|
|
|
|
func embeddingResponseBaidu2OpenAI(response *BaiduEmbeddingResponse) *dto.OpenAIEmbeddingResponse {
|
|
openAIEmbeddingResponse := dto.OpenAIEmbeddingResponse{
|
|
Object: "list",
|
|
Data: make([]dto.OpenAIEmbeddingResponseItem, 0, len(response.Data)),
|
|
Model: "baidu-embedding",
|
|
Usage: response.Usage,
|
|
}
|
|
for _, item := range response.Data {
|
|
openAIEmbeddingResponse.Data = append(openAIEmbeddingResponse.Data, dto.OpenAIEmbeddingResponseItem{
|
|
Object: item.Object,
|
|
Index: item.Index,
|
|
Embedding: item.Embedding,
|
|
})
|
|
}
|
|
return &openAIEmbeddingResponse
|
|
}
|
|
|
|
func baiduStreamHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
|
|
var usage dto.Usage
|
|
scanner := bufio.NewScanner(resp.Body)
|
|
scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
|
|
if atEOF && len(data) == 0 {
|
|
return 0, nil, nil
|
|
}
|
|
if i := strings.Index(string(data), "\n"); i >= 0 {
|
|
return i + 1, data[0:i], nil
|
|
}
|
|
if atEOF {
|
|
return len(data), data, nil
|
|
}
|
|
return 0, nil, nil
|
|
})
|
|
dataChan := make(chan string)
|
|
stopChan := make(chan bool)
|
|
go func() {
|
|
for scanner.Scan() {
|
|
data := scanner.Text()
|
|
if len(data) < 6 { // ignore blank line or wrong format
|
|
continue
|
|
}
|
|
data = data[6:]
|
|
dataChan <- data
|
|
}
|
|
stopChan <- true
|
|
}()
|
|
helper.SetEventStreamHeaders(c)
|
|
c.Stream(func(w io.Writer) bool {
|
|
select {
|
|
case data := <-dataChan:
|
|
var baiduResponse BaiduChatStreamResponse
|
|
err := json.Unmarshal([]byte(data), &baiduResponse)
|
|
if err != nil {
|
|
common.SysError("error unmarshalling stream response: " + err.Error())
|
|
return true
|
|
}
|
|
if baiduResponse.Usage.TotalTokens != 0 {
|
|
usage.TotalTokens = baiduResponse.Usage.TotalTokens
|
|
usage.PromptTokens = baiduResponse.Usage.PromptTokens
|
|
usage.CompletionTokens = baiduResponse.Usage.TotalTokens - baiduResponse.Usage.PromptTokens
|
|
}
|
|
response := streamResponseBaidu2OpenAI(&baiduResponse)
|
|
jsonResponse, err := json.Marshal(response)
|
|
if err != nil {
|
|
common.SysError("error marshalling stream response: " + err.Error())
|
|
return true
|
|
}
|
|
c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)})
|
|
return true
|
|
case <-stopChan:
|
|
c.Render(-1, common.CustomEvent{Data: "data: [DONE]"})
|
|
return false
|
|
}
|
|
})
|
|
common.CloseResponseBodyGracefully(resp)
|
|
return nil, &usage
|
|
}
|
|
|
|
func baiduHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
|
|
var baiduResponse BaiduChatResponse
|
|
responseBody, err := io.ReadAll(resp.Body)
|
|
if err != nil {
|
|
return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
|
|
}
|
|
common.CloseResponseBodyGracefully(resp)
|
|
err = json.Unmarshal(responseBody, &baiduResponse)
|
|
if err != nil {
|
|
return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
|
|
}
|
|
if baiduResponse.ErrorMsg != "" {
|
|
return &dto.OpenAIErrorWithStatusCode{
|
|
Error: dto.OpenAIError{
|
|
Message: baiduResponse.ErrorMsg,
|
|
Type: "baidu_error",
|
|
Param: "",
|
|
Code: baiduResponse.ErrorCode,
|
|
},
|
|
StatusCode: resp.StatusCode,
|
|
}, nil
|
|
}
|
|
fullTextResponse := responseBaidu2OpenAI(&baiduResponse)
|
|
jsonResponse, err := json.Marshal(fullTextResponse)
|
|
if err != nil {
|
|
return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
|
|
}
|
|
c.Writer.Header().Set("Content-Type", "application/json")
|
|
c.Writer.WriteHeader(resp.StatusCode)
|
|
_, err = c.Writer.Write(jsonResponse)
|
|
return nil, &fullTextResponse.Usage
|
|
}
|
|
|
|
func baiduEmbeddingHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
|
|
var baiduResponse BaiduEmbeddingResponse
|
|
responseBody, err := io.ReadAll(resp.Body)
|
|
if err != nil {
|
|
return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
|
|
}
|
|
common.CloseResponseBodyGracefully(resp)
|
|
err = json.Unmarshal(responseBody, &baiduResponse)
|
|
if err != nil {
|
|
return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
|
|
}
|
|
if baiduResponse.ErrorMsg != "" {
|
|
return &dto.OpenAIErrorWithStatusCode{
|
|
Error: dto.OpenAIError{
|
|
Message: baiduResponse.ErrorMsg,
|
|
Type: "baidu_error",
|
|
Param: "",
|
|
Code: baiduResponse.ErrorCode,
|
|
},
|
|
StatusCode: resp.StatusCode,
|
|
}, nil
|
|
}
|
|
fullTextResponse := embeddingResponseBaidu2OpenAI(&baiduResponse)
|
|
jsonResponse, err := json.Marshal(fullTextResponse)
|
|
if err != nil {
|
|
return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
|
|
}
|
|
c.Writer.Header().Set("Content-Type", "application/json")
|
|
c.Writer.WriteHeader(resp.StatusCode)
|
|
_, err = c.Writer.Write(jsonResponse)
|
|
return nil, &fullTextResponse.Usage
|
|
}
|
|
|
|
func getBaiduAccessToken(apiKey string) (string, error) {
|
|
if val, ok := baiduTokenStore.Load(apiKey); ok {
|
|
var accessToken BaiduAccessToken
|
|
if accessToken, ok = val.(BaiduAccessToken); ok {
|
|
// soon this will expire
|
|
if time.Now().Add(time.Hour).After(accessToken.ExpiresAt) {
|
|
go func() {
|
|
_, _ = getBaiduAccessTokenHelper(apiKey)
|
|
}()
|
|
}
|
|
return accessToken.AccessToken, nil
|
|
}
|
|
}
|
|
accessToken, err := getBaiduAccessTokenHelper(apiKey)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
if accessToken == nil {
|
|
return "", errors.New("getBaiduAccessToken return a nil token")
|
|
}
|
|
return (*accessToken).AccessToken, nil
|
|
}
|
|
|
|
func getBaiduAccessTokenHelper(apiKey string) (*BaiduAccessToken, error) {
|
|
parts := strings.Split(apiKey, "|")
|
|
if len(parts) != 2 {
|
|
return nil, errors.New("invalid baidu apikey")
|
|
}
|
|
req, err := http.NewRequest("POST", fmt.Sprintf("https://aip.baidubce.com/oauth/2.0/token?grant_type=client_credentials&client_id=%s&client_secret=%s",
|
|
parts[0], parts[1]), nil)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
req.Header.Add("Content-Type", "application/json")
|
|
req.Header.Add("Accept", "application/json")
|
|
res, err := service.GetImpatientHttpClient().Do(req)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
defer res.Body.Close()
|
|
|
|
var accessToken BaiduAccessToken
|
|
err = json.NewDecoder(res.Body).Decode(&accessToken)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
if accessToken.Error != "" {
|
|
return nil, errors.New(accessToken.Error + ": " + accessToken.ErrorDescription)
|
|
}
|
|
if accessToken.AccessToken == "" {
|
|
return nil, errors.New("getBaiduAccessTokenHelper get empty access token")
|
|
}
|
|
accessToken.ExpiresAt = time.Now().Add(time.Duration(accessToken.ExpiresIn) * time.Second)
|
|
baiduTokenStore.Store(apiKey, accessToken)
|
|
return &accessToken, nil
|
|
}
|