refactor(channel): split long functions, extract shared validation, move billing validation to service

- Split Update (98→25 lines), buildCache (54→20 lines), Create (51→25 lines)
  into focused sub-functions: applyUpdateInput, checkGroupConflicts,
  fetchChannelData, populateChannelCache, storeErrorCache, getOldGroupIDs,
  invalidateAuthCacheForGroups
- Extract validateChannelConfig to eliminate duplicated validation calls
  between Create and Update
- Move validatePricingBillingMode from handler to service layer for
  proper separation of concerns
- Add error logging to IsModelRestricted (was silently swallowing errors)
- Add 12 new tests: ToUsageFields, billing mode validation, antigravity
  wildcard mapping isolation, Create/Update mapping conflict integration
This commit is contained in:
erio
2026-04-05 22:05:13 +08:00
parent 339d906e54
commit 9151d34d40
4 changed files with 420 additions and 288 deletions

View File

@@ -2199,3 +2199,207 @@ func TestGetChannelModelPricing_NonAntigravityUnaffected(t *testing.T) {
require.Equal(t, int64(601), result.ID)
require.InDelta(t, 5e-6, *result.InputPrice, 1e-12)
}
// ---------------------------------------------------------------------------
// 10. ToUsageFields
// ---------------------------------------------------------------------------
func TestToUsageFields_NoMapping(t *testing.T) {
r := ChannelMappingResult{
MappedModel: "claude-opus-4",
ChannelID: 1,
Mapped: false,
BillingModelSource: BillingModelSourceRequested,
}
fields := r.ToUsageFields("claude-opus-4", "claude-opus-4")
require.Equal(t, int64(1), fields.ChannelID)
require.Equal(t, "claude-opus-4", fields.OriginalModel)
require.Equal(t, "claude-opus-4", fields.ChannelMappedModel)
require.Equal(t, BillingModelSourceRequested, fields.BillingModelSource)
require.Empty(t, fields.ModelMappingChain)
}
func TestToUsageFields_WithChannelMapping(t *testing.T) {
r := ChannelMappingResult{
MappedModel: "claude-sonnet-4-20250514",
ChannelID: 2,
Mapped: true,
BillingModelSource: BillingModelSourceChannelMapped,
}
fields := r.ToUsageFields("claude-sonnet-4", "claude-sonnet-4-20250514")
require.Equal(t, int64(2), fields.ChannelID)
require.Equal(t, "claude-sonnet-4", fields.OriginalModel)
require.Equal(t, "claude-sonnet-4-20250514", fields.ChannelMappedModel)
require.Equal(t, "claude-sonnet-4→claude-sonnet-4-20250514", fields.ModelMappingChain)
}
func TestToUsageFields_WithUpstreamDifference(t *testing.T) {
r := ChannelMappingResult{
MappedModel: "claude-sonnet-4",
ChannelID: 3,
Mapped: true,
BillingModelSource: BillingModelSourceUpstream,
}
fields := r.ToUsageFields("my-alias", "claude-sonnet-4-20250514")
require.Equal(t, "my-alias", fields.OriginalModel)
require.Equal(t, "claude-sonnet-4", fields.ChannelMappedModel)
require.Equal(t, "my-alias→claude-sonnet-4→claude-sonnet-4-20250514", fields.ModelMappingChain)
}
// ---------------------------------------------------------------------------
// 11. validatePricingBillingMode (moved from handler tests)
// ---------------------------------------------------------------------------
func TestValidatePricingBillingMode(t *testing.T) {
tests := []struct {
name string
pricing []ChannelModelPricing
wantErr bool
errMsg string
}{
{
name: "token mode - valid",
pricing: []ChannelModelPricing{{BillingMode: BillingModeToken}},
},
{
name: "per_request with price - valid",
pricing: []ChannelModelPricing{{
BillingMode: BillingModePerRequest,
PerRequestPrice: testPtrFloat64(0.5),
}},
},
{
name: "per_request with intervals - valid",
pricing: []ChannelModelPricing{{
BillingMode: BillingModePerRequest,
Intervals: []PricingInterval{{MinTokens: 0, MaxTokens: testPtrInt(1000), PerRequestPrice: testPtrFloat64(0.1)}},
}},
},
{
name: "per_request no price no intervals - invalid",
pricing: []ChannelModelPricing{{BillingMode: BillingModePerRequest}},
wantErr: true,
errMsg: "per-request price or intervals required",
},
{
name: "image no price no intervals - invalid",
pricing: []ChannelModelPricing{{BillingMode: BillingModeImage}},
wantErr: true,
errMsg: "per-request price or intervals required",
},
{
name: "empty list - valid",
pricing: []ChannelModelPricing{},
},
{
name: "negative input_price - invalid",
pricing: []ChannelModelPricing{{
BillingMode: BillingModeToken,
InputPrice: testPtrFloat64(-0.01),
}},
wantErr: true,
errMsg: "input_price must be >= 0",
},
{
name: "interval with no price fields - invalid",
pricing: []ChannelModelPricing{{
BillingMode: BillingModePerRequest,
PerRequestPrice: testPtrFloat64(0.5),
Intervals: []PricingInterval{{MinTokens: 0, MaxTokens: testPtrInt(1000)}},
}},
wantErr: true,
errMsg: "has no price fields set",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := validatePricingBillingMode(tt.pricing)
if tt.wantErr {
require.Error(t, err)
require.Contains(t, err.Error(), tt.errMsg)
} else {
require.NoError(t, err)
}
})
}
}
// ---------------------------------------------------------------------------
// 12. Antigravity wildcard mapping isolation
// ---------------------------------------------------------------------------
func TestResolveChannelMapping_AntigravityDoesNotSeeWildcardMappingFromOtherPlatforms(t *testing.T) {
ch := Channel{
ID: 1,
Status: StatusActive,
GroupIDs: []int64{10, 20},
ModelMapping: map[string]map[string]string{
PlatformAnthropic: {"claude-*": "claude-override"},
PlatformGemini: {"gemini-*": "gemini-override"},
},
}
repo := makeStandardRepo(ch, map[int64]string{10: PlatformAntigravity, 20: PlatformAnthropic})
svc := newTestChannelService(repo)
// antigravity 分组不应看到 anthropic/gemini 的通配符映射
result := svc.ResolveChannelMapping(context.Background(), 10, "claude-opus-4")
require.False(t, result.Mapped)
require.Equal(t, "claude-opus-4", result.MappedModel)
result = svc.ResolveChannelMapping(context.Background(), 10, "gemini-2.5-pro")
require.False(t, result.Mapped)
require.Equal(t, "gemini-2.5-pro", result.MappedModel)
// anthropic 分组应该能看到 anthropic 的通配符映射
result = svc.ResolveChannelMapping(context.Background(), 20, "claude-opus-4")
require.True(t, result.Mapped)
require.Equal(t, "claude-override", result.MappedModel)
}
// ---------------------------------------------------------------------------
// 13. Create/Update with mapping conflict validation
// ---------------------------------------------------------------------------
func TestCreate_MappingConflict(t *testing.T) {
repo := &mockChannelRepository{}
svc := newTestChannelService(repo)
_, err := svc.Create(context.Background(), &CreateChannelInput{
Name: "test",
ModelMapping: map[string]map[string]string{
PlatformAnthropic: {
"claude-*": "target-a",
"claude-opus-*": "target-b",
},
},
})
require.Error(t, err)
require.Contains(t, err.Error(), "MAPPING_PATTERN_CONFLICT")
}
func TestUpdate_MappingConflict(t *testing.T) {
existingChannel := &Channel{
ID: 1,
Name: "existing",
Status: StatusActive,
}
repo := &mockChannelRepository{
getByIDFn: func(_ context.Context, _ int64) (*Channel, error) {
return existingChannel, nil
},
}
svc := newTestChannelService(repo)
conflictMapping := map[string]map[string]string{
PlatformAnthropic: {
"claude-*": "target-a",
"claude-opus-*": "target-b",
},
}
_, err := svc.Update(context.Background(), 1, &UpdateChannelInput{
ModelMapping: conflictMapping,
})
require.Error(t, err)
require.Contains(t, err.Error(), "MAPPING_PATTERN_CONFLICT")
}