Files
sub2api-ht/backend/internal/service/content_moderation_test.go
2026-05-07 14:31:19 +08:00

1001 lines
34 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package service
import (
"context"
"encoding/json"
"net/http"
"net/http/httptest"
"strings"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/stretchr/testify/require"
)
type contentModerationTestSettingRepo struct {
values map[string]string
}
func (r *contentModerationTestSettingRepo) Get(ctx context.Context, key string) (*Setting, error) {
if value, ok := r.values[key]; ok {
return &Setting{Key: key, Value: value}, nil
}
return nil, ErrSettingNotFound
}
func (r *contentModerationTestSettingRepo) GetValue(ctx context.Context, key string) (string, error) {
if value, ok := r.values[key]; ok {
return value, nil
}
return "", ErrSettingNotFound
}
func (r *contentModerationTestSettingRepo) Set(ctx context.Context, key, value string) error {
if r.values == nil {
r.values = map[string]string{}
}
r.values[key] = value
return nil
}
func (r *contentModerationTestSettingRepo) GetMultiple(ctx context.Context, keys []string) (map[string]string, error) {
out := map[string]string{}
for _, key := range keys {
if value, ok := r.values[key]; ok {
out[key] = value
}
}
return out, nil
}
func (r *contentModerationTestSettingRepo) SetMultiple(ctx context.Context, settings map[string]string) error {
if r.values == nil {
r.values = map[string]string{}
}
for key, value := range settings {
r.values[key] = value
}
return nil
}
func (r *contentModerationTestSettingRepo) GetAll(ctx context.Context) (map[string]string, error) {
out := make(map[string]string, len(r.values))
for key, value := range r.values {
out[key] = value
}
return out, nil
}
func (r *contentModerationTestSettingRepo) Delete(ctx context.Context, key string) error {
delete(r.values, key)
return nil
}
type contentModerationTestRepo struct {
logs []ContentModerationLog
}
func (r *contentModerationTestRepo) CreateLog(ctx context.Context, log *ContentModerationLog) error {
if log != nil {
r.logs = append(r.logs, *log)
}
return nil
}
func (r *contentModerationTestRepo) ListLogs(ctx context.Context, filter ContentModerationLogFilter) ([]ContentModerationLog, *pagination.PaginationResult, error) {
return nil, nil, nil
}
func (r *contentModerationTestRepo) CountFlaggedByUserSince(ctx context.Context, userID int64, since time.Time) (int, error) {
return 0, nil
}
func (r *contentModerationTestRepo) CleanupExpiredLogs(ctx context.Context, hitBefore time.Time, nonHitBefore time.Time) (*ContentModerationCleanupResult, error) {
return &ContentModerationCleanupResult{}, nil
}
type contentModerationTestHashCache struct {
hashes map[string]struct{}
recorded []string
checked []string
deleted []string
hasResult bool
hasResultUsed bool
}
type contentModerationTestUserRepo struct {
user *User
updated []User
}
func (r *contentModerationTestUserRepo) Create(ctx context.Context, user *User) error {
panic("unexpected Create call")
}
func (r *contentModerationTestUserRepo) GetByID(ctx context.Context, id int64) (*User, error) {
if r.user == nil {
return nil, ErrUserNotFound
}
clone := *r.user
return &clone, nil
}
func (r *contentModerationTestUserRepo) GetByEmail(ctx context.Context, email string) (*User, error) {
panic("unexpected GetByEmail call")
}
func (r *contentModerationTestUserRepo) GetFirstAdmin(ctx context.Context) (*User, error) {
panic("unexpected GetFirstAdmin call")
}
func (r *contentModerationTestUserRepo) Update(ctx context.Context, user *User) error {
if user == nil {
return nil
}
clone := *user
r.updated = append(r.updated, clone)
r.user = &clone
return nil
}
func (r *contentModerationTestUserRepo) Delete(ctx context.Context, id int64) error {
panic("unexpected Delete call")
}
func (r *contentModerationTestUserRepo) GetUserAvatar(ctx context.Context, userID int64) (*UserAvatar, error) {
panic("unexpected GetUserAvatar call")
}
func (r *contentModerationTestUserRepo) UpsertUserAvatar(ctx context.Context, userID int64, input UpsertUserAvatarInput) (*UserAvatar, error) {
panic("unexpected UpsertUserAvatar call")
}
func (r *contentModerationTestUserRepo) DeleteUserAvatar(ctx context.Context, userID int64) error {
panic("unexpected DeleteUserAvatar call")
}
func (r *contentModerationTestUserRepo) List(ctx context.Context, params pagination.PaginationParams) ([]User, *pagination.PaginationResult, error) {
panic("unexpected List call")
}
func (r *contentModerationTestUserRepo) ListWithFilters(ctx context.Context, params pagination.PaginationParams, filters UserListFilters) ([]User, *pagination.PaginationResult, error) {
panic("unexpected ListWithFilters call")
}
func (r *contentModerationTestUserRepo) GetLatestUsedAtByUserIDs(ctx context.Context, userIDs []int64) (map[int64]*time.Time, error) {
panic("unexpected GetLatestUsedAtByUserIDs call")
}
func (r *contentModerationTestUserRepo) GetLatestUsedAtByUserID(ctx context.Context, userID int64) (*time.Time, error) {
panic("unexpected GetLatestUsedAtByUserID call")
}
func (r *contentModerationTestUserRepo) UpdateUserLastActiveAt(ctx context.Context, userID int64, activeAt time.Time) error {
panic("unexpected UpdateUserLastActiveAt call")
}
func (r *contentModerationTestUserRepo) UpdateBalance(ctx context.Context, id int64, amount float64) error {
panic("unexpected UpdateBalance call")
}
func (r *contentModerationTestUserRepo) DeductBalance(ctx context.Context, id int64, amount float64) error {
panic("unexpected DeductBalance call")
}
func (r *contentModerationTestUserRepo) UpdateConcurrency(ctx context.Context, id int64, amount int) error {
panic("unexpected UpdateConcurrency call")
}
func (r *contentModerationTestUserRepo) BatchSetConcurrency(ctx context.Context, userIDs []int64, value int) (int, error) {
panic("unexpected BatchSetConcurrency call")
}
func (r *contentModerationTestUserRepo) BatchAddConcurrency(ctx context.Context, userIDs []int64, delta int) (int, error) {
panic("unexpected BatchAddConcurrency call")
}
func (r *contentModerationTestUserRepo) ExistsByEmail(ctx context.Context, email string) (bool, error) {
panic("unexpected ExistsByEmail call")
}
func (r *contentModerationTestUserRepo) RemoveGroupFromAllowedGroups(ctx context.Context, groupID int64) (int64, error) {
panic("unexpected RemoveGroupFromAllowedGroups call")
}
func (r *contentModerationTestUserRepo) AddGroupToAllowedGroups(ctx context.Context, userID int64, groupID int64) error {
panic("unexpected AddGroupToAllowedGroups call")
}
func (r *contentModerationTestUserRepo) RemoveGroupFromUserAllowedGroups(ctx context.Context, userID int64, groupID int64) error {
panic("unexpected RemoveGroupFromUserAllowedGroups call")
}
func (r *contentModerationTestUserRepo) ListUserAuthIdentities(ctx context.Context, userID int64) ([]UserAuthIdentityRecord, error) {
panic("unexpected ListUserAuthIdentities call")
}
func (r *contentModerationTestUserRepo) UnbindUserAuthProvider(ctx context.Context, userID int64, provider string) error {
panic("unexpected UnbindUserAuthProvider call")
}
func (r *contentModerationTestUserRepo) UpdateTotpSecret(ctx context.Context, userID int64, encryptedSecret *string) error {
panic("unexpected UpdateTotpSecret call")
}
func (r *contentModerationTestUserRepo) EnableTotp(ctx context.Context, userID int64) error {
panic("unexpected EnableTotp call")
}
func (r *contentModerationTestUserRepo) DisableTotp(ctx context.Context, userID int64) error {
panic("unexpected DisableTotp call")
}
type contentModerationTestAuthCacheInvalidator struct {
userIDs []int64
}
func (i *contentModerationTestAuthCacheInvalidator) InvalidateAuthCacheByKey(ctx context.Context, key string) {
}
func (i *contentModerationTestAuthCacheInvalidator) InvalidateAuthCacheByUserID(ctx context.Context, userID int64) {
i.userIDs = append(i.userIDs, userID)
}
func (i *contentModerationTestAuthCacheInvalidator) InvalidateAuthCacheByGroupID(ctx context.Context, groupID int64) {
}
func (c *contentModerationTestHashCache) RecordFlaggedInputHash(ctx context.Context, inputHash string) error {
if c.hashes == nil {
c.hashes = map[string]struct{}{}
}
c.hashes[inputHash] = struct{}{}
c.recorded = append(c.recorded, inputHash)
return nil
}
func (c *contentModerationTestHashCache) HasFlaggedInputHash(ctx context.Context, inputHash string) (bool, error) {
c.checked = append(c.checked, inputHash)
if c.hasResultUsed {
return c.hasResult, nil
}
_, ok := c.hashes[inputHash]
return ok, nil
}
func (c *contentModerationTestHashCache) DeleteFlaggedInputHash(ctx context.Context, inputHash string) (bool, error) {
c.deleted = append(c.deleted, inputHash)
if c.hashes == nil {
return false, nil
}
if _, ok := c.hashes[inputHash]; !ok {
return false, nil
}
delete(c.hashes, inputHash)
return true, nil
}
func (c *contentModerationTestHashCache) ClearFlaggedInputHashes(ctx context.Context) (int64, error) {
deleted := int64(len(c.hashes))
c.hashes = map[string]struct{}{}
return deleted, nil
}
func (c *contentModerationTestHashCache) CountFlaggedInputHashes(ctx context.Context) (int64, error) {
return int64(len(c.hashes)), nil
}
func TestBuildContentModerationLog_RedactsInputExcerpt(t *testing.T) {
svc := &ContentModerationService{}
cfg := defaultContentModerationConfig()
input := ContentModerationCheckInput{
RequestID: "req-1",
Endpoint: "/v1/chat/completions",
Provider: "openai",
}
log := svc.buildLog(input, cfg, ContentModerationActionAllow, true, "sexual", 0.8, map[string]float64{"sexual": 0.8}, "hello sk-proj-1234567890abcdef", nil, nil, "")
require.NotContains(t, log.InputExcerpt, "sk-proj-1234567890abcdef")
require.Contains(t, log.InputExcerpt, "[已脱敏]")
}
func TestRedactContentModerationSecrets_LongHexAndTokens(t *testing.T) {
input := "你哈市多大事cf5bbdc4cd508f3aaf0d2070d529d4a4ac29099f8ecc357f696df28e1df91554 token=abc123456789xyz Bearer eyJhbGciOiJIUzI1NiJ9.eyJzdWIiOiIxMjM0NTY3ODkwIn0.signaturepart https://example.com/private/path?token=abc123"
out := redactContentModerationSecrets(input)
require.NotContains(t, out, "cf5bbdc4cd508f3aaf0d2070d529d4a4ac29099f8ecc357f696df28e1df91554")
require.NotContains(t, out, "abc123456789xyz")
require.NotContains(t, out, "eyJhbGciOiJIUzI1NiJ9")
require.NotContains(t, out, "https://example.com/private/path")
require.Contains(t, out, "[已脱敏]")
}
func TestContentModerationConfigNormalize_NonHitRetentionMaxThreeDays(t *testing.T) {
cfg := defaultContentModerationConfig()
cfg.NonHitRetentionDays = 30
cfg.normalize()
require.Equal(t, 3, cfg.NonHitRetentionDays)
}
func TestContentModerationUpdateConfig_AppendsAndDeletesAPIKeys(t *testing.T) {
cfg := defaultContentModerationConfig()
cfg.APIKeys = []string{"sk-old-a", "sk-old-b"}
rawCfg, err := json.Marshal(cfg)
require.NoError(t, err)
repo := &contentModerationTestSettingRepo{values: map[string]string{
SettingKeyContentModerationConfig: string(rawCfg),
}}
svc := NewContentModerationService(repo, nil, nil, nil, nil, nil, nil)
deleteHashes := []string{moderationAPIKeyHash("sk-old-a")}
addKeys := []string{"sk-new-c", "sk-old-b"}
view, err := svc.UpdateConfig(context.Background(), UpdateContentModerationConfigInput{
APIKeys: &addKeys,
DeleteAPIKeyHashes: &deleteHashes,
})
require.NoError(t, err)
require.Equal(t, 2, view.APIKeyCount)
require.Equal(t, []string{maskSecretTail("sk-old-b"), maskSecretTail("sk-new-c")}, view.APIKeyMasks)
var saved ContentModerationConfig
require.NoError(t, json.Unmarshal([]byte(repo.values[SettingKeyContentModerationConfig]), &saved))
require.Equal(t, []string{"sk-old-b", "sk-new-c"}, saved.apiKeys())
}
func TestContentModerationUpdateConfig_ReplacesAPIKeysWhenRequested(t *testing.T) {
cfg := defaultContentModerationConfig()
cfg.APIKeys = []string{"sk-old-a", "sk-old-b"}
rawCfg, err := json.Marshal(cfg)
require.NoError(t, err)
repo := &contentModerationTestSettingRepo{values: map[string]string{
SettingKeyContentModerationConfig: string(rawCfg),
}}
svc := NewContentModerationService(repo, nil, nil, nil, nil, nil, nil)
deleteHashes := []string{moderationAPIKeyHash("sk-old-a")}
replaceKeys := []string{"sk-new-only"}
view, err := svc.UpdateConfig(context.Background(), UpdateContentModerationConfigInput{
APIKeys: &replaceKeys,
APIKeysMode: contentModerationAPIKeysModeReplace,
DeleteAPIKeyHashes: &deleteHashes,
})
require.NoError(t, err)
require.Equal(t, 1, view.APIKeyCount)
require.Equal(t, []string{maskSecretTail("sk-new-only")}, view.APIKeyMasks)
var saved ContentModerationConfig
require.NoError(t, json.Unmarshal([]byte(repo.values[SettingKeyContentModerationConfig]), &saved))
require.Equal(t, []string{"sk-new-only"}, saved.apiKeys())
}
func TestExtractContentModerationInput_AnthropicImageSourceOnlyParticipatesInMemory(t *testing.T) {
body := []byte(`{
"messages": [
{"role":"user","content":"old"},
{"role":"assistant","content":"ok"},
{"role":"user","content":[
{"type":"text","text":"检查这张图"},
{"type":"image","source":{"type":"base64","media_type":"image/png","data":"aGVsbG8="}}
]}
]
}`)
input := ExtractContentModerationInput(ContentModerationProtocolAnthropicMessages, body)
require.Equal(t, "检查这张图", input.Text)
require.Equal(t, []string{"data:image/png;base64,aGVsbG8="}, input.Images)
log := (&ContentModerationService{}).buildLog(ContentModerationCheckInput{}, defaultContentModerationConfig(), ContentModerationActionAllow, false, "", 0, nil, input.ExcerptText(), nil, nil, "")
require.Equal(t, "检查这张图", log.InputExcerpt)
require.NotContains(t, log.InputExcerpt, "aGVsbG8=")
}
func TestExtractContentModerationInput_AnthropicKeepsEphemeralUserTextAndSkipsSystemReminders(t *testing.T) {
body := []byte(`{
"messages": [
{
"role": "user",
"content": [
{"type": "text", "text": "<system-reminder>工具说明</system-reminder>"},
{"type": "text", "text": "<system-reminder>Ainder>\n\n"},
{"type": "text", "text": "hid", "cache_control": {"type": "ephemeral"}}
]
}
]
}`)
input := ExtractContentModerationInput(ContentModerationProtocolAnthropicMessages, body)
require.Equal(t, "hid", input.Text)
require.Empty(t, input.Images)
}
func TestExtractContentModerationInput_OpenAIChatUsesLastUserMessage(t *testing.T) {
body := []byte(`{
"model":"gpt-5.5",
"messages":[
{"role":"system","content":"system prompt"},
{"role":"user","content":"old user"},
{"role":"assistant","content":"ok"},
{"role":"user","content":[{"type":"text","text":"latest user"},{"type":"image_url","image_url":{"url":"https://example.com/a.png"}}]}
]
}`)
input := ExtractContentModerationInput(ContentModerationProtocolOpenAIChat, body)
require.Equal(t, "latest user", input.Text)
require.Equal(t, []string{"https://example.com/a.png"}, input.Images)
require.NotContains(t, input.Text, "old user")
require.NotContains(t, input.Text, "system prompt")
}
func TestExtractContentModerationInput_OpenAIImagesIncludesPromptAndImages(t *testing.T) {
body := []byte(`{
"prompt":"replace background",
"images":[
{"image_url":"https://example.com/source.png"},
{"image_url":"data:image/png;base64,aGVsbG8="}
]
}`)
input := ExtractContentModerationInput(ContentModerationProtocolOpenAIImages, body)
require.Equal(t, "replace background", input.Text)
require.Equal(t, []string{"https://example.com/source.png", "data:image/png;base64,aGVsbG8="}, input.Images)
}
func TestContentModerationInput_NormalizeRandomSamplesOneImageForModerationAPI(t *testing.T) {
images := []string{
"data:image/png;base64,Zmlyc3Q=",
"data:image/png;base64,c2Vjb25k",
}
input := ContentModerationInput{
Text: "check image",
Images: append([]string(nil), images...),
}
input.Normalize()
require.Len(t, input.Images, 1)
require.Contains(t, images, input.Images[0])
require.Len(t, input.ModerationInput(), 2)
}
func TestBuildModerationTestInputRejectsMultipleImages(t *testing.T) {
_, _, err := buildModerationTestInput("check image", []string{
"data:image/png;base64,Zmlyc3Q=",
"data:image/png;base64,c2Vjb25k",
})
require.Error(t, err)
require.Contains(t, err.Error(), "最多上传 1 张测试图片")
}
func TestExtractContentModerationInput_OpenAIResponsesCodexPayloadUsesLastUserMessage(t *testing.T) {
body := []byte(`{
"model":"gpt-5.5",
"instructions":"instructions.....",
"input":[
{"type":"message","role":"developer","content":[{"type":"input_text","text":"developer permissions sk-proj-1234567890abcdef"}]},
{"type":"message","role":"user","content":[{"type":"input_text","text":"first user prompt"}]},
{"type":"message","role":"user","content":[{"type":"input_text","text":"last user prompt"}]}
],
"prompt_cache_key":"cache-key"
}`)
input := ExtractContentModerationInput(ContentModerationProtocolOpenAIResponses, body)
require.Equal(t, "last user prompt", input.Text)
require.Empty(t, input.Images)
require.NotContains(t, input.Text, "developer permissions")
require.NotContains(t, input.Text, "first user prompt")
}
func TestContentModerationCheck_OpenAIResponsesRecordsNonHitForCodexPayload(t *testing.T) {
var moderationRequest moderationAPIRequest
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
require.Equal(t, "/v1/moderations", r.URL.Path)
require.NoError(t, json.NewDecoder(r.Body).Decode(&moderationRequest))
_ = json.NewEncoder(w).Encode(moderationAPIResponse{
Results: []moderationAPIResult{{
CategoryScores: map[string]float64{"sexual": 0.01},
}},
})
}))
defer server.Close()
cfg := defaultContentModerationConfig()
cfg.Enabled = true
cfg.Mode = ContentModerationModePreBlock
cfg.BaseURL = server.URL
cfg.APIKeys = []string{"sk-test"}
cfg.RecordNonHits = true
rawCfg, err := json.Marshal(cfg)
require.NoError(t, err)
repo := &contentModerationTestRepo{}
svc := NewContentModerationService(
&contentModerationTestSettingRepo{values: map[string]string{
SettingKeyRiskControlEnabled: "true",
SettingKeyContentModerationConfig: string(rawCfg),
}},
repo,
&contentModerationTestHashCache{},
nil,
nil,
nil,
nil,
)
body := []byte(`{
"model":"gpt-5.5",
"input":[
{"type":"message","role":"developer","content":[{"type":"input_text","text":"developer instructions should not be audited"}]},
{"type":"message","role":"user","content":[{"type":"input_text","text":"first user prompt"}]},
{"type":"message","role":"user","content":[{"type":"input_text","text":"last user prompt"}]}
]
}`)
decision, err := svc.Check(context.Background(), ContentModerationCheckInput{
UserID: 1001,
Endpoint: "/responses",
Provider: "openai",
Model: "gpt-5.5",
Protocol: ContentModerationProtocolOpenAIResponses,
Body: body,
})
require.NoError(t, err)
require.False(t, decision.Blocked)
require.Len(t, repo.logs, 1)
require.False(t, repo.logs[0].Flagged)
require.Equal(t, ContentModerationActionAllow, repo.logs[0].Action)
require.Equal(t, "/responses", repo.logs[0].Endpoint)
require.Equal(t, "last user prompt", repo.logs[0].InputExcerpt)
require.Equal(t, "last user prompt", moderationRequest.Input)
}
func TestContentModerationCheck_PreBlockBlocksCodexResponsesLatestUserInput(t *testing.T) {
var moderationRequest moderationAPIRequest
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
require.Equal(t, "/v1/moderations", r.URL.Path)
require.NoError(t, json.NewDecoder(r.Body).Decode(&moderationRequest))
_ = json.NewEncoder(w).Encode(moderationAPIResponse{
Results: []moderationAPIResult{{
CategoryScores: map[string]float64{"sexual": 0.9},
}},
})
}))
defer server.Close()
cfg := defaultContentModerationConfig()
cfg.Enabled = true
cfg.Mode = ContentModerationModePreBlock
cfg.BaseURL = server.URL
cfg.APIKeys = []string{"sk-test"}
cfg.BlockStatus = http.StatusUnavailableForLegalReasons
cfg.BlockMessage = "内容审计测试阻断"
rawCfg, err := json.Marshal(cfg)
require.NoError(t, err)
repo := &contentModerationTestRepo{}
svc := NewContentModerationService(
&contentModerationTestSettingRepo{values: map[string]string{
SettingKeyRiskControlEnabled: "true",
SettingKeyContentModerationConfig: string(rawCfg),
}},
repo,
&contentModerationTestHashCache{},
nil,
nil,
nil,
nil,
)
body := []byte(`{
"model":"gpt-5.5",
"instructions":"instructions.....",
"input":[
{"type":"message","role":"developer","content":[{"type":"input_text","text":"developer instructions should not be audited"}]},
{"type":"message","role":"user","content":[{"type":"input_text","text":"environment context"}]},
{"type":"message","role":"user","content":[{"type":"input_text","text":"latest blocked prompt"}]}
]
}`)
decision, err := svc.Check(context.Background(), ContentModerationCheckInput{
UserID: 1001,
Endpoint: "/responses",
Provider: "openai",
Model: "gpt-5.5",
Protocol: ContentModerationProtocolOpenAIResponses,
Body: body,
})
require.NoError(t, err)
require.True(t, decision.Blocked)
require.Equal(t, ContentModerationActionBlock, decision.Action)
require.Equal(t, http.StatusUnavailableForLegalReasons, decision.StatusCode)
require.Equal(t, "内容审计测试阻断", decision.Message)
require.Len(t, repo.logs, 1)
require.True(t, repo.logs[0].Flagged)
require.Equal(t, ContentModerationActionBlock, repo.logs[0].Action)
require.Equal(t, ContentModerationModePreBlock, repo.logs[0].Mode)
require.Equal(t, "latest blocked prompt", repo.logs[0].InputExcerpt)
require.Equal(t, "latest blocked prompt", moderationRequest.Input)
}
func TestBuildContentModerationTestAuditResult_UsesConfiguredThresholdsOnly(t *testing.T) {
result := buildContentModerationTestAuditResult(&moderationAPIResult{
Flagged: true,
CategoryScores: map[string]float64{
"harassment": 0.65,
},
}, nil)
require.NotNil(t, result)
require.False(t, result.Flagged)
require.Equal(t, "harassment", result.HighestCategory)
require.Equal(t, 0.65, result.HighestScore)
require.Equal(t, 0.65, result.CompositeScore)
require.Equal(t, 0.98, result.Thresholds["harassment"])
}
func TestContentModerationCallModeration_400DoesNotFreezeAPIKey(t *testing.T) {
requestCount := 0
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
requestCount++
w.WriteHeader(http.StatusBadRequest)
_, _ = w.Write([]byte(`{"error":{"message":"Number of images (5) exceeds maximum of 1","type":"invalid_request_error","param":"input","code":"too_many_images"}}`))
}))
defer server.Close()
cfg := defaultContentModerationConfig()
cfg.BaseURL = server.URL
cfg.APIKeys = []string{"sk-test"}
cfg.RetryCount = 5
svc := NewContentModerationService(nil, nil, nil, nil, nil, nil, nil)
_, err := svc.callModeration(context.Background(), cfg, "hello")
require.Error(t, err)
require.Equal(t, 1, requestCount)
status := svc.apiKeyStatusForHash(0, moderationAPIKeyHash("sk-test"), maskSecretTail("sk-test"), true)
require.Equal(t, "error", status.Status)
require.Equal(t, http.StatusBadRequest, status.LastHTTPStatus)
require.Zero(t, status.FailureCount)
require.Nil(t, status.FrozenUntil)
}
func TestContentModerationCallModeration_FreezesByHTTPStatus(t *testing.T) {
tests := []struct {
name string
statusCode int
minFreeze time.Duration
maxFreeze time.Duration
}{
{name: "401 freezes ten minutes", statusCode: http.StatusUnauthorized, minFreeze: 9*time.Minute + 55*time.Second, maxFreeze: 10*time.Minute + time.Second},
{name: "403 freezes ten minutes", statusCode: http.StatusForbidden, minFreeze: 9*time.Minute + 55*time.Second, maxFreeze: 10*time.Minute + time.Second},
{name: "429 freezes one minute", statusCode: http.StatusTooManyRequests, minFreeze: 55 * time.Second, maxFreeze: time.Minute + time.Second},
{name: "529 freezes one minute", statusCode: 529, minFreeze: 55 * time.Second, maxFreeze: time.Minute + time.Second},
{name: "500 freezes ten seconds", statusCode: http.StatusInternalServerError, minFreeze: 5 * time.Second, maxFreeze: 11 * time.Second},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(tt.statusCode)
_, _ = w.Write([]byte(`{"error":{"message":"upstream error"}}`))
}))
defer server.Close()
cfg := defaultContentModerationConfig()
cfg.BaseURL = server.URL
cfg.APIKeys = []string{"sk-test"}
cfg.RetryCount = 0
svc := NewContentModerationService(nil, nil, nil, nil, nil, nil, nil)
_, err := svc.callModeration(context.Background(), cfg, "hello")
require.Error(t, err)
status := svc.apiKeyStatusForHash(0, moderationAPIKeyHash("sk-test"), maskSecretTail("sk-test"), true)
require.Equal(t, "frozen", status.Status)
require.Equal(t, tt.statusCode, status.LastHTTPStatus)
require.Equal(t, 1, status.FailureCount)
require.NotNil(t, status.FrozenUntil)
remaining := time.Until(*status.FrozenUntil)
require.GreaterOrEqual(t, remaining, tt.minFreeze)
require.LessOrEqual(t, remaining, tt.maxFreeze)
})
}
}
func TestContentModerationTestAPIKeys_400DoesNotFreezeAPIKey(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusBadRequest)
_, _ = w.Write([]byte(`{"error":{"message":"invalid moderation request"}}`))
}))
defer server.Close()
svc := NewContentModerationService(
&contentModerationTestSettingRepo{values: map[string]string{}},
nil,
nil,
nil,
nil,
nil,
nil,
)
result, err := svc.TestAPIKeys(context.Background(), TestContentModerationAPIKeysInput{
APIKeys: []string{"sk-test"},
BaseURL: server.URL,
Prompt: "hello",
})
require.NoError(t, err)
require.Len(t, result.Items, 1)
require.Equal(t, "error", result.Items[0].Status)
require.Equal(t, http.StatusBadRequest, result.Items[0].LastHTTPStatus)
require.Zero(t, result.Items[0].FailureCount)
require.Nil(t, result.Items[0].FrozenUntil)
}
func TestContentModerationCheck_PreHashUsesRedisHashCache(t *testing.T) {
cfg := defaultContentModerationConfig()
cfg.Enabled = true
cfg.PreHashCheckEnabled = true
cfg.APIKeys = []string{"sk-test"}
cfg.BlockStatus = http.StatusConflict
cfg.BlockMessage = "命中历史风险输入"
rawCfg, err := json.Marshal(cfg)
require.NoError(t, err)
hashCache := &contentModerationTestHashCache{hashes: map[string]struct{}{}}
content := ContentModerationInput{Text: "blocked prompt"}
content.Normalize()
hashCache.hashes[content.Hash()] = struct{}{}
svc := NewContentModerationService(
&contentModerationTestSettingRepo{values: map[string]string{
SettingKeyRiskControlEnabled: "true",
SettingKeyContentModerationConfig: string(rawCfg),
}},
&contentModerationTestRepo{},
hashCache,
nil,
nil,
nil,
nil,
)
decision, err := svc.Check(context.Background(), ContentModerationCheckInput{
Protocol: ContentModerationProtocolOpenAIChat,
Body: []byte(`{"messages":[{"role":"user","content":"blocked prompt"}]}`),
})
require.NoError(t, err)
require.True(t, decision.Blocked)
require.Equal(t, ContentModerationActionHashBlock, decision.Action)
require.Equal(t, http.StatusConflict, decision.StatusCode)
require.Equal(t, content.Hash(), decision.InputHash)
require.Contains(t, decision.Message, "命中历史风险输入")
require.Contains(t, decision.Message, content.Hash())
require.Len(t, hashCache.checked, 1)
}
func TestContentModerationCheck_PreBlockFlaggedWritesRedisHashCache(t *testing.T) {
requestCount := 0
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
requestCount++
_ = json.NewEncoder(w).Encode(moderationAPIResponse{
Results: []moderationAPIResult{{
CategoryScores: map[string]float64{"sexual": 0.9},
}},
})
}))
defer server.Close()
cfg := defaultContentModerationConfig()
cfg.Enabled = true
cfg.Mode = ContentModerationModePreBlock
cfg.PreHashCheckEnabled = true
cfg.BaseURL = server.URL
cfg.APIKeys = []string{"sk-test"}
cfg.BlockStatus = http.StatusConflict
cfg.BlockMessage = "命中风险输入"
rawCfg, err := json.Marshal(cfg)
require.NoError(t, err)
repo := &contentModerationTestRepo{}
hashCache := &contentModerationTestHashCache{}
svc := NewContentModerationService(
&contentModerationTestSettingRepo{values: map[string]string{
SettingKeyRiskControlEnabled: "true",
SettingKeyContentModerationConfig: string(rawCfg),
}},
repo,
hashCache,
nil,
nil,
nil,
nil,
)
body := []byte(`{"messages":[{"role":"user","content":"repeat blocked prompt"}]}`)
decision, err := svc.Check(context.Background(), ContentModerationCheckInput{
Protocol: ContentModerationProtocolOpenAIChat,
Body: body,
})
require.NoError(t, err)
require.True(t, decision.Blocked)
require.Equal(t, ContentModerationActionBlock, decision.Action)
require.Equal(t, 1, requestCount)
require.Len(t, hashCache.recorded, 1)
require.Len(t, repo.logs, 1)
decision, err = svc.Check(context.Background(), ContentModerationCheckInput{
Protocol: ContentModerationProtocolOpenAIChat,
Body: body,
})
require.NoError(t, err)
require.True(t, decision.Blocked)
require.Equal(t, ContentModerationActionHashBlock, decision.Action)
require.Equal(t, hashCache.recorded[0], decision.InputHash)
require.Equal(t, 1, requestCount)
require.Len(t, repo.logs, 1)
}
func TestContentModerationDeleteFlaggedInputHash_NormalizesAndDeletes(t *testing.T) {
existingHash := strings.Repeat("a", 64)
hashCache := &contentModerationTestHashCache{hashes: map[string]struct{}{
existingHash: {},
}}
svc := &ContentModerationService{hashCache: hashCache}
result, err := svc.DeleteFlaggedInputHash(context.Background(), strings.ToUpper(existingHash))
require.NoError(t, err)
require.Equal(t, existingHash, result.InputHash)
require.True(t, result.Deleted)
require.NotContains(t, hashCache.hashes, existingHash)
require.Equal(t, []string{existingHash}, hashCache.deleted)
result, err = svc.DeleteFlaggedInputHash(context.Background(), existingHash)
require.NoError(t, err)
require.Equal(t, existingHash, result.InputHash)
require.False(t, result.Deleted)
}
func TestContentModerationClearFlaggedInputHashesAndStatusCount(t *testing.T) {
cfg := defaultContentModerationConfig()
cfg.Enabled = true
rawCfg, err := json.Marshal(cfg)
require.NoError(t, err)
hashCache := &contentModerationTestHashCache{hashes: map[string]struct{}{
strings.Repeat("a", 64): {},
strings.Repeat("b", 64): {},
}}
svc := &ContentModerationService{
settingRepo: &contentModerationTestSettingRepo{values: map[string]string{
SettingKeyRiskControlEnabled: "true",
SettingKeyContentModerationConfig: string(rawCfg),
}},
hashCache: hashCache,
keyHealth: make(map[string]*contentModerationKeyHealth),
}
status, err := svc.GetStatus(context.Background())
require.NoError(t, err)
require.Equal(t, int64(2), status.FlaggedHashCount)
result, err := svc.ClearFlaggedInputHashes(context.Background())
require.NoError(t, err)
require.Equal(t, int64(2), result.Deleted)
status, err = svc.GetStatus(context.Background())
require.NoError(t, err)
require.Equal(t, int64(0), status.FlaggedHashCount)
}
func TestContentModerationCheck_AsyncFlaggedWritesRedisHashCache(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
_ = json.NewEncoder(w).Encode(moderationAPIResponse{
Results: []moderationAPIResult{{
CategoryScores: map[string]float64{"sexual": 0.9},
}},
})
}))
defer server.Close()
cfg := defaultContentModerationConfig()
cfg.Enabled = true
cfg.Mode = ContentModerationModeObserve
cfg.BaseURL = server.URL
cfg.APIKeys = []string{"sk-test"}
rawCfg, err := json.Marshal(cfg)
require.NoError(t, err)
repo := &contentModerationTestRepo{}
hashCache := &contentModerationTestHashCache{}
svc := NewContentModerationService(
&contentModerationTestSettingRepo{values: map[string]string{
SettingKeyRiskControlEnabled: "true",
SettingKeyContentModerationConfig: string(rawCfg),
}},
repo,
hashCache,
nil,
nil,
nil,
nil,
)
decision := svc.checkSync(context.Background(), ContentModerationCheckInput{
Protocol: ContentModerationProtocolOpenAIChat,
Body: []byte(`{"messages":[{"role":"user","content":"bad prompt"}]}`),
}, cfg, ContentModerationInput{Text: "bad prompt"}, strings.Repeat("b", 64), contentModerationIntPtr(25), false)
require.False(t, decision.Blocked)
require.Len(t, hashCache.recorded, 1)
require.Len(t, repo.logs, 1)
}
func TestBuildContentModerationAccountDisabledEmailBody_ContainsBanDetails(t *testing.T) {
userID := int64(1001)
cfg := defaultContentModerationConfig()
cfg.BanThreshold = 10
body := buildContentModerationAccountDisabledEmailBody("Sub2API <Admin>", &ContentModerationLog{
UserID: &userID,
UserEmail: "user@example.com",
GroupName: "vip_2",
HighestCategory: "sexual",
HighestScore: 0.926,
ViolationCount: 10,
}, cfg)
require.Contains(t, body, "账户已被自动禁用")
require.Contains(t, body, "封禁详情")
require.Contains(t, body, "账户当前处于封禁状态,所有 API 请求将被拒绝")
require.Contains(t, body, "10 次(阈值 10")
require.Contains(t, body, "sexual / 0.926")
require.Contains(t, body, "Sub2API &lt;Admin&gt;")
}
func TestContentModerationUnbanUser_ActivatesUserAndInvalidatesAuthCache(t *testing.T) {
userRepo := &contentModerationTestUserRepo{user: &User{ID: 1001, Email: "user@example.com", Status: StatusDisabled}}
invalidator := &contentModerationTestAuthCacheInvalidator{}
repo := &contentModerationTestRepo{}
svc := NewContentModerationService(nil, repo, nil, nil, userRepo, invalidator, nil)
result, err := svc.UnbanUser(context.Background(), 1001)
require.NoError(t, err)
require.Equal(t, int64(1001), result.UserID)
require.Equal(t, StatusActive, result.Status)
require.Len(t, userRepo.updated, 1)
require.Equal(t, StatusActive, userRepo.updated[0].Status)
require.Equal(t, []int64{1001}, invalidator.userIDs)
}
func TestContentModerationUnbanUser_ActiveUserOnlyInvalidatesAuthCache(t *testing.T) {
userRepo := &contentModerationTestUserRepo{user: &User{ID: 1001, Email: "user@example.com", Status: StatusActive}}
invalidator := &contentModerationTestAuthCacheInvalidator{}
repo := &contentModerationTestRepo{}
svc := NewContentModerationService(nil, repo, nil, nil, userRepo, invalidator, nil)
result, err := svc.UnbanUser(context.Background(), 1001)
require.NoError(t, err)
require.Equal(t, StatusActive, result.Status)
require.Empty(t, userRepo.updated)
require.Equal(t, []int64{1001}, invalidator.userIDs)
}
func contentModerationIntPtr(v int) *int {
return &v
}