From 892d014c265181320a1d590aeb84e11e0eef8c66 Mon Sep 17 00:00:00 2001 From: neotf Date: Fri, 14 Mar 2025 21:38:57 +0800 Subject: [PATCH] feat: support AWS Model CrossRegion --- relay/channel/aws/constants.go | 37 ++++++++++++++++++++++++++++++++++ relay/channel/aws/relay-aws.go | 28 +++++++++++++++++++++++++ 2 files changed, 65 insertions(+) diff --git a/relay/channel/aws/constants.go b/relay/channel/aws/constants.go index 66dc7cd9..37196fd8 100644 --- a/relay/channel/aws/constants.go +++ b/relay/channel/aws/constants.go @@ -13,4 +13,41 @@ var awsModelIDMap = map[string]string{ "claude-3-7-sonnet-20250219": "anthropic.claude-3-7-sonnet-20250219-v1:0", } +var awsModelCanCrossRegionMap = map[string]map[string]bool{ + "anthropic.claude-3-sonnet-20240229-v1:0": { + "us": true, + "eu": true, + "ap": true, + }, + "anthropic.claude-3-opus-20240229-v1:0": { + "us": true, + }, + "anthropic.claude-3-haiku-20240307-v1:0": { + "us": true, + "eu": true, + "ap": true, + }, + "anthropic.claude-3-5-sonnet-20240620-v1:0": { + "us": true, + "eu": true, + "ap": true, + }, + "anthropic.claude-3-5-sonnet-20241022-v2:0": { + "us": true, + "ap": true, + }, + "anthropic.claude-3-5-haiku-20241022-v1:0": { + "us": true, + }, + "anthropic.claude-3-7-sonnet-20250219-v1:0": { + "us": true, + }, +} + +var awsRegionCrossModelPrefixMap = map[string]string{ + "us": "us", + "eu": "eu", + "ap": "apac", +} + var ChannelName = "aws" diff --git a/relay/channel/aws/relay-aws.go b/relay/channel/aws/relay-aws.go index 0d517256..89b7b7eb 100644 --- a/relay/channel/aws/relay-aws.go +++ b/relay/channel/aws/relay-aws.go @@ -48,6 +48,28 @@ func wrapErr(err error) *dto.OpenAIErrorWithStatusCode { } } +func awsRegionPrefix(awsRegionId string) string { + parts := strings.Split(awsRegionId, "-") + regionPrefix := "" + if len(parts) > 0 { + regionPrefix = parts[0] + } + return regionPrefix +} + +func awsModelCanCrossRegion(awsModelId, awsRegionPrefix string) bool { + regionSet, exists := awsModelCanCrossRegionMap[awsModelId] + return exists && regionSet[awsRegionPrefix] +} + +func awsModelCrossRegion(awsModelId, awsRegionPrefix string) string { + modelPrefix, find := awsRegionCrossModelPrefixMap[awsRegionPrefix] + if !find { + return awsModelId + } + return modelPrefix + "." + awsModelId +} + func awsModelID(requestModel string) (string, error) { if awsModelID, ok := awsModelIDMap[requestModel]; ok { return awsModelID, nil @@ -67,6 +89,12 @@ func awsHandler(c *gin.Context, info *relaycommon.RelayInfo, requestMode int) (* return wrapErr(errors.Wrap(err, "awsModelID")), nil } + awsRegionPrefix := awsRegionPrefix(awsCli.Options().Region) + canCrossRegion := awsModelCanCrossRegion(awsModelId, awsRegionPrefix) + if canCrossRegion { + awsModelId = awsModelCrossRegion(awsModelId, awsRegionPrefix) + } + awsReq := &bedrockruntime.InvokeModelInput{ ModelId: aws.String(awsModelId), Accept: aws.String("application/json"),