refactor: aws

This commit is contained in:
CaIon
2025-10-15 16:44:33 +08:00
parent 5069482d22
commit 31e09b960f
7 changed files with 108 additions and 113 deletions

View File

@@ -1,14 +1,17 @@
package aws
import (
"errors"
"io"
"net/http"
"github.com/QuantumNous/new-api/common"
"github.com/QuantumNous/new-api/dto"
"github.com/QuantumNous/new-api/relay/channel/claude"
relaycommon "github.com/QuantumNous/new-api/relay/common"
"github.com/QuantumNous/new-api/types"
"github.com/aws/aws-sdk-go-v2/aws"
"github.com/aws/aws-sdk-go-v2/service/bedrockruntime"
"github.com/pkg/errors"
"github.com/gin-gonic/gin"
)
@@ -19,7 +22,10 @@ const (
)
type Adaptor struct {
RequestMode int
AwsClient *bedrockruntime.Client
AwsModelId string
AwsReq any
IsNova bool
}
func (a *Adaptor) ConvertGeminiRequest(*gin.Context, *relaycommon.RelayInfo, *dto.GeminiChatRequest) (any, error) {
@@ -28,8 +34,6 @@ func (a *Adaptor) ConvertGeminiRequest(*gin.Context, *relaycommon.RelayInfo, *dt
}
func (a *Adaptor) ConvertClaudeRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.ClaudeRequest) (any, error) {
c.Set("request_model", request.Model)
c.Set("converted_request", request)
return request, nil
}
@@ -44,7 +48,6 @@ func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInf
}
func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
a.RequestMode = RequestModeMessage
}
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
@@ -63,9 +66,7 @@ func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayIn
// 检查是否为Nova模型
if isNovaModel(request.Model) {
novaReq := convertToNovaRequest(request)
c.Set("request_model", request.Model)
c.Set("converted_request", novaReq)
c.Set("is_nova_model", true)
a.IsNova = true
return novaReq, nil
}
@@ -76,9 +77,6 @@ func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayIn
if err != nil {
return nil, err
}
c.Set("request_model", claudeReq.Model)
c.Set("converted_request", claudeReq)
c.Set("is_nova_model", false)
return claudeReq, err
}
@@ -97,14 +95,83 @@ func (a *Adaptor) ConvertOpenAIResponsesRequest(c *gin.Context, info *relaycommo
}
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
return nil, nil
awsCli, err := newAwsClient(c, info)
if err != nil {
return nil, types.NewError(err, types.ErrorCodeChannelAwsClientError)
}
a.AwsClient = awsCli
awsModelId := awsModelID(info.UpstreamModelName)
awsRegionPrefix := awsRegionPrefix(awsCli.Options().Region)
canCrossRegion := awsModelCanCrossRegion(awsModelId, awsRegionPrefix)
if canCrossRegion {
awsModelId = awsModelCrossRegion(awsModelId, awsRegionPrefix)
}
if isNovaModel(awsModelId) {
var novaReq *NovaRequest
err = common.DecodeJson(requestBody, &novaReq)
if err != nil {
return nil, types.NewError(errors.Wrap(err, "decode nova request fail"), types.ErrorCodeBadRequestBody)
}
// 使用InvokeModel API但使用Nova格式的请求体
awsReq := &bedrockruntime.InvokeModelInput{
ModelId: aws.String(awsModelId),
Accept: aws.String("application/json"),
ContentType: aws.String("application/json"),
}
reqBody, err := common.Marshal(novaReq)
if err != nil {
return nil, types.NewError(errors.Wrap(err, "marshal nova request"), types.ErrorCodeBadResponseBody)
}
awsReq.Body = reqBody
return nil, nil
} else {
awsClaudeReq, err := formatRequest(requestBody)
if err != nil {
return nil, types.NewError(errors.Wrap(err, "format aws request fail"), types.ErrorCodeBadRequestBody)
}
if info.IsStream {
awsReq := &bedrockruntime.InvokeModelWithResponseStreamInput{
ModelId: aws.String(awsModelId),
Accept: aws.String("application/json"),
ContentType: aws.String("application/json"),
}
awsReq.Body, err = common.Marshal(awsClaudeReq)
if err != nil {
return nil, types.NewError(errors.Wrap(err, "marshal aws request fail"), types.ErrorCodeBadRequestBody)
}
a.AwsReq = awsReq
return nil, nil
} else {
awsReq := &bedrockruntime.InvokeModelInput{
ModelId: aws.String(awsModelId),
Accept: aws.String("application/json"),
ContentType: aws.String("application/json"),
}
awsReq.Body, err = common.Marshal(awsClaudeReq)
if err != nil {
return nil, types.NewError(errors.Wrap(err, "marshal aws request fail"), types.ErrorCodeBadRequestBody)
}
a.AwsReq = awsReq
return nil, nil
}
}
}
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *types.NewAPIError) {
if info.IsStream {
err, usage = awsStreamHandler(c, resp, info, a.RequestMode)
if a.IsNova {
err, usage = handleNovaRequest(c, info, a)
} else {
err, usage = awsHandler(c, info, a.RequestMode)
if info.IsStream {
err, usage = awsStreamHandler(c, info, a)
} else {
err, usage = awsHandler(c, info, a)
}
}
return
}