feat(channels): add custom account stats pricing rules
Allow channels to configure independent model pricing for account statistics cost calculation, decoupled from user billing. Backend: - Migration 101: channels.apply_pricing_to_account_stats toggle, channel_account_stats_pricing_rules/model_pricing tables, usage_logs.account_stats_cost column - resolveAccountStatsCost: match rules by group/account, then channel pricing, fallback to original formula when unconfigured - Integrate into both GatewayService.recordUsageCore and OpenAIGatewayService.RecordUsage - Update 8 account stats SQL queries to use COALESCE(account_stats_cost, total_cost) * account_rate_multiplier - 23 unit tests for matching, pricing lookup, and cost calculation Frontend: - Channel edit dialog: toggle + custom rules UI with group/account multi-select and pricing entry cards - API types and i18n (zh/en)
This commit is contained in:
192
backend/internal/service/account_stats_pricing.go
Normal file
192
backend/internal/service/account_stats_pricing.go
Normal file
@@ -0,0 +1,192 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sort"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// resolveAccountStatsCost 计算账号统计定价费用。
|
||||
// 返回 nil 表示不覆盖,使用默认公式(total_cost × account_rate_multiplier)。
|
||||
//
|
||||
// 匹配优先级(先命中为准):
|
||||
// 1. 自定义规则(AccountStatsPricingRules,按数组顺序遍历)
|
||||
// 2. 渠道已有的模型定价(ApplyPricingToAccountStats 开启时)
|
||||
// 3. nil → 走默认公式
|
||||
func resolveAccountStatsCost(
|
||||
ctx context.Context,
|
||||
channelService *ChannelService,
|
||||
billingService *BillingService,
|
||||
accountID int64,
|
||||
groupID int64,
|
||||
billingModel string,
|
||||
tokens UsageTokens,
|
||||
requestCount int,
|
||||
serviceTier string,
|
||||
) *float64 {
|
||||
if channelService == nil || billingService == nil {
|
||||
return nil
|
||||
}
|
||||
channel, err := channelService.GetChannelForGroup(ctx, groupID)
|
||||
if err != nil || channel == nil || !channel.ApplyPricingToAccountStats {
|
||||
return nil
|
||||
}
|
||||
|
||||
platform := channelService.GetGroupPlatform(ctx, groupID)
|
||||
modelLower := strings.ToLower(billingModel)
|
||||
|
||||
// 优先级 1:自定义规则
|
||||
if cost := tryCustomRules(channel, accountID, groupID, platform, modelLower, tokens, requestCount); cost != nil {
|
||||
return cost
|
||||
}
|
||||
|
||||
// 优先级 2:渠道已有模型定价
|
||||
return tryChannelPricing(ctx, channelService, groupID, billingModel, tokens, requestCount)
|
||||
}
|
||||
|
||||
// tryCustomRules 遍历自定义规则,按数组顺序先命中为准。
|
||||
func tryCustomRules(
|
||||
channel *Channel, accountID, groupID int64,
|
||||
platform, modelLower string, tokens UsageTokens, requestCount int,
|
||||
) *float64 {
|
||||
for _, rule := range channel.AccountStatsPricingRules {
|
||||
if !matchAccountStatsRule(&rule, accountID, groupID) {
|
||||
continue
|
||||
}
|
||||
pricing := findPricingForModel(rule.Pricing, platform, modelLower)
|
||||
if pricing == nil {
|
||||
continue // 规则匹配但模型不在规则定价中,继续下一条
|
||||
}
|
||||
return calculateStatsCost(pricing, tokens, requestCount)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// tryChannelPricing 使用渠道已有的模型定价计算账号统计费用。
|
||||
func tryChannelPricing(
|
||||
ctx context.Context, channelService *ChannelService,
|
||||
groupID int64, billingModel string, tokens UsageTokens, requestCount int,
|
||||
) *float64 {
|
||||
pricing := channelService.GetChannelModelPricing(ctx, groupID, billingModel)
|
||||
if pricing == nil {
|
||||
return nil
|
||||
}
|
||||
return calculateStatsCost(pricing, tokens, requestCount)
|
||||
}
|
||||
|
||||
// matchAccountStatsRule 检查规则是否匹配指定的 accountID 和 groupID。
|
||||
// 匹配条件:accountID ∈ rule.AccountIDs 或 groupID ∈ rule.GroupIDs。
|
||||
// 如果规则的 AccountIDs 和 GroupIDs 都为空,视为不匹配。
|
||||
func matchAccountStatsRule(rule *AccountStatsPricingRule, accountID, groupID int64) bool {
|
||||
if len(rule.AccountIDs) == 0 && len(rule.GroupIDs) == 0 {
|
||||
return false
|
||||
}
|
||||
for _, id := range rule.AccountIDs {
|
||||
if id == accountID {
|
||||
return true
|
||||
}
|
||||
}
|
||||
for _, id := range rule.GroupIDs {
|
||||
if id == groupID {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// wildcardMatch 通配符匹配候选项(用于排序)
|
||||
type wildcardMatch struct {
|
||||
prefixLen int
|
||||
pricing *ChannelModelPricing
|
||||
}
|
||||
|
||||
// findPricingForModel 在定价列表中查找匹配的模型定价。
|
||||
// 先精确匹配,再通配符匹配(前缀越长优先级越高)。
|
||||
func findPricingForModel(pricingList []ChannelModelPricing, platform, modelLower string) *ChannelModelPricing {
|
||||
// 精确匹配优先
|
||||
for i := range pricingList {
|
||||
p := &pricingList[i]
|
||||
if !isPlatformMatch(platform, p.Platform) {
|
||||
continue
|
||||
}
|
||||
for _, m := range p.Models {
|
||||
if strings.ToLower(m) == modelLower {
|
||||
return p
|
||||
}
|
||||
}
|
||||
}
|
||||
// 通配符匹配:收集所有匹配项,按前缀长度降序取最长
|
||||
var matches []wildcardMatch
|
||||
for i := range pricingList {
|
||||
p := &pricingList[i]
|
||||
if !isPlatformMatch(platform, p.Platform) {
|
||||
continue
|
||||
}
|
||||
for _, m := range p.Models {
|
||||
ml := strings.ToLower(m)
|
||||
if !strings.HasSuffix(ml, "*") {
|
||||
continue
|
||||
}
|
||||
prefix := strings.TrimSuffix(ml, "*")
|
||||
if strings.HasPrefix(modelLower, prefix) {
|
||||
matches = append(matches, wildcardMatch{prefixLen: len(prefix), pricing: p})
|
||||
}
|
||||
}
|
||||
}
|
||||
if len(matches) == 0 {
|
||||
return nil
|
||||
}
|
||||
sort.Slice(matches, func(i, j int) bool {
|
||||
return matches[i].prefixLen > matches[j].prefixLen
|
||||
})
|
||||
return matches[0].pricing
|
||||
}
|
||||
|
||||
// isPlatformMatch 判断平台是否匹配(空平台视为不限平台)。
|
||||
func isPlatformMatch(queryPlatform, pricingPlatform string) bool {
|
||||
if queryPlatform == "" || pricingPlatform == "" {
|
||||
return true
|
||||
}
|
||||
return queryPlatform == pricingPlatform
|
||||
}
|
||||
|
||||
// calculateStatsCost 使用给定的定价计算费用(不含任何倍率,原始费用)。
|
||||
func calculateStatsCost(pricing *ChannelModelPricing, tokens UsageTokens, requestCount int) *float64 {
|
||||
if pricing == nil {
|
||||
return nil
|
||||
}
|
||||
switch pricing.BillingMode {
|
||||
case BillingModePerRequest, BillingModeImage:
|
||||
return calculatePerRequestStatsCost(pricing, requestCount)
|
||||
default:
|
||||
return calculateTokenStatsCost(pricing, tokens)
|
||||
}
|
||||
}
|
||||
|
||||
// calculatePerRequestStatsCost 按次/图片计费。
|
||||
func calculatePerRequestStatsCost(pricing *ChannelModelPricing, requestCount int) *float64 {
|
||||
if pricing.PerRequestPrice == nil || *pricing.PerRequestPrice <= 0 {
|
||||
return nil
|
||||
}
|
||||
cost := *pricing.PerRequestPrice * float64(requestCount)
|
||||
return &cost
|
||||
}
|
||||
|
||||
// calculateTokenStatsCost Token 计费。
|
||||
func calculateTokenStatsCost(pricing *ChannelModelPricing, tokens UsageTokens) *float64 {
|
||||
deref := func(p *float64) float64 {
|
||||
if p == nil {
|
||||
return 0
|
||||
}
|
||||
return *p
|
||||
}
|
||||
cost := float64(tokens.InputTokens)*deref(pricing.InputPrice) +
|
||||
float64(tokens.OutputTokens)*deref(pricing.OutputPrice) +
|
||||
float64(tokens.CacheCreationTokens)*deref(pricing.CacheWritePrice) +
|
||||
float64(tokens.CacheReadTokens)*deref(pricing.CacheReadPrice) +
|
||||
float64(tokens.ImageOutputTokens)*deref(pricing.ImageOutputPrice)
|
||||
if cost == 0 {
|
||||
return nil
|
||||
}
|
||||
return &cost
|
||||
}
|
||||
430
backend/internal/service/account_stats_pricing_test.go
Normal file
430
backend/internal/service/account_stats_pricing_test.go
Normal file
@@ -0,0 +1,430 @@
|
||||
//go:build unit
|
||||
|
||||
package service
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// matchAccountStatsRule
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func TestMatchAccountStatsRule_BothEmpty_NoMatch(t *testing.T) {
|
||||
rule := &AccountStatsPricingRule{}
|
||||
require.False(t, matchAccountStatsRule(rule, 1, 10))
|
||||
}
|
||||
|
||||
func TestMatchAccountStatsRule_AccountIDMatch(t *testing.T) {
|
||||
rule := &AccountStatsPricingRule{AccountIDs: []int64{1, 2, 3}}
|
||||
require.True(t, matchAccountStatsRule(rule, 2, 999))
|
||||
}
|
||||
|
||||
func TestMatchAccountStatsRule_GroupIDMatch(t *testing.T) {
|
||||
rule := &AccountStatsPricingRule{GroupIDs: []int64{10, 20}}
|
||||
require.True(t, matchAccountStatsRule(rule, 999, 20))
|
||||
}
|
||||
|
||||
func TestMatchAccountStatsRule_BothConfigured_AccountMatch(t *testing.T) {
|
||||
rule := &AccountStatsPricingRule{
|
||||
AccountIDs: []int64{1, 2},
|
||||
GroupIDs: []int64{10, 20},
|
||||
}
|
||||
require.True(t, matchAccountStatsRule(rule, 2, 999))
|
||||
}
|
||||
|
||||
func TestMatchAccountStatsRule_BothConfigured_GroupMatch(t *testing.T) {
|
||||
rule := &AccountStatsPricingRule{
|
||||
AccountIDs: []int64{1, 2},
|
||||
GroupIDs: []int64{10, 20},
|
||||
}
|
||||
require.True(t, matchAccountStatsRule(rule, 999, 10))
|
||||
}
|
||||
|
||||
func TestMatchAccountStatsRule_BothConfigured_NeitherMatch(t *testing.T) {
|
||||
rule := &AccountStatsPricingRule{
|
||||
AccountIDs: []int64{1, 2},
|
||||
GroupIDs: []int64{10, 20},
|
||||
}
|
||||
require.False(t, matchAccountStatsRule(rule, 999, 999))
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// findPricingForModel
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func TestFindPricingForModel(t *testing.T) {
|
||||
exactPricing := ChannelModelPricing{
|
||||
ID: 1,
|
||||
Models: []string{"claude-opus-4"},
|
||||
}
|
||||
wildcardPricing := ChannelModelPricing{
|
||||
ID: 2,
|
||||
Models: []string{"claude-*"},
|
||||
}
|
||||
platformPricing := ChannelModelPricing{
|
||||
ID: 3,
|
||||
Platform: "openai",
|
||||
Models: []string{"gpt-4o"},
|
||||
}
|
||||
emptyPlatformPricing := ChannelModelPricing{
|
||||
ID: 4,
|
||||
Models: []string{"gemini-2.5-pro"},
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
list []ChannelModelPricing
|
||||
platform string
|
||||
model string
|
||||
wantID int64
|
||||
wantNil bool
|
||||
}{
|
||||
{
|
||||
name: "exact match",
|
||||
list: []ChannelModelPricing{exactPricing},
|
||||
platform: "anthropic",
|
||||
model: "claude-opus-4",
|
||||
wantID: 1,
|
||||
},
|
||||
{
|
||||
name: "exact match case insensitive",
|
||||
list: []ChannelModelPricing{{ID: 5, Models: []string{"Claude-Opus-4"}}},
|
||||
platform: "",
|
||||
model: "claude-opus-4",
|
||||
wantID: 5,
|
||||
},
|
||||
{
|
||||
name: "wildcard match",
|
||||
list: []ChannelModelPricing{wildcardPricing},
|
||||
platform: "anthropic",
|
||||
model: "claude-opus-4",
|
||||
wantID: 2,
|
||||
},
|
||||
{
|
||||
name: "exact match takes priority over wildcard",
|
||||
list: []ChannelModelPricing{wildcardPricing, exactPricing},
|
||||
platform: "anthropic",
|
||||
model: "claude-opus-4",
|
||||
wantID: 1,
|
||||
},
|
||||
{
|
||||
name: "platform mismatch skipped",
|
||||
list: []ChannelModelPricing{platformPricing},
|
||||
platform: "anthropic",
|
||||
model: "gpt-4o",
|
||||
wantNil: true,
|
||||
},
|
||||
{
|
||||
name: "empty platform in pricing matches any",
|
||||
list: []ChannelModelPricing{emptyPlatformPricing},
|
||||
platform: "gemini",
|
||||
model: "gemini-2.5-pro",
|
||||
wantID: 4,
|
||||
},
|
||||
{
|
||||
name: "empty platform in query matches any pricing platform",
|
||||
list: []ChannelModelPricing{platformPricing},
|
||||
platform: "",
|
||||
model: "gpt-4o",
|
||||
wantID: 3,
|
||||
},
|
||||
{
|
||||
name: "no match at all",
|
||||
list: []ChannelModelPricing{exactPricing, wildcardPricing},
|
||||
platform: "anthropic",
|
||||
model: "gpt-4o",
|
||||
wantNil: true,
|
||||
},
|
||||
{
|
||||
name: "empty list returns nil",
|
||||
list: nil,
|
||||
model: "claude-opus-4",
|
||||
wantNil: true,
|
||||
},
|
||||
{
|
||||
name: "longer wildcard prefix wins over shorter",
|
||||
list: []ChannelModelPricing{
|
||||
{ID: 10, Models: []string{"claude-*"}},
|
||||
{ID: 11, Models: []string{"claude-opus-*"}},
|
||||
},
|
||||
platform: "",
|
||||
model: "claude-opus-4",
|
||||
wantID: 11, // "claude-opus-" (12 chars) > "claude-" (7 chars)
|
||||
},
|
||||
{
|
||||
name: "shorter wildcard used when longer does not match",
|
||||
list: []ChannelModelPricing{
|
||||
{ID: 10, Models: []string{"claude-*"}},
|
||||
{ID: 11, Models: []string{"claude-opus-*"}},
|
||||
},
|
||||
platform: "",
|
||||
model: "claude-sonnet-4",
|
||||
wantID: 10, // only "claude-*" matches
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := findPricingForModel(tt.list, tt.platform, tt.model)
|
||||
if tt.wantNil {
|
||||
require.Nil(t, result)
|
||||
return
|
||||
}
|
||||
require.NotNil(t, result)
|
||||
require.Equal(t, tt.wantID, result.ID)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// calculateStatsCost
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func TestCalculateStatsCost_NilPricing(t *testing.T) {
|
||||
result := calculateStatsCost(nil, UsageTokens{}, 1)
|
||||
require.Nil(t, result)
|
||||
}
|
||||
|
||||
func TestCalculateStatsCost_TokenBilling(t *testing.T) {
|
||||
pricing := &ChannelModelPricing{
|
||||
BillingMode: BillingModeToken,
|
||||
InputPrice: testPtrFloat64(0.001),
|
||||
OutputPrice: testPtrFloat64(0.002),
|
||||
}
|
||||
tokens := UsageTokens{
|
||||
InputTokens: 100,
|
||||
OutputTokens: 50,
|
||||
}
|
||||
result := calculateStatsCost(pricing, tokens, 1)
|
||||
require.NotNil(t, result)
|
||||
// 100*0.001 + 50*0.002 = 0.1 + 0.1 = 0.2
|
||||
require.InDelta(t, 0.2, *result, 1e-12)
|
||||
}
|
||||
|
||||
func TestCalculateStatsCost_TokenBilling_WithCache(t *testing.T) {
|
||||
pricing := &ChannelModelPricing{
|
||||
BillingMode: BillingModeToken,
|
||||
InputPrice: testPtrFloat64(0.001),
|
||||
OutputPrice: testPtrFloat64(0.002),
|
||||
CacheWritePrice: testPtrFloat64(0.003),
|
||||
CacheReadPrice: testPtrFloat64(0.0005),
|
||||
}
|
||||
tokens := UsageTokens{
|
||||
InputTokens: 100,
|
||||
OutputTokens: 50,
|
||||
CacheCreationTokens: 200,
|
||||
CacheReadTokens: 300,
|
||||
}
|
||||
result := calculateStatsCost(pricing, tokens, 1)
|
||||
require.NotNil(t, result)
|
||||
// 100*0.001 + 50*0.002 + 200*0.003 + 300*0.0005
|
||||
// = 0.1 + 0.1 + 0.6 + 0.15 = 0.95
|
||||
require.InDelta(t, 0.95, *result, 1e-12)
|
||||
}
|
||||
|
||||
func TestCalculateStatsCost_TokenBilling_WithImageOutput(t *testing.T) {
|
||||
pricing := &ChannelModelPricing{
|
||||
BillingMode: BillingModeToken,
|
||||
InputPrice: testPtrFloat64(0.001),
|
||||
OutputPrice: testPtrFloat64(0.002),
|
||||
ImageOutputPrice: testPtrFloat64(0.01),
|
||||
}
|
||||
tokens := UsageTokens{
|
||||
InputTokens: 100,
|
||||
OutputTokens: 50,
|
||||
ImageOutputTokens: 10,
|
||||
}
|
||||
result := calculateStatsCost(pricing, tokens, 1)
|
||||
require.NotNil(t, result)
|
||||
// 100*0.001 + 50*0.002 + 10*0.01 = 0.1 + 0.1 + 0.1 = 0.3
|
||||
require.InDelta(t, 0.3, *result, 1e-12)
|
||||
}
|
||||
|
||||
func TestCalculateStatsCost_TokenBilling_PartialPricesNil(t *testing.T) {
|
||||
pricing := &ChannelModelPricing{
|
||||
BillingMode: BillingModeToken,
|
||||
InputPrice: testPtrFloat64(0.001),
|
||||
// OutputPrice, CacheWritePrice, etc. are all nil → treated as 0
|
||||
}
|
||||
tokens := UsageTokens{
|
||||
InputTokens: 100,
|
||||
OutputTokens: 50,
|
||||
CacheCreationTokens: 200,
|
||||
}
|
||||
result := calculateStatsCost(pricing, tokens, 1)
|
||||
require.NotNil(t, result)
|
||||
// Only input contributes: 100*0.001 = 0.1
|
||||
require.InDelta(t, 0.1, *result, 1e-12)
|
||||
}
|
||||
|
||||
func TestCalculateStatsCost_TokenBilling_AllTokensZero(t *testing.T) {
|
||||
pricing := &ChannelModelPricing{
|
||||
BillingMode: BillingModeToken,
|
||||
InputPrice: testPtrFloat64(0.001),
|
||||
OutputPrice: testPtrFloat64(0.002),
|
||||
}
|
||||
tokens := UsageTokens{} // all zeros
|
||||
result := calculateStatsCost(pricing, tokens, 1)
|
||||
// totalCost == 0 → returns nil (does not override, falls back to default formula)
|
||||
require.Nil(t, result)
|
||||
}
|
||||
|
||||
func TestCalculateStatsCost_PerRequestBilling(t *testing.T) {
|
||||
pricing := &ChannelModelPricing{
|
||||
BillingMode: BillingModePerRequest,
|
||||
PerRequestPrice: testPtrFloat64(0.05),
|
||||
}
|
||||
tokens := UsageTokens{InputTokens: 999, OutputTokens: 999}
|
||||
result := calculateStatsCost(pricing, tokens, 3)
|
||||
require.NotNil(t, result)
|
||||
// 0.05 * 3 = 0.15
|
||||
require.InDelta(t, 0.15, *result, 1e-12)
|
||||
}
|
||||
|
||||
func TestCalculateStatsCost_PerRequestBilling_PriceNil(t *testing.T) {
|
||||
pricing := &ChannelModelPricing{
|
||||
BillingMode: BillingModePerRequest,
|
||||
// PerRequestPrice is nil
|
||||
}
|
||||
result := calculateStatsCost(pricing, UsageTokens{}, 1)
|
||||
require.Nil(t, result)
|
||||
}
|
||||
|
||||
func TestCalculateStatsCost_PerRequestBilling_PriceZero(t *testing.T) {
|
||||
pricing := &ChannelModelPricing{
|
||||
BillingMode: BillingModePerRequest,
|
||||
PerRequestPrice: testPtrFloat64(0),
|
||||
}
|
||||
result := calculateStatsCost(pricing, UsageTokens{}, 1)
|
||||
// price == 0 → condition *pricing.PerRequestPrice > 0 is false → returns nil
|
||||
require.Nil(t, result)
|
||||
}
|
||||
|
||||
func TestCalculateStatsCost_ImageBilling(t *testing.T) {
|
||||
pricing := &ChannelModelPricing{
|
||||
BillingMode: BillingModeImage,
|
||||
PerRequestPrice: testPtrFloat64(0.10),
|
||||
}
|
||||
result := calculateStatsCost(pricing, UsageTokens{}, 2)
|
||||
require.NotNil(t, result)
|
||||
// 0.10 * 2 = 0.20
|
||||
require.InDelta(t, 0.20, *result, 1e-12)
|
||||
}
|
||||
|
||||
func TestCalculateStatsCost_ImageBilling_PriceNil(t *testing.T) {
|
||||
pricing := &ChannelModelPricing{
|
||||
BillingMode: BillingModeImage,
|
||||
// PerRequestPrice is nil
|
||||
}
|
||||
result := calculateStatsCost(pricing, UsageTokens{}, 1)
|
||||
require.Nil(t, result)
|
||||
}
|
||||
|
||||
func TestCalculateStatsCost_DefaultBillingMode_FallsToToken(t *testing.T) {
|
||||
// BillingMode is empty string (default) → falls into token billing
|
||||
pricing := &ChannelModelPricing{
|
||||
InputPrice: testPtrFloat64(0.001),
|
||||
OutputPrice: testPtrFloat64(0.002),
|
||||
}
|
||||
tokens := UsageTokens{
|
||||
InputTokens: 100,
|
||||
OutputTokens: 50,
|
||||
}
|
||||
result := calculateStatsCost(pricing, tokens, 1)
|
||||
require.NotNil(t, result)
|
||||
require.InDelta(t, 0.2, *result, 1e-12)
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// tryCustomRules — 多规则顺序测试
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func TestTryCustomRules_FirstMatchWins(t *testing.T) {
|
||||
channel := &Channel{
|
||||
AccountStatsPricingRules: []AccountStatsPricingRule{
|
||||
{
|
||||
GroupIDs: []int64{1},
|
||||
Pricing: []ChannelModelPricing{
|
||||
{ID: 100, Models: []string{"claude-opus-4"}, InputPrice: testPtrFloat64(0.01), OutputPrice: testPtrFloat64(0.02)},
|
||||
},
|
||||
},
|
||||
{
|
||||
GroupIDs: []int64{1},
|
||||
Pricing: []ChannelModelPricing{
|
||||
{ID: 200, Models: []string{"claude-opus-4"}, InputPrice: testPtrFloat64(0.99), OutputPrice: testPtrFloat64(0.99)},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
tokens := UsageTokens{InputTokens: 100, OutputTokens: 50}
|
||||
result := tryCustomRules(channel, 999, 1, "", "claude-opus-4", tokens, 1)
|
||||
require.NotNil(t, result)
|
||||
// 应使用第一条规则的价格:100*0.01 + 50*0.02 = 2.0
|
||||
require.InDelta(t, 2.0, *result, 1e-12)
|
||||
}
|
||||
|
||||
func TestTryCustomRules_SkipsNonMatchingRules(t *testing.T) {
|
||||
channel := &Channel{
|
||||
AccountStatsPricingRules: []AccountStatsPricingRule{
|
||||
{
|
||||
AccountIDs: []int64{888}, // 不匹配
|
||||
Pricing: []ChannelModelPricing{
|
||||
{ID: 100, Models: []string{"claude-opus-4"}, InputPrice: testPtrFloat64(0.99)},
|
||||
},
|
||||
},
|
||||
{
|
||||
GroupIDs: []int64{1}, // 匹配
|
||||
Pricing: []ChannelModelPricing{
|
||||
{ID: 200, Models: []string{"claude-opus-4"}, InputPrice: testPtrFloat64(0.05)},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
tokens := UsageTokens{InputTokens: 100}
|
||||
result := tryCustomRules(channel, 999, 1, "", "claude-opus-4", tokens, 1)
|
||||
require.NotNil(t, result)
|
||||
// 跳过规则1(账号不匹配),使用规则2:100*0.05 = 5.0
|
||||
require.InDelta(t, 5.0, *result, 1e-12)
|
||||
}
|
||||
|
||||
func TestTryCustomRules_NoMatch_ReturnsNil(t *testing.T) {
|
||||
channel := &Channel{
|
||||
AccountStatsPricingRules: []AccountStatsPricingRule{
|
||||
{
|
||||
AccountIDs: []int64{888},
|
||||
Pricing: []ChannelModelPricing{
|
||||
{ID: 100, Models: []string{"claude-opus-4"}, InputPrice: testPtrFloat64(0.01)},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
tokens := UsageTokens{InputTokens: 100}
|
||||
result := tryCustomRules(channel, 999, 2, "", "claude-opus-4", tokens, 1)
|
||||
require.Nil(t, result) // 账号和分组都不匹配
|
||||
}
|
||||
|
||||
func TestTryCustomRules_RuleMatchesButModelNot_ContinuesToNext(t *testing.T) {
|
||||
channel := &Channel{
|
||||
AccountStatsPricingRules: []AccountStatsPricingRule{
|
||||
{
|
||||
GroupIDs: []int64{1},
|
||||
Pricing: []ChannelModelPricing{
|
||||
{ID: 100, Models: []string{"gpt-4o"}, InputPrice: testPtrFloat64(0.01)}, // 模型不匹配
|
||||
},
|
||||
},
|
||||
{
|
||||
GroupIDs: []int64{1},
|
||||
Pricing: []ChannelModelPricing{
|
||||
{ID: 200, Models: []string{"claude-opus-4"}, InputPrice: testPtrFloat64(0.05)}, // 模型匹配
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
tokens := UsageTokens{InputTokens: 100}
|
||||
result := tryCustomRules(channel, 999, 1, "", "claude-opus-4", tokens, 1)
|
||||
require.NotNil(t, result)
|
||||
require.InDelta(t, 5.0, *result, 1e-12) // 使用规则2
|
||||
}
|
||||
@@ -49,21 +49,25 @@ type Channel struct {
|
||||
ModelPricing []ChannelModelPricing
|
||||
// 渠道级模型映射(按平台分组:platform → {src→dst})
|
||||
ModelMapping map[string]map[string]string
|
||||
// 渠道特性配置(如 {"web_search_emulation": {"anthropic": true}})
|
||||
FeaturesConfig map[string]any
|
||||
|
||||
// 账号统计定价
|
||||
ApplyPricingToAccountStats bool // 是否应用渠道模型定价到账号统计
|
||||
AccountStatsPricingRules []AccountStatsPricingRule // 自定义账号统计定价规则(按 SortOrder 排序,先命中为准)
|
||||
}
|
||||
|
||||
// IsWebSearchEmulationEnabled 返回该渠道是否为指定平台启用了 web search 模拟。
|
||||
func (c *Channel) IsWebSearchEmulationEnabled(platform string) bool {
|
||||
if c == nil || c.FeaturesConfig == nil {
|
||||
return false
|
||||
}
|
||||
wse, ok := c.FeaturesConfig[featureKeyWebSearchEmulation].(map[string]any)
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
enabled, ok := wse[platform].(bool)
|
||||
return ok && enabled
|
||||
// AccountStatsPricingRule 账号统计定价规则
|
||||
// 每条规则包含匹配条件(分组/账号)和独立的模型定价。
|
||||
// 多条规则按 SortOrder 排序,先命中为准。
|
||||
type AccountStatsPricingRule struct {
|
||||
ID int64
|
||||
ChannelID int64
|
||||
Name string
|
||||
GroupIDs []int64
|
||||
AccountIDs []int64
|
||||
SortOrder int
|
||||
Pricing []ChannelModelPricing // 规则内的模型定价(复用现有定价结构)
|
||||
CreatedAt time.Time
|
||||
UpdatedAt time.Time
|
||||
}
|
||||
|
||||
// ChannelModelPricing 渠道模型定价条目
|
||||
@@ -192,6 +196,26 @@ func (c *Channel) Clone() *Channel {
|
||||
cp.ModelMapping[platform] = inner
|
||||
}
|
||||
}
|
||||
if c.AccountStatsPricingRules != nil {
|
||||
cp.AccountStatsPricingRules = make([]AccountStatsPricingRule, len(c.AccountStatsPricingRules))
|
||||
for i, rule := range c.AccountStatsPricingRules {
|
||||
cp.AccountStatsPricingRules[i] = rule
|
||||
if rule.GroupIDs != nil {
|
||||
cp.AccountStatsPricingRules[i].GroupIDs = make([]int64, len(rule.GroupIDs))
|
||||
copy(cp.AccountStatsPricingRules[i].GroupIDs, rule.GroupIDs)
|
||||
}
|
||||
if rule.AccountIDs != nil {
|
||||
cp.AccountStatsPricingRules[i].AccountIDs = make([]int64, len(rule.AccountIDs))
|
||||
copy(cp.AccountStatsPricingRules[i].AccountIDs, rule.AccountIDs)
|
||||
}
|
||||
if rule.Pricing != nil {
|
||||
cp.AccountStatsPricingRules[i].Pricing = make([]ChannelModelPricing, len(rule.Pricing))
|
||||
for j := range rule.Pricing {
|
||||
cp.AccountStatsPricingRules[i].Pricing[j] = rule.Pricing[j].Clone()
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return &cp
|
||||
}
|
||||
|
||||
|
||||
@@ -416,6 +416,15 @@ func (s *ChannelService) GetChannelForGroup(ctx context.Context, groupID int64)
|
||||
return ch.Clone(), nil
|
||||
}
|
||||
|
||||
// GetGroupPlatform 获取分组的平台标识(从缓存)
|
||||
func (s *ChannelService) GetGroupPlatform(ctx context.Context, groupID int64) string {
|
||||
cache, err := s.loadCache(ctx)
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
return cache.groupPlatform[groupID]
|
||||
}
|
||||
|
||||
// channelLookup 热路径公共查找结果
|
||||
type channelLookup struct {
|
||||
cache *channelCache
|
||||
@@ -656,16 +665,17 @@ func (s *ChannelService) Create(ctx context.Context, input *CreateChannelInput)
|
||||
}
|
||||
|
||||
channel := &Channel{
|
||||
Name: input.Name,
|
||||
Description: input.Description,
|
||||
Status: StatusActive,
|
||||
BillingModelSource: input.BillingModelSource,
|
||||
RestrictModels: input.RestrictModels,
|
||||
GroupIDs: input.GroupIDs,
|
||||
ModelPricing: input.ModelPricing,
|
||||
ModelMapping: input.ModelMapping,
|
||||
Features: input.Features,
|
||||
FeaturesConfig: input.FeaturesConfig,
|
||||
Name: input.Name,
|
||||
Description: input.Description,
|
||||
Status: StatusActive,
|
||||
BillingModelSource: input.BillingModelSource,
|
||||
RestrictModels: input.RestrictModels,
|
||||
GroupIDs: input.GroupIDs,
|
||||
ModelPricing: input.ModelPricing,
|
||||
ModelMapping: input.ModelMapping,
|
||||
Features: input.Features,
|
||||
ApplyPricingToAccountStats: input.ApplyPricingToAccountStats,
|
||||
AccountStatsPricingRules: input.AccountStatsPricingRules,
|
||||
}
|
||||
if channel.BillingModelSource == "" {
|
||||
channel.BillingModelSource = BillingModelSourceChannelMapped
|
||||
@@ -754,8 +764,11 @@ func (s *ChannelService) applyUpdateInput(ctx context.Context, channel *Channel,
|
||||
if input.BillingModelSource != "" {
|
||||
channel.BillingModelSource = input.BillingModelSource
|
||||
}
|
||||
if input.FeaturesConfig != nil {
|
||||
channel.FeaturesConfig = input.FeaturesConfig
|
||||
if input.ApplyPricingToAccountStats != nil {
|
||||
channel.ApplyPricingToAccountStats = *input.ApplyPricingToAccountStats
|
||||
}
|
||||
if input.AccountStatsPricingRules != nil {
|
||||
channel.AccountStatsPricingRules = *input.AccountStatsPricingRules
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -922,27 +935,29 @@ func detectConflicts(entries []modelEntry, platform, errCode, label string) erro
|
||||
|
||||
// CreateChannelInput 创建渠道输入
|
||||
type CreateChannelInput struct {
|
||||
Name string
|
||||
Description string
|
||||
GroupIDs []int64
|
||||
ModelPricing []ChannelModelPricing
|
||||
ModelMapping map[string]map[string]string // platform → {src→dst}
|
||||
BillingModelSource string
|
||||
RestrictModels bool
|
||||
Features string
|
||||
FeaturesConfig map[string]any
|
||||
Name string
|
||||
Description string
|
||||
GroupIDs []int64
|
||||
ModelPricing []ChannelModelPricing
|
||||
ModelMapping map[string]map[string]string // platform → {src→dst}
|
||||
BillingModelSource string
|
||||
RestrictModels bool
|
||||
Features string
|
||||
ApplyPricingToAccountStats bool
|
||||
AccountStatsPricingRules []AccountStatsPricingRule
|
||||
}
|
||||
|
||||
// UpdateChannelInput 更新渠道输入
|
||||
type UpdateChannelInput struct {
|
||||
Name string
|
||||
Description *string
|
||||
Status string
|
||||
GroupIDs *[]int64
|
||||
ModelPricing *[]ChannelModelPricing
|
||||
ModelMapping map[string]map[string]string // platform → {src→dst}
|
||||
BillingModelSource string
|
||||
RestrictModels *bool
|
||||
Features *string
|
||||
FeaturesConfig map[string]any
|
||||
Name string
|
||||
Description *string
|
||||
Status string
|
||||
GroupIDs *[]int64
|
||||
ModelPricing *[]ChannelModelPricing
|
||||
ModelMapping map[string]map[string]string // platform → {src→dst}
|
||||
BillingModelSource string
|
||||
RestrictModels *bool
|
||||
Features *string
|
||||
ApplyPricingToAccountStats *bool
|
||||
AccountStatsPricingRules *[]AccountStatsPricingRule
|
||||
}
|
||||
|
||||
@@ -7559,6 +7559,23 @@ func (s *GatewayService) recordUsageCore(ctx context.Context, input *recordUsage
|
||||
usageLog := s.buildRecordUsageLog(ctx, input, result, apiKey, user, account, subscription,
|
||||
requestedModel, multiplier, accountRateMultiplier, billingType, cacheTTLOverridden, cost, opts)
|
||||
|
||||
// 计算账号统计定价费用
|
||||
if apiKey.GroupID != nil {
|
||||
usageLog.AccountStatsCost = resolveAccountStatsCost(
|
||||
ctx, s.channelService, s.billingService,
|
||||
account.ID, *apiKey.GroupID, billingModel,
|
||||
UsageTokens{
|
||||
InputTokens: result.Usage.InputTokens,
|
||||
OutputTokens: result.Usage.OutputTokens,
|
||||
CacheCreationTokens: result.Usage.CacheCreationInputTokens,
|
||||
CacheReadTokens: result.Usage.CacheReadInputTokens,
|
||||
ImageOutputTokens: result.Usage.ImageOutputTokens,
|
||||
},
|
||||
1, // requestCount
|
||||
"", // serviceTier: Anthropic 平台不使用 service tier
|
||||
)
|
||||
}
|
||||
|
||||
if s.cfg != nil && s.cfg.RunMode == config.RunModeSimple {
|
||||
writeUsageLogBestEffort(ctx, s.usageLogRepo, usageLog, "service.gateway")
|
||||
logger.LegacyPrintf("service.gateway", "[SIMPLE MODE] Usage recorded (not billed): user=%d, tokens=%d", usageLog.UserID, usageLog.TotalTokens())
|
||||
|
||||
@@ -4569,6 +4569,15 @@ func (s *OpenAIGatewayService) RecordUsage(ctx context.Context, input *OpenAIRec
|
||||
usageLog.SubscriptionID = &subscription.ID
|
||||
}
|
||||
|
||||
// 计算账号统计定价费用
|
||||
if apiKey.GroupID != nil {
|
||||
usageLog.AccountStatsCost = resolveAccountStatsCost(
|
||||
ctx, s.channelService, s.billingService,
|
||||
account.ID, *apiKey.GroupID, billingModel,
|
||||
tokens, 1, serviceTier,
|
||||
)
|
||||
}
|
||||
|
||||
if s.cfg != nil && s.cfg.RunMode == config.RunModeSimple {
|
||||
writeUsageLogBestEffort(ctx, s.usageLogRepo, usageLog, "service.openai_gateway")
|
||||
logger.LegacyPrintf("service.openai_gateway", "[SIMPLE MODE] Usage recorded (not billed): user=%d, tokens=%d", usageLog.UserID, usageLog.TotalTokens())
|
||||
|
||||
@@ -146,6 +146,8 @@ type UsageLog struct {
|
||||
RateMultiplier float64
|
||||
// AccountRateMultiplier 账号计费倍率快照(nil 表示历史数据,按 1.0 处理)
|
||||
AccountRateMultiplier *float64
|
||||
// AccountStatsCost 账号统计定价预计算费用(nil = 使用默认公式 total_cost × account_rate_multiplier)
|
||||
AccountStatsCost *float64
|
||||
|
||||
BillingType int8
|
||||
RequestType RequestType
|
||||
|
||||
Reference in New Issue
Block a user