fix: audit round-3 — proxy safety, intervals persistence, SMTP timeout, sort fix
- Skip websearch provider when ProxyID is set but proxy not found (prevent silent direct connection bypass) - Fix sortByStableRandomWeight: pair factors with items so sort.Slice swap keeps weights aligned - Allow empty platform in account_stats_pricing_rules (wildcard matching), only force anthropic default for main model_pricing - Add channel_account_stats_pricing_intervals table and repo layer support for interval-based pricing in account stats rules - calculateTokenStatsCost now uses interval pricing when available - Replace smtp.SendMail/tls.Dial with net.Dialer timeout (10s dial, 20s IO) to prevent goroutine leak on SMTP hang - Fix gofmt formatting issues - Web Search label: black text with red warning hint
This commit is contained in:
@@ -249,9 +249,6 @@ func pricingRequestToService(reqs []channelModelPricingRequest) []service.Channe
|
|||||||
billingMode = service.BillingModeToken
|
billingMode = service.BillingModeToken
|
||||||
}
|
}
|
||||||
platform := r.Platform
|
platform := r.Platform
|
||||||
if platform == "" {
|
|
||||||
platform = service.PlatformAnthropic
|
|
||||||
}
|
|
||||||
intervals := make([]service.PricingInterval, 0, len(r.Intervals))
|
intervals := make([]service.PricingInterval, 0, len(r.Intervals))
|
||||||
for _, iv := range r.Intervals {
|
for _, iv := range r.Intervals {
|
||||||
intervals = append(intervals, service.PricingInterval{
|
intervals = append(intervals, service.PricingInterval{
|
||||||
@@ -349,6 +346,12 @@ func (h *ChannelHandler) Create(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
pricing := pricingRequestToService(req.ModelPricing)
|
pricing := pricingRequestToService(req.ModelPricing)
|
||||||
|
// Main model_pricing requires a platform; default to anthropic for backward compatibility.
|
||||||
|
for i := range pricing {
|
||||||
|
if pricing[i].Platform == "" {
|
||||||
|
pricing[i].Platform = service.PlatformAnthropic
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
var statsRules []service.AccountStatsPricingRule
|
var statsRules []service.AccountStatsPricingRule
|
||||||
for i, r := range req.AccountStatsPricingRules {
|
for i, r := range req.AccountStatsPricingRules {
|
||||||
@@ -415,6 +418,11 @@ func (h *ChannelHandler) Update(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
if req.ModelPricing != nil {
|
if req.ModelPricing != nil {
|
||||||
pricing := pricingRequestToService(*req.ModelPricing)
|
pricing := pricingRequestToService(*req.ModelPricing)
|
||||||
|
for i := range pricing {
|
||||||
|
if pricing[i].Platform == "" {
|
||||||
|
pricing[i].Platform = service.PlatformAnthropic
|
||||||
|
}
|
||||||
|
}
|
||||||
input.ModelPricing = &pricing
|
input.ModelPricing = &pricing
|
||||||
}
|
}
|
||||||
if req.AccountStatsPricingRules != nil {
|
if req.AccountStatsPricingRules != nil {
|
||||||
|
|||||||
@@ -200,13 +200,20 @@ func sortByStableRandomWeight(items []weighted) {
|
|||||||
if len(items) <= 1 {
|
if len(items) <= 1 {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
factors := make([]float64, len(items))
|
type entry struct {
|
||||||
for i, item := range items {
|
item weighted
|
||||||
factors[i] = float64(item.weight) * (0.5 + rand.Float64())
|
factor float64
|
||||||
}
|
}
|
||||||
sort.Slice(items, func(i, j int) bool {
|
entries := make([]entry, len(items))
|
||||||
return factors[i] > factors[j]
|
for i, item := range items {
|
||||||
|
entries[i] = entry{item: item, factor: float64(item.weight) * (0.5 + rand.Float64())}
|
||||||
|
}
|
||||||
|
sort.Slice(entries, func(i, j int) bool {
|
||||||
|
return entries[i].factor > entries[j].factor
|
||||||
})
|
})
|
||||||
|
for i, e := range entries {
|
||||||
|
items[i] = e.item
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func mergeWeightedResults(withQuota, withoutQuota []weighted, capacity int) []ProviderConfig {
|
func mergeWeightedResults(withQuota, withoutQuota []weighted, capacity int) []ProviderConfig {
|
||||||
|
|||||||
@@ -96,6 +96,27 @@ func (r *channelRepository) batchLoadAccountStatsModelPricing(ctx context.Contex
|
|||||||
if err := rows.Err(); err != nil {
|
if err := rows.Err(); err != nil {
|
||||||
return nil, fmt.Errorf("iterate account stats model pricing: %w", err)
|
return nil, fmt.Errorf("iterate account stats model pricing: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Load intervals for all pricing entries.
|
||||||
|
var allPricingIDs []int64
|
||||||
|
for _, pricings := range pricingMap {
|
||||||
|
for _, p := range pricings {
|
||||||
|
allPricingIDs = append(allPricingIDs, p.ID)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if len(allPricingIDs) > 0 {
|
||||||
|
intervalsMap, err := r.batchLoadAccountStatsIntervals(ctx, allPricingIDs)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
for ruleID, pricings := range pricingMap {
|
||||||
|
for i := range pricings {
|
||||||
|
pricings[i].Intervals = intervalsMap[pricings[i].ID]
|
||||||
|
}
|
||||||
|
pricingMap[ruleID] = pricings
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
return pricingMap, nil
|
return pricingMap, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -166,5 +187,58 @@ func createAccountStatsModelPricingTx(ctx context.Context, tx *sql.Tx, ruleID in
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("insert account stats model pricing: %w", err)
|
return fmt.Errorf("insert account stats model pricing: %w", err)
|
||||||
}
|
}
|
||||||
|
// Persist intervals (mirrors channel_pricing_intervals logic).
|
||||||
|
for i := range pricing.Intervals {
|
||||||
|
iv := &pricing.Intervals[i]
|
||||||
|
iv.PricingID = pricing.ID
|
||||||
|
if err := createAccountStatsIntervalTx(ctx, tx, iv); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// createAccountStatsIntervalTx inserts a single interval for an account stats pricing entry.
|
||||||
|
func createAccountStatsIntervalTx(ctx context.Context, tx *sql.Tx, iv *service.PricingInterval) error {
|
||||||
|
return tx.QueryRowContext(ctx,
|
||||||
|
`INSERT INTO channel_account_stats_pricing_intervals
|
||||||
|
(pricing_id, min_tokens, max_tokens, tier_label, input_price, output_price, cache_write_price, cache_read_price, per_request_price, sort_order)
|
||||||
|
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10) RETURNING id, created_at, updated_at`,
|
||||||
|
iv.PricingID, iv.MinTokens, iv.MaxTokens, iv.TierLabel,
|
||||||
|
iv.InputPrice, iv.OutputPrice, iv.CacheWritePrice, iv.CacheReadPrice,
|
||||||
|
iv.PerRequestPrice, iv.SortOrder,
|
||||||
|
).Scan(&iv.ID, &iv.CreatedAt, &iv.UpdatedAt)
|
||||||
|
}
|
||||||
|
|
||||||
|
// batchLoadAccountStatsIntervals loads intervals for account stats pricing entries.
|
||||||
|
func (r *channelRepository) batchLoadAccountStatsIntervals(ctx context.Context, pricingIDs []int64) (map[int64][]service.PricingInterval, error) {
|
||||||
|
if len(pricingIDs) == 0 {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
rows, err := r.db.QueryContext(ctx,
|
||||||
|
`SELECT id, pricing_id, min_tokens, max_tokens, tier_label,
|
||||||
|
input_price, output_price, cache_write_price, cache_read_price,
|
||||||
|
per_request_price, sort_order, created_at, updated_at
|
||||||
|
FROM channel_account_stats_pricing_intervals
|
||||||
|
WHERE pricing_id = ANY($1) ORDER BY pricing_id, sort_order, id`,
|
||||||
|
pq.Array(pricingIDs),
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("batch load account stats pricing intervals: %w", err)
|
||||||
|
}
|
||||||
|
defer func() { _ = rows.Close() }()
|
||||||
|
|
||||||
|
result := make(map[int64][]service.PricingInterval)
|
||||||
|
for rows.Next() {
|
||||||
|
var iv service.PricingInterval
|
||||||
|
if err := rows.Scan(
|
||||||
|
&iv.ID, &iv.PricingID, &iv.MinTokens, &iv.MaxTokens, &iv.TierLabel,
|
||||||
|
&iv.InputPrice, &iv.OutputPrice, &iv.CacheWritePrice, &iv.CacheReadPrice,
|
||||||
|
&iv.PerRequestPrice, &iv.SortOrder, &iv.CreatedAt, &iv.UpdatedAt,
|
||||||
|
); err != nil {
|
||||||
|
return nil, fmt.Errorf("scan account stats pricing interval: %w", err)
|
||||||
|
}
|
||||||
|
result[iv.PricingID] = append(result[iv.PricingID], iv)
|
||||||
|
}
|
||||||
|
return result, rows.Err()
|
||||||
|
}
|
||||||
|
|||||||
@@ -4,6 +4,7 @@ package server
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"log"
|
"log"
|
||||||
|
"log/slog"
|
||||||
"net/http"
|
"net/http"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
@@ -82,6 +83,11 @@ func ProvideRouter(
|
|||||||
pc.ProxyID = *p.ProxyID
|
pc.ProxyID = *p.ProxyID
|
||||||
if u, ok := proxyURLs[*p.ProxyID]; ok {
|
if u, ok := proxyURLs[*p.ProxyID]; ok {
|
||||||
pc.ProxyURL = u
|
pc.ProxyURL = u
|
||||||
|
} else {
|
||||||
|
// Proxy configured but not found — skip this provider to prevent direct connection.
|
||||||
|
slog.Warn("websearch: proxy not found for provider, skipping",
|
||||||
|
"provider", p.Type, "proxy_id", *p.ProxyID)
|
||||||
|
continue
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
configs = append(configs, pc)
|
configs = append(configs, pc)
|
||||||
|
|||||||
@@ -195,18 +195,33 @@ func calculatePerRequestStatsCost(pricing *ChannelModelPricing, requestCount int
|
|||||||
}
|
}
|
||||||
|
|
||||||
// calculateTokenStatsCost Token 计费。
|
// calculateTokenStatsCost Token 计费。
|
||||||
|
// If the pricing has intervals, find the matching interval by total token count
|
||||||
|
// and use its prices instead of the flat pricing fields.
|
||||||
func calculateTokenStatsCost(pricing *ChannelModelPricing, tokens UsageTokens) *float64 {
|
func calculateTokenStatsCost(pricing *ChannelModelPricing, tokens UsageTokens) *float64 {
|
||||||
deref := func(p *float64) float64 {
|
p := pricing
|
||||||
if p == nil {
|
if len(pricing.Intervals) > 0 {
|
||||||
|
totalTokens := tokens.InputTokens + tokens.OutputTokens + tokens.CacheCreationTokens + tokens.CacheReadTokens
|
||||||
|
if iv := FindMatchingInterval(pricing.Intervals, totalTokens); iv != nil {
|
||||||
|
p = &ChannelModelPricing{
|
||||||
|
InputPrice: iv.InputPrice,
|
||||||
|
OutputPrice: iv.OutputPrice,
|
||||||
|
CacheWritePrice: iv.CacheWritePrice,
|
||||||
|
CacheReadPrice: iv.CacheReadPrice,
|
||||||
|
PerRequestPrice: iv.PerRequestPrice,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
deref := func(ptr *float64) float64 {
|
||||||
|
if ptr == nil {
|
||||||
return 0
|
return 0
|
||||||
}
|
}
|
||||||
return *p
|
return *ptr
|
||||||
}
|
}
|
||||||
cost := float64(tokens.InputTokens)*deref(pricing.InputPrice) +
|
cost := float64(tokens.InputTokens)*deref(p.InputPrice) +
|
||||||
float64(tokens.OutputTokens)*deref(pricing.OutputPrice) +
|
float64(tokens.OutputTokens)*deref(p.OutputPrice) +
|
||||||
float64(tokens.CacheCreationTokens)*deref(pricing.CacheWritePrice) +
|
float64(tokens.CacheCreationTokens)*deref(p.CacheWritePrice) +
|
||||||
float64(tokens.CacheReadTokens)*deref(pricing.CacheReadPrice) +
|
float64(tokens.CacheReadTokens)*deref(p.CacheReadPrice) +
|
||||||
float64(tokens.ImageOutputTokens)*deref(pricing.ImageOutputPrice)
|
float64(tokens.ImageOutputTokens)*deref(p.ImageOutputPrice)
|
||||||
if cost <= 0 {
|
if cost <= 0 {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -477,4 +477,3 @@ func (s *BalanceNotifyService) buildQuotaAlertEmailBody(accountID int64, account
|
|||||||
}
|
}
|
||||||
return fmt.Sprintf(quotaAlertEmailTemplate, siteName, accountID, accountName, platform, dimLabel, used, limitStr, remaining, thresholdDisplay)
|
return fmt.Sprintf(quotaAlertEmailTemplate, siteName, accountID, accountName, platform, dimLabel, used, limitStr, remaining, thresholdDisplay)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -9,6 +9,7 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"log/slog"
|
"log/slog"
|
||||||
"math/big"
|
"math/big"
|
||||||
|
"net"
|
||||||
"net/smtp"
|
"net/smtp"
|
||||||
"net/url"
|
"net/url"
|
||||||
"strconv"
|
"strconv"
|
||||||
@@ -152,6 +153,9 @@ func (s *EmailService) SendEmail(ctx context.Context, to, subject, body string)
|
|||||||
return s.SendEmailWithConfig(config, to, subject, body)
|
return s.SendEmailWithConfig(config, to, subject, body)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const smtpDialTimeout = 10 * time.Second
|
||||||
|
const smtpIOTimeout = 20 * time.Second
|
||||||
|
|
||||||
// SendEmailWithConfig 使用指定配置发送邮件
|
// SendEmailWithConfig 使用指定配置发送邮件
|
||||||
func (s *EmailService) SendEmailWithConfig(config *SMTPConfig, to, subject, body string) error {
|
func (s *EmailService) SendEmailWithConfig(config *SMTPConfig, to, subject, body string) error {
|
||||||
// Sanitize all SMTP header fields to prevent header injection (CR/LF removal).
|
// Sanitize all SMTP header fields to prevent header injection (CR/LF removal).
|
||||||
@@ -173,7 +177,46 @@ func (s *EmailService) SendEmailWithConfig(config *SMTPConfig, to, subject, body
|
|||||||
return s.sendMailTLS(addr, auth, config.From, to, []byte(msg), config.Host)
|
return s.sendMailTLS(addr, auth, config.From, to, []byte(msg), config.Host)
|
||||||
}
|
}
|
||||||
|
|
||||||
return smtp.SendMail(addr, auth, config.From, []string{to}, []byte(msg))
|
return s.sendMailPlain(addr, auth, config.From, to, []byte(msg), config.Host)
|
||||||
|
}
|
||||||
|
|
||||||
|
// sendMailPlain sends mail without TLS using a dialer with timeout.
|
||||||
|
func (s *EmailService) sendMailPlain(addr string, auth smtp.Auth, from, to string, msg []byte, host string) error {
|
||||||
|
dialer := &net.Dialer{Timeout: smtpDialTimeout}
|
||||||
|
conn, err := dialer.Dial("tcp", addr)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("smtp dial: %w", err)
|
||||||
|
}
|
||||||
|
_ = conn.SetDeadline(time.Now().Add(smtpIOTimeout))
|
||||||
|
defer func() { _ = conn.Close() }()
|
||||||
|
|
||||||
|
client, err := smtp.NewClient(conn, host)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("new smtp client: %w", err)
|
||||||
|
}
|
||||||
|
defer func() { _ = client.Close() }()
|
||||||
|
|
||||||
|
if err = client.Auth(auth); err != nil {
|
||||||
|
return fmt.Errorf("smtp auth: %w", err)
|
||||||
|
}
|
||||||
|
if err = client.Mail(from); err != nil {
|
||||||
|
return fmt.Errorf("smtp mail: %w", err)
|
||||||
|
}
|
||||||
|
if err = client.Rcpt(to); err != nil {
|
||||||
|
return fmt.Errorf("smtp rcpt: %w", err)
|
||||||
|
}
|
||||||
|
w, err := client.Data()
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("smtp data: %w", err)
|
||||||
|
}
|
||||||
|
if _, err = w.Write(msg); err != nil {
|
||||||
|
return fmt.Errorf("write msg: %w", err)
|
||||||
|
}
|
||||||
|
if err = w.Close(); err != nil {
|
||||||
|
return fmt.Errorf("close writer: %w", err)
|
||||||
|
}
|
||||||
|
_ = client.Quit()
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// sendMailTLS 使用TLS发送邮件
|
// sendMailTLS 使用TLS发送邮件
|
||||||
@@ -184,10 +227,12 @@ func (s *EmailService) sendMailTLS(addr string, auth smtp.Auth, from, to string,
|
|||||||
MinVersion: tls.VersionTLS12,
|
MinVersion: tls.VersionTLS12,
|
||||||
}
|
}
|
||||||
|
|
||||||
conn, err := tls.Dial("tcp", addr, tlsConfig)
|
dialer := &net.Dialer{Timeout: smtpDialTimeout}
|
||||||
|
conn, err := tls.DialWithDialer(dialer, "tcp", addr, tlsConfig)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("tls dial: %w", err)
|
return fmt.Errorf("tls dial: %w", err)
|
||||||
}
|
}
|
||||||
|
_ = conn.SetDeadline(time.Now().Add(smtpIOTimeout))
|
||||||
defer func() { _ = conn.Close() }()
|
defer func() { _ = conn.Close() }()
|
||||||
|
|
||||||
client, err := smtp.NewClient(conn, host)
|
client, err := smtp.NewClient(conn, host)
|
||||||
|
|||||||
@@ -79,4 +79,3 @@ func MarshalNotifyEmails(entries []NotifyEmailEntry) string {
|
|||||||
}
|
}
|
||||||
return string(data)
|
return string(data)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -0,0 +1,19 @@
|
|||||||
|
-- Add intervals table for account stats pricing rules (mirrors channel_pricing_intervals).
|
||||||
|
CREATE TABLE IF NOT EXISTS channel_account_stats_pricing_intervals (
|
||||||
|
id BIGSERIAL PRIMARY KEY,
|
||||||
|
pricing_id BIGINT NOT NULL REFERENCES channel_account_stats_model_pricing(id) ON DELETE CASCADE,
|
||||||
|
min_tokens INT NOT NULL DEFAULT 0,
|
||||||
|
max_tokens INT,
|
||||||
|
tier_label VARCHAR(50),
|
||||||
|
input_price NUMERIC(20,12),
|
||||||
|
output_price NUMERIC(20,12),
|
||||||
|
cache_write_price NUMERIC(20,12),
|
||||||
|
cache_read_price NUMERIC(20,12),
|
||||||
|
per_request_price NUMERIC(20,12),
|
||||||
|
sort_order INT NOT NULL DEFAULT 0,
|
||||||
|
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
|
||||||
|
updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
|
||||||
|
);
|
||||||
|
|
||||||
|
CREATE INDEX IF NOT EXISTS idx_account_stats_pricing_intervals_pricing_id
|
||||||
|
ON channel_account_stats_pricing_intervals (pricing_id);
|
||||||
@@ -328,10 +328,10 @@
|
|||||||
<div v-if="section.platform === 'anthropic' && webSearchGlobalEnabled" class="border-t border-gray-200 pt-3 dark:border-dark-600">
|
<div v-if="section.platform === 'anthropic' && webSearchGlobalEnabled" class="border-t border-gray-200 pt-3 dark:border-dark-600">
|
||||||
<div class="flex items-center justify-between">
|
<div class="flex items-center justify-between">
|
||||||
<div>
|
<div>
|
||||||
<label class="text-xs font-medium text-orange-600 dark:text-orange-400">
|
<label class="text-xs font-medium text-gray-700 dark:text-gray-300">
|
||||||
{{ t('admin.channels.form.webSearchEmulation') }}
|
{{ t('admin.channels.form.webSearchEmulation') }}
|
||||||
</label>
|
</label>
|
||||||
<p class="mt-0.5 text-[11px] text-amber-500 dark:text-amber-400">
|
<p class="mt-0.5 text-[11px] text-red-500 dark:text-red-400">
|
||||||
{{ t('admin.channels.form.webSearchEmulationHint') }}
|
{{ t('admin.channels.form.webSearchEmulationHint') }}
|
||||||
</p>
|
</p>
|
||||||
</div>
|
</div>
|
||||||
|
|||||||
Reference in New Issue
Block a user