820 lines
27 KiB
Go
820 lines
27 KiB
Go
package service
|
||
|
||
import (
|
||
"context"
|
||
"encoding/json"
|
||
"net/http"
|
||
"net/http/httptest"
|
||
"strings"
|
||
"testing"
|
||
"time"
|
||
|
||
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||
"github.com/stretchr/testify/require"
|
||
)
|
||
|
||
type contentModerationTestSettingRepo struct {
|
||
values map[string]string
|
||
}
|
||
|
||
func (r *contentModerationTestSettingRepo) Get(ctx context.Context, key string) (*Setting, error) {
|
||
if value, ok := r.values[key]; ok {
|
||
return &Setting{Key: key, Value: value}, nil
|
||
}
|
||
return nil, ErrSettingNotFound
|
||
}
|
||
|
||
func (r *contentModerationTestSettingRepo) GetValue(ctx context.Context, key string) (string, error) {
|
||
if value, ok := r.values[key]; ok {
|
||
return value, nil
|
||
}
|
||
return "", ErrSettingNotFound
|
||
}
|
||
|
||
func (r *contentModerationTestSettingRepo) Set(ctx context.Context, key, value string) error {
|
||
if r.values == nil {
|
||
r.values = map[string]string{}
|
||
}
|
||
r.values[key] = value
|
||
return nil
|
||
}
|
||
|
||
func (r *contentModerationTestSettingRepo) GetMultiple(ctx context.Context, keys []string) (map[string]string, error) {
|
||
out := map[string]string{}
|
||
for _, key := range keys {
|
||
if value, ok := r.values[key]; ok {
|
||
out[key] = value
|
||
}
|
||
}
|
||
return out, nil
|
||
}
|
||
|
||
func (r *contentModerationTestSettingRepo) SetMultiple(ctx context.Context, settings map[string]string) error {
|
||
if r.values == nil {
|
||
r.values = map[string]string{}
|
||
}
|
||
for key, value := range settings {
|
||
r.values[key] = value
|
||
}
|
||
return nil
|
||
}
|
||
|
||
func (r *contentModerationTestSettingRepo) GetAll(ctx context.Context) (map[string]string, error) {
|
||
out := make(map[string]string, len(r.values))
|
||
for key, value := range r.values {
|
||
out[key] = value
|
||
}
|
||
return out, nil
|
||
}
|
||
|
||
func (r *contentModerationTestSettingRepo) Delete(ctx context.Context, key string) error {
|
||
delete(r.values, key)
|
||
return nil
|
||
}
|
||
|
||
type contentModerationTestRepo struct {
|
||
logs []ContentModerationLog
|
||
}
|
||
|
||
func (r *contentModerationTestRepo) CreateLog(ctx context.Context, log *ContentModerationLog) error {
|
||
if log != nil {
|
||
r.logs = append(r.logs, *log)
|
||
}
|
||
return nil
|
||
}
|
||
|
||
func (r *contentModerationTestRepo) ListLogs(ctx context.Context, filter ContentModerationLogFilter) ([]ContentModerationLog, *pagination.PaginationResult, error) {
|
||
return nil, nil, nil
|
||
}
|
||
|
||
func (r *contentModerationTestRepo) CountFlaggedByUserSince(ctx context.Context, userID int64, since time.Time) (int, error) {
|
||
return 0, nil
|
||
}
|
||
|
||
func (r *contentModerationTestRepo) CleanupExpiredLogs(ctx context.Context, hitBefore time.Time, nonHitBefore time.Time) (*ContentModerationCleanupResult, error) {
|
||
return &ContentModerationCleanupResult{}, nil
|
||
}
|
||
|
||
type contentModerationTestHashCache struct {
|
||
hashes map[string]struct{}
|
||
recorded []string
|
||
checked []string
|
||
deleted []string
|
||
hasResult bool
|
||
hasResultUsed bool
|
||
}
|
||
|
||
type contentModerationTestUserRepo struct {
|
||
user *User
|
||
updated []User
|
||
}
|
||
|
||
func (r *contentModerationTestUserRepo) Create(ctx context.Context, user *User) error {
|
||
panic("unexpected Create call")
|
||
}
|
||
|
||
func (r *contentModerationTestUserRepo) GetByID(ctx context.Context, id int64) (*User, error) {
|
||
if r.user == nil {
|
||
return nil, ErrUserNotFound
|
||
}
|
||
clone := *r.user
|
||
return &clone, nil
|
||
}
|
||
|
||
func (r *contentModerationTestUserRepo) GetByEmail(ctx context.Context, email string) (*User, error) {
|
||
panic("unexpected GetByEmail call")
|
||
}
|
||
|
||
func (r *contentModerationTestUserRepo) GetFirstAdmin(ctx context.Context) (*User, error) {
|
||
panic("unexpected GetFirstAdmin call")
|
||
}
|
||
|
||
func (r *contentModerationTestUserRepo) Update(ctx context.Context, user *User) error {
|
||
if user == nil {
|
||
return nil
|
||
}
|
||
clone := *user
|
||
r.updated = append(r.updated, clone)
|
||
r.user = &clone
|
||
return nil
|
||
}
|
||
|
||
func (r *contentModerationTestUserRepo) Delete(ctx context.Context, id int64) error {
|
||
panic("unexpected Delete call")
|
||
}
|
||
|
||
func (r *contentModerationTestUserRepo) GetUserAvatar(ctx context.Context, userID int64) (*UserAvatar, error) {
|
||
panic("unexpected GetUserAvatar call")
|
||
}
|
||
|
||
func (r *contentModerationTestUserRepo) UpsertUserAvatar(ctx context.Context, userID int64, input UpsertUserAvatarInput) (*UserAvatar, error) {
|
||
panic("unexpected UpsertUserAvatar call")
|
||
}
|
||
|
||
func (r *contentModerationTestUserRepo) DeleteUserAvatar(ctx context.Context, userID int64) error {
|
||
panic("unexpected DeleteUserAvatar call")
|
||
}
|
||
|
||
func (r *contentModerationTestUserRepo) List(ctx context.Context, params pagination.PaginationParams) ([]User, *pagination.PaginationResult, error) {
|
||
panic("unexpected List call")
|
||
}
|
||
|
||
func (r *contentModerationTestUserRepo) ListWithFilters(ctx context.Context, params pagination.PaginationParams, filters UserListFilters) ([]User, *pagination.PaginationResult, error) {
|
||
panic("unexpected ListWithFilters call")
|
||
}
|
||
|
||
func (r *contentModerationTestUserRepo) GetLatestUsedAtByUserIDs(ctx context.Context, userIDs []int64) (map[int64]*time.Time, error) {
|
||
panic("unexpected GetLatestUsedAtByUserIDs call")
|
||
}
|
||
|
||
func (r *contentModerationTestUserRepo) GetLatestUsedAtByUserID(ctx context.Context, userID int64) (*time.Time, error) {
|
||
panic("unexpected GetLatestUsedAtByUserID call")
|
||
}
|
||
|
||
func (r *contentModerationTestUserRepo) UpdateUserLastActiveAt(ctx context.Context, userID int64, activeAt time.Time) error {
|
||
panic("unexpected UpdateUserLastActiveAt call")
|
||
}
|
||
|
||
func (r *contentModerationTestUserRepo) UpdateBalance(ctx context.Context, id int64, amount float64) error {
|
||
panic("unexpected UpdateBalance call")
|
||
}
|
||
|
||
func (r *contentModerationTestUserRepo) DeductBalance(ctx context.Context, id int64, amount float64) error {
|
||
panic("unexpected DeductBalance call")
|
||
}
|
||
|
||
func (r *contentModerationTestUserRepo) UpdateConcurrency(ctx context.Context, id int64, amount int) error {
|
||
panic("unexpected UpdateConcurrency call")
|
||
}
|
||
|
||
func (r *contentModerationTestUserRepo) BatchSetConcurrency(ctx context.Context, userIDs []int64, value int) (int, error) {
|
||
panic("unexpected BatchSetConcurrency call")
|
||
}
|
||
|
||
func (r *contentModerationTestUserRepo) BatchAddConcurrency(ctx context.Context, userIDs []int64, delta int) (int, error) {
|
||
panic("unexpected BatchAddConcurrency call")
|
||
}
|
||
|
||
func (r *contentModerationTestUserRepo) ExistsByEmail(ctx context.Context, email string) (bool, error) {
|
||
panic("unexpected ExistsByEmail call")
|
||
}
|
||
|
||
func (r *contentModerationTestUserRepo) RemoveGroupFromAllowedGroups(ctx context.Context, groupID int64) (int64, error) {
|
||
panic("unexpected RemoveGroupFromAllowedGroups call")
|
||
}
|
||
|
||
func (r *contentModerationTestUserRepo) AddGroupToAllowedGroups(ctx context.Context, userID int64, groupID int64) error {
|
||
panic("unexpected AddGroupToAllowedGroups call")
|
||
}
|
||
|
||
func (r *contentModerationTestUserRepo) RemoveGroupFromUserAllowedGroups(ctx context.Context, userID int64, groupID int64) error {
|
||
panic("unexpected RemoveGroupFromUserAllowedGroups call")
|
||
}
|
||
|
||
func (r *contentModerationTestUserRepo) ListUserAuthIdentities(ctx context.Context, userID int64) ([]UserAuthIdentityRecord, error) {
|
||
panic("unexpected ListUserAuthIdentities call")
|
||
}
|
||
|
||
func (r *contentModerationTestUserRepo) UnbindUserAuthProvider(ctx context.Context, userID int64, provider string) error {
|
||
panic("unexpected UnbindUserAuthProvider call")
|
||
}
|
||
|
||
func (r *contentModerationTestUserRepo) UpdateTotpSecret(ctx context.Context, userID int64, encryptedSecret *string) error {
|
||
panic("unexpected UpdateTotpSecret call")
|
||
}
|
||
|
||
func (r *contentModerationTestUserRepo) EnableTotp(ctx context.Context, userID int64) error {
|
||
panic("unexpected EnableTotp call")
|
||
}
|
||
|
||
func (r *contentModerationTestUserRepo) DisableTotp(ctx context.Context, userID int64) error {
|
||
panic("unexpected DisableTotp call")
|
||
}
|
||
|
||
type contentModerationTestAuthCacheInvalidator struct {
|
||
userIDs []int64
|
||
}
|
||
|
||
func (i *contentModerationTestAuthCacheInvalidator) InvalidateAuthCacheByKey(ctx context.Context, key string) {
|
||
}
|
||
|
||
func (i *contentModerationTestAuthCacheInvalidator) InvalidateAuthCacheByUserID(ctx context.Context, userID int64) {
|
||
i.userIDs = append(i.userIDs, userID)
|
||
}
|
||
|
||
func (i *contentModerationTestAuthCacheInvalidator) InvalidateAuthCacheByGroupID(ctx context.Context, groupID int64) {
|
||
}
|
||
|
||
func (c *contentModerationTestHashCache) RecordFlaggedInputHash(ctx context.Context, inputHash string) error {
|
||
if c.hashes == nil {
|
||
c.hashes = map[string]struct{}{}
|
||
}
|
||
c.hashes[inputHash] = struct{}{}
|
||
c.recorded = append(c.recorded, inputHash)
|
||
return nil
|
||
}
|
||
|
||
func (c *contentModerationTestHashCache) HasFlaggedInputHash(ctx context.Context, inputHash string) (bool, error) {
|
||
c.checked = append(c.checked, inputHash)
|
||
if c.hasResultUsed {
|
||
return c.hasResult, nil
|
||
}
|
||
_, ok := c.hashes[inputHash]
|
||
return ok, nil
|
||
}
|
||
|
||
func (c *contentModerationTestHashCache) DeleteFlaggedInputHash(ctx context.Context, inputHash string) (bool, error) {
|
||
c.deleted = append(c.deleted, inputHash)
|
||
if c.hashes == nil {
|
||
return false, nil
|
||
}
|
||
if _, ok := c.hashes[inputHash]; !ok {
|
||
return false, nil
|
||
}
|
||
delete(c.hashes, inputHash)
|
||
return true, nil
|
||
}
|
||
|
||
func (c *contentModerationTestHashCache) ClearFlaggedInputHashes(ctx context.Context) (int64, error) {
|
||
deleted := int64(len(c.hashes))
|
||
c.hashes = map[string]struct{}{}
|
||
return deleted, nil
|
||
}
|
||
|
||
func (c *contentModerationTestHashCache) CountFlaggedInputHashes(ctx context.Context) (int64, error) {
|
||
return int64(len(c.hashes)), nil
|
||
}
|
||
|
||
func TestBuildContentModerationLog_RedactsInputExcerpt(t *testing.T) {
|
||
svc := &ContentModerationService{}
|
||
cfg := defaultContentModerationConfig()
|
||
input := ContentModerationCheckInput{
|
||
RequestID: "req-1",
|
||
Endpoint: "/v1/chat/completions",
|
||
Provider: "openai",
|
||
}
|
||
|
||
log := svc.buildLog(input, cfg, ContentModerationActionAllow, true, "sexual", 0.8, map[string]float64{"sexual": 0.8}, "hello sk-proj-1234567890abcdef", nil, nil, "")
|
||
|
||
require.NotContains(t, log.InputExcerpt, "sk-proj-1234567890abcdef")
|
||
require.Contains(t, log.InputExcerpt, "[已脱敏]")
|
||
}
|
||
|
||
func TestRedactContentModerationSecrets_LongHexAndTokens(t *testing.T) {
|
||
input := "你哈市多大事cf5bbdc4cd508f3aaf0d2070d529d4a4ac29099f8ecc357f696df28e1df91554 token=abc123456789xyz Bearer eyJhbGciOiJIUzI1NiJ9.eyJzdWIiOiIxMjM0NTY3ODkwIn0.signaturepart"
|
||
|
||
out := redactContentModerationSecrets(input)
|
||
|
||
require.NotContains(t, out, "cf5bbdc4cd508f3aaf0d2070d529d4a4ac29099f8ecc357f696df28e1df91554")
|
||
require.NotContains(t, out, "abc123456789xyz")
|
||
require.NotContains(t, out, "eyJhbGciOiJIUzI1NiJ9")
|
||
require.Contains(t, out, "[已脱敏]")
|
||
}
|
||
|
||
func TestContentModerationConfigNormalize_NonHitRetentionMaxThreeDays(t *testing.T) {
|
||
cfg := defaultContentModerationConfig()
|
||
cfg.NonHitRetentionDays = 30
|
||
|
||
cfg.normalize()
|
||
|
||
require.Equal(t, 3, cfg.NonHitRetentionDays)
|
||
}
|
||
|
||
func TestExtractContentModerationInput_AnthropicImageSourceOnlyParticipatesInMemory(t *testing.T) {
|
||
body := []byte(`{
|
||
"messages": [
|
||
{"role":"user","content":"old"},
|
||
{"role":"assistant","content":"ok"},
|
||
{"role":"user","content":[
|
||
{"type":"text","text":"检查这张图"},
|
||
{"type":"image","source":{"type":"base64","media_type":"image/png","data":"aGVsbG8="}}
|
||
]}
|
||
]
|
||
}`)
|
||
|
||
input := ExtractContentModerationInput(ContentModerationProtocolAnthropicMessages, body)
|
||
require.Equal(t, "检查这张图", input.Text)
|
||
require.Equal(t, []string{"data:image/png;base64,aGVsbG8="}, input.Images)
|
||
|
||
log := (&ContentModerationService{}).buildLog(ContentModerationCheckInput{}, defaultContentModerationConfig(), ContentModerationActionAllow, false, "", 0, nil, input.ExcerptText(), nil, nil, "")
|
||
require.Equal(t, "检查这张图", log.InputExcerpt)
|
||
require.NotContains(t, log.InputExcerpt, "aGVsbG8=")
|
||
}
|
||
|
||
func TestExtractContentModerationInput_AnthropicKeepsEphemeralUserTextAndSkipsSystemReminders(t *testing.T) {
|
||
body := []byte(`{
|
||
"messages": [
|
||
{
|
||
"role": "user",
|
||
"content": [
|
||
{"type": "text", "text": "<system-reminder>工具说明</system-reminder>"},
|
||
{"type": "text", "text": "<system-reminder>Ainder>\n\n"},
|
||
{"type": "text", "text": "hid", "cache_control": {"type": "ephemeral"}}
|
||
]
|
||
}
|
||
]
|
||
}`)
|
||
|
||
input := ExtractContentModerationInput(ContentModerationProtocolAnthropicMessages, body)
|
||
|
||
require.Equal(t, "hid", input.Text)
|
||
require.Empty(t, input.Images)
|
||
}
|
||
|
||
func TestExtractContentModerationInput_OpenAIChatUsesLastUserMessage(t *testing.T) {
|
||
body := []byte(`{
|
||
"model":"gpt-5.5",
|
||
"messages":[
|
||
{"role":"system","content":"system prompt"},
|
||
{"role":"user","content":"old user"},
|
||
{"role":"assistant","content":"ok"},
|
||
{"role":"user","content":[{"type":"text","text":"latest user"},{"type":"image_url","image_url":{"url":"https://example.com/a.png"}}]}
|
||
]
|
||
}`)
|
||
|
||
input := ExtractContentModerationInput(ContentModerationProtocolOpenAIChat, body)
|
||
|
||
require.Equal(t, "latest user", input.Text)
|
||
require.Equal(t, []string{"https://example.com/a.png"}, input.Images)
|
||
require.NotContains(t, input.Text, "old user")
|
||
require.NotContains(t, input.Text, "system prompt")
|
||
}
|
||
|
||
func TestExtractContentModerationInput_OpenAIImagesIncludesPromptAndImages(t *testing.T) {
|
||
body := []byte(`{
|
||
"prompt":"replace background",
|
||
"images":[
|
||
{"image_url":"https://example.com/source.png"},
|
||
{"image_url":"data:image/png;base64,aGVsbG8="}
|
||
]
|
||
}`)
|
||
|
||
input := ExtractContentModerationInput(ContentModerationProtocolOpenAIImages, body)
|
||
|
||
require.Equal(t, "replace background", input.Text)
|
||
require.Equal(t, []string{"https://example.com/source.png", "data:image/png;base64,aGVsbG8="}, input.Images)
|
||
}
|
||
|
||
func TestExtractContentModerationInput_OpenAIResponsesCodexPayloadUsesLastUserMessage(t *testing.T) {
|
||
body := []byte(`{
|
||
"model":"gpt-5.5",
|
||
"instructions":"instructions.....",
|
||
"input":[
|
||
{"type":"message","role":"developer","content":[{"type":"input_text","text":"developer permissions sk-proj-1234567890abcdef"}]},
|
||
{"type":"message","role":"user","content":[{"type":"input_text","text":"first user prompt"}]},
|
||
{"type":"message","role":"user","content":[{"type":"input_text","text":"last user prompt"}]}
|
||
],
|
||
"prompt_cache_key":"cache-key"
|
||
}`)
|
||
|
||
input := ExtractContentModerationInput(ContentModerationProtocolOpenAIResponses, body)
|
||
|
||
require.Equal(t, "last user prompt", input.Text)
|
||
require.Empty(t, input.Images)
|
||
require.NotContains(t, input.Text, "developer permissions")
|
||
require.NotContains(t, input.Text, "first user prompt")
|
||
}
|
||
|
||
func TestContentModerationCheck_OpenAIResponsesRecordsNonHitForCodexPayload(t *testing.T) {
|
||
var moderationRequest moderationAPIRequest
|
||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||
require.Equal(t, "/v1/moderations", r.URL.Path)
|
||
require.NoError(t, json.NewDecoder(r.Body).Decode(&moderationRequest))
|
||
_ = json.NewEncoder(w).Encode(moderationAPIResponse{
|
||
Results: []moderationAPIResult{{
|
||
CategoryScores: map[string]float64{"sexual": 0.01},
|
||
}},
|
||
})
|
||
}))
|
||
defer server.Close()
|
||
|
||
cfg := defaultContentModerationConfig()
|
||
cfg.Enabled = true
|
||
cfg.Mode = ContentModerationModePreBlock
|
||
cfg.BaseURL = server.URL
|
||
cfg.APIKeys = []string{"sk-test"}
|
||
cfg.RecordNonHits = true
|
||
rawCfg, err := json.Marshal(cfg)
|
||
require.NoError(t, err)
|
||
|
||
repo := &contentModerationTestRepo{}
|
||
svc := NewContentModerationService(
|
||
&contentModerationTestSettingRepo{values: map[string]string{
|
||
SettingKeyRiskControlEnabled: "true",
|
||
SettingKeyContentModerationConfig: string(rawCfg),
|
||
}},
|
||
repo,
|
||
&contentModerationTestHashCache{},
|
||
nil,
|
||
nil,
|
||
nil,
|
||
nil,
|
||
)
|
||
|
||
body := []byte(`{
|
||
"model":"gpt-5.5",
|
||
"input":[
|
||
{"type":"message","role":"developer","content":[{"type":"input_text","text":"developer instructions should not be audited"}]},
|
||
{"type":"message","role":"user","content":[{"type":"input_text","text":"first user prompt"}]},
|
||
{"type":"message","role":"user","content":[{"type":"input_text","text":"last user prompt"}]}
|
||
]
|
||
}`)
|
||
decision, err := svc.Check(context.Background(), ContentModerationCheckInput{
|
||
UserID: 1001,
|
||
Endpoint: "/responses",
|
||
Provider: "openai",
|
||
Model: "gpt-5.5",
|
||
Protocol: ContentModerationProtocolOpenAIResponses,
|
||
Body: body,
|
||
})
|
||
|
||
require.NoError(t, err)
|
||
require.False(t, decision.Blocked)
|
||
require.Len(t, repo.logs, 1)
|
||
require.False(t, repo.logs[0].Flagged)
|
||
require.Equal(t, ContentModerationActionAllow, repo.logs[0].Action)
|
||
require.Equal(t, "/responses", repo.logs[0].Endpoint)
|
||
require.Equal(t, "last user prompt", repo.logs[0].InputExcerpt)
|
||
require.Equal(t, "last user prompt", moderationRequest.Input)
|
||
}
|
||
|
||
func TestContentModerationCheck_PreBlockBlocksCodexResponsesLatestUserInput(t *testing.T) {
|
||
var moderationRequest moderationAPIRequest
|
||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||
require.Equal(t, "/v1/moderations", r.URL.Path)
|
||
require.NoError(t, json.NewDecoder(r.Body).Decode(&moderationRequest))
|
||
_ = json.NewEncoder(w).Encode(moderationAPIResponse{
|
||
Results: []moderationAPIResult{{
|
||
CategoryScores: map[string]float64{"sexual": 0.9},
|
||
}},
|
||
})
|
||
}))
|
||
defer server.Close()
|
||
|
||
cfg := defaultContentModerationConfig()
|
||
cfg.Enabled = true
|
||
cfg.Mode = ContentModerationModePreBlock
|
||
cfg.BaseURL = server.URL
|
||
cfg.APIKeys = []string{"sk-test"}
|
||
cfg.BlockStatus = http.StatusUnavailableForLegalReasons
|
||
cfg.BlockMessage = "内容审计测试阻断"
|
||
rawCfg, err := json.Marshal(cfg)
|
||
require.NoError(t, err)
|
||
|
||
repo := &contentModerationTestRepo{}
|
||
svc := NewContentModerationService(
|
||
&contentModerationTestSettingRepo{values: map[string]string{
|
||
SettingKeyRiskControlEnabled: "true",
|
||
SettingKeyContentModerationConfig: string(rawCfg),
|
||
}},
|
||
repo,
|
||
&contentModerationTestHashCache{},
|
||
nil,
|
||
nil,
|
||
nil,
|
||
nil,
|
||
)
|
||
|
||
body := []byte(`{
|
||
"model":"gpt-5.5",
|
||
"instructions":"instructions.....",
|
||
"input":[
|
||
{"type":"message","role":"developer","content":[{"type":"input_text","text":"developer instructions should not be audited"}]},
|
||
{"type":"message","role":"user","content":[{"type":"input_text","text":"environment context"}]},
|
||
{"type":"message","role":"user","content":[{"type":"input_text","text":"latest blocked prompt"}]}
|
||
]
|
||
}`)
|
||
decision, err := svc.Check(context.Background(), ContentModerationCheckInput{
|
||
UserID: 1001,
|
||
Endpoint: "/responses",
|
||
Provider: "openai",
|
||
Model: "gpt-5.5",
|
||
Protocol: ContentModerationProtocolOpenAIResponses,
|
||
Body: body,
|
||
})
|
||
|
||
require.NoError(t, err)
|
||
require.True(t, decision.Blocked)
|
||
require.Equal(t, ContentModerationActionBlock, decision.Action)
|
||
require.Equal(t, http.StatusUnavailableForLegalReasons, decision.StatusCode)
|
||
require.Equal(t, "内容审计测试阻断", decision.Message)
|
||
require.Len(t, repo.logs, 1)
|
||
require.True(t, repo.logs[0].Flagged)
|
||
require.Equal(t, ContentModerationActionBlock, repo.logs[0].Action)
|
||
require.Equal(t, ContentModerationModePreBlock, repo.logs[0].Mode)
|
||
require.Equal(t, "latest blocked prompt", repo.logs[0].InputExcerpt)
|
||
require.Equal(t, "latest blocked prompt", moderationRequest.Input)
|
||
}
|
||
|
||
func TestBuildContentModerationTestAuditResult_UsesConfiguredThresholdsOnly(t *testing.T) {
|
||
result := buildContentModerationTestAuditResult(&moderationAPIResult{
|
||
Flagged: true,
|
||
CategoryScores: map[string]float64{
|
||
"harassment": 0.65,
|
||
},
|
||
}, nil)
|
||
|
||
require.NotNil(t, result)
|
||
require.False(t, result.Flagged)
|
||
require.Equal(t, "harassment", result.HighestCategory)
|
||
require.Equal(t, 0.65, result.HighestScore)
|
||
require.Equal(t, 0.65, result.CompositeScore)
|
||
require.Equal(t, 0.98, result.Thresholds["harassment"])
|
||
}
|
||
|
||
func TestContentModerationCheck_PreHashUsesRedisHashCache(t *testing.T) {
|
||
cfg := defaultContentModerationConfig()
|
||
cfg.Enabled = true
|
||
cfg.PreHashCheckEnabled = true
|
||
cfg.APIKeys = []string{"sk-test"}
|
||
cfg.BlockStatus = http.StatusConflict
|
||
cfg.BlockMessage = "命中历史风险输入"
|
||
rawCfg, err := json.Marshal(cfg)
|
||
require.NoError(t, err)
|
||
|
||
hashCache := &contentModerationTestHashCache{hashes: map[string]struct{}{}}
|
||
content := ContentModerationInput{Text: "blocked prompt"}
|
||
content.Normalize()
|
||
hashCache.hashes[content.Hash()] = struct{}{}
|
||
|
||
svc := NewContentModerationService(
|
||
&contentModerationTestSettingRepo{values: map[string]string{
|
||
SettingKeyRiskControlEnabled: "true",
|
||
SettingKeyContentModerationConfig: string(rawCfg),
|
||
}},
|
||
&contentModerationTestRepo{},
|
||
hashCache,
|
||
nil,
|
||
nil,
|
||
nil,
|
||
nil,
|
||
)
|
||
|
||
decision, err := svc.Check(context.Background(), ContentModerationCheckInput{
|
||
Protocol: ContentModerationProtocolOpenAIChat,
|
||
Body: []byte(`{"messages":[{"role":"user","content":"blocked prompt"}]}`),
|
||
})
|
||
require.NoError(t, err)
|
||
require.True(t, decision.Blocked)
|
||
require.Equal(t, ContentModerationActionHashBlock, decision.Action)
|
||
require.Equal(t, http.StatusConflict, decision.StatusCode)
|
||
require.Equal(t, content.Hash(), decision.InputHash)
|
||
require.Contains(t, decision.Message, "命中历史风险输入")
|
||
require.Contains(t, decision.Message, content.Hash())
|
||
require.Len(t, hashCache.checked, 1)
|
||
}
|
||
|
||
func TestContentModerationCheck_PreBlockFlaggedWritesRedisHashCache(t *testing.T) {
|
||
requestCount := 0
|
||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||
requestCount++
|
||
_ = json.NewEncoder(w).Encode(moderationAPIResponse{
|
||
Results: []moderationAPIResult{{
|
||
CategoryScores: map[string]float64{"sexual": 0.9},
|
||
}},
|
||
})
|
||
}))
|
||
defer server.Close()
|
||
|
||
cfg := defaultContentModerationConfig()
|
||
cfg.Enabled = true
|
||
cfg.Mode = ContentModerationModePreBlock
|
||
cfg.PreHashCheckEnabled = true
|
||
cfg.BaseURL = server.URL
|
||
cfg.APIKeys = []string{"sk-test"}
|
||
cfg.BlockStatus = http.StatusConflict
|
||
cfg.BlockMessage = "命中风险输入"
|
||
rawCfg, err := json.Marshal(cfg)
|
||
require.NoError(t, err)
|
||
|
||
repo := &contentModerationTestRepo{}
|
||
hashCache := &contentModerationTestHashCache{}
|
||
svc := NewContentModerationService(
|
||
&contentModerationTestSettingRepo{values: map[string]string{
|
||
SettingKeyRiskControlEnabled: "true",
|
||
SettingKeyContentModerationConfig: string(rawCfg),
|
||
}},
|
||
repo,
|
||
hashCache,
|
||
nil,
|
||
nil,
|
||
nil,
|
||
nil,
|
||
)
|
||
|
||
body := []byte(`{"messages":[{"role":"user","content":"repeat blocked prompt"}]}`)
|
||
decision, err := svc.Check(context.Background(), ContentModerationCheckInput{
|
||
Protocol: ContentModerationProtocolOpenAIChat,
|
||
Body: body,
|
||
})
|
||
require.NoError(t, err)
|
||
require.True(t, decision.Blocked)
|
||
require.Equal(t, ContentModerationActionBlock, decision.Action)
|
||
require.Equal(t, 1, requestCount)
|
||
require.Len(t, hashCache.recorded, 1)
|
||
require.Len(t, repo.logs, 1)
|
||
|
||
decision, err = svc.Check(context.Background(), ContentModerationCheckInput{
|
||
Protocol: ContentModerationProtocolOpenAIChat,
|
||
Body: body,
|
||
})
|
||
require.NoError(t, err)
|
||
require.True(t, decision.Blocked)
|
||
require.Equal(t, ContentModerationActionHashBlock, decision.Action)
|
||
require.Equal(t, hashCache.recorded[0], decision.InputHash)
|
||
require.Equal(t, 1, requestCount)
|
||
require.Len(t, repo.logs, 1)
|
||
}
|
||
|
||
func TestContentModerationDeleteFlaggedInputHash_NormalizesAndDeletes(t *testing.T) {
|
||
existingHash := strings.Repeat("a", 64)
|
||
hashCache := &contentModerationTestHashCache{hashes: map[string]struct{}{
|
||
existingHash: {},
|
||
}}
|
||
svc := &ContentModerationService{hashCache: hashCache}
|
||
|
||
result, err := svc.DeleteFlaggedInputHash(context.Background(), strings.ToUpper(existingHash))
|
||
|
||
require.NoError(t, err)
|
||
require.Equal(t, existingHash, result.InputHash)
|
||
require.True(t, result.Deleted)
|
||
require.NotContains(t, hashCache.hashes, existingHash)
|
||
require.Equal(t, []string{existingHash}, hashCache.deleted)
|
||
|
||
result, err = svc.DeleteFlaggedInputHash(context.Background(), existingHash)
|
||
|
||
require.NoError(t, err)
|
||
require.Equal(t, existingHash, result.InputHash)
|
||
require.False(t, result.Deleted)
|
||
}
|
||
|
||
func TestContentModerationClearFlaggedInputHashesAndStatusCount(t *testing.T) {
|
||
cfg := defaultContentModerationConfig()
|
||
cfg.Enabled = true
|
||
rawCfg, err := json.Marshal(cfg)
|
||
require.NoError(t, err)
|
||
|
||
hashCache := &contentModerationTestHashCache{hashes: map[string]struct{}{
|
||
strings.Repeat("a", 64): {},
|
||
strings.Repeat("b", 64): {},
|
||
}}
|
||
svc := &ContentModerationService{
|
||
settingRepo: &contentModerationTestSettingRepo{values: map[string]string{
|
||
SettingKeyRiskControlEnabled: "true",
|
||
SettingKeyContentModerationConfig: string(rawCfg),
|
||
}},
|
||
hashCache: hashCache,
|
||
keyHealth: make(map[string]*contentModerationKeyHealth),
|
||
}
|
||
|
||
status, err := svc.GetStatus(context.Background())
|
||
require.NoError(t, err)
|
||
require.Equal(t, int64(2), status.FlaggedHashCount)
|
||
|
||
result, err := svc.ClearFlaggedInputHashes(context.Background())
|
||
require.NoError(t, err)
|
||
require.Equal(t, int64(2), result.Deleted)
|
||
|
||
status, err = svc.GetStatus(context.Background())
|
||
require.NoError(t, err)
|
||
require.Equal(t, int64(0), status.FlaggedHashCount)
|
||
}
|
||
|
||
func TestContentModerationCheck_AsyncFlaggedWritesRedisHashCache(t *testing.T) {
|
||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||
_ = json.NewEncoder(w).Encode(moderationAPIResponse{
|
||
Results: []moderationAPIResult{{
|
||
CategoryScores: map[string]float64{"sexual": 0.9},
|
||
}},
|
||
})
|
||
}))
|
||
defer server.Close()
|
||
|
||
cfg := defaultContentModerationConfig()
|
||
cfg.Enabled = true
|
||
cfg.Mode = ContentModerationModeObserve
|
||
cfg.BaseURL = server.URL
|
||
cfg.APIKeys = []string{"sk-test"}
|
||
rawCfg, err := json.Marshal(cfg)
|
||
require.NoError(t, err)
|
||
|
||
repo := &contentModerationTestRepo{}
|
||
hashCache := &contentModerationTestHashCache{}
|
||
svc := NewContentModerationService(
|
||
&contentModerationTestSettingRepo{values: map[string]string{
|
||
SettingKeyRiskControlEnabled: "true",
|
||
SettingKeyContentModerationConfig: string(rawCfg),
|
||
}},
|
||
repo,
|
||
hashCache,
|
||
nil,
|
||
nil,
|
||
nil,
|
||
nil,
|
||
)
|
||
|
||
decision := svc.checkSync(context.Background(), ContentModerationCheckInput{
|
||
Protocol: ContentModerationProtocolOpenAIChat,
|
||
Body: []byte(`{"messages":[{"role":"user","content":"bad prompt"}]}`),
|
||
}, cfg, ContentModerationInput{Text: "bad prompt"}, strings.Repeat("b", 64), contentModerationIntPtr(25), false)
|
||
|
||
require.False(t, decision.Blocked)
|
||
require.Len(t, hashCache.recorded, 1)
|
||
require.Len(t, repo.logs, 1)
|
||
}
|
||
|
||
func TestBuildContentModerationAccountDisabledEmailBody_ContainsBanDetails(t *testing.T) {
|
||
userID := int64(1001)
|
||
cfg := defaultContentModerationConfig()
|
||
cfg.BanThreshold = 10
|
||
body := buildContentModerationAccountDisabledEmailBody("Sub2API <Admin>", &ContentModerationLog{
|
||
UserID: &userID,
|
||
UserEmail: "user@example.com",
|
||
GroupName: "vip_2",
|
||
HighestCategory: "sexual",
|
||
HighestScore: 0.926,
|
||
ViolationCount: 10,
|
||
}, cfg)
|
||
|
||
require.Contains(t, body, "账户已被自动禁用")
|
||
require.Contains(t, body, "封禁详情")
|
||
require.Contains(t, body, "账户当前处于封禁状态,所有 API 请求将被拒绝")
|
||
require.Contains(t, body, "10 次(阈值 10)")
|
||
require.Contains(t, body, "sexual / 0.926")
|
||
require.Contains(t, body, "Sub2API <Admin>")
|
||
}
|
||
|
||
func TestContentModerationUnbanUser_ActivatesUserAndInvalidatesAuthCache(t *testing.T) {
|
||
userRepo := &contentModerationTestUserRepo{user: &User{ID: 1001, Email: "user@example.com", Status: StatusDisabled}}
|
||
invalidator := &contentModerationTestAuthCacheInvalidator{}
|
||
repo := &contentModerationTestRepo{}
|
||
svc := NewContentModerationService(nil, repo, nil, nil, userRepo, invalidator, nil)
|
||
|
||
result, err := svc.UnbanUser(context.Background(), 1001)
|
||
|
||
require.NoError(t, err)
|
||
require.Equal(t, int64(1001), result.UserID)
|
||
require.Equal(t, StatusActive, result.Status)
|
||
require.Len(t, userRepo.updated, 1)
|
||
require.Equal(t, StatusActive, userRepo.updated[0].Status)
|
||
require.Equal(t, []int64{1001}, invalidator.userIDs)
|
||
}
|
||
|
||
func TestContentModerationUnbanUser_ActiveUserOnlyInvalidatesAuthCache(t *testing.T) {
|
||
userRepo := &contentModerationTestUserRepo{user: &User{ID: 1001, Email: "user@example.com", Status: StatusActive}}
|
||
invalidator := &contentModerationTestAuthCacheInvalidator{}
|
||
repo := &contentModerationTestRepo{}
|
||
svc := NewContentModerationService(nil, repo, nil, nil, userRepo, invalidator, nil)
|
||
|
||
result, err := svc.UnbanUser(context.Background(), 1001)
|
||
|
||
require.NoError(t, err)
|
||
require.Equal(t, StatusActive, result.Status)
|
||
require.Empty(t, userRepo.updated)
|
||
require.Equal(t, []int64{1001}, invalidator.userIDs)
|
||
}
|
||
|
||
func contentModerationIntPtr(v int) *int {
|
||
return &v
|
||
}
|