//go:build unit package service import ( "context" "testing" "time" "github.com/Wei-Shaw/sub2api/internal/config" "github.com/stretchr/testify/require" ) // --------------------------------------------------------------------------- // CalculateCostUnified // --------------------------------------------------------------------------- func TestCalculateCostUnified_NilResolver_FallsBackToOldPath(t *testing.T) { svc := newTestBillingService() tokens := UsageTokens{InputTokens: 1000, OutputTokens: 500} input := CostInput{ Model: "claude-sonnet-4", Tokens: tokens, RateMultiplier: 1.0, Resolver: nil, // no resolver } cost, err := svc.CalculateCostUnified(input) require.NoError(t, err) // Should match the old-path result exactly expected, err := svc.calculateCostInternal("claude-sonnet-4", tokens, 1.0, "", nil) require.NoError(t, err) require.InDelta(t, expected.TotalCost, cost.TotalCost, 1e-10) require.InDelta(t, expected.ActualCost, cost.ActualCost, 1e-10) // BillingMode is NOT set by old path through CalculateCostUnified (resolver == nil) require.Empty(t, cost.BillingMode) } func TestCalculateCostUnified_TokenMode(t *testing.T) { bs := newTestBillingService() resolver := NewModelPricingResolver(nil, bs) tokens := UsageTokens{InputTokens: 1000, OutputTokens: 500} input := CostInput{ Ctx: context.Background(), Model: "claude-sonnet-4", Tokens: tokens, RateMultiplier: 1.5, Resolver: resolver, } cost, err := bs.CalculateCostUnified(input) require.NoError(t, err) require.NotNil(t, cost) // Verify token billing: Input: 1000*3e-6=0.003, Output: 500*15e-6=0.0075 expectedTotal := 1000*3e-6 + 500*15e-6 require.InDelta(t, expectedTotal, cost.TotalCost, 1e-10) require.InDelta(t, expectedTotal*1.5, cost.ActualCost, 1e-10) require.Equal(t, string(BillingModeToken), cost.BillingMode) } func TestCalculateCostUnified_PerRequestMode(t *testing.T) { // Set up a ChannelService with a per-request pricing channel cs := newTestChannelServiceWithCache(t, &channelCache{ pricingByGroupModel: map[channelModelKey]*ChannelModelPricing{ {groupID: 1, model: "claude-sonnet-4"}: { BillingMode: BillingModePerRequest, PerRequestPrice: testPtrFloat64(0.05), }, }, channelByGroupID: map[int64]*Channel{ 1: {ID: 1, Status: StatusActive}, }, groupPlatform: map[int64]string{1: ""}, wildcardByGroupPlatform: map[channelGroupPlatformKey][]*wildcardPricingEntry{}, mappingByGroupModel: map[channelModelKey]string{}, wildcardMappingByGP: map[channelGroupPlatformKey][]*wildcardMappingEntry{}, byID: map[int64]*Channel{}, }) bs := newTestBillingService() resolver := NewModelPricingResolver(cs, bs) groupID := int64(1) input := CostInput{ Ctx: context.Background(), Model: "claude-sonnet-4", GroupID: &groupID, Tokens: UsageTokens{InputTokens: 100, OutputTokens: 50}, RequestCount: 3, RateMultiplier: 2.0, Resolver: resolver, } cost, err := bs.CalculateCostUnified(input) require.NoError(t, err) require.NotNil(t, cost) // 3 requests * $0.05 = $0.15 require.InDelta(t, 0.15, cost.TotalCost, 1e-10) // ActualCost = 0.15 * 2.0 = 0.30 require.InDelta(t, 0.30, cost.ActualCost, 1e-10) require.Equal(t, string(BillingModePerRequest), cost.BillingMode) } func TestCalculateCostUnified_ImageMode(t *testing.T) { cs := newTestChannelServiceWithCache(t, &channelCache{ pricingByGroupModel: map[channelModelKey]*ChannelModelPricing{ {groupID: 2, model: "gemini-image"}: { BillingMode: BillingModeImage, PerRequestPrice: testPtrFloat64(0.10), }, }, channelByGroupID: map[int64]*Channel{ 2: {ID: 2, Status: StatusActive}, }, groupPlatform: map[int64]string{2: ""}, wildcardByGroupPlatform: map[channelGroupPlatformKey][]*wildcardPricingEntry{}, mappingByGroupModel: map[channelModelKey]string{}, wildcardMappingByGP: map[channelGroupPlatformKey][]*wildcardMappingEntry{}, byID: map[int64]*Channel{}, }) bs := &BillingService{ cfg: &config.Config{}, fallbackPrices: map[string]*ModelPricing{}, } resolver := NewModelPricingResolver(cs, bs) groupID := int64(2) input := CostInput{ Ctx: context.Background(), Model: "gemini-image", GroupID: &groupID, Tokens: UsageTokens{}, RequestCount: 2, RateMultiplier: 1.0, Resolver: resolver, } cost, err := bs.CalculateCostUnified(input) require.NoError(t, err) require.NotNil(t, cost) // 2 * $0.10 = $0.20 require.InDelta(t, 0.20, cost.TotalCost, 1e-10) require.InDelta(t, 0.20, cost.ActualCost, 1e-10) require.Equal(t, string(BillingModeImage), cost.BillingMode) } func TestCalculateCostUnified_RateMultiplierZeroDefaultsToOne(t *testing.T) { bs := newTestBillingService() resolver := NewModelPricingResolver(nil, bs) tokens := UsageTokens{InputTokens: 1000, OutputTokens: 500} costZero, err := bs.CalculateCostUnified(CostInput{ Ctx: context.Background(), Model: "claude-sonnet-4", Tokens: tokens, RateMultiplier: 0, // should default to 1.0 Resolver: resolver, }) require.NoError(t, err) costOne, err := bs.CalculateCostUnified(CostInput{ Ctx: context.Background(), Model: "claude-sonnet-4", Tokens: tokens, RateMultiplier: 1.0, Resolver: resolver, }) require.NoError(t, err) require.InDelta(t, costOne.ActualCost, costZero.ActualCost, 1e-10) } func TestCalculateCostUnified_NegativeRateMultiplierDefaultsToOne(t *testing.T) { bs := newTestBillingService() resolver := NewModelPricingResolver(nil, bs) tokens := UsageTokens{InputTokens: 1000} costNeg, err := bs.CalculateCostUnified(CostInput{ Ctx: context.Background(), Model: "claude-sonnet-4", Tokens: tokens, RateMultiplier: -5.0, Resolver: resolver, }) require.NoError(t, err) costOne, err := bs.CalculateCostUnified(CostInput{ Ctx: context.Background(), Model: "claude-sonnet-4", Tokens: tokens, RateMultiplier: 1.0, Resolver: resolver, }) require.NoError(t, err) require.InDelta(t, costOne.ActualCost, costNeg.ActualCost, 1e-10) } func TestCalculateCostUnified_BillingModeFieldFilled(t *testing.T) { bs := newTestBillingService() resolver := NewModelPricingResolver(nil, bs) cost, err := bs.CalculateCostUnified(CostInput{ Ctx: context.Background(), Model: "claude-sonnet-4", Tokens: UsageTokens{InputTokens: 100}, RateMultiplier: 1.0, Resolver: resolver, }) require.NoError(t, err) require.Equal(t, "token", cost.BillingMode) } func TestCalculateCostUnified_UsesPreResolvedPricing(t *testing.T) { bs := newTestBillingService() resolver := NewModelPricingResolver(nil, bs) // Pre-resolve with per_request mode to verify it's used instead of re-resolving preResolved := &ResolvedPricing{ Mode: BillingModePerRequest, DefaultPerRequestPrice: 0.07, } cost, err := bs.CalculateCostUnified(CostInput{ Ctx: context.Background(), Model: "claude-sonnet-4", Tokens: UsageTokens{InputTokens: 100}, RequestCount: 2, RateMultiplier: 1.0, Resolver: resolver, Resolved: preResolved, }) require.NoError(t, err) require.NotNil(t, cost) // 2 * $0.07 = $0.14 require.InDelta(t, 0.14, cost.TotalCost, 1e-10) require.Equal(t, string(BillingModePerRequest), cost.BillingMode) } // --------------------------------------------------------------------------- // helpers // --------------------------------------------------------------------------- // newTestChannelServiceWithCache creates a ChannelService with a pre-populated // cache snapshot, bypassing the repository layer entirely. func newTestChannelServiceWithCache(t *testing.T, cache *channelCache) *ChannelService { t.Helper() cs := &ChannelService{} cache.loadedAt = time.Now() cs.cache.Store(cache) return cs }