fix: address audit findings for websearch, email verification, and pricing

- Fix websearch provider failover: proxy error from provider-specific proxy
  now continues to next provider instead of aborting the entire loop
- Fix SMTP failure locking users out: send email first, then write cache
  and increment rate counter
- Fix notify email cache key case sensitivity: normalize to lowercase
- Add OriginalPrice validation to validatePlanPatch and validatePlanRequired
- Add empty scope validation for channel pricing rules (group_ids/account_ids)
- Add platform color to account search dropdown in channel pricing rules
This commit is contained in:
erio
2026-04-13 23:35:59 +08:00
parent 1b7c295199
commit 74f8a30f86
7 changed files with 103 additions and 26 deletions

View File

@@ -1,6 +1,7 @@
package admin
import (
"fmt"
"strconv"
"strings"
@@ -351,6 +352,11 @@ func (h *ChannelHandler) Create(c *gin.Context) {
var statsRules []service.AccountStatsPricingRule
for i, r := range req.AccountStatsPricingRules {
if len(r.GroupIDs) == 0 && len(r.AccountIDs) == 0 {
response.ErrorFrom(c, infraerrors.BadRequest("PRICING_RULE_EMPTY_SCOPE",
fmt.Sprintf("pricing rule #%d must have at least one group or account", i+1)))
return
}
rule := accountStatsPricingRuleRequestToService(r)
rule.SortOrder = i
statsRules = append(statsRules, rule)
@@ -409,6 +415,11 @@ func (h *ChannelHandler) Update(c *gin.Context) {
if req.AccountStatsPricingRules != nil {
statsRules := make([]service.AccountStatsPricingRule, 0, len(*req.AccountStatsPricingRules))
for i, r := range *req.AccountStatsPricingRules {
if len(r.GroupIDs) == 0 && len(r.AccountIDs) == 0 {
response.ErrorFrom(c, infraerrors.BadRequest("PRICING_RULE_EMPTY_SCOPE",
fmt.Sprintf("pricing rule #%d must have at least one group or account", i+1)))
return
}
rule := accountStatsPricingRuleRequestToService(r)
rule.SortOrder = i
statsRules = append(statsRules, rule)

View File

@@ -111,9 +111,18 @@ func (m *Manager) SearchWithBestProvider(ctx context.Context, req SearchRequest)
}
if isProxyError(err) {
m.markProxyUnavailable(ctx, cfg, req.ProxyURL)
slog.Warn("websearch: proxy error, marking unavailable",
if req.ProxyURL != "" {
// Account-level proxy is shared by all providers — no point
// trying others with the same broken proxy; signal account switch.
slog.Warn("websearch: account proxy error, aborting failover",
"provider", cfg.Type, "error", err)
return nil, "", fmt.Errorf("%w: %s", ErrProxyUnavailable, err.Error())
}
// Provider-specific proxy failed — try the next provider which
// may use a different (or no) proxy.
slog.Warn("websearch: provider proxy error, trying next provider",
"provider", cfg.Type, "error", err)
return nil, "", fmt.Errorf("%w: %s", ErrProxyUnavailable, err.Error())
continue
}
slog.Warn("websearch: provider search failed",
"provider", cfg.Type, "error", err)

View File

@@ -4,6 +4,7 @@ import (
"context"
"encoding/json"
"fmt"
"strings"
"time"
"github.com/Wei-Shaw/sub2api/internal/service"
@@ -24,8 +25,10 @@ func verifyCodeKey(email string) string {
}
// notifyVerifyKey generates the Redis key for notify email verification code.
// Email is lowercased to prevent case-sensitive key mismatch (the business layer
// uses strings.EqualFold for comparison).
func notifyVerifyKey(email string) string {
return notifyVerifyKeyPrefix + email
return notifyVerifyKeyPrefix + strings.ToLower(email)
}
// passwordResetKey generates the Redis key for password reset token.

View File

@@ -12,7 +12,7 @@ import (
)
// validatePlanRequired checks that all required fields for a plan are provided.
func validatePlanRequired(name string, groupID int64, price float64, validityDays int, validityUnit string) error {
func validatePlanRequired(name string, groupID int64, price float64, validityDays int, validityUnit string, originalPrice *float64) error {
if strings.TrimSpace(name) == "" {
return infraerrors.BadRequest("PLAN_NAME_REQUIRED", "plan name is required")
}
@@ -28,6 +28,9 @@ func validatePlanRequired(name string, groupID int64, price float64, validityDay
if strings.TrimSpace(validityUnit) == "" {
return infraerrors.BadRequest("PLAN_VALIDITY_UNIT_REQUIRED", "validity unit is required")
}
if originalPrice != nil && *originalPrice < 0 {
return infraerrors.BadRequest("PLAN_ORIGINAL_PRICE_INVALID", "original price must be >= 0")
}
return nil
}
@@ -48,6 +51,9 @@ func validatePlanPatch(req UpdatePlanRequest) error {
if req.ValidityUnit != nil && strings.TrimSpace(*req.ValidityUnit) == "" {
return infraerrors.BadRequest("PLAN_VALIDITY_UNIT_REQUIRED", "validity unit is required")
}
if req.OriginalPrice != nil && *req.OriginalPrice < 0 {
return infraerrors.BadRequest("PLAN_ORIGINAL_PRICE_INVALID", "original price must be >= 0")
}
return nil
}
@@ -115,7 +121,7 @@ func (s *PaymentConfigService) ListPlansForSale(ctx context.Context) ([]*dbent.S
}
func (s *PaymentConfigService) CreatePlan(ctx context.Context, req CreatePlanRequest) (*dbent.SubscriptionPlan, error) {
if err := validatePlanRequired(req.Name, req.GroupID, req.Price, req.ValidityDays, req.ValidityUnit); err != nil {
if err := validatePlanRequired(req.Name, req.GroupID, req.Price, req.ValidityDays, req.ValidityUnit, req.OriginalPrice); err != nil {
return nil, err
}
b := s.entClient.SubscriptionPlan.Create().

View File

@@ -9,81 +9,122 @@ import (
)
func TestValidatePlanRequired_AllValid(t *testing.T) {
err := validatePlanRequired("Pro", 1, 9.99, 30, "days")
err := validatePlanRequired("Pro", 1, 9.99, 30, "days", nil)
require.NoError(t, err)
}
func TestValidatePlanRequired_EmptyName(t *testing.T) {
err := validatePlanRequired("", 1, 9.99, 30, "days")
err := validatePlanRequired("", 1, 9.99, 30, "days", nil)
require.Error(t, err)
require.Contains(t, err.Error(), "plan name")
}
func TestValidatePlanRequired_WhitespaceName(t *testing.T) {
err := validatePlanRequired(" ", 1, 9.99, 30, "days")
err := validatePlanRequired(" ", 1, 9.99, 30, "days", nil)
require.Error(t, err)
require.Contains(t, err.Error(), "plan name")
}
func TestValidatePlanRequired_ZeroGroupID(t *testing.T) {
err := validatePlanRequired("Pro", 0, 9.99, 30, "days")
err := validatePlanRequired("Pro", 0, 9.99, 30, "days", nil)
require.Error(t, err)
require.Contains(t, err.Error(), "group")
}
func TestValidatePlanRequired_NegativeGroupID(t *testing.T) {
err := validatePlanRequired("Pro", -1, 9.99, 30, "days")
err := validatePlanRequired("Pro", -1, 9.99, 30, "days", nil)
require.Error(t, err)
require.Contains(t, err.Error(), "group")
}
func TestValidatePlanRequired_ZeroPrice(t *testing.T) {
err := validatePlanRequired("Pro", 1, 0, 30, "days")
err := validatePlanRequired("Pro", 1, 0, 30, "days", nil)
require.Error(t, err)
require.Contains(t, err.Error(), "price")
}
func TestValidatePlanRequired_NegativePrice(t *testing.T) {
err := validatePlanRequired("Pro", 1, -5, 30, "days")
err := validatePlanRequired("Pro", 1, -5, 30, "days", nil)
require.Error(t, err)
require.Contains(t, err.Error(), "price")
}
func TestValidatePlanRequired_ZeroValidityDays(t *testing.T) {
err := validatePlanRequired("Pro", 1, 9.99, 0, "days")
err := validatePlanRequired("Pro", 1, 9.99, 0, "days", nil)
require.Error(t, err)
require.Contains(t, err.Error(), "validity days")
}
func TestValidatePlanRequired_NegativeValidityDays(t *testing.T) {
err := validatePlanRequired("Pro", 1, 9.99, -7, "days")
err := validatePlanRequired("Pro", 1, 9.99, -7, "days", nil)
require.Error(t, err)
require.Contains(t, err.Error(), "validity days")
}
func TestValidatePlanRequired_EmptyValidityUnit(t *testing.T) {
err := validatePlanRequired("Pro", 1, 9.99, 30, "")
err := validatePlanRequired("Pro", 1, 9.99, 30, "", nil)
require.Error(t, err)
require.Contains(t, err.Error(), "validity unit")
}
func TestValidatePlanRequired_WhitespaceValidityUnit(t *testing.T) {
err := validatePlanRequired("Pro", 1, 9.99, 30, " ")
err := validatePlanRequired("Pro", 1, 9.99, 30, " ", nil)
require.Error(t, err)
require.Contains(t, err.Error(), "validity unit")
}
func TestValidatePlanRequired_NameValidatedFirst(t *testing.T) {
// When multiple fields are invalid, name should be reported first
// (follows the order of checks in the function).
err := validatePlanRequired("", 0, 0, 0, "")
err := validatePlanRequired("", 0, 0, 0, "", nil)
require.Error(t, err)
require.Contains(t, err.Error(), "plan name")
}
func TestValidatePlanRequired_TrimmedValidName(t *testing.T) {
// Whitespace-surrounded but non-empty name is accepted (trimmed check only
// rejects pure whitespace).
err := validatePlanRequired(" Pro ", 1, 9.99, 30, "days")
err := validatePlanRequired(" Pro ", 1, 9.99, 30, "days", nil)
require.NoError(t, err)
}
func TestValidatePlanRequired_NegativeOriginalPrice(t *testing.T) {
neg := -10.0
err := validatePlanRequired("Pro", 1, 9.99, 30, "days", &neg)
require.Error(t, err)
require.Contains(t, err.Error(), "original price")
}
func TestValidatePlanRequired_ZeroOriginalPrice(t *testing.T) {
zero := 0.0
err := validatePlanRequired("Pro", 1, 9.99, 30, "days", &zero)
require.NoError(t, err)
}
func TestValidatePlanRequired_ValidOriginalPrice(t *testing.T) {
op := 19.99
err := validatePlanRequired("Pro", 1, 9.99, 30, "days", &op)
require.NoError(t, err)
}
// --- validatePlanPatch tests ---
func TestValidatePlanPatch_NegativeOriginalPrice(t *testing.T) {
neg := -5.0
err := validatePlanPatch(UpdatePlanRequest{OriginalPrice: &neg})
require.Error(t, err)
require.Contains(t, err.Error(), "original price")
}
func TestValidatePlanPatch_ZeroOriginalPrice(t *testing.T) {
zero := 0.0
err := validatePlanPatch(UpdatePlanRequest{OriginalPrice: &zero})
require.NoError(t, err)
}
func TestValidatePlanPatch_ValidOriginalPrice(t *testing.T) {
op := 29.99
err := validatePlanPatch(UpdatePlanRequest{OriginalPrice: &op})
require.NoError(t, err)
}
func TestValidatePlanPatch_NilOriginalPrice(t *testing.T) {
err := validatePlanPatch(UpdatePlanRequest{OriginalPrice: nil})
require.NoError(t, err)
}

View File

@@ -291,6 +291,12 @@ func (s *UserService) SendNotifyEmailCode(ctx context.Context, userID int64, ema
return fmt.Errorf("generate code: %w", err)
}
// Send email first — if SMTP fails, don't write cache or increment counters,
// so the user is not locked out by cooldown/rate-limit for a code they never received.
if err := s.sendNotifyVerifyEmail(ctx, emailService, email, code); err != nil {
return err
}
if err := saveNotifyVerifyCode(ctx, cache, email, code); err != nil {
return err
}
@@ -300,7 +306,7 @@ func (s *UserService) SendNotifyEmailCode(ctx context.Context, userID int64, ema
slog.Error("failed to increment notify code user rate", "user_id", userID, "error", err)
}
return s.sendNotifyVerifyEmail(ctx, emailService, email, code)
return nil
}
// checkNotifyCodeRateLimit checks both email cooldown and user-level rate limit.