feat: add xAI handling and response processing

This commit is contained in:
CaIon
2025-04-10 17:20:59 +08:00
parent 8efa12b941
commit 8723e3f239
6 changed files with 137 additions and 9 deletions

View File

@@ -12,3 +12,7 @@ func DecodeJson(data []byte, v any) error {
func DecodeJsonStr(data string, v any) error { func DecodeJsonStr(data string, v any) error {
return DecodeJson(StringToByteSlice(data), v) return DecodeJson(StringToByteSlice(data), v)
} }
func EncodeJson(v any) ([]byte, error) {
return json.Marshal(v)
}

View File

@@ -45,15 +45,16 @@ type RealtimeUsage struct {
type InputTokenDetails struct { type InputTokenDetails struct {
CachedTokens int `json:"cached_tokens"` CachedTokens int `json:"cached_tokens"`
CachedCreationTokens int CachedCreationTokens int `json:"-"`
TextTokens int `json:"text_tokens"` TextTokens int `json:"text_tokens"`
AudioTokens int `json:"audio_tokens"` AudioTokens int `json:"audio_tokens"`
ImageTokens int `json:"image_tokens"` ImageTokens int `json:"image_tokens"`
} }
type OutputTokenDetails struct { type OutputTokenDetails struct {
TextTokens int `json:"text_tokens"` TextTokens int `json:"text_tokens"`
AudioTokens int `json:"audio_tokens"` AudioTokens int `json:"audio_tokens"`
ReasoningTokens int `json:"reasoning_tokens"`
} }
type RealtimeSession struct { type RealtimeSession struct {

View File

@@ -8,7 +8,6 @@ import (
"net/http" "net/http"
"one-api/dto" "one-api/dto"
"one-api/relay/channel" "one-api/relay/channel"
"one-api/relay/channel/openai"
relaycommon "one-api/relay/common" relaycommon "one-api/relay/common"
"strings" "strings"
) )
@@ -86,13 +85,13 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) { func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) {
if info.IsStream { if info.IsStream {
err, usage = openai.OaiStreamHandler(c, resp, info) err, usage = xAIStreamHandler(c, resp, info)
} else { } else {
err, usage = openai.OpenaiHandler(c, resp, info) err, usage = xAIHandler(c, resp, info)
}
if _, ok := usage.(*dto.Usage); ok && usage != nil {
usage.(*dto.Usage).CompletionTokens = usage.(*dto.Usage).TotalTokens - usage.(*dto.Usage).PromptTokens
} }
//if _, ok := usage.(*dto.Usage); ok && usage != nil {
// usage.(*dto.Usage).CompletionTokens = usage.(*dto.Usage).TotalTokens - usage.(*dto.Usage).PromptTokens
//}
return return
} }

14
relay/channel/xai/dto.go Normal file
View File

@@ -0,0 +1,14 @@
package xai
import "one-api/dto"
// ChatCompletionResponse represents the response from XAI chat completion API
type ChatCompletionResponse struct {
Id string `json:"id"`
Object string `json:"object"`
Created int64 `json:"created"`
Model string `json:"model"`
Choices []dto.ChatCompletionsStreamResponseChoice
Usage *dto.Usage `json:"usage"`
SystemFingerprint string `json:"system_fingerprint"`
}

107
relay/channel/xai/text.go Normal file
View File

@@ -0,0 +1,107 @@
package xai
import (
"bytes"
"encoding/json"
"github.com/gin-gonic/gin"
"io"
"net/http"
"one-api/common"
"one-api/dto"
relaycommon "one-api/relay/common"
"one-api/relay/helper"
"one-api/service"
)
func streamResponseXAI2OpenAI(xAIResp *dto.ChatCompletionsStreamResponse, usage *dto.Usage) *dto.ChatCompletionsStreamResponse {
if xAIResp == nil {
return nil
}
if xAIResp.Usage != nil {
xAIResp.Usage.CompletionTokens = usage.CompletionTokens
}
openAIResp := &dto.ChatCompletionsStreamResponse{
Id: xAIResp.Id,
Object: xAIResp.Object,
Created: xAIResp.Created,
Model: xAIResp.Model,
Choices: xAIResp.Choices,
Usage: xAIResp.Usage,
}
return openAIResp
}
func xAIStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
usage := &dto.Usage{}
helper.SetEventStreamHeaders(c)
helper.StreamScannerHandler(c, resp, info, func(data string) bool {
var xAIResp *dto.ChatCompletionsStreamResponse
err := json.Unmarshal([]byte(data), &xAIResp)
if err != nil {
common.SysError("error unmarshalling stream response: " + err.Error())
return true
}
// 把 xAI 的usage转换为 OpenAI 的usage
if xAIResp.Usage != nil {
usage.PromptTokens = xAIResp.Usage.PromptTokens
usage.TotalTokens = xAIResp.Usage.TotalTokens
usage.CompletionTokens = usage.TotalTokens - usage.PromptTokens
}
openaiResponse := streamResponseXAI2OpenAI(xAIResp, usage)
err = helper.ObjectData(c, openaiResponse)
if err != nil {
common.SysError(err.Error())
}
return true
})
helper.Done(c)
err := resp.Body.Close()
if err != nil {
//return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
common.SysError("close_response_body_failed: " + err.Error())
}
return nil, usage
}
func xAIHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
responseBody, err := io.ReadAll(resp.Body)
var response *dto.TextResponse
err = common.DecodeJson(responseBody, &response)
if err != nil {
common.SysError("error unmarshalling stream response: " + err.Error())
return nil, nil
}
response.Usage.CompletionTokens = response.Usage.TotalTokens - response.Usage.PromptTokens
response.Usage.CompletionTokenDetails.TextTokens = response.Usage.CompletionTokens - response.Usage.CompletionTokenDetails.ReasoningTokens
// new body
encodeJson, err := common.EncodeJson(response)
if err != nil {
common.SysError("error marshalling stream response: " + err.Error())
return nil, nil
}
// set new body
resp.Body = io.NopCloser(bytes.NewBuffer(encodeJson))
for k, v := range resp.Header {
c.Writer.Header().Set(k, v[0])
}
c.Writer.WriteHeader(resp.StatusCode)
_, err = io.Copy(c.Writer, resp.Body)
if err != nil {
return service.OpenAIErrorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError), nil
}
err = resp.Body.Close()
if err != nil {
return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
}
return nil, &response.Usage
}

View File

@@ -56,6 +56,9 @@ func StringData(c *gin.Context, str string) error {
} }
func ObjectData(c *gin.Context, object interface{}) error { func ObjectData(c *gin.Context, object interface{}) error {
if object == nil {
return errors.New("object is nil")
}
jsonData, err := json.Marshal(object) jsonData, err := json.Marshal(object)
if err != nil { if err != nil {
return fmt.Errorf("error marshalling object: %w", err) return fmt.Errorf("error marshalling object: %w", err)