feat: track authenticated user activity

This commit is contained in:
IanShaw027
2026-04-21 14:54:53 +08:00
parent 422f3449a2
commit ed01c59916
10 changed files with 254 additions and 26 deletions

View File

@@ -7,6 +7,7 @@ import (
"net/http"
"net/http/httptest"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
@@ -153,6 +154,18 @@ func (s *stubUserRepo) Delete(ctx context.Context, id int64) error {
panic("unexpected Delete call")
}
func (s *stubUserRepo) GetUserAvatar(ctx context.Context, userID int64) (*service.UserAvatar, error) {
return nil, nil
}
func (s *stubUserRepo) UpsertUserAvatar(ctx context.Context, userID int64, input service.UpsertUserAvatarInput) (*service.UserAvatar, error) {
panic("unexpected UpsertUserAvatar call")
}
func (s *stubUserRepo) DeleteUserAvatar(ctx context.Context, userID int64) error {
panic("unexpected DeleteUserAvatar call")
}
func (s *stubUserRepo) List(ctx context.Context, params pagination.PaginationParams) ([]service.User, *pagination.PaginationResult, error) {
panic("unexpected List call")
}
@@ -161,6 +174,18 @@ func (s *stubUserRepo) ListWithFilters(ctx context.Context, params pagination.Pa
panic("unexpected ListWithFilters call")
}
func (s *stubUserRepo) GetLatestUsedAtByUserIDs(ctx context.Context, userIDs []int64) (map[int64]*time.Time, error) {
panic("unexpected GetLatestUsedAtByUserIDs call")
}
func (s *stubUserRepo) GetLatestUsedAtByUserID(ctx context.Context, userID int64) (*time.Time, error) {
panic("unexpected GetLatestUsedAtByUserID call")
}
func (s *stubUserRepo) UpdateUserLastActiveAt(ctx context.Context, userID int64, activeAt time.Time) error {
panic("unexpected UpdateUserLastActiveAt call")
}
func (s *stubUserRepo) UpdateBalance(ctx context.Context, id int64, amount float64) error {
panic("unexpected UpdateBalance call")
}
@@ -189,6 +214,10 @@ func (s *stubUserRepo) AddGroupToAllowedGroups(ctx context.Context, userID int64
panic("unexpected AddGroupToAllowedGroups call")
}
func (s *stubUserRepo) ListUserAuthIdentities(ctx context.Context, userID int64) ([]service.UserAuthIdentityRecord, error) {
panic("unexpected ListUserAuthIdentities call")
}
func (s *stubUserRepo) UpdateTotpSecret(ctx context.Context, userID int64, encryptedSecret *string) error {
panic("unexpected UpdateTotpSecret call")
}

View File

@@ -1,6 +1,7 @@
package middleware
import (
"context"
"errors"
"strings"
@@ -11,11 +12,19 @@ import (
// NewJWTAuthMiddleware 创建 JWT 认证中间件
func NewJWTAuthMiddleware(authService *service.AuthService, userService *service.UserService) JWTAuthMiddleware {
return JWTAuthMiddleware(jwtAuth(authService, userService))
return JWTAuthMiddleware(jwtAuth(authService, userService, userService))
}
type jwtUserReader interface {
GetByID(ctx context.Context, id int64) (*service.User, error)
}
type userActivityToucher interface {
TouchLastActiveForUser(ctx context.Context, user *service.User)
}
// jwtAuth JWT认证中间件实现
func jwtAuth(authService *service.AuthService, userService *service.UserService) gin.HandlerFunc {
func jwtAuth(authService *service.AuthService, userService jwtUserReader, activityToucher userActivityToucher) gin.HandlerFunc {
return func(c *gin.Context) {
// 从Authorization header中提取token
authHeader := c.GetHeader("Authorization")
@@ -73,6 +82,9 @@ func jwtAuth(authService *service.AuthService, userService *service.UserService)
Concurrency: user.Concurrency,
})
c.Set(string(ContextKeyUserRole), user.Role)
if activityToucher != nil {
activityToucher.TouchLastActiveForUser(c.Request.Context(), user)
}
c.Next()
}

View File

@@ -9,6 +9,7 @@ import (
"net/http"
"net/http/httptest"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/service"
@@ -30,6 +31,25 @@ func (r *stubJWTUserRepo) GetByID(_ context.Context, id int64) (*service.User, e
return u, nil
}
func (r *stubJWTUserRepo) GetUserAvatar(_ context.Context, _ int64) (*service.UserAvatar, error) {
return nil, nil
}
func (r *stubJWTUserRepo) UpdateUserLastActiveAt(_ context.Context, _ int64, _ time.Time) error {
return nil
}
type recordingActivityToucher struct {
userIDs []int64
}
func (r *recordingActivityToucher) TouchLastActiveForUser(_ context.Context, user *service.User) {
if user == nil {
return
}
r.userIDs = append(r.userIDs, user.ID)
}
// newJWTTestEnv 创建 JWT 认证中间件测试环境。
// 返回 gin.Engine已注册 JWT 中间件)和 AuthService用于生成 Token
func newJWTTestEnv(users map[int64]*service.User) (*gin.Engine, *service.AuthService) {
@@ -106,6 +126,45 @@ func TestJWTAuth_ValidToken_LowercaseBearer(t *testing.T) {
require.Equal(t, http.StatusOK, w.Code)
}
func TestJWTAuth_ValidToken_TouchesLastActive(t *testing.T) {
user := &service.User{
ID: 1,
Email: "test@example.com",
Role: "user",
Status: service.StatusActive,
Concurrency: 5,
TokenVersion: 1,
}
gin.SetMode(gin.TestMode)
cfg := &config.Config{}
cfg.JWT.Secret = "test-jwt-secret-32bytes-long!!!"
cfg.JWT.AccessTokenExpireMinutes = 60
userRepo := &stubJWTUserRepo{users: map[int64]*service.User{1: user}}
authSvc := service.NewAuthService(nil, userRepo, nil, nil, cfg, nil, nil, nil, nil, nil, nil)
userSvc := service.NewUserService(userRepo, nil, nil, nil)
toucher := &recordingActivityToucher{}
r := gin.New()
r.Use(jwtAuth(authSvc, userSvc, toucher))
r.GET("/protected", func(c *gin.Context) {
c.Status(http.StatusOK)
})
token, err := authSvc.GenerateToken(user)
require.NoError(t, err)
w := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/protected", nil)
req.Header.Set("Authorization", "Bearer "+token)
r.ServeHTTP(w, req)
require.Equal(t, http.StatusOK, w.Code)
require.Equal(t, []int64{1}, toucher.userIDs)
}
func TestJWTAuth_MissingAuthorizationHeader(t *testing.T) {
router, _ := newJWTTestEnv(nil)