feat(channels): add custom account stats pricing rules

Allow channels to configure independent model pricing for account
statistics cost calculation, decoupled from user billing.

Backend:
- Migration 101: channels.apply_pricing_to_account_stats toggle,
  channel_account_stats_pricing_rules/model_pricing tables,
  usage_logs.account_stats_cost column
- resolveAccountStatsCost: match rules by group/account, then channel
  pricing, fallback to original formula when unconfigured
- Integrate into both GatewayService.recordUsageCore and
  OpenAIGatewayService.RecordUsage
- Update 8 account stats SQL queries to use
  COALESCE(account_stats_cost, total_cost) * account_rate_multiplier
- 23 unit tests for matching, pricing lookup, and cost calculation

Frontend:
- Channel edit dialog: toggle + custom rules UI with group/account
  multi-select and pricing entry cards
- API types and i18n (zh/en)
This commit is contained in:
erio
2026-04-11 23:39:49 +08:00
parent 7fad9f604f
commit 7535e312e0
17 changed files with 1449 additions and 244 deletions

View File

@@ -41,14 +41,10 @@ func (r *channelRepository) Create(ctx context.Context, channel *service.Channel
if err != nil {
return err
}
featuresConfigJSON, err := marshalFeaturesConfig(channel.FeaturesConfig)
if err != nil {
return err
}
err = tx.QueryRowContext(ctx,
`INSERT INTO channels (name, description, status, model_mapping, billing_model_source, restrict_models, features, features_config) VALUES ($1, $2, $3, $4, $5, $6, $7, $8)
`INSERT INTO channels (name, description, status, model_mapping, billing_model_source, restrict_models, features, apply_pricing_to_account_stats) VALUES ($1, $2, $3, $4, $5, $6, $7, $8)
RETURNING id, created_at, updated_at`,
channel.Name, channel.Description, channel.Status, modelMappingJSON, channel.BillingModelSource, channel.RestrictModels, channel.Features, featuresConfigJSON,
channel.Name, channel.Description, channel.Status, modelMappingJSON, channel.BillingModelSource, channel.RestrictModels, channel.Features, channel.ApplyPricingToAccountStats,
).Scan(&channel.ID, &channel.CreatedAt, &channel.UpdatedAt)
if err != nil {
if isUniqueViolation(err) {
@@ -71,17 +67,24 @@ func (r *channelRepository) Create(ctx context.Context, channel *service.Channel
}
}
// 设置账号统计定价规则
if len(channel.AccountStatsPricingRules) > 0 {
if err := replaceAccountStatsPricingRulesTx(ctx, tx, channel.ID, channel.AccountStatsPricingRules); err != nil {
return err
}
}
return nil
})
}
func (r *channelRepository) GetByID(ctx context.Context, id int64) (*service.Channel, error) {
ch := &service.Channel{}
var modelMappingJSON, featuresConfigJSON []byte
var modelMappingJSON []byte
err := r.db.QueryRowContext(ctx,
`SELECT id, name, description, status, model_mapping, billing_model_source, restrict_models, features, features_config, created_at, updated_at
`SELECT id, name, description, status, model_mapping, billing_model_source, restrict_models, features, apply_pricing_to_account_stats, created_at, updated_at
FROM channels WHERE id = $1`, id,
).Scan(&ch.ID, &ch.Name, &ch.Description, &ch.Status, &modelMappingJSON, &ch.BillingModelSource, &ch.RestrictModels, &ch.Features, &featuresConfigJSON, &ch.CreatedAt, &ch.UpdatedAt)
).Scan(&ch.ID, &ch.Name, &ch.Description, &ch.Status, &modelMappingJSON, &ch.BillingModelSource, &ch.RestrictModels, &ch.Features, &ch.ApplyPricingToAccountStats, &ch.CreatedAt, &ch.UpdatedAt)
if err == sql.ErrNoRows {
return nil, service.ErrChannelNotFound
}
@@ -89,7 +92,6 @@ func (r *channelRepository) GetByID(ctx context.Context, id int64) (*service.Cha
return nil, fmt.Errorf("get channel: %w", err)
}
ch.ModelMapping = unmarshalModelMapping(modelMappingJSON)
ch.FeaturesConfig = unmarshalFeaturesConfig(featuresConfigJSON)
groupIDs, err := r.GetGroupIDs(ctx, id)
if err != nil {
@@ -103,6 +105,12 @@ func (r *channelRepository) GetByID(ctx context.Context, id int64) (*service.Cha
}
ch.ModelPricing = pricing
statsPricingRules, err := r.loadAccountStatsPricingRules(ctx, id)
if err != nil {
return nil, err
}
ch.AccountStatsPricingRules = statsPricingRules
return ch, nil
}
@@ -112,14 +120,10 @@ func (r *channelRepository) Update(ctx context.Context, channel *service.Channel
if err != nil {
return err
}
featuresConfigJSON, err := marshalFeaturesConfig(channel.FeaturesConfig)
if err != nil {
return err
}
result, err := tx.ExecContext(ctx,
`UPDATE channels SET name = $1, description = $2, status = $3, model_mapping = $4, billing_model_source = $5, restrict_models = $6, features = $7, features_config = $8, updated_at = NOW()
`UPDATE channels SET name = $1, description = $2, status = $3, model_mapping = $4, billing_model_source = $5, restrict_models = $6, features = $7, apply_pricing_to_account_stats = $8, updated_at = NOW()
WHERE id = $9`,
channel.Name, channel.Description, channel.Status, modelMappingJSON, channel.BillingModelSource, channel.RestrictModels, channel.Features, featuresConfigJSON, channel.ID,
channel.Name, channel.Description, channel.Status, modelMappingJSON, channel.BillingModelSource, channel.RestrictModels, channel.Features, channel.ApplyPricingToAccountStats, channel.ID,
)
if err != nil {
if isUniqueViolation(err) {
@@ -146,6 +150,13 @@ func (r *channelRepository) Update(ctx context.Context, channel *service.Channel
}
}
// 更新账号统计定价规则
if channel.AccountStatsPricingRules != nil {
if err := replaceAccountStatsPricingRulesTx(ctx, tx, channel.ID, channel.AccountStatsPricingRules); err != nil {
return err
}
}
return nil
})
}
@@ -196,7 +207,7 @@ func (r *channelRepository) List(ctx context.Context, params pagination.Paginati
// 查询 channel 列表
dataQuery := fmt.Sprintf(
`SELECT c.id, c.name, c.description, c.status, c.model_mapping, c.billing_model_source, c.restrict_models, c.features, c.features_config, c.created_at, c.updated_at
`SELECT c.id, c.name, c.description, c.status, c.model_mapping, c.billing_model_source, c.restrict_models, c.features, c.apply_pricing_to_account_stats, c.created_at, c.updated_at
FROM channels c WHERE %s ORDER BY %s LIMIT $%d OFFSET $%d`,
whereClause, channelListOrderBy(params), argIdx, argIdx+1,
)
@@ -212,12 +223,11 @@ func (r *channelRepository) List(ctx context.Context, params pagination.Paginati
var channelIDs []int64
for rows.Next() {
var ch service.Channel
var modelMappingJSON, featuresConfigJSON []byte
if err := rows.Scan(&ch.ID, &ch.Name, &ch.Description, &ch.Status, &modelMappingJSON, &ch.BillingModelSource, &ch.RestrictModels, &ch.Features, &featuresConfigJSON, &ch.CreatedAt, &ch.UpdatedAt); err != nil {
var modelMappingJSON []byte
if err := rows.Scan(&ch.ID, &ch.Name, &ch.Description, &ch.Status, &modelMappingJSON, &ch.BillingModelSource, &ch.RestrictModels, &ch.Features, &ch.ApplyPricingToAccountStats, &ch.CreatedAt, &ch.UpdatedAt); err != nil {
return nil, nil, fmt.Errorf("scan channel: %w", err)
}
ch.ModelMapping = unmarshalModelMapping(modelMappingJSON)
ch.FeaturesConfig = unmarshalFeaturesConfig(featuresConfigJSON)
channels = append(channels, ch)
channelIDs = append(channelIDs, ch.ID)
}
@@ -235,9 +245,14 @@ func (r *channelRepository) List(ctx context.Context, params pagination.Paginati
if err != nil {
return nil, nil, err
}
statsRulesMap, err := r.batchLoadAccountStatsPricingRules(ctx, channelIDs)
if err != nil {
return nil, nil, err
}
for i := range channels {
channels[i].GroupIDs = groupMap[channels[i].ID]
channels[i].ModelPricing = pricingMap[channels[i].ID]
channels[i].AccountStatsPricingRules = statsRulesMap[channels[i].ID]
}
}
@@ -283,7 +298,7 @@ func channelListOrderBy(params pagination.PaginationParams) string {
func (r *channelRepository) ListAll(ctx context.Context) ([]service.Channel, error) {
rows, err := r.db.QueryContext(ctx,
`SELECT id, name, description, status, model_mapping, billing_model_source, restrict_models, features, features_config, created_at, updated_at FROM channels ORDER BY id`,
`SELECT id, name, description, status, model_mapping, billing_model_source, restrict_models, features, apply_pricing_to_account_stats, created_at, updated_at FROM channels ORDER BY id`,
)
if err != nil {
return nil, fmt.Errorf("query all channels: %w", err)
@@ -294,12 +309,11 @@ func (r *channelRepository) ListAll(ctx context.Context) ([]service.Channel, err
var channelIDs []int64
for rows.Next() {
var ch service.Channel
var modelMappingJSON, featuresConfigJSON []byte
if err := rows.Scan(&ch.ID, &ch.Name, &ch.Description, &ch.Status, &modelMappingJSON, &ch.BillingModelSource, &ch.RestrictModels, &ch.Features, &featuresConfigJSON, &ch.CreatedAt, &ch.UpdatedAt); err != nil {
var modelMappingJSON []byte
if err := rows.Scan(&ch.ID, &ch.Name, &ch.Description, &ch.Status, &modelMappingJSON, &ch.BillingModelSource, &ch.RestrictModels, &ch.Features, &ch.ApplyPricingToAccountStats, &ch.CreatedAt, &ch.UpdatedAt); err != nil {
return nil, fmt.Errorf("scan channel: %w", err)
}
ch.ModelMapping = unmarshalModelMapping(modelMappingJSON)
ch.FeaturesConfig = unmarshalFeaturesConfig(featuresConfigJSON)
channels = append(channels, ch)
channelIDs = append(channelIDs, ch.ID)
}
@@ -323,9 +337,16 @@ func (r *channelRepository) ListAll(ctx context.Context) ([]service.Channel, err
return nil, err
}
// 批量加载账号统计定价规则
statsRulesMap, err := r.batchLoadAccountStatsPricingRules(ctx, channelIDs)
if err != nil {
return nil, err
}
for i := range channels {
channels[i].GroupIDs = groupMap[channels[i].ID]
channels[i].ModelPricing = pricingMap[channels[i].ID]
channels[i].AccountStatsPricingRules = statsRulesMap[channels[i].ID]
}
return channels, nil
@@ -467,28 +488,6 @@ func unmarshalModelMapping(data []byte) map[string]map[string]string {
return m
}
func marshalFeaturesConfig(m map[string]any) ([]byte, error) {
if len(m) == 0 {
return []byte("{}"), nil
}
data, err := json.Marshal(m)
if err != nil {
return nil, fmt.Errorf("marshal features_config: %w", err)
}
return data, nil
}
func unmarshalFeaturesConfig(data []byte) map[string]any {
if len(data) == 0 {
return nil
}
var m map[string]any
if err := json.Unmarshal(data, &m); err != nil {
return nil
}
return m
}
// GetGroupPlatforms 批量查询分组 ID 对应的平台
func (r *channelRepository) GetGroupPlatforms(ctx context.Context, groupIDs []int64) (map[int64]string, error) {
if len(groupIDs) == 0 {

View File

@@ -0,0 +1,170 @@
package repository
import (
"context"
"database/sql"
"encoding/json"
"fmt"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/lib/pq"
)
// --- 账号统计定价规则 ---
// batchLoadAccountStatsPricingRules 批量加载多个渠道的账号统计定价规则(含模型定价)
func (r *channelRepository) batchLoadAccountStatsPricingRules(ctx context.Context, channelIDs []int64) (map[int64][]service.AccountStatsPricingRule, error) {
// 1. 查询规则
rows, err := r.db.QueryContext(ctx,
`SELECT id, channel_id, name, group_ids, account_ids, sort_order, created_at, updated_at
FROM channel_account_stats_pricing_rules WHERE channel_id = ANY($1) ORDER BY channel_id, sort_order, id`,
pq.Array(channelIDs),
)
if err != nil {
return nil, fmt.Errorf("batch load account stats pricing rules: %w", err)
}
defer func() { _ = rows.Close() }()
var allRules []service.AccountStatsPricingRule
var ruleIDs []int64
for rows.Next() {
var rule service.AccountStatsPricingRule
if err := rows.Scan(
&rule.ID, &rule.ChannelID, &rule.Name,
pq.Array(&rule.GroupIDs), pq.Array(&rule.AccountIDs),
&rule.SortOrder, &rule.CreatedAt, &rule.UpdatedAt,
); err != nil {
return nil, fmt.Errorf("scan account stats pricing rule: %w", err)
}
ruleIDs = append(ruleIDs, rule.ID)
allRules = append(allRules, rule)
}
if err := rows.Err(); err != nil {
return nil, fmt.Errorf("iterate account stats pricing rules: %w", err)
}
// 2. 批量加载规则的模型定价
pricingMap, err := r.batchLoadAccountStatsModelPricing(ctx, ruleIDs)
if err != nil {
return nil, err
}
// 3. 按 channelID 分组并关联定价
result := make(map[int64][]service.AccountStatsPricingRule, len(channelIDs))
for i := range allRules {
allRules[i].Pricing = pricingMap[allRules[i].ID]
result[allRules[i].ChannelID] = append(result[allRules[i].ChannelID], allRules[i])
}
return result, nil
}
// batchLoadAccountStatsModelPricing 批量加载规则的模型定价
func (r *channelRepository) batchLoadAccountStatsModelPricing(ctx context.Context, ruleIDs []int64) (map[int64][]service.ChannelModelPricing, error) {
if len(ruleIDs) == 0 {
return make(map[int64][]service.ChannelModelPricing), nil
}
rows, err := r.db.QueryContext(ctx,
`SELECT id, rule_id, platform, models, billing_mode, input_price, output_price,
cache_write_price, cache_read_price, image_output_price, per_request_price, created_at, updated_at
FROM channel_account_stats_model_pricing WHERE rule_id = ANY($1) ORDER BY rule_id, id`,
pq.Array(ruleIDs),
)
if err != nil {
return nil, fmt.Errorf("batch load account stats model pricing: %w", err)
}
defer func() { _ = rows.Close() }()
pricingMap := make(map[int64][]service.ChannelModelPricing, len(ruleIDs))
for rows.Next() {
var p service.ChannelModelPricing
var ruleID int64
var modelsJSON []byte
if err := rows.Scan(
&p.ID, &ruleID, &p.Platform, &modelsJSON, &p.BillingMode,
&p.InputPrice, &p.OutputPrice, &p.CacheWritePrice, &p.CacheReadPrice,
&p.ImageOutputPrice, &p.PerRequestPrice, &p.CreatedAt, &p.UpdatedAt,
); err != nil {
return nil, fmt.Errorf("scan account stats model pricing: %w", err)
}
if err := json.Unmarshal(modelsJSON, &p.Models); err != nil {
p.Models = []string{}
}
pricingMap[ruleID] = append(pricingMap[ruleID], p)
}
if err := rows.Err(); err != nil {
return nil, fmt.Errorf("iterate account stats model pricing: %w", err)
}
return pricingMap, nil
}
// loadAccountStatsPricingRules 加载单个渠道的账号统计定价规则(供 GetByID 使用)
func (r *channelRepository) loadAccountStatsPricingRules(ctx context.Context, channelID int64) ([]service.AccountStatsPricingRule, error) {
result, err := r.batchLoadAccountStatsPricingRules(ctx, []int64{channelID})
if err != nil {
return nil, err
}
return result[channelID], nil
}
// replaceAccountStatsPricingRulesTx 在事务中替换渠道的账号统计定价规则(删除旧的 + 插入新的)
func replaceAccountStatsPricingRulesTx(ctx context.Context, tx *sql.Tx, channelID int64, rules []service.AccountStatsPricingRule) error {
// CASCADE 会自动删除关联的 model_pricing
if _, err := tx.ExecContext(ctx,
`DELETE FROM channel_account_stats_pricing_rules WHERE channel_id = $1`, channelID,
); err != nil {
return fmt.Errorf("delete old account stats pricing rules: %w", err)
}
for i := range rules {
rules[i].ChannelID = channelID
if err := createAccountStatsPricingRuleTx(ctx, tx, &rules[i]); err != nil {
return fmt.Errorf("insert account stats pricing rule: %w", err)
}
}
return nil
}
// createAccountStatsPricingRuleTx 在事务中创建单条账号统计定价规则及其模型定价
func createAccountStatsPricingRuleTx(ctx context.Context, tx *sql.Tx, rule *service.AccountStatsPricingRule) error {
err := tx.QueryRowContext(ctx,
`INSERT INTO channel_account_stats_pricing_rules (channel_id, name, group_ids, account_ids, sort_order)
VALUES ($1, $2, $3, $4, $5) RETURNING id, created_at, updated_at`,
rule.ChannelID, rule.Name, pq.Array(rule.GroupIDs), pq.Array(rule.AccountIDs), rule.SortOrder,
).Scan(&rule.ID, &rule.CreatedAt, &rule.UpdatedAt)
if err != nil {
return fmt.Errorf("insert account stats pricing rule: %w", err)
}
for j := range rule.Pricing {
if err := createAccountStatsModelPricingTx(ctx, tx, rule.ID, &rule.Pricing[j]); err != nil {
return err
}
}
return nil
}
// createAccountStatsModelPricingTx 在事务中创建单条账号统计模型定价
func createAccountStatsModelPricingTx(ctx context.Context, tx *sql.Tx, ruleID int64, pricing *service.ChannelModelPricing) error {
modelsJSON, err := json.Marshal(pricing.Models)
if err != nil {
return fmt.Errorf("marshal models: %w", err)
}
billingMode := pricing.BillingMode
if billingMode == "" {
billingMode = service.BillingModeToken
}
platform := pricing.Platform
err = tx.QueryRowContext(ctx,
`INSERT INTO channel_account_stats_model_pricing (rule_id, platform, models, billing_mode, input_price, output_price, cache_write_price, cache_read_price, image_output_price, per_request_price)
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10) RETURNING id, created_at, updated_at`,
ruleID, platform, modelsJSON, billingMode,
pricing.InputPrice, pricing.OutputPrice, pricing.CacheWritePrice, pricing.CacheReadPrice,
pricing.ImageOutputPrice, pricing.PerRequestPrice,
).Scan(&pricing.ID, &pricing.CreatedAt, &pricing.UpdatedAt)
if err != nil {
return fmt.Errorf("insert account stats model pricing: %w", err)
}
return nil
}

View File

@@ -28,7 +28,7 @@ import (
gocache "github.com/patrickmn/go-cache"
)
const usageLogSelectColumns = "id, user_id, api_key_id, account_id, request_id, model, requested_model, upstream_model, group_id, subscription_id, input_tokens, output_tokens, cache_creation_tokens, cache_read_tokens, cache_creation_5m_tokens, cache_creation_1h_tokens, image_output_tokens, image_output_cost, input_cost, output_cost, cache_creation_cost, cache_read_cost, total_cost, actual_cost, rate_multiplier, account_rate_multiplier, billing_type, request_type, stream, openai_ws_mode, duration_ms, first_token_ms, user_agent, ip_address, image_count, image_size, service_tier, reasoning_effort, inbound_endpoint, upstream_endpoint, cache_ttl_overridden, channel_id, model_mapping_chain, billing_tier, billing_mode, created_at"
const usageLogSelectColumns = "id, user_id, api_key_id, account_id, request_id, model, requested_model, upstream_model, group_id, subscription_id, input_tokens, output_tokens, cache_creation_tokens, cache_read_tokens, cache_creation_5m_tokens, cache_creation_1h_tokens, image_output_tokens, image_output_cost, input_cost, output_cost, cache_creation_cost, cache_read_cost, total_cost, actual_cost, rate_multiplier, account_rate_multiplier, billing_type, request_type, stream, openai_ws_mode, duration_ms, first_token_ms, user_agent, ip_address, image_count, image_size, service_tier, reasoning_effort, inbound_endpoint, upstream_endpoint, cache_ttl_overridden, channel_id, model_mapping_chain, billing_tier, billing_mode, account_stats_cost, created_at"
// usageLogInsertArgTypes must stay in the same order as:
// 1. prepareUsageLogInsert().args
@@ -82,6 +82,7 @@ var usageLogInsertArgTypes = [...]string{
"text", // model_mapping_chain
"text", // billing_tier
"text", // billing_mode
"numeric", // account_stats_cost
"timestamptz", // created_at
}
@@ -360,6 +361,7 @@ func (r *usageLogRepository) createSingle(ctx context.Context, sqlq sqlExecutor,
model_mapping_chain,
billing_tier,
billing_mode,
account_stats_cost,
created_at
) VALUES (
$1, $2, $3, $4, $5, $6, $7,
@@ -367,7 +369,7 @@ func (r *usageLogRepository) createSingle(ctx context.Context, sqlq sqlExecutor,
$10, $11, $12, $13,
$14, $15, $16, $17,
$18, $19, $20, $21, $22, $23,
$24, $25, $26, $27, $28, $29, $30, $31, $32, $33, $34, $35, $36, $37, $38, $39, $40, $41, $42, $43, $44, $45
$24, $25, $26, $27, $28, $29, $30, $31, $32, $33, $34, $35, $36, $37, $38, $39, $40, $41, $42, $43, $44, $45, $46
)
ON CONFLICT (request_id, api_key_id) DO NOTHING
RETURNING id, created_at
@@ -797,6 +799,7 @@ func buildUsageLogBatchInsertQuery(keys []string, preparedByKey map[string]usage
model_mapping_chain,
billing_tier,
billing_mode,
account_stats_cost,
created_at
) AS (VALUES `)
@@ -873,6 +876,7 @@ func buildUsageLogBatchInsertQuery(keys []string, preparedByKey map[string]usage
model_mapping_chain,
billing_tier,
billing_mode,
account_stats_cost,
created_at
)
SELECT
@@ -920,6 +924,7 @@ func buildUsageLogBatchInsertQuery(keys []string, preparedByKey map[string]usage
model_mapping_chain,
billing_tier,
billing_mode,
account_stats_cost,
created_at
FROM input
ON CONFLICT (request_id, api_key_id) DO NOTHING
@@ -1007,10 +1012,11 @@ func buildUsageLogBestEffortInsertQuery(preparedList []usageLogInsertPrepared) (
model_mapping_chain,
billing_tier,
billing_mode,
account_stats_cost,
created_at
) AS (VALUES `)
args := make([]any, 0, len(preparedList)*45)
args := make([]any, 0, len(preparedList)*46)
argPos := 1
for idx, prepared := range preparedList {
if idx > 0 {
@@ -1080,6 +1086,7 @@ func buildUsageLogBestEffortInsertQuery(preparedList []usageLogInsertPrepared) (
model_mapping_chain,
billing_tier,
billing_mode,
account_stats_cost,
created_at
)
SELECT
@@ -1127,6 +1134,7 @@ func buildUsageLogBestEffortInsertQuery(preparedList []usageLogInsertPrepared) (
model_mapping_chain,
billing_tier,
billing_mode,
account_stats_cost,
created_at
FROM input
ON CONFLICT (request_id, api_key_id) DO NOTHING
@@ -1182,6 +1190,7 @@ func execUsageLogInsertNoResult(ctx context.Context, sqlq sqlExecutor, prepared
model_mapping_chain,
billing_tier,
billing_mode,
account_stats_cost,
created_at
) VALUES (
$1, $2, $3, $4, $5, $6, $7,
@@ -1189,7 +1198,7 @@ func execUsageLogInsertNoResult(ctx context.Context, sqlq sqlExecutor, prepared
$10, $11, $12, $13,
$14, $15, $16, $17,
$18, $19, $20, $21, $22, $23,
$24, $25, $26, $27, $28, $29, $30, $31, $32, $33, $34, $35, $36, $37, $38, $39, $40, $41, $42, $43, $44, $45
$24, $25, $26, $27, $28, $29, $30, $31, $32, $33, $34, $35, $36, $37, $38, $39, $40, $41, $42, $43, $44, $45, $46
)
ON CONFLICT (request_id, api_key_id) DO NOTHING
`, prepared.args...)
@@ -1285,6 +1294,7 @@ func prepareUsageLogInsert(log *service.UsageLog) usageLogInsertPrepared {
modelMappingChain,
billingTier,
billingMode,
log.AccountStatsCost, // account_stats_cost
createdAt,
},
}
@@ -1959,7 +1969,7 @@ func (r *usageLogRepository) GetAccountTodayStats(ctx context.Context, accountID
SELECT
COUNT(*) as requests,
COALESCE(SUM(input_tokens + output_tokens + cache_creation_tokens + cache_read_tokens), 0) as tokens,
COALESCE(SUM(total_cost * COALESCE(account_rate_multiplier, 1)), 0) as cost,
COALESCE(SUM(COALESCE(account_stats_cost, total_cost) * COALESCE(account_rate_multiplier, 1)), 0) as cost,
COALESCE(SUM(total_cost), 0) as standard_cost,
COALESCE(SUM(actual_cost), 0) as user_cost
FROM usage_logs
@@ -1989,7 +1999,7 @@ func (r *usageLogRepository) GetAccountWindowStats(ctx context.Context, accountI
SELECT
COUNT(*) as requests,
COALESCE(SUM(input_tokens + output_tokens + cache_creation_tokens + cache_read_tokens), 0) as tokens,
COALESCE(SUM(total_cost * COALESCE(account_rate_multiplier, 1)), 0) as cost,
COALESCE(SUM(COALESCE(account_stats_cost, total_cost) * COALESCE(account_rate_multiplier, 1)), 0) as cost,
COALESCE(SUM(total_cost), 0) as standard_cost,
COALESCE(SUM(actual_cost), 0) as user_cost
FROM usage_logs
@@ -2026,7 +2036,7 @@ func (r *usageLogRepository) GetAccountWindowStatsBatch(ctx context.Context, acc
account_id,
COUNT(*) as requests,
COALESCE(SUM(input_tokens + output_tokens + cache_creation_tokens + cache_read_tokens), 0) as tokens,
COALESCE(SUM(total_cost * COALESCE(account_rate_multiplier, 1)), 0) as cost,
COALESCE(SUM(COALESCE(account_stats_cost, total_cost) * COALESCE(account_rate_multiplier, 1)), 0) as cost,
COALESCE(SUM(total_cost), 0) as standard_cost,
COALESCE(SUM(actual_cost), 0) as user_cost
FROM usage_logs
@@ -2990,7 +3000,7 @@ func (r *usageLogRepository) getModelStatsWithFiltersBySource(ctx context.Contex
actualCostExpr := "COALESCE(SUM(actual_cost), 0) as actual_cost"
// 当仅按 account_id 聚合时实际费用使用账号倍率total_cost * account_rate_multiplier
if accountID > 0 && userID == 0 && apiKeyID == 0 {
actualCostExpr = "COALESCE(SUM(total_cost * COALESCE(account_rate_multiplier, 1)), 0) as actual_cost"
actualCostExpr = "COALESCE(SUM(COALESCE(account_stats_cost, total_cost) * COALESCE(account_rate_multiplier, 1)), 0) as actual_cost"
}
modelExpr := resolveModelDimensionExpression(source)
@@ -3358,7 +3368,7 @@ func (r *usageLogRepository) GetStatsWithFilters(ctx context.Context, filters Us
COALESCE(SUM(cache_creation_tokens + cache_read_tokens), 0) as total_cache_tokens,
COALESCE(SUM(total_cost), 0) as total_cost,
COALESCE(SUM(actual_cost), 0) as total_actual_cost,
COALESCE(SUM(total_cost * COALESCE(account_rate_multiplier, 1)), 0) as total_account_cost,
COALESCE(SUM(COALESCE(account_stats_cost, total_cost) * COALESCE(account_rate_multiplier, 1)), 0) as total_account_cost,
COALESCE(AVG(duration_ms), 0) as avg_duration_ms
FROM usage_logs
%s
@@ -3433,7 +3443,7 @@ type EndpointStat = usagestats.EndpointStat
func (r *usageLogRepository) getEndpointStatsByColumnWithFilters(ctx context.Context, endpointColumn string, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, model string, requestType *int16, stream *bool, billingType *int8) (results []EndpointStat, err error) {
actualCostExpr := "COALESCE(SUM(actual_cost), 0) as actual_cost"
if accountID > 0 && userID == 0 && apiKeyID == 0 {
actualCostExpr = "COALESCE(SUM(total_cost * COALESCE(account_rate_multiplier, 1)), 0) as actual_cost"
actualCostExpr = "COALESCE(SUM(COALESCE(account_stats_cost, total_cost) * COALESCE(account_rate_multiplier, 1)), 0) as actual_cost"
}
query := fmt.Sprintf(`
@@ -3500,7 +3510,7 @@ func (r *usageLogRepository) getEndpointStatsByColumnWithFilters(ctx context.Con
func (r *usageLogRepository) getEndpointPathStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, model string, requestType *int16, stream *bool, billingType *int8) (results []EndpointStat, err error) {
actualCostExpr := "COALESCE(SUM(actual_cost), 0) as actual_cost"
if accountID > 0 && userID == 0 && apiKeyID == 0 {
actualCostExpr = "COALESCE(SUM(total_cost * COALESCE(account_rate_multiplier, 1)), 0) as actual_cost"
actualCostExpr = "COALESCE(SUM(COALESCE(account_stats_cost, total_cost) * COALESCE(account_rate_multiplier, 1)), 0) as actual_cost"
}
query := fmt.Sprintf(`
@@ -3591,7 +3601,7 @@ func (r *usageLogRepository) GetAccountUsageStats(ctx context.Context, accountID
COUNT(*) as requests,
COALESCE(SUM(input_tokens + output_tokens + cache_creation_tokens + cache_read_tokens), 0) as tokens,
COALESCE(SUM(total_cost), 0) as cost,
COALESCE(SUM(total_cost * COALESCE(account_rate_multiplier, 1)), 0) as actual_cost,
COALESCE(SUM(COALESCE(account_stats_cost, total_cost) * COALESCE(account_rate_multiplier, 1)), 0) as actual_cost,
COALESCE(SUM(actual_cost), 0) as user_cost
FROM usage_logs
WHERE account_id = $1 AND created_at >= $2 AND created_at < $3
@@ -4069,6 +4079,7 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e
modelMappingChain sql.NullString
billingTier sql.NullString
billingMode sql.NullString
accountStatsCost sql.NullFloat64
createdAt time.Time
)
@@ -4118,6 +4129,7 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e
&modelMappingChain,
&billingTier,
&billingMode,
&accountStatsCost,
&createdAt,
); err != nil {
return nil, err
@@ -4214,6 +4226,9 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e
if billingMode.Valid {
log.BillingMode = &billingMode.String
}
if accountStatsCost.Valid {
log.AccountStatsCost = &accountStatsCost.Float64
}
return log, nil
}

View File

@@ -85,6 +85,7 @@ func TestUsageLogRepositoryCreateSyncRequestTypeAndLegacyFields(t *testing.T) {
sqlmock.AnyArg(), // model_mapping_chain
sqlmock.AnyArg(), // billing_tier
sqlmock.AnyArg(), // billing_mode
sqlmock.AnyArg(), // account_stats_cost
createdAt,
).
WillReturnRows(sqlmock.NewRows([]string{"id", "created_at"}).AddRow(int64(99), createdAt))
@@ -163,6 +164,7 @@ func TestUsageLogRepositoryCreate_PersistsServiceTier(t *testing.T) {
sqlmock.AnyArg(), // model_mapping_chain
sqlmock.AnyArg(), // billing_tier
sqlmock.AnyArg(), // billing_mode
sqlmock.AnyArg(), // account_stats_cost
createdAt,
).
WillReturnRows(sqlmock.NewRows([]string{"id", "created_at"}).AddRow(int64(100), createdAt))
@@ -483,10 +485,11 @@ func TestScanUsageLogRequestTypeAndLegacyFallback(t *testing.T) {
sql.NullString{},
sql.NullString{},
false,
sql.NullInt64{}, // channel_id
sql.NullString{}, // model_mapping_chain
sql.NullString{}, // billing_tier
sql.NullString{}, // billing_mode
sql.NullInt64{}, // channel_id
sql.NullString{}, // model_mapping_chain
sql.NullString{}, // billing_tier
sql.NullString{}, // billing_mode
sql.NullFloat64{}, // account_stats_cost
now,
}})
require.NoError(t, err)
@@ -530,10 +533,11 @@ func TestScanUsageLogRequestTypeAndLegacyFallback(t *testing.T) {
sql.NullString{},
sql.NullString{},
false,
sql.NullInt64{}, // channel_id
sql.NullString{}, // model_mapping_chain
sql.NullString{}, // billing_tier
sql.NullString{}, // billing_mode
sql.NullInt64{}, // channel_id
sql.NullString{}, // model_mapping_chain
sql.NullString{}, // billing_tier
sql.NullString{}, // billing_mode
sql.NullFloat64{}, // account_stats_cost
now,
}})
require.NoError(t, err)
@@ -577,10 +581,11 @@ func TestScanUsageLogRequestTypeAndLegacyFallback(t *testing.T) {
sql.NullString{},
sql.NullString{},
false,
sql.NullInt64{}, // channel_id
sql.NullString{}, // model_mapping_chain
sql.NullString{}, // billing_tier
sql.NullString{}, // billing_mode
sql.NullInt64{}, // channel_id
sql.NullString{}, // model_mapping_chain
sql.NullString{}, // billing_tier
sql.NullString{}, // billing_mode
sql.NullFloat64{}, // account_stats_cost
now,
}})
require.NoError(t, err)