feat(channel): 缓存扁平化 + 网关映射集成 + 计费模式统一 + 模型限制

- 缓存重构为 O(1) 哈希结构 (pricingByGroupModel, mappingByGroupModel)
- 渠道模型映射接入网关流程 (Forward 前应用, a→b→c 映射链)
- 新增 billing_model_source 配置 (请求模型/最终模型计费)
- usage_logs 新增 channel_id, model_mapping_chain, billing_tier 字段
- 每种计费模式统一支持默认价格 + 区间定价
- 渠道模型限制开关 (restrict_models)
- 分组按平台分类展示 + 彩色图标
- 必填字段红色星号 + 模型映射 UI
- 去除模型通配符支持
This commit is contained in:
erio
2026-03-30 13:26:05 +08:00
parent 29d58f2414
commit ebac0dc628
23 changed files with 779 additions and 206 deletions

View File

@@ -23,14 +23,21 @@ func (m BillingMode) IsValid() bool {
return false
}
const (
BillingModelSourceRequested = "requested"
BillingModelSourceUpstream = "upstream"
)
// Channel 渠道实体
type Channel struct {
ID int64
Name string
Description string
Status string
CreatedAt time.Time
UpdatedAt time.Time
ID int64
Name string
Description string
Status string
BillingModelSource string // "requested" or "upstream"
RestrictModels bool // 是否限制模型(仅允许定价列表中的模型)
CreatedAt time.Time
UpdatedAt time.Time
// 关联的分组 ID 列表
GroupIDs []int64
@@ -44,13 +51,14 @@ type Channel struct {
type ChannelModelPricing struct {
ID int64
ChannelID int64
Models []string // 绑定的模型列表
BillingMode BillingMode // 计费模式
InputPrice *float64 // 每 token 输入价格USD— 向后兼容 flat 定价
OutputPrice *float64 // 每 token 输出价格USD
CacheWritePrice *float64 // 缓存写入价格
CacheReadPrice *float64 // 缓存读取价格
ImageOutputPrice *float64 // 图片输出价格(向后兼容)
Models []string // 绑定的模型列表
BillingMode BillingMode // 计费模式
InputPrice *float64 // 每 token 输入价格USD— 向后兼容 flat 定价
OutputPrice *float64 // 每 token 输出价格USD
CacheWritePrice *float64 // 缓存写入价格
CacheReadPrice *float64 // 缓存读取价格
ImageOutputPrice *float64 // 图片输出价格(向后兼容)
PerRequestPrice *float64 // 默认按次计费价格USD
Intervals []PricingInterval // 区间定价列表
CreatedAt time.Time
UpdatedAt time.Time
@@ -106,12 +114,10 @@ func (c *Channel) IsActive() bool {
}
// GetModelPricing 根据模型名查找渠道定价,未找到返回 nil。
// 优先精确匹配,然后通配符匹配(如 claude-opus-*)。大小写不敏感。
// 返回值拷贝,不污染缓存。
// 精确匹配,大小写不敏感。返回值拷贝,不污染缓存
func (c *Channel) GetModelPricing(model string) *ChannelModelPricing {
modelLower := strings.ToLower(model)
// 第一轮:精确匹配
for i := range c.ModelPricing {
for _, m := range c.ModelPricing[i].Models {
if strings.ToLower(m) == modelLower {
@@ -121,20 +127,6 @@ func (c *Channel) GetModelPricing(model string) *ChannelModelPricing {
}
}
// 第二轮:通配符匹配(仅支持末尾 *
for i := range c.ModelPricing {
for _, m := range c.ModelPricing[i].Models {
mLower := strings.ToLower(m)
if strings.HasSuffix(mLower, "*") {
prefix := strings.TrimSuffix(mLower, "*")
if strings.HasPrefix(modelLower, prefix) {
cp := c.ModelPricing[i].Clone()
return &cp
}
}
}
}
return nil
}

View File

@@ -47,13 +47,30 @@ type ChannelRepository interface {
ReplaceModelPricing(ctx context.Context, channelID int64, pricingList []ChannelModelPricing) error
}
// channelCache 渠道缓存快照
// channelModelKey 渠道缓存复合键
type channelModelKey struct {
groupID int64
model string // lowercase
}
// channelCache 渠道缓存快照(扁平化哈希结构,热路径 O(1) 查找)
type channelCache struct {
// byID: channelID -> *Channel含 ModelPricing
byID map[int64]*Channel
// byGroupID: groupID -> channelID
byGroupID map[int64]int64
loadedAt time.Time
// 热路径查找
pricingByGroupModel map[channelModelKey]*ChannelModelPricing // (groupID, model) → 定价
mappingByGroupModel map[channelModelKey]string // (groupID, model) → 映射目标
channelByGroupID map[int64]*Channel // groupID → 渠道
// 冷路径CRUD 操作)
byID map[int64]*Channel
loadedAt time.Time
}
// ChannelMappingResult 渠道映射查找结果
type ChannelMappingResult struct {
MappedModel string // 映射后的模型名(无映射时等于原始模型名)
ChannelID int64 // 渠道 ID0 = 无渠道关联)
Mapped bool // 是否发生了映射
BillingModelSource string // 计费模型来源("requested" / "upstream"
}
const (
@@ -115,25 +132,46 @@ func (s *ChannelService) buildCache(ctx context.Context) (*channelCache, error)
// error-TTL失败时存入短 TTL 空缓存,防止紧密重试
slog.Warn("failed to build channel cache", "error", err)
errorCache := &channelCache{
byID: make(map[int64]*Channel),
byGroupID: make(map[int64]int64),
loadedAt: time.Now().Add(channelCacheTTL - channelErrorTTL), // 使剩余 TTL = errorTTL
pricingByGroupModel: make(map[channelModelKey]*ChannelModelPricing),
mappingByGroupModel: make(map[channelModelKey]string),
channelByGroupID: make(map[int64]*Channel),
byID: make(map[int64]*Channel),
loadedAt: time.Now().Add(channelCacheTTL - channelErrorTTL), // 使剩余 TTL = errorTTL
}
s.cache.Store(errorCache)
return nil, fmt.Errorf("list all channels: %w", err)
}
cache := &channelCache{
byID: make(map[int64]*Channel, len(channels)),
byGroupID: make(map[int64]int64),
loadedAt: time.Now(),
pricingByGroupModel: make(map[channelModelKey]*ChannelModelPricing),
mappingByGroupModel: make(map[channelModelKey]string),
channelByGroupID: make(map[int64]*Channel),
byID: make(map[int64]*Channel, len(channels)),
loadedAt: time.Now(),
}
for i := range channels {
ch := &channels[i]
cache.byID[ch.ID] = ch
// 展开到分组维度
for _, gid := range ch.GroupIDs {
cache.byGroupID[gid] = ch.ID
cache.channelByGroupID[gid] = ch
// 展开模型定价到 (groupID, model) → *ChannelModelPricing
for j := range ch.ModelPricing {
pricing := &ch.ModelPricing[j]
for _, model := range pricing.Models {
key := channelModelKey{groupID: gid, model: strings.ToLower(model)}
cache.pricingByGroupModel[key] = pricing
}
}
// 展开模型映射到 (groupID, model) → target
for src, dst := range ch.ModelMapping {
key := channelModelKey{groupID: gid, model: strings.ToLower(src)}
cache.mappingByGroupModel[key] = dst
}
}
}
@@ -147,42 +185,94 @@ func (s *ChannelService) invalidateCache() {
s.cacheSF.Forget("channel_cache")
}
// GetChannelForGroup 获取分组关联的渠道(热路径,从缓存读取
// 返回深拷贝,不污染缓存。
// GetChannelForGroup 获取分组关联的渠道(热路径 O(1)
func (s *ChannelService) GetChannelForGroup(ctx context.Context, groupID int64) (*Channel, error) {
cache, err := s.loadCache(ctx)
if err != nil {
return nil, err
}
channelID, ok := cache.byGroupID[groupID]
if !ok {
return nil, nil
}
ch, ok := cache.byID[channelID]
if !ok {
return nil, nil
}
if !ch.IsActive() {
ch, ok := cache.channelByGroupID[groupID]
if !ok || !ch.IsActive() {
return nil, nil
}
return ch.Clone(), nil
}
// GetChannelModelPricing 获取指定分组+模型的渠道定价(热路径)
// GetChannelModelPricing 获取指定分组+模型的渠道定价(热路径 O(1)
func (s *ChannelService) GetChannelModelPricing(ctx context.Context, groupID int64, model string) *ChannelModelPricing {
ch, err := s.GetChannelForGroup(ctx, groupID)
cache, err := s.loadCache(ctx)
if err != nil {
slog.Warn("failed to get channel for group", "group_id", groupID, "error", err)
slog.Warn("failed to load channel cache", "group_id", groupID, "error", err)
return nil
}
if ch == nil {
// 检查渠道是否启用
ch, ok := cache.channelByGroupID[groupID]
if !ok || !ch.IsActive() {
return nil
}
return ch.GetModelPricing(model)
key := channelModelKey{groupID: groupID, model: strings.ToLower(model)}
pricing, ok := cache.pricingByGroupModel[key]
if !ok {
return nil
}
cp := pricing.Clone()
return &cp
}
// ResolveChannelMapping 解析渠道级模型映射(热路径 O(1)
// 返回映射结果,包含映射后的模型名、渠道 ID、计费模型来源。
func (s *ChannelService) ResolveChannelMapping(ctx context.Context, groupID int64, model string) ChannelMappingResult {
cache, err := s.loadCache(ctx)
if err != nil {
return ChannelMappingResult{MappedModel: model}
}
ch, ok := cache.channelByGroupID[groupID]
if !ok || !ch.IsActive() {
return ChannelMappingResult{MappedModel: model}
}
result := ChannelMappingResult{
MappedModel: model,
ChannelID: ch.ID,
BillingModelSource: ch.BillingModelSource,
}
if result.BillingModelSource == "" {
result.BillingModelSource = BillingModelSourceRequested
}
key := channelModelKey{groupID: groupID, model: strings.ToLower(model)}
if mapped, ok := cache.mappingByGroupModel[key]; ok {
result.MappedModel = mapped
result.Mapped = true
}
return result
}
// IsModelRestricted 检查模型是否被渠道限制。
// 返回 true 表示模型被限制(不在允许列表中)。
// 如果渠道未启用模型限制或分组无渠道关联,返回 false。
func (s *ChannelService) IsModelRestricted(ctx context.Context, groupID int64, model string) bool {
cache, err := s.loadCache(ctx)
if err != nil {
return false // 缓存加载失败时不限制
}
ch, ok := cache.channelByGroupID[groupID]
if !ok || !ch.IsActive() || !ch.RestrictModels {
return false
}
// 检查模型是否在定价列表中
key := channelModelKey{groupID: groupID, model: strings.ToLower(model)}
_, exists := cache.pricingByGroupModel[key]
return !exists
}
// --- CRUD ---
@@ -209,12 +299,17 @@ func (s *ChannelService) Create(ctx context.Context, input *CreateChannelInput)
}
channel := &Channel{
Name: input.Name,
Description: input.Description,
Status: StatusActive,
GroupIDs: input.GroupIDs,
ModelPricing: input.ModelPricing,
ModelMapping: input.ModelMapping,
Name: input.Name,
Description: input.Description,
Status: StatusActive,
BillingModelSource: input.BillingModelSource,
RestrictModels: input.RestrictModels,
GroupIDs: input.GroupIDs,
ModelPricing: input.ModelPricing,
ModelMapping: input.ModelMapping,
}
if channel.BillingModelSource == "" {
channel.BillingModelSource = BillingModelSourceRequested
}
if err := validateNoDuplicateModels(channel.ModelPricing); err != nil {
@@ -260,6 +355,10 @@ func (s *ChannelService) Update(ctx context.Context, id int64, input *UpdateChan
channel.Status = input.Status
}
if input.RestrictModels != nil {
channel.RestrictModels = *input.RestrictModels
}
// 检查分组冲突
if input.GroupIDs != nil {
conflicting, err := s.repo.GetGroupsInOtherChannels(ctx, id, *input.GroupIDs)
@@ -280,6 +379,10 @@ func (s *ChannelService) Update(ctx context.Context, id int64, input *UpdateChan
channel.ModelMapping = input.ModelMapping
}
if input.BillingModelSource != "" {
channel.BillingModelSource = input.BillingModelSource
}
if err := validateNoDuplicateModels(channel.ModelPricing); err != nil {
return nil, err
}
@@ -351,19 +454,23 @@ func validateNoDuplicateModels(pricingList []ChannelModelPricing) error {
// CreateChannelInput 创建渠道输入
type CreateChannelInput struct {
Name string
Description string
GroupIDs []int64
ModelPricing []ChannelModelPricing
ModelMapping map[string]string
Name string
Description string
GroupIDs []int64
ModelPricing []ChannelModelPricing
ModelMapping map[string]string
BillingModelSource string
RestrictModels bool
}
// UpdateChannelInput 更新渠道输入
type UpdateChannelInput struct {
Name string
Description *string
Status string
GroupIDs *[]int64
ModelPricing *[]ChannelModelPricing
ModelMapping map[string]string
Name string
Description *string
Status string
GroupIDs *[]int64
ModelPricing *[]ChannelModelPricing
ModelMapping map[string]string
BillingModelSource string
RestrictModels *bool
}

View File

@@ -15,7 +15,6 @@ func TestGetModelPricing(t *testing.T) {
ch := &Channel{
ModelPricing: []ChannelModelPricing{
{ID: 1, Models: []string{"claude-sonnet-4"}, BillingMode: BillingModeToken, InputPrice: channelTestPtrFloat64(3e-6)},
{ID: 2, Models: []string{"claude-*"}, BillingMode: BillingModeToken, InputPrice: channelTestPtrFloat64(5e-6)},
{ID: 3, Models: []string{"gpt-5.1"}, BillingMode: BillingModePerRequest},
},
}
@@ -28,9 +27,8 @@ func TestGetModelPricing(t *testing.T) {
}{
{"exact match", "claude-sonnet-4", 1, false},
{"case insensitive", "Claude-Sonnet-4", 1, false},
{"wildcard match", "claude-opus-4-20250514", 2, false},
{"exact takes priority over wildcard", "claude-sonnet-4", 1, false},
{"not found", "gemini-3.1-pro", 0, true},
{"wildcard pattern not matched", "claude-opus-4-20250514", 0, true},
{"per_request model", "gpt-5.1", 3, false},
}

View File

@@ -7413,6 +7413,12 @@ type RecordUsageInput struct {
RequestPayloadHash string // 请求体语义哈希,用于降低 request_id 误复用时的静默误去重风险
ForceCacheBilling bool // 强制缓存计费:将 input_tokens 转为 cache_read 计费(用于粘性会话切换)
APIKeyService APIKeyQuotaUpdater // 可选用于更新API Key配额
// 渠道映射信息(由 handler 在 Forward 前解析)
ChannelID int64 // 渠道 ID0 = 无渠道)
OriginalModel string // 用户原始请求模型(渠道映射前)
BillingModelSource string // 计费模型来源:"requested" / "upstream"
ModelMappingChain string // 映射链描述,如 "a→b→c"
}
// APIKeyQuotaUpdater defines the interface for updating API Key quota and rate limit usage
@@ -7732,7 +7738,17 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu
}
var cost *CostBreakdown
// 确定计费模型
billingModel := forwardResultBillingModel(result.Model, result.UpstreamModel)
if input.BillingModelSource == BillingModelSourceRequested && input.OriginalModel != "" {
billingModel = input.OriginalModel
}
// 确定 RequestedModel渠道映射前的原始模型
requestedModel := result.Model
if input.OriginalModel != "" {
requestedModel = input.OriginalModel
}
// 根据请求类型选择计费方式
if result.MediaType == "image" || result.MediaType == "video" {
@@ -7815,7 +7831,7 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu
AccountID: account.ID,
RequestID: requestID,
Model: result.Model,
RequestedModel: result.Model,
RequestedModel: requestedModel,
UpstreamModel: optionalNonEqualStringPtr(result.UpstreamModel, result.Model),
ReasoningEffort: result.ReasoningEffort,
InboundEndpoint: optionalTrimmedStringPtr(input.InboundEndpoint),
@@ -7842,6 +7858,8 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu
ImageSize: imageSize,
MediaType: mediaType,
CacheTTLOverridden: cacheTTLOverridden,
ChannelID: optionalInt64Ptr(input.ChannelID),
ModelMappingChain: optionalTrimmedStringPtr(input.ModelMappingChain),
CreatedAt: time.Now(),
}
@@ -7909,6 +7927,12 @@ type RecordUsageLongContextInput struct {
LongContextMultiplier float64 // 超出阈值部分的倍率(如 2.0
ForceCacheBilling bool // 强制缓存计费:将 input_tokens 转为 cache_read 计费(用于粘性会话切换)
APIKeyService APIKeyQuotaUpdater // API Key 配额服务(可选)
// 渠道映射信息(由 handler 在 Forward 前解析)
ChannelID int64 // 渠道 ID0 = 无渠道)
OriginalModel string // 用户原始请求模型(渠道映射前)
BillingModelSource string // 计费模型来源:"requested" / "upstream"
ModelMappingChain string // 映射链描述,如 "a→b→c"
}
// RecordUsageWithLongContext 记录使用量并扣费,支持长上下文双倍计费(用于 Gemini
@@ -7946,7 +7970,17 @@ func (s *GatewayService) RecordUsageWithLongContext(ctx context.Context, input *
}
var cost *CostBreakdown
// 确定计费模型
billingModel := forwardResultBillingModel(result.Model, result.UpstreamModel)
if input.BillingModelSource == BillingModelSourceRequested && input.OriginalModel != "" {
billingModel = input.OriginalModel
}
// 确定 RequestedModel渠道映射前的原始模型
requestedModel := result.Model
if input.OriginalModel != "" {
requestedModel = input.OriginalModel
}
// 根据请求类型选择计费方式
if result.ImageCount > 0 {
@@ -8008,7 +8042,7 @@ func (s *GatewayService) RecordUsageWithLongContext(ctx context.Context, input *
AccountID: account.ID,
RequestID: requestID,
Model: result.Model,
RequestedModel: result.Model,
RequestedModel: requestedModel,
UpstreamModel: optionalNonEqualStringPtr(result.UpstreamModel, result.Model),
ReasoningEffort: result.ReasoningEffort,
InboundEndpoint: optionalTrimmedStringPtr(input.InboundEndpoint),
@@ -8034,6 +8068,8 @@ func (s *GatewayService) RecordUsageWithLongContext(ctx context.Context, input *
ImageCount: result.ImageCount,
ImageSize: imageSize,
CacheTTLOverridden: cacheTTLOverridden,
ChannelID: optionalInt64Ptr(input.ChannelID),
ModelMappingChain: optionalTrimmedStringPtr(input.ModelMappingChain),
CreatedAt: time.Now(),
}
@@ -8085,6 +8121,27 @@ func (s *GatewayService) RecordUsageWithLongContext(ctx context.Context, input *
return nil
}
// ResolveChannelMapping 委托渠道服务解析模型映射
func (s *GatewayService) ResolveChannelMapping(ctx context.Context, groupID int64, model string) ChannelMappingResult {
if s.channelService == nil {
return ChannelMappingResult{MappedModel: model}
}
return s.channelService.ResolveChannelMapping(ctx, groupID, model)
}
// ReplaceModelInBody 替换请求体中的模型名(导出供 handler 使用)
func (s *GatewayService) ReplaceModelInBody(body []byte, newModel string) []byte {
return s.replaceModelInBody(body, newModel)
}
// IsModelRestricted 检查模型是否被渠道限制
func (s *GatewayService) IsModelRestricted(ctx context.Context, groupID int64, model string) bool {
if s.channelService == nil {
return false
}
return s.channelService.IsModelRestricted(ctx, groupID, model)
}
// ForwardCountTokens 转发 count_tokens 请求到上游 API
// 特点:不记录使用量、仅支持非流式响应
func (s *GatewayService) ForwardCountTokens(ctx context.Context, c *gin.Context, account *Account, parsed *ParsedRequest) error {

View File

@@ -104,6 +104,12 @@ type UsageLog struct {
// UpstreamModel is the actual model sent to the upstream provider after mapping.
// Nil means no mapping was applied (requested model was used as-is).
UpstreamModel *string
// ChannelID 渠道 ID
ChannelID *int64
// ModelMappingChain 模型映射链,如 "a→b→c"
ModelMappingChain *string
// BillingTier 计费层级标签per_request/image 模式)
BillingTier *string
// ServiceTier records the OpenAI service tier used for billing, e.g. "priority" / "flex".
ServiceTier *string
// ReasoningEffort is the request's reasoning effort level.

View File

@@ -26,3 +26,10 @@ func forwardResultBillingModel(requestedModel, upstreamModel string) string {
}
return strings.TrimSpace(upstreamModel)
}
func optionalInt64Ptr(v int64) *int64 {
if v == 0 {
return nil
}
return &v
}