From 0bb43aa46417432eed95ab98428ae2abf8233e86 Mon Sep 17 00:00:00 2001 From: CaIon Date: Fri, 15 Aug 2025 18:40:54 +0800 Subject: [PATCH] refactor: update function signatures to include context and improve file handling #1599 --- constant/context_key.go | 3 +- controller/relay.go | 2 ++ dto/claude.go | 4 +-- dto/gemini.go | 16 +++++----- dto/openai_request.go | 18 +++++------ relay/channel/aws/adaptor.go | 2 +- relay/channel/claude/adaptor.go | 2 +- relay/channel/claude/relay-claude.go | 4 +-- relay/channel/gemini/adaptor.go | 2 +- relay/channel/gemini/relay-gemini.go | 4 +-- relay/channel/ollama/adaptor.go | 4 +-- relay/channel/ollama/relay-ollama.go | 4 +-- relay/channel/vertex/adaptor.go | 4 +-- service/cf_worker.go | 6 ++-- service/file_decoder.go | 27 ++++++++++++---- service/token_counter.go | 46 ++++++++++++++++++++++++---- {dto => types}/file_data.go | 2 +- types/request_meta.go | 7 +++-- 18 files changed, 105 insertions(+), 52 deletions(-) rename {dto => types}/file_data.go (88%) diff --git a/constant/context_key.go b/constant/context_key.go index 569a0373..3945243c 100644 --- a/constant/context_key.go +++ b/constant/context_key.go @@ -3,7 +3,8 @@ package constant type ContextKey string const ( - ContextKeyPromptTokens ContextKey = "prompt_tokens" + ContextKeyTokenCountMeta ContextKey = "token_count_meta" + ContextKeyPromptTokens ContextKey = "prompt_tokens" ContextKeyOriginalModel ContextKey = "original_model" ContextKeyRequestStartTime ContextKey = "request_start_time" diff --git a/controller/relay.go b/controller/relay.go index b0c995fb..f0d84ea0 100644 --- a/controller/relay.go +++ b/controller/relay.go @@ -133,6 +133,8 @@ func Relay(c *gin.Context, relayFormat types.RelayFormat) { return } + common.SetContextKey(c, constant.ContextKeyTokenCountMeta, meta) + preConsumedQuota, newAPIError := service.PreConsumeQuota(c, priceData.ShouldPreConsumedQuota, relayInfo) if newAPIError != nil { return diff --git a/dto/claude.go b/dto/claude.go index 48bef659..ad2f3ebc 100644 --- a/dto/claude.go +++ b/dto/claude.go @@ -231,7 +231,7 @@ func (c *ClaudeRequest) GetTokenCountMeta() *types.TokenCountMeta { data = common.Interface2String(media.Source.Data) } if data != "" { - fileMeta = append(fileMeta, &types.FileMeta{FileType: types.FileTypeImage, Data: data}) + fileMeta = append(fileMeta, &types.FileMeta{FileType: types.FileTypeImage, OriginData: data}) } } } @@ -263,7 +263,7 @@ func (c *ClaudeRequest) GetTokenCountMeta() *types.TokenCountMeta { data = common.Interface2String(media.Source.Data) } if data != "" { - fileMeta = append(fileMeta, &types.FileMeta{FileType: types.FileTypeImage, Data: data}) + fileMeta = append(fileMeta, &types.FileMeta{FileType: types.FileTypeImage, OriginData: data}) } } case "tool_use": diff --git a/dto/gemini.go b/dto/gemini.go index b327de62..9f5b34ea 100644 --- a/dto/gemini.go +++ b/dto/gemini.go @@ -35,23 +35,23 @@ func (r *GeminiChatRequest) GetTokenCountMeta() *types.TokenCountMeta { if part.InlineData != nil && part.InlineData.Data != "" { if strings.HasPrefix(part.InlineData.MimeType, "image/") { files = append(files, &types.FileMeta{ - FileType: types.FileTypeImage, - Data: part.InlineData.Data, + FileType: types.FileTypeImage, + OriginData: part.InlineData.Data, }) } else if strings.HasPrefix(part.InlineData.MimeType, "audio/") { files = append(files, &types.FileMeta{ - FileType: types.FileTypeAudio, - Data: part.InlineData.Data, + FileType: types.FileTypeAudio, + OriginData: part.InlineData.Data, }) } else if strings.HasPrefix(part.InlineData.MimeType, "video/") { files = append(files, &types.FileMeta{ - FileType: types.FileTypeVideo, - Data: part.InlineData.Data, + FileType: types.FileTypeVideo, + OriginData: part.InlineData.Data, }) } else { files = append(files, &types.FileMeta{ - FileType: types.FileTypeFile, - Data: part.InlineData.Data, + FileType: types.FileTypeFile, + OriginData: part.InlineData.Data, }) } } diff --git a/dto/openai_request.go b/dto/openai_request.go index 12aa54f4..36240852 100644 --- a/dto/openai_request.go +++ b/dto/openai_request.go @@ -118,7 +118,7 @@ func (r *GeneralOpenAIRequest) GetTokenCountMeta() *types.TokenCountMeta { meta := &types.FileMeta{ FileType: types.FileTypeImage, } - meta.Data = imageUrl.Url + meta.OriginData = imageUrl.Url meta.Detail = imageUrl.Detail fileMeta = append(fileMeta, meta) } @@ -128,7 +128,7 @@ func (r *GeneralOpenAIRequest) GetTokenCountMeta() *types.TokenCountMeta { meta := &types.FileMeta{ FileType: types.FileTypeAudio, } - meta.Data = inputAudio.Data + meta.OriginData = inputAudio.Data fileMeta = append(fileMeta, meta) } } else if m.Type == ContentTypeFile { @@ -137,7 +137,7 @@ func (r *GeneralOpenAIRequest) GetTokenCountMeta() *types.TokenCountMeta { meta := &types.FileMeta{ FileType: types.FileTypeFile, } - meta.Data = file.FileData + meta.OriginData = file.FileData fileMeta = append(fileMeta, meta) } } else if m.Type == ContentTypeVideoUrl { @@ -146,7 +146,7 @@ func (r *GeneralOpenAIRequest) GetTokenCountMeta() *types.TokenCountMeta { meta := &types.FileMeta{ FileType: types.FileTypeVideo, } - meta.Data = videoUrl.Url + meta.OriginData = videoUrl.Url fileMeta = append(fileMeta, meta) } } else { @@ -784,14 +784,14 @@ func (r *OpenAIResponsesRequest) GetTokenCountMeta() *types.TokenCountMeta { for _, input := range inputs { if input.Type == "input_image" { fileMeta = append(fileMeta, &types.FileMeta{ - FileType: types.FileTypeImage, - Data: input.ImageUrl, - Detail: input.Detail, + FileType: types.FileTypeImage, + OriginData: input.ImageUrl, + Detail: input.Detail, }) } else if input.Type == "input_file" { fileMeta = append(fileMeta, &types.FileMeta{ - FileType: types.FileTypeFile, - Data: input.FileUrl, + FileType: types.FileTypeFile, + OriginData: input.FileUrl, }) } else { texts = append(texts, input.Text) diff --git a/relay/channel/aws/adaptor.go b/relay/channel/aws/adaptor.go index d7910725..1526a7f7 100644 --- a/relay/channel/aws/adaptor.go +++ b/relay/channel/aws/adaptor.go @@ -63,7 +63,7 @@ func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayIn var claudeReq *dto.ClaudeRequest var err error - claudeReq, err = claude.RequestOpenAI2ClaudeMessage(*request) + claudeReq, err = claude.RequestOpenAI2ClaudeMessage(c, *request) if err != nil { return nil, err } diff --git a/relay/channel/claude/adaptor.go b/relay/channel/claude/adaptor.go index 41583d30..c5f6efcc 100644 --- a/relay/channel/claude/adaptor.go +++ b/relay/channel/claude/adaptor.go @@ -78,7 +78,7 @@ func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayIn if a.RequestMode == RequestModeCompletion { return RequestOpenAI2ClaudeComplete(*request), nil } else { - return RequestOpenAI2ClaudeMessage(*request) + return RequestOpenAI2ClaudeMessage(c, *request) } } diff --git a/relay/channel/claude/relay-claude.go b/relay/channel/claude/relay-claude.go index 57670bcf..ad363352 100644 --- a/relay/channel/claude/relay-claude.go +++ b/relay/channel/claude/relay-claude.go @@ -71,7 +71,7 @@ func RequestOpenAI2ClaudeComplete(textRequest dto.GeneralOpenAIRequest) *dto.Cla return &claudeRequest } -func RequestOpenAI2ClaudeMessage(textRequest dto.GeneralOpenAIRequest) (*dto.ClaudeRequest, error) { +func RequestOpenAI2ClaudeMessage(c *gin.Context, textRequest dto.GeneralOpenAIRequest) (*dto.ClaudeRequest, error) { claudeTools := make([]any, 0, len(textRequest.Tools)) for _, tool := range textRequest.Tools { @@ -355,7 +355,7 @@ func RequestOpenAI2ClaudeMessage(textRequest dto.GeneralOpenAIRequest) (*dto.Cla // 判断是否是url if strings.HasPrefix(imageUrl.Url, "http") { // 是url,获取图片的类型和base64编码的数据 - fileData, err := service.GetFileBase64FromUrl(imageUrl.Url) + fileData, err := service.GetFileBase64FromUrl(c, imageUrl.Url, "formatting image for Claude") if err != nil { return nil, fmt.Errorf("get file base64 from url failed: %s", err.Error()) } diff --git a/relay/channel/gemini/adaptor.go b/relay/channel/gemini/adaptor.go index 99b6645e..43e48f34 100644 --- a/relay/channel/gemini/adaptor.go +++ b/relay/channel/gemini/adaptor.go @@ -142,7 +142,7 @@ func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayIn return nil, errors.New("request is nil") } - geminiRequest, err := CovertGemini2OpenAI(*request, info) + geminiRequest, err := CovertGemini2OpenAI(c, *request, info) if err != nil { return nil, err } diff --git a/relay/channel/gemini/relay-gemini.go b/relay/channel/gemini/relay-gemini.go index af5e8233..b0336af4 100644 --- a/relay/channel/gemini/relay-gemini.go +++ b/relay/channel/gemini/relay-gemini.go @@ -178,7 +178,7 @@ func ThinkingAdaptor(geminiRequest *dto.GeminiChatRequest, info *relaycommon.Rel } // Setting safety to the lowest possible values since Gemini is already powerless enough -func CovertGemini2OpenAI(textRequest dto.GeneralOpenAIRequest, info *relaycommon.RelayInfo) (*dto.GeminiChatRequest, error) { +func CovertGemini2OpenAI(c *gin.Context, textRequest dto.GeneralOpenAIRequest, info *relaycommon.RelayInfo) (*dto.GeminiChatRequest, error) { geminiRequest := dto.GeminiChatRequest{ Contents: make([]dto.GeminiChatContent, 0, len(textRequest.Messages)), @@ -390,7 +390,7 @@ func CovertGemini2OpenAI(textRequest dto.GeneralOpenAIRequest, info *relaycommon // 判断是否是url if strings.HasPrefix(part.GetImageMedia().Url, "http") { // 是url,获取文件的类型和base64编码的数据 - fileData, err := service.GetFileBase64FromUrl(part.GetImageMedia().Url) + fileData, err := service.GetFileBase64FromUrl(c, part.GetImageMedia().Url, "formatting image for Gemini") if err != nil { return nil, fmt.Errorf("get file base64 from url '%s' failed: %w", part.GetImageMedia().Url, err) } diff --git a/relay/channel/ollama/adaptor.go b/relay/channel/ollama/adaptor.go index 1a0caf75..d6b5b697 100644 --- a/relay/channel/ollama/adaptor.go +++ b/relay/channel/ollama/adaptor.go @@ -31,7 +31,7 @@ func (a *Adaptor) ConvertClaudeRequest(c *gin.Context, info *relaycommon.RelayIn openaiRequest.(*dto.GeneralOpenAIRequest).StreamOptions = &dto.StreamOptions{ IncludeUsage: true, } - return requestOpenAI2Ollama(openaiRequest.(*dto.GeneralOpenAIRequest)) + return requestOpenAI2Ollama(c, openaiRequest.(*dto.GeneralOpenAIRequest)) } func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) { @@ -69,7 +69,7 @@ func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayIn if request == nil { return nil, errors.New("request is nil") } - return requestOpenAI2Ollama(request) + return requestOpenAI2Ollama(c, request) } func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) { diff --git a/relay/channel/ollama/relay-ollama.go b/relay/channel/ollama/relay-ollama.go index 066581fa..be2029f5 100644 --- a/relay/channel/ollama/relay-ollama.go +++ b/relay/channel/ollama/relay-ollama.go @@ -14,7 +14,7 @@ import ( "github.com/gin-gonic/gin" ) -func requestOpenAI2Ollama(request *dto.GeneralOpenAIRequest) (*OllamaRequest, error) { +func requestOpenAI2Ollama(c *gin.Context, request *dto.GeneralOpenAIRequest) (*OllamaRequest, error) { messages := make([]dto.Message, 0, len(request.Messages)) for _, message := range request.Messages { if !message.IsStringContent() { @@ -24,7 +24,7 @@ func requestOpenAI2Ollama(request *dto.GeneralOpenAIRequest) (*OllamaRequest, er imageUrl := mediaMessage.GetImageMedia() // check if not base64 if strings.HasPrefix(imageUrl.Url, "http") { - fileData, err := service.GetFileBase64FromUrl(imageUrl.Url) + fileData, err := service.GetFileBase64FromUrl(c, imageUrl.Url, "formatting image for Ollama") if err != nil { return nil, err } diff --git a/relay/channel/vertex/adaptor.go b/relay/channel/vertex/adaptor.go index 35e4490b..6cc48d7b 100644 --- a/relay/channel/vertex/adaptor.go +++ b/relay/channel/vertex/adaptor.go @@ -182,7 +182,7 @@ func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayIn return nil, errors.New("request is nil") } if a.RequestMode == RequestModeClaude { - claudeReq, err := claude.RequestOpenAI2ClaudeMessage(*request) + claudeReq, err := claude.RequestOpenAI2ClaudeMessage(c, *request) if err != nil { return nil, err } @@ -191,7 +191,7 @@ func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayIn info.UpstreamModelName = claudeReq.Model return vertexClaudeReq, nil } else if a.RequestMode == RequestModeGemini { - geminiRequest, err := gemini.CovertGemini2OpenAI(*request, info) + geminiRequest, err := gemini.CovertGemini2OpenAI(c, *request, info) if err != nil { return nil, err } diff --git a/service/cf_worker.go b/service/cf_worker.go index ae6e1ffe..4a7b4376 100644 --- a/service/cf_worker.go +++ b/service/cf_worker.go @@ -42,16 +42,16 @@ func DoWorkerRequest(req *WorkerRequest) (*http.Response, error) { return http.Post(workerUrl, "application/json", bytes.NewBuffer(workerPayload)) } -func DoDownloadRequest(originUrl string) (resp *http.Response, err error) { +func DoDownloadRequest(originUrl string, reason ...string) (resp *http.Response, err error) { if setting.EnableWorker() { - common.SysLog(fmt.Sprintf("downloading file from worker: %s", originUrl)) + common.SysLog(fmt.Sprintf("downloading file from worker: %s, reason: %s", originUrl, strings.Join(reason, ", "))) req := &WorkerRequest{ URL: originUrl, Key: setting.WorkerValidKey, } return DoWorkerRequest(req) } else { - common.SysLog(fmt.Sprintf("downloading from origin: %s", originUrl)) + common.SysLog(fmt.Sprintf("downloading from origin with worker: %s, reason: %s", originUrl, strings.Join(reason, ", "))) return http.Get(originUrl) } } diff --git a/service/file_decoder.go b/service/file_decoder.go index c1d4fb0c..bd14b963 100644 --- a/service/file_decoder.go +++ b/service/file_decoder.go @@ -3,17 +3,29 @@ package service import ( "encoding/base64" "fmt" + "github.com/gin-gonic/gin" "io" "one-api/common" "one-api/constant" - "one-api/dto" + "one-api/logger" + "one-api/types" "strings" ) -func GetFileBase64FromUrl(url string) (*dto.LocalFileData, error) { +func GetFileBase64FromUrl(c *gin.Context, url string, reason ...string) (*types.LocalFileData, error) { + contextKey := fmt.Sprintf("file_download_%s", common.GenerateHMAC(url)) + + // Check if the file has already been downloaded in this request + if cachedData, exists := c.Get(contextKey); exists { + if common.DebugEnabled { + logger.LogDebug(c, fmt.Sprintf("Using cached file data for URL: %s", url)) + } + return cachedData.(*types.LocalFileData), nil + } + var maxFileSize = constant.MaxFileDownloadMB * 1024 * 1024 - resp, err := DoDownloadRequest(url) + resp, err := DoDownloadRequest(url, reason...) if err != nil { return nil, err } @@ -77,12 +89,15 @@ func GetFileBase64FromUrl(url string) (*dto.LocalFileData, error) { } } } - - return &dto.LocalFileData{ + data := &types.LocalFileData{ Base64Data: base64Data, MimeType: mimeType, Size: int64(len(fileBytes)), - }, nil + } + // Store the file data in the context to avoid re-downloading + c.Set(contextKey, data) + + return data, nil } func GetMimeTypeByExtension(ext string) string { diff --git a/service/token_counter.go b/service/token_counter.go index 314fa593..2312f27e 100644 --- a/service/token_counter.go +++ b/service/token_counter.go @@ -154,16 +154,22 @@ func getImageToken(fileMeta *types.FileMeta, model string, stream bool) (int, er var err error var format string var b64str string - if strings.HasPrefix(fileMeta.Data, "http") { - config, format, err = DecodeUrlImageData(fileMeta.Data) + + if fileMeta.ParsedData != nil { + config, format, b64str, err = DecodeBase64ImageData(fileMeta.ParsedData.Base64Data) } else { - common.SysLog(fmt.Sprintf("decoding image")) - config, format, b64str, err = DecodeBase64ImageData(fileMeta.Data) + if strings.HasPrefix(fileMeta.OriginData, "http") { + config, format, err = DecodeUrlImageData(fileMeta.OriginData) + } else { + common.SysLog(fmt.Sprintf("decoding image")) + config, format, b64str, err = DecodeBase64ImageData(fileMeta.OriginData) + } + fileMeta.MimeType = format } + if err != nil { return 0, err } - fileMeta.MimeType = format if config.Width == 0 || config.Height == 0 { // not an image @@ -171,7 +177,7 @@ func getImageToken(fileMeta *types.FileMeta, model string, stream bool) (int, er // file type return 3 * baseTokens, nil } - return 0, errors.New(fmt.Sprintf("fail to decode base64 config: %s", fileMeta.Data)) + return 0, errors.New(fmt.Sprintf("fail to decode base64 config: %s", fileMeta.OriginData)) } width := config.Width @@ -268,6 +274,34 @@ func CountRequestToken(c *gin.Context, meta *types.TokenCountMeta, info *relayco tkm += 3 } + shouldFetchFiles := true + + if info.RelayFormat == types.RelayFormatOpenAIRealtime || info.RelayFormat == types.RelayFormatGemini { + shouldFetchFiles = false + } + + if shouldFetchFiles { + for _, file := range meta.Files { + if strings.HasPrefix(file.OriginData, "http") { + localFileData, err := GetFileBase64FromUrl(c, file.OriginData, "token_counter") + if err != nil { + return 0, fmt.Errorf("error getting file base64 from url: %v", err) + } + if strings.HasPrefix(localFileData.MimeType, "image/") { + file.FileType = types.FileTypeImage + } else if strings.HasPrefix(localFileData.MimeType, "video/") { + file.FileType = types.FileTypeVideo + } else if strings.HasPrefix(localFileData.MimeType, "audio/") { + file.FileType = types.FileTypeAudio + } else { + file.FileType = types.FileTypeFile + } + file.MimeType = localFileData.MimeType + file.ParsedData = localFileData + } + } + } + for _, file := range meta.Files { switch file.FileType { case types.FileTypeImage: diff --git a/dto/file_data.go b/types/file_data.go similarity index 88% rename from dto/file_data.go rename to types/file_data.go index d5cf0f68..f1c82e21 100644 --- a/dto/file_data.go +++ b/types/file_data.go @@ -1,4 +1,4 @@ -package dto +package types type LocalFileData struct { MimeType string diff --git a/types/request_meta.go b/types/request_meta.go index 427bacb9..18f80832 100644 --- a/types/request_meta.go +++ b/types/request_meta.go @@ -32,9 +32,10 @@ type TokenCountMeta struct { type FileMeta struct { FileType - MimeType string - Data string - Detail string + MimeType string + OriginData string // url or base64 data + Detail string + ParsedData *LocalFileData } type RequestMeta struct {