fix: audit fixes for websearch, notifications, and channel pricing

P0: fix wildcard matching test assertion (config order, not longest prefix)
P0: add TotalRecharged to auth cache snapshot (v5) for percentage threshold
P1: move pricing rules into per-platform sections in ChannelsView
P1: populate account name cache when editing existing channel rules
P1: sanitize email subject headers to prevent SMTP injection
P1: make Redis INCR+EXPIRE idempotent for rate limiting
P1: deep copy FeaturesConfig in Channel.Clone()
P2: clean up stale email="" placeholder comments
P2: replace log.Printf with slog in email_service.go
This commit is contained in:
erio
2026-04-13 13:59:35 +08:00
parent a68df457d8
commit b7fb2e4387
13 changed files with 273 additions and 118 deletions

View File

@@ -145,14 +145,14 @@ func TestFindPricingForModel(t *testing.T) {
wantNil: true,
},
{
name: "longer wildcard prefix wins over shorter",
name: "wildcard matches by config order (first match wins)",
list: []ChannelModelPricing{
{ID: 10, Models: []string{"claude-*"}},
{ID: 11, Models: []string{"claude-opus-*"}},
},
platform: "",
model: "claude-opus-4",
wantID: 11, // "claude-opus-" (12 chars) > "claude-" (7 chars)
wantID: 10, // config order: "claude-*" is first and matches, so it wins
},
{
name: "shorter wildcard used when longer does not match",

View File

@@ -42,6 +42,7 @@ type APIKeyAuthUserSnapshot struct {
BalanceNotifyThresholdType string `json:"balance_notify_threshold_type"`
BalanceNotifyThreshold *float64 `json:"balance_notify_threshold,omitempty"`
BalanceNotifyExtraEmails []NotifyEmailEntry `json:"balance_notify_extra_emails,omitempty"`
TotalRecharged float64 `json:"total_recharged"`
}
// APIKeyAuthGroupSnapshot 分组快照

View File

@@ -13,7 +13,7 @@ import (
"github.com/dgraph-io/ristretto"
)
const apiKeyAuthSnapshotVersion = 4 // v4: added balance notification fields to UserSnapshot
const apiKeyAuthSnapshotVersion = 5 // v5: added TotalRecharged for percentage threshold
type apiKeyAuthCacheConfig struct {
l1Size int
@@ -230,6 +230,7 @@ func (s *APIKeyService) snapshotFromAPIKey(apiKey *APIKey) *APIKeyAuthSnapshot {
BalanceNotifyThresholdType: apiKey.User.BalanceNotifyThresholdType,
BalanceNotifyThreshold: apiKey.User.BalanceNotifyThreshold,
BalanceNotifyExtraEmails: apiKey.User.BalanceNotifyExtraEmails,
TotalRecharged: apiKey.User.TotalRecharged,
},
}
if apiKey.Group != nil {
@@ -291,6 +292,7 @@ func (s *APIKeyService) snapshotToAPIKey(key string, snapshot *APIKeyAuthSnapsho
BalanceNotifyThresholdType: snapshot.User.BalanceNotifyThresholdType,
BalanceNotifyThreshold: snapshot.User.BalanceNotifyThreshold,
BalanceNotifyExtraEmails: snapshot.User.BalanceNotifyExtraEmails,
TotalRecharged: snapshot.User.TotalRecharged,
},
}
if snapshot.Group != nil {

View File

@@ -309,7 +309,7 @@ func (s *BalanceNotifyService) sendBalanceLowEmails(recipients []string, userNam
if displayName == "" {
displayName = userEmail
}
subject := fmt.Sprintf("[%s] 余额不足提醒 / Balance Low Alert", siteName)
subject := fmt.Sprintf("[%s] 余额不足提醒 / Balance Low Alert", sanitizeEmailHeader(siteName))
body := s.buildBalanceLowEmailBody(html.EscapeString(displayName), balance, threshold, html.EscapeString(siteName))
s.sendEmails(recipients, subject, body, "user_email", userEmail, "balance", balance)
}
@@ -321,11 +321,16 @@ func (s *BalanceNotifyService) sendQuotaAlertEmails(adminEmails []string, accoun
dimLabel = dimension
}
subject := fmt.Sprintf("[%s] 账号限额告警 / Account Quota Alert - %s", siteName, accountName)
subject := fmt.Sprintf("[%s] 账号限额告警 / Account Quota Alert - %s", sanitizeEmailHeader(siteName), sanitizeEmailHeader(accountName))
body := s.buildQuotaAlertEmailBody(html.EscapeString(accountName), html.EscapeString(dimLabel), used, limit, threshold, html.EscapeString(siteName))
s.sendEmails(adminEmails, subject, body, "account", accountName, "dimension", dimension)
}
// sanitizeEmailHeader removes CR/LF characters to prevent SMTP header injection.
func sanitizeEmailHeader(s string) string {
return strings.NewReplacer("\r", "", "\n", "").Replace(s)
}
// balanceLowEmailTemplate is the HTML template for balance low notifications.
// Format args: siteName, userName, userName, balance, threshold, threshold.
const balanceLowEmailTemplate = `<!DOCTYPE html>

View File

@@ -196,6 +196,9 @@ func (c *Channel) Clone() *Channel {
cp.ModelMapping[platform] = inner
}
}
if c.FeaturesConfig != nil {
cp.FeaturesConfig = deepCopyFeaturesConfig(c.FeaturesConfig)
}
if c.AccountStatsPricingRules != nil {
cp.AccountStatsPricingRules = make([]AccountStatsPricingRule, len(c.AccountStatsPricingRules))
for i, rule := range c.AccountStatsPricingRules {
@@ -219,6 +222,19 @@ func (c *Channel) Clone() *Channel {
return &cp
}
// deepCopyFeaturesConfig creates a deep copy of FeaturesConfig to prevent cache pollution.
func deepCopyFeaturesConfig(src map[string]any) map[string]any {
dst := make(map[string]any, len(src))
for k, v := range src {
if inner, ok := v.(map[string]any); ok {
dst[k] = deepCopyFeaturesConfig(inner)
} else {
dst[k] = v
}
}
return dst
}
// ValidateIntervals 校验区间列表的合法性。
// 规则MinTokens >= 0MaxTokens 若非 nil 则 > 0 且 > MinTokens
// 所有价格字段 >= 0区间按 MinTokens 排序后无重叠((min, max] 语义);

View File

@@ -7,7 +7,7 @@ import (
"crypto/tls"
"encoding/hex"
"fmt"
"log"
"log/slog"
"math/big"
"net/smtp"
"net/url"
@@ -292,7 +292,7 @@ func (s *EmailService) VerifyCode(ctx context.Context, email, code string) error
if subtle.ConstantTimeCompare([]byte(data.Code), []byte(code)) != 1 {
data.Attempts++
if err := s.cache.SetVerificationCode(ctx, email, data, verifyCodeTTL); err != nil {
log.Printf("[Email] Failed to update verification attempt count: %v", err)
slog.Error("failed to update verification attempt count", "email", email, "error", err)
}
if data.Attempts >= maxVerifyCodeAttempts {
return ErrVerifyCodeMaxAttempts
@@ -302,7 +302,7 @@ func (s *EmailService) VerifyCode(ctx context.Context, email, code string) error
// 验证成功,删除验证码
if err := s.cache.DeleteVerificationCode(ctx, email); err != nil {
log.Printf("[Email] Failed to delete verification code after success: %v", err)
slog.Error("failed to delete verification code after success", "email", email, "error", err)
}
return nil
}
@@ -452,7 +452,7 @@ func (s *EmailService) SendPasswordResetEmail(ctx context.Context, email, siteNa
func (s *EmailService) SendPasswordResetEmailWithCooldown(ctx context.Context, email, siteName, resetURL string) error {
// Check email cooldown to prevent email bombing
if s.cache.IsPasswordResetEmailInCooldown(ctx, email) {
log.Printf("[Email] Password reset email skipped (cooldown): %s", email)
slog.Info("password reset email skipped due to cooldown", "email", email)
return nil // Silent success to prevent revealing cooldown to attackers
}
@@ -463,7 +463,7 @@ func (s *EmailService) SendPasswordResetEmailWithCooldown(ctx context.Context, e
// Set cooldown marker (Redis TTL handles expiration)
if err := s.cache.SetPasswordResetEmailCooldown(ctx, email, passwordResetEmailCooldown); err != nil {
log.Printf("[Email] Failed to set password reset cooldown for %s: %v", email, err)
slog.Error("failed to set password reset cooldown", "email", email, "error", err)
}
return nil
@@ -493,7 +493,7 @@ func (s *EmailService) ConsumePasswordResetToken(ctx context.Context, email, tok
// Delete after verification (one-time use)
if err := s.cache.DeletePasswordResetToken(ctx, email); err != nil {
log.Printf("[Email] Failed to delete password reset token after consumption: %v", err)
slog.Error("failed to delete password reset token after consumption", "email", email, "error", err)
}
return nil
}

View File

@@ -6,7 +6,7 @@ import (
)
// NotifyEmailEntry represents a notification email with enable/disable and verification state.
// Email="" is a placeholder for the "primary email" (user's registration email or first admin email).
// All emails are user-managed; maximum 3 entries per user.
type NotifyEmailEntry struct {
Email string `json:"email"`
Disabled bool `json:"disabled"`

View File

@@ -4,7 +4,7 @@ import (
"context"
"crypto/subtle"
"fmt"
"log"
"log/slog"
"strings"
"time"
@@ -13,12 +13,19 @@ import (
)
var (
ErrUserNotFound = infraerrors.NotFound("USER_NOT_FOUND", "user not found")
ErrPasswordIncorrect = infraerrors.BadRequest("PASSWORD_INCORRECT", "current password is incorrect")
ErrInsufficientPerms = infraerrors.Forbidden("INSUFFICIENT_PERMISSIONS", "insufficient permissions")
ErrUserNotFound = infraerrors.NotFound("USER_NOT_FOUND", "user not found")
ErrPasswordIncorrect = infraerrors.BadRequest("PASSWORD_INCORRECT", "current password is incorrect")
ErrInsufficientPerms = infraerrors.Forbidden("INSUFFICIENT_PERMISSIONS", "insufficient permissions")
ErrNotifyCodeUserRateLimit = infraerrors.TooManyRequests("NOTIFY_CODE_USER_RATE_LIMIT", "too many verification codes requested, please try again later")
)
const maxNotifyEmails = 3 // Total limit: primary (email="") + up to 2 extra
const (
maxNotifyEmails = 3 // Maximum number of notification emails per user
// User-level rate limiting for notify email verification codes
notifyCodeUserRateLimit = 5
notifyCodeUserRateWindow = 10 * time.Minute
)
// UserListFilters contains all filter options for listing users
type UserListFilters struct {
@@ -220,7 +227,7 @@ func (s *UserService) UpdateBalance(ctx context.Context, userID int64, amount fl
cacheCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
if err := s.billingCache.InvalidateUserBalance(cacheCtx, userID); err != nil {
log.Printf("invalidate user balance cache failed: user_id=%d err=%v", userID, err)
slog.Error("invalidate user balance cache failed", "user_id", userID, "error", err)
}
}()
}
@@ -270,21 +277,44 @@ func (s *UserService) Delete(ctx context.Context, userID int64) error {
// SendNotifyEmailCode sends a verification code to the extra notification email.
func (s *UserService) SendNotifyEmailCode(ctx context.Context, userID int64, email string, emailService *EmailService, cache EmailCache) error {
// Check cooldown
if err := checkNotifyCodeRateLimit(ctx, cache, userID, email); err != nil {
return err
}
code, err := emailService.GenerateVerifyCode()
if err != nil {
return fmt.Errorf("generate code: %w", err)
}
if err := saveNotifyVerifyCode(ctx, cache, email, code); err != nil {
return err
}
// Increment user-level counter after successful save
if _, err := cache.IncrNotifyCodeUserRate(ctx, userID, notifyCodeUserRateWindow); err != nil {
slog.Error("failed to increment notify code user rate", "user_id", userID, "error", err)
}
return s.sendNotifyVerifyEmail(ctx, emailService, email, code)
}
// checkNotifyCodeRateLimit checks both email cooldown and user-level rate limit.
func checkNotifyCodeRateLimit(ctx context.Context, cache EmailCache, userID int64, email string) error {
existing, err := cache.GetNotifyVerifyCode(ctx, email)
if err == nil && existing != nil {
if time.Since(existing.CreatedAt) < verifyCodeCooldown {
return ErrVerifyCodeTooFrequent
}
}
// Generate code
code, err := emailService.GenerateVerifyCode()
if err != nil {
return fmt.Errorf("generate code: %w", err)
count, err := cache.GetNotifyCodeUserRate(ctx, userID)
if err == nil && count >= notifyCodeUserRateLimit {
return ErrNotifyCodeUserRateLimit
}
return nil
}
// Save to cache
// saveNotifyVerifyCode saves the verification code to cache.
func saveNotifyVerifyCode(ctx context.Context, cache EmailCache, email, code string) error {
data := &VerificationCodeData{
Code: code,
Attempts: 0,
@@ -293,16 +323,17 @@ func (s *UserService) SendNotifyEmailCode(ctx context.Context, userID int64, ema
if err := cache.SetNotifyVerifyCode(ctx, email, data, verifyCodeTTL); err != nil {
return fmt.Errorf("save verify code: %w", err)
}
return nil
}
// Get site name
// sendNotifyVerifyEmail builds and sends the verification email.
func (s *UserService) sendNotifyVerifyEmail(ctx context.Context, emailService *EmailService, email, code string) error {
siteName := "Sub2API"
if s.settingRepo != nil {
if name, err := s.settingRepo.GetValue(ctx, SettingKeySiteName); err == nil && name != "" {
siteName = name
}
}
// Build and send email
subject := fmt.Sprintf("[%s] 通知邮箱验证码 / Notification Email Verification", siteName)
body := buildNotifyVerifyEmailBody(code, siteName)
return emailService.SendEmail(ctx, email, subject, body)
@@ -310,7 +341,15 @@ func (s *UserService) SendNotifyEmailCode(ctx context.Context, userID int64, ema
// VerifyAndAddNotifyEmail verifies the code and adds the email to user's extra emails.
func (s *UserService) VerifyAndAddNotifyEmail(ctx context.Context, userID int64, email, code string, cache EmailCache) error {
// Verify code
if err := verifyNotifyCode(ctx, cache, email, code); err != nil {
return err
}
_ = cache.DeleteNotifyVerifyCode(ctx, email)
return s.addOrVerifyNotifyEmail(ctx, userID, email)
}
// verifyNotifyCode validates the verification code against the cached data.
func verifyNotifyCode(ctx context.Context, cache EmailCache, email, code string) error {
data, err := cache.GetNotifyVerifyCode(ctx, email)
if err != nil || data == nil {
return ErrInvalidVerifyCode
@@ -326,17 +365,18 @@ func (s *UserService) VerifyAndAddNotifyEmail(ctx context.Context, userID int64,
}
return ErrInvalidVerifyCode
}
return nil
}
// Delete code after verification
_ = cache.DeleteNotifyVerifyCode(ctx, email)
// Add to user's extra emails
// addOrVerifyNotifyEmail adds the email to user's extra notification emails or marks it as verified.
// Note: concurrent calls for the same user could race on the read-modify-write of
// BalanceNotifyExtraEmails. The window is small (requires two verify flows completing
// simultaneously), and the worst case is a duplicate entry which is harmless.
func (s *UserService) addOrVerifyNotifyEmail(ctx context.Context, userID int64, email string) error {
user, err := s.userRepo.GetByID(ctx, userID)
if err != nil {
return err
}
// Check if already exists — if unverified, mark as verified
for i, e := range user.BalanceNotifyExtraEmails {
if strings.EqualFold(e.Email, email) {
if !e.Verified {
@@ -346,12 +386,9 @@ func (s *UserService) VerifyAndAddNotifyEmail(ctx context.Context, userID int64,
return nil // Already verified
}
}
// Check limit
if len(user.BalanceNotifyExtraEmails) >= maxNotifyEmails {
return infraerrors.BadRequest("TOO_MANY_NOTIFY_EMAILS", fmt.Sprintf("maximum %d notification emails allowed", maxNotifyEmails))
}
user.BalanceNotifyExtraEmails = append(user.BalanceNotifyExtraEmails, NotifyEmailEntry{
Email: email,
Disabled: false,
@@ -399,10 +436,9 @@ func (s *UserService) ToggleNotifyEmail(ctx context.Context, userID int64, email
return s.userRepo.Update(ctx, user)
}
// buildNotifyVerifyEmailBody builds the HTML email body for notify email verification.
func buildNotifyVerifyEmailBody(code, siteName string) string {
return fmt.Sprintf(`
<!DOCTYPE html>
// notifyVerifyEmailTemplate is the HTML template for notify email verification.
// Format args: siteName, code.
const notifyVerifyEmailTemplate = `<!DOCTYPE html>
<html>
<head>
<meta charset="UTF-8">
@@ -439,6 +475,9 @@ func buildNotifyVerifyEmailBody(code, siteName string) string {
</div>
</div>
</body>
</html>
`, siteName, code)
</html>`
// buildNotifyVerifyEmailBody builds the HTML email body for notify email verification.
func buildNotifyVerifyEmailBody(code, siteName string) string {
return fmt.Sprintf(notifyVerifyEmailTemplate, siteName, code)
}