Merge pull request #674 from Yan-Zero/main
fix: Gemini 函数调用的文本转义,以及其他文件类型的 Base64 支持
This commit is contained in:
@@ -12,6 +12,7 @@ import (
|
||||
relaycommon "one-api/relay/common"
|
||||
"one-api/service"
|
||||
"strings"
|
||||
"unicode/utf8"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
@@ -203,13 +204,13 @@ func CovertGemini2OpenAI(textRequest dto.GeneralOpenAIRequest) (*GeminiChatReque
|
||||
},
|
||||
})
|
||||
} else {
|
||||
_, format, base64String, err := service.DecodeBase64ImageData(part.ImageUrl.(dto.MessageImageUrl).Url)
|
||||
format, base64String, err := service.DecodeBase64FileData(part.ImageUrl.(dto.MessageImageUrl).Url)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("decode base64 image data failed: %s", err.Error())
|
||||
}
|
||||
parts = append(parts, GeminiPart{
|
||||
InlineData: &GeminiInlineData{
|
||||
MimeType: "image/" + format,
|
||||
MimeType: format,
|
||||
Data: base64String,
|
||||
},
|
||||
})
|
||||
@@ -279,57 +280,97 @@ func removeAdditionalPropertiesWithDepth(schema interface{}, depth int) interfac
|
||||
return v
|
||||
}
|
||||
|
||||
// func (g *GeminiChatResponse) GetResponseText() string {
|
||||
// if g == nil {
|
||||
// return ""
|
||||
// }
|
||||
// if len(g.Candidates) > 0 && len(g.Candidates[0].Content.Parts) > 0 {
|
||||
// return g.Candidates[0].Content.Parts[0].Text
|
||||
// }
|
||||
// return ""
|
||||
// }
|
||||
func unescapeString(s string) (string, error) {
|
||||
var result []rune
|
||||
escaped := false
|
||||
i := 0
|
||||
|
||||
for i < len(s) {
|
||||
r, size := utf8.DecodeRuneInString(s[i:]) // 正确解码UTF-8字符
|
||||
if r == utf8.RuneError {
|
||||
return "", fmt.Errorf("invalid UTF-8 encoding")
|
||||
}
|
||||
|
||||
if escaped {
|
||||
// 如果是转义符后的字符,检查其类型
|
||||
switch r {
|
||||
case '"':
|
||||
result = append(result, '"')
|
||||
case '\\':
|
||||
result = append(result, '\\')
|
||||
case '/':
|
||||
result = append(result, '/')
|
||||
case 'b':
|
||||
result = append(result, '\b')
|
||||
case 'f':
|
||||
result = append(result, '\f')
|
||||
case 'n':
|
||||
result = append(result, '\n')
|
||||
case 'r':
|
||||
result = append(result, '\r')
|
||||
case 't':
|
||||
result = append(result, '\t')
|
||||
case '\'':
|
||||
result = append(result, '\'')
|
||||
default:
|
||||
// 如果遇到一个非法的转义字符,直接按原样输出
|
||||
result = append(result, '\\', r)
|
||||
}
|
||||
escaped = false
|
||||
} else {
|
||||
if r == '\\' {
|
||||
escaped = true // 记录反斜杠作为转义符
|
||||
} else {
|
||||
result = append(result, r)
|
||||
}
|
||||
}
|
||||
i += size // 移动到下一个字符
|
||||
}
|
||||
|
||||
return string(result), nil
|
||||
}
|
||||
func unescapeMapOrSlice(data interface{}) interface{} {
|
||||
switch v := data.(type) {
|
||||
case map[string]interface{}:
|
||||
for k, val := range v {
|
||||
v[k] = unescapeMapOrSlice(val)
|
||||
}
|
||||
case []interface{}:
|
||||
for i, val := range v {
|
||||
v[i] = unescapeMapOrSlice(val)
|
||||
}
|
||||
case string:
|
||||
if unescaped, err := unescapeString(v); err != nil {
|
||||
return v
|
||||
} else {
|
||||
return unescaped
|
||||
}
|
||||
}
|
||||
return data
|
||||
}
|
||||
|
||||
func getToolCall(item *GeminiPart) *dto.ToolCall {
|
||||
argsBytes, err := json.Marshal(item.FunctionCall.Arguments)
|
||||
var argsBytes []byte
|
||||
var err error
|
||||
if result, ok := item.FunctionCall.Arguments.(map[string]interface{}); ok {
|
||||
argsBytes, err = json.Marshal(unescapeMapOrSlice(result))
|
||||
} else {
|
||||
argsBytes, err = json.Marshal(item.FunctionCall.Arguments)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
//common.SysError("getToolCall failed: " + err.Error())
|
||||
return nil
|
||||
}
|
||||
return &dto.ToolCall{
|
||||
ID: fmt.Sprintf("call_%s", common.GetUUID()),
|
||||
Type: "function",
|
||||
Function: dto.FunctionCall{
|
||||
// 不好评价,得去转义一下反斜杠,Gemini 的特性好像是,Google 返回的时候本身就会转义“\”
|
||||
Arguments: strings.ReplaceAll(string(argsBytes), "\\\\", "\\"),
|
||||
Arguments: string(argsBytes),
|
||||
Name: item.FunctionCall.FunctionName,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// func getToolCalls(candidate *GeminiChatCandidate, index int) []dto.ToolCall {
|
||||
// var toolCalls []dto.ToolCall
|
||||
|
||||
// item := candidate.Content.Parts[index]
|
||||
// if item.FunctionCall == nil {
|
||||
// return toolCalls
|
||||
// }
|
||||
// argsBytes, err := json.Marshal(item.FunctionCall.Arguments)
|
||||
// if err != nil {
|
||||
// //common.SysError("getToolCalls failed: " + err.Error())
|
||||
// return toolCalls
|
||||
// }
|
||||
// toolCall := dto.ToolCall{
|
||||
// ID: fmt.Sprintf("call_%s", common.GetUUID()),
|
||||
// Type: "function",
|
||||
// Function: dto.FunctionCall{
|
||||
// Arguments: string(argsBytes),
|
||||
// Name: item.FunctionCall.FunctionName,
|
||||
// },
|
||||
// }
|
||||
// toolCalls = append(toolCalls, toolCall)
|
||||
// return toolCalls
|
||||
// }
|
||||
|
||||
func responseGeminiChat2OpenAI(response *GeminiChatResponse) *dto.OpenAITextResponse {
|
||||
fullTextResponse := dto.OpenAITextResponse{
|
||||
Id: fmt.Sprintf("chatcmpl-%s", common.GetUUID()),
|
||||
|
||||
@@ -5,11 +5,12 @@ import (
|
||||
"encoding/base64"
|
||||
"errors"
|
||||
"fmt"
|
||||
"golang.org/x/image/webp"
|
||||
"image"
|
||||
"io"
|
||||
"one-api/common"
|
||||
"strings"
|
||||
|
||||
"golang.org/x/image/webp"
|
||||
)
|
||||
|
||||
func DecodeBase64ImageData(base64String string) (image.Config, string, string, error) {
|
||||
@@ -31,6 +32,31 @@ func DecodeBase64ImageData(base64String string) (image.Config, string, string, e
|
||||
return config, format, base64String, err
|
||||
}
|
||||
|
||||
func DecodeBase64FileData(base64String string) (string, string, error) {
|
||||
var mimeType string
|
||||
var idx int
|
||||
idx = strings.Index(base64String, ",")
|
||||
if idx == -1 {
|
||||
_, file_type, base64, err := DecodeBase64ImageData(base64String)
|
||||
return "image/" + file_type, base64, err
|
||||
}
|
||||
mimeType = base64String[:idx]
|
||||
base64String = base64String[idx+1:]
|
||||
idx = strings.Index(mimeType, ";")
|
||||
if idx == -1 {
|
||||
_, file_type, base64, err := DecodeBase64ImageData(base64String)
|
||||
return "image/" + file_type, base64, err
|
||||
}
|
||||
mimeType = mimeType[:idx]
|
||||
idx = strings.Index(mimeType, ":")
|
||||
if idx == -1 {
|
||||
_, file_type, base64, err := DecodeBase64ImageData(base64String)
|
||||
return "image/" + file_type, base64, err
|
||||
}
|
||||
mimeType = mimeType[idx+1:]
|
||||
return mimeType, base64String, nil
|
||||
}
|
||||
|
||||
// GetImageFromUrl 获取图片的类型和base64编码的数据
|
||||
func GetImageFromUrl(url string) (mimeType string, data string, err error) {
|
||||
resp, err := DoDownloadRequest(url)
|
||||
|
||||
Reference in New Issue
Block a user