Files
sub2api/backend/internal/service/account_stats_pricing.go
erio a68df457d8 fix: address audit findings across websearch, notify, and channel pricing
Backend fixes:
- Fix balance notify ignoring percentage threshold type (was treating
  percentage value as fixed USD amount)
- Remove dead code parseJSONStringArray
- Add ImageOutputTokens to tryModelFilePricing calculation
- Unify zero-value check: cost == 0 → cost <= 0 in calculateTokenStatsCost
- Use MarshalNotifyEmails instead of json.Marshal for consistency
- Rename quotaDim.oldUsed → currentUsed for clarity
- Extract HTML email templates to const variables (function ≤30 lines)

Test fixes:
- Rewrite account_websearch_test.go for GetWebSearchEmulationMode tri-state
- Add 6 tryModelFilePricing test cases

Frontend fixes:
- Replace hardcoded '未命名' with i18n key
- Extract getBillingModeLabel/getBillingModeBadgeClass to shared utils
- Replace inline type with imported NotifyEmailEntry
- Pass platform to AccountStats pricing rules via inferRulePlatform()
- Add billing mode constants (BILLING_MODE_TOKEN/PER_REQUEST/IMAGE)
2026-04-14 09:26:08 +08:00

203 lines
6.1 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package service
import (
"context"
"sort"
"strings"
)
// resolveAccountStatsCost 计算账号统计定价费用。
// 返回 nil 表示不覆盖使用默认公式total_cost × account_rate_multiplier
//
// 优先级(先命中为准):
// 1. 自定义规则(始终尝试,不依赖 ApplyPricingToAccountStats 开关)
// 2. ApplyPricingToAccountStats 启用时用模型定价文件LiteLLM中上游模型的标准价格计算
// 3. nil → 走默认公式
//
// upstreamModel 是最终发往上游的模型 ID。
func resolveAccountStatsCost(
ctx context.Context,
channelService *ChannelService,
billingService *BillingService,
accountID int64,
groupID int64,
upstreamModel string,
tokens UsageTokens,
requestCount int,
) *float64 {
if channelService == nil || upstreamModel == "" {
return nil
}
channel, err := channelService.GetChannelForGroup(ctx, groupID)
if err != nil || channel == nil {
return nil
}
platform := channelService.GetGroupPlatform(ctx, groupID)
// 优先级 1自定义规则始终尝试
if cost := tryCustomRules(channel, accountID, groupID, platform, upstreamModel, tokens, requestCount); cost != nil {
return cost
}
// 优先级 2模型定价文件LiteLLM/fallback中上游模型的标准价格
if channel.ApplyPricingToAccountStats && billingService != nil {
return tryModelFilePricing(billingService, upstreamModel, tokens)
}
return nil
}
// tryModelFilePricing 使用模型定价文件LiteLLM/fallback中的标准价格计算费用。
func tryModelFilePricing(billingService *BillingService, model string, tokens UsageTokens) *float64 {
pricing, err := billingService.GetModelPricing(model)
if err != nil || pricing == nil {
return nil
}
cost := float64(tokens.InputTokens)*pricing.InputPricePerToken +
float64(tokens.OutputTokens)*pricing.OutputPricePerToken +
float64(tokens.CacheCreationTokens)*pricing.CacheCreationPricePerToken +
float64(tokens.CacheReadTokens)*pricing.CacheReadPricePerToken +
float64(tokens.ImageOutputTokens)*pricing.ImageOutputPricePerToken
if cost <= 0 {
return nil
}
return &cost
}
// tryCustomRules 遍历自定义规则,按数组顺序先命中为准。
func tryCustomRules(
channel *Channel, accountID, groupID int64,
platform, model string, tokens UsageTokens, requestCount int,
) *float64 {
modelLower := strings.ToLower(model)
for _, rule := range channel.AccountStatsPricingRules {
if !matchAccountStatsRule(&rule, accountID, groupID) {
continue
}
pricing := findPricingForModel(rule.Pricing, platform, modelLower)
if pricing == nil {
continue // 规则匹配但模型不在规则定价中,继续下一条
}
return calculateStatsCost(pricing, tokens, requestCount)
}
return nil
}
// matchAccountStatsRule 检查规则是否匹配指定的 accountID 和 groupID。
// 匹配条件accountID ∈ rule.AccountIDs 或 groupID ∈ rule.GroupIDs。
// 如果规则的 AccountIDs 和 GroupIDs 都为空,视为不匹配。
func matchAccountStatsRule(rule *AccountStatsPricingRule, accountID, groupID int64) bool {
if len(rule.AccountIDs) == 0 && len(rule.GroupIDs) == 0 {
return false
}
for _, id := range rule.AccountIDs {
if id == accountID {
return true
}
}
for _, id := range rule.GroupIDs {
if id == groupID {
return true
}
}
return false
}
// wildcardMatch 通配符匹配候选项(用于排序)
type wildcardMatch struct {
prefixLen int
pricing *ChannelModelPricing
}
// findPricingForModel 在定价列表中查找匹配的模型定价。
// 先精确匹配,再通配符匹配(前缀越长优先级越高)。
func findPricingForModel(pricingList []ChannelModelPricing, platform, modelLower string) *ChannelModelPricing {
// 精确匹配优先
for i := range pricingList {
p := &pricingList[i]
if !isPlatformMatch(platform, p.Platform) {
continue
}
for _, m := range p.Models {
if strings.ToLower(m) == modelLower {
return p
}
}
}
// 通配符匹配:收集所有匹配项,按前缀长度降序取最长
var matches []wildcardMatch
for i := range pricingList {
p := &pricingList[i]
if !isPlatformMatch(platform, p.Platform) {
continue
}
for _, m := range p.Models {
ml := strings.ToLower(m)
if !strings.HasSuffix(ml, "*") {
continue
}
prefix := strings.TrimSuffix(ml, "*")
if strings.HasPrefix(modelLower, prefix) {
matches = append(matches, wildcardMatch{prefixLen: len(prefix), pricing: p})
}
}
}
if len(matches) == 0 {
return nil
}
sort.Slice(matches, func(i, j int) bool {
return matches[i].prefixLen > matches[j].prefixLen
})
return matches[0].pricing
}
// isPlatformMatch 判断平台是否匹配(空平台视为不限平台)。
func isPlatformMatch(queryPlatform, pricingPlatform string) bool {
if queryPlatform == "" || pricingPlatform == "" {
return true
}
return queryPlatform == pricingPlatform
}
// calculateStatsCost 使用给定的定价计算费用(不含任何倍率,原始费用)。
func calculateStatsCost(pricing *ChannelModelPricing, tokens UsageTokens, requestCount int) *float64 {
if pricing == nil {
return nil
}
switch pricing.BillingMode {
case BillingModePerRequest, BillingModeImage:
return calculatePerRequestStatsCost(pricing, requestCount)
default:
return calculateTokenStatsCost(pricing, tokens)
}
}
// calculatePerRequestStatsCost 按次/图片计费。
func calculatePerRequestStatsCost(pricing *ChannelModelPricing, requestCount int) *float64 {
if pricing.PerRequestPrice == nil || *pricing.PerRequestPrice <= 0 {
return nil
}
cost := *pricing.PerRequestPrice * float64(requestCount)
return &cost
}
// calculateTokenStatsCost Token 计费。
func calculateTokenStatsCost(pricing *ChannelModelPricing, tokens UsageTokens) *float64 {
deref := func(p *float64) float64 {
if p == nil {
return 0
}
return *p
}
cost := float64(tokens.InputTokens)*deref(pricing.InputPrice) +
float64(tokens.OutputTokens)*deref(pricing.OutputPrice) +
float64(tokens.CacheCreationTokens)*deref(pricing.CacheWritePrice) +
float64(tokens.CacheReadTokens)*deref(pricing.CacheReadPrice) +
float64(tokens.ImageOutputTokens)*deref(pricing.ImageOutputPrice)
if cost <= 0 {
return nil
}
return &cost
}