From 74f8a30f861f2b5072f5916265abbf61d3448b12 Mon Sep 17 00:00:00 2001 From: erio Date: Mon, 13 Apr 2026 23:35:59 +0800 Subject: [PATCH] fix: address audit findings for websearch, email verification, and pricing - Fix websearch provider failover: proxy error from provider-specific proxy now continues to next provider instead of aborting the entire loop - Fix SMTP failure locking users out: send email first, then write cache and increment rate counter - Fix notify email cache key case sensitivity: normalize to lowercase - Add OriginalPrice validation to validatePlanPatch and validatePlanRequired - Add empty scope validation for channel pricing rules (group_ids/account_ids) - Add platform color to account search dropdown in channel pricing rules --- .../internal/handler/admin/channel_handler.go | 11 +++ backend/internal/pkg/websearch/manager.go | 13 +++- backend/internal/repository/email_cache.go | 5 +- .../internal/service/payment_config_plans.go | 10 ++- .../payment_config_plans_validation_test.go | 75 ++++++++++++++----- backend/internal/service/user_service.go | 8 +- frontend/src/views/admin/ChannelsView.vue | 7 +- 7 files changed, 103 insertions(+), 26 deletions(-) diff --git a/backend/internal/handler/admin/channel_handler.go b/backend/internal/handler/admin/channel_handler.go index 2d4cd56a..ee76a750 100644 --- a/backend/internal/handler/admin/channel_handler.go +++ b/backend/internal/handler/admin/channel_handler.go @@ -1,6 +1,7 @@ package admin import ( + "fmt" "strconv" "strings" @@ -351,6 +352,11 @@ func (h *ChannelHandler) Create(c *gin.Context) { var statsRules []service.AccountStatsPricingRule for i, r := range req.AccountStatsPricingRules { + if len(r.GroupIDs) == 0 && len(r.AccountIDs) == 0 { + response.ErrorFrom(c, infraerrors.BadRequest("PRICING_RULE_EMPTY_SCOPE", + fmt.Sprintf("pricing rule #%d must have at least one group or account", i+1))) + return + } rule := accountStatsPricingRuleRequestToService(r) rule.SortOrder = i statsRules = append(statsRules, rule) @@ -409,6 +415,11 @@ func (h *ChannelHandler) Update(c *gin.Context) { if req.AccountStatsPricingRules != nil { statsRules := make([]service.AccountStatsPricingRule, 0, len(*req.AccountStatsPricingRules)) for i, r := range *req.AccountStatsPricingRules { + if len(r.GroupIDs) == 0 && len(r.AccountIDs) == 0 { + response.ErrorFrom(c, infraerrors.BadRequest("PRICING_RULE_EMPTY_SCOPE", + fmt.Sprintf("pricing rule #%d must have at least one group or account", i+1))) + return + } rule := accountStatsPricingRuleRequestToService(r) rule.SortOrder = i statsRules = append(statsRules, rule) diff --git a/backend/internal/pkg/websearch/manager.go b/backend/internal/pkg/websearch/manager.go index ae0683ad..27592459 100644 --- a/backend/internal/pkg/websearch/manager.go +++ b/backend/internal/pkg/websearch/manager.go @@ -111,9 +111,18 @@ func (m *Manager) SearchWithBestProvider(ctx context.Context, req SearchRequest) } if isProxyError(err) { m.markProxyUnavailable(ctx, cfg, req.ProxyURL) - slog.Warn("websearch: proxy error, marking unavailable", + if req.ProxyURL != "" { + // Account-level proxy is shared by all providers — no point + // trying others with the same broken proxy; signal account switch. + slog.Warn("websearch: account proxy error, aborting failover", + "provider", cfg.Type, "error", err) + return nil, "", fmt.Errorf("%w: %s", ErrProxyUnavailable, err.Error()) + } + // Provider-specific proxy failed — try the next provider which + // may use a different (or no) proxy. + slog.Warn("websearch: provider proxy error, trying next provider", "provider", cfg.Type, "error", err) - return nil, "", fmt.Errorf("%w: %s", ErrProxyUnavailable, err.Error()) + continue } slog.Warn("websearch: provider search failed", "provider", cfg.Type, "error", err) diff --git a/backend/internal/repository/email_cache.go b/backend/internal/repository/email_cache.go index ed903e0d..1356163d 100644 --- a/backend/internal/repository/email_cache.go +++ b/backend/internal/repository/email_cache.go @@ -4,6 +4,7 @@ import ( "context" "encoding/json" "fmt" + "strings" "time" "github.com/Wei-Shaw/sub2api/internal/service" @@ -24,8 +25,10 @@ func verifyCodeKey(email string) string { } // notifyVerifyKey generates the Redis key for notify email verification code. +// Email is lowercased to prevent case-sensitive key mismatch (the business layer +// uses strings.EqualFold for comparison). func notifyVerifyKey(email string) string { - return notifyVerifyKeyPrefix + email + return notifyVerifyKeyPrefix + strings.ToLower(email) } // passwordResetKey generates the Redis key for password reset token. diff --git a/backend/internal/service/payment_config_plans.go b/backend/internal/service/payment_config_plans.go index 8a5e1924..6753071d 100644 --- a/backend/internal/service/payment_config_plans.go +++ b/backend/internal/service/payment_config_plans.go @@ -12,7 +12,7 @@ import ( ) // validatePlanRequired checks that all required fields for a plan are provided. -func validatePlanRequired(name string, groupID int64, price float64, validityDays int, validityUnit string) error { +func validatePlanRequired(name string, groupID int64, price float64, validityDays int, validityUnit string, originalPrice *float64) error { if strings.TrimSpace(name) == "" { return infraerrors.BadRequest("PLAN_NAME_REQUIRED", "plan name is required") } @@ -28,6 +28,9 @@ func validatePlanRequired(name string, groupID int64, price float64, validityDay if strings.TrimSpace(validityUnit) == "" { return infraerrors.BadRequest("PLAN_VALIDITY_UNIT_REQUIRED", "validity unit is required") } + if originalPrice != nil && *originalPrice < 0 { + return infraerrors.BadRequest("PLAN_ORIGINAL_PRICE_INVALID", "original price must be >= 0") + } return nil } @@ -48,6 +51,9 @@ func validatePlanPatch(req UpdatePlanRequest) error { if req.ValidityUnit != nil && strings.TrimSpace(*req.ValidityUnit) == "" { return infraerrors.BadRequest("PLAN_VALIDITY_UNIT_REQUIRED", "validity unit is required") } + if req.OriginalPrice != nil && *req.OriginalPrice < 0 { + return infraerrors.BadRequest("PLAN_ORIGINAL_PRICE_INVALID", "original price must be >= 0") + } return nil } @@ -115,7 +121,7 @@ func (s *PaymentConfigService) ListPlansForSale(ctx context.Context) ([]*dbent.S } func (s *PaymentConfigService) CreatePlan(ctx context.Context, req CreatePlanRequest) (*dbent.SubscriptionPlan, error) { - if err := validatePlanRequired(req.Name, req.GroupID, req.Price, req.ValidityDays, req.ValidityUnit); err != nil { + if err := validatePlanRequired(req.Name, req.GroupID, req.Price, req.ValidityDays, req.ValidityUnit, req.OriginalPrice); err != nil { return nil, err } b := s.entClient.SubscriptionPlan.Create(). diff --git a/backend/internal/service/payment_config_plans_validation_test.go b/backend/internal/service/payment_config_plans_validation_test.go index bc9c0048..9a2d8716 100644 --- a/backend/internal/service/payment_config_plans_validation_test.go +++ b/backend/internal/service/payment_config_plans_validation_test.go @@ -9,81 +9,122 @@ import ( ) func TestValidatePlanRequired_AllValid(t *testing.T) { - err := validatePlanRequired("Pro", 1, 9.99, 30, "days") + err := validatePlanRequired("Pro", 1, 9.99, 30, "days", nil) require.NoError(t, err) } func TestValidatePlanRequired_EmptyName(t *testing.T) { - err := validatePlanRequired("", 1, 9.99, 30, "days") + err := validatePlanRequired("", 1, 9.99, 30, "days", nil) require.Error(t, err) require.Contains(t, err.Error(), "plan name") } func TestValidatePlanRequired_WhitespaceName(t *testing.T) { - err := validatePlanRequired(" ", 1, 9.99, 30, "days") + err := validatePlanRequired(" ", 1, 9.99, 30, "days", nil) require.Error(t, err) require.Contains(t, err.Error(), "plan name") } func TestValidatePlanRequired_ZeroGroupID(t *testing.T) { - err := validatePlanRequired("Pro", 0, 9.99, 30, "days") + err := validatePlanRequired("Pro", 0, 9.99, 30, "days", nil) require.Error(t, err) require.Contains(t, err.Error(), "group") } func TestValidatePlanRequired_NegativeGroupID(t *testing.T) { - err := validatePlanRequired("Pro", -1, 9.99, 30, "days") + err := validatePlanRequired("Pro", -1, 9.99, 30, "days", nil) require.Error(t, err) require.Contains(t, err.Error(), "group") } func TestValidatePlanRequired_ZeroPrice(t *testing.T) { - err := validatePlanRequired("Pro", 1, 0, 30, "days") + err := validatePlanRequired("Pro", 1, 0, 30, "days", nil) require.Error(t, err) require.Contains(t, err.Error(), "price") } func TestValidatePlanRequired_NegativePrice(t *testing.T) { - err := validatePlanRequired("Pro", 1, -5, 30, "days") + err := validatePlanRequired("Pro", 1, -5, 30, "days", nil) require.Error(t, err) require.Contains(t, err.Error(), "price") } func TestValidatePlanRequired_ZeroValidityDays(t *testing.T) { - err := validatePlanRequired("Pro", 1, 9.99, 0, "days") + err := validatePlanRequired("Pro", 1, 9.99, 0, "days", nil) require.Error(t, err) require.Contains(t, err.Error(), "validity days") } func TestValidatePlanRequired_NegativeValidityDays(t *testing.T) { - err := validatePlanRequired("Pro", 1, 9.99, -7, "days") + err := validatePlanRequired("Pro", 1, 9.99, -7, "days", nil) require.Error(t, err) require.Contains(t, err.Error(), "validity days") } func TestValidatePlanRequired_EmptyValidityUnit(t *testing.T) { - err := validatePlanRequired("Pro", 1, 9.99, 30, "") + err := validatePlanRequired("Pro", 1, 9.99, 30, "", nil) require.Error(t, err) require.Contains(t, err.Error(), "validity unit") } func TestValidatePlanRequired_WhitespaceValidityUnit(t *testing.T) { - err := validatePlanRequired("Pro", 1, 9.99, 30, " ") + err := validatePlanRequired("Pro", 1, 9.99, 30, " ", nil) require.Error(t, err) require.Contains(t, err.Error(), "validity unit") } func TestValidatePlanRequired_NameValidatedFirst(t *testing.T) { - // When multiple fields are invalid, name should be reported first - // (follows the order of checks in the function). - err := validatePlanRequired("", 0, 0, 0, "") + err := validatePlanRequired("", 0, 0, 0, "", nil) require.Error(t, err) require.Contains(t, err.Error(), "plan name") } func TestValidatePlanRequired_TrimmedValidName(t *testing.T) { - // Whitespace-surrounded but non-empty name is accepted (trimmed check only - // rejects pure whitespace). - err := validatePlanRequired(" Pro ", 1, 9.99, 30, "days") + err := validatePlanRequired(" Pro ", 1, 9.99, 30, "days", nil) + require.NoError(t, err) +} + +func TestValidatePlanRequired_NegativeOriginalPrice(t *testing.T) { + neg := -10.0 + err := validatePlanRequired("Pro", 1, 9.99, 30, "days", &neg) + require.Error(t, err) + require.Contains(t, err.Error(), "original price") +} + +func TestValidatePlanRequired_ZeroOriginalPrice(t *testing.T) { + zero := 0.0 + err := validatePlanRequired("Pro", 1, 9.99, 30, "days", &zero) + require.NoError(t, err) +} + +func TestValidatePlanRequired_ValidOriginalPrice(t *testing.T) { + op := 19.99 + err := validatePlanRequired("Pro", 1, 9.99, 30, "days", &op) + require.NoError(t, err) +} + +// --- validatePlanPatch tests --- + +func TestValidatePlanPatch_NegativeOriginalPrice(t *testing.T) { + neg := -5.0 + err := validatePlanPatch(UpdatePlanRequest{OriginalPrice: &neg}) + require.Error(t, err) + require.Contains(t, err.Error(), "original price") +} + +func TestValidatePlanPatch_ZeroOriginalPrice(t *testing.T) { + zero := 0.0 + err := validatePlanPatch(UpdatePlanRequest{OriginalPrice: &zero}) + require.NoError(t, err) +} + +func TestValidatePlanPatch_ValidOriginalPrice(t *testing.T) { + op := 29.99 + err := validatePlanPatch(UpdatePlanRequest{OriginalPrice: &op}) + require.NoError(t, err) +} + +func TestValidatePlanPatch_NilOriginalPrice(t *testing.T) { + err := validatePlanPatch(UpdatePlanRequest{OriginalPrice: nil}) require.NoError(t, err) } diff --git a/backend/internal/service/user_service.go b/backend/internal/service/user_service.go index 0da73762..7602d162 100644 --- a/backend/internal/service/user_service.go +++ b/backend/internal/service/user_service.go @@ -291,6 +291,12 @@ func (s *UserService) SendNotifyEmailCode(ctx context.Context, userID int64, ema return fmt.Errorf("generate code: %w", err) } + // Send email first — if SMTP fails, don't write cache or increment counters, + // so the user is not locked out by cooldown/rate-limit for a code they never received. + if err := s.sendNotifyVerifyEmail(ctx, emailService, email, code); err != nil { + return err + } + if err := saveNotifyVerifyCode(ctx, cache, email, code); err != nil { return err } @@ -300,7 +306,7 @@ func (s *UserService) SendNotifyEmailCode(ctx context.Context, userID int64, ema slog.Error("failed to increment notify code user rate", "user_id", userID, "error", err) } - return s.sendNotifyVerifyEmail(ctx, emailService, email, code) + return nil } // checkNotifyCodeRateLimit checks both email cooldown and user-level rate limit. diff --git a/frontend/src/views/admin/ChannelsView.vue b/frontend/src/views/admin/ChannelsView.vue index 2ca1141d..60704b65 100644 --- a/frontend/src/views/admin/ChannelsView.vue +++ b/frontend/src/views/admin/ChannelsView.vue @@ -511,7 +511,7 @@ :class="{ 'opacity-50': rule.account_ids.includes(account.id) }" :disabled="rule.account_ids.includes(account.id)" > - {{ account.name }} + {{ account.name }} #{{ account.id }} @@ -595,6 +595,7 @@ import type { PricingFormEntry } from '@/components/admin/channel/types' import { mTokToPerToken, perTokenToMTok, apiIntervalsToForm, formIntervalsToAPI, findModelConflict, validateIntervals } from '@/components/admin/channel/types' import type { AdminGroup, GroupPlatform } from '@/types' import type { Column } from '@/components/common/types' +import { platformTextClass } from '@/utils/platformColors' import AppLayout from '@/components/layout/AppLayout.vue' import TablePageLayout from '@/components/layout/TablePageLayout.vue' import DataTable from '@/components/common/DataTable.vue' @@ -911,7 +912,7 @@ function getGroupNameById(groupId: number): string { } // ── Account search for pricing rules ── -interface SimpleAccount { id: number; name: string } +interface SimpleAccount { id: number; name: string; platform: string } const ruleAccountSearchKeyword = ref>({}) const ruleAccountSearchResults = ref>({}) @@ -924,7 +925,7 @@ const ruleAccountSearchRunner = useKeyedDebouncedSearch({ search: async (keyword, { key, signal }) => { const platform = key.split('-')[0] const res = await adminAPI.accounts.list(1, 20, { platform, search: keyword }, { signal }) - return res.items.map(a => ({ id: a.id, name: a.name })) + return res.items.map(a => ({ id: a.id, name: a.name, platform: a.platform })) }, onSuccess: (key, result) => { ruleAccountSearchResults.value[key] = result }, onError: (key) => { ruleAccountSearchResults.value[key] = [] },