package baidu import ( "encoding/json" "errors" "fmt" "io" "net/http" "one-api/common" "one-api/constant" "one-api/dto" relaycommon "one-api/relay/common" "one-api/relay/helper" "one-api/service" "one-api/types" "strings" "sync" "time" "github.com/gin-gonic/gin" ) // https://cloud.baidu.com/doc/WENXINWORKSHOP/s/flfmc9do2 var baiduTokenStore sync.Map func requestOpenAI2Baidu(request dto.GeneralOpenAIRequest) *BaiduChatRequest { baiduRequest := BaiduChatRequest{ Temperature: request.Temperature, TopP: request.TopP, PenaltyScore: request.FrequencyPenalty, Stream: request.Stream, DisableSearch: false, EnableCitation: false, UserId: request.User, } if request.GetMaxTokens() != 0 { maxTokens := int(request.GetMaxTokens()) if request.GetMaxTokens() == 1 { maxTokens = 2 } baiduRequest.MaxOutputTokens = &maxTokens } for _, message := range request.Messages { if message.Role == "system" { baiduRequest.System = message.StringContent() } else { baiduRequest.Messages = append(baiduRequest.Messages, BaiduMessage{ Role: message.Role, Content: message.StringContent(), }) } } return &baiduRequest } func responseBaidu2OpenAI(response *BaiduChatResponse) *dto.OpenAITextResponse { choice := dto.OpenAITextResponseChoice{ Index: 0, Message: dto.Message{ Role: "assistant", Content: response.Result, }, FinishReason: "stop", } fullTextResponse := dto.OpenAITextResponse{ Id: response.Id, Object: "chat.completion", Created: response.Created, Choices: []dto.OpenAITextResponseChoice{choice}, Usage: response.Usage, } return &fullTextResponse } func streamResponseBaidu2OpenAI(baiduResponse *BaiduChatStreamResponse) *dto.ChatCompletionsStreamResponse { var choice dto.ChatCompletionsStreamResponseChoice choice.Delta.SetContentString(baiduResponse.Result) if baiduResponse.IsEnd { choice.FinishReason = &constant.FinishReasonStop } response := dto.ChatCompletionsStreamResponse{ Id: baiduResponse.Id, Object: "chat.completion.chunk", Created: baiduResponse.Created, Model: "ernie-bot", Choices: []dto.ChatCompletionsStreamResponseChoice{choice}, } return &response } func embeddingRequestOpenAI2Baidu(request dto.EmbeddingRequest) *BaiduEmbeddingRequest { return &BaiduEmbeddingRequest{ Input: request.ParseInput(), } } func embeddingResponseBaidu2OpenAI(response *BaiduEmbeddingResponse) *dto.OpenAIEmbeddingResponse { openAIEmbeddingResponse := dto.OpenAIEmbeddingResponse{ Object: "list", Data: make([]dto.OpenAIEmbeddingResponseItem, 0, len(response.Data)), Model: "baidu-embedding", Usage: response.Usage, } for _, item := range response.Data { openAIEmbeddingResponse.Data = append(openAIEmbeddingResponse.Data, dto.OpenAIEmbeddingResponseItem{ Object: item.Object, Index: item.Index, Embedding: item.Embedding, }) } return &openAIEmbeddingResponse } func baiduStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*types.NewAPIError, *dto.Usage) { usage := &dto.Usage{} helper.StreamScannerHandler(c, resp, info, func(data string) bool { var baiduResponse BaiduChatStreamResponse err := common.Unmarshal([]byte(data), &baiduResponse) if err != nil { common.SysLog("error unmarshalling stream response: " + err.Error()) return true } if baiduResponse.Usage.TotalTokens != 0 { usage.TotalTokens = baiduResponse.Usage.TotalTokens usage.PromptTokens = baiduResponse.Usage.PromptTokens usage.CompletionTokens = baiduResponse.Usage.TotalTokens - baiduResponse.Usage.PromptTokens } response := streamResponseBaidu2OpenAI(&baiduResponse) err = helper.ObjectData(c, response) if err != nil { common.SysLog("error sending stream response: " + err.Error()) } return true }) service.CloseResponseBodyGracefully(resp) return nil, usage } func baiduHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*types.NewAPIError, *dto.Usage) { var baiduResponse BaiduChatResponse responseBody, err := io.ReadAll(resp.Body) if err != nil { return types.NewError(err, types.ErrorCodeBadResponseBody), nil } service.CloseResponseBodyGracefully(resp) err = json.Unmarshal(responseBody, &baiduResponse) if err != nil { return types.NewError(err, types.ErrorCodeBadResponseBody), nil } if baiduResponse.ErrorMsg != "" { return types.NewError(fmt.Errorf(baiduResponse.ErrorMsg), types.ErrorCodeBadResponseBody), nil } fullTextResponse := responseBaidu2OpenAI(&baiduResponse) jsonResponse, err := json.Marshal(fullTextResponse) if err != nil { return types.NewError(err, types.ErrorCodeBadResponseBody), nil } c.Writer.Header().Set("Content-Type", "application/json") c.Writer.WriteHeader(resp.StatusCode) _, err = c.Writer.Write(jsonResponse) return nil, &fullTextResponse.Usage } func baiduEmbeddingHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*types.NewAPIError, *dto.Usage) { var baiduResponse BaiduEmbeddingResponse responseBody, err := io.ReadAll(resp.Body) if err != nil { return types.NewError(err, types.ErrorCodeBadResponseBody), nil } service.CloseResponseBodyGracefully(resp) err = json.Unmarshal(responseBody, &baiduResponse) if err != nil { return types.NewError(err, types.ErrorCodeBadResponseBody), nil } if baiduResponse.ErrorMsg != "" { return types.NewError(fmt.Errorf(baiduResponse.ErrorMsg), types.ErrorCodeBadResponseBody), nil } fullTextResponse := embeddingResponseBaidu2OpenAI(&baiduResponse) jsonResponse, err := json.Marshal(fullTextResponse) if err != nil { return types.NewError(err, types.ErrorCodeBadResponseBody), nil } c.Writer.Header().Set("Content-Type", "application/json") c.Writer.WriteHeader(resp.StatusCode) _, err = c.Writer.Write(jsonResponse) return nil, &fullTextResponse.Usage } func getBaiduAccessToken(apiKey string) (string, error) { if val, ok := baiduTokenStore.Load(apiKey); ok { var accessToken BaiduAccessToken if accessToken, ok = val.(BaiduAccessToken); ok { // soon this will expire if time.Now().Add(time.Hour).After(accessToken.ExpiresAt) { go func() { _, _ = getBaiduAccessTokenHelper(apiKey) }() } return accessToken.AccessToken, nil } } accessToken, err := getBaiduAccessTokenHelper(apiKey) if err != nil { return "", err } if accessToken == nil { return "", errors.New("getBaiduAccessToken return a nil token") } return (*accessToken).AccessToken, nil } func getBaiduAccessTokenHelper(apiKey string) (*BaiduAccessToken, error) { parts := strings.Split(apiKey, "|") if len(parts) != 2 { return nil, errors.New("invalid baidu apikey") } req, err := http.NewRequest("POST", fmt.Sprintf("https://aip.baidubce.com/oauth/2.0/token?grant_type=client_credentials&client_id=%s&client_secret=%s", parts[0], parts[1]), nil) if err != nil { return nil, err } req.Header.Add("Content-Type", "application/json") req.Header.Add("Accept", "application/json") res, err := service.GetHttpClient().Do(req) if err != nil { return nil, err } defer res.Body.Close() var accessToken BaiduAccessToken err = json.NewDecoder(res.Body).Decode(&accessToken) if err != nil { return nil, err } if accessToken.Error != "" { return nil, errors.New(accessToken.Error + ": " + accessToken.ErrorDescription) } if accessToken.AccessToken == "" { return nil, errors.New("getBaiduAccessTokenHelper get empty access token") } accessToken.ExpiresAt = time.Now().Add(time.Duration(accessToken.ExpiresIn) * time.Second) baiduTokenStore.Store(apiKey, accessToken) return &accessToken, nil }