feat(file): unify file handling with a new FileSource abstraction for URL and base64 data
This commit is contained in:
@@ -2,7 +2,6 @@ package service
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"image"
|
||||
_ "image/gif"
|
||||
@@ -13,7 +12,6 @@ import (
|
||||
"strings"
|
||||
|
||||
"github.com/QuantumNous/new-api/common"
|
||||
"github.com/QuantumNous/new-api/constant"
|
||||
"github.com/QuantumNous/new-api/logger"
|
||||
"github.com/QuantumNous/new-api/types"
|
||||
|
||||
@@ -130,90 +128,27 @@ func GetFileTypeFromUrl(c *gin.Context, url string, reason ...string) (string, e
|
||||
return "application/octet-stream", nil
|
||||
}
|
||||
|
||||
// GetFileBase64FromUrl 从 URL 获取文件的 base64 编码数据
|
||||
// Deprecated: 请使用 GetBase64Data 配合 types.NewURLFileSource 替代
|
||||
// 此函数保留用于向后兼容,内部已重构为调用统一的文件服务
|
||||
func GetFileBase64FromUrl(c *gin.Context, url string, reason ...string) (*types.LocalFileData, error) {
|
||||
contextKey := fmt.Sprintf("file_download_%s", common.GenerateHMAC(url))
|
||||
|
||||
// Check if the file has already been downloaded in this request
|
||||
if cachedData, exists := c.Get(contextKey); exists {
|
||||
if common.DebugEnabled {
|
||||
logger.LogDebug(c, fmt.Sprintf("Using cached file data for URL: %s", url))
|
||||
}
|
||||
return cachedData.(*types.LocalFileData), nil
|
||||
}
|
||||
|
||||
var maxFileSize = constant.MaxFileDownloadMB * 1024 * 1024
|
||||
|
||||
resp, err := DoDownloadRequest(url, reason...)
|
||||
source := types.NewURLFileSource(url)
|
||||
cachedData, err := LoadFileSource(c, source, reason...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
// Always use LimitReader to prevent oversized downloads
|
||||
fileBytes, err := io.ReadAll(io.LimitReader(resp.Body, int64(maxFileSize+1)))
|
||||
// 转换为旧的 LocalFileData 格式以保持兼容
|
||||
base64Data, err := cachedData.GetBase64Data()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// Check actual size after reading
|
||||
if len(fileBytes) > maxFileSize {
|
||||
return nil, fmt.Errorf("file size exceeds maximum allowed size: %dMB", constant.MaxFileDownloadMB)
|
||||
}
|
||||
|
||||
// Convert to base64
|
||||
base64Data := base64.StdEncoding.EncodeToString(fileBytes)
|
||||
|
||||
mimeType := resp.Header.Get("Content-Type")
|
||||
if len(strings.Split(mimeType, ";")) > 1 {
|
||||
// If Content-Type has parameters, take the first part
|
||||
mimeType = strings.Split(mimeType, ";")[0]
|
||||
}
|
||||
if mimeType == "application/octet-stream" {
|
||||
logger.LogDebug(c, fmt.Sprintf("MIME type is application/octet-stream for URL: %s", url))
|
||||
// try to guess the MIME type from the url last segment
|
||||
urlParts := strings.Split(url, "/")
|
||||
if len(urlParts) > 0 {
|
||||
lastSegment := urlParts[len(urlParts)-1]
|
||||
if strings.Contains(lastSegment, ".") {
|
||||
// Extract the file extension
|
||||
filename := strings.Split(lastSegment, ".")
|
||||
if len(filename) > 1 {
|
||||
ext := strings.ToLower(filename[len(filename)-1])
|
||||
// Guess MIME type based on file extension
|
||||
mimeType = GetMimeTypeByExtension(ext)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// try to guess the MIME type from the file extension
|
||||
fileName := resp.Header.Get("Content-Disposition")
|
||||
if fileName != "" {
|
||||
// Extract the filename from the Content-Disposition header
|
||||
parts := strings.Split(fileName, ";")
|
||||
for _, part := range parts {
|
||||
if strings.HasPrefix(strings.TrimSpace(part), "filename=") {
|
||||
fileName = strings.TrimSpace(strings.TrimPrefix(part, "filename="))
|
||||
// Remove quotes if present
|
||||
if len(fileName) > 2 && fileName[0] == '"' && fileName[len(fileName)-1] == '"' {
|
||||
fileName = fileName[1 : len(fileName)-1]
|
||||
}
|
||||
// Guess MIME type based on file extension
|
||||
if ext := strings.ToLower(strings.TrimPrefix(fileName, ".")); ext != "" {
|
||||
mimeType = GetMimeTypeByExtension(ext)
|
||||
}
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
data := &types.LocalFileData{
|
||||
return &types.LocalFileData{
|
||||
Base64Data: base64Data,
|
||||
MimeType: mimeType,
|
||||
Size: int64(len(fileBytes)),
|
||||
}
|
||||
// Store the file data in the context to avoid re-downloading
|
||||
c.Set(contextKey, data)
|
||||
|
||||
return data, nil
|
||||
MimeType: cachedData.MimeType,
|
||||
Size: cachedData.Size,
|
||||
Url: url,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func GetMimeTypeByExtension(ext string) string {
|
||||
|
||||
451
service/file_service.go
Normal file
451
service/file_service.go
Normal file
@@ -0,0 +1,451 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"image"
|
||||
_ "image/gif"
|
||||
_ "image/jpeg"
|
||||
_ "image/png"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/QuantumNous/new-api/common"
|
||||
"github.com/QuantumNous/new-api/constant"
|
||||
"github.com/QuantumNous/new-api/logger"
|
||||
"github.com/QuantumNous/new-api/types"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"golang.org/x/image/webp"
|
||||
)
|
||||
|
||||
// FileService 统一的文件处理服务
|
||||
// 提供文件下载、解码、缓存等功能的统一入口
|
||||
|
||||
// getContextCacheKey 生成 context 缓存的 key
|
||||
func getContextCacheKey(url string) string {
|
||||
return fmt.Sprintf("file_cache_%s", common.GenerateHMAC(url))
|
||||
}
|
||||
|
||||
// LoadFileSource 加载文件源数据
|
||||
// 这是统一的入口,会自动处理缓存和不同的来源类型
|
||||
func LoadFileSource(c *gin.Context, source *types.FileSource, reason ...string) (*types.CachedFileData, error) {
|
||||
if source == nil {
|
||||
return nil, fmt.Errorf("file source is nil")
|
||||
}
|
||||
|
||||
// 如果已有缓存,直接返回
|
||||
if source.HasCache() {
|
||||
return source.GetCache(), nil
|
||||
}
|
||||
|
||||
var cachedData *types.CachedFileData
|
||||
var err error
|
||||
|
||||
if source.IsURL() {
|
||||
cachedData, err = loadFromURL(c, source.URL, reason...)
|
||||
} else {
|
||||
cachedData, err = loadFromBase64(source.Base64Data, source.MimeType)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 设置缓存
|
||||
source.SetCache(cachedData)
|
||||
|
||||
// 注册到 context 以便请求结束时自动清理
|
||||
if c != nil {
|
||||
registerSourceForCleanup(c, source)
|
||||
}
|
||||
|
||||
return cachedData, nil
|
||||
}
|
||||
|
||||
// registerSourceForCleanup 注册 FileSource 到 context 以便请求结束时清理
|
||||
func registerSourceForCleanup(c *gin.Context, source *types.FileSource) {
|
||||
key := string(constant.ContextKeyFileSourcesToCleanup)
|
||||
var sources []*types.FileSource
|
||||
if existing, exists := c.Get(key); exists {
|
||||
sources = existing.([]*types.FileSource)
|
||||
}
|
||||
sources = append(sources, source)
|
||||
c.Set(key, sources)
|
||||
}
|
||||
|
||||
// CleanupFileSources 清理请求中所有注册的 FileSource
|
||||
// 应在请求结束时调用(通常由中间件自动调用)
|
||||
func CleanupFileSources(c *gin.Context) {
|
||||
key := string(constant.ContextKeyFileSourcesToCleanup)
|
||||
if sources, exists := c.Get(key); exists {
|
||||
for _, source := range sources.([]*types.FileSource) {
|
||||
if cache := source.GetCache(); cache != nil {
|
||||
if cache.IsDisk() {
|
||||
common.DecrementDiskFiles(cache.Size)
|
||||
}
|
||||
cache.Close()
|
||||
}
|
||||
}
|
||||
c.Set(key, nil) // 清除引用
|
||||
}
|
||||
}
|
||||
|
||||
// loadFromURL 从 URL 加载文件
|
||||
// 支持磁盘缓存:当文件大小超过阈值且磁盘缓存可用时,将数据存储到磁盘
|
||||
func loadFromURL(c *gin.Context, url string, reason ...string) (*types.CachedFileData, error) {
|
||||
contextKey := getContextCacheKey(url)
|
||||
|
||||
// 检查 context 缓存
|
||||
if cachedData, exists := c.Get(contextKey); exists {
|
||||
if common.DebugEnabled {
|
||||
logger.LogDebug(c, fmt.Sprintf("Using cached file data for URL: %s", url))
|
||||
}
|
||||
return cachedData.(*types.CachedFileData), nil
|
||||
}
|
||||
|
||||
// 下载文件
|
||||
var maxFileSize = constant.MaxFileDownloadMB * 1024 * 1024
|
||||
|
||||
resp, err := DoDownloadRequest(url, reason...)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to download file from %s: %w", url, err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != 200 {
|
||||
return nil, fmt.Errorf("failed to download file, status code: %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
// 读取文件内容(限制大小)
|
||||
fileBytes, err := io.ReadAll(io.LimitReader(resp.Body, int64(maxFileSize+1)))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to read file content: %w", err)
|
||||
}
|
||||
if len(fileBytes) > maxFileSize {
|
||||
return nil, fmt.Errorf("file size exceeds maximum allowed size: %dMB", constant.MaxFileDownloadMB)
|
||||
}
|
||||
|
||||
// 转换为 base64
|
||||
base64Data := base64.StdEncoding.EncodeToString(fileBytes)
|
||||
|
||||
// 智能获取 MIME 类型
|
||||
mimeType := smartDetectMimeType(resp, url, fileBytes)
|
||||
|
||||
// 判断是否使用磁盘缓存
|
||||
base64Size := int64(len(base64Data))
|
||||
var cachedData *types.CachedFileData
|
||||
|
||||
if shouldUseDiskCache(base64Size) {
|
||||
// 使用磁盘缓存
|
||||
diskPath, err := writeToDiskCache(base64Data)
|
||||
if err != nil {
|
||||
// 磁盘缓存失败,回退到内存
|
||||
logger.LogWarn(c, fmt.Sprintf("Failed to write to disk cache, falling back to memory: %v", err))
|
||||
cachedData = types.NewMemoryCachedData(base64Data, mimeType, int64(len(fileBytes)))
|
||||
} else {
|
||||
cachedData = types.NewDiskCachedData(diskPath, mimeType, int64(len(fileBytes)))
|
||||
common.IncrementDiskFiles(base64Size)
|
||||
if common.DebugEnabled {
|
||||
logger.LogDebug(c, fmt.Sprintf("File cached to disk: %s, size: %d bytes", diskPath, base64Size))
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// 使用内存缓存
|
||||
cachedData = types.NewMemoryCachedData(base64Data, mimeType, int64(len(fileBytes)))
|
||||
}
|
||||
|
||||
// 如果是图片,尝试获取图片配置
|
||||
if strings.HasPrefix(mimeType, "image/") {
|
||||
config, format, err := decodeImageConfig(fileBytes)
|
||||
if err == nil {
|
||||
cachedData.ImageConfig = &config
|
||||
cachedData.ImageFormat = format
|
||||
// 如果通过图片解码获取了更准确的格式,更新 MIME 类型
|
||||
if mimeType == "application/octet-stream" || mimeType == "" {
|
||||
cachedData.MimeType = "image/" + format
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 存入 context 缓存
|
||||
c.Set(contextKey, cachedData)
|
||||
|
||||
return cachedData, nil
|
||||
}
|
||||
|
||||
// shouldUseDiskCache 判断是否应该使用磁盘缓存
|
||||
func shouldUseDiskCache(dataSize int64) bool {
|
||||
return common.ShouldUseDiskCache(dataSize)
|
||||
}
|
||||
|
||||
// writeToDiskCache 将数据写入磁盘缓存
|
||||
func writeToDiskCache(base64Data string) (string, error) {
|
||||
return common.WriteDiskCacheFileString(common.DiskCacheTypeFile, base64Data)
|
||||
}
|
||||
|
||||
// smartDetectMimeType 智能检测 MIME 类型
|
||||
// 优先级:Content-Type header > Content-Disposition filename > URL 路径 > 内容嗅探 > 图片解码
|
||||
func smartDetectMimeType(resp *http.Response, url string, fileBytes []byte) string {
|
||||
// 1. 尝试从 Content-Type header 获取
|
||||
mimeType := resp.Header.Get("Content-Type")
|
||||
if idx := strings.Index(mimeType, ";"); idx != -1 {
|
||||
mimeType = strings.TrimSpace(mimeType[:idx])
|
||||
}
|
||||
if mimeType != "" && mimeType != "application/octet-stream" {
|
||||
return mimeType
|
||||
}
|
||||
|
||||
// 2. 尝试从 Content-Disposition header 的 filename 获取
|
||||
if cd := resp.Header.Get("Content-Disposition"); cd != "" {
|
||||
parts := strings.Split(cd, ";")
|
||||
for _, part := range parts {
|
||||
part = strings.TrimSpace(part)
|
||||
if strings.HasPrefix(strings.ToLower(part), "filename=") {
|
||||
name := strings.TrimSpace(strings.TrimPrefix(part, "filename="))
|
||||
// 移除引号
|
||||
if len(name) > 2 && name[0] == '"' && name[len(name)-1] == '"' {
|
||||
name = name[1 : len(name)-1]
|
||||
}
|
||||
if dot := strings.LastIndex(name, "."); dot != -1 && dot+1 < len(name) {
|
||||
ext := strings.ToLower(name[dot+1:])
|
||||
if ext != "" {
|
||||
mt := GetMimeTypeByExtension(ext)
|
||||
if mt != "application/octet-stream" {
|
||||
return mt
|
||||
}
|
||||
}
|
||||
}
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 3. 尝试从 URL 路径获取扩展名
|
||||
mt := guessMimeTypeFromURL(url)
|
||||
if mt != "application/octet-stream" {
|
||||
return mt
|
||||
}
|
||||
|
||||
// 4. 使用 http.DetectContentType 内容嗅探
|
||||
if len(fileBytes) > 0 {
|
||||
sniffed := http.DetectContentType(fileBytes)
|
||||
if sniffed != "" && sniffed != "application/octet-stream" {
|
||||
// 去除可能的 charset 参数
|
||||
if idx := strings.Index(sniffed, ";"); idx != -1 {
|
||||
sniffed = strings.TrimSpace(sniffed[:idx])
|
||||
}
|
||||
return sniffed
|
||||
}
|
||||
}
|
||||
|
||||
// 5. 尝试作为图片解码获取格式
|
||||
if len(fileBytes) > 0 {
|
||||
if _, format, err := decodeImageConfig(fileBytes); err == nil && format != "" {
|
||||
return "image/" + strings.ToLower(format)
|
||||
}
|
||||
}
|
||||
|
||||
// 最终回退
|
||||
return "application/octet-stream"
|
||||
}
|
||||
|
||||
// loadFromBase64 从 base64 字符串加载文件
|
||||
func loadFromBase64(base64String string, providedMimeType string) (*types.CachedFileData, error) {
|
||||
var mimeType string
|
||||
var cleanBase64 string
|
||||
|
||||
// 处理 data: 前缀
|
||||
if strings.HasPrefix(base64String, "data:") {
|
||||
// 格式: data:mime/type;base64,xxxxx
|
||||
idx := strings.Index(base64String, ",")
|
||||
if idx != -1 {
|
||||
header := base64String[:idx]
|
||||
cleanBase64 = base64String[idx+1:]
|
||||
|
||||
// 从 header 提取 MIME 类型
|
||||
if strings.Contains(header, ":") && strings.Contains(header, ";") {
|
||||
mimeStart := strings.Index(header, ":") + 1
|
||||
mimeEnd := strings.Index(header, ";")
|
||||
if mimeStart < mimeEnd {
|
||||
mimeType = header[mimeStart:mimeEnd]
|
||||
}
|
||||
}
|
||||
} else {
|
||||
cleanBase64 = base64String
|
||||
}
|
||||
} else {
|
||||
cleanBase64 = base64String
|
||||
}
|
||||
|
||||
// 使用提供的 MIME 类型(如果有)
|
||||
if providedMimeType != "" {
|
||||
mimeType = providedMimeType
|
||||
}
|
||||
|
||||
// 解码 base64
|
||||
decodedData, err := base64.StdEncoding.DecodeString(cleanBase64)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to decode base64 data: %w", err)
|
||||
}
|
||||
|
||||
// 判断是否使用磁盘缓存(对于 base64 内联数据也支持磁盘缓存)
|
||||
base64Size := int64(len(cleanBase64))
|
||||
var cachedData *types.CachedFileData
|
||||
|
||||
if shouldUseDiskCache(base64Size) {
|
||||
// 使用磁盘缓存
|
||||
diskPath, err := writeToDiskCache(cleanBase64)
|
||||
if err != nil {
|
||||
// 磁盘缓存失败,回退到内存
|
||||
cachedData = types.NewMemoryCachedData(cleanBase64, mimeType, int64(len(decodedData)))
|
||||
} else {
|
||||
cachedData = types.NewDiskCachedData(diskPath, mimeType, int64(len(decodedData)))
|
||||
common.IncrementDiskFiles(base64Size)
|
||||
}
|
||||
} else {
|
||||
cachedData = types.NewMemoryCachedData(cleanBase64, mimeType, int64(len(decodedData)))
|
||||
}
|
||||
|
||||
// 如果是图片或 MIME 类型未知,尝试解码图片获取更多信息
|
||||
if mimeType == "" || strings.HasPrefix(mimeType, "image/") {
|
||||
config, format, err := decodeImageConfig(decodedData)
|
||||
if err == nil {
|
||||
cachedData.ImageConfig = &config
|
||||
cachedData.ImageFormat = format
|
||||
if mimeType == "" {
|
||||
cachedData.MimeType = "image/" + format
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return cachedData, nil
|
||||
}
|
||||
|
||||
// GetImageConfig 获取图片配置(宽高等信息)
|
||||
// 会自动处理缓存,避免重复下载/解码
|
||||
func GetImageConfig(c *gin.Context, source *types.FileSource) (image.Config, string, error) {
|
||||
cachedData, err := LoadFileSource(c, source, "get_image_config")
|
||||
if err != nil {
|
||||
return image.Config{}, "", err
|
||||
}
|
||||
|
||||
if cachedData.ImageConfig != nil {
|
||||
return *cachedData.ImageConfig, cachedData.ImageFormat, nil
|
||||
}
|
||||
|
||||
// 如果缓存中没有图片配置,尝试解码
|
||||
base64Str, err := cachedData.GetBase64Data()
|
||||
if err != nil {
|
||||
return image.Config{}, "", fmt.Errorf("failed to get base64 data: %w", err)
|
||||
}
|
||||
decodedData, err := base64.StdEncoding.DecodeString(base64Str)
|
||||
if err != nil {
|
||||
return image.Config{}, "", fmt.Errorf("failed to decode base64 for image config: %w", err)
|
||||
}
|
||||
|
||||
config, format, err := decodeImageConfig(decodedData)
|
||||
if err != nil {
|
||||
return image.Config{}, "", err
|
||||
}
|
||||
|
||||
// 更新缓存
|
||||
cachedData.ImageConfig = &config
|
||||
cachedData.ImageFormat = format
|
||||
|
||||
return config, format, nil
|
||||
}
|
||||
|
||||
// GetBase64Data 获取 base64 编码的数据
|
||||
// 会自动处理缓存,避免重复下载
|
||||
// 支持内存缓存和磁盘缓存
|
||||
func GetBase64Data(c *gin.Context, source *types.FileSource, reason ...string) (string, string, error) {
|
||||
cachedData, err := LoadFileSource(c, source, reason...)
|
||||
if err != nil {
|
||||
return "", "", err
|
||||
}
|
||||
base64Str, err := cachedData.GetBase64Data()
|
||||
if err != nil {
|
||||
return "", "", fmt.Errorf("failed to get base64 data: %w", err)
|
||||
}
|
||||
return base64Str, cachedData.MimeType, nil
|
||||
}
|
||||
|
||||
// GetMimeType 获取文件的 MIME 类型
|
||||
func GetMimeType(c *gin.Context, source *types.FileSource) (string, error) {
|
||||
// 如果已经有缓存,直接返回
|
||||
if source.HasCache() {
|
||||
return source.GetCache().MimeType, nil
|
||||
}
|
||||
|
||||
// 如果是 URL,尝试只获取 header 而不下载完整文件
|
||||
if source.IsURL() {
|
||||
mimeType, err := GetFileTypeFromUrl(c, source.URL, "get_mime_type")
|
||||
if err == nil && mimeType != "" && mimeType != "application/octet-stream" {
|
||||
return mimeType, nil
|
||||
}
|
||||
}
|
||||
|
||||
// 否则加载完整数据
|
||||
cachedData, err := LoadFileSource(c, source, "get_mime_type")
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return cachedData.MimeType, nil
|
||||
}
|
||||
|
||||
// DetectFileType 检测文件类型(image/audio/video/file)
|
||||
func DetectFileType(mimeType string) types.FileType {
|
||||
if strings.HasPrefix(mimeType, "image/") {
|
||||
return types.FileTypeImage
|
||||
}
|
||||
if strings.HasPrefix(mimeType, "audio/") {
|
||||
return types.FileTypeAudio
|
||||
}
|
||||
if strings.HasPrefix(mimeType, "video/") {
|
||||
return types.FileTypeVideo
|
||||
}
|
||||
return types.FileTypeFile
|
||||
}
|
||||
|
||||
// decodeImageConfig 从字节数据解码图片配置
|
||||
func decodeImageConfig(data []byte) (image.Config, string, error) {
|
||||
reader := bytes.NewReader(data)
|
||||
|
||||
// 尝试标准格式
|
||||
config, format, err := image.DecodeConfig(reader)
|
||||
if err == nil {
|
||||
return config, format, nil
|
||||
}
|
||||
|
||||
// 尝试 webp
|
||||
reader.Seek(0, io.SeekStart)
|
||||
config, err = webp.DecodeConfig(reader)
|
||||
if err == nil {
|
||||
return config, "webp", nil
|
||||
}
|
||||
|
||||
return image.Config{}, "", fmt.Errorf("failed to decode image config: unsupported format")
|
||||
}
|
||||
|
||||
// guessMimeTypeFromURL 从 URL 猜测 MIME 类型
|
||||
func guessMimeTypeFromURL(url string) string {
|
||||
// 移除查询参数
|
||||
cleanedURL := url
|
||||
if q := strings.Index(cleanedURL, "?"); q != -1 {
|
||||
cleanedURL = cleanedURL[:q]
|
||||
}
|
||||
|
||||
// 获取最后一段
|
||||
if slash := strings.LastIndex(cleanedURL, "/"); slash != -1 && slash+1 < len(cleanedURL) {
|
||||
last := cleanedURL[slash+1:]
|
||||
if dot := strings.LastIndex(last, "."); dot != -1 && dot+1 < len(last) {
|
||||
ext := strings.ToLower(last[dot+1:])
|
||||
return GetMimeTypeByExtension(ext)
|
||||
}
|
||||
}
|
||||
|
||||
return "application/octet-stream"
|
||||
}
|
||||
@@ -3,10 +3,6 @@ package service
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"image"
|
||||
_ "image/gif"
|
||||
_ "image/jpeg"
|
||||
_ "image/png"
|
||||
"log"
|
||||
"math"
|
||||
"path/filepath"
|
||||
@@ -23,8 +19,8 @@ import (
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
func getImageToken(fileMeta *types.FileMeta, model string, stream bool) (int, error) {
|
||||
if fileMeta == nil {
|
||||
func getImageToken(c *gin.Context, fileMeta *types.FileMeta, model string, stream bool) (int, error) {
|
||||
if fileMeta == nil || fileMeta.Source == nil {
|
||||
return 0, fmt.Errorf("image_url_is_nil")
|
||||
}
|
||||
|
||||
@@ -99,35 +95,20 @@ func getImageToken(fileMeta *types.FileMeta, model string, stream bool) (int, er
|
||||
fileMeta.Detail = "high"
|
||||
}
|
||||
|
||||
// Decode image to get dimensions
|
||||
var config image.Config
|
||||
var err error
|
||||
var format string
|
||||
var b64str string
|
||||
|
||||
if fileMeta.ParsedData != nil {
|
||||
config, format, b64str, err = DecodeBase64ImageData(fileMeta.ParsedData.Base64Data)
|
||||
} else {
|
||||
if strings.HasPrefix(fileMeta.OriginData, "http") {
|
||||
config, format, err = DecodeUrlImageData(fileMeta.OriginData)
|
||||
} else {
|
||||
common.SysLog(fmt.Sprintf("decoding image"))
|
||||
config, format, b64str, err = DecodeBase64ImageData(fileMeta.OriginData)
|
||||
}
|
||||
fileMeta.MimeType = format
|
||||
}
|
||||
|
||||
// 使用统一的文件服务获取图片配置
|
||||
config, format, err := GetImageConfig(c, fileMeta.Source)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
fileMeta.MimeType = format
|
||||
|
||||
if config.Width == 0 || config.Height == 0 {
|
||||
// not an image
|
||||
if format != "" && b64str != "" {
|
||||
// not an image, but might be a valid file
|
||||
if format != "" {
|
||||
// file type
|
||||
return 3 * baseTokens, nil
|
||||
}
|
||||
return 0, errors.New(fmt.Sprintf("fail to decode base64 config: %s", fileMeta.OriginData))
|
||||
return 0, errors.New(fmt.Sprintf("fail to decode image config: %s", fileMeta.GetIdentifier()))
|
||||
}
|
||||
|
||||
width := config.Width
|
||||
@@ -269,48 +250,24 @@ func EstimateRequestToken(c *gin.Context, meta *types.TokenCountMeta, info *rela
|
||||
shouldFetchFiles = false
|
||||
}
|
||||
|
||||
// 使用统一的文件服务获取文件类型
|
||||
for _, file := range meta.Files {
|
||||
if strings.HasPrefix(file.OriginData, "http") {
|
||||
if shouldFetchFiles {
|
||||
mineType, err := GetFileTypeFromUrl(c, file.OriginData, "token_counter")
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("error getting file base64 from url: %v", err)
|
||||
}
|
||||
if strings.HasPrefix(mineType, "image/") {
|
||||
file.FileType = types.FileTypeImage
|
||||
} else if strings.HasPrefix(mineType, "video/") {
|
||||
file.FileType = types.FileTypeVideo
|
||||
} else if strings.HasPrefix(mineType, "audio/") {
|
||||
file.FileType = types.FileTypeAudio
|
||||
} else {
|
||||
file.FileType = types.FileTypeFile
|
||||
}
|
||||
file.MimeType = mineType
|
||||
}
|
||||
} else if strings.HasPrefix(file.OriginData, "data:") {
|
||||
// get mime type from base64 header
|
||||
parts := strings.SplitN(file.OriginData, ",", 2)
|
||||
if len(parts) >= 1 {
|
||||
header := parts[0]
|
||||
// Extract mime type from "data:mime/type;base64" format
|
||||
if strings.Contains(header, ":") && strings.Contains(header, ";") {
|
||||
mimeStart := strings.Index(header, ":") + 1
|
||||
mimeEnd := strings.Index(header, ";")
|
||||
if mimeStart < mimeEnd {
|
||||
mineType := header[mimeStart:mimeEnd]
|
||||
if strings.HasPrefix(mineType, "image/") {
|
||||
file.FileType = types.FileTypeImage
|
||||
} else if strings.HasPrefix(mineType, "video/") {
|
||||
file.FileType = types.FileTypeVideo
|
||||
} else if strings.HasPrefix(mineType, "audio/") {
|
||||
file.FileType = types.FileTypeAudio
|
||||
} else {
|
||||
file.FileType = types.FileTypeFile
|
||||
}
|
||||
file.MimeType = mineType
|
||||
}
|
||||
if file.Source == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
// 如果文件类型未知且需要获取,通过 MIME 类型检测
|
||||
if file.FileType == "" || (file.Source.IsURL() && shouldFetchFiles) {
|
||||
mimeType, err := GetMimeType(c, file.Source)
|
||||
if err != nil {
|
||||
if shouldFetchFiles {
|
||||
return 0, fmt.Errorf("error getting file type: %v", err)
|
||||
}
|
||||
// 如果不需要获取,使用默认类型
|
||||
continue
|
||||
}
|
||||
file.MimeType = mimeType
|
||||
file.FileType = DetectFileType(mimeType)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -318,9 +275,9 @@ func EstimateRequestToken(c *gin.Context, meta *types.TokenCountMeta, info *rela
|
||||
switch file.FileType {
|
||||
case types.FileTypeImage:
|
||||
if common.IsOpenAITextModel(model) {
|
||||
token, err := getImageToken(file, model, info.IsStream)
|
||||
token, err := getImageToken(c, file, model, info.IsStream)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("error counting image token, media index[%d], original data[%s], err: %v", i, file.OriginData, err)
|
||||
return 0, fmt.Errorf("error counting image token, media index[%d], identifier[%s], err: %v", i, file.GetIdentifier(), err)
|
||||
}
|
||||
tkm += token
|
||||
} else {
|
||||
|
||||
Reference in New Issue
Block a user