feat: claude relay

This commit is contained in:
1808837298@qq.com
2025-03-12 21:31:46 +08:00
parent c0b9350785
commit bd48f43410
51 changed files with 1660 additions and 236 deletions

View File

@@ -30,9 +30,9 @@ func stopReasonClaude2OpenAI(reason string) string {
}
}
func RequestOpenAI2ClaudeComplete(textRequest dto.GeneralOpenAIRequest) *ClaudeRequest {
func RequestOpenAI2ClaudeComplete(textRequest dto.GeneralOpenAIRequest) *dto.ClaudeRequest {
claudeRequest := ClaudeRequest{
claudeRequest := dto.ClaudeRequest{
Model: textRequest.Model,
Prompt: "",
StopSequences: nil,
@@ -61,12 +61,12 @@ func RequestOpenAI2ClaudeComplete(textRequest dto.GeneralOpenAIRequest) *ClaudeR
return &claudeRequest
}
func RequestOpenAI2ClaudeMessage(textRequest dto.GeneralOpenAIRequest) (*ClaudeRequest, error) {
claudeTools := make([]Tool, 0, len(textRequest.Tools))
func RequestOpenAI2ClaudeMessage(textRequest dto.GeneralOpenAIRequest) (*dto.ClaudeRequest, error) {
claudeTools := make([]dto.Tool, 0, len(textRequest.Tools))
for _, tool := range textRequest.Tools {
if params, ok := tool.Function.Parameters.(map[string]any); ok {
claudeTool := Tool{
claudeTool := dto.Tool{
Name: tool.Function.Name,
Description: tool.Function.Description,
}
@@ -84,7 +84,7 @@ func RequestOpenAI2ClaudeMessage(textRequest dto.GeneralOpenAIRequest) (*ClaudeR
}
}
claudeRequest := ClaudeRequest{
claudeRequest := dto.ClaudeRequest{
Model: textRequest.Model,
MaxTokens: textRequest.MaxTokens,
StopSequences: nil,
@@ -108,7 +108,7 @@ func RequestOpenAI2ClaudeMessage(textRequest dto.GeneralOpenAIRequest) (*ClaudeR
}
// BudgetTokens 为 max_tokens 的 80%
claudeRequest.Thinking = &Thinking{
claudeRequest.Thinking = &dto.Thinking{
Type: "enabled",
BudgetTokens: int(float64(claudeRequest.MaxTokens) * model_setting.GetClaudeSettings().ThinkingAdapterBudgetTokensPercentage),
}
@@ -166,7 +166,7 @@ func RequestOpenAI2ClaudeMessage(textRequest dto.GeneralOpenAIRequest) (*ClaudeR
lastMessage = fmtMessage
}
claudeMessages := make([]ClaudeMessage, 0)
claudeMessages := make([]dto.ClaudeMessage, 0)
isFirstMessage := true
for _, message := range formatMessages {
if message.Role == "system" {
@@ -187,63 +187,63 @@ func RequestOpenAI2ClaudeMessage(textRequest dto.GeneralOpenAIRequest) (*ClaudeR
isFirstMessage = false
if message.Role != "user" {
// fix: first message is assistant, add user message
claudeMessage := ClaudeMessage{
claudeMessage := dto.ClaudeMessage{
Role: "user",
Content: []ClaudeMediaMessage{
Content: []dto.ClaudeMediaMessage{
{
Type: "text",
Text: "...",
Text: common.GetPointer[string]("..."),
},
},
}
claudeMessages = append(claudeMessages, claudeMessage)
}
}
claudeMessage := ClaudeMessage{
claudeMessage := dto.ClaudeMessage{
Role: message.Role,
}
if message.Role == "tool" {
if len(claudeMessages) > 0 && claudeMessages[len(claudeMessages)-1].Role == "user" {
lastMessage := claudeMessages[len(claudeMessages)-1]
if content, ok := lastMessage.Content.(string); ok {
lastMessage.Content = []ClaudeMediaMessage{
lastMessage.Content = []dto.ClaudeMediaMessage{
{
Type: "text",
Text: content,
Text: common.GetPointer[string](content),
},
}
}
lastMessage.Content = append(lastMessage.Content.([]ClaudeMediaMessage), ClaudeMediaMessage{
lastMessage.Content = append(lastMessage.Content.([]dto.ClaudeMediaMessage), dto.ClaudeMediaMessage{
Type: "tool_result",
ToolUseId: message.ToolCallId,
Content: message.StringContent(),
Content: message.Content,
})
claudeMessages[len(claudeMessages)-1] = lastMessage
continue
} else {
claudeMessage.Role = "user"
claudeMessage.Content = []ClaudeMediaMessage{
claudeMessage.Content = []dto.ClaudeMediaMessage{
{
Type: "tool_result",
ToolUseId: message.ToolCallId,
Content: message.StringContent(),
Content: message.Content,
},
}
}
} else if message.IsStringContent() && message.ToolCalls == nil {
claudeMessage.Content = message.StringContent()
} else {
claudeMediaMessages := make([]ClaudeMediaMessage, 0)
claudeMediaMessages := make([]dto.ClaudeMediaMessage, 0)
for _, mediaMessage := range message.ParseContent() {
claudeMediaMessage := ClaudeMediaMessage{
claudeMediaMessage := dto.ClaudeMediaMessage{
Type: mediaMessage.Type,
}
if mediaMessage.Type == "text" {
claudeMediaMessage.Text = mediaMessage.Text
claudeMediaMessage.Text = common.GetPointer[string](mediaMessage.Text)
} else {
imageUrl := mediaMessage.ImageUrl.(dto.MessageImageUrl)
claudeMediaMessage.Type = "image"
claudeMediaMessage.Source = &ClaudeMessageSource{
claudeMediaMessage.Source = &dto.ClaudeMessageSource{
Type: "base64",
}
// 判断是否是url
@@ -273,7 +273,7 @@ func RequestOpenAI2ClaudeMessage(textRequest dto.GeneralOpenAIRequest) (*ClaudeR
common.SysError("tool call function arguments is not a map[string]any: " + fmt.Sprintf("%v", toolCall.Function.Arguments))
continue
}
claudeMediaMessages = append(claudeMediaMessages, ClaudeMediaMessage{
claudeMediaMessages = append(claudeMediaMessages, dto.ClaudeMediaMessage{
Type: "tool_use",
Id: toolCall.ID,
Name: toolCall.Function.Name,
@@ -291,7 +291,7 @@ func RequestOpenAI2ClaudeMessage(textRequest dto.GeneralOpenAIRequest) (*ClaudeR
return &claudeRequest, nil
}
func StreamResponseClaude2OpenAI(reqMode int, claudeResponse *ClaudeResponse) *dto.ChatCompletionsStreamResponse {
func StreamResponseClaude2OpenAI(reqMode int, claudeResponse *dto.ClaudeResponse) *dto.ChatCompletionsStreamResponse {
var response dto.ChatCompletionsStreamResponse
response.Object = "chat.completion.chunk"
response.Model = claudeResponse.Model
@@ -329,8 +329,8 @@ func StreamResponseClaude2OpenAI(reqMode int, claudeResponse *ClaudeResponse) *d
}
} else if claudeResponse.Type == "content_block_delta" {
if claudeResponse.Delta != nil {
choice.Index = claudeResponse.Index
choice.Delta.SetContentString(claudeResponse.Delta.Text)
choice.Index = *claudeResponse.Index
choice.Delta.SetContentString(*claudeResponse.Delta.Text)
switch claudeResponse.Delta.Type {
case "input_json_delta":
tools = append(tools, dto.ToolCallResponse{
@@ -368,7 +368,7 @@ func StreamResponseClaude2OpenAI(reqMode int, claudeResponse *ClaudeResponse) *d
return &response
}
func ResponseClaude2OpenAI(reqMode int, claudeResponse *ClaudeResponse) *dto.OpenAITextResponse {
func ResponseClaude2OpenAI(reqMode int, claudeResponse *dto.ClaudeResponse) *dto.OpenAITextResponse {
choices := make([]dto.OpenAITextResponseChoice, 0)
fullTextResponse := dto.OpenAITextResponse{
Id: fmt.Sprintf("chatcmpl-%s", common.GetUUID()),
@@ -377,7 +377,7 @@ func ResponseClaude2OpenAI(reqMode int, claudeResponse *ClaudeResponse) *dto.Ope
}
var responseText string
if len(claudeResponse.Content) > 0 {
responseText = claudeResponse.Content[0].Text
responseText = *claudeResponse.Content[0].Text
}
tools := make([]dto.ToolCallResponse, 0)
thinkingContent := ""
@@ -412,7 +412,7 @@ func ResponseClaude2OpenAI(reqMode int, claudeResponse *ClaudeResponse) *dto.Ope
// 加密的不管, 只输出明文的推理过程
thinkingContent = message.Thinking
case "text":
responseText = message.Text
responseText = *message.Text
}
}
}
@@ -442,7 +442,7 @@ type ClaudeResponseInfo struct {
Usage *dto.Usage
}
func FormatClaudeResponseInfo(requestMode int, claudeResponse *ClaudeResponse, oaiResponse *dto.ChatCompletionsStreamResponse, claudeInfo *ClaudeResponseInfo) bool {
func FormatClaudeResponseInfo(requestMode int, claudeResponse *dto.ClaudeResponse, oaiResponse *dto.ChatCompletionsStreamResponse, claudeInfo *ClaudeResponseInfo) bool {
if requestMode == RequestModeCompletion {
claudeInfo.ResponseText.WriteString(claudeResponse.Completion)
} else {
@@ -452,7 +452,7 @@ func FormatClaudeResponseInfo(requestMode int, claudeResponse *ClaudeResponse, o
claudeInfo.Model = claudeResponse.Message.Model
claudeInfo.Usage.PromptTokens = claudeResponse.Message.Usage.InputTokens
} else if claudeResponse.Type == "content_block_delta" {
claudeInfo.ResponseText.WriteString(claudeResponse.Delta.Text)
claudeInfo.ResponseText.WriteString(*claudeResponse.Delta.Text)
} else if claudeResponse.Type == "message_delta" {
claudeInfo.Usage.CompletionTokens = claudeResponse.Usage.OutputTokens
claudeInfo.Usage.TotalTokens = claudeResponse.Usage.InputTokens + claudeResponse.Usage.OutputTokens
@@ -470,6 +470,61 @@ func FormatClaudeResponseInfo(requestMode int, claudeResponse *ClaudeResponse, o
}
func ClaudeStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo, requestMode int) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
if info.RelayFormat == relaycommon.RelayFormatOpenAI {
return toOpenAIStreamHandler(c, resp, info, requestMode)
}
usage := &dto.Usage{}
responseText := strings.Builder{}
helper.StreamScannerHandler(c, resp, info, func(data string) bool {
var claudeResponse dto.ClaudeResponse
err := json.NewDecoder(bytes.NewReader(common.StringToByteSlice(data))).Decode(&claudeResponse)
if err != nil {
common.SysError("error unmarshalling stream response: " + err.Error())
return true
}
if requestMode == RequestModeCompletion {
responseText.WriteString(claudeResponse.Completion)
} else {
if claudeResponse.Type == "message_start" {
// message_start, 获取usage
info.UpstreamModelName = claudeResponse.Message.Model
usage.PromptTokens = claudeResponse.Message.Usage.InputTokens
usage.PromptTokensDetails.CachedTokens = claudeResponse.Message.Usage.CacheReadInputTokens
usage.PromptTokensDetails.CachedCreationTokens = claudeResponse.Message.Usage.CacheCreationInputTokens
usage.CompletionTokens = claudeResponse.Message.Usage.OutputTokens
} else if claudeResponse.Type == "content_block_delta" {
responseText.WriteString(claudeResponse.Delta.GetText())
} else if claudeResponse.Type == "message_delta" {
if claudeResponse.Usage.InputTokens > 0 {
// 不叠加,只取最新的
usage.PromptTokens = claudeResponse.Usage.InputTokens
}
usage.CompletionTokens = claudeResponse.Usage.OutputTokens
usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens
}
}
helper.ClaudeChunkData(c, claudeResponse, data)
return true
})
if requestMode == RequestModeCompletion {
usage, _ = service.ResponseText2Usage(responseText.String(), info.UpstreamModelName, info.PromptTokens)
} else {
// 说明流模式建立失败,可能为官方出错
if usage.PromptTokens == 0 {
//usage.PromptTokens = info.PromptTokens
}
if usage.CompletionTokens == 0 {
usage, _ = service.ResponseText2Usage(responseText.String(), info.UpstreamModelName, usage.PromptTokens)
}
}
return nil, usage
}
func toOpenAIStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo, requestMode int) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
responseId := fmt.Sprintf("chatcmpl-%s", common.GetUUID())
claudeInfo := &ClaudeResponseInfo{
ResponseId: responseId,
@@ -480,7 +535,7 @@ func ClaudeStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.
}
helper.StreamScannerHandler(c, resp, info, func(data string) bool {
var claudeResponse ClaudeResponse
var claudeResponse dto.ClaudeResponse
err := json.NewDecoder(bytes.NewReader(common.StringToByteSlice(data))).Decode(&claudeResponse)
if err != nil {
common.SysError("error unmarshalling stream response: " + err.Error())
@@ -530,7 +585,7 @@ func ClaudeHandler(c *gin.Context, resp *http.Response, requestMode int, info *r
if err != nil {
return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
}
var claudeResponse ClaudeResponse
var claudeResponse dto.ClaudeResponse
err = json.Unmarshal(responseBody, &claudeResponse)
if err != nil {
return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
@@ -546,13 +601,12 @@ func ClaudeHandler(c *gin.Context, resp *http.Response, requestMode int, info *r
StatusCode: resp.StatusCode,
}, nil
}
fullTextResponse := ResponseClaude2OpenAI(requestMode, &claudeResponse)
completionTokens, err := service.CountTextToken(claudeResponse.Completion, info.OriginModelName)
if err != nil {
return service.OpenAIErrorWrapper(err, "count_token_text_failed", http.StatusInternalServerError), nil
}
usage := dto.Usage{}
if requestMode == RequestModeCompletion {
completionTokens, err := service.CountTextToken(claudeResponse.Completion, info.OriginModelName)
if err != nil {
return service.OpenAIErrorWrapper(err, "count_token_text_failed", http.StatusInternalServerError), nil
}
usage.PromptTokens = info.PromptTokens
usage.CompletionTokens = completionTokens
usage.TotalTokens = info.PromptTokens + completionTokens
@@ -560,14 +614,23 @@ func ClaudeHandler(c *gin.Context, resp *http.Response, requestMode int, info *r
usage.PromptTokens = claudeResponse.Usage.InputTokens
usage.CompletionTokens = claudeResponse.Usage.OutputTokens
usage.TotalTokens = claudeResponse.Usage.InputTokens + claudeResponse.Usage.OutputTokens
usage.PromptTokensDetails.CachedTokens = claudeResponse.Usage.CacheReadInputTokens
usage.PromptTokensDetails.CachedCreationTokens = claudeResponse.Usage.CacheCreationInputTokens
}
fullTextResponse.Usage = usage
jsonResponse, err := json.Marshal(fullTextResponse)
if err != nil {
return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
var responseData []byte
switch info.RelayFormat {
case relaycommon.RelayFormatOpenAI:
openaiResponse := ResponseClaude2OpenAI(requestMode, &claudeResponse)
openaiResponse.Usage = usage
responseData, err = json.Marshal(openaiResponse)
if err != nil {
return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
}
case relaycommon.RelayFormatClaude:
responseData = responseBody
}
c.Writer.Header().Set("Content-Type", "application/json")
c.Writer.WriteHeader(resp.StatusCode)
_, err = c.Writer.Write(jsonResponse)
_, err = c.Writer.Write(responseData)
return nil, &usage
}