Files
sub2api/backend/internal/service/gateway_channel_restriction_test.go
erio 6d3ea64a35 test: add unit tests for channel pricing restriction in scheduling phase
20 test cases covering:
- billingModelForRestriction: 4 cases (requested/channel_mapped/upstream/empty)
- resolveAccountUpstreamModel: 3 cases (antigravity/unsupported/non-antigravity)
- checkChannelPricingRestriction: 10 cases (nil guards, 3 billing sources,
  RestrictModels disabled, no channel)
- isUpstreamModelRestrictedByChannel: 3 cases (restricted/allowed/unsupported)
2026-04-04 11:25:01 +08:00

294 lines
10 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

//go:build unit
package service
import (
"context"
"testing"
"github.com/stretchr/testify/require"
)
// --- billingModelForRestriction ---
func TestBillingModelForRestriction_Requested(t *testing.T) {
t.Parallel()
got := billingModelForRestriction(BillingModelSourceRequested, "claude-sonnet-4-5", "claude-sonnet-4-6")
require.Equal(t, "claude-sonnet-4-5", got)
}
func TestBillingModelForRestriction_ChannelMapped(t *testing.T) {
t.Parallel()
got := billingModelForRestriction(BillingModelSourceChannelMapped, "claude-sonnet-4-5", "claude-sonnet-4-6")
require.Equal(t, "claude-sonnet-4-6", got)
}
func TestBillingModelForRestriction_Upstream(t *testing.T) {
t.Parallel()
got := billingModelForRestriction(BillingModelSourceUpstream, "claude-sonnet-4-5", "claude-sonnet-4-6")
require.Equal(t, "", got, "upstream should return empty (per-account check needed)")
}
func TestBillingModelForRestriction_Empty(t *testing.T) {
t.Parallel()
got := billingModelForRestriction("", "claude-sonnet-4-5", "claude-sonnet-4-6")
require.Equal(t, "claude-sonnet-4-6", got, "empty source defaults to channel_mapped")
}
// --- resolveAccountUpstreamModel ---
func TestResolveAccountUpstreamModel_Antigravity(t *testing.T) {
t.Parallel()
account := &Account{
Platform: PlatformAntigravity,
}
// Antigravity 平台使用 DefaultAntigravityModelMapping
got := resolveAccountUpstreamModel(account, "claude-sonnet-4-6")
require.Equal(t, "claude-sonnet-4-6", got)
}
func TestResolveAccountUpstreamModel_Antigravity_Unsupported(t *testing.T) {
t.Parallel()
account := &Account{
Platform: PlatformAntigravity,
}
got := resolveAccountUpstreamModel(account, "totally-unknown-model")
require.Equal(t, "", got, "unsupported model should return empty")
}
func TestResolveAccountUpstreamModel_NonAntigravity(t *testing.T) {
t.Parallel()
account := &Account{
Platform: PlatformAnthropic,
}
got := resolveAccountUpstreamModel(account, "claude-sonnet-4-6")
require.Equal(t, "claude-sonnet-4-6", got, "no mapping = passthrough")
}
// --- checkChannelPricingRestriction ---
func TestCheckChannelPricingRestriction_NilGroupID(t *testing.T) {
t.Parallel()
svc := &GatewayService{channelService: &ChannelService{}}
require.False(t, svc.checkChannelPricingRestriction(context.Background(), nil, "claude-sonnet-4"))
}
func TestCheckChannelPricingRestriction_NilChannelService(t *testing.T) {
t.Parallel()
svc := &GatewayService{}
gid := int64(10)
require.False(t, svc.checkChannelPricingRestriction(context.Background(), &gid, "claude-sonnet-4"))
}
func TestCheckChannelPricingRestriction_EmptyModel(t *testing.T) {
t.Parallel()
svc := &GatewayService{channelService: &ChannelService{}}
gid := int64(10)
require.False(t, svc.checkChannelPricingRestriction(context.Background(), &gid, ""))
}
func TestCheckChannelPricingRestriction_ChannelMapped_Restricted(t *testing.T) {
t.Parallel()
// 渠道映射 claude-sonnet-4-5 → claude-sonnet-4-6但定价列表只有 claude-opus-4-6
ch := Channel{
ID: 1,
Status: StatusActive,
GroupIDs: []int64{10},
RestrictModels: true,
BillingModelSource: BillingModelSourceChannelMapped,
ModelPricing: []ChannelModelPricing{
{Platform: "anthropic", Models: []string{"claude-opus-4-6"}},
},
ModelMapping: map[string]map[string]string{
"anthropic": {"claude-sonnet-4-5": "claude-sonnet-4-6"},
},
}
channelSvc := newTestChannelService(makeStandardRepo(ch, map[int64]string{10: "anthropic"}))
svc := &GatewayService{channelService: channelSvc}
gid := int64(10)
require.True(t, svc.checkChannelPricingRestriction(context.Background(), &gid, "claude-sonnet-4-5"),
"mapped model claude-sonnet-4-6 is NOT in pricing → restricted")
}
func TestCheckChannelPricingRestriction_ChannelMapped_Allowed(t *testing.T) {
t.Parallel()
// 渠道映射 claude-sonnet-4-5 → claude-sonnet-4-6定价列表包含 claude-sonnet-4-6
ch := Channel{
ID: 1,
Status: StatusActive,
GroupIDs: []int64{10},
RestrictModels: true,
BillingModelSource: BillingModelSourceChannelMapped,
ModelPricing: []ChannelModelPricing{
{Platform: "anthropic", Models: []string{"claude-sonnet-4-6"}},
},
ModelMapping: map[string]map[string]string{
"anthropic": {"claude-sonnet-4-5": "claude-sonnet-4-6"},
},
}
channelSvc := newTestChannelService(makeStandardRepo(ch, map[int64]string{10: "anthropic"}))
svc := &GatewayService{channelService: channelSvc}
gid := int64(10)
require.False(t, svc.checkChannelPricingRestriction(context.Background(), &gid, "claude-sonnet-4-5"),
"mapped model claude-sonnet-4-6 IS in pricing → allowed")
}
func TestCheckChannelPricingRestriction_Requested_Restricted(t *testing.T) {
t.Parallel()
// billing_model_source=requested定价列表有 claude-sonnet-4-6 但请求的是 claude-sonnet-4-5
ch := Channel{
ID: 1,
Status: StatusActive,
GroupIDs: []int64{10},
RestrictModels: true,
BillingModelSource: BillingModelSourceRequested,
ModelPricing: []ChannelModelPricing{
{Platform: "anthropic", Models: []string{"claude-sonnet-4-6"}},
},
}
channelSvc := newTestChannelService(makeStandardRepo(ch, map[int64]string{10: "anthropic"}))
svc := &GatewayService{channelService: channelSvc}
gid := int64(10)
require.True(t, svc.checkChannelPricingRestriction(context.Background(), &gid, "claude-sonnet-4-5"),
"requested model claude-sonnet-4-5 is NOT in pricing → restricted")
}
func TestCheckChannelPricingRestriction_Requested_Allowed(t *testing.T) {
t.Parallel()
ch := Channel{
ID: 1,
Status: StatusActive,
GroupIDs: []int64{10},
RestrictModels: true,
BillingModelSource: BillingModelSourceRequested,
ModelPricing: []ChannelModelPricing{
{Platform: "anthropic", Models: []string{"claude-sonnet-4-5"}},
},
}
channelSvc := newTestChannelService(makeStandardRepo(ch, map[int64]string{10: "anthropic"}))
svc := &GatewayService{channelService: channelSvc}
gid := int64(10)
require.False(t, svc.checkChannelPricingRestriction(context.Background(), &gid, "claude-sonnet-4-5"),
"requested model IS in pricing → allowed")
}
func TestCheckChannelPricingRestriction_Upstream_SkipsPreCheck(t *testing.T) {
t.Parallel()
// upstream 模式:预检查始终跳过(返回 false需逐账号检查
ch := Channel{
ID: 1,
Status: StatusActive,
GroupIDs: []int64{10},
RestrictModels: true,
BillingModelSource: BillingModelSourceUpstream,
ModelPricing: []ChannelModelPricing{
{Platform: "anthropic", Models: []string{"claude-opus-4-6"}},
},
}
channelSvc := newTestChannelService(makeStandardRepo(ch, map[int64]string{10: "anthropic"}))
svc := &GatewayService{channelService: channelSvc}
gid := int64(10)
require.False(t, svc.checkChannelPricingRestriction(context.Background(), &gid, "unknown-model"),
"upstream mode should skip pre-check (per-account check needed)")
}
func TestCheckChannelPricingRestriction_RestrictModelsDisabled(t *testing.T) {
t.Parallel()
ch := Channel{
ID: 1,
Status: StatusActive,
GroupIDs: []int64{10},
RestrictModels: false, // 未开启模型限制
ModelPricing: []ChannelModelPricing{
{Platform: "anthropic", Models: []string{"claude-opus-4-6"}},
},
}
channelSvc := newTestChannelService(makeStandardRepo(ch, map[int64]string{10: "anthropic"}))
svc := &GatewayService{channelService: channelSvc}
gid := int64(10)
require.False(t, svc.checkChannelPricingRestriction(context.Background(), &gid, "any-model"),
"RestrictModels=false → always allowed")
}
func TestCheckChannelPricingRestriction_NoChannel(t *testing.T) {
t.Parallel()
// 分组没有关联渠道
repo := &mockChannelRepository{
listAllFn: func(_ context.Context) ([]Channel, error) { return nil, nil },
}
channelSvc := newTestChannelService(repo)
svc := &GatewayService{channelService: channelSvc}
gid := int64(999)
require.False(t, svc.checkChannelPricingRestriction(context.Background(), &gid, "any-model"),
"no channel for group → allowed")
}
// --- isUpstreamModelRestrictedByChannel ---
func TestIsUpstreamModelRestrictedByChannel_Restricted(t *testing.T) {
t.Parallel()
ch := Channel{
ID: 1,
Status: StatusActive,
GroupIDs: []int64{10},
RestrictModels: true,
ModelPricing: []ChannelModelPricing{
{Platform: "anthropic", Models: []string{"claude-opus-4-6"}},
},
}
channelSvc := newTestChannelService(makeStandardRepo(ch, map[int64]string{10: "anthropic"}))
svc := &GatewayService{channelService: channelSvc}
account := &Account{Platform: PlatformAntigravity}
// claude-sonnet-4-6 在 DefaultAntigravityModelMapping 中,映射后仍为 claude-sonnet-4-6
// 但定价列表只有 claude-opus-4-6
require.True(t, svc.isUpstreamModelRestrictedByChannel(context.Background(), 10, account, "claude-sonnet-4-6"),
"upstream model claude-sonnet-4-6 NOT in pricing → restricted")
}
func TestIsUpstreamModelRestrictedByChannel_Allowed(t *testing.T) {
t.Parallel()
ch := Channel{
ID: 1,
Status: StatusActive,
GroupIDs: []int64{10},
RestrictModels: true,
ModelPricing: []ChannelModelPricing{
{Platform: "anthropic", Models: []string{"claude-sonnet-4-6"}},
},
}
channelSvc := newTestChannelService(makeStandardRepo(ch, map[int64]string{10: "anthropic"}))
svc := &GatewayService{channelService: channelSvc}
account := &Account{Platform: PlatformAntigravity}
require.False(t, svc.isUpstreamModelRestrictedByChannel(context.Background(), 10, account, "claude-sonnet-4-6"),
"upstream model claude-sonnet-4-6 IS in pricing → allowed")
}
func TestIsUpstreamModelRestrictedByChannel_UnsupportedModel(t *testing.T) {
t.Parallel()
ch := Channel{
ID: 1,
Status: StatusActive,
GroupIDs: []int64{10},
RestrictModels: true,
ModelPricing: []ChannelModelPricing{
{Platform: "anthropic", Models: []string{"claude-opus-4-6"}},
},
}
channelSvc := newTestChannelService(makeStandardRepo(ch, map[int64]string{10: "anthropic"}))
svc := &GatewayService{channelService: channelSvc}
account := &Account{Platform: PlatformAntigravity}
// totally-unknown-model 不在 DefaultAntigravityModelMapping 中 → 映射结果为空
require.False(t, svc.isUpstreamModelRestrictedByChannel(context.Background(), 10, account, "totally-unknown-model"),
"unmappable model → upstream model empty → not restricted (account filter handles this)")
}