fix: address audit findings for websearch and balance notification

- Fix GetByKeyForAuth not selecting balance notify fields (notifications
  never triggered in gateway path)
- Fix provider-level ProxyURL never resolved: inject ProxyRepository into
  SettingService, resolve proxy URLs when building Manager
- Fix admin manual balance adjustment not updating total_recharged
- Add threshold_type input validation (reject invalid values)
- Fix user threshold_type inheritance: custom threshold defaults to "fixed"
  instead of inheriting global type (prevents $5 being treated as 5%)
- Add try-catch for clipboard.writeText (fails on non-HTTPS)
- Add SetTotalRecharged to user Update for admin balance operations
This commit is contained in:
erio
2026-04-12 14:43:12 +08:00
parent f694afbbf4
commit 9e33d0c4c0
11 changed files with 102 additions and 43 deletions

View File

@@ -143,6 +143,11 @@ func (r *apiKeyRepository) GetByKeyForAuth(ctx context.Context, key string) (*se
user.FieldRole,
user.FieldBalance,
user.FieldConcurrency,
user.FieldBalanceNotifyEnabled,
user.FieldBalanceNotifyThresholdType,
user.FieldBalanceNotifyThreshold,
user.FieldBalanceNotifyExtraEmails,
user.FieldTotalRecharged,
)
}).
WithGroup(func(q *dbent.GroupQuery) {

View File

@@ -150,7 +150,8 @@ func (r *userRepository) Update(ctx context.Context, userIn *service.User) error
SetBalanceNotifyEnabled(userIn.BalanceNotifyEnabled).
SetBalanceNotifyThresholdType(userIn.BalanceNotifyThresholdType).
SetNillableBalanceNotifyThreshold(userIn.BalanceNotifyThreshold).
SetBalanceNotifyExtraEmails(marshalExtraEmails(userIn.BalanceNotifyExtraEmails))
SetBalanceNotifyExtraEmails(marshalExtraEmails(userIn.BalanceNotifyExtraEmails)).
SetTotalRecharged(userIn.TotalRecharged)
if userIn.BalanceNotifyThreshold == nil {
updateOp = updateOp.ClearBalanceNotifyThreshold()
}

View File

@@ -59,7 +59,7 @@ func ProvideRouter(
}
// Wire up websearch Manager builder so it initializes on startup and rebuilds on config save.
settingService.SetWebSearchManagerBuilder(context.Background(), func(cfg *service.WebSearchEmulationConfig) {
settingService.SetWebSearchManagerBuilder(context.Background(), func(cfg *service.WebSearchEmulationConfig, proxyURLs map[int64]string) {
if cfg == nil || !cfg.Enabled || len(cfg.Providers) == 0 {
service.SetWebSearchManager(nil)
return
@@ -80,6 +80,9 @@ func ProvideRouter(
}
if p.ProxyID != nil {
pc.ProxyID = *p.ProxyID
if u, ok := proxyURLs[*p.ProxyID]; ok {
pc.ProxyURL = u
}
}
configs = append(configs, pc)
}

View File

@@ -709,6 +709,12 @@ func (s *adminServiceImpl) UpdateUserBalance(ctx context.Context, userID int64,
return nil, fmt.Errorf("balance cannot be negative, current balance: %.2f, requested operation would result in: %.2f", oldBalance, user.Balance)
}
// Track cumulative recharge for percentage-based balance notifications
balanceDelta := user.Balance - oldBalance
if balanceDelta > 0 {
user.TotalRecharged += balanceDelta
}
if err := s.userRepo.Update(ctx, user); err != nil {
return nil, err
}

View File

@@ -77,12 +77,12 @@ func (s *BalanceNotifyService) CheckBalanceAfterDeduction(ctx context.Context, u
}
// resolveEffectiveThreshold computes the actual USD threshold based on type and user settings.
// When user sets a custom threshold, their type is used independently (defaults to "fixed" if unset).
func (s *BalanceNotifyService) resolveEffectiveThreshold(user *User, globalType string, globalValue float64) float64 {
// User-level override takes full precedence
if user.BalanceNotifyThreshold != nil {
thresholdType := user.BalanceNotifyThresholdType
if thresholdType == "" {
thresholdType = globalType
thresholdType = ThresholdTypeFixed // user custom value defaults to fixed, not inherited
}
return computeThreshold(thresholdType, *user.BalanceNotifyThreshold, user.TotalRecharged)
}

View File

@@ -99,13 +99,19 @@ type DefaultSubscriptionGroupReader interface {
GetByID(ctx context.Context, id int64) (*Group, error)
}
// WebSearchManagerBuilder creates a websearch.Manager from config (injected by infra layer).
// proxyURLs maps proxy ID to resolved URL for provider-level proxy support.
type WebSearchManagerBuilder func(cfg *WebSearchEmulationConfig, proxyURLs map[int64]string)
// SettingService 系统设置服务
type SettingService struct {
settingRepo SettingRepository
defaultSubGroupReader DefaultSubscriptionGroupReader
cfg *config.Config
onUpdate func() // Callback when settings are updated (for cache invalidation)
version string // Application version
settingRepo SettingRepository
defaultSubGroupReader DefaultSubscriptionGroupReader
proxyRepo ProxyRepository // for resolving websearch provider proxy URLs
cfg *config.Config
onUpdate func() // Callback when settings are updated (for cache invalidation)
version string // Application version
webSearchManagerBuilder WebSearchManagerBuilder
}
// NewSettingService 创建系统设置服务实例
@@ -121,6 +127,11 @@ func (s *SettingService) SetDefaultSubscriptionGroupReader(reader DefaultSubscri
s.defaultSubGroupReader = reader
}
// SetProxyRepository injects a proxy repo for resolving websearch provider proxy URLs.
func (s *SettingService) SetProxyRepository(repo ProxyRepository) {
s.proxyRepo = repo
}
// GetAllSettings 获取所有系统设置
func (s *SettingService) GetAllSettings(ctx context.Context) (*SystemSettings, error) {
settings, err := s.settingRepo.GetAll(ctx)
@@ -598,7 +609,7 @@ func (s *SettingService) UpdateSettings(ctx context.Context, settings *SystemSet
// Balance low notification
updates[SettingKeyBalanceLowNotifyEnabled] = strconv.FormatBool(settings.BalanceLowNotifyEnabled)
thresholdType := settings.BalanceLowNotifyThresholdType
if thresholdType == "" {
if thresholdType != ThresholdTypeFixed && thresholdType != ThresholdTypePercentage {
thresholdType = ThresholdTypeFixed
}
updates[SettingKeyBalanceLowNotifyThresholdType] = thresholdType
@@ -1231,6 +1242,14 @@ func (s *SettingService) parseSettings(settings map[string]string) *SystemSettin
result.EnableMetadataPassthrough = settings[SettingKeyEnableMetadataPassthrough] == "true"
result.EnableCCHSigning = settings[SettingKeyEnableCCHSigning] == "true"
// Web search emulation: quick enabled check from the JSON config
if raw := settings[SettingKeyWebSearchEmulationConfig]; raw != "" {
var wsCfg WebSearchEmulationConfig
if err := json.Unmarshal([]byte(raw), &wsCfg); err == nil {
result.WebSearchEmulationEnabled = wsCfg.Enabled && len(wsCfg.Providers) > 0
}
}
// Balance low notification
result.BalanceLowNotifyEnabled = settings[SettingKeyBalanceLowNotifyEnabled] == "true"
result.BalanceLowNotifyThresholdType = settings[SettingKeyBalanceLowNotifyThresholdType]

View File

@@ -145,7 +145,9 @@ func (s *UserService) UpdateProfile(ctx context.Context, userID int64, req Updat
user.BalanceNotifyEnabled = *req.BalanceNotifyEnabled
}
if req.BalanceNotifyThresholdType != nil {
user.BalanceNotifyThresholdType = *req.BalanceNotifyThresholdType
if *req.BalanceNotifyThresholdType == ThresholdTypeFixed || *req.BalanceNotifyThresholdType == ThresholdTypePercentage {
user.BalanceNotifyThresholdType = *req.BalanceNotifyThresholdType
}
}
if req.BalanceNotifyThreshold != nil {
if *req.BalanceNotifyThreshold <= 0 {

View File

@@ -10,7 +10,6 @@ import (
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
"github.com/Wei-Shaw/sub2api/internal/pkg/websearch"
"github.com/redis/go-redis/v9"
"golang.org/x/sync/singleflight"
)
@@ -85,8 +84,7 @@ const (
// GetWebSearchEmulationConfig returns the configuration with in-process cache + singleflight.
func (s *SettingService) GetWebSearchEmulationConfig(ctx context.Context) (*WebSearchEmulationConfig, error) {
if cached := webSearchEmulationCache.Load(); cached != nil {
c := cached.(*cachedWebSearchEmulationConfig)
if time.Now().UnixNano() < c.expiresAt {
if c, ok := cached.(*cachedWebSearchEmulationConfig); ok && time.Now().UnixNano() < c.expiresAt {
return c.config, nil
}
}
@@ -96,7 +94,10 @@ func (s *SettingService) GetWebSearchEmulationConfig(ctx context.Context) (*WebS
if err != nil {
return &WebSearchEmulationConfig{}, err
}
return result.(*WebSearchEmulationConfig), nil
if cfg, ok := result.(*WebSearchEmulationConfig); ok {
return cfg, nil
}
return &WebSearchEmulationConfig{}, nil
}
func (s *SettingService) loadWebSearchConfigFromDB() (*WebSearchEmulationConfig, error) {
@@ -154,7 +155,7 @@ func (s *SettingService) SaveWebSearchEmulationConfig(ctx context.Context, cfg *
})
// Hot-reload: rebuild the global Manager with new config
s.RebuildWebSearchManager(ctx)
s.rebuildWebSearchManager(ctx)
return nil
}
@@ -196,34 +197,51 @@ func (s *SettingService) IsWebSearchEmulationEnabled(ctx context.Context) bool {
return cfg.Enabled && len(cfg.Providers) > 0
}
// SetWebSearchRedisClient injects the Redis client used for quota tracking.
// Call after construction, before first use. Triggers initial Manager build.
func (s *SettingService) SetWebSearchRedisClient(ctx context.Context, redisClient *redis.Client) {
s.webSearchRedis = redisClient
s.RebuildWebSearchManager(ctx)
// SetWebSearchManagerBuilder injects a callback that creates and wires a websearch.Manager.
// The infra layer (main/wire) provides this builder, keeping redis out of the service layer.
// Triggers initial build.
func (s *SettingService) SetWebSearchManagerBuilder(ctx context.Context, builder WebSearchManagerBuilder) {
s.webSearchManagerBuilder = builder
s.rebuildWebSearchManager(ctx)
}
// RebuildWebSearchManager reads the current config and (re)creates the global websearch.Manager.
// Called on startup and after SaveWebSearchEmulationConfig.
func (s *SettingService) RebuildWebSearchManager(ctx context.Context) {
// rebuildWebSearchManager reads the current config, resolves proxy URLs, and invokes the builder.
func (s *SettingService) rebuildWebSearchManager(ctx context.Context) {
if s.webSearchManagerBuilder == nil {
return
}
cfg, err := s.GetWebSearchEmulationConfig(ctx)
if err != nil || !cfg.Enabled || len(cfg.Providers) == 0 {
if err != nil {
SetWebSearchManager(nil)
return
}
providerConfigs := make([]websearch.ProviderConfig, 0, len(cfg.Providers))
for _, p := range cfg.Providers {
providerConfigs = append(providerConfigs, websearch.ProviderConfig{
Type: p.Type,
APIKey: p.APIKey,
Priority: p.Priority,
QuotaLimit: p.QuotaLimit,
QuotaRefreshInterval: p.QuotaRefreshInterval,
ExpiresAt: p.ExpiresAt,
})
proxyURLs := s.resolveProviderProxyURLs(ctx, cfg)
s.webSearchManagerBuilder(cfg, proxyURLs)
}
// resolveProviderProxyURLs collects proxy IDs from providers and resolves them to URLs.
func (s *SettingService) resolveProviderProxyURLs(ctx context.Context, cfg *WebSearchEmulationConfig) map[int64]string {
if cfg == nil || s.proxyRepo == nil {
return nil
}
SetWebSearchManager(websearch.NewManager(providerConfigs, s.webSearchRedis))
slog.Info("websearch: manager rebuilt", "provider_count", len(providerConfigs))
var ids []int64
for _, p := range cfg.Providers {
if p.ProxyID != nil && *p.ProxyID > 0 {
ids = append(ids, *p.ProxyID)
}
}
if len(ids) == 0 {
return nil
}
proxies, err := s.proxyRepo.ListByIDs(ctx, ids)
if err != nil {
return nil
}
result := make(map[int64]string, len(proxies))
for _, px := range proxies {
result[px.ID] = px.URL()
}
return result
}
// WebSearchTestResult holds the result of a search test.

View File

@@ -373,10 +373,11 @@ func ProvideBackupService(
return svc
}
// ProvideSettingService wires SettingService with group reader for default subscription validation.
func ProvideSettingService(settingRepo SettingRepository, groupRepo GroupRepository, cfg *config.Config) *SettingService {
// ProvideSettingService wires SettingService with group reader and proxy repo.
func ProvideSettingService(settingRepo SettingRepository, groupRepo GroupRepository, proxyRepo ProxyRepository, cfg *config.Config) *SettingService {
svc := NewSettingService(settingRepo, cfg)
svc.SetDefaultSubscriptionGroupReader(groupRepo)
svc.SetProxyRepository(proxyRepo)
return svc
}