feat(channel): 缓存扁平化 + 网关映射集成 + 计费模式统一 + 模型限制
- 缓存重构为 O(1) 哈希结构 (pricingByGroupModel, mappingByGroupModel) - 渠道模型映射接入网关流程 (Forward 前应用, a→b→c 映射链) - 新增 billing_model_source 配置 (请求模型/最终模型计费) - usage_logs 新增 channel_id, model_mapping_chain, billing_tier 字段 - 每种计费模式统一支持默认价格 + 区间定价 - 渠道模型限制开关 (restrict_models) - 分组按平台分类展示 + 彩色图标 - 必填字段红色星号 + 模型映射 UI - 去除模型通配符支持
This commit is contained in:
@@ -47,13 +47,30 @@ type ChannelRepository interface {
|
||||
ReplaceModelPricing(ctx context.Context, channelID int64, pricingList []ChannelModelPricing) error
|
||||
}
|
||||
|
||||
// channelCache 渠道缓存快照
|
||||
// channelModelKey 渠道缓存复合键
|
||||
type channelModelKey struct {
|
||||
groupID int64
|
||||
model string // lowercase
|
||||
}
|
||||
|
||||
// channelCache 渠道缓存快照(扁平化哈希结构,热路径 O(1) 查找)
|
||||
type channelCache struct {
|
||||
// byID: channelID -> *Channel(含 ModelPricing)
|
||||
byID map[int64]*Channel
|
||||
// byGroupID: groupID -> channelID
|
||||
byGroupID map[int64]int64
|
||||
loadedAt time.Time
|
||||
// 热路径查找
|
||||
pricingByGroupModel map[channelModelKey]*ChannelModelPricing // (groupID, model) → 定价
|
||||
mappingByGroupModel map[channelModelKey]string // (groupID, model) → 映射目标
|
||||
channelByGroupID map[int64]*Channel // groupID → 渠道
|
||||
|
||||
// 冷路径(CRUD 操作)
|
||||
byID map[int64]*Channel
|
||||
loadedAt time.Time
|
||||
}
|
||||
|
||||
// ChannelMappingResult 渠道映射查找结果
|
||||
type ChannelMappingResult struct {
|
||||
MappedModel string // 映射后的模型名(无映射时等于原始模型名)
|
||||
ChannelID int64 // 渠道 ID(0 = 无渠道关联)
|
||||
Mapped bool // 是否发生了映射
|
||||
BillingModelSource string // 计费模型来源("requested" / "upstream")
|
||||
}
|
||||
|
||||
const (
|
||||
@@ -115,25 +132,46 @@ func (s *ChannelService) buildCache(ctx context.Context) (*channelCache, error)
|
||||
// error-TTL:失败时存入短 TTL 空缓存,防止紧密重试
|
||||
slog.Warn("failed to build channel cache", "error", err)
|
||||
errorCache := &channelCache{
|
||||
byID: make(map[int64]*Channel),
|
||||
byGroupID: make(map[int64]int64),
|
||||
loadedAt: time.Now().Add(channelCacheTTL - channelErrorTTL), // 使剩余 TTL = errorTTL
|
||||
pricingByGroupModel: make(map[channelModelKey]*ChannelModelPricing),
|
||||
mappingByGroupModel: make(map[channelModelKey]string),
|
||||
channelByGroupID: make(map[int64]*Channel),
|
||||
byID: make(map[int64]*Channel),
|
||||
loadedAt: time.Now().Add(channelCacheTTL - channelErrorTTL), // 使剩余 TTL = errorTTL
|
||||
}
|
||||
s.cache.Store(errorCache)
|
||||
return nil, fmt.Errorf("list all channels: %w", err)
|
||||
}
|
||||
|
||||
cache := &channelCache{
|
||||
byID: make(map[int64]*Channel, len(channels)),
|
||||
byGroupID: make(map[int64]int64),
|
||||
loadedAt: time.Now(),
|
||||
pricingByGroupModel: make(map[channelModelKey]*ChannelModelPricing),
|
||||
mappingByGroupModel: make(map[channelModelKey]string),
|
||||
channelByGroupID: make(map[int64]*Channel),
|
||||
byID: make(map[int64]*Channel, len(channels)),
|
||||
loadedAt: time.Now(),
|
||||
}
|
||||
|
||||
for i := range channels {
|
||||
ch := &channels[i]
|
||||
cache.byID[ch.ID] = ch
|
||||
|
||||
// 展开到分组维度
|
||||
for _, gid := range ch.GroupIDs {
|
||||
cache.byGroupID[gid] = ch.ID
|
||||
cache.channelByGroupID[gid] = ch
|
||||
|
||||
// 展开模型定价到 (groupID, model) → *ChannelModelPricing
|
||||
for j := range ch.ModelPricing {
|
||||
pricing := &ch.ModelPricing[j]
|
||||
for _, model := range pricing.Models {
|
||||
key := channelModelKey{groupID: gid, model: strings.ToLower(model)}
|
||||
cache.pricingByGroupModel[key] = pricing
|
||||
}
|
||||
}
|
||||
|
||||
// 展开模型映射到 (groupID, model) → target
|
||||
for src, dst := range ch.ModelMapping {
|
||||
key := channelModelKey{groupID: gid, model: strings.ToLower(src)}
|
||||
cache.mappingByGroupModel[key] = dst
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -147,42 +185,94 @@ func (s *ChannelService) invalidateCache() {
|
||||
s.cacheSF.Forget("channel_cache")
|
||||
}
|
||||
|
||||
// GetChannelForGroup 获取分组关联的渠道(热路径,从缓存读取)
|
||||
// 返回深拷贝,不污染缓存。
|
||||
// GetChannelForGroup 获取分组关联的渠道(热路径 O(1))
|
||||
func (s *ChannelService) GetChannelForGroup(ctx context.Context, groupID int64) (*Channel, error) {
|
||||
cache, err := s.loadCache(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
channelID, ok := cache.byGroupID[groupID]
|
||||
if !ok {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
ch, ok := cache.byID[channelID]
|
||||
if !ok {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
if !ch.IsActive() {
|
||||
ch, ok := cache.channelByGroupID[groupID]
|
||||
if !ok || !ch.IsActive() {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
return ch.Clone(), nil
|
||||
}
|
||||
|
||||
// GetChannelModelPricing 获取指定分组+模型的渠道定价(热路径)
|
||||
// GetChannelModelPricing 获取指定分组+模型的渠道定价(热路径 O(1))
|
||||
func (s *ChannelService) GetChannelModelPricing(ctx context.Context, groupID int64, model string) *ChannelModelPricing {
|
||||
ch, err := s.GetChannelForGroup(ctx, groupID)
|
||||
cache, err := s.loadCache(ctx)
|
||||
if err != nil {
|
||||
slog.Warn("failed to get channel for group", "group_id", groupID, "error", err)
|
||||
slog.Warn("failed to load channel cache", "group_id", groupID, "error", err)
|
||||
return nil
|
||||
}
|
||||
if ch == nil {
|
||||
|
||||
// 检查渠道是否启用
|
||||
ch, ok := cache.channelByGroupID[groupID]
|
||||
if !ok || !ch.IsActive() {
|
||||
return nil
|
||||
}
|
||||
return ch.GetModelPricing(model)
|
||||
|
||||
key := channelModelKey{groupID: groupID, model: strings.ToLower(model)}
|
||||
pricing, ok := cache.pricingByGroupModel[key]
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
|
||||
cp := pricing.Clone()
|
||||
return &cp
|
||||
}
|
||||
|
||||
// ResolveChannelMapping 解析渠道级模型映射(热路径 O(1))
|
||||
// 返回映射结果,包含映射后的模型名、渠道 ID、计费模型来源。
|
||||
func (s *ChannelService) ResolveChannelMapping(ctx context.Context, groupID int64, model string) ChannelMappingResult {
|
||||
cache, err := s.loadCache(ctx)
|
||||
if err != nil {
|
||||
return ChannelMappingResult{MappedModel: model}
|
||||
}
|
||||
|
||||
ch, ok := cache.channelByGroupID[groupID]
|
||||
if !ok || !ch.IsActive() {
|
||||
return ChannelMappingResult{MappedModel: model}
|
||||
}
|
||||
|
||||
result := ChannelMappingResult{
|
||||
MappedModel: model,
|
||||
ChannelID: ch.ID,
|
||||
BillingModelSource: ch.BillingModelSource,
|
||||
}
|
||||
if result.BillingModelSource == "" {
|
||||
result.BillingModelSource = BillingModelSourceRequested
|
||||
}
|
||||
|
||||
key := channelModelKey{groupID: groupID, model: strings.ToLower(model)}
|
||||
if mapped, ok := cache.mappingByGroupModel[key]; ok {
|
||||
result.MappedModel = mapped
|
||||
result.Mapped = true
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
// IsModelRestricted 检查模型是否被渠道限制。
|
||||
// 返回 true 表示模型被限制(不在允许列表中)。
|
||||
// 如果渠道未启用模型限制或分组无渠道关联,返回 false。
|
||||
func (s *ChannelService) IsModelRestricted(ctx context.Context, groupID int64, model string) bool {
|
||||
cache, err := s.loadCache(ctx)
|
||||
if err != nil {
|
||||
return false // 缓存加载失败时不限制
|
||||
}
|
||||
|
||||
ch, ok := cache.channelByGroupID[groupID]
|
||||
if !ok || !ch.IsActive() || !ch.RestrictModels {
|
||||
return false
|
||||
}
|
||||
|
||||
// 检查模型是否在定价列表中
|
||||
key := channelModelKey{groupID: groupID, model: strings.ToLower(model)}
|
||||
_, exists := cache.pricingByGroupModel[key]
|
||||
return !exists
|
||||
}
|
||||
|
||||
// --- CRUD ---
|
||||
@@ -209,12 +299,17 @@ func (s *ChannelService) Create(ctx context.Context, input *CreateChannelInput)
|
||||
}
|
||||
|
||||
channel := &Channel{
|
||||
Name: input.Name,
|
||||
Description: input.Description,
|
||||
Status: StatusActive,
|
||||
GroupIDs: input.GroupIDs,
|
||||
ModelPricing: input.ModelPricing,
|
||||
ModelMapping: input.ModelMapping,
|
||||
Name: input.Name,
|
||||
Description: input.Description,
|
||||
Status: StatusActive,
|
||||
BillingModelSource: input.BillingModelSource,
|
||||
RestrictModels: input.RestrictModels,
|
||||
GroupIDs: input.GroupIDs,
|
||||
ModelPricing: input.ModelPricing,
|
||||
ModelMapping: input.ModelMapping,
|
||||
}
|
||||
if channel.BillingModelSource == "" {
|
||||
channel.BillingModelSource = BillingModelSourceRequested
|
||||
}
|
||||
|
||||
if err := validateNoDuplicateModels(channel.ModelPricing); err != nil {
|
||||
@@ -260,6 +355,10 @@ func (s *ChannelService) Update(ctx context.Context, id int64, input *UpdateChan
|
||||
channel.Status = input.Status
|
||||
}
|
||||
|
||||
if input.RestrictModels != nil {
|
||||
channel.RestrictModels = *input.RestrictModels
|
||||
}
|
||||
|
||||
// 检查分组冲突
|
||||
if input.GroupIDs != nil {
|
||||
conflicting, err := s.repo.GetGroupsInOtherChannels(ctx, id, *input.GroupIDs)
|
||||
@@ -280,6 +379,10 @@ func (s *ChannelService) Update(ctx context.Context, id int64, input *UpdateChan
|
||||
channel.ModelMapping = input.ModelMapping
|
||||
}
|
||||
|
||||
if input.BillingModelSource != "" {
|
||||
channel.BillingModelSource = input.BillingModelSource
|
||||
}
|
||||
|
||||
if err := validateNoDuplicateModels(channel.ModelPricing); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -351,19 +454,23 @@ func validateNoDuplicateModels(pricingList []ChannelModelPricing) error {
|
||||
|
||||
// CreateChannelInput 创建渠道输入
|
||||
type CreateChannelInput struct {
|
||||
Name string
|
||||
Description string
|
||||
GroupIDs []int64
|
||||
ModelPricing []ChannelModelPricing
|
||||
ModelMapping map[string]string
|
||||
Name string
|
||||
Description string
|
||||
GroupIDs []int64
|
||||
ModelPricing []ChannelModelPricing
|
||||
ModelMapping map[string]string
|
||||
BillingModelSource string
|
||||
RestrictModels bool
|
||||
}
|
||||
|
||||
// UpdateChannelInput 更新渠道输入
|
||||
type UpdateChannelInput struct {
|
||||
Name string
|
||||
Description *string
|
||||
Status string
|
||||
GroupIDs *[]int64
|
||||
ModelPricing *[]ChannelModelPricing
|
||||
ModelMapping map[string]string
|
||||
Name string
|
||||
Description *string
|
||||
Status string
|
||||
GroupIDs *[]int64
|
||||
ModelPricing *[]ChannelModelPricing
|
||||
ModelMapping map[string]string
|
||||
BillingModelSource string
|
||||
RestrictModels *bool
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user