feat(file): unify file handling with a new FileSource abstraction for URL and base64 data
This commit is contained in:
@@ -3,10 +3,6 @@ package service
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"image"
|
||||
_ "image/gif"
|
||||
_ "image/jpeg"
|
||||
_ "image/png"
|
||||
"log"
|
||||
"math"
|
||||
"path/filepath"
|
||||
@@ -23,8 +19,8 @@ import (
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
func getImageToken(fileMeta *types.FileMeta, model string, stream bool) (int, error) {
|
||||
if fileMeta == nil {
|
||||
func getImageToken(c *gin.Context, fileMeta *types.FileMeta, model string, stream bool) (int, error) {
|
||||
if fileMeta == nil || fileMeta.Source == nil {
|
||||
return 0, fmt.Errorf("image_url_is_nil")
|
||||
}
|
||||
|
||||
@@ -99,35 +95,20 @@ func getImageToken(fileMeta *types.FileMeta, model string, stream bool) (int, er
|
||||
fileMeta.Detail = "high"
|
||||
}
|
||||
|
||||
// Decode image to get dimensions
|
||||
var config image.Config
|
||||
var err error
|
||||
var format string
|
||||
var b64str string
|
||||
|
||||
if fileMeta.ParsedData != nil {
|
||||
config, format, b64str, err = DecodeBase64ImageData(fileMeta.ParsedData.Base64Data)
|
||||
} else {
|
||||
if strings.HasPrefix(fileMeta.OriginData, "http") {
|
||||
config, format, err = DecodeUrlImageData(fileMeta.OriginData)
|
||||
} else {
|
||||
common.SysLog(fmt.Sprintf("decoding image"))
|
||||
config, format, b64str, err = DecodeBase64ImageData(fileMeta.OriginData)
|
||||
}
|
||||
fileMeta.MimeType = format
|
||||
}
|
||||
|
||||
// 使用统一的文件服务获取图片配置
|
||||
config, format, err := GetImageConfig(c, fileMeta.Source)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
fileMeta.MimeType = format
|
||||
|
||||
if config.Width == 0 || config.Height == 0 {
|
||||
// not an image
|
||||
if format != "" && b64str != "" {
|
||||
// not an image, but might be a valid file
|
||||
if format != "" {
|
||||
// file type
|
||||
return 3 * baseTokens, nil
|
||||
}
|
||||
return 0, errors.New(fmt.Sprintf("fail to decode base64 config: %s", fileMeta.OriginData))
|
||||
return 0, errors.New(fmt.Sprintf("fail to decode image config: %s", fileMeta.GetIdentifier()))
|
||||
}
|
||||
|
||||
width := config.Width
|
||||
@@ -269,48 +250,24 @@ func EstimateRequestToken(c *gin.Context, meta *types.TokenCountMeta, info *rela
|
||||
shouldFetchFiles = false
|
||||
}
|
||||
|
||||
// 使用统一的文件服务获取文件类型
|
||||
for _, file := range meta.Files {
|
||||
if strings.HasPrefix(file.OriginData, "http") {
|
||||
if shouldFetchFiles {
|
||||
mineType, err := GetFileTypeFromUrl(c, file.OriginData, "token_counter")
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("error getting file base64 from url: %v", err)
|
||||
}
|
||||
if strings.HasPrefix(mineType, "image/") {
|
||||
file.FileType = types.FileTypeImage
|
||||
} else if strings.HasPrefix(mineType, "video/") {
|
||||
file.FileType = types.FileTypeVideo
|
||||
} else if strings.HasPrefix(mineType, "audio/") {
|
||||
file.FileType = types.FileTypeAudio
|
||||
} else {
|
||||
file.FileType = types.FileTypeFile
|
||||
}
|
||||
file.MimeType = mineType
|
||||
}
|
||||
} else if strings.HasPrefix(file.OriginData, "data:") {
|
||||
// get mime type from base64 header
|
||||
parts := strings.SplitN(file.OriginData, ",", 2)
|
||||
if len(parts) >= 1 {
|
||||
header := parts[0]
|
||||
// Extract mime type from "data:mime/type;base64" format
|
||||
if strings.Contains(header, ":") && strings.Contains(header, ";") {
|
||||
mimeStart := strings.Index(header, ":") + 1
|
||||
mimeEnd := strings.Index(header, ";")
|
||||
if mimeStart < mimeEnd {
|
||||
mineType := header[mimeStart:mimeEnd]
|
||||
if strings.HasPrefix(mineType, "image/") {
|
||||
file.FileType = types.FileTypeImage
|
||||
} else if strings.HasPrefix(mineType, "video/") {
|
||||
file.FileType = types.FileTypeVideo
|
||||
} else if strings.HasPrefix(mineType, "audio/") {
|
||||
file.FileType = types.FileTypeAudio
|
||||
} else {
|
||||
file.FileType = types.FileTypeFile
|
||||
}
|
||||
file.MimeType = mineType
|
||||
}
|
||||
if file.Source == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
// 如果文件类型未知且需要获取,通过 MIME 类型检测
|
||||
if file.FileType == "" || (file.Source.IsURL() && shouldFetchFiles) {
|
||||
mimeType, err := GetMimeType(c, file.Source)
|
||||
if err != nil {
|
||||
if shouldFetchFiles {
|
||||
return 0, fmt.Errorf("error getting file type: %v", err)
|
||||
}
|
||||
// 如果不需要获取,使用默认类型
|
||||
continue
|
||||
}
|
||||
file.MimeType = mimeType
|
||||
file.FileType = DetectFileType(mimeType)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -318,9 +275,9 @@ func EstimateRequestToken(c *gin.Context, meta *types.TokenCountMeta, info *rela
|
||||
switch file.FileType {
|
||||
case types.FileTypeImage:
|
||||
if common.IsOpenAITextModel(model) {
|
||||
token, err := getImageToken(file, model, info.IsStream)
|
||||
token, err := getImageToken(c, file, model, info.IsStream)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("error counting image token, media index[%d], original data[%s], err: %v", i, file.OriginData, err)
|
||||
return 0, fmt.Errorf("error counting image token, media index[%d], identifier[%s], err: %v", i, file.GetIdentifier(), err)
|
||||
}
|
||||
tkm += token
|
||||
} else {
|
||||
|
||||
Reference in New Issue
Block a user