feat(channel): 缓存扁平化 + 网关映射集成 + 计费模式统一 + 模型限制
- 缓存按 (groupID, platform, model) 三维 key 扁平化,避免跨平台同名模型冲突
- buildCache 批量查询 group platform,按平台过滤展开定价和映射
- model_mapping 改为嵌套格式 {platform: {src: dst}}
- channel_model_pricing 新增 platform 列
- 前端按平台维度重构:每个平台独立配置分组/映射/定价
- 迁移 086: platform 列 + model_mapping 嵌套格式迁移
This commit is contained in:
@@ -406,8 +406,9 @@ func (r *channelRepository) GetGroupsInOtherChannels(ctx context.Context, channe
|
||||
return conflicting, nil
|
||||
}
|
||||
|
||||
// marshalModelMapping 将 model mapping 序列化为 JSON 字节,nil/空 map 返回 '{}'
|
||||
func marshalModelMapping(m map[string]string) ([]byte, error) {
|
||||
// marshalModelMapping 将 model mapping 序列化为嵌套 JSON 字节
|
||||
// 格式:{"platform": {"src": "dst"}, ...}
|
||||
func marshalModelMapping(m map[string]map[string]string) ([]byte, error) {
|
||||
if len(m) == 0 {
|
||||
return []byte("{}"), nil
|
||||
}
|
||||
@@ -418,14 +419,43 @@ func marshalModelMapping(m map[string]string) ([]byte, error) {
|
||||
return data, nil
|
||||
}
|
||||
|
||||
// unmarshalModelMapping 将 JSON 字节反序列化为 model mapping
|
||||
func unmarshalModelMapping(data []byte) map[string]string {
|
||||
// unmarshalModelMapping 将 JSON 字节反序列化为嵌套 model mapping
|
||||
func unmarshalModelMapping(data []byte) map[string]map[string]string {
|
||||
if len(data) == 0 {
|
||||
return nil
|
||||
}
|
||||
var m map[string]string
|
||||
var m map[string]map[string]string
|
||||
if err := json.Unmarshal(data, &m); err != nil {
|
||||
return nil
|
||||
}
|
||||
return m
|
||||
}
|
||||
|
||||
// GetGroupPlatforms 批量查询分组 ID 对应的平台
|
||||
func (r *channelRepository) GetGroupPlatforms(ctx context.Context, groupIDs []int64) (map[int64]string, error) {
|
||||
if len(groupIDs) == 0 {
|
||||
return make(map[int64]string), nil
|
||||
}
|
||||
rows, err := r.db.QueryContext(ctx,
|
||||
`SELECT id, platform FROM groups WHERE id = ANY($1)`,
|
||||
pq.Array(groupIDs),
|
||||
)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get group platforms: %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
result := make(map[int64]string, len(groupIDs))
|
||||
for rows.Next() {
|
||||
var id int64
|
||||
var platform string
|
||||
if err := rows.Scan(&id, &platform); err != nil {
|
||||
return nil, fmt.Errorf("scan group platform: %w", err)
|
||||
}
|
||||
result[id] = platform
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, fmt.Errorf("iterate group platforms: %w", err)
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
@@ -15,7 +15,7 @@ import (
|
||||
|
||||
func (r *channelRepository) ListModelPricing(ctx context.Context, channelID int64) ([]service.ChannelModelPricing, error) {
|
||||
rows, err := r.db.QueryContext(ctx,
|
||||
`SELECT id, channel_id, models, billing_mode, input_price, output_price, cache_write_price, cache_read_price, image_output_price, per_request_price, created_at, updated_at
|
||||
`SELECT id, channel_id, platform, models, billing_mode, input_price, output_price, cache_write_price, cache_read_price, image_output_price, per_request_price, created_at, updated_at
|
||||
FROM channel_model_pricing WHERE channel_id = $1 ORDER BY id`, channelID,
|
||||
)
|
||||
if err != nil {
|
||||
@@ -56,10 +56,10 @@ func (r *channelRepository) UpdateModelPricing(ctx context.Context, pricing *ser
|
||||
}
|
||||
result, err := r.db.ExecContext(ctx,
|
||||
`UPDATE channel_model_pricing
|
||||
SET models = $1, billing_mode = $2, input_price = $3, output_price = $4, cache_write_price = $5, cache_read_price = $6, image_output_price = $7, per_request_price = $8, updated_at = NOW()
|
||||
WHERE id = $9`,
|
||||
SET models = $1, billing_mode = $2, input_price = $3, output_price = $4, cache_write_price = $5, cache_read_price = $6, image_output_price = $7, per_request_price = $8, platform = $9, updated_at = NOW()
|
||||
WHERE id = $10`,
|
||||
modelsJSON, billingMode, pricing.InputPrice, pricing.OutputPrice, pricing.CacheWritePrice, pricing.CacheReadPrice,
|
||||
pricing.ImageOutputPrice, pricing.PerRequestPrice, pricing.ID,
|
||||
pricing.ImageOutputPrice, pricing.PerRequestPrice, pricing.Platform, pricing.ID,
|
||||
)
|
||||
if err != nil {
|
||||
return fmt.Errorf("update model pricing: %w", err)
|
||||
@@ -90,7 +90,7 @@ func (r *channelRepository) ReplaceModelPricing(ctx context.Context, channelID i
|
||||
// batchLoadModelPricing 批量加载多个渠道的模型定价(含区间)
|
||||
func (r *channelRepository) batchLoadModelPricing(ctx context.Context, channelIDs []int64) (map[int64][]service.ChannelModelPricing, error) {
|
||||
rows, err := r.db.QueryContext(ctx,
|
||||
`SELECT id, channel_id, models, billing_mode, input_price, output_price, cache_write_price, cache_read_price, image_output_price, per_request_price, created_at, updated_at
|
||||
`SELECT id, channel_id, platform, models, billing_mode, input_price, output_price, cache_write_price, cache_read_price, image_output_price, per_request_price, created_at, updated_at
|
||||
FROM channel_model_pricing WHERE channel_id = ANY($1) ORDER BY channel_id, id`,
|
||||
pq.Array(channelIDs),
|
||||
)
|
||||
@@ -169,7 +169,7 @@ func scanModelPricingRows(rows *sql.Rows) ([]service.ChannelModelPricing, []int6
|
||||
var p service.ChannelModelPricing
|
||||
var modelsJSON []byte
|
||||
if err := rows.Scan(
|
||||
&p.ID, &p.ChannelID, &modelsJSON, &p.BillingMode,
|
||||
&p.ID, &p.ChannelID, &p.Platform, &modelsJSON, &p.BillingMode,
|
||||
&p.InputPrice, &p.OutputPrice, &p.CacheWritePrice, &p.CacheReadPrice,
|
||||
&p.ImageOutputPrice, &p.PerRequestPrice, &p.CreatedAt, &p.UpdatedAt,
|
||||
); err != nil {
|
||||
@@ -223,10 +223,14 @@ func createModelPricingExec(ctx context.Context, exec dbExec, pricing *service.C
|
||||
if billingMode == "" {
|
||||
billingMode = service.BillingModeToken
|
||||
}
|
||||
platform := pricing.Platform
|
||||
if platform == "" {
|
||||
platform = "anthropic"
|
||||
}
|
||||
err = exec.QueryRowContext(ctx,
|
||||
`INSERT INTO channel_model_pricing (channel_id, models, billing_mode, input_price, output_price, cache_write_price, cache_read_price, image_output_price, per_request_price)
|
||||
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9) RETURNING id, created_at, updated_at`,
|
||||
pricing.ChannelID, modelsJSON, billingMode,
|
||||
`INSERT INTO channel_model_pricing (channel_id, platform, models, billing_mode, input_price, output_price, cache_write_price, cache_read_price, image_output_price, per_request_price)
|
||||
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10) RETURNING id, created_at, updated_at`,
|
||||
pricing.ChannelID, platform, modelsJSON, billingMode,
|
||||
pricing.InputPrice, pricing.OutputPrice, pricing.CacheWritePrice, pricing.CacheReadPrice,
|
||||
pricing.ImageOutputPrice, pricing.PerRequestPrice,
|
||||
).Scan(&pricing.ID, &pricing.CreatedAt, &pricing.UpdatedAt)
|
||||
|
||||
Reference in New Issue
Block a user