feat: Add CountToken configuration and update token counting logic

This commit is contained in:
CaIon
2025-11-22 17:15:34 +08:00
parent efb8f1f5b8
commit 0952973887
3 changed files with 48 additions and 35 deletions

View File

@@ -111,6 +111,7 @@ func initConstantEnv() {
constant.MaxFileDownloadMB = GetEnvOrDefault("MAX_FILE_DOWNLOAD_MB", 20) constant.MaxFileDownloadMB = GetEnvOrDefault("MAX_FILE_DOWNLOAD_MB", 20)
// ForceStreamOption 覆盖请求参数强制返回usage信息 // ForceStreamOption 覆盖请求参数强制返回usage信息
constant.ForceStreamOption = GetEnvOrDefaultBool("FORCE_STREAM_OPTION", true) constant.ForceStreamOption = GetEnvOrDefaultBool("FORCE_STREAM_OPTION", true)
constant.CountToken = GetEnvOrDefaultBool("CountToken", true)
constant.GetMediaToken = GetEnvOrDefaultBool("GET_MEDIA_TOKEN", true) constant.GetMediaToken = GetEnvOrDefaultBool("GET_MEDIA_TOKEN", true)
constant.GetMediaTokenNotStream = GetEnvOrDefaultBool("GET_MEDIA_TOKEN_NOT_STREAM", false) constant.GetMediaTokenNotStream = GetEnvOrDefaultBool("GET_MEDIA_TOKEN_NOT_STREAM", false)
constant.UpdateTask = GetEnvOrDefaultBool("UPDATE_TASK", true) constant.UpdateTask = GetEnvOrDefaultBool("UPDATE_TASK", true)

View File

@@ -4,6 +4,7 @@ var StreamingTimeout int
var DifyDebug bool var DifyDebug bool
var MaxFileDownloadMB int var MaxFileDownloadMB int
var ForceStreamOption bool var ForceStreamOption bool
var CountToken bool
var GetMediaToken bool var GetMediaToken bool
var GetMediaTokenNotStream bool var GetMediaTokenNotStream bool
var UpdateTask bool var UpdateTask bool

View File

@@ -143,6 +143,12 @@ func getImageToken(fileMeta *types.FileMeta, model string, stream bool) (int, er
if fileMeta.Detail == "low" && !isPatchBased { if fileMeta.Detail == "low" && !isPatchBased {
return baseTokens, nil return baseTokens, nil
} }
// Whether to count image tokens at all
if !constant.GetMediaToken {
return 3 * baseTokens, nil
}
if !constant.GetMediaTokenNotStream && !stream { if !constant.GetMediaTokenNotStream && !stream {
return 3 * baseTokens, nil return 3 * baseTokens, nil
} }
@@ -150,10 +156,6 @@ func getImageToken(fileMeta *types.FileMeta, model string, stream bool) (int, er
if fileMeta.Detail == "auto" || fileMeta.Detail == "" { if fileMeta.Detail == "auto" || fileMeta.Detail == "" {
fileMeta.Detail = "high" fileMeta.Detail = "high"
} }
// Whether to count image tokens at all
if !constant.GetMediaToken {
return 3 * baseTokens, nil
}
// Decode image to get dimensions // Decode image to get dimensions
var config image.Config var config image.Config
@@ -256,16 +258,15 @@ func getImageToken(fileMeta *types.FileMeta, model string, stream bool) (int, er
} }
func CountRequestToken(c *gin.Context, meta *types.TokenCountMeta, info *relaycommon.RelayInfo) (int, error) { func CountRequestToken(c *gin.Context, meta *types.TokenCountMeta, info *relaycommon.RelayInfo) (int, error) {
// 是否统计token
if !constant.CountToken {
return 0, nil
}
if meta == nil { if meta == nil {
return 0, errors.New("token count meta is nil") return 0, errors.New("token count meta is nil")
} }
if !constant.GetMediaToken {
return 0, nil
}
if !constant.GetMediaTokenNotStream && !info.IsStream {
return 0, nil
}
if info.RelayFormat == types.RelayFormatOpenAIRealtime { if info.RelayFormat == types.RelayFormatOpenAIRealtime {
return 0, nil return 0, nil
} }
@@ -316,9 +317,19 @@ func CountRequestToken(c *gin.Context, meta *types.TokenCountMeta, info *relayco
shouldFetchFiles = false shouldFetchFiles = false
} }
if shouldFetchFiles { // 是否本地计算媒体token数量
if !constant.GetMediaToken {
shouldFetchFiles = false
}
// 是否在非流模式下本地计算媒体token数量
if !constant.GetMediaTokenNotStream && !info.IsStream {
shouldFetchFiles = false
}
for _, file := range meta.Files { for _, file := range meta.Files {
if strings.HasPrefix(file.OriginData, "http") { if strings.HasPrefix(file.OriginData, "http") {
if shouldFetchFiles {
mineType, err := GetFileTypeFromUrl(c, file.OriginData, "token_counter") mineType, err := GetFileTypeFromUrl(c, file.OriginData, "token_counter")
if err != nil { if err != nil {
return 0, fmt.Errorf("error getting file base64 from url: %v", err) return 0, fmt.Errorf("error getting file base64 from url: %v", err)
@@ -333,6 +344,7 @@ func CountRequestToken(c *gin.Context, meta *types.TokenCountMeta, info *relayco
file.FileType = types.FileTypeFile file.FileType = types.FileTypeFile
} }
file.MimeType = mineType file.MimeType = mineType
}
} else if strings.HasPrefix(file.OriginData, "data:") { } else if strings.HasPrefix(file.OriginData, "data:") {
// get mime type from base64 header // get mime type from base64 header
parts := strings.SplitN(file.OriginData, ",", 2) parts := strings.SplitN(file.OriginData, ",", 2)
@@ -359,13 +371,12 @@ func CountRequestToken(c *gin.Context, meta *types.TokenCountMeta, info *relayco
} }
} }
} }
}
for i, file := range meta.Files { for i, file := range meta.Files {
switch file.FileType { switch file.FileType {
case types.FileTypeImage: case types.FileTypeImage:
if info.RelayFormat == types.RelayFormatGemini { if info.RelayFormat == types.RelayFormatGemini {
tkm += 256 tkm += 520 // gemini per input image tokens
} else { } else {
token, err := getImageToken(file, model, info.IsStream) token, err := getImageToken(file, model, info.IsStream)
if err != nil { if err != nil {