package service import ( "context" "errors" "fmt" "log" "sub2api/internal/config" "sub2api/internal/model" "sub2api/internal/repository" "time" "github.com/golang-jwt/jwt/v5" "golang.org/x/crypto/bcrypt" "gorm.io/gorm" ) var ( ErrInvalidCredentials = errors.New("invalid email or password") ErrUserNotActive = errors.New("user is not active") ErrEmailExists = errors.New("email already exists") ErrInvalidToken = errors.New("invalid token") ErrTokenExpired = errors.New("token has expired") ErrEmailVerifyRequired = errors.New("email verification is required") ErrRegDisabled = errors.New("registration is currently disabled") ) // JWTClaims JWT载荷数据 type JWTClaims struct { UserID int64 `json:"user_id"` Email string `json:"email"` Role string `json:"role"` jwt.RegisteredClaims } // AuthService 认证服务 type AuthService struct { userRepo *repository.UserRepository cfg *config.Config settingService *SettingService emailService *EmailService turnstileService *TurnstileService emailQueueService *EmailQueueService } // NewAuthService 创建认证服务实例 func NewAuthService(userRepo *repository.UserRepository, cfg *config.Config) *AuthService { return &AuthService{ userRepo: userRepo, cfg: cfg, } } // SetSettingService 设置系统设置服务(用于检查注册开关和邮件验证) func (s *AuthService) SetSettingService(settingService *SettingService) { s.settingService = settingService } // SetEmailService 设置邮件服务(用于邮件验证) func (s *AuthService) SetEmailService(emailService *EmailService) { s.emailService = emailService } // SetTurnstileService 设置Turnstile服务(用于验证码校验) func (s *AuthService) SetTurnstileService(turnstileService *TurnstileService) { s.turnstileService = turnstileService } // SetEmailQueueService 设置邮件队列服务(用于异步发送邮件) func (s *AuthService) SetEmailQueueService(emailQueueService *EmailQueueService) { s.emailQueueService = emailQueueService } // Register 用户注册,返回token和用户 func (s *AuthService) Register(ctx context.Context, email, password string) (string, *model.User, error) { return s.RegisterWithVerification(ctx, email, password, "") } // RegisterWithVerification 用户注册(支持邮件验证),返回token和用户 func (s *AuthService) RegisterWithVerification(ctx context.Context, email, password, verifyCode string) (string, *model.User, error) { // 检查是否开放注册 if s.settingService != nil && !s.settingService.IsRegistrationEnabled(ctx) { return "", nil, ErrRegDisabled } // 检查是否需要邮件验证 if s.settingService != nil && s.settingService.IsEmailVerifyEnabled(ctx) { if verifyCode == "" { return "", nil, ErrEmailVerifyRequired } // 验证邮箱验证码 if s.emailService != nil { if err := s.emailService.VerifyCode(ctx, email, verifyCode); err != nil { return "", nil, fmt.Errorf("verify code: %w", err) } } } // 检查邮箱是否已存在 existsEmail, err := s.userRepo.ExistsByEmail(ctx, email) if err != nil { return "", nil, fmt.Errorf("check email exists: %w", err) } if existsEmail { return "", nil, ErrEmailExists } // 密码哈希 hashedPassword, err := s.HashPassword(password) if err != nil { return "", nil, fmt.Errorf("hash password: %w", err) } // 获取默认配置 defaultBalance := s.cfg.Default.UserBalance defaultConcurrency := s.cfg.Default.UserConcurrency if s.settingService != nil { defaultBalance = s.settingService.GetDefaultBalance(ctx) defaultConcurrency = s.settingService.GetDefaultConcurrency(ctx) } // 创建用户 user := &model.User{ Email: email, PasswordHash: hashedPassword, Role: model.RoleUser, Balance: defaultBalance, Concurrency: defaultConcurrency, Status: model.StatusActive, } if err := s.userRepo.Create(ctx, user); err != nil { return "", nil, fmt.Errorf("create user: %w", err) } // 生成token token, err := s.GenerateToken(user) if err != nil { return "", nil, fmt.Errorf("generate token: %w", err) } return token, user, nil } // SendVerifyCodeResult 发送验证码返回结果 type SendVerifyCodeResult struct { Countdown int `json:"countdown"` // 倒计时秒数 } // SendVerifyCode 发送邮箱验证码(同步方式) func (s *AuthService) SendVerifyCode(ctx context.Context, email string) error { // 检查是否开放注册 if s.settingService != nil && !s.settingService.IsRegistrationEnabled(ctx) { return ErrRegDisabled } // 检查邮箱是否已存在 existsEmail, err := s.userRepo.ExistsByEmail(ctx, email) if err != nil { return fmt.Errorf("check email exists: %w", err) } if existsEmail { return ErrEmailExists } // 发送验证码 if s.emailService == nil { return errors.New("email service not configured") } // 获取网站名称 siteName := "Sub2API" if s.settingService != nil { siteName = s.settingService.GetSiteName(ctx) } return s.emailService.SendVerifyCode(ctx, email, siteName) } // SendVerifyCodeAsync 异步发送邮箱验证码并返回倒计时 func (s *AuthService) SendVerifyCodeAsync(ctx context.Context, email string) (*SendVerifyCodeResult, error) { log.Printf("[Auth] SendVerifyCodeAsync called for email: %s", email) // 检查是否开放注册 if s.settingService != nil && !s.settingService.IsRegistrationEnabled(ctx) { log.Println("[Auth] Registration is disabled") return nil, ErrRegDisabled } // 检查邮箱是否已存在 existsEmail, err := s.userRepo.ExistsByEmail(ctx, email) if err != nil { log.Printf("[Auth] Error checking email exists: %v", err) return nil, fmt.Errorf("check email exists: %w", err) } if existsEmail { log.Printf("[Auth] Email already exists: %s", email) return nil, ErrEmailExists } // 检查邮件队列服务是否配置 if s.emailQueueService == nil { log.Println("[Auth] Email queue service not configured") return nil, errors.New("email queue service not configured") } // 获取网站名称 siteName := "Sub2API" if s.settingService != nil { siteName = s.settingService.GetSiteName(ctx) } // 异步发送 log.Printf("[Auth] Enqueueing verify code for: %s", email) if err := s.emailQueueService.EnqueueVerifyCode(email, siteName); err != nil { log.Printf("[Auth] Failed to enqueue: %v", err) return nil, fmt.Errorf("enqueue verify code: %w", err) } log.Printf("[Auth] Verify code enqueued successfully for: %s", email) return &SendVerifyCodeResult{ Countdown: 60, // 60秒倒计时 }, nil } // VerifyTurnstile 验证Turnstile token func (s *AuthService) VerifyTurnstile(ctx context.Context, token string, remoteIP string) error { if s.turnstileService == nil { return nil // 服务未配置则跳过验证 } return s.turnstileService.VerifyToken(ctx, token, remoteIP) } // IsTurnstileEnabled 检查是否启用Turnstile验证 func (s *AuthService) IsTurnstileEnabled(ctx context.Context) bool { if s.turnstileService == nil { return false } return s.turnstileService.IsEnabled(ctx) } // IsRegistrationEnabled 检查是否开放注册 func (s *AuthService) IsRegistrationEnabled(ctx context.Context) bool { if s.settingService == nil { return true } return s.settingService.IsRegistrationEnabled(ctx) } // IsEmailVerifyEnabled 检查是否开启邮件验证 func (s *AuthService) IsEmailVerifyEnabled(ctx context.Context) bool { if s.settingService == nil { return false } return s.settingService.IsEmailVerifyEnabled(ctx) } // Login 用户登录,返回JWT token func (s *AuthService) Login(ctx context.Context, email, password string) (string, *model.User, error) { // 查找用户 user, err := s.userRepo.GetByEmail(ctx, email) if err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { return "", nil, ErrInvalidCredentials } return "", nil, fmt.Errorf("get user by email: %w", err) } // 验证密码 if !s.CheckPassword(password, user.PasswordHash) { return "", nil, ErrInvalidCredentials } // 检查用户状态 if !user.IsActive() { return "", nil, ErrUserNotActive } // 生成JWT token token, err := s.GenerateToken(user) if err != nil { return "", nil, fmt.Errorf("generate token: %w", err) } return token, user, nil } // ValidateToken 验证JWT token并返回用户声明 func (s *AuthService) ValidateToken(tokenString string) (*JWTClaims, error) { token, err := jwt.ParseWithClaims(tokenString, &JWTClaims{}, func(token *jwt.Token) (interface{}, error) { // 验证签名方法 if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok { return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"]) } return []byte(s.cfg.JWT.Secret), nil }) if err != nil { if errors.Is(err, jwt.ErrTokenExpired) { return nil, ErrTokenExpired } return nil, ErrInvalidToken } if claims, ok := token.Claims.(*JWTClaims); ok && token.Valid { return claims, nil } return nil, ErrInvalidToken } // GenerateToken 生成JWT token func (s *AuthService) GenerateToken(user *model.User) (string, error) { now := time.Now() expiresAt := now.Add(time.Duration(s.cfg.JWT.ExpireHour) * time.Hour) claims := &JWTClaims{ UserID: user.ID, Email: user.Email, Role: user.Role, RegisteredClaims: jwt.RegisteredClaims{ ExpiresAt: jwt.NewNumericDate(expiresAt), IssuedAt: jwt.NewNumericDate(now), NotBefore: jwt.NewNumericDate(now), }, } token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims) tokenString, err := token.SignedString([]byte(s.cfg.JWT.Secret)) if err != nil { return "", fmt.Errorf("sign token: %w", err) } return tokenString, nil } // HashPassword 使用bcrypt加密密码 func (s *AuthService) HashPassword(password string) (string, error) { hashedBytes, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost) if err != nil { return "", err } return string(hashedBytes), nil } // CheckPassword 验证密码是否匹配 func (s *AuthService) CheckPassword(password, hashedPassword string) bool { err := bcrypt.CompareHashAndPassword([]byte(hashedPassword), []byte(password)) return err == nil } // RefreshToken 刷新token func (s *AuthService) RefreshToken(ctx context.Context, oldTokenString string) (string, error) { // 验证旧token(即使过期也允许,用于刷新) claims, err := s.ValidateToken(oldTokenString) if err != nil && !errors.Is(err, ErrTokenExpired) { return "", err } // 获取最新的用户信息 user, err := s.userRepo.GetByID(ctx, claims.UserID) if err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { return "", ErrInvalidToken } return "", fmt.Errorf("get user: %w", err) } // 检查用户状态 if !user.IsActive() { return "", ErrUserNotActive } // 生成新token return s.GenerateToken(user) }