feat: Add automatic channel disabling based on configurable keywords

- Introduce AutomaticDisableKeywords setting to dynamically control channel disabling
- Implement AC search for matching error messages against disable keywords
- Add frontend UI for configuring automatic disable keywords
- Update localization with new keyword-based channel disabling feature
- Refactor sensitive word and AC search logic to support multiple keyword lists
This commit is contained in:
1808837298@qq.com
2025-02-13 16:39:17 +08:00
parent bc62d1bb81
commit 9edb9f7a71
10 changed files with 97 additions and 50 deletions

View File

@@ -3,5 +3,6 @@ package common
var UsingSQLite = false var UsingSQLite = false
var UsingPostgreSQL = false var UsingPostgreSQL = false
var UsingMySQL = false var UsingMySQL = false
var UsingClickHouse = false
var SQLitePath = "one-api.db?_busy_timeout=5000" var SQLitePath = "one-api.db?_busy_timeout=5000"

View File

@@ -46,31 +46,29 @@ func testChannel(channel *model.Channel, testModel string) (err error, openAIErr
// 先判断是否为 Embedding 模型 // 先判断是否为 Embedding 模型
if strings.Contains(strings.ToLower(testModel), "embedding") || if strings.Contains(strings.ToLower(testModel), "embedding") ||
strings.HasPrefix(testModel, "m3e") || // m3e 系列模型 strings.HasPrefix(testModel, "m3e") || // m3e 系列模型
strings.Contains(testModel, "bge-") || // bge 系列模型 strings.Contains(testModel, "bge-") || // bge 系列模型
testModel == "text-embedding-v1" || testModel == "text-embedding-v1" ||
channel.Type == common.ChannelTypeMokaAI{ // 其他 embedding 模型 channel.Type == common.ChannelTypeMokaAI { // 其他 embedding 模型
requestPath = "/v1/embeddings" // 修改请求路径 requestPath = "/v1/embeddings" // 修改请求路径
} }
c.Request = &http.Request{ c.Request = &http.Request{
Method: "POST", Method: "POST",
URL: &url.URL{Path: requestPath}, // 使用动态路径 URL: &url.URL{Path: requestPath}, // 使用动态路径
Body: nil, Body: nil,
Header: make(http.Header), Header: make(http.Header),
} }
if testModel == "" { if testModel == "" {
common.SysLog(fmt.Sprintf("testModel 为空, channel 的 TestModel 是 %s", string(*channel.TestModel)))
if channel.TestModel != nil && *channel.TestModel != "" { if channel.TestModel != nil && *channel.TestModel != "" {
testModel = *channel.TestModel testModel = *channel.TestModel
} else { } else {
if len(channel.GetModels()) > 0 { if len(channel.GetModels()) > 0 {
testModel = channel.GetModels()[0] testModel = channel.GetModels()[0]
} else { } else {
testModel = "gpt-3.5-turbo" testModel = "gpt-4o-mini"
} }
common.SysLog(fmt.Sprintf("testModel 为空, channel 的 TestModel 为空:", string(testModel)))
} }
} }
@@ -102,7 +100,7 @@ func testChannel(channel *model.Channel, testModel string) (err error, openAIErr
request := buildTestRequest(testModel) request := buildTestRequest(testModel)
meta.UpstreamModelName = testModel meta.UpstreamModelName = testModel
common.SysLog(fmt.Sprintf("testing channel %d with model %s , meta %s ", channel.Id, testModel, meta)) common.SysLog(fmt.Sprintf("testing channel %d with model %s , meta %v ", channel.Id, testModel, meta))
adaptor.Init(meta) adaptor.Init(meta)
@@ -173,9 +171,9 @@ func buildTestRequest(model string) *dto.GeneralOpenAIRequest {
// 先判断是否为 Embedding 模型 // 先判断是否为 Embedding 模型
if strings.Contains(strings.ToLower(model), "embedding") || if strings.Contains(strings.ToLower(model), "embedding") ||
strings.HasPrefix(model, "m3e") || // m3e 系列模型 strings.HasPrefix(model, "m3e") || // m3e 系列模型
strings.Contains(model, "bge-") || // bge 系列模型 strings.Contains(model, "bge-") || // bge 系列模型
model == "text-embedding-v1" { // 其他 embedding 模型 model == "text-embedding-v1" { // 其他 embedding 模型
// Embedding 请求 // Embedding 请求
testRequest.Input = []string{"hello world"} testRequest.Input = []string{"hello world"}
return testRequest return testRequest

View File

@@ -110,6 +110,7 @@ func InitOptionMap() {
common.OptionMap["StopOnSensitiveEnabled"] = strconv.FormatBool(setting.StopOnSensitiveEnabled) common.OptionMap["StopOnSensitiveEnabled"] = strconv.FormatBool(setting.StopOnSensitiveEnabled)
common.OptionMap["SensitiveWords"] = setting.SensitiveWordsToString() common.OptionMap["SensitiveWords"] = setting.SensitiveWordsToString()
common.OptionMap["StreamCacheQueueLength"] = strconv.Itoa(setting.StreamCacheQueueLength) common.OptionMap["StreamCacheQueueLength"] = strconv.Itoa(setting.StreamCacheQueueLength)
common.OptionMap["AutomaticDisableKeywords"] = setting.AutomaticDisableKeywordsToString()
common.OptionMapRWMutex.Unlock() common.OptionMapRWMutex.Unlock()
loadOptionsFromDatabase() loadOptionsFromDatabase()
@@ -335,6 +336,8 @@ func updateOptionMap(key string, value string) (err error) {
common.QuotaPerUnit, _ = strconv.ParseFloat(value, 64) common.QuotaPerUnit, _ = strconv.ParseFloat(value, 64)
case "SensitiveWords": case "SensitiveWords":
setting.SensitiveWordsFromString(value) setting.SensitiveWordsFromString(value)
case "AutomaticDisableKeywords":
setting.AutomaticDisableKeywordsFromString(value)
case "StreamCacheQueueLength": case "StreamCacheQueueLength":
setting.StreamCacheQueueLength, _ = strconv.Atoi(value) setting.StreamCacheQueueLength, _ = strconv.Atoi(value)
} }

View File

@@ -6,6 +6,7 @@ import (
"one-api/common" "one-api/common"
relaymodel "one-api/dto" relaymodel "one-api/dto"
"one-api/model" "one-api/model"
"one-api/setting"
"strings" "strings"
) )
@@ -64,21 +65,10 @@ func ShouldDisableChannel(channelType int, err *relaymodel.OpenAIErrorWithStatus
case "forbidden": case "forbidden":
return true return true
} }
if strings.HasPrefix(err.Error.Message, "Your credit balance is too low") { // anthropic
return true
} else if strings.HasPrefix(err.Error.Message, "This organization has been disabled.") {
return true
} else if strings.HasPrefix(err.Error.Message, "You exceeded your current quota") {
return true
} else if strings.HasPrefix(err.Error.Message, "Permission denied") {
return true
}
if strings.Contains(err.Error.Message, "The security token included in the request is invalid") { // anthropic lowerMessage := strings.ToLower(err.Error.Message)
return true search, _ := AcSearch(lowerMessage, setting.AutomaticDisableKeywords, true)
} else if strings.Contains(err.Error.Message, "Operation not allowed") { if search {
return true
} else if strings.Contains(err.Error.Message, "Your account is not authorized") {
return true return true
} }

View File

@@ -60,17 +60,7 @@ func SensitiveWordContains(text string) (bool, []string) {
return false, nil return false, nil
} }
checkText := strings.ToLower(text) checkText := strings.ToLower(text)
// 构建一个AC自动机 return AcSearch(checkText, setting.SensitiveWords, false)
m := InitAc()
hits := m.MultiPatternSearch([]rune(checkText), false)
if len(hits) > 0 {
words := make([]string, 0)
for _, hit := range hits {
words = append(words, string(hit.Word))
}
return true, words
}
return false, nil
} }
// SensitiveWordReplace 敏感词替换,返回是否包含敏感词和替换后的文本 // SensitiveWordReplace 敏感词替换,返回是否包含敏感词和替换后的文本
@@ -79,7 +69,7 @@ func SensitiveWordReplace(text string, returnImmediately bool) (bool, []string,
return false, nil, text return false, nil, text
} }
checkText := strings.ToLower(text) checkText := strings.ToLower(text)
m := InitAc() m := InitAc(setting.SensitiveWords)
hits := m.MultiPatternSearch([]rune(checkText), returnImmediately) hits := m.MultiPatternSearch([]rune(checkText), returnImmediately)
if len(hits) > 0 { if len(hits) > 0 {
words := make([]string, 0) words := make([]string, 0)

View File

@@ -4,7 +4,6 @@ import (
"bytes" "bytes"
"fmt" "fmt"
goahocorasick "github.com/anknown/ahocorasick" goahocorasick "github.com/anknown/ahocorasick"
"one-api/setting"
"strings" "strings"
) )
@@ -57,9 +56,9 @@ func RemoveDuplicate(s []string) []string {
return result return result
} }
func InitAc() *goahocorasick.Machine { func InitAc(words []string) *goahocorasick.Machine {
m := new(goahocorasick.Machine) m := new(goahocorasick.Machine)
dict := readRunes() dict := readRunes(words)
if err := m.Build(dict); err != nil { if err := m.Build(dict); err != nil {
fmt.Println(err) fmt.Println(err)
return nil return nil
@@ -67,10 +66,10 @@ func InitAc() *goahocorasick.Machine {
return m return m
} }
func readRunes() [][]rune { func readRunes(words []string) [][]rune {
var dict [][]rune var dict [][]rune
for _, word := range setting.SensitiveWords { for _, word := range words {
word = strings.ToLower(word) word = strings.ToLower(word)
l := bytes.TrimSpace([]byte(word)) l := bytes.TrimSpace([]byte(word))
dict = append(dict, bytes.Runes(l)) dict = append(dict, bytes.Runes(l))
@@ -78,3 +77,25 @@ func readRunes() [][]rune {
return dict return dict
} }
func AcSearch(findText string, dict []string, stopImmediately bool) (bool, []string) {
if len(dict) == 0 {
return false, nil
}
if len(findText) == 0 {
return false, nil
}
m := InitAc(dict)
if m == nil {
return false, nil
}
hits := m.MultiPatternSearch([]rune(findText), stopImmediately)
if len(hits) > 0 {
words := make([]string, 0)
for _, hit := range hits {
words = append(words, string(hit.Word))
}
return true, words
}
return false, nil
}

View File

@@ -1,3 +1,30 @@
package setting package setting
import "strings"
var DemoSiteEnabled = false var DemoSiteEnabled = false
var AutomaticDisableKeywords = []string{
"Your credit balance is too low",
"This organization has been disabled.",
"You exceeded your current quota",
"Permission denied",
"The security token included in the request is invalid",
"Operation not allowed",
"Your account is not authorized",
}
func AutomaticDisableKeywordsToString() string {
return strings.Join(AutomaticDisableKeywords, "\n")
}
func AutomaticDisableKeywordsFromString(s string) {
AutomaticDisableKeywords = []string{}
ak := strings.Split(s, "\n")
for _, k := range ak {
k = strings.TrimSpace(k)
if k != "" {
AutomaticDisableKeywords = append(AutomaticDisableKeywords, k)
}
}
}

View File

@@ -59,6 +59,7 @@ const OperationSetting = () => {
RetryTimes: 0, RetryTimes: 0,
Chats: "[]", Chats: "[]",
DemoSiteEnabled: false, DemoSiteEnabled: false,
AutomaticDisableKeywords: '',
}); });
let [loading, setLoading] = useState(false); let [loading, setLoading] = useState(false);

View File

@@ -201,7 +201,7 @@
"相关 API 显示令牌额度而非用户额度": "Related APIs show token quota instead of user quota", "相关 API 显示令牌额度而非用户额度": "Related APIs show token quota instead of user quota",
"保存通用设置": "Save General Settings", "保存通用设置": "Save General Settings",
"监控设置": "Monitoring Settings", "监控设置": "Monitoring Settings",
"最长响应时间": "Maximum Response Time", "测试所有渠道的最长响应时间": "Maximum response time for testing all channels",
"单位秒": "Unit: seconds", "单位秒": "Unit: seconds",
"当运行通道全部测试时": "When running all channel tests", "当运行通道全部测试时": "When running all channel tests",
"超过此时间将自动禁用通道": "Channels exceeding this time will be automatically disabled", "超过此时间将自动禁用通道": "Channels exceeding this time will be automatically disabled",
@@ -1246,5 +1246,8 @@
"请输入要设置的标签名称": "Please enter the tag name to be set", "请输入要设置的标签名称": "Please enter the tag name to be set",
"请输入标签名称": "Please enter the tag name", "请输入标签名称": "Please enter the tag name",
"支持搜索用户的 ID、用户名、显示名称和邮箱地址": "Support searching for user ID, username, display name, and email address", "支持搜索用户的 ID、用户名、显示名称和邮箱地址": "Support searching for user ID, username, display name, and email address",
"已注销": "Logged out" "已注销": "Logged out",
"自动禁用关键词": "Automatic disable keywords",
"一行一个,不区分大小写": "One line per keyword, not case-sensitive",
"当上游通道返回错误中包含这些关键词时(不区分大小写),自动禁用通道": "When the upstream channel returns an error containing these keywords (not case-sensitive), automatically disable the channel"
} }

View File

@@ -5,7 +5,7 @@ import {
API, API,
showError, showError,
showSuccess, showSuccess,
showWarning, showWarning, verifyJSON
} from '../../../helpers'; } from '../../../helpers';
import { useTranslation } from 'react-i18next'; import { useTranslation } from 'react-i18next';
@@ -17,6 +17,7 @@ export default function SettingsMonitoring(props) {
QuotaRemindThreshold: '', QuotaRemindThreshold: '',
AutomaticDisableChannelEnabled: false, AutomaticDisableChannelEnabled: false,
AutomaticEnableChannelEnabled: false, AutomaticEnableChannelEnabled: false,
AutomaticDisableKeywords: '',
}); });
const refForm = useRef(); const refForm = useRef();
const [inputsRow, setInputsRow] = useState(inputs); const [inputsRow, setInputsRow] = useState(inputs);
@@ -79,7 +80,7 @@ export default function SettingsMonitoring(props) {
<Row gutter={16}> <Row gutter={16}>
<Col span={8}> <Col span={8}>
<Form.InputNumber <Form.InputNumber
label={t('最长响应时间')} label={t('测试所有渠道的最长响应时间')}
step={1} step={1}
min={0} min={0}
suffix={t('秒')} suffix={t('秒')}
@@ -144,6 +145,18 @@ export default function SettingsMonitoring(props) {
/> />
</Col> </Col>
</Row> </Row>
<Row gutter={16}>
<Col span={16}>
<Form.TextArea
label={t('自动禁用关键词')}
placeholder={t('一行一个,不区分大小写')}
extraText={t('当上游通道返回错误中包含这些关键词时(不区分大小写),自动禁用通道')}
field={'AutomaticDisableKeywords'}
autosize={{ minRows: 6, maxRows: 12 }}
onChange={(value) => setInputs({ ...inputs, AutomaticDisableKeywords: value })}
/>
</Col>
</Row>
<Row> <Row>
<Button size='default' onClick={onSubmit}> <Button size='default' onClick={onSubmit}>
{t('保存监控设置')} {t('保存监控设置')}