fix: 支持aws 通过全局参数透传或者渠道参数透传来 调用 (#2423)
* fix: 支持aws 通过全局参数透传或者渠道参数透传来 调用 * fix(aws): replace json.Unmarshal with common.Unmarshal for request body processing --------- Co-authored-by: r0 <liangchunlei@01.ai> Co-authored-by: CaIon <i@caion.me>
This commit is contained in:
@@ -18,6 +18,7 @@ import (
|
|||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
"github.com/pkg/errors"
|
"github.com/pkg/errors"
|
||||||
|
|
||||||
|
"github.com/QuantumNous/new-api/setting/model_setting"
|
||||||
"github.com/aws/aws-sdk-go-v2/aws"
|
"github.com/aws/aws-sdk-go-v2/aws"
|
||||||
"github.com/aws/aws-sdk-go-v2/credentials"
|
"github.com/aws/aws-sdk-go-v2/credentials"
|
||||||
"github.com/aws/aws-sdk-go-v2/service/bedrockruntime"
|
"github.com/aws/aws-sdk-go-v2/service/bedrockruntime"
|
||||||
@@ -129,7 +130,7 @@ func doAwsClientRequest(c *gin.Context, info *relaycommon.RelayInfo, a *Adaptor,
|
|||||||
Accept: aws.String("application/json"),
|
Accept: aws.String("application/json"),
|
||||||
ContentType: aws.String("application/json"),
|
ContentType: aws.String("application/json"),
|
||||||
}
|
}
|
||||||
awsReq.Body, err = common.Marshal(awsClaudeReq)
|
awsReq.Body, err = buildAwsRequestBody(c, info, awsClaudeReq)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, types.NewError(errors.Wrap(err, "marshal aws request fail"), types.ErrorCodeBadRequestBody)
|
return nil, types.NewError(errors.Wrap(err, "marshal aws request fail"), types.ErrorCodeBadRequestBody)
|
||||||
}
|
}
|
||||||
@@ -141,7 +142,7 @@ func doAwsClientRequest(c *gin.Context, info *relaycommon.RelayInfo, a *Adaptor,
|
|||||||
Accept: aws.String("application/json"),
|
Accept: aws.String("application/json"),
|
||||||
ContentType: aws.String("application/json"),
|
ContentType: aws.String("application/json"),
|
||||||
}
|
}
|
||||||
awsReq.Body, err = common.Marshal(awsClaudeReq)
|
awsReq.Body, err = buildAwsRequestBody(c, info, awsClaudeReq)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, types.NewError(errors.Wrap(err, "marshal aws request fail"), types.ErrorCodeBadRequestBody)
|
return nil, types.NewError(errors.Wrap(err, "marshal aws request fail"), types.ErrorCodeBadRequestBody)
|
||||||
}
|
}
|
||||||
@@ -151,6 +152,24 @@ func doAwsClientRequest(c *gin.Context, info *relaycommon.RelayInfo, a *Adaptor,
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// buildAwsRequestBody prepares the payload for AWS requests, applying passthrough rules when enabled.
|
||||||
|
func buildAwsRequestBody(c *gin.Context, info *relaycommon.RelayInfo, awsClaudeReq any) ([]byte, error) {
|
||||||
|
if model_setting.GetGlobalSettings().PassThroughRequestEnabled || info.ChannelSetting.PassThroughBodyEnabled {
|
||||||
|
body, err := common.GetRequestBody(c)
|
||||||
|
if err != nil {
|
||||||
|
return nil, errors.Wrap(err, "get request body for pass-through fail")
|
||||||
|
}
|
||||||
|
var data map[string]interface{}
|
||||||
|
if err := common.Unmarshal(body, &data); err != nil {
|
||||||
|
return nil, errors.Wrap(err, "pass-through unmarshal request body fail")
|
||||||
|
}
|
||||||
|
delete(data, "model")
|
||||||
|
delete(data, "stream")
|
||||||
|
return common.Marshal(data)
|
||||||
|
}
|
||||||
|
return common.Marshal(awsClaudeReq)
|
||||||
|
}
|
||||||
|
|
||||||
func getAwsRegionPrefix(awsRegionId string) string {
|
func getAwsRegionPrefix(awsRegionId string) string {
|
||||||
parts := strings.Split(awsRegionId, "-")
|
parts := strings.Split(awsRegionId, "-")
|
||||||
regionPrefix := ""
|
regionPrefix := ""
|
||||||
|
|||||||
Reference in New Issue
Block a user