Why: channels with model pricing entries but no model mapping (e.g. azcc with 3 priced claude models, no mapping) were rendering as 未配置模型 in the 'Available Channels' page. The algorithm only iterated ModelMapping and silently dropped any platform without a mapping entry. Changes: - channel.go: SupportedModels now unions mapping + pricing entries. For exact mapping src → target, pricing is looked up by target (the actually billed name), not by src. - channel_available.go: ListAvailable enriches each entry with nil pricing via PricingService.GetModelPricing (global LiteLLM fallback) so the popover always shows a price. - channel_service.go: NewChannelService takes *PricingService as 4th param. - channel_test.go: rewrote 4 tests that froze the old mapping-only semantics; added pricing-only / mapping-target / target-missing coverage.
2406 lines
76 KiB
Go
2406 lines
76 KiB
Go
//go:build unit
|
||
|
||
package service
|
||
|
||
import (
|
||
"context"
|
||
"errors"
|
||
"testing"
|
||
|
||
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||
"github.com/stretchr/testify/require"
|
||
)
|
||
|
||
// ---------------------------------------------------------------------------
|
||
// Mock: ChannelRepository
|
||
// ---------------------------------------------------------------------------
|
||
|
||
type mockChannelRepository struct {
|
||
listAllFn func(ctx context.Context) ([]Channel, error)
|
||
getGroupPlatformsFn func(ctx context.Context, groupIDs []int64) (map[int64]string, error)
|
||
createFn func(ctx context.Context, channel *Channel) error
|
||
getByIDFn func(ctx context.Context, id int64) (*Channel, error)
|
||
updateFn func(ctx context.Context, channel *Channel) error
|
||
deleteFn func(ctx context.Context, id int64) error
|
||
listFn func(ctx context.Context, params pagination.PaginationParams, status, search string) ([]Channel, *pagination.PaginationResult, error)
|
||
existsByNameFn func(ctx context.Context, name string) (bool, error)
|
||
existsByNameExcludingFn func(ctx context.Context, name string, excludeID int64) (bool, error)
|
||
getGroupIDsFn func(ctx context.Context, channelID int64) ([]int64, error)
|
||
setGroupIDsFn func(ctx context.Context, channelID int64, groupIDs []int64) error
|
||
getChannelIDByGroupIDFn func(ctx context.Context, groupID int64) (int64, error)
|
||
getGroupsInOtherChannelsFn func(ctx context.Context, channelID int64, groupIDs []int64) ([]int64, error)
|
||
listModelPricingFn func(ctx context.Context, channelID int64) ([]ChannelModelPricing, error)
|
||
createModelPricingFn func(ctx context.Context, pricing *ChannelModelPricing) error
|
||
updateModelPricingFn func(ctx context.Context, pricing *ChannelModelPricing) error
|
||
deleteModelPricingFn func(ctx context.Context, id int64) error
|
||
replaceModelPricingFn func(ctx context.Context, channelID int64, pricingList []ChannelModelPricing) error
|
||
}
|
||
|
||
func (m *mockChannelRepository) Create(ctx context.Context, channel *Channel) error {
|
||
if m.createFn != nil {
|
||
return m.createFn(ctx, channel)
|
||
}
|
||
return nil
|
||
}
|
||
|
||
func (m *mockChannelRepository) GetByID(ctx context.Context, id int64) (*Channel, error) {
|
||
if m.getByIDFn != nil {
|
||
return m.getByIDFn(ctx, id)
|
||
}
|
||
return nil, ErrChannelNotFound
|
||
}
|
||
|
||
func (m *mockChannelRepository) Update(ctx context.Context, channel *Channel) error {
|
||
if m.updateFn != nil {
|
||
return m.updateFn(ctx, channel)
|
||
}
|
||
return nil
|
||
}
|
||
|
||
func (m *mockChannelRepository) Delete(ctx context.Context, id int64) error {
|
||
if m.deleteFn != nil {
|
||
return m.deleteFn(ctx, id)
|
||
}
|
||
return nil
|
||
}
|
||
|
||
func (m *mockChannelRepository) List(ctx context.Context, params pagination.PaginationParams, status, search string) ([]Channel, *pagination.PaginationResult, error) {
|
||
if m.listFn != nil {
|
||
return m.listFn(ctx, params, status, search)
|
||
}
|
||
return nil, nil, nil
|
||
}
|
||
|
||
func (m *mockChannelRepository) ListAll(ctx context.Context) ([]Channel, error) {
|
||
if m.listAllFn != nil {
|
||
return m.listAllFn(ctx)
|
||
}
|
||
return nil, nil
|
||
}
|
||
|
||
func (m *mockChannelRepository) ExistsByName(ctx context.Context, name string) (bool, error) {
|
||
if m.existsByNameFn != nil {
|
||
return m.existsByNameFn(ctx, name)
|
||
}
|
||
return false, nil
|
||
}
|
||
|
||
func (m *mockChannelRepository) ExistsByNameExcluding(ctx context.Context, name string, excludeID int64) (bool, error) {
|
||
if m.existsByNameExcludingFn != nil {
|
||
return m.existsByNameExcludingFn(ctx, name, excludeID)
|
||
}
|
||
return false, nil
|
||
}
|
||
|
||
func (m *mockChannelRepository) GetGroupIDs(ctx context.Context, channelID int64) ([]int64, error) {
|
||
if m.getGroupIDsFn != nil {
|
||
return m.getGroupIDsFn(ctx, channelID)
|
||
}
|
||
return nil, nil
|
||
}
|
||
|
||
func (m *mockChannelRepository) SetGroupIDs(ctx context.Context, channelID int64, groupIDs []int64) error {
|
||
if m.setGroupIDsFn != nil {
|
||
return m.setGroupIDsFn(ctx, channelID, groupIDs)
|
||
}
|
||
return nil
|
||
}
|
||
|
||
func (m *mockChannelRepository) GetChannelIDByGroupID(ctx context.Context, groupID int64) (int64, error) {
|
||
if m.getChannelIDByGroupIDFn != nil {
|
||
return m.getChannelIDByGroupIDFn(ctx, groupID)
|
||
}
|
||
return 0, nil
|
||
}
|
||
|
||
func (m *mockChannelRepository) GetGroupsInOtherChannels(ctx context.Context, channelID int64, groupIDs []int64) ([]int64, error) {
|
||
if m.getGroupsInOtherChannelsFn != nil {
|
||
return m.getGroupsInOtherChannelsFn(ctx, channelID, groupIDs)
|
||
}
|
||
return nil, nil
|
||
}
|
||
|
||
func (m *mockChannelRepository) GetGroupPlatforms(ctx context.Context, groupIDs []int64) (map[int64]string, error) {
|
||
if m.getGroupPlatformsFn != nil {
|
||
return m.getGroupPlatformsFn(ctx, groupIDs)
|
||
}
|
||
return nil, nil
|
||
}
|
||
|
||
func (m *mockChannelRepository) ListModelPricing(ctx context.Context, channelID int64) ([]ChannelModelPricing, error) {
|
||
if m.listModelPricingFn != nil {
|
||
return m.listModelPricingFn(ctx, channelID)
|
||
}
|
||
return nil, nil
|
||
}
|
||
|
||
func (m *mockChannelRepository) CreateModelPricing(ctx context.Context, pricing *ChannelModelPricing) error {
|
||
if m.createModelPricingFn != nil {
|
||
return m.createModelPricingFn(ctx, pricing)
|
||
}
|
||
return nil
|
||
}
|
||
|
||
func (m *mockChannelRepository) UpdateModelPricing(ctx context.Context, pricing *ChannelModelPricing) error {
|
||
if m.updateModelPricingFn != nil {
|
||
return m.updateModelPricingFn(ctx, pricing)
|
||
}
|
||
return nil
|
||
}
|
||
|
||
func (m *mockChannelRepository) DeleteModelPricing(ctx context.Context, id int64) error {
|
||
if m.deleteModelPricingFn != nil {
|
||
return m.deleteModelPricingFn(ctx, id)
|
||
}
|
||
return nil
|
||
}
|
||
|
||
func (m *mockChannelRepository) ReplaceModelPricing(ctx context.Context, channelID int64, pricingList []ChannelModelPricing) error {
|
||
if m.replaceModelPricingFn != nil {
|
||
return m.replaceModelPricingFn(ctx, channelID, pricingList)
|
||
}
|
||
return nil
|
||
}
|
||
|
||
// ---------------------------------------------------------------------------
|
||
// Mock: APIKeyAuthCacheInvalidator
|
||
// ---------------------------------------------------------------------------
|
||
|
||
type mockChannelAuthCacheInvalidator struct {
|
||
invalidatedGroupIDs []int64
|
||
invalidatedKeys []string
|
||
invalidatedUserIDs []int64
|
||
}
|
||
|
||
func (m *mockChannelAuthCacheInvalidator) InvalidateAuthCacheByKey(_ context.Context, key string) {
|
||
m.invalidatedKeys = append(m.invalidatedKeys, key)
|
||
}
|
||
|
||
func (m *mockChannelAuthCacheInvalidator) InvalidateAuthCacheByUserID(_ context.Context, userID int64) {
|
||
m.invalidatedUserIDs = append(m.invalidatedUserIDs, userID)
|
||
}
|
||
|
||
func (m *mockChannelAuthCacheInvalidator) InvalidateAuthCacheByGroupID(_ context.Context, groupID int64) {
|
||
m.invalidatedGroupIDs = append(m.invalidatedGroupIDs, groupID)
|
||
}
|
||
|
||
// ---------------------------------------------------------------------------
|
||
// Helpers
|
||
// ---------------------------------------------------------------------------
|
||
|
||
func newTestChannelService(repo *mockChannelRepository) *ChannelService {
|
||
return NewChannelService(repo, nil, nil, nil)
|
||
}
|
||
|
||
func newTestChannelServiceWithAuth(repo *mockChannelRepository, auth *mockChannelAuthCacheInvalidator) *ChannelService {
|
||
return NewChannelService(repo, nil, auth, nil)
|
||
}
|
||
|
||
// makeStandardRepo returns a repo that serves one active channel with anthropic pricing
|
||
// for group 1, with the given model pricing and model mapping.
|
||
func makeStandardRepo(ch Channel, groupPlatforms map[int64]string) *mockChannelRepository {
|
||
return &mockChannelRepository{
|
||
listAllFn: func(_ context.Context) ([]Channel, error) {
|
||
return []Channel{ch}, nil
|
||
},
|
||
getGroupPlatformsFn: func(_ context.Context, _ []int64) (map[int64]string, error) {
|
||
return groupPlatforms, nil
|
||
},
|
||
}
|
||
}
|
||
|
||
// ===========================================================================
|
||
// 1. BuildModelMappingChain
|
||
// ===========================================================================
|
||
|
||
func TestBuildModelMappingChain(t *testing.T) {
|
||
tests := []struct {
|
||
name string
|
||
result ChannelMappingResult
|
||
requestModel string
|
||
upstreamModel string
|
||
want string
|
||
}{
|
||
{
|
||
name: "no mapping, no upstream diff",
|
||
result: ChannelMappingResult{Mapped: false, MappedModel: "claude-sonnet-4"},
|
||
requestModel: "claude-sonnet-4",
|
||
upstreamModel: "claude-sonnet-4",
|
||
want: "",
|
||
},
|
||
{
|
||
name: "no mapping, upstream differs",
|
||
result: ChannelMappingResult{Mapped: false, MappedModel: "claude-sonnet-4"},
|
||
requestModel: "claude-sonnet-4",
|
||
upstreamModel: "claude-sonnet-4-20250514",
|
||
want: "claude-sonnet-4\u2192claude-sonnet-4-20250514",
|
||
},
|
||
{
|
||
name: "mapped, upstream differs",
|
||
result: ChannelMappingResult{Mapped: true, MappedModel: "claude-sonnet-4-20250514"},
|
||
requestModel: "my-model",
|
||
upstreamModel: "actual-upstream",
|
||
want: "my-model\u2192claude-sonnet-4-20250514\u2192actual-upstream",
|
||
},
|
||
{
|
||
name: "mapped, upstream same as mapped",
|
||
result: ChannelMappingResult{Mapped: true, MappedModel: "claude-sonnet-4-20250514"},
|
||
requestModel: "claude-sonnet-4",
|
||
upstreamModel: "claude-sonnet-4-20250514",
|
||
want: "claude-sonnet-4\u2192claude-sonnet-4-20250514",
|
||
},
|
||
{
|
||
name: "mapped, upstream empty",
|
||
result: ChannelMappingResult{Mapped: true, MappedModel: "target-model"},
|
||
requestModel: "my-model",
|
||
upstreamModel: "",
|
||
want: "my-model\u2192target-model",
|
||
},
|
||
}
|
||
|
||
for _, tt := range tests {
|
||
t.Run(tt.name, func(t *testing.T) {
|
||
got := tt.result.BuildModelMappingChain(tt.requestModel, tt.upstreamModel)
|
||
require.Equal(t, tt.want, got)
|
||
})
|
||
}
|
||
}
|
||
|
||
// ===========================================================================
|
||
// 2. ReplaceModelInBody
|
||
// ===========================================================================
|
||
|
||
func TestReplaceModelInBody(t *testing.T) {
|
||
tests := []struct {
|
||
name string
|
||
body []byte
|
||
newModel string
|
||
check func(t *testing.T, result []byte)
|
||
}{
|
||
{
|
||
name: "empty body",
|
||
body: []byte{},
|
||
newModel: "new-model",
|
||
check: func(t *testing.T, result []byte) {
|
||
require.Equal(t, []byte{}, result)
|
||
},
|
||
},
|
||
{
|
||
name: "model already equal",
|
||
body: []byte(`{"model":"claude-sonnet-4","temperature":0.7}`),
|
||
newModel: "claude-sonnet-4",
|
||
check: func(t *testing.T, result []byte) {
|
||
require.Equal(t, []byte(`{"model":"claude-sonnet-4","temperature":0.7}`), result)
|
||
},
|
||
},
|
||
{
|
||
name: "model different",
|
||
body: []byte(`{"model":"claude-sonnet-4","temperature":0.7}`),
|
||
newModel: "claude-opus-4",
|
||
check: func(t *testing.T, result []byte) {
|
||
require.Contains(t, string(result), `"model":"claude-opus-4"`)
|
||
require.Contains(t, string(result), `"temperature"`)
|
||
},
|
||
},
|
||
{
|
||
name: "no model field",
|
||
body: []byte(`{"temperature":0.7}`),
|
||
newModel: "claude-opus-4",
|
||
check: func(t *testing.T, result []byte) {
|
||
require.Contains(t, string(result), `"model":"claude-opus-4"`)
|
||
require.Contains(t, string(result), `"temperature"`)
|
||
},
|
||
},
|
||
}
|
||
|
||
for _, tt := range tests {
|
||
t.Run(tt.name, func(t *testing.T) {
|
||
result := ReplaceModelInBody(tt.body, tt.newModel)
|
||
tt.check(t, result)
|
||
})
|
||
}
|
||
}
|
||
|
||
// ===========================================================================
|
||
// 3. validateNoConflictingModels + validateNoConflictingMappings
|
||
// ===========================================================================
|
||
|
||
func TestValidateNoConflictingModels(t *testing.T) {
|
||
tests := []struct {
|
||
name string
|
||
pricingList []ChannelModelPricing
|
||
wantErr bool
|
||
errContains string
|
||
}{
|
||
{
|
||
name: "no duplicates",
|
||
pricingList: []ChannelModelPricing{
|
||
{Platform: "anthropic", Models: []string{"claude-sonnet-4", "claude-opus-4"}},
|
||
{Platform: "openai", Models: []string{"gpt-5.1"}},
|
||
},
|
||
wantErr: false,
|
||
},
|
||
{
|
||
name: "same platform duplicate",
|
||
pricingList: []ChannelModelPricing{
|
||
{Platform: "anthropic", Models: []string{"claude-sonnet-4"}},
|
||
{Platform: "anthropic", Models: []string{"claude-sonnet-4"}},
|
||
},
|
||
wantErr: true,
|
||
errContains: "claude-sonnet-4",
|
||
},
|
||
{
|
||
name: "same model different platform",
|
||
pricingList: []ChannelModelPricing{
|
||
{Platform: "anthropic", Models: []string{"model-a"}},
|
||
{Platform: "openai", Models: []string{"model-a"}},
|
||
},
|
||
wantErr: false,
|
||
},
|
||
{
|
||
name: "case insensitive",
|
||
pricingList: []ChannelModelPricing{
|
||
{Platform: "anthropic", Models: []string{"Claude"}},
|
||
{Platform: "anthropic", Models: []string{"claude"}},
|
||
},
|
||
wantErr: true,
|
||
},
|
||
{
|
||
name: "empty list (nil)",
|
||
pricingList: nil,
|
||
wantErr: false,
|
||
},
|
||
{
|
||
name: "wildcard_vs_wildcard_conflict",
|
||
pricingList: []ChannelModelPricing{
|
||
{Platform: "anthropic", Models: []string{"claude-*"}},
|
||
{Platform: "anthropic", Models: []string{"claude-opus-*"}},
|
||
},
|
||
wantErr: true,
|
||
errContains: "conflict",
|
||
},
|
||
{
|
||
name: "wildcard_vs_exact_conflict",
|
||
pricingList: []ChannelModelPricing{
|
||
{Platform: "anthropic", Models: []string{"claude-*"}},
|
||
{Platform: "anthropic", Models: []string{"claude-opus-4-6"}},
|
||
},
|
||
wantErr: true,
|
||
errContains: "conflict",
|
||
},
|
||
{
|
||
name: "no_conflict_different_platform",
|
||
pricingList: []ChannelModelPricing{
|
||
{Platform: "anthropic", Models: []string{"claude-opus-*"}},
|
||
{Platform: "openai", Models: []string{"claude-*"}},
|
||
},
|
||
wantErr: false,
|
||
},
|
||
{
|
||
name: "no_conflict_same_platform_different_prefix",
|
||
pricingList: []ChannelModelPricing{
|
||
{Platform: "anthropic", Models: []string{"claude-opus-*"}},
|
||
{Platform: "anthropic", Models: []string{"gpt-*"}},
|
||
},
|
||
wantErr: false,
|
||
},
|
||
{
|
||
name: "catch_all_wildcard_conflicts_with_everything",
|
||
pricingList: []ChannelModelPricing{
|
||
{Platform: "openai", Models: []string{"*"}},
|
||
{Platform: "openai", Models: []string{"gpt-5"}},
|
||
},
|
||
wantErr: true,
|
||
errContains: "conflict",
|
||
},
|
||
}
|
||
|
||
for _, tt := range tests {
|
||
t.Run(tt.name, func(t *testing.T) {
|
||
err := validateNoConflictingModels(tt.pricingList)
|
||
if tt.wantErr {
|
||
require.Error(t, err)
|
||
if tt.errContains != "" {
|
||
require.Contains(t, err.Error(), tt.errContains)
|
||
}
|
||
} else {
|
||
require.NoError(t, err)
|
||
}
|
||
})
|
||
}
|
||
|
||
// Additional sub-case: explicit empty slice
|
||
t.Run("empty list (empty slice)", func(t *testing.T) {
|
||
err := validateNoConflictingModels([]ChannelModelPricing{})
|
||
require.NoError(t, err)
|
||
})
|
||
}
|
||
|
||
func TestValidateNoConflictingMappings(t *testing.T) {
|
||
tests := []struct {
|
||
name string
|
||
mapping map[string]map[string]string
|
||
wantErr bool
|
||
errContains string
|
||
}{
|
||
{
|
||
name: "nil mapping",
|
||
mapping: nil,
|
||
wantErr: false,
|
||
},
|
||
{
|
||
name: "empty mapping",
|
||
mapping: map[string]map[string]string{},
|
||
wantErr: false,
|
||
},
|
||
{
|
||
name: "no conflict",
|
||
mapping: map[string]map[string]string{
|
||
"anthropic": {"claude-opus-*": "opus", "gpt-*": "gpt"},
|
||
},
|
||
wantErr: false,
|
||
},
|
||
{
|
||
name: "wildcard vs wildcard conflict",
|
||
mapping: map[string]map[string]string{
|
||
"anthropic": {"claude-*": "a", "claude-opus-*": "b"},
|
||
},
|
||
wantErr: true,
|
||
errContains: "conflict",
|
||
},
|
||
{
|
||
name: "wildcard vs exact conflict",
|
||
mapping: map[string]map[string]string{
|
||
"openai": {"gpt-*": "a", "gpt-4o": "b"},
|
||
},
|
||
wantErr: true,
|
||
errContains: "conflict",
|
||
},
|
||
{
|
||
name: "exact duplicate conflict",
|
||
mapping: map[string]map[string]string{
|
||
"anthropic": {"claude-opus-4": "a"},
|
||
"openai": {"claude-opus-4": "b"},
|
||
},
|
||
wantErr: false, // different platforms
|
||
},
|
||
{
|
||
name: "different platforms no conflict",
|
||
mapping: map[string]map[string]string{
|
||
"anthropic": {"claude-*": "a"},
|
||
"openai": {"claude-*": "b"},
|
||
},
|
||
wantErr: false,
|
||
},
|
||
}
|
||
|
||
for _, tt := range tests {
|
||
t.Run(tt.name, func(t *testing.T) {
|
||
err := validateNoConflictingMappings(tt.mapping)
|
||
if tt.wantErr {
|
||
require.Error(t, err)
|
||
if tt.errContains != "" {
|
||
require.Contains(t, err.Error(), tt.errContains)
|
||
}
|
||
} else {
|
||
require.NoError(t, err)
|
||
}
|
||
})
|
||
}
|
||
}
|
||
|
||
func TestConflictsBetween(t *testing.T) {
|
||
tests := []struct {
|
||
name string
|
||
a, b modelEntry
|
||
want bool
|
||
}{
|
||
{
|
||
name: "exact same",
|
||
a: modelEntry{prefix: "claude-opus-4", wildcard: false},
|
||
b: modelEntry{prefix: "claude-opus-4", wildcard: false},
|
||
want: true,
|
||
},
|
||
{
|
||
name: "exact different",
|
||
a: modelEntry{prefix: "claude-opus-4", wildcard: false},
|
||
b: modelEntry{prefix: "gpt-4o", wildcard: false},
|
||
want: false,
|
||
},
|
||
{
|
||
name: "wildcard matches exact",
|
||
a: modelEntry{prefix: "claude-", wildcard: true},
|
||
b: modelEntry{prefix: "claude-opus-4", wildcard: false},
|
||
want: true,
|
||
},
|
||
{
|
||
name: "exact does not match unrelated wildcard",
|
||
a: modelEntry{prefix: "gpt-4o", wildcard: false},
|
||
b: modelEntry{prefix: "claude-", wildcard: true},
|
||
want: false,
|
||
},
|
||
{
|
||
name: "wildcard prefix overlap",
|
||
a: modelEntry{prefix: "claude-", wildcard: true},
|
||
b: modelEntry{prefix: "claude-opus-", wildcard: true},
|
||
want: true,
|
||
},
|
||
{
|
||
name: "wildcards no overlap",
|
||
a: modelEntry{prefix: "claude-", wildcard: true},
|
||
b: modelEntry{prefix: "gpt-", wildcard: true},
|
||
want: false,
|
||
},
|
||
{
|
||
name: "catch-all wildcard vs any",
|
||
a: modelEntry{prefix: "", wildcard: true},
|
||
b: modelEntry{prefix: "anything", wildcard: false},
|
||
want: true,
|
||
},
|
||
}
|
||
|
||
for _, tt := range tests {
|
||
t.Run(tt.name, func(t *testing.T) {
|
||
require.Equal(t, tt.want, conflictsBetween(tt.a, tt.b))
|
||
})
|
||
}
|
||
}
|
||
|
||
// ===========================================================================
|
||
// 4. Cache Building + Hot Path Methods
|
||
// ===========================================================================
|
||
|
||
// --- 4.1 GetChannelForGroup ---
|
||
|
||
func TestGetChannelForGroup_Success(t *testing.T) {
|
||
ch := Channel{
|
||
ID: 1,
|
||
Name: "test-channel",
|
||
Status: StatusActive,
|
||
GroupIDs: []int64{10},
|
||
}
|
||
repo := makeStandardRepo(ch, map[int64]string{10: "anthropic"})
|
||
svc := newTestChannelService(repo)
|
||
|
||
result, err := svc.GetChannelForGroup(context.Background(), 10)
|
||
require.NoError(t, err)
|
||
require.NotNil(t, result)
|
||
require.Equal(t, int64(1), result.ID)
|
||
require.Equal(t, "test-channel", result.Name)
|
||
|
||
// returned value should be a clone
|
||
result.Name = "mutated"
|
||
result2, err := svc.GetChannelForGroup(context.Background(), 10)
|
||
require.NoError(t, err)
|
||
require.Equal(t, "test-channel", result2.Name)
|
||
}
|
||
|
||
func TestGetChannelForGroup_InactiveChannel(t *testing.T) {
|
||
ch := Channel{
|
||
ID: 1,
|
||
Status: StatusDisabled,
|
||
GroupIDs: []int64{10},
|
||
}
|
||
repo := makeStandardRepo(ch, map[int64]string{10: "anthropic"})
|
||
svc := newTestChannelService(repo)
|
||
|
||
result, err := svc.GetChannelForGroup(context.Background(), 10)
|
||
require.NoError(t, err)
|
||
require.Nil(t, result)
|
||
}
|
||
|
||
func TestGetChannelForGroup_NoChannel(t *testing.T) {
|
||
ch := Channel{
|
||
ID: 1,
|
||
Status: StatusActive,
|
||
GroupIDs: []int64{10},
|
||
}
|
||
repo := makeStandardRepo(ch, map[int64]string{10: "anthropic"})
|
||
svc := newTestChannelService(repo)
|
||
|
||
result, err := svc.GetChannelForGroup(context.Background(), 999)
|
||
require.NoError(t, err)
|
||
require.Nil(t, result)
|
||
}
|
||
|
||
func TestGetChannelForGroup_CacheError(t *testing.T) {
|
||
repo := &mockChannelRepository{
|
||
listAllFn: func(_ context.Context) ([]Channel, error) {
|
||
return nil, errors.New("db connection failed")
|
||
},
|
||
}
|
||
svc := newTestChannelService(repo)
|
||
|
||
result, err := svc.GetChannelForGroup(context.Background(), 10)
|
||
require.Error(t, err)
|
||
require.Nil(t, result)
|
||
require.Contains(t, err.Error(), "db connection failed")
|
||
}
|
||
|
||
// --- 4.2 GetChannelModelPricing ---
|
||
|
||
func TestGetChannelModelPricing_ExactMatch(t *testing.T) {
|
||
ch := Channel{
|
||
ID: 1,
|
||
Status: StatusActive,
|
||
GroupIDs: []int64{10},
|
||
ModelPricing: []ChannelModelPricing{
|
||
{ID: 100, Platform: "anthropic", Models: []string{"claude-opus-4"}, InputPrice: testPtrFloat64(15e-6)},
|
||
},
|
||
}
|
||
repo := makeStandardRepo(ch, map[int64]string{10: "anthropic"})
|
||
svc := newTestChannelService(repo)
|
||
|
||
result := svc.GetChannelModelPricing(context.Background(), 10, "claude-opus-4")
|
||
require.NotNil(t, result)
|
||
require.Equal(t, int64(100), result.ID)
|
||
require.InDelta(t, 15e-6, *result.InputPrice, 1e-12)
|
||
}
|
||
|
||
func TestGetChannelModelPricing_CaseInsensitive(t *testing.T) {
|
||
ch := Channel{
|
||
ID: 1,
|
||
Status: StatusActive,
|
||
GroupIDs: []int64{10},
|
||
ModelPricing: []ChannelModelPricing{
|
||
{ID: 100, Platform: "anthropic", Models: []string{"claude-opus-4"}, InputPrice: testPtrFloat64(15e-6)},
|
||
},
|
||
}
|
||
repo := makeStandardRepo(ch, map[int64]string{10: "anthropic"})
|
||
svc := newTestChannelService(repo)
|
||
|
||
result := svc.GetChannelModelPricing(context.Background(), 10, "Claude-Opus-4")
|
||
require.NotNil(t, result)
|
||
require.Equal(t, int64(100), result.ID)
|
||
}
|
||
|
||
func TestGetChannelModelPricing_WildcardMatch(t *testing.T) {
|
||
ch := Channel{
|
||
ID: 1,
|
||
Status: StatusActive,
|
||
GroupIDs: []int64{10},
|
||
ModelPricing: []ChannelModelPricing{
|
||
{ID: 200, Platform: "anthropic", Models: []string{"claude-*"}, InputPrice: testPtrFloat64(10e-6)},
|
||
},
|
||
}
|
||
repo := makeStandardRepo(ch, map[int64]string{10: "anthropic"})
|
||
svc := newTestChannelService(repo)
|
||
|
||
result := svc.GetChannelModelPricing(context.Background(), 10, "claude-sonnet-4")
|
||
require.NotNil(t, result)
|
||
require.Equal(t, int64(200), result.ID)
|
||
}
|
||
|
||
func TestGetChannelModelPricing_WildcardFirstMatch(t *testing.T) {
|
||
ch := Channel{
|
||
ID: 1,
|
||
Status: StatusActive,
|
||
GroupIDs: []int64{10},
|
||
ModelPricing: []ChannelModelPricing{
|
||
{ID: 200, Platform: "anthropic", Models: []string{"claude-*"}, InputPrice: testPtrFloat64(10e-6)},
|
||
{ID: 300, Platform: "anthropic", Models: []string{"claude-sonnet-*"}, InputPrice: testPtrFloat64(5e-6)},
|
||
},
|
||
}
|
||
repo := makeStandardRepo(ch, map[int64]string{10: "anthropic"})
|
||
svc := newTestChannelService(repo)
|
||
|
||
result := svc.GetChannelModelPricing(context.Background(), 10, "claude-sonnet-4-20250514")
|
||
require.NotNil(t, result)
|
||
// "claude-*" is defined first, so it matches first regardless of prefix length
|
||
require.Equal(t, int64(200), result.ID)
|
||
require.InDelta(t, 10e-6, *result.InputPrice, 1e-12)
|
||
}
|
||
|
||
func TestGetChannelModelPricing_NoMatch(t *testing.T) {
|
||
ch := Channel{
|
||
ID: 1,
|
||
Status: StatusActive,
|
||
GroupIDs: []int64{10},
|
||
ModelPricing: []ChannelModelPricing{
|
||
{ID: 100, Platform: "anthropic", Models: []string{"claude-opus-4"}, InputPrice: testPtrFloat64(15e-6)},
|
||
},
|
||
}
|
||
repo := makeStandardRepo(ch, map[int64]string{10: "anthropic"})
|
||
svc := newTestChannelService(repo)
|
||
|
||
result := svc.GetChannelModelPricing(context.Background(), 10, "gpt-5.1")
|
||
require.Nil(t, result)
|
||
}
|
||
|
||
func TestGetChannelModelPricing_InactiveChannel(t *testing.T) {
|
||
ch := Channel{
|
||
ID: 1,
|
||
Status: StatusDisabled,
|
||
GroupIDs: []int64{10},
|
||
ModelPricing: []ChannelModelPricing{
|
||
{ID: 100, Platform: "anthropic", Models: []string{"claude-opus-4"}},
|
||
},
|
||
}
|
||
repo := makeStandardRepo(ch, map[int64]string{10: "anthropic"})
|
||
svc := newTestChannelService(repo)
|
||
|
||
result := svc.GetChannelModelPricing(context.Background(), 10, "claude-opus-4")
|
||
require.Nil(t, result)
|
||
}
|
||
|
||
func TestGetChannelModelPricing_PlatformFiltering(t *testing.T) {
|
||
ch := Channel{
|
||
ID: 1,
|
||
Status: StatusActive,
|
||
GroupIDs: []int64{10, 20},
|
||
ModelPricing: []ChannelModelPricing{
|
||
{ID: 100, Platform: "openai", Models: []string{"gpt-5.1"}, InputPrice: testPtrFloat64(5e-6)},
|
||
{ID: 200, Platform: "anthropic", Models: []string{"claude-opus-4"}, InputPrice: testPtrFloat64(15e-6)},
|
||
},
|
||
}
|
||
repo := makeStandardRepo(ch, map[int64]string{10: "anthropic", 20: "openai"})
|
||
svc := newTestChannelService(repo)
|
||
|
||
// Group 10 (anthropic) should NOT see openai pricing
|
||
result := svc.GetChannelModelPricing(context.Background(), 10, "gpt-5.1")
|
||
require.Nil(t, result)
|
||
|
||
// Group 10 (anthropic) should see anthropic pricing
|
||
result = svc.GetChannelModelPricing(context.Background(), 10, "claude-opus-4")
|
||
require.NotNil(t, result)
|
||
require.Equal(t, int64(200), result.ID)
|
||
|
||
// Group 20 (openai) should see openai pricing
|
||
result = svc.GetChannelModelPricing(context.Background(), 20, "gpt-5.1")
|
||
require.NotNil(t, result)
|
||
require.Equal(t, int64(100), result.ID)
|
||
|
||
// Group 20 (openai) should NOT see anthropic pricing
|
||
result = svc.GetChannelModelPricing(context.Background(), 20, "claude-opus-4")
|
||
require.Nil(t, result)
|
||
}
|
||
|
||
func TestGetChannelModelPricing_ReturnsCopy(t *testing.T) {
|
||
ch := Channel{
|
||
ID: 1,
|
||
Status: StatusActive,
|
||
GroupIDs: []int64{10},
|
||
ModelPricing: []ChannelModelPricing{
|
||
{ID: 100, Platform: "anthropic", Models: []string{"claude-opus-4"}, InputPrice: testPtrFloat64(15e-6)},
|
||
},
|
||
}
|
||
repo := makeStandardRepo(ch, map[int64]string{10: "anthropic"})
|
||
svc := newTestChannelService(repo)
|
||
|
||
result := svc.GetChannelModelPricing(context.Background(), 10, "claude-opus-4")
|
||
require.NotNil(t, result)
|
||
|
||
// Mutate the returned pricing's slice fields — original cache should not be affected
|
||
// (Clone copies slices independently, pointer fields are shared per design)
|
||
result.Models = append(result.Models, "hacked")
|
||
result.ID = 999
|
||
|
||
// Original cache should not be affected (slice independence + struct copy)
|
||
result2 := svc.GetChannelModelPricing(context.Background(), 10, "claude-opus-4")
|
||
require.NotNil(t, result2)
|
||
require.Equal(t, 1, len(result2.Models))
|
||
require.Equal(t, int64(100), result2.ID)
|
||
}
|
||
|
||
// --- 4.3 ResolveChannelMapping ---
|
||
|
||
func TestResolveChannelMapping_NoChannel(t *testing.T) {
|
||
ch := Channel{
|
||
ID: 1,
|
||
Status: StatusActive,
|
||
GroupIDs: []int64{10},
|
||
}
|
||
repo := makeStandardRepo(ch, map[int64]string{10: "anthropic"})
|
||
svc := newTestChannelService(repo)
|
||
|
||
// Group 999 is not in any channel
|
||
result := svc.ResolveChannelMapping(context.Background(), 999, "claude-opus-4")
|
||
require.Equal(t, "claude-opus-4", result.MappedModel)
|
||
require.False(t, result.Mapped)
|
||
require.Equal(t, int64(0), result.ChannelID)
|
||
}
|
||
|
||
func TestResolveChannelMapping_ExactMapping(t *testing.T) {
|
||
ch := Channel{
|
||
ID: 1,
|
||
Status: StatusActive,
|
||
GroupIDs: []int64{10},
|
||
ModelMapping: map[string]map[string]string{
|
||
"anthropic": {
|
||
"claude-sonnet-4": "claude-sonnet-4-20250514",
|
||
},
|
||
},
|
||
}
|
||
repo := makeStandardRepo(ch, map[int64]string{10: "anthropic"})
|
||
svc := newTestChannelService(repo)
|
||
|
||
result := svc.ResolveChannelMapping(context.Background(), 10, "claude-sonnet-4")
|
||
require.True(t, result.Mapped)
|
||
require.Equal(t, "claude-sonnet-4-20250514", result.MappedModel)
|
||
require.Equal(t, int64(1), result.ChannelID)
|
||
}
|
||
|
||
func TestResolveChannelMapping_WildcardMapping(t *testing.T) {
|
||
ch := Channel{
|
||
ID: 1,
|
||
Status: StatusActive,
|
||
GroupIDs: []int64{10},
|
||
ModelMapping: map[string]map[string]string{
|
||
"anthropic": {
|
||
"*": "gpt-5.4",
|
||
},
|
||
},
|
||
}
|
||
repo := makeStandardRepo(ch, map[int64]string{10: "anthropic"})
|
||
svc := newTestChannelService(repo)
|
||
|
||
result := svc.ResolveChannelMapping(context.Background(), 10, "any-model-name")
|
||
require.True(t, result.Mapped)
|
||
require.Equal(t, "gpt-5.4", result.MappedModel)
|
||
}
|
||
|
||
func TestResolveChannelMapping_WildcardFirstMatch(t *testing.T) {
|
||
ch := Channel{
|
||
ID: 1,
|
||
Status: StatusActive,
|
||
GroupIDs: []int64{10},
|
||
ModelMapping: map[string]map[string]string{
|
||
"anthropic": {
|
||
"claude-*": "target2",
|
||
"claude-sonnet-*": "target1",
|
||
},
|
||
},
|
||
}
|
||
repo := makeStandardRepo(ch, map[int64]string{10: "anthropic"})
|
||
svc := newTestChannelService(repo)
|
||
|
||
result := svc.ResolveChannelMapping(context.Background(), 10, "claude-sonnet-4")
|
||
require.True(t, result.Mapped)
|
||
// map iteration order is non-deterministic, so the first-match depends on
|
||
// insertion order which Go maps don't guarantee; verify that one of the
|
||
// wildcard targets matched
|
||
require.Contains(t, []string{"target1", "target2"}, result.MappedModel)
|
||
}
|
||
|
||
func TestResolveChannelMapping_NoMapping(t *testing.T) {
|
||
ch := Channel{
|
||
ID: 1,
|
||
Status: StatusActive,
|
||
GroupIDs: []int64{10},
|
||
ModelMapping: map[string]map[string]string{
|
||
"anthropic": {
|
||
"claude-sonnet-4": "mapped",
|
||
},
|
||
},
|
||
}
|
||
repo := makeStandardRepo(ch, map[int64]string{10: "anthropic"})
|
||
svc := newTestChannelService(repo)
|
||
|
||
result := svc.ResolveChannelMapping(context.Background(), 10, "claude-opus-4")
|
||
require.False(t, result.Mapped)
|
||
require.Equal(t, "claude-opus-4", result.MappedModel)
|
||
require.Equal(t, int64(1), result.ChannelID)
|
||
}
|
||
|
||
func TestResolveChannelMapping_DefaultBillingModelSource(t *testing.T) {
|
||
ch := Channel{
|
||
ID: 1,
|
||
Status: StatusActive,
|
||
GroupIDs: []int64{10},
|
||
BillingModelSource: "", // empty
|
||
}
|
||
repo := makeStandardRepo(ch, map[int64]string{10: "anthropic"})
|
||
svc := newTestChannelService(repo)
|
||
|
||
result := svc.ResolveChannelMapping(context.Background(), 10, "claude-opus-4")
|
||
require.Equal(t, BillingModelSourceChannelMapped, result.BillingModelSource)
|
||
}
|
||
|
||
func TestResolveChannelMapping_UpstreamBillingModelSource(t *testing.T) {
|
||
ch := Channel{
|
||
ID: 1,
|
||
Status: StatusActive,
|
||
GroupIDs: []int64{10},
|
||
BillingModelSource: BillingModelSourceUpstream,
|
||
}
|
||
repo := makeStandardRepo(ch, map[int64]string{10: "anthropic"})
|
||
svc := newTestChannelService(repo)
|
||
|
||
result := svc.ResolveChannelMapping(context.Background(), 10, "claude-opus-4")
|
||
require.Equal(t, BillingModelSourceUpstream, result.BillingModelSource)
|
||
}
|
||
|
||
func TestResolveChannelMapping_InactiveChannel(t *testing.T) {
|
||
ch := Channel{
|
||
ID: 1,
|
||
Status: StatusDisabled,
|
||
GroupIDs: []int64{10},
|
||
ModelMapping: map[string]map[string]string{
|
||
"anthropic": {
|
||
"claude-sonnet-4": "mapped",
|
||
},
|
||
},
|
||
}
|
||
repo := makeStandardRepo(ch, map[int64]string{10: "anthropic"})
|
||
svc := newTestChannelService(repo)
|
||
|
||
result := svc.ResolveChannelMapping(context.Background(), 10, "claude-sonnet-4")
|
||
require.False(t, result.Mapped)
|
||
require.Equal(t, "claude-sonnet-4", result.MappedModel)
|
||
require.Equal(t, int64(0), result.ChannelID) // no channel
|
||
}
|
||
|
||
// --- 4.4 IsModelRestricted ---
|
||
|
||
func TestIsModelRestricted_NoChannel(t *testing.T) {
|
||
ch := Channel{
|
||
ID: 1,
|
||
Status: StatusActive,
|
||
GroupIDs: []int64{10},
|
||
RestrictModels: true,
|
||
}
|
||
repo := makeStandardRepo(ch, map[int64]string{10: "anthropic"})
|
||
svc := newTestChannelService(repo)
|
||
|
||
// Group 999 is not in any channel
|
||
restricted := svc.IsModelRestricted(context.Background(), 999, "claude-opus-4")
|
||
require.False(t, restricted)
|
||
}
|
||
|
||
func TestIsModelRestricted_RestrictDisabled(t *testing.T) {
|
||
ch := Channel{
|
||
ID: 1,
|
||
Status: StatusActive,
|
||
GroupIDs: []int64{10},
|
||
RestrictModels: false,
|
||
ModelPricing: []ChannelModelPricing{
|
||
{Platform: "anthropic", Models: []string{"claude-opus-4"}},
|
||
},
|
||
}
|
||
repo := makeStandardRepo(ch, map[int64]string{10: "anthropic"})
|
||
svc := newTestChannelService(repo)
|
||
|
||
// Even though model is not in pricing, RestrictModels=false
|
||
restricted := svc.IsModelRestricted(context.Background(), 10, "nonexistent-model")
|
||
require.False(t, restricted)
|
||
}
|
||
|
||
func TestIsModelRestricted_InactiveChannel(t *testing.T) {
|
||
ch := Channel{
|
||
ID: 1,
|
||
Status: StatusDisabled,
|
||
GroupIDs: []int64{10},
|
||
RestrictModels: true,
|
||
}
|
||
repo := makeStandardRepo(ch, map[int64]string{10: "anthropic"})
|
||
svc := newTestChannelService(repo)
|
||
|
||
restricted := svc.IsModelRestricted(context.Background(), 10, "any-model")
|
||
require.False(t, restricted)
|
||
}
|
||
|
||
func TestIsModelRestricted_ModelInPricing(t *testing.T) {
|
||
ch := Channel{
|
||
ID: 1,
|
||
Status: StatusActive,
|
||
GroupIDs: []int64{10},
|
||
RestrictModels: true,
|
||
ModelPricing: []ChannelModelPricing{
|
||
{Platform: "anthropic", Models: []string{"claude-opus-4", "claude-sonnet-4"}},
|
||
},
|
||
}
|
||
repo := makeStandardRepo(ch, map[int64]string{10: "anthropic"})
|
||
svc := newTestChannelService(repo)
|
||
|
||
restricted := svc.IsModelRestricted(context.Background(), 10, "claude-opus-4")
|
||
require.False(t, restricted)
|
||
}
|
||
|
||
func TestIsModelRestricted_ModelInWildcard(t *testing.T) {
|
||
ch := Channel{
|
||
ID: 1,
|
||
Status: StatusActive,
|
||
GroupIDs: []int64{10},
|
||
RestrictModels: true,
|
||
ModelPricing: []ChannelModelPricing{
|
||
{Platform: "anthropic", Models: []string{"claude-*"}},
|
||
},
|
||
}
|
||
repo := makeStandardRepo(ch, map[int64]string{10: "anthropic"})
|
||
svc := newTestChannelService(repo)
|
||
|
||
restricted := svc.IsModelRestricted(context.Background(), 10, "claude-sonnet-4")
|
||
require.False(t, restricted)
|
||
}
|
||
|
||
func TestIsModelRestricted_ModelNotFound(t *testing.T) {
|
||
ch := Channel{
|
||
ID: 1,
|
||
Status: StatusActive,
|
||
GroupIDs: []int64{10},
|
||
RestrictModels: true,
|
||
ModelPricing: []ChannelModelPricing{
|
||
{Platform: "anthropic", Models: []string{"claude-opus-4"}},
|
||
},
|
||
}
|
||
repo := makeStandardRepo(ch, map[int64]string{10: "anthropic"})
|
||
svc := newTestChannelService(repo)
|
||
|
||
restricted := svc.IsModelRestricted(context.Background(), 10, "gpt-5.1")
|
||
require.True(t, restricted)
|
||
}
|
||
|
||
func TestIsModelRestricted_CaseInsensitive(t *testing.T) {
|
||
ch := Channel{
|
||
ID: 1,
|
||
Status: StatusActive,
|
||
GroupIDs: []int64{10},
|
||
RestrictModels: true,
|
||
ModelPricing: []ChannelModelPricing{
|
||
{Platform: "anthropic", Models: []string{"claude-opus-4"}},
|
||
},
|
||
}
|
||
repo := makeStandardRepo(ch, map[int64]string{10: "anthropic"})
|
||
svc := newTestChannelService(repo)
|
||
|
||
restricted := svc.IsModelRestricted(context.Background(), 10, "Claude-Opus-4")
|
||
require.False(t, restricted)
|
||
}
|
||
|
||
// --- 4.5 ResolveChannelMappingAndRestrict ---
|
||
// 注意:模型限制检查已移至调度阶段(GatewayService.checkChannelPricingRestriction),
|
||
// ResolveChannelMappingAndRestrict 仅做映射,restricted 始终为 false。
|
||
|
||
func TestResolveChannelMappingAndRestrict_NilGroupID(t *testing.T) {
|
||
repo := &mockChannelRepository{
|
||
listAllFn: func(_ context.Context) ([]Channel, error) {
|
||
return nil, nil
|
||
},
|
||
}
|
||
svc := newTestChannelService(repo)
|
||
|
||
mapping, restricted := svc.ResolveChannelMappingAndRestrict(context.Background(), nil, "claude-opus-4")
|
||
require.False(t, restricted)
|
||
require.False(t, mapping.Mapped)
|
||
require.Equal(t, "claude-opus-4", mapping.MappedModel)
|
||
}
|
||
|
||
func TestResolveChannelMappingAndRestrict_WithMapping(t *testing.T) {
|
||
ch := Channel{
|
||
ID: 1,
|
||
Status: StatusActive,
|
||
GroupIDs: []int64{10},
|
||
RestrictModels: true,
|
||
ModelPricing: []ChannelModelPricing{
|
||
{Platform: "anthropic", Models: []string{"claude-sonnet-4"}},
|
||
},
|
||
ModelMapping: map[string]map[string]string{
|
||
"anthropic": {
|
||
"claude-sonnet-4": "claude-sonnet-4-20250514",
|
||
},
|
||
},
|
||
}
|
||
repo := makeStandardRepo(ch, map[int64]string{10: "anthropic"})
|
||
svc := newTestChannelService(repo)
|
||
|
||
gid := int64(10)
|
||
mapping, restricted := svc.ResolveChannelMappingAndRestrict(context.Background(), &gid, "claude-sonnet-4")
|
||
require.False(t, restricted) // restricted 始终为 false,限制检查在调度阶段
|
||
require.True(t, mapping.Mapped)
|
||
require.Equal(t, "claude-sonnet-4-20250514", mapping.MappedModel)
|
||
}
|
||
|
||
func TestResolveChannelMappingAndRestrict_NoMapping(t *testing.T) {
|
||
ch := Channel{
|
||
ID: 1,
|
||
Status: StatusActive,
|
||
GroupIDs: []int64{10},
|
||
RestrictModels: true,
|
||
ModelPricing: []ChannelModelPricing{
|
||
{Platform: "anthropic", Models: []string{"claude-sonnet-4"}},
|
||
},
|
||
}
|
||
repo := makeStandardRepo(ch, map[int64]string{10: "anthropic"})
|
||
svc := newTestChannelService(repo)
|
||
|
||
gid := int64(10)
|
||
mapping, restricted := svc.ResolveChannelMappingAndRestrict(context.Background(), &gid, "unknown-model")
|
||
require.False(t, restricted) // restricted 始终为 false,限制检查在调度阶段
|
||
require.False(t, mapping.Mapped)
|
||
require.Equal(t, "unknown-model", mapping.MappedModel)
|
||
}
|
||
|
||
// --- 4.6 Cache Building Specifics ---
|
||
|
||
func TestBuildCache_DBError(t *testing.T) {
|
||
callCount := 0
|
||
repo := &mockChannelRepository{
|
||
listAllFn: func(_ context.Context) ([]Channel, error) {
|
||
callCount++
|
||
return nil, errors.New("database down")
|
||
},
|
||
}
|
||
svc := newTestChannelService(repo)
|
||
|
||
// First call should fail
|
||
_, err := svc.GetChannelForGroup(context.Background(), 10)
|
||
require.Error(t, err)
|
||
require.Contains(t, err.Error(), "database down")
|
||
require.Equal(t, 1, callCount)
|
||
|
||
// Second call within error-TTL should use error cache, but still return error
|
||
// Because buildCache stores error-TTL cache and returns error, the cached value
|
||
// is still within TTL and loadCache returns it (which is an empty cache).
|
||
// Actually, re-reading the code: buildCache returns nil, err, and the error cache
|
||
// only serves as a "don't retry immediately" mechanism. The singleflight.Do
|
||
// returns the error. On next call within error-TTL, the cache has an empty but
|
||
// valid entry, so loadCache returns it (with empty maps). GetChannelForGroup
|
||
// will find nothing and return nil, nil.
|
||
result, err := svc.GetChannelForGroup(context.Background(), 10)
|
||
require.NoError(t, err)
|
||
require.Nil(t, result)
|
||
// Should NOT have hit DB again (error-TTL cache is active)
|
||
require.Equal(t, 1, callCount)
|
||
}
|
||
|
||
func TestBuildCache_GroupPlatformError(t *testing.T) {
|
||
ch := Channel{
|
||
ID: 1,
|
||
Status: StatusActive,
|
||
GroupIDs: []int64{10},
|
||
ModelPricing: []ChannelModelPricing{
|
||
{ID: 100, Platform: "anthropic", Models: []string{"claude-opus-4"}},
|
||
},
|
||
}
|
||
repo := &mockChannelRepository{
|
||
listAllFn: func(_ context.Context) ([]Channel, error) {
|
||
return []Channel{ch}, nil
|
||
},
|
||
getGroupPlatformsFn: func(_ context.Context, _ []int64) (map[int64]string, error) {
|
||
return nil, errors.New("group platforms failed")
|
||
},
|
||
}
|
||
svc := newTestChannelService(repo)
|
||
|
||
// Should fail-close: error propagated when group platforms cannot be loaded
|
||
result, err := svc.GetChannelForGroup(context.Background(), 10)
|
||
require.Error(t, err)
|
||
require.Nil(t, result)
|
||
|
||
// Within error-TTL, second call should hit cache (empty) and return nil, nil
|
||
result2, err2 := svc.GetChannelForGroup(context.Background(), 10)
|
||
require.NoError(t, err2)
|
||
require.Nil(t, result2)
|
||
}
|
||
|
||
func TestBuildCache_MultipleGroupsSameChannel(t *testing.T) {
|
||
ch := Channel{
|
||
ID: 1,
|
||
Status: StatusActive,
|
||
GroupIDs: []int64{10, 20, 30},
|
||
ModelPricing: []ChannelModelPricing{
|
||
{ID: 100, Platform: "anthropic", Models: []string{"claude-opus-4"}, InputPrice: testPtrFloat64(15e-6)},
|
||
},
|
||
}
|
||
repo := makeStandardRepo(ch, map[int64]string{
|
||
10: "anthropic",
|
||
20: "anthropic",
|
||
30: "anthropic",
|
||
})
|
||
svc := newTestChannelService(repo)
|
||
|
||
for _, gid := range []int64{10, 20, 30} {
|
||
result := svc.GetChannelModelPricing(context.Background(), gid, "claude-opus-4")
|
||
require.NotNil(t, result, "group %d should have pricing", gid)
|
||
require.Equal(t, int64(100), result.ID)
|
||
}
|
||
}
|
||
|
||
func TestBuildCache_PlatformFiltering(t *testing.T) {
|
||
ch := Channel{
|
||
ID: 1,
|
||
Status: StatusActive,
|
||
GroupIDs: []int64{10, 20},
|
||
ModelPricing: []ChannelModelPricing{
|
||
{ID: 100, Platform: "anthropic", Models: []string{"claude-opus-4"}},
|
||
{ID: 200, Platform: "openai", Models: []string{"gpt-5.1"}},
|
||
},
|
||
}
|
||
repo := makeStandardRepo(ch, map[int64]string{
|
||
10: "anthropic",
|
||
20: "openai",
|
||
})
|
||
svc := newTestChannelService(repo)
|
||
|
||
// anthropic group sees only anthropic models
|
||
require.NotNil(t, svc.GetChannelModelPricing(context.Background(), 10, "claude-opus-4"))
|
||
require.Nil(t, svc.GetChannelModelPricing(context.Background(), 10, "gpt-5.1"))
|
||
|
||
// openai group sees only openai models
|
||
require.NotNil(t, svc.GetChannelModelPricing(context.Background(), 20, "gpt-5.1"))
|
||
require.Nil(t, svc.GetChannelModelPricing(context.Background(), 20, "claude-opus-4"))
|
||
}
|
||
|
||
func TestBuildCache_WildcardPreservesConfigOrder(t *testing.T) {
|
||
ch := Channel{
|
||
ID: 1,
|
||
Status: StatusActive,
|
||
GroupIDs: []int64{10},
|
||
ModelPricing: []ChannelModelPricing{
|
||
// Configuration order: shortest prefix first
|
||
{ID: 100, Platform: "anthropic", Models: []string{"c-*"}, InputPrice: testPtrFloat64(1e-6)},
|
||
{ID: 200, Platform: "anthropic", Models: []string{"c-son-*"}, InputPrice: testPtrFloat64(2e-6)},
|
||
{ID: 300, Platform: "anthropic", Models: []string{"c-son-4-*"}, InputPrice: testPtrFloat64(3e-6)},
|
||
},
|
||
}
|
||
repo := makeStandardRepo(ch, map[int64]string{10: "anthropic"})
|
||
svc := newTestChannelService(repo)
|
||
|
||
// "c-son-4-xxx" matches all three wildcards, but "c-*" (ID=100) is first in config
|
||
result := svc.GetChannelModelPricing(context.Background(), 10, "c-son-4-xxx")
|
||
require.NotNil(t, result)
|
||
require.Equal(t, int64(100), result.ID)
|
||
|
||
// "c-son-yyy" matches "c-*" and "c-son-*", but "c-*" (ID=100) is first
|
||
result = svc.GetChannelModelPricing(context.Background(), 10, "c-son-yyy")
|
||
require.NotNil(t, result)
|
||
require.Equal(t, int64(100), result.ID)
|
||
|
||
// "c-other" only matches "c-*" (ID=100)
|
||
result = svc.GetChannelModelPricing(context.Background(), 10, "c-other")
|
||
require.NotNil(t, result)
|
||
require.Equal(t, int64(100), result.ID)
|
||
}
|
||
|
||
// --- 4.7 invalidateCache ---
|
||
|
||
func TestInvalidateCache(t *testing.T) {
|
||
callCount := 0
|
||
ch := Channel{
|
||
ID: 1,
|
||
Status: StatusActive,
|
||
GroupIDs: []int64{10},
|
||
ModelPricing: []ChannelModelPricing{
|
||
{ID: 100, Platform: "anthropic", Models: []string{"claude-opus-4"}},
|
||
},
|
||
}
|
||
repo := &mockChannelRepository{
|
||
listAllFn: func(_ context.Context) ([]Channel, error) {
|
||
callCount++
|
||
return []Channel{ch}, nil
|
||
},
|
||
getGroupPlatformsFn: func(_ context.Context, _ []int64) (map[int64]string, error) {
|
||
return map[int64]string{10: "anthropic"}, nil
|
||
},
|
||
}
|
||
svc := newTestChannelService(repo)
|
||
|
||
// First load
|
||
result := svc.GetChannelModelPricing(context.Background(), 10, "claude-opus-4")
|
||
require.NotNil(t, result)
|
||
require.Equal(t, 1, callCount)
|
||
|
||
// Second call should use cache
|
||
result = svc.GetChannelModelPricing(context.Background(), 10, "claude-opus-4")
|
||
require.NotNil(t, result)
|
||
require.Equal(t, 1, callCount) // no new DB call
|
||
|
||
// Invalidate
|
||
svc.invalidateCache()
|
||
|
||
// Next call should rebuild from DB
|
||
result = svc.GetChannelModelPricing(context.Background(), 10, "claude-opus-4")
|
||
require.NotNil(t, result)
|
||
require.Equal(t, 2, callCount) // rebuilt
|
||
}
|
||
|
||
// ===========================================================================
|
||
// 5. CRUD Methods
|
||
// ===========================================================================
|
||
|
||
// --- 5.1 Create ---
|
||
|
||
func TestCreate_Success(t *testing.T) {
|
||
createdID := int64(42)
|
||
repo := &mockChannelRepository{
|
||
existsByNameFn: func(_ context.Context, _ string) (bool, error) {
|
||
return false, nil
|
||
},
|
||
getGroupsInOtherChannelsFn: func(_ context.Context, _ int64, _ []int64) ([]int64, error) {
|
||
return nil, nil
|
||
},
|
||
createFn: func(_ context.Context, ch *Channel) error {
|
||
ch.ID = createdID
|
||
return nil
|
||
},
|
||
getByIDFn: func(_ context.Context, id int64) (*Channel, error) {
|
||
return &Channel{ID: id, Name: "new-channel", Status: StatusActive}, nil
|
||
},
|
||
listAllFn: func(_ context.Context) ([]Channel, error) {
|
||
return nil, nil
|
||
},
|
||
}
|
||
svc := newTestChannelService(repo)
|
||
|
||
result, err := svc.Create(context.Background(), &CreateChannelInput{
|
||
Name: "new-channel",
|
||
GroupIDs: []int64{10},
|
||
})
|
||
require.NoError(t, err)
|
||
require.NotNil(t, result)
|
||
require.Equal(t, createdID, result.ID)
|
||
}
|
||
|
||
func TestCreate_NameExists(t *testing.T) {
|
||
repo := &mockChannelRepository{
|
||
existsByNameFn: func(_ context.Context, _ string) (bool, error) {
|
||
return true, nil
|
||
},
|
||
}
|
||
svc := newTestChannelService(repo)
|
||
|
||
_, err := svc.Create(context.Background(), &CreateChannelInput{
|
||
Name: "existing-channel",
|
||
})
|
||
require.Error(t, err)
|
||
require.ErrorIs(t, err, ErrChannelExists)
|
||
}
|
||
|
||
func TestCreate_GroupConflict(t *testing.T) {
|
||
repo := &mockChannelRepository{
|
||
existsByNameFn: func(_ context.Context, _ string) (bool, error) {
|
||
return false, nil
|
||
},
|
||
getGroupsInOtherChannelsFn: func(_ context.Context, _ int64, _ []int64) ([]int64, error) {
|
||
return []int64{10}, nil // group 10 already in another channel
|
||
},
|
||
}
|
||
svc := newTestChannelService(repo)
|
||
|
||
_, err := svc.Create(context.Background(), &CreateChannelInput{
|
||
Name: "new-channel",
|
||
GroupIDs: []int64{10, 20},
|
||
})
|
||
require.Error(t, err)
|
||
require.ErrorIs(t, err, ErrGroupAlreadyInChannel)
|
||
}
|
||
|
||
func TestCreate_DuplicateModel(t *testing.T) {
|
||
repo := &mockChannelRepository{
|
||
existsByNameFn: func(_ context.Context, _ string) (bool, error) {
|
||
return false, nil
|
||
},
|
||
}
|
||
svc := newTestChannelService(repo)
|
||
|
||
_, err := svc.Create(context.Background(), &CreateChannelInput{
|
||
Name: "new-channel",
|
||
ModelPricing: []ChannelModelPricing{
|
||
{Platform: "anthropic", Models: []string{"claude-opus-4"}},
|
||
{Platform: "anthropic", Models: []string{"claude-opus-4"}}, // duplicate
|
||
},
|
||
})
|
||
require.Error(t, err)
|
||
require.Contains(t, err.Error(), "claude-opus-4")
|
||
}
|
||
|
||
func TestCreate_InvalidPricingIntervals(t *testing.T) {
|
||
repo := &mockChannelRepository{
|
||
existsByNameFn: func(_ context.Context, _ string) (bool, error) {
|
||
return false, nil
|
||
},
|
||
}
|
||
svc := newTestChannelService(repo)
|
||
|
||
_, err := svc.Create(context.Background(), &CreateChannelInput{
|
||
Name: "new-channel",
|
||
ModelPricing: []ChannelModelPricing{
|
||
{
|
||
Platform: "anthropic",
|
||
Models: []string{"claude-opus-4"},
|
||
Intervals: []PricingInterval{
|
||
{MinTokens: 0, MaxTokens: testPtrInt(2000), InputPrice: testPtrFloat64(1e-6)},
|
||
{MinTokens: 1000, MaxTokens: testPtrInt(3000), InputPrice: testPtrFloat64(2e-6)},
|
||
},
|
||
},
|
||
},
|
||
})
|
||
require.Error(t, err)
|
||
require.Contains(t, err.Error(), "INVALID_PRICING_INTERVALS")
|
||
require.Contains(t, err.Error(), "overlap")
|
||
}
|
||
|
||
func TestCreate_DefaultBillingModelSource(t *testing.T) {
|
||
var capturedChannel *Channel
|
||
repo := &mockChannelRepository{
|
||
existsByNameFn: func(_ context.Context, _ string) (bool, error) {
|
||
return false, nil
|
||
},
|
||
createFn: func(_ context.Context, ch *Channel) error {
|
||
capturedChannel = ch
|
||
ch.ID = 1
|
||
return nil
|
||
},
|
||
getByIDFn: func(_ context.Context, id int64) (*Channel, error) {
|
||
return capturedChannel, nil
|
||
},
|
||
listAllFn: func(_ context.Context) ([]Channel, error) {
|
||
return nil, nil
|
||
},
|
||
}
|
||
svc := newTestChannelService(repo)
|
||
|
||
result, err := svc.Create(context.Background(), &CreateChannelInput{
|
||
Name: "new-channel",
|
||
BillingModelSource: "", // empty, should default to "channel_mapped"
|
||
})
|
||
require.NoError(t, err)
|
||
require.NotNil(t, result)
|
||
require.Equal(t, BillingModelSourceChannelMapped, result.BillingModelSource)
|
||
}
|
||
|
||
func TestCreate_InvalidatesCache(t *testing.T) {
|
||
loadCount := 0
|
||
ch := Channel{
|
||
ID: 1,
|
||
Status: StatusActive,
|
||
GroupIDs: []int64{10},
|
||
ModelPricing: []ChannelModelPricing{
|
||
{ID: 100, Platform: "anthropic", Models: []string{"claude-opus-4"}},
|
||
},
|
||
}
|
||
repo := &mockChannelRepository{
|
||
listAllFn: func(_ context.Context) ([]Channel, error) {
|
||
loadCount++
|
||
return []Channel{ch}, nil
|
||
},
|
||
getGroupPlatformsFn: func(_ context.Context, _ []int64) (map[int64]string, error) {
|
||
return map[int64]string{10: "anthropic"}, nil
|
||
},
|
||
existsByNameFn: func(_ context.Context, _ string) (bool, error) {
|
||
return false, nil
|
||
},
|
||
createFn: func(_ context.Context, c *Channel) error {
|
||
c.ID = 2
|
||
return nil
|
||
},
|
||
getByIDFn: func(_ context.Context, id int64) (*Channel, error) {
|
||
return &Channel{ID: id, Name: "new", Status: StatusActive}, nil
|
||
},
|
||
}
|
||
svc := newTestChannelService(repo)
|
||
|
||
// Load cache
|
||
_ = svc.GetChannelModelPricing(context.Background(), 10, "claude-opus-4")
|
||
require.Equal(t, 1, loadCount)
|
||
|
||
// Create triggers cache invalidation
|
||
_, err := svc.Create(context.Background(), &CreateChannelInput{Name: "new"})
|
||
require.NoError(t, err)
|
||
|
||
// Next cache access should rebuild
|
||
_ = svc.GetChannelModelPricing(context.Background(), 10, "claude-opus-4")
|
||
require.Equal(t, 2, loadCount)
|
||
}
|
||
|
||
// --- 5.2 Update ---
|
||
|
||
func TestUpdate_Success(t *testing.T) {
|
||
existing := &Channel{
|
||
ID: 1,
|
||
Name: "original",
|
||
Status: StatusActive,
|
||
}
|
||
repo := &mockChannelRepository{
|
||
getByIDFn: func(_ context.Context, id int64) (*Channel, error) {
|
||
return existing.Clone(), nil
|
||
},
|
||
updateFn: func(_ context.Context, _ *Channel) error {
|
||
return nil
|
||
},
|
||
getGroupIDsFn: func(_ context.Context, _ int64) ([]int64, error) {
|
||
return nil, nil
|
||
},
|
||
listAllFn: func(_ context.Context) ([]Channel, error) {
|
||
return nil, nil
|
||
},
|
||
}
|
||
svc := newTestChannelService(repo)
|
||
|
||
result, err := svc.Update(context.Background(), 1, &UpdateChannelInput{
|
||
Name: "updated-name",
|
||
Description: testPtrString("new desc"),
|
||
})
|
||
require.NoError(t, err)
|
||
require.NotNil(t, result)
|
||
}
|
||
|
||
func TestUpdate_NotFound(t *testing.T) {
|
||
repo := &mockChannelRepository{
|
||
getByIDFn: func(_ context.Context, _ int64) (*Channel, error) {
|
||
return nil, ErrChannelNotFound
|
||
},
|
||
}
|
||
svc := newTestChannelService(repo)
|
||
|
||
_, err := svc.Update(context.Background(), 999, &UpdateChannelInput{
|
||
Name: "whatever",
|
||
})
|
||
require.Error(t, err)
|
||
require.Contains(t, err.Error(), "channel")
|
||
}
|
||
|
||
func TestUpdate_NameConflict(t *testing.T) {
|
||
existing := &Channel{
|
||
ID: 1,
|
||
Name: "original",
|
||
Status: StatusActive,
|
||
}
|
||
repo := &mockChannelRepository{
|
||
getByIDFn: func(_ context.Context, _ int64) (*Channel, error) {
|
||
return existing.Clone(), nil
|
||
},
|
||
existsByNameExcludingFn: func(_ context.Context, _ string, _ int64) (bool, error) {
|
||
return true, nil // name conflicts with another channel
|
||
},
|
||
}
|
||
svc := newTestChannelService(repo)
|
||
|
||
_, err := svc.Update(context.Background(), 1, &UpdateChannelInput{
|
||
Name: "conflicting-name",
|
||
})
|
||
require.Error(t, err)
|
||
require.ErrorIs(t, err, ErrChannelExists)
|
||
}
|
||
|
||
func TestUpdate_GroupConflict(t *testing.T) {
|
||
existing := &Channel{
|
||
ID: 1,
|
||
Name: "original",
|
||
Status: StatusActive,
|
||
}
|
||
repo := &mockChannelRepository{
|
||
getByIDFn: func(_ context.Context, _ int64) (*Channel, error) {
|
||
return existing.Clone(), nil
|
||
},
|
||
getGroupsInOtherChannelsFn: func(_ context.Context, _ int64, _ []int64) ([]int64, error) {
|
||
return []int64{20}, nil // group 20 in another channel
|
||
},
|
||
}
|
||
svc := newTestChannelService(repo)
|
||
|
||
newGroupIDs := []int64{10, 20}
|
||
_, err := svc.Update(context.Background(), 1, &UpdateChannelInput{
|
||
GroupIDs: &newGroupIDs,
|
||
})
|
||
require.Error(t, err)
|
||
require.ErrorIs(t, err, ErrGroupAlreadyInChannel)
|
||
}
|
||
|
||
func TestUpdate_DuplicateModel(t *testing.T) {
|
||
existing := &Channel{
|
||
ID: 1,
|
||
Name: "original",
|
||
Status: StatusActive,
|
||
}
|
||
repo := &mockChannelRepository{
|
||
getByIDFn: func(_ context.Context, _ int64) (*Channel, error) {
|
||
return existing.Clone(), nil
|
||
},
|
||
}
|
||
svc := newTestChannelService(repo)
|
||
|
||
dupPricing := []ChannelModelPricing{
|
||
{Platform: "anthropic", Models: []string{"claude-opus-4"}},
|
||
{Platform: "anthropic", Models: []string{"claude-opus-4"}},
|
||
}
|
||
_, err := svc.Update(context.Background(), 1, &UpdateChannelInput{
|
||
ModelPricing: &dupPricing,
|
||
})
|
||
require.Error(t, err)
|
||
require.Contains(t, err.Error(), "claude-opus-4")
|
||
}
|
||
|
||
func TestUpdate_InvalidPricingIntervals(t *testing.T) {
|
||
existing := &Channel{
|
||
ID: 1,
|
||
Name: "original",
|
||
Status: StatusActive,
|
||
}
|
||
repo := &mockChannelRepository{
|
||
getByIDFn: func(_ context.Context, _ int64) (*Channel, error) {
|
||
return existing.Clone(), nil
|
||
},
|
||
}
|
||
svc := newTestChannelService(repo)
|
||
|
||
invalidPricing := []ChannelModelPricing{
|
||
{
|
||
Platform: "anthropic",
|
||
Models: []string{"claude-opus-4"},
|
||
Intervals: []PricingInterval{
|
||
{MinTokens: 0, MaxTokens: nil, InputPrice: testPtrFloat64(1e-6)},
|
||
{MinTokens: 2000, MaxTokens: testPtrInt(4000), InputPrice: testPtrFloat64(2e-6)},
|
||
},
|
||
},
|
||
}
|
||
_, err := svc.Update(context.Background(), 1, &UpdateChannelInput{
|
||
ModelPricing: &invalidPricing,
|
||
})
|
||
require.Error(t, err)
|
||
require.Contains(t, err.Error(), "INVALID_PRICING_INTERVALS")
|
||
require.Contains(t, err.Error(), "unbounded")
|
||
}
|
||
|
||
func TestUpdate_InvalidatesChannelCache(t *testing.T) {
|
||
existing := &Channel{
|
||
ID: 1,
|
||
Name: "original",
|
||
Status: StatusActive,
|
||
}
|
||
loadCount := 0
|
||
repo := &mockChannelRepository{
|
||
getByIDFn: func(_ context.Context, _ int64) (*Channel, error) {
|
||
return existing.Clone(), nil
|
||
},
|
||
updateFn: func(_ context.Context, _ *Channel) error {
|
||
return nil
|
||
},
|
||
getGroupIDsFn: func(_ context.Context, _ int64) ([]int64, error) {
|
||
return []int64{10, 20}, nil
|
||
},
|
||
listAllFn: func(_ context.Context) ([]Channel, error) {
|
||
loadCount++
|
||
return []Channel{*existing}, nil
|
||
},
|
||
getGroupPlatformsFn: func(_ context.Context, _ []int64) (map[int64]string, error) {
|
||
return nil, nil
|
||
},
|
||
}
|
||
svc := newTestChannelService(repo)
|
||
|
||
// Load cache first
|
||
_, _ = svc.GetChannelForGroup(context.Background(), 10)
|
||
require.Equal(t, 1, loadCount)
|
||
|
||
result, err := svc.Update(context.Background(), 1, &UpdateChannelInput{
|
||
Description: testPtrString("updated"),
|
||
})
|
||
require.NoError(t, err)
|
||
require.NotNil(t, result)
|
||
|
||
// Channel cache should be invalidated (next access rebuilds)
|
||
_, _ = svc.GetChannelForGroup(context.Background(), 10)
|
||
require.Equal(t, 2, loadCount)
|
||
}
|
||
|
||
func TestUpdate_InvalidatesAuthCache(t *testing.T) {
|
||
existing := &Channel{
|
||
ID: 1,
|
||
Name: "original",
|
||
Status: StatusActive,
|
||
}
|
||
auth := &mockChannelAuthCacheInvalidator{}
|
||
repo := &mockChannelRepository{
|
||
getByIDFn: func(_ context.Context, _ int64) (*Channel, error) {
|
||
return existing.Clone(), nil
|
||
},
|
||
updateFn: func(_ context.Context, _ *Channel) error {
|
||
return nil
|
||
},
|
||
getGroupIDsFn: func(_ context.Context, _ int64) ([]int64, error) {
|
||
return []int64{10, 20}, nil
|
||
},
|
||
listAllFn: func(_ context.Context) ([]Channel, error) {
|
||
return nil, nil
|
||
},
|
||
}
|
||
svc := newTestChannelServiceWithAuth(repo, auth)
|
||
|
||
result, err := svc.Update(context.Background(), 1, &UpdateChannelInput{
|
||
Description: testPtrString("updated"),
|
||
})
|
||
require.NoError(t, err)
|
||
require.NotNil(t, result)
|
||
|
||
// Auth cache should be invalidated for both groups
|
||
require.ElementsMatch(t, []int64{10, 20}, auth.invalidatedGroupIDs)
|
||
}
|
||
|
||
// --- 5.3 Delete ---
|
||
|
||
func TestChannelDelete_Success(t *testing.T) {
|
||
deleted := false
|
||
repo := &mockChannelRepository{
|
||
getGroupIDsFn: func(_ context.Context, _ int64) ([]int64, error) {
|
||
return nil, nil
|
||
},
|
||
deleteFn: func(_ context.Context, _ int64) error {
|
||
deleted = true
|
||
return nil
|
||
},
|
||
listAllFn: func(_ context.Context) ([]Channel, error) {
|
||
return nil, nil
|
||
},
|
||
}
|
||
svc := newTestChannelService(repo)
|
||
|
||
err := svc.Delete(context.Background(), 1)
|
||
require.NoError(t, err)
|
||
require.True(t, deleted)
|
||
}
|
||
|
||
func TestChannelDelete_InvalidatesCaches(t *testing.T) {
|
||
auth := &mockChannelAuthCacheInvalidator{}
|
||
loadCount := 0
|
||
repo := &mockChannelRepository{
|
||
getGroupIDsFn: func(_ context.Context, _ int64) ([]int64, error) {
|
||
return []int64{10, 20}, nil
|
||
},
|
||
deleteFn: func(_ context.Context, _ int64) error {
|
||
return nil
|
||
},
|
||
listAllFn: func(_ context.Context) ([]Channel, error) {
|
||
loadCount++
|
||
return []Channel{{ID: 1, Status: StatusActive, GroupIDs: []int64{10, 20}}}, nil
|
||
},
|
||
getGroupPlatformsFn: func(_ context.Context, _ []int64) (map[int64]string, error) {
|
||
return nil, nil
|
||
},
|
||
}
|
||
svc := newTestChannelServiceWithAuth(repo, auth)
|
||
|
||
// Load cache first
|
||
_, _ = svc.GetChannelForGroup(context.Background(), 10)
|
||
require.Equal(t, 1, loadCount)
|
||
|
||
err := svc.Delete(context.Background(), 1)
|
||
require.NoError(t, err)
|
||
|
||
// Auth cache invalidated for both groups
|
||
require.ElementsMatch(t, []int64{10, 20}, auth.invalidatedGroupIDs)
|
||
|
||
// Channel cache invalidated
|
||
_, _ = svc.GetChannelForGroup(context.Background(), 10)
|
||
require.Equal(t, 2, loadCount)
|
||
}
|
||
|
||
func TestChannelDelete_NotFound(t *testing.T) {
|
||
repo := &mockChannelRepository{
|
||
getGroupIDsFn: func(_ context.Context, _ int64) ([]int64, error) {
|
||
return nil, nil
|
||
},
|
||
deleteFn: func(_ context.Context, _ int64) error {
|
||
return errors.New("record not found")
|
||
},
|
||
}
|
||
svc := newTestChannelService(repo)
|
||
|
||
err := svc.Delete(context.Background(), 999)
|
||
require.Error(t, err)
|
||
require.Contains(t, err.Error(), "not found")
|
||
}
|
||
|
||
// ===========================================================================
|
||
// 6. Edge Case Tests
|
||
// ===========================================================================
|
||
|
||
// --- 6.1 Create with empty GroupIDs ---
|
||
|
||
func TestCreate_NoGroups(t *testing.T) {
|
||
createdID := int64(55)
|
||
getGroupsInOtherChannelsCalled := false
|
||
repo := &mockChannelRepository{
|
||
existsByNameFn: func(_ context.Context, _ string) (bool, error) {
|
||
return false, nil
|
||
},
|
||
getGroupsInOtherChannelsFn: func(_ context.Context, _ int64, _ []int64) ([]int64, error) {
|
||
getGroupsInOtherChannelsCalled = true
|
||
return nil, nil
|
||
},
|
||
createFn: func(_ context.Context, ch *Channel) error {
|
||
ch.ID = createdID
|
||
return nil
|
||
},
|
||
getByIDFn: func(_ context.Context, id int64) (*Channel, error) {
|
||
return &Channel{ID: id, Name: "no-groups-channel", Status: StatusActive}, nil
|
||
},
|
||
listAllFn: func(_ context.Context) ([]Channel, error) {
|
||
return nil, nil
|
||
},
|
||
}
|
||
svc := newTestChannelService(repo)
|
||
|
||
result, err := svc.Create(context.Background(), &CreateChannelInput{
|
||
Name: "no-groups-channel",
|
||
GroupIDs: []int64{}, // empty slice
|
||
})
|
||
require.NoError(t, err)
|
||
require.NotNil(t, result)
|
||
require.Equal(t, createdID, result.ID)
|
||
// GetGroupsInOtherChannels should NOT have been called (skipped by len(input.GroupIDs) > 0)
|
||
require.False(t, getGroupsInOtherChannelsCalled)
|
||
}
|
||
|
||
// --- 6.2 Update only Status ---
|
||
|
||
func TestUpdate_StatusOnly(t *testing.T) {
|
||
existing := &Channel{
|
||
ID: 1,
|
||
Name: "test-channel",
|
||
Status: StatusActive,
|
||
}
|
||
var capturedChannel *Channel
|
||
repo := &mockChannelRepository{
|
||
getByIDFn: func(_ context.Context, id int64) (*Channel, error) {
|
||
return existing.Clone(), nil
|
||
},
|
||
updateFn: func(_ context.Context, ch *Channel) error {
|
||
capturedChannel = ch
|
||
return nil
|
||
},
|
||
getGroupIDsFn: func(_ context.Context, _ int64) ([]int64, error) {
|
||
return nil, nil
|
||
},
|
||
listAllFn: func(_ context.Context) ([]Channel, error) {
|
||
return nil, nil
|
||
},
|
||
}
|
||
svc := newTestChannelService(repo)
|
||
|
||
result, err := svc.Update(context.Background(), 1, &UpdateChannelInput{
|
||
Status: StatusDisabled,
|
||
})
|
||
require.NoError(t, err)
|
||
require.NotNil(t, result)
|
||
// Verify that the channel passed to repo.Update has the new status
|
||
require.NotNil(t, capturedChannel)
|
||
require.Equal(t, StatusDisabled, capturedChannel.Status)
|
||
// Name should remain unchanged
|
||
require.Equal(t, "test-channel", capturedChannel.Name)
|
||
}
|
||
|
||
// --- 6.3 Delete when GetGroupIDs fails ---
|
||
|
||
func TestChannelDelete_GetGroupIDsError(t *testing.T) {
|
||
deleted := false
|
||
repo := &mockChannelRepository{
|
||
getGroupIDsFn: func(_ context.Context, _ int64) ([]int64, error) {
|
||
return nil, errors.New("group IDs lookup failed")
|
||
},
|
||
deleteFn: func(_ context.Context, _ int64) error {
|
||
deleted = true
|
||
return nil
|
||
},
|
||
listAllFn: func(_ context.Context) ([]Channel, error) {
|
||
return nil, nil
|
||
},
|
||
}
|
||
svc := newTestChannelService(repo)
|
||
|
||
// Delete should still succeed even though GetGroupIDs returned error (degradation path L588-591)
|
||
err := svc.Delete(context.Background(), 1)
|
||
require.NoError(t, err)
|
||
require.True(t, deleted)
|
||
}
|
||
|
||
// --- 6.4 ReplaceModelInBody with invalid JSON ---
|
||
|
||
func TestReplaceModelInBody_InvalidJSON(t *testing.T) {
|
||
// Case 1: broken JSON object — gjson won't find "model", sjson does best-effort set
|
||
// (no panic, no error from sjson, but result is mutated garbage)
|
||
brokenBody := []byte("{broken")
|
||
result := ReplaceModelInBody(brokenBody, "new-model")
|
||
require.NotNil(t, result)
|
||
// sjson does not error on this input, so result differs from original — just verify no panic
|
||
|
||
// Case 2: JSON array — sjson.SetBytes returns error on non-object,
|
||
// triggering the L447 error fallback path that returns original body.
|
||
arrayBody := []byte("[]")
|
||
result2 := ReplaceModelInBody(arrayBody, "new-model")
|
||
require.Equal(t, arrayBody, result2)
|
||
}
|
||
|
||
// ===========================================================================
|
||
// 7. isPlatformPricingMatch
|
||
// ===========================================================================
|
||
|
||
func TestIsPlatformPricingMatch(t *testing.T) {
|
||
tests := []struct {
|
||
name string
|
||
groupPlatform string
|
||
pricingPlatform string
|
||
want bool
|
||
}{
|
||
{"antigravity does NOT match anthropic", PlatformAntigravity, PlatformAnthropic, false},
|
||
{"antigravity does NOT match gemini", PlatformAntigravity, PlatformGemini, false},
|
||
{"antigravity matches antigravity", PlatformAntigravity, PlatformAntigravity, true},
|
||
{"antigravity does NOT match openai", PlatformAntigravity, PlatformOpenAI, false},
|
||
{"anthropic matches anthropic", PlatformAnthropic, PlatformAnthropic, true},
|
||
{"anthropic does NOT match antigravity", PlatformAnthropic, PlatformAntigravity, false},
|
||
{"anthropic does NOT match gemini", PlatformAnthropic, PlatformGemini, false},
|
||
{"gemini matches gemini", PlatformGemini, PlatformGemini, true},
|
||
{"gemini does NOT match antigravity", PlatformGemini, PlatformAntigravity, false},
|
||
{"gemini does NOT match anthropic", PlatformGemini, PlatformAnthropic, false},
|
||
{"empty string matches nothing", "", PlatformAnthropic, false},
|
||
{"empty string matches empty", "", "", true},
|
||
}
|
||
|
||
for _, tt := range tests {
|
||
t.Run(tt.name, func(t *testing.T) {
|
||
require.Equal(t, tt.want, isPlatformPricingMatch(tt.groupPlatform, tt.pricingPlatform))
|
||
})
|
||
}
|
||
}
|
||
|
||
// ===========================================================================
|
||
// 8. matchingPlatforms
|
||
// ===========================================================================
|
||
|
||
func TestMatchingPlatforms(t *testing.T) {
|
||
tests := []struct {
|
||
name string
|
||
groupPlatform string
|
||
want []string
|
||
}{
|
||
{"antigravity returns itself only", PlatformAntigravity, []string{PlatformAntigravity}},
|
||
{"anthropic returns itself", PlatformAnthropic, []string{PlatformAnthropic}},
|
||
{"gemini returns itself", PlatformGemini, []string{PlatformGemini}},
|
||
{"openai returns itself", PlatformOpenAI, []string{PlatformOpenAI}},
|
||
}
|
||
|
||
for _, tt := range tests {
|
||
t.Run(tt.name, func(t *testing.T) {
|
||
result := matchingPlatforms(tt.groupPlatform)
|
||
require.Equal(t, tt.want, result)
|
||
})
|
||
}
|
||
}
|
||
|
||
// ===========================================================================
|
||
// 9. Antigravity platform isolation — no cross-platform pricing leakage
|
||
// ===========================================================================
|
||
|
||
func TestGetChannelModelPricing_AntigravityDoesNotSeeCrossPlatformPricing(t *testing.T) {
|
||
// Channel has anthropic pricing for claude-opus-4-6.
|
||
// Group 10 is antigravity — should NOT see the anthropic pricing.
|
||
ch := Channel{
|
||
ID: 1,
|
||
Status: StatusActive,
|
||
GroupIDs: []int64{10},
|
||
ModelPricing: []ChannelModelPricing{
|
||
{ID: 100, Platform: PlatformAnthropic, Models: []string{"claude-opus-4-6"}, InputPrice: testPtrFloat64(15e-6)},
|
||
},
|
||
}
|
||
repo := makeStandardRepo(ch, map[int64]string{10: PlatformAntigravity})
|
||
svc := newTestChannelService(repo)
|
||
|
||
result := svc.GetChannelModelPricing(context.Background(), 10, "claude-opus-4-6")
|
||
require.Nil(t, result, "antigravity group should NOT see anthropic-platform pricing")
|
||
}
|
||
|
||
func TestGetChannelModelPricing_AnthropicCannotSeeAntigravityPricing(t *testing.T) {
|
||
// Channel has antigravity-platform pricing for claude-opus-4-6.
|
||
// Group 10 is anthropic — should NOT see antigravity pricing (no cross-platform leakage).
|
||
ch := Channel{
|
||
ID: 1,
|
||
Status: StatusActive,
|
||
GroupIDs: []int64{10},
|
||
ModelPricing: []ChannelModelPricing{
|
||
{ID: 100, Platform: PlatformAntigravity, Models: []string{"claude-opus-4-6"}, InputPrice: testPtrFloat64(15e-6)},
|
||
},
|
||
}
|
||
repo := makeStandardRepo(ch, map[int64]string{10: PlatformAnthropic})
|
||
svc := newTestChannelService(repo)
|
||
|
||
result := svc.GetChannelModelPricing(context.Background(), 10, "claude-opus-4-6")
|
||
require.Nil(t, result, "anthropic group should NOT see antigravity-platform pricing")
|
||
}
|
||
|
||
// ===========================================================================
|
||
// 10. Antigravity platform isolation — no cross-platform model mapping
|
||
// ===========================================================================
|
||
|
||
func TestResolveChannelMapping_AntigravityDoesNotSeeCrossPlatformMapping(t *testing.T) {
|
||
// Channel has anthropic model mapping: claude-opus-4-5 → claude-opus-4-6.
|
||
// Group 10 is antigravity — should NOT apply the anthropic mapping.
|
||
ch := Channel{
|
||
ID: 1,
|
||
Status: StatusActive,
|
||
GroupIDs: []int64{10},
|
||
ModelMapping: map[string]map[string]string{
|
||
PlatformAnthropic: {
|
||
"claude-opus-4-5": "claude-opus-4-6",
|
||
},
|
||
},
|
||
}
|
||
repo := makeStandardRepo(ch, map[int64]string{10: PlatformAntigravity})
|
||
svc := newTestChannelService(repo)
|
||
|
||
result := svc.ResolveChannelMapping(context.Background(), 10, "claude-opus-4-5")
|
||
require.False(t, result.Mapped, "antigravity group should NOT apply anthropic mapping")
|
||
require.Equal(t, "claude-opus-4-5", result.MappedModel)
|
||
}
|
||
|
||
// ===========================================================================
|
||
// 11. Antigravity platform isolation — same-name model across platforms
|
||
// ===========================================================================
|
||
|
||
func TestGetChannelModelPricing_AntigravityDoesNotSeeSameModelFromOtherPlatforms(t *testing.T) {
|
||
// anthropic 和 gemini 都定义了同名模型 "shared-model",价格不同。
|
||
// antigravity 分组不应看到任何一个(各平台严格独立)。
|
||
ch := Channel{
|
||
ID: 1,
|
||
Status: StatusActive,
|
||
GroupIDs: []int64{10},
|
||
ModelPricing: []ChannelModelPricing{
|
||
{ID: 200, Platform: PlatformAnthropic, Models: []string{"shared-model"}, InputPrice: testPtrFloat64(10e-6)},
|
||
{ID: 201, Platform: PlatformGemini, Models: []string{"shared-model"}, InputPrice: testPtrFloat64(5e-6)},
|
||
},
|
||
}
|
||
repo := makeStandardRepo(ch, map[int64]string{10: PlatformAntigravity})
|
||
svc := newTestChannelService(repo)
|
||
|
||
result := svc.GetChannelModelPricing(context.Background(), 10, "shared-model")
|
||
require.Nil(t, result, "antigravity group should NOT see anthropic/gemini-platform pricing")
|
||
}
|
||
|
||
func TestGetChannelModelPricing_AntigravityDoesNotSeeGeminiOnlyPricing(t *testing.T) {
|
||
// 只有 gemini 平台定义了模型 "gemini-model"。
|
||
// antigravity 分组不应看到 gemini 的定价。
|
||
ch := Channel{
|
||
ID: 1,
|
||
Status: StatusActive,
|
||
GroupIDs: []int64{10},
|
||
ModelPricing: []ChannelModelPricing{
|
||
{ID: 300, Platform: PlatformGemini, Models: []string{"gemini-model"}, InputPrice: testPtrFloat64(2e-6)},
|
||
},
|
||
}
|
||
repo := makeStandardRepo(ch, map[int64]string{10: PlatformAntigravity})
|
||
svc := newTestChannelService(repo)
|
||
|
||
result := svc.GetChannelModelPricing(context.Background(), 10, "gemini-model")
|
||
require.Nil(t, result, "antigravity group should NOT see gemini-platform pricing")
|
||
}
|
||
|
||
func TestGetChannelModelPricing_AntigravityDoesNotSeeWildcardFromOtherPlatforms(t *testing.T) {
|
||
// anthropic 和 gemini 都有 "shared-*" 通配符定价。
|
||
// antigravity 分组不应命中任何一个。
|
||
ch := Channel{
|
||
ID: 1,
|
||
Status: StatusActive,
|
||
GroupIDs: []int64{10},
|
||
ModelPricing: []ChannelModelPricing{
|
||
{ID: 400, Platform: PlatformAnthropic, Models: []string{"shared-*"}, InputPrice: testPtrFloat64(10e-6)},
|
||
{ID: 401, Platform: PlatformGemini, Models: []string{"shared-*"}, InputPrice: testPtrFloat64(5e-6)},
|
||
},
|
||
}
|
||
repo := makeStandardRepo(ch, map[int64]string{10: PlatformAntigravity})
|
||
svc := newTestChannelService(repo)
|
||
|
||
result := svc.GetChannelModelPricing(context.Background(), 10, "shared-model")
|
||
require.Nil(t, result, "antigravity group should NOT see wildcard pricing from other platforms")
|
||
}
|
||
|
||
func TestResolveChannelMapping_AntigravityDoesNotSeeMappingFromOtherPlatforms(t *testing.T) {
|
||
// anthropic 和 gemini 都定义了同名模型映射 "alias" → 不同目标。
|
||
// antigravity 分组不应命中任何一个。
|
||
ch := Channel{
|
||
ID: 1,
|
||
Status: StatusActive,
|
||
GroupIDs: []int64{10},
|
||
ModelMapping: map[string]map[string]string{
|
||
PlatformAnthropic: {"alias": "anthropic-target"},
|
||
PlatformGemini: {"alias": "gemini-target"},
|
||
},
|
||
}
|
||
repo := makeStandardRepo(ch, map[int64]string{10: PlatformAntigravity})
|
||
svc := newTestChannelService(repo)
|
||
|
||
result := svc.ResolveChannelMapping(context.Background(), 10, "alias")
|
||
require.False(t, result.Mapped, "antigravity group should NOT see mapping from other platforms")
|
||
require.Equal(t, "alias", result.MappedModel)
|
||
}
|
||
|
||
func TestCheckRestricted_AntigravityDoesNotSeeModelsFromOtherPlatforms(t *testing.T) {
|
||
// anthropic 和 gemini 都定义了同名模型 "shared-model"。
|
||
// antigravity 分组启用了 RestrictModels,"shared-model" 应被限制(各平台独立)。
|
||
ch := Channel{
|
||
ID: 1,
|
||
Status: StatusActive,
|
||
RestrictModels: true,
|
||
GroupIDs: []int64{10},
|
||
ModelPricing: []ChannelModelPricing{
|
||
{ID: 500, Platform: PlatformAnthropic, Models: []string{"shared-model"}, InputPrice: testPtrFloat64(10e-6)},
|
||
{ID: 501, Platform: PlatformGemini, Models: []string{"shared-model"}, InputPrice: testPtrFloat64(5e-6)},
|
||
},
|
||
}
|
||
repo := makeStandardRepo(ch, map[int64]string{10: PlatformAntigravity})
|
||
svc := newTestChannelService(repo)
|
||
|
||
restricted := svc.IsModelRestricted(context.Background(), 10, "shared-model")
|
||
require.True(t, restricted, "shared-model from other platforms should be restricted for antigravity")
|
||
|
||
restricted = svc.IsModelRestricted(context.Background(), 10, "unknown-model")
|
||
require.True(t, restricted, "unknown-model should be restricted for antigravity")
|
||
}
|
||
|
||
func TestGetChannelModelPricing_AntigravityOwnPricingWorks(t *testing.T) {
|
||
// antigravity 平台自己配置的定价应正常生效(覆盖 Claude 和 Gemini 模型)。
|
||
ch := Channel{
|
||
ID: 1,
|
||
Status: StatusActive,
|
||
GroupIDs: []int64{10},
|
||
ModelPricing: []ChannelModelPricing{
|
||
{ID: 600, Platform: PlatformAntigravity, Models: []string{"claude-*"}, InputPrice: testPtrFloat64(15e-6)},
|
||
{ID: 601, Platform: PlatformAntigravity, Models: []string{"gemini-*"}, InputPrice: testPtrFloat64(2e-6)},
|
||
},
|
||
}
|
||
repo := makeStandardRepo(ch, map[int64]string{10: PlatformAntigravity})
|
||
svc := newTestChannelService(repo)
|
||
|
||
// Claude 模型匹配 antigravity 定价
|
||
result := svc.GetChannelModelPricing(context.Background(), 10, "claude-sonnet-4")
|
||
require.NotNil(t, result)
|
||
require.Equal(t, int64(600), result.ID)
|
||
require.InDelta(t, 15e-6, *result.InputPrice, 1e-12)
|
||
|
||
// Gemini 模型匹配 antigravity 定价
|
||
result = svc.GetChannelModelPricing(context.Background(), 10, "gemini-2.5-flash")
|
||
require.NotNil(t, result)
|
||
require.Equal(t, int64(601), result.ID)
|
||
require.InDelta(t, 2e-6, *result.InputPrice, 1e-12)
|
||
}
|
||
|
||
func TestGetChannelModelPricing_NonAntigravityUnaffected(t *testing.T) {
|
||
// 确保非 antigravity 平台的行为不受影响。
|
||
// anthropic 分组只能看到 anthropic 的定价,看不到 gemini 的。
|
||
ch := Channel{
|
||
ID: 1,
|
||
Status: StatusActive,
|
||
GroupIDs: []int64{10, 20},
|
||
ModelPricing: []ChannelModelPricing{
|
||
{ID: 600, Platform: PlatformAnthropic, Models: []string{"shared-model"}, InputPrice: testPtrFloat64(10e-6)},
|
||
{ID: 601, Platform: PlatformGemini, Models: []string{"shared-model"}, InputPrice: testPtrFloat64(5e-6)},
|
||
},
|
||
}
|
||
repo := makeStandardRepo(ch, map[int64]string{10: PlatformAnthropic, 20: PlatformGemini})
|
||
svc := newTestChannelService(repo)
|
||
|
||
// anthropic 分组应该只看到 anthropic 的定价
|
||
result := svc.GetChannelModelPricing(context.Background(), 10, "shared-model")
|
||
require.NotNil(t, result)
|
||
require.Equal(t, int64(600), result.ID)
|
||
require.InDelta(t, 10e-6, *result.InputPrice, 1e-12)
|
||
|
||
// gemini 分组应该只看到 gemini 的定价
|
||
result = svc.GetChannelModelPricing(context.Background(), 20, "shared-model")
|
||
require.NotNil(t, result)
|
||
require.Equal(t, int64(601), result.ID)
|
||
require.InDelta(t, 5e-6, *result.InputPrice, 1e-12)
|
||
}
|
||
|
||
// ---------------------------------------------------------------------------
|
||
// 10. ToUsageFields
|
||
// ---------------------------------------------------------------------------
|
||
|
||
func TestToUsageFields_NoMapping(t *testing.T) {
|
||
r := ChannelMappingResult{
|
||
MappedModel: "claude-opus-4",
|
||
ChannelID: 1,
|
||
Mapped: false,
|
||
BillingModelSource: BillingModelSourceRequested,
|
||
}
|
||
fields := r.ToUsageFields("claude-opus-4", "claude-opus-4")
|
||
require.Equal(t, int64(1), fields.ChannelID)
|
||
require.Equal(t, "claude-opus-4", fields.OriginalModel)
|
||
require.Equal(t, "claude-opus-4", fields.ChannelMappedModel)
|
||
require.Equal(t, BillingModelSourceRequested, fields.BillingModelSource)
|
||
require.Empty(t, fields.ModelMappingChain)
|
||
}
|
||
|
||
func TestToUsageFields_WithChannelMapping(t *testing.T) {
|
||
r := ChannelMappingResult{
|
||
MappedModel: "claude-sonnet-4-20250514",
|
||
ChannelID: 2,
|
||
Mapped: true,
|
||
BillingModelSource: BillingModelSourceChannelMapped,
|
||
}
|
||
fields := r.ToUsageFields("claude-sonnet-4", "claude-sonnet-4-20250514")
|
||
require.Equal(t, int64(2), fields.ChannelID)
|
||
require.Equal(t, "claude-sonnet-4", fields.OriginalModel)
|
||
require.Equal(t, "claude-sonnet-4-20250514", fields.ChannelMappedModel)
|
||
require.Equal(t, "claude-sonnet-4→claude-sonnet-4-20250514", fields.ModelMappingChain)
|
||
}
|
||
|
||
func TestToUsageFields_WithUpstreamDifference(t *testing.T) {
|
||
r := ChannelMappingResult{
|
||
MappedModel: "claude-sonnet-4",
|
||
ChannelID: 3,
|
||
Mapped: true,
|
||
BillingModelSource: BillingModelSourceUpstream,
|
||
}
|
||
fields := r.ToUsageFields("my-alias", "claude-sonnet-4-20250514")
|
||
require.Equal(t, "my-alias", fields.OriginalModel)
|
||
require.Equal(t, "claude-sonnet-4", fields.ChannelMappedModel)
|
||
require.Equal(t, "my-alias→claude-sonnet-4→claude-sonnet-4-20250514", fields.ModelMappingChain)
|
||
}
|
||
|
||
// ---------------------------------------------------------------------------
|
||
// 11. validatePricingBillingMode (moved from handler tests)
|
||
// ---------------------------------------------------------------------------
|
||
|
||
func TestValidatePricingBillingMode(t *testing.T) {
|
||
tests := []struct {
|
||
name string
|
||
pricing []ChannelModelPricing
|
||
wantErr bool
|
||
errMsg string
|
||
}{
|
||
{
|
||
name: "token mode - valid",
|
||
pricing: []ChannelModelPricing{{BillingMode: BillingModeToken}},
|
||
},
|
||
{
|
||
name: "per_request with price - valid",
|
||
pricing: []ChannelModelPricing{{
|
||
BillingMode: BillingModePerRequest,
|
||
PerRequestPrice: testPtrFloat64(0.5),
|
||
}},
|
||
},
|
||
{
|
||
name: "per_request with intervals - valid",
|
||
pricing: []ChannelModelPricing{{
|
||
BillingMode: BillingModePerRequest,
|
||
Intervals: []PricingInterval{{MinTokens: 0, MaxTokens: testPtrInt(1000), PerRequestPrice: testPtrFloat64(0.1)}},
|
||
}},
|
||
},
|
||
{
|
||
name: "per_request no price no intervals - invalid",
|
||
pricing: []ChannelModelPricing{{BillingMode: BillingModePerRequest}},
|
||
wantErr: true,
|
||
errMsg: "per-request price or intervals required",
|
||
},
|
||
{
|
||
name: "image no price no intervals - invalid",
|
||
pricing: []ChannelModelPricing{{BillingMode: BillingModeImage}},
|
||
wantErr: true,
|
||
errMsg: "per-request price or intervals required",
|
||
},
|
||
{
|
||
name: "empty list - valid",
|
||
pricing: []ChannelModelPricing{},
|
||
},
|
||
{
|
||
name: "negative input_price - invalid",
|
||
pricing: []ChannelModelPricing{{
|
||
BillingMode: BillingModeToken,
|
||
InputPrice: testPtrFloat64(-0.01),
|
||
}},
|
||
wantErr: true,
|
||
errMsg: "input_price must be >= 0",
|
||
},
|
||
{
|
||
name: "interval with no price fields - invalid",
|
||
pricing: []ChannelModelPricing{{
|
||
BillingMode: BillingModePerRequest,
|
||
PerRequestPrice: testPtrFloat64(0.5),
|
||
Intervals: []PricingInterval{{MinTokens: 0, MaxTokens: testPtrInt(1000)}},
|
||
}},
|
||
wantErr: true,
|
||
errMsg: "has no price fields set",
|
||
},
|
||
}
|
||
|
||
for _, tt := range tests {
|
||
t.Run(tt.name, func(t *testing.T) {
|
||
err := validatePricingBillingMode(tt.pricing)
|
||
if tt.wantErr {
|
||
require.Error(t, err)
|
||
require.Contains(t, err.Error(), tt.errMsg)
|
||
} else {
|
||
require.NoError(t, err)
|
||
}
|
||
})
|
||
}
|
||
}
|
||
|
||
// ---------------------------------------------------------------------------
|
||
// 12. Antigravity wildcard mapping isolation
|
||
// ---------------------------------------------------------------------------
|
||
|
||
func TestResolveChannelMapping_AntigravityDoesNotSeeWildcardMappingFromOtherPlatforms(t *testing.T) {
|
||
ch := Channel{
|
||
ID: 1,
|
||
Status: StatusActive,
|
||
GroupIDs: []int64{10, 20},
|
||
ModelMapping: map[string]map[string]string{
|
||
PlatformAnthropic: {"claude-*": "claude-override"},
|
||
PlatformGemini: {"gemini-*": "gemini-override"},
|
||
},
|
||
}
|
||
repo := makeStandardRepo(ch, map[int64]string{10: PlatformAntigravity, 20: PlatformAnthropic})
|
||
svc := newTestChannelService(repo)
|
||
|
||
// antigravity 分组不应看到 anthropic/gemini 的通配符映射
|
||
result := svc.ResolveChannelMapping(context.Background(), 10, "claude-opus-4")
|
||
require.False(t, result.Mapped)
|
||
require.Equal(t, "claude-opus-4", result.MappedModel)
|
||
|
||
result = svc.ResolveChannelMapping(context.Background(), 10, "gemini-2.5-pro")
|
||
require.False(t, result.Mapped)
|
||
require.Equal(t, "gemini-2.5-pro", result.MappedModel)
|
||
|
||
// anthropic 分组应该能看到 anthropic 的通配符映射
|
||
result = svc.ResolveChannelMapping(context.Background(), 20, "claude-opus-4")
|
||
require.True(t, result.Mapped)
|
||
require.Equal(t, "claude-override", result.MappedModel)
|
||
}
|
||
|
||
// ---------------------------------------------------------------------------
|
||
// 13. Create/Update with mapping conflict validation
|
||
// ---------------------------------------------------------------------------
|
||
|
||
func TestCreate_MappingConflict(t *testing.T) {
|
||
repo := &mockChannelRepository{}
|
||
svc := newTestChannelService(repo)
|
||
|
||
_, err := svc.Create(context.Background(), &CreateChannelInput{
|
||
Name: "test",
|
||
ModelMapping: map[string]map[string]string{
|
||
PlatformAnthropic: {
|
||
"claude-*": "target-a",
|
||
"claude-opus-*": "target-b",
|
||
},
|
||
},
|
||
})
|
||
require.Error(t, err)
|
||
require.Contains(t, err.Error(), "MAPPING_PATTERN_CONFLICT")
|
||
}
|
||
|
||
func TestUpdate_MappingConflict(t *testing.T) {
|
||
existingChannel := &Channel{
|
||
ID: 1,
|
||
Name: "existing",
|
||
Status: StatusActive,
|
||
}
|
||
repo := &mockChannelRepository{
|
||
getByIDFn: func(_ context.Context, _ int64) (*Channel, error) {
|
||
return existingChannel, nil
|
||
},
|
||
}
|
||
svc := newTestChannelService(repo)
|
||
|
||
conflictMapping := map[string]map[string]string{
|
||
PlatformAnthropic: {
|
||
"claude-*": "target-a",
|
||
"claude-opus-*": "target-b",
|
||
},
|
||
}
|
||
_, err := svc.Update(context.Background(), 1, &UpdateChannelInput{
|
||
ModelMapping: conflictMapping,
|
||
})
|
||
require.Error(t, err)
|
||
require.Contains(t, err.Error(), "MAPPING_PATTERN_CONFLICT")
|
||
}
|