fix: xAI usage
This commit is contained in:
@@ -41,12 +41,7 @@ func handleClaudeFormat(c *gin.Context, data string, info *relaycommon.RelayInfo
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func processStreamResponse(item string, responseTextBuilder *strings.Builder, toolCount *int) error {
|
func ProcessStreamResponse(streamResponse dto.ChatCompletionsStreamResponse, responseTextBuilder *strings.Builder, toolCount *int) error {
|
||||||
var streamResponse dto.ChatCompletionsStreamResponse
|
|
||||||
if err := json.Unmarshal(common.StringToByteSlice(item), &streamResponse); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, choice := range streamResponse.Choices {
|
for _, choice := range streamResponse.Choices {
|
||||||
responseTextBuilder.WriteString(choice.Delta.GetContentString())
|
responseTextBuilder.WriteString(choice.Delta.GetContentString())
|
||||||
responseTextBuilder.WriteString(choice.Delta.GetReasoningContent())
|
responseTextBuilder.WriteString(choice.Delta.GetReasoningContent())
|
||||||
@@ -81,7 +76,11 @@ func processChatCompletions(streamResp string, streamItems []string, responseTex
|
|||||||
// 一次性解析失败,逐个解析
|
// 一次性解析失败,逐个解析
|
||||||
common.SysError("error unmarshalling stream response: " + err.Error())
|
common.SysError("error unmarshalling stream response: " + err.Error())
|
||||||
for _, item := range streamItems {
|
for _, item := range streamItems {
|
||||||
if err := processStreamResponse(item, responseTextBuilder, toolCount); err != nil {
|
var streamResponse dto.ChatCompletionsStreamResponse
|
||||||
|
if err := json.Unmarshal(common.StringToByteSlice(item), &streamResponse); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if err := ProcessStreamResponse(streamResponse, responseTextBuilder, toolCount); err != nil {
|
||||||
common.SysError("error processing stream response: " + err.Error())
|
common.SysError("error processing stream response: " + err.Error())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -117,6 +117,7 @@ func OaiStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel
|
|||||||
model := info.UpstreamModelName
|
model := info.UpstreamModelName
|
||||||
|
|
||||||
var responseTextBuilder strings.Builder
|
var responseTextBuilder strings.Builder
|
||||||
|
var toolCount int
|
||||||
var usage = &dto.Usage{}
|
var usage = &dto.Usage{}
|
||||||
var streamItems []string // store stream items
|
var streamItems []string // store stream items
|
||||||
var forceFormat bool
|
var forceFormat bool
|
||||||
@@ -130,8 +131,6 @@ func OaiStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel
|
|||||||
thinkToContent = think2Content
|
thinkToContent = think2Content
|
||||||
}
|
}
|
||||||
|
|
||||||
toolCount := 0
|
|
||||||
|
|
||||||
var (
|
var (
|
||||||
lastStreamData string
|
lastStreamData string
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -48,7 +48,6 @@ func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayIn
|
|||||||
if request == nil {
|
if request == nil {
|
||||||
return nil, errors.New("request is nil")
|
return nil, errors.New("request is nil")
|
||||||
}
|
}
|
||||||
request.StreamOptions = nil
|
|
||||||
if strings.HasPrefix(request.Model, "grok-3-mini") {
|
if strings.HasPrefix(request.Model, "grok-3-mini") {
|
||||||
if request.MaxCompletionTokens == 0 && request.MaxTokens != 0 {
|
if request.MaxCompletionTokens == 0 && request.MaxTokens != 0 {
|
||||||
request.MaxCompletionTokens = request.MaxTokens
|
request.MaxCompletionTokens = request.MaxTokens
|
||||||
|
|||||||
@@ -8,9 +8,11 @@ import (
|
|||||||
"net/http"
|
"net/http"
|
||||||
"one-api/common"
|
"one-api/common"
|
||||||
"one-api/dto"
|
"one-api/dto"
|
||||||
|
"one-api/relay/channel/openai"
|
||||||
relaycommon "one-api/relay/common"
|
relaycommon "one-api/relay/common"
|
||||||
"one-api/relay/helper"
|
"one-api/relay/helper"
|
||||||
"one-api/service"
|
"one-api/service"
|
||||||
|
"strings"
|
||||||
)
|
)
|
||||||
|
|
||||||
func streamResponseXAI2OpenAI(xAIResp *dto.ChatCompletionsStreamResponse, usage *dto.Usage) *dto.ChatCompletionsStreamResponse {
|
func streamResponseXAI2OpenAI(xAIResp *dto.ChatCompletionsStreamResponse, usage *dto.Usage) *dto.ChatCompletionsStreamResponse {
|
||||||
@@ -34,6 +36,9 @@ func streamResponseXAI2OpenAI(xAIResp *dto.ChatCompletionsStreamResponse, usage
|
|||||||
|
|
||||||
func xAIStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
|
func xAIStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
|
||||||
usage := &dto.Usage{}
|
usage := &dto.Usage{}
|
||||||
|
var responseTextBuilder strings.Builder
|
||||||
|
var toolCount int
|
||||||
|
var containStreamUsage bool
|
||||||
|
|
||||||
helper.SetEventStreamHeaders(c)
|
helper.SetEventStreamHeaders(c)
|
||||||
|
|
||||||
@@ -47,12 +52,14 @@ func xAIStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel
|
|||||||
|
|
||||||
// 把 xAI 的usage转换为 OpenAI 的usage
|
// 把 xAI 的usage转换为 OpenAI 的usage
|
||||||
if xAIResp.Usage != nil {
|
if xAIResp.Usage != nil {
|
||||||
|
containStreamUsage = true
|
||||||
usage.PromptTokens = xAIResp.Usage.PromptTokens
|
usage.PromptTokens = xAIResp.Usage.PromptTokens
|
||||||
usage.TotalTokens = xAIResp.Usage.TotalTokens
|
usage.TotalTokens = xAIResp.Usage.TotalTokens
|
||||||
usage.CompletionTokens = usage.TotalTokens - usage.PromptTokens
|
usage.CompletionTokens = usage.TotalTokens - usage.PromptTokens
|
||||||
}
|
}
|
||||||
|
|
||||||
openaiResponse := streamResponseXAI2OpenAI(xAIResp, usage)
|
openaiResponse := streamResponseXAI2OpenAI(xAIResp, usage)
|
||||||
|
_ = openai.ProcessStreamResponse(*openaiResponse, &responseTextBuilder, &toolCount)
|
||||||
err = helper.ObjectData(c, openaiResponse)
|
err = helper.ObjectData(c, openaiResponse)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
common.SysError(err.Error())
|
common.SysError(err.Error())
|
||||||
@@ -60,6 +67,11 @@ func xAIStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel
|
|||||||
return true
|
return true
|
||||||
})
|
})
|
||||||
|
|
||||||
|
if !containStreamUsage {
|
||||||
|
usage, _ = service.ResponseText2Usage(responseTextBuilder.String(), info.UpstreamModelName, info.PromptTokens)
|
||||||
|
usage.CompletionTokens += toolCount * 7
|
||||||
|
}
|
||||||
|
|
||||||
helper.Done(c)
|
helper.Done(c)
|
||||||
err := resp.Body.Close()
|
err := resp.Body.Close()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|||||||
@@ -102,6 +102,7 @@ var streamSupportedChannels = map[int]bool{
|
|||||||
common.ChannelTypeAzure: true,
|
common.ChannelTypeAzure: true,
|
||||||
common.ChannelTypeVolcEngine: true,
|
common.ChannelTypeVolcEngine: true,
|
||||||
common.ChannelTypeOllama: true,
|
common.ChannelTypeOllama: true,
|
||||||
|
common.ChannelTypeXai: true,
|
||||||
}
|
}
|
||||||
|
|
||||||
func GenRelayInfoWs(c *gin.Context, ws *websocket.Conn) *RelayInfo {
|
func GenRelayInfoWs(c *gin.Context, ws *websocket.Conn) *RelayInfo {
|
||||||
|
|||||||
Reference in New Issue
Block a user