|
|
|
|
@@ -12,6 +12,7 @@ import (
|
|
|
|
|
"strings"
|
|
|
|
|
"time"
|
|
|
|
|
|
|
|
|
|
dbent "github.com/Wei-Shaw/sub2api/ent"
|
|
|
|
|
"github.com/Wei-Shaw/sub2api/internal/config"
|
|
|
|
|
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
|
|
|
|
|
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
|
|
|
|
|
@@ -59,6 +60,7 @@ type JWTClaims struct {
|
|
|
|
|
|
|
|
|
|
// AuthService 认证服务
|
|
|
|
|
type AuthService struct {
|
|
|
|
|
entClient *dbent.Client
|
|
|
|
|
userRepo UserRepository
|
|
|
|
|
redeemRepo RedeemCodeRepository
|
|
|
|
|
refreshTokenCache RefreshTokenCache
|
|
|
|
|
@@ -77,6 +79,7 @@ type DefaultSubscriptionAssigner interface {
|
|
|
|
|
|
|
|
|
|
// NewAuthService 创建认证服务实例
|
|
|
|
|
func NewAuthService(
|
|
|
|
|
entClient *dbent.Client,
|
|
|
|
|
userRepo UserRepository,
|
|
|
|
|
redeemRepo RedeemCodeRepository,
|
|
|
|
|
refreshTokenCache RefreshTokenCache,
|
|
|
|
|
@@ -89,6 +92,7 @@ func NewAuthService(
|
|
|
|
|
defaultSubAssigner DefaultSubscriptionAssigner,
|
|
|
|
|
) *AuthService {
|
|
|
|
|
return &AuthService{
|
|
|
|
|
entClient: entClient,
|
|
|
|
|
userRepo: userRepo,
|
|
|
|
|
redeemRepo: redeemRepo,
|
|
|
|
|
refreshTokenCache: refreshTokenCache,
|
|
|
|
|
@@ -597,24 +601,52 @@ func (s *AuthService) LoginOrRegisterOAuthWithTokenPair(ctx context.Context, ema
|
|
|
|
|
Status: StatusActive,
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if err := s.userRepo.Create(ctx, newUser); err != nil {
|
|
|
|
|
if errors.Is(err, ErrEmailExists) {
|
|
|
|
|
user, err = s.userRepo.GetByEmail(ctx, email)
|
|
|
|
|
if err != nil {
|
|
|
|
|
logger.LegacyPrintf("service.auth", "[Auth] Database error getting user after conflict: %v", err)
|
|
|
|
|
if s.entClient != nil && invitationRedeemCode != nil {
|
|
|
|
|
tx, err := s.entClient.Tx(ctx)
|
|
|
|
|
if err != nil {
|
|
|
|
|
logger.LegacyPrintf("service.auth", "[Auth] Failed to begin transaction for oauth registration: %v", err)
|
|
|
|
|
return nil, nil, ErrServiceUnavailable
|
|
|
|
|
}
|
|
|
|
|
defer func() { _ = tx.Rollback() }()
|
|
|
|
|
txCtx := dbent.NewTxContext(ctx, tx)
|
|
|
|
|
|
|
|
|
|
if err := s.userRepo.Create(txCtx, newUser); err != nil {
|
|
|
|
|
if errors.Is(err, ErrEmailExists) {
|
|
|
|
|
user, err = s.userRepo.GetByEmail(ctx, email)
|
|
|
|
|
if err != nil {
|
|
|
|
|
logger.LegacyPrintf("service.auth", "[Auth] Database error getting user after conflict: %v", err)
|
|
|
|
|
return nil, nil, ErrServiceUnavailable
|
|
|
|
|
}
|
|
|
|
|
} else {
|
|
|
|
|
logger.LegacyPrintf("service.auth", "[Auth] Database error creating oauth user: %v", err)
|
|
|
|
|
return nil, nil, ErrServiceUnavailable
|
|
|
|
|
}
|
|
|
|
|
} else {
|
|
|
|
|
logger.LegacyPrintf("service.auth", "[Auth] Database error creating oauth user: %v", err)
|
|
|
|
|
return nil, nil, ErrServiceUnavailable
|
|
|
|
|
if err := s.redeemRepo.Use(txCtx, invitationRedeemCode.ID, newUser.ID); err != nil {
|
|
|
|
|
return nil, nil, ErrInvitationCodeInvalid
|
|
|
|
|
}
|
|
|
|
|
if err := tx.Commit(); err != nil {
|
|
|
|
|
logger.LegacyPrintf("service.auth", "[Auth] Failed to commit oauth registration transaction: %v", err)
|
|
|
|
|
return nil, nil, ErrServiceUnavailable
|
|
|
|
|
}
|
|
|
|
|
user = newUser
|
|
|
|
|
s.assignDefaultSubscriptions(ctx, user.ID)
|
|
|
|
|
}
|
|
|
|
|
} else {
|
|
|
|
|
user = newUser
|
|
|
|
|
s.assignDefaultSubscriptions(ctx, user.ID)
|
|
|
|
|
if invitationRedeemCode != nil {
|
|
|
|
|
if err := s.redeemRepo.Use(ctx, invitationRedeemCode.ID, user.ID); err != nil {
|
|
|
|
|
logger.LegacyPrintf("service.auth", "[Auth] Failed to mark invitation code as used for oauth user %d: %v", user.ID, err)
|
|
|
|
|
if err := s.userRepo.Create(ctx, newUser); err != nil {
|
|
|
|
|
if errors.Is(err, ErrEmailExists) {
|
|
|
|
|
user, err = s.userRepo.GetByEmail(ctx, email)
|
|
|
|
|
if err != nil {
|
|
|
|
|
logger.LegacyPrintf("service.auth", "[Auth] Database error getting user after conflict: %v", err)
|
|
|
|
|
return nil, nil, ErrServiceUnavailable
|
|
|
|
|
}
|
|
|
|
|
} else {
|
|
|
|
|
logger.LegacyPrintf("service.auth", "[Auth] Database error creating oauth user: %v", err)
|
|
|
|
|
return nil, nil, ErrServiceUnavailable
|
|
|
|
|
}
|
|
|
|
|
} else {
|
|
|
|
|
user = newUser
|
|
|
|
|
s.assignDefaultSubscriptions(ctx, user.ID)
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
} else {
|
|
|
|
|
@@ -644,9 +676,13 @@ func (s *AuthService) LoginOrRegisterOAuthWithTokenPair(ctx context.Context, ema
|
|
|
|
|
// pendingOAuthTokenTTL is the validity period for pending OAuth tokens.
|
|
|
|
|
const pendingOAuthTokenTTL = 10 * time.Minute
|
|
|
|
|
|
|
|
|
|
// pendingOAuthPurpose is the purpose claim value for pending OAuth registration tokens.
|
|
|
|
|
const pendingOAuthPurpose = "pending_oauth_registration"
|
|
|
|
|
|
|
|
|
|
type pendingOAuthClaims struct {
|
|
|
|
|
Email string `json:"email"`
|
|
|
|
|
Username string `json:"username"`
|
|
|
|
|
Purpose string `json:"purpose"`
|
|
|
|
|
jwt.RegisteredClaims
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
@@ -657,6 +693,7 @@ func (s *AuthService) CreatePendingOAuthToken(email, username string) (string, e
|
|
|
|
|
claims := &pendingOAuthClaims{
|
|
|
|
|
Email: email,
|
|
|
|
|
Username: username,
|
|
|
|
|
Purpose: pendingOAuthPurpose,
|
|
|
|
|
RegisteredClaims: jwt.RegisteredClaims{
|
|
|
|
|
ExpiresAt: jwt.NewNumericDate(now.Add(pendingOAuthTokenTTL)),
|
|
|
|
|
IssuedAt: jwt.NewNumericDate(now),
|
|
|
|
|
@@ -687,6 +724,9 @@ func (s *AuthService) VerifyPendingOAuthToken(tokenStr string) (email, username
|
|
|
|
|
if !ok || !token.Valid {
|
|
|
|
|
return "", "", ErrInvalidToken
|
|
|
|
|
}
|
|
|
|
|
if claims.Purpose != pendingOAuthPurpose {
|
|
|
|
|
return "", "", ErrInvalidToken
|
|
|
|
|
}
|
|
|
|
|
return claims.Email, claims.Username, nil
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|