fix settings auth source default persistence

This commit is contained in:
IanShaw027
2026-04-21 02:08:04 +08:00
parent cd0338fbae
commit e12599c1b9
7 changed files with 398 additions and 75 deletions

View File

@@ -44,13 +44,11 @@ func (s *AuthService) applyProviderDefaultSettingsOnFirstBind(
userID int64,
providerType string,
) error {
defaults, err := s.settingService.GetAuthSourceDefaultSettings(ctx)
providerDefaults, enabled, err := s.settingService.ResolveAuthSourceGrantSettings(ctx, providerType, true)
if err != nil {
return fmt.Errorf("load auth source defaults: %w", err)
}
providerDefaults, ok := authSourceSignupSettings(defaults, providerType)
if !ok || !providerDefaults.GrantOnFirstBind {
if !enabled {
return nil
}

View File

@@ -716,20 +716,18 @@ func (s *AuthService) resolveSignupGrantPlan(ctx context.Context, signupSource s
plan.Concurrency = s.settingService.GetDefaultConcurrency(ctx)
plan.Subscriptions = s.settingService.GetDefaultSubscriptions(ctx)
defaults, err := s.settingService.GetAuthSourceDefaultSettings(ctx)
resolved, enabled, err := s.settingService.ResolveAuthSourceGrantSettings(ctx, signupSource, false)
if err != nil {
logger.LegacyPrintf("service.auth", "[Auth] Failed to load auth source signup defaults for %s: %v", signupSource, err)
return plan
}
providerDefaults, ok := authSourceSignupSettings(defaults, signupSource)
if !ok || !providerDefaults.GrantOnSignup {
if !enabled {
return plan
}
plan.Balance = providerDefaults.Balance
plan.Concurrency = providerDefaults.Concurrency
plan.Subscriptions = providerDefaults.Subscriptions
plan.Balance = resolved.Balance
plan.Concurrency = resolved.Concurrency
plan.Subscriptions = resolved.Subscriptions
return plan
}

View File

@@ -321,6 +321,47 @@ func TestAuthServiceLogin_AppliesEmailFirstBindDefaultsOnlyWhenEmailIdentityIsNe
require.Equal(t, 1, countProviderGrantRecords(t, client, user.ID, "email", "first_bind"))
}
func TestAuthServiceLogin_MergesEmailFirstBindSourceOverridesWithGlobalDefaults(t *testing.T) {
assigner := &authIdentityDefaultSubAssignerStub{}
svc, _, client := newAuthServiceWithEnt(t, map[string]string{
service.SettingKeyRegistrationEnabled: "true",
service.SettingKeyDefaultSubscriptions: `[{"group_id":21,"validity_days":14}]`,
service.SettingKeyAuthSourceDefaultEmailBalance: "8.5",
service.SettingKeyAuthSourceDefaultEmailConcurrency: "5",
service.SettingKeyAuthSourceDefaultEmailSubscriptions: `[]`,
service.SettingKeyAuthSourceDefaultEmailGrantOnFirstBind: "true",
}, assigner)
ctx := context.Background()
passwordHash, err := svc.HashPassword("password")
require.NoError(t, err)
user, err := client.User.Create().
SetEmail("merged-first-bind@example.com").
SetUsername("merged-user").
SetPasswordHash(passwordHash).
SetBalance(1.5).
SetConcurrency(2).
SetRole(service.RoleUser).
SetStatus(service.StatusActive).
Save(ctx)
require.NoError(t, err)
token, gotUser, err := svc.Login(ctx, user.Email, "password")
require.NoError(t, err)
require.NotEmpty(t, token)
require.NotNil(t, gotUser)
svc.RecordSuccessfulLogin(ctx, user.ID)
storedUser, err := client.User.Get(ctx, user.ID)
require.NoError(t, err)
require.Equal(t, 10.0, storedUser.Balance)
require.Equal(t, 4, storedUser.Concurrency)
require.Len(t, assigner.calls, 1)
require.Equal(t, int64(21), assigner.calls[0].GroupID)
require.Equal(t, 14, assigner.calls[0].ValidityDays)
require.Equal(t, 1, countProviderGrantRecords(t, client, user.ID, "email", "first_bind"))
}
func TestAuthServiceLogin_DoesNotApplyEmailFirstBindDefaultsWhenIdentityAlreadyExists(t *testing.T) {
assigner := &authIdentityDefaultSubAssignerStub{}
svc, _, client := newAuthServiceWithEnt(t, map[string]string{

View File

@@ -584,6 +584,29 @@ func TestAuthService_Register_GrantOnSignupFalseFallsBackToGlobalDefaults(t *tes
require.Equal(t, 5, assigner.calls[0].ValidityDays)
}
func TestAuthService_Register_GrantOnSignupMergesSourceOverridesWithGlobalDefaults(t *testing.T) {
repo := &userRepoStub{nextID: 54}
assigner := &defaultSubscriptionAssignerStub{}
service := newAuthService(repo, map[string]string{
SettingKeyRegistrationEnabled: "true",
SettingKeyDefaultSubscriptions: `[{"group_id":31,"validity_days":5}]`,
SettingKeyAuthSourceDefaultEmailBalance: "9.5",
SettingKeyAuthSourceDefaultEmailConcurrency: "5",
SettingKeyAuthSourceDefaultEmailSubscriptions: `[]`,
SettingKeyAuthSourceDefaultEmailGrantOnSignup: "true",
}, nil)
service.defaultSubAssigner = assigner
_, user, err := service.Register(context.Background(), "email-merged@test.com", "password")
require.NoError(t, err)
require.NotNil(t, user)
require.Equal(t, 9.5, user.Balance)
require.Equal(t, 2, user.Concurrency)
require.Len(t, assigner.calls, 1)
require.Equal(t, int64(31), assigner.calls[0].GroupID)
require.Equal(t, 5, assigner.calls[0].ValidityDays)
}
func TestAuthService_LoginOrRegisterOAuthWithTokenPair_UsesLinuxDoAuthSourceDefaultsOnSignup(t *testing.T) {
repo := &userRepoStub{nextID: 61}
assigner := &defaultSubscriptionAssignerStub{}

View File

@@ -569,12 +569,47 @@ func parseCustomMenuItemURLs(raw string) []string {
// UpdateSettings 更新系统设置
func (s *SettingService) UpdateSettings(ctx context.Context, settings *SystemSettings) error {
if err := s.validateDefaultSubscriptionGroups(ctx, settings.DefaultSubscriptions); err != nil {
updates, err := s.buildSystemSettingsUpdates(ctx, settings)
if err != nil {
return err
}
err = s.settingRepo.SetMultiple(ctx, updates)
if err == nil {
s.refreshCachedSettings(settings)
}
return err
}
// UpdateSettingsWithAuthSourceDefaults persists system settings and auth-source defaults in a single write.
func (s *SettingService) UpdateSettingsWithAuthSourceDefaults(ctx context.Context, settings *SystemSettings, authDefaults *AuthSourceDefaultSettings) error {
updates, err := s.buildSystemSettingsUpdates(ctx, settings)
if err != nil {
return err
}
authSourceUpdates, err := s.buildAuthSourceDefaultUpdates(ctx, authDefaults)
if err != nil {
return err
}
for key, value := range authSourceUpdates {
updates[key] = value
}
err = s.settingRepo.SetMultiple(ctx, updates)
if err == nil {
s.refreshCachedSettings(settings)
}
return err
}
func (s *SettingService) buildSystemSettingsUpdates(ctx context.Context, settings *SystemSettings) (map[string]string, error) {
if err := s.validateDefaultSubscriptionGroups(ctx, settings.DefaultSubscriptions); err != nil {
return nil, err
}
normalizedWhitelist, err := NormalizeRegistrationEmailSuffixWhitelist(settings.RegistrationEmailSuffixWhitelist)
if err != nil {
return infraerrors.BadRequest("INVALID_REGISTRATION_EMAIL_SUFFIX_WHITELIST", err.Error())
return nil, infraerrors.BadRequest("INVALID_REGISTRATION_EMAIL_SUFFIX_WHITELIST", err.Error())
}
if normalizedWhitelist == nil {
normalizedWhitelist = []string{}
@@ -582,11 +617,11 @@ func (s *SettingService) UpdateSettings(ctx context.Context, settings *SystemSet
settings.RegistrationEmailSuffixWhitelist = normalizedWhitelist
alipaySource, err := normalizeVisibleMethodSettingSource("alipay", settings.PaymentVisibleMethodAlipaySource, settings.PaymentVisibleMethodAlipayEnabled)
if err != nil {
return err
return nil, err
}
wxpaySource, err := normalizeVisibleMethodSettingSource("wxpay", settings.PaymentVisibleMethodWxpaySource, settings.PaymentVisibleMethodWxpayEnabled)
if err != nil {
return err
return nil, err
}
settings.PaymentVisibleMethodAlipaySource = alipaySource
settings.PaymentVisibleMethodWxpaySource = wxpaySource
@@ -598,7 +633,7 @@ func (s *SettingService) UpdateSettings(ctx context.Context, settings *SystemSet
updates[SettingKeyEmailVerifyEnabled] = strconv.FormatBool(settings.EmailVerifyEnabled)
registrationEmailSuffixWhitelistJSON, err := json.Marshal(settings.RegistrationEmailSuffixWhitelist)
if err != nil {
return fmt.Errorf("marshal registration email suffix whitelist: %w", err)
return nil, fmt.Errorf("marshal registration email suffix whitelist: %w", err)
}
updates[SettingKeyRegistrationEmailSuffixWhitelist] = string(registrationEmailSuffixWhitelistJSON)
updates[SettingKeyPromoCodeEnabled] = strconv.FormatBool(settings.PromoCodeEnabled)
@@ -677,7 +712,7 @@ func (s *SettingService) UpdateSettings(ctx context.Context, settings *SystemSet
updates[SettingKeyTableDefaultPageSize] = strconv.Itoa(tableDefaultPageSize)
tablePageSizeOptionsJSON, err := json.Marshal(tablePageSizeOptions)
if err != nil {
return fmt.Errorf("marshal table page size options: %w", err)
return nil, fmt.Errorf("marshal table page size options: %w", err)
}
updates[SettingKeyTablePageSizeOptions] = string(tablePageSizeOptionsJSON)
updates[SettingKeyCustomMenuItems] = settings.CustomMenuItems
@@ -688,7 +723,7 @@ func (s *SettingService) UpdateSettings(ctx context.Context, settings *SystemSet
updates[SettingKeyDefaultBalance] = strconv.FormatFloat(settings.DefaultBalance, 'f', 8, 64)
defaultSubsJSON, err := json.Marshal(settings.DefaultSubscriptions)
if err != nil {
return fmt.Errorf("marshal default subscriptions: %w", err)
return nil, fmt.Errorf("marshal default subscriptions: %w", err)
}
updates[SettingKeyDefaultSubscriptions] = string(defaultSubsJSON)
@@ -738,37 +773,66 @@ func (s *SettingService) UpdateSettings(ctx context.Context, settings *SystemSet
updates[SettingKeyAccountQuotaNotifyEnabled] = strconv.FormatBool(settings.AccountQuotaNotifyEnabled)
updates[SettingKeyAccountQuotaNotifyEmails] = MarshalNotifyEmails(settings.AccountQuotaNotifyEmails)
err = s.settingRepo.SetMultiple(ctx, updates)
if err == nil {
// 先使 inflight singleflight 失效,再刷新缓存,缩小旧值覆盖新值的竞态窗口
versionBoundsSF.Forget("version_bounds")
versionBoundsCache.Store(&cachedVersionBounds{
min: settings.MinClaudeCodeVersion,
max: settings.MaxClaudeCodeVersion,
expiresAt: time.Now().Add(versionBoundsCacheTTL).UnixNano(),
})
backendModeSF.Forget("backend_mode")
backendModeCache.Store(&cachedBackendMode{
value: settings.BackendModeEnabled,
expiresAt: time.Now().Add(backendModeCacheTTL).UnixNano(),
})
gatewayForwardingSF.Forget("gateway_forwarding")
gatewayForwardingCache.Store(&cachedGatewayForwardingSettings{
fingerprintUnification: settings.EnableFingerprintUnification,
metadataPassthrough: settings.EnableMetadataPassthrough,
cchSigning: settings.EnableCCHSigning,
expiresAt: time.Now().Add(gatewayForwardingCacheTTL).UnixNano(),
})
openAIAdvancedSchedulerSettingSF.Forget(openAIAdvancedSchedulerSettingKey)
openAIAdvancedSchedulerSettingCache.Store(&cachedOpenAIAdvancedSchedulerSetting{
enabled: settings.OpenAIAdvancedSchedulerEnabled,
expiresAt: time.Now().Add(openAIAdvancedSchedulerSettingCacheTTL).UnixNano(),
})
if s.onUpdate != nil {
s.onUpdate() // Invalidate cache after settings update
return updates, nil
}
func (s *SettingService) buildAuthSourceDefaultUpdates(ctx context.Context, settings *AuthSourceDefaultSettings) (map[string]string, error) {
if settings == nil {
return nil, nil
}
for _, subscriptions := range [][]DefaultSubscriptionSetting{
settings.Email.Subscriptions,
settings.LinuxDo.Subscriptions,
settings.OIDC.Subscriptions,
settings.WeChat.Subscriptions,
} {
if err := s.validateDefaultSubscriptionGroups(ctx, subscriptions); err != nil {
return nil, err
}
}
return err
updates := make(map[string]string, 21)
writeProviderDefaultGrantUpdates(updates, emailAuthSourceDefaultKeys, settings.Email)
writeProviderDefaultGrantUpdates(updates, linuxDoAuthSourceDefaultKeys, settings.LinuxDo)
writeProviderDefaultGrantUpdates(updates, oidcAuthSourceDefaultKeys, settings.OIDC)
writeProviderDefaultGrantUpdates(updates, weChatAuthSourceDefaultKeys, settings.WeChat)
updates[SettingKeyForceEmailOnThirdPartySignup] = strconv.FormatBool(settings.ForceEmailOnThirdPartySignup)
return updates, nil
}
func (s *SettingService) refreshCachedSettings(settings *SystemSettings) {
if settings == nil {
return
}
// 先使 inflight singleflight 失效,再刷新缓存,缩小旧值覆盖新值的竞态窗口
versionBoundsSF.Forget("version_bounds")
versionBoundsCache.Store(&cachedVersionBounds{
min: settings.MinClaudeCodeVersion,
max: settings.MaxClaudeCodeVersion,
expiresAt: time.Now().Add(versionBoundsCacheTTL).UnixNano(),
})
backendModeSF.Forget("backend_mode")
backendModeCache.Store(&cachedBackendMode{
value: settings.BackendModeEnabled,
expiresAt: time.Now().Add(backendModeCacheTTL).UnixNano(),
})
gatewayForwardingSF.Forget("gateway_forwarding")
gatewayForwardingCache.Store(&cachedGatewayForwardingSettings{
fingerprintUnification: settings.EnableFingerprintUnification,
metadataPassthrough: settings.EnableMetadataPassthrough,
cchSigning: settings.EnableCCHSigning,
expiresAt: time.Now().Add(gatewayForwardingCacheTTL).UnixNano(),
})
openAIAdvancedSchedulerSettingSF.Forget(openAIAdvancedSchedulerSettingKey)
openAIAdvancedSchedulerSettingCache.Store(&cachedOpenAIAdvancedSchedulerSetting{
enabled: settings.OpenAIAdvancedSchedulerEnabled,
expiresAt: time.Now().Add(openAIAdvancedSchedulerSettingCacheTTL).UnixNano(),
})
if s.onUpdate != nil {
s.onUpdate() // Invalidate cache after settings update
}
}
func (s *SettingService) validateDefaultSubscriptionGroups(ctx context.Context, items []DefaultSubscriptionSetting) error {
@@ -1067,29 +1131,43 @@ func (s *SettingService) GetAuthSourceDefaultSettings(ctx context.Context) (*Aut
}, nil
}
func (s *SettingService) ResolveAuthSourceGrantSettings(ctx context.Context, signupSource string, firstBind bool) (ProviderDefaultGrantSettings, bool, error) {
result := ProviderDefaultGrantSettings{
Balance: s.GetDefaultBalance(ctx),
Concurrency: s.GetDefaultConcurrency(ctx),
Subscriptions: s.GetDefaultSubscriptions(ctx),
}
defaults, err := s.GetAuthSourceDefaultSettings(ctx)
if err != nil {
return result, false, err
}
providerDefaults, ok := authSourceSignupSettings(defaults, signupSource)
if !ok {
return result, false, nil
}
enabled := providerDefaults.GrantOnSignup
if firstBind {
enabled = providerDefaults.GrantOnFirstBind
}
if !enabled {
return result, false, nil
}
return mergeProviderDefaultGrantSettings(result, providerDefaults), true, nil
}
func (s *SettingService) UpdateAuthSourceDefaultSettings(ctx context.Context, settings *AuthSourceDefaultSettings) error {
if settings == nil {
updates, err := s.buildAuthSourceDefaultUpdates(ctx, settings)
if err != nil {
return err
}
if len(updates) == 0 {
return nil
}
for _, subscriptions := range [][]DefaultSubscriptionSetting{
settings.Email.Subscriptions,
settings.LinuxDo.Subscriptions,
settings.OIDC.Subscriptions,
settings.WeChat.Subscriptions,
} {
if err := s.validateDefaultSubscriptionGroups(ctx, subscriptions); err != nil {
return err
}
}
updates := make(map[string]string, 21)
writeProviderDefaultGrantUpdates(updates, emailAuthSourceDefaultKeys, settings.Email)
writeProviderDefaultGrantUpdates(updates, linuxDoAuthSourceDefaultKeys, settings.LinuxDo)
writeProviderDefaultGrantUpdates(updates, oidcAuthSourceDefaultKeys, settings.OIDC)
writeProviderDefaultGrantUpdates(updates, weChatAuthSourceDefaultKeys, settings.WeChat)
updates[SettingKeyForceEmailOnThirdPartySignup] = strconv.FormatBool(settings.ForceEmailOnThirdPartySignup)
if err := s.settingRepo.SetMultiple(ctx, updates); err != nil {
return fmt.Errorf("update auth source default settings: %w", err)
}
@@ -1594,6 +1672,28 @@ func writeProviderDefaultGrantUpdates(updates map[string]string, keys authSource
updates[keys.grantOnFirstBind] = strconv.FormatBool(settings.GrantOnFirstBind)
}
func mergeProviderDefaultGrantSettings(globalDefaults ProviderDefaultGrantSettings, providerDefaults ProviderDefaultGrantSettings) ProviderDefaultGrantSettings {
result := ProviderDefaultGrantSettings{
Balance: globalDefaults.Balance,
Concurrency: globalDefaults.Concurrency,
Subscriptions: append([]DefaultSubscriptionSetting(nil), globalDefaults.Subscriptions...),
GrantOnSignup: providerDefaults.GrantOnSignup,
GrantOnFirstBind: providerDefaults.GrantOnFirstBind,
}
if providerDefaults.Balance != defaultAuthSourceBalance {
result.Balance = providerDefaults.Balance
}
if providerDefaults.Concurrency > 0 && providerDefaults.Concurrency != defaultAuthSourceConcurrency {
result.Concurrency = providerDefaults.Concurrency
}
if len(providerDefaults.Subscriptions) > 0 {
result.Subscriptions = append([]DefaultSubscriptionSetting(nil), providerDefaults.Subscriptions...)
}
return result
}
func parseTablePreferences(defaultPageSizeRaw, optionsRaw string) (int, []int) {
defaultPageSize := 20
if v, err := strconv.Atoi(strings.TrimSpace(defaultPageSizeRaw)); err == nil {