refactor: Streamline AWS and Claude response handling by consolidating logic and improving error management
This commit is contained in:
@@ -1,21 +1,17 @@
|
|||||||
package aws
|
package aws
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
"github.com/pkg/errors"
|
"github.com/pkg/errors"
|
||||||
"io"
|
|
||||||
"net/http"
|
"net/http"
|
||||||
"one-api/common"
|
"one-api/common"
|
||||||
"one-api/dto"
|
"one-api/dto"
|
||||||
"one-api/relay/channel/claude"
|
"one-api/relay/channel/claude"
|
||||||
relaycommon "one-api/relay/common"
|
relaycommon "one-api/relay/common"
|
||||||
"one-api/relay/helper"
|
|
||||||
"one-api/service"
|
"one-api/service"
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/aws/aws-sdk-go-v2/aws"
|
"github.com/aws/aws-sdk-go-v2/aws"
|
||||||
"github.com/aws/aws-sdk-go-v2/credentials"
|
"github.com/aws/aws-sdk-go-v2/credentials"
|
||||||
@@ -143,7 +139,6 @@ func awsStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel
|
|||||||
stream := awsResp.GetStream()
|
stream := awsResp.GetStream()
|
||||||
defer stream.Close()
|
defer stream.Close()
|
||||||
|
|
||||||
c.Writer.Header().Set("Content-Type", "text/event-stream")
|
|
||||||
claudeInfo := &claude.ClaudeResponseInfo{
|
claudeInfo := &claude.ClaudeResponseInfo{
|
||||||
ResponseId: fmt.Sprintf("chatcmpl-%s", common.GetUUID()),
|
ResponseId: fmt.Sprintf("chatcmpl-%s", common.GetUUID()),
|
||||||
Created: common.GetTimestamp(),
|
Created: common.GetTimestamp(),
|
||||||
@@ -151,63 +146,23 @@ func awsStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel
|
|||||||
ResponseText: strings.Builder{},
|
ResponseText: strings.Builder{},
|
||||||
Usage: &dto.Usage{},
|
Usage: &dto.Usage{},
|
||||||
}
|
}
|
||||||
isFirst := true
|
|
||||||
c.Stream(func(w io.Writer) bool {
|
|
||||||
event, ok := <-stream.Events()
|
|
||||||
if !ok {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
|
for event := range stream.Events() {
|
||||||
switch v := event.(type) {
|
switch v := event.(type) {
|
||||||
case *types.ResponseStreamMemberChunk:
|
case *types.ResponseStreamMemberChunk:
|
||||||
if isFirst {
|
info.SetFirstResponseTime()
|
||||||
isFirst = false
|
claude.HandleResponseData(c, info, claudeInfo, string(v.Value.Bytes), RequestModeMessage)
|
||||||
info.FirstResponseTime = time.Now()
|
|
||||||
}
|
|
||||||
claudeResponse := new(dto.ClaudeResponse)
|
|
||||||
err := json.NewDecoder(bytes.NewReader(v.Value.Bytes)).Decode(claudeResponse)
|
|
||||||
if err != nil {
|
|
||||||
common.SysError("error unmarshalling stream response: " + err.Error())
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
response := claude.StreamResponseClaude2OpenAI(requestMode, claudeResponse)
|
|
||||||
|
|
||||||
if !claude.FormatClaudeResponseInfo(RequestModeMessage, claudeResponse, response, claudeInfo) {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
|
|
||||||
jsonStr, err := json.Marshal(response)
|
|
||||||
if err != nil {
|
|
||||||
common.SysError("error marshalling stream response: " + err.Error())
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonStr)})
|
|
||||||
return true
|
|
||||||
case *types.UnknownUnionMember:
|
case *types.UnknownUnionMember:
|
||||||
fmt.Println("unknown tag:", v.Tag)
|
fmt.Println("unknown tag:", v.Tag)
|
||||||
return false
|
return wrapErr(errors.New("unknown response type")), nil
|
||||||
default:
|
default:
|
||||||
fmt.Println("union is nil or unknown type")
|
fmt.Println("union is nil or unknown type")
|
||||||
return false
|
return wrapErr(errors.New("nil or unknown response type")), nil
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
if claudeInfo.Usage.PromptTokens == 0 {
|
|
||||||
//上游出错
|
|
||||||
}
|
|
||||||
if claudeInfo.Usage.CompletionTokens == 0 {
|
|
||||||
claudeInfo.Usage, _ = service.ResponseText2Usage(claudeInfo.ResponseText.String(), info.UpstreamModelName, claudeInfo.Usage.PromptTokens)
|
|
||||||
}
|
|
||||||
|
|
||||||
if info.ShouldIncludeUsage {
|
|
||||||
response := helper.GenerateFinalUsageResponse(claudeInfo.ResponseId, claudeInfo.Created, info.UpstreamModelName, *claudeInfo.Usage)
|
|
||||||
err := helper.ObjectData(c, response)
|
|
||||||
if err != nil {
|
|
||||||
common.SysError("send final response failed: " + err.Error())
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
helper.Done(c)
|
|
||||||
|
claude.HandleFinalResponse(c, info, claudeInfo, RequestModeMessage)
|
||||||
|
|
||||||
if resp != nil {
|
if resp != nil {
|
||||||
err = resp.Body.Close()
|
err = resp.Body.Close()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|||||||
@@ -479,77 +479,41 @@ func FormatClaudeResponseInfo(requestMode int, claudeResponse *dto.ClaudeRespons
|
|||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
func ClaudeStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo, requestMode int) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
|
func HandleResponseData(c *gin.Context, info *relaycommon.RelayInfo, claudeInfo *ClaudeResponseInfo, data string, requestMode int) bool {
|
||||||
|
var claudeResponse dto.ClaudeResponse
|
||||||
if info.RelayFormat == relaycommon.RelayFormatOpenAI {
|
err := json.NewDecoder(bytes.NewReader(common.StringToByteSlice(data))).Decode(&claudeResponse)
|
||||||
return toOpenAIStreamHandler(c, resp, info, requestMode)
|
if err != nil {
|
||||||
|
common.SysError("error unmarshalling stream response: " + err.Error())
|
||||||
|
return false
|
||||||
}
|
}
|
||||||
|
if info.RelayFormat == relaycommon.RelayFormatClaude {
|
||||||
usage := &dto.Usage{}
|
|
||||||
responseText := strings.Builder{}
|
|
||||||
|
|
||||||
helper.StreamScannerHandler(c, resp, info, func(data string) bool {
|
|
||||||
var claudeResponse dto.ClaudeResponse
|
|
||||||
err := json.NewDecoder(bytes.NewReader(common.StringToByteSlice(data))).Decode(&claudeResponse)
|
|
||||||
if err != nil {
|
|
||||||
common.SysError("error unmarshalling stream response: " + err.Error())
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
if requestMode == RequestModeCompletion {
|
if requestMode == RequestModeCompletion {
|
||||||
responseText.WriteString(claudeResponse.Completion)
|
claudeInfo.ResponseText.WriteString(claudeResponse.Completion)
|
||||||
} else {
|
} else {
|
||||||
if claudeResponse.Type == "message_start" {
|
if claudeResponse.Type == "message_start" {
|
||||||
// message_start, 获取usage
|
// message_start, 获取usage
|
||||||
info.UpstreamModelName = claudeResponse.Message.Model
|
info.UpstreamModelName = claudeResponse.Message.Model
|
||||||
usage.PromptTokens = claudeResponse.Message.Usage.InputTokens
|
claudeInfo.Usage.PromptTokens = claudeResponse.Message.Usage.InputTokens
|
||||||
usage.PromptTokensDetails.CachedTokens = claudeResponse.Message.Usage.CacheReadInputTokens
|
claudeInfo.Usage.PromptTokensDetails.CachedTokens = claudeResponse.Message.Usage.CacheReadInputTokens
|
||||||
usage.PromptTokensDetails.CachedCreationTokens = claudeResponse.Message.Usage.CacheCreationInputTokens
|
claudeInfo.Usage.PromptTokensDetails.CachedCreationTokens = claudeResponse.Message.Usage.CacheCreationInputTokens
|
||||||
usage.CompletionTokens = claudeResponse.Message.Usage.OutputTokens
|
claudeInfo.Usage.CompletionTokens = claudeResponse.Message.Usage.OutputTokens
|
||||||
} else if claudeResponse.Type == "content_block_delta" {
|
} else if claudeResponse.Type == "content_block_delta" {
|
||||||
responseText.WriteString(claudeResponse.Delta.GetText())
|
claudeInfo.ResponseText.WriteString(claudeResponse.Delta.GetText())
|
||||||
} else if claudeResponse.Type == "message_delta" {
|
} else if claudeResponse.Type == "message_delta" {
|
||||||
if claudeResponse.Usage.InputTokens > 0 {
|
if claudeResponse.Usage.InputTokens > 0 {
|
||||||
// 不叠加,只取最新的
|
// 不叠加,只取最新的
|
||||||
usage.PromptTokens = claudeResponse.Usage.InputTokens
|
claudeInfo.Usage.PromptTokens = claudeResponse.Usage.InputTokens
|
||||||
}
|
}
|
||||||
usage.CompletionTokens = claudeResponse.Usage.OutputTokens
|
claudeInfo.Usage.CompletionTokens = claudeResponse.Usage.OutputTokens
|
||||||
usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens
|
claudeInfo.Usage.TotalTokens = claudeInfo.Usage.PromptTokens + claudeInfo.Usage.CompletionTokens
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
helper.ClaudeChunkData(c, claudeResponse, data)
|
helper.ClaudeChunkData(c, claudeResponse, data)
|
||||||
return true
|
} else if info.RelayFormat == relaycommon.RelayFormatOpenAI {
|
||||||
})
|
|
||||||
|
|
||||||
if requestMode == RequestModeCompletion {
|
|
||||||
usage, _ = service.ResponseText2Usage(responseText.String(), info.UpstreamModelName, info.PromptTokens)
|
|
||||||
} else {
|
|
||||||
// 说明流模式建立失败,可能为官方出错
|
|
||||||
if usage.PromptTokens == 0 {
|
|
||||||
//usage.PromptTokens = info.PromptTokens
|
|
||||||
}
|
|
||||||
if usage.CompletionTokens == 0 {
|
|
||||||
usage, _ = service.ResponseText2Usage(responseText.String(), info.UpstreamModelName, usage.PromptTokens)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return nil, usage
|
|
||||||
}
|
|
||||||
|
|
||||||
func toOpenAIStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo, requestMode int) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
|
|
||||||
responseId := fmt.Sprintf("chatcmpl-%s", common.GetUUID())
|
|
||||||
claudeInfo := &ClaudeResponseInfo{
|
|
||||||
ResponseId: responseId,
|
|
||||||
Created: common.GetTimestamp(),
|
|
||||||
Model: info.UpstreamModelName,
|
|
||||||
ResponseText: strings.Builder{},
|
|
||||||
Usage: &dto.Usage{},
|
|
||||||
}
|
|
||||||
|
|
||||||
helper.StreamScannerHandler(c, resp, info, func(data string) bool {
|
|
||||||
var claudeResponse dto.ClaudeResponse
|
|
||||||
err := json.NewDecoder(bytes.NewReader(common.StringToByteSlice(data))).Decode(&claudeResponse)
|
err := json.NewDecoder(bytes.NewReader(common.StringToByteSlice(data))).Decode(&claudeResponse)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
common.SysError("error unmarshalling stream response: " + err.Error())
|
common.SysError("error unmarshalling stream response: " + err.Error())
|
||||||
return true
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
response := StreamResponseClaude2OpenAI(requestMode, &claudeResponse)
|
response := StreamResponseClaude2OpenAI(requestMode, &claudeResponse)
|
||||||
@@ -562,27 +526,60 @@ func toOpenAIStreamHandler(c *gin.Context, resp *http.Response, info *relaycommo
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
common.LogError(c, "send_stream_response_failed: "+err.Error())
|
common.LogError(c, "send_stream_response_failed: "+err.Error())
|
||||||
}
|
}
|
||||||
return true
|
}
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
func HandleFinalResponse(c *gin.Context, info *relaycommon.RelayInfo, claudeInfo *ClaudeResponseInfo, requestMode int) {
|
||||||
|
if info.RelayFormat == relaycommon.RelayFormatClaude {
|
||||||
|
if requestMode == RequestModeCompletion {
|
||||||
|
claudeInfo.Usage, _ = service.ResponseText2Usage(claudeInfo.ResponseText.String(), info.UpstreamModelName, info.PromptTokens)
|
||||||
|
} else {
|
||||||
|
// 说明流模式建立失败,可能为官方出错
|
||||||
|
if claudeInfo.Usage.PromptTokens == 0 {
|
||||||
|
//usage.PromptTokens = info.PromptTokens
|
||||||
|
}
|
||||||
|
if claudeInfo.Usage.CompletionTokens == 0 {
|
||||||
|
claudeInfo.Usage, _ = service.ResponseText2Usage(claudeInfo.ResponseText.String(), info.UpstreamModelName, claudeInfo.Usage.PromptTokens)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else if info.RelayFormat == relaycommon.RelayFormatOpenAI {
|
||||||
|
if requestMode == RequestModeCompletion {
|
||||||
|
claudeInfo.Usage, _ = service.ResponseText2Usage(claudeInfo.ResponseText.String(), info.UpstreamModelName, info.PromptTokens)
|
||||||
|
} else {
|
||||||
|
if claudeInfo.Usage.PromptTokens == 0 {
|
||||||
|
//上游出错
|
||||||
|
}
|
||||||
|
if claudeInfo.Usage.CompletionTokens == 0 {
|
||||||
|
claudeInfo.Usage, _ = service.ResponseText2Usage(claudeInfo.ResponseText.String(), info.UpstreamModelName, claudeInfo.Usage.PromptTokens)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if info.ShouldIncludeUsage {
|
||||||
|
response := helper.GenerateFinalUsageResponse(claudeInfo.ResponseId, claudeInfo.Created, info.UpstreamModelName, *claudeInfo.Usage)
|
||||||
|
err := helper.ObjectData(c, response)
|
||||||
|
if err != nil {
|
||||||
|
common.SysError("send final response failed: " + err.Error())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
helper.Done(c)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func ClaudeStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo, requestMode int) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
|
||||||
|
claudeInfo := &ClaudeResponseInfo{
|
||||||
|
ResponseId: fmt.Sprintf("chatcmpl-%s", common.GetUUID()),
|
||||||
|
Created: common.GetTimestamp(),
|
||||||
|
Model: info.UpstreamModelName,
|
||||||
|
ResponseText: strings.Builder{},
|
||||||
|
Usage: &dto.Usage{},
|
||||||
|
}
|
||||||
|
|
||||||
|
helper.StreamScannerHandler(c, resp, info, func(data string) bool {
|
||||||
|
return HandleResponseData(c, info, claudeInfo, data, requestMode)
|
||||||
})
|
})
|
||||||
|
|
||||||
if requestMode == RequestModeCompletion {
|
HandleFinalResponse(c, info, claudeInfo, requestMode)
|
||||||
claudeInfo.Usage, _ = service.ResponseText2Usage(claudeInfo.ResponseText.String(), info.UpstreamModelName, info.PromptTokens)
|
|
||||||
} else {
|
|
||||||
if claudeInfo.Usage.PromptTokens == 0 {
|
|
||||||
//上游出错
|
|
||||||
}
|
|
||||||
if claudeInfo.Usage.CompletionTokens == 0 {
|
|
||||||
claudeInfo.Usage, _ = service.ResponseText2Usage(claudeInfo.ResponseText.String(), info.UpstreamModelName, claudeInfo.Usage.PromptTokens)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if info.ShouldIncludeUsage {
|
|
||||||
response := helper.GenerateFinalUsageResponse(responseId, claudeInfo.Created, info.UpstreamModelName, *claudeInfo.Usage)
|
|
||||||
err := helper.ObjectData(c, response)
|
|
||||||
if err != nil {
|
|
||||||
common.SysError("send final response failed: " + err.Error())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
helper.Done(c)
|
|
||||||
return nil, claudeInfo.Usage
|
return nil, claudeInfo.Usage
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -2,6 +2,7 @@ package xinference
|
|||||||
|
|
||||||
var ModelList = []string{
|
var ModelList = []string{
|
||||||
"bge-reranker-v2-m3",
|
"bge-reranker-v2-m3",
|
||||||
|
"jina-reranker-v2",
|
||||||
}
|
}
|
||||||
|
|
||||||
var ChannelName = "xinference"
|
var ChannelName = "xinference"
|
||||||
|
|||||||
Reference in New Issue
Block a user