feat: add multi-file type support for Gemini and Claude

- Add file data DTO for structured file handling
- Implement file decoder service
- Update Claude and Gemini relay channels to handle various file types
- Reorganize worker service to cf_worker for clarity
- Update token counter and image service for new file types
This commit is contained in:
CalciumIon
2024-12-29 00:00:24 +08:00
parent d75ecfc63e
commit 2b38e8ed8d
11 changed files with 89 additions and 20 deletions

View File

@@ -9,9 +9,12 @@ import (
"strings"
)
func DoImageRequest(originUrl string) (resp *http.Response, err error) {
func DoDownloadRequest(originUrl string) (resp *http.Response, err error) {
if setting.EnableWorker() {
common.SysLog(fmt.Sprintf("downloading image from worker: %s", originUrl))
common.SysLog(fmt.Sprintf("downloading file from worker: %s", originUrl))
if !strings.HasPrefix(originUrl, "https") {
return nil, fmt.Errorf("only support https url")
}
workerUrl := setting.WorkerUrl
if !strings.HasSuffix(workerUrl, "/") {
workerUrl += "/"
@@ -20,7 +23,7 @@ func DoImageRequest(originUrl string) (resp *http.Response, err error) {
data := []byte(`{"url":"` + originUrl + `","key":"` + setting.WorkerValidKey + `"}`)
return http.Post(setting.WorkerUrl, "application/json", bytes.NewBuffer(data))
} else {
common.SysLog(fmt.Sprintf("downloading image from origin: %s", originUrl))
common.SysLog(fmt.Sprintf("downloading from origin: %s", originUrl))
return http.Get(originUrl)
}
}

39
service/file_decoder.go Normal file
View File

@@ -0,0 +1,39 @@
package service
import (
"encoding/base64"
"fmt"
"io"
"one-api/constant"
"one-api/dto"
)
var maxFileSize = constant.MaxFileDownloadMB * 1024 * 1024
func GetFileBase64FromUrl(url string) (*dto.LocalFileData, error) {
resp, err := DoDownloadRequest(url)
if err != nil {
return nil, err
}
defer resp.Body.Close()
// Always use LimitReader to prevent oversized downloads
fileBytes, err := io.ReadAll(io.LimitReader(resp.Body, int64(maxFileSize+1)))
if err != nil {
return nil, err
}
// Check actual size after reading
if len(fileBytes) > maxFileSize {
return nil, fmt.Errorf("file size exceeds maximum allowed size: %dMB", constant.MaxFileDownloadMB)
}
// Convert to base64
base64Data := base64.StdEncoding.EncodeToString(fileBytes)
return &dto.LocalFileData{
Base64Data: base64Data,
MimeType: resp.Header.Get("Content-Type"),
Size: int64(len(fileBytes)),
}, nil
}

View File

@@ -33,12 +33,12 @@ func DecodeBase64ImageData(base64String string) (image.Config, string, string, e
// GetImageFromUrl 获取图片的类型和base64编码的数据
func GetImageFromUrl(url string) (mimeType string, data string, err error) {
resp, err := DoImageRequest(url)
resp, err := DoDownloadRequest(url)
if err != nil {
return
return "", "", err
}
if !strings.HasPrefix(resp.Header.Get("Content-Type"), "image/") {
return
return "", "", fmt.Errorf("invalid content type: %s, required image/*", resp.Header.Get("Content-Type"))
}
defer resp.Body.Close()
buffer := bytes.NewBuffer(nil)
@@ -52,7 +52,7 @@ func GetImageFromUrl(url string) (mimeType string, data string, err error) {
}
func DecodeUrlImageData(imageUrl string) (image.Config, string, error) {
response, err := DoImageRequest(imageUrl)
response, err := DoDownloadRequest(imageUrl)
if err != nil {
common.SysLog(fmt.Sprintf("fail to get image from url: %s", err.Error()))
return image.Config{}, "", err
@@ -64,6 +64,12 @@ func DecodeUrlImageData(imageUrl string) (image.Config, string, error) {
return image.Config{}, "", err
}
mimeType := response.Header.Get("Content-Type")
if !strings.HasPrefix(mimeType, "image/") {
return image.Config{}, "", fmt.Errorf("invalid content type: %s, required image/*", mimeType)
}
var readData []byte
for _, limit := range []int64{1024 * 8, 1024 * 24, 1024 * 64} {
common.SysLog(fmt.Sprintf("try to decode image config with limit: %d", limit))

View File

@@ -80,7 +80,7 @@ func getTokenNum(tokenEncoder *tiktoken.Tiktoken, text string) int {
return len(tokenEncoder.Encode(text, nil, nil))
}
func getImageToken(imageUrl *dto.MessageImageUrl, model string, stream bool) (int, error) {
func getImageToken(info *relaycommon.RelayInfo, imageUrl *dto.MessageImageUrl, model string, stream bool) (int, error) {
baseTokens := 85
if model == "glm-4v" {
return 1047, nil
@@ -96,6 +96,9 @@ func getImageToken(imageUrl *dto.MessageImageUrl, model string, stream bool) (in
if !constant.GetMediaToken {
return 256, nil
}
if info.ChannelType == common.ChannelTypeGemini || info.ChannelType == common.ChannelTypeVertexAi || info.ChannelType == common.ChannelTypeAnthropic {
return 256, nil
}
// 同步One API的图片计费逻辑
if imageUrl.Detail == "auto" || imageUrl.Detail == "" {
imageUrl.Detail = "high"
@@ -155,9 +158,9 @@ func getImageToken(imageUrl *dto.MessageImageUrl, model string, stream bool) (in
return tiles*tileTokens + baseTokens, nil
}
func CountTokenChatRequest(request dto.GeneralOpenAIRequest, model string) (int, error) {
func CountTokenChatRequest(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) (int, error) {
tkm := 0
msgTokens, err := CountTokenMessages(request.Messages, model, request.Stream)
msgTokens, err := CountTokenMessages(info, request.Messages, request.Model, request.Stream)
if err != nil {
return 0, err
}
@@ -179,7 +182,7 @@ func CountTokenChatRequest(request dto.GeneralOpenAIRequest, model string) (int,
countStr += fmt.Sprintf("%v", tool.Function.Parameters)
}
}
toolTokens, err := CountTokenInput(countStr, model)
toolTokens, err := CountTokenInput(countStr, request.Model)
if err != nil {
return 0, err
}
@@ -256,7 +259,7 @@ func CountTokenRealtime(info *relaycommon.RelayInfo, request dto.RealtimeEvent,
return textToken, audioToken, nil
}
func CountTokenMessages(messages []dto.Message, model string, stream bool) (int, error) {
func CountTokenMessages(info *relaycommon.RelayInfo, messages []dto.Message, model string, stream bool) (int, error) {
//recover when panic
tokenEncoder := getTokenEncoder(model)
// Reference:
@@ -290,7 +293,7 @@ func CountTokenMessages(messages []dto.Message, model string, stream bool) (int,
for _, m := range arrayContent {
if m.Type == dto.ContentTypeImageURL {
imageUrl := m.ImageUrl.(dto.MessageImageUrl)
imageTokenNum, err := getImageToken(&imageUrl, model, stream)
imageTokenNum, err := getImageToken(info, &imageUrl, model, stream)
if err != nil {
return 0, err
}