Merge pull request #674 from Yan-Zero/main
fix: Gemini 函数调用的文本转义,以及其他文件类型的 Base64 支持
This commit is contained in:
@@ -12,6 +12,7 @@ import (
|
|||||||
relaycommon "one-api/relay/common"
|
relaycommon "one-api/relay/common"
|
||||||
"one-api/service"
|
"one-api/service"
|
||||||
"strings"
|
"strings"
|
||||||
|
"unicode/utf8"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
)
|
)
|
||||||
@@ -203,13 +204,13 @@ func CovertGemini2OpenAI(textRequest dto.GeneralOpenAIRequest) (*GeminiChatReque
|
|||||||
},
|
},
|
||||||
})
|
})
|
||||||
} else {
|
} else {
|
||||||
_, format, base64String, err := service.DecodeBase64ImageData(part.ImageUrl.(dto.MessageImageUrl).Url)
|
format, base64String, err := service.DecodeBase64FileData(part.ImageUrl.(dto.MessageImageUrl).Url)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("decode base64 image data failed: %s", err.Error())
|
return nil, fmt.Errorf("decode base64 image data failed: %s", err.Error())
|
||||||
}
|
}
|
||||||
parts = append(parts, GeminiPart{
|
parts = append(parts, GeminiPart{
|
||||||
InlineData: &GeminiInlineData{
|
InlineData: &GeminiInlineData{
|
||||||
MimeType: "image/" + format,
|
MimeType: format,
|
||||||
Data: base64String,
|
Data: base64String,
|
||||||
},
|
},
|
||||||
})
|
})
|
||||||
@@ -279,57 +280,97 @@ func removeAdditionalPropertiesWithDepth(schema interface{}, depth int) interfac
|
|||||||
return v
|
return v
|
||||||
}
|
}
|
||||||
|
|
||||||
// func (g *GeminiChatResponse) GetResponseText() string {
|
func unescapeString(s string) (string, error) {
|
||||||
// if g == nil {
|
var result []rune
|
||||||
// return ""
|
escaped := false
|
||||||
// }
|
i := 0
|
||||||
// if len(g.Candidates) > 0 && len(g.Candidates[0].Content.Parts) > 0 {
|
|
||||||
// return g.Candidates[0].Content.Parts[0].Text
|
for i < len(s) {
|
||||||
// }
|
r, size := utf8.DecodeRuneInString(s[i:]) // 正确解码UTF-8字符
|
||||||
// return ""
|
if r == utf8.RuneError {
|
||||||
// }
|
return "", fmt.Errorf("invalid UTF-8 encoding")
|
||||||
|
}
|
||||||
|
|
||||||
|
if escaped {
|
||||||
|
// 如果是转义符后的字符,检查其类型
|
||||||
|
switch r {
|
||||||
|
case '"':
|
||||||
|
result = append(result, '"')
|
||||||
|
case '\\':
|
||||||
|
result = append(result, '\\')
|
||||||
|
case '/':
|
||||||
|
result = append(result, '/')
|
||||||
|
case 'b':
|
||||||
|
result = append(result, '\b')
|
||||||
|
case 'f':
|
||||||
|
result = append(result, '\f')
|
||||||
|
case 'n':
|
||||||
|
result = append(result, '\n')
|
||||||
|
case 'r':
|
||||||
|
result = append(result, '\r')
|
||||||
|
case 't':
|
||||||
|
result = append(result, '\t')
|
||||||
|
case '\'':
|
||||||
|
result = append(result, '\'')
|
||||||
|
default:
|
||||||
|
// 如果遇到一个非法的转义字符,直接按原样输出
|
||||||
|
result = append(result, '\\', r)
|
||||||
|
}
|
||||||
|
escaped = false
|
||||||
|
} else {
|
||||||
|
if r == '\\' {
|
||||||
|
escaped = true // 记录反斜杠作为转义符
|
||||||
|
} else {
|
||||||
|
result = append(result, r)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
i += size // 移动到下一个字符
|
||||||
|
}
|
||||||
|
|
||||||
|
return string(result), nil
|
||||||
|
}
|
||||||
|
func unescapeMapOrSlice(data interface{}) interface{} {
|
||||||
|
switch v := data.(type) {
|
||||||
|
case map[string]interface{}:
|
||||||
|
for k, val := range v {
|
||||||
|
v[k] = unescapeMapOrSlice(val)
|
||||||
|
}
|
||||||
|
case []interface{}:
|
||||||
|
for i, val := range v {
|
||||||
|
v[i] = unescapeMapOrSlice(val)
|
||||||
|
}
|
||||||
|
case string:
|
||||||
|
if unescaped, err := unescapeString(v); err != nil {
|
||||||
|
return v
|
||||||
|
} else {
|
||||||
|
return unescaped
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return data
|
||||||
|
}
|
||||||
|
|
||||||
func getToolCall(item *GeminiPart) *dto.ToolCall {
|
func getToolCall(item *GeminiPart) *dto.ToolCall {
|
||||||
argsBytes, err := json.Marshal(item.FunctionCall.Arguments)
|
var argsBytes []byte
|
||||||
|
var err error
|
||||||
|
if result, ok := item.FunctionCall.Arguments.(map[string]interface{}); ok {
|
||||||
|
argsBytes, err = json.Marshal(unescapeMapOrSlice(result))
|
||||||
|
} else {
|
||||||
|
argsBytes, err = json.Marshal(item.FunctionCall.Arguments)
|
||||||
|
}
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
//common.SysError("getToolCall failed: " + err.Error())
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
return &dto.ToolCall{
|
return &dto.ToolCall{
|
||||||
ID: fmt.Sprintf("call_%s", common.GetUUID()),
|
ID: fmt.Sprintf("call_%s", common.GetUUID()),
|
||||||
Type: "function",
|
Type: "function",
|
||||||
Function: dto.FunctionCall{
|
Function: dto.FunctionCall{
|
||||||
// 不好评价,得去转义一下反斜杠,Gemini 的特性好像是,Google 返回的时候本身就会转义“\”
|
Arguments: string(argsBytes),
|
||||||
Arguments: strings.ReplaceAll(string(argsBytes), "\\\\", "\\"),
|
|
||||||
Name: item.FunctionCall.FunctionName,
|
Name: item.FunctionCall.FunctionName,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// func getToolCalls(candidate *GeminiChatCandidate, index int) []dto.ToolCall {
|
|
||||||
// var toolCalls []dto.ToolCall
|
|
||||||
|
|
||||||
// item := candidate.Content.Parts[index]
|
|
||||||
// if item.FunctionCall == nil {
|
|
||||||
// return toolCalls
|
|
||||||
// }
|
|
||||||
// argsBytes, err := json.Marshal(item.FunctionCall.Arguments)
|
|
||||||
// if err != nil {
|
|
||||||
// //common.SysError("getToolCalls failed: " + err.Error())
|
|
||||||
// return toolCalls
|
|
||||||
// }
|
|
||||||
// toolCall := dto.ToolCall{
|
|
||||||
// ID: fmt.Sprintf("call_%s", common.GetUUID()),
|
|
||||||
// Type: "function",
|
|
||||||
// Function: dto.FunctionCall{
|
|
||||||
// Arguments: string(argsBytes),
|
|
||||||
// Name: item.FunctionCall.FunctionName,
|
|
||||||
// },
|
|
||||||
// }
|
|
||||||
// toolCalls = append(toolCalls, toolCall)
|
|
||||||
// return toolCalls
|
|
||||||
// }
|
|
||||||
|
|
||||||
func responseGeminiChat2OpenAI(response *GeminiChatResponse) *dto.OpenAITextResponse {
|
func responseGeminiChat2OpenAI(response *GeminiChatResponse) *dto.OpenAITextResponse {
|
||||||
fullTextResponse := dto.OpenAITextResponse{
|
fullTextResponse := dto.OpenAITextResponse{
|
||||||
Id: fmt.Sprintf("chatcmpl-%s", common.GetUUID()),
|
Id: fmt.Sprintf("chatcmpl-%s", common.GetUUID()),
|
||||||
|
|||||||
@@ -5,11 +5,12 @@ import (
|
|||||||
"encoding/base64"
|
"encoding/base64"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"golang.org/x/image/webp"
|
|
||||||
"image"
|
"image"
|
||||||
"io"
|
"io"
|
||||||
"one-api/common"
|
"one-api/common"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
|
"golang.org/x/image/webp"
|
||||||
)
|
)
|
||||||
|
|
||||||
func DecodeBase64ImageData(base64String string) (image.Config, string, string, error) {
|
func DecodeBase64ImageData(base64String string) (image.Config, string, string, error) {
|
||||||
@@ -31,6 +32,31 @@ func DecodeBase64ImageData(base64String string) (image.Config, string, string, e
|
|||||||
return config, format, base64String, err
|
return config, format, base64String, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func DecodeBase64FileData(base64String string) (string, string, error) {
|
||||||
|
var mimeType string
|
||||||
|
var idx int
|
||||||
|
idx = strings.Index(base64String, ",")
|
||||||
|
if idx == -1 {
|
||||||
|
_, file_type, base64, err := DecodeBase64ImageData(base64String)
|
||||||
|
return "image/" + file_type, base64, err
|
||||||
|
}
|
||||||
|
mimeType = base64String[:idx]
|
||||||
|
base64String = base64String[idx+1:]
|
||||||
|
idx = strings.Index(mimeType, ";")
|
||||||
|
if idx == -1 {
|
||||||
|
_, file_type, base64, err := DecodeBase64ImageData(base64String)
|
||||||
|
return "image/" + file_type, base64, err
|
||||||
|
}
|
||||||
|
mimeType = mimeType[:idx]
|
||||||
|
idx = strings.Index(mimeType, ":")
|
||||||
|
if idx == -1 {
|
||||||
|
_, file_type, base64, err := DecodeBase64ImageData(base64String)
|
||||||
|
return "image/" + file_type, base64, err
|
||||||
|
}
|
||||||
|
mimeType = mimeType[idx+1:]
|
||||||
|
return mimeType, base64String, nil
|
||||||
|
}
|
||||||
|
|
||||||
// GetImageFromUrl 获取图片的类型和base64编码的数据
|
// GetImageFromUrl 获取图片的类型和base64编码的数据
|
||||||
func GetImageFromUrl(url string) (mimeType string, data string, err error) {
|
func GetImageFromUrl(url string) (mimeType string, data string, err error) {
|
||||||
resp, err := DoDownloadRequest(url)
|
resp, err := DoDownloadRequest(url)
|
||||||
|
|||||||
Reference in New Issue
Block a user