fix: audit round-3 — proxy safety, intervals persistence, SMTP timeout, sort fix

- Skip websearch provider when ProxyID is set but proxy not found (prevent
  silent direct connection bypass)
- Fix sortByStableRandomWeight: pair factors with items so sort.Slice swap
  keeps weights aligned
- Allow empty platform in account_stats_pricing_rules (wildcard matching),
  only force anthropic default for main model_pricing
- Add channel_account_stats_pricing_intervals table and repo layer support
  for interval-based pricing in account stats rules
- calculateTokenStatsCost now uses interval pricing when available
- Replace smtp.SendMail/tls.Dial with net.Dialer timeout (10s dial, 20s IO)
  to prevent goroutine leak on SMTP hang
- Fix gofmt formatting issues
- Web Search label: black text with red warning hint
This commit is contained in:
erio
2026-04-14 01:10:46 +08:00
parent 9c09bd19b4
commit 0a4ece5f5b
11 changed files with 199 additions and 27 deletions

View File

@@ -96,6 +96,27 @@ func (r *channelRepository) batchLoadAccountStatsModelPricing(ctx context.Contex
if err := rows.Err(); err != nil {
return nil, fmt.Errorf("iterate account stats model pricing: %w", err)
}
// Load intervals for all pricing entries.
var allPricingIDs []int64
for _, pricings := range pricingMap {
for _, p := range pricings {
allPricingIDs = append(allPricingIDs, p.ID)
}
}
if len(allPricingIDs) > 0 {
intervalsMap, err := r.batchLoadAccountStatsIntervals(ctx, allPricingIDs)
if err != nil {
return nil, err
}
for ruleID, pricings := range pricingMap {
for i := range pricings {
pricings[i].Intervals = intervalsMap[pricings[i].ID]
}
pricingMap[ruleID] = pricings
}
}
return pricingMap, nil
}
@@ -166,5 +187,58 @@ func createAccountStatsModelPricingTx(ctx context.Context, tx *sql.Tx, ruleID in
if err != nil {
return fmt.Errorf("insert account stats model pricing: %w", err)
}
// Persist intervals (mirrors channel_pricing_intervals logic).
for i := range pricing.Intervals {
iv := &pricing.Intervals[i]
iv.PricingID = pricing.ID
if err := createAccountStatsIntervalTx(ctx, tx, iv); err != nil {
return err
}
}
return nil
}
// createAccountStatsIntervalTx inserts a single interval for an account stats pricing entry.
func createAccountStatsIntervalTx(ctx context.Context, tx *sql.Tx, iv *service.PricingInterval) error {
return tx.QueryRowContext(ctx,
`INSERT INTO channel_account_stats_pricing_intervals
(pricing_id, min_tokens, max_tokens, tier_label, input_price, output_price, cache_write_price, cache_read_price, per_request_price, sort_order)
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10) RETURNING id, created_at, updated_at`,
iv.PricingID, iv.MinTokens, iv.MaxTokens, iv.TierLabel,
iv.InputPrice, iv.OutputPrice, iv.CacheWritePrice, iv.CacheReadPrice,
iv.PerRequestPrice, iv.SortOrder,
).Scan(&iv.ID, &iv.CreatedAt, &iv.UpdatedAt)
}
// batchLoadAccountStatsIntervals loads intervals for account stats pricing entries.
func (r *channelRepository) batchLoadAccountStatsIntervals(ctx context.Context, pricingIDs []int64) (map[int64][]service.PricingInterval, error) {
if len(pricingIDs) == 0 {
return nil, nil
}
rows, err := r.db.QueryContext(ctx,
`SELECT id, pricing_id, min_tokens, max_tokens, tier_label,
input_price, output_price, cache_write_price, cache_read_price,
per_request_price, sort_order, created_at, updated_at
FROM channel_account_stats_pricing_intervals
WHERE pricing_id = ANY($1) ORDER BY pricing_id, sort_order, id`,
pq.Array(pricingIDs),
)
if err != nil {
return nil, fmt.Errorf("batch load account stats pricing intervals: %w", err)
}
defer func() { _ = rows.Close() }()
result := make(map[int64][]service.PricingInterval)
for rows.Next() {
var iv service.PricingInterval
if err := rows.Scan(
&iv.ID, &iv.PricingID, &iv.MinTokens, &iv.MaxTokens, &iv.TierLabel,
&iv.InputPrice, &iv.OutputPrice, &iv.CacheWritePrice, &iv.CacheReadPrice,
&iv.PerRequestPrice, &iv.SortOrder, &iv.CreatedAt, &iv.UpdatedAt,
); err != nil {
return nil, fmt.Errorf("scan account stats pricing interval: %w", err)
}
result[iv.PricingID] = append(result[iv.PricingID], iv)
}
return result, rows.Err()
}

View File

@@ -12,11 +12,11 @@ import (
)
const (
verifyCodeKeyPrefix = "verify_code:"
notifyVerifyKeyPrefix = "notify_verify:"
passwordResetKeyPrefix = "password_reset:"
passwordResetSentAtKeyPrefix = "password_reset_sent:"
notifyCodeUserRateKeyPrefix = "notify_code_user_rate:"
verifyCodeKeyPrefix = "verify_code:"
notifyVerifyKeyPrefix = "notify_verify:"
passwordResetKeyPrefix = "password_reset:"
passwordResetSentAtKeyPrefix = "password_reset_sent:"
notifyCodeUserRateKeyPrefix = "notify_code_user_rate:"
)
// verifyCodeKey generates the Redis key for email verification code.