feat(channel): 缓存扁平化 + 网关映射集成 + 计费模式统一 + 模型限制

- 缓存按 (groupID, platform, model) 三维 key 扁平化,避免跨平台同名模型冲突
- buildCache 批量查询 group platform,按平台过滤展开定价和映射
- model_mapping 改为嵌套格式 {platform: {src: dst}}
- channel_model_pricing 新增 platform 列
- 前端按平台维度重构:每个平台独立配置分组/映射/定价
- 迁移 086: platform 列 + model_mapping 嵌套格式迁移
This commit is contained in:
erio
2026-03-30 15:04:30 +08:00
parent 28a6adaaa4
commit 0b1ce6be8f
10 changed files with 542 additions and 320 deletions

View File

@@ -41,16 +41,17 @@ type Channel struct {
// 关联的分组 ID 列表
GroupIDs []int64
// 模型定价列表
// 模型定价列表(每条含 Platform 字段)
ModelPricing []ChannelModelPricing
// 渠道级模型映射
ModelMapping map[string]string
// 渠道级模型映射按平台分组platform → {src→dst}
ModelMapping map[string]map[string]string
}
// ChannelModelPricing 渠道模型定价条目
type ChannelModelPricing struct {
ID int64
ChannelID int64
Platform string // 所属平台anthropic/openai/gemini/...
Models []string // 绑定的模型列表
BillingMode BillingMode // 计费模式
InputPrice *float64 // 每 token 输入价格USD— 向后兼容 flat 定价
@@ -82,21 +83,26 @@ type PricingInterval struct {
}
// ResolveMappedModel 解析渠道级模型映射,返回映射后的模型名。
// platform 指定查找哪个平台的映射规则。
// 支持通配符(如 "claude-*" → "claude-sonnet-4")。
// 如果没有匹配的映射规则,返回原始模型名。
func (c *Channel) ResolveMappedModel(requestedModel string) string {
func (c *Channel) ResolveMappedModel(platform, requestedModel string) string {
if len(c.ModelMapping) == 0 {
return requestedModel
}
platformMapping, ok := c.ModelMapping[platform]
if !ok || len(platformMapping) == 0 {
return requestedModel
}
lower := strings.ToLower(requestedModel)
// 精确匹配优先
for src, dst := range c.ModelMapping {
for src, dst := range platformMapping {
if strings.ToLower(src) == lower {
return dst
}
}
// 通配符匹配
for src, dst := range c.ModelMapping {
for src, dst := range platformMapping {
srcLower := strings.ToLower(src)
if strings.HasSuffix(srcLower, "*") {
prefix := strings.TrimSuffix(srcLower, "*")
@@ -190,9 +196,13 @@ func (c *Channel) Clone() *Channel {
}
}
if c.ModelMapping != nil {
cp.ModelMapping = make(map[string]string, len(c.ModelMapping))
for k, v := range c.ModelMapping {
cp.ModelMapping[k] = v
cp.ModelMapping = make(map[string]map[string]string, len(c.ModelMapping))
for platform, mapping := range c.ModelMapping {
inner := make(map[string]string, len(mapping))
for k, v := range mapping {
inner[k] = v
}
cp.ModelMapping[platform] = inner
}
}
return &cp

View File

@@ -39,6 +39,9 @@ type ChannelRepository interface {
GetChannelIDByGroupID(ctx context.Context, groupID int64) (int64, error)
GetGroupsInOtherChannels(ctx context.Context, channelID int64, groupIDs []int64) ([]int64, error)
// 分组平台查询
GetGroupPlatforms(ctx context.Context, groupIDs []int64) (map[int64]string, error)
// 模型定价
ListModelPricing(ctx context.Context, channelID int64) ([]ChannelModelPricing, error)
CreateModelPricing(ctx context.Context, pricing *ChannelModelPricing) error
@@ -47,18 +50,20 @@ type ChannelRepository interface {
ReplaceModelPricing(ctx context.Context, channelID int64, pricingList []ChannelModelPricing) error
}
// channelModelKey 渠道缓存复合键
// channelModelKey 渠道缓存复合键(显式包含 platform 防止跨平台同名模型冲突)
type channelModelKey struct {
groupID int64
model string // lowercase
groupID int64
platform string // 平台标识
model string // lowercase
}
// channelCache 渠道缓存快照(扁平化哈希结构,热路径 O(1) 查找)
type channelCache struct {
// 热路径查找
pricingByGroupModel map[channelModelKey]*ChannelModelPricing // (groupID, model) → 定价
mappingByGroupModel map[channelModelKey]string // (groupID, model) → 映射目标
pricingByGroupModel map[channelModelKey]*ChannelModelPricing // (groupID, platform, model) → 定价
mappingByGroupModel map[channelModelKey]string // (groupID, platform, model) → 映射目标
channelByGroupID map[int64]*Channel // groupID → 渠道
groupPlatform map[int64]string // groupID → platform
// 冷路径CRUD 操作)
byID map[int64]*Channel
@@ -135,6 +140,7 @@ func (s *ChannelService) buildCache(ctx context.Context) (*channelCache, error)
pricingByGroupModel: make(map[channelModelKey]*ChannelModelPricing),
mappingByGroupModel: make(map[channelModelKey]string),
channelByGroupID: make(map[int64]*Channel),
groupPlatform: make(map[int64]string),
byID: make(map[int64]*Channel),
loadedAt: time.Now().Add(channelCacheTTL - channelErrorTTL), // 使剩余 TTL = errorTTL
}
@@ -142,10 +148,25 @@ func (s *ChannelService) buildCache(ctx context.Context) (*channelCache, error)
return nil, fmt.Errorf("list all channels: %w", err)
}
// 收集所有 groupID批量查询 platform
var allGroupIDs []int64
for i := range channels {
allGroupIDs = append(allGroupIDs, channels[i].GroupIDs...)
}
groupPlatforms := make(map[int64]string)
if len(allGroupIDs) > 0 {
groupPlatforms, err = s.repo.GetGroupPlatforms(dbCtx, allGroupIDs)
if err != nil {
slog.Warn("failed to load group platforms for channel cache", "error", err)
// 降级:继续构建缓存但无法按平台过滤
}
}
cache := &channelCache{
pricingByGroupModel: make(map[channelModelKey]*ChannelModelPricing),
mappingByGroupModel: make(map[channelModelKey]string),
channelByGroupID: make(map[int64]*Channel),
groupPlatform: groupPlatforms,
byID: make(map[int64]*Channel, len(channels)),
loadedAt: time.Now(),
}
@@ -157,20 +178,26 @@ func (s *ChannelService) buildCache(ctx context.Context) (*channelCache, error)
// 展开到分组维度
for _, gid := range ch.GroupIDs {
cache.channelByGroupID[gid] = ch
platform := groupPlatforms[gid] // e.g. "anthropic"
// 展开模型定价到 (groupID, model) → *ChannelModelPricing
// 展开该平台的模型定价到 (groupID, platform, model) → *ChannelModelPricing
for j := range ch.ModelPricing {
pricing := &ch.ModelPricing[j]
if pricing.Platform != platform {
continue // 跳过非本平台的定价
}
for _, model := range pricing.Models {
key := channelModelKey{groupID: gid, model: strings.ToLower(model)}
key := channelModelKey{groupID: gid, platform: platform, model: strings.ToLower(model)}
cache.pricingByGroupModel[key] = pricing
}
}
// 展开模型映射到 (groupID, model) → target
for src, dst := range ch.ModelMapping {
key := channelModelKey{groupID: gid, model: strings.ToLower(src)}
cache.mappingByGroupModel[key] = dst
// 展开该平台的模型映射到 (groupID, platform, model) → target
if platformMapping, ok := ch.ModelMapping[platform]; ok {
for src, dst := range platformMapping {
key := channelModelKey{groupID: gid, platform: platform, model: strings.ToLower(src)}
cache.mappingByGroupModel[key] = dst
}
}
}
}
@@ -214,7 +241,8 @@ func (s *ChannelService) GetChannelModelPricing(ctx context.Context, groupID int
return nil
}
key := channelModelKey{groupID: groupID, model: strings.ToLower(model)}
platform := cache.groupPlatform[groupID]
key := channelModelKey{groupID: groupID, platform: platform, model: strings.ToLower(model)}
pricing, ok := cache.pricingByGroupModel[key]
if !ok {
return nil
@@ -246,7 +274,8 @@ func (s *ChannelService) ResolveChannelMapping(ctx context.Context, groupID int6
result.BillingModelSource = BillingModelSourceRequested
}
key := channelModelKey{groupID: groupID, model: strings.ToLower(model)}
platform := cache.groupPlatform[groupID]
key := channelModelKey{groupID: groupID, platform: platform, model: strings.ToLower(model)}
if mapped, ok := cache.mappingByGroupModel[key]; ok {
result.MappedModel = mapped
result.Mapped = true
@@ -270,7 +299,8 @@ func (s *ChannelService) IsModelRestricted(ctx context.Context, groupID int64, m
}
// 检查模型是否在定价列表中
key := channelModelKey{groupID: groupID, model: strings.ToLower(model)}
platform := cache.groupPlatform[groupID]
key := channelModelKey{groupID: groupID, platform: platform, model: strings.ToLower(model)}
_, exists := cache.pricingByGroupModel[key]
return !exists
}
@@ -458,7 +488,7 @@ type CreateChannelInput struct {
Description string
GroupIDs []int64
ModelPricing []ChannelModelPricing
ModelMapping map[string]string
ModelMapping map[string]map[string]string // platform → {src→dst}
BillingModelSource string
RestrictModels bool
}
@@ -470,7 +500,7 @@ type UpdateChannelInput struct {
Status string
GroupIDs *[]int64
ModelPricing *[]ChannelModelPricing
ModelMapping map[string]string
ModelMapping map[string]map[string]string // platform → {src→dst}
BillingModelSource string
RestrictModels *bool
}