fix(auth): 注册接口安全加固 - 默认关闭注册

This commit is contained in:
shaw
2026-01-09 14:49:20 +08:00
parent 0a9c17b9d1
commit 43f104bdf7
4 changed files with 52 additions and 15 deletions

View File

@@ -75,8 +75,8 @@ func (s *AuthService) Register(ctx context.Context, email, password string) (str
// RegisterWithVerification 用户注册支持邮件验证返回token和用户 // RegisterWithVerification 用户注册支持邮件验证返回token和用户
func (s *AuthService) RegisterWithVerification(ctx context.Context, email, password, verifyCode string) (string, *User, error) { func (s *AuthService) RegisterWithVerification(ctx context.Context, email, password, verifyCode string) (string, *User, error) {
// 检查是否开放注册 // 检查是否开放注册默认关闭settingService 未配置时不允许注册)
if s.settingService != nil && !s.settingService.IsRegistrationEnabled(ctx) { if s.settingService == nil || !s.settingService.IsRegistrationEnabled(ctx) {
return "", nil, ErrRegDisabled return "", nil, ErrRegDisabled
} }
@@ -132,6 +132,10 @@ func (s *AuthService) RegisterWithVerification(ctx context.Context, email, passw
} }
if err := s.userRepo.Create(ctx, user); err != nil { if err := s.userRepo.Create(ctx, user); err != nil {
// 优先检查邮箱冲突错误(竞态条件下可能发生)
if errors.Is(err, ErrEmailExists) {
return "", nil, ErrEmailExists
}
log.Printf("[Auth] Database error creating user: %v", err) log.Printf("[Auth] Database error creating user: %v", err)
return "", nil, ErrServiceUnavailable return "", nil, ErrServiceUnavailable
} }
@@ -152,8 +156,8 @@ type SendVerifyCodeResult struct {
// SendVerifyCode 发送邮箱验证码(同步方式) // SendVerifyCode 发送邮箱验证码(同步方式)
func (s *AuthService) SendVerifyCode(ctx context.Context, email string) error { func (s *AuthService) SendVerifyCode(ctx context.Context, email string) error {
// 检查是否开放注册 // 检查是否开放注册(默认关闭)
if s.settingService != nil && !s.settingService.IsRegistrationEnabled(ctx) { if s.settingService == nil || !s.settingService.IsRegistrationEnabled(ctx) {
return ErrRegDisabled return ErrRegDisabled
} }
@@ -185,8 +189,8 @@ func (s *AuthService) SendVerifyCode(ctx context.Context, email string) error {
func (s *AuthService) SendVerifyCodeAsync(ctx context.Context, email string) (*SendVerifyCodeResult, error) { func (s *AuthService) SendVerifyCodeAsync(ctx context.Context, email string) (*SendVerifyCodeResult, error) {
log.Printf("[Auth] SendVerifyCodeAsync called for email: %s", email) log.Printf("[Auth] SendVerifyCodeAsync called for email: %s", email)
// 检查是否开放注册 // 检查是否开放注册(默认关闭)
if s.settingService != nil && !s.settingService.IsRegistrationEnabled(ctx) { if s.settingService == nil || !s.settingService.IsRegistrationEnabled(ctx) {
log.Println("[Auth] Registration is disabled") log.Println("[Auth] Registration is disabled")
return nil, ErrRegDisabled return nil, ErrRegDisabled
} }
@@ -270,7 +274,7 @@ func (s *AuthService) IsTurnstileEnabled(ctx context.Context) bool {
// IsRegistrationEnabled 检查是否开放注册 // IsRegistrationEnabled 检查是否开放注册
func (s *AuthService) IsRegistrationEnabled(ctx context.Context) bool { func (s *AuthService) IsRegistrationEnabled(ctx context.Context) bool {
if s.settingService == nil { if s.settingService == nil {
return true return false // 安全默认settingService 未配置时关闭注册
} }
return s.settingService.IsRegistrationEnabled(ctx) return s.settingService.IsRegistrationEnabled(ctx)
} }

View File

@@ -113,6 +113,15 @@ func TestAuthService_Register_Disabled(t *testing.T) {
require.ErrorIs(t, err, ErrRegDisabled) require.ErrorIs(t, err, ErrRegDisabled)
} }
func TestAuthService_Register_DisabledByDefault(t *testing.T) {
// 当 settings 为 nil设置项不存在注册应该默认关闭
repo := &userRepoStub{}
service := newAuthService(repo, nil, nil)
_, _, err := service.Register(context.Background(), "user@test.com", "password")
require.ErrorIs(t, err, ErrRegDisabled)
}
func TestAuthService_Register_EmailVerifyEnabledButServiceNotConfigured(t *testing.T) { func TestAuthService_Register_EmailVerifyEnabledButServiceNotConfigured(t *testing.T) {
repo := &userRepoStub{} repo := &userRepoStub{}
// 邮件验证开启但 emailCache 为 nilemailService 未配置) // 邮件验证开启但 emailCache 为 nilemailService 未配置)
@@ -155,7 +164,9 @@ func TestAuthService_Register_EmailVerifyInvalid(t *testing.T) {
func TestAuthService_Register_EmailExists(t *testing.T) { func TestAuthService_Register_EmailExists(t *testing.T) {
repo := &userRepoStub{exists: true} repo := &userRepoStub{exists: true}
service := newAuthService(repo, nil, nil) service := newAuthService(repo, map[string]string{
SettingKeyRegistrationEnabled: "true",
}, nil)
_, _, err := service.Register(context.Background(), "user@test.com", "password") _, _, err := service.Register(context.Background(), "user@test.com", "password")
require.ErrorIs(t, err, ErrEmailExists) require.ErrorIs(t, err, ErrEmailExists)
@@ -163,7 +174,9 @@ func TestAuthService_Register_EmailExists(t *testing.T) {
func TestAuthService_Register_CheckEmailError(t *testing.T) { func TestAuthService_Register_CheckEmailError(t *testing.T) {
repo := &userRepoStub{existsErr: errors.New("db down")} repo := &userRepoStub{existsErr: errors.New("db down")}
service := newAuthService(repo, nil, nil) service := newAuthService(repo, map[string]string{
SettingKeyRegistrationEnabled: "true",
}, nil)
_, _, err := service.Register(context.Background(), "user@test.com", "password") _, _, err := service.Register(context.Background(), "user@test.com", "password")
require.ErrorIs(t, err, ErrServiceUnavailable) require.ErrorIs(t, err, ErrServiceUnavailable)
@@ -171,15 +184,30 @@ func TestAuthService_Register_CheckEmailError(t *testing.T) {
func TestAuthService_Register_CreateError(t *testing.T) { func TestAuthService_Register_CreateError(t *testing.T) {
repo := &userRepoStub{createErr: errors.New("create failed")} repo := &userRepoStub{createErr: errors.New("create failed")}
service := newAuthService(repo, nil, nil) service := newAuthService(repo, map[string]string{
SettingKeyRegistrationEnabled: "true",
}, nil)
_, _, err := service.Register(context.Background(), "user@test.com", "password") _, _, err := service.Register(context.Background(), "user@test.com", "password")
require.ErrorIs(t, err, ErrServiceUnavailable) require.ErrorIs(t, err, ErrServiceUnavailable)
} }
func TestAuthService_Register_CreateEmailExistsRace(t *testing.T) {
// 模拟竞态条件ExistsByEmail 返回 false但 Create 时因唯一约束失败
repo := &userRepoStub{createErr: ErrEmailExists}
service := newAuthService(repo, map[string]string{
SettingKeyRegistrationEnabled: "true",
}, nil)
_, _, err := service.Register(context.Background(), "user@test.com", "password")
require.ErrorIs(t, err, ErrEmailExists)
}
func TestAuthService_Register_Success(t *testing.T) { func TestAuthService_Register_Success(t *testing.T) {
repo := &userRepoStub{nextID: 5} repo := &userRepoStub{nextID: 5}
service := newAuthService(repo, nil, nil) service := newAuthService(repo, map[string]string{
SettingKeyRegistrationEnabled: "true",
}, nil)
token, user, err := service.Register(context.Background(), "user@test.com", "password") token, user, err := service.Register(context.Background(), "user@test.com", "password")
require.NoError(t, err) require.NoError(t, err)

View File

@@ -5,6 +5,7 @@ import (
"crypto/rand" "crypto/rand"
"crypto/tls" "crypto/tls"
"fmt" "fmt"
"log"
"math/big" "math/big"
"net/smtp" "net/smtp"
"strconv" "strconv"
@@ -256,7 +257,9 @@ func (s *EmailService) VerifyCode(ctx context.Context, email, code string) error
// 验证码不匹配 // 验证码不匹配
if data.Code != code { if data.Code != code {
data.Attempts++ data.Attempts++
_ = s.cache.SetVerificationCode(ctx, email, data, verifyCodeTTL) if err := s.cache.SetVerificationCode(ctx, email, data, verifyCodeTTL); err != nil {
log.Printf("[Email] Failed to update verification attempt count: %v", err)
}
if data.Attempts >= maxVerifyCodeAttempts { if data.Attempts >= maxVerifyCodeAttempts {
return ErrVerifyCodeMaxAttempts return ErrVerifyCodeMaxAttempts
} }
@@ -264,7 +267,9 @@ func (s *EmailService) VerifyCode(ctx context.Context, email, code string) error
} }
// 验证成功,删除验证码 // 验证成功,删除验证码
_ = s.cache.DeleteVerificationCode(ctx, email) if err := s.cache.DeleteVerificationCode(ctx, email); err != nil {
log.Printf("[Email] Failed to delete verification code after success: %v", err)
}
return nil return nil
} }

View File

@@ -141,8 +141,8 @@ func (s *SettingService) UpdateSettings(ctx context.Context, settings *SystemSet
func (s *SettingService) IsRegistrationEnabled(ctx context.Context) bool { func (s *SettingService) IsRegistrationEnabled(ctx context.Context) bool {
value, err := s.settingRepo.GetValue(ctx, SettingKeyRegistrationEnabled) value, err := s.settingRepo.GetValue(ctx, SettingKeyRegistrationEnabled)
if err != nil { if err != nil {
// 默认开放注册 // 安全默认:如果设置不存在或查询出错,默认关闭注册
return true return false
} }
return value == "true" return value == "true"
} }