refactor: aws
This commit is contained in:
@@ -88,50 +88,9 @@ func awsModelID(requestModel string) string {
|
||||
return requestModel
|
||||
}
|
||||
|
||||
func awsHandler(c *gin.Context, info *relaycommon.RelayInfo, requestMode int) (*types.NewAPIError, *dto.Usage) {
|
||||
awsCli, err := newAwsClient(c, info)
|
||||
if err != nil {
|
||||
return types.NewError(err, types.ErrorCodeChannelAwsClientError), nil
|
||||
}
|
||||
func awsHandler(c *gin.Context, info *relaycommon.RelayInfo, a *Adaptor) (*types.NewAPIError, *dto.Usage) {
|
||||
|
||||
awsModelId := awsModelID(c.GetString("request_model"))
|
||||
// 检查是否为Nova模型
|
||||
isNova, _ := c.Get("is_nova_model")
|
||||
if isNova == true {
|
||||
// Nova模型也支持跨区域
|
||||
awsRegionPrefix := awsRegionPrefix(awsCli.Options().Region)
|
||||
canCrossRegion := awsModelCanCrossRegion(awsModelId, awsRegionPrefix)
|
||||
if canCrossRegion {
|
||||
awsModelId = awsModelCrossRegion(awsModelId, awsRegionPrefix)
|
||||
}
|
||||
return handleNovaRequest(c, awsCli, info, awsModelId)
|
||||
}
|
||||
|
||||
// 原有的Claude处理逻辑
|
||||
awsRegionPrefix := awsRegionPrefix(awsCli.Options().Region)
|
||||
canCrossRegion := awsModelCanCrossRegion(awsModelId, awsRegionPrefix)
|
||||
if canCrossRegion {
|
||||
awsModelId = awsModelCrossRegion(awsModelId, awsRegionPrefix)
|
||||
}
|
||||
|
||||
awsReq := &bedrockruntime.InvokeModelInput{
|
||||
ModelId: aws.String(awsModelId),
|
||||
Accept: aws.String("application/json"),
|
||||
ContentType: aws.String("application/json"),
|
||||
}
|
||||
|
||||
claudeReq_, ok := c.Get("converted_request")
|
||||
if !ok {
|
||||
return types.NewError(errors.New("aws claude request not found"), types.ErrorCodeInvalidRequest), nil
|
||||
}
|
||||
claudeReq := claudeReq_.(*dto.ClaudeRequest)
|
||||
awsClaudeReq := copyRequest(claudeReq)
|
||||
awsReq.Body, err = common.Marshal(awsClaudeReq)
|
||||
if err != nil {
|
||||
return types.NewError(errors.Wrap(err, "marshal request"), types.ErrorCodeBadResponseBody), nil
|
||||
}
|
||||
|
||||
awsResp, err := awsCli.InvokeModel(c.Request.Context(), awsReq)
|
||||
awsResp, err := a.AwsClient.InvokeModel(c.Request.Context(), a.AwsReq.(*bedrockruntime.InvokeModelInput))
|
||||
if err != nil {
|
||||
return types.NewOpenAIError(errors.Wrap(err, "InvokeModel"), types.ErrorCodeAwsInvokeError, http.StatusInternalServerError), nil
|
||||
}
|
||||
@@ -156,39 +115,8 @@ func awsHandler(c *gin.Context, info *relaycommon.RelayInfo, requestMode int) (*
|
||||
return nil, claudeInfo.Usage
|
||||
}
|
||||
|
||||
func awsStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo, requestMode int) (*types.NewAPIError, *dto.Usage) {
|
||||
awsCli, err := newAwsClient(c, info)
|
||||
if err != nil {
|
||||
return types.NewError(err, types.ErrorCodeChannelAwsClientError), nil
|
||||
}
|
||||
|
||||
awsModelId := awsModelID(c.GetString("request_model"))
|
||||
|
||||
awsRegionPrefix := awsRegionPrefix(awsCli.Options().Region)
|
||||
canCrossRegion := awsModelCanCrossRegion(awsModelId, awsRegionPrefix)
|
||||
if canCrossRegion {
|
||||
awsModelId = awsModelCrossRegion(awsModelId, awsRegionPrefix)
|
||||
}
|
||||
|
||||
awsReq := &bedrockruntime.InvokeModelWithResponseStreamInput{
|
||||
ModelId: aws.String(awsModelId),
|
||||
Accept: aws.String("application/json"),
|
||||
ContentType: aws.String("application/json"),
|
||||
}
|
||||
|
||||
claudeReq_, ok := c.Get("converted_request")
|
||||
if !ok {
|
||||
return types.NewError(errors.New("aws claude request not found"), types.ErrorCodeInvalidRequest), nil
|
||||
}
|
||||
claudeReq := claudeReq_.(*dto.ClaudeRequest)
|
||||
|
||||
awsClaudeReq := copyRequest(claudeReq)
|
||||
awsReq.Body, err = common.Marshal(awsClaudeReq)
|
||||
if err != nil {
|
||||
return types.NewError(errors.Wrap(err, "marshal request"), types.ErrorCodeBadResponseBody), nil
|
||||
}
|
||||
|
||||
awsResp, err := awsCli.InvokeModelWithResponseStream(c.Request.Context(), awsReq)
|
||||
func awsStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, a *Adaptor) (*types.NewAPIError, *dto.Usage) {
|
||||
awsResp, err := a.AwsClient.InvokeModelWithResponseStream(c.Request.Context(), a.AwsReq.(*bedrockruntime.InvokeModelWithResponseStreamInput))
|
||||
if err != nil {
|
||||
return types.NewOpenAIError(errors.Wrap(err, "InvokeModelWithResponseStream"), types.ErrorCodeAwsInvokeError, http.StatusInternalServerError), nil
|
||||
}
|
||||
@@ -225,27 +153,9 @@ func awsStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel
|
||||
}
|
||||
|
||||
// Nova模型处理函数
|
||||
func handleNovaRequest(c *gin.Context, awsCli *bedrockruntime.Client, info *relaycommon.RelayInfo, awsModelId string) (*types.NewAPIError, *dto.Usage) {
|
||||
novaReq_, ok := c.Get("converted_request")
|
||||
if !ok {
|
||||
return types.NewError(errors.New("nova request not found"), types.ErrorCodeInvalidRequest), nil
|
||||
}
|
||||
novaReq := novaReq_.(*NovaRequest)
|
||||
func handleNovaRequest(c *gin.Context, info *relaycommon.RelayInfo, a *Adaptor) (*types.NewAPIError, *dto.Usage) {
|
||||
|
||||
// 使用InvokeModel API,但使用Nova格式的请求体
|
||||
awsReq := &bedrockruntime.InvokeModelInput{
|
||||
ModelId: aws.String(awsModelId),
|
||||
Accept: aws.String("application/json"),
|
||||
ContentType: aws.String("application/json"),
|
||||
}
|
||||
|
||||
reqBody, err := json.Marshal(novaReq)
|
||||
if err != nil {
|
||||
return types.NewError(errors.Wrap(err, "marshal nova request"), types.ErrorCodeBadResponseBody), nil
|
||||
}
|
||||
awsReq.Body = reqBody
|
||||
|
||||
awsResp, err := awsCli.InvokeModel(c.Request.Context(), awsReq)
|
||||
awsResp, err := a.AwsClient.InvokeModel(c.Request.Context(), a.AwsReq.(*bedrockruntime.InvokeModelInput))
|
||||
if err != nil {
|
||||
return types.NewError(errors.Wrap(err, "InvokeModel"), types.ErrorCodeChannelAwsClientError), nil
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user