feat: rebuild auth identity foundation flow
This commit is contained in:
@@ -13,6 +13,7 @@ import (
|
||||
"time"
|
||||
|
||||
dbent "github.com/Wei-Shaw/sub2api/ent"
|
||||
"github.com/Wei-Shaw/sub2api/ent/authidentity"
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
|
||||
@@ -106,6 +107,13 @@ func NewAuthService(
|
||||
}
|
||||
}
|
||||
|
||||
func (s *AuthService) EntClient() *dbent.Client {
|
||||
if s == nil {
|
||||
return nil
|
||||
}
|
||||
return s.entClient
|
||||
}
|
||||
|
||||
// Register 用户注册,返回token和用户
|
||||
func (s *AuthService) Register(ctx context.Context, email, password string) (string, *User, error) {
|
||||
return s.RegisterWithVerification(ctx, email, password, "", "", "")
|
||||
@@ -205,6 +213,7 @@ func (s *AuthService) RegisterWithVerification(ctx context.Context, email, passw
|
||||
logger.LegacyPrintf("service.auth", "[Auth] Database error creating user: %v", err)
|
||||
return "", nil, ErrServiceUnavailable
|
||||
}
|
||||
s.postAuthUserBootstrap(ctx, user, "email", true)
|
||||
s.assignDefaultSubscriptions(ctx, user.ID)
|
||||
|
||||
// 标记邀请码为已使用(如果使用了邀请码)
|
||||
@@ -421,6 +430,7 @@ func (s *AuthService) Login(ctx context.Context, email, password string) (string
|
||||
if !user.IsActive() {
|
||||
return "", nil, ErrUserNotActive
|
||||
}
|
||||
s.touchUserLogin(ctx, user.ID)
|
||||
|
||||
// 生成JWT token
|
||||
token, err := s.GenerateToken(user)
|
||||
@@ -501,6 +511,7 @@ func (s *AuthService) LoginOrRegisterOAuth(ctx context.Context, email, username
|
||||
}
|
||||
} else {
|
||||
user = newUser
|
||||
s.postAuthUserBootstrap(ctx, user, inferLegacySignupSource(email), true)
|
||||
s.assignDefaultSubscriptions(ctx, user.ID)
|
||||
}
|
||||
} else {
|
||||
@@ -520,6 +531,7 @@ func (s *AuthService) LoginOrRegisterOAuth(ctx context.Context, email, username
|
||||
logger.LegacyPrintf("service.auth", "[Auth] Failed to update username after oauth login: %v", err)
|
||||
}
|
||||
}
|
||||
s.touchUserLogin(ctx, user.ID)
|
||||
|
||||
token, err := s.GenerateToken(user)
|
||||
if err != nil {
|
||||
@@ -630,6 +642,7 @@ func (s *AuthService) LoginOrRegisterOAuthWithTokenPair(ctx context.Context, ema
|
||||
return nil, nil, ErrServiceUnavailable
|
||||
}
|
||||
user = newUser
|
||||
s.postAuthUserBootstrap(ctx, user, inferLegacySignupSource(email), true)
|
||||
s.assignDefaultSubscriptions(ctx, user.ID)
|
||||
}
|
||||
} else {
|
||||
@@ -646,6 +659,7 @@ func (s *AuthService) LoginOrRegisterOAuthWithTokenPair(ctx context.Context, ema
|
||||
}
|
||||
} else {
|
||||
user = newUser
|
||||
s.postAuthUserBootstrap(ctx, user, inferLegacySignupSource(email), true)
|
||||
s.assignDefaultSubscriptions(ctx, user.ID)
|
||||
if invitationRedeemCode != nil {
|
||||
if err := s.redeemRepo.Use(ctx, invitationRedeemCode.ID, user.ID); err != nil {
|
||||
@@ -670,6 +684,7 @@ func (s *AuthService) LoginOrRegisterOAuthWithTokenPair(ctx context.Context, ema
|
||||
logger.LegacyPrintf("service.auth", "[Auth] Failed to update username after oauth login: %v", err)
|
||||
}
|
||||
}
|
||||
s.touchUserLogin(ctx, user.ID)
|
||||
|
||||
tokenPair, err := s.GenerateTokenPair(ctx, user, "")
|
||||
if err != nil {
|
||||
@@ -678,63 +693,6 @@ func (s *AuthService) LoginOrRegisterOAuthWithTokenPair(ctx context.Context, ema
|
||||
return tokenPair, user, nil
|
||||
}
|
||||
|
||||
// pendingOAuthTokenTTL is the validity period for pending OAuth tokens.
|
||||
const pendingOAuthTokenTTL = 10 * time.Minute
|
||||
|
||||
// pendingOAuthPurpose is the purpose claim value for pending OAuth registration tokens.
|
||||
const pendingOAuthPurpose = "pending_oauth_registration"
|
||||
|
||||
type pendingOAuthClaims struct {
|
||||
Email string `json:"email"`
|
||||
Username string `json:"username"`
|
||||
Purpose string `json:"purpose"`
|
||||
jwt.RegisteredClaims
|
||||
}
|
||||
|
||||
// CreatePendingOAuthToken generates a short-lived JWT that carries the OAuth identity
|
||||
// while waiting for the user to supply an invitation code.
|
||||
func (s *AuthService) CreatePendingOAuthToken(email, username string) (string, error) {
|
||||
now := time.Now()
|
||||
claims := &pendingOAuthClaims{
|
||||
Email: email,
|
||||
Username: username,
|
||||
Purpose: pendingOAuthPurpose,
|
||||
RegisteredClaims: jwt.RegisteredClaims{
|
||||
ExpiresAt: jwt.NewNumericDate(now.Add(pendingOAuthTokenTTL)),
|
||||
IssuedAt: jwt.NewNumericDate(now),
|
||||
NotBefore: jwt.NewNumericDate(now),
|
||||
},
|
||||
}
|
||||
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
|
||||
return token.SignedString([]byte(s.cfg.JWT.Secret))
|
||||
}
|
||||
|
||||
// VerifyPendingOAuthToken validates a pending OAuth token and returns the embedded identity.
|
||||
// Returns ErrInvalidToken when the token is invalid or expired.
|
||||
func (s *AuthService) VerifyPendingOAuthToken(tokenStr string) (email, username string, err error) {
|
||||
if len(tokenStr) > maxTokenLength {
|
||||
return "", "", ErrInvalidToken
|
||||
}
|
||||
parser := jwt.NewParser(jwt.WithValidMethods([]string{jwt.SigningMethodHS256.Name}))
|
||||
token, parseErr := parser.ParseWithClaims(tokenStr, &pendingOAuthClaims{}, func(t *jwt.Token) (any, error) {
|
||||
if _, ok := t.Method.(*jwt.SigningMethodHMAC); !ok {
|
||||
return nil, fmt.Errorf("unexpected signing method: %v", t.Header["alg"])
|
||||
}
|
||||
return []byte(s.cfg.JWT.Secret), nil
|
||||
})
|
||||
if parseErr != nil {
|
||||
return "", "", ErrInvalidToken
|
||||
}
|
||||
claims, ok := token.Claims.(*pendingOAuthClaims)
|
||||
if !ok || !token.Valid {
|
||||
return "", "", ErrInvalidToken
|
||||
}
|
||||
if claims.Purpose != pendingOAuthPurpose {
|
||||
return "", "", ErrInvalidToken
|
||||
}
|
||||
return claims.Email, claims.Username, nil
|
||||
}
|
||||
|
||||
func (s *AuthService) assignDefaultSubscriptions(ctx context.Context, userID int64) {
|
||||
if s.settingService == nil || s.defaultSubAssigner == nil || userID <= 0 {
|
||||
return
|
||||
@@ -752,6 +710,95 @@ func (s *AuthService) assignDefaultSubscriptions(ctx context.Context, userID int
|
||||
}
|
||||
}
|
||||
|
||||
func (s *AuthService) postAuthUserBootstrap(ctx context.Context, user *User, signupSource string, touchLogin bool) {
|
||||
if user == nil || user.ID <= 0 {
|
||||
return
|
||||
}
|
||||
|
||||
if strings.TrimSpace(signupSource) == "" {
|
||||
signupSource = "email"
|
||||
}
|
||||
s.updateUserSignupSource(ctx, user.ID, signupSource)
|
||||
|
||||
if signupSource == "email" {
|
||||
s.ensureEmailAuthIdentity(ctx, user)
|
||||
}
|
||||
if touchLogin {
|
||||
s.touchUserLogin(ctx, user.ID)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *AuthService) updateUserSignupSource(ctx context.Context, userID int64, signupSource string) {
|
||||
if s == nil || s.entClient == nil || userID <= 0 {
|
||||
return
|
||||
}
|
||||
if strings.TrimSpace(signupSource) == "" {
|
||||
return
|
||||
}
|
||||
if err := s.entClient.User.UpdateOneID(userID).
|
||||
SetSignupSource(signupSource).
|
||||
Exec(ctx); err != nil {
|
||||
logger.LegacyPrintf("service.auth", "[Auth] Failed to update signup source: user_id=%d source=%s err=%v", userID, signupSource, err)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *AuthService) touchUserLogin(ctx context.Context, userID int64) {
|
||||
if s == nil || s.entClient == nil || userID <= 0 {
|
||||
return
|
||||
}
|
||||
now := time.Now().UTC()
|
||||
if err := s.entClient.User.UpdateOneID(userID).
|
||||
SetLastLoginAt(now).
|
||||
SetLastActiveAt(now).
|
||||
Exec(ctx); err != nil {
|
||||
logger.LegacyPrintf("service.auth", "[Auth] Failed to touch login timestamps: user_id=%d err=%v", userID, err)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *AuthService) ensureEmailAuthIdentity(ctx context.Context, user *User) {
|
||||
if s == nil || s.entClient == nil || user == nil || user.ID <= 0 {
|
||||
return
|
||||
}
|
||||
|
||||
email := strings.ToLower(strings.TrimSpace(user.Email))
|
||||
if email == "" || isReservedEmail(email) {
|
||||
return
|
||||
}
|
||||
|
||||
if err := s.entClient.AuthIdentity.Create().
|
||||
SetUserID(user.ID).
|
||||
SetProviderType("email").
|
||||
SetProviderKey("email").
|
||||
SetProviderSubject(email).
|
||||
SetVerifiedAt(time.Now().UTC()).
|
||||
SetMetadata(map[string]any{
|
||||
"source": "auth_service_dual_write",
|
||||
}).
|
||||
OnConflictColumns(
|
||||
authidentity.FieldProviderType,
|
||||
authidentity.FieldProviderKey,
|
||||
authidentity.FieldProviderSubject,
|
||||
).
|
||||
DoNothing().
|
||||
Exec(ctx); err != nil {
|
||||
logger.LegacyPrintf("service.auth", "[Auth] Failed to ensure email auth identity: user_id=%d email=%s err=%v", user.ID, email, err)
|
||||
}
|
||||
}
|
||||
|
||||
func inferLegacySignupSource(email string) string {
|
||||
normalized := strings.ToLower(strings.TrimSpace(email))
|
||||
switch {
|
||||
case strings.HasSuffix(normalized, LinuxDoConnectSyntheticEmailDomain):
|
||||
return "linuxdo"
|
||||
case strings.HasSuffix(normalized, OIDCConnectSyntheticEmailDomain):
|
||||
return "oidc"
|
||||
case strings.HasSuffix(normalized, WeChatConnectSyntheticEmailDomain):
|
||||
return "wechat"
|
||||
default:
|
||||
return "email"
|
||||
}
|
||||
}
|
||||
|
||||
func (s *AuthService) validateRegistrationEmailPolicy(ctx context.Context, email string) error {
|
||||
if s.settingService == nil {
|
||||
return nil
|
||||
@@ -834,7 +881,8 @@ func randomHexString(byteLength int) (string, error) {
|
||||
func isReservedEmail(email string) bool {
|
||||
normalized := strings.ToLower(strings.TrimSpace(email))
|
||||
return strings.HasSuffix(normalized, LinuxDoConnectSyntheticEmailDomain) ||
|
||||
strings.HasSuffix(normalized, OIDCConnectSyntheticEmailDomain)
|
||||
strings.HasSuffix(normalized, OIDCConnectSyntheticEmailDomain) ||
|
||||
strings.HasSuffix(normalized, WeChatConnectSyntheticEmailDomain)
|
||||
}
|
||||
|
||||
// GenerateToken 生成JWT access token
|
||||
|
||||
Reference in New Issue
Block a user