From ee302c063c417738d9024251900dfb4bf6bc8ff6 Mon Sep 17 00:00:00 2001 From: "1808837298@qq.com" <1808837298@qq.com> Date: Sun, 16 Mar 2025 16:47:16 +0800 Subject: [PATCH] refactor: Enhance error handling in AWS and Claude response processing by updating function signatures and improving error propagation --- relay/channel/aws/relay-aws.go | 11 +++++------ relay/channel/claude/relay-claude.go | 25 +++++++++++++------------ 2 files changed, 18 insertions(+), 18 deletions(-) diff --git a/relay/channel/aws/relay-aws.go b/relay/channel/aws/relay-aws.go index 22200e32..da4bab89 100644 --- a/relay/channel/aws/relay-aws.go +++ b/relay/channel/aws/relay-aws.go @@ -10,7 +10,6 @@ import ( "one-api/dto" "one-api/relay/channel/claude" relaycommon "one-api/relay/common" - "one-api/service" "strings" "github.com/aws/aws-sdk-go-v2/aws" @@ -151,7 +150,10 @@ func awsStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel switch v := event.(type) { case *types.ResponseStreamMemberChunk: info.SetFirstResponseTime() - claude.HandleResponseData(c, info, claudeInfo, string(v.Value.Bytes), RequestModeMessage) + err = claude.HandleResponseData(c, info, claudeInfo, string(v.Value.Bytes), RequestModeMessage) + if err != nil { + return wrapErr(err), nil + } case *types.UnknownUnionMember: fmt.Println("unknown tag:", v.Tag) return wrapErr(errors.New("unknown response type")), nil @@ -164,10 +166,7 @@ func awsStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel claude.HandleFinalResponse(c, info, claudeInfo, RequestModeMessage) if resp != nil { - err = resp.Body.Close() - if err != nil { - return service.OpenAIErrorWrapperLocal(err, "close_response_body_failed", http.StatusInternalServerError), nil - } + resp.Body.Close() } return nil, claudeInfo.Usage } diff --git a/relay/channel/claude/relay-claude.go b/relay/channel/claude/relay-claude.go index 6786a636..1214699b 100644 --- a/relay/channel/claude/relay-claude.go +++ b/relay/channel/claude/relay-claude.go @@ -479,12 +479,12 @@ func FormatClaudeResponseInfo(requestMode int, claudeResponse *dto.ClaudeRespons return true } -func HandleResponseData(c *gin.Context, info *relaycommon.RelayInfo, claudeInfo *ClaudeResponseInfo, data string, requestMode int) bool { +func HandleResponseData(c *gin.Context, info *relaycommon.RelayInfo, claudeInfo *ClaudeResponseInfo, data string, requestMode int) error { var claudeResponse dto.ClaudeResponse err := json.NewDecoder(bytes.NewReader(common.StringToByteSlice(data))).Decode(&claudeResponse) if err != nil { common.SysError("error unmarshalling stream response: " + err.Error()) - return false + return fmt.Errorf("error unmarshalling stream aws response: %w", err) } if info.RelayFormat == relaycommon.RelayFormatClaude { if requestMode == RequestModeCompletion { @@ -510,16 +510,10 @@ func HandleResponseData(c *gin.Context, info *relaycommon.RelayInfo, claudeInfo } helper.ClaudeChunkData(c, claudeResponse, data) } else if info.RelayFormat == relaycommon.RelayFormatOpenAI { - err := json.NewDecoder(bytes.NewReader(common.StringToByteSlice(data))).Decode(&claudeResponse) - if err != nil { - common.SysError("error unmarshalling stream response: " + err.Error()) - return false - } - response := StreamResponseClaude2OpenAI(requestMode, &claudeResponse) if !FormatClaudeResponseInfo(requestMode, &claudeResponse, response, claudeInfo) { - return true + return nil } err = helper.ObjectData(c, response) @@ -527,7 +521,7 @@ func HandleResponseData(c *gin.Context, info *relaycommon.RelayInfo, claudeInfo common.LogError(c, "send_stream_response_failed: "+err.Error()) } } - return true + return nil } func HandleFinalResponse(c *gin.Context, info *relaycommon.RelayInfo, claudeInfo *ClaudeResponseInfo, requestMode int) { @@ -573,10 +567,17 @@ func ClaudeStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon. ResponseText: strings.Builder{}, Usage: &dto.Usage{}, } - + var err error helper.StreamScannerHandler(c, resp, info, func(data string) bool { - return HandleResponseData(c, info, claudeInfo, data, requestMode) + err = HandleResponseData(c, info, claudeInfo, data, requestMode) + if err != nil { + return false + } + return true }) + if err != nil { + return service.OpenAIErrorWrapper(err, "stream_response_error", http.StatusInternalServerError), nil + } HandleFinalResponse(c, info, claudeInfo, requestMode)