diff --git a/common/database.go b/common/database.go index ce7a9bc1..3c0a944b 100644 --- a/common/database.go +++ b/common/database.go @@ -3,5 +3,6 @@ package common var UsingSQLite = false var UsingPostgreSQL = false var UsingMySQL = false +var UsingClickHouse = false var SQLitePath = "one-api.db?_busy_timeout=5000" diff --git a/controller/channel-test.go b/controller/channel-test.go index 93f92f4c..7e74bec2 100644 --- a/controller/channel-test.go +++ b/controller/channel-test.go @@ -41,36 +41,34 @@ func testChannel(channel *model.Channel, testModel string) (err error, openAIErr } w := httptest.NewRecorder() c, _ := gin.CreateTestContext(w) - + requestPath := "/v1/chat/completions" - + // 先判断是否为 Embedding 模型 if strings.Contains(strings.ToLower(testModel), "embedding") || - strings.HasPrefix(testModel, "m3e") || // m3e 系列模型 - strings.Contains(testModel, "bge-") || // bge 系列模型 + strings.HasPrefix(testModel, "m3e") || // m3e 系列模型 + strings.Contains(testModel, "bge-") || // bge 系列模型 testModel == "text-embedding-v1" || - channel.Type == common.ChannelTypeMokaAI{ // 其他 embedding 模型 - requestPath = "/v1/embeddings" // 修改请求路径 + channel.Type == common.ChannelTypeMokaAI { // 其他 embedding 模型 + requestPath = "/v1/embeddings" // 修改请求路径 } - + c.Request = &http.Request{ Method: "POST", - URL: &url.URL{Path: requestPath}, // 使用动态路径 + URL: &url.URL{Path: requestPath}, // 使用动态路径 Body: nil, Header: make(http.Header), } if testModel == "" { - common.SysLog(fmt.Sprintf("testModel 为空, channel 的 TestModel 是 %s", string(*channel.TestModel))) if channel.TestModel != nil && *channel.TestModel != "" { testModel = *channel.TestModel } else { if len(channel.GetModels()) > 0 { testModel = channel.GetModels()[0] } else { - testModel = "gpt-3.5-turbo" + testModel = "gpt-4o-mini" } - common.SysLog(fmt.Sprintf("testModel 为空, channel 的 TestModel 为空:", string(testModel))) } } @@ -102,7 +100,7 @@ func testChannel(channel *model.Channel, testModel string) (err error, openAIErr request := buildTestRequest(testModel) meta.UpstreamModelName = testModel - common.SysLog(fmt.Sprintf("testing channel %d with model %s , meta %s ", channel.Id, testModel, meta)) + common.SysLog(fmt.Sprintf("testing channel %d with model %s , meta %v ", channel.Id, testModel, meta)) adaptor.Init(meta) @@ -173,9 +171,9 @@ func buildTestRequest(model string) *dto.GeneralOpenAIRequest { // 先判断是否为 Embedding 模型 if strings.Contains(strings.ToLower(model), "embedding") || - strings.HasPrefix(model, "m3e") || // m3e 系列模型 - strings.Contains(model, "bge-") || // bge 系列模型 - model == "text-embedding-v1" { // 其他 embedding 模型 + strings.HasPrefix(model, "m3e") || // m3e 系列模型 + strings.Contains(model, "bge-") || // bge 系列模型 + model == "text-embedding-v1" { // 其他 embedding 模型 // Embedding 请求 testRequest.Input = []string{"hello world"} return testRequest diff --git a/model/option.go b/model/option.go index f1f2809d..0c4114a4 100644 --- a/model/option.go +++ b/model/option.go @@ -110,6 +110,7 @@ func InitOptionMap() { common.OptionMap["StopOnSensitiveEnabled"] = strconv.FormatBool(setting.StopOnSensitiveEnabled) common.OptionMap["SensitiveWords"] = setting.SensitiveWordsToString() common.OptionMap["StreamCacheQueueLength"] = strconv.Itoa(setting.StreamCacheQueueLength) + common.OptionMap["AutomaticDisableKeywords"] = setting.AutomaticDisableKeywordsToString() common.OptionMapRWMutex.Unlock() loadOptionsFromDatabase() @@ -335,6 +336,8 @@ func updateOptionMap(key string, value string) (err error) { common.QuotaPerUnit, _ = strconv.ParseFloat(value, 64) case "SensitiveWords": setting.SensitiveWordsFromString(value) + case "AutomaticDisableKeywords": + setting.AutomaticDisableKeywordsFromString(value) case "StreamCacheQueueLength": setting.StreamCacheQueueLength, _ = strconv.Atoi(value) } diff --git a/service/channel.go b/service/channel.go index 047550d8..73545b1e 100644 --- a/service/channel.go +++ b/service/channel.go @@ -6,6 +6,7 @@ import ( "one-api/common" relaymodel "one-api/dto" "one-api/model" + "one-api/setting" "strings" ) @@ -64,21 +65,10 @@ func ShouldDisableChannel(channelType int, err *relaymodel.OpenAIErrorWithStatus case "forbidden": return true } - if strings.HasPrefix(err.Error.Message, "Your credit balance is too low") { // anthropic - return true - } else if strings.HasPrefix(err.Error.Message, "This organization has been disabled.") { - return true - } else if strings.HasPrefix(err.Error.Message, "You exceeded your current quota") { - return true - } else if strings.HasPrefix(err.Error.Message, "Permission denied") { - return true - } - if strings.Contains(err.Error.Message, "The security token included in the request is invalid") { // anthropic - return true - } else if strings.Contains(err.Error.Message, "Operation not allowed") { - return true - } else if strings.Contains(err.Error.Message, "Your account is not authorized") { + lowerMessage := strings.ToLower(err.Error.Message) + search, _ := AcSearch(lowerMessage, setting.AutomaticDisableKeywords, true) + if search { return true } diff --git a/service/sensitive.go b/service/sensitive.go index 321f55af..14ac9481 100644 --- a/service/sensitive.go +++ b/service/sensitive.go @@ -60,17 +60,7 @@ func SensitiveWordContains(text string) (bool, []string) { return false, nil } checkText := strings.ToLower(text) - // 构建一个AC自动机 - m := InitAc() - hits := m.MultiPatternSearch([]rune(checkText), false) - if len(hits) > 0 { - words := make([]string, 0) - for _, hit := range hits { - words = append(words, string(hit.Word)) - } - return true, words - } - return false, nil + return AcSearch(checkText, setting.SensitiveWords, false) } // SensitiveWordReplace 敏感词替换,返回是否包含敏感词和替换后的文本 @@ -79,7 +69,7 @@ func SensitiveWordReplace(text string, returnImmediately bool) (bool, []string, return false, nil, text } checkText := strings.ToLower(text) - m := InitAc() + m := InitAc(setting.SensitiveWords) hits := m.MultiPatternSearch([]rune(checkText), returnImmediately) if len(hits) > 0 { words := make([]string, 0) diff --git a/service/str.go b/service/str.go index 8137bf55..4390e99b 100644 --- a/service/str.go +++ b/service/str.go @@ -4,7 +4,6 @@ import ( "bytes" "fmt" goahocorasick "github.com/anknown/ahocorasick" - "one-api/setting" "strings" ) @@ -57,9 +56,9 @@ func RemoveDuplicate(s []string) []string { return result } -func InitAc() *goahocorasick.Machine { +func InitAc(words []string) *goahocorasick.Machine { m := new(goahocorasick.Machine) - dict := readRunes() + dict := readRunes(words) if err := m.Build(dict); err != nil { fmt.Println(err) return nil @@ -67,10 +66,10 @@ func InitAc() *goahocorasick.Machine { return m } -func readRunes() [][]rune { +func readRunes(words []string) [][]rune { var dict [][]rune - for _, word := range setting.SensitiveWords { + for _, word := range words { word = strings.ToLower(word) l := bytes.TrimSpace([]byte(word)) dict = append(dict, bytes.Runes(l)) @@ -78,3 +77,25 @@ func readRunes() [][]rune { return dict } + +func AcSearch(findText string, dict []string, stopImmediately bool) (bool, []string) { + if len(dict) == 0 { + return false, nil + } + if len(findText) == 0 { + return false, nil + } + m := InitAc(dict) + if m == nil { + return false, nil + } + hits := m.MultiPatternSearch([]rune(findText), stopImmediately) + if len(hits) > 0 { + words := make([]string, 0) + for _, hit := range hits { + words = append(words, string(hit.Word)) + } + return true, words + } + return false, nil +} diff --git a/setting/operation_setting.go b/setting/operation_setting.go index 0f2b4ffd..9a28e987 100644 --- a/setting/operation_setting.go +++ b/setting/operation_setting.go @@ -1,3 +1,30 @@ package setting +import "strings" + var DemoSiteEnabled = false + +var AutomaticDisableKeywords = []string{ + "Your credit balance is too low", + "This organization has been disabled.", + "You exceeded your current quota", + "Permission denied", + "The security token included in the request is invalid", + "Operation not allowed", + "Your account is not authorized", +} + +func AutomaticDisableKeywordsToString() string { + return strings.Join(AutomaticDisableKeywords, "\n") +} + +func AutomaticDisableKeywordsFromString(s string) { + AutomaticDisableKeywords = []string{} + ak := strings.Split(s, "\n") + for _, k := range ak { + k = strings.TrimSpace(k) + if k != "" { + AutomaticDisableKeywords = append(AutomaticDisableKeywords, k) + } + } +} diff --git a/web/src/components/OperationSetting.js b/web/src/components/OperationSetting.js index 98b67c67..caa9cc2e 100644 --- a/web/src/components/OperationSetting.js +++ b/web/src/components/OperationSetting.js @@ -59,6 +59,7 @@ const OperationSetting = () => { RetryTimes: 0, Chats: "[]", DemoSiteEnabled: false, + AutomaticDisableKeywords: '', }); let [loading, setLoading] = useState(false); diff --git a/web/src/i18n/locales/en.json b/web/src/i18n/locales/en.json index bac8c829..3d2c7a55 100644 --- a/web/src/i18n/locales/en.json +++ b/web/src/i18n/locales/en.json @@ -201,7 +201,7 @@ "相关 API 显示令牌额度而非用户额度": "Related APIs show token quota instead of user quota", "保存通用设置": "Save General Settings", "监控设置": "Monitoring Settings", - "最长响应时间": "Maximum Response Time", + "测试所有渠道的最长响应时间": "Maximum response time for testing all channels", "单位秒": "Unit: seconds", "当运行通道全部测试时": "When running all channel tests", "超过此时间将自动禁用通道": "Channels exceeding this time will be automatically disabled", @@ -1246,5 +1246,8 @@ "请输入要设置的标签名称": "Please enter the tag name to be set", "请输入标签名称": "Please enter the tag name", "支持搜索用户的 ID、用户名、显示名称和邮箱地址": "Support searching for user ID, username, display name, and email address", - "已注销": "Logged out" + "已注销": "Logged out", + "自动禁用关键词": "Automatic disable keywords", + "一行一个,不区分大小写": "One line per keyword, not case-sensitive", + "当上游通道返回错误中包含这些关键词时(不区分大小写),自动禁用通道": "When the upstream channel returns an error containing these keywords (not case-sensitive), automatically disable the channel" } \ No newline at end of file diff --git a/web/src/pages/Setting/Operation/SettingsMonitoring.js b/web/src/pages/Setting/Operation/SettingsMonitoring.js index 89196f98..18c368aa 100644 --- a/web/src/pages/Setting/Operation/SettingsMonitoring.js +++ b/web/src/pages/Setting/Operation/SettingsMonitoring.js @@ -5,7 +5,7 @@ import { API, showError, showSuccess, - showWarning, + showWarning, verifyJSON } from '../../../helpers'; import { useTranslation } from 'react-i18next'; @@ -17,6 +17,7 @@ export default function SettingsMonitoring(props) { QuotaRemindThreshold: '', AutomaticDisableChannelEnabled: false, AutomaticEnableChannelEnabled: false, + AutomaticDisableKeywords: '', }); const refForm = useRef(); const [inputsRow, setInputsRow] = useState(inputs); @@ -79,7 +80,7 @@ export default function SettingsMonitoring(props) { + + + setInputs({ ...inputs, AutomaticDisableKeywords: value })} + /> + +