feat(channel): 通配符定价匹配 + OpenAI BillingModelSource + 按次价格校验 + 用户端计费模式展示
- 定价查找支持通配符(suffix *),最长前缀优先匹配 - 模型限制(restrict_models)同样支持通配符匹配 - OpenAI 网关接入渠道映射/BillingModelSource/模型限制 - 按次/图片计费模式创建时强制要求价格或层级(前后端) - 用户使用记录列表增加计费模式 badge 列
This commit is contained in:
@@ -323,6 +323,7 @@ type OpenAIGatewayService struct {
|
||||
toolCorrector *CodexToolCorrector
|
||||
openaiWSResolver OpenAIWSProtocolResolver
|
||||
resolver *ModelPricingResolver
|
||||
channelService *ChannelService
|
||||
|
||||
openaiWSPoolOnce sync.Once
|
||||
openaiWSStateStoreOnce sync.Once
|
||||
@@ -359,6 +360,7 @@ func NewOpenAIGatewayService(
|
||||
deferredService *DeferredService,
|
||||
openAITokenProvider *OpenAITokenProvider,
|
||||
resolver *ModelPricingResolver,
|
||||
channelService *ChannelService,
|
||||
) *OpenAIGatewayService {
|
||||
svc := &OpenAIGatewayService{
|
||||
accountRepo: accountRepo,
|
||||
@@ -387,6 +389,7 @@ func NewOpenAIGatewayService(
|
||||
toolCorrector: NewCodexToolCorrector(),
|
||||
openaiWSResolver: NewOpenAIWSProtocolResolver(cfg),
|
||||
resolver: resolver,
|
||||
channelService: channelService,
|
||||
responseHeaderFilter: compileResponseHeaderFilter(cfg),
|
||||
codexSnapshotThrottle: newAccountWriteThrottle(openAICodexSnapshotPersistMinInterval),
|
||||
}
|
||||
@@ -394,6 +397,22 @@ func NewOpenAIGatewayService(
|
||||
return svc
|
||||
}
|
||||
|
||||
// ResolveChannelMapping 解析渠道级模型映射(代理到 ChannelService)
|
||||
func (s *OpenAIGatewayService) ResolveChannelMapping(ctx context.Context, groupID int64, model string) ChannelMappingResult {
|
||||
if s.channelService == nil {
|
||||
return ChannelMappingResult{MappedModel: model}
|
||||
}
|
||||
return s.channelService.ResolveChannelMapping(ctx, groupID, model)
|
||||
}
|
||||
|
||||
// IsModelRestricted 检查模型是否被渠道限制(代理到 ChannelService)
|
||||
func (s *OpenAIGatewayService) IsModelRestricted(ctx context.Context, groupID int64, model string) bool {
|
||||
if s.channelService == nil {
|
||||
return false
|
||||
}
|
||||
return s.channelService.IsModelRestricted(ctx, groupID, model)
|
||||
}
|
||||
|
||||
func (s *OpenAIGatewayService) getCodexSnapshotThrottle() *accountWriteThrottle {
|
||||
if s != nil && s.codexSnapshotThrottle != nil {
|
||||
return s.codexSnapshotThrottle
|
||||
@@ -4113,6 +4132,10 @@ type OpenAIRecordUsageInput struct {
|
||||
IPAddress string // 请求的客户端 IP 地址
|
||||
RequestPayloadHash string
|
||||
APIKeyService APIKeyQuotaUpdater
|
||||
ChannelID int64
|
||||
OriginalModel string
|
||||
BillingModelSource string
|
||||
ModelMappingChain string
|
||||
}
|
||||
|
||||
// RecordUsage records usage and deducts balance
|
||||
@@ -4158,6 +4181,12 @@ func (s *OpenAIGatewayService) RecordUsage(ctx context.Context, input *OpenAIRec
|
||||
var cost *CostBreakdown
|
||||
var err error
|
||||
billingModel := forwardResultBillingModel(result.Model, result.UpstreamModel)
|
||||
if result.BillingModel != "" {
|
||||
billingModel = strings.TrimSpace(result.BillingModel)
|
||||
}
|
||||
if input.BillingModelSource == "requested" && input.OriginalModel != "" {
|
||||
billingModel = input.OriginalModel
|
||||
}
|
||||
serviceTier := ""
|
||||
if result.ServiceTier != nil {
|
||||
serviceTier = strings.TrimSpace(*result.ServiceTier)
|
||||
@@ -4223,6 +4252,9 @@ func (s *OpenAIGatewayService) RecordUsage(ctx context.Context, input *OpenAIRec
|
||||
FirstTokenMs: result.FirstTokenMs,
|
||||
CreatedAt: time.Now(),
|
||||
}
|
||||
// 设置渠道信息
|
||||
usageLog.ChannelID = optionalInt64Ptr(input.ChannelID)
|
||||
usageLog.ModelMappingChain = optionalTrimmedStringPtr(input.ModelMappingChain)
|
||||
// 设置计费模式
|
||||
if cost != nil && cost.BillingMode != "" {
|
||||
billingMode := cost.BillingMode
|
||||
|
||||
Reference in New Issue
Block a user