Allow channels to configure independent model pricing for account statistics cost calculation, decoupled from user billing. Backend: - Migration 101: channels.apply_pricing_to_account_stats toggle, channel_account_stats_pricing_rules/model_pricing tables, usage_logs.account_stats_cost column - resolveAccountStatsCost: match rules by group/account, then channel pricing, fallback to original formula when unconfigured - Integrate into both GatewayService.recordUsageCore and OpenAIGatewayService.RecordUsage - Update 8 account stats SQL queries to use COALESCE(account_stats_cost, total_cost) * account_rate_multiplier - 23 unit tests for matching, pricing lookup, and cost calculation Frontend: - Channel edit dialog: toggle + custom rules UI with group/account multi-select and pricing entry cards - API types and i18n (zh/en)
193 lines
5.7 KiB
Go
193 lines
5.7 KiB
Go
package service
|
||
|
||
import (
|
||
"context"
|
||
"sort"
|
||
"strings"
|
||
)
|
||
|
||
// resolveAccountStatsCost 计算账号统计定价费用。
|
||
// 返回 nil 表示不覆盖,使用默认公式(total_cost × account_rate_multiplier)。
|
||
//
|
||
// 匹配优先级(先命中为准):
|
||
// 1. 自定义规则(AccountStatsPricingRules,按数组顺序遍历)
|
||
// 2. 渠道已有的模型定价(ApplyPricingToAccountStats 开启时)
|
||
// 3. nil → 走默认公式
|
||
func resolveAccountStatsCost(
|
||
ctx context.Context,
|
||
channelService *ChannelService,
|
||
billingService *BillingService,
|
||
accountID int64,
|
||
groupID int64,
|
||
billingModel string,
|
||
tokens UsageTokens,
|
||
requestCount int,
|
||
serviceTier string,
|
||
) *float64 {
|
||
if channelService == nil || billingService == nil {
|
||
return nil
|
||
}
|
||
channel, err := channelService.GetChannelForGroup(ctx, groupID)
|
||
if err != nil || channel == nil || !channel.ApplyPricingToAccountStats {
|
||
return nil
|
||
}
|
||
|
||
platform := channelService.GetGroupPlatform(ctx, groupID)
|
||
modelLower := strings.ToLower(billingModel)
|
||
|
||
// 优先级 1:自定义规则
|
||
if cost := tryCustomRules(channel, accountID, groupID, platform, modelLower, tokens, requestCount); cost != nil {
|
||
return cost
|
||
}
|
||
|
||
// 优先级 2:渠道已有模型定价
|
||
return tryChannelPricing(ctx, channelService, groupID, billingModel, tokens, requestCount)
|
||
}
|
||
|
||
// tryCustomRules 遍历自定义规则,按数组顺序先命中为准。
|
||
func tryCustomRules(
|
||
channel *Channel, accountID, groupID int64,
|
||
platform, modelLower string, tokens UsageTokens, requestCount int,
|
||
) *float64 {
|
||
for _, rule := range channel.AccountStatsPricingRules {
|
||
if !matchAccountStatsRule(&rule, accountID, groupID) {
|
||
continue
|
||
}
|
||
pricing := findPricingForModel(rule.Pricing, platform, modelLower)
|
||
if pricing == nil {
|
||
continue // 规则匹配但模型不在规则定价中,继续下一条
|
||
}
|
||
return calculateStatsCost(pricing, tokens, requestCount)
|
||
}
|
||
return nil
|
||
}
|
||
|
||
// tryChannelPricing 使用渠道已有的模型定价计算账号统计费用。
|
||
func tryChannelPricing(
|
||
ctx context.Context, channelService *ChannelService,
|
||
groupID int64, billingModel string, tokens UsageTokens, requestCount int,
|
||
) *float64 {
|
||
pricing := channelService.GetChannelModelPricing(ctx, groupID, billingModel)
|
||
if pricing == nil {
|
||
return nil
|
||
}
|
||
return calculateStatsCost(pricing, tokens, requestCount)
|
||
}
|
||
|
||
// matchAccountStatsRule 检查规则是否匹配指定的 accountID 和 groupID。
|
||
// 匹配条件:accountID ∈ rule.AccountIDs 或 groupID ∈ rule.GroupIDs。
|
||
// 如果规则的 AccountIDs 和 GroupIDs 都为空,视为不匹配。
|
||
func matchAccountStatsRule(rule *AccountStatsPricingRule, accountID, groupID int64) bool {
|
||
if len(rule.AccountIDs) == 0 && len(rule.GroupIDs) == 0 {
|
||
return false
|
||
}
|
||
for _, id := range rule.AccountIDs {
|
||
if id == accountID {
|
||
return true
|
||
}
|
||
}
|
||
for _, id := range rule.GroupIDs {
|
||
if id == groupID {
|
||
return true
|
||
}
|
||
}
|
||
return false
|
||
}
|
||
|
||
// wildcardMatch 通配符匹配候选项(用于排序)
|
||
type wildcardMatch struct {
|
||
prefixLen int
|
||
pricing *ChannelModelPricing
|
||
}
|
||
|
||
// findPricingForModel 在定价列表中查找匹配的模型定价。
|
||
// 先精确匹配,再通配符匹配(前缀越长优先级越高)。
|
||
func findPricingForModel(pricingList []ChannelModelPricing, platform, modelLower string) *ChannelModelPricing {
|
||
// 精确匹配优先
|
||
for i := range pricingList {
|
||
p := &pricingList[i]
|
||
if !isPlatformMatch(platform, p.Platform) {
|
||
continue
|
||
}
|
||
for _, m := range p.Models {
|
||
if strings.ToLower(m) == modelLower {
|
||
return p
|
||
}
|
||
}
|
||
}
|
||
// 通配符匹配:收集所有匹配项,按前缀长度降序取最长
|
||
var matches []wildcardMatch
|
||
for i := range pricingList {
|
||
p := &pricingList[i]
|
||
if !isPlatformMatch(platform, p.Platform) {
|
||
continue
|
||
}
|
||
for _, m := range p.Models {
|
||
ml := strings.ToLower(m)
|
||
if !strings.HasSuffix(ml, "*") {
|
||
continue
|
||
}
|
||
prefix := strings.TrimSuffix(ml, "*")
|
||
if strings.HasPrefix(modelLower, prefix) {
|
||
matches = append(matches, wildcardMatch{prefixLen: len(prefix), pricing: p})
|
||
}
|
||
}
|
||
}
|
||
if len(matches) == 0 {
|
||
return nil
|
||
}
|
||
sort.Slice(matches, func(i, j int) bool {
|
||
return matches[i].prefixLen > matches[j].prefixLen
|
||
})
|
||
return matches[0].pricing
|
||
}
|
||
|
||
// isPlatformMatch 判断平台是否匹配(空平台视为不限平台)。
|
||
func isPlatformMatch(queryPlatform, pricingPlatform string) bool {
|
||
if queryPlatform == "" || pricingPlatform == "" {
|
||
return true
|
||
}
|
||
return queryPlatform == pricingPlatform
|
||
}
|
||
|
||
// calculateStatsCost 使用给定的定价计算费用(不含任何倍率,原始费用)。
|
||
func calculateStatsCost(pricing *ChannelModelPricing, tokens UsageTokens, requestCount int) *float64 {
|
||
if pricing == nil {
|
||
return nil
|
||
}
|
||
switch pricing.BillingMode {
|
||
case BillingModePerRequest, BillingModeImage:
|
||
return calculatePerRequestStatsCost(pricing, requestCount)
|
||
default:
|
||
return calculateTokenStatsCost(pricing, tokens)
|
||
}
|
||
}
|
||
|
||
// calculatePerRequestStatsCost 按次/图片计费。
|
||
func calculatePerRequestStatsCost(pricing *ChannelModelPricing, requestCount int) *float64 {
|
||
if pricing.PerRequestPrice == nil || *pricing.PerRequestPrice <= 0 {
|
||
return nil
|
||
}
|
||
cost := *pricing.PerRequestPrice * float64(requestCount)
|
||
return &cost
|
||
}
|
||
|
||
// calculateTokenStatsCost Token 计费。
|
||
func calculateTokenStatsCost(pricing *ChannelModelPricing, tokens UsageTokens) *float64 {
|
||
deref := func(p *float64) float64 {
|
||
if p == nil {
|
||
return 0
|
||
}
|
||
return *p
|
||
}
|
||
cost := float64(tokens.InputTokens)*deref(pricing.InputPrice) +
|
||
float64(tokens.OutputTokens)*deref(pricing.OutputPrice) +
|
||
float64(tokens.CacheCreationTokens)*deref(pricing.CacheWritePrice) +
|
||
float64(tokens.CacheReadTokens)*deref(pricing.CacheReadPrice) +
|
||
float64(tokens.ImageOutputTokens)*deref(pricing.ImageOutputPrice)
|
||
if cost == 0 {
|
||
return nil
|
||
}
|
||
return &cost
|
||
}
|