diff --git a/common/database.go b/common/database.go index 3c0a944b..9cbaf46a 100644 --- a/common/database.go +++ b/common/database.go @@ -1,7 +1,14 @@ package common +const ( + DatabaseTypeMySQL = "mysql" + DatabaseTypeSQLite = "sqlite" + DatabaseTypePostgreSQL = "postgres" +) + var UsingSQLite = false var UsingPostgreSQL = false +var LogSqlType = DatabaseTypeSQLite // Default to SQLite for logging SQL queries var UsingMySQL = false var UsingClickHouse = false diff --git a/model/ability.go b/model/ability.go index dd1a11be..96a9ef6a 100644 --- a/model/ability.go +++ b/model/ability.go @@ -24,7 +24,7 @@ type Ability struct { func GetGroupModels(group string) []string { var models []string // Find distinct models - DB.Table("abilities").Where(groupCol+" = ? and enabled = ?", group, true).Distinct("model").Pluck("model", &models) + DB.Table("abilities").Where(commonGroupCol+" = ? and enabled = ?", group, true).Distinct("model").Pluck("model", &models) return models } @@ -42,16 +42,12 @@ func GetAllEnableAbilities() []Ability { } func getPriority(group string, model string, retry int) (int, error) { - trueVal := "1" - if common.UsingPostgreSQL { - trueVal = "true" - } var priorities []int err := DB.Model(&Ability{}). Select("DISTINCT(priority)"). - Where(groupCol+" = ? and model = ? and enabled = "+trueVal, group, model). - Order("priority DESC"). // 按优先级降序排序 + Where(commonGroupCol+" = ? and model = ? and enabled = ?", group, model, commonTrueVal). + Order("priority DESC"). // 按优先级降序排序 Pluck("priority", &priorities).Error // Pluck用于将查询的结果直接扫描到一个切片中 if err != nil { @@ -76,18 +72,14 @@ func getPriority(group string, model string, retry int) (int, error) { } func getChannelQuery(group string, model string, retry int) *gorm.DB { - trueVal := "1" - if common.UsingPostgreSQL { - trueVal = "true" - } - maxPrioritySubQuery := DB.Model(&Ability{}).Select("MAX(priority)").Where(groupCol+" = ? and model = ? and enabled = "+trueVal, group, model) - channelQuery := DB.Where(groupCol+" = ? and model = ? and enabled = "+trueVal+" and priority = (?)", group, model, maxPrioritySubQuery) + maxPrioritySubQuery := DB.Model(&Ability{}).Select("MAX(priority)").Where(commonGroupCol+" = ? and model = ? and enabled = ?", group, model, commonTrueVal) + channelQuery := DB.Where(commonGroupCol+" = ? and model = ? and enabled = ? and priority = (?)", group, model, commonTrueVal, maxPrioritySubQuery) if retry != 0 { priority, err := getPriority(group, model, retry) if err != nil { common.SysError(fmt.Sprintf("Get priority failed: %s", err.Error())) } else { - channelQuery = DB.Where(groupCol+" = ? and model = ? and enabled = "+trueVal+" and priority = ?", group, model, priority) + channelQuery = DB.Where(commonGroupCol+" = ? and model = ? and enabled = ? and priority = ?", group, model, commonTrueVal, priority) } } diff --git a/model/channel.go b/model/channel.go index a302df40..b5503eee 100644 --- a/model/channel.go +++ b/model/channel.go @@ -145,7 +145,7 @@ func SearchChannels(keyword string, group string, model string, idSort bool) ([] } // 构造基础查询 - baseQuery := DB.Model(&Channel{}).Omit(keyCol) + baseQuery := DB.Model(&Channel{}).Omit("key") // 构造WHERE子句 var whereClause string @@ -153,15 +153,15 @@ func SearchChannels(keyword string, group string, model string, idSort bool) ([] if group != "" && group != "null" { var groupCondition string if common.UsingMySQL { - groupCondition = `CONCAT(',', ` + groupCol + `, ',') LIKE ?` + groupCondition = `CONCAT(',', ` + commonGroupCol + `, ',') LIKE ?` } else { // sqlite, PostgreSQL - groupCondition = `(',' || ` + groupCol + ` || ',') LIKE ?` + groupCondition = `(',' || ` + commonGroupCol + ` || ',') LIKE ?` } - whereClause = "(id = ? OR name LIKE ? OR " + keyCol + " = ? OR " + baseURLCol + " LIKE ?) AND " + modelsCol + ` LIKE ? AND ` + groupCondition + whereClause = "(id = ? OR name LIKE ? OR " + commonKeyCol + " = ? OR " + baseURLCol + " LIKE ?) AND " + modelsCol + ` LIKE ? AND ` + groupCondition args = append(args, common.String2Int(keyword), "%"+keyword+"%", keyword, "%"+keyword+"%", "%"+model+"%", "%,"+group+",%") } else { - whereClause = "(id = ? OR name LIKE ? OR " + keyCol + " = ? OR " + baseURLCol + " LIKE ?) AND " + modelsCol + " LIKE ?" + whereClause = "(id = ? OR name LIKE ? OR " + commonKeyCol + " = ? OR " + baseURLCol + " LIKE ?) AND " + modelsCol + " LIKE ?" args = append(args, common.String2Int(keyword), "%"+keyword+"%", keyword, "%"+keyword+"%", "%"+model+"%") } @@ -478,7 +478,7 @@ func SearchTags(keyword string, group string, model string, idSort bool) ([]*str } // 构造基础查询 - baseQuery := DB.Model(&Channel{}).Omit(keyCol) + baseQuery := DB.Model(&Channel{}).Omit("key") // 构造WHERE子句 var whereClause string @@ -486,15 +486,15 @@ func SearchTags(keyword string, group string, model string, idSort bool) ([]*str if group != "" && group != "null" { var groupCondition string if common.UsingMySQL { - groupCondition = `CONCAT(',', ` + groupCol + `, ',') LIKE ?` + groupCondition = `CONCAT(',', ` + commonGroupCol + `, ',') LIKE ?` } else { // sqlite, PostgreSQL - groupCondition = `(',' || ` + groupCol + ` || ',') LIKE ?` + groupCondition = `(',' || ` + commonGroupCol + ` || ',') LIKE ?` } - whereClause = "(id = ? OR name LIKE ? OR " + keyCol + " = ? OR " + baseURLCol + " LIKE ?) AND " + modelsCol + ` LIKE ? AND ` + groupCondition + whereClause = "(id = ? OR name LIKE ? OR " + commonKeyCol + " = ? OR " + baseURLCol + " LIKE ?) AND " + modelsCol + ` LIKE ? AND ` + groupCondition args = append(args, common.String2Int(keyword), "%"+keyword+"%", keyword, "%"+keyword+"%", "%"+model+"%", "%,"+group+",%") } else { - whereClause = "(id = ? OR name LIKE ? OR " + keyCol + " = ? OR " + baseURLCol + " LIKE ?) AND " + modelsCol + " LIKE ?" + whereClause = "(id = ? OR name LIKE ? OR " + commonKeyCol + " = ? OR " + baseURLCol + " LIKE ?) AND " + modelsCol + " LIKE ?" args = append(args, common.String2Int(keyword), "%"+keyword+"%", keyword, "%"+keyword+"%", "%"+model+"%") } diff --git a/model/log.go b/model/log.go index 3df961e1..b3fd1ad2 100644 --- a/model/log.go +++ b/model/log.go @@ -63,7 +63,7 @@ func formatUserLogs(logs []*Log) { func GetLogByKey(key string) (logs []*Log, err error) { if os.Getenv("LOG_SQL_DSN") != "" { var tk Token - if err = DB.Model(&Token{}).Where(keyCol+"=?", strings.TrimPrefix(key, "sk-")).First(&tk).Error; err != nil { + if err = DB.Model(&Token{}).Where(logKeyCol+"=?", strings.TrimPrefix(key, "sk-")).First(&tk).Error; err != nil { return nil, err } err = LOG_DB.Model(&Log{}).Where("token_id=?", tk.Id).Find(&logs).Error @@ -122,8 +122,13 @@ func RecordErrorLog(c *gin.Context, userId int, channelId int, modelName string, UseTime: useTimeSeconds, IsStream: isStream, Group: group, - Ip: func() string { if needRecordIp { return c.ClientIP() }; return "" }(), - Other: otherStr, + Ip: func() string { + if needRecordIp { + return c.ClientIP() + } + return "" + }(), + Other: otherStr, } err := LOG_DB.Create(log).Error if err != nil { @@ -165,8 +170,13 @@ func RecordConsumeLog(c *gin.Context, userId int, channelId int, promptTokens in UseTime: useTimeSeconds, IsStream: isStream, Group: group, - Ip: func() string { if needRecordIp { return c.ClientIP() }; return "" }(), - Other: otherStr, + Ip: func() string { + if needRecordIp { + return c.ClientIP() + } + return "" + }(), + Other: otherStr, } err := LOG_DB.Create(log).Error if err != nil { @@ -206,7 +216,7 @@ func GetAllLogs(logType int, startTimestamp int64, endTimestamp int64, modelName tx = tx.Where("logs.channel_id = ?", channel) } if group != "" { - tx = tx.Where("logs."+groupCol+" = ?", group) + tx = tx.Where("logs."+logGroupCol+" = ?", group) } err = tx.Model(&Log{}).Count(&total).Error if err != nil { @@ -217,13 +227,18 @@ func GetAllLogs(logType int, startTimestamp int64, endTimestamp int64, modelName return nil, 0, err } - channelIds := make([]int, 0) + channelIdsMap := make(map[int]struct{}) channelMap := make(map[int]string) for _, log := range logs { if log.ChannelId != 0 { - channelIds = append(channelIds, log.ChannelId) + channelIdsMap[log.ChannelId] = struct{}{} } } + + channelIds := make([]int, 0, len(channelIdsMap)) + for channelId := range channelIdsMap { + channelIds = append(channelIds, channelId) + } if len(channelIds) > 0 { var channels []struct { Id int `gorm:"column:id"` @@ -264,7 +279,7 @@ func GetUserLogs(userId int, logType int, startTimestamp int64, endTimestamp int tx = tx.Where("logs.created_at <= ?", endTimestamp) } if group != "" { - tx = tx.Where("logs."+groupCol+" = ?", group) + tx = tx.Where("logs."+logGroupCol+" = ?", group) } err = tx.Model(&Log{}).Count(&total).Error if err != nil { @@ -325,8 +340,8 @@ func SumUsedQuota(logType int, startTimestamp int64, endTimestamp int64, modelNa rpmTpmQuery = rpmTpmQuery.Where("channel_id = ?", channel) } if group != "" { - tx = tx.Where(groupCol+" = ?", group) - rpmTpmQuery = rpmTpmQuery.Where(groupCol+" = ?", group) + tx = tx.Where(logGroupCol+" = ?", group) + rpmTpmQuery = rpmTpmQuery.Where(logGroupCol+" = ?", group) } tx = tx.Where("type = ?", LogTypeConsume) diff --git a/model/main.go b/model/main.go index 61d6bb10..289baa2f 100644 --- a/model/main.go +++ b/model/main.go @@ -1,6 +1,7 @@ package model import ( + "fmt" "log" "one-api/common" "one-api/constant" @@ -15,18 +16,39 @@ import ( "gorm.io/gorm" ) -var groupCol string -var keyCol string +var commonGroupCol string +var commonKeyCol string +var commonTrueVal string +var commonFalseVal string + +var logKeyCol string +var logGroupCol string func initCol() { + // init common column names if common.UsingPostgreSQL { - groupCol = `"group"` - keyCol = `"key"` - + commonGroupCol = `"group"` + commonKeyCol = `"key"` + commonTrueVal = "true" + commonFalseVal = "false" } else { - groupCol = "`group`" - keyCol = "`key`" + commonGroupCol = "`group`" + commonKeyCol = "`key`" + commonTrueVal = "1" + commonFalseVal = "0" } + if os.Getenv("LOG_SQL_DSN") != "" { + switch common.LogSqlType { + case common.DatabaseTypePostgreSQL: + logGroupCol = `"group"` + logKeyCol = `"key"` + default: + logGroupCol = commonGroupCol + logKeyCol = commonKeyCol + } + } + // log sql type and database type + common.SysLog("Using Log SQL Type: " + common.LogSqlType) } var DB *gorm.DB @@ -83,7 +105,7 @@ func CheckSetup() { } } -func chooseDB(envName string) (*gorm.DB, error) { +func chooseDB(envName string, isLog bool) (*gorm.DB, error) { defer func() { initCol() }() @@ -92,7 +114,11 @@ func chooseDB(envName string) (*gorm.DB, error) { if strings.HasPrefix(dsn, "postgres://") || strings.HasPrefix(dsn, "postgresql://") { // Use PostgreSQL common.SysLog("using PostgreSQL as database") - common.UsingPostgreSQL = true + if !isLog { + common.UsingPostgreSQL = true + } else { + common.LogSqlType = common.DatabaseTypePostgreSQL + } return gorm.Open(postgres.New(postgres.Config{ DSN: dsn, PreferSimpleProtocol: true, // disables implicit prepared statement usage @@ -102,7 +128,11 @@ func chooseDB(envName string) (*gorm.DB, error) { } if strings.HasPrefix(dsn, "local") { common.SysLog("SQL_DSN not set, using SQLite as database") - common.UsingSQLite = true + if !isLog { + common.UsingSQLite = true + } else { + common.LogSqlType = common.DatabaseTypeSQLite + } return gorm.Open(sqlite.Open(common.SQLitePath), &gorm.Config{ PrepareStmt: true, // precompile SQL }) @@ -117,7 +147,11 @@ func chooseDB(envName string) (*gorm.DB, error) { dsn += "?parseTime=true" } } - common.UsingMySQL = true + if !isLog { + common.UsingMySQL = true + } else { + common.LogSqlType = common.DatabaseTypeMySQL + } return gorm.Open(mysql.Open(dsn), &gorm.Config{ PrepareStmt: true, // precompile SQL }) @@ -131,7 +165,7 @@ func chooseDB(envName string) (*gorm.DB, error) { } func InitDB() (err error) { - db, err := chooseDB("SQL_DSN") + db, err := chooseDB("SQL_DSN", false) if err == nil { if common.DebugEnabled { db = db.Debug() @@ -149,7 +183,7 @@ func InitDB() (err error) { return nil } if common.UsingMySQL { - _, _ = sqlDB.Exec("ALTER TABLE channels MODIFY model_mapping TEXT;") // TODO: delete this line when most users have upgraded + //_, _ = sqlDB.Exec("ALTER TABLE channels MODIFY model_mapping TEXT;") // TODO: delete this line when most users have upgraded } common.SysLog("database migration started") err = migrateDB() @@ -165,7 +199,7 @@ func InitLogDB() (err error) { LOG_DB = DB return } - db, err := chooseDB("LOG_SQL_DSN") + db, err := chooseDB("LOG_SQL_DSN", true) if err == nil { if common.DebugEnabled { db = db.Debug() @@ -198,54 +232,50 @@ func InitLogDB() (err error) { } func migrateDB() error { - err := DB.AutoMigrate(&Channel{}) - if err != nil { - return err + var wg sync.WaitGroup + errChan := make(chan error, 12) // Buffer size matches number of migrations + + migrations := []struct { + model interface{} + name string + }{ + {&Channel{}, "Channel"}, + {&Token{}, "Token"}, + {&User{}, "User"}, + {&Option{}, "Option"}, + {&Redemption{}, "Redemption"}, + {&Ability{}, "Ability"}, + {&Log{}, "Log"}, + {&Midjourney{}, "Midjourney"}, + {&TopUp{}, "TopUp"}, + {&QuotaData{}, "QuotaData"}, + {&Task{}, "Task"}, + {&Setup{}, "Setup"}, } - err = DB.AutoMigrate(&Token{}) - if err != nil { - return err + + for _, m := range migrations { + wg.Add(1) + go func(model interface{}, name string) { + defer wg.Done() + if err := DB.AutoMigrate(model); err != nil { + errChan <- fmt.Errorf("failed to migrate %s: %v", name, err) + } + }(m.model, m.name) } - err = DB.AutoMigrate(&User{}) - if err != nil { - return err + + // Wait for all migrations to complete + wg.Wait() + close(errChan) + + // Check for any errors + for err := range errChan { + if err != nil { + return err + } } - err = DB.AutoMigrate(&Option{}) - if err != nil { - return err - } - err = DB.AutoMigrate(&Redemption{}) - if err != nil { - return err - } - err = DB.AutoMigrate(&Ability{}) - if err != nil { - return err - } - err = DB.AutoMigrate(&Log{}) - if err != nil { - return err - } - err = DB.AutoMigrate(&Midjourney{}) - if err != nil { - return err - } - err = DB.AutoMigrate(&TopUp{}) - if err != nil { - return err - } - err = DB.AutoMigrate(&QuotaData{}) - if err != nil { - return err - } - err = DB.AutoMigrate(&Task{}) - if err != nil { - return err - } - err = DB.AutoMigrate(&Setup{}) + common.SysLog("database migrated") - //err = createRootAccountIfNeed() - return err + return nil } func migrateLOGDB() error { diff --git a/model/token.go b/model/token.go index d4b26afe..2ed2c09a 100644 --- a/model/token.go +++ b/model/token.go @@ -66,7 +66,7 @@ func SearchUserTokens(userId int, keyword string, token string) (tokens []*Token if token != "" { token = strings.Trim(token, "sk-") } - err = DB.Where("user_id = ?", userId).Where("name LIKE ?", "%"+keyword+"%").Where(keyCol+" LIKE ?", "%"+token+"%").Find(&tokens).Error + err = DB.Where("user_id = ?", userId).Where("name LIKE ?", "%"+keyword+"%").Where(commonKeyCol+" LIKE ?", "%"+token+"%").Find(&tokens).Error return tokens, err } @@ -161,7 +161,7 @@ func GetTokenByKey(key string, fromDB bool) (token *Token, err error) { // Don't return error - fall through to DB } fromDB = true - err = DB.Where(keyCol+" = ?", key).First(&token).Error + err = DB.Where(commonKeyCol+" = ?", key).First(&token).Error return token, err } diff --git a/model/user.go b/model/user.go index 1a3372aa..1b3a04b6 100644 --- a/model/user.go +++ b/model/user.go @@ -175,7 +175,7 @@ func SearchUsers(keyword string, group string, startIdx int, num int) ([]*User, // 如果是数字,同时搜索ID和其他字段 likeCondition = "id = ? OR " + likeCondition if group != "" { - query = query.Where("("+likeCondition+") AND "+groupCol+" = ?", + query = query.Where("("+likeCondition+") AND "+commonGroupCol+" = ?", keywordInt, "%"+keyword+"%", "%"+keyword+"%", "%"+keyword+"%", group) } else { query = query.Where(likeCondition, @@ -184,7 +184,7 @@ func SearchUsers(keyword string, group string, startIdx int, num int) ([]*User, } else { // 非数字关键字,只搜索字符串字段 if group != "" { - query = query.Where("("+likeCondition+") AND "+groupCol+" = ?", + query = query.Where("("+likeCondition+") AND "+commonGroupCol+" = ?", "%"+keyword+"%", "%"+keyword+"%", "%"+keyword+"%", group) } else { query = query.Where(likeCondition, @@ -615,7 +615,7 @@ func GetUserGroup(id int, fromDB bool) (group string, err error) { // Don't return error - fall through to DB } fromDB = true - err = DB.Model(&User{}).Where("id = ?", id).Select(groupCol).Find(&group).Error + err = DB.Model(&User{}).Where("id = ?", id).Select(commonGroupCol).Find(&group).Error if err != nil { return "", err }