Merge pull request #1225 from QuantumNous/fix_mixing_sql_conflicts

Fix mixing databases conflicts
This commit is contained in:
Calcium-Ion
2025-06-14 18:24:53 +08:00
committed by GitHub
7 changed files with 142 additions and 98 deletions

View File

@@ -1,7 +1,14 @@
package common package common
const (
DatabaseTypeMySQL = "mysql"
DatabaseTypeSQLite = "sqlite"
DatabaseTypePostgreSQL = "postgres"
)
var UsingSQLite = false var UsingSQLite = false
var UsingPostgreSQL = false var UsingPostgreSQL = false
var LogSqlType = DatabaseTypeSQLite // Default to SQLite for logging SQL queries
var UsingMySQL = false var UsingMySQL = false
var UsingClickHouse = false var UsingClickHouse = false

View File

@@ -24,7 +24,7 @@ type Ability struct {
func GetGroupModels(group string) []string { func GetGroupModels(group string) []string {
var models []string var models []string
// Find distinct models // Find distinct models
DB.Table("abilities").Where(groupCol+" = ? and enabled = ?", group, true).Distinct("model").Pluck("model", &models) DB.Table("abilities").Where(commonGroupCol+" = ? and enabled = ?", group, true).Distinct("model").Pluck("model", &models)
return models return models
} }
@@ -42,16 +42,12 @@ func GetAllEnableAbilities() []Ability {
} }
func getPriority(group string, model string, retry int) (int, error) { func getPriority(group string, model string, retry int) (int, error) {
trueVal := "1"
if common.UsingPostgreSQL {
trueVal = "true"
}
var priorities []int var priorities []int
err := DB.Model(&Ability{}). err := DB.Model(&Ability{}).
Select("DISTINCT(priority)"). Select("DISTINCT(priority)").
Where(groupCol+" = ? and model = ? and enabled = "+trueVal, group, model). Where(commonGroupCol+" = ? and model = ? and enabled = ?", group, model, commonTrueVal).
Order("priority DESC"). // 按优先级降序排序 Order("priority DESC"). // 按优先级降序排序
Pluck("priority", &priorities).Error // Pluck用于将查询的结果直接扫描到一个切片中 Pluck("priority", &priorities).Error // Pluck用于将查询的结果直接扫描到一个切片中
if err != nil { if err != nil {
@@ -76,18 +72,14 @@ func getPriority(group string, model string, retry int) (int, error) {
} }
func getChannelQuery(group string, model string, retry int) *gorm.DB { func getChannelQuery(group string, model string, retry int) *gorm.DB {
trueVal := "1" maxPrioritySubQuery := DB.Model(&Ability{}).Select("MAX(priority)").Where(commonGroupCol+" = ? and model = ? and enabled = ?", group, model, commonTrueVal)
if common.UsingPostgreSQL { channelQuery := DB.Where(commonGroupCol+" = ? and model = ? and enabled = ? and priority = (?)", group, model, commonTrueVal, maxPrioritySubQuery)
trueVal = "true"
}
maxPrioritySubQuery := DB.Model(&Ability{}).Select("MAX(priority)").Where(groupCol+" = ? and model = ? and enabled = "+trueVal, group, model)
channelQuery := DB.Where(groupCol+" = ? and model = ? and enabled = "+trueVal+" and priority = (?)", group, model, maxPrioritySubQuery)
if retry != 0 { if retry != 0 {
priority, err := getPriority(group, model, retry) priority, err := getPriority(group, model, retry)
if err != nil { if err != nil {
common.SysError(fmt.Sprintf("Get priority failed: %s", err.Error())) common.SysError(fmt.Sprintf("Get priority failed: %s", err.Error()))
} else { } else {
channelQuery = DB.Where(groupCol+" = ? and model = ? and enabled = "+trueVal+" and priority = ?", group, model, priority) channelQuery = DB.Where(commonGroupCol+" = ? and model = ? and enabled = ? and priority = ?", group, model, commonTrueVal, priority)
} }
} }

View File

@@ -145,7 +145,7 @@ func SearchChannels(keyword string, group string, model string, idSort bool) ([]
} }
// 构造基础查询 // 构造基础查询
baseQuery := DB.Model(&Channel{}).Omit(keyCol) baseQuery := DB.Model(&Channel{}).Omit("key")
// 构造WHERE子句 // 构造WHERE子句
var whereClause string var whereClause string
@@ -153,15 +153,15 @@ func SearchChannels(keyword string, group string, model string, idSort bool) ([]
if group != "" && group != "null" { if group != "" && group != "null" {
var groupCondition string var groupCondition string
if common.UsingMySQL { if common.UsingMySQL {
groupCondition = `CONCAT(',', ` + groupCol + `, ',') LIKE ?` groupCondition = `CONCAT(',', ` + commonGroupCol + `, ',') LIKE ?`
} else { } else {
// sqlite, PostgreSQL // sqlite, PostgreSQL
groupCondition = `(',' || ` + groupCol + ` || ',') LIKE ?` groupCondition = `(',' || ` + commonGroupCol + ` || ',') LIKE ?`
} }
whereClause = "(id = ? OR name LIKE ? OR " + keyCol + " = ? OR " + baseURLCol + " LIKE ?) AND " + modelsCol + ` LIKE ? AND ` + groupCondition whereClause = "(id = ? OR name LIKE ? OR " + commonKeyCol + " = ? OR " + baseURLCol + " LIKE ?) AND " + modelsCol + ` LIKE ? AND ` + groupCondition
args = append(args, common.String2Int(keyword), "%"+keyword+"%", keyword, "%"+keyword+"%", "%"+model+"%", "%,"+group+",%") args = append(args, common.String2Int(keyword), "%"+keyword+"%", keyword, "%"+keyword+"%", "%"+model+"%", "%,"+group+",%")
} else { } else {
whereClause = "(id = ? OR name LIKE ? OR " + keyCol + " = ? OR " + baseURLCol + " LIKE ?) AND " + modelsCol + " LIKE ?" whereClause = "(id = ? OR name LIKE ? OR " + commonKeyCol + " = ? OR " + baseURLCol + " LIKE ?) AND " + modelsCol + " LIKE ?"
args = append(args, common.String2Int(keyword), "%"+keyword+"%", keyword, "%"+keyword+"%", "%"+model+"%") args = append(args, common.String2Int(keyword), "%"+keyword+"%", keyword, "%"+keyword+"%", "%"+model+"%")
} }
@@ -478,7 +478,7 @@ func SearchTags(keyword string, group string, model string, idSort bool) ([]*str
} }
// 构造基础查询 // 构造基础查询
baseQuery := DB.Model(&Channel{}).Omit(keyCol) baseQuery := DB.Model(&Channel{}).Omit("key")
// 构造WHERE子句 // 构造WHERE子句
var whereClause string var whereClause string
@@ -486,15 +486,15 @@ func SearchTags(keyword string, group string, model string, idSort bool) ([]*str
if group != "" && group != "null" { if group != "" && group != "null" {
var groupCondition string var groupCondition string
if common.UsingMySQL { if common.UsingMySQL {
groupCondition = `CONCAT(',', ` + groupCol + `, ',') LIKE ?` groupCondition = `CONCAT(',', ` + commonGroupCol + `, ',') LIKE ?`
} else { } else {
// sqlite, PostgreSQL // sqlite, PostgreSQL
groupCondition = `(',' || ` + groupCol + ` || ',') LIKE ?` groupCondition = `(',' || ` + commonGroupCol + ` || ',') LIKE ?`
} }
whereClause = "(id = ? OR name LIKE ? OR " + keyCol + " = ? OR " + baseURLCol + " LIKE ?) AND " + modelsCol + ` LIKE ? AND ` + groupCondition whereClause = "(id = ? OR name LIKE ? OR " + commonKeyCol + " = ? OR " + baseURLCol + " LIKE ?) AND " + modelsCol + ` LIKE ? AND ` + groupCondition
args = append(args, common.String2Int(keyword), "%"+keyword+"%", keyword, "%"+keyword+"%", "%"+model+"%", "%,"+group+",%") args = append(args, common.String2Int(keyword), "%"+keyword+"%", keyword, "%"+keyword+"%", "%"+model+"%", "%,"+group+",%")
} else { } else {
whereClause = "(id = ? OR name LIKE ? OR " + keyCol + " = ? OR " + baseURLCol + " LIKE ?) AND " + modelsCol + " LIKE ?" whereClause = "(id = ? OR name LIKE ? OR " + commonKeyCol + " = ? OR " + baseURLCol + " LIKE ?) AND " + modelsCol + " LIKE ?"
args = append(args, common.String2Int(keyword), "%"+keyword+"%", keyword, "%"+keyword+"%", "%"+model+"%") args = append(args, common.String2Int(keyword), "%"+keyword+"%", keyword, "%"+keyword+"%", "%"+model+"%")
} }

View File

@@ -63,7 +63,7 @@ func formatUserLogs(logs []*Log) {
func GetLogByKey(key string) (logs []*Log, err error) { func GetLogByKey(key string) (logs []*Log, err error) {
if os.Getenv("LOG_SQL_DSN") != "" { if os.Getenv("LOG_SQL_DSN") != "" {
var tk Token var tk Token
if err = DB.Model(&Token{}).Where(keyCol+"=?", strings.TrimPrefix(key, "sk-")).First(&tk).Error; err != nil { if err = DB.Model(&Token{}).Where(logKeyCol+"=?", strings.TrimPrefix(key, "sk-")).First(&tk).Error; err != nil {
return nil, err return nil, err
} }
err = LOG_DB.Model(&Log{}).Where("token_id=?", tk.Id).Find(&logs).Error err = LOG_DB.Model(&Log{}).Where("token_id=?", tk.Id).Find(&logs).Error
@@ -122,8 +122,13 @@ func RecordErrorLog(c *gin.Context, userId int, channelId int, modelName string,
UseTime: useTimeSeconds, UseTime: useTimeSeconds,
IsStream: isStream, IsStream: isStream,
Group: group, Group: group,
Ip: func() string { if needRecordIp { return c.ClientIP() }; return "" }(), Ip: func() string {
Other: otherStr, if needRecordIp {
return c.ClientIP()
}
return ""
}(),
Other: otherStr,
} }
err := LOG_DB.Create(log).Error err := LOG_DB.Create(log).Error
if err != nil { if err != nil {
@@ -165,8 +170,13 @@ func RecordConsumeLog(c *gin.Context, userId int, channelId int, promptTokens in
UseTime: useTimeSeconds, UseTime: useTimeSeconds,
IsStream: isStream, IsStream: isStream,
Group: group, Group: group,
Ip: func() string { if needRecordIp { return c.ClientIP() }; return "" }(), Ip: func() string {
Other: otherStr, if needRecordIp {
return c.ClientIP()
}
return ""
}(),
Other: otherStr,
} }
err := LOG_DB.Create(log).Error err := LOG_DB.Create(log).Error
if err != nil { if err != nil {
@@ -206,7 +216,7 @@ func GetAllLogs(logType int, startTimestamp int64, endTimestamp int64, modelName
tx = tx.Where("logs.channel_id = ?", channel) tx = tx.Where("logs.channel_id = ?", channel)
} }
if group != "" { if group != "" {
tx = tx.Where("logs."+groupCol+" = ?", group) tx = tx.Where("logs."+logGroupCol+" = ?", group)
} }
err = tx.Model(&Log{}).Count(&total).Error err = tx.Model(&Log{}).Count(&total).Error
if err != nil { if err != nil {
@@ -217,13 +227,18 @@ func GetAllLogs(logType int, startTimestamp int64, endTimestamp int64, modelName
return nil, 0, err return nil, 0, err
} }
channelIds := make([]int, 0) channelIdsMap := make(map[int]struct{})
channelMap := make(map[int]string) channelMap := make(map[int]string)
for _, log := range logs { for _, log := range logs {
if log.ChannelId != 0 { if log.ChannelId != 0 {
channelIds = append(channelIds, log.ChannelId) channelIdsMap[log.ChannelId] = struct{}{}
} }
} }
channelIds := make([]int, 0, len(channelIdsMap))
for channelId := range channelIdsMap {
channelIds = append(channelIds, channelId)
}
if len(channelIds) > 0 { if len(channelIds) > 0 {
var channels []struct { var channels []struct {
Id int `gorm:"column:id"` Id int `gorm:"column:id"`
@@ -264,7 +279,7 @@ func GetUserLogs(userId int, logType int, startTimestamp int64, endTimestamp int
tx = tx.Where("logs.created_at <= ?", endTimestamp) tx = tx.Where("logs.created_at <= ?", endTimestamp)
} }
if group != "" { if group != "" {
tx = tx.Where("logs."+groupCol+" = ?", group) tx = tx.Where("logs."+logGroupCol+" = ?", group)
} }
err = tx.Model(&Log{}).Count(&total).Error err = tx.Model(&Log{}).Count(&total).Error
if err != nil { if err != nil {
@@ -325,8 +340,8 @@ func SumUsedQuota(logType int, startTimestamp int64, endTimestamp int64, modelNa
rpmTpmQuery = rpmTpmQuery.Where("channel_id = ?", channel) rpmTpmQuery = rpmTpmQuery.Where("channel_id = ?", channel)
} }
if group != "" { if group != "" {
tx = tx.Where(groupCol+" = ?", group) tx = tx.Where(logGroupCol+" = ?", group)
rpmTpmQuery = rpmTpmQuery.Where(groupCol+" = ?", group) rpmTpmQuery = rpmTpmQuery.Where(logGroupCol+" = ?", group)
} }
tx = tx.Where("type = ?", LogTypeConsume) tx = tx.Where("type = ?", LogTypeConsume)

View File

@@ -1,6 +1,7 @@
package model package model
import ( import (
"fmt"
"log" "log"
"one-api/common" "one-api/common"
"one-api/constant" "one-api/constant"
@@ -15,18 +16,39 @@ import (
"gorm.io/gorm" "gorm.io/gorm"
) )
var groupCol string var commonGroupCol string
var keyCol string var commonKeyCol string
var commonTrueVal string
var commonFalseVal string
var logKeyCol string
var logGroupCol string
func initCol() { func initCol() {
// init common column names
if common.UsingPostgreSQL { if common.UsingPostgreSQL {
groupCol = `"group"` commonGroupCol = `"group"`
keyCol = `"key"` commonKeyCol = `"key"`
commonTrueVal = "true"
commonFalseVal = "false"
} else { } else {
groupCol = "`group`" commonGroupCol = "`group`"
keyCol = "`key`" commonKeyCol = "`key`"
commonTrueVal = "1"
commonFalseVal = "0"
} }
if os.Getenv("LOG_SQL_DSN") != "" {
switch common.LogSqlType {
case common.DatabaseTypePostgreSQL:
logGroupCol = `"group"`
logKeyCol = `"key"`
default:
logGroupCol = commonGroupCol
logKeyCol = commonKeyCol
}
}
// log sql type and database type
common.SysLog("Using Log SQL Type: " + common.LogSqlType)
} }
var DB *gorm.DB var DB *gorm.DB
@@ -83,7 +105,7 @@ func CheckSetup() {
} }
} }
func chooseDB(envName string) (*gorm.DB, error) { func chooseDB(envName string, isLog bool) (*gorm.DB, error) {
defer func() { defer func() {
initCol() initCol()
}() }()
@@ -92,7 +114,11 @@ func chooseDB(envName string) (*gorm.DB, error) {
if strings.HasPrefix(dsn, "postgres://") || strings.HasPrefix(dsn, "postgresql://") { if strings.HasPrefix(dsn, "postgres://") || strings.HasPrefix(dsn, "postgresql://") {
// Use PostgreSQL // Use PostgreSQL
common.SysLog("using PostgreSQL as database") common.SysLog("using PostgreSQL as database")
common.UsingPostgreSQL = true if !isLog {
common.UsingPostgreSQL = true
} else {
common.LogSqlType = common.DatabaseTypePostgreSQL
}
return gorm.Open(postgres.New(postgres.Config{ return gorm.Open(postgres.New(postgres.Config{
DSN: dsn, DSN: dsn,
PreferSimpleProtocol: true, // disables implicit prepared statement usage PreferSimpleProtocol: true, // disables implicit prepared statement usage
@@ -102,7 +128,11 @@ func chooseDB(envName string) (*gorm.DB, error) {
} }
if strings.HasPrefix(dsn, "local") { if strings.HasPrefix(dsn, "local") {
common.SysLog("SQL_DSN not set, using SQLite as database") common.SysLog("SQL_DSN not set, using SQLite as database")
common.UsingSQLite = true if !isLog {
common.UsingSQLite = true
} else {
common.LogSqlType = common.DatabaseTypeSQLite
}
return gorm.Open(sqlite.Open(common.SQLitePath), &gorm.Config{ return gorm.Open(sqlite.Open(common.SQLitePath), &gorm.Config{
PrepareStmt: true, // precompile SQL PrepareStmt: true, // precompile SQL
}) })
@@ -117,7 +147,11 @@ func chooseDB(envName string) (*gorm.DB, error) {
dsn += "?parseTime=true" dsn += "?parseTime=true"
} }
} }
common.UsingMySQL = true if !isLog {
common.UsingMySQL = true
} else {
common.LogSqlType = common.DatabaseTypeMySQL
}
return gorm.Open(mysql.Open(dsn), &gorm.Config{ return gorm.Open(mysql.Open(dsn), &gorm.Config{
PrepareStmt: true, // precompile SQL PrepareStmt: true, // precompile SQL
}) })
@@ -131,7 +165,7 @@ func chooseDB(envName string) (*gorm.DB, error) {
} }
func InitDB() (err error) { func InitDB() (err error) {
db, err := chooseDB("SQL_DSN") db, err := chooseDB("SQL_DSN", false)
if err == nil { if err == nil {
if common.DebugEnabled { if common.DebugEnabled {
db = db.Debug() db = db.Debug()
@@ -149,7 +183,7 @@ func InitDB() (err error) {
return nil return nil
} }
if common.UsingMySQL { if common.UsingMySQL {
_, _ = sqlDB.Exec("ALTER TABLE channels MODIFY model_mapping TEXT;") // TODO: delete this line when most users have upgraded //_, _ = sqlDB.Exec("ALTER TABLE channels MODIFY model_mapping TEXT;") // TODO: delete this line when most users have upgraded
} }
common.SysLog("database migration started") common.SysLog("database migration started")
err = migrateDB() err = migrateDB()
@@ -165,7 +199,7 @@ func InitLogDB() (err error) {
LOG_DB = DB LOG_DB = DB
return return
} }
db, err := chooseDB("LOG_SQL_DSN") db, err := chooseDB("LOG_SQL_DSN", true)
if err == nil { if err == nil {
if common.DebugEnabled { if common.DebugEnabled {
db = db.Debug() db = db.Debug()
@@ -198,54 +232,50 @@ func InitLogDB() (err error) {
} }
func migrateDB() error { func migrateDB() error {
err := DB.AutoMigrate(&Channel{}) var wg sync.WaitGroup
if err != nil { errChan := make(chan error, 12) // Buffer size matches number of migrations
return err
migrations := []struct {
model interface{}
name string
}{
{&Channel{}, "Channel"},
{&Token{}, "Token"},
{&User{}, "User"},
{&Option{}, "Option"},
{&Redemption{}, "Redemption"},
{&Ability{}, "Ability"},
{&Log{}, "Log"},
{&Midjourney{}, "Midjourney"},
{&TopUp{}, "TopUp"},
{&QuotaData{}, "QuotaData"},
{&Task{}, "Task"},
{&Setup{}, "Setup"},
} }
err = DB.AutoMigrate(&Token{})
if err != nil { for _, m := range migrations {
return err wg.Add(1)
go func(model interface{}, name string) {
defer wg.Done()
if err := DB.AutoMigrate(model); err != nil {
errChan <- fmt.Errorf("failed to migrate %s: %v", name, err)
}
}(m.model, m.name)
} }
err = DB.AutoMigrate(&User{})
if err != nil { // Wait for all migrations to complete
return err wg.Wait()
close(errChan)
// Check for any errors
for err := range errChan {
if err != nil {
return err
}
} }
err = DB.AutoMigrate(&Option{})
if err != nil {
return err
}
err = DB.AutoMigrate(&Redemption{})
if err != nil {
return err
}
err = DB.AutoMigrate(&Ability{})
if err != nil {
return err
}
err = DB.AutoMigrate(&Log{})
if err != nil {
return err
}
err = DB.AutoMigrate(&Midjourney{})
if err != nil {
return err
}
err = DB.AutoMigrate(&TopUp{})
if err != nil {
return err
}
err = DB.AutoMigrate(&QuotaData{})
if err != nil {
return err
}
err = DB.AutoMigrate(&Task{})
if err != nil {
return err
}
err = DB.AutoMigrate(&Setup{})
common.SysLog("database migrated") common.SysLog("database migrated")
//err = createRootAccountIfNeed() return nil
return err
} }
func migrateLOGDB() error { func migrateLOGDB() error {

View File

@@ -66,7 +66,7 @@ func SearchUserTokens(userId int, keyword string, token string) (tokens []*Token
if token != "" { if token != "" {
token = strings.Trim(token, "sk-") token = strings.Trim(token, "sk-")
} }
err = DB.Where("user_id = ?", userId).Where("name LIKE ?", "%"+keyword+"%").Where(keyCol+" LIKE ?", "%"+token+"%").Find(&tokens).Error err = DB.Where("user_id = ?", userId).Where("name LIKE ?", "%"+keyword+"%").Where(commonKeyCol+" LIKE ?", "%"+token+"%").Find(&tokens).Error
return tokens, err return tokens, err
} }
@@ -161,7 +161,7 @@ func GetTokenByKey(key string, fromDB bool) (token *Token, err error) {
// Don't return error - fall through to DB // Don't return error - fall through to DB
} }
fromDB = true fromDB = true
err = DB.Where(keyCol+" = ?", key).First(&token).Error err = DB.Where(commonKeyCol+" = ?", key).First(&token).Error
return token, err return token, err
} }

View File

@@ -175,7 +175,7 @@ func SearchUsers(keyword string, group string, startIdx int, num int) ([]*User,
// 如果是数字同时搜索ID和其他字段 // 如果是数字同时搜索ID和其他字段
likeCondition = "id = ? OR " + likeCondition likeCondition = "id = ? OR " + likeCondition
if group != "" { if group != "" {
query = query.Where("("+likeCondition+") AND "+groupCol+" = ?", query = query.Where("("+likeCondition+") AND "+commonGroupCol+" = ?",
keywordInt, "%"+keyword+"%", "%"+keyword+"%", "%"+keyword+"%", group) keywordInt, "%"+keyword+"%", "%"+keyword+"%", "%"+keyword+"%", group)
} else { } else {
query = query.Where(likeCondition, query = query.Where(likeCondition,
@@ -184,7 +184,7 @@ func SearchUsers(keyword string, group string, startIdx int, num int) ([]*User,
} else { } else {
// 非数字关键字,只搜索字符串字段 // 非数字关键字,只搜索字符串字段
if group != "" { if group != "" {
query = query.Where("("+likeCondition+") AND "+groupCol+" = ?", query = query.Where("("+likeCondition+") AND "+commonGroupCol+" = ?",
"%"+keyword+"%", "%"+keyword+"%", "%"+keyword+"%", group) "%"+keyword+"%", "%"+keyword+"%", "%"+keyword+"%", group)
} else { } else {
query = query.Where(likeCondition, query = query.Where(likeCondition,
@@ -615,7 +615,7 @@ func GetUserGroup(id int, fromDB bool) (group string, err error) {
// Don't return error - fall through to DB // Don't return error - fall through to DB
} }
fromDB = true fromDB = true
err = DB.Model(&User{}).Where("id = ?", id).Select(groupCol).Find(&group).Error err = DB.Model(&User{}).Where("id = ?", id).Select(commonGroupCol).Find(&group).Error
if err != nil { if err != nil {
return "", err return "", err
} }