- 新增 Access Token + Refresh Token 双令牌认证 - 支持 Token 自动刷新和轮转 - 添加登出和撤销所有会话接口 - 前端实现无感刷新和主动刷新定时器
159 lines
4.5 KiB
Go
159 lines
4.5 KiB
Go
package repository
|
|
|
|
import (
|
|
"context"
|
|
"encoding/json"
|
|
"fmt"
|
|
"time"
|
|
|
|
"github.com/Wei-Shaw/sub2api/internal/service"
|
|
"github.com/redis/go-redis/v9"
|
|
)
|
|
|
|
const (
|
|
refreshTokenKeyPrefix = "refresh_token:"
|
|
userRefreshTokensPrefix = "user_refresh_tokens:"
|
|
tokenFamilyPrefix = "token_family:"
|
|
)
|
|
|
|
// refreshTokenKey generates the Redis key for a refresh token.
|
|
func refreshTokenKey(tokenHash string) string {
|
|
return refreshTokenKeyPrefix + tokenHash
|
|
}
|
|
|
|
// userRefreshTokensKey generates the Redis key for user's token set.
|
|
func userRefreshTokensKey(userID int64) string {
|
|
return fmt.Sprintf("%s%d", userRefreshTokensPrefix, userID)
|
|
}
|
|
|
|
// tokenFamilyKey generates the Redis key for token family set.
|
|
func tokenFamilyKey(familyID string) string {
|
|
return tokenFamilyPrefix + familyID
|
|
}
|
|
|
|
type refreshTokenCache struct {
|
|
rdb *redis.Client
|
|
}
|
|
|
|
// NewRefreshTokenCache creates a new RefreshTokenCache implementation.
|
|
func NewRefreshTokenCache(rdb *redis.Client) service.RefreshTokenCache {
|
|
return &refreshTokenCache{rdb: rdb}
|
|
}
|
|
|
|
func (c *refreshTokenCache) StoreRefreshToken(ctx context.Context, tokenHash string, data *service.RefreshTokenData, ttl time.Duration) error {
|
|
key := refreshTokenKey(tokenHash)
|
|
val, err := json.Marshal(data)
|
|
if err != nil {
|
|
return fmt.Errorf("marshal refresh token data: %w", err)
|
|
}
|
|
return c.rdb.Set(ctx, key, val, ttl).Err()
|
|
}
|
|
|
|
func (c *refreshTokenCache) GetRefreshToken(ctx context.Context, tokenHash string) (*service.RefreshTokenData, error) {
|
|
key := refreshTokenKey(tokenHash)
|
|
val, err := c.rdb.Get(ctx, key).Result()
|
|
if err != nil {
|
|
if err == redis.Nil {
|
|
return nil, service.ErrRefreshTokenNotFound
|
|
}
|
|
return nil, err
|
|
}
|
|
var data service.RefreshTokenData
|
|
if err := json.Unmarshal([]byte(val), &data); err != nil {
|
|
return nil, fmt.Errorf("unmarshal refresh token data: %w", err)
|
|
}
|
|
return &data, nil
|
|
}
|
|
|
|
func (c *refreshTokenCache) DeleteRefreshToken(ctx context.Context, tokenHash string) error {
|
|
key := refreshTokenKey(tokenHash)
|
|
return c.rdb.Del(ctx, key).Err()
|
|
}
|
|
|
|
func (c *refreshTokenCache) DeleteUserRefreshTokens(ctx context.Context, userID int64) error {
|
|
// Get all token hashes for this user
|
|
tokenHashes, err := c.GetUserTokenHashes(ctx, userID)
|
|
if err != nil && err != redis.Nil {
|
|
return fmt.Errorf("get user token hashes: %w", err)
|
|
}
|
|
|
|
if len(tokenHashes) == 0 {
|
|
return nil
|
|
}
|
|
|
|
// Build keys to delete
|
|
keys := make([]string, 0, len(tokenHashes)+1)
|
|
for _, hash := range tokenHashes {
|
|
keys = append(keys, refreshTokenKey(hash))
|
|
}
|
|
keys = append(keys, userRefreshTokensKey(userID))
|
|
|
|
// Delete all keys in a pipeline
|
|
pipe := c.rdb.Pipeline()
|
|
for _, key := range keys {
|
|
pipe.Del(ctx, key)
|
|
}
|
|
_, err = pipe.Exec(ctx)
|
|
return err
|
|
}
|
|
|
|
func (c *refreshTokenCache) DeleteTokenFamily(ctx context.Context, familyID string) error {
|
|
// Get all token hashes in this family
|
|
tokenHashes, err := c.GetFamilyTokenHashes(ctx, familyID)
|
|
if err != nil && err != redis.Nil {
|
|
return fmt.Errorf("get family token hashes: %w", err)
|
|
}
|
|
|
|
if len(tokenHashes) == 0 {
|
|
return nil
|
|
}
|
|
|
|
// Build keys to delete
|
|
keys := make([]string, 0, len(tokenHashes)+1)
|
|
for _, hash := range tokenHashes {
|
|
keys = append(keys, refreshTokenKey(hash))
|
|
}
|
|
keys = append(keys, tokenFamilyKey(familyID))
|
|
|
|
// Delete all keys in a pipeline
|
|
pipe := c.rdb.Pipeline()
|
|
for _, key := range keys {
|
|
pipe.Del(ctx, key)
|
|
}
|
|
_, err = pipe.Exec(ctx)
|
|
return err
|
|
}
|
|
|
|
func (c *refreshTokenCache) AddToUserTokenSet(ctx context.Context, userID int64, tokenHash string, ttl time.Duration) error {
|
|
key := userRefreshTokensKey(userID)
|
|
pipe := c.rdb.Pipeline()
|
|
pipe.SAdd(ctx, key, tokenHash)
|
|
pipe.Expire(ctx, key, ttl)
|
|
_, err := pipe.Exec(ctx)
|
|
return err
|
|
}
|
|
|
|
func (c *refreshTokenCache) AddToFamilyTokenSet(ctx context.Context, familyID string, tokenHash string, ttl time.Duration) error {
|
|
key := tokenFamilyKey(familyID)
|
|
pipe := c.rdb.Pipeline()
|
|
pipe.SAdd(ctx, key, tokenHash)
|
|
pipe.Expire(ctx, key, ttl)
|
|
_, err := pipe.Exec(ctx)
|
|
return err
|
|
}
|
|
|
|
func (c *refreshTokenCache) GetUserTokenHashes(ctx context.Context, userID int64) ([]string, error) {
|
|
key := userRefreshTokensKey(userID)
|
|
return c.rdb.SMembers(ctx, key).Result()
|
|
}
|
|
|
|
func (c *refreshTokenCache) GetFamilyTokenHashes(ctx context.Context, familyID string) ([]string, error) {
|
|
key := tokenFamilyKey(familyID)
|
|
return c.rdb.SMembers(ctx, key).Result()
|
|
}
|
|
|
|
func (c *refreshTokenCache) IsTokenInFamily(ctx context.Context, familyID string, tokenHash string) (bool, error) {
|
|
key := tokenFamilyKey(familyID)
|
|
return c.rdb.SIsMember(ctx, key, tokenHash).Result()
|
|
}
|