feat: Add CountToken configuration and update token counting logic
This commit is contained in:
@@ -111,6 +111,7 @@ func initConstantEnv() {
|
|||||||
constant.MaxFileDownloadMB = GetEnvOrDefault("MAX_FILE_DOWNLOAD_MB", 20)
|
constant.MaxFileDownloadMB = GetEnvOrDefault("MAX_FILE_DOWNLOAD_MB", 20)
|
||||||
// ForceStreamOption 覆盖请求参数,强制返回usage信息
|
// ForceStreamOption 覆盖请求参数,强制返回usage信息
|
||||||
constant.ForceStreamOption = GetEnvOrDefaultBool("FORCE_STREAM_OPTION", true)
|
constant.ForceStreamOption = GetEnvOrDefaultBool("FORCE_STREAM_OPTION", true)
|
||||||
|
constant.CountToken = GetEnvOrDefaultBool("CountToken", true)
|
||||||
constant.GetMediaToken = GetEnvOrDefaultBool("GET_MEDIA_TOKEN", true)
|
constant.GetMediaToken = GetEnvOrDefaultBool("GET_MEDIA_TOKEN", true)
|
||||||
constant.GetMediaTokenNotStream = GetEnvOrDefaultBool("GET_MEDIA_TOKEN_NOT_STREAM", false)
|
constant.GetMediaTokenNotStream = GetEnvOrDefaultBool("GET_MEDIA_TOKEN_NOT_STREAM", false)
|
||||||
constant.UpdateTask = GetEnvOrDefaultBool("UPDATE_TASK", true)
|
constant.UpdateTask = GetEnvOrDefaultBool("UPDATE_TASK", true)
|
||||||
|
|||||||
@@ -4,6 +4,7 @@ var StreamingTimeout int
|
|||||||
var DifyDebug bool
|
var DifyDebug bool
|
||||||
var MaxFileDownloadMB int
|
var MaxFileDownloadMB int
|
||||||
var ForceStreamOption bool
|
var ForceStreamOption bool
|
||||||
|
var CountToken bool
|
||||||
var GetMediaToken bool
|
var GetMediaToken bool
|
||||||
var GetMediaTokenNotStream bool
|
var GetMediaTokenNotStream bool
|
||||||
var UpdateTask bool
|
var UpdateTask bool
|
||||||
|
|||||||
@@ -143,6 +143,12 @@ func getImageToken(fileMeta *types.FileMeta, model string, stream bool) (int, er
|
|||||||
if fileMeta.Detail == "low" && !isPatchBased {
|
if fileMeta.Detail == "low" && !isPatchBased {
|
||||||
return baseTokens, nil
|
return baseTokens, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Whether to count image tokens at all
|
||||||
|
if !constant.GetMediaToken {
|
||||||
|
return 3 * baseTokens, nil
|
||||||
|
}
|
||||||
|
|
||||||
if !constant.GetMediaTokenNotStream && !stream {
|
if !constant.GetMediaTokenNotStream && !stream {
|
||||||
return 3 * baseTokens, nil
|
return 3 * baseTokens, nil
|
||||||
}
|
}
|
||||||
@@ -150,10 +156,6 @@ func getImageToken(fileMeta *types.FileMeta, model string, stream bool) (int, er
|
|||||||
if fileMeta.Detail == "auto" || fileMeta.Detail == "" {
|
if fileMeta.Detail == "auto" || fileMeta.Detail == "" {
|
||||||
fileMeta.Detail = "high"
|
fileMeta.Detail = "high"
|
||||||
}
|
}
|
||||||
// Whether to count image tokens at all
|
|
||||||
if !constant.GetMediaToken {
|
|
||||||
return 3 * baseTokens, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Decode image to get dimensions
|
// Decode image to get dimensions
|
||||||
var config image.Config
|
var config image.Config
|
||||||
@@ -256,16 +258,15 @@ func getImageToken(fileMeta *types.FileMeta, model string, stream bool) (int, er
|
|||||||
}
|
}
|
||||||
|
|
||||||
func CountRequestToken(c *gin.Context, meta *types.TokenCountMeta, info *relaycommon.RelayInfo) (int, error) {
|
func CountRequestToken(c *gin.Context, meta *types.TokenCountMeta, info *relaycommon.RelayInfo) (int, error) {
|
||||||
|
// 是否统计token
|
||||||
|
if !constant.CountToken {
|
||||||
|
return 0, nil
|
||||||
|
}
|
||||||
|
|
||||||
if meta == nil {
|
if meta == nil {
|
||||||
return 0, errors.New("token count meta is nil")
|
return 0, errors.New("token count meta is nil")
|
||||||
}
|
}
|
||||||
|
|
||||||
if !constant.GetMediaToken {
|
|
||||||
return 0, nil
|
|
||||||
}
|
|
||||||
if !constant.GetMediaTokenNotStream && !info.IsStream {
|
|
||||||
return 0, nil
|
|
||||||
}
|
|
||||||
if info.RelayFormat == types.RelayFormatOpenAIRealtime {
|
if info.RelayFormat == types.RelayFormatOpenAIRealtime {
|
||||||
return 0, nil
|
return 0, nil
|
||||||
}
|
}
|
||||||
@@ -316,9 +317,19 @@ func CountRequestToken(c *gin.Context, meta *types.TokenCountMeta, info *relayco
|
|||||||
shouldFetchFiles = false
|
shouldFetchFiles = false
|
||||||
}
|
}
|
||||||
|
|
||||||
if shouldFetchFiles {
|
// 是否本地计算媒体token数量
|
||||||
for _, file := range meta.Files {
|
if !constant.GetMediaToken {
|
||||||
if strings.HasPrefix(file.OriginData, "http") {
|
shouldFetchFiles = false
|
||||||
|
}
|
||||||
|
|
||||||
|
// 是否在非流模式下本地计算媒体token数量
|
||||||
|
if !constant.GetMediaTokenNotStream && !info.IsStream {
|
||||||
|
shouldFetchFiles = false
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, file := range meta.Files {
|
||||||
|
if strings.HasPrefix(file.OriginData, "http") {
|
||||||
|
if shouldFetchFiles {
|
||||||
mineType, err := GetFileTypeFromUrl(c, file.OriginData, "token_counter")
|
mineType, err := GetFileTypeFromUrl(c, file.OriginData, "token_counter")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return 0, fmt.Errorf("error getting file base64 from url: %v", err)
|
return 0, fmt.Errorf("error getting file base64 from url: %v", err)
|
||||||
@@ -333,28 +344,28 @@ func CountRequestToken(c *gin.Context, meta *types.TokenCountMeta, info *relayco
|
|||||||
file.FileType = types.FileTypeFile
|
file.FileType = types.FileTypeFile
|
||||||
}
|
}
|
||||||
file.MimeType = mineType
|
file.MimeType = mineType
|
||||||
} else if strings.HasPrefix(file.OriginData, "data:") {
|
}
|
||||||
// get mime type from base64 header
|
} else if strings.HasPrefix(file.OriginData, "data:") {
|
||||||
parts := strings.SplitN(file.OriginData, ",", 2)
|
// get mime type from base64 header
|
||||||
if len(parts) >= 1 {
|
parts := strings.SplitN(file.OriginData, ",", 2)
|
||||||
header := parts[0]
|
if len(parts) >= 1 {
|
||||||
// Extract mime type from "data:mime/type;base64" format
|
header := parts[0]
|
||||||
if strings.Contains(header, ":") && strings.Contains(header, ";") {
|
// Extract mime type from "data:mime/type;base64" format
|
||||||
mimeStart := strings.Index(header, ":") + 1
|
if strings.Contains(header, ":") && strings.Contains(header, ";") {
|
||||||
mimeEnd := strings.Index(header, ";")
|
mimeStart := strings.Index(header, ":") + 1
|
||||||
if mimeStart < mimeEnd {
|
mimeEnd := strings.Index(header, ";")
|
||||||
mineType := header[mimeStart:mimeEnd]
|
if mimeStart < mimeEnd {
|
||||||
if strings.HasPrefix(mineType, "image/") {
|
mineType := header[mimeStart:mimeEnd]
|
||||||
file.FileType = types.FileTypeImage
|
if strings.HasPrefix(mineType, "image/") {
|
||||||
} else if strings.HasPrefix(mineType, "video/") {
|
file.FileType = types.FileTypeImage
|
||||||
file.FileType = types.FileTypeVideo
|
} else if strings.HasPrefix(mineType, "video/") {
|
||||||
} else if strings.HasPrefix(mineType, "audio/") {
|
file.FileType = types.FileTypeVideo
|
||||||
file.FileType = types.FileTypeAudio
|
} else if strings.HasPrefix(mineType, "audio/") {
|
||||||
} else {
|
file.FileType = types.FileTypeAudio
|
||||||
file.FileType = types.FileTypeFile
|
} else {
|
||||||
}
|
file.FileType = types.FileTypeFile
|
||||||
file.MimeType = mineType
|
|
||||||
}
|
}
|
||||||
|
file.MimeType = mineType
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -365,7 +376,7 @@ func CountRequestToken(c *gin.Context, meta *types.TokenCountMeta, info *relayco
|
|||||||
switch file.FileType {
|
switch file.FileType {
|
||||||
case types.FileTypeImage:
|
case types.FileTypeImage:
|
||||||
if info.RelayFormat == types.RelayFormatGemini {
|
if info.RelayFormat == types.RelayFormatGemini {
|
||||||
tkm += 256
|
tkm += 520 // gemini per input image tokens
|
||||||
} else {
|
} else {
|
||||||
token, err := getImageToken(file, model, info.IsStream)
|
token, err := getImageToken(file, model, info.IsStream)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|||||||
Reference in New Issue
Block a user