refactor: replace magic strings with named constants

- PricingSourceChannel/LiteLLM/Fallback for resolver source
- MediaTypeImage/Video/Prompt for result.MediaType
- Reuse BillingModeToken/BillingModeImage for billing mode
- Reuse BillingModelSourceChannelMapped/PlatformAnthropic in handler
This commit is contained in:
erio
2026-04-02 02:22:15 +08:00
parent 212eaa3a05
commit 0d241d52eb
4 changed files with 39 additions and 19 deletions

View File

@@ -130,7 +130,7 @@ func channelToResponse(ch *service.Channel) *channelResponse {
} }
resp.BillingModelSource = ch.BillingModelSource resp.BillingModelSource = ch.BillingModelSource
if resp.BillingModelSource == "" { if resp.BillingModelSource == "" {
resp.BillingModelSource = "channel_mapped" resp.BillingModelSource = service.BillingModelSourceChannelMapped
} }
if resp.GroupIDs == nil { if resp.GroupIDs == nil {
resp.GroupIDs = []int64{} resp.GroupIDs = []int64{}
@@ -147,11 +147,11 @@ func channelToResponse(ch *service.Channel) *channelResponse {
} }
billingMode := string(p.BillingMode) billingMode := string(p.BillingMode)
if billingMode == "" { if billingMode == "" {
billingMode = "token" billingMode = string(service.BillingModeToken)
} }
platform := p.Platform platform := p.Platform
if platform == "" { if platform == "" {
platform = "anthropic" platform = service.PlatformAnthropic
} }
intervals := make([]pricingIntervalResponse, 0, len(p.Intervals)) intervals := make([]pricingIntervalResponse, 0, len(p.Intervals))
for _, iv := range p.Intervals { for _, iv := range p.Intervals {
@@ -194,7 +194,7 @@ func pricingRequestToService(reqs []channelModelPricingRequest) []service.Channe
} }
platform := r.Platform platform := r.Platform
if platform == "" { if platform == "" {
platform = "anthropic" platform = service.PlatformAnthropic
} }
intervals := make([]service.PricingInterval, 0, len(r.Intervals)) intervals := make([]service.PricingInterval, 0, len(r.Intervals))
for _, iv := range r.Intervals { for _, iv := range r.Intervals {

View File

@@ -60,6 +60,19 @@ const (
claudeMimicDebugInfoKey = "claude_mimic_debug_info" claudeMimicDebugInfoKey = "claude_mimic_debug_info"
) )
// MediaType 媒体类型常量
const (
MediaTypeImage = "image"
MediaTypeVideo = "video"
MediaTypePrompt = "prompt"
)
const (
claudeMaxMessageOverheadTokens = 3
claudeMaxBlockOverheadTokens = 1
claudeMaxUnknownContentTokens = 4
)
// ForceCacheBillingContextKey 强制缓存计费上下文键 // ForceCacheBillingContextKey 强制缓存计费上下文键
// 用于粘性会话切换时,将 input_tokens 转为 cache_read_input_tokens 计费 // 用于粘性会话切换时,将 input_tokens 转为 cache_read_input_tokens 计费
type forceCacheBillingKeyType struct{} type forceCacheBillingKeyType struct{}
@@ -7744,7 +7757,7 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu
} }
// 根据请求类型选择计费方式 // 根据请求类型选择计费方式
if result.MediaType == "image" || result.MediaType == "video" { if result.MediaType == MediaTypeImage || result.MediaType == MediaTypeVideo {
var soraConfig *SoraPriceConfig var soraConfig *SoraPriceConfig
if apiKey.Group != nil { if apiKey.Group != nil {
soraConfig = &SoraPriceConfig{ soraConfig = &SoraPriceConfig{
@@ -7754,12 +7767,12 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu
VideoPricePerRequestHD: apiKey.Group.SoraVideoPricePerRequestHD, VideoPricePerRequestHD: apiKey.Group.SoraVideoPricePerRequestHD,
} }
} }
if result.MediaType == "image" { if result.MediaType == MediaTypeImage {
cost = s.billingService.CalculateSoraImageCost(result.ImageSize, result.ImageCount, soraConfig, multiplier) cost = s.billingService.CalculateSoraImageCost(result.ImageSize, result.ImageCount, soraConfig, multiplier)
} else { } else {
cost = s.billingService.CalculateSoraVideoCost(billingModel, soraConfig, multiplier) cost = s.billingService.CalculateSoraVideoCost(billingModel, soraConfig, multiplier)
} }
} else if result.MediaType == "prompt" { } else if result.MediaType == MediaTypePrompt {
cost = &CostBreakdown{} cost = &CostBreakdown{}
} else if result.ImageCount > 0 { } else if result.ImageCount > 0 {
// 图片生成计费:渠道级别定价优先,否则走按次计费(兼容旧版本) // 图片生成计费:渠道级别定价优先,否则走按次计费(兼容旧版本)
@@ -7767,7 +7780,7 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu
if s.resolver != nil && apiKey.Group != nil { if s.resolver != nil && apiKey.Group != nil {
gid := apiKey.Group.ID gid := apiKey.Group.ID
resolved := s.resolver.Resolve(ctx, PricingInput{Model: billingModel, GroupID: &gid}) resolved := s.resolver.Resolve(ctx, PricingInput{Model: billingModel, GroupID: &gid})
if resolved.Source == "channel" { if resolved.Source == PricingSourceChannel {
hasChannelPricing = true hasChannelPricing = true
} }
} }
@@ -7900,15 +7913,15 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu
} }
// 设置计费模式 // 设置计费模式
if result.MediaType != "image" && result.MediaType != "video" && result.MediaType != "prompt" { if result.MediaType != MediaTypeImage && result.MediaType != MediaTypeVideo && result.MediaType != MediaTypePrompt {
if cost != nil && cost.BillingMode != "" { if cost != nil && cost.BillingMode != "" {
billingMode := cost.BillingMode billingMode := cost.BillingMode
usageLog.BillingMode = &billingMode usageLog.BillingMode = &billingMode
} else if result.ImageCount > 0 { } else if result.ImageCount > 0 {
billingMode := "image" billingMode := string(BillingModeImage)
usageLog.BillingMode = &billingMode usageLog.BillingMode = &billingMode
} else { } else {
billingMode := "token" billingMode := string(BillingModeToken)
usageLog.BillingMode = &billingMode usageLog.BillingMode = &billingMode
} }
} }
@@ -8038,7 +8051,7 @@ func (s *GatewayService) RecordUsageWithLongContext(ctx context.Context, input *
if s.resolver != nil && apiKey.Group != nil { if s.resolver != nil && apiKey.Group != nil {
gid := apiKey.Group.ID gid := apiKey.Group.ID
resolved := s.resolver.Resolve(ctx, PricingInput{Model: billingModel, GroupID: &gid}) resolved := s.resolver.Resolve(ctx, PricingInput{Model: billingModel, GroupID: &gid})
if resolved.Source == "channel" { if resolved.Source == PricingSourceChannel {
hasChannelPricing = true hasChannelPricing = true
} }
} }
@@ -8094,7 +8107,7 @@ func (s *GatewayService) RecordUsageWithLongContext(ctx context.Context, input *
Model: billingModel, Model: billingModel,
GroupID: &gid, GroupID: &gid,
}) })
if resolved.Source == "channel" { if resolved.Source == PricingSourceChannel {
// 有渠道定价,渠道区间已包含上下文分层 // 有渠道定价,渠道区间已包含上下文分层
cost, err = s.billingService.CalculateCostUnified(CostInput{ cost, err = s.billingService.CalculateCostUnified(CostInput{
Ctx: ctx, Ctx: ctx,
@@ -8179,10 +8192,10 @@ func (s *GatewayService) RecordUsageWithLongContext(ctx context.Context, input *
billingMode := cost.BillingMode billingMode := cost.BillingMode
usageLog.BillingMode = &billingMode usageLog.BillingMode = &billingMode
} else if result.ImageCount > 0 { } else if result.ImageCount > 0 {
billingMode := "image" billingMode := string(BillingModeImage)
usageLog.BillingMode = &billingMode usageLog.BillingMode = &billingMode
} else { } else {
billingMode := "token" billingMode := string(BillingModeToken)
usageLog.BillingMode = &billingMode usageLog.BillingMode = &billingMode
} }

View File

@@ -5,6 +5,13 @@ import (
"log/slog" "log/slog"
) )
// PricingSource 定价来源标识
const (
PricingSourceChannel = "channel"
PricingSourceLiteLLM = "litellm"
PricingSourceFallback = "fallback"
)
// ResolvedPricing 统一定价解析结果 // ResolvedPricing 统一定价解析结果
type ResolvedPricing struct { type ResolvedPricing struct {
// Mode 计费模式 // Mode 计费模式
@@ -78,9 +85,9 @@ func (r *ModelPricingResolver) resolveBasePricing(model string) (*ModelPricing,
if err != nil { if err != nil {
slog.Debug("failed to get model pricing from LiteLLM, using fallback", slog.Debug("failed to get model pricing from LiteLLM, using fallback",
"model", model, "error", err) "model", model, "error", err)
return nil, "fallback" return nil, PricingSourceFallback
} }
return pricing, "litellm" return pricing, PricingSourceLiteLLM
} }
// applyChannelOverrides 应用渠道定价覆盖 // applyChannelOverrides 应用渠道定价覆盖
@@ -90,7 +97,7 @@ func (r *ModelPricingResolver) applyChannelOverrides(ctx context.Context, groupI
return return
} }
resolved.Source = "channel" resolved.Source = PricingSourceChannel
resolved.Mode = chPricing.BillingMode resolved.Mode = chPricing.BillingMode
if resolved.Mode == "" { if resolved.Mode == "" {
resolved.Mode = BillingModeToken resolved.Mode = BillingModeToken

View File

@@ -4290,7 +4290,7 @@ func (s *OpenAIGatewayService) RecordUsage(ctx context.Context, input *OpenAIRec
billingMode := cost.BillingMode billingMode := cost.BillingMode
usageLog.BillingMode = &billingMode usageLog.BillingMode = &billingMode
} else { } else {
billingMode := "token" billingMode := string(BillingModeToken)
usageLog.BillingMode = &billingMode usageLog.BillingMode = &billingMode
} }
// 添加 UserAgent // 添加 UserAgent