284 lines
9.2 KiB
Go
284 lines
9.2 KiB
Go
package openai
|
||
|
||
import (
|
||
"bytes"
|
||
"encoding/json"
|
||
"errors"
|
||
"fmt"
|
||
"github.com/gin-gonic/gin"
|
||
"io"
|
||
"mime/multipart"
|
||
"net/http"
|
||
"one-api/common"
|
||
constant2 "one-api/constant"
|
||
"one-api/dto"
|
||
"one-api/relay/channel"
|
||
"one-api/relay/channel/ai360"
|
||
"one-api/relay/channel/jina"
|
||
"one-api/relay/channel/lingyiwanwu"
|
||
"one-api/relay/channel/minimax"
|
||
"one-api/relay/channel/moonshot"
|
||
"one-api/relay/channel/xinference"
|
||
relaycommon "one-api/relay/common"
|
||
"one-api/relay/constant"
|
||
"strings"
|
||
)
|
||
|
||
type Adaptor struct {
|
||
ChannelType int
|
||
ResponseFormat string
|
||
}
|
||
|
||
func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) {
|
||
//TODO implement me
|
||
panic("implement me")
|
||
return nil, nil
|
||
}
|
||
|
||
func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
|
||
a.ChannelType = info.ChannelType
|
||
}
|
||
|
||
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
||
if info.RelayMode == constant.RelayModeRealtime {
|
||
if strings.HasPrefix(info.BaseUrl, "https://") {
|
||
baseUrl := strings.TrimPrefix(info.BaseUrl, "https://")
|
||
baseUrl = "wss://" + baseUrl
|
||
info.BaseUrl = baseUrl
|
||
} else if strings.HasPrefix(info.BaseUrl, "http://") {
|
||
baseUrl := strings.TrimPrefix(info.BaseUrl, "http://")
|
||
baseUrl = "ws://" + baseUrl
|
||
info.BaseUrl = baseUrl
|
||
}
|
||
}
|
||
switch info.ChannelType {
|
||
case common.ChannelTypeAzure:
|
||
apiVersion := info.ApiVersion
|
||
if apiVersion == "" {
|
||
apiVersion = constant2.AzureDefaultAPIVersion
|
||
}
|
||
// https://learn.microsoft.com/en-us/azure/cognitive-services/openai/chatgpt-quickstart?pivots=rest-api&tabs=command-line#rest-api
|
||
requestURL := strings.Split(info.RequestURLPath, "?")[0]
|
||
requestURL = fmt.Sprintf("%s?api-version=%s", requestURL, apiVersion)
|
||
task := strings.TrimPrefix(requestURL, "/v1/")
|
||
model_ := info.UpstreamModelName
|
||
model_ = strings.Replace(model_, ".", "", -1)
|
||
// https://github.com/songquanpeng/one-api/issues/67
|
||
requestURL = fmt.Sprintf("/openai/deployments/%s/%s", model_, task)
|
||
if info.RelayMode == constant.RelayModeRealtime {
|
||
requestURL = fmt.Sprintf("/openai/realtime?deployment=%s&api-version=%s", model_, apiVersion)
|
||
}
|
||
return relaycommon.GetFullRequestURL(info.BaseUrl, requestURL, info.ChannelType), nil
|
||
case common.ChannelTypeMiniMax:
|
||
return minimax.GetRequestURL(info)
|
||
case common.ChannelTypeCustom:
|
||
url := info.BaseUrl
|
||
url = strings.Replace(url, "{model}", info.UpstreamModelName, -1)
|
||
return url, nil
|
||
default:
|
||
return relaycommon.GetFullRequestURL(info.BaseUrl, info.RequestURLPath, info.ChannelType), nil
|
||
}
|
||
}
|
||
|
||
func (a *Adaptor) SetupRequestHeader(c *gin.Context, header *http.Header, info *relaycommon.RelayInfo) error {
|
||
channel.SetupApiRequestHeader(info, c, header)
|
||
if info.ChannelType == common.ChannelTypeAzure {
|
||
header.Set("api-key", info.ApiKey)
|
||
return nil
|
||
}
|
||
if info.ChannelType == common.ChannelTypeOpenAI && "" != info.Organization {
|
||
header.Set("OpenAI-Organization", info.Organization)
|
||
}
|
||
if info.RelayMode == constant.RelayModeRealtime {
|
||
swp := c.Request.Header.Get("Sec-WebSocket-Protocol")
|
||
if swp != "" {
|
||
items := []string{
|
||
"realtime",
|
||
"openai-insecure-api-key." + info.ApiKey,
|
||
"openai-beta.realtime-v1",
|
||
}
|
||
header.Set("Sec-WebSocket-Protocol", strings.Join(items, ","))
|
||
//req.Header.Set("Sec-WebSocket-Key", c.Request.Header.Get("Sec-WebSocket-Key"))
|
||
//req.Header.Set("Sec-Websocket-Extensions", c.Request.Header.Get("Sec-Websocket-Extensions"))
|
||
//req.Header.Set("Sec-Websocket-Version", c.Request.Header.Get("Sec-Websocket-Version"))
|
||
} else {
|
||
header.Set("openai-beta", "realtime=v1")
|
||
header.Set("Authorization", "Bearer "+info.ApiKey)
|
||
}
|
||
} else {
|
||
header.Set("Authorization", "Bearer "+info.ApiKey)
|
||
}
|
||
//if info.ChannelType == common.ChannelTypeOpenRouter {
|
||
// req.Header.Set("HTTP-Referer", "https://github.com/songquanpeng/one-api")
|
||
// req.Header.Set("X-Title", "One API")
|
||
//}
|
||
return nil
|
||
}
|
||
|
||
func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
|
||
if request == nil {
|
||
return nil, errors.New("request is nil")
|
||
}
|
||
if info.ChannelType != common.ChannelTypeOpenAI && info.ChannelType != common.ChannelTypeAzure {
|
||
request.StreamOptions = nil
|
||
}
|
||
if strings.HasPrefix(request.Model, "o1") || strings.HasPrefix(request.Model, "o3") {
|
||
if request.MaxCompletionTokens == 0 && request.MaxTokens != 0 {
|
||
request.MaxCompletionTokens = request.MaxTokens
|
||
request.MaxTokens = 0
|
||
}
|
||
if strings.HasPrefix(request.Model, "o3") || strings.HasPrefix(request.Model, "o1") {
|
||
request.Temperature = nil
|
||
}
|
||
if strings.HasSuffix(request.Model, "-high") {
|
||
request.ReasoningEffort = "high"
|
||
request.Model = strings.TrimSuffix(request.Model, "-high")
|
||
} else if strings.HasSuffix(request.Model, "-low") {
|
||
request.ReasoningEffort = "low"
|
||
request.Model = strings.TrimSuffix(request.Model, "-low")
|
||
} else if strings.HasSuffix(request.Model, "-medium") {
|
||
request.ReasoningEffort = "medium"
|
||
request.Model = strings.TrimSuffix(request.Model, "-medium")
|
||
}
|
||
info.ReasoningEffort = request.ReasoningEffort
|
||
info.UpstreamModelName = request.Model
|
||
}
|
||
if request.Model == "o1" || request.Model == "o1-2024-12-17" || strings.HasPrefix(request.Model, "o3") {
|
||
//修改第一个Message的内容,将system改为developer
|
||
if len(request.Messages) > 0 && request.Messages[0].Role == "system" {
|
||
request.Messages[0].Role = "developer"
|
||
}
|
||
}
|
||
|
||
return request, nil
|
||
}
|
||
|
||
func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) {
|
||
return request, nil
|
||
}
|
||
|
||
func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.EmbeddingRequest) (any, error) {
|
||
return request, nil
|
||
}
|
||
|
||
func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) {
|
||
a.ResponseFormat = request.ResponseFormat
|
||
if info.RelayMode == constant.RelayModeAudioSpeech {
|
||
jsonData, err := json.Marshal(request)
|
||
if err != nil {
|
||
return nil, fmt.Errorf("error marshalling object: %w", err)
|
||
}
|
||
return bytes.NewReader(jsonData), nil
|
||
} else {
|
||
var requestBody bytes.Buffer
|
||
writer := multipart.NewWriter(&requestBody)
|
||
|
||
writer.WriteField("model", request.Model)
|
||
|
||
// 获取所有表单字段
|
||
formData := c.Request.PostForm
|
||
|
||
// 遍历表单字段并打印输出
|
||
for key, values := range formData {
|
||
if key == "model" {
|
||
continue
|
||
}
|
||
for _, value := range values {
|
||
writer.WriteField(key, value)
|
||
}
|
||
}
|
||
|
||
// 添加文件字段
|
||
file, header, err := c.Request.FormFile("file")
|
||
if err != nil {
|
||
return nil, errors.New("file is required")
|
||
}
|
||
defer file.Close()
|
||
|
||
part, err := writer.CreateFormFile("file", header.Filename)
|
||
if err != nil {
|
||
return nil, errors.New("create form file failed")
|
||
}
|
||
if _, err := io.Copy(part, file); err != nil {
|
||
return nil, errors.New("copy file failed")
|
||
}
|
||
|
||
// 关闭 multipart 编写器以设置分界线
|
||
writer.Close()
|
||
c.Request.Header.Set("Content-Type", writer.FormDataContentType())
|
||
return &requestBody, nil
|
||
}
|
||
}
|
||
|
||
func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) {
|
||
return request, nil
|
||
}
|
||
|
||
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
|
||
if info.RelayMode == constant.RelayModeAudioTranscription || info.RelayMode == constant.RelayModeAudioTranslation {
|
||
return channel.DoFormRequest(a, c, info, requestBody)
|
||
} else if info.RelayMode == constant.RelayModeRealtime {
|
||
return channel.DoWssRequest(a, c, info, requestBody)
|
||
} else {
|
||
return channel.DoApiRequest(a, c, info, requestBody)
|
||
}
|
||
}
|
||
|
||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) {
|
||
switch info.RelayMode {
|
||
case constant.RelayModeRealtime:
|
||
err, usage = OpenaiRealtimeHandler(c, info)
|
||
case constant.RelayModeAudioSpeech:
|
||
err, usage = OpenaiTTSHandler(c, resp, info)
|
||
case constant.RelayModeAudioTranslation:
|
||
fallthrough
|
||
case constant.RelayModeAudioTranscription:
|
||
err, usage = OpenaiSTTHandler(c, resp, info, a.ResponseFormat)
|
||
case constant.RelayModeImagesGenerations:
|
||
err, usage = OpenaiTTSHandler(c, resp, info)
|
||
case constant.RelayModeRerank:
|
||
err, usage = jina.JinaRerankHandler(c, resp)
|
||
default:
|
||
if info.IsStream {
|
||
err, usage = OaiStreamHandler(c, resp, info)
|
||
} else {
|
||
err, usage = OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName)
|
||
}
|
||
}
|
||
return
|
||
}
|
||
|
||
func (a *Adaptor) GetModelList() []string {
|
||
switch a.ChannelType {
|
||
case common.ChannelType360:
|
||
return ai360.ModelList
|
||
case common.ChannelTypeMoonshot:
|
||
return moonshot.ModelList
|
||
case common.ChannelTypeLingYiWanWu:
|
||
return lingyiwanwu.ModelList
|
||
case common.ChannelTypeMiniMax:
|
||
return minimax.ModelList
|
||
case common.ChannelTypeXinference:
|
||
return xinference.ModelList
|
||
default:
|
||
return ModelList
|
||
}
|
||
}
|
||
|
||
func (a *Adaptor) GetChannelName() string {
|
||
switch a.ChannelType {
|
||
case common.ChannelType360:
|
||
return ai360.ChannelName
|
||
case common.ChannelTypeMoonshot:
|
||
return moonshot.ChannelName
|
||
case common.ChannelTypeLingYiWanWu:
|
||
return lingyiwanwu.ChannelName
|
||
case common.ChannelTypeMiniMax:
|
||
return minimax.ChannelName
|
||
case common.ChannelTypeXinference:
|
||
return xinference.ChannelName
|
||
default:
|
||
return ChannelName
|
||
}
|
||
}
|