diff --git a/relay/relay-audio.go b/relay/relay-audio.go index a858bb91..b95c1eb6 100644 --- a/relay/relay-audio.go +++ b/relay/relay-audio.go @@ -13,6 +13,7 @@ import ( "one-api/relay/helper" "one-api/service" "one-api/setting" + "strings" ) func getAndValidAudioRequest(c *gin.Context, info *relaycommon.RelayInfo) (*dto.AudioRequest, error) { @@ -27,8 +28,9 @@ func getAndValidAudioRequest(c *gin.Context, info *relaycommon.RelayInfo) (*dto. return nil, errors.New("model is required") } if setting.ShouldCheckPromptSensitive() { - err := service.CheckSensitiveInput(audioRequest.Input) + words, err := service.CheckSensitiveInput(audioRequest.Input) if err != nil { + common.LogWarn(c, fmt.Sprintf("user sensitive words detected: %s", strings.Join(words, ","))) return nil, err } } diff --git a/relay/relay-image.go b/relay/relay-image.go index 24e62073..6544042f 100644 --- a/relay/relay-image.go +++ b/relay/relay-image.go @@ -61,8 +61,9 @@ func getAndValidImageRequest(c *gin.Context, info *relaycommon.RelayInfo) (*dto. // return service.OpenAIErrorWrapper(errors.New("n must be between 1 and 10"), "invalid_field_value", http.StatusBadRequest) //} if setting.ShouldCheckPromptSensitive() { - err := service.CheckSensitiveInput(imageRequest.Prompt) + words, err := service.CheckSensitiveInput(imageRequest.Prompt) if err != nil { + common.LogWarn(c, fmt.Sprintf("user sensitive words detected: %s", strings.Join(words, ","))) return nil, err } } diff --git a/relay/relay-text.go b/relay/relay-text.go index b438571c..c1a3e099 100644 --- a/relay/relay-text.go +++ b/relay/relay-text.go @@ -78,8 +78,9 @@ func TextHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) { } if setting.ShouldCheckPromptSensitive() { - err = checkRequestSensitive(textRequest, relayInfo) + words, err := checkRequestSensitive(textRequest, relayInfo) if err != nil { + common.LogWarn(c, fmt.Sprintf("user sensitive words detected: %s", strings.Join(words, ", "))) return service.OpenAIErrorWrapperLocal(err, "sensitive_words_detected", http.StatusBadRequest) } } @@ -219,19 +220,20 @@ func getPromptTokens(textRequest *dto.GeneralOpenAIRequest, info *relaycommon.Re return promptTokens, err } -func checkRequestSensitive(textRequest *dto.GeneralOpenAIRequest, info *relaycommon.RelayInfo) error { +func checkRequestSensitive(textRequest *dto.GeneralOpenAIRequest, info *relaycommon.RelayInfo) ([]string, error) { var err error + var words []string switch info.RelayMode { case relayconstant.RelayModeChatCompletions: - err = service.CheckSensitiveMessages(textRequest.Messages) + words, err = service.CheckSensitiveMessages(textRequest.Messages) case relayconstant.RelayModeCompletions: - err = service.CheckSensitiveInput(textRequest.Prompt) + words, err = service.CheckSensitiveInput(textRequest.Prompt) case relayconstant.RelayModeModerations: - err = service.CheckSensitiveInput(textRequest.Input) + words, err = service.CheckSensitiveInput(textRequest.Input) case relayconstant.RelayModeEmbeddings: - err = service.CheckSensitiveInput(textRequest.Input) + words, err = service.CheckSensitiveInput(textRequest.Input) } - return err + return words, err } // 预扣费并返回用户剩余配额 diff --git a/service/sensitive.go b/service/sensitive.go index 14ac9481..c4a4ad44 100644 --- a/service/sensitive.go +++ b/service/sensitive.go @@ -8,39 +8,30 @@ import ( "strings" ) -func CheckSensitiveMessages(messages []dto.Message) error { +func CheckSensitiveMessages(messages []dto.Message) ([]string, error) { for _, message := range messages { - if len(message.Content) > 0 { - if message.IsStringContent() { - stringContent := message.StringContent() - if ok, words := SensitiveWordContains(stringContent); ok { - return errors.New("sensitive words: " + strings.Join(words, ",")) - } - } - } else { - arrayContent := message.ParseContent() - for _, m := range arrayContent { - if m.Type == "image_url" { - // TODO: check image url - } else { - if ok, words := SensitiveWordContains(m.Text); ok { - return errors.New("sensitive words: " + strings.Join(words, ",")) - } + arrayContent := message.ParseContent() + for _, m := range arrayContent { + if m.Type == "image_url" { + // TODO: check image url + } else { + if ok, words := SensitiveWordContains(m.Text); ok { + return words, errors.New("sensitive words detected") } } } } - return nil + return nil, nil } -func CheckSensitiveText(text string) error { +func CheckSensitiveText(text string) ([]string, error) { if ok, words := SensitiveWordContains(text); ok { - return errors.New("sensitive words: " + strings.Join(words, ",")) + return words, errors.New("sensitive words detected") } - return nil + return nil, nil } -func CheckSensitiveInput(input any) error { +func CheckSensitiveInput(input any) ([]string, error) { switch v := input.(type) { case string: return CheckSensitiveText(v) @@ -60,7 +51,7 @@ func SensitiveWordContains(text string) (bool, []string) { return false, nil } checkText := strings.ToLower(text) - return AcSearch(checkText, setting.SensitiveWords, false) + return AcSearch(checkText, setting.SensitiveWords, true) } // SensitiveWordReplace 敏感词替换,返回是否包含敏感词和替换后的文本