Files
sub2api/backend/internal/service/auth_service_register_test.go
shaw 9cc8352593 feat(auth): 密码重置邮件队列化与限流优化
- 邮件发送改为异步队列处理,避免并发导致发送失败
- 新增 Email 维度限流(30秒冷却期),防止邮件轰炸
- Token 验证使用常量时间比较,防止时序攻击
- 重构代码消除冗余,提取公共验证逻辑
2026-01-24 22:55:28 +08:00

316 lines
9.3 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

//go:build unit
package service
import (
"context"
"errors"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/stretchr/testify/require"
)
type settingRepoStub struct {
values map[string]string
err error
}
func (s *settingRepoStub) Get(ctx context.Context, key string) (*Setting, error) {
panic("unexpected Get call")
}
func (s *settingRepoStub) GetValue(ctx context.Context, key string) (string, error) {
if s.err != nil {
return "", s.err
}
if v, ok := s.values[key]; ok {
return v, nil
}
return "", ErrSettingNotFound
}
func (s *settingRepoStub) Set(ctx context.Context, key, value string) error {
panic("unexpected Set call")
}
func (s *settingRepoStub) GetMultiple(ctx context.Context, keys []string) (map[string]string, error) {
panic("unexpected GetMultiple call")
}
func (s *settingRepoStub) SetMultiple(ctx context.Context, settings map[string]string) error {
panic("unexpected SetMultiple call")
}
func (s *settingRepoStub) GetAll(ctx context.Context) (map[string]string, error) {
panic("unexpected GetAll call")
}
func (s *settingRepoStub) Delete(ctx context.Context, key string) error {
panic("unexpected Delete call")
}
type emailCacheStub struct {
data *VerificationCodeData
err error
}
func (s *emailCacheStub) GetVerificationCode(ctx context.Context, email string) (*VerificationCodeData, error) {
if s.err != nil {
return nil, s.err
}
return s.data, nil
}
func (s *emailCacheStub) SetVerificationCode(ctx context.Context, email string, data *VerificationCodeData, ttl time.Duration) error {
return nil
}
func (s *emailCacheStub) DeleteVerificationCode(ctx context.Context, email string) error {
return nil
}
func (s *emailCacheStub) GetPasswordResetToken(ctx context.Context, email string) (*PasswordResetTokenData, error) {
return nil, nil
}
func (s *emailCacheStub) SetPasswordResetToken(ctx context.Context, email string, data *PasswordResetTokenData, ttl time.Duration) error {
return nil
}
func (s *emailCacheStub) DeletePasswordResetToken(ctx context.Context, email string) error {
return nil
}
func (s *emailCacheStub) IsPasswordResetEmailInCooldown(ctx context.Context, email string) bool {
return false
}
func (s *emailCacheStub) SetPasswordResetEmailCooldown(ctx context.Context, email string, ttl time.Duration) error {
return nil
}
func newAuthService(repo *userRepoStub, settings map[string]string, emailCache EmailCache) *AuthService {
cfg := &config.Config{
JWT: config.JWTConfig{
Secret: "test-secret",
ExpireHour: 1,
},
Default: config.DefaultConfig{
UserBalance: 3.5,
UserConcurrency: 2,
},
}
var settingService *SettingService
if settings != nil {
settingService = NewSettingService(&settingRepoStub{values: settings}, cfg)
}
var emailService *EmailService
if emailCache != nil {
emailService = NewEmailService(&settingRepoStub{values: settings}, emailCache)
}
return NewAuthService(
repo,
cfg,
settingService,
emailService,
nil,
nil,
nil, // promoService
)
}
func TestAuthService_Register_Disabled(t *testing.T) {
repo := &userRepoStub{}
service := newAuthService(repo, map[string]string{
SettingKeyRegistrationEnabled: "false",
}, nil)
_, _, err := service.Register(context.Background(), "user@test.com", "password")
require.ErrorIs(t, err, ErrRegDisabled)
}
func TestAuthService_Register_DisabledByDefault(t *testing.T) {
// 当 settings 为 nil设置项不存在注册应该默认关闭
repo := &userRepoStub{}
service := newAuthService(repo, nil, nil)
_, _, err := service.Register(context.Background(), "user@test.com", "password")
require.ErrorIs(t, err, ErrRegDisabled)
}
func TestAuthService_Register_EmailVerifyEnabledButServiceNotConfigured(t *testing.T) {
repo := &userRepoStub{}
// 邮件验证开启但 emailCache 为 nilemailService 未配置)
service := newAuthService(repo, map[string]string{
SettingKeyRegistrationEnabled: "true",
SettingKeyEmailVerifyEnabled: "true",
}, nil)
// 应返回服务不可用错误,而不是允许绕过验证
_, _, err := service.RegisterWithVerification(context.Background(), "user@test.com", "password", "any-code", "")
require.ErrorIs(t, err, ErrServiceUnavailable)
}
func TestAuthService_Register_EmailVerifyRequired(t *testing.T) {
repo := &userRepoStub{}
cache := &emailCacheStub{} // 配置 emailService
service := newAuthService(repo, map[string]string{
SettingKeyRegistrationEnabled: "true",
SettingKeyEmailVerifyEnabled: "true",
}, cache)
_, _, err := service.RegisterWithVerification(context.Background(), "user@test.com", "password", "", "")
require.ErrorIs(t, err, ErrEmailVerifyRequired)
}
func TestAuthService_Register_EmailVerifyInvalid(t *testing.T) {
repo := &userRepoStub{}
cache := &emailCacheStub{
data: &VerificationCodeData{Code: "expected", Attempts: 0},
}
service := newAuthService(repo, map[string]string{
SettingKeyRegistrationEnabled: "true",
SettingKeyEmailVerifyEnabled: "true",
}, cache)
_, _, err := service.RegisterWithVerification(context.Background(), "user@test.com", "password", "wrong", "")
require.ErrorIs(t, err, ErrInvalidVerifyCode)
require.ErrorContains(t, err, "verify code")
}
func TestAuthService_Register_EmailExists(t *testing.T) {
repo := &userRepoStub{exists: true}
service := newAuthService(repo, map[string]string{
SettingKeyRegistrationEnabled: "true",
}, nil)
_, _, err := service.Register(context.Background(), "user@test.com", "password")
require.ErrorIs(t, err, ErrEmailExists)
}
func TestAuthService_Register_CheckEmailError(t *testing.T) {
repo := &userRepoStub{existsErr: errors.New("db down")}
service := newAuthService(repo, map[string]string{
SettingKeyRegistrationEnabled: "true",
}, nil)
_, _, err := service.Register(context.Background(), "user@test.com", "password")
require.ErrorIs(t, err, ErrServiceUnavailable)
}
func TestAuthService_Register_ReservedEmail(t *testing.T) {
repo := &userRepoStub{}
service := newAuthService(repo, map[string]string{
SettingKeyRegistrationEnabled: "true",
}, nil)
_, _, err := service.Register(context.Background(), "linuxdo-123@linuxdo-connect.invalid", "password")
require.ErrorIs(t, err, ErrEmailReserved)
}
func TestAuthService_Register_CreateError(t *testing.T) {
repo := &userRepoStub{createErr: errors.New("create failed")}
service := newAuthService(repo, map[string]string{
SettingKeyRegistrationEnabled: "true",
}, nil)
_, _, err := service.Register(context.Background(), "user@test.com", "password")
require.ErrorIs(t, err, ErrServiceUnavailable)
}
func TestAuthService_Register_CreateEmailExistsRace(t *testing.T) {
// 模拟竞态条件ExistsByEmail 返回 false但 Create 时因唯一约束失败
repo := &userRepoStub{createErr: ErrEmailExists}
service := newAuthService(repo, map[string]string{
SettingKeyRegistrationEnabled: "true",
}, nil)
_, _, err := service.Register(context.Background(), "user@test.com", "password")
require.ErrorIs(t, err, ErrEmailExists)
}
func TestAuthService_Register_Success(t *testing.T) {
repo := &userRepoStub{nextID: 5}
service := newAuthService(repo, map[string]string{
SettingKeyRegistrationEnabled: "true",
}, nil)
token, user, err := service.Register(context.Background(), "user@test.com", "password")
require.NoError(t, err)
require.NotEmpty(t, token)
require.NotNil(t, user)
require.Equal(t, int64(5), user.ID)
require.Equal(t, "user@test.com", user.Email)
require.Equal(t, RoleUser, user.Role)
require.Equal(t, StatusActive, user.Status)
require.Equal(t, 3.5, user.Balance)
require.Equal(t, 2, user.Concurrency)
require.Len(t, repo.created, 1)
require.True(t, user.CheckPassword("password"))
}
func TestAuthService_ValidateToken_ExpiredReturnsClaimsWithError(t *testing.T) {
repo := &userRepoStub{}
service := newAuthService(repo, nil, nil)
// 创建用户并生成 token
user := &User{
ID: 1,
Email: "test@test.com",
Role: RoleUser,
Status: StatusActive,
TokenVersion: 1,
}
token, err := service.GenerateToken(user)
require.NoError(t, err)
// 验证有效 token
claims, err := service.ValidateToken(token)
require.NoError(t, err)
require.NotNil(t, claims)
require.Equal(t, int64(1), claims.UserID)
// 模拟过期 token通过创建一个过期很久的 token
service.cfg.JWT.ExpireHour = -1 // 设置为负数使 token 立即过期
expiredToken, err := service.GenerateToken(user)
require.NoError(t, err)
service.cfg.JWT.ExpireHour = 1 // 恢复
// 验证过期 token 应返回 claims 和 ErrTokenExpired
claims, err = service.ValidateToken(expiredToken)
require.ErrorIs(t, err, ErrTokenExpired)
require.NotNil(t, claims, "claims should not be nil when token is expired")
require.Equal(t, int64(1), claims.UserID)
require.Equal(t, "test@test.com", claims.Email)
}
func TestAuthService_RefreshToken_ExpiredTokenNoPanic(t *testing.T) {
user := &User{
ID: 1,
Email: "test@test.com",
Role: RoleUser,
Status: StatusActive,
TokenVersion: 1,
}
repo := &userRepoStub{user: user}
service := newAuthService(repo, nil, nil)
// 创建过期 token
service.cfg.JWT.ExpireHour = -1
expiredToken, err := service.GenerateToken(user)
require.NoError(t, err)
service.cfg.JWT.ExpireHour = 1
// RefreshToken 使用过期 token 不应 panic
require.NotPanics(t, func() {
newToken, err := service.RefreshToken(context.Background(), expiredToken)
require.NoError(t, err)
require.NotEmpty(t, newToken)
})
}