fix: audit round-3 — proxy safety, intervals persistence, SMTP timeout, sort fix

- Skip websearch provider when ProxyID is set but proxy not found (prevent
  silent direct connection bypass)
- Fix sortByStableRandomWeight: pair factors with items so sort.Slice swap
  keeps weights aligned
- Allow empty platform in account_stats_pricing_rules (wildcard matching),
  only force anthropic default for main model_pricing
- Add channel_account_stats_pricing_intervals table and repo layer support
  for interval-based pricing in account stats rules
- calculateTokenStatsCost now uses interval pricing when available
- Replace smtp.SendMail/tls.Dial with net.Dialer timeout (10s dial, 20s IO)
  to prevent goroutine leak on SMTP hang
- Fix gofmt formatting issues
- Web Search label: black text with red warning hint
This commit is contained in:
erio
2026-04-14 01:10:46 +08:00
parent 9c09bd19b4
commit 0a4ece5f5b
11 changed files with 199 additions and 27 deletions

View File

@@ -195,18 +195,33 @@ func calculatePerRequestStatsCost(pricing *ChannelModelPricing, requestCount int
}
// calculateTokenStatsCost Token 计费。
// If the pricing has intervals, find the matching interval by total token count
// and use its prices instead of the flat pricing fields.
func calculateTokenStatsCost(pricing *ChannelModelPricing, tokens UsageTokens) *float64 {
deref := func(p *float64) float64 {
if p == nil {
p := pricing
if len(pricing.Intervals) > 0 {
totalTokens := tokens.InputTokens + tokens.OutputTokens + tokens.CacheCreationTokens + tokens.CacheReadTokens
if iv := FindMatchingInterval(pricing.Intervals, totalTokens); iv != nil {
p = &ChannelModelPricing{
InputPrice: iv.InputPrice,
OutputPrice: iv.OutputPrice,
CacheWritePrice: iv.CacheWritePrice,
CacheReadPrice: iv.CacheReadPrice,
PerRequestPrice: iv.PerRequestPrice,
}
}
}
deref := func(ptr *float64) float64 {
if ptr == nil {
return 0
}
return *p
return *ptr
}
cost := float64(tokens.InputTokens)*deref(pricing.InputPrice) +
float64(tokens.OutputTokens)*deref(pricing.OutputPrice) +
float64(tokens.CacheCreationTokens)*deref(pricing.CacheWritePrice) +
float64(tokens.CacheReadTokens)*deref(pricing.CacheReadPrice) +
float64(tokens.ImageOutputTokens)*deref(pricing.ImageOutputPrice)
cost := float64(tokens.InputTokens)*deref(p.InputPrice) +
float64(tokens.OutputTokens)*deref(p.OutputPrice) +
float64(tokens.CacheCreationTokens)*deref(p.CacheWritePrice) +
float64(tokens.CacheReadTokens)*deref(p.CacheReadPrice) +
float64(tokens.ImageOutputTokens)*deref(p.ImageOutputPrice)
if cost <= 0 {
return nil
}

View File

@@ -477,4 +477,3 @@ func (s *BalanceNotifyService) buildQuotaAlertEmailBody(accountID int64, account
}
return fmt.Sprintf(quotaAlertEmailTemplate, siteName, accountID, accountName, platform, dimLabel, used, limitStr, remaining, thresholdDisplay)
}

View File

@@ -9,6 +9,7 @@ import (
"fmt"
"log/slog"
"math/big"
"net"
"net/smtp"
"net/url"
"strconv"
@@ -152,6 +153,9 @@ func (s *EmailService) SendEmail(ctx context.Context, to, subject, body string)
return s.SendEmailWithConfig(config, to, subject, body)
}
const smtpDialTimeout = 10 * time.Second
const smtpIOTimeout = 20 * time.Second
// SendEmailWithConfig 使用指定配置发送邮件
func (s *EmailService) SendEmailWithConfig(config *SMTPConfig, to, subject, body string) error {
// Sanitize all SMTP header fields to prevent header injection (CR/LF removal).
@@ -173,7 +177,46 @@ func (s *EmailService) SendEmailWithConfig(config *SMTPConfig, to, subject, body
return s.sendMailTLS(addr, auth, config.From, to, []byte(msg), config.Host)
}
return smtp.SendMail(addr, auth, config.From, []string{to}, []byte(msg))
return s.sendMailPlain(addr, auth, config.From, to, []byte(msg), config.Host)
}
// sendMailPlain sends mail without TLS using a dialer with timeout.
func (s *EmailService) sendMailPlain(addr string, auth smtp.Auth, from, to string, msg []byte, host string) error {
dialer := &net.Dialer{Timeout: smtpDialTimeout}
conn, err := dialer.Dial("tcp", addr)
if err != nil {
return fmt.Errorf("smtp dial: %w", err)
}
_ = conn.SetDeadline(time.Now().Add(smtpIOTimeout))
defer func() { _ = conn.Close() }()
client, err := smtp.NewClient(conn, host)
if err != nil {
return fmt.Errorf("new smtp client: %w", err)
}
defer func() { _ = client.Close() }()
if err = client.Auth(auth); err != nil {
return fmt.Errorf("smtp auth: %w", err)
}
if err = client.Mail(from); err != nil {
return fmt.Errorf("smtp mail: %w", err)
}
if err = client.Rcpt(to); err != nil {
return fmt.Errorf("smtp rcpt: %w", err)
}
w, err := client.Data()
if err != nil {
return fmt.Errorf("smtp data: %w", err)
}
if _, err = w.Write(msg); err != nil {
return fmt.Errorf("write msg: %w", err)
}
if err = w.Close(); err != nil {
return fmt.Errorf("close writer: %w", err)
}
_ = client.Quit()
return nil
}
// sendMailTLS 使用TLS发送邮件
@@ -184,10 +227,12 @@ func (s *EmailService) sendMailTLS(addr string, auth smtp.Auth, from, to string,
MinVersion: tls.VersionTLS12,
}
conn, err := tls.Dial("tcp", addr, tlsConfig)
dialer := &net.Dialer{Timeout: smtpDialTimeout}
conn, err := tls.DialWithDialer(dialer, "tcp", addr, tlsConfig)
if err != nil {
return fmt.Errorf("tls dial: %w", err)
}
_ = conn.SetDeadline(time.Now().Add(smtpIOTimeout))
defer func() { _ = conn.Close() }()
client, err := smtp.NewClient(conn, host)

View File

@@ -79,4 +79,3 @@ func MarshalNotifyEmails(entries []NotifyEmailEntry) string {
}
return string(data)
}