feat(settings): add default subscriptions for new users
- add default subscriptions to admin settings - auto-assign subscriptions on register and admin user creation - add validation/tests and align settings UI with subscription selector patterns
This commit is contained in:
@@ -19,10 +19,18 @@ import (
|
||||
)
|
||||
|
||||
var (
|
||||
ErrRegistrationDisabled = infraerrors.Forbidden("REGISTRATION_DISABLED", "registration is currently disabled")
|
||||
ErrSettingNotFound = infraerrors.NotFound("SETTING_NOT_FOUND", "setting not found")
|
||||
ErrSoraS3ProfileNotFound = infraerrors.NotFound("SORA_S3_PROFILE_NOT_FOUND", "sora s3 profile not found")
|
||||
ErrSoraS3ProfileExists = infraerrors.Conflict("SORA_S3_PROFILE_EXISTS", "sora s3 profile already exists")
|
||||
ErrRegistrationDisabled = infraerrors.Forbidden("REGISTRATION_DISABLED", "registration is currently disabled")
|
||||
ErrSettingNotFound = infraerrors.NotFound("SETTING_NOT_FOUND", "setting not found")
|
||||
ErrSoraS3ProfileNotFound = infraerrors.NotFound("SORA_S3_PROFILE_NOT_FOUND", "sora s3 profile not found")
|
||||
ErrSoraS3ProfileExists = infraerrors.Conflict("SORA_S3_PROFILE_EXISTS", "sora s3 profile already exists")
|
||||
ErrDefaultSubGroupInvalid = infraerrors.BadRequest(
|
||||
"DEFAULT_SUBSCRIPTION_GROUP_INVALID",
|
||||
"default subscription group must exist and be subscription type",
|
||||
)
|
||||
ErrDefaultSubGroupDuplicate = infraerrors.BadRequest(
|
||||
"DEFAULT_SUBSCRIPTION_GROUP_DUPLICATE",
|
||||
"default subscription group cannot be duplicated",
|
||||
)
|
||||
)
|
||||
|
||||
type SettingRepository interface {
|
||||
@@ -56,13 +64,19 @@ const minVersionErrorTTL = 5 * time.Second
|
||||
// minVersionDBTimeout singleflight 内 DB 查询超时,独立于请求 context
|
||||
const minVersionDBTimeout = 5 * time.Second
|
||||
|
||||
// DefaultSubscriptionGroupReader validates group references used by default subscriptions.
|
||||
type DefaultSubscriptionGroupReader interface {
|
||||
GetByID(ctx context.Context, id int64) (*Group, error)
|
||||
}
|
||||
|
||||
// SettingService 系统设置服务
|
||||
type SettingService struct {
|
||||
settingRepo SettingRepository
|
||||
cfg *config.Config
|
||||
onUpdate func() // Callback when settings are updated (for cache invalidation)
|
||||
onS3Update func() // Callback when Sora S3 settings are updated
|
||||
version string // Application version
|
||||
settingRepo SettingRepository
|
||||
defaultSubGroupReader DefaultSubscriptionGroupReader
|
||||
cfg *config.Config
|
||||
onUpdate func() // Callback when settings are updated (for cache invalidation)
|
||||
onS3Update func() // Callback when Sora S3 settings are updated
|
||||
version string // Application version
|
||||
}
|
||||
|
||||
// NewSettingService 创建系统设置服务实例
|
||||
@@ -73,6 +87,11 @@ func NewSettingService(settingRepo SettingRepository, cfg *config.Config) *Setti
|
||||
}
|
||||
}
|
||||
|
||||
// SetDefaultSubscriptionGroupReader injects an optional group reader for default subscription validation.
|
||||
func (s *SettingService) SetDefaultSubscriptionGroupReader(reader DefaultSubscriptionGroupReader) {
|
||||
s.defaultSubGroupReader = reader
|
||||
}
|
||||
|
||||
// GetAllSettings 获取所有系统设置
|
||||
func (s *SettingService) GetAllSettings(ctx context.Context) (*SystemSettings, error) {
|
||||
settings, err := s.settingRepo.GetAll(ctx)
|
||||
@@ -222,6 +241,10 @@ func (s *SettingService) GetPublicSettingsForInjection(ctx context.Context) (any
|
||||
|
||||
// UpdateSettings 更新系统设置
|
||||
func (s *SettingService) UpdateSettings(ctx context.Context, settings *SystemSettings) error {
|
||||
if err := s.validateDefaultSubscriptionGroups(ctx, settings.DefaultSubscriptions); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
updates := make(map[string]string)
|
||||
|
||||
// 注册设置
|
||||
@@ -274,6 +297,11 @@ func (s *SettingService) UpdateSettings(ctx context.Context, settings *SystemSet
|
||||
// 默认配置
|
||||
updates[SettingKeyDefaultConcurrency] = strconv.Itoa(settings.DefaultConcurrency)
|
||||
updates[SettingKeyDefaultBalance] = strconv.FormatFloat(settings.DefaultBalance, 'f', 8, 64)
|
||||
defaultSubsJSON, err := json.Marshal(settings.DefaultSubscriptions)
|
||||
if err != nil {
|
||||
return fmt.Errorf("marshal default subscriptions: %w", err)
|
||||
}
|
||||
updates[SettingKeyDefaultSubscriptions] = string(defaultSubsJSON)
|
||||
|
||||
// Model fallback configuration
|
||||
updates[SettingKeyEnableModelFallback] = strconv.FormatBool(settings.EnableModelFallback)
|
||||
@@ -297,7 +325,7 @@ func (s *SettingService) UpdateSettings(ctx context.Context, settings *SystemSet
|
||||
// Claude Code version check
|
||||
updates[SettingKeyMinClaudeCodeVersion] = settings.MinClaudeCodeVersion
|
||||
|
||||
err := s.settingRepo.SetMultiple(ctx, updates)
|
||||
err = s.settingRepo.SetMultiple(ctx, updates)
|
||||
if err == nil {
|
||||
// 先使 inflight singleflight 失效,再刷新缓存,缩小旧值覆盖新值的竞态窗口
|
||||
minVersionSF.Forget("min_version")
|
||||
@@ -312,6 +340,45 @@ func (s *SettingService) UpdateSettings(ctx context.Context, settings *SystemSet
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *SettingService) validateDefaultSubscriptionGroups(ctx context.Context, items []DefaultSubscriptionSetting) error {
|
||||
if len(items) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
checked := make(map[int64]struct{}, len(items))
|
||||
for _, item := range items {
|
||||
if item.GroupID <= 0 {
|
||||
continue
|
||||
}
|
||||
if _, ok := checked[item.GroupID]; ok {
|
||||
return ErrDefaultSubGroupDuplicate.WithMetadata(map[string]string{
|
||||
"group_id": strconv.FormatInt(item.GroupID, 10),
|
||||
})
|
||||
}
|
||||
checked[item.GroupID] = struct{}{}
|
||||
if s.defaultSubGroupReader == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
group, err := s.defaultSubGroupReader.GetByID(ctx, item.GroupID)
|
||||
if err != nil {
|
||||
if errors.Is(err, ErrGroupNotFound) {
|
||||
return ErrDefaultSubGroupInvalid.WithMetadata(map[string]string{
|
||||
"group_id": strconv.FormatInt(item.GroupID, 10),
|
||||
})
|
||||
}
|
||||
return fmt.Errorf("get default subscription group %d: %w", item.GroupID, err)
|
||||
}
|
||||
if !group.IsSubscriptionType() {
|
||||
return ErrDefaultSubGroupInvalid.WithMetadata(map[string]string{
|
||||
"group_id": strconv.FormatInt(item.GroupID, 10),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// IsRegistrationEnabled 检查是否开放注册
|
||||
func (s *SettingService) IsRegistrationEnabled(ctx context.Context) bool {
|
||||
value, err := s.settingRepo.GetValue(ctx, SettingKeyRegistrationEnabled)
|
||||
@@ -411,6 +478,15 @@ func (s *SettingService) GetDefaultBalance(ctx context.Context) float64 {
|
||||
return s.cfg.Default.UserBalance
|
||||
}
|
||||
|
||||
// GetDefaultSubscriptions 获取新用户默认订阅配置列表。
|
||||
func (s *SettingService) GetDefaultSubscriptions(ctx context.Context) []DefaultSubscriptionSetting {
|
||||
value, err := s.settingRepo.GetValue(ctx, SettingKeyDefaultSubscriptions)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
return parseDefaultSubscriptions(value)
|
||||
}
|
||||
|
||||
// InitializeDefaultSettings 初始化默认设置
|
||||
func (s *SettingService) InitializeDefaultSettings(ctx context.Context) error {
|
||||
// 检查是否已有设置
|
||||
@@ -435,6 +511,7 @@ func (s *SettingService) InitializeDefaultSettings(ctx context.Context) error {
|
||||
SettingKeySoraClientEnabled: "false",
|
||||
SettingKeyDefaultConcurrency: strconv.Itoa(s.cfg.Default.UserConcurrency),
|
||||
SettingKeyDefaultBalance: strconv.FormatFloat(s.cfg.Default.UserBalance, 'f', 8, 64),
|
||||
SettingKeyDefaultSubscriptions: "[]",
|
||||
SettingKeySMTPPort: "587",
|
||||
SettingKeySMTPUseTLS: "false",
|
||||
// Model fallback defaults
|
||||
@@ -511,6 +588,7 @@ func (s *SettingService) parseSettings(settings map[string]string) *SystemSettin
|
||||
} else {
|
||||
result.DefaultBalance = s.cfg.Default.UserBalance
|
||||
}
|
||||
result.DefaultSubscriptions = parseDefaultSubscriptions(settings[SettingKeyDefaultSubscriptions])
|
||||
|
||||
// 敏感信息直接返回,方便测试连接时使用
|
||||
result.SMTPPassword = settings[SettingKeySMTPPassword]
|
||||
@@ -595,6 +673,31 @@ func isFalseSettingValue(value string) bool {
|
||||
}
|
||||
}
|
||||
|
||||
func parseDefaultSubscriptions(raw string) []DefaultSubscriptionSetting {
|
||||
raw = strings.TrimSpace(raw)
|
||||
if raw == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
var items []DefaultSubscriptionSetting
|
||||
if err := json.Unmarshal([]byte(raw), &items); err != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
normalized := make([]DefaultSubscriptionSetting, 0, len(items))
|
||||
for _, item := range items {
|
||||
if item.GroupID <= 0 || item.ValidityDays <= 0 {
|
||||
continue
|
||||
}
|
||||
if item.ValidityDays > MaxValidityDays {
|
||||
item.ValidityDays = MaxValidityDays
|
||||
}
|
||||
normalized = append(normalized, item)
|
||||
}
|
||||
|
||||
return normalized
|
||||
}
|
||||
|
||||
// getStringOrDefault 获取字符串值或默认值
|
||||
func (s *SettingService) getStringOrDefault(settings map[string]string, key, defaultValue string) string {
|
||||
if value, ok := settings[key]; ok && value != "" {
|
||||
|
||||
Reference in New Issue
Block a user