feat: claude relay
This commit is contained in:
310
service/convert.go
Normal file
310
service/convert.go
Normal file
@@ -0,0 +1,310 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"one-api/common"
|
||||
"one-api/dto"
|
||||
relaycommon "one-api/relay/common"
|
||||
)
|
||||
|
||||
func ClaudeToOpenAIRequest(claudeRequest dto.ClaudeRequest) (*dto.GeneralOpenAIRequest, error) {
|
||||
openAIRequest := dto.GeneralOpenAIRequest{
|
||||
Model: claudeRequest.Model,
|
||||
MaxTokens: claudeRequest.MaxTokens,
|
||||
Temperature: claudeRequest.Temperature,
|
||||
TopP: claudeRequest.TopP,
|
||||
Stream: claudeRequest.Stream,
|
||||
}
|
||||
|
||||
// Convert stop sequences
|
||||
if len(claudeRequest.StopSequences) == 1 {
|
||||
openAIRequest.Stop = claudeRequest.StopSequences[0]
|
||||
} else if len(claudeRequest.StopSequences) > 1 {
|
||||
openAIRequest.Stop = claudeRequest.StopSequences
|
||||
}
|
||||
|
||||
// Convert tools
|
||||
tools, _ := common.Any2Type[[]dto.Tool](claudeRequest.Tools)
|
||||
openAITools := make([]dto.ToolCallRequest, 0)
|
||||
for _, claudeTool := range tools {
|
||||
openAITool := dto.ToolCallRequest{
|
||||
Type: "function",
|
||||
Function: dto.FunctionRequest{
|
||||
Name: claudeTool.Name,
|
||||
Description: claudeTool.Description,
|
||||
Parameters: claudeTool.InputSchema,
|
||||
},
|
||||
}
|
||||
openAITools = append(openAITools, openAITool)
|
||||
}
|
||||
openAIRequest.Tools = openAITools
|
||||
|
||||
// Convert messages
|
||||
openAIMessages := make([]dto.Message, 0)
|
||||
|
||||
// Add system message if present
|
||||
if claudeRequest.IsStringSystem() {
|
||||
openAIMessage := dto.Message{
|
||||
Role: "system",
|
||||
}
|
||||
openAIMessage.SetStringContent(claudeRequest.GetStringSystem())
|
||||
openAIMessages = append(openAIMessages, openAIMessage)
|
||||
} else {
|
||||
systems := claudeRequest.ParseSystem()
|
||||
if len(systems) > 0 {
|
||||
systemStr := ""
|
||||
openAIMessage := dto.Message{
|
||||
Role: "system",
|
||||
}
|
||||
for _, system := range systems {
|
||||
systemStr += system.Type
|
||||
}
|
||||
openAIMessage.SetStringContent(systemStr)
|
||||
openAIMessages = append(openAIMessages, openAIMessage)
|
||||
}
|
||||
}
|
||||
for _, claudeMessage := range claudeRequest.Messages {
|
||||
openAIMessage := dto.Message{
|
||||
Role: claudeMessage.Role,
|
||||
}
|
||||
|
||||
//log.Printf("claudeMessage.Content: %v", claudeMessage.Content)
|
||||
if claudeMessage.IsStringContent() {
|
||||
openAIMessage.SetStringContent(claudeMessage.GetStringContent())
|
||||
} else {
|
||||
content, err := claudeMessage.ParseContent()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
contents := content
|
||||
var toolCalls []dto.ToolCallRequest
|
||||
mediaMessages := make([]dto.MediaContent, 0, len(contents))
|
||||
|
||||
for _, mediaMsg := range contents {
|
||||
switch mediaMsg.Type {
|
||||
case "text":
|
||||
message := dto.MediaContent{
|
||||
Type: "text",
|
||||
Text: mediaMsg.GetText(),
|
||||
}
|
||||
mediaMessages = append(mediaMessages, message)
|
||||
case "image":
|
||||
// Handle image conversion (base64 to URL or keep as is)
|
||||
imageData := fmt.Sprintf("data:%s;base64,%s", mediaMsg.Source.MediaType, mediaMsg.Source.Data)
|
||||
//textContent += fmt.Sprintf("[Image: %s]", imageData)
|
||||
mediaMessage := dto.MediaContent{
|
||||
Type: "image_url",
|
||||
ImageUrl: &dto.MessageImageUrl{Url: imageData},
|
||||
}
|
||||
mediaMessages = append(mediaMessages, mediaMessage)
|
||||
case "tool_use":
|
||||
toolCall := dto.ToolCallRequest{
|
||||
ID: mediaMsg.Id,
|
||||
Function: dto.FunctionRequest{
|
||||
Name: mediaMsg.Name,
|
||||
Arguments: toJSONString(mediaMsg.Input),
|
||||
},
|
||||
}
|
||||
toolCalls = append(toolCalls, toolCall)
|
||||
case "tool_result":
|
||||
// Add tool result as a separate message
|
||||
oaiToolMessage := dto.Message{
|
||||
Role: "tool",
|
||||
ToolCallId: mediaMsg.ToolUseId,
|
||||
}
|
||||
oaiToolMessage.Content = mediaMsg.Content
|
||||
}
|
||||
}
|
||||
|
||||
openAIMessage.SetMediaContent(mediaMessages)
|
||||
|
||||
if len(toolCalls) > 0 {
|
||||
openAIMessage.SetToolCalls(toolCalls)
|
||||
}
|
||||
}
|
||||
|
||||
openAIMessages = append(openAIMessages, openAIMessage)
|
||||
}
|
||||
|
||||
openAIRequest.Messages = openAIMessages
|
||||
|
||||
return &openAIRequest, nil
|
||||
}
|
||||
|
||||
func OpenAIErrorToClaudeError(openAIError *dto.OpenAIErrorWithStatusCode) *dto.ClaudeErrorWithStatusCode {
|
||||
claudeError := dto.ClaudeError{
|
||||
Type: "new_api_error",
|
||||
Message: openAIError.Error.Message,
|
||||
}
|
||||
return &dto.ClaudeErrorWithStatusCode{
|
||||
Error: claudeError,
|
||||
StatusCode: openAIError.StatusCode,
|
||||
}
|
||||
}
|
||||
|
||||
func ClaudeErrorToOpenAIError(claudeError *dto.ClaudeErrorWithStatusCode) *dto.OpenAIErrorWithStatusCode {
|
||||
openAIError := dto.OpenAIError{
|
||||
Message: claudeError.Error.Message,
|
||||
Type: "new_api_error",
|
||||
}
|
||||
return &dto.OpenAIErrorWithStatusCode{
|
||||
Error: openAIError,
|
||||
StatusCode: claudeError.StatusCode,
|
||||
}
|
||||
}
|
||||
|
||||
func StreamResponseOpenAI2Claude(openAIResponse *dto.ChatCompletionsStreamResponse, info *relaycommon.RelayInfo) []*dto.ClaudeResponse {
|
||||
var claudeResponses []*dto.ClaudeResponse
|
||||
if info.ResponseTimes == 1 {
|
||||
claudeResponses = append(claudeResponses, &dto.ClaudeResponse{
|
||||
Type: "message_start",
|
||||
Message: &dto.ClaudeMediaMessage{
|
||||
Id: openAIResponse.Id,
|
||||
Model: openAIResponse.Model,
|
||||
Type: "message",
|
||||
Role: "assistant",
|
||||
Usage: &dto.ClaudeUsage{
|
||||
InputTokens: info.PromptTokens,
|
||||
OutputTokens: 0,
|
||||
},
|
||||
},
|
||||
})
|
||||
if openAIResponse.IsToolCall() {
|
||||
resp := &dto.ClaudeResponse{
|
||||
Type: "content_block_start",
|
||||
ContentBlock: &dto.ClaudeMediaMessage{
|
||||
Id: openAIResponse.GetFirstToolCall().ID,
|
||||
Type: "tool_use",
|
||||
Name: openAIResponse.GetFirstToolCall().Function.Name,
|
||||
},
|
||||
}
|
||||
resp.SetIndex(0)
|
||||
claudeResponses = append(claudeResponses, resp)
|
||||
} else {
|
||||
resp := &dto.ClaudeResponse{
|
||||
Type: "content_block_start",
|
||||
ContentBlock: &dto.ClaudeMediaMessage{
|
||||
Type: "text",
|
||||
Text: common.GetPointer[string](""),
|
||||
},
|
||||
}
|
||||
resp.SetIndex(0)
|
||||
claudeResponses = append(claudeResponses, resp)
|
||||
}
|
||||
claudeResponses = append(claudeResponses, &dto.ClaudeResponse{
|
||||
Type: "ping",
|
||||
})
|
||||
return claudeResponses
|
||||
}
|
||||
|
||||
if len(openAIResponse.Choices) == 0 {
|
||||
// no choices
|
||||
// TODO: handle this case
|
||||
} else {
|
||||
chosenChoice := openAIResponse.Choices[0]
|
||||
if chosenChoice.FinishReason != nil && *chosenChoice.FinishReason != "" {
|
||||
// should be done
|
||||
claudeResponses = append(claudeResponses, &dto.ClaudeResponse{
|
||||
Type: "content_block_stop",
|
||||
Index: common.GetPointer[int](0),
|
||||
})
|
||||
if openAIResponse.Usage != nil {
|
||||
claudeResponses = append(claudeResponses, &dto.ClaudeResponse{
|
||||
Type: "message_delta",
|
||||
Usage: &dto.ClaudeUsage{
|
||||
InputTokens: openAIResponse.Usage.PromptTokens,
|
||||
OutputTokens: openAIResponse.Usage.CompletionTokens,
|
||||
},
|
||||
Delta: &dto.ClaudeMediaMessage{
|
||||
StopReason: common.GetPointer[string](stopReasonOpenAI2Claude(*chosenChoice.FinishReason)),
|
||||
},
|
||||
})
|
||||
}
|
||||
claudeResponses = append(claudeResponses, &dto.ClaudeResponse{
|
||||
Type: "message_stop",
|
||||
})
|
||||
} else {
|
||||
var claudeResponse dto.ClaudeResponse
|
||||
claudeResponse.SetIndex(0)
|
||||
claudeResponse.Type = "content_block_delta"
|
||||
if len(chosenChoice.Delta.ToolCalls) > 0 {
|
||||
// tools delta
|
||||
claudeResponse.Delta = &dto.ClaudeMediaMessage{
|
||||
Type: "input_json_delta",
|
||||
PartialJson: chosenChoice.Delta.ToolCalls[0].Function.Arguments,
|
||||
}
|
||||
} else {
|
||||
// text delta
|
||||
claudeResponse.Delta = &dto.ClaudeMediaMessage{
|
||||
Type: "text_delta",
|
||||
Text: common.GetPointer[string](chosenChoice.Delta.GetContentString()),
|
||||
}
|
||||
}
|
||||
claudeResponses = append(claudeResponses, &claudeResponse)
|
||||
}
|
||||
}
|
||||
|
||||
return claudeResponses
|
||||
}
|
||||
|
||||
func ResponseOpenAI2Claude(openAIResponse *dto.OpenAITextResponse, info *relaycommon.RelayInfo) *dto.ClaudeResponse {
|
||||
var stopReason string
|
||||
contents := make([]dto.ClaudeMediaMessage, 0)
|
||||
claudeResponse := &dto.ClaudeResponse{
|
||||
Id: openAIResponse.Id,
|
||||
Type: "message",
|
||||
Role: "assistant",
|
||||
Model: openAIResponse.Model,
|
||||
}
|
||||
for _, choice := range openAIResponse.Choices {
|
||||
stopReason = stopReasonOpenAI2Claude(choice.FinishReason)
|
||||
claudeContent := dto.ClaudeMediaMessage{}
|
||||
if choice.FinishReason == "tool_calls" {
|
||||
claudeContent.Type = "tool_use"
|
||||
claudeContent.Id = choice.Message.ToolCallId
|
||||
claudeContent.Name = choice.Message.ParseToolCalls()[0].Function.Name
|
||||
var mapParams map[string]interface{}
|
||||
if err := json.Unmarshal([]byte(choice.Message.ParseToolCalls()[0].Function.Arguments), &mapParams); err == nil {
|
||||
claudeContent.Input = mapParams
|
||||
} else {
|
||||
claudeContent.Input = choice.Message.ParseToolCalls()[0].Function.Arguments
|
||||
}
|
||||
} else {
|
||||
claudeContent.Type = "text"
|
||||
claudeContent.SetText(choice.Message.StringContent())
|
||||
}
|
||||
contents = append(contents, claudeContent)
|
||||
}
|
||||
claudeResponse.Content = contents
|
||||
claudeResponse.StopReason = stopReason
|
||||
claudeResponse.Usage = &dto.ClaudeUsage{
|
||||
InputTokens: openAIResponse.PromptTokens,
|
||||
OutputTokens: openAIResponse.CompletionTokens,
|
||||
}
|
||||
|
||||
return claudeResponse
|
||||
}
|
||||
|
||||
func stopReasonOpenAI2Claude(reason string) string {
|
||||
switch reason {
|
||||
case "stop":
|
||||
return "end_turn"
|
||||
case "stop_sequence":
|
||||
return "stop_sequence"
|
||||
case "max_tokens":
|
||||
return "max_tokens"
|
||||
case "tool_calls":
|
||||
return "tool_use"
|
||||
default:
|
||||
return reason
|
||||
}
|
||||
}
|
||||
|
||||
func toJSONString(v interface{}) string {
|
||||
b, err := json.Marshal(v)
|
||||
if err != nil {
|
||||
return "{}"
|
||||
}
|
||||
return string(b)
|
||||
}
|
||||
@@ -50,6 +50,30 @@ func OpenAIErrorWrapperLocal(err error, code string, statusCode int) *dto.OpenAI
|
||||
return openaiErr
|
||||
}
|
||||
|
||||
func ClaudeErrorWrapper(err error, code string, statusCode int) *dto.ClaudeErrorWithStatusCode {
|
||||
text := err.Error()
|
||||
lowerText := strings.ToLower(text)
|
||||
if strings.Contains(lowerText, "post") || strings.Contains(lowerText, "dial") || strings.Contains(lowerText, "http") {
|
||||
common.SysLog(fmt.Sprintf("error: %s", text))
|
||||
text = "请求上游地址失败"
|
||||
}
|
||||
claudeError := dto.ClaudeError{
|
||||
Message: text,
|
||||
Type: "new_api_error",
|
||||
//Code: code,
|
||||
}
|
||||
return &dto.ClaudeErrorWithStatusCode{
|
||||
Error: claudeError,
|
||||
StatusCode: statusCode,
|
||||
}
|
||||
}
|
||||
|
||||
func ClaudeErrorWrapperLocal(err error, code string, statusCode int) *dto.ClaudeErrorWithStatusCode {
|
||||
claudeErr := ClaudeErrorWrapper(err, code, statusCode)
|
||||
claudeErr.LocalError = true
|
||||
return claudeErr
|
||||
}
|
||||
|
||||
func RelayErrorHandler(resp *http.Response, showBodyWhenFail bool) (errWithStatusCode *dto.OpenAIErrorWithStatusCode) {
|
||||
errWithStatusCode = &dto.OpenAIErrorWithStatusCode{
|
||||
StatusCode: resp.StatusCode,
|
||||
|
||||
@@ -53,3 +53,12 @@ func GenerateAudioOtherInfo(ctx *gin.Context, relayInfo *relaycommon.RelayInfo,
|
||||
info["audio_completion_ratio"] = audioCompletionRatio
|
||||
return info
|
||||
}
|
||||
|
||||
func GenerateClaudeOtherInfo(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, modelRatio, groupRatio, completionRatio float64,
|
||||
cacheTokens int, cacheRatio float64, cacheCreationTokens int, cacheCreationRatio float64, modelPrice float64) map[string]interface{} {
|
||||
info := GenerateTextOtherInfo(ctx, relayInfo, modelRatio, groupRatio, completionRatio, cacheTokens, cacheRatio, modelPrice)
|
||||
info["claude"] = true
|
||||
info["cache_creation_tokens"] = cacheCreationTokens
|
||||
info["cache_creation_ratio"] = cacheCreationRatio
|
||||
return info
|
||||
}
|
||||
|
||||
@@ -194,6 +194,75 @@ func PostWssConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, mod
|
||||
tokenName, quota, logContent, relayInfo.TokenId, userQuota, int(useTimeSeconds), relayInfo.IsStream, relayInfo.Group, other)
|
||||
}
|
||||
|
||||
func PostClaudeConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo,
|
||||
usage *dto.Usage, preConsumedQuota int, userQuota int, priceData helper.PriceData, extraContent string) {
|
||||
|
||||
useTimeSeconds := time.Now().Unix() - relayInfo.StartTime.Unix()
|
||||
promptTokens := usage.PromptTokens
|
||||
completionTokens := usage.CompletionTokens
|
||||
modelName := relayInfo.OriginModelName
|
||||
|
||||
tokenName := ctx.GetString("token_name")
|
||||
completionRatio := priceData.CompletionRatio
|
||||
modelRatio := priceData.ModelRatio
|
||||
groupRatio := priceData.GroupRatio
|
||||
modelPrice := priceData.ModelPrice
|
||||
|
||||
cacheRatio := priceData.CacheRatio
|
||||
cacheTokens := usage.PromptTokensDetails.CachedTokens
|
||||
|
||||
cacheCreationRatio := priceData.CacheCreationRatio
|
||||
cacheCreationTokens := usage.PromptTokensDetails.CachedCreationTokens
|
||||
|
||||
calculateQuota := 0.0
|
||||
if !priceData.UsePrice {
|
||||
calculateQuota = float64(promptTokens)
|
||||
calculateQuota += float64(cacheTokens) * cacheRatio
|
||||
calculateQuota += float64(cacheCreationTokens) * cacheCreationRatio
|
||||
calculateQuota += float64(completionTokens) * completionRatio
|
||||
calculateQuota = calculateQuota * groupRatio * modelRatio
|
||||
} else {
|
||||
calculateQuota = modelPrice * common.QuotaPerUnit * groupRatio
|
||||
}
|
||||
|
||||
if modelRatio != 0 && calculateQuota <= 0 {
|
||||
calculateQuota = 1
|
||||
}
|
||||
|
||||
quota := int(calculateQuota)
|
||||
|
||||
totalTokens := promptTokens + completionTokens
|
||||
|
||||
var logContent string
|
||||
// record all the consume log even if quota is 0
|
||||
if totalTokens == 0 {
|
||||
// in this case, must be some error happened
|
||||
// we cannot just return, because we may have to return the pre-consumed quota
|
||||
quota = 0
|
||||
logContent += fmt.Sprintf("(可能是上游出错)")
|
||||
common.LogError(ctx, fmt.Sprintf("total tokens is 0, cannot consume quota, userId %d, channelId %d, "+
|
||||
"tokenId %d, model %s, pre-consumed quota %d", relayInfo.UserId, relayInfo.ChannelId, relayInfo.TokenId, modelName, preConsumedQuota))
|
||||
} else {
|
||||
//if sensitiveResp != nil {
|
||||
// logContent += fmt.Sprintf(",敏感词:%s", strings.Join(sensitiveResp.SensitiveWords, ", "))
|
||||
//}
|
||||
quotaDelta := quota - preConsumedQuota
|
||||
if quotaDelta != 0 {
|
||||
err := PostConsumeQuota(relayInfo, quotaDelta, preConsumedQuota, true)
|
||||
if err != nil {
|
||||
common.LogError(ctx, "error consuming token remain quota: "+err.Error())
|
||||
}
|
||||
}
|
||||
model.UpdateUserUsedQuotaAndRequestCount(relayInfo.UserId, quota)
|
||||
model.UpdateChannelUsedQuota(relayInfo.ChannelId, quota)
|
||||
}
|
||||
|
||||
other := GenerateClaudeOtherInfo(ctx, relayInfo, modelRatio, groupRatio, completionRatio,
|
||||
cacheTokens, cacheRatio, cacheCreationTokens, cacheCreationRatio, modelPrice)
|
||||
model.RecordConsumeLog(ctx, relayInfo.UserId, relayInfo.ChannelId, promptTokens, completionTokens, modelName,
|
||||
tokenName, quota, logContent, relayInfo.TokenId, userQuota, int(useTimeSeconds), relayInfo.IsStream, relayInfo.Group, other)
|
||||
}
|
||||
|
||||
func PostAudioConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo,
|
||||
usage *dto.Usage, preConsumedQuota int, userQuota int, priceData helper.PriceData, extraContent string) {
|
||||
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"image"
|
||||
@@ -192,6 +193,110 @@ func CountTokenChatRequest(info *relaycommon.RelayInfo, request dto.GeneralOpenA
|
||||
return tkm, nil
|
||||
}
|
||||
|
||||
func CountTokenClaudeRequest(request dto.ClaudeRequest, model string) (int, error) {
|
||||
tkm := 0
|
||||
|
||||
// Count tokens in messages
|
||||
msgTokens, err := CountTokenClaudeMessages(request.Messages, model, request.Stream)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
tkm += msgTokens
|
||||
|
||||
// Count tokens in system message
|
||||
if request.System != "" {
|
||||
systemTokens, err := CountTokenInput(request.System, model)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
tkm += systemTokens
|
||||
}
|
||||
|
||||
if request.Tools != nil {
|
||||
// check is array
|
||||
if tools, ok := request.Tools.([]any); ok {
|
||||
if len(tools) > 0 {
|
||||
parsedTools, err1 := common.Any2Type[[]dto.Tool](request.Tools)
|
||||
if err1 != nil {
|
||||
return 0, fmt.Errorf("tools: Input should be a valid list: %v", err)
|
||||
}
|
||||
toolTokens, err2 := CountTokenClaudeTools(parsedTools, model)
|
||||
if err2 != nil {
|
||||
return 0, fmt.Errorf("tools: %v", err)
|
||||
}
|
||||
tkm += toolTokens
|
||||
}
|
||||
} else {
|
||||
return 0, errors.New("tools: Input should be a valid list")
|
||||
}
|
||||
}
|
||||
|
||||
return tkm, nil
|
||||
}
|
||||
|
||||
func CountTokenClaudeMessages(messages []dto.ClaudeMessage, model string, stream bool) (int, error) {
|
||||
tokenEncoder := getTokenEncoder(model)
|
||||
tokenNum := 0
|
||||
|
||||
for _, message := range messages {
|
||||
// Count tokens for role
|
||||
tokenNum += getTokenNum(tokenEncoder, message.Role)
|
||||
if message.IsStringContent() {
|
||||
tokenNum += getTokenNum(tokenEncoder, message.GetStringContent())
|
||||
} else {
|
||||
content, err := message.ParseContent()
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
for _, mediaMessage := range content {
|
||||
switch mediaMessage.Type {
|
||||
case "text":
|
||||
tokenNum += getTokenNum(tokenEncoder, mediaMessage.GetText())
|
||||
case "image":
|
||||
//imageTokenNum, err := getClaudeImageToken(mediaMsg.Source, model, stream)
|
||||
//if err != nil {
|
||||
// return 0, err
|
||||
//}
|
||||
tokenNum += 1000
|
||||
case "tool_use":
|
||||
tokenNum += getTokenNum(tokenEncoder, mediaMessage.Name)
|
||||
inputJSON, _ := json.Marshal(mediaMessage.Input)
|
||||
tokenNum += getTokenNum(tokenEncoder, string(inputJSON))
|
||||
case "tool_result":
|
||||
contentJSON, _ := json.Marshal(mediaMessage.Content)
|
||||
tokenNum += getTokenNum(tokenEncoder, string(contentJSON))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Add a constant for message formatting (this may need adjustment based on Claude's exact formatting)
|
||||
tokenNum += len(messages) * 2 // Assuming 2 tokens per message for formatting
|
||||
|
||||
return tokenNum, nil
|
||||
}
|
||||
|
||||
func CountTokenClaudeTools(tools []dto.Tool, model string) (int, error) {
|
||||
tokenEncoder := getTokenEncoder(model)
|
||||
tokenNum := 0
|
||||
|
||||
for _, tool := range tools {
|
||||
tokenNum += getTokenNum(tokenEncoder, tool.Name)
|
||||
tokenNum += getTokenNum(tokenEncoder, tool.Description)
|
||||
|
||||
schemaJSON, err := json.Marshal(tool.InputSchema)
|
||||
if err != nil {
|
||||
return 0, errors.New(fmt.Sprintf("marshal_tool_schema_fail: %s", err.Error()))
|
||||
}
|
||||
tokenNum += getTokenNum(tokenEncoder, string(schemaJSON))
|
||||
}
|
||||
|
||||
// Add a constant for tool formatting (this may need adjustment based on Claude's exact formatting)
|
||||
tokenNum += len(tools) * 3 // Assuming 3 tokens per tool for formatting
|
||||
|
||||
return tokenNum, nil
|
||||
}
|
||||
|
||||
func CountTokenRealtime(info *relaycommon.RelayInfo, request dto.RealtimeEvent, model string) (int, int, error) {
|
||||
audioToken := 0
|
||||
textToken := 0
|
||||
|
||||
Reference in New Issue
Block a user