From 8723e3f239f6b165e255a9b45c070c58cd28b7e5 Mon Sep 17 00:00:00 2001 From: CaIon <1808837298@qq.com> Date: Thu, 10 Apr 2025 17:20:59 +0800 Subject: [PATCH] feat: add xAI handling and response processing --- common/json.go | 4 ++ dto/realtime.go | 7 ++- relay/channel/xai/adaptor.go | 11 ++-- relay/channel/xai/dto.go | 14 +++++ relay/channel/xai/text.go | 107 +++++++++++++++++++++++++++++++++++ relay/helper/common.go | 3 + 6 files changed, 137 insertions(+), 9 deletions(-) create mode 100644 relay/channel/xai/dto.go create mode 100644 relay/channel/xai/text.go diff --git a/common/json.go b/common/json.go index 5b2b1aac..cec8f16b 100644 --- a/common/json.go +++ b/common/json.go @@ -12,3 +12,7 @@ func DecodeJson(data []byte, v any) error { func DecodeJsonStr(data string, v any) error { return DecodeJson(StringToByteSlice(data), v) } + +func EncodeJson(v any) ([]byte, error) { + return json.Marshal(v) +} diff --git a/dto/realtime.go b/dto/realtime.go index 8c6e8932..bb572267 100644 --- a/dto/realtime.go +++ b/dto/realtime.go @@ -45,15 +45,16 @@ type RealtimeUsage struct { type InputTokenDetails struct { CachedTokens int `json:"cached_tokens"` - CachedCreationTokens int + CachedCreationTokens int `json:"-"` TextTokens int `json:"text_tokens"` AudioTokens int `json:"audio_tokens"` ImageTokens int `json:"image_tokens"` } type OutputTokenDetails struct { - TextTokens int `json:"text_tokens"` - AudioTokens int `json:"audio_tokens"` + TextTokens int `json:"text_tokens"` + AudioTokens int `json:"audio_tokens"` + ReasoningTokens int `json:"reasoning_tokens"` } type RealtimeSession struct { diff --git a/relay/channel/xai/adaptor.go b/relay/channel/xai/adaptor.go index 5828ef0a..2b032701 100644 --- a/relay/channel/xai/adaptor.go +++ b/relay/channel/xai/adaptor.go @@ -8,7 +8,6 @@ import ( "net/http" "one-api/dto" "one-api/relay/channel" - "one-api/relay/channel/openai" relaycommon "one-api/relay/common" "strings" ) @@ -86,13 +85,13 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) { if info.IsStream { - err, usage = openai.OaiStreamHandler(c, resp, info) + err, usage = xAIStreamHandler(c, resp, info) } else { - err, usage = openai.OpenaiHandler(c, resp, info) - } - if _, ok := usage.(*dto.Usage); ok && usage != nil { - usage.(*dto.Usage).CompletionTokens = usage.(*dto.Usage).TotalTokens - usage.(*dto.Usage).PromptTokens + err, usage = xAIHandler(c, resp, info) } + //if _, ok := usage.(*dto.Usage); ok && usage != nil { + // usage.(*dto.Usage).CompletionTokens = usage.(*dto.Usage).TotalTokens - usage.(*dto.Usage).PromptTokens + //} return } diff --git a/relay/channel/xai/dto.go b/relay/channel/xai/dto.go new file mode 100644 index 00000000..7036d5f1 --- /dev/null +++ b/relay/channel/xai/dto.go @@ -0,0 +1,14 @@ +package xai + +import "one-api/dto" + +// ChatCompletionResponse represents the response from XAI chat completion API +type ChatCompletionResponse struct { + Id string `json:"id"` + Object string `json:"object"` + Created int64 `json:"created"` + Model string `json:"model"` + Choices []dto.ChatCompletionsStreamResponseChoice + Usage *dto.Usage `json:"usage"` + SystemFingerprint string `json:"system_fingerprint"` +} diff --git a/relay/channel/xai/text.go b/relay/channel/xai/text.go new file mode 100644 index 00000000..0f66b735 --- /dev/null +++ b/relay/channel/xai/text.go @@ -0,0 +1,107 @@ +package xai + +import ( + "bytes" + "encoding/json" + "github.com/gin-gonic/gin" + "io" + "net/http" + "one-api/common" + "one-api/dto" + relaycommon "one-api/relay/common" + "one-api/relay/helper" + "one-api/service" +) + +func streamResponseXAI2OpenAI(xAIResp *dto.ChatCompletionsStreamResponse, usage *dto.Usage) *dto.ChatCompletionsStreamResponse { + if xAIResp == nil { + return nil + } + if xAIResp.Usage != nil { + xAIResp.Usage.CompletionTokens = usage.CompletionTokens + } + openAIResp := &dto.ChatCompletionsStreamResponse{ + Id: xAIResp.Id, + Object: xAIResp.Object, + Created: xAIResp.Created, + Model: xAIResp.Model, + Choices: xAIResp.Choices, + Usage: xAIResp.Usage, + } + + return openAIResp +} + +func xAIStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) { + usage := &dto.Usage{} + + helper.SetEventStreamHeaders(c) + + helper.StreamScannerHandler(c, resp, info, func(data string) bool { + var xAIResp *dto.ChatCompletionsStreamResponse + err := json.Unmarshal([]byte(data), &xAIResp) + if err != nil { + common.SysError("error unmarshalling stream response: " + err.Error()) + return true + } + + // 把 xAI 的usage转换为 OpenAI 的usage + if xAIResp.Usage != nil { + usage.PromptTokens = xAIResp.Usage.PromptTokens + usage.TotalTokens = xAIResp.Usage.TotalTokens + usage.CompletionTokens = usage.TotalTokens - usage.PromptTokens + } + + openaiResponse := streamResponseXAI2OpenAI(xAIResp, usage) + err = helper.ObjectData(c, openaiResponse) + if err != nil { + common.SysError(err.Error()) + } + return true + }) + + helper.Done(c) + err := resp.Body.Close() + if err != nil { + //return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil + common.SysError("close_response_body_failed: " + err.Error()) + } + return nil, usage +} + +func xAIHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) { + responseBody, err := io.ReadAll(resp.Body) + var response *dto.TextResponse + err = common.DecodeJson(responseBody, &response) + if err != nil { + common.SysError("error unmarshalling stream response: " + err.Error()) + return nil, nil + } + response.Usage.CompletionTokens = response.Usage.TotalTokens - response.Usage.PromptTokens + response.Usage.CompletionTokenDetails.TextTokens = response.Usage.CompletionTokens - response.Usage.CompletionTokenDetails.ReasoningTokens + + // new body + encodeJson, err := common.EncodeJson(response) + if err != nil { + common.SysError("error marshalling stream response: " + err.Error()) + return nil, nil + } + + // set new body + resp.Body = io.NopCloser(bytes.NewBuffer(encodeJson)) + + for k, v := range resp.Header { + c.Writer.Header().Set(k, v[0]) + } + c.Writer.WriteHeader(resp.StatusCode) + _, err = io.Copy(c.Writer, resp.Body) + if err != nil { + return service.OpenAIErrorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError), nil + } + err = resp.Body.Close() + if err != nil { + return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil + } + + return nil, &response.Usage +} diff --git a/relay/helper/common.go b/relay/helper/common.go index 13fc85ab..200846f6 100644 --- a/relay/helper/common.go +++ b/relay/helper/common.go @@ -56,6 +56,9 @@ func StringData(c *gin.Context, str string) error { } func ObjectData(c *gin.Context, object interface{}) error { + if object == nil { + return errors.New("object is nil") + } jsonData, err := json.Marshal(object) if err != nil { return fmt.Errorf("error marshalling object: %w", err)