647 lines
18 KiB
Go
647 lines
18 KiB
Go
package service
|
||
|
||
import (
|
||
"encoding/json"
|
||
"errors"
|
||
"fmt"
|
||
"image"
|
||
"log"
|
||
"math"
|
||
"one-api/common"
|
||
"one-api/constant"
|
||
"one-api/dto"
|
||
relaycommon "one-api/relay/common"
|
||
"one-api/types"
|
||
"strings"
|
||
"sync"
|
||
"unicode/utf8"
|
||
|
||
"github.com/gin-gonic/gin"
|
||
"github.com/tiktoken-go/tokenizer"
|
||
"github.com/tiktoken-go/tokenizer/codec"
|
||
)
|
||
|
||
// tokenEncoderMap won't grow after initialization
|
||
var defaultTokenEncoder tokenizer.Codec
|
||
|
||
// tokenEncoderMap is used to store token encoders for different models
|
||
var tokenEncoderMap = make(map[string]tokenizer.Codec)
|
||
|
||
// tokenEncoderMutex protects tokenEncoderMap for concurrent access
|
||
var tokenEncoderMutex sync.RWMutex
|
||
|
||
func InitTokenEncoders() {
|
||
common.SysLog("initializing token encoders")
|
||
defaultTokenEncoder = codec.NewCl100kBase()
|
||
common.SysLog("token encoders initialized")
|
||
}
|
||
|
||
func getTokenEncoder(model string) tokenizer.Codec {
|
||
// First, try to get the encoder from cache with read lock
|
||
tokenEncoderMutex.RLock()
|
||
if encoder, exists := tokenEncoderMap[model]; exists {
|
||
tokenEncoderMutex.RUnlock()
|
||
return encoder
|
||
}
|
||
tokenEncoderMutex.RUnlock()
|
||
|
||
// If not in cache, create new encoder with write lock
|
||
tokenEncoderMutex.Lock()
|
||
defer tokenEncoderMutex.Unlock()
|
||
|
||
// Double-check if another goroutine already created the encoder
|
||
if encoder, exists := tokenEncoderMap[model]; exists {
|
||
return encoder
|
||
}
|
||
|
||
// Create new encoder
|
||
modelCodec, err := tokenizer.ForModel(tokenizer.Model(model))
|
||
if err != nil {
|
||
// Cache the default encoder for this model to avoid repeated failures
|
||
tokenEncoderMap[model] = defaultTokenEncoder
|
||
return defaultTokenEncoder
|
||
}
|
||
|
||
// Cache the new encoder
|
||
tokenEncoderMap[model] = modelCodec
|
||
return modelCodec
|
||
}
|
||
|
||
func getTokenNum(tokenEncoder tokenizer.Codec, text string) int {
|
||
if text == "" {
|
||
return 0
|
||
}
|
||
tkm, _ := tokenEncoder.Count(text)
|
||
return tkm
|
||
}
|
||
|
||
func getImageToken(fileMeta *types.FileMeta, model string, stream bool) (int, error) {
|
||
if fileMeta == nil {
|
||
return 0, fmt.Errorf("image_url_is_nil")
|
||
}
|
||
|
||
// Defaults for 4o/4.1/4.5 family unless overridden below
|
||
baseTokens := 85
|
||
tileTokens := 170
|
||
|
||
// Model classification
|
||
lowerModel := strings.ToLower(model)
|
||
|
||
// Special cases from existing behavior
|
||
if strings.HasPrefix(lowerModel, "glm-4") {
|
||
return 1047, nil
|
||
}
|
||
|
||
// Patch-based models (32x32 patches, capped at 1536, with multiplier)
|
||
isPatchBased := false
|
||
multiplier := 1.0
|
||
switch {
|
||
case strings.Contains(lowerModel, "gpt-4.1-mini"):
|
||
isPatchBased = true
|
||
multiplier = 1.62
|
||
case strings.Contains(lowerModel, "gpt-4.1-nano"):
|
||
isPatchBased = true
|
||
multiplier = 2.46
|
||
case strings.HasPrefix(lowerModel, "o4-mini"):
|
||
isPatchBased = true
|
||
multiplier = 1.72
|
||
case strings.HasPrefix(lowerModel, "gpt-5-mini"):
|
||
isPatchBased = true
|
||
multiplier = 1.62
|
||
case strings.HasPrefix(lowerModel, "gpt-5-nano"):
|
||
isPatchBased = true
|
||
multiplier = 2.46
|
||
}
|
||
|
||
// Tile-based model tokens and bases per doc
|
||
if !isPatchBased {
|
||
if strings.HasPrefix(lowerModel, "gpt-4o-mini") {
|
||
baseTokens = 2833
|
||
tileTokens = 5667
|
||
} else if strings.HasPrefix(lowerModel, "gpt-5-chat-latest") || (strings.HasPrefix(lowerModel, "gpt-5") && !strings.Contains(lowerModel, "mini") && !strings.Contains(lowerModel, "nano")) {
|
||
baseTokens = 70
|
||
tileTokens = 140
|
||
} else if strings.HasPrefix(lowerModel, "o1") || strings.HasPrefix(lowerModel, "o3") || strings.HasPrefix(lowerModel, "o1-pro") {
|
||
baseTokens = 75
|
||
tileTokens = 150
|
||
} else if strings.Contains(lowerModel, "computer-use-preview") {
|
||
baseTokens = 65
|
||
tileTokens = 129
|
||
} else if strings.Contains(lowerModel, "4.1") || strings.Contains(lowerModel, "4o") || strings.Contains(lowerModel, "4.5") {
|
||
baseTokens = 85
|
||
tileTokens = 170
|
||
}
|
||
}
|
||
|
||
// Respect existing feature flags/short-circuits
|
||
if fileMeta.Detail == "low" && !isPatchBased {
|
||
return baseTokens, nil
|
||
}
|
||
if !constant.GetMediaTokenNotStream && !stream {
|
||
return 3 * baseTokens, nil
|
||
}
|
||
// Normalize detail
|
||
if fileMeta.Detail == "auto" || fileMeta.Detail == "" {
|
||
fileMeta.Detail = "high"
|
||
}
|
||
// Whether to count image tokens at all
|
||
if !constant.GetMediaToken {
|
||
return 3 * baseTokens, nil
|
||
}
|
||
|
||
// Decode image to get dimensions
|
||
var config image.Config
|
||
var err error
|
||
var format string
|
||
var b64str string
|
||
|
||
if fileMeta.ParsedData != nil {
|
||
config, format, b64str, err = DecodeBase64ImageData(fileMeta.ParsedData.Base64Data)
|
||
} else {
|
||
if strings.HasPrefix(fileMeta.OriginData, "http") {
|
||
config, format, err = DecodeUrlImageData(fileMeta.OriginData)
|
||
} else {
|
||
common.SysLog(fmt.Sprintf("decoding image"))
|
||
config, format, b64str, err = DecodeBase64ImageData(fileMeta.OriginData)
|
||
}
|
||
fileMeta.MimeType = format
|
||
}
|
||
|
||
if err != nil {
|
||
return 0, err
|
||
}
|
||
|
||
if config.Width == 0 || config.Height == 0 {
|
||
// not an image
|
||
if format != "" && b64str != "" {
|
||
// file type
|
||
return 3 * baseTokens, nil
|
||
}
|
||
return 0, errors.New(fmt.Sprintf("fail to decode base64 config: %s", fileMeta.OriginData))
|
||
}
|
||
|
||
width := config.Width
|
||
height := config.Height
|
||
log.Printf("format: %s, width: %d, height: %d", format, width, height)
|
||
|
||
if isPatchBased {
|
||
// 32x32 patch-based calculation with 1536 cap and model multiplier
|
||
ceilDiv := func(a, b int) int { return (a + b - 1) / b }
|
||
rawPatchesW := ceilDiv(width, 32)
|
||
rawPatchesH := ceilDiv(height, 32)
|
||
rawPatches := rawPatchesW * rawPatchesH
|
||
if rawPatches > 1536 {
|
||
// scale down
|
||
area := float64(width * height)
|
||
r := math.Sqrt(float64(32*32*1536) / area)
|
||
wScaled := float64(width) * r
|
||
hScaled := float64(height) * r
|
||
// adjust to fit whole number of patches after scaling
|
||
adjW := math.Floor(wScaled/32.0) / (wScaled / 32.0)
|
||
adjH := math.Floor(hScaled/32.0) / (hScaled / 32.0)
|
||
adj := math.Min(adjW, adjH)
|
||
if !math.IsNaN(adj) && adj > 0 {
|
||
r = r * adj
|
||
}
|
||
wScaled = float64(width) * r
|
||
hScaled = float64(height) * r
|
||
patchesW := math.Ceil(wScaled / 32.0)
|
||
patchesH := math.Ceil(hScaled / 32.0)
|
||
imageTokens := int(patchesW * patchesH)
|
||
if imageTokens > 1536 {
|
||
imageTokens = 1536
|
||
}
|
||
return int(math.Round(float64(imageTokens) * multiplier)), nil
|
||
}
|
||
// below cap
|
||
imageTokens := rawPatches
|
||
return int(math.Round(float64(imageTokens) * multiplier)), nil
|
||
}
|
||
|
||
// Tile-based calculation for 4o/4.1/4.5/o1/o3/etc.
|
||
// Step 1: fit within 2048x2048 square
|
||
maxSide := math.Max(float64(width), float64(height))
|
||
fitScale := 1.0
|
||
if maxSide > 2048 {
|
||
fitScale = maxSide / 2048.0
|
||
}
|
||
fitW := int(math.Round(float64(width) / fitScale))
|
||
fitH := int(math.Round(float64(height) / fitScale))
|
||
|
||
// Step 2: scale so that shortest side is exactly 768
|
||
minSide := math.Min(float64(fitW), float64(fitH))
|
||
if minSide == 0 {
|
||
return baseTokens, nil
|
||
}
|
||
shortScale := 768.0 / minSide
|
||
finalW := int(math.Round(float64(fitW) * shortScale))
|
||
finalH := int(math.Round(float64(fitH) * shortScale))
|
||
|
||
// Count 512px tiles
|
||
tilesW := (finalW + 512 - 1) / 512
|
||
tilesH := (finalH + 512 - 1) / 512
|
||
tiles := tilesW * tilesH
|
||
|
||
if common.DebugEnabled {
|
||
log.Printf("scaled to: %dx%d, tiles: %d", finalW, finalH, tiles)
|
||
}
|
||
|
||
return tiles*tileTokens + baseTokens, nil
|
||
}
|
||
|
||
func CountRequestToken(c *gin.Context, meta *types.TokenCountMeta, info *relaycommon.RelayInfo) (int, error) {
|
||
if meta == nil {
|
||
return 0, errors.New("token count meta is nil")
|
||
}
|
||
|
||
if info.RelayFormat == types.RelayFormatOpenAIRealtime {
|
||
return 0, nil
|
||
}
|
||
|
||
model := common.GetContextKeyString(c, constant.ContextKeyOriginalModel)
|
||
tkm := 0
|
||
|
||
if meta.TokenType == types.TokenTypeTextNumber {
|
||
tkm += utf8.RuneCountInString(meta.CombineText)
|
||
} else {
|
||
tkm += CountTextToken(meta.CombineText, model)
|
||
}
|
||
|
||
if info.RelayFormat == types.RelayFormatOpenAI {
|
||
tkm += meta.ToolsCount * 8
|
||
tkm += meta.MessagesCount * 3 // 每条消息的格式化token数量
|
||
tkm += meta.NameCount * 3
|
||
tkm += 3
|
||
}
|
||
|
||
shouldFetchFiles := true
|
||
|
||
if info.RelayFormat == types.RelayFormatOpenAIRealtime || info.RelayFormat == types.RelayFormatGemini {
|
||
shouldFetchFiles = false
|
||
}
|
||
|
||
if shouldFetchFiles {
|
||
for _, file := range meta.Files {
|
||
if strings.HasPrefix(file.OriginData, "http") {
|
||
localFileData, err := GetFileBase64FromUrl(c, file.OriginData, "token_counter")
|
||
if err != nil {
|
||
return 0, fmt.Errorf("error getting file base64 from url: %v", err)
|
||
}
|
||
if strings.HasPrefix(localFileData.MimeType, "image/") {
|
||
file.FileType = types.FileTypeImage
|
||
} else if strings.HasPrefix(localFileData.MimeType, "video/") {
|
||
file.FileType = types.FileTypeVideo
|
||
} else if strings.HasPrefix(localFileData.MimeType, "audio/") {
|
||
file.FileType = types.FileTypeAudio
|
||
} else {
|
||
file.FileType = types.FileTypeFile
|
||
}
|
||
file.MimeType = localFileData.MimeType
|
||
file.ParsedData = localFileData
|
||
}
|
||
}
|
||
}
|
||
|
||
for _, file := range meta.Files {
|
||
switch file.FileType {
|
||
case types.FileTypeImage:
|
||
if info.RelayFormat == types.RelayFormatGemini {
|
||
tkm += 240
|
||
} else {
|
||
token, err := getImageToken(file, model, info.IsStream)
|
||
if err != nil {
|
||
return 0, fmt.Errorf("error counting image token: %v", err)
|
||
}
|
||
tkm += token
|
||
}
|
||
case types.FileTypeAudio:
|
||
tkm += 100
|
||
case types.FileTypeVideo:
|
||
tkm += 5000
|
||
case types.FileTypeFile:
|
||
tkm += 5000
|
||
}
|
||
}
|
||
|
||
common.SetContextKey(c, constant.ContextKeyPromptTokens, tkm)
|
||
return tkm, nil
|
||
}
|
||
|
||
//func CountTokenChatRequest(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) (int, error) {
|
||
// tkm := 0
|
||
// msgTokens, err := CountTokenMessages(info, request.Messages, request.Model, request.Stream)
|
||
// if err != nil {
|
||
// return 0, err
|
||
// }
|
||
// tkm += msgTokens
|
||
// if request.Tools != nil {
|
||
// openaiTools := request.Tools
|
||
// countStr := ""
|
||
// for _, tool := range openaiTools {
|
||
// countStr = tool.Function.Name
|
||
// if tool.Function.Description != "" {
|
||
// countStr += tool.Function.Description
|
||
// }
|
||
// if tool.Function.Parameters != nil {
|
||
// countStr += fmt.Sprintf("%v", tool.Function.Parameters)
|
||
// }
|
||
// }
|
||
// toolTokens := CountTokenInput(countStr, request.Model)
|
||
// tkm += 8
|
||
// tkm += toolTokens
|
||
// }
|
||
//
|
||
// return tkm, nil
|
||
//}
|
||
|
||
func CountTokenClaudeRequest(request dto.ClaudeRequest, model string) (int, error) {
|
||
tkm := 0
|
||
|
||
// Count tokens in messages
|
||
msgTokens, err := CountTokenClaudeMessages(request.Messages, model, request.Stream)
|
||
if err != nil {
|
||
return 0, err
|
||
}
|
||
tkm += msgTokens
|
||
|
||
// Count tokens in system message
|
||
if request.System != "" {
|
||
systemTokens := CountTokenInput(request.System, model)
|
||
tkm += systemTokens
|
||
}
|
||
|
||
if request.Tools != nil {
|
||
// check is array
|
||
if tools, ok := request.Tools.([]any); ok {
|
||
if len(tools) > 0 {
|
||
parsedTools, err1 := common.Any2Type[[]dto.Tool](request.Tools)
|
||
if err1 != nil {
|
||
return 0, fmt.Errorf("tools: Input should be a valid list: %v", err)
|
||
}
|
||
toolTokens, err2 := CountTokenClaudeTools(parsedTools, model)
|
||
if err2 != nil {
|
||
return 0, fmt.Errorf("tools: %v", err)
|
||
}
|
||
tkm += toolTokens
|
||
}
|
||
} else {
|
||
return 0, errors.New("tools: Input should be a valid list")
|
||
}
|
||
}
|
||
|
||
return tkm, nil
|
||
}
|
||
|
||
func CountTokenClaudeMessages(messages []dto.ClaudeMessage, model string, stream bool) (int, error) {
|
||
tokenEncoder := getTokenEncoder(model)
|
||
tokenNum := 0
|
||
|
||
for _, message := range messages {
|
||
// Count tokens for role
|
||
tokenNum += getTokenNum(tokenEncoder, message.Role)
|
||
if message.IsStringContent() {
|
||
tokenNum += getTokenNum(tokenEncoder, message.GetStringContent())
|
||
} else {
|
||
content, err := message.ParseContent()
|
||
if err != nil {
|
||
return 0, err
|
||
}
|
||
for _, mediaMessage := range content {
|
||
switch mediaMessage.Type {
|
||
case "text":
|
||
tokenNum += getTokenNum(tokenEncoder, mediaMessage.GetText())
|
||
case "image":
|
||
//imageTokenNum, err := getClaudeImageToken(mediaMsg.Source, model, stream)
|
||
//if err != nil {
|
||
// return 0, err
|
||
//}
|
||
tokenNum += 1000
|
||
case "tool_use":
|
||
if mediaMessage.Input != nil {
|
||
tokenNum += getTokenNum(tokenEncoder, mediaMessage.Name)
|
||
inputJSON, _ := json.Marshal(mediaMessage.Input)
|
||
tokenNum += getTokenNum(tokenEncoder, string(inputJSON))
|
||
}
|
||
case "tool_result":
|
||
if mediaMessage.Content != nil {
|
||
contentJSON, _ := json.Marshal(mediaMessage.Content)
|
||
tokenNum += getTokenNum(tokenEncoder, string(contentJSON))
|
||
}
|
||
}
|
||
}
|
||
}
|
||
}
|
||
|
||
// Add a constant for message formatting (this may need adjustment based on Claude's exact formatting)
|
||
tokenNum += len(messages) * 2 // Assuming 2 tokens per message for formatting
|
||
|
||
return tokenNum, nil
|
||
}
|
||
|
||
func CountTokenClaudeTools(tools []dto.Tool, model string) (int, error) {
|
||
tokenEncoder := getTokenEncoder(model)
|
||
tokenNum := 0
|
||
|
||
for _, tool := range tools {
|
||
tokenNum += getTokenNum(tokenEncoder, tool.Name)
|
||
tokenNum += getTokenNum(tokenEncoder, tool.Description)
|
||
|
||
schemaJSON, err := json.Marshal(tool.InputSchema)
|
||
if err != nil {
|
||
return 0, errors.New(fmt.Sprintf("marshal_tool_schema_fail: %s", err.Error()))
|
||
}
|
||
tokenNum += getTokenNum(tokenEncoder, string(schemaJSON))
|
||
}
|
||
|
||
// Add a constant for tool formatting (this may need adjustment based on Claude's exact formatting)
|
||
tokenNum += len(tools) * 3 // Assuming 3 tokens per tool for formatting
|
||
|
||
return tokenNum, nil
|
||
}
|
||
|
||
func CountTokenRealtime(info *relaycommon.RelayInfo, request dto.RealtimeEvent, model string) (int, int, error) {
|
||
audioToken := 0
|
||
textToken := 0
|
||
switch request.Type {
|
||
case dto.RealtimeEventTypeSessionUpdate:
|
||
if request.Session != nil {
|
||
msgTokens := CountTextToken(request.Session.Instructions, model)
|
||
textToken += msgTokens
|
||
}
|
||
case dto.RealtimeEventResponseAudioDelta:
|
||
// count audio token
|
||
atk, err := CountAudioTokenOutput(request.Delta, info.OutputAudioFormat)
|
||
if err != nil {
|
||
return 0, 0, fmt.Errorf("error counting audio token: %v", err)
|
||
}
|
||
audioToken += atk
|
||
case dto.RealtimeEventResponseAudioTranscriptionDelta, dto.RealtimeEventResponseFunctionCallArgumentsDelta:
|
||
// count text token
|
||
tkm := CountTextToken(request.Delta, model)
|
||
textToken += tkm
|
||
case dto.RealtimeEventInputAudioBufferAppend:
|
||
// count audio token
|
||
atk, err := CountAudioTokenInput(request.Audio, info.InputAudioFormat)
|
||
if err != nil {
|
||
return 0, 0, fmt.Errorf("error counting audio token: %v", err)
|
||
}
|
||
audioToken += atk
|
||
case dto.RealtimeEventConversationItemCreated:
|
||
if request.Item != nil {
|
||
switch request.Item.Type {
|
||
case "message":
|
||
for _, content := range request.Item.Content {
|
||
if content.Type == "input_text" {
|
||
tokens := CountTextToken(content.Text, model)
|
||
textToken += tokens
|
||
}
|
||
}
|
||
}
|
||
}
|
||
case dto.RealtimeEventTypeResponseDone:
|
||
// count tools token
|
||
if !info.IsFirstRequest {
|
||
if info.RealtimeTools != nil && len(info.RealtimeTools) > 0 {
|
||
for _, tool := range info.RealtimeTools {
|
||
toolTokens := CountTokenInput(tool, model)
|
||
textToken += 8
|
||
textToken += toolTokens
|
||
}
|
||
}
|
||
}
|
||
}
|
||
return textToken, audioToken, nil
|
||
}
|
||
|
||
//func CountTokenMessages(info *relaycommon.RelayInfo, messages []dto.Message, model string, stream bool) (int, error) {
|
||
// //recover when panic
|
||
// tokenEncoder := getTokenEncoder(model)
|
||
// // Reference:
|
||
// // https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb
|
||
// // https://github.com/pkoukk/tiktoken-go/issues/6
|
||
// //
|
||
// // Every message follows <|start|>{role/name}\n{content}<|end|>\n
|
||
// var tokensPerMessage int
|
||
// var tokensPerName int
|
||
//
|
||
// tokensPerMessage = 3
|
||
// tokensPerName = 1
|
||
//
|
||
// tokenNum := 0
|
||
// for _, message := range messages {
|
||
// tokenNum += tokensPerMessage
|
||
// tokenNum += getTokenNum(tokenEncoder, message.Role)
|
||
// if message.Content != nil {
|
||
// if message.Name != nil {
|
||
// tokenNum += tokensPerName
|
||
// tokenNum += getTokenNum(tokenEncoder, *message.Name)
|
||
// }
|
||
// arrayContent := message.ParseContent()
|
||
// for _, m := range arrayContent {
|
||
// if m.Type == dto.ContentTypeImageURL {
|
||
// imageUrl := m.GetImageMedia()
|
||
// imageTokenNum, err := getImageToken(info, imageUrl, model, stream)
|
||
// if err != nil {
|
||
// return 0, err
|
||
// }
|
||
// tokenNum += imageTokenNum
|
||
// log.Printf("image token num: %d", imageTokenNum)
|
||
// } else if m.Type == dto.ContentTypeInputAudio {
|
||
// // TODO: 音频token数量计算
|
||
// tokenNum += 100
|
||
// } else if m.Type == dto.ContentTypeFile {
|
||
// tokenNum += 5000
|
||
// } else if m.Type == dto.ContentTypeVideoUrl {
|
||
// tokenNum += 5000
|
||
// } else {
|
||
// tokenNum += getTokenNum(tokenEncoder, m.Text)
|
||
// }
|
||
// }
|
||
// }
|
||
// }
|
||
// tokenNum += 3 // Every reply is primed with <|start|>assistant<|message|>
|
||
// return tokenNum, nil
|
||
//}
|
||
|
||
func CountTokenInput(input any, model string) int {
|
||
switch v := input.(type) {
|
||
case string:
|
||
return CountTextToken(v, model)
|
||
case []string:
|
||
text := ""
|
||
for _, s := range v {
|
||
text += s
|
||
}
|
||
return CountTextToken(text, model)
|
||
case []interface{}:
|
||
text := ""
|
||
for _, item := range v {
|
||
text += fmt.Sprintf("%v", item)
|
||
}
|
||
return CountTextToken(text, model)
|
||
}
|
||
return CountTokenInput(fmt.Sprintf("%v", input), model)
|
||
}
|
||
|
||
func CountTokenStreamChoices(messages []dto.ChatCompletionsStreamResponseChoice, model string) int {
|
||
tokens := 0
|
||
for _, message := range messages {
|
||
tkm := CountTokenInput(message.Delta.GetContentString(), model)
|
||
tokens += tkm
|
||
if message.Delta.ToolCalls != nil {
|
||
for _, tool := range message.Delta.ToolCalls {
|
||
tkm := CountTokenInput(tool.Function.Name, model)
|
||
tokens += tkm
|
||
tkm = CountTokenInput(tool.Function.Arguments, model)
|
||
tokens += tkm
|
||
}
|
||
}
|
||
}
|
||
return tokens
|
||
}
|
||
|
||
func CountTTSToken(text string, model string) int {
|
||
if strings.HasPrefix(model, "tts") {
|
||
return utf8.RuneCountInString(text)
|
||
} else {
|
||
return CountTextToken(text, model)
|
||
}
|
||
}
|
||
|
||
func CountAudioTokenInput(audioBase64 string, audioFormat string) (int, error) {
|
||
if audioBase64 == "" {
|
||
return 0, nil
|
||
}
|
||
duration, err := parseAudio(audioBase64, audioFormat)
|
||
if err != nil {
|
||
return 0, err
|
||
}
|
||
return int(duration / 60 * 100 / 0.06), nil
|
||
}
|
||
|
||
func CountAudioTokenOutput(audioBase64 string, audioFormat string) (int, error) {
|
||
if audioBase64 == "" {
|
||
return 0, nil
|
||
}
|
||
duration, err := parseAudio(audioBase64, audioFormat)
|
||
if err != nil {
|
||
return 0, err
|
||
}
|
||
return int(duration / 60 * 200 / 0.24), nil
|
||
}
|
||
|
||
//func CountAudioToken(sec float64, audioType string) {
|
||
// if audioType == "input" {
|
||
//
|
||
// }
|
||
//}
|
||
|
||
// CountTextToken 统计文本的token数量,仅当文本包含敏感词,返回错误,同时返回token数量
|
||
func CountTextToken(text string, model string) int {
|
||
if text == "" {
|
||
return 0
|
||
}
|
||
tokenEncoder := getTokenEncoder(model)
|
||
return getTokenNum(tokenEncoder, text)
|
||
}
|