This commit refactors the logging mechanism across the application by replacing direct logger calls with a centralized logging approach using the `common` package. Key changes include: - Replaced instances of `logger.SysLog` and `logger.FatalLog` with `common.SysLog` and `common.FatalLog` for consistent logging practices. - Updated resource initialization error handling to utilize the new logging structure, enhancing maintainability and readability. - Minor adjustments to improve code clarity and organization throughout various modules. This change aims to streamline logging and improve the overall architecture of the codebase.
246 lines
7.5 KiB
Go
246 lines
7.5 KiB
Go
package baidu
|
|
|
|
import (
|
|
"encoding/json"
|
|
"errors"
|
|
"fmt"
|
|
"io"
|
|
"net/http"
|
|
"one-api/common"
|
|
"one-api/constant"
|
|
"one-api/dto"
|
|
relaycommon "one-api/relay/common"
|
|
"one-api/relay/helper"
|
|
"one-api/service"
|
|
"one-api/types"
|
|
"strings"
|
|
"sync"
|
|
"time"
|
|
|
|
"github.com/gin-gonic/gin"
|
|
)
|
|
|
|
// https://cloud.baidu.com/doc/WENXINWORKSHOP/s/flfmc9do2
|
|
|
|
var baiduTokenStore sync.Map
|
|
|
|
func requestOpenAI2Baidu(request dto.GeneralOpenAIRequest) *BaiduChatRequest {
|
|
baiduRequest := BaiduChatRequest{
|
|
Temperature: request.Temperature,
|
|
TopP: request.TopP,
|
|
PenaltyScore: request.FrequencyPenalty,
|
|
Stream: request.Stream,
|
|
DisableSearch: false,
|
|
EnableCitation: false,
|
|
UserId: request.User,
|
|
}
|
|
if request.GetMaxTokens() != 0 {
|
|
maxTokens := int(request.GetMaxTokens())
|
|
if request.GetMaxTokens() == 1 {
|
|
maxTokens = 2
|
|
}
|
|
baiduRequest.MaxOutputTokens = &maxTokens
|
|
}
|
|
for _, message := range request.Messages {
|
|
if message.Role == "system" {
|
|
baiduRequest.System = message.StringContent()
|
|
} else {
|
|
baiduRequest.Messages = append(baiduRequest.Messages, BaiduMessage{
|
|
Role: message.Role,
|
|
Content: message.StringContent(),
|
|
})
|
|
}
|
|
}
|
|
return &baiduRequest
|
|
}
|
|
|
|
func responseBaidu2OpenAI(response *BaiduChatResponse) *dto.OpenAITextResponse {
|
|
choice := dto.OpenAITextResponseChoice{
|
|
Index: 0,
|
|
Message: dto.Message{
|
|
Role: "assistant",
|
|
Content: response.Result,
|
|
},
|
|
FinishReason: "stop",
|
|
}
|
|
fullTextResponse := dto.OpenAITextResponse{
|
|
Id: response.Id,
|
|
Object: "chat.completion",
|
|
Created: response.Created,
|
|
Choices: []dto.OpenAITextResponseChoice{choice},
|
|
Usage: response.Usage,
|
|
}
|
|
return &fullTextResponse
|
|
}
|
|
|
|
func streamResponseBaidu2OpenAI(baiduResponse *BaiduChatStreamResponse) *dto.ChatCompletionsStreamResponse {
|
|
var choice dto.ChatCompletionsStreamResponseChoice
|
|
choice.Delta.SetContentString(baiduResponse.Result)
|
|
if baiduResponse.IsEnd {
|
|
choice.FinishReason = &constant.FinishReasonStop
|
|
}
|
|
response := dto.ChatCompletionsStreamResponse{
|
|
Id: baiduResponse.Id,
|
|
Object: "chat.completion.chunk",
|
|
Created: baiduResponse.Created,
|
|
Model: "ernie-bot",
|
|
Choices: []dto.ChatCompletionsStreamResponseChoice{choice},
|
|
}
|
|
return &response
|
|
}
|
|
|
|
func embeddingRequestOpenAI2Baidu(request dto.EmbeddingRequest) *BaiduEmbeddingRequest {
|
|
return &BaiduEmbeddingRequest{
|
|
Input: request.ParseInput(),
|
|
}
|
|
}
|
|
|
|
func embeddingResponseBaidu2OpenAI(response *BaiduEmbeddingResponse) *dto.OpenAIEmbeddingResponse {
|
|
openAIEmbeddingResponse := dto.OpenAIEmbeddingResponse{
|
|
Object: "list",
|
|
Data: make([]dto.OpenAIEmbeddingResponseItem, 0, len(response.Data)),
|
|
Model: "baidu-embedding",
|
|
Usage: response.Usage,
|
|
}
|
|
for _, item := range response.Data {
|
|
openAIEmbeddingResponse.Data = append(openAIEmbeddingResponse.Data, dto.OpenAIEmbeddingResponseItem{
|
|
Object: item.Object,
|
|
Index: item.Index,
|
|
Embedding: item.Embedding,
|
|
})
|
|
}
|
|
return &openAIEmbeddingResponse
|
|
}
|
|
|
|
func baiduStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*types.NewAPIError, *dto.Usage) {
|
|
usage := &dto.Usage{}
|
|
helper.StreamScannerHandler(c, resp, info, func(data string) bool {
|
|
var baiduResponse BaiduChatStreamResponse
|
|
err := common.Unmarshal([]byte(data), &baiduResponse)
|
|
if err != nil {
|
|
common.SysLog("error unmarshalling stream response: " + err.Error())
|
|
return true
|
|
}
|
|
if baiduResponse.Usage.TotalTokens != 0 {
|
|
usage.TotalTokens = baiduResponse.Usage.TotalTokens
|
|
usage.PromptTokens = baiduResponse.Usage.PromptTokens
|
|
usage.CompletionTokens = baiduResponse.Usage.TotalTokens - baiduResponse.Usage.PromptTokens
|
|
}
|
|
response := streamResponseBaidu2OpenAI(&baiduResponse)
|
|
err = helper.ObjectData(c, response)
|
|
if err != nil {
|
|
common.SysLog("error sending stream response: " + err.Error())
|
|
}
|
|
return true
|
|
})
|
|
service.CloseResponseBodyGracefully(resp)
|
|
return nil, usage
|
|
}
|
|
|
|
func baiduHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*types.NewAPIError, *dto.Usage) {
|
|
var baiduResponse BaiduChatResponse
|
|
responseBody, err := io.ReadAll(resp.Body)
|
|
if err != nil {
|
|
return types.NewError(err, types.ErrorCodeBadResponseBody), nil
|
|
}
|
|
service.CloseResponseBodyGracefully(resp)
|
|
err = json.Unmarshal(responseBody, &baiduResponse)
|
|
if err != nil {
|
|
return types.NewError(err, types.ErrorCodeBadResponseBody), nil
|
|
}
|
|
if baiduResponse.ErrorMsg != "" {
|
|
return types.NewError(fmt.Errorf(baiduResponse.ErrorMsg), types.ErrorCodeBadResponseBody), nil
|
|
}
|
|
fullTextResponse := responseBaidu2OpenAI(&baiduResponse)
|
|
jsonResponse, err := json.Marshal(fullTextResponse)
|
|
if err != nil {
|
|
return types.NewError(err, types.ErrorCodeBadResponseBody), nil
|
|
}
|
|
c.Writer.Header().Set("Content-Type", "application/json")
|
|
c.Writer.WriteHeader(resp.StatusCode)
|
|
_, err = c.Writer.Write(jsonResponse)
|
|
return nil, &fullTextResponse.Usage
|
|
}
|
|
|
|
func baiduEmbeddingHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*types.NewAPIError, *dto.Usage) {
|
|
var baiduResponse BaiduEmbeddingResponse
|
|
responseBody, err := io.ReadAll(resp.Body)
|
|
if err != nil {
|
|
return types.NewError(err, types.ErrorCodeBadResponseBody), nil
|
|
}
|
|
service.CloseResponseBodyGracefully(resp)
|
|
err = json.Unmarshal(responseBody, &baiduResponse)
|
|
if err != nil {
|
|
return types.NewError(err, types.ErrorCodeBadResponseBody), nil
|
|
}
|
|
if baiduResponse.ErrorMsg != "" {
|
|
return types.NewError(fmt.Errorf(baiduResponse.ErrorMsg), types.ErrorCodeBadResponseBody), nil
|
|
}
|
|
fullTextResponse := embeddingResponseBaidu2OpenAI(&baiduResponse)
|
|
jsonResponse, err := json.Marshal(fullTextResponse)
|
|
if err != nil {
|
|
return types.NewError(err, types.ErrorCodeBadResponseBody), nil
|
|
}
|
|
c.Writer.Header().Set("Content-Type", "application/json")
|
|
c.Writer.WriteHeader(resp.StatusCode)
|
|
_, err = c.Writer.Write(jsonResponse)
|
|
return nil, &fullTextResponse.Usage
|
|
}
|
|
|
|
func getBaiduAccessToken(apiKey string) (string, error) {
|
|
if val, ok := baiduTokenStore.Load(apiKey); ok {
|
|
var accessToken BaiduAccessToken
|
|
if accessToken, ok = val.(BaiduAccessToken); ok {
|
|
// soon this will expire
|
|
if time.Now().Add(time.Hour).After(accessToken.ExpiresAt) {
|
|
go func() {
|
|
_, _ = getBaiduAccessTokenHelper(apiKey)
|
|
}()
|
|
}
|
|
return accessToken.AccessToken, nil
|
|
}
|
|
}
|
|
accessToken, err := getBaiduAccessTokenHelper(apiKey)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
if accessToken == nil {
|
|
return "", errors.New("getBaiduAccessToken return a nil token")
|
|
}
|
|
return (*accessToken).AccessToken, nil
|
|
}
|
|
|
|
func getBaiduAccessTokenHelper(apiKey string) (*BaiduAccessToken, error) {
|
|
parts := strings.Split(apiKey, "|")
|
|
if len(parts) != 2 {
|
|
return nil, errors.New("invalid baidu apikey")
|
|
}
|
|
req, err := http.NewRequest("POST", fmt.Sprintf("https://aip.baidubce.com/oauth/2.0/token?grant_type=client_credentials&client_id=%s&client_secret=%s",
|
|
parts[0], parts[1]), nil)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
req.Header.Add("Content-Type", "application/json")
|
|
req.Header.Add("Accept", "application/json")
|
|
res, err := service.GetHttpClient().Do(req)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
defer res.Body.Close()
|
|
|
|
var accessToken BaiduAccessToken
|
|
err = json.NewDecoder(res.Body).Decode(&accessToken)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
if accessToken.Error != "" {
|
|
return nil, errors.New(accessToken.Error + ": " + accessToken.ErrorDescription)
|
|
}
|
|
if accessToken.AccessToken == "" {
|
|
return nil, errors.New("getBaiduAccessTokenHelper get empty access token")
|
|
}
|
|
accessToken.ExpiresAt = time.Now().Add(time.Duration(accessToken.ExpiresIn) * time.Second)
|
|
baiduTokenStore.Store(apiKey, accessToken)
|
|
return &accessToken, nil
|
|
}
|