Files
new-api/relay/channel/openai/adaptor.go
1808837298@qq.com bd48f43410 feat: claude relay
2025-03-12 21:31:46 +08:00

284 lines
9.2 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package openai
import (
"bytes"
"encoding/json"
"errors"
"fmt"
"github.com/gin-gonic/gin"
"io"
"mime/multipart"
"net/http"
"one-api/common"
constant2 "one-api/constant"
"one-api/dto"
"one-api/relay/channel"
"one-api/relay/channel/ai360"
"one-api/relay/channel/jina"
"one-api/relay/channel/lingyiwanwu"
"one-api/relay/channel/minimax"
"one-api/relay/channel/moonshot"
"one-api/relay/channel/xinference"
relaycommon "one-api/relay/common"
"one-api/relay/constant"
"strings"
)
type Adaptor struct {
ChannelType int
ResponseFormat string
}
func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) {
//TODO implement me
panic("implement me")
return nil, nil
}
func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
a.ChannelType = info.ChannelType
}
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
if info.RelayMode == constant.RelayModeRealtime {
if strings.HasPrefix(info.BaseUrl, "https://") {
baseUrl := strings.TrimPrefix(info.BaseUrl, "https://")
baseUrl = "wss://" + baseUrl
info.BaseUrl = baseUrl
} else if strings.HasPrefix(info.BaseUrl, "http://") {
baseUrl := strings.TrimPrefix(info.BaseUrl, "http://")
baseUrl = "ws://" + baseUrl
info.BaseUrl = baseUrl
}
}
switch info.ChannelType {
case common.ChannelTypeAzure:
apiVersion := info.ApiVersion
if apiVersion == "" {
apiVersion = constant2.AzureDefaultAPIVersion
}
// https://learn.microsoft.com/en-us/azure/cognitive-services/openai/chatgpt-quickstart?pivots=rest-api&tabs=command-line#rest-api
requestURL := strings.Split(info.RequestURLPath, "?")[0]
requestURL = fmt.Sprintf("%s?api-version=%s", requestURL, apiVersion)
task := strings.TrimPrefix(requestURL, "/v1/")
model_ := info.UpstreamModelName
model_ = strings.Replace(model_, ".", "", -1)
// https://github.com/songquanpeng/one-api/issues/67
requestURL = fmt.Sprintf("/openai/deployments/%s/%s", model_, task)
if info.RelayMode == constant.RelayModeRealtime {
requestURL = fmt.Sprintf("/openai/realtime?deployment=%s&api-version=%s", model_, apiVersion)
}
return relaycommon.GetFullRequestURL(info.BaseUrl, requestURL, info.ChannelType), nil
case common.ChannelTypeMiniMax:
return minimax.GetRequestURL(info)
case common.ChannelTypeCustom:
url := info.BaseUrl
url = strings.Replace(url, "{model}", info.UpstreamModelName, -1)
return url, nil
default:
return relaycommon.GetFullRequestURL(info.BaseUrl, info.RequestURLPath, info.ChannelType), nil
}
}
func (a *Adaptor) SetupRequestHeader(c *gin.Context, header *http.Header, info *relaycommon.RelayInfo) error {
channel.SetupApiRequestHeader(info, c, header)
if info.ChannelType == common.ChannelTypeAzure {
header.Set("api-key", info.ApiKey)
return nil
}
if info.ChannelType == common.ChannelTypeOpenAI && "" != info.Organization {
header.Set("OpenAI-Organization", info.Organization)
}
if info.RelayMode == constant.RelayModeRealtime {
swp := c.Request.Header.Get("Sec-WebSocket-Protocol")
if swp != "" {
items := []string{
"realtime",
"openai-insecure-api-key." + info.ApiKey,
"openai-beta.realtime-v1",
}
header.Set("Sec-WebSocket-Protocol", strings.Join(items, ","))
//req.Header.Set("Sec-WebSocket-Key", c.Request.Header.Get("Sec-WebSocket-Key"))
//req.Header.Set("Sec-Websocket-Extensions", c.Request.Header.Get("Sec-Websocket-Extensions"))
//req.Header.Set("Sec-Websocket-Version", c.Request.Header.Get("Sec-Websocket-Version"))
} else {
header.Set("openai-beta", "realtime=v1")
header.Set("Authorization", "Bearer "+info.ApiKey)
}
} else {
header.Set("Authorization", "Bearer "+info.ApiKey)
}
//if info.ChannelType == common.ChannelTypeOpenRouter {
// req.Header.Set("HTTP-Referer", "https://github.com/songquanpeng/one-api")
// req.Header.Set("X-Title", "One API")
//}
return nil
}
func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
if request == nil {
return nil, errors.New("request is nil")
}
if info.ChannelType != common.ChannelTypeOpenAI && info.ChannelType != common.ChannelTypeAzure {
request.StreamOptions = nil
}
if strings.HasPrefix(request.Model, "o1") || strings.HasPrefix(request.Model, "o3") {
if request.MaxCompletionTokens == 0 && request.MaxTokens != 0 {
request.MaxCompletionTokens = request.MaxTokens
request.MaxTokens = 0
}
if strings.HasPrefix(request.Model, "o3") || strings.HasPrefix(request.Model, "o1") {
request.Temperature = nil
}
if strings.HasSuffix(request.Model, "-high") {
request.ReasoningEffort = "high"
request.Model = strings.TrimSuffix(request.Model, "-high")
} else if strings.HasSuffix(request.Model, "-low") {
request.ReasoningEffort = "low"
request.Model = strings.TrimSuffix(request.Model, "-low")
} else if strings.HasSuffix(request.Model, "-medium") {
request.ReasoningEffort = "medium"
request.Model = strings.TrimSuffix(request.Model, "-medium")
}
info.ReasoningEffort = request.ReasoningEffort
info.UpstreamModelName = request.Model
}
if request.Model == "o1" || request.Model == "o1-2024-12-17" || strings.HasPrefix(request.Model, "o3") {
//修改第一个Message的内容将system改为developer
if len(request.Messages) > 0 && request.Messages[0].Role == "system" {
request.Messages[0].Role = "developer"
}
}
return request, nil
}
func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) {
return request, nil
}
func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.EmbeddingRequest) (any, error) {
return request, nil
}
func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) {
a.ResponseFormat = request.ResponseFormat
if info.RelayMode == constant.RelayModeAudioSpeech {
jsonData, err := json.Marshal(request)
if err != nil {
return nil, fmt.Errorf("error marshalling object: %w", err)
}
return bytes.NewReader(jsonData), nil
} else {
var requestBody bytes.Buffer
writer := multipart.NewWriter(&requestBody)
writer.WriteField("model", request.Model)
// 获取所有表单字段
formData := c.Request.PostForm
// 遍历表单字段并打印输出
for key, values := range formData {
if key == "model" {
continue
}
for _, value := range values {
writer.WriteField(key, value)
}
}
// 添加文件字段
file, header, err := c.Request.FormFile("file")
if err != nil {
return nil, errors.New("file is required")
}
defer file.Close()
part, err := writer.CreateFormFile("file", header.Filename)
if err != nil {
return nil, errors.New("create form file failed")
}
if _, err := io.Copy(part, file); err != nil {
return nil, errors.New("copy file failed")
}
// 关闭 multipart 编写器以设置分界线
writer.Close()
c.Request.Header.Set("Content-Type", writer.FormDataContentType())
return &requestBody, nil
}
}
func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) {
return request, nil
}
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
if info.RelayMode == constant.RelayModeAudioTranscription || info.RelayMode == constant.RelayModeAudioTranslation {
return channel.DoFormRequest(a, c, info, requestBody)
} else if info.RelayMode == constant.RelayModeRealtime {
return channel.DoWssRequest(a, c, info, requestBody)
} else {
return channel.DoApiRequest(a, c, info, requestBody)
}
}
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) {
switch info.RelayMode {
case constant.RelayModeRealtime:
err, usage = OpenaiRealtimeHandler(c, info)
case constant.RelayModeAudioSpeech:
err, usage = OpenaiTTSHandler(c, resp, info)
case constant.RelayModeAudioTranslation:
fallthrough
case constant.RelayModeAudioTranscription:
err, usage = OpenaiSTTHandler(c, resp, info, a.ResponseFormat)
case constant.RelayModeImagesGenerations:
err, usage = OpenaiTTSHandler(c, resp, info)
case constant.RelayModeRerank:
err, usage = jina.JinaRerankHandler(c, resp)
default:
if info.IsStream {
err, usage = OaiStreamHandler(c, resp, info)
} else {
err, usage = OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName)
}
}
return
}
func (a *Adaptor) GetModelList() []string {
switch a.ChannelType {
case common.ChannelType360:
return ai360.ModelList
case common.ChannelTypeMoonshot:
return moonshot.ModelList
case common.ChannelTypeLingYiWanWu:
return lingyiwanwu.ModelList
case common.ChannelTypeMiniMax:
return minimax.ModelList
case common.ChannelTypeXinference:
return xinference.ModelList
default:
return ModelList
}
}
func (a *Adaptor) GetChannelName() string {
switch a.ChannelType {
case common.ChannelType360:
return ai360.ChannelName
case common.ChannelTypeMoonshot:
return moonshot.ChannelName
case common.ChannelTypeLingYiWanWu:
return lingyiwanwu.ChannelName
case common.ChannelTypeMiniMax:
return minimax.ChannelName
case common.ChannelTypeXinference:
return xinference.ChannelName
default:
return ChannelName
}
}