From 2c81a5f0cc6bfdba70d462fd7ccc228ef8c9d244 Mon Sep 17 00:00:00 2001 From: "1808837298@qq.com" <1808837298@qq.com> Date: Sun, 16 Mar 2025 15:57:01 +0800 Subject: [PATCH] refactor: Streamline AWS and Claude response handling by consolidating logic and improving error management --- relay/channel/aws/relay-aws.go | 61 ++---------- relay/channel/claude/relay-claude.go | 143 +++++++++++++-------------- relay/channel/xinference/constant.go | 1 + 3 files changed, 79 insertions(+), 126 deletions(-) diff --git a/relay/channel/aws/relay-aws.go b/relay/channel/aws/relay-aws.go index 0d517256..22200e32 100644 --- a/relay/channel/aws/relay-aws.go +++ b/relay/channel/aws/relay-aws.go @@ -1,21 +1,17 @@ package aws import ( - "bytes" "encoding/json" "fmt" "github.com/gin-gonic/gin" "github.com/pkg/errors" - "io" "net/http" "one-api/common" "one-api/dto" "one-api/relay/channel/claude" relaycommon "one-api/relay/common" - "one-api/relay/helper" "one-api/service" "strings" - "time" "github.com/aws/aws-sdk-go-v2/aws" "github.com/aws/aws-sdk-go-v2/credentials" @@ -143,7 +139,6 @@ func awsStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel stream := awsResp.GetStream() defer stream.Close() - c.Writer.Header().Set("Content-Type", "text/event-stream") claudeInfo := &claude.ClaudeResponseInfo{ ResponseId: fmt.Sprintf("chatcmpl-%s", common.GetUUID()), Created: common.GetTimestamp(), @@ -151,63 +146,23 @@ func awsStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel ResponseText: strings.Builder{}, Usage: &dto.Usage{}, } - isFirst := true - c.Stream(func(w io.Writer) bool { - event, ok := <-stream.Events() - if !ok { - return false - } + for event := range stream.Events() { switch v := event.(type) { case *types.ResponseStreamMemberChunk: - if isFirst { - isFirst = false - info.FirstResponseTime = time.Now() - } - claudeResponse := new(dto.ClaudeResponse) - err := json.NewDecoder(bytes.NewReader(v.Value.Bytes)).Decode(claudeResponse) - if err != nil { - common.SysError("error unmarshalling stream response: " + err.Error()) - return false - } - - response := claude.StreamResponseClaude2OpenAI(requestMode, claudeResponse) - - if !claude.FormatClaudeResponseInfo(RequestModeMessage, claudeResponse, response, claudeInfo) { - return true - } - - jsonStr, err := json.Marshal(response) - if err != nil { - common.SysError("error marshalling stream response: " + err.Error()) - return true - } - c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonStr)}) - return true + info.SetFirstResponseTime() + claude.HandleResponseData(c, info, claudeInfo, string(v.Value.Bytes), RequestModeMessage) case *types.UnknownUnionMember: fmt.Println("unknown tag:", v.Tag) - return false + return wrapErr(errors.New("unknown response type")), nil default: fmt.Println("union is nil or unknown type") - return false - } - }) - - if claudeInfo.Usage.PromptTokens == 0 { - //上游出错 - } - if claudeInfo.Usage.CompletionTokens == 0 { - claudeInfo.Usage, _ = service.ResponseText2Usage(claudeInfo.ResponseText.String(), info.UpstreamModelName, claudeInfo.Usage.PromptTokens) - } - - if info.ShouldIncludeUsage { - response := helper.GenerateFinalUsageResponse(claudeInfo.ResponseId, claudeInfo.Created, info.UpstreamModelName, *claudeInfo.Usage) - err := helper.ObjectData(c, response) - if err != nil { - common.SysError("send final response failed: " + err.Error()) + return wrapErr(errors.New("nil or unknown response type")), nil } } - helper.Done(c) + + claude.HandleFinalResponse(c, info, claudeInfo, RequestModeMessage) + if resp != nil { err = resp.Body.Close() if err != nil { diff --git a/relay/channel/claude/relay-claude.go b/relay/channel/claude/relay-claude.go index 5316a66e..6786a636 100644 --- a/relay/channel/claude/relay-claude.go +++ b/relay/channel/claude/relay-claude.go @@ -479,77 +479,41 @@ func FormatClaudeResponseInfo(requestMode int, claudeResponse *dto.ClaudeRespons return true } -func ClaudeStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo, requestMode int) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) { - - if info.RelayFormat == relaycommon.RelayFormatOpenAI { - return toOpenAIStreamHandler(c, resp, info, requestMode) +func HandleResponseData(c *gin.Context, info *relaycommon.RelayInfo, claudeInfo *ClaudeResponseInfo, data string, requestMode int) bool { + var claudeResponse dto.ClaudeResponse + err := json.NewDecoder(bytes.NewReader(common.StringToByteSlice(data))).Decode(&claudeResponse) + if err != nil { + common.SysError("error unmarshalling stream response: " + err.Error()) + return false } - - usage := &dto.Usage{} - responseText := strings.Builder{} - - helper.StreamScannerHandler(c, resp, info, func(data string) bool { - var claudeResponse dto.ClaudeResponse - err := json.NewDecoder(bytes.NewReader(common.StringToByteSlice(data))).Decode(&claudeResponse) - if err != nil { - common.SysError("error unmarshalling stream response: " + err.Error()) - return true - } + if info.RelayFormat == relaycommon.RelayFormatClaude { if requestMode == RequestModeCompletion { - responseText.WriteString(claudeResponse.Completion) + claudeInfo.ResponseText.WriteString(claudeResponse.Completion) } else { if claudeResponse.Type == "message_start" { // message_start, 获取usage info.UpstreamModelName = claudeResponse.Message.Model - usage.PromptTokens = claudeResponse.Message.Usage.InputTokens - usage.PromptTokensDetails.CachedTokens = claudeResponse.Message.Usage.CacheReadInputTokens - usage.PromptTokensDetails.CachedCreationTokens = claudeResponse.Message.Usage.CacheCreationInputTokens - usage.CompletionTokens = claudeResponse.Message.Usage.OutputTokens + claudeInfo.Usage.PromptTokens = claudeResponse.Message.Usage.InputTokens + claudeInfo.Usage.PromptTokensDetails.CachedTokens = claudeResponse.Message.Usage.CacheReadInputTokens + claudeInfo.Usage.PromptTokensDetails.CachedCreationTokens = claudeResponse.Message.Usage.CacheCreationInputTokens + claudeInfo.Usage.CompletionTokens = claudeResponse.Message.Usage.OutputTokens } else if claudeResponse.Type == "content_block_delta" { - responseText.WriteString(claudeResponse.Delta.GetText()) + claudeInfo.ResponseText.WriteString(claudeResponse.Delta.GetText()) } else if claudeResponse.Type == "message_delta" { if claudeResponse.Usage.InputTokens > 0 { // 不叠加,只取最新的 - usage.PromptTokens = claudeResponse.Usage.InputTokens + claudeInfo.Usage.PromptTokens = claudeResponse.Usage.InputTokens } - usage.CompletionTokens = claudeResponse.Usage.OutputTokens - usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens + claudeInfo.Usage.CompletionTokens = claudeResponse.Usage.OutputTokens + claudeInfo.Usage.TotalTokens = claudeInfo.Usage.PromptTokens + claudeInfo.Usage.CompletionTokens } } helper.ClaudeChunkData(c, claudeResponse, data) - return true - }) - - if requestMode == RequestModeCompletion { - usage, _ = service.ResponseText2Usage(responseText.String(), info.UpstreamModelName, info.PromptTokens) - } else { - // 说明流模式建立失败,可能为官方出错 - if usage.PromptTokens == 0 { - //usage.PromptTokens = info.PromptTokens - } - if usage.CompletionTokens == 0 { - usage, _ = service.ResponseText2Usage(responseText.String(), info.UpstreamModelName, usage.PromptTokens) - } - } - return nil, usage -} - -func toOpenAIStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo, requestMode int) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) { - responseId := fmt.Sprintf("chatcmpl-%s", common.GetUUID()) - claudeInfo := &ClaudeResponseInfo{ - ResponseId: responseId, - Created: common.GetTimestamp(), - Model: info.UpstreamModelName, - ResponseText: strings.Builder{}, - Usage: &dto.Usage{}, - } - - helper.StreamScannerHandler(c, resp, info, func(data string) bool { - var claudeResponse dto.ClaudeResponse + } else if info.RelayFormat == relaycommon.RelayFormatOpenAI { err := json.NewDecoder(bytes.NewReader(common.StringToByteSlice(data))).Decode(&claudeResponse) if err != nil { common.SysError("error unmarshalling stream response: " + err.Error()) - return true + return false } response := StreamResponseClaude2OpenAI(requestMode, &claudeResponse) @@ -562,27 +526,60 @@ func toOpenAIStreamHandler(c *gin.Context, resp *http.Response, info *relaycommo if err != nil { common.LogError(c, "send_stream_response_failed: "+err.Error()) } - return true + } + return true +} + +func HandleFinalResponse(c *gin.Context, info *relaycommon.RelayInfo, claudeInfo *ClaudeResponseInfo, requestMode int) { + if info.RelayFormat == relaycommon.RelayFormatClaude { + if requestMode == RequestModeCompletion { + claudeInfo.Usage, _ = service.ResponseText2Usage(claudeInfo.ResponseText.String(), info.UpstreamModelName, info.PromptTokens) + } else { + // 说明流模式建立失败,可能为官方出错 + if claudeInfo.Usage.PromptTokens == 0 { + //usage.PromptTokens = info.PromptTokens + } + if claudeInfo.Usage.CompletionTokens == 0 { + claudeInfo.Usage, _ = service.ResponseText2Usage(claudeInfo.ResponseText.String(), info.UpstreamModelName, claudeInfo.Usage.PromptTokens) + } + } + } else if info.RelayFormat == relaycommon.RelayFormatOpenAI { + if requestMode == RequestModeCompletion { + claudeInfo.Usage, _ = service.ResponseText2Usage(claudeInfo.ResponseText.String(), info.UpstreamModelName, info.PromptTokens) + } else { + if claudeInfo.Usage.PromptTokens == 0 { + //上游出错 + } + if claudeInfo.Usage.CompletionTokens == 0 { + claudeInfo.Usage, _ = service.ResponseText2Usage(claudeInfo.ResponseText.String(), info.UpstreamModelName, claudeInfo.Usage.PromptTokens) + } + } + if info.ShouldIncludeUsage { + response := helper.GenerateFinalUsageResponse(claudeInfo.ResponseId, claudeInfo.Created, info.UpstreamModelName, *claudeInfo.Usage) + err := helper.ObjectData(c, response) + if err != nil { + common.SysError("send final response failed: " + err.Error()) + } + } + helper.Done(c) + } +} + +func ClaudeStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo, requestMode int) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) { + claudeInfo := &ClaudeResponseInfo{ + ResponseId: fmt.Sprintf("chatcmpl-%s", common.GetUUID()), + Created: common.GetTimestamp(), + Model: info.UpstreamModelName, + ResponseText: strings.Builder{}, + Usage: &dto.Usage{}, + } + + helper.StreamScannerHandler(c, resp, info, func(data string) bool { + return HandleResponseData(c, info, claudeInfo, data, requestMode) }) - if requestMode == RequestModeCompletion { - claudeInfo.Usage, _ = service.ResponseText2Usage(claudeInfo.ResponseText.String(), info.UpstreamModelName, info.PromptTokens) - } else { - if claudeInfo.Usage.PromptTokens == 0 { - //上游出错 - } - if claudeInfo.Usage.CompletionTokens == 0 { - claudeInfo.Usage, _ = service.ResponseText2Usage(claudeInfo.ResponseText.String(), info.UpstreamModelName, claudeInfo.Usage.PromptTokens) - } - } - if info.ShouldIncludeUsage { - response := helper.GenerateFinalUsageResponse(responseId, claudeInfo.Created, info.UpstreamModelName, *claudeInfo.Usage) - err := helper.ObjectData(c, response) - if err != nil { - common.SysError("send final response failed: " + err.Error()) - } - } - helper.Done(c) + HandleFinalResponse(c, info, claudeInfo, requestMode) + return nil, claudeInfo.Usage } diff --git a/relay/channel/xinference/constant.go b/relay/channel/xinference/constant.go index 98ec9b04..a119084f 100644 --- a/relay/channel/xinference/constant.go +++ b/relay/channel/xinference/constant.go @@ -2,6 +2,7 @@ package xinference var ModelList = []string{ "bge-reranker-v2-m3", + "jina-reranker-v2", } var ChannelName = "xinference"