diff --git a/backend/internal/service/account_stats_pricing.go b/backend/internal/service/account_stats_pricing.go index e88f7f8c..cbe9c76c 100644 --- a/backend/internal/service/account_stats_pricing.go +++ b/backend/internal/service/account_stats_pricing.go @@ -57,7 +57,8 @@ func tryModelFilePricing(billingService *BillingService, model string, tokens Us cost := float64(tokens.InputTokens)*pricing.InputPricePerToken + float64(tokens.OutputTokens)*pricing.OutputPricePerToken + float64(tokens.CacheCreationTokens)*pricing.CacheCreationPricePerToken + - float64(tokens.CacheReadTokens)*pricing.CacheReadPricePerToken + float64(tokens.CacheReadTokens)*pricing.CacheReadPricePerToken + + float64(tokens.ImageOutputTokens)*pricing.ImageOutputPricePerToken if cost <= 0 { return nil } @@ -194,7 +195,7 @@ func calculateTokenStatsCost(pricing *ChannelModelPricing, tokens UsageTokens) * float64(tokens.CacheCreationTokens)*deref(pricing.CacheWritePrice) + float64(tokens.CacheReadTokens)*deref(pricing.CacheReadPrice) + float64(tokens.ImageOutputTokens)*deref(pricing.ImageOutputPrice) - if cost == 0 { + if cost <= 0 { return nil } return &cost diff --git a/backend/internal/service/account_stats_pricing_test.go b/backend/internal/service/account_stats_pricing_test.go index bc3db251..bf9da978 100644 --- a/backend/internal/service/account_stats_pricing_test.go +++ b/backend/internal/service/account_stats_pricing_test.go @@ -428,3 +428,102 @@ func TestTryCustomRules_RuleMatchesButModelNot_ContinuesToNext(t *testing.T) { require.NotNil(t, result) require.InDelta(t, 5.0, *result, 1e-12) // 使用规则2 } + +// --------------------------------------------------------------------------- +// tryModelFilePricing +// --------------------------------------------------------------------------- + +// newTestBillingServiceWithPrices creates a BillingService with pre-populated +// fallback prices for testing. No config or pricing service is needed. +// The key must match what getFallbackPricing resolves to for a given model name. +// E.g., model "claude-sonnet-4" resolves to key "claude-sonnet-4". +func newTestBillingServiceWithPrices(prices map[string]*ModelPricing) *BillingService { + return &BillingService{ + fallbackPrices: prices, + } +} + +func TestTryModelFilePricing_Success(t *testing.T) { + bs := newTestBillingServiceWithPrices(map[string]*ModelPricing{ + "claude-sonnet-4": { + InputPricePerToken: 0.001, + OutputPricePerToken: 0.002, + }, + }) + tokens := UsageTokens{InputTokens: 100, OutputTokens: 50} + result := tryModelFilePricing(bs, "claude-sonnet-4", tokens) + require.NotNil(t, result) + // 100*0.001 + 50*0.002 = 0.1 + 0.1 = 0.2 + require.InDelta(t, 0.2, *result, 1e-12) +} + +func TestTryModelFilePricing_PricingNotFound(t *testing.T) { + // "nonexistent-model" does not match any fallback pattern + bs := newTestBillingServiceWithPrices(map[string]*ModelPricing{}) + tokens := UsageTokens{InputTokens: 100, OutputTokens: 50} + result := tryModelFilePricing(bs, "nonexistent-model", tokens) + require.Nil(t, result) +} + +func TestTryModelFilePricing_NilFallback(t *testing.T) { + // getFallbackPricing returns nil when key maps to nil + bs := newTestBillingServiceWithPrices(map[string]*ModelPricing{ + "claude-sonnet-4": nil, + }) + tokens := UsageTokens{InputTokens: 100} + result := tryModelFilePricing(bs, "claude-sonnet-4", tokens) + require.Nil(t, result) +} + +func TestTryModelFilePricing_ZeroCost(t *testing.T) { + bs := newTestBillingServiceWithPrices(map[string]*ModelPricing{ + "claude-sonnet-4": { + InputPricePerToken: 0.001, + OutputPricePerToken: 0.002, + }, + }) + tokens := UsageTokens{} // all zero tokens → cost = 0 → nil + result := tryModelFilePricing(bs, "claude-sonnet-4", tokens) + require.Nil(t, result) +} + +func TestTryModelFilePricing_WithImageOutput(t *testing.T) { + bs := newTestBillingServiceWithPrices(map[string]*ModelPricing{ + "claude-sonnet-4": { + InputPricePerToken: 0.001, + OutputPricePerToken: 0.002, + ImageOutputPricePerToken: 0.01, + }, + }) + tokens := UsageTokens{ + InputTokens: 100, + OutputTokens: 50, + ImageOutputTokens: 10, + } + result := tryModelFilePricing(bs, "claude-sonnet-4", tokens) + require.NotNil(t, result) + // 100*0.001 + 50*0.002 + 10*0.01 = 0.1 + 0.1 + 0.1 = 0.3 + require.InDelta(t, 0.3, *result, 1e-12) +} + +func TestTryModelFilePricing_WithCacheTokens(t *testing.T) { + bs := newTestBillingServiceWithPrices(map[string]*ModelPricing{ + "claude-sonnet-4": { + InputPricePerToken: 0.001, + OutputPricePerToken: 0.002, + CacheCreationPricePerToken: 0.003, + CacheReadPricePerToken: 0.0005, + }, + }) + tokens := UsageTokens{ + InputTokens: 100, + OutputTokens: 50, + CacheCreationTokens: 200, + CacheReadTokens: 300, + } + result := tryModelFilePricing(bs, "claude-sonnet-4", tokens) + require.NotNil(t, result) + // 100*0.001 + 50*0.002 + 200*0.003 + 300*0.0005 + // = 0.1 + 0.1 + 0.6 + 0.15 = 0.95 + require.InDelta(t, 0.95, *result, 1e-12) +} diff --git a/backend/internal/service/account_websearch_test.go b/backend/internal/service/account_websearch_test.go index fe742ebf..b4d23c6b 100644 --- a/backend/internal/service/account_websearch_test.go +++ b/backend/internal/service/account_websearch_test.go @@ -1,3 +1,5 @@ +//go:build unit + package service import ( @@ -6,66 +8,98 @@ import ( "github.com/stretchr/testify/require" ) -func TestAccount_IsWebSearchEmulationEnabled_Enabled(t *testing.T) { +func TestGetWebSearchEmulationMode_Enabled(t *testing.T) { + a := &Account{ + Platform: PlatformAnthropic, + Type: AccountTypeAPIKey, + Extra: map[string]any{featureKeyWebSearchEmulation: "enabled"}, + } + require.Equal(t, WebSearchModeEnabled, a.GetWebSearchEmulationMode()) +} + +func TestGetWebSearchEmulationMode_Disabled(t *testing.T) { + a := &Account{ + Platform: PlatformAnthropic, + Type: AccountTypeAPIKey, + Extra: map[string]any{featureKeyWebSearchEmulation: "disabled"}, + } + require.Equal(t, WebSearchModeDisabled, a.GetWebSearchEmulationMode()) +} + +func TestGetWebSearchEmulationMode_Default(t *testing.T) { + a := &Account{ + Platform: PlatformAnthropic, + Type: AccountTypeAPIKey, + Extra: map[string]any{featureKeyWebSearchEmulation: "default"}, + } + require.Equal(t, WebSearchModeDefault, a.GetWebSearchEmulationMode()) +} + +func TestGetWebSearchEmulationMode_UnknownString(t *testing.T) { + a := &Account{ + Platform: PlatformAnthropic, + Type: AccountTypeAPIKey, + Extra: map[string]any{featureKeyWebSearchEmulation: "unknown"}, + } + require.Equal(t, WebSearchModeDefault, a.GetWebSearchEmulationMode()) +} + +func TestGetWebSearchEmulationMode_OldBoolTrue(t *testing.T) { a := &Account{ Platform: PlatformAnthropic, Type: AccountTypeAPIKey, Extra: map[string]any{featureKeyWebSearchEmulation: true}, } - require.True(t, a.IsWebSearchEmulationEnabled()) + // bool is not a string, type assertion fails → default + require.Equal(t, WebSearchModeDefault, a.GetWebSearchEmulationMode()) } -func TestAccount_IsWebSearchEmulationEnabled_Disabled(t *testing.T) { +func TestGetWebSearchEmulationMode_OldBoolFalse(t *testing.T) { a := &Account{ Platform: PlatformAnthropic, Type: AccountTypeAPIKey, Extra: map[string]any{featureKeyWebSearchEmulation: false}, } - require.False(t, a.IsWebSearchEmulationEnabled()) + require.Equal(t, WebSearchModeDefault, a.GetWebSearchEmulationMode()) } -func TestAccount_IsWebSearchEmulationEnabled_MissingField(t *testing.T) { +func TestGetWebSearchEmulationMode_NilAccount(t *testing.T) { + var a *Account + require.Equal(t, WebSearchModeDefault, a.GetWebSearchEmulationMode()) +} + +func TestGetWebSearchEmulationMode_NilExtra(t *testing.T) { + a := &Account{ + Platform: PlatformAnthropic, + Type: AccountTypeAPIKey, + Extra: nil, + } + require.Equal(t, WebSearchModeDefault, a.GetWebSearchEmulationMode()) +} + +func TestGetWebSearchEmulationMode_MissingField(t *testing.T) { a := &Account{ Platform: PlatformAnthropic, Type: AccountTypeAPIKey, Extra: map[string]any{}, } - require.False(t, a.IsWebSearchEmulationEnabled()) + require.Equal(t, WebSearchModeDefault, a.GetWebSearchEmulationMode()) } -func TestAccount_IsWebSearchEmulationEnabled_WrongType(t *testing.T) { - a := &Account{ - Platform: PlatformAnthropic, - Type: AccountTypeAPIKey, - Extra: map[string]any{featureKeyWebSearchEmulation: "true"}, - } - require.False(t, a.IsWebSearchEmulationEnabled()) -} - -func TestAccount_IsWebSearchEmulationEnabled_NilExtra(t *testing.T) { - a := &Account{Platform: PlatformAnthropic, Type: AccountTypeAPIKey, Extra: nil} - require.False(t, a.IsWebSearchEmulationEnabled()) -} - -func TestAccount_IsWebSearchEmulationEnabled_NilAccount(t *testing.T) { - var a *Account - require.False(t, a.IsWebSearchEmulationEnabled()) -} - -func TestAccount_IsWebSearchEmulationEnabled_NonAnthropicPlatform(t *testing.T) { +func TestGetWebSearchEmulationMode_NonAnthropicPlatform(t *testing.T) { a := &Account{ Platform: PlatformOpenAI, Type: AccountTypeAPIKey, - Extra: map[string]any{featureKeyWebSearchEmulation: true}, + Extra: map[string]any{featureKeyWebSearchEmulation: "enabled"}, } - require.False(t, a.IsWebSearchEmulationEnabled()) + require.Equal(t, WebSearchModeDefault, a.GetWebSearchEmulationMode()) } -func TestAccount_IsWebSearchEmulationEnabled_NonAPIKeyType(t *testing.T) { +func TestGetWebSearchEmulationMode_NonAPIKeyType(t *testing.T) { a := &Account{ Platform: PlatformAnthropic, Type: AccountTypeOAuth, - Extra: map[string]any{featureKeyWebSearchEmulation: true}, + Extra: map[string]any{featureKeyWebSearchEmulation: "enabled"}, } - require.False(t, a.IsWebSearchEmulationEnabled()) + require.Equal(t, WebSearchModeDefault, a.GetWebSearchEmulationMode()) } diff --git a/backend/internal/service/balance_notify_service.go b/backend/internal/service/balance_notify_service.go index 23411ed5..1e4d8ff6 100644 --- a/backend/internal/service/balance_notify_service.go +++ b/backend/internal/service/balance_notify_service.go @@ -2,7 +2,6 @@ package service import ( "context" - "encoding/json" "fmt" "html" "log/slog" @@ -14,6 +13,10 @@ import ( const ( emailSendTimeout = 30 * time.Second + // Threshold type values + thresholdTypeFixed = "fixed" + thresholdTypePercentage = "percentage" + // Quota dimension labels quotaDimDaily = "daily" quotaDimWeekly = "weekly" @@ -48,6 +51,15 @@ func NewBalanceNotifyService(emailService *EmailService, settingRepo SettingRepo } } +// resolveBalanceThreshold returns the effective balance threshold. +// For percentage type, it computes threshold = totalRecharged * percentage / 100. +func resolveBalanceThreshold(threshold float64, thresholdType string, totalRecharged float64) float64 { + if thresholdType == thresholdTypePercentage && totalRecharged > 0 { + return totalRecharged * threshold / 100 + } + return threshold +} + // CheckBalanceAfterDeduction checks if balance crossed below threshold after deduction. // oldBalance is the balance before deduction, cost is the amount deducted. // Notification is sent only on first crossing: oldBalance >= threshold && newBalance < threshold. @@ -73,8 +85,13 @@ func (s *BalanceNotifyService) CheckBalanceAfterDeduction(ctx context.Context, u return } + effectiveThreshold := resolveBalanceThreshold(threshold, user.BalanceNotifyThresholdType, user.TotalRecharged) + if effectiveThreshold <= 0 { + return + } + newBalance := oldBalance - cost - if oldBalance >= threshold && newBalance < threshold { + if oldBalance >= effectiveThreshold && newBalance < effectiveThreshold { siteName := s.getSiteName(ctx) recipients := s.collectBalanceNotifyRecipients(user) go func() { @@ -83,7 +100,7 @@ func (s *BalanceNotifyService) CheckBalanceAfterDeduction(ctx context.Context, u slog.Error("panic in balance notification", "recover", r) } }() - s.sendBalanceLowEmails(recipients, user.Username, user.Email, newBalance, threshold, siteName) + s.sendBalanceLowEmails(recipients, user.Username, user.Email, newBalance, effectiveThreshold, siteName) }() } } @@ -94,14 +111,14 @@ type quotaDim struct { enabled bool threshold float64 thresholdType string // "fixed" (default) or "percentage" - oldUsed float64 + currentUsed float64 limit float64 } // resolvedThreshold returns the effective threshold value. // For percentage type, it computes threshold = limit * percentage / 100. func (d quotaDim) resolvedThreshold() float64 { - if d.thresholdType == "percentage" && d.limit > 0 { + if d.thresholdType == thresholdTypePercentage && d.limit > 0 { return d.limit * d.threshold / 100 } return d.threshold @@ -150,7 +167,7 @@ func (s *BalanceNotifyService) fetchFreshAccount(ctx context.Context, snapshot * } // checkQuotaDimCrossings iterates quota dimensions and sends alerts for threshold crossings. -// freshAccount has post-increment values; oldUsed is reconstructed as freshUsed - cost. +// freshAccount has post-increment values; pre-increment is reconstructed as currentUsed - cost. func (s *BalanceNotifyService) checkQuotaDimCrossings(freshAccount *Account, cost float64, adminEmails []string, siteName string) { for _, dim := range buildQuotaDims(freshAccount) { if !dim.enabled || dim.threshold <= 0 { @@ -160,10 +177,10 @@ func (s *BalanceNotifyService) checkQuotaDimCrossings(freshAccount *Account, cos if effectiveThreshold <= 0 { continue } - // dim.oldUsed is actually the post-increment value from fresh DB data; + // currentUsed is the post-increment value from fresh DB data; // reconstruct pre-increment value to detect threshold crossing. - newUsed := dim.oldUsed - oldUsed := dim.oldUsed - cost + newUsed := dim.currentUsed + oldUsed := dim.currentUsed - cost if oldUsed < effectiveThreshold && newUsed >= effectiveThreshold { s.asyncSendQuotaAlert(adminEmails, freshAccount.Name, dim, newUsed, effectiveThreshold, siteName) } @@ -309,10 +326,9 @@ func (s *BalanceNotifyService) sendQuotaAlertEmails(adminEmails []string, accoun s.sendEmails(adminEmails, subject, body, "account", accountName, "dimension", dimension) } -// buildBalanceLowEmailBody builds HTML email for balance low notification. -// Lines exceed 30 due to inline HTML template (not splittable). -func (s *BalanceNotifyService) buildBalanceLowEmailBody(userName string, balance, threshold float64, siteName string) string { - return fmt.Sprintf(` +// balanceLowEmailTemplate is the HTML template for balance low notifications. +// Format args: siteName, userName, userName, balance, threshold, threshold. +const balanceLowEmailTemplate = `
@@ -344,17 +360,11 @@ func (s *BalanceNotifyService) buildBalanceLowEmailBody(userName string, balance -`, siteName, userName, userName, balance, threshold, threshold) -} +