From 4e0e69154649def4f4149054e09ff03fc8a3e50a Mon Sep 17 00:00:00 2001 From: IanShaw027 Date: Mon, 20 Apr 2026 18:39:53 +0800 Subject: [PATCH] feat: apply auth source signup defaults --- .../service/admin_service_apikey_test.go | 3 + .../service/admin_service_delete_test.go | 51 ++++-- backend/internal/service/auth_service.go | 117 +++++++++---- .../service/auth_service_register_test.go | 159 +++++++++++++++++- 4 files changed, 283 insertions(+), 47 deletions(-) diff --git a/backend/internal/service/admin_service_apikey_test.go b/backend/internal/service/admin_service_apikey_test.go index b802a9c2..487fb5f1 100644 --- a/backend/internal/service/admin_service_apikey_test.go +++ b/backend/internal/service/admin_service_apikey_test.go @@ -79,6 +79,9 @@ func (s *userRepoStubForGroupUpdate) UpdateTotpSecret(context.Context, int64, *s } func (s *userRepoStubForGroupUpdate) EnableTotp(context.Context, int64) error { panic("unexpected") } func (s *userRepoStubForGroupUpdate) DisableTotp(context.Context, int64) error { panic("unexpected") } +func (s *userRepoStubForGroupUpdate) ListUserAuthIdentities(context.Context, int64) ([]UserAuthIdentityRecord, error) { + panic("unexpected") +} func (s *userRepoStubForGroupUpdate) RemoveGroupFromUserAllowedGroups(context.Context, int64, int64) error { panic("unexpected") } diff --git a/backend/internal/service/admin_service_delete_test.go b/backend/internal/service/admin_service_delete_test.go index 323286b0..ac1d8ee7 100644 --- a/backend/internal/service/admin_service_delete_test.go +++ b/backend/internal/service/admin_service_delete_test.go @@ -13,15 +13,18 @@ import ( ) type userRepoStub struct { - user *User - getErr error - createErr error - deleteErr error - exists bool - existsErr error - nextID int64 - created []*User - deletedIDs []int64 + user *User + getErr error + createErr error + deleteErr error + exists bool + existsErr error + nextID int64 + created []*User + updated []*User + deletedIDs []int64 + usersByEmail map[string]*User + getByEmailErr error } func (s *userRepoStub) Create(ctx context.Context, user *User) error { @@ -32,6 +35,11 @@ func (s *userRepoStub) Create(ctx context.Context, user *User) error { user.ID = s.nextID } s.created = append(s.created, user) + if s.usersByEmail == nil { + s.usersByEmail = make(map[string]*User) + } + s.usersByEmail[user.Email] = user + s.user = user return nil } @@ -46,7 +54,18 @@ func (s *userRepoStub) GetByID(ctx context.Context, id int64) (*User, error) { } func (s *userRepoStub) GetByEmail(ctx context.Context, email string) (*User, error) { - panic("unexpected GetByEmail call") + if s.getByEmailErr != nil { + return nil, s.getByEmailErr + } + if s.usersByEmail != nil { + if user, ok := s.usersByEmail[email]; ok { + return user, nil + } + } + if s.user != nil && s.user.Email == email { + return s.user, nil + } + return nil, ErrUserNotFound } func (s *userRepoStub) GetFirstAdmin(ctx context.Context) (*User, error) { @@ -54,7 +73,13 @@ func (s *userRepoStub) GetFirstAdmin(ctx context.Context) (*User, error) { } func (s *userRepoStub) Update(ctx context.Context, user *User) error { - panic("unexpected Update call") + s.updated = append(s.updated, user) + if s.usersByEmail == nil { + s.usersByEmail = make(map[string]*User) + } + s.usersByEmail[user.Email] = user + s.user = user + return nil } func (s *userRepoStub) Delete(ctx context.Context, id int64) error { @@ -113,6 +138,10 @@ func (s *userRepoStub) AddGroupToAllowedGroups(ctx context.Context, userID int64 panic("unexpected AddGroupToAllowedGroups call") } +func (s *userRepoStub) ListUserAuthIdentities(ctx context.Context, userID int64) ([]UserAuthIdentityRecord, error) { + panic("unexpected ListUserAuthIdentities call") +} + func (s *userRepoStub) UpdateTotpSecret(ctx context.Context, userID int64, encryptedSecret *string) error { panic("unexpected UpdateTotpSecret call") } diff --git a/backend/internal/service/auth_service.go b/backend/internal/service/auth_service.go index 962009ce..40753139 100644 --- a/backend/internal/service/auth_service.go +++ b/backend/internal/service/auth_service.go @@ -78,6 +78,12 @@ type DefaultSubscriptionAssigner interface { AssignOrExtendSubscription(ctx context.Context, input *AssignSubscriptionInput) (*UserSubscription, bool, error) } +type signupGrantPlan struct { + Balance float64 + Concurrency int + Subscriptions []DefaultSubscriptionSetting +} + // NewAuthService 创建认证服务实例 func NewAuthService( entClient *dbent.Client, @@ -187,21 +193,15 @@ func (s *AuthService) RegisterWithVerification(ctx context.Context, email, passw return "", nil, fmt.Errorf("hash password: %w", err) } - // 获取默认配置 - defaultBalance := s.cfg.Default.UserBalance - defaultConcurrency := s.cfg.Default.UserConcurrency - if s.settingService != nil { - defaultBalance = s.settingService.GetDefaultBalance(ctx) - defaultConcurrency = s.settingService.GetDefaultConcurrency(ctx) - } + grantPlan := s.resolveSignupGrantPlan(ctx, "email") // 创建用户 user := &User{ Email: email, PasswordHash: hashedPassword, Role: RoleUser, - Balance: defaultBalance, - Concurrency: defaultConcurrency, + Balance: grantPlan.Balance, + Concurrency: grantPlan.Concurrency, Status: StatusActive, } @@ -214,7 +214,7 @@ func (s *AuthService) RegisterWithVerification(ctx context.Context, email, passw return "", nil, ErrServiceUnavailable } s.postAuthUserBootstrap(ctx, user, "email", true) - s.assignDefaultSubscriptions(ctx, user.ID) + s.assignSubscriptions(ctx, user.ID, grantPlan.Subscriptions, "auto assigned by signup defaults") // 标记邀请码为已使用(如果使用了邀请码) if invitationRedeemCode != nil { @@ -479,21 +479,16 @@ func (s *AuthService) LoginOrRegisterOAuth(ctx context.Context, email, username return "", nil, fmt.Errorf("hash password: %w", err) } - // 新用户默认值。 - defaultBalance := s.cfg.Default.UserBalance - defaultConcurrency := s.cfg.Default.UserConcurrency - if s.settingService != nil { - defaultBalance = s.settingService.GetDefaultBalance(ctx) - defaultConcurrency = s.settingService.GetDefaultConcurrency(ctx) - } + signupSource := inferLegacySignupSource(email) + grantPlan := s.resolveSignupGrantPlan(ctx, signupSource) newUser := &User{ Email: email, Username: username, PasswordHash: hashedPassword, Role: RoleUser, - Balance: defaultBalance, - Concurrency: defaultConcurrency, + Balance: grantPlan.Balance, + Concurrency: grantPlan.Concurrency, Status: StatusActive, } @@ -511,8 +506,8 @@ func (s *AuthService) LoginOrRegisterOAuth(ctx context.Context, email, username } } else { user = newUser - s.postAuthUserBootstrap(ctx, user, inferLegacySignupSource(email), true) - s.assignDefaultSubscriptions(ctx, user.ID) + s.postAuthUserBootstrap(ctx, user, signupSource, true) + s.assignSubscriptions(ctx, user.ID, grantPlan.Subscriptions, "auto assigned by signup defaults") } } else { logger.LegacyPrintf("service.auth", "[Auth] Database error during oauth login: %v", err) @@ -596,20 +591,16 @@ func (s *AuthService) LoginOrRegisterOAuthWithTokenPair(ctx context.Context, ema return nil, nil, fmt.Errorf("hash password: %w", err) } - defaultBalance := s.cfg.Default.UserBalance - defaultConcurrency := s.cfg.Default.UserConcurrency - if s.settingService != nil { - defaultBalance = s.settingService.GetDefaultBalance(ctx) - defaultConcurrency = s.settingService.GetDefaultConcurrency(ctx) - } + signupSource := inferLegacySignupSource(email) + grantPlan := s.resolveSignupGrantPlan(ctx, signupSource) newUser := &User{ Email: email, Username: username, PasswordHash: hashedPassword, Role: RoleUser, - Balance: defaultBalance, - Concurrency: defaultConcurrency, + Balance: grantPlan.Balance, + Concurrency: grantPlan.Concurrency, Status: StatusActive, } @@ -642,8 +633,8 @@ func (s *AuthService) LoginOrRegisterOAuthWithTokenPair(ctx context.Context, ema return nil, nil, ErrServiceUnavailable } user = newUser - s.postAuthUserBootstrap(ctx, user, inferLegacySignupSource(email), true) - s.assignDefaultSubscriptions(ctx, user.ID) + s.postAuthUserBootstrap(ctx, user, signupSource, true) + s.assignSubscriptions(ctx, user.ID, grantPlan.Subscriptions, "auto assigned by signup defaults") } } else { if err := s.userRepo.Create(ctx, newUser); err != nil { @@ -659,8 +650,8 @@ func (s *AuthService) LoginOrRegisterOAuthWithTokenPair(ctx context.Context, ema } } else { user = newUser - s.postAuthUserBootstrap(ctx, user, inferLegacySignupSource(email), true) - s.assignDefaultSubscriptions(ctx, user.ID) + s.postAuthUserBootstrap(ctx, user, signupSource, true) + s.assignSubscriptions(ctx, user.ID, grantPlan.Subscriptions, "auto assigned by signup defaults") if invitationRedeemCode != nil { if err := s.redeemRepo.Use(ctx, invitationRedeemCode.ID, user.ID); err != nil { return nil, nil, ErrInvitationCodeInvalid @@ -694,22 +685,78 @@ func (s *AuthService) LoginOrRegisterOAuthWithTokenPair(ctx context.Context, ema } func (s *AuthService) assignDefaultSubscriptions(ctx context.Context, userID int64) { + if s.settingService == nil { + return + } + s.assignSubscriptions(ctx, userID, s.settingService.GetDefaultSubscriptions(ctx), "auto assigned by default user subscriptions setting") +} + +func (s *AuthService) assignSubscriptions(ctx context.Context, userID int64, items []DefaultSubscriptionSetting, notes string) { if s.settingService == nil || s.defaultSubAssigner == nil || userID <= 0 { return } - items := s.settingService.GetDefaultSubscriptions(ctx) for _, item := range items { if _, _, err := s.defaultSubAssigner.AssignOrExtendSubscription(ctx, &AssignSubscriptionInput{ UserID: userID, GroupID: item.GroupID, ValidityDays: item.ValidityDays, - Notes: "auto assigned by default user subscriptions setting", + Notes: notes, }); err != nil { logger.LegacyPrintf("service.auth", "[Auth] Failed to assign default subscription: user_id=%d group_id=%d err=%v", userID, item.GroupID, err) } } } +func (s *AuthService) resolveSignupGrantPlan(ctx context.Context, signupSource string) signupGrantPlan { + plan := signupGrantPlan{} + if s != nil && s.cfg != nil { + plan.Balance = s.cfg.Default.UserBalance + plan.Concurrency = s.cfg.Default.UserConcurrency + } + if s == nil || s.settingService == nil { + return plan + } + + plan.Balance = s.settingService.GetDefaultBalance(ctx) + plan.Concurrency = s.settingService.GetDefaultConcurrency(ctx) + plan.Subscriptions = s.settingService.GetDefaultSubscriptions(ctx) + + defaults, err := s.settingService.GetAuthSourceDefaultSettings(ctx) + if err != nil { + logger.LegacyPrintf("service.auth", "[Auth] Failed to load auth source signup defaults for %s: %v", signupSource, err) + return plan + } + + providerDefaults, ok := authSourceSignupSettings(defaults, signupSource) + if !ok || !providerDefaults.GrantOnSignup { + return plan + } + + plan.Balance = providerDefaults.Balance + plan.Concurrency = providerDefaults.Concurrency + plan.Subscriptions = providerDefaults.Subscriptions + return plan +} + +func authSourceSignupSettings(defaults *AuthSourceDefaultSettings, signupSource string) (ProviderDefaultGrantSettings, bool) { + if defaults == nil { + return ProviderDefaultGrantSettings{}, false + } + + switch strings.ToLower(strings.TrimSpace(signupSource)) { + case "email": + return defaults.Email, true + case "linuxdo": + return defaults.LinuxDo, true + case "oidc": + return defaults.OIDC, true + case "wechat": + return defaults.WeChat, true + default: + return ProviderDefaultGrantSettings{}, false + } +} + func (s *AuthService) postAuthUserBootstrap(ctx context.Context, user *User, signupSource string, touchLogin bool) { if user == nil || user.ID <= 0 { return diff --git a/backend/internal/service/auth_service_register_test.go b/backend/internal/service/auth_service_register_test.go index 103bafe7..901b3db3 100644 --- a/backend/internal/service/auth_service_register_test.go +++ b/backend/internal/service/auth_service_register_test.go @@ -37,7 +37,16 @@ func (s *settingRepoStub) Set(ctx context.Context, key, value string) error { } func (s *settingRepoStub) GetMultiple(ctx context.Context, keys []string) (map[string]string, error) { - panic("unexpected GetMultiple call") + if s.err != nil { + return nil, s.err + } + result := make(map[string]string, len(keys)) + for _, key := range keys { + if v, ok := s.values[key]; ok { + result[key] = v + } + } + return result, nil } func (s *settingRepoStub) SetMultiple(ctx context.Context, settings map[string]string) error { @@ -62,6 +71,8 @@ type defaultSubscriptionAssignerStub struct { err error } +type refreshTokenCacheStub struct{} + func (s *defaultSubscriptionAssignerStub) AssignOrExtendSubscription(_ context.Context, input *AssignSubscriptionInput) (*UserSubscription, bool, error) { if input != nil { s.calls = append(s.calls, *input) @@ -72,6 +83,46 @@ func (s *defaultSubscriptionAssignerStub) AssignOrExtendSubscription(_ context.C return &UserSubscription{UserID: input.UserID, GroupID: input.GroupID}, false, nil } +func (s *refreshTokenCacheStub) StoreRefreshToken(context.Context, string, *RefreshTokenData, time.Duration) error { + return nil +} + +func (s *refreshTokenCacheStub) GetRefreshToken(context.Context, string) (*RefreshTokenData, error) { + return nil, ErrRefreshTokenNotFound +} + +func (s *refreshTokenCacheStub) DeleteRefreshToken(context.Context, string) error { + return nil +} + +func (s *refreshTokenCacheStub) DeleteUserRefreshTokens(context.Context, int64) error { + return nil +} + +func (s *refreshTokenCacheStub) DeleteTokenFamily(context.Context, string) error { + return nil +} + +func (s *refreshTokenCacheStub) AddToUserTokenSet(context.Context, int64, string, time.Duration) error { + return nil +} + +func (s *refreshTokenCacheStub) AddToFamilyTokenSet(context.Context, string, string, time.Duration) error { + return nil +} + +func (s *refreshTokenCacheStub) GetUserTokenHashes(context.Context, int64) ([]string, error) { + return nil, nil +} + +func (s *refreshTokenCacheStub) GetFamilyTokenHashes(context.Context, string) ([]string, error) { + return nil, nil +} + +func (s *refreshTokenCacheStub) IsTokenInFamily(context.Context, string, string) (bool, error) { + return false, nil +} + func (s *emailCacheStub) GetVerificationCode(ctx context.Context, email string) (*VerificationCodeData, error) { if s.err != nil { return nil, s.err @@ -484,3 +535,109 @@ func TestAuthService_Register_AssignsDefaultSubscriptions(t *testing.T) { require.Equal(t, int64(12), assigner.calls[1].GroupID) require.Equal(t, 7, assigner.calls[1].ValidityDays) } + +func TestAuthService_Register_UsesEmailAuthSourceDefaultsWhenGrantEnabled(t *testing.T) { + repo := &userRepoStub{nextID: 52} + assigner := &defaultSubscriptionAssignerStub{} + service := newAuthService(repo, map[string]string{ + SettingKeyRegistrationEnabled: "true", + SettingKeyDefaultSubscriptions: `[{"group_id":91,"validity_days":3}]`, + SettingKeyAuthSourceDefaultEmailBalance: "12.5", + SettingKeyAuthSourceDefaultEmailConcurrency: "7", + SettingKeyAuthSourceDefaultEmailSubscriptions: `[{"group_id":11,"validity_days":30}]`, + SettingKeyAuthSourceDefaultEmailGrantOnSignup: "true", + }, nil) + service.defaultSubAssigner = assigner + + _, user, err := service.Register(context.Background(), "email-defaults@test.com", "password") + require.NoError(t, err) + require.NotNil(t, user) + require.Equal(t, 12.5, user.Balance) + require.Equal(t, 7, user.Concurrency) + require.Len(t, assigner.calls, 1) + require.Equal(t, int64(11), assigner.calls[0].GroupID) + require.Equal(t, 30, assigner.calls[0].ValidityDays) +} + +func TestAuthService_Register_GrantOnSignupFalseFallsBackToGlobalDefaults(t *testing.T) { + repo := &userRepoStub{nextID: 53} + assigner := &defaultSubscriptionAssignerStub{} + service := newAuthService(repo, map[string]string{ + SettingKeyRegistrationEnabled: "true", + SettingKeyDefaultSubscriptions: `[{"group_id":31,"validity_days":5}]`, + SettingKeyAuthSourceDefaultEmailBalance: "99", + SettingKeyAuthSourceDefaultEmailConcurrency: "88", + SettingKeyAuthSourceDefaultEmailSubscriptions: `[{"group_id":32,"validity_days":9}]`, + SettingKeyAuthSourceDefaultEmailGrantOnSignup: "false", + }, nil) + service.defaultSubAssigner = assigner + + _, user, err := service.Register(context.Background(), "email-global@test.com", "password") + require.NoError(t, err) + require.NotNil(t, user) + require.Equal(t, 3.5, user.Balance) + require.Equal(t, 2, user.Concurrency) + require.Len(t, assigner.calls, 1) + require.Equal(t, int64(31), assigner.calls[0].GroupID) + require.Equal(t, 5, assigner.calls[0].ValidityDays) +} + +func TestAuthService_LoginOrRegisterOAuthWithTokenPair_UsesLinuxDoAuthSourceDefaultsOnSignup(t *testing.T) { + repo := &userRepoStub{nextID: 61} + assigner := &defaultSubscriptionAssignerStub{} + service := newAuthService(repo, map[string]string{ + SettingKeyRegistrationEnabled: "true", + SettingKeyDefaultSubscriptions: `[{"group_id":81,"validity_days":1}]`, + SettingKeyAuthSourceDefaultLinuxDoBalance: "21.75", + SettingKeyAuthSourceDefaultLinuxDoConcurrency: "9", + SettingKeyAuthSourceDefaultLinuxDoSubscriptions: `[{"group_id":22,"validity_days":14}]`, + SettingKeyAuthSourceDefaultLinuxDoGrantOnSignup: "true", + }, nil) + service.defaultSubAssigner = assigner + service.refreshTokenCache = &refreshTokenCacheStub{} + + tokenPair, user, err := service.LoginOrRegisterOAuthWithTokenPair(context.Background(), "linuxdo-123@linuxdo-connect.invalid", "linuxdo_user", "") + require.NoError(t, err) + require.NotNil(t, tokenPair) + require.NotNil(t, user) + require.Equal(t, int64(61), user.ID) + require.Equal(t, 21.75, user.Balance) + require.Equal(t, 9, user.Concurrency) + require.Len(t, repo.created, 1) + require.Len(t, assigner.calls, 1) + require.Equal(t, int64(22), assigner.calls[0].GroupID) + require.Equal(t, 14, assigner.calls[0].ValidityDays) +} + +func TestAuthService_LoginOrRegisterOAuthWithTokenPair_ExistingUserDoesNotGrantAgain(t *testing.T) { + existing := &User{ + ID: 88, + Email: "linuxdo-123@linuxdo-connect.invalid", + Username: "existing-linuxdo", + Role: RoleUser, + Status: StatusActive, + Balance: 4, + Concurrency: 1, + TokenVersion: 2, + } + repo := &userRepoStub{user: existing} + assigner := &defaultSubscriptionAssignerStub{} + service := newAuthService(repo, map[string]string{ + SettingKeyRegistrationEnabled: "true", + SettingKeyAuthSourceDefaultLinuxDoBalance: "21.75", + SettingKeyAuthSourceDefaultLinuxDoConcurrency: "9", + SettingKeyAuthSourceDefaultLinuxDoSubscriptions: `[{"group_id":22,"validity_days":14}]`, + SettingKeyAuthSourceDefaultLinuxDoGrantOnSignup: "true", + }, nil) + service.defaultSubAssigner = assigner + service.refreshTokenCache = &refreshTokenCacheStub{} + + tokenPair, user, err := service.LoginOrRegisterOAuthWithTokenPair(context.Background(), existing.Email, "linuxdo_user", "") + require.NoError(t, err) + require.NotNil(t, tokenPair) + require.Equal(t, existing.ID, user.ID) + require.Equal(t, 4.0, user.Balance) + require.Equal(t, 1, user.Concurrency) + require.Empty(t, repo.created) + require.Empty(t, assigner.calls) +}