refactor: update function signatures to include context and improve file handling #1599
This commit is contained in:
@@ -42,16 +42,16 @@ func DoWorkerRequest(req *WorkerRequest) (*http.Response, error) {
|
||||
return http.Post(workerUrl, "application/json", bytes.NewBuffer(workerPayload))
|
||||
}
|
||||
|
||||
func DoDownloadRequest(originUrl string) (resp *http.Response, err error) {
|
||||
func DoDownloadRequest(originUrl string, reason ...string) (resp *http.Response, err error) {
|
||||
if setting.EnableWorker() {
|
||||
common.SysLog(fmt.Sprintf("downloading file from worker: %s", originUrl))
|
||||
common.SysLog(fmt.Sprintf("downloading file from worker: %s, reason: %s", originUrl, strings.Join(reason, ", ")))
|
||||
req := &WorkerRequest{
|
||||
URL: originUrl,
|
||||
Key: setting.WorkerValidKey,
|
||||
}
|
||||
return DoWorkerRequest(req)
|
||||
} else {
|
||||
common.SysLog(fmt.Sprintf("downloading from origin: %s", originUrl))
|
||||
common.SysLog(fmt.Sprintf("downloading from origin with worker: %s, reason: %s", originUrl, strings.Join(reason, ", ")))
|
||||
return http.Get(originUrl)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -3,17 +3,29 @@ package service
|
||||
import (
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"github.com/gin-gonic/gin"
|
||||
"io"
|
||||
"one-api/common"
|
||||
"one-api/constant"
|
||||
"one-api/dto"
|
||||
"one-api/logger"
|
||||
"one-api/types"
|
||||
"strings"
|
||||
)
|
||||
|
||||
func GetFileBase64FromUrl(url string) (*dto.LocalFileData, error) {
|
||||
func GetFileBase64FromUrl(c *gin.Context, url string, reason ...string) (*types.LocalFileData, error) {
|
||||
contextKey := fmt.Sprintf("file_download_%s", common.GenerateHMAC(url))
|
||||
|
||||
// Check if the file has already been downloaded in this request
|
||||
if cachedData, exists := c.Get(contextKey); exists {
|
||||
if common.DebugEnabled {
|
||||
logger.LogDebug(c, fmt.Sprintf("Using cached file data for URL: %s", url))
|
||||
}
|
||||
return cachedData.(*types.LocalFileData), nil
|
||||
}
|
||||
|
||||
var maxFileSize = constant.MaxFileDownloadMB * 1024 * 1024
|
||||
|
||||
resp, err := DoDownloadRequest(url)
|
||||
resp, err := DoDownloadRequest(url, reason...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -77,12 +89,15 @@ func GetFileBase64FromUrl(url string) (*dto.LocalFileData, error) {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return &dto.LocalFileData{
|
||||
data := &types.LocalFileData{
|
||||
Base64Data: base64Data,
|
||||
MimeType: mimeType,
|
||||
Size: int64(len(fileBytes)),
|
||||
}, nil
|
||||
}
|
||||
// Store the file data in the context to avoid re-downloading
|
||||
c.Set(contextKey, data)
|
||||
|
||||
return data, nil
|
||||
}
|
||||
|
||||
func GetMimeTypeByExtension(ext string) string {
|
||||
|
||||
@@ -154,16 +154,22 @@ func getImageToken(fileMeta *types.FileMeta, model string, stream bool) (int, er
|
||||
var err error
|
||||
var format string
|
||||
var b64str string
|
||||
if strings.HasPrefix(fileMeta.Data, "http") {
|
||||
config, format, err = DecodeUrlImageData(fileMeta.Data)
|
||||
|
||||
if fileMeta.ParsedData != nil {
|
||||
config, format, b64str, err = DecodeBase64ImageData(fileMeta.ParsedData.Base64Data)
|
||||
} else {
|
||||
common.SysLog(fmt.Sprintf("decoding image"))
|
||||
config, format, b64str, err = DecodeBase64ImageData(fileMeta.Data)
|
||||
if strings.HasPrefix(fileMeta.OriginData, "http") {
|
||||
config, format, err = DecodeUrlImageData(fileMeta.OriginData)
|
||||
} else {
|
||||
common.SysLog(fmt.Sprintf("decoding image"))
|
||||
config, format, b64str, err = DecodeBase64ImageData(fileMeta.OriginData)
|
||||
}
|
||||
fileMeta.MimeType = format
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
fileMeta.MimeType = format
|
||||
|
||||
if config.Width == 0 || config.Height == 0 {
|
||||
// not an image
|
||||
@@ -171,7 +177,7 @@ func getImageToken(fileMeta *types.FileMeta, model string, stream bool) (int, er
|
||||
// file type
|
||||
return 3 * baseTokens, nil
|
||||
}
|
||||
return 0, errors.New(fmt.Sprintf("fail to decode base64 config: %s", fileMeta.Data))
|
||||
return 0, errors.New(fmt.Sprintf("fail to decode base64 config: %s", fileMeta.OriginData))
|
||||
}
|
||||
|
||||
width := config.Width
|
||||
@@ -268,6 +274,34 @@ func CountRequestToken(c *gin.Context, meta *types.TokenCountMeta, info *relayco
|
||||
tkm += 3
|
||||
}
|
||||
|
||||
shouldFetchFiles := true
|
||||
|
||||
if info.RelayFormat == types.RelayFormatOpenAIRealtime || info.RelayFormat == types.RelayFormatGemini {
|
||||
shouldFetchFiles = false
|
||||
}
|
||||
|
||||
if shouldFetchFiles {
|
||||
for _, file := range meta.Files {
|
||||
if strings.HasPrefix(file.OriginData, "http") {
|
||||
localFileData, err := GetFileBase64FromUrl(c, file.OriginData, "token_counter")
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("error getting file base64 from url: %v", err)
|
||||
}
|
||||
if strings.HasPrefix(localFileData.MimeType, "image/") {
|
||||
file.FileType = types.FileTypeImage
|
||||
} else if strings.HasPrefix(localFileData.MimeType, "video/") {
|
||||
file.FileType = types.FileTypeVideo
|
||||
} else if strings.HasPrefix(localFileData.MimeType, "audio/") {
|
||||
file.FileType = types.FileTypeAudio
|
||||
} else {
|
||||
file.FileType = types.FileTypeFile
|
||||
}
|
||||
file.MimeType = localFileData.MimeType
|
||||
file.ParsedData = localFileData
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for _, file := range meta.Files {
|
||||
switch file.FileType {
|
||||
case types.FileTypeImage:
|
||||
|
||||
Reference in New Issue
Block a user