Files
yinghuoapi/backend/internal/service/auth_service.go
2025-12-25 17:15:01 +08:00

375 lines
10 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package service
import (
"context"
"errors"
"fmt"
"log"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/model"
"github.com/golang-jwt/jwt/v5"
"golang.org/x/crypto/bcrypt"
"gorm.io/gorm"
)
var (
ErrInvalidCredentials = errors.New("invalid email or password")
ErrUserNotActive = errors.New("user is not active")
ErrEmailExists = errors.New("email already exists")
ErrInvalidToken = errors.New("invalid token")
ErrTokenExpired = errors.New("token has expired")
ErrEmailVerifyRequired = errors.New("email verification is required")
ErrRegDisabled = errors.New("registration is currently disabled")
ErrServiceUnavailable = errors.New("service temporarily unavailable")
)
// JWTClaims JWT载荷数据
type JWTClaims struct {
UserID int64 `json:"user_id"`
Email string `json:"email"`
Role string `json:"role"`
jwt.RegisteredClaims
}
// AuthService 认证服务
type AuthService struct {
userRepo UserRepository
cfg *config.Config
settingService *SettingService
emailService *EmailService
turnstileService *TurnstileService
emailQueueService *EmailQueueService
}
// NewAuthService 创建认证服务实例
func NewAuthService(
userRepo UserRepository,
cfg *config.Config,
settingService *SettingService,
emailService *EmailService,
turnstileService *TurnstileService,
emailQueueService *EmailQueueService,
) *AuthService {
return &AuthService{
userRepo: userRepo,
cfg: cfg,
settingService: settingService,
emailService: emailService,
turnstileService: turnstileService,
emailQueueService: emailQueueService,
}
}
// Register 用户注册返回token和用户
func (s *AuthService) Register(ctx context.Context, email, password string) (string, *model.User, error) {
return s.RegisterWithVerification(ctx, email, password, "")
}
// RegisterWithVerification 用户注册支持邮件验证返回token和用户
func (s *AuthService) RegisterWithVerification(ctx context.Context, email, password, verifyCode string) (string, *model.User, error) {
// 检查是否开放注册
if s.settingService != nil && !s.settingService.IsRegistrationEnabled(ctx) {
return "", nil, ErrRegDisabled
}
// 检查是否需要邮件验证
if s.settingService != nil && s.settingService.IsEmailVerifyEnabled(ctx) {
if verifyCode == "" {
return "", nil, ErrEmailVerifyRequired
}
// 验证邮箱验证码
if s.emailService != nil {
if err := s.emailService.VerifyCode(ctx, email, verifyCode); err != nil {
return "", nil, fmt.Errorf("verify code: %w", err)
}
}
}
// 检查邮箱是否已存在
existsEmail, err := s.userRepo.ExistsByEmail(ctx, email)
if err != nil {
log.Printf("[Auth] Database error checking email exists: %v", err)
return "", nil, ErrServiceUnavailable
}
if existsEmail {
return "", nil, ErrEmailExists
}
// 密码哈希
hashedPassword, err := s.HashPassword(password)
if err != nil {
return "", nil, fmt.Errorf("hash password: %w", err)
}
// 获取默认配置
defaultBalance := s.cfg.Default.UserBalance
defaultConcurrency := s.cfg.Default.UserConcurrency
if s.settingService != nil {
defaultBalance = s.settingService.GetDefaultBalance(ctx)
defaultConcurrency = s.settingService.GetDefaultConcurrency(ctx)
}
// 创建用户
user := &model.User{
Email: email,
PasswordHash: hashedPassword,
Role: model.RoleUser,
Balance: defaultBalance,
Concurrency: defaultConcurrency,
Status: model.StatusActive,
}
if err := s.userRepo.Create(ctx, user); err != nil {
log.Printf("[Auth] Database error creating user: %v", err)
return "", nil, ErrServiceUnavailable
}
// 生成token
token, err := s.GenerateToken(user)
if err != nil {
return "", nil, fmt.Errorf("generate token: %w", err)
}
return token, user, nil
}
// SendVerifyCodeResult 发送验证码返回结果
type SendVerifyCodeResult struct {
Countdown int `json:"countdown"` // 倒计时秒数
}
// SendVerifyCode 发送邮箱验证码(同步方式)
func (s *AuthService) SendVerifyCode(ctx context.Context, email string) error {
// 检查是否开放注册
if s.settingService != nil && !s.settingService.IsRegistrationEnabled(ctx) {
return ErrRegDisabled
}
// 检查邮箱是否已存在
existsEmail, err := s.userRepo.ExistsByEmail(ctx, email)
if err != nil {
log.Printf("[Auth] Database error checking email exists: %v", err)
return ErrServiceUnavailable
}
if existsEmail {
return ErrEmailExists
}
// 发送验证码
if s.emailService == nil {
return errors.New("email service not configured")
}
// 获取网站名称
siteName := "Sub2API"
if s.settingService != nil {
siteName = s.settingService.GetSiteName(ctx)
}
return s.emailService.SendVerifyCode(ctx, email, siteName)
}
// SendVerifyCodeAsync 异步发送邮箱验证码并返回倒计时
func (s *AuthService) SendVerifyCodeAsync(ctx context.Context, email string) (*SendVerifyCodeResult, error) {
log.Printf("[Auth] SendVerifyCodeAsync called for email: %s", email)
// 检查是否开放注册
if s.settingService != nil && !s.settingService.IsRegistrationEnabled(ctx) {
log.Println("[Auth] Registration is disabled")
return nil, ErrRegDisabled
}
// 检查邮箱是否已存在
existsEmail, err := s.userRepo.ExistsByEmail(ctx, email)
if err != nil {
log.Printf("[Auth] Database error checking email exists: %v", err)
return nil, ErrServiceUnavailable
}
if existsEmail {
log.Printf("[Auth] Email already exists: %s", email)
return nil, ErrEmailExists
}
// 检查邮件队列服务是否配置
if s.emailQueueService == nil {
log.Println("[Auth] Email queue service not configured")
return nil, errors.New("email queue service not configured")
}
// 获取网站名称
siteName := "Sub2API"
if s.settingService != nil {
siteName = s.settingService.GetSiteName(ctx)
}
// 异步发送
log.Printf("[Auth] Enqueueing verify code for: %s", email)
if err := s.emailQueueService.EnqueueVerifyCode(email, siteName); err != nil {
log.Printf("[Auth] Failed to enqueue: %v", err)
return nil, fmt.Errorf("enqueue verify code: %w", err)
}
log.Printf("[Auth] Verify code enqueued successfully for: %s", email)
return &SendVerifyCodeResult{
Countdown: 60, // 60秒倒计时
}, nil
}
// VerifyTurnstile 验证Turnstile token
func (s *AuthService) VerifyTurnstile(ctx context.Context, token string, remoteIP string) error {
if s.turnstileService == nil {
return nil // 服务未配置则跳过验证
}
return s.turnstileService.VerifyToken(ctx, token, remoteIP)
}
// IsTurnstileEnabled 检查是否启用Turnstile验证
func (s *AuthService) IsTurnstileEnabled(ctx context.Context) bool {
if s.turnstileService == nil {
return false
}
return s.turnstileService.IsEnabled(ctx)
}
// IsRegistrationEnabled 检查是否开放注册
func (s *AuthService) IsRegistrationEnabled(ctx context.Context) bool {
if s.settingService == nil {
return true
}
return s.settingService.IsRegistrationEnabled(ctx)
}
// IsEmailVerifyEnabled 检查是否开启邮件验证
func (s *AuthService) IsEmailVerifyEnabled(ctx context.Context) bool {
if s.settingService == nil {
return false
}
return s.settingService.IsEmailVerifyEnabled(ctx)
}
// Login 用户登录返回JWT token
func (s *AuthService) Login(ctx context.Context, email, password string) (string, *model.User, error) {
// 查找用户
user, err := s.userRepo.GetByEmail(ctx, email)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return "", nil, ErrInvalidCredentials
}
// 记录数据库错误但不暴露给用户
log.Printf("[Auth] Database error during login: %v", err)
return "", nil, ErrServiceUnavailable
}
// 验证密码
if !s.CheckPassword(password, user.PasswordHash) {
return "", nil, ErrInvalidCredentials
}
// 检查用户状态
if !user.IsActive() {
return "", nil, ErrUserNotActive
}
// 生成JWT token
token, err := s.GenerateToken(user)
if err != nil {
return "", nil, fmt.Errorf("generate token: %w", err)
}
return token, user, nil
}
// ValidateToken 验证JWT token并返回用户声明
func (s *AuthService) ValidateToken(tokenString string) (*JWTClaims, error) {
token, err := jwt.ParseWithClaims(tokenString, &JWTClaims{}, func(token *jwt.Token) (any, error) {
// 验证签名方法
if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"])
}
return []byte(s.cfg.JWT.Secret), nil
})
if err != nil {
if errors.Is(err, jwt.ErrTokenExpired) {
return nil, ErrTokenExpired
}
return nil, ErrInvalidToken
}
if claims, ok := token.Claims.(*JWTClaims); ok && token.Valid {
return claims, nil
}
return nil, ErrInvalidToken
}
// GenerateToken 生成JWT token
func (s *AuthService) GenerateToken(user *model.User) (string, error) {
now := time.Now()
expiresAt := now.Add(time.Duration(s.cfg.JWT.ExpireHour) * time.Hour)
claims := &JWTClaims{
UserID: user.ID,
Email: user.Email,
Role: user.Role,
RegisteredClaims: jwt.RegisteredClaims{
ExpiresAt: jwt.NewNumericDate(expiresAt),
IssuedAt: jwt.NewNumericDate(now),
NotBefore: jwt.NewNumericDate(now),
},
}
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
tokenString, err := token.SignedString([]byte(s.cfg.JWT.Secret))
if err != nil {
return "", fmt.Errorf("sign token: %w", err)
}
return tokenString, nil
}
// HashPassword 使用bcrypt加密密码
func (s *AuthService) HashPassword(password string) (string, error) {
hashedBytes, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost)
if err != nil {
return "", err
}
return string(hashedBytes), nil
}
// CheckPassword 验证密码是否匹配
func (s *AuthService) CheckPassword(password, hashedPassword string) bool {
err := bcrypt.CompareHashAndPassword([]byte(hashedPassword), []byte(password))
return err == nil
}
// RefreshToken 刷新token
func (s *AuthService) RefreshToken(ctx context.Context, oldTokenString string) (string, error) {
// 验证旧token即使过期也允许用于刷新
claims, err := s.ValidateToken(oldTokenString)
if err != nil && !errors.Is(err, ErrTokenExpired) {
return "", err
}
// 获取最新的用户信息
user, err := s.userRepo.GetByID(ctx, claims.UserID)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return "", ErrInvalidToken
}
log.Printf("[Auth] Database error refreshing token: %v", err)
return "", ErrServiceUnavailable
}
// 检查用户状态
if !user.IsActive() {
return "", ErrUserNotActive
}
// 生成新token
return s.GenerateToken(user)
}