diff --git a/relay/channel/aws/relay-aws.go b/relay/channel/aws/relay-aws.go index a5bc896b..ea59aeda 100644 --- a/relay/channel/aws/relay-aws.go +++ b/relay/channel/aws/relay-aws.go @@ -1,11 +1,13 @@ package aws import ( + "context" "encoding/json" "fmt" "io" "net/http" "strings" + "time" "github.com/QuantumNous/new-api/common" "github.com/QuantumNous/new-api/dto" @@ -37,6 +39,13 @@ func getAwsErrorStatusCode(err error) int { return http.StatusInternalServerError } +func newAwsInvokeContext() (context.Context, context.CancelFunc) { + if common.RelayTimeout <= 0 { + return context.Background(), func() {} + } + return context.WithTimeout(context.Background(), time.Duration(common.RelayTimeout)*time.Second) +} + func newAwsClient(c *gin.Context, info *relaycommon.RelayInfo) (*bedrockruntime.Client, error) { var ( httpClient *http.Client @@ -117,6 +126,7 @@ func doAwsClientRequest(c *gin.Context, info *relaycommon.RelayInfo, a *Adaptor, return nil, types.NewError(errors.Wrap(err, "marshal nova request"), types.ErrorCodeBadResponseBody) } awsReq.Body = reqBody + a.AwsReq = awsReq return nil, nil } else { awsClaudeReq, err := formatRequest(requestBody, requestHeader) @@ -201,7 +211,10 @@ func getAwsModelID(requestModel string) string { func awsHandler(c *gin.Context, info *relaycommon.RelayInfo, a *Adaptor) (*types.NewAPIError, *dto.Usage) { - awsResp, err := a.AwsClient.InvokeModel(c.Request.Context(), a.AwsReq.(*bedrockruntime.InvokeModelInput)) + ctx, cancel := newAwsInvokeContext() + defer cancel() + + awsResp, err := a.AwsClient.InvokeModel(ctx, a.AwsReq.(*bedrockruntime.InvokeModelInput)) if err != nil { statusCode := getAwsErrorStatusCode(err) return types.NewOpenAIError(errors.Wrap(err, "InvokeModel"), types.ErrorCodeAwsInvokeError, statusCode), nil @@ -228,7 +241,10 @@ func awsHandler(c *gin.Context, info *relaycommon.RelayInfo, a *Adaptor) (*types } func awsStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, a *Adaptor) (*types.NewAPIError, *dto.Usage) { - awsResp, err := a.AwsClient.InvokeModelWithResponseStream(c.Request.Context(), a.AwsReq.(*bedrockruntime.InvokeModelWithResponseStreamInput)) + ctx, cancel := newAwsInvokeContext() + defer cancel() + + awsResp, err := a.AwsClient.InvokeModelWithResponseStream(ctx, a.AwsReq.(*bedrockruntime.InvokeModelWithResponseStreamInput)) if err != nil { statusCode := getAwsErrorStatusCode(err) return types.NewOpenAIError(errors.Wrap(err, "InvokeModelWithResponseStream"), types.ErrorCodeAwsInvokeError, statusCode), nil @@ -268,7 +284,10 @@ func awsStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, a *Adaptor) ( // Nova模型处理函数 func handleNovaRequest(c *gin.Context, info *relaycommon.RelayInfo, a *Adaptor) (*types.NewAPIError, *dto.Usage) { - awsResp, err := a.AwsClient.InvokeModel(c.Request.Context(), a.AwsReq.(*bedrockruntime.InvokeModelInput)) + ctx, cancel := newAwsInvokeContext() + defer cancel() + + awsResp, err := a.AwsClient.InvokeModel(ctx, a.AwsReq.(*bedrockruntime.InvokeModelInput)) if err != nil { statusCode := getAwsErrorStatusCode(err) return types.NewOpenAIError(errors.Wrap(err, "InvokeModel"), types.ErrorCodeAwsInvokeError, statusCode), nil diff --git a/service/http_client.go b/service/http_client.go index 3ae6a676..783aac89 100644 --- a/service/http_client.go +++ b/service/http_client.go @@ -82,6 +82,9 @@ func ResetProxyClientCache() { // NewProxyHttpClient 创建支持代理的 HTTP 客户端 func NewProxyHttpClient(proxyURL string) (*http.Client, error) { if proxyURL == "" { + if client := GetHttpClient(); client != nil { + return client, nil + } return http.DefaultClient, nil }