package service import ( "context" "crypto/rand" "crypto/sha256" "encoding/hex" "errors" "fmt" "log" "net/mail" "strings" "time" "github.com/Wei-Shaw/sub2api/internal/config" infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors" "github.com/golang-jwt/jwt/v5" "golang.org/x/crypto/bcrypt" ) var ( ErrInvalidCredentials = infraerrors.Unauthorized("INVALID_CREDENTIALS", "invalid email or password") ErrUserNotActive = infraerrors.Forbidden("USER_NOT_ACTIVE", "user is not active") ErrEmailExists = infraerrors.Conflict("EMAIL_EXISTS", "email already exists") ErrEmailReserved = infraerrors.BadRequest("EMAIL_RESERVED", "email is reserved") ErrInvalidToken = infraerrors.Unauthorized("INVALID_TOKEN", "invalid token") ErrTokenExpired = infraerrors.Unauthorized("TOKEN_EXPIRED", "token has expired") ErrAccessTokenExpired = infraerrors.Unauthorized("ACCESS_TOKEN_EXPIRED", "access token has expired") ErrTokenTooLarge = infraerrors.BadRequest("TOKEN_TOO_LARGE", "token too large") ErrTokenRevoked = infraerrors.Unauthorized("TOKEN_REVOKED", "token has been revoked") ErrRefreshTokenInvalid = infraerrors.Unauthorized("REFRESH_TOKEN_INVALID", "invalid refresh token") ErrRefreshTokenExpired = infraerrors.Unauthorized("REFRESH_TOKEN_EXPIRED", "refresh token has expired") ErrRefreshTokenReused = infraerrors.Unauthorized("REFRESH_TOKEN_REUSED", "refresh token has been reused") ErrEmailVerifyRequired = infraerrors.BadRequest("EMAIL_VERIFY_REQUIRED", "email verification is required") ErrRegDisabled = infraerrors.Forbidden("REGISTRATION_DISABLED", "registration is currently disabled") ErrServiceUnavailable = infraerrors.ServiceUnavailable("SERVICE_UNAVAILABLE", "service temporarily unavailable") ErrInvitationCodeRequired = infraerrors.BadRequest("INVITATION_CODE_REQUIRED", "invitation code is required") ErrInvitationCodeInvalid = infraerrors.BadRequest("INVITATION_CODE_INVALID", "invalid or used invitation code") ) // maxTokenLength 限制 token 大小,避免超长 header 触发解析时的异常内存分配。 const maxTokenLength = 8192 // refreshTokenPrefix is the prefix for refresh tokens to distinguish them from access tokens. const refreshTokenPrefix = "rt_" // JWTClaims JWT载荷数据 type JWTClaims struct { UserID int64 `json:"user_id"` Email string `json:"email"` Role string `json:"role"` TokenVersion int64 `json:"token_version"` // Used to invalidate tokens on password change jwt.RegisteredClaims } // AuthService 认证服务 type AuthService struct { userRepo UserRepository redeemRepo RedeemCodeRepository refreshTokenCache RefreshTokenCache cfg *config.Config settingService *SettingService emailService *EmailService turnstileService *TurnstileService emailQueueService *EmailQueueService promoService *PromoService } // NewAuthService 创建认证服务实例 func NewAuthService( userRepo UserRepository, redeemRepo RedeemCodeRepository, refreshTokenCache RefreshTokenCache, cfg *config.Config, settingService *SettingService, emailService *EmailService, turnstileService *TurnstileService, emailQueueService *EmailQueueService, promoService *PromoService, ) *AuthService { return &AuthService{ userRepo: userRepo, redeemRepo: redeemRepo, refreshTokenCache: refreshTokenCache, cfg: cfg, settingService: settingService, emailService: emailService, turnstileService: turnstileService, emailQueueService: emailQueueService, promoService: promoService, } } // Register 用户注册,返回token和用户 func (s *AuthService) Register(ctx context.Context, email, password string) (string, *User, error) { return s.RegisterWithVerification(ctx, email, password, "", "", "") } // RegisterWithVerification 用户注册(支持邮件验证、优惠码和邀请码),返回token和用户 func (s *AuthService) RegisterWithVerification(ctx context.Context, email, password, verifyCode, promoCode, invitationCode string) (string, *User, error) { // 检查是否开放注册(默认关闭:settingService 未配置时不允许注册) if s.settingService == nil || !s.settingService.IsRegistrationEnabled(ctx) { return "", nil, ErrRegDisabled } // 防止用户注册 LinuxDo OAuth 合成邮箱,避免第三方登录与本地账号发生碰撞。 if isReservedEmail(email) { return "", nil, ErrEmailReserved } // 检查是否需要邀请码 var invitationRedeemCode *RedeemCode if s.settingService != nil && s.settingService.IsInvitationCodeEnabled(ctx) { if invitationCode == "" { return "", nil, ErrInvitationCodeRequired } // 验证邀请码 redeemCode, err := s.redeemRepo.GetByCode(ctx, invitationCode) if err != nil { log.Printf("[Auth] Invalid invitation code: %s, error: %v", invitationCode, err) return "", nil, ErrInvitationCodeInvalid } // 检查类型和状态 if redeemCode.Type != RedeemTypeInvitation || redeemCode.Status != StatusUnused { log.Printf("[Auth] Invitation code invalid: type=%s, status=%s", redeemCode.Type, redeemCode.Status) return "", nil, ErrInvitationCodeInvalid } invitationRedeemCode = redeemCode } // 检查是否需要邮件验证 if s.settingService != nil && s.settingService.IsEmailVerifyEnabled(ctx) { // 如果邮件验证已开启但邮件服务未配置,拒绝注册 // 这是一个配置错误,不应该允许绕过验证 if s.emailService == nil { log.Println("[Auth] Email verification enabled but email service not configured, rejecting registration") return "", nil, ErrServiceUnavailable } if verifyCode == "" { return "", nil, ErrEmailVerifyRequired } // 验证邮箱验证码 if err := s.emailService.VerifyCode(ctx, email, verifyCode); err != nil { return "", nil, fmt.Errorf("verify code: %w", err) } } // 检查邮箱是否已存在 existsEmail, err := s.userRepo.ExistsByEmail(ctx, email) if err != nil { log.Printf("[Auth] Database error checking email exists: %v", err) return "", nil, ErrServiceUnavailable } if existsEmail { return "", nil, ErrEmailExists } // 密码哈希 hashedPassword, err := s.HashPassword(password) if err != nil { return "", nil, fmt.Errorf("hash password: %w", err) } // 获取默认配置 defaultBalance := s.cfg.Default.UserBalance defaultConcurrency := s.cfg.Default.UserConcurrency if s.settingService != nil { defaultBalance = s.settingService.GetDefaultBalance(ctx) defaultConcurrency = s.settingService.GetDefaultConcurrency(ctx) } // 创建用户 user := &User{ Email: email, PasswordHash: hashedPassword, Role: RoleUser, Balance: defaultBalance, Concurrency: defaultConcurrency, Status: StatusActive, } if err := s.userRepo.Create(ctx, user); err != nil { // 优先检查邮箱冲突错误(竞态条件下可能发生) if errors.Is(err, ErrEmailExists) { return "", nil, ErrEmailExists } log.Printf("[Auth] Database error creating user: %v", err) return "", nil, ErrServiceUnavailable } // 标记邀请码为已使用(如果使用了邀请码) if invitationRedeemCode != nil { if err := s.redeemRepo.Use(ctx, invitationRedeemCode.ID, user.ID); err != nil { // 邀请码标记失败不影响注册,只记录日志 log.Printf("[Auth] Failed to mark invitation code as used for user %d: %v", user.ID, err) } } // 应用优惠码(如果提供且功能已启用) if promoCode != "" && s.promoService != nil && s.settingService != nil && s.settingService.IsPromoCodeEnabled(ctx) { if err := s.promoService.ApplyPromoCode(ctx, user.ID, promoCode); err != nil { // 优惠码应用失败不影响注册,只记录日志 log.Printf("[Auth] Failed to apply promo code for user %d: %v", user.ID, err) } else { // 重新获取用户信息以获取更新后的余额 if updatedUser, err := s.userRepo.GetByID(ctx, user.ID); err == nil { user = updatedUser } } } // 生成token token, err := s.GenerateToken(user) if err != nil { return "", nil, fmt.Errorf("generate token: %w", err) } return token, user, nil } // SendVerifyCodeResult 发送验证码返回结果 type SendVerifyCodeResult struct { Countdown int `json:"countdown"` // 倒计时秒数 } // SendVerifyCode 发送邮箱验证码(同步方式) func (s *AuthService) SendVerifyCode(ctx context.Context, email string) error { // 检查是否开放注册(默认关闭) if s.settingService == nil || !s.settingService.IsRegistrationEnabled(ctx) { return ErrRegDisabled } if isReservedEmail(email) { return ErrEmailReserved } // 检查邮箱是否已存在 existsEmail, err := s.userRepo.ExistsByEmail(ctx, email) if err != nil { log.Printf("[Auth] Database error checking email exists: %v", err) return ErrServiceUnavailable } if existsEmail { return ErrEmailExists } // 发送验证码 if s.emailService == nil { return errors.New("email service not configured") } // 获取网站名称 siteName := "TianShuAPI" if s.settingService != nil { siteName = s.settingService.GetSiteName(ctx) } return s.emailService.SendVerifyCode(ctx, email, siteName) } // SendVerifyCodeAsync 异步发送邮箱验证码并返回倒计时 func (s *AuthService) SendVerifyCodeAsync(ctx context.Context, email string) (*SendVerifyCodeResult, error) { log.Printf("[Auth] SendVerifyCodeAsync called for email: %s", email) // 检查是否开放注册(默认关闭) if s.settingService == nil || !s.settingService.IsRegistrationEnabled(ctx) { log.Println("[Auth] Registration is disabled") return nil, ErrRegDisabled } if isReservedEmail(email) { return nil, ErrEmailReserved } // 检查邮箱是否已存在 existsEmail, err := s.userRepo.ExistsByEmail(ctx, email) if err != nil { log.Printf("[Auth] Database error checking email exists: %v", err) return nil, ErrServiceUnavailable } if existsEmail { log.Printf("[Auth] Email already exists: %s", email) return nil, ErrEmailExists } // 检查邮件队列服务是否配置 if s.emailQueueService == nil { log.Println("[Auth] Email queue service not configured") return nil, errors.New("email queue service not configured") } // 获取网站名称 siteName := "TianShuAPI" if s.settingService != nil { siteName = s.settingService.GetSiteName(ctx) } // 异步发送 log.Printf("[Auth] Enqueueing verify code for: %s", email) if err := s.emailQueueService.EnqueueVerifyCode(email, siteName); err != nil { log.Printf("[Auth] Failed to enqueue: %v", err) return nil, fmt.Errorf("enqueue verify code: %w", err) } log.Printf("[Auth] Verify code enqueued successfully for: %s", email) return &SendVerifyCodeResult{ Countdown: 60, // 60秒倒计时 }, nil } // VerifyTurnstile 验证Turnstile token func (s *AuthService) VerifyTurnstile(ctx context.Context, token string, remoteIP string) error { required := s.cfg != nil && s.cfg.Server.Mode == "release" && s.cfg.Turnstile.Required if required { if s.settingService == nil { log.Println("[Auth] Turnstile required but settings service is not configured") return ErrTurnstileNotConfigured } enabled := s.settingService.IsTurnstileEnabled(ctx) secretConfigured := s.settingService.GetTurnstileSecretKey(ctx) != "" if !enabled || !secretConfigured { log.Printf("[Auth] Turnstile required but not configured (enabled=%v, secret_configured=%v)", enabled, secretConfigured) return ErrTurnstileNotConfigured } } if s.turnstileService == nil { if required { log.Println("[Auth] Turnstile required but service not configured") return ErrTurnstileNotConfigured } return nil // 服务未配置则跳过验证 } if !required && s.settingService != nil && s.settingService.IsTurnstileEnabled(ctx) && s.settingService.GetTurnstileSecretKey(ctx) == "" { log.Println("[Auth] Turnstile enabled but secret key not configured") } return s.turnstileService.VerifyToken(ctx, token, remoteIP) } // IsTurnstileEnabled 检查是否启用Turnstile验证 func (s *AuthService) IsTurnstileEnabled(ctx context.Context) bool { if s.turnstileService == nil { return false } return s.turnstileService.IsEnabled(ctx) } // IsRegistrationEnabled 检查是否开放注册 func (s *AuthService) IsRegistrationEnabled(ctx context.Context) bool { if s.settingService == nil { return false // 安全默认:settingService 未配置时关闭注册 } return s.settingService.IsRegistrationEnabled(ctx) } // IsEmailVerifyEnabled 检查是否开启邮件验证 func (s *AuthService) IsEmailVerifyEnabled(ctx context.Context) bool { if s.settingService == nil { return false } return s.settingService.IsEmailVerifyEnabled(ctx) } // Login 用户登录,返回JWT token func (s *AuthService) Login(ctx context.Context, email, password string) (string, *User, error) { // 查找用户 user, err := s.userRepo.GetByEmail(ctx, email) if err != nil { if errors.Is(err, ErrUserNotFound) { return "", nil, ErrInvalidCredentials } // 记录数据库错误但不暴露给用户 log.Printf("[Auth] Database error during login: %v", err) return "", nil, ErrServiceUnavailable } // 验证密码 if !s.CheckPassword(password, user.PasswordHash) { return "", nil, ErrInvalidCredentials } // 检查用户状态 if !user.IsActive() { return "", nil, ErrUserNotActive } // 生成JWT token token, err := s.GenerateToken(user) if err != nil { return "", nil, fmt.Errorf("generate token: %w", err) } return token, user, nil } // LoginOrRegisterOAuth 用于第三方 OAuth/SSO 登录: // - 如果邮箱已存在:直接登录(不需要本地密码) // - 如果邮箱不存在:创建新用户并登录 // // 注意:该函数用于 LinuxDo OAuth 登录场景(不同于上游账号的 OAuth,例如 Claude/OpenAI/Gemini)。 // 为了满足现有数据库约束(需要密码哈希),新用户会生成随机密码并进行哈希保存。 func (s *AuthService) LoginOrRegisterOAuth(ctx context.Context, email, username string) (string, *User, error) { email = strings.TrimSpace(email) if email == "" || len(email) > 255 { return "", nil, infraerrors.BadRequest("INVALID_EMAIL", "invalid email") } if _, err := mail.ParseAddress(email); err != nil { return "", nil, infraerrors.BadRequest("INVALID_EMAIL", "invalid email") } username = strings.TrimSpace(username) if len([]rune(username)) > 100 { username = string([]rune(username)[:100]) } user, err := s.userRepo.GetByEmail(ctx, email) if err != nil { if errors.Is(err, ErrUserNotFound) { // OAuth 首次登录视为注册(fail-close:settingService 未配置时不允许注册) if s.settingService == nil || !s.settingService.IsRegistrationEnabled(ctx) { return "", nil, ErrRegDisabled } randomPassword, err := randomHexString(32) if err != nil { log.Printf("[Auth] Failed to generate random password for oauth signup: %v", err) return "", nil, ErrServiceUnavailable } hashedPassword, err := s.HashPassword(randomPassword) if err != nil { return "", nil, fmt.Errorf("hash password: %w", err) } // 新用户默认值。 defaultBalance := s.cfg.Default.UserBalance defaultConcurrency := s.cfg.Default.UserConcurrency if s.settingService != nil { defaultBalance = s.settingService.GetDefaultBalance(ctx) defaultConcurrency = s.settingService.GetDefaultConcurrency(ctx) } newUser := &User{ Email: email, Username: username, PasswordHash: hashedPassword, Role: RoleUser, Balance: defaultBalance, Concurrency: defaultConcurrency, Status: StatusActive, } if err := s.userRepo.Create(ctx, newUser); err != nil { if errors.Is(err, ErrEmailExists) { // 并发场景:GetByEmail 与 Create 之间用户被创建。 user, err = s.userRepo.GetByEmail(ctx, email) if err != nil { log.Printf("[Auth] Database error getting user after conflict: %v", err) return "", nil, ErrServiceUnavailable } } else { log.Printf("[Auth] Database error creating oauth user: %v", err) return "", nil, ErrServiceUnavailable } } else { user = newUser } } else { log.Printf("[Auth] Database error during oauth login: %v", err) return "", nil, ErrServiceUnavailable } } if !user.IsActive() { return "", nil, ErrUserNotActive } // 尽力补全:当用户名为空时,使用第三方返回的用户名回填。 if user.Username == "" && username != "" { user.Username = username if err := s.userRepo.Update(ctx, user); err != nil { log.Printf("[Auth] Failed to update username after oauth login: %v", err) } } token, err := s.GenerateToken(user) if err != nil { return "", nil, fmt.Errorf("generate token: %w", err) } return token, user, nil } // LoginOrRegisterOAuthWithTokenPair 用于第三方 OAuth/SSO 登录,返回完整的 TokenPair // 与 LoginOrRegisterOAuth 功能相同,但返回 TokenPair 而非单个 token func (s *AuthService) LoginOrRegisterOAuthWithTokenPair(ctx context.Context, email, username string) (*TokenPair, *User, error) { // 检查 refreshTokenCache 是否可用 if s.refreshTokenCache == nil { return nil, nil, errors.New("refresh token cache not configured") } email = strings.TrimSpace(email) if email == "" || len(email) > 255 { return nil, nil, infraerrors.BadRequest("INVALID_EMAIL", "invalid email") } if _, err := mail.ParseAddress(email); err != nil { return nil, nil, infraerrors.BadRequest("INVALID_EMAIL", "invalid email") } username = strings.TrimSpace(username) if len([]rune(username)) > 100 { username = string([]rune(username)[:100]) } user, err := s.userRepo.GetByEmail(ctx, email) if err != nil { if errors.Is(err, ErrUserNotFound) { // OAuth 首次登录视为注册 if s.settingService == nil || !s.settingService.IsRegistrationEnabled(ctx) { return nil, nil, ErrRegDisabled } randomPassword, err := randomHexString(32) if err != nil { log.Printf("[Auth] Failed to generate random password for oauth signup: %v", err) return nil, nil, ErrServiceUnavailable } hashedPassword, err := s.HashPassword(randomPassword) if err != nil { return nil, nil, fmt.Errorf("hash password: %w", err) } defaultBalance := s.cfg.Default.UserBalance defaultConcurrency := s.cfg.Default.UserConcurrency if s.settingService != nil { defaultBalance = s.settingService.GetDefaultBalance(ctx) defaultConcurrency = s.settingService.GetDefaultConcurrency(ctx) } newUser := &User{ Email: email, Username: username, PasswordHash: hashedPassword, Role: RoleUser, Balance: defaultBalance, Concurrency: defaultConcurrency, Status: StatusActive, } if err := s.userRepo.Create(ctx, newUser); err != nil { if errors.Is(err, ErrEmailExists) { user, err = s.userRepo.GetByEmail(ctx, email) if err != nil { log.Printf("[Auth] Database error getting user after conflict: %v", err) return nil, nil, ErrServiceUnavailable } } else { log.Printf("[Auth] Database error creating oauth user: %v", err) return nil, nil, ErrServiceUnavailable } } else { user = newUser } } else { log.Printf("[Auth] Database error during oauth login: %v", err) return nil, nil, ErrServiceUnavailable } } if !user.IsActive() { return nil, nil, ErrUserNotActive } if user.Username == "" && username != "" { user.Username = username if err := s.userRepo.Update(ctx, user); err != nil { log.Printf("[Auth] Failed to update username after oauth login: %v", err) } } tokenPair, err := s.GenerateTokenPair(ctx, user, "") if err != nil { return nil, nil, fmt.Errorf("generate token pair: %w", err) } return tokenPair, user, nil } // ValidateToken 验证JWT token并返回用户声明 func (s *AuthService) ValidateToken(tokenString string) (*JWTClaims, error) { // 先做长度校验,尽早拒绝异常超长 token,降低 DoS 风险。 if len(tokenString) > maxTokenLength { return nil, ErrTokenTooLarge } // 使用解析器并限制可接受的签名算法,防止算法混淆。 parser := jwt.NewParser(jwt.WithValidMethods([]string{ jwt.SigningMethodHS256.Name, jwt.SigningMethodHS384.Name, jwt.SigningMethodHS512.Name, })) // 保留默认 claims 校验(exp/nbf),避免放行过期或未生效的 token。 token, err := parser.ParseWithClaims(tokenString, &JWTClaims{}, func(token *jwt.Token) (any, error) { // 验证签名方法 if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok { return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"]) } return []byte(s.cfg.JWT.Secret), nil }) if err != nil { if errors.Is(err, jwt.ErrTokenExpired) { // token 过期但仍返回 claims(用于 RefreshToken 等场景) // jwt-go 在解析时即使遇到过期错误,token.Claims 仍会被填充 if claims, ok := token.Claims.(*JWTClaims); ok { return claims, ErrTokenExpired } return nil, ErrTokenExpired } return nil, ErrInvalidToken } if claims, ok := token.Claims.(*JWTClaims); ok && token.Valid { return claims, nil } return nil, ErrInvalidToken } func randomHexString(byteLength int) (string, error) { if byteLength <= 0 { byteLength = 16 } buf := make([]byte, byteLength) if _, err := rand.Read(buf); err != nil { return "", err } return hex.EncodeToString(buf), nil } func isReservedEmail(email string) bool { normalized := strings.ToLower(strings.TrimSpace(email)) return strings.HasSuffix(normalized, LinuxDoConnectSyntheticEmailDomain) } // GenerateToken 生成JWT access token // 使用新的access_token_expire_minutes配置项(如果配置了),否则回退到expire_hour func (s *AuthService) GenerateToken(user *User) (string, error) { now := time.Now() var expiresAt time.Time if s.cfg.JWT.AccessTokenExpireMinutes > 0 { expiresAt = now.Add(time.Duration(s.cfg.JWT.AccessTokenExpireMinutes) * time.Minute) } else { // 向后兼容:使用旧的expire_hour配置 expiresAt = now.Add(time.Duration(s.cfg.JWT.ExpireHour) * time.Hour) } claims := &JWTClaims{ UserID: user.ID, Email: user.Email, Role: user.Role, TokenVersion: user.TokenVersion, RegisteredClaims: jwt.RegisteredClaims{ ExpiresAt: jwt.NewNumericDate(expiresAt), IssuedAt: jwt.NewNumericDate(now), NotBefore: jwt.NewNumericDate(now), }, } token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims) tokenString, err := token.SignedString([]byte(s.cfg.JWT.Secret)) if err != nil { return "", fmt.Errorf("sign token: %w", err) } return tokenString, nil } // GetAccessTokenExpiresIn 返回Access Token的有效期(秒) // 用于前端设置刷新定时器 func (s *AuthService) GetAccessTokenExpiresIn() int { if s.cfg.JWT.AccessTokenExpireMinutes > 0 { return s.cfg.JWT.AccessTokenExpireMinutes * 60 } return s.cfg.JWT.ExpireHour * 3600 } // HashPassword 使用bcrypt加密密码 func (s *AuthService) HashPassword(password string) (string, error) { hashedBytes, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost) if err != nil { return "", err } return string(hashedBytes), nil } // CheckPassword 验证密码是否匹配 func (s *AuthService) CheckPassword(password, hashedPassword string) bool { err := bcrypt.CompareHashAndPassword([]byte(hashedPassword), []byte(password)) return err == nil } // RefreshToken 刷新token func (s *AuthService) RefreshToken(ctx context.Context, oldTokenString string) (string, error) { // 验证旧token(即使过期也允许,用于刷新) claims, err := s.ValidateToken(oldTokenString) if err != nil && !errors.Is(err, ErrTokenExpired) { return "", err } // 获取最新的用户信息 user, err := s.userRepo.GetByID(ctx, claims.UserID) if err != nil { if errors.Is(err, ErrUserNotFound) { return "", ErrInvalidToken } log.Printf("[Auth] Database error refreshing token: %v", err) return "", ErrServiceUnavailable } // 检查用户状态 if !user.IsActive() { return "", ErrUserNotActive } // Security: Check TokenVersion to prevent refreshing revoked tokens // This ensures tokens issued before a password change cannot be refreshed if claims.TokenVersion != user.TokenVersion { return "", ErrTokenRevoked } // 生成新token return s.GenerateToken(user) } // IsPasswordResetEnabled 检查是否启用密码重置功能 // 要求:必须同时开启邮件验证且 SMTP 配置正确 func (s *AuthService) IsPasswordResetEnabled(ctx context.Context) bool { if s.settingService == nil { return false } // Must have email verification enabled and SMTP configured if !s.settingService.IsEmailVerifyEnabled(ctx) { return false } return s.settingService.IsPasswordResetEnabled(ctx) } // preparePasswordReset validates the password reset request and returns necessary data // Returns (siteName, resetURL, shouldProceed) // shouldProceed is false when we should silently return success (to prevent enumeration) func (s *AuthService) preparePasswordReset(ctx context.Context, email, frontendBaseURL string) (string, string, bool) { // Check if user exists (but don't reveal this to the caller) user, err := s.userRepo.GetByEmail(ctx, email) if err != nil { if errors.Is(err, ErrUserNotFound) { // Security: Log but don't reveal that user doesn't exist log.Printf("[Auth] Password reset requested for non-existent email: %s", email) return "", "", false } log.Printf("[Auth] Database error checking email for password reset: %v", err) return "", "", false } // Check if user is active if !user.IsActive() { log.Printf("[Auth] Password reset requested for inactive user: %s", email) return "", "", false } // Get site name siteName := "Sub2API" if s.settingService != nil { siteName = s.settingService.GetSiteName(ctx) } // Build reset URL base resetURL := fmt.Sprintf("%s/reset-password", strings.TrimSuffix(frontendBaseURL, "/")) return siteName, resetURL, true } // RequestPasswordReset 请求密码重置(同步发送) // Security: Returns the same response regardless of whether the email exists (prevent user enumeration) func (s *AuthService) RequestPasswordReset(ctx context.Context, email, frontendBaseURL string) error { if !s.IsPasswordResetEnabled(ctx) { return infraerrors.Forbidden("PASSWORD_RESET_DISABLED", "password reset is not enabled") } if s.emailService == nil { return ErrServiceUnavailable } siteName, resetURL, shouldProceed := s.preparePasswordReset(ctx, email, frontendBaseURL) if !shouldProceed { return nil // Silent success to prevent enumeration } if err := s.emailService.SendPasswordResetEmail(ctx, email, siteName, resetURL); err != nil { log.Printf("[Auth] Failed to send password reset email to %s: %v", email, err) return nil // Silent success to prevent enumeration } log.Printf("[Auth] Password reset email sent to: %s", email) return nil } // RequestPasswordResetAsync 异步请求密码重置(队列发送) // Security: Returns the same response regardless of whether the email exists (prevent user enumeration) func (s *AuthService) RequestPasswordResetAsync(ctx context.Context, email, frontendBaseURL string) error { if !s.IsPasswordResetEnabled(ctx) { return infraerrors.Forbidden("PASSWORD_RESET_DISABLED", "password reset is not enabled") } if s.emailQueueService == nil { return ErrServiceUnavailable } siteName, resetURL, shouldProceed := s.preparePasswordReset(ctx, email, frontendBaseURL) if !shouldProceed { return nil // Silent success to prevent enumeration } if err := s.emailQueueService.EnqueuePasswordReset(email, siteName, resetURL); err != nil { log.Printf("[Auth] Failed to enqueue password reset email for %s: %v", email, err) return nil // Silent success to prevent enumeration } log.Printf("[Auth] Password reset email enqueued for: %s", email) return nil } // ResetPassword 重置密码 // Security: Increments TokenVersion to invalidate all existing JWT tokens func (s *AuthService) ResetPassword(ctx context.Context, email, token, newPassword string) error { // Check if password reset is enabled if !s.IsPasswordResetEnabled(ctx) { return infraerrors.Forbidden("PASSWORD_RESET_DISABLED", "password reset is not enabled") } if s.emailService == nil { return ErrServiceUnavailable } // Verify and consume the reset token (one-time use) if err := s.emailService.ConsumePasswordResetToken(ctx, email, token); err != nil { return err } // Get user user, err := s.userRepo.GetByEmail(ctx, email) if err != nil { if errors.Is(err, ErrUserNotFound) { return ErrInvalidResetToken // Token was valid but user was deleted } log.Printf("[Auth] Database error getting user for password reset: %v", err) return ErrServiceUnavailable } // Check if user is active if !user.IsActive() { return ErrUserNotActive } // Hash new password hashedPassword, err := s.HashPassword(newPassword) if err != nil { return fmt.Errorf("hash password: %w", err) } // Update password and increment TokenVersion user.PasswordHash = hashedPassword user.TokenVersion++ // Invalidate all existing tokens if err := s.userRepo.Update(ctx, user); err != nil { log.Printf("[Auth] Database error updating password for user %d: %v", user.ID, err) return ErrServiceUnavailable } // Also revoke all refresh tokens for this user if err := s.RevokeAllUserSessions(ctx, user.ID); err != nil { log.Printf("[Auth] Failed to revoke refresh tokens for user %d: %v", user.ID, err) // Don't return error - password was already changed successfully } log.Printf("[Auth] Password reset successful for user: %s", email) return nil } // ==================== Refresh Token Methods ==================== // TokenPair 包含Access Token和Refresh Token type TokenPair struct { AccessToken string `json:"access_token"` RefreshToken string `json:"refresh_token"` ExpiresIn int `json:"expires_in"` // Access Token有效期(秒) } // GenerateTokenPair 生成Access Token和Refresh Token对 // familyID: 可选的Token家族ID,用于Token轮转时保持家族关系 func (s *AuthService) GenerateTokenPair(ctx context.Context, user *User, familyID string) (*TokenPair, error) { // 检查 refreshTokenCache 是否可用 if s.refreshTokenCache == nil { return nil, errors.New("refresh token cache not configured") } // 生成Access Token accessToken, err := s.GenerateToken(user) if err != nil { return nil, fmt.Errorf("generate access token: %w", err) } // 生成Refresh Token refreshToken, err := s.generateRefreshToken(ctx, user, familyID) if err != nil { return nil, fmt.Errorf("generate refresh token: %w", err) } return &TokenPair{ AccessToken: accessToken, RefreshToken: refreshToken, ExpiresIn: s.GetAccessTokenExpiresIn(), }, nil } // generateRefreshToken 生成并存储Refresh Token func (s *AuthService) generateRefreshToken(ctx context.Context, user *User, familyID string) (string, error) { // 生成随机Token tokenBytes := make([]byte, 32) if _, err := rand.Read(tokenBytes); err != nil { return "", fmt.Errorf("generate random bytes: %w", err) } rawToken := refreshTokenPrefix + hex.EncodeToString(tokenBytes) // 计算Token哈希(存储哈希而非原始Token) tokenHash := hashToken(rawToken) // 如果没有提供familyID,生成新的 if familyID == "" { familyBytes := make([]byte, 16) if _, err := rand.Read(familyBytes); err != nil { return "", fmt.Errorf("generate family id: %w", err) } familyID = hex.EncodeToString(familyBytes) } now := time.Now() ttl := time.Duration(s.cfg.JWT.RefreshTokenExpireDays) * 24 * time.Hour data := &RefreshTokenData{ UserID: user.ID, TokenVersion: user.TokenVersion, FamilyID: familyID, CreatedAt: now, ExpiresAt: now.Add(ttl), } // 存储Token数据 if err := s.refreshTokenCache.StoreRefreshToken(ctx, tokenHash, data, ttl); err != nil { return "", fmt.Errorf("store refresh token: %w", err) } // 添加到用户Token集合 if err := s.refreshTokenCache.AddToUserTokenSet(ctx, user.ID, tokenHash, ttl); err != nil { log.Printf("[Auth] Failed to add token to user set: %v", err) // 不影响主流程 } // 添加到家族Token集合 if err := s.refreshTokenCache.AddToFamilyTokenSet(ctx, familyID, tokenHash, ttl); err != nil { log.Printf("[Auth] Failed to add token to family set: %v", err) // 不影响主流程 } return rawToken, nil } // RefreshTokenPair 使用Refresh Token刷新Token对 // 实现Token轮转:每次刷新都会生成新的Refresh Token,旧Token立即失效 func (s *AuthService) RefreshTokenPair(ctx context.Context, refreshToken string) (*TokenPair, error) { // 检查 refreshTokenCache 是否可用 if s.refreshTokenCache == nil { return nil, ErrRefreshTokenInvalid } // 验证Token格式 if !strings.HasPrefix(refreshToken, refreshTokenPrefix) { return nil, ErrRefreshTokenInvalid } tokenHash := hashToken(refreshToken) // 获取Token数据 data, err := s.refreshTokenCache.GetRefreshToken(ctx, tokenHash) if err != nil { if errors.Is(err, ErrRefreshTokenNotFound) { // Token不存在,可能是已被使用(Token轮转)或已过期 log.Printf("[Auth] Refresh token not found, possible reuse attack") return nil, ErrRefreshTokenInvalid } log.Printf("[Auth] Error getting refresh token: %v", err) return nil, ErrServiceUnavailable } // 检查Token是否过期 if time.Now().After(data.ExpiresAt) { // 删除过期Token _ = s.refreshTokenCache.DeleteRefreshToken(ctx, tokenHash) return nil, ErrRefreshTokenExpired } // 获取用户信息 user, err := s.userRepo.GetByID(ctx, data.UserID) if err != nil { if errors.Is(err, ErrUserNotFound) { // 用户已删除,撤销整个Token家族 _ = s.refreshTokenCache.DeleteTokenFamily(ctx, data.FamilyID) return nil, ErrRefreshTokenInvalid } log.Printf("[Auth] Database error getting user for token refresh: %v", err) return nil, ErrServiceUnavailable } // 检查用户状态 if !user.IsActive() { // 用户被禁用,撤销整个Token家族 _ = s.refreshTokenCache.DeleteTokenFamily(ctx, data.FamilyID) return nil, ErrUserNotActive } // 检查TokenVersion(密码更改后所有Token失效) if data.TokenVersion != user.TokenVersion { // TokenVersion不匹配,撤销整个Token家族 _ = s.refreshTokenCache.DeleteTokenFamily(ctx, data.FamilyID) return nil, ErrTokenRevoked } // Token轮转:立即使旧Token失效 if err := s.refreshTokenCache.DeleteRefreshToken(ctx, tokenHash); err != nil { log.Printf("[Auth] Failed to delete old refresh token: %v", err) // 继续处理,不影响主流程 } // 生成新的Token对,保持同一个家族ID return s.GenerateTokenPair(ctx, user, data.FamilyID) } // RevokeRefreshToken 撤销单个Refresh Token func (s *AuthService) RevokeRefreshToken(ctx context.Context, refreshToken string) error { if s.refreshTokenCache == nil { return nil // No-op if cache not configured } if !strings.HasPrefix(refreshToken, refreshTokenPrefix) { return ErrRefreshTokenInvalid } tokenHash := hashToken(refreshToken) return s.refreshTokenCache.DeleteRefreshToken(ctx, tokenHash) } // RevokeAllUserSessions 撤销用户的所有会话(所有Refresh Token) // 用于密码更改或用户主动登出所有设备 func (s *AuthService) RevokeAllUserSessions(ctx context.Context, userID int64) error { if s.refreshTokenCache == nil { return nil // No-op if cache not configured } return s.refreshTokenCache.DeleteUserRefreshTokens(ctx, userID) } // hashToken 计算Token的SHA256哈希 func hashToken(token string) string { hash := sha256.Sum256([]byte(token)) return hex.EncodeToString(hash[:]) }