Files
sub2api/backend/internal/service/channel.go
erio 71f61bbc47 fix: resolve 5 audit findings in channel/credits/scheduling
P0-1: Credits degraded response retry + fail-open
- Add isAntigravityDegradedResponse() to detect transient API failures
- Retry up to 3 times with exponential backoff (500ms/1s/2s)
- Invalidate singleflight cache between retries
- Fail-open after exhausting retries instead of 5h circuit break

P1-1: Fix channel restriction pre-check timing conflict
- Swap checkClaudeCodeRestriction before checkChannelPricingRestriction
- Ensures channel restriction is checked against final fallback groupID

P1-2: Add interval pricing validation (frontend + backend)
- Backend: ValidateIntervals() with boundary, price, overlap checks
- Frontend: validateIntervals() with Chinese error messages
- Rules: MinTokens>=0, MaxTokens>MinTokens, prices>=0, no overlap

P2: Fix cross-platform same-model pricing/mapping override
- Store cache keys using original platform instead of group platform
- Lookup across matching platforms (antigravity→anthropic→gemini)
- Prevents anthropic/gemini same-name models from overwriting each other
2026-04-04 11:25:01 +08:00

278 lines
9.0 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package service
import (
"fmt"
"sort"
"strings"
"time"
)
// BillingMode 计费模式
type BillingMode string
const (
BillingModeToken BillingMode = "token" // 按 token 区间计费
BillingModePerRequest BillingMode = "per_request" // 按次计费(支持上下文窗口分层)
BillingModeImage BillingMode = "image" // 图片计费(当前按次,预留 token 计费)
)
// IsValid 检查 BillingMode 是否为合法值
func (m BillingMode) IsValid() bool {
switch m {
case BillingModeToken, BillingModePerRequest, BillingModeImage, "":
return true
}
return false
}
const (
BillingModelSourceRequested = "requested"
BillingModelSourceUpstream = "upstream"
BillingModelSourceChannelMapped = "channel_mapped"
)
// Channel 渠道实体
type Channel struct {
ID int64
Name string
Description string
Status string
BillingModelSource string // "requested", "upstream", or "channel_mapped"
RestrictModels bool // 是否限制模型(仅允许定价列表中的模型)
CreatedAt time.Time
UpdatedAt time.Time
// 关联的分组 ID 列表
GroupIDs []int64
// 模型定价列表(每条含 Platform 字段)
ModelPricing []ChannelModelPricing
// 渠道级模型映射按平台分组platform → {src→dst}
ModelMapping map[string]map[string]string
}
// ChannelModelPricing 渠道模型定价条目
type ChannelModelPricing struct {
ID int64
ChannelID int64
Platform string // 所属平台anthropic/openai/gemini/...
Models []string // 绑定的模型列表
BillingMode BillingMode // 计费模式
InputPrice *float64 // 每 token 输入价格USD— 向后兼容 flat 定价
OutputPrice *float64 // 每 token 输出价格USD
CacheWritePrice *float64 // 缓存写入价格
CacheReadPrice *float64 // 缓存读取价格
ImageOutputPrice *float64 // 图片输出价格(向后兼容)
PerRequestPrice *float64 // 默认按次计费价格USD
Intervals []PricingInterval // 区间定价列表
CreatedAt time.Time
UpdatedAt time.Time
}
// PricingInterval 定价区间token 区间 / 按次分层 / 图片分辨率分层)
type PricingInterval struct {
ID int64
PricingID int64
MinTokens int // 区间下界(含)
MaxTokens *int // 区间上界不含nil = 无上限
TierLabel string // 层级标签(按次/图片模式1K, 2K, 4K, HD 等)
InputPrice *float64 // token 模式:每 token 输入价
OutputPrice *float64 // token 模式:每 token 输出价
CacheWritePrice *float64 // token 模式:缓存写入价
CacheReadPrice *float64 // token 模式:缓存读取价
PerRequestPrice *float64 // 按次/图片模式:每次请求价格
SortOrder int
CreatedAt time.Time
UpdatedAt time.Time
}
// IsActive 判断渠道是否启用
func (c *Channel) IsActive() bool {
return c.Status == StatusActive
}
// GetModelPricing 根据模型名查找渠道定价,未找到返回 nil。
// 精确匹配,大小写不敏感。返回值拷贝,不污染缓存。
func (c *Channel) GetModelPricing(model string) *ChannelModelPricing {
modelLower := strings.ToLower(model)
for i := range c.ModelPricing {
for _, m := range c.ModelPricing[i].Models {
if strings.ToLower(m) == modelLower {
cp := c.ModelPricing[i].Clone()
return &cp
}
}
}
return nil
}
// FindMatchingInterval 在区间列表中查找匹配 totalTokens 的区间。
// 区间为左开右闭 (min, max]min 不含max 包含。
// 第一个区间 min=0 时0 token 不匹配任何区间(回退到默认价格)。
func FindMatchingInterval(intervals []PricingInterval, totalTokens int) *PricingInterval {
for i := range intervals {
iv := &intervals[i]
if totalTokens > iv.MinTokens && (iv.MaxTokens == nil || totalTokens <= *iv.MaxTokens) {
return iv
}
}
return nil
}
// GetIntervalForContext 根据总 context token 数查找匹配的区间。
func (p *ChannelModelPricing) GetIntervalForContext(totalTokens int) *PricingInterval {
return FindMatchingInterval(p.Intervals, totalTokens)
}
// GetTierByLabel 根据标签查找层级(用于 per_request / image 模式)
func (p *ChannelModelPricing) GetTierByLabel(label string) *PricingInterval {
labelLower := strings.ToLower(label)
for i := range p.Intervals {
if strings.ToLower(p.Intervals[i].TierLabel) == labelLower {
return &p.Intervals[i]
}
}
return nil
}
// Clone 返回 ChannelModelPricing 的拷贝(切片独立,指针字段共享,调用方只读安全)
func (p ChannelModelPricing) Clone() ChannelModelPricing {
cp := p
if p.Models != nil {
cp.Models = make([]string, len(p.Models))
copy(cp.Models, p.Models)
}
if p.Intervals != nil {
cp.Intervals = make([]PricingInterval, len(p.Intervals))
copy(cp.Intervals, p.Intervals)
}
return cp
}
// Clone 返回 Channel 的深拷贝
func (c *Channel) Clone() *Channel {
if c == nil {
return nil
}
cp := *c
if c.GroupIDs != nil {
cp.GroupIDs = make([]int64, len(c.GroupIDs))
copy(cp.GroupIDs, c.GroupIDs)
}
if c.ModelPricing != nil {
cp.ModelPricing = make([]ChannelModelPricing, len(c.ModelPricing))
for i := range c.ModelPricing {
cp.ModelPricing[i] = c.ModelPricing[i].Clone()
}
}
if c.ModelMapping != nil {
cp.ModelMapping = make(map[string]map[string]string, len(c.ModelMapping))
for platform, mapping := range c.ModelMapping {
inner := make(map[string]string, len(mapping))
for k, v := range mapping {
inner[k] = v
}
cp.ModelMapping[platform] = inner
}
}
return &cp
}
// ValidateIntervals 校验区间列表的合法性。
// 规则MinTokens >= 0MaxTokens 若非 nil 则 > 0 且 > MinTokens
// 所有价格字段 >= 0区间按 MinTokens 排序后无重叠((min, max] 语义);
// 无界区间MaxTokens=nil必须是最后一个。间隙允许回退默认价格
func ValidateIntervals(intervals []PricingInterval) error {
if len(intervals) == 0 {
return nil
}
sorted := make([]PricingInterval, len(intervals))
copy(sorted, intervals)
sort.Slice(sorted, func(i, j int) bool {
return sorted[i].MinTokens < sorted[j].MinTokens
})
for i := range sorted {
if err := validateSingleInterval(&sorted[i], i); err != nil {
return err
}
}
return validateIntervalOverlap(sorted)
}
// validateSingleInterval 校验单个区间的字段合法性
func validateSingleInterval(iv *PricingInterval, idx int) error {
if iv.MinTokens < 0 {
return fmt.Errorf("interval #%d: min_tokens (%d) must be >= 0", idx+1, iv.MinTokens)
}
if iv.MaxTokens != nil {
if *iv.MaxTokens <= 0 {
return fmt.Errorf("interval #%d: max_tokens (%d) must be > 0", idx+1, *iv.MaxTokens)
}
if *iv.MaxTokens <= iv.MinTokens {
return fmt.Errorf("interval #%d: max_tokens (%d) must be > min_tokens (%d)",
idx+1, *iv.MaxTokens, iv.MinTokens)
}
}
return validateIntervalPrices(iv, idx)
}
// validateIntervalPrices 校验区间内所有价格字段 >= 0
func validateIntervalPrices(iv *PricingInterval, idx int) error {
prices := []struct {
name string
val *float64
}{
{"input_price", iv.InputPrice},
{"output_price", iv.OutputPrice},
{"cache_write_price", iv.CacheWritePrice},
{"cache_read_price", iv.CacheReadPrice},
{"per_request_price", iv.PerRequestPrice},
}
for _, p := range prices {
if p.val != nil && *p.val < 0 {
return fmt.Errorf("interval #%d: %s must be >= 0", idx+1, p.name)
}
}
return nil
}
// validateIntervalOverlap 校验排序后的区间列表无重叠,且无界区间在最后
func validateIntervalOverlap(sorted []PricingInterval) error {
for i, iv := range sorted {
// 无界区间必须是最后一个
if iv.MaxTokens == nil && i < len(sorted)-1 {
return fmt.Errorf("interval #%d: unbounded interval (max_tokens=null) must be the last one",
i+1)
}
if i == 0 {
continue
}
prev := sorted[i-1]
// 检查重叠:前一个区间的上界 > 当前区间的下界则重叠
// (min, max] 语义prev 覆盖 (prev.Min, prev.Max]cur 覆盖 (cur.Min, cur.Max]
if prev.MaxTokens == nil || *prev.MaxTokens > iv.MinTokens {
return fmt.Errorf("interval #%d and #%d overlap: prev max=%s > cur min=%d",
i, i+1, formatMaxTokensLabel(prev.MaxTokens), iv.MinTokens)
}
}
return nil
}
func formatMaxTokensLabel(max *int) string {
if max == nil {
return "∞"
}
return fmt.Sprintf("%d", *max)
}
// ChannelUsageFields 渠道相关的使用记录字段(嵌入到各平台的 RecordUsageInput 中)
type ChannelUsageFields struct {
ChannelID int64 // 渠道 ID0 = 无渠道)
OriginalModel string // 用户原始请求模型(渠道映射前)
ChannelMappedModel string // 渠道映射后的模型名(无映射时等于 OriginalModel
BillingModelSource string // 计费模型来源:"requested" / "upstream" / "channel_mapped"
ModelMappingChain string // 映射链描述,如 "a→b→c"
}