@@ -2,9 +2,13 @@ package service
import (
"context"
"crypto/rand"
"encoding/hex"
"errors"
"fmt"
"log"
"net/mail"
"strings"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
@@ -18,6 +22,7 @@ var (
ErrInvalidCredentials = infraerrors . Unauthorized ( "INVALID_CREDENTIALS" , "invalid email or password" )
ErrUserNotActive = infraerrors . Forbidden ( "USER_NOT_ACTIVE" , "user is not active" )
ErrEmailExists = infraerrors . Conflict ( "EMAIL_EXISTS" , "email already exists" )
ErrEmailReserved = infraerrors . BadRequest ( "EMAIL_RESERVED" , "email is reserved" )
ErrInvalidToken = infraerrors . Unauthorized ( "INVALID_TOKEN" , "invalid token" )
ErrTokenExpired = infraerrors . Unauthorized ( "TOKEN_EXPIRED" , "token has expired" )
ErrTokenTooLarge = infraerrors . BadRequest ( "TOKEN_TOO_LARGE" , "token too large" )
@@ -27,6 +32,8 @@ var (
ErrServiceUnavailable = infraerrors . ServiceUnavailable ( "SERVICE_UNAVAILABLE" , "service temporarily unavailable" )
)
const linuxDoSyntheticEmailDomain = "@linuxdo-connect.invalid"
// maxTokenLength 限制 token 大小,避免超长 header 触发解析时的异常内存分配。
const maxTokenLength = 8192
@@ -80,6 +87,11 @@ func (s *AuthService) RegisterWithVerification(ctx context.Context, email, passw
return "" , nil , ErrRegDisabled
}
// Prevent users from registering emails reserved for synthetic OAuth accounts.
if isReservedEmail ( email ) {
return "" , nil , ErrEmailReserved
}
// 检查是否需要邮件验证
if s . settingService != nil && s . settingService . IsEmailVerifyEnabled ( ctx ) {
// 如果邮件验证已开启但邮件服务未配置,拒绝注册
@@ -161,6 +173,10 @@ func (s *AuthService) SendVerifyCode(ctx context.Context, email string) error {
return ErrRegDisabled
}
if isReservedEmail ( email ) {
return ErrEmailReserved
}
// 检查邮箱是否已存在
existsEmail , err := s . userRepo . ExistsByEmail ( ctx , email )
if err != nil {
@@ -195,6 +211,10 @@ func (s *AuthService) SendVerifyCodeAsync(ctx context.Context, email string) (*S
return nil , ErrRegDisabled
}
if isReservedEmail ( email ) {
return nil , ErrEmailReserved
}
// 检查邮箱是否已存在
existsEmail , err := s . userRepo . ExistsByEmail ( ctx , email )
if err != nil {
@@ -319,6 +339,101 @@ func (s *AuthService) Login(ctx context.Context, email, password string) (string
return token , user , nil
}
// LoginOrRegisterOAuth logs a user in by email (trusted from an OAuth provider) or creates a new user.
//
// This is used by end-user OAuth/SSO login flows (e.g. LinuxDo Connect), and intentionally does
// NOT require the local password. A random password hash is generated for new users to satisfy
// the existing database constraint.
func ( s * AuthService ) LoginOrRegisterOAuth ( ctx context . Context , email , username string ) ( string , * User , error ) {
email = strings . TrimSpace ( email )
if email == "" || len ( email ) > 255 {
return "" , nil , infraerrors . BadRequest ( "INVALID_EMAIL" , "invalid email" )
}
if _ , err := mail . ParseAddress ( email ) ; err != nil {
return "" , nil , infraerrors . BadRequest ( "INVALID_EMAIL" , "invalid email" )
}
username = strings . TrimSpace ( username )
if len ( [ ] rune ( username ) ) > 100 {
username = string ( [ ] rune ( username ) [ : 100 ] )
}
user , err := s . userRepo . GetByEmail ( ctx , email )
if err != nil {
if errors . Is ( err , ErrUserNotFound ) {
// Treat OAuth-first login as registration.
if s . settingService != nil && ! s . settingService . IsRegistrationEnabled ( ctx ) {
return "" , nil , ErrRegDisabled
}
randomPassword , err := randomHexString ( 32 )
if err != nil {
log . Printf ( "[Auth] Failed to generate random password for oauth signup: %v" , err )
return "" , nil , ErrServiceUnavailable
}
hashedPassword , err := s . HashPassword ( randomPassword )
if err != nil {
return "" , nil , fmt . Errorf ( "hash password: %w" , err )
}
// Defaults for new users.
defaultBalance := s . cfg . Default . UserBalance
defaultConcurrency := s . cfg . Default . UserConcurrency
if s . settingService != nil {
defaultBalance = s . settingService . GetDefaultBalance ( ctx )
defaultConcurrency = s . settingService . GetDefaultConcurrency ( ctx )
}
newUser := & User {
Email : email ,
Username : username ,
PasswordHash : hashedPassword ,
Role : RoleUser ,
Balance : defaultBalance ,
Concurrency : defaultConcurrency ,
Status : StatusActive ,
}
if err := s . userRepo . Create ( ctx , newUser ) ; err != nil {
if errors . Is ( err , ErrEmailExists ) {
// Race: user created between GetByEmail and Create.
user , err = s . userRepo . GetByEmail ( ctx , email )
if err != nil {
log . Printf ( "[Auth] Database error getting user after conflict: %v" , err )
return "" , nil , ErrServiceUnavailable
}
} else {
log . Printf ( "[Auth] Database error creating oauth user: %v" , err )
return "" , nil , ErrServiceUnavailable
}
} else {
user = newUser
}
} else {
log . Printf ( "[Auth] Database error during oauth login: %v" , err )
return "" , nil , ErrServiceUnavailable
}
}
if ! user . IsActive ( ) {
return "" , nil , ErrUserNotActive
}
// Best-effort: fill username when empty.
if user . Username == "" && username != "" {
user . Username = username
if err := s . userRepo . Update ( ctx , user ) ; err != nil {
log . Printf ( "[Auth] Failed to update username after oauth login: %v" , err )
}
}
token , err := s . GenerateToken ( user )
if err != nil {
return "" , nil , fmt . Errorf ( "generate token: %w" , err )
}
return token , user , nil
}
// ValidateToken 验证JWT token并返回用户声明
func ( s * AuthService ) ValidateToken ( tokenString string ) ( * JWTClaims , error ) {
// 先做长度校验,尽早拒绝异常超长 token, 降低 DoS 风险。
@@ -361,6 +476,22 @@ func (s *AuthService) ValidateToken(tokenString string) (*JWTClaims, error) {
return nil , ErrInvalidToken
}
func randomHexString ( byteLength int ) ( string , error ) {
if byteLength <= 0 {
byteLength = 16
}
buf := make ( [ ] byte , byteLength )
if _ , err := rand . Read ( buf ) ; err != nil {
return "" , err
}
return hex . EncodeToString ( buf ) , nil
}
func isReservedEmail ( email string ) bool {
normalized := strings . ToLower ( strings . TrimSpace ( email ) )
return strings . HasSuffix ( normalized , linuxDoSyntheticEmailDomain )
}
// GenerateToken 生成JWT token
func ( s * AuthService ) GenerateToken ( user * User ) ( string , error ) {
now := time . Now ( )