Files
sub2api/backend/internal/service/channel_service_test.go
erio c5688fef9a fix: remove cross-platform pricing/mapping leakage for antigravity groups
Antigravity groups were incorrectly matching pricing and model mapping
entries from anthropic/gemini platform tabs. Each platform should be
strictly isolated — antigravity groups only use antigravity-tagged pricing.
2026-04-05 20:42:24 +08:00

2202 lines
70 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

//go:build unit
package service
import (
"context"
"errors"
"testing"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/stretchr/testify/require"
)
// ---------------------------------------------------------------------------
// Mock: ChannelRepository
// ---------------------------------------------------------------------------
type mockChannelRepository struct {
listAllFn func(ctx context.Context) ([]Channel, error)
getGroupPlatformsFn func(ctx context.Context, groupIDs []int64) (map[int64]string, error)
createFn func(ctx context.Context, channel *Channel) error
getByIDFn func(ctx context.Context, id int64) (*Channel, error)
updateFn func(ctx context.Context, channel *Channel) error
deleteFn func(ctx context.Context, id int64) error
listFn func(ctx context.Context, params pagination.PaginationParams, status, search string) ([]Channel, *pagination.PaginationResult, error)
existsByNameFn func(ctx context.Context, name string) (bool, error)
existsByNameExcludingFn func(ctx context.Context, name string, excludeID int64) (bool, error)
getGroupIDsFn func(ctx context.Context, channelID int64) ([]int64, error)
setGroupIDsFn func(ctx context.Context, channelID int64, groupIDs []int64) error
getChannelIDByGroupIDFn func(ctx context.Context, groupID int64) (int64, error)
getGroupsInOtherChannelsFn func(ctx context.Context, channelID int64, groupIDs []int64) ([]int64, error)
listModelPricingFn func(ctx context.Context, channelID int64) ([]ChannelModelPricing, error)
createModelPricingFn func(ctx context.Context, pricing *ChannelModelPricing) error
updateModelPricingFn func(ctx context.Context, pricing *ChannelModelPricing) error
deleteModelPricingFn func(ctx context.Context, id int64) error
replaceModelPricingFn func(ctx context.Context, channelID int64, pricingList []ChannelModelPricing) error
}
func (m *mockChannelRepository) Create(ctx context.Context, channel *Channel) error {
if m.createFn != nil {
return m.createFn(ctx, channel)
}
return nil
}
func (m *mockChannelRepository) GetByID(ctx context.Context, id int64) (*Channel, error) {
if m.getByIDFn != nil {
return m.getByIDFn(ctx, id)
}
return nil, ErrChannelNotFound
}
func (m *mockChannelRepository) Update(ctx context.Context, channel *Channel) error {
if m.updateFn != nil {
return m.updateFn(ctx, channel)
}
return nil
}
func (m *mockChannelRepository) Delete(ctx context.Context, id int64) error {
if m.deleteFn != nil {
return m.deleteFn(ctx, id)
}
return nil
}
func (m *mockChannelRepository) List(ctx context.Context, params pagination.PaginationParams, status, search string) ([]Channel, *pagination.PaginationResult, error) {
if m.listFn != nil {
return m.listFn(ctx, params, status, search)
}
return nil, nil, nil
}
func (m *mockChannelRepository) ListAll(ctx context.Context) ([]Channel, error) {
if m.listAllFn != nil {
return m.listAllFn(ctx)
}
return nil, nil
}
func (m *mockChannelRepository) ExistsByName(ctx context.Context, name string) (bool, error) {
if m.existsByNameFn != nil {
return m.existsByNameFn(ctx, name)
}
return false, nil
}
func (m *mockChannelRepository) ExistsByNameExcluding(ctx context.Context, name string, excludeID int64) (bool, error) {
if m.existsByNameExcludingFn != nil {
return m.existsByNameExcludingFn(ctx, name, excludeID)
}
return false, nil
}
func (m *mockChannelRepository) GetGroupIDs(ctx context.Context, channelID int64) ([]int64, error) {
if m.getGroupIDsFn != nil {
return m.getGroupIDsFn(ctx, channelID)
}
return nil, nil
}
func (m *mockChannelRepository) SetGroupIDs(ctx context.Context, channelID int64, groupIDs []int64) error {
if m.setGroupIDsFn != nil {
return m.setGroupIDsFn(ctx, channelID, groupIDs)
}
return nil
}
func (m *mockChannelRepository) GetChannelIDByGroupID(ctx context.Context, groupID int64) (int64, error) {
if m.getChannelIDByGroupIDFn != nil {
return m.getChannelIDByGroupIDFn(ctx, groupID)
}
return 0, nil
}
func (m *mockChannelRepository) GetGroupsInOtherChannels(ctx context.Context, channelID int64, groupIDs []int64) ([]int64, error) {
if m.getGroupsInOtherChannelsFn != nil {
return m.getGroupsInOtherChannelsFn(ctx, channelID, groupIDs)
}
return nil, nil
}
func (m *mockChannelRepository) GetGroupPlatforms(ctx context.Context, groupIDs []int64) (map[int64]string, error) {
if m.getGroupPlatformsFn != nil {
return m.getGroupPlatformsFn(ctx, groupIDs)
}
return nil, nil
}
func (m *mockChannelRepository) ListModelPricing(ctx context.Context, channelID int64) ([]ChannelModelPricing, error) {
if m.listModelPricingFn != nil {
return m.listModelPricingFn(ctx, channelID)
}
return nil, nil
}
func (m *mockChannelRepository) CreateModelPricing(ctx context.Context, pricing *ChannelModelPricing) error {
if m.createModelPricingFn != nil {
return m.createModelPricingFn(ctx, pricing)
}
return nil
}
func (m *mockChannelRepository) UpdateModelPricing(ctx context.Context, pricing *ChannelModelPricing) error {
if m.updateModelPricingFn != nil {
return m.updateModelPricingFn(ctx, pricing)
}
return nil
}
func (m *mockChannelRepository) DeleteModelPricing(ctx context.Context, id int64) error {
if m.deleteModelPricingFn != nil {
return m.deleteModelPricingFn(ctx, id)
}
return nil
}
func (m *mockChannelRepository) ReplaceModelPricing(ctx context.Context, channelID int64, pricingList []ChannelModelPricing) error {
if m.replaceModelPricingFn != nil {
return m.replaceModelPricingFn(ctx, channelID, pricingList)
}
return nil
}
// ---------------------------------------------------------------------------
// Mock: APIKeyAuthCacheInvalidator
// ---------------------------------------------------------------------------
type mockChannelAuthCacheInvalidator struct {
invalidatedGroupIDs []int64
invalidatedKeys []string
invalidatedUserIDs []int64
}
func (m *mockChannelAuthCacheInvalidator) InvalidateAuthCacheByKey(_ context.Context, key string) {
m.invalidatedKeys = append(m.invalidatedKeys, key)
}
func (m *mockChannelAuthCacheInvalidator) InvalidateAuthCacheByUserID(_ context.Context, userID int64) {
m.invalidatedUserIDs = append(m.invalidatedUserIDs, userID)
}
func (m *mockChannelAuthCacheInvalidator) InvalidateAuthCacheByGroupID(_ context.Context, groupID int64) {
m.invalidatedGroupIDs = append(m.invalidatedGroupIDs, groupID)
}
// ---------------------------------------------------------------------------
// Helpers
// ---------------------------------------------------------------------------
func newTestChannelService(repo *mockChannelRepository) *ChannelService {
return NewChannelService(repo, nil)
}
func newTestChannelServiceWithAuth(repo *mockChannelRepository, auth *mockChannelAuthCacheInvalidator) *ChannelService {
return NewChannelService(repo, auth)
}
// makeStandardRepo returns a repo that serves one active channel with anthropic pricing
// for group 1, with the given model pricing and model mapping.
func makeStandardRepo(ch Channel, groupPlatforms map[int64]string) *mockChannelRepository {
return &mockChannelRepository{
listAllFn: func(_ context.Context) ([]Channel, error) {
return []Channel{ch}, nil
},
getGroupPlatformsFn: func(_ context.Context, _ []int64) (map[int64]string, error) {
return groupPlatforms, nil
},
}
}
// ===========================================================================
// 1. BuildModelMappingChain
// ===========================================================================
func TestBuildModelMappingChain(t *testing.T) {
tests := []struct {
name string
result ChannelMappingResult
requestModel string
upstreamModel string
want string
}{
{
name: "no mapping, no upstream diff",
result: ChannelMappingResult{Mapped: false, MappedModel: "claude-sonnet-4"},
requestModel: "claude-sonnet-4",
upstreamModel: "claude-sonnet-4",
want: "",
},
{
name: "no mapping, upstream differs",
result: ChannelMappingResult{Mapped: false, MappedModel: "claude-sonnet-4"},
requestModel: "claude-sonnet-4",
upstreamModel: "claude-sonnet-4-20250514",
want: "claude-sonnet-4\u2192claude-sonnet-4-20250514",
},
{
name: "mapped, upstream differs",
result: ChannelMappingResult{Mapped: true, MappedModel: "claude-sonnet-4-20250514"},
requestModel: "my-model",
upstreamModel: "actual-upstream",
want: "my-model\u2192claude-sonnet-4-20250514\u2192actual-upstream",
},
{
name: "mapped, upstream same as mapped",
result: ChannelMappingResult{Mapped: true, MappedModel: "claude-sonnet-4-20250514"},
requestModel: "claude-sonnet-4",
upstreamModel: "claude-sonnet-4-20250514",
want: "claude-sonnet-4\u2192claude-sonnet-4-20250514",
},
{
name: "mapped, upstream empty",
result: ChannelMappingResult{Mapped: true, MappedModel: "target-model"},
requestModel: "my-model",
upstreamModel: "",
want: "my-model\u2192target-model",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := tt.result.BuildModelMappingChain(tt.requestModel, tt.upstreamModel)
require.Equal(t, tt.want, got)
})
}
}
// ===========================================================================
// 2. ReplaceModelInBody
// ===========================================================================
func TestReplaceModelInBody(t *testing.T) {
tests := []struct {
name string
body []byte
newModel string
check func(t *testing.T, result []byte)
}{
{
name: "empty body",
body: []byte{},
newModel: "new-model",
check: func(t *testing.T, result []byte) {
require.Equal(t, []byte{}, result)
},
},
{
name: "model already equal",
body: []byte(`{"model":"claude-sonnet-4","temperature":0.7}`),
newModel: "claude-sonnet-4",
check: func(t *testing.T, result []byte) {
require.Equal(t, []byte(`{"model":"claude-sonnet-4","temperature":0.7}`), result)
},
},
{
name: "model different",
body: []byte(`{"model":"claude-sonnet-4","temperature":0.7}`),
newModel: "claude-opus-4",
check: func(t *testing.T, result []byte) {
require.Contains(t, string(result), `"model":"claude-opus-4"`)
require.Contains(t, string(result), `"temperature"`)
},
},
{
name: "no model field",
body: []byte(`{"temperature":0.7}`),
newModel: "claude-opus-4",
check: func(t *testing.T, result []byte) {
require.Contains(t, string(result), `"model":"claude-opus-4"`)
require.Contains(t, string(result), `"temperature"`)
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := ReplaceModelInBody(tt.body, tt.newModel)
tt.check(t, result)
})
}
}
// ===========================================================================
// 3. validateNoConflictingModels + validateNoConflictingMappings
// ===========================================================================
func TestValidateNoConflictingModels(t *testing.T) {
tests := []struct {
name string
pricingList []ChannelModelPricing
wantErr bool
errContains string
}{
{
name: "no duplicates",
pricingList: []ChannelModelPricing{
{Platform: "anthropic", Models: []string{"claude-sonnet-4", "claude-opus-4"}},
{Platform: "openai", Models: []string{"gpt-5.1"}},
},
wantErr: false,
},
{
name: "same platform duplicate",
pricingList: []ChannelModelPricing{
{Platform: "anthropic", Models: []string{"claude-sonnet-4"}},
{Platform: "anthropic", Models: []string{"claude-sonnet-4"}},
},
wantErr: true,
errContains: "claude-sonnet-4",
},
{
name: "same model different platform",
pricingList: []ChannelModelPricing{
{Platform: "anthropic", Models: []string{"model-a"}},
{Platform: "openai", Models: []string{"model-a"}},
},
wantErr: false,
},
{
name: "case insensitive",
pricingList: []ChannelModelPricing{
{Platform: "anthropic", Models: []string{"Claude"}},
{Platform: "anthropic", Models: []string{"claude"}},
},
wantErr: true,
},
{
name: "empty list (nil)",
pricingList: nil,
wantErr: false,
},
{
name: "wildcard_vs_wildcard_conflict",
pricingList: []ChannelModelPricing{
{Platform: "anthropic", Models: []string{"claude-*"}},
{Platform: "anthropic", Models: []string{"claude-opus-*"}},
},
wantErr: true,
errContains: "conflict",
},
{
name: "wildcard_vs_exact_conflict",
pricingList: []ChannelModelPricing{
{Platform: "anthropic", Models: []string{"claude-*"}},
{Platform: "anthropic", Models: []string{"claude-opus-4-6"}},
},
wantErr: true,
errContains: "conflict",
},
{
name: "no_conflict_different_platform",
pricingList: []ChannelModelPricing{
{Platform: "anthropic", Models: []string{"claude-opus-*"}},
{Platform: "openai", Models: []string{"claude-*"}},
},
wantErr: false,
},
{
name: "no_conflict_same_platform_different_prefix",
pricingList: []ChannelModelPricing{
{Platform: "anthropic", Models: []string{"claude-opus-*"}},
{Platform: "anthropic", Models: []string{"gpt-*"}},
},
wantErr: false,
},
{
name: "catch_all_wildcard_conflicts_with_everything",
pricingList: []ChannelModelPricing{
{Platform: "openai", Models: []string{"*"}},
{Platform: "openai", Models: []string{"gpt-5"}},
},
wantErr: true,
errContains: "conflict",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := validateNoConflictingModels(tt.pricingList)
if tt.wantErr {
require.Error(t, err)
if tt.errContains != "" {
require.Contains(t, err.Error(), tt.errContains)
}
} else {
require.NoError(t, err)
}
})
}
// Additional sub-case: explicit empty slice
t.Run("empty list (empty slice)", func(t *testing.T) {
err := validateNoConflictingModels([]ChannelModelPricing{})
require.NoError(t, err)
})
}
func TestValidateNoConflictingMappings(t *testing.T) {
tests := []struct {
name string
mapping map[string]map[string]string
wantErr bool
errContains string
}{
{
name: "nil mapping",
mapping: nil,
wantErr: false,
},
{
name: "empty mapping",
mapping: map[string]map[string]string{},
wantErr: false,
},
{
name: "no conflict",
mapping: map[string]map[string]string{
"anthropic": {"claude-opus-*": "opus", "gpt-*": "gpt"},
},
wantErr: false,
},
{
name: "wildcard vs wildcard conflict",
mapping: map[string]map[string]string{
"anthropic": {"claude-*": "a", "claude-opus-*": "b"},
},
wantErr: true,
errContains: "conflict",
},
{
name: "wildcard vs exact conflict",
mapping: map[string]map[string]string{
"openai": {"gpt-*": "a", "gpt-4o": "b"},
},
wantErr: true,
errContains: "conflict",
},
{
name: "exact duplicate conflict",
mapping: map[string]map[string]string{
"anthropic": {"claude-opus-4": "a"},
"openai": {"claude-opus-4": "b"},
},
wantErr: false, // different platforms
},
{
name: "different platforms no conflict",
mapping: map[string]map[string]string{
"anthropic": {"claude-*": "a"},
"openai": {"claude-*": "b"},
},
wantErr: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := validateNoConflictingMappings(tt.mapping)
if tt.wantErr {
require.Error(t, err)
if tt.errContains != "" {
require.Contains(t, err.Error(), tt.errContains)
}
} else {
require.NoError(t, err)
}
})
}
}
func TestConflictsBetween(t *testing.T) {
tests := []struct {
name string
a, b modelEntry
want bool
}{
{
name: "exact same",
a: modelEntry{prefix: "claude-opus-4", wildcard: false},
b: modelEntry{prefix: "claude-opus-4", wildcard: false},
want: true,
},
{
name: "exact different",
a: modelEntry{prefix: "claude-opus-4", wildcard: false},
b: modelEntry{prefix: "gpt-4o", wildcard: false},
want: false,
},
{
name: "wildcard matches exact",
a: modelEntry{prefix: "claude-", wildcard: true},
b: modelEntry{prefix: "claude-opus-4", wildcard: false},
want: true,
},
{
name: "exact does not match unrelated wildcard",
a: modelEntry{prefix: "gpt-4o", wildcard: false},
b: modelEntry{prefix: "claude-", wildcard: true},
want: false,
},
{
name: "wildcard prefix overlap",
a: modelEntry{prefix: "claude-", wildcard: true},
b: modelEntry{prefix: "claude-opus-", wildcard: true},
want: true,
},
{
name: "wildcards no overlap",
a: modelEntry{prefix: "claude-", wildcard: true},
b: modelEntry{prefix: "gpt-", wildcard: true},
want: false,
},
{
name: "catch-all wildcard vs any",
a: modelEntry{prefix: "", wildcard: true},
b: modelEntry{prefix: "anything", wildcard: false},
want: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
require.Equal(t, tt.want, conflictsBetween(tt.a, tt.b))
})
}
}
// ===========================================================================
// 4. Cache Building + Hot Path Methods
// ===========================================================================
// --- 4.1 GetChannelForGroup ---
func TestGetChannelForGroup_Success(t *testing.T) {
ch := Channel{
ID: 1,
Name: "test-channel",
Status: StatusActive,
GroupIDs: []int64{10},
}
repo := makeStandardRepo(ch, map[int64]string{10: "anthropic"})
svc := newTestChannelService(repo)
result, err := svc.GetChannelForGroup(context.Background(), 10)
require.NoError(t, err)
require.NotNil(t, result)
require.Equal(t, int64(1), result.ID)
require.Equal(t, "test-channel", result.Name)
// returned value should be a clone
result.Name = "mutated"
result2, err := svc.GetChannelForGroup(context.Background(), 10)
require.NoError(t, err)
require.Equal(t, "test-channel", result2.Name)
}
func TestGetChannelForGroup_InactiveChannel(t *testing.T) {
ch := Channel{
ID: 1,
Status: StatusDisabled,
GroupIDs: []int64{10},
}
repo := makeStandardRepo(ch, map[int64]string{10: "anthropic"})
svc := newTestChannelService(repo)
result, err := svc.GetChannelForGroup(context.Background(), 10)
require.NoError(t, err)
require.Nil(t, result)
}
func TestGetChannelForGroup_NoChannel(t *testing.T) {
ch := Channel{
ID: 1,
Status: StatusActive,
GroupIDs: []int64{10},
}
repo := makeStandardRepo(ch, map[int64]string{10: "anthropic"})
svc := newTestChannelService(repo)
result, err := svc.GetChannelForGroup(context.Background(), 999)
require.NoError(t, err)
require.Nil(t, result)
}
func TestGetChannelForGroup_CacheError(t *testing.T) {
repo := &mockChannelRepository{
listAllFn: func(_ context.Context) ([]Channel, error) {
return nil, errors.New("db connection failed")
},
}
svc := newTestChannelService(repo)
result, err := svc.GetChannelForGroup(context.Background(), 10)
require.Error(t, err)
require.Nil(t, result)
require.Contains(t, err.Error(), "db connection failed")
}
// --- 4.2 GetChannelModelPricing ---
func TestGetChannelModelPricing_ExactMatch(t *testing.T) {
ch := Channel{
ID: 1,
Status: StatusActive,
GroupIDs: []int64{10},
ModelPricing: []ChannelModelPricing{
{ID: 100, Platform: "anthropic", Models: []string{"claude-opus-4"}, InputPrice: testPtrFloat64(15e-6)},
},
}
repo := makeStandardRepo(ch, map[int64]string{10: "anthropic"})
svc := newTestChannelService(repo)
result := svc.GetChannelModelPricing(context.Background(), 10, "claude-opus-4")
require.NotNil(t, result)
require.Equal(t, int64(100), result.ID)
require.InDelta(t, 15e-6, *result.InputPrice, 1e-12)
}
func TestGetChannelModelPricing_CaseInsensitive(t *testing.T) {
ch := Channel{
ID: 1,
Status: StatusActive,
GroupIDs: []int64{10},
ModelPricing: []ChannelModelPricing{
{ID: 100, Platform: "anthropic", Models: []string{"claude-opus-4"}, InputPrice: testPtrFloat64(15e-6)},
},
}
repo := makeStandardRepo(ch, map[int64]string{10: "anthropic"})
svc := newTestChannelService(repo)
result := svc.GetChannelModelPricing(context.Background(), 10, "Claude-Opus-4")
require.NotNil(t, result)
require.Equal(t, int64(100), result.ID)
}
func TestGetChannelModelPricing_WildcardMatch(t *testing.T) {
ch := Channel{
ID: 1,
Status: StatusActive,
GroupIDs: []int64{10},
ModelPricing: []ChannelModelPricing{
{ID: 200, Platform: "anthropic", Models: []string{"claude-*"}, InputPrice: testPtrFloat64(10e-6)},
},
}
repo := makeStandardRepo(ch, map[int64]string{10: "anthropic"})
svc := newTestChannelService(repo)
result := svc.GetChannelModelPricing(context.Background(), 10, "claude-sonnet-4")
require.NotNil(t, result)
require.Equal(t, int64(200), result.ID)
}
func TestGetChannelModelPricing_WildcardFirstMatch(t *testing.T) {
ch := Channel{
ID: 1,
Status: StatusActive,
GroupIDs: []int64{10},
ModelPricing: []ChannelModelPricing{
{ID: 200, Platform: "anthropic", Models: []string{"claude-*"}, InputPrice: testPtrFloat64(10e-6)},
{ID: 300, Platform: "anthropic", Models: []string{"claude-sonnet-*"}, InputPrice: testPtrFloat64(5e-6)},
},
}
repo := makeStandardRepo(ch, map[int64]string{10: "anthropic"})
svc := newTestChannelService(repo)
result := svc.GetChannelModelPricing(context.Background(), 10, "claude-sonnet-4-20250514")
require.NotNil(t, result)
// "claude-*" is defined first, so it matches first regardless of prefix length
require.Equal(t, int64(200), result.ID)
require.InDelta(t, 10e-6, *result.InputPrice, 1e-12)
}
func TestGetChannelModelPricing_NoMatch(t *testing.T) {
ch := Channel{
ID: 1,
Status: StatusActive,
GroupIDs: []int64{10},
ModelPricing: []ChannelModelPricing{
{ID: 100, Platform: "anthropic", Models: []string{"claude-opus-4"}, InputPrice: testPtrFloat64(15e-6)},
},
}
repo := makeStandardRepo(ch, map[int64]string{10: "anthropic"})
svc := newTestChannelService(repo)
result := svc.GetChannelModelPricing(context.Background(), 10, "gpt-5.1")
require.Nil(t, result)
}
func TestGetChannelModelPricing_InactiveChannel(t *testing.T) {
ch := Channel{
ID: 1,
Status: StatusDisabled,
GroupIDs: []int64{10},
ModelPricing: []ChannelModelPricing{
{ID: 100, Platform: "anthropic", Models: []string{"claude-opus-4"}},
},
}
repo := makeStandardRepo(ch, map[int64]string{10: "anthropic"})
svc := newTestChannelService(repo)
result := svc.GetChannelModelPricing(context.Background(), 10, "claude-opus-4")
require.Nil(t, result)
}
func TestGetChannelModelPricing_PlatformFiltering(t *testing.T) {
ch := Channel{
ID: 1,
Status: StatusActive,
GroupIDs: []int64{10, 20},
ModelPricing: []ChannelModelPricing{
{ID: 100, Platform: "openai", Models: []string{"gpt-5.1"}, InputPrice: testPtrFloat64(5e-6)},
{ID: 200, Platform: "anthropic", Models: []string{"claude-opus-4"}, InputPrice: testPtrFloat64(15e-6)},
},
}
repo := makeStandardRepo(ch, map[int64]string{10: "anthropic", 20: "openai"})
svc := newTestChannelService(repo)
// Group 10 (anthropic) should NOT see openai pricing
result := svc.GetChannelModelPricing(context.Background(), 10, "gpt-5.1")
require.Nil(t, result)
// Group 10 (anthropic) should see anthropic pricing
result = svc.GetChannelModelPricing(context.Background(), 10, "claude-opus-4")
require.NotNil(t, result)
require.Equal(t, int64(200), result.ID)
// Group 20 (openai) should see openai pricing
result = svc.GetChannelModelPricing(context.Background(), 20, "gpt-5.1")
require.NotNil(t, result)
require.Equal(t, int64(100), result.ID)
// Group 20 (openai) should NOT see anthropic pricing
result = svc.GetChannelModelPricing(context.Background(), 20, "claude-opus-4")
require.Nil(t, result)
}
func TestGetChannelModelPricing_ReturnsCopy(t *testing.T) {
ch := Channel{
ID: 1,
Status: StatusActive,
GroupIDs: []int64{10},
ModelPricing: []ChannelModelPricing{
{ID: 100, Platform: "anthropic", Models: []string{"claude-opus-4"}, InputPrice: testPtrFloat64(15e-6)},
},
}
repo := makeStandardRepo(ch, map[int64]string{10: "anthropic"})
svc := newTestChannelService(repo)
result := svc.GetChannelModelPricing(context.Background(), 10, "claude-opus-4")
require.NotNil(t, result)
// Mutate the returned pricing's slice fields — original cache should not be affected
// (Clone copies slices independently, pointer fields are shared per design)
result.Models = append(result.Models, "hacked")
result.ID = 999
// Original cache should not be affected (slice independence + struct copy)
result2 := svc.GetChannelModelPricing(context.Background(), 10, "claude-opus-4")
require.NotNil(t, result2)
require.Equal(t, 1, len(result2.Models))
require.Equal(t, int64(100), result2.ID)
}
// --- 4.3 ResolveChannelMapping ---
func TestResolveChannelMapping_NoChannel(t *testing.T) {
ch := Channel{
ID: 1,
Status: StatusActive,
GroupIDs: []int64{10},
}
repo := makeStandardRepo(ch, map[int64]string{10: "anthropic"})
svc := newTestChannelService(repo)
// Group 999 is not in any channel
result := svc.ResolveChannelMapping(context.Background(), 999, "claude-opus-4")
require.Equal(t, "claude-opus-4", result.MappedModel)
require.False(t, result.Mapped)
require.Equal(t, int64(0), result.ChannelID)
}
func TestResolveChannelMapping_ExactMapping(t *testing.T) {
ch := Channel{
ID: 1,
Status: StatusActive,
GroupIDs: []int64{10},
ModelMapping: map[string]map[string]string{
"anthropic": {
"claude-sonnet-4": "claude-sonnet-4-20250514",
},
},
}
repo := makeStandardRepo(ch, map[int64]string{10: "anthropic"})
svc := newTestChannelService(repo)
result := svc.ResolveChannelMapping(context.Background(), 10, "claude-sonnet-4")
require.True(t, result.Mapped)
require.Equal(t, "claude-sonnet-4-20250514", result.MappedModel)
require.Equal(t, int64(1), result.ChannelID)
}
func TestResolveChannelMapping_WildcardMapping(t *testing.T) {
ch := Channel{
ID: 1,
Status: StatusActive,
GroupIDs: []int64{10},
ModelMapping: map[string]map[string]string{
"anthropic": {
"*": "gpt-5.4",
},
},
}
repo := makeStandardRepo(ch, map[int64]string{10: "anthropic"})
svc := newTestChannelService(repo)
result := svc.ResolveChannelMapping(context.Background(), 10, "any-model-name")
require.True(t, result.Mapped)
require.Equal(t, "gpt-5.4", result.MappedModel)
}
func TestResolveChannelMapping_WildcardFirstMatch(t *testing.T) {
ch := Channel{
ID: 1,
Status: StatusActive,
GroupIDs: []int64{10},
ModelMapping: map[string]map[string]string{
"anthropic": {
"claude-*": "target2",
"claude-sonnet-*": "target1",
},
},
}
repo := makeStandardRepo(ch, map[int64]string{10: "anthropic"})
svc := newTestChannelService(repo)
result := svc.ResolveChannelMapping(context.Background(), 10, "claude-sonnet-4")
require.True(t, result.Mapped)
// map iteration order is non-deterministic, so the first-match depends on
// insertion order which Go maps don't guarantee; verify that one of the
// wildcard targets matched
require.Contains(t, []string{"target1", "target2"}, result.MappedModel)
}
func TestResolveChannelMapping_NoMapping(t *testing.T) {
ch := Channel{
ID: 1,
Status: StatusActive,
GroupIDs: []int64{10},
ModelMapping: map[string]map[string]string{
"anthropic": {
"claude-sonnet-4": "mapped",
},
},
}
repo := makeStandardRepo(ch, map[int64]string{10: "anthropic"})
svc := newTestChannelService(repo)
result := svc.ResolveChannelMapping(context.Background(), 10, "claude-opus-4")
require.False(t, result.Mapped)
require.Equal(t, "claude-opus-4", result.MappedModel)
require.Equal(t, int64(1), result.ChannelID)
}
func TestResolveChannelMapping_DefaultBillingModelSource(t *testing.T) {
ch := Channel{
ID: 1,
Status: StatusActive,
GroupIDs: []int64{10},
BillingModelSource: "", // empty
}
repo := makeStandardRepo(ch, map[int64]string{10: "anthropic"})
svc := newTestChannelService(repo)
result := svc.ResolveChannelMapping(context.Background(), 10, "claude-opus-4")
require.Equal(t, BillingModelSourceChannelMapped, result.BillingModelSource)
}
func TestResolveChannelMapping_UpstreamBillingModelSource(t *testing.T) {
ch := Channel{
ID: 1,
Status: StatusActive,
GroupIDs: []int64{10},
BillingModelSource: BillingModelSourceUpstream,
}
repo := makeStandardRepo(ch, map[int64]string{10: "anthropic"})
svc := newTestChannelService(repo)
result := svc.ResolveChannelMapping(context.Background(), 10, "claude-opus-4")
require.Equal(t, BillingModelSourceUpstream, result.BillingModelSource)
}
func TestResolveChannelMapping_InactiveChannel(t *testing.T) {
ch := Channel{
ID: 1,
Status: StatusDisabled,
GroupIDs: []int64{10},
ModelMapping: map[string]map[string]string{
"anthropic": {
"claude-sonnet-4": "mapped",
},
},
}
repo := makeStandardRepo(ch, map[int64]string{10: "anthropic"})
svc := newTestChannelService(repo)
result := svc.ResolveChannelMapping(context.Background(), 10, "claude-sonnet-4")
require.False(t, result.Mapped)
require.Equal(t, "claude-sonnet-4", result.MappedModel)
require.Equal(t, int64(0), result.ChannelID) // no channel
}
// --- 4.4 IsModelRestricted ---
func TestIsModelRestricted_NoChannel(t *testing.T) {
ch := Channel{
ID: 1,
Status: StatusActive,
GroupIDs: []int64{10},
RestrictModels: true,
}
repo := makeStandardRepo(ch, map[int64]string{10: "anthropic"})
svc := newTestChannelService(repo)
// Group 999 is not in any channel
restricted := svc.IsModelRestricted(context.Background(), 999, "claude-opus-4")
require.False(t, restricted)
}
func TestIsModelRestricted_RestrictDisabled(t *testing.T) {
ch := Channel{
ID: 1,
Status: StatusActive,
GroupIDs: []int64{10},
RestrictModels: false,
ModelPricing: []ChannelModelPricing{
{Platform: "anthropic", Models: []string{"claude-opus-4"}},
},
}
repo := makeStandardRepo(ch, map[int64]string{10: "anthropic"})
svc := newTestChannelService(repo)
// Even though model is not in pricing, RestrictModels=false
restricted := svc.IsModelRestricted(context.Background(), 10, "nonexistent-model")
require.False(t, restricted)
}
func TestIsModelRestricted_InactiveChannel(t *testing.T) {
ch := Channel{
ID: 1,
Status: StatusDisabled,
GroupIDs: []int64{10},
RestrictModels: true,
}
repo := makeStandardRepo(ch, map[int64]string{10: "anthropic"})
svc := newTestChannelService(repo)
restricted := svc.IsModelRestricted(context.Background(), 10, "any-model")
require.False(t, restricted)
}
func TestIsModelRestricted_ModelInPricing(t *testing.T) {
ch := Channel{
ID: 1,
Status: StatusActive,
GroupIDs: []int64{10},
RestrictModels: true,
ModelPricing: []ChannelModelPricing{
{Platform: "anthropic", Models: []string{"claude-opus-4", "claude-sonnet-4"}},
},
}
repo := makeStandardRepo(ch, map[int64]string{10: "anthropic"})
svc := newTestChannelService(repo)
restricted := svc.IsModelRestricted(context.Background(), 10, "claude-opus-4")
require.False(t, restricted)
}
func TestIsModelRestricted_ModelInWildcard(t *testing.T) {
ch := Channel{
ID: 1,
Status: StatusActive,
GroupIDs: []int64{10},
RestrictModels: true,
ModelPricing: []ChannelModelPricing{
{Platform: "anthropic", Models: []string{"claude-*"}},
},
}
repo := makeStandardRepo(ch, map[int64]string{10: "anthropic"})
svc := newTestChannelService(repo)
restricted := svc.IsModelRestricted(context.Background(), 10, "claude-sonnet-4")
require.False(t, restricted)
}
func TestIsModelRestricted_ModelNotFound(t *testing.T) {
ch := Channel{
ID: 1,
Status: StatusActive,
GroupIDs: []int64{10},
RestrictModels: true,
ModelPricing: []ChannelModelPricing{
{Platform: "anthropic", Models: []string{"claude-opus-4"}},
},
}
repo := makeStandardRepo(ch, map[int64]string{10: "anthropic"})
svc := newTestChannelService(repo)
restricted := svc.IsModelRestricted(context.Background(), 10, "gpt-5.1")
require.True(t, restricted)
}
func TestIsModelRestricted_CaseInsensitive(t *testing.T) {
ch := Channel{
ID: 1,
Status: StatusActive,
GroupIDs: []int64{10},
RestrictModels: true,
ModelPricing: []ChannelModelPricing{
{Platform: "anthropic", Models: []string{"claude-opus-4"}},
},
}
repo := makeStandardRepo(ch, map[int64]string{10: "anthropic"})
svc := newTestChannelService(repo)
restricted := svc.IsModelRestricted(context.Background(), 10, "Claude-Opus-4")
require.False(t, restricted)
}
// --- 4.5 ResolveChannelMappingAndRestrict ---
// 注意模型限制检查已移至调度阶段GatewayService.checkChannelPricingRestriction
// ResolveChannelMappingAndRestrict 仅做映射restricted 始终为 false。
func TestResolveChannelMappingAndRestrict_NilGroupID(t *testing.T) {
repo := &mockChannelRepository{
listAllFn: func(_ context.Context) ([]Channel, error) {
return nil, nil
},
}
svc := newTestChannelService(repo)
mapping, restricted := svc.ResolveChannelMappingAndRestrict(context.Background(), nil, "claude-opus-4")
require.False(t, restricted)
require.False(t, mapping.Mapped)
require.Equal(t, "claude-opus-4", mapping.MappedModel)
}
func TestResolveChannelMappingAndRestrict_WithMapping(t *testing.T) {
ch := Channel{
ID: 1,
Status: StatusActive,
GroupIDs: []int64{10},
RestrictModels: true,
ModelPricing: []ChannelModelPricing{
{Platform: "anthropic", Models: []string{"claude-sonnet-4"}},
},
ModelMapping: map[string]map[string]string{
"anthropic": {
"claude-sonnet-4": "claude-sonnet-4-20250514",
},
},
}
repo := makeStandardRepo(ch, map[int64]string{10: "anthropic"})
svc := newTestChannelService(repo)
gid := int64(10)
mapping, restricted := svc.ResolveChannelMappingAndRestrict(context.Background(), &gid, "claude-sonnet-4")
require.False(t, restricted) // restricted 始终为 false限制检查在调度阶段
require.True(t, mapping.Mapped)
require.Equal(t, "claude-sonnet-4-20250514", mapping.MappedModel)
}
func TestResolveChannelMappingAndRestrict_NoMapping(t *testing.T) {
ch := Channel{
ID: 1,
Status: StatusActive,
GroupIDs: []int64{10},
RestrictModels: true,
ModelPricing: []ChannelModelPricing{
{Platform: "anthropic", Models: []string{"claude-sonnet-4"}},
},
}
repo := makeStandardRepo(ch, map[int64]string{10: "anthropic"})
svc := newTestChannelService(repo)
gid := int64(10)
mapping, restricted := svc.ResolveChannelMappingAndRestrict(context.Background(), &gid, "unknown-model")
require.False(t, restricted) // restricted 始终为 false限制检查在调度阶段
require.False(t, mapping.Mapped)
require.Equal(t, "unknown-model", mapping.MappedModel)
}
// --- 4.6 Cache Building Specifics ---
func TestBuildCache_DBError(t *testing.T) {
callCount := 0
repo := &mockChannelRepository{
listAllFn: func(_ context.Context) ([]Channel, error) {
callCount++
return nil, errors.New("database down")
},
}
svc := newTestChannelService(repo)
// First call should fail
_, err := svc.GetChannelForGroup(context.Background(), 10)
require.Error(t, err)
require.Contains(t, err.Error(), "database down")
require.Equal(t, 1, callCount)
// Second call within error-TTL should use error cache, but still return error
// Because buildCache stores error-TTL cache and returns error, the cached value
// is still within TTL and loadCache returns it (which is an empty cache).
// Actually, re-reading the code: buildCache returns nil, err, and the error cache
// only serves as a "don't retry immediately" mechanism. The singleflight.Do
// returns the error. On next call within error-TTL, the cache has an empty but
// valid entry, so loadCache returns it (with empty maps). GetChannelForGroup
// will find nothing and return nil, nil.
result, err := svc.GetChannelForGroup(context.Background(), 10)
require.NoError(t, err)
require.Nil(t, result)
// Should NOT have hit DB again (error-TTL cache is active)
require.Equal(t, 1, callCount)
}
func TestBuildCache_GroupPlatformError(t *testing.T) {
ch := Channel{
ID: 1,
Status: StatusActive,
GroupIDs: []int64{10},
ModelPricing: []ChannelModelPricing{
{ID: 100, Platform: "anthropic", Models: []string{"claude-opus-4"}},
},
}
repo := &mockChannelRepository{
listAllFn: func(_ context.Context) ([]Channel, error) {
return []Channel{ch}, nil
},
getGroupPlatformsFn: func(_ context.Context, _ []int64) (map[int64]string, error) {
return nil, errors.New("group platforms failed")
},
}
svc := newTestChannelService(repo)
// Should fail-close: error propagated when group platforms cannot be loaded
result, err := svc.GetChannelForGroup(context.Background(), 10)
require.Error(t, err)
require.Nil(t, result)
// Within error-TTL, second call should hit cache (empty) and return nil, nil
result2, err2 := svc.GetChannelForGroup(context.Background(), 10)
require.NoError(t, err2)
require.Nil(t, result2)
}
func TestBuildCache_MultipleGroupsSameChannel(t *testing.T) {
ch := Channel{
ID: 1,
Status: StatusActive,
GroupIDs: []int64{10, 20, 30},
ModelPricing: []ChannelModelPricing{
{ID: 100, Platform: "anthropic", Models: []string{"claude-opus-4"}, InputPrice: testPtrFloat64(15e-6)},
},
}
repo := makeStandardRepo(ch, map[int64]string{
10: "anthropic",
20: "anthropic",
30: "anthropic",
})
svc := newTestChannelService(repo)
for _, gid := range []int64{10, 20, 30} {
result := svc.GetChannelModelPricing(context.Background(), gid, "claude-opus-4")
require.NotNil(t, result, "group %d should have pricing", gid)
require.Equal(t, int64(100), result.ID)
}
}
func TestBuildCache_PlatformFiltering(t *testing.T) {
ch := Channel{
ID: 1,
Status: StatusActive,
GroupIDs: []int64{10, 20},
ModelPricing: []ChannelModelPricing{
{ID: 100, Platform: "anthropic", Models: []string{"claude-opus-4"}},
{ID: 200, Platform: "openai", Models: []string{"gpt-5.1"}},
},
}
repo := makeStandardRepo(ch, map[int64]string{
10: "anthropic",
20: "openai",
})
svc := newTestChannelService(repo)
// anthropic group sees only anthropic models
require.NotNil(t, svc.GetChannelModelPricing(context.Background(), 10, "claude-opus-4"))
require.Nil(t, svc.GetChannelModelPricing(context.Background(), 10, "gpt-5.1"))
// openai group sees only openai models
require.NotNil(t, svc.GetChannelModelPricing(context.Background(), 20, "gpt-5.1"))
require.Nil(t, svc.GetChannelModelPricing(context.Background(), 20, "claude-opus-4"))
}
func TestBuildCache_WildcardPreservesConfigOrder(t *testing.T) {
ch := Channel{
ID: 1,
Status: StatusActive,
GroupIDs: []int64{10},
ModelPricing: []ChannelModelPricing{
// Configuration order: shortest prefix first
{ID: 100, Platform: "anthropic", Models: []string{"c-*"}, InputPrice: testPtrFloat64(1e-6)},
{ID: 200, Platform: "anthropic", Models: []string{"c-son-*"}, InputPrice: testPtrFloat64(2e-6)},
{ID: 300, Platform: "anthropic", Models: []string{"c-son-4-*"}, InputPrice: testPtrFloat64(3e-6)},
},
}
repo := makeStandardRepo(ch, map[int64]string{10: "anthropic"})
svc := newTestChannelService(repo)
// "c-son-4-xxx" matches all three wildcards, but "c-*" (ID=100) is first in config
result := svc.GetChannelModelPricing(context.Background(), 10, "c-son-4-xxx")
require.NotNil(t, result)
require.Equal(t, int64(100), result.ID)
// "c-son-yyy" matches "c-*" and "c-son-*", but "c-*" (ID=100) is first
result = svc.GetChannelModelPricing(context.Background(), 10, "c-son-yyy")
require.NotNil(t, result)
require.Equal(t, int64(100), result.ID)
// "c-other" only matches "c-*" (ID=100)
result = svc.GetChannelModelPricing(context.Background(), 10, "c-other")
require.NotNil(t, result)
require.Equal(t, int64(100), result.ID)
}
// --- 4.7 invalidateCache ---
func TestInvalidateCache(t *testing.T) {
callCount := 0
ch := Channel{
ID: 1,
Status: StatusActive,
GroupIDs: []int64{10},
ModelPricing: []ChannelModelPricing{
{ID: 100, Platform: "anthropic", Models: []string{"claude-opus-4"}},
},
}
repo := &mockChannelRepository{
listAllFn: func(_ context.Context) ([]Channel, error) {
callCount++
return []Channel{ch}, nil
},
getGroupPlatformsFn: func(_ context.Context, _ []int64) (map[int64]string, error) {
return map[int64]string{10: "anthropic"}, nil
},
}
svc := newTestChannelService(repo)
// First load
result := svc.GetChannelModelPricing(context.Background(), 10, "claude-opus-4")
require.NotNil(t, result)
require.Equal(t, 1, callCount)
// Second call should use cache
result = svc.GetChannelModelPricing(context.Background(), 10, "claude-opus-4")
require.NotNil(t, result)
require.Equal(t, 1, callCount) // no new DB call
// Invalidate
svc.invalidateCache()
// Next call should rebuild from DB
result = svc.GetChannelModelPricing(context.Background(), 10, "claude-opus-4")
require.NotNil(t, result)
require.Equal(t, 2, callCount) // rebuilt
}
// ===========================================================================
// 5. CRUD Methods
// ===========================================================================
// --- 5.1 Create ---
func TestCreate_Success(t *testing.T) {
createdID := int64(42)
repo := &mockChannelRepository{
existsByNameFn: func(_ context.Context, _ string) (bool, error) {
return false, nil
},
getGroupsInOtherChannelsFn: func(_ context.Context, _ int64, _ []int64) ([]int64, error) {
return nil, nil
},
createFn: func(_ context.Context, ch *Channel) error {
ch.ID = createdID
return nil
},
getByIDFn: func(_ context.Context, id int64) (*Channel, error) {
return &Channel{ID: id, Name: "new-channel", Status: StatusActive}, nil
},
listAllFn: func(_ context.Context) ([]Channel, error) {
return nil, nil
},
}
svc := newTestChannelService(repo)
result, err := svc.Create(context.Background(), &CreateChannelInput{
Name: "new-channel",
GroupIDs: []int64{10},
})
require.NoError(t, err)
require.NotNil(t, result)
require.Equal(t, createdID, result.ID)
}
func TestCreate_NameExists(t *testing.T) {
repo := &mockChannelRepository{
existsByNameFn: func(_ context.Context, _ string) (bool, error) {
return true, nil
},
}
svc := newTestChannelService(repo)
_, err := svc.Create(context.Background(), &CreateChannelInput{
Name: "existing-channel",
})
require.Error(t, err)
require.ErrorIs(t, err, ErrChannelExists)
}
func TestCreate_GroupConflict(t *testing.T) {
repo := &mockChannelRepository{
existsByNameFn: func(_ context.Context, _ string) (bool, error) {
return false, nil
},
getGroupsInOtherChannelsFn: func(_ context.Context, _ int64, _ []int64) ([]int64, error) {
return []int64{10}, nil // group 10 already in another channel
},
}
svc := newTestChannelService(repo)
_, err := svc.Create(context.Background(), &CreateChannelInput{
Name: "new-channel",
GroupIDs: []int64{10, 20},
})
require.Error(t, err)
require.ErrorIs(t, err, ErrGroupAlreadyInChannel)
}
func TestCreate_DuplicateModel(t *testing.T) {
repo := &mockChannelRepository{
existsByNameFn: func(_ context.Context, _ string) (bool, error) {
return false, nil
},
}
svc := newTestChannelService(repo)
_, err := svc.Create(context.Background(), &CreateChannelInput{
Name: "new-channel",
ModelPricing: []ChannelModelPricing{
{Platform: "anthropic", Models: []string{"claude-opus-4"}},
{Platform: "anthropic", Models: []string{"claude-opus-4"}}, // duplicate
},
})
require.Error(t, err)
require.Contains(t, err.Error(), "claude-opus-4")
}
func TestCreate_InvalidPricingIntervals(t *testing.T) {
repo := &mockChannelRepository{
existsByNameFn: func(_ context.Context, _ string) (bool, error) {
return false, nil
},
}
svc := newTestChannelService(repo)
_, err := svc.Create(context.Background(), &CreateChannelInput{
Name: "new-channel",
ModelPricing: []ChannelModelPricing{
{
Platform: "anthropic",
Models: []string{"claude-opus-4"},
Intervals: []PricingInterval{
{MinTokens: 0, MaxTokens: testPtrInt(2000), InputPrice: testPtrFloat64(1e-6)},
{MinTokens: 1000, MaxTokens: testPtrInt(3000), InputPrice: testPtrFloat64(2e-6)},
},
},
},
})
require.Error(t, err)
require.Contains(t, err.Error(), "INVALID_PRICING_INTERVALS")
require.Contains(t, err.Error(), "overlap")
}
func TestCreate_DefaultBillingModelSource(t *testing.T) {
var capturedChannel *Channel
repo := &mockChannelRepository{
existsByNameFn: func(_ context.Context, _ string) (bool, error) {
return false, nil
},
createFn: func(_ context.Context, ch *Channel) error {
capturedChannel = ch
ch.ID = 1
return nil
},
getByIDFn: func(_ context.Context, id int64) (*Channel, error) {
return capturedChannel, nil
},
listAllFn: func(_ context.Context) ([]Channel, error) {
return nil, nil
},
}
svc := newTestChannelService(repo)
result, err := svc.Create(context.Background(), &CreateChannelInput{
Name: "new-channel",
BillingModelSource: "", // empty, should default to "channel_mapped"
})
require.NoError(t, err)
require.NotNil(t, result)
require.Equal(t, BillingModelSourceChannelMapped, result.BillingModelSource)
}
func TestCreate_InvalidatesCache(t *testing.T) {
loadCount := 0
ch := Channel{
ID: 1,
Status: StatusActive,
GroupIDs: []int64{10},
ModelPricing: []ChannelModelPricing{
{ID: 100, Platform: "anthropic", Models: []string{"claude-opus-4"}},
},
}
repo := &mockChannelRepository{
listAllFn: func(_ context.Context) ([]Channel, error) {
loadCount++
return []Channel{ch}, nil
},
getGroupPlatformsFn: func(_ context.Context, _ []int64) (map[int64]string, error) {
return map[int64]string{10: "anthropic"}, nil
},
existsByNameFn: func(_ context.Context, _ string) (bool, error) {
return false, nil
},
createFn: func(_ context.Context, c *Channel) error {
c.ID = 2
return nil
},
getByIDFn: func(_ context.Context, id int64) (*Channel, error) {
return &Channel{ID: id, Name: "new", Status: StatusActive}, nil
},
}
svc := newTestChannelService(repo)
// Load cache
_ = svc.GetChannelModelPricing(context.Background(), 10, "claude-opus-4")
require.Equal(t, 1, loadCount)
// Create triggers cache invalidation
_, err := svc.Create(context.Background(), &CreateChannelInput{Name: "new"})
require.NoError(t, err)
// Next cache access should rebuild
_ = svc.GetChannelModelPricing(context.Background(), 10, "claude-opus-4")
require.Equal(t, 2, loadCount)
}
// --- 5.2 Update ---
func TestUpdate_Success(t *testing.T) {
existing := &Channel{
ID: 1,
Name: "original",
Status: StatusActive,
}
repo := &mockChannelRepository{
getByIDFn: func(_ context.Context, id int64) (*Channel, error) {
return existing.Clone(), nil
},
updateFn: func(_ context.Context, _ *Channel) error {
return nil
},
getGroupIDsFn: func(_ context.Context, _ int64) ([]int64, error) {
return nil, nil
},
listAllFn: func(_ context.Context) ([]Channel, error) {
return nil, nil
},
}
svc := newTestChannelService(repo)
result, err := svc.Update(context.Background(), 1, &UpdateChannelInput{
Name: "updated-name",
Description: testPtrString("new desc"),
})
require.NoError(t, err)
require.NotNil(t, result)
}
func TestUpdate_NotFound(t *testing.T) {
repo := &mockChannelRepository{
getByIDFn: func(_ context.Context, _ int64) (*Channel, error) {
return nil, ErrChannelNotFound
},
}
svc := newTestChannelService(repo)
_, err := svc.Update(context.Background(), 999, &UpdateChannelInput{
Name: "whatever",
})
require.Error(t, err)
require.Contains(t, err.Error(), "channel")
}
func TestUpdate_NameConflict(t *testing.T) {
existing := &Channel{
ID: 1,
Name: "original",
Status: StatusActive,
}
repo := &mockChannelRepository{
getByIDFn: func(_ context.Context, _ int64) (*Channel, error) {
return existing.Clone(), nil
},
existsByNameExcludingFn: func(_ context.Context, _ string, _ int64) (bool, error) {
return true, nil // name conflicts with another channel
},
}
svc := newTestChannelService(repo)
_, err := svc.Update(context.Background(), 1, &UpdateChannelInput{
Name: "conflicting-name",
})
require.Error(t, err)
require.ErrorIs(t, err, ErrChannelExists)
}
func TestUpdate_GroupConflict(t *testing.T) {
existing := &Channel{
ID: 1,
Name: "original",
Status: StatusActive,
}
repo := &mockChannelRepository{
getByIDFn: func(_ context.Context, _ int64) (*Channel, error) {
return existing.Clone(), nil
},
getGroupsInOtherChannelsFn: func(_ context.Context, _ int64, _ []int64) ([]int64, error) {
return []int64{20}, nil // group 20 in another channel
},
}
svc := newTestChannelService(repo)
newGroupIDs := []int64{10, 20}
_, err := svc.Update(context.Background(), 1, &UpdateChannelInput{
GroupIDs: &newGroupIDs,
})
require.Error(t, err)
require.ErrorIs(t, err, ErrGroupAlreadyInChannel)
}
func TestUpdate_DuplicateModel(t *testing.T) {
existing := &Channel{
ID: 1,
Name: "original",
Status: StatusActive,
}
repo := &mockChannelRepository{
getByIDFn: func(_ context.Context, _ int64) (*Channel, error) {
return existing.Clone(), nil
},
}
svc := newTestChannelService(repo)
dupPricing := []ChannelModelPricing{
{Platform: "anthropic", Models: []string{"claude-opus-4"}},
{Platform: "anthropic", Models: []string{"claude-opus-4"}},
}
_, err := svc.Update(context.Background(), 1, &UpdateChannelInput{
ModelPricing: &dupPricing,
})
require.Error(t, err)
require.Contains(t, err.Error(), "claude-opus-4")
}
func TestUpdate_InvalidPricingIntervals(t *testing.T) {
existing := &Channel{
ID: 1,
Name: "original",
Status: StatusActive,
}
repo := &mockChannelRepository{
getByIDFn: func(_ context.Context, _ int64) (*Channel, error) {
return existing.Clone(), nil
},
}
svc := newTestChannelService(repo)
invalidPricing := []ChannelModelPricing{
{
Platform: "anthropic",
Models: []string{"claude-opus-4"},
Intervals: []PricingInterval{
{MinTokens: 0, MaxTokens: nil, InputPrice: testPtrFloat64(1e-6)},
{MinTokens: 2000, MaxTokens: testPtrInt(4000), InputPrice: testPtrFloat64(2e-6)},
},
},
}
_, err := svc.Update(context.Background(), 1, &UpdateChannelInput{
ModelPricing: &invalidPricing,
})
require.Error(t, err)
require.Contains(t, err.Error(), "INVALID_PRICING_INTERVALS")
require.Contains(t, err.Error(), "unbounded")
}
func TestUpdate_InvalidatesChannelCache(t *testing.T) {
existing := &Channel{
ID: 1,
Name: "original",
Status: StatusActive,
}
loadCount := 0
repo := &mockChannelRepository{
getByIDFn: func(_ context.Context, _ int64) (*Channel, error) {
return existing.Clone(), nil
},
updateFn: func(_ context.Context, _ *Channel) error {
return nil
},
getGroupIDsFn: func(_ context.Context, _ int64) ([]int64, error) {
return []int64{10, 20}, nil
},
listAllFn: func(_ context.Context) ([]Channel, error) {
loadCount++
return []Channel{*existing}, nil
},
getGroupPlatformsFn: func(_ context.Context, _ []int64) (map[int64]string, error) {
return nil, nil
},
}
svc := newTestChannelService(repo)
// Load cache first
_, _ = svc.GetChannelForGroup(context.Background(), 10)
require.Equal(t, 1, loadCount)
result, err := svc.Update(context.Background(), 1, &UpdateChannelInput{
Description: testPtrString("updated"),
})
require.NoError(t, err)
require.NotNil(t, result)
// Channel cache should be invalidated (next access rebuilds)
_, _ = svc.GetChannelForGroup(context.Background(), 10)
require.Equal(t, 2, loadCount)
}
func TestUpdate_InvalidatesAuthCache(t *testing.T) {
existing := &Channel{
ID: 1,
Name: "original",
Status: StatusActive,
}
auth := &mockChannelAuthCacheInvalidator{}
repo := &mockChannelRepository{
getByIDFn: func(_ context.Context, _ int64) (*Channel, error) {
return existing.Clone(), nil
},
updateFn: func(_ context.Context, _ *Channel) error {
return nil
},
getGroupIDsFn: func(_ context.Context, _ int64) ([]int64, error) {
return []int64{10, 20}, nil
},
listAllFn: func(_ context.Context) ([]Channel, error) {
return nil, nil
},
}
svc := newTestChannelServiceWithAuth(repo, auth)
result, err := svc.Update(context.Background(), 1, &UpdateChannelInput{
Description: testPtrString("updated"),
})
require.NoError(t, err)
require.NotNil(t, result)
// Auth cache should be invalidated for both groups
require.ElementsMatch(t, []int64{10, 20}, auth.invalidatedGroupIDs)
}
// --- 5.3 Delete ---
func TestChannelDelete_Success(t *testing.T) {
deleted := false
repo := &mockChannelRepository{
getGroupIDsFn: func(_ context.Context, _ int64) ([]int64, error) {
return nil, nil
},
deleteFn: func(_ context.Context, _ int64) error {
deleted = true
return nil
},
listAllFn: func(_ context.Context) ([]Channel, error) {
return nil, nil
},
}
svc := newTestChannelService(repo)
err := svc.Delete(context.Background(), 1)
require.NoError(t, err)
require.True(t, deleted)
}
func TestChannelDelete_InvalidatesCaches(t *testing.T) {
auth := &mockChannelAuthCacheInvalidator{}
loadCount := 0
repo := &mockChannelRepository{
getGroupIDsFn: func(_ context.Context, _ int64) ([]int64, error) {
return []int64{10, 20}, nil
},
deleteFn: func(_ context.Context, _ int64) error {
return nil
},
listAllFn: func(_ context.Context) ([]Channel, error) {
loadCount++
return []Channel{{ID: 1, Status: StatusActive, GroupIDs: []int64{10, 20}}}, nil
},
getGroupPlatformsFn: func(_ context.Context, _ []int64) (map[int64]string, error) {
return nil, nil
},
}
svc := newTestChannelServiceWithAuth(repo, auth)
// Load cache first
_, _ = svc.GetChannelForGroup(context.Background(), 10)
require.Equal(t, 1, loadCount)
err := svc.Delete(context.Background(), 1)
require.NoError(t, err)
// Auth cache invalidated for both groups
require.ElementsMatch(t, []int64{10, 20}, auth.invalidatedGroupIDs)
// Channel cache invalidated
_, _ = svc.GetChannelForGroup(context.Background(), 10)
require.Equal(t, 2, loadCount)
}
func TestChannelDelete_NotFound(t *testing.T) {
repo := &mockChannelRepository{
getGroupIDsFn: func(_ context.Context, _ int64) ([]int64, error) {
return nil, nil
},
deleteFn: func(_ context.Context, _ int64) error {
return errors.New("record not found")
},
}
svc := newTestChannelService(repo)
err := svc.Delete(context.Background(), 999)
require.Error(t, err)
require.Contains(t, err.Error(), "not found")
}
// ===========================================================================
// 6. Edge Case Tests
// ===========================================================================
// --- 6.1 Create with empty GroupIDs ---
func TestCreate_NoGroups(t *testing.T) {
createdID := int64(55)
getGroupsInOtherChannelsCalled := false
repo := &mockChannelRepository{
existsByNameFn: func(_ context.Context, _ string) (bool, error) {
return false, nil
},
getGroupsInOtherChannelsFn: func(_ context.Context, _ int64, _ []int64) ([]int64, error) {
getGroupsInOtherChannelsCalled = true
return nil, nil
},
createFn: func(_ context.Context, ch *Channel) error {
ch.ID = createdID
return nil
},
getByIDFn: func(_ context.Context, id int64) (*Channel, error) {
return &Channel{ID: id, Name: "no-groups-channel", Status: StatusActive}, nil
},
listAllFn: func(_ context.Context) ([]Channel, error) {
return nil, nil
},
}
svc := newTestChannelService(repo)
result, err := svc.Create(context.Background(), &CreateChannelInput{
Name: "no-groups-channel",
GroupIDs: []int64{}, // empty slice
})
require.NoError(t, err)
require.NotNil(t, result)
require.Equal(t, createdID, result.ID)
// GetGroupsInOtherChannels should NOT have been called (skipped by len(input.GroupIDs) > 0)
require.False(t, getGroupsInOtherChannelsCalled)
}
// --- 6.2 Update only Status ---
func TestUpdate_StatusOnly(t *testing.T) {
existing := &Channel{
ID: 1,
Name: "test-channel",
Status: StatusActive,
}
var capturedChannel *Channel
repo := &mockChannelRepository{
getByIDFn: func(_ context.Context, id int64) (*Channel, error) {
return existing.Clone(), nil
},
updateFn: func(_ context.Context, ch *Channel) error {
capturedChannel = ch
return nil
},
getGroupIDsFn: func(_ context.Context, _ int64) ([]int64, error) {
return nil, nil
},
listAllFn: func(_ context.Context) ([]Channel, error) {
return nil, nil
},
}
svc := newTestChannelService(repo)
result, err := svc.Update(context.Background(), 1, &UpdateChannelInput{
Status: StatusDisabled,
})
require.NoError(t, err)
require.NotNil(t, result)
// Verify that the channel passed to repo.Update has the new status
require.NotNil(t, capturedChannel)
require.Equal(t, StatusDisabled, capturedChannel.Status)
// Name should remain unchanged
require.Equal(t, "test-channel", capturedChannel.Name)
}
// --- 6.3 Delete when GetGroupIDs fails ---
func TestChannelDelete_GetGroupIDsError(t *testing.T) {
deleted := false
repo := &mockChannelRepository{
getGroupIDsFn: func(_ context.Context, _ int64) ([]int64, error) {
return nil, errors.New("group IDs lookup failed")
},
deleteFn: func(_ context.Context, _ int64) error {
deleted = true
return nil
},
listAllFn: func(_ context.Context) ([]Channel, error) {
return nil, nil
},
}
svc := newTestChannelService(repo)
// Delete should still succeed even though GetGroupIDs returned error (degradation path L588-591)
err := svc.Delete(context.Background(), 1)
require.NoError(t, err)
require.True(t, deleted)
}
// --- 6.4 ReplaceModelInBody with invalid JSON ---
func TestReplaceModelInBody_InvalidJSON(t *testing.T) {
// Case 1: broken JSON object — gjson won't find "model", sjson does best-effort set
// (no panic, no error from sjson, but result is mutated garbage)
brokenBody := []byte("{broken")
result := ReplaceModelInBody(brokenBody, "new-model")
require.NotNil(t, result)
// sjson does not error on this input, so result differs from original — just verify no panic
// Case 2: JSON array — sjson.SetBytes returns error on non-object,
// triggering the L447 error fallback path that returns original body.
arrayBody := []byte("[]")
result2 := ReplaceModelInBody(arrayBody, "new-model")
require.Equal(t, arrayBody, result2)
}
// ===========================================================================
// 7. isPlatformPricingMatch
// ===========================================================================
func TestIsPlatformPricingMatch(t *testing.T) {
tests := []struct {
name string
groupPlatform string
pricingPlatform string
want bool
}{
{"antigravity does NOT match anthropic", PlatformAntigravity, PlatformAnthropic, false},
{"antigravity does NOT match gemini", PlatformAntigravity, PlatformGemini, false},
{"antigravity matches antigravity", PlatformAntigravity, PlatformAntigravity, true},
{"antigravity does NOT match openai", PlatformAntigravity, PlatformOpenAI, false},
{"anthropic matches anthropic", PlatformAnthropic, PlatformAnthropic, true},
{"anthropic does NOT match antigravity", PlatformAnthropic, PlatformAntigravity, false},
{"anthropic does NOT match gemini", PlatformAnthropic, PlatformGemini, false},
{"gemini matches gemini", PlatformGemini, PlatformGemini, true},
{"gemini does NOT match antigravity", PlatformGemini, PlatformAntigravity, false},
{"gemini does NOT match anthropic", PlatformGemini, PlatformAnthropic, false},
{"empty string matches nothing", "", PlatformAnthropic, false},
{"empty string matches empty", "", "", true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
require.Equal(t, tt.want, isPlatformPricingMatch(tt.groupPlatform, tt.pricingPlatform))
})
}
}
// ===========================================================================
// 8. matchingPlatforms
// ===========================================================================
func TestMatchingPlatforms(t *testing.T) {
tests := []struct {
name string
groupPlatform string
want []string
}{
{"antigravity returns itself only", PlatformAntigravity, []string{PlatformAntigravity}},
{"anthropic returns itself", PlatformAnthropic, []string{PlatformAnthropic}},
{"gemini returns itself", PlatformGemini, []string{PlatformGemini}},
{"openai returns itself", PlatformOpenAI, []string{PlatformOpenAI}},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := matchingPlatforms(tt.groupPlatform)
require.Equal(t, tt.want, result)
})
}
}
// ===========================================================================
// 9. Antigravity platform isolation — no cross-platform pricing leakage
// ===========================================================================
func TestGetChannelModelPricing_AntigravityDoesNotSeeCrossPlatformPricing(t *testing.T) {
// Channel has anthropic pricing for claude-opus-4-6.
// Group 10 is antigravity — should NOT see the anthropic pricing.
ch := Channel{
ID: 1,
Status: StatusActive,
GroupIDs: []int64{10},
ModelPricing: []ChannelModelPricing{
{ID: 100, Platform: PlatformAnthropic, Models: []string{"claude-opus-4-6"}, InputPrice: testPtrFloat64(15e-6)},
},
}
repo := makeStandardRepo(ch, map[int64]string{10: PlatformAntigravity})
svc := newTestChannelService(repo)
result := svc.GetChannelModelPricing(context.Background(), 10, "claude-opus-4-6")
require.Nil(t, result, "antigravity group should NOT see anthropic-platform pricing")
}
func TestGetChannelModelPricing_AnthropicCannotSeeAntigravityPricing(t *testing.T) {
// Channel has antigravity-platform pricing for claude-opus-4-6.
// Group 10 is anthropic — should NOT see antigravity pricing (no cross-platform leakage).
ch := Channel{
ID: 1,
Status: StatusActive,
GroupIDs: []int64{10},
ModelPricing: []ChannelModelPricing{
{ID: 100, Platform: PlatformAntigravity, Models: []string{"claude-opus-4-6"}, InputPrice: testPtrFloat64(15e-6)},
},
}
repo := makeStandardRepo(ch, map[int64]string{10: PlatformAnthropic})
svc := newTestChannelService(repo)
result := svc.GetChannelModelPricing(context.Background(), 10, "claude-opus-4-6")
require.Nil(t, result, "anthropic group should NOT see antigravity-platform pricing")
}
// ===========================================================================
// 10. Antigravity platform isolation — no cross-platform model mapping
// ===========================================================================
func TestResolveChannelMapping_AntigravityDoesNotSeeCrossPlatformMapping(t *testing.T) {
// Channel has anthropic model mapping: claude-opus-4-5 → claude-opus-4-6.
// Group 10 is antigravity — should NOT apply the anthropic mapping.
ch := Channel{
ID: 1,
Status: StatusActive,
GroupIDs: []int64{10},
ModelMapping: map[string]map[string]string{
PlatformAnthropic: {
"claude-opus-4-5": "claude-opus-4-6",
},
},
}
repo := makeStandardRepo(ch, map[int64]string{10: PlatformAntigravity})
svc := newTestChannelService(repo)
result := svc.ResolveChannelMapping(context.Background(), 10, "claude-opus-4-5")
require.False(t, result.Mapped, "antigravity group should NOT apply anthropic mapping")
require.Equal(t, "claude-opus-4-5", result.MappedModel)
}
// ===========================================================================
// 11. Antigravity platform isolation — same-name model across platforms
// ===========================================================================
func TestGetChannelModelPricing_AntigravityDoesNotSeeSameModelFromOtherPlatforms(t *testing.T) {
// anthropic 和 gemini 都定义了同名模型 "shared-model",价格不同。
// antigravity 分组不应看到任何一个(各平台严格独立)。
ch := Channel{
ID: 1,
Status: StatusActive,
GroupIDs: []int64{10},
ModelPricing: []ChannelModelPricing{
{ID: 200, Platform: PlatformAnthropic, Models: []string{"shared-model"}, InputPrice: testPtrFloat64(10e-6)},
{ID: 201, Platform: PlatformGemini, Models: []string{"shared-model"}, InputPrice: testPtrFloat64(5e-6)},
},
}
repo := makeStandardRepo(ch, map[int64]string{10: PlatformAntigravity})
svc := newTestChannelService(repo)
result := svc.GetChannelModelPricing(context.Background(), 10, "shared-model")
require.Nil(t, result, "antigravity group should NOT see anthropic/gemini-platform pricing")
}
func TestGetChannelModelPricing_AntigravityDoesNotSeeGeminiOnlyPricing(t *testing.T) {
// 只有 gemini 平台定义了模型 "gemini-model"。
// antigravity 分组不应看到 gemini 的定价。
ch := Channel{
ID: 1,
Status: StatusActive,
GroupIDs: []int64{10},
ModelPricing: []ChannelModelPricing{
{ID: 300, Platform: PlatformGemini, Models: []string{"gemini-model"}, InputPrice: testPtrFloat64(2e-6)},
},
}
repo := makeStandardRepo(ch, map[int64]string{10: PlatformAntigravity})
svc := newTestChannelService(repo)
result := svc.GetChannelModelPricing(context.Background(), 10, "gemini-model")
require.Nil(t, result, "antigravity group should NOT see gemini-platform pricing")
}
func TestGetChannelModelPricing_AntigravityDoesNotSeeWildcardFromOtherPlatforms(t *testing.T) {
// anthropic 和 gemini 都有 "shared-*" 通配符定价。
// antigravity 分组不应命中任何一个。
ch := Channel{
ID: 1,
Status: StatusActive,
GroupIDs: []int64{10},
ModelPricing: []ChannelModelPricing{
{ID: 400, Platform: PlatformAnthropic, Models: []string{"shared-*"}, InputPrice: testPtrFloat64(10e-6)},
{ID: 401, Platform: PlatformGemini, Models: []string{"shared-*"}, InputPrice: testPtrFloat64(5e-6)},
},
}
repo := makeStandardRepo(ch, map[int64]string{10: PlatformAntigravity})
svc := newTestChannelService(repo)
result := svc.GetChannelModelPricing(context.Background(), 10, "shared-model")
require.Nil(t, result, "antigravity group should NOT see wildcard pricing from other platforms")
}
func TestResolveChannelMapping_AntigravityDoesNotSeeMappingFromOtherPlatforms(t *testing.T) {
// anthropic 和 gemini 都定义了同名模型映射 "alias" → 不同目标。
// antigravity 分组不应命中任何一个。
ch := Channel{
ID: 1,
Status: StatusActive,
GroupIDs: []int64{10},
ModelMapping: map[string]map[string]string{
PlatformAnthropic: {"alias": "anthropic-target"},
PlatformGemini: {"alias": "gemini-target"},
},
}
repo := makeStandardRepo(ch, map[int64]string{10: PlatformAntigravity})
svc := newTestChannelService(repo)
result := svc.ResolveChannelMapping(context.Background(), 10, "alias")
require.False(t, result.Mapped, "antigravity group should NOT see mapping from other platforms")
require.Equal(t, "alias", result.MappedModel)
}
func TestCheckRestricted_AntigravityDoesNotSeeModelsFromOtherPlatforms(t *testing.T) {
// anthropic 和 gemini 都定义了同名模型 "shared-model"。
// antigravity 分组启用了 RestrictModels"shared-model" 应被限制(各平台独立)。
ch := Channel{
ID: 1,
Status: StatusActive,
RestrictModels: true,
GroupIDs: []int64{10},
ModelPricing: []ChannelModelPricing{
{ID: 500, Platform: PlatformAnthropic, Models: []string{"shared-model"}, InputPrice: testPtrFloat64(10e-6)},
{ID: 501, Platform: PlatformGemini, Models: []string{"shared-model"}, InputPrice: testPtrFloat64(5e-6)},
},
}
repo := makeStandardRepo(ch, map[int64]string{10: PlatformAntigravity})
svc := newTestChannelService(repo)
restricted := svc.IsModelRestricted(context.Background(), 10, "shared-model")
require.True(t, restricted, "shared-model from other platforms should be restricted for antigravity")
restricted = svc.IsModelRestricted(context.Background(), 10, "unknown-model")
require.True(t, restricted, "unknown-model should be restricted for antigravity")
}
func TestGetChannelModelPricing_AntigravityOwnPricingWorks(t *testing.T) {
// antigravity 平台自己配置的定价应正常生效(覆盖 Claude 和 Gemini 模型)。
ch := Channel{
ID: 1,
Status: StatusActive,
GroupIDs: []int64{10},
ModelPricing: []ChannelModelPricing{
{ID: 600, Platform: PlatformAntigravity, Models: []string{"claude-*"}, InputPrice: testPtrFloat64(15e-6)},
{ID: 601, Platform: PlatformAntigravity, Models: []string{"gemini-*"}, InputPrice: testPtrFloat64(2e-6)},
},
}
repo := makeStandardRepo(ch, map[int64]string{10: PlatformAntigravity})
svc := newTestChannelService(repo)
// Claude 模型匹配 antigravity 定价
result := svc.GetChannelModelPricing(context.Background(), 10, "claude-sonnet-4")
require.NotNil(t, result)
require.Equal(t, int64(600), result.ID)
require.InDelta(t, 15e-6, *result.InputPrice, 1e-12)
// Gemini 模型匹配 antigravity 定价
result = svc.GetChannelModelPricing(context.Background(), 10, "gemini-2.5-flash")
require.NotNil(t, result)
require.Equal(t, int64(601), result.ID)
require.InDelta(t, 2e-6, *result.InputPrice, 1e-12)
}
func TestGetChannelModelPricing_NonAntigravityUnaffected(t *testing.T) {
// 确保非 antigravity 平台的行为不受影响。
// anthropic 分组只能看到 anthropic 的定价,看不到 gemini 的。
ch := Channel{
ID: 1,
Status: StatusActive,
GroupIDs: []int64{10, 20},
ModelPricing: []ChannelModelPricing{
{ID: 600, Platform: PlatformAnthropic, Models: []string{"shared-model"}, InputPrice: testPtrFloat64(10e-6)},
{ID: 601, Platform: PlatformGemini, Models: []string{"shared-model"}, InputPrice: testPtrFloat64(5e-6)},
},
}
repo := makeStandardRepo(ch, map[int64]string{10: PlatformAnthropic, 20: PlatformGemini})
svc := newTestChannelService(repo)
// anthropic 分组应该只看到 anthropic 的定价
result := svc.GetChannelModelPricing(context.Background(), 10, "shared-model")
require.NotNil(t, result)
require.Equal(t, int64(600), result.ID)
require.InDelta(t, 10e-6, *result.InputPrice, 1e-12)
// gemini 分组应该只看到 gemini 的定价
result = svc.GetChannelModelPricing(context.Background(), 20, "shared-model")
require.NotNil(t, result)
require.Equal(t, int64(601), result.ID)
require.InDelta(t, 5e-6, *result.InputPrice, 1e-12)
}