feat(channel): 通配符定价匹配 + OpenAI BillingModelSource + 按次价格校验 + 用户端计费模式展示
- 定价查找支持通配符(suffix *),最长前缀优先匹配 - 模型限制(restrict_models)同样支持通配符匹配 - OpenAI 网关接入渠道映射/BillingModelSource/模型限制 - 按次/图片计费模式创建时强制要求价格或层级(前后端) - 用户使用记录列表增加计费模式 badge 列
This commit is contained in:
@@ -4,6 +4,7 @@ import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"sort"
|
||||
"strings"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
@@ -57,13 +58,26 @@ type channelModelKey struct {
|
||||
model string // lowercase
|
||||
}
|
||||
|
||||
// channelGroupPlatformKey 通配符定价缓存键
|
||||
type channelGroupPlatformKey struct {
|
||||
groupID int64
|
||||
platform string
|
||||
}
|
||||
|
||||
// wildcardPricingEntry 通配符定价条目
|
||||
type wildcardPricingEntry struct {
|
||||
prefix string
|
||||
pricing *ChannelModelPricing
|
||||
}
|
||||
|
||||
// channelCache 渠道缓存快照(扁平化哈希结构,热路径 O(1) 查找)
|
||||
type channelCache struct {
|
||||
// 热路径查找
|
||||
pricingByGroupModel map[channelModelKey]*ChannelModelPricing // (groupID, platform, model) → 定价
|
||||
mappingByGroupModel map[channelModelKey]string // (groupID, platform, model) → 映射目标
|
||||
channelByGroupID map[int64]*Channel // groupID → 渠道
|
||||
groupPlatform map[int64]string // groupID → platform
|
||||
pricingByGroupModel map[channelModelKey]*ChannelModelPricing // (groupID, platform, model) → 定价
|
||||
wildcardByGroupPlatform map[channelGroupPlatformKey][]*wildcardPricingEntry // (groupID, platform) → 通配符定价(前缀长度降序)
|
||||
mappingByGroupModel map[channelModelKey]string // (groupID, platform, model) → 映射目标
|
||||
channelByGroupID map[int64]*Channel // groupID → 渠道
|
||||
groupPlatform map[int64]string // groupID → platform
|
||||
|
||||
// 冷路径(CRUD 操作)
|
||||
byID map[int64]*Channel
|
||||
@@ -137,12 +151,13 @@ func (s *ChannelService) buildCache(ctx context.Context) (*channelCache, error)
|
||||
// error-TTL:失败时存入短 TTL 空缓存,防止紧密重试
|
||||
slog.Warn("failed to build channel cache", "error", err)
|
||||
errorCache := &channelCache{
|
||||
pricingByGroupModel: make(map[channelModelKey]*ChannelModelPricing),
|
||||
mappingByGroupModel: make(map[channelModelKey]string),
|
||||
channelByGroupID: make(map[int64]*Channel),
|
||||
groupPlatform: make(map[int64]string),
|
||||
byID: make(map[int64]*Channel),
|
||||
loadedAt: time.Now().Add(channelCacheTTL - channelErrorTTL), // 使剩余 TTL = errorTTL
|
||||
pricingByGroupModel: make(map[channelModelKey]*ChannelModelPricing),
|
||||
wildcardByGroupPlatform: make(map[channelGroupPlatformKey][]*wildcardPricingEntry),
|
||||
mappingByGroupModel: make(map[channelModelKey]string),
|
||||
channelByGroupID: make(map[int64]*Channel),
|
||||
groupPlatform: make(map[int64]string),
|
||||
byID: make(map[int64]*Channel),
|
||||
loadedAt: time.Now().Add(channelCacheTTL - channelErrorTTL), // 使剩余 TTL = errorTTL
|
||||
}
|
||||
s.cache.Store(errorCache)
|
||||
return nil, fmt.Errorf("list all channels: %w", err)
|
||||
@@ -163,12 +178,13 @@ func (s *ChannelService) buildCache(ctx context.Context) (*channelCache, error)
|
||||
}
|
||||
|
||||
cache := &channelCache{
|
||||
pricingByGroupModel: make(map[channelModelKey]*ChannelModelPricing),
|
||||
mappingByGroupModel: make(map[channelModelKey]string),
|
||||
channelByGroupID: make(map[int64]*Channel),
|
||||
groupPlatform: groupPlatforms,
|
||||
byID: make(map[int64]*Channel, len(channels)),
|
||||
loadedAt: time.Now(),
|
||||
pricingByGroupModel: make(map[channelModelKey]*ChannelModelPricing),
|
||||
wildcardByGroupPlatform: make(map[channelGroupPlatformKey][]*wildcardPricingEntry),
|
||||
mappingByGroupModel: make(map[channelModelKey]string),
|
||||
channelByGroupID: make(map[int64]*Channel),
|
||||
groupPlatform: groupPlatforms,
|
||||
byID: make(map[int64]*Channel, len(channels)),
|
||||
loadedAt: time.Now(),
|
||||
}
|
||||
|
||||
for i := range channels {
|
||||
@@ -187,8 +203,18 @@ func (s *ChannelService) buildCache(ctx context.Context) (*channelCache, error)
|
||||
continue // 跳过非本平台的定价
|
||||
}
|
||||
for _, model := range pricing.Models {
|
||||
key := channelModelKey{groupID: gid, platform: platform, model: strings.ToLower(model)}
|
||||
cache.pricingByGroupModel[key] = pricing
|
||||
if strings.HasSuffix(model, "*") {
|
||||
// 通配符模型 → 存入 wildcardByGroupPlatform
|
||||
prefix := strings.ToLower(strings.TrimSuffix(model, "*"))
|
||||
gpKey := channelGroupPlatformKey{groupID: gid, platform: platform}
|
||||
cache.wildcardByGroupPlatform[gpKey] = append(cache.wildcardByGroupPlatform[gpKey], &wildcardPricingEntry{
|
||||
prefix: prefix,
|
||||
pricing: pricing,
|
||||
})
|
||||
} else {
|
||||
key := channelModelKey{groupID: gid, platform: platform, model: strings.ToLower(model)}
|
||||
cache.pricingByGroupModel[key] = pricing
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -202,6 +228,14 @@ func (s *ChannelService) buildCache(ctx context.Context) (*channelCache, error)
|
||||
}
|
||||
}
|
||||
|
||||
// 通配符条目按前缀长度降序排列(最长前缀优先匹配)
|
||||
for gpKey, entries := range cache.wildcardByGroupPlatform {
|
||||
sort.Slice(entries, func(i, j int) bool {
|
||||
return len(entries[i].prefix) > len(entries[j].prefix)
|
||||
})
|
||||
cache.wildcardByGroupPlatform[gpKey] = entries
|
||||
}
|
||||
|
||||
s.cache.Store(cache)
|
||||
return cache, nil
|
||||
}
|
||||
@@ -212,6 +246,18 @@ func (s *ChannelService) invalidateCache() {
|
||||
s.cacheSF.Forget("channel_cache")
|
||||
}
|
||||
|
||||
// matchWildcard 在通配符定价中查找匹配项(最长前缀优先)
|
||||
func (c *channelCache) matchWildcard(groupID int64, platform, modelLower string) *ChannelModelPricing {
|
||||
gpKey := channelGroupPlatformKey{groupID: groupID, platform: platform}
|
||||
wildcards := c.wildcardByGroupPlatform[gpKey]
|
||||
for _, wc := range wildcards {
|
||||
if strings.HasPrefix(modelLower, wc.prefix) {
|
||||
return wc.pricing
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetChannelForGroup 获取分组关联的渠道(热路径 O(1))
|
||||
func (s *ChannelService) GetChannelForGroup(ctx context.Context, groupID int64) (*Channel, error) {
|
||||
cache, err := s.loadCache(ctx)
|
||||
@@ -245,7 +291,11 @@ func (s *ChannelService) GetChannelModelPricing(ctx context.Context, groupID int
|
||||
key := channelModelKey{groupID: groupID, platform: platform, model: strings.ToLower(model)}
|
||||
pricing, ok := cache.pricingByGroupModel[key]
|
||||
if !ok {
|
||||
return nil
|
||||
// 精确查找失败,尝试通配符匹配
|
||||
pricing = cache.matchWildcard(groupID, platform, strings.ToLower(model))
|
||||
if pricing == nil {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
cp := pricing.Clone()
|
||||
@@ -302,7 +352,14 @@ func (s *ChannelService) IsModelRestricted(ctx context.Context, groupID int64, m
|
||||
platform := cache.groupPlatform[groupID]
|
||||
key := channelModelKey{groupID: groupID, platform: platform, model: strings.ToLower(model)}
|
||||
_, exists := cache.pricingByGroupModel[key]
|
||||
return !exists
|
||||
if exists {
|
||||
return false
|
||||
}
|
||||
// 精确查找失败,尝试通配符匹配
|
||||
if cache.matchWildcard(groupID, platform, strings.ToLower(model)) != nil {
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// --- CRUD ---
|
||||
|
||||
Reference in New Issue
Block a user