diff --git a/backend/internal/handler/admin/channel_handler.go b/backend/internal/handler/admin/channel_handler.go index 1a328551..88d27c47 100644 --- a/backend/internal/handler/admin/channel_handler.go +++ b/backend/internal/handler/admin/channel_handler.go @@ -249,9 +249,6 @@ func pricingRequestToService(reqs []channelModelPricingRequest) []service.Channe billingMode = service.BillingModeToken } platform := r.Platform - if platform == "" { - platform = service.PlatformAnthropic - } intervals := make([]service.PricingInterval, 0, len(r.Intervals)) for _, iv := range r.Intervals { intervals = append(intervals, service.PricingInterval{ @@ -349,6 +346,12 @@ func (h *ChannelHandler) Create(c *gin.Context) { } pricing := pricingRequestToService(req.ModelPricing) + // Main model_pricing requires a platform; default to anthropic for backward compatibility. + for i := range pricing { + if pricing[i].Platform == "" { + pricing[i].Platform = service.PlatformAnthropic + } + } var statsRules []service.AccountStatsPricingRule for i, r := range req.AccountStatsPricingRules { @@ -415,6 +418,11 @@ func (h *ChannelHandler) Update(c *gin.Context) { } if req.ModelPricing != nil { pricing := pricingRequestToService(*req.ModelPricing) + for i := range pricing { + if pricing[i].Platform == "" { + pricing[i].Platform = service.PlatformAnthropic + } + } input.ModelPricing = &pricing } if req.AccountStatsPricingRules != nil { diff --git a/backend/internal/pkg/websearch/manager.go b/backend/internal/pkg/websearch/manager.go index 27592459..61faa616 100644 --- a/backend/internal/pkg/websearch/manager.go +++ b/backend/internal/pkg/websearch/manager.go @@ -200,13 +200,20 @@ func sortByStableRandomWeight(items []weighted) { if len(items) <= 1 { return } - factors := make([]float64, len(items)) - for i, item := range items { - factors[i] = float64(item.weight) * (0.5 + rand.Float64()) + type entry struct { + item weighted + factor float64 } - sort.Slice(items, func(i, j int) bool { - return factors[i] > factors[j] + entries := make([]entry, len(items)) + for i, item := range items { + entries[i] = entry{item: item, factor: float64(item.weight) * (0.5 + rand.Float64())} + } + sort.Slice(entries, func(i, j int) bool { + return entries[i].factor > entries[j].factor }) + for i, e := range entries { + items[i] = e.item + } } func mergeWeightedResults(withQuota, withoutQuota []weighted, capacity int) []ProviderConfig { diff --git a/backend/internal/repository/channel_repo_account_stats_pricing.go b/backend/internal/repository/channel_repo_account_stats_pricing.go index ef8f5177..9e00fed8 100644 --- a/backend/internal/repository/channel_repo_account_stats_pricing.go +++ b/backend/internal/repository/channel_repo_account_stats_pricing.go @@ -96,6 +96,27 @@ func (r *channelRepository) batchLoadAccountStatsModelPricing(ctx context.Contex if err := rows.Err(); err != nil { return nil, fmt.Errorf("iterate account stats model pricing: %w", err) } + + // Load intervals for all pricing entries. + var allPricingIDs []int64 + for _, pricings := range pricingMap { + for _, p := range pricings { + allPricingIDs = append(allPricingIDs, p.ID) + } + } + if len(allPricingIDs) > 0 { + intervalsMap, err := r.batchLoadAccountStatsIntervals(ctx, allPricingIDs) + if err != nil { + return nil, err + } + for ruleID, pricings := range pricingMap { + for i := range pricings { + pricings[i].Intervals = intervalsMap[pricings[i].ID] + } + pricingMap[ruleID] = pricings + } + } + return pricingMap, nil } @@ -166,5 +187,58 @@ func createAccountStatsModelPricingTx(ctx context.Context, tx *sql.Tx, ruleID in if err != nil { return fmt.Errorf("insert account stats model pricing: %w", err) } + // Persist intervals (mirrors channel_pricing_intervals logic). + for i := range pricing.Intervals { + iv := &pricing.Intervals[i] + iv.PricingID = pricing.ID + if err := createAccountStatsIntervalTx(ctx, tx, iv); err != nil { + return err + } + } return nil } + +// createAccountStatsIntervalTx inserts a single interval for an account stats pricing entry. +func createAccountStatsIntervalTx(ctx context.Context, tx *sql.Tx, iv *service.PricingInterval) error { + return tx.QueryRowContext(ctx, + `INSERT INTO channel_account_stats_pricing_intervals + (pricing_id, min_tokens, max_tokens, tier_label, input_price, output_price, cache_write_price, cache_read_price, per_request_price, sort_order) + VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10) RETURNING id, created_at, updated_at`, + iv.PricingID, iv.MinTokens, iv.MaxTokens, iv.TierLabel, + iv.InputPrice, iv.OutputPrice, iv.CacheWritePrice, iv.CacheReadPrice, + iv.PerRequestPrice, iv.SortOrder, + ).Scan(&iv.ID, &iv.CreatedAt, &iv.UpdatedAt) +} + +// batchLoadAccountStatsIntervals loads intervals for account stats pricing entries. +func (r *channelRepository) batchLoadAccountStatsIntervals(ctx context.Context, pricingIDs []int64) (map[int64][]service.PricingInterval, error) { + if len(pricingIDs) == 0 { + return nil, nil + } + rows, err := r.db.QueryContext(ctx, + `SELECT id, pricing_id, min_tokens, max_tokens, tier_label, + input_price, output_price, cache_write_price, cache_read_price, + per_request_price, sort_order, created_at, updated_at + FROM channel_account_stats_pricing_intervals + WHERE pricing_id = ANY($1) ORDER BY pricing_id, sort_order, id`, + pq.Array(pricingIDs), + ) + if err != nil { + return nil, fmt.Errorf("batch load account stats pricing intervals: %w", err) + } + defer func() { _ = rows.Close() }() + + result := make(map[int64][]service.PricingInterval) + for rows.Next() { + var iv service.PricingInterval + if err := rows.Scan( + &iv.ID, &iv.PricingID, &iv.MinTokens, &iv.MaxTokens, &iv.TierLabel, + &iv.InputPrice, &iv.OutputPrice, &iv.CacheWritePrice, &iv.CacheReadPrice, + &iv.PerRequestPrice, &iv.SortOrder, &iv.CreatedAt, &iv.UpdatedAt, + ); err != nil { + return nil, fmt.Errorf("scan account stats pricing interval: %w", err) + } + result[iv.PricingID] = append(result[iv.PricingID], iv) + } + return result, rows.Err() +} diff --git a/backend/internal/repository/email_cache.go b/backend/internal/repository/email_cache.go index 0eb6bef1..96a23a8e 100644 --- a/backend/internal/repository/email_cache.go +++ b/backend/internal/repository/email_cache.go @@ -12,11 +12,11 @@ import ( ) const ( - verifyCodeKeyPrefix = "verify_code:" - notifyVerifyKeyPrefix = "notify_verify:" - passwordResetKeyPrefix = "password_reset:" - passwordResetSentAtKeyPrefix = "password_reset_sent:" - notifyCodeUserRateKeyPrefix = "notify_code_user_rate:" + verifyCodeKeyPrefix = "verify_code:" + notifyVerifyKeyPrefix = "notify_verify:" + passwordResetKeyPrefix = "password_reset:" + passwordResetSentAtKeyPrefix = "password_reset_sent:" + notifyCodeUserRateKeyPrefix = "notify_code_user_rate:" ) // verifyCodeKey generates the Redis key for email verification code. diff --git a/backend/internal/server/http.go b/backend/internal/server/http.go index 5165b059..d203bab2 100644 --- a/backend/internal/server/http.go +++ b/backend/internal/server/http.go @@ -4,6 +4,7 @@ package server import ( "context" "log" + "log/slog" "net/http" "time" @@ -82,6 +83,11 @@ func ProvideRouter( pc.ProxyID = *p.ProxyID if u, ok := proxyURLs[*p.ProxyID]; ok { pc.ProxyURL = u + } else { + // Proxy configured but not found — skip this provider to prevent direct connection. + slog.Warn("websearch: proxy not found for provider, skipping", + "provider", p.Type, "proxy_id", *p.ProxyID) + continue } } configs = append(configs, pc) diff --git a/backend/internal/service/account_stats_pricing.go b/backend/internal/service/account_stats_pricing.go index 8251dede..61c318d9 100644 --- a/backend/internal/service/account_stats_pricing.go +++ b/backend/internal/service/account_stats_pricing.go @@ -195,18 +195,33 @@ func calculatePerRequestStatsCost(pricing *ChannelModelPricing, requestCount int } // calculateTokenStatsCost Token 计费。 +// If the pricing has intervals, find the matching interval by total token count +// and use its prices instead of the flat pricing fields. func calculateTokenStatsCost(pricing *ChannelModelPricing, tokens UsageTokens) *float64 { - deref := func(p *float64) float64 { - if p == nil { + p := pricing + if len(pricing.Intervals) > 0 { + totalTokens := tokens.InputTokens + tokens.OutputTokens + tokens.CacheCreationTokens + tokens.CacheReadTokens + if iv := FindMatchingInterval(pricing.Intervals, totalTokens); iv != nil { + p = &ChannelModelPricing{ + InputPrice: iv.InputPrice, + OutputPrice: iv.OutputPrice, + CacheWritePrice: iv.CacheWritePrice, + CacheReadPrice: iv.CacheReadPrice, + PerRequestPrice: iv.PerRequestPrice, + } + } + } + deref := func(ptr *float64) float64 { + if ptr == nil { return 0 } - return *p + return *ptr } - cost := float64(tokens.InputTokens)*deref(pricing.InputPrice) + - float64(tokens.OutputTokens)*deref(pricing.OutputPrice) + - float64(tokens.CacheCreationTokens)*deref(pricing.CacheWritePrice) + - float64(tokens.CacheReadTokens)*deref(pricing.CacheReadPrice) + - float64(tokens.ImageOutputTokens)*deref(pricing.ImageOutputPrice) + cost := float64(tokens.InputTokens)*deref(p.InputPrice) + + float64(tokens.OutputTokens)*deref(p.OutputPrice) + + float64(tokens.CacheCreationTokens)*deref(p.CacheWritePrice) + + float64(tokens.CacheReadTokens)*deref(p.CacheReadPrice) + + float64(tokens.ImageOutputTokens)*deref(p.ImageOutputPrice) if cost <= 0 { return nil } diff --git a/backend/internal/service/balance_notify_service.go b/backend/internal/service/balance_notify_service.go index 5e9afcc8..5b7e413a 100644 --- a/backend/internal/service/balance_notify_service.go +++ b/backend/internal/service/balance_notify_service.go @@ -477,4 +477,3 @@ func (s *BalanceNotifyService) buildQuotaAlertEmailBody(accountID int64, account } return fmt.Sprintf(quotaAlertEmailTemplate, siteName, accountID, accountName, platform, dimLabel, used, limitStr, remaining, thresholdDisplay) } - diff --git a/backend/internal/service/email_service.go b/backend/internal/service/email_service.go index a94e0dde..425887cd 100644 --- a/backend/internal/service/email_service.go +++ b/backend/internal/service/email_service.go @@ -9,6 +9,7 @@ import ( "fmt" "log/slog" "math/big" + "net" "net/smtp" "net/url" "strconv" @@ -152,6 +153,9 @@ func (s *EmailService) SendEmail(ctx context.Context, to, subject, body string) return s.SendEmailWithConfig(config, to, subject, body) } +const smtpDialTimeout = 10 * time.Second +const smtpIOTimeout = 20 * time.Second + // SendEmailWithConfig 使用指定配置发送邮件 func (s *EmailService) SendEmailWithConfig(config *SMTPConfig, to, subject, body string) error { // Sanitize all SMTP header fields to prevent header injection (CR/LF removal). @@ -173,7 +177,46 @@ func (s *EmailService) SendEmailWithConfig(config *SMTPConfig, to, subject, body return s.sendMailTLS(addr, auth, config.From, to, []byte(msg), config.Host) } - return smtp.SendMail(addr, auth, config.From, []string{to}, []byte(msg)) + return s.sendMailPlain(addr, auth, config.From, to, []byte(msg), config.Host) +} + +// sendMailPlain sends mail without TLS using a dialer with timeout. +func (s *EmailService) sendMailPlain(addr string, auth smtp.Auth, from, to string, msg []byte, host string) error { + dialer := &net.Dialer{Timeout: smtpDialTimeout} + conn, err := dialer.Dial("tcp", addr) + if err != nil { + return fmt.Errorf("smtp dial: %w", err) + } + _ = conn.SetDeadline(time.Now().Add(smtpIOTimeout)) + defer func() { _ = conn.Close() }() + + client, err := smtp.NewClient(conn, host) + if err != nil { + return fmt.Errorf("new smtp client: %w", err) + } + defer func() { _ = client.Close() }() + + if err = client.Auth(auth); err != nil { + return fmt.Errorf("smtp auth: %w", err) + } + if err = client.Mail(from); err != nil { + return fmt.Errorf("smtp mail: %w", err) + } + if err = client.Rcpt(to); err != nil { + return fmt.Errorf("smtp rcpt: %w", err) + } + w, err := client.Data() + if err != nil { + return fmt.Errorf("smtp data: %w", err) + } + if _, err = w.Write(msg); err != nil { + return fmt.Errorf("write msg: %w", err) + } + if err = w.Close(); err != nil { + return fmt.Errorf("close writer: %w", err) + } + _ = client.Quit() + return nil } // sendMailTLS 使用TLS发送邮件 @@ -184,10 +227,12 @@ func (s *EmailService) sendMailTLS(addr string, auth smtp.Auth, from, to string, MinVersion: tls.VersionTLS12, } - conn, err := tls.Dial("tcp", addr, tlsConfig) + dialer := &net.Dialer{Timeout: smtpDialTimeout} + conn, err := tls.DialWithDialer(dialer, "tcp", addr, tlsConfig) if err != nil { return fmt.Errorf("tls dial: %w", err) } + _ = conn.SetDeadline(time.Now().Add(smtpIOTimeout)) defer func() { _ = conn.Close() }() client, err := smtp.NewClient(conn, host) diff --git a/backend/internal/service/notify_email_entry.go b/backend/internal/service/notify_email_entry.go index d181200b..625185b2 100644 --- a/backend/internal/service/notify_email_entry.go +++ b/backend/internal/service/notify_email_entry.go @@ -79,4 +79,3 @@ func MarshalNotifyEmails(entries []NotifyEmailEntry) string { } return string(data) } - diff --git a/backend/migrations/106_add_account_stats_pricing_intervals.sql b/backend/migrations/106_add_account_stats_pricing_intervals.sql new file mode 100644 index 00000000..5ae10655 --- /dev/null +++ b/backend/migrations/106_add_account_stats_pricing_intervals.sql @@ -0,0 +1,19 @@ +-- Add intervals table for account stats pricing rules (mirrors channel_pricing_intervals). +CREATE TABLE IF NOT EXISTS channel_account_stats_pricing_intervals ( + id BIGSERIAL PRIMARY KEY, + pricing_id BIGINT NOT NULL REFERENCES channel_account_stats_model_pricing(id) ON DELETE CASCADE, + min_tokens INT NOT NULL DEFAULT 0, + max_tokens INT, + tier_label VARCHAR(50), + input_price NUMERIC(20,12), + output_price NUMERIC(20,12), + cache_write_price NUMERIC(20,12), + cache_read_price NUMERIC(20,12), + per_request_price NUMERIC(20,12), + sort_order INT NOT NULL DEFAULT 0, + created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), + updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW() +); + +CREATE INDEX IF NOT EXISTS idx_account_stats_pricing_intervals_pricing_id + ON channel_account_stats_pricing_intervals (pricing_id); diff --git a/frontend/src/views/admin/ChannelsView.vue b/frontend/src/views/admin/ChannelsView.vue index 0b37a20d..e4452b98 100644 --- a/frontend/src/views/admin/ChannelsView.vue +++ b/frontend/src/views/admin/ChannelsView.vue @@ -328,10 +328,10 @@