360 lines
8.1 KiB
Go
360 lines
8.1 KiB
Go
package model
|
|
|
|
import (
|
|
"fmt"
|
|
"log"
|
|
"one-api/common"
|
|
"one-api/constant"
|
|
"os"
|
|
"strings"
|
|
"sync"
|
|
"time"
|
|
|
|
"github.com/glebarez/sqlite"
|
|
"gorm.io/driver/mysql"
|
|
"gorm.io/driver/postgres"
|
|
"gorm.io/gorm"
|
|
)
|
|
|
|
var commonGroupCol string
|
|
var commonKeyCol string
|
|
var commonTrueVal string
|
|
var commonFalseVal string
|
|
|
|
var logKeyCol string
|
|
var logGroupCol string
|
|
|
|
func initCol() {
|
|
// init common column names
|
|
if common.UsingPostgreSQL {
|
|
commonGroupCol = `"group"`
|
|
commonKeyCol = `"key"`
|
|
commonTrueVal = "true"
|
|
commonFalseVal = "false"
|
|
} else {
|
|
commonGroupCol = "`group`"
|
|
commonKeyCol = "`key`"
|
|
commonTrueVal = "1"
|
|
commonFalseVal = "0"
|
|
}
|
|
if os.Getenv("LOG_SQL_DSN") != "" {
|
|
switch common.LogSqlType {
|
|
case common.DatabaseTypePostgreSQL:
|
|
logGroupCol = `"group"`
|
|
logKeyCol = `"key"`
|
|
default:
|
|
logGroupCol = commonGroupCol
|
|
logKeyCol = commonKeyCol
|
|
}
|
|
}
|
|
// log sql type and database type
|
|
//common.SysLog("Using Log SQL Type: " + common.LogSqlType)
|
|
}
|
|
|
|
var DB *gorm.DB
|
|
|
|
var LOG_DB *gorm.DB
|
|
|
|
func createRootAccountIfNeed() error {
|
|
var user User
|
|
//if user.Status != common.UserStatusEnabled {
|
|
if err := DB.First(&user).Error; err != nil {
|
|
common.SysLog("no user exists, create a root user for you: username is root, password is 123456")
|
|
hashedPassword, err := common.Password2Hash("123456")
|
|
if err != nil {
|
|
return err
|
|
}
|
|
rootUser := User{
|
|
Username: "root",
|
|
Password: hashedPassword,
|
|
Role: common.RoleRootUser,
|
|
Status: common.UserStatusEnabled,
|
|
DisplayName: "Root User",
|
|
AccessToken: nil,
|
|
Quota: 100000000,
|
|
}
|
|
DB.Create(&rootUser)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func CheckSetup() {
|
|
setup := GetSetup()
|
|
if setup == nil {
|
|
// No setup record exists, check if we have a root user
|
|
if RootUserExists() {
|
|
common.SysLog("system is not initialized, but root user exists")
|
|
// Create setup record
|
|
newSetup := Setup{
|
|
Version: common.Version,
|
|
InitializedAt: time.Now().Unix(),
|
|
}
|
|
err := DB.Create(&newSetup).Error
|
|
if err != nil {
|
|
common.SysLog("failed to create setup record: " + err.Error())
|
|
}
|
|
constant.Setup = true
|
|
} else {
|
|
common.SysLog("system is not initialized and no root user exists")
|
|
constant.Setup = false
|
|
}
|
|
} else {
|
|
// Setup record exists, system is initialized
|
|
common.SysLog("system is already initialized at: " + time.Unix(setup.InitializedAt, 0).String())
|
|
constant.Setup = true
|
|
}
|
|
}
|
|
|
|
func chooseDB(envName string, isLog bool) (*gorm.DB, error) {
|
|
defer func() {
|
|
initCol()
|
|
}()
|
|
dsn := os.Getenv(envName)
|
|
if dsn != "" {
|
|
if strings.HasPrefix(dsn, "postgres://") || strings.HasPrefix(dsn, "postgresql://") {
|
|
// Use PostgreSQL
|
|
common.SysLog("using PostgreSQL as database")
|
|
if !isLog {
|
|
common.UsingPostgreSQL = true
|
|
} else {
|
|
common.LogSqlType = common.DatabaseTypePostgreSQL
|
|
}
|
|
return gorm.Open(postgres.New(postgres.Config{
|
|
DSN: dsn,
|
|
PreferSimpleProtocol: true, // disables implicit prepared statement usage
|
|
}), &gorm.Config{
|
|
PrepareStmt: true, // precompile SQL
|
|
})
|
|
}
|
|
if strings.HasPrefix(dsn, "local") {
|
|
common.SysLog("SQL_DSN not set, using SQLite as database")
|
|
if !isLog {
|
|
common.UsingSQLite = true
|
|
} else {
|
|
common.LogSqlType = common.DatabaseTypeSQLite
|
|
}
|
|
return gorm.Open(sqlite.Open(common.SQLitePath), &gorm.Config{
|
|
PrepareStmt: true, // precompile SQL
|
|
})
|
|
}
|
|
// Use MySQL
|
|
common.SysLog("using MySQL as database")
|
|
// check parseTime
|
|
if !strings.Contains(dsn, "parseTime") {
|
|
if strings.Contains(dsn, "?") {
|
|
dsn += "&parseTime=true"
|
|
} else {
|
|
dsn += "?parseTime=true"
|
|
}
|
|
}
|
|
if !isLog {
|
|
common.UsingMySQL = true
|
|
} else {
|
|
common.LogSqlType = common.DatabaseTypeMySQL
|
|
}
|
|
return gorm.Open(mysql.Open(dsn), &gorm.Config{
|
|
PrepareStmt: true, // precompile SQL
|
|
})
|
|
}
|
|
// Use SQLite
|
|
common.SysLog("SQL_DSN not set, using SQLite as database")
|
|
common.UsingSQLite = true
|
|
return gorm.Open(sqlite.Open(common.SQLitePath), &gorm.Config{
|
|
PrepareStmt: true, // precompile SQL
|
|
})
|
|
}
|
|
|
|
func InitDB() (err error) {
|
|
db, err := chooseDB("SQL_DSN", false)
|
|
if err == nil {
|
|
if common.DebugEnabled {
|
|
db = db.Debug()
|
|
}
|
|
DB = db
|
|
sqlDB, err := DB.DB()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
sqlDB.SetMaxIdleConns(common.GetEnvOrDefault("SQL_MAX_IDLE_CONNS", 100))
|
|
sqlDB.SetMaxOpenConns(common.GetEnvOrDefault("SQL_MAX_OPEN_CONNS", 1000))
|
|
sqlDB.SetConnMaxLifetime(time.Second * time.Duration(common.GetEnvOrDefault("SQL_MAX_LIFETIME", 60)))
|
|
|
|
if !common.IsMasterNode {
|
|
return nil
|
|
}
|
|
if common.UsingMySQL {
|
|
//_, _ = sqlDB.Exec("ALTER TABLE channels MODIFY model_mapping TEXT;") // TODO: delete this line when most users have upgraded
|
|
}
|
|
common.SysLog("database migration started")
|
|
err = migrateDB()
|
|
return err
|
|
} else {
|
|
common.FatalLog(err)
|
|
}
|
|
return err
|
|
}
|
|
|
|
func InitLogDB() (err error) {
|
|
if os.Getenv("LOG_SQL_DSN") == "" {
|
|
LOG_DB = DB
|
|
return
|
|
}
|
|
db, err := chooseDB("LOG_SQL_DSN", true)
|
|
if err == nil {
|
|
if common.DebugEnabled {
|
|
db = db.Debug()
|
|
}
|
|
LOG_DB = db
|
|
sqlDB, err := LOG_DB.DB()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
sqlDB.SetMaxIdleConns(common.GetEnvOrDefault("SQL_MAX_IDLE_CONNS", 100))
|
|
sqlDB.SetMaxOpenConns(common.GetEnvOrDefault("SQL_MAX_OPEN_CONNS", 1000))
|
|
sqlDB.SetConnMaxLifetime(time.Second * time.Duration(common.GetEnvOrDefault("SQL_MAX_LIFETIME", 60)))
|
|
|
|
if !common.IsMasterNode {
|
|
return nil
|
|
}
|
|
//if common.UsingMySQL {
|
|
// _, _ = sqlDB.Exec("DROP INDEX idx_channels_key ON channels;") // TODO: delete this line when most users have upgraded
|
|
// _, _ = sqlDB.Exec("ALTER TABLE midjourneys MODIFY action VARCHAR(40);") // TODO: delete this line when most users have upgraded
|
|
// _, _ = sqlDB.Exec("ALTER TABLE midjourneys MODIFY progress VARCHAR(30);") // TODO: delete this line when most users have upgraded
|
|
// _, _ = sqlDB.Exec("ALTER TABLE midjourneys MODIFY status VARCHAR(20);") // TODO: delete this line when most users have upgraded
|
|
//}
|
|
common.SysLog("database migration started")
|
|
err = migrateLOGDB()
|
|
return err
|
|
} else {
|
|
common.FatalLog(err)
|
|
}
|
|
return err
|
|
}
|
|
|
|
func migrateDB() error {
|
|
if !common.UsingPostgreSQL {
|
|
return migrateDBFast()
|
|
}
|
|
err := DB.AutoMigrate(
|
|
&Channel{},
|
|
&Token{},
|
|
&User{},
|
|
&Option{},
|
|
&Redemption{},
|
|
&Ability{},
|
|
&Log{},
|
|
&Midjourney{},
|
|
&TopUp{},
|
|
&QuotaData{},
|
|
&Task{},
|
|
&Setup{},
|
|
)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func migrateDBFast() error {
|
|
var wg sync.WaitGroup
|
|
errChan := make(chan error, 12) // Buffer size matches number of migrations
|
|
|
|
migrations := []struct {
|
|
model interface{}
|
|
name string
|
|
}{
|
|
{&Channel{}, "Channel"},
|
|
{&Token{}, "Token"},
|
|
{&User{}, "User"},
|
|
{&Option{}, "Option"},
|
|
{&Redemption{}, "Redemption"},
|
|
{&Ability{}, "Ability"},
|
|
{&Log{}, "Log"},
|
|
{&Midjourney{}, "Midjourney"},
|
|
{&TopUp{}, "TopUp"},
|
|
{&QuotaData{}, "QuotaData"},
|
|
{&Task{}, "Task"},
|
|
{&Setup{}, "Setup"},
|
|
}
|
|
|
|
for _, m := range migrations {
|
|
wg.Add(1)
|
|
go func(model interface{}, name string) {
|
|
defer wg.Done()
|
|
if err := DB.AutoMigrate(model); err != nil {
|
|
errChan <- fmt.Errorf("failed to migrate %s: %v", name, err)
|
|
}
|
|
}(m.model, m.name)
|
|
}
|
|
|
|
// Wait for all migrations to complete
|
|
wg.Wait()
|
|
close(errChan)
|
|
|
|
// Check for any errors
|
|
for err := range errChan {
|
|
if err != nil {
|
|
return err
|
|
}
|
|
}
|
|
common.SysLog("database migrated")
|
|
return nil
|
|
}
|
|
|
|
func migrateLOGDB() error {
|
|
var err error
|
|
if err = LOG_DB.AutoMigrate(&Log{}); err != nil {
|
|
return err
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func closeDB(db *gorm.DB) error {
|
|
sqlDB, err := db.DB()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
err = sqlDB.Close()
|
|
return err
|
|
}
|
|
|
|
func CloseDB() error {
|
|
if LOG_DB != DB {
|
|
err := closeDB(LOG_DB)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
}
|
|
return closeDB(DB)
|
|
}
|
|
|
|
var (
|
|
lastPingTime time.Time
|
|
pingMutex sync.Mutex
|
|
)
|
|
|
|
func PingDB() error {
|
|
pingMutex.Lock()
|
|
defer pingMutex.Unlock()
|
|
|
|
if time.Since(lastPingTime) < time.Second*10 {
|
|
return nil
|
|
}
|
|
|
|
sqlDB, err := DB.DB()
|
|
if err != nil {
|
|
log.Printf("Error getting sql.DB from GORM: %v", err)
|
|
return err
|
|
}
|
|
|
|
err = sqlDB.Ping()
|
|
if err != nil {
|
|
log.Printf("Error pinging DB: %v", err)
|
|
return err
|
|
}
|
|
|
|
lastPingTime = time.Now()
|
|
common.SysLog("Database pinged successfully")
|
|
return nil
|
|
}
|