From c784a702778e4f972db8a111c5bae636e8834954 Mon Sep 17 00:00:00 2001 From: Seefs Date: Sat, 2 Aug 2025 14:53:28 +0800 Subject: [PATCH 1/8] feat: implement two-factor authentication (2FA) support with user login and settings integration --- common/totp.go | 153 +++++ controller/twofa.go | 547 ++++++++++++++++++ controller/user.go | 26 + go.mod | 2 + go.sum | 4 + model/main.go | 4 + model/twofa.go | 315 ++++++++++ router/api-router.go | 12 + web/bun.lock | 9 +- web/package.json | 1 + web/src/components/auth/LoginForm.js | 54 ++ web/src/components/auth/TwoFAVerification.js | 222 +++++++ .../components/settings/PersonalSetting.js | 4 + web/src/components/settings/TwoFASetting.js | 524 +++++++++++++++++ 14 files changed, 1874 insertions(+), 3 deletions(-) create mode 100644 common/totp.go create mode 100644 controller/twofa.go create mode 100644 model/twofa.go create mode 100644 web/src/components/auth/TwoFAVerification.js create mode 100644 web/src/components/settings/TwoFASetting.js diff --git a/common/totp.go b/common/totp.go new file mode 100644 index 00000000..ece5bc31 --- /dev/null +++ b/common/totp.go @@ -0,0 +1,153 @@ +package common + +import ( + "crypto/rand" + "fmt" + "os" + "strconv" + "strings" + + "github.com/pquerna/otp" + "github.com/pquerna/otp/totp" +) + +const ( + // 备用码配置 + BackupCodeLength = 8 // 备用码长度 + BackupCodeCount = 4 // 生成备用码数量 + + // 限制配置 + MaxFailAttempts = 5 // 最大失败尝试次数 + LockoutDuration = 300 // 锁定时间(秒) +) + +// GenerateTOTPSecret 生成TOTP密钥和配置 +func GenerateTOTPSecret(accountName string) (*otp.Key, error) { + issuer := Get2FAIssuer() + return totp.Generate(totp.GenerateOpts{ + Issuer: issuer, + AccountName: accountName, + Period: 30, + Digits: otp.DigitsSix, + Algorithm: otp.AlgorithmSHA1, + }) +} + +// ValidateTOTPCode 验证TOTP验证码 +func ValidateTOTPCode(secret, code string) bool { + // 清理验证码格式 + cleanCode := strings.ReplaceAll(code, " ", "") + if len(cleanCode) != 6 { + return false + } + + // 验证验证码 + return totp.Validate(cleanCode, secret) +} + +// GenerateBackupCodes 生成备用恢复码 +func GenerateBackupCodes() ([]string, error) { + codes := make([]string, BackupCodeCount) + + for i := 0; i < BackupCodeCount; i++ { + code, err := generateRandomBackupCode() + if err != nil { + return nil, err + } + codes[i] = code + } + + return codes, nil +} + +// generateRandomBackupCode 生成单个备用码 +func generateRandomBackupCode() (string, error) { + const charset = "ABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789" + code := make([]byte, BackupCodeLength) + + for i := range code { + randomBytes := make([]byte, 1) + _, err := rand.Read(randomBytes) + if err != nil { + return "", err + } + code[i] = charset[int(randomBytes[0])%len(charset)] + } + + // 格式化为 XXXX-XXXX 格式 + return fmt.Sprintf("%s-%s", string(code[:4]), string(code[4:])), nil +} + +// ValidateBackupCode 验证备用码格式 +func ValidateBackupCode(code string) bool { + // 移除所有分隔符并转为大写 + cleanCode := strings.ToUpper(strings.ReplaceAll(code, "-", "")) + if len(cleanCode) != BackupCodeLength { + return false + } + + // 检查字符是否合法 + for _, char := range cleanCode { + if !((char >= 'A' && char <= 'Z') || (char >= '0' && char <= '9')) { + return false + } + } + + return true +} + +// NormalizeBackupCode 标准化备用码格式 +func NormalizeBackupCode(code string) string { + cleanCode := strings.ToUpper(strings.ReplaceAll(code, "-", "")) + if len(cleanCode) == BackupCodeLength { + return fmt.Sprintf("%s-%s", cleanCode[:4], cleanCode[4:]) + } + return code +} + +// HashBackupCode 对备用码进行哈希 +func HashBackupCode(code string) (string, error) { + normalizedCode := NormalizeBackupCode(code) + return Password2Hash(normalizedCode) +} + +// Get2FAIssuer 获取2FA发行者名称 +func Get2FAIssuer() string { + if issuer := SystemName; issuer != "" { + return issuer + } + return "NewAPI" +} + +// getEnvOrDefault 获取环境变量或默认值 +func getEnvOrDefault(key, defaultValue string) string { + if value, exists := os.LookupEnv(key); exists { + return value + } + return defaultValue +} + +// ValidateNumericCode 验证数字验证码格式 +func ValidateNumericCode(code string) (string, error) { + // 移除空格 + code = strings.ReplaceAll(code, " ", "") + + if len(code) != 6 { + return "", fmt.Errorf("验证码必须是6位数字") + } + + // 检查是否为纯数字 + if _, err := strconv.Atoi(code); err != nil { + return "", fmt.Errorf("验证码只能包含数字") + } + + return code, nil +} + +// GenerateQRCodeData 生成二维码数据 +func GenerateQRCodeData(secret, username string) string { + issuer := Get2FAIssuer() + accountName := fmt.Sprintf("%s (%s)", username, issuer) + return fmt.Sprintf("otpauth://totp/%s:%s?secret=%s&issuer=%s&digits=6&period=30", + issuer, accountName, secret, issuer) +} diff --git a/controller/twofa.go b/controller/twofa.go new file mode 100644 index 00000000..368289c9 --- /dev/null +++ b/controller/twofa.go @@ -0,0 +1,547 @@ +package controller + +import ( + "fmt" + "net/http" + "one-api/common" + "one-api/model" + "strconv" + "strings" + + "github.com/gin-contrib/sessions" + "github.com/gin-gonic/gin" +) + +// Setup2FARequest 设置2FA请求结构 +type Setup2FARequest struct { + Code string `json:"code" binding:"required"` +} + +// Verify2FARequest 验证2FA请求结构 +type Verify2FARequest struct { + Code string `json:"code" binding:"required"` +} + +// Setup2FAResponse 设置2FA响应结构 +type Setup2FAResponse struct { + Secret string `json:"secret"` + QRCodeData string `json:"qr_code_data"` + BackupCodes []string `json:"backup_codes"` +} + +// Setup2FA 初始化2FA设置 +func Setup2FA(c *gin.Context) { + userId := c.GetInt("id") + + // 检查用户是否已经启用2FA + existing, err := model.GetTwoFAByUserId(userId) + if err != nil { + common.ApiError(c, err) + return + } + if existing != nil && existing.IsEnabled { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": "用户已启用2FA,请先禁用后重新设置", + }) + return + } + + // 如果存在已禁用的2FA记录,先删除它 + if existing != nil && !existing.IsEnabled { + if err := existing.Delete(); err != nil { + common.ApiError(c, err) + return + } + existing = nil // 重置为nil,后续将创建新记录 + } + + // 获取用户信息 + user, err := model.GetUserById(userId, false) + if err != nil { + common.ApiError(c, err) + return + } + + // 生成TOTP密钥 + key, err := common.GenerateTOTPSecret(user.Username) + if err != nil { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": "生成2FA密钥失败", + }) + common.SysError("生成TOTP密钥失败: " + err.Error()) + return + } + + // 生成备用码 + backupCodes, err := common.GenerateBackupCodes() + if err != nil { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": "生成备用码失败", + }) + common.SysError("生成备用码失败: " + err.Error()) + return + } + + // 生成二维码数据 + qrCodeData := common.GenerateQRCodeData(key.Secret(), user.Username) + + // 创建或更新2FA记录(暂未启用) + twoFA := &model.TwoFA{ + UserId: userId, + Secret: key.Secret(), + IsEnabled: false, + } + + if existing != nil { + // 更新现有记录 + twoFA.Id = existing.Id + err = twoFA.Update() + } else { + // 创建新记录 + err = twoFA.Create() + } + + if err != nil { + common.ApiError(c, err) + return + } + + // 创建备用码记录 + if err := model.CreateBackupCodes(userId, backupCodes); err != nil { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": "保存备用码失败", + }) + common.SysError("保存备用码失败: " + err.Error()) + return + } + + // 记录操作日志 + model.RecordLog(userId, model.LogTypeSystem, "开始设置两步验证") + + c.JSON(http.StatusOK, gin.H{ + "success": true, + "message": "2FA设置初始化成功,请使用认证器扫描二维码并输入验证码完成设置", + "data": Setup2FAResponse{ + Secret: key.Secret(), + QRCodeData: qrCodeData, + BackupCodes: backupCodes, + }, + }) +} + +// Enable2FA 启用2FA +func Enable2FA(c *gin.Context) { + var req Setup2FARequest + if err := c.ShouldBindJSON(&req); err != nil { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": "参数错误", + }) + return + } + + userId := c.GetInt("id") + + // 获取2FA记录 + twoFA, err := model.GetTwoFAByUserId(userId) + if err != nil { + common.ApiError(c, err) + return + } + if twoFA == nil { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": "请先完成2FA初始化设置", + }) + return + } + if twoFA.IsEnabled { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": "2FA已经启用", + }) + return + } + + // 验证TOTP验证码 + cleanCode, err := common.ValidateNumericCode(req.Code) + if err != nil { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": err.Error(), + }) + return + } + + if !common.ValidateTOTPCode(twoFA.Secret, cleanCode) { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": "验证码或备用码错误,请重试", + }) + return + } + + // 启用2FA + if err := twoFA.Enable(); err != nil { + common.ApiError(c, err) + return + } + + // 记录操作日志 + model.RecordLog(userId, model.LogTypeSystem, "成功启用两步验证") + + c.JSON(http.StatusOK, gin.H{ + "success": true, + "message": "两步验证启用成功", + }) +} + +// Disable2FA 禁用2FA +func Disable2FA(c *gin.Context) { + var req Verify2FARequest + if err := c.ShouldBindJSON(&req); err != nil { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": "参数错误", + }) + return + } + + userId := c.GetInt("id") + + // 获取2FA记录 + twoFA, err := model.GetTwoFAByUserId(userId) + if err != nil { + common.ApiError(c, err) + return + } + if twoFA == nil || !twoFA.IsEnabled { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": "用户未启用2FA", + }) + return + } + + // 验证TOTP验证码或备用码 + cleanCode, err := common.ValidateNumericCode(req.Code) + isValidTOTP := false + isValidBackup := false + + if err == nil { + // 尝试验证TOTP + isValidTOTP, _ = twoFA.ValidateTOTPAndUpdateUsage(cleanCode) + } + + if !isValidTOTP { + // 尝试验证备用码 + isValidBackup, err = twoFA.ValidateBackupCodeAndUpdateUsage(req.Code) + if err != nil { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": err.Error(), + }) + return + } + } + + if !isValidTOTP && !isValidBackup { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": "验证码或备用码错误,请重试", + }) + return + } + + // 禁用2FA + if err := model.DisableTwoFA(userId); err != nil { + common.ApiError(c, err) + return + } + + // 记录操作日志 + model.RecordLog(userId, model.LogTypeSystem, "禁用两步验证") + + c.JSON(http.StatusOK, gin.H{ + "success": true, + "message": "两步验证已禁用", + }) +} + +// Get2FAStatus 获取用户2FA状态 +func Get2FAStatus(c *gin.Context) { + userId := c.GetInt("id") + + twoFA, err := model.GetTwoFAByUserId(userId) + if err != nil { + common.ApiError(c, err) + return + } + + status := map[string]interface{}{ + "enabled": false, + "locked": false, + } + + if twoFA != nil { + status["enabled"] = twoFA.IsEnabled + status["locked"] = twoFA.IsLocked() + if twoFA.IsEnabled { + // 获取剩余备用码数量 + backupCount, err := model.GetUnusedBackupCodeCount(userId) + if err != nil { + common.SysError("获取备用码数量失败: " + err.Error()) + } else { + status["backup_codes_remaining"] = backupCount + } + } + } + + c.JSON(http.StatusOK, gin.H{ + "success": true, + "message": "", + "data": status, + }) +} + +// RegenerateBackupCodes 重新生成备用码 +func RegenerateBackupCodes(c *gin.Context) { + var req Verify2FARequest + if err := c.ShouldBindJSON(&req); err != nil { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": "参数错误", + }) + return + } + + userId := c.GetInt("id") + + // 获取2FA记录 + twoFA, err := model.GetTwoFAByUserId(userId) + if err != nil { + common.ApiError(c, err) + return + } + if twoFA == nil || !twoFA.IsEnabled { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": "用户未启用2FA", + }) + return + } + + // 验证TOTP验证码 + cleanCode, err := common.ValidateNumericCode(req.Code) + if err != nil { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": err.Error(), + }) + return + } + + valid, err := twoFA.ValidateTOTPAndUpdateUsage(cleanCode) + if err != nil { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": err.Error(), + }) + return + } + if !valid { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": "验证码或备用码错误,请重试", + }) + return + } + + // 生成新的备用码 + backupCodes, err := common.GenerateBackupCodes() + if err != nil { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": "生成备用码失败", + }) + common.SysError("生成备用码失败: " + err.Error()) + return + } + + // 保存新的备用码 + if err := model.CreateBackupCodes(userId, backupCodes); err != nil { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": "保存备用码失败", + }) + common.SysError("保存备用码失败: " + err.Error()) + return + } + + // 记录操作日志 + model.RecordLog(userId, model.LogTypeSystem, "重新生成两步验证备用码") + + c.JSON(http.StatusOK, gin.H{ + "success": true, + "message": "备用码重新生成成功", + "data": map[string]interface{}{ + "backup_codes": backupCodes, + }, + }) +} + +// Verify2FALogin 登录时验证2FA +func Verify2FALogin(c *gin.Context) { + var req Verify2FARequest + if err := c.ShouldBindJSON(&req); err != nil { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": "参数错误", + }) + return + } + + // 从会话中获取pending用户信息 + session := sessions.Default(c) + pendingUserId := session.Get("pending_user_id") + if pendingUserId == nil { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": "会话已过期,请重新登录", + }) + return + } + userId := pendingUserId.(int) + + // 获取用户信息 + user, err := model.GetUserById(userId, false) + if err != nil { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": "用户不存在", + }) + return + } + + // 获取2FA记录 + twoFA, err := model.GetTwoFAByUserId(user.Id) + if err != nil { + common.ApiError(c, err) + return + } + if twoFA == nil || !twoFA.IsEnabled { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": "用户未启用2FA", + }) + return + } + + // 验证TOTP验证码或备用码 + cleanCode, err := common.ValidateNumericCode(req.Code) + isValidTOTP := false + isValidBackup := false + + if err == nil { + // 尝试验证TOTP + isValidTOTP, _ = twoFA.ValidateTOTPAndUpdateUsage(cleanCode) + } + + if !isValidTOTP { + // 尝试验证备用码 + isValidBackup, err = twoFA.ValidateBackupCodeAndUpdateUsage(req.Code) + if err != nil { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": err.Error(), + }) + return + } + } + + if !isValidTOTP && !isValidBackup { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": "验证码或备用码错误,请重试", + }) + return + } + + // 2FA验证成功,清理pending会话信息并完成登录 + session.Delete("pending_username") + session.Delete("pending_user_id") + session.Save() + + setupLogin(user, c) +} + +// Admin2FAStats 管理员获取2FA统计信息 +func Admin2FAStats(c *gin.Context) { + stats, err := model.GetTwoFAStats() + if err != nil { + common.ApiError(c, err) + return + } + + c.JSON(http.StatusOK, gin.H{ + "success": true, + "message": "", + "data": stats, + }) +} + +// AdminDisable2FA 管理员强制禁用用户2FA +func AdminDisable2FA(c *gin.Context) { + userIdStr := c.Param("id") + userId, err := strconv.Atoi(userIdStr) + if err != nil { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": "用户ID格式错误", + }) + return + } + + // 检查目标用户权限 + targetUser, err := model.GetUserById(userId, false) + if err != nil { + common.ApiError(c, err) + return + } + + myRole := c.GetInt("role") + if myRole <= targetUser.Role && myRole != common.RoleRootUser { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": "无权操作同级或更高级用户的2FA设置", + }) + return + } + + // 禁用2FA + if err := model.DisableTwoFA(userId); err != nil { + if strings.Contains(err.Error(), "未启用2FA") { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": "用户未启用2FA", + }) + return + } + common.ApiError(c, err) + return + } + + // 记录操作日志 + adminId := c.GetInt("id") + model.RecordLog(userId, model.LogTypeManage, + fmt.Sprintf("管理员(ID:%d)强制禁用了用户的两步验证", adminId)) + + c.JSON(http.StatusOK, gin.H{ + "success": true, + "message": "用户2FA已被强制禁用", + }) +} diff --git a/controller/user.go b/controller/user.go index 292ed8c6..6e968037 100644 --- a/controller/user.go +++ b/controller/user.go @@ -62,6 +62,32 @@ func Login(c *gin.Context) { }) return } + + // 检查是否启用2FA + if model.IsTwoFAEnabled(user.Id) { + // 设置pending session,等待2FA验证 + session := sessions.Default(c) + session.Set("pending_username", user.Username) + session.Set("pending_user_id", user.Id) + err := session.Save() + if err != nil { + c.JSON(http.StatusOK, gin.H{ + "message": "无法保存会话信息,请重试", + "success": false, + }) + return + } + + c.JSON(http.StatusOK, gin.H{ + "message": "请输入两步验证码", + "success": true, + "data": map[string]interface{}{ + "require_2fa": true, + }, + }) + return + } + setupLogin(&user, c) } diff --git a/go.mod b/go.mod index 94873c88..1def0b08 100644 --- a/go.mod +++ b/go.mod @@ -45,6 +45,7 @@ require ( github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.5 // indirect github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.5 // indirect github.com/aws/smithy-go v1.20.2 // indirect + github.com/boombuler/barcode v1.0.1-0.20190219062509-6c824513bacc // indirect github.com/bytedance/sonic v1.11.6 // indirect github.com/bytedance/sonic/loader v0.1.1 // indirect github.com/cespare/xxhash/v2 v2.3.0 // indirect @@ -79,6 +80,7 @@ require ( github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect github.com/modern-go/reflect2 v1.0.2 // indirect github.com/pelletier/go-toml/v2 v2.2.1 // indirect + github.com/pquerna/otp v1.5.0 // indirect github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec // indirect github.com/tklauser/go-sysconf v0.3.12 // indirect github.com/tklauser/numcpus v0.6.1 // indirect diff --git a/go.sum b/go.sum index 74eecd4c..4f5ae530 100644 --- a/go.sum +++ b/go.sum @@ -20,6 +20,8 @@ github.com/aws/aws-sdk-go-v2/service/bedrockruntime v1.7.4 h1:JgHnonzbnA3pbqj76w github.com/aws/aws-sdk-go-v2/service/bedrockruntime v1.7.4/go.mod h1:nZspkhg+9p8iApLFoyAqfyuMP0F38acy2Hm3r5r95Cg= github.com/aws/smithy-go v1.20.2 h1:tbp628ireGtzcHDDmLT/6ADHidqnwgF57XOXZe6tp4Q= github.com/aws/smithy-go v1.20.2/go.mod h1:krry+ya/rV9RDcV/Q16kpu6ypI4K2czasz0NC3qS14E= +github.com/boombuler/barcode v1.0.1-0.20190219062509-6c824513bacc h1:biVzkmvwrH8WK8raXaxBx6fRVTlJILwEwQGL1I/ByEI= +github.com/boombuler/barcode v1.0.1-0.20190219062509-6c824513bacc/go.mod h1:paBWMcWSl3LHKBqUq+rly7CNSldXjb2rDl3JlRe0mD8= github.com/bytedance/gopkg v0.0.0-20220118071334-3db87571198b h1:LTGVFpNmNHhj0vhOlfgWueFJ32eK9blaIlHR2ciXOT0= github.com/bytedance/gopkg v0.0.0-20220118071334-3db87571198b/go.mod h1:2ZlV9BaUH4+NXIBF0aMdKKAnHTzqH+iMU4KUjAbL23Q= github.com/bytedance/sonic v1.11.6 h1:oUp34TzMlL+OY1OUWxHqsdkgC/Zfc85zGqw9siXjrc0= @@ -169,6 +171,8 @@ github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/pquerna/otp v1.5.0 h1:NMMR+WrmaqXU4EzdGJEE1aUUI0AMRzsp96fFFWNPwxs= +github.com/pquerna/otp v1.5.0/go.mod h1:dkJfzwRKNiegxyNb54X/3fLwhCynbMspSyWKnvi1AEg= github.com/remyoudompheng/bigfft v0.0.0-20200410134404-eec4a21b6bb0/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo= github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec h1:W09IVJc94icq4NjY3clb7Lk8O1qJ8BdBEF8z0ibU0rE= github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo= diff --git a/model/main.go b/model/main.go index 013beacd..38dd2aee 100644 --- a/model/main.go +++ b/model/main.go @@ -251,6 +251,8 @@ func migrateDB() error { &QuotaData{}, &Task{}, &Setup{}, + &TwoFA{}, + &TwoFABackupCode{}, ) if err != nil { return err @@ -277,6 +279,8 @@ func migrateDBFast() error { {&QuotaData{}, "QuotaData"}, {&Task{}, "Task"}, {&Setup{}, "Setup"}, + {&TwoFA{}, "TwoFA"}, + {&TwoFABackupCode{}, "TwoFABackupCode"}, } // 动态计算migration数量,确保errChan缓冲区足够大 errChan := make(chan error, len(migrations)) diff --git a/model/twofa.go b/model/twofa.go new file mode 100644 index 00000000..4a96ffb0 --- /dev/null +++ b/model/twofa.go @@ -0,0 +1,315 @@ +package model + +import ( + "errors" + "fmt" + "one-api/common" + "time" + + "gorm.io/gorm" +) + +// TwoFA 用户2FA设置表 +type TwoFA struct { + Id int `json:"id" gorm:"primaryKey"` + UserId int `json:"user_id" gorm:"unique;not null;index"` + Secret string `json:"-" gorm:"type:varchar(255);not null"` // TOTP密钥,不返回给前端 + IsEnabled bool `json:"is_enabled" gorm:"default:false"` + FailedAttempts int `json:"failed_attempts" gorm:"default:0"` + LockedUntil *time.Time `json:"locked_until,omitempty"` + LastUsedAt *time.Time `json:"last_used_at,omitempty"` + CreatedAt time.Time `json:"created_at"` + UpdatedAt time.Time `json:"updated_at"` + DeletedAt gorm.DeletedAt `json:"-" gorm:"index"` +} + +// TwoFABackupCode 备用码使用记录表 +type TwoFABackupCode struct { + Id int `json:"id" gorm:"primaryKey"` + UserId int `json:"user_id" gorm:"not null;index"` + CodeHash string `json:"-" gorm:"type:varchar(255);not null"` // 备用码哈希 + IsUsed bool `json:"is_used" gorm:"default:false"` + UsedAt *time.Time `json:"used_at,omitempty"` + CreatedAt time.Time `json:"created_at"` + DeletedAt gorm.DeletedAt `json:"-" gorm:"index"` +} + +// GetTwoFAByUserId 根据用户ID获取2FA设置 +func GetTwoFAByUserId(userId int) (*TwoFA, error) { + if userId == 0 { + return nil, errors.New("用户ID不能为空") + } + + var twoFA TwoFA + err := DB.Where("user_id = ?", userId).First(&twoFA).Error + if err != nil { + if errors.Is(err, gorm.ErrRecordNotFound) { + return nil, nil // 返回nil表示未设置2FA + } + return nil, err + } + + return &twoFA, nil +} + +// IsTwoFAEnabled 检查用户是否启用了2FA +func IsTwoFAEnabled(userId int) bool { + twoFA, err := GetTwoFAByUserId(userId) + if err != nil || twoFA == nil { + return false + } + return twoFA.IsEnabled +} + +// CreateTwoFA 创建2FA设置 +func (t *TwoFA) Create() error { + // 检查用户是否已存在2FA设置 + existing, err := GetTwoFAByUserId(t.UserId) + if err != nil { + return err + } + if existing != nil { + return errors.New("用户已存在2FA设置") + } + + // 验证用户存在 + var user User + if err := DB.First(&user, t.UserId).Error; err != nil { + if errors.Is(err, gorm.ErrRecordNotFound) { + return errors.New("用户不存在") + } + return err + } + + return DB.Create(t).Error +} + +// Update 更新2FA设置 +func (t *TwoFA) Update() error { + if t.Id == 0 { + return errors.New("2FA记录ID不能为空") + } + return DB.Save(t).Error +} + +// Delete 删除2FA设置 +func (t *TwoFA) Delete() error { + if t.Id == 0 { + return errors.New("2FA记录ID不能为空") + } + + // 同时删除相关的备用码记录(硬删除) + if err := DB.Unscoped().Where("user_id = ?", t.UserId).Delete(&TwoFABackupCode{}).Error; err != nil { + return err + } + + // 硬删除2FA记录 + return DB.Unscoped().Delete(t).Error +} + +// ResetFailedAttempts 重置失败尝试次数 +func (t *TwoFA) ResetFailedAttempts() error { + t.FailedAttempts = 0 + t.LockedUntil = nil + return t.Update() +} + +// IncrementFailedAttempts 增加失败尝试次数 +func (t *TwoFA) IncrementFailedAttempts() error { + t.FailedAttempts++ + + // 检查是否需要锁定 + if t.FailedAttempts >= common.MaxFailAttempts { + lockUntil := time.Now().Add(time.Duration(common.LockoutDuration) * time.Second) + t.LockedUntil = &lockUntil + } + + return t.Update() +} + +// IsLocked 检查账户是否被锁定 +func (t *TwoFA) IsLocked() bool { + if t.LockedUntil == nil { + return false + } + return time.Now().Before(*t.LockedUntil) +} + +// CreateBackupCodes 创建备用码 +func CreateBackupCodes(userId int, codes []string) error { + // 先删除现有的备用码 + if err := DB.Where("user_id = ?", userId).Delete(&TwoFABackupCode{}).Error; err != nil { + return err + } + + // 创建新的备用码记录 + for _, code := range codes { + hashedCode, err := common.HashBackupCode(code) + if err != nil { + return err + } + + backupCode := TwoFABackupCode{ + UserId: userId, + CodeHash: hashedCode, + IsUsed: false, + } + + if err := DB.Create(&backupCode).Error; err != nil { + return err + } + } + + return nil +} + +// ValidateBackupCode 验证并使用备用码 +func ValidateBackupCode(userId int, code string) (bool, error) { + if !common.ValidateBackupCode(code) { + return false, errors.New("验证码或备用码不正确") + } + + normalizedCode := common.NormalizeBackupCode(code) + + // 查找未使用的备用码 + var backupCodes []TwoFABackupCode + if err := DB.Where("user_id = ? AND is_used = false", userId).Find(&backupCodes).Error; err != nil { + return false, err + } + + // 验证备用码 + for _, bc := range backupCodes { + if common.ValidatePasswordAndHash(normalizedCode, bc.CodeHash) { + // 标记为已使用 + now := time.Now() + bc.IsUsed = true + bc.UsedAt = &now + + if err := DB.Save(&bc).Error; err != nil { + return false, err + } + + return true, nil + } + } + + return false, nil +} + +// GetUnusedBackupCodeCount 获取未使用的备用码数量 +func GetUnusedBackupCodeCount(userId int) (int, error) { + var count int64 + err := DB.Model(&TwoFABackupCode{}).Where("user_id = ? AND is_used = false", userId).Count(&count).Error + return int(count), err +} + +// DisableTwoFA 禁用用户的2FA +func DisableTwoFA(userId int) error { + twoFA, err := GetTwoFAByUserId(userId) + if err != nil { + return err + } + if twoFA == nil { + return errors.New("用户未启用2FA") + } + + // 删除2FA设置和备用码 + return twoFA.Delete() +} + +// EnableTwoFA 启用2FA +func (t *TwoFA) Enable() error { + t.IsEnabled = true + t.FailedAttempts = 0 + t.LockedUntil = nil + return t.Update() +} + +// ValidateTOTPAndUpdateUsage 验证TOTP并更新使用记录 +func (t *TwoFA) ValidateTOTPAndUpdateUsage(code string) (bool, error) { + // 检查是否被锁定 + if t.IsLocked() { + return false, fmt.Errorf("账户已被锁定,请在%v后重试", t.LockedUntil.Format("2006-01-02 15:04:05")) + } + + // 验证TOTP码 + if !common.ValidateTOTPCode(t.Secret, code) { + // 增加失败次数 + if err := t.IncrementFailedAttempts(); err != nil { + common.SysError("更新2FA失败次数失败: " + err.Error()) + } + return false, nil + } + + // 验证成功,重置失败次数并更新最后使用时间 + now := time.Now() + t.FailedAttempts = 0 + t.LockedUntil = nil + t.LastUsedAt = &now + + if err := t.Update(); err != nil { + common.SysError("更新2FA使用记录失败: " + err.Error()) + } + + return true, nil +} + +// ValidateBackupCodeAndUpdateUsage 验证备用码并更新使用记录 +func (t *TwoFA) ValidateBackupCodeAndUpdateUsage(code string) (bool, error) { + // 检查是否被锁定 + if t.IsLocked() { + return false, fmt.Errorf("账户已被锁定,请在%v后重试", t.LockedUntil.Format("2006-01-02 15:04:05")) + } + + // 验证备用码 + valid, err := ValidateBackupCode(t.UserId, code) + if err != nil { + return false, err + } + + if !valid { + // 增加失败次数 + if err := t.IncrementFailedAttempts(); err != nil { + common.SysError("更新2FA失败次数失败: " + err.Error()) + } + return false, nil + } + + // 验证成功,重置失败次数并更新最后使用时间 + now := time.Now() + t.FailedAttempts = 0 + t.LockedUntil = nil + t.LastUsedAt = &now + + if err := t.Update(); err != nil { + common.SysError("更新2FA使用记录失败: " + err.Error()) + } + + return true, nil +} + +// GetTwoFAStats 获取2FA统计信息(管理员使用) +func GetTwoFAStats() (map[string]interface{}, error) { + var totalUsers, enabledUsers int64 + + // 总用户数 + if err := DB.Model(&User{}).Count(&totalUsers).Error; err != nil { + return nil, err + } + + // 启用2FA的用户数 + if err := DB.Model(&TwoFA{}).Where("is_enabled = true").Count(&enabledUsers).Error; err != nil { + return nil, err + } + + enabledRate := float64(0) + if totalUsers > 0 { + enabledRate = float64(enabledUsers) / float64(totalUsers) * 100 + } + + return map[string]interface{}{ + "total_users": totalUsers, + "enabled_users": enabledUsers, + "enabled_rate": fmt.Sprintf("%.1f%%", enabledRate), + }, nil +} diff --git a/router/api-router.go b/router/api-router.go index bc49803a..16c78186 100644 --- a/router/api-router.go +++ b/router/api-router.go @@ -44,6 +44,7 @@ func SetApiRouter(router *gin.Engine) { { userRoute.POST("/register", middleware.CriticalRateLimit(), middleware.TurnstileCheck(), controller.Register) userRoute.POST("/login", middleware.CriticalRateLimit(), middleware.TurnstileCheck(), controller.Login) + userRoute.POST("/login/2fa", middleware.CriticalRateLimit(), controller.Verify2FALogin) //userRoute.POST("/tokenlog", middleware.CriticalRateLimit(), controller.TokenLog) userRoute.GET("/logout", controller.Logout) userRoute.GET("/epay/notify", controller.EpayNotify) @@ -66,6 +67,13 @@ func SetApiRouter(router *gin.Engine) { selfRoute.POST("/stripe/amount", controller.RequestStripeAmount) selfRoute.POST("/aff_transfer", controller.TransferAffQuota) selfRoute.PUT("/setting", controller.UpdateUserSetting) + + // 2FA routes + selfRoute.GET("/2fa/status", controller.Get2FAStatus) + selfRoute.POST("/2fa/setup", controller.Setup2FA) + selfRoute.POST("/2fa/enable", controller.Enable2FA) + selfRoute.POST("/2fa/disable", controller.Disable2FA) + selfRoute.POST("/2fa/backup_codes", controller.RegenerateBackupCodes) } adminRoute := userRoute.Group("/") @@ -78,6 +86,10 @@ func SetApiRouter(router *gin.Engine) { adminRoute.POST("/manage", controller.ManageUser) adminRoute.PUT("/", controller.UpdateUser) adminRoute.DELETE("/:id", controller.DeleteUser) + + // Admin 2FA routes + adminRoute.GET("/2fa/stats", controller.Admin2FAStats) + adminRoute.DELETE("/:id/2fa", controller.AdminDisable2FA) } } optionRoute := apiRouter.Group("/option") diff --git a/web/bun.lock b/web/bun.lock index ca4e337c..53467aa5 100644 --- a/web/bun.lock +++ b/web/bun.lock @@ -21,6 +21,7 @@ "lucide-react": "^0.511.0", "marked": "^4.1.1", "mermaid": "^11.6.0", + "qrcode.react": "^4.2.0", "react": "^18.2.0", "react-dom": "^18.2.0", "react-dropzone": "^14.2.3", @@ -1492,6 +1493,8 @@ "punycode": ["punycode@2.3.1", "", {}, "sha512-vYt7UD1U9Wg6138shLtLOvdAu+8DsC/ilFtEVHcH+wydcSpNE20AfSOduf6MkRFahL5FY7X1oU7nKVZFtfq8Fg=="], + "qrcode.react": ["qrcode.react@4.2.0", "", { "peerDependencies": { "react": "^16.8.0 || ^17.0.0 || ^18.0.0 || ^19.0.0" } }, "sha512-QpgqWi8rD9DsS9EP3z7BT+5lY5SFhsqGjpgW5DY/i3mK4M9DTBNz3ErMi8BWYEfI3L0d8GIbGmcdFAS1uIRGjA=="], + "quansync": ["quansync@0.2.10", "", {}, "sha512-t41VRkMYbkHyCYmOvx/6URnN80H7k4X0lLdBMGsz+maAwrJQYB1djpV6vHrQIBE0WBSGqhtEHrK9U3DWWH8v7A=="], "query-string": ["query-string@9.2.0", "", { "dependencies": { "decode-uri-component": "^0.4.1", "filter-obj": "^5.1.0", "split-on-first": "^3.0.0" } }, "sha512-YIRhrHujoQxhexwRLxfy3VSjOXmvZRd2nyw1PwL1UUqZ/ys1dEZd1+NSgXkne2l/4X/7OXkigEAuhTX0g/ivJQ=="], @@ -1502,7 +1505,7 @@ "rc-checkbox": ["rc-checkbox@3.5.0", "", { "dependencies": { "@babel/runtime": "^7.10.1", "classnames": "^2.3.2", "rc-util": "^5.25.2" }, "peerDependencies": { "react": ">=16.9.0", "react-dom": ">=16.9.0" } }, "sha512-aOAQc3E98HteIIsSqm6Xk2FPKIER6+5vyEFMZfo73TqM+VVAIqOkHoPjgKLqSNtVLWScoaM7vY2ZrGEheI79yg=="], - "rc-collapse": ["rc-collapse@3.9.0", "", { "dependencies": { "@babel/runtime": "^7.10.1", "classnames": "2.x", "rc-motion": "^2.3.4", "rc-util": "^5.27.0" }, "peerDependencies": { "react": ">=16.9.0", "react-dom": ">=16.9.0" } }, "sha512-swDdz4QZ4dFTo4RAUMLL50qP0EY62N2kvmk2We5xYdRwcRn8WcYtuetCJpwpaCbUfUt5+huLpVxhvmnK+PHrkA=="], + "rc-collapse": ["rc-collapse@4.0.0", "", { "dependencies": { "@babel/runtime": "^7.10.1", "classnames": "2.x", "rc-motion": "^2.3.4", "rc-util": "^5.27.0" }, "peerDependencies": { "react": ">=16.9.0", "react-dom": ">=16.9.0" } }, "sha512-SwoOByE39/3oIokDs/BnkqI+ltwirZbP8HZdq1/3SkPSBi7xDdvWHTp7cpNI9ullozkR6mwTWQi6/E/9huQVrA=="], "rc-dialog": ["rc-dialog@9.6.0", "", { "dependencies": { "@babel/runtime": "^7.10.1", "@rc-component/portal": "^1.0.0-8", "classnames": "^2.2.6", "rc-motion": "^2.3.0", "rc-util": "^5.21.0" }, "peerDependencies": { "react": ">=16.9.0", "react-dom": ">=16.9.0" } }, "sha512-ApoVi9Z8PaCQg6FsUzS8yvBEQy0ZL2PkuvAgrmohPkN3okps5WZ5WQWPc1RNuiOKaAYv8B97ACdsFU5LizzCqg=="], @@ -1946,8 +1949,6 @@ "@lobehub/ui/lucide-react": ["lucide-react@0.484.0", "", { "peerDependencies": { "react": "^16.5.1 || ^17.0.0 || ^18.0.0 || ^19.0.0" } }, "sha512-oZy8coK9kZzvqhSgfbGkPtTgyjpBvs3ukLgDPv14dSOZtBtboryWF5o8i3qen7QbGg7JhiJBz5mK1p8YoMZTLQ=="], - "@lobehub/ui/rc-collapse": ["rc-collapse@4.0.0", "", { "dependencies": { "@babel/runtime": "^7.10.1", "classnames": "2.x", "rc-motion": "^2.3.4", "rc-util": "^5.27.0" }, "peerDependencies": { "react": ">=16.9.0", "react-dom": ">=16.9.0" } }, "sha512-SwoOByE39/3oIokDs/BnkqI+ltwirZbP8HZdq1/3SkPSBi7xDdvWHTp7cpNI9ullozkR6mwTWQi6/E/9huQVrA=="], - "@radix-ui/react-dismissable-layer/@radix-ui/react-compose-refs": ["@radix-ui/react-compose-refs@1.0.0", "", { "dependencies": { "@babel/runtime": "^7.13.10" }, "peerDependencies": { "react": "^16.8 || ^17.0 || ^18.0" } }, "sha512-0KaSv6sx787/hK3eF53iOkiSLwAGlFMx5lotrqD2pTjB18KbybKoEIgkNZTKC60YECDQTKGTRcDBILwZVqVKvA=="], "@radix-ui/react-popper/@floating-ui/react-dom": ["@floating-ui/react-dom@0.7.2", "", { "dependencies": { "@floating-ui/dom": "^0.5.3", "use-isomorphic-layout-effect": "^1.1.1" }, "peerDependencies": { "react": ">=16.8.0", "react-dom": ">=16.8.0" } }, "sha512-1T0sJcpHgX/u4I1OzIEhlcrvkUN8ln39nz7fMoE/2HDHrPiMFoOGR7++GYyfUmIQHkkrTinaeQsO3XWubjSvGg=="], @@ -1964,6 +1965,8 @@ "@visactor/vrender-kits/roughjs": ["roughjs@4.5.2", "", { "dependencies": { "path-data-parser": "^0.1.0", "points-on-curve": "^0.2.0", "points-on-path": "^0.2.1" } }, "sha512-2xSlLDKdsWyFxrveYWk9YQ/Y9UfK38EAMRNkYkMqYBJvPX8abCa9PN0x3w02H8Oa6/0bcZICJU+U95VumPqseg=="], + "antd/rc-collapse": ["rc-collapse@3.9.0", "", { "dependencies": { "@babel/runtime": "^7.10.1", "classnames": "2.x", "rc-motion": "^2.3.4", "rc-util": "^5.27.0" }, "peerDependencies": { "react": ">=16.9.0", "react-dom": ">=16.9.0" } }, "sha512-swDdz4QZ4dFTo4RAUMLL50qP0EY62N2kvmk2We5xYdRwcRn8WcYtuetCJpwpaCbUfUt5+huLpVxhvmnK+PHrkA=="], + "antd/scroll-into-view-if-needed": ["scroll-into-view-if-needed@3.1.0", "", { "dependencies": { "compute-scroll-into-view": "^3.0.2" } }, "sha512-49oNpRjWRvnU8NyGVmUaYG4jtTkNonFZI86MmGRDqBphEK2EXT9gdEUoQPZhuBM8yWHxCWbobltqYO5M4XrUvQ=="], "chokidar/glob-parent": ["glob-parent@5.1.2", "", { "dependencies": { "is-glob": "^4.0.1" } }, "sha512-AOIgSQCepiJYwP3ARnGx+5VnTu2HBYdzbGP45eLw1vr3zB3vZLeyed1sC9hnbcOc9/SrMyM5RPQrkGz4aS9Zow=="], diff --git a/web/package.json b/web/package.json index ba0df966..f014d84b 100644 --- a/web/package.json +++ b/web/package.json @@ -21,6 +21,7 @@ "lucide-react": "^0.511.0", "marked": "^4.1.1", "mermaid": "^11.6.0", + "qrcode.react": "^4.2.0", "react": "^18.2.0", "react-dom": "^18.2.0", "react-dropzone": "^14.2.3", diff --git a/web/src/components/auth/LoginForm.js b/web/src/components/auth/LoginForm.js index f81dfd81..9c6650f8 100644 --- a/web/src/components/auth/LoginForm.js +++ b/web/src/components/auth/LoginForm.js @@ -50,6 +50,7 @@ import { IconGithubLogo, IconMail, IconLock } from '@douyinfe/semi-icons'; import OIDCIcon from '../common/logo/OIDCIcon.js'; import WeChatIcon from '../common/logo/WeChatIcon.js'; import LinuxDoIcon from '../common/logo/LinuxDoIcon.js'; +import TwoFAVerification from './TwoFAVerification.js'; import { useTranslation } from 'react-i18next'; const LoginForm = () => { @@ -78,6 +79,7 @@ const LoginForm = () => { const [resetPasswordLoading, setResetPasswordLoading] = useState(false); const [otherLoginOptionsLoading, setOtherLoginOptionsLoading] = useState(false); const [wechatCodeSubmitLoading, setWechatCodeSubmitLoading] = useState(false); + const [showTwoFA, setShowTwoFA] = useState(false); const logo = getLogo(); const systemName = getSystemName(); @@ -162,6 +164,13 @@ const LoginForm = () => { ); const { success, message, data } = res.data; if (success) { + // 检查是否需要2FA验证 + if (data && data.require_2fa) { + setShowTwoFA(true); + setLoginLoading(false); + return; + } + userDispatch({ type: 'login', payload: data }); setUserData(data); updateAPI(); @@ -280,6 +289,21 @@ const LoginForm = () => { setOtherLoginOptionsLoading(false); }; + // 2FA验证成功处理 + const handle2FASuccess = (data) => { + userDispatch({ type: 'login', payload: data }); + setUserData(data); + updateAPI(); + showSuccess('登录成功!'); + navigate('/console'); + }; + + // 返回登录页面 + const handleBackToLogin = () => { + setShowTwoFA(false); + setInputs({ username: '', password: '', wechat_verification_code: '' }); + }; + const renderOAuthOptions = () => { return (
@@ -537,6 +561,35 @@ const LoginForm = () => { ); }; + // 2FA验证弹窗 + const render2FAModal = () => { + return ( + +
+ + + +
+ 两步验证 +
+ } + visible={showTwoFA} + onCancel={handleBackToLogin} + footer={null} + width={450} + centered + > + + + ); + }; + return (
{/* 背景模糊晕染球 */} @@ -547,6 +600,7 @@ const LoginForm = () => { ? renderEmailLoginForm() : renderOAuthOptions()} {renderWeChatLoginModal()} + {render2FAModal()} {turnstileEnabled && (
diff --git a/web/src/components/auth/TwoFAVerification.js b/web/src/components/auth/TwoFAVerification.js new file mode 100644 index 00000000..384273ed --- /dev/null +++ b/web/src/components/auth/TwoFAVerification.js @@ -0,0 +1,222 @@ +/* +Copyright (C) 2025 QuantumNous + +This program is free software: you can redistribute it and/or modify +it under the terms of the GNU Affero General Public License as +published by the Free Software Foundation, either version 3 of the +License, or (at your option) any later version. + +This program is distributed in the hope that it will be useful, +but WITHOUT ANY WARRANTY; without even the implied warranty of +MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +GNU Affero General Public License for more details. + +You should have received a copy of the GNU Affero General Public License +along with this program. If not, see . + +For commercial licensing, please contact support@quantumnous.com +*/ +import { Button, Card, Divider, Form, Input, Typography } from '@douyinfe/semi-ui'; +import React, { useState } from 'react'; +import { showError, showSuccess, API } from '../../helpers'; + +const { Title, Text, Paragraph } = Typography; + +const TwoFAVerification = ({ onSuccess, onBack, isModal = false }) => { + const [loading, setLoading] = useState(false); + const [useBackupCode, setUseBackupCode] = useState(false); + const [verificationCode, setVerificationCode] = useState(''); + + const handleSubmit = async () => { + if (!verificationCode) { + showError('请输入验证码'); + return; + } + + setLoading(true); + try { + const res = await API.post('/api/user/login/2fa', { + code: verificationCode + }); + + if (res.data.success) { + showSuccess('登录成功'); + // 保存用户信息到本地存储 + localStorage.setItem('user', JSON.stringify(res.data.data)); + if (onSuccess) { + onSuccess(res.data.data); + } + } else { + showError(res.data.message); + } + } catch (error) { + showError('验证失败,请重试'); + } finally { + setLoading(false); + } + }; + + const handleKeyPress = (e) => { + if (e.key === 'Enter') { + handleSubmit(); + } + }; + + if (isModal) { + return ( +
+ + 请输入认证器应用显示的验证码完成登录 + + +
+ + + + + + + +
+ + + {onBack && ( + + )} +
+ +
+ + 提示: +
+ • 验证码每30秒更新一次 +
+ • 如果无法获取验证码,请使用备用码 +
+ • 每个备用码只能使用一次 +
+
+
+ ); + } + + return ( +
+ +
+ 两步验证 + + 请输入认证器应用显示的验证码完成登录 + +
+ +
+ + + + + + + +
+ + + {onBack && ( + + )} +
+ +
+ + 提示: +
+ • 验证码每30秒更新一次 +
+ • 如果无法获取验证码,请使用备用码 +
+ • 每个备用码只能使用一次 +
+
+
+
+ ); +}; + +export default TwoFAVerification; \ No newline at end of file diff --git a/web/src/components/settings/PersonalSetting.js b/web/src/components/settings/PersonalSetting.js index 1e0132cf..0a350084 100644 --- a/web/src/components/settings/PersonalSetting.js +++ b/web/src/components/settings/PersonalSetting.js @@ -36,6 +36,7 @@ import { renderModelTag, getModelCategories } from '../../helpers'; +import TwoFASetting from './TwoFASetting'; import Turnstile from 'react-turnstile'; import { UserContext } from '../../context/User'; import { useTheme } from '../../context/Theme'; @@ -1041,6 +1042,9 @@ const PersonalSetting = () => {
+ {/* 两步验证设置 */} + + {/* 危险区域 */} . + +For commercial licensing, please contact support@quantumnous.com +*/ +import { API, showError, showSuccess, showWarning } from '../../helpers'; +import { Banner, Button, Card, Checkbox, Divider, Form, Input, Modal, Tag, Typography } from '@douyinfe/semi-ui'; +import React, { useEffect, useState } from 'react'; + +import { QRCodeSVG } from 'qrcode.react'; + +const { Text, Paragraph } = Typography; + +const TwoFASetting = () => { + const [loading, setLoading] = useState(false); + const [status, setStatus] = useState({ + enabled: false, + locked: false, + backup_codes_remaining: 0 + }); + + // 模态框状态 + const [setupModalVisible, setSetupModalVisible] = useState(false); + const [enableModalVisible, setEnableModalVisible] = useState(false); + const [disableModalVisible, setDisableModalVisible] = useState(false); + const [backupModalVisible, setBackupModalVisible] = useState(false); + + // 表单数据 + const [setupData, setSetupData] = useState(null); + const [verificationCode, setVerificationCode] = useState(''); + const [backupCodes, setBackupCodes] = useState([]); + const [confirmDisable, setConfirmDisable] = useState(false); + + // 获取2FA状态 + const fetchStatus = async () => { + try { + const res = await API.get('/api/user/2fa/status'); + if (res.data.success) { + setStatus(res.data.data); + } + } catch (error) { + showError('获取2FA状态失败'); + } + }; + + useEffect(() => { + fetchStatus(); + }, []); + + // 初始化2FA设置 + const handleSetup2FA = async () => { + setLoading(true); + try { + const res = await API.post('/api/user/2fa/setup'); + if (res.data.success) { + setSetupData(res.data.data); + setSetupModalVisible(true); + } else { + showError(res.data.message); + } + } catch (error) { + showError('设置2FA失败'); + } finally { + setLoading(false); + } + }; + + // 启用2FA + const handleEnable2FA = async () => { + if (!verificationCode) { + showWarning('请输入验证码'); + return; + } + + setLoading(true); + try { + const res = await API.post('/api/user/2fa/enable', { + code: verificationCode + }); + if (res.data.success) { + showSuccess('两步验证启用成功!'); + setEnableModalVisible(false); + setSetupModalVisible(false); + setVerificationCode(''); + fetchStatus(); + } else { + showError(res.data.message); + } + } catch (error) { + showError('启用2FA失败'); + } finally { + setLoading(false); + } + }; + + // 禁用2FA + const handleDisable2FA = async () => { + if (!verificationCode) { + showWarning('请输入验证码或备用码'); + return; + } + + if (!confirmDisable) { + showWarning('请确认您已了解禁用两步验证的后果'); + return; + } + + setLoading(true); + try { + const res = await API.post('/api/user/2fa/disable', { + code: verificationCode + }); + if (res.data.success) { + showSuccess('两步验证已禁用'); + setDisableModalVisible(false); + setVerificationCode(''); + setConfirmDisable(false); + fetchStatus(); + } else { + showError(res.data.message); + } + } catch (error) { + showError('禁用2FA失败'); + } finally { + setLoading(false); + } + }; + + // 重新生成备用码 + const handleRegenerateBackupCodes = async () => { + if (!verificationCode) { + showWarning('请输入验证码'); + return; + } + + setLoading(true); + try { + const res = await API.post('/api/user/2fa/backup_codes', { + code: verificationCode + }); + if (res.data.success) { + setBackupCodes(res.data.data.backup_codes); + showSuccess('备用码重新生成成功'); + setVerificationCode(''); + fetchStatus(); + } else { + showError(res.data.message); + } + } catch (error) { + showError('重新生成备用码失败'); + } finally { + setLoading(false); + } + }; + + const copyBackupCodes = () => { + const codesText = backupCodes.join('\n'); + navigator.clipboard.writeText(codesText).then(() => { + showSuccess('备用码已复制到剪贴板'); + }).catch(() => { + showError('复制失败,请手动复制'); + }); + }; + + return ( +
+ +
+
+
+ + + +
+
+
两步验证设置
+
+ 两步验证(2FA)为您的账户提供额外的安全保护。启用后,登录时需要输入密码和验证器应用生成的验证码。 +
+
+ 当前状态: + {status.enabled ? ( + 已启用 + ) : ( + 未启用 + )} + {status.locked && ( + 账户已锁定 + )} +
+ {status.enabled && ( +
+ 剩余备用码:{status.backup_codes_remaining || 0} 个 +
+ )} +
+
+
+ {!status.enabled ? ( + + ) : ( +
+ + +
+ )} +
+
+
+ + {/* 2FA设置模态框 */} + +
+ + + +
+ 设置两步验证 +
+ } + visible={setupModalVisible} + onCancel={() => { + setSetupModalVisible(false); + setSetupData(null); + }} + footer={null} + width={650} + style={{ maxWidth: '90vw' }} + > + {setupData && ( +
+ {/* 步骤 1:扫描二维码 */} +
+
+
+ 1 +
+ 扫描二维码 +
+ + 使用认证器应用(如 Google Authenticator、Microsoft Authenticator)扫描下方二维码: + +
+
+ +
+
+
+ + 或手动输入密钥:{setupData.secret} + +
+
+ + {/* 步骤 2:保存备用码 */} +
+
+
+ 2 +
+ 保存备用码 +
+ + 请将以下备用码保存在安全的地方。如果丢失手机,可以使用这些备用码登录: + +
+
+ {setupData.backup_codes.map((code, index) => ( +
+ {code} +
+ ))} +
+ +
+
+ + {/* 步骤 3:验证设置 */} +
+
+
+ 3 +
+ 验证设置 +
+ + 输入认证器应用显示的6位数字验证码: + +
+ + + +
+
+ )} + + + {/* 禁用2FA模态框 */} + +
+ + + +
+ 禁用两步验证 +
+ } + visible={disableModalVisible} + onCancel={() => { + setDisableModalVisible(false); + setVerificationCode(''); + setConfirmDisable(false); + }} + footer={null} + width={550} + > +
+ +
警告:禁用两步验证将会:
+
    +
  • 降低您账户的安全性
  • +
  • 永久删除您的两步验证设置
  • +
  • 永久删除所有备用码(包括未使用的)
  • +
  • 需要重新完整设置才能再次启用
  • +
+
+ 此操作不可撤销,请谨慎操作! +
+
+ } + className="rounded-lg" + /> +
+ +
+ setConfirmDisable(e.target.checked)} + className="text-sm" + > + 我已了解禁用两步验证将永久删除所有相关设置和备用码,此操作不可撤销 + +
+ + + + + + {/* 重新生成备用码模态框 */} + +
+ + + +
+ 重新生成备用码 + + } + visible={backupModalVisible} + onCancel={() => { + setBackupModalVisible(false); + setVerificationCode(''); + setBackupCodes([]); + }} + footer={null} + width={500} + > +
+ {backupCodes.length === 0 ? ( + <> + +
+ + + + + ) : ( + <> +
+
+ + + +
+ 新的备用码已生成 + + 请将以下备用码保存在安全的地方: + +
+
+
+ {backupCodes.map((code, index) => ( +
+ {code} +
+ ))} +
+ +
+ + )} +
+
+ + ); +}; + +export default TwoFASetting; \ No newline at end of file From d85eeabf11f43aab1e1defcd458b590e2a53fd06 Mon Sep 17 00:00:00 2001 From: Seefs Date: Sun, 3 Aug 2025 10:41:00 +0800 Subject: [PATCH 2/8] fix: coderabbit review --- common/totp.go | 5 +---- controller/twofa.go | 12 +++++++++--- go.mod | 2 +- go.sum | 2 ++ model/twofa.go | 4 +++- web/src/components/auth/TwoFAVerification.js | 10 +++++++++- 6 files changed, 25 insertions(+), 10 deletions(-) diff --git a/common/totp.go b/common/totp.go index ece5bc31..400f9d05 100644 --- a/common/totp.go +++ b/common/totp.go @@ -113,10 +113,7 @@ func HashBackupCode(code string) (string, error) { // Get2FAIssuer 获取2FA发行者名称 func Get2FAIssuer() string { - if issuer := SystemName; issuer != "" { - return issuer - } - return "NewAPI" + return SystemName } // getEnvOrDefault 获取环境变量或默认值 diff --git a/controller/twofa.go b/controller/twofa.go index 368289c9..2a7016c5 100644 --- a/controller/twofa.go +++ b/controller/twofa.go @@ -46,7 +46,7 @@ func Setup2FA(c *gin.Context) { }) return } - + // 如果存在已禁用的2FA记录,先删除它 if existing != nil && !existing.IsEnabled { if err := existing.Delete(); err != nil { @@ -415,8 +415,14 @@ func Verify2FALogin(c *gin.Context) { }) return } - userId := pendingUserId.(int) - + userId, ok := pendingUserId.(int) + if !ok { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": "会话数据无效,请重新登录", + }) + return + } // 获取用户信息 user, err := model.GetUserById(userId, false) if err != nil { diff --git a/go.mod b/go.mod index 1def0b08..86576bc2 100644 --- a/go.mod +++ b/go.mod @@ -45,7 +45,7 @@ require ( github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.5 // indirect github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.5 // indirect github.com/aws/smithy-go v1.20.2 // indirect - github.com/boombuler/barcode v1.0.1-0.20190219062509-6c824513bacc // indirect + github.com/boombuler/barcode v1.1.0 // indirect github.com/bytedance/sonic v1.11.6 // indirect github.com/bytedance/sonic/loader v0.1.1 // indirect github.com/cespare/xxhash/v2 v2.3.0 // indirect diff --git a/go.sum b/go.sum index 4f5ae530..a1cc5ece 100644 --- a/go.sum +++ b/go.sum @@ -22,6 +22,8 @@ github.com/aws/smithy-go v1.20.2 h1:tbp628ireGtzcHDDmLT/6ADHidqnwgF57XOXZe6tp4Q= github.com/aws/smithy-go v1.20.2/go.mod h1:krry+ya/rV9RDcV/Q16kpu6ypI4K2czasz0NC3qS14E= github.com/boombuler/barcode v1.0.1-0.20190219062509-6c824513bacc h1:biVzkmvwrH8WK8raXaxBx6fRVTlJILwEwQGL1I/ByEI= github.com/boombuler/barcode v1.0.1-0.20190219062509-6c824513bacc/go.mod h1:paBWMcWSl3LHKBqUq+rly7CNSldXjb2rDl3JlRe0mD8= +github.com/boombuler/barcode v1.1.0 h1:ChaYjBR63fr4LFyGn8E8nt7dBSt3MiU3zMOZqFvVkHo= +github.com/boombuler/barcode v1.1.0/go.mod h1:paBWMcWSl3LHKBqUq+rly7CNSldXjb2rDl3JlRe0mD8= github.com/bytedance/gopkg v0.0.0-20220118071334-3db87571198b h1:LTGVFpNmNHhj0vhOlfgWueFJ32eK9blaIlHR2ciXOT0= github.com/bytedance/gopkg v0.0.0-20220118071334-3db87571198b/go.mod h1:2ZlV9BaUH4+NXIBF0aMdKKAnHTzqH+iMU4KUjAbL23Q= github.com/bytedance/sonic v1.11.6 h1:oUp34TzMlL+OY1OUWxHqsdkgC/Zfc85zGqw9siXjrc0= diff --git a/model/twofa.go b/model/twofa.go index 4a96ffb0..d7b08f93 100644 --- a/model/twofa.go +++ b/model/twofa.go @@ -9,6 +9,8 @@ import ( "gorm.io/gorm" ) +var ErrTwoFANotEnabled = errors.New("用户未启用2FA") + // TwoFA 用户2FA设置表 type TwoFA struct { Id int `json:"id" gorm:"primaryKey"` @@ -210,7 +212,7 @@ func DisableTwoFA(userId int) error { return err } if twoFA == nil { - return errors.New("用户未启用2FA") + return ErrTwoFANotEnabled } // 删除2FA设置和备用码 diff --git a/web/src/components/auth/TwoFAVerification.js b/web/src/components/auth/TwoFAVerification.js index 384273ed..69756384 100644 --- a/web/src/components/auth/TwoFAVerification.js +++ b/web/src/components/auth/TwoFAVerification.js @@ -16,9 +16,9 @@ along with this program. If not, see . For commercial licensing, please contact support@quantumnous.com */ +import { API, showError, showSuccess } from '../../helpers'; import { Button, Card, Divider, Form, Input, Typography } from '@douyinfe/semi-ui'; import React, { useState } from 'react'; -import { showError, showSuccess, API } from '../../helpers'; const { Title, Text, Paragraph } = Typography; @@ -32,6 +32,14 @@ const TwoFAVerification = ({ onSuccess, onBack, isModal = false }) => { showError('请输入验证码'); return; } + // Validate code format + if (useBackupCode && verificationCode.length !== 8) { + showError('备用码必须是8位'); + return; + } else if (!useBackupCode && !/^\d{6}$/.test(verificationCode)) { + showError('验证码必须是6位数字'); + return; + } setLoading(true); try { From 398ae7156b72f753d0c6893a2e5dffcc2e6ac2bd Mon Sep 17 00:00:00 2001 From: Seefs Date: Sun, 3 Aug 2025 10:49:55 +0800 Subject: [PATCH 3/8] refactor: improve error handling and database transactions in 2FA model methods --- controller/twofa.go | 4 ++-- model/twofa.go | 55 ++++++++++++++++++++++++--------------------- 2 files changed, 32 insertions(+), 27 deletions(-) diff --git a/controller/twofa.go b/controller/twofa.go index 2a7016c5..9f48eed8 100644 --- a/controller/twofa.go +++ b/controller/twofa.go @@ -1,12 +1,12 @@ package controller import ( + "errors" "fmt" "net/http" "one-api/common" "one-api/model" "strconv" - "strings" "github.com/gin-contrib/sessions" "github.com/gin-gonic/gin" @@ -530,7 +530,7 @@ func AdminDisable2FA(c *gin.Context) { // 禁用2FA if err := model.DisableTwoFA(userId); err != nil { - if strings.Contains(err.Error(), "未启用2FA") { + if errors.Is(err, model.ErrTwoFANotEnabled) { c.JSON(http.StatusOK, gin.H{ "success": false, "message": "用户未启用2FA", diff --git a/model/twofa.go b/model/twofa.go index d7b08f93..d09ff9fe 100644 --- a/model/twofa.go +++ b/model/twofa.go @@ -100,13 +100,16 @@ func (t *TwoFA) Delete() error { return errors.New("2FA记录ID不能为空") } - // 同时删除相关的备用码记录(硬删除) - if err := DB.Unscoped().Where("user_id = ?", t.UserId).Delete(&TwoFABackupCode{}).Error; err != nil { - return err - } + // 使用事务确保原子性 + return DB.Transaction(func(tx *gorm.DB) error { + // 同时删除相关的备用码记录(硬删除) + if err := tx.Unscoped().Where("user_id = ?", t.UserId).Delete(&TwoFABackupCode{}).Error; err != nil { + return err + } - // 硬删除2FA记录 - return DB.Unscoped().Delete(t).Error + // 硬删除2FA记录 + return tx.Unscoped().Delete(t).Error + }) } // ResetFailedAttempts 重置失败尝试次数 @@ -139,30 +142,32 @@ func (t *TwoFA) IsLocked() bool { // CreateBackupCodes 创建备用码 func CreateBackupCodes(userId int, codes []string) error { - // 先删除现有的备用码 - if err := DB.Where("user_id = ?", userId).Delete(&TwoFABackupCode{}).Error; err != nil { - return err - } - - // 创建新的备用码记录 - for _, code := range codes { - hashedCode, err := common.HashBackupCode(code) - if err != nil { + return DB.Transaction(func(tx *gorm.DB) error { + // 先删除现有的备用码 + if err := tx.Where("user_id = ?", userId).Delete(&TwoFABackupCode{}).Error; err != nil { return err } - backupCode := TwoFABackupCode{ - UserId: userId, - CodeHash: hashedCode, - IsUsed: false, + // 创建新的备用码记录 + for _, code := range codes { + hashedCode, err := common.HashBackupCode(code) + if err != nil { + return err + } + + backupCode := TwoFABackupCode{ + UserId: userId, + CodeHash: hashedCode, + IsUsed: false, + } + + if err := tx.Create(&backupCode).Error; err != nil { + return err + } } - if err := DB.Create(&backupCode).Error; err != nil { - return err - } - } - - return nil + return nil + }) } // ValidateBackupCode 验证并使用备用码 From ecdd9d1ccbe2ce0ccdf452549bff5846dc50e022 Mon Sep 17 00:00:00 2001 From: CaIon Date: Mon, 4 Aug 2025 16:52:31 +0800 Subject: [PATCH 4/8] feat: add multi-key management --- controller/channel.go | 258 ++++++++++++ model/channel.go | 33 +- model/channel_cache.go | 2 +- router/api-router.go | 1 + .../table/channels/ChannelsColumnDefs.js | 105 +++-- .../table/channels/ChannelsTable.jsx | 7 + web/src/components/table/channels/index.jsx | 7 + .../channels/modals/MultiKeyManageModal.jsx | 372 ++++++++++++++++++ web/src/hooks/channels/useChannelsData.js | 10 + 9 files changed, 730 insertions(+), 65 deletions(-) create mode 100644 web/src/components/table/channels/modals/MultiKeyManageModal.jsx diff --git a/controller/channel.go b/controller/channel.go index d9e4d422..a2ee5743 100644 --- a/controller/channel.go +++ b/controller/channel.go @@ -1030,3 +1030,261 @@ func CopyChannel(c *gin.Context) { // success c.JSON(http.StatusOK, gin.H{"success": true, "message": "", "data": gin.H{"id": clone.Id}}) } + +// MultiKeyManageRequest represents the request for multi-key management operations +type MultiKeyManageRequest struct { + ChannelId int `json:"channel_id"` + Action string `json:"action"` // "disable_key", "enable_key", "delete_disabled_keys", "get_key_status" + KeyIndex *int `json:"key_index,omitempty"` // for disable_key and enable_key actions +} + +// MultiKeyStatusResponse represents the response for key status query +type MultiKeyStatusResponse struct { + Keys []KeyStatus `json:"keys"` +} + +type KeyStatus struct { + Index int `json:"index"` + Status int `json:"status"` // 1: enabled, 2: disabled + DisabledTime int64 `json:"disabled_time,omitempty"` + Reason string `json:"reason,omitempty"` + KeyPreview string `json:"key_preview"` // first 10 chars of key for identification +} + +// ManageMultiKeys handles multi-key management operations +func ManageMultiKeys(c *gin.Context) { + request := MultiKeyManageRequest{} + err := c.ShouldBindJSON(&request) + if err != nil { + common.ApiError(c, err) + return + } + + channel, err := model.GetChannelById(request.ChannelId, true) + if err != nil { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": "渠道不存在", + }) + return + } + + if !channel.ChannelInfo.IsMultiKey { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": "该渠道不是多密钥模式", + }) + return + } + + switch request.Action { + case "get_key_status": + keys := channel.GetKeys() + var keyStatusList []KeyStatus + + for i, key := range keys { + status := 1 // default enabled + var disabledTime int64 + var reason string + + if channel.ChannelInfo.MultiKeyStatusList != nil { + if s, exists := channel.ChannelInfo.MultiKeyStatusList[i]; exists { + status = s + } + } + + if status != 1 { + if channel.ChannelInfo.MultiKeyDisabledTime != nil { + disabledTime = channel.ChannelInfo.MultiKeyDisabledTime[i] + } + if channel.ChannelInfo.MultiKeyDisabledReason != nil { + reason = channel.ChannelInfo.MultiKeyDisabledReason[i] + } + } + + // Create key preview (first 10 chars) + keyPreview := key + if len(key) > 10 { + keyPreview = key[:10] + "..." + } + + keyStatusList = append(keyStatusList, KeyStatus{ + Index: i, + Status: status, + DisabledTime: disabledTime, + Reason: reason, + KeyPreview: keyPreview, + }) + } + + c.JSON(http.StatusOK, gin.H{ + "success": true, + "message": "", + "data": MultiKeyStatusResponse{Keys: keyStatusList}, + }) + return + + case "disable_key": + if request.KeyIndex == nil { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": "未指定要禁用的密钥索引", + }) + return + } + + keyIndex := *request.KeyIndex + if keyIndex < 0 || keyIndex >= channel.ChannelInfo.MultiKeySize { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": "密钥索引超出范围", + }) + return + } + + if channel.ChannelInfo.MultiKeyStatusList == nil { + channel.ChannelInfo.MultiKeyStatusList = make(map[int]int) + } + if channel.ChannelInfo.MultiKeyDisabledTime == nil { + channel.ChannelInfo.MultiKeyDisabledTime = make(map[int]int64) + } + if channel.ChannelInfo.MultiKeyDisabledReason == nil { + channel.ChannelInfo.MultiKeyDisabledReason = make(map[int]string) + } + + channel.ChannelInfo.MultiKeyStatusList[keyIndex] = 2 // disabled + channel.ChannelInfo.MultiKeyDisabledTime[keyIndex] = common.GetTimestamp() + channel.ChannelInfo.MultiKeyDisabledReason[keyIndex] = "手动禁用" + + err = channel.Update() + if err != nil { + common.ApiError(c, err) + return + } + + model.InitChannelCache() + c.JSON(http.StatusOK, gin.H{ + "success": true, + "message": "密钥已禁用", + }) + return + + case "enable_key": + if request.KeyIndex == nil { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": "未指定要启用的密钥索引", + }) + return + } + + keyIndex := *request.KeyIndex + if keyIndex < 0 || keyIndex >= channel.ChannelInfo.MultiKeySize { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": "密钥索引超出范围", + }) + return + } + + // 从状态列表中删除该密钥的记录,使其回到默认启用状态 + if channel.ChannelInfo.MultiKeyStatusList != nil { + delete(channel.ChannelInfo.MultiKeyStatusList, keyIndex) + } + if channel.ChannelInfo.MultiKeyDisabledTime != nil { + delete(channel.ChannelInfo.MultiKeyDisabledTime, keyIndex) + } + if channel.ChannelInfo.MultiKeyDisabledReason != nil { + delete(channel.ChannelInfo.MultiKeyDisabledReason, keyIndex) + } + + err = channel.Update() + if err != nil { + common.ApiError(c, err) + return + } + + model.InitChannelCache() + c.JSON(http.StatusOK, gin.H{ + "success": true, + "message": "密钥已启用", + }) + return + + case "delete_disabled_keys": + keys := channel.GetKeys() + var remainingKeys []string + var deletedCount int + var newStatusList = make(map[int]int) + var newDisabledTime = make(map[int]int64) + var newDisabledReason = make(map[int]string) + + newIndex := 0 + for i, key := range keys { + status := 1 // default enabled + if channel.ChannelInfo.MultiKeyStatusList != nil { + if s, exists := channel.ChannelInfo.MultiKeyStatusList[i]; exists { + status = s + } + } + + // 只删除自动禁用(status == 3)的密钥,保留启用(status == 1)和手动禁用(status == 2)的密钥 + if status == 3 { + deletedCount++ + } else { + remainingKeys = append(remainingKeys, key) + // 保留非自动禁用密钥的状态信息,重新索引 + if status != 1 { + newStatusList[newIndex] = status + if channel.ChannelInfo.MultiKeyDisabledTime != nil { + if t, exists := channel.ChannelInfo.MultiKeyDisabledTime[i]; exists { + newDisabledTime[newIndex] = t + } + } + if channel.ChannelInfo.MultiKeyDisabledReason != nil { + if r, exists := channel.ChannelInfo.MultiKeyDisabledReason[i]; exists { + newDisabledReason[newIndex] = r + } + } + } + newIndex++ + } + } + + if deletedCount == 0 { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": "没有需要删除的自动禁用密钥", + }) + return + } + + // Update channel with remaining keys + channel.Key = strings.Join(remainingKeys, "\n") + channel.ChannelInfo.MultiKeySize = len(remainingKeys) + channel.ChannelInfo.MultiKeyStatusList = newStatusList + channel.ChannelInfo.MultiKeyDisabledTime = newDisabledTime + channel.ChannelInfo.MultiKeyDisabledReason = newDisabledReason + + err = channel.Update() + if err != nil { + common.ApiError(c, err) + return + } + + model.InitChannelCache() + c.JSON(http.StatusOK, gin.H{ + "success": true, + "message": fmt.Sprintf("已删除 %d 个自动禁用的密钥", deletedCount), + "data": deletedCount, + }) + return + + default: + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": "不支持的操作", + }) + return + } +} diff --git a/model/channel.go b/model/channel.go index bcffc102..502171fa 100644 --- a/model/channel.go +++ b/model/channel.go @@ -41,6 +41,7 @@ type Channel struct { Priority *int64 `json:"priority" gorm:"bigint;default:0"` AutoBan *int `json:"auto_ban" gorm:"default:1"` OtherInfo string `json:"other_info"` + Settings string `json:"settings"` Tag *string `json:"tag" gorm:"index"` Setting *string `json:"setting" gorm:"type:text"` // 渠道额外设置 ParamOverride *string `json:"param_override" gorm:"type:text"` @@ -52,11 +53,13 @@ type Channel struct { } type ChannelInfo struct { - IsMultiKey bool `json:"is_multi_key"` // 是否多Key模式 - MultiKeySize int `json:"multi_key_size"` // 多Key模式下的Key数量 - MultiKeyStatusList map[int]int `json:"multi_key_status_list"` // key状态列表,key index -> status - MultiKeyPollingIndex int `json:"multi_key_polling_index"` // 多Key模式下轮询的key索引 - MultiKeyMode constant.MultiKeyMode `json:"multi_key_mode"` + IsMultiKey bool `json:"is_multi_key"` // 是否多Key模式 + MultiKeySize int `json:"multi_key_size"` // 多Key模式下的Key数量 + MultiKeyStatusList map[int]int `json:"multi_key_status_list"` // key状态列表,key index -> status + MultiKeyDisabledReason map[int]string `json:"multi_key_disabled_reason"` // key禁用原因列表,key index -> reason + MultiKeyDisabledTime map[int]int64 `json:"multi_key_disabled_time"` // key禁用时间列表,key index -> time + MultiKeyPollingIndex int `json:"multi_key_polling_index"` // 多Key模式下轮询的key索引 + MultiKeyMode constant.MultiKeyMode `json:"multi_key_mode"` } // Value implements driver.Valuer interface @@ -70,7 +73,7 @@ func (c *ChannelInfo) Scan(value interface{}) error { return common.Unmarshal(bytesValue, c) } -func (channel *Channel) getKeys() []string { +func (channel *Channel) GetKeys() []string { if channel.Key == "" { return []string{} } @@ -101,7 +104,7 @@ func (channel *Channel) GetNextEnabledKey() (string, int, *types.NewAPIError) { } // Obtain all keys (split by \n) - keys := channel.getKeys() + keys := channel.GetKeys() if len(keys) == 0 { // No keys available, return error, should disable the channel return "", 0, types.NewError(errors.New("no keys available"), types.ErrorCodeChannelNoAvailableKey) @@ -528,8 +531,8 @@ func CleanupChannelPollingLocks() { }) } -func handlerMultiKeyUpdate(channel *Channel, usingKey string, status int) { - keys := channel.getKeys() +func handlerMultiKeyUpdate(channel *Channel, usingKey string, status int, reason string) { + keys := channel.GetKeys() if len(keys) == 0 { channel.Status = status } else { @@ -547,6 +550,14 @@ func handlerMultiKeyUpdate(channel *Channel, usingKey string, status int) { delete(channel.ChannelInfo.MultiKeyStatusList, keyIndex) } else { channel.ChannelInfo.MultiKeyStatusList[keyIndex] = status + if channel.ChannelInfo.MultiKeyDisabledReason == nil { + channel.ChannelInfo.MultiKeyDisabledReason = make(map[int]string) + } + if channel.ChannelInfo.MultiKeyDisabledTime == nil { + channel.ChannelInfo.MultiKeyDisabledTime = make(map[int]int64) + } + channel.ChannelInfo.MultiKeyDisabledReason[keyIndex] = reason + channel.ChannelInfo.MultiKeyDisabledTime[keyIndex] = common.GetTimestamp() } if len(channel.ChannelInfo.MultiKeyStatusList) >= channel.ChannelInfo.MultiKeySize { channel.Status = common.ChannelStatusAutoDisabled @@ -569,7 +580,7 @@ func UpdateChannelStatus(channelId int, usingKey string, status int, reason stri } if channelCache.ChannelInfo.IsMultiKey { // 如果是多Key模式,更新缓存中的状态 - handlerMultiKeyUpdate(channelCache, usingKey, status) + handlerMultiKeyUpdate(channelCache, usingKey, status, reason) //CacheUpdateChannel(channelCache) //return true } else { @@ -600,7 +611,7 @@ func UpdateChannelStatus(channelId int, usingKey string, status int, reason stri if channel.ChannelInfo.IsMultiKey { beforeStatus := channel.Status - handlerMultiKeyUpdate(channel, usingKey, status) + handlerMultiKeyUpdate(channel, usingKey, status, reason) if beforeStatus != channel.Status { shouldUpdateAbilities = true } diff --git a/model/channel_cache.go b/model/channel_cache.go index ecd87607..6ca23cf9 100644 --- a/model/channel_cache.go +++ b/model/channel_cache.go @@ -70,7 +70,7 @@ func InitChannelCache() { //channelsIDM = newChannelId2channel for i, channel := range newChannelId2channel { if channel.ChannelInfo.IsMultiKey { - channel.Keys = channel.getKeys() + channel.Keys = channel.GetKeys() if channel.ChannelInfo.MultiKeyMode == constant.MultiKeyModePolling { if oldChannel, ok := channelsIDM[i]; ok { // 存在旧的渠道,如果是多key且轮询,保留轮询索引信息 diff --git a/router/api-router.go b/router/api-router.go index bc49803a..12846012 100644 --- a/router/api-router.go +++ b/router/api-router.go @@ -120,6 +120,7 @@ func SetApiRouter(router *gin.Engine) { channelRoute.POST("/batch/tag", controller.BatchSetChannelTag) channelRoute.GET("/tag/models", controller.GetTagModels) channelRoute.POST("/copy/:id", controller.CopyChannel) + channelRoute.POST("/multi_key/manage", controller.ManageMultiKeys) } tokenRoute := apiRouter.Group("/token") tokenRoute.Use(middleware.UserAuth()) diff --git a/web/src/components/table/channels/ChannelsColumnDefs.js b/web/src/components/table/channels/ChannelsColumnDefs.js index beb5fe55..18cb5700 100644 --- a/web/src/components/table/channels/ChannelsColumnDefs.js +++ b/web/src/components/table/channels/ChannelsColumnDefs.js @@ -210,7 +210,9 @@ export const getChannelsColumns = ({ copySelectedChannel, refresh, activePage, - channels + channels, + setShowMultiKeyManageModal, + setCurrentMultiKeyChannel }) => { return [ { @@ -503,47 +505,7 @@ export const getChannelsColumns = ({ /> - {record.channel_info?.is_multi_key ? ( - - { - record.status === 1 ? ( - - ) : ( - - ) - } - manageChannel(record.id, 'enable_all', record), - } - ]} - > - + {record.channel_info?.is_multi_key ? ( + + + { + setCurrentMultiKeyChannel(record); + setShowMultiKeyManageModal(true); + }, + } + ]} + > + + )} { setEditingTag, copySelectedChannel, refresh, + // Multi-key management + setShowMultiKeyManageModal, + setCurrentMultiKeyChannel, } = channelsData; // Get all columns @@ -79,6 +82,8 @@ const ChannelsTable = (channelsData) => { refresh, activePage, channels, + setShowMultiKeyManageModal, + setCurrentMultiKeyChannel, }); }, [ t, @@ -98,6 +103,8 @@ const ChannelsTable = (channelsData) => { refresh, activePage, channels, + setShowMultiKeyManageModal, + setCurrentMultiKeyChannel, ]); // Filter columns based on visibility settings diff --git a/web/src/components/table/channels/index.jsx b/web/src/components/table/channels/index.jsx index b0106b4e..66e2d72d 100644 --- a/web/src/components/table/channels/index.jsx +++ b/web/src/components/table/channels/index.jsx @@ -30,6 +30,7 @@ import ModelTestModal from './modals/ModelTestModal.jsx'; import ColumnSelectorModal from './modals/ColumnSelectorModal.jsx'; import EditChannelModal from './modals/EditChannelModal.jsx'; import EditTagModal from './modals/EditTagModal.jsx'; +import MultiKeyManageModal from './modals/MultiKeyManageModal.jsx'; import { createCardProPagination } from '../../../helpers/utils'; const ChannelsPage = () => { @@ -54,6 +55,12 @@ const ChannelsPage = () => { /> + channelsData.setShowMultiKeyManageModal(false)} + channel={channelsData.currentMultiKeyChannel} + onRefresh={channelsData.refresh} + /> {/* Main Content */} . + +For commercial licensing, please contact support@quantumnous.com +*/ + +import React, { useState, useEffect } from 'react'; +import { useTranslation } from 'react-i18next'; +import { + Modal, + Button, + Table, + Tag, + Typography, + Space, + Tooltip, + Popconfirm, + Empty, + Spin, + Banner +} from '@douyinfe/semi-ui'; +import { + IconRefresh, + IconDelete, + IconClose, + IconSave, + IconSetting +} from '@douyinfe/semi-icons'; +import { API, showError, showSuccess, timestamp2string } from '../../../../helpers/index.js'; + +const { Text, Title } = Typography; + +const MultiKeyManageModal = ({ + visible, + onCancel, + channel, + onRefresh +}) => { + const { t } = useTranslation(); + const [loading, setLoading] = useState(false); + const [keyStatusList, setKeyStatusList] = useState([]); + const [operationLoading, setOperationLoading] = useState({}); + + // Load key status data + const loadKeyStatus = async () => { + if (!channel?.id) return; + + setLoading(true); + try { + const res = await API.post('/api/channel/multi_key/manage', { + channel_id: channel.id, + action: 'get_key_status' + }); + + if (res.data.success) { + setKeyStatusList(res.data.data.keys || []); + } else { + showError(res.data.message); + } + } catch (error) { + showError(t('获取密钥状态失败')); + } finally { + setLoading(false); + } + }; + + // Disable a specific key + const handleDisableKey = async (keyIndex) => { + const operationId = `disable_${keyIndex}`; + setOperationLoading(prev => ({ ...prev, [operationId]: true })); + + try { + const res = await API.post('/api/channel/multi_key/manage', { + channel_id: channel.id, + action: 'disable_key', + key_index: keyIndex + }); + + if (res.data.success) { + showSuccess(t('密钥已禁用')); + await loadKeyStatus(); // Reload data + onRefresh && onRefresh(); // Refresh parent component + } else { + showError(res.data.message); + } + } catch (error) { + showError(t('禁用密钥失败')); + } finally { + setOperationLoading(prev => ({ ...prev, [operationId]: false })); + } + }; + + // Enable a specific key + const handleEnableKey = async (keyIndex) => { + const operationId = `enable_${keyIndex}`; + setOperationLoading(prev => ({ ...prev, [operationId]: true })); + + try { + const res = await API.post('/api/channel/multi_key/manage', { + channel_id: channel.id, + action: 'enable_key', + key_index: keyIndex + }); + + if (res.data.success) { + showSuccess(t('密钥已启用')); + await loadKeyStatus(); // Reload data + onRefresh && onRefresh(); // Refresh parent component + } else { + showError(res.data.message); + } + } catch (error) { + showError(t('启用密钥失败')); + } finally { + setOperationLoading(prev => ({ ...prev, [operationId]: false })); + } + }; + + // Delete all disabled keys + const handleDeleteDisabledKeys = async () => { + setOperationLoading(prev => ({ ...prev, delete_disabled: true })); + + try { + const res = await API.post('/api/channel/multi_key/manage', { + channel_id: channel.id, + action: 'delete_disabled_keys' + }); + + if (res.data.success) { + showSuccess(res.data.message); + await loadKeyStatus(); // Reload data + onRefresh && onRefresh(); // Refresh parent component + } else { + showError(res.data.message); + } + } catch (error) { + showError(t('删除禁用密钥失败')); + } finally { + setOperationLoading(prev => ({ ...prev, delete_disabled: false })); + } + }; + + // Effect to load data when modal opens + useEffect(() => { + if (visible && channel?.id) { + loadKeyStatus(); + } + }, [visible, channel?.id]); + + // Get status tag component + const renderStatusTag = (status) => { + switch (status) { + case 1: + return {t('已启用')}; + case 2: + return {t('已禁用')}; + case 3: + return {t('自动禁用')}; + default: + return {t('未知状态')}; + } + }; + + // Table columns definition + const columns = [ + { + title: t('索引'), + dataIndex: 'index', + render: (text) => `#${text}`, + }, + { + title: t('密钥预览'), + dataIndex: 'key_preview', + render: (text) => ( + + {text} + + ), + }, + { + title: t('状态'), + dataIndex: 'status', + width: 100, + render: (status) => renderStatusTag(status), + }, + { + title: t('禁用原因'), + dataIndex: 'reason', + width: 220, + render: (reason, record) => { + if (record.status === 1 || !reason) { + return -; + } + return ( + + + {reason} + + + ); + }, + }, + { + title: t('禁用时间'), + dataIndex: 'disabled_time', + width: 150, + render: (time, record) => { + if (record.status === 1 || !time) { + return -; + } + return ( + + + {timestamp2string(time)} + + + ); + }, + }, + { + title: t('操作'), + key: 'action', + width: 120, + render: (_, record) => ( + + {record.status === 1 ? ( + handleDisableKey(record.index)} + > + + + ) : ( + handleEnableKey(record.index)} + > + + + )} + + ), + }, + ]; + + // Calculate statistics + const enabledCount = keyStatusList.filter(key => key.status === 1).length; + const manualDisabledCount = keyStatusList.filter(key => key.status === 2).length; + const autoDisabledCount = keyStatusList.filter(key => key.status === 3).length; + const totalCount = keyStatusList.length; + + return ( + + + {t('多密钥管理')} - {channel?.name} + + } + visible={visible} + onCancel={onCancel} + width={800} + height={600} + footer={ + + + + {autoDisabledCount > 0 && ( + + + + )} + + } + > +
+ {/* Statistics Banner */} + + + {t('总共 {{total}} 个密钥,{{enabled}} 个已启用,{{manual}} 个手动禁用,{{auto}} 个自动禁用', { + total: totalCount, + enabled: enabledCount, + manual: manualDisabledCount, + auto: autoDisabledCount + })} + + {channel?.channel_info?.multi_key_mode && ( +
+ + {t('多密钥模式')}: {channel.channel_info.multi_key_mode === 'random' ? t('随机') : t('轮询')} + +
+ )} +
+ } + /> + + {/* Key Status Table */} + + {keyStatusList.length > 0 ? ( + + ) : ( + !loading && ( + + ) + )} + + + + ); +}; + +export default MultiKeyManageModal; \ No newline at end of file diff --git a/web/src/hooks/channels/useChannelsData.js b/web/src/hooks/channels/useChannelsData.js index d188c9fe..8f1f8c29 100644 --- a/web/src/hooks/channels/useChannelsData.js +++ b/web/src/hooks/channels/useChannelsData.js @@ -83,6 +83,10 @@ export const useChannelsData = () => { const [isProcessingQueue, setIsProcessingQueue] = useState(false); const [modelTablePage, setModelTablePage] = useState(1); + // Multi-key management states + const [showMultiKeyManageModal, setShowMultiKeyManageModal] = useState(false); + const [currentMultiKeyChannel, setCurrentMultiKeyChannel] = useState(null); + // Refs const requestCounter = useRef(0); const allSelectingRef = useRef(false); @@ -885,6 +889,12 @@ export const useChannelsData = () => { setModelTablePage, allSelectingRef, + // Multi-key management states + showMultiKeyManageModal, + setShowMultiKeyManageModal, + currentMultiKeyChannel, + setCurrentMultiKeyChannel, + // Form formApi, setFormApi, From 8357b15fec0a3e1a6c6607be12b641ee050fd6c0 Mon Sep 17 00:00:00 2001 From: CaIon Date: Mon, 4 Aug 2025 17:15:32 +0800 Subject: [PATCH 5/8] feat: enhance multi-key management with pagination and statistics --- controller/channel.go | 122 ++++++++++++--- model/channel.go | 12 +- .../channels/modals/MultiKeyManageModal.jsx | 147 +++++++++++++++--- 3 files changed, 228 insertions(+), 53 deletions(-) diff --git a/controller/channel.go b/controller/channel.go index a2ee5743..440815cc 100644 --- a/controller/channel.go +++ b/controller/channel.go @@ -71,6 +71,13 @@ func parseStatusFilter(statusParam string) int { } } +func clearChannelInfo(channel *model.Channel) { + if channel.ChannelInfo.IsMultiKey { + channel.ChannelInfo.MultiKeyDisabledReason = nil + channel.ChannelInfo.MultiKeyDisabledTime = nil + } +} + func GetAllChannels(c *gin.Context) { pageInfo := common.GetPageQuery(c) channelData := make([]*model.Channel, 0) @@ -145,6 +152,10 @@ func GetAllChannels(c *gin.Context) { } } + for _, datum := range channelData { + clearChannelInfo(datum) + } + countQuery := model.DB.Model(&model.Channel{}) if statusFilter == common.ChannelStatusEnabled { countQuery = countQuery.Where("status = ?", common.ChannelStatusEnabled) @@ -371,6 +382,10 @@ func SearchChannels(c *gin.Context) { pagedData := channelData[startIdx:endIdx] + for _, datum := range pagedData { + clearChannelInfo(datum) + } + c.JSON(http.StatusOK, gin.H{ "success": true, "message": "", @@ -394,6 +409,9 @@ func GetChannel(c *gin.Context) { common.ApiError(c, err) return } + if channel != nil { + clearChannelInfo(channel) + } c.JSON(http.StatusOK, gin.H{ "success": true, "message": "", @@ -827,6 +845,7 @@ func UpdateChannel(c *gin.Context) { } model.InitChannelCache() channel.Key = "" + clearChannelInfo(&channel.Channel) c.JSON(http.StatusOK, gin.H{ "success": true, "message": "", @@ -1036,11 +1055,21 @@ type MultiKeyManageRequest struct { ChannelId int `json:"channel_id"` Action string `json:"action"` // "disable_key", "enable_key", "delete_disabled_keys", "get_key_status" KeyIndex *int `json:"key_index,omitempty"` // for disable_key and enable_key actions + Page int `json:"page,omitempty"` // for get_key_status pagination + PageSize int `json:"page_size,omitempty"` // for get_key_status pagination } // MultiKeyStatusResponse represents the response for key status query type MultiKeyStatusResponse struct { - Keys []KeyStatus `json:"keys"` + Keys []KeyStatus `json:"keys"` + Total int `json:"total"` + Page int `json:"page"` + PageSize int `json:"page_size"` + TotalPages int `json:"total_pages"` + // Statistics + EnabledCount int `json:"enabled_count"` + ManualDisabledCount int `json:"manual_disabled_count"` + AutoDisabledCount int `json:"auto_disabled_count"` } type KeyStatus struct { @@ -1080,8 +1109,35 @@ func ManageMultiKeys(c *gin.Context) { switch request.Action { case "get_key_status": keys := channel.GetKeys() - var keyStatusList []KeyStatus + total := len(keys) + // Default pagination parameters + page := request.Page + pageSize := request.PageSize + if page <= 0 { + page = 1 + } + if pageSize <= 0 { + pageSize = 50 // Default page size + } + + // Calculate pagination + totalPages := (total + pageSize - 1) / pageSize + if page > totalPages && totalPages > 0 { + page = totalPages + } + + // Calculate range + start := (page - 1) * pageSize + end := start + pageSize + if end > total { + end = total + } + + // Statistics for all keys + var enabledCount, manualDisabledCount, autoDisabledCount int + + var keyStatusList []KeyStatus for i, key := range keys { status := 1 // default enabled var disabledTime int64 @@ -1093,34 +1149,56 @@ func ManageMultiKeys(c *gin.Context) { } } - if status != 1 { - if channel.ChannelInfo.MultiKeyDisabledTime != nil { - disabledTime = channel.ChannelInfo.MultiKeyDisabledTime[i] - } - if channel.ChannelInfo.MultiKeyDisabledReason != nil { - reason = channel.ChannelInfo.MultiKeyDisabledReason[i] - } + // Count for statistics + switch status { + case 1: + enabledCount++ + case 2: + manualDisabledCount++ + case 3: + autoDisabledCount++ } - // Create key preview (first 10 chars) - keyPreview := key - if len(key) > 10 { - keyPreview = key[:10] + "..." - } + // Only include keys in current page + if i >= start && i < end { + if status != 1 { + if channel.ChannelInfo.MultiKeyDisabledTime != nil { + disabledTime = channel.ChannelInfo.MultiKeyDisabledTime[i] + } + if channel.ChannelInfo.MultiKeyDisabledReason != nil { + reason = channel.ChannelInfo.MultiKeyDisabledReason[i] + } + } - keyStatusList = append(keyStatusList, KeyStatus{ - Index: i, - Status: status, - DisabledTime: disabledTime, - Reason: reason, - KeyPreview: keyPreview, - }) + // Create key preview (first 10 chars) + keyPreview := key + if len(key) > 10 { + keyPreview = key[:10] + "..." + } + + keyStatusList = append(keyStatusList, KeyStatus{ + Index: i, + Status: status, + DisabledTime: disabledTime, + Reason: reason, + KeyPreview: keyPreview, + }) + } } c.JSON(http.StatusOK, gin.H{ "success": true, "message": "", - "data": MultiKeyStatusResponse{Keys: keyStatusList}, + "data": MultiKeyStatusResponse{ + Keys: keyStatusList, + Total: total, + Page: page, + PageSize: pageSize, + TotalPages: totalPages, + EnabledCount: enabledCount, + ManualDisabledCount: manualDisabledCount, + AutoDisabledCount: autoDisabledCount, + }, }) return diff --git a/model/channel.go b/model/channel.go index 502171fa..280781f1 100644 --- a/model/channel.go +++ b/model/channel.go @@ -53,12 +53,12 @@ type Channel struct { } type ChannelInfo struct { - IsMultiKey bool `json:"is_multi_key"` // 是否多Key模式 - MultiKeySize int `json:"multi_key_size"` // 多Key模式下的Key数量 - MultiKeyStatusList map[int]int `json:"multi_key_status_list"` // key状态列表,key index -> status - MultiKeyDisabledReason map[int]string `json:"multi_key_disabled_reason"` // key禁用原因列表,key index -> reason - MultiKeyDisabledTime map[int]int64 `json:"multi_key_disabled_time"` // key禁用时间列表,key index -> time - MultiKeyPollingIndex int `json:"multi_key_polling_index"` // 多Key模式下轮询的key索引 + IsMultiKey bool `json:"is_multi_key"` // 是否多Key模式 + MultiKeySize int `json:"multi_key_size"` // 多Key模式下的Key数量 + MultiKeyStatusList map[int]int `json:"multi_key_status_list"` // key状态列表,key index -> status + MultiKeyDisabledReason map[int]string `json:"multi_key_disabled_reason,omitempty"` // key禁用原因列表,key index -> reason + MultiKeyDisabledTime map[int]int64 `json:"multi_key_disabled_time,omitempty"` // key禁用时间列表,key index -> time + MultiKeyPollingIndex int `json:"multi_key_polling_index"` // 多Key模式下轮询的key索引 MultiKeyMode constant.MultiKeyMode `json:"multi_key_mode"` } diff --git a/web/src/components/table/channels/modals/MultiKeyManageModal.jsx b/web/src/components/table/channels/modals/MultiKeyManageModal.jsx index 9ae46ea3..44f16c03 100644 --- a/web/src/components/table/channels/modals/MultiKeyManageModal.jsx +++ b/web/src/components/table/channels/modals/MultiKeyManageModal.jsx @@ -30,7 +30,9 @@ import { Popconfirm, Empty, Spin, - Banner + Banner, + Select, + Pagination } from '@douyinfe/semi-ui'; import { IconRefresh, @@ -53,24 +55,48 @@ const MultiKeyManageModal = ({ const [loading, setLoading] = useState(false); const [keyStatusList, setKeyStatusList] = useState([]); const [operationLoading, setOperationLoading] = useState({}); + + // Pagination states + const [currentPage, setCurrentPage] = useState(1); + const [pageSize, setPageSize] = useState(50); + const [total, setTotal] = useState(0); + const [totalPages, setTotalPages] = useState(0); + + // Statistics states + const [enabledCount, setEnabledCount] = useState(0); + const [manualDisabledCount, setManualDisabledCount] = useState(0); + const [autoDisabledCount, setAutoDisabledCount] = useState(0); // Load key status data - const loadKeyStatus = async () => { + const loadKeyStatus = async (page = currentPage, size = pageSize) => { if (!channel?.id) return; setLoading(true); try { const res = await API.post('/api/channel/multi_key/manage', { channel_id: channel.id, - action: 'get_key_status' + action: 'get_key_status', + page: page, + page_size: size }); if (res.data.success) { - setKeyStatusList(res.data.data.keys || []); + const data = res.data.data; + setKeyStatusList(data.keys || []); + setTotal(data.total || 0); + setCurrentPage(data.page || 1); + setPageSize(data.page_size || 50); + setTotalPages(data.total_pages || 0); + + // Update statistics + setEnabledCount(data.enabled_count || 0); + setManualDisabledCount(data.manual_disabled_count || 0); + setAutoDisabledCount(data.auto_disabled_count || 0); } else { showError(res.data.message); } } catch (error) { + console.error(error); showError(t('获取密钥状态失败')); } finally { setLoading(false); @@ -91,7 +117,7 @@ const MultiKeyManageModal = ({ if (res.data.success) { showSuccess(t('密钥已禁用')); - await loadKeyStatus(); // Reload data + await loadKeyStatus(currentPage, pageSize); // Reload current page onRefresh && onRefresh(); // Refresh parent component } else { showError(res.data.message); @@ -117,7 +143,7 @@ const MultiKeyManageModal = ({ if (res.data.success) { showSuccess(t('密钥已启用')); - await loadKeyStatus(); // Reload data + await loadKeyStatus(currentPage, pageSize); // Reload current page onRefresh && onRefresh(); // Refresh parent component } else { showError(res.data.message); @@ -141,7 +167,9 @@ const MultiKeyManageModal = ({ if (res.data.success) { showSuccess(res.data.message); - await loadKeyStatus(); // Reload data + // Reset to first page after deletion as data structure might change + setCurrentPage(1); + await loadKeyStatus(1, pageSize); onRefresh && onRefresh(); // Refresh parent component } else { showError(res.data.message); @@ -153,13 +181,40 @@ const MultiKeyManageModal = ({ } }; + // Handle page change + const handlePageChange = (page) => { + setCurrentPage(page); + loadKeyStatus(page, pageSize); + }; + + // Handle page size change + const handlePageSizeChange = (size) => { + setPageSize(size); + setCurrentPage(1); // Reset to first page + loadKeyStatus(1, size); + }; + // Effect to load data when modal opens useEffect(() => { if (visible && channel?.id) { - loadKeyStatus(); + setCurrentPage(1); // Reset to first page when opening + loadKeyStatus(1, pageSize); } }, [visible, channel?.id]); + // Reset pagination when modal closes + useEffect(() => { + if (!visible) { + setCurrentPage(1); + setKeyStatusList([]); + setTotal(0); + setTotalPages(0); + setEnabledCount(0); + setManualDisabledCount(0); + setAutoDisabledCount(0); + } + }, [visible]); + // Get status tag component const renderStatusTag = (status) => { switch (status) { @@ -270,12 +325,6 @@ const MultiKeyManageModal = ({ }, ]; - // Calculate statistics - const enabledCount = keyStatusList.filter(key => key.status === 1).length; - const manualDisabledCount = keyStatusList.filter(key => key.status === 2).length; - const autoDisabledCount = keyStatusList.filter(key => key.status === 3).length; - const totalCount = keyStatusList.length; - return ( {t('关闭')}
+ <> +
+ + {/* Pagination */} + {total > 0 && ( +
+ + {t('显示第 {{start}}-{{end}} 条,共 {{total}} 条', { + start: (currentPage - 1) * pageSize + 1, + end: Math.min(currentPage * pageSize, total), + total: total + })} + + +
+ + {t('每页显示')}: + + + + + t('第 {{current}} / {{total}} 页', { + current: currentPage, + total: totalPages + }) + } + /> +
+
+ )} + ) : ( !loading && ( Date: Mon, 4 Aug 2025 19:33:24 +0800 Subject: [PATCH 6/8] fix: correct option value for pagination in MultiKeyManageModal --- .../components/table/channels/modals/MultiKeyManageModal.jsx | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/web/src/components/table/channels/modals/MultiKeyManageModal.jsx b/web/src/components/table/channels/modals/MultiKeyManageModal.jsx index 44f16c03..6bb14184 100644 --- a/web/src/components/table/channels/modals/MultiKeyManageModal.jsx +++ b/web/src/components/table/channels/modals/MultiKeyManageModal.jsx @@ -428,7 +428,7 @@ const MultiKeyManageModal = ({ > 50 100 - 500 + 500 1000 From 12b4e80d4b242d79153687b4e4a93ccaafc96cbf Mon Sep 17 00:00:00 2001 From: CaIon Date: Mon, 4 Aug 2025 19:51:58 +0800 Subject: [PATCH 7/8] feat: add status filtering and bulk enable/disable functionality in multi-key management --- controller/channel.go | 181 ++++++++++++---- .../channels/modals/MultiKeyManageModal.jsx | 197 ++++++++++++++---- 2 files changed, 288 insertions(+), 90 deletions(-) diff --git a/controller/channel.go b/controller/channel.go index 440815cc..7756e18f 100644 --- a/controller/channel.go +++ b/controller/channel.go @@ -1057,6 +1057,7 @@ type MultiKeyManageRequest struct { KeyIndex *int `json:"key_index,omitempty"` // for disable_key and enable_key actions Page int `json:"page,omitempty"` // for get_key_status pagination PageSize int `json:"page_size,omitempty"` // for get_key_status pagination + Status *int `json:"status,omitempty"` // for get_key_status filtering: 1=enabled, 2=manual_disabled, 3=auto_disabled, nil=all } // MultiKeyStatusResponse represents the response for key status query @@ -1109,7 +1110,6 @@ func ManageMultiKeys(c *gin.Context) { switch request.Action { case "get_key_status": keys := channel.GetKeys() - total := len(keys) // Default pagination parameters page := request.Page @@ -1121,23 +1121,11 @@ func ManageMultiKeys(c *gin.Context) { pageSize = 50 // Default page size } - // Calculate pagination - totalPages := (total + pageSize - 1) / pageSize - if page > totalPages && totalPages > 0 { - page = totalPages - } - - // Calculate range - start := (page - 1) * pageSize - end := start + pageSize - if end > total { - end = total - } - - // Statistics for all keys + // Statistics for all keys (unchanged by filtering) var enabledCount, manualDisabledCount, autoDisabledCount int - var keyStatusList []KeyStatus + // Build all key status data first + var allKeyStatusList []KeyStatus for i, key := range keys { status := 1 // default enabled var disabledTime int64 @@ -1149,7 +1137,7 @@ func ManageMultiKeys(c *gin.Context) { } } - // Count for statistics + // Count for statistics (all keys) switch status { case 1: enabledCount++ @@ -1159,45 +1147,77 @@ func ManageMultiKeys(c *gin.Context) { autoDisabledCount++ } - // Only include keys in current page - if i >= start && i < end { - if status != 1 { - if channel.ChannelInfo.MultiKeyDisabledTime != nil { - disabledTime = channel.ChannelInfo.MultiKeyDisabledTime[i] - } - if channel.ChannelInfo.MultiKeyDisabledReason != nil { - reason = channel.ChannelInfo.MultiKeyDisabledReason[i] - } + if status != 1 { + if channel.ChannelInfo.MultiKeyDisabledTime != nil { + disabledTime = channel.ChannelInfo.MultiKeyDisabledTime[i] } - - // Create key preview (first 10 chars) - keyPreview := key - if len(key) > 10 { - keyPreview = key[:10] + "..." + if channel.ChannelInfo.MultiKeyDisabledReason != nil { + reason = channel.ChannelInfo.MultiKeyDisabledReason[i] } - - keyStatusList = append(keyStatusList, KeyStatus{ - Index: i, - Status: status, - DisabledTime: disabledTime, - Reason: reason, - KeyPreview: keyPreview, - }) } + + // Create key preview (first 10 chars) + keyPreview := key + if len(key) > 10 { + keyPreview = key[:10] + "..." + } + + allKeyStatusList = append(allKeyStatusList, KeyStatus{ + Index: i, + Status: status, + DisabledTime: disabledTime, + Reason: reason, + KeyPreview: keyPreview, + }) + } + + // Apply status filter if specified + var filteredKeyStatusList []KeyStatus + if request.Status != nil { + for _, keyStatus := range allKeyStatusList { + if keyStatus.Status == *request.Status { + filteredKeyStatusList = append(filteredKeyStatusList, keyStatus) + } + } + } else { + filteredKeyStatusList = allKeyStatusList + } + + // Calculate pagination based on filtered results + filteredTotal := len(filteredKeyStatusList) + totalPages := (filteredTotal + pageSize - 1) / pageSize + if totalPages == 0 { + totalPages = 1 + } + if page > totalPages { + page = totalPages + } + + // Calculate range for current page + start := (page - 1) * pageSize + end := start + pageSize + if end > filteredTotal { + end = filteredTotal + } + + // Get the page data + var pageKeyStatusList []KeyStatus + if start < filteredTotal { + pageKeyStatusList = filteredKeyStatusList[start:end] } c.JSON(http.StatusOK, gin.H{ "success": true, "message": "", "data": MultiKeyStatusResponse{ - Keys: keyStatusList, - Total: total, + Keys: pageKeyStatusList, + Total: filteredTotal, // Total of filtered results Page: page, PageSize: pageSize, TotalPages: totalPages, - EnabledCount: enabledCount, - ManualDisabledCount: manualDisabledCount, - AutoDisabledCount: autoDisabledCount, + EnabledCount: enabledCount, // Overall statistics + ManualDisabledCount: manualDisabledCount, // Overall statistics + AutoDisabledCount: autoDisabledCount, // Overall statistics }, }) return @@ -1231,8 +1251,6 @@ func ManageMultiKeys(c *gin.Context) { } channel.ChannelInfo.MultiKeyStatusList[keyIndex] = 2 // disabled - channel.ChannelInfo.MultiKeyDisabledTime[keyIndex] = common.GetTimestamp() - channel.ChannelInfo.MultiKeyDisabledReason[keyIndex] = "手动禁用" err = channel.Update() if err != nil { @@ -1289,6 +1307,77 @@ func ManageMultiKeys(c *gin.Context) { }) return + case "enable_all_keys": + // 清空所有禁用状态,使所有密钥回到默认启用状态 + var enabledCount int + if channel.ChannelInfo.MultiKeyStatusList != nil { + enabledCount = len(channel.ChannelInfo.MultiKeyStatusList) + } + + channel.ChannelInfo.MultiKeyStatusList = make(map[int]int) + channel.ChannelInfo.MultiKeyDisabledTime = make(map[int]int64) + channel.ChannelInfo.MultiKeyDisabledReason = make(map[int]string) + + err = channel.Update() + if err != nil { + common.ApiError(c, err) + return + } + + model.InitChannelCache() + c.JSON(http.StatusOK, gin.H{ + "success": true, + "message": fmt.Sprintf("已启用 %d 个密钥", enabledCount), + }) + return + + case "disable_all_keys": + // 禁用所有启用的密钥 + if channel.ChannelInfo.MultiKeyStatusList == nil { + channel.ChannelInfo.MultiKeyStatusList = make(map[int]int) + } + if channel.ChannelInfo.MultiKeyDisabledTime == nil { + channel.ChannelInfo.MultiKeyDisabledTime = make(map[int]int64) + } + if channel.ChannelInfo.MultiKeyDisabledReason == nil { + channel.ChannelInfo.MultiKeyDisabledReason = make(map[int]string) + } + + var disabledCount int + for i := 0; i < channel.ChannelInfo.MultiKeySize; i++ { + status := 1 // default enabled + if s, exists := channel.ChannelInfo.MultiKeyStatusList[i]; exists { + status = s + } + + // 只禁用当前启用的密钥 + if status == 1 { + channel.ChannelInfo.MultiKeyStatusList[i] = 2 // disabled + disabledCount++ + } + } + + if disabledCount == 0 { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": "没有可禁用的密钥", + }) + return + } + + err = channel.Update() + if err != nil { + common.ApiError(c, err) + return + } + + model.InitChannelCache() + c.JSON(http.StatusOK, gin.H{ + "success": true, + "message": fmt.Sprintf("已禁用 %d 个密钥", disabledCount), + }) + return + case "delete_disabled_keys": keys := channel.GetKeys() var remainingKeys []string diff --git a/web/src/components/table/channels/modals/MultiKeyManageModal.jsx b/web/src/components/table/channels/modals/MultiKeyManageModal.jsx index 6bb14184..161da1cc 100644 --- a/web/src/components/table/channels/modals/MultiKeyManageModal.jsx +++ b/web/src/components/table/channels/modals/MultiKeyManageModal.jsx @@ -67,18 +67,28 @@ const MultiKeyManageModal = ({ const [manualDisabledCount, setManualDisabledCount] = useState(0); const [autoDisabledCount, setAutoDisabledCount] = useState(0); + // Filter states + const [statusFilter, setStatusFilter] = useState(null); // null=all, 1=enabled, 2=manual_disabled, 3=auto_disabled + // Load key status data - const loadKeyStatus = async (page = currentPage, size = pageSize) => { + const loadKeyStatus = async (page = currentPage, size = pageSize, status = statusFilter) => { if (!channel?.id) return; setLoading(true); try { - const res = await API.post('/api/channel/multi_key/manage', { + const requestData = { channel_id: channel.id, action: 'get_key_status', page: page, page_size: size - }); + }; + + // Add status filter if specified + if (status !== null) { + requestData.status = status; + } + + const res = await API.post('/api/channel/multi_key/manage', requestData); if (res.data.success) { const data = res.data.data; @@ -88,7 +98,7 @@ const MultiKeyManageModal = ({ setPageSize(data.page_size || 50); setTotalPages(data.total_pages || 0); - // Update statistics + // Update statistics (these are always the overall statistics) setEnabledCount(data.enabled_count || 0); setManualDisabledCount(data.manual_disabled_count || 0); setAutoDisabledCount(data.auto_disabled_count || 0); @@ -155,6 +165,58 @@ const MultiKeyManageModal = ({ } }; + // Enable all disabled keys + const handleEnableAll = async () => { + setOperationLoading(prev => ({ ...prev, enable_all: true })); + + try { + const res = await API.post('/api/channel/multi_key/manage', { + channel_id: channel.id, + action: 'enable_all_keys' + }); + + if (res.data.success) { + showSuccess(res.data.message || t('已启用所有密钥')); + // Reset to first page after bulk operation + setCurrentPage(1); + await loadKeyStatus(1, pageSize); + onRefresh && onRefresh(); // Refresh parent component + } else { + showError(res.data.message); + } + } catch (error) { + showError(t('启用所有密钥失败')); + } finally { + setOperationLoading(prev => ({ ...prev, enable_all: false })); + } + }; + + // Disable all enabled keys + const handleDisableAll = async () => { + setOperationLoading(prev => ({ ...prev, disable_all: true })); + + try { + const res = await API.post('/api/channel/multi_key/manage', { + channel_id: channel.id, + action: 'disable_all_keys' + }); + + if (res.data.success) { + showSuccess(res.data.message || t('已禁用所有密钥')); + // Reset to first page after bulk operation + setCurrentPage(1); + await loadKeyStatus(1, pageSize); + onRefresh && onRefresh(); // Refresh parent component + } else { + showError(res.data.message); + } + } catch (error) { + showError(t('禁用所有密钥失败')); + } finally { + setOperationLoading(prev => ({ ...prev, disable_all: false })); + } + }; + // Delete all disabled keys const handleDeleteDisabledKeys = async () => { setOperationLoading(prev => ({ ...prev, delete_disabled: true })); @@ -194,6 +256,13 @@ const MultiKeyManageModal = ({ loadKeyStatus(1, size); }; + // Handle status filter change + const handleStatusFilterChange = (status) => { + setStatusFilter(status); + setCurrentPage(1); // Reset to first page when filter changes + loadKeyStatus(1, pageSize, status); + }; + // Effect to load data when modal opens useEffect(() => { if (visible && channel?.id) { @@ -212,6 +281,7 @@ const MultiKeyManageModal = ({ setEnabledCount(0); setManualDisabledCount(0); setAutoDisabledCount(0); + setStatusFilter(null); // Reset filter } }, [visible]); @@ -236,15 +306,15 @@ const MultiKeyManageModal = ({ dataIndex: 'index', render: (text) => `#${text}`, }, - { - title: t('密钥预览'), - dataIndex: 'key_preview', - render: (text) => ( - - {text} - - ), - }, + // { + // title: t('密钥预览'), + // dataIndex: 'key_preview', + // render: (text) => ( + // + // {text} + // + // ), + // }, { title: t('状态'), dataIndex: 'status', @@ -292,33 +362,23 @@ const MultiKeyManageModal = ({ render: (_, record) => ( {record.status === 1 ? ( - handleDisableKey(record.index)} + - + {t('禁用')} + ) : ( - handleEnableKey(record.index)} + - + {t('启用')} + )} ), @@ -347,21 +407,48 @@ const MultiKeyManageModal = ({ > {t('刷新')} - {autoDisabledCount > 0 && ( + + + + {enabledCount > 0 && ( )} + + + } > @@ -391,6 +478,28 @@ const MultiKeyManageModal = ({ } /> + {/* Filter Controls */} +
+ {t('状态筛选')}: + + {statusFilter !== null && ( + + {t('当前显示 {{count}} 条筛选结果', { count: total })} + + )} +
+ {/* Key Status Table */} {keyStatusList.length > 0 ? ( From c00f5a17c81a1baab34de0c0e2a8277f701deb0f Mon Sep 17 00:00:00 2001 From: CaIon Date: Mon, 4 Aug 2025 20:16:51 +0800 Subject: [PATCH 8/8] feat: improve layout and pagination handling in MultiKeyManageModal --- .../channels/modals/MultiKeyManageModal.jsx | 157 ++++++++++-------- 1 file changed, 84 insertions(+), 73 deletions(-) diff --git a/web/src/components/table/channels/modals/MultiKeyManageModal.jsx b/web/src/components/table/channels/modals/MultiKeyManageModal.jsx index 161da1cc..89ab790f 100644 --- a/web/src/components/table/channels/modals/MultiKeyManageModal.jsx +++ b/web/src/components/table/channels/modals/MultiKeyManageModal.jsx @@ -395,8 +395,7 @@ const MultiKeyManageModal = ({ } visible={visible} onCancel={onCancel} - width={800} - height={600} + width={900} footer={ @@ -452,11 +451,11 @@ const MultiKeyManageModal = ({ } > -
+
{/* Statistics Banner */} @@ -479,7 +478,7 @@ const MultiKeyManageModal = ({ /> {/* Filter Controls */} -
+
{t('状态筛选')}:
- - {/* Pagination */} - {total > 0 && ( -
- - {t('显示第 {{start}}-{{end}} 条,共 {{total}} 条', { - start: (currentPage - 1) * pageSize + 1, - end: Math.min(currentPage * pageSize, total), - total: total - })} - - -
- - {t('每页显示')}: - - - - - t('第 {{current}} / {{total}} 页', { - current: currentPage, - total: totalPages - }) - } - /> -
+
+ + {keyStatusList.length > 0 ? ( +
+
+
- )} - - ) : ( - !loading && ( - - ) - )} - + + {/* Pagination */} + {total > 0 && ( +
+ + {t('显示第 {{start}}-{{end}} 条,共 {{total}} 条', { + start: (currentPage - 1) * pageSize + 1, + end: Math.min(currentPage * pageSize, total), + total: total + })} + + +
+ + {t('每页显示')}: + + + + + t('第 {{current}} / {{total}} 页', { + current: currentPage, + total: totalPages + }) + } + /> +
+
+ )} + + ) : ( + !loading && ( + + ) + )} + + );