Files
sub2api/backend/internal/repository/channel_repo_account_stats_pricing.go
erio 0a4ece5f5b fix: audit round-3 — proxy safety, intervals persistence, SMTP timeout, sort fix
- Skip websearch provider when ProxyID is set but proxy not found (prevent
  silent direct connection bypass)
- Fix sortByStableRandomWeight: pair factors with items so sort.Slice swap
  keeps weights aligned
- Allow empty platform in account_stats_pricing_rules (wildcard matching),
  only force anthropic default for main model_pricing
- Add channel_account_stats_pricing_intervals table and repo layer support
  for interval-based pricing in account stats rules
- calculateTokenStatsCost now uses interval pricing when available
- Replace smtp.SendMail/tls.Dial with net.Dialer timeout (10s dial, 20s IO)
  to prevent goroutine leak on SMTP hang
- Fix gofmt formatting issues
- Web Search label: black text with red warning hint
2026-04-14 09:35:20 +08:00

245 lines
9.2 KiB
Go

package repository
import (
"context"
"database/sql"
"encoding/json"
"fmt"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/lib/pq"
)
// --- 账号统计定价规则 ---
// batchLoadAccountStatsPricingRules 批量加载多个渠道的账号统计定价规则(含模型定价)
func (r *channelRepository) batchLoadAccountStatsPricingRules(ctx context.Context, channelIDs []int64) (map[int64][]service.AccountStatsPricingRule, error) {
// 1. 查询规则
rows, err := r.db.QueryContext(ctx,
`SELECT id, channel_id, name, group_ids, account_ids, sort_order, created_at, updated_at
FROM channel_account_stats_pricing_rules WHERE channel_id = ANY($1) ORDER BY channel_id, sort_order, id`,
pq.Array(channelIDs),
)
if err != nil {
return nil, fmt.Errorf("batch load account stats pricing rules: %w", err)
}
defer func() { _ = rows.Close() }()
var allRules []service.AccountStatsPricingRule
var ruleIDs []int64
for rows.Next() {
var rule service.AccountStatsPricingRule
if err := rows.Scan(
&rule.ID, &rule.ChannelID, &rule.Name,
pq.Array(&rule.GroupIDs), pq.Array(&rule.AccountIDs),
&rule.SortOrder, &rule.CreatedAt, &rule.UpdatedAt,
); err != nil {
return nil, fmt.Errorf("scan account stats pricing rule: %w", err)
}
ruleIDs = append(ruleIDs, rule.ID)
allRules = append(allRules, rule)
}
if err := rows.Err(); err != nil {
return nil, fmt.Errorf("iterate account stats pricing rules: %w", err)
}
// 2. 批量加载规则的模型定价
pricingMap, err := r.batchLoadAccountStatsModelPricing(ctx, ruleIDs)
if err != nil {
return nil, err
}
// 3. 按 channelID 分组并关联定价
result := make(map[int64][]service.AccountStatsPricingRule, len(channelIDs))
for i := range allRules {
allRules[i].Pricing = pricingMap[allRules[i].ID]
result[allRules[i].ChannelID] = append(result[allRules[i].ChannelID], allRules[i])
}
return result, nil
}
// batchLoadAccountStatsModelPricing 批量加载规则的模型定价
func (r *channelRepository) batchLoadAccountStatsModelPricing(ctx context.Context, ruleIDs []int64) (map[int64][]service.ChannelModelPricing, error) {
if len(ruleIDs) == 0 {
return make(map[int64][]service.ChannelModelPricing), nil
}
rows, err := r.db.QueryContext(ctx,
`SELECT id, rule_id, platform, models, billing_mode, input_price, output_price,
cache_write_price, cache_read_price, image_output_price, per_request_price, created_at, updated_at
FROM channel_account_stats_model_pricing WHERE rule_id = ANY($1) ORDER BY rule_id, id`,
pq.Array(ruleIDs),
)
if err != nil {
return nil, fmt.Errorf("batch load account stats model pricing: %w", err)
}
defer func() { _ = rows.Close() }()
pricingMap := make(map[int64][]service.ChannelModelPricing, len(ruleIDs))
for rows.Next() {
var p service.ChannelModelPricing
var ruleID int64
var modelsJSON []byte
if err := rows.Scan(
&p.ID, &ruleID, &p.Platform, &modelsJSON, &p.BillingMode,
&p.InputPrice, &p.OutputPrice, &p.CacheWritePrice, &p.CacheReadPrice,
&p.ImageOutputPrice, &p.PerRequestPrice, &p.CreatedAt, &p.UpdatedAt,
); err != nil {
return nil, fmt.Errorf("scan account stats model pricing: %w", err)
}
if err := json.Unmarshal(modelsJSON, &p.Models); err != nil {
p.Models = []string{}
}
pricingMap[ruleID] = append(pricingMap[ruleID], p)
}
if err := rows.Err(); err != nil {
return nil, fmt.Errorf("iterate account stats model pricing: %w", err)
}
// Load intervals for all pricing entries.
var allPricingIDs []int64
for _, pricings := range pricingMap {
for _, p := range pricings {
allPricingIDs = append(allPricingIDs, p.ID)
}
}
if len(allPricingIDs) > 0 {
intervalsMap, err := r.batchLoadAccountStatsIntervals(ctx, allPricingIDs)
if err != nil {
return nil, err
}
for ruleID, pricings := range pricingMap {
for i := range pricings {
pricings[i].Intervals = intervalsMap[pricings[i].ID]
}
pricingMap[ruleID] = pricings
}
}
return pricingMap, nil
}
// loadAccountStatsPricingRules 加载单个渠道的账号统计定价规则(供 GetByID 使用)
func (r *channelRepository) loadAccountStatsPricingRules(ctx context.Context, channelID int64) ([]service.AccountStatsPricingRule, error) {
result, err := r.batchLoadAccountStatsPricingRules(ctx, []int64{channelID})
if err != nil {
return nil, err
}
return result[channelID], nil
}
// replaceAccountStatsPricingRulesTx 在事务中替换渠道的账号统计定价规则(删除旧的 + 插入新的)
func replaceAccountStatsPricingRulesTx(ctx context.Context, tx *sql.Tx, channelID int64, rules []service.AccountStatsPricingRule) error {
// CASCADE 会自动删除关联的 model_pricing
if _, err := tx.ExecContext(ctx,
`DELETE FROM channel_account_stats_pricing_rules WHERE channel_id = $1`, channelID,
); err != nil {
return fmt.Errorf("delete old account stats pricing rules: %w", err)
}
for i := range rules {
rules[i].ChannelID = channelID
if err := createAccountStatsPricingRuleTx(ctx, tx, &rules[i]); err != nil {
return fmt.Errorf("insert account stats pricing rule: %w", err)
}
}
return nil
}
// createAccountStatsPricingRuleTx 在事务中创建单条账号统计定价规则及其模型定价
func createAccountStatsPricingRuleTx(ctx context.Context, tx *sql.Tx, rule *service.AccountStatsPricingRule) error {
err := tx.QueryRowContext(ctx,
`INSERT INTO channel_account_stats_pricing_rules (channel_id, name, group_ids, account_ids, sort_order)
VALUES ($1, $2, $3, $4, $5) RETURNING id, created_at, updated_at`,
rule.ChannelID, rule.Name, pq.Array(rule.GroupIDs), pq.Array(rule.AccountIDs), rule.SortOrder,
).Scan(&rule.ID, &rule.CreatedAt, &rule.UpdatedAt)
if err != nil {
return fmt.Errorf("insert account stats pricing rule: %w", err)
}
for j := range rule.Pricing {
if err := createAccountStatsModelPricingTx(ctx, tx, rule.ID, &rule.Pricing[j]); err != nil {
return err
}
}
return nil
}
// createAccountStatsModelPricingTx 在事务中创建单条账号统计模型定价
func createAccountStatsModelPricingTx(ctx context.Context, tx *sql.Tx, ruleID int64, pricing *service.ChannelModelPricing) error {
modelsJSON, err := json.Marshal(pricing.Models)
if err != nil {
return fmt.Errorf("marshal models: %w", err)
}
billingMode := pricing.BillingMode
if billingMode == "" {
billingMode = service.BillingModeToken
}
platform := pricing.Platform
err = tx.QueryRowContext(ctx,
`INSERT INTO channel_account_stats_model_pricing (rule_id, platform, models, billing_mode, input_price, output_price, cache_write_price, cache_read_price, image_output_price, per_request_price)
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10) RETURNING id, created_at, updated_at`,
ruleID, platform, modelsJSON, billingMode,
pricing.InputPrice, pricing.OutputPrice, pricing.CacheWritePrice, pricing.CacheReadPrice,
pricing.ImageOutputPrice, pricing.PerRequestPrice,
).Scan(&pricing.ID, &pricing.CreatedAt, &pricing.UpdatedAt)
if err != nil {
return fmt.Errorf("insert account stats model pricing: %w", err)
}
// Persist intervals (mirrors channel_pricing_intervals logic).
for i := range pricing.Intervals {
iv := &pricing.Intervals[i]
iv.PricingID = pricing.ID
if err := createAccountStatsIntervalTx(ctx, tx, iv); err != nil {
return err
}
}
return nil
}
// createAccountStatsIntervalTx inserts a single interval for an account stats pricing entry.
func createAccountStatsIntervalTx(ctx context.Context, tx *sql.Tx, iv *service.PricingInterval) error {
return tx.QueryRowContext(ctx,
`INSERT INTO channel_account_stats_pricing_intervals
(pricing_id, min_tokens, max_tokens, tier_label, input_price, output_price, cache_write_price, cache_read_price, per_request_price, sort_order)
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10) RETURNING id, created_at, updated_at`,
iv.PricingID, iv.MinTokens, iv.MaxTokens, iv.TierLabel,
iv.InputPrice, iv.OutputPrice, iv.CacheWritePrice, iv.CacheReadPrice,
iv.PerRequestPrice, iv.SortOrder,
).Scan(&iv.ID, &iv.CreatedAt, &iv.UpdatedAt)
}
// batchLoadAccountStatsIntervals loads intervals for account stats pricing entries.
func (r *channelRepository) batchLoadAccountStatsIntervals(ctx context.Context, pricingIDs []int64) (map[int64][]service.PricingInterval, error) {
if len(pricingIDs) == 0 {
return nil, nil
}
rows, err := r.db.QueryContext(ctx,
`SELECT id, pricing_id, min_tokens, max_tokens, tier_label,
input_price, output_price, cache_write_price, cache_read_price,
per_request_price, sort_order, created_at, updated_at
FROM channel_account_stats_pricing_intervals
WHERE pricing_id = ANY($1) ORDER BY pricing_id, sort_order, id`,
pq.Array(pricingIDs),
)
if err != nil {
return nil, fmt.Errorf("batch load account stats pricing intervals: %w", err)
}
defer func() { _ = rows.Close() }()
result := make(map[int64][]service.PricingInterval)
for rows.Next() {
var iv service.PricingInterval
if err := rows.Scan(
&iv.ID, &iv.PricingID, &iv.MinTokens, &iv.MaxTokens, &iv.TierLabel,
&iv.InputPrice, &iv.OutputPrice, &iv.CacheWritePrice, &iv.CacheReadPrice,
&iv.PerRequestPrice, &iv.SortOrder, &iv.CreatedAt, &iv.UpdatedAt,
); err != nil {
return nil, fmt.Errorf("scan account stats pricing interval: %w", err)
}
result[iv.PricingID] = append(result[iv.PricingID], iv)
}
return result, rows.Err()
}