Files
sub2api/backend/internal/service/account_stats_pricing.go
erio f1297a3694 feat: add per-provider allow_user_refund control and align wildcard matching
allow_user_refund:
- Add allow_user_refund field to PaymentProviderInstance ent schema
- Migration 103: ALTER TABLE payment_provider_instances ADD COLUMN
- Cascade logic: disabling refund_enabled auto-disables allow_user_refund
- User refund validation: check provider instance allows user refund
- Admin refund validation: check provider instance allows admin refund
- Subscription refund: deduct days on refund, rollback on failure
- New endpoint: GET /payment/orders/refund-eligible-providers
- Frontend: ToggleSwitch in ProviderCard/Dialog, cascade in SettingsView

Wildcard matching:
- Change findPricingForModel from "longest prefix wins" to "config order
  priority (first match wins)", aligning with channel service behavior
2026-04-14 16:26:46 +08:00

237 lines
7.3 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package service
import (
"context"
"strings"
)
// resolveAccountStatsCost 计算账号统计定价费用。
// 返回 nil 表示不覆盖使用默认公式total_cost × account_rate_multiplier
//
// 优先级(先命中为准):
// 1. 自定义规则(始终尝试,不依赖 ApplyPricingToAccountStats 开关)
// 2. ApplyPricingToAccountStats 启用时,直接使用本次请求的客户计费(倍率前的 totalCost
// 3. 模型定价文件LiteLLM中上游模型的默认价格
// 4. nil → 走默认公式total_cost × account_rate_multiplier
//
// upstreamModel 是最终发往上游的模型 ID。
// totalCost 是本次请求的客户计费(倍率前),用于优先级 2。
func resolveAccountStatsCost(
ctx context.Context,
channelService *ChannelService,
billingService *BillingService,
accountID int64,
groupID int64,
upstreamModel string,
tokens UsageTokens,
requestCount int,
totalCost float64,
) *float64 {
if channelService == nil || upstreamModel == "" {
return nil
}
channel, err := channelService.GetChannelForGroup(ctx, groupID)
if err != nil || channel == nil {
return nil
}
platform := channelService.GetGroupPlatform(ctx, groupID)
// 优先级 1自定义规则始终尝试
if cost := tryCustomRules(channel, accountID, groupID, platform, upstreamModel, tokens, requestCount); cost != nil {
return cost
}
// 优先级 2渠道开启"应用模型定价到账号统计"时,直接使用客户计费(倍率前)
if channel.ApplyPricingToAccountStats {
cost := totalCost
if cost <= 0 {
return nil
}
return &cost
}
// 优先级 3模型定价文件LiteLLM默认价格
if billingService != nil {
return tryModelFilePricing(billingService, upstreamModel, tokens)
}
return nil
}
// tryModelFilePricing 使用模型定价文件LiteLLM/fallback中的标准价格计算费用。
func tryModelFilePricing(billingService *BillingService, model string, tokens UsageTokens) *float64 {
pricing, err := billingService.GetModelPricing(model)
if err != nil || pricing == nil {
return nil
}
cost := float64(tokens.InputTokens)*pricing.InputPricePerToken +
float64(tokens.OutputTokens)*pricing.OutputPricePerToken +
float64(tokens.CacheCreationTokens)*pricing.CacheCreationPricePerToken +
float64(tokens.CacheReadTokens)*pricing.CacheReadPricePerToken +
float64(tokens.ImageOutputTokens)*pricing.ImageOutputPricePerToken
if cost <= 0 {
return nil
}
return &cost
}
// tryCustomRules 遍历自定义规则,按数组顺序先命中为准。
func tryCustomRules(
channel *Channel, accountID, groupID int64,
platform, model string, tokens UsageTokens, requestCount int,
) *float64 {
modelLower := strings.ToLower(model)
for _, rule := range channel.AccountStatsPricingRules {
if !matchAccountStatsRule(&rule, accountID, groupID) {
continue
}
pricing := findPricingForModel(rule.Pricing, platform, modelLower)
if pricing == nil {
continue // 规则匹配但模型不在规则定价中,继续下一条
}
return calculateStatsCost(pricing, tokens, requestCount)
}
return nil
}
// matchAccountStatsRule 检查规则是否匹配指定的 accountID 和 groupID。
// 匹配条件accountID ∈ rule.AccountIDs 或 groupID ∈ rule.GroupIDs。
// 如果规则的 AccountIDs 和 GroupIDs 都为空,视为不匹配。
func matchAccountStatsRule(rule *AccountStatsPricingRule, accountID, groupID int64) bool {
if len(rule.AccountIDs) == 0 && len(rule.GroupIDs) == 0 {
return false
}
for _, id := range rule.AccountIDs {
if id == accountID {
return true
}
}
for _, id := range rule.GroupIDs {
if id == groupID {
return true
}
}
return false
}
// findPricingForModel 在定价列表中查找匹配的模型定价。
// 先精确匹配,再通配符匹配(按配置顺序,先匹配先使用)。
func findPricingForModel(pricingList []ChannelModelPricing, platform, modelLower string) *ChannelModelPricing {
// 精确匹配优先
for i := range pricingList {
p := &pricingList[i]
if !isPlatformMatch(platform, p.Platform) {
continue
}
for _, m := range p.Models {
if strings.ToLower(m) == modelLower {
return p
}
}
}
// 通配符匹配:按配置顺序,先匹配先使用
for i := range pricingList {
p := &pricingList[i]
if !isPlatformMatch(platform, p.Platform) {
continue
}
for _, m := range p.Models {
ml := strings.ToLower(m)
if !strings.HasSuffix(ml, "*") {
continue
}
prefix := strings.TrimSuffix(ml, "*")
if strings.HasPrefix(modelLower, prefix) {
return p
}
}
}
return nil
}
// isPlatformMatch 判断平台是否匹配(空平台视为不限平台)。
func isPlatformMatch(queryPlatform, pricingPlatform string) bool {
if queryPlatform == "" || pricingPlatform == "" {
return true
}
return queryPlatform == pricingPlatform
}
// calculateStatsCost 使用给定的定价计算费用(不含任何倍率,原始费用)。
func calculateStatsCost(pricing *ChannelModelPricing, tokens UsageTokens, requestCount int) *float64 {
if pricing == nil {
return nil
}
switch pricing.BillingMode {
case BillingModePerRequest, BillingModeImage:
return calculatePerRequestStatsCost(pricing, requestCount)
default:
return calculateTokenStatsCost(pricing, tokens)
}
}
// calculatePerRequestStatsCost 按次/图片计费。
func calculatePerRequestStatsCost(pricing *ChannelModelPricing, requestCount int) *float64 {
if pricing.PerRequestPrice == nil || *pricing.PerRequestPrice <= 0 {
return nil
}
cost := *pricing.PerRequestPrice * float64(requestCount)
return &cost
}
// calculateTokenStatsCost Token 计费。
// If the pricing has intervals, find the matching interval by total token count
// and use its prices instead of the flat pricing fields.
func calculateTokenStatsCost(pricing *ChannelModelPricing, tokens UsageTokens) *float64 {
p := pricing
if len(pricing.Intervals) > 0 {
totalTokens := tokens.InputTokens + tokens.OutputTokens + tokens.CacheCreationTokens + tokens.CacheReadTokens
if iv := FindMatchingInterval(pricing.Intervals, totalTokens); iv != nil {
p = &ChannelModelPricing{
InputPrice: iv.InputPrice,
OutputPrice: iv.OutputPrice,
CacheWritePrice: iv.CacheWritePrice,
CacheReadPrice: iv.CacheReadPrice,
PerRequestPrice: iv.PerRequestPrice,
}
}
}
deref := func(ptr *float64) float64 {
if ptr == nil {
return 0
}
return *ptr
}
cost := float64(tokens.InputTokens)*deref(p.InputPrice) +
float64(tokens.OutputTokens)*deref(p.OutputPrice) +
float64(tokens.CacheCreationTokens)*deref(p.CacheWritePrice) +
float64(tokens.CacheReadTokens)*deref(p.CacheReadPrice) +
float64(tokens.ImageOutputTokens)*deref(p.ImageOutputPrice)
if cost <= 0 {
return nil
}
return &cost
}
// applyAccountStatsCost resolves the account stats cost for a usage log entry.
// It resolves the upstream model (falling back to the requested model) and calls
// the 4-level priority chain via resolveAccountStatsCost.
func applyAccountStatsCost(
ctx context.Context,
usageLog *UsageLog,
cs *ChannelService, bs *BillingService,
accountID int64, groupID int64,
upstreamModel, requestedModel string,
tokens UsageTokens,
totalCost float64,
) {
model := upstreamModel
if model == "" {
model = requestedModel
}
usageLog.AccountStatsCost = resolveAccountStatsCost(
ctx, cs, bs, accountID, groupID, model, tokens, 1, totalCost,
)
}