feat: support gemini output text and inline images. (close #866)

This commit is contained in:
CaIon
2025-04-15 02:32:51 +08:00
parent 99efc1fbb6
commit 473e8e0eaf
8 changed files with 74 additions and 23 deletions

View File

@@ -99,7 +99,7 @@ func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayIn
if request == nil { if request == nil {
return nil, errors.New("request is nil") return nil, errors.New("request is nil")
} }
ai, err := CovertGemini2OpenAI(*request) ai, err := CovertGemini2OpenAI(*request, info)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@@ -71,15 +71,16 @@ type GeminiChatTool struct {
} }
type GeminiChatGenerationConfig struct { type GeminiChatGenerationConfig struct {
Temperature *float64 `json:"temperature,omitempty"` Temperature *float64 `json:"temperature,omitempty"`
TopP float64 `json:"topP,omitempty"` TopP float64 `json:"topP,omitempty"`
TopK float64 `json:"topK,omitempty"` TopK float64 `json:"topK,omitempty"`
MaxOutputTokens uint `json:"maxOutputTokens,omitempty"` MaxOutputTokens uint `json:"maxOutputTokens,omitempty"`
CandidateCount int `json:"candidateCount,omitempty"` CandidateCount int `json:"candidateCount,omitempty"`
StopSequences []string `json:"stopSequences,omitempty"` StopSequences []string `json:"stopSequences,omitempty"`
ResponseMimeType string `json:"responseMimeType,omitempty"` ResponseMimeType string `json:"responseMimeType,omitempty"`
ResponseSchema any `json:"responseSchema,omitempty"` ResponseSchema any `json:"responseSchema,omitempty"`
Seed int64 `json:"seed,omitempty"` Seed int64 `json:"seed,omitempty"`
ResponseModalities []string `json:"responseModalities,omitempty"`
} }
type GeminiChatCandidate struct { type GeminiChatCandidate struct {

View File

@@ -19,7 +19,7 @@ import (
) )
// Setting safety to the lowest possible values since Gemini is already powerless enough // Setting safety to the lowest possible values since Gemini is already powerless enough
func CovertGemini2OpenAI(textRequest dto.GeneralOpenAIRequest) (*GeminiChatRequest, error) { func CovertGemini2OpenAI(textRequest dto.GeneralOpenAIRequest, info *relaycommon.RelayInfo) (*GeminiChatRequest, error) {
geminiRequest := GeminiChatRequest{ geminiRequest := GeminiChatRequest{
Contents: make([]GeminiChatContent, 0, len(textRequest.Messages)), Contents: make([]GeminiChatContent, 0, len(textRequest.Messages)),
@@ -32,6 +32,13 @@ func CovertGemini2OpenAI(textRequest dto.GeneralOpenAIRequest) (*GeminiChatReque
}, },
} }
if model_setting.IsGeminiModelSupportImagine(info.UpstreamModelName) {
geminiRequest.GenerationConfig.ResponseModalities = []string{
"TEXT",
"IMAGE",
}
}
safetySettings := make([]GeminiChatSafetySettings, 0, len(SafetySettingList)) safetySettings := make([]GeminiChatSafetySettings, 0, len(SafetySettingList))
for _, category := range SafetySettingList { for _, category := range SafetySettingList {
safetySettings = append(safetySettings, GeminiChatSafetySettings{ safetySettings = append(safetySettings, GeminiChatSafetySettings{
@@ -546,9 +553,10 @@ func responseGeminiChat2OpenAI(response *GeminiChatResponse) *dto.OpenAITextResp
return &fullTextResponse return &fullTextResponse
} }
func streamResponseGeminiChat2OpenAI(geminiResponse *GeminiChatResponse) (*dto.ChatCompletionsStreamResponse, bool) { func streamResponseGeminiChat2OpenAI(geminiResponse *GeminiChatResponse) (*dto.ChatCompletionsStreamResponse, bool, bool) {
choices := make([]dto.ChatCompletionsStreamResponseChoice, 0, len(geminiResponse.Candidates)) choices := make([]dto.ChatCompletionsStreamResponseChoice, 0, len(geminiResponse.Candidates))
isStop := false isStop := false
hasImage := false
for _, candidate := range geminiResponse.Candidates { for _, candidate := range geminiResponse.Candidates {
if candidate.FinishReason != nil && *candidate.FinishReason == "STOP" { if candidate.FinishReason != nil && *candidate.FinishReason == "STOP" {
isStop = true isStop = true
@@ -574,7 +582,13 @@ func streamResponseGeminiChat2OpenAI(geminiResponse *GeminiChatResponse) (*dto.C
} }
} }
for _, part := range candidate.Content.Parts { for _, part := range candidate.Content.Parts {
if part.FunctionCall != nil { if part.InlineData != nil {
if strings.HasPrefix(part.InlineData.MimeType, "image") {
imgText := "![image](data:" + part.InlineData.MimeType + ";base64," + part.InlineData.Data + ")"
texts = append(texts, imgText)
hasImage = true
}
} else if part.FunctionCall != nil {
isTools = true isTools = true
if call := getResponseToolCall(&part); call != nil { if call := getResponseToolCall(&part); call != nil {
call.SetIndex(len(choice.Delta.ToolCalls)) call.SetIndex(len(choice.Delta.ToolCalls))
@@ -602,7 +616,7 @@ func streamResponseGeminiChat2OpenAI(geminiResponse *GeminiChatResponse) (*dto.C
var response dto.ChatCompletionsStreamResponse var response dto.ChatCompletionsStreamResponse
response.Object = "chat.completion.chunk" response.Object = "chat.completion.chunk"
response.Choices = choices response.Choices = choices
return &response, isStop return &response, isStop, hasImage
} }
func GeminiChatStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) { func GeminiChatStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
@@ -610,20 +624,23 @@ func GeminiChatStreamHandler(c *gin.Context, resp *http.Response, info *relaycom
id := fmt.Sprintf("chatcmpl-%s", common.GetUUID()) id := fmt.Sprintf("chatcmpl-%s", common.GetUUID())
createAt := common.GetTimestamp() createAt := common.GetTimestamp()
var usage = &dto.Usage{} var usage = &dto.Usage{}
var imageCount int
helper.StreamScannerHandler(c, resp, info, func(data string) bool { helper.StreamScannerHandler(c, resp, info, func(data string) bool {
var geminiResponse GeminiChatResponse var geminiResponse GeminiChatResponse
err := json.Unmarshal([]byte(data), &geminiResponse) err := common.DecodeJsonStr(data, &geminiResponse)
if err != nil { if err != nil {
common.LogError(c, "error unmarshalling stream response: "+err.Error()) common.LogError(c, "error unmarshalling stream response: "+err.Error())
return false return false
} }
response, isStop := streamResponseGeminiChat2OpenAI(&geminiResponse) response, isStop, hasImage := streamResponseGeminiChat2OpenAI(&geminiResponse)
if hasImage {
imageCount++
}
response.Id = id response.Id = id
response.Created = createAt response.Created = createAt
response.Model = info.UpstreamModelName response.Model = info.UpstreamModelName
// responseText += response.Choices[0].Delta.GetContentString()
if geminiResponse.UsageMetadata.TotalTokenCount != 0 { if geminiResponse.UsageMetadata.TotalTokenCount != 0 {
usage.PromptTokens = geminiResponse.UsageMetadata.PromptTokenCount usage.PromptTokens = geminiResponse.UsageMetadata.PromptTokenCount
usage.CompletionTokens = geminiResponse.UsageMetadata.CandidatesTokenCount usage.CompletionTokens = geminiResponse.UsageMetadata.CandidatesTokenCount
@@ -641,6 +658,12 @@ func GeminiChatStreamHandler(c *gin.Context, resp *http.Response, info *relaycom
var response *dto.ChatCompletionsStreamResponse var response *dto.ChatCompletionsStreamResponse
if imageCount != 0 {
if usage.CompletionTokens == 0 {
usage.CompletionTokens = imageCount * 258
}
}
usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens
usage.PromptTokensDetails.TextTokens = usage.PromptTokens usage.PromptTokensDetails.TextTokens = usage.PromptTokens
usage.CompletionTokenDetails.TextTokens = usage.CompletionTokens usage.CompletionTokenDetails.TextTokens = usage.CompletionTokens

View File

@@ -143,7 +143,7 @@ func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayIn
info.UpstreamModelName = claudeReq.Model info.UpstreamModelName = claudeReq.Model
return vertexClaudeReq, nil return vertexClaudeReq, nil
} else if a.RequestMode == RequestModeGemini { } else if a.RequestMode == RequestModeGemini {
geminiRequest, err := gemini.CovertGemini2OpenAI(*request) geminiRequest, err := gemini.CovertGemini2OpenAI(*request, info)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@@ -90,7 +90,7 @@ type RelayInfo struct {
RelayFormat string RelayFormat string
SendResponseCount int SendResponseCount int
ThinkingContentInfo ThinkingContentInfo
ClaudeConvertInfo *ClaudeConvertInfo
*RerankerInfo *RerankerInfo
} }
@@ -120,7 +120,7 @@ func GenRelayInfoClaude(c *gin.Context) *RelayInfo {
info := GenRelayInfo(c) info := GenRelayInfo(c)
info.RelayFormat = RelayFormatClaude info.RelayFormat = RelayFormatClaude
info.ShouldIncludeUsage = false info.ShouldIncludeUsage = false
info.ClaudeConvertInfo = ClaudeConvertInfo{ info.ClaudeConvertInfo = &ClaudeConvertInfo{
LastMessagesType: LastMessageTypeNone, LastMessagesType: LastMessageTypeNone,
} }
return info return info

View File

@@ -6,8 +6,9 @@ import (
// GeminiSettings 定义Gemini模型的配置 // GeminiSettings 定义Gemini模型的配置
type GeminiSettings struct { type GeminiSettings struct {
SafetySettings map[string]string `json:"safety_settings"` SafetySettings map[string]string `json:"safety_settings"`
VersionSettings map[string]string `json:"version_settings"` VersionSettings map[string]string `json:"version_settings"`
SupportedImagineModels []string `json:"supported_imagine_models"`
} }
// 默认配置 // 默认配置
@@ -20,6 +21,10 @@ var defaultGeminiSettings = GeminiSettings{
"default": "v1beta", "default": "v1beta",
"gemini-1.0-pro": "v1", "gemini-1.0-pro": "v1",
}, },
SupportedImagineModels: []string{
"gemini-2.0-flash-exp-image-generation",
"gemini-2.0-flash-exp",
},
} }
// 全局实例 // 全局实例
@@ -50,3 +55,12 @@ func GetGeminiVersionSetting(key string) string {
} }
return geminiSettings.VersionSettings["default"] return geminiSettings.VersionSettings["default"]
} }
func IsGeminiModelSupportImagine(model string) bool {
for _, v := range geminiSettings.SupportedImagineModels {
if v == model {
return true
}
}
return false
}

View File

@@ -13,6 +13,7 @@ const ModelSetting = () => {
let [inputs, setInputs] = useState({ let [inputs, setInputs] = useState({
'gemini.safety_settings': '', 'gemini.safety_settings': '',
'gemini.version_settings': '', 'gemini.version_settings': '',
'gemini.supported_imagine_models': '',
'claude.model_headers_settings': '', 'claude.model_headers_settings': '',
'claude.thinking_adapter_enabled': true, 'claude.thinking_adapter_enabled': true,
'claude.default_max_tokens': '', 'claude.default_max_tokens': '',
@@ -34,7 +35,8 @@ const ModelSetting = () => {
item.key === 'gemini.safety_settings' || item.key === 'gemini.safety_settings' ||
item.key === 'gemini.version_settings' || item.key === 'gemini.version_settings' ||
item.key === 'claude.model_headers_settings'|| item.key === 'claude.model_headers_settings'||
item.key === 'claude.default_max_tokens' item.key === 'claude.default_max_tokens'||
item.key === 'gemini.supported_imagine_models'
) { ) {
item.value = JSON.stringify(JSON.parse(item.value), null, 2); item.value = JSON.stringify(JSON.parse(item.value), null, 2);
} }

View File

@@ -26,6 +26,7 @@ export default function SettingGeminiModel(props) {
const [inputs, setInputs] = useState({ const [inputs, setInputs] = useState({
'gemini.safety_settings': '', 'gemini.safety_settings': '',
'gemini.version_settings': '', 'gemini.version_settings': '',
'gemini.supported_imagine_models': [],
}); });
const refForm = useRef(); const refForm = useRef();
const [inputsRow, setInputsRow] = useState(inputs); const [inputsRow, setInputsRow] = useState(inputs);
@@ -125,6 +126,16 @@ export default function SettingGeminiModel(props) {
/> />
</Col> </Col>
</Row> </Row>
<Row>
<Col xs={24} sm={12} md={8} lg={8} xl={8}>
<Form.TextArea
field={'gemini.supported_imagine_models'}
label={t('支持的图像模型')}
placeholder={t('例如:') + '\n' + JSON.stringify(['gemini-2.0-flash-exp-image-generation'], null, 2)}
onChange={(value) => setInputs({ ...inputs, 'gemini.supported_imagine_models': value })}
/>
</Col>
</Row>
<Row> <Row>
<Button size='default' onClick={onSubmit}> <Button size='default' onClick={onSubmit}>