feat(auth): 实现 TOTP 双因素认证功能

新增功能:
- 支持 Google Authenticator 等应用进行 TOTP 二次验证
- 用户可在个人设置中启用/禁用 2FA
- 登录时支持 TOTP 验证流程
- 管理后台可全局开关 TOTP 功能

安全增强:
- TOTP 密钥使用 AES-256-GCM 加密存储
- 添加 TOTP_ENCRYPTION_KEY 配置项,必须手动配置才能启用功能
- 防止服务重启导致加密密钥变更使用户无法登录
- 验证失败次数限制,防止暴力破解

配置说明:
- Docker 部署:在 .env 中设置 TOTP_ENCRYPTION_KEY
- 非 Docker 部署:在 config.yaml 中设置 totp.encryption_key
- 生成密钥命令:openssl rand -hex 32
This commit is contained in:
shaw
2026-01-26 08:45:43 +08:00
parent 74e05b83ea
commit 1245f07a2d
60 changed files with 4140 additions and 350 deletions

View File

@@ -0,0 +1,95 @@
package repository
import (
"crypto/aes"
"crypto/cipher"
"crypto/rand"
"encoding/base64"
"encoding/hex"
"fmt"
"io"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/service"
)
// AESEncryptor implements SecretEncryptor using AES-256-GCM
type AESEncryptor struct {
key []byte
}
// NewAESEncryptor creates a new AES encryptor
func NewAESEncryptor(cfg *config.Config) (service.SecretEncryptor, error) {
key, err := hex.DecodeString(cfg.Totp.EncryptionKey)
if err != nil {
return nil, fmt.Errorf("invalid totp encryption key: %w", err)
}
if len(key) != 32 {
return nil, fmt.Errorf("totp encryption key must be 32 bytes (64 hex chars), got %d bytes", len(key))
}
return &AESEncryptor{key: key}, nil
}
// Encrypt encrypts plaintext using AES-256-GCM
// Output format: base64(nonce + ciphertext + tag)
func (e *AESEncryptor) Encrypt(plaintext string) (string, error) {
block, err := aes.NewCipher(e.key)
if err != nil {
return "", fmt.Errorf("create cipher: %w", err)
}
gcm, err := cipher.NewGCM(block)
if err != nil {
return "", fmt.Errorf("create gcm: %w", err)
}
// Generate a random nonce
nonce := make([]byte, gcm.NonceSize())
if _, err := io.ReadFull(rand.Reader, nonce); err != nil {
return "", fmt.Errorf("generate nonce: %w", err)
}
// Encrypt the plaintext
// Seal appends the ciphertext and tag to the nonce
ciphertext := gcm.Seal(nonce, nonce, []byte(plaintext), nil)
// Encode as base64
return base64.StdEncoding.EncodeToString(ciphertext), nil
}
// Decrypt decrypts ciphertext using AES-256-GCM
func (e *AESEncryptor) Decrypt(ciphertext string) (string, error) {
// Decode from base64
data, err := base64.StdEncoding.DecodeString(ciphertext)
if err != nil {
return "", fmt.Errorf("decode base64: %w", err)
}
block, err := aes.NewCipher(e.key)
if err != nil {
return "", fmt.Errorf("create cipher: %w", err)
}
gcm, err := cipher.NewGCM(block)
if err != nil {
return "", fmt.Errorf("create gcm: %w", err)
}
nonceSize := gcm.NonceSize()
if len(data) < nonceSize {
return "", fmt.Errorf("ciphertext too short")
}
// Extract nonce and ciphertext
nonce, ciphertextData := data[:nonceSize], data[nonceSize:]
// Decrypt
plaintext, err := gcm.Open(nil, nonce, ciphertextData, nil)
if err != nil {
return "", fmt.Errorf("decrypt: %w", err)
}
return string(plaintext), nil
}

View File

@@ -387,17 +387,20 @@ func userEntityToService(u *dbent.User) *service.User {
return nil
}
return &service.User{
ID: u.ID,
Email: u.Email,
Username: u.Username,
Notes: u.Notes,
PasswordHash: u.PasswordHash,
Role: u.Role,
Balance: u.Balance,
Concurrency: u.Concurrency,
Status: u.Status,
CreatedAt: u.CreatedAt,
UpdatedAt: u.UpdatedAt,
ID: u.ID,
Email: u.Email,
Username: u.Username,
Notes: u.Notes,
PasswordHash: u.PasswordHash,
Role: u.Role,
Balance: u.Balance,
Concurrency: u.Concurrency,
Status: u.Status,
TotpSecretEncrypted: u.TotpSecretEncrypted,
TotpEnabled: u.TotpEnabled,
TotpEnabledAt: u.TotpEnabledAt,
CreatedAt: u.CreatedAt,
UpdatedAt: u.UpdatedAt,
}
}

View File

@@ -0,0 +1,149 @@
package repository
import (
"context"
"encoding/json"
"fmt"
"time"
"github.com/redis/go-redis/v9"
"github.com/Wei-Shaw/sub2api/internal/service"
)
const (
totpSetupKeyPrefix = "totp:setup:"
totpLoginKeyPrefix = "totp:login:"
totpAttemptsKeyPrefix = "totp:attempts:"
totpAttemptsTTL = 15 * time.Minute
)
// TotpCache implements service.TotpCache using Redis
type TotpCache struct {
rdb *redis.Client
}
// NewTotpCache creates a new TOTP cache
func NewTotpCache(rdb *redis.Client) service.TotpCache {
return &TotpCache{rdb: rdb}
}
// GetSetupSession retrieves a TOTP setup session
func (c *TotpCache) GetSetupSession(ctx context.Context, userID int64) (*service.TotpSetupSession, error) {
key := fmt.Sprintf("%s%d", totpSetupKeyPrefix, userID)
data, err := c.rdb.Get(ctx, key).Bytes()
if err != nil {
if err == redis.Nil {
return nil, nil
}
return nil, fmt.Errorf("get setup session: %w", err)
}
var session service.TotpSetupSession
if err := json.Unmarshal(data, &session); err != nil {
return nil, fmt.Errorf("unmarshal setup session: %w", err)
}
return &session, nil
}
// SetSetupSession stores a TOTP setup session
func (c *TotpCache) SetSetupSession(ctx context.Context, userID int64, session *service.TotpSetupSession, ttl time.Duration) error {
key := fmt.Sprintf("%s%d", totpSetupKeyPrefix, userID)
data, err := json.Marshal(session)
if err != nil {
return fmt.Errorf("marshal setup session: %w", err)
}
if err := c.rdb.Set(ctx, key, data, ttl).Err(); err != nil {
return fmt.Errorf("set setup session: %w", err)
}
return nil
}
// DeleteSetupSession deletes a TOTP setup session
func (c *TotpCache) DeleteSetupSession(ctx context.Context, userID int64) error {
key := fmt.Sprintf("%s%d", totpSetupKeyPrefix, userID)
return c.rdb.Del(ctx, key).Err()
}
// GetLoginSession retrieves a TOTP login session
func (c *TotpCache) GetLoginSession(ctx context.Context, tempToken string) (*service.TotpLoginSession, error) {
key := totpLoginKeyPrefix + tempToken
data, err := c.rdb.Get(ctx, key).Bytes()
if err != nil {
if err == redis.Nil {
return nil, nil
}
return nil, fmt.Errorf("get login session: %w", err)
}
var session service.TotpLoginSession
if err := json.Unmarshal(data, &session); err != nil {
return nil, fmt.Errorf("unmarshal login session: %w", err)
}
return &session, nil
}
// SetLoginSession stores a TOTP login session
func (c *TotpCache) SetLoginSession(ctx context.Context, tempToken string, session *service.TotpLoginSession, ttl time.Duration) error {
key := totpLoginKeyPrefix + tempToken
data, err := json.Marshal(session)
if err != nil {
return fmt.Errorf("marshal login session: %w", err)
}
if err := c.rdb.Set(ctx, key, data, ttl).Err(); err != nil {
return fmt.Errorf("set login session: %w", err)
}
return nil
}
// DeleteLoginSession deletes a TOTP login session
func (c *TotpCache) DeleteLoginSession(ctx context.Context, tempToken string) error {
key := totpLoginKeyPrefix + tempToken
return c.rdb.Del(ctx, key).Err()
}
// IncrementVerifyAttempts increments the verify attempt counter
func (c *TotpCache) IncrementVerifyAttempts(ctx context.Context, userID int64) (int, error) {
key := fmt.Sprintf("%s%d", totpAttemptsKeyPrefix, userID)
// Use pipeline for atomic increment and set TTL
pipe := c.rdb.Pipeline()
incrCmd := pipe.Incr(ctx, key)
pipe.Expire(ctx, key, totpAttemptsTTL)
if _, err := pipe.Exec(ctx); err != nil {
return 0, fmt.Errorf("increment verify attempts: %w", err)
}
count, err := incrCmd.Result()
if err != nil {
return 0, fmt.Errorf("get increment result: %w", err)
}
return int(count), nil
}
// GetVerifyAttempts gets the current verify attempt count
func (c *TotpCache) GetVerifyAttempts(ctx context.Context, userID int64) (int, error) {
key := fmt.Sprintf("%s%d", totpAttemptsKeyPrefix, userID)
count, err := c.rdb.Get(ctx, key).Int()
if err != nil {
if err == redis.Nil {
return 0, nil
}
return 0, fmt.Errorf("get verify attempts: %w", err)
}
return count, nil
}
// ClearVerifyAttempts clears the verify attempt counter
func (c *TotpCache) ClearVerifyAttempts(ctx context.Context, userID int64) error {
key := fmt.Sprintf("%s%d", totpAttemptsKeyPrefix, userID)
return c.rdb.Del(ctx, key).Err()
}

View File

@@ -7,6 +7,7 @@ import (
"fmt"
"sort"
"strings"
"time"
dbent "github.com/Wei-Shaw/sub2api/ent"
dbuser "github.com/Wei-Shaw/sub2api/ent/user"
@@ -466,3 +467,46 @@ func applyUserEntityToService(dst *service.User, src *dbent.User) {
dst.CreatedAt = src.CreatedAt
dst.UpdatedAt = src.UpdatedAt
}
// UpdateTotpSecret 更新用户的 TOTP 加密密钥
func (r *userRepository) UpdateTotpSecret(ctx context.Context, userID int64, encryptedSecret *string) error {
client := clientFromContext(ctx, r.client)
update := client.User.UpdateOneID(userID)
if encryptedSecret == nil {
update = update.ClearTotpSecretEncrypted()
} else {
update = update.SetTotpSecretEncrypted(*encryptedSecret)
}
_, err := update.Save(ctx)
if err != nil {
return translatePersistenceError(err, service.ErrUserNotFound, nil)
}
return nil
}
// EnableTotp 启用用户的 TOTP 双因素认证
func (r *userRepository) EnableTotp(ctx context.Context, userID int64) error {
client := clientFromContext(ctx, r.client)
_, err := client.User.UpdateOneID(userID).
SetTotpEnabled(true).
SetTotpEnabledAt(time.Now()).
Save(ctx)
if err != nil {
return translatePersistenceError(err, service.ErrUserNotFound, nil)
}
return nil
}
// DisableTotp 禁用用户的 TOTP 双因素认证
func (r *userRepository) DisableTotp(ctx context.Context, userID int64) error {
client := clientFromContext(ctx, r.client)
_, err := client.User.UpdateOneID(userID).
SetTotpEnabled(false).
ClearTotpEnabledAt().
ClearTotpSecretEncrypted().
Save(ctx)
if err != nil {
return translatePersistenceError(err, service.ErrUserNotFound, nil)
}
return nil
}

View File

@@ -82,6 +82,10 @@ var ProviderSet = wire.NewSet(
NewSchedulerCache,
NewSchedulerOutboxRepository,
NewProxyLatencyCache,
NewTotpCache,
// Encryptors
NewAESEncryptor,
// HTTP service ports (DI Strategy A: return interface directly)
NewTurnstileVerifier,