diff --git a/relay/channel/gemini/adaptor.go b/relay/channel/gemini/adaptor.go
index c5a547ba..e5ee134a 100644
--- a/relay/channel/gemini/adaptor.go
+++ b/relay/channel/gemini/adaptor.go
@@ -99,7 +99,7 @@ func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayIn
if request == nil {
return nil, errors.New("request is nil")
}
- ai, err := CovertGemini2OpenAI(*request)
+ ai, err := CovertGemini2OpenAI(*request, info)
if err != nil {
return nil, err
}
diff --git a/relay/channel/gemini/dto.go b/relay/channel/gemini/dto.go
index cbf55576..7f98b1b7 100644
--- a/relay/channel/gemini/dto.go
+++ b/relay/channel/gemini/dto.go
@@ -71,15 +71,16 @@ type GeminiChatTool struct {
}
type GeminiChatGenerationConfig struct {
- Temperature *float64 `json:"temperature,omitempty"`
- TopP float64 `json:"topP,omitempty"`
- TopK float64 `json:"topK,omitempty"`
- MaxOutputTokens uint `json:"maxOutputTokens,omitempty"`
- CandidateCount int `json:"candidateCount,omitempty"`
- StopSequences []string `json:"stopSequences,omitempty"`
- ResponseMimeType string `json:"responseMimeType,omitempty"`
- ResponseSchema any `json:"responseSchema,omitempty"`
- Seed int64 `json:"seed,omitempty"`
+ Temperature *float64 `json:"temperature,omitempty"`
+ TopP float64 `json:"topP,omitempty"`
+ TopK float64 `json:"topK,omitempty"`
+ MaxOutputTokens uint `json:"maxOutputTokens,omitempty"`
+ CandidateCount int `json:"candidateCount,omitempty"`
+ StopSequences []string `json:"stopSequences,omitempty"`
+ ResponseMimeType string `json:"responseMimeType,omitempty"`
+ ResponseSchema any `json:"responseSchema,omitempty"`
+ Seed int64 `json:"seed,omitempty"`
+ ResponseModalities []string `json:"responseModalities,omitempty"`
}
type GeminiChatCandidate struct {
diff --git a/relay/channel/gemini/relay-gemini.go b/relay/channel/gemini/relay-gemini.go
index 00b39cb2..03736f38 100644
--- a/relay/channel/gemini/relay-gemini.go
+++ b/relay/channel/gemini/relay-gemini.go
@@ -19,7 +19,7 @@ import (
)
// Setting safety to the lowest possible values since Gemini is already powerless enough
-func CovertGemini2OpenAI(textRequest dto.GeneralOpenAIRequest) (*GeminiChatRequest, error) {
+func CovertGemini2OpenAI(textRequest dto.GeneralOpenAIRequest, info *relaycommon.RelayInfo) (*GeminiChatRequest, error) {
geminiRequest := GeminiChatRequest{
Contents: make([]GeminiChatContent, 0, len(textRequest.Messages)),
@@ -32,6 +32,13 @@ func CovertGemini2OpenAI(textRequest dto.GeneralOpenAIRequest) (*GeminiChatReque
},
}
+ if model_setting.IsGeminiModelSupportImagine(info.UpstreamModelName) {
+ geminiRequest.GenerationConfig.ResponseModalities = []string{
+ "TEXT",
+ "IMAGE",
+ }
+ }
+
safetySettings := make([]GeminiChatSafetySettings, 0, len(SafetySettingList))
for _, category := range SafetySettingList {
safetySettings = append(safetySettings, GeminiChatSafetySettings{
@@ -546,9 +553,10 @@ func responseGeminiChat2OpenAI(response *GeminiChatResponse) *dto.OpenAITextResp
return &fullTextResponse
}
-func streamResponseGeminiChat2OpenAI(geminiResponse *GeminiChatResponse) (*dto.ChatCompletionsStreamResponse, bool) {
+func streamResponseGeminiChat2OpenAI(geminiResponse *GeminiChatResponse) (*dto.ChatCompletionsStreamResponse, bool, bool) {
choices := make([]dto.ChatCompletionsStreamResponseChoice, 0, len(geminiResponse.Candidates))
isStop := false
+ hasImage := false
for _, candidate := range geminiResponse.Candidates {
if candidate.FinishReason != nil && *candidate.FinishReason == "STOP" {
isStop = true
@@ -574,7 +582,13 @@ func streamResponseGeminiChat2OpenAI(geminiResponse *GeminiChatResponse) (*dto.C
}
}
for _, part := range candidate.Content.Parts {
- if part.FunctionCall != nil {
+ if part.InlineData != nil {
+ if strings.HasPrefix(part.InlineData.MimeType, "image") {
+ imgText := ""
+ texts = append(texts, imgText)
+ hasImage = true
+ }
+ } else if part.FunctionCall != nil {
isTools = true
if call := getResponseToolCall(&part); call != nil {
call.SetIndex(len(choice.Delta.ToolCalls))
@@ -602,7 +616,7 @@ func streamResponseGeminiChat2OpenAI(geminiResponse *GeminiChatResponse) (*dto.C
var response dto.ChatCompletionsStreamResponse
response.Object = "chat.completion.chunk"
response.Choices = choices
- return &response, isStop
+ return &response, isStop, hasImage
}
func GeminiChatStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
@@ -610,20 +624,23 @@ func GeminiChatStreamHandler(c *gin.Context, resp *http.Response, info *relaycom
id := fmt.Sprintf("chatcmpl-%s", common.GetUUID())
createAt := common.GetTimestamp()
var usage = &dto.Usage{}
+ var imageCount int
helper.StreamScannerHandler(c, resp, info, func(data string) bool {
var geminiResponse GeminiChatResponse
- err := json.Unmarshal([]byte(data), &geminiResponse)
+ err := common.DecodeJsonStr(data, &geminiResponse)
if err != nil {
common.LogError(c, "error unmarshalling stream response: "+err.Error())
return false
}
- response, isStop := streamResponseGeminiChat2OpenAI(&geminiResponse)
+ response, isStop, hasImage := streamResponseGeminiChat2OpenAI(&geminiResponse)
+ if hasImage {
+ imageCount++
+ }
response.Id = id
response.Created = createAt
response.Model = info.UpstreamModelName
- // responseText += response.Choices[0].Delta.GetContentString()
if geminiResponse.UsageMetadata.TotalTokenCount != 0 {
usage.PromptTokens = geminiResponse.UsageMetadata.PromptTokenCount
usage.CompletionTokens = geminiResponse.UsageMetadata.CandidatesTokenCount
@@ -641,6 +658,12 @@ func GeminiChatStreamHandler(c *gin.Context, resp *http.Response, info *relaycom
var response *dto.ChatCompletionsStreamResponse
+ if imageCount != 0 {
+ if usage.CompletionTokens == 0 {
+ usage.CompletionTokens = imageCount * 258
+ }
+ }
+
usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens
usage.PromptTokensDetails.TextTokens = usage.PromptTokens
usage.CompletionTokenDetails.TextTokens = usage.CompletionTokens
diff --git a/relay/channel/vertex/adaptor.go b/relay/channel/vertex/adaptor.go
index b52e7e0a..77f29620 100644
--- a/relay/channel/vertex/adaptor.go
+++ b/relay/channel/vertex/adaptor.go
@@ -143,7 +143,7 @@ func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayIn
info.UpstreamModelName = claudeReq.Model
return vertexClaudeReq, nil
} else if a.RequestMode == RequestModeGemini {
- geminiRequest, err := gemini.CovertGemini2OpenAI(*request)
+ geminiRequest, err := gemini.CovertGemini2OpenAI(*request, info)
if err != nil {
return nil, err
}
diff --git a/relay/common/relay_info.go b/relay/common/relay_info.go
index 8ab97f5e..fa87dc24 100644
--- a/relay/common/relay_info.go
+++ b/relay/common/relay_info.go
@@ -90,7 +90,7 @@ type RelayInfo struct {
RelayFormat string
SendResponseCount int
ThinkingContentInfo
- ClaudeConvertInfo
+ *ClaudeConvertInfo
*RerankerInfo
}
@@ -120,7 +120,7 @@ func GenRelayInfoClaude(c *gin.Context) *RelayInfo {
info := GenRelayInfo(c)
info.RelayFormat = RelayFormatClaude
info.ShouldIncludeUsage = false
- info.ClaudeConvertInfo = ClaudeConvertInfo{
+ info.ClaudeConvertInfo = &ClaudeConvertInfo{
LastMessagesType: LastMessageTypeNone,
}
return info
diff --git a/setting/model_setting/gemini.go b/setting/model_setting/gemini.go
index 07e993bc..e6509232 100644
--- a/setting/model_setting/gemini.go
+++ b/setting/model_setting/gemini.go
@@ -6,8 +6,9 @@ import (
// GeminiSettings 定义Gemini模型的配置
type GeminiSettings struct {
- SafetySettings map[string]string `json:"safety_settings"`
- VersionSettings map[string]string `json:"version_settings"`
+ SafetySettings map[string]string `json:"safety_settings"`
+ VersionSettings map[string]string `json:"version_settings"`
+ SupportedImagineModels []string `json:"supported_imagine_models"`
}
// 默认配置
@@ -20,6 +21,10 @@ var defaultGeminiSettings = GeminiSettings{
"default": "v1beta",
"gemini-1.0-pro": "v1",
},
+ SupportedImagineModels: []string{
+ "gemini-2.0-flash-exp-image-generation",
+ "gemini-2.0-flash-exp",
+ },
}
// 全局实例
@@ -50,3 +55,12 @@ func GetGeminiVersionSetting(key string) string {
}
return geminiSettings.VersionSettings["default"]
}
+
+func IsGeminiModelSupportImagine(model string) bool {
+ for _, v := range geminiSettings.SupportedImagineModels {
+ if v == model {
+ return true
+ }
+ }
+ return false
+}
diff --git a/web/src/components/ModelSetting.js b/web/src/components/ModelSetting.js
index ce89c337..a9e1b855 100644
--- a/web/src/components/ModelSetting.js
+++ b/web/src/components/ModelSetting.js
@@ -13,6 +13,7 @@ const ModelSetting = () => {
let [inputs, setInputs] = useState({
'gemini.safety_settings': '',
'gemini.version_settings': '',
+ 'gemini.supported_imagine_models': '',
'claude.model_headers_settings': '',
'claude.thinking_adapter_enabled': true,
'claude.default_max_tokens': '',
@@ -34,7 +35,8 @@ const ModelSetting = () => {
item.key === 'gemini.safety_settings' ||
item.key === 'gemini.version_settings' ||
item.key === 'claude.model_headers_settings'||
- item.key === 'claude.default_max_tokens'
+ item.key === 'claude.default_max_tokens'||
+ item.key === 'gemini.supported_imagine_models'
) {
item.value = JSON.stringify(JSON.parse(item.value), null, 2);
}
diff --git a/web/src/pages/Setting/Model/SettingGeminiModel.js b/web/src/pages/Setting/Model/SettingGeminiModel.js
index 6139142c..844812e5 100644
--- a/web/src/pages/Setting/Model/SettingGeminiModel.js
+++ b/web/src/pages/Setting/Model/SettingGeminiModel.js
@@ -26,6 +26,7 @@ export default function SettingGeminiModel(props) {
const [inputs, setInputs] = useState({
'gemini.safety_settings': '',
'gemini.version_settings': '',
+ 'gemini.supported_imagine_models': [],
});
const refForm = useRef();
const [inputsRow, setInputsRow] = useState(inputs);
@@ -125,6 +126,16 @@ export default function SettingGeminiModel(props) {
/>
+
+
+ setInputs({ ...inputs, 'gemini.supported_imagine_models': value })}
+ />
+
+