Merge pull request #1225 from QuantumNous/fix_mixing_sql_conflicts
Fix mixing databases conflicts
This commit is contained in:
@@ -1,7 +1,14 @@
|
|||||||
package common
|
package common
|
||||||
|
|
||||||
|
const (
|
||||||
|
DatabaseTypeMySQL = "mysql"
|
||||||
|
DatabaseTypeSQLite = "sqlite"
|
||||||
|
DatabaseTypePostgreSQL = "postgres"
|
||||||
|
)
|
||||||
|
|
||||||
var UsingSQLite = false
|
var UsingSQLite = false
|
||||||
var UsingPostgreSQL = false
|
var UsingPostgreSQL = false
|
||||||
|
var LogSqlType = DatabaseTypeSQLite // Default to SQLite for logging SQL queries
|
||||||
var UsingMySQL = false
|
var UsingMySQL = false
|
||||||
var UsingClickHouse = false
|
var UsingClickHouse = false
|
||||||
|
|
||||||
|
|||||||
@@ -24,7 +24,7 @@ type Ability struct {
|
|||||||
func GetGroupModels(group string) []string {
|
func GetGroupModels(group string) []string {
|
||||||
var models []string
|
var models []string
|
||||||
// Find distinct models
|
// Find distinct models
|
||||||
DB.Table("abilities").Where(groupCol+" = ? and enabled = ?", group, true).Distinct("model").Pluck("model", &models)
|
DB.Table("abilities").Where(commonGroupCol+" = ? and enabled = ?", group, true).Distinct("model").Pluck("model", &models)
|
||||||
return models
|
return models
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -42,16 +42,12 @@ func GetAllEnableAbilities() []Ability {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func getPriority(group string, model string, retry int) (int, error) {
|
func getPriority(group string, model string, retry int) (int, error) {
|
||||||
trueVal := "1"
|
|
||||||
if common.UsingPostgreSQL {
|
|
||||||
trueVal = "true"
|
|
||||||
}
|
|
||||||
|
|
||||||
var priorities []int
|
var priorities []int
|
||||||
err := DB.Model(&Ability{}).
|
err := DB.Model(&Ability{}).
|
||||||
Select("DISTINCT(priority)").
|
Select("DISTINCT(priority)").
|
||||||
Where(groupCol+" = ? and model = ? and enabled = "+trueVal, group, model).
|
Where(commonGroupCol+" = ? and model = ? and enabled = ?", group, model, commonTrueVal).
|
||||||
Order("priority DESC"). // 按优先级降序排序
|
Order("priority DESC"). // 按优先级降序排序
|
||||||
Pluck("priority", &priorities).Error // Pluck用于将查询的结果直接扫描到一个切片中
|
Pluck("priority", &priorities).Error // Pluck用于将查询的结果直接扫描到一个切片中
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -76,18 +72,14 @@ func getPriority(group string, model string, retry int) (int, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func getChannelQuery(group string, model string, retry int) *gorm.DB {
|
func getChannelQuery(group string, model string, retry int) *gorm.DB {
|
||||||
trueVal := "1"
|
maxPrioritySubQuery := DB.Model(&Ability{}).Select("MAX(priority)").Where(commonGroupCol+" = ? and model = ? and enabled = ?", group, model, commonTrueVal)
|
||||||
if common.UsingPostgreSQL {
|
channelQuery := DB.Where(commonGroupCol+" = ? and model = ? and enabled = ? and priority = (?)", group, model, commonTrueVal, maxPrioritySubQuery)
|
||||||
trueVal = "true"
|
|
||||||
}
|
|
||||||
maxPrioritySubQuery := DB.Model(&Ability{}).Select("MAX(priority)").Where(groupCol+" = ? and model = ? and enabled = "+trueVal, group, model)
|
|
||||||
channelQuery := DB.Where(groupCol+" = ? and model = ? and enabled = "+trueVal+" and priority = (?)", group, model, maxPrioritySubQuery)
|
|
||||||
if retry != 0 {
|
if retry != 0 {
|
||||||
priority, err := getPriority(group, model, retry)
|
priority, err := getPriority(group, model, retry)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
common.SysError(fmt.Sprintf("Get priority failed: %s", err.Error()))
|
common.SysError(fmt.Sprintf("Get priority failed: %s", err.Error()))
|
||||||
} else {
|
} else {
|
||||||
channelQuery = DB.Where(groupCol+" = ? and model = ? and enabled = "+trueVal+" and priority = ?", group, model, priority)
|
channelQuery = DB.Where(commonGroupCol+" = ? and model = ? and enabled = ? and priority = ?", group, model, commonTrueVal, priority)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -145,7 +145,7 @@ func SearchChannels(keyword string, group string, model string, idSort bool) ([]
|
|||||||
}
|
}
|
||||||
|
|
||||||
// 构造基础查询
|
// 构造基础查询
|
||||||
baseQuery := DB.Model(&Channel{}).Omit(keyCol)
|
baseQuery := DB.Model(&Channel{}).Omit("key")
|
||||||
|
|
||||||
// 构造WHERE子句
|
// 构造WHERE子句
|
||||||
var whereClause string
|
var whereClause string
|
||||||
@@ -153,15 +153,15 @@ func SearchChannels(keyword string, group string, model string, idSort bool) ([]
|
|||||||
if group != "" && group != "null" {
|
if group != "" && group != "null" {
|
||||||
var groupCondition string
|
var groupCondition string
|
||||||
if common.UsingMySQL {
|
if common.UsingMySQL {
|
||||||
groupCondition = `CONCAT(',', ` + groupCol + `, ',') LIKE ?`
|
groupCondition = `CONCAT(',', ` + commonGroupCol + `, ',') LIKE ?`
|
||||||
} else {
|
} else {
|
||||||
// sqlite, PostgreSQL
|
// sqlite, PostgreSQL
|
||||||
groupCondition = `(',' || ` + groupCol + ` || ',') LIKE ?`
|
groupCondition = `(',' || ` + commonGroupCol + ` || ',') LIKE ?`
|
||||||
}
|
}
|
||||||
whereClause = "(id = ? OR name LIKE ? OR " + keyCol + " = ? OR " + baseURLCol + " LIKE ?) AND " + modelsCol + ` LIKE ? AND ` + groupCondition
|
whereClause = "(id = ? OR name LIKE ? OR " + commonKeyCol + " = ? OR " + baseURLCol + " LIKE ?) AND " + modelsCol + ` LIKE ? AND ` + groupCondition
|
||||||
args = append(args, common.String2Int(keyword), "%"+keyword+"%", keyword, "%"+keyword+"%", "%"+model+"%", "%,"+group+",%")
|
args = append(args, common.String2Int(keyword), "%"+keyword+"%", keyword, "%"+keyword+"%", "%"+model+"%", "%,"+group+",%")
|
||||||
} else {
|
} else {
|
||||||
whereClause = "(id = ? OR name LIKE ? OR " + keyCol + " = ? OR " + baseURLCol + " LIKE ?) AND " + modelsCol + " LIKE ?"
|
whereClause = "(id = ? OR name LIKE ? OR " + commonKeyCol + " = ? OR " + baseURLCol + " LIKE ?) AND " + modelsCol + " LIKE ?"
|
||||||
args = append(args, common.String2Int(keyword), "%"+keyword+"%", keyword, "%"+keyword+"%", "%"+model+"%")
|
args = append(args, common.String2Int(keyword), "%"+keyword+"%", keyword, "%"+keyword+"%", "%"+model+"%")
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -478,7 +478,7 @@ func SearchTags(keyword string, group string, model string, idSort bool) ([]*str
|
|||||||
}
|
}
|
||||||
|
|
||||||
// 构造基础查询
|
// 构造基础查询
|
||||||
baseQuery := DB.Model(&Channel{}).Omit(keyCol)
|
baseQuery := DB.Model(&Channel{}).Omit("key")
|
||||||
|
|
||||||
// 构造WHERE子句
|
// 构造WHERE子句
|
||||||
var whereClause string
|
var whereClause string
|
||||||
@@ -486,15 +486,15 @@ func SearchTags(keyword string, group string, model string, idSort bool) ([]*str
|
|||||||
if group != "" && group != "null" {
|
if group != "" && group != "null" {
|
||||||
var groupCondition string
|
var groupCondition string
|
||||||
if common.UsingMySQL {
|
if common.UsingMySQL {
|
||||||
groupCondition = `CONCAT(',', ` + groupCol + `, ',') LIKE ?`
|
groupCondition = `CONCAT(',', ` + commonGroupCol + `, ',') LIKE ?`
|
||||||
} else {
|
} else {
|
||||||
// sqlite, PostgreSQL
|
// sqlite, PostgreSQL
|
||||||
groupCondition = `(',' || ` + groupCol + ` || ',') LIKE ?`
|
groupCondition = `(',' || ` + commonGroupCol + ` || ',') LIKE ?`
|
||||||
}
|
}
|
||||||
whereClause = "(id = ? OR name LIKE ? OR " + keyCol + " = ? OR " + baseURLCol + " LIKE ?) AND " + modelsCol + ` LIKE ? AND ` + groupCondition
|
whereClause = "(id = ? OR name LIKE ? OR " + commonKeyCol + " = ? OR " + baseURLCol + " LIKE ?) AND " + modelsCol + ` LIKE ? AND ` + groupCondition
|
||||||
args = append(args, common.String2Int(keyword), "%"+keyword+"%", keyword, "%"+keyword+"%", "%"+model+"%", "%,"+group+",%")
|
args = append(args, common.String2Int(keyword), "%"+keyword+"%", keyword, "%"+keyword+"%", "%"+model+"%", "%,"+group+",%")
|
||||||
} else {
|
} else {
|
||||||
whereClause = "(id = ? OR name LIKE ? OR " + keyCol + " = ? OR " + baseURLCol + " LIKE ?) AND " + modelsCol + " LIKE ?"
|
whereClause = "(id = ? OR name LIKE ? OR " + commonKeyCol + " = ? OR " + baseURLCol + " LIKE ?) AND " + modelsCol + " LIKE ?"
|
||||||
args = append(args, common.String2Int(keyword), "%"+keyword+"%", keyword, "%"+keyword+"%", "%"+model+"%")
|
args = append(args, common.String2Int(keyword), "%"+keyword+"%", keyword, "%"+keyword+"%", "%"+model+"%")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
37
model/log.go
37
model/log.go
@@ -63,7 +63,7 @@ func formatUserLogs(logs []*Log) {
|
|||||||
func GetLogByKey(key string) (logs []*Log, err error) {
|
func GetLogByKey(key string) (logs []*Log, err error) {
|
||||||
if os.Getenv("LOG_SQL_DSN") != "" {
|
if os.Getenv("LOG_SQL_DSN") != "" {
|
||||||
var tk Token
|
var tk Token
|
||||||
if err = DB.Model(&Token{}).Where(keyCol+"=?", strings.TrimPrefix(key, "sk-")).First(&tk).Error; err != nil {
|
if err = DB.Model(&Token{}).Where(logKeyCol+"=?", strings.TrimPrefix(key, "sk-")).First(&tk).Error; err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
err = LOG_DB.Model(&Log{}).Where("token_id=?", tk.Id).Find(&logs).Error
|
err = LOG_DB.Model(&Log{}).Where("token_id=?", tk.Id).Find(&logs).Error
|
||||||
@@ -122,8 +122,13 @@ func RecordErrorLog(c *gin.Context, userId int, channelId int, modelName string,
|
|||||||
UseTime: useTimeSeconds,
|
UseTime: useTimeSeconds,
|
||||||
IsStream: isStream,
|
IsStream: isStream,
|
||||||
Group: group,
|
Group: group,
|
||||||
Ip: func() string { if needRecordIp { return c.ClientIP() }; return "" }(),
|
Ip: func() string {
|
||||||
Other: otherStr,
|
if needRecordIp {
|
||||||
|
return c.ClientIP()
|
||||||
|
}
|
||||||
|
return ""
|
||||||
|
}(),
|
||||||
|
Other: otherStr,
|
||||||
}
|
}
|
||||||
err := LOG_DB.Create(log).Error
|
err := LOG_DB.Create(log).Error
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -165,8 +170,13 @@ func RecordConsumeLog(c *gin.Context, userId int, channelId int, promptTokens in
|
|||||||
UseTime: useTimeSeconds,
|
UseTime: useTimeSeconds,
|
||||||
IsStream: isStream,
|
IsStream: isStream,
|
||||||
Group: group,
|
Group: group,
|
||||||
Ip: func() string { if needRecordIp { return c.ClientIP() }; return "" }(),
|
Ip: func() string {
|
||||||
Other: otherStr,
|
if needRecordIp {
|
||||||
|
return c.ClientIP()
|
||||||
|
}
|
||||||
|
return ""
|
||||||
|
}(),
|
||||||
|
Other: otherStr,
|
||||||
}
|
}
|
||||||
err := LOG_DB.Create(log).Error
|
err := LOG_DB.Create(log).Error
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -206,7 +216,7 @@ func GetAllLogs(logType int, startTimestamp int64, endTimestamp int64, modelName
|
|||||||
tx = tx.Where("logs.channel_id = ?", channel)
|
tx = tx.Where("logs.channel_id = ?", channel)
|
||||||
}
|
}
|
||||||
if group != "" {
|
if group != "" {
|
||||||
tx = tx.Where("logs."+groupCol+" = ?", group)
|
tx = tx.Where("logs."+logGroupCol+" = ?", group)
|
||||||
}
|
}
|
||||||
err = tx.Model(&Log{}).Count(&total).Error
|
err = tx.Model(&Log{}).Count(&total).Error
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -217,13 +227,18 @@ func GetAllLogs(logType int, startTimestamp int64, endTimestamp int64, modelName
|
|||||||
return nil, 0, err
|
return nil, 0, err
|
||||||
}
|
}
|
||||||
|
|
||||||
channelIds := make([]int, 0)
|
channelIdsMap := make(map[int]struct{})
|
||||||
channelMap := make(map[int]string)
|
channelMap := make(map[int]string)
|
||||||
for _, log := range logs {
|
for _, log := range logs {
|
||||||
if log.ChannelId != 0 {
|
if log.ChannelId != 0 {
|
||||||
channelIds = append(channelIds, log.ChannelId)
|
channelIdsMap[log.ChannelId] = struct{}{}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
channelIds := make([]int, 0, len(channelIdsMap))
|
||||||
|
for channelId := range channelIdsMap {
|
||||||
|
channelIds = append(channelIds, channelId)
|
||||||
|
}
|
||||||
if len(channelIds) > 0 {
|
if len(channelIds) > 0 {
|
||||||
var channels []struct {
|
var channels []struct {
|
||||||
Id int `gorm:"column:id"`
|
Id int `gorm:"column:id"`
|
||||||
@@ -264,7 +279,7 @@ func GetUserLogs(userId int, logType int, startTimestamp int64, endTimestamp int
|
|||||||
tx = tx.Where("logs.created_at <= ?", endTimestamp)
|
tx = tx.Where("logs.created_at <= ?", endTimestamp)
|
||||||
}
|
}
|
||||||
if group != "" {
|
if group != "" {
|
||||||
tx = tx.Where("logs."+groupCol+" = ?", group)
|
tx = tx.Where("logs."+logGroupCol+" = ?", group)
|
||||||
}
|
}
|
||||||
err = tx.Model(&Log{}).Count(&total).Error
|
err = tx.Model(&Log{}).Count(&total).Error
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -325,8 +340,8 @@ func SumUsedQuota(logType int, startTimestamp int64, endTimestamp int64, modelNa
|
|||||||
rpmTpmQuery = rpmTpmQuery.Where("channel_id = ?", channel)
|
rpmTpmQuery = rpmTpmQuery.Where("channel_id = ?", channel)
|
||||||
}
|
}
|
||||||
if group != "" {
|
if group != "" {
|
||||||
tx = tx.Where(groupCol+" = ?", group)
|
tx = tx.Where(logGroupCol+" = ?", group)
|
||||||
rpmTpmQuery = rpmTpmQuery.Where(groupCol+" = ?", group)
|
rpmTpmQuery = rpmTpmQuery.Where(logGroupCol+" = ?", group)
|
||||||
}
|
}
|
||||||
|
|
||||||
tx = tx.Where("type = ?", LogTypeConsume)
|
tx = tx.Where("type = ?", LogTypeConsume)
|
||||||
|
|||||||
146
model/main.go
146
model/main.go
@@ -1,6 +1,7 @@
|
|||||||
package model
|
package model
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"fmt"
|
||||||
"log"
|
"log"
|
||||||
"one-api/common"
|
"one-api/common"
|
||||||
"one-api/constant"
|
"one-api/constant"
|
||||||
@@ -15,18 +16,39 @@ import (
|
|||||||
"gorm.io/gorm"
|
"gorm.io/gorm"
|
||||||
)
|
)
|
||||||
|
|
||||||
var groupCol string
|
var commonGroupCol string
|
||||||
var keyCol string
|
var commonKeyCol string
|
||||||
|
var commonTrueVal string
|
||||||
|
var commonFalseVal string
|
||||||
|
|
||||||
|
var logKeyCol string
|
||||||
|
var logGroupCol string
|
||||||
|
|
||||||
func initCol() {
|
func initCol() {
|
||||||
|
// init common column names
|
||||||
if common.UsingPostgreSQL {
|
if common.UsingPostgreSQL {
|
||||||
groupCol = `"group"`
|
commonGroupCol = `"group"`
|
||||||
keyCol = `"key"`
|
commonKeyCol = `"key"`
|
||||||
|
commonTrueVal = "true"
|
||||||
|
commonFalseVal = "false"
|
||||||
} else {
|
} else {
|
||||||
groupCol = "`group`"
|
commonGroupCol = "`group`"
|
||||||
keyCol = "`key`"
|
commonKeyCol = "`key`"
|
||||||
|
commonTrueVal = "1"
|
||||||
|
commonFalseVal = "0"
|
||||||
}
|
}
|
||||||
|
if os.Getenv("LOG_SQL_DSN") != "" {
|
||||||
|
switch common.LogSqlType {
|
||||||
|
case common.DatabaseTypePostgreSQL:
|
||||||
|
logGroupCol = `"group"`
|
||||||
|
logKeyCol = `"key"`
|
||||||
|
default:
|
||||||
|
logGroupCol = commonGroupCol
|
||||||
|
logKeyCol = commonKeyCol
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// log sql type and database type
|
||||||
|
common.SysLog("Using Log SQL Type: " + common.LogSqlType)
|
||||||
}
|
}
|
||||||
|
|
||||||
var DB *gorm.DB
|
var DB *gorm.DB
|
||||||
@@ -83,7 +105,7 @@ func CheckSetup() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func chooseDB(envName string) (*gorm.DB, error) {
|
func chooseDB(envName string, isLog bool) (*gorm.DB, error) {
|
||||||
defer func() {
|
defer func() {
|
||||||
initCol()
|
initCol()
|
||||||
}()
|
}()
|
||||||
@@ -92,7 +114,11 @@ func chooseDB(envName string) (*gorm.DB, error) {
|
|||||||
if strings.HasPrefix(dsn, "postgres://") || strings.HasPrefix(dsn, "postgresql://") {
|
if strings.HasPrefix(dsn, "postgres://") || strings.HasPrefix(dsn, "postgresql://") {
|
||||||
// Use PostgreSQL
|
// Use PostgreSQL
|
||||||
common.SysLog("using PostgreSQL as database")
|
common.SysLog("using PostgreSQL as database")
|
||||||
common.UsingPostgreSQL = true
|
if !isLog {
|
||||||
|
common.UsingPostgreSQL = true
|
||||||
|
} else {
|
||||||
|
common.LogSqlType = common.DatabaseTypePostgreSQL
|
||||||
|
}
|
||||||
return gorm.Open(postgres.New(postgres.Config{
|
return gorm.Open(postgres.New(postgres.Config{
|
||||||
DSN: dsn,
|
DSN: dsn,
|
||||||
PreferSimpleProtocol: true, // disables implicit prepared statement usage
|
PreferSimpleProtocol: true, // disables implicit prepared statement usage
|
||||||
@@ -102,7 +128,11 @@ func chooseDB(envName string) (*gorm.DB, error) {
|
|||||||
}
|
}
|
||||||
if strings.HasPrefix(dsn, "local") {
|
if strings.HasPrefix(dsn, "local") {
|
||||||
common.SysLog("SQL_DSN not set, using SQLite as database")
|
common.SysLog("SQL_DSN not set, using SQLite as database")
|
||||||
common.UsingSQLite = true
|
if !isLog {
|
||||||
|
common.UsingSQLite = true
|
||||||
|
} else {
|
||||||
|
common.LogSqlType = common.DatabaseTypeSQLite
|
||||||
|
}
|
||||||
return gorm.Open(sqlite.Open(common.SQLitePath), &gorm.Config{
|
return gorm.Open(sqlite.Open(common.SQLitePath), &gorm.Config{
|
||||||
PrepareStmt: true, // precompile SQL
|
PrepareStmt: true, // precompile SQL
|
||||||
})
|
})
|
||||||
@@ -117,7 +147,11 @@ func chooseDB(envName string) (*gorm.DB, error) {
|
|||||||
dsn += "?parseTime=true"
|
dsn += "?parseTime=true"
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
common.UsingMySQL = true
|
if !isLog {
|
||||||
|
common.UsingMySQL = true
|
||||||
|
} else {
|
||||||
|
common.LogSqlType = common.DatabaseTypeMySQL
|
||||||
|
}
|
||||||
return gorm.Open(mysql.Open(dsn), &gorm.Config{
|
return gorm.Open(mysql.Open(dsn), &gorm.Config{
|
||||||
PrepareStmt: true, // precompile SQL
|
PrepareStmt: true, // precompile SQL
|
||||||
})
|
})
|
||||||
@@ -131,7 +165,7 @@ func chooseDB(envName string) (*gorm.DB, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func InitDB() (err error) {
|
func InitDB() (err error) {
|
||||||
db, err := chooseDB("SQL_DSN")
|
db, err := chooseDB("SQL_DSN", false)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
if common.DebugEnabled {
|
if common.DebugEnabled {
|
||||||
db = db.Debug()
|
db = db.Debug()
|
||||||
@@ -149,7 +183,7 @@ func InitDB() (err error) {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
if common.UsingMySQL {
|
if common.UsingMySQL {
|
||||||
_, _ = sqlDB.Exec("ALTER TABLE channels MODIFY model_mapping TEXT;") // TODO: delete this line when most users have upgraded
|
//_, _ = sqlDB.Exec("ALTER TABLE channels MODIFY model_mapping TEXT;") // TODO: delete this line when most users have upgraded
|
||||||
}
|
}
|
||||||
common.SysLog("database migration started")
|
common.SysLog("database migration started")
|
||||||
err = migrateDB()
|
err = migrateDB()
|
||||||
@@ -165,7 +199,7 @@ func InitLogDB() (err error) {
|
|||||||
LOG_DB = DB
|
LOG_DB = DB
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
db, err := chooseDB("LOG_SQL_DSN")
|
db, err := chooseDB("LOG_SQL_DSN", true)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
if common.DebugEnabled {
|
if common.DebugEnabled {
|
||||||
db = db.Debug()
|
db = db.Debug()
|
||||||
@@ -198,54 +232,50 @@ func InitLogDB() (err error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func migrateDB() error {
|
func migrateDB() error {
|
||||||
err := DB.AutoMigrate(&Channel{})
|
var wg sync.WaitGroup
|
||||||
if err != nil {
|
errChan := make(chan error, 12) // Buffer size matches number of migrations
|
||||||
return err
|
|
||||||
|
migrations := []struct {
|
||||||
|
model interface{}
|
||||||
|
name string
|
||||||
|
}{
|
||||||
|
{&Channel{}, "Channel"},
|
||||||
|
{&Token{}, "Token"},
|
||||||
|
{&User{}, "User"},
|
||||||
|
{&Option{}, "Option"},
|
||||||
|
{&Redemption{}, "Redemption"},
|
||||||
|
{&Ability{}, "Ability"},
|
||||||
|
{&Log{}, "Log"},
|
||||||
|
{&Midjourney{}, "Midjourney"},
|
||||||
|
{&TopUp{}, "TopUp"},
|
||||||
|
{&QuotaData{}, "QuotaData"},
|
||||||
|
{&Task{}, "Task"},
|
||||||
|
{&Setup{}, "Setup"},
|
||||||
}
|
}
|
||||||
err = DB.AutoMigrate(&Token{})
|
|
||||||
if err != nil {
|
for _, m := range migrations {
|
||||||
return err
|
wg.Add(1)
|
||||||
|
go func(model interface{}, name string) {
|
||||||
|
defer wg.Done()
|
||||||
|
if err := DB.AutoMigrate(model); err != nil {
|
||||||
|
errChan <- fmt.Errorf("failed to migrate %s: %v", name, err)
|
||||||
|
}
|
||||||
|
}(m.model, m.name)
|
||||||
}
|
}
|
||||||
err = DB.AutoMigrate(&User{})
|
|
||||||
if err != nil {
|
// Wait for all migrations to complete
|
||||||
return err
|
wg.Wait()
|
||||||
|
close(errChan)
|
||||||
|
|
||||||
|
// Check for any errors
|
||||||
|
for err := range errChan {
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
}
|
}
|
||||||
err = DB.AutoMigrate(&Option{})
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
err = DB.AutoMigrate(&Redemption{})
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
err = DB.AutoMigrate(&Ability{})
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
err = DB.AutoMigrate(&Log{})
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
err = DB.AutoMigrate(&Midjourney{})
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
err = DB.AutoMigrate(&TopUp{})
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
err = DB.AutoMigrate(&QuotaData{})
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
err = DB.AutoMigrate(&Task{})
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
err = DB.AutoMigrate(&Setup{})
|
|
||||||
common.SysLog("database migrated")
|
common.SysLog("database migrated")
|
||||||
//err = createRootAccountIfNeed()
|
return nil
|
||||||
return err
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func migrateLOGDB() error {
|
func migrateLOGDB() error {
|
||||||
|
|||||||
@@ -66,7 +66,7 @@ func SearchUserTokens(userId int, keyword string, token string) (tokens []*Token
|
|||||||
if token != "" {
|
if token != "" {
|
||||||
token = strings.Trim(token, "sk-")
|
token = strings.Trim(token, "sk-")
|
||||||
}
|
}
|
||||||
err = DB.Where("user_id = ?", userId).Where("name LIKE ?", "%"+keyword+"%").Where(keyCol+" LIKE ?", "%"+token+"%").Find(&tokens).Error
|
err = DB.Where("user_id = ?", userId).Where("name LIKE ?", "%"+keyword+"%").Where(commonKeyCol+" LIKE ?", "%"+token+"%").Find(&tokens).Error
|
||||||
return tokens, err
|
return tokens, err
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -161,7 +161,7 @@ func GetTokenByKey(key string, fromDB bool) (token *Token, err error) {
|
|||||||
// Don't return error - fall through to DB
|
// Don't return error - fall through to DB
|
||||||
}
|
}
|
||||||
fromDB = true
|
fromDB = true
|
||||||
err = DB.Where(keyCol+" = ?", key).First(&token).Error
|
err = DB.Where(commonKeyCol+" = ?", key).First(&token).Error
|
||||||
return token, err
|
return token, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -175,7 +175,7 @@ func SearchUsers(keyword string, group string, startIdx int, num int) ([]*User,
|
|||||||
// 如果是数字,同时搜索ID和其他字段
|
// 如果是数字,同时搜索ID和其他字段
|
||||||
likeCondition = "id = ? OR " + likeCondition
|
likeCondition = "id = ? OR " + likeCondition
|
||||||
if group != "" {
|
if group != "" {
|
||||||
query = query.Where("("+likeCondition+") AND "+groupCol+" = ?",
|
query = query.Where("("+likeCondition+") AND "+commonGroupCol+" = ?",
|
||||||
keywordInt, "%"+keyword+"%", "%"+keyword+"%", "%"+keyword+"%", group)
|
keywordInt, "%"+keyword+"%", "%"+keyword+"%", "%"+keyword+"%", group)
|
||||||
} else {
|
} else {
|
||||||
query = query.Where(likeCondition,
|
query = query.Where(likeCondition,
|
||||||
@@ -184,7 +184,7 @@ func SearchUsers(keyword string, group string, startIdx int, num int) ([]*User,
|
|||||||
} else {
|
} else {
|
||||||
// 非数字关键字,只搜索字符串字段
|
// 非数字关键字,只搜索字符串字段
|
||||||
if group != "" {
|
if group != "" {
|
||||||
query = query.Where("("+likeCondition+") AND "+groupCol+" = ?",
|
query = query.Where("("+likeCondition+") AND "+commonGroupCol+" = ?",
|
||||||
"%"+keyword+"%", "%"+keyword+"%", "%"+keyword+"%", group)
|
"%"+keyword+"%", "%"+keyword+"%", "%"+keyword+"%", group)
|
||||||
} else {
|
} else {
|
||||||
query = query.Where(likeCondition,
|
query = query.Where(likeCondition,
|
||||||
@@ -615,7 +615,7 @@ func GetUserGroup(id int, fromDB bool) (group string, err error) {
|
|||||||
// Don't return error - fall through to DB
|
// Don't return error - fall through to DB
|
||||||
}
|
}
|
||||||
fromDB = true
|
fromDB = true
|
||||||
err = DB.Model(&User{}).Where("id = ?", id).Select(groupCol).Find(&group).Error
|
err = DB.Model(&User{}).Where("id = ?", id).Select(commonGroupCol).Find(&group).Error
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", err
|
return "", err
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user