From e19b244e73325717170c8fe8cdef90ca37fcc2b6 Mon Sep 17 00:00:00 2001 From: "1808837298@qq.com" <1808837298@qq.com> Date: Wed, 26 Feb 2025 16:54:43 +0800 Subject: [PATCH] feat: Add Gemini safety settings configuration support (close #703) --- model/option.go | 3 + relay/channel/gemini/constant.go | 8 ++ relay/channel/gemini/relay-gemini.go | 33 ++---- setting/model_setting.go | 45 +++++++ web/src/components/ModelSetting.js | 82 +++++++++++++ .../pages/Setting/Model/SettingGeminiModel.js | 112 ++++++++++++++++++ web/src/pages/Setting/index.js | 6 + 7 files changed, 267 insertions(+), 22 deletions(-) create mode 100644 setting/model_setting.go create mode 100644 web/src/components/ModelSetting.js create mode 100644 web/src/pages/Setting/Model/SettingGeminiModel.js diff --git a/model/option.go b/model/option.go index 3e9e9541..3897ea36 100644 --- a/model/option.go +++ b/model/option.go @@ -115,6 +115,7 @@ func InitOptionMap() { common.OptionMap["SensitiveWords"] = setting.SensitiveWordsToString() common.OptionMap["StreamCacheQueueLength"] = strconv.Itoa(setting.StreamCacheQueueLength) common.OptionMap["AutomaticDisableKeywords"] = setting.AutomaticDisableKeywordsToString() + common.OptionMap["GeminiSafetySettings"] = setting.GeminiSafetySettingsJsonString() common.OptionMapRWMutex.Unlock() loadOptionsFromDatabase() @@ -351,6 +352,8 @@ func updateOptionMap(key string, value string) (err error) { setting.SensitiveWordsFromString(value) case "AutomaticDisableKeywords": setting.AutomaticDisableKeywordsFromString(value) + case "GeminiSafetySettings": + setting.GeminiSafetySettingFromJsonString(value) case "StreamCacheQueueLength": setting.StreamCacheQueueLength, _ = strconv.Atoi(value) } diff --git a/relay/channel/gemini/constant.go b/relay/channel/gemini/constant.go index b7c1f0cf..ed668d74 100644 --- a/relay/channel/gemini/constant.go +++ b/relay/channel/gemini/constant.go @@ -20,4 +20,12 @@ var ModelList = []string{ "imagen-3.0-generate-002", } +var SafetySettingList = []string{ + "HARM_CATEGORY_HARASSMENT", + "HARM_CATEGORY_VIOLENCE", + "HARM_CATEGORY_SEXUALLY_EXPLICIT", + "HARM_CATEGORY_DANGEROUS_CONTENT", + "HARM_CATEGORY_CIVIC_INTEGRITY", +} + var ChannelName = "google gemini" diff --git a/relay/channel/gemini/relay-gemini.go b/relay/channel/gemini/relay-gemini.go index 8068709e..7b7c9cb7 100644 --- a/relay/channel/gemini/relay-gemini.go +++ b/relay/channel/gemini/relay-gemini.go @@ -11,6 +11,7 @@ import ( "one-api/dto" relaycommon "one-api/relay/common" "one-api/service" + "one-api/setting" "strings" "unicode/utf8" @@ -22,28 +23,7 @@ func CovertGemini2OpenAI(textRequest dto.GeneralOpenAIRequest) (*GeminiChatReque geminiRequest := GeminiChatRequest{ Contents: make([]GeminiChatContent, 0, len(textRequest.Messages)), - SafetySettings: []GeminiChatSafetySettings{ - { - Category: "HARM_CATEGORY_HARASSMENT", - Threshold: common.GeminiSafetySetting, - }, - { - Category: "HARM_CATEGORY_HATE_SPEECH", - Threshold: common.GeminiSafetySetting, - }, - { - Category: "HARM_CATEGORY_SEXUALLY_EXPLICIT", - Threshold: common.GeminiSafetySetting, - }, - { - Category: "HARM_CATEGORY_DANGEROUS_CONTENT", - Threshold: common.GeminiSafetySetting, - }, - { - Category: "HARM_CATEGORY_CIVIC_INTEGRITY", - Threshold: common.GeminiSafetySetting, - }, - }, + //SafetySettings: []GeminiChatSafetySettings{}, GenerationConfig: GeminiChatGenerationConfig{ Temperature: textRequest.Temperature, TopP: textRequest.TopP, @@ -52,6 +32,15 @@ func CovertGemini2OpenAI(textRequest dto.GeneralOpenAIRequest) (*GeminiChatReque }, } + safetySettings := make([]GeminiChatSafetySettings, 0, len(SafetySettingList)) + for _, category := range SafetySettingList { + safetySettings = append(safetySettings, GeminiChatSafetySettings{ + Category: category, + Threshold: setting.GetGeminiSafetySetting(category), + }) + } + geminiRequest.SafetySettings = safetySettings + // openaiContent.FuncToToolCalls() if textRequest.Tools != nil { functions := make([]dto.FunctionCall, 0, len(textRequest.Tools)) diff --git a/setting/model_setting.go b/setting/model_setting.go new file mode 100644 index 00000000..c0f9bd1b --- /dev/null +++ b/setting/model_setting.go @@ -0,0 +1,45 @@ +package setting + +import ( + "encoding/json" + "one-api/common" +) + +var geminiSafetySettings = map[string]string{ + "default": "OFF", + "HARM_CATEGORY_CIVIC_INTEGRITY": "BLOCK_NONE", +} + +func GetGeminiSafetySetting(key string) string { + if value, ok := geminiSafetySettings[key]; ok { + return value + } + return geminiSafetySettings["default"] +} + +func GeminiSafetySettingFromJsonString(jsonString string) { + geminiSafetySettings = map[string]string{} + err := json.Unmarshal([]byte(jsonString), &geminiSafetySettings) + if err != nil { + geminiSafetySettings = map[string]string{ + "default": "OFF", + "HARM_CATEGORY_CIVIC_INTEGRITY": "BLOCK_NONE", + } + } + // check must have default + if _, ok := geminiSafetySettings["default"]; !ok { + geminiSafetySettings["default"] = common.GeminiSafetySetting + } +} + +func GeminiSafetySettingsJsonString() string { + // check must have default + if _, ok := geminiSafetySettings["default"]; !ok { + geminiSafetySettings["default"] = common.GeminiSafetySetting + } + jsonString, err := json.Marshal(geminiSafetySettings) + if err != nil { + return "{}" + } + return string(jsonString) +} diff --git a/web/src/components/ModelSetting.js b/web/src/components/ModelSetting.js new file mode 100644 index 00000000..34cba0db --- /dev/null +++ b/web/src/components/ModelSetting.js @@ -0,0 +1,82 @@ +import React, { useEffect, useState } from 'react'; +import { Card, Spin, Tabs } from '@douyinfe/semi-ui'; +import SettingsGeneral from '../pages/Setting/Operation/SettingsGeneral.js'; +import SettingsDrawing from '../pages/Setting/Operation/SettingsDrawing.js'; +import SettingsSensitiveWords from '../pages/Setting/Operation/SettingsSensitiveWords.js'; +import SettingsLog from '../pages/Setting/Operation/SettingsLog.js'; +import SettingsDataDashboard from '../pages/Setting/Operation/SettingsDataDashboard.js'; +import SettingsMonitoring from '../pages/Setting/Operation/SettingsMonitoring.js'; +import SettingsCreditLimit from '../pages/Setting/Operation/SettingsCreditLimit.js'; +import SettingsMagnification from '../pages/Setting/Operation/SettingsMagnification.js'; +import ModelSettingsVisualEditor from '../pages/Setting/Operation/ModelSettingsVisualEditor.js'; +import GroupRatioSettings from '../pages/Setting/Operation/GroupRatioSettings.js'; +import ModelRatioSettings from '../pages/Setting/Operation/ModelRatioSettings.js'; + + +import { API, showError, showSuccess } from '../helpers'; +import SettingsChats from '../pages/Setting/Operation/SettingsChats.js'; +import { useTranslation } from 'react-i18next'; +import SettingGeminiModel from '../pages/Setting/Model/SettingGeminiModel.js'; + +const ModelSetting = () => { + const { t } = useTranslation(); + let [inputs, setInputs] = useState({ + GeminiSafetySettings: '', + }); + + let [loading, setLoading] = useState(false); + + const getOptions = async () => { + const res = await API.get('/api/option/'); + const { success, message, data } = res.data; + if (success) { + let newInputs = {}; + data.forEach((item) => { + if ( + item.key === 'GeminiSafetySettings' + ) { + item.value = JSON.stringify(JSON.parse(item.value), null, 2); + } + if ( + item.key.endsWith('Enabled') + ) { + newInputs[item.key] = item.value === 'true' ? true : false; + } else { + newInputs[item.key] = item.value; + } + }); + + setInputs(newInputs); + } else { + showError(message); + } + }; + async function onRefresh() { + try { + setLoading(true); + await getOptions(); + // showSuccess('刷新成功'); + } catch (error) { + showError('刷新失败'); + } finally { + setLoading(false); + } + } + + useEffect(() => { + onRefresh(); + }, []); + + return ( + <> + + {/* Gemini */} + + + + + + ); +}; + +export default ModelSetting; diff --git a/web/src/pages/Setting/Model/SettingGeminiModel.js b/web/src/pages/Setting/Model/SettingGeminiModel.js new file mode 100644 index 00000000..3075b05f --- /dev/null +++ b/web/src/pages/Setting/Model/SettingGeminiModel.js @@ -0,0 +1,112 @@ +import React, { useEffect, useState, useRef } from 'react'; +import { Button, Col, Form, Row, Spin } from '@douyinfe/semi-ui'; +import { + compareObjects, + API, + showError, + showSuccess, + showWarning, verifyJSON +} from '../../../helpers'; +import { useTranslation } from 'react-i18next'; + +const GEMINI_SETTING_EXAMPLE = { + 'default': 'OFF', + 'HARM_CATEGORY_CIVIC_INTEGRITY': 'BLOCK_NONE', +}; + +export default function SettingGeminiModel(props) { + const { t } = useTranslation(); + + const [loading, setLoading] = useState(false); + const [inputs, setInputs] = useState({ + GeminiSafetySettings: '', + }); + const refForm = useRef(); + const [inputsRow, setInputsRow] = useState(inputs); + + function onSubmit() { + const updateArray = compareObjects(inputs, inputsRow); + if (!updateArray.length) return showWarning(t('你似乎并没有修改什么')); + const requestQueue = updateArray.map((item) => { + let value = ''; + if (typeof inputs[item.key] === 'boolean') { + value = String(inputs[item.key]); + } else { + value = inputs[item.key]; + } + return API.put('/api/option/', { + key: item.key, + value, + }); + }); + setLoading(true); + Promise.all(requestQueue) + .then((res) => { + if (requestQueue.length === 1) { + if (res.includes(undefined)) return; + } else if (requestQueue.length > 1) { + if (res.includes(undefined)) return showError(t('部分保存失败,请重试')); + } + showSuccess(t('保存成功')); + props.refresh(); + }) + .catch(() => { + showError(t('保存失败,请重试')); + }) + .finally(() => { + setLoading(false); + }); + } + + useEffect(() => { + const currentInputs = {}; + for (let key in props.options) { + if (Object.keys(inputs).includes(key)) { + currentInputs[key] = props.options[key]; + } + } + setInputs(currentInputs); + setInputsRow(structuredClone(currentInputs)); + refForm.current.setValues(currentInputs); + }, [props.options]); + + return ( + <> + +
(refForm.current = formAPI)} + style={{ marginBottom: 15 }} + > + + + + verifyJSON(value), + message: t('不是合法的 JSON 字符串') + } + ]} + onChange={(value) => setInputs({ ...inputs, GeminiSafetySettings: value })} + /> + + + + + + +
+
+ + ); +} diff --git a/web/src/pages/Setting/index.js b/web/src/pages/Setting/index.js index b5c5e268..17a85088 100644 --- a/web/src/pages/Setting/index.js +++ b/web/src/pages/Setting/index.js @@ -9,6 +9,7 @@ import OtherSetting from '../../components/OtherSetting'; import PersonalSetting from '../../components/PersonalSetting'; import OperationSetting from '../../components/OperationSetting'; import RateLimitSetting from '../../components/RateLimitSetting.js'; +import ModelSetting from '../../components/ModelSetting.js'; const Setting = () => { const { t } = useTranslation(); @@ -34,6 +35,11 @@ const Setting = () => { content: , itemKey: 'ratelimit', }); + panes.push({ + tab: t('模型相关设置'), + content: , + itemKey: 'models', + }); panes.push({ tab: t('系统设置'), content: ,