From fecfaae8dc510cf22706f1da45fedae0ac57fb59 Mon Sep 17 00:00:00 2001 From: Payne Fu Date: Wed, 4 Feb 2026 15:56:01 +0800 Subject: [PATCH 1/3] fix: remove unsupported safety_identifier and previous_response_id fields from upstream requests Co-Authored-By: Claude Opus 4.5 --- backend/internal/service/openai_gateway_service.go | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/backend/internal/service/openai_gateway_service.go b/backend/internal/service/openai_gateway_service.go index 742946d8..4658c694 100644 --- a/backend/internal/service/openai_gateway_service.go +++ b/backend/internal/service/openai_gateway_service.go @@ -846,10 +846,12 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco } } - // Remove prompt_cache_retention (not supported by upstream OpenAI API) - if _, has := reqBody["prompt_cache_retention"]; has { - delete(reqBody, "prompt_cache_retention") - bodyModified = true + // Remove unsupported fields (not supported by upstream OpenAI API) + for _, unsupportedField := range []string{"prompt_cache_retention", "safety_identifier", "previous_response_id"} { + if _, has := reqBody[unsupportedField]; has { + delete(reqBody, unsupportedField) + bodyModified = true + } } } From 05af95dade5ea1def96d280755646f961a46fe53 Mon Sep 17 00:00:00 2001 From: shaw Date: Thu, 5 Feb 2026 09:53:20 +0800 Subject: [PATCH 2/3] =?UTF-8?q?fix(gateway):=20=E4=BF=AE=E5=A4=8D=E5=B7=A5?= =?UTF-8?q?=E5=85=B7=E5=90=8D=E8=BD=AC=E6=8D=A2=E7=A0=B4=E5=9D=8F=20Anthro?= =?UTF-8?q?pic=20=E7=89=B9=E6=AE=8A=E5=B7=A5=E5=85=B7=E7=9A=84=E9=97=AE?= =?UTF-8?q?=E9=A2=98?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 未知工具名不再进行 PascalCase/snake_case 转换,保持原样透传。 修复 text_editor_20250728 等 Anthropic 特殊工具被错误转换的问题。 --- backend/internal/service/gateway_service.go | 46 ++++----------------- 1 file changed, 9 insertions(+), 37 deletions(-) diff --git a/backend/internal/service/gateway_service.go b/backend/internal/service/gateway_service.go index b2ac1efa..8c88c0a9 100644 --- a/backend/internal/service/gateway_service.go +++ b/backend/internal/service/gateway_service.go @@ -20,7 +20,6 @@ import ( "strings" "sync/atomic" "time" - "unicode" "github.com/Wei-Shaw/sub2api/internal/config" "github.com/Wei-Shaw/sub2api/internal/pkg/claude" @@ -620,35 +619,6 @@ func stripToolPrefix(value string) string { return toolPrefixRe.ReplaceAllString(value, "") } -func toPascalCase(value string) string { - if value == "" { - return value - } - normalized := toolNameBoundaryRe.ReplaceAllString(value, " ") - tokens := make([]string, 0) - for _, token := range strings.Fields(normalized) { - expanded := toolNameCamelRe.ReplaceAllString(token, "$1 $2") - parts := strings.Fields(expanded) - if len(parts) > 0 { - tokens = append(tokens, parts...) - } - } - if len(tokens) == 0 { - return value - } - var builder strings.Builder - for _, token := range tokens { - lower := strings.ToLower(token) - if lower == "" { - continue - } - runes := []rune(lower) - runes[0] = unicode.ToUpper(runes[0]) - _, _ = builder.WriteString(string(runes)) - } - return builder.String() -} - func toSnakeCase(value string) string { if value == "" { return value @@ -664,16 +634,15 @@ func normalizeToolNameForClaude(name string, cache map[string]string) string { return name } stripped := stripToolPrefix(name) + // 只对已知的工具名进行映射,未知工具名保持原样 + // 避免破坏 Anthropic 特殊工具(如 text_editor_20250728) mapped, ok := claudeToolNameOverrides[strings.ToLower(stripped)] if !ok { - mapped = toPascalCase(stripped) - } - if mapped != "" && cache != nil && mapped != stripped { - cache[mapped] = stripped - } - if mapped == "" { return stripped } + if cache != nil && mapped != stripped { + cache[mapped] = stripped + } return mapped } @@ -682,15 +651,18 @@ func normalizeToolNameForOpenCode(name string, cache map[string]string) string { return name } stripped := stripToolPrefix(name) + // 优先从请求时建立的映射中查找 if cache != nil { if mapped, ok := cache[stripped]; ok { return mapped } } + // 已知工具名的硬编码映射 if mapped, ok := openCodeToolOverrides[stripped]; ok { return mapped } - return toSnakeCase(stripped) + // 未知工具名保持原样,避免破坏 Anthropic 特殊工具 + return stripped } func normalizeParamNameForOpenCode(name string, cache map[string]string) string { From 49a3c43741be43a84147075f1d9f1e0a7500d84b Mon Sep 17 00:00:00 2001 From: shaw Date: Thu, 5 Feb 2026 12:38:48 +0800 Subject: [PATCH 3/3] =?UTF-8?q?feat(auth):=20=E5=AE=9E=E7=8E=B0=20Refresh?= =?UTF-8?q?=20Token=20=E6=9C=BA=E5=88=B6?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - 新增 Access Token + Refresh Token 双令牌认证 - 支持 Token 自动刷新和轮转 - 添加登出和撤销所有会话接口 - 前端实现无感刷新和主动刷新定时器 --- backend/cmd/jwtgen/main.go | 2 +- backend/cmd/server/wire_gen.go | 5 +- backend/internal/config/config.go | 26 ++ backend/internal/handler/auth_handler.go | 157 +++++++-- .../internal/handler/auth_linuxdo_oauth.go | 6 +- .../repository/refresh_token_cache.go | 158 +++++++++ backend/internal/repository/wire.go | 1 + backend/internal/server/routes/auth.go | 8 + backend/internal/service/auth_service.go | 317 +++++++++++++++++- .../service/auth_service_register_test.go | 1 + .../internal/service/refresh_token_cache.go | 73 ++++ frontend/src/api/auth.ts | 116 ++++++- frontend/src/api/client.ts | 137 +++++++- frontend/src/components/layout/AppHeader.vue | 7 +- frontend/src/stores/auth.ts | 163 +++++++-- frontend/src/types/index.ts | 2 + .../src/views/auth/LinuxDoCallbackView.vue | 13 + 17 files changed, 1119 insertions(+), 73 deletions(-) create mode 100644 backend/internal/repository/refresh_token_cache.go create mode 100644 backend/internal/service/refresh_token_cache.go diff --git a/backend/cmd/jwtgen/main.go b/backend/cmd/jwtgen/main.go index 139a3a39..ce4718bf 100644 --- a/backend/cmd/jwtgen/main.go +++ b/backend/cmd/jwtgen/main.go @@ -33,7 +33,7 @@ func main() { }() userRepo := repository.NewUserRepository(client, sqlDB) - authService := service.NewAuthService(userRepo, nil, cfg, nil, nil, nil, nil, nil) + authService := service.NewAuthService(userRepo, nil, nil, cfg, nil, nil, nil, nil, nil) ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() diff --git a/backend/cmd/server/wire_gen.go b/backend/cmd/server/wire_gen.go index ab51540f..47b1e8ac 100644 --- a/backend/cmd/server/wire_gen.go +++ b/backend/cmd/server/wire_gen.go @@ -44,9 +44,10 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { } userRepository := repository.NewUserRepository(client, db) redeemCodeRepository := repository.NewRedeemCodeRepository(client) + redisClient := repository.ProvideRedis(configConfig) + refreshTokenCache := repository.NewRefreshTokenCache(redisClient) settingRepository := repository.NewSettingRepository(client) settingService := service.NewSettingService(settingRepository, configConfig) - redisClient := repository.ProvideRedis(configConfig) emailCache := repository.NewEmailCache(redisClient) emailService := service.NewEmailService(settingRepository, emailCache) turnstileVerifier := repository.NewTurnstileVerifier() @@ -62,7 +63,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { apiKeyService := service.NewAPIKeyService(apiKeyRepository, userRepository, groupRepository, userSubscriptionRepository, apiKeyCache, configConfig) apiKeyAuthCacheInvalidator := service.ProvideAPIKeyAuthCacheInvalidator(apiKeyService) promoService := service.NewPromoService(promoCodeRepository, userRepository, billingCacheService, client, apiKeyAuthCacheInvalidator) - authService := service.NewAuthService(userRepository, redeemCodeRepository, configConfig, settingService, emailService, turnstileService, emailQueueService, promoService) + authService := service.NewAuthService(userRepository, redeemCodeRepository, refreshTokenCache, configConfig, settingService, emailService, turnstileService, emailQueueService, promoService) userService := service.NewUserService(userRepository, apiKeyAuthCacheInvalidator) subscriptionService := service.NewSubscriptionService(groupRepository, userSubscriptionRepository, billingCacheService) redeemCache := repository.NewRedeemCache(redisClient) diff --git a/backend/internal/config/config.go b/backend/internal/config/config.go index 84be445b..25258b23 100644 --- a/backend/internal/config/config.go +++ b/backend/internal/config/config.go @@ -467,6 +467,13 @@ type OpsMetricsCollectorCacheConfig struct { type JWTConfig struct { Secret string `mapstructure:"secret"` ExpireHour int `mapstructure:"expire_hour"` + // AccessTokenExpireMinutes: Access Token有效期(分钟),默认15分钟 + // 短有效期减少被盗用风险,配合Refresh Token实现无感续期 + AccessTokenExpireMinutes int `mapstructure:"access_token_expire_minutes"` + // RefreshTokenExpireDays: Refresh Token有效期(天),默认30天 + RefreshTokenExpireDays int `mapstructure:"refresh_token_expire_days"` + // RefreshWindowMinutes: 刷新窗口(分钟),在Access Token过期前多久开始允许刷新 + RefreshWindowMinutes int `mapstructure:"refresh_window_minutes"` } // TotpConfig TOTP 双因素认证配置 @@ -783,6 +790,9 @@ func setDefaults() { // JWT viper.SetDefault("jwt.secret", "") viper.SetDefault("jwt.expire_hour", 24) + viper.SetDefault("jwt.access_token_expire_minutes", 360) // 6小时Access Token有效期 + viper.SetDefault("jwt.refresh_token_expire_days", 30) // 30天Refresh Token有效期 + viper.SetDefault("jwt.refresh_window_minutes", 2) // 过期前2分钟开始允许刷新 // TOTP viper.SetDefault("totp.encryption_key", "") @@ -912,6 +922,22 @@ func (c *Config) Validate() error { if c.JWT.ExpireHour > 24 { log.Printf("Warning: jwt.expire_hour is %d hours (> 24). Consider shorter expiration for security.", c.JWT.ExpireHour) } + // JWT Refresh Token配置验证 + if c.JWT.AccessTokenExpireMinutes <= 0 { + return fmt.Errorf("jwt.access_token_expire_minutes must be positive") + } + if c.JWT.AccessTokenExpireMinutes > 720 { + log.Printf("Warning: jwt.access_token_expire_minutes is %d (> 720). Consider shorter expiration for security.", c.JWT.AccessTokenExpireMinutes) + } + if c.JWT.RefreshTokenExpireDays <= 0 { + return fmt.Errorf("jwt.refresh_token_expire_days must be positive") + } + if c.JWT.RefreshTokenExpireDays > 90 { + log.Printf("Warning: jwt.refresh_token_expire_days is %d (> 90). Consider shorter expiration for security.", c.JWT.RefreshTokenExpireDays) + } + if c.JWT.RefreshWindowMinutes < 0 { + return fmt.Errorf("jwt.refresh_window_minutes must be non-negative") + } if c.Security.CSP.Enabled && strings.TrimSpace(c.Security.CSP.Policy) == "" { return fmt.Errorf("security.csp.policy is required when CSP is enabled") } diff --git a/backend/internal/handler/auth_handler.go b/backend/internal/handler/auth_handler.go index 75ea9f08..34ed63bc 100644 --- a/backend/internal/handler/auth_handler.go +++ b/backend/internal/handler/auth_handler.go @@ -68,9 +68,39 @@ type LoginRequest struct { // AuthResponse 认证响应格式(匹配前端期望) type AuthResponse struct { - AccessToken string `json:"access_token"` - TokenType string `json:"token_type"` - User *dto.User `json:"user"` + AccessToken string `json:"access_token"` + RefreshToken string `json:"refresh_token,omitempty"` // 新增:Refresh Token + ExpiresIn int `json:"expires_in,omitempty"` // 新增:Access Token有效期(秒) + TokenType string `json:"token_type"` + User *dto.User `json:"user"` +} + +// respondWithTokenPair 生成 Token 对并返回认证响应 +// 如果 Token 对生成失败,回退到只返回 Access Token(向后兼容) +func (h *AuthHandler) respondWithTokenPair(c *gin.Context, user *service.User) { + tokenPair, err := h.authService.GenerateTokenPair(c.Request.Context(), user, "") + if err != nil { + slog.Error("failed to generate token pair", "error", err, "user_id", user.ID) + // 回退到只返回Access Token + token, tokenErr := h.authService.GenerateToken(user) + if tokenErr != nil { + response.InternalError(c, "Failed to generate token") + return + } + response.Success(c, AuthResponse{ + AccessToken: token, + TokenType: "Bearer", + User: dto.UserFromService(user), + }) + return + } + response.Success(c, AuthResponse{ + AccessToken: tokenPair.AccessToken, + RefreshToken: tokenPair.RefreshToken, + ExpiresIn: tokenPair.ExpiresIn, + TokenType: "Bearer", + User: dto.UserFromService(user), + }) } // Register handles user registration @@ -90,17 +120,13 @@ func (h *AuthHandler) Register(c *gin.Context) { } } - token, user, err := h.authService.RegisterWithVerification(c.Request.Context(), req.Email, req.Password, req.VerifyCode, req.PromoCode, req.InvitationCode) + _, user, err := h.authService.RegisterWithVerification(c.Request.Context(), req.Email, req.Password, req.VerifyCode, req.PromoCode, req.InvitationCode) if err != nil { response.ErrorFrom(c, err) return } - response.Success(c, AuthResponse{ - AccessToken: token, - TokenType: "Bearer", - User: dto.UserFromService(user), - }) + h.respondWithTokenPair(c, user) } // SendVerifyCode 发送邮箱验证码 @@ -150,6 +176,7 @@ func (h *AuthHandler) Login(c *gin.Context) { response.ErrorFrom(c, err) return } + _ = token // token 由 authService.Login 返回但此处由 respondWithTokenPair 重新生成 // Check if TOTP 2FA is enabled for this user if h.totpService != nil && h.settingSvc.IsTotpEnabled(c.Request.Context()) && user.TotpEnabled { @@ -168,11 +195,7 @@ func (h *AuthHandler) Login(c *gin.Context) { return } - response.Success(c, AuthResponse{ - AccessToken: token, - TokenType: "Bearer", - User: dto.UserFromService(user), - }) + h.respondWithTokenPair(c, user) } // TotpLoginResponse represents the response when 2FA is required @@ -238,18 +261,7 @@ func (h *AuthHandler) Login2FA(c *gin.Context) { return } - // Generate the JWT token - token, err := h.authService.GenerateToken(user) - if err != nil { - response.InternalError(c, "Failed to generate token") - return - } - - response.Success(c, AuthResponse{ - AccessToken: token, - TokenType: "Bearer", - User: dto.UserFromService(user), - }) + h.respondWithTokenPair(c, user) } // GetCurrentUser handles getting current authenticated user @@ -491,3 +503,96 @@ func (h *AuthHandler) ResetPassword(c *gin.Context) { Message: "Your password has been reset successfully. You can now log in with your new password.", }) } + +// ==================== Token Refresh Endpoints ==================== + +// RefreshTokenRequest 刷新Token请求 +type RefreshTokenRequest struct { + RefreshToken string `json:"refresh_token" binding:"required"` +} + +// RefreshTokenResponse 刷新Token响应 +type RefreshTokenResponse struct { + AccessToken string `json:"access_token"` + RefreshToken string `json:"refresh_token"` + ExpiresIn int `json:"expires_in"` // Access Token有效期(秒) + TokenType string `json:"token_type"` +} + +// RefreshToken 刷新Token +// POST /api/v1/auth/refresh +func (h *AuthHandler) RefreshToken(c *gin.Context) { + var req RefreshTokenRequest + if err := c.ShouldBindJSON(&req); err != nil { + response.BadRequest(c, "Invalid request: "+err.Error()) + return + } + + tokenPair, err := h.authService.RefreshTokenPair(c.Request.Context(), req.RefreshToken) + if err != nil { + response.ErrorFrom(c, err) + return + } + + response.Success(c, RefreshTokenResponse{ + AccessToken: tokenPair.AccessToken, + RefreshToken: tokenPair.RefreshToken, + ExpiresIn: tokenPair.ExpiresIn, + TokenType: "Bearer", + }) +} + +// LogoutRequest 登出请求 +type LogoutRequest struct { + RefreshToken string `json:"refresh_token,omitempty"` // 可选:撤销指定的Refresh Token +} + +// LogoutResponse 登出响应 +type LogoutResponse struct { + Message string `json:"message"` +} + +// Logout 用户登出 +// POST /api/v1/auth/logout +func (h *AuthHandler) Logout(c *gin.Context) { + var req LogoutRequest + // 允许空请求体(向后兼容) + _ = c.ShouldBindJSON(&req) + + // 如果提供了Refresh Token,撤销它 + if req.RefreshToken != "" { + if err := h.authService.RevokeRefreshToken(c.Request.Context(), req.RefreshToken); err != nil { + slog.Debug("failed to revoke refresh token", "error", err) + // 不影响登出流程 + } + } + + response.Success(c, LogoutResponse{ + Message: "Logged out successfully", + }) +} + +// RevokeAllSessionsResponse 撤销所有会话响应 +type RevokeAllSessionsResponse struct { + Message string `json:"message"` +} + +// RevokeAllSessions 撤销当前用户的所有会话 +// POST /api/v1/auth/revoke-all-sessions +func (h *AuthHandler) RevokeAllSessions(c *gin.Context) { + subject, ok := middleware2.GetAuthSubjectFromContext(c) + if !ok { + response.Unauthorized(c, "User not authenticated") + return + } + + if err := h.authService.RevokeAllUserSessions(c.Request.Context(), subject.UserID); err != nil { + slog.Error("failed to revoke all sessions", "user_id", subject.UserID, "error", err) + response.InternalError(c, "Failed to revoke sessions") + return + } + + response.Success(c, RevokeAllSessionsResponse{ + Message: "All sessions have been revoked. Please log in again.", + }) +} diff --git a/backend/internal/handler/auth_linuxdo_oauth.go b/backend/internal/handler/auth_linuxdo_oauth.go index a16c4cc7..0ccf47e4 100644 --- a/backend/internal/handler/auth_linuxdo_oauth.go +++ b/backend/internal/handler/auth_linuxdo_oauth.go @@ -211,7 +211,7 @@ func (h *AuthHandler) LinuxDoOAuthCallback(c *gin.Context) { email = linuxDoSyntheticEmail(subject) } - jwtToken, _, err := h.authService.LoginOrRegisterOAuth(c.Request.Context(), email, username) + tokenPair, _, err := h.authService.LoginOrRegisterOAuthWithTokenPair(c.Request.Context(), email, username) if err != nil { // 避免把内部细节泄露给客户端;给前端保留结构化原因与提示信息即可。 redirectOAuthError(c, frontendCallback, "login_failed", infraerrors.Reason(err), infraerrors.Message(err)) @@ -219,7 +219,9 @@ func (h *AuthHandler) LinuxDoOAuthCallback(c *gin.Context) { } fragment := url.Values{} - fragment.Set("access_token", jwtToken) + fragment.Set("access_token", tokenPair.AccessToken) + fragment.Set("refresh_token", tokenPair.RefreshToken) + fragment.Set("expires_in", fmt.Sprintf("%d", tokenPair.ExpiresIn)) fragment.Set("token_type", "Bearer") fragment.Set("redirect", redirectTo) redirectWithFragment(c, frontendCallback, fragment) diff --git a/backend/internal/repository/refresh_token_cache.go b/backend/internal/repository/refresh_token_cache.go new file mode 100644 index 00000000..b01bd476 --- /dev/null +++ b/backend/internal/repository/refresh_token_cache.go @@ -0,0 +1,158 @@ +package repository + +import ( + "context" + "encoding/json" + "fmt" + "time" + + "github.com/Wei-Shaw/sub2api/internal/service" + "github.com/redis/go-redis/v9" +) + +const ( + refreshTokenKeyPrefix = "refresh_token:" + userRefreshTokensPrefix = "user_refresh_tokens:" + tokenFamilyPrefix = "token_family:" +) + +// refreshTokenKey generates the Redis key for a refresh token. +func refreshTokenKey(tokenHash string) string { + return refreshTokenKeyPrefix + tokenHash +} + +// userRefreshTokensKey generates the Redis key for user's token set. +func userRefreshTokensKey(userID int64) string { + return fmt.Sprintf("%s%d", userRefreshTokensPrefix, userID) +} + +// tokenFamilyKey generates the Redis key for token family set. +func tokenFamilyKey(familyID string) string { + return tokenFamilyPrefix + familyID +} + +type refreshTokenCache struct { + rdb *redis.Client +} + +// NewRefreshTokenCache creates a new RefreshTokenCache implementation. +func NewRefreshTokenCache(rdb *redis.Client) service.RefreshTokenCache { + return &refreshTokenCache{rdb: rdb} +} + +func (c *refreshTokenCache) StoreRefreshToken(ctx context.Context, tokenHash string, data *service.RefreshTokenData, ttl time.Duration) error { + key := refreshTokenKey(tokenHash) + val, err := json.Marshal(data) + if err != nil { + return fmt.Errorf("marshal refresh token data: %w", err) + } + return c.rdb.Set(ctx, key, val, ttl).Err() +} + +func (c *refreshTokenCache) GetRefreshToken(ctx context.Context, tokenHash string) (*service.RefreshTokenData, error) { + key := refreshTokenKey(tokenHash) + val, err := c.rdb.Get(ctx, key).Result() + if err != nil { + if err == redis.Nil { + return nil, service.ErrRefreshTokenNotFound + } + return nil, err + } + var data service.RefreshTokenData + if err := json.Unmarshal([]byte(val), &data); err != nil { + return nil, fmt.Errorf("unmarshal refresh token data: %w", err) + } + return &data, nil +} + +func (c *refreshTokenCache) DeleteRefreshToken(ctx context.Context, tokenHash string) error { + key := refreshTokenKey(tokenHash) + return c.rdb.Del(ctx, key).Err() +} + +func (c *refreshTokenCache) DeleteUserRefreshTokens(ctx context.Context, userID int64) error { + // Get all token hashes for this user + tokenHashes, err := c.GetUserTokenHashes(ctx, userID) + if err != nil && err != redis.Nil { + return fmt.Errorf("get user token hashes: %w", err) + } + + if len(tokenHashes) == 0 { + return nil + } + + // Build keys to delete + keys := make([]string, 0, len(tokenHashes)+1) + for _, hash := range tokenHashes { + keys = append(keys, refreshTokenKey(hash)) + } + keys = append(keys, userRefreshTokensKey(userID)) + + // Delete all keys in a pipeline + pipe := c.rdb.Pipeline() + for _, key := range keys { + pipe.Del(ctx, key) + } + _, err = pipe.Exec(ctx) + return err +} + +func (c *refreshTokenCache) DeleteTokenFamily(ctx context.Context, familyID string) error { + // Get all token hashes in this family + tokenHashes, err := c.GetFamilyTokenHashes(ctx, familyID) + if err != nil && err != redis.Nil { + return fmt.Errorf("get family token hashes: %w", err) + } + + if len(tokenHashes) == 0 { + return nil + } + + // Build keys to delete + keys := make([]string, 0, len(tokenHashes)+1) + for _, hash := range tokenHashes { + keys = append(keys, refreshTokenKey(hash)) + } + keys = append(keys, tokenFamilyKey(familyID)) + + // Delete all keys in a pipeline + pipe := c.rdb.Pipeline() + for _, key := range keys { + pipe.Del(ctx, key) + } + _, err = pipe.Exec(ctx) + return err +} + +func (c *refreshTokenCache) AddToUserTokenSet(ctx context.Context, userID int64, tokenHash string, ttl time.Duration) error { + key := userRefreshTokensKey(userID) + pipe := c.rdb.Pipeline() + pipe.SAdd(ctx, key, tokenHash) + pipe.Expire(ctx, key, ttl) + _, err := pipe.Exec(ctx) + return err +} + +func (c *refreshTokenCache) AddToFamilyTokenSet(ctx context.Context, familyID string, tokenHash string, ttl time.Duration) error { + key := tokenFamilyKey(familyID) + pipe := c.rdb.Pipeline() + pipe.SAdd(ctx, key, tokenHash) + pipe.Expire(ctx, key, ttl) + _, err := pipe.Exec(ctx) + return err +} + +func (c *refreshTokenCache) GetUserTokenHashes(ctx context.Context, userID int64) ([]string, error) { + key := userRefreshTokensKey(userID) + return c.rdb.SMembers(ctx, key).Result() +} + +func (c *refreshTokenCache) GetFamilyTokenHashes(ctx context.Context, familyID string) ([]string, error) { + key := tokenFamilyKey(familyID) + return c.rdb.SMembers(ctx, key).Result() +} + +func (c *refreshTokenCache) IsTokenInFamily(ctx context.Context, familyID string, tokenHash string) (bool, error) { + key := tokenFamilyKey(familyID) + return c.rdb.SIsMember(ctx, key, tokenHash).Result() +} diff --git a/backend/internal/repository/wire.go b/backend/internal/repository/wire.go index e3394361..857ce3e8 100644 --- a/backend/internal/repository/wire.go +++ b/backend/internal/repository/wire.go @@ -85,6 +85,7 @@ var ProviderSet = wire.NewSet( NewSchedulerOutboxRepository, NewProxyLatencyCache, NewTotpCache, + NewRefreshTokenCache, // Encryptors NewAESEncryptor, diff --git a/backend/internal/server/routes/auth.go b/backend/internal/server/routes/auth.go index 24f6d549..26d79605 100644 --- a/backend/internal/server/routes/auth.go +++ b/backend/internal/server/routes/auth.go @@ -28,6 +28,12 @@ func RegisterAuthRoutes( auth.POST("/login", h.Auth.Login) auth.POST("/login/2fa", h.Auth.Login2FA) auth.POST("/send-verify-code", h.Auth.SendVerifyCode) + // Token刷新接口添加速率限制:每分钟最多 30 次(Redis 故障时 fail-close) + auth.POST("/refresh", rateLimiter.LimitWithOptions("refresh-token", 30, time.Minute, middleware.RateLimitOptions{ + FailureMode: middleware.RateLimitFailClose, + }), h.Auth.RefreshToken) + // 登出接口(公开,允许未认证用户调用以撤销Refresh Token) + auth.POST("/logout", h.Auth.Logout) // 优惠码验证接口添加速率限制:每分钟最多 10 次(Redis 故障时 fail-close) auth.POST("/validate-promo-code", rateLimiter.LimitWithOptions("validate-promo", 10, time.Minute, middleware.RateLimitOptions{ FailureMode: middleware.RateLimitFailClose, @@ -59,5 +65,7 @@ func RegisterAuthRoutes( authenticated.Use(gin.HandlerFunc(jwtAuth)) { authenticated.GET("/auth/me", h.Auth.GetCurrentUser) + // 撤销所有会话(需要认证) + authenticated.POST("/auth/revoke-all-sessions", h.Auth.RevokeAllSessions) } } diff --git a/backend/internal/service/auth_service.go b/backend/internal/service/auth_service.go index 25604d2c..fb8aaf9c 100644 --- a/backend/internal/service/auth_service.go +++ b/backend/internal/service/auth_service.go @@ -3,6 +3,7 @@ package service import ( "context" "crypto/rand" + "crypto/sha256" "encoding/hex" "errors" "fmt" @@ -25,8 +26,12 @@ var ( ErrEmailReserved = infraerrors.BadRequest("EMAIL_RESERVED", "email is reserved") ErrInvalidToken = infraerrors.Unauthorized("INVALID_TOKEN", "invalid token") ErrTokenExpired = infraerrors.Unauthorized("TOKEN_EXPIRED", "token has expired") + ErrAccessTokenExpired = infraerrors.Unauthorized("ACCESS_TOKEN_EXPIRED", "access token has expired") ErrTokenTooLarge = infraerrors.BadRequest("TOKEN_TOO_LARGE", "token too large") ErrTokenRevoked = infraerrors.Unauthorized("TOKEN_REVOKED", "token has been revoked") + ErrRefreshTokenInvalid = infraerrors.Unauthorized("REFRESH_TOKEN_INVALID", "invalid refresh token") + ErrRefreshTokenExpired = infraerrors.Unauthorized("REFRESH_TOKEN_EXPIRED", "refresh token has expired") + ErrRefreshTokenReused = infraerrors.Unauthorized("REFRESH_TOKEN_REUSED", "refresh token has been reused") ErrEmailVerifyRequired = infraerrors.BadRequest("EMAIL_VERIFY_REQUIRED", "email verification is required") ErrRegDisabled = infraerrors.Forbidden("REGISTRATION_DISABLED", "registration is currently disabled") ErrServiceUnavailable = infraerrors.ServiceUnavailable("SERVICE_UNAVAILABLE", "service temporarily unavailable") @@ -37,6 +42,9 @@ var ( // maxTokenLength 限制 token 大小,避免超长 header 触发解析时的异常内存分配。 const maxTokenLength = 8192 +// refreshTokenPrefix is the prefix for refresh tokens to distinguish them from access tokens. +const refreshTokenPrefix = "rt_" + // JWTClaims JWT载荷数据 type JWTClaims struct { UserID int64 `json:"user_id"` @@ -50,6 +58,7 @@ type JWTClaims struct { type AuthService struct { userRepo UserRepository redeemRepo RedeemCodeRepository + refreshTokenCache RefreshTokenCache cfg *config.Config settingService *SettingService emailService *EmailService @@ -62,6 +71,7 @@ type AuthService struct { func NewAuthService( userRepo UserRepository, redeemRepo RedeemCodeRepository, + refreshTokenCache RefreshTokenCache, cfg *config.Config, settingService *SettingService, emailService *EmailService, @@ -72,6 +82,7 @@ func NewAuthService( return &AuthService{ userRepo: userRepo, redeemRepo: redeemRepo, + refreshTokenCache: refreshTokenCache, cfg: cfg, settingService: settingService, emailService: emailService, @@ -481,6 +492,100 @@ func (s *AuthService) LoginOrRegisterOAuth(ctx context.Context, email, username return token, user, nil } +// LoginOrRegisterOAuthWithTokenPair 用于第三方 OAuth/SSO 登录,返回完整的 TokenPair +// 与 LoginOrRegisterOAuth 功能相同,但返回 TokenPair 而非单个 token +func (s *AuthService) LoginOrRegisterOAuthWithTokenPair(ctx context.Context, email, username string) (*TokenPair, *User, error) { + // 检查 refreshTokenCache 是否可用 + if s.refreshTokenCache == nil { + return nil, nil, errors.New("refresh token cache not configured") + } + + email = strings.TrimSpace(email) + if email == "" || len(email) > 255 { + return nil, nil, infraerrors.BadRequest("INVALID_EMAIL", "invalid email") + } + if _, err := mail.ParseAddress(email); err != nil { + return nil, nil, infraerrors.BadRequest("INVALID_EMAIL", "invalid email") + } + + username = strings.TrimSpace(username) + if len([]rune(username)) > 100 { + username = string([]rune(username)[:100]) + } + + user, err := s.userRepo.GetByEmail(ctx, email) + if err != nil { + if errors.Is(err, ErrUserNotFound) { + // OAuth 首次登录视为注册 + if s.settingService == nil || !s.settingService.IsRegistrationEnabled(ctx) { + return nil, nil, ErrRegDisabled + } + + randomPassword, err := randomHexString(32) + if err != nil { + log.Printf("[Auth] Failed to generate random password for oauth signup: %v", err) + return nil, nil, ErrServiceUnavailable + } + hashedPassword, err := s.HashPassword(randomPassword) + if err != nil { + return nil, nil, fmt.Errorf("hash password: %w", err) + } + + defaultBalance := s.cfg.Default.UserBalance + defaultConcurrency := s.cfg.Default.UserConcurrency + if s.settingService != nil { + defaultBalance = s.settingService.GetDefaultBalance(ctx) + defaultConcurrency = s.settingService.GetDefaultConcurrency(ctx) + } + + newUser := &User{ + Email: email, + Username: username, + PasswordHash: hashedPassword, + Role: RoleUser, + Balance: defaultBalance, + Concurrency: defaultConcurrency, + Status: StatusActive, + } + + if err := s.userRepo.Create(ctx, newUser); err != nil { + if errors.Is(err, ErrEmailExists) { + user, err = s.userRepo.GetByEmail(ctx, email) + if err != nil { + log.Printf("[Auth] Database error getting user after conflict: %v", err) + return nil, nil, ErrServiceUnavailable + } + } else { + log.Printf("[Auth] Database error creating oauth user: %v", err) + return nil, nil, ErrServiceUnavailable + } + } else { + user = newUser + } + } else { + log.Printf("[Auth] Database error during oauth login: %v", err) + return nil, nil, ErrServiceUnavailable + } + } + + if !user.IsActive() { + return nil, nil, ErrUserNotActive + } + + if user.Username == "" && username != "" { + user.Username = username + if err := s.userRepo.Update(ctx, user); err != nil { + log.Printf("[Auth] Failed to update username after oauth login: %v", err) + } + } + + tokenPair, err := s.GenerateTokenPair(ctx, user, "") + if err != nil { + return nil, nil, fmt.Errorf("generate token pair: %w", err) + } + return tokenPair, user, nil +} + // ValidateToken 验证JWT token并返回用户声明 func (s *AuthService) ValidateToken(tokenString string) (*JWTClaims, error) { // 先做长度校验,尽早拒绝异常超长 token,降低 DoS 风险。 @@ -539,10 +644,17 @@ func isReservedEmail(email string) bool { return strings.HasSuffix(normalized, LinuxDoConnectSyntheticEmailDomain) } -// GenerateToken 生成JWT token +// GenerateToken 生成JWT access token +// 使用新的access_token_expire_minutes配置项(如果配置了),否则回退到expire_hour func (s *AuthService) GenerateToken(user *User) (string, error) { now := time.Now() - expiresAt := now.Add(time.Duration(s.cfg.JWT.ExpireHour) * time.Hour) + var expiresAt time.Time + if s.cfg.JWT.AccessTokenExpireMinutes > 0 { + expiresAt = now.Add(time.Duration(s.cfg.JWT.AccessTokenExpireMinutes) * time.Minute) + } else { + // 向后兼容:使用旧的expire_hour配置 + expiresAt = now.Add(time.Duration(s.cfg.JWT.ExpireHour) * time.Hour) + } claims := &JWTClaims{ UserID: user.ID, @@ -565,6 +677,15 @@ func (s *AuthService) GenerateToken(user *User) (string, error) { return tokenString, nil } +// GetAccessTokenExpiresIn 返回Access Token的有效期(秒) +// 用于前端设置刷新定时器 +func (s *AuthService) GetAccessTokenExpiresIn() int { + if s.cfg.JWT.AccessTokenExpireMinutes > 0 { + return s.cfg.JWT.AccessTokenExpireMinutes * 60 + } + return s.cfg.JWT.ExpireHour * 3600 +} + // HashPassword 使用bcrypt加密密码 func (s *AuthService) HashPassword(password string) (string, error) { hashedBytes, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost) @@ -755,6 +876,198 @@ func (s *AuthService) ResetPassword(ctx context.Context, email, token, newPasswo return ErrServiceUnavailable } + // Also revoke all refresh tokens for this user + if err := s.RevokeAllUserSessions(ctx, user.ID); err != nil { + log.Printf("[Auth] Failed to revoke refresh tokens for user %d: %v", user.ID, err) + // Don't return error - password was already changed successfully + } + log.Printf("[Auth] Password reset successful for user: %s", email) return nil } + +// ==================== Refresh Token Methods ==================== + +// TokenPair 包含Access Token和Refresh Token +type TokenPair struct { + AccessToken string `json:"access_token"` + RefreshToken string `json:"refresh_token"` + ExpiresIn int `json:"expires_in"` // Access Token有效期(秒) +} + +// GenerateTokenPair 生成Access Token和Refresh Token对 +// familyID: 可选的Token家族ID,用于Token轮转时保持家族关系 +func (s *AuthService) GenerateTokenPair(ctx context.Context, user *User, familyID string) (*TokenPair, error) { + // 检查 refreshTokenCache 是否可用 + if s.refreshTokenCache == nil { + return nil, errors.New("refresh token cache not configured") + } + + // 生成Access Token + accessToken, err := s.GenerateToken(user) + if err != nil { + return nil, fmt.Errorf("generate access token: %w", err) + } + + // 生成Refresh Token + refreshToken, err := s.generateRefreshToken(ctx, user, familyID) + if err != nil { + return nil, fmt.Errorf("generate refresh token: %w", err) + } + + return &TokenPair{ + AccessToken: accessToken, + RefreshToken: refreshToken, + ExpiresIn: s.GetAccessTokenExpiresIn(), + }, nil +} + +// generateRefreshToken 生成并存储Refresh Token +func (s *AuthService) generateRefreshToken(ctx context.Context, user *User, familyID string) (string, error) { + // 生成随机Token + tokenBytes := make([]byte, 32) + if _, err := rand.Read(tokenBytes); err != nil { + return "", fmt.Errorf("generate random bytes: %w", err) + } + rawToken := refreshTokenPrefix + hex.EncodeToString(tokenBytes) + + // 计算Token哈希(存储哈希而非原始Token) + tokenHash := hashToken(rawToken) + + // 如果没有提供familyID,生成新的 + if familyID == "" { + familyBytes := make([]byte, 16) + if _, err := rand.Read(familyBytes); err != nil { + return "", fmt.Errorf("generate family id: %w", err) + } + familyID = hex.EncodeToString(familyBytes) + } + + now := time.Now() + ttl := time.Duration(s.cfg.JWT.RefreshTokenExpireDays) * 24 * time.Hour + + data := &RefreshTokenData{ + UserID: user.ID, + TokenVersion: user.TokenVersion, + FamilyID: familyID, + CreatedAt: now, + ExpiresAt: now.Add(ttl), + } + + // 存储Token数据 + if err := s.refreshTokenCache.StoreRefreshToken(ctx, tokenHash, data, ttl); err != nil { + return "", fmt.Errorf("store refresh token: %w", err) + } + + // 添加到用户Token集合 + if err := s.refreshTokenCache.AddToUserTokenSet(ctx, user.ID, tokenHash, ttl); err != nil { + log.Printf("[Auth] Failed to add token to user set: %v", err) + // 不影响主流程 + } + + // 添加到家族Token集合 + if err := s.refreshTokenCache.AddToFamilyTokenSet(ctx, familyID, tokenHash, ttl); err != nil { + log.Printf("[Auth] Failed to add token to family set: %v", err) + // 不影响主流程 + } + + return rawToken, nil +} + +// RefreshTokenPair 使用Refresh Token刷新Token对 +// 实现Token轮转:每次刷新都会生成新的Refresh Token,旧Token立即失效 +func (s *AuthService) RefreshTokenPair(ctx context.Context, refreshToken string) (*TokenPair, error) { + // 检查 refreshTokenCache 是否可用 + if s.refreshTokenCache == nil { + return nil, ErrRefreshTokenInvalid + } + + // 验证Token格式 + if !strings.HasPrefix(refreshToken, refreshTokenPrefix) { + return nil, ErrRefreshTokenInvalid + } + + tokenHash := hashToken(refreshToken) + + // 获取Token数据 + data, err := s.refreshTokenCache.GetRefreshToken(ctx, tokenHash) + if err != nil { + if errors.Is(err, ErrRefreshTokenNotFound) { + // Token不存在,可能是已被使用(Token轮转)或已过期 + log.Printf("[Auth] Refresh token not found, possible reuse attack") + return nil, ErrRefreshTokenInvalid + } + log.Printf("[Auth] Error getting refresh token: %v", err) + return nil, ErrServiceUnavailable + } + + // 检查Token是否过期 + if time.Now().After(data.ExpiresAt) { + // 删除过期Token + _ = s.refreshTokenCache.DeleteRefreshToken(ctx, tokenHash) + return nil, ErrRefreshTokenExpired + } + + // 获取用户信息 + user, err := s.userRepo.GetByID(ctx, data.UserID) + if err != nil { + if errors.Is(err, ErrUserNotFound) { + // 用户已删除,撤销整个Token家族 + _ = s.refreshTokenCache.DeleteTokenFamily(ctx, data.FamilyID) + return nil, ErrRefreshTokenInvalid + } + log.Printf("[Auth] Database error getting user for token refresh: %v", err) + return nil, ErrServiceUnavailable + } + + // 检查用户状态 + if !user.IsActive() { + // 用户被禁用,撤销整个Token家族 + _ = s.refreshTokenCache.DeleteTokenFamily(ctx, data.FamilyID) + return nil, ErrUserNotActive + } + + // 检查TokenVersion(密码更改后所有Token失效) + if data.TokenVersion != user.TokenVersion { + // TokenVersion不匹配,撤销整个Token家族 + _ = s.refreshTokenCache.DeleteTokenFamily(ctx, data.FamilyID) + return nil, ErrTokenRevoked + } + + // Token轮转:立即使旧Token失效 + if err := s.refreshTokenCache.DeleteRefreshToken(ctx, tokenHash); err != nil { + log.Printf("[Auth] Failed to delete old refresh token: %v", err) + // 继续处理,不影响主流程 + } + + // 生成新的Token对,保持同一个家族ID + return s.GenerateTokenPair(ctx, user, data.FamilyID) +} + +// RevokeRefreshToken 撤销单个Refresh Token +func (s *AuthService) RevokeRefreshToken(ctx context.Context, refreshToken string) error { + if s.refreshTokenCache == nil { + return nil // No-op if cache not configured + } + if !strings.HasPrefix(refreshToken, refreshTokenPrefix) { + return ErrRefreshTokenInvalid + } + + tokenHash := hashToken(refreshToken) + return s.refreshTokenCache.DeleteRefreshToken(ctx, tokenHash) +} + +// RevokeAllUserSessions 撤销用户的所有会话(所有Refresh Token) +// 用于密码更改或用户主动登出所有设备 +func (s *AuthService) RevokeAllUserSessions(ctx context.Context, userID int64) error { + if s.refreshTokenCache == nil { + return nil // No-op if cache not configured + } + return s.refreshTokenCache.DeleteUserRefreshTokens(ctx, userID) +} + +// hashToken 计算Token的SHA256哈希 +func hashToken(token string) string { + hash := sha256.Sum256([]byte(token)) + return hex.EncodeToString(hash[:]) +} diff --git a/backend/internal/service/auth_service_register_test.go b/backend/internal/service/auth_service_register_test.go index aa3c769e..f1685be5 100644 --- a/backend/internal/service/auth_service_register_test.go +++ b/backend/internal/service/auth_service_register_test.go @@ -116,6 +116,7 @@ func newAuthService(repo *userRepoStub, settings map[string]string, emailCache E return NewAuthService( repo, nil, // redeemRepo + nil, // refreshTokenCache cfg, settingService, emailService, diff --git a/backend/internal/service/refresh_token_cache.go b/backend/internal/service/refresh_token_cache.go new file mode 100644 index 00000000..91b3924f --- /dev/null +++ b/backend/internal/service/refresh_token_cache.go @@ -0,0 +1,73 @@ +package service + +import ( + "context" + "errors" + "time" +) + +// ErrRefreshTokenNotFound is returned when a refresh token is not found in cache. +// This is used to abstract away the underlying cache implementation (e.g., redis.Nil). +var ErrRefreshTokenNotFound = errors.New("refresh token not found") + +// RefreshTokenData 存储在Redis中的Refresh Token数据 +type RefreshTokenData struct { + UserID int64 `json:"user_id"` + TokenVersion int64 `json:"token_version"` // 用于检测密码更改后的Token失效 + FamilyID string `json:"family_id"` // Token家族ID,用于防重放攻击 + CreatedAt time.Time `json:"created_at"` + ExpiresAt time.Time `json:"expires_at"` +} + +// RefreshTokenCache 管理Refresh Token的Redis缓存 +// 用于JWT Token刷新机制,支持Token轮转和防重放攻击 +// +// Key 格式: +// - refresh_token:{token_hash} -> RefreshTokenData (JSON) +// - user_refresh_tokens:{user_id} -> Set +// - token_family:{family_id} -> Set +type RefreshTokenCache interface { + // StoreRefreshToken 存储Refresh Token + // tokenHash: Token的SHA256哈希值(不存储原始Token) + // data: Token关联的数据 + // ttl: Token过期时间 + StoreRefreshToken(ctx context.Context, tokenHash string, data *RefreshTokenData, ttl time.Duration) error + + // GetRefreshToken 获取Refresh Token数据 + // 返回 (data, nil) 如果Token存在 + // 返回 (nil, ErrRefreshTokenNotFound) 如果Token不存在 + // 返回 (nil, err) 如果发生其他错误 + GetRefreshToken(ctx context.Context, tokenHash string) (*RefreshTokenData, error) + + // DeleteRefreshToken 删除单个Refresh Token + // 用于Token轮转时使旧Token失效 + DeleteRefreshToken(ctx context.Context, tokenHash string) error + + // DeleteUserRefreshTokens 删除用户的所有Refresh Token + // 用于密码更改或用户主动登出所有设备 + DeleteUserRefreshTokens(ctx context.Context, userID int64) error + + // DeleteTokenFamily 删除整个Token家族 + // 用于检测到Token重放攻击时,撤销整个会话链 + DeleteTokenFamily(ctx context.Context, familyID string) error + + // AddToUserTokenSet 将Token添加到用户的Token集合 + // 用于跟踪用户的所有活跃Refresh Token + AddToUserTokenSet(ctx context.Context, userID int64, tokenHash string, ttl time.Duration) error + + // AddToFamilyTokenSet 将Token添加到家族Token集合 + // 用于跟踪同一登录会话的所有Token + AddToFamilyTokenSet(ctx context.Context, familyID string, tokenHash string, ttl time.Duration) error + + // GetUserTokenHashes 获取用户的所有Token哈希 + // 用于批量删除用户Token + GetUserTokenHashes(ctx context.Context, userID int64) ([]string, error) + + // GetFamilyTokenHashes 获取家族的所有Token哈希 + // 用于批量删除家族Token + GetFamilyTokenHashes(ctx context.Context, familyID string) ([]string, error) + + // IsTokenInFamily 检查Token是否属于指定家族 + // 用于验证Token家族关系 + IsTokenInFamily(ctx context.Context, familyID string, tokenHash string) (bool, error) +} diff --git a/frontend/src/api/auth.ts b/frontend/src/api/auth.ts index 40c9c5a4..e196e234 100644 --- a/frontend/src/api/auth.ts +++ b/frontend/src/api/auth.ts @@ -35,6 +35,22 @@ export function setAuthToken(token: string): void { localStorage.setItem('auth_token', token) } +/** + * Store refresh token in localStorage + */ +export function setRefreshToken(token: string): void { + localStorage.setItem('refresh_token', token) +} + +/** + * Store token expiration timestamp in localStorage + * Converts expires_in (seconds) to absolute timestamp (milliseconds) + */ +export function setTokenExpiresAt(expiresIn: number): void { + const expiresAt = Date.now() + expiresIn * 1000 + localStorage.setItem('token_expires_at', String(expiresAt)) +} + /** * Get authentication token from localStorage */ @@ -42,12 +58,29 @@ export function getAuthToken(): string | null { return localStorage.getItem('auth_token') } +/** + * Get refresh token from localStorage + */ +export function getRefreshToken(): string | null { + return localStorage.getItem('refresh_token') +} + +/** + * Get token expiration timestamp from localStorage + */ +export function getTokenExpiresAt(): number | null { + const value = localStorage.getItem('token_expires_at') + return value ? parseInt(value, 10) : null +} + /** * Clear authentication token from localStorage */ export function clearAuthToken(): void { localStorage.removeItem('auth_token') + localStorage.removeItem('refresh_token') localStorage.removeItem('auth_user') + localStorage.removeItem('token_expires_at') } /** @@ -61,6 +94,12 @@ export async function login(credentials: LoginRequest): Promise { // Only store token if 2FA is not required if (!isTotp2FARequired(data)) { setAuthToken(data.access_token) + if (data.refresh_token) { + setRefreshToken(data.refresh_token) + } + if (data.expires_in) { + setTokenExpiresAt(data.expires_in) + } localStorage.setItem('auth_user', JSON.stringify(data.user)) } @@ -77,6 +116,12 @@ export async function login2FA(request: TotpLogin2FARequest): Promise // Store token and user data setAuthToken(data.access_token) + if (data.refresh_token) { + setRefreshToken(data.refresh_token) + } + if (data.expires_in) { + setTokenExpiresAt(data.expires_in) + } localStorage.setItem('auth_user', JSON.stringify(data.user)) return data @@ -108,11 +159,62 @@ export async function getCurrentUser() { /** * User logout * Clears authentication token and user data from localStorage + * Optionally revokes the refresh token on the server */ -export function logout(): void { +export async function logout(): Promise { + const refreshToken = getRefreshToken() + + // Try to revoke the refresh token on the server + if (refreshToken) { + try { + await apiClient.post('/auth/logout', { refresh_token: refreshToken }) + } catch { + // Ignore errors - we still want to clear local state + } + } + clearAuthToken() - // Optionally redirect to login page - // window.location.href = '/login'; +} + +/** + * Refresh token response + */ +export interface RefreshTokenResponse { + access_token: string + refresh_token: string + expires_in: number + token_type: string +} + +/** + * Refresh the access token using the refresh token + * @returns New token pair + */ +export async function refreshToken(): Promise { + const currentRefreshToken = getRefreshToken() + if (!currentRefreshToken) { + throw new Error('No refresh token available') + } + + const { data } = await apiClient.post('/auth/refresh', { + refresh_token: currentRefreshToken + }) + + // Update tokens in localStorage + setAuthToken(data.access_token) + setRefreshToken(data.refresh_token) + setTokenExpiresAt(data.expires_in) + + return data +} + +/** + * Revoke all sessions for the current user + * @returns Response with message + */ +export async function revokeAllSessions(): Promise<{ message: string }> { + const { data } = await apiClient.post<{ message: string }>('/auth/revoke-all-sessions') + return data } /** @@ -242,14 +344,20 @@ export const authAPI = { logout, isAuthenticated, setAuthToken, + setRefreshToken, + setTokenExpiresAt, getAuthToken, + getRefreshToken, + getTokenExpiresAt, clearAuthToken, getPublicSettings, sendVerifyCode, validatePromoCode, validateInvitationCode, forgotPassword, - resetPassword + resetPassword, + refreshToken, + revokeAllSessions } export default authAPI diff --git a/frontend/src/api/client.ts b/frontend/src/api/client.ts index 3827498b..22db5a44 100644 --- a/frontend/src/api/client.ts +++ b/frontend/src/api/client.ts @@ -1,9 +1,9 @@ /** * Axios HTTP Client Configuration - * Base client with interceptors for authentication and error handling + * Base client with interceptors for authentication, token refresh, and error handling */ -import axios, { AxiosInstance, AxiosError, InternalAxiosRequestConfig } from 'axios' +import axios, { AxiosInstance, AxiosError, InternalAxiosRequestConfig, AxiosResponse } from 'axios' import type { ApiResponse } from '@/types' import { getLocale } from '@/i18n' @@ -19,6 +19,28 @@ export const apiClient: AxiosInstance = axios.create({ } }) +// ==================== Token Refresh State ==================== + +// Track if a token refresh is in progress to prevent multiple simultaneous refresh requests +let isRefreshing = false +// Queue of requests waiting for token refresh +let refreshSubscribers: Array<(token: string) => void> = [] + +/** + * Subscribe to token refresh completion + */ +function subscribeTokenRefresh(callback: (token: string) => void): void { + refreshSubscribers.push(callback) +} + +/** + * Notify all subscribers that token has been refreshed + */ +function onTokenRefreshed(token: string): void { + refreshSubscribers.forEach((callback) => callback(token)) + refreshSubscribers = [] +} + // ==================== Request Interceptor ==================== // Get user's timezone @@ -61,7 +83,7 @@ apiClient.interceptors.request.use( // ==================== Response Interceptor ==================== apiClient.interceptors.response.use( - (response) => { + (response: AxiosResponse) => { // Unwrap standard API response format { code, message, data } const apiResponse = response.data as ApiResponse if (apiResponse && typeof apiResponse === 'object' && 'code' in apiResponse) { @@ -79,13 +101,15 @@ apiClient.interceptors.response.use( } return response }, - (error: AxiosError>) => { + async (error: AxiosError>) => { // Request cancellation: keep the original axios cancellation error so callers can ignore it. // Otherwise we'd misclassify it as a generic "network error". if (error.code === 'ERR_CANCELED' || axios.isCancel(error)) { return Promise.reject(error) } + const originalRequest = error.config as InternalAxiosRequestConfig & { _retry?: boolean } + // Handle common errors if (error.response) { const { status, data } = error.response @@ -120,23 +144,116 @@ apiClient.interceptors.response.use( }) } - // 401: Unauthorized - clear token and redirect to login - if (status === 401) { - const hasToken = !!localStorage.getItem('auth_token') - const url = error.config?.url || '' + // 401: Try to refresh the token if we have a refresh token + // This handles TOKEN_EXPIRED, INVALID_TOKEN, TOKEN_REVOKED, etc. + if (status === 401 && !originalRequest._retry) { + const refreshToken = localStorage.getItem('refresh_token') const isAuthEndpoint = url.includes('/auth/login') || url.includes('/auth/register') || url.includes('/auth/refresh') + + // If we have a refresh token and this is not an auth endpoint, try to refresh + if (refreshToken && !isAuthEndpoint) { + if (isRefreshing) { + // Wait for the ongoing refresh to complete + return new Promise((resolve, reject) => { + subscribeTokenRefresh((newToken: string) => { + if (newToken) { + // Mark as retried to prevent infinite loop if retry also returns 401 + originalRequest._retry = true + if (originalRequest.headers) { + originalRequest.headers.Authorization = `Bearer ${newToken}` + } + resolve(apiClient(originalRequest)) + } else { + // Refresh failed, reject with original error + reject({ + status, + code: apiData.code, + message: apiData.message || apiData.detail || error.message + }) + } + }) + }) + } + + originalRequest._retry = true + isRefreshing = true + + try { + // Call refresh endpoint directly to avoid circular dependency + const refreshResponse = await axios.post( + `${API_BASE_URL}/auth/refresh`, + { refresh_token: refreshToken }, + { headers: { 'Content-Type': 'application/json' } } + ) + + const refreshData = refreshResponse.data as ApiResponse<{ + access_token: string + refresh_token: string + expires_in: number + }> + + if (refreshData.code === 0 && refreshData.data) { + const { access_token, refresh_token: newRefreshToken, expires_in } = refreshData.data + + // Update tokens in localStorage (convert expires_in to timestamp) + localStorage.setItem('auth_token', access_token) + localStorage.setItem('refresh_token', newRefreshToken) + localStorage.setItem('token_expires_at', String(Date.now() + expires_in * 1000)) + + // Notify subscribers with new token + onTokenRefreshed(access_token) + + // Retry the original request with new token + if (originalRequest.headers) { + originalRequest.headers.Authorization = `Bearer ${access_token}` + } + + isRefreshing = false + return apiClient(originalRequest) + } + + // Refresh response was not successful, fall through to clear auth + throw new Error('Token refresh failed') + } catch (refreshError) { + // Refresh failed - notify subscribers with empty token + onTokenRefreshed('') + isRefreshing = false + + // Clear tokens and redirect to login + localStorage.removeItem('auth_token') + localStorage.removeItem('refresh_token') + localStorage.removeItem('auth_user') + localStorage.removeItem('token_expires_at') + sessionStorage.setItem('auth_expired', '1') + + if (!window.location.pathname.includes('/login')) { + window.location.href = '/login' + } + + return Promise.reject({ + status: 401, + code: 'TOKEN_REFRESH_FAILED', + message: 'Session expired. Please log in again.' + }) + } + } + + // No refresh token or is auth endpoint - clear auth and redirect + const hasToken = !!localStorage.getItem('auth_token') const headers = error.config?.headers as Record | undefined const authHeader = headers?.Authorization ?? headers?.authorization const sentAuth = typeof authHeader === 'string' ? authHeader.trim() !== '' : Array.isArray(authHeader) - ? authHeader.length > 0 - : !!authHeader + ? authHeader.length > 0 + : !!authHeader localStorage.removeItem('auth_token') + localStorage.removeItem('refresh_token') localStorage.removeItem('auth_user') + localStorage.removeItem('token_expires_at') if ((hasToken || sentAuth) && !isAuthEndpoint) { sessionStorage.setItem('auth_expired', '1') } diff --git a/frontend/src/components/layout/AppHeader.vue b/frontend/src/components/layout/AppHeader.vue index 6b5849c0..a6b4030f 100644 --- a/frontend/src/components/layout/AppHeader.vue +++ b/frontend/src/components/layout/AppHeader.vue @@ -283,7 +283,12 @@ function closeDropdown() { async function handleLogout() { closeDropdown() - authStore.logout() + try { + await authStore.logout() + } catch (error) { + // Ignore logout errors - still redirect to login + console.error('Logout error:', error) + } await router.push('/login') } diff --git a/frontend/src/stores/auth.ts b/frontend/src/stores/auth.ts index e4612f5e..22cad50a 100644 --- a/frontend/src/stores/auth.ts +++ b/frontend/src/stores/auth.ts @@ -1,6 +1,6 @@ /** * Authentication Store - * Manages user authentication state, login/logout, and token persistence + * Manages user authentication state, login/logout, token refresh, and token persistence */ import { defineStore } from 'pinia' @@ -10,15 +10,21 @@ import type { User, LoginRequest, RegisterRequest, AuthResponse } from '@/types' const AUTH_TOKEN_KEY = 'auth_token' const AUTH_USER_KEY = 'auth_user' -const AUTO_REFRESH_INTERVAL = 60 * 1000 // 60 seconds +const REFRESH_TOKEN_KEY = 'refresh_token' +const TOKEN_EXPIRES_AT_KEY = 'token_expires_at' // 存储过期时间戳而非有效期 +const AUTO_REFRESH_INTERVAL = 60 * 1000 // 60 seconds for user data refresh +const TOKEN_REFRESH_BUFFER = 120 * 1000 // 120 seconds before expiry to refresh token export const useAuthStore = defineStore('auth', () => { // ==================== State ==================== const user = ref(null) const token = ref(null) + const refreshTokenValue = ref(null) + const tokenExpiresAt = ref(null) // 过期时间戳(毫秒) const runMode = ref<'standard' | 'simple'>('standard') let refreshIntervalId: ReturnType | null = null + let tokenRefreshTimeoutId: ReturnType | null = null // ==================== Computed ==================== @@ -42,19 +48,29 @@ export const useAuthStore = defineStore('auth', () => { function checkAuth(): void { const savedToken = localStorage.getItem(AUTH_TOKEN_KEY) const savedUser = localStorage.getItem(AUTH_USER_KEY) + const savedRefreshToken = localStorage.getItem(REFRESH_TOKEN_KEY) + const savedExpiresAt = localStorage.getItem(TOKEN_EXPIRES_AT_KEY) if (savedToken && savedUser) { try { token.value = savedToken user.value = JSON.parse(savedUser) + refreshTokenValue.value = savedRefreshToken + tokenExpiresAt.value = savedExpiresAt ? parseInt(savedExpiresAt, 10) : null // Immediately refresh user data from backend (async, don't block) refreshUser().catch((error) => { console.error('Failed to refresh user on init:', error) }) - // Start auto-refresh interval + // Start auto-refresh interval for user data startAutoRefresh() + + // Start proactive token refresh if we have refresh token and expiry info + // Note: use !== null to handle case when tokenExpiresAt.value is 0 (expired) + if (savedRefreshToken && tokenExpiresAt.value !== null) { + scheduleTokenRefreshAt(tokenExpiresAt.value) + } } catch (error) { console.error('Failed to parse saved user data:', error) clearAuth() @@ -89,6 +105,76 @@ export const useAuthStore = defineStore('auth', () => { } } + /** + * Schedule proactive token refresh before expiry (based on expiry timestamp) + * @param expiresAtMs - Token expiry timestamp in milliseconds + */ + function scheduleTokenRefreshAt(expiresAtMs: number): void { + // Clear any existing timeout + if (tokenRefreshTimeoutId) { + clearTimeout(tokenRefreshTimeoutId) + tokenRefreshTimeoutId = null + } + + // Calculate remaining time until refresh (buffer time before expiry) + const now = Date.now() + const refreshInMs = Math.max(0, expiresAtMs - now - TOKEN_REFRESH_BUFFER) + + if (refreshInMs <= 0) { + // Token is about to expire or already expired, refresh immediately + performTokenRefresh() + return + } + + tokenRefreshTimeoutId = setTimeout(() => { + performTokenRefresh() + }, refreshInMs) + } + + /** + * Schedule proactive token refresh before expiry (based on expires_in seconds) + * @param expiresInSeconds - Token expiry time in seconds from now + */ + function scheduleTokenRefresh(expiresInSeconds: number): void { + const expiresAtMs = Date.now() + expiresInSeconds * 1000 + tokenExpiresAt.value = expiresAtMs + localStorage.setItem(TOKEN_EXPIRES_AT_KEY, String(expiresAtMs)) + scheduleTokenRefreshAt(expiresAtMs) + } + + /** + * Perform the actual token refresh + */ + async function performTokenRefresh(): Promise { + if (!refreshTokenValue.value) { + return + } + + try { + const response = await authAPI.refreshToken() + + // Update state + token.value = response.access_token + refreshTokenValue.value = response.refresh_token + + // Schedule next refresh (this also updates tokenExpiresAt and localStorage) + scheduleTokenRefresh(response.expires_in) + } catch (error) { + console.error('Token refresh failed:', error) + // Don't clear auth here - the interceptor will handle 401 errors + } + } + + /** + * Stop token refresh timeout + */ + function stopTokenRefresh(): void { + if (tokenRefreshTimeoutId) { + clearTimeout(tokenRefreshTimeoutId) + tokenRefreshTimeoutId = null + } + } + /** * User login * @param credentials - Login credentials (email and password) @@ -141,6 +227,12 @@ export const useAuthStore = defineStore('auth', () => { // Store token and user token.value = response.access_token + // Store refresh token if present + if (response.refresh_token) { + refreshTokenValue.value = response.refresh_token + localStorage.setItem(REFRESH_TOKEN_KEY, response.refresh_token) + } + // Extract run_mode if present if (response.user.run_mode) { runMode.value = response.user.run_mode @@ -152,8 +244,14 @@ export const useAuthStore = defineStore('auth', () => { localStorage.setItem(AUTH_TOKEN_KEY, response.access_token) localStorage.setItem(AUTH_USER_KEY, JSON.stringify(userData)) - // Start auto-refresh interval + // Start auto-refresh interval for user data startAutoRefresh() + + // Start proactive token refresh if we have refresh token and expiry info + // scheduleTokenRefresh will also store the expiry timestamp + if (response.refresh_token && response.expires_in) { + scheduleTokenRefresh(response.expires_in) + } } /** @@ -166,24 +264,10 @@ export const useAuthStore = defineStore('auth', () => { try { const response = await authAPI.register(userData) - // Store token and user - token.value = response.access_token + // Use the common helper to set auth state + setAuthFromResponse(response) - // Extract run_mode if present - if (response.user.run_mode) { - runMode.value = response.user.run_mode - } - const { run_mode: _run_mode, ...userDataWithoutRunMode } = response.user - user.value = userDataWithoutRunMode - - // Persist to localStorage - localStorage.setItem(AUTH_TOKEN_KEY, response.access_token) - localStorage.setItem(AUTH_USER_KEY, JSON.stringify(userDataWithoutRunMode)) - - // Start auto-refresh interval - startAutoRefresh() - - return userDataWithoutRunMode + return user.value! } catch (error) { // Clear any partial state on error clearAuth() @@ -193,18 +277,41 @@ export const useAuthStore = defineStore('auth', () => { /** * 直接设置 token(用于 OAuth/SSO 回调),并加载当前用户信息。 + * 会自动读取 localStorage 中已设置的 refresh_token 和 token_expires_in * @param newToken - 后端签发的 JWT access token */ async function setToken(newToken: string): Promise { // Clear any previous state first (avoid mixing sessions) - clearAuth() + // Note: Don't clear localStorage here as OAuth callback may have set refresh_token + stopAutoRefresh() + stopTokenRefresh() + token.value = null + user.value = null token.value = newToken localStorage.setItem(AUTH_TOKEN_KEY, newToken) + // Read refresh token and expires_at from localStorage if set by OAuth callback + const savedRefreshToken = localStorage.getItem(REFRESH_TOKEN_KEY) + const savedExpiresAt = localStorage.getItem(TOKEN_EXPIRES_AT_KEY) + + if (savedRefreshToken) { + refreshTokenValue.value = savedRefreshToken + } + if (savedExpiresAt) { + tokenExpiresAt.value = parseInt(savedExpiresAt, 10) + } + try { const userData = await refreshUser() startAutoRefresh() + + // Start proactive token refresh if we have refresh token and expiry info + // Note: use !== null to handle case when tokenExpiresAt.value is 0 (expired) + if (savedRefreshToken && tokenExpiresAt.value !== null) { + scheduleTokenRefreshAt(tokenExpiresAt.value) + } + return userData } catch (error) { clearAuth() @@ -216,9 +323,9 @@ export const useAuthStore = defineStore('auth', () => { * User logout * Clears all authentication state and persisted data */ - function logout(): void { - // Call API logout (client-side cleanup) - authAPI.logout() + async function logout(): Promise { + // Call API logout (revokes refresh token on server) + await authAPI.logout() // Clear state clearAuth() @@ -263,11 +370,17 @@ export const useAuthStore = defineStore('auth', () => { function clearAuth(): void { // Stop auto-refresh stopAutoRefresh() + // Stop token refresh + stopTokenRefresh() token.value = null + refreshTokenValue.value = null + tokenExpiresAt.value = null user.value = null localStorage.removeItem(AUTH_TOKEN_KEY) localStorage.removeItem(AUTH_USER_KEY) + localStorage.removeItem(REFRESH_TOKEN_KEY) + localStorage.removeItem(TOKEN_EXPIRES_AT_KEY) } // ==================== Return Store API ==================== diff --git a/frontend/src/types/index.ts b/frontend/src/types/index.ts index 12449d3c..eb53de44 100644 --- a/frontend/src/types/index.ts +++ b/frontend/src/types/index.ts @@ -92,6 +92,8 @@ export interface PublicSettings { export interface AuthResponse { access_token: string + refresh_token?: string // New: Refresh Token for token renewal + expires_in?: number // New: Access Token expiry time in seconds token_type: string user: User & { run_mode?: 'standard' | 'simple' } } diff --git a/frontend/src/views/auth/LinuxDoCallbackView.vue b/frontend/src/views/auth/LinuxDoCallbackView.vue index c6f93e6b..4dbca1df 100644 --- a/frontend/src/views/auth/LinuxDoCallbackView.vue +++ b/frontend/src/views/auth/LinuxDoCallbackView.vue @@ -71,6 +71,8 @@ onMounted(async () => { const params = parseFragmentParams() const token = params.get('access_token') || '' + const refreshToken = params.get('refresh_token') || '' + const expiresInStr = params.get('expires_in') || '' const redirect = sanitizeRedirectPath( params.get('redirect') || (route.query.redirect as string | undefined) || '/dashboard' ) @@ -92,6 +94,17 @@ onMounted(async () => { } try { + // Store refresh token and expires_at (convert to timestamp) if provided + if (refreshToken) { + localStorage.setItem('refresh_token', refreshToken) + } + if (expiresInStr) { + const expiresIn = parseInt(expiresInStr, 10) + if (!isNaN(expiresIn)) { + localStorage.setItem('token_expires_at', String(Date.now() + expiresIn * 1000)) + } + } + await authStore.setToken(token) appStore.showSuccess(t('auth.loginSuccess')) await router.replace(redirect)