Move the model pricing restriction check from 8 handler entry points to the account scheduling phase (SelectAccountForModelWithExclusions / SelectAccountWithLoadAwareness), aligning restriction with billing: - requested: check original request model against pricing list - channel_mapped: check channel-mapped model against pricing list - upstream: per-account check using account-mapped model Handler layer now only resolves channel mapping (no restriction). Scheduling layer performs pre-check for requested/channel_mapped, and per-account filtering for upstream billing source.
1987 lines
61 KiB
Go
1987 lines
61 KiB
Go
//go:build unit
|
||
|
||
package service
|
||
|
||
import (
|
||
"context"
|
||
"errors"
|
||
"testing"
|
||
|
||
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||
"github.com/stretchr/testify/require"
|
||
)
|
||
|
||
// ---------------------------------------------------------------------------
|
||
// Mock: ChannelRepository
|
||
// ---------------------------------------------------------------------------
|
||
|
||
type mockChannelRepository struct {
|
||
listAllFn func(ctx context.Context) ([]Channel, error)
|
||
getGroupPlatformsFn func(ctx context.Context, groupIDs []int64) (map[int64]string, error)
|
||
createFn func(ctx context.Context, channel *Channel) error
|
||
getByIDFn func(ctx context.Context, id int64) (*Channel, error)
|
||
updateFn func(ctx context.Context, channel *Channel) error
|
||
deleteFn func(ctx context.Context, id int64) error
|
||
listFn func(ctx context.Context, params pagination.PaginationParams, status, search string) ([]Channel, *pagination.PaginationResult, error)
|
||
existsByNameFn func(ctx context.Context, name string) (bool, error)
|
||
existsByNameExcludingFn func(ctx context.Context, name string, excludeID int64) (bool, error)
|
||
getGroupIDsFn func(ctx context.Context, channelID int64) ([]int64, error)
|
||
setGroupIDsFn func(ctx context.Context, channelID int64, groupIDs []int64) error
|
||
getChannelIDByGroupIDFn func(ctx context.Context, groupID int64) (int64, error)
|
||
getGroupsInOtherChannelsFn func(ctx context.Context, channelID int64, groupIDs []int64) ([]int64, error)
|
||
listModelPricingFn func(ctx context.Context, channelID int64) ([]ChannelModelPricing, error)
|
||
createModelPricingFn func(ctx context.Context, pricing *ChannelModelPricing) error
|
||
updateModelPricingFn func(ctx context.Context, pricing *ChannelModelPricing) error
|
||
deleteModelPricingFn func(ctx context.Context, id int64) error
|
||
replaceModelPricingFn func(ctx context.Context, channelID int64, pricingList []ChannelModelPricing) error
|
||
}
|
||
|
||
func (m *mockChannelRepository) Create(ctx context.Context, channel *Channel) error {
|
||
if m.createFn != nil {
|
||
return m.createFn(ctx, channel)
|
||
}
|
||
return nil
|
||
}
|
||
|
||
func (m *mockChannelRepository) GetByID(ctx context.Context, id int64) (*Channel, error) {
|
||
if m.getByIDFn != nil {
|
||
return m.getByIDFn(ctx, id)
|
||
}
|
||
return nil, ErrChannelNotFound
|
||
}
|
||
|
||
func (m *mockChannelRepository) Update(ctx context.Context, channel *Channel) error {
|
||
if m.updateFn != nil {
|
||
return m.updateFn(ctx, channel)
|
||
}
|
||
return nil
|
||
}
|
||
|
||
func (m *mockChannelRepository) Delete(ctx context.Context, id int64) error {
|
||
if m.deleteFn != nil {
|
||
return m.deleteFn(ctx, id)
|
||
}
|
||
return nil
|
||
}
|
||
|
||
func (m *mockChannelRepository) List(ctx context.Context, params pagination.PaginationParams, status, search string) ([]Channel, *pagination.PaginationResult, error) {
|
||
if m.listFn != nil {
|
||
return m.listFn(ctx, params, status, search)
|
||
}
|
||
return nil, nil, nil
|
||
}
|
||
|
||
func (m *mockChannelRepository) ListAll(ctx context.Context) ([]Channel, error) {
|
||
if m.listAllFn != nil {
|
||
return m.listAllFn(ctx)
|
||
}
|
||
return nil, nil
|
||
}
|
||
|
||
func (m *mockChannelRepository) ExistsByName(ctx context.Context, name string) (bool, error) {
|
||
if m.existsByNameFn != nil {
|
||
return m.existsByNameFn(ctx, name)
|
||
}
|
||
return false, nil
|
||
}
|
||
|
||
func (m *mockChannelRepository) ExistsByNameExcluding(ctx context.Context, name string, excludeID int64) (bool, error) {
|
||
if m.existsByNameExcludingFn != nil {
|
||
return m.existsByNameExcludingFn(ctx, name, excludeID)
|
||
}
|
||
return false, nil
|
||
}
|
||
|
||
func (m *mockChannelRepository) GetGroupIDs(ctx context.Context, channelID int64) ([]int64, error) {
|
||
if m.getGroupIDsFn != nil {
|
||
return m.getGroupIDsFn(ctx, channelID)
|
||
}
|
||
return nil, nil
|
||
}
|
||
|
||
func (m *mockChannelRepository) SetGroupIDs(ctx context.Context, channelID int64, groupIDs []int64) error {
|
||
if m.setGroupIDsFn != nil {
|
||
return m.setGroupIDsFn(ctx, channelID, groupIDs)
|
||
}
|
||
return nil
|
||
}
|
||
|
||
func (m *mockChannelRepository) GetChannelIDByGroupID(ctx context.Context, groupID int64) (int64, error) {
|
||
if m.getChannelIDByGroupIDFn != nil {
|
||
return m.getChannelIDByGroupIDFn(ctx, groupID)
|
||
}
|
||
return 0, nil
|
||
}
|
||
|
||
func (m *mockChannelRepository) GetGroupsInOtherChannels(ctx context.Context, channelID int64, groupIDs []int64) ([]int64, error) {
|
||
if m.getGroupsInOtherChannelsFn != nil {
|
||
return m.getGroupsInOtherChannelsFn(ctx, channelID, groupIDs)
|
||
}
|
||
return nil, nil
|
||
}
|
||
|
||
func (m *mockChannelRepository) GetGroupPlatforms(ctx context.Context, groupIDs []int64) (map[int64]string, error) {
|
||
if m.getGroupPlatformsFn != nil {
|
||
return m.getGroupPlatformsFn(ctx, groupIDs)
|
||
}
|
||
return nil, nil
|
||
}
|
||
|
||
func (m *mockChannelRepository) ListModelPricing(ctx context.Context, channelID int64) ([]ChannelModelPricing, error) {
|
||
if m.listModelPricingFn != nil {
|
||
return m.listModelPricingFn(ctx, channelID)
|
||
}
|
||
return nil, nil
|
||
}
|
||
|
||
func (m *mockChannelRepository) CreateModelPricing(ctx context.Context, pricing *ChannelModelPricing) error {
|
||
if m.createModelPricingFn != nil {
|
||
return m.createModelPricingFn(ctx, pricing)
|
||
}
|
||
return nil
|
||
}
|
||
|
||
func (m *mockChannelRepository) UpdateModelPricing(ctx context.Context, pricing *ChannelModelPricing) error {
|
||
if m.updateModelPricingFn != nil {
|
||
return m.updateModelPricingFn(ctx, pricing)
|
||
}
|
||
return nil
|
||
}
|
||
|
||
func (m *mockChannelRepository) DeleteModelPricing(ctx context.Context, id int64) error {
|
||
if m.deleteModelPricingFn != nil {
|
||
return m.deleteModelPricingFn(ctx, id)
|
||
}
|
||
return nil
|
||
}
|
||
|
||
func (m *mockChannelRepository) ReplaceModelPricing(ctx context.Context, channelID int64, pricingList []ChannelModelPricing) error {
|
||
if m.replaceModelPricingFn != nil {
|
||
return m.replaceModelPricingFn(ctx, channelID, pricingList)
|
||
}
|
||
return nil
|
||
}
|
||
|
||
// ---------------------------------------------------------------------------
|
||
// Mock: APIKeyAuthCacheInvalidator
|
||
// ---------------------------------------------------------------------------
|
||
|
||
type mockChannelAuthCacheInvalidator struct {
|
||
invalidatedGroupIDs []int64
|
||
invalidatedKeys []string
|
||
invalidatedUserIDs []int64
|
||
}
|
||
|
||
func (m *mockChannelAuthCacheInvalidator) InvalidateAuthCacheByKey(_ context.Context, key string) {
|
||
m.invalidatedKeys = append(m.invalidatedKeys, key)
|
||
}
|
||
|
||
func (m *mockChannelAuthCacheInvalidator) InvalidateAuthCacheByUserID(_ context.Context, userID int64) {
|
||
m.invalidatedUserIDs = append(m.invalidatedUserIDs, userID)
|
||
}
|
||
|
||
func (m *mockChannelAuthCacheInvalidator) InvalidateAuthCacheByGroupID(_ context.Context, groupID int64) {
|
||
m.invalidatedGroupIDs = append(m.invalidatedGroupIDs, groupID)
|
||
}
|
||
|
||
// ---------------------------------------------------------------------------
|
||
// Helpers
|
||
// ---------------------------------------------------------------------------
|
||
|
||
func newTestChannelService(repo *mockChannelRepository) *ChannelService {
|
||
return NewChannelService(repo, nil)
|
||
}
|
||
|
||
func newTestChannelServiceWithAuth(repo *mockChannelRepository, auth *mockChannelAuthCacheInvalidator) *ChannelService {
|
||
return NewChannelService(repo, auth)
|
||
}
|
||
|
||
// makeStandardRepo returns a repo that serves one active channel with anthropic pricing
|
||
// for group 1, with the given model pricing and model mapping.
|
||
func makeStandardRepo(ch Channel, groupPlatforms map[int64]string) *mockChannelRepository {
|
||
return &mockChannelRepository{
|
||
listAllFn: func(_ context.Context) ([]Channel, error) {
|
||
return []Channel{ch}, nil
|
||
},
|
||
getGroupPlatformsFn: func(_ context.Context, _ []int64) (map[int64]string, error) {
|
||
return groupPlatforms, nil
|
||
},
|
||
}
|
||
}
|
||
|
||
// ===========================================================================
|
||
// 1. BuildModelMappingChain
|
||
// ===========================================================================
|
||
|
||
func TestBuildModelMappingChain(t *testing.T) {
|
||
tests := []struct {
|
||
name string
|
||
result ChannelMappingResult
|
||
requestModel string
|
||
upstreamModel string
|
||
want string
|
||
}{
|
||
{
|
||
name: "no mapping, no upstream diff",
|
||
result: ChannelMappingResult{Mapped: false, MappedModel: "claude-sonnet-4"},
|
||
requestModel: "claude-sonnet-4",
|
||
upstreamModel: "claude-sonnet-4",
|
||
want: "",
|
||
},
|
||
{
|
||
name: "no mapping, upstream differs",
|
||
result: ChannelMappingResult{Mapped: false, MappedModel: "claude-sonnet-4"},
|
||
requestModel: "claude-sonnet-4",
|
||
upstreamModel: "claude-sonnet-4-20250514",
|
||
want: "claude-sonnet-4\u2192claude-sonnet-4-20250514",
|
||
},
|
||
{
|
||
name: "mapped, upstream differs",
|
||
result: ChannelMappingResult{Mapped: true, MappedModel: "claude-sonnet-4-20250514"},
|
||
requestModel: "my-model",
|
||
upstreamModel: "actual-upstream",
|
||
want: "my-model\u2192claude-sonnet-4-20250514\u2192actual-upstream",
|
||
},
|
||
{
|
||
name: "mapped, upstream same as mapped",
|
||
result: ChannelMappingResult{Mapped: true, MappedModel: "claude-sonnet-4-20250514"},
|
||
requestModel: "claude-sonnet-4",
|
||
upstreamModel: "claude-sonnet-4-20250514",
|
||
want: "claude-sonnet-4\u2192claude-sonnet-4-20250514",
|
||
},
|
||
{
|
||
name: "mapped, upstream empty",
|
||
result: ChannelMappingResult{Mapped: true, MappedModel: "target-model"},
|
||
requestModel: "my-model",
|
||
upstreamModel: "",
|
||
want: "my-model\u2192target-model",
|
||
},
|
||
}
|
||
|
||
for _, tt := range tests {
|
||
t.Run(tt.name, func(t *testing.T) {
|
||
got := tt.result.BuildModelMappingChain(tt.requestModel, tt.upstreamModel)
|
||
require.Equal(t, tt.want, got)
|
||
})
|
||
}
|
||
}
|
||
|
||
// ===========================================================================
|
||
// 2. ReplaceModelInBody
|
||
// ===========================================================================
|
||
|
||
func TestReplaceModelInBody(t *testing.T) {
|
||
tests := []struct {
|
||
name string
|
||
body []byte
|
||
newModel string
|
||
check func(t *testing.T, result []byte)
|
||
}{
|
||
{
|
||
name: "empty body",
|
||
body: []byte{},
|
||
newModel: "new-model",
|
||
check: func(t *testing.T, result []byte) {
|
||
require.Equal(t, []byte{}, result)
|
||
},
|
||
},
|
||
{
|
||
name: "model already equal",
|
||
body: []byte(`{"model":"claude-sonnet-4","temperature":0.7}`),
|
||
newModel: "claude-sonnet-4",
|
||
check: func(t *testing.T, result []byte) {
|
||
require.Equal(t, []byte(`{"model":"claude-sonnet-4","temperature":0.7}`), result)
|
||
},
|
||
},
|
||
{
|
||
name: "model different",
|
||
body: []byte(`{"model":"claude-sonnet-4","temperature":0.7}`),
|
||
newModel: "claude-opus-4",
|
||
check: func(t *testing.T, result []byte) {
|
||
require.Contains(t, string(result), `"model":"claude-opus-4"`)
|
||
require.Contains(t, string(result), `"temperature"`)
|
||
},
|
||
},
|
||
{
|
||
name: "no model field",
|
||
body: []byte(`{"temperature":0.7}`),
|
||
newModel: "claude-opus-4",
|
||
check: func(t *testing.T, result []byte) {
|
||
require.Contains(t, string(result), `"model":"claude-opus-4"`)
|
||
require.Contains(t, string(result), `"temperature"`)
|
||
},
|
||
},
|
||
}
|
||
|
||
for _, tt := range tests {
|
||
t.Run(tt.name, func(t *testing.T) {
|
||
result := ReplaceModelInBody(tt.body, tt.newModel)
|
||
tt.check(t, result)
|
||
})
|
||
}
|
||
}
|
||
|
||
// ===========================================================================
|
||
// 3. validateNoConflictingModels + validateNoConflictingMappings
|
||
// ===========================================================================
|
||
|
||
func TestValidateNoConflictingModels(t *testing.T) {
|
||
tests := []struct {
|
||
name string
|
||
pricingList []ChannelModelPricing
|
||
wantErr bool
|
||
errContains string
|
||
}{
|
||
{
|
||
name: "no duplicates",
|
||
pricingList: []ChannelModelPricing{
|
||
{Platform: "anthropic", Models: []string{"claude-sonnet-4", "claude-opus-4"}},
|
||
{Platform: "openai", Models: []string{"gpt-5.1"}},
|
||
},
|
||
wantErr: false,
|
||
},
|
||
{
|
||
name: "same platform duplicate",
|
||
pricingList: []ChannelModelPricing{
|
||
{Platform: "anthropic", Models: []string{"claude-sonnet-4"}},
|
||
{Platform: "anthropic", Models: []string{"claude-sonnet-4"}},
|
||
},
|
||
wantErr: true,
|
||
errContains: "claude-sonnet-4",
|
||
},
|
||
{
|
||
name: "same model different platform",
|
||
pricingList: []ChannelModelPricing{
|
||
{Platform: "anthropic", Models: []string{"model-a"}},
|
||
{Platform: "openai", Models: []string{"model-a"}},
|
||
},
|
||
wantErr: false,
|
||
},
|
||
{
|
||
name: "case insensitive",
|
||
pricingList: []ChannelModelPricing{
|
||
{Platform: "anthropic", Models: []string{"Claude"}},
|
||
{Platform: "anthropic", Models: []string{"claude"}},
|
||
},
|
||
wantErr: true,
|
||
},
|
||
{
|
||
name: "empty list (nil)",
|
||
pricingList: nil,
|
||
wantErr: false,
|
||
},
|
||
{
|
||
name: "wildcard_vs_wildcard_conflict",
|
||
pricingList: []ChannelModelPricing{
|
||
{Platform: "anthropic", Models: []string{"claude-*"}},
|
||
{Platform: "anthropic", Models: []string{"claude-opus-*"}},
|
||
},
|
||
wantErr: true,
|
||
errContains: "conflict",
|
||
},
|
||
{
|
||
name: "wildcard_vs_exact_conflict",
|
||
pricingList: []ChannelModelPricing{
|
||
{Platform: "anthropic", Models: []string{"claude-*"}},
|
||
{Platform: "anthropic", Models: []string{"claude-opus-4-6"}},
|
||
},
|
||
wantErr: true,
|
||
errContains: "conflict",
|
||
},
|
||
{
|
||
name: "no_conflict_different_platform",
|
||
pricingList: []ChannelModelPricing{
|
||
{Platform: "anthropic", Models: []string{"claude-opus-*"}},
|
||
{Platform: "openai", Models: []string{"claude-*"}},
|
||
},
|
||
wantErr: false,
|
||
},
|
||
{
|
||
name: "no_conflict_same_platform_different_prefix",
|
||
pricingList: []ChannelModelPricing{
|
||
{Platform: "anthropic", Models: []string{"claude-opus-*"}},
|
||
{Platform: "anthropic", Models: []string{"gpt-*"}},
|
||
},
|
||
wantErr: false,
|
||
},
|
||
{
|
||
name: "catch_all_wildcard_conflicts_with_everything",
|
||
pricingList: []ChannelModelPricing{
|
||
{Platform: "openai", Models: []string{"*"}},
|
||
{Platform: "openai", Models: []string{"gpt-5"}},
|
||
},
|
||
wantErr: true,
|
||
errContains: "conflict",
|
||
},
|
||
}
|
||
|
||
for _, tt := range tests {
|
||
t.Run(tt.name, func(t *testing.T) {
|
||
err := validateNoConflictingModels(tt.pricingList)
|
||
if tt.wantErr {
|
||
require.Error(t, err)
|
||
if tt.errContains != "" {
|
||
require.Contains(t, err.Error(), tt.errContains)
|
||
}
|
||
} else {
|
||
require.NoError(t, err)
|
||
}
|
||
})
|
||
}
|
||
|
||
// Additional sub-case: explicit empty slice
|
||
t.Run("empty list (empty slice)", func(t *testing.T) {
|
||
err := validateNoConflictingModels([]ChannelModelPricing{})
|
||
require.NoError(t, err)
|
||
})
|
||
}
|
||
|
||
func TestValidateNoConflictingMappings(t *testing.T) {
|
||
tests := []struct {
|
||
name string
|
||
mapping map[string]map[string]string
|
||
wantErr bool
|
||
errContains string
|
||
}{
|
||
{
|
||
name: "nil mapping",
|
||
mapping: nil,
|
||
wantErr: false,
|
||
},
|
||
{
|
||
name: "empty mapping",
|
||
mapping: map[string]map[string]string{},
|
||
wantErr: false,
|
||
},
|
||
{
|
||
name: "no conflict",
|
||
mapping: map[string]map[string]string{
|
||
"anthropic": {"claude-opus-*": "opus", "gpt-*": "gpt"},
|
||
},
|
||
wantErr: false,
|
||
},
|
||
{
|
||
name: "wildcard vs wildcard conflict",
|
||
mapping: map[string]map[string]string{
|
||
"anthropic": {"claude-*": "a", "claude-opus-*": "b"},
|
||
},
|
||
wantErr: true,
|
||
errContains: "conflict",
|
||
},
|
||
{
|
||
name: "wildcard vs exact conflict",
|
||
mapping: map[string]map[string]string{
|
||
"openai": {"gpt-*": "a", "gpt-4o": "b"},
|
||
},
|
||
wantErr: true,
|
||
errContains: "conflict",
|
||
},
|
||
{
|
||
name: "exact duplicate conflict",
|
||
mapping: map[string]map[string]string{
|
||
"anthropic": {"claude-opus-4": "a"},
|
||
"openai": {"claude-opus-4": "b"},
|
||
},
|
||
wantErr: false, // different platforms
|
||
},
|
||
{
|
||
name: "different platforms no conflict",
|
||
mapping: map[string]map[string]string{
|
||
"anthropic": {"claude-*": "a"},
|
||
"openai": {"claude-*": "b"},
|
||
},
|
||
wantErr: false,
|
||
},
|
||
}
|
||
|
||
for _, tt := range tests {
|
||
t.Run(tt.name, func(t *testing.T) {
|
||
err := validateNoConflictingMappings(tt.mapping)
|
||
if tt.wantErr {
|
||
require.Error(t, err)
|
||
if tt.errContains != "" {
|
||
require.Contains(t, err.Error(), tt.errContains)
|
||
}
|
||
} else {
|
||
require.NoError(t, err)
|
||
}
|
||
})
|
||
}
|
||
}
|
||
|
||
func TestConflictsBetween(t *testing.T) {
|
||
tests := []struct {
|
||
name string
|
||
a, b modelEntry
|
||
want bool
|
||
}{
|
||
{
|
||
name: "exact same",
|
||
a: modelEntry{prefix: "claude-opus-4", wildcard: false},
|
||
b: modelEntry{prefix: "claude-opus-4", wildcard: false},
|
||
want: true,
|
||
},
|
||
{
|
||
name: "exact different",
|
||
a: modelEntry{prefix: "claude-opus-4", wildcard: false},
|
||
b: modelEntry{prefix: "gpt-4o", wildcard: false},
|
||
want: false,
|
||
},
|
||
{
|
||
name: "wildcard matches exact",
|
||
a: modelEntry{prefix: "claude-", wildcard: true},
|
||
b: modelEntry{prefix: "claude-opus-4", wildcard: false},
|
||
want: true,
|
||
},
|
||
{
|
||
name: "exact does not match unrelated wildcard",
|
||
a: modelEntry{prefix: "gpt-4o", wildcard: false},
|
||
b: modelEntry{prefix: "claude-", wildcard: true},
|
||
want: false,
|
||
},
|
||
{
|
||
name: "wildcard prefix overlap",
|
||
a: modelEntry{prefix: "claude-", wildcard: true},
|
||
b: modelEntry{prefix: "claude-opus-", wildcard: true},
|
||
want: true,
|
||
},
|
||
{
|
||
name: "wildcards no overlap",
|
||
a: modelEntry{prefix: "claude-", wildcard: true},
|
||
b: modelEntry{prefix: "gpt-", wildcard: true},
|
||
want: false,
|
||
},
|
||
{
|
||
name: "catch-all wildcard vs any",
|
||
a: modelEntry{prefix: "", wildcard: true},
|
||
b: modelEntry{prefix: "anything", wildcard: false},
|
||
want: true,
|
||
},
|
||
}
|
||
|
||
for _, tt := range tests {
|
||
t.Run(tt.name, func(t *testing.T) {
|
||
require.Equal(t, tt.want, conflictsBetween(tt.a, tt.b))
|
||
})
|
||
}
|
||
}
|
||
|
||
// ===========================================================================
|
||
// 4. Cache Building + Hot Path Methods
|
||
// ===========================================================================
|
||
|
||
// --- 4.1 GetChannelForGroup ---
|
||
|
||
func TestGetChannelForGroup_Success(t *testing.T) {
|
||
ch := Channel{
|
||
ID: 1,
|
||
Name: "test-channel",
|
||
Status: StatusActive,
|
||
GroupIDs: []int64{10},
|
||
}
|
||
repo := makeStandardRepo(ch, map[int64]string{10: "anthropic"})
|
||
svc := newTestChannelService(repo)
|
||
|
||
result, err := svc.GetChannelForGroup(context.Background(), 10)
|
||
require.NoError(t, err)
|
||
require.NotNil(t, result)
|
||
require.Equal(t, int64(1), result.ID)
|
||
require.Equal(t, "test-channel", result.Name)
|
||
|
||
// returned value should be a clone
|
||
result.Name = "mutated"
|
||
result2, err := svc.GetChannelForGroup(context.Background(), 10)
|
||
require.NoError(t, err)
|
||
require.Equal(t, "test-channel", result2.Name)
|
||
}
|
||
|
||
func TestGetChannelForGroup_InactiveChannel(t *testing.T) {
|
||
ch := Channel{
|
||
ID: 1,
|
||
Status: StatusDisabled,
|
||
GroupIDs: []int64{10},
|
||
}
|
||
repo := makeStandardRepo(ch, map[int64]string{10: "anthropic"})
|
||
svc := newTestChannelService(repo)
|
||
|
||
result, err := svc.GetChannelForGroup(context.Background(), 10)
|
||
require.NoError(t, err)
|
||
require.Nil(t, result)
|
||
}
|
||
|
||
func TestGetChannelForGroup_NoChannel(t *testing.T) {
|
||
ch := Channel{
|
||
ID: 1,
|
||
Status: StatusActive,
|
||
GroupIDs: []int64{10},
|
||
}
|
||
repo := makeStandardRepo(ch, map[int64]string{10: "anthropic"})
|
||
svc := newTestChannelService(repo)
|
||
|
||
result, err := svc.GetChannelForGroup(context.Background(), 999)
|
||
require.NoError(t, err)
|
||
require.Nil(t, result)
|
||
}
|
||
|
||
func TestGetChannelForGroup_CacheError(t *testing.T) {
|
||
repo := &mockChannelRepository{
|
||
listAllFn: func(_ context.Context) ([]Channel, error) {
|
||
return nil, errors.New("db connection failed")
|
||
},
|
||
}
|
||
svc := newTestChannelService(repo)
|
||
|
||
result, err := svc.GetChannelForGroup(context.Background(), 10)
|
||
require.Error(t, err)
|
||
require.Nil(t, result)
|
||
require.Contains(t, err.Error(), "db connection failed")
|
||
}
|
||
|
||
// --- 4.2 GetChannelModelPricing ---
|
||
|
||
func TestGetChannelModelPricing_ExactMatch(t *testing.T) {
|
||
ch := Channel{
|
||
ID: 1,
|
||
Status: StatusActive,
|
||
GroupIDs: []int64{10},
|
||
ModelPricing: []ChannelModelPricing{
|
||
{ID: 100, Platform: "anthropic", Models: []string{"claude-opus-4"}, InputPrice: testPtrFloat64(15e-6)},
|
||
},
|
||
}
|
||
repo := makeStandardRepo(ch, map[int64]string{10: "anthropic"})
|
||
svc := newTestChannelService(repo)
|
||
|
||
result := svc.GetChannelModelPricing(context.Background(), 10, "claude-opus-4")
|
||
require.NotNil(t, result)
|
||
require.Equal(t, int64(100), result.ID)
|
||
require.InDelta(t, 15e-6, *result.InputPrice, 1e-12)
|
||
}
|
||
|
||
func TestGetChannelModelPricing_CaseInsensitive(t *testing.T) {
|
||
ch := Channel{
|
||
ID: 1,
|
||
Status: StatusActive,
|
||
GroupIDs: []int64{10},
|
||
ModelPricing: []ChannelModelPricing{
|
||
{ID: 100, Platform: "anthropic", Models: []string{"claude-opus-4"}, InputPrice: testPtrFloat64(15e-6)},
|
||
},
|
||
}
|
||
repo := makeStandardRepo(ch, map[int64]string{10: "anthropic"})
|
||
svc := newTestChannelService(repo)
|
||
|
||
result := svc.GetChannelModelPricing(context.Background(), 10, "Claude-Opus-4")
|
||
require.NotNil(t, result)
|
||
require.Equal(t, int64(100), result.ID)
|
||
}
|
||
|
||
func TestGetChannelModelPricing_WildcardMatch(t *testing.T) {
|
||
ch := Channel{
|
||
ID: 1,
|
||
Status: StatusActive,
|
||
GroupIDs: []int64{10},
|
||
ModelPricing: []ChannelModelPricing{
|
||
{ID: 200, Platform: "anthropic", Models: []string{"claude-*"}, InputPrice: testPtrFloat64(10e-6)},
|
||
},
|
||
}
|
||
repo := makeStandardRepo(ch, map[int64]string{10: "anthropic"})
|
||
svc := newTestChannelService(repo)
|
||
|
||
result := svc.GetChannelModelPricing(context.Background(), 10, "claude-sonnet-4")
|
||
require.NotNil(t, result)
|
||
require.Equal(t, int64(200), result.ID)
|
||
}
|
||
|
||
func TestGetChannelModelPricing_WildcardFirstMatch(t *testing.T) {
|
||
ch := Channel{
|
||
ID: 1,
|
||
Status: StatusActive,
|
||
GroupIDs: []int64{10},
|
||
ModelPricing: []ChannelModelPricing{
|
||
{ID: 200, Platform: "anthropic", Models: []string{"claude-*"}, InputPrice: testPtrFloat64(10e-6)},
|
||
{ID: 300, Platform: "anthropic", Models: []string{"claude-sonnet-*"}, InputPrice: testPtrFloat64(5e-6)},
|
||
},
|
||
}
|
||
repo := makeStandardRepo(ch, map[int64]string{10: "anthropic"})
|
||
svc := newTestChannelService(repo)
|
||
|
||
result := svc.GetChannelModelPricing(context.Background(), 10, "claude-sonnet-4-20250514")
|
||
require.NotNil(t, result)
|
||
// "claude-*" is defined first, so it matches first regardless of prefix length
|
||
require.Equal(t, int64(200), result.ID)
|
||
require.InDelta(t, 10e-6, *result.InputPrice, 1e-12)
|
||
}
|
||
|
||
func TestGetChannelModelPricing_NoMatch(t *testing.T) {
|
||
ch := Channel{
|
||
ID: 1,
|
||
Status: StatusActive,
|
||
GroupIDs: []int64{10},
|
||
ModelPricing: []ChannelModelPricing{
|
||
{ID: 100, Platform: "anthropic", Models: []string{"claude-opus-4"}, InputPrice: testPtrFloat64(15e-6)},
|
||
},
|
||
}
|
||
repo := makeStandardRepo(ch, map[int64]string{10: "anthropic"})
|
||
svc := newTestChannelService(repo)
|
||
|
||
result := svc.GetChannelModelPricing(context.Background(), 10, "gpt-5.1")
|
||
require.Nil(t, result)
|
||
}
|
||
|
||
func TestGetChannelModelPricing_InactiveChannel(t *testing.T) {
|
||
ch := Channel{
|
||
ID: 1,
|
||
Status: StatusDisabled,
|
||
GroupIDs: []int64{10},
|
||
ModelPricing: []ChannelModelPricing{
|
||
{ID: 100, Platform: "anthropic", Models: []string{"claude-opus-4"}},
|
||
},
|
||
}
|
||
repo := makeStandardRepo(ch, map[int64]string{10: "anthropic"})
|
||
svc := newTestChannelService(repo)
|
||
|
||
result := svc.GetChannelModelPricing(context.Background(), 10, "claude-opus-4")
|
||
require.Nil(t, result)
|
||
}
|
||
|
||
func TestGetChannelModelPricing_PlatformFiltering(t *testing.T) {
|
||
ch := Channel{
|
||
ID: 1,
|
||
Status: StatusActive,
|
||
GroupIDs: []int64{10, 20},
|
||
ModelPricing: []ChannelModelPricing{
|
||
{ID: 100, Platform: "openai", Models: []string{"gpt-5.1"}, InputPrice: testPtrFloat64(5e-6)},
|
||
{ID: 200, Platform: "anthropic", Models: []string{"claude-opus-4"}, InputPrice: testPtrFloat64(15e-6)},
|
||
},
|
||
}
|
||
repo := makeStandardRepo(ch, map[int64]string{10: "anthropic", 20: "openai"})
|
||
svc := newTestChannelService(repo)
|
||
|
||
// Group 10 (anthropic) should NOT see openai pricing
|
||
result := svc.GetChannelModelPricing(context.Background(), 10, "gpt-5.1")
|
||
require.Nil(t, result)
|
||
|
||
// Group 10 (anthropic) should see anthropic pricing
|
||
result = svc.GetChannelModelPricing(context.Background(), 10, "claude-opus-4")
|
||
require.NotNil(t, result)
|
||
require.Equal(t, int64(200), result.ID)
|
||
|
||
// Group 20 (openai) should see openai pricing
|
||
result = svc.GetChannelModelPricing(context.Background(), 20, "gpt-5.1")
|
||
require.NotNil(t, result)
|
||
require.Equal(t, int64(100), result.ID)
|
||
|
||
// Group 20 (openai) should NOT see anthropic pricing
|
||
result = svc.GetChannelModelPricing(context.Background(), 20, "claude-opus-4")
|
||
require.Nil(t, result)
|
||
}
|
||
|
||
func TestGetChannelModelPricing_ReturnsCopy(t *testing.T) {
|
||
ch := Channel{
|
||
ID: 1,
|
||
Status: StatusActive,
|
||
GroupIDs: []int64{10},
|
||
ModelPricing: []ChannelModelPricing{
|
||
{ID: 100, Platform: "anthropic", Models: []string{"claude-opus-4"}, InputPrice: testPtrFloat64(15e-6)},
|
||
},
|
||
}
|
||
repo := makeStandardRepo(ch, map[int64]string{10: "anthropic"})
|
||
svc := newTestChannelService(repo)
|
||
|
||
result := svc.GetChannelModelPricing(context.Background(), 10, "claude-opus-4")
|
||
require.NotNil(t, result)
|
||
|
||
// Mutate the returned pricing's slice fields — original cache should not be affected
|
||
// (Clone copies slices independently, pointer fields are shared per design)
|
||
result.Models = append(result.Models, "hacked")
|
||
result.ID = 999
|
||
|
||
// Original cache should not be affected (slice independence + struct copy)
|
||
result2 := svc.GetChannelModelPricing(context.Background(), 10, "claude-opus-4")
|
||
require.NotNil(t, result2)
|
||
require.Equal(t, 1, len(result2.Models))
|
||
require.Equal(t, int64(100), result2.ID)
|
||
}
|
||
|
||
// --- 4.3 ResolveChannelMapping ---
|
||
|
||
func TestResolveChannelMapping_NoChannel(t *testing.T) {
|
||
ch := Channel{
|
||
ID: 1,
|
||
Status: StatusActive,
|
||
GroupIDs: []int64{10},
|
||
}
|
||
repo := makeStandardRepo(ch, map[int64]string{10: "anthropic"})
|
||
svc := newTestChannelService(repo)
|
||
|
||
// Group 999 is not in any channel
|
||
result := svc.ResolveChannelMapping(context.Background(), 999, "claude-opus-4")
|
||
require.Equal(t, "claude-opus-4", result.MappedModel)
|
||
require.False(t, result.Mapped)
|
||
require.Equal(t, int64(0), result.ChannelID)
|
||
}
|
||
|
||
func TestResolveChannelMapping_ExactMapping(t *testing.T) {
|
||
ch := Channel{
|
||
ID: 1,
|
||
Status: StatusActive,
|
||
GroupIDs: []int64{10},
|
||
ModelMapping: map[string]map[string]string{
|
||
"anthropic": {
|
||
"claude-sonnet-4": "claude-sonnet-4-20250514",
|
||
},
|
||
},
|
||
}
|
||
repo := makeStandardRepo(ch, map[int64]string{10: "anthropic"})
|
||
svc := newTestChannelService(repo)
|
||
|
||
result := svc.ResolveChannelMapping(context.Background(), 10, "claude-sonnet-4")
|
||
require.True(t, result.Mapped)
|
||
require.Equal(t, "claude-sonnet-4-20250514", result.MappedModel)
|
||
require.Equal(t, int64(1), result.ChannelID)
|
||
}
|
||
|
||
func TestResolveChannelMapping_WildcardMapping(t *testing.T) {
|
||
ch := Channel{
|
||
ID: 1,
|
||
Status: StatusActive,
|
||
GroupIDs: []int64{10},
|
||
ModelMapping: map[string]map[string]string{
|
||
"anthropic": {
|
||
"*": "gpt-5.4",
|
||
},
|
||
},
|
||
}
|
||
repo := makeStandardRepo(ch, map[int64]string{10: "anthropic"})
|
||
svc := newTestChannelService(repo)
|
||
|
||
result := svc.ResolveChannelMapping(context.Background(), 10, "any-model-name")
|
||
require.True(t, result.Mapped)
|
||
require.Equal(t, "gpt-5.4", result.MappedModel)
|
||
}
|
||
|
||
func TestResolveChannelMapping_WildcardFirstMatch(t *testing.T) {
|
||
ch := Channel{
|
||
ID: 1,
|
||
Status: StatusActive,
|
||
GroupIDs: []int64{10},
|
||
ModelMapping: map[string]map[string]string{
|
||
"anthropic": {
|
||
"claude-*": "target2",
|
||
"claude-sonnet-*": "target1",
|
||
},
|
||
},
|
||
}
|
||
repo := makeStandardRepo(ch, map[int64]string{10: "anthropic"})
|
||
svc := newTestChannelService(repo)
|
||
|
||
result := svc.ResolveChannelMapping(context.Background(), 10, "claude-sonnet-4")
|
||
require.True(t, result.Mapped)
|
||
// map iteration order is non-deterministic, so the first-match depends on
|
||
// insertion order which Go maps don't guarantee; verify that one of the
|
||
// wildcard targets matched
|
||
require.Contains(t, []string{"target1", "target2"}, result.MappedModel)
|
||
}
|
||
|
||
func TestResolveChannelMapping_NoMapping(t *testing.T) {
|
||
ch := Channel{
|
||
ID: 1,
|
||
Status: StatusActive,
|
||
GroupIDs: []int64{10},
|
||
ModelMapping: map[string]map[string]string{
|
||
"anthropic": {
|
||
"claude-sonnet-4": "mapped",
|
||
},
|
||
},
|
||
}
|
||
repo := makeStandardRepo(ch, map[int64]string{10: "anthropic"})
|
||
svc := newTestChannelService(repo)
|
||
|
||
result := svc.ResolveChannelMapping(context.Background(), 10, "claude-opus-4")
|
||
require.False(t, result.Mapped)
|
||
require.Equal(t, "claude-opus-4", result.MappedModel)
|
||
require.Equal(t, int64(1), result.ChannelID)
|
||
}
|
||
|
||
func TestResolveChannelMapping_DefaultBillingModelSource(t *testing.T) {
|
||
ch := Channel{
|
||
ID: 1,
|
||
Status: StatusActive,
|
||
GroupIDs: []int64{10},
|
||
BillingModelSource: "", // empty
|
||
}
|
||
repo := makeStandardRepo(ch, map[int64]string{10: "anthropic"})
|
||
svc := newTestChannelService(repo)
|
||
|
||
result := svc.ResolveChannelMapping(context.Background(), 10, "claude-opus-4")
|
||
require.Equal(t, BillingModelSourceChannelMapped, result.BillingModelSource)
|
||
}
|
||
|
||
func TestResolveChannelMapping_UpstreamBillingModelSource(t *testing.T) {
|
||
ch := Channel{
|
||
ID: 1,
|
||
Status: StatusActive,
|
||
GroupIDs: []int64{10},
|
||
BillingModelSource: BillingModelSourceUpstream,
|
||
}
|
||
repo := makeStandardRepo(ch, map[int64]string{10: "anthropic"})
|
||
svc := newTestChannelService(repo)
|
||
|
||
result := svc.ResolveChannelMapping(context.Background(), 10, "claude-opus-4")
|
||
require.Equal(t, BillingModelSourceUpstream, result.BillingModelSource)
|
||
}
|
||
|
||
func TestResolveChannelMapping_InactiveChannel(t *testing.T) {
|
||
ch := Channel{
|
||
ID: 1,
|
||
Status: StatusDisabled,
|
||
GroupIDs: []int64{10},
|
||
ModelMapping: map[string]map[string]string{
|
||
"anthropic": {
|
||
"claude-sonnet-4": "mapped",
|
||
},
|
||
},
|
||
}
|
||
repo := makeStandardRepo(ch, map[int64]string{10: "anthropic"})
|
||
svc := newTestChannelService(repo)
|
||
|
||
result := svc.ResolveChannelMapping(context.Background(), 10, "claude-sonnet-4")
|
||
require.False(t, result.Mapped)
|
||
require.Equal(t, "claude-sonnet-4", result.MappedModel)
|
||
require.Equal(t, int64(0), result.ChannelID) // no channel
|
||
}
|
||
|
||
// --- 4.4 IsModelRestricted ---
|
||
|
||
func TestIsModelRestricted_NoChannel(t *testing.T) {
|
||
ch := Channel{
|
||
ID: 1,
|
||
Status: StatusActive,
|
||
GroupIDs: []int64{10},
|
||
RestrictModels: true,
|
||
}
|
||
repo := makeStandardRepo(ch, map[int64]string{10: "anthropic"})
|
||
svc := newTestChannelService(repo)
|
||
|
||
// Group 999 is not in any channel
|
||
restricted := svc.IsModelRestricted(context.Background(), 999, "claude-opus-4")
|
||
require.False(t, restricted)
|
||
}
|
||
|
||
func TestIsModelRestricted_RestrictDisabled(t *testing.T) {
|
||
ch := Channel{
|
||
ID: 1,
|
||
Status: StatusActive,
|
||
GroupIDs: []int64{10},
|
||
RestrictModels: false,
|
||
ModelPricing: []ChannelModelPricing{
|
||
{Platform: "anthropic", Models: []string{"claude-opus-4"}},
|
||
},
|
||
}
|
||
repo := makeStandardRepo(ch, map[int64]string{10: "anthropic"})
|
||
svc := newTestChannelService(repo)
|
||
|
||
// Even though model is not in pricing, RestrictModels=false
|
||
restricted := svc.IsModelRestricted(context.Background(), 10, "nonexistent-model")
|
||
require.False(t, restricted)
|
||
}
|
||
|
||
func TestIsModelRestricted_InactiveChannel(t *testing.T) {
|
||
ch := Channel{
|
||
ID: 1,
|
||
Status: StatusDisabled,
|
||
GroupIDs: []int64{10},
|
||
RestrictModels: true,
|
||
}
|
||
repo := makeStandardRepo(ch, map[int64]string{10: "anthropic"})
|
||
svc := newTestChannelService(repo)
|
||
|
||
restricted := svc.IsModelRestricted(context.Background(), 10, "any-model")
|
||
require.False(t, restricted)
|
||
}
|
||
|
||
func TestIsModelRestricted_ModelInPricing(t *testing.T) {
|
||
ch := Channel{
|
||
ID: 1,
|
||
Status: StatusActive,
|
||
GroupIDs: []int64{10},
|
||
RestrictModels: true,
|
||
ModelPricing: []ChannelModelPricing{
|
||
{Platform: "anthropic", Models: []string{"claude-opus-4", "claude-sonnet-4"}},
|
||
},
|
||
}
|
||
repo := makeStandardRepo(ch, map[int64]string{10: "anthropic"})
|
||
svc := newTestChannelService(repo)
|
||
|
||
restricted := svc.IsModelRestricted(context.Background(), 10, "claude-opus-4")
|
||
require.False(t, restricted)
|
||
}
|
||
|
||
func TestIsModelRestricted_ModelInWildcard(t *testing.T) {
|
||
ch := Channel{
|
||
ID: 1,
|
||
Status: StatusActive,
|
||
GroupIDs: []int64{10},
|
||
RestrictModels: true,
|
||
ModelPricing: []ChannelModelPricing{
|
||
{Platform: "anthropic", Models: []string{"claude-*"}},
|
||
},
|
||
}
|
||
repo := makeStandardRepo(ch, map[int64]string{10: "anthropic"})
|
||
svc := newTestChannelService(repo)
|
||
|
||
restricted := svc.IsModelRestricted(context.Background(), 10, "claude-sonnet-4")
|
||
require.False(t, restricted)
|
||
}
|
||
|
||
func TestIsModelRestricted_ModelNotFound(t *testing.T) {
|
||
ch := Channel{
|
||
ID: 1,
|
||
Status: StatusActive,
|
||
GroupIDs: []int64{10},
|
||
RestrictModels: true,
|
||
ModelPricing: []ChannelModelPricing{
|
||
{Platform: "anthropic", Models: []string{"claude-opus-4"}},
|
||
},
|
||
}
|
||
repo := makeStandardRepo(ch, map[int64]string{10: "anthropic"})
|
||
svc := newTestChannelService(repo)
|
||
|
||
restricted := svc.IsModelRestricted(context.Background(), 10, "gpt-5.1")
|
||
require.True(t, restricted)
|
||
}
|
||
|
||
func TestIsModelRestricted_CaseInsensitive(t *testing.T) {
|
||
ch := Channel{
|
||
ID: 1,
|
||
Status: StatusActive,
|
||
GroupIDs: []int64{10},
|
||
RestrictModels: true,
|
||
ModelPricing: []ChannelModelPricing{
|
||
{Platform: "anthropic", Models: []string{"claude-opus-4"}},
|
||
},
|
||
}
|
||
repo := makeStandardRepo(ch, map[int64]string{10: "anthropic"})
|
||
svc := newTestChannelService(repo)
|
||
|
||
restricted := svc.IsModelRestricted(context.Background(), 10, "Claude-Opus-4")
|
||
require.False(t, restricted)
|
||
}
|
||
|
||
// --- 4.5 ResolveChannelMappingAndRestrict ---
|
||
// 注意:模型限制检查已移至调度阶段(GatewayService.checkChannelPricingRestriction),
|
||
// ResolveChannelMappingAndRestrict 仅做映射,restricted 始终为 false。
|
||
|
||
func TestResolveChannelMappingAndRestrict_NilGroupID(t *testing.T) {
|
||
repo := &mockChannelRepository{
|
||
listAllFn: func(_ context.Context) ([]Channel, error) {
|
||
return nil, nil
|
||
},
|
||
}
|
||
svc := newTestChannelService(repo)
|
||
|
||
mapping, restricted := svc.ResolveChannelMappingAndRestrict(context.Background(), nil, "claude-opus-4")
|
||
require.False(t, restricted)
|
||
require.False(t, mapping.Mapped)
|
||
require.Equal(t, "claude-opus-4", mapping.MappedModel)
|
||
}
|
||
|
||
func TestResolveChannelMappingAndRestrict_WithMapping(t *testing.T) {
|
||
ch := Channel{
|
||
ID: 1,
|
||
Status: StatusActive,
|
||
GroupIDs: []int64{10},
|
||
RestrictModels: true,
|
||
ModelPricing: []ChannelModelPricing{
|
||
{Platform: "anthropic", Models: []string{"claude-sonnet-4"}},
|
||
},
|
||
ModelMapping: map[string]map[string]string{
|
||
"anthropic": {
|
||
"claude-sonnet-4": "claude-sonnet-4-20250514",
|
||
},
|
||
},
|
||
}
|
||
repo := makeStandardRepo(ch, map[int64]string{10: "anthropic"})
|
||
svc := newTestChannelService(repo)
|
||
|
||
gid := int64(10)
|
||
mapping, restricted := svc.ResolveChannelMappingAndRestrict(context.Background(), &gid, "claude-sonnet-4")
|
||
require.False(t, restricted) // restricted 始终为 false,限制检查在调度阶段
|
||
require.True(t, mapping.Mapped)
|
||
require.Equal(t, "claude-sonnet-4-20250514", mapping.MappedModel)
|
||
}
|
||
|
||
func TestResolveChannelMappingAndRestrict_NoMapping(t *testing.T) {
|
||
ch := Channel{
|
||
ID: 1,
|
||
Status: StatusActive,
|
||
GroupIDs: []int64{10},
|
||
RestrictModels: true,
|
||
ModelPricing: []ChannelModelPricing{
|
||
{Platform: "anthropic", Models: []string{"claude-sonnet-4"}},
|
||
},
|
||
}
|
||
repo := makeStandardRepo(ch, map[int64]string{10: "anthropic"})
|
||
svc := newTestChannelService(repo)
|
||
|
||
gid := int64(10)
|
||
mapping, restricted := svc.ResolveChannelMappingAndRestrict(context.Background(), &gid, "unknown-model")
|
||
require.False(t, restricted) // restricted 始终为 false,限制检查在调度阶段
|
||
require.False(t, mapping.Mapped)
|
||
require.Equal(t, "unknown-model", mapping.MappedModel)
|
||
}
|
||
|
||
// --- 4.6 Cache Building Specifics ---
|
||
|
||
func TestBuildCache_DBError(t *testing.T) {
|
||
callCount := 0
|
||
repo := &mockChannelRepository{
|
||
listAllFn: func(_ context.Context) ([]Channel, error) {
|
||
callCount++
|
||
return nil, errors.New("database down")
|
||
},
|
||
}
|
||
svc := newTestChannelService(repo)
|
||
|
||
// First call should fail
|
||
_, err := svc.GetChannelForGroup(context.Background(), 10)
|
||
require.Error(t, err)
|
||
require.Contains(t, err.Error(), "database down")
|
||
require.Equal(t, 1, callCount)
|
||
|
||
// Second call within error-TTL should use error cache, but still return error
|
||
// Because buildCache stores error-TTL cache and returns error, the cached value
|
||
// is still within TTL and loadCache returns it (which is an empty cache).
|
||
// Actually, re-reading the code: buildCache returns nil, err, and the error cache
|
||
// only serves as a "don't retry immediately" mechanism. The singleflight.Do
|
||
// returns the error. On next call within error-TTL, the cache has an empty but
|
||
// valid entry, so loadCache returns it (with empty maps). GetChannelForGroup
|
||
// will find nothing and return nil, nil.
|
||
result, err := svc.GetChannelForGroup(context.Background(), 10)
|
||
require.NoError(t, err)
|
||
require.Nil(t, result)
|
||
// Should NOT have hit DB again (error-TTL cache is active)
|
||
require.Equal(t, 1, callCount)
|
||
}
|
||
|
||
func TestBuildCache_GroupPlatformError(t *testing.T) {
|
||
ch := Channel{
|
||
ID: 1,
|
||
Status: StatusActive,
|
||
GroupIDs: []int64{10},
|
||
ModelPricing: []ChannelModelPricing{
|
||
{ID: 100, Platform: "anthropic", Models: []string{"claude-opus-4"}},
|
||
},
|
||
}
|
||
repo := &mockChannelRepository{
|
||
listAllFn: func(_ context.Context) ([]Channel, error) {
|
||
return []Channel{ch}, nil
|
||
},
|
||
getGroupPlatformsFn: func(_ context.Context, _ []int64) (map[int64]string, error) {
|
||
return nil, errors.New("group platforms failed")
|
||
},
|
||
}
|
||
svc := newTestChannelService(repo)
|
||
|
||
// Should degrade gracefully: channel is found, but without platform info
|
||
// pricing won't match because platform will be "" and pricing platform is "anthropic"
|
||
result, err := svc.GetChannelForGroup(context.Background(), 10)
|
||
require.NoError(t, err)
|
||
require.NotNil(t, result) // channel still found
|
||
require.Equal(t, int64(1), result.ID)
|
||
}
|
||
|
||
func TestBuildCache_MultipleGroupsSameChannel(t *testing.T) {
|
||
ch := Channel{
|
||
ID: 1,
|
||
Status: StatusActive,
|
||
GroupIDs: []int64{10, 20, 30},
|
||
ModelPricing: []ChannelModelPricing{
|
||
{ID: 100, Platform: "anthropic", Models: []string{"claude-opus-4"}, InputPrice: testPtrFloat64(15e-6)},
|
||
},
|
||
}
|
||
repo := makeStandardRepo(ch, map[int64]string{
|
||
10: "anthropic",
|
||
20: "anthropic",
|
||
30: "anthropic",
|
||
})
|
||
svc := newTestChannelService(repo)
|
||
|
||
for _, gid := range []int64{10, 20, 30} {
|
||
result := svc.GetChannelModelPricing(context.Background(), gid, "claude-opus-4")
|
||
require.NotNil(t, result, "group %d should have pricing", gid)
|
||
require.Equal(t, int64(100), result.ID)
|
||
}
|
||
}
|
||
|
||
func TestBuildCache_PlatformFiltering(t *testing.T) {
|
||
ch := Channel{
|
||
ID: 1,
|
||
Status: StatusActive,
|
||
GroupIDs: []int64{10, 20},
|
||
ModelPricing: []ChannelModelPricing{
|
||
{ID: 100, Platform: "anthropic", Models: []string{"claude-opus-4"}},
|
||
{ID: 200, Platform: "openai", Models: []string{"gpt-5.1"}},
|
||
},
|
||
}
|
||
repo := makeStandardRepo(ch, map[int64]string{
|
||
10: "anthropic",
|
||
20: "openai",
|
||
})
|
||
svc := newTestChannelService(repo)
|
||
|
||
// anthropic group sees only anthropic models
|
||
require.NotNil(t, svc.GetChannelModelPricing(context.Background(), 10, "claude-opus-4"))
|
||
require.Nil(t, svc.GetChannelModelPricing(context.Background(), 10, "gpt-5.1"))
|
||
|
||
// openai group sees only openai models
|
||
require.NotNil(t, svc.GetChannelModelPricing(context.Background(), 20, "gpt-5.1"))
|
||
require.Nil(t, svc.GetChannelModelPricing(context.Background(), 20, "claude-opus-4"))
|
||
}
|
||
|
||
func TestBuildCache_WildcardPreservesConfigOrder(t *testing.T) {
|
||
ch := Channel{
|
||
ID: 1,
|
||
Status: StatusActive,
|
||
GroupIDs: []int64{10},
|
||
ModelPricing: []ChannelModelPricing{
|
||
// Configuration order: shortest prefix first
|
||
{ID: 100, Platform: "anthropic", Models: []string{"c-*"}, InputPrice: testPtrFloat64(1e-6)},
|
||
{ID: 200, Platform: "anthropic", Models: []string{"c-son-*"}, InputPrice: testPtrFloat64(2e-6)},
|
||
{ID: 300, Platform: "anthropic", Models: []string{"c-son-4-*"}, InputPrice: testPtrFloat64(3e-6)},
|
||
},
|
||
}
|
||
repo := makeStandardRepo(ch, map[int64]string{10: "anthropic"})
|
||
svc := newTestChannelService(repo)
|
||
|
||
// "c-son-4-xxx" matches all three wildcards, but "c-*" (ID=100) is first in config
|
||
result := svc.GetChannelModelPricing(context.Background(), 10, "c-son-4-xxx")
|
||
require.NotNil(t, result)
|
||
require.Equal(t, int64(100), result.ID)
|
||
|
||
// "c-son-yyy" matches "c-*" and "c-son-*", but "c-*" (ID=100) is first
|
||
result = svc.GetChannelModelPricing(context.Background(), 10, "c-son-yyy")
|
||
require.NotNil(t, result)
|
||
require.Equal(t, int64(100), result.ID)
|
||
|
||
// "c-other" only matches "c-*" (ID=100)
|
||
result = svc.GetChannelModelPricing(context.Background(), 10, "c-other")
|
||
require.NotNil(t, result)
|
||
require.Equal(t, int64(100), result.ID)
|
||
}
|
||
|
||
// --- 4.7 invalidateCache ---
|
||
|
||
func TestInvalidateCache(t *testing.T) {
|
||
callCount := 0
|
||
ch := Channel{
|
||
ID: 1,
|
||
Status: StatusActive,
|
||
GroupIDs: []int64{10},
|
||
ModelPricing: []ChannelModelPricing{
|
||
{ID: 100, Platform: "anthropic", Models: []string{"claude-opus-4"}},
|
||
},
|
||
}
|
||
repo := &mockChannelRepository{
|
||
listAllFn: func(_ context.Context) ([]Channel, error) {
|
||
callCount++
|
||
return []Channel{ch}, nil
|
||
},
|
||
getGroupPlatformsFn: func(_ context.Context, _ []int64) (map[int64]string, error) {
|
||
return map[int64]string{10: "anthropic"}, nil
|
||
},
|
||
}
|
||
svc := newTestChannelService(repo)
|
||
|
||
// First load
|
||
result := svc.GetChannelModelPricing(context.Background(), 10, "claude-opus-4")
|
||
require.NotNil(t, result)
|
||
require.Equal(t, 1, callCount)
|
||
|
||
// Second call should use cache
|
||
result = svc.GetChannelModelPricing(context.Background(), 10, "claude-opus-4")
|
||
require.NotNil(t, result)
|
||
require.Equal(t, 1, callCount) // no new DB call
|
||
|
||
// Invalidate
|
||
svc.invalidateCache()
|
||
|
||
// Next call should rebuild from DB
|
||
result = svc.GetChannelModelPricing(context.Background(), 10, "claude-opus-4")
|
||
require.NotNil(t, result)
|
||
require.Equal(t, 2, callCount) // rebuilt
|
||
}
|
||
|
||
// ===========================================================================
|
||
// 5. CRUD Methods
|
||
// ===========================================================================
|
||
|
||
// --- 5.1 Create ---
|
||
|
||
func TestCreate_Success(t *testing.T) {
|
||
createdID := int64(42)
|
||
repo := &mockChannelRepository{
|
||
existsByNameFn: func(_ context.Context, _ string) (bool, error) {
|
||
return false, nil
|
||
},
|
||
getGroupsInOtherChannelsFn: func(_ context.Context, _ int64, _ []int64) ([]int64, error) {
|
||
return nil, nil
|
||
},
|
||
createFn: func(_ context.Context, ch *Channel) error {
|
||
ch.ID = createdID
|
||
return nil
|
||
},
|
||
getByIDFn: func(_ context.Context, id int64) (*Channel, error) {
|
||
return &Channel{ID: id, Name: "new-channel", Status: StatusActive}, nil
|
||
},
|
||
listAllFn: func(_ context.Context) ([]Channel, error) {
|
||
return nil, nil
|
||
},
|
||
}
|
||
svc := newTestChannelService(repo)
|
||
|
||
result, err := svc.Create(context.Background(), &CreateChannelInput{
|
||
Name: "new-channel",
|
||
GroupIDs: []int64{10},
|
||
})
|
||
require.NoError(t, err)
|
||
require.NotNil(t, result)
|
||
require.Equal(t, createdID, result.ID)
|
||
}
|
||
|
||
func TestCreate_NameExists(t *testing.T) {
|
||
repo := &mockChannelRepository{
|
||
existsByNameFn: func(_ context.Context, _ string) (bool, error) {
|
||
return true, nil
|
||
},
|
||
}
|
||
svc := newTestChannelService(repo)
|
||
|
||
_, err := svc.Create(context.Background(), &CreateChannelInput{
|
||
Name: "existing-channel",
|
||
})
|
||
require.Error(t, err)
|
||
require.ErrorIs(t, err, ErrChannelExists)
|
||
}
|
||
|
||
func TestCreate_GroupConflict(t *testing.T) {
|
||
repo := &mockChannelRepository{
|
||
existsByNameFn: func(_ context.Context, _ string) (bool, error) {
|
||
return false, nil
|
||
},
|
||
getGroupsInOtherChannelsFn: func(_ context.Context, _ int64, _ []int64) ([]int64, error) {
|
||
return []int64{10}, nil // group 10 already in another channel
|
||
},
|
||
}
|
||
svc := newTestChannelService(repo)
|
||
|
||
_, err := svc.Create(context.Background(), &CreateChannelInput{
|
||
Name: "new-channel",
|
||
GroupIDs: []int64{10, 20},
|
||
})
|
||
require.Error(t, err)
|
||
require.ErrorIs(t, err, ErrGroupAlreadyInChannel)
|
||
}
|
||
|
||
func TestCreate_DuplicateModel(t *testing.T) {
|
||
repo := &mockChannelRepository{
|
||
existsByNameFn: func(_ context.Context, _ string) (bool, error) {
|
||
return false, nil
|
||
},
|
||
}
|
||
svc := newTestChannelService(repo)
|
||
|
||
_, err := svc.Create(context.Background(), &CreateChannelInput{
|
||
Name: "new-channel",
|
||
ModelPricing: []ChannelModelPricing{
|
||
{Platform: "anthropic", Models: []string{"claude-opus-4"}},
|
||
{Platform: "anthropic", Models: []string{"claude-opus-4"}}, // duplicate
|
||
},
|
||
})
|
||
require.Error(t, err)
|
||
require.Contains(t, err.Error(), "claude-opus-4")
|
||
}
|
||
|
||
func TestCreate_DefaultBillingModelSource(t *testing.T) {
|
||
var capturedChannel *Channel
|
||
repo := &mockChannelRepository{
|
||
existsByNameFn: func(_ context.Context, _ string) (bool, error) {
|
||
return false, nil
|
||
},
|
||
createFn: func(_ context.Context, ch *Channel) error {
|
||
capturedChannel = ch
|
||
ch.ID = 1
|
||
return nil
|
||
},
|
||
getByIDFn: func(_ context.Context, id int64) (*Channel, error) {
|
||
return capturedChannel, nil
|
||
},
|
||
listAllFn: func(_ context.Context) ([]Channel, error) {
|
||
return nil, nil
|
||
},
|
||
}
|
||
svc := newTestChannelService(repo)
|
||
|
||
result, err := svc.Create(context.Background(), &CreateChannelInput{
|
||
Name: "new-channel",
|
||
BillingModelSource: "", // empty, should default to "channel_mapped"
|
||
})
|
||
require.NoError(t, err)
|
||
require.NotNil(t, result)
|
||
require.Equal(t, BillingModelSourceChannelMapped, result.BillingModelSource)
|
||
}
|
||
|
||
func TestCreate_InvalidatesCache(t *testing.T) {
|
||
loadCount := 0
|
||
ch := Channel{
|
||
ID: 1,
|
||
Status: StatusActive,
|
||
GroupIDs: []int64{10},
|
||
ModelPricing: []ChannelModelPricing{
|
||
{ID: 100, Platform: "anthropic", Models: []string{"claude-opus-4"}},
|
||
},
|
||
}
|
||
repo := &mockChannelRepository{
|
||
listAllFn: func(_ context.Context) ([]Channel, error) {
|
||
loadCount++
|
||
return []Channel{ch}, nil
|
||
},
|
||
getGroupPlatformsFn: func(_ context.Context, _ []int64) (map[int64]string, error) {
|
||
return map[int64]string{10: "anthropic"}, nil
|
||
},
|
||
existsByNameFn: func(_ context.Context, _ string) (bool, error) {
|
||
return false, nil
|
||
},
|
||
createFn: func(_ context.Context, c *Channel) error {
|
||
c.ID = 2
|
||
return nil
|
||
},
|
||
getByIDFn: func(_ context.Context, id int64) (*Channel, error) {
|
||
return &Channel{ID: id, Name: "new", Status: StatusActive}, nil
|
||
},
|
||
}
|
||
svc := newTestChannelService(repo)
|
||
|
||
// Load cache
|
||
_ = svc.GetChannelModelPricing(context.Background(), 10, "claude-opus-4")
|
||
require.Equal(t, 1, loadCount)
|
||
|
||
// Create triggers cache invalidation
|
||
_, err := svc.Create(context.Background(), &CreateChannelInput{Name: "new"})
|
||
require.NoError(t, err)
|
||
|
||
// Next cache access should rebuild
|
||
_ = svc.GetChannelModelPricing(context.Background(), 10, "claude-opus-4")
|
||
require.Equal(t, 2, loadCount)
|
||
}
|
||
|
||
// --- 5.2 Update ---
|
||
|
||
func TestUpdate_Success(t *testing.T) {
|
||
existing := &Channel{
|
||
ID: 1,
|
||
Name: "original",
|
||
Status: StatusActive,
|
||
}
|
||
repo := &mockChannelRepository{
|
||
getByIDFn: func(_ context.Context, id int64) (*Channel, error) {
|
||
return existing.Clone(), nil
|
||
},
|
||
updateFn: func(_ context.Context, _ *Channel) error {
|
||
return nil
|
||
},
|
||
getGroupIDsFn: func(_ context.Context, _ int64) ([]int64, error) {
|
||
return nil, nil
|
||
},
|
||
listAllFn: func(_ context.Context) ([]Channel, error) {
|
||
return nil, nil
|
||
},
|
||
}
|
||
svc := newTestChannelService(repo)
|
||
|
||
result, err := svc.Update(context.Background(), 1, &UpdateChannelInput{
|
||
Name: "updated-name",
|
||
Description: testPtrString("new desc"),
|
||
})
|
||
require.NoError(t, err)
|
||
require.NotNil(t, result)
|
||
}
|
||
|
||
func TestUpdate_NotFound(t *testing.T) {
|
||
repo := &mockChannelRepository{
|
||
getByIDFn: func(_ context.Context, _ int64) (*Channel, error) {
|
||
return nil, ErrChannelNotFound
|
||
},
|
||
}
|
||
svc := newTestChannelService(repo)
|
||
|
||
_, err := svc.Update(context.Background(), 999, &UpdateChannelInput{
|
||
Name: "whatever",
|
||
})
|
||
require.Error(t, err)
|
||
require.Contains(t, err.Error(), "channel")
|
||
}
|
||
|
||
func TestUpdate_NameConflict(t *testing.T) {
|
||
existing := &Channel{
|
||
ID: 1,
|
||
Name: "original",
|
||
Status: StatusActive,
|
||
}
|
||
repo := &mockChannelRepository{
|
||
getByIDFn: func(_ context.Context, _ int64) (*Channel, error) {
|
||
return existing.Clone(), nil
|
||
},
|
||
existsByNameExcludingFn: func(_ context.Context, _ string, _ int64) (bool, error) {
|
||
return true, nil // name conflicts with another channel
|
||
},
|
||
}
|
||
svc := newTestChannelService(repo)
|
||
|
||
_, err := svc.Update(context.Background(), 1, &UpdateChannelInput{
|
||
Name: "conflicting-name",
|
||
})
|
||
require.Error(t, err)
|
||
require.ErrorIs(t, err, ErrChannelExists)
|
||
}
|
||
|
||
func TestUpdate_GroupConflict(t *testing.T) {
|
||
existing := &Channel{
|
||
ID: 1,
|
||
Name: "original",
|
||
Status: StatusActive,
|
||
}
|
||
repo := &mockChannelRepository{
|
||
getByIDFn: func(_ context.Context, _ int64) (*Channel, error) {
|
||
return existing.Clone(), nil
|
||
},
|
||
getGroupsInOtherChannelsFn: func(_ context.Context, _ int64, _ []int64) ([]int64, error) {
|
||
return []int64{20}, nil // group 20 in another channel
|
||
},
|
||
}
|
||
svc := newTestChannelService(repo)
|
||
|
||
newGroupIDs := []int64{10, 20}
|
||
_, err := svc.Update(context.Background(), 1, &UpdateChannelInput{
|
||
GroupIDs: &newGroupIDs,
|
||
})
|
||
require.Error(t, err)
|
||
require.ErrorIs(t, err, ErrGroupAlreadyInChannel)
|
||
}
|
||
|
||
func TestUpdate_DuplicateModel(t *testing.T) {
|
||
existing := &Channel{
|
||
ID: 1,
|
||
Name: "original",
|
||
Status: StatusActive,
|
||
}
|
||
repo := &mockChannelRepository{
|
||
getByIDFn: func(_ context.Context, _ int64) (*Channel, error) {
|
||
return existing.Clone(), nil
|
||
},
|
||
}
|
||
svc := newTestChannelService(repo)
|
||
|
||
dupPricing := []ChannelModelPricing{
|
||
{Platform: "anthropic", Models: []string{"claude-opus-4"}},
|
||
{Platform: "anthropic", Models: []string{"claude-opus-4"}},
|
||
}
|
||
_, err := svc.Update(context.Background(), 1, &UpdateChannelInput{
|
||
ModelPricing: &dupPricing,
|
||
})
|
||
require.Error(t, err)
|
||
require.Contains(t, err.Error(), "claude-opus-4")
|
||
}
|
||
|
||
func TestUpdate_InvalidatesChannelCache(t *testing.T) {
|
||
existing := &Channel{
|
||
ID: 1,
|
||
Name: "original",
|
||
Status: StatusActive,
|
||
}
|
||
loadCount := 0
|
||
repo := &mockChannelRepository{
|
||
getByIDFn: func(_ context.Context, _ int64) (*Channel, error) {
|
||
return existing.Clone(), nil
|
||
},
|
||
updateFn: func(_ context.Context, _ *Channel) error {
|
||
return nil
|
||
},
|
||
getGroupIDsFn: func(_ context.Context, _ int64) ([]int64, error) {
|
||
return []int64{10, 20}, nil
|
||
},
|
||
listAllFn: func(_ context.Context) ([]Channel, error) {
|
||
loadCount++
|
||
return []Channel{*existing}, nil
|
||
},
|
||
getGroupPlatformsFn: func(_ context.Context, _ []int64) (map[int64]string, error) {
|
||
return nil, nil
|
||
},
|
||
}
|
||
svc := newTestChannelService(repo)
|
||
|
||
// Load cache first
|
||
_, _ = svc.GetChannelForGroup(context.Background(), 10)
|
||
require.Equal(t, 1, loadCount)
|
||
|
||
result, err := svc.Update(context.Background(), 1, &UpdateChannelInput{
|
||
Description: testPtrString("updated"),
|
||
})
|
||
require.NoError(t, err)
|
||
require.NotNil(t, result)
|
||
|
||
// Channel cache should be invalidated (next access rebuilds)
|
||
_, _ = svc.GetChannelForGroup(context.Background(), 10)
|
||
require.Equal(t, 2, loadCount)
|
||
}
|
||
|
||
func TestUpdate_InvalidatesAuthCache(t *testing.T) {
|
||
existing := &Channel{
|
||
ID: 1,
|
||
Name: "original",
|
||
Status: StatusActive,
|
||
}
|
||
auth := &mockChannelAuthCacheInvalidator{}
|
||
repo := &mockChannelRepository{
|
||
getByIDFn: func(_ context.Context, _ int64) (*Channel, error) {
|
||
return existing.Clone(), nil
|
||
},
|
||
updateFn: func(_ context.Context, _ *Channel) error {
|
||
return nil
|
||
},
|
||
getGroupIDsFn: func(_ context.Context, _ int64) ([]int64, error) {
|
||
return []int64{10, 20}, nil
|
||
},
|
||
listAllFn: func(_ context.Context) ([]Channel, error) {
|
||
return nil, nil
|
||
},
|
||
}
|
||
svc := newTestChannelServiceWithAuth(repo, auth)
|
||
|
||
result, err := svc.Update(context.Background(), 1, &UpdateChannelInput{
|
||
Description: testPtrString("updated"),
|
||
})
|
||
require.NoError(t, err)
|
||
require.NotNil(t, result)
|
||
|
||
// Auth cache should be invalidated for both groups
|
||
require.ElementsMatch(t, []int64{10, 20}, auth.invalidatedGroupIDs)
|
||
}
|
||
|
||
// --- 5.3 Delete ---
|
||
|
||
func TestChannelDelete_Success(t *testing.T) {
|
||
deleted := false
|
||
repo := &mockChannelRepository{
|
||
getGroupIDsFn: func(_ context.Context, _ int64) ([]int64, error) {
|
||
return nil, nil
|
||
},
|
||
deleteFn: func(_ context.Context, _ int64) error {
|
||
deleted = true
|
||
return nil
|
||
},
|
||
listAllFn: func(_ context.Context) ([]Channel, error) {
|
||
return nil, nil
|
||
},
|
||
}
|
||
svc := newTestChannelService(repo)
|
||
|
||
err := svc.Delete(context.Background(), 1)
|
||
require.NoError(t, err)
|
||
require.True(t, deleted)
|
||
}
|
||
|
||
func TestChannelDelete_InvalidatesCaches(t *testing.T) {
|
||
auth := &mockChannelAuthCacheInvalidator{}
|
||
loadCount := 0
|
||
repo := &mockChannelRepository{
|
||
getGroupIDsFn: func(_ context.Context, _ int64) ([]int64, error) {
|
||
return []int64{10, 20}, nil
|
||
},
|
||
deleteFn: func(_ context.Context, _ int64) error {
|
||
return nil
|
||
},
|
||
listAllFn: func(_ context.Context) ([]Channel, error) {
|
||
loadCount++
|
||
return []Channel{{ID: 1, Status: StatusActive, GroupIDs: []int64{10, 20}}}, nil
|
||
},
|
||
getGroupPlatformsFn: func(_ context.Context, _ []int64) (map[int64]string, error) {
|
||
return nil, nil
|
||
},
|
||
}
|
||
svc := newTestChannelServiceWithAuth(repo, auth)
|
||
|
||
// Load cache first
|
||
_, _ = svc.GetChannelForGroup(context.Background(), 10)
|
||
require.Equal(t, 1, loadCount)
|
||
|
||
err := svc.Delete(context.Background(), 1)
|
||
require.NoError(t, err)
|
||
|
||
// Auth cache invalidated for both groups
|
||
require.ElementsMatch(t, []int64{10, 20}, auth.invalidatedGroupIDs)
|
||
|
||
// Channel cache invalidated
|
||
_, _ = svc.GetChannelForGroup(context.Background(), 10)
|
||
require.Equal(t, 2, loadCount)
|
||
}
|
||
|
||
func TestChannelDelete_NotFound(t *testing.T) {
|
||
repo := &mockChannelRepository{
|
||
getGroupIDsFn: func(_ context.Context, _ int64) ([]int64, error) {
|
||
return nil, nil
|
||
},
|
||
deleteFn: func(_ context.Context, _ int64) error {
|
||
return errors.New("record not found")
|
||
},
|
||
}
|
||
svc := newTestChannelService(repo)
|
||
|
||
err := svc.Delete(context.Background(), 999)
|
||
require.Error(t, err)
|
||
require.Contains(t, err.Error(), "not found")
|
||
}
|
||
|
||
// ===========================================================================
|
||
// 6. Edge Case Tests
|
||
// ===========================================================================
|
||
|
||
// --- 6.1 Create with empty GroupIDs ---
|
||
|
||
func TestCreate_NoGroups(t *testing.T) {
|
||
createdID := int64(55)
|
||
getGroupsInOtherChannelsCalled := false
|
||
repo := &mockChannelRepository{
|
||
existsByNameFn: func(_ context.Context, _ string) (bool, error) {
|
||
return false, nil
|
||
},
|
||
getGroupsInOtherChannelsFn: func(_ context.Context, _ int64, _ []int64) ([]int64, error) {
|
||
getGroupsInOtherChannelsCalled = true
|
||
return nil, nil
|
||
},
|
||
createFn: func(_ context.Context, ch *Channel) error {
|
||
ch.ID = createdID
|
||
return nil
|
||
},
|
||
getByIDFn: func(_ context.Context, id int64) (*Channel, error) {
|
||
return &Channel{ID: id, Name: "no-groups-channel", Status: StatusActive}, nil
|
||
},
|
||
listAllFn: func(_ context.Context) ([]Channel, error) {
|
||
return nil, nil
|
||
},
|
||
}
|
||
svc := newTestChannelService(repo)
|
||
|
||
result, err := svc.Create(context.Background(), &CreateChannelInput{
|
||
Name: "no-groups-channel",
|
||
GroupIDs: []int64{}, // empty slice
|
||
})
|
||
require.NoError(t, err)
|
||
require.NotNil(t, result)
|
||
require.Equal(t, createdID, result.ID)
|
||
// GetGroupsInOtherChannels should NOT have been called (skipped by len(input.GroupIDs) > 0)
|
||
require.False(t, getGroupsInOtherChannelsCalled)
|
||
}
|
||
|
||
// --- 6.2 Update only Status ---
|
||
|
||
func TestUpdate_StatusOnly(t *testing.T) {
|
||
existing := &Channel{
|
||
ID: 1,
|
||
Name: "test-channel",
|
||
Status: StatusActive,
|
||
}
|
||
var capturedChannel *Channel
|
||
repo := &mockChannelRepository{
|
||
getByIDFn: func(_ context.Context, id int64) (*Channel, error) {
|
||
return existing.Clone(), nil
|
||
},
|
||
updateFn: func(_ context.Context, ch *Channel) error {
|
||
capturedChannel = ch
|
||
return nil
|
||
},
|
||
getGroupIDsFn: func(_ context.Context, _ int64) ([]int64, error) {
|
||
return nil, nil
|
||
},
|
||
listAllFn: func(_ context.Context) ([]Channel, error) {
|
||
return nil, nil
|
||
},
|
||
}
|
||
svc := newTestChannelService(repo)
|
||
|
||
result, err := svc.Update(context.Background(), 1, &UpdateChannelInput{
|
||
Status: StatusDisabled,
|
||
})
|
||
require.NoError(t, err)
|
||
require.NotNil(t, result)
|
||
// Verify that the channel passed to repo.Update has the new status
|
||
require.NotNil(t, capturedChannel)
|
||
require.Equal(t, StatusDisabled, capturedChannel.Status)
|
||
// Name should remain unchanged
|
||
require.Equal(t, "test-channel", capturedChannel.Name)
|
||
}
|
||
|
||
// --- 6.3 Delete when GetGroupIDs fails ---
|
||
|
||
func TestChannelDelete_GetGroupIDsError(t *testing.T) {
|
||
deleted := false
|
||
repo := &mockChannelRepository{
|
||
getGroupIDsFn: func(_ context.Context, _ int64) ([]int64, error) {
|
||
return nil, errors.New("group IDs lookup failed")
|
||
},
|
||
deleteFn: func(_ context.Context, _ int64) error {
|
||
deleted = true
|
||
return nil
|
||
},
|
||
listAllFn: func(_ context.Context) ([]Channel, error) {
|
||
return nil, nil
|
||
},
|
||
}
|
||
svc := newTestChannelService(repo)
|
||
|
||
// Delete should still succeed even though GetGroupIDs returned error (degradation path L588-591)
|
||
err := svc.Delete(context.Background(), 1)
|
||
require.NoError(t, err)
|
||
require.True(t, deleted)
|
||
}
|
||
|
||
// --- 6.4 ReplaceModelInBody with invalid JSON ---
|
||
|
||
func TestReplaceModelInBody_InvalidJSON(t *testing.T) {
|
||
// Case 1: broken JSON object — gjson won't find "model", sjson does best-effort set
|
||
// (no panic, no error from sjson, but result is mutated garbage)
|
||
brokenBody := []byte("{broken")
|
||
result := ReplaceModelInBody(brokenBody, "new-model")
|
||
require.NotNil(t, result)
|
||
// sjson does not error on this input, so result differs from original — just verify no panic
|
||
|
||
// Case 2: JSON array — sjson.SetBytes returns error on non-object,
|
||
// triggering the L447 error fallback path that returns original body.
|
||
arrayBody := []byte("[]")
|
||
result2 := ReplaceModelInBody(arrayBody, "new-model")
|
||
require.Equal(t, arrayBody, result2)
|
||
}
|
||
|
||
// ===========================================================================
|
||
// 7. isPlatformPricingMatch
|
||
// ===========================================================================
|
||
|
||
func TestIsPlatformPricingMatch(t *testing.T) {
|
||
tests := []struct {
|
||
name string
|
||
groupPlatform string
|
||
pricingPlatform string
|
||
want bool
|
||
}{
|
||
{"antigravity matches anthropic", PlatformAntigravity, PlatformAnthropic, true},
|
||
{"antigravity matches gemini", PlatformAntigravity, PlatformGemini, true},
|
||
{"antigravity matches antigravity", PlatformAntigravity, PlatformAntigravity, true},
|
||
{"antigravity does NOT match openai", PlatformAntigravity, PlatformOpenAI, false},
|
||
{"anthropic matches anthropic", PlatformAnthropic, PlatformAnthropic, true},
|
||
{"anthropic does NOT match antigravity", PlatformAnthropic, PlatformAntigravity, false},
|
||
{"anthropic does NOT match gemini", PlatformAnthropic, PlatformGemini, false},
|
||
{"gemini matches gemini", PlatformGemini, PlatformGemini, true},
|
||
{"gemini does NOT match antigravity", PlatformGemini, PlatformAntigravity, false},
|
||
{"gemini does NOT match anthropic", PlatformGemini, PlatformAnthropic, false},
|
||
{"empty string matches nothing", "", PlatformAnthropic, false},
|
||
{"empty string matches empty", "", "", true},
|
||
}
|
||
|
||
for _, tt := range tests {
|
||
t.Run(tt.name, func(t *testing.T) {
|
||
require.Equal(t, tt.want, isPlatformPricingMatch(tt.groupPlatform, tt.pricingPlatform))
|
||
})
|
||
}
|
||
}
|
||
|
||
// ===========================================================================
|
||
// 8. matchingPlatforms
|
||
// ===========================================================================
|
||
|
||
func TestMatchingPlatforms(t *testing.T) {
|
||
tests := []struct {
|
||
name string
|
||
groupPlatform string
|
||
want []string
|
||
}{
|
||
{"antigravity returns all three", PlatformAntigravity, []string{PlatformAntigravity, PlatformAnthropic, PlatformGemini}},
|
||
{"anthropic returns itself", PlatformAnthropic, []string{PlatformAnthropic}},
|
||
{"gemini returns itself", PlatformGemini, []string{PlatformGemini}},
|
||
{"openai returns itself", PlatformOpenAI, []string{PlatformOpenAI}},
|
||
}
|
||
|
||
for _, tt := range tests {
|
||
t.Run(tt.name, func(t *testing.T) {
|
||
result := matchingPlatforms(tt.groupPlatform)
|
||
require.Equal(t, tt.want, result)
|
||
})
|
||
}
|
||
}
|
||
|
||
// ===========================================================================
|
||
// 9. Antigravity cross-platform channel pricing
|
||
// ===========================================================================
|
||
|
||
func TestGetChannelModelPricing_AntigravityCrossPlatform(t *testing.T) {
|
||
// Channel has anthropic pricing for claude-opus-4-6.
|
||
// Group 10 is antigravity — should see the anthropic pricing.
|
||
ch := Channel{
|
||
ID: 1,
|
||
Status: StatusActive,
|
||
GroupIDs: []int64{10},
|
||
ModelPricing: []ChannelModelPricing{
|
||
{ID: 100, Platform: PlatformAnthropic, Models: []string{"claude-opus-4-6"}, InputPrice: testPtrFloat64(15e-6)},
|
||
},
|
||
}
|
||
repo := makeStandardRepo(ch, map[int64]string{10: PlatformAntigravity})
|
||
svc := newTestChannelService(repo)
|
||
|
||
result := svc.GetChannelModelPricing(context.Background(), 10, "claude-opus-4-6")
|
||
require.NotNil(t, result, "antigravity group should see anthropic pricing")
|
||
require.Equal(t, int64(100), result.ID)
|
||
require.InDelta(t, 15e-6, *result.InputPrice, 1e-12)
|
||
}
|
||
|
||
func TestGetChannelModelPricing_AnthropicCannotSeeAntigravityPricing(t *testing.T) {
|
||
// Channel has antigravity-platform pricing for claude-opus-4-6.
|
||
// Group 10 is anthropic — should NOT see antigravity pricing (no cross-platform leakage).
|
||
ch := Channel{
|
||
ID: 1,
|
||
Status: StatusActive,
|
||
GroupIDs: []int64{10},
|
||
ModelPricing: []ChannelModelPricing{
|
||
{ID: 100, Platform: PlatformAntigravity, Models: []string{"claude-opus-4-6"}, InputPrice: testPtrFloat64(15e-6)},
|
||
},
|
||
}
|
||
repo := makeStandardRepo(ch, map[int64]string{10: PlatformAnthropic})
|
||
svc := newTestChannelService(repo)
|
||
|
||
result := svc.GetChannelModelPricing(context.Background(), 10, "claude-opus-4-6")
|
||
require.Nil(t, result, "anthropic group should NOT see antigravity-platform pricing")
|
||
}
|
||
|
||
// ===========================================================================
|
||
// 10. Antigravity cross-platform model mapping
|
||
// ===========================================================================
|
||
|
||
func TestResolveChannelMapping_AntigravityCrossPlatform(t *testing.T) {
|
||
// Channel has anthropic model mapping: claude-opus-4-5 → claude-opus-4-6.
|
||
// Group 10 is antigravity — should apply the anthropic mapping.
|
||
ch := Channel{
|
||
ID: 1,
|
||
Status: StatusActive,
|
||
GroupIDs: []int64{10},
|
||
ModelMapping: map[string]map[string]string{
|
||
PlatformAnthropic: {
|
||
"claude-opus-4-5": "claude-opus-4-6",
|
||
},
|
||
},
|
||
}
|
||
repo := makeStandardRepo(ch, map[int64]string{10: PlatformAntigravity})
|
||
svc := newTestChannelService(repo)
|
||
|
||
result := svc.ResolveChannelMapping(context.Background(), 10, "claude-opus-4-5")
|
||
require.True(t, result.Mapped, "antigravity group should apply anthropic mapping")
|
||
require.Equal(t, "claude-opus-4-6", result.MappedModel)
|
||
require.Equal(t, int64(1), result.ChannelID)
|
||
}
|