refactor: update function signatures to include context and improve file handling #1599

This commit is contained in:
CaIon
2025-08-15 18:40:54 +08:00
parent b57e97d2a1
commit 0bb43aa464
18 changed files with 105 additions and 52 deletions

View File

@@ -3,7 +3,8 @@ package constant
type ContextKey string type ContextKey string
const ( const (
ContextKeyPromptTokens ContextKey = "prompt_tokens" ContextKeyTokenCountMeta ContextKey = "token_count_meta"
ContextKeyPromptTokens ContextKey = "prompt_tokens"
ContextKeyOriginalModel ContextKey = "original_model" ContextKeyOriginalModel ContextKey = "original_model"
ContextKeyRequestStartTime ContextKey = "request_start_time" ContextKeyRequestStartTime ContextKey = "request_start_time"

View File

@@ -133,6 +133,8 @@ func Relay(c *gin.Context, relayFormat types.RelayFormat) {
return return
} }
common.SetContextKey(c, constant.ContextKeyTokenCountMeta, meta)
preConsumedQuota, newAPIError := service.PreConsumeQuota(c, priceData.ShouldPreConsumedQuota, relayInfo) preConsumedQuota, newAPIError := service.PreConsumeQuota(c, priceData.ShouldPreConsumedQuota, relayInfo)
if newAPIError != nil { if newAPIError != nil {
return return

View File

@@ -231,7 +231,7 @@ func (c *ClaudeRequest) GetTokenCountMeta() *types.TokenCountMeta {
data = common.Interface2String(media.Source.Data) data = common.Interface2String(media.Source.Data)
} }
if data != "" { if data != "" {
fileMeta = append(fileMeta, &types.FileMeta{FileType: types.FileTypeImage, Data: data}) fileMeta = append(fileMeta, &types.FileMeta{FileType: types.FileTypeImage, OriginData: data})
} }
} }
} }
@@ -263,7 +263,7 @@ func (c *ClaudeRequest) GetTokenCountMeta() *types.TokenCountMeta {
data = common.Interface2String(media.Source.Data) data = common.Interface2String(media.Source.Data)
} }
if data != "" { if data != "" {
fileMeta = append(fileMeta, &types.FileMeta{FileType: types.FileTypeImage, Data: data}) fileMeta = append(fileMeta, &types.FileMeta{FileType: types.FileTypeImage, OriginData: data})
} }
} }
case "tool_use": case "tool_use":

View File

@@ -35,23 +35,23 @@ func (r *GeminiChatRequest) GetTokenCountMeta() *types.TokenCountMeta {
if part.InlineData != nil && part.InlineData.Data != "" { if part.InlineData != nil && part.InlineData.Data != "" {
if strings.HasPrefix(part.InlineData.MimeType, "image/") { if strings.HasPrefix(part.InlineData.MimeType, "image/") {
files = append(files, &types.FileMeta{ files = append(files, &types.FileMeta{
FileType: types.FileTypeImage, FileType: types.FileTypeImage,
Data: part.InlineData.Data, OriginData: part.InlineData.Data,
}) })
} else if strings.HasPrefix(part.InlineData.MimeType, "audio/") { } else if strings.HasPrefix(part.InlineData.MimeType, "audio/") {
files = append(files, &types.FileMeta{ files = append(files, &types.FileMeta{
FileType: types.FileTypeAudio, FileType: types.FileTypeAudio,
Data: part.InlineData.Data, OriginData: part.InlineData.Data,
}) })
} else if strings.HasPrefix(part.InlineData.MimeType, "video/") { } else if strings.HasPrefix(part.InlineData.MimeType, "video/") {
files = append(files, &types.FileMeta{ files = append(files, &types.FileMeta{
FileType: types.FileTypeVideo, FileType: types.FileTypeVideo,
Data: part.InlineData.Data, OriginData: part.InlineData.Data,
}) })
} else { } else {
files = append(files, &types.FileMeta{ files = append(files, &types.FileMeta{
FileType: types.FileTypeFile, FileType: types.FileTypeFile,
Data: part.InlineData.Data, OriginData: part.InlineData.Data,
}) })
} }
} }

View File

@@ -118,7 +118,7 @@ func (r *GeneralOpenAIRequest) GetTokenCountMeta() *types.TokenCountMeta {
meta := &types.FileMeta{ meta := &types.FileMeta{
FileType: types.FileTypeImage, FileType: types.FileTypeImage,
} }
meta.Data = imageUrl.Url meta.OriginData = imageUrl.Url
meta.Detail = imageUrl.Detail meta.Detail = imageUrl.Detail
fileMeta = append(fileMeta, meta) fileMeta = append(fileMeta, meta)
} }
@@ -128,7 +128,7 @@ func (r *GeneralOpenAIRequest) GetTokenCountMeta() *types.TokenCountMeta {
meta := &types.FileMeta{ meta := &types.FileMeta{
FileType: types.FileTypeAudio, FileType: types.FileTypeAudio,
} }
meta.Data = inputAudio.Data meta.OriginData = inputAudio.Data
fileMeta = append(fileMeta, meta) fileMeta = append(fileMeta, meta)
} }
} else if m.Type == ContentTypeFile { } else if m.Type == ContentTypeFile {
@@ -137,7 +137,7 @@ func (r *GeneralOpenAIRequest) GetTokenCountMeta() *types.TokenCountMeta {
meta := &types.FileMeta{ meta := &types.FileMeta{
FileType: types.FileTypeFile, FileType: types.FileTypeFile,
} }
meta.Data = file.FileData meta.OriginData = file.FileData
fileMeta = append(fileMeta, meta) fileMeta = append(fileMeta, meta)
} }
} else if m.Type == ContentTypeVideoUrl { } else if m.Type == ContentTypeVideoUrl {
@@ -146,7 +146,7 @@ func (r *GeneralOpenAIRequest) GetTokenCountMeta() *types.TokenCountMeta {
meta := &types.FileMeta{ meta := &types.FileMeta{
FileType: types.FileTypeVideo, FileType: types.FileTypeVideo,
} }
meta.Data = videoUrl.Url meta.OriginData = videoUrl.Url
fileMeta = append(fileMeta, meta) fileMeta = append(fileMeta, meta)
} }
} else { } else {
@@ -784,14 +784,14 @@ func (r *OpenAIResponsesRequest) GetTokenCountMeta() *types.TokenCountMeta {
for _, input := range inputs { for _, input := range inputs {
if input.Type == "input_image" { if input.Type == "input_image" {
fileMeta = append(fileMeta, &types.FileMeta{ fileMeta = append(fileMeta, &types.FileMeta{
FileType: types.FileTypeImage, FileType: types.FileTypeImage,
Data: input.ImageUrl, OriginData: input.ImageUrl,
Detail: input.Detail, Detail: input.Detail,
}) })
} else if input.Type == "input_file" { } else if input.Type == "input_file" {
fileMeta = append(fileMeta, &types.FileMeta{ fileMeta = append(fileMeta, &types.FileMeta{
FileType: types.FileTypeFile, FileType: types.FileTypeFile,
Data: input.FileUrl, OriginData: input.FileUrl,
}) })
} else { } else {
texts = append(texts, input.Text) texts = append(texts, input.Text)

View File

@@ -63,7 +63,7 @@ func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayIn
var claudeReq *dto.ClaudeRequest var claudeReq *dto.ClaudeRequest
var err error var err error
claudeReq, err = claude.RequestOpenAI2ClaudeMessage(*request) claudeReq, err = claude.RequestOpenAI2ClaudeMessage(c, *request)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@@ -78,7 +78,7 @@ func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayIn
if a.RequestMode == RequestModeCompletion { if a.RequestMode == RequestModeCompletion {
return RequestOpenAI2ClaudeComplete(*request), nil return RequestOpenAI2ClaudeComplete(*request), nil
} else { } else {
return RequestOpenAI2ClaudeMessage(*request) return RequestOpenAI2ClaudeMessage(c, *request)
} }
} }

View File

@@ -71,7 +71,7 @@ func RequestOpenAI2ClaudeComplete(textRequest dto.GeneralOpenAIRequest) *dto.Cla
return &claudeRequest return &claudeRequest
} }
func RequestOpenAI2ClaudeMessage(textRequest dto.GeneralOpenAIRequest) (*dto.ClaudeRequest, error) { func RequestOpenAI2ClaudeMessage(c *gin.Context, textRequest dto.GeneralOpenAIRequest) (*dto.ClaudeRequest, error) {
claudeTools := make([]any, 0, len(textRequest.Tools)) claudeTools := make([]any, 0, len(textRequest.Tools))
for _, tool := range textRequest.Tools { for _, tool := range textRequest.Tools {
@@ -355,7 +355,7 @@ func RequestOpenAI2ClaudeMessage(textRequest dto.GeneralOpenAIRequest) (*dto.Cla
// 判断是否是url // 判断是否是url
if strings.HasPrefix(imageUrl.Url, "http") { if strings.HasPrefix(imageUrl.Url, "http") {
// 是url获取图片的类型和base64编码的数据 // 是url获取图片的类型和base64编码的数据
fileData, err := service.GetFileBase64FromUrl(imageUrl.Url) fileData, err := service.GetFileBase64FromUrl(c, imageUrl.Url, "formatting image for Claude")
if err != nil { if err != nil {
return nil, fmt.Errorf("get file base64 from url failed: %s", err.Error()) return nil, fmt.Errorf("get file base64 from url failed: %s", err.Error())
} }

View File

@@ -142,7 +142,7 @@ func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayIn
return nil, errors.New("request is nil") return nil, errors.New("request is nil")
} }
geminiRequest, err := CovertGemini2OpenAI(*request, info) geminiRequest, err := CovertGemini2OpenAI(c, *request, info)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@@ -178,7 +178,7 @@ func ThinkingAdaptor(geminiRequest *dto.GeminiChatRequest, info *relaycommon.Rel
} }
// Setting safety to the lowest possible values since Gemini is already powerless enough // Setting safety to the lowest possible values since Gemini is already powerless enough
func CovertGemini2OpenAI(textRequest dto.GeneralOpenAIRequest, info *relaycommon.RelayInfo) (*dto.GeminiChatRequest, error) { func CovertGemini2OpenAI(c *gin.Context, textRequest dto.GeneralOpenAIRequest, info *relaycommon.RelayInfo) (*dto.GeminiChatRequest, error) {
geminiRequest := dto.GeminiChatRequest{ geminiRequest := dto.GeminiChatRequest{
Contents: make([]dto.GeminiChatContent, 0, len(textRequest.Messages)), Contents: make([]dto.GeminiChatContent, 0, len(textRequest.Messages)),
@@ -390,7 +390,7 @@ func CovertGemini2OpenAI(textRequest dto.GeneralOpenAIRequest, info *relaycommon
// 判断是否是url // 判断是否是url
if strings.HasPrefix(part.GetImageMedia().Url, "http") { if strings.HasPrefix(part.GetImageMedia().Url, "http") {
// 是url获取文件的类型和base64编码的数据 // 是url获取文件的类型和base64编码的数据
fileData, err := service.GetFileBase64FromUrl(part.GetImageMedia().Url) fileData, err := service.GetFileBase64FromUrl(c, part.GetImageMedia().Url, "formatting image for Gemini")
if err != nil { if err != nil {
return nil, fmt.Errorf("get file base64 from url '%s' failed: %w", part.GetImageMedia().Url, err) return nil, fmt.Errorf("get file base64 from url '%s' failed: %w", part.GetImageMedia().Url, err)
} }

View File

@@ -31,7 +31,7 @@ func (a *Adaptor) ConvertClaudeRequest(c *gin.Context, info *relaycommon.RelayIn
openaiRequest.(*dto.GeneralOpenAIRequest).StreamOptions = &dto.StreamOptions{ openaiRequest.(*dto.GeneralOpenAIRequest).StreamOptions = &dto.StreamOptions{
IncludeUsage: true, IncludeUsage: true,
} }
return requestOpenAI2Ollama(openaiRequest.(*dto.GeneralOpenAIRequest)) return requestOpenAI2Ollama(c, openaiRequest.(*dto.GeneralOpenAIRequest))
} }
func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) { func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) {
@@ -69,7 +69,7 @@ func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayIn
if request == nil { if request == nil {
return nil, errors.New("request is nil") return nil, errors.New("request is nil")
} }
return requestOpenAI2Ollama(request) return requestOpenAI2Ollama(c, request)
} }
func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) { func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) {

View File

@@ -14,7 +14,7 @@ import (
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
) )
func requestOpenAI2Ollama(request *dto.GeneralOpenAIRequest) (*OllamaRequest, error) { func requestOpenAI2Ollama(c *gin.Context, request *dto.GeneralOpenAIRequest) (*OllamaRequest, error) {
messages := make([]dto.Message, 0, len(request.Messages)) messages := make([]dto.Message, 0, len(request.Messages))
for _, message := range request.Messages { for _, message := range request.Messages {
if !message.IsStringContent() { if !message.IsStringContent() {
@@ -24,7 +24,7 @@ func requestOpenAI2Ollama(request *dto.GeneralOpenAIRequest) (*OllamaRequest, er
imageUrl := mediaMessage.GetImageMedia() imageUrl := mediaMessage.GetImageMedia()
// check if not base64 // check if not base64
if strings.HasPrefix(imageUrl.Url, "http") { if strings.HasPrefix(imageUrl.Url, "http") {
fileData, err := service.GetFileBase64FromUrl(imageUrl.Url) fileData, err := service.GetFileBase64FromUrl(c, imageUrl.Url, "formatting image for Ollama")
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@@ -182,7 +182,7 @@ func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayIn
return nil, errors.New("request is nil") return nil, errors.New("request is nil")
} }
if a.RequestMode == RequestModeClaude { if a.RequestMode == RequestModeClaude {
claudeReq, err := claude.RequestOpenAI2ClaudeMessage(*request) claudeReq, err := claude.RequestOpenAI2ClaudeMessage(c, *request)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -191,7 +191,7 @@ func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayIn
info.UpstreamModelName = claudeReq.Model info.UpstreamModelName = claudeReq.Model
return vertexClaudeReq, nil return vertexClaudeReq, nil
} else if a.RequestMode == RequestModeGemini { } else if a.RequestMode == RequestModeGemini {
geminiRequest, err := gemini.CovertGemini2OpenAI(*request, info) geminiRequest, err := gemini.CovertGemini2OpenAI(c, *request, info)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@@ -42,16 +42,16 @@ func DoWorkerRequest(req *WorkerRequest) (*http.Response, error) {
return http.Post(workerUrl, "application/json", bytes.NewBuffer(workerPayload)) return http.Post(workerUrl, "application/json", bytes.NewBuffer(workerPayload))
} }
func DoDownloadRequest(originUrl string) (resp *http.Response, err error) { func DoDownloadRequest(originUrl string, reason ...string) (resp *http.Response, err error) {
if setting.EnableWorker() { if setting.EnableWorker() {
common.SysLog(fmt.Sprintf("downloading file from worker: %s", originUrl)) common.SysLog(fmt.Sprintf("downloading file from worker: %s, reason: %s", originUrl, strings.Join(reason, ", ")))
req := &WorkerRequest{ req := &WorkerRequest{
URL: originUrl, URL: originUrl,
Key: setting.WorkerValidKey, Key: setting.WorkerValidKey,
} }
return DoWorkerRequest(req) return DoWorkerRequest(req)
} else { } else {
common.SysLog(fmt.Sprintf("downloading from origin: %s", originUrl)) common.SysLog(fmt.Sprintf("downloading from origin with worker: %s, reason: %s", originUrl, strings.Join(reason, ", ")))
return http.Get(originUrl) return http.Get(originUrl)
} }
} }

View File

@@ -3,17 +3,29 @@ package service
import ( import (
"encoding/base64" "encoding/base64"
"fmt" "fmt"
"github.com/gin-gonic/gin"
"io" "io"
"one-api/common" "one-api/common"
"one-api/constant" "one-api/constant"
"one-api/dto" "one-api/logger"
"one-api/types"
"strings" "strings"
) )
func GetFileBase64FromUrl(url string) (*dto.LocalFileData, error) { func GetFileBase64FromUrl(c *gin.Context, url string, reason ...string) (*types.LocalFileData, error) {
contextKey := fmt.Sprintf("file_download_%s", common.GenerateHMAC(url))
// Check if the file has already been downloaded in this request
if cachedData, exists := c.Get(contextKey); exists {
if common.DebugEnabled {
logger.LogDebug(c, fmt.Sprintf("Using cached file data for URL: %s", url))
}
return cachedData.(*types.LocalFileData), nil
}
var maxFileSize = constant.MaxFileDownloadMB * 1024 * 1024 var maxFileSize = constant.MaxFileDownloadMB * 1024 * 1024
resp, err := DoDownloadRequest(url) resp, err := DoDownloadRequest(url, reason...)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -77,12 +89,15 @@ func GetFileBase64FromUrl(url string) (*dto.LocalFileData, error) {
} }
} }
} }
data := &types.LocalFileData{
return &dto.LocalFileData{
Base64Data: base64Data, Base64Data: base64Data,
MimeType: mimeType, MimeType: mimeType,
Size: int64(len(fileBytes)), Size: int64(len(fileBytes)),
}, nil }
// Store the file data in the context to avoid re-downloading
c.Set(contextKey, data)
return data, nil
} }
func GetMimeTypeByExtension(ext string) string { func GetMimeTypeByExtension(ext string) string {

View File

@@ -154,16 +154,22 @@ func getImageToken(fileMeta *types.FileMeta, model string, stream bool) (int, er
var err error var err error
var format string var format string
var b64str string var b64str string
if strings.HasPrefix(fileMeta.Data, "http") {
config, format, err = DecodeUrlImageData(fileMeta.Data) if fileMeta.ParsedData != nil {
config, format, b64str, err = DecodeBase64ImageData(fileMeta.ParsedData.Base64Data)
} else { } else {
common.SysLog(fmt.Sprintf("decoding image")) if strings.HasPrefix(fileMeta.OriginData, "http") {
config, format, b64str, err = DecodeBase64ImageData(fileMeta.Data) config, format, err = DecodeUrlImageData(fileMeta.OriginData)
} else {
common.SysLog(fmt.Sprintf("decoding image"))
config, format, b64str, err = DecodeBase64ImageData(fileMeta.OriginData)
}
fileMeta.MimeType = format
} }
if err != nil { if err != nil {
return 0, err return 0, err
} }
fileMeta.MimeType = format
if config.Width == 0 || config.Height == 0 { if config.Width == 0 || config.Height == 0 {
// not an image // not an image
@@ -171,7 +177,7 @@ func getImageToken(fileMeta *types.FileMeta, model string, stream bool) (int, er
// file type // file type
return 3 * baseTokens, nil return 3 * baseTokens, nil
} }
return 0, errors.New(fmt.Sprintf("fail to decode base64 config: %s", fileMeta.Data)) return 0, errors.New(fmt.Sprintf("fail to decode base64 config: %s", fileMeta.OriginData))
} }
width := config.Width width := config.Width
@@ -268,6 +274,34 @@ func CountRequestToken(c *gin.Context, meta *types.TokenCountMeta, info *relayco
tkm += 3 tkm += 3
} }
shouldFetchFiles := true
if info.RelayFormat == types.RelayFormatOpenAIRealtime || info.RelayFormat == types.RelayFormatGemini {
shouldFetchFiles = false
}
if shouldFetchFiles {
for _, file := range meta.Files {
if strings.HasPrefix(file.OriginData, "http") {
localFileData, err := GetFileBase64FromUrl(c, file.OriginData, "token_counter")
if err != nil {
return 0, fmt.Errorf("error getting file base64 from url: %v", err)
}
if strings.HasPrefix(localFileData.MimeType, "image/") {
file.FileType = types.FileTypeImage
} else if strings.HasPrefix(localFileData.MimeType, "video/") {
file.FileType = types.FileTypeVideo
} else if strings.HasPrefix(localFileData.MimeType, "audio/") {
file.FileType = types.FileTypeAudio
} else {
file.FileType = types.FileTypeFile
}
file.MimeType = localFileData.MimeType
file.ParsedData = localFileData
}
}
}
for _, file := range meta.Files { for _, file := range meta.Files {
switch file.FileType { switch file.FileType {
case types.FileTypeImage: case types.FileTypeImage:

View File

@@ -1,4 +1,4 @@
package dto package types
type LocalFileData struct { type LocalFileData struct {
MimeType string MimeType string

View File

@@ -32,9 +32,10 @@ type TokenCountMeta struct {
type FileMeta struct { type FileMeta struct {
FileType FileType
MimeType string MimeType string
Data string OriginData string // url or base64 data
Detail string Detail string
ParsedData *LocalFileData
} }
type RequestMeta struct { type RequestMeta struct {