diff --git a/backend/internal/pkg/websearch/manager_test.go b/backend/internal/pkg/websearch/manager_test.go index a4beef68..cbcf1b76 100644 --- a/backend/internal/pkg/websearch/manager_test.go +++ b/backend/internal/pkg/websearch/manager_test.go @@ -313,3 +313,11 @@ func TestNewHTTPClient_ValidSOCKS5Proxy(t *testing.T) { require.NoError(t, err) require.NotNil(t, c) } + +// --- ResetUsage --- + +func TestManager_ResetUsage_NilRedis(t *testing.T) { + m := NewManager(nil, nil) + err := m.ResetUsage(context.Background(), "brave") + require.NoError(t, err) +} diff --git a/backend/internal/service/account_stats_pricing_test.go b/backend/internal/service/account_stats_pricing_test.go index 23409d5e..36e5eb74 100644 --- a/backend/internal/service/account_stats_pricing_test.go +++ b/backend/internal/service/account_stats_pricing_test.go @@ -3,7 +3,9 @@ package service import ( + "context" "testing" + "time" "github.com/stretchr/testify/require" ) @@ -527,3 +529,243 @@ func TestTryModelFilePricing_WithCacheTokens(t *testing.T) { // = 0.1 + 0.1 + 0.6 + 0.15 = 0.95 require.InDelta(t, 0.95, *result, 1e-12) } + +// --------------------------------------------------------------------------- +// resolveAccountStatsCost — integration tests covering the 4-level priority chain +// --------------------------------------------------------------------------- + +func TestResolveAccountStatsCost_NilChannelService(t *testing.T) { + result := resolveAccountStatsCost( + context.Background(), + nil, // channelService is nil + newTestBillingServiceWithPrices(map[string]*ModelPricing{}), + 1, 1, "claude-sonnet-4", + UsageTokens{InputTokens: 100}, 1, 0.5, + ) + require.Nil(t, result) +} + +func TestResolveAccountStatsCost_EmptyUpstreamModel(t *testing.T) { + cs := newTestChannelServiceForStats(t, &Channel{ + ID: 1, + Status: StatusActive, + }, 1, "") + + result := resolveAccountStatsCost( + context.Background(), + cs, + newTestBillingServiceWithPrices(map[string]*ModelPricing{}), + 1, 1, "", // empty upstream model + UsageTokens{InputTokens: 100}, 1, 0.5, + ) + require.Nil(t, result) +} + +func TestResolveAccountStatsCost_GetChannelForGroupReturnsNil(t *testing.T) { + // Group 99 is NOT in the cache, so GetChannelForGroup returns nil + cs := newTestChannelServiceForStats(t, &Channel{ + ID: 1, + Status: StatusActive, + }, 1, "") + + result := resolveAccountStatsCost( + context.Background(), + cs, + newTestBillingServiceWithPrices(map[string]*ModelPricing{}), + 1, 99, "claude-sonnet-4", // groupID 99 has no channel + UsageTokens{InputTokens: 100}, 1, 0.5, + ) + require.Nil(t, result) +} + +func TestResolveAccountStatsCost_HitsCustomRule(t *testing.T) { + channel := &Channel{ + ID: 1, + Status: StatusActive, + AccountStatsPricingRules: []AccountStatsPricingRule{ + { + GroupIDs: []int64{10}, + Pricing: []ChannelModelPricing{ + { + ID: 100, + Models: []string{"claude-sonnet-4"}, + InputPrice: testPtrFloat64(0.01), + OutputPrice: testPtrFloat64(0.02), + }, + }, + }, + }, + } + cs := newTestChannelServiceForStats(t, channel, 10, "anthropic") + + tokens := UsageTokens{InputTokens: 100, OutputTokens: 50} + + result := resolveAccountStatsCost( + context.Background(), + cs, nil, // billingService not needed when custom rule hits + 1, 10, "claude-sonnet-4", + tokens, 1, 999.0, // totalCost ignored because custom rule hits + ) + require.NotNil(t, result) + // 100*0.01 + 50*0.02 = 1.0 + 1.0 = 2.0 + require.InDelta(t, 2.0, *result, 1e-12) +} + +func TestResolveAccountStatsCost_ApplyPricingToAccountStats_UsesTotalCost(t *testing.T) { + channel := &Channel{ + ID: 1, + Status: StatusActive, + ApplyPricingToAccountStats: true, + // No custom rules + } + cs := newTestChannelServiceForStats(t, channel, 10, "anthropic") + + tokens := UsageTokens{InputTokens: 100, OutputTokens: 50} + + result := resolveAccountStatsCost( + context.Background(), + cs, nil, + 1, 10, "claude-sonnet-4", + tokens, 1, 0.75, // totalCost = 0.75 + ) + require.NotNil(t, result) + require.InDelta(t, 0.75, *result, 1e-12) +} + +func TestResolveAccountStatsCost_ApplyPricingToAccountStats_ZeroTotalCost_ReturnsNil(t *testing.T) { + channel := &Channel{ + ID: 1, + Status: StatusActive, + ApplyPricingToAccountStats: true, + } + cs := newTestChannelServiceForStats(t, channel, 10, "anthropic") + + result := resolveAccountStatsCost( + context.Background(), + cs, nil, + 1, 10, "claude-sonnet-4", + UsageTokens{}, 1, 0.0, // totalCost = 0 + ) + require.Nil(t, result) +} + +func TestResolveAccountStatsCost_FallsBackToLiteLLM(t *testing.T) { + channel := &Channel{ + ID: 1, + Status: StatusActive, + ApplyPricingToAccountStats: false, // not enabled + // No custom rules + } + cs := newTestChannelServiceForStats(t, channel, 10, "anthropic") + + bs := newTestBillingServiceWithPrices(map[string]*ModelPricing{ + "claude-sonnet-4": { + InputPricePerToken: 0.001, + OutputPricePerToken: 0.002, + }, + }) + + tokens := UsageTokens{InputTokens: 100, OutputTokens: 50} + + result := resolveAccountStatsCost( + context.Background(), + cs, bs, + 1, 10, "claude-sonnet-4", + tokens, 1, 999.0, // totalCost ignored + ) + require.NotNil(t, result) + // 100*0.001 + 50*0.002 = 0.1 + 0.1 = 0.2 + require.InDelta(t, 0.2, *result, 1e-12) +} + +func TestResolveAccountStatsCost_AllMiss_ReturnsNil(t *testing.T) { + channel := &Channel{ + ID: 1, + Status: StatusActive, + ApplyPricingToAccountStats: false, + // No custom rules + } + cs := newTestChannelServiceForStats(t, channel, 10, "anthropic") + + // BillingService with no pricing for the model + bs := newTestBillingServiceWithPrices(map[string]*ModelPricing{}) + + tokens := UsageTokens{InputTokens: 100, OutputTokens: 50} + + result := resolveAccountStatsCost( + context.Background(), + cs, bs, + 1, 10, "totally-unknown-model", + tokens, 1, 0.0, + ) + require.Nil(t, result) +} + +func TestResolveAccountStatsCost_NilBillingService_SkipsLiteLLM(t *testing.T) { + channel := &Channel{ + ID: 1, + Status: StatusActive, + ApplyPricingToAccountStats: false, + } + cs := newTestChannelServiceForStats(t, channel, 10, "anthropic") + + result := resolveAccountStatsCost( + context.Background(), + cs, nil, // billingService is nil + 1, 10, "claude-sonnet-4", + UsageTokens{InputTokens: 100}, 1, 0.0, + ) + require.Nil(t, result) +} + +func TestResolveAccountStatsCost_CustomRulePriorityOverApplyPricing(t *testing.T) { + // Both custom rule and ApplyPricingToAccountStats are configured; + // custom rule should take precedence. + channel := &Channel{ + ID: 1, + Status: StatusActive, + ApplyPricingToAccountStats: true, + AccountStatsPricingRules: []AccountStatsPricingRule{ + { + GroupIDs: []int64{10}, + Pricing: []ChannelModelPricing{ + { + ID: 100, + Models: []string{"claude-sonnet-4"}, + InputPrice: testPtrFloat64(0.05), + }, + }, + }, + }, + } + cs := newTestChannelServiceForStats(t, channel, 10, "anthropic") + + tokens := UsageTokens{InputTokens: 100} + + result := resolveAccountStatsCost( + context.Background(), + cs, nil, + 1, 10, "claude-sonnet-4", + tokens, 1, 99.0, // totalCost = 99.0 (would be used if ApplyPricing wins) + ) + require.NotNil(t, result) + // Custom rule: 100*0.05 = 5.0 (NOT 99.0 from totalCost) + require.InDelta(t, 5.0, *result, 1e-12) +} + +// --------------------------------------------------------------------------- +// helpers for resolveAccountStatsCost tests +// --------------------------------------------------------------------------- + +// newTestChannelServiceForStats creates a ChannelService with a single channel +// mapped to the given groupID, suitable for resolveAccountStatsCost tests. +func newTestChannelServiceForStats(t *testing.T, channel *Channel, groupID int64, platform string) *ChannelService { + t.Helper() + cache := newEmptyChannelCache() + cache.channelByGroupID[groupID] = channel + cache.groupPlatform[groupID] = platform + cs := &ChannelService{} + cache.loadedAt = time.Now() + cs.cache.Store(cache) + return cs +} diff --git a/backend/internal/service/balance_notify_check_test.go b/backend/internal/service/balance_notify_check_test.go index 955f3129..7bb4cf9e 100644 --- a/backend/internal/service/balance_notify_check_test.go +++ b/backend/internal/service/balance_notify_check_test.go @@ -178,3 +178,227 @@ func TestGetSiteName_Configured(t *testing.T) { repo.data[SettingKeySiteName] = "My Site" require.Equal(t, "My Site", s.getSiteName(context.Background())) } + +// ---------- crossedDownward ---------- + +func TestCrossedDownward_CrossesBelow(t *testing.T) { + // oldBalance > threshold, newBalance < threshold → true + require.True(t, crossedDownward(100, 5, 10)) +} + +func TestCrossedDownward_ExactlyAtThreshold(t *testing.T) { + // oldBalance > threshold, newBalance == threshold → false (not below) + require.False(t, crossedDownward(100, 10, 10)) +} + +func TestCrossedDownward_OldExactlyAtThreshold_NewBelow(t *testing.T) { + // oldBalance == threshold, newBalance < threshold → true + // (at-or-above → below counts as a crossing) + require.True(t, crossedDownward(10, 5, 10)) +} + +func TestCrossedDownward_AlreadyBelow(t *testing.T) { + // oldBalance < threshold → false (already below, no new crossing) + require.False(t, crossedDownward(5, 3, 10)) +} + +func TestCrossedDownward_BothAbove(t *testing.T) { + // oldBalance > threshold, newBalance > threshold → false (no crossing) + require.False(t, crossedDownward(100, 50, 10)) +} + +func TestCrossedDownward_ZeroThreshold(t *testing.T) { + // threshold == 0 → oldV >= 0 is always true, but newV < 0 only for negatives + // Typical case: positive balances should not fire when threshold is 0. + require.False(t, crossedDownward(10, 5, 0)) + require.False(t, crossedDownward(0, 0, 0)) +} + +func TestCrossedDownward_ZeroThreshold_NegativeNew(t *testing.T) { + // Edge case: newBalance goes negative with threshold=0. + require.True(t, crossedDownward(5, -1, 0)) +} + +func TestCrossedDownward_NegativeValues(t *testing.T) { + // Both already negative, threshold is positive → no crossing (already below). + require.False(t, crossedDownward(-5, -10, 10)) +} + +func TestCrossedDownward_LargeDecrement(t *testing.T) { + // A single large deduction crosses the threshold. + require.True(t, crossedDownward(1000, 0.5, 100)) +} + +func TestCrossedDownward_SmallDecrement_NoCrossing(t *testing.T) { + // A tiny deduction stays above threshold. + require.False(t, crossedDownward(100, 99.99, 10)) +} + +// ---------- checkQuotaDimCrossings ---------- + +func TestCheckQuotaDimCrossings_NoDimensions(t *testing.T) { + s, _ := newBalanceNotifyServiceForTest() + account := &Account{ID: 1, Name: "test", Platform: PlatformAnthropic} + // Empty dims → no crossing, no panic. + s.checkQuotaDimCrossings(account, nil, 10, []string{"admin@example.com"}, "TestSite") + s.checkQuotaDimCrossings(account, []quotaDim{}, 10, []string{"admin@example.com"}, "TestSite") +} + +func TestCheckQuotaDimCrossings_DisabledDimension(t *testing.T) { + s, _ := newBalanceNotifyServiceForTest() + account := &Account{ID: 1, Name: "test", Platform: PlatformAnthropic} + dims := []quotaDim{ + { + name: quotaDimDaily, + enabled: false, // disabled + threshold: 100, + thresholdType: thresholdTypeFixed, + currentUsed: 950, + limit: 1000, + }, + } + // Disabled dimension should be skipped even if crossing would occur. + s.checkQuotaDimCrossings(account, dims, 50, []string{"admin@example.com"}, "TestSite") +} + +func TestCheckQuotaDimCrossings_ZeroThresholdSkipped(t *testing.T) { + s, _ := newBalanceNotifyServiceForTest() + account := &Account{ID: 1, Name: "test", Platform: PlatformAnthropic} + dims := []quotaDim{ + { + name: quotaDimDaily, + enabled: true, + threshold: 0, // zero threshold + thresholdType: thresholdTypeFixed, + currentUsed: 950, + limit: 1000, + }, + } + // Zero threshold → skipped. + s.checkQuotaDimCrossings(account, dims, 50, []string{"admin@example.com"}, "TestSite") +} + +func TestCheckQuotaDimCrossings_NoCrossing_BothBelowThreshold(t *testing.T) { + s, _ := newBalanceNotifyServiceForTest() + account := &Account{ID: 1, Name: "test", Platform: PlatformAnthropic} + // threshold=400 remaining, limit=1000 → effectiveThreshold = 600 (usage trigger) + // currentUsed=300 (after), oldUsed=300-50=250 (before). Both < 600, no crossing. + dims := []quotaDim{ + { + name: quotaDimDaily, + enabled: true, + threshold: 400, + thresholdType: thresholdTypeFixed, + currentUsed: 300, + limit: 1000, + }, + } + s.checkQuotaDimCrossings(account, dims, 50, []string{"admin@example.com"}, "TestSite") +} + +func TestCheckQuotaDimCrossings_NoCrossing_BothAboveThreshold(t *testing.T) { + s, _ := newBalanceNotifyServiceForTest() + account := &Account{ID: 1, Name: "test", Platform: PlatformAnthropic} + // threshold=400 remaining, limit=1000 → effectiveThreshold = 600 (usage trigger) + // currentUsed=800 (after), oldUsed=800-50=750 (before). Both >= 600, no crossing. + dims := []quotaDim{ + { + name: quotaDimDaily, + enabled: true, + threshold: 400, + thresholdType: thresholdTypeFixed, + currentUsed: 800, + limit: 1000, + }, + } + s.checkQuotaDimCrossings(account, dims, 50, []string{"admin@example.com"}, "TestSite") +} + +func TestCheckQuotaDimCrossings_NegativeResolvedThreshold_Skipped(t *testing.T) { + s, _ := newBalanceNotifyServiceForTest() + account := &Account{ID: 1, Name: "test", Platform: PlatformAnthropic} + // threshold=1200 remaining, limit=1000 → effectiveThreshold = 1000-1200 = -200 + // Negative resolved threshold → skipped. + dims := []quotaDim{ + { + name: quotaDimDaily, + enabled: true, + threshold: 1200, + thresholdType: thresholdTypeFixed, + currentUsed: 950, + limit: 1000, + }, + } + s.checkQuotaDimCrossings(account, dims, 50, []string{"admin@example.com"}, "TestSite") +} + +func TestCheckQuotaDimCrossings_PercentageThreshold_NoCrossing(t *testing.T) { + s, _ := newBalanceNotifyServiceForTest() + account := &Account{ID: 1, Name: "test", Platform: PlatformAnthropic} + // threshold=30%, limit=1000 → effectiveThreshold = 1000 * (1 - 0.30) = 700 + // currentUsed=500, oldUsed=500-50=450. Both < 700, no crossing. + dims := []quotaDim{ + { + name: quotaDimWeekly, + enabled: true, + threshold: 30, + thresholdType: thresholdTypePercentage, + currentUsed: 500, + limit: 1000, + }, + } + s.checkQuotaDimCrossings(account, dims, 50, []string{"admin@example.com"}, "TestSite") +} + +func TestCheckQuotaDimCrossings_ZeroLimit_Skipped(t *testing.T) { + s, _ := newBalanceNotifyServiceForTest() + account := &Account{ID: 1, Name: "test", Platform: PlatformAnthropic} + // limit=0 → resolvedThreshold returns 0 → skipped. + dims := []quotaDim{ + { + name: quotaDimTotal, + enabled: true, + threshold: 100, + thresholdType: thresholdTypeFixed, + currentUsed: 50, + limit: 0, + }, + } + s.checkQuotaDimCrossings(account, dims, 50, []string{"admin@example.com"}, "TestSite") +} + +func TestCheckQuotaDimCrossings_MultipleDims_MixedResults(t *testing.T) { + s, _ := newBalanceNotifyServiceForTest() + account := &Account{ID: 1, Name: "test", Platform: PlatformAnthropic} + // dim1: no crossing (both below effective threshold) + // dim2: disabled (skipped) + // dim3: zero threshold (skipped) + dims := []quotaDim{ + { + name: quotaDimDaily, + enabled: true, + threshold: 400, + thresholdType: thresholdTypeFixed, + currentUsed: 300, // oldUsed=250, effectiveThreshold=600, both below + limit: 1000, + }, + { + name: quotaDimWeekly, + enabled: false, + threshold: 100, + thresholdType: thresholdTypeFixed, + currentUsed: 900, + limit: 1000, + }, + { + name: quotaDimTotal, + enabled: true, + threshold: 0, + thresholdType: thresholdTypeFixed, + currentUsed: 500, + limit: 1000, + }, + } + // None should trigger. No panic expected. + s.checkQuotaDimCrossings(account, dims, 50, []string{"admin@example.com"}, "TestSite") +} diff --git a/backend/internal/service/billing_service_test.go b/backend/internal/service/billing_service_test.go index 6f6c41ce..2cf134e2 100644 --- a/backend/internal/service/billing_service_test.go +++ b/backend/internal/service/billing_service_test.go @@ -718,3 +718,123 @@ func TestGetModelPricing_MapsDynamicPriorityFieldsIntoBillingPricing(t *testing. require.InDelta(t, 1.5, pricing.LongContextInputMultiplier, 1e-12) require.InDelta(t, 1.25, pricing.LongContextOutputMultiplier, 1e-12) } + +// --------------------------------------------------------------------------- +// GetModelPricingWithChannel +// --------------------------------------------------------------------------- + +func TestGetModelPricingWithChannel_NilChannelPricing_ReturnsOriginal(t *testing.T) { + svc := newTestBillingService() + + pricing, err := svc.GetModelPricingWithChannel("claude-sonnet-4", nil) + require.NoError(t, err) + require.NotNil(t, pricing) + + // Should be identical to GetModelPricing + original, err := svc.GetModelPricing("claude-sonnet-4") + require.NoError(t, err) + require.InDelta(t, original.InputPricePerToken, pricing.InputPricePerToken, 1e-12) + require.InDelta(t, original.OutputPricePerToken, pricing.OutputPricePerToken, 1e-12) + require.InDelta(t, original.CacheCreationPricePerToken, pricing.CacheCreationPricePerToken, 1e-12) + require.InDelta(t, original.CacheReadPricePerToken, pricing.CacheReadPricePerToken, 1e-12) +} + +func TestGetModelPricingWithChannel_OverrideInputPriceOnly(t *testing.T) { + svc := newTestBillingService() + + chPricing := &ChannelModelPricing{ + InputPrice: testPtrFloat64(99e-6), + } + pricing, err := svc.GetModelPricingWithChannel("claude-sonnet-4", chPricing) + require.NoError(t, err) + + // InputPrice overridden (both normal and priority) + require.InDelta(t, 99e-6, pricing.InputPricePerToken, 1e-12) + require.InDelta(t, 99e-6, pricing.InputPricePerTokenPriority, 1e-12) + + // OutputPrice unchanged (claude-sonnet-4 fallback = 15e-6) + require.InDelta(t, 15e-6, pricing.OutputPricePerToken, 1e-12) +} + +func TestGetModelPricingWithChannel_OverrideOutputPriceOnly(t *testing.T) { + svc := newTestBillingService() + + chPricing := &ChannelModelPricing{ + OutputPrice: testPtrFloat64(88e-6), + } + pricing, err := svc.GetModelPricingWithChannel("claude-sonnet-4", chPricing) + require.NoError(t, err) + + // OutputPrice overridden + require.InDelta(t, 88e-6, pricing.OutputPricePerToken, 1e-12) + require.InDelta(t, 88e-6, pricing.OutputPricePerTokenPriority, 1e-12) + + // InputPrice unchanged (claude-sonnet-4 fallback = 3e-6) + require.InDelta(t, 3e-6, pricing.InputPricePerToken, 1e-12) +} + +func TestGetModelPricingWithChannel_OverrideAllFields(t *testing.T) { + svc := newTestBillingService() + + chPricing := &ChannelModelPricing{ + InputPrice: testPtrFloat64(10e-6), + OutputPrice: testPtrFloat64(20e-6), + CacheWritePrice: testPtrFloat64(5e-6), + CacheReadPrice: testPtrFloat64(1e-6), + ImageOutputPrice: testPtrFloat64(50e-6), + } + pricing, err := svc.GetModelPricingWithChannel("claude-sonnet-4", chPricing) + require.NoError(t, err) + + require.InDelta(t, 10e-6, pricing.InputPricePerToken, 1e-12) + require.InDelta(t, 10e-6, pricing.InputPricePerTokenPriority, 1e-12) + require.InDelta(t, 20e-6, pricing.OutputPricePerToken, 1e-12) + require.InDelta(t, 20e-6, pricing.OutputPricePerTokenPriority, 1e-12) + require.InDelta(t, 5e-6, pricing.CacheCreationPricePerToken, 1e-12) + require.InDelta(t, 5e-6, pricing.CacheCreation5mPrice, 1e-12) + require.InDelta(t, 5e-6, pricing.CacheCreation1hPrice, 1e-12) + require.InDelta(t, 1e-6, pricing.CacheReadPricePerToken, 1e-12) + require.InDelta(t, 1e-6, pricing.CacheReadPricePerTokenPriority, 1e-12) + require.InDelta(t, 50e-6, pricing.ImageOutputPricePerToken, 1e-12) +} + +func TestGetModelPricingWithChannel_CacheWritePriceAffects5mAnd1h(t *testing.T) { + svc := newTestBillingService() + + chPricing := &ChannelModelPricing{ + CacheWritePrice: testPtrFloat64(7e-6), + } + pricing, err := svc.GetModelPricingWithChannel("claude-sonnet-4", chPricing) + require.NoError(t, err) + + // CacheWritePrice should set all three: CacheCreationPricePerToken, 5m, and 1h + require.InDelta(t, 7e-6, pricing.CacheCreationPricePerToken, 1e-12) + require.InDelta(t, 7e-6, pricing.CacheCreation5mPrice, 1e-12) + require.InDelta(t, 7e-6, pricing.CacheCreation1hPrice, 1e-12) +} + +func TestGetModelPricingWithChannel_CacheReadPriceAffectsPriority(t *testing.T) { + svc := newTestBillingService() + + chPricing := &ChannelModelPricing{ + CacheReadPrice: testPtrFloat64(2e-6), + } + pricing, err := svc.GetModelPricingWithChannel("claude-sonnet-4", chPricing) + require.NoError(t, err) + + // CacheReadPrice should set both normal and priority + require.InDelta(t, 2e-6, pricing.CacheReadPricePerToken, 1e-12) + require.InDelta(t, 2e-6, pricing.CacheReadPricePerTokenPriority, 1e-12) +} + +func TestGetModelPricingWithChannel_UnknownModelReturnsError(t *testing.T) { + svc := newTestBillingService() + + chPricing := &ChannelModelPricing{ + InputPrice: testPtrFloat64(1e-6), + } + pricing, err := svc.GetModelPricingWithChannel("totally-unknown-model", chPricing) + require.Error(t, err) + require.Nil(t, pricing) + require.Contains(t, err.Error(), "pricing not found") +} diff --git a/backend/internal/service/billing_service_unified_test.go b/backend/internal/service/billing_service_unified_test.go new file mode 100644 index 00000000..694c3384 --- /dev/null +++ b/backend/internal/service/billing_service_unified_test.go @@ -0,0 +1,258 @@ +//go:build unit + +package service + +import ( + "context" + "testing" + "time" + + "github.com/Wei-Shaw/sub2api/internal/config" + "github.com/stretchr/testify/require" +) + +// --------------------------------------------------------------------------- +// CalculateCostUnified +// --------------------------------------------------------------------------- + +func TestCalculateCostUnified_NilResolver_FallsBackToOldPath(t *testing.T) { + svc := newTestBillingService() + + tokens := UsageTokens{InputTokens: 1000, OutputTokens: 500} + input := CostInput{ + Model: "claude-sonnet-4", + Tokens: tokens, + RateMultiplier: 1.0, + Resolver: nil, // no resolver + } + cost, err := svc.CalculateCostUnified(input) + require.NoError(t, err) + + // Should match the old-path result exactly + expected, err := svc.calculateCostInternal("claude-sonnet-4", tokens, 1.0, "", nil) + require.NoError(t, err) + require.InDelta(t, expected.TotalCost, cost.TotalCost, 1e-10) + require.InDelta(t, expected.ActualCost, cost.ActualCost, 1e-10) + // BillingMode is NOT set by old path through CalculateCostUnified (resolver == nil) + require.Empty(t, cost.BillingMode) +} + +func TestCalculateCostUnified_TokenMode(t *testing.T) { + bs := newTestBillingService() + resolver := NewModelPricingResolver(nil, bs) + + tokens := UsageTokens{InputTokens: 1000, OutputTokens: 500} + input := CostInput{ + Ctx: context.Background(), + Model: "claude-sonnet-4", + Tokens: tokens, + RateMultiplier: 1.5, + Resolver: resolver, + } + cost, err := bs.CalculateCostUnified(input) + require.NoError(t, err) + require.NotNil(t, cost) + + // Verify token billing: Input: 1000*3e-6=0.003, Output: 500*15e-6=0.0075 + expectedTotal := 1000*3e-6 + 500*15e-6 + require.InDelta(t, expectedTotal, cost.TotalCost, 1e-10) + require.InDelta(t, expectedTotal*1.5, cost.ActualCost, 1e-10) + require.Equal(t, string(BillingModeToken), cost.BillingMode) +} + +func TestCalculateCostUnified_PerRequestMode(t *testing.T) { + // Set up a ChannelService with a per-request pricing channel + cs := newTestChannelServiceWithCache(t, &channelCache{ + pricingByGroupModel: map[channelModelKey]*ChannelModelPricing{ + {groupID: 1, model: "claude-sonnet-4"}: { + BillingMode: BillingModePerRequest, + PerRequestPrice: testPtrFloat64(0.05), + }, + }, + channelByGroupID: map[int64]*Channel{ + 1: {ID: 1, Status: StatusActive}, + }, + groupPlatform: map[int64]string{1: ""}, + wildcardByGroupPlatform: map[channelGroupPlatformKey][]*wildcardPricingEntry{}, + mappingByGroupModel: map[channelModelKey]string{}, + wildcardMappingByGP: map[channelGroupPlatformKey][]*wildcardMappingEntry{}, + byID: map[int64]*Channel{}, + }) + + bs := newTestBillingService() + resolver := NewModelPricingResolver(cs, bs) + groupID := int64(1) + + input := CostInput{ + Ctx: context.Background(), + Model: "claude-sonnet-4", + GroupID: &groupID, + Tokens: UsageTokens{InputTokens: 100, OutputTokens: 50}, + RequestCount: 3, + RateMultiplier: 2.0, + Resolver: resolver, + } + cost, err := bs.CalculateCostUnified(input) + require.NoError(t, err) + require.NotNil(t, cost) + + // 3 requests * $0.05 = $0.15 + require.InDelta(t, 0.15, cost.TotalCost, 1e-10) + // ActualCost = 0.15 * 2.0 = 0.30 + require.InDelta(t, 0.30, cost.ActualCost, 1e-10) + require.Equal(t, string(BillingModePerRequest), cost.BillingMode) +} + +func TestCalculateCostUnified_ImageMode(t *testing.T) { + cs := newTestChannelServiceWithCache(t, &channelCache{ + pricingByGroupModel: map[channelModelKey]*ChannelModelPricing{ + {groupID: 2, model: "gemini-image"}: { + BillingMode: BillingModeImage, + PerRequestPrice: testPtrFloat64(0.10), + }, + }, + channelByGroupID: map[int64]*Channel{ + 2: {ID: 2, Status: StatusActive}, + }, + groupPlatform: map[int64]string{2: ""}, + wildcardByGroupPlatform: map[channelGroupPlatformKey][]*wildcardPricingEntry{}, + mappingByGroupModel: map[channelModelKey]string{}, + wildcardMappingByGP: map[channelGroupPlatformKey][]*wildcardMappingEntry{}, + byID: map[int64]*Channel{}, + }) + + bs := &BillingService{ + cfg: &config.Config{}, + fallbackPrices: map[string]*ModelPricing{}, + } + resolver := NewModelPricingResolver(cs, bs) + groupID := int64(2) + + input := CostInput{ + Ctx: context.Background(), + Model: "gemini-image", + GroupID: &groupID, + Tokens: UsageTokens{}, + RequestCount: 2, + RateMultiplier: 1.0, + Resolver: resolver, + } + cost, err := bs.CalculateCostUnified(input) + require.NoError(t, err) + require.NotNil(t, cost) + + // 2 * $0.10 = $0.20 + require.InDelta(t, 0.20, cost.TotalCost, 1e-10) + require.InDelta(t, 0.20, cost.ActualCost, 1e-10) + require.Equal(t, string(BillingModeImage), cost.BillingMode) +} + +func TestCalculateCostUnified_RateMultiplierZeroDefaultsToOne(t *testing.T) { + bs := newTestBillingService() + resolver := NewModelPricingResolver(nil, bs) + + tokens := UsageTokens{InputTokens: 1000, OutputTokens: 500} + + costZero, err := bs.CalculateCostUnified(CostInput{ + Ctx: context.Background(), + Model: "claude-sonnet-4", + Tokens: tokens, + RateMultiplier: 0, // should default to 1.0 + Resolver: resolver, + }) + require.NoError(t, err) + + costOne, err := bs.CalculateCostUnified(CostInput{ + Ctx: context.Background(), + Model: "claude-sonnet-4", + Tokens: tokens, + RateMultiplier: 1.0, + Resolver: resolver, + }) + require.NoError(t, err) + + require.InDelta(t, costOne.ActualCost, costZero.ActualCost, 1e-10) +} + +func TestCalculateCostUnified_NegativeRateMultiplierDefaultsToOne(t *testing.T) { + bs := newTestBillingService() + resolver := NewModelPricingResolver(nil, bs) + + tokens := UsageTokens{InputTokens: 1000} + + costNeg, err := bs.CalculateCostUnified(CostInput{ + Ctx: context.Background(), + Model: "claude-sonnet-4", + Tokens: tokens, + RateMultiplier: -5.0, + Resolver: resolver, + }) + require.NoError(t, err) + + costOne, err := bs.CalculateCostUnified(CostInput{ + Ctx: context.Background(), + Model: "claude-sonnet-4", + Tokens: tokens, + RateMultiplier: 1.0, + Resolver: resolver, + }) + require.NoError(t, err) + + require.InDelta(t, costOne.ActualCost, costNeg.ActualCost, 1e-10) +} + +func TestCalculateCostUnified_BillingModeFieldFilled(t *testing.T) { + bs := newTestBillingService() + resolver := NewModelPricingResolver(nil, bs) + + cost, err := bs.CalculateCostUnified(CostInput{ + Ctx: context.Background(), + Model: "claude-sonnet-4", + Tokens: UsageTokens{InputTokens: 100}, + RateMultiplier: 1.0, + Resolver: resolver, + }) + require.NoError(t, err) + require.Equal(t, "token", cost.BillingMode) +} + +func TestCalculateCostUnified_UsesPreResolvedPricing(t *testing.T) { + bs := newTestBillingService() + resolver := NewModelPricingResolver(nil, bs) + + // Pre-resolve with per_request mode to verify it's used instead of re-resolving + preResolved := &ResolvedPricing{ + Mode: BillingModePerRequest, + DefaultPerRequestPrice: 0.07, + } + + cost, err := bs.CalculateCostUnified(CostInput{ + Ctx: context.Background(), + Model: "claude-sonnet-4", + Tokens: UsageTokens{InputTokens: 100}, + RequestCount: 2, + RateMultiplier: 1.0, + Resolver: resolver, + Resolved: preResolved, + }) + require.NoError(t, err) + require.NotNil(t, cost) + + // 2 * $0.07 = $0.14 + require.InDelta(t, 0.14, cost.TotalCost, 1e-10) + require.Equal(t, string(BillingModePerRequest), cost.BillingMode) +} + +// --------------------------------------------------------------------------- +// helpers +// --------------------------------------------------------------------------- + +// newTestChannelServiceWithCache creates a ChannelService with a pre-populated +// cache snapshot, bypassing the repository layer entirely. +func newTestChannelServiceWithCache(t *testing.T, cache *channelCache) *ChannelService { + t.Helper() + cs := &ChannelService{} + cache.loadedAt = time.Now() + cs.cache.Store(cache) + return cs +} diff --git a/backend/internal/service/gateway_websearch_emulation_test.go b/backend/internal/service/gateway_websearch_emulation_test.go index b606c748..de1f0014 100644 --- a/backend/internal/service/gateway_websearch_emulation_test.go +++ b/backend/internal/service/gateway_websearch_emulation_test.go @@ -1,8 +1,14 @@ +//go:build unit + package service import ( + "context" + "encoding/json" "testing" + "time" + "github.com/Wei-Shaw/sub2api/internal/config" "github.com/Wei-Shaw/sub2api/internal/pkg/websearch" "github.com/stretchr/testify/require" ) @@ -140,3 +146,235 @@ func TestBuildTextSummary_NoResults(t *testing.T) { summary := buildTextSummary("test", nil) require.Contains(t, summary, "No search results found for: test") } + +// --- shouldEmulateWebSearch --- + +// webSearchToolBody is a valid request body with exactly one web_search tool. +var webSearchToolBody = []byte(`{"tools":[{"type":"web_search"}],"messages":[{"role":"user","content":"test"}]}`) + +// nonWebSearchToolBody is a request body without web_search tool. +var nonWebSearchToolBody = []byte(`{"tools":[{"type":"text_editor"}],"messages":[{"role":"user","content":"test"}]}`) + +// newAnthropicAPIKeyAccount creates a test Account with the given web search emulation mode. +func newAnthropicAPIKeyAccount(mode string) *Account { + return &Account{ + ID: 1, + Platform: PlatformAnthropic, + Type: AccountTypeAPIKey, + Extra: map[string]any{featureKeyWebSearchEmulation: mode}, + } +} + +// setGlobalWebSearchConfig stores a config in the global cache used by SettingService.IsWebSearchEmulationEnabled. +func setGlobalWebSearchConfig(cfg *WebSearchEmulationConfig) { + webSearchEmulationCache.Store(&cachedWebSearchEmulationConfig{ + config: cfg, + expiresAt: time.Now().Add(10 * time.Minute).UnixNano(), + }) +} + +// clearGlobalWebSearchConfig resets the global cache to force re-read. +func clearGlobalWebSearchConfig() { + webSearchEmulationCache.Store((*cachedWebSearchEmulationConfig)(nil)) +} + +// newSettingServiceForWebSearchTest creates a SettingService with a mock repo pre-loaded with config. +func newSettingServiceForWebSearchTest(enabled bool) *SettingService { + repo := newMockSettingRepo() + cfg := &WebSearchEmulationConfig{ + Enabled: enabled, + Providers: []WebSearchProviderConfig{{Type: "brave", APIKey: "sk-test"}}, + } + data, _ := json.Marshal(cfg) + repo.data[SettingKeyWebSearchEmulationConfig] = string(data) + return NewSettingService(repo, &config.Config{}) +} + +// newChannelServiceWithCache creates a ChannelService with a pre-built cache containing the channel. +func newChannelServiceWithCache(groupID int64, ch *Channel) *ChannelService { + svc := &ChannelService{} + cache := &channelCache{ + channelByGroupID: map[int64]*Channel{groupID: ch}, + byID: map[int64]*Channel{ch.ID: ch}, + groupPlatform: map[int64]string{}, + loadedAt: time.Now(), + } + svc.cache.Store(cache) + return svc +} + +func TestShouldEmulateWebSearch_NilManager(t *testing.T) { + SetWebSearchManager(nil) + defer SetWebSearchManager(nil) + + settingSvc := newSettingServiceForWebSearchTest(true) + setGlobalWebSearchConfig(&WebSearchEmulationConfig{ + Enabled: true, + Providers: []WebSearchProviderConfig{{Type: "brave", APIKey: "k"}}, + }) + defer clearGlobalWebSearchConfig() + + svc := &GatewayService{settingService: settingSvc} + account := newAnthropicAPIKeyAccount(WebSearchModeEnabled) + require.False(t, svc.shouldEmulateWebSearch(context.Background(), account, nil, webSearchToolBody)) +} + +func TestShouldEmulateWebSearch_NotOnlyWebSearchTool(t *testing.T) { + mgr := websearch.NewManager([]websearch.ProviderConfig{{Type: "brave", APIKey: "k"}}, nil) + SetWebSearchManager(mgr) + defer SetWebSearchManager(nil) + + settingSvc := newSettingServiceForWebSearchTest(true) + setGlobalWebSearchConfig(&WebSearchEmulationConfig{ + Enabled: true, + Providers: []WebSearchProviderConfig{{Type: "brave", APIKey: "k"}}, + }) + defer clearGlobalWebSearchConfig() + + svc := &GatewayService{settingService: settingSvc} + account := newAnthropicAPIKeyAccount(WebSearchModeEnabled) + require.False(t, svc.shouldEmulateWebSearch(context.Background(), account, nil, nonWebSearchToolBody)) +} + +func TestShouldEmulateWebSearch_GlobalDisabled(t *testing.T) { + mgr := websearch.NewManager([]websearch.ProviderConfig{{Type: "brave", APIKey: "k"}}, nil) + SetWebSearchManager(mgr) + defer SetWebSearchManager(nil) + + // Global config disabled + setGlobalWebSearchConfig(&WebSearchEmulationConfig{ + Enabled: false, + Providers: []WebSearchProviderConfig{{Type: "brave", APIKey: "k"}}, + }) + defer clearGlobalWebSearchConfig() + + settingSvc := newSettingServiceForWebSearchTest(false) + svc := &GatewayService{settingService: settingSvc} + account := newAnthropicAPIKeyAccount(WebSearchModeEnabled) + require.False(t, svc.shouldEmulateWebSearch(context.Background(), account, nil, webSearchToolBody)) +} + +func TestShouldEmulateWebSearch_AccountDisabled(t *testing.T) { + mgr := websearch.NewManager([]websearch.ProviderConfig{{Type: "brave", APIKey: "k"}}, nil) + SetWebSearchManager(mgr) + defer SetWebSearchManager(nil) + + setGlobalWebSearchConfig(&WebSearchEmulationConfig{ + Enabled: true, + Providers: []WebSearchProviderConfig{{Type: "brave", APIKey: "k"}}, + }) + defer clearGlobalWebSearchConfig() + + settingSvc := newSettingServiceForWebSearchTest(true) + svc := &GatewayService{settingService: settingSvc} + account := newAnthropicAPIKeyAccount(WebSearchModeDisabled) + require.False(t, svc.shouldEmulateWebSearch(context.Background(), account, nil, webSearchToolBody)) +} + +func TestShouldEmulateWebSearch_AccountEnabled(t *testing.T) { + mgr := websearch.NewManager([]websearch.ProviderConfig{{Type: "brave", APIKey: "k"}}, nil) + SetWebSearchManager(mgr) + defer SetWebSearchManager(nil) + + setGlobalWebSearchConfig(&WebSearchEmulationConfig{ + Enabled: true, + Providers: []WebSearchProviderConfig{{Type: "brave", APIKey: "k"}}, + }) + defer clearGlobalWebSearchConfig() + + settingSvc := newSettingServiceForWebSearchTest(true) + svc := &GatewayService{settingService: settingSvc} + account := newAnthropicAPIKeyAccount(WebSearchModeEnabled) + require.True(t, svc.shouldEmulateWebSearch(context.Background(), account, nil, webSearchToolBody)) +} + +func TestShouldEmulateWebSearch_DefaultMode_ChannelEnabled(t *testing.T) { + mgr := websearch.NewManager([]websearch.ProviderConfig{{Type: "brave", APIKey: "k"}}, nil) + SetWebSearchManager(mgr) + defer SetWebSearchManager(nil) + + setGlobalWebSearchConfig(&WebSearchEmulationConfig{ + Enabled: true, + Providers: []WebSearchProviderConfig{{Type: "brave", APIKey: "k"}}, + }) + defer clearGlobalWebSearchConfig() + + settingSvc := newSettingServiceForWebSearchTest(true) + ch := &Channel{ + ID: 10, + Status: StatusActive, + FeaturesConfig: map[string]any{ + featureKeyWebSearchEmulation: map[string]any{PlatformAnthropic: true}, + }, + } + channelSvc := newChannelServiceWithCache(42, ch) + svc := &GatewayService{settingService: settingSvc, channelService: channelSvc} + + account := newAnthropicAPIKeyAccount(WebSearchModeDefault) + groupID := int64(42) + require.True(t, svc.shouldEmulateWebSearch(context.Background(), account, &groupID, webSearchToolBody)) +} + +func TestShouldEmulateWebSearch_DefaultMode_ChannelDisabled(t *testing.T) { + mgr := websearch.NewManager([]websearch.ProviderConfig{{Type: "brave", APIKey: "k"}}, nil) + SetWebSearchManager(mgr) + defer SetWebSearchManager(nil) + + setGlobalWebSearchConfig(&WebSearchEmulationConfig{ + Enabled: true, + Providers: []WebSearchProviderConfig{{Type: "brave", APIKey: "k"}}, + }) + defer clearGlobalWebSearchConfig() + + settingSvc := newSettingServiceForWebSearchTest(true) + ch := &Channel{ + ID: 10, + Status: StatusActive, + FeaturesConfig: map[string]any{ + featureKeyWebSearchEmulation: map[string]any{PlatformAnthropic: false}, + }, + } + channelSvc := newChannelServiceWithCache(42, ch) + svc := &GatewayService{settingService: settingSvc, channelService: channelSvc} + + account := newAnthropicAPIKeyAccount(WebSearchModeDefault) + groupID := int64(42) + require.False(t, svc.shouldEmulateWebSearch(context.Background(), account, &groupID, webSearchToolBody)) +} + +func TestShouldEmulateWebSearch_DefaultMode_NilGroupID(t *testing.T) { + mgr := websearch.NewManager([]websearch.ProviderConfig{{Type: "brave", APIKey: "k"}}, nil) + SetWebSearchManager(mgr) + defer SetWebSearchManager(nil) + + setGlobalWebSearchConfig(&WebSearchEmulationConfig{ + Enabled: true, + Providers: []WebSearchProviderConfig{{Type: "brave", APIKey: "k"}}, + }) + defer clearGlobalWebSearchConfig() + + settingSvc := newSettingServiceForWebSearchTest(true) + svc := &GatewayService{settingService: settingSvc} + account := newAnthropicAPIKeyAccount(WebSearchModeDefault) + // nil groupID + default mode → falls through to channel check → returns false + require.False(t, svc.shouldEmulateWebSearch(context.Background(), account, nil, webSearchToolBody)) +} + +func TestShouldEmulateWebSearch_DefaultMode_NilChannelService(t *testing.T) { + mgr := websearch.NewManager([]websearch.ProviderConfig{{Type: "brave", APIKey: "k"}}, nil) + SetWebSearchManager(mgr) + defer SetWebSearchManager(nil) + + setGlobalWebSearchConfig(&WebSearchEmulationConfig{ + Enabled: true, + Providers: []WebSearchProviderConfig{{Type: "brave", APIKey: "k"}}, + }) + defer clearGlobalWebSearchConfig() + + settingSvc := newSettingServiceForWebSearchTest(true) + svc := &GatewayService{settingService: settingSvc, channelService: nil} + account := newAnthropicAPIKeyAccount(WebSearchModeDefault) + groupID := int64(42) + // nil channelService + default mode → returns false + require.False(t, svc.shouldEmulateWebSearch(context.Background(), account, &groupID, webSearchToolBody)) +} diff --git a/backend/internal/service/notify_email_entry_test.go b/backend/internal/service/notify_email_entry_test.go new file mode 100644 index 00000000..0f4bb12e --- /dev/null +++ b/backend/internal/service/notify_email_entry_test.go @@ -0,0 +1,156 @@ +//go:build unit + +package service + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +// ---------- ParseNotifyEmails ---------- + +func TestParseNotifyEmails_EmptyString(t *testing.T) { + result := ParseNotifyEmails("") + require.Nil(t, result) +} + +func TestParseNotifyEmails_EmptyArray(t *testing.T) { + result := ParseNotifyEmails("[]") + require.Nil(t, result) +} + +func TestParseNotifyEmails_Null(t *testing.T) { + // "null" is valid JSON that unmarshals into a nil string slice. + // The old-format branch then returns an empty (non-nil) slice. + result := ParseNotifyEmails("null") + require.Empty(t, result) +} + +func TestParseNotifyEmails_WhitespaceOnly(t *testing.T) { + result := ParseNotifyEmails(" ") + require.Nil(t, result) +} + +func TestParseNotifyEmails_OldFormat(t *testing.T) { + raw := `["alice@example.com", "bob@example.com"]` + result := ParseNotifyEmails(raw) + require.Len(t, result, 2) + + require.Equal(t, "alice@example.com", result[0].Email) + require.False(t, result[0].Verified, "old format emails should default to unverified") + require.False(t, result[0].Disabled) + + require.Equal(t, "bob@example.com", result[1].Email) + require.False(t, result[1].Verified) + require.False(t, result[1].Disabled) +} + +func TestParseNotifyEmails_OldFormat_SkipsEmptyEntries(t *testing.T) { + raw := `["alice@example.com", "", " ", "bob@example.com"]` + result := ParseNotifyEmails(raw) + require.Len(t, result, 2) + require.Equal(t, "alice@example.com", result[0].Email) + require.Equal(t, "bob@example.com", result[1].Email) +} + +func TestParseNotifyEmails_NewFormat(t *testing.T) { + raw := `[{"email":"alice@example.com","verified":true,"disabled":false},{"email":"bob@example.com","verified":false,"disabled":true}]` + result := ParseNotifyEmails(raw) + require.Len(t, result, 2) + + require.Equal(t, "alice@example.com", result[0].Email) + require.True(t, result[0].Verified) + require.False(t, result[0].Disabled) + + require.Equal(t, "bob@example.com", result[1].Email) + require.False(t, result[1].Verified) + require.True(t, result[1].Disabled) +} + +func TestParseNotifyEmails_NewFormat_SingleEntry(t *testing.T) { + raw := `[{"email":"solo@example.com","verified":true,"disabled":false}]` + result := ParseNotifyEmails(raw) + require.Len(t, result, 1) + require.Equal(t, "solo@example.com", result[0].Email) + require.True(t, result[0].Verified) +} + +func TestParseNotifyEmails_InvalidJSON(t *testing.T) { + result := ParseNotifyEmails(`{not valid json`) + require.Nil(t, result) +} + +func TestParseNotifyEmails_InvalidJSONObject(t *testing.T) { + // A plain JSON object (not array) should return nil. + result := ParseNotifyEmails(`{"email":"a@b.com"}`) + require.Nil(t, result) +} + +func TestParseNotifyEmails_WhitespacePadding(t *testing.T) { + raw := ` ["padded@example.com"] ` + result := ParseNotifyEmails(raw) + require.Len(t, result, 1) + require.Equal(t, "padded@example.com", result[0].Email) +} + +// ---------- MarshalNotifyEmails ---------- + +func TestMarshalNotifyEmails_EmptySlice(t *testing.T) { + result := MarshalNotifyEmails([]NotifyEmailEntry{}) + require.Equal(t, "[]", result) +} + +func TestMarshalNotifyEmails_NilSlice(t *testing.T) { + result := MarshalNotifyEmails(nil) + require.Equal(t, "[]", result) +} + +func TestMarshalNotifyEmails_SingleEntry(t *testing.T) { + entries := []NotifyEmailEntry{ + {Email: "test@example.com", Verified: true, Disabled: false}, + } + result := MarshalNotifyEmails(entries) + require.Contains(t, result, `"email":"test@example.com"`) + require.Contains(t, result, `"verified":true`) + require.Contains(t, result, `"disabled":false`) + + // Round-trip: parsing the marshalled result should produce the original entries. + parsed := ParseNotifyEmails(result) + require.Len(t, parsed, 1) + require.Equal(t, entries[0], parsed[0]) +} + +func TestMarshalNotifyEmails_MultipleEntries(t *testing.T) { + entries := []NotifyEmailEntry{ + {Email: "a@example.com", Verified: true, Disabled: false}, + {Email: "b@example.com", Verified: false, Disabled: true}, + } + result := MarshalNotifyEmails(entries) + + // Round-trip verification. + parsed := ParseNotifyEmails(result) + require.Len(t, parsed, 2) + require.Equal(t, entries[0], parsed[0]) + require.Equal(t, entries[1], parsed[1]) +} + +func TestMarshalNotifyEmails_RoundTrip_NewFormat(t *testing.T) { + original := []NotifyEmailEntry{ + {Email: "x@example.com", Verified: true, Disabled: true}, + {Email: "y@example.com", Verified: false, Disabled: false}, + } + marshalled := MarshalNotifyEmails(original) + parsed := ParseNotifyEmails(marshalled) + require.Equal(t, original, parsed) +} + +// ---------- isOldStringArrayFormat (indirectly via ParseNotifyEmails) ---------- + +func TestParseNotifyEmails_MixedOldFormatWithWhitespace(t *testing.T) { + // Emails with leading/trailing whitespace in old format should be trimmed. + raw := `[" alice@example.com "]` + result := ParseNotifyEmails(raw) + require.Len(t, result, 1) + require.Equal(t, "alice@example.com", result[0].Email) +} diff --git a/backend/internal/service/websearch_config_test.go b/backend/internal/service/websearch_config_test.go index 8cd50d0d..c5b96e01 100644 --- a/backend/internal/service/websearch_config_test.go +++ b/backend/internal/service/websearch_config_test.go @@ -1,9 +1,12 @@ +//go:build unit + package service import ( "context" "testing" + "github.com/Wei-Shaw/sub2api/internal/pkg/websearch" "github.com/stretchr/testify/require" ) @@ -141,3 +144,123 @@ func TestSanitizeWebSearchConfig_DoesNotMutateOriginal(t *testing.T) { _ = SanitizeWebSearchConfig(context.Background(), cfg) require.Equal(t, "secret", cfg.Providers[0].APIKey) } + +// --- PopulateWebSearchUsage --- + +func TestPopulateWebSearchUsage_NilInput(t *testing.T) { + require.Nil(t, PopulateWebSearchUsage(context.Background(), nil)) +} + +func TestPopulateWebSearchUsage_NoManager_QuotaUsedZero(t *testing.T) { + // Ensure no global manager is set + SetWebSearchManager(nil) + defer SetWebSearchManager(nil) + + cfg := &WebSearchEmulationConfig{ + Enabled: true, + Providers: []WebSearchProviderConfig{ + {Type: "brave", APIKey: "sk-key", QuotaLimit: int64Ptr(1000)}, + }, + } + out := PopulateWebSearchUsage(context.Background(), cfg) + require.NotNil(t, out) + require.Len(t, out.Providers, 1) + require.Equal(t, int64(0), out.Providers[0].QuotaUsed) +} + +func TestPopulateWebSearchUsage_APIKeyConfigured_True(t *testing.T) { + SetWebSearchManager(nil) + defer SetWebSearchManager(nil) + + cfg := &WebSearchEmulationConfig{ + Providers: []WebSearchProviderConfig{ + {Type: "brave", APIKey: "sk-key"}, + }, + } + out := PopulateWebSearchUsage(context.Background(), cfg) + require.True(t, out.Providers[0].APIKeyConfigured) +} + +func TestPopulateWebSearchUsage_APIKeyConfigured_False(t *testing.T) { + SetWebSearchManager(nil) + defer SetWebSearchManager(nil) + + cfg := &WebSearchEmulationConfig{ + Providers: []WebSearchProviderConfig{ + {Type: "brave", APIKey: ""}, + }, + } + out := PopulateWebSearchUsage(context.Background(), cfg) + require.False(t, out.Providers[0].APIKeyConfigured) +} + +func TestPopulateWebSearchUsage_NilQuotaLimit(t *testing.T) { + SetWebSearchManager(nil) + defer SetWebSearchManager(nil) + + cfg := &WebSearchEmulationConfig{ + Providers: []WebSearchProviderConfig{ + {Type: "brave", APIKey: "sk-key", QuotaLimit: nil}, + }, + } + out := PopulateWebSearchUsage(context.Background(), cfg) + require.Nil(t, out.Providers[0].QuotaLimit) +} + +func TestPopulateWebSearchUsage_NonNilQuotaLimit(t *testing.T) { + SetWebSearchManager(nil) + defer SetWebSearchManager(nil) + + cfg := &WebSearchEmulationConfig{ + Providers: []WebSearchProviderConfig{ + {Type: "brave", APIKey: "sk-key", QuotaLimit: int64Ptr(500)}, + }, + } + out := PopulateWebSearchUsage(context.Background(), cfg) + require.NotNil(t, out.Providers[0].QuotaLimit) + require.Equal(t, int64(500), *out.Providers[0].QuotaLimit) +} + +func TestPopulateWebSearchUsage_WithManager_NilRedis(t *testing.T) { + // Manager with nil Redis returns 0 usage without error + mgr := websearch.NewManager([]websearch.ProviderConfig{ + {Type: "brave", APIKey: "k"}, + }, nil) + SetWebSearchManager(mgr) + defer SetWebSearchManager(nil) + + cfg := &WebSearchEmulationConfig{ + Providers: []WebSearchProviderConfig{ + {Type: "brave", APIKey: "sk-key", QuotaLimit: int64Ptr(1000)}, + }, + } + out := PopulateWebSearchUsage(context.Background(), cfg) + require.Equal(t, int64(0), out.Providers[0].QuotaUsed) + require.True(t, out.Providers[0].APIKeyConfigured) +} + +func TestPopulateWebSearchUsage_DoesNotMutateOriginal(t *testing.T) { + SetWebSearchManager(nil) + defer SetWebSearchManager(nil) + + cfg := &WebSearchEmulationConfig{ + Providers: []WebSearchProviderConfig{ + {Type: "brave", APIKey: "secret", QuotaLimit: int64Ptr(100)}, + }, + } + _ = PopulateWebSearchUsage(context.Background(), cfg) + // Original should be unchanged + require.Equal(t, "secret", cfg.Providers[0].APIKey) + require.Equal(t, int64(0), cfg.Providers[0].QuotaUsed) +} + +// --- ResetWebSearchUsage --- + +func TestResetWebSearchUsage_NilManager(t *testing.T) { + SetWebSearchManager(nil) + defer SetWebSearchManager(nil) + + err := ResetWebSearchUsage(context.Background(), "brave") + require.Error(t, err) + require.Contains(t, err.Error(), "not initialized") +}