feat: apply auth source signup defaults

This commit is contained in:
IanShaw027
2026-04-20 18:39:53 +08:00
parent c6d8592484
commit 4e0e691546
4 changed files with 283 additions and 47 deletions

View File

@@ -79,6 +79,9 @@ func (s *userRepoStubForGroupUpdate) UpdateTotpSecret(context.Context, int64, *s
} }
func (s *userRepoStubForGroupUpdate) EnableTotp(context.Context, int64) error { panic("unexpected") } func (s *userRepoStubForGroupUpdate) EnableTotp(context.Context, int64) error { panic("unexpected") }
func (s *userRepoStubForGroupUpdate) DisableTotp(context.Context, int64) error { panic("unexpected") } func (s *userRepoStubForGroupUpdate) DisableTotp(context.Context, int64) error { panic("unexpected") }
func (s *userRepoStubForGroupUpdate) ListUserAuthIdentities(context.Context, int64) ([]UserAuthIdentityRecord, error) {
panic("unexpected")
}
func (s *userRepoStubForGroupUpdate) RemoveGroupFromUserAllowedGroups(context.Context, int64, int64) error { func (s *userRepoStubForGroupUpdate) RemoveGroupFromUserAllowedGroups(context.Context, int64, int64) error {
panic("unexpected") panic("unexpected")
} }

View File

@@ -13,15 +13,18 @@ import (
) )
type userRepoStub struct { type userRepoStub struct {
user *User user *User
getErr error getErr error
createErr error createErr error
deleteErr error deleteErr error
exists bool exists bool
existsErr error existsErr error
nextID int64 nextID int64
created []*User created []*User
deletedIDs []int64 updated []*User
deletedIDs []int64
usersByEmail map[string]*User
getByEmailErr error
} }
func (s *userRepoStub) Create(ctx context.Context, user *User) error { func (s *userRepoStub) Create(ctx context.Context, user *User) error {
@@ -32,6 +35,11 @@ func (s *userRepoStub) Create(ctx context.Context, user *User) error {
user.ID = s.nextID user.ID = s.nextID
} }
s.created = append(s.created, user) s.created = append(s.created, user)
if s.usersByEmail == nil {
s.usersByEmail = make(map[string]*User)
}
s.usersByEmail[user.Email] = user
s.user = user
return nil return nil
} }
@@ -46,7 +54,18 @@ func (s *userRepoStub) GetByID(ctx context.Context, id int64) (*User, error) {
} }
func (s *userRepoStub) GetByEmail(ctx context.Context, email string) (*User, error) { func (s *userRepoStub) GetByEmail(ctx context.Context, email string) (*User, error) {
panic("unexpected GetByEmail call") if s.getByEmailErr != nil {
return nil, s.getByEmailErr
}
if s.usersByEmail != nil {
if user, ok := s.usersByEmail[email]; ok {
return user, nil
}
}
if s.user != nil && s.user.Email == email {
return s.user, nil
}
return nil, ErrUserNotFound
} }
func (s *userRepoStub) GetFirstAdmin(ctx context.Context) (*User, error) { func (s *userRepoStub) GetFirstAdmin(ctx context.Context) (*User, error) {
@@ -54,7 +73,13 @@ func (s *userRepoStub) GetFirstAdmin(ctx context.Context) (*User, error) {
} }
func (s *userRepoStub) Update(ctx context.Context, user *User) error { func (s *userRepoStub) Update(ctx context.Context, user *User) error {
panic("unexpected Update call") s.updated = append(s.updated, user)
if s.usersByEmail == nil {
s.usersByEmail = make(map[string]*User)
}
s.usersByEmail[user.Email] = user
s.user = user
return nil
} }
func (s *userRepoStub) Delete(ctx context.Context, id int64) error { func (s *userRepoStub) Delete(ctx context.Context, id int64) error {
@@ -113,6 +138,10 @@ func (s *userRepoStub) AddGroupToAllowedGroups(ctx context.Context, userID int64
panic("unexpected AddGroupToAllowedGroups call") panic("unexpected AddGroupToAllowedGroups call")
} }
func (s *userRepoStub) ListUserAuthIdentities(ctx context.Context, userID int64) ([]UserAuthIdentityRecord, error) {
panic("unexpected ListUserAuthIdentities call")
}
func (s *userRepoStub) UpdateTotpSecret(ctx context.Context, userID int64, encryptedSecret *string) error { func (s *userRepoStub) UpdateTotpSecret(ctx context.Context, userID int64, encryptedSecret *string) error {
panic("unexpected UpdateTotpSecret call") panic("unexpected UpdateTotpSecret call")
} }

View File

@@ -78,6 +78,12 @@ type DefaultSubscriptionAssigner interface {
AssignOrExtendSubscription(ctx context.Context, input *AssignSubscriptionInput) (*UserSubscription, bool, error) AssignOrExtendSubscription(ctx context.Context, input *AssignSubscriptionInput) (*UserSubscription, bool, error)
} }
type signupGrantPlan struct {
Balance float64
Concurrency int
Subscriptions []DefaultSubscriptionSetting
}
// NewAuthService 创建认证服务实例 // NewAuthService 创建认证服务实例
func NewAuthService( func NewAuthService(
entClient *dbent.Client, entClient *dbent.Client,
@@ -187,21 +193,15 @@ func (s *AuthService) RegisterWithVerification(ctx context.Context, email, passw
return "", nil, fmt.Errorf("hash password: %w", err) return "", nil, fmt.Errorf("hash password: %w", err)
} }
// 获取默认配置 grantPlan := s.resolveSignupGrantPlan(ctx, "email")
defaultBalance := s.cfg.Default.UserBalance
defaultConcurrency := s.cfg.Default.UserConcurrency
if s.settingService != nil {
defaultBalance = s.settingService.GetDefaultBalance(ctx)
defaultConcurrency = s.settingService.GetDefaultConcurrency(ctx)
}
// 创建用户 // 创建用户
user := &User{ user := &User{
Email: email, Email: email,
PasswordHash: hashedPassword, PasswordHash: hashedPassword,
Role: RoleUser, Role: RoleUser,
Balance: defaultBalance, Balance: grantPlan.Balance,
Concurrency: defaultConcurrency, Concurrency: grantPlan.Concurrency,
Status: StatusActive, Status: StatusActive,
} }
@@ -214,7 +214,7 @@ func (s *AuthService) RegisterWithVerification(ctx context.Context, email, passw
return "", nil, ErrServiceUnavailable return "", nil, ErrServiceUnavailable
} }
s.postAuthUserBootstrap(ctx, user, "email", true) s.postAuthUserBootstrap(ctx, user, "email", true)
s.assignDefaultSubscriptions(ctx, user.ID) s.assignSubscriptions(ctx, user.ID, grantPlan.Subscriptions, "auto assigned by signup defaults")
// 标记邀请码为已使用(如果使用了邀请码) // 标记邀请码为已使用(如果使用了邀请码)
if invitationRedeemCode != nil { if invitationRedeemCode != nil {
@@ -479,21 +479,16 @@ func (s *AuthService) LoginOrRegisterOAuth(ctx context.Context, email, username
return "", nil, fmt.Errorf("hash password: %w", err) return "", nil, fmt.Errorf("hash password: %w", err)
} }
// 新用户默认值。 signupSource := inferLegacySignupSource(email)
defaultBalance := s.cfg.Default.UserBalance grantPlan := s.resolveSignupGrantPlan(ctx, signupSource)
defaultConcurrency := s.cfg.Default.UserConcurrency
if s.settingService != nil {
defaultBalance = s.settingService.GetDefaultBalance(ctx)
defaultConcurrency = s.settingService.GetDefaultConcurrency(ctx)
}
newUser := &User{ newUser := &User{
Email: email, Email: email,
Username: username, Username: username,
PasswordHash: hashedPassword, PasswordHash: hashedPassword,
Role: RoleUser, Role: RoleUser,
Balance: defaultBalance, Balance: grantPlan.Balance,
Concurrency: defaultConcurrency, Concurrency: grantPlan.Concurrency,
Status: StatusActive, Status: StatusActive,
} }
@@ -511,8 +506,8 @@ func (s *AuthService) LoginOrRegisterOAuth(ctx context.Context, email, username
} }
} else { } else {
user = newUser user = newUser
s.postAuthUserBootstrap(ctx, user, inferLegacySignupSource(email), true) s.postAuthUserBootstrap(ctx, user, signupSource, true)
s.assignDefaultSubscriptions(ctx, user.ID) s.assignSubscriptions(ctx, user.ID, grantPlan.Subscriptions, "auto assigned by signup defaults")
} }
} else { } else {
logger.LegacyPrintf("service.auth", "[Auth] Database error during oauth login: %v", err) logger.LegacyPrintf("service.auth", "[Auth] Database error during oauth login: %v", err)
@@ -596,20 +591,16 @@ func (s *AuthService) LoginOrRegisterOAuthWithTokenPair(ctx context.Context, ema
return nil, nil, fmt.Errorf("hash password: %w", err) return nil, nil, fmt.Errorf("hash password: %w", err)
} }
defaultBalance := s.cfg.Default.UserBalance signupSource := inferLegacySignupSource(email)
defaultConcurrency := s.cfg.Default.UserConcurrency grantPlan := s.resolveSignupGrantPlan(ctx, signupSource)
if s.settingService != nil {
defaultBalance = s.settingService.GetDefaultBalance(ctx)
defaultConcurrency = s.settingService.GetDefaultConcurrency(ctx)
}
newUser := &User{ newUser := &User{
Email: email, Email: email,
Username: username, Username: username,
PasswordHash: hashedPassword, PasswordHash: hashedPassword,
Role: RoleUser, Role: RoleUser,
Balance: defaultBalance, Balance: grantPlan.Balance,
Concurrency: defaultConcurrency, Concurrency: grantPlan.Concurrency,
Status: StatusActive, Status: StatusActive,
} }
@@ -642,8 +633,8 @@ func (s *AuthService) LoginOrRegisterOAuthWithTokenPair(ctx context.Context, ema
return nil, nil, ErrServiceUnavailable return nil, nil, ErrServiceUnavailable
} }
user = newUser user = newUser
s.postAuthUserBootstrap(ctx, user, inferLegacySignupSource(email), true) s.postAuthUserBootstrap(ctx, user, signupSource, true)
s.assignDefaultSubscriptions(ctx, user.ID) s.assignSubscriptions(ctx, user.ID, grantPlan.Subscriptions, "auto assigned by signup defaults")
} }
} else { } else {
if err := s.userRepo.Create(ctx, newUser); err != nil { if err := s.userRepo.Create(ctx, newUser); err != nil {
@@ -659,8 +650,8 @@ func (s *AuthService) LoginOrRegisterOAuthWithTokenPair(ctx context.Context, ema
} }
} else { } else {
user = newUser user = newUser
s.postAuthUserBootstrap(ctx, user, inferLegacySignupSource(email), true) s.postAuthUserBootstrap(ctx, user, signupSource, true)
s.assignDefaultSubscriptions(ctx, user.ID) s.assignSubscriptions(ctx, user.ID, grantPlan.Subscriptions, "auto assigned by signup defaults")
if invitationRedeemCode != nil { if invitationRedeemCode != nil {
if err := s.redeemRepo.Use(ctx, invitationRedeemCode.ID, user.ID); err != nil { if err := s.redeemRepo.Use(ctx, invitationRedeemCode.ID, user.ID); err != nil {
return nil, nil, ErrInvitationCodeInvalid return nil, nil, ErrInvitationCodeInvalid
@@ -694,22 +685,78 @@ func (s *AuthService) LoginOrRegisterOAuthWithTokenPair(ctx context.Context, ema
} }
func (s *AuthService) assignDefaultSubscriptions(ctx context.Context, userID int64) { func (s *AuthService) assignDefaultSubscriptions(ctx context.Context, userID int64) {
if s.settingService == nil {
return
}
s.assignSubscriptions(ctx, userID, s.settingService.GetDefaultSubscriptions(ctx), "auto assigned by default user subscriptions setting")
}
func (s *AuthService) assignSubscriptions(ctx context.Context, userID int64, items []DefaultSubscriptionSetting, notes string) {
if s.settingService == nil || s.defaultSubAssigner == nil || userID <= 0 { if s.settingService == nil || s.defaultSubAssigner == nil || userID <= 0 {
return return
} }
items := s.settingService.GetDefaultSubscriptions(ctx)
for _, item := range items { for _, item := range items {
if _, _, err := s.defaultSubAssigner.AssignOrExtendSubscription(ctx, &AssignSubscriptionInput{ if _, _, err := s.defaultSubAssigner.AssignOrExtendSubscription(ctx, &AssignSubscriptionInput{
UserID: userID, UserID: userID,
GroupID: item.GroupID, GroupID: item.GroupID,
ValidityDays: item.ValidityDays, ValidityDays: item.ValidityDays,
Notes: "auto assigned by default user subscriptions setting", Notes: notes,
}); err != nil { }); err != nil {
logger.LegacyPrintf("service.auth", "[Auth] Failed to assign default subscription: user_id=%d group_id=%d err=%v", userID, item.GroupID, err) logger.LegacyPrintf("service.auth", "[Auth] Failed to assign default subscription: user_id=%d group_id=%d err=%v", userID, item.GroupID, err)
} }
} }
} }
func (s *AuthService) resolveSignupGrantPlan(ctx context.Context, signupSource string) signupGrantPlan {
plan := signupGrantPlan{}
if s != nil && s.cfg != nil {
plan.Balance = s.cfg.Default.UserBalance
plan.Concurrency = s.cfg.Default.UserConcurrency
}
if s == nil || s.settingService == nil {
return plan
}
plan.Balance = s.settingService.GetDefaultBalance(ctx)
plan.Concurrency = s.settingService.GetDefaultConcurrency(ctx)
plan.Subscriptions = s.settingService.GetDefaultSubscriptions(ctx)
defaults, err := s.settingService.GetAuthSourceDefaultSettings(ctx)
if err != nil {
logger.LegacyPrintf("service.auth", "[Auth] Failed to load auth source signup defaults for %s: %v", signupSource, err)
return plan
}
providerDefaults, ok := authSourceSignupSettings(defaults, signupSource)
if !ok || !providerDefaults.GrantOnSignup {
return plan
}
plan.Balance = providerDefaults.Balance
plan.Concurrency = providerDefaults.Concurrency
plan.Subscriptions = providerDefaults.Subscriptions
return plan
}
func authSourceSignupSettings(defaults *AuthSourceDefaultSettings, signupSource string) (ProviderDefaultGrantSettings, bool) {
if defaults == nil {
return ProviderDefaultGrantSettings{}, false
}
switch strings.ToLower(strings.TrimSpace(signupSource)) {
case "email":
return defaults.Email, true
case "linuxdo":
return defaults.LinuxDo, true
case "oidc":
return defaults.OIDC, true
case "wechat":
return defaults.WeChat, true
default:
return ProviderDefaultGrantSettings{}, false
}
}
func (s *AuthService) postAuthUserBootstrap(ctx context.Context, user *User, signupSource string, touchLogin bool) { func (s *AuthService) postAuthUserBootstrap(ctx context.Context, user *User, signupSource string, touchLogin bool) {
if user == nil || user.ID <= 0 { if user == nil || user.ID <= 0 {
return return

View File

@@ -37,7 +37,16 @@ func (s *settingRepoStub) Set(ctx context.Context, key, value string) error {
} }
func (s *settingRepoStub) GetMultiple(ctx context.Context, keys []string) (map[string]string, error) { func (s *settingRepoStub) GetMultiple(ctx context.Context, keys []string) (map[string]string, error) {
panic("unexpected GetMultiple call") if s.err != nil {
return nil, s.err
}
result := make(map[string]string, len(keys))
for _, key := range keys {
if v, ok := s.values[key]; ok {
result[key] = v
}
}
return result, nil
} }
func (s *settingRepoStub) SetMultiple(ctx context.Context, settings map[string]string) error { func (s *settingRepoStub) SetMultiple(ctx context.Context, settings map[string]string) error {
@@ -62,6 +71,8 @@ type defaultSubscriptionAssignerStub struct {
err error err error
} }
type refreshTokenCacheStub struct{}
func (s *defaultSubscriptionAssignerStub) AssignOrExtendSubscription(_ context.Context, input *AssignSubscriptionInput) (*UserSubscription, bool, error) { func (s *defaultSubscriptionAssignerStub) AssignOrExtendSubscription(_ context.Context, input *AssignSubscriptionInput) (*UserSubscription, bool, error) {
if input != nil { if input != nil {
s.calls = append(s.calls, *input) s.calls = append(s.calls, *input)
@@ -72,6 +83,46 @@ func (s *defaultSubscriptionAssignerStub) AssignOrExtendSubscription(_ context.C
return &UserSubscription{UserID: input.UserID, GroupID: input.GroupID}, false, nil return &UserSubscription{UserID: input.UserID, GroupID: input.GroupID}, false, nil
} }
func (s *refreshTokenCacheStub) StoreRefreshToken(context.Context, string, *RefreshTokenData, time.Duration) error {
return nil
}
func (s *refreshTokenCacheStub) GetRefreshToken(context.Context, string) (*RefreshTokenData, error) {
return nil, ErrRefreshTokenNotFound
}
func (s *refreshTokenCacheStub) DeleteRefreshToken(context.Context, string) error {
return nil
}
func (s *refreshTokenCacheStub) DeleteUserRefreshTokens(context.Context, int64) error {
return nil
}
func (s *refreshTokenCacheStub) DeleteTokenFamily(context.Context, string) error {
return nil
}
func (s *refreshTokenCacheStub) AddToUserTokenSet(context.Context, int64, string, time.Duration) error {
return nil
}
func (s *refreshTokenCacheStub) AddToFamilyTokenSet(context.Context, string, string, time.Duration) error {
return nil
}
func (s *refreshTokenCacheStub) GetUserTokenHashes(context.Context, int64) ([]string, error) {
return nil, nil
}
func (s *refreshTokenCacheStub) GetFamilyTokenHashes(context.Context, string) ([]string, error) {
return nil, nil
}
func (s *refreshTokenCacheStub) IsTokenInFamily(context.Context, string, string) (bool, error) {
return false, nil
}
func (s *emailCacheStub) GetVerificationCode(ctx context.Context, email string) (*VerificationCodeData, error) { func (s *emailCacheStub) GetVerificationCode(ctx context.Context, email string) (*VerificationCodeData, error) {
if s.err != nil { if s.err != nil {
return nil, s.err return nil, s.err
@@ -484,3 +535,109 @@ func TestAuthService_Register_AssignsDefaultSubscriptions(t *testing.T) {
require.Equal(t, int64(12), assigner.calls[1].GroupID) require.Equal(t, int64(12), assigner.calls[1].GroupID)
require.Equal(t, 7, assigner.calls[1].ValidityDays) require.Equal(t, 7, assigner.calls[1].ValidityDays)
} }
func TestAuthService_Register_UsesEmailAuthSourceDefaultsWhenGrantEnabled(t *testing.T) {
repo := &userRepoStub{nextID: 52}
assigner := &defaultSubscriptionAssignerStub{}
service := newAuthService(repo, map[string]string{
SettingKeyRegistrationEnabled: "true",
SettingKeyDefaultSubscriptions: `[{"group_id":91,"validity_days":3}]`,
SettingKeyAuthSourceDefaultEmailBalance: "12.5",
SettingKeyAuthSourceDefaultEmailConcurrency: "7",
SettingKeyAuthSourceDefaultEmailSubscriptions: `[{"group_id":11,"validity_days":30}]`,
SettingKeyAuthSourceDefaultEmailGrantOnSignup: "true",
}, nil)
service.defaultSubAssigner = assigner
_, user, err := service.Register(context.Background(), "email-defaults@test.com", "password")
require.NoError(t, err)
require.NotNil(t, user)
require.Equal(t, 12.5, user.Balance)
require.Equal(t, 7, user.Concurrency)
require.Len(t, assigner.calls, 1)
require.Equal(t, int64(11), assigner.calls[0].GroupID)
require.Equal(t, 30, assigner.calls[0].ValidityDays)
}
func TestAuthService_Register_GrantOnSignupFalseFallsBackToGlobalDefaults(t *testing.T) {
repo := &userRepoStub{nextID: 53}
assigner := &defaultSubscriptionAssignerStub{}
service := newAuthService(repo, map[string]string{
SettingKeyRegistrationEnabled: "true",
SettingKeyDefaultSubscriptions: `[{"group_id":31,"validity_days":5}]`,
SettingKeyAuthSourceDefaultEmailBalance: "99",
SettingKeyAuthSourceDefaultEmailConcurrency: "88",
SettingKeyAuthSourceDefaultEmailSubscriptions: `[{"group_id":32,"validity_days":9}]`,
SettingKeyAuthSourceDefaultEmailGrantOnSignup: "false",
}, nil)
service.defaultSubAssigner = assigner
_, user, err := service.Register(context.Background(), "email-global@test.com", "password")
require.NoError(t, err)
require.NotNil(t, user)
require.Equal(t, 3.5, user.Balance)
require.Equal(t, 2, user.Concurrency)
require.Len(t, assigner.calls, 1)
require.Equal(t, int64(31), assigner.calls[0].GroupID)
require.Equal(t, 5, assigner.calls[0].ValidityDays)
}
func TestAuthService_LoginOrRegisterOAuthWithTokenPair_UsesLinuxDoAuthSourceDefaultsOnSignup(t *testing.T) {
repo := &userRepoStub{nextID: 61}
assigner := &defaultSubscriptionAssignerStub{}
service := newAuthService(repo, map[string]string{
SettingKeyRegistrationEnabled: "true",
SettingKeyDefaultSubscriptions: `[{"group_id":81,"validity_days":1}]`,
SettingKeyAuthSourceDefaultLinuxDoBalance: "21.75",
SettingKeyAuthSourceDefaultLinuxDoConcurrency: "9",
SettingKeyAuthSourceDefaultLinuxDoSubscriptions: `[{"group_id":22,"validity_days":14}]`,
SettingKeyAuthSourceDefaultLinuxDoGrantOnSignup: "true",
}, nil)
service.defaultSubAssigner = assigner
service.refreshTokenCache = &refreshTokenCacheStub{}
tokenPair, user, err := service.LoginOrRegisterOAuthWithTokenPair(context.Background(), "linuxdo-123@linuxdo-connect.invalid", "linuxdo_user", "")
require.NoError(t, err)
require.NotNil(t, tokenPair)
require.NotNil(t, user)
require.Equal(t, int64(61), user.ID)
require.Equal(t, 21.75, user.Balance)
require.Equal(t, 9, user.Concurrency)
require.Len(t, repo.created, 1)
require.Len(t, assigner.calls, 1)
require.Equal(t, int64(22), assigner.calls[0].GroupID)
require.Equal(t, 14, assigner.calls[0].ValidityDays)
}
func TestAuthService_LoginOrRegisterOAuthWithTokenPair_ExistingUserDoesNotGrantAgain(t *testing.T) {
existing := &User{
ID: 88,
Email: "linuxdo-123@linuxdo-connect.invalid",
Username: "existing-linuxdo",
Role: RoleUser,
Status: StatusActive,
Balance: 4,
Concurrency: 1,
TokenVersion: 2,
}
repo := &userRepoStub{user: existing}
assigner := &defaultSubscriptionAssignerStub{}
service := newAuthService(repo, map[string]string{
SettingKeyRegistrationEnabled: "true",
SettingKeyAuthSourceDefaultLinuxDoBalance: "21.75",
SettingKeyAuthSourceDefaultLinuxDoConcurrency: "9",
SettingKeyAuthSourceDefaultLinuxDoSubscriptions: `[{"group_id":22,"validity_days":14}]`,
SettingKeyAuthSourceDefaultLinuxDoGrantOnSignup: "true",
}, nil)
service.defaultSubAssigner = assigner
service.refreshTokenCache = &refreshTokenCacheStub{}
tokenPair, user, err := service.LoginOrRegisterOAuthWithTokenPair(context.Background(), existing.Email, "linuxdo_user", "")
require.NoError(t, err)
require.NotNil(t, tokenPair)
require.Equal(t, existing.ID, user.ID)
require.Equal(t, 4.0, user.Balance)
require.Equal(t, 1, user.Concurrency)
require.Empty(t, repo.created)
require.Empty(t, assigner.calls)
}