diff --git a/relay/channel/claude/relay-claude.go b/relay/channel/claude/relay-claude.go index 8c74af08..ba20adea 100644 --- a/relay/channel/claude/relay-claude.go +++ b/relay/channel/claude/relay-claude.go @@ -454,6 +454,7 @@ type ClaudeResponseInfo struct { Model string ResponseText strings.Builder Usage *dto.Usage + Done bool } func FormatClaudeResponseInfo(requestMode int, claudeResponse *dto.ClaudeResponse, oaiResponse *dto.ChatCompletionsStreamResponse, claudeInfo *ClaudeResponseInfo) bool { @@ -461,20 +462,32 @@ func FormatClaudeResponseInfo(requestMode int, claudeResponse *dto.ClaudeRespons claudeInfo.ResponseText.WriteString(claudeResponse.Completion) } else { if claudeResponse.Type == "message_start" { - // message_start, 获取usage claudeInfo.ResponseId = claudeResponse.Message.Id claudeInfo.Model = claudeResponse.Message.Model + + // message_start, 获取usage claudeInfo.Usage.PromptTokens = claudeResponse.Message.Usage.InputTokens + claudeInfo.Usage.PromptTokensDetails.CachedTokens = claudeResponse.Message.Usage.CacheReadInputTokens + claudeInfo.Usage.PromptTokensDetails.CachedCreationTokens = claudeResponse.Message.Usage.CacheCreationInputTokens + claudeInfo.Usage.CompletionTokens = claudeResponse.Message.Usage.OutputTokens } else if claudeResponse.Type == "content_block_delta" { if claudeResponse.Delta.Text != nil { claudeInfo.ResponseText.WriteString(*claudeResponse.Delta.Text) } + if claudeResponse.Delta.Thinking != "" { + claudeInfo.ResponseText.WriteString(claudeResponse.Delta.Thinking) + } } else if claudeResponse.Type == "message_delta" { - claudeInfo.Usage.CompletionTokens = claudeResponse.Usage.OutputTokens + // 最终的usage获取 if claudeResponse.Usage.InputTokens > 0 { + // 不叠加,只取最新的 claudeInfo.Usage.PromptTokens = claudeResponse.Usage.InputTokens } - claudeInfo.Usage.TotalTokens = claudeInfo.Usage.PromptTokens + claudeResponse.Usage.OutputTokens + claudeInfo.Usage.CompletionTokens = claudeResponse.Usage.OutputTokens + claudeInfo.Usage.TotalTokens = claudeInfo.Usage.PromptTokens + claudeInfo.Usage.CompletionTokens + + // 判断是否完整 + claudeInfo.Done = true } else if claudeResponse.Type == "content_block_start" { } else { return false @@ -506,25 +519,15 @@ func HandleStreamResponseData(c *gin.Context, info *relaycommon.RelayInfo, claud } } if info.RelayFormat == relaycommon.RelayFormatClaude { + FormatClaudeResponseInfo(requestMode, &claudeResponse, nil, claudeInfo) + if requestMode == RequestModeCompletion { - claudeInfo.ResponseText.WriteString(claudeResponse.Completion) } else { if claudeResponse.Type == "message_start" { // message_start, 获取usage info.UpstreamModelName = claudeResponse.Message.Model - claudeInfo.Usage.PromptTokens = claudeResponse.Message.Usage.InputTokens - claudeInfo.Usage.PromptTokensDetails.CachedTokens = claudeResponse.Message.Usage.CacheReadInputTokens - claudeInfo.Usage.PromptTokensDetails.CachedCreationTokens = claudeResponse.Message.Usage.CacheCreationInputTokens - claudeInfo.Usage.CompletionTokens = claudeResponse.Message.Usage.OutputTokens } else if claudeResponse.Type == "content_block_delta" { - claudeInfo.ResponseText.WriteString(claudeResponse.Delta.GetText()) } else if claudeResponse.Type == "message_delta" { - if claudeResponse.Usage.InputTokens > 0 { - // 不叠加,只取最新的 - claudeInfo.Usage.PromptTokens = claudeResponse.Usage.InputTokens - } - claudeInfo.Usage.CompletionTokens = claudeResponse.Usage.OutputTokens - claudeInfo.Usage.TotalTokens = claudeInfo.Usage.PromptTokens + claudeInfo.Usage.CompletionTokens } } helper.ClaudeChunkData(c, claudeResponse, data) @@ -544,29 +547,25 @@ func HandleStreamResponseData(c *gin.Context, info *relaycommon.RelayInfo, claud } func HandleStreamFinalResponse(c *gin.Context, info *relaycommon.RelayInfo, claudeInfo *ClaudeResponseInfo, requestMode int) { + + if requestMode == RequestModeCompletion { + claudeInfo.Usage, _ = service.ResponseText2Usage(claudeInfo.ResponseText.String(), info.UpstreamModelName, info.PromptTokens) + } else { + if claudeInfo.Usage.PromptTokens == 0 { + //上游出错 + } + if claudeInfo.Usage.CompletionTokens == 0 || !claudeInfo.Done { + if common.DebugEnabled { + common.SysError("claude response usage is not complete, maybe upstream error") + } + claudeInfo.Usage, _ = service.ResponseText2Usage(claudeInfo.ResponseText.String(), info.UpstreamModelName, claudeInfo.Usage.PromptTokens) + } + } + if info.RelayFormat == relaycommon.RelayFormatClaude { - if requestMode == RequestModeCompletion { - claudeInfo.Usage, _ = service.ResponseText2Usage(claudeInfo.ResponseText.String(), info.UpstreamModelName, info.PromptTokens) - } else { - // 说明流模式建立失败,可能为官方出错 - if claudeInfo.Usage.PromptTokens == 0 { - //usage.PromptTokens = info.PromptTokens - } - if claudeInfo.Usage.CompletionTokens == 0 { - claudeInfo.Usage, _ = service.ResponseText2Usage(claudeInfo.ResponseText.String(), info.UpstreamModelName, claudeInfo.Usage.PromptTokens) - } - } + // } else if info.RelayFormat == relaycommon.RelayFormatOpenAI { - if requestMode == RequestModeCompletion { - claudeInfo.Usage, _ = service.ResponseText2Usage(claudeInfo.ResponseText.String(), info.UpstreamModelName, info.PromptTokens) - } else { - if claudeInfo.Usage.PromptTokens == 0 { - //上游出错 - } - if claudeInfo.Usage.CompletionTokens == 0 { - claudeInfo.Usage, _ = service.ResponseText2Usage(claudeInfo.ResponseText.String(), info.UpstreamModelName, claudeInfo.Usage.PromptTokens) - } - } + if info.ShouldIncludeUsage { response := helper.GenerateFinalUsageResponse(claudeInfo.ResponseId, claudeInfo.Created, info.UpstreamModelName, *claudeInfo.Usage) err := helper.ObjectData(c, response)