This update standardizes the closure of HTTP response bodies across multiple stream handlers, enhancing error management and resource cleanup. The new method ensures that any errors during closure are handled gracefully, preventing potential request termination issues.
113 lines
3.2 KiB
Go
113 lines
3.2 KiB
Go
package xai
|
|
|
|
import (
|
|
"bytes"
|
|
"encoding/json"
|
|
"github.com/gin-gonic/gin"
|
|
"io"
|
|
"net/http"
|
|
"one-api/common"
|
|
"one-api/dto"
|
|
"one-api/relay/channel/openai"
|
|
relaycommon "one-api/relay/common"
|
|
"one-api/relay/helper"
|
|
"one-api/service"
|
|
"strings"
|
|
)
|
|
|
|
func streamResponseXAI2OpenAI(xAIResp *dto.ChatCompletionsStreamResponse, usage *dto.Usage) *dto.ChatCompletionsStreamResponse {
|
|
if xAIResp == nil {
|
|
return nil
|
|
}
|
|
if xAIResp.Usage != nil {
|
|
xAIResp.Usage.CompletionTokens = usage.CompletionTokens
|
|
}
|
|
openAIResp := &dto.ChatCompletionsStreamResponse{
|
|
Id: xAIResp.Id,
|
|
Object: xAIResp.Object,
|
|
Created: xAIResp.Created,
|
|
Model: xAIResp.Model,
|
|
Choices: xAIResp.Choices,
|
|
Usage: xAIResp.Usage,
|
|
}
|
|
|
|
return openAIResp
|
|
}
|
|
|
|
func xAIStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
|
|
usage := &dto.Usage{}
|
|
var responseTextBuilder strings.Builder
|
|
var toolCount int
|
|
var containStreamUsage bool
|
|
|
|
helper.SetEventStreamHeaders(c)
|
|
|
|
helper.StreamScannerHandler(c, resp, info, func(data string) bool {
|
|
var xAIResp *dto.ChatCompletionsStreamResponse
|
|
err := json.Unmarshal([]byte(data), &xAIResp)
|
|
if err != nil {
|
|
common.SysError("error unmarshalling stream response: " + err.Error())
|
|
return true
|
|
}
|
|
|
|
// 把 xAI 的usage转换为 OpenAI 的usage
|
|
if xAIResp.Usage != nil {
|
|
containStreamUsage = true
|
|
usage.PromptTokens = xAIResp.Usage.PromptTokens
|
|
usage.TotalTokens = xAIResp.Usage.TotalTokens
|
|
usage.CompletionTokens = usage.TotalTokens - usage.PromptTokens
|
|
}
|
|
|
|
openaiResponse := streamResponseXAI2OpenAI(xAIResp, usage)
|
|
_ = openai.ProcessStreamResponse(*openaiResponse, &responseTextBuilder, &toolCount)
|
|
err = helper.ObjectData(c, openaiResponse)
|
|
if err != nil {
|
|
common.SysError(err.Error())
|
|
}
|
|
return true
|
|
})
|
|
|
|
if !containStreamUsage {
|
|
usage = service.ResponseText2Usage(responseTextBuilder.String(), info.UpstreamModelName, info.PromptTokens)
|
|
usage.CompletionTokens += toolCount * 7
|
|
}
|
|
|
|
helper.Done(c)
|
|
common.CloseResponseBodyGracefully(resp)
|
|
return nil, usage
|
|
}
|
|
|
|
func xAIHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
|
|
responseBody, err := io.ReadAll(resp.Body)
|
|
var response *dto.TextResponse
|
|
err = common.DecodeJson(responseBody, &response)
|
|
if err != nil {
|
|
common.SysError("error unmarshalling stream response: " + err.Error())
|
|
return nil, nil
|
|
}
|
|
response.Usage.CompletionTokens = response.Usage.TotalTokens - response.Usage.PromptTokens
|
|
response.Usage.CompletionTokenDetails.TextTokens = response.Usage.CompletionTokens - response.Usage.CompletionTokenDetails.ReasoningTokens
|
|
|
|
// new body
|
|
encodeJson, err := common.EncodeJson(response)
|
|
if err != nil {
|
|
common.SysError("error marshalling stream response: " + err.Error())
|
|
return nil, nil
|
|
}
|
|
|
|
// set new body
|
|
resp.Body = io.NopCloser(bytes.NewBuffer(encodeJson))
|
|
|
|
for k, v := range resp.Header {
|
|
c.Writer.Header().Set(k, v[0])
|
|
}
|
|
c.Writer.WriteHeader(resp.StatusCode)
|
|
_, err = io.Copy(c.Writer, resp.Body)
|
|
if err != nil {
|
|
return service.OpenAIErrorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError), nil
|
|
}
|
|
common.CloseResponseBodyGracefully(resp)
|
|
|
|
return nil, &response.Usage
|
|
}
|