为共享 req 客户端增加 HTTP/2 选项与缓存隔离 OpenAI OAuth 超时提升到 120s,并按协议控制强制 新增客户端池与 OAuth 客户端单测覆盖 修复 usage cleanup 相关 errcheck/ineffassign/staticcheck 并统一格式 测试: make test
367 lines
9.2 KiB
Go
367 lines
9.2 KiB
Go
package repository
|
|
|
|
import (
|
|
"context"
|
|
"database/sql"
|
|
"encoding/json"
|
|
"errors"
|
|
"fmt"
|
|
"strings"
|
|
|
|
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
|
"github.com/Wei-Shaw/sub2api/internal/service"
|
|
)
|
|
|
|
type usageCleanupRepository struct {
|
|
sql sqlExecutor
|
|
}
|
|
|
|
func NewUsageCleanupRepository(sqlDB *sql.DB) service.UsageCleanupRepository {
|
|
return &usageCleanupRepository{sql: sqlDB}
|
|
}
|
|
|
|
func (r *usageCleanupRepository) CreateTask(ctx context.Context, task *service.UsageCleanupTask) error {
|
|
if task == nil {
|
|
return nil
|
|
}
|
|
filtersJSON, err := json.Marshal(task.Filters)
|
|
if err != nil {
|
|
return fmt.Errorf("marshal cleanup filters: %w", err)
|
|
}
|
|
query := `
|
|
INSERT INTO usage_cleanup_tasks (
|
|
status,
|
|
filters,
|
|
created_by,
|
|
deleted_rows
|
|
) VALUES ($1, $2, $3, $4)
|
|
RETURNING id, created_at, updated_at
|
|
`
|
|
if err := scanSingleRow(ctx, r.sql, query, []any{task.Status, filtersJSON, task.CreatedBy, task.DeletedRows}, &task.ID, &task.CreatedAt, &task.UpdatedAt); err != nil {
|
|
return err
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (r *usageCleanupRepository) ListTasks(ctx context.Context, params pagination.PaginationParams) ([]service.UsageCleanupTask, *pagination.PaginationResult, error) {
|
|
var total int64
|
|
if err := scanSingleRow(ctx, r.sql, "SELECT COUNT(*) FROM usage_cleanup_tasks", nil, &total); err != nil {
|
|
return nil, nil, err
|
|
}
|
|
if total == 0 {
|
|
return []service.UsageCleanupTask{}, paginationResultFromTotal(0, params), nil
|
|
}
|
|
|
|
query := `
|
|
SELECT id, status, filters, created_by, deleted_rows, error_message,
|
|
canceled_by, canceled_at,
|
|
started_at, finished_at, created_at, updated_at
|
|
FROM usage_cleanup_tasks
|
|
ORDER BY created_at DESC
|
|
LIMIT $1 OFFSET $2
|
|
`
|
|
rows, err := r.sql.QueryContext(ctx, query, params.Limit(), params.Offset())
|
|
if err != nil {
|
|
return nil, nil, err
|
|
}
|
|
defer func() {
|
|
_ = rows.Close()
|
|
}()
|
|
|
|
tasks := make([]service.UsageCleanupTask, 0)
|
|
for rows.Next() {
|
|
var task service.UsageCleanupTask
|
|
var filtersJSON []byte
|
|
var errMsg sql.NullString
|
|
var canceledBy sql.NullInt64
|
|
var canceledAt sql.NullTime
|
|
var startedAt sql.NullTime
|
|
var finishedAt sql.NullTime
|
|
if err := rows.Scan(
|
|
&task.ID,
|
|
&task.Status,
|
|
&filtersJSON,
|
|
&task.CreatedBy,
|
|
&task.DeletedRows,
|
|
&errMsg,
|
|
&canceledBy,
|
|
&canceledAt,
|
|
&startedAt,
|
|
&finishedAt,
|
|
&task.CreatedAt,
|
|
&task.UpdatedAt,
|
|
); err != nil {
|
|
return nil, nil, err
|
|
}
|
|
if err := json.Unmarshal(filtersJSON, &task.Filters); err != nil {
|
|
return nil, nil, fmt.Errorf("parse cleanup filters: %w", err)
|
|
}
|
|
if errMsg.Valid {
|
|
task.ErrorMsg = &errMsg.String
|
|
}
|
|
if canceledBy.Valid {
|
|
v := canceledBy.Int64
|
|
task.CanceledBy = &v
|
|
}
|
|
if canceledAt.Valid {
|
|
task.CanceledAt = &canceledAt.Time
|
|
}
|
|
if startedAt.Valid {
|
|
task.StartedAt = &startedAt.Time
|
|
}
|
|
if finishedAt.Valid {
|
|
task.FinishedAt = &finishedAt.Time
|
|
}
|
|
tasks = append(tasks, task)
|
|
}
|
|
if err := rows.Err(); err != nil {
|
|
return nil, nil, err
|
|
}
|
|
return tasks, paginationResultFromTotal(total, params), nil
|
|
}
|
|
|
|
func (r *usageCleanupRepository) ClaimNextPendingTask(ctx context.Context, staleRunningAfterSeconds int64) (*service.UsageCleanupTask, error) {
|
|
if staleRunningAfterSeconds <= 0 {
|
|
staleRunningAfterSeconds = 1800
|
|
}
|
|
query := `
|
|
WITH next AS (
|
|
SELECT id
|
|
FROM usage_cleanup_tasks
|
|
WHERE status = $1
|
|
OR (
|
|
status = $2
|
|
AND started_at IS NOT NULL
|
|
AND started_at < NOW() - ($3 * interval '1 second')
|
|
)
|
|
ORDER BY created_at ASC
|
|
LIMIT 1
|
|
FOR UPDATE SKIP LOCKED
|
|
)
|
|
UPDATE usage_cleanup_tasks
|
|
SET status = $4,
|
|
started_at = NOW(),
|
|
finished_at = NULL,
|
|
error_message = NULL,
|
|
updated_at = NOW()
|
|
FROM next
|
|
WHERE usage_cleanup_tasks.id = next.id
|
|
RETURNING id, status, filters, created_by, deleted_rows, error_message,
|
|
started_at, finished_at, created_at, updated_at
|
|
`
|
|
var task service.UsageCleanupTask
|
|
var filtersJSON []byte
|
|
var errMsg sql.NullString
|
|
var startedAt sql.NullTime
|
|
var finishedAt sql.NullTime
|
|
if err := scanSingleRow(
|
|
ctx,
|
|
r.sql,
|
|
query,
|
|
[]any{
|
|
service.UsageCleanupStatusPending,
|
|
service.UsageCleanupStatusRunning,
|
|
staleRunningAfterSeconds,
|
|
service.UsageCleanupStatusRunning,
|
|
},
|
|
&task.ID,
|
|
&task.Status,
|
|
&filtersJSON,
|
|
&task.CreatedBy,
|
|
&task.DeletedRows,
|
|
&errMsg,
|
|
&startedAt,
|
|
&finishedAt,
|
|
&task.CreatedAt,
|
|
&task.UpdatedAt,
|
|
); err != nil {
|
|
if errors.Is(err, sql.ErrNoRows) {
|
|
return nil, nil
|
|
}
|
|
return nil, err
|
|
}
|
|
if err := json.Unmarshal(filtersJSON, &task.Filters); err != nil {
|
|
return nil, fmt.Errorf("parse cleanup filters: %w", err)
|
|
}
|
|
if errMsg.Valid {
|
|
task.ErrorMsg = &errMsg.String
|
|
}
|
|
if startedAt.Valid {
|
|
task.StartedAt = &startedAt.Time
|
|
}
|
|
if finishedAt.Valid {
|
|
task.FinishedAt = &finishedAt.Time
|
|
}
|
|
return &task, nil
|
|
}
|
|
|
|
func (r *usageCleanupRepository) GetTaskStatus(ctx context.Context, taskID int64) (string, error) {
|
|
var status string
|
|
if err := scanSingleRow(ctx, r.sql, "SELECT status FROM usage_cleanup_tasks WHERE id = $1", []any{taskID}, &status); err != nil {
|
|
return "", err
|
|
}
|
|
return status, nil
|
|
}
|
|
|
|
func (r *usageCleanupRepository) UpdateTaskProgress(ctx context.Context, taskID int64, deletedRows int64) error {
|
|
query := `
|
|
UPDATE usage_cleanup_tasks
|
|
SET deleted_rows = $1,
|
|
updated_at = NOW()
|
|
WHERE id = $2
|
|
`
|
|
_, err := r.sql.ExecContext(ctx, query, deletedRows, taskID)
|
|
return err
|
|
}
|
|
|
|
func (r *usageCleanupRepository) CancelTask(ctx context.Context, taskID int64, canceledBy int64) (bool, error) {
|
|
query := `
|
|
UPDATE usage_cleanup_tasks
|
|
SET status = $1,
|
|
canceled_by = $3,
|
|
canceled_at = NOW(),
|
|
finished_at = NOW(),
|
|
error_message = NULL,
|
|
updated_at = NOW()
|
|
WHERE id = $2
|
|
AND status IN ($4, $5)
|
|
RETURNING id
|
|
`
|
|
var id int64
|
|
err := scanSingleRow(ctx, r.sql, query, []any{
|
|
service.UsageCleanupStatusCanceled,
|
|
taskID,
|
|
canceledBy,
|
|
service.UsageCleanupStatusPending,
|
|
service.UsageCleanupStatusRunning,
|
|
}, &id)
|
|
if errors.Is(err, sql.ErrNoRows) {
|
|
return false, nil
|
|
}
|
|
if err != nil {
|
|
return false, err
|
|
}
|
|
return true, nil
|
|
}
|
|
|
|
func (r *usageCleanupRepository) MarkTaskSucceeded(ctx context.Context, taskID int64, deletedRows int64) error {
|
|
query := `
|
|
UPDATE usage_cleanup_tasks
|
|
SET status = $1,
|
|
deleted_rows = $2,
|
|
finished_at = NOW(),
|
|
updated_at = NOW()
|
|
WHERE id = $3
|
|
`
|
|
_, err := r.sql.ExecContext(ctx, query, service.UsageCleanupStatusSucceeded, deletedRows, taskID)
|
|
return err
|
|
}
|
|
|
|
func (r *usageCleanupRepository) MarkTaskFailed(ctx context.Context, taskID int64, deletedRows int64, errorMsg string) error {
|
|
query := `
|
|
UPDATE usage_cleanup_tasks
|
|
SET status = $1,
|
|
deleted_rows = $2,
|
|
error_message = $3,
|
|
finished_at = NOW(),
|
|
updated_at = NOW()
|
|
WHERE id = $4
|
|
`
|
|
_, err := r.sql.ExecContext(ctx, query, service.UsageCleanupStatusFailed, deletedRows, errorMsg, taskID)
|
|
return err
|
|
}
|
|
|
|
func (r *usageCleanupRepository) DeleteUsageLogsBatch(ctx context.Context, filters service.UsageCleanupFilters, limit int) (int64, error) {
|
|
if filters.StartTime.IsZero() || filters.EndTime.IsZero() {
|
|
return 0, fmt.Errorf("cleanup filters missing time range")
|
|
}
|
|
whereClause, args := buildUsageCleanupWhere(filters)
|
|
if whereClause == "" {
|
|
return 0, fmt.Errorf("cleanup filters missing time range")
|
|
}
|
|
args = append(args, limit)
|
|
query := fmt.Sprintf(`
|
|
WITH target AS (
|
|
SELECT id
|
|
FROM usage_logs
|
|
WHERE %s
|
|
ORDER BY created_at ASC, id ASC
|
|
LIMIT $%d
|
|
)
|
|
DELETE FROM usage_logs
|
|
WHERE id IN (SELECT id FROM target)
|
|
RETURNING id
|
|
`, whereClause, len(args))
|
|
|
|
rows, err := r.sql.QueryContext(ctx, query, args...)
|
|
if err != nil {
|
|
return 0, err
|
|
}
|
|
defer func() {
|
|
_ = rows.Close()
|
|
}()
|
|
|
|
var deleted int64
|
|
for rows.Next() {
|
|
deleted++
|
|
}
|
|
if err := rows.Err(); err != nil {
|
|
return 0, err
|
|
}
|
|
return deleted, nil
|
|
}
|
|
|
|
func buildUsageCleanupWhere(filters service.UsageCleanupFilters) (string, []any) {
|
|
conditions := make([]string, 0, 8)
|
|
args := make([]any, 0, 8)
|
|
idx := 1
|
|
if !filters.StartTime.IsZero() {
|
|
conditions = append(conditions, fmt.Sprintf("created_at >= $%d", idx))
|
|
args = append(args, filters.StartTime)
|
|
idx++
|
|
}
|
|
if !filters.EndTime.IsZero() {
|
|
conditions = append(conditions, fmt.Sprintf("created_at <= $%d", idx))
|
|
args = append(args, filters.EndTime)
|
|
idx++
|
|
}
|
|
if filters.UserID != nil {
|
|
conditions = append(conditions, fmt.Sprintf("user_id = $%d", idx))
|
|
args = append(args, *filters.UserID)
|
|
idx++
|
|
}
|
|
if filters.APIKeyID != nil {
|
|
conditions = append(conditions, fmt.Sprintf("api_key_id = $%d", idx))
|
|
args = append(args, *filters.APIKeyID)
|
|
idx++
|
|
}
|
|
if filters.AccountID != nil {
|
|
conditions = append(conditions, fmt.Sprintf("account_id = $%d", idx))
|
|
args = append(args, *filters.AccountID)
|
|
idx++
|
|
}
|
|
if filters.GroupID != nil {
|
|
conditions = append(conditions, fmt.Sprintf("group_id = $%d", idx))
|
|
args = append(args, *filters.GroupID)
|
|
idx++
|
|
}
|
|
if filters.Model != nil {
|
|
model := strings.TrimSpace(*filters.Model)
|
|
if model != "" {
|
|
conditions = append(conditions, fmt.Sprintf("model = $%d", idx))
|
|
args = append(args, model)
|
|
idx++
|
|
}
|
|
}
|
|
if filters.Stream != nil {
|
|
conditions = append(conditions, fmt.Sprintf("stream = $%d", idx))
|
|
args = append(args, *filters.Stream)
|
|
idx++
|
|
}
|
|
if filters.BillingType != nil {
|
|
conditions = append(conditions, fmt.Sprintf("billing_type = $%d", idx))
|
|
args = append(args, *filters.BillingType)
|
|
}
|
|
return strings.Join(conditions, " AND "), args
|
|
}
|