feat: add xAI handling and response processing
This commit is contained in:
@@ -12,3 +12,7 @@ func DecodeJson(data []byte, v any) error {
|
|||||||
func DecodeJsonStr(data string, v any) error {
|
func DecodeJsonStr(data string, v any) error {
|
||||||
return DecodeJson(StringToByteSlice(data), v)
|
return DecodeJson(StringToByteSlice(data), v)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func EncodeJson(v any) ([]byte, error) {
|
||||||
|
return json.Marshal(v)
|
||||||
|
}
|
||||||
|
|||||||
@@ -45,15 +45,16 @@ type RealtimeUsage struct {
|
|||||||
|
|
||||||
type InputTokenDetails struct {
|
type InputTokenDetails struct {
|
||||||
CachedTokens int `json:"cached_tokens"`
|
CachedTokens int `json:"cached_tokens"`
|
||||||
CachedCreationTokens int
|
CachedCreationTokens int `json:"-"`
|
||||||
TextTokens int `json:"text_tokens"`
|
TextTokens int `json:"text_tokens"`
|
||||||
AudioTokens int `json:"audio_tokens"`
|
AudioTokens int `json:"audio_tokens"`
|
||||||
ImageTokens int `json:"image_tokens"`
|
ImageTokens int `json:"image_tokens"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type OutputTokenDetails struct {
|
type OutputTokenDetails struct {
|
||||||
TextTokens int `json:"text_tokens"`
|
TextTokens int `json:"text_tokens"`
|
||||||
AudioTokens int `json:"audio_tokens"`
|
AudioTokens int `json:"audio_tokens"`
|
||||||
|
ReasoningTokens int `json:"reasoning_tokens"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type RealtimeSession struct {
|
type RealtimeSession struct {
|
||||||
|
|||||||
@@ -8,7 +8,6 @@ import (
|
|||||||
"net/http"
|
"net/http"
|
||||||
"one-api/dto"
|
"one-api/dto"
|
||||||
"one-api/relay/channel"
|
"one-api/relay/channel"
|
||||||
"one-api/relay/channel/openai"
|
|
||||||
relaycommon "one-api/relay/common"
|
relaycommon "one-api/relay/common"
|
||||||
"strings"
|
"strings"
|
||||||
)
|
)
|
||||||
@@ -86,13 +85,13 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request
|
|||||||
|
|
||||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) {
|
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) {
|
||||||
if info.IsStream {
|
if info.IsStream {
|
||||||
err, usage = openai.OaiStreamHandler(c, resp, info)
|
err, usage = xAIStreamHandler(c, resp, info)
|
||||||
} else {
|
} else {
|
||||||
err, usage = openai.OpenaiHandler(c, resp, info)
|
err, usage = xAIHandler(c, resp, info)
|
||||||
}
|
|
||||||
if _, ok := usage.(*dto.Usage); ok && usage != nil {
|
|
||||||
usage.(*dto.Usage).CompletionTokens = usage.(*dto.Usage).TotalTokens - usage.(*dto.Usage).PromptTokens
|
|
||||||
}
|
}
|
||||||
|
//if _, ok := usage.(*dto.Usage); ok && usage != nil {
|
||||||
|
// usage.(*dto.Usage).CompletionTokens = usage.(*dto.Usage).TotalTokens - usage.(*dto.Usage).PromptTokens
|
||||||
|
//}
|
||||||
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|||||||
14
relay/channel/xai/dto.go
Normal file
14
relay/channel/xai/dto.go
Normal file
@@ -0,0 +1,14 @@
|
|||||||
|
package xai
|
||||||
|
|
||||||
|
import "one-api/dto"
|
||||||
|
|
||||||
|
// ChatCompletionResponse represents the response from XAI chat completion API
|
||||||
|
type ChatCompletionResponse struct {
|
||||||
|
Id string `json:"id"`
|
||||||
|
Object string `json:"object"`
|
||||||
|
Created int64 `json:"created"`
|
||||||
|
Model string `json:"model"`
|
||||||
|
Choices []dto.ChatCompletionsStreamResponseChoice
|
||||||
|
Usage *dto.Usage `json:"usage"`
|
||||||
|
SystemFingerprint string `json:"system_fingerprint"`
|
||||||
|
}
|
||||||
107
relay/channel/xai/text.go
Normal file
107
relay/channel/xai/text.go
Normal file
@@ -0,0 +1,107 @@
|
|||||||
|
package xai
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"encoding/json"
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"one-api/common"
|
||||||
|
"one-api/dto"
|
||||||
|
relaycommon "one-api/relay/common"
|
||||||
|
"one-api/relay/helper"
|
||||||
|
"one-api/service"
|
||||||
|
)
|
||||||
|
|
||||||
|
func streamResponseXAI2OpenAI(xAIResp *dto.ChatCompletionsStreamResponse, usage *dto.Usage) *dto.ChatCompletionsStreamResponse {
|
||||||
|
if xAIResp == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
if xAIResp.Usage != nil {
|
||||||
|
xAIResp.Usage.CompletionTokens = usage.CompletionTokens
|
||||||
|
}
|
||||||
|
openAIResp := &dto.ChatCompletionsStreamResponse{
|
||||||
|
Id: xAIResp.Id,
|
||||||
|
Object: xAIResp.Object,
|
||||||
|
Created: xAIResp.Created,
|
||||||
|
Model: xAIResp.Model,
|
||||||
|
Choices: xAIResp.Choices,
|
||||||
|
Usage: xAIResp.Usage,
|
||||||
|
}
|
||||||
|
|
||||||
|
return openAIResp
|
||||||
|
}
|
||||||
|
|
||||||
|
func xAIStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
|
||||||
|
usage := &dto.Usage{}
|
||||||
|
|
||||||
|
helper.SetEventStreamHeaders(c)
|
||||||
|
|
||||||
|
helper.StreamScannerHandler(c, resp, info, func(data string) bool {
|
||||||
|
var xAIResp *dto.ChatCompletionsStreamResponse
|
||||||
|
err := json.Unmarshal([]byte(data), &xAIResp)
|
||||||
|
if err != nil {
|
||||||
|
common.SysError("error unmarshalling stream response: " + err.Error())
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
// 把 xAI 的usage转换为 OpenAI 的usage
|
||||||
|
if xAIResp.Usage != nil {
|
||||||
|
usage.PromptTokens = xAIResp.Usage.PromptTokens
|
||||||
|
usage.TotalTokens = xAIResp.Usage.TotalTokens
|
||||||
|
usage.CompletionTokens = usage.TotalTokens - usage.PromptTokens
|
||||||
|
}
|
||||||
|
|
||||||
|
openaiResponse := streamResponseXAI2OpenAI(xAIResp, usage)
|
||||||
|
err = helper.ObjectData(c, openaiResponse)
|
||||||
|
if err != nil {
|
||||||
|
common.SysError(err.Error())
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
})
|
||||||
|
|
||||||
|
helper.Done(c)
|
||||||
|
err := resp.Body.Close()
|
||||||
|
if err != nil {
|
||||||
|
//return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
|
||||||
|
common.SysError("close_response_body_failed: " + err.Error())
|
||||||
|
}
|
||||||
|
return nil, usage
|
||||||
|
}
|
||||||
|
|
||||||
|
func xAIHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
|
||||||
|
responseBody, err := io.ReadAll(resp.Body)
|
||||||
|
var response *dto.TextResponse
|
||||||
|
err = common.DecodeJson(responseBody, &response)
|
||||||
|
if err != nil {
|
||||||
|
common.SysError("error unmarshalling stream response: " + err.Error())
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
response.Usage.CompletionTokens = response.Usage.TotalTokens - response.Usage.PromptTokens
|
||||||
|
response.Usage.CompletionTokenDetails.TextTokens = response.Usage.CompletionTokens - response.Usage.CompletionTokenDetails.ReasoningTokens
|
||||||
|
|
||||||
|
// new body
|
||||||
|
encodeJson, err := common.EncodeJson(response)
|
||||||
|
if err != nil {
|
||||||
|
common.SysError("error marshalling stream response: " + err.Error())
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// set new body
|
||||||
|
resp.Body = io.NopCloser(bytes.NewBuffer(encodeJson))
|
||||||
|
|
||||||
|
for k, v := range resp.Header {
|
||||||
|
c.Writer.Header().Set(k, v[0])
|
||||||
|
}
|
||||||
|
c.Writer.WriteHeader(resp.StatusCode)
|
||||||
|
_, err = io.Copy(c.Writer, resp.Body)
|
||||||
|
if err != nil {
|
||||||
|
return service.OpenAIErrorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError), nil
|
||||||
|
}
|
||||||
|
err = resp.Body.Close()
|
||||||
|
if err != nil {
|
||||||
|
return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil, &response.Usage
|
||||||
|
}
|
||||||
@@ -56,6 +56,9 @@ func StringData(c *gin.Context, str string) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func ObjectData(c *gin.Context, object interface{}) error {
|
func ObjectData(c *gin.Context, object interface{}) error {
|
||||||
|
if object == nil {
|
||||||
|
return errors.New("object is nil")
|
||||||
|
}
|
||||||
jsonData, err := json.Marshal(object)
|
jsonData, err := json.Marshal(object)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("error marshalling object: %w", err)
|
return fmt.Errorf("error marshalling object: %w", err)
|
||||||
|
|||||||
Reference in New Issue
Block a user