feat(auth): 实现 Refresh Token 机制
- 新增 Access Token + Refresh Token 双令牌认证 - 支持 Token 自动刷新和轮转 - 添加登出和撤销所有会话接口 - 前端实现无感刷新和主动刷新定时器
This commit is contained in:
@@ -33,7 +33,7 @@ func main() {
|
||||
}()
|
||||
|
||||
userRepo := repository.NewUserRepository(client, sqlDB)
|
||||
authService := service.NewAuthService(userRepo, nil, cfg, nil, nil, nil, nil, nil)
|
||||
authService := service.NewAuthService(userRepo, nil, nil, cfg, nil, nil, nil, nil, nil)
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
@@ -44,9 +44,10 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
|
||||
}
|
||||
userRepository := repository.NewUserRepository(client, db)
|
||||
redeemCodeRepository := repository.NewRedeemCodeRepository(client)
|
||||
redisClient := repository.ProvideRedis(configConfig)
|
||||
refreshTokenCache := repository.NewRefreshTokenCache(redisClient)
|
||||
settingRepository := repository.NewSettingRepository(client)
|
||||
settingService := service.NewSettingService(settingRepository, configConfig)
|
||||
redisClient := repository.ProvideRedis(configConfig)
|
||||
emailCache := repository.NewEmailCache(redisClient)
|
||||
emailService := service.NewEmailService(settingRepository, emailCache)
|
||||
turnstileVerifier := repository.NewTurnstileVerifier()
|
||||
@@ -62,7 +63,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
|
||||
apiKeyService := service.NewAPIKeyService(apiKeyRepository, userRepository, groupRepository, userSubscriptionRepository, apiKeyCache, configConfig)
|
||||
apiKeyAuthCacheInvalidator := service.ProvideAPIKeyAuthCacheInvalidator(apiKeyService)
|
||||
promoService := service.NewPromoService(promoCodeRepository, userRepository, billingCacheService, client, apiKeyAuthCacheInvalidator)
|
||||
authService := service.NewAuthService(userRepository, redeemCodeRepository, configConfig, settingService, emailService, turnstileService, emailQueueService, promoService)
|
||||
authService := service.NewAuthService(userRepository, redeemCodeRepository, refreshTokenCache, configConfig, settingService, emailService, turnstileService, emailQueueService, promoService)
|
||||
userService := service.NewUserService(userRepository, apiKeyAuthCacheInvalidator)
|
||||
subscriptionService := service.NewSubscriptionService(groupRepository, userSubscriptionRepository, billingCacheService)
|
||||
redeemCache := repository.NewRedeemCache(redisClient)
|
||||
|
||||
@@ -467,6 +467,13 @@ type OpsMetricsCollectorCacheConfig struct {
|
||||
type JWTConfig struct {
|
||||
Secret string `mapstructure:"secret"`
|
||||
ExpireHour int `mapstructure:"expire_hour"`
|
||||
// AccessTokenExpireMinutes: Access Token有效期(分钟),默认15分钟
|
||||
// 短有效期减少被盗用风险,配合Refresh Token实现无感续期
|
||||
AccessTokenExpireMinutes int `mapstructure:"access_token_expire_minutes"`
|
||||
// RefreshTokenExpireDays: Refresh Token有效期(天),默认30天
|
||||
RefreshTokenExpireDays int `mapstructure:"refresh_token_expire_days"`
|
||||
// RefreshWindowMinutes: 刷新窗口(分钟),在Access Token过期前多久开始允许刷新
|
||||
RefreshWindowMinutes int `mapstructure:"refresh_window_minutes"`
|
||||
}
|
||||
|
||||
// TotpConfig TOTP 双因素认证配置
|
||||
@@ -783,6 +790,9 @@ func setDefaults() {
|
||||
// JWT
|
||||
viper.SetDefault("jwt.secret", "")
|
||||
viper.SetDefault("jwt.expire_hour", 24)
|
||||
viper.SetDefault("jwt.access_token_expire_minutes", 360) // 6小时Access Token有效期
|
||||
viper.SetDefault("jwt.refresh_token_expire_days", 30) // 30天Refresh Token有效期
|
||||
viper.SetDefault("jwt.refresh_window_minutes", 2) // 过期前2分钟开始允许刷新
|
||||
|
||||
// TOTP
|
||||
viper.SetDefault("totp.encryption_key", "")
|
||||
@@ -912,6 +922,22 @@ func (c *Config) Validate() error {
|
||||
if c.JWT.ExpireHour > 24 {
|
||||
log.Printf("Warning: jwt.expire_hour is %d hours (> 24). Consider shorter expiration for security.", c.JWT.ExpireHour)
|
||||
}
|
||||
// JWT Refresh Token配置验证
|
||||
if c.JWT.AccessTokenExpireMinutes <= 0 {
|
||||
return fmt.Errorf("jwt.access_token_expire_minutes must be positive")
|
||||
}
|
||||
if c.JWT.AccessTokenExpireMinutes > 720 {
|
||||
log.Printf("Warning: jwt.access_token_expire_minutes is %d (> 720). Consider shorter expiration for security.", c.JWT.AccessTokenExpireMinutes)
|
||||
}
|
||||
if c.JWT.RefreshTokenExpireDays <= 0 {
|
||||
return fmt.Errorf("jwt.refresh_token_expire_days must be positive")
|
||||
}
|
||||
if c.JWT.RefreshTokenExpireDays > 90 {
|
||||
log.Printf("Warning: jwt.refresh_token_expire_days is %d (> 90). Consider shorter expiration for security.", c.JWT.RefreshTokenExpireDays)
|
||||
}
|
||||
if c.JWT.RefreshWindowMinutes < 0 {
|
||||
return fmt.Errorf("jwt.refresh_window_minutes must be non-negative")
|
||||
}
|
||||
if c.Security.CSP.Enabled && strings.TrimSpace(c.Security.CSP.Policy) == "" {
|
||||
return fmt.Errorf("security.csp.policy is required when CSP is enabled")
|
||||
}
|
||||
|
||||
@@ -68,9 +68,39 @@ type LoginRequest struct {
|
||||
|
||||
// AuthResponse 认证响应格式(匹配前端期望)
|
||||
type AuthResponse struct {
|
||||
AccessToken string `json:"access_token"`
|
||||
TokenType string `json:"token_type"`
|
||||
User *dto.User `json:"user"`
|
||||
AccessToken string `json:"access_token"`
|
||||
RefreshToken string `json:"refresh_token,omitempty"` // 新增:Refresh Token
|
||||
ExpiresIn int `json:"expires_in,omitempty"` // 新增:Access Token有效期(秒)
|
||||
TokenType string `json:"token_type"`
|
||||
User *dto.User `json:"user"`
|
||||
}
|
||||
|
||||
// respondWithTokenPair 生成 Token 对并返回认证响应
|
||||
// 如果 Token 对生成失败,回退到只返回 Access Token(向后兼容)
|
||||
func (h *AuthHandler) respondWithTokenPair(c *gin.Context, user *service.User) {
|
||||
tokenPair, err := h.authService.GenerateTokenPair(c.Request.Context(), user, "")
|
||||
if err != nil {
|
||||
slog.Error("failed to generate token pair", "error", err, "user_id", user.ID)
|
||||
// 回退到只返回Access Token
|
||||
token, tokenErr := h.authService.GenerateToken(user)
|
||||
if tokenErr != nil {
|
||||
response.InternalError(c, "Failed to generate token")
|
||||
return
|
||||
}
|
||||
response.Success(c, AuthResponse{
|
||||
AccessToken: token,
|
||||
TokenType: "Bearer",
|
||||
User: dto.UserFromService(user),
|
||||
})
|
||||
return
|
||||
}
|
||||
response.Success(c, AuthResponse{
|
||||
AccessToken: tokenPair.AccessToken,
|
||||
RefreshToken: tokenPair.RefreshToken,
|
||||
ExpiresIn: tokenPair.ExpiresIn,
|
||||
TokenType: "Bearer",
|
||||
User: dto.UserFromService(user),
|
||||
})
|
||||
}
|
||||
|
||||
// Register handles user registration
|
||||
@@ -90,17 +120,13 @@ func (h *AuthHandler) Register(c *gin.Context) {
|
||||
}
|
||||
}
|
||||
|
||||
token, user, err := h.authService.RegisterWithVerification(c.Request.Context(), req.Email, req.Password, req.VerifyCode, req.PromoCode, req.InvitationCode)
|
||||
_, user, err := h.authService.RegisterWithVerification(c.Request.Context(), req.Email, req.Password, req.VerifyCode, req.PromoCode, req.InvitationCode)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, AuthResponse{
|
||||
AccessToken: token,
|
||||
TokenType: "Bearer",
|
||||
User: dto.UserFromService(user),
|
||||
})
|
||||
h.respondWithTokenPair(c, user)
|
||||
}
|
||||
|
||||
// SendVerifyCode 发送邮箱验证码
|
||||
@@ -150,6 +176,7 @@ func (h *AuthHandler) Login(c *gin.Context) {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
_ = token // token 由 authService.Login 返回但此处由 respondWithTokenPair 重新生成
|
||||
|
||||
// Check if TOTP 2FA is enabled for this user
|
||||
if h.totpService != nil && h.settingSvc.IsTotpEnabled(c.Request.Context()) && user.TotpEnabled {
|
||||
@@ -168,11 +195,7 @@ func (h *AuthHandler) Login(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, AuthResponse{
|
||||
AccessToken: token,
|
||||
TokenType: "Bearer",
|
||||
User: dto.UserFromService(user),
|
||||
})
|
||||
h.respondWithTokenPair(c, user)
|
||||
}
|
||||
|
||||
// TotpLoginResponse represents the response when 2FA is required
|
||||
@@ -238,18 +261,7 @@ func (h *AuthHandler) Login2FA(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
// Generate the JWT token
|
||||
token, err := h.authService.GenerateToken(user)
|
||||
if err != nil {
|
||||
response.InternalError(c, "Failed to generate token")
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, AuthResponse{
|
||||
AccessToken: token,
|
||||
TokenType: "Bearer",
|
||||
User: dto.UserFromService(user),
|
||||
})
|
||||
h.respondWithTokenPair(c, user)
|
||||
}
|
||||
|
||||
// GetCurrentUser handles getting current authenticated user
|
||||
@@ -491,3 +503,96 @@ func (h *AuthHandler) ResetPassword(c *gin.Context) {
|
||||
Message: "Your password has been reset successfully. You can now log in with your new password.",
|
||||
})
|
||||
}
|
||||
|
||||
// ==================== Token Refresh Endpoints ====================
|
||||
|
||||
// RefreshTokenRequest 刷新Token请求
|
||||
type RefreshTokenRequest struct {
|
||||
RefreshToken string `json:"refresh_token" binding:"required"`
|
||||
}
|
||||
|
||||
// RefreshTokenResponse 刷新Token响应
|
||||
type RefreshTokenResponse struct {
|
||||
AccessToken string `json:"access_token"`
|
||||
RefreshToken string `json:"refresh_token"`
|
||||
ExpiresIn int `json:"expires_in"` // Access Token有效期(秒)
|
||||
TokenType string `json:"token_type"`
|
||||
}
|
||||
|
||||
// RefreshToken 刷新Token
|
||||
// POST /api/v1/auth/refresh
|
||||
func (h *AuthHandler) RefreshToken(c *gin.Context) {
|
||||
var req RefreshTokenRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.BadRequest(c, "Invalid request: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
tokenPair, err := h.authService.RefreshTokenPair(c.Request.Context(), req.RefreshToken)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, RefreshTokenResponse{
|
||||
AccessToken: tokenPair.AccessToken,
|
||||
RefreshToken: tokenPair.RefreshToken,
|
||||
ExpiresIn: tokenPair.ExpiresIn,
|
||||
TokenType: "Bearer",
|
||||
})
|
||||
}
|
||||
|
||||
// LogoutRequest 登出请求
|
||||
type LogoutRequest struct {
|
||||
RefreshToken string `json:"refresh_token,omitempty"` // 可选:撤销指定的Refresh Token
|
||||
}
|
||||
|
||||
// LogoutResponse 登出响应
|
||||
type LogoutResponse struct {
|
||||
Message string `json:"message"`
|
||||
}
|
||||
|
||||
// Logout 用户登出
|
||||
// POST /api/v1/auth/logout
|
||||
func (h *AuthHandler) Logout(c *gin.Context) {
|
||||
var req LogoutRequest
|
||||
// 允许空请求体(向后兼容)
|
||||
_ = c.ShouldBindJSON(&req)
|
||||
|
||||
// 如果提供了Refresh Token,撤销它
|
||||
if req.RefreshToken != "" {
|
||||
if err := h.authService.RevokeRefreshToken(c.Request.Context(), req.RefreshToken); err != nil {
|
||||
slog.Debug("failed to revoke refresh token", "error", err)
|
||||
// 不影响登出流程
|
||||
}
|
||||
}
|
||||
|
||||
response.Success(c, LogoutResponse{
|
||||
Message: "Logged out successfully",
|
||||
})
|
||||
}
|
||||
|
||||
// RevokeAllSessionsResponse 撤销所有会话响应
|
||||
type RevokeAllSessionsResponse struct {
|
||||
Message string `json:"message"`
|
||||
}
|
||||
|
||||
// RevokeAllSessions 撤销当前用户的所有会话
|
||||
// POST /api/v1/auth/revoke-all-sessions
|
||||
func (h *AuthHandler) RevokeAllSessions(c *gin.Context) {
|
||||
subject, ok := middleware2.GetAuthSubjectFromContext(c)
|
||||
if !ok {
|
||||
response.Unauthorized(c, "User not authenticated")
|
||||
return
|
||||
}
|
||||
|
||||
if err := h.authService.RevokeAllUserSessions(c.Request.Context(), subject.UserID); err != nil {
|
||||
slog.Error("failed to revoke all sessions", "user_id", subject.UserID, "error", err)
|
||||
response.InternalError(c, "Failed to revoke sessions")
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, RevokeAllSessionsResponse{
|
||||
Message: "All sessions have been revoked. Please log in again.",
|
||||
})
|
||||
}
|
||||
|
||||
@@ -211,7 +211,7 @@ func (h *AuthHandler) LinuxDoOAuthCallback(c *gin.Context) {
|
||||
email = linuxDoSyntheticEmail(subject)
|
||||
}
|
||||
|
||||
jwtToken, _, err := h.authService.LoginOrRegisterOAuth(c.Request.Context(), email, username)
|
||||
tokenPair, _, err := h.authService.LoginOrRegisterOAuthWithTokenPair(c.Request.Context(), email, username)
|
||||
if err != nil {
|
||||
// 避免把内部细节泄露给客户端;给前端保留结构化原因与提示信息即可。
|
||||
redirectOAuthError(c, frontendCallback, "login_failed", infraerrors.Reason(err), infraerrors.Message(err))
|
||||
@@ -219,7 +219,9 @@ func (h *AuthHandler) LinuxDoOAuthCallback(c *gin.Context) {
|
||||
}
|
||||
|
||||
fragment := url.Values{}
|
||||
fragment.Set("access_token", jwtToken)
|
||||
fragment.Set("access_token", tokenPair.AccessToken)
|
||||
fragment.Set("refresh_token", tokenPair.RefreshToken)
|
||||
fragment.Set("expires_in", fmt.Sprintf("%d", tokenPair.ExpiresIn))
|
||||
fragment.Set("token_type", "Bearer")
|
||||
fragment.Set("redirect", redirectTo)
|
||||
redirectWithFragment(c, frontendCallback, fragment)
|
||||
|
||||
158
backend/internal/repository/refresh_token_cache.go
Normal file
158
backend/internal/repository/refresh_token_cache.go
Normal file
@@ -0,0 +1,158 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/redis/go-redis/v9"
|
||||
)
|
||||
|
||||
const (
|
||||
refreshTokenKeyPrefix = "refresh_token:"
|
||||
userRefreshTokensPrefix = "user_refresh_tokens:"
|
||||
tokenFamilyPrefix = "token_family:"
|
||||
)
|
||||
|
||||
// refreshTokenKey generates the Redis key for a refresh token.
|
||||
func refreshTokenKey(tokenHash string) string {
|
||||
return refreshTokenKeyPrefix + tokenHash
|
||||
}
|
||||
|
||||
// userRefreshTokensKey generates the Redis key for user's token set.
|
||||
func userRefreshTokensKey(userID int64) string {
|
||||
return fmt.Sprintf("%s%d", userRefreshTokensPrefix, userID)
|
||||
}
|
||||
|
||||
// tokenFamilyKey generates the Redis key for token family set.
|
||||
func tokenFamilyKey(familyID string) string {
|
||||
return tokenFamilyPrefix + familyID
|
||||
}
|
||||
|
||||
type refreshTokenCache struct {
|
||||
rdb *redis.Client
|
||||
}
|
||||
|
||||
// NewRefreshTokenCache creates a new RefreshTokenCache implementation.
|
||||
func NewRefreshTokenCache(rdb *redis.Client) service.RefreshTokenCache {
|
||||
return &refreshTokenCache{rdb: rdb}
|
||||
}
|
||||
|
||||
func (c *refreshTokenCache) StoreRefreshToken(ctx context.Context, tokenHash string, data *service.RefreshTokenData, ttl time.Duration) error {
|
||||
key := refreshTokenKey(tokenHash)
|
||||
val, err := json.Marshal(data)
|
||||
if err != nil {
|
||||
return fmt.Errorf("marshal refresh token data: %w", err)
|
||||
}
|
||||
return c.rdb.Set(ctx, key, val, ttl).Err()
|
||||
}
|
||||
|
||||
func (c *refreshTokenCache) GetRefreshToken(ctx context.Context, tokenHash string) (*service.RefreshTokenData, error) {
|
||||
key := refreshTokenKey(tokenHash)
|
||||
val, err := c.rdb.Get(ctx, key).Result()
|
||||
if err != nil {
|
||||
if err == redis.Nil {
|
||||
return nil, service.ErrRefreshTokenNotFound
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
var data service.RefreshTokenData
|
||||
if err := json.Unmarshal([]byte(val), &data); err != nil {
|
||||
return nil, fmt.Errorf("unmarshal refresh token data: %w", err)
|
||||
}
|
||||
return &data, nil
|
||||
}
|
||||
|
||||
func (c *refreshTokenCache) DeleteRefreshToken(ctx context.Context, tokenHash string) error {
|
||||
key := refreshTokenKey(tokenHash)
|
||||
return c.rdb.Del(ctx, key).Err()
|
||||
}
|
||||
|
||||
func (c *refreshTokenCache) DeleteUserRefreshTokens(ctx context.Context, userID int64) error {
|
||||
// Get all token hashes for this user
|
||||
tokenHashes, err := c.GetUserTokenHashes(ctx, userID)
|
||||
if err != nil && err != redis.Nil {
|
||||
return fmt.Errorf("get user token hashes: %w", err)
|
||||
}
|
||||
|
||||
if len(tokenHashes) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Build keys to delete
|
||||
keys := make([]string, 0, len(tokenHashes)+1)
|
||||
for _, hash := range tokenHashes {
|
||||
keys = append(keys, refreshTokenKey(hash))
|
||||
}
|
||||
keys = append(keys, userRefreshTokensKey(userID))
|
||||
|
||||
// Delete all keys in a pipeline
|
||||
pipe := c.rdb.Pipeline()
|
||||
for _, key := range keys {
|
||||
pipe.Del(ctx, key)
|
||||
}
|
||||
_, err = pipe.Exec(ctx)
|
||||
return err
|
||||
}
|
||||
|
||||
func (c *refreshTokenCache) DeleteTokenFamily(ctx context.Context, familyID string) error {
|
||||
// Get all token hashes in this family
|
||||
tokenHashes, err := c.GetFamilyTokenHashes(ctx, familyID)
|
||||
if err != nil && err != redis.Nil {
|
||||
return fmt.Errorf("get family token hashes: %w", err)
|
||||
}
|
||||
|
||||
if len(tokenHashes) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Build keys to delete
|
||||
keys := make([]string, 0, len(tokenHashes)+1)
|
||||
for _, hash := range tokenHashes {
|
||||
keys = append(keys, refreshTokenKey(hash))
|
||||
}
|
||||
keys = append(keys, tokenFamilyKey(familyID))
|
||||
|
||||
// Delete all keys in a pipeline
|
||||
pipe := c.rdb.Pipeline()
|
||||
for _, key := range keys {
|
||||
pipe.Del(ctx, key)
|
||||
}
|
||||
_, err = pipe.Exec(ctx)
|
||||
return err
|
||||
}
|
||||
|
||||
func (c *refreshTokenCache) AddToUserTokenSet(ctx context.Context, userID int64, tokenHash string, ttl time.Duration) error {
|
||||
key := userRefreshTokensKey(userID)
|
||||
pipe := c.rdb.Pipeline()
|
||||
pipe.SAdd(ctx, key, tokenHash)
|
||||
pipe.Expire(ctx, key, ttl)
|
||||
_, err := pipe.Exec(ctx)
|
||||
return err
|
||||
}
|
||||
|
||||
func (c *refreshTokenCache) AddToFamilyTokenSet(ctx context.Context, familyID string, tokenHash string, ttl time.Duration) error {
|
||||
key := tokenFamilyKey(familyID)
|
||||
pipe := c.rdb.Pipeline()
|
||||
pipe.SAdd(ctx, key, tokenHash)
|
||||
pipe.Expire(ctx, key, ttl)
|
||||
_, err := pipe.Exec(ctx)
|
||||
return err
|
||||
}
|
||||
|
||||
func (c *refreshTokenCache) GetUserTokenHashes(ctx context.Context, userID int64) ([]string, error) {
|
||||
key := userRefreshTokensKey(userID)
|
||||
return c.rdb.SMembers(ctx, key).Result()
|
||||
}
|
||||
|
||||
func (c *refreshTokenCache) GetFamilyTokenHashes(ctx context.Context, familyID string) ([]string, error) {
|
||||
key := tokenFamilyKey(familyID)
|
||||
return c.rdb.SMembers(ctx, key).Result()
|
||||
}
|
||||
|
||||
func (c *refreshTokenCache) IsTokenInFamily(ctx context.Context, familyID string, tokenHash string) (bool, error) {
|
||||
key := tokenFamilyKey(familyID)
|
||||
return c.rdb.SIsMember(ctx, key, tokenHash).Result()
|
||||
}
|
||||
@@ -85,6 +85,7 @@ var ProviderSet = wire.NewSet(
|
||||
NewSchedulerOutboxRepository,
|
||||
NewProxyLatencyCache,
|
||||
NewTotpCache,
|
||||
NewRefreshTokenCache,
|
||||
|
||||
// Encryptors
|
||||
NewAESEncryptor,
|
||||
|
||||
@@ -28,6 +28,12 @@ func RegisterAuthRoutes(
|
||||
auth.POST("/login", h.Auth.Login)
|
||||
auth.POST("/login/2fa", h.Auth.Login2FA)
|
||||
auth.POST("/send-verify-code", h.Auth.SendVerifyCode)
|
||||
// Token刷新接口添加速率限制:每分钟最多 30 次(Redis 故障时 fail-close)
|
||||
auth.POST("/refresh", rateLimiter.LimitWithOptions("refresh-token", 30, time.Minute, middleware.RateLimitOptions{
|
||||
FailureMode: middleware.RateLimitFailClose,
|
||||
}), h.Auth.RefreshToken)
|
||||
// 登出接口(公开,允许未认证用户调用以撤销Refresh Token)
|
||||
auth.POST("/logout", h.Auth.Logout)
|
||||
// 优惠码验证接口添加速率限制:每分钟最多 10 次(Redis 故障时 fail-close)
|
||||
auth.POST("/validate-promo-code", rateLimiter.LimitWithOptions("validate-promo", 10, time.Minute, middleware.RateLimitOptions{
|
||||
FailureMode: middleware.RateLimitFailClose,
|
||||
@@ -59,5 +65,7 @@ func RegisterAuthRoutes(
|
||||
authenticated.Use(gin.HandlerFunc(jwtAuth))
|
||||
{
|
||||
authenticated.GET("/auth/me", h.Auth.GetCurrentUser)
|
||||
// 撤销所有会话(需要认证)
|
||||
authenticated.POST("/auth/revoke-all-sessions", h.Auth.RevokeAllSessions)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -3,6 +3,7 @@ package service
|
||||
import (
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"errors"
|
||||
"fmt"
|
||||
@@ -25,8 +26,12 @@ var (
|
||||
ErrEmailReserved = infraerrors.BadRequest("EMAIL_RESERVED", "email is reserved")
|
||||
ErrInvalidToken = infraerrors.Unauthorized("INVALID_TOKEN", "invalid token")
|
||||
ErrTokenExpired = infraerrors.Unauthorized("TOKEN_EXPIRED", "token has expired")
|
||||
ErrAccessTokenExpired = infraerrors.Unauthorized("ACCESS_TOKEN_EXPIRED", "access token has expired")
|
||||
ErrTokenTooLarge = infraerrors.BadRequest("TOKEN_TOO_LARGE", "token too large")
|
||||
ErrTokenRevoked = infraerrors.Unauthorized("TOKEN_REVOKED", "token has been revoked")
|
||||
ErrRefreshTokenInvalid = infraerrors.Unauthorized("REFRESH_TOKEN_INVALID", "invalid refresh token")
|
||||
ErrRefreshTokenExpired = infraerrors.Unauthorized("REFRESH_TOKEN_EXPIRED", "refresh token has expired")
|
||||
ErrRefreshTokenReused = infraerrors.Unauthorized("REFRESH_TOKEN_REUSED", "refresh token has been reused")
|
||||
ErrEmailVerifyRequired = infraerrors.BadRequest("EMAIL_VERIFY_REQUIRED", "email verification is required")
|
||||
ErrRegDisabled = infraerrors.Forbidden("REGISTRATION_DISABLED", "registration is currently disabled")
|
||||
ErrServiceUnavailable = infraerrors.ServiceUnavailable("SERVICE_UNAVAILABLE", "service temporarily unavailable")
|
||||
@@ -37,6 +42,9 @@ var (
|
||||
// maxTokenLength 限制 token 大小,避免超长 header 触发解析时的异常内存分配。
|
||||
const maxTokenLength = 8192
|
||||
|
||||
// refreshTokenPrefix is the prefix for refresh tokens to distinguish them from access tokens.
|
||||
const refreshTokenPrefix = "rt_"
|
||||
|
||||
// JWTClaims JWT载荷数据
|
||||
type JWTClaims struct {
|
||||
UserID int64 `json:"user_id"`
|
||||
@@ -50,6 +58,7 @@ type JWTClaims struct {
|
||||
type AuthService struct {
|
||||
userRepo UserRepository
|
||||
redeemRepo RedeemCodeRepository
|
||||
refreshTokenCache RefreshTokenCache
|
||||
cfg *config.Config
|
||||
settingService *SettingService
|
||||
emailService *EmailService
|
||||
@@ -62,6 +71,7 @@ type AuthService struct {
|
||||
func NewAuthService(
|
||||
userRepo UserRepository,
|
||||
redeemRepo RedeemCodeRepository,
|
||||
refreshTokenCache RefreshTokenCache,
|
||||
cfg *config.Config,
|
||||
settingService *SettingService,
|
||||
emailService *EmailService,
|
||||
@@ -72,6 +82,7 @@ func NewAuthService(
|
||||
return &AuthService{
|
||||
userRepo: userRepo,
|
||||
redeemRepo: redeemRepo,
|
||||
refreshTokenCache: refreshTokenCache,
|
||||
cfg: cfg,
|
||||
settingService: settingService,
|
||||
emailService: emailService,
|
||||
@@ -481,6 +492,100 @@ func (s *AuthService) LoginOrRegisterOAuth(ctx context.Context, email, username
|
||||
return token, user, nil
|
||||
}
|
||||
|
||||
// LoginOrRegisterOAuthWithTokenPair 用于第三方 OAuth/SSO 登录,返回完整的 TokenPair
|
||||
// 与 LoginOrRegisterOAuth 功能相同,但返回 TokenPair 而非单个 token
|
||||
func (s *AuthService) LoginOrRegisterOAuthWithTokenPair(ctx context.Context, email, username string) (*TokenPair, *User, error) {
|
||||
// 检查 refreshTokenCache 是否可用
|
||||
if s.refreshTokenCache == nil {
|
||||
return nil, nil, errors.New("refresh token cache not configured")
|
||||
}
|
||||
|
||||
email = strings.TrimSpace(email)
|
||||
if email == "" || len(email) > 255 {
|
||||
return nil, nil, infraerrors.BadRequest("INVALID_EMAIL", "invalid email")
|
||||
}
|
||||
if _, err := mail.ParseAddress(email); err != nil {
|
||||
return nil, nil, infraerrors.BadRequest("INVALID_EMAIL", "invalid email")
|
||||
}
|
||||
|
||||
username = strings.TrimSpace(username)
|
||||
if len([]rune(username)) > 100 {
|
||||
username = string([]rune(username)[:100])
|
||||
}
|
||||
|
||||
user, err := s.userRepo.GetByEmail(ctx, email)
|
||||
if err != nil {
|
||||
if errors.Is(err, ErrUserNotFound) {
|
||||
// OAuth 首次登录视为注册
|
||||
if s.settingService == nil || !s.settingService.IsRegistrationEnabled(ctx) {
|
||||
return nil, nil, ErrRegDisabled
|
||||
}
|
||||
|
||||
randomPassword, err := randomHexString(32)
|
||||
if err != nil {
|
||||
log.Printf("[Auth] Failed to generate random password for oauth signup: %v", err)
|
||||
return nil, nil, ErrServiceUnavailable
|
||||
}
|
||||
hashedPassword, err := s.HashPassword(randomPassword)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("hash password: %w", err)
|
||||
}
|
||||
|
||||
defaultBalance := s.cfg.Default.UserBalance
|
||||
defaultConcurrency := s.cfg.Default.UserConcurrency
|
||||
if s.settingService != nil {
|
||||
defaultBalance = s.settingService.GetDefaultBalance(ctx)
|
||||
defaultConcurrency = s.settingService.GetDefaultConcurrency(ctx)
|
||||
}
|
||||
|
||||
newUser := &User{
|
||||
Email: email,
|
||||
Username: username,
|
||||
PasswordHash: hashedPassword,
|
||||
Role: RoleUser,
|
||||
Balance: defaultBalance,
|
||||
Concurrency: defaultConcurrency,
|
||||
Status: StatusActive,
|
||||
}
|
||||
|
||||
if err := s.userRepo.Create(ctx, newUser); err != nil {
|
||||
if errors.Is(err, ErrEmailExists) {
|
||||
user, err = s.userRepo.GetByEmail(ctx, email)
|
||||
if err != nil {
|
||||
log.Printf("[Auth] Database error getting user after conflict: %v", err)
|
||||
return nil, nil, ErrServiceUnavailable
|
||||
}
|
||||
} else {
|
||||
log.Printf("[Auth] Database error creating oauth user: %v", err)
|
||||
return nil, nil, ErrServiceUnavailable
|
||||
}
|
||||
} else {
|
||||
user = newUser
|
||||
}
|
||||
} else {
|
||||
log.Printf("[Auth] Database error during oauth login: %v", err)
|
||||
return nil, nil, ErrServiceUnavailable
|
||||
}
|
||||
}
|
||||
|
||||
if !user.IsActive() {
|
||||
return nil, nil, ErrUserNotActive
|
||||
}
|
||||
|
||||
if user.Username == "" && username != "" {
|
||||
user.Username = username
|
||||
if err := s.userRepo.Update(ctx, user); err != nil {
|
||||
log.Printf("[Auth] Failed to update username after oauth login: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
tokenPair, err := s.GenerateTokenPair(ctx, user, "")
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("generate token pair: %w", err)
|
||||
}
|
||||
return tokenPair, user, nil
|
||||
}
|
||||
|
||||
// ValidateToken 验证JWT token并返回用户声明
|
||||
func (s *AuthService) ValidateToken(tokenString string) (*JWTClaims, error) {
|
||||
// 先做长度校验,尽早拒绝异常超长 token,降低 DoS 风险。
|
||||
@@ -539,10 +644,17 @@ func isReservedEmail(email string) bool {
|
||||
return strings.HasSuffix(normalized, LinuxDoConnectSyntheticEmailDomain)
|
||||
}
|
||||
|
||||
// GenerateToken 生成JWT token
|
||||
// GenerateToken 生成JWT access token
|
||||
// 使用新的access_token_expire_minutes配置项(如果配置了),否则回退到expire_hour
|
||||
func (s *AuthService) GenerateToken(user *User) (string, error) {
|
||||
now := time.Now()
|
||||
expiresAt := now.Add(time.Duration(s.cfg.JWT.ExpireHour) * time.Hour)
|
||||
var expiresAt time.Time
|
||||
if s.cfg.JWT.AccessTokenExpireMinutes > 0 {
|
||||
expiresAt = now.Add(time.Duration(s.cfg.JWT.AccessTokenExpireMinutes) * time.Minute)
|
||||
} else {
|
||||
// 向后兼容:使用旧的expire_hour配置
|
||||
expiresAt = now.Add(time.Duration(s.cfg.JWT.ExpireHour) * time.Hour)
|
||||
}
|
||||
|
||||
claims := &JWTClaims{
|
||||
UserID: user.ID,
|
||||
@@ -565,6 +677,15 @@ func (s *AuthService) GenerateToken(user *User) (string, error) {
|
||||
return tokenString, nil
|
||||
}
|
||||
|
||||
// GetAccessTokenExpiresIn 返回Access Token的有效期(秒)
|
||||
// 用于前端设置刷新定时器
|
||||
func (s *AuthService) GetAccessTokenExpiresIn() int {
|
||||
if s.cfg.JWT.AccessTokenExpireMinutes > 0 {
|
||||
return s.cfg.JWT.AccessTokenExpireMinutes * 60
|
||||
}
|
||||
return s.cfg.JWT.ExpireHour * 3600
|
||||
}
|
||||
|
||||
// HashPassword 使用bcrypt加密密码
|
||||
func (s *AuthService) HashPassword(password string) (string, error) {
|
||||
hashedBytes, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost)
|
||||
@@ -755,6 +876,198 @@ func (s *AuthService) ResetPassword(ctx context.Context, email, token, newPasswo
|
||||
return ErrServiceUnavailable
|
||||
}
|
||||
|
||||
// Also revoke all refresh tokens for this user
|
||||
if err := s.RevokeAllUserSessions(ctx, user.ID); err != nil {
|
||||
log.Printf("[Auth] Failed to revoke refresh tokens for user %d: %v", user.ID, err)
|
||||
// Don't return error - password was already changed successfully
|
||||
}
|
||||
|
||||
log.Printf("[Auth] Password reset successful for user: %s", email)
|
||||
return nil
|
||||
}
|
||||
|
||||
// ==================== Refresh Token Methods ====================
|
||||
|
||||
// TokenPair 包含Access Token和Refresh Token
|
||||
type TokenPair struct {
|
||||
AccessToken string `json:"access_token"`
|
||||
RefreshToken string `json:"refresh_token"`
|
||||
ExpiresIn int `json:"expires_in"` // Access Token有效期(秒)
|
||||
}
|
||||
|
||||
// GenerateTokenPair 生成Access Token和Refresh Token对
|
||||
// familyID: 可选的Token家族ID,用于Token轮转时保持家族关系
|
||||
func (s *AuthService) GenerateTokenPair(ctx context.Context, user *User, familyID string) (*TokenPair, error) {
|
||||
// 检查 refreshTokenCache 是否可用
|
||||
if s.refreshTokenCache == nil {
|
||||
return nil, errors.New("refresh token cache not configured")
|
||||
}
|
||||
|
||||
// 生成Access Token
|
||||
accessToken, err := s.GenerateToken(user)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("generate access token: %w", err)
|
||||
}
|
||||
|
||||
// 生成Refresh Token
|
||||
refreshToken, err := s.generateRefreshToken(ctx, user, familyID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("generate refresh token: %w", err)
|
||||
}
|
||||
|
||||
return &TokenPair{
|
||||
AccessToken: accessToken,
|
||||
RefreshToken: refreshToken,
|
||||
ExpiresIn: s.GetAccessTokenExpiresIn(),
|
||||
}, nil
|
||||
}
|
||||
|
||||
// generateRefreshToken 生成并存储Refresh Token
|
||||
func (s *AuthService) generateRefreshToken(ctx context.Context, user *User, familyID string) (string, error) {
|
||||
// 生成随机Token
|
||||
tokenBytes := make([]byte, 32)
|
||||
if _, err := rand.Read(tokenBytes); err != nil {
|
||||
return "", fmt.Errorf("generate random bytes: %w", err)
|
||||
}
|
||||
rawToken := refreshTokenPrefix + hex.EncodeToString(tokenBytes)
|
||||
|
||||
// 计算Token哈希(存储哈希而非原始Token)
|
||||
tokenHash := hashToken(rawToken)
|
||||
|
||||
// 如果没有提供familyID,生成新的
|
||||
if familyID == "" {
|
||||
familyBytes := make([]byte, 16)
|
||||
if _, err := rand.Read(familyBytes); err != nil {
|
||||
return "", fmt.Errorf("generate family id: %w", err)
|
||||
}
|
||||
familyID = hex.EncodeToString(familyBytes)
|
||||
}
|
||||
|
||||
now := time.Now()
|
||||
ttl := time.Duration(s.cfg.JWT.RefreshTokenExpireDays) * 24 * time.Hour
|
||||
|
||||
data := &RefreshTokenData{
|
||||
UserID: user.ID,
|
||||
TokenVersion: user.TokenVersion,
|
||||
FamilyID: familyID,
|
||||
CreatedAt: now,
|
||||
ExpiresAt: now.Add(ttl),
|
||||
}
|
||||
|
||||
// 存储Token数据
|
||||
if err := s.refreshTokenCache.StoreRefreshToken(ctx, tokenHash, data, ttl); err != nil {
|
||||
return "", fmt.Errorf("store refresh token: %w", err)
|
||||
}
|
||||
|
||||
// 添加到用户Token集合
|
||||
if err := s.refreshTokenCache.AddToUserTokenSet(ctx, user.ID, tokenHash, ttl); err != nil {
|
||||
log.Printf("[Auth] Failed to add token to user set: %v", err)
|
||||
// 不影响主流程
|
||||
}
|
||||
|
||||
// 添加到家族Token集合
|
||||
if err := s.refreshTokenCache.AddToFamilyTokenSet(ctx, familyID, tokenHash, ttl); err != nil {
|
||||
log.Printf("[Auth] Failed to add token to family set: %v", err)
|
||||
// 不影响主流程
|
||||
}
|
||||
|
||||
return rawToken, nil
|
||||
}
|
||||
|
||||
// RefreshTokenPair 使用Refresh Token刷新Token对
|
||||
// 实现Token轮转:每次刷新都会生成新的Refresh Token,旧Token立即失效
|
||||
func (s *AuthService) RefreshTokenPair(ctx context.Context, refreshToken string) (*TokenPair, error) {
|
||||
// 检查 refreshTokenCache 是否可用
|
||||
if s.refreshTokenCache == nil {
|
||||
return nil, ErrRefreshTokenInvalid
|
||||
}
|
||||
|
||||
// 验证Token格式
|
||||
if !strings.HasPrefix(refreshToken, refreshTokenPrefix) {
|
||||
return nil, ErrRefreshTokenInvalid
|
||||
}
|
||||
|
||||
tokenHash := hashToken(refreshToken)
|
||||
|
||||
// 获取Token数据
|
||||
data, err := s.refreshTokenCache.GetRefreshToken(ctx, tokenHash)
|
||||
if err != nil {
|
||||
if errors.Is(err, ErrRefreshTokenNotFound) {
|
||||
// Token不存在,可能是已被使用(Token轮转)或已过期
|
||||
log.Printf("[Auth] Refresh token not found, possible reuse attack")
|
||||
return nil, ErrRefreshTokenInvalid
|
||||
}
|
||||
log.Printf("[Auth] Error getting refresh token: %v", err)
|
||||
return nil, ErrServiceUnavailable
|
||||
}
|
||||
|
||||
// 检查Token是否过期
|
||||
if time.Now().After(data.ExpiresAt) {
|
||||
// 删除过期Token
|
||||
_ = s.refreshTokenCache.DeleteRefreshToken(ctx, tokenHash)
|
||||
return nil, ErrRefreshTokenExpired
|
||||
}
|
||||
|
||||
// 获取用户信息
|
||||
user, err := s.userRepo.GetByID(ctx, data.UserID)
|
||||
if err != nil {
|
||||
if errors.Is(err, ErrUserNotFound) {
|
||||
// 用户已删除,撤销整个Token家族
|
||||
_ = s.refreshTokenCache.DeleteTokenFamily(ctx, data.FamilyID)
|
||||
return nil, ErrRefreshTokenInvalid
|
||||
}
|
||||
log.Printf("[Auth] Database error getting user for token refresh: %v", err)
|
||||
return nil, ErrServiceUnavailable
|
||||
}
|
||||
|
||||
// 检查用户状态
|
||||
if !user.IsActive() {
|
||||
// 用户被禁用,撤销整个Token家族
|
||||
_ = s.refreshTokenCache.DeleteTokenFamily(ctx, data.FamilyID)
|
||||
return nil, ErrUserNotActive
|
||||
}
|
||||
|
||||
// 检查TokenVersion(密码更改后所有Token失效)
|
||||
if data.TokenVersion != user.TokenVersion {
|
||||
// TokenVersion不匹配,撤销整个Token家族
|
||||
_ = s.refreshTokenCache.DeleteTokenFamily(ctx, data.FamilyID)
|
||||
return nil, ErrTokenRevoked
|
||||
}
|
||||
|
||||
// Token轮转:立即使旧Token失效
|
||||
if err := s.refreshTokenCache.DeleteRefreshToken(ctx, tokenHash); err != nil {
|
||||
log.Printf("[Auth] Failed to delete old refresh token: %v", err)
|
||||
// 继续处理,不影响主流程
|
||||
}
|
||||
|
||||
// 生成新的Token对,保持同一个家族ID
|
||||
return s.GenerateTokenPair(ctx, user, data.FamilyID)
|
||||
}
|
||||
|
||||
// RevokeRefreshToken 撤销单个Refresh Token
|
||||
func (s *AuthService) RevokeRefreshToken(ctx context.Context, refreshToken string) error {
|
||||
if s.refreshTokenCache == nil {
|
||||
return nil // No-op if cache not configured
|
||||
}
|
||||
if !strings.HasPrefix(refreshToken, refreshTokenPrefix) {
|
||||
return ErrRefreshTokenInvalid
|
||||
}
|
||||
|
||||
tokenHash := hashToken(refreshToken)
|
||||
return s.refreshTokenCache.DeleteRefreshToken(ctx, tokenHash)
|
||||
}
|
||||
|
||||
// RevokeAllUserSessions 撤销用户的所有会话(所有Refresh Token)
|
||||
// 用于密码更改或用户主动登出所有设备
|
||||
func (s *AuthService) RevokeAllUserSessions(ctx context.Context, userID int64) error {
|
||||
if s.refreshTokenCache == nil {
|
||||
return nil // No-op if cache not configured
|
||||
}
|
||||
return s.refreshTokenCache.DeleteUserRefreshTokens(ctx, userID)
|
||||
}
|
||||
|
||||
// hashToken 计算Token的SHA256哈希
|
||||
func hashToken(token string) string {
|
||||
hash := sha256.Sum256([]byte(token))
|
||||
return hex.EncodeToString(hash[:])
|
||||
}
|
||||
|
||||
@@ -116,6 +116,7 @@ func newAuthService(repo *userRepoStub, settings map[string]string, emailCache E
|
||||
return NewAuthService(
|
||||
repo,
|
||||
nil, // redeemRepo
|
||||
nil, // refreshTokenCache
|
||||
cfg,
|
||||
settingService,
|
||||
emailService,
|
||||
|
||||
73
backend/internal/service/refresh_token_cache.go
Normal file
73
backend/internal/service/refresh_token_cache.go
Normal file
@@ -0,0 +1,73 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"time"
|
||||
)
|
||||
|
||||
// ErrRefreshTokenNotFound is returned when a refresh token is not found in cache.
|
||||
// This is used to abstract away the underlying cache implementation (e.g., redis.Nil).
|
||||
var ErrRefreshTokenNotFound = errors.New("refresh token not found")
|
||||
|
||||
// RefreshTokenData 存储在Redis中的Refresh Token数据
|
||||
type RefreshTokenData struct {
|
||||
UserID int64 `json:"user_id"`
|
||||
TokenVersion int64 `json:"token_version"` // 用于检测密码更改后的Token失效
|
||||
FamilyID string `json:"family_id"` // Token家族ID,用于防重放攻击
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
ExpiresAt time.Time `json:"expires_at"`
|
||||
}
|
||||
|
||||
// RefreshTokenCache 管理Refresh Token的Redis缓存
|
||||
// 用于JWT Token刷新机制,支持Token轮转和防重放攻击
|
||||
//
|
||||
// Key 格式:
|
||||
// - refresh_token:{token_hash} -> RefreshTokenData (JSON)
|
||||
// - user_refresh_tokens:{user_id} -> Set<token_hash>
|
||||
// - token_family:{family_id} -> Set<token_hash>
|
||||
type RefreshTokenCache interface {
|
||||
// StoreRefreshToken 存储Refresh Token
|
||||
// tokenHash: Token的SHA256哈希值(不存储原始Token)
|
||||
// data: Token关联的数据
|
||||
// ttl: Token过期时间
|
||||
StoreRefreshToken(ctx context.Context, tokenHash string, data *RefreshTokenData, ttl time.Duration) error
|
||||
|
||||
// GetRefreshToken 获取Refresh Token数据
|
||||
// 返回 (data, nil) 如果Token存在
|
||||
// 返回 (nil, ErrRefreshTokenNotFound) 如果Token不存在
|
||||
// 返回 (nil, err) 如果发生其他错误
|
||||
GetRefreshToken(ctx context.Context, tokenHash string) (*RefreshTokenData, error)
|
||||
|
||||
// DeleteRefreshToken 删除单个Refresh Token
|
||||
// 用于Token轮转时使旧Token失效
|
||||
DeleteRefreshToken(ctx context.Context, tokenHash string) error
|
||||
|
||||
// DeleteUserRefreshTokens 删除用户的所有Refresh Token
|
||||
// 用于密码更改或用户主动登出所有设备
|
||||
DeleteUserRefreshTokens(ctx context.Context, userID int64) error
|
||||
|
||||
// DeleteTokenFamily 删除整个Token家族
|
||||
// 用于检测到Token重放攻击时,撤销整个会话链
|
||||
DeleteTokenFamily(ctx context.Context, familyID string) error
|
||||
|
||||
// AddToUserTokenSet 将Token添加到用户的Token集合
|
||||
// 用于跟踪用户的所有活跃Refresh Token
|
||||
AddToUserTokenSet(ctx context.Context, userID int64, tokenHash string, ttl time.Duration) error
|
||||
|
||||
// AddToFamilyTokenSet 将Token添加到家族Token集合
|
||||
// 用于跟踪同一登录会话的所有Token
|
||||
AddToFamilyTokenSet(ctx context.Context, familyID string, tokenHash string, ttl time.Duration) error
|
||||
|
||||
// GetUserTokenHashes 获取用户的所有Token哈希
|
||||
// 用于批量删除用户Token
|
||||
GetUserTokenHashes(ctx context.Context, userID int64) ([]string, error)
|
||||
|
||||
// GetFamilyTokenHashes 获取家族的所有Token哈希
|
||||
// 用于批量删除家族Token
|
||||
GetFamilyTokenHashes(ctx context.Context, familyID string) ([]string, error)
|
||||
|
||||
// IsTokenInFamily 检查Token是否属于指定家族
|
||||
// 用于验证Token家族关系
|
||||
IsTokenInFamily(ctx context.Context, familyID string, tokenHash string) (bool, error)
|
||||
}
|
||||
@@ -35,6 +35,22 @@ export function setAuthToken(token: string): void {
|
||||
localStorage.setItem('auth_token', token)
|
||||
}
|
||||
|
||||
/**
|
||||
* Store refresh token in localStorage
|
||||
*/
|
||||
export function setRefreshToken(token: string): void {
|
||||
localStorage.setItem('refresh_token', token)
|
||||
}
|
||||
|
||||
/**
|
||||
* Store token expiration timestamp in localStorage
|
||||
* Converts expires_in (seconds) to absolute timestamp (milliseconds)
|
||||
*/
|
||||
export function setTokenExpiresAt(expiresIn: number): void {
|
||||
const expiresAt = Date.now() + expiresIn * 1000
|
||||
localStorage.setItem('token_expires_at', String(expiresAt))
|
||||
}
|
||||
|
||||
/**
|
||||
* Get authentication token from localStorage
|
||||
*/
|
||||
@@ -42,12 +58,29 @@ export function getAuthToken(): string | null {
|
||||
return localStorage.getItem('auth_token')
|
||||
}
|
||||
|
||||
/**
|
||||
* Get refresh token from localStorage
|
||||
*/
|
||||
export function getRefreshToken(): string | null {
|
||||
return localStorage.getItem('refresh_token')
|
||||
}
|
||||
|
||||
/**
|
||||
* Get token expiration timestamp from localStorage
|
||||
*/
|
||||
export function getTokenExpiresAt(): number | null {
|
||||
const value = localStorage.getItem('token_expires_at')
|
||||
return value ? parseInt(value, 10) : null
|
||||
}
|
||||
|
||||
/**
|
||||
* Clear authentication token from localStorage
|
||||
*/
|
||||
export function clearAuthToken(): void {
|
||||
localStorage.removeItem('auth_token')
|
||||
localStorage.removeItem('refresh_token')
|
||||
localStorage.removeItem('auth_user')
|
||||
localStorage.removeItem('token_expires_at')
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -61,6 +94,12 @@ export async function login(credentials: LoginRequest): Promise<LoginResponse> {
|
||||
// Only store token if 2FA is not required
|
||||
if (!isTotp2FARequired(data)) {
|
||||
setAuthToken(data.access_token)
|
||||
if (data.refresh_token) {
|
||||
setRefreshToken(data.refresh_token)
|
||||
}
|
||||
if (data.expires_in) {
|
||||
setTokenExpiresAt(data.expires_in)
|
||||
}
|
||||
localStorage.setItem('auth_user', JSON.stringify(data.user))
|
||||
}
|
||||
|
||||
@@ -77,6 +116,12 @@ export async function login2FA(request: TotpLogin2FARequest): Promise<AuthRespon
|
||||
|
||||
// Store token and user data
|
||||
setAuthToken(data.access_token)
|
||||
if (data.refresh_token) {
|
||||
setRefreshToken(data.refresh_token)
|
||||
}
|
||||
if (data.expires_in) {
|
||||
setTokenExpiresAt(data.expires_in)
|
||||
}
|
||||
localStorage.setItem('auth_user', JSON.stringify(data.user))
|
||||
|
||||
return data
|
||||
@@ -92,6 +137,12 @@ export async function register(userData: RegisterRequest): Promise<AuthResponse>
|
||||
|
||||
// Store token and user data
|
||||
setAuthToken(data.access_token)
|
||||
if (data.refresh_token) {
|
||||
setRefreshToken(data.refresh_token)
|
||||
}
|
||||
if (data.expires_in) {
|
||||
setTokenExpiresAt(data.expires_in)
|
||||
}
|
||||
localStorage.setItem('auth_user', JSON.stringify(data.user))
|
||||
|
||||
return data
|
||||
@@ -108,11 +159,62 @@ export async function getCurrentUser() {
|
||||
/**
|
||||
* User logout
|
||||
* Clears authentication token and user data from localStorage
|
||||
* Optionally revokes the refresh token on the server
|
||||
*/
|
||||
export function logout(): void {
|
||||
export async function logout(): Promise<void> {
|
||||
const refreshToken = getRefreshToken()
|
||||
|
||||
// Try to revoke the refresh token on the server
|
||||
if (refreshToken) {
|
||||
try {
|
||||
await apiClient.post('/auth/logout', { refresh_token: refreshToken })
|
||||
} catch {
|
||||
// Ignore errors - we still want to clear local state
|
||||
}
|
||||
}
|
||||
|
||||
clearAuthToken()
|
||||
// Optionally redirect to login page
|
||||
// window.location.href = '/login';
|
||||
}
|
||||
|
||||
/**
|
||||
* Refresh token response
|
||||
*/
|
||||
export interface RefreshTokenResponse {
|
||||
access_token: string
|
||||
refresh_token: string
|
||||
expires_in: number
|
||||
token_type: string
|
||||
}
|
||||
|
||||
/**
|
||||
* Refresh the access token using the refresh token
|
||||
* @returns New token pair
|
||||
*/
|
||||
export async function refreshToken(): Promise<RefreshTokenResponse> {
|
||||
const currentRefreshToken = getRefreshToken()
|
||||
if (!currentRefreshToken) {
|
||||
throw new Error('No refresh token available')
|
||||
}
|
||||
|
||||
const { data } = await apiClient.post<RefreshTokenResponse>('/auth/refresh', {
|
||||
refresh_token: currentRefreshToken
|
||||
})
|
||||
|
||||
// Update tokens in localStorage
|
||||
setAuthToken(data.access_token)
|
||||
setRefreshToken(data.refresh_token)
|
||||
setTokenExpiresAt(data.expires_in)
|
||||
|
||||
return data
|
||||
}
|
||||
|
||||
/**
|
||||
* Revoke all sessions for the current user
|
||||
* @returns Response with message
|
||||
*/
|
||||
export async function revokeAllSessions(): Promise<{ message: string }> {
|
||||
const { data } = await apiClient.post<{ message: string }>('/auth/revoke-all-sessions')
|
||||
return data
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -242,14 +344,20 @@ export const authAPI = {
|
||||
logout,
|
||||
isAuthenticated,
|
||||
setAuthToken,
|
||||
setRefreshToken,
|
||||
setTokenExpiresAt,
|
||||
getAuthToken,
|
||||
getRefreshToken,
|
||||
getTokenExpiresAt,
|
||||
clearAuthToken,
|
||||
getPublicSettings,
|
||||
sendVerifyCode,
|
||||
validatePromoCode,
|
||||
validateInvitationCode,
|
||||
forgotPassword,
|
||||
resetPassword
|
||||
resetPassword,
|
||||
refreshToken,
|
||||
revokeAllSessions
|
||||
}
|
||||
|
||||
export default authAPI
|
||||
|
||||
@@ -1,9 +1,9 @@
|
||||
/**
|
||||
* Axios HTTP Client Configuration
|
||||
* Base client with interceptors for authentication and error handling
|
||||
* Base client with interceptors for authentication, token refresh, and error handling
|
||||
*/
|
||||
|
||||
import axios, { AxiosInstance, AxiosError, InternalAxiosRequestConfig } from 'axios'
|
||||
import axios, { AxiosInstance, AxiosError, InternalAxiosRequestConfig, AxiosResponse } from 'axios'
|
||||
import type { ApiResponse } from '@/types'
|
||||
import { getLocale } from '@/i18n'
|
||||
|
||||
@@ -19,6 +19,28 @@ export const apiClient: AxiosInstance = axios.create({
|
||||
}
|
||||
})
|
||||
|
||||
// ==================== Token Refresh State ====================
|
||||
|
||||
// Track if a token refresh is in progress to prevent multiple simultaneous refresh requests
|
||||
let isRefreshing = false
|
||||
// Queue of requests waiting for token refresh
|
||||
let refreshSubscribers: Array<(token: string) => void> = []
|
||||
|
||||
/**
|
||||
* Subscribe to token refresh completion
|
||||
*/
|
||||
function subscribeTokenRefresh(callback: (token: string) => void): void {
|
||||
refreshSubscribers.push(callback)
|
||||
}
|
||||
|
||||
/**
|
||||
* Notify all subscribers that token has been refreshed
|
||||
*/
|
||||
function onTokenRefreshed(token: string): void {
|
||||
refreshSubscribers.forEach((callback) => callback(token))
|
||||
refreshSubscribers = []
|
||||
}
|
||||
|
||||
// ==================== Request Interceptor ====================
|
||||
|
||||
// Get user's timezone
|
||||
@@ -61,7 +83,7 @@ apiClient.interceptors.request.use(
|
||||
// ==================== Response Interceptor ====================
|
||||
|
||||
apiClient.interceptors.response.use(
|
||||
(response) => {
|
||||
(response: AxiosResponse) => {
|
||||
// Unwrap standard API response format { code, message, data }
|
||||
const apiResponse = response.data as ApiResponse<unknown>
|
||||
if (apiResponse && typeof apiResponse === 'object' && 'code' in apiResponse) {
|
||||
@@ -79,13 +101,15 @@ apiClient.interceptors.response.use(
|
||||
}
|
||||
return response
|
||||
},
|
||||
(error: AxiosError<ApiResponse<unknown>>) => {
|
||||
async (error: AxiosError<ApiResponse<unknown>>) => {
|
||||
// Request cancellation: keep the original axios cancellation error so callers can ignore it.
|
||||
// Otherwise we'd misclassify it as a generic "network error".
|
||||
if (error.code === 'ERR_CANCELED' || axios.isCancel(error)) {
|
||||
return Promise.reject(error)
|
||||
}
|
||||
|
||||
const originalRequest = error.config as InternalAxiosRequestConfig & { _retry?: boolean }
|
||||
|
||||
// Handle common errors
|
||||
if (error.response) {
|
||||
const { status, data } = error.response
|
||||
@@ -120,23 +144,116 @@ apiClient.interceptors.response.use(
|
||||
})
|
||||
}
|
||||
|
||||
// 401: Unauthorized - clear token and redirect to login
|
||||
if (status === 401) {
|
||||
const hasToken = !!localStorage.getItem('auth_token')
|
||||
const url = error.config?.url || ''
|
||||
// 401: Try to refresh the token if we have a refresh token
|
||||
// This handles TOKEN_EXPIRED, INVALID_TOKEN, TOKEN_REVOKED, etc.
|
||||
if (status === 401 && !originalRequest._retry) {
|
||||
const refreshToken = localStorage.getItem('refresh_token')
|
||||
const isAuthEndpoint =
|
||||
url.includes('/auth/login') || url.includes('/auth/register') || url.includes('/auth/refresh')
|
||||
|
||||
// If we have a refresh token and this is not an auth endpoint, try to refresh
|
||||
if (refreshToken && !isAuthEndpoint) {
|
||||
if (isRefreshing) {
|
||||
// Wait for the ongoing refresh to complete
|
||||
return new Promise((resolve, reject) => {
|
||||
subscribeTokenRefresh((newToken: string) => {
|
||||
if (newToken) {
|
||||
// Mark as retried to prevent infinite loop if retry also returns 401
|
||||
originalRequest._retry = true
|
||||
if (originalRequest.headers) {
|
||||
originalRequest.headers.Authorization = `Bearer ${newToken}`
|
||||
}
|
||||
resolve(apiClient(originalRequest))
|
||||
} else {
|
||||
// Refresh failed, reject with original error
|
||||
reject({
|
||||
status,
|
||||
code: apiData.code,
|
||||
message: apiData.message || apiData.detail || error.message
|
||||
})
|
||||
}
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
originalRequest._retry = true
|
||||
isRefreshing = true
|
||||
|
||||
try {
|
||||
// Call refresh endpoint directly to avoid circular dependency
|
||||
const refreshResponse = await axios.post(
|
||||
`${API_BASE_URL}/auth/refresh`,
|
||||
{ refresh_token: refreshToken },
|
||||
{ headers: { 'Content-Type': 'application/json' } }
|
||||
)
|
||||
|
||||
const refreshData = refreshResponse.data as ApiResponse<{
|
||||
access_token: string
|
||||
refresh_token: string
|
||||
expires_in: number
|
||||
}>
|
||||
|
||||
if (refreshData.code === 0 && refreshData.data) {
|
||||
const { access_token, refresh_token: newRefreshToken, expires_in } = refreshData.data
|
||||
|
||||
// Update tokens in localStorage (convert expires_in to timestamp)
|
||||
localStorage.setItem('auth_token', access_token)
|
||||
localStorage.setItem('refresh_token', newRefreshToken)
|
||||
localStorage.setItem('token_expires_at', String(Date.now() + expires_in * 1000))
|
||||
|
||||
// Notify subscribers with new token
|
||||
onTokenRefreshed(access_token)
|
||||
|
||||
// Retry the original request with new token
|
||||
if (originalRequest.headers) {
|
||||
originalRequest.headers.Authorization = `Bearer ${access_token}`
|
||||
}
|
||||
|
||||
isRefreshing = false
|
||||
return apiClient(originalRequest)
|
||||
}
|
||||
|
||||
// Refresh response was not successful, fall through to clear auth
|
||||
throw new Error('Token refresh failed')
|
||||
} catch (refreshError) {
|
||||
// Refresh failed - notify subscribers with empty token
|
||||
onTokenRefreshed('')
|
||||
isRefreshing = false
|
||||
|
||||
// Clear tokens and redirect to login
|
||||
localStorage.removeItem('auth_token')
|
||||
localStorage.removeItem('refresh_token')
|
||||
localStorage.removeItem('auth_user')
|
||||
localStorage.removeItem('token_expires_at')
|
||||
sessionStorage.setItem('auth_expired', '1')
|
||||
|
||||
if (!window.location.pathname.includes('/login')) {
|
||||
window.location.href = '/login'
|
||||
}
|
||||
|
||||
return Promise.reject({
|
||||
status: 401,
|
||||
code: 'TOKEN_REFRESH_FAILED',
|
||||
message: 'Session expired. Please log in again.'
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// No refresh token or is auth endpoint - clear auth and redirect
|
||||
const hasToken = !!localStorage.getItem('auth_token')
|
||||
const headers = error.config?.headers as Record<string, unknown> | undefined
|
||||
const authHeader = headers?.Authorization ?? headers?.authorization
|
||||
const sentAuth =
|
||||
typeof authHeader === 'string'
|
||||
? authHeader.trim() !== ''
|
||||
: Array.isArray(authHeader)
|
||||
? authHeader.length > 0
|
||||
: !!authHeader
|
||||
? authHeader.length > 0
|
||||
: !!authHeader
|
||||
|
||||
localStorage.removeItem('auth_token')
|
||||
localStorage.removeItem('refresh_token')
|
||||
localStorage.removeItem('auth_user')
|
||||
localStorage.removeItem('token_expires_at')
|
||||
if ((hasToken || sentAuth) && !isAuthEndpoint) {
|
||||
sessionStorage.setItem('auth_expired', '1')
|
||||
}
|
||||
|
||||
@@ -283,7 +283,12 @@ function closeDropdown() {
|
||||
|
||||
async function handleLogout() {
|
||||
closeDropdown()
|
||||
authStore.logout()
|
||||
try {
|
||||
await authStore.logout()
|
||||
} catch (error) {
|
||||
// Ignore logout errors - still redirect to login
|
||||
console.error('Logout error:', error)
|
||||
}
|
||||
await router.push('/login')
|
||||
}
|
||||
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
/**
|
||||
* Authentication Store
|
||||
* Manages user authentication state, login/logout, and token persistence
|
||||
* Manages user authentication state, login/logout, token refresh, and token persistence
|
||||
*/
|
||||
|
||||
import { defineStore } from 'pinia'
|
||||
@@ -10,15 +10,21 @@ import type { User, LoginRequest, RegisterRequest, AuthResponse } from '@/types'
|
||||
|
||||
const AUTH_TOKEN_KEY = 'auth_token'
|
||||
const AUTH_USER_KEY = 'auth_user'
|
||||
const AUTO_REFRESH_INTERVAL = 60 * 1000 // 60 seconds
|
||||
const REFRESH_TOKEN_KEY = 'refresh_token'
|
||||
const TOKEN_EXPIRES_AT_KEY = 'token_expires_at' // 存储过期时间戳而非有效期
|
||||
const AUTO_REFRESH_INTERVAL = 60 * 1000 // 60 seconds for user data refresh
|
||||
const TOKEN_REFRESH_BUFFER = 120 * 1000 // 120 seconds before expiry to refresh token
|
||||
|
||||
export const useAuthStore = defineStore('auth', () => {
|
||||
// ==================== State ====================
|
||||
|
||||
const user = ref<User | null>(null)
|
||||
const token = ref<string | null>(null)
|
||||
const refreshTokenValue = ref<string | null>(null)
|
||||
const tokenExpiresAt = ref<number | null>(null) // 过期时间戳(毫秒)
|
||||
const runMode = ref<'standard' | 'simple'>('standard')
|
||||
let refreshIntervalId: ReturnType<typeof setInterval> | null = null
|
||||
let tokenRefreshTimeoutId: ReturnType<typeof setTimeout> | null = null
|
||||
|
||||
// ==================== Computed ====================
|
||||
|
||||
@@ -42,19 +48,29 @@ export const useAuthStore = defineStore('auth', () => {
|
||||
function checkAuth(): void {
|
||||
const savedToken = localStorage.getItem(AUTH_TOKEN_KEY)
|
||||
const savedUser = localStorage.getItem(AUTH_USER_KEY)
|
||||
const savedRefreshToken = localStorage.getItem(REFRESH_TOKEN_KEY)
|
||||
const savedExpiresAt = localStorage.getItem(TOKEN_EXPIRES_AT_KEY)
|
||||
|
||||
if (savedToken && savedUser) {
|
||||
try {
|
||||
token.value = savedToken
|
||||
user.value = JSON.parse(savedUser)
|
||||
refreshTokenValue.value = savedRefreshToken
|
||||
tokenExpiresAt.value = savedExpiresAt ? parseInt(savedExpiresAt, 10) : null
|
||||
|
||||
// Immediately refresh user data from backend (async, don't block)
|
||||
refreshUser().catch((error) => {
|
||||
console.error('Failed to refresh user on init:', error)
|
||||
})
|
||||
|
||||
// Start auto-refresh interval
|
||||
// Start auto-refresh interval for user data
|
||||
startAutoRefresh()
|
||||
|
||||
// Start proactive token refresh if we have refresh token and expiry info
|
||||
// Note: use !== null to handle case when tokenExpiresAt.value is 0 (expired)
|
||||
if (savedRefreshToken && tokenExpiresAt.value !== null) {
|
||||
scheduleTokenRefreshAt(tokenExpiresAt.value)
|
||||
}
|
||||
} catch (error) {
|
||||
console.error('Failed to parse saved user data:', error)
|
||||
clearAuth()
|
||||
@@ -89,6 +105,76 @@ export const useAuthStore = defineStore('auth', () => {
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Schedule proactive token refresh before expiry (based on expiry timestamp)
|
||||
* @param expiresAtMs - Token expiry timestamp in milliseconds
|
||||
*/
|
||||
function scheduleTokenRefreshAt(expiresAtMs: number): void {
|
||||
// Clear any existing timeout
|
||||
if (tokenRefreshTimeoutId) {
|
||||
clearTimeout(tokenRefreshTimeoutId)
|
||||
tokenRefreshTimeoutId = null
|
||||
}
|
||||
|
||||
// Calculate remaining time until refresh (buffer time before expiry)
|
||||
const now = Date.now()
|
||||
const refreshInMs = Math.max(0, expiresAtMs - now - TOKEN_REFRESH_BUFFER)
|
||||
|
||||
if (refreshInMs <= 0) {
|
||||
// Token is about to expire or already expired, refresh immediately
|
||||
performTokenRefresh()
|
||||
return
|
||||
}
|
||||
|
||||
tokenRefreshTimeoutId = setTimeout(() => {
|
||||
performTokenRefresh()
|
||||
}, refreshInMs)
|
||||
}
|
||||
|
||||
/**
|
||||
* Schedule proactive token refresh before expiry (based on expires_in seconds)
|
||||
* @param expiresInSeconds - Token expiry time in seconds from now
|
||||
*/
|
||||
function scheduleTokenRefresh(expiresInSeconds: number): void {
|
||||
const expiresAtMs = Date.now() + expiresInSeconds * 1000
|
||||
tokenExpiresAt.value = expiresAtMs
|
||||
localStorage.setItem(TOKEN_EXPIRES_AT_KEY, String(expiresAtMs))
|
||||
scheduleTokenRefreshAt(expiresAtMs)
|
||||
}
|
||||
|
||||
/**
|
||||
* Perform the actual token refresh
|
||||
*/
|
||||
async function performTokenRefresh(): Promise<void> {
|
||||
if (!refreshTokenValue.value) {
|
||||
return
|
||||
}
|
||||
|
||||
try {
|
||||
const response = await authAPI.refreshToken()
|
||||
|
||||
// Update state
|
||||
token.value = response.access_token
|
||||
refreshTokenValue.value = response.refresh_token
|
||||
|
||||
// Schedule next refresh (this also updates tokenExpiresAt and localStorage)
|
||||
scheduleTokenRefresh(response.expires_in)
|
||||
} catch (error) {
|
||||
console.error('Token refresh failed:', error)
|
||||
// Don't clear auth here - the interceptor will handle 401 errors
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Stop token refresh timeout
|
||||
*/
|
||||
function stopTokenRefresh(): void {
|
||||
if (tokenRefreshTimeoutId) {
|
||||
clearTimeout(tokenRefreshTimeoutId)
|
||||
tokenRefreshTimeoutId = null
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* User login
|
||||
* @param credentials - Login credentials (email and password)
|
||||
@@ -141,6 +227,12 @@ export const useAuthStore = defineStore('auth', () => {
|
||||
// Store token and user
|
||||
token.value = response.access_token
|
||||
|
||||
// Store refresh token if present
|
||||
if (response.refresh_token) {
|
||||
refreshTokenValue.value = response.refresh_token
|
||||
localStorage.setItem(REFRESH_TOKEN_KEY, response.refresh_token)
|
||||
}
|
||||
|
||||
// Extract run_mode if present
|
||||
if (response.user.run_mode) {
|
||||
runMode.value = response.user.run_mode
|
||||
@@ -152,8 +244,14 @@ export const useAuthStore = defineStore('auth', () => {
|
||||
localStorage.setItem(AUTH_TOKEN_KEY, response.access_token)
|
||||
localStorage.setItem(AUTH_USER_KEY, JSON.stringify(userData))
|
||||
|
||||
// Start auto-refresh interval
|
||||
// Start auto-refresh interval for user data
|
||||
startAutoRefresh()
|
||||
|
||||
// Start proactive token refresh if we have refresh token and expiry info
|
||||
// scheduleTokenRefresh will also store the expiry timestamp
|
||||
if (response.refresh_token && response.expires_in) {
|
||||
scheduleTokenRefresh(response.expires_in)
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -166,24 +264,10 @@ export const useAuthStore = defineStore('auth', () => {
|
||||
try {
|
||||
const response = await authAPI.register(userData)
|
||||
|
||||
// Store token and user
|
||||
token.value = response.access_token
|
||||
// Use the common helper to set auth state
|
||||
setAuthFromResponse(response)
|
||||
|
||||
// Extract run_mode if present
|
||||
if (response.user.run_mode) {
|
||||
runMode.value = response.user.run_mode
|
||||
}
|
||||
const { run_mode: _run_mode, ...userDataWithoutRunMode } = response.user
|
||||
user.value = userDataWithoutRunMode
|
||||
|
||||
// Persist to localStorage
|
||||
localStorage.setItem(AUTH_TOKEN_KEY, response.access_token)
|
||||
localStorage.setItem(AUTH_USER_KEY, JSON.stringify(userDataWithoutRunMode))
|
||||
|
||||
// Start auto-refresh interval
|
||||
startAutoRefresh()
|
||||
|
||||
return userDataWithoutRunMode
|
||||
return user.value!
|
||||
} catch (error) {
|
||||
// Clear any partial state on error
|
||||
clearAuth()
|
||||
@@ -193,18 +277,41 @@ export const useAuthStore = defineStore('auth', () => {
|
||||
|
||||
/**
|
||||
* 直接设置 token(用于 OAuth/SSO 回调),并加载当前用户信息。
|
||||
* 会自动读取 localStorage 中已设置的 refresh_token 和 token_expires_in
|
||||
* @param newToken - 后端签发的 JWT access token
|
||||
*/
|
||||
async function setToken(newToken: string): Promise<User> {
|
||||
// Clear any previous state first (avoid mixing sessions)
|
||||
clearAuth()
|
||||
// Note: Don't clear localStorage here as OAuth callback may have set refresh_token
|
||||
stopAutoRefresh()
|
||||
stopTokenRefresh()
|
||||
token.value = null
|
||||
user.value = null
|
||||
|
||||
token.value = newToken
|
||||
localStorage.setItem(AUTH_TOKEN_KEY, newToken)
|
||||
|
||||
// Read refresh token and expires_at from localStorage if set by OAuth callback
|
||||
const savedRefreshToken = localStorage.getItem(REFRESH_TOKEN_KEY)
|
||||
const savedExpiresAt = localStorage.getItem(TOKEN_EXPIRES_AT_KEY)
|
||||
|
||||
if (savedRefreshToken) {
|
||||
refreshTokenValue.value = savedRefreshToken
|
||||
}
|
||||
if (savedExpiresAt) {
|
||||
tokenExpiresAt.value = parseInt(savedExpiresAt, 10)
|
||||
}
|
||||
|
||||
try {
|
||||
const userData = await refreshUser()
|
||||
startAutoRefresh()
|
||||
|
||||
// Start proactive token refresh if we have refresh token and expiry info
|
||||
// Note: use !== null to handle case when tokenExpiresAt.value is 0 (expired)
|
||||
if (savedRefreshToken && tokenExpiresAt.value !== null) {
|
||||
scheduleTokenRefreshAt(tokenExpiresAt.value)
|
||||
}
|
||||
|
||||
return userData
|
||||
} catch (error) {
|
||||
clearAuth()
|
||||
@@ -216,9 +323,9 @@ export const useAuthStore = defineStore('auth', () => {
|
||||
* User logout
|
||||
* Clears all authentication state and persisted data
|
||||
*/
|
||||
function logout(): void {
|
||||
// Call API logout (client-side cleanup)
|
||||
authAPI.logout()
|
||||
async function logout(): Promise<void> {
|
||||
// Call API logout (revokes refresh token on server)
|
||||
await authAPI.logout()
|
||||
|
||||
// Clear state
|
||||
clearAuth()
|
||||
@@ -263,11 +370,17 @@ export const useAuthStore = defineStore('auth', () => {
|
||||
function clearAuth(): void {
|
||||
// Stop auto-refresh
|
||||
stopAutoRefresh()
|
||||
// Stop token refresh
|
||||
stopTokenRefresh()
|
||||
|
||||
token.value = null
|
||||
refreshTokenValue.value = null
|
||||
tokenExpiresAt.value = null
|
||||
user.value = null
|
||||
localStorage.removeItem(AUTH_TOKEN_KEY)
|
||||
localStorage.removeItem(AUTH_USER_KEY)
|
||||
localStorage.removeItem(REFRESH_TOKEN_KEY)
|
||||
localStorage.removeItem(TOKEN_EXPIRES_AT_KEY)
|
||||
}
|
||||
|
||||
// ==================== Return Store API ====================
|
||||
|
||||
@@ -92,6 +92,8 @@ export interface PublicSettings {
|
||||
|
||||
export interface AuthResponse {
|
||||
access_token: string
|
||||
refresh_token?: string // New: Refresh Token for token renewal
|
||||
expires_in?: number // New: Access Token expiry time in seconds
|
||||
token_type: string
|
||||
user: User & { run_mode?: 'standard' | 'simple' }
|
||||
}
|
||||
|
||||
@@ -71,6 +71,8 @@ onMounted(async () => {
|
||||
const params = parseFragmentParams()
|
||||
|
||||
const token = params.get('access_token') || ''
|
||||
const refreshToken = params.get('refresh_token') || ''
|
||||
const expiresInStr = params.get('expires_in') || ''
|
||||
const redirect = sanitizeRedirectPath(
|
||||
params.get('redirect') || (route.query.redirect as string | undefined) || '/dashboard'
|
||||
)
|
||||
@@ -92,6 +94,17 @@ onMounted(async () => {
|
||||
}
|
||||
|
||||
try {
|
||||
// Store refresh token and expires_at (convert to timestamp) if provided
|
||||
if (refreshToken) {
|
||||
localStorage.setItem('refresh_token', refreshToken)
|
||||
}
|
||||
if (expiresInStr) {
|
||||
const expiresIn = parseInt(expiresInStr, 10)
|
||||
if (!isNaN(expiresIn)) {
|
||||
localStorage.setItem('token_expires_at', String(Date.now() + expiresIn * 1000))
|
||||
}
|
||||
}
|
||||
|
||||
await authStore.setToken(token)
|
||||
appStore.showSuccess(t('auth.loginSuccess'))
|
||||
await router.replace(redirect)
|
||||
|
||||
Reference in New Issue
Block a user