diff --git a/common/constants.go b/common/constants.go index 32e9d2fd..202d0660 100644 --- a/common/constants.go +++ b/common/constants.go @@ -235,6 +235,7 @@ const ( ChannelTypeVolcEngine = 45 ChannelTypeBaiduV2 = 46 ChannelTypeXinference = 47 + ChannelTypeXai = 48 ChannelTypeDummy // this one is only for count, do not add any channel after this ) @@ -288,4 +289,5 @@ var ChannelBaseURLs = []string{ "https://ark.cn-beijing.volces.com", //45 "https://qianfan.baidubce.com", //46 "", //47 + "https://api.x.ai", //48 } diff --git a/relay/channel/xai/adaptor.go b/relay/channel/xai/adaptor.go new file mode 100644 index 00000000..5828ef0a --- /dev/null +++ b/relay/channel/xai/adaptor.go @@ -0,0 +1,106 @@ +package xai + +import ( + "errors" + "fmt" + "github.com/gin-gonic/gin" + "io" + "net/http" + "one-api/dto" + "one-api/relay/channel" + "one-api/relay/channel/openai" + relaycommon "one-api/relay/common" + "strings" +) + +type Adaptor struct { +} + +func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) { + //TODO implement me + //panic("implement me") + return nil, errors.New("not available") +} + +func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) { + //not available + return nil, errors.New("not available") +} + +func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) { + request.Size = "" + return request, nil +} + +func (a *Adaptor) Init(info *relaycommon.RelayInfo) { +} + +func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) { + return fmt.Sprintf("%s/v1/chat/completions", info.BaseUrl), nil +} + +func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error { + channel.SetupApiRequestHeader(info, c, req) + req.Set("Authorization", "Bearer "+info.ApiKey) + return nil +} + +func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) { + if request == nil { + return nil, errors.New("request is nil") + } + request.StreamOptions = nil + if strings.HasPrefix(request.Model, "grok-3-mini") { + if request.MaxCompletionTokens == 0 && request.MaxTokens != 0 { + request.MaxCompletionTokens = request.MaxTokens + request.MaxTokens = 0 + } + if strings.HasSuffix(request.Model, "-high") { + request.ReasoningEffort = "high" + request.Model = strings.TrimSuffix(request.Model, "-high") + } else if strings.HasSuffix(request.Model, "-low") { + request.ReasoningEffort = "low" + request.Model = strings.TrimSuffix(request.Model, "-low") + } else if strings.HasSuffix(request.Model, "-medium") { + request.ReasoningEffort = "medium" + request.Model = strings.TrimSuffix(request.Model, "-medium") + } + info.ReasoningEffort = request.ReasoningEffort + info.UpstreamModelName = request.Model + } + return request, nil +} + +func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) { + return nil, nil +} + +func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.EmbeddingRequest) (any, error) { + //not available + return nil, errors.New("not available") +} + +func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) { + return channel.DoApiRequest(a, c, info, requestBody) +} + +func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) { + if info.IsStream { + err, usage = openai.OaiStreamHandler(c, resp, info) + } else { + err, usage = openai.OpenaiHandler(c, resp, info) + } + if _, ok := usage.(*dto.Usage); ok && usage != nil { + usage.(*dto.Usage).CompletionTokens = usage.(*dto.Usage).TotalTokens - usage.(*dto.Usage).PromptTokens + } + + return +} + +func (a *Adaptor) GetModelList() []string { + return ModelList +} + +func (a *Adaptor) GetChannelName() string { + return ChannelName +} diff --git a/relay/channel/xai/constants.go b/relay/channel/xai/constants.go new file mode 100644 index 00000000..685fe3bb --- /dev/null +++ b/relay/channel/xai/constants.go @@ -0,0 +1,18 @@ +package xai + +var ModelList = []string{ + // grok-3 + "grok-3-beta", "grok-3-mini-beta", + // grok-3 mini + "grok-3-fast-beta", "grok-3-mini-fast-beta", + // extend grok-3-mini reasoning + "grok-3-mini-beta-high", "grok-3-mini-beta-low", "grok-3-mini-beta-medium", + "grok-3-mini-fast-beta-high", "grok-3-mini-fast-beta-low", "grok-3-mini-fast-beta-medium", + // image model + "grok-2-image", + // legacy models + "grok-2", "grok-2-vision", + "grok-beta", "grok-vision-beta", +} + +var ChannelName = "xai" diff --git a/relay/constant/api_type.go b/relay/constant/api_type.go index 2cd0e399..fef38f23 100644 --- a/relay/constant/api_type.go +++ b/relay/constant/api_type.go @@ -32,6 +32,7 @@ const ( APITypeBaiduV2 APITypeOpenRouter APITypeXinference + APITypeXai APITypeDummy // this one is only for count, do not add any channel after this ) @@ -92,6 +93,8 @@ func ChannelType2APIType(channelType int) (int, bool) { apiType = APITypeOpenRouter case common.ChannelTypeXinference: apiType = APITypeXinference + case common.ChannelTypeXai: + apiType = APITypeXai } if apiType == -1 { return APITypeOpenAI, false diff --git a/relay/relay_adaptor.go b/relay/relay_adaptor.go index be7d07e6..8b4afcb3 100644 --- a/relay/relay_adaptor.go +++ b/relay/relay_adaptor.go @@ -25,6 +25,7 @@ import ( "one-api/relay/channel/tencent" "one-api/relay/channel/vertex" "one-api/relay/channel/volcengine" + "one-api/relay/channel/xai" "one-api/relay/channel/xunfei" "one-api/relay/channel/zhipu" "one-api/relay/channel/zhipu_4v" @@ -85,6 +86,8 @@ func GetAdaptor(apiType int) channel.Adaptor { return &openai.Adaptor{} case constant.APITypeXinference: return &openai.Adaptor{} + case constant.APITypeXai: + return &xai.Adaptor{} } return nil } diff --git a/setting/operation_setting/model-ratio.go b/setting/operation_setting/model-ratio.go index 3f215b88..68e50757 100644 --- a/setting/operation_setting/model-ratio.go +++ b/setting/operation_setting/model-ratio.go @@ -199,6 +199,15 @@ var defaultModelRatio = map[string]float64{ "llama-3-sonar-small-32k-online": 0.2 / 1000 * USD, "llama-3-sonar-large-32k-chat": 1 / 1000 * USD, "llama-3-sonar-large-32k-online": 1 / 1000 * USD, + // grok + "grok-3-beta": 1.5, + "grok-3-mini-beta": 0.15, + "grok-2": 1, + "grok-2-vision": 1, + "grok-beta": 2.5, + "grok-vision-beta": 2.5, + "grok-3-fast-beta": 2.5, + "grok-3-mini-fast-beta": 0.3, } var defaultModelPrice = map[string]float64{ diff --git a/web/src/constants/channel.constants.js b/web/src/constants/channel.constants.js index f1d0c88d..583995e7 100644 --- a/web/src/constants/channel.constants.js +++ b/web/src/constants/channel.constants.js @@ -115,4 +115,9 @@ export const CHANNEL_OPTIONS = [ color: 'blue', label: '字节火山方舟、豆包、DeepSeek通用' }, + { + value: 48, + color: 'blue', + label: 'xAI' + } ];