fix(profile): stabilize binding compatibility and frontend checks
This commit is contained in:
@@ -6,6 +6,7 @@ import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"errors"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
@@ -13,6 +14,7 @@ import (
|
||||
"github.com/Wei-Shaw/sub2api/ent/authidentity"
|
||||
"github.com/Wei-Shaw/sub2api/ent/enttest"
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||||
"github.com/Wei-Shaw/sub2api/internal/repository"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/stretchr/testify/require"
|
||||
@@ -54,6 +56,16 @@ func newAuthServiceForEmailBind(
|
||||
settings map[string]string,
|
||||
emailCache service.EmailCache,
|
||||
defaultSubAssigner service.DefaultSubscriptionAssigner,
|
||||
) (*service.AuthService, service.UserRepository, *dbent.Client) {
|
||||
return newAuthServiceForEmailBindWithRefreshCache(t, settings, emailCache, defaultSubAssigner, nil)
|
||||
}
|
||||
|
||||
func newAuthServiceForEmailBindWithRefreshCache(
|
||||
t *testing.T,
|
||||
settings map[string]string,
|
||||
emailCache service.EmailCache,
|
||||
defaultSubAssigner service.DefaultSubscriptionAssigner,
|
||||
refreshTokenCache service.RefreshTokenCache,
|
||||
) (*service.AuthService, service.UserRepository, *dbent.Client) {
|
||||
t.Helper()
|
||||
|
||||
@@ -98,7 +110,7 @@ CREATE TABLE IF NOT EXISTS user_provider_default_grants (
|
||||
emailSvc = service.NewEmailService(settingRepo, emailCache)
|
||||
}
|
||||
|
||||
svc := service.NewAuthService(client, repo, nil, nil, cfg, settingSvc, emailSvc, nil, nil, nil, defaultSubAssigner)
|
||||
svc := service.NewAuthService(client, repo, nil, refreshTokenCache, cfg, settingSvc, emailSvc, nil, nil, nil, defaultSubAssigner)
|
||||
return svc, repo, client
|
||||
}
|
||||
|
||||
@@ -427,6 +439,61 @@ func TestAuthServiceBindEmailIdentity_RejectsWrongCurrentPasswordForBoundEmail(t
|
||||
require.Equal(t, 0, newIdentityCount)
|
||||
}
|
||||
|
||||
func TestAuthServiceBindEmailIdentity_RevokesExistingAccessAndRefreshTokens(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
cache := &emailBindCacheStub{
|
||||
data: &service.VerificationCodeData{
|
||||
Code: "123456",
|
||||
CreatedAt: time.Now().UTC(),
|
||||
ExpiresAt: time.Now().UTC().Add(10 * time.Minute),
|
||||
},
|
||||
}
|
||||
refreshTokenCache := newEmailBindRefreshTokenCacheStub()
|
||||
userRepo := newEmailBindUserRepoStub(&service.User{
|
||||
ID: 41,
|
||||
Email: "legacy-user" + service.OIDCConnectSyntheticEmailDomain,
|
||||
Username: "legacy-user",
|
||||
PasswordHash: "old-hash",
|
||||
Role: service.RoleUser,
|
||||
Status: service.StatusActive,
|
||||
TokenVersion: 4,
|
||||
})
|
||||
cfg := &config.Config{
|
||||
JWT: config.JWTConfig{
|
||||
Secret: "test-bind-email-secret",
|
||||
ExpireHour: 1,
|
||||
AccessTokenExpireMinutes: 60,
|
||||
RefreshTokenExpireDays: 7,
|
||||
},
|
||||
}
|
||||
emailService := service.NewEmailService(nil, cache)
|
||||
svc := service.NewAuthService(nil, userRepo, nil, refreshTokenCache, cfg, nil, emailService, nil, nil, nil, nil)
|
||||
|
||||
oldTokenPair, err := svc.GenerateTokenPair(ctx, &service.User{
|
||||
ID: 41,
|
||||
Email: "legacy-user" + service.OIDCConnectSyntheticEmailDomain,
|
||||
Role: service.RoleUser,
|
||||
Status: service.StatusActive,
|
||||
TokenVersion: 4,
|
||||
}, "")
|
||||
require.NoError(t, err)
|
||||
|
||||
updatedUser, err := svc.BindEmailIdentity(ctx, 41, "new@example.com", "123456", "new-password")
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, updatedUser)
|
||||
|
||||
storedUser, err := userRepo.GetByID(ctx, 41)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "new@example.com", storedUser.Email)
|
||||
require.True(t, svc.CheckPassword("new-password", storedUser.PasswordHash))
|
||||
|
||||
_, err = svc.RefreshToken(ctx, oldTokenPair.AccessToken)
|
||||
require.ErrorIs(t, err, service.ErrTokenRevoked)
|
||||
|
||||
_, err = svc.RefreshTokenPair(ctx, oldTokenPair.RefreshToken)
|
||||
require.True(t, errors.Is(err, service.ErrTokenRevoked) || errors.Is(err, service.ErrRefreshTokenInvalid))
|
||||
}
|
||||
|
||||
type emailBindSettingRepoStub struct {
|
||||
values map[string]string
|
||||
}
|
||||
@@ -527,3 +594,260 @@ func (s *emailBindCacheStub) GetNotifyCodeUserRate(context.Context, int64) (int6
|
||||
func (s *emailBindCacheStub) IncrNotifyCodeUserRate(context.Context, int64, time.Duration) (int64, error) {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
type emailBindRefreshTokenCacheStub struct {
|
||||
mu sync.Mutex
|
||||
tokens map[string]*service.RefreshTokenData
|
||||
userSets map[int64]map[string]struct{}
|
||||
families map[string]map[string]struct{}
|
||||
}
|
||||
|
||||
func newEmailBindRefreshTokenCacheStub() *emailBindRefreshTokenCacheStub {
|
||||
return &emailBindRefreshTokenCacheStub{
|
||||
tokens: make(map[string]*service.RefreshTokenData),
|
||||
userSets: make(map[int64]map[string]struct{}),
|
||||
families: make(map[string]map[string]struct{}),
|
||||
}
|
||||
}
|
||||
|
||||
func (s *emailBindRefreshTokenCacheStub) StoreRefreshToken(_ context.Context, tokenHash string, data *service.RefreshTokenData, _ time.Duration) error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
cloned := *data
|
||||
s.tokens[tokenHash] = &cloned
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *emailBindRefreshTokenCacheStub) GetRefreshToken(_ context.Context, tokenHash string) (*service.RefreshTokenData, error) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
data, ok := s.tokens[tokenHash]
|
||||
if !ok {
|
||||
return nil, service.ErrRefreshTokenNotFound
|
||||
}
|
||||
cloned := *data
|
||||
return &cloned, nil
|
||||
}
|
||||
|
||||
func (s *emailBindRefreshTokenCacheStub) DeleteRefreshToken(_ context.Context, tokenHash string) error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
delete(s.tokens, tokenHash)
|
||||
for _, tokenSet := range s.userSets {
|
||||
delete(tokenSet, tokenHash)
|
||||
}
|
||||
for _, tokenSet := range s.families {
|
||||
delete(tokenSet, tokenHash)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *emailBindRefreshTokenCacheStub) DeleteUserRefreshTokens(_ context.Context, userID int64) error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
for tokenHash := range s.userSets[userID] {
|
||||
delete(s.tokens, tokenHash)
|
||||
for _, tokenSet := range s.families {
|
||||
delete(tokenSet, tokenHash)
|
||||
}
|
||||
}
|
||||
delete(s.userSets, userID)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *emailBindRefreshTokenCacheStub) DeleteTokenFamily(_ context.Context, familyID string) error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
for tokenHash := range s.families[familyID] {
|
||||
delete(s.tokens, tokenHash)
|
||||
for _, tokenSet := range s.userSets {
|
||||
delete(tokenSet, tokenHash)
|
||||
}
|
||||
}
|
||||
delete(s.families, familyID)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *emailBindRefreshTokenCacheStub) AddToUserTokenSet(_ context.Context, userID int64, tokenHash string, _ time.Duration) error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
if s.userSets[userID] == nil {
|
||||
s.userSets[userID] = make(map[string]struct{})
|
||||
}
|
||||
s.userSets[userID][tokenHash] = struct{}{}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *emailBindRefreshTokenCacheStub) AddToFamilyTokenSet(_ context.Context, familyID string, tokenHash string, _ time.Duration) error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
if s.families[familyID] == nil {
|
||||
s.families[familyID] = make(map[string]struct{})
|
||||
}
|
||||
s.families[familyID][tokenHash] = struct{}{}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *emailBindRefreshTokenCacheStub) GetUserTokenHashes(_ context.Context, userID int64) ([]string, error) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
tokenSet := s.userSets[userID]
|
||||
out := make([]string, 0, len(tokenSet))
|
||||
for tokenHash := range tokenSet {
|
||||
out = append(out, tokenHash)
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func (s *emailBindRefreshTokenCacheStub) GetFamilyTokenHashes(_ context.Context, familyID string) ([]string, error) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
tokenSet := s.families[familyID]
|
||||
out := make([]string, 0, len(tokenSet))
|
||||
for tokenHash := range tokenSet {
|
||||
out = append(out, tokenHash)
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func (s *emailBindRefreshTokenCacheStub) IsTokenInFamily(_ context.Context, familyID string, tokenHash string) (bool, error) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
_, ok := s.families[familyID][tokenHash]
|
||||
return ok, nil
|
||||
}
|
||||
|
||||
type emailBindUserRepoStub struct {
|
||||
mu sync.Mutex
|
||||
usersByID map[int64]*service.User
|
||||
usersByEmail map[string]*service.User
|
||||
}
|
||||
|
||||
func newEmailBindUserRepoStub(user *service.User) *emailBindUserRepoStub {
|
||||
cloned := cloneEmailBindUser(user)
|
||||
return &emailBindUserRepoStub{
|
||||
usersByID: map[int64]*service.User{
|
||||
cloned.ID: cloned,
|
||||
},
|
||||
usersByEmail: map[string]*service.User{
|
||||
cloned.Email: cloned,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func (s *emailBindUserRepoStub) Create(context.Context, *service.User) error { return nil }
|
||||
|
||||
func (s *emailBindUserRepoStub) GetByID(_ context.Context, id int64) (*service.User, error) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
user, ok := s.usersByID[id]
|
||||
if !ok {
|
||||
return nil, service.ErrUserNotFound
|
||||
}
|
||||
return cloneEmailBindUser(user), nil
|
||||
}
|
||||
|
||||
func (s *emailBindUserRepoStub) GetByEmail(_ context.Context, email string) (*service.User, error) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
user, ok := s.usersByEmail[email]
|
||||
if !ok {
|
||||
return nil, service.ErrUserNotFound
|
||||
}
|
||||
return cloneEmailBindUser(user), nil
|
||||
}
|
||||
|
||||
func (s *emailBindUserRepoStub) GetFirstAdmin(context.Context) (*service.User, error) {
|
||||
panic("unexpected GetFirstAdmin call")
|
||||
}
|
||||
|
||||
func (s *emailBindUserRepoStub) Update(_ context.Context, user *service.User) error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
existing, ok := s.usersByID[user.ID]
|
||||
if !ok {
|
||||
return service.ErrUserNotFound
|
||||
}
|
||||
delete(s.usersByEmail, existing.Email)
|
||||
cloned := cloneEmailBindUser(user)
|
||||
s.usersByID[user.ID] = cloned
|
||||
s.usersByEmail[cloned.Email] = cloned
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *emailBindUserRepoStub) Delete(context.Context, int64) error { return nil }
|
||||
|
||||
func (s *emailBindUserRepoStub) GetUserAvatar(context.Context, int64) (*service.UserAvatar, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (s *emailBindUserRepoStub) UpsertUserAvatar(context.Context, int64, service.UpsertUserAvatarInput) (*service.UserAvatar, error) {
|
||||
panic("unexpected UpsertUserAvatar call")
|
||||
}
|
||||
|
||||
func (s *emailBindUserRepoStub) DeleteUserAvatar(context.Context, int64) error {
|
||||
panic("unexpected DeleteUserAvatar call")
|
||||
}
|
||||
|
||||
func (s *emailBindUserRepoStub) List(context.Context, pagination.PaginationParams) ([]service.User, *pagination.PaginationResult, error) {
|
||||
panic("unexpected List call")
|
||||
}
|
||||
|
||||
func (s *emailBindUserRepoStub) ListWithFilters(context.Context, pagination.PaginationParams, service.UserListFilters) ([]service.User, *pagination.PaginationResult, error) {
|
||||
panic("unexpected ListWithFilters call")
|
||||
}
|
||||
|
||||
func (s *emailBindUserRepoStub) GetLatestUsedAtByUserIDs(context.Context, []int64) (map[int64]*time.Time, error) {
|
||||
return map[int64]*time.Time{}, nil
|
||||
}
|
||||
|
||||
func (s *emailBindUserRepoStub) GetLatestUsedAtByUserID(context.Context, int64) (*time.Time, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (s *emailBindUserRepoStub) UpdateUserLastActiveAt(context.Context, int64, time.Time) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *emailBindUserRepoStub) UpdateBalance(context.Context, int64, float64) error { return nil }
|
||||
func (s *emailBindUserRepoStub) DeductBalance(context.Context, int64, float64) error { return nil }
|
||||
func (s *emailBindUserRepoStub) UpdateConcurrency(context.Context, int64, int) error { return nil }
|
||||
|
||||
func (s *emailBindUserRepoStub) ExistsByEmail(_ context.Context, email string) (bool, error) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
_, ok := s.usersByEmail[email]
|
||||
return ok, nil
|
||||
}
|
||||
|
||||
func (s *emailBindUserRepoStub) RemoveGroupFromAllowedGroups(context.Context, int64) (int64, error) {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
func (s *emailBindUserRepoStub) AddGroupToAllowedGroups(context.Context, int64, int64) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *emailBindUserRepoStub) RemoveGroupFromUserAllowedGroups(context.Context, int64, int64) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *emailBindUserRepoStub) ListUserAuthIdentities(context.Context, int64) ([]service.UserAuthIdentityRecord, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (s *emailBindUserRepoStub) UnbindUserAuthProvider(context.Context, int64, string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *emailBindUserRepoStub) UpdateTotpSecret(context.Context, int64, *string) error { return nil }
|
||||
func (s *emailBindUserRepoStub) EnableTotp(context.Context, int64) error { return nil }
|
||||
func (s *emailBindUserRepoStub) DisableTotp(context.Context, int64) error { return nil }
|
||||
|
||||
func cloneEmailBindUser(user *service.User) *service.User {
|
||||
if user == nil {
|
||||
return nil
|
||||
}
|
||||
cloned := *user
|
||||
return &cloned
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user