feat(channel): 渠道管理全链路集成 — 模型映射、定价、限制、用量统计
- 渠道模型映射:支持精确匹配和通配符映射,按平台隔离 - 渠道模型定价:支持 token/按次/图片三种计费模式,区间分层定价 - 模型限制:渠道可限制仅允许定价列表中的模型 - 计费模型来源:支持 requested/upstream 两种计费模型选择 - 用量统计:usage_logs 新增 channel_id/model_mapping_chain/billing_tier/billing_mode 字段 - Dashboard 支持 model_source 维度(requested/upstream/mapping)查看模型统计 - 全部 gateway handler 统一接入 ResolveChannelMappingAndRestrict - 修复测试:同步 SoraGenerationRepository 接口、SQL INSERT 参数、scan 字段
This commit is contained in:
@@ -8,13 +8,10 @@ import (
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func channelTestPtrFloat64(v float64) *float64 { return &v }
|
||||
func channelTestPtrInt(v int) *int { return &v }
|
||||
|
||||
func TestGetModelPricing(t *testing.T) {
|
||||
ch := &Channel{
|
||||
ModelPricing: []ChannelModelPricing{
|
||||
{ID: 1, Models: []string{"claude-sonnet-4"}, BillingMode: BillingModeToken, InputPrice: channelTestPtrFloat64(3e-6)},
|
||||
{ID: 1, Models: []string{"claude-sonnet-4"}, BillingMode: BillingModeToken, InputPrice: testPtrFloat64(3e-6)},
|
||||
{ID: 3, Models: []string{"gpt-5.1"}, BillingMode: BillingModePerRequest},
|
||||
},
|
||||
}
|
||||
@@ -48,7 +45,7 @@ func TestGetModelPricing(t *testing.T) {
|
||||
func TestGetModelPricing_ReturnsCopy(t *testing.T) {
|
||||
ch := &Channel{
|
||||
ModelPricing: []ChannelModelPricing{
|
||||
{ID: 1, Models: []string{"claude-sonnet-4"}, InputPrice: channelTestPtrFloat64(3e-6)},
|
||||
{ID: 1, Models: []string{"claude-sonnet-4"}, InputPrice: testPtrFloat64(3e-6)},
|
||||
},
|
||||
}
|
||||
|
||||
@@ -73,23 +70,23 @@ func TestGetModelPricing_EmptyPricing(t *testing.T) {
|
||||
func TestGetIntervalForContext(t *testing.T) {
|
||||
p := &ChannelModelPricing{
|
||||
Intervals: []PricingInterval{
|
||||
{MinTokens: 0, MaxTokens: channelTestPtrInt(128000), InputPrice: channelTestPtrFloat64(1e-6)},
|
||||
{MinTokens: 128000, MaxTokens: nil, InputPrice: channelTestPtrFloat64(2e-6)},
|
||||
{MinTokens: 0, MaxTokens: testPtrInt(128000), InputPrice: testPtrFloat64(1e-6)},
|
||||
{MinTokens: 128000, MaxTokens: nil, InputPrice: testPtrFloat64(2e-6)},
|
||||
},
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
tokens int
|
||||
wantPrice *float64
|
||||
wantNil bool
|
||||
name string
|
||||
tokens int
|
||||
wantPrice *float64
|
||||
wantNil bool
|
||||
}{
|
||||
{"first interval", 50000, channelTestPtrFloat64(1e-6), false},
|
||||
{"first interval", 50000, testPtrFloat64(1e-6), false},
|
||||
// (min, max] — 128000 在第一个区间的 max,包含,所以匹配第一个
|
||||
{"boundary: max of first (inclusive)", 128000, channelTestPtrFloat64(1e-6), false},
|
||||
{"boundary: max of first (inclusive)", 128000, testPtrFloat64(1e-6), false},
|
||||
// 128001 > 128000,匹配第二个区间
|
||||
{"boundary: just above first max", 128001, channelTestPtrFloat64(2e-6), false},
|
||||
{"unbounded interval", 500000, channelTestPtrFloat64(2e-6), false},
|
||||
{"boundary: just above first max", 128001, testPtrFloat64(2e-6), false},
|
||||
{"unbounded interval", 500000, testPtrFloat64(2e-6), false},
|
||||
// (0, max] — 0 不匹配任何区间(左开)
|
||||
{"zero tokens: no match", 0, nil, true},
|
||||
}
|
||||
@@ -110,11 +107,11 @@ func TestGetIntervalForContext(t *testing.T) {
|
||||
func TestGetIntervalForContext_NoMatch(t *testing.T) {
|
||||
p := &ChannelModelPricing{
|
||||
Intervals: []PricingInterval{
|
||||
{MinTokens: 10000, MaxTokens: channelTestPtrInt(50000)},
|
||||
{MinTokens: 10000, MaxTokens: testPtrInt(50000)},
|
||||
},
|
||||
}
|
||||
require.Nil(t, p.GetIntervalForContext(5000)) // 5000 <= 10000, not > min
|
||||
require.Nil(t, p.GetIntervalForContext(10000)) // 10000 not > 10000 (left-open)
|
||||
require.Nil(t, p.GetIntervalForContext(5000)) // 5000 <= 10000, not > min
|
||||
require.Nil(t, p.GetIntervalForContext(10000)) // 10000 not > 10000 (left-open)
|
||||
require.NotNil(t, p.GetIntervalForContext(50000)) // 50000 <= 50000 (right-closed)
|
||||
require.Nil(t, p.GetIntervalForContext(50001)) // 50001 > 50000
|
||||
}
|
||||
@@ -127,9 +124,9 @@ func TestGetIntervalForContext_Empty(t *testing.T) {
|
||||
func TestGetTierByLabel(t *testing.T) {
|
||||
p := &ChannelModelPricing{
|
||||
Intervals: []PricingInterval{
|
||||
{TierLabel: "1K", PerRequestPrice: channelTestPtrFloat64(0.04)},
|
||||
{TierLabel: "2K", PerRequestPrice: channelTestPtrFloat64(0.08)},
|
||||
{TierLabel: "HD", PerRequestPrice: channelTestPtrFloat64(0.12)},
|
||||
{TierLabel: "1K", PerRequestPrice: testPtrFloat64(0.04)},
|
||||
{TierLabel: "2K", PerRequestPrice: testPtrFloat64(0.08)},
|
||||
{TierLabel: "HD", PerRequestPrice: testPtrFloat64(0.12)},
|
||||
},
|
||||
}
|
||||
|
||||
@@ -171,7 +168,7 @@ func TestChannelClone(t *testing.T) {
|
||||
{
|
||||
ID: 100,
|
||||
Models: []string{"model-a"},
|
||||
InputPrice: channelTestPtrFloat64(5e-6),
|
||||
InputPrice: testPtrFloat64(5e-6),
|
||||
},
|
||||
},
|
||||
}
|
||||
@@ -211,3 +208,102 @@ func TestChannelModelPricingClone(t *testing.T) {
|
||||
cloned.Intervals[0].TierLabel = "hacked"
|
||||
require.Equal(t, "tier1", original.Intervals[0].TierLabel)
|
||||
}
|
||||
|
||||
// --- BillingMode.IsValid ---
|
||||
|
||||
func TestBillingModeIsValid(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
mode BillingMode
|
||||
want bool
|
||||
}{
|
||||
{"token", BillingModeToken, true},
|
||||
{"per_request", BillingModePerRequest, true},
|
||||
{"image", BillingModeImage, true},
|
||||
{"empty", BillingMode(""), true},
|
||||
{"unknown", BillingMode("unknown"), false},
|
||||
{"random", BillingMode("xyz"), false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
require.Equal(t, tt.want, tt.mode.IsValid())
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// --- Channel.IsActive ---
|
||||
|
||||
func TestChannelIsActive(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
status string
|
||||
want bool
|
||||
}{
|
||||
{"active", StatusActive, true},
|
||||
{"disabled", "disabled", false},
|
||||
{"empty", "", false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
ch := &Channel{Status: tt.status}
|
||||
require.Equal(t, tt.want, ch.IsActive())
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// --- ChannelModelPricing.Clone edge cases ---
|
||||
|
||||
func TestChannelModelPricingClone_EdgeCases(t *testing.T) {
|
||||
t.Run("nil models", func(t *testing.T) {
|
||||
original := ChannelModelPricing{Models: nil}
|
||||
cloned := original.Clone()
|
||||
require.Nil(t, cloned.Models)
|
||||
})
|
||||
|
||||
t.Run("nil intervals", func(t *testing.T) {
|
||||
original := ChannelModelPricing{Intervals: nil}
|
||||
cloned := original.Clone()
|
||||
require.Nil(t, cloned.Intervals)
|
||||
})
|
||||
|
||||
t.Run("empty models", func(t *testing.T) {
|
||||
original := ChannelModelPricing{Models: []string{}}
|
||||
cloned := original.Clone()
|
||||
require.NotNil(t, cloned.Models)
|
||||
require.Empty(t, cloned.Models)
|
||||
})
|
||||
}
|
||||
|
||||
// --- Channel.Clone edge cases ---
|
||||
|
||||
func TestChannelClone_EdgeCases(t *testing.T) {
|
||||
t.Run("nil model mapping", func(t *testing.T) {
|
||||
original := &Channel{ID: 1, ModelMapping: nil}
|
||||
cloned := original.Clone()
|
||||
require.Nil(t, cloned.ModelMapping)
|
||||
})
|
||||
|
||||
t.Run("nil model pricing", func(t *testing.T) {
|
||||
original := &Channel{ID: 1, ModelPricing: nil}
|
||||
cloned := original.Clone()
|
||||
require.Nil(t, cloned.ModelPricing)
|
||||
})
|
||||
|
||||
t.Run("deep copy model mapping", func(t *testing.T) {
|
||||
original := &Channel{
|
||||
ID: 1,
|
||||
ModelMapping: map[string]map[string]string{
|
||||
"openai": {"gpt-4": "gpt-4-turbo"},
|
||||
},
|
||||
}
|
||||
cloned := original.Clone()
|
||||
|
||||
// Modify the cloned nested map
|
||||
cloned.ModelMapping["openai"]["gpt-4"] = "hacked"
|
||||
|
||||
// Original must remain unchanged
|
||||
require.Equal(t, "gpt-4-turbo", original.ModelMapping["openai"]["gpt-4"])
|
||||
})
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user