Merge remote-tracking branch 'upstream/main'
# Conflicts: # backend/internal/server/api_contract_test.go # backend/internal/service/setting_service.go # deploy/docker-compose.yml # frontend/src/components/layout/AppSidebar.vue # frontend/src/views/admin/SettingsView.vue
This commit is contained in:
@@ -4,6 +4,7 @@ import (
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"crypto/sha256"
|
||||
"encoding/binary"
|
||||
"encoding/hex"
|
||||
"errors"
|
||||
"fmt"
|
||||
@@ -13,6 +14,7 @@ import (
|
||||
"time"
|
||||
|
||||
dbent "github.com/Wei-Shaw/sub2api/ent"
|
||||
"github.com/Wei-Shaw/sub2api/ent/authidentity"
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
|
||||
@@ -70,6 +72,7 @@ type AuthService struct {
|
||||
turnstileService *TurnstileService
|
||||
emailQueueService *EmailQueueService
|
||||
promoService *PromoService
|
||||
affiliateService *AffiliateService
|
||||
defaultSubAssigner DefaultSubscriptionAssigner
|
||||
}
|
||||
|
||||
@@ -77,6 +80,12 @@ type DefaultSubscriptionAssigner interface {
|
||||
AssignOrExtendSubscription(ctx context.Context, input *AssignSubscriptionInput) (*UserSubscription, bool, error)
|
||||
}
|
||||
|
||||
type signupGrantPlan struct {
|
||||
Balance float64
|
||||
Concurrency int
|
||||
Subscriptions []DefaultSubscriptionSetting
|
||||
}
|
||||
|
||||
// NewAuthService 创建认证服务实例
|
||||
func NewAuthService(
|
||||
entClient *dbent.Client,
|
||||
@@ -90,6 +99,7 @@ func NewAuthService(
|
||||
emailQueueService *EmailQueueService,
|
||||
promoService *PromoService,
|
||||
defaultSubAssigner DefaultSubscriptionAssigner,
|
||||
affiliateService *AffiliateService,
|
||||
) *AuthService {
|
||||
return &AuthService{
|
||||
entClient: entClient,
|
||||
@@ -102,17 +112,25 @@ func NewAuthService(
|
||||
turnstileService: turnstileService,
|
||||
emailQueueService: emailQueueService,
|
||||
promoService: promoService,
|
||||
affiliateService: affiliateService,
|
||||
defaultSubAssigner: defaultSubAssigner,
|
||||
}
|
||||
}
|
||||
|
||||
// Register 用户注册,返回token和用户
|
||||
func (s *AuthService) Register(ctx context.Context, email, password string) (string, *User, error) {
|
||||
return s.RegisterWithVerification(ctx, email, password, "", "", "")
|
||||
func (s *AuthService) EntClient() *dbent.Client {
|
||||
if s == nil {
|
||||
return nil
|
||||
}
|
||||
return s.entClient
|
||||
}
|
||||
|
||||
// RegisterWithVerification 用户注册(支持邮件验证、优惠码和邀请码),返回token和用户
|
||||
func (s *AuthService) RegisterWithVerification(ctx context.Context, email, password, verifyCode, promoCode, invitationCode string) (string, *User, error) {
|
||||
// Register 用户注册,返回token和用户
|
||||
func (s *AuthService) Register(ctx context.Context, email, password string) (string, *User, error) {
|
||||
return s.RegisterWithVerification(ctx, email, password, "", "", "", "")
|
||||
}
|
||||
|
||||
// RegisterWithVerification 用户注册(支持邮件验证、优惠码、邀请码和邀请返利码),返回token和用户。
|
||||
func (s *AuthService) RegisterWithVerification(ctx context.Context, email, password, verifyCode, promoCode, invitationCode, affiliateCode string) (string, *User, error) {
|
||||
// 检查是否开放注册(默认关闭:settingService 未配置时不允许注册)
|
||||
if s.settingService == nil || !s.settingService.IsRegistrationEnabled(ctx) {
|
||||
return "", nil, ErrRegDisabled
|
||||
@@ -179,12 +197,12 @@ func (s *AuthService) RegisterWithVerification(ctx context.Context, email, passw
|
||||
return "", nil, fmt.Errorf("hash password: %w", err)
|
||||
}
|
||||
|
||||
// 获取默认配置
|
||||
defaultBalance := s.cfg.Default.UserBalance
|
||||
defaultConcurrency := s.cfg.Default.UserConcurrency
|
||||
grantPlan := s.resolveSignupGrantPlan(ctx, "email")
|
||||
|
||||
// 新用户默认 RPM(0 = 不限制)。注册时写入,后续作为用户级兜底。
|
||||
var defaultRPMLimit int
|
||||
if s.settingService != nil {
|
||||
defaultBalance = s.settingService.GetDefaultBalance(ctx)
|
||||
defaultConcurrency = s.settingService.GetDefaultConcurrency(ctx)
|
||||
defaultRPMLimit = s.settingService.GetDefaultUserRPMLimit(ctx)
|
||||
}
|
||||
|
||||
// 创建用户
|
||||
@@ -192,8 +210,9 @@ func (s *AuthService) RegisterWithVerification(ctx context.Context, email, passw
|
||||
Email: email,
|
||||
PasswordHash: hashedPassword,
|
||||
Role: RoleUser,
|
||||
Balance: defaultBalance,
|
||||
Concurrency: defaultConcurrency,
|
||||
Balance: grantPlan.Balance,
|
||||
Concurrency: grantPlan.Concurrency,
|
||||
RPMLimit: defaultRPMLimit,
|
||||
Status: StatusActive,
|
||||
}
|
||||
|
||||
@@ -205,7 +224,19 @@ func (s *AuthService) RegisterWithVerification(ctx context.Context, email, passw
|
||||
logger.LegacyPrintf("service.auth", "[Auth] Database error creating user: %v", err)
|
||||
return "", nil, ErrServiceUnavailable
|
||||
}
|
||||
s.assignDefaultSubscriptions(ctx, user.ID)
|
||||
s.postAuthUserBootstrap(ctx, user, "email", true)
|
||||
s.assignSubscriptions(ctx, user.ID, grantPlan.Subscriptions, "auto assigned by signup defaults")
|
||||
if s.affiliateService != nil {
|
||||
if _, err := s.affiliateService.EnsureUserAffiliate(ctx, user.ID); err != nil {
|
||||
logger.LegacyPrintf("service.auth", "[Auth] Failed to initialize affiliate profile for user %d: %v", user.ID, err)
|
||||
}
|
||||
if code := strings.TrimSpace(affiliateCode); code != "" {
|
||||
if err := s.affiliateService.BindInviterByCode(ctx, user.ID, code); err != nil {
|
||||
// 邀请返利码绑定失败不影响注册,只记录日志
|
||||
logger.LegacyPrintf("service.auth", "[Auth] Failed to bind affiliate inviter for user %d: %v", user.ID, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 标记邀请码为已使用(如果使用了邀请码)
|
||||
if invitationRedeemCode != nil {
|
||||
@@ -469,12 +500,11 @@ func (s *AuthService) LoginOrRegisterOAuth(ctx context.Context, email, username
|
||||
return "", nil, fmt.Errorf("hash password: %w", err)
|
||||
}
|
||||
|
||||
// 新用户默认值。
|
||||
defaultBalance := s.cfg.Default.UserBalance
|
||||
defaultConcurrency := s.cfg.Default.UserConcurrency
|
||||
signupSource := inferLegacySignupSource(email)
|
||||
grantPlan := s.resolveSignupGrantPlan(ctx, signupSource)
|
||||
var defaultRPMLimit int
|
||||
if s.settingService != nil {
|
||||
defaultBalance = s.settingService.GetDefaultBalance(ctx)
|
||||
defaultConcurrency = s.settingService.GetDefaultConcurrency(ctx)
|
||||
defaultRPMLimit = s.settingService.GetDefaultUserRPMLimit(ctx)
|
||||
}
|
||||
|
||||
newUser := &User{
|
||||
@@ -482,9 +512,11 @@ func (s *AuthService) LoginOrRegisterOAuth(ctx context.Context, email, username
|
||||
Username: username,
|
||||
PasswordHash: hashedPassword,
|
||||
Role: RoleUser,
|
||||
Balance: defaultBalance,
|
||||
Concurrency: defaultConcurrency,
|
||||
Balance: grantPlan.Balance,
|
||||
Concurrency: grantPlan.Concurrency,
|
||||
RPMLimit: defaultRPMLimit,
|
||||
Status: StatusActive,
|
||||
SignupSource: signupSource,
|
||||
}
|
||||
|
||||
if err := s.userRepo.Create(ctx, newUser); err != nil {
|
||||
@@ -501,7 +533,8 @@ func (s *AuthService) LoginOrRegisterOAuth(ctx context.Context, email, username
|
||||
}
|
||||
} else {
|
||||
user = newUser
|
||||
s.assignDefaultSubscriptions(ctx, user.ID)
|
||||
s.postAuthUserBootstrap(ctx, user, signupSource, false)
|
||||
s.assignSubscriptions(ctx, user.ID, grantPlan.Subscriptions, "auto assigned by signup defaults")
|
||||
}
|
||||
} else {
|
||||
logger.LegacyPrintf("service.auth", "[Auth] Database error during oauth login: %v", err)
|
||||
@@ -520,7 +553,6 @@ func (s *AuthService) LoginOrRegisterOAuth(ctx context.Context, email, username
|
||||
logger.LegacyPrintf("service.auth", "[Auth] Failed to update username after oauth login: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
token, err := s.GenerateToken(user)
|
||||
if err != nil {
|
||||
return "", nil, fmt.Errorf("generate token: %w", err)
|
||||
@@ -531,7 +563,8 @@ func (s *AuthService) LoginOrRegisterOAuth(ctx context.Context, email, username
|
||||
// LoginOrRegisterOAuthWithTokenPair 用于第三方 OAuth/SSO 登录,返回完整的 TokenPair。
|
||||
// 与 LoginOrRegisterOAuth 功能相同,但返回 TokenPair 而非单个 token。
|
||||
// invitationCode 仅在邀请码注册模式下新用户注册时使用;已有账号登录时忽略。
|
||||
func (s *AuthService) LoginOrRegisterOAuthWithTokenPair(ctx context.Context, email, username, invitationCode string) (*TokenPair, *User, error) {
|
||||
// affiliateCode 用于邀请返利绑定,仅在新用户注册时使用。
|
||||
func (s *AuthService) LoginOrRegisterOAuthWithTokenPair(ctx context.Context, email, username, invitationCode, affiliateCode string) (*TokenPair, *User, error) {
|
||||
// 检查 refreshTokenCache 是否可用
|
||||
if s.refreshTokenCache == nil {
|
||||
return nil, nil, errors.New("refresh token cache not configured")
|
||||
@@ -584,11 +617,11 @@ func (s *AuthService) LoginOrRegisterOAuthWithTokenPair(ctx context.Context, ema
|
||||
return nil, nil, fmt.Errorf("hash password: %w", err)
|
||||
}
|
||||
|
||||
defaultBalance := s.cfg.Default.UserBalance
|
||||
defaultConcurrency := s.cfg.Default.UserConcurrency
|
||||
signupSource := inferLegacySignupSource(email)
|
||||
grantPlan := s.resolveSignupGrantPlan(ctx, signupSource)
|
||||
var defaultRPMLimit int
|
||||
if s.settingService != nil {
|
||||
defaultBalance = s.settingService.GetDefaultBalance(ctx)
|
||||
defaultConcurrency = s.settingService.GetDefaultConcurrency(ctx)
|
||||
defaultRPMLimit = s.settingService.GetDefaultUserRPMLimit(ctx)
|
||||
}
|
||||
|
||||
newUser := &User{
|
||||
@@ -596,9 +629,11 @@ func (s *AuthService) LoginOrRegisterOAuthWithTokenPair(ctx context.Context, ema
|
||||
Username: username,
|
||||
PasswordHash: hashedPassword,
|
||||
Role: RoleUser,
|
||||
Balance: defaultBalance,
|
||||
Concurrency: defaultConcurrency,
|
||||
Balance: grantPlan.Balance,
|
||||
Concurrency: grantPlan.Concurrency,
|
||||
RPMLimit: defaultRPMLimit,
|
||||
Status: StatusActive,
|
||||
SignupSource: signupSource,
|
||||
}
|
||||
|
||||
if s.entClient != nil && invitationRedeemCode != nil {
|
||||
@@ -630,7 +665,9 @@ func (s *AuthService) LoginOrRegisterOAuthWithTokenPair(ctx context.Context, ema
|
||||
return nil, nil, ErrServiceUnavailable
|
||||
}
|
||||
user = newUser
|
||||
s.assignDefaultSubscriptions(ctx, user.ID)
|
||||
s.postAuthUserBootstrap(ctx, user, signupSource, false)
|
||||
s.assignSubscriptions(ctx, user.ID, grantPlan.Subscriptions, "auto assigned by signup defaults")
|
||||
s.bindOAuthAffiliate(ctx, user.ID, affiliateCode)
|
||||
}
|
||||
} else {
|
||||
if err := s.userRepo.Create(ctx, newUser); err != nil {
|
||||
@@ -646,7 +683,9 @@ func (s *AuthService) LoginOrRegisterOAuthWithTokenPair(ctx context.Context, ema
|
||||
}
|
||||
} else {
|
||||
user = newUser
|
||||
s.assignDefaultSubscriptions(ctx, user.ID)
|
||||
s.postAuthUserBootstrap(ctx, user, signupSource, false)
|
||||
s.assignSubscriptions(ctx, user.ID, grantPlan.Subscriptions, "auto assigned by signup defaults")
|
||||
s.bindOAuthAffiliate(ctx, user.ID, affiliateCode)
|
||||
if invitationRedeemCode != nil {
|
||||
if err := s.redeemRepo.Use(ctx, invitationRedeemCode.ID, user.ID); err != nil {
|
||||
return nil, nil, ErrInvitationCodeInvalid
|
||||
@@ -670,7 +709,6 @@ func (s *AuthService) LoginOrRegisterOAuthWithTokenPair(ctx context.Context, ema
|
||||
logger.LegacyPrintf("service.auth", "[Auth] Failed to update username after oauth login: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
tokenPair, err := s.GenerateTokenPair(ctx, user, "")
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("generate token pair: %w", err)
|
||||
@@ -678,80 +716,289 @@ func (s *AuthService) LoginOrRegisterOAuthWithTokenPair(ctx context.Context, ema
|
||||
return tokenPair, user, nil
|
||||
}
|
||||
|
||||
// pendingOAuthTokenTTL is the validity period for pending OAuth tokens.
|
||||
const pendingOAuthTokenTTL = 10 * time.Minute
|
||||
|
||||
// pendingOAuthPurpose is the purpose claim value for pending OAuth registration tokens.
|
||||
const pendingOAuthPurpose = "pending_oauth_registration"
|
||||
|
||||
type pendingOAuthClaims struct {
|
||||
Email string `json:"email"`
|
||||
Username string `json:"username"`
|
||||
Purpose string `json:"purpose"`
|
||||
jwt.RegisteredClaims
|
||||
}
|
||||
|
||||
// CreatePendingOAuthToken generates a short-lived JWT that carries the OAuth identity
|
||||
// while waiting for the user to supply an invitation code.
|
||||
func (s *AuthService) CreatePendingOAuthToken(email, username string) (string, error) {
|
||||
now := time.Now()
|
||||
claims := &pendingOAuthClaims{
|
||||
Email: email,
|
||||
Username: username,
|
||||
Purpose: pendingOAuthPurpose,
|
||||
RegisteredClaims: jwt.RegisteredClaims{
|
||||
ExpiresAt: jwt.NewNumericDate(now.Add(pendingOAuthTokenTTL)),
|
||||
IssuedAt: jwt.NewNumericDate(now),
|
||||
NotBefore: jwt.NewNumericDate(now),
|
||||
},
|
||||
}
|
||||
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
|
||||
return token.SignedString([]byte(s.cfg.JWT.Secret))
|
||||
}
|
||||
|
||||
// VerifyPendingOAuthToken validates a pending OAuth token and returns the embedded identity.
|
||||
// Returns ErrInvalidToken when the token is invalid or expired.
|
||||
func (s *AuthService) VerifyPendingOAuthToken(tokenStr string) (email, username string, err error) {
|
||||
if len(tokenStr) > maxTokenLength {
|
||||
return "", "", ErrInvalidToken
|
||||
}
|
||||
parser := jwt.NewParser(jwt.WithValidMethods([]string{jwt.SigningMethodHS256.Name}))
|
||||
token, parseErr := parser.ParseWithClaims(tokenStr, &pendingOAuthClaims{}, func(t *jwt.Token) (any, error) {
|
||||
if _, ok := t.Method.(*jwt.SigningMethodHMAC); !ok {
|
||||
return nil, fmt.Errorf("unexpected signing method: %v", t.Header["alg"])
|
||||
}
|
||||
return []byte(s.cfg.JWT.Secret), nil
|
||||
})
|
||||
if parseErr != nil {
|
||||
return "", "", ErrInvalidToken
|
||||
}
|
||||
claims, ok := token.Claims.(*pendingOAuthClaims)
|
||||
if !ok || !token.Valid {
|
||||
return "", "", ErrInvalidToken
|
||||
}
|
||||
if claims.Purpose != pendingOAuthPurpose {
|
||||
return "", "", ErrInvalidToken
|
||||
}
|
||||
return claims.Email, claims.Username, nil
|
||||
}
|
||||
|
||||
func (s *AuthService) assignDefaultSubscriptions(ctx context.Context, userID int64) {
|
||||
func (s *AuthService) assignSubscriptions(ctx context.Context, userID int64, items []DefaultSubscriptionSetting, notes string) {
|
||||
if s.settingService == nil || s.defaultSubAssigner == nil || userID <= 0 {
|
||||
return
|
||||
}
|
||||
items := s.settingService.GetDefaultSubscriptions(ctx)
|
||||
for _, item := range items {
|
||||
if _, _, err := s.defaultSubAssigner.AssignOrExtendSubscription(ctx, &AssignSubscriptionInput{
|
||||
UserID: userID,
|
||||
GroupID: item.GroupID,
|
||||
ValidityDays: item.ValidityDays,
|
||||
Notes: "auto assigned by default user subscriptions setting",
|
||||
Notes: notes,
|
||||
}); err != nil {
|
||||
logger.LegacyPrintf("service.auth", "[Auth] Failed to assign default subscription: user_id=%d group_id=%d err=%v", userID, item.GroupID, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *AuthService) resolveSignupGrantPlan(ctx context.Context, signupSource string) signupGrantPlan {
|
||||
plan := signupGrantPlan{}
|
||||
if s != nil && s.cfg != nil {
|
||||
plan.Balance = s.cfg.Default.UserBalance
|
||||
plan.Concurrency = s.cfg.Default.UserConcurrency
|
||||
}
|
||||
if s == nil || s.settingService == nil {
|
||||
return plan
|
||||
}
|
||||
|
||||
plan.Balance = s.settingService.GetDefaultBalance(ctx)
|
||||
plan.Concurrency = s.settingService.GetDefaultConcurrency(ctx)
|
||||
plan.Subscriptions = s.settingService.GetDefaultSubscriptions(ctx)
|
||||
|
||||
resolved, enabled, err := s.settingService.ResolveAuthSourceGrantSettings(ctx, signupSource, false)
|
||||
if err != nil {
|
||||
logger.LegacyPrintf("service.auth", "[Auth] Failed to load auth source signup defaults for %s: %v", signupSource, err)
|
||||
return plan
|
||||
}
|
||||
if !enabled {
|
||||
return plan
|
||||
}
|
||||
|
||||
plan.Balance = resolved.Balance
|
||||
plan.Concurrency = resolved.Concurrency
|
||||
plan.Subscriptions = resolved.Subscriptions
|
||||
return plan
|
||||
}
|
||||
|
||||
func authSourceSignupSettings(defaults *AuthSourceDefaultSettings, signupSource string) (ProviderDefaultGrantSettings, bool) {
|
||||
if defaults == nil {
|
||||
return ProviderDefaultGrantSettings{}, false
|
||||
}
|
||||
|
||||
switch strings.ToLower(strings.TrimSpace(signupSource)) {
|
||||
case "email":
|
||||
return defaults.Email, true
|
||||
case "linuxdo":
|
||||
return defaults.LinuxDo, true
|
||||
case "oidc":
|
||||
return defaults.OIDC, true
|
||||
case "wechat":
|
||||
return defaults.WeChat, true
|
||||
default:
|
||||
return ProviderDefaultGrantSettings{}, false
|
||||
}
|
||||
}
|
||||
|
||||
// bindOAuthAffiliate initializes the affiliate profile and binds the inviter
|
||||
// for an OAuth-registered user. Failures are logged but never block registration.
|
||||
func (s *AuthService) bindOAuthAffiliate(ctx context.Context, userID int64, affiliateCode string) {
|
||||
if s.affiliateService == nil || userID <= 0 {
|
||||
return
|
||||
}
|
||||
if _, err := s.affiliateService.EnsureUserAffiliate(ctx, userID); err != nil {
|
||||
logger.LegacyPrintf("service.auth", "[Auth] Failed to initialize affiliate profile for user %d: %v", userID, err)
|
||||
}
|
||||
if code := strings.TrimSpace(affiliateCode); code != "" {
|
||||
if err := s.affiliateService.BindInviterByCode(ctx, userID, code); err != nil {
|
||||
logger.LegacyPrintf("service.auth", "[Auth] Failed to bind affiliate inviter for user %d: %v", userID, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *AuthService) postAuthUserBootstrap(ctx context.Context, user *User, signupSource string, touchLogin bool) {
|
||||
if user == nil || user.ID <= 0 {
|
||||
return
|
||||
}
|
||||
|
||||
if strings.TrimSpace(signupSource) == "" {
|
||||
signupSource = "email"
|
||||
}
|
||||
s.updateUserSignupSource(ctx, user.ID, signupSource)
|
||||
|
||||
if touchLogin {
|
||||
s.touchUserLogin(ctx, user.ID)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *AuthService) updateUserSignupSource(ctx context.Context, userID int64, signupSource string) {
|
||||
if s == nil || s.entClient == nil || userID <= 0 {
|
||||
return
|
||||
}
|
||||
if strings.TrimSpace(signupSource) == "" {
|
||||
return
|
||||
}
|
||||
if err := s.entClient.User.UpdateOneID(userID).
|
||||
SetSignupSource(signupSource).
|
||||
Exec(ctx); err != nil {
|
||||
logger.LegacyPrintf("service.auth", "[Auth] Failed to update signup source: user_id=%d source=%s err=%v", userID, signupSource, err)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *AuthService) touchUserLogin(ctx context.Context, userID int64) {
|
||||
if s == nil || s.entClient == nil || userID <= 0 {
|
||||
return
|
||||
}
|
||||
now := time.Now().UTC()
|
||||
if err := s.entClient.User.UpdateOneID(userID).
|
||||
SetLastLoginAt(now).
|
||||
SetLastActiveAt(now).
|
||||
Exec(ctx); err != nil {
|
||||
logger.LegacyPrintf("service.auth", "[Auth] Failed to touch login timestamps: user_id=%d err=%v", userID, err)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *AuthService) backfillEmailIdentityOnSuccessfulLogin(ctx context.Context, user *User) {
|
||||
if s == nil || user == nil || user.ID <= 0 {
|
||||
return
|
||||
}
|
||||
identity, created := s.ensureEmailAuthIdentity(ctx, user, "auth_service_login_backfill")
|
||||
if s.shouldApplyEmailFirstBindDefaults(ctx, user.ID, identity, created) {
|
||||
if err := s.ApplyProviderDefaultSettingsOnFirstBind(ctx, user.ID, "email"); err != nil {
|
||||
logger.LegacyPrintf("service.auth", "[Auth] Failed to apply email first bind defaults: user_id=%d err=%v", user.ID, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *AuthService) shouldApplyEmailFirstBindDefaults(
|
||||
ctx context.Context,
|
||||
userID int64,
|
||||
identity *dbent.AuthIdentity,
|
||||
created bool,
|
||||
) bool {
|
||||
source := emailAuthIdentitySource(identity.Metadata)
|
||||
if source == "auth_service_login_backfill" {
|
||||
return false
|
||||
}
|
||||
if created {
|
||||
return true
|
||||
}
|
||||
if s == nil || s.entClient == nil || userID <= 0 || identity == nil || identity.UserID != userID {
|
||||
return false
|
||||
}
|
||||
if source != "auth_service_dual_write" {
|
||||
return false
|
||||
}
|
||||
|
||||
hasGrant, err := s.hasProviderGrantRecord(ctx, userID, "email", "first_bind")
|
||||
if err != nil {
|
||||
logger.LegacyPrintf("service.auth", "[Auth] Failed to inspect email first bind grant state: user_id=%d err=%v", userID, err)
|
||||
return false
|
||||
}
|
||||
return !hasGrant
|
||||
}
|
||||
|
||||
func emailAuthIdentitySource(metadata map[string]any) string {
|
||||
if len(metadata) == 0 {
|
||||
return ""
|
||||
}
|
||||
raw, ok := metadata["source"]
|
||||
if !ok {
|
||||
return ""
|
||||
}
|
||||
return strings.TrimSpace(fmt.Sprint(raw))
|
||||
}
|
||||
|
||||
func (s *AuthService) hasProviderGrantRecord(
|
||||
ctx context.Context,
|
||||
userID int64,
|
||||
providerType string,
|
||||
grantReason string,
|
||||
) (bool, error) {
|
||||
if s == nil || s.entClient == nil || userID <= 0 {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
rows, err := s.entClient.QueryContext(
|
||||
ctx,
|
||||
`SELECT 1 FROM user_provider_default_grants WHERE user_id = $1 AND provider_type = $2 AND grant_reason = $3 LIMIT 1`,
|
||||
userID,
|
||||
strings.TrimSpace(providerType),
|
||||
strings.TrimSpace(grantReason),
|
||||
)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
defer func() { _ = rows.Close() }()
|
||||
return rows.Next(), rows.Err()
|
||||
}
|
||||
|
||||
func (s *AuthService) ensureEmailAuthIdentity(ctx context.Context, user *User, source string) (*dbent.AuthIdentity, bool) {
|
||||
if s == nil || s.entClient == nil || user == nil || user.ID <= 0 {
|
||||
return nil, false
|
||||
}
|
||||
|
||||
email := strings.ToLower(strings.TrimSpace(user.Email))
|
||||
if email == "" || isReservedEmail(email) {
|
||||
return nil, false
|
||||
}
|
||||
if strings.TrimSpace(source) == "" {
|
||||
source = "auth_service_dual_write"
|
||||
}
|
||||
|
||||
client := s.entClient
|
||||
if tx := dbent.TxFromContext(ctx); tx != nil {
|
||||
client = tx.Client()
|
||||
}
|
||||
|
||||
buildQuery := func() *dbent.AuthIdentityQuery {
|
||||
return client.AuthIdentity.Query().Where(
|
||||
authidentity.ProviderTypeEQ("email"),
|
||||
authidentity.ProviderKeyEQ("email"),
|
||||
authidentity.ProviderSubjectEQ(email),
|
||||
)
|
||||
}
|
||||
|
||||
existed, err := buildQuery().Exist(ctx)
|
||||
if err != nil {
|
||||
logger.LegacyPrintf("service.auth", "[Auth] Failed to inspect email auth identity: user_id=%d email=%s err=%v", user.ID, email, err)
|
||||
return nil, false
|
||||
}
|
||||
|
||||
if !existed {
|
||||
if err := client.AuthIdentity.Create().
|
||||
SetUserID(user.ID).
|
||||
SetProviderType("email").
|
||||
SetProviderKey("email").
|
||||
SetProviderSubject(email).
|
||||
SetVerifiedAt(time.Now().UTC()).
|
||||
SetMetadata(map[string]any{
|
||||
"source": strings.TrimSpace(source),
|
||||
}).
|
||||
OnConflictColumns(
|
||||
authidentity.FieldProviderType,
|
||||
authidentity.FieldProviderKey,
|
||||
authidentity.FieldProviderSubject,
|
||||
).
|
||||
DoNothing().
|
||||
Exec(ctx); err != nil {
|
||||
if isSQLNoRowsError(err) {
|
||||
return nil, false
|
||||
}
|
||||
}
|
||||
if err != nil {
|
||||
logger.LegacyPrintf("service.auth", "[Auth] Failed to ensure email auth identity: user_id=%d email=%s err=%v", user.ID, email, err)
|
||||
return nil, false
|
||||
}
|
||||
}
|
||||
|
||||
identity, err := buildQuery().Only(ctx)
|
||||
if err != nil {
|
||||
logger.LegacyPrintf("service.auth", "[Auth] Failed to reload email auth identity: user_id=%d email=%s err=%v", user.ID, email, err)
|
||||
return nil, false
|
||||
}
|
||||
if identity.UserID != user.ID {
|
||||
logger.LegacyPrintf("service.auth", "[Auth] Email auth identity ownership mismatch: user_id=%d email=%s owner_id=%d", user.ID, email, identity.UserID)
|
||||
return nil, false
|
||||
}
|
||||
|
||||
return identity, !existed
|
||||
}
|
||||
|
||||
func inferLegacySignupSource(email string) string {
|
||||
normalized := strings.ToLower(strings.TrimSpace(email))
|
||||
switch {
|
||||
case strings.HasSuffix(normalized, LinuxDoConnectSyntheticEmailDomain):
|
||||
return "linuxdo"
|
||||
case strings.HasSuffix(normalized, OIDCConnectSyntheticEmailDomain):
|
||||
return "oidc"
|
||||
case strings.HasSuffix(normalized, WeChatConnectSyntheticEmailDomain):
|
||||
return "wechat"
|
||||
default:
|
||||
return "email"
|
||||
}
|
||||
}
|
||||
|
||||
func (s *AuthService) validateRegistrationEmailPolicy(ctx context.Context, email string) error {
|
||||
if s.settingService == nil {
|
||||
return nil
|
||||
@@ -833,7 +1080,9 @@ func randomHexString(byteLength int) (string, error) {
|
||||
|
||||
func isReservedEmail(email string) bool {
|
||||
normalized := strings.ToLower(strings.TrimSpace(email))
|
||||
return strings.HasSuffix(normalized, LinuxDoConnectSyntheticEmailDomain)
|
||||
return strings.HasSuffix(normalized, LinuxDoConnectSyntheticEmailDomain) ||
|
||||
strings.HasSuffix(normalized, OIDCConnectSyntheticEmailDomain) ||
|
||||
strings.HasSuffix(normalized, WeChatConnectSyntheticEmailDomain)
|
||||
}
|
||||
|
||||
// GenerateToken 生成JWT access token
|
||||
@@ -852,7 +1101,7 @@ func (s *AuthService) GenerateToken(user *User) (string, error) {
|
||||
UserID: user.ID,
|
||||
Email: user.Email,
|
||||
Role: user.Role,
|
||||
TokenVersion: user.TokenVersion,
|
||||
TokenVersion: resolvedTokenVersion(user),
|
||||
RegisteredClaims: jwt.RegisteredClaims{
|
||||
ExpiresAt: jwt.NewNumericDate(expiresAt),
|
||||
IssuedAt: jwt.NewNumericDate(now),
|
||||
@@ -918,7 +1167,7 @@ func (s *AuthService) RefreshToken(ctx context.Context, oldTokenString string) (
|
||||
|
||||
// Security: Check TokenVersion to prevent refreshing revoked tokens
|
||||
// This ensures tokens issued before a password change cannot be refreshed
|
||||
if claims.TokenVersion != user.TokenVersion {
|
||||
if claims.TokenVersion != resolvedTokenVersion(user) {
|
||||
return "", ErrTokenRevoked
|
||||
}
|
||||
|
||||
@@ -1146,7 +1395,7 @@ func (s *AuthService) generateRefreshToken(ctx context.Context, user *User, fami
|
||||
|
||||
data := &RefreshTokenData{
|
||||
UserID: user.ID,
|
||||
TokenVersion: user.TokenVersion,
|
||||
TokenVersion: resolvedTokenVersion(user),
|
||||
FamilyID: familyID,
|
||||
CreatedAt: now,
|
||||
ExpiresAt: now.Add(ttl),
|
||||
@@ -1226,7 +1475,7 @@ func (s *AuthService) RefreshTokenPair(ctx context.Context, refreshToken string)
|
||||
}
|
||||
|
||||
// 检查TokenVersion(密码更改后所有Token失效)
|
||||
if data.TokenVersion != user.TokenVersion {
|
||||
if data.TokenVersion != resolvedTokenVersion(user) {
|
||||
// TokenVersion不匹配,撤销整个Token家族
|
||||
_ = s.refreshTokenCache.DeleteTokenFamily(ctx, data.FamilyID)
|
||||
return nil, ErrTokenRevoked
|
||||
@@ -1271,8 +1520,42 @@ func (s *AuthService) RevokeAllUserSessions(ctx context.Context, userID int64) e
|
||||
return s.refreshTokenCache.DeleteUserRefreshTokens(ctx, userID)
|
||||
}
|
||||
|
||||
// RevokeAllUserTokens invalidates both stateless access tokens and refresh sessions.
|
||||
// Access/refresh token verification both depend on TokenVersion, so bumping it provides
|
||||
// immediate revocation even if refresh-token cache cleanup later fails.
|
||||
func (s *AuthService) RevokeAllUserTokens(ctx context.Context, userID int64) error {
|
||||
user, err := s.userRepo.GetByID(ctx, userID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("get user: %w", err)
|
||||
}
|
||||
|
||||
user.TokenVersion++
|
||||
if err := s.userRepo.Update(ctx, user); err != nil {
|
||||
return fmt.Errorf("update user: %w", err)
|
||||
}
|
||||
|
||||
if err := s.RevokeAllUserSessions(ctx, userID); err != nil {
|
||||
logger.LegacyPrintf("service.auth", "[Auth] Failed to revoke refresh sessions after token invalidation for user %d: %v", userID, err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// hashToken 计算Token的SHA256哈希
|
||||
func hashToken(token string) string {
|
||||
hash := sha256.Sum256([]byte(token))
|
||||
return hex.EncodeToString(hash[:])
|
||||
}
|
||||
|
||||
func resolvedTokenVersion(user *User) int64 {
|
||||
if user == nil {
|
||||
return 0
|
||||
}
|
||||
if user.TokenVersionResolved {
|
||||
return user.TokenVersion
|
||||
}
|
||||
|
||||
material := strings.ToLower(strings.TrimSpace(user.Email)) + "\n" + user.PasswordHash
|
||||
sum := sha256.Sum256([]byte(material))
|
||||
fingerprint := int64(binary.BigEndian.Uint64(sum[:8]) & 0x7fffffffffffffff)
|
||||
return user.TokenVersion ^ fingerprint
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user