Merge remote-tracking branch 'upstream/main'
Some checks failed
CI / test (push) Has been cancelled
CI / frontend (push) Has been cancelled
CI / golangci-lint (push) Has been cancelled
Security Scan / backend-security (push) Has been cancelled
Security Scan / frontend-security (push) Has been cancelled

# Conflicts:
#	backend/internal/server/api_contract_test.go
#	backend/internal/service/setting_service.go
#	deploy/docker-compose.yml
#	frontend/src/components/layout/AppSidebar.vue
#	frontend/src/views/admin/SettingsView.vue
This commit is contained in:
huangzhenpc
2026-04-27 17:01:41 +08:00
997 changed files with 221727 additions and 36134 deletions

View File

@@ -4,6 +4,7 @@ import (
"context"
"crypto/rand"
"crypto/sha256"
"encoding/binary"
"encoding/hex"
"errors"
"fmt"
@@ -13,6 +14,7 @@ import (
"time"
dbent "github.com/Wei-Shaw/sub2api/ent"
"github.com/Wei-Shaw/sub2api/ent/authidentity"
"github.com/Wei-Shaw/sub2api/internal/config"
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
@@ -70,6 +72,7 @@ type AuthService struct {
turnstileService *TurnstileService
emailQueueService *EmailQueueService
promoService *PromoService
affiliateService *AffiliateService
defaultSubAssigner DefaultSubscriptionAssigner
}
@@ -77,6 +80,12 @@ type DefaultSubscriptionAssigner interface {
AssignOrExtendSubscription(ctx context.Context, input *AssignSubscriptionInput) (*UserSubscription, bool, error)
}
type signupGrantPlan struct {
Balance float64
Concurrency int
Subscriptions []DefaultSubscriptionSetting
}
// NewAuthService 创建认证服务实例
func NewAuthService(
entClient *dbent.Client,
@@ -90,6 +99,7 @@ func NewAuthService(
emailQueueService *EmailQueueService,
promoService *PromoService,
defaultSubAssigner DefaultSubscriptionAssigner,
affiliateService *AffiliateService,
) *AuthService {
return &AuthService{
entClient: entClient,
@@ -102,17 +112,25 @@ func NewAuthService(
turnstileService: turnstileService,
emailQueueService: emailQueueService,
promoService: promoService,
affiliateService: affiliateService,
defaultSubAssigner: defaultSubAssigner,
}
}
// Register 用户注册返回token和用户
func (s *AuthService) Register(ctx context.Context, email, password string) (string, *User, error) {
return s.RegisterWithVerification(ctx, email, password, "", "", "")
func (s *AuthService) EntClient() *dbent.Client {
if s == nil {
return nil
}
return s.entClient
}
// RegisterWithVerification 用户注册(支持邮件验证、优惠码和邀请码)返回token和用户
func (s *AuthService) RegisterWithVerification(ctx context.Context, email, password, verifyCode, promoCode, invitationCode string) (string, *User, error) {
// Register 用户注册返回token和用户
func (s *AuthService) Register(ctx context.Context, email, password string) (string, *User, error) {
return s.RegisterWithVerification(ctx, email, password, "", "", "", "")
}
// RegisterWithVerification 用户注册支持邮件验证、优惠码、邀请码和邀请返利码返回token和用户。
func (s *AuthService) RegisterWithVerification(ctx context.Context, email, password, verifyCode, promoCode, invitationCode, affiliateCode string) (string, *User, error) {
// 检查是否开放注册默认关闭settingService 未配置时不允许注册)
if s.settingService == nil || !s.settingService.IsRegistrationEnabled(ctx) {
return "", nil, ErrRegDisabled
@@ -179,12 +197,12 @@ func (s *AuthService) RegisterWithVerification(ctx context.Context, email, passw
return "", nil, fmt.Errorf("hash password: %w", err)
}
// 获取默认配置
defaultBalance := s.cfg.Default.UserBalance
defaultConcurrency := s.cfg.Default.UserConcurrency
grantPlan := s.resolveSignupGrantPlan(ctx, "email")
// 新用户默认 RPM0 = 不限制)。注册时写入,后续作为用户级兜底。
var defaultRPMLimit int
if s.settingService != nil {
defaultBalance = s.settingService.GetDefaultBalance(ctx)
defaultConcurrency = s.settingService.GetDefaultConcurrency(ctx)
defaultRPMLimit = s.settingService.GetDefaultUserRPMLimit(ctx)
}
// 创建用户
@@ -192,8 +210,9 @@ func (s *AuthService) RegisterWithVerification(ctx context.Context, email, passw
Email: email,
PasswordHash: hashedPassword,
Role: RoleUser,
Balance: defaultBalance,
Concurrency: defaultConcurrency,
Balance: grantPlan.Balance,
Concurrency: grantPlan.Concurrency,
RPMLimit: defaultRPMLimit,
Status: StatusActive,
}
@@ -205,7 +224,19 @@ func (s *AuthService) RegisterWithVerification(ctx context.Context, email, passw
logger.LegacyPrintf("service.auth", "[Auth] Database error creating user: %v", err)
return "", nil, ErrServiceUnavailable
}
s.assignDefaultSubscriptions(ctx, user.ID)
s.postAuthUserBootstrap(ctx, user, "email", true)
s.assignSubscriptions(ctx, user.ID, grantPlan.Subscriptions, "auto assigned by signup defaults")
if s.affiliateService != nil {
if _, err := s.affiliateService.EnsureUserAffiliate(ctx, user.ID); err != nil {
logger.LegacyPrintf("service.auth", "[Auth] Failed to initialize affiliate profile for user %d: %v", user.ID, err)
}
if code := strings.TrimSpace(affiliateCode); code != "" {
if err := s.affiliateService.BindInviterByCode(ctx, user.ID, code); err != nil {
// 邀请返利码绑定失败不影响注册,只记录日志
logger.LegacyPrintf("service.auth", "[Auth] Failed to bind affiliate inviter for user %d: %v", user.ID, err)
}
}
}
// 标记邀请码为已使用(如果使用了邀请码)
if invitationRedeemCode != nil {
@@ -469,12 +500,11 @@ func (s *AuthService) LoginOrRegisterOAuth(ctx context.Context, email, username
return "", nil, fmt.Errorf("hash password: %w", err)
}
// 新用户默认值。
defaultBalance := s.cfg.Default.UserBalance
defaultConcurrency := s.cfg.Default.UserConcurrency
signupSource := inferLegacySignupSource(email)
grantPlan := s.resolveSignupGrantPlan(ctx, signupSource)
var defaultRPMLimit int
if s.settingService != nil {
defaultBalance = s.settingService.GetDefaultBalance(ctx)
defaultConcurrency = s.settingService.GetDefaultConcurrency(ctx)
defaultRPMLimit = s.settingService.GetDefaultUserRPMLimit(ctx)
}
newUser := &User{
@@ -482,9 +512,11 @@ func (s *AuthService) LoginOrRegisterOAuth(ctx context.Context, email, username
Username: username,
PasswordHash: hashedPassword,
Role: RoleUser,
Balance: defaultBalance,
Concurrency: defaultConcurrency,
Balance: grantPlan.Balance,
Concurrency: grantPlan.Concurrency,
RPMLimit: defaultRPMLimit,
Status: StatusActive,
SignupSource: signupSource,
}
if err := s.userRepo.Create(ctx, newUser); err != nil {
@@ -501,7 +533,8 @@ func (s *AuthService) LoginOrRegisterOAuth(ctx context.Context, email, username
}
} else {
user = newUser
s.assignDefaultSubscriptions(ctx, user.ID)
s.postAuthUserBootstrap(ctx, user, signupSource, false)
s.assignSubscriptions(ctx, user.ID, grantPlan.Subscriptions, "auto assigned by signup defaults")
}
} else {
logger.LegacyPrintf("service.auth", "[Auth] Database error during oauth login: %v", err)
@@ -520,7 +553,6 @@ func (s *AuthService) LoginOrRegisterOAuth(ctx context.Context, email, username
logger.LegacyPrintf("service.auth", "[Auth] Failed to update username after oauth login: %v", err)
}
}
token, err := s.GenerateToken(user)
if err != nil {
return "", nil, fmt.Errorf("generate token: %w", err)
@@ -531,7 +563,8 @@ func (s *AuthService) LoginOrRegisterOAuth(ctx context.Context, email, username
// LoginOrRegisterOAuthWithTokenPair 用于第三方 OAuth/SSO 登录,返回完整的 TokenPair。
// 与 LoginOrRegisterOAuth 功能相同,但返回 TokenPair 而非单个 token。
// invitationCode 仅在邀请码注册模式下新用户注册时使用;已有账号登录时忽略。
func (s *AuthService) LoginOrRegisterOAuthWithTokenPair(ctx context.Context, email, username, invitationCode string) (*TokenPair, *User, error) {
// affiliateCode 用于邀请返利绑定,仅在新用户注册时使用。
func (s *AuthService) LoginOrRegisterOAuthWithTokenPair(ctx context.Context, email, username, invitationCode, affiliateCode string) (*TokenPair, *User, error) {
// 检查 refreshTokenCache 是否可用
if s.refreshTokenCache == nil {
return nil, nil, errors.New("refresh token cache not configured")
@@ -584,11 +617,11 @@ func (s *AuthService) LoginOrRegisterOAuthWithTokenPair(ctx context.Context, ema
return nil, nil, fmt.Errorf("hash password: %w", err)
}
defaultBalance := s.cfg.Default.UserBalance
defaultConcurrency := s.cfg.Default.UserConcurrency
signupSource := inferLegacySignupSource(email)
grantPlan := s.resolveSignupGrantPlan(ctx, signupSource)
var defaultRPMLimit int
if s.settingService != nil {
defaultBalance = s.settingService.GetDefaultBalance(ctx)
defaultConcurrency = s.settingService.GetDefaultConcurrency(ctx)
defaultRPMLimit = s.settingService.GetDefaultUserRPMLimit(ctx)
}
newUser := &User{
@@ -596,9 +629,11 @@ func (s *AuthService) LoginOrRegisterOAuthWithTokenPair(ctx context.Context, ema
Username: username,
PasswordHash: hashedPassword,
Role: RoleUser,
Balance: defaultBalance,
Concurrency: defaultConcurrency,
Balance: grantPlan.Balance,
Concurrency: grantPlan.Concurrency,
RPMLimit: defaultRPMLimit,
Status: StatusActive,
SignupSource: signupSource,
}
if s.entClient != nil && invitationRedeemCode != nil {
@@ -630,7 +665,9 @@ func (s *AuthService) LoginOrRegisterOAuthWithTokenPair(ctx context.Context, ema
return nil, nil, ErrServiceUnavailable
}
user = newUser
s.assignDefaultSubscriptions(ctx, user.ID)
s.postAuthUserBootstrap(ctx, user, signupSource, false)
s.assignSubscriptions(ctx, user.ID, grantPlan.Subscriptions, "auto assigned by signup defaults")
s.bindOAuthAffiliate(ctx, user.ID, affiliateCode)
}
} else {
if err := s.userRepo.Create(ctx, newUser); err != nil {
@@ -646,7 +683,9 @@ func (s *AuthService) LoginOrRegisterOAuthWithTokenPair(ctx context.Context, ema
}
} else {
user = newUser
s.assignDefaultSubscriptions(ctx, user.ID)
s.postAuthUserBootstrap(ctx, user, signupSource, false)
s.assignSubscriptions(ctx, user.ID, grantPlan.Subscriptions, "auto assigned by signup defaults")
s.bindOAuthAffiliate(ctx, user.ID, affiliateCode)
if invitationRedeemCode != nil {
if err := s.redeemRepo.Use(ctx, invitationRedeemCode.ID, user.ID); err != nil {
return nil, nil, ErrInvitationCodeInvalid
@@ -670,7 +709,6 @@ func (s *AuthService) LoginOrRegisterOAuthWithTokenPair(ctx context.Context, ema
logger.LegacyPrintf("service.auth", "[Auth] Failed to update username after oauth login: %v", err)
}
}
tokenPair, err := s.GenerateTokenPair(ctx, user, "")
if err != nil {
return nil, nil, fmt.Errorf("generate token pair: %w", err)
@@ -678,80 +716,289 @@ func (s *AuthService) LoginOrRegisterOAuthWithTokenPair(ctx context.Context, ema
return tokenPair, user, nil
}
// pendingOAuthTokenTTL is the validity period for pending OAuth tokens.
const pendingOAuthTokenTTL = 10 * time.Minute
// pendingOAuthPurpose is the purpose claim value for pending OAuth registration tokens.
const pendingOAuthPurpose = "pending_oauth_registration"
type pendingOAuthClaims struct {
Email string `json:"email"`
Username string `json:"username"`
Purpose string `json:"purpose"`
jwt.RegisteredClaims
}
// CreatePendingOAuthToken generates a short-lived JWT that carries the OAuth identity
// while waiting for the user to supply an invitation code.
func (s *AuthService) CreatePendingOAuthToken(email, username string) (string, error) {
now := time.Now()
claims := &pendingOAuthClaims{
Email: email,
Username: username,
Purpose: pendingOAuthPurpose,
RegisteredClaims: jwt.RegisteredClaims{
ExpiresAt: jwt.NewNumericDate(now.Add(pendingOAuthTokenTTL)),
IssuedAt: jwt.NewNumericDate(now),
NotBefore: jwt.NewNumericDate(now),
},
}
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
return token.SignedString([]byte(s.cfg.JWT.Secret))
}
// VerifyPendingOAuthToken validates a pending OAuth token and returns the embedded identity.
// Returns ErrInvalidToken when the token is invalid or expired.
func (s *AuthService) VerifyPendingOAuthToken(tokenStr string) (email, username string, err error) {
if len(tokenStr) > maxTokenLength {
return "", "", ErrInvalidToken
}
parser := jwt.NewParser(jwt.WithValidMethods([]string{jwt.SigningMethodHS256.Name}))
token, parseErr := parser.ParseWithClaims(tokenStr, &pendingOAuthClaims{}, func(t *jwt.Token) (any, error) {
if _, ok := t.Method.(*jwt.SigningMethodHMAC); !ok {
return nil, fmt.Errorf("unexpected signing method: %v", t.Header["alg"])
}
return []byte(s.cfg.JWT.Secret), nil
})
if parseErr != nil {
return "", "", ErrInvalidToken
}
claims, ok := token.Claims.(*pendingOAuthClaims)
if !ok || !token.Valid {
return "", "", ErrInvalidToken
}
if claims.Purpose != pendingOAuthPurpose {
return "", "", ErrInvalidToken
}
return claims.Email, claims.Username, nil
}
func (s *AuthService) assignDefaultSubscriptions(ctx context.Context, userID int64) {
func (s *AuthService) assignSubscriptions(ctx context.Context, userID int64, items []DefaultSubscriptionSetting, notes string) {
if s.settingService == nil || s.defaultSubAssigner == nil || userID <= 0 {
return
}
items := s.settingService.GetDefaultSubscriptions(ctx)
for _, item := range items {
if _, _, err := s.defaultSubAssigner.AssignOrExtendSubscription(ctx, &AssignSubscriptionInput{
UserID: userID,
GroupID: item.GroupID,
ValidityDays: item.ValidityDays,
Notes: "auto assigned by default user subscriptions setting",
Notes: notes,
}); err != nil {
logger.LegacyPrintf("service.auth", "[Auth] Failed to assign default subscription: user_id=%d group_id=%d err=%v", userID, item.GroupID, err)
}
}
}
func (s *AuthService) resolveSignupGrantPlan(ctx context.Context, signupSource string) signupGrantPlan {
plan := signupGrantPlan{}
if s != nil && s.cfg != nil {
plan.Balance = s.cfg.Default.UserBalance
plan.Concurrency = s.cfg.Default.UserConcurrency
}
if s == nil || s.settingService == nil {
return plan
}
plan.Balance = s.settingService.GetDefaultBalance(ctx)
plan.Concurrency = s.settingService.GetDefaultConcurrency(ctx)
plan.Subscriptions = s.settingService.GetDefaultSubscriptions(ctx)
resolved, enabled, err := s.settingService.ResolveAuthSourceGrantSettings(ctx, signupSource, false)
if err != nil {
logger.LegacyPrintf("service.auth", "[Auth] Failed to load auth source signup defaults for %s: %v", signupSource, err)
return plan
}
if !enabled {
return plan
}
plan.Balance = resolved.Balance
plan.Concurrency = resolved.Concurrency
plan.Subscriptions = resolved.Subscriptions
return plan
}
func authSourceSignupSettings(defaults *AuthSourceDefaultSettings, signupSource string) (ProviderDefaultGrantSettings, bool) {
if defaults == nil {
return ProviderDefaultGrantSettings{}, false
}
switch strings.ToLower(strings.TrimSpace(signupSource)) {
case "email":
return defaults.Email, true
case "linuxdo":
return defaults.LinuxDo, true
case "oidc":
return defaults.OIDC, true
case "wechat":
return defaults.WeChat, true
default:
return ProviderDefaultGrantSettings{}, false
}
}
// bindOAuthAffiliate initializes the affiliate profile and binds the inviter
// for an OAuth-registered user. Failures are logged but never block registration.
func (s *AuthService) bindOAuthAffiliate(ctx context.Context, userID int64, affiliateCode string) {
if s.affiliateService == nil || userID <= 0 {
return
}
if _, err := s.affiliateService.EnsureUserAffiliate(ctx, userID); err != nil {
logger.LegacyPrintf("service.auth", "[Auth] Failed to initialize affiliate profile for user %d: %v", userID, err)
}
if code := strings.TrimSpace(affiliateCode); code != "" {
if err := s.affiliateService.BindInviterByCode(ctx, userID, code); err != nil {
logger.LegacyPrintf("service.auth", "[Auth] Failed to bind affiliate inviter for user %d: %v", userID, err)
}
}
}
func (s *AuthService) postAuthUserBootstrap(ctx context.Context, user *User, signupSource string, touchLogin bool) {
if user == nil || user.ID <= 0 {
return
}
if strings.TrimSpace(signupSource) == "" {
signupSource = "email"
}
s.updateUserSignupSource(ctx, user.ID, signupSource)
if touchLogin {
s.touchUserLogin(ctx, user.ID)
}
}
func (s *AuthService) updateUserSignupSource(ctx context.Context, userID int64, signupSource string) {
if s == nil || s.entClient == nil || userID <= 0 {
return
}
if strings.TrimSpace(signupSource) == "" {
return
}
if err := s.entClient.User.UpdateOneID(userID).
SetSignupSource(signupSource).
Exec(ctx); err != nil {
logger.LegacyPrintf("service.auth", "[Auth] Failed to update signup source: user_id=%d source=%s err=%v", userID, signupSource, err)
}
}
func (s *AuthService) touchUserLogin(ctx context.Context, userID int64) {
if s == nil || s.entClient == nil || userID <= 0 {
return
}
now := time.Now().UTC()
if err := s.entClient.User.UpdateOneID(userID).
SetLastLoginAt(now).
SetLastActiveAt(now).
Exec(ctx); err != nil {
logger.LegacyPrintf("service.auth", "[Auth] Failed to touch login timestamps: user_id=%d err=%v", userID, err)
}
}
func (s *AuthService) backfillEmailIdentityOnSuccessfulLogin(ctx context.Context, user *User) {
if s == nil || user == nil || user.ID <= 0 {
return
}
identity, created := s.ensureEmailAuthIdentity(ctx, user, "auth_service_login_backfill")
if s.shouldApplyEmailFirstBindDefaults(ctx, user.ID, identity, created) {
if err := s.ApplyProviderDefaultSettingsOnFirstBind(ctx, user.ID, "email"); err != nil {
logger.LegacyPrintf("service.auth", "[Auth] Failed to apply email first bind defaults: user_id=%d err=%v", user.ID, err)
}
}
}
func (s *AuthService) shouldApplyEmailFirstBindDefaults(
ctx context.Context,
userID int64,
identity *dbent.AuthIdentity,
created bool,
) bool {
source := emailAuthIdentitySource(identity.Metadata)
if source == "auth_service_login_backfill" {
return false
}
if created {
return true
}
if s == nil || s.entClient == nil || userID <= 0 || identity == nil || identity.UserID != userID {
return false
}
if source != "auth_service_dual_write" {
return false
}
hasGrant, err := s.hasProviderGrantRecord(ctx, userID, "email", "first_bind")
if err != nil {
logger.LegacyPrintf("service.auth", "[Auth] Failed to inspect email first bind grant state: user_id=%d err=%v", userID, err)
return false
}
return !hasGrant
}
func emailAuthIdentitySource(metadata map[string]any) string {
if len(metadata) == 0 {
return ""
}
raw, ok := metadata["source"]
if !ok {
return ""
}
return strings.TrimSpace(fmt.Sprint(raw))
}
func (s *AuthService) hasProviderGrantRecord(
ctx context.Context,
userID int64,
providerType string,
grantReason string,
) (bool, error) {
if s == nil || s.entClient == nil || userID <= 0 {
return false, nil
}
rows, err := s.entClient.QueryContext(
ctx,
`SELECT 1 FROM user_provider_default_grants WHERE user_id = $1 AND provider_type = $2 AND grant_reason = $3 LIMIT 1`,
userID,
strings.TrimSpace(providerType),
strings.TrimSpace(grantReason),
)
if err != nil {
return false, err
}
defer func() { _ = rows.Close() }()
return rows.Next(), rows.Err()
}
func (s *AuthService) ensureEmailAuthIdentity(ctx context.Context, user *User, source string) (*dbent.AuthIdentity, bool) {
if s == nil || s.entClient == nil || user == nil || user.ID <= 0 {
return nil, false
}
email := strings.ToLower(strings.TrimSpace(user.Email))
if email == "" || isReservedEmail(email) {
return nil, false
}
if strings.TrimSpace(source) == "" {
source = "auth_service_dual_write"
}
client := s.entClient
if tx := dbent.TxFromContext(ctx); tx != nil {
client = tx.Client()
}
buildQuery := func() *dbent.AuthIdentityQuery {
return client.AuthIdentity.Query().Where(
authidentity.ProviderTypeEQ("email"),
authidentity.ProviderKeyEQ("email"),
authidentity.ProviderSubjectEQ(email),
)
}
existed, err := buildQuery().Exist(ctx)
if err != nil {
logger.LegacyPrintf("service.auth", "[Auth] Failed to inspect email auth identity: user_id=%d email=%s err=%v", user.ID, email, err)
return nil, false
}
if !existed {
if err := client.AuthIdentity.Create().
SetUserID(user.ID).
SetProviderType("email").
SetProviderKey("email").
SetProviderSubject(email).
SetVerifiedAt(time.Now().UTC()).
SetMetadata(map[string]any{
"source": strings.TrimSpace(source),
}).
OnConflictColumns(
authidentity.FieldProviderType,
authidentity.FieldProviderKey,
authidentity.FieldProviderSubject,
).
DoNothing().
Exec(ctx); err != nil {
if isSQLNoRowsError(err) {
return nil, false
}
}
if err != nil {
logger.LegacyPrintf("service.auth", "[Auth] Failed to ensure email auth identity: user_id=%d email=%s err=%v", user.ID, email, err)
return nil, false
}
}
identity, err := buildQuery().Only(ctx)
if err != nil {
logger.LegacyPrintf("service.auth", "[Auth] Failed to reload email auth identity: user_id=%d email=%s err=%v", user.ID, email, err)
return nil, false
}
if identity.UserID != user.ID {
logger.LegacyPrintf("service.auth", "[Auth] Email auth identity ownership mismatch: user_id=%d email=%s owner_id=%d", user.ID, email, identity.UserID)
return nil, false
}
return identity, !existed
}
func inferLegacySignupSource(email string) string {
normalized := strings.ToLower(strings.TrimSpace(email))
switch {
case strings.HasSuffix(normalized, LinuxDoConnectSyntheticEmailDomain):
return "linuxdo"
case strings.HasSuffix(normalized, OIDCConnectSyntheticEmailDomain):
return "oidc"
case strings.HasSuffix(normalized, WeChatConnectSyntheticEmailDomain):
return "wechat"
default:
return "email"
}
}
func (s *AuthService) validateRegistrationEmailPolicy(ctx context.Context, email string) error {
if s.settingService == nil {
return nil
@@ -833,7 +1080,9 @@ func randomHexString(byteLength int) (string, error) {
func isReservedEmail(email string) bool {
normalized := strings.ToLower(strings.TrimSpace(email))
return strings.HasSuffix(normalized, LinuxDoConnectSyntheticEmailDomain)
return strings.HasSuffix(normalized, LinuxDoConnectSyntheticEmailDomain) ||
strings.HasSuffix(normalized, OIDCConnectSyntheticEmailDomain) ||
strings.HasSuffix(normalized, WeChatConnectSyntheticEmailDomain)
}
// GenerateToken 生成JWT access token
@@ -852,7 +1101,7 @@ func (s *AuthService) GenerateToken(user *User) (string, error) {
UserID: user.ID,
Email: user.Email,
Role: user.Role,
TokenVersion: user.TokenVersion,
TokenVersion: resolvedTokenVersion(user),
RegisteredClaims: jwt.RegisteredClaims{
ExpiresAt: jwt.NewNumericDate(expiresAt),
IssuedAt: jwt.NewNumericDate(now),
@@ -918,7 +1167,7 @@ func (s *AuthService) RefreshToken(ctx context.Context, oldTokenString string) (
// Security: Check TokenVersion to prevent refreshing revoked tokens
// This ensures tokens issued before a password change cannot be refreshed
if claims.TokenVersion != user.TokenVersion {
if claims.TokenVersion != resolvedTokenVersion(user) {
return "", ErrTokenRevoked
}
@@ -1146,7 +1395,7 @@ func (s *AuthService) generateRefreshToken(ctx context.Context, user *User, fami
data := &RefreshTokenData{
UserID: user.ID,
TokenVersion: user.TokenVersion,
TokenVersion: resolvedTokenVersion(user),
FamilyID: familyID,
CreatedAt: now,
ExpiresAt: now.Add(ttl),
@@ -1226,7 +1475,7 @@ func (s *AuthService) RefreshTokenPair(ctx context.Context, refreshToken string)
}
// 检查TokenVersion密码更改后所有Token失效
if data.TokenVersion != user.TokenVersion {
if data.TokenVersion != resolvedTokenVersion(user) {
// TokenVersion不匹配撤销整个Token家族
_ = s.refreshTokenCache.DeleteTokenFamily(ctx, data.FamilyID)
return nil, ErrTokenRevoked
@@ -1271,8 +1520,42 @@ func (s *AuthService) RevokeAllUserSessions(ctx context.Context, userID int64) e
return s.refreshTokenCache.DeleteUserRefreshTokens(ctx, userID)
}
// RevokeAllUserTokens invalidates both stateless access tokens and refresh sessions.
// Access/refresh token verification both depend on TokenVersion, so bumping it provides
// immediate revocation even if refresh-token cache cleanup later fails.
func (s *AuthService) RevokeAllUserTokens(ctx context.Context, userID int64) error {
user, err := s.userRepo.GetByID(ctx, userID)
if err != nil {
return fmt.Errorf("get user: %w", err)
}
user.TokenVersion++
if err := s.userRepo.Update(ctx, user); err != nil {
return fmt.Errorf("update user: %w", err)
}
if err := s.RevokeAllUserSessions(ctx, userID); err != nil {
logger.LegacyPrintf("service.auth", "[Auth] Failed to revoke refresh sessions after token invalidation for user %d: %v", userID, err)
}
return nil
}
// hashToken 计算Token的SHA256哈希
func hashToken(token string) string {
hash := sha256.Sum256([]byte(token))
return hex.EncodeToString(hash[:])
}
func resolvedTokenVersion(user *User) int64 {
if user == nil {
return 0
}
if user.TokenVersionResolved {
return user.TokenVersion
}
material := strings.ToLower(strings.TrimSpace(user.Email)) + "\n" + user.PasswordHash
sum := sha256.Sum256([]byte(material))
fingerprint := int64(binary.BigEndian.Uint64(sum[:8]) & 0x7fffffffffffffff)
return user.TokenVersion ^ fingerprint
}