fix settings auth source default persistence
This commit is contained in:
@@ -1011,10 +1011,6 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
|
|||||||
}(),
|
}(),
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := h.settingService.UpdateSettings(c.Request.Context(), settings); err != nil {
|
|
||||||
response.ErrorFrom(c, err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
authSourceDefaults := &service.AuthSourceDefaultSettings{
|
authSourceDefaults := &service.AuthSourceDefaultSettings{
|
||||||
Email: service.ProviderDefaultGrantSettings{
|
Email: service.ProviderDefaultGrantSettings{
|
||||||
Balance: float64ValueOrDefault(req.AuthSourceDefaultEmailBalance, previousAuthSourceDefaults.Email.Balance),
|
Balance: float64ValueOrDefault(req.AuthSourceDefaultEmailBalance, previousAuthSourceDefaults.Email.Balance),
|
||||||
@@ -1046,7 +1042,7 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
|
|||||||
},
|
},
|
||||||
ForceEmailOnThirdPartySignup: boolValueOrDefault(req.ForceEmailOnThirdPartySignup, previousAuthSourceDefaults.ForceEmailOnThirdPartySignup),
|
ForceEmailOnThirdPartySignup: boolValueOrDefault(req.ForceEmailOnThirdPartySignup, previousAuthSourceDefaults.ForceEmailOnThirdPartySignup),
|
||||||
}
|
}
|
||||||
if err := h.settingService.UpdateAuthSourceDefaultSettings(c.Request.Context(), authSourceDefaults); err != nil {
|
if err := h.settingService.UpdateSettingsWithAuthSourceDefaults(c.Request.Context(), settings, authSourceDefaults); err != nil {
|
||||||
response.ErrorFrom(c, err)
|
response.ErrorFrom(c, err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -1086,7 +1082,7 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
h.auditSettingsUpdate(c, previousSettings, settings, req)
|
h.auditSettingsUpdate(c, previousSettings, settings, previousAuthSourceDefaults, authSourceDefaults, req)
|
||||||
|
|
||||||
// 重新获取设置返回
|
// 重新获取设置返回
|
||||||
updatedSettings, err := h.settingService.GetAllSettings(c.Request.Context())
|
updatedSettings, err := h.settingService.GetAllSettings(c.Request.Context())
|
||||||
@@ -1245,12 +1241,12 @@ func hasPaymentFields(req UpdateSettingsRequest) bool {
|
|||||||
req.PaymentCancelRateLimitUnit != nil || req.PaymentCancelRateLimitMode != nil
|
req.PaymentCancelRateLimitUnit != nil || req.PaymentCancelRateLimitMode != nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *SettingHandler) auditSettingsUpdate(c *gin.Context, before *service.SystemSettings, after *service.SystemSettings, req UpdateSettingsRequest) {
|
func (h *SettingHandler) auditSettingsUpdate(c *gin.Context, before *service.SystemSettings, after *service.SystemSettings, beforeAuthSourceDefaults *service.AuthSourceDefaultSettings, afterAuthSourceDefaults *service.AuthSourceDefaultSettings, req UpdateSettingsRequest) {
|
||||||
if before == nil || after == nil {
|
if before == nil || after == nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
changed := diffSettings(before, after, req)
|
changed := diffSettings(before, after, beforeAuthSourceDefaults, afterAuthSourceDefaults, req)
|
||||||
if len(changed) == 0 {
|
if len(changed) == 0 {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -1265,7 +1261,7 @@ func (h *SettingHandler) auditSettingsUpdate(c *gin.Context, before *service.Sys
|
|||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
func diffSettings(before *service.SystemSettings, after *service.SystemSettings, req UpdateSettingsRequest) []string {
|
func diffSettings(before *service.SystemSettings, after *service.SystemSettings, beforeAuthSourceDefaults *service.AuthSourceDefaultSettings, afterAuthSourceDefaults *service.AuthSourceDefaultSettings, req UpdateSettingsRequest) []string {
|
||||||
changed := make([]string, 0, 20)
|
changed := make([]string, 0, 20)
|
||||||
if before.RegistrationEnabled != after.RegistrationEnabled {
|
if before.RegistrationEnabled != after.RegistrationEnabled {
|
||||||
changed = append(changed, "registration_enabled")
|
changed = append(changed, "registration_enabled")
|
||||||
@@ -1535,6 +1531,50 @@ func diffSettings(before *service.SystemSettings, after *service.SystemSettings,
|
|||||||
if !equalNotifyEmailEntries(before.AccountQuotaNotifyEmails, after.AccountQuotaNotifyEmails) {
|
if !equalNotifyEmailEntries(before.AccountQuotaNotifyEmails, after.AccountQuotaNotifyEmails) {
|
||||||
changed = append(changed, "account_quota_notify_emails")
|
changed = append(changed, "account_quota_notify_emails")
|
||||||
}
|
}
|
||||||
|
changed = appendAuthSourceDefaultChanges(changed, beforeAuthSourceDefaults, afterAuthSourceDefaults)
|
||||||
|
return changed
|
||||||
|
}
|
||||||
|
|
||||||
|
func appendAuthSourceDefaultChanges(changed []string, before *service.AuthSourceDefaultSettings, after *service.AuthSourceDefaultSettings) []string {
|
||||||
|
if before == nil {
|
||||||
|
before = &service.AuthSourceDefaultSettings{}
|
||||||
|
}
|
||||||
|
if after == nil {
|
||||||
|
after = &service.AuthSourceDefaultSettings{}
|
||||||
|
}
|
||||||
|
|
||||||
|
type providerDefaultGrantField struct {
|
||||||
|
name string
|
||||||
|
before service.ProviderDefaultGrantSettings
|
||||||
|
after service.ProviderDefaultGrantSettings
|
||||||
|
}
|
||||||
|
|
||||||
|
fields := []providerDefaultGrantField{
|
||||||
|
{name: "email", before: before.Email, after: after.Email},
|
||||||
|
{name: "linuxdo", before: before.LinuxDo, after: after.LinuxDo},
|
||||||
|
{name: "oidc", before: before.OIDC, after: after.OIDC},
|
||||||
|
{name: "wechat", before: before.WeChat, after: after.WeChat},
|
||||||
|
}
|
||||||
|
for _, field := range fields {
|
||||||
|
if field.before.Balance != field.after.Balance {
|
||||||
|
changed = append(changed, "auth_source_default_"+field.name+"_balance")
|
||||||
|
}
|
||||||
|
if field.before.Concurrency != field.after.Concurrency {
|
||||||
|
changed = append(changed, "auth_source_default_"+field.name+"_concurrency")
|
||||||
|
}
|
||||||
|
if !equalDefaultSubscriptions(field.before.Subscriptions, field.after.Subscriptions) {
|
||||||
|
changed = append(changed, "auth_source_default_"+field.name+"_subscriptions")
|
||||||
|
}
|
||||||
|
if field.before.GrantOnSignup != field.after.GrantOnSignup {
|
||||||
|
changed = append(changed, "auth_source_default_"+field.name+"_grant_on_signup")
|
||||||
|
}
|
||||||
|
if field.before.GrantOnFirstBind != field.after.GrantOnFirstBind {
|
||||||
|
changed = append(changed, "auth_source_default_"+field.name+"_grant_on_first_bind")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if before.ForceEmailOnThirdPartySignup != after.ForceEmailOnThirdPartySignup {
|
||||||
|
changed = append(changed, "force_email_on_third_party_signup")
|
||||||
|
}
|
||||||
return changed
|
return changed
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -4,6 +4,7 @@ import (
|
|||||||
"bytes"
|
"bytes"
|
||||||
"context"
|
"context"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
|
"errors"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/http/httptest"
|
"net/http/httptest"
|
||||||
"testing"
|
"testing"
|
||||||
@@ -66,6 +67,58 @@ func (s *settingHandlerRepoStub) Delete(ctx context.Context, key string) error {
|
|||||||
panic("unexpected Delete call")
|
panic("unexpected Delete call")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type failingAuthSourceSettingsRepoStub struct {
|
||||||
|
values map[string]string
|
||||||
|
err error
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *failingAuthSourceSettingsRepoStub) Get(ctx context.Context, key string) (*service.Setting, error) {
|
||||||
|
panic("unexpected Get call")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *failingAuthSourceSettingsRepoStub) GetValue(ctx context.Context, key string) (string, error) {
|
||||||
|
panic("unexpected GetValue call")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *failingAuthSourceSettingsRepoStub) Set(ctx context.Context, key, value string) error {
|
||||||
|
panic("unexpected Set call")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *failingAuthSourceSettingsRepoStub) GetMultiple(ctx context.Context, keys []string) (map[string]string, error) {
|
||||||
|
out := make(map[string]string, len(keys))
|
||||||
|
for _, key := range keys {
|
||||||
|
if value, ok := s.values[key]; ok {
|
||||||
|
out[key] = value
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return out, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *failingAuthSourceSettingsRepoStub) SetMultiple(ctx context.Context, settings map[string]string) error {
|
||||||
|
if _, ok := settings[service.SettingKeyAuthSourceDefaultEmailBalance]; ok {
|
||||||
|
return s.err
|
||||||
|
}
|
||||||
|
for key, value := range settings {
|
||||||
|
if s.values == nil {
|
||||||
|
s.values = map[string]string{}
|
||||||
|
}
|
||||||
|
s.values[key] = value
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *failingAuthSourceSettingsRepoStub) GetAll(ctx context.Context) (map[string]string, error) {
|
||||||
|
out := make(map[string]string, len(s.values))
|
||||||
|
for key, value := range s.values {
|
||||||
|
out[key] = value
|
||||||
|
}
|
||||||
|
return out, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *failingAuthSourceSettingsRepoStub) Delete(ctx context.Context, key string) error {
|
||||||
|
panic("unexpected Delete call")
|
||||||
|
}
|
||||||
|
|
||||||
func TestSettingHandler_GetSettings_InjectsAuthSourceDefaults(t *testing.T) {
|
func TestSettingHandler_GetSettings_InjectsAuthSourceDefaults(t *testing.T) {
|
||||||
gin.SetMode(gin.TestMode)
|
gin.SetMode(gin.TestMode)
|
||||||
repo := &settingHandlerRepoStub{
|
repo := &settingHandlerRepoStub{
|
||||||
@@ -221,3 +274,73 @@ func TestSettingHandler_UpdateSettings_RejectsInvalidPaymentVisibleMethodSource(
|
|||||||
require.Equal(t, http.StatusBadRequest, rec.Code)
|
require.Equal(t, http.StatusBadRequest, rec.Code)
|
||||||
require.NotContains(t, repo.values, service.SettingPaymentVisibleMethodAlipaySource)
|
require.NotContains(t, repo.values, service.SettingPaymentVisibleMethodAlipaySource)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestSettingHandler_UpdateSettings_DoesNotPersistPartialSystemSettingsWhenAuthSourceDefaultsFail(t *testing.T) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
repo := &failingAuthSourceSettingsRepoStub{
|
||||||
|
values: map[string]string{
|
||||||
|
service.SettingKeyRegistrationEnabled: "false",
|
||||||
|
service.SettingKeyPromoCodeEnabled: "true",
|
||||||
|
service.SettingKeyAuthSourceDefaultEmailBalance: "9.5",
|
||||||
|
service.SettingKeyAuthSourceDefaultEmailConcurrency: "8",
|
||||||
|
service.SettingKeyAuthSourceDefaultEmailSubscriptions: `[{"group_id":31,"validity_days":15}]`,
|
||||||
|
},
|
||||||
|
err: errors.New("write auth source defaults failed"),
|
||||||
|
}
|
||||||
|
svc := service.NewSettingService(repo, &config.Config{Default: config.DefaultConfig{UserConcurrency: 5}})
|
||||||
|
handler := NewSettingHandler(svc, nil, nil, nil, nil, nil)
|
||||||
|
|
||||||
|
body := map[string]any{
|
||||||
|
"registration_enabled": true,
|
||||||
|
"promo_code_enabled": true,
|
||||||
|
"auth_source_default_email_balance": 12.75,
|
||||||
|
}
|
||||||
|
rawBody, err := json.Marshal(body)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
c, _ := gin.CreateTestContext(rec)
|
||||||
|
c.Request = httptest.NewRequest(http.MethodPut, "/api/v1/admin/settings", bytes.NewReader(rawBody))
|
||||||
|
c.Request.Header.Set("Content-Type", "application/json")
|
||||||
|
|
||||||
|
handler.UpdateSettings(c)
|
||||||
|
|
||||||
|
require.Equal(t, http.StatusInternalServerError, rec.Code)
|
||||||
|
require.Equal(t, "false", repo.values[service.SettingKeyRegistrationEnabled])
|
||||||
|
require.Equal(t, "9.5", repo.values[service.SettingKeyAuthSourceDefaultEmailBalance])
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDiffSettings_IncludesAuthSourceDefaultsAndForceEmail(t *testing.T) {
|
||||||
|
changed := diffSettings(
|
||||||
|
&service.SystemSettings{},
|
||||||
|
&service.SystemSettings{},
|
||||||
|
&service.AuthSourceDefaultSettings{
|
||||||
|
Email: service.ProviderDefaultGrantSettings{
|
||||||
|
Balance: 0,
|
||||||
|
Concurrency: 5,
|
||||||
|
Subscriptions: nil,
|
||||||
|
GrantOnSignup: true,
|
||||||
|
GrantOnFirstBind: false,
|
||||||
|
},
|
||||||
|
ForceEmailOnThirdPartySignup: false,
|
||||||
|
},
|
||||||
|
&service.AuthSourceDefaultSettings{
|
||||||
|
Email: service.ProviderDefaultGrantSettings{
|
||||||
|
Balance: 12.5,
|
||||||
|
Concurrency: 7,
|
||||||
|
Subscriptions: []service.DefaultSubscriptionSetting{{GroupID: 21, ValidityDays: 30}},
|
||||||
|
GrantOnSignup: false,
|
||||||
|
GrantOnFirstBind: true,
|
||||||
|
},
|
||||||
|
ForceEmailOnThirdPartySignup: true,
|
||||||
|
},
|
||||||
|
UpdateSettingsRequest{},
|
||||||
|
)
|
||||||
|
|
||||||
|
require.Contains(t, changed, "auth_source_default_email_balance")
|
||||||
|
require.Contains(t, changed, "auth_source_default_email_concurrency")
|
||||||
|
require.Contains(t, changed, "auth_source_default_email_subscriptions")
|
||||||
|
require.Contains(t, changed, "auth_source_default_email_grant_on_signup")
|
||||||
|
require.Contains(t, changed, "auth_source_default_email_grant_on_first_bind")
|
||||||
|
require.Contains(t, changed, "force_email_on_third_party_signup")
|
||||||
|
}
|
||||||
|
|||||||
@@ -44,13 +44,11 @@ func (s *AuthService) applyProviderDefaultSettingsOnFirstBind(
|
|||||||
userID int64,
|
userID int64,
|
||||||
providerType string,
|
providerType string,
|
||||||
) error {
|
) error {
|
||||||
defaults, err := s.settingService.GetAuthSourceDefaultSettings(ctx)
|
providerDefaults, enabled, err := s.settingService.ResolveAuthSourceGrantSettings(ctx, providerType, true)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("load auth source defaults: %w", err)
|
return fmt.Errorf("load auth source defaults: %w", err)
|
||||||
}
|
}
|
||||||
|
if !enabled {
|
||||||
providerDefaults, ok := authSourceSignupSettings(defaults, providerType)
|
|
||||||
if !ok || !providerDefaults.GrantOnFirstBind {
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -716,20 +716,18 @@ func (s *AuthService) resolveSignupGrantPlan(ctx context.Context, signupSource s
|
|||||||
plan.Concurrency = s.settingService.GetDefaultConcurrency(ctx)
|
plan.Concurrency = s.settingService.GetDefaultConcurrency(ctx)
|
||||||
plan.Subscriptions = s.settingService.GetDefaultSubscriptions(ctx)
|
plan.Subscriptions = s.settingService.GetDefaultSubscriptions(ctx)
|
||||||
|
|
||||||
defaults, err := s.settingService.GetAuthSourceDefaultSettings(ctx)
|
resolved, enabled, err := s.settingService.ResolveAuthSourceGrantSettings(ctx, signupSource, false)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.LegacyPrintf("service.auth", "[Auth] Failed to load auth source signup defaults for %s: %v", signupSource, err)
|
logger.LegacyPrintf("service.auth", "[Auth] Failed to load auth source signup defaults for %s: %v", signupSource, err)
|
||||||
return plan
|
return plan
|
||||||
}
|
}
|
||||||
|
if !enabled {
|
||||||
providerDefaults, ok := authSourceSignupSettings(defaults, signupSource)
|
|
||||||
if !ok || !providerDefaults.GrantOnSignup {
|
|
||||||
return plan
|
return plan
|
||||||
}
|
}
|
||||||
|
|
||||||
plan.Balance = providerDefaults.Balance
|
plan.Balance = resolved.Balance
|
||||||
plan.Concurrency = providerDefaults.Concurrency
|
plan.Concurrency = resolved.Concurrency
|
||||||
plan.Subscriptions = providerDefaults.Subscriptions
|
plan.Subscriptions = resolved.Subscriptions
|
||||||
return plan
|
return plan
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -321,6 +321,47 @@ func TestAuthServiceLogin_AppliesEmailFirstBindDefaultsOnlyWhenEmailIdentityIsNe
|
|||||||
require.Equal(t, 1, countProviderGrantRecords(t, client, user.ID, "email", "first_bind"))
|
require.Equal(t, 1, countProviderGrantRecords(t, client, user.ID, "email", "first_bind"))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestAuthServiceLogin_MergesEmailFirstBindSourceOverridesWithGlobalDefaults(t *testing.T) {
|
||||||
|
assigner := &authIdentityDefaultSubAssignerStub{}
|
||||||
|
svc, _, client := newAuthServiceWithEnt(t, map[string]string{
|
||||||
|
service.SettingKeyRegistrationEnabled: "true",
|
||||||
|
service.SettingKeyDefaultSubscriptions: `[{"group_id":21,"validity_days":14}]`,
|
||||||
|
service.SettingKeyAuthSourceDefaultEmailBalance: "8.5",
|
||||||
|
service.SettingKeyAuthSourceDefaultEmailConcurrency: "5",
|
||||||
|
service.SettingKeyAuthSourceDefaultEmailSubscriptions: `[]`,
|
||||||
|
service.SettingKeyAuthSourceDefaultEmailGrantOnFirstBind: "true",
|
||||||
|
}, assigner)
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
passwordHash, err := svc.HashPassword("password")
|
||||||
|
require.NoError(t, err)
|
||||||
|
user, err := client.User.Create().
|
||||||
|
SetEmail("merged-first-bind@example.com").
|
||||||
|
SetUsername("merged-user").
|
||||||
|
SetPasswordHash(passwordHash).
|
||||||
|
SetBalance(1.5).
|
||||||
|
SetConcurrency(2).
|
||||||
|
SetRole(service.RoleUser).
|
||||||
|
SetStatus(service.StatusActive).
|
||||||
|
Save(ctx)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
token, gotUser, err := svc.Login(ctx, user.Email, "password")
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotEmpty(t, token)
|
||||||
|
require.NotNil(t, gotUser)
|
||||||
|
svc.RecordSuccessfulLogin(ctx, user.ID)
|
||||||
|
|
||||||
|
storedUser, err := client.User.Get(ctx, user.ID)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, 10.0, storedUser.Balance)
|
||||||
|
require.Equal(t, 4, storedUser.Concurrency)
|
||||||
|
require.Len(t, assigner.calls, 1)
|
||||||
|
require.Equal(t, int64(21), assigner.calls[0].GroupID)
|
||||||
|
require.Equal(t, 14, assigner.calls[0].ValidityDays)
|
||||||
|
require.Equal(t, 1, countProviderGrantRecords(t, client, user.ID, "email", "first_bind"))
|
||||||
|
}
|
||||||
|
|
||||||
func TestAuthServiceLogin_DoesNotApplyEmailFirstBindDefaultsWhenIdentityAlreadyExists(t *testing.T) {
|
func TestAuthServiceLogin_DoesNotApplyEmailFirstBindDefaultsWhenIdentityAlreadyExists(t *testing.T) {
|
||||||
assigner := &authIdentityDefaultSubAssignerStub{}
|
assigner := &authIdentityDefaultSubAssignerStub{}
|
||||||
svc, _, client := newAuthServiceWithEnt(t, map[string]string{
|
svc, _, client := newAuthServiceWithEnt(t, map[string]string{
|
||||||
|
|||||||
@@ -584,6 +584,29 @@ func TestAuthService_Register_GrantOnSignupFalseFallsBackToGlobalDefaults(t *tes
|
|||||||
require.Equal(t, 5, assigner.calls[0].ValidityDays)
|
require.Equal(t, 5, assigner.calls[0].ValidityDays)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestAuthService_Register_GrantOnSignupMergesSourceOverridesWithGlobalDefaults(t *testing.T) {
|
||||||
|
repo := &userRepoStub{nextID: 54}
|
||||||
|
assigner := &defaultSubscriptionAssignerStub{}
|
||||||
|
service := newAuthService(repo, map[string]string{
|
||||||
|
SettingKeyRegistrationEnabled: "true",
|
||||||
|
SettingKeyDefaultSubscriptions: `[{"group_id":31,"validity_days":5}]`,
|
||||||
|
SettingKeyAuthSourceDefaultEmailBalance: "9.5",
|
||||||
|
SettingKeyAuthSourceDefaultEmailConcurrency: "5",
|
||||||
|
SettingKeyAuthSourceDefaultEmailSubscriptions: `[]`,
|
||||||
|
SettingKeyAuthSourceDefaultEmailGrantOnSignup: "true",
|
||||||
|
}, nil)
|
||||||
|
service.defaultSubAssigner = assigner
|
||||||
|
|
||||||
|
_, user, err := service.Register(context.Background(), "email-merged@test.com", "password")
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, user)
|
||||||
|
require.Equal(t, 9.5, user.Balance)
|
||||||
|
require.Equal(t, 2, user.Concurrency)
|
||||||
|
require.Len(t, assigner.calls, 1)
|
||||||
|
require.Equal(t, int64(31), assigner.calls[0].GroupID)
|
||||||
|
require.Equal(t, 5, assigner.calls[0].ValidityDays)
|
||||||
|
}
|
||||||
|
|
||||||
func TestAuthService_LoginOrRegisterOAuthWithTokenPair_UsesLinuxDoAuthSourceDefaultsOnSignup(t *testing.T) {
|
func TestAuthService_LoginOrRegisterOAuthWithTokenPair_UsesLinuxDoAuthSourceDefaultsOnSignup(t *testing.T) {
|
||||||
repo := &userRepoStub{nextID: 61}
|
repo := &userRepoStub{nextID: 61}
|
||||||
assigner := &defaultSubscriptionAssignerStub{}
|
assigner := &defaultSubscriptionAssignerStub{}
|
||||||
|
|||||||
@@ -569,12 +569,47 @@ func parseCustomMenuItemURLs(raw string) []string {
|
|||||||
|
|
||||||
// UpdateSettings 更新系统设置
|
// UpdateSettings 更新系统设置
|
||||||
func (s *SettingService) UpdateSettings(ctx context.Context, settings *SystemSettings) error {
|
func (s *SettingService) UpdateSettings(ctx context.Context, settings *SystemSettings) error {
|
||||||
if err := s.validateDefaultSubscriptionGroups(ctx, settings.DefaultSubscriptions); err != nil {
|
updates, err := s.buildSystemSettingsUpdates(ctx, settings)
|
||||||
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
err = s.settingRepo.SetMultiple(ctx, updates)
|
||||||
|
if err == nil {
|
||||||
|
s.refreshCachedSettings(settings)
|
||||||
|
}
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdateSettingsWithAuthSourceDefaults persists system settings and auth-source defaults in a single write.
|
||||||
|
func (s *SettingService) UpdateSettingsWithAuthSourceDefaults(ctx context.Context, settings *SystemSettings, authDefaults *AuthSourceDefaultSettings) error {
|
||||||
|
updates, err := s.buildSystemSettingsUpdates(ctx, settings)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
authSourceUpdates, err := s.buildAuthSourceDefaultUpdates(ctx, authDefaults)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
for key, value := range authSourceUpdates {
|
||||||
|
updates[key] = value
|
||||||
|
}
|
||||||
|
|
||||||
|
err = s.settingRepo.SetMultiple(ctx, updates)
|
||||||
|
if err == nil {
|
||||||
|
s.refreshCachedSettings(settings)
|
||||||
|
}
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *SettingService) buildSystemSettingsUpdates(ctx context.Context, settings *SystemSettings) (map[string]string, error) {
|
||||||
|
if err := s.validateDefaultSubscriptionGroups(ctx, settings.DefaultSubscriptions); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
normalizedWhitelist, err := NormalizeRegistrationEmailSuffixWhitelist(settings.RegistrationEmailSuffixWhitelist)
|
normalizedWhitelist, err := NormalizeRegistrationEmailSuffixWhitelist(settings.RegistrationEmailSuffixWhitelist)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return infraerrors.BadRequest("INVALID_REGISTRATION_EMAIL_SUFFIX_WHITELIST", err.Error())
|
return nil, infraerrors.BadRequest("INVALID_REGISTRATION_EMAIL_SUFFIX_WHITELIST", err.Error())
|
||||||
}
|
}
|
||||||
if normalizedWhitelist == nil {
|
if normalizedWhitelist == nil {
|
||||||
normalizedWhitelist = []string{}
|
normalizedWhitelist = []string{}
|
||||||
@@ -582,11 +617,11 @@ func (s *SettingService) UpdateSettings(ctx context.Context, settings *SystemSet
|
|||||||
settings.RegistrationEmailSuffixWhitelist = normalizedWhitelist
|
settings.RegistrationEmailSuffixWhitelist = normalizedWhitelist
|
||||||
alipaySource, err := normalizeVisibleMethodSettingSource("alipay", settings.PaymentVisibleMethodAlipaySource, settings.PaymentVisibleMethodAlipayEnabled)
|
alipaySource, err := normalizeVisibleMethodSettingSource("alipay", settings.PaymentVisibleMethodAlipaySource, settings.PaymentVisibleMethodAlipayEnabled)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return nil, err
|
||||||
}
|
}
|
||||||
wxpaySource, err := normalizeVisibleMethodSettingSource("wxpay", settings.PaymentVisibleMethodWxpaySource, settings.PaymentVisibleMethodWxpayEnabled)
|
wxpaySource, err := normalizeVisibleMethodSettingSource("wxpay", settings.PaymentVisibleMethodWxpaySource, settings.PaymentVisibleMethodWxpayEnabled)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return nil, err
|
||||||
}
|
}
|
||||||
settings.PaymentVisibleMethodAlipaySource = alipaySource
|
settings.PaymentVisibleMethodAlipaySource = alipaySource
|
||||||
settings.PaymentVisibleMethodWxpaySource = wxpaySource
|
settings.PaymentVisibleMethodWxpaySource = wxpaySource
|
||||||
@@ -598,7 +633,7 @@ func (s *SettingService) UpdateSettings(ctx context.Context, settings *SystemSet
|
|||||||
updates[SettingKeyEmailVerifyEnabled] = strconv.FormatBool(settings.EmailVerifyEnabled)
|
updates[SettingKeyEmailVerifyEnabled] = strconv.FormatBool(settings.EmailVerifyEnabled)
|
||||||
registrationEmailSuffixWhitelistJSON, err := json.Marshal(settings.RegistrationEmailSuffixWhitelist)
|
registrationEmailSuffixWhitelistJSON, err := json.Marshal(settings.RegistrationEmailSuffixWhitelist)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("marshal registration email suffix whitelist: %w", err)
|
return nil, fmt.Errorf("marshal registration email suffix whitelist: %w", err)
|
||||||
}
|
}
|
||||||
updates[SettingKeyRegistrationEmailSuffixWhitelist] = string(registrationEmailSuffixWhitelistJSON)
|
updates[SettingKeyRegistrationEmailSuffixWhitelist] = string(registrationEmailSuffixWhitelistJSON)
|
||||||
updates[SettingKeyPromoCodeEnabled] = strconv.FormatBool(settings.PromoCodeEnabled)
|
updates[SettingKeyPromoCodeEnabled] = strconv.FormatBool(settings.PromoCodeEnabled)
|
||||||
@@ -677,7 +712,7 @@ func (s *SettingService) UpdateSettings(ctx context.Context, settings *SystemSet
|
|||||||
updates[SettingKeyTableDefaultPageSize] = strconv.Itoa(tableDefaultPageSize)
|
updates[SettingKeyTableDefaultPageSize] = strconv.Itoa(tableDefaultPageSize)
|
||||||
tablePageSizeOptionsJSON, err := json.Marshal(tablePageSizeOptions)
|
tablePageSizeOptionsJSON, err := json.Marshal(tablePageSizeOptions)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("marshal table page size options: %w", err)
|
return nil, fmt.Errorf("marshal table page size options: %w", err)
|
||||||
}
|
}
|
||||||
updates[SettingKeyTablePageSizeOptions] = string(tablePageSizeOptionsJSON)
|
updates[SettingKeyTablePageSizeOptions] = string(tablePageSizeOptionsJSON)
|
||||||
updates[SettingKeyCustomMenuItems] = settings.CustomMenuItems
|
updates[SettingKeyCustomMenuItems] = settings.CustomMenuItems
|
||||||
@@ -688,7 +723,7 @@ func (s *SettingService) UpdateSettings(ctx context.Context, settings *SystemSet
|
|||||||
updates[SettingKeyDefaultBalance] = strconv.FormatFloat(settings.DefaultBalance, 'f', 8, 64)
|
updates[SettingKeyDefaultBalance] = strconv.FormatFloat(settings.DefaultBalance, 'f', 8, 64)
|
||||||
defaultSubsJSON, err := json.Marshal(settings.DefaultSubscriptions)
|
defaultSubsJSON, err := json.Marshal(settings.DefaultSubscriptions)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("marshal default subscriptions: %w", err)
|
return nil, fmt.Errorf("marshal default subscriptions: %w", err)
|
||||||
}
|
}
|
||||||
updates[SettingKeyDefaultSubscriptions] = string(defaultSubsJSON)
|
updates[SettingKeyDefaultSubscriptions] = string(defaultSubsJSON)
|
||||||
|
|
||||||
@@ -738,37 +773,66 @@ func (s *SettingService) UpdateSettings(ctx context.Context, settings *SystemSet
|
|||||||
updates[SettingKeyAccountQuotaNotifyEnabled] = strconv.FormatBool(settings.AccountQuotaNotifyEnabled)
|
updates[SettingKeyAccountQuotaNotifyEnabled] = strconv.FormatBool(settings.AccountQuotaNotifyEnabled)
|
||||||
updates[SettingKeyAccountQuotaNotifyEmails] = MarshalNotifyEmails(settings.AccountQuotaNotifyEmails)
|
updates[SettingKeyAccountQuotaNotifyEmails] = MarshalNotifyEmails(settings.AccountQuotaNotifyEmails)
|
||||||
|
|
||||||
err = s.settingRepo.SetMultiple(ctx, updates)
|
return updates, nil
|
||||||
if err == nil {
|
}
|
||||||
// 先使 inflight singleflight 失效,再刷新缓存,缩小旧值覆盖新值的竞态窗口
|
|
||||||
versionBoundsSF.Forget("version_bounds")
|
func (s *SettingService) buildAuthSourceDefaultUpdates(ctx context.Context, settings *AuthSourceDefaultSettings) (map[string]string, error) {
|
||||||
versionBoundsCache.Store(&cachedVersionBounds{
|
if settings == nil {
|
||||||
min: settings.MinClaudeCodeVersion,
|
return nil, nil
|
||||||
max: settings.MaxClaudeCodeVersion,
|
}
|
||||||
expiresAt: time.Now().Add(versionBoundsCacheTTL).UnixNano(),
|
|
||||||
})
|
for _, subscriptions := range [][]DefaultSubscriptionSetting{
|
||||||
backendModeSF.Forget("backend_mode")
|
settings.Email.Subscriptions,
|
||||||
backendModeCache.Store(&cachedBackendMode{
|
settings.LinuxDo.Subscriptions,
|
||||||
value: settings.BackendModeEnabled,
|
settings.OIDC.Subscriptions,
|
||||||
expiresAt: time.Now().Add(backendModeCacheTTL).UnixNano(),
|
settings.WeChat.Subscriptions,
|
||||||
})
|
} {
|
||||||
gatewayForwardingSF.Forget("gateway_forwarding")
|
if err := s.validateDefaultSubscriptionGroups(ctx, subscriptions); err != nil {
|
||||||
gatewayForwardingCache.Store(&cachedGatewayForwardingSettings{
|
return nil, err
|
||||||
fingerprintUnification: settings.EnableFingerprintUnification,
|
|
||||||
metadataPassthrough: settings.EnableMetadataPassthrough,
|
|
||||||
cchSigning: settings.EnableCCHSigning,
|
|
||||||
expiresAt: time.Now().Add(gatewayForwardingCacheTTL).UnixNano(),
|
|
||||||
})
|
|
||||||
openAIAdvancedSchedulerSettingSF.Forget(openAIAdvancedSchedulerSettingKey)
|
|
||||||
openAIAdvancedSchedulerSettingCache.Store(&cachedOpenAIAdvancedSchedulerSetting{
|
|
||||||
enabled: settings.OpenAIAdvancedSchedulerEnabled,
|
|
||||||
expiresAt: time.Now().Add(openAIAdvancedSchedulerSettingCacheTTL).UnixNano(),
|
|
||||||
})
|
|
||||||
if s.onUpdate != nil {
|
|
||||||
s.onUpdate() // Invalidate cache after settings update
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return err
|
|
||||||
|
updates := make(map[string]string, 21)
|
||||||
|
writeProviderDefaultGrantUpdates(updates, emailAuthSourceDefaultKeys, settings.Email)
|
||||||
|
writeProviderDefaultGrantUpdates(updates, linuxDoAuthSourceDefaultKeys, settings.LinuxDo)
|
||||||
|
writeProviderDefaultGrantUpdates(updates, oidcAuthSourceDefaultKeys, settings.OIDC)
|
||||||
|
writeProviderDefaultGrantUpdates(updates, weChatAuthSourceDefaultKeys, settings.WeChat)
|
||||||
|
updates[SettingKeyForceEmailOnThirdPartySignup] = strconv.FormatBool(settings.ForceEmailOnThirdPartySignup)
|
||||||
|
return updates, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *SettingService) refreshCachedSettings(settings *SystemSettings) {
|
||||||
|
if settings == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 先使 inflight singleflight 失效,再刷新缓存,缩小旧值覆盖新值的竞态窗口
|
||||||
|
versionBoundsSF.Forget("version_bounds")
|
||||||
|
versionBoundsCache.Store(&cachedVersionBounds{
|
||||||
|
min: settings.MinClaudeCodeVersion,
|
||||||
|
max: settings.MaxClaudeCodeVersion,
|
||||||
|
expiresAt: time.Now().Add(versionBoundsCacheTTL).UnixNano(),
|
||||||
|
})
|
||||||
|
backendModeSF.Forget("backend_mode")
|
||||||
|
backendModeCache.Store(&cachedBackendMode{
|
||||||
|
value: settings.BackendModeEnabled,
|
||||||
|
expiresAt: time.Now().Add(backendModeCacheTTL).UnixNano(),
|
||||||
|
})
|
||||||
|
gatewayForwardingSF.Forget("gateway_forwarding")
|
||||||
|
gatewayForwardingCache.Store(&cachedGatewayForwardingSettings{
|
||||||
|
fingerprintUnification: settings.EnableFingerprintUnification,
|
||||||
|
metadataPassthrough: settings.EnableMetadataPassthrough,
|
||||||
|
cchSigning: settings.EnableCCHSigning,
|
||||||
|
expiresAt: time.Now().Add(gatewayForwardingCacheTTL).UnixNano(),
|
||||||
|
})
|
||||||
|
openAIAdvancedSchedulerSettingSF.Forget(openAIAdvancedSchedulerSettingKey)
|
||||||
|
openAIAdvancedSchedulerSettingCache.Store(&cachedOpenAIAdvancedSchedulerSetting{
|
||||||
|
enabled: settings.OpenAIAdvancedSchedulerEnabled,
|
||||||
|
expiresAt: time.Now().Add(openAIAdvancedSchedulerSettingCacheTTL).UnixNano(),
|
||||||
|
})
|
||||||
|
if s.onUpdate != nil {
|
||||||
|
s.onUpdate() // Invalidate cache after settings update
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *SettingService) validateDefaultSubscriptionGroups(ctx context.Context, items []DefaultSubscriptionSetting) error {
|
func (s *SettingService) validateDefaultSubscriptionGroups(ctx context.Context, items []DefaultSubscriptionSetting) error {
|
||||||
@@ -1067,29 +1131,43 @@ func (s *SettingService) GetAuthSourceDefaultSettings(ctx context.Context) (*Aut
|
|||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *SettingService) ResolveAuthSourceGrantSettings(ctx context.Context, signupSource string, firstBind bool) (ProviderDefaultGrantSettings, bool, error) {
|
||||||
|
result := ProviderDefaultGrantSettings{
|
||||||
|
Balance: s.GetDefaultBalance(ctx),
|
||||||
|
Concurrency: s.GetDefaultConcurrency(ctx),
|
||||||
|
Subscriptions: s.GetDefaultSubscriptions(ctx),
|
||||||
|
}
|
||||||
|
|
||||||
|
defaults, err := s.GetAuthSourceDefaultSettings(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return result, false, err
|
||||||
|
}
|
||||||
|
|
||||||
|
providerDefaults, ok := authSourceSignupSettings(defaults, signupSource)
|
||||||
|
if !ok {
|
||||||
|
return result, false, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
enabled := providerDefaults.GrantOnSignup
|
||||||
|
if firstBind {
|
||||||
|
enabled = providerDefaults.GrantOnFirstBind
|
||||||
|
}
|
||||||
|
if !enabled {
|
||||||
|
return result, false, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return mergeProviderDefaultGrantSettings(result, providerDefaults), true, nil
|
||||||
|
}
|
||||||
|
|
||||||
func (s *SettingService) UpdateAuthSourceDefaultSettings(ctx context.Context, settings *AuthSourceDefaultSettings) error {
|
func (s *SettingService) UpdateAuthSourceDefaultSettings(ctx context.Context, settings *AuthSourceDefaultSettings) error {
|
||||||
if settings == nil {
|
updates, err := s.buildAuthSourceDefaultUpdates(ctx, settings)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if len(updates) == 0 {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, subscriptions := range [][]DefaultSubscriptionSetting{
|
|
||||||
settings.Email.Subscriptions,
|
|
||||||
settings.LinuxDo.Subscriptions,
|
|
||||||
settings.OIDC.Subscriptions,
|
|
||||||
settings.WeChat.Subscriptions,
|
|
||||||
} {
|
|
||||||
if err := s.validateDefaultSubscriptionGroups(ctx, subscriptions); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
updates := make(map[string]string, 21)
|
|
||||||
writeProviderDefaultGrantUpdates(updates, emailAuthSourceDefaultKeys, settings.Email)
|
|
||||||
writeProviderDefaultGrantUpdates(updates, linuxDoAuthSourceDefaultKeys, settings.LinuxDo)
|
|
||||||
writeProviderDefaultGrantUpdates(updates, oidcAuthSourceDefaultKeys, settings.OIDC)
|
|
||||||
writeProviderDefaultGrantUpdates(updates, weChatAuthSourceDefaultKeys, settings.WeChat)
|
|
||||||
updates[SettingKeyForceEmailOnThirdPartySignup] = strconv.FormatBool(settings.ForceEmailOnThirdPartySignup)
|
|
||||||
|
|
||||||
if err := s.settingRepo.SetMultiple(ctx, updates); err != nil {
|
if err := s.settingRepo.SetMultiple(ctx, updates); err != nil {
|
||||||
return fmt.Errorf("update auth source default settings: %w", err)
|
return fmt.Errorf("update auth source default settings: %w", err)
|
||||||
}
|
}
|
||||||
@@ -1594,6 +1672,28 @@ func writeProviderDefaultGrantUpdates(updates map[string]string, keys authSource
|
|||||||
updates[keys.grantOnFirstBind] = strconv.FormatBool(settings.GrantOnFirstBind)
|
updates[keys.grantOnFirstBind] = strconv.FormatBool(settings.GrantOnFirstBind)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func mergeProviderDefaultGrantSettings(globalDefaults ProviderDefaultGrantSettings, providerDefaults ProviderDefaultGrantSettings) ProviderDefaultGrantSettings {
|
||||||
|
result := ProviderDefaultGrantSettings{
|
||||||
|
Balance: globalDefaults.Balance,
|
||||||
|
Concurrency: globalDefaults.Concurrency,
|
||||||
|
Subscriptions: append([]DefaultSubscriptionSetting(nil), globalDefaults.Subscriptions...),
|
||||||
|
GrantOnSignup: providerDefaults.GrantOnSignup,
|
||||||
|
GrantOnFirstBind: providerDefaults.GrantOnFirstBind,
|
||||||
|
}
|
||||||
|
|
||||||
|
if providerDefaults.Balance != defaultAuthSourceBalance {
|
||||||
|
result.Balance = providerDefaults.Balance
|
||||||
|
}
|
||||||
|
if providerDefaults.Concurrency > 0 && providerDefaults.Concurrency != defaultAuthSourceConcurrency {
|
||||||
|
result.Concurrency = providerDefaults.Concurrency
|
||||||
|
}
|
||||||
|
if len(providerDefaults.Subscriptions) > 0 {
|
||||||
|
result.Subscriptions = append([]DefaultSubscriptionSetting(nil), providerDefaults.Subscriptions...)
|
||||||
|
}
|
||||||
|
|
||||||
|
return result
|
||||||
|
}
|
||||||
|
|
||||||
func parseTablePreferences(defaultPageSizeRaw, optionsRaw string) (int, []int) {
|
func parseTablePreferences(defaultPageSizeRaw, optionsRaw string) (int, []int) {
|
||||||
defaultPageSize := 20
|
defaultPageSize := 20
|
||||||
if v, err := strconv.Atoi(strings.TrimSpace(defaultPageSizeRaw)); err == nil {
|
if v, err := strconv.Atoi(strings.TrimSpace(defaultPageSizeRaw)); err == nil {
|
||||||
|
|||||||
Reference in New Issue
Block a user