feat(channel): 缓存扁平化 + 网关映射集成 + 计费模式统一 + 模型限制
- 缓存按 (groupID, platform, model) 三维 key 扁平化,避免跨平台同名模型冲突
- buildCache 批量查询 group platform,按平台过滤展开定价和映射
- model_mapping 改为嵌套格式 {platform: {src: dst}}
- channel_model_pricing 新增 platform 列
- 前端按平台维度重构:每个平台独立配置分组/映射/定价
- 迁移 086: platform 列 + model_mapping 嵌套格式迁移
This commit is contained in:
@@ -24,27 +24,28 @@ func NewChannelHandler(channelService *service.ChannelService) *ChannelHandler {
|
||||
// --- Request / Response types ---
|
||||
|
||||
type createChannelRequest struct {
|
||||
Name string `json:"name" binding:"required,max=100"`
|
||||
Description string `json:"description"`
|
||||
GroupIDs []int64 `json:"group_ids"`
|
||||
ModelPricing []channelModelPricingRequest `json:"model_pricing"`
|
||||
ModelMapping map[string]string `json:"model_mapping"`
|
||||
BillingModelSource string `json:"billing_model_source" binding:"omitempty,oneof=requested upstream"`
|
||||
RestrictModels bool `json:"restrict_models"`
|
||||
Name string `json:"name" binding:"required,max=100"`
|
||||
Description string `json:"description"`
|
||||
GroupIDs []int64 `json:"group_ids"`
|
||||
ModelPricing []channelModelPricingRequest `json:"model_pricing"`
|
||||
ModelMapping map[string]map[string]string `json:"model_mapping"`
|
||||
BillingModelSource string `json:"billing_model_source" binding:"omitempty,oneof=requested upstream"`
|
||||
RestrictModels bool `json:"restrict_models"`
|
||||
}
|
||||
|
||||
type updateChannelRequest struct {
|
||||
Name string `json:"name" binding:"omitempty,max=100"`
|
||||
Description *string `json:"description"`
|
||||
Status string `json:"status" binding:"omitempty,oneof=active disabled"`
|
||||
GroupIDs *[]int64 `json:"group_ids"`
|
||||
ModelPricing *[]channelModelPricingRequest `json:"model_pricing"`
|
||||
ModelMapping map[string]string `json:"model_mapping"`
|
||||
BillingModelSource string `json:"billing_model_source" binding:"omitempty,oneof=requested upstream"`
|
||||
RestrictModels *bool `json:"restrict_models"`
|
||||
Name string `json:"name" binding:"omitempty,max=100"`
|
||||
Description *string `json:"description"`
|
||||
Status string `json:"status" binding:"omitempty,oneof=active disabled"`
|
||||
GroupIDs *[]int64 `json:"group_ids"`
|
||||
ModelPricing *[]channelModelPricingRequest `json:"model_pricing"`
|
||||
ModelMapping map[string]map[string]string `json:"model_mapping"`
|
||||
BillingModelSource string `json:"billing_model_source" binding:"omitempty,oneof=requested upstream"`
|
||||
RestrictModels *bool `json:"restrict_models"`
|
||||
}
|
||||
|
||||
type channelModelPricingRequest struct {
|
||||
Platform string `json:"platform" binding:"omitempty,max=50"`
|
||||
Models []string `json:"models" binding:"required,min=1,max=100"`
|
||||
BillingMode string `json:"billing_mode" binding:"omitempty,oneof=token per_request image"`
|
||||
InputPrice *float64 `json:"input_price" binding:"omitempty,min=0"`
|
||||
@@ -69,21 +70,22 @@ type pricingIntervalRequest struct {
|
||||
}
|
||||
|
||||
type channelResponse struct {
|
||||
ID int64 `json:"id"`
|
||||
Name string `json:"name"`
|
||||
Description string `json:"description"`
|
||||
Status string `json:"status"`
|
||||
BillingModelSource string `json:"billing_model_source"`
|
||||
RestrictModels bool `json:"restrict_models"`
|
||||
GroupIDs []int64 `json:"group_ids"`
|
||||
ModelPricing []channelModelPricingResponse `json:"model_pricing"`
|
||||
ModelMapping map[string]string `json:"model_mapping"`
|
||||
CreatedAt string `json:"created_at"`
|
||||
UpdatedAt string `json:"updated_at"`
|
||||
ID int64 `json:"id"`
|
||||
Name string `json:"name"`
|
||||
Description string `json:"description"`
|
||||
Status string `json:"status"`
|
||||
BillingModelSource string `json:"billing_model_source"`
|
||||
RestrictModels bool `json:"restrict_models"`
|
||||
GroupIDs []int64 `json:"group_ids"`
|
||||
ModelPricing []channelModelPricingResponse `json:"model_pricing"`
|
||||
ModelMapping map[string]map[string]string `json:"model_mapping"`
|
||||
CreatedAt string `json:"created_at"`
|
||||
UpdatedAt string `json:"updated_at"`
|
||||
}
|
||||
|
||||
type channelModelPricingResponse struct {
|
||||
ID int64 `json:"id"`
|
||||
Platform string `json:"platform"`
|
||||
Models []string `json:"models"`
|
||||
BillingMode string `json:"billing_mode"`
|
||||
InputPrice *float64 `json:"input_price"`
|
||||
@@ -131,7 +133,7 @@ func channelToResponse(ch *service.Channel) *channelResponse {
|
||||
resp.GroupIDs = []int64{}
|
||||
}
|
||||
if resp.ModelMapping == nil {
|
||||
resp.ModelMapping = map[string]string{}
|
||||
resp.ModelMapping = map[string]map[string]string{}
|
||||
}
|
||||
|
||||
resp.ModelPricing = make([]channelModelPricingResponse, 0, len(ch.ModelPricing))
|
||||
@@ -144,6 +146,10 @@ func channelToResponse(ch *service.Channel) *channelResponse {
|
||||
if billingMode == "" {
|
||||
billingMode = "token"
|
||||
}
|
||||
platform := p.Platform
|
||||
if platform == "" {
|
||||
platform = "anthropic"
|
||||
}
|
||||
intervals := make([]pricingIntervalResponse, 0, len(p.Intervals))
|
||||
for _, iv := range p.Intervals {
|
||||
intervals = append(intervals, pricingIntervalResponse{
|
||||
@@ -161,6 +167,7 @@ func channelToResponse(ch *service.Channel) *channelResponse {
|
||||
}
|
||||
resp.ModelPricing = append(resp.ModelPricing, channelModelPricingResponse{
|
||||
ID: p.ID,
|
||||
Platform: platform,
|
||||
Models: models,
|
||||
BillingMode: billingMode,
|
||||
InputPrice: p.InputPrice,
|
||||
@@ -182,6 +189,10 @@ func pricingRequestToService(reqs []channelModelPricingRequest) []service.Channe
|
||||
if billingMode == "" {
|
||||
billingMode = service.BillingModeToken
|
||||
}
|
||||
platform := r.Platform
|
||||
if platform == "" {
|
||||
platform = "anthropic"
|
||||
}
|
||||
intervals := make([]service.PricingInterval, 0, len(r.Intervals))
|
||||
for _, iv := range r.Intervals {
|
||||
intervals = append(intervals, service.PricingInterval{
|
||||
@@ -197,6 +208,7 @@ func pricingRequestToService(reqs []channelModelPricingRequest) []service.Channe
|
||||
})
|
||||
}
|
||||
result = append(result, service.ChannelModelPricing{
|
||||
Platform: platform,
|
||||
Models: r.Models,
|
||||
BillingMode: billingMode,
|
||||
InputPrice: r.InputPrice,
|
||||
|
||||
@@ -406,8 +406,9 @@ func (r *channelRepository) GetGroupsInOtherChannels(ctx context.Context, channe
|
||||
return conflicting, nil
|
||||
}
|
||||
|
||||
// marshalModelMapping 将 model mapping 序列化为 JSON 字节,nil/空 map 返回 '{}'
|
||||
func marshalModelMapping(m map[string]string) ([]byte, error) {
|
||||
// marshalModelMapping 将 model mapping 序列化为嵌套 JSON 字节
|
||||
// 格式:{"platform": {"src": "dst"}, ...}
|
||||
func marshalModelMapping(m map[string]map[string]string) ([]byte, error) {
|
||||
if len(m) == 0 {
|
||||
return []byte("{}"), nil
|
||||
}
|
||||
@@ -418,14 +419,43 @@ func marshalModelMapping(m map[string]string) ([]byte, error) {
|
||||
return data, nil
|
||||
}
|
||||
|
||||
// unmarshalModelMapping 将 JSON 字节反序列化为 model mapping
|
||||
func unmarshalModelMapping(data []byte) map[string]string {
|
||||
// unmarshalModelMapping 将 JSON 字节反序列化为嵌套 model mapping
|
||||
func unmarshalModelMapping(data []byte) map[string]map[string]string {
|
||||
if len(data) == 0 {
|
||||
return nil
|
||||
}
|
||||
var m map[string]string
|
||||
var m map[string]map[string]string
|
||||
if err := json.Unmarshal(data, &m); err != nil {
|
||||
return nil
|
||||
}
|
||||
return m
|
||||
}
|
||||
|
||||
// GetGroupPlatforms 批量查询分组 ID 对应的平台
|
||||
func (r *channelRepository) GetGroupPlatforms(ctx context.Context, groupIDs []int64) (map[int64]string, error) {
|
||||
if len(groupIDs) == 0 {
|
||||
return make(map[int64]string), nil
|
||||
}
|
||||
rows, err := r.db.QueryContext(ctx,
|
||||
`SELECT id, platform FROM groups WHERE id = ANY($1)`,
|
||||
pq.Array(groupIDs),
|
||||
)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get group platforms: %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
result := make(map[int64]string, len(groupIDs))
|
||||
for rows.Next() {
|
||||
var id int64
|
||||
var platform string
|
||||
if err := rows.Scan(&id, &platform); err != nil {
|
||||
return nil, fmt.Errorf("scan group platform: %w", err)
|
||||
}
|
||||
result[id] = platform
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, fmt.Errorf("iterate group platforms: %w", err)
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
@@ -15,7 +15,7 @@ import (
|
||||
|
||||
func (r *channelRepository) ListModelPricing(ctx context.Context, channelID int64) ([]service.ChannelModelPricing, error) {
|
||||
rows, err := r.db.QueryContext(ctx,
|
||||
`SELECT id, channel_id, models, billing_mode, input_price, output_price, cache_write_price, cache_read_price, image_output_price, per_request_price, created_at, updated_at
|
||||
`SELECT id, channel_id, platform, models, billing_mode, input_price, output_price, cache_write_price, cache_read_price, image_output_price, per_request_price, created_at, updated_at
|
||||
FROM channel_model_pricing WHERE channel_id = $1 ORDER BY id`, channelID,
|
||||
)
|
||||
if err != nil {
|
||||
@@ -56,10 +56,10 @@ func (r *channelRepository) UpdateModelPricing(ctx context.Context, pricing *ser
|
||||
}
|
||||
result, err := r.db.ExecContext(ctx,
|
||||
`UPDATE channel_model_pricing
|
||||
SET models = $1, billing_mode = $2, input_price = $3, output_price = $4, cache_write_price = $5, cache_read_price = $6, image_output_price = $7, per_request_price = $8, updated_at = NOW()
|
||||
WHERE id = $9`,
|
||||
SET models = $1, billing_mode = $2, input_price = $3, output_price = $4, cache_write_price = $5, cache_read_price = $6, image_output_price = $7, per_request_price = $8, platform = $9, updated_at = NOW()
|
||||
WHERE id = $10`,
|
||||
modelsJSON, billingMode, pricing.InputPrice, pricing.OutputPrice, pricing.CacheWritePrice, pricing.CacheReadPrice,
|
||||
pricing.ImageOutputPrice, pricing.PerRequestPrice, pricing.ID,
|
||||
pricing.ImageOutputPrice, pricing.PerRequestPrice, pricing.Platform, pricing.ID,
|
||||
)
|
||||
if err != nil {
|
||||
return fmt.Errorf("update model pricing: %w", err)
|
||||
@@ -90,7 +90,7 @@ func (r *channelRepository) ReplaceModelPricing(ctx context.Context, channelID i
|
||||
// batchLoadModelPricing 批量加载多个渠道的模型定价(含区间)
|
||||
func (r *channelRepository) batchLoadModelPricing(ctx context.Context, channelIDs []int64) (map[int64][]service.ChannelModelPricing, error) {
|
||||
rows, err := r.db.QueryContext(ctx,
|
||||
`SELECT id, channel_id, models, billing_mode, input_price, output_price, cache_write_price, cache_read_price, image_output_price, per_request_price, created_at, updated_at
|
||||
`SELECT id, channel_id, platform, models, billing_mode, input_price, output_price, cache_write_price, cache_read_price, image_output_price, per_request_price, created_at, updated_at
|
||||
FROM channel_model_pricing WHERE channel_id = ANY($1) ORDER BY channel_id, id`,
|
||||
pq.Array(channelIDs),
|
||||
)
|
||||
@@ -169,7 +169,7 @@ func scanModelPricingRows(rows *sql.Rows) ([]service.ChannelModelPricing, []int6
|
||||
var p service.ChannelModelPricing
|
||||
var modelsJSON []byte
|
||||
if err := rows.Scan(
|
||||
&p.ID, &p.ChannelID, &modelsJSON, &p.BillingMode,
|
||||
&p.ID, &p.ChannelID, &p.Platform, &modelsJSON, &p.BillingMode,
|
||||
&p.InputPrice, &p.OutputPrice, &p.CacheWritePrice, &p.CacheReadPrice,
|
||||
&p.ImageOutputPrice, &p.PerRequestPrice, &p.CreatedAt, &p.UpdatedAt,
|
||||
); err != nil {
|
||||
@@ -223,10 +223,14 @@ func createModelPricingExec(ctx context.Context, exec dbExec, pricing *service.C
|
||||
if billingMode == "" {
|
||||
billingMode = service.BillingModeToken
|
||||
}
|
||||
platform := pricing.Platform
|
||||
if platform == "" {
|
||||
platform = "anthropic"
|
||||
}
|
||||
err = exec.QueryRowContext(ctx,
|
||||
`INSERT INTO channel_model_pricing (channel_id, models, billing_mode, input_price, output_price, cache_write_price, cache_read_price, image_output_price, per_request_price)
|
||||
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9) RETURNING id, created_at, updated_at`,
|
||||
pricing.ChannelID, modelsJSON, billingMode,
|
||||
`INSERT INTO channel_model_pricing (channel_id, platform, models, billing_mode, input_price, output_price, cache_write_price, cache_read_price, image_output_price, per_request_price)
|
||||
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10) RETURNING id, created_at, updated_at`,
|
||||
pricing.ChannelID, platform, modelsJSON, billingMode,
|
||||
pricing.InputPrice, pricing.OutputPrice, pricing.CacheWritePrice, pricing.CacheReadPrice,
|
||||
pricing.ImageOutputPrice, pricing.PerRequestPrice,
|
||||
).Scan(&pricing.ID, &pricing.CreatedAt, &pricing.UpdatedAt)
|
||||
|
||||
@@ -41,16 +41,17 @@ type Channel struct {
|
||||
|
||||
// 关联的分组 ID 列表
|
||||
GroupIDs []int64
|
||||
// 模型定价列表
|
||||
// 模型定价列表(每条含 Platform 字段)
|
||||
ModelPricing []ChannelModelPricing
|
||||
// 渠道级模型映射
|
||||
ModelMapping map[string]string
|
||||
// 渠道级模型映射(按平台分组:platform → {src→dst})
|
||||
ModelMapping map[string]map[string]string
|
||||
}
|
||||
|
||||
// ChannelModelPricing 渠道模型定价条目
|
||||
type ChannelModelPricing struct {
|
||||
ID int64
|
||||
ChannelID int64
|
||||
Platform string // 所属平台(anthropic/openai/gemini/...)
|
||||
Models []string // 绑定的模型列表
|
||||
BillingMode BillingMode // 计费模式
|
||||
InputPrice *float64 // 每 token 输入价格(USD)— 向后兼容 flat 定价
|
||||
@@ -82,21 +83,26 @@ type PricingInterval struct {
|
||||
}
|
||||
|
||||
// ResolveMappedModel 解析渠道级模型映射,返回映射后的模型名。
|
||||
// platform 指定查找哪个平台的映射规则。
|
||||
// 支持通配符(如 "claude-*" → "claude-sonnet-4")。
|
||||
// 如果没有匹配的映射规则,返回原始模型名。
|
||||
func (c *Channel) ResolveMappedModel(requestedModel string) string {
|
||||
func (c *Channel) ResolveMappedModel(platform, requestedModel string) string {
|
||||
if len(c.ModelMapping) == 0 {
|
||||
return requestedModel
|
||||
}
|
||||
platformMapping, ok := c.ModelMapping[platform]
|
||||
if !ok || len(platformMapping) == 0 {
|
||||
return requestedModel
|
||||
}
|
||||
lower := strings.ToLower(requestedModel)
|
||||
// 精确匹配优先
|
||||
for src, dst := range c.ModelMapping {
|
||||
for src, dst := range platformMapping {
|
||||
if strings.ToLower(src) == lower {
|
||||
return dst
|
||||
}
|
||||
}
|
||||
// 通配符匹配
|
||||
for src, dst := range c.ModelMapping {
|
||||
for src, dst := range platformMapping {
|
||||
srcLower := strings.ToLower(src)
|
||||
if strings.HasSuffix(srcLower, "*") {
|
||||
prefix := strings.TrimSuffix(srcLower, "*")
|
||||
@@ -190,9 +196,13 @@ func (c *Channel) Clone() *Channel {
|
||||
}
|
||||
}
|
||||
if c.ModelMapping != nil {
|
||||
cp.ModelMapping = make(map[string]string, len(c.ModelMapping))
|
||||
for k, v := range c.ModelMapping {
|
||||
cp.ModelMapping[k] = v
|
||||
cp.ModelMapping = make(map[string]map[string]string, len(c.ModelMapping))
|
||||
for platform, mapping := range c.ModelMapping {
|
||||
inner := make(map[string]string, len(mapping))
|
||||
for k, v := range mapping {
|
||||
inner[k] = v
|
||||
}
|
||||
cp.ModelMapping[platform] = inner
|
||||
}
|
||||
}
|
||||
return &cp
|
||||
|
||||
@@ -39,6 +39,9 @@ type ChannelRepository interface {
|
||||
GetChannelIDByGroupID(ctx context.Context, groupID int64) (int64, error)
|
||||
GetGroupsInOtherChannels(ctx context.Context, channelID int64, groupIDs []int64) ([]int64, error)
|
||||
|
||||
// 分组平台查询
|
||||
GetGroupPlatforms(ctx context.Context, groupIDs []int64) (map[int64]string, error)
|
||||
|
||||
// 模型定价
|
||||
ListModelPricing(ctx context.Context, channelID int64) ([]ChannelModelPricing, error)
|
||||
CreateModelPricing(ctx context.Context, pricing *ChannelModelPricing) error
|
||||
@@ -47,18 +50,20 @@ type ChannelRepository interface {
|
||||
ReplaceModelPricing(ctx context.Context, channelID int64, pricingList []ChannelModelPricing) error
|
||||
}
|
||||
|
||||
// channelModelKey 渠道缓存复合键
|
||||
// channelModelKey 渠道缓存复合键(显式包含 platform 防止跨平台同名模型冲突)
|
||||
type channelModelKey struct {
|
||||
groupID int64
|
||||
model string // lowercase
|
||||
groupID int64
|
||||
platform string // 平台标识
|
||||
model string // lowercase
|
||||
}
|
||||
|
||||
// channelCache 渠道缓存快照(扁平化哈希结构,热路径 O(1) 查找)
|
||||
type channelCache struct {
|
||||
// 热路径查找
|
||||
pricingByGroupModel map[channelModelKey]*ChannelModelPricing // (groupID, model) → 定价
|
||||
mappingByGroupModel map[channelModelKey]string // (groupID, model) → 映射目标
|
||||
pricingByGroupModel map[channelModelKey]*ChannelModelPricing // (groupID, platform, model) → 定价
|
||||
mappingByGroupModel map[channelModelKey]string // (groupID, platform, model) → 映射目标
|
||||
channelByGroupID map[int64]*Channel // groupID → 渠道
|
||||
groupPlatform map[int64]string // groupID → platform
|
||||
|
||||
// 冷路径(CRUD 操作)
|
||||
byID map[int64]*Channel
|
||||
@@ -135,6 +140,7 @@ func (s *ChannelService) buildCache(ctx context.Context) (*channelCache, error)
|
||||
pricingByGroupModel: make(map[channelModelKey]*ChannelModelPricing),
|
||||
mappingByGroupModel: make(map[channelModelKey]string),
|
||||
channelByGroupID: make(map[int64]*Channel),
|
||||
groupPlatform: make(map[int64]string),
|
||||
byID: make(map[int64]*Channel),
|
||||
loadedAt: time.Now().Add(channelCacheTTL - channelErrorTTL), // 使剩余 TTL = errorTTL
|
||||
}
|
||||
@@ -142,10 +148,25 @@ func (s *ChannelService) buildCache(ctx context.Context) (*channelCache, error)
|
||||
return nil, fmt.Errorf("list all channels: %w", err)
|
||||
}
|
||||
|
||||
// 收集所有 groupID,批量查询 platform
|
||||
var allGroupIDs []int64
|
||||
for i := range channels {
|
||||
allGroupIDs = append(allGroupIDs, channels[i].GroupIDs...)
|
||||
}
|
||||
groupPlatforms := make(map[int64]string)
|
||||
if len(allGroupIDs) > 0 {
|
||||
groupPlatforms, err = s.repo.GetGroupPlatforms(dbCtx, allGroupIDs)
|
||||
if err != nil {
|
||||
slog.Warn("failed to load group platforms for channel cache", "error", err)
|
||||
// 降级:继续构建缓存但无法按平台过滤
|
||||
}
|
||||
}
|
||||
|
||||
cache := &channelCache{
|
||||
pricingByGroupModel: make(map[channelModelKey]*ChannelModelPricing),
|
||||
mappingByGroupModel: make(map[channelModelKey]string),
|
||||
channelByGroupID: make(map[int64]*Channel),
|
||||
groupPlatform: groupPlatforms,
|
||||
byID: make(map[int64]*Channel, len(channels)),
|
||||
loadedAt: time.Now(),
|
||||
}
|
||||
@@ -157,20 +178,26 @@ func (s *ChannelService) buildCache(ctx context.Context) (*channelCache, error)
|
||||
// 展开到分组维度
|
||||
for _, gid := range ch.GroupIDs {
|
||||
cache.channelByGroupID[gid] = ch
|
||||
platform := groupPlatforms[gid] // e.g. "anthropic"
|
||||
|
||||
// 展开模型定价到 (groupID, model) → *ChannelModelPricing
|
||||
// 只展开该平台的模型定价到 (groupID, platform, model) → *ChannelModelPricing
|
||||
for j := range ch.ModelPricing {
|
||||
pricing := &ch.ModelPricing[j]
|
||||
if pricing.Platform != platform {
|
||||
continue // 跳过非本平台的定价
|
||||
}
|
||||
for _, model := range pricing.Models {
|
||||
key := channelModelKey{groupID: gid, model: strings.ToLower(model)}
|
||||
key := channelModelKey{groupID: gid, platform: platform, model: strings.ToLower(model)}
|
||||
cache.pricingByGroupModel[key] = pricing
|
||||
}
|
||||
}
|
||||
|
||||
// 展开模型映射到 (groupID, model) → target
|
||||
for src, dst := range ch.ModelMapping {
|
||||
key := channelModelKey{groupID: gid, model: strings.ToLower(src)}
|
||||
cache.mappingByGroupModel[key] = dst
|
||||
// 只展开该平台的模型映射到 (groupID, platform, model) → target
|
||||
if platformMapping, ok := ch.ModelMapping[platform]; ok {
|
||||
for src, dst := range platformMapping {
|
||||
key := channelModelKey{groupID: gid, platform: platform, model: strings.ToLower(src)}
|
||||
cache.mappingByGroupModel[key] = dst
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -214,7 +241,8 @@ func (s *ChannelService) GetChannelModelPricing(ctx context.Context, groupID int
|
||||
return nil
|
||||
}
|
||||
|
||||
key := channelModelKey{groupID: groupID, model: strings.ToLower(model)}
|
||||
platform := cache.groupPlatform[groupID]
|
||||
key := channelModelKey{groupID: groupID, platform: platform, model: strings.ToLower(model)}
|
||||
pricing, ok := cache.pricingByGroupModel[key]
|
||||
if !ok {
|
||||
return nil
|
||||
@@ -246,7 +274,8 @@ func (s *ChannelService) ResolveChannelMapping(ctx context.Context, groupID int6
|
||||
result.BillingModelSource = BillingModelSourceRequested
|
||||
}
|
||||
|
||||
key := channelModelKey{groupID: groupID, model: strings.ToLower(model)}
|
||||
platform := cache.groupPlatform[groupID]
|
||||
key := channelModelKey{groupID: groupID, platform: platform, model: strings.ToLower(model)}
|
||||
if mapped, ok := cache.mappingByGroupModel[key]; ok {
|
||||
result.MappedModel = mapped
|
||||
result.Mapped = true
|
||||
@@ -270,7 +299,8 @@ func (s *ChannelService) IsModelRestricted(ctx context.Context, groupID int64, m
|
||||
}
|
||||
|
||||
// 检查模型是否在定价列表中
|
||||
key := channelModelKey{groupID: groupID, model: strings.ToLower(model)}
|
||||
platform := cache.groupPlatform[groupID]
|
||||
key := channelModelKey{groupID: groupID, platform: platform, model: strings.ToLower(model)}
|
||||
_, exists := cache.pricingByGroupModel[key]
|
||||
return !exists
|
||||
}
|
||||
@@ -458,7 +488,7 @@ type CreateChannelInput struct {
|
||||
Description string
|
||||
GroupIDs []int64
|
||||
ModelPricing []ChannelModelPricing
|
||||
ModelMapping map[string]string
|
||||
ModelMapping map[string]map[string]string // platform → {src→dst}
|
||||
BillingModelSource string
|
||||
RestrictModels bool
|
||||
}
|
||||
@@ -470,7 +500,7 @@ type UpdateChannelInput struct {
|
||||
Status string
|
||||
GroupIDs *[]int64
|
||||
ModelPricing *[]ChannelModelPricing
|
||||
ModelMapping map[string]string
|
||||
ModelMapping map[string]map[string]string // platform → {src→dst}
|
||||
BillingModelSource string
|
||||
RestrictModels *bool
|
||||
}
|
||||
|
||||
21
backend/migrations/086_channel_platform_pricing.sql
Normal file
21
backend/migrations/086_channel_platform_pricing.sql
Normal file
@@ -0,0 +1,21 @@
|
||||
-- 086_channel_platform_pricing.sql
|
||||
-- 渠道按平台维度:model_pricing 加 platform 列,model_mapping 改为嵌套格式
|
||||
|
||||
-- 1. channel_model_pricing 加 platform 列
|
||||
ALTER TABLE channel_model_pricing
|
||||
ADD COLUMN IF NOT EXISTS platform VARCHAR(50) NOT NULL DEFAULT 'anthropic';
|
||||
|
||||
CREATE INDEX IF NOT EXISTS idx_channel_model_pricing_platform
|
||||
ON channel_model_pricing (platform);
|
||||
|
||||
-- 2. model_mapping: 从扁平 {"src":"dst"} 迁移为嵌套 {"anthropic":{"src":"dst"}}
|
||||
-- 仅迁移非空、非 '{}' 的旧格式数据(通过检查第一个 value 是否为字符串来判断是否为旧格式)
|
||||
UPDATE channels
|
||||
SET model_mapping = jsonb_build_object('anthropic', model_mapping)
|
||||
WHERE model_mapping IS NOT NULL
|
||||
AND model_mapping::text NOT IN ('{}', 'null', '')
|
||||
AND NOT EXISTS (
|
||||
SELECT 1 FROM jsonb_each(model_mapping) AS kv
|
||||
WHERE jsonb_typeof(kv.value) = 'object'
|
||||
LIMIT 1
|
||||
);
|
||||
Reference in New Issue
Block a user