Files
sub2api-ht/backend/internal/service/content_moderation.go
2026-05-07 15:11:38 +08:00

2049 lines
67 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package service
import (
"bytes"
"context"
"crypto/sha256"
"encoding/base64"
"encoding/binary"
"encoding/hex"
"encoding/json"
"errors"
"fmt"
"io"
"log/slog"
"net/http"
"net/url"
"sort"
"strings"
"sync"
"sync/atomic"
"time"
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
)
const (
ContentModerationModeOff = "off"
ContentModerationModeObserve = "observe"
ContentModerationModePreBlock = "pre_block"
contentModerationAPIKeysModeAppend = "append"
contentModerationAPIKeysModeReplace = "replace"
ContentModerationActionAllow = "allow"
ContentModerationActionBlock = "block"
ContentModerationActionHashBlock = "hash_block"
ContentModerationActionError = "error"
ContentModerationProtocolAnthropicMessages = "anthropic_messages"
ContentModerationProtocolOpenAIResponses = "openai_responses"
ContentModerationProtocolOpenAIChat = "openai_chat_completions"
ContentModerationProtocolGemini = "gemini"
ContentModerationProtocolOpenAIImages = "openai_images"
defaultContentModerationBaseURL = "https://api.openai.com"
defaultContentModerationModel = "omni-moderation-latest"
defaultContentModerationTimeoutMS = 3000
maxContentModerationTimeoutMS = 30000
maxModerationInputRunes = 12000
maxModerationExcerptRunes = 240
defaultContentModerationWorkerCount = 4
maxContentModerationWorkerCount = 32
defaultContentModerationQueueSize = 32768
maxContentModerationQueueSize = 100000
defaultContentModerationBanThreshold = 10
defaultContentModerationViolationWindowHours = 720
defaultContentModerationBlockHTTPStatus = http.StatusForbidden
defaultContentModerationBlockMessage = "内容审计命中风险规则,请调整输入后重试"
defaultContentModerationRetryCount = 2
maxContentModerationRetryCount = 5
defaultContentModerationHitRetentionDays = 180
defaultContentModerationNonHitRetentionDays = 3
maxContentModerationRetentionDays = 3650
maxContentModerationNonHitRetentionDays = 3
contentModerationKeyRateLimitFreezeDuration = time.Minute
contentModerationKeyAuthFreezeDuration = 10 * time.Minute
contentModerationKeyHTTPErrorFreezeDuration = 10 * time.Second
maxContentModerationInputImages = 1
maxContentModerationTestImages = maxContentModerationInputImages
maxContentModerationTestImageBytes = 8 * 1024 * 1024
maxContentModerationTestImageDataURLBytes = 12 * 1024 * 1024
contentModerationCleanupInterval = 24 * time.Hour
contentModerationCleanupTimeout = 30 * time.Minute
contentModerationCleanupDelay = 5 * time.Minute
)
var contentModerationCategoryOrder = []string{
"harassment",
"harassment/threatening",
"hate",
"hate/threatening",
"illicit",
"illicit/violent",
"self-harm",
"self-harm/intent",
"self-harm/instructions",
"sexual",
"sexual/minors",
"violence",
"violence/graphic",
}
func ContentModerationDefaultThresholds() map[string]float64 {
return map[string]float64{
"harassment": 0.98,
"harassment/threatening": 0.90,
"hate": 0.65,
"hate/threatening": 0.65,
"illicit": 0.95,
"illicit/violent": 0.95,
"self-harm": 0.65,
"self-harm/intent": 0.85,
"self-harm/instructions": 0.65,
"sexual": 0.65,
"sexual/minors": 0.65,
"violence": 0.95,
"violence/graphic": 0.95,
}
}
func ContentModerationCategories() []string {
out := make([]string, len(contentModerationCategoryOrder))
copy(out, contentModerationCategoryOrder)
return out
}
type ContentModerationConfig struct {
Enabled bool `json:"enabled"`
Mode string `json:"mode"`
BaseURL string `json:"base_url"`
Model string `json:"model"`
APIKey string `json:"api_key,omitempty"`
APIKeys []string `json:"api_keys,omitempty"`
TimeoutMS int `json:"timeout_ms"`
SampleRate int `json:"sample_rate"`
AllGroups bool `json:"all_groups"`
GroupIDs []int64 `json:"group_ids"`
RecordNonHits bool `json:"record_non_hits"`
Thresholds map[string]float64 `json:"thresholds"`
WorkerCount int `json:"worker_count"`
QueueSize int `json:"queue_size"`
BlockStatus int `json:"block_status"`
BlockMessage string `json:"block_message"`
EmailOnHit bool `json:"email_on_hit"`
AutoBanEnabled bool `json:"auto_ban_enabled"`
BanThreshold int `json:"ban_threshold"`
ViolationWindowHours int `json:"violation_window_hours"`
RetryCount int `json:"retry_count"`
HitRetentionDays int `json:"hit_retention_days"`
NonHitRetentionDays int `json:"non_hit_retention_days"`
PreHashCheckEnabled bool `json:"pre_hash_check_enabled"`
}
type ContentModerationConfigView struct {
Enabled bool `json:"enabled"`
Mode string `json:"mode"`
BaseURL string `json:"base_url"`
Model string `json:"model"`
APIKeyConfigured bool `json:"api_key_configured"`
APIKeyMasked string `json:"api_key_masked"`
APIKeyCount int `json:"api_key_count"`
APIKeyMasks []string `json:"api_key_masks"`
APIKeyStatuses []ContentModerationAPIKeyStatus `json:"api_key_statuses"`
TimeoutMS int `json:"timeout_ms"`
SampleRate int `json:"sample_rate"`
AllGroups bool `json:"all_groups"`
GroupIDs []int64 `json:"group_ids"`
RecordNonHits bool `json:"record_non_hits"`
WorkerCount int `json:"worker_count"`
QueueSize int `json:"queue_size"`
BlockStatus int `json:"block_status"`
BlockMessage string `json:"block_message"`
EmailOnHit bool `json:"email_on_hit"`
AutoBanEnabled bool `json:"auto_ban_enabled"`
BanThreshold int `json:"ban_threshold"`
ViolationWindowHours int `json:"violation_window_hours"`
RetryCount int `json:"retry_count"`
HitRetentionDays int `json:"hit_retention_days"`
NonHitRetentionDays int `json:"non_hit_retention_days"`
PreHashCheckEnabled bool `json:"pre_hash_check_enabled"`
}
type ContentModerationAPIKeyStatus struct {
Index int `json:"index"`
KeyHash string `json:"key_hash"`
Masked string `json:"masked"`
Status string `json:"status"`
FailureCount int `json:"failure_count"`
SuccessCount int64 `json:"success_count"`
LastError string `json:"last_error"`
LastCheckedAt *time.Time `json:"last_checked_at,omitempty"`
FrozenUntil *time.Time `json:"frozen_until,omitempty"`
LastLatencyMS int `json:"last_latency_ms"`
LastHTTPStatus int `json:"last_http_status"`
LastTested bool `json:"last_tested"`
Configured bool `json:"configured"`
}
type TestContentModerationAPIKeysInput struct {
APIKeys []string `json:"api_keys"`
BaseURL string `json:"base_url"`
Model string `json:"model"`
TimeoutMS int `json:"timeout_ms"`
Prompt string `json:"prompt"`
Images []string `json:"images"`
}
type TestContentModerationAPIKeysResult struct {
Items []ContentModerationAPIKeyStatus `json:"items"`
AuditResult *ContentModerationTestAuditResult `json:"audit_result,omitempty"`
ImageCount int `json:"image_count"`
}
type ContentModerationTestAuditResult struct {
Flagged bool `json:"flagged"`
HighestCategory string `json:"highest_category"`
HighestScore float64 `json:"highest_score"`
CompositeScore float64 `json:"composite_score"`
CategoryScores map[string]float64 `json:"category_scores"`
Thresholds map[string]float64 `json:"thresholds"`
}
type UpdateContentModerationConfigInput struct {
Enabled *bool `json:"enabled"`
Mode *string `json:"mode"`
BaseURL *string `json:"base_url"`
Model *string `json:"model"`
APIKey *string `json:"api_key"`
APIKeys *[]string `json:"api_keys"`
APIKeysMode string `json:"api_keys_mode"`
DeleteAPIKeyHashes *[]string `json:"delete_api_key_hashes"`
ClearAPIKey bool `json:"clear_api_key"`
TimeoutMS *int `json:"timeout_ms"`
SampleRate *int `json:"sample_rate"`
AllGroups *bool `json:"all_groups"`
GroupIDs *[]int64 `json:"group_ids"`
RecordNonHits *bool `json:"record_non_hits"`
WorkerCount *int `json:"worker_count"`
QueueSize *int `json:"queue_size"`
BlockStatus *int `json:"block_status"`
BlockMessage *string `json:"block_message"`
EmailOnHit *bool `json:"email_on_hit"`
AutoBanEnabled *bool `json:"auto_ban_enabled"`
BanThreshold *int `json:"ban_threshold"`
ViolationWindowHours *int `json:"violation_window_hours"`
RetryCount *int `json:"retry_count"`
HitRetentionDays *int `json:"hit_retention_days"`
NonHitRetentionDays *int `json:"non_hit_retention_days"`
PreHashCheckEnabled *bool `json:"pre_hash_check_enabled"`
}
type ContentModerationCheckInput struct {
RequestID string
UserID int64
UserEmail string
APIKeyID int64
APIKeyName string
GroupID *int64
GroupName string
Endpoint string
Provider string
Model string
Protocol string
Body []byte
}
type ContentModerationInput struct {
Text string
Images []string
}
func (in *ContentModerationInput) Normalize() {
if in == nil {
return
}
in.Text = trimRunes(normalizeContentModerationText(in.Text), maxModerationInputRunes)
in.Images = normalizeModerationImages(in.Images)
}
func (in ContentModerationInput) IsEmpty() bool {
return strings.TrimSpace(in.Text) == "" && len(in.Images) == 0
}
func (in ContentModerationInput) ModerationInput() any {
images := limitContentModerationImages(in.Images)
if len(images) == 0 {
return in.Text
}
parts := make([]moderationAPIInputPart, 0, len(images)+1)
if strings.TrimSpace(in.Text) != "" {
parts = append(parts, moderationAPIInputPart{Type: "text", Text: in.Text})
}
for _, image := range images {
parts = append(parts, moderationAPIInputPart{
Type: "image_url",
ImageURL: &moderationAPIImageURLRef{URL: image},
})
}
return parts
}
func (in ContentModerationInput) ExcerptText() string {
return in.Text
}
func (in ContentModerationInput) Hash() string {
h := sha256.New()
_, _ = h.Write([]byte("text:"))
_, _ = h.Write([]byte(in.Text))
for _, image := range in.Images {
imageHash := sha256.Sum256([]byte(image))
_, _ = h.Write([]byte("\nimage:"))
_, _ = h.Write([]byte(hex.EncodeToString(imageHash[:])))
}
return hex.EncodeToString(h.Sum(nil))
}
type ContentModerationDecision struct {
Allowed bool `json:"allowed"`
Blocked bool `json:"blocked"`
Flagged bool `json:"flagged"`
Message string `json:"message"`
StatusCode int `json:"status_code"`
InputHash string `json:"input_hash,omitempty"`
HighestCategory string `json:"highest_category"`
HighestScore float64 `json:"highest_score"`
CategoryScores map[string]float64 `json:"category_scores"`
Action string `json:"action"`
}
type ContentModerationLog struct {
ID int64 `json:"id"`
RequestID string `json:"request_id"`
UserID *int64 `json:"user_id,omitempty"`
UserEmail string `json:"user_email"`
APIKeyID *int64 `json:"api_key_id,omitempty"`
APIKeyName string `json:"api_key_name"`
GroupID *int64 `json:"group_id,omitempty"`
GroupName string `json:"group_name"`
Endpoint string `json:"endpoint"`
Provider string `json:"provider"`
Model string `json:"model"`
Mode string `json:"mode"`
Action string `json:"action"`
Flagged bool `json:"flagged"`
HighestCategory string `json:"highest_category"`
HighestScore float64 `json:"highest_score"`
CategoryScores map[string]float64 `json:"category_scores"`
ThresholdSnapshot map[string]float64 `json:"threshold_snapshot"`
InputExcerpt string `json:"input_excerpt"`
UpstreamLatencyMS *int `json:"upstream_latency_ms,omitempty"`
Error string `json:"error"`
ViolationCount int `json:"violation_count"`
AutoBanned bool `json:"auto_banned"`
EmailSent bool `json:"email_sent"`
UserStatus string `json:"user_status"`
QueueDelayMS *int `json:"queue_delay_ms,omitempty"`
CreatedAt time.Time `json:"created_at"`
}
type ContentModerationLogFilter struct {
Pagination pagination.PaginationParams
Result string
GroupID *int64
Endpoint string
Search string
From *time.Time
To *time.Time
}
type ContentModerationCleanupResult struct {
DeletedHit int64 `json:"deleted_hit"`
DeletedNonHit int64 `json:"deleted_non_hit"`
FinishedAt time.Time `json:"finished_at"`
}
type ContentModerationRuntimeStatus struct {
Enabled bool `json:"enabled"`
RiskControlEnabled bool `json:"risk_control_enabled"`
Mode string `json:"mode"`
WorkerCount int `json:"worker_count"`
MaxWorkers int `json:"max_workers"`
ActiveWorkers int `json:"active_workers"`
IdleWorkers int `json:"idle_workers"`
QueueSize int `json:"queue_size"`
QueueLength int `json:"queue_length"`
QueueUsagePercent float64 `json:"queue_usage_percent"`
Enqueued int64 `json:"enqueued"`
Dropped int64 `json:"dropped"`
Processed int64 `json:"processed"`
Errors int64 `json:"errors"`
APIKeyStatuses []ContentModerationAPIKeyStatus `json:"api_key_statuses"`
FlaggedHashCount int64 `json:"flagged_hash_count"`
LastCleanupAt *time.Time `json:"last_cleanup_at,omitempty"`
LastCleanupDeletedHit int64 `json:"last_cleanup_deleted_hit"`
LastCleanupDeletedNonHit int64 `json:"last_cleanup_deleted_non_hit"`
}
type ContentModerationUnbanUserResult struct {
UserID int64 `json:"user_id"`
Status string `json:"status"`
}
type ContentModerationDeleteHashResult struct {
InputHash string `json:"input_hash"`
Deleted bool `json:"deleted"`
}
type ContentModerationClearHashesResult struct {
Deleted int64 `json:"deleted"`
}
type ContentModerationRepository interface {
CreateLog(ctx context.Context, log *ContentModerationLog) error
ListLogs(ctx context.Context, filter ContentModerationLogFilter) ([]ContentModerationLog, *pagination.PaginationResult, error)
CountFlaggedByUserSince(ctx context.Context, userID int64, since time.Time) (int, error)
CleanupExpiredLogs(ctx context.Context, hitBefore time.Time, nonHitBefore time.Time) (*ContentModerationCleanupResult, error)
}
type ContentModerationHashCache interface {
RecordFlaggedInputHash(ctx context.Context, inputHash string) error
HasFlaggedInputHash(ctx context.Context, inputHash string) (bool, error)
DeleteFlaggedInputHash(ctx context.Context, inputHash string) (bool, error)
ClearFlaggedInputHashes(ctx context.Context) (int64, error)
CountFlaggedInputHashes(ctx context.Context) (int64, error)
}
type ContentModerationService struct {
settingRepo SettingRepository
repo ContentModerationRepository
hashCache ContentModerationHashCache
groupRepo GroupRepository
userRepo UserRepository
authCacheInvalidator APIKeyAuthCacheInvalidator
emailService *EmailService
httpClient *http.Client
asyncQueue chan contentModerationTask
workerCount int
apiKeyCursor atomic.Uint64
asyncActive atomic.Int64
asyncEnqueued atomic.Int64
asyncDropped atomic.Int64
asyncProcessed atomic.Int64
asyncErrors atomic.Int64
lastCleanupUnix atomic.Int64
lastCleanupDeletedHit atomic.Int64
lastCleanupDeletedNonHit atomic.Int64
keyHealthMu sync.Mutex
keyHealth map[string]*contentModerationKeyHealth
}
type contentModerationTask struct {
input ContentModerationCheckInput
content ContentModerationInput
inputHash string
enqueuedAt time.Time
}
type contentModerationKeyHealth struct {
Hash string
Masked string
FailureCount int
SuccessCount int64
LastError string
LastCheckedAt time.Time
FrozenUntil time.Time
LastLatencyMS int
LastHTTPStatus int
LastTested bool
}
func NewContentModerationService(
settingRepo SettingRepository,
repo ContentModerationRepository,
hashCache ContentModerationHashCache,
groupRepo GroupRepository,
userRepo UserRepository,
authCacheInvalidator APIKeyAuthCacheInvalidator,
emailService *EmailService,
) *ContentModerationService {
svc := &ContentModerationService{
settingRepo: settingRepo,
repo: repo,
hashCache: hashCache,
groupRepo: groupRepo,
userRepo: userRepo,
authCacheInvalidator: authCacheInvalidator,
emailService: emailService,
httpClient: &http.Client{},
workerCount: maxContentModerationWorkerCount,
asyncQueue: make(chan contentModerationTask, maxContentModerationQueueSize),
keyHealth: make(map[string]*contentModerationKeyHealth),
}
if settingRepo != nil && repo != nil {
for i := 0; i < svc.workerCount; i++ {
go svc.worker(i)
}
go svc.cleanupWorker()
}
return svc
}
func (s *ContentModerationService) GetConfig(ctx context.Context) (*ContentModerationConfigView, error) {
cfg, err := s.loadConfig(ctx)
if err != nil {
return nil, err
}
return s.configView(cfg), nil
}
func (s *ContentModerationService) UpdateConfig(ctx context.Context, input UpdateContentModerationConfigInput) (*ContentModerationConfigView, error) {
cfg, err := s.loadConfig(ctx)
if err != nil {
return nil, err
}
if input.Enabled != nil {
cfg.Enabled = *input.Enabled
}
if input.Mode != nil {
cfg.Mode = strings.TrimSpace(*input.Mode)
}
if input.BaseURL != nil {
cfg.BaseURL = strings.TrimSpace(*input.BaseURL)
}
if input.Model != nil {
cfg.Model = strings.TrimSpace(*input.Model)
}
if input.TimeoutMS != nil {
cfg.TimeoutMS = *input.TimeoutMS
}
if input.SampleRate != nil {
cfg.SampleRate = *input.SampleRate
}
if input.WorkerCount != nil {
cfg.WorkerCount = *input.WorkerCount
}
if input.QueueSize != nil {
cfg.QueueSize = *input.QueueSize
}
if input.BlockStatus != nil {
cfg.BlockStatus = *input.BlockStatus
}
if input.BlockMessage != nil {
cfg.BlockMessage = strings.TrimSpace(*input.BlockMessage)
}
if input.EmailOnHit != nil {
cfg.EmailOnHit = *input.EmailOnHit
}
if input.AutoBanEnabled != nil {
cfg.AutoBanEnabled = *input.AutoBanEnabled
}
if input.BanThreshold != nil {
cfg.BanThreshold = *input.BanThreshold
}
if input.ViolationWindowHours != nil {
cfg.ViolationWindowHours = *input.ViolationWindowHours
}
if input.RetryCount != nil {
cfg.RetryCount = *input.RetryCount
}
if input.HitRetentionDays != nil {
cfg.HitRetentionDays = *input.HitRetentionDays
}
if input.NonHitRetentionDays != nil {
cfg.NonHitRetentionDays = *input.NonHitRetentionDays
}
if input.PreHashCheckEnabled != nil {
cfg.PreHashCheckEnabled = *input.PreHashCheckEnabled
}
if input.AllGroups != nil {
cfg.AllGroups = *input.AllGroups
}
if input.GroupIDs != nil {
cfg.GroupIDs = normalizeInt64IDs(*input.GroupIDs)
}
if input.RecordNonHits != nil {
cfg.RecordNonHits = *input.RecordNonHits
}
if input.ClearAPIKey {
cfg.APIKey = ""
cfg.APIKeys = []string{}
} else {
apiKeysMode := normalizeContentModerationAPIKeysMode(input.APIKeysMode)
if input.DeleteAPIKeyHashes != nil && apiKeysMode != contentModerationAPIKeysModeReplace {
cfg.APIKeys = deleteModerationAPIKeysByHash(cfg.apiKeys(), *input.DeleteAPIKeyHashes)
cfg.APIKey = ""
}
if input.APIKeys != nil {
if apiKeysMode == contentModerationAPIKeysModeReplace {
cfg.APIKeys = normalizeModerationAPIKeys(*input.APIKeys)
} else {
cfg.APIKeys = normalizeModerationAPIKeys(append(cfg.apiKeys(), *input.APIKeys...))
}
cfg.APIKey = ""
}
if input.APIKey != nil && strings.TrimSpace(*input.APIKey) != "" {
cfg.APIKeys = normalizeModerationAPIKeys(append(cfg.APIKeys, *input.APIKey))
cfg.APIKey = ""
}
}
if err := s.validateConfig(ctx, cfg); err != nil {
return nil, err
}
cfg.normalize()
raw, err := json.Marshal(cfg)
if err != nil {
return nil, fmt.Errorf("marshal content moderation config: %w", err)
}
if err := s.settingRepo.Set(ctx, SettingKeyContentModerationConfig, string(raw)); err != nil {
return nil, fmt.Errorf("save content moderation config: %w", err)
}
return s.configView(cfg), nil
}
func (s *ContentModerationService) TestAPIKeys(ctx context.Context, input TestContentModerationAPIKeysInput) (*TestContentModerationAPIKeysResult, error) {
cfg, err := s.loadConfig(ctx)
if err != nil {
return nil, err
}
keys := normalizeModerationAPIKeys(input.APIKeys)
configured := false
if len(keys) == 0 {
keys = cfg.apiKeys()
configured = true
}
if strings.TrimSpace(input.BaseURL) != "" {
cfg.BaseURL = input.BaseURL
}
if strings.TrimSpace(input.Model) != "" {
cfg.Model = input.Model
}
if input.TimeoutMS > 0 {
cfg.TimeoutMS = input.TimeoutMS
}
cfg.normalize()
testInput, imageCount, err := buildModerationTestInput(input.Prompt, input.Images)
if err != nil {
return nil, err
}
auditOnly := contentModerationTestHasAuditInput(input.Prompt, input.Images)
if configured && auditOnly {
key, ok := s.nextUsableAPIKey(cfg)
if !ok {
return &TestContentModerationAPIKeysResult{
Items: s.apiKeyStatuses(keys),
ImageCount: imageCount,
}, nil
}
keys = []string{key}
}
if len(keys) == 0 {
return &TestContentModerationAPIKeysResult{Items: []ContentModerationAPIKeyStatus{}, ImageCount: imageCount}, nil
}
items := make([]ContentModerationAPIKeyStatus, 0, len(keys))
var auditResult *ContentModerationTestAuditResult
for idx, key := range keys {
start := time.Now()
httpStatus := 0
result, err := s.callModerationOnceWithInput(ctx, cfg, key, testInput, &httpStatus)
latency := int(time.Since(start).Milliseconds())
keyHash := moderationAPIKeyHash(key)
if err != nil {
s.markAPIKeyError(key, err.Error(), latency, httpStatus)
} else {
s.markAPIKeySuccess(key, latency, httpStatus)
if auditResult == nil {
auditResult = buildContentModerationTestAuditResult(result, cfg.Thresholds)
}
}
status := s.apiKeyStatusForHash(idx, keyHash, maskSecretTail(key), configured)
status.LastTested = true
items = append(items, status)
}
return &TestContentModerationAPIKeysResult{Items: items, AuditResult: auditResult, ImageCount: imageCount}, nil
}
func (s *ContentModerationService) Check(ctx context.Context, input ContentModerationCheckInput) (*ContentModerationDecision, error) {
allow := &ContentModerationDecision{Allowed: true, Action: ContentModerationActionAllow}
if s == nil || s.settingRepo == nil || s.repo == nil {
slog.Info("content_moderation.skip_unavailable",
"user_id", input.UserID,
"api_key_id", input.APIKeyID,
"group_id", contentModerationLogGroupID(input.GroupID),
"endpoint", input.Endpoint,
"protocol", input.Protocol)
return allow, nil
}
if !s.isRiskControlEnabled(ctx) {
slog.Info("content_moderation.skip_feature_disabled",
"user_id", input.UserID,
"api_key_id", input.APIKeyID,
"group_id", contentModerationLogGroupID(input.GroupID),
"endpoint", input.Endpoint,
"protocol", input.Protocol)
return allow, nil
}
cfg, err := s.loadConfig(ctx)
if err != nil {
slog.Warn("content_moderation.skip_config_load_failed",
"user_id", input.UserID,
"api_key_id", input.APIKeyID,
"group_id", contentModerationLogGroupID(input.GroupID),
"endpoint", input.Endpoint,
"protocol", input.Protocol,
"error", err)
return allow, nil
}
inScope := cfg.includesGroup(input.GroupID)
slog.Info("content_moderation.config_loaded",
"user_id", input.UserID,
"api_key_id", input.APIKeyID,
"group_id", contentModerationLogGroupID(input.GroupID),
"group_name", input.GroupName,
"endpoint", input.Endpoint,
"provider", input.Provider,
"protocol", input.Protocol,
"model", input.Model,
"enabled", cfg.Enabled,
"mode", cfg.Mode,
"all_groups", cfg.AllGroups,
"configured_group_ids", cfg.GroupIDs,
"in_scope", inScope,
"sample_rate", cfg.SampleRate,
"api_key_count", len(cfg.apiKeys()),
"pre_hash_check_enabled", cfg.PreHashCheckEnabled,
"record_non_hits", cfg.RecordNonHits)
if !cfg.Enabled {
slog.Info("content_moderation.skip_config_disabled",
"user_id", input.UserID,
"api_key_id", input.APIKeyID,
"group_id", contentModerationLogGroupID(input.GroupID),
"endpoint", input.Endpoint,
"protocol", input.Protocol)
return allow, nil
}
if cfg.Mode == ContentModerationModeOff {
slog.Info("content_moderation.skip_mode_off",
"user_id", input.UserID,
"api_key_id", input.APIKeyID,
"group_id", contentModerationLogGroupID(input.GroupID),
"endpoint", input.Endpoint,
"protocol", input.Protocol)
return allow, nil
}
if !inScope {
slog.Info("content_moderation.skip_group_out_of_scope",
"user_id", input.UserID,
"api_key_id", input.APIKeyID,
"group_id", contentModerationLogGroupID(input.GroupID),
"group_name", input.GroupName,
"endpoint", input.Endpoint,
"protocol", input.Protocol,
"all_groups", cfg.AllGroups,
"configured_group_ids", cfg.GroupIDs)
return allow, nil
}
content := ExtractContentModerationInput(input.Protocol, input.Body)
if content.IsEmpty() {
slog.Info("content_moderation.skip_empty_input",
"user_id", input.UserID,
"api_key_id", input.APIKeyID,
"group_id", contentModerationLogGroupID(input.GroupID),
"endpoint", input.Endpoint,
"protocol", input.Protocol,
"body_bytes", len(input.Body))
return allow, nil
}
content.Normalize()
slog.Info("content_moderation.input_extracted",
"user_id", input.UserID,
"api_key_id", input.APIKeyID,
"group_id", contentModerationLogGroupID(input.GroupID),
"endpoint", input.Endpoint,
"protocol", input.Protocol,
"text_runes", len([]rune(content.Text)),
"image_count", len(content.Images))
hashText := content.Hash()
if cfg.PreHashCheckEnabled && s.hashCache != nil {
matched, err := s.hashCache.HasFlaggedInputHash(ctx, hashText)
if err != nil {
slog.Warn("content_moderation.hash_check_failed", "user_id", input.UserID, "endpoint", input.Endpoint, "error", err)
}
if matched {
slog.Info("content_moderation.hash_block",
"user_id", input.UserID,
"api_key_id", input.APIKeyID,
"group_id", contentModerationLogGroupID(input.GroupID),
"endpoint", input.Endpoint,
"protocol", input.Protocol,
"input_hash", hashText)
message := cfg.BlockMessage
if message != "" {
message = fmt.Sprintf("%shash: %s", message, hashText)
}
return &ContentModerationDecision{
Allowed: false,
Blocked: true,
Flagged: true,
Message: message,
StatusCode: cfg.BlockStatus,
InputHash: hashText,
Action: ContentModerationActionHashBlock,
}, nil
}
}
if !cfg.shouldSample(hashText) {
slog.Info("content_moderation.skip_sample_rate",
"user_id", input.UserID,
"api_key_id", input.APIKeyID,
"group_id", contentModerationLogGroupID(input.GroupID),
"endpoint", input.Endpoint,
"protocol", input.Protocol,
"sample_rate", cfg.SampleRate)
return allow, nil
}
if len(cfg.apiKeys()) == 0 {
slog.Warn("content_moderation.skip_no_audit_api_keys",
"user_id", input.UserID,
"api_key_id", input.APIKeyID,
"group_id", contentModerationLogGroupID(input.GroupID),
"endpoint", input.Endpoint,
"protocol", input.Protocol)
return allow, nil
}
if cfg.Mode == ContentModerationModeObserve {
slog.Info("content_moderation.enqueue_observe",
"user_id", input.UserID,
"api_key_id", input.APIKeyID,
"group_id", contentModerationLogGroupID(input.GroupID),
"endpoint", input.Endpoint,
"protocol", input.Protocol,
"queue_len", len(s.asyncQueue))
s.enqueueAsync(input, cfg, content, hashText)
return allow, nil
}
return s.checkSync(ctx, input, cfg, content, hashText, nil, true), nil
}
func (s *ContentModerationService) checkSync(ctx context.Context, input ContentModerationCheckInput, cfg *ContentModerationConfig, content ContentModerationInput, hashText string, queueDelay *int, allowBlock bool) *ContentModerationDecision {
allow := &ContentModerationDecision{Allowed: true, Action: ContentModerationActionAllow}
start := time.Now()
result, err := s.callModeration(ctx, cfg, content.ModerationInput())
latency := int(time.Since(start).Milliseconds())
if err != nil {
slog.Warn("content_moderation.audit_api_failed",
"user_id", input.UserID,
"api_key_id", input.APIKeyID,
"group_id", contentModerationLogGroupID(input.GroupID),
"endpoint", input.Endpoint,
"protocol", input.Protocol,
"mode", cfg.Mode,
"allow_block", allowBlock,
"queue_delay_ms", queueDelay,
"latency_ms", latency,
"error", err)
if queueDelay != nil {
s.asyncErrors.Add(1)
}
if cfg.RecordNonHits {
log := s.buildLog(input, cfg, ContentModerationActionError, false, "", 0, nil, content.ExcerptText(), &latency, queueDelay, err.Error())
_ = s.repo.CreateLog(ctx, log)
}
return allow
}
flagged, highestCategory, highestScore := evaluateModerationScores(result.CategoryScores, cfg.Thresholds)
action := ContentModerationActionAllow
blocked := false
if allowBlock && flagged && cfg.Mode == ContentModerationModePreBlock {
action = ContentModerationActionBlock
blocked = true
}
slog.Info("content_moderation.audit_result",
"user_id", input.UserID,
"api_key_id", input.APIKeyID,
"group_id", contentModerationLogGroupID(input.GroupID),
"group_name", input.GroupName,
"endpoint", input.Endpoint,
"protocol", input.Protocol,
"mode", cfg.Mode,
"allow_block", allowBlock,
"flagged", flagged,
"blocked", blocked,
"action", action,
"highest_category", highestCategory,
"highest_score", highestScore,
"latency_ms", latency,
"queue_delay_ms", queueDelay)
if flagged || cfg.RecordNonHits {
log := s.buildLog(input, cfg, action, flagged, highestCategory, highestScore, result.CategoryScores, content.ExcerptText(), &latency, queueDelay, "")
if flagged && s.hashCache != nil {
if err := s.hashCache.RecordFlaggedInputHash(ctx, hashText); err != nil {
slog.Warn("content_moderation.record_hash_failed", "user_id", input.UserID, "endpoint", input.Endpoint, "error", err)
}
}
s.applyFlaggedSideEffects(ctx, cfg, log)
_ = s.repo.CreateLog(ctx, log)
}
if blocked {
return &ContentModerationDecision{
Allowed: false,
Blocked: true,
Flagged: true,
Message: cfg.BlockMessage,
StatusCode: cfg.BlockStatus,
HighestCategory: highestCategory,
HighestScore: highestScore,
CategoryScores: result.CategoryScores,
Action: action,
}
}
return &ContentModerationDecision{
Allowed: true,
Flagged: flagged,
Message: "",
HighestCategory: highestCategory,
HighestScore: highestScore,
CategoryScores: result.CategoryScores,
Action: action,
}
}
func (s *ContentModerationService) enqueueAsync(input ContentModerationCheckInput, cfg *ContentModerationConfig, content ContentModerationInput, hashText string) {
if s == nil || s.asyncQueue == nil {
return
}
queueSize := defaultContentModerationQueueSize
if cfg != nil && cfg.QueueSize > 0 {
queueSize = cfg.QueueSize
}
if len(s.asyncQueue) >= queueSize {
slog.Warn("content_moderation.async_queue_full", "user_id", input.UserID, "endpoint", input.Endpoint, "queue_size", queueSize)
s.asyncDropped.Add(1)
return
}
task := contentModerationTask{
input: input,
content: content,
inputHash: hashText,
enqueuedAt: time.Now(),
}
select {
case s.asyncQueue <- task:
s.asyncEnqueued.Add(1)
default:
slog.Warn("content_moderation.async_queue_full", "user_id", input.UserID, "endpoint", input.Endpoint)
s.asyncDropped.Add(1)
}
}
func (s *ContentModerationService) worker(id int) {
for {
ctx, cancel := context.WithTimeout(context.Background(), maxContentModerationTimeoutMS*time.Millisecond+10*time.Second)
cfg, err := s.loadConfig(ctx)
if err != nil || !cfg.Enabled || cfg.Mode == ContentModerationModeOff || len(cfg.apiKeys()) == 0 || id >= cfg.WorkerCount {
cancel()
time.Sleep(time.Second)
continue
}
task, ok := s.dequeueAsyncTask(ctx, time.Second)
if !ok {
cancel()
continue
}
func() {
defer cancel()
defer func() {
if r := recover(); r != nil {
slog.Error("content_moderation.worker_panic", "worker_id", id, "recover", r)
}
}()
if !cfg.includesGroup(task.input.GroupID) {
return
}
s.asyncActive.Add(1)
defer s.asyncActive.Add(-1)
queueDelay := int(time.Since(task.enqueuedAt).Milliseconds())
_ = s.checkSync(ctx, task.input, cfg, task.content, task.inputHash, &queueDelay, false)
s.asyncProcessed.Add(1)
}()
}
}
func (s *ContentModerationService) dequeueAsyncTask(ctx context.Context, idleWait time.Duration) (contentModerationTask, bool) {
var zero contentModerationTask
if s == nil || s.asyncQueue == nil {
return zero, false
}
if idleWait <= 0 {
idleWait = time.Second
}
timer := time.NewTimer(idleWait)
defer timer.Stop()
select {
case task, ok := <-s.asyncQueue:
return task, ok
case <-ctx.Done():
return zero, false
case <-timer.C:
return zero, false
}
}
func (s *ContentModerationService) ListLogs(ctx context.Context, filter ContentModerationLogFilter) ([]ContentModerationLog, *pagination.PaginationResult, error) {
if filter.Pagination.Page <= 0 {
filter.Pagination.Page = 1
}
if filter.Pagination.PageSize <= 0 {
filter.Pagination.PageSize = 20
}
if filter.Pagination.PageSize > 100 {
filter.Pagination.PageSize = 100
}
if filter.Pagination.SortOrder == "" {
filter.Pagination.SortOrder = pagination.SortOrderDesc
}
return s.repo.ListLogs(ctx, filter)
}
func (s *ContentModerationService) UnbanUser(ctx context.Context, userID int64) (*ContentModerationUnbanUserResult, error) {
if s == nil || s.userRepo == nil {
return nil, infraerrors.InternalServer("CONTENT_MODERATION_USER_REPOSITORY_UNAVAILABLE", "用户仓储不可用")
}
if userID <= 0 {
return nil, infraerrors.BadRequest("INVALID_USER_ID", "用户 ID 无效")
}
user, err := s.userRepo.GetByID(ctx, userID)
if err != nil {
if errors.Is(err, ErrUserNotFound) {
return nil, infraerrors.NotFound("USER_NOT_FOUND", "用户不存在")
}
return nil, fmt.Errorf("get content moderation unban user: %w", err)
}
if user.Status != StatusActive {
user.Status = StatusActive
if err := s.userRepo.Update(ctx, user); err != nil {
return nil, fmt.Errorf("update content moderation unban user: %w", err)
}
}
if s.authCacheInvalidator != nil {
s.authCacheInvalidator.InvalidateAuthCacheByUserID(ctx, userID)
}
return &ContentModerationUnbanUserResult{
UserID: userID,
Status: StatusActive,
}, nil
}
func (s *ContentModerationService) DeleteFlaggedInputHash(ctx context.Context, inputHash string) (*ContentModerationDeleteHashResult, error) {
inputHash = normalizeContentModerationHash(inputHash)
if inputHash == "" {
return nil, infraerrors.BadRequest("INVALID_CONTENT_MODERATION_HASH", "风险输入哈希无效")
}
if s == nil || s.hashCache == nil {
return nil, infraerrors.InternalServer("CONTENT_MODERATION_HASH_CACHE_UNAVAILABLE", "内容审计哈希缓存不可用")
}
deleted, err := s.hashCache.DeleteFlaggedInputHash(ctx, inputHash)
if err != nil {
return nil, fmt.Errorf("delete content moderation flagged hash: %w", err)
}
return &ContentModerationDeleteHashResult{
InputHash: inputHash,
Deleted: deleted,
}, nil
}
func (s *ContentModerationService) ClearFlaggedInputHashes(ctx context.Context) (*ContentModerationClearHashesResult, error) {
if s == nil || s.hashCache == nil {
return nil, infraerrors.InternalServer("CONTENT_MODERATION_HASH_CACHE_UNAVAILABLE", "内容审计哈希缓存不可用")
}
deleted, err := s.hashCache.ClearFlaggedInputHashes(ctx)
if err != nil {
return nil, fmt.Errorf("clear content moderation flagged hashes: %w", err)
}
return &ContentModerationClearHashesResult{Deleted: deleted}, nil
}
func (s *ContentModerationService) GetStatus(ctx context.Context) (*ContentModerationRuntimeStatus, error) {
if s == nil {
return &ContentModerationRuntimeStatus{}, nil
}
cfg, err := s.loadConfig(ctx)
if err != nil {
return nil, err
}
riskEnabled := s.isRiskControlEnabled(ctx)
active := int(s.asyncActive.Load())
if active < 0 {
active = 0
}
if active > cfg.WorkerCount {
active = cfg.WorkerCount
}
queueLength := 0
if s.asyncQueue != nil {
queueLength = len(s.asyncQueue)
}
queueUsage := 0.0
if cfg.QueueSize > 0 {
queueUsage = float64(queueLength) * 100 / float64(cfg.QueueSize)
}
var flaggedHashCount int64
if s.hashCache != nil {
if n, err := s.hashCache.CountFlaggedInputHashes(ctx); err == nil {
flaggedHashCount = n
} else {
slog.Warn("content_moderation.hash_count_failed", "error", err)
}
}
var lastCleanupAt *time.Time
if unix := s.lastCleanupUnix.Load(); unix > 0 {
t := time.Unix(unix, 0)
lastCleanupAt = &t
}
return &ContentModerationRuntimeStatus{
Enabled: cfg.Enabled,
RiskControlEnabled: riskEnabled,
Mode: cfg.Mode,
WorkerCount: cfg.WorkerCount,
MaxWorkers: maxContentModerationWorkerCount,
ActiveWorkers: active,
IdleWorkers: cfg.WorkerCount - active,
QueueSize: cfg.QueueSize,
QueueLength: queueLength,
QueueUsagePercent: queueUsage,
Enqueued: s.asyncEnqueued.Load(),
Dropped: s.asyncDropped.Load(),
Processed: s.asyncProcessed.Load(),
Errors: s.asyncErrors.Load(),
APIKeyStatuses: s.apiKeyStatuses(cfg.apiKeys()),
FlaggedHashCount: flaggedHashCount,
LastCleanupAt: lastCleanupAt,
LastCleanupDeletedHit: s.lastCleanupDeletedHit.Load(),
LastCleanupDeletedNonHit: s.lastCleanupDeletedNonHit.Load(),
}, nil
}
func (s *ContentModerationService) cleanupWorker() {
timer := time.NewTimer(contentModerationCleanupDelay)
defer timer.Stop()
for {
<-timer.C
s.runCleanupOnce()
timer.Reset(contentModerationCleanupInterval)
}
}
func (s *ContentModerationService) runCleanupOnce() {
if s == nil || s.repo == nil || s.settingRepo == nil {
return
}
ctx, cancel := context.WithTimeout(context.Background(), contentModerationCleanupTimeout)
defer cancel()
cfg, err := s.loadConfig(ctx)
if err != nil {
slog.Warn("content_moderation.cleanup_load_config_failed", "error", err)
return
}
now := time.Now()
hitBefore := now.AddDate(0, 0, -cfg.HitRetentionDays)
nonHitBefore := now.AddDate(0, 0, -cfg.NonHitRetentionDays)
result, err := s.repo.CleanupExpiredLogs(ctx, hitBefore, nonHitBefore)
if err != nil {
slog.Warn("content_moderation.cleanup_failed", "error", err)
return
}
if result == nil {
return
}
s.lastCleanupUnix.Store(result.FinishedAt.Unix())
s.lastCleanupDeletedHit.Store(result.DeletedHit)
s.lastCleanupDeletedNonHit.Store(result.DeletedNonHit)
}
func (s *ContentModerationService) loadConfig(ctx context.Context) (*ContentModerationConfig, error) {
cfg := defaultContentModerationConfig()
raw, err := s.settingRepo.GetValue(ctx, SettingKeyContentModerationConfig)
if err != nil {
if errors.Is(err, ErrSettingNotFound) {
cfg.normalize()
return cfg, nil
}
return nil, fmt.Errorf("get content moderation config: %w", err)
}
if strings.TrimSpace(raw) == "" {
cfg.normalize()
return cfg, nil
}
if err := json.Unmarshal([]byte(raw), cfg); err != nil {
return nil, infraerrors.BadRequest("INVALID_CONTENT_MODERATION_CONFIG", "内容审计配置不是有效 JSON")
}
cfg.normalize()
return cfg, nil
}
func (s *ContentModerationService) isRiskControlEnabled(ctx context.Context) bool {
raw, err := s.settingRepo.GetValue(ctx, SettingKeyRiskControlEnabled)
if err != nil {
return false
}
return raw == "true"
}
func (s *ContentModerationService) validateConfig(ctx context.Context, cfg *ContentModerationConfig) error {
if cfg == nil {
return infraerrors.BadRequest("INVALID_CONTENT_MODERATION_CONFIG", "内容审计配置不能为空")
}
cfg.normalize()
switch cfg.Mode {
case ContentModerationModeOff, ContentModerationModeObserve, ContentModerationModePreBlock:
default:
return infraerrors.BadRequest("INVALID_CONTENT_MODERATION_MODE", "内容审计模式无效")
}
if _, err := url.ParseRequestURI(cfg.BaseURL); err != nil {
return infraerrors.BadRequest("INVALID_CONTENT_MODERATION_BASE_URL", "OpenAI Base URL 无效")
}
if cfg.BlockStatus < 400 || cfg.BlockStatus > 599 {
return infraerrors.BadRequest("INVALID_CONTENT_MODERATION_BLOCK_STATUS", "拦截 HTTP 状态码必须在 400-599 之间")
}
if !cfg.AllGroups && len(cfg.GroupIDs) > 0 && s.groupRepo != nil {
for _, groupID := range cfg.GroupIDs {
if _, err := s.groupRepo.GetByIDLite(ctx, groupID); err != nil {
return infraerrors.BadRequest("INVALID_CONTENT_MODERATION_GROUP", fmt.Sprintf("审计分组不存在: %d", groupID))
}
}
}
return nil
}
func (s *ContentModerationService) callModeration(ctx context.Context, cfg *ContentModerationConfig, input any) (*moderationAPIResult, error) {
attempts := cfg.RetryCount + 1
if attempts <= 0 {
attempts = 1
}
if attempts > maxContentModerationRetryCount+1 {
attempts = maxContentModerationRetryCount + 1
}
var lastErr error
for attempt := 0; attempt < attempts; attempt++ {
key, ok := s.nextUsableAPIKey(cfg)
if !ok {
lastErr = errors.New("no moderation api key available")
break
}
start := time.Now()
httpStatus := 0
result, err := s.callModerationOnceWithInput(ctx, cfg, key, input, &httpStatus)
latency := int(time.Since(start).Milliseconds())
if err == nil {
s.markAPIKeySuccess(key, latency, httpStatus)
return result, nil
}
s.markAPIKeyError(key, err.Error(), latency, httpStatus)
lastErr = err
if httpStatus == http.StatusBadRequest {
break
}
if attempt == attempts-1 {
break
}
wait := time.Duration(100*(attempt+1)) * time.Millisecond
select {
case <-ctx.Done():
return nil, ctx.Err()
case <-time.After(wait):
}
}
return nil, lastErr
}
func (s *ContentModerationService) callModerationOnceWithInput(ctx context.Context, cfg *ContentModerationConfig, apiKey string, input any, httpStatus *int) (*moderationAPIResult, error) {
base := strings.TrimRight(cfg.BaseURL, "/")
endpoint, err := url.JoinPath(base, "/v1/moderations")
if err != nil {
return nil, err
}
payload := moderationAPIRequest{
Model: cfg.Model,
Input: input,
}
raw, err := json.Marshal(payload)
if err != nil {
return nil, err
}
timeout := time.Duration(cfg.TimeoutMS) * time.Millisecond
reqCtx, cancel := context.WithTimeout(ctx, timeout)
defer cancel()
req, err := http.NewRequestWithContext(reqCtx, http.MethodPost, endpoint, bytes.NewReader(raw))
if err != nil {
return nil, err
}
req.Header.Set("Authorization", "Bearer "+apiKey)
req.Header.Set("Content-Type", "application/json")
client := s.httpClient
if client == nil {
client = http.DefaultClient
}
resp, err := client.Do(req)
if err != nil {
return nil, err
}
defer func() { _ = resp.Body.Close() }()
if httpStatus != nil {
*httpStatus = resp.StatusCode
}
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
body, _ := io.ReadAll(io.LimitReader(resp.Body, 512))
return nil, fmt.Errorf("moderation api status %d: %s", resp.StatusCode, strings.TrimSpace(string(body)))
}
var out moderationAPIResponse
if err := json.NewDecoder(resp.Body).Decode(&out); err != nil {
return nil, err
}
if len(out.Results) == 0 {
return nil, errors.New("moderation api returned empty results")
}
return &out.Results[0], nil
}
func (s *ContentModerationService) buildLog(input ContentModerationCheckInput, cfg *ContentModerationConfig, action string, flagged bool, highestCategory string, highestScore float64, scores map[string]float64, text string, latency *int, queueDelay *int, errText string) *ContentModerationLog {
var userID *int64
if input.UserID > 0 {
userID = &input.UserID
}
var apiKeyID *int64
if input.APIKeyID > 0 {
apiKeyID = &input.APIKeyID
}
return &ContentModerationLog{
RequestID: input.RequestID,
UserID: userID,
UserEmail: input.UserEmail,
APIKeyID: apiKeyID,
APIKeyName: input.APIKeyName,
GroupID: cloneInt64Ptr(input.GroupID),
GroupName: input.GroupName,
Endpoint: input.Endpoint,
Provider: input.Provider,
Model: input.Model,
Mode: cfg.Mode,
Action: action,
Flagged: flagged,
HighestCategory: highestCategory,
HighestScore: highestScore,
CategoryScores: cloneFloatMap(scores),
ThresholdSnapshot: cloneFloatMap(cfg.Thresholds),
InputExcerpt: trimRunes(redactContentModerationSecrets(text), maxModerationExcerptRunes),
UpstreamLatencyMS: latency,
QueueDelayMS: queueDelay,
Error: errText,
}
}
func (s *ContentModerationService) applyFlaggedSideEffects(ctx context.Context, cfg *ContentModerationConfig, log *ContentModerationLog) {
if s == nil || cfg == nil || log == nil || !log.Flagged || log.UserID == nil || *log.UserID <= 0 {
return
}
count := 1
if s.repo != nil && cfg.ViolationWindowHours > 0 {
since := time.Now().Add(-time.Duration(cfg.ViolationWindowHours) * time.Hour)
if n, err := s.repo.CountFlaggedByUserSince(ctx, *log.UserID, since); err == nil {
count = n + 1
}
}
log.ViolationCount = count
autoBanJustApplied := false
if cfg.AutoBanEnabled && cfg.BanThreshold > 0 && count >= cfg.BanThreshold && s.userRepo != nil {
user, err := s.userRepo.GetByID(ctx, *log.UserID)
if err != nil {
slog.Warn("content_moderation.ban_get_user_failed", "user_id", *log.UserID, "error", err)
return
}
if user.Status != StatusDisabled {
user.Status = StatusDisabled
if err := s.userRepo.Update(ctx, user); err != nil {
slog.Warn("content_moderation.ban_update_user_failed", "user_id", *log.UserID, "error", err)
return
}
if s.authCacheInvalidator != nil {
s.authCacheInvalidator.InvalidateAuthCacheByUserID(ctx, *log.UserID)
}
autoBanJustApplied = true
}
log.AutoBanned = true
}
if s.emailService == nil || strings.TrimSpace(log.UserEmail) == "" {
return
}
emailSent := false
if cfg.EmailOnHit {
if err := s.sendViolationEmail(ctx, cfg, log); err != nil {
slog.Warn("content_moderation.email_failed", "user_id", *log.UserID, "email", log.UserEmail, "error", err)
} else {
emailSent = true
}
}
if autoBanJustApplied {
if err := s.sendAccountDisabledEmail(ctx, cfg, log); err != nil {
slog.Warn("content_moderation.ban_email_failed", "user_id", *log.UserID, "email", log.UserEmail, "error", err)
} else {
emailSent = true
}
}
log.EmailSent = emailSent
}
func (s *ContentModerationService) sendViolationEmail(ctx context.Context, cfg *ContentModerationConfig, log *ContentModerationLog) error {
siteName := s.siteName(ctx)
subject := fmt.Sprintf("[%s] 账户风控提醒 / Risk Control Notice", sanitizeEmailHeader(siteName))
body := buildContentModerationViolationEmailBody(siteName, log, cfg)
return s.emailService.SendEmail(ctx, log.UserEmail, subject, body)
}
func (s *ContentModerationService) sendAccountDisabledEmail(ctx context.Context, cfg *ContentModerationConfig, log *ContentModerationLog) error {
siteName := s.siteName(ctx)
subject := fmt.Sprintf("[%s] 账户已被禁用 / Account Disabled", sanitizeEmailHeader(siteName))
body := buildContentModerationAccountDisabledEmailBody(siteName, log, cfg)
return s.emailService.SendEmail(ctx, log.UserEmail, subject, body)
}
func (s *ContentModerationService) siteName(ctx context.Context) string {
if s == nil || s.settingRepo == nil {
return "Sub2API"
}
name, err := s.settingRepo.GetValue(ctx, SettingKeySiteName)
if err != nil || strings.TrimSpace(name) == "" {
return "Sub2API"
}
return strings.TrimSpace(name)
}
func defaultContentModerationConfig() *ContentModerationConfig {
return &ContentModerationConfig{
Enabled: false,
Mode: ContentModerationModePreBlock,
BaseURL: defaultContentModerationBaseURL,
Model: defaultContentModerationModel,
TimeoutMS: defaultContentModerationTimeoutMS,
SampleRate: 100,
AllGroups: true,
GroupIDs: []int64{},
RecordNonHits: false,
Thresholds: ContentModerationDefaultThresholds(),
WorkerCount: defaultContentModerationWorkerCount,
QueueSize: defaultContentModerationQueueSize,
BlockStatus: defaultContentModerationBlockHTTPStatus,
BlockMessage: defaultContentModerationBlockMessage,
EmailOnHit: true,
AutoBanEnabled: true,
BanThreshold: defaultContentModerationBanThreshold,
ViolationWindowHours: defaultContentModerationViolationWindowHours,
RetryCount: defaultContentModerationRetryCount,
HitRetentionDays: defaultContentModerationHitRetentionDays,
NonHitRetentionDays: defaultContentModerationNonHitRetentionDays,
PreHashCheckEnabled: false,
}
}
func (cfg *ContentModerationConfig) normalize() {
if cfg.APIKey != "" {
cfg.APIKeys = normalizeModerationAPIKeys(append(cfg.APIKeys, cfg.APIKey))
cfg.APIKey = ""
} else {
cfg.APIKeys = normalizeModerationAPIKeys(cfg.APIKeys)
}
if cfg.Mode == "" {
cfg.Mode = ContentModerationModePreBlock
}
if cfg.BaseURL == "" {
cfg.BaseURL = defaultContentModerationBaseURL
}
cfg.BaseURL = strings.TrimRight(strings.TrimSpace(cfg.BaseURL), "/")
if cfg.Model == "" {
cfg.Model = defaultContentModerationModel
}
cfg.Model = strings.TrimSpace(cfg.Model)
if cfg.TimeoutMS <= 0 {
cfg.TimeoutMS = defaultContentModerationTimeoutMS
}
if cfg.TimeoutMS > maxContentModerationTimeoutMS {
cfg.TimeoutMS = maxContentModerationTimeoutMS
}
if cfg.SampleRate < 0 {
cfg.SampleRate = 0
}
if cfg.SampleRate > 100 {
cfg.SampleRate = 100
}
if cfg.WorkerCount <= 0 {
cfg.WorkerCount = defaultContentModerationWorkerCount
}
if cfg.WorkerCount > maxContentModerationWorkerCount {
cfg.WorkerCount = maxContentModerationWorkerCount
}
if cfg.QueueSize <= 0 {
cfg.QueueSize = defaultContentModerationQueueSize
}
if cfg.QueueSize > maxContentModerationQueueSize {
cfg.QueueSize = maxContentModerationQueueSize
}
if strings.TrimSpace(cfg.BlockMessage) == "" {
cfg.BlockMessage = defaultContentModerationBlockMessage
}
cfg.BlockMessage = strings.TrimSpace(cfg.BlockMessage)
if cfg.BlockStatus <= 0 {
cfg.BlockStatus = defaultContentModerationBlockHTTPStatus
}
if cfg.BanThreshold <= 0 {
cfg.BanThreshold = defaultContentModerationBanThreshold
}
if cfg.ViolationWindowHours <= 0 {
cfg.ViolationWindowHours = defaultContentModerationViolationWindowHours
}
if cfg.RetryCount < 0 {
cfg.RetryCount = 0
}
if cfg.RetryCount > maxContentModerationRetryCount {
cfg.RetryCount = maxContentModerationRetryCount
}
if cfg.HitRetentionDays <= 0 {
cfg.HitRetentionDays = defaultContentModerationHitRetentionDays
}
if cfg.HitRetentionDays > maxContentModerationRetentionDays {
cfg.HitRetentionDays = maxContentModerationRetentionDays
}
if cfg.NonHitRetentionDays <= 0 {
cfg.NonHitRetentionDays = defaultContentModerationNonHitRetentionDays
}
if cfg.NonHitRetentionDays > maxContentModerationNonHitRetentionDays {
cfg.NonHitRetentionDays = maxContentModerationNonHitRetentionDays
}
cfg.GroupIDs = normalizeInt64IDs(cfg.GroupIDs)
cfg.Thresholds = mergeContentModerationThresholds(ContentModerationDefaultThresholds(), cfg.Thresholds)
}
func (cfg *ContentModerationConfig) includesGroup(groupID *int64) bool {
if cfg.AllGroups {
return true
}
if groupID == nil {
return false
}
for _, id := range cfg.GroupIDs {
if id == *groupID {
return true
}
}
return false
}
func contentModerationLogGroupID(groupID *int64) int64 {
if groupID == nil {
return 0
}
return *groupID
}
func (cfg *ContentModerationConfig) shouldSample(hashText string) bool {
if cfg.SampleRate >= 100 {
return true
}
if cfg.SampleRate <= 0 {
return false
}
raw, err := hex.DecodeString(hashText)
if err != nil || len(raw) < 2 {
return true
}
return int(binary.BigEndian.Uint16(raw[:2])%100) < cfg.SampleRate
}
func (cfg *ContentModerationConfig) apiKeys() []string {
if cfg == nil {
return nil
}
return normalizeModerationAPIKeys(cfg.APIKeys)
}
func (s *ContentModerationService) nextUsableAPIKey(cfg *ContentModerationConfig) (string, bool) {
keys := cfg.apiKeys()
if len(keys) == 0 {
return "", false
}
now := time.Now()
for i := 0; i < len(keys); i++ {
idx := int(s.apiKeyCursor.Add(1)-1) % len(keys)
key := keys[idx]
if !s.isAPIKeyFrozen(key, now) {
return key, true
}
}
return "", false
}
func (s *ContentModerationService) isAPIKeyFrozen(key string, now time.Time) bool {
hash := moderationAPIKeyHash(key)
if hash == "" || s == nil {
return false
}
s.keyHealthMu.Lock()
defer s.keyHealthMu.Unlock()
state := s.keyHealth[hash]
return state != nil && state.FrozenUntil.After(now)
}
func (s *ContentModerationService) markAPIKeySuccess(key string, latencyMS int, httpStatus int) {
hash := moderationAPIKeyHash(key)
if hash == "" || s == nil {
return
}
s.keyHealthMu.Lock()
defer s.keyHealthMu.Unlock()
state := s.ensureAPIKeyHealthLocked(hash, maskSecretTail(key))
state.FailureCount = 0
state.SuccessCount++
state.LastError = ""
state.LastCheckedAt = time.Now()
state.FrozenUntil = time.Time{}
state.LastLatencyMS = latencyMS
state.LastHTTPStatus = httpStatus
state.LastTested = true
}
func (s *ContentModerationService) markAPIKeyError(key string, errText string, latencyMS int, httpStatus int) {
hash := moderationAPIKeyHash(key)
if hash == "" || s == nil {
return
}
s.keyHealthMu.Lock()
defer s.keyHealthMu.Unlock()
state := s.ensureAPIKeyHealthLocked(hash, maskSecretTail(key))
if contentModerationFreezeDurationForHTTPStatus(httpStatus) > 0 {
state.FailureCount++
}
state.LastError = trimRunes(errText, 180)
state.LastCheckedAt = time.Now()
state.LastLatencyMS = latencyMS
state.LastHTTPStatus = httpStatus
state.LastTested = true
if freezeDuration := contentModerationFreezeDurationForHTTPStatus(httpStatus); freezeDuration > 0 {
state.FrozenUntil = time.Now().Add(freezeDuration)
}
}
func contentModerationFreezeDurationForHTTPStatus(httpStatus int) time.Duration {
switch httpStatus {
case 0, http.StatusBadRequest:
return 0
case http.StatusUnauthorized, http.StatusForbidden:
return contentModerationKeyAuthFreezeDuration
case http.StatusTooManyRequests, 529:
return contentModerationKeyRateLimitFreezeDuration
default:
return contentModerationKeyHTTPErrorFreezeDuration
}
}
func (s *ContentModerationService) ensureAPIKeyHealthLocked(hash string, masked string) *contentModerationKeyHealth {
if s.keyHealth == nil {
s.keyHealth = make(map[string]*contentModerationKeyHealth)
}
state := s.keyHealth[hash]
if state == nil {
state = &contentModerationKeyHealth{Hash: hash}
s.keyHealth[hash] = state
}
if strings.TrimSpace(masked) != "" {
state.Masked = masked
}
return state
}
func (s *ContentModerationService) configView(cfg *ContentModerationConfig) *ContentModerationConfigView {
keys := cfg.apiKeys()
masks := make([]string, 0, len(keys))
for _, key := range keys {
masks = append(masks, maskSecretTail(key))
}
apiKeyMasked := ""
if len(masks) > 0 {
apiKeyMasked = masks[0]
}
return &ContentModerationConfigView{
Enabled: cfg.Enabled,
Mode: cfg.Mode,
BaseURL: cfg.BaseURL,
Model: cfg.Model,
APIKeyConfigured: len(keys) > 0,
APIKeyMasked: apiKeyMasked,
APIKeyCount: len(keys),
APIKeyMasks: masks,
APIKeyStatuses: s.apiKeyStatuses(keys),
TimeoutMS: cfg.TimeoutMS,
SampleRate: cfg.SampleRate,
AllGroups: cfg.AllGroups,
GroupIDs: append([]int64(nil), cfg.GroupIDs...),
RecordNonHits: cfg.RecordNonHits,
WorkerCount: cfg.WorkerCount,
QueueSize: cfg.QueueSize,
BlockStatus: cfg.BlockStatus,
BlockMessage: cfg.BlockMessage,
EmailOnHit: cfg.EmailOnHit,
AutoBanEnabled: cfg.AutoBanEnabled,
BanThreshold: cfg.BanThreshold,
ViolationWindowHours: cfg.ViolationWindowHours,
RetryCount: cfg.RetryCount,
HitRetentionDays: cfg.HitRetentionDays,
NonHitRetentionDays: cfg.NonHitRetentionDays,
PreHashCheckEnabled: cfg.PreHashCheckEnabled,
}
}
func (s *ContentModerationService) apiKeyStatuses(keys []string) []ContentModerationAPIKeyStatus {
out := make([]ContentModerationAPIKeyStatus, 0, len(keys))
for idx, key := range keys {
out = append(out, s.apiKeyStatusForHash(idx, moderationAPIKeyHash(key), maskSecretTail(key), true))
}
return out
}
func (s *ContentModerationService) apiKeyStatusForHash(index int, hash string, masked string, configured bool) ContentModerationAPIKeyStatus {
status := ContentModerationAPIKeyStatus{
Index: index,
KeyHash: hash,
Masked: masked,
Status: "unknown",
Configured: configured,
}
if hash == "" || s == nil {
return status
}
now := time.Now()
s.keyHealthMu.Lock()
defer s.keyHealthMu.Unlock()
state := s.keyHealth[hash]
if state == nil {
return status
}
status.FailureCount = state.FailureCount
status.SuccessCount = state.SuccessCount
status.LastError = state.LastError
status.LastLatencyMS = state.LastLatencyMS
status.LastHTTPStatus = state.LastHTTPStatus
status.LastTested = state.LastTested
if !state.LastCheckedAt.IsZero() {
t := state.LastCheckedAt
status.LastCheckedAt = &t
}
if state.FrozenUntil.After(now) {
t := state.FrozenUntil
status.FrozenUntil = &t
status.Status = "frozen"
return status
}
if state.LastError != "" {
status.Status = "error"
return status
}
if state.SuccessCount > 0 || state.LastTested {
status.Status = "ok"
}
return status
}
func moderationAPIKeyHash(key string) string {
key = strings.TrimSpace(key)
if key == "" {
return ""
}
sum := sha256.Sum256([]byte(key))
return hex.EncodeToString(sum[:])
}
func buildModerationTestInput(prompt string, images []string) (any, int, error) {
prompt = trimRunes(normalizeContentModerationText(prompt), maxModerationInputRunes)
normalizedImages := make([]string, 0, len(images))
for _, image := range images {
image = strings.TrimSpace(image)
if image == "" {
continue
}
if len(normalizedImages) >= maxContentModerationTestImages {
return nil, 0, infraerrors.BadRequest("TOO_MANY_MODERATION_TEST_IMAGES", fmt.Sprintf("最多上传 %d 张测试图片", maxContentModerationTestImages))
}
if err := validateModerationTestImageDataURL(image); err != nil {
return nil, 0, err
}
normalizedImages = append(normalizedImages, image)
}
if prompt == "" && len(normalizedImages) == 0 {
return "hello", 0, nil
}
if len(normalizedImages) == 0 {
return prompt, 0, nil
}
parts := make([]moderationAPIInputPart, 0, len(normalizedImages)+1)
if prompt != "" {
parts = append(parts, moderationAPIInputPart{Type: "text", Text: prompt})
}
for _, image := range normalizedImages {
parts = append(parts, moderationAPIInputPart{
Type: "image_url",
ImageURL: &moderationAPIImageURLRef{URL: image},
})
}
return parts, len(normalizedImages), nil
}
func contentModerationTestHasAuditInput(prompt string, images []string) bool {
if normalizeContentModerationText(prompt) != "" {
return true
}
for _, image := range images {
if strings.TrimSpace(image) != "" {
return true
}
}
return false
}
func validateModerationTestImageDataURL(value string) error {
if len(value) > maxContentModerationTestImageDataURLBytes {
return infraerrors.BadRequest("MODERATION_TEST_IMAGE_TOO_LARGE", "测试图片不能超过 8MB")
}
if !strings.HasPrefix(value, "data:image/") {
return infraerrors.BadRequest("INVALID_MODERATION_TEST_IMAGE", "测试图片必须是 data:image/* base64")
}
parts := strings.SplitN(value, ",", 2)
if len(parts) != 2 || !strings.Contains(parts[0], ";base64") {
return infraerrors.BadRequest("INVALID_MODERATION_TEST_IMAGE", "测试图片必须是 base64 data URL")
}
raw, err := base64.StdEncoding.DecodeString(parts[1])
if err != nil {
return infraerrors.BadRequest("INVALID_MODERATION_TEST_IMAGE", "测试图片 base64 无效")
}
if len(raw) > maxContentModerationTestImageBytes {
return infraerrors.BadRequest("MODERATION_TEST_IMAGE_TOO_LARGE", "测试图片不能超过 8MB")
}
return nil
}
func buildContentModerationTestAuditResult(result *moderationAPIResult, thresholds map[string]float64) *ContentModerationTestAuditResult {
if result == nil {
return nil
}
scores := make(map[string]float64, len(result.CategoryScores))
for category, score := range result.CategoryScores {
scores[category] = score
}
thresholdSnapshot := mergeContentModerationThresholds(ContentModerationDefaultThresholds(), thresholds)
flagged, highestCategory, highestScore := evaluateModerationScores(scores, thresholdSnapshot)
compositeScore := highestScore
return &ContentModerationTestAuditResult{
Flagged: flagged,
HighestCategory: highestCategory,
HighestScore: highestScore,
CompositeScore: compositeScore,
CategoryScores: scores,
Thresholds: thresholdSnapshot,
}
}
type moderationAPIRequest struct {
Model string `json:"model"`
Input any `json:"input"`
}
type moderationAPIInputPart struct {
Type string `json:"type"`
Text string `json:"text,omitempty"`
ImageURL *moderationAPIImageURLRef `json:"image_url,omitempty"`
}
type moderationAPIImageURLRef struct {
URL string `json:"url"`
}
type moderationAPIResponse struct {
Results []moderationAPIResult `json:"results"`
}
type moderationAPIResult struct {
Flagged bool `json:"flagged"`
CategoryScores map[string]float64 `json:"category_scores"`
}
func evaluateModerationScores(scores map[string]float64, thresholds map[string]float64) (bool, string, float64) {
flagged := false
highestCategory := ""
highestScore := 0.0
for _, category := range contentModerationCategoryOrder {
score := scores[category]
if score > highestScore || highestCategory == "" {
highestScore = score
highestCategory = category
}
if score >= thresholds[category] {
flagged = true
}
}
for category, score := range scores {
if score > highestScore || highestCategory == "" {
highestScore = score
highestCategory = category
}
}
return flagged, highestCategory, highestScore
}
func mergeContentModerationThresholds(base map[string]float64, override map[string]float64) map[string]float64 {
out := cloneFloatMap(base)
if out == nil {
out = map[string]float64{}
}
for _, category := range contentModerationCategoryOrder {
if v, ok := override[category]; ok {
if v < 0 {
v = 0
}
if v > 1 {
v = 1
}
out[category] = v
}
}
return out
}
func normalizeInt64IDs(ids []int64) []int64 {
if len(ids) == 0 {
return []int64{}
}
seen := make(map[int64]struct{}, len(ids))
out := make([]int64, 0, len(ids))
for _, id := range ids {
if id <= 0 {
continue
}
if _, ok := seen[id]; ok {
continue
}
seen[id] = struct{}{}
out = append(out, id)
}
sort.Slice(out, func(i, j int) bool { return out[i] < out[j] })
return out
}
func normalizeModerationAPIKeys(keys []string) []string {
if len(keys) == 0 {
return []string{}
}
seen := make(map[string]struct{}, len(keys))
out := make([]string, 0, len(keys))
for _, key := range keys {
key = strings.TrimSpace(key)
if key == "" {
continue
}
if _, ok := seen[key]; ok {
continue
}
seen[key] = struct{}{}
out = append(out, key)
}
return out
}
func deleteModerationAPIKeysByHash(keys []string, hashes []string) []string {
keys = normalizeModerationAPIKeys(keys)
deleteHashes := make(map[string]struct{}, len(hashes))
for _, hash := range hashes {
hash = normalizeContentModerationHash(hash)
if hash != "" {
deleteHashes[hash] = struct{}{}
}
}
if len(deleteHashes) == 0 {
return keys
}
out := make([]string, 0, len(keys))
for _, key := range keys {
if _, ok := deleteHashes[moderationAPIKeyHash(key)]; ok {
continue
}
out = append(out, key)
}
return out
}
func normalizeContentModerationAPIKeysMode(mode string) string {
switch strings.ToLower(strings.TrimSpace(mode)) {
case contentModerationAPIKeysModeReplace:
return contentModerationAPIKeysModeReplace
default:
return contentModerationAPIKeysModeAppend
}
}
func normalizeContentModerationHash(inputHash string) string {
inputHash = strings.ToLower(strings.TrimSpace(inputHash))
if len(inputHash) != sha256.Size*2 {
return ""
}
if _, err := hex.DecodeString(inputHash); err != nil {
return ""
}
return inputHash
}
func cloneFloatMap(in map[string]float64) map[string]float64 {
if in == nil {
return map[string]float64{}
}
out := make(map[string]float64, len(in))
for k, v := range in {
out[k] = v
}
return out
}
func cloneInt64Ptr(in *int64) *int64 {
if in == nil {
return nil
}
v := *in
return &v
}
func trimRunes(text string, max int) string {
if max <= 0 {
return ""
}
runes := []rune(text)
if len(runes) <= max {
return text
}
return string(runes[:max])
}
func maskSecretTail(secret string) string {
secret = strings.TrimSpace(secret)
if secret == "" {
return ""
}
if len(secret) <= 4 {
return "****"
}
return strings.Repeat("*", 8) + secret[len(secret)-4:]
}