Files
sub2api/backend/internal/repository/migrations_schema_integration_test.go
2026-04-22 14:56:56 +08:00

283 lines
11 KiB
Go

//go:build integration
package repository
import (
"context"
"database/sql"
"testing"
"github.com/stretchr/testify/require"
)
func TestMigrationsRunner_IsIdempotent_AndSchemaIsUpToDate(t *testing.T) {
tx := testTx(t)
// Re-apply migrations to verify idempotency (no errors, no duplicate rows).
require.NoError(t, ApplyMigrations(context.Background(), integrationDB))
// schema_migrations should have at least the current migration set.
var applied int
require.NoError(t, tx.QueryRowContext(context.Background(), "SELECT COUNT(*) FROM schema_migrations").Scan(&applied))
require.GreaterOrEqual(t, applied, 7, "expected schema_migrations to contain applied migrations")
// users: columns required by repository queries
requireColumn(t, tx, "users", "username", "character varying", 100, false)
requireColumn(t, tx, "users", "notes", "text", 0, false)
// accounts: schedulable and rate-limit fields
requireColumn(t, tx, "accounts", "notes", "text", 0, true)
requireColumn(t, tx, "accounts", "schedulable", "boolean", 0, false)
requireColumn(t, tx, "accounts", "rate_limited_at", "timestamp with time zone", 0, true)
requireColumn(t, tx, "accounts", "rate_limit_reset_at", "timestamp with time zone", 0, true)
requireColumn(t, tx, "accounts", "overload_until", "timestamp with time zone", 0, true)
requireColumn(t, tx, "accounts", "session_window_status", "character varying", 20, true)
// api_keys: key length should be 128
requireColumn(t, tx, "api_keys", "key", "character varying", 128, false)
// redeem_codes: subscription fields
requireColumn(t, tx, "redeem_codes", "group_id", "bigint", 0, true)
requireColumn(t, tx, "redeem_codes", "validity_days", "integer", 0, false)
// usage_logs: billing_type used by filters/stats
requireColumn(t, tx, "usage_logs", "billing_type", "smallint", 0, false)
requireColumn(t, tx, "usage_logs", "request_type", "smallint", 0, false)
requireColumn(t, tx, "usage_logs", "openai_ws_mode", "boolean", 0, false)
// usage_billing_dedup: billing idempotency narrow table
var usageBillingDedupRegclass sql.NullString
require.NoError(t, tx.QueryRowContext(context.Background(), "SELECT to_regclass('public.usage_billing_dedup')").Scan(&usageBillingDedupRegclass))
require.True(t, usageBillingDedupRegclass.Valid, "expected usage_billing_dedup table to exist")
requireColumn(t, tx, "usage_billing_dedup", "request_fingerprint", "character varying", 64, false)
requireIndex(t, tx, "usage_billing_dedup", "idx_usage_billing_dedup_request_api_key")
requireIndex(t, tx, "usage_billing_dedup", "idx_usage_billing_dedup_created_at_brin")
var usageBillingDedupArchiveRegclass sql.NullString
require.NoError(t, tx.QueryRowContext(context.Background(), "SELECT to_regclass('public.usage_billing_dedup_archive')").Scan(&usageBillingDedupArchiveRegclass))
require.True(t, usageBillingDedupArchiveRegclass.Valid, "expected usage_billing_dedup_archive table to exist")
requireColumn(t, tx, "usage_billing_dedup_archive", "request_fingerprint", "character varying", 64, false)
requireIndex(t, tx, "usage_billing_dedup_archive", "usage_billing_dedup_archive_pkey")
// settings table should exist
var settingsRegclass sql.NullString
require.NoError(t, tx.QueryRowContext(context.Background(), "SELECT to_regclass('public.settings')").Scan(&settingsRegclass))
require.True(t, settingsRegclass.Valid, "expected settings table to exist")
// security_secrets table should exist
var securitySecretsRegclass sql.NullString
require.NoError(t, tx.QueryRowContext(context.Background(), "SELECT to_regclass('public.security_secrets')").Scan(&securitySecretsRegclass))
require.True(t, securitySecretsRegclass.Valid, "expected security_secrets table to exist")
// user_allowed_groups table should exist
var uagRegclass sql.NullString
require.NoError(t, tx.QueryRowContext(context.Background(), "SELECT to_regclass('public.user_allowed_groups')").Scan(&uagRegclass))
require.True(t, uagRegclass.Valid, "expected user_allowed_groups table to exist")
// user_subscriptions: deleted_at for soft delete support (migration 012)
requireColumn(t, tx, "user_subscriptions", "deleted_at", "timestamp with time zone", 0, true)
// orphan_allowed_groups_audit table should exist (migration 013)
var orphanAuditRegclass sql.NullString
require.NoError(t, tx.QueryRowContext(context.Background(), "SELECT to_regclass('public.orphan_allowed_groups_audit')").Scan(&orphanAuditRegclass))
require.True(t, orphanAuditRegclass.Valid, "expected orphan_allowed_groups_audit table to exist")
// account_groups: created_at should be timestamptz
requireColumn(t, tx, "account_groups", "created_at", "timestamp with time zone", 0, false)
// user_allowed_groups: created_at should be timestamptz
requireColumn(t, tx, "user_allowed_groups", "created_at", "timestamp with time zone", 0, false)
}
func TestMigrationsRunner_AuthIdentityAndPaymentSchemaStayAligned(t *testing.T) {
tx := testTx(t)
requireColumn(t, tx, "auth_identity_migration_reports", "report_type", "character varying", 80, false)
requireColumn(t, tx, "users", "signup_source", "character varying", 20, false)
requireColumnDefaultContains(t, tx, "users", "signup_source", "email")
requireConstraintDefinitionContains(
t,
tx,
"users",
"users_signup_source_check",
"signup_source",
"'email'",
"'linuxdo'",
"'wechat'",
"'oidc'",
)
requireForeignKeyOnDelete(t, tx, "auth_identities", "user_id", "users", "CASCADE")
requireForeignKeyOnDelete(t, tx, "auth_identity_channels", "identity_id", "auth_identities", "CASCADE")
requireForeignKeyOnDelete(t, tx, "pending_auth_sessions", "target_user_id", "users", "SET NULL")
requireForeignKeyOnDelete(t, tx, "identity_adoption_decisions", "pending_auth_session_id", "pending_auth_sessions", "CASCADE")
requireForeignKeyOnDelete(t, tx, "identity_adoption_decisions", "identity_id", "auth_identities", "SET NULL")
requireIndex(t, tx, "payment_orders", "paymentorder_out_trade_no")
requirePartialUniqueIndexDefinition(t, tx, "payment_orders", "paymentorder_out_trade_no", "out_trade_no", "WHERE")
requireIndexAbsent(t, tx, "payment_orders", "paymentorder_out_trade_no_unique")
}
func requireIndex(t *testing.T, tx *sql.Tx, table, index string) {
t.Helper()
var exists bool
err := tx.QueryRowContext(context.Background(), `
SELECT EXISTS (
SELECT 1
FROM pg_indexes
WHERE schemaname = 'public'
AND tablename = $1
AND indexname = $2
)
`, table, index).Scan(&exists)
require.NoError(t, err, "query pg_indexes for %s.%s", table, index)
require.True(t, exists, "expected index %s on %s", index, table)
}
func requireIndexAbsent(t *testing.T, tx *sql.Tx, table, index string) {
t.Helper()
var exists bool
err := tx.QueryRowContext(context.Background(), `
SELECT EXISTS (
SELECT 1
FROM pg_indexes
WHERE schemaname = 'public'
AND tablename = $1
AND indexname = $2
)
`, table, index).Scan(&exists)
require.NoError(t, err, "query pg_indexes for %s.%s", table, index)
require.False(t, exists, "expected index %s on %s to be absent", index, table)
}
func requirePartialUniqueIndexDefinition(t *testing.T, tx *sql.Tx, table, index string, fragments ...string) {
t.Helper()
var (
unique bool
def string
)
err := tx.QueryRowContext(context.Background(), `
SELECT
i.indisunique,
pg_get_indexdef(i.indexrelid)
FROM pg_class idx
JOIN pg_index i ON i.indexrelid = idx.oid
JOIN pg_class tbl ON tbl.oid = i.indrelid
JOIN pg_namespace ns ON ns.oid = tbl.relnamespace
WHERE ns.nspname = 'public'
AND tbl.relname = $1
AND idx.relname = $2
`, table, index).Scan(&unique, &def)
require.NoError(t, err, "query index definition for %s.%s", table, index)
require.True(t, unique, "expected index %s on %s to be unique", index, table)
for _, fragment := range fragments {
require.Contains(t, def, fragment, "expected index definition for %s.%s to contain %q", table, index, fragment)
}
}
func requireForeignKeyOnDelete(t *testing.T, tx *sql.Tx, table, column, refTable, expected string) {
t.Helper()
var actual string
err := tx.QueryRowContext(context.Background(), `
SELECT CASE c.confdeltype
WHEN 'a' THEN 'NO ACTION'
WHEN 'r' THEN 'RESTRICT'
WHEN 'c' THEN 'CASCADE'
WHEN 'n' THEN 'SET NULL'
WHEN 'd' THEN 'SET DEFAULT'
END
FROM pg_constraint c
JOIN pg_class tbl ON tbl.oid = c.conrelid
JOIN pg_namespace ns ON ns.oid = tbl.relnamespace
JOIN pg_class ref_tbl ON ref_tbl.oid = c.confrelid
JOIN pg_attribute attr ON attr.attrelid = tbl.oid AND attr.attnum = ANY(c.conkey)
WHERE ns.nspname = 'public'
AND c.contype = 'f'
AND tbl.relname = $1
AND attr.attname = $2
AND ref_tbl.relname = $3
LIMIT 1
`, table, column, refTable).Scan(&actual)
require.NoError(t, err, "query foreign key action for %s.%s -> %s", table, column, refTable)
require.Equal(t, expected, actual, "unexpected ON DELETE action for %s.%s -> %s", table, column, refTable)
}
func requireConstraintDefinitionContains(t *testing.T, tx *sql.Tx, table, constraint string, fragments ...string) {
t.Helper()
var def string
err := tx.QueryRowContext(context.Background(), `
SELECT pg_get_constraintdef(c.oid)
FROM pg_constraint c
JOIN pg_class tbl ON tbl.oid = c.conrelid
JOIN pg_namespace ns ON ns.oid = tbl.relnamespace
WHERE ns.nspname = 'public'
AND tbl.relname = $1
AND c.conname = $2
`, table, constraint).Scan(&def)
require.NoError(t, err, "query constraint definition for %s.%s", table, constraint)
for _, fragment := range fragments {
require.Contains(t, def, fragment, "expected constraint definition for %s.%s to contain %q", table, constraint, fragment)
}
}
func requireColumnDefaultContains(t *testing.T, tx *sql.Tx, table, column string, fragments ...string) {
t.Helper()
var columnDefault sql.NullString
err := tx.QueryRowContext(context.Background(), `
SELECT column_default
FROM information_schema.columns
WHERE table_schema = 'public'
AND table_name = $1
AND column_name = $2
`, table, column).Scan(&columnDefault)
require.NoError(t, err, "query column_default for %s.%s", table, column)
require.True(t, columnDefault.Valid, "expected column_default for %s.%s", table, column)
for _, fragment := range fragments {
require.Contains(t, columnDefault.String, fragment, "expected default for %s.%s to contain %q", table, column, fragment)
}
}
func requireColumn(t *testing.T, tx *sql.Tx, table, column, dataType string, maxLen int, nullable bool) {
t.Helper()
var row struct {
DataType string
MaxLen sql.NullInt64
Nullable string
}
err := tx.QueryRowContext(context.Background(), `
SELECT
data_type,
character_maximum_length,
is_nullable
FROM information_schema.columns
WHERE table_schema = 'public'
AND table_name = $1
AND column_name = $2
`, table, column).Scan(&row.DataType, &row.MaxLen, &row.Nullable)
require.NoError(t, err, "query information_schema.columns for %s.%s", table, column)
require.Equal(t, dataType, row.DataType, "data_type mismatch for %s.%s", table, column)
if maxLen > 0 {
require.True(t, row.MaxLen.Valid, "expected maxLen for %s.%s", table, column)
require.Equal(t, int64(maxLen), row.MaxLen.Int64, "maxLen mismatch for %s.%s", table, column)
}
if nullable {
require.Equal(t, "YES", row.Nullable, "nullable mismatch for %s.%s", table, column)
} else {
require.Equal(t, "NO", row.Nullable, "nullable mismatch for %s.%s", table, column)
}
}