test: add unit tests for billing, websearch, and notify systems
Billing (25 tests): - CalculateCostUnified: nil resolver fallback, token/per_request/image modes - GetModelPricingWithChannel: nil/partial/full channel overrides - resolveAccountStatsCost: four-level priority chain integration tests WebSearch (18 tests): - PopulateWebSearchUsage: nil input, manager states, QuotaLimit nil/*int64 - ResetWebSearchUsage: nil manager error - Manager.ResetUsage: nil Redis - shouldEmulateWebSearch: full decision chain (8 scenarios) Notify (36 tests): - ParseNotifyEmails/MarshalNotifyEmails: old/new format, roundtrip - crossedDownward: boundary values, threshold semantics - checkQuotaDimCrossings: mixed dimensions, disabled/zero skip
This commit is contained in:
@@ -3,7 +3,9 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
@@ -527,3 +529,243 @@ func TestTryModelFilePricing_WithCacheTokens(t *testing.T) {
|
||||
// = 0.1 + 0.1 + 0.6 + 0.15 = 0.95
|
||||
require.InDelta(t, 0.95, *result, 1e-12)
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// resolveAccountStatsCost — integration tests covering the 4-level priority chain
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func TestResolveAccountStatsCost_NilChannelService(t *testing.T) {
|
||||
result := resolveAccountStatsCost(
|
||||
context.Background(),
|
||||
nil, // channelService is nil
|
||||
newTestBillingServiceWithPrices(map[string]*ModelPricing{}),
|
||||
1, 1, "claude-sonnet-4",
|
||||
UsageTokens{InputTokens: 100}, 1, 0.5,
|
||||
)
|
||||
require.Nil(t, result)
|
||||
}
|
||||
|
||||
func TestResolveAccountStatsCost_EmptyUpstreamModel(t *testing.T) {
|
||||
cs := newTestChannelServiceForStats(t, &Channel{
|
||||
ID: 1,
|
||||
Status: StatusActive,
|
||||
}, 1, "")
|
||||
|
||||
result := resolveAccountStatsCost(
|
||||
context.Background(),
|
||||
cs,
|
||||
newTestBillingServiceWithPrices(map[string]*ModelPricing{}),
|
||||
1, 1, "", // empty upstream model
|
||||
UsageTokens{InputTokens: 100}, 1, 0.5,
|
||||
)
|
||||
require.Nil(t, result)
|
||||
}
|
||||
|
||||
func TestResolveAccountStatsCost_GetChannelForGroupReturnsNil(t *testing.T) {
|
||||
// Group 99 is NOT in the cache, so GetChannelForGroup returns nil
|
||||
cs := newTestChannelServiceForStats(t, &Channel{
|
||||
ID: 1,
|
||||
Status: StatusActive,
|
||||
}, 1, "")
|
||||
|
||||
result := resolveAccountStatsCost(
|
||||
context.Background(),
|
||||
cs,
|
||||
newTestBillingServiceWithPrices(map[string]*ModelPricing{}),
|
||||
1, 99, "claude-sonnet-4", // groupID 99 has no channel
|
||||
UsageTokens{InputTokens: 100}, 1, 0.5,
|
||||
)
|
||||
require.Nil(t, result)
|
||||
}
|
||||
|
||||
func TestResolveAccountStatsCost_HitsCustomRule(t *testing.T) {
|
||||
channel := &Channel{
|
||||
ID: 1,
|
||||
Status: StatusActive,
|
||||
AccountStatsPricingRules: []AccountStatsPricingRule{
|
||||
{
|
||||
GroupIDs: []int64{10},
|
||||
Pricing: []ChannelModelPricing{
|
||||
{
|
||||
ID: 100,
|
||||
Models: []string{"claude-sonnet-4"},
|
||||
InputPrice: testPtrFloat64(0.01),
|
||||
OutputPrice: testPtrFloat64(0.02),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
cs := newTestChannelServiceForStats(t, channel, 10, "anthropic")
|
||||
|
||||
tokens := UsageTokens{InputTokens: 100, OutputTokens: 50}
|
||||
|
||||
result := resolveAccountStatsCost(
|
||||
context.Background(),
|
||||
cs, nil, // billingService not needed when custom rule hits
|
||||
1, 10, "claude-sonnet-4",
|
||||
tokens, 1, 999.0, // totalCost ignored because custom rule hits
|
||||
)
|
||||
require.NotNil(t, result)
|
||||
// 100*0.01 + 50*0.02 = 1.0 + 1.0 = 2.0
|
||||
require.InDelta(t, 2.0, *result, 1e-12)
|
||||
}
|
||||
|
||||
func TestResolveAccountStatsCost_ApplyPricingToAccountStats_UsesTotalCost(t *testing.T) {
|
||||
channel := &Channel{
|
||||
ID: 1,
|
||||
Status: StatusActive,
|
||||
ApplyPricingToAccountStats: true,
|
||||
// No custom rules
|
||||
}
|
||||
cs := newTestChannelServiceForStats(t, channel, 10, "anthropic")
|
||||
|
||||
tokens := UsageTokens{InputTokens: 100, OutputTokens: 50}
|
||||
|
||||
result := resolveAccountStatsCost(
|
||||
context.Background(),
|
||||
cs, nil,
|
||||
1, 10, "claude-sonnet-4",
|
||||
tokens, 1, 0.75, // totalCost = 0.75
|
||||
)
|
||||
require.NotNil(t, result)
|
||||
require.InDelta(t, 0.75, *result, 1e-12)
|
||||
}
|
||||
|
||||
func TestResolveAccountStatsCost_ApplyPricingToAccountStats_ZeroTotalCost_ReturnsNil(t *testing.T) {
|
||||
channel := &Channel{
|
||||
ID: 1,
|
||||
Status: StatusActive,
|
||||
ApplyPricingToAccountStats: true,
|
||||
}
|
||||
cs := newTestChannelServiceForStats(t, channel, 10, "anthropic")
|
||||
|
||||
result := resolveAccountStatsCost(
|
||||
context.Background(),
|
||||
cs, nil,
|
||||
1, 10, "claude-sonnet-4",
|
||||
UsageTokens{}, 1, 0.0, // totalCost = 0
|
||||
)
|
||||
require.Nil(t, result)
|
||||
}
|
||||
|
||||
func TestResolveAccountStatsCost_FallsBackToLiteLLM(t *testing.T) {
|
||||
channel := &Channel{
|
||||
ID: 1,
|
||||
Status: StatusActive,
|
||||
ApplyPricingToAccountStats: false, // not enabled
|
||||
// No custom rules
|
||||
}
|
||||
cs := newTestChannelServiceForStats(t, channel, 10, "anthropic")
|
||||
|
||||
bs := newTestBillingServiceWithPrices(map[string]*ModelPricing{
|
||||
"claude-sonnet-4": {
|
||||
InputPricePerToken: 0.001,
|
||||
OutputPricePerToken: 0.002,
|
||||
},
|
||||
})
|
||||
|
||||
tokens := UsageTokens{InputTokens: 100, OutputTokens: 50}
|
||||
|
||||
result := resolveAccountStatsCost(
|
||||
context.Background(),
|
||||
cs, bs,
|
||||
1, 10, "claude-sonnet-4",
|
||||
tokens, 1, 999.0, // totalCost ignored
|
||||
)
|
||||
require.NotNil(t, result)
|
||||
// 100*0.001 + 50*0.002 = 0.1 + 0.1 = 0.2
|
||||
require.InDelta(t, 0.2, *result, 1e-12)
|
||||
}
|
||||
|
||||
func TestResolveAccountStatsCost_AllMiss_ReturnsNil(t *testing.T) {
|
||||
channel := &Channel{
|
||||
ID: 1,
|
||||
Status: StatusActive,
|
||||
ApplyPricingToAccountStats: false,
|
||||
// No custom rules
|
||||
}
|
||||
cs := newTestChannelServiceForStats(t, channel, 10, "anthropic")
|
||||
|
||||
// BillingService with no pricing for the model
|
||||
bs := newTestBillingServiceWithPrices(map[string]*ModelPricing{})
|
||||
|
||||
tokens := UsageTokens{InputTokens: 100, OutputTokens: 50}
|
||||
|
||||
result := resolveAccountStatsCost(
|
||||
context.Background(),
|
||||
cs, bs,
|
||||
1, 10, "totally-unknown-model",
|
||||
tokens, 1, 0.0,
|
||||
)
|
||||
require.Nil(t, result)
|
||||
}
|
||||
|
||||
func TestResolveAccountStatsCost_NilBillingService_SkipsLiteLLM(t *testing.T) {
|
||||
channel := &Channel{
|
||||
ID: 1,
|
||||
Status: StatusActive,
|
||||
ApplyPricingToAccountStats: false,
|
||||
}
|
||||
cs := newTestChannelServiceForStats(t, channel, 10, "anthropic")
|
||||
|
||||
result := resolveAccountStatsCost(
|
||||
context.Background(),
|
||||
cs, nil, // billingService is nil
|
||||
1, 10, "claude-sonnet-4",
|
||||
UsageTokens{InputTokens: 100}, 1, 0.0,
|
||||
)
|
||||
require.Nil(t, result)
|
||||
}
|
||||
|
||||
func TestResolveAccountStatsCost_CustomRulePriorityOverApplyPricing(t *testing.T) {
|
||||
// Both custom rule and ApplyPricingToAccountStats are configured;
|
||||
// custom rule should take precedence.
|
||||
channel := &Channel{
|
||||
ID: 1,
|
||||
Status: StatusActive,
|
||||
ApplyPricingToAccountStats: true,
|
||||
AccountStatsPricingRules: []AccountStatsPricingRule{
|
||||
{
|
||||
GroupIDs: []int64{10},
|
||||
Pricing: []ChannelModelPricing{
|
||||
{
|
||||
ID: 100,
|
||||
Models: []string{"claude-sonnet-4"},
|
||||
InputPrice: testPtrFloat64(0.05),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
cs := newTestChannelServiceForStats(t, channel, 10, "anthropic")
|
||||
|
||||
tokens := UsageTokens{InputTokens: 100}
|
||||
|
||||
result := resolveAccountStatsCost(
|
||||
context.Background(),
|
||||
cs, nil,
|
||||
1, 10, "claude-sonnet-4",
|
||||
tokens, 1, 99.0, // totalCost = 99.0 (would be used if ApplyPricing wins)
|
||||
)
|
||||
require.NotNil(t, result)
|
||||
// Custom rule: 100*0.05 = 5.0 (NOT 99.0 from totalCost)
|
||||
require.InDelta(t, 5.0, *result, 1e-12)
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// helpers for resolveAccountStatsCost tests
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
// newTestChannelServiceForStats creates a ChannelService with a single channel
|
||||
// mapped to the given groupID, suitable for resolveAccountStatsCost tests.
|
||||
func newTestChannelServiceForStats(t *testing.T, channel *Channel, groupID int64, platform string) *ChannelService {
|
||||
t.Helper()
|
||||
cache := newEmptyChannelCache()
|
||||
cache.channelByGroupID[groupID] = channel
|
||||
cache.groupPlatform[groupID] = platform
|
||||
cs := &ChannelService{}
|
||||
cache.loadedAt = time.Now()
|
||||
cs.cache.Store(cache)
|
||||
return cs
|
||||
}
|
||||
|
||||
@@ -178,3 +178,227 @@ func TestGetSiteName_Configured(t *testing.T) {
|
||||
repo.data[SettingKeySiteName] = "My Site"
|
||||
require.Equal(t, "My Site", s.getSiteName(context.Background()))
|
||||
}
|
||||
|
||||
// ---------- crossedDownward ----------
|
||||
|
||||
func TestCrossedDownward_CrossesBelow(t *testing.T) {
|
||||
// oldBalance > threshold, newBalance < threshold → true
|
||||
require.True(t, crossedDownward(100, 5, 10))
|
||||
}
|
||||
|
||||
func TestCrossedDownward_ExactlyAtThreshold(t *testing.T) {
|
||||
// oldBalance > threshold, newBalance == threshold → false (not below)
|
||||
require.False(t, crossedDownward(100, 10, 10))
|
||||
}
|
||||
|
||||
func TestCrossedDownward_OldExactlyAtThreshold_NewBelow(t *testing.T) {
|
||||
// oldBalance == threshold, newBalance < threshold → true
|
||||
// (at-or-above → below counts as a crossing)
|
||||
require.True(t, crossedDownward(10, 5, 10))
|
||||
}
|
||||
|
||||
func TestCrossedDownward_AlreadyBelow(t *testing.T) {
|
||||
// oldBalance < threshold → false (already below, no new crossing)
|
||||
require.False(t, crossedDownward(5, 3, 10))
|
||||
}
|
||||
|
||||
func TestCrossedDownward_BothAbove(t *testing.T) {
|
||||
// oldBalance > threshold, newBalance > threshold → false (no crossing)
|
||||
require.False(t, crossedDownward(100, 50, 10))
|
||||
}
|
||||
|
||||
func TestCrossedDownward_ZeroThreshold(t *testing.T) {
|
||||
// threshold == 0 → oldV >= 0 is always true, but newV < 0 only for negatives
|
||||
// Typical case: positive balances should not fire when threshold is 0.
|
||||
require.False(t, crossedDownward(10, 5, 0))
|
||||
require.False(t, crossedDownward(0, 0, 0))
|
||||
}
|
||||
|
||||
func TestCrossedDownward_ZeroThreshold_NegativeNew(t *testing.T) {
|
||||
// Edge case: newBalance goes negative with threshold=0.
|
||||
require.True(t, crossedDownward(5, -1, 0))
|
||||
}
|
||||
|
||||
func TestCrossedDownward_NegativeValues(t *testing.T) {
|
||||
// Both already negative, threshold is positive → no crossing (already below).
|
||||
require.False(t, crossedDownward(-5, -10, 10))
|
||||
}
|
||||
|
||||
func TestCrossedDownward_LargeDecrement(t *testing.T) {
|
||||
// A single large deduction crosses the threshold.
|
||||
require.True(t, crossedDownward(1000, 0.5, 100))
|
||||
}
|
||||
|
||||
func TestCrossedDownward_SmallDecrement_NoCrossing(t *testing.T) {
|
||||
// A tiny deduction stays above threshold.
|
||||
require.False(t, crossedDownward(100, 99.99, 10))
|
||||
}
|
||||
|
||||
// ---------- checkQuotaDimCrossings ----------
|
||||
|
||||
func TestCheckQuotaDimCrossings_NoDimensions(t *testing.T) {
|
||||
s, _ := newBalanceNotifyServiceForTest()
|
||||
account := &Account{ID: 1, Name: "test", Platform: PlatformAnthropic}
|
||||
// Empty dims → no crossing, no panic.
|
||||
s.checkQuotaDimCrossings(account, nil, 10, []string{"admin@example.com"}, "TestSite")
|
||||
s.checkQuotaDimCrossings(account, []quotaDim{}, 10, []string{"admin@example.com"}, "TestSite")
|
||||
}
|
||||
|
||||
func TestCheckQuotaDimCrossings_DisabledDimension(t *testing.T) {
|
||||
s, _ := newBalanceNotifyServiceForTest()
|
||||
account := &Account{ID: 1, Name: "test", Platform: PlatformAnthropic}
|
||||
dims := []quotaDim{
|
||||
{
|
||||
name: quotaDimDaily,
|
||||
enabled: false, // disabled
|
||||
threshold: 100,
|
||||
thresholdType: thresholdTypeFixed,
|
||||
currentUsed: 950,
|
||||
limit: 1000,
|
||||
},
|
||||
}
|
||||
// Disabled dimension should be skipped even if crossing would occur.
|
||||
s.checkQuotaDimCrossings(account, dims, 50, []string{"admin@example.com"}, "TestSite")
|
||||
}
|
||||
|
||||
func TestCheckQuotaDimCrossings_ZeroThresholdSkipped(t *testing.T) {
|
||||
s, _ := newBalanceNotifyServiceForTest()
|
||||
account := &Account{ID: 1, Name: "test", Platform: PlatformAnthropic}
|
||||
dims := []quotaDim{
|
||||
{
|
||||
name: quotaDimDaily,
|
||||
enabled: true,
|
||||
threshold: 0, // zero threshold
|
||||
thresholdType: thresholdTypeFixed,
|
||||
currentUsed: 950,
|
||||
limit: 1000,
|
||||
},
|
||||
}
|
||||
// Zero threshold → skipped.
|
||||
s.checkQuotaDimCrossings(account, dims, 50, []string{"admin@example.com"}, "TestSite")
|
||||
}
|
||||
|
||||
func TestCheckQuotaDimCrossings_NoCrossing_BothBelowThreshold(t *testing.T) {
|
||||
s, _ := newBalanceNotifyServiceForTest()
|
||||
account := &Account{ID: 1, Name: "test", Platform: PlatformAnthropic}
|
||||
// threshold=400 remaining, limit=1000 → effectiveThreshold = 600 (usage trigger)
|
||||
// currentUsed=300 (after), oldUsed=300-50=250 (before). Both < 600, no crossing.
|
||||
dims := []quotaDim{
|
||||
{
|
||||
name: quotaDimDaily,
|
||||
enabled: true,
|
||||
threshold: 400,
|
||||
thresholdType: thresholdTypeFixed,
|
||||
currentUsed: 300,
|
||||
limit: 1000,
|
||||
},
|
||||
}
|
||||
s.checkQuotaDimCrossings(account, dims, 50, []string{"admin@example.com"}, "TestSite")
|
||||
}
|
||||
|
||||
func TestCheckQuotaDimCrossings_NoCrossing_BothAboveThreshold(t *testing.T) {
|
||||
s, _ := newBalanceNotifyServiceForTest()
|
||||
account := &Account{ID: 1, Name: "test", Platform: PlatformAnthropic}
|
||||
// threshold=400 remaining, limit=1000 → effectiveThreshold = 600 (usage trigger)
|
||||
// currentUsed=800 (after), oldUsed=800-50=750 (before). Both >= 600, no crossing.
|
||||
dims := []quotaDim{
|
||||
{
|
||||
name: quotaDimDaily,
|
||||
enabled: true,
|
||||
threshold: 400,
|
||||
thresholdType: thresholdTypeFixed,
|
||||
currentUsed: 800,
|
||||
limit: 1000,
|
||||
},
|
||||
}
|
||||
s.checkQuotaDimCrossings(account, dims, 50, []string{"admin@example.com"}, "TestSite")
|
||||
}
|
||||
|
||||
func TestCheckQuotaDimCrossings_NegativeResolvedThreshold_Skipped(t *testing.T) {
|
||||
s, _ := newBalanceNotifyServiceForTest()
|
||||
account := &Account{ID: 1, Name: "test", Platform: PlatformAnthropic}
|
||||
// threshold=1200 remaining, limit=1000 → effectiveThreshold = 1000-1200 = -200
|
||||
// Negative resolved threshold → skipped.
|
||||
dims := []quotaDim{
|
||||
{
|
||||
name: quotaDimDaily,
|
||||
enabled: true,
|
||||
threshold: 1200,
|
||||
thresholdType: thresholdTypeFixed,
|
||||
currentUsed: 950,
|
||||
limit: 1000,
|
||||
},
|
||||
}
|
||||
s.checkQuotaDimCrossings(account, dims, 50, []string{"admin@example.com"}, "TestSite")
|
||||
}
|
||||
|
||||
func TestCheckQuotaDimCrossings_PercentageThreshold_NoCrossing(t *testing.T) {
|
||||
s, _ := newBalanceNotifyServiceForTest()
|
||||
account := &Account{ID: 1, Name: "test", Platform: PlatformAnthropic}
|
||||
// threshold=30%, limit=1000 → effectiveThreshold = 1000 * (1 - 0.30) = 700
|
||||
// currentUsed=500, oldUsed=500-50=450. Both < 700, no crossing.
|
||||
dims := []quotaDim{
|
||||
{
|
||||
name: quotaDimWeekly,
|
||||
enabled: true,
|
||||
threshold: 30,
|
||||
thresholdType: thresholdTypePercentage,
|
||||
currentUsed: 500,
|
||||
limit: 1000,
|
||||
},
|
||||
}
|
||||
s.checkQuotaDimCrossings(account, dims, 50, []string{"admin@example.com"}, "TestSite")
|
||||
}
|
||||
|
||||
func TestCheckQuotaDimCrossings_ZeroLimit_Skipped(t *testing.T) {
|
||||
s, _ := newBalanceNotifyServiceForTest()
|
||||
account := &Account{ID: 1, Name: "test", Platform: PlatformAnthropic}
|
||||
// limit=0 → resolvedThreshold returns 0 → skipped.
|
||||
dims := []quotaDim{
|
||||
{
|
||||
name: quotaDimTotal,
|
||||
enabled: true,
|
||||
threshold: 100,
|
||||
thresholdType: thresholdTypeFixed,
|
||||
currentUsed: 50,
|
||||
limit: 0,
|
||||
},
|
||||
}
|
||||
s.checkQuotaDimCrossings(account, dims, 50, []string{"admin@example.com"}, "TestSite")
|
||||
}
|
||||
|
||||
func TestCheckQuotaDimCrossings_MultipleDims_MixedResults(t *testing.T) {
|
||||
s, _ := newBalanceNotifyServiceForTest()
|
||||
account := &Account{ID: 1, Name: "test", Platform: PlatformAnthropic}
|
||||
// dim1: no crossing (both below effective threshold)
|
||||
// dim2: disabled (skipped)
|
||||
// dim3: zero threshold (skipped)
|
||||
dims := []quotaDim{
|
||||
{
|
||||
name: quotaDimDaily,
|
||||
enabled: true,
|
||||
threshold: 400,
|
||||
thresholdType: thresholdTypeFixed,
|
||||
currentUsed: 300, // oldUsed=250, effectiveThreshold=600, both below
|
||||
limit: 1000,
|
||||
},
|
||||
{
|
||||
name: quotaDimWeekly,
|
||||
enabled: false,
|
||||
threshold: 100,
|
||||
thresholdType: thresholdTypeFixed,
|
||||
currentUsed: 900,
|
||||
limit: 1000,
|
||||
},
|
||||
{
|
||||
name: quotaDimTotal,
|
||||
enabled: true,
|
||||
threshold: 0,
|
||||
thresholdType: thresholdTypeFixed,
|
||||
currentUsed: 500,
|
||||
limit: 1000,
|
||||
},
|
||||
}
|
||||
// None should trigger. No panic expected.
|
||||
s.checkQuotaDimCrossings(account, dims, 50, []string{"admin@example.com"}, "TestSite")
|
||||
}
|
||||
|
||||
@@ -718,3 +718,123 @@ func TestGetModelPricing_MapsDynamicPriorityFieldsIntoBillingPricing(t *testing.
|
||||
require.InDelta(t, 1.5, pricing.LongContextInputMultiplier, 1e-12)
|
||||
require.InDelta(t, 1.25, pricing.LongContextOutputMultiplier, 1e-12)
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// GetModelPricingWithChannel
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func TestGetModelPricingWithChannel_NilChannelPricing_ReturnsOriginal(t *testing.T) {
|
||||
svc := newTestBillingService()
|
||||
|
||||
pricing, err := svc.GetModelPricingWithChannel("claude-sonnet-4", nil)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, pricing)
|
||||
|
||||
// Should be identical to GetModelPricing
|
||||
original, err := svc.GetModelPricing("claude-sonnet-4")
|
||||
require.NoError(t, err)
|
||||
require.InDelta(t, original.InputPricePerToken, pricing.InputPricePerToken, 1e-12)
|
||||
require.InDelta(t, original.OutputPricePerToken, pricing.OutputPricePerToken, 1e-12)
|
||||
require.InDelta(t, original.CacheCreationPricePerToken, pricing.CacheCreationPricePerToken, 1e-12)
|
||||
require.InDelta(t, original.CacheReadPricePerToken, pricing.CacheReadPricePerToken, 1e-12)
|
||||
}
|
||||
|
||||
func TestGetModelPricingWithChannel_OverrideInputPriceOnly(t *testing.T) {
|
||||
svc := newTestBillingService()
|
||||
|
||||
chPricing := &ChannelModelPricing{
|
||||
InputPrice: testPtrFloat64(99e-6),
|
||||
}
|
||||
pricing, err := svc.GetModelPricingWithChannel("claude-sonnet-4", chPricing)
|
||||
require.NoError(t, err)
|
||||
|
||||
// InputPrice overridden (both normal and priority)
|
||||
require.InDelta(t, 99e-6, pricing.InputPricePerToken, 1e-12)
|
||||
require.InDelta(t, 99e-6, pricing.InputPricePerTokenPriority, 1e-12)
|
||||
|
||||
// OutputPrice unchanged (claude-sonnet-4 fallback = 15e-6)
|
||||
require.InDelta(t, 15e-6, pricing.OutputPricePerToken, 1e-12)
|
||||
}
|
||||
|
||||
func TestGetModelPricingWithChannel_OverrideOutputPriceOnly(t *testing.T) {
|
||||
svc := newTestBillingService()
|
||||
|
||||
chPricing := &ChannelModelPricing{
|
||||
OutputPrice: testPtrFloat64(88e-6),
|
||||
}
|
||||
pricing, err := svc.GetModelPricingWithChannel("claude-sonnet-4", chPricing)
|
||||
require.NoError(t, err)
|
||||
|
||||
// OutputPrice overridden
|
||||
require.InDelta(t, 88e-6, pricing.OutputPricePerToken, 1e-12)
|
||||
require.InDelta(t, 88e-6, pricing.OutputPricePerTokenPriority, 1e-12)
|
||||
|
||||
// InputPrice unchanged (claude-sonnet-4 fallback = 3e-6)
|
||||
require.InDelta(t, 3e-6, pricing.InputPricePerToken, 1e-12)
|
||||
}
|
||||
|
||||
func TestGetModelPricingWithChannel_OverrideAllFields(t *testing.T) {
|
||||
svc := newTestBillingService()
|
||||
|
||||
chPricing := &ChannelModelPricing{
|
||||
InputPrice: testPtrFloat64(10e-6),
|
||||
OutputPrice: testPtrFloat64(20e-6),
|
||||
CacheWritePrice: testPtrFloat64(5e-6),
|
||||
CacheReadPrice: testPtrFloat64(1e-6),
|
||||
ImageOutputPrice: testPtrFloat64(50e-6),
|
||||
}
|
||||
pricing, err := svc.GetModelPricingWithChannel("claude-sonnet-4", chPricing)
|
||||
require.NoError(t, err)
|
||||
|
||||
require.InDelta(t, 10e-6, pricing.InputPricePerToken, 1e-12)
|
||||
require.InDelta(t, 10e-6, pricing.InputPricePerTokenPriority, 1e-12)
|
||||
require.InDelta(t, 20e-6, pricing.OutputPricePerToken, 1e-12)
|
||||
require.InDelta(t, 20e-6, pricing.OutputPricePerTokenPriority, 1e-12)
|
||||
require.InDelta(t, 5e-6, pricing.CacheCreationPricePerToken, 1e-12)
|
||||
require.InDelta(t, 5e-6, pricing.CacheCreation5mPrice, 1e-12)
|
||||
require.InDelta(t, 5e-6, pricing.CacheCreation1hPrice, 1e-12)
|
||||
require.InDelta(t, 1e-6, pricing.CacheReadPricePerToken, 1e-12)
|
||||
require.InDelta(t, 1e-6, pricing.CacheReadPricePerTokenPriority, 1e-12)
|
||||
require.InDelta(t, 50e-6, pricing.ImageOutputPricePerToken, 1e-12)
|
||||
}
|
||||
|
||||
func TestGetModelPricingWithChannel_CacheWritePriceAffects5mAnd1h(t *testing.T) {
|
||||
svc := newTestBillingService()
|
||||
|
||||
chPricing := &ChannelModelPricing{
|
||||
CacheWritePrice: testPtrFloat64(7e-6),
|
||||
}
|
||||
pricing, err := svc.GetModelPricingWithChannel("claude-sonnet-4", chPricing)
|
||||
require.NoError(t, err)
|
||||
|
||||
// CacheWritePrice should set all three: CacheCreationPricePerToken, 5m, and 1h
|
||||
require.InDelta(t, 7e-6, pricing.CacheCreationPricePerToken, 1e-12)
|
||||
require.InDelta(t, 7e-6, pricing.CacheCreation5mPrice, 1e-12)
|
||||
require.InDelta(t, 7e-6, pricing.CacheCreation1hPrice, 1e-12)
|
||||
}
|
||||
|
||||
func TestGetModelPricingWithChannel_CacheReadPriceAffectsPriority(t *testing.T) {
|
||||
svc := newTestBillingService()
|
||||
|
||||
chPricing := &ChannelModelPricing{
|
||||
CacheReadPrice: testPtrFloat64(2e-6),
|
||||
}
|
||||
pricing, err := svc.GetModelPricingWithChannel("claude-sonnet-4", chPricing)
|
||||
require.NoError(t, err)
|
||||
|
||||
// CacheReadPrice should set both normal and priority
|
||||
require.InDelta(t, 2e-6, pricing.CacheReadPricePerToken, 1e-12)
|
||||
require.InDelta(t, 2e-6, pricing.CacheReadPricePerTokenPriority, 1e-12)
|
||||
}
|
||||
|
||||
func TestGetModelPricingWithChannel_UnknownModelReturnsError(t *testing.T) {
|
||||
svc := newTestBillingService()
|
||||
|
||||
chPricing := &ChannelModelPricing{
|
||||
InputPrice: testPtrFloat64(1e-6),
|
||||
}
|
||||
pricing, err := svc.GetModelPricingWithChannel("totally-unknown-model", chPricing)
|
||||
require.Error(t, err)
|
||||
require.Nil(t, pricing)
|
||||
require.Contains(t, err.Error(), "pricing not found")
|
||||
}
|
||||
|
||||
258
backend/internal/service/billing_service_unified_test.go
Normal file
258
backend/internal/service/billing_service_unified_test.go
Normal file
@@ -0,0 +1,258 @@
|
||||
//go:build unit
|
||||
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// CalculateCostUnified
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func TestCalculateCostUnified_NilResolver_FallsBackToOldPath(t *testing.T) {
|
||||
svc := newTestBillingService()
|
||||
|
||||
tokens := UsageTokens{InputTokens: 1000, OutputTokens: 500}
|
||||
input := CostInput{
|
||||
Model: "claude-sonnet-4",
|
||||
Tokens: tokens,
|
||||
RateMultiplier: 1.0,
|
||||
Resolver: nil, // no resolver
|
||||
}
|
||||
cost, err := svc.CalculateCostUnified(input)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Should match the old-path result exactly
|
||||
expected, err := svc.calculateCostInternal("claude-sonnet-4", tokens, 1.0, "", nil)
|
||||
require.NoError(t, err)
|
||||
require.InDelta(t, expected.TotalCost, cost.TotalCost, 1e-10)
|
||||
require.InDelta(t, expected.ActualCost, cost.ActualCost, 1e-10)
|
||||
// BillingMode is NOT set by old path through CalculateCostUnified (resolver == nil)
|
||||
require.Empty(t, cost.BillingMode)
|
||||
}
|
||||
|
||||
func TestCalculateCostUnified_TokenMode(t *testing.T) {
|
||||
bs := newTestBillingService()
|
||||
resolver := NewModelPricingResolver(nil, bs)
|
||||
|
||||
tokens := UsageTokens{InputTokens: 1000, OutputTokens: 500}
|
||||
input := CostInput{
|
||||
Ctx: context.Background(),
|
||||
Model: "claude-sonnet-4",
|
||||
Tokens: tokens,
|
||||
RateMultiplier: 1.5,
|
||||
Resolver: resolver,
|
||||
}
|
||||
cost, err := bs.CalculateCostUnified(input)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, cost)
|
||||
|
||||
// Verify token billing: Input: 1000*3e-6=0.003, Output: 500*15e-6=0.0075
|
||||
expectedTotal := 1000*3e-6 + 500*15e-6
|
||||
require.InDelta(t, expectedTotal, cost.TotalCost, 1e-10)
|
||||
require.InDelta(t, expectedTotal*1.5, cost.ActualCost, 1e-10)
|
||||
require.Equal(t, string(BillingModeToken), cost.BillingMode)
|
||||
}
|
||||
|
||||
func TestCalculateCostUnified_PerRequestMode(t *testing.T) {
|
||||
// Set up a ChannelService with a per-request pricing channel
|
||||
cs := newTestChannelServiceWithCache(t, &channelCache{
|
||||
pricingByGroupModel: map[channelModelKey]*ChannelModelPricing{
|
||||
{groupID: 1, model: "claude-sonnet-4"}: {
|
||||
BillingMode: BillingModePerRequest,
|
||||
PerRequestPrice: testPtrFloat64(0.05),
|
||||
},
|
||||
},
|
||||
channelByGroupID: map[int64]*Channel{
|
||||
1: {ID: 1, Status: StatusActive},
|
||||
},
|
||||
groupPlatform: map[int64]string{1: ""},
|
||||
wildcardByGroupPlatform: map[channelGroupPlatformKey][]*wildcardPricingEntry{},
|
||||
mappingByGroupModel: map[channelModelKey]string{},
|
||||
wildcardMappingByGP: map[channelGroupPlatformKey][]*wildcardMappingEntry{},
|
||||
byID: map[int64]*Channel{},
|
||||
})
|
||||
|
||||
bs := newTestBillingService()
|
||||
resolver := NewModelPricingResolver(cs, bs)
|
||||
groupID := int64(1)
|
||||
|
||||
input := CostInput{
|
||||
Ctx: context.Background(),
|
||||
Model: "claude-sonnet-4",
|
||||
GroupID: &groupID,
|
||||
Tokens: UsageTokens{InputTokens: 100, OutputTokens: 50},
|
||||
RequestCount: 3,
|
||||
RateMultiplier: 2.0,
|
||||
Resolver: resolver,
|
||||
}
|
||||
cost, err := bs.CalculateCostUnified(input)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, cost)
|
||||
|
||||
// 3 requests * $0.05 = $0.15
|
||||
require.InDelta(t, 0.15, cost.TotalCost, 1e-10)
|
||||
// ActualCost = 0.15 * 2.0 = 0.30
|
||||
require.InDelta(t, 0.30, cost.ActualCost, 1e-10)
|
||||
require.Equal(t, string(BillingModePerRequest), cost.BillingMode)
|
||||
}
|
||||
|
||||
func TestCalculateCostUnified_ImageMode(t *testing.T) {
|
||||
cs := newTestChannelServiceWithCache(t, &channelCache{
|
||||
pricingByGroupModel: map[channelModelKey]*ChannelModelPricing{
|
||||
{groupID: 2, model: "gemini-image"}: {
|
||||
BillingMode: BillingModeImage,
|
||||
PerRequestPrice: testPtrFloat64(0.10),
|
||||
},
|
||||
},
|
||||
channelByGroupID: map[int64]*Channel{
|
||||
2: {ID: 2, Status: StatusActive},
|
||||
},
|
||||
groupPlatform: map[int64]string{2: ""},
|
||||
wildcardByGroupPlatform: map[channelGroupPlatformKey][]*wildcardPricingEntry{},
|
||||
mappingByGroupModel: map[channelModelKey]string{},
|
||||
wildcardMappingByGP: map[channelGroupPlatformKey][]*wildcardMappingEntry{},
|
||||
byID: map[int64]*Channel{},
|
||||
})
|
||||
|
||||
bs := &BillingService{
|
||||
cfg: &config.Config{},
|
||||
fallbackPrices: map[string]*ModelPricing{},
|
||||
}
|
||||
resolver := NewModelPricingResolver(cs, bs)
|
||||
groupID := int64(2)
|
||||
|
||||
input := CostInput{
|
||||
Ctx: context.Background(),
|
||||
Model: "gemini-image",
|
||||
GroupID: &groupID,
|
||||
Tokens: UsageTokens{},
|
||||
RequestCount: 2,
|
||||
RateMultiplier: 1.0,
|
||||
Resolver: resolver,
|
||||
}
|
||||
cost, err := bs.CalculateCostUnified(input)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, cost)
|
||||
|
||||
// 2 * $0.10 = $0.20
|
||||
require.InDelta(t, 0.20, cost.TotalCost, 1e-10)
|
||||
require.InDelta(t, 0.20, cost.ActualCost, 1e-10)
|
||||
require.Equal(t, string(BillingModeImage), cost.BillingMode)
|
||||
}
|
||||
|
||||
func TestCalculateCostUnified_RateMultiplierZeroDefaultsToOne(t *testing.T) {
|
||||
bs := newTestBillingService()
|
||||
resolver := NewModelPricingResolver(nil, bs)
|
||||
|
||||
tokens := UsageTokens{InputTokens: 1000, OutputTokens: 500}
|
||||
|
||||
costZero, err := bs.CalculateCostUnified(CostInput{
|
||||
Ctx: context.Background(),
|
||||
Model: "claude-sonnet-4",
|
||||
Tokens: tokens,
|
||||
RateMultiplier: 0, // should default to 1.0
|
||||
Resolver: resolver,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
costOne, err := bs.CalculateCostUnified(CostInput{
|
||||
Ctx: context.Background(),
|
||||
Model: "claude-sonnet-4",
|
||||
Tokens: tokens,
|
||||
RateMultiplier: 1.0,
|
||||
Resolver: resolver,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
require.InDelta(t, costOne.ActualCost, costZero.ActualCost, 1e-10)
|
||||
}
|
||||
|
||||
func TestCalculateCostUnified_NegativeRateMultiplierDefaultsToOne(t *testing.T) {
|
||||
bs := newTestBillingService()
|
||||
resolver := NewModelPricingResolver(nil, bs)
|
||||
|
||||
tokens := UsageTokens{InputTokens: 1000}
|
||||
|
||||
costNeg, err := bs.CalculateCostUnified(CostInput{
|
||||
Ctx: context.Background(),
|
||||
Model: "claude-sonnet-4",
|
||||
Tokens: tokens,
|
||||
RateMultiplier: -5.0,
|
||||
Resolver: resolver,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
costOne, err := bs.CalculateCostUnified(CostInput{
|
||||
Ctx: context.Background(),
|
||||
Model: "claude-sonnet-4",
|
||||
Tokens: tokens,
|
||||
RateMultiplier: 1.0,
|
||||
Resolver: resolver,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
require.InDelta(t, costOne.ActualCost, costNeg.ActualCost, 1e-10)
|
||||
}
|
||||
|
||||
func TestCalculateCostUnified_BillingModeFieldFilled(t *testing.T) {
|
||||
bs := newTestBillingService()
|
||||
resolver := NewModelPricingResolver(nil, bs)
|
||||
|
||||
cost, err := bs.CalculateCostUnified(CostInput{
|
||||
Ctx: context.Background(),
|
||||
Model: "claude-sonnet-4",
|
||||
Tokens: UsageTokens{InputTokens: 100},
|
||||
RateMultiplier: 1.0,
|
||||
Resolver: resolver,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "token", cost.BillingMode)
|
||||
}
|
||||
|
||||
func TestCalculateCostUnified_UsesPreResolvedPricing(t *testing.T) {
|
||||
bs := newTestBillingService()
|
||||
resolver := NewModelPricingResolver(nil, bs)
|
||||
|
||||
// Pre-resolve with per_request mode to verify it's used instead of re-resolving
|
||||
preResolved := &ResolvedPricing{
|
||||
Mode: BillingModePerRequest,
|
||||
DefaultPerRequestPrice: 0.07,
|
||||
}
|
||||
|
||||
cost, err := bs.CalculateCostUnified(CostInput{
|
||||
Ctx: context.Background(),
|
||||
Model: "claude-sonnet-4",
|
||||
Tokens: UsageTokens{InputTokens: 100},
|
||||
RequestCount: 2,
|
||||
RateMultiplier: 1.0,
|
||||
Resolver: resolver,
|
||||
Resolved: preResolved,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, cost)
|
||||
|
||||
// 2 * $0.07 = $0.14
|
||||
require.InDelta(t, 0.14, cost.TotalCost, 1e-10)
|
||||
require.Equal(t, string(BillingModePerRequest), cost.BillingMode)
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// helpers
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
// newTestChannelServiceWithCache creates a ChannelService with a pre-populated
|
||||
// cache snapshot, bypassing the repository layer entirely.
|
||||
func newTestChannelServiceWithCache(t *testing.T, cache *channelCache) *ChannelService {
|
||||
t.Helper()
|
||||
cs := &ChannelService{}
|
||||
cache.loadedAt = time.Now()
|
||||
cs.cache.Store(cache)
|
||||
return cs
|
||||
}
|
||||
@@ -1,8 +1,14 @@
|
||||
//go:build unit
|
||||
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/websearch"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
@@ -140,3 +146,235 @@ func TestBuildTextSummary_NoResults(t *testing.T) {
|
||||
summary := buildTextSummary("test", nil)
|
||||
require.Contains(t, summary, "No search results found for: test")
|
||||
}
|
||||
|
||||
// --- shouldEmulateWebSearch ---
|
||||
|
||||
// webSearchToolBody is a valid request body with exactly one web_search tool.
|
||||
var webSearchToolBody = []byte(`{"tools":[{"type":"web_search"}],"messages":[{"role":"user","content":"test"}]}`)
|
||||
|
||||
// nonWebSearchToolBody is a request body without web_search tool.
|
||||
var nonWebSearchToolBody = []byte(`{"tools":[{"type":"text_editor"}],"messages":[{"role":"user","content":"test"}]}`)
|
||||
|
||||
// newAnthropicAPIKeyAccount creates a test Account with the given web search emulation mode.
|
||||
func newAnthropicAPIKeyAccount(mode string) *Account {
|
||||
return &Account{
|
||||
ID: 1,
|
||||
Platform: PlatformAnthropic,
|
||||
Type: AccountTypeAPIKey,
|
||||
Extra: map[string]any{featureKeyWebSearchEmulation: mode},
|
||||
}
|
||||
}
|
||||
|
||||
// setGlobalWebSearchConfig stores a config in the global cache used by SettingService.IsWebSearchEmulationEnabled.
|
||||
func setGlobalWebSearchConfig(cfg *WebSearchEmulationConfig) {
|
||||
webSearchEmulationCache.Store(&cachedWebSearchEmulationConfig{
|
||||
config: cfg,
|
||||
expiresAt: time.Now().Add(10 * time.Minute).UnixNano(),
|
||||
})
|
||||
}
|
||||
|
||||
// clearGlobalWebSearchConfig resets the global cache to force re-read.
|
||||
func clearGlobalWebSearchConfig() {
|
||||
webSearchEmulationCache.Store((*cachedWebSearchEmulationConfig)(nil))
|
||||
}
|
||||
|
||||
// newSettingServiceForWebSearchTest creates a SettingService with a mock repo pre-loaded with config.
|
||||
func newSettingServiceForWebSearchTest(enabled bool) *SettingService {
|
||||
repo := newMockSettingRepo()
|
||||
cfg := &WebSearchEmulationConfig{
|
||||
Enabled: enabled,
|
||||
Providers: []WebSearchProviderConfig{{Type: "brave", APIKey: "sk-test"}},
|
||||
}
|
||||
data, _ := json.Marshal(cfg)
|
||||
repo.data[SettingKeyWebSearchEmulationConfig] = string(data)
|
||||
return NewSettingService(repo, &config.Config{})
|
||||
}
|
||||
|
||||
// newChannelServiceWithCache creates a ChannelService with a pre-built cache containing the channel.
|
||||
func newChannelServiceWithCache(groupID int64, ch *Channel) *ChannelService {
|
||||
svc := &ChannelService{}
|
||||
cache := &channelCache{
|
||||
channelByGroupID: map[int64]*Channel{groupID: ch},
|
||||
byID: map[int64]*Channel{ch.ID: ch},
|
||||
groupPlatform: map[int64]string{},
|
||||
loadedAt: time.Now(),
|
||||
}
|
||||
svc.cache.Store(cache)
|
||||
return svc
|
||||
}
|
||||
|
||||
func TestShouldEmulateWebSearch_NilManager(t *testing.T) {
|
||||
SetWebSearchManager(nil)
|
||||
defer SetWebSearchManager(nil)
|
||||
|
||||
settingSvc := newSettingServiceForWebSearchTest(true)
|
||||
setGlobalWebSearchConfig(&WebSearchEmulationConfig{
|
||||
Enabled: true,
|
||||
Providers: []WebSearchProviderConfig{{Type: "brave", APIKey: "k"}},
|
||||
})
|
||||
defer clearGlobalWebSearchConfig()
|
||||
|
||||
svc := &GatewayService{settingService: settingSvc}
|
||||
account := newAnthropicAPIKeyAccount(WebSearchModeEnabled)
|
||||
require.False(t, svc.shouldEmulateWebSearch(context.Background(), account, nil, webSearchToolBody))
|
||||
}
|
||||
|
||||
func TestShouldEmulateWebSearch_NotOnlyWebSearchTool(t *testing.T) {
|
||||
mgr := websearch.NewManager([]websearch.ProviderConfig{{Type: "brave", APIKey: "k"}}, nil)
|
||||
SetWebSearchManager(mgr)
|
||||
defer SetWebSearchManager(nil)
|
||||
|
||||
settingSvc := newSettingServiceForWebSearchTest(true)
|
||||
setGlobalWebSearchConfig(&WebSearchEmulationConfig{
|
||||
Enabled: true,
|
||||
Providers: []WebSearchProviderConfig{{Type: "brave", APIKey: "k"}},
|
||||
})
|
||||
defer clearGlobalWebSearchConfig()
|
||||
|
||||
svc := &GatewayService{settingService: settingSvc}
|
||||
account := newAnthropicAPIKeyAccount(WebSearchModeEnabled)
|
||||
require.False(t, svc.shouldEmulateWebSearch(context.Background(), account, nil, nonWebSearchToolBody))
|
||||
}
|
||||
|
||||
func TestShouldEmulateWebSearch_GlobalDisabled(t *testing.T) {
|
||||
mgr := websearch.NewManager([]websearch.ProviderConfig{{Type: "brave", APIKey: "k"}}, nil)
|
||||
SetWebSearchManager(mgr)
|
||||
defer SetWebSearchManager(nil)
|
||||
|
||||
// Global config disabled
|
||||
setGlobalWebSearchConfig(&WebSearchEmulationConfig{
|
||||
Enabled: false,
|
||||
Providers: []WebSearchProviderConfig{{Type: "brave", APIKey: "k"}},
|
||||
})
|
||||
defer clearGlobalWebSearchConfig()
|
||||
|
||||
settingSvc := newSettingServiceForWebSearchTest(false)
|
||||
svc := &GatewayService{settingService: settingSvc}
|
||||
account := newAnthropicAPIKeyAccount(WebSearchModeEnabled)
|
||||
require.False(t, svc.shouldEmulateWebSearch(context.Background(), account, nil, webSearchToolBody))
|
||||
}
|
||||
|
||||
func TestShouldEmulateWebSearch_AccountDisabled(t *testing.T) {
|
||||
mgr := websearch.NewManager([]websearch.ProviderConfig{{Type: "brave", APIKey: "k"}}, nil)
|
||||
SetWebSearchManager(mgr)
|
||||
defer SetWebSearchManager(nil)
|
||||
|
||||
setGlobalWebSearchConfig(&WebSearchEmulationConfig{
|
||||
Enabled: true,
|
||||
Providers: []WebSearchProviderConfig{{Type: "brave", APIKey: "k"}},
|
||||
})
|
||||
defer clearGlobalWebSearchConfig()
|
||||
|
||||
settingSvc := newSettingServiceForWebSearchTest(true)
|
||||
svc := &GatewayService{settingService: settingSvc}
|
||||
account := newAnthropicAPIKeyAccount(WebSearchModeDisabled)
|
||||
require.False(t, svc.shouldEmulateWebSearch(context.Background(), account, nil, webSearchToolBody))
|
||||
}
|
||||
|
||||
func TestShouldEmulateWebSearch_AccountEnabled(t *testing.T) {
|
||||
mgr := websearch.NewManager([]websearch.ProviderConfig{{Type: "brave", APIKey: "k"}}, nil)
|
||||
SetWebSearchManager(mgr)
|
||||
defer SetWebSearchManager(nil)
|
||||
|
||||
setGlobalWebSearchConfig(&WebSearchEmulationConfig{
|
||||
Enabled: true,
|
||||
Providers: []WebSearchProviderConfig{{Type: "brave", APIKey: "k"}},
|
||||
})
|
||||
defer clearGlobalWebSearchConfig()
|
||||
|
||||
settingSvc := newSettingServiceForWebSearchTest(true)
|
||||
svc := &GatewayService{settingService: settingSvc}
|
||||
account := newAnthropicAPIKeyAccount(WebSearchModeEnabled)
|
||||
require.True(t, svc.shouldEmulateWebSearch(context.Background(), account, nil, webSearchToolBody))
|
||||
}
|
||||
|
||||
func TestShouldEmulateWebSearch_DefaultMode_ChannelEnabled(t *testing.T) {
|
||||
mgr := websearch.NewManager([]websearch.ProviderConfig{{Type: "brave", APIKey: "k"}}, nil)
|
||||
SetWebSearchManager(mgr)
|
||||
defer SetWebSearchManager(nil)
|
||||
|
||||
setGlobalWebSearchConfig(&WebSearchEmulationConfig{
|
||||
Enabled: true,
|
||||
Providers: []WebSearchProviderConfig{{Type: "brave", APIKey: "k"}},
|
||||
})
|
||||
defer clearGlobalWebSearchConfig()
|
||||
|
||||
settingSvc := newSettingServiceForWebSearchTest(true)
|
||||
ch := &Channel{
|
||||
ID: 10,
|
||||
Status: StatusActive,
|
||||
FeaturesConfig: map[string]any{
|
||||
featureKeyWebSearchEmulation: map[string]any{PlatformAnthropic: true},
|
||||
},
|
||||
}
|
||||
channelSvc := newChannelServiceWithCache(42, ch)
|
||||
svc := &GatewayService{settingService: settingSvc, channelService: channelSvc}
|
||||
|
||||
account := newAnthropicAPIKeyAccount(WebSearchModeDefault)
|
||||
groupID := int64(42)
|
||||
require.True(t, svc.shouldEmulateWebSearch(context.Background(), account, &groupID, webSearchToolBody))
|
||||
}
|
||||
|
||||
func TestShouldEmulateWebSearch_DefaultMode_ChannelDisabled(t *testing.T) {
|
||||
mgr := websearch.NewManager([]websearch.ProviderConfig{{Type: "brave", APIKey: "k"}}, nil)
|
||||
SetWebSearchManager(mgr)
|
||||
defer SetWebSearchManager(nil)
|
||||
|
||||
setGlobalWebSearchConfig(&WebSearchEmulationConfig{
|
||||
Enabled: true,
|
||||
Providers: []WebSearchProviderConfig{{Type: "brave", APIKey: "k"}},
|
||||
})
|
||||
defer clearGlobalWebSearchConfig()
|
||||
|
||||
settingSvc := newSettingServiceForWebSearchTest(true)
|
||||
ch := &Channel{
|
||||
ID: 10,
|
||||
Status: StatusActive,
|
||||
FeaturesConfig: map[string]any{
|
||||
featureKeyWebSearchEmulation: map[string]any{PlatformAnthropic: false},
|
||||
},
|
||||
}
|
||||
channelSvc := newChannelServiceWithCache(42, ch)
|
||||
svc := &GatewayService{settingService: settingSvc, channelService: channelSvc}
|
||||
|
||||
account := newAnthropicAPIKeyAccount(WebSearchModeDefault)
|
||||
groupID := int64(42)
|
||||
require.False(t, svc.shouldEmulateWebSearch(context.Background(), account, &groupID, webSearchToolBody))
|
||||
}
|
||||
|
||||
func TestShouldEmulateWebSearch_DefaultMode_NilGroupID(t *testing.T) {
|
||||
mgr := websearch.NewManager([]websearch.ProviderConfig{{Type: "brave", APIKey: "k"}}, nil)
|
||||
SetWebSearchManager(mgr)
|
||||
defer SetWebSearchManager(nil)
|
||||
|
||||
setGlobalWebSearchConfig(&WebSearchEmulationConfig{
|
||||
Enabled: true,
|
||||
Providers: []WebSearchProviderConfig{{Type: "brave", APIKey: "k"}},
|
||||
})
|
||||
defer clearGlobalWebSearchConfig()
|
||||
|
||||
settingSvc := newSettingServiceForWebSearchTest(true)
|
||||
svc := &GatewayService{settingService: settingSvc}
|
||||
account := newAnthropicAPIKeyAccount(WebSearchModeDefault)
|
||||
// nil groupID + default mode → falls through to channel check → returns false
|
||||
require.False(t, svc.shouldEmulateWebSearch(context.Background(), account, nil, webSearchToolBody))
|
||||
}
|
||||
|
||||
func TestShouldEmulateWebSearch_DefaultMode_NilChannelService(t *testing.T) {
|
||||
mgr := websearch.NewManager([]websearch.ProviderConfig{{Type: "brave", APIKey: "k"}}, nil)
|
||||
SetWebSearchManager(mgr)
|
||||
defer SetWebSearchManager(nil)
|
||||
|
||||
setGlobalWebSearchConfig(&WebSearchEmulationConfig{
|
||||
Enabled: true,
|
||||
Providers: []WebSearchProviderConfig{{Type: "brave", APIKey: "k"}},
|
||||
})
|
||||
defer clearGlobalWebSearchConfig()
|
||||
|
||||
settingSvc := newSettingServiceForWebSearchTest(true)
|
||||
svc := &GatewayService{settingService: settingSvc, channelService: nil}
|
||||
account := newAnthropicAPIKeyAccount(WebSearchModeDefault)
|
||||
groupID := int64(42)
|
||||
// nil channelService + default mode → returns false
|
||||
require.False(t, svc.shouldEmulateWebSearch(context.Background(), account, &groupID, webSearchToolBody))
|
||||
}
|
||||
|
||||
156
backend/internal/service/notify_email_entry_test.go
Normal file
156
backend/internal/service/notify_email_entry_test.go
Normal file
@@ -0,0 +1,156 @@
|
||||
//go:build unit
|
||||
|
||||
package service
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// ---------- ParseNotifyEmails ----------
|
||||
|
||||
func TestParseNotifyEmails_EmptyString(t *testing.T) {
|
||||
result := ParseNotifyEmails("")
|
||||
require.Nil(t, result)
|
||||
}
|
||||
|
||||
func TestParseNotifyEmails_EmptyArray(t *testing.T) {
|
||||
result := ParseNotifyEmails("[]")
|
||||
require.Nil(t, result)
|
||||
}
|
||||
|
||||
func TestParseNotifyEmails_Null(t *testing.T) {
|
||||
// "null" is valid JSON that unmarshals into a nil string slice.
|
||||
// The old-format branch then returns an empty (non-nil) slice.
|
||||
result := ParseNotifyEmails("null")
|
||||
require.Empty(t, result)
|
||||
}
|
||||
|
||||
func TestParseNotifyEmails_WhitespaceOnly(t *testing.T) {
|
||||
result := ParseNotifyEmails(" ")
|
||||
require.Nil(t, result)
|
||||
}
|
||||
|
||||
func TestParseNotifyEmails_OldFormat(t *testing.T) {
|
||||
raw := `["alice@example.com", "bob@example.com"]`
|
||||
result := ParseNotifyEmails(raw)
|
||||
require.Len(t, result, 2)
|
||||
|
||||
require.Equal(t, "alice@example.com", result[0].Email)
|
||||
require.False(t, result[0].Verified, "old format emails should default to unverified")
|
||||
require.False(t, result[0].Disabled)
|
||||
|
||||
require.Equal(t, "bob@example.com", result[1].Email)
|
||||
require.False(t, result[1].Verified)
|
||||
require.False(t, result[1].Disabled)
|
||||
}
|
||||
|
||||
func TestParseNotifyEmails_OldFormat_SkipsEmptyEntries(t *testing.T) {
|
||||
raw := `["alice@example.com", "", " ", "bob@example.com"]`
|
||||
result := ParseNotifyEmails(raw)
|
||||
require.Len(t, result, 2)
|
||||
require.Equal(t, "alice@example.com", result[0].Email)
|
||||
require.Equal(t, "bob@example.com", result[1].Email)
|
||||
}
|
||||
|
||||
func TestParseNotifyEmails_NewFormat(t *testing.T) {
|
||||
raw := `[{"email":"alice@example.com","verified":true,"disabled":false},{"email":"bob@example.com","verified":false,"disabled":true}]`
|
||||
result := ParseNotifyEmails(raw)
|
||||
require.Len(t, result, 2)
|
||||
|
||||
require.Equal(t, "alice@example.com", result[0].Email)
|
||||
require.True(t, result[0].Verified)
|
||||
require.False(t, result[0].Disabled)
|
||||
|
||||
require.Equal(t, "bob@example.com", result[1].Email)
|
||||
require.False(t, result[1].Verified)
|
||||
require.True(t, result[1].Disabled)
|
||||
}
|
||||
|
||||
func TestParseNotifyEmails_NewFormat_SingleEntry(t *testing.T) {
|
||||
raw := `[{"email":"solo@example.com","verified":true,"disabled":false}]`
|
||||
result := ParseNotifyEmails(raw)
|
||||
require.Len(t, result, 1)
|
||||
require.Equal(t, "solo@example.com", result[0].Email)
|
||||
require.True(t, result[0].Verified)
|
||||
}
|
||||
|
||||
func TestParseNotifyEmails_InvalidJSON(t *testing.T) {
|
||||
result := ParseNotifyEmails(`{not valid json`)
|
||||
require.Nil(t, result)
|
||||
}
|
||||
|
||||
func TestParseNotifyEmails_InvalidJSONObject(t *testing.T) {
|
||||
// A plain JSON object (not array) should return nil.
|
||||
result := ParseNotifyEmails(`{"email":"a@b.com"}`)
|
||||
require.Nil(t, result)
|
||||
}
|
||||
|
||||
func TestParseNotifyEmails_WhitespacePadding(t *testing.T) {
|
||||
raw := ` ["padded@example.com"] `
|
||||
result := ParseNotifyEmails(raw)
|
||||
require.Len(t, result, 1)
|
||||
require.Equal(t, "padded@example.com", result[0].Email)
|
||||
}
|
||||
|
||||
// ---------- MarshalNotifyEmails ----------
|
||||
|
||||
func TestMarshalNotifyEmails_EmptySlice(t *testing.T) {
|
||||
result := MarshalNotifyEmails([]NotifyEmailEntry{})
|
||||
require.Equal(t, "[]", result)
|
||||
}
|
||||
|
||||
func TestMarshalNotifyEmails_NilSlice(t *testing.T) {
|
||||
result := MarshalNotifyEmails(nil)
|
||||
require.Equal(t, "[]", result)
|
||||
}
|
||||
|
||||
func TestMarshalNotifyEmails_SingleEntry(t *testing.T) {
|
||||
entries := []NotifyEmailEntry{
|
||||
{Email: "test@example.com", Verified: true, Disabled: false},
|
||||
}
|
||||
result := MarshalNotifyEmails(entries)
|
||||
require.Contains(t, result, `"email":"test@example.com"`)
|
||||
require.Contains(t, result, `"verified":true`)
|
||||
require.Contains(t, result, `"disabled":false`)
|
||||
|
||||
// Round-trip: parsing the marshalled result should produce the original entries.
|
||||
parsed := ParseNotifyEmails(result)
|
||||
require.Len(t, parsed, 1)
|
||||
require.Equal(t, entries[0], parsed[0])
|
||||
}
|
||||
|
||||
func TestMarshalNotifyEmails_MultipleEntries(t *testing.T) {
|
||||
entries := []NotifyEmailEntry{
|
||||
{Email: "a@example.com", Verified: true, Disabled: false},
|
||||
{Email: "b@example.com", Verified: false, Disabled: true},
|
||||
}
|
||||
result := MarshalNotifyEmails(entries)
|
||||
|
||||
// Round-trip verification.
|
||||
parsed := ParseNotifyEmails(result)
|
||||
require.Len(t, parsed, 2)
|
||||
require.Equal(t, entries[0], parsed[0])
|
||||
require.Equal(t, entries[1], parsed[1])
|
||||
}
|
||||
|
||||
func TestMarshalNotifyEmails_RoundTrip_NewFormat(t *testing.T) {
|
||||
original := []NotifyEmailEntry{
|
||||
{Email: "x@example.com", Verified: true, Disabled: true},
|
||||
{Email: "y@example.com", Verified: false, Disabled: false},
|
||||
}
|
||||
marshalled := MarshalNotifyEmails(original)
|
||||
parsed := ParseNotifyEmails(marshalled)
|
||||
require.Equal(t, original, parsed)
|
||||
}
|
||||
|
||||
// ---------- isOldStringArrayFormat (indirectly via ParseNotifyEmails) ----------
|
||||
|
||||
func TestParseNotifyEmails_MixedOldFormatWithWhitespace(t *testing.T) {
|
||||
// Emails with leading/trailing whitespace in old format should be trimmed.
|
||||
raw := `[" alice@example.com "]`
|
||||
result := ParseNotifyEmails(raw)
|
||||
require.Len(t, result, 1)
|
||||
require.Equal(t, "alice@example.com", result[0].Email)
|
||||
}
|
||||
@@ -1,9 +1,12 @@
|
||||
//go:build unit
|
||||
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/websearch"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
@@ -141,3 +144,123 @@ func TestSanitizeWebSearchConfig_DoesNotMutateOriginal(t *testing.T) {
|
||||
_ = SanitizeWebSearchConfig(context.Background(), cfg)
|
||||
require.Equal(t, "secret", cfg.Providers[0].APIKey)
|
||||
}
|
||||
|
||||
// --- PopulateWebSearchUsage ---
|
||||
|
||||
func TestPopulateWebSearchUsage_NilInput(t *testing.T) {
|
||||
require.Nil(t, PopulateWebSearchUsage(context.Background(), nil))
|
||||
}
|
||||
|
||||
func TestPopulateWebSearchUsage_NoManager_QuotaUsedZero(t *testing.T) {
|
||||
// Ensure no global manager is set
|
||||
SetWebSearchManager(nil)
|
||||
defer SetWebSearchManager(nil)
|
||||
|
||||
cfg := &WebSearchEmulationConfig{
|
||||
Enabled: true,
|
||||
Providers: []WebSearchProviderConfig{
|
||||
{Type: "brave", APIKey: "sk-key", QuotaLimit: int64Ptr(1000)},
|
||||
},
|
||||
}
|
||||
out := PopulateWebSearchUsage(context.Background(), cfg)
|
||||
require.NotNil(t, out)
|
||||
require.Len(t, out.Providers, 1)
|
||||
require.Equal(t, int64(0), out.Providers[0].QuotaUsed)
|
||||
}
|
||||
|
||||
func TestPopulateWebSearchUsage_APIKeyConfigured_True(t *testing.T) {
|
||||
SetWebSearchManager(nil)
|
||||
defer SetWebSearchManager(nil)
|
||||
|
||||
cfg := &WebSearchEmulationConfig{
|
||||
Providers: []WebSearchProviderConfig{
|
||||
{Type: "brave", APIKey: "sk-key"},
|
||||
},
|
||||
}
|
||||
out := PopulateWebSearchUsage(context.Background(), cfg)
|
||||
require.True(t, out.Providers[0].APIKeyConfigured)
|
||||
}
|
||||
|
||||
func TestPopulateWebSearchUsage_APIKeyConfigured_False(t *testing.T) {
|
||||
SetWebSearchManager(nil)
|
||||
defer SetWebSearchManager(nil)
|
||||
|
||||
cfg := &WebSearchEmulationConfig{
|
||||
Providers: []WebSearchProviderConfig{
|
||||
{Type: "brave", APIKey: ""},
|
||||
},
|
||||
}
|
||||
out := PopulateWebSearchUsage(context.Background(), cfg)
|
||||
require.False(t, out.Providers[0].APIKeyConfigured)
|
||||
}
|
||||
|
||||
func TestPopulateWebSearchUsage_NilQuotaLimit(t *testing.T) {
|
||||
SetWebSearchManager(nil)
|
||||
defer SetWebSearchManager(nil)
|
||||
|
||||
cfg := &WebSearchEmulationConfig{
|
||||
Providers: []WebSearchProviderConfig{
|
||||
{Type: "brave", APIKey: "sk-key", QuotaLimit: nil},
|
||||
},
|
||||
}
|
||||
out := PopulateWebSearchUsage(context.Background(), cfg)
|
||||
require.Nil(t, out.Providers[0].QuotaLimit)
|
||||
}
|
||||
|
||||
func TestPopulateWebSearchUsage_NonNilQuotaLimit(t *testing.T) {
|
||||
SetWebSearchManager(nil)
|
||||
defer SetWebSearchManager(nil)
|
||||
|
||||
cfg := &WebSearchEmulationConfig{
|
||||
Providers: []WebSearchProviderConfig{
|
||||
{Type: "brave", APIKey: "sk-key", QuotaLimit: int64Ptr(500)},
|
||||
},
|
||||
}
|
||||
out := PopulateWebSearchUsage(context.Background(), cfg)
|
||||
require.NotNil(t, out.Providers[0].QuotaLimit)
|
||||
require.Equal(t, int64(500), *out.Providers[0].QuotaLimit)
|
||||
}
|
||||
|
||||
func TestPopulateWebSearchUsage_WithManager_NilRedis(t *testing.T) {
|
||||
// Manager with nil Redis returns 0 usage without error
|
||||
mgr := websearch.NewManager([]websearch.ProviderConfig{
|
||||
{Type: "brave", APIKey: "k"},
|
||||
}, nil)
|
||||
SetWebSearchManager(mgr)
|
||||
defer SetWebSearchManager(nil)
|
||||
|
||||
cfg := &WebSearchEmulationConfig{
|
||||
Providers: []WebSearchProviderConfig{
|
||||
{Type: "brave", APIKey: "sk-key", QuotaLimit: int64Ptr(1000)},
|
||||
},
|
||||
}
|
||||
out := PopulateWebSearchUsage(context.Background(), cfg)
|
||||
require.Equal(t, int64(0), out.Providers[0].QuotaUsed)
|
||||
require.True(t, out.Providers[0].APIKeyConfigured)
|
||||
}
|
||||
|
||||
func TestPopulateWebSearchUsage_DoesNotMutateOriginal(t *testing.T) {
|
||||
SetWebSearchManager(nil)
|
||||
defer SetWebSearchManager(nil)
|
||||
|
||||
cfg := &WebSearchEmulationConfig{
|
||||
Providers: []WebSearchProviderConfig{
|
||||
{Type: "brave", APIKey: "secret", QuotaLimit: int64Ptr(100)},
|
||||
},
|
||||
}
|
||||
_ = PopulateWebSearchUsage(context.Background(), cfg)
|
||||
// Original should be unchanged
|
||||
require.Equal(t, "secret", cfg.Providers[0].APIKey)
|
||||
require.Equal(t, int64(0), cfg.Providers[0].QuotaUsed)
|
||||
}
|
||||
|
||||
// --- ResetWebSearchUsage ---
|
||||
|
||||
func TestResetWebSearchUsage_NilManager(t *testing.T) {
|
||||
SetWebSearchManager(nil)
|
||||
defer SetWebSearchManager(nil)
|
||||
|
||||
err := ResetWebSearchUsage(context.Background(), "brave")
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "not initialized")
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user