refactor: update function signatures to include context and improve file handling #1599
This commit is contained in:
@@ -3,7 +3,8 @@ package constant
|
|||||||
type ContextKey string
|
type ContextKey string
|
||||||
|
|
||||||
const (
|
const (
|
||||||
ContextKeyPromptTokens ContextKey = "prompt_tokens"
|
ContextKeyTokenCountMeta ContextKey = "token_count_meta"
|
||||||
|
ContextKeyPromptTokens ContextKey = "prompt_tokens"
|
||||||
|
|
||||||
ContextKeyOriginalModel ContextKey = "original_model"
|
ContextKeyOriginalModel ContextKey = "original_model"
|
||||||
ContextKeyRequestStartTime ContextKey = "request_start_time"
|
ContextKeyRequestStartTime ContextKey = "request_start_time"
|
||||||
|
|||||||
@@ -133,6 +133,8 @@ func Relay(c *gin.Context, relayFormat types.RelayFormat) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
common.SetContextKey(c, constant.ContextKeyTokenCountMeta, meta)
|
||||||
|
|
||||||
preConsumedQuota, newAPIError := service.PreConsumeQuota(c, priceData.ShouldPreConsumedQuota, relayInfo)
|
preConsumedQuota, newAPIError := service.PreConsumeQuota(c, priceData.ShouldPreConsumedQuota, relayInfo)
|
||||||
if newAPIError != nil {
|
if newAPIError != nil {
|
||||||
return
|
return
|
||||||
|
|||||||
@@ -231,7 +231,7 @@ func (c *ClaudeRequest) GetTokenCountMeta() *types.TokenCountMeta {
|
|||||||
data = common.Interface2String(media.Source.Data)
|
data = common.Interface2String(media.Source.Data)
|
||||||
}
|
}
|
||||||
if data != "" {
|
if data != "" {
|
||||||
fileMeta = append(fileMeta, &types.FileMeta{FileType: types.FileTypeImage, Data: data})
|
fileMeta = append(fileMeta, &types.FileMeta{FileType: types.FileTypeImage, OriginData: data})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -263,7 +263,7 @@ func (c *ClaudeRequest) GetTokenCountMeta() *types.TokenCountMeta {
|
|||||||
data = common.Interface2String(media.Source.Data)
|
data = common.Interface2String(media.Source.Data)
|
||||||
}
|
}
|
||||||
if data != "" {
|
if data != "" {
|
||||||
fileMeta = append(fileMeta, &types.FileMeta{FileType: types.FileTypeImage, Data: data})
|
fileMeta = append(fileMeta, &types.FileMeta{FileType: types.FileTypeImage, OriginData: data})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
case "tool_use":
|
case "tool_use":
|
||||||
|
|||||||
@@ -35,23 +35,23 @@ func (r *GeminiChatRequest) GetTokenCountMeta() *types.TokenCountMeta {
|
|||||||
if part.InlineData != nil && part.InlineData.Data != "" {
|
if part.InlineData != nil && part.InlineData.Data != "" {
|
||||||
if strings.HasPrefix(part.InlineData.MimeType, "image/") {
|
if strings.HasPrefix(part.InlineData.MimeType, "image/") {
|
||||||
files = append(files, &types.FileMeta{
|
files = append(files, &types.FileMeta{
|
||||||
FileType: types.FileTypeImage,
|
FileType: types.FileTypeImage,
|
||||||
Data: part.InlineData.Data,
|
OriginData: part.InlineData.Data,
|
||||||
})
|
})
|
||||||
} else if strings.HasPrefix(part.InlineData.MimeType, "audio/") {
|
} else if strings.HasPrefix(part.InlineData.MimeType, "audio/") {
|
||||||
files = append(files, &types.FileMeta{
|
files = append(files, &types.FileMeta{
|
||||||
FileType: types.FileTypeAudio,
|
FileType: types.FileTypeAudio,
|
||||||
Data: part.InlineData.Data,
|
OriginData: part.InlineData.Data,
|
||||||
})
|
})
|
||||||
} else if strings.HasPrefix(part.InlineData.MimeType, "video/") {
|
} else if strings.HasPrefix(part.InlineData.MimeType, "video/") {
|
||||||
files = append(files, &types.FileMeta{
|
files = append(files, &types.FileMeta{
|
||||||
FileType: types.FileTypeVideo,
|
FileType: types.FileTypeVideo,
|
||||||
Data: part.InlineData.Data,
|
OriginData: part.InlineData.Data,
|
||||||
})
|
})
|
||||||
} else {
|
} else {
|
||||||
files = append(files, &types.FileMeta{
|
files = append(files, &types.FileMeta{
|
||||||
FileType: types.FileTypeFile,
|
FileType: types.FileTypeFile,
|
||||||
Data: part.InlineData.Data,
|
OriginData: part.InlineData.Data,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -118,7 +118,7 @@ func (r *GeneralOpenAIRequest) GetTokenCountMeta() *types.TokenCountMeta {
|
|||||||
meta := &types.FileMeta{
|
meta := &types.FileMeta{
|
||||||
FileType: types.FileTypeImage,
|
FileType: types.FileTypeImage,
|
||||||
}
|
}
|
||||||
meta.Data = imageUrl.Url
|
meta.OriginData = imageUrl.Url
|
||||||
meta.Detail = imageUrl.Detail
|
meta.Detail = imageUrl.Detail
|
||||||
fileMeta = append(fileMeta, meta)
|
fileMeta = append(fileMeta, meta)
|
||||||
}
|
}
|
||||||
@@ -128,7 +128,7 @@ func (r *GeneralOpenAIRequest) GetTokenCountMeta() *types.TokenCountMeta {
|
|||||||
meta := &types.FileMeta{
|
meta := &types.FileMeta{
|
||||||
FileType: types.FileTypeAudio,
|
FileType: types.FileTypeAudio,
|
||||||
}
|
}
|
||||||
meta.Data = inputAudio.Data
|
meta.OriginData = inputAudio.Data
|
||||||
fileMeta = append(fileMeta, meta)
|
fileMeta = append(fileMeta, meta)
|
||||||
}
|
}
|
||||||
} else if m.Type == ContentTypeFile {
|
} else if m.Type == ContentTypeFile {
|
||||||
@@ -137,7 +137,7 @@ func (r *GeneralOpenAIRequest) GetTokenCountMeta() *types.TokenCountMeta {
|
|||||||
meta := &types.FileMeta{
|
meta := &types.FileMeta{
|
||||||
FileType: types.FileTypeFile,
|
FileType: types.FileTypeFile,
|
||||||
}
|
}
|
||||||
meta.Data = file.FileData
|
meta.OriginData = file.FileData
|
||||||
fileMeta = append(fileMeta, meta)
|
fileMeta = append(fileMeta, meta)
|
||||||
}
|
}
|
||||||
} else if m.Type == ContentTypeVideoUrl {
|
} else if m.Type == ContentTypeVideoUrl {
|
||||||
@@ -146,7 +146,7 @@ func (r *GeneralOpenAIRequest) GetTokenCountMeta() *types.TokenCountMeta {
|
|||||||
meta := &types.FileMeta{
|
meta := &types.FileMeta{
|
||||||
FileType: types.FileTypeVideo,
|
FileType: types.FileTypeVideo,
|
||||||
}
|
}
|
||||||
meta.Data = videoUrl.Url
|
meta.OriginData = videoUrl.Url
|
||||||
fileMeta = append(fileMeta, meta)
|
fileMeta = append(fileMeta, meta)
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
@@ -784,14 +784,14 @@ func (r *OpenAIResponsesRequest) GetTokenCountMeta() *types.TokenCountMeta {
|
|||||||
for _, input := range inputs {
|
for _, input := range inputs {
|
||||||
if input.Type == "input_image" {
|
if input.Type == "input_image" {
|
||||||
fileMeta = append(fileMeta, &types.FileMeta{
|
fileMeta = append(fileMeta, &types.FileMeta{
|
||||||
FileType: types.FileTypeImage,
|
FileType: types.FileTypeImage,
|
||||||
Data: input.ImageUrl,
|
OriginData: input.ImageUrl,
|
||||||
Detail: input.Detail,
|
Detail: input.Detail,
|
||||||
})
|
})
|
||||||
} else if input.Type == "input_file" {
|
} else if input.Type == "input_file" {
|
||||||
fileMeta = append(fileMeta, &types.FileMeta{
|
fileMeta = append(fileMeta, &types.FileMeta{
|
||||||
FileType: types.FileTypeFile,
|
FileType: types.FileTypeFile,
|
||||||
Data: input.FileUrl,
|
OriginData: input.FileUrl,
|
||||||
})
|
})
|
||||||
} else {
|
} else {
|
||||||
texts = append(texts, input.Text)
|
texts = append(texts, input.Text)
|
||||||
|
|||||||
@@ -63,7 +63,7 @@ func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayIn
|
|||||||
|
|
||||||
var claudeReq *dto.ClaudeRequest
|
var claudeReq *dto.ClaudeRequest
|
||||||
var err error
|
var err error
|
||||||
claudeReq, err = claude.RequestOpenAI2ClaudeMessage(*request)
|
claudeReq, err = claude.RequestOpenAI2ClaudeMessage(c, *request)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -78,7 +78,7 @@ func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayIn
|
|||||||
if a.RequestMode == RequestModeCompletion {
|
if a.RequestMode == RequestModeCompletion {
|
||||||
return RequestOpenAI2ClaudeComplete(*request), nil
|
return RequestOpenAI2ClaudeComplete(*request), nil
|
||||||
} else {
|
} else {
|
||||||
return RequestOpenAI2ClaudeMessage(*request)
|
return RequestOpenAI2ClaudeMessage(c, *request)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -71,7 +71,7 @@ func RequestOpenAI2ClaudeComplete(textRequest dto.GeneralOpenAIRequest) *dto.Cla
|
|||||||
return &claudeRequest
|
return &claudeRequest
|
||||||
}
|
}
|
||||||
|
|
||||||
func RequestOpenAI2ClaudeMessage(textRequest dto.GeneralOpenAIRequest) (*dto.ClaudeRequest, error) {
|
func RequestOpenAI2ClaudeMessage(c *gin.Context, textRequest dto.GeneralOpenAIRequest) (*dto.ClaudeRequest, error) {
|
||||||
claudeTools := make([]any, 0, len(textRequest.Tools))
|
claudeTools := make([]any, 0, len(textRequest.Tools))
|
||||||
|
|
||||||
for _, tool := range textRequest.Tools {
|
for _, tool := range textRequest.Tools {
|
||||||
@@ -355,7 +355,7 @@ func RequestOpenAI2ClaudeMessage(textRequest dto.GeneralOpenAIRequest) (*dto.Cla
|
|||||||
// 判断是否是url
|
// 判断是否是url
|
||||||
if strings.HasPrefix(imageUrl.Url, "http") {
|
if strings.HasPrefix(imageUrl.Url, "http") {
|
||||||
// 是url,获取图片的类型和base64编码的数据
|
// 是url,获取图片的类型和base64编码的数据
|
||||||
fileData, err := service.GetFileBase64FromUrl(imageUrl.Url)
|
fileData, err := service.GetFileBase64FromUrl(c, imageUrl.Url, "formatting image for Claude")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("get file base64 from url failed: %s", err.Error())
|
return nil, fmt.Errorf("get file base64 from url failed: %s", err.Error())
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -142,7 +142,7 @@ func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayIn
|
|||||||
return nil, errors.New("request is nil")
|
return nil, errors.New("request is nil")
|
||||||
}
|
}
|
||||||
|
|
||||||
geminiRequest, err := CovertGemini2OpenAI(*request, info)
|
geminiRequest, err := CovertGemini2OpenAI(c, *request, info)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -178,7 +178,7 @@ func ThinkingAdaptor(geminiRequest *dto.GeminiChatRequest, info *relaycommon.Rel
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Setting safety to the lowest possible values since Gemini is already powerless enough
|
// Setting safety to the lowest possible values since Gemini is already powerless enough
|
||||||
func CovertGemini2OpenAI(textRequest dto.GeneralOpenAIRequest, info *relaycommon.RelayInfo) (*dto.GeminiChatRequest, error) {
|
func CovertGemini2OpenAI(c *gin.Context, textRequest dto.GeneralOpenAIRequest, info *relaycommon.RelayInfo) (*dto.GeminiChatRequest, error) {
|
||||||
|
|
||||||
geminiRequest := dto.GeminiChatRequest{
|
geminiRequest := dto.GeminiChatRequest{
|
||||||
Contents: make([]dto.GeminiChatContent, 0, len(textRequest.Messages)),
|
Contents: make([]dto.GeminiChatContent, 0, len(textRequest.Messages)),
|
||||||
@@ -390,7 +390,7 @@ func CovertGemini2OpenAI(textRequest dto.GeneralOpenAIRequest, info *relaycommon
|
|||||||
// 判断是否是url
|
// 判断是否是url
|
||||||
if strings.HasPrefix(part.GetImageMedia().Url, "http") {
|
if strings.HasPrefix(part.GetImageMedia().Url, "http") {
|
||||||
// 是url,获取文件的类型和base64编码的数据
|
// 是url,获取文件的类型和base64编码的数据
|
||||||
fileData, err := service.GetFileBase64FromUrl(part.GetImageMedia().Url)
|
fileData, err := service.GetFileBase64FromUrl(c, part.GetImageMedia().Url, "formatting image for Gemini")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("get file base64 from url '%s' failed: %w", part.GetImageMedia().Url, err)
|
return nil, fmt.Errorf("get file base64 from url '%s' failed: %w", part.GetImageMedia().Url, err)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -31,7 +31,7 @@ func (a *Adaptor) ConvertClaudeRequest(c *gin.Context, info *relaycommon.RelayIn
|
|||||||
openaiRequest.(*dto.GeneralOpenAIRequest).StreamOptions = &dto.StreamOptions{
|
openaiRequest.(*dto.GeneralOpenAIRequest).StreamOptions = &dto.StreamOptions{
|
||||||
IncludeUsage: true,
|
IncludeUsage: true,
|
||||||
}
|
}
|
||||||
return requestOpenAI2Ollama(openaiRequest.(*dto.GeneralOpenAIRequest))
|
return requestOpenAI2Ollama(c, openaiRequest.(*dto.GeneralOpenAIRequest))
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) {
|
func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) {
|
||||||
@@ -69,7 +69,7 @@ func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayIn
|
|||||||
if request == nil {
|
if request == nil {
|
||||||
return nil, errors.New("request is nil")
|
return nil, errors.New("request is nil")
|
||||||
}
|
}
|
||||||
return requestOpenAI2Ollama(request)
|
return requestOpenAI2Ollama(c, request)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) {
|
func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) {
|
||||||
|
|||||||
@@ -14,7 +14,7 @@ import (
|
|||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
)
|
)
|
||||||
|
|
||||||
func requestOpenAI2Ollama(request *dto.GeneralOpenAIRequest) (*OllamaRequest, error) {
|
func requestOpenAI2Ollama(c *gin.Context, request *dto.GeneralOpenAIRequest) (*OllamaRequest, error) {
|
||||||
messages := make([]dto.Message, 0, len(request.Messages))
|
messages := make([]dto.Message, 0, len(request.Messages))
|
||||||
for _, message := range request.Messages {
|
for _, message := range request.Messages {
|
||||||
if !message.IsStringContent() {
|
if !message.IsStringContent() {
|
||||||
@@ -24,7 +24,7 @@ func requestOpenAI2Ollama(request *dto.GeneralOpenAIRequest) (*OllamaRequest, er
|
|||||||
imageUrl := mediaMessage.GetImageMedia()
|
imageUrl := mediaMessage.GetImageMedia()
|
||||||
// check if not base64
|
// check if not base64
|
||||||
if strings.HasPrefix(imageUrl.Url, "http") {
|
if strings.HasPrefix(imageUrl.Url, "http") {
|
||||||
fileData, err := service.GetFileBase64FromUrl(imageUrl.Url)
|
fileData, err := service.GetFileBase64FromUrl(c, imageUrl.Url, "formatting image for Ollama")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -182,7 +182,7 @@ func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayIn
|
|||||||
return nil, errors.New("request is nil")
|
return nil, errors.New("request is nil")
|
||||||
}
|
}
|
||||||
if a.RequestMode == RequestModeClaude {
|
if a.RequestMode == RequestModeClaude {
|
||||||
claudeReq, err := claude.RequestOpenAI2ClaudeMessage(*request)
|
claudeReq, err := claude.RequestOpenAI2ClaudeMessage(c, *request)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@@ -191,7 +191,7 @@ func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayIn
|
|||||||
info.UpstreamModelName = claudeReq.Model
|
info.UpstreamModelName = claudeReq.Model
|
||||||
return vertexClaudeReq, nil
|
return vertexClaudeReq, nil
|
||||||
} else if a.RequestMode == RequestModeGemini {
|
} else if a.RequestMode == RequestModeGemini {
|
||||||
geminiRequest, err := gemini.CovertGemini2OpenAI(*request, info)
|
geminiRequest, err := gemini.CovertGemini2OpenAI(c, *request, info)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -42,16 +42,16 @@ func DoWorkerRequest(req *WorkerRequest) (*http.Response, error) {
|
|||||||
return http.Post(workerUrl, "application/json", bytes.NewBuffer(workerPayload))
|
return http.Post(workerUrl, "application/json", bytes.NewBuffer(workerPayload))
|
||||||
}
|
}
|
||||||
|
|
||||||
func DoDownloadRequest(originUrl string) (resp *http.Response, err error) {
|
func DoDownloadRequest(originUrl string, reason ...string) (resp *http.Response, err error) {
|
||||||
if setting.EnableWorker() {
|
if setting.EnableWorker() {
|
||||||
common.SysLog(fmt.Sprintf("downloading file from worker: %s", originUrl))
|
common.SysLog(fmt.Sprintf("downloading file from worker: %s, reason: %s", originUrl, strings.Join(reason, ", ")))
|
||||||
req := &WorkerRequest{
|
req := &WorkerRequest{
|
||||||
URL: originUrl,
|
URL: originUrl,
|
||||||
Key: setting.WorkerValidKey,
|
Key: setting.WorkerValidKey,
|
||||||
}
|
}
|
||||||
return DoWorkerRequest(req)
|
return DoWorkerRequest(req)
|
||||||
} else {
|
} else {
|
||||||
common.SysLog(fmt.Sprintf("downloading from origin: %s", originUrl))
|
common.SysLog(fmt.Sprintf("downloading from origin with worker: %s, reason: %s", originUrl, strings.Join(reason, ", ")))
|
||||||
return http.Get(originUrl)
|
return http.Get(originUrl)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -3,17 +3,29 @@ package service
|
|||||||
import (
|
import (
|
||||||
"encoding/base64"
|
"encoding/base64"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
"io"
|
"io"
|
||||||
"one-api/common"
|
"one-api/common"
|
||||||
"one-api/constant"
|
"one-api/constant"
|
||||||
"one-api/dto"
|
"one-api/logger"
|
||||||
|
"one-api/types"
|
||||||
"strings"
|
"strings"
|
||||||
)
|
)
|
||||||
|
|
||||||
func GetFileBase64FromUrl(url string) (*dto.LocalFileData, error) {
|
func GetFileBase64FromUrl(c *gin.Context, url string, reason ...string) (*types.LocalFileData, error) {
|
||||||
|
contextKey := fmt.Sprintf("file_download_%s", common.GenerateHMAC(url))
|
||||||
|
|
||||||
|
// Check if the file has already been downloaded in this request
|
||||||
|
if cachedData, exists := c.Get(contextKey); exists {
|
||||||
|
if common.DebugEnabled {
|
||||||
|
logger.LogDebug(c, fmt.Sprintf("Using cached file data for URL: %s", url))
|
||||||
|
}
|
||||||
|
return cachedData.(*types.LocalFileData), nil
|
||||||
|
}
|
||||||
|
|
||||||
var maxFileSize = constant.MaxFileDownloadMB * 1024 * 1024
|
var maxFileSize = constant.MaxFileDownloadMB * 1024 * 1024
|
||||||
|
|
||||||
resp, err := DoDownloadRequest(url)
|
resp, err := DoDownloadRequest(url, reason...)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@@ -77,12 +89,15 @@ func GetFileBase64FromUrl(url string) (*dto.LocalFileData, error) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
data := &types.LocalFileData{
|
||||||
return &dto.LocalFileData{
|
|
||||||
Base64Data: base64Data,
|
Base64Data: base64Data,
|
||||||
MimeType: mimeType,
|
MimeType: mimeType,
|
||||||
Size: int64(len(fileBytes)),
|
Size: int64(len(fileBytes)),
|
||||||
}, nil
|
}
|
||||||
|
// Store the file data in the context to avoid re-downloading
|
||||||
|
c.Set(contextKey, data)
|
||||||
|
|
||||||
|
return data, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func GetMimeTypeByExtension(ext string) string {
|
func GetMimeTypeByExtension(ext string) string {
|
||||||
|
|||||||
@@ -154,16 +154,22 @@ func getImageToken(fileMeta *types.FileMeta, model string, stream bool) (int, er
|
|||||||
var err error
|
var err error
|
||||||
var format string
|
var format string
|
||||||
var b64str string
|
var b64str string
|
||||||
if strings.HasPrefix(fileMeta.Data, "http") {
|
|
||||||
config, format, err = DecodeUrlImageData(fileMeta.Data)
|
if fileMeta.ParsedData != nil {
|
||||||
|
config, format, b64str, err = DecodeBase64ImageData(fileMeta.ParsedData.Base64Data)
|
||||||
} else {
|
} else {
|
||||||
common.SysLog(fmt.Sprintf("decoding image"))
|
if strings.HasPrefix(fileMeta.OriginData, "http") {
|
||||||
config, format, b64str, err = DecodeBase64ImageData(fileMeta.Data)
|
config, format, err = DecodeUrlImageData(fileMeta.OriginData)
|
||||||
|
} else {
|
||||||
|
common.SysLog(fmt.Sprintf("decoding image"))
|
||||||
|
config, format, b64str, err = DecodeBase64ImageData(fileMeta.OriginData)
|
||||||
|
}
|
||||||
|
fileMeta.MimeType = format
|
||||||
}
|
}
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return 0, err
|
return 0, err
|
||||||
}
|
}
|
||||||
fileMeta.MimeType = format
|
|
||||||
|
|
||||||
if config.Width == 0 || config.Height == 0 {
|
if config.Width == 0 || config.Height == 0 {
|
||||||
// not an image
|
// not an image
|
||||||
@@ -171,7 +177,7 @@ func getImageToken(fileMeta *types.FileMeta, model string, stream bool) (int, er
|
|||||||
// file type
|
// file type
|
||||||
return 3 * baseTokens, nil
|
return 3 * baseTokens, nil
|
||||||
}
|
}
|
||||||
return 0, errors.New(fmt.Sprintf("fail to decode base64 config: %s", fileMeta.Data))
|
return 0, errors.New(fmt.Sprintf("fail to decode base64 config: %s", fileMeta.OriginData))
|
||||||
}
|
}
|
||||||
|
|
||||||
width := config.Width
|
width := config.Width
|
||||||
@@ -268,6 +274,34 @@ func CountRequestToken(c *gin.Context, meta *types.TokenCountMeta, info *relayco
|
|||||||
tkm += 3
|
tkm += 3
|
||||||
}
|
}
|
||||||
|
|
||||||
|
shouldFetchFiles := true
|
||||||
|
|
||||||
|
if info.RelayFormat == types.RelayFormatOpenAIRealtime || info.RelayFormat == types.RelayFormatGemini {
|
||||||
|
shouldFetchFiles = false
|
||||||
|
}
|
||||||
|
|
||||||
|
if shouldFetchFiles {
|
||||||
|
for _, file := range meta.Files {
|
||||||
|
if strings.HasPrefix(file.OriginData, "http") {
|
||||||
|
localFileData, err := GetFileBase64FromUrl(c, file.OriginData, "token_counter")
|
||||||
|
if err != nil {
|
||||||
|
return 0, fmt.Errorf("error getting file base64 from url: %v", err)
|
||||||
|
}
|
||||||
|
if strings.HasPrefix(localFileData.MimeType, "image/") {
|
||||||
|
file.FileType = types.FileTypeImage
|
||||||
|
} else if strings.HasPrefix(localFileData.MimeType, "video/") {
|
||||||
|
file.FileType = types.FileTypeVideo
|
||||||
|
} else if strings.HasPrefix(localFileData.MimeType, "audio/") {
|
||||||
|
file.FileType = types.FileTypeAudio
|
||||||
|
} else {
|
||||||
|
file.FileType = types.FileTypeFile
|
||||||
|
}
|
||||||
|
file.MimeType = localFileData.MimeType
|
||||||
|
file.ParsedData = localFileData
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
for _, file := range meta.Files {
|
for _, file := range meta.Files {
|
||||||
switch file.FileType {
|
switch file.FileType {
|
||||||
case types.FileTypeImage:
|
case types.FileTypeImage:
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
package dto
|
package types
|
||||||
|
|
||||||
type LocalFileData struct {
|
type LocalFileData struct {
|
||||||
MimeType string
|
MimeType string
|
||||||
@@ -32,9 +32,10 @@ type TokenCountMeta struct {
|
|||||||
|
|
||||||
type FileMeta struct {
|
type FileMeta struct {
|
||||||
FileType
|
FileType
|
||||||
MimeType string
|
MimeType string
|
||||||
Data string
|
OriginData string // url or base64 data
|
||||||
Detail string
|
Detail string
|
||||||
|
ParsedData *LocalFileData
|
||||||
}
|
}
|
||||||
|
|
||||||
type RequestMeta struct {
|
type RequestMeta struct {
|
||||||
|
|||||||
Reference in New Issue
Block a user