allow_user_refund: - Add allow_user_refund field to PaymentProviderInstance ent schema - Migration 103: ALTER TABLE payment_provider_instances ADD COLUMN - Cascade logic: disabling refund_enabled auto-disables allow_user_refund - User refund validation: check provider instance allows user refund - Admin refund validation: check provider instance allows admin refund - Subscription refund: deduct days on refund, rollback on failure - New endpoint: GET /payment/orders/refund-eligible-providers - Frontend: ToggleSwitch in ProviderCard/Dialog, cascade in SettingsView Wildcard matching: - Change findPricingForModel from "longest prefix wins" to "config order priority (first match wins)", aligning with channel service behavior
237 lines
7.3 KiB
Go
237 lines
7.3 KiB
Go
package service
|
||
|
||
import (
|
||
"context"
|
||
"strings"
|
||
)
|
||
|
||
// resolveAccountStatsCost 计算账号统计定价费用。
|
||
// 返回 nil 表示不覆盖,使用默认公式(total_cost × account_rate_multiplier)。
|
||
//
|
||
// 优先级(先命中为准):
|
||
// 1. 自定义规则(始终尝试,不依赖 ApplyPricingToAccountStats 开关)
|
||
// 2. ApplyPricingToAccountStats 启用时,直接使用本次请求的客户计费(倍率前的 totalCost)
|
||
// 3. 模型定价文件(LiteLLM)中上游模型的默认价格
|
||
// 4. nil → 走默认公式(total_cost × account_rate_multiplier)
|
||
//
|
||
// upstreamModel 是最终发往上游的模型 ID。
|
||
// totalCost 是本次请求的客户计费(倍率前),用于优先级 2。
|
||
func resolveAccountStatsCost(
|
||
ctx context.Context,
|
||
channelService *ChannelService,
|
||
billingService *BillingService,
|
||
accountID int64,
|
||
groupID int64,
|
||
upstreamModel string,
|
||
tokens UsageTokens,
|
||
requestCount int,
|
||
totalCost float64,
|
||
) *float64 {
|
||
if channelService == nil || upstreamModel == "" {
|
||
return nil
|
||
}
|
||
channel, err := channelService.GetChannelForGroup(ctx, groupID)
|
||
if err != nil || channel == nil {
|
||
return nil
|
||
}
|
||
|
||
platform := channelService.GetGroupPlatform(ctx, groupID)
|
||
|
||
// 优先级 1:自定义规则(始终尝试)
|
||
if cost := tryCustomRules(channel, accountID, groupID, platform, upstreamModel, tokens, requestCount); cost != nil {
|
||
return cost
|
||
}
|
||
|
||
// 优先级 2:渠道开启"应用模型定价到账号统计"时,直接使用客户计费(倍率前)
|
||
if channel.ApplyPricingToAccountStats {
|
||
cost := totalCost
|
||
if cost <= 0 {
|
||
return nil
|
||
}
|
||
return &cost
|
||
}
|
||
|
||
// 优先级 3:模型定价文件(LiteLLM)默认价格
|
||
if billingService != nil {
|
||
return tryModelFilePricing(billingService, upstreamModel, tokens)
|
||
}
|
||
|
||
return nil
|
||
}
|
||
|
||
// tryModelFilePricing 使用模型定价文件(LiteLLM/fallback)中的标准价格计算费用。
|
||
func tryModelFilePricing(billingService *BillingService, model string, tokens UsageTokens) *float64 {
|
||
pricing, err := billingService.GetModelPricing(model)
|
||
if err != nil || pricing == nil {
|
||
return nil
|
||
}
|
||
cost := float64(tokens.InputTokens)*pricing.InputPricePerToken +
|
||
float64(tokens.OutputTokens)*pricing.OutputPricePerToken +
|
||
float64(tokens.CacheCreationTokens)*pricing.CacheCreationPricePerToken +
|
||
float64(tokens.CacheReadTokens)*pricing.CacheReadPricePerToken +
|
||
float64(tokens.ImageOutputTokens)*pricing.ImageOutputPricePerToken
|
||
if cost <= 0 {
|
||
return nil
|
||
}
|
||
return &cost
|
||
}
|
||
|
||
// tryCustomRules 遍历自定义规则,按数组顺序先命中为准。
|
||
func tryCustomRules(
|
||
channel *Channel, accountID, groupID int64,
|
||
platform, model string, tokens UsageTokens, requestCount int,
|
||
) *float64 {
|
||
modelLower := strings.ToLower(model)
|
||
for _, rule := range channel.AccountStatsPricingRules {
|
||
if !matchAccountStatsRule(&rule, accountID, groupID) {
|
||
continue
|
||
}
|
||
pricing := findPricingForModel(rule.Pricing, platform, modelLower)
|
||
if pricing == nil {
|
||
continue // 规则匹配但模型不在规则定价中,继续下一条
|
||
}
|
||
return calculateStatsCost(pricing, tokens, requestCount)
|
||
}
|
||
return nil
|
||
}
|
||
|
||
// matchAccountStatsRule 检查规则是否匹配指定的 accountID 和 groupID。
|
||
// 匹配条件:accountID ∈ rule.AccountIDs 或 groupID ∈ rule.GroupIDs。
|
||
// 如果规则的 AccountIDs 和 GroupIDs 都为空,视为不匹配。
|
||
func matchAccountStatsRule(rule *AccountStatsPricingRule, accountID, groupID int64) bool {
|
||
if len(rule.AccountIDs) == 0 && len(rule.GroupIDs) == 0 {
|
||
return false
|
||
}
|
||
for _, id := range rule.AccountIDs {
|
||
if id == accountID {
|
||
return true
|
||
}
|
||
}
|
||
for _, id := range rule.GroupIDs {
|
||
if id == groupID {
|
||
return true
|
||
}
|
||
}
|
||
return false
|
||
}
|
||
|
||
// findPricingForModel 在定价列表中查找匹配的模型定价。
|
||
// 先精确匹配,再通配符匹配(按配置顺序,先匹配先使用)。
|
||
func findPricingForModel(pricingList []ChannelModelPricing, platform, modelLower string) *ChannelModelPricing {
|
||
// 精确匹配优先
|
||
for i := range pricingList {
|
||
p := &pricingList[i]
|
||
if !isPlatformMatch(platform, p.Platform) {
|
||
continue
|
||
}
|
||
for _, m := range p.Models {
|
||
if strings.ToLower(m) == modelLower {
|
||
return p
|
||
}
|
||
}
|
||
}
|
||
// 通配符匹配:按配置顺序,先匹配先使用
|
||
for i := range pricingList {
|
||
p := &pricingList[i]
|
||
if !isPlatformMatch(platform, p.Platform) {
|
||
continue
|
||
}
|
||
for _, m := range p.Models {
|
||
ml := strings.ToLower(m)
|
||
if !strings.HasSuffix(ml, "*") {
|
||
continue
|
||
}
|
||
prefix := strings.TrimSuffix(ml, "*")
|
||
if strings.HasPrefix(modelLower, prefix) {
|
||
return p
|
||
}
|
||
}
|
||
}
|
||
return nil
|
||
}
|
||
|
||
// isPlatformMatch 判断平台是否匹配(空平台视为不限平台)。
|
||
func isPlatformMatch(queryPlatform, pricingPlatform string) bool {
|
||
if queryPlatform == "" || pricingPlatform == "" {
|
||
return true
|
||
}
|
||
return queryPlatform == pricingPlatform
|
||
}
|
||
|
||
// calculateStatsCost 使用给定的定价计算费用(不含任何倍率,原始费用)。
|
||
func calculateStatsCost(pricing *ChannelModelPricing, tokens UsageTokens, requestCount int) *float64 {
|
||
if pricing == nil {
|
||
return nil
|
||
}
|
||
switch pricing.BillingMode {
|
||
case BillingModePerRequest, BillingModeImage:
|
||
return calculatePerRequestStatsCost(pricing, requestCount)
|
||
default:
|
||
return calculateTokenStatsCost(pricing, tokens)
|
||
}
|
||
}
|
||
|
||
// calculatePerRequestStatsCost 按次/图片计费。
|
||
func calculatePerRequestStatsCost(pricing *ChannelModelPricing, requestCount int) *float64 {
|
||
if pricing.PerRequestPrice == nil || *pricing.PerRequestPrice <= 0 {
|
||
return nil
|
||
}
|
||
cost := *pricing.PerRequestPrice * float64(requestCount)
|
||
return &cost
|
||
}
|
||
|
||
// calculateTokenStatsCost Token 计费。
|
||
// If the pricing has intervals, find the matching interval by total token count
|
||
// and use its prices instead of the flat pricing fields.
|
||
func calculateTokenStatsCost(pricing *ChannelModelPricing, tokens UsageTokens) *float64 {
|
||
p := pricing
|
||
if len(pricing.Intervals) > 0 {
|
||
totalTokens := tokens.InputTokens + tokens.OutputTokens + tokens.CacheCreationTokens + tokens.CacheReadTokens
|
||
if iv := FindMatchingInterval(pricing.Intervals, totalTokens); iv != nil {
|
||
p = &ChannelModelPricing{
|
||
InputPrice: iv.InputPrice,
|
||
OutputPrice: iv.OutputPrice,
|
||
CacheWritePrice: iv.CacheWritePrice,
|
||
CacheReadPrice: iv.CacheReadPrice,
|
||
PerRequestPrice: iv.PerRequestPrice,
|
||
}
|
||
}
|
||
}
|
||
deref := func(ptr *float64) float64 {
|
||
if ptr == nil {
|
||
return 0
|
||
}
|
||
return *ptr
|
||
}
|
||
cost := float64(tokens.InputTokens)*deref(p.InputPrice) +
|
||
float64(tokens.OutputTokens)*deref(p.OutputPrice) +
|
||
float64(tokens.CacheCreationTokens)*deref(p.CacheWritePrice) +
|
||
float64(tokens.CacheReadTokens)*deref(p.CacheReadPrice) +
|
||
float64(tokens.ImageOutputTokens)*deref(p.ImageOutputPrice)
|
||
if cost <= 0 {
|
||
return nil
|
||
}
|
||
return &cost
|
||
}
|
||
|
||
// applyAccountStatsCost resolves the account stats cost for a usage log entry.
|
||
// It resolves the upstream model (falling back to the requested model) and calls
|
||
// the 4-level priority chain via resolveAccountStatsCost.
|
||
func applyAccountStatsCost(
|
||
ctx context.Context,
|
||
usageLog *UsageLog,
|
||
cs *ChannelService, bs *BillingService,
|
||
accountID int64, groupID int64,
|
||
upstreamModel, requestedModel string,
|
||
tokens UsageTokens,
|
||
totalCost float64,
|
||
) {
|
||
model := upstreamModel
|
||
if model == "" {
|
||
model = requestedModel
|
||
}
|
||
usageLog.AccountStatsCost = resolveAccountStatsCost(
|
||
ctx, cs, bs, accountID, groupID, model, tokens, 1, totalCost,
|
||
)
|
||
}
|