Billing (25 tests): - CalculateCostUnified: nil resolver fallback, token/per_request/image modes - GetModelPricingWithChannel: nil/partial/full channel overrides - resolveAccountStatsCost: four-level priority chain integration tests WebSearch (18 tests): - PopulateWebSearchUsage: nil input, manager states, QuotaLimit nil/*int64 - ResetWebSearchUsage: nil manager error - Manager.ResetUsage: nil Redis - shouldEmulateWebSearch: full decision chain (8 scenarios) Notify (36 tests): - ParseNotifyEmails/MarshalNotifyEmails: old/new format, roundtrip - crossedDownward: boundary values, threshold semantics - checkQuotaDimCrossings: mixed dimensions, disabled/zero skip
259 lines
7.9 KiB
Go
259 lines
7.9 KiB
Go
//go:build unit
|
|
|
|
package service
|
|
|
|
import (
|
|
"context"
|
|
"testing"
|
|
"time"
|
|
|
|
"github.com/Wei-Shaw/sub2api/internal/config"
|
|
"github.com/stretchr/testify/require"
|
|
)
|
|
|
|
// ---------------------------------------------------------------------------
|
|
// CalculateCostUnified
|
|
// ---------------------------------------------------------------------------
|
|
|
|
func TestCalculateCostUnified_NilResolver_FallsBackToOldPath(t *testing.T) {
|
|
svc := newTestBillingService()
|
|
|
|
tokens := UsageTokens{InputTokens: 1000, OutputTokens: 500}
|
|
input := CostInput{
|
|
Model: "claude-sonnet-4",
|
|
Tokens: tokens,
|
|
RateMultiplier: 1.0,
|
|
Resolver: nil, // no resolver
|
|
}
|
|
cost, err := svc.CalculateCostUnified(input)
|
|
require.NoError(t, err)
|
|
|
|
// Should match the old-path result exactly
|
|
expected, err := svc.calculateCostInternal("claude-sonnet-4", tokens, 1.0, "", nil)
|
|
require.NoError(t, err)
|
|
require.InDelta(t, expected.TotalCost, cost.TotalCost, 1e-10)
|
|
require.InDelta(t, expected.ActualCost, cost.ActualCost, 1e-10)
|
|
// BillingMode is NOT set by old path through CalculateCostUnified (resolver == nil)
|
|
require.Empty(t, cost.BillingMode)
|
|
}
|
|
|
|
func TestCalculateCostUnified_TokenMode(t *testing.T) {
|
|
bs := newTestBillingService()
|
|
resolver := NewModelPricingResolver(nil, bs)
|
|
|
|
tokens := UsageTokens{InputTokens: 1000, OutputTokens: 500}
|
|
input := CostInput{
|
|
Ctx: context.Background(),
|
|
Model: "claude-sonnet-4",
|
|
Tokens: tokens,
|
|
RateMultiplier: 1.5,
|
|
Resolver: resolver,
|
|
}
|
|
cost, err := bs.CalculateCostUnified(input)
|
|
require.NoError(t, err)
|
|
require.NotNil(t, cost)
|
|
|
|
// Verify token billing: Input: 1000*3e-6=0.003, Output: 500*15e-6=0.0075
|
|
expectedTotal := 1000*3e-6 + 500*15e-6
|
|
require.InDelta(t, expectedTotal, cost.TotalCost, 1e-10)
|
|
require.InDelta(t, expectedTotal*1.5, cost.ActualCost, 1e-10)
|
|
require.Equal(t, string(BillingModeToken), cost.BillingMode)
|
|
}
|
|
|
|
func TestCalculateCostUnified_PerRequestMode(t *testing.T) {
|
|
// Set up a ChannelService with a per-request pricing channel
|
|
cs := newTestChannelServiceWithCache(t, &channelCache{
|
|
pricingByGroupModel: map[channelModelKey]*ChannelModelPricing{
|
|
{groupID: 1, model: "claude-sonnet-4"}: {
|
|
BillingMode: BillingModePerRequest,
|
|
PerRequestPrice: testPtrFloat64(0.05),
|
|
},
|
|
},
|
|
channelByGroupID: map[int64]*Channel{
|
|
1: {ID: 1, Status: StatusActive},
|
|
},
|
|
groupPlatform: map[int64]string{1: ""},
|
|
wildcardByGroupPlatform: map[channelGroupPlatformKey][]*wildcardPricingEntry{},
|
|
mappingByGroupModel: map[channelModelKey]string{},
|
|
wildcardMappingByGP: map[channelGroupPlatformKey][]*wildcardMappingEntry{},
|
|
byID: map[int64]*Channel{},
|
|
})
|
|
|
|
bs := newTestBillingService()
|
|
resolver := NewModelPricingResolver(cs, bs)
|
|
groupID := int64(1)
|
|
|
|
input := CostInput{
|
|
Ctx: context.Background(),
|
|
Model: "claude-sonnet-4",
|
|
GroupID: &groupID,
|
|
Tokens: UsageTokens{InputTokens: 100, OutputTokens: 50},
|
|
RequestCount: 3,
|
|
RateMultiplier: 2.0,
|
|
Resolver: resolver,
|
|
}
|
|
cost, err := bs.CalculateCostUnified(input)
|
|
require.NoError(t, err)
|
|
require.NotNil(t, cost)
|
|
|
|
// 3 requests * $0.05 = $0.15
|
|
require.InDelta(t, 0.15, cost.TotalCost, 1e-10)
|
|
// ActualCost = 0.15 * 2.0 = 0.30
|
|
require.InDelta(t, 0.30, cost.ActualCost, 1e-10)
|
|
require.Equal(t, string(BillingModePerRequest), cost.BillingMode)
|
|
}
|
|
|
|
func TestCalculateCostUnified_ImageMode(t *testing.T) {
|
|
cs := newTestChannelServiceWithCache(t, &channelCache{
|
|
pricingByGroupModel: map[channelModelKey]*ChannelModelPricing{
|
|
{groupID: 2, model: "gemini-image"}: {
|
|
BillingMode: BillingModeImage,
|
|
PerRequestPrice: testPtrFloat64(0.10),
|
|
},
|
|
},
|
|
channelByGroupID: map[int64]*Channel{
|
|
2: {ID: 2, Status: StatusActive},
|
|
},
|
|
groupPlatform: map[int64]string{2: ""},
|
|
wildcardByGroupPlatform: map[channelGroupPlatformKey][]*wildcardPricingEntry{},
|
|
mappingByGroupModel: map[channelModelKey]string{},
|
|
wildcardMappingByGP: map[channelGroupPlatformKey][]*wildcardMappingEntry{},
|
|
byID: map[int64]*Channel{},
|
|
})
|
|
|
|
bs := &BillingService{
|
|
cfg: &config.Config{},
|
|
fallbackPrices: map[string]*ModelPricing{},
|
|
}
|
|
resolver := NewModelPricingResolver(cs, bs)
|
|
groupID := int64(2)
|
|
|
|
input := CostInput{
|
|
Ctx: context.Background(),
|
|
Model: "gemini-image",
|
|
GroupID: &groupID,
|
|
Tokens: UsageTokens{},
|
|
RequestCount: 2,
|
|
RateMultiplier: 1.0,
|
|
Resolver: resolver,
|
|
}
|
|
cost, err := bs.CalculateCostUnified(input)
|
|
require.NoError(t, err)
|
|
require.NotNil(t, cost)
|
|
|
|
// 2 * $0.10 = $0.20
|
|
require.InDelta(t, 0.20, cost.TotalCost, 1e-10)
|
|
require.InDelta(t, 0.20, cost.ActualCost, 1e-10)
|
|
require.Equal(t, string(BillingModeImage), cost.BillingMode)
|
|
}
|
|
|
|
func TestCalculateCostUnified_RateMultiplierZeroDefaultsToOne(t *testing.T) {
|
|
bs := newTestBillingService()
|
|
resolver := NewModelPricingResolver(nil, bs)
|
|
|
|
tokens := UsageTokens{InputTokens: 1000, OutputTokens: 500}
|
|
|
|
costZero, err := bs.CalculateCostUnified(CostInput{
|
|
Ctx: context.Background(),
|
|
Model: "claude-sonnet-4",
|
|
Tokens: tokens,
|
|
RateMultiplier: 0, // should default to 1.0
|
|
Resolver: resolver,
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
costOne, err := bs.CalculateCostUnified(CostInput{
|
|
Ctx: context.Background(),
|
|
Model: "claude-sonnet-4",
|
|
Tokens: tokens,
|
|
RateMultiplier: 1.0,
|
|
Resolver: resolver,
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
require.InDelta(t, costOne.ActualCost, costZero.ActualCost, 1e-10)
|
|
}
|
|
|
|
func TestCalculateCostUnified_NegativeRateMultiplierDefaultsToOne(t *testing.T) {
|
|
bs := newTestBillingService()
|
|
resolver := NewModelPricingResolver(nil, bs)
|
|
|
|
tokens := UsageTokens{InputTokens: 1000}
|
|
|
|
costNeg, err := bs.CalculateCostUnified(CostInput{
|
|
Ctx: context.Background(),
|
|
Model: "claude-sonnet-4",
|
|
Tokens: tokens,
|
|
RateMultiplier: -5.0,
|
|
Resolver: resolver,
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
costOne, err := bs.CalculateCostUnified(CostInput{
|
|
Ctx: context.Background(),
|
|
Model: "claude-sonnet-4",
|
|
Tokens: tokens,
|
|
RateMultiplier: 1.0,
|
|
Resolver: resolver,
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
require.InDelta(t, costOne.ActualCost, costNeg.ActualCost, 1e-10)
|
|
}
|
|
|
|
func TestCalculateCostUnified_BillingModeFieldFilled(t *testing.T) {
|
|
bs := newTestBillingService()
|
|
resolver := NewModelPricingResolver(nil, bs)
|
|
|
|
cost, err := bs.CalculateCostUnified(CostInput{
|
|
Ctx: context.Background(),
|
|
Model: "claude-sonnet-4",
|
|
Tokens: UsageTokens{InputTokens: 100},
|
|
RateMultiplier: 1.0,
|
|
Resolver: resolver,
|
|
})
|
|
require.NoError(t, err)
|
|
require.Equal(t, "token", cost.BillingMode)
|
|
}
|
|
|
|
func TestCalculateCostUnified_UsesPreResolvedPricing(t *testing.T) {
|
|
bs := newTestBillingService()
|
|
resolver := NewModelPricingResolver(nil, bs)
|
|
|
|
// Pre-resolve with per_request mode to verify it's used instead of re-resolving
|
|
preResolved := &ResolvedPricing{
|
|
Mode: BillingModePerRequest,
|
|
DefaultPerRequestPrice: 0.07,
|
|
}
|
|
|
|
cost, err := bs.CalculateCostUnified(CostInput{
|
|
Ctx: context.Background(),
|
|
Model: "claude-sonnet-4",
|
|
Tokens: UsageTokens{InputTokens: 100},
|
|
RequestCount: 2,
|
|
RateMultiplier: 1.0,
|
|
Resolver: resolver,
|
|
Resolved: preResolved,
|
|
})
|
|
require.NoError(t, err)
|
|
require.NotNil(t, cost)
|
|
|
|
// 2 * $0.07 = $0.14
|
|
require.InDelta(t, 0.14, cost.TotalCost, 1e-10)
|
|
require.Equal(t, string(BillingModePerRequest), cost.BillingMode)
|
|
}
|
|
|
|
// ---------------------------------------------------------------------------
|
|
// helpers
|
|
// ---------------------------------------------------------------------------
|
|
|
|
// newTestChannelServiceWithCache creates a ChannelService with a pre-populated
|
|
// cache snapshot, bypassing the repository layer entirely.
|
|
func newTestChannelServiceWithCache(t *testing.T, cache *channelCache) *ChannelService {
|
|
t.Helper()
|
|
cs := &ChannelService{}
|
|
cache.loadedAt = time.Now()
|
|
cs.cache.Store(cache)
|
|
return cs
|
|
}
|