Merge upstream/main: v0.1.88-v0.1.101 updates
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -8,9 +8,11 @@ import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/mail"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
dbent "github.com/Wei-Shaw/sub2api/ent"
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
|
||||
@@ -20,23 +22,25 @@ import (
|
||||
)
|
||||
|
||||
var (
|
||||
ErrInvalidCredentials = infraerrors.Unauthorized("INVALID_CREDENTIALS", "invalid email or password")
|
||||
ErrUserNotActive = infraerrors.Forbidden("USER_NOT_ACTIVE", "user is not active")
|
||||
ErrEmailExists = infraerrors.Conflict("EMAIL_EXISTS", "email already exists")
|
||||
ErrEmailReserved = infraerrors.BadRequest("EMAIL_RESERVED", "email is reserved")
|
||||
ErrInvalidToken = infraerrors.Unauthorized("INVALID_TOKEN", "invalid token")
|
||||
ErrTokenExpired = infraerrors.Unauthorized("TOKEN_EXPIRED", "token has expired")
|
||||
ErrAccessTokenExpired = infraerrors.Unauthorized("ACCESS_TOKEN_EXPIRED", "access token has expired")
|
||||
ErrTokenTooLarge = infraerrors.BadRequest("TOKEN_TOO_LARGE", "token too large")
|
||||
ErrTokenRevoked = infraerrors.Unauthorized("TOKEN_REVOKED", "token has been revoked")
|
||||
ErrRefreshTokenInvalid = infraerrors.Unauthorized("REFRESH_TOKEN_INVALID", "invalid refresh token")
|
||||
ErrRefreshTokenExpired = infraerrors.Unauthorized("REFRESH_TOKEN_EXPIRED", "refresh token has expired")
|
||||
ErrRefreshTokenReused = infraerrors.Unauthorized("REFRESH_TOKEN_REUSED", "refresh token has been reused")
|
||||
ErrEmailVerifyRequired = infraerrors.BadRequest("EMAIL_VERIFY_REQUIRED", "email verification is required")
|
||||
ErrRegDisabled = infraerrors.Forbidden("REGISTRATION_DISABLED", "registration is currently disabled")
|
||||
ErrServiceUnavailable = infraerrors.ServiceUnavailable("SERVICE_UNAVAILABLE", "service temporarily unavailable")
|
||||
ErrInvitationCodeRequired = infraerrors.BadRequest("INVITATION_CODE_REQUIRED", "invitation code is required")
|
||||
ErrInvitationCodeInvalid = infraerrors.BadRequest("INVITATION_CODE_INVALID", "invalid or used invitation code")
|
||||
ErrInvalidCredentials = infraerrors.Unauthorized("INVALID_CREDENTIALS", "invalid email or password")
|
||||
ErrUserNotActive = infraerrors.Forbidden("USER_NOT_ACTIVE", "user is not active")
|
||||
ErrEmailExists = infraerrors.Conflict("EMAIL_EXISTS", "email already exists")
|
||||
ErrEmailReserved = infraerrors.BadRequest("EMAIL_RESERVED", "email is reserved")
|
||||
ErrInvalidToken = infraerrors.Unauthorized("INVALID_TOKEN", "invalid token")
|
||||
ErrTokenExpired = infraerrors.Unauthorized("TOKEN_EXPIRED", "token has expired")
|
||||
ErrAccessTokenExpired = infraerrors.Unauthorized("ACCESS_TOKEN_EXPIRED", "access token has expired")
|
||||
ErrTokenTooLarge = infraerrors.BadRequest("TOKEN_TOO_LARGE", "token too large")
|
||||
ErrTokenRevoked = infraerrors.Unauthorized("TOKEN_REVOKED", "token has been revoked")
|
||||
ErrRefreshTokenInvalid = infraerrors.Unauthorized("REFRESH_TOKEN_INVALID", "invalid refresh token")
|
||||
ErrRefreshTokenExpired = infraerrors.Unauthorized("REFRESH_TOKEN_EXPIRED", "refresh token has expired")
|
||||
ErrRefreshTokenReused = infraerrors.Unauthorized("REFRESH_TOKEN_REUSED", "refresh token has been reused")
|
||||
ErrEmailVerifyRequired = infraerrors.BadRequest("EMAIL_VERIFY_REQUIRED", "email verification is required")
|
||||
ErrEmailSuffixNotAllowed = infraerrors.BadRequest("EMAIL_SUFFIX_NOT_ALLOWED", "email suffix is not allowed")
|
||||
ErrRegDisabled = infraerrors.Forbidden("REGISTRATION_DISABLED", "registration is currently disabled")
|
||||
ErrServiceUnavailable = infraerrors.ServiceUnavailable("SERVICE_UNAVAILABLE", "service temporarily unavailable")
|
||||
ErrInvitationCodeRequired = infraerrors.BadRequest("INVITATION_CODE_REQUIRED", "invitation code is required")
|
||||
ErrInvitationCodeInvalid = infraerrors.BadRequest("INVITATION_CODE_INVALID", "invalid or used invitation code")
|
||||
ErrOAuthInvitationRequired = infraerrors.Forbidden("OAUTH_INVITATION_REQUIRED", "invitation code required to complete oauth registration")
|
||||
)
|
||||
|
||||
// maxTokenLength 限制 token 大小,避免超长 header 触发解析时的异常内存分配。
|
||||
@@ -56,6 +60,7 @@ type JWTClaims struct {
|
||||
|
||||
// AuthService 认证服务
|
||||
type AuthService struct {
|
||||
entClient *dbent.Client
|
||||
userRepo UserRepository
|
||||
redeemRepo RedeemCodeRepository
|
||||
refreshTokenCache RefreshTokenCache
|
||||
@@ -74,6 +79,7 @@ type DefaultSubscriptionAssigner interface {
|
||||
|
||||
// NewAuthService 创建认证服务实例
|
||||
func NewAuthService(
|
||||
entClient *dbent.Client,
|
||||
userRepo UserRepository,
|
||||
redeemRepo RedeemCodeRepository,
|
||||
refreshTokenCache RefreshTokenCache,
|
||||
@@ -86,6 +92,7 @@ func NewAuthService(
|
||||
defaultSubAssigner DefaultSubscriptionAssigner,
|
||||
) *AuthService {
|
||||
return &AuthService{
|
||||
entClient: entClient,
|
||||
userRepo: userRepo,
|
||||
redeemRepo: redeemRepo,
|
||||
refreshTokenCache: refreshTokenCache,
|
||||
@@ -115,6 +122,9 @@ func (s *AuthService) RegisterWithVerification(ctx context.Context, email, passw
|
||||
if isReservedEmail(email) {
|
||||
return "", nil, ErrEmailReserved
|
||||
}
|
||||
if err := s.validateRegistrationEmailPolicy(ctx, email); err != nil {
|
||||
return "", nil, err
|
||||
}
|
||||
|
||||
// 检查是否需要邀请码
|
||||
var invitationRedeemCode *RedeemCode
|
||||
@@ -241,6 +251,9 @@ func (s *AuthService) SendVerifyCode(ctx context.Context, email string) error {
|
||||
if isReservedEmail(email) {
|
||||
return ErrEmailReserved
|
||||
}
|
||||
if err := s.validateRegistrationEmailPolicy(ctx, email); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 检查邮箱是否已存在
|
||||
existsEmail, err := s.userRepo.ExistsByEmail(ctx, email)
|
||||
@@ -279,6 +292,9 @@ func (s *AuthService) SendVerifyCodeAsync(ctx context.Context, email string) (*S
|
||||
if isReservedEmail(email) {
|
||||
return nil, ErrEmailReserved
|
||||
}
|
||||
if err := s.validateRegistrationEmailPolicy(ctx, email); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 检查邮箱是否已存在
|
||||
existsEmail, err := s.userRepo.ExistsByEmail(ctx, email)
|
||||
@@ -512,9 +528,10 @@ func (s *AuthService) LoginOrRegisterOAuth(ctx context.Context, email, username
|
||||
return token, user, nil
|
||||
}
|
||||
|
||||
// LoginOrRegisterOAuthWithTokenPair 用于第三方 OAuth/SSO 登录,返回完整的 TokenPair
|
||||
// 与 LoginOrRegisterOAuth 功能相同,但返回 TokenPair 而非单个 token
|
||||
func (s *AuthService) LoginOrRegisterOAuthWithTokenPair(ctx context.Context, email, username string) (*TokenPair, *User, error) {
|
||||
// LoginOrRegisterOAuthWithTokenPair 用于第三方 OAuth/SSO 登录,返回完整的 TokenPair。
|
||||
// 与 LoginOrRegisterOAuth 功能相同,但返回 TokenPair 而非单个 token。
|
||||
// invitationCode 仅在邀请码注册模式下新用户注册时使用;已有账号登录时忽略。
|
||||
func (s *AuthService) LoginOrRegisterOAuthWithTokenPair(ctx context.Context, email, username, invitationCode string) (*TokenPair, *User, error) {
|
||||
// 检查 refreshTokenCache 是否可用
|
||||
if s.refreshTokenCache == nil {
|
||||
return nil, nil, errors.New("refresh token cache not configured")
|
||||
@@ -541,6 +558,22 @@ func (s *AuthService) LoginOrRegisterOAuthWithTokenPair(ctx context.Context, ema
|
||||
return nil, nil, ErrRegDisabled
|
||||
}
|
||||
|
||||
// 检查是否需要邀请码
|
||||
var invitationRedeemCode *RedeemCode
|
||||
if s.settingService != nil && s.settingService.IsInvitationCodeEnabled(ctx) {
|
||||
if invitationCode == "" {
|
||||
return nil, nil, ErrOAuthInvitationRequired
|
||||
}
|
||||
redeemCode, err := s.redeemRepo.GetByCode(ctx, invitationCode)
|
||||
if err != nil {
|
||||
return nil, nil, ErrInvitationCodeInvalid
|
||||
}
|
||||
if redeemCode.Type != RedeemTypeInvitation || redeemCode.Status != StatusUnused {
|
||||
return nil, nil, ErrInvitationCodeInvalid
|
||||
}
|
||||
invitationRedeemCode = redeemCode
|
||||
}
|
||||
|
||||
randomPassword, err := randomHexString(32)
|
||||
if err != nil {
|
||||
logger.LegacyPrintf("service.auth", "[Auth] Failed to generate random password for oauth signup: %v", err)
|
||||
@@ -568,20 +601,58 @@ func (s *AuthService) LoginOrRegisterOAuthWithTokenPair(ctx context.Context, ema
|
||||
Status: StatusActive,
|
||||
}
|
||||
|
||||
if err := s.userRepo.Create(ctx, newUser); err != nil {
|
||||
if errors.Is(err, ErrEmailExists) {
|
||||
user, err = s.userRepo.GetByEmail(ctx, email)
|
||||
if err != nil {
|
||||
logger.LegacyPrintf("service.auth", "[Auth] Database error getting user after conflict: %v", err)
|
||||
if s.entClient != nil && invitationRedeemCode != nil {
|
||||
tx, err := s.entClient.Tx(ctx)
|
||||
if err != nil {
|
||||
logger.LegacyPrintf("service.auth", "[Auth] Failed to begin transaction for oauth registration: %v", err)
|
||||
return nil, nil, ErrServiceUnavailable
|
||||
}
|
||||
defer func() { _ = tx.Rollback() }()
|
||||
txCtx := dbent.NewTxContext(ctx, tx)
|
||||
|
||||
if err := s.userRepo.Create(txCtx, newUser); err != nil {
|
||||
if errors.Is(err, ErrEmailExists) {
|
||||
user, err = s.userRepo.GetByEmail(ctx, email)
|
||||
if err != nil {
|
||||
logger.LegacyPrintf("service.auth", "[Auth] Database error getting user after conflict: %v", err)
|
||||
return nil, nil, ErrServiceUnavailable
|
||||
}
|
||||
} else {
|
||||
logger.LegacyPrintf("service.auth", "[Auth] Database error creating oauth user: %v", err)
|
||||
return nil, nil, ErrServiceUnavailable
|
||||
}
|
||||
} else {
|
||||
logger.LegacyPrintf("service.auth", "[Auth] Database error creating oauth user: %v", err)
|
||||
return nil, nil, ErrServiceUnavailable
|
||||
if err := s.redeemRepo.Use(txCtx, invitationRedeemCode.ID, newUser.ID); err != nil {
|
||||
return nil, nil, ErrInvitationCodeInvalid
|
||||
}
|
||||
if err := tx.Commit(); err != nil {
|
||||
logger.LegacyPrintf("service.auth", "[Auth] Failed to commit oauth registration transaction: %v", err)
|
||||
return nil, nil, ErrServiceUnavailable
|
||||
}
|
||||
user = newUser
|
||||
s.assignDefaultSubscriptions(ctx, user.ID)
|
||||
}
|
||||
} else {
|
||||
user = newUser
|
||||
s.assignDefaultSubscriptions(ctx, user.ID)
|
||||
if err := s.userRepo.Create(ctx, newUser); err != nil {
|
||||
if errors.Is(err, ErrEmailExists) {
|
||||
user, err = s.userRepo.GetByEmail(ctx, email)
|
||||
if err != nil {
|
||||
logger.LegacyPrintf("service.auth", "[Auth] Database error getting user after conflict: %v", err)
|
||||
return nil, nil, ErrServiceUnavailable
|
||||
}
|
||||
} else {
|
||||
logger.LegacyPrintf("service.auth", "[Auth] Database error creating oauth user: %v", err)
|
||||
return nil, nil, ErrServiceUnavailable
|
||||
}
|
||||
} else {
|
||||
user = newUser
|
||||
s.assignDefaultSubscriptions(ctx, user.ID)
|
||||
if invitationRedeemCode != nil {
|
||||
if err := s.redeemRepo.Use(ctx, invitationRedeemCode.ID, user.ID); err != nil {
|
||||
return nil, nil, ErrInvitationCodeInvalid
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
logger.LegacyPrintf("service.auth", "[Auth] Database error during oauth login: %v", err)
|
||||
@@ -607,6 +678,63 @@ func (s *AuthService) LoginOrRegisterOAuthWithTokenPair(ctx context.Context, ema
|
||||
return tokenPair, user, nil
|
||||
}
|
||||
|
||||
// pendingOAuthTokenTTL is the validity period for pending OAuth tokens.
|
||||
const pendingOAuthTokenTTL = 10 * time.Minute
|
||||
|
||||
// pendingOAuthPurpose is the purpose claim value for pending OAuth registration tokens.
|
||||
const pendingOAuthPurpose = "pending_oauth_registration"
|
||||
|
||||
type pendingOAuthClaims struct {
|
||||
Email string `json:"email"`
|
||||
Username string `json:"username"`
|
||||
Purpose string `json:"purpose"`
|
||||
jwt.RegisteredClaims
|
||||
}
|
||||
|
||||
// CreatePendingOAuthToken generates a short-lived JWT that carries the OAuth identity
|
||||
// while waiting for the user to supply an invitation code.
|
||||
func (s *AuthService) CreatePendingOAuthToken(email, username string) (string, error) {
|
||||
now := time.Now()
|
||||
claims := &pendingOAuthClaims{
|
||||
Email: email,
|
||||
Username: username,
|
||||
Purpose: pendingOAuthPurpose,
|
||||
RegisteredClaims: jwt.RegisteredClaims{
|
||||
ExpiresAt: jwt.NewNumericDate(now.Add(pendingOAuthTokenTTL)),
|
||||
IssuedAt: jwt.NewNumericDate(now),
|
||||
NotBefore: jwt.NewNumericDate(now),
|
||||
},
|
||||
}
|
||||
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
|
||||
return token.SignedString([]byte(s.cfg.JWT.Secret))
|
||||
}
|
||||
|
||||
// VerifyPendingOAuthToken validates a pending OAuth token and returns the embedded identity.
|
||||
// Returns ErrInvalidToken when the token is invalid or expired.
|
||||
func (s *AuthService) VerifyPendingOAuthToken(tokenStr string) (email, username string, err error) {
|
||||
if len(tokenStr) > maxTokenLength {
|
||||
return "", "", ErrInvalidToken
|
||||
}
|
||||
parser := jwt.NewParser(jwt.WithValidMethods([]string{jwt.SigningMethodHS256.Name}))
|
||||
token, parseErr := parser.ParseWithClaims(tokenStr, &pendingOAuthClaims{}, func(t *jwt.Token) (any, error) {
|
||||
if _, ok := t.Method.(*jwt.SigningMethodHMAC); !ok {
|
||||
return nil, fmt.Errorf("unexpected signing method: %v", t.Header["alg"])
|
||||
}
|
||||
return []byte(s.cfg.JWT.Secret), nil
|
||||
})
|
||||
if parseErr != nil {
|
||||
return "", "", ErrInvalidToken
|
||||
}
|
||||
claims, ok := token.Claims.(*pendingOAuthClaims)
|
||||
if !ok || !token.Valid {
|
||||
return "", "", ErrInvalidToken
|
||||
}
|
||||
if claims.Purpose != pendingOAuthPurpose {
|
||||
return "", "", ErrInvalidToken
|
||||
}
|
||||
return claims.Email, claims.Username, nil
|
||||
}
|
||||
|
||||
func (s *AuthService) assignDefaultSubscriptions(ctx context.Context, userID int64) {
|
||||
if s.settingService == nil || s.defaultSubAssigner == nil || userID <= 0 {
|
||||
return
|
||||
@@ -624,6 +752,32 @@ func (s *AuthService) assignDefaultSubscriptions(ctx context.Context, userID int
|
||||
}
|
||||
}
|
||||
|
||||
func (s *AuthService) validateRegistrationEmailPolicy(ctx context.Context, email string) error {
|
||||
if s.settingService == nil {
|
||||
return nil
|
||||
}
|
||||
whitelist := s.settingService.GetRegistrationEmailSuffixWhitelist(ctx)
|
||||
if !IsRegistrationEmailSuffixAllowed(email, whitelist) {
|
||||
return buildEmailSuffixNotAllowedError(whitelist)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func buildEmailSuffixNotAllowedError(whitelist []string) error {
|
||||
if len(whitelist) == 0 {
|
||||
return ErrEmailSuffixNotAllowed
|
||||
}
|
||||
|
||||
allowed := strings.Join(whitelist, ", ")
|
||||
return infraerrors.BadRequest(
|
||||
"EMAIL_SUFFIX_NOT_ALLOWED",
|
||||
fmt.Sprintf("email suffix is not allowed, allowed suffixes: %s", allowed),
|
||||
).WithMetadata(map[string]string{
|
||||
"allowed_suffixes": strings.Join(whitelist, ","),
|
||||
"allowed_suffix_count": strconv.Itoa(len(whitelist)),
|
||||
})
|
||||
}
|
||||
|
||||
// ValidateToken 验证JWT token并返回用户声明
|
||||
func (s *AuthService) ValidateToken(tokenString string) (*JWTClaims, error) {
|
||||
// 先做长度校验,尽早拒绝异常超长 token,降低 DoS 风险。
|
||||
@@ -933,6 +1087,12 @@ type TokenPair struct {
|
||||
ExpiresIn int `json:"expires_in"` // Access Token有效期(秒)
|
||||
}
|
||||
|
||||
// TokenPairWithUser extends TokenPair with user role for backend mode checks
|
||||
type TokenPairWithUser struct {
|
||||
TokenPair
|
||||
UserRole string
|
||||
}
|
||||
|
||||
// GenerateTokenPair 生成Access Token和Refresh Token对
|
||||
// familyID: 可选的Token家族ID,用于Token轮转时保持家族关系
|
||||
func (s *AuthService) GenerateTokenPair(ctx context.Context, user *User, familyID string) (*TokenPair, error) {
|
||||
@@ -1014,7 +1174,7 @@ func (s *AuthService) generateRefreshToken(ctx context.Context, user *User, fami
|
||||
|
||||
// RefreshTokenPair 使用Refresh Token刷新Token对
|
||||
// 实现Token轮转:每次刷新都会生成新的Refresh Token,旧Token立即失效
|
||||
func (s *AuthService) RefreshTokenPair(ctx context.Context, refreshToken string) (*TokenPair, error) {
|
||||
func (s *AuthService) RefreshTokenPair(ctx context.Context, refreshToken string) (*TokenPairWithUser, error) {
|
||||
// 检查 refreshTokenCache 是否可用
|
||||
if s.refreshTokenCache == nil {
|
||||
return nil, ErrRefreshTokenInvalid
|
||||
@@ -1079,7 +1239,14 @@ func (s *AuthService) RefreshTokenPair(ctx context.Context, refreshToken string)
|
||||
}
|
||||
|
||||
// 生成新的Token对,保持同一个家族ID
|
||||
return s.GenerateTokenPair(ctx, user, data.FamilyID)
|
||||
pair, err := s.GenerateTokenPair(ctx, user, data.FamilyID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &TokenPairWithUser{
|
||||
TokenPair: *pair,
|
||||
UserRole: user.Role,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// RevokeRefreshToken 撤销单个Refresh Token
|
||||
|
||||
Reference in New Issue
Block a user