diff --git a/relay/channel/xai/adaptor.go b/relay/channel/xai/adaptor.go index 12634c84..b5896415 100644 --- a/relay/channel/xai/adaptor.go +++ b/relay/channel/xai/adaptor.go @@ -2,14 +2,16 @@ package xai import ( "errors" - "fmt" "io" "net/http" "one-api/dto" "one-api/relay/channel" + "one-api/relay/channel/openai" relaycommon "one-api/relay/common" "strings" + "one-api/relay/constant" + "github.com/gin-gonic/gin" ) @@ -28,15 +30,20 @@ func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInf } func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) { - request.Size = "" - return request, nil + xaiRequest := ImageRequest{ + Model: request.Model, + Prompt: request.Prompt, + N: request.N, + ResponseFormat: request.ResponseFormat, + } + return xaiRequest, nil } func (a *Adaptor) Init(info *relaycommon.RelayInfo) { } func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) { - return fmt.Sprintf("%s/v1/chat/completions", info.BaseUrl), nil + return relaycommon.GetFullRequestURL(info.BaseUrl, info.RequestURLPath, info.ChannelType), nil } func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error { @@ -89,15 +96,16 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request } func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) { - if info.IsStream { - err, usage = xAIStreamHandler(c, resp, info) - } else { - err, usage = xAIHandler(c, resp, info) + switch info.RelayMode { + case constant.RelayModeImagesGenerations, constant.RelayModeImagesEdits: + err, usage = openai.OpenaiHandlerWithUsage(c, resp, info) + default: + if info.IsStream { + err, usage = xAIStreamHandler(c, resp, info) + } else { + err, usage = xAIHandler(c, resp, info) + } } - //if _, ok := usage.(*dto.Usage); ok && usage != nil { - // usage.(*dto.Usage).CompletionTokens = usage.(*dto.Usage).TotalTokens - usage.(*dto.Usage).PromptTokens - //} - return } diff --git a/relay/channel/xai/dto.go b/relay/channel/xai/dto.go index 7036d5f1..b8098475 100644 --- a/relay/channel/xai/dto.go +++ b/relay/channel/xai/dto.go @@ -12,3 +12,16 @@ type ChatCompletionResponse struct { Usage *dto.Usage `json:"usage"` SystemFingerprint string `json:"system_fingerprint"` } + +// quality, size or style are not supported by xAI API at the moment. +type ImageRequest struct { + Model string `json:"model"` + Prompt string `json:"prompt" binding:"required"` + N int `json:"n,omitempty"` + // Size string `json:"size,omitempty"` + // Quality string `json:"quality,omitempty"` + ResponseFormat string `json:"response_format,omitempty"` + // Style string `json:"style,omitempty"` + // User string `json:"user,omitempty"` + // ExtraFields json.RawMessage `json:"extra_fields,omitempty"` +} \ No newline at end of file