diff --git a/relay/channel/gemini/adaptor.go b/relay/channel/gemini/adaptor.go index c5a547ba..e5ee134a 100644 --- a/relay/channel/gemini/adaptor.go +++ b/relay/channel/gemini/adaptor.go @@ -99,7 +99,7 @@ func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayIn if request == nil { return nil, errors.New("request is nil") } - ai, err := CovertGemini2OpenAI(*request) + ai, err := CovertGemini2OpenAI(*request, info) if err != nil { return nil, err } diff --git a/relay/channel/gemini/dto.go b/relay/channel/gemini/dto.go index cbf55576..7f98b1b7 100644 --- a/relay/channel/gemini/dto.go +++ b/relay/channel/gemini/dto.go @@ -71,15 +71,16 @@ type GeminiChatTool struct { } type GeminiChatGenerationConfig struct { - Temperature *float64 `json:"temperature,omitempty"` - TopP float64 `json:"topP,omitempty"` - TopK float64 `json:"topK,omitempty"` - MaxOutputTokens uint `json:"maxOutputTokens,omitempty"` - CandidateCount int `json:"candidateCount,omitempty"` - StopSequences []string `json:"stopSequences,omitempty"` - ResponseMimeType string `json:"responseMimeType,omitempty"` - ResponseSchema any `json:"responseSchema,omitempty"` - Seed int64 `json:"seed,omitempty"` + Temperature *float64 `json:"temperature,omitempty"` + TopP float64 `json:"topP,omitempty"` + TopK float64 `json:"topK,omitempty"` + MaxOutputTokens uint `json:"maxOutputTokens,omitempty"` + CandidateCount int `json:"candidateCount,omitempty"` + StopSequences []string `json:"stopSequences,omitempty"` + ResponseMimeType string `json:"responseMimeType,omitempty"` + ResponseSchema any `json:"responseSchema,omitempty"` + Seed int64 `json:"seed,omitempty"` + ResponseModalities []string `json:"responseModalities,omitempty"` } type GeminiChatCandidate struct { diff --git a/relay/channel/gemini/relay-gemini.go b/relay/channel/gemini/relay-gemini.go index 00b39cb2..03736f38 100644 --- a/relay/channel/gemini/relay-gemini.go +++ b/relay/channel/gemini/relay-gemini.go @@ -19,7 +19,7 @@ import ( ) // Setting safety to the lowest possible values since Gemini is already powerless enough -func CovertGemini2OpenAI(textRequest dto.GeneralOpenAIRequest) (*GeminiChatRequest, error) { +func CovertGemini2OpenAI(textRequest dto.GeneralOpenAIRequest, info *relaycommon.RelayInfo) (*GeminiChatRequest, error) { geminiRequest := GeminiChatRequest{ Contents: make([]GeminiChatContent, 0, len(textRequest.Messages)), @@ -32,6 +32,13 @@ func CovertGemini2OpenAI(textRequest dto.GeneralOpenAIRequest) (*GeminiChatReque }, } + if model_setting.IsGeminiModelSupportImagine(info.UpstreamModelName) { + geminiRequest.GenerationConfig.ResponseModalities = []string{ + "TEXT", + "IMAGE", + } + } + safetySettings := make([]GeminiChatSafetySettings, 0, len(SafetySettingList)) for _, category := range SafetySettingList { safetySettings = append(safetySettings, GeminiChatSafetySettings{ @@ -546,9 +553,10 @@ func responseGeminiChat2OpenAI(response *GeminiChatResponse) *dto.OpenAITextResp return &fullTextResponse } -func streamResponseGeminiChat2OpenAI(geminiResponse *GeminiChatResponse) (*dto.ChatCompletionsStreamResponse, bool) { +func streamResponseGeminiChat2OpenAI(geminiResponse *GeminiChatResponse) (*dto.ChatCompletionsStreamResponse, bool, bool) { choices := make([]dto.ChatCompletionsStreamResponseChoice, 0, len(geminiResponse.Candidates)) isStop := false + hasImage := false for _, candidate := range geminiResponse.Candidates { if candidate.FinishReason != nil && *candidate.FinishReason == "STOP" { isStop = true @@ -574,7 +582,13 @@ func streamResponseGeminiChat2OpenAI(geminiResponse *GeminiChatResponse) (*dto.C } } for _, part := range candidate.Content.Parts { - if part.FunctionCall != nil { + if part.InlineData != nil { + if strings.HasPrefix(part.InlineData.MimeType, "image") { + imgText := "![image](data:" + part.InlineData.MimeType + ";base64," + part.InlineData.Data + ")" + texts = append(texts, imgText) + hasImage = true + } + } else if part.FunctionCall != nil { isTools = true if call := getResponseToolCall(&part); call != nil { call.SetIndex(len(choice.Delta.ToolCalls)) @@ -602,7 +616,7 @@ func streamResponseGeminiChat2OpenAI(geminiResponse *GeminiChatResponse) (*dto.C var response dto.ChatCompletionsStreamResponse response.Object = "chat.completion.chunk" response.Choices = choices - return &response, isStop + return &response, isStop, hasImage } func GeminiChatStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) { @@ -610,20 +624,23 @@ func GeminiChatStreamHandler(c *gin.Context, resp *http.Response, info *relaycom id := fmt.Sprintf("chatcmpl-%s", common.GetUUID()) createAt := common.GetTimestamp() var usage = &dto.Usage{} + var imageCount int helper.StreamScannerHandler(c, resp, info, func(data string) bool { var geminiResponse GeminiChatResponse - err := json.Unmarshal([]byte(data), &geminiResponse) + err := common.DecodeJsonStr(data, &geminiResponse) if err != nil { common.LogError(c, "error unmarshalling stream response: "+err.Error()) return false } - response, isStop := streamResponseGeminiChat2OpenAI(&geminiResponse) + response, isStop, hasImage := streamResponseGeminiChat2OpenAI(&geminiResponse) + if hasImage { + imageCount++ + } response.Id = id response.Created = createAt response.Model = info.UpstreamModelName - // responseText += response.Choices[0].Delta.GetContentString() if geminiResponse.UsageMetadata.TotalTokenCount != 0 { usage.PromptTokens = geminiResponse.UsageMetadata.PromptTokenCount usage.CompletionTokens = geminiResponse.UsageMetadata.CandidatesTokenCount @@ -641,6 +658,12 @@ func GeminiChatStreamHandler(c *gin.Context, resp *http.Response, info *relaycom var response *dto.ChatCompletionsStreamResponse + if imageCount != 0 { + if usage.CompletionTokens == 0 { + usage.CompletionTokens = imageCount * 258 + } + } + usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens usage.PromptTokensDetails.TextTokens = usage.PromptTokens usage.CompletionTokenDetails.TextTokens = usage.CompletionTokens diff --git a/relay/channel/vertex/adaptor.go b/relay/channel/vertex/adaptor.go index b52e7e0a..77f29620 100644 --- a/relay/channel/vertex/adaptor.go +++ b/relay/channel/vertex/adaptor.go @@ -143,7 +143,7 @@ func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayIn info.UpstreamModelName = claudeReq.Model return vertexClaudeReq, nil } else if a.RequestMode == RequestModeGemini { - geminiRequest, err := gemini.CovertGemini2OpenAI(*request) + geminiRequest, err := gemini.CovertGemini2OpenAI(*request, info) if err != nil { return nil, err } diff --git a/relay/common/relay_info.go b/relay/common/relay_info.go index 8ab97f5e..fa87dc24 100644 --- a/relay/common/relay_info.go +++ b/relay/common/relay_info.go @@ -90,7 +90,7 @@ type RelayInfo struct { RelayFormat string SendResponseCount int ThinkingContentInfo - ClaudeConvertInfo + *ClaudeConvertInfo *RerankerInfo } @@ -120,7 +120,7 @@ func GenRelayInfoClaude(c *gin.Context) *RelayInfo { info := GenRelayInfo(c) info.RelayFormat = RelayFormatClaude info.ShouldIncludeUsage = false - info.ClaudeConvertInfo = ClaudeConvertInfo{ + info.ClaudeConvertInfo = &ClaudeConvertInfo{ LastMessagesType: LastMessageTypeNone, } return info diff --git a/setting/model_setting/gemini.go b/setting/model_setting/gemini.go index 07e993bc..e6509232 100644 --- a/setting/model_setting/gemini.go +++ b/setting/model_setting/gemini.go @@ -6,8 +6,9 @@ import ( // GeminiSettings 定义Gemini模型的配置 type GeminiSettings struct { - SafetySettings map[string]string `json:"safety_settings"` - VersionSettings map[string]string `json:"version_settings"` + SafetySettings map[string]string `json:"safety_settings"` + VersionSettings map[string]string `json:"version_settings"` + SupportedImagineModels []string `json:"supported_imagine_models"` } // 默认配置 @@ -20,6 +21,10 @@ var defaultGeminiSettings = GeminiSettings{ "default": "v1beta", "gemini-1.0-pro": "v1", }, + SupportedImagineModels: []string{ + "gemini-2.0-flash-exp-image-generation", + "gemini-2.0-flash-exp", + }, } // 全局实例 @@ -50,3 +55,12 @@ func GetGeminiVersionSetting(key string) string { } return geminiSettings.VersionSettings["default"] } + +func IsGeminiModelSupportImagine(model string) bool { + for _, v := range geminiSettings.SupportedImagineModels { + if v == model { + return true + } + } + return false +} diff --git a/web/src/components/ModelSetting.js b/web/src/components/ModelSetting.js index ce89c337..a9e1b855 100644 --- a/web/src/components/ModelSetting.js +++ b/web/src/components/ModelSetting.js @@ -13,6 +13,7 @@ const ModelSetting = () => { let [inputs, setInputs] = useState({ 'gemini.safety_settings': '', 'gemini.version_settings': '', + 'gemini.supported_imagine_models': '', 'claude.model_headers_settings': '', 'claude.thinking_adapter_enabled': true, 'claude.default_max_tokens': '', @@ -34,7 +35,8 @@ const ModelSetting = () => { item.key === 'gemini.safety_settings' || item.key === 'gemini.version_settings' || item.key === 'claude.model_headers_settings'|| - item.key === 'claude.default_max_tokens' + item.key === 'claude.default_max_tokens'|| + item.key === 'gemini.supported_imagine_models' ) { item.value = JSON.stringify(JSON.parse(item.value), null, 2); } diff --git a/web/src/pages/Setting/Model/SettingGeminiModel.js b/web/src/pages/Setting/Model/SettingGeminiModel.js index 6139142c..844812e5 100644 --- a/web/src/pages/Setting/Model/SettingGeminiModel.js +++ b/web/src/pages/Setting/Model/SettingGeminiModel.js @@ -26,6 +26,7 @@ export default function SettingGeminiModel(props) { const [inputs, setInputs] = useState({ 'gemini.safety_settings': '', 'gemini.version_settings': '', + 'gemini.supported_imagine_models': [], }); const refForm = useRef(); const [inputsRow, setInputsRow] = useState(inputs); @@ -125,6 +126,16 @@ export default function SettingGeminiModel(props) { /> + + + setInputs({ ...inputs, 'gemini.supported_imagine_models': value })} + /> + +