This commit refactors the logging mechanism across the application by replacing direct logger calls with a centralized logging approach using the `common` package. Key changes include: - Replaced instances of `logger.SysLog` and `logger.FatalLog` with `common.SysLog` and `common.FatalLog` for consistent logging practices. - Updated resource initialization error handling to utilize the new logging structure, enhancing maintainability and readability. - Minor adjustments to improve code clarity and organization throughout various modules. This change aims to streamline logging and improve the overall architecture of the codebase.
288 lines
8.5 KiB
Go
288 lines
8.5 KiB
Go
package xunfei
|
|
|
|
import (
|
|
"crypto/hmac"
|
|
"crypto/sha256"
|
|
"encoding/base64"
|
|
"encoding/json"
|
|
"fmt"
|
|
"io"
|
|
"net/url"
|
|
"one-api/common"
|
|
"one-api/constant"
|
|
"one-api/dto"
|
|
"one-api/relay/helper"
|
|
"one-api/types"
|
|
"strings"
|
|
"time"
|
|
|
|
"github.com/gin-gonic/gin"
|
|
"github.com/gorilla/websocket"
|
|
)
|
|
|
|
// https://console.xfyun.cn/services/cbm
|
|
// https://www.xfyun.cn/doc/spark/Web.html
|
|
|
|
func requestOpenAI2Xunfei(request dto.GeneralOpenAIRequest, xunfeiAppId string, domain string) *XunfeiChatRequest {
|
|
messages := make([]XunfeiMessage, 0, len(request.Messages))
|
|
shouldCovertSystemMessage := !strings.HasSuffix(request.Model, "3.5")
|
|
for _, message := range request.Messages {
|
|
if message.Role == "system" && shouldCovertSystemMessage {
|
|
messages = append(messages, XunfeiMessage{
|
|
Role: "user",
|
|
Content: message.StringContent(),
|
|
})
|
|
messages = append(messages, XunfeiMessage{
|
|
Role: "assistant",
|
|
Content: "Okay",
|
|
})
|
|
} else {
|
|
messages = append(messages, XunfeiMessage{
|
|
Role: message.Role,
|
|
Content: message.StringContent(),
|
|
})
|
|
}
|
|
}
|
|
xunfeiRequest := XunfeiChatRequest{}
|
|
xunfeiRequest.Header.AppId = xunfeiAppId
|
|
xunfeiRequest.Parameter.Chat.Domain = domain
|
|
xunfeiRequest.Parameter.Chat.Temperature = request.Temperature
|
|
xunfeiRequest.Parameter.Chat.TopK = request.N
|
|
xunfeiRequest.Parameter.Chat.MaxTokens = request.GetMaxTokens()
|
|
xunfeiRequest.Payload.Message.Text = messages
|
|
return &xunfeiRequest
|
|
}
|
|
|
|
func responseXunfei2OpenAI(response *XunfeiChatResponse) *dto.OpenAITextResponse {
|
|
if len(response.Payload.Choices.Text) == 0 {
|
|
response.Payload.Choices.Text = []XunfeiChatResponseTextItem{
|
|
{
|
|
Content: "",
|
|
},
|
|
}
|
|
}
|
|
choice := dto.OpenAITextResponseChoice{
|
|
Index: 0,
|
|
Message: dto.Message{
|
|
Role: "assistant",
|
|
Content: response.Payload.Choices.Text[0].Content,
|
|
},
|
|
FinishReason: constant.FinishReasonStop,
|
|
}
|
|
fullTextResponse := dto.OpenAITextResponse{
|
|
Object: "chat.completion",
|
|
Created: common.GetTimestamp(),
|
|
Choices: []dto.OpenAITextResponseChoice{choice},
|
|
Usage: response.Payload.Usage.Text,
|
|
}
|
|
return &fullTextResponse
|
|
}
|
|
|
|
func streamResponseXunfei2OpenAI(xunfeiResponse *XunfeiChatResponse) *dto.ChatCompletionsStreamResponse {
|
|
if len(xunfeiResponse.Payload.Choices.Text) == 0 {
|
|
xunfeiResponse.Payload.Choices.Text = []XunfeiChatResponseTextItem{
|
|
{
|
|
Content: "",
|
|
},
|
|
}
|
|
}
|
|
var choice dto.ChatCompletionsStreamResponseChoice
|
|
choice.Delta.SetContentString(xunfeiResponse.Payload.Choices.Text[0].Content)
|
|
if xunfeiResponse.Payload.Choices.Status == 2 {
|
|
choice.FinishReason = &constant.FinishReasonStop
|
|
}
|
|
response := dto.ChatCompletionsStreamResponse{
|
|
Object: "chat.completion.chunk",
|
|
Created: common.GetTimestamp(),
|
|
Model: "SparkDesk",
|
|
Choices: []dto.ChatCompletionsStreamResponseChoice{choice},
|
|
}
|
|
return &response
|
|
}
|
|
|
|
func buildXunfeiAuthUrl(hostUrl string, apiKey, apiSecret string) string {
|
|
HmacWithShaToBase64 := func(algorithm, data, key string) string {
|
|
mac := hmac.New(sha256.New, []byte(key))
|
|
mac.Write([]byte(data))
|
|
encodeData := mac.Sum(nil)
|
|
return base64.StdEncoding.EncodeToString(encodeData)
|
|
}
|
|
ul, err := url.Parse(hostUrl)
|
|
if err != nil {
|
|
fmt.Println(err)
|
|
}
|
|
date := time.Now().UTC().Format(time.RFC1123)
|
|
signString := []string{"host: " + ul.Host, "date: " + date, "GET " + ul.Path + " HTTP/1.1"}
|
|
sign := strings.Join(signString, "\n")
|
|
sha := HmacWithShaToBase64("hmac-sha256", sign, apiSecret)
|
|
authUrl := fmt.Sprintf("hmac username=\"%s\", algorithm=\"%s\", headers=\"%s\", signature=\"%s\"", apiKey,
|
|
"hmac-sha256", "host date request-line", sha)
|
|
authorization := base64.StdEncoding.EncodeToString([]byte(authUrl))
|
|
v := url.Values{}
|
|
v.Add("host", ul.Host)
|
|
v.Add("date", date)
|
|
v.Add("authorization", authorization)
|
|
callUrl := hostUrl + "?" + v.Encode()
|
|
return callUrl
|
|
}
|
|
|
|
func xunfeiStreamHandler(c *gin.Context, textRequest dto.GeneralOpenAIRequest, appId string, apiSecret string, apiKey string) (*dto.Usage, *types.NewAPIError) {
|
|
domain, authUrl := getXunfeiAuthUrl(c, apiKey, apiSecret, textRequest.Model)
|
|
dataChan, stopChan, err := xunfeiMakeRequest(textRequest, domain, authUrl, appId)
|
|
if err != nil {
|
|
return nil, types.NewError(err, types.ErrorCodeDoRequestFailed)
|
|
}
|
|
helper.SetEventStreamHeaders(c)
|
|
var usage dto.Usage
|
|
c.Stream(func(w io.Writer) bool {
|
|
select {
|
|
case xunfeiResponse := <-dataChan:
|
|
usage.PromptTokens += xunfeiResponse.Payload.Usage.Text.PromptTokens
|
|
usage.CompletionTokens += xunfeiResponse.Payload.Usage.Text.CompletionTokens
|
|
usage.TotalTokens += xunfeiResponse.Payload.Usage.Text.TotalTokens
|
|
response := streamResponseXunfei2OpenAI(&xunfeiResponse)
|
|
jsonResponse, err := json.Marshal(response)
|
|
if err != nil {
|
|
common.SysLog("error marshalling stream response: " + err.Error())
|
|
return true
|
|
}
|
|
c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)})
|
|
return true
|
|
case <-stopChan:
|
|
c.Render(-1, common.CustomEvent{Data: "data: [DONE]"})
|
|
return false
|
|
}
|
|
})
|
|
return &usage, nil
|
|
}
|
|
|
|
func xunfeiHandler(c *gin.Context, textRequest dto.GeneralOpenAIRequest, appId string, apiSecret string, apiKey string) (*dto.Usage, *types.NewAPIError) {
|
|
domain, authUrl := getXunfeiAuthUrl(c, apiKey, apiSecret, textRequest.Model)
|
|
dataChan, stopChan, err := xunfeiMakeRequest(textRequest, domain, authUrl, appId)
|
|
if err != nil {
|
|
return nil, types.NewError(err, types.ErrorCodeDoRequestFailed)
|
|
}
|
|
var usage dto.Usage
|
|
var content string
|
|
var xunfeiResponse XunfeiChatResponse
|
|
stop := false
|
|
for !stop {
|
|
select {
|
|
case xunfeiResponse = <-dataChan:
|
|
if len(xunfeiResponse.Payload.Choices.Text) == 0 {
|
|
continue
|
|
}
|
|
content += xunfeiResponse.Payload.Choices.Text[0].Content
|
|
usage.PromptTokens += xunfeiResponse.Payload.Usage.Text.PromptTokens
|
|
usage.CompletionTokens += xunfeiResponse.Payload.Usage.Text.CompletionTokens
|
|
usage.TotalTokens += xunfeiResponse.Payload.Usage.Text.TotalTokens
|
|
case stop = <-stopChan:
|
|
}
|
|
}
|
|
if len(xunfeiResponse.Payload.Choices.Text) == 0 {
|
|
xunfeiResponse.Payload.Choices.Text = []XunfeiChatResponseTextItem{
|
|
{
|
|
Content: "",
|
|
},
|
|
}
|
|
}
|
|
xunfeiResponse.Payload.Choices.Text[0].Content = content
|
|
|
|
response := responseXunfei2OpenAI(&xunfeiResponse)
|
|
jsonResponse, err := json.Marshal(response)
|
|
if err != nil {
|
|
return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
|
|
}
|
|
c.Writer.Header().Set("Content-Type", "application/json")
|
|
_, _ = c.Writer.Write(jsonResponse)
|
|
return &usage, nil
|
|
}
|
|
|
|
func xunfeiMakeRequest(textRequest dto.GeneralOpenAIRequest, domain, authUrl, appId string) (chan XunfeiChatResponse, chan bool, error) {
|
|
d := websocket.Dialer{
|
|
HandshakeTimeout: 5 * time.Second,
|
|
}
|
|
conn, resp, err := d.Dial(authUrl, nil)
|
|
if err != nil || resp.StatusCode != 101 {
|
|
return nil, nil, err
|
|
}
|
|
data := requestOpenAI2Xunfei(textRequest, appId, domain)
|
|
err = conn.WriteJSON(data)
|
|
if err != nil {
|
|
return nil, nil, err
|
|
}
|
|
|
|
dataChan := make(chan XunfeiChatResponse)
|
|
stopChan := make(chan bool)
|
|
go func() {
|
|
for {
|
|
_, msg, err := conn.ReadMessage()
|
|
if err != nil {
|
|
common.SysLog("error reading stream response: " + err.Error())
|
|
break
|
|
}
|
|
var response XunfeiChatResponse
|
|
err = json.Unmarshal(msg, &response)
|
|
if err != nil {
|
|
common.SysLog("error unmarshalling stream response: " + err.Error())
|
|
break
|
|
}
|
|
dataChan <- response
|
|
if response.Payload.Choices.Status == 2 {
|
|
err := conn.Close()
|
|
if err != nil {
|
|
common.SysLog("error closing websocket connection: " + err.Error())
|
|
}
|
|
break
|
|
}
|
|
}
|
|
stopChan <- true
|
|
}()
|
|
|
|
return dataChan, stopChan, nil
|
|
}
|
|
|
|
func apiVersion2domain(apiVersion string) string {
|
|
switch apiVersion {
|
|
case "v1.1":
|
|
return "lite"
|
|
case "v2.1":
|
|
return "generalv2"
|
|
case "v3.1":
|
|
return "generalv3"
|
|
case "v3.5":
|
|
return "generalv3.5"
|
|
case "v4.0":
|
|
return "4.0Ultra"
|
|
}
|
|
return "general" + apiVersion
|
|
}
|
|
|
|
func getXunfeiAuthUrl(c *gin.Context, apiKey string, apiSecret string, modelName string) (string, string) {
|
|
apiVersion := getAPIVersion(c, modelName)
|
|
domain := apiVersion2domain(apiVersion)
|
|
authUrl := buildXunfeiAuthUrl(fmt.Sprintf("wss://spark-api.xf-yun.com/%s/chat", apiVersion), apiKey, apiSecret)
|
|
return domain, authUrl
|
|
}
|
|
|
|
func getAPIVersion(c *gin.Context, modelName string) string {
|
|
query := c.Request.URL.Query()
|
|
apiVersion := query.Get("api-version")
|
|
if apiVersion != "" {
|
|
return apiVersion
|
|
}
|
|
parts := strings.Split(modelName, "-")
|
|
if len(parts) == 2 {
|
|
apiVersion = parts[1]
|
|
return apiVersion
|
|
|
|
}
|
|
apiVersion = c.GetString("api_version")
|
|
if apiVersion != "" {
|
|
return apiVersion
|
|
}
|
|
apiVersion = "v1.1"
|
|
common.SysLog("api_version not found, using default: " + apiVersion)
|
|
return apiVersion
|
|
}
|