test: add unit tests for channel pricing restriction in scheduling phase
20 test cases covering: - billingModelForRestriction: 4 cases (requested/channel_mapped/upstream/empty) - resolveAccountUpstreamModel: 3 cases (antigravity/unsupported/non-antigravity) - checkChannelPricingRestriction: 10 cases (nil guards, 3 billing sources, RestrictModels disabled, no channel) - isUpstreamModelRestrictedByChannel: 3 cases (restricted/allowed/unsupported)
This commit is contained in:
293
backend/internal/service/gateway_channel_restriction_test.go
Normal file
293
backend/internal/service/gateway_channel_restriction_test.go
Normal file
@@ -0,0 +1,293 @@
|
|||||||
|
//go:build unit
|
||||||
|
|
||||||
|
package service
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
// --- billingModelForRestriction ---
|
||||||
|
|
||||||
|
func TestBillingModelForRestriction_Requested(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
got := billingModelForRestriction(BillingModelSourceRequested, "claude-sonnet-4-5", "claude-sonnet-4-6")
|
||||||
|
require.Equal(t, "claude-sonnet-4-5", got)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBillingModelForRestriction_ChannelMapped(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
got := billingModelForRestriction(BillingModelSourceChannelMapped, "claude-sonnet-4-5", "claude-sonnet-4-6")
|
||||||
|
require.Equal(t, "claude-sonnet-4-6", got)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBillingModelForRestriction_Upstream(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
got := billingModelForRestriction(BillingModelSourceUpstream, "claude-sonnet-4-5", "claude-sonnet-4-6")
|
||||||
|
require.Equal(t, "", got, "upstream should return empty (per-account check needed)")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBillingModelForRestriction_Empty(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
got := billingModelForRestriction("", "claude-sonnet-4-5", "claude-sonnet-4-6")
|
||||||
|
require.Equal(t, "claude-sonnet-4-6", got, "empty source defaults to channel_mapped")
|
||||||
|
}
|
||||||
|
|
||||||
|
// --- resolveAccountUpstreamModel ---
|
||||||
|
|
||||||
|
func TestResolveAccountUpstreamModel_Antigravity(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
account := &Account{
|
||||||
|
Platform: PlatformAntigravity,
|
||||||
|
}
|
||||||
|
// Antigravity 平台使用 DefaultAntigravityModelMapping
|
||||||
|
got := resolveAccountUpstreamModel(account, "claude-sonnet-4-6")
|
||||||
|
require.Equal(t, "claude-sonnet-4-6", got)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestResolveAccountUpstreamModel_Antigravity_Unsupported(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
account := &Account{
|
||||||
|
Platform: PlatformAntigravity,
|
||||||
|
}
|
||||||
|
got := resolveAccountUpstreamModel(account, "totally-unknown-model")
|
||||||
|
require.Equal(t, "", got, "unsupported model should return empty")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestResolveAccountUpstreamModel_NonAntigravity(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
account := &Account{
|
||||||
|
Platform: PlatformAnthropic,
|
||||||
|
}
|
||||||
|
got := resolveAccountUpstreamModel(account, "claude-sonnet-4-6")
|
||||||
|
require.Equal(t, "claude-sonnet-4-6", got, "no mapping = passthrough")
|
||||||
|
}
|
||||||
|
|
||||||
|
// --- checkChannelPricingRestriction ---
|
||||||
|
|
||||||
|
func TestCheckChannelPricingRestriction_NilGroupID(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
svc := &GatewayService{channelService: &ChannelService{}}
|
||||||
|
require.False(t, svc.checkChannelPricingRestriction(context.Background(), nil, "claude-sonnet-4"))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCheckChannelPricingRestriction_NilChannelService(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
svc := &GatewayService{}
|
||||||
|
gid := int64(10)
|
||||||
|
require.False(t, svc.checkChannelPricingRestriction(context.Background(), &gid, "claude-sonnet-4"))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCheckChannelPricingRestriction_EmptyModel(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
svc := &GatewayService{channelService: &ChannelService{}}
|
||||||
|
gid := int64(10)
|
||||||
|
require.False(t, svc.checkChannelPricingRestriction(context.Background(), &gid, ""))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCheckChannelPricingRestriction_ChannelMapped_Restricted(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
// 渠道映射 claude-sonnet-4-5 → claude-sonnet-4-6,但定价列表只有 claude-opus-4-6
|
||||||
|
ch := Channel{
|
||||||
|
ID: 1,
|
||||||
|
Status: StatusActive,
|
||||||
|
GroupIDs: []int64{10},
|
||||||
|
RestrictModels: true,
|
||||||
|
BillingModelSource: BillingModelSourceChannelMapped,
|
||||||
|
ModelPricing: []ChannelModelPricing{
|
||||||
|
{Platform: "anthropic", Models: []string{"claude-opus-4-6"}},
|
||||||
|
},
|
||||||
|
ModelMapping: map[string]map[string]string{
|
||||||
|
"anthropic": {"claude-sonnet-4-5": "claude-sonnet-4-6"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
channelSvc := newTestChannelService(makeStandardRepo(ch, map[int64]string{10: "anthropic"}))
|
||||||
|
svc := &GatewayService{channelService: channelSvc}
|
||||||
|
|
||||||
|
gid := int64(10)
|
||||||
|
require.True(t, svc.checkChannelPricingRestriction(context.Background(), &gid, "claude-sonnet-4-5"),
|
||||||
|
"mapped model claude-sonnet-4-6 is NOT in pricing → restricted")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCheckChannelPricingRestriction_ChannelMapped_Allowed(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
// 渠道映射 claude-sonnet-4-5 → claude-sonnet-4-6,定价列表包含 claude-sonnet-4-6
|
||||||
|
ch := Channel{
|
||||||
|
ID: 1,
|
||||||
|
Status: StatusActive,
|
||||||
|
GroupIDs: []int64{10},
|
||||||
|
RestrictModels: true,
|
||||||
|
BillingModelSource: BillingModelSourceChannelMapped,
|
||||||
|
ModelPricing: []ChannelModelPricing{
|
||||||
|
{Platform: "anthropic", Models: []string{"claude-sonnet-4-6"}},
|
||||||
|
},
|
||||||
|
ModelMapping: map[string]map[string]string{
|
||||||
|
"anthropic": {"claude-sonnet-4-5": "claude-sonnet-4-6"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
channelSvc := newTestChannelService(makeStandardRepo(ch, map[int64]string{10: "anthropic"}))
|
||||||
|
svc := &GatewayService{channelService: channelSvc}
|
||||||
|
|
||||||
|
gid := int64(10)
|
||||||
|
require.False(t, svc.checkChannelPricingRestriction(context.Background(), &gid, "claude-sonnet-4-5"),
|
||||||
|
"mapped model claude-sonnet-4-6 IS in pricing → allowed")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCheckChannelPricingRestriction_Requested_Restricted(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
// billing_model_source=requested,定价列表有 claude-sonnet-4-6 但请求的是 claude-sonnet-4-5
|
||||||
|
ch := Channel{
|
||||||
|
ID: 1,
|
||||||
|
Status: StatusActive,
|
||||||
|
GroupIDs: []int64{10},
|
||||||
|
RestrictModels: true,
|
||||||
|
BillingModelSource: BillingModelSourceRequested,
|
||||||
|
ModelPricing: []ChannelModelPricing{
|
||||||
|
{Platform: "anthropic", Models: []string{"claude-sonnet-4-6"}},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
channelSvc := newTestChannelService(makeStandardRepo(ch, map[int64]string{10: "anthropic"}))
|
||||||
|
svc := &GatewayService{channelService: channelSvc}
|
||||||
|
|
||||||
|
gid := int64(10)
|
||||||
|
require.True(t, svc.checkChannelPricingRestriction(context.Background(), &gid, "claude-sonnet-4-5"),
|
||||||
|
"requested model claude-sonnet-4-5 is NOT in pricing → restricted")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCheckChannelPricingRestriction_Requested_Allowed(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
ch := Channel{
|
||||||
|
ID: 1,
|
||||||
|
Status: StatusActive,
|
||||||
|
GroupIDs: []int64{10},
|
||||||
|
RestrictModels: true,
|
||||||
|
BillingModelSource: BillingModelSourceRequested,
|
||||||
|
ModelPricing: []ChannelModelPricing{
|
||||||
|
{Platform: "anthropic", Models: []string{"claude-sonnet-4-5"}},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
channelSvc := newTestChannelService(makeStandardRepo(ch, map[int64]string{10: "anthropic"}))
|
||||||
|
svc := &GatewayService{channelService: channelSvc}
|
||||||
|
|
||||||
|
gid := int64(10)
|
||||||
|
require.False(t, svc.checkChannelPricingRestriction(context.Background(), &gid, "claude-sonnet-4-5"),
|
||||||
|
"requested model IS in pricing → allowed")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCheckChannelPricingRestriction_Upstream_SkipsPreCheck(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
// upstream 模式:预检查始终跳过(返回 false),需逐账号检查
|
||||||
|
ch := Channel{
|
||||||
|
ID: 1,
|
||||||
|
Status: StatusActive,
|
||||||
|
GroupIDs: []int64{10},
|
||||||
|
RestrictModels: true,
|
||||||
|
BillingModelSource: BillingModelSourceUpstream,
|
||||||
|
ModelPricing: []ChannelModelPricing{
|
||||||
|
{Platform: "anthropic", Models: []string{"claude-opus-4-6"}},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
channelSvc := newTestChannelService(makeStandardRepo(ch, map[int64]string{10: "anthropic"}))
|
||||||
|
svc := &GatewayService{channelService: channelSvc}
|
||||||
|
|
||||||
|
gid := int64(10)
|
||||||
|
require.False(t, svc.checkChannelPricingRestriction(context.Background(), &gid, "unknown-model"),
|
||||||
|
"upstream mode should skip pre-check (per-account check needed)")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCheckChannelPricingRestriction_RestrictModelsDisabled(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
ch := Channel{
|
||||||
|
ID: 1,
|
||||||
|
Status: StatusActive,
|
||||||
|
GroupIDs: []int64{10},
|
||||||
|
RestrictModels: false, // 未开启模型限制
|
||||||
|
ModelPricing: []ChannelModelPricing{
|
||||||
|
{Platform: "anthropic", Models: []string{"claude-opus-4-6"}},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
channelSvc := newTestChannelService(makeStandardRepo(ch, map[int64]string{10: "anthropic"}))
|
||||||
|
svc := &GatewayService{channelService: channelSvc}
|
||||||
|
|
||||||
|
gid := int64(10)
|
||||||
|
require.False(t, svc.checkChannelPricingRestriction(context.Background(), &gid, "any-model"),
|
||||||
|
"RestrictModels=false → always allowed")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCheckChannelPricingRestriction_NoChannel(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
// 分组没有关联渠道
|
||||||
|
repo := &mockChannelRepository{
|
||||||
|
listAllFn: func(_ context.Context) ([]Channel, error) { return nil, nil },
|
||||||
|
}
|
||||||
|
channelSvc := newTestChannelService(repo)
|
||||||
|
svc := &GatewayService{channelService: channelSvc}
|
||||||
|
|
||||||
|
gid := int64(999)
|
||||||
|
require.False(t, svc.checkChannelPricingRestriction(context.Background(), &gid, "any-model"),
|
||||||
|
"no channel for group → allowed")
|
||||||
|
}
|
||||||
|
|
||||||
|
// --- isUpstreamModelRestrictedByChannel ---
|
||||||
|
|
||||||
|
func TestIsUpstreamModelRestrictedByChannel_Restricted(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
ch := Channel{
|
||||||
|
ID: 1,
|
||||||
|
Status: StatusActive,
|
||||||
|
GroupIDs: []int64{10},
|
||||||
|
RestrictModels: true,
|
||||||
|
ModelPricing: []ChannelModelPricing{
|
||||||
|
{Platform: "anthropic", Models: []string{"claude-opus-4-6"}},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
channelSvc := newTestChannelService(makeStandardRepo(ch, map[int64]string{10: "anthropic"}))
|
||||||
|
svc := &GatewayService{channelService: channelSvc}
|
||||||
|
|
||||||
|
account := &Account{Platform: PlatformAntigravity}
|
||||||
|
// claude-sonnet-4-6 在 DefaultAntigravityModelMapping 中,映射后仍为 claude-sonnet-4-6
|
||||||
|
// 但定价列表只有 claude-opus-4-6
|
||||||
|
require.True(t, svc.isUpstreamModelRestrictedByChannel(context.Background(), 10, account, "claude-sonnet-4-6"),
|
||||||
|
"upstream model claude-sonnet-4-6 NOT in pricing → restricted")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestIsUpstreamModelRestrictedByChannel_Allowed(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
ch := Channel{
|
||||||
|
ID: 1,
|
||||||
|
Status: StatusActive,
|
||||||
|
GroupIDs: []int64{10},
|
||||||
|
RestrictModels: true,
|
||||||
|
ModelPricing: []ChannelModelPricing{
|
||||||
|
{Platform: "anthropic", Models: []string{"claude-sonnet-4-6"}},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
channelSvc := newTestChannelService(makeStandardRepo(ch, map[int64]string{10: "anthropic"}))
|
||||||
|
svc := &GatewayService{channelService: channelSvc}
|
||||||
|
|
||||||
|
account := &Account{Platform: PlatformAntigravity}
|
||||||
|
require.False(t, svc.isUpstreamModelRestrictedByChannel(context.Background(), 10, account, "claude-sonnet-4-6"),
|
||||||
|
"upstream model claude-sonnet-4-6 IS in pricing → allowed")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestIsUpstreamModelRestrictedByChannel_UnsupportedModel(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
ch := Channel{
|
||||||
|
ID: 1,
|
||||||
|
Status: StatusActive,
|
||||||
|
GroupIDs: []int64{10},
|
||||||
|
RestrictModels: true,
|
||||||
|
ModelPricing: []ChannelModelPricing{
|
||||||
|
{Platform: "anthropic", Models: []string{"claude-opus-4-6"}},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
channelSvc := newTestChannelService(makeStandardRepo(ch, map[int64]string{10: "anthropic"}))
|
||||||
|
svc := &GatewayService{channelService: channelSvc}
|
||||||
|
|
||||||
|
account := &Account{Platform: PlatformAntigravity}
|
||||||
|
// totally-unknown-model 不在 DefaultAntigravityModelMapping 中 → 映射结果为空
|
||||||
|
require.False(t, svc.isUpstreamModelRestrictedByChannel(context.Background(), 10, account, "totally-unknown-model"),
|
||||||
|
"unmappable model → upstream model empty → not restricted (account filter handles this)")
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user