feat(channel): 通配符定价匹配 + OpenAI BillingModelSource + 按次价格校验 + 用户端计费模式展示

- 定价查找支持通配符(suffix *),最长前缀优先匹配
- 模型限制(restrict_models)同样支持通配符匹配
- OpenAI 网关接入渠道映射/BillingModelSource/模型限制
- 按次/图片计费模式创建时强制要求价格或层级(前后端)
- 用户使用记录列表增加计费模式 badge 列
This commit is contained in:
erio
2026-03-31 00:23:45 +08:00
parent 0fbc9a44d3
commit 8d03c52e15
11 changed files with 255 additions and 22 deletions

View File

@@ -4,6 +4,7 @@ import (
"context"
"fmt"
"log/slog"
"sort"
"strings"
"sync/atomic"
"time"
@@ -57,13 +58,26 @@ type channelModelKey struct {
model string // lowercase
}
// channelGroupPlatformKey 通配符定价缓存键
type channelGroupPlatformKey struct {
groupID int64
platform string
}
// wildcardPricingEntry 通配符定价条目
type wildcardPricingEntry struct {
prefix string
pricing *ChannelModelPricing
}
// channelCache 渠道缓存快照(扁平化哈希结构,热路径 O(1) 查找)
type channelCache struct {
// 热路径查找
pricingByGroupModel map[channelModelKey]*ChannelModelPricing // (groupID, platform, model) → 定价
mappingByGroupModel map[channelModelKey]string // (groupID, platform, model) → 映射目标
channelByGroupID map[int64]*Channel // groupID → 渠道
groupPlatform map[int64]string // groupID → platform
pricingByGroupModel map[channelModelKey]*ChannelModelPricing // (groupID, platform, model) → 定价
wildcardByGroupPlatform map[channelGroupPlatformKey][]*wildcardPricingEntry // (groupID, platform) → 通配符定价(前缀长度降序)
mappingByGroupModel map[channelModelKey]string // (groupID, platform, model) → 映射目标
channelByGroupID map[int64]*Channel // groupID → 渠道
groupPlatform map[int64]string // groupID → platform
// 冷路径CRUD 操作)
byID map[int64]*Channel
@@ -137,12 +151,13 @@ func (s *ChannelService) buildCache(ctx context.Context) (*channelCache, error)
// error-TTL失败时存入短 TTL 空缓存,防止紧密重试
slog.Warn("failed to build channel cache", "error", err)
errorCache := &channelCache{
pricingByGroupModel: make(map[channelModelKey]*ChannelModelPricing),
mappingByGroupModel: make(map[channelModelKey]string),
channelByGroupID: make(map[int64]*Channel),
groupPlatform: make(map[int64]string),
byID: make(map[int64]*Channel),
loadedAt: time.Now().Add(channelCacheTTL - channelErrorTTL), // 使剩余 TTL = errorTTL
pricingByGroupModel: make(map[channelModelKey]*ChannelModelPricing),
wildcardByGroupPlatform: make(map[channelGroupPlatformKey][]*wildcardPricingEntry),
mappingByGroupModel: make(map[channelModelKey]string),
channelByGroupID: make(map[int64]*Channel),
groupPlatform: make(map[int64]string),
byID: make(map[int64]*Channel),
loadedAt: time.Now().Add(channelCacheTTL - channelErrorTTL), // 使剩余 TTL = errorTTL
}
s.cache.Store(errorCache)
return nil, fmt.Errorf("list all channels: %w", err)
@@ -163,12 +178,13 @@ func (s *ChannelService) buildCache(ctx context.Context) (*channelCache, error)
}
cache := &channelCache{
pricingByGroupModel: make(map[channelModelKey]*ChannelModelPricing),
mappingByGroupModel: make(map[channelModelKey]string),
channelByGroupID: make(map[int64]*Channel),
groupPlatform: groupPlatforms,
byID: make(map[int64]*Channel, len(channels)),
loadedAt: time.Now(),
pricingByGroupModel: make(map[channelModelKey]*ChannelModelPricing),
wildcardByGroupPlatform: make(map[channelGroupPlatformKey][]*wildcardPricingEntry),
mappingByGroupModel: make(map[channelModelKey]string),
channelByGroupID: make(map[int64]*Channel),
groupPlatform: groupPlatforms,
byID: make(map[int64]*Channel, len(channels)),
loadedAt: time.Now(),
}
for i := range channels {
@@ -187,8 +203,18 @@ func (s *ChannelService) buildCache(ctx context.Context) (*channelCache, error)
continue // 跳过非本平台的定价
}
for _, model := range pricing.Models {
key := channelModelKey{groupID: gid, platform: platform, model: strings.ToLower(model)}
cache.pricingByGroupModel[key] = pricing
if strings.HasSuffix(model, "*") {
// 通配符模型 → 存入 wildcardByGroupPlatform
prefix := strings.ToLower(strings.TrimSuffix(model, "*"))
gpKey := channelGroupPlatformKey{groupID: gid, platform: platform}
cache.wildcardByGroupPlatform[gpKey] = append(cache.wildcardByGroupPlatform[gpKey], &wildcardPricingEntry{
prefix: prefix,
pricing: pricing,
})
} else {
key := channelModelKey{groupID: gid, platform: platform, model: strings.ToLower(model)}
cache.pricingByGroupModel[key] = pricing
}
}
}
@@ -202,6 +228,14 @@ func (s *ChannelService) buildCache(ctx context.Context) (*channelCache, error)
}
}
// 通配符条目按前缀长度降序排列(最长前缀优先匹配)
for gpKey, entries := range cache.wildcardByGroupPlatform {
sort.Slice(entries, func(i, j int) bool {
return len(entries[i].prefix) > len(entries[j].prefix)
})
cache.wildcardByGroupPlatform[gpKey] = entries
}
s.cache.Store(cache)
return cache, nil
}
@@ -212,6 +246,18 @@ func (s *ChannelService) invalidateCache() {
s.cacheSF.Forget("channel_cache")
}
// matchWildcard 在通配符定价中查找匹配项(最长前缀优先)
func (c *channelCache) matchWildcard(groupID int64, platform, modelLower string) *ChannelModelPricing {
gpKey := channelGroupPlatformKey{groupID: groupID, platform: platform}
wildcards := c.wildcardByGroupPlatform[gpKey]
for _, wc := range wildcards {
if strings.HasPrefix(modelLower, wc.prefix) {
return wc.pricing
}
}
return nil
}
// GetChannelForGroup 获取分组关联的渠道(热路径 O(1)
func (s *ChannelService) GetChannelForGroup(ctx context.Context, groupID int64) (*Channel, error) {
cache, err := s.loadCache(ctx)
@@ -245,7 +291,11 @@ func (s *ChannelService) GetChannelModelPricing(ctx context.Context, groupID int
key := channelModelKey{groupID: groupID, platform: platform, model: strings.ToLower(model)}
pricing, ok := cache.pricingByGroupModel[key]
if !ok {
return nil
// 精确查找失败,尝试通配符匹配
pricing = cache.matchWildcard(groupID, platform, strings.ToLower(model))
if pricing == nil {
return nil
}
}
cp := pricing.Clone()
@@ -302,7 +352,14 @@ func (s *ChannelService) IsModelRestricted(ctx context.Context, groupID int64, m
platform := cache.groupPlatform[groupID]
key := channelModelKey{groupID: groupID, platform: platform, model: strings.ToLower(model)}
_, exists := cache.pricingByGroupModel[key]
return !exists
if exists {
return false
}
// 精确查找失败,尝试通配符匹配
if cache.matchWildcard(groupID, platform, strings.ToLower(model)) != nil {
return false
}
return true
}
// --- CRUD ---

View File

@@ -146,6 +146,7 @@ func newOpenAIRecordUsageServiceForTest(usageRepo UsageLogRepository, userRepo U
&DeferredService{},
nil,
nil,
nil,
)
svc.userGroupRateResolver = newUserGroupRateResolver(
rateRepo,

View File

@@ -323,6 +323,7 @@ type OpenAIGatewayService struct {
toolCorrector *CodexToolCorrector
openaiWSResolver OpenAIWSProtocolResolver
resolver *ModelPricingResolver
channelService *ChannelService
openaiWSPoolOnce sync.Once
openaiWSStateStoreOnce sync.Once
@@ -359,6 +360,7 @@ func NewOpenAIGatewayService(
deferredService *DeferredService,
openAITokenProvider *OpenAITokenProvider,
resolver *ModelPricingResolver,
channelService *ChannelService,
) *OpenAIGatewayService {
svc := &OpenAIGatewayService{
accountRepo: accountRepo,
@@ -387,6 +389,7 @@ func NewOpenAIGatewayService(
toolCorrector: NewCodexToolCorrector(),
openaiWSResolver: NewOpenAIWSProtocolResolver(cfg),
resolver: resolver,
channelService: channelService,
responseHeaderFilter: compileResponseHeaderFilter(cfg),
codexSnapshotThrottle: newAccountWriteThrottle(openAICodexSnapshotPersistMinInterval),
}
@@ -394,6 +397,22 @@ func NewOpenAIGatewayService(
return svc
}
// ResolveChannelMapping 解析渠道级模型映射(代理到 ChannelService
func (s *OpenAIGatewayService) ResolveChannelMapping(ctx context.Context, groupID int64, model string) ChannelMappingResult {
if s.channelService == nil {
return ChannelMappingResult{MappedModel: model}
}
return s.channelService.ResolveChannelMapping(ctx, groupID, model)
}
// IsModelRestricted 检查模型是否被渠道限制(代理到 ChannelService
func (s *OpenAIGatewayService) IsModelRestricted(ctx context.Context, groupID int64, model string) bool {
if s.channelService == nil {
return false
}
return s.channelService.IsModelRestricted(ctx, groupID, model)
}
func (s *OpenAIGatewayService) getCodexSnapshotThrottle() *accountWriteThrottle {
if s != nil && s.codexSnapshotThrottle != nil {
return s.codexSnapshotThrottle
@@ -4113,6 +4132,10 @@ type OpenAIRecordUsageInput struct {
IPAddress string // 请求的客户端 IP 地址
RequestPayloadHash string
APIKeyService APIKeyQuotaUpdater
ChannelID int64
OriginalModel string
BillingModelSource string
ModelMappingChain string
}
// RecordUsage records usage and deducts balance
@@ -4158,6 +4181,12 @@ func (s *OpenAIGatewayService) RecordUsage(ctx context.Context, input *OpenAIRec
var cost *CostBreakdown
var err error
billingModel := forwardResultBillingModel(result.Model, result.UpstreamModel)
if result.BillingModel != "" {
billingModel = strings.TrimSpace(result.BillingModel)
}
if input.BillingModelSource == "requested" && input.OriginalModel != "" {
billingModel = input.OriginalModel
}
serviceTier := ""
if result.ServiceTier != nil {
serviceTier = strings.TrimSpace(*result.ServiceTier)
@@ -4223,6 +4252,9 @@ func (s *OpenAIGatewayService) RecordUsage(ctx context.Context, input *OpenAIRec
FirstTokenMs: result.FirstTokenMs,
CreatedAt: time.Now(),
}
// 设置渠道信息
usageLog.ChannelID = optionalInt64Ptr(input.ChannelID)
usageLog.ModelMappingChain = optionalTrimmedStringPtr(input.ModelMappingChain)
// 设置计费模式
if cost != nil && cost.BillingMode != "" {
billingMode := cost.BillingMode

View File

@@ -616,6 +616,7 @@ func TestNewOpenAIGatewayService_InitializesOpenAIWSResolver(t *testing.T) {
nil,
nil,
nil,
nil,
)
decision := svc.getOpenAIWSProtocolResolver().Resolve(nil)