fix: address audit findings across websearch, notify, and channel pricing

Backend fixes:
- Fix balance notify ignoring percentage threshold type (was treating
  percentage value as fixed USD amount)
- Remove dead code parseJSONStringArray
- Add ImageOutputTokens to tryModelFilePricing calculation
- Unify zero-value check: cost == 0 → cost <= 0 in calculateTokenStatsCost
- Use MarshalNotifyEmails instead of json.Marshal for consistency
- Rename quotaDim.oldUsed → currentUsed for clarity
- Extract HTML email templates to const variables (function ≤30 lines)

Test fixes:
- Rewrite account_websearch_test.go for GetWebSearchEmulationMode tri-state
- Add 6 tryModelFilePricing test cases

Frontend fixes:
- Replace hardcoded '未命名' with i18n key
- Extract getBillingModeLabel/getBillingModeBadgeClass to shared utils
- Replace inline type with imported NotifyEmailEntry
- Pass platform to AccountStats pricing rules via inferRulePlatform()
- Add billing mode constants (BILLING_MODE_TOKEN/PER_REQUEST/IMAGE)
This commit is contained in:
erio
2026-04-13 12:07:09 +08:00
parent 1262654d97
commit a68df457d8
12 changed files with 275 additions and 121 deletions

View File

@@ -57,7 +57,8 @@ func tryModelFilePricing(billingService *BillingService, model string, tokens Us
cost := float64(tokens.InputTokens)*pricing.InputPricePerToken +
float64(tokens.OutputTokens)*pricing.OutputPricePerToken +
float64(tokens.CacheCreationTokens)*pricing.CacheCreationPricePerToken +
float64(tokens.CacheReadTokens)*pricing.CacheReadPricePerToken
float64(tokens.CacheReadTokens)*pricing.CacheReadPricePerToken +
float64(tokens.ImageOutputTokens)*pricing.ImageOutputPricePerToken
if cost <= 0 {
return nil
}
@@ -194,7 +195,7 @@ func calculateTokenStatsCost(pricing *ChannelModelPricing, tokens UsageTokens) *
float64(tokens.CacheCreationTokens)*deref(pricing.CacheWritePrice) +
float64(tokens.CacheReadTokens)*deref(pricing.CacheReadPrice) +
float64(tokens.ImageOutputTokens)*deref(pricing.ImageOutputPrice)
if cost == 0 {
if cost <= 0 {
return nil
}
return &cost

View File

@@ -428,3 +428,102 @@ func TestTryCustomRules_RuleMatchesButModelNot_ContinuesToNext(t *testing.T) {
require.NotNil(t, result)
require.InDelta(t, 5.0, *result, 1e-12) // 使用规则2
}
// ---------------------------------------------------------------------------
// tryModelFilePricing
// ---------------------------------------------------------------------------
// newTestBillingServiceWithPrices creates a BillingService with pre-populated
// fallback prices for testing. No config or pricing service is needed.
// The key must match what getFallbackPricing resolves to for a given model name.
// E.g., model "claude-sonnet-4" resolves to key "claude-sonnet-4".
func newTestBillingServiceWithPrices(prices map[string]*ModelPricing) *BillingService {
return &BillingService{
fallbackPrices: prices,
}
}
func TestTryModelFilePricing_Success(t *testing.T) {
bs := newTestBillingServiceWithPrices(map[string]*ModelPricing{
"claude-sonnet-4": {
InputPricePerToken: 0.001,
OutputPricePerToken: 0.002,
},
})
tokens := UsageTokens{InputTokens: 100, OutputTokens: 50}
result := tryModelFilePricing(bs, "claude-sonnet-4", tokens)
require.NotNil(t, result)
// 100*0.001 + 50*0.002 = 0.1 + 0.1 = 0.2
require.InDelta(t, 0.2, *result, 1e-12)
}
func TestTryModelFilePricing_PricingNotFound(t *testing.T) {
// "nonexistent-model" does not match any fallback pattern
bs := newTestBillingServiceWithPrices(map[string]*ModelPricing{})
tokens := UsageTokens{InputTokens: 100, OutputTokens: 50}
result := tryModelFilePricing(bs, "nonexistent-model", tokens)
require.Nil(t, result)
}
func TestTryModelFilePricing_NilFallback(t *testing.T) {
// getFallbackPricing returns nil when key maps to nil
bs := newTestBillingServiceWithPrices(map[string]*ModelPricing{
"claude-sonnet-4": nil,
})
tokens := UsageTokens{InputTokens: 100}
result := tryModelFilePricing(bs, "claude-sonnet-4", tokens)
require.Nil(t, result)
}
func TestTryModelFilePricing_ZeroCost(t *testing.T) {
bs := newTestBillingServiceWithPrices(map[string]*ModelPricing{
"claude-sonnet-4": {
InputPricePerToken: 0.001,
OutputPricePerToken: 0.002,
},
})
tokens := UsageTokens{} // all zero tokens → cost = 0 → nil
result := tryModelFilePricing(bs, "claude-sonnet-4", tokens)
require.Nil(t, result)
}
func TestTryModelFilePricing_WithImageOutput(t *testing.T) {
bs := newTestBillingServiceWithPrices(map[string]*ModelPricing{
"claude-sonnet-4": {
InputPricePerToken: 0.001,
OutputPricePerToken: 0.002,
ImageOutputPricePerToken: 0.01,
},
})
tokens := UsageTokens{
InputTokens: 100,
OutputTokens: 50,
ImageOutputTokens: 10,
}
result := tryModelFilePricing(bs, "claude-sonnet-4", tokens)
require.NotNil(t, result)
// 100*0.001 + 50*0.002 + 10*0.01 = 0.1 + 0.1 + 0.1 = 0.3
require.InDelta(t, 0.3, *result, 1e-12)
}
func TestTryModelFilePricing_WithCacheTokens(t *testing.T) {
bs := newTestBillingServiceWithPrices(map[string]*ModelPricing{
"claude-sonnet-4": {
InputPricePerToken: 0.001,
OutputPricePerToken: 0.002,
CacheCreationPricePerToken: 0.003,
CacheReadPricePerToken: 0.0005,
},
})
tokens := UsageTokens{
InputTokens: 100,
OutputTokens: 50,
CacheCreationTokens: 200,
CacheReadTokens: 300,
}
result := tryModelFilePricing(bs, "claude-sonnet-4", tokens)
require.NotNil(t, result)
// 100*0.001 + 50*0.002 + 200*0.003 + 300*0.0005
// = 0.1 + 0.1 + 0.6 + 0.15 = 0.95
require.InDelta(t, 0.95, *result, 1e-12)
}

View File

@@ -1,3 +1,5 @@
//go:build unit
package service
import (
@@ -6,66 +8,98 @@ import (
"github.com/stretchr/testify/require"
)
func TestAccount_IsWebSearchEmulationEnabled_Enabled(t *testing.T) {
func TestGetWebSearchEmulationMode_Enabled(t *testing.T) {
a := &Account{
Platform: PlatformAnthropic,
Type: AccountTypeAPIKey,
Extra: map[string]any{featureKeyWebSearchEmulation: "enabled"},
}
require.Equal(t, WebSearchModeEnabled, a.GetWebSearchEmulationMode())
}
func TestGetWebSearchEmulationMode_Disabled(t *testing.T) {
a := &Account{
Platform: PlatformAnthropic,
Type: AccountTypeAPIKey,
Extra: map[string]any{featureKeyWebSearchEmulation: "disabled"},
}
require.Equal(t, WebSearchModeDisabled, a.GetWebSearchEmulationMode())
}
func TestGetWebSearchEmulationMode_Default(t *testing.T) {
a := &Account{
Platform: PlatformAnthropic,
Type: AccountTypeAPIKey,
Extra: map[string]any{featureKeyWebSearchEmulation: "default"},
}
require.Equal(t, WebSearchModeDefault, a.GetWebSearchEmulationMode())
}
func TestGetWebSearchEmulationMode_UnknownString(t *testing.T) {
a := &Account{
Platform: PlatformAnthropic,
Type: AccountTypeAPIKey,
Extra: map[string]any{featureKeyWebSearchEmulation: "unknown"},
}
require.Equal(t, WebSearchModeDefault, a.GetWebSearchEmulationMode())
}
func TestGetWebSearchEmulationMode_OldBoolTrue(t *testing.T) {
a := &Account{
Platform: PlatformAnthropic,
Type: AccountTypeAPIKey,
Extra: map[string]any{featureKeyWebSearchEmulation: true},
}
require.True(t, a.IsWebSearchEmulationEnabled())
// bool is not a string, type assertion fails → default
require.Equal(t, WebSearchModeDefault, a.GetWebSearchEmulationMode())
}
func TestAccount_IsWebSearchEmulationEnabled_Disabled(t *testing.T) {
func TestGetWebSearchEmulationMode_OldBoolFalse(t *testing.T) {
a := &Account{
Platform: PlatformAnthropic,
Type: AccountTypeAPIKey,
Extra: map[string]any{featureKeyWebSearchEmulation: false},
}
require.False(t, a.IsWebSearchEmulationEnabled())
require.Equal(t, WebSearchModeDefault, a.GetWebSearchEmulationMode())
}
func TestAccount_IsWebSearchEmulationEnabled_MissingField(t *testing.T) {
func TestGetWebSearchEmulationMode_NilAccount(t *testing.T) {
var a *Account
require.Equal(t, WebSearchModeDefault, a.GetWebSearchEmulationMode())
}
func TestGetWebSearchEmulationMode_NilExtra(t *testing.T) {
a := &Account{
Platform: PlatformAnthropic,
Type: AccountTypeAPIKey,
Extra: nil,
}
require.Equal(t, WebSearchModeDefault, a.GetWebSearchEmulationMode())
}
func TestGetWebSearchEmulationMode_MissingField(t *testing.T) {
a := &Account{
Platform: PlatformAnthropic,
Type: AccountTypeAPIKey,
Extra: map[string]any{},
}
require.False(t, a.IsWebSearchEmulationEnabled())
require.Equal(t, WebSearchModeDefault, a.GetWebSearchEmulationMode())
}
func TestAccount_IsWebSearchEmulationEnabled_WrongType(t *testing.T) {
a := &Account{
Platform: PlatformAnthropic,
Type: AccountTypeAPIKey,
Extra: map[string]any{featureKeyWebSearchEmulation: "true"},
}
require.False(t, a.IsWebSearchEmulationEnabled())
}
func TestAccount_IsWebSearchEmulationEnabled_NilExtra(t *testing.T) {
a := &Account{Platform: PlatformAnthropic, Type: AccountTypeAPIKey, Extra: nil}
require.False(t, a.IsWebSearchEmulationEnabled())
}
func TestAccount_IsWebSearchEmulationEnabled_NilAccount(t *testing.T) {
var a *Account
require.False(t, a.IsWebSearchEmulationEnabled())
}
func TestAccount_IsWebSearchEmulationEnabled_NonAnthropicPlatform(t *testing.T) {
func TestGetWebSearchEmulationMode_NonAnthropicPlatform(t *testing.T) {
a := &Account{
Platform: PlatformOpenAI,
Type: AccountTypeAPIKey,
Extra: map[string]any{featureKeyWebSearchEmulation: true},
Extra: map[string]any{featureKeyWebSearchEmulation: "enabled"},
}
require.False(t, a.IsWebSearchEmulationEnabled())
require.Equal(t, WebSearchModeDefault, a.GetWebSearchEmulationMode())
}
func TestAccount_IsWebSearchEmulationEnabled_NonAPIKeyType(t *testing.T) {
func TestGetWebSearchEmulationMode_NonAPIKeyType(t *testing.T) {
a := &Account{
Platform: PlatformAnthropic,
Type: AccountTypeOAuth,
Extra: map[string]any{featureKeyWebSearchEmulation: true},
Extra: map[string]any{featureKeyWebSearchEmulation: "enabled"},
}
require.False(t, a.IsWebSearchEmulationEnabled())
require.Equal(t, WebSearchModeDefault, a.GetWebSearchEmulationMode())
}

View File

@@ -2,7 +2,6 @@ package service
import (
"context"
"encoding/json"
"fmt"
"html"
"log/slog"
@@ -14,6 +13,10 @@ import (
const (
emailSendTimeout = 30 * time.Second
// Threshold type values
thresholdTypeFixed = "fixed"
thresholdTypePercentage = "percentage"
// Quota dimension labels
quotaDimDaily = "daily"
quotaDimWeekly = "weekly"
@@ -48,6 +51,15 @@ func NewBalanceNotifyService(emailService *EmailService, settingRepo SettingRepo
}
}
// resolveBalanceThreshold returns the effective balance threshold.
// For percentage type, it computes threshold = totalRecharged * percentage / 100.
func resolveBalanceThreshold(threshold float64, thresholdType string, totalRecharged float64) float64 {
if thresholdType == thresholdTypePercentage && totalRecharged > 0 {
return totalRecharged * threshold / 100
}
return threshold
}
// CheckBalanceAfterDeduction checks if balance crossed below threshold after deduction.
// oldBalance is the balance before deduction, cost is the amount deducted.
// Notification is sent only on first crossing: oldBalance >= threshold && newBalance < threshold.
@@ -73,8 +85,13 @@ func (s *BalanceNotifyService) CheckBalanceAfterDeduction(ctx context.Context, u
return
}
effectiveThreshold := resolveBalanceThreshold(threshold, user.BalanceNotifyThresholdType, user.TotalRecharged)
if effectiveThreshold <= 0 {
return
}
newBalance := oldBalance - cost
if oldBalance >= threshold && newBalance < threshold {
if oldBalance >= effectiveThreshold && newBalance < effectiveThreshold {
siteName := s.getSiteName(ctx)
recipients := s.collectBalanceNotifyRecipients(user)
go func() {
@@ -83,7 +100,7 @@ func (s *BalanceNotifyService) CheckBalanceAfterDeduction(ctx context.Context, u
slog.Error("panic in balance notification", "recover", r)
}
}()
s.sendBalanceLowEmails(recipients, user.Username, user.Email, newBalance, threshold, siteName)
s.sendBalanceLowEmails(recipients, user.Username, user.Email, newBalance, effectiveThreshold, siteName)
}()
}
}
@@ -94,14 +111,14 @@ type quotaDim struct {
enabled bool
threshold float64
thresholdType string // "fixed" (default) or "percentage"
oldUsed float64
currentUsed float64
limit float64
}
// resolvedThreshold returns the effective threshold value.
// For percentage type, it computes threshold = limit * percentage / 100.
func (d quotaDim) resolvedThreshold() float64 {
if d.thresholdType == "percentage" && d.limit > 0 {
if d.thresholdType == thresholdTypePercentage && d.limit > 0 {
return d.limit * d.threshold / 100
}
return d.threshold
@@ -150,7 +167,7 @@ func (s *BalanceNotifyService) fetchFreshAccount(ctx context.Context, snapshot *
}
// checkQuotaDimCrossings iterates quota dimensions and sends alerts for threshold crossings.
// freshAccount has post-increment values; oldUsed is reconstructed as freshUsed - cost.
// freshAccount has post-increment values; pre-increment is reconstructed as currentUsed - cost.
func (s *BalanceNotifyService) checkQuotaDimCrossings(freshAccount *Account, cost float64, adminEmails []string, siteName string) {
for _, dim := range buildQuotaDims(freshAccount) {
if !dim.enabled || dim.threshold <= 0 {
@@ -160,10 +177,10 @@ func (s *BalanceNotifyService) checkQuotaDimCrossings(freshAccount *Account, cos
if effectiveThreshold <= 0 {
continue
}
// dim.oldUsed is actually the post-increment value from fresh DB data;
// currentUsed is the post-increment value from fresh DB data;
// reconstruct pre-increment value to detect threshold crossing.
newUsed := dim.oldUsed
oldUsed := dim.oldUsed - cost
newUsed := dim.currentUsed
oldUsed := dim.currentUsed - cost
if oldUsed < effectiveThreshold && newUsed >= effectiveThreshold {
s.asyncSendQuotaAlert(adminEmails, freshAccount.Name, dim, newUsed, effectiveThreshold, siteName)
}
@@ -309,10 +326,9 @@ func (s *BalanceNotifyService) sendQuotaAlertEmails(adminEmails []string, accoun
s.sendEmails(adminEmails, subject, body, "account", accountName, "dimension", dimension)
}
// buildBalanceLowEmailBody builds HTML email for balance low notification.
// Lines exceed 30 due to inline HTML template (not splittable).
func (s *BalanceNotifyService) buildBalanceLowEmailBody(userName string, balance, threshold float64, siteName string) string {
return fmt.Sprintf(`<!DOCTYPE html>
// balanceLowEmailTemplate is the HTML template for balance low notifications.
// Format args: siteName, userName, userName, balance, threshold, threshold.
const balanceLowEmailTemplate = `<!DOCTYPE html>
<html>
<head>
<meta charset="UTF-8">
@@ -344,17 +360,11 @@ func (s *BalanceNotifyService) buildBalanceLowEmailBody(userName string, balance
<div class="footer"><p>此邮件由系统自动发送,请勿回复。</p></div>
</div>
</body>
</html>`, siteName, userName, userName, balance, threshold, threshold)
}
</html>`
// buildQuotaAlertEmailBody builds HTML email for account quota alert.
// Lines exceed 30 due to inline HTML template (not splittable).
func (s *BalanceNotifyService) buildQuotaAlertEmailBody(accountName, dimLabel string, used, limit, threshold float64, siteName string) string {
limitStr := fmt.Sprintf("$%.2f", limit)
if limit <= 0 {
limitStr = "无限制 / Unlimited"
}
return fmt.Sprintf(`<!DOCTYPE html>
// quotaAlertEmailTemplate is the HTML template for account quota alert notifications.
// Format args: siteName, accountName, dimLabel, used, limitStr, threshold.
const quotaAlertEmailTemplate = `<!DOCTYPE html>
<html>
<head>
<meta charset="UTF-8">
@@ -389,18 +399,19 @@ func (s *BalanceNotifyService) buildQuotaAlertEmailBody(accountName, dimLabel st
<div class="footer"><p>此邮件由系统自动发送,请勿回复。</p></div>
</div>
</body>
</html>`, siteName, accountName, dimLabel, used, limitStr, threshold)
</html>`
// buildBalanceLowEmailBody builds HTML email for balance low notification.
func (s *BalanceNotifyService) buildBalanceLowEmailBody(userName string, balance, threshold float64, siteName string) string {
return fmt.Sprintf(balanceLowEmailTemplate, siteName, userName, userName, balance, threshold, threshold)
}
// parseJSONStringArray parses a JSON string array, returns nil on error.
func parseJSONStringArray(raw string) []string {
raw = strings.TrimSpace(raw)
if raw == "" || raw == "[]" {
return nil
// buildQuotaAlertEmailBody builds HTML email for account quota alert.
func (s *BalanceNotifyService) buildQuotaAlertEmailBody(accountName, dimLabel string, used, limit, threshold float64, siteName string) string {
limitStr := fmt.Sprintf("$%.2f", limit)
if limit <= 0 {
limitStr = "无限制 / Unlimited"
}
var result []string
if err := json.Unmarshal([]byte(raw), &result); err != nil {
return nil
}
return result
return fmt.Sprintf(quotaAlertEmailTemplate, siteName, accountName, dimLabel, used, limitStr, threshold)
}

View File

@@ -627,11 +627,7 @@ func (s *SettingService) UpdateSettings(ctx context.Context, settings *SystemSet
updates[SettingKeyBalanceLowNotifyEnabled] = strconv.FormatBool(settings.BalanceLowNotifyEnabled)
updates[SettingKeyBalanceLowNotifyThreshold] = strconv.FormatFloat(settings.BalanceLowNotifyThreshold, 'f', 8, 64)
updates[SettingKeyAccountQuotaNotifyEnabled] = strconv.FormatBool(settings.AccountQuotaNotifyEnabled)
accountQuotaNotifyEmailsJSON, err := json.Marshal(settings.AccountQuotaNotifyEmails)
if err != nil {
return fmt.Errorf("marshal account quota notify emails: %w", err)
}
updates[SettingKeyAccountQuotaNotifyEmails] = string(accountQuotaNotifyEmailsJSON)
updates[SettingKeyAccountQuotaNotifyEmails] = MarshalNotifyEmails(settings.AccountQuotaNotifyEmails)
err = s.settingRepo.SetMultiple(ctx, updates)
if err == nil {