package xai import ( "bytes" "encoding/json" "github.com/gin-gonic/gin" "io" "net/http" "one-api/common" "one-api/dto" relaycommon "one-api/relay/common" "one-api/relay/helper" "one-api/service" ) func streamResponseXAI2OpenAI(xAIResp *dto.ChatCompletionsStreamResponse, usage *dto.Usage) *dto.ChatCompletionsStreamResponse { if xAIResp == nil { return nil } if xAIResp.Usage != nil { xAIResp.Usage.CompletionTokens = usage.CompletionTokens } openAIResp := &dto.ChatCompletionsStreamResponse{ Id: xAIResp.Id, Object: xAIResp.Object, Created: xAIResp.Created, Model: xAIResp.Model, Choices: xAIResp.Choices, Usage: xAIResp.Usage, } return openAIResp } func xAIStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) { usage := &dto.Usage{} helper.SetEventStreamHeaders(c) helper.StreamScannerHandler(c, resp, info, func(data string) bool { var xAIResp *dto.ChatCompletionsStreamResponse err := json.Unmarshal([]byte(data), &xAIResp) if err != nil { common.SysError("error unmarshalling stream response: " + err.Error()) return true } // 把 xAI 的usage转换为 OpenAI 的usage if xAIResp.Usage != nil { usage.PromptTokens = xAIResp.Usage.PromptTokens usage.TotalTokens = xAIResp.Usage.TotalTokens usage.CompletionTokens = usage.TotalTokens - usage.PromptTokens } openaiResponse := streamResponseXAI2OpenAI(xAIResp, usage) err = helper.ObjectData(c, openaiResponse) if err != nil { common.SysError(err.Error()) } return true }) helper.Done(c) err := resp.Body.Close() if err != nil { //return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil common.SysError("close_response_body_failed: " + err.Error()) } return nil, usage } func xAIHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) { responseBody, err := io.ReadAll(resp.Body) var response *dto.TextResponse err = common.DecodeJson(responseBody, &response) if err != nil { common.SysError("error unmarshalling stream response: " + err.Error()) return nil, nil } response.Usage.CompletionTokens = response.Usage.TotalTokens - response.Usage.PromptTokens response.Usage.CompletionTokenDetails.TextTokens = response.Usage.CompletionTokens - response.Usage.CompletionTokenDetails.ReasoningTokens // new body encodeJson, err := common.EncodeJson(response) if err != nil { common.SysError("error marshalling stream response: " + err.Error()) return nil, nil } // set new body resp.Body = io.NopCloser(bytes.NewBuffer(encodeJson)) for k, v := range resp.Header { c.Writer.Header().Set(k, v[0]) } c.Writer.WriteHeader(resp.StatusCode) _, err = io.Copy(c.Writer, resp.Body) if err != nil { return service.OpenAIErrorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError), nil } err = resp.Body.Close() if err != nil { return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil } return nil, &response.Usage }