190 lines
5.5 KiB
Go
190 lines
5.5 KiB
Go
package aws
|
||
|
||
import (
|
||
"io"
|
||
"net/http"
|
||
|
||
"github.com/QuantumNous/new-api/common"
|
||
"github.com/QuantumNous/new-api/dto"
|
||
"github.com/QuantumNous/new-api/relay/channel/claude"
|
||
relaycommon "github.com/QuantumNous/new-api/relay/common"
|
||
"github.com/QuantumNous/new-api/types"
|
||
"github.com/aws/aws-sdk-go-v2/aws"
|
||
"github.com/aws/aws-sdk-go-v2/service/bedrockruntime"
|
||
"github.com/pkg/errors"
|
||
|
||
"github.com/gin-gonic/gin"
|
||
)
|
||
|
||
const (
|
||
RequestModeCompletion = 1
|
||
RequestModeMessage = 2
|
||
)
|
||
|
||
type Adaptor struct {
|
||
AwsClient *bedrockruntime.Client
|
||
AwsModelId string
|
||
AwsReq any
|
||
IsNova bool
|
||
}
|
||
|
||
func (a *Adaptor) ConvertGeminiRequest(*gin.Context, *relaycommon.RelayInfo, *dto.GeminiChatRequest) (any, error) {
|
||
//TODO implement me
|
||
return nil, errors.New("not implemented")
|
||
}
|
||
|
||
func (a *Adaptor) ConvertClaudeRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.ClaudeRequest) (any, error) {
|
||
return request, nil
|
||
}
|
||
|
||
func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) {
|
||
//TODO implement me
|
||
return nil, errors.New("not implemented")
|
||
}
|
||
|
||
func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) {
|
||
//TODO implement me
|
||
return nil, errors.New("not implemented")
|
||
}
|
||
|
||
func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
|
||
}
|
||
|
||
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
||
return "", nil
|
||
}
|
||
|
||
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error {
|
||
claude.CommonClaudeHeadersOperation(c, req, info)
|
||
return nil
|
||
}
|
||
|
||
func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
|
||
if request == nil {
|
||
return nil, errors.New("request is nil")
|
||
}
|
||
// 检查是否为Nova模型
|
||
if isNovaModel(request.Model) {
|
||
novaReq := convertToNovaRequest(request)
|
||
a.IsNova = true
|
||
return novaReq, nil
|
||
}
|
||
|
||
// 原有的Claude模型处理逻辑
|
||
var claudeReq *dto.ClaudeRequest
|
||
var err error
|
||
claudeReq, err = claude.RequestOpenAI2ClaudeMessage(c, *request)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
return claudeReq, err
|
||
}
|
||
|
||
func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) {
|
||
return nil, nil
|
||
}
|
||
|
||
func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.EmbeddingRequest) (any, error) {
|
||
//TODO implement me
|
||
return nil, errors.New("not implemented")
|
||
}
|
||
|
||
func (a *Adaptor) ConvertOpenAIResponsesRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.OpenAIResponsesRequest) (any, error) {
|
||
// TODO implement me
|
||
return nil, errors.New("not implemented")
|
||
}
|
||
|
||
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
|
||
awsCli, err := newAwsClient(c, info)
|
||
if err != nil {
|
||
return nil, types.NewError(err, types.ErrorCodeChannelAwsClientError)
|
||
}
|
||
a.AwsClient = awsCli
|
||
|
||
awsModelId := awsModelID(info.UpstreamModelName)
|
||
|
||
awsRegionPrefix := awsRegionPrefix(awsCli.Options().Region)
|
||
canCrossRegion := awsModelCanCrossRegion(awsModelId, awsRegionPrefix)
|
||
if canCrossRegion {
|
||
awsModelId = awsModelCrossRegion(awsModelId, awsRegionPrefix)
|
||
}
|
||
|
||
if isNovaModel(awsModelId) {
|
||
var novaReq *NovaRequest
|
||
err = common.DecodeJson(requestBody, &novaReq)
|
||
if err != nil {
|
||
return nil, types.NewError(errors.Wrap(err, "decode nova request fail"), types.ErrorCodeBadRequestBody)
|
||
}
|
||
|
||
// 使用InvokeModel API,但使用Nova格式的请求体
|
||
awsReq := &bedrockruntime.InvokeModelInput{
|
||
ModelId: aws.String(awsModelId),
|
||
Accept: aws.String("application/json"),
|
||
ContentType: aws.String("application/json"),
|
||
}
|
||
|
||
reqBody, err := common.Marshal(novaReq)
|
||
if err != nil {
|
||
return nil, types.NewError(errors.Wrap(err, "marshal nova request"), types.ErrorCodeBadResponseBody)
|
||
}
|
||
awsReq.Body = reqBody
|
||
return nil, nil
|
||
} else {
|
||
awsClaudeReq, err := formatRequest(requestBody)
|
||
if err != nil {
|
||
return nil, types.NewError(errors.Wrap(err, "format aws request fail"), types.ErrorCodeBadRequestBody)
|
||
}
|
||
|
||
if info.IsStream {
|
||
awsReq := &bedrockruntime.InvokeModelWithResponseStreamInput{
|
||
ModelId: aws.String(awsModelId),
|
||
Accept: aws.String("application/json"),
|
||
ContentType: aws.String("application/json"),
|
||
}
|
||
awsReq.Body, err = common.Marshal(awsClaudeReq)
|
||
if err != nil {
|
||
return nil, types.NewError(errors.Wrap(err, "marshal aws request fail"), types.ErrorCodeBadRequestBody)
|
||
}
|
||
a.AwsReq = awsReq
|
||
return nil, nil
|
||
} else {
|
||
awsReq := &bedrockruntime.InvokeModelInput{
|
||
ModelId: aws.String(awsModelId),
|
||
Accept: aws.String("application/json"),
|
||
ContentType: aws.String("application/json"),
|
||
}
|
||
awsReq.Body, err = common.Marshal(awsClaudeReq)
|
||
if err != nil {
|
||
return nil, types.NewError(errors.Wrap(err, "marshal aws request fail"), types.ErrorCodeBadRequestBody)
|
||
}
|
||
a.AwsReq = awsReq
|
||
return nil, nil
|
||
}
|
||
}
|
||
}
|
||
|
||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *types.NewAPIError) {
|
||
if a.IsNova {
|
||
err, usage = handleNovaRequest(c, info, a)
|
||
} else {
|
||
if info.IsStream {
|
||
err, usage = awsStreamHandler(c, info, a)
|
||
} else {
|
||
err, usage = awsHandler(c, info, a)
|
||
}
|
||
}
|
||
return
|
||
}
|
||
|
||
func (a *Adaptor) GetModelList() (models []string) {
|
||
for n := range awsModelIDMap {
|
||
models = append(models, n)
|
||
}
|
||
|
||
return
|
||
}
|
||
|
||
func (a *Adaptor) GetChannelName() string {
|
||
return ChannelName
|
||
}
|