diff --git a/.env.example b/.env.example index 07602eca..bece06db 100644 --- a/.env.example +++ b/.env.example @@ -50,10 +50,6 @@ # CHANNEL_TEST_FREQUENCY=10 # 生成默认token # GENERATE_DEFAULT_TOKEN=false -# Gemini 安全设置 -# GEMINI_SAFETY_SETTING=BLOCK_NONE -# Gemini版本设置 -# GEMINI_MODEL_MAP=gemini-1.0-pro:v1 # Cohere 安全设置 # COHERE_SAFETY_SETTING=NONE # 是否统计图片token diff --git a/README.md b/README.md index 3d880f03..62b6ba15 100644 --- a/README.md +++ b/README.md @@ -94,7 +94,6 @@ - `GET_MEDIA_TOKEN`:是否统计图片token,默认为 `true`,关闭后将不再在本地计算图片token,可能会导致和上游计费不同,此项覆盖 `GET_MEDIA_TOKEN_NOT_STREAM` 选项作用。 - `GET_MEDIA_TOKEN_NOT_STREAM`:是否在非流(`stream=false`)情况下统计图片token,默认为 `true`。 - `UPDATE_TASK`:是否更新异步任务(Midjourney、Suno),默认为 `true`,关闭后将不会更新任务进度。 -- `GEMINI_MODEL_MAP`:Gemini模型指定版本(v1/v1beta),使用"模型:版本"指定,","分隔,例如:-e GEMINI_MODEL_MAP="gemini-1.5-pro-latest:v1beta,gemini-1.5-pro-001:v1beta",为空则使用默认配置(v1beta) - `COHERE_SAFETY_SETTING`:Cohere模型[安全设置](https://docs.cohere.com/docs/safety-modes#overview),可选值为 `NONE`, `CONTEXTUAL`, `STRICT`,默认为 `NONE`。 - `GEMINI_VISION_MAX_IMAGE_NUM`:Gemini模型最大图片数量,默认为 `16`,设置为 `-1` 则不限制。 - `MAX_FILE_DOWNLOAD_MB`: 最大文件下载大小,单位 MB,默认为 `20`。 @@ -103,6 +102,10 @@ - `NOTIFICATION_LIMIT_DURATION_MINUTE`:通知限制的持续时间(分钟),默认为 `10`。 - `NOTIFY_LIMIT_COUNT`:用户通知在指定持续时间内的最大数量,默认为 `2`。 +## 已废弃的环境变量 +- ~~`GEMINI_MODEL_MAP`(已废弃)~~:改为到`设置-模型相关设置`中设置 +- ~~`GEMINI_SAFETY_SETTING`(已废弃)~~:改为到`设置-模型相关设置`中设置 + ## 部署 > [!TIP] diff --git a/constant/env.go b/constant/env.go index bffbfeea..d2a1d04d 100644 --- a/constant/env.go +++ b/constant/env.go @@ -1,10 +1,7 @@ package constant import ( - "fmt" "one-api/common" - "os" - "strings" ) var StreamingTimeout = common.GetEnvOrDefault("STREAMING_TIMEOUT", 60) @@ -23,9 +20,9 @@ var UpdateTask = common.GetEnvOrDefaultBool("UPDATE_TASK", true) var AzureDefaultAPIVersion = common.GetEnvOrDefaultString("AZURE_DEFAULT_API_VERSION", "2024-12-01-preview") -var GeminiModelMap = map[string]string{ - "gemini-1.0-pro": "v1", -} +//var GeminiModelMap = map[string]string{ +// "gemini-1.0-pro": "v1", +//} var GeminiVisionMaxImageNum = common.GetEnvOrDefault("GEMINI_VISION_MAX_IMAGE_NUM", 16) @@ -33,18 +30,18 @@ var NotifyLimitCount = common.GetEnvOrDefault("NOTIFY_LIMIT_COUNT", 2) var NotificationLimitDurationMinute = common.GetEnvOrDefault("NOTIFICATION_LIMIT_DURATION_MINUTE", 10) func InitEnv() { - modelVersionMapStr := strings.TrimSpace(os.Getenv("GEMINI_MODEL_MAP")) - if modelVersionMapStr == "" { - return - } - for _, pair := range strings.Split(modelVersionMapStr, ",") { - parts := strings.Split(pair, ":") - if len(parts) == 2 { - GeminiModelMap[parts[0]] = parts[1] - } else { - common.SysError(fmt.Sprintf("invalid model version map: %s", pair)) - } - } + //modelVersionMapStr := strings.TrimSpace(os.Getenv("GEMINI_MODEL_MAP")) + //if modelVersionMapStr == "" { + // return + //} + //for _, pair := range strings.Split(modelVersionMapStr, ",") { + // parts := strings.Split(pair, ":") + // if len(parts) == 2 { + // GeminiModelMap[parts[0]] = parts[1] + // } else { + // common.SysError(fmt.Sprintf("invalid model version map: %s", pair)) + // } + //} } // GenerateDefaultToken 是否生成初始令牌,默认关闭。 diff --git a/model/option.go b/model/option.go index 3897ea36..319d2406 100644 --- a/model/option.go +++ b/model/option.go @@ -3,6 +3,7 @@ package model import ( "one-api/common" "one-api/setting" + "one-api/setting/model_setting" "strconv" "strings" "time" @@ -115,7 +116,8 @@ func InitOptionMap() { common.OptionMap["SensitiveWords"] = setting.SensitiveWordsToString() common.OptionMap["StreamCacheQueueLength"] = strconv.Itoa(setting.StreamCacheQueueLength) common.OptionMap["AutomaticDisableKeywords"] = setting.AutomaticDisableKeywordsToString() - common.OptionMap["GeminiSafetySettings"] = setting.GeminiSafetySettingsJsonString() + common.OptionMap["GeminiSafetySettings"] = model_setting.GeminiSafetySettingsJsonString() + common.OptionMap["GeminiVersionSettings"] = model_setting.GeminiVersionSettingsJsonString() common.OptionMapRWMutex.Unlock() loadOptionsFromDatabase() @@ -353,7 +355,9 @@ func updateOptionMap(key string, value string) (err error) { case "AutomaticDisableKeywords": setting.AutomaticDisableKeywordsFromString(value) case "GeminiSafetySettings": - setting.GeminiSafetySettingFromJsonString(value) + model_setting.GeminiSafetySettingFromJsonString(value) + case "GeminiVersionSettings": + model_setting.GeminiVersionSettingFromJsonString(value) case "StreamCacheQueueLength": setting.StreamCacheQueueLength, _ = strconv.Atoi(value) } diff --git a/relay/channel/gemini/adaptor.go b/relay/channel/gemini/adaptor.go index 32513c42..37c6c9df 100644 --- a/relay/channel/gemini/adaptor.go +++ b/relay/channel/gemini/adaptor.go @@ -7,11 +7,11 @@ import ( "io" "net/http" "one-api/common" - "one-api/constant" "one-api/dto" "one-api/relay/channel" relaycommon "one-api/relay/common" "one-api/service" + "one-api/setting/model_setting" "strings" @@ -64,15 +64,7 @@ func (a *Adaptor) Init(info *relaycommon.RelayInfo) { } func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) { - // 从映射中获取模型名称对应的版本,如果找不到就使用 info.ApiVersion 或默认的版本 "v1beta" - version, beta := constant.GeminiModelMap[info.UpstreamModelName] - if !beta { - if info.ApiVersion != "" { - version = info.ApiVersion - } else { - version = "v1beta" - } - } + version := model_setting.GetGeminiVersionSetting(info.UpstreamModelName) if strings.HasPrefix(info.UpstreamModelName, "imagen") { return fmt.Sprintf("%s/%s/models/%s:predict", info.BaseUrl, version, info.UpstreamModelName), nil diff --git a/relay/channel/gemini/relay-gemini.go b/relay/channel/gemini/relay-gemini.go index 7b7c9cb7..5db01c20 100644 --- a/relay/channel/gemini/relay-gemini.go +++ b/relay/channel/gemini/relay-gemini.go @@ -11,7 +11,7 @@ import ( "one-api/dto" relaycommon "one-api/relay/common" "one-api/service" - "one-api/setting" + "one-api/setting/model_setting" "strings" "unicode/utf8" @@ -36,7 +36,7 @@ func CovertGemini2OpenAI(textRequest dto.GeneralOpenAIRequest) (*GeminiChatReque for _, category := range SafetySettingList { safetySettings = append(safetySettings, GeminiChatSafetySettings{ Category: category, - Threshold: setting.GetGeminiSafetySetting(category), + Threshold: model_setting.GetGeminiSafetySetting(category), }) } geminiRequest.SafetySettings = safetySettings diff --git a/setting/model_setting.go b/setting/model_setting/gemini.go similarity index 53% rename from setting/model_setting.go rename to setting/model_setting/gemini.go index c0f9bd1b..31cd50a0 100644 --- a/setting/model_setting.go +++ b/setting/model_setting/gemini.go @@ -1,4 +1,4 @@ -package setting +package model_setting import ( "encoding/json" @@ -43,3 +43,41 @@ func GeminiSafetySettingsJsonString() string { } return string(jsonString) } + +var geminiVersionSettings = map[string]string{ + "default": "v1beta", + "gemini-1.0-pro": "v1", +} + +func GetGeminiVersionSetting(key string) string { + if value, ok := geminiVersionSettings[key]; ok { + return value + } + return geminiVersionSettings["default"] +} + +func GeminiVersionSettingFromJsonString(jsonString string) { + geminiVersionSettings = map[string]string{} + err := json.Unmarshal([]byte(jsonString), &geminiVersionSettings) + if err != nil { + geminiVersionSettings = map[string]string{ + "default": "v1beta", + } + } + // check must have default + if _, ok := geminiVersionSettings["default"]; !ok { + geminiVersionSettings["default"] = "v1beta" + } +} + +func GeminiVersionSettingsJsonString() string { + // check must have default + if _, ok := geminiVersionSettings["default"]; !ok { + geminiVersionSettings["default"] = "v1beta" + } + jsonString, err := json.Marshal(geminiVersionSettings) + if err != nil { + return "{}" + } + return string(jsonString) +} diff --git a/web/src/components/ModelSetting.js b/web/src/components/ModelSetting.js index 34cba0db..0c76b012 100644 --- a/web/src/components/ModelSetting.js +++ b/web/src/components/ModelSetting.js @@ -22,6 +22,7 @@ const ModelSetting = () => { const { t } = useTranslation(); let [inputs, setInputs] = useState({ GeminiSafetySettings: '', + GeminiVersionSettings: '', }); let [loading, setLoading] = useState(false); @@ -33,7 +34,8 @@ const ModelSetting = () => { let newInputs = {}; data.forEach((item) => { if ( - item.key === 'GeminiSafetySettings' + item.key === 'GeminiSafetySettings' || + item.key === 'GeminiVersionSettings' ) { item.value = JSON.stringify(JSON.parse(item.value), null, 2); } diff --git a/web/src/pages/Setting/Model/SettingGeminiModel.js b/web/src/pages/Setting/Model/SettingGeminiModel.js index 3075b05f..fa63cb74 100644 --- a/web/src/pages/Setting/Model/SettingGeminiModel.js +++ b/web/src/pages/Setting/Model/SettingGeminiModel.js @@ -14,12 +14,18 @@ const GEMINI_SETTING_EXAMPLE = { 'HARM_CATEGORY_CIVIC_INTEGRITY': 'BLOCK_NONE', }; +const GEMINI_VERSION_EXAMPLE = { + 'default': 'v1beta', +}; + + export default function SettingGeminiModel(props) { const { t } = useTranslation(); const [loading, setLoading] = useState(false); const [inputs, setInputs] = useState({ GeminiSafetySettings: '', + GeminiVersionSettings: '', }); const refForm = useRef(); const [inputsRow, setInputsRow] = useState(inputs); @@ -99,6 +105,27 @@ export default function SettingGeminiModel(props) { /> + + + verifyJSON(value), + message: t('不是合法的 JSON 字符串') + } + ]} + onChange={(value) => setInputs({ ...inputs, GeminiVersionSettings: value })} + /> + + +