Merge remote-tracking branch 'origin/alpha' into refactor/model-pricing

This commit is contained in:
t0ng7u
2025-08-04 21:37:38 +08:00
22 changed files with 3008 additions and 68 deletions

View File

@@ -41,6 +41,7 @@ type Channel struct {
Priority *int64 `json:"priority" gorm:"bigint;default:0"`
AutoBan *int `json:"auto_ban" gorm:"default:1"`
OtherInfo string `json:"other_info"`
Settings string `json:"settings"`
Tag *string `json:"tag" gorm:"index"`
Setting *string `json:"setting" gorm:"type:text"` // 渠道额外设置
ParamOverride *string `json:"param_override" gorm:"type:text"`
@@ -52,11 +53,13 @@ type Channel struct {
}
type ChannelInfo struct {
IsMultiKey bool `json:"is_multi_key"` // 是否多Key模式
MultiKeySize int `json:"multi_key_size"` // 多Key模式下的Key数量
MultiKeyStatusList map[int]int `json:"multi_key_status_list"` // key状态列表key index -> status
MultiKeyPollingIndex int `json:"multi_key_polling_index"` // 多Key模式下轮询的key索引
MultiKeyMode constant.MultiKeyMode `json:"multi_key_mode"`
IsMultiKey bool `json:"is_multi_key"` // 是否多Key模式
MultiKeySize int `json:"multi_key_size"` // 多Key模式下的Key数量
MultiKeyStatusList map[int]int `json:"multi_key_status_list"` // key状态列表key index -> status
MultiKeyDisabledReason map[int]string `json:"multi_key_disabled_reason,omitempty"` // key禁用原因列表key index -> reason
MultiKeyDisabledTime map[int]int64 `json:"multi_key_disabled_time,omitempty"` // key禁用时间列表key index -> time
MultiKeyPollingIndex int `json:"multi_key_polling_index"` // 多Key模式下轮询的key索引
MultiKeyMode constant.MultiKeyMode `json:"multi_key_mode"`
}
// Value implements driver.Valuer interface
@@ -70,7 +73,7 @@ func (c *ChannelInfo) Scan(value interface{}) error {
return common.Unmarshal(bytesValue, c)
}
func (channel *Channel) getKeys() []string {
func (channel *Channel) GetKeys() []string {
if channel.Key == "" {
return []string{}
}
@@ -101,7 +104,7 @@ func (channel *Channel) GetNextEnabledKey() (string, int, *types.NewAPIError) {
}
// Obtain all keys (split by \n)
keys := channel.getKeys()
keys := channel.GetKeys()
if len(keys) == 0 {
// No keys available, return error, should disable the channel
return "", 0, types.NewError(errors.New("no keys available"), types.ErrorCodeChannelNoAvailableKey)
@@ -528,8 +531,8 @@ func CleanupChannelPollingLocks() {
})
}
func handlerMultiKeyUpdate(channel *Channel, usingKey string, status int) {
keys := channel.getKeys()
func handlerMultiKeyUpdate(channel *Channel, usingKey string, status int, reason string) {
keys := channel.GetKeys()
if len(keys) == 0 {
channel.Status = status
} else {
@@ -547,6 +550,14 @@ func handlerMultiKeyUpdate(channel *Channel, usingKey string, status int) {
delete(channel.ChannelInfo.MultiKeyStatusList, keyIndex)
} else {
channel.ChannelInfo.MultiKeyStatusList[keyIndex] = status
if channel.ChannelInfo.MultiKeyDisabledReason == nil {
channel.ChannelInfo.MultiKeyDisabledReason = make(map[int]string)
}
if channel.ChannelInfo.MultiKeyDisabledTime == nil {
channel.ChannelInfo.MultiKeyDisabledTime = make(map[int]int64)
}
channel.ChannelInfo.MultiKeyDisabledReason[keyIndex] = reason
channel.ChannelInfo.MultiKeyDisabledTime[keyIndex] = common.GetTimestamp()
}
if len(channel.ChannelInfo.MultiKeyStatusList) >= channel.ChannelInfo.MultiKeySize {
channel.Status = common.ChannelStatusAutoDisabled
@@ -569,7 +580,7 @@ func UpdateChannelStatus(channelId int, usingKey string, status int, reason stri
}
if channelCache.ChannelInfo.IsMultiKey {
// 如果是多Key模式更新缓存中的状态
handlerMultiKeyUpdate(channelCache, usingKey, status)
handlerMultiKeyUpdate(channelCache, usingKey, status, reason)
//CacheUpdateChannel(channelCache)
//return true
} else {
@@ -600,7 +611,7 @@ func UpdateChannelStatus(channelId int, usingKey string, status int, reason stri
if channel.ChannelInfo.IsMultiKey {
beforeStatus := channel.Status
handlerMultiKeyUpdate(channel, usingKey, status)
handlerMultiKeyUpdate(channel, usingKey, status, reason)
if beforeStatus != channel.Status {
shouldUpdateAbilities = true
}

View File

@@ -70,7 +70,7 @@ func InitChannelCache() {
//channelsIDM = newChannelId2channel
for i, channel := range newChannelId2channel {
if channel.ChannelInfo.IsMultiKey {
channel.Keys = channel.getKeys()
channel.Keys = channel.GetKeys()
if channel.ChannelInfo.MultiKeyMode == constant.MultiKeyModePolling {
if oldChannel, ok := channelsIDM[i]; ok {
// 存在旧的渠道如果是多key且轮询保留轮询索引信息

View File

@@ -254,6 +254,8 @@ func migrateDB() error {
&Vendor{},
&PrefillGroup{},
&Setup{},
&TwoFA{},
&TwoFABackupCode{},
)
if err != nil {
return err
@@ -283,6 +285,8 @@ func migrateDBFast() error {
{&Vendor{}, "Vendor"},
{&PrefillGroup{}, "PrefillGroup"},
{&Setup{}, "Setup"},
{&TwoFA{}, "TwoFA"},
{&TwoFABackupCode{}, "TwoFABackupCode"},
}
// 动态计算migration数量确保errChan缓冲区足够大
errChan := make(chan error, len(migrations))

322
model/twofa.go Normal file
View File

@@ -0,0 +1,322 @@
package model
import (
"errors"
"fmt"
"one-api/common"
"time"
"gorm.io/gorm"
)
var ErrTwoFANotEnabled = errors.New("用户未启用2FA")
// TwoFA 用户2FA设置表
type TwoFA struct {
Id int `json:"id" gorm:"primaryKey"`
UserId int `json:"user_id" gorm:"unique;not null;index"`
Secret string `json:"-" gorm:"type:varchar(255);not null"` // TOTP密钥不返回给前端
IsEnabled bool `json:"is_enabled" gorm:"default:false"`
FailedAttempts int `json:"failed_attempts" gorm:"default:0"`
LockedUntil *time.Time `json:"locked_until,omitempty"`
LastUsedAt *time.Time `json:"last_used_at,omitempty"`
CreatedAt time.Time `json:"created_at"`
UpdatedAt time.Time `json:"updated_at"`
DeletedAt gorm.DeletedAt `json:"-" gorm:"index"`
}
// TwoFABackupCode 备用码使用记录表
type TwoFABackupCode struct {
Id int `json:"id" gorm:"primaryKey"`
UserId int `json:"user_id" gorm:"not null;index"`
CodeHash string `json:"-" gorm:"type:varchar(255);not null"` // 备用码哈希
IsUsed bool `json:"is_used" gorm:"default:false"`
UsedAt *time.Time `json:"used_at,omitempty"`
CreatedAt time.Time `json:"created_at"`
DeletedAt gorm.DeletedAt `json:"-" gorm:"index"`
}
// GetTwoFAByUserId 根据用户ID获取2FA设置
func GetTwoFAByUserId(userId int) (*TwoFA, error) {
if userId == 0 {
return nil, errors.New("用户ID不能为空")
}
var twoFA TwoFA
err := DB.Where("user_id = ?", userId).First(&twoFA).Error
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, nil // 返回nil表示未设置2FA
}
return nil, err
}
return &twoFA, nil
}
// IsTwoFAEnabled 检查用户是否启用了2FA
func IsTwoFAEnabled(userId int) bool {
twoFA, err := GetTwoFAByUserId(userId)
if err != nil || twoFA == nil {
return false
}
return twoFA.IsEnabled
}
// CreateTwoFA 创建2FA设置
func (t *TwoFA) Create() error {
// 检查用户是否已存在2FA设置
existing, err := GetTwoFAByUserId(t.UserId)
if err != nil {
return err
}
if existing != nil {
return errors.New("用户已存在2FA设置")
}
// 验证用户存在
var user User
if err := DB.First(&user, t.UserId).Error; err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return errors.New("用户不存在")
}
return err
}
return DB.Create(t).Error
}
// Update 更新2FA设置
func (t *TwoFA) Update() error {
if t.Id == 0 {
return errors.New("2FA记录ID不能为空")
}
return DB.Save(t).Error
}
// Delete 删除2FA设置
func (t *TwoFA) Delete() error {
if t.Id == 0 {
return errors.New("2FA记录ID不能为空")
}
// 使用事务确保原子性
return DB.Transaction(func(tx *gorm.DB) error {
// 同时删除相关的备用码记录(硬删除)
if err := tx.Unscoped().Where("user_id = ?", t.UserId).Delete(&TwoFABackupCode{}).Error; err != nil {
return err
}
// 硬删除2FA记录
return tx.Unscoped().Delete(t).Error
})
}
// ResetFailedAttempts 重置失败尝试次数
func (t *TwoFA) ResetFailedAttempts() error {
t.FailedAttempts = 0
t.LockedUntil = nil
return t.Update()
}
// IncrementFailedAttempts 增加失败尝试次数
func (t *TwoFA) IncrementFailedAttempts() error {
t.FailedAttempts++
// 检查是否需要锁定
if t.FailedAttempts >= common.MaxFailAttempts {
lockUntil := time.Now().Add(time.Duration(common.LockoutDuration) * time.Second)
t.LockedUntil = &lockUntil
}
return t.Update()
}
// IsLocked 检查账户是否被锁定
func (t *TwoFA) IsLocked() bool {
if t.LockedUntil == nil {
return false
}
return time.Now().Before(*t.LockedUntil)
}
// CreateBackupCodes 创建备用码
func CreateBackupCodes(userId int, codes []string) error {
return DB.Transaction(func(tx *gorm.DB) error {
// 先删除现有的备用码
if err := tx.Where("user_id = ?", userId).Delete(&TwoFABackupCode{}).Error; err != nil {
return err
}
// 创建新的备用码记录
for _, code := range codes {
hashedCode, err := common.HashBackupCode(code)
if err != nil {
return err
}
backupCode := TwoFABackupCode{
UserId: userId,
CodeHash: hashedCode,
IsUsed: false,
}
if err := tx.Create(&backupCode).Error; err != nil {
return err
}
}
return nil
})
}
// ValidateBackupCode 验证并使用备用码
func ValidateBackupCode(userId int, code string) (bool, error) {
if !common.ValidateBackupCode(code) {
return false, errors.New("验证码或备用码不正确")
}
normalizedCode := common.NormalizeBackupCode(code)
// 查找未使用的备用码
var backupCodes []TwoFABackupCode
if err := DB.Where("user_id = ? AND is_used = false", userId).Find(&backupCodes).Error; err != nil {
return false, err
}
// 验证备用码
for _, bc := range backupCodes {
if common.ValidatePasswordAndHash(normalizedCode, bc.CodeHash) {
// 标记为已使用
now := time.Now()
bc.IsUsed = true
bc.UsedAt = &now
if err := DB.Save(&bc).Error; err != nil {
return false, err
}
return true, nil
}
}
return false, nil
}
// GetUnusedBackupCodeCount 获取未使用的备用码数量
func GetUnusedBackupCodeCount(userId int) (int, error) {
var count int64
err := DB.Model(&TwoFABackupCode{}).Where("user_id = ? AND is_used = false", userId).Count(&count).Error
return int(count), err
}
// DisableTwoFA 禁用用户的2FA
func DisableTwoFA(userId int) error {
twoFA, err := GetTwoFAByUserId(userId)
if err != nil {
return err
}
if twoFA == nil {
return ErrTwoFANotEnabled
}
// 删除2FA设置和备用码
return twoFA.Delete()
}
// EnableTwoFA 启用2FA
func (t *TwoFA) Enable() error {
t.IsEnabled = true
t.FailedAttempts = 0
t.LockedUntil = nil
return t.Update()
}
// ValidateTOTPAndUpdateUsage 验证TOTP并更新使用记录
func (t *TwoFA) ValidateTOTPAndUpdateUsage(code string) (bool, error) {
// 检查是否被锁定
if t.IsLocked() {
return false, fmt.Errorf("账户已被锁定,请在%v后重试", t.LockedUntil.Format("2006-01-02 15:04:05"))
}
// 验证TOTP码
if !common.ValidateTOTPCode(t.Secret, code) {
// 增加失败次数
if err := t.IncrementFailedAttempts(); err != nil {
common.SysError("更新2FA失败次数失败: " + err.Error())
}
return false, nil
}
// 验证成功,重置失败次数并更新最后使用时间
now := time.Now()
t.FailedAttempts = 0
t.LockedUntil = nil
t.LastUsedAt = &now
if err := t.Update(); err != nil {
common.SysError("更新2FA使用记录失败: " + err.Error())
}
return true, nil
}
// ValidateBackupCodeAndUpdateUsage 验证备用码并更新使用记录
func (t *TwoFA) ValidateBackupCodeAndUpdateUsage(code string) (bool, error) {
// 检查是否被锁定
if t.IsLocked() {
return false, fmt.Errorf("账户已被锁定,请在%v后重试", t.LockedUntil.Format("2006-01-02 15:04:05"))
}
// 验证备用码
valid, err := ValidateBackupCode(t.UserId, code)
if err != nil {
return false, err
}
if !valid {
// 增加失败次数
if err := t.IncrementFailedAttempts(); err != nil {
common.SysError("更新2FA失败次数失败: " + err.Error())
}
return false, nil
}
// 验证成功,重置失败次数并更新最后使用时间
now := time.Now()
t.FailedAttempts = 0
t.LockedUntil = nil
t.LastUsedAt = &now
if err := t.Update(); err != nil {
common.SysError("更新2FA使用记录失败: " + err.Error())
}
return true, nil
}
// GetTwoFAStats 获取2FA统计信息管理员使用
func GetTwoFAStats() (map[string]interface{}, error) {
var totalUsers, enabledUsers int64
// 总用户数
if err := DB.Model(&User{}).Count(&totalUsers).Error; err != nil {
return nil, err
}
// 启用2FA的用户数
if err := DB.Model(&TwoFA{}).Where("is_enabled = true").Count(&enabledUsers).Error; err != nil {
return nil, err
}
enabledRate := float64(0)
if totalUsers > 0 {
enabledRate = float64(enabledUsers) / float64(totalUsers) * 100
}
return map[string]interface{}{
"total_users": totalUsers,
"enabled_users": enabledUsers,
"enabled_rate": fmt.Sprintf("%.1f%%", enabledRate),
}, nil
}