From 9edb9f7a71e2b0a0b1d809a9996d7392f0d4cd1c Mon Sep 17 00:00:00 2001
From: "1808837298@qq.com" <1808837298@qq.com>
Date: Thu, 13 Feb 2025 16:39:17 +0800
Subject: [PATCH] feat: Add automatic channel disabling based on configurable
keywords
- Introduce AutomaticDisableKeywords setting to dynamically control channel disabling
- Implement AC search for matching error messages against disable keywords
- Add frontend UI for configuring automatic disable keywords
- Update localization with new keyword-based channel disabling feature
- Refactor sensitive word and AC search logic to support multiple keyword lists
---
common/database.go | 1 +
controller/channel-test.go | 28 ++++++++---------
model/option.go | 3 ++
service/channel.go | 18 +++--------
service/sensitive.go | 14 ++-------
service/str.go | 31 ++++++++++++++++---
setting/operation_setting.go | 27 ++++++++++++++++
web/src/components/OperationSetting.js | 1 +
web/src/i18n/locales/en.json | 7 +++--
.../Setting/Operation/SettingsMonitoring.js | 17 ++++++++--
10 files changed, 97 insertions(+), 50 deletions(-)
diff --git a/common/database.go b/common/database.go
index ce7a9bc1..3c0a944b 100644
--- a/common/database.go
+++ b/common/database.go
@@ -3,5 +3,6 @@ package common
var UsingSQLite = false
var UsingPostgreSQL = false
var UsingMySQL = false
+var UsingClickHouse = false
var SQLitePath = "one-api.db?_busy_timeout=5000"
diff --git a/controller/channel-test.go b/controller/channel-test.go
index 93f92f4c..7e74bec2 100644
--- a/controller/channel-test.go
+++ b/controller/channel-test.go
@@ -41,36 +41,34 @@ func testChannel(channel *model.Channel, testModel string) (err error, openAIErr
}
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
-
+
requestPath := "/v1/chat/completions"
-
+
// 先判断是否为 Embedding 模型
if strings.Contains(strings.ToLower(testModel), "embedding") ||
- strings.HasPrefix(testModel, "m3e") || // m3e 系列模型
- strings.Contains(testModel, "bge-") || // bge 系列模型
+ strings.HasPrefix(testModel, "m3e") || // m3e 系列模型
+ strings.Contains(testModel, "bge-") || // bge 系列模型
testModel == "text-embedding-v1" ||
- channel.Type == common.ChannelTypeMokaAI{ // 其他 embedding 模型
- requestPath = "/v1/embeddings" // 修改请求路径
+ channel.Type == common.ChannelTypeMokaAI { // 其他 embedding 模型
+ requestPath = "/v1/embeddings" // 修改请求路径
}
-
+
c.Request = &http.Request{
Method: "POST",
- URL: &url.URL{Path: requestPath}, // 使用动态路径
+ URL: &url.URL{Path: requestPath}, // 使用动态路径
Body: nil,
Header: make(http.Header),
}
if testModel == "" {
- common.SysLog(fmt.Sprintf("testModel 为空, channel 的 TestModel 是 %s", string(*channel.TestModel)))
if channel.TestModel != nil && *channel.TestModel != "" {
testModel = *channel.TestModel
} else {
if len(channel.GetModels()) > 0 {
testModel = channel.GetModels()[0]
} else {
- testModel = "gpt-3.5-turbo"
+ testModel = "gpt-4o-mini"
}
- common.SysLog(fmt.Sprintf("testModel 为空, channel 的 TestModel 为空:", string(testModel)))
}
}
@@ -102,7 +100,7 @@ func testChannel(channel *model.Channel, testModel string) (err error, openAIErr
request := buildTestRequest(testModel)
meta.UpstreamModelName = testModel
- common.SysLog(fmt.Sprintf("testing channel %d with model %s , meta %s ", channel.Id, testModel, meta))
+ common.SysLog(fmt.Sprintf("testing channel %d with model %s , meta %v ", channel.Id, testModel, meta))
adaptor.Init(meta)
@@ -173,9 +171,9 @@ func buildTestRequest(model string) *dto.GeneralOpenAIRequest {
// 先判断是否为 Embedding 模型
if strings.Contains(strings.ToLower(model), "embedding") ||
- strings.HasPrefix(model, "m3e") || // m3e 系列模型
- strings.Contains(model, "bge-") || // bge 系列模型
- model == "text-embedding-v1" { // 其他 embedding 模型
+ strings.HasPrefix(model, "m3e") || // m3e 系列模型
+ strings.Contains(model, "bge-") || // bge 系列模型
+ model == "text-embedding-v1" { // 其他 embedding 模型
// Embedding 请求
testRequest.Input = []string{"hello world"}
return testRequest
diff --git a/model/option.go b/model/option.go
index f1f2809d..0c4114a4 100644
--- a/model/option.go
+++ b/model/option.go
@@ -110,6 +110,7 @@ func InitOptionMap() {
common.OptionMap["StopOnSensitiveEnabled"] = strconv.FormatBool(setting.StopOnSensitiveEnabled)
common.OptionMap["SensitiveWords"] = setting.SensitiveWordsToString()
common.OptionMap["StreamCacheQueueLength"] = strconv.Itoa(setting.StreamCacheQueueLength)
+ common.OptionMap["AutomaticDisableKeywords"] = setting.AutomaticDisableKeywordsToString()
common.OptionMapRWMutex.Unlock()
loadOptionsFromDatabase()
@@ -335,6 +336,8 @@ func updateOptionMap(key string, value string) (err error) {
common.QuotaPerUnit, _ = strconv.ParseFloat(value, 64)
case "SensitiveWords":
setting.SensitiveWordsFromString(value)
+ case "AutomaticDisableKeywords":
+ setting.AutomaticDisableKeywordsFromString(value)
case "StreamCacheQueueLength":
setting.StreamCacheQueueLength, _ = strconv.Atoi(value)
}
diff --git a/service/channel.go b/service/channel.go
index 047550d8..73545b1e 100644
--- a/service/channel.go
+++ b/service/channel.go
@@ -6,6 +6,7 @@ import (
"one-api/common"
relaymodel "one-api/dto"
"one-api/model"
+ "one-api/setting"
"strings"
)
@@ -64,21 +65,10 @@ func ShouldDisableChannel(channelType int, err *relaymodel.OpenAIErrorWithStatus
case "forbidden":
return true
}
- if strings.HasPrefix(err.Error.Message, "Your credit balance is too low") { // anthropic
- return true
- } else if strings.HasPrefix(err.Error.Message, "This organization has been disabled.") {
- return true
- } else if strings.HasPrefix(err.Error.Message, "You exceeded your current quota") {
- return true
- } else if strings.HasPrefix(err.Error.Message, "Permission denied") {
- return true
- }
- if strings.Contains(err.Error.Message, "The security token included in the request is invalid") { // anthropic
- return true
- } else if strings.Contains(err.Error.Message, "Operation not allowed") {
- return true
- } else if strings.Contains(err.Error.Message, "Your account is not authorized") {
+ lowerMessage := strings.ToLower(err.Error.Message)
+ search, _ := AcSearch(lowerMessage, setting.AutomaticDisableKeywords, true)
+ if search {
return true
}
diff --git a/service/sensitive.go b/service/sensitive.go
index 321f55af..14ac9481 100644
--- a/service/sensitive.go
+++ b/service/sensitive.go
@@ -60,17 +60,7 @@ func SensitiveWordContains(text string) (bool, []string) {
return false, nil
}
checkText := strings.ToLower(text)
- // 构建一个AC自动机
- m := InitAc()
- hits := m.MultiPatternSearch([]rune(checkText), false)
- if len(hits) > 0 {
- words := make([]string, 0)
- for _, hit := range hits {
- words = append(words, string(hit.Word))
- }
- return true, words
- }
- return false, nil
+ return AcSearch(checkText, setting.SensitiveWords, false)
}
// SensitiveWordReplace 敏感词替换,返回是否包含敏感词和替换后的文本
@@ -79,7 +69,7 @@ func SensitiveWordReplace(text string, returnImmediately bool) (bool, []string,
return false, nil, text
}
checkText := strings.ToLower(text)
- m := InitAc()
+ m := InitAc(setting.SensitiveWords)
hits := m.MultiPatternSearch([]rune(checkText), returnImmediately)
if len(hits) > 0 {
words := make([]string, 0)
diff --git a/service/str.go b/service/str.go
index 8137bf55..4390e99b 100644
--- a/service/str.go
+++ b/service/str.go
@@ -4,7 +4,6 @@ import (
"bytes"
"fmt"
goahocorasick "github.com/anknown/ahocorasick"
- "one-api/setting"
"strings"
)
@@ -57,9 +56,9 @@ func RemoveDuplicate(s []string) []string {
return result
}
-func InitAc() *goahocorasick.Machine {
+func InitAc(words []string) *goahocorasick.Machine {
m := new(goahocorasick.Machine)
- dict := readRunes()
+ dict := readRunes(words)
if err := m.Build(dict); err != nil {
fmt.Println(err)
return nil
@@ -67,10 +66,10 @@ func InitAc() *goahocorasick.Machine {
return m
}
-func readRunes() [][]rune {
+func readRunes(words []string) [][]rune {
var dict [][]rune
- for _, word := range setting.SensitiveWords {
+ for _, word := range words {
word = strings.ToLower(word)
l := bytes.TrimSpace([]byte(word))
dict = append(dict, bytes.Runes(l))
@@ -78,3 +77,25 @@ func readRunes() [][]rune {
return dict
}
+
+func AcSearch(findText string, dict []string, stopImmediately bool) (bool, []string) {
+ if len(dict) == 0 {
+ return false, nil
+ }
+ if len(findText) == 0 {
+ return false, nil
+ }
+ m := InitAc(dict)
+ if m == nil {
+ return false, nil
+ }
+ hits := m.MultiPatternSearch([]rune(findText), stopImmediately)
+ if len(hits) > 0 {
+ words := make([]string, 0)
+ for _, hit := range hits {
+ words = append(words, string(hit.Word))
+ }
+ return true, words
+ }
+ return false, nil
+}
diff --git a/setting/operation_setting.go b/setting/operation_setting.go
index 0f2b4ffd..9a28e987 100644
--- a/setting/operation_setting.go
+++ b/setting/operation_setting.go
@@ -1,3 +1,30 @@
package setting
+import "strings"
+
var DemoSiteEnabled = false
+
+var AutomaticDisableKeywords = []string{
+ "Your credit balance is too low",
+ "This organization has been disabled.",
+ "You exceeded your current quota",
+ "Permission denied",
+ "The security token included in the request is invalid",
+ "Operation not allowed",
+ "Your account is not authorized",
+}
+
+func AutomaticDisableKeywordsToString() string {
+ return strings.Join(AutomaticDisableKeywords, "\n")
+}
+
+func AutomaticDisableKeywordsFromString(s string) {
+ AutomaticDisableKeywords = []string{}
+ ak := strings.Split(s, "\n")
+ for _, k := range ak {
+ k = strings.TrimSpace(k)
+ if k != "" {
+ AutomaticDisableKeywords = append(AutomaticDisableKeywords, k)
+ }
+ }
+}
diff --git a/web/src/components/OperationSetting.js b/web/src/components/OperationSetting.js
index 98b67c67..caa9cc2e 100644
--- a/web/src/components/OperationSetting.js
+++ b/web/src/components/OperationSetting.js
@@ -59,6 +59,7 @@ const OperationSetting = () => {
RetryTimes: 0,
Chats: "[]",
DemoSiteEnabled: false,
+ AutomaticDisableKeywords: '',
});
let [loading, setLoading] = useState(false);
diff --git a/web/src/i18n/locales/en.json b/web/src/i18n/locales/en.json
index bac8c829..3d2c7a55 100644
--- a/web/src/i18n/locales/en.json
+++ b/web/src/i18n/locales/en.json
@@ -201,7 +201,7 @@
"相关 API 显示令牌额度而非用户额度": "Related APIs show token quota instead of user quota",
"保存通用设置": "Save General Settings",
"监控设置": "Monitoring Settings",
- "最长响应时间": "Maximum Response Time",
+ "测试所有渠道的最长响应时间": "Maximum response time for testing all channels",
"单位秒": "Unit: seconds",
"当运行通道全部测试时": "When running all channel tests",
"超过此时间将自动禁用通道": "Channels exceeding this time will be automatically disabled",
@@ -1246,5 +1246,8 @@
"请输入要设置的标签名称": "Please enter the tag name to be set",
"请输入标签名称": "Please enter the tag name",
"支持搜索用户的 ID、用户名、显示名称和邮箱地址": "Support searching for user ID, username, display name, and email address",
- "已注销": "Logged out"
+ "已注销": "Logged out",
+ "自动禁用关键词": "Automatic disable keywords",
+ "一行一个,不区分大小写": "One line per keyword, not case-sensitive",
+ "当上游通道返回错误中包含这些关键词时(不区分大小写),自动禁用通道": "When the upstream channel returns an error containing these keywords (not case-sensitive), automatically disable the channel"
}
\ No newline at end of file
diff --git a/web/src/pages/Setting/Operation/SettingsMonitoring.js b/web/src/pages/Setting/Operation/SettingsMonitoring.js
index 89196f98..18c368aa 100644
--- a/web/src/pages/Setting/Operation/SettingsMonitoring.js
+++ b/web/src/pages/Setting/Operation/SettingsMonitoring.js
@@ -5,7 +5,7 @@ import {
API,
showError,
showSuccess,
- showWarning,
+ showWarning, verifyJSON
} from '../../../helpers';
import { useTranslation } from 'react-i18next';
@@ -17,6 +17,7 @@ export default function SettingsMonitoring(props) {
QuotaRemindThreshold: '',
AutomaticDisableChannelEnabled: false,
AutomaticEnableChannelEnabled: false,
+ AutomaticDisableKeywords: '',
});
const refForm = useRef();
const [inputsRow, setInputsRow] = useState(inputs);
@@ -79,7 +80,7 @@ export default function SettingsMonitoring(props) {
+
+
+ setInputs({ ...inputs, AutomaticDisableKeywords: value })}
+ />
+
+