feat(risk-control): add content moderation audit
This commit is contained in:
@@ -125,6 +125,7 @@ func (r *apiKeyRepository) GetByKeyForAuth(ctx context.Context, key string) (*se
|
||||
apikey.FieldID,
|
||||
apikey.FieldUserID,
|
||||
apikey.FieldGroupID,
|
||||
apikey.FieldName,
|
||||
apikey.FieldStatus,
|
||||
apikey.FieldIPWhitelist,
|
||||
apikey.FieldIPBlacklist,
|
||||
|
||||
@@ -69,6 +69,7 @@ func TestAPIKeyRepository_GetByKeyForAuth_PreservesMessagesDispatchModelConfig_S
|
||||
|
||||
got, err := repo.GetByKeyForAuth(ctx, key.Key)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, key.Name, got.Name)
|
||||
require.NotNil(t, got.Group)
|
||||
require.Equal(t, group.MessagesDispatchModelConfig, got.Group.MessagesDispatchModelConfig)
|
||||
}
|
||||
|
||||
71
backend/internal/repository/content_moderation_hash_cache.go
Normal file
71
backend/internal/repository/content_moderation_hash_cache.go
Normal file
@@ -0,0 +1,71 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"strings"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/redis/go-redis/v9"
|
||||
)
|
||||
|
||||
const contentModerationFlaggedHashSetKey = "content_moderation:flagged_hashes"
|
||||
|
||||
type contentModerationHashCache struct {
|
||||
rdb *redis.Client
|
||||
}
|
||||
|
||||
func NewContentModerationHashCache(rdb *redis.Client) service.ContentModerationHashCache {
|
||||
return &contentModerationHashCache{rdb: rdb}
|
||||
}
|
||||
|
||||
func (c *contentModerationHashCache) RecordFlaggedInputHash(ctx context.Context, inputHash string) error {
|
||||
inputHash = strings.TrimSpace(inputHash)
|
||||
if c == nil || c.rdb == nil || inputHash == "" {
|
||||
return nil
|
||||
}
|
||||
return c.rdb.SAdd(ctx, contentModerationFlaggedHashSetKey, inputHash).Err()
|
||||
}
|
||||
|
||||
func (c *contentModerationHashCache) HasFlaggedInputHash(ctx context.Context, inputHash string) (bool, error) {
|
||||
inputHash = strings.TrimSpace(inputHash)
|
||||
if c == nil || c.rdb == nil || inputHash == "" {
|
||||
return false, nil
|
||||
}
|
||||
return c.rdb.SIsMember(ctx, contentModerationFlaggedHashSetKey, inputHash).Result()
|
||||
}
|
||||
|
||||
func (c *contentModerationHashCache) DeleteFlaggedInputHash(ctx context.Context, inputHash string) (bool, error) {
|
||||
inputHash = strings.TrimSpace(inputHash)
|
||||
if c == nil || c.rdb == nil || inputHash == "" {
|
||||
return false, nil
|
||||
}
|
||||
deleted, err := c.rdb.SRem(ctx, contentModerationFlaggedHashSetKey, inputHash).Result()
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
return deleted > 0, nil
|
||||
}
|
||||
|
||||
func (c *contentModerationHashCache) ClearFlaggedInputHashes(ctx context.Context) (int64, error) {
|
||||
if c == nil || c.rdb == nil {
|
||||
return 0, nil
|
||||
}
|
||||
deleted, err := c.rdb.SCard(ctx, contentModerationFlaggedHashSetKey).Result()
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
if deleted == 0 {
|
||||
return 0, nil
|
||||
}
|
||||
if err := c.rdb.Del(ctx, contentModerationFlaggedHashSetKey).Err(); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return deleted, nil
|
||||
}
|
||||
|
||||
func (c *contentModerationHashCache) CountFlaggedInputHashes(ctx context.Context) (int64, error) {
|
||||
if c == nil || c.rdb == nil {
|
||||
return 0, nil
|
||||
}
|
||||
return c.rdb.SCard(ctx, contentModerationFlaggedHashSetKey).Result()
|
||||
}
|
||||
274
backend/internal/repository/content_moderation_repo.go
Normal file
274
backend/internal/repository/content_moderation_repo.go
Normal file
@@ -0,0 +1,274 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
)
|
||||
|
||||
type contentModerationRepository struct {
|
||||
db *sql.DB
|
||||
}
|
||||
|
||||
func NewContentModerationRepository(db *sql.DB) service.ContentModerationRepository {
|
||||
return &contentModerationRepository{db: db}
|
||||
}
|
||||
|
||||
func (r *contentModerationRepository) CreateLog(ctx context.Context, log *service.ContentModerationLog) error {
|
||||
if log == nil {
|
||||
return nil
|
||||
}
|
||||
categoryScores, err := json.Marshal(log.CategoryScores)
|
||||
if err != nil {
|
||||
return fmt.Errorf("marshal moderation category scores: %w", err)
|
||||
}
|
||||
thresholdSnapshot, err := json.Marshal(log.ThresholdSnapshot)
|
||||
if err != nil {
|
||||
return fmt.Errorf("marshal moderation thresholds: %w", err)
|
||||
}
|
||||
var userID any
|
||||
if log.UserID != nil {
|
||||
userID = *log.UserID
|
||||
}
|
||||
var apiKeyID any
|
||||
if log.APIKeyID != nil {
|
||||
apiKeyID = *log.APIKeyID
|
||||
}
|
||||
var groupID any
|
||||
if log.GroupID != nil {
|
||||
groupID = *log.GroupID
|
||||
}
|
||||
var latency any
|
||||
if log.UpstreamLatencyMS != nil {
|
||||
latency = *log.UpstreamLatencyMS
|
||||
}
|
||||
err = r.db.QueryRowContext(ctx, `
|
||||
INSERT INTO content_moderation_logs (
|
||||
request_id, user_id, user_email, api_key_id, api_key_name, group_id, group_name,
|
||||
endpoint, provider, model, mode, action, flagged, highest_category, highest_score,
|
||||
category_scores, threshold_snapshot, input_excerpt, upstream_latency_ms, error,
|
||||
violation_count, auto_banned, email_sent, queue_delay_ms
|
||||
) VALUES (
|
||||
$1, $2, $3, $4, $5, $6, $7,
|
||||
$8, $9, $10, $11, $12, $13, $14, $15,
|
||||
$16::jsonb, $17::jsonb, $18, $19, $20,
|
||||
$21, $22, $23, $24
|
||||
) RETURNING id, created_at`,
|
||||
log.RequestID, userID, log.UserEmail, apiKeyID, log.APIKeyName, groupID, log.GroupName,
|
||||
log.Endpoint, log.Provider, log.Model, log.Mode, log.Action, log.Flagged, log.HighestCategory, log.HighestScore,
|
||||
string(categoryScores), string(thresholdSnapshot), log.InputExcerpt, latency, log.Error,
|
||||
log.ViolationCount, log.AutoBanned, log.EmailSent, nullableIntPtr(log.QueueDelayMS),
|
||||
).Scan(&log.ID, &log.CreatedAt)
|
||||
if err != nil {
|
||||
return fmt.Errorf("insert content moderation log: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *contentModerationRepository) ListLogs(ctx context.Context, filter service.ContentModerationLogFilter) ([]service.ContentModerationLog, *pagination.PaginationResult, error) {
|
||||
where, args := buildContentModerationLogWhere(filter)
|
||||
whereSQL := "WHERE " + strings.Join(where, " AND ")
|
||||
|
||||
var total int64
|
||||
if err := r.db.QueryRowContext(ctx, "SELECT COUNT(*) FROM content_moderation_logs l "+whereSQL, args...).Scan(&total); err != nil {
|
||||
return nil, nil, fmt.Errorf("count content moderation logs: %w", err)
|
||||
}
|
||||
|
||||
params := filter.Pagination
|
||||
if params.Page <= 0 {
|
||||
params.Page = 1
|
||||
}
|
||||
if params.PageSize <= 0 {
|
||||
params.PageSize = 20
|
||||
}
|
||||
if params.PageSize > 100 {
|
||||
params.PageSize = 100
|
||||
}
|
||||
queryArgs := append([]any{}, args...)
|
||||
queryArgs = append(queryArgs, params.Limit(), params.Offset())
|
||||
rows, err := r.db.QueryContext(ctx, `
|
||||
SELECT
|
||||
l.id, l.request_id, l.user_id, l.user_email, l.api_key_id, l.api_key_name, l.group_id, l.group_name,
|
||||
l.endpoint, l.provider, l.model, l.mode, l.action, l.flagged, l.highest_category, l.highest_score,
|
||||
l.category_scores, l.threshold_snapshot, l.input_excerpt, l.upstream_latency_ms, l.error,
|
||||
l.violation_count, l.auto_banned, l.email_sent, COALESCE(u.status, ''), l.queue_delay_ms, l.created_at
|
||||
FROM content_moderation_logs l
|
||||
LEFT JOIN users u ON u.id = l.user_id `+whereSQL+`
|
||||
ORDER BY l.created_at DESC, l.id DESC
|
||||
LIMIT $`+fmt.Sprint(len(queryArgs)-1)+` OFFSET $`+fmt.Sprint(len(queryArgs)),
|
||||
queryArgs...,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("list content moderation logs: %w", err)
|
||||
}
|
||||
defer func() { _ = rows.Close() }()
|
||||
|
||||
items := make([]service.ContentModerationLog, 0)
|
||||
for rows.Next() {
|
||||
var item service.ContentModerationLog
|
||||
var userID, apiKeyID, groupID, latency, queueDelay sql.NullInt64
|
||||
var scoresRaw, thresholdsRaw []byte
|
||||
if err := rows.Scan(
|
||||
&item.ID,
|
||||
&item.RequestID,
|
||||
&userID,
|
||||
&item.UserEmail,
|
||||
&apiKeyID,
|
||||
&item.APIKeyName,
|
||||
&groupID,
|
||||
&item.GroupName,
|
||||
&item.Endpoint,
|
||||
&item.Provider,
|
||||
&item.Model,
|
||||
&item.Mode,
|
||||
&item.Action,
|
||||
&item.Flagged,
|
||||
&item.HighestCategory,
|
||||
&item.HighestScore,
|
||||
&scoresRaw,
|
||||
&thresholdsRaw,
|
||||
&item.InputExcerpt,
|
||||
&latency,
|
||||
&item.Error,
|
||||
&item.ViolationCount,
|
||||
&item.AutoBanned,
|
||||
&item.EmailSent,
|
||||
&item.UserStatus,
|
||||
&queueDelay,
|
||||
&item.CreatedAt,
|
||||
); err != nil {
|
||||
return nil, nil, fmt.Errorf("scan content moderation log: %w", err)
|
||||
}
|
||||
if userID.Valid {
|
||||
v := userID.Int64
|
||||
item.UserID = &v
|
||||
}
|
||||
if apiKeyID.Valid {
|
||||
v := apiKeyID.Int64
|
||||
item.APIKeyID = &v
|
||||
}
|
||||
if groupID.Valid {
|
||||
v := groupID.Int64
|
||||
item.GroupID = &v
|
||||
}
|
||||
if latency.Valid {
|
||||
v := int(latency.Int64)
|
||||
item.UpstreamLatencyMS = &v
|
||||
}
|
||||
if queueDelay.Valid {
|
||||
v := int(queueDelay.Int64)
|
||||
item.QueueDelayMS = &v
|
||||
}
|
||||
item.CategoryScores = map[string]float64{}
|
||||
_ = json.Unmarshal(scoresRaw, &item.CategoryScores)
|
||||
item.ThresholdSnapshot = map[string]float64{}
|
||||
_ = json.Unmarshal(thresholdsRaw, &item.ThresholdSnapshot)
|
||||
items = append(items, item)
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, nil, fmt.Errorf("iterate content moderation logs: %w", err)
|
||||
}
|
||||
return items, paginationResultFromTotal(total, params), nil
|
||||
}
|
||||
|
||||
func (r *contentModerationRepository) CountFlaggedByUserSince(ctx context.Context, userID int64, since time.Time) (int, error) {
|
||||
if userID <= 0 {
|
||||
return 0, nil
|
||||
}
|
||||
var count int
|
||||
err := r.db.QueryRowContext(ctx, `
|
||||
WITH last_auto_ban AS (
|
||||
SELECT MAX(created_at) AS at
|
||||
FROM content_moderation_logs
|
||||
WHERE user_id = $1 AND auto_banned = TRUE
|
||||
)
|
||||
SELECT COUNT(*)
|
||||
FROM content_moderation_logs
|
||||
WHERE user_id = $1
|
||||
AND flagged = TRUE
|
||||
AND created_at >= $2
|
||||
AND created_at > COALESCE((SELECT at FROM last_auto_ban), '-infinity'::timestamptz)
|
||||
`, userID, since).Scan(&count)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("count user content moderation flagged logs: %w", err)
|
||||
}
|
||||
return count, nil
|
||||
}
|
||||
|
||||
func (r *contentModerationRepository) CleanupExpiredLogs(ctx context.Context, hitBefore time.Time, nonHitBefore time.Time) (*service.ContentModerationCleanupResult, error) {
|
||||
result := &service.ContentModerationCleanupResult{FinishedAt: time.Now()}
|
||||
if r == nil || r.db == nil {
|
||||
return result, nil
|
||||
}
|
||||
hitExec, err := r.db.ExecContext(ctx, `
|
||||
DELETE FROM content_moderation_logs
|
||||
WHERE flagged = TRUE AND created_at < $1
|
||||
`, hitBefore)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("delete expired hit content moderation logs: %w", err)
|
||||
}
|
||||
result.DeletedHit, _ = hitExec.RowsAffected()
|
||||
|
||||
nonHitExec, err := r.db.ExecContext(ctx, `
|
||||
DELETE FROM content_moderation_logs
|
||||
WHERE flagged = FALSE AND created_at < $1
|
||||
`, nonHitBefore)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("delete expired non-hit content moderation logs: %w", err)
|
||||
}
|
||||
result.DeletedNonHit, _ = nonHitExec.RowsAffected()
|
||||
|
||||
result.FinishedAt = time.Now()
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func nullableIntPtr(value *int) any {
|
||||
if value == nil {
|
||||
return nil
|
||||
}
|
||||
return *value
|
||||
}
|
||||
|
||||
func buildContentModerationLogWhere(filter service.ContentModerationLogFilter) ([]string, []any) {
|
||||
where := []string{"l.id IS NOT NULL"}
|
||||
args := make([]any, 0)
|
||||
add := func(expr string, value any) {
|
||||
args = append(args, value)
|
||||
where = append(where, fmt.Sprintf(expr, len(args)))
|
||||
}
|
||||
switch strings.ToLower(strings.TrimSpace(filter.Result)) {
|
||||
case "hit", "flagged":
|
||||
where = append(where, "l.flagged = TRUE")
|
||||
case "blocked", "block":
|
||||
where = append(where, "l.action = 'block'")
|
||||
case "pass", "allow":
|
||||
where = append(where, "l.flagged = FALSE AND l.error = ''")
|
||||
case "error":
|
||||
where = append(where, "l.error <> ''")
|
||||
}
|
||||
if filter.GroupID != nil {
|
||||
add("l.group_id = $%d", *filter.GroupID)
|
||||
}
|
||||
if endpoint := strings.TrimSpace(filter.Endpoint); endpoint != "" {
|
||||
add("l.endpoint = $%d", endpoint)
|
||||
}
|
||||
if search := strings.TrimSpace(filter.Search); search != "" {
|
||||
like := "%" + search + "%"
|
||||
args = append(args, like, like, like, like, like)
|
||||
idx := len(args) - 4
|
||||
where = append(where, fmt.Sprintf("(l.request_id ILIKE $%d OR l.user_email ILIKE $%d OR l.api_key_name ILIKE $%d OR l.model ILIKE $%d OR l.input_excerpt ILIKE $%d)", idx, idx+1, idx+2, idx+3, idx+4))
|
||||
}
|
||||
if filter.From != nil && !filter.From.IsZero() {
|
||||
add("l.created_at >= $%d", *filter.From)
|
||||
}
|
||||
if filter.To != nil && !filter.To.IsZero() {
|
||||
add("l.created_at <= $%d", *filter.To)
|
||||
}
|
||||
return where, args
|
||||
}
|
||||
@@ -91,6 +91,7 @@ var ProviderSet = wire.NewSet(
|
||||
NewChannelRepository,
|
||||
NewChannelMonitorRepository,
|
||||
NewChannelMonitorRequestTemplateRepository,
|
||||
NewContentModerationRepository,
|
||||
NewAffiliateRepository,
|
||||
|
||||
// Cache implementations
|
||||
@@ -119,6 +120,7 @@ var ProviderSet = wire.NewSet(
|
||||
NewRefreshTokenCache,
|
||||
NewErrorPassthroughCache,
|
||||
NewTLSFingerprintProfileCache,
|
||||
NewContentModerationHashCache,
|
||||
|
||||
// Encryptors
|
||||
NewAESEncryptor,
|
||||
|
||||
Reference in New Issue
Block a user