diff --git a/common/str.go b/common/str.go index ddf8375c..bab252c6 100644 --- a/common/str.go +++ b/common/str.go @@ -1,5 +1,13 @@ package common +import ( + "bytes" + "fmt" + goahocorasick "github.com/anknown/ahocorasick" + "one-api/constant" + "strings" +) + func SundaySearch(text string, pattern string) bool { // 计算偏移表 offset := make(map[rune]int) @@ -48,3 +56,25 @@ func RemoveDuplicate(s []string) []string { } return result } + +func InitAc() *goahocorasick.Machine { + m := new(goahocorasick.Machine) + dict := readRunes() + if err := m.Build(dict); err != nil { + fmt.Println(err) + return nil + } + return m +} + +func readRunes() [][]rune { + var dict [][]rune + + for _, word := range constant.SensitiveWords { + word = strings.ToLower(word) + l := bytes.TrimSpace([]byte(word)) + dict = append(dict, bytes.Runes(l)) + } + + return dict +} diff --git a/constant/sensitive.go b/constant/sensitive.go index 52975606..f56fd6e7 100644 --- a/constant/sensitive.go +++ b/constant/sensitive.go @@ -16,7 +16,7 @@ var StreamCacheQueueLength = 0 // SensitiveWords 敏感词 // var SensitiveWords []string var SensitiveWords = []string{ - "test", + "test_sensitive", } func SensitiveWordsToString() string { diff --git a/relay/channel/claude/relay-claude.go b/relay/channel/claude/relay-claude.go index d1d8e6eb..7195bcf3 100644 --- a/relay/channel/claude/relay-claude.go +++ b/relay/channel/claude/relay-claude.go @@ -370,7 +370,7 @@ func claudeHandler(requestMode int, c *gin.Context, resp *http.Response, promptT }, nil } fullTextResponse := ResponseClaude2OpenAI(requestMode, &claudeResponse) - completionTokens, err, _ := service.CountTokenText(claudeResponse.Completion, model, false) + completionTokens, err := service.CountTokenText(claudeResponse.Completion, model) if err != nil { return service.OpenAIErrorWrapper(err, "count_token_text_failed", http.StatusInternalServerError), nil } diff --git a/relay/channel/gemini/relay-gemini.go b/relay/channel/gemini/relay-gemini.go index ee9301d2..79644af3 100644 --- a/relay/channel/gemini/relay-gemini.go +++ b/relay/channel/gemini/relay-gemini.go @@ -256,7 +256,7 @@ func geminiChatHandler(c *gin.Context, resp *http.Response, promptTokens int, mo }, nil } fullTextResponse := responseGeminiChat2OpenAI(&geminiResponse) - completionTokens, _, _ := service.CountTokenText(geminiResponse.GetResponseText(), model, false) + completionTokens, _ := service.CountTokenText(geminiResponse.GetResponseText(), model) usage := dto.Usage{ PromptTokens: promptTokens, CompletionTokens: completionTokens, diff --git a/relay/channel/openai/relay-openai.go b/relay/channel/openai/relay-openai.go index 68251194..8de0b3eb 100644 --- a/relay/channel/openai/relay-openai.go +++ b/relay/channel/openai/relay-openai.go @@ -190,7 +190,7 @@ func OpenaiHandler(c *gin.Context, resp *http.Response, promptTokens int, model if simpleResponse.Usage.TotalTokens == 0 { completionTokens := 0 for _, choice := range simpleResponse.Choices { - ctkm, _, _ := service.CountTokenText(string(choice.Message.Content), model, false) + ctkm, _ := service.CountTokenText(string(choice.Message.Content), model) completionTokens += ctkm } simpleResponse.Usage = dto.Usage{ diff --git a/relay/channel/palm/relay-palm.go b/relay/channel/palm/relay-palm.go index 6933d6f1..47588a2e 100644 --- a/relay/channel/palm/relay-palm.go +++ b/relay/channel/palm/relay-palm.go @@ -156,7 +156,7 @@ func palmHandler(c *gin.Context, resp *http.Response, promptTokens int, model st }, nil } fullTextResponse := responsePaLM2OpenAI(&palmResponse) - completionTokens, _, _ := service.CountTokenText(palmResponse.Candidates[0].Content, model, false) + completionTokens, _ := service.CountTokenText(palmResponse.Candidates[0].Content, model) usage := dto.Usage{ PromptTokens: promptTokens, CompletionTokens: completionTokens, diff --git a/relay/relay-audio.go b/relay/relay-audio.go index ef89597d..09b67cfa 100644 --- a/relay/relay-audio.go +++ b/relay/relay-audio.go @@ -55,7 +55,13 @@ func AudioHelper(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusCode { promptTokens := 0 preConsumedTokens := common.PreConsumedQuota if strings.HasPrefix(audioRequest.Model, "tts-1") { - promptTokens, err, _ = service.CountAudioToken(audioRequest.Input, audioRequest.Model, constant.ShouldCheckPromptSensitive()) + if constant.ShouldCheckPromptSensitive() { + err = service.CheckSensitiveInput(audioRequest.Input) + if err != nil { + return service.OpenAIErrorWrapper(err, "sensitive_words_detected", http.StatusBadRequest) + } + } + promptTokens, err = service.CountAudioToken(audioRequest.Input, audioRequest.Model) if err != nil { return service.OpenAIErrorWrapper(err, "count_audio_token_failed", http.StatusInternalServerError) } @@ -178,7 +184,7 @@ func AudioHelper(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusCode { if strings.HasPrefix(audioRequest.Model, "tts-1") { quota = promptTokens } else { - quota, err, _ = service.CountAudioToken(audioResponse.Text, audioRequest.Model, false) + quota, err = service.CountAudioToken(audioResponse.Text, audioRequest.Model) } quota = int(float64(quota) * ratio) if ratio != 0 && quota <= 0 { diff --git a/relay/relay-image.go b/relay/relay-image.go index ed090f53..6b6914e2 100644 --- a/relay/relay-image.go +++ b/relay/relay-image.go @@ -10,6 +10,7 @@ import ( "io" "net/http" "one-api/common" + "one-api/constant" "one-api/dto" "one-api/model" relaycommon "one-api/relay/common" @@ -47,6 +48,13 @@ func RelayImageHelper(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusC return service.OpenAIErrorWrapper(errors.New("prompt is required"), "required_field_missing", http.StatusBadRequest) } + if constant.ShouldCheckPromptSensitive() { + err = service.CheckSensitiveInput(imageRequest.Prompt) + if err != nil { + return service.OpenAIErrorWrapper(err, "sensitive_words_detected", http.StatusBadRequest) + } + } + if strings.Contains(imageRequest.Size, "×") { return service.OpenAIErrorWrapper(errors.New("size an unexpected error occurred in the parameter, please use 'x' instead of the multiplication sign '×'"), "invalid_field_value", http.StatusBadRequest) } diff --git a/relay/relay-text.go b/relay/relay-text.go index 007b3b1d..02dcaaa6 100644 --- a/relay/relay-text.go +++ b/relay/relay-text.go @@ -98,13 +98,17 @@ func TextHelper(c *gin.Context) *dto.OpenAIErrorWithStatusCode { var ratio float64 var modelRatio float64 //err := service.SensitiveWordsCheck(textRequest) - promptTokens, err, sensitiveTrigger := getPromptTokens(textRequest, relayInfo) - // count messages token error 计算promptTokens错误 - if err != nil { - if sensitiveTrigger { + if constant.ShouldCheckPromptSensitive() { + err = checkRequestSensitive(textRequest, relayInfo) + if err != nil { return service.OpenAIErrorWrapperLocal(err, "sensitive_words_detected", http.StatusBadRequest) } + } + + promptTokens, err := getPromptTokens(textRequest, relayInfo) + // count messages token error 计算promptTokens错误 + if err != nil { return service.OpenAIErrorWrapper(err, "count_token_messages_failed", http.StatusInternalServerError) } @@ -128,7 +132,7 @@ func TextHelper(c *gin.Context) *dto.OpenAIErrorWithStatusCode { adaptor := GetAdaptor(relayInfo.ApiType) if adaptor == nil { - return service.OpenAIErrorWrapper(fmt.Errorf("invalid api type: %d", relayInfo.ApiType), "invalid_api_type", http.StatusBadRequest) + return service.OpenAIErrorWrapperLocal(fmt.Errorf("invalid api type: %d", relayInfo.ApiType), "invalid_api_type", http.StatusBadRequest) } adaptor.Init(relayInfo, *textRequest) var requestBody io.Reader @@ -136,7 +140,7 @@ func TextHelper(c *gin.Context) *dto.OpenAIErrorWithStatusCode { if isModelMapped { jsonStr, err := json.Marshal(textRequest) if err != nil { - return service.OpenAIErrorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError) + return service.OpenAIErrorWrapperLocal(err, "marshal_text_request_failed", http.StatusInternalServerError) } requestBody = bytes.NewBuffer(jsonStr) } else { @@ -145,11 +149,11 @@ func TextHelper(c *gin.Context) *dto.OpenAIErrorWithStatusCode { } else { convertedRequest, err := adaptor.ConvertRequest(c, relayInfo.RelayMode, textRequest) if err != nil { - return service.OpenAIErrorWrapper(err, "convert_request_failed", http.StatusInternalServerError) + return service.OpenAIErrorWrapperLocal(err, "convert_request_failed", http.StatusInternalServerError) } jsonData, err := json.Marshal(convertedRequest) if err != nil { - return service.OpenAIErrorWrapper(err, "json_marshal_failed", http.StatusInternalServerError) + return service.OpenAIErrorWrapperLocal(err, "json_marshal_failed", http.StatusInternalServerError) } requestBody = bytes.NewBuffer(jsonData) } @@ -182,26 +186,39 @@ func TextHelper(c *gin.Context) *dto.OpenAIErrorWithStatusCode { return nil } -func getPromptTokens(textRequest *dto.GeneralOpenAIRequest, info *relaycommon.RelayInfo) (int, error, bool) { +func getPromptTokens(textRequest *dto.GeneralOpenAIRequest, info *relaycommon.RelayInfo) (int, error) { var promptTokens int var err error - var sensitiveTrigger bool - checkSensitive := constant.ShouldCheckPromptSensitive() switch info.RelayMode { case relayconstant.RelayModeChatCompletions: - promptTokens, err, sensitiveTrigger = service.CountTokenChatRequest(*textRequest, textRequest.Model, checkSensitive) + promptTokens, err = service.CountTokenChatRequest(*textRequest, textRequest.Model) case relayconstant.RelayModeCompletions: - promptTokens, err, sensitiveTrigger = service.CountTokenInput(textRequest.Prompt, textRequest.Model, checkSensitive) + promptTokens, err = service.CountTokenInput(textRequest.Prompt, textRequest.Model) case relayconstant.RelayModeModerations: - promptTokens, err, sensitiveTrigger = service.CountTokenInput(textRequest.Input, textRequest.Model, checkSensitive) + promptTokens, err = service.CountTokenInput(textRequest.Input, textRequest.Model) case relayconstant.RelayModeEmbeddings: - promptTokens, err, sensitiveTrigger = service.CountTokenInput(textRequest.Input, textRequest.Model, checkSensitive) + promptTokens, err = service.CountTokenInput(textRequest.Input, textRequest.Model) default: err = errors.New("unknown relay mode") promptTokens = 0 } info.PromptTokens = promptTokens - return promptTokens, err, sensitiveTrigger + return promptTokens, err +} + +func checkRequestSensitive(textRequest *dto.GeneralOpenAIRequest, info *relaycommon.RelayInfo) error { + var err error + switch info.RelayMode { + case relayconstant.RelayModeChatCompletions: + err = service.CheckSensitiveMessages(textRequest.Messages) + case relayconstant.RelayModeCompletions: + err = service.CheckSensitiveInput(textRequest.Prompt) + case relayconstant.RelayModeModerations: + err = service.CheckSensitiveInput(textRequest.Input) + case relayconstant.RelayModeEmbeddings: + err = service.CheckSensitiveInput(textRequest.Input) + } + return err } // 预扣费并返回用户剩余配额 diff --git a/service/sensitive.go b/service/sensitive.go index 51621c3b..a9b51983 100644 --- a/service/sensitive.go +++ b/service/sensitive.go @@ -1,13 +1,60 @@ package service import ( - "bytes" + "errors" "fmt" - "github.com/anknown/ahocorasick" + "one-api/common" "one-api/constant" + "one-api/dto" "strings" ) +func CheckSensitiveMessages(messages []dto.Message) error { + for _, message := range messages { + if len(message.Content) > 0 { + if message.IsStringContent() { + stringContent := message.StringContent() + if ok, words := SensitiveWordContains(stringContent); ok { + return errors.New("sensitive words: " + strings.Join(words, ",")) + } + } + } else { + arrayContent := message.ParseContent() + for _, m := range arrayContent { + if m.Type == "image_url" { + // TODO: check image url + } else { + if ok, words := SensitiveWordContains(m.Text); ok { + return errors.New("sensitive words: " + strings.Join(words, ",")) + } + } + } + } + } + return nil +} + +func CheckSensitiveText(text string) error { + if ok, words := SensitiveWordContains(text); ok { + return errors.New("sensitive words: " + strings.Join(words, ",")) + } + return nil +} + +func CheckSensitiveInput(input any) error { + switch v := input.(type) { + case string: + return CheckSensitiveText(v) + case []string: + text := "" + for _, s := range v { + text += s + } + return CheckSensitiveText(text) + } + return CheckSensitiveText(fmt.Sprintf("%v", input)) +} + // SensitiveWordContains 是否包含敏感词,返回是否包含敏感词和敏感词列表 func SensitiveWordContains(text string) (bool, []string) { if len(constant.SensitiveWords) == 0 { @@ -15,7 +62,7 @@ func SensitiveWordContains(text string) (bool, []string) { } checkText := strings.ToLower(text) // 构建一个AC自动机 - m := initAc() + m := common.InitAc() hits := m.MultiPatternSearch([]rune(checkText), false) if len(hits) > 0 { words := make([]string, 0) @@ -33,7 +80,7 @@ func SensitiveWordReplace(text string, returnImmediately bool) (bool, []string, return false, nil, text } checkText := strings.ToLower(text) - m := initAc() + m := common.InitAc() hits := m.MultiPatternSearch([]rune(checkText), returnImmediately) if len(hits) > 0 { words := make([]string, 0) @@ -47,25 +94,3 @@ func SensitiveWordReplace(text string, returnImmediately bool) (bool, []string, } return false, nil, text } - -func initAc() *goahocorasick.Machine { - m := new(goahocorasick.Machine) - dict := readRunes() - if err := m.Build(dict); err != nil { - fmt.Println(err) - return nil - } - return m -} - -func readRunes() [][]rune { - var dict [][]rune - - for _, word := range constant.SensitiveWords { - word = strings.ToLower(word) - l := bytes.TrimSpace([]byte(word)) - dict = append(dict, bytes.Runes(l)) - } - - return dict -} diff --git a/service/token_counter.go b/service/token_counter.go index bca3d512..432f15e4 100644 --- a/service/token_counter.go +++ b/service/token_counter.go @@ -125,11 +125,11 @@ func getImageToken(imageUrl *dto.MessageImageUrl, model string, stream bool) (in return tiles*170 + 85, nil } -func CountTokenChatRequest(request dto.GeneralOpenAIRequest, model string, checkSensitive bool) (int, error, bool) { +func CountTokenChatRequest(request dto.GeneralOpenAIRequest, model string) (int, error) { tkm := 0 - msgTokens, err, b := CountTokenMessages(request.Messages, model, request.Stream, checkSensitive) + msgTokens, err := CountTokenMessages(request.Messages, model, request.Stream) if err != nil { - return 0, err, b + return 0, err } tkm += msgTokens if request.Tools != nil { @@ -137,7 +137,7 @@ func CountTokenChatRequest(request dto.GeneralOpenAIRequest, model string, check var openaiTools []dto.OpenAITools err := json.Unmarshal(toolsData, &openaiTools) if err != nil { - return 0, errors.New(fmt.Sprintf("count_tools_token_fail: %s", err.Error())), false + return 0, errors.New(fmt.Sprintf("count_tools_token_fail: %s", err.Error())) } countStr := "" for _, tool := range openaiTools { @@ -149,18 +149,18 @@ func CountTokenChatRequest(request dto.GeneralOpenAIRequest, model string, check countStr += fmt.Sprintf("%v", tool.Function.Parameters) } } - toolTokens, err, _ := CountTokenInput(countStr, model, false) + toolTokens, err := CountTokenInput(countStr, model) if err != nil { - return 0, err, false + return 0, err } tkm += 8 tkm += toolTokens } - return tkm, nil, false + return tkm, nil } -func CountTokenMessages(messages []dto.Message, model string, stream bool, checkSensitive bool) (int, error, bool) { +func CountTokenMessages(messages []dto.Message, model string, stream bool) (int, error) { //recover when panic tokenEncoder := getTokenEncoder(model) // Reference: @@ -184,13 +184,6 @@ func CountTokenMessages(messages []dto.Message, model string, stream bool, check if len(message.Content) > 0 { if message.IsStringContent() { stringContent := message.StringContent() - if checkSensitive { - contains, words := SensitiveWordContains(stringContent) - if contains { - err := fmt.Errorf("message contains sensitive words: [%s]", strings.Join(words, ", ")) - return 0, err, true - } - } tokenNum += getTokenNum(tokenEncoder, stringContent) if message.Name != nil { tokenNum += tokensPerName @@ -203,7 +196,7 @@ func CountTokenMessages(messages []dto.Message, model string, stream bool, check imageUrl := m.ImageUrl.(dto.MessageImageUrl) imageTokenNum, err := getImageToken(&imageUrl, model, stream) if err != nil { - return 0, err, false + return 0, err } tokenNum += imageTokenNum log.Printf("image token num: %d", imageTokenNum) @@ -215,33 +208,33 @@ func CountTokenMessages(messages []dto.Message, model string, stream bool, check } } tokenNum += 3 // Every reply is primed with <|start|>assistant<|message|> - return tokenNum, nil, false + return tokenNum, nil } -func CountTokenInput(input any, model string, check bool) (int, error, bool) { +func CountTokenInput(input any, model string) (int, error) { switch v := input.(type) { case string: - return CountTokenText(v, model, check) + return CountTokenText(v, model) case []string: text := "" for _, s := range v { text += s } - return CountTokenText(text, model, check) + return CountTokenText(text, model) } - return CountTokenInput(fmt.Sprintf("%v", input), model, check) + return CountTokenInput(fmt.Sprintf("%v", input), model) } func CountTokenStreamChoices(messages []dto.ChatCompletionsStreamResponseChoice, model string) int { tokens := 0 for _, message := range messages { - tkm, _, _ := CountTokenInput(message.Delta.GetContentString(), model, false) + tkm, _ := CountTokenInput(message.Delta.GetContentString(), model) tokens += tkm if message.Delta.ToolCalls != nil { for _, tool := range message.Delta.ToolCalls { - tkm, _, _ := CountTokenInput(tool.Function.Name, model, false) + tkm, _ := CountTokenInput(tool.Function.Name, model) tokens += tkm - tkm, _, _ = CountTokenInput(tool.Function.Arguments, model, false) + tkm, _ = CountTokenInput(tool.Function.Arguments, model) tokens += tkm } } @@ -249,29 +242,17 @@ func CountTokenStreamChoices(messages []dto.ChatCompletionsStreamResponseChoice, return tokens } -func CountAudioToken(text string, model string, check bool) (int, error, bool) { +func CountAudioToken(text string, model string) (int, error) { if strings.HasPrefix(model, "tts") { - contains, words := SensitiveWordContains(text) - if contains { - return utf8.RuneCountInString(text), fmt.Errorf("input contains sensitive words: [%s]", strings.Join(words, ",")), true - } - return utf8.RuneCountInString(text), nil, false + return utf8.RuneCountInString(text), nil } else { - return CountTokenText(text, model, check) + return CountTokenText(text, model) } } // CountTokenText 统计文本的token数量,仅当文本包含敏感词,返回错误,同时返回token数量 -func CountTokenText(text string, model string, check bool) (int, error, bool) { +func CountTokenText(text string, model string) (int, error) { var err error - var trigger bool - if check { - contains, words := SensitiveWordContains(text) - if contains { - err = fmt.Errorf("input contains sensitive words: [%s]", strings.Join(words, ",")) - trigger = true - } - } tokenEncoder := getTokenEncoder(model) - return getTokenNum(tokenEncoder, text), err, trigger + return getTokenNum(tokenEncoder, text), err } diff --git a/service/usage_helpr.go b/service/usage_helpr.go index 460ac566..15e3226f 100644 --- a/service/usage_helpr.go +++ b/service/usage_helpr.go @@ -19,7 +19,7 @@ import ( func ResponseText2Usage(responseText string, modeName string, promptTokens int) (*dto.Usage, error) { usage := &dto.Usage{} usage.PromptTokens = promptTokens - ctkm, err, _ := CountTokenText(responseText, modeName, false) + ctkm, err := CountTokenText(responseText, modeName) usage.CompletionTokens = ctkm usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens return usage, err