382 lines
10 KiB
Go
382 lines
10 KiB
Go
package setup
|
|
|
|
import (
|
|
"fmt"
|
|
"log"
|
|
"net/http"
|
|
"net/mail"
|
|
"os/exec"
|
|
"regexp"
|
|
"runtime"
|
|
"strings"
|
|
"sync"
|
|
"time"
|
|
|
|
"sub2api/internal/pkg/response"
|
|
|
|
"github.com/gin-gonic/gin"
|
|
)
|
|
|
|
// installMutex prevents concurrent installation attempts (TOCTOU protection)
|
|
var installMutex sync.Mutex
|
|
|
|
// RegisterRoutes registers setup wizard routes
|
|
func RegisterRoutes(r *gin.Engine) {
|
|
setup := r.Group("/setup")
|
|
{
|
|
// Status endpoint is always accessible (read-only)
|
|
setup.GET("/status", getStatus)
|
|
|
|
// All modification endpoints are protected by setupGuard
|
|
protected := setup.Group("")
|
|
protected.Use(setupGuard())
|
|
{
|
|
protected.POST("/test-db", testDatabase)
|
|
protected.POST("/test-redis", testRedis)
|
|
protected.POST("/install", install)
|
|
}
|
|
}
|
|
}
|
|
|
|
// SetupStatus represents the current setup state
|
|
type SetupStatus struct {
|
|
NeedsSetup bool `json:"needs_setup"`
|
|
Step string `json:"step"`
|
|
}
|
|
|
|
// getStatus returns the current setup status
|
|
func getStatus(c *gin.Context) {
|
|
response.Success(c, SetupStatus{
|
|
NeedsSetup: NeedsSetup(),
|
|
Step: "welcome",
|
|
})
|
|
}
|
|
|
|
// setupGuard middleware ensures setup endpoints are only accessible during setup mode
|
|
func setupGuard() gin.HandlerFunc {
|
|
return func(c *gin.Context) {
|
|
if !NeedsSetup() {
|
|
response.Error(c, http.StatusForbidden, "Setup is not allowed: system is already installed")
|
|
c.Abort()
|
|
return
|
|
}
|
|
c.Next()
|
|
}
|
|
}
|
|
|
|
// validateHostname checks if a hostname/IP is safe (no injection characters)
|
|
func validateHostname(host string) bool {
|
|
// Allow only alphanumeric, dots, hyphens, and colons (for IPv6)
|
|
validHost := regexp.MustCompile(`^[a-zA-Z0-9.\-:]+$`)
|
|
return validHost.MatchString(host) && len(host) <= 253
|
|
}
|
|
|
|
// validateDBName checks if database name is safe
|
|
func validateDBName(name string) bool {
|
|
// Allow only alphanumeric and underscores, starting with letter
|
|
validName := regexp.MustCompile(`^[a-zA-Z][a-zA-Z0-9_]*$`)
|
|
return validName.MatchString(name) && len(name) <= 63
|
|
}
|
|
|
|
// validateUsername checks if username is safe
|
|
func validateUsername(name string) bool {
|
|
// Allow only alphanumeric and underscores
|
|
validName := regexp.MustCompile(`^[a-zA-Z0-9_]+$`)
|
|
return validName.MatchString(name) && len(name) <= 63
|
|
}
|
|
|
|
// validateEmail checks if email format is valid
|
|
func validateEmail(email string) bool {
|
|
_, err := mail.ParseAddress(email)
|
|
return err == nil && len(email) <= 254
|
|
}
|
|
|
|
// validatePassword checks password strength
|
|
func validatePassword(password string) error {
|
|
if len(password) < 8 {
|
|
return fmt.Errorf("password must be at least 8 characters")
|
|
}
|
|
if len(password) > 128 {
|
|
return fmt.Errorf("password must be at most 128 characters")
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// validatePort checks if port is in valid range
|
|
func validatePort(port int) bool {
|
|
return port > 0 && port <= 65535
|
|
}
|
|
|
|
// validateSSLMode checks if SSL mode is valid
|
|
func validateSSLMode(mode string) bool {
|
|
validModes := map[string]bool{
|
|
"disable": true, "require": true, "verify-ca": true, "verify-full": true,
|
|
}
|
|
return validModes[mode]
|
|
}
|
|
|
|
// TestDatabaseRequest represents database test request
|
|
type TestDatabaseRequest struct {
|
|
Host string `json:"host" binding:"required"`
|
|
Port int `json:"port" binding:"required"`
|
|
User string `json:"user" binding:"required"`
|
|
Password string `json:"password"`
|
|
DBName string `json:"dbname" binding:"required"`
|
|
SSLMode string `json:"sslmode"`
|
|
}
|
|
|
|
// testDatabase tests database connection
|
|
func testDatabase(c *gin.Context) {
|
|
var req TestDatabaseRequest
|
|
if err := c.ShouldBindJSON(&req); err != nil {
|
|
response.Error(c, http.StatusBadRequest, "Invalid request: "+err.Error())
|
|
return
|
|
}
|
|
|
|
// Security: Validate all inputs to prevent injection attacks
|
|
if !validateHostname(req.Host) {
|
|
response.Error(c, http.StatusBadRequest, "Invalid hostname format")
|
|
return
|
|
}
|
|
if !validatePort(req.Port) {
|
|
response.Error(c, http.StatusBadRequest, "Invalid port number")
|
|
return
|
|
}
|
|
if !validateUsername(req.User) {
|
|
response.Error(c, http.StatusBadRequest, "Invalid username format")
|
|
return
|
|
}
|
|
if !validateDBName(req.DBName) {
|
|
response.Error(c, http.StatusBadRequest, "Invalid database name format")
|
|
return
|
|
}
|
|
|
|
if req.SSLMode == "" {
|
|
req.SSLMode = "disable"
|
|
}
|
|
if !validateSSLMode(req.SSLMode) {
|
|
response.Error(c, http.StatusBadRequest, "Invalid SSL mode")
|
|
return
|
|
}
|
|
|
|
cfg := &DatabaseConfig{
|
|
Host: req.Host,
|
|
Port: req.Port,
|
|
User: req.User,
|
|
Password: req.Password,
|
|
DBName: req.DBName,
|
|
SSLMode: req.SSLMode,
|
|
}
|
|
|
|
if err := TestDatabaseConnection(cfg); err != nil {
|
|
response.Error(c, http.StatusBadRequest, "Connection failed: "+err.Error())
|
|
return
|
|
}
|
|
|
|
response.Success(c, gin.H{"message": "Connection successful"})
|
|
}
|
|
|
|
// TestRedisRequest represents Redis test request
|
|
type TestRedisRequest struct {
|
|
Host string `json:"host" binding:"required"`
|
|
Port int `json:"port" binding:"required"`
|
|
Password string `json:"password"`
|
|
DB int `json:"db"`
|
|
}
|
|
|
|
// testRedis tests Redis connection
|
|
func testRedis(c *gin.Context) {
|
|
var req TestRedisRequest
|
|
if err := c.ShouldBindJSON(&req); err != nil {
|
|
response.Error(c, http.StatusBadRequest, "Invalid request: "+err.Error())
|
|
return
|
|
}
|
|
|
|
// Security: Validate inputs
|
|
if !validateHostname(req.Host) {
|
|
response.Error(c, http.StatusBadRequest, "Invalid hostname format")
|
|
return
|
|
}
|
|
if !validatePort(req.Port) {
|
|
response.Error(c, http.StatusBadRequest, "Invalid port number")
|
|
return
|
|
}
|
|
if req.DB < 0 || req.DB > 15 {
|
|
response.Error(c, http.StatusBadRequest, "Invalid Redis database number (0-15)")
|
|
return
|
|
}
|
|
|
|
cfg := &RedisConfig{
|
|
Host: req.Host,
|
|
Port: req.Port,
|
|
Password: req.Password,
|
|
DB: req.DB,
|
|
}
|
|
|
|
if err := TestRedisConnection(cfg); err != nil {
|
|
response.Error(c, http.StatusBadRequest, "Connection failed: "+err.Error())
|
|
return
|
|
}
|
|
|
|
response.Success(c, gin.H{"message": "Connection successful"})
|
|
}
|
|
|
|
// InstallRequest represents installation request
|
|
type InstallRequest struct {
|
|
Database DatabaseConfig `json:"database" binding:"required"`
|
|
Redis RedisConfig `json:"redis" binding:"required"`
|
|
Admin AdminConfig `json:"admin" binding:"required"`
|
|
Server ServerConfig `json:"server"`
|
|
}
|
|
|
|
// install performs the installation
|
|
func install(c *gin.Context) {
|
|
// TOCTOU Protection: Acquire mutex to prevent concurrent installation
|
|
installMutex.Lock()
|
|
defer installMutex.Unlock()
|
|
|
|
// Double-check after acquiring lock
|
|
if !NeedsSetup() {
|
|
response.Error(c, http.StatusForbidden, "Setup is not allowed: system is already installed")
|
|
return
|
|
}
|
|
|
|
var req InstallRequest
|
|
if err := c.ShouldBindJSON(&req); err != nil {
|
|
response.Error(c, http.StatusBadRequest, "Invalid request: "+err.Error())
|
|
return
|
|
}
|
|
|
|
// ========== COMPREHENSIVE INPUT VALIDATION ==========
|
|
// Database validation
|
|
if !validateHostname(req.Database.Host) {
|
|
response.Error(c, http.StatusBadRequest, "Invalid database hostname")
|
|
return
|
|
}
|
|
if !validatePort(req.Database.Port) {
|
|
response.Error(c, http.StatusBadRequest, "Invalid database port")
|
|
return
|
|
}
|
|
if !validateUsername(req.Database.User) {
|
|
response.Error(c, http.StatusBadRequest, "Invalid database username")
|
|
return
|
|
}
|
|
if !validateDBName(req.Database.DBName) {
|
|
response.Error(c, http.StatusBadRequest, "Invalid database name")
|
|
return
|
|
}
|
|
|
|
// Redis validation
|
|
if !validateHostname(req.Redis.Host) {
|
|
response.Error(c, http.StatusBadRequest, "Invalid Redis hostname")
|
|
return
|
|
}
|
|
if !validatePort(req.Redis.Port) {
|
|
response.Error(c, http.StatusBadRequest, "Invalid Redis port")
|
|
return
|
|
}
|
|
if req.Redis.DB < 0 || req.Redis.DB > 15 {
|
|
response.Error(c, http.StatusBadRequest, "Invalid Redis database number")
|
|
return
|
|
}
|
|
|
|
// Admin validation
|
|
if !validateEmail(req.Admin.Email) {
|
|
response.Error(c, http.StatusBadRequest, "Invalid admin email format")
|
|
return
|
|
}
|
|
if err := validatePassword(req.Admin.Password); err != nil {
|
|
response.Error(c, http.StatusBadRequest, err.Error())
|
|
return
|
|
}
|
|
|
|
// Server validation
|
|
if req.Server.Port != 0 && !validatePort(req.Server.Port) {
|
|
response.Error(c, http.StatusBadRequest, "Invalid server port")
|
|
return
|
|
}
|
|
|
|
// ========== SET DEFAULTS ==========
|
|
if req.Database.SSLMode == "" {
|
|
req.Database.SSLMode = "disable"
|
|
}
|
|
if !validateSSLMode(req.Database.SSLMode) {
|
|
response.Error(c, http.StatusBadRequest, "Invalid SSL mode")
|
|
return
|
|
}
|
|
if req.Server.Host == "" {
|
|
req.Server.Host = "0.0.0.0"
|
|
}
|
|
if req.Server.Port == 0 {
|
|
req.Server.Port = 8080
|
|
}
|
|
if req.Server.Mode == "" {
|
|
req.Server.Mode = "release"
|
|
}
|
|
// Validate server mode
|
|
if req.Server.Mode != "release" && req.Server.Mode != "debug" {
|
|
response.Error(c, http.StatusBadRequest, "Invalid server mode (must be 'release' or 'debug')")
|
|
return
|
|
}
|
|
|
|
// Trim whitespace from string inputs
|
|
req.Admin.Email = strings.TrimSpace(req.Admin.Email)
|
|
req.Database.Host = strings.TrimSpace(req.Database.Host)
|
|
req.Database.User = strings.TrimSpace(req.Database.User)
|
|
req.Database.DBName = strings.TrimSpace(req.Database.DBName)
|
|
req.Redis.Host = strings.TrimSpace(req.Redis.Host)
|
|
|
|
cfg := &SetupConfig{
|
|
Database: req.Database,
|
|
Redis: req.Redis,
|
|
Admin: req.Admin,
|
|
Server: req.Server,
|
|
JWT: JWTConfig{
|
|
ExpireHour: 24,
|
|
},
|
|
}
|
|
|
|
if err := Install(cfg); err != nil {
|
|
response.Error(c, http.StatusInternalServerError, "Installation failed: "+err.Error())
|
|
return
|
|
}
|
|
|
|
// Schedule service restart in background after sending response
|
|
// This ensures the client receives the success response before the service restarts
|
|
go func() {
|
|
// Wait a moment to ensure the response is sent
|
|
time.Sleep(500 * time.Millisecond)
|
|
triggerServiceRestart()
|
|
}()
|
|
|
|
response.Success(c, gin.H{
|
|
"message": "Installation completed successfully. Service will restart automatically.",
|
|
"restart": true,
|
|
})
|
|
}
|
|
|
|
// triggerServiceRestart attempts to restart the service via systemd
|
|
// This is called after setup completes to switch from setup mode to normal mode
|
|
func triggerServiceRestart() {
|
|
if runtime.GOOS != "linux" {
|
|
log.Println("Service restart: not on Linux, manual restart required")
|
|
return
|
|
}
|
|
|
|
log.Println("Setup completed, triggering service restart...")
|
|
|
|
// Try direct systemctl first (works if running as root or with proper permissions)
|
|
cmd := exec.Command("systemctl", "restart", "sub2api")
|
|
if err := cmd.Run(); err != nil {
|
|
// Try with sudo (requires NOPASSWD sudoers entry)
|
|
sudoCmd := exec.Command("sudo", "systemctl", "restart", "sub2api")
|
|
if sudoErr := sudoCmd.Run(); sudoErr != nil {
|
|
log.Printf("Service restart failed: %v (sudo also failed: %v)", err, sudoErr)
|
|
log.Println("Please restart the service manually: sudo systemctl restart sub2api")
|
|
return
|
|
}
|
|
}
|
|
|
|
log.Println("Service restart initiated successfully")
|
|
}
|