feat: add multi-file type support for Gemini and Claude
- Add file data DTO for structured file handling - Implement file decoder service - Update Claude and Gemini relay channels to handle various file types - Reorganize worker service to cf_worker for clarity - Update token counter and image service for new file types
This commit is contained in:
@@ -82,6 +82,7 @@ You can add custom models gpt-4-gizmo-* in channels. These are third-party model
|
||||
- `GEMINI_MODEL_MAP`: Specify Gemini model versions (v1/v1beta), format: "model:version", comma-separated
|
||||
- `COHERE_SAFETY_SETTING`: Cohere model [safety settings](https://docs.cohere.com/docs/safety-modes#overview), options: `NONE`, `CONTEXTUAL`, `STRICT`, default `NONE`
|
||||
- `GEMINI_VISION_MAX_IMAGE_NUM`: Gemini model maximum image number, default `16`, set to `-1` to disable
|
||||
- `MAX_FILE_DOWNLOAD_MB`: Maximum file download size in MB, default `20`
|
||||
|
||||
## Deployment
|
||||
> [!TIP]
|
||||
|
||||
@@ -88,6 +88,7 @@
|
||||
- `GEMINI_MODEL_MAP`:Gemini模型指定版本(v1/v1beta),使用“模型:版本”指定,","分隔,例如:-e GEMINI_MODEL_MAP="gemini-1.5-pro-latest:v1beta,gemini-1.5-pro-001:v1beta",为空则使用默认配置(v1beta)
|
||||
- `COHERE_SAFETY_SETTING`:Cohere模型[安全设置](https://docs.cohere.com/docs/safety-modes#overview),可选值为 `NONE`, `CONTEXTUAL`,`STRICT`,默认为 `NONE`。
|
||||
- `GEMINI_VISION_MAX_IMAGE_NUM`:Gemini模型最大图片数量,默认为 `16`,设置为 `-1` 则不限制。
|
||||
- `MAX_FILE_DOWNLOAD_MB`: 最大文件下载大小,单位 MB,默认为 `20`。
|
||||
## 部署
|
||||
> [!TIP]
|
||||
> 最新版Docker镜像:`calciumion/new-api:latest`
|
||||
|
||||
@@ -10,6 +10,8 @@ import (
|
||||
var StreamingTimeout = common.GetEnvOrDefault("STREAMING_TIMEOUT", 60)
|
||||
var DifyDebug = common.GetEnvOrDefaultBool("DIFY_DEBUG", true)
|
||||
|
||||
var MaxFileDownloadMB = common.GetEnvOrDefault("MAX_FILE_DOWNLOAD_MB", 20)
|
||||
|
||||
// ForceStreamOption 覆盖请求参数,强制返回usage信息
|
||||
var ForceStreamOption = common.GetEnvOrDefaultBool("FORCE_STREAM_OPTION", true)
|
||||
|
||||
|
||||
8
dto/file_data.go
Normal file
8
dto/file_data.go
Normal file
@@ -0,0 +1,8 @@
|
||||
package dto
|
||||
|
||||
type LocalFileData struct {
|
||||
MimeType string
|
||||
Base64Data string
|
||||
Url string
|
||||
Size int64
|
||||
}
|
||||
@@ -225,9 +225,12 @@ func RequestOpenAI2ClaudeMessage(textRequest dto.GeneralOpenAIRequest) (*ClaudeR
|
||||
// 判断是否是url
|
||||
if strings.HasPrefix(imageUrl.Url, "http") {
|
||||
// 是url,获取图片的类型和base64编码的数据
|
||||
mimeType, data, _ := service.GetImageFromUrl(imageUrl.Url)
|
||||
claudeMediaMessage.Source.MediaType = mimeType
|
||||
claudeMediaMessage.Source.Data = data
|
||||
fileData, err := service.GetFileBase64FromUrl(imageUrl.Url)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get file base64 from url failed: %s", err.Error())
|
||||
}
|
||||
claudeMediaMessage.Source.MediaType = fileData.MimeType
|
||||
claudeMediaMessage.Source.Data = fileData.Base64Data
|
||||
} else {
|
||||
_, format, base64String, err := service.DecodeBase64ImageData(imageUrl.Url)
|
||||
if err != nil {
|
||||
|
||||
@@ -192,11 +192,14 @@ func CovertGemini2OpenAI(textRequest dto.GeneralOpenAIRequest) (*GeminiChatReque
|
||||
// 判断是否是url
|
||||
if strings.HasPrefix(part.ImageUrl.(dto.MessageImageUrl).Url, "http") {
|
||||
// 是url,获取图片的类型和base64编码的数据
|
||||
mimeType, data, _ := service.GetImageFromUrl(part.ImageUrl.(dto.MessageImageUrl).Url)
|
||||
fileData, err := service.GetFileBase64FromUrl(part.ImageUrl.(dto.MessageImageUrl).Url)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get file base64 from url failed: %s", err.Error())
|
||||
}
|
||||
parts = append(parts, GeminiPart{
|
||||
InlineData: &GeminiInlineData{
|
||||
MimeType: mimeType,
|
||||
Data: data,
|
||||
MimeType: fileData.MimeType,
|
||||
Data: fileData.Base64Data,
|
||||
},
|
||||
})
|
||||
} else {
|
||||
|
||||
@@ -230,7 +230,7 @@ func getPromptTokens(textRequest *dto.GeneralOpenAIRequest, info *relaycommon.Re
|
||||
var err error
|
||||
switch info.RelayMode {
|
||||
case relayconstant.RelayModeChatCompletions:
|
||||
promptTokens, err = service.CountTokenChatRequest(*textRequest, textRequest.Model)
|
||||
promptTokens, err = service.CountTokenChatRequest(info, *textRequest)
|
||||
case relayconstant.RelayModeCompletions:
|
||||
promptTokens, err = service.CountTokenInput(textRequest.Prompt, textRequest.Model)
|
||||
case relayconstant.RelayModeModerations:
|
||||
|
||||
@@ -9,9 +9,12 @@ import (
|
||||
"strings"
|
||||
)
|
||||
|
||||
func DoImageRequest(originUrl string) (resp *http.Response, err error) {
|
||||
func DoDownloadRequest(originUrl string) (resp *http.Response, err error) {
|
||||
if setting.EnableWorker() {
|
||||
common.SysLog(fmt.Sprintf("downloading image from worker: %s", originUrl))
|
||||
common.SysLog(fmt.Sprintf("downloading file from worker: %s", originUrl))
|
||||
if !strings.HasPrefix(originUrl, "https") {
|
||||
return nil, fmt.Errorf("only support https url")
|
||||
}
|
||||
workerUrl := setting.WorkerUrl
|
||||
if !strings.HasSuffix(workerUrl, "/") {
|
||||
workerUrl += "/"
|
||||
@@ -20,7 +23,7 @@ func DoImageRequest(originUrl string) (resp *http.Response, err error) {
|
||||
data := []byte(`{"url":"` + originUrl + `","key":"` + setting.WorkerValidKey + `"}`)
|
||||
return http.Post(setting.WorkerUrl, "application/json", bytes.NewBuffer(data))
|
||||
} else {
|
||||
common.SysLog(fmt.Sprintf("downloading image from origin: %s", originUrl))
|
||||
common.SysLog(fmt.Sprintf("downloading from origin: %s", originUrl))
|
||||
return http.Get(originUrl)
|
||||
}
|
||||
}
|
||||
39
service/file_decoder.go
Normal file
39
service/file_decoder.go
Normal file
@@ -0,0 +1,39 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"io"
|
||||
"one-api/constant"
|
||||
"one-api/dto"
|
||||
)
|
||||
|
||||
var maxFileSize = constant.MaxFileDownloadMB * 1024 * 1024
|
||||
|
||||
func GetFileBase64FromUrl(url string) (*dto.LocalFileData, error) {
|
||||
resp, err := DoDownloadRequest(url)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
// Always use LimitReader to prevent oversized downloads
|
||||
fileBytes, err := io.ReadAll(io.LimitReader(resp.Body, int64(maxFileSize+1)))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Check actual size after reading
|
||||
if len(fileBytes) > maxFileSize {
|
||||
return nil, fmt.Errorf("file size exceeds maximum allowed size: %dMB", constant.MaxFileDownloadMB)
|
||||
}
|
||||
|
||||
// Convert to base64
|
||||
base64Data := base64.StdEncoding.EncodeToString(fileBytes)
|
||||
|
||||
return &dto.LocalFileData{
|
||||
Base64Data: base64Data,
|
||||
MimeType: resp.Header.Get("Content-Type"),
|
||||
Size: int64(len(fileBytes)),
|
||||
}, nil
|
||||
}
|
||||
@@ -33,12 +33,12 @@ func DecodeBase64ImageData(base64String string) (image.Config, string, string, e
|
||||
|
||||
// GetImageFromUrl 获取图片的类型和base64编码的数据
|
||||
func GetImageFromUrl(url string) (mimeType string, data string, err error) {
|
||||
resp, err := DoImageRequest(url)
|
||||
resp, err := DoDownloadRequest(url)
|
||||
if err != nil {
|
||||
return
|
||||
return "", "", err
|
||||
}
|
||||
if !strings.HasPrefix(resp.Header.Get("Content-Type"), "image/") {
|
||||
return
|
||||
return "", "", fmt.Errorf("invalid content type: %s, required image/*", resp.Header.Get("Content-Type"))
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
buffer := bytes.NewBuffer(nil)
|
||||
@@ -52,7 +52,7 @@ func GetImageFromUrl(url string) (mimeType string, data string, err error) {
|
||||
}
|
||||
|
||||
func DecodeUrlImageData(imageUrl string) (image.Config, string, error) {
|
||||
response, err := DoImageRequest(imageUrl)
|
||||
response, err := DoDownloadRequest(imageUrl)
|
||||
if err != nil {
|
||||
common.SysLog(fmt.Sprintf("fail to get image from url: %s", err.Error()))
|
||||
return image.Config{}, "", err
|
||||
@@ -64,6 +64,12 @@ func DecodeUrlImageData(imageUrl string) (image.Config, string, error) {
|
||||
return image.Config{}, "", err
|
||||
}
|
||||
|
||||
mimeType := response.Header.Get("Content-Type")
|
||||
|
||||
if !strings.HasPrefix(mimeType, "image/") {
|
||||
return image.Config{}, "", fmt.Errorf("invalid content type: %s, required image/*", mimeType)
|
||||
}
|
||||
|
||||
var readData []byte
|
||||
for _, limit := range []int64{1024 * 8, 1024 * 24, 1024 * 64} {
|
||||
common.SysLog(fmt.Sprintf("try to decode image config with limit: %d", limit))
|
||||
|
||||
@@ -80,7 +80,7 @@ func getTokenNum(tokenEncoder *tiktoken.Tiktoken, text string) int {
|
||||
return len(tokenEncoder.Encode(text, nil, nil))
|
||||
}
|
||||
|
||||
func getImageToken(imageUrl *dto.MessageImageUrl, model string, stream bool) (int, error) {
|
||||
func getImageToken(info *relaycommon.RelayInfo, imageUrl *dto.MessageImageUrl, model string, stream bool) (int, error) {
|
||||
baseTokens := 85
|
||||
if model == "glm-4v" {
|
||||
return 1047, nil
|
||||
@@ -96,6 +96,9 @@ func getImageToken(imageUrl *dto.MessageImageUrl, model string, stream bool) (in
|
||||
if !constant.GetMediaToken {
|
||||
return 256, nil
|
||||
}
|
||||
if info.ChannelType == common.ChannelTypeGemini || info.ChannelType == common.ChannelTypeVertexAi || info.ChannelType == common.ChannelTypeAnthropic {
|
||||
return 256, nil
|
||||
}
|
||||
// 同步One API的图片计费逻辑
|
||||
if imageUrl.Detail == "auto" || imageUrl.Detail == "" {
|
||||
imageUrl.Detail = "high"
|
||||
@@ -155,9 +158,9 @@ func getImageToken(imageUrl *dto.MessageImageUrl, model string, stream bool) (in
|
||||
return tiles*tileTokens + baseTokens, nil
|
||||
}
|
||||
|
||||
func CountTokenChatRequest(request dto.GeneralOpenAIRequest, model string) (int, error) {
|
||||
func CountTokenChatRequest(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) (int, error) {
|
||||
tkm := 0
|
||||
msgTokens, err := CountTokenMessages(request.Messages, model, request.Stream)
|
||||
msgTokens, err := CountTokenMessages(info, request.Messages, request.Model, request.Stream)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
@@ -179,7 +182,7 @@ func CountTokenChatRequest(request dto.GeneralOpenAIRequest, model string) (int,
|
||||
countStr += fmt.Sprintf("%v", tool.Function.Parameters)
|
||||
}
|
||||
}
|
||||
toolTokens, err := CountTokenInput(countStr, model)
|
||||
toolTokens, err := CountTokenInput(countStr, request.Model)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
@@ -256,7 +259,7 @@ func CountTokenRealtime(info *relaycommon.RelayInfo, request dto.RealtimeEvent,
|
||||
return textToken, audioToken, nil
|
||||
}
|
||||
|
||||
func CountTokenMessages(messages []dto.Message, model string, stream bool) (int, error) {
|
||||
func CountTokenMessages(info *relaycommon.RelayInfo, messages []dto.Message, model string, stream bool) (int, error) {
|
||||
//recover when panic
|
||||
tokenEncoder := getTokenEncoder(model)
|
||||
// Reference:
|
||||
@@ -290,7 +293,7 @@ func CountTokenMessages(messages []dto.Message, model string, stream bool) (int,
|
||||
for _, m := range arrayContent {
|
||||
if m.Type == dto.ContentTypeImageURL {
|
||||
imageUrl := m.ImageUrl.(dto.MessageImageUrl)
|
||||
imageTokenNum, err := getImageToken(&imageUrl, model, stream)
|
||||
imageTokenNum, err := getImageToken(info, &imageUrl, model, stream)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user