feat(channel): 渠道管理全链路集成 — 模型映射、定价、限制、用量统计
- 渠道模型映射:支持精确匹配和通配符映射,按平台隔离 - 渠道模型定价:支持 token/按次/图片三种计费模式,区间分层定价 - 模型限制:渠道可限制仅允许定价列表中的模型 - 计费模型来源:支持 requested/upstream 两种计费模型选择 - 用量统计:usage_logs 新增 channel_id/model_mapping_chain/billing_tier/billing_mode 字段 - Dashboard 支持 model_source 维度(requested/upstream/mapping)查看模型统计 - 全部 gateway handler 统一接入 ResolveChannelMappingAndRestrict - 修复测试:同步 SoraGenerationRepository 接口、SQL INSERT 参数、scan 字段
This commit is contained in:
@@ -4,7 +4,6 @@ import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"sort"
|
||||
"strings"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
@@ -17,8 +16,8 @@ import (
|
||||
)
|
||||
|
||||
var (
|
||||
ErrChannelNotFound = infraerrors.NotFound("CHANNEL_NOT_FOUND", "channel not found")
|
||||
ErrChannelExists = infraerrors.Conflict("CHANNEL_EXISTS", "channel name already exists")
|
||||
ErrChannelNotFound = infraerrors.NotFound("CHANNEL_NOT_FOUND", "channel not found")
|
||||
ErrChannelExists = infraerrors.Conflict("CHANNEL_EXISTS", "channel name already exists")
|
||||
ErrGroupAlreadyInChannel = infraerrors.Conflict(
|
||||
"GROUP_ALREADY_IN_CHANNEL",
|
||||
"one or more groups already belong to another channel",
|
||||
@@ -81,12 +80,12 @@ type wildcardMappingEntry struct {
|
||||
// channelCache 渠道缓存快照(扁平化哈希结构,热路径 O(1) 查找)
|
||||
type channelCache struct {
|
||||
// 热路径查找
|
||||
pricingByGroupModel map[channelModelKey]*ChannelModelPricing // (groupID, platform, model) → 定价
|
||||
pricingByGroupModel map[channelModelKey]*ChannelModelPricing // (groupID, platform, model) → 定价
|
||||
wildcardByGroupPlatform map[channelGroupPlatformKey][]*wildcardPricingEntry // (groupID, platform) → 通配符定价(前缀长度降序)
|
||||
mappingByGroupModel map[channelModelKey]string // (groupID, platform, model) → 映射目标
|
||||
wildcardMappingByGP map[channelGroupPlatformKey][]*wildcardMappingEntry // (groupID, platform) → 通配符映射(前缀长度降序)
|
||||
channelByGroupID map[int64]*Channel // groupID → 渠道
|
||||
groupPlatform map[int64]string // groupID → platform
|
||||
mappingByGroupModel map[channelModelKey]string // (groupID, platform, model) → 映射目标
|
||||
wildcardMappingByGP map[channelGroupPlatformKey][]*wildcardMappingEntry // (groupID, platform) → 通配符映射(前缀长度降序)
|
||||
channelByGroupID map[int64]*Channel // groupID → 渠道
|
||||
groupPlatform map[int64]string // groupID → platform
|
||||
|
||||
// 冷路径(CRUD 操作)
|
||||
byID map[int64]*Channel
|
||||
@@ -118,9 +117,19 @@ func (r ChannelMappingResult) BuildModelMappingChain(reqModel, upstreamModel str
|
||||
return reqModel + "→" + r.MappedModel
|
||||
}
|
||||
|
||||
// ToUsageFields 将渠道映射结果转为使用记录字段
|
||||
func (r ChannelMappingResult) ToUsageFields(reqModel, upstreamModel string) ChannelUsageFields {
|
||||
return ChannelUsageFields{
|
||||
ChannelID: r.ChannelID,
|
||||
OriginalModel: reqModel,
|
||||
BillingModelSource: r.BillingModelSource,
|
||||
ModelMappingChain: r.BuildModelMappingChain(reqModel, upstreamModel),
|
||||
}
|
||||
}
|
||||
|
||||
const (
|
||||
channelCacheTTL = 60 * time.Second
|
||||
channelErrorTTL = 5 * time.Second // DB 错误时的短缓存
|
||||
channelCacheTTL = 60 * time.Second
|
||||
channelErrorTTL = 5 * time.Second // DB 错误时的短缓存
|
||||
channelCacheDBTimeout = 10 * time.Second
|
||||
)
|
||||
|
||||
@@ -177,14 +186,14 @@ func (s *ChannelService) buildCache(ctx context.Context) (*channelCache, error)
|
||||
// error-TTL:失败时存入短 TTL 空缓存,防止紧密重试
|
||||
slog.Warn("failed to build channel cache", "error", err)
|
||||
errorCache := &channelCache{
|
||||
pricingByGroupModel: make(map[channelModelKey]*ChannelModelPricing),
|
||||
pricingByGroupModel: make(map[channelModelKey]*ChannelModelPricing),
|
||||
wildcardByGroupPlatform: make(map[channelGroupPlatformKey][]*wildcardPricingEntry),
|
||||
mappingByGroupModel: make(map[channelModelKey]string),
|
||||
wildcardMappingByGP: make(map[channelGroupPlatformKey][]*wildcardMappingEntry),
|
||||
channelByGroupID: make(map[int64]*Channel),
|
||||
groupPlatform: make(map[int64]string),
|
||||
byID: make(map[int64]*Channel),
|
||||
loadedAt: time.Now().Add(channelCacheTTL - channelErrorTTL), // 使剩余 TTL = errorTTL
|
||||
mappingByGroupModel: make(map[channelModelKey]string),
|
||||
wildcardMappingByGP: make(map[channelGroupPlatformKey][]*wildcardMappingEntry),
|
||||
channelByGroupID: make(map[int64]*Channel),
|
||||
groupPlatform: make(map[int64]string),
|
||||
byID: make(map[int64]*Channel),
|
||||
loadedAt: time.Now().Add(channelCacheTTL - channelErrorTTL), // 使剩余 TTL = errorTTL
|
||||
}
|
||||
s.cache.Store(errorCache)
|
||||
return nil, fmt.Errorf("list all channels: %w", err)
|
||||
@@ -205,14 +214,14 @@ func (s *ChannelService) buildCache(ctx context.Context) (*channelCache, error)
|
||||
}
|
||||
|
||||
cache := &channelCache{
|
||||
pricingByGroupModel: make(map[channelModelKey]*ChannelModelPricing),
|
||||
pricingByGroupModel: make(map[channelModelKey]*ChannelModelPricing),
|
||||
wildcardByGroupPlatform: make(map[channelGroupPlatformKey][]*wildcardPricingEntry),
|
||||
mappingByGroupModel: make(map[channelModelKey]string),
|
||||
wildcardMappingByGP: make(map[channelGroupPlatformKey][]*wildcardMappingEntry),
|
||||
channelByGroupID: make(map[int64]*Channel),
|
||||
groupPlatform: groupPlatforms,
|
||||
byID: make(map[int64]*Channel, len(channels)),
|
||||
loadedAt: time.Now(),
|
||||
mappingByGroupModel: make(map[channelModelKey]string),
|
||||
wildcardMappingByGP: make(map[channelGroupPlatformKey][]*wildcardMappingEntry),
|
||||
channelByGroupID: make(map[int64]*Channel),
|
||||
groupPlatform: groupPlatforms,
|
||||
byID: make(map[int64]*Channel, len(channels)),
|
||||
loadedAt: time.Now(),
|
||||
}
|
||||
|
||||
for i := range channels {
|
||||
@@ -266,19 +275,7 @@ func (s *ChannelService) buildCache(ctx context.Context) (*channelCache, error)
|
||||
}
|
||||
}
|
||||
|
||||
// 通配符条目按前缀长度降序排列(最长前缀优先匹配)
|
||||
for gpKey, entries := range cache.wildcardByGroupPlatform {
|
||||
sort.Slice(entries, func(i, j int) bool {
|
||||
return len(entries[i].prefix) > len(entries[j].prefix)
|
||||
})
|
||||
cache.wildcardByGroupPlatform[gpKey] = entries
|
||||
}
|
||||
for gpKey, entries := range cache.wildcardMappingByGP {
|
||||
sort.Slice(entries, func(i, j int) bool {
|
||||
return len(entries[i].prefix) > len(entries[j].prefix)
|
||||
})
|
||||
cache.wildcardMappingByGP[gpKey] = entries
|
||||
}
|
||||
// 通配符条目保持配置顺序(最先匹配到优先)
|
||||
|
||||
s.cache.Store(cache)
|
||||
return cache, nil
|
||||
@@ -290,7 +287,7 @@ func (s *ChannelService) invalidateCache() {
|
||||
s.cacheSF.Forget("channel_cache")
|
||||
}
|
||||
|
||||
// matchWildcard 在通配符定价中查找匹配项(最长前缀优先)
|
||||
// matchWildcard 在通配符定价中查找匹配项(最先匹配到优先)
|
||||
func (c *channelCache) matchWildcard(groupID int64, platform, modelLower string) *ChannelModelPricing {
|
||||
gpKey := channelGroupPlatformKey{groupID: groupID, platform: platform}
|
||||
wildcards := c.wildcardByGroupPlatform[gpKey]
|
||||
@@ -302,7 +299,7 @@ func (c *channelCache) matchWildcard(groupID int64, platform, modelLower string)
|
||||
return nil
|
||||
}
|
||||
|
||||
// matchWildcardMapping 在通配符映射中查找匹配项(最长前缀优先)
|
||||
// matchWildcardMapping 在通配符映射中查找匹配项(最先匹配到优先)
|
||||
func (c *channelCache) matchWildcardMapping(groupID int64, platform, modelLower string) string {
|
||||
gpKey := channelGroupPlatformKey{groupID: groupID, platform: platform}
|
||||
wildcards := c.wildcardMappingByGP[gpKey]
|
||||
@@ -479,15 +476,18 @@ func (s *ChannelService) Create(ctx context.Context, input *CreateChannelInput)
|
||||
Status: StatusActive,
|
||||
BillingModelSource: input.BillingModelSource,
|
||||
RestrictModels: input.RestrictModels,
|
||||
GroupIDs: input.GroupIDs,
|
||||
ModelPricing: input.ModelPricing,
|
||||
ModelMapping: input.ModelMapping,
|
||||
GroupIDs: input.GroupIDs,
|
||||
ModelPricing: input.ModelPricing,
|
||||
ModelMapping: input.ModelMapping,
|
||||
}
|
||||
if channel.BillingModelSource == "" {
|
||||
channel.BillingModelSource = BillingModelSourceRequested
|
||||
}
|
||||
|
||||
if err := validateNoDuplicateModels(channel.ModelPricing); err != nil {
|
||||
if err := validateNoConflictingModels(channel.ModelPricing); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := validateNoConflictingMappings(channel.ModelMapping); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -558,7 +558,10 @@ func (s *ChannelService) Update(ctx context.Context, id int64, input *UpdateChan
|
||||
channel.BillingModelSource = input.BillingModelSource
|
||||
}
|
||||
|
||||
if err := validateNoDuplicateModels(channel.ModelPricing); err != nil {
|
||||
if err := validateNoConflictingModels(channel.ModelPricing); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := validateNoConflictingMappings(channel.ModelMapping); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -610,16 +613,79 @@ func (s *ChannelService) List(ctx context.Context, params pagination.PaginationP
|
||||
return s.repo.List(ctx, params, status, search)
|
||||
}
|
||||
|
||||
// validateNoDuplicateModels 检查定价列表中是否有重复模型(同一平台下不允许重复)
|
||||
func validateNoDuplicateModels(pricingList []ChannelModelPricing) error {
|
||||
seen := make(map[string]bool)
|
||||
// modelEntry 表示一个模型模式条目(用于冲突检测)
|
||||
type modelEntry struct {
|
||||
pattern string // 原始模式(如 "claude-*" 或 "claude-opus-4")
|
||||
prefix string // lowercase 前缀(通配符去掉 *,精确名保持原样)
|
||||
wildcard bool
|
||||
}
|
||||
|
||||
// conflictsBetween 检查两个模型模式是否冲突
|
||||
func conflictsBetween(a, b modelEntry) bool {
|
||||
switch {
|
||||
case !a.wildcard && !b.wildcard:
|
||||
return a.prefix == b.prefix
|
||||
case a.wildcard && !b.wildcard:
|
||||
return strings.HasPrefix(b.prefix, a.prefix)
|
||||
case !a.wildcard && b.wildcard:
|
||||
return strings.HasPrefix(a.prefix, b.prefix)
|
||||
default:
|
||||
return strings.HasPrefix(a.prefix, b.prefix) ||
|
||||
strings.HasPrefix(b.prefix, a.prefix)
|
||||
}
|
||||
}
|
||||
|
||||
// toModelEntry 将模型名转换为 modelEntry
|
||||
func toModelEntry(pattern string) modelEntry {
|
||||
lower := strings.ToLower(pattern)
|
||||
isWild := strings.HasSuffix(lower, "*")
|
||||
prefix := lower
|
||||
if isWild {
|
||||
prefix = strings.TrimSuffix(lower, "*")
|
||||
}
|
||||
return modelEntry{pattern: pattern, prefix: prefix, wildcard: isWild}
|
||||
}
|
||||
|
||||
// validateNoConflictingModels 检查定价列表中是否有冲突模型模式(同一平台下)。
|
||||
// 冲突包括:精确重复、通配符之间的前缀包含、通配符与精确名的前缀匹配。
|
||||
func validateNoConflictingModels(pricingList []ChannelModelPricing) error {
|
||||
byPlatform := make(map[string][]modelEntry)
|
||||
for _, p := range pricingList {
|
||||
for _, model := range p.Models {
|
||||
key := p.Platform + ":" + strings.ToLower(model)
|
||||
if seen[key] {
|
||||
return infraerrors.BadRequest("DUPLICATE_MODEL", fmt.Sprintf("model '%s' appears in multiple pricing entries for platform '%s'", model, p.Platform))
|
||||
byPlatform[p.Platform] = append(byPlatform[p.Platform], toModelEntry(model))
|
||||
}
|
||||
}
|
||||
for platform, entries := range byPlatform {
|
||||
if err := detectConflicts(entries, platform, "MODEL_PATTERN_CONFLICT", "model patterns"); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// validateNoConflictingMappings 检查模型映射中是否有冲突的源模式
|
||||
func validateNoConflictingMappings(mapping map[string]map[string]string) error {
|
||||
for platform, platformMapping := range mapping {
|
||||
entries := make([]modelEntry, 0, len(platformMapping))
|
||||
for src := range platformMapping {
|
||||
entries = append(entries, toModelEntry(src))
|
||||
}
|
||||
if err := detectConflicts(entries, platform, "MAPPING_PATTERN_CONFLICT", "mapping source patterns"); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// detectConflicts 在一组 modelEntry 中检测冲突,返回带有 errCode 和 label 的错误
|
||||
func detectConflicts(entries []modelEntry, platform, errCode, label string) error {
|
||||
for i := 0; i < len(entries); i++ {
|
||||
for j := i + 1; j < len(entries); j++ {
|
||||
if conflictsBetween(entries[i], entries[j]) {
|
||||
return infraerrors.BadRequest(errCode,
|
||||
fmt.Sprintf("%s '%s' and '%s' conflict in platform '%s': overlapping match range",
|
||||
label, entries[i].pattern, entries[j].pattern, platform))
|
||||
}
|
||||
seen[key] = true
|
||||
}
|
||||
}
|
||||
return nil
|
||||
|
||||
Reference in New Issue
Block a user