- 渠道模型映射:支持精确匹配和通配符映射,按平台隔离 - 渠道模型定价:支持 token/按次/图片三种计费模式,区间分层定价 - 模型限制:渠道可限制仅允许定价列表中的模型 - 计费模型来源:支持 requested/upstream 两种计费模型选择 - 用量统计:usage_logs 新增 channel_id/model_mapping_chain/billing_tier/billing_mode 字段 - Dashboard 支持 model_source 维度(requested/upstream/mapping)查看模型统计 - 全部 gateway handler 统一接入 ResolveChannelMappingAndRestrict - 修复测试:同步 SoraGenerationRepository 接口、SQL INSERT 参数、scan 字段
292 lines
10 KiB
Go
292 lines
10 KiB
Go
package repository
|
|
|
|
import (
|
|
"context"
|
|
"database/sql"
|
|
"encoding/json"
|
|
"errors"
|
|
"fmt"
|
|
"strings"
|
|
|
|
"github.com/Wei-Shaw/sub2api/internal/service"
|
|
"github.com/lib/pq"
|
|
)
|
|
|
|
// --- 模型定价 ---
|
|
|
|
func (r *channelRepository) ListModelPricing(ctx context.Context, channelID int64) ([]service.ChannelModelPricing, error) {
|
|
rows, err := r.db.QueryContext(ctx,
|
|
`SELECT id, channel_id, platform, models, billing_mode, input_price, output_price, cache_write_price, cache_read_price, image_output_price, per_request_price, created_at, updated_at
|
|
FROM channel_model_pricing WHERE channel_id = $1 ORDER BY id`, channelID,
|
|
)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("list model pricing: %w", err)
|
|
}
|
|
defer func() { _ = rows.Close() }()
|
|
|
|
result, pricingIDs, err := scanModelPricingRows(rows)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
if len(pricingIDs) > 0 {
|
|
intervalMap, err := r.batchLoadIntervals(ctx, pricingIDs)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
for i := range result {
|
|
result[i].Intervals = intervalMap[result[i].ID]
|
|
}
|
|
}
|
|
|
|
return result, nil
|
|
}
|
|
|
|
func (r *channelRepository) CreateModelPricing(ctx context.Context, pricing *service.ChannelModelPricing) error {
|
|
return createModelPricingExec(ctx, r.db, pricing)
|
|
}
|
|
|
|
func (r *channelRepository) UpdateModelPricing(ctx context.Context, pricing *service.ChannelModelPricing) error {
|
|
modelsJSON, err := json.Marshal(pricing.Models)
|
|
if err != nil {
|
|
return fmt.Errorf("marshal models: %w", err)
|
|
}
|
|
billingMode := pricing.BillingMode
|
|
if billingMode == "" {
|
|
billingMode = service.BillingModeToken
|
|
}
|
|
result, err := r.db.ExecContext(ctx,
|
|
`UPDATE channel_model_pricing
|
|
SET models = $1, billing_mode = $2, input_price = $3, output_price = $4, cache_write_price = $5, cache_read_price = $6, image_output_price = $7, per_request_price = $8, platform = $9, updated_at = NOW()
|
|
WHERE id = $10`,
|
|
modelsJSON, billingMode, pricing.InputPrice, pricing.OutputPrice, pricing.CacheWritePrice, pricing.CacheReadPrice,
|
|
pricing.ImageOutputPrice, pricing.PerRequestPrice, pricing.Platform, pricing.ID,
|
|
)
|
|
if err != nil {
|
|
return fmt.Errorf("update model pricing: %w", err)
|
|
}
|
|
rows, _ := result.RowsAffected()
|
|
if rows == 0 {
|
|
return fmt.Errorf("pricing entry not found: %d", pricing.ID)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (r *channelRepository) DeleteModelPricing(ctx context.Context, id int64) error {
|
|
_, err := r.db.ExecContext(ctx, `DELETE FROM channel_model_pricing WHERE id = $1`, id)
|
|
if err != nil {
|
|
return fmt.Errorf("delete model pricing: %w", err)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (r *channelRepository) ReplaceModelPricing(ctx context.Context, channelID int64, pricingList []service.ChannelModelPricing) error {
|
|
return r.runInTx(ctx, func(tx *sql.Tx) error {
|
|
return replaceModelPricingTx(ctx, tx, channelID, pricingList)
|
|
})
|
|
}
|
|
|
|
// --- 批量加载辅助方法 ---
|
|
|
|
// batchLoadModelPricing 批量加载多个渠道的模型定价(含区间)
|
|
func (r *channelRepository) batchLoadModelPricing(ctx context.Context, channelIDs []int64) (map[int64][]service.ChannelModelPricing, error) {
|
|
rows, err := r.db.QueryContext(ctx,
|
|
`SELECT id, channel_id, platform, models, billing_mode, input_price, output_price, cache_write_price, cache_read_price, image_output_price, per_request_price, created_at, updated_at
|
|
FROM channel_model_pricing WHERE channel_id = ANY($1) ORDER BY channel_id, id`,
|
|
pq.Array(channelIDs),
|
|
)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("batch load model pricing: %w", err)
|
|
}
|
|
defer func() { _ = rows.Close() }()
|
|
|
|
allPricing, allPricingIDs, err := scanModelPricingRows(rows)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
// 按 channelID 分组
|
|
pricingMap := make(map[int64][]service.ChannelModelPricing, len(channelIDs))
|
|
for _, p := range allPricing {
|
|
pricingMap[p.ChannelID] = append(pricingMap[p.ChannelID], p)
|
|
}
|
|
|
|
// 批量加载所有区间
|
|
if len(allPricingIDs) > 0 {
|
|
intervalMap, err := r.batchLoadIntervals(ctx, allPricingIDs)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
for chID := range pricingMap {
|
|
for i := range pricingMap[chID] {
|
|
pricingMap[chID][i].Intervals = intervalMap[pricingMap[chID][i].ID]
|
|
}
|
|
}
|
|
}
|
|
|
|
return pricingMap, nil
|
|
}
|
|
|
|
// batchLoadIntervals 批量加载多个定价条目的区间
|
|
func (r *channelRepository) batchLoadIntervals(ctx context.Context, pricingIDs []int64) (map[int64][]service.PricingInterval, error) {
|
|
rows, err := r.db.QueryContext(ctx,
|
|
`SELECT id, pricing_id, min_tokens, max_tokens, tier_label,
|
|
input_price, output_price, cache_write_price, cache_read_price,
|
|
per_request_price, sort_order, created_at, updated_at
|
|
FROM channel_pricing_intervals
|
|
WHERE pricing_id = ANY($1) ORDER BY pricing_id, sort_order, id`,
|
|
pq.Array(pricingIDs),
|
|
)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("batch load intervals: %w", err)
|
|
}
|
|
defer func() { _ = rows.Close() }()
|
|
|
|
intervalMap := make(map[int64][]service.PricingInterval, len(pricingIDs))
|
|
for rows.Next() {
|
|
var iv service.PricingInterval
|
|
if err := rows.Scan(
|
|
&iv.ID, &iv.PricingID, &iv.MinTokens, &iv.MaxTokens, &iv.TierLabel,
|
|
&iv.InputPrice, &iv.OutputPrice, &iv.CacheWritePrice, &iv.CacheReadPrice,
|
|
&iv.PerRequestPrice, &iv.SortOrder, &iv.CreatedAt, &iv.UpdatedAt,
|
|
); err != nil {
|
|
return nil, fmt.Errorf("scan interval: %w", err)
|
|
}
|
|
intervalMap[iv.PricingID] = append(intervalMap[iv.PricingID], iv)
|
|
}
|
|
if err := rows.Err(); err != nil {
|
|
return nil, fmt.Errorf("iterate intervals: %w", err)
|
|
}
|
|
return intervalMap, nil
|
|
}
|
|
|
|
// --- 共享 scan 辅助 ---
|
|
|
|
// scanModelPricingRows 扫描 model pricing 行,返回结果列表和 ID 列表
|
|
func scanModelPricingRows(rows *sql.Rows) ([]service.ChannelModelPricing, []int64, error) {
|
|
var result []service.ChannelModelPricing
|
|
var pricingIDs []int64
|
|
for rows.Next() {
|
|
var p service.ChannelModelPricing
|
|
var modelsJSON []byte
|
|
if err := rows.Scan(
|
|
&p.ID, &p.ChannelID, &p.Platform, &modelsJSON, &p.BillingMode,
|
|
&p.InputPrice, &p.OutputPrice, &p.CacheWritePrice, &p.CacheReadPrice,
|
|
&p.ImageOutputPrice, &p.PerRequestPrice, &p.CreatedAt, &p.UpdatedAt,
|
|
); err != nil {
|
|
return nil, nil, fmt.Errorf("scan model pricing: %w", err)
|
|
}
|
|
if err := json.Unmarshal(modelsJSON, &p.Models); err != nil {
|
|
p.Models = []string{}
|
|
}
|
|
pricingIDs = append(pricingIDs, p.ID)
|
|
result = append(result, p)
|
|
}
|
|
if err := rows.Err(); err != nil {
|
|
return nil, nil, fmt.Errorf("iterate model pricing: %w", err)
|
|
}
|
|
return result, pricingIDs, nil
|
|
}
|
|
|
|
// --- 事务内辅助方法 ---
|
|
|
|
// dbExec 是 *sql.DB 和 *sql.Tx 共享的最小 SQL 执行接口
|
|
type dbExec interface {
|
|
ExecContext(ctx context.Context, query string, args ...any) (sql.Result, error)
|
|
QueryContext(ctx context.Context, query string, args ...any) (*sql.Rows, error)
|
|
QueryRowContext(ctx context.Context, query string, args ...any) *sql.Row
|
|
}
|
|
|
|
func setGroupIDsTx(ctx context.Context, exec dbExec, channelID int64, groupIDs []int64) error {
|
|
if _, err := exec.ExecContext(ctx, `DELETE FROM channel_groups WHERE channel_id = $1`, channelID); err != nil {
|
|
return fmt.Errorf("delete old group associations: %w", err)
|
|
}
|
|
if len(groupIDs) == 0 {
|
|
return nil
|
|
}
|
|
_, err := exec.ExecContext(ctx,
|
|
`INSERT INTO channel_groups (channel_id, group_id)
|
|
SELECT $1, unnest($2::bigint[])`,
|
|
channelID, pq.Array(groupIDs),
|
|
)
|
|
if err != nil {
|
|
return fmt.Errorf("insert group associations: %w", err)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func createModelPricingExec(ctx context.Context, exec dbExec, pricing *service.ChannelModelPricing) error {
|
|
modelsJSON, err := json.Marshal(pricing.Models)
|
|
if err != nil {
|
|
return fmt.Errorf("marshal models: %w", err)
|
|
}
|
|
billingMode := pricing.BillingMode
|
|
if billingMode == "" {
|
|
billingMode = service.BillingModeToken
|
|
}
|
|
platform := pricing.Platform
|
|
if platform == "" {
|
|
platform = "anthropic"
|
|
}
|
|
err = exec.QueryRowContext(ctx,
|
|
`INSERT INTO channel_model_pricing (channel_id, platform, models, billing_mode, input_price, output_price, cache_write_price, cache_read_price, image_output_price, per_request_price)
|
|
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10) RETURNING id, created_at, updated_at`,
|
|
pricing.ChannelID, platform, modelsJSON, billingMode,
|
|
pricing.InputPrice, pricing.OutputPrice, pricing.CacheWritePrice, pricing.CacheReadPrice,
|
|
pricing.ImageOutputPrice, pricing.PerRequestPrice,
|
|
).Scan(&pricing.ID, &pricing.CreatedAt, &pricing.UpdatedAt)
|
|
if err != nil {
|
|
return fmt.Errorf("insert model pricing: %w", err)
|
|
}
|
|
|
|
for i := range pricing.Intervals {
|
|
pricing.Intervals[i].PricingID = pricing.ID
|
|
if err := createIntervalExec(ctx, exec, &pricing.Intervals[i]); err != nil {
|
|
return err
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func createIntervalExec(ctx context.Context, exec dbExec, iv *service.PricingInterval) error {
|
|
return exec.QueryRowContext(ctx,
|
|
`INSERT INTO channel_pricing_intervals
|
|
(pricing_id, min_tokens, max_tokens, tier_label, input_price, output_price, cache_write_price, cache_read_price, per_request_price, sort_order)
|
|
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10) RETURNING id, created_at, updated_at`,
|
|
iv.PricingID, iv.MinTokens, iv.MaxTokens, iv.TierLabel,
|
|
iv.InputPrice, iv.OutputPrice, iv.CacheWritePrice, iv.CacheReadPrice,
|
|
iv.PerRequestPrice, iv.SortOrder,
|
|
).Scan(&iv.ID, &iv.CreatedAt, &iv.UpdatedAt)
|
|
}
|
|
|
|
func replaceModelPricingTx(ctx context.Context, exec dbExec, channelID int64, pricingList []service.ChannelModelPricing) error {
|
|
if _, err := exec.ExecContext(ctx, `DELETE FROM channel_model_pricing WHERE channel_id = $1`, channelID); err != nil {
|
|
return fmt.Errorf("delete old model pricing: %w", err)
|
|
}
|
|
for i := range pricingList {
|
|
pricingList[i].ChannelID = channelID
|
|
if err := createModelPricingExec(ctx, exec, &pricingList[i]); err != nil {
|
|
return fmt.Errorf("insert model pricing: %w", err)
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// isUniqueViolation 检查 pq 唯一约束违反错误
|
|
func isUniqueViolation(err error) bool {
|
|
var pqErr *pq.Error
|
|
if errors.As(err, &pqErr) && pqErr != nil {
|
|
return pqErr.Code == "23505"
|
|
}
|
|
return false
|
|
}
|
|
|
|
// escapeLike 转义 LIKE/ILIKE 模式中的特殊字符
|
|
func escapeLike(s string) string {
|
|
s = strings.ReplaceAll(s, `\`, `\\`)
|
|
s = strings.ReplaceAll(s, `%`, `\%`)
|
|
s = strings.ReplaceAll(s, `_`, `\_`)
|
|
return s
|
|
}
|