Files
sub2api/backend/internal/repository/usage_cleanup_repo.go

364 lines
9.2 KiB
Go

package repository
import (
"context"
"database/sql"
"encoding/json"
"errors"
"fmt"
"strings"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/service"
)
type usageCleanupRepository struct {
sql sqlExecutor
}
func NewUsageCleanupRepository(sqlDB *sql.DB) service.UsageCleanupRepository {
return &usageCleanupRepository{sql: sqlDB}
}
func (r *usageCleanupRepository) CreateTask(ctx context.Context, task *service.UsageCleanupTask) error {
if task == nil {
return nil
}
filtersJSON, err := json.Marshal(task.Filters)
if err != nil {
return fmt.Errorf("marshal cleanup filters: %w", err)
}
query := `
INSERT INTO usage_cleanup_tasks (
status,
filters,
created_by,
deleted_rows
) VALUES ($1, $2, $3, $4)
RETURNING id, created_at, updated_at
`
if err := scanSingleRow(ctx, r.sql, query, []any{task.Status, filtersJSON, task.CreatedBy, task.DeletedRows}, &task.ID, &task.CreatedAt, &task.UpdatedAt); err != nil {
return err
}
return nil
}
func (r *usageCleanupRepository) ListTasks(ctx context.Context, params pagination.PaginationParams) ([]service.UsageCleanupTask, *pagination.PaginationResult, error) {
var total int64
if err := scanSingleRow(ctx, r.sql, "SELECT COUNT(*) FROM usage_cleanup_tasks", nil, &total); err != nil {
return nil, nil, err
}
if total == 0 {
return []service.UsageCleanupTask{}, paginationResultFromTotal(0, params), nil
}
query := `
SELECT id, status, filters, created_by, deleted_rows, error_message,
canceled_by, canceled_at,
started_at, finished_at, created_at, updated_at
FROM usage_cleanup_tasks
ORDER BY created_at DESC
LIMIT $1 OFFSET $2
`
rows, err := r.sql.QueryContext(ctx, query, params.Limit(), params.Offset())
if err != nil {
return nil, nil, err
}
defer rows.Close()
tasks := make([]service.UsageCleanupTask, 0)
for rows.Next() {
var task service.UsageCleanupTask
var filtersJSON []byte
var errMsg sql.NullString
var canceledBy sql.NullInt64
var canceledAt sql.NullTime
var startedAt sql.NullTime
var finishedAt sql.NullTime
if err := rows.Scan(
&task.ID,
&task.Status,
&filtersJSON,
&task.CreatedBy,
&task.DeletedRows,
&errMsg,
&canceledBy,
&canceledAt,
&startedAt,
&finishedAt,
&task.CreatedAt,
&task.UpdatedAt,
); err != nil {
return nil, nil, err
}
if err := json.Unmarshal(filtersJSON, &task.Filters); err != nil {
return nil, nil, fmt.Errorf("parse cleanup filters: %w", err)
}
if errMsg.Valid {
task.ErrorMsg = &errMsg.String
}
if canceledBy.Valid {
v := canceledBy.Int64
task.CanceledBy = &v
}
if canceledAt.Valid {
task.CanceledAt = &canceledAt.Time
}
if startedAt.Valid {
task.StartedAt = &startedAt.Time
}
if finishedAt.Valid {
task.FinishedAt = &finishedAt.Time
}
tasks = append(tasks, task)
}
if err := rows.Err(); err != nil {
return nil, nil, err
}
return tasks, paginationResultFromTotal(total, params), nil
}
func (r *usageCleanupRepository) ClaimNextPendingTask(ctx context.Context, staleRunningAfterSeconds int64) (*service.UsageCleanupTask, error) {
if staleRunningAfterSeconds <= 0 {
staleRunningAfterSeconds = 1800
}
query := `
WITH next AS (
SELECT id
FROM usage_cleanup_tasks
WHERE status = $1
OR (
status = $2
AND started_at IS NOT NULL
AND started_at < NOW() - ($3 * interval '1 second')
)
ORDER BY created_at ASC
LIMIT 1
FOR UPDATE SKIP LOCKED
)
UPDATE usage_cleanup_tasks AS tasks
SET status = $4,
started_at = NOW(),
finished_at = NULL,
error_message = NULL,
updated_at = NOW()
FROM next
WHERE tasks.id = next.id
RETURNING tasks.id, tasks.status, tasks.filters, tasks.created_by, tasks.deleted_rows, tasks.error_message,
tasks.started_at, tasks.finished_at, tasks.created_at, tasks.updated_at
`
var task service.UsageCleanupTask
var filtersJSON []byte
var errMsg sql.NullString
var startedAt sql.NullTime
var finishedAt sql.NullTime
if err := scanSingleRow(
ctx,
r.sql,
query,
[]any{
service.UsageCleanupStatusPending,
service.UsageCleanupStatusRunning,
staleRunningAfterSeconds,
service.UsageCleanupStatusRunning,
},
&task.ID,
&task.Status,
&filtersJSON,
&task.CreatedBy,
&task.DeletedRows,
&errMsg,
&startedAt,
&finishedAt,
&task.CreatedAt,
&task.UpdatedAt,
); err != nil {
if errors.Is(err, sql.ErrNoRows) {
return nil, nil
}
return nil, err
}
if err := json.Unmarshal(filtersJSON, &task.Filters); err != nil {
return nil, fmt.Errorf("parse cleanup filters: %w", err)
}
if errMsg.Valid {
task.ErrorMsg = &errMsg.String
}
if startedAt.Valid {
task.StartedAt = &startedAt.Time
}
if finishedAt.Valid {
task.FinishedAt = &finishedAt.Time
}
return &task, nil
}
func (r *usageCleanupRepository) GetTaskStatus(ctx context.Context, taskID int64) (string, error) {
var status string
if err := scanSingleRow(ctx, r.sql, "SELECT status FROM usage_cleanup_tasks WHERE id = $1", []any{taskID}, &status); err != nil {
return "", err
}
return status, nil
}
func (r *usageCleanupRepository) UpdateTaskProgress(ctx context.Context, taskID int64, deletedRows int64) error {
query := `
UPDATE usage_cleanup_tasks
SET deleted_rows = $1,
updated_at = NOW()
WHERE id = $2
`
_, err := r.sql.ExecContext(ctx, query, deletedRows, taskID)
return err
}
func (r *usageCleanupRepository) CancelTask(ctx context.Context, taskID int64, canceledBy int64) (bool, error) {
query := `
UPDATE usage_cleanup_tasks
SET status = $1,
canceled_by = $3,
canceled_at = NOW(),
finished_at = NOW(),
error_message = NULL,
updated_at = NOW()
WHERE id = $2
AND status IN ($4, $5)
RETURNING id
`
var id int64
err := scanSingleRow(ctx, r.sql, query, []any{
service.UsageCleanupStatusCanceled,
taskID,
canceledBy,
service.UsageCleanupStatusPending,
service.UsageCleanupStatusRunning,
}, &id)
if errors.Is(err, sql.ErrNoRows) {
return false, nil
}
if err != nil {
return false, err
}
return true, nil
}
func (r *usageCleanupRepository) MarkTaskSucceeded(ctx context.Context, taskID int64, deletedRows int64) error {
query := `
UPDATE usage_cleanup_tasks
SET status = $1,
deleted_rows = $2,
finished_at = NOW(),
updated_at = NOW()
WHERE id = $3
`
_, err := r.sql.ExecContext(ctx, query, service.UsageCleanupStatusSucceeded, deletedRows, taskID)
return err
}
func (r *usageCleanupRepository) MarkTaskFailed(ctx context.Context, taskID int64, deletedRows int64, errorMsg string) error {
query := `
UPDATE usage_cleanup_tasks
SET status = $1,
deleted_rows = $2,
error_message = $3,
finished_at = NOW(),
updated_at = NOW()
WHERE id = $4
`
_, err := r.sql.ExecContext(ctx, query, service.UsageCleanupStatusFailed, deletedRows, errorMsg, taskID)
return err
}
func (r *usageCleanupRepository) DeleteUsageLogsBatch(ctx context.Context, filters service.UsageCleanupFilters, limit int) (int64, error) {
if filters.StartTime.IsZero() || filters.EndTime.IsZero() {
return 0, fmt.Errorf("cleanup filters missing time range")
}
whereClause, args := buildUsageCleanupWhere(filters)
if whereClause == "" {
return 0, fmt.Errorf("cleanup filters missing time range")
}
args = append(args, limit)
query := fmt.Sprintf(`
WITH target AS (
SELECT id
FROM usage_logs
WHERE %s
ORDER BY created_at ASC, id ASC
LIMIT $%d
)
DELETE FROM usage_logs
WHERE id IN (SELECT id FROM target)
RETURNING id
`, whereClause, len(args))
rows, err := r.sql.QueryContext(ctx, query, args...)
if err != nil {
return 0, err
}
defer rows.Close()
var deleted int64
for rows.Next() {
deleted++
}
if err := rows.Err(); err != nil {
return 0, err
}
return deleted, nil
}
func buildUsageCleanupWhere(filters service.UsageCleanupFilters) (string, []any) {
conditions := make([]string, 0, 8)
args := make([]any, 0, 8)
idx := 1
if !filters.StartTime.IsZero() {
conditions = append(conditions, fmt.Sprintf("created_at >= $%d", idx))
args = append(args, filters.StartTime)
idx++
}
if !filters.EndTime.IsZero() {
conditions = append(conditions, fmt.Sprintf("created_at <= $%d", idx))
args = append(args, filters.EndTime)
idx++
}
if filters.UserID != nil {
conditions = append(conditions, fmt.Sprintf("user_id = $%d", idx))
args = append(args, *filters.UserID)
idx++
}
if filters.APIKeyID != nil {
conditions = append(conditions, fmt.Sprintf("api_key_id = $%d", idx))
args = append(args, *filters.APIKeyID)
idx++
}
if filters.AccountID != nil {
conditions = append(conditions, fmt.Sprintf("account_id = $%d", idx))
args = append(args, *filters.AccountID)
idx++
}
if filters.GroupID != nil {
conditions = append(conditions, fmt.Sprintf("group_id = $%d", idx))
args = append(args, *filters.GroupID)
idx++
}
if filters.Model != nil {
model := strings.TrimSpace(*filters.Model)
if model != "" {
conditions = append(conditions, fmt.Sprintf("model = $%d", idx))
args = append(args, model)
idx++
}
}
if filters.Stream != nil {
conditions = append(conditions, fmt.Sprintf("stream = $%d", idx))
args = append(args, *filters.Stream)
idx++
}
if filters.BillingType != nil {
conditions = append(conditions, fmt.Sprintf("billing_type = $%d", idx))
args = append(args, *filters.BillingType)
idx++
}
return strings.Join(conditions, " AND "), args
}