From ef8ae4db8052c44b1382d3f3f7cd20a177338728 Mon Sep 17 00:00:00 2001 From: CaIon <1808837298@qq.com> Date: Fri, 11 Apr 2025 23:31:32 +0800 Subject: [PATCH] fix: xAI usage --- relay/channel/openai/helper.go | 13 ++++++------- relay/channel/openai/relay-openai.go | 3 +-- relay/channel/xai/adaptor.go | 1 - relay/channel/xai/text.go | 12 ++++++++++++ relay/common/relay_info.go | 1 + 5 files changed, 20 insertions(+), 10 deletions(-) diff --git a/relay/channel/openai/helper.go b/relay/channel/openai/helper.go index 1a394f6f..e7ba2e7b 100644 --- a/relay/channel/openai/helper.go +++ b/relay/channel/openai/helper.go @@ -41,12 +41,7 @@ func handleClaudeFormat(c *gin.Context, data string, info *relaycommon.RelayInfo return nil } -func processStreamResponse(item string, responseTextBuilder *strings.Builder, toolCount *int) error { - var streamResponse dto.ChatCompletionsStreamResponse - if err := json.Unmarshal(common.StringToByteSlice(item), &streamResponse); err != nil { - return err - } - +func ProcessStreamResponse(streamResponse dto.ChatCompletionsStreamResponse, responseTextBuilder *strings.Builder, toolCount *int) error { for _, choice := range streamResponse.Choices { responseTextBuilder.WriteString(choice.Delta.GetContentString()) responseTextBuilder.WriteString(choice.Delta.GetReasoningContent()) @@ -81,7 +76,11 @@ func processChatCompletions(streamResp string, streamItems []string, responseTex // 一次性解析失败,逐个解析 common.SysError("error unmarshalling stream response: " + err.Error()) for _, item := range streamItems { - if err := processStreamResponse(item, responseTextBuilder, toolCount); err != nil { + var streamResponse dto.ChatCompletionsStreamResponse + if err := json.Unmarshal(common.StringToByteSlice(item), &streamResponse); err != nil { + return err + } + if err := ProcessStreamResponse(streamResponse, responseTextBuilder, toolCount); err != nil { common.SysError("error processing stream response: " + err.Error()) } } diff --git a/relay/channel/openai/relay-openai.go b/relay/channel/openai/relay-openai.go index e690f3d1..cb209fed 100644 --- a/relay/channel/openai/relay-openai.go +++ b/relay/channel/openai/relay-openai.go @@ -117,6 +117,7 @@ func OaiStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel model := info.UpstreamModelName var responseTextBuilder strings.Builder + var toolCount int var usage = &dto.Usage{} var streamItems []string // store stream items var forceFormat bool @@ -130,8 +131,6 @@ func OaiStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel thinkToContent = think2Content } - toolCount := 0 - var ( lastStreamData string ) diff --git a/relay/channel/xai/adaptor.go b/relay/channel/xai/adaptor.go index 2b032701..669b8c68 100644 --- a/relay/channel/xai/adaptor.go +++ b/relay/channel/xai/adaptor.go @@ -48,7 +48,6 @@ func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayIn if request == nil { return nil, errors.New("request is nil") } - request.StreamOptions = nil if strings.HasPrefix(request.Model, "grok-3-mini") { if request.MaxCompletionTokens == 0 && request.MaxTokens != 0 { request.MaxCompletionTokens = request.MaxTokens diff --git a/relay/channel/xai/text.go b/relay/channel/xai/text.go index 0f66b735..e019c2dc 100644 --- a/relay/channel/xai/text.go +++ b/relay/channel/xai/text.go @@ -8,9 +8,11 @@ import ( "net/http" "one-api/common" "one-api/dto" + "one-api/relay/channel/openai" relaycommon "one-api/relay/common" "one-api/relay/helper" "one-api/service" + "strings" ) func streamResponseXAI2OpenAI(xAIResp *dto.ChatCompletionsStreamResponse, usage *dto.Usage) *dto.ChatCompletionsStreamResponse { @@ -34,6 +36,9 @@ func streamResponseXAI2OpenAI(xAIResp *dto.ChatCompletionsStreamResponse, usage func xAIStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) { usage := &dto.Usage{} + var responseTextBuilder strings.Builder + var toolCount int + var containStreamUsage bool helper.SetEventStreamHeaders(c) @@ -47,12 +52,14 @@ func xAIStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel // 把 xAI 的usage转换为 OpenAI 的usage if xAIResp.Usage != nil { + containStreamUsage = true usage.PromptTokens = xAIResp.Usage.PromptTokens usage.TotalTokens = xAIResp.Usage.TotalTokens usage.CompletionTokens = usage.TotalTokens - usage.PromptTokens } openaiResponse := streamResponseXAI2OpenAI(xAIResp, usage) + _ = openai.ProcessStreamResponse(*openaiResponse, &responseTextBuilder, &toolCount) err = helper.ObjectData(c, openaiResponse) if err != nil { common.SysError(err.Error()) @@ -60,6 +67,11 @@ func xAIStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel return true }) + if !containStreamUsage { + usage, _ = service.ResponseText2Usage(responseTextBuilder.String(), info.UpstreamModelName, info.PromptTokens) + usage.CompletionTokens += toolCount * 7 + } + helper.Done(c) err := resp.Body.Close() if err != nil { diff --git a/relay/common/relay_info.go b/relay/common/relay_info.go index 0fdb1eb3..f10d3826 100644 --- a/relay/common/relay_info.go +++ b/relay/common/relay_info.go @@ -102,6 +102,7 @@ var streamSupportedChannels = map[int]bool{ common.ChannelTypeAzure: true, common.ChannelTypeVolcEngine: true, common.ChannelTypeOllama: true, + common.ChannelTypeXai: true, } func GenRelayInfoWs(c *gin.Context, ws *websocket.Conn) *RelayInfo {