From 5d6fac69c4babefdff649e526a2805e476d82323 Mon Sep 17 00:00:00 2001 From: CaIon Date: Sat, 16 Aug 2025 14:56:29 +0800 Subject: [PATCH] feat: implement file type detection from URL with enhanced MIME type handling --- service/file_decoder.go | 120 +++++++++++++++++++++++++++++++++++++-- service/token_counter.go | 21 +++---- 2 files changed, 127 insertions(+), 14 deletions(-) diff --git a/service/file_decoder.go b/service/file_decoder.go index bd14b963..359bd9ab 100644 --- a/service/file_decoder.go +++ b/service/file_decoder.go @@ -1,17 +1,131 @@ package service import ( + "bytes" "encoding/base64" "fmt" - "github.com/gin-gonic/gin" + "image" "io" + "net/http" "one-api/common" "one-api/constant" "one-api/logger" "one-api/types" "strings" + + "github.com/gin-gonic/gin" ) +// GetFileTypeFromUrl 获取文件类型,返回 mime type, 例如 image/jpeg, image/png, image/gif, image/bmp, image/tiff, application/pdf +// 如果获取失败,返回 application/octet-stream +func GetFileTypeFromUrl(c *gin.Context, url string, reason ...string) (string, error) { + response, err := DoDownloadRequest(url, reason...) + if err != nil { + common.SysLog(fmt.Sprintf("fail to get file type from url: %s, error: %s", url, err.Error())) + return "", err + } + defer response.Body.Close() + + if response.StatusCode != 200 { + logger.LogError(c, fmt.Sprintf("failed to download file from %s, status code: %d", url, response.StatusCode)) + return "", fmt.Errorf("failed to download file, status code: %d", response.StatusCode) + } + + if headerType := strings.TrimSpace(response.Header.Get("Content-Type")); headerType != "" { + if i := strings.Index(headerType, ";"); i != -1 { + headerType = headerType[:i] + } + if headerType != "application/octet-stream" { + return headerType, nil + } + } + + if cd := response.Header.Get("Content-Disposition"); cd != "" { + parts := strings.Split(cd, ";") + for _, part := range parts { + part = strings.TrimSpace(part) + if strings.HasPrefix(strings.ToLower(part), "filename=") { + name := strings.TrimSpace(strings.TrimPrefix(part, "filename=")) + if len(name) > 2 && name[0] == '"' && name[len(name)-1] == '"' { + name = name[1 : len(name)-1] + } + if dot := strings.LastIndex(name, "."); dot != -1 && dot+1 < len(name) { + ext := strings.ToLower(name[dot+1:]) + if ext != "" { + mt := GetMimeTypeByExtension(ext) + if mt != "application/octet-stream" { + return mt, nil + } + } + } + break + } + } + } + + cleanedURL := url + if q := strings.Index(cleanedURL, "?"); q != -1 { + cleanedURL = cleanedURL[:q] + } + if slash := strings.LastIndex(cleanedURL, "/"); slash != -1 && slash+1 < len(cleanedURL) { + last := cleanedURL[slash+1:] + if dot := strings.LastIndex(last, "."); dot != -1 && dot+1 < len(last) { + ext := strings.ToLower(last[dot+1:]) + if ext != "" { + mt := GetMimeTypeByExtension(ext) + if mt != "application/octet-stream" { + return mt, nil + } + } + } + } + + var readData []byte + limits := []int{512, 8 * 1024, 24 * 1024, 64 * 1024} + for _, limit := range limits { + logger.LogDebug(c, fmt.Sprintf("Trying to read %d bytes to determine file type", limit)) + if len(readData) < limit { + need := limit - len(readData) + tmp := make([]byte, need) + n, _ := io.ReadFull(response.Body, tmp) + if n > 0 { + readData = append(readData, tmp[:n]...) + } + } + + if len(readData) == 0 { + continue + } + + sniffed := http.DetectContentType(readData) + if sniffed != "" && sniffed != "application/octet-stream" { + return sniffed, nil + } + + if _, format, err := image.DecodeConfig(bytes.NewReader(readData)); err == nil { + switch strings.ToLower(format) { + case "jpeg", "jpg": + return "image/jpeg", nil + case "png": + return "image/png", nil + case "gif": + return "image/gif", nil + case "bmp": + return "image/bmp", nil + case "tiff": + return "image/tiff", nil + default: + if format != "" { + return "image/" + strings.ToLower(format), nil + } + } + } + } + + // Fallback + return "application/octet-stream", nil +} + func GetFileBase64FromUrl(c *gin.Context, url string, reason ...string) (*types.LocalFileData, error) { contextKey := fmt.Sprintf("file_download_%s", common.GenerateHMAC(url)) @@ -50,9 +164,7 @@ func GetFileBase64FromUrl(c *gin.Context, url string, reason ...string) (*types. mimeType = strings.Split(mimeType, ";")[0] } if mimeType == "application/octet-stream" { - if common.DebugEnabled { - println("MIME type is application/octet-stream, trying to guess from URL or filename") - } + logger.LogDebug(c, fmt.Sprintf("MIME type is application/octet-stream for URL: %s", url)) // try to guess the MIME type from the url last segment urlParts := strings.Split(url, "/") if len(urlParts) > 0 { diff --git a/service/token_counter.go b/service/token_counter.go index 2312f27e..bac6c067 100644 --- a/service/token_counter.go +++ b/service/token_counter.go @@ -283,21 +283,20 @@ func CountRequestToken(c *gin.Context, meta *types.TokenCountMeta, info *relayco if shouldFetchFiles { for _, file := range meta.Files { if strings.HasPrefix(file.OriginData, "http") { - localFileData, err := GetFileBase64FromUrl(c, file.OriginData, "token_counter") + mineType, err := GetFileTypeFromUrl(c, file.OriginData, "token_counter") if err != nil { return 0, fmt.Errorf("error getting file base64 from url: %v", err) } - if strings.HasPrefix(localFileData.MimeType, "image/") { + if strings.HasPrefix(mineType, "image/") { file.FileType = types.FileTypeImage - } else if strings.HasPrefix(localFileData.MimeType, "video/") { + } else if strings.HasPrefix(mineType, "video/") { file.FileType = types.FileTypeVideo - } else if strings.HasPrefix(localFileData.MimeType, "audio/") { + } else if strings.HasPrefix(mineType, "audio/") { file.FileType = types.FileTypeAudio } else { file.FileType = types.FileTypeFile } - file.MimeType = localFileData.MimeType - file.ParsedData = localFileData + file.MimeType = mineType } } } @@ -306,7 +305,7 @@ func CountRequestToken(c *gin.Context, meta *types.TokenCountMeta, info *relayco switch file.FileType { case types.FileTypeImage: if info.RelayFormat == types.RelayFormatGemini { - tkm += 240 + tkm += 256 } else { token, err := getImageToken(file, model, info.IsStream) if err != nil { @@ -315,11 +314,13 @@ func CountRequestToken(c *gin.Context, meta *types.TokenCountMeta, info *relayco tkm += token } case types.FileTypeAudio: - tkm += 100 + tkm += 256 case types.FileTypeVideo: - tkm += 5000 + tkm += 4096 * 2 case types.FileTypeFile: - tkm += 5000 + tkm += 4096 + default: + tkm += 4096 // Default case for unknown file types } }