165 lines
5.1 KiB
Go
165 lines
5.1 KiB
Go
//go:build unit
|
|
|
|
package service
|
|
|
|
import (
|
|
"context"
|
|
"testing"
|
|
|
|
"github.com/stretchr/testify/require"
|
|
)
|
|
|
|
func resolverPtrFloat64(v float64) *float64 { return &v }
|
|
func resolverPtrInt(v int) *int { return &v }
|
|
|
|
func newTestBillingServiceForResolver() *BillingService {
|
|
bs := &BillingService{
|
|
fallbackPrices: make(map[string]*ModelPricing),
|
|
}
|
|
bs.fallbackPrices["claude-sonnet-4"] = &ModelPricing{
|
|
InputPricePerToken: 3e-6,
|
|
OutputPricePerToken: 15e-6,
|
|
CacheCreationPricePerToken: 3.75e-6,
|
|
CacheReadPricePerToken: 0.3e-6,
|
|
SupportsCacheBreakdown: false,
|
|
}
|
|
return bs
|
|
}
|
|
|
|
func TestResolve_NoGroupID(t *testing.T) {
|
|
bs := newTestBillingServiceForResolver()
|
|
r := NewModelPricingResolver(&ChannelService{}, bs)
|
|
|
|
resolved := r.Resolve(context.Background(), PricingInput{
|
|
Model: "claude-sonnet-4",
|
|
GroupID: nil,
|
|
})
|
|
|
|
require.NotNil(t, resolved)
|
|
require.Equal(t, BillingModeToken, resolved.Mode)
|
|
require.NotNil(t, resolved.BasePricing)
|
|
require.InDelta(t, 3e-6, resolved.BasePricing.InputPricePerToken, 1e-12)
|
|
// BillingService.GetModelPricing uses fallback internally, but resolveBasePricing
|
|
// reports "litellm" when GetModelPricing succeeds (regardless of internal source)
|
|
require.Equal(t, "litellm", resolved.Source)
|
|
}
|
|
|
|
func TestResolve_UnknownModel(t *testing.T) {
|
|
bs := newTestBillingServiceForResolver()
|
|
r := NewModelPricingResolver(&ChannelService{}, bs)
|
|
|
|
resolved := r.Resolve(context.Background(), PricingInput{
|
|
Model: "unknown-model-xyz",
|
|
GroupID: nil,
|
|
})
|
|
|
|
require.NotNil(t, resolved)
|
|
require.Nil(t, resolved.BasePricing)
|
|
// Unknown model: GetModelPricing returns error, source is "fallback"
|
|
require.Equal(t, "fallback", resolved.Source)
|
|
}
|
|
|
|
func TestGetIntervalPricing_NoIntervals(t *testing.T) {
|
|
bs := newTestBillingServiceForResolver()
|
|
r := NewModelPricingResolver(&ChannelService{}, bs)
|
|
|
|
basePricing := &ModelPricing{InputPricePerToken: 5e-6}
|
|
resolved := &ResolvedPricing{
|
|
Mode: BillingModeToken,
|
|
BasePricing: basePricing,
|
|
Intervals: nil,
|
|
}
|
|
|
|
result := r.GetIntervalPricing(resolved, 50000)
|
|
require.Equal(t, basePricing, result)
|
|
}
|
|
|
|
func TestGetIntervalPricing_MatchesInterval(t *testing.T) {
|
|
bs := newTestBillingServiceForResolver()
|
|
r := NewModelPricingResolver(&ChannelService{}, bs)
|
|
|
|
resolved := &ResolvedPricing{
|
|
Mode: BillingModeToken,
|
|
BasePricing: &ModelPricing{InputPricePerToken: 5e-6},
|
|
SupportsCacheBreakdown: true,
|
|
Intervals: []PricingInterval{
|
|
{MinTokens: 0, MaxTokens: resolverPtrInt(128000), InputPrice: resolverPtrFloat64(1e-6), OutputPrice: resolverPtrFloat64(2e-6)},
|
|
{MinTokens: 128000, MaxTokens: nil, InputPrice: resolverPtrFloat64(3e-6), OutputPrice: resolverPtrFloat64(6e-6)},
|
|
},
|
|
}
|
|
|
|
result := r.GetIntervalPricing(resolved, 50000)
|
|
require.NotNil(t, result)
|
|
require.InDelta(t, 1e-6, result.InputPricePerToken, 1e-12)
|
|
require.InDelta(t, 2e-6, result.OutputPricePerToken, 1e-12)
|
|
require.True(t, result.SupportsCacheBreakdown)
|
|
|
|
result2 := r.GetIntervalPricing(resolved, 200000)
|
|
require.NotNil(t, result2)
|
|
require.InDelta(t, 3e-6, result2.InputPricePerToken, 1e-12)
|
|
}
|
|
|
|
func TestGetIntervalPricing_NoMatch_FallsBackToBase(t *testing.T) {
|
|
bs := newTestBillingServiceForResolver()
|
|
r := NewModelPricingResolver(&ChannelService{}, bs)
|
|
|
|
basePricing := &ModelPricing{InputPricePerToken: 99e-6}
|
|
resolved := &ResolvedPricing{
|
|
Mode: BillingModeToken,
|
|
BasePricing: basePricing,
|
|
Intervals: []PricingInterval{
|
|
{MinTokens: 10000, MaxTokens: resolverPtrInt(50000), InputPrice: resolverPtrFloat64(1e-6)},
|
|
},
|
|
}
|
|
|
|
result := r.GetIntervalPricing(resolved, 5000)
|
|
require.Equal(t, basePricing, result)
|
|
}
|
|
|
|
func TestGetRequestTierPrice(t *testing.T) {
|
|
bs := newTestBillingServiceForResolver()
|
|
r := NewModelPricingResolver(&ChannelService{}, bs)
|
|
|
|
resolved := &ResolvedPricing{
|
|
Mode: BillingModePerRequest,
|
|
RequestTiers: []PricingInterval{
|
|
{TierLabel: "1K", PerRequestPrice: resolverPtrFloat64(0.04)},
|
|
{TierLabel: "2K", PerRequestPrice: resolverPtrFloat64(0.08)},
|
|
},
|
|
}
|
|
|
|
require.InDelta(t, 0.04, r.GetRequestTierPrice(resolved, "1K"), 1e-12)
|
|
require.InDelta(t, 0.08, r.GetRequestTierPrice(resolved, "2K"), 1e-12)
|
|
require.InDelta(t, 0.0, r.GetRequestTierPrice(resolved, "4K"), 1e-12)
|
|
}
|
|
|
|
func TestGetRequestTierPriceByContext(t *testing.T) {
|
|
bs := newTestBillingServiceForResolver()
|
|
r := NewModelPricingResolver(&ChannelService{}, bs)
|
|
|
|
resolved := &ResolvedPricing{
|
|
Mode: BillingModePerRequest,
|
|
RequestTiers: []PricingInterval{
|
|
{MinTokens: 0, MaxTokens: resolverPtrInt(128000), PerRequestPrice: resolverPtrFloat64(0.05)},
|
|
{MinTokens: 128000, MaxTokens: nil, PerRequestPrice: resolverPtrFloat64(0.10)},
|
|
},
|
|
}
|
|
|
|
require.InDelta(t, 0.05, r.GetRequestTierPriceByContext(resolved, 50000), 1e-12)
|
|
require.InDelta(t, 0.10, r.GetRequestTierPriceByContext(resolved, 200000), 1e-12)
|
|
}
|
|
|
|
func TestGetRequestTierPrice_NilPerRequestPrice(t *testing.T) {
|
|
bs := newTestBillingServiceForResolver()
|
|
r := NewModelPricingResolver(&ChannelService{}, bs)
|
|
|
|
resolved := &ResolvedPricing{
|
|
Mode: BillingModePerRequest,
|
|
RequestTiers: []PricingInterval{
|
|
{TierLabel: "1K", PerRequestPrice: nil},
|
|
},
|
|
}
|
|
|
|
require.InDelta(t, 0.0, r.GetRequestTierPrice(resolved, "1K"), 1e-12)
|
|
}
|