Files
sub2api-ht/backend/internal/repository/content_moderation_repo.go

275 lines
8.3 KiB
Go

package repository
import (
"context"
"database/sql"
"encoding/json"
"fmt"
"strings"
"time"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/service"
)
type contentModerationRepository struct {
db *sql.DB
}
func NewContentModerationRepository(db *sql.DB) service.ContentModerationRepository {
return &contentModerationRepository{db: db}
}
func (r *contentModerationRepository) CreateLog(ctx context.Context, log *service.ContentModerationLog) error {
if log == nil {
return nil
}
categoryScores, err := json.Marshal(log.CategoryScores)
if err != nil {
return fmt.Errorf("marshal moderation category scores: %w", err)
}
thresholdSnapshot, err := json.Marshal(log.ThresholdSnapshot)
if err != nil {
return fmt.Errorf("marshal moderation thresholds: %w", err)
}
var userID any
if log.UserID != nil {
userID = *log.UserID
}
var apiKeyID any
if log.APIKeyID != nil {
apiKeyID = *log.APIKeyID
}
var groupID any
if log.GroupID != nil {
groupID = *log.GroupID
}
var latency any
if log.UpstreamLatencyMS != nil {
latency = *log.UpstreamLatencyMS
}
err = r.db.QueryRowContext(ctx, `
INSERT INTO content_moderation_logs (
request_id, user_id, user_email, api_key_id, api_key_name, group_id, group_name,
endpoint, provider, model, mode, action, flagged, highest_category, highest_score,
category_scores, threshold_snapshot, input_excerpt, upstream_latency_ms, error,
violation_count, auto_banned, email_sent, queue_delay_ms
) VALUES (
$1, $2, $3, $4, $5, $6, $7,
$8, $9, $10, $11, $12, $13, $14, $15,
$16::jsonb, $17::jsonb, $18, $19, $20,
$21, $22, $23, $24
) RETURNING id, created_at`,
log.RequestID, userID, log.UserEmail, apiKeyID, log.APIKeyName, groupID, log.GroupName,
log.Endpoint, log.Provider, log.Model, log.Mode, log.Action, log.Flagged, log.HighestCategory, log.HighestScore,
string(categoryScores), string(thresholdSnapshot), log.InputExcerpt, latency, log.Error,
log.ViolationCount, log.AutoBanned, log.EmailSent, nullableIntPtr(log.QueueDelayMS),
).Scan(&log.ID, &log.CreatedAt)
if err != nil {
return fmt.Errorf("insert content moderation log: %w", err)
}
return nil
}
func (r *contentModerationRepository) ListLogs(ctx context.Context, filter service.ContentModerationLogFilter) ([]service.ContentModerationLog, *pagination.PaginationResult, error) {
where, args := buildContentModerationLogWhere(filter)
whereSQL := "WHERE " + strings.Join(where, " AND ")
var total int64
if err := r.db.QueryRowContext(ctx, "SELECT COUNT(*) FROM content_moderation_logs l "+whereSQL, args...).Scan(&total); err != nil {
return nil, nil, fmt.Errorf("count content moderation logs: %w", err)
}
params := filter.Pagination
if params.Page <= 0 {
params.Page = 1
}
if params.PageSize <= 0 {
params.PageSize = 20
}
if params.PageSize > 100 {
params.PageSize = 100
}
queryArgs := append([]any{}, args...)
queryArgs = append(queryArgs, params.Limit(), params.Offset())
rows, err := r.db.QueryContext(ctx, `
SELECT
l.id, l.request_id, l.user_id, l.user_email, l.api_key_id, l.api_key_name, l.group_id, l.group_name,
l.endpoint, l.provider, l.model, l.mode, l.action, l.flagged, l.highest_category, l.highest_score,
l.category_scores, l.threshold_snapshot, l.input_excerpt, l.upstream_latency_ms, l.error,
l.violation_count, l.auto_banned, l.email_sent, COALESCE(u.status, ''), l.queue_delay_ms, l.created_at
FROM content_moderation_logs l
LEFT JOIN users u ON u.id = l.user_id `+whereSQL+`
ORDER BY l.created_at DESC, l.id DESC
LIMIT $`+fmt.Sprint(len(queryArgs)-1)+` OFFSET $`+fmt.Sprint(len(queryArgs)),
queryArgs...,
)
if err != nil {
return nil, nil, fmt.Errorf("list content moderation logs: %w", err)
}
defer func() { _ = rows.Close() }()
items := make([]service.ContentModerationLog, 0)
for rows.Next() {
var item service.ContentModerationLog
var userID, apiKeyID, groupID, latency, queueDelay sql.NullInt64
var scoresRaw, thresholdsRaw []byte
if err := rows.Scan(
&item.ID,
&item.RequestID,
&userID,
&item.UserEmail,
&apiKeyID,
&item.APIKeyName,
&groupID,
&item.GroupName,
&item.Endpoint,
&item.Provider,
&item.Model,
&item.Mode,
&item.Action,
&item.Flagged,
&item.HighestCategory,
&item.HighestScore,
&scoresRaw,
&thresholdsRaw,
&item.InputExcerpt,
&latency,
&item.Error,
&item.ViolationCount,
&item.AutoBanned,
&item.EmailSent,
&item.UserStatus,
&queueDelay,
&item.CreatedAt,
); err != nil {
return nil, nil, fmt.Errorf("scan content moderation log: %w", err)
}
if userID.Valid {
v := userID.Int64
item.UserID = &v
}
if apiKeyID.Valid {
v := apiKeyID.Int64
item.APIKeyID = &v
}
if groupID.Valid {
v := groupID.Int64
item.GroupID = &v
}
if latency.Valid {
v := int(latency.Int64)
item.UpstreamLatencyMS = &v
}
if queueDelay.Valid {
v := int(queueDelay.Int64)
item.QueueDelayMS = &v
}
item.CategoryScores = map[string]float64{}
_ = json.Unmarshal(scoresRaw, &item.CategoryScores)
item.ThresholdSnapshot = map[string]float64{}
_ = json.Unmarshal(thresholdsRaw, &item.ThresholdSnapshot)
items = append(items, item)
}
if err := rows.Err(); err != nil {
return nil, nil, fmt.Errorf("iterate content moderation logs: %w", err)
}
return items, paginationResultFromTotal(total, params), nil
}
func (r *contentModerationRepository) CountFlaggedByUserSince(ctx context.Context, userID int64, since time.Time) (int, error) {
if userID <= 0 {
return 0, nil
}
var count int
err := r.db.QueryRowContext(ctx, `
WITH last_auto_ban AS (
SELECT MAX(created_at) AS at
FROM content_moderation_logs
WHERE user_id = $1 AND auto_banned = TRUE
)
SELECT COUNT(*)
FROM content_moderation_logs
WHERE user_id = $1
AND flagged = TRUE
AND created_at >= $2
AND created_at > COALESCE((SELECT at FROM last_auto_ban), '-infinity'::timestamptz)
`, userID, since).Scan(&count)
if err != nil {
return 0, fmt.Errorf("count user content moderation flagged logs: %w", err)
}
return count, nil
}
func (r *contentModerationRepository) CleanupExpiredLogs(ctx context.Context, hitBefore time.Time, nonHitBefore time.Time) (*service.ContentModerationCleanupResult, error) {
result := &service.ContentModerationCleanupResult{FinishedAt: time.Now()}
if r == nil || r.db == nil {
return result, nil
}
hitExec, err := r.db.ExecContext(ctx, `
DELETE FROM content_moderation_logs
WHERE flagged = TRUE AND created_at < $1
`, hitBefore)
if err != nil {
return nil, fmt.Errorf("delete expired hit content moderation logs: %w", err)
}
result.DeletedHit, _ = hitExec.RowsAffected()
nonHitExec, err := r.db.ExecContext(ctx, `
DELETE FROM content_moderation_logs
WHERE flagged = FALSE AND created_at < $1
`, nonHitBefore)
if err != nil {
return nil, fmt.Errorf("delete expired non-hit content moderation logs: %w", err)
}
result.DeletedNonHit, _ = nonHitExec.RowsAffected()
result.FinishedAt = time.Now()
return result, nil
}
func nullableIntPtr(value *int) any {
if value == nil {
return nil
}
return *value
}
func buildContentModerationLogWhere(filter service.ContentModerationLogFilter) ([]string, []any) {
where := []string{"l.id IS NOT NULL"}
args := make([]any, 0)
add := func(expr string, value any) {
args = append(args, value)
where = append(where, fmt.Sprintf(expr, len(args)))
}
switch strings.ToLower(strings.TrimSpace(filter.Result)) {
case "hit", "flagged":
where = append(where, "l.flagged = TRUE")
case "blocked", "block":
where = append(where, "l.action = 'block'")
case "pass", "allow":
where = append(where, "l.flagged = FALSE AND l.error = ''")
case "error":
where = append(where, "l.error <> ''")
}
if filter.GroupID != nil {
add("l.group_id = $%d", *filter.GroupID)
}
if endpoint := strings.TrimSpace(filter.Endpoint); endpoint != "" {
add("l.endpoint = $%d", endpoint)
}
if search := strings.TrimSpace(filter.Search); search != "" {
like := "%" + search + "%"
args = append(args, like, like, like, like, like)
idx := len(args) - 4
where = append(where, fmt.Sprintf("(l.request_id ILIKE $%d OR l.user_email ILIKE $%d OR l.api_key_name ILIKE $%d OR l.model ILIKE $%d OR l.input_excerpt ILIKE $%d)", idx, idx+1, idx+2, idx+3, idx+4))
}
if filter.From != nil && !filter.From.IsZero() {
add("l.created_at >= $%d", *filter.From)
}
if filter.To != nil && !filter.To.IsZero() {
add("l.created_at <= $%d", *filter.To)
}
return where, args
}