From bf80d71ddf649184f9a050388fc86edd1ba417d8 Mon Sep 17 00:00:00 2001
From: "1808837298@qq.com" <1808837298@qq.com>
Date: Wed, 26 Feb 2025 18:19:09 +0800
Subject: [PATCH] feat: Add Gemini version settings configuration support
(close #568)
---
.env.example | 4 --
README.md | 5 ++-
constant/env.go | 33 +++++++--------
model/option.go | 8 +++-
relay/channel/gemini/adaptor.go | 12 +-----
relay/channel/gemini/relay-gemini.go | 4 +-
.../gemini.go} | 40 ++++++++++++++++++-
web/src/components/ModelSetting.js | 4 +-
.../pages/Setting/Model/SettingGeminiModel.js | 27 +++++++++++++
9 files changed, 98 insertions(+), 39 deletions(-)
rename setting/{model_setting.go => model_setting/gemini.go} (53%)
diff --git a/.env.example b/.env.example
index 07602eca..bece06db 100644
--- a/.env.example
+++ b/.env.example
@@ -50,10 +50,6 @@
# CHANNEL_TEST_FREQUENCY=10
# 生成默认token
# GENERATE_DEFAULT_TOKEN=false
-# Gemini 安全设置
-# GEMINI_SAFETY_SETTING=BLOCK_NONE
-# Gemini版本设置
-# GEMINI_MODEL_MAP=gemini-1.0-pro:v1
# Cohere 安全设置
# COHERE_SAFETY_SETTING=NONE
# 是否统计图片token
diff --git a/README.md b/README.md
index 3d880f03..62b6ba15 100644
--- a/README.md
+++ b/README.md
@@ -94,7 +94,6 @@
- `GET_MEDIA_TOKEN`:是否统计图片token,默认为 `true`,关闭后将不再在本地计算图片token,可能会导致和上游计费不同,此项覆盖 `GET_MEDIA_TOKEN_NOT_STREAM` 选项作用。
- `GET_MEDIA_TOKEN_NOT_STREAM`:是否在非流(`stream=false`)情况下统计图片token,默认为 `true`。
- `UPDATE_TASK`:是否更新异步任务(Midjourney、Suno),默认为 `true`,关闭后将不会更新任务进度。
-- `GEMINI_MODEL_MAP`:Gemini模型指定版本(v1/v1beta),使用"模型:版本"指定,","分隔,例如:-e GEMINI_MODEL_MAP="gemini-1.5-pro-latest:v1beta,gemini-1.5-pro-001:v1beta",为空则使用默认配置(v1beta)
- `COHERE_SAFETY_SETTING`:Cohere模型[安全设置](https://docs.cohere.com/docs/safety-modes#overview),可选值为 `NONE`, `CONTEXTUAL`, `STRICT`,默认为 `NONE`。
- `GEMINI_VISION_MAX_IMAGE_NUM`:Gemini模型最大图片数量,默认为 `16`,设置为 `-1` 则不限制。
- `MAX_FILE_DOWNLOAD_MB`: 最大文件下载大小,单位 MB,默认为 `20`。
@@ -103,6 +102,10 @@
- `NOTIFICATION_LIMIT_DURATION_MINUTE`:通知限制的持续时间(分钟),默认为 `10`。
- `NOTIFY_LIMIT_COUNT`:用户通知在指定持续时间内的最大数量,默认为 `2`。
+## 已废弃的环境变量
+- ~~`GEMINI_MODEL_MAP`(已废弃)~~:改为到`设置-模型相关设置`中设置
+- ~~`GEMINI_SAFETY_SETTING`(已废弃)~~:改为到`设置-模型相关设置`中设置
+
## 部署
> [!TIP]
diff --git a/constant/env.go b/constant/env.go
index bffbfeea..d2a1d04d 100644
--- a/constant/env.go
+++ b/constant/env.go
@@ -1,10 +1,7 @@
package constant
import (
- "fmt"
"one-api/common"
- "os"
- "strings"
)
var StreamingTimeout = common.GetEnvOrDefault("STREAMING_TIMEOUT", 60)
@@ -23,9 +20,9 @@ var UpdateTask = common.GetEnvOrDefaultBool("UPDATE_TASK", true)
var AzureDefaultAPIVersion = common.GetEnvOrDefaultString("AZURE_DEFAULT_API_VERSION", "2024-12-01-preview")
-var GeminiModelMap = map[string]string{
- "gemini-1.0-pro": "v1",
-}
+//var GeminiModelMap = map[string]string{
+// "gemini-1.0-pro": "v1",
+//}
var GeminiVisionMaxImageNum = common.GetEnvOrDefault("GEMINI_VISION_MAX_IMAGE_NUM", 16)
@@ -33,18 +30,18 @@ var NotifyLimitCount = common.GetEnvOrDefault("NOTIFY_LIMIT_COUNT", 2)
var NotificationLimitDurationMinute = common.GetEnvOrDefault("NOTIFICATION_LIMIT_DURATION_MINUTE", 10)
func InitEnv() {
- modelVersionMapStr := strings.TrimSpace(os.Getenv("GEMINI_MODEL_MAP"))
- if modelVersionMapStr == "" {
- return
- }
- for _, pair := range strings.Split(modelVersionMapStr, ",") {
- parts := strings.Split(pair, ":")
- if len(parts) == 2 {
- GeminiModelMap[parts[0]] = parts[1]
- } else {
- common.SysError(fmt.Sprintf("invalid model version map: %s", pair))
- }
- }
+ //modelVersionMapStr := strings.TrimSpace(os.Getenv("GEMINI_MODEL_MAP"))
+ //if modelVersionMapStr == "" {
+ // return
+ //}
+ //for _, pair := range strings.Split(modelVersionMapStr, ",") {
+ // parts := strings.Split(pair, ":")
+ // if len(parts) == 2 {
+ // GeminiModelMap[parts[0]] = parts[1]
+ // } else {
+ // common.SysError(fmt.Sprintf("invalid model version map: %s", pair))
+ // }
+ //}
}
// GenerateDefaultToken 是否生成初始令牌,默认关闭。
diff --git a/model/option.go b/model/option.go
index 3897ea36..319d2406 100644
--- a/model/option.go
+++ b/model/option.go
@@ -3,6 +3,7 @@ package model
import (
"one-api/common"
"one-api/setting"
+ "one-api/setting/model_setting"
"strconv"
"strings"
"time"
@@ -115,7 +116,8 @@ func InitOptionMap() {
common.OptionMap["SensitiveWords"] = setting.SensitiveWordsToString()
common.OptionMap["StreamCacheQueueLength"] = strconv.Itoa(setting.StreamCacheQueueLength)
common.OptionMap["AutomaticDisableKeywords"] = setting.AutomaticDisableKeywordsToString()
- common.OptionMap["GeminiSafetySettings"] = setting.GeminiSafetySettingsJsonString()
+ common.OptionMap["GeminiSafetySettings"] = model_setting.GeminiSafetySettingsJsonString()
+ common.OptionMap["GeminiVersionSettings"] = model_setting.GeminiVersionSettingsJsonString()
common.OptionMapRWMutex.Unlock()
loadOptionsFromDatabase()
@@ -353,7 +355,9 @@ func updateOptionMap(key string, value string) (err error) {
case "AutomaticDisableKeywords":
setting.AutomaticDisableKeywordsFromString(value)
case "GeminiSafetySettings":
- setting.GeminiSafetySettingFromJsonString(value)
+ model_setting.GeminiSafetySettingFromJsonString(value)
+ case "GeminiVersionSettings":
+ model_setting.GeminiVersionSettingFromJsonString(value)
case "StreamCacheQueueLength":
setting.StreamCacheQueueLength, _ = strconv.Atoi(value)
}
diff --git a/relay/channel/gemini/adaptor.go b/relay/channel/gemini/adaptor.go
index 32513c42..37c6c9df 100644
--- a/relay/channel/gemini/adaptor.go
+++ b/relay/channel/gemini/adaptor.go
@@ -7,11 +7,11 @@ import (
"io"
"net/http"
"one-api/common"
- "one-api/constant"
"one-api/dto"
"one-api/relay/channel"
relaycommon "one-api/relay/common"
"one-api/service"
+ "one-api/setting/model_setting"
"strings"
@@ -64,15 +64,7 @@ func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
}
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
- // 从映射中获取模型名称对应的版本,如果找不到就使用 info.ApiVersion 或默认的版本 "v1beta"
- version, beta := constant.GeminiModelMap[info.UpstreamModelName]
- if !beta {
- if info.ApiVersion != "" {
- version = info.ApiVersion
- } else {
- version = "v1beta"
- }
- }
+ version := model_setting.GetGeminiVersionSetting(info.UpstreamModelName)
if strings.HasPrefix(info.UpstreamModelName, "imagen") {
return fmt.Sprintf("%s/%s/models/%s:predict", info.BaseUrl, version, info.UpstreamModelName), nil
diff --git a/relay/channel/gemini/relay-gemini.go b/relay/channel/gemini/relay-gemini.go
index 7b7c9cb7..5db01c20 100644
--- a/relay/channel/gemini/relay-gemini.go
+++ b/relay/channel/gemini/relay-gemini.go
@@ -11,7 +11,7 @@ import (
"one-api/dto"
relaycommon "one-api/relay/common"
"one-api/service"
- "one-api/setting"
+ "one-api/setting/model_setting"
"strings"
"unicode/utf8"
@@ -36,7 +36,7 @@ func CovertGemini2OpenAI(textRequest dto.GeneralOpenAIRequest) (*GeminiChatReque
for _, category := range SafetySettingList {
safetySettings = append(safetySettings, GeminiChatSafetySettings{
Category: category,
- Threshold: setting.GetGeminiSafetySetting(category),
+ Threshold: model_setting.GetGeminiSafetySetting(category),
})
}
geminiRequest.SafetySettings = safetySettings
diff --git a/setting/model_setting.go b/setting/model_setting/gemini.go
similarity index 53%
rename from setting/model_setting.go
rename to setting/model_setting/gemini.go
index c0f9bd1b..31cd50a0 100644
--- a/setting/model_setting.go
+++ b/setting/model_setting/gemini.go
@@ -1,4 +1,4 @@
-package setting
+package model_setting
import (
"encoding/json"
@@ -43,3 +43,41 @@ func GeminiSafetySettingsJsonString() string {
}
return string(jsonString)
}
+
+var geminiVersionSettings = map[string]string{
+ "default": "v1beta",
+ "gemini-1.0-pro": "v1",
+}
+
+func GetGeminiVersionSetting(key string) string {
+ if value, ok := geminiVersionSettings[key]; ok {
+ return value
+ }
+ return geminiVersionSettings["default"]
+}
+
+func GeminiVersionSettingFromJsonString(jsonString string) {
+ geminiVersionSettings = map[string]string{}
+ err := json.Unmarshal([]byte(jsonString), &geminiVersionSettings)
+ if err != nil {
+ geminiVersionSettings = map[string]string{
+ "default": "v1beta",
+ }
+ }
+ // check must have default
+ if _, ok := geminiVersionSettings["default"]; !ok {
+ geminiVersionSettings["default"] = "v1beta"
+ }
+}
+
+func GeminiVersionSettingsJsonString() string {
+ // check must have default
+ if _, ok := geminiVersionSettings["default"]; !ok {
+ geminiVersionSettings["default"] = "v1beta"
+ }
+ jsonString, err := json.Marshal(geminiVersionSettings)
+ if err != nil {
+ return "{}"
+ }
+ return string(jsonString)
+}
diff --git a/web/src/components/ModelSetting.js b/web/src/components/ModelSetting.js
index 34cba0db..0c76b012 100644
--- a/web/src/components/ModelSetting.js
+++ b/web/src/components/ModelSetting.js
@@ -22,6 +22,7 @@ const ModelSetting = () => {
const { t } = useTranslation();
let [inputs, setInputs] = useState({
GeminiSafetySettings: '',
+ GeminiVersionSettings: '',
});
let [loading, setLoading] = useState(false);
@@ -33,7 +34,8 @@ const ModelSetting = () => {
let newInputs = {};
data.forEach((item) => {
if (
- item.key === 'GeminiSafetySettings'
+ item.key === 'GeminiSafetySettings' ||
+ item.key === 'GeminiVersionSettings'
) {
item.value = JSON.stringify(JSON.parse(item.value), null, 2);
}
diff --git a/web/src/pages/Setting/Model/SettingGeminiModel.js b/web/src/pages/Setting/Model/SettingGeminiModel.js
index 3075b05f..fa63cb74 100644
--- a/web/src/pages/Setting/Model/SettingGeminiModel.js
+++ b/web/src/pages/Setting/Model/SettingGeminiModel.js
@@ -14,12 +14,18 @@ const GEMINI_SETTING_EXAMPLE = {
'HARM_CATEGORY_CIVIC_INTEGRITY': 'BLOCK_NONE',
};
+const GEMINI_VERSION_EXAMPLE = {
+ 'default': 'v1beta',
+};
+
+
export default function SettingGeminiModel(props) {
const { t } = useTranslation();
const [loading, setLoading] = useState(false);
const [inputs, setInputs] = useState({
GeminiSafetySettings: '',
+ GeminiVersionSettings: '',
});
const refForm = useRef();
const [inputsRow, setInputsRow] = useState(inputs);
@@ -99,6 +105,27 @@ export default function SettingGeminiModel(props) {
/>
+
+
+ verifyJSON(value),
+ message: t('不是合法的 JSON 字符串')
+ }
+ ]}
+ onChange={(value) => setInputs({ ...inputs, GeminiVersionSettings: value })}
+ />
+
+
+