feat(settings): add default subscriptions for new users
- add default subscriptions to admin settings - auto-assign subscriptions on register and admin user creation - add validation/tests and align settings UI with subscription selector patterns
This commit is contained in:
@@ -51,6 +51,13 @@ func (h *SettingHandler) GetSettings(c *gin.Context) {
|
||||
|
||||
// Check if ops monitoring is enabled (respects config.ops.enabled)
|
||||
opsEnabled := h.opsService != nil && h.opsService.IsMonitoringEnabled(c.Request.Context())
|
||||
defaultSubscriptions := make([]dto.DefaultSubscriptionSetting, 0, len(settings.DefaultSubscriptions))
|
||||
for _, sub := range settings.DefaultSubscriptions {
|
||||
defaultSubscriptions = append(defaultSubscriptions, dto.DefaultSubscriptionSetting{
|
||||
GroupID: sub.GroupID,
|
||||
ValidityDays: sub.ValidityDays,
|
||||
})
|
||||
}
|
||||
|
||||
response.Success(c, dto.SystemSettings{
|
||||
RegistrationEnabled: settings.RegistrationEnabled,
|
||||
@@ -87,6 +94,7 @@ func (h *SettingHandler) GetSettings(c *gin.Context) {
|
||||
SoraClientEnabled: settings.SoraClientEnabled,
|
||||
DefaultConcurrency: settings.DefaultConcurrency,
|
||||
DefaultBalance: settings.DefaultBalance,
|
||||
DefaultSubscriptions: defaultSubscriptions,
|
||||
EnableModelFallback: settings.EnableModelFallback,
|
||||
FallbackModelAnthropic: settings.FallbackModelAnthropic,
|
||||
FallbackModelOpenAI: settings.FallbackModelOpenAI,
|
||||
@@ -146,8 +154,9 @@ type UpdateSettingsRequest struct {
|
||||
SoraClientEnabled bool `json:"sora_client_enabled"`
|
||||
|
||||
// 默认配置
|
||||
DefaultConcurrency int `json:"default_concurrency"`
|
||||
DefaultBalance float64 `json:"default_balance"`
|
||||
DefaultConcurrency int `json:"default_concurrency"`
|
||||
DefaultBalance float64 `json:"default_balance"`
|
||||
DefaultSubscriptions []dto.DefaultSubscriptionSetting `json:"default_subscriptions"`
|
||||
|
||||
// Model fallback configuration
|
||||
EnableModelFallback bool `json:"enable_model_fallback"`
|
||||
@@ -194,6 +203,7 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
|
||||
if req.SMTPPort <= 0 {
|
||||
req.SMTPPort = 587
|
||||
}
|
||||
req.DefaultSubscriptions = normalizeDefaultSubscriptions(req.DefaultSubscriptions)
|
||||
|
||||
// Turnstile 参数验证
|
||||
if req.TurnstileEnabled {
|
||||
@@ -300,6 +310,13 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
|
||||
}
|
||||
req.OpsMetricsIntervalSeconds = &v
|
||||
}
|
||||
defaultSubscriptions := make([]service.DefaultSubscriptionSetting, 0, len(req.DefaultSubscriptions))
|
||||
for _, sub := range req.DefaultSubscriptions {
|
||||
defaultSubscriptions = append(defaultSubscriptions, service.DefaultSubscriptionSetting{
|
||||
GroupID: sub.GroupID,
|
||||
ValidityDays: sub.ValidityDays,
|
||||
})
|
||||
}
|
||||
|
||||
// 验证最低版本号格式(空字符串=禁用,或合法 semver)
|
||||
if req.MinClaudeCodeVersion != "" {
|
||||
@@ -343,6 +360,7 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
|
||||
SoraClientEnabled: req.SoraClientEnabled,
|
||||
DefaultConcurrency: req.DefaultConcurrency,
|
||||
DefaultBalance: req.DefaultBalance,
|
||||
DefaultSubscriptions: defaultSubscriptions,
|
||||
EnableModelFallback: req.EnableModelFallback,
|
||||
FallbackModelAnthropic: req.FallbackModelAnthropic,
|
||||
FallbackModelOpenAI: req.FallbackModelOpenAI,
|
||||
@@ -390,6 +408,13 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
updatedDefaultSubscriptions := make([]dto.DefaultSubscriptionSetting, 0, len(updatedSettings.DefaultSubscriptions))
|
||||
for _, sub := range updatedSettings.DefaultSubscriptions {
|
||||
updatedDefaultSubscriptions = append(updatedDefaultSubscriptions, dto.DefaultSubscriptionSetting{
|
||||
GroupID: sub.GroupID,
|
||||
ValidityDays: sub.ValidityDays,
|
||||
})
|
||||
}
|
||||
|
||||
response.Success(c, dto.SystemSettings{
|
||||
RegistrationEnabled: updatedSettings.RegistrationEnabled,
|
||||
@@ -426,6 +451,7 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
|
||||
SoraClientEnabled: updatedSettings.SoraClientEnabled,
|
||||
DefaultConcurrency: updatedSettings.DefaultConcurrency,
|
||||
DefaultBalance: updatedSettings.DefaultBalance,
|
||||
DefaultSubscriptions: updatedDefaultSubscriptions,
|
||||
EnableModelFallback: updatedSettings.EnableModelFallback,
|
||||
FallbackModelAnthropic: updatedSettings.FallbackModelAnthropic,
|
||||
FallbackModelOpenAI: updatedSettings.FallbackModelOpenAI,
|
||||
@@ -547,6 +573,9 @@ func diffSettings(before *service.SystemSettings, after *service.SystemSettings,
|
||||
if before.DefaultBalance != after.DefaultBalance {
|
||||
changed = append(changed, "default_balance")
|
||||
}
|
||||
if !equalDefaultSubscriptions(before.DefaultSubscriptions, after.DefaultSubscriptions) {
|
||||
changed = append(changed, "default_subscriptions")
|
||||
}
|
||||
if before.EnableModelFallback != after.EnableModelFallback {
|
||||
changed = append(changed, "enable_model_fallback")
|
||||
}
|
||||
@@ -586,6 +615,35 @@ func diffSettings(before *service.SystemSettings, after *service.SystemSettings,
|
||||
return changed
|
||||
}
|
||||
|
||||
func normalizeDefaultSubscriptions(input []dto.DefaultSubscriptionSetting) []dto.DefaultSubscriptionSetting {
|
||||
if len(input) == 0 {
|
||||
return nil
|
||||
}
|
||||
normalized := make([]dto.DefaultSubscriptionSetting, 0, len(input))
|
||||
for _, item := range input {
|
||||
if item.GroupID <= 0 || item.ValidityDays <= 0 {
|
||||
continue
|
||||
}
|
||||
if item.ValidityDays > service.MaxValidityDays {
|
||||
item.ValidityDays = service.MaxValidityDays
|
||||
}
|
||||
normalized = append(normalized, item)
|
||||
}
|
||||
return normalized
|
||||
}
|
||||
|
||||
func equalDefaultSubscriptions(a, b []service.DefaultSubscriptionSetting) bool {
|
||||
if len(a) != len(b) {
|
||||
return false
|
||||
}
|
||||
for i := range a {
|
||||
if a[i].GroupID != b[i].GroupID || a[i].ValidityDays != b[i].ValidityDays {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// TestSMTPRequest 测试SMTP连接请求
|
||||
type TestSMTPRequest struct {
|
||||
SMTPHost string `json:"smtp_host" binding:"required"`
|
||||
|
||||
@@ -39,8 +39,9 @@ type SystemSettings struct {
|
||||
PurchaseSubscriptionURL string `json:"purchase_subscription_url"`
|
||||
SoraClientEnabled bool `json:"sora_client_enabled"`
|
||||
|
||||
DefaultConcurrency int `json:"default_concurrency"`
|
||||
DefaultBalance float64 `json:"default_balance"`
|
||||
DefaultConcurrency int `json:"default_concurrency"`
|
||||
DefaultBalance float64 `json:"default_balance"`
|
||||
DefaultSubscriptions []DefaultSubscriptionSetting `json:"default_subscriptions"`
|
||||
|
||||
// Model fallback configuration
|
||||
EnableModelFallback bool `json:"enable_model_fallback"`
|
||||
@@ -62,6 +63,11 @@ type SystemSettings struct {
|
||||
MinClaudeCodeVersion string `json:"min_claude_code_version"`
|
||||
}
|
||||
|
||||
type DefaultSubscriptionSetting struct {
|
||||
GroupID int64 `json:"group_id"`
|
||||
ValidityDays int `json:"validity_days"`
|
||||
}
|
||||
|
||||
type PublicSettings struct {
|
||||
RegistrationEnabled bool `json:"registration_enabled"`
|
||||
EmailVerifyEnabled bool `json:"email_verify_enabled"`
|
||||
|
||||
@@ -499,6 +499,7 @@ func TestAPIContracts(t *testing.T) {
|
||||
"doc_url": "https://docs.example.com",
|
||||
"default_concurrency": 5,
|
||||
"default_balance": 1.25,
|
||||
"default_subscriptions": [],
|
||||
"enable_model_fallback": false,
|
||||
"fallback_model_anthropic": "claude-3-5-sonnet-20241022",
|
||||
"fallback_model_antigravity": "gemini-2.5-pro",
|
||||
@@ -620,7 +621,7 @@ func newContractDeps(t *testing.T) *contractDeps {
|
||||
settingRepo := newStubSettingRepo()
|
||||
settingService := service.NewSettingService(settingRepo, cfg)
|
||||
|
||||
adminService := service.NewAdminService(userRepo, groupRepo, &accountRepo, nil, proxyRepo, apiKeyRepo, redeemRepo, nil, nil, nil, nil, nil, nil)
|
||||
adminService := service.NewAdminService(userRepo, groupRepo, &accountRepo, nil, proxyRepo, apiKeyRepo, redeemRepo, nil, nil, nil, nil, nil, nil, nil, nil)
|
||||
authHandler := handler.NewAuthHandler(cfg, nil, userService, settingService, nil, redeemService, nil)
|
||||
apiKeyHandler := handler.NewAPIKeyHandler(apiKeyService)
|
||||
usageHandler := handler.NewUsageHandler(usageService, apiKeyService)
|
||||
|
||||
@@ -19,7 +19,7 @@ func TestAdminAuthJWTValidatesTokenVersion(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
cfg := &config.Config{JWT: config.JWTConfig{Secret: "test-secret", ExpireHour: 1}}
|
||||
authService := service.NewAuthService(nil, nil, nil, cfg, nil, nil, nil, nil, nil)
|
||||
authService := service.NewAuthService(nil, nil, nil, cfg, nil, nil, nil, nil, nil, nil)
|
||||
|
||||
admin := &service.User{
|
||||
ID: 1,
|
||||
|
||||
@@ -40,7 +40,7 @@ func newJWTTestEnv(users map[int64]*service.User) (*gin.Engine, *service.AuthSer
|
||||
cfg.JWT.AccessTokenExpireMinutes = 60
|
||||
|
||||
userRepo := &stubJWTUserRepo{users: users}
|
||||
authSvc := service.NewAuthService(userRepo, nil, nil, cfg, nil, nil, nil, nil, nil)
|
||||
authSvc := service.NewAuthService(userRepo, nil, nil, cfg, nil, nil, nil, nil, nil, nil)
|
||||
userSvc := service.NewUserService(userRepo, nil, nil)
|
||||
mw := NewJWTAuthMiddleware(authSvc, userSvc)
|
||||
|
||||
|
||||
@@ -420,6 +420,8 @@ type adminServiceImpl struct {
|
||||
proxyLatencyCache ProxyLatencyCache
|
||||
authCacheInvalidator APIKeyAuthCacheInvalidator
|
||||
entClient *dbent.Client // 用于开启数据库事务
|
||||
settingService *SettingService
|
||||
defaultSubAssigner DefaultSubscriptionAssigner
|
||||
}
|
||||
|
||||
type userGroupRateBatchReader interface {
|
||||
@@ -445,6 +447,8 @@ func NewAdminService(
|
||||
proxyLatencyCache ProxyLatencyCache,
|
||||
authCacheInvalidator APIKeyAuthCacheInvalidator,
|
||||
entClient *dbent.Client,
|
||||
settingService *SettingService,
|
||||
defaultSubAssigner DefaultSubscriptionAssigner,
|
||||
) AdminService {
|
||||
return &adminServiceImpl{
|
||||
userRepo: userRepo,
|
||||
@@ -460,6 +464,8 @@ func NewAdminService(
|
||||
proxyLatencyCache: proxyLatencyCache,
|
||||
authCacheInvalidator: authCacheInvalidator,
|
||||
entClient: entClient,
|
||||
settingService: settingService,
|
||||
defaultSubAssigner: defaultSubAssigner,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -544,9 +550,27 @@ func (s *adminServiceImpl) CreateUser(ctx context.Context, input *CreateUserInpu
|
||||
if err := s.userRepo.Create(ctx, user); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
s.assignDefaultSubscriptions(ctx, user.ID)
|
||||
return user, nil
|
||||
}
|
||||
|
||||
func (s *adminServiceImpl) assignDefaultSubscriptions(ctx context.Context, userID int64) {
|
||||
if s.settingService == nil || s.defaultSubAssigner == nil || userID <= 0 {
|
||||
return
|
||||
}
|
||||
items := s.settingService.GetDefaultSubscriptions(ctx)
|
||||
for _, item := range items {
|
||||
if _, _, err := s.defaultSubAssigner.AssignOrExtendSubscription(ctx, &AssignSubscriptionInput{
|
||||
UserID: userID,
|
||||
GroupID: item.GroupID,
|
||||
ValidityDays: item.ValidityDays,
|
||||
Notes: "auto assigned by default user subscriptions setting",
|
||||
}); err != nil {
|
||||
logger.LegacyPrintf("service.admin", "failed to assign default subscription: user_id=%d group_id=%d err=%v", userID, item.GroupID, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *adminServiceImpl) UpdateUser(ctx context.Context, id int64, input *UpdateUserInput) (*User, error) {
|
||||
user, err := s.userRepo.GetByID(ctx, id)
|
||||
if err != nil {
|
||||
|
||||
@@ -7,6 +7,7 @@ import (
|
||||
"errors"
|
||||
"testing"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
@@ -65,3 +66,32 @@ func TestAdminService_CreateUser_CreateError(t *testing.T) {
|
||||
require.ErrorIs(t, err, createErr)
|
||||
require.Empty(t, repo.created)
|
||||
}
|
||||
|
||||
func TestAdminService_CreateUser_AssignsDefaultSubscriptions(t *testing.T) {
|
||||
repo := &userRepoStub{nextID: 21}
|
||||
assigner := &defaultSubscriptionAssignerStub{}
|
||||
cfg := &config.Config{
|
||||
Default: config.DefaultConfig{
|
||||
UserBalance: 0,
|
||||
UserConcurrency: 1,
|
||||
},
|
||||
}
|
||||
settingService := NewSettingService(&settingRepoStub{values: map[string]string{
|
||||
SettingKeyDefaultSubscriptions: `[{"group_id":5,"validity_days":30}]`,
|
||||
}}, cfg)
|
||||
svc := &adminServiceImpl{
|
||||
userRepo: repo,
|
||||
settingService: settingService,
|
||||
defaultSubAssigner: assigner,
|
||||
}
|
||||
|
||||
_, err := svc.CreateUser(context.Background(), &CreateUserInput{
|
||||
Email: "new-user@test.com",
|
||||
Password: "password",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Len(t, assigner.calls, 1)
|
||||
require.Equal(t, int64(21), assigner.calls[0].UserID)
|
||||
require.Equal(t, int64(5), assigner.calls[0].GroupID)
|
||||
require.Equal(t, 30, assigner.calls[0].ValidityDays)
|
||||
}
|
||||
|
||||
@@ -65,6 +65,11 @@ type AuthService struct {
|
||||
turnstileService *TurnstileService
|
||||
emailQueueService *EmailQueueService
|
||||
promoService *PromoService
|
||||
defaultSubAssigner DefaultSubscriptionAssigner
|
||||
}
|
||||
|
||||
type DefaultSubscriptionAssigner interface {
|
||||
AssignOrExtendSubscription(ctx context.Context, input *AssignSubscriptionInput) (*UserSubscription, bool, error)
|
||||
}
|
||||
|
||||
// NewAuthService 创建认证服务实例
|
||||
@@ -78,6 +83,7 @@ func NewAuthService(
|
||||
turnstileService *TurnstileService,
|
||||
emailQueueService *EmailQueueService,
|
||||
promoService *PromoService,
|
||||
defaultSubAssigner DefaultSubscriptionAssigner,
|
||||
) *AuthService {
|
||||
return &AuthService{
|
||||
userRepo: userRepo,
|
||||
@@ -89,6 +95,7 @@ func NewAuthService(
|
||||
turnstileService: turnstileService,
|
||||
emailQueueService: emailQueueService,
|
||||
promoService: promoService,
|
||||
defaultSubAssigner: defaultSubAssigner,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -188,6 +195,7 @@ func (s *AuthService) RegisterWithVerification(ctx context.Context, email, passw
|
||||
logger.LegacyPrintf("service.auth", "[Auth] Database error creating user: %v", err)
|
||||
return "", nil, ErrServiceUnavailable
|
||||
}
|
||||
s.assignDefaultSubscriptions(ctx, user.ID)
|
||||
|
||||
// 标记邀请码为已使用(如果使用了邀请码)
|
||||
if invitationRedeemCode != nil {
|
||||
@@ -477,6 +485,7 @@ func (s *AuthService) LoginOrRegisterOAuth(ctx context.Context, email, username
|
||||
}
|
||||
} else {
|
||||
user = newUser
|
||||
s.assignDefaultSubscriptions(ctx, user.ID)
|
||||
}
|
||||
} else {
|
||||
logger.LegacyPrintf("service.auth", "[Auth] Database error during oauth login: %v", err)
|
||||
@@ -572,6 +581,7 @@ func (s *AuthService) LoginOrRegisterOAuthWithTokenPair(ctx context.Context, ema
|
||||
}
|
||||
} else {
|
||||
user = newUser
|
||||
s.assignDefaultSubscriptions(ctx, user.ID)
|
||||
}
|
||||
} else {
|
||||
logger.LegacyPrintf("service.auth", "[Auth] Database error during oauth login: %v", err)
|
||||
@@ -597,6 +607,23 @@ func (s *AuthService) LoginOrRegisterOAuthWithTokenPair(ctx context.Context, ema
|
||||
return tokenPair, user, nil
|
||||
}
|
||||
|
||||
func (s *AuthService) assignDefaultSubscriptions(ctx context.Context, userID int64) {
|
||||
if s.settingService == nil || s.defaultSubAssigner == nil || userID <= 0 {
|
||||
return
|
||||
}
|
||||
items := s.settingService.GetDefaultSubscriptions(ctx)
|
||||
for _, item := range items {
|
||||
if _, _, err := s.defaultSubAssigner.AssignOrExtendSubscription(ctx, &AssignSubscriptionInput{
|
||||
UserID: userID,
|
||||
GroupID: item.GroupID,
|
||||
ValidityDays: item.ValidityDays,
|
||||
Notes: "auto assigned by default user subscriptions setting",
|
||||
}); err != nil {
|
||||
logger.LegacyPrintf("service.auth", "[Auth] Failed to assign default subscription: user_id=%d group_id=%d err=%v", userID, item.GroupID, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ValidateToken 验证JWT token并返回用户声明
|
||||
func (s *AuthService) ValidateToken(tokenString string) (*JWTClaims, error) {
|
||||
// 先做长度校验,尽早拒绝异常超长 token,降低 DoS 风险。
|
||||
|
||||
@@ -56,6 +56,21 @@ type emailCacheStub struct {
|
||||
err error
|
||||
}
|
||||
|
||||
type defaultSubscriptionAssignerStub struct {
|
||||
calls []AssignSubscriptionInput
|
||||
err error
|
||||
}
|
||||
|
||||
func (s *defaultSubscriptionAssignerStub) AssignOrExtendSubscription(_ context.Context, input *AssignSubscriptionInput) (*UserSubscription, bool, error) {
|
||||
if input != nil {
|
||||
s.calls = append(s.calls, *input)
|
||||
}
|
||||
if s.err != nil {
|
||||
return nil, false, s.err
|
||||
}
|
||||
return &UserSubscription{UserID: input.UserID, GroupID: input.GroupID}, false, nil
|
||||
}
|
||||
|
||||
func (s *emailCacheStub) GetVerificationCode(ctx context.Context, email string) (*VerificationCodeData, error) {
|
||||
if s.err != nil {
|
||||
return nil, s.err
|
||||
@@ -123,6 +138,7 @@ func newAuthService(repo *userRepoStub, settings map[string]string, emailCache E
|
||||
nil,
|
||||
nil,
|
||||
nil, // promoService
|
||||
nil, // defaultSubAssigner
|
||||
)
|
||||
}
|
||||
|
||||
@@ -381,3 +397,23 @@ func TestAuthService_GenerateToken_UsesMinutesWhenConfigured(t *testing.T) {
|
||||
|
||||
require.WithinDuration(t, claims.IssuedAt.Time.Add(90*time.Minute), claims.ExpiresAt.Time, 2*time.Second)
|
||||
}
|
||||
|
||||
func TestAuthService_Register_AssignsDefaultSubscriptions(t *testing.T) {
|
||||
repo := &userRepoStub{nextID: 42}
|
||||
assigner := &defaultSubscriptionAssignerStub{}
|
||||
service := newAuthService(repo, map[string]string{
|
||||
SettingKeyRegistrationEnabled: "true",
|
||||
SettingKeyDefaultSubscriptions: `[{"group_id":11,"validity_days":30},{"group_id":12,"validity_days":7}]`,
|
||||
}, nil)
|
||||
service.defaultSubAssigner = assigner
|
||||
|
||||
_, user, err := service.Register(context.Background(), "default-sub@test.com", "password")
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, user)
|
||||
require.Len(t, assigner.calls, 2)
|
||||
require.Equal(t, int64(42), assigner.calls[0].UserID)
|
||||
require.Equal(t, int64(11), assigner.calls[0].GroupID)
|
||||
require.Equal(t, 30, assigner.calls[0].ValidityDays)
|
||||
require.Equal(t, int64(12), assigner.calls[1].GroupID)
|
||||
require.Equal(t, 7, assigner.calls[1].ValidityDays)
|
||||
}
|
||||
|
||||
@@ -117,8 +117,9 @@ const (
|
||||
SettingKeyPurchaseSubscriptionURL = "purchase_subscription_url" // “购买订阅”页面 URL(作为 iframe src)
|
||||
|
||||
// 默认配置
|
||||
SettingKeyDefaultConcurrency = "default_concurrency" // 新用户默认并发量
|
||||
SettingKeyDefaultBalance = "default_balance" // 新用户默认余额
|
||||
SettingKeyDefaultConcurrency = "default_concurrency" // 新用户默认并发量
|
||||
SettingKeyDefaultBalance = "default_balance" // 新用户默认余额
|
||||
SettingKeyDefaultSubscriptions = "default_subscriptions" // 新用户默认订阅列表(JSON)
|
||||
|
||||
// 管理员 API Key
|
||||
SettingKeyAdminAPIKey = "admin_api_key" // 全局管理员 API Key(用于外部系统集成)
|
||||
|
||||
@@ -19,10 +19,18 @@ import (
|
||||
)
|
||||
|
||||
var (
|
||||
ErrRegistrationDisabled = infraerrors.Forbidden("REGISTRATION_DISABLED", "registration is currently disabled")
|
||||
ErrSettingNotFound = infraerrors.NotFound("SETTING_NOT_FOUND", "setting not found")
|
||||
ErrSoraS3ProfileNotFound = infraerrors.NotFound("SORA_S3_PROFILE_NOT_FOUND", "sora s3 profile not found")
|
||||
ErrSoraS3ProfileExists = infraerrors.Conflict("SORA_S3_PROFILE_EXISTS", "sora s3 profile already exists")
|
||||
ErrRegistrationDisabled = infraerrors.Forbidden("REGISTRATION_DISABLED", "registration is currently disabled")
|
||||
ErrSettingNotFound = infraerrors.NotFound("SETTING_NOT_FOUND", "setting not found")
|
||||
ErrSoraS3ProfileNotFound = infraerrors.NotFound("SORA_S3_PROFILE_NOT_FOUND", "sora s3 profile not found")
|
||||
ErrSoraS3ProfileExists = infraerrors.Conflict("SORA_S3_PROFILE_EXISTS", "sora s3 profile already exists")
|
||||
ErrDefaultSubGroupInvalid = infraerrors.BadRequest(
|
||||
"DEFAULT_SUBSCRIPTION_GROUP_INVALID",
|
||||
"default subscription group must exist and be subscription type",
|
||||
)
|
||||
ErrDefaultSubGroupDuplicate = infraerrors.BadRequest(
|
||||
"DEFAULT_SUBSCRIPTION_GROUP_DUPLICATE",
|
||||
"default subscription group cannot be duplicated",
|
||||
)
|
||||
)
|
||||
|
||||
type SettingRepository interface {
|
||||
@@ -56,13 +64,19 @@ const minVersionErrorTTL = 5 * time.Second
|
||||
// minVersionDBTimeout singleflight 内 DB 查询超时,独立于请求 context
|
||||
const minVersionDBTimeout = 5 * time.Second
|
||||
|
||||
// DefaultSubscriptionGroupReader validates group references used by default subscriptions.
|
||||
type DefaultSubscriptionGroupReader interface {
|
||||
GetByID(ctx context.Context, id int64) (*Group, error)
|
||||
}
|
||||
|
||||
// SettingService 系统设置服务
|
||||
type SettingService struct {
|
||||
settingRepo SettingRepository
|
||||
cfg *config.Config
|
||||
onUpdate func() // Callback when settings are updated (for cache invalidation)
|
||||
onS3Update func() // Callback when Sora S3 settings are updated
|
||||
version string // Application version
|
||||
settingRepo SettingRepository
|
||||
defaultSubGroupReader DefaultSubscriptionGroupReader
|
||||
cfg *config.Config
|
||||
onUpdate func() // Callback when settings are updated (for cache invalidation)
|
||||
onS3Update func() // Callback when Sora S3 settings are updated
|
||||
version string // Application version
|
||||
}
|
||||
|
||||
// NewSettingService 创建系统设置服务实例
|
||||
@@ -73,6 +87,11 @@ func NewSettingService(settingRepo SettingRepository, cfg *config.Config) *Setti
|
||||
}
|
||||
}
|
||||
|
||||
// SetDefaultSubscriptionGroupReader injects an optional group reader for default subscription validation.
|
||||
func (s *SettingService) SetDefaultSubscriptionGroupReader(reader DefaultSubscriptionGroupReader) {
|
||||
s.defaultSubGroupReader = reader
|
||||
}
|
||||
|
||||
// GetAllSettings 获取所有系统设置
|
||||
func (s *SettingService) GetAllSettings(ctx context.Context) (*SystemSettings, error) {
|
||||
settings, err := s.settingRepo.GetAll(ctx)
|
||||
@@ -222,6 +241,10 @@ func (s *SettingService) GetPublicSettingsForInjection(ctx context.Context) (any
|
||||
|
||||
// UpdateSettings 更新系统设置
|
||||
func (s *SettingService) UpdateSettings(ctx context.Context, settings *SystemSettings) error {
|
||||
if err := s.validateDefaultSubscriptionGroups(ctx, settings.DefaultSubscriptions); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
updates := make(map[string]string)
|
||||
|
||||
// 注册设置
|
||||
@@ -274,6 +297,11 @@ func (s *SettingService) UpdateSettings(ctx context.Context, settings *SystemSet
|
||||
// 默认配置
|
||||
updates[SettingKeyDefaultConcurrency] = strconv.Itoa(settings.DefaultConcurrency)
|
||||
updates[SettingKeyDefaultBalance] = strconv.FormatFloat(settings.DefaultBalance, 'f', 8, 64)
|
||||
defaultSubsJSON, err := json.Marshal(settings.DefaultSubscriptions)
|
||||
if err != nil {
|
||||
return fmt.Errorf("marshal default subscriptions: %w", err)
|
||||
}
|
||||
updates[SettingKeyDefaultSubscriptions] = string(defaultSubsJSON)
|
||||
|
||||
// Model fallback configuration
|
||||
updates[SettingKeyEnableModelFallback] = strconv.FormatBool(settings.EnableModelFallback)
|
||||
@@ -297,7 +325,7 @@ func (s *SettingService) UpdateSettings(ctx context.Context, settings *SystemSet
|
||||
// Claude Code version check
|
||||
updates[SettingKeyMinClaudeCodeVersion] = settings.MinClaudeCodeVersion
|
||||
|
||||
err := s.settingRepo.SetMultiple(ctx, updates)
|
||||
err = s.settingRepo.SetMultiple(ctx, updates)
|
||||
if err == nil {
|
||||
// 先使 inflight singleflight 失效,再刷新缓存,缩小旧值覆盖新值的竞态窗口
|
||||
minVersionSF.Forget("min_version")
|
||||
@@ -312,6 +340,45 @@ func (s *SettingService) UpdateSettings(ctx context.Context, settings *SystemSet
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *SettingService) validateDefaultSubscriptionGroups(ctx context.Context, items []DefaultSubscriptionSetting) error {
|
||||
if len(items) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
checked := make(map[int64]struct{}, len(items))
|
||||
for _, item := range items {
|
||||
if item.GroupID <= 0 {
|
||||
continue
|
||||
}
|
||||
if _, ok := checked[item.GroupID]; ok {
|
||||
return ErrDefaultSubGroupDuplicate.WithMetadata(map[string]string{
|
||||
"group_id": strconv.FormatInt(item.GroupID, 10),
|
||||
})
|
||||
}
|
||||
checked[item.GroupID] = struct{}{}
|
||||
if s.defaultSubGroupReader == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
group, err := s.defaultSubGroupReader.GetByID(ctx, item.GroupID)
|
||||
if err != nil {
|
||||
if errors.Is(err, ErrGroupNotFound) {
|
||||
return ErrDefaultSubGroupInvalid.WithMetadata(map[string]string{
|
||||
"group_id": strconv.FormatInt(item.GroupID, 10),
|
||||
})
|
||||
}
|
||||
return fmt.Errorf("get default subscription group %d: %w", item.GroupID, err)
|
||||
}
|
||||
if !group.IsSubscriptionType() {
|
||||
return ErrDefaultSubGroupInvalid.WithMetadata(map[string]string{
|
||||
"group_id": strconv.FormatInt(item.GroupID, 10),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// IsRegistrationEnabled 检查是否开放注册
|
||||
func (s *SettingService) IsRegistrationEnabled(ctx context.Context) bool {
|
||||
value, err := s.settingRepo.GetValue(ctx, SettingKeyRegistrationEnabled)
|
||||
@@ -411,6 +478,15 @@ func (s *SettingService) GetDefaultBalance(ctx context.Context) float64 {
|
||||
return s.cfg.Default.UserBalance
|
||||
}
|
||||
|
||||
// GetDefaultSubscriptions 获取新用户默认订阅配置列表。
|
||||
func (s *SettingService) GetDefaultSubscriptions(ctx context.Context) []DefaultSubscriptionSetting {
|
||||
value, err := s.settingRepo.GetValue(ctx, SettingKeyDefaultSubscriptions)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
return parseDefaultSubscriptions(value)
|
||||
}
|
||||
|
||||
// InitializeDefaultSettings 初始化默认设置
|
||||
func (s *SettingService) InitializeDefaultSettings(ctx context.Context) error {
|
||||
// 检查是否已有设置
|
||||
@@ -435,6 +511,7 @@ func (s *SettingService) InitializeDefaultSettings(ctx context.Context) error {
|
||||
SettingKeySoraClientEnabled: "false",
|
||||
SettingKeyDefaultConcurrency: strconv.Itoa(s.cfg.Default.UserConcurrency),
|
||||
SettingKeyDefaultBalance: strconv.FormatFloat(s.cfg.Default.UserBalance, 'f', 8, 64),
|
||||
SettingKeyDefaultSubscriptions: "[]",
|
||||
SettingKeySMTPPort: "587",
|
||||
SettingKeySMTPUseTLS: "false",
|
||||
// Model fallback defaults
|
||||
@@ -511,6 +588,7 @@ func (s *SettingService) parseSettings(settings map[string]string) *SystemSettin
|
||||
} else {
|
||||
result.DefaultBalance = s.cfg.Default.UserBalance
|
||||
}
|
||||
result.DefaultSubscriptions = parseDefaultSubscriptions(settings[SettingKeyDefaultSubscriptions])
|
||||
|
||||
// 敏感信息直接返回,方便测试连接时使用
|
||||
result.SMTPPassword = settings[SettingKeySMTPPassword]
|
||||
@@ -595,6 +673,31 @@ func isFalseSettingValue(value string) bool {
|
||||
}
|
||||
}
|
||||
|
||||
func parseDefaultSubscriptions(raw string) []DefaultSubscriptionSetting {
|
||||
raw = strings.TrimSpace(raw)
|
||||
if raw == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
var items []DefaultSubscriptionSetting
|
||||
if err := json.Unmarshal([]byte(raw), &items); err != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
normalized := make([]DefaultSubscriptionSetting, 0, len(items))
|
||||
for _, item := range items {
|
||||
if item.GroupID <= 0 || item.ValidityDays <= 0 {
|
||||
continue
|
||||
}
|
||||
if item.ValidityDays > MaxValidityDays {
|
||||
item.ValidityDays = MaxValidityDays
|
||||
}
|
||||
normalized = append(normalized, item)
|
||||
}
|
||||
|
||||
return normalized
|
||||
}
|
||||
|
||||
// getStringOrDefault 获取字符串值或默认值
|
||||
func (s *SettingService) getStringOrDefault(settings map[string]string, key, defaultValue string) string {
|
||||
if value, ok := settings[key]; ok && value != "" {
|
||||
|
||||
182
backend/internal/service/setting_service_update_test.go
Normal file
182
backend/internal/service/setting_service_update_test.go
Normal file
@@ -0,0 +1,182 @@
|
||||
//go:build unit
|
||||
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"testing"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
type settingUpdateRepoStub struct {
|
||||
updates map[string]string
|
||||
}
|
||||
|
||||
func (s *settingUpdateRepoStub) Get(ctx context.Context, key string) (*Setting, error) {
|
||||
panic("unexpected Get call")
|
||||
}
|
||||
|
||||
func (s *settingUpdateRepoStub) GetValue(ctx context.Context, key string) (string, error) {
|
||||
panic("unexpected GetValue call")
|
||||
}
|
||||
|
||||
func (s *settingUpdateRepoStub) Set(ctx context.Context, key, value string) error {
|
||||
panic("unexpected Set call")
|
||||
}
|
||||
|
||||
func (s *settingUpdateRepoStub) GetMultiple(ctx context.Context, keys []string) (map[string]string, error) {
|
||||
panic("unexpected GetMultiple call")
|
||||
}
|
||||
|
||||
func (s *settingUpdateRepoStub) SetMultiple(ctx context.Context, settings map[string]string) error {
|
||||
s.updates = make(map[string]string, len(settings))
|
||||
for k, v := range settings {
|
||||
s.updates[k] = v
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *settingUpdateRepoStub) GetAll(ctx context.Context) (map[string]string, error) {
|
||||
panic("unexpected GetAll call")
|
||||
}
|
||||
|
||||
func (s *settingUpdateRepoStub) Delete(ctx context.Context, key string) error {
|
||||
panic("unexpected Delete call")
|
||||
}
|
||||
|
||||
type defaultSubGroupReaderStub struct {
|
||||
byID map[int64]*Group
|
||||
errBy map[int64]error
|
||||
calls []int64
|
||||
}
|
||||
|
||||
func (s *defaultSubGroupReaderStub) GetByID(ctx context.Context, id int64) (*Group, error) {
|
||||
s.calls = append(s.calls, id)
|
||||
if err, ok := s.errBy[id]; ok {
|
||||
return nil, err
|
||||
}
|
||||
if g, ok := s.byID[id]; ok {
|
||||
return g, nil
|
||||
}
|
||||
return nil, ErrGroupNotFound
|
||||
}
|
||||
|
||||
func TestSettingService_UpdateSettings_DefaultSubscriptions_ValidGroup(t *testing.T) {
|
||||
repo := &settingUpdateRepoStub{}
|
||||
groupReader := &defaultSubGroupReaderStub{
|
||||
byID: map[int64]*Group{
|
||||
11: {ID: 11, SubscriptionType: SubscriptionTypeSubscription},
|
||||
},
|
||||
}
|
||||
svc := NewSettingService(repo, &config.Config{})
|
||||
svc.SetDefaultSubscriptionGroupReader(groupReader)
|
||||
|
||||
err := svc.UpdateSettings(context.Background(), &SystemSettings{
|
||||
DefaultSubscriptions: []DefaultSubscriptionSetting{
|
||||
{GroupID: 11, ValidityDays: 30},
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, []int64{11}, groupReader.calls)
|
||||
|
||||
raw, ok := repo.updates[SettingKeyDefaultSubscriptions]
|
||||
require.True(t, ok)
|
||||
|
||||
var got []DefaultSubscriptionSetting
|
||||
require.NoError(t, json.Unmarshal([]byte(raw), &got))
|
||||
require.Equal(t, []DefaultSubscriptionSetting{
|
||||
{GroupID: 11, ValidityDays: 30},
|
||||
}, got)
|
||||
}
|
||||
|
||||
func TestSettingService_UpdateSettings_DefaultSubscriptions_RejectsNonSubscriptionGroup(t *testing.T) {
|
||||
repo := &settingUpdateRepoStub{}
|
||||
groupReader := &defaultSubGroupReaderStub{
|
||||
byID: map[int64]*Group{
|
||||
12: {ID: 12, SubscriptionType: SubscriptionTypeStandard},
|
||||
},
|
||||
}
|
||||
svc := NewSettingService(repo, &config.Config{})
|
||||
svc.SetDefaultSubscriptionGroupReader(groupReader)
|
||||
|
||||
err := svc.UpdateSettings(context.Background(), &SystemSettings{
|
||||
DefaultSubscriptions: []DefaultSubscriptionSetting{
|
||||
{GroupID: 12, ValidityDays: 7},
|
||||
},
|
||||
})
|
||||
require.Error(t, err)
|
||||
require.Equal(t, "DEFAULT_SUBSCRIPTION_GROUP_INVALID", infraerrors.Reason(err))
|
||||
require.Nil(t, repo.updates)
|
||||
}
|
||||
|
||||
func TestSettingService_UpdateSettings_DefaultSubscriptions_RejectsNotFoundGroup(t *testing.T) {
|
||||
repo := &settingUpdateRepoStub{}
|
||||
groupReader := &defaultSubGroupReaderStub{
|
||||
errBy: map[int64]error{
|
||||
13: ErrGroupNotFound,
|
||||
},
|
||||
}
|
||||
svc := NewSettingService(repo, &config.Config{})
|
||||
svc.SetDefaultSubscriptionGroupReader(groupReader)
|
||||
|
||||
err := svc.UpdateSettings(context.Background(), &SystemSettings{
|
||||
DefaultSubscriptions: []DefaultSubscriptionSetting{
|
||||
{GroupID: 13, ValidityDays: 7},
|
||||
},
|
||||
})
|
||||
require.Error(t, err)
|
||||
require.Equal(t, "DEFAULT_SUBSCRIPTION_GROUP_INVALID", infraerrors.Reason(err))
|
||||
require.Equal(t, "13", infraerrors.FromError(err).Metadata["group_id"])
|
||||
require.Nil(t, repo.updates)
|
||||
}
|
||||
|
||||
func TestSettingService_UpdateSettings_DefaultSubscriptions_RejectsDuplicateGroup(t *testing.T) {
|
||||
repo := &settingUpdateRepoStub{}
|
||||
groupReader := &defaultSubGroupReaderStub{
|
||||
byID: map[int64]*Group{
|
||||
11: {ID: 11, SubscriptionType: SubscriptionTypeSubscription},
|
||||
},
|
||||
}
|
||||
svc := NewSettingService(repo, &config.Config{})
|
||||
svc.SetDefaultSubscriptionGroupReader(groupReader)
|
||||
|
||||
err := svc.UpdateSettings(context.Background(), &SystemSettings{
|
||||
DefaultSubscriptions: []DefaultSubscriptionSetting{
|
||||
{GroupID: 11, ValidityDays: 30},
|
||||
{GroupID: 11, ValidityDays: 60},
|
||||
},
|
||||
})
|
||||
require.Error(t, err)
|
||||
require.Equal(t, "DEFAULT_SUBSCRIPTION_GROUP_DUPLICATE", infraerrors.Reason(err))
|
||||
require.Equal(t, "11", infraerrors.FromError(err).Metadata["group_id"])
|
||||
require.Nil(t, repo.updates)
|
||||
}
|
||||
|
||||
func TestSettingService_UpdateSettings_DefaultSubscriptions_RejectsDuplicateGroupWithoutGroupReader(t *testing.T) {
|
||||
repo := &settingUpdateRepoStub{}
|
||||
svc := NewSettingService(repo, &config.Config{})
|
||||
|
||||
err := svc.UpdateSettings(context.Background(), &SystemSettings{
|
||||
DefaultSubscriptions: []DefaultSubscriptionSetting{
|
||||
{GroupID: 11, ValidityDays: 30},
|
||||
{GroupID: 11, ValidityDays: 60},
|
||||
},
|
||||
})
|
||||
require.Error(t, err)
|
||||
require.Equal(t, "DEFAULT_SUBSCRIPTION_GROUP_DUPLICATE", infraerrors.Reason(err))
|
||||
require.Equal(t, "11", infraerrors.FromError(err).Metadata["group_id"])
|
||||
require.Nil(t, repo.updates)
|
||||
}
|
||||
|
||||
func TestParseDefaultSubscriptions_NormalizesValues(t *testing.T) {
|
||||
got := parseDefaultSubscriptions(`[{"group_id":11,"validity_days":30},{"group_id":11,"validity_days":60},{"group_id":0,"validity_days":10},{"group_id":12,"validity_days":99999}]`)
|
||||
require.Equal(t, []DefaultSubscriptionSetting{
|
||||
{GroupID: 11, ValidityDays: 30},
|
||||
{GroupID: 11, ValidityDays: 60},
|
||||
{GroupID: 12, ValidityDays: MaxValidityDays},
|
||||
}, got)
|
||||
}
|
||||
@@ -43,6 +43,7 @@ type SystemSettings struct {
|
||||
|
||||
DefaultConcurrency int
|
||||
DefaultBalance float64
|
||||
DefaultSubscriptions []DefaultSubscriptionSetting
|
||||
|
||||
// Model fallback configuration
|
||||
EnableModelFallback bool `json:"enable_model_fallback"`
|
||||
@@ -65,6 +66,11 @@ type SystemSettings struct {
|
||||
MinClaudeCodeVersion string
|
||||
}
|
||||
|
||||
type DefaultSubscriptionSetting struct {
|
||||
GroupID int64 `json:"group_id"`
|
||||
ValidityDays int `json:"validity_days"`
|
||||
}
|
||||
|
||||
type PublicSettings struct {
|
||||
RegistrationEnabled bool
|
||||
EmailVerifyEnabled bool
|
||||
|
||||
@@ -284,6 +284,13 @@ func ProvideAPIKeyAuthCacheInvalidator(apiKeyService *APIKeyService) APIKeyAuthC
|
||||
return apiKeyService
|
||||
}
|
||||
|
||||
// ProvideSettingService wires SettingService with group reader for default subscription validation.
|
||||
func ProvideSettingService(settingRepo SettingRepository, groupRepo GroupRepository, cfg *config.Config) *SettingService {
|
||||
svc := NewSettingService(settingRepo, cfg)
|
||||
svc.SetDefaultSubscriptionGroupReader(groupRepo)
|
||||
return svc
|
||||
}
|
||||
|
||||
// ProviderSet is the Wire provider set for all services
|
||||
var ProviderSet = wire.NewSet(
|
||||
// Core services
|
||||
@@ -326,7 +333,7 @@ var ProviderSet = wire.NewSet(
|
||||
ProvideRateLimitService,
|
||||
NewAccountUsageService,
|
||||
NewAccountTestService,
|
||||
NewSettingService,
|
||||
ProvideSettingService,
|
||||
NewDataManagementService,
|
||||
ProvideOpsSystemLogSink,
|
||||
NewOpsService,
|
||||
@@ -339,6 +346,7 @@ var ProviderSet = wire.NewSet(
|
||||
ProvideEmailQueueService,
|
||||
NewTurnstileService,
|
||||
NewSubscriptionService,
|
||||
wire.Bind(new(DefaultSubscriptionAssigner), new(*SubscriptionService)),
|
||||
ProvideConcurrencyService,
|
||||
NewUsageRecordWorkerPool,
|
||||
ProvideSchedulerSnapshotService,
|
||||
|
||||
Reference in New Issue
Block a user