diff --git a/backend/internal/setup/setup.go b/backend/internal/setup/setup.go index f81f75cf..d35b82d0 100644 --- a/backend/internal/setup/setup.go +++ b/backend/internal/setup/setup.go @@ -10,6 +10,7 @@ import ( "log" "os" "strconv" + "strings" "time" "github.com/Wei-Shaw/sub2api/internal/repository" @@ -103,6 +104,36 @@ type JWTConfig struct { ExpireHour int `json:"expire_hour" yaml:"expire_hour"` } +const ( + adminBootstrapReasonEmptyDatabase = "empty_database" + adminBootstrapReasonAdminExists = "admin_exists" + adminBootstrapReasonUsersExistWithoutAdmin = "users_exist_without_admin" +) + +type adminBootstrapDecision struct { + shouldCreate bool + reason string +} + +func decideAdminBootstrap(totalUsers, adminUsers int64) adminBootstrapDecision { + if adminUsers > 0 { + return adminBootstrapDecision{ + shouldCreate: false, + reason: adminBootstrapReasonAdminExists, + } + } + if totalUsers > 0 { + return adminBootstrapDecision{ + shouldCreate: false, + reason: adminBootstrapReasonUsersExistWithoutAdmin, + } + } + return adminBootstrapDecision{ + shouldCreate: true, + reason: adminBootstrapReasonEmptyDatabase, + } +} + // NeedsSetup checks if the system needs initial setup // Uses multiple checks to prevent attackers from forcing re-setup by deleting config func NeedsSetup() bool { @@ -262,8 +293,8 @@ func Install(cfg *SetupConfig) error { return fmt.Errorf("database initialization failed: %w", err) } - // Create admin user - if err := createAdminUser(cfg); err != nil { + // Create admin user (only when database is empty and no admin exists). + if _, _, err := createAdminUser(cfg); err != nil { return fmt.Errorf("admin user creation failed: %w", err) } @@ -309,7 +340,7 @@ func initializeDatabase(cfg *SetupConfig) error { return repository.ApplyMigrations(migrationCtx, db) } -func createAdminUser(cfg *SetupConfig) error { +func createAdminUser(cfg *SetupConfig) (bool, string, error) { dsn := fmt.Sprintf( "host=%s port=%d user=%s password=%s dbname=%s sslmode=%s", cfg.Database.Host, cfg.Database.Port, cfg.Database.User, @@ -318,7 +349,7 @@ func createAdminUser(cfg *SetupConfig) error { db, err := sql.Open("postgres", dsn) if err != nil { - return err + return false, "", err } defer func() { @@ -331,13 +362,27 @@ func createAdminUser(cfg *SetupConfig) error { ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() - // Check if admin already exists - var count int64 - if err := db.QueryRowContext(ctx, "SELECT COUNT(1) FROM users WHERE role = $1", service.RoleAdmin).Scan(&count); err != nil { - return err + var totalUsers int64 + if err := db.QueryRowContext(ctx, "SELECT COUNT(1) FROM users").Scan(&totalUsers); err != nil { + return false, "", err } - if count > 0 { - return nil // Admin already exists + var adminUsers int64 + if err := db.QueryRowContext(ctx, "SELECT COUNT(1) FROM users WHERE role = $1", service.RoleAdmin).Scan(&adminUsers); err != nil { + return false, "", err + } + decision := decideAdminBootstrap(totalUsers, adminUsers) + if !decision.shouldCreate { + return false, decision.reason, nil + } + + if strings.TrimSpace(cfg.Admin.Password) == "" { + password, genErr := generateSecret(16) + if genErr != nil { + return false, "", fmt.Errorf("failed to generate admin password: %w", genErr) + } + cfg.Admin.Password = password + fmt.Printf("Generated admin password (one-time): %s\n", cfg.Admin.Password) + fmt.Println("IMPORTANT: Save this password! It will not be shown again.") } admin := &service.User{ @@ -351,7 +396,7 @@ func createAdminUser(cfg *SetupConfig) error { } if err := admin.SetPassword(cfg.Admin.Password); err != nil { - return err + return false, "", err } _, err = db.ExecContext( @@ -367,7 +412,10 @@ func createAdminUser(cfg *SetupConfig) error { admin.CreatedAt, admin.UpdatedAt, ) - return err + if err != nil { + return false, "", err + } + return true, decision.reason, nil } func writeConfigFile(cfg *SetupConfig) error { @@ -528,17 +576,6 @@ func AutoSetupFromEnv() error { log.Println("Warning: JWT secret auto-generated. Consider setting a fixed secret for production.") } - // Generate admin password if not provided - if cfg.Admin.Password == "" { - password, err := generateSecret(16) - if err != nil { - return fmt.Errorf("failed to generate admin password: %w", err) - } - cfg.Admin.Password = password - fmt.Printf("Generated admin password (one-time): %s\n", cfg.Admin.Password) - fmt.Println("IMPORTANT: Save this password! It will not be shown again.") - } - // Test database connection log.Println("Testing database connection...") if err := TestDatabaseConnection(&cfg.Database); err != nil { @@ -562,10 +599,22 @@ func AutoSetupFromEnv() error { // Create admin user log.Println("Creating admin user...") - if err := createAdminUser(cfg); err != nil { + created, reason, err := createAdminUser(cfg) + if err != nil { return fmt.Errorf("admin user creation failed: %w", err) } - log.Printf("Admin user created: %s", cfg.Admin.Email) + if created { + log.Printf("Admin user created: %s", cfg.Admin.Email) + } else { + switch reason { + case adminBootstrapReasonAdminExists: + log.Println("Admin user already exists, skipping admin bootstrap") + case adminBootstrapReasonUsersExistWithoutAdmin: + log.Println("Database already has user data; skipping auto admin bootstrap to avoid password overwrite") + default: + log.Println("Admin bootstrap skipped") + } + } // Write config file log.Println("Writing configuration file...") diff --git a/backend/internal/setup/setup_test.go b/backend/internal/setup/setup_test.go new file mode 100644 index 00000000..69655e92 --- /dev/null +++ b/backend/internal/setup/setup_test.go @@ -0,0 +1,51 @@ +package setup + +import "testing" + +func TestDecideAdminBootstrap(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + totalUsers int64 + adminUsers int64 + should bool + reason string + }{ + { + name: "empty database should create admin", + totalUsers: 0, + adminUsers: 0, + should: true, + reason: adminBootstrapReasonEmptyDatabase, + }, + { + name: "admin exists should skip", + totalUsers: 10, + adminUsers: 1, + should: false, + reason: adminBootstrapReasonAdminExists, + }, + { + name: "users exist without admin should skip", + totalUsers: 5, + adminUsers: 0, + should: false, + reason: adminBootstrapReasonUsersExistWithoutAdmin, + }, + } + + for _, tc := range tests { + tc := tc + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + got := decideAdminBootstrap(tc.totalUsers, tc.adminUsers) + if got.shouldCreate != tc.should { + t.Fatalf("shouldCreate=%v, want %v", got.shouldCreate, tc.should) + } + if got.reason != tc.reason { + t.Fatalf("reason=%q, want %q", got.reason, tc.reason) + } + }) + } +}