diff --git a/.gitignore b/.gitignore index 60959006..8c5b8302 100644 --- a/.gitignore +++ b/.gitignore @@ -12,6 +12,7 @@ docs/claude-relay-service/ backend/bin/ backend/server backend/sub2api +backend/main # 测试覆盖率 *.out diff --git a/backend/cmd/server/main.go b/backend/cmd/server/main.go index e020c218..ee55d599 100644 --- a/backend/cmd/server/main.go +++ b/backend/cmd/server/main.go @@ -212,11 +212,10 @@ func initDB(cfg *config.Config) (*gorm.DB, error) { return nil, err } - // 自动迁移(开发环境) - if cfg.Server.Mode == "debug" { - if err := model.AutoMigrate(db); err != nil { - return nil, err - } + // 自动迁移(始终执行,确保数据库结构与代码同步) + // GORM 的 AutoMigrate 只会添加新字段,不会删除或修改已有字段,是安全的 + if err := model.AutoMigrate(db); err != nil { + return nil, err } return db, nil diff --git a/backend/internal/setup/setup.go b/backend/internal/setup/setup.go index a7d062da..e8cb702e 100644 --- a/backend/internal/setup/setup.go +++ b/backend/internal/setup/setup.go @@ -10,6 +10,8 @@ import ( "strconv" "time" + "sub2api/internal/model" + "github.com/redis/go-redis/v9" "golang.org/x/crypto/bcrypt" "gorm.io/driver/postgres" @@ -82,16 +84,17 @@ func NeedsSetup() bool { return true } -// TestDatabaseConnection tests the database connection +// TestDatabaseConnection tests the database connection and creates database if not exists func TestDatabaseConnection(cfg *DatabaseConfig) error { - dsn := fmt.Sprintf( - "host=%s port=%d user=%s password=%s dbname=%s sslmode=%s", - cfg.Host, cfg.Port, cfg.User, cfg.Password, cfg.DBName, cfg.SSLMode, + // First, connect to the default 'postgres' database to check/create target database + defaultDSN := fmt.Sprintf( + "host=%s port=%d user=%s password=%s dbname=postgres sslmode=%s", + cfg.Host, cfg.Port, cfg.User, cfg.Password, cfg.SSLMode, ) - db, err := gorm.Open(postgres.Open(dsn), &gorm.Config{}) + db, err := gorm.Open(postgres.Open(defaultDSN), &gorm.Config{}) if err != nil { - return fmt.Errorf("failed to connect: %w", err) + return fmt.Errorf("failed to connect to PostgreSQL: %w", err) } sqlDB, err := db.DB() @@ -107,6 +110,50 @@ func TestDatabaseConnection(cfg *DatabaseConfig) error { return fmt.Errorf("ping failed: %w", err) } + // Check if target database exists + var exists bool + row := sqlDB.QueryRowContext(ctx, "SELECT EXISTS(SELECT 1 FROM pg_database WHERE datname = $1)", cfg.DBName) + if err := row.Scan(&exists); err != nil { + return fmt.Errorf("failed to check database existence: %w", err) + } + + // Create database if not exists + if !exists { + // Note: Database names cannot be parameterized, but we've already validated cfg.DBName + // in the handler using validateDBName() which only allows [a-zA-Z][a-zA-Z0-9_]* + _, err := sqlDB.ExecContext(ctx, fmt.Sprintf("CREATE DATABASE %s", cfg.DBName)) + if err != nil { + return fmt.Errorf("failed to create database '%s': %w", cfg.DBName, err) + } + log.Printf("Database '%s' created successfully", cfg.DBName) + } + + // Now connect to the target database to verify + sqlDB.Close() + + targetDSN := fmt.Sprintf( + "host=%s port=%d user=%s password=%s dbname=%s sslmode=%s", + cfg.Host, cfg.Port, cfg.User, cfg.Password, cfg.DBName, cfg.SSLMode, + ) + + targetDB, err := gorm.Open(postgres.Open(targetDSN), &gorm.Config{}) + if err != nil { + return fmt.Errorf("failed to connect to database '%s': %w", cfg.DBName, err) + } + + targetSqlDB, err := targetDB.DB() + if err != nil { + return fmt.Errorf("failed to get target db instance: %w", err) + } + defer targetSqlDB.Close() + + ctx2, cancel2 := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel2() + + if err := targetSqlDB.PingContext(ctx2); err != nil { + return fmt.Errorf("ping target database failed: %w", err) + } + return nil } @@ -198,18 +245,8 @@ func initializeDatabase(cfg *SetupConfig) error { } defer sqlDB.Close() - // Run auto-migration for all models - return db.AutoMigrate( - &User{}, - &Group{}, - &APIKey{}, - &Account{}, - &Proxy{}, - &RedeemCode{}, - &UsageLog{}, - &UserSubscription{}, - &Setting{}, - ) + // 使用 model 包的 AutoMigrate,确保模型定义统一 + return model.AutoMigrate(db) } func createAdminUser(cfg *SetupConfig) error { @@ -232,7 +269,7 @@ func createAdminUser(cfg *SetupConfig) error { // Check if admin already exists var count int64 - db.Model(&User{}).Where("role = ?", "admin").Count(&count) + db.Model(&model.User{}).Where("role = ?", "admin").Count(&count) if count > 0 { return nil // Admin already exists } @@ -244,11 +281,11 @@ func createAdminUser(cfg *SetupConfig) error { } // Create admin user - admin := &User{ + admin := &model.User{ Email: cfg.Admin.Email, PasswordHash: string(hashedPassword), - Role: "admin", - Status: "active", + Role: model.RoleAdmin, + Status: model.StatusActive, Balance: 0, CreatedAt: time.Now(), UpdatedAt: time.Now(), @@ -321,119 +358,6 @@ func generateSecret(length int) string { return hex.EncodeToString(bytes) } -// Minimal model definitions for migration (to avoid circular import) -type User struct { - ID uint `gorm:"primaryKey"` - Email string `gorm:"uniqueIndex;not null"` - PasswordHash string `gorm:"not null"` - Role string `gorm:"default:user"` - Status string `gorm:"default:active"` - Balance float64 `gorm:"default:0"` - CreatedAt time.Time - UpdatedAt time.Time -} - -type Group struct { - ID uint `gorm:"primaryKey"` - Name string `gorm:"uniqueIndex;not null"` - Description string `gorm:"type:text"` - RateMultiplier float64 `gorm:"default:1.0"` - IsExclusive bool `gorm:"default:false"` - Priority int `gorm:"default:0"` - Status string `gorm:"default:active"` - CreatedAt time.Time - UpdatedAt time.Time -} - -type APIKey struct { - ID uint `gorm:"primaryKey"` - UserID uint `gorm:"index;not null"` - Key string `gorm:"uniqueIndex;not null"` - Name string - GroupID *uint - Status string `gorm:"default:active"` - CreatedAt time.Time - UpdatedAt time.Time -} - -type Account struct { - ID uint `gorm:"primaryKey"` - Platform string `gorm:"not null"` - Type string `gorm:"not null"` - Credentials string `gorm:"type:text"` - Status string `gorm:"default:active"` - Priority int `gorm:"default:0"` - ProxyID *uint - CreatedAt time.Time - UpdatedAt time.Time -} - -type Proxy struct { - ID uint `gorm:"primaryKey"` - Name string `gorm:"not null"` - Protocol string `gorm:"not null"` - Host string `gorm:"not null"` - Port int `gorm:"not null"` - Username string - Password string - Status string `gorm:"default:active"` - CreatedAt time.Time - UpdatedAt time.Time -} - -type RedeemCode struct { - ID uint `gorm:"primaryKey"` - Code string `gorm:"uniqueIndex;not null"` - Value float64 `gorm:"not null"` - Status string `gorm:"default:unused"` - UsedBy *uint - UsedAt *time.Time - ExpiresAt *time.Time - CreatedAt time.Time -} - -type UsageLog struct { - ID uint `gorm:"primaryKey"` - UserID uint `gorm:"index"` - APIKeyID uint `gorm:"index"` - AccountID *uint `gorm:"index"` - Model string `gorm:"index"` - InputTokens int - OutputTokens int - Cost float64 - CreatedAt time.Time -} - -type UserSubscription struct { - ID uint `gorm:"primaryKey"` - UserID uint `gorm:"index;not null"` - GroupID uint `gorm:"index;not null"` - Quota int64 - Used int64 `gorm:"default:0"` - Status string - ExpiresAt *time.Time - CreatedAt time.Time - UpdatedAt time.Time -} - -type Setting struct { - ID uint `gorm:"primaryKey"` - Key string `gorm:"uniqueIndex;not null"` - Value string `gorm:"type:text"` - CreatedAt time.Time - UpdatedAt time.Time -} - -func (User) TableName() string { return "users" } -func (Group) TableName() string { return "groups" } -func (APIKey) TableName() string { return "api_keys" } -func (Account) TableName() string { return "accounts" } -func (Proxy) TableName() string { return "proxies" } -func (RedeemCode) TableName() string { return "redeem_codes" } -func (UsageLog) TableName() string { return "usage_logs" } -func (UserSubscription) TableName() string { return "user_subscriptions" } -func (Setting) TableName() string { return "settings" } - // ============================================================================= // Auto Setup for Docker Deployment // =============================================================================