diff --git a/backend/internal/handler/admin/setting_handler.go b/backend/internal/handler/admin/setting_handler.go index e5681208..f0e91f3a 100644 --- a/backend/internal/handler/admin/setting_handler.go +++ b/backend/internal/handler/admin/setting_handler.go @@ -1011,10 +1011,6 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) { }(), } - if err := h.settingService.UpdateSettings(c.Request.Context(), settings); err != nil { - response.ErrorFrom(c, err) - return - } authSourceDefaults := &service.AuthSourceDefaultSettings{ Email: service.ProviderDefaultGrantSettings{ Balance: float64ValueOrDefault(req.AuthSourceDefaultEmailBalance, previousAuthSourceDefaults.Email.Balance), @@ -1046,7 +1042,7 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) { }, ForceEmailOnThirdPartySignup: boolValueOrDefault(req.ForceEmailOnThirdPartySignup, previousAuthSourceDefaults.ForceEmailOnThirdPartySignup), } - if err := h.settingService.UpdateAuthSourceDefaultSettings(c.Request.Context(), authSourceDefaults); err != nil { + if err := h.settingService.UpdateSettingsWithAuthSourceDefaults(c.Request.Context(), settings, authSourceDefaults); err != nil { response.ErrorFrom(c, err) return } @@ -1086,7 +1082,7 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) { } } - h.auditSettingsUpdate(c, previousSettings, settings, req) + h.auditSettingsUpdate(c, previousSettings, settings, previousAuthSourceDefaults, authSourceDefaults, req) // 重新获取设置返回 updatedSettings, err := h.settingService.GetAllSettings(c.Request.Context()) @@ -1245,12 +1241,12 @@ func hasPaymentFields(req UpdateSettingsRequest) bool { req.PaymentCancelRateLimitUnit != nil || req.PaymentCancelRateLimitMode != nil } -func (h *SettingHandler) auditSettingsUpdate(c *gin.Context, before *service.SystemSettings, after *service.SystemSettings, req UpdateSettingsRequest) { +func (h *SettingHandler) auditSettingsUpdate(c *gin.Context, before *service.SystemSettings, after *service.SystemSettings, beforeAuthSourceDefaults *service.AuthSourceDefaultSettings, afterAuthSourceDefaults *service.AuthSourceDefaultSettings, req UpdateSettingsRequest) { if before == nil || after == nil { return } - changed := diffSettings(before, after, req) + changed := diffSettings(before, after, beforeAuthSourceDefaults, afterAuthSourceDefaults, req) if len(changed) == 0 { return } @@ -1265,7 +1261,7 @@ func (h *SettingHandler) auditSettingsUpdate(c *gin.Context, before *service.Sys ) } -func diffSettings(before *service.SystemSettings, after *service.SystemSettings, req UpdateSettingsRequest) []string { +func diffSettings(before *service.SystemSettings, after *service.SystemSettings, beforeAuthSourceDefaults *service.AuthSourceDefaultSettings, afterAuthSourceDefaults *service.AuthSourceDefaultSettings, req UpdateSettingsRequest) []string { changed := make([]string, 0, 20) if before.RegistrationEnabled != after.RegistrationEnabled { changed = append(changed, "registration_enabled") @@ -1535,6 +1531,50 @@ func diffSettings(before *service.SystemSettings, after *service.SystemSettings, if !equalNotifyEmailEntries(before.AccountQuotaNotifyEmails, after.AccountQuotaNotifyEmails) { changed = append(changed, "account_quota_notify_emails") } + changed = appendAuthSourceDefaultChanges(changed, beforeAuthSourceDefaults, afterAuthSourceDefaults) + return changed +} + +func appendAuthSourceDefaultChanges(changed []string, before *service.AuthSourceDefaultSettings, after *service.AuthSourceDefaultSettings) []string { + if before == nil { + before = &service.AuthSourceDefaultSettings{} + } + if after == nil { + after = &service.AuthSourceDefaultSettings{} + } + + type providerDefaultGrantField struct { + name string + before service.ProviderDefaultGrantSettings + after service.ProviderDefaultGrantSettings + } + + fields := []providerDefaultGrantField{ + {name: "email", before: before.Email, after: after.Email}, + {name: "linuxdo", before: before.LinuxDo, after: after.LinuxDo}, + {name: "oidc", before: before.OIDC, after: after.OIDC}, + {name: "wechat", before: before.WeChat, after: after.WeChat}, + } + for _, field := range fields { + if field.before.Balance != field.after.Balance { + changed = append(changed, "auth_source_default_"+field.name+"_balance") + } + if field.before.Concurrency != field.after.Concurrency { + changed = append(changed, "auth_source_default_"+field.name+"_concurrency") + } + if !equalDefaultSubscriptions(field.before.Subscriptions, field.after.Subscriptions) { + changed = append(changed, "auth_source_default_"+field.name+"_subscriptions") + } + if field.before.GrantOnSignup != field.after.GrantOnSignup { + changed = append(changed, "auth_source_default_"+field.name+"_grant_on_signup") + } + if field.before.GrantOnFirstBind != field.after.GrantOnFirstBind { + changed = append(changed, "auth_source_default_"+field.name+"_grant_on_first_bind") + } + } + if before.ForceEmailOnThirdPartySignup != after.ForceEmailOnThirdPartySignup { + changed = append(changed, "force_email_on_third_party_signup") + } return changed } diff --git a/backend/internal/handler/admin/setting_handler_auth_source_defaults_test.go b/backend/internal/handler/admin/setting_handler_auth_source_defaults_test.go index bf51fc68..cef531e0 100644 --- a/backend/internal/handler/admin/setting_handler_auth_source_defaults_test.go +++ b/backend/internal/handler/admin/setting_handler_auth_source_defaults_test.go @@ -4,6 +4,7 @@ import ( "bytes" "context" "encoding/json" + "errors" "net/http" "net/http/httptest" "testing" @@ -66,6 +67,58 @@ func (s *settingHandlerRepoStub) Delete(ctx context.Context, key string) error { panic("unexpected Delete call") } +type failingAuthSourceSettingsRepoStub struct { + values map[string]string + err error +} + +func (s *failingAuthSourceSettingsRepoStub) Get(ctx context.Context, key string) (*service.Setting, error) { + panic("unexpected Get call") +} + +func (s *failingAuthSourceSettingsRepoStub) GetValue(ctx context.Context, key string) (string, error) { + panic("unexpected GetValue call") +} + +func (s *failingAuthSourceSettingsRepoStub) Set(ctx context.Context, key, value string) error { + panic("unexpected Set call") +} + +func (s *failingAuthSourceSettingsRepoStub) GetMultiple(ctx context.Context, keys []string) (map[string]string, error) { + out := make(map[string]string, len(keys)) + for _, key := range keys { + if value, ok := s.values[key]; ok { + out[key] = value + } + } + return out, nil +} + +func (s *failingAuthSourceSettingsRepoStub) SetMultiple(ctx context.Context, settings map[string]string) error { + if _, ok := settings[service.SettingKeyAuthSourceDefaultEmailBalance]; ok { + return s.err + } + for key, value := range settings { + if s.values == nil { + s.values = map[string]string{} + } + s.values[key] = value + } + return nil +} + +func (s *failingAuthSourceSettingsRepoStub) GetAll(ctx context.Context) (map[string]string, error) { + out := make(map[string]string, len(s.values)) + for key, value := range s.values { + out[key] = value + } + return out, nil +} + +func (s *failingAuthSourceSettingsRepoStub) Delete(ctx context.Context, key string) error { + panic("unexpected Delete call") +} + func TestSettingHandler_GetSettings_InjectsAuthSourceDefaults(t *testing.T) { gin.SetMode(gin.TestMode) repo := &settingHandlerRepoStub{ @@ -221,3 +274,73 @@ func TestSettingHandler_UpdateSettings_RejectsInvalidPaymentVisibleMethodSource( require.Equal(t, http.StatusBadRequest, rec.Code) require.NotContains(t, repo.values, service.SettingPaymentVisibleMethodAlipaySource) } + +func TestSettingHandler_UpdateSettings_DoesNotPersistPartialSystemSettingsWhenAuthSourceDefaultsFail(t *testing.T) { + gin.SetMode(gin.TestMode) + repo := &failingAuthSourceSettingsRepoStub{ + values: map[string]string{ + service.SettingKeyRegistrationEnabled: "false", + service.SettingKeyPromoCodeEnabled: "true", + service.SettingKeyAuthSourceDefaultEmailBalance: "9.5", + service.SettingKeyAuthSourceDefaultEmailConcurrency: "8", + service.SettingKeyAuthSourceDefaultEmailSubscriptions: `[{"group_id":31,"validity_days":15}]`, + }, + err: errors.New("write auth source defaults failed"), + } + svc := service.NewSettingService(repo, &config.Config{Default: config.DefaultConfig{UserConcurrency: 5}}) + handler := NewSettingHandler(svc, nil, nil, nil, nil, nil) + + body := map[string]any{ + "registration_enabled": true, + "promo_code_enabled": true, + "auth_source_default_email_balance": 12.75, + } + rawBody, err := json.Marshal(body) + require.NoError(t, err) + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPut, "/api/v1/admin/settings", bytes.NewReader(rawBody)) + c.Request.Header.Set("Content-Type", "application/json") + + handler.UpdateSettings(c) + + require.Equal(t, http.StatusInternalServerError, rec.Code) + require.Equal(t, "false", repo.values[service.SettingKeyRegistrationEnabled]) + require.Equal(t, "9.5", repo.values[service.SettingKeyAuthSourceDefaultEmailBalance]) +} + +func TestDiffSettings_IncludesAuthSourceDefaultsAndForceEmail(t *testing.T) { + changed := diffSettings( + &service.SystemSettings{}, + &service.SystemSettings{}, + &service.AuthSourceDefaultSettings{ + Email: service.ProviderDefaultGrantSettings{ + Balance: 0, + Concurrency: 5, + Subscriptions: nil, + GrantOnSignup: true, + GrantOnFirstBind: false, + }, + ForceEmailOnThirdPartySignup: false, + }, + &service.AuthSourceDefaultSettings{ + Email: service.ProviderDefaultGrantSettings{ + Balance: 12.5, + Concurrency: 7, + Subscriptions: []service.DefaultSubscriptionSetting{{GroupID: 21, ValidityDays: 30}}, + GrantOnSignup: false, + GrantOnFirstBind: true, + }, + ForceEmailOnThirdPartySignup: true, + }, + UpdateSettingsRequest{}, + ) + + require.Contains(t, changed, "auth_source_default_email_balance") + require.Contains(t, changed, "auth_source_default_email_concurrency") + require.Contains(t, changed, "auth_source_default_email_subscriptions") + require.Contains(t, changed, "auth_source_default_email_grant_on_signup") + require.Contains(t, changed, "auth_source_default_email_grant_on_first_bind") + require.Contains(t, changed, "force_email_on_third_party_signup") +} diff --git a/backend/internal/service/auth_oauth_first_bind.go b/backend/internal/service/auth_oauth_first_bind.go index 422a2a88..b0da5069 100644 --- a/backend/internal/service/auth_oauth_first_bind.go +++ b/backend/internal/service/auth_oauth_first_bind.go @@ -44,13 +44,11 @@ func (s *AuthService) applyProviderDefaultSettingsOnFirstBind( userID int64, providerType string, ) error { - defaults, err := s.settingService.GetAuthSourceDefaultSettings(ctx) + providerDefaults, enabled, err := s.settingService.ResolveAuthSourceGrantSettings(ctx, providerType, true) if err != nil { return fmt.Errorf("load auth source defaults: %w", err) } - - providerDefaults, ok := authSourceSignupSettings(defaults, providerType) - if !ok || !providerDefaults.GrantOnFirstBind { + if !enabled { return nil } diff --git a/backend/internal/service/auth_service.go b/backend/internal/service/auth_service.go index d63a8753..5b6e5fef 100644 --- a/backend/internal/service/auth_service.go +++ b/backend/internal/service/auth_service.go @@ -716,20 +716,18 @@ func (s *AuthService) resolveSignupGrantPlan(ctx context.Context, signupSource s plan.Concurrency = s.settingService.GetDefaultConcurrency(ctx) plan.Subscriptions = s.settingService.GetDefaultSubscriptions(ctx) - defaults, err := s.settingService.GetAuthSourceDefaultSettings(ctx) + resolved, enabled, err := s.settingService.ResolveAuthSourceGrantSettings(ctx, signupSource, false) if err != nil { logger.LegacyPrintf("service.auth", "[Auth] Failed to load auth source signup defaults for %s: %v", signupSource, err) return plan } - - providerDefaults, ok := authSourceSignupSettings(defaults, signupSource) - if !ok || !providerDefaults.GrantOnSignup { + if !enabled { return plan } - plan.Balance = providerDefaults.Balance - plan.Concurrency = providerDefaults.Concurrency - plan.Subscriptions = providerDefaults.Subscriptions + plan.Balance = resolved.Balance + plan.Concurrency = resolved.Concurrency + plan.Subscriptions = resolved.Subscriptions return plan } diff --git a/backend/internal/service/auth_service_identity_sync_test.go b/backend/internal/service/auth_service_identity_sync_test.go index 85c13604..4d2a840f 100644 --- a/backend/internal/service/auth_service_identity_sync_test.go +++ b/backend/internal/service/auth_service_identity_sync_test.go @@ -321,6 +321,47 @@ func TestAuthServiceLogin_AppliesEmailFirstBindDefaultsOnlyWhenEmailIdentityIsNe require.Equal(t, 1, countProviderGrantRecords(t, client, user.ID, "email", "first_bind")) } +func TestAuthServiceLogin_MergesEmailFirstBindSourceOverridesWithGlobalDefaults(t *testing.T) { + assigner := &authIdentityDefaultSubAssignerStub{} + svc, _, client := newAuthServiceWithEnt(t, map[string]string{ + service.SettingKeyRegistrationEnabled: "true", + service.SettingKeyDefaultSubscriptions: `[{"group_id":21,"validity_days":14}]`, + service.SettingKeyAuthSourceDefaultEmailBalance: "8.5", + service.SettingKeyAuthSourceDefaultEmailConcurrency: "5", + service.SettingKeyAuthSourceDefaultEmailSubscriptions: `[]`, + service.SettingKeyAuthSourceDefaultEmailGrantOnFirstBind: "true", + }, assigner) + ctx := context.Background() + + passwordHash, err := svc.HashPassword("password") + require.NoError(t, err) + user, err := client.User.Create(). + SetEmail("merged-first-bind@example.com"). + SetUsername("merged-user"). + SetPasswordHash(passwordHash). + SetBalance(1.5). + SetConcurrency(2). + SetRole(service.RoleUser). + SetStatus(service.StatusActive). + Save(ctx) + require.NoError(t, err) + + token, gotUser, err := svc.Login(ctx, user.Email, "password") + require.NoError(t, err) + require.NotEmpty(t, token) + require.NotNil(t, gotUser) + svc.RecordSuccessfulLogin(ctx, user.ID) + + storedUser, err := client.User.Get(ctx, user.ID) + require.NoError(t, err) + require.Equal(t, 10.0, storedUser.Balance) + require.Equal(t, 4, storedUser.Concurrency) + require.Len(t, assigner.calls, 1) + require.Equal(t, int64(21), assigner.calls[0].GroupID) + require.Equal(t, 14, assigner.calls[0].ValidityDays) + require.Equal(t, 1, countProviderGrantRecords(t, client, user.ID, "email", "first_bind")) +} + func TestAuthServiceLogin_DoesNotApplyEmailFirstBindDefaultsWhenIdentityAlreadyExists(t *testing.T) { assigner := &authIdentityDefaultSubAssignerStub{} svc, _, client := newAuthServiceWithEnt(t, map[string]string{ diff --git a/backend/internal/service/auth_service_register_test.go b/backend/internal/service/auth_service_register_test.go index e0dce982..dbd18a20 100644 --- a/backend/internal/service/auth_service_register_test.go +++ b/backend/internal/service/auth_service_register_test.go @@ -584,6 +584,29 @@ func TestAuthService_Register_GrantOnSignupFalseFallsBackToGlobalDefaults(t *tes require.Equal(t, 5, assigner.calls[0].ValidityDays) } +func TestAuthService_Register_GrantOnSignupMergesSourceOverridesWithGlobalDefaults(t *testing.T) { + repo := &userRepoStub{nextID: 54} + assigner := &defaultSubscriptionAssignerStub{} + service := newAuthService(repo, map[string]string{ + SettingKeyRegistrationEnabled: "true", + SettingKeyDefaultSubscriptions: `[{"group_id":31,"validity_days":5}]`, + SettingKeyAuthSourceDefaultEmailBalance: "9.5", + SettingKeyAuthSourceDefaultEmailConcurrency: "5", + SettingKeyAuthSourceDefaultEmailSubscriptions: `[]`, + SettingKeyAuthSourceDefaultEmailGrantOnSignup: "true", + }, nil) + service.defaultSubAssigner = assigner + + _, user, err := service.Register(context.Background(), "email-merged@test.com", "password") + require.NoError(t, err) + require.NotNil(t, user) + require.Equal(t, 9.5, user.Balance) + require.Equal(t, 2, user.Concurrency) + require.Len(t, assigner.calls, 1) + require.Equal(t, int64(31), assigner.calls[0].GroupID) + require.Equal(t, 5, assigner.calls[0].ValidityDays) +} + func TestAuthService_LoginOrRegisterOAuthWithTokenPair_UsesLinuxDoAuthSourceDefaultsOnSignup(t *testing.T) { repo := &userRepoStub{nextID: 61} assigner := &defaultSubscriptionAssignerStub{} diff --git a/backend/internal/service/setting_service.go b/backend/internal/service/setting_service.go index 58246111..8c879d52 100644 --- a/backend/internal/service/setting_service.go +++ b/backend/internal/service/setting_service.go @@ -569,12 +569,47 @@ func parseCustomMenuItemURLs(raw string) []string { // UpdateSettings 更新系统设置 func (s *SettingService) UpdateSettings(ctx context.Context, settings *SystemSettings) error { - if err := s.validateDefaultSubscriptionGroups(ctx, settings.DefaultSubscriptions); err != nil { + updates, err := s.buildSystemSettingsUpdates(ctx, settings) + if err != nil { return err } + + err = s.settingRepo.SetMultiple(ctx, updates) + if err == nil { + s.refreshCachedSettings(settings) + } + return err +} + +// UpdateSettingsWithAuthSourceDefaults persists system settings and auth-source defaults in a single write. +func (s *SettingService) UpdateSettingsWithAuthSourceDefaults(ctx context.Context, settings *SystemSettings, authDefaults *AuthSourceDefaultSettings) error { + updates, err := s.buildSystemSettingsUpdates(ctx, settings) + if err != nil { + return err + } + + authSourceUpdates, err := s.buildAuthSourceDefaultUpdates(ctx, authDefaults) + if err != nil { + return err + } + for key, value := range authSourceUpdates { + updates[key] = value + } + + err = s.settingRepo.SetMultiple(ctx, updates) + if err == nil { + s.refreshCachedSettings(settings) + } + return err +} + +func (s *SettingService) buildSystemSettingsUpdates(ctx context.Context, settings *SystemSettings) (map[string]string, error) { + if err := s.validateDefaultSubscriptionGroups(ctx, settings.DefaultSubscriptions); err != nil { + return nil, err + } normalizedWhitelist, err := NormalizeRegistrationEmailSuffixWhitelist(settings.RegistrationEmailSuffixWhitelist) if err != nil { - return infraerrors.BadRequest("INVALID_REGISTRATION_EMAIL_SUFFIX_WHITELIST", err.Error()) + return nil, infraerrors.BadRequest("INVALID_REGISTRATION_EMAIL_SUFFIX_WHITELIST", err.Error()) } if normalizedWhitelist == nil { normalizedWhitelist = []string{} @@ -582,11 +617,11 @@ func (s *SettingService) UpdateSettings(ctx context.Context, settings *SystemSet settings.RegistrationEmailSuffixWhitelist = normalizedWhitelist alipaySource, err := normalizeVisibleMethodSettingSource("alipay", settings.PaymentVisibleMethodAlipaySource, settings.PaymentVisibleMethodAlipayEnabled) if err != nil { - return err + return nil, err } wxpaySource, err := normalizeVisibleMethodSettingSource("wxpay", settings.PaymentVisibleMethodWxpaySource, settings.PaymentVisibleMethodWxpayEnabled) if err != nil { - return err + return nil, err } settings.PaymentVisibleMethodAlipaySource = alipaySource settings.PaymentVisibleMethodWxpaySource = wxpaySource @@ -598,7 +633,7 @@ func (s *SettingService) UpdateSettings(ctx context.Context, settings *SystemSet updates[SettingKeyEmailVerifyEnabled] = strconv.FormatBool(settings.EmailVerifyEnabled) registrationEmailSuffixWhitelistJSON, err := json.Marshal(settings.RegistrationEmailSuffixWhitelist) if err != nil { - return fmt.Errorf("marshal registration email suffix whitelist: %w", err) + return nil, fmt.Errorf("marshal registration email suffix whitelist: %w", err) } updates[SettingKeyRegistrationEmailSuffixWhitelist] = string(registrationEmailSuffixWhitelistJSON) updates[SettingKeyPromoCodeEnabled] = strconv.FormatBool(settings.PromoCodeEnabled) @@ -677,7 +712,7 @@ func (s *SettingService) UpdateSettings(ctx context.Context, settings *SystemSet updates[SettingKeyTableDefaultPageSize] = strconv.Itoa(tableDefaultPageSize) tablePageSizeOptionsJSON, err := json.Marshal(tablePageSizeOptions) if err != nil { - return fmt.Errorf("marshal table page size options: %w", err) + return nil, fmt.Errorf("marshal table page size options: %w", err) } updates[SettingKeyTablePageSizeOptions] = string(tablePageSizeOptionsJSON) updates[SettingKeyCustomMenuItems] = settings.CustomMenuItems @@ -688,7 +723,7 @@ func (s *SettingService) UpdateSettings(ctx context.Context, settings *SystemSet updates[SettingKeyDefaultBalance] = strconv.FormatFloat(settings.DefaultBalance, 'f', 8, 64) defaultSubsJSON, err := json.Marshal(settings.DefaultSubscriptions) if err != nil { - return fmt.Errorf("marshal default subscriptions: %w", err) + return nil, fmt.Errorf("marshal default subscriptions: %w", err) } updates[SettingKeyDefaultSubscriptions] = string(defaultSubsJSON) @@ -738,37 +773,66 @@ func (s *SettingService) UpdateSettings(ctx context.Context, settings *SystemSet updates[SettingKeyAccountQuotaNotifyEnabled] = strconv.FormatBool(settings.AccountQuotaNotifyEnabled) updates[SettingKeyAccountQuotaNotifyEmails] = MarshalNotifyEmails(settings.AccountQuotaNotifyEmails) - err = s.settingRepo.SetMultiple(ctx, updates) - if err == nil { - // 先使 inflight singleflight 失效,再刷新缓存,缩小旧值覆盖新值的竞态窗口 - versionBoundsSF.Forget("version_bounds") - versionBoundsCache.Store(&cachedVersionBounds{ - min: settings.MinClaudeCodeVersion, - max: settings.MaxClaudeCodeVersion, - expiresAt: time.Now().Add(versionBoundsCacheTTL).UnixNano(), - }) - backendModeSF.Forget("backend_mode") - backendModeCache.Store(&cachedBackendMode{ - value: settings.BackendModeEnabled, - expiresAt: time.Now().Add(backendModeCacheTTL).UnixNano(), - }) - gatewayForwardingSF.Forget("gateway_forwarding") - gatewayForwardingCache.Store(&cachedGatewayForwardingSettings{ - fingerprintUnification: settings.EnableFingerprintUnification, - metadataPassthrough: settings.EnableMetadataPassthrough, - cchSigning: settings.EnableCCHSigning, - expiresAt: time.Now().Add(gatewayForwardingCacheTTL).UnixNano(), - }) - openAIAdvancedSchedulerSettingSF.Forget(openAIAdvancedSchedulerSettingKey) - openAIAdvancedSchedulerSettingCache.Store(&cachedOpenAIAdvancedSchedulerSetting{ - enabled: settings.OpenAIAdvancedSchedulerEnabled, - expiresAt: time.Now().Add(openAIAdvancedSchedulerSettingCacheTTL).UnixNano(), - }) - if s.onUpdate != nil { - s.onUpdate() // Invalidate cache after settings update + return updates, nil +} + +func (s *SettingService) buildAuthSourceDefaultUpdates(ctx context.Context, settings *AuthSourceDefaultSettings) (map[string]string, error) { + if settings == nil { + return nil, nil + } + + for _, subscriptions := range [][]DefaultSubscriptionSetting{ + settings.Email.Subscriptions, + settings.LinuxDo.Subscriptions, + settings.OIDC.Subscriptions, + settings.WeChat.Subscriptions, + } { + if err := s.validateDefaultSubscriptionGroups(ctx, subscriptions); err != nil { + return nil, err } } - return err + + updates := make(map[string]string, 21) + writeProviderDefaultGrantUpdates(updates, emailAuthSourceDefaultKeys, settings.Email) + writeProviderDefaultGrantUpdates(updates, linuxDoAuthSourceDefaultKeys, settings.LinuxDo) + writeProviderDefaultGrantUpdates(updates, oidcAuthSourceDefaultKeys, settings.OIDC) + writeProviderDefaultGrantUpdates(updates, weChatAuthSourceDefaultKeys, settings.WeChat) + updates[SettingKeyForceEmailOnThirdPartySignup] = strconv.FormatBool(settings.ForceEmailOnThirdPartySignup) + return updates, nil +} + +func (s *SettingService) refreshCachedSettings(settings *SystemSettings) { + if settings == nil { + return + } + + // 先使 inflight singleflight 失效,再刷新缓存,缩小旧值覆盖新值的竞态窗口 + versionBoundsSF.Forget("version_bounds") + versionBoundsCache.Store(&cachedVersionBounds{ + min: settings.MinClaudeCodeVersion, + max: settings.MaxClaudeCodeVersion, + expiresAt: time.Now().Add(versionBoundsCacheTTL).UnixNano(), + }) + backendModeSF.Forget("backend_mode") + backendModeCache.Store(&cachedBackendMode{ + value: settings.BackendModeEnabled, + expiresAt: time.Now().Add(backendModeCacheTTL).UnixNano(), + }) + gatewayForwardingSF.Forget("gateway_forwarding") + gatewayForwardingCache.Store(&cachedGatewayForwardingSettings{ + fingerprintUnification: settings.EnableFingerprintUnification, + metadataPassthrough: settings.EnableMetadataPassthrough, + cchSigning: settings.EnableCCHSigning, + expiresAt: time.Now().Add(gatewayForwardingCacheTTL).UnixNano(), + }) + openAIAdvancedSchedulerSettingSF.Forget(openAIAdvancedSchedulerSettingKey) + openAIAdvancedSchedulerSettingCache.Store(&cachedOpenAIAdvancedSchedulerSetting{ + enabled: settings.OpenAIAdvancedSchedulerEnabled, + expiresAt: time.Now().Add(openAIAdvancedSchedulerSettingCacheTTL).UnixNano(), + }) + if s.onUpdate != nil { + s.onUpdate() // Invalidate cache after settings update + } } func (s *SettingService) validateDefaultSubscriptionGroups(ctx context.Context, items []DefaultSubscriptionSetting) error { @@ -1067,29 +1131,43 @@ func (s *SettingService) GetAuthSourceDefaultSettings(ctx context.Context) (*Aut }, nil } +func (s *SettingService) ResolveAuthSourceGrantSettings(ctx context.Context, signupSource string, firstBind bool) (ProviderDefaultGrantSettings, bool, error) { + result := ProviderDefaultGrantSettings{ + Balance: s.GetDefaultBalance(ctx), + Concurrency: s.GetDefaultConcurrency(ctx), + Subscriptions: s.GetDefaultSubscriptions(ctx), + } + + defaults, err := s.GetAuthSourceDefaultSettings(ctx) + if err != nil { + return result, false, err + } + + providerDefaults, ok := authSourceSignupSettings(defaults, signupSource) + if !ok { + return result, false, nil + } + + enabled := providerDefaults.GrantOnSignup + if firstBind { + enabled = providerDefaults.GrantOnFirstBind + } + if !enabled { + return result, false, nil + } + + return mergeProviderDefaultGrantSettings(result, providerDefaults), true, nil +} + func (s *SettingService) UpdateAuthSourceDefaultSettings(ctx context.Context, settings *AuthSourceDefaultSettings) error { - if settings == nil { + updates, err := s.buildAuthSourceDefaultUpdates(ctx, settings) + if err != nil { + return err + } + if len(updates) == 0 { return nil } - for _, subscriptions := range [][]DefaultSubscriptionSetting{ - settings.Email.Subscriptions, - settings.LinuxDo.Subscriptions, - settings.OIDC.Subscriptions, - settings.WeChat.Subscriptions, - } { - if err := s.validateDefaultSubscriptionGroups(ctx, subscriptions); err != nil { - return err - } - } - - updates := make(map[string]string, 21) - writeProviderDefaultGrantUpdates(updates, emailAuthSourceDefaultKeys, settings.Email) - writeProviderDefaultGrantUpdates(updates, linuxDoAuthSourceDefaultKeys, settings.LinuxDo) - writeProviderDefaultGrantUpdates(updates, oidcAuthSourceDefaultKeys, settings.OIDC) - writeProviderDefaultGrantUpdates(updates, weChatAuthSourceDefaultKeys, settings.WeChat) - updates[SettingKeyForceEmailOnThirdPartySignup] = strconv.FormatBool(settings.ForceEmailOnThirdPartySignup) - if err := s.settingRepo.SetMultiple(ctx, updates); err != nil { return fmt.Errorf("update auth source default settings: %w", err) } @@ -1594,6 +1672,28 @@ func writeProviderDefaultGrantUpdates(updates map[string]string, keys authSource updates[keys.grantOnFirstBind] = strconv.FormatBool(settings.GrantOnFirstBind) } +func mergeProviderDefaultGrantSettings(globalDefaults ProviderDefaultGrantSettings, providerDefaults ProviderDefaultGrantSettings) ProviderDefaultGrantSettings { + result := ProviderDefaultGrantSettings{ + Balance: globalDefaults.Balance, + Concurrency: globalDefaults.Concurrency, + Subscriptions: append([]DefaultSubscriptionSetting(nil), globalDefaults.Subscriptions...), + GrantOnSignup: providerDefaults.GrantOnSignup, + GrantOnFirstBind: providerDefaults.GrantOnFirstBind, + } + + if providerDefaults.Balance != defaultAuthSourceBalance { + result.Balance = providerDefaults.Balance + } + if providerDefaults.Concurrency > 0 && providerDefaults.Concurrency != defaultAuthSourceConcurrency { + result.Concurrency = providerDefaults.Concurrency + } + if len(providerDefaults.Subscriptions) > 0 { + result.Subscriptions = append([]DefaultSubscriptionSetting(nil), providerDefaults.Subscriptions...) + } + + return result +} + func parseTablePreferences(defaultPageSizeRaw, optionsRaw string) (int, []int) { defaultPageSize := 20 if v, err := strconv.Atoi(strings.TrimSpace(defaultPageSizeRaw)); err == nil {